diff --git a/src/pkg/net/http/request.go b/src/pkg/net/http/request.go index 80bff9c0ec..131cb6d67e 100644 --- a/src/pkg/net/http/request.go +++ b/src/pkg/net/http/request.go @@ -390,10 +390,16 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err w = bw } - fmt.Fprintf(w, "%s %s HTTP/1.1\r\n", valueOrDefault(req.Method, "GET"), ruri) + _, err := fmt.Fprintf(w, "%s %s HTTP/1.1\r\n", valueOrDefault(req.Method, "GET"), ruri) + if err != nil { + return err + } // Header lines - fmt.Fprintf(w, "Host: %s\r\n", host) + _, err = fmt.Fprintf(w, "Host: %s\r\n", host) + if err != nil { + return err + } // Use the defaultUserAgent unless the Header contains one, which // may be blank to not send the header. @@ -404,7 +410,10 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err } } if userAgent != "" { - fmt.Fprintf(w, "User-Agent: %s\r\n", userAgent) + _, err = fmt.Fprintf(w, "User-Agent: %s\r\n", userAgent) + if err != nil { + return err + } } // Process Body,ContentLength,Close,Trailer @@ -429,7 +438,10 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err } } - io.WriteString(w, "\r\n") + _, err = io.WriteString(w, "\r\n") + if err != nil { + return err + } // Write body and trailer err = tw.WriteBody(w) diff --git a/src/pkg/net/http/requestwrite_test.go b/src/pkg/net/http/requestwrite_test.go index dc0e204cac..997010c2b2 100644 --- a/src/pkg/net/http/requestwrite_test.go +++ b/src/pkg/net/http/requestwrite_test.go @@ -563,3 +563,61 @@ func mustParseURL(s string) *url.URL { } return u } + +type writerFunc func([]byte) (int, error) + +func (f writerFunc) Write(p []byte) (int, error) { return f(p) } + +// TestRequestWriteError tests the Write err != nil checks in (*Request).write. +func TestRequestWriteError(t *testing.T) { + failAfter, writeCount := 0, 0 + errFail := errors.New("fake write failure") + + // w is the buffered io.Writer to write the request to. It + // fails exactly once on its Nth Write call, as controlled by + // failAfter. It also tracks the number of calls in + // writeCount. + w := struct { + io.ByteWriter // to avoid being wrapped by a bufio.Writer + io.Writer + }{ + nil, + writerFunc(func(p []byte) (n int, err error) { + writeCount++ + if failAfter == 0 { + err = errFail + } + failAfter-- + return len(p), err + }), + } + + req, _ := NewRequest("GET", "http://example.com/", nil) + const writeCalls = 4 // number of Write calls in current implementation + sawGood := false + for n := 0; n <= writeCalls+2; n++ { + failAfter = n + writeCount = 0 + err := req.Write(w) + var wantErr error + if n < writeCalls { + wantErr = errFail + } + if err != wantErr { + t.Errorf("for fail-after %d Writes, err = %v; want %v", n, err, wantErr) + continue + } + if err == nil { + sawGood = true + if writeCount != writeCalls { + t.Fatalf("writeCalls constant is outdated in test") + } + } + if writeCount > writeCalls || writeCount > n+1 { + t.Errorf("for fail-after %d, saw unexpectedly high (%d) write calls", n, writeCount) + } + } + if !sawGood { + t.Fatalf("writeCalls constant is outdated in test") + } +}