mirror of
https://github.com/golang/go
synced 2024-11-19 11:44:45 -07:00
net/http: don't ignore errors in Request.Write
LGTM=josharian, adg R=golang-codereviews, josharian, adg CC=golang-codereviews https://golang.org/cl/119110043
This commit is contained in:
parent
8cb040771b
commit
c4807f6a84
@ -390,10 +390,16 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err
|
||||
w = bw
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "%s %s HTTP/1.1\r\n", valueOrDefault(req.Method, "GET"), ruri)
|
||||
_, err := fmt.Fprintf(w, "%s %s HTTP/1.1\r\n", valueOrDefault(req.Method, "GET"), ruri)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Header lines
|
||||
fmt.Fprintf(w, "Host: %s\r\n", host)
|
||||
_, err = fmt.Fprintf(w, "Host: %s\r\n", host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Use the defaultUserAgent unless the Header contains one, which
|
||||
// may be blank to not send the header.
|
||||
@ -404,7 +410,10 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err
|
||||
}
|
||||
}
|
||||
if userAgent != "" {
|
||||
fmt.Fprintf(w, "User-Agent: %s\r\n", userAgent)
|
||||
_, err = fmt.Fprintf(w, "User-Agent: %s\r\n", userAgent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Process Body,ContentLength,Close,Trailer
|
||||
@ -429,7 +438,10 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err
|
||||
}
|
||||
}
|
||||
|
||||
io.WriteString(w, "\r\n")
|
||||
_, err = io.WriteString(w, "\r\n")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write body and trailer
|
||||
err = tw.WriteBody(w)
|
||||
|
@ -563,3 +563,61 @@ func mustParseURL(s string) *url.URL {
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
||||
type writerFunc func([]byte) (int, error)
|
||||
|
||||
func (f writerFunc) Write(p []byte) (int, error) { return f(p) }
|
||||
|
||||
// TestRequestWriteError tests the Write err != nil checks in (*Request).write.
|
||||
func TestRequestWriteError(t *testing.T) {
|
||||
failAfter, writeCount := 0, 0
|
||||
errFail := errors.New("fake write failure")
|
||||
|
||||
// w is the buffered io.Writer to write the request to. It
|
||||
// fails exactly once on its Nth Write call, as controlled by
|
||||
// failAfter. It also tracks the number of calls in
|
||||
// writeCount.
|
||||
w := struct {
|
||||
io.ByteWriter // to avoid being wrapped by a bufio.Writer
|
||||
io.Writer
|
||||
}{
|
||||
nil,
|
||||
writerFunc(func(p []byte) (n int, err error) {
|
||||
writeCount++
|
||||
if failAfter == 0 {
|
||||
err = errFail
|
||||
}
|
||||
failAfter--
|
||||
return len(p), err
|
||||
}),
|
||||
}
|
||||
|
||||
req, _ := NewRequest("GET", "http://example.com/", nil)
|
||||
const writeCalls = 4 // number of Write calls in current implementation
|
||||
sawGood := false
|
||||
for n := 0; n <= writeCalls+2; n++ {
|
||||
failAfter = n
|
||||
writeCount = 0
|
||||
err := req.Write(w)
|
||||
var wantErr error
|
||||
if n < writeCalls {
|
||||
wantErr = errFail
|
||||
}
|
||||
if err != wantErr {
|
||||
t.Errorf("for fail-after %d Writes, err = %v; want %v", n, err, wantErr)
|
||||
continue
|
||||
}
|
||||
if err == nil {
|
||||
sawGood = true
|
||||
if writeCount != writeCalls {
|
||||
t.Fatalf("writeCalls constant is outdated in test")
|
||||
}
|
||||
}
|
||||
if writeCount > writeCalls || writeCount > n+1 {
|
||||
t.Errorf("for fail-after %d, saw unexpectedly high (%d) write calls", n, writeCount)
|
||||
}
|
||||
}
|
||||
if !sawGood {
|
||||
t.Fatalf("writeCalls constant is outdated in test")
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user