1
0
mirror of https://github.com/golang/go synced 2024-10-04 15:11:20 -06:00

net/http: on Transport body write error, wait briefly for a response

From https://github.com/golang/go/issues/11745#issuecomment-123555313 :

The http.RoundTripper interface says you get either a *Response, or an
error.

But in the case of a client writing a large request and the server
replying prematurely (e.g. 403 Forbidden) and closing the connection
without reading the request body, what does the client want? The 403
response, or the error that the body couldn't be copied?

This CL implements the aforementioned comment's option c), making the
Transport give an N millisecond advantage to responses over body write
errors.

Updates #11745

Change-Id: I4485a782505d54de6189f6856a7a1f33ce4d5e5e
Reviewed-on: https://go-review.googlesource.com/12590
Reviewed-by: Russ Cox <rsc@golang.org>
This commit is contained in:
Brad Fitzpatrick 2015-07-23 14:05:54 -07:00 committed by Russ Cox
parent a60c5366f9
commit 0c2d3f7346
2 changed files with 82 additions and 7 deletions

View File

@ -1164,6 +1164,19 @@ WaitResponse:
for {
select {
case err := <-writeErrCh:
if isSyscallWriteError(err) {
// Issue 11745. If we failed to write the request
// body, it's possible the server just heard enough
// and already wrote to us. Prioritize the server's
// response over returning a body write error.
select {
case re = <-resc:
pc.close()
break WaitResponse
case <-time.After(50 * time.Millisecond):
// Fall through.
}
}
if err != nil {
re = responseAndError{nil, err}
pc.close()
@ -1366,3 +1379,16 @@ type fakeLocker struct{}
func (fakeLocker) Lock() {}
func (fakeLocker) Unlock() {}
func isSyscallWriteError(err error) bool {
switch e := err.(type) {
case *url.Error:
return isSyscallWriteError(e.Err)
case *net.OpError:
return e.Op == "write" && isSyscallWriteError(e.Err)
case *os.SyscallError:
return e.Syscall == "write"
default:
return false
}
}

View File

@ -18,7 +18,6 @@ import (
"io/ioutil"
"log"
"net"
"net/http"
. "net/http"
"net/http/httptest"
"net/url"
@ -320,7 +319,7 @@ func TestTransportReadToEndReusesConn(t *testing.T) {
addrSeen[r.RemoteAddr]++
if r.URL.Path == "/chunked/" {
w.WriteHeader(200)
w.(http.Flusher).Flush()
w.(Flusher).Flush()
} else {
w.Header().Set("Content-Type", strconv.Itoa(len(msg)))
w.WriteHeader(200)
@ -335,7 +334,7 @@ func TestTransportReadToEndReusesConn(t *testing.T) {
wantLen := []int{len(msg), -1}[pi]
addrSeen = make(map[string]int)
for i := 0; i < 3; i++ {
res, err := http.Get(ts.URL + path)
res, err := Get(ts.URL + path)
if err != nil {
t.Errorf("Get %s: %v", path, err)
continue
@ -1976,7 +1975,7 @@ func TestIdleConnChannelLeak(t *testing.T) {
// then closes it.
func TestTransportClosesRequestBody(t *testing.T) {
defer afterTest(t)
ts := httptest.NewServer(http.HandlerFunc(func(w ResponseWriter, r *Request) {
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
io.Copy(ioutil.Discard, r.Body)
}))
defer ts.Close()
@ -2346,13 +2345,13 @@ func TestTransportDialTLS(t *testing.T) {
// Test for issue 8755
// Ensure that if a proxy returns an error, it is exposed by RoundTrip
func TestRoundTripReturnsProxyError(t *testing.T) {
badProxy := func(*http.Request) (*url.URL, error) {
badProxy := func(*Request) (*url.URL, error) {
return nil, errors.New("errorMessage")
}
tr := &Transport{Proxy: badProxy}
req, _ := http.NewRequest("GET", "http://example.com", nil)
req, _ := NewRequest("GET", "http://example.com", nil)
_, err := tr.RoundTrip(req)
@ -2644,7 +2643,57 @@ func TestTransportFlushesBodyChunks(t *testing.T) {
}
}
func wantBody(res *http.Response, err error, want string) error {
// Issue 11745.
func TestTransportPrefersResponseOverWriteError(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
defer afterTest(t)
const contentLengthLimit = 1024 * 1024 // 1MB
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
if r.ContentLength >= contentLengthLimit {
w.WriteHeader(StatusBadRequest)
r.Body.Close()
return
}
w.WriteHeader(StatusOK)
}))
defer ts.Close()
fail := 0
count := 100
bigBody := strings.Repeat("a", contentLengthLimit*2)
for i := 0; i < count; i++ {
req, err := NewRequest("PUT", ts.URL, strings.NewReader(bigBody))
if err != nil {
t.Fatal(err)
}
tr := new(Transport)
defer tr.CloseIdleConnections()
client := &Client{Transport: tr}
resp, err := client.Do(req)
if err != nil {
fail++
t.Logf("%d = %#v", i, err)
if ue, ok := err.(*url.Error); ok {
t.Logf("urlErr = %#v", ue.Err)
if ne, ok := ue.Err.(*net.OpError); ok {
t.Logf("netOpError = %#v", ne.Err)
}
}
} else {
resp.Body.Close()
if resp.StatusCode != 400 {
t.Errorf("Expected status code 400, got %v", resp.Status)
}
}
}
if fail > 0 {
t.Errorf("Failed %v out of %v\n", fail, count)
}
}
func wantBody(res *Response, err error, want string) error {
if err != nil {
return err
}