diff --git a/doc/go1.16.html b/doc/go1.16.html index b3d905c168a..1694b2277db 100644 --- a/doc/go1.16.html +++ b/doc/go1.16.html @@ -539,16 +539,6 @@ func TestFoo(t *testing.T) { indefinitely.

-

- The new Conn.HandshakeContext - method allows cancellation of an in-progress handshake. The provided - context is accessible through the new - ClientHelloInfo.Context - and - CertificateRequestInfo.Context methods. Canceling the - context after the handshake has finished has no effect. -

-

Clients now return a handshake error if the server selects @@ -771,13 +761,6 @@ func TestFoo(t *testing.T) { generating a SameSite key without a value.

-

- The net/http package now passes the - Request context to - tls.Conn.HandshakeContext - when performing TLS handshakes. -

-

The Client now sends an explicit Content-Length: 0 diff --git a/src/crypto/tls/common.go b/src/crypto/tls/common.go index 5b68742975c..eec6e1ebbd9 100644 --- a/src/crypto/tls/common.go +++ b/src/crypto/tls/common.go @@ -7,7 +7,6 @@ package tls import ( "bytes" "container/list" - "context" "crypto" "crypto/ecdsa" "crypto/ed25519" @@ -444,16 +443,6 @@ type ClientHelloInfo struct { // config is embedded by the GetCertificate or GetConfigForClient caller, // for use with SupportsCertificate. config *Config - - // ctx is the context of the handshake that is in progress. - ctx context.Context -} - -// Context returns the context of the handshake that is in progress. -// This context is a child of the context passed to HandshakeContext, -// if any, and is canceled when the handshake concludes. -func (c *ClientHelloInfo) Context() context.Context { - return c.ctx } // CertificateRequestInfo contains information from a server's @@ -472,16 +461,6 @@ type CertificateRequestInfo struct { // Version is the TLS version that was negotiated for this connection. Version uint16 - - // ctx is the context of the handshake that is in progress. - ctx context.Context -} - -// Context returns the context of the handshake that is in progress. -// This context is a child of the context passed to HandshakeContext, -// if any, and is canceled when the handshake concludes. -func (c *CertificateRequestInfo) Context() context.Context { - return c.ctx } // RenegotiationSupport enumerates the different levels of support for TLS diff --git a/src/crypto/tls/conn.go b/src/crypto/tls/conn.go index 969f357834c..72ad52c194b 100644 --- a/src/crypto/tls/conn.go +++ b/src/crypto/tls/conn.go @@ -8,7 +8,6 @@ package tls import ( "bytes" - "context" "crypto/cipher" "crypto/subtle" "crypto/x509" @@ -28,7 +27,7 @@ type Conn struct { // constant conn net.Conn isClient bool - handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake + handshakeFn func() error // (*Conn).clientHandshake or serverHandshake // handshakeStatus is 1 if the connection is currently transferring // application data (i.e. is not currently processing a handshake). @@ -1191,7 +1190,7 @@ func (c *Conn) handleRenegotiation() error { defer c.handshakeMutex.Unlock() atomic.StoreUint32(&c.handshakeStatus, 0) - if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil { + if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil { c.handshakes++ } return c.handshakeErr @@ -1374,61 +1373,8 @@ func (c *Conn) closeNotify() error { // first Read or Write will call it automatically. // // For control over canceling or setting a timeout on a handshake, use -// HandshakeContext or the Dialer's DialContext method instead. +// the Dialer's DialContext method. func (c *Conn) Handshake() error { - return c.HandshakeContext(context.Background()) -} - -// HandshakeContext runs the client or server handshake -// protocol if it has not yet been run. -// -// The provided Context must be non-nil. If the context is canceled before -// the handshake is complete, the handshake is interrupted and an error is returned. -// Once the handshake has completed, cancellation of the context will not affect the -// connection. -// -// Most uses of this package need not call HandshakeContext explicitly: the -// first Read or Write will call it automatically. -func (c *Conn) HandshakeContext(ctx context.Context) error { - // Delegate to unexported method for named return - // without confusing documented signature. - return c.handshakeContext(ctx) -} - -func (c *Conn) handshakeContext(ctx context.Context) (ret error) { - handshakeCtx, cancel := context.WithCancel(ctx) - // Note: defer this before starting the "interrupter" goroutine - // so that we can tell the difference between the input being canceled and - // this cancellation. In the former case, we need to close the connection. - defer cancel() - - // Start the "interrupter" goroutine, if this context might be canceled. - // (The background context cannot). - // - // The interrupter goroutine waits for the input context to be done and - // closes the connection if this happens before the function returns. - if ctx.Done() != nil { - done := make(chan struct{}) - interruptRes := make(chan error, 1) - defer func() { - close(done) - if ctxErr := <-interruptRes; ctxErr != nil { - // Return context error to user. - ret = ctxErr - } - }() - go func() { - select { - case <-handshakeCtx.Done(): - // Close the connection, discarding the error - _ = c.conn.Close() - interruptRes <- handshakeCtx.Err() - case <-done: - interruptRes <- nil - } - }() - } - c.handshakeMutex.Lock() defer c.handshakeMutex.Unlock() @@ -1442,7 +1388,7 @@ func (c *Conn) handshakeContext(ctx context.Context) (ret error) { c.in.Lock() defer c.in.Unlock() - c.handshakeErr = c.handshakeFn(handshakeCtx) + c.handshakeErr = c.handshakeFn() if c.handshakeErr == nil { c.handshakes++ } else { diff --git a/src/crypto/tls/handshake_client.go b/src/crypto/tls/handshake_client.go index 92e33e71690..e684b21d527 100644 --- a/src/crypto/tls/handshake_client.go +++ b/src/crypto/tls/handshake_client.go @@ -6,7 +6,6 @@ package tls import ( "bytes" - "context" "crypto" "crypto/ecdsa" "crypto/ed25519" @@ -25,7 +24,6 @@ import ( type clientHandshakeState struct { c *Conn - ctx context.Context serverHello *serverHelloMsg hello *clientHelloMsg suite *cipherSuite @@ -136,7 +134,7 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, ecdheParameters, error) { return hello, params, nil } -func (c *Conn) clientHandshake(ctx context.Context) (err error) { +func (c *Conn) clientHandshake() (err error) { if c.config == nil { c.config = defaultConfig() } @@ -200,7 +198,6 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) { if c.vers == VersionTLS13 { hs := &clientHandshakeStateTLS13{ c: c, - ctx: ctx, serverHello: serverHello, hello: hello, ecdheParams: ecdheParams, @@ -215,7 +212,6 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) { hs := &clientHandshakeState{ c: c, - ctx: ctx, serverHello: serverHello, hello: hello, session: session, @@ -544,7 +540,7 @@ func (hs *clientHandshakeState) doFullHandshake() error { certRequested = true hs.finishedHash.Write(certReq.marshal()) - cri := certificateRequestInfoFromMsg(hs.ctx, c.vers, certReq) + cri := certificateRequestInfoFromMsg(c.vers, certReq) if chainToSend, err = c.getClientCertificate(cri); err != nil { c.sendAlert(alertInternalError) return err @@ -884,11 +880,10 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error { // certificateRequestInfoFromMsg generates a CertificateRequestInfo from a TLS // <= 1.2 CertificateRequest, making an effort to fill in missing information. -func certificateRequestInfoFromMsg(ctx context.Context, vers uint16, certReq *certificateRequestMsg) *CertificateRequestInfo { +func certificateRequestInfoFromMsg(vers uint16, certReq *certificateRequestMsg) *CertificateRequestInfo { cri := &CertificateRequestInfo{ AcceptableCAs: certReq.certificateAuthorities, Version: vers, - ctx: ctx, } var rsaAvail, ecAvail bool diff --git a/src/crypto/tls/handshake_client_test.go b/src/crypto/tls/handshake_client_test.go index 8889e2c8c33..12b0254123e 100644 --- a/src/crypto/tls/handshake_client_test.go +++ b/src/crypto/tls/handshake_client_test.go @@ -6,7 +6,6 @@ package tls import ( "bytes" - "context" "crypto/rsa" "crypto/x509" "encoding/base64" @@ -21,7 +20,6 @@ import ( "os/exec" "path/filepath" "reflect" - "runtime" "strconv" "strings" "testing" @@ -2513,37 +2511,3 @@ func testResumptionKeepsOCSPAndSCT(t *testing.T, ver uint16) { serverConfig.Certificates[0].SignedCertificateTimestamps, ccs.SignedCertificateTimestamps) } } - -func TestClientHandshakeContextCancellation(t *testing.T) { - c, s := localPipe(t) - serverConfig := testConfig.Clone() - serverErr := make(chan error, 1) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - go func() { - defer close(serverErr) - defer s.Close() - conn := Server(s, serverConfig) - _, err := conn.readClientHello(ctx) - cancel() - serverErr <- err - }() - cli := Client(c, testConfig) - err := cli.HandshakeContext(ctx) - if err == nil { - t.Fatal("Client handshake did not error when the context was canceled") - } - if err != context.Canceled { - t.Errorf("Unexpected client handshake error: %v", err) - } - if err := <-serverErr; err != nil { - t.Errorf("Unexpected server error: %v", err) - } - if runtime.GOARCH == "wasm" { - t.Skip("conn.Close does not error as expected when called multiple times on WASM") - } - err = cli.Close() - if err == nil { - t.Error("Client connection was not closed when the context was canceled") - } -} diff --git a/src/crypto/tls/handshake_client_tls13.go b/src/crypto/tls/handshake_client_tls13.go index be37c681c6d..daa5d97fd35 100644 --- a/src/crypto/tls/handshake_client_tls13.go +++ b/src/crypto/tls/handshake_client_tls13.go @@ -6,7 +6,6 @@ package tls import ( "bytes" - "context" "crypto" "crypto/hmac" "crypto/rsa" @@ -18,7 +17,6 @@ import ( type clientHandshakeStateTLS13 struct { c *Conn - ctx context.Context serverHello *serverHelloMsg hello *clientHelloMsg ecdheParams ecdheParameters @@ -557,7 +555,6 @@ func (hs *clientHandshakeStateTLS13) sendClientCertificate() error { AcceptableCAs: hs.certReq.certificateAuthorities, SignatureSchemes: hs.certReq.supportedSignatureAlgorithms, Version: c.vers, - ctx: hs.ctx, }) if err != nil { return err diff --git a/src/crypto/tls/handshake_server.go b/src/crypto/tls/handshake_server.go index 5a572a9db10..9c3e0f636ea 100644 --- a/src/crypto/tls/handshake_server.go +++ b/src/crypto/tls/handshake_server.go @@ -5,7 +5,6 @@ package tls import ( - "context" "crypto" "crypto/ecdsa" "crypto/ed25519" @@ -24,7 +23,6 @@ import ( // It's discarded once the handshake has completed. type serverHandshakeState struct { c *Conn - ctx context.Context clientHello *clientHelloMsg hello *serverHelloMsg suite *cipherSuite @@ -39,8 +37,8 @@ type serverHandshakeState struct { } // serverHandshake performs a TLS handshake as a server. -func (c *Conn) serverHandshake(ctx context.Context) error { - clientHello, err := c.readClientHello(ctx) +func (c *Conn) serverHandshake() error { + clientHello, err := c.readClientHello() if err != nil { return err } @@ -48,7 +46,6 @@ func (c *Conn) serverHandshake(ctx context.Context) error { if c.vers == VersionTLS13 { hs := serverHandshakeStateTLS13{ c: c, - ctx: ctx, clientHello: clientHello, } return hs.handshake() @@ -56,7 +53,6 @@ func (c *Conn) serverHandshake(ctx context.Context) error { hs := serverHandshakeState{ c: c, - ctx: ctx, clientHello: clientHello, } return hs.handshake() @@ -128,7 +124,7 @@ func (hs *serverHandshakeState) handshake() error { } // readClientHello reads a ClientHello message and selects the protocol version. -func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) { +func (c *Conn) readClientHello() (*clientHelloMsg, error) { msg, err := c.readHandshake() if err != nil { return nil, err @@ -142,7 +138,7 @@ func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) { var configForClient *Config originalConfig := c.config if c.config.GetConfigForClient != nil { - chi := clientHelloInfo(ctx, c, clientHello) + chi := clientHelloInfo(c, clientHello) if configForClient, err = c.config.GetConfigForClient(chi); err != nil { c.sendAlert(alertInternalError) return nil, err @@ -224,7 +220,7 @@ func (hs *serverHandshakeState) processClientHello() error { } } - hs.cert, err = c.config.getCertificate(clientHelloInfo(hs.ctx, c, hs.clientHello)) + hs.cert, err = c.config.getCertificate(clientHelloInfo(c, hs.clientHello)) if err != nil { if err == errNoCertificates { c.sendAlert(alertUnrecognizedName) @@ -832,7 +828,7 @@ func (c *Conn) processCertsFromClient(certificate Certificate) error { return nil } -func clientHelloInfo(ctx context.Context, c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo { +func clientHelloInfo(c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo { supportedVersions := clientHello.supportedVersions if len(clientHello.supportedVersions) == 0 { supportedVersions = supportedVersionsFromMax(clientHello.vers) @@ -848,6 +844,5 @@ func clientHelloInfo(ctx context.Context, c *Conn, clientHello *clientHelloMsg) SupportedVersions: supportedVersions, Conn: c.conn, config: c.config, - ctx: ctx, } } diff --git a/src/crypto/tls/handshake_server_test.go b/src/crypto/tls/handshake_server_test.go index ad851b6edf0..d6bf9e439b0 100644 --- a/src/crypto/tls/handshake_server_test.go +++ b/src/crypto/tls/handshake_server_test.go @@ -6,7 +6,6 @@ package tls import ( "bytes" - "context" "crypto" "crypto/elliptic" "crypto/x509" @@ -18,7 +17,6 @@ import ( "os" "os/exec" "path/filepath" - "runtime" "strings" "testing" "time" @@ -40,12 +38,10 @@ func testClientHelloFailure(t *testing.T, serverConfig *Config, m handshakeMessa cli.writeRecord(recordTypeHandshake, m.marshal()) c.Close() }() - ctx := context.Background() conn := Server(s, serverConfig) - ch, err := conn.readClientHello(ctx) + ch, err := conn.readClientHello() hs := serverHandshakeState{ c: conn, - ctx: ctx, clientHello: ch, } if err == nil { @@ -1425,11 +1421,9 @@ func TestSNIGivenOnFailure(t *testing.T) { c.Close() }() conn := Server(s, serverConfig) - ctx := context.Background() - ch, err := conn.readClientHello(ctx) + ch, err := conn.readClientHello() hs := serverHandshakeState{ c: conn, - ctx: ctx, clientHello: ch, } if err == nil { @@ -1683,46 +1677,6 @@ func TestMultipleCertificates(t *testing.T) { } } -func TestServerHandshakeContextCancellation(t *testing.T) { - c, s := localPipe(t) - clientConfig := testConfig.Clone() - clientErr := make(chan error, 1) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - go func() { - defer close(clientErr) - defer c.Close() - clientHello := &clientHelloMsg{ - vers: VersionTLS10, - random: make([]byte, 32), - cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, - compressionMethods: []uint8{compressionNone}, - } - cli := Client(c, clientConfig) - _, err := cli.writeRecord(recordTypeHandshake, clientHello.marshal()) - cancel() - clientErr <- err - }() - conn := Server(s, testConfig) - err := conn.HandshakeContext(ctx) - if err == nil { - t.Fatal("Server handshake did not error when the context was canceled") - } - if err != context.Canceled { - t.Errorf("Unexpected server handshake error: %v", err) - } - if err := <-clientErr; err != nil { - t.Errorf("Unexpected client error: %v", err) - } - if runtime.GOARCH == "wasm" { - t.Skip("conn.Close does not error as expected when called multiple times on WASM") - } - err = conn.Close() - if err == nil { - t.Error("Server connection was not closed when the context was canceled") - } -} - func TestAESCipherReordering(t *testing.T) { currentAESSupport := hasAESGCMHardwareSupport defer func() { hasAESGCMHardwareSupport = currentAESSupport; initDefaultCipherSuites() }() diff --git a/src/crypto/tls/handshake_server_tls13.go b/src/crypto/tls/handshake_server_tls13.go index c7837d2955d..c2c288aed43 100644 --- a/src/crypto/tls/handshake_server_tls13.go +++ b/src/crypto/tls/handshake_server_tls13.go @@ -6,7 +6,6 @@ package tls import ( "bytes" - "context" "crypto" "crypto/hmac" "crypto/rsa" @@ -24,7 +23,6 @@ const maxClientPSKIdentities = 5 type serverHandshakeStateTLS13 struct { c *Conn - ctx context.Context clientHello *clientHelloMsg hello *serverHelloMsg sentDummyCCS bool @@ -376,7 +374,7 @@ func (hs *serverHandshakeStateTLS13) pickCertificate() error { return c.sendAlert(alertMissingExtension) } - certificate, err := c.config.getCertificate(clientHelloInfo(hs.ctx, c, hs.clientHello)) + certificate, err := c.config.getCertificate(clientHelloInfo(c, hs.clientHello)) if err != nil { if err == errNoCertificates { c.sendAlert(alertUnrecognizedName) diff --git a/src/crypto/tls/tls.go b/src/crypto/tls/tls.go index 19884f96e7d..a389873d32e 100644 --- a/src/crypto/tls/tls.go +++ b/src/crypto/tls/tls.go @@ -25,6 +25,7 @@ import ( "net" "os" "strings" + "time" ) // Server returns a new TLS server side connection @@ -115,16 +116,28 @@ func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (* } func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config) (*Conn, error) { - if netDialer.Timeout != 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, netDialer.Timeout) - defer cancel() - } + // We want the Timeout and Deadline values from dialer to cover the + // whole process: TCP connection and TLS handshake. This means that we + // also need to start our own timers now. + timeout := netDialer.Timeout if !netDialer.Deadline.IsZero() { - var cancel context.CancelFunc - ctx, cancel = context.WithDeadline(ctx, netDialer.Deadline) - defer cancel() + deadlineTimeout := time.Until(netDialer.Deadline) + if timeout == 0 || deadlineTimeout < timeout { + timeout = deadlineTimeout + } + } + + // hsErrCh is non-nil if we might not wait for Handshake to complete. + var hsErrCh chan error + if timeout != 0 || ctx.Done() != nil { + hsErrCh = make(chan error, 2) + } + if timeout != 0 { + timer := time.AfterFunc(timeout, func() { + hsErrCh <- timeoutError{} + }) + defer timer.Stop() } rawConn, err := netDialer.DialContext(ctx, network, addr) @@ -151,10 +164,34 @@ func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, conf } conn := Client(rawConn, config) - if err := conn.HandshakeContext(ctx); err != nil { + + if hsErrCh == nil { + err = conn.Handshake() + } else { + go func() { + hsErrCh <- conn.Handshake() + }() + + select { + case <-ctx.Done(): + err = ctx.Err() + case err = <-hsErrCh: + if err != nil { + // If the error was due to the context + // closing, prefer the context's error, rather + // than some random network teardown error. + if e := ctx.Err(); e != nil { + err = e + } + } + } + } + + if err != nil { rawConn.Close() return nil, err } + return conn, nil } diff --git a/src/net/http/server.go b/src/net/http/server.go index 102e893d5f1..ad99741177d 100644 --- a/src/net/http/server.go +++ b/src/net/http/server.go @@ -1837,7 +1837,7 @@ func (c *conn) serve(ctx context.Context) { if d := c.server.WriteTimeout; d != 0 { c.rwc.SetWriteDeadline(time.Now().Add(d)) } - if err := tlsConn.HandshakeContext(ctx); err != nil { + if err := tlsConn.Handshake(); err != nil { // If the handshake failed due to the client not speaking // TLS, assume they're speaking plaintext HTTP and write a // 400 response on the TLS conn's underlying net.Conn. diff --git a/src/net/http/transport.go b/src/net/http/transport.go index 6358c3897ec..0aa48273dd3 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -1505,7 +1505,7 @@ func (t *Transport) decConnsPerHost(key connectMethodKey) { // Add TLS to a persistent connection, i.e. negotiate a TLS session. If pconn is already a TLS // tunnel, this function establishes a nested TLS session inside the encrypted channel. // The remote endpoint's name may be overridden by TLSClientConfig.ServerName. -func (pconn *persistConn) addTLS(ctx context.Context, name string, trace *httptrace.ClientTrace) error { +func (pconn *persistConn) addTLS(name string, trace *httptrace.ClientTrace) error { // Initiate TLS and check remote host name against certificate. cfg := cloneTLSConfig(pconn.t.TLSClientConfig) if cfg.ServerName == "" { @@ -1527,7 +1527,7 @@ func (pconn *persistConn) addTLS(ctx context.Context, name string, trace *httptr if trace != nil && trace.TLSHandshakeStart != nil { trace.TLSHandshakeStart() } - err := tlsConn.HandshakeContext(ctx) + err := tlsConn.Handshake() if timer != nil { timer.Stop() } @@ -1583,7 +1583,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers if trace != nil && trace.TLSHandshakeStart != nil { trace.TLSHandshakeStart() } - if err := tc.HandshakeContext(ctx); err != nil { + if err := tc.Handshake(); err != nil { go pconn.conn.Close() if trace != nil && trace.TLSHandshakeDone != nil { trace.TLSHandshakeDone(tls.ConnectionState{}, err) @@ -1607,7 +1607,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers if firstTLSHost, _, err = net.SplitHostPort(cm.addr()); err != nil { return nil, wrapErr(err) } - if err = pconn.addTLS(ctx, firstTLSHost, trace); err != nil { + if err = pconn.addTLS(firstTLSHost, trace); err != nil { return nil, wrapErr(err) } } @@ -1721,7 +1721,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers } if cm.proxyURL != nil && cm.targetScheme == "https" { - if err := pconn.addTLS(ctx, cm.tlsHost(), trace); err != nil { + if err := pconn.addTLS(cm.tlsHost(), trace); err != nil { return nil, err } } diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go index 7f6e0938c20..ba85a61683a 100644 --- a/src/net/http/transport_test.go +++ b/src/net/http/transport_test.go @@ -3734,7 +3734,7 @@ func TestTransportDialTLSContext(t *testing.T) { if err != nil { return nil, err } - return c, c.HandshakeContext(ctx) + return c, c.Handshake() } req, err := NewRequest("GET", ts.URL, nil)