diff --git a/src/net/http/transport.go b/src/net/http/transport.go index 7fdd94e05b..865dbdd508 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -387,47 +387,47 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { if err == nil { return resp, nil } - if err := checkTransportResend(err, req, pconn); err != nil { + if !pconn.shouldRetryRequest(req, err) { return nil, err } testHookRoundTripRetried() } } -// checkTransportResend checks whether a failed HTTP request can be -// resent on a new connection. The non-nil input error is the error from -// roundTrip, which might be wrapped in a beforeRespHeaderError error. -// -// The return value is either nil to retry the request, the provided -// err unmodified, or the unwrapped error inside a -// beforeRespHeaderError. -func checkTransportResend(err error, req *Request, pconn *persistConn) error { - brhErr, ok := err.(beforeRespHeaderError) - if !ok { - return err +// shouldRetryRequest reports whether we should retry sending a failed +// HTTP request on a new connection. The non-nil input error is the +// error from roundTrip. +func (pc *persistConn) shouldRetryRequest(req *Request, err error) bool { + if err == errMissingHost { + // User error. + return false } - err = brhErr.error // unwrap the custom error in case we return it - if err != errMissingHost && pconn.isReused() && req.isReplayable() { - // If we try to reuse a connection that the server is in the process of - // closing, we may end up successfully writing out our request (or a - // portion of our request) only to find a connection error when we try to - // read from (or finish writing to) the socket. - - // There can be a race between the socket pool checking whether a socket - // is still connected, receiving the FIN, and sending/reading data on a - // reused socket. If we receive the FIN between the connectedness check - // and writing/reading from the socket, we may first learn the socket is - // disconnected when we get a ERR_SOCKET_NOT_CONNECTED. This will most - // likely happen when trying to retrieve its IP address. See - // http://crbug.com/105824 for more details. - - // We resend a request only if we reused a keep-alive connection and did - // not yet receive any header data. This automatically prevents an - // infinite resend loop because we'll run out of the cached keep-alive - // connections eventually. - return nil + if !pc.isReused() { + // This was a fresh connection. There's no reason the server + // should've hung up on us. + // + // Also, if we retried now, we could loop forever + // creating new connections and retrying if the server + // is just hanging up on us because it doesn't like + // our request (as opposed to sending an error). + return false } - return err + if !req.isReplayable() { + // Don't retry non-idempotent requests. + + // TODO: swap the nothingWrittenError and isReplayable checks, + // putting the "if nothingWrittenError => return true" case + // first, per golang.org/issue/15723 + return false + } + if _, ok := err.(nothingWrittenError); ok { + // We never wrote anything, so it's safe to retry. + return true + } + if err == errServerClosedIdle || err == errServerClosedConn { + return true + } + return false // conservatively } // ErrSkipAltProtocol is a sentinel error value defined by Transport.RegisterProtocol. @@ -570,7 +570,8 @@ var ( errTooManyIdleHost = errors.New("http: putIdleConn: too many idle connections for host") errCloseIdleConns = errors.New("http: CloseIdleConnections called") errReadLoopExiting = errors.New("http: persistConn.readLoop exiting") - errServerClosedIdle = errors.New("http: server closed idle conn") + errServerClosedIdle = errors.New("http: server closed idle connection") + errServerClosedConn = errors.New("http: server closed connection") errIdleConnTimeout = errors.New("http: idle connection timeout") ) @@ -881,12 +882,13 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistConn, error) { pconn := &persistConn{ - t: t, - cacheKey: cm.key(), - reqch: make(chan requestAndChan, 1), - writech: make(chan writeRequest, 1), - closech: make(chan struct{}), - writeErrCh: make(chan error, 1), + t: t, + cacheKey: cm.key(), + reqch: make(chan requestAndChan, 1), + writech: make(chan writeRequest, 1), + closech: make(chan struct{}), + writeErrCh: make(chan error, 1), + writeLoopDone: make(chan struct{}), } tlsDial := t.DialTLS != nil && cm.targetScheme == "https" && cm.proxyURL == nil if tlsDial { @@ -1003,12 +1005,28 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon } pconn.br = bufio.NewReader(pconn) - pconn.bw = bufio.NewWriter(pconn.conn) + pconn.bw = bufio.NewWriter(persistConnWriter{pconn}) go pconn.readLoop() go pconn.writeLoop() return pconn, nil } +// persistConnWriter is the io.Writer written to by pc.bw. +// It accumulates the number of bytes written to the underlying conn, +// so the retry logic can determine whether any bytes made it across +// the wire. +// This is exactly 1 pointer field wide so it can go into an interface +// without allocation. +type persistConnWriter struct { + pc *persistConn +} + +func (w persistConnWriter) Write(p []byte) (n int, err error) { + n, err = w.pc.conn.Write(p) + w.pc.nwrite += int64(n) + return +} + // useProxy reports whether requests to addr should use a proxy, // according to the NO_PROXY or no_proxy environment variable. // addr is always a canonicalAddr with a host and port. @@ -1142,6 +1160,7 @@ type persistConn struct { tlsState *tls.ConnectionState br *bufio.Reader // from conn bw *bufio.Writer // to conn + nwrite int64 // bytes written reqch chan requestAndChan // written by roundTrip; read by readLoop writech chan writeRequest // written by roundTrip; read by writeLoop closech chan struct{} // closed when conn closed @@ -1154,6 +1173,8 @@ type persistConn struct { // whether or not a connection can be reused. Issue 7569. writeErrCh chan error + writeLoopDone chan struct{} // closed when write loop ends + // Both guarded by Transport.idleMu: idleAt time.Time // time it last become idle idleTimer *time.Timer // holding an AfterFunc to close it @@ -1195,7 +1216,7 @@ func (pc *persistConn) Read(p []byte) (n int, err error) { // isBroken reports whether this connection is in a known broken state. func (pc *persistConn) isBroken() bool { pc.mu.Lock() - b := pc.broken + b := pc.closed != nil pc.mu.Unlock() return b } @@ -1247,6 +1268,56 @@ func (pc *persistConn) closeConnIfStillIdle() { pc.close(errIdleConnTimeout) } +// mapRoundTripErrorFromReadLoop maps the provided readLoop error into +// the error value that should be returned from persistConn.roundTrip. +// +// The startBytesWritten value should be the value of pc.nwrite before the roundTrip +// started writing the request. +func (pc *persistConn) mapRoundTripErrorFromReadLoop(startBytesWritten int64, err error) (out error) { + if err == nil { + return nil + } + if pc.isCanceled() { + return errRequestCanceled + } + if err == errServerClosedIdle || err == errServerClosedConn { + return err + } + if pc.isBroken() { + <-pc.writeLoopDone + if pc.nwrite == startBytesWritten { + return nothingWrittenError{err} + } + } + return err +} + +// mapRoundTripErrorAfterClosed returns the error value to be propagated +// up to Transport.RoundTrip method when persistConn.roundTrip sees +// its pc.closech channel close, indicating the persistConn is dead. +// (after closech is closed, pc.closed is valid). +func (pc *persistConn) mapRoundTripErrorAfterClosed(startBytesWritten int64) error { + if pc.isCanceled() { + return errRequestCanceled + } + err := pc.closed + if err == errServerClosedIdle || err == errServerClosedConn { + // Don't decorate + return err + } + + // Wait for the writeLoop goroutine to terminated, and then + // see if we actually managed to write anything. If not, we + // can retry the request. + <-pc.writeLoopDone + if pc.nwrite == startBytesWritten { + return nothingWrittenError{err} + } + + return fmt.Errorf("net/http: HTTP/1.x transport connection broken: %v", err) + +} + func (pc *persistConn) readLoop() { closeErr := errReadLoopExiting // default value, if not changed below defer func() { @@ -1283,9 +1354,6 @@ func (pc *persistConn) readLoop() { for alive { pc.readLimit = pc.maxHeaderResponseSize() _, err := pc.br.Peek(1) - if err != nil { - err = beforeRespHeaderError{err} - } pc.mu.Lock() if pc.numExpectedResponses == 0 { @@ -1301,12 +1369,16 @@ func (pc *persistConn) readLoop() { var resp *Response if err == nil { resp, err = pc.readResponse(rc, trace) + } else { + err = errServerClosedConn + closeErr = err } if err != nil { if pc.readLimit <= 0 { err = fmt.Errorf("net/http: server response headers exceeded %d bytes; aborted", pc.maxHeaderResponseSize()) } + // If we won't be able to retry this request later (from the // roundTrip goroutine), mark it as done now. // BEFORE the send on rc.ch, as the client might re-use the @@ -1314,7 +1386,7 @@ func (pc *persistConn) readLoop() { // t.setReqCanceler from this persistConn while the Transport // potentially spins up a different persistConn for the // caller's subsequent request. - if checkTransportResend(err, rc.req, pc) != nil { + if !pc.shouldRetryRequest(rc.req, err) { pc.t.setReqCanceler(rc.req, nil) } select { @@ -1501,24 +1573,33 @@ func (pc *persistConn) waitForContinue(continueCh <-chan struct{}) func() bool { } } +// nothingWrittenError wraps a write errors which ended up writing zero bytes. +type nothingWrittenError struct { + error +} + func (pc *persistConn) writeLoop() { + defer close(pc.writeLoopDone) for { select { case wr := <-pc.writech: - if pc.isBroken() { - wr.ch <- errors.New("http: can't write HTTP request on broken connection") - continue - } + startBytesWritten := pc.nwrite err := wr.req.Request.write(pc.bw, pc.isProxy, wr.req.extra, pc.waitForContinue(wr.continueCh)) if err == nil { err = pc.bw.Flush() } if err != nil { - pc.markBroken() wr.req.Request.closeBody() + if pc.nwrite == startBytesWritten { + err = nothingWrittenError{err} + } } pc.writeErrCh <- err // to the body reader, which might recycle us wr.ch <- err // to the roundTrip function + if err != nil { + pc.close(err) + return + } case <-pc.closech: return } @@ -1619,12 +1700,6 @@ var ( testHookReadLoopBeforeNextRead = nop ) -// beforeRespHeaderError is used to indicate when an IO error has occurred before -// any header data was received. -type beforeRespHeaderError struct { - error -} - func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) { testHookEnterRoundTrip() if !pc.t.replaceReqCanceler(req.Request, pc.cancelRequest) { @@ -1680,6 +1755,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // Write the request concurrently with waiting for a response, // in case the server decides to reply before reading our full // request body. + startBytesWritten := pc.nwrite writeErrCh := make(chan error, 1) pc.writech <- writeRequest{req, writeErrCh, continueCh} @@ -1704,7 +1780,7 @@ WaitResponse: if pc.isCanceled() { err = errRequestCanceled } - re = responseAndError{err: beforeRespHeaderError{err}} + re = responseAndError{err: err} pc.close(fmt.Errorf("write error: %v", err)) break WaitResponse } @@ -1714,22 +1790,14 @@ WaitResponse: respHeaderTimer = timer.C } case <-pc.closech: - var err error - if pc.isCanceled() { - err = errRequestCanceled - } else { - err = beforeRespHeaderError{fmt.Errorf("net/http: HTTP/1 transport connection broken: %v", pc.closed)} - } - re = responseAndError{err: err} + re = responseAndError{err: pc.mapRoundTripErrorAfterClosed(startBytesWritten)} break WaitResponse case <-respHeaderTimer: pc.close(errTimeout) re = responseAndError{err: errTimeout} break WaitResponse case re = <-resc: - if re.err != nil && pc.isCanceled() { - re.err = errRequestCanceled - } + re.err = pc.mapRoundTripErrorFromReadLoop(startBytesWritten, re.err) break WaitResponse case <-cancelChan: pc.t.CancelRequest(req.Request) @@ -1749,15 +1817,6 @@ WaitResponse: return re.res, re.err } -// markBroken marks a connection as broken (so it's not reused). -// It differs from close in that it doesn't close the underlying -// connection for use when it's still being read. -func (pc *persistConn) markBroken() { - pc.mu.Lock() - defer pc.mu.Unlock() - pc.broken = true -} - // markReused marks this connection as having been successfully used for a // request and response. func (pc *persistConn) markReused() { diff --git a/src/net/http/transport_internal_test.go b/src/net/http/transport_internal_test.go new file mode 100644 index 0000000000..a157d90630 --- /dev/null +++ b/src/net/http/transport_internal_test.go @@ -0,0 +1,69 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// White-box tests for transport.go (in package http instead of http_test). + +package http + +import ( + "errors" + "net" + "testing" +) + +// Issue 15446: incorrect wrapping of errors when server closes an idle connection. +func TestTransportPersistConnReadLoopEOF(t *testing.T) { + ln := newLocalListener(t) + defer ln.Close() + + connc := make(chan net.Conn, 1) + go func() { + defer close(connc) + c, err := ln.Accept() + if err != nil { + t.Error(err) + return + } + connc <- c + }() + + tr := new(Transport) + req, _ := NewRequest("GET", "http://"+ln.Addr().String(), nil) + treq := &transportRequest{Request: req} + cm := connectMethod{targetScheme: "http", targetAddr: ln.Addr().String()} + pc, err := tr.getConn(treq, cm) + if err != nil { + t.Fatal(err) + } + defer pc.close(errors.New("test over")) + + conn := <-connc + if conn == nil { + // Already called t.Error in the accept goroutine. + return + } + conn.Close() // simulate the server hanging up on the client + + _, err = pc.roundTrip(treq) + if err != errServerClosedConn && err != errServerClosedIdle { + t.Fatalf("roundTrip = %#v, %v; want errServerClosedConn or errServerClosedIdle", err, err) + } + + <-pc.closech + err = pc.closed + if err != errServerClosedConn && err != errServerClosedIdle { + t.Fatalf("pc.closed = %#v, %v; want errServerClosedConn or errServerClosedIdle", err, err) + } +} + +func newLocalListener(t *testing.T) net.Listener { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + ln, err = net.Listen("tcp6", "[::1]:0") + } + if err != nil { + t.Fatal(err) + } + return ln +}