diff --git a/src/net/http/transfer.go b/src/net/http/transfer.go index e8a93e9137..7d73dc4fc0 100644 --- a/src/net/http/transfer.go +++ b/src/net/http/transfer.go @@ -53,19 +53,6 @@ func (br *byteReader) Read(p []byte) (n int, err error) { return 1, io.EOF } -// transferBodyReader is an io.Reader that reads from tw.Body -// and records any non-EOF error in tw.bodyReadError. -// It is exactly 1 pointer wide to avoid allocations into interfaces. -type transferBodyReader struct{ tw *transferWriter } - -func (br transferBodyReader) Read(p []byte) (n int, err error) { - n, err = br.tw.Body.Read(p) - if err != nil && err != io.EOF { - br.tw.bodyReadError = err - } - return -} - // transferWriter inspects the fields of a user-supplied Request or Response, // sanitizes them without changing the user object and provides methods for // writing the respective header, body and trailer in wire format. @@ -347,15 +334,18 @@ func (t *transferWriter) writeBody(w io.Writer) error { var err error var ncopy int64 - // Write body + // Write body. We "unwrap" the body first if it was wrapped in a + // nopCloser. This is to ensure that we can take advantage of + // OS-level optimizations in the event that the body is an + // *os.File. if t.Body != nil { - var body = transferBodyReader{t} + var body = t.unwrapBody() if chunked(t.TransferEncoding) { if bw, ok := w.(*bufio.Writer); ok && !t.IsResponse { w = &internal.FlushAfterChunkWriter{Writer: bw} } cw := internal.NewChunkedWriter(w) - _, err = io.Copy(cw, body) + _, err = t.doBodyCopy(cw, body) if err == nil { err = cw.Close() } @@ -364,14 +354,14 @@ func (t *transferWriter) writeBody(w io.Writer) error { if t.Method == "CONNECT" { dst = bufioFlushWriter{dst} } - ncopy, err = io.Copy(dst, body) + ncopy, err = t.doBodyCopy(dst, body) } else { - ncopy, err = io.Copy(w, io.LimitReader(body, t.ContentLength)) + ncopy, err = t.doBodyCopy(w, io.LimitReader(body, t.ContentLength)) if err != nil { return err } var nextra int64 - nextra, err = io.Copy(ioutil.Discard, body) + nextra, err = t.doBodyCopy(ioutil.Discard, body) ncopy += nextra } if err != nil { @@ -402,6 +392,31 @@ func (t *transferWriter) writeBody(w io.Writer) error { return err } +// doBodyCopy wraps a copy operation, with any resulting error also +// being saved in bodyReadError. +// +// This function is only intended for use in writeBody. +func (t *transferWriter) doBodyCopy(dst io.Writer, src io.Reader) (n int64, err error) { + n, err = io.Copy(dst, src) + if err != nil && err != io.EOF { + t.bodyReadError = err + } + return +} + +// unwrapBodyReader unwraps the body's inner reader if it's a +// nopCloser. This is to ensure that body writes sourced from local +// files (*os.File types) are properly optimized. +// +// This function is only intended for use in writeBody. +func (t *transferWriter) unwrapBody() io.Reader { + if reflect.TypeOf(t.Body) == nopCloserType { + return reflect.ValueOf(t.Body).Field(0).Interface().(io.Reader) + } + + return t.Body +} + type transferReader struct { // Input Header Header diff --git a/src/net/http/transfer_test.go b/src/net/http/transfer_test.go index 993ea4ef18..aa465d0600 100644 --- a/src/net/http/transfer_test.go +++ b/src/net/http/transfer_test.go @@ -7,8 +7,12 @@ package http import ( "bufio" "bytes" + "crypto/rand" + "fmt" "io" "io/ioutil" + "os" + "reflect" "strings" "testing" ) @@ -90,3 +94,187 @@ func TestDetectInMemoryReaders(t *testing.T) { } } } + +type mockTransferWriter struct { + CalledReader io.Reader + WriteCalled bool +} + +var _ io.ReaderFrom = (*mockTransferWriter)(nil) + +func (w *mockTransferWriter) ReadFrom(r io.Reader) (int64, error) { + w.CalledReader = r + return io.Copy(ioutil.Discard, r) +} + +func (w *mockTransferWriter) Write(p []byte) (int, error) { + w.WriteCalled = true + return ioutil.Discard.Write(p) +} + +func TestTransferWriterWriteBodyReaderTypes(t *testing.T) { + fileType := reflect.TypeOf(&os.File{}) + bufferType := reflect.TypeOf(&bytes.Buffer{}) + + nBytes := int64(1 << 10) + newFileFunc := func() (r io.Reader, done func(), err error) { + f, err := ioutil.TempFile("", "net-http-newfilefunc") + if err != nil { + return nil, nil, err + } + + // Write some bytes to the file to enable reading. + if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil { + return nil, nil, fmt.Errorf("failed to write data to file: %v", err) + } + if _, err := f.Seek(0, 0); err != nil { + return nil, nil, fmt.Errorf("failed to seek to front: %v", err) + } + + done = func() { + f.Close() + os.Remove(f.Name()) + } + + return f, done, nil + } + + newBufferFunc := func() (io.Reader, func(), error) { + return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil + } + + cases := []struct { + name string + bodyFunc func() (io.Reader, func(), error) + method string + contentLength int64 + transferEncoding []string + limitedReader bool + expectedReader reflect.Type + expectedWrite bool + }{ + { + name: "file, non-chunked, size set", + bodyFunc: newFileFunc, + method: "PUT", + contentLength: nBytes, + limitedReader: true, + expectedReader: fileType, + }, + { + name: "file, non-chunked, size set, nopCloser wrapped", + method: "PUT", + bodyFunc: func() (io.Reader, func(), error) { + r, cleanup, err := newFileFunc() + return ioutil.NopCloser(r), cleanup, err + }, + contentLength: nBytes, + limitedReader: true, + expectedReader: fileType, + }, + { + name: "file, non-chunked, negative size", + method: "PUT", + bodyFunc: newFileFunc, + contentLength: -1, + expectedReader: fileType, + }, + { + name: "file, non-chunked, CONNECT, negative size", + method: "CONNECT", + bodyFunc: newFileFunc, + contentLength: -1, + expectedReader: fileType, + }, + { + name: "file, chunked", + method: "PUT", + bodyFunc: newFileFunc, + transferEncoding: []string{"chunked"}, + expectedWrite: true, + }, + { + name: "buffer, non-chunked, size set", + bodyFunc: newBufferFunc, + method: "PUT", + contentLength: nBytes, + limitedReader: true, + expectedReader: bufferType, + }, + { + name: "buffer, non-chunked, size set, nopCloser wrapped", + method: "PUT", + bodyFunc: func() (io.Reader, func(), error) { + r, cleanup, err := newBufferFunc() + return ioutil.NopCloser(r), cleanup, err + }, + contentLength: nBytes, + limitedReader: true, + expectedReader: bufferType, + }, + { + name: "buffer, non-chunked, negative size", + method: "PUT", + bodyFunc: newBufferFunc, + contentLength: -1, + expectedWrite: true, + }, + { + name: "buffer, non-chunked, CONNECT, negative size", + method: "CONNECT", + bodyFunc: newBufferFunc, + contentLength: -1, + expectedWrite: true, + }, + { + name: "buffer, chunked", + method: "PUT", + bodyFunc: newBufferFunc, + transferEncoding: []string{"chunked"}, + expectedWrite: true, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + body, cleanup, err := tc.bodyFunc() + if err != nil { + t.Fatal(err) + } + defer cleanup() + + mw := &mockTransferWriter{} + tw := &transferWriter{ + Body: body, + ContentLength: tc.contentLength, + TransferEncoding: tc.transferEncoding, + } + + if err := tw.writeBody(mw); err != nil { + t.Fatal(err) + } + + if tc.expectedReader != nil { + if mw.CalledReader == nil { + t.Fatal("did not call ReadFrom") + } + + var actualReader reflect.Type + lr, ok := mw.CalledReader.(*io.LimitedReader) + if ok && tc.limitedReader { + actualReader = reflect.TypeOf(lr.R) + } else { + actualReader = reflect.TypeOf(mw.CalledReader) + } + + if tc.expectedReader != actualReader { + t.Fatalf("got reader %T want %T", actualReader, tc.expectedReader) + } + } + + if tc.expectedWrite && !mw.WriteCalled { + t.Fatal("did not invoke Write") + } + }) + } +}