1
0
mirror of https://github.com/golang/go synced 2024-11-19 02:04:42 -07:00

io: MultiReader and MultiWriter

Little helpers I've found useful.

R=adg, rsc, r, gri
CC=golang-dev
https://golang.org/cl/1764043
This commit is contained in:
Brad Fitzpatrick 2010-07-28 11:30:00 -07:00 committed by Russ Cox
parent c6e4697141
commit 719cde2c47
5 changed files with 168 additions and 0 deletions

View File

@ -7,6 +7,8 @@ include ../../Make.$(GOARCH)
TARG=io TARG=io
GOFILES=\ GOFILES=\
io.go\ io.go\
multi_reader.go\
multi_writer.go\
pipe.go\ pipe.go\
include ../../Make.pkg include ../../Make.pkg

View File

@ -0,0 +1,36 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package io
import "os"
type multiReader struct {
readers []Reader
}
func (mr *multiReader) Read(p []byte) (n int, err os.Error) {
for len(mr.readers) > 0 {
n, err = mr.readers[0].Read(p)
if n > 0 || err != os.EOF {
if err == os.EOF {
// This shouldn't happen.
// Well-behaved Readers should never
// return non-zero bytes read with an
// EOF. But if so, we clean it.
err = nil
}
return
}
mr.readers = mr.readers[1:]
}
return 0, os.EOF
}
// MultiReader returns a Reader that's the logical concatenation of
// the provided input readers. They're read sequentially. Once all
// inputs are drained, Read will return os.EOF.
func MultiReader(readers ...Reader) Reader {
return &multiReader{readers}
}

View File

@ -0,0 +1,58 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package io_test
import (
. "io"
"os"
"strings"
"testing"
)
func TestMultiReader(t *testing.T) {
var mr Reader
var buf []byte
nread := 0
withFooBar := func(tests func()) {
r1 := strings.NewReader("foo ")
r2 := strings.NewReader("bar")
mr = MultiReader(r1, r2)
buf = make([]byte, 20)
tests()
}
expectRead := func(size int, expected string, eerr os.Error) {
nread++
n, gerr := mr.Read(buf[0:size])
if n != len(expected) {
t.Errorf("#%d, expected %d bytes; got %d",
nread, len(expected), n)
}
got := string(buf[0:n])
if got != expected {
t.Errorf("#%d, expected %q; got %q",
nread, expected, got)
}
if gerr != eerr {
t.Errorf("#%d, expected error %v; got %v",
nread, eerr, gerr)
}
buf = buf[n:]
}
withFooBar(func() {
expectRead(2, "fo", nil)
expectRead(5, "o ", nil)
expectRead(5, "bar", nil)
expectRead(5, "", os.EOF)
})
withFooBar(func() {
expectRead(4, "foo ", nil)
expectRead(1, "b", nil)
expectRead(3, "ar", nil)
expectRead(1, "", os.EOF)
})
withFooBar(func() {
expectRead(5, "foo ", nil)
})
}

View File

@ -0,0 +1,31 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package io
import "os"
type multiWriter struct {
writers []Writer
}
func (t *multiWriter) Write(p []byte) (n int, err os.Error) {
for _, w := range t.writers {
n, err = w.Write(p)
if err != nil {
return
}
if n != len(p) {
err = ErrShortWrite
return
}
}
return len(p), nil
}
// MultiWriter creates a writer that duplicates its writes to all the
// provided writers, similar to the Unix tee(1) command.
func MultiWriter(writers ...Writer) Writer {
return &multiWriter{writers}
}

View File

@ -0,0 +1,41 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package io_test
import (
. "io"
"bytes"
"crypto/sha1"
"fmt"
"strings"
"testing"
)
func TestMultiWriter(t *testing.T) {
sha1 := sha1.New()
sink := new(bytes.Buffer)
mw := MultiWriter(sha1, sink)
sourceString := "My input text."
source := strings.NewReader(sourceString)
written, err := Copy(mw, source)
if written != int64(len(sourceString)) {
t.Errorf("short write of %d, not %d", written, len(sourceString))
}
if err != nil {
t.Errorf("unexpected error: %v", err)
}
sha1hex := fmt.Sprintf("%x", sha1.Sum())
if sha1hex != "01cb303fa8c30a64123067c5aa6284ba7ec2d31b" {
t.Error("incorrect sha1 value")
}
if sink.String() != sourceString {
t.Error("expected %q; got %q", sourceString, sink.String())
}
}