mirror of
https://github.com/golang/go
synced 2024-11-19 03:34:41 -07:00
io: consolidate multi_reader and multi_writer into a single file, multi.go
R=rsc CC=golang-dev https://golang.org/cl/1860046
This commit is contained in:
parent
55badd474b
commit
3f5966dcc0
@ -7,8 +7,7 @@ include ../../Make.$(GOARCH)
|
||||
TARG=io
|
||||
GOFILES=\
|
||||
io.go\
|
||||
multi_reader.go\
|
||||
multi_writer.go\
|
||||
multi.go\
|
||||
pipe.go\
|
||||
|
||||
include ../../Make.pkg
|
||||
|
@ -34,3 +34,27 @@ func (mr *multiReader) Read(p []byte) (n int, err os.Error) {
|
||||
func MultiReader(readers ...Reader) Reader {
|
||||
return &multiReader{readers}
|
||||
}
|
||||
|
||||
type multiWriter struct {
|
||||
writers []Writer
|
||||
}
|
||||
|
||||
func (t *multiWriter) Write(p []byte) (n int, err os.Error) {
|
||||
for _, w := range t.writers {
|
||||
n, err = w.Write(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if n != len(p) {
|
||||
err = ErrShortWrite
|
||||
return
|
||||
}
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// MultiWriter creates a writer that duplicates its writes to all the
|
||||
// provided writers, similar to the Unix tee(1) command.
|
||||
func MultiWriter(writers ...Writer) Writer {
|
||||
return &multiWriter{writers}
|
||||
}
|
@ -6,6 +6,9 @@ package io_test
|
||||
|
||||
import (
|
||||
. "io"
|
||||
"bytes"
|
||||
"crypto/sha1"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
@ -56,3 +59,30 @@ func TestMultiReader(t *testing.T) {
|
||||
expectRead(5, "foo ", nil)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMultiWriter(t *testing.T) {
|
||||
sha1 := sha1.New()
|
||||
sink := new(bytes.Buffer)
|
||||
mw := MultiWriter(sha1, sink)
|
||||
|
||||
sourceString := "My input text."
|
||||
source := strings.NewReader(sourceString)
|
||||
written, err := Copy(mw, source)
|
||||
|
||||
if written != int64(len(sourceString)) {
|
||||
t.Errorf("short write of %d, not %d", written, len(sourceString))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
sha1hex := fmt.Sprintf("%x", sha1.Sum())
|
||||
if sha1hex != "01cb303fa8c30a64123067c5aa6284ba7ec2d31b" {
|
||||
t.Error("incorrect sha1 value")
|
||||
}
|
||||
|
||||
if sink.String() != sourceString {
|
||||
t.Error("expected %q; got %q", sourceString, sink.String())
|
||||
}
|
||||
}
|
@ -1,31 +0,0 @@
|
||||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package io
|
||||
|
||||
import "os"
|
||||
|
||||
type multiWriter struct {
|
||||
writers []Writer
|
||||
}
|
||||
|
||||
func (t *multiWriter) Write(p []byte) (n int, err os.Error) {
|
||||
for _, w := range t.writers {
|
||||
n, err = w.Write(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if n != len(p) {
|
||||
err = ErrShortWrite
|
||||
return
|
||||
}
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// MultiWriter creates a writer that duplicates its writes to all the
|
||||
// provided writers, similar to the Unix tee(1) command.
|
||||
func MultiWriter(writers ...Writer) Writer {
|
||||
return &multiWriter{writers}
|
||||
}
|
@ -1,41 +0,0 @@
|
||||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package io_test
|
||||
|
||||
import (
|
||||
. "io"
|
||||
"bytes"
|
||||
"crypto/sha1"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMultiWriter(t *testing.T) {
|
||||
sha1 := sha1.New()
|
||||
sink := new(bytes.Buffer)
|
||||
mw := MultiWriter(sha1, sink)
|
||||
|
||||
sourceString := "My input text."
|
||||
source := strings.NewReader(sourceString)
|
||||
written, err := Copy(mw, source)
|
||||
|
||||
if written != int64(len(sourceString)) {
|
||||
t.Errorf("short write of %d, not %d", written, len(sourceString))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
sha1hex := fmt.Sprintf("%x", sha1.Sum())
|
||||
if sha1hex != "01cb303fa8c30a64123067c5aa6284ba7ec2d31b" {
|
||||
t.Error("incorrect sha1 value")
|
||||
}
|
||||
|
||||
if sink.String() != sourceString {
|
||||
t.Error("expected %q; got %q", sourceString, sink.String())
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user