diff --git a/doc/go1.16.html b/doc/go1.16.html index 6c4d076d502..bb920a0cb8a 100644 --- a/doc/go1.16.html +++ b/doc/go1.16.html @@ -271,6 +271,21 @@ Do not send CLs removing the interior tags from such phrases. indefinitely.

+

+ (*Conn).HandshakeContext was added to + allow the user to control cancellation of an in-progress TLS Handshake. + The context provided is propagated into the + ClientHelloInfo + and CertificateRequestInfo + structs and accessible through the new + (*ClientHelloInfo).Context + and + + (*CertificateRequestInfo).Context + methods respectively. Canceling the context after the handshake has finished + has no effect. +

+

crypto/x509

@@ -405,6 +420,13 @@ Do not send CLs removing the interior tags from such phrases. Cookies set with SameSiteDefaultMode now behave according to the current spec (no attribute is set) instead of generating a SameSite key without a value.

+ +

+ The net/http package now uses the new + (*tls.Conn).HandshakeContext + with the Request context + when performing TLS handshakes in the client or server. +

diff --git a/src/crypto/tls/common.go b/src/crypto/tls/common.go index 86dc0dd3b2e..1370d26fe2c 100644 --- a/src/crypto/tls/common.go +++ b/src/crypto/tls/common.go @@ -7,6 +7,7 @@ package tls import ( "bytes" "container/list" + "context" "crypto" "crypto/ecdsa" "crypto/ed25519" @@ -444,6 +445,16 @@ type ClientHelloInfo struct { // config is embedded by the GetCertificate or GetConfigForClient caller, // for use with SupportsCertificate. config *Config + + // ctx is the context of the handshake that is in progress. + ctx context.Context +} + +// Context returns the context of the handshake that is in progress. +// This context is a child of the context passed to HandshakeContext, +// if any, and is canceled when the handshake concludes. +func (c *ClientHelloInfo) Context() context.Context { + return c.ctx } // CertificateRequestInfo contains information from a server's @@ -462,6 +473,16 @@ type CertificateRequestInfo struct { // Version is the TLS version that was negotiated for this connection. Version uint16 + + // ctx is the context of the handshake that is in progress. + ctx context.Context +} + +// Context returns the context of the handshake that is in progress. +// This context is a child of the context passed to HandshakeContext, +// if any, and is canceled when the handshake concludes. +func (c *CertificateRequestInfo) Context() context.Context { + return c.ctx } // RenegotiationSupport enumerates the different levels of support for TLS diff --git a/src/crypto/tls/conn.go b/src/crypto/tls/conn.go index b9a1095862a..2f5d4303c25 100644 --- a/src/crypto/tls/conn.go +++ b/src/crypto/tls/conn.go @@ -8,6 +8,7 @@ package tls import ( "bytes" + "context" "crypto/cipher" "crypto/subtle" "crypto/x509" @@ -26,7 +27,7 @@ type Conn struct { // constant conn net.Conn isClient bool - handshakeFn func() error // (*Conn).clientHandshake or serverHandshake + handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake // handshakeStatus is 1 if the connection is currently transferring // application data (i.e. is not currently processing a handshake). @@ -1192,7 +1193,7 @@ func (c *Conn) handleRenegotiation() error { defer c.handshakeMutex.Unlock() atomic.StoreUint32(&c.handshakeStatus, 0) - if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil { + if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil { c.handshakes++ } return c.handshakeErr @@ -1375,8 +1376,61 @@ func (c *Conn) closeNotify() error { // first Read or Write will call it automatically. // // For control over canceling or setting a timeout on a handshake, use -// the Dialer's DialContext method. +// HandshakeContext or the Dialer's DialContext method instead. func (c *Conn) Handshake() error { + return c.HandshakeContext(context.Background()) +} + +// HandshakeContext runs the client or server handshake +// protocol if it has not yet been run. +// +// The provided Context must be non-nil. If the context is canceled before +// the handshake is complete, the handshake is interrupted and an error is returned. +// Once the handshake has completed, cancellation of the context will not affect the +// connection. +// +// Most uses of this package need not call HandshakeContext explicitly: the +// first Read or Write will call it automatically. +func (c *Conn) HandshakeContext(ctx context.Context) error { + // Delegate to unexported method for named return + // without confusing documented signature. + return c.handshakeContext(ctx) +} + +func (c *Conn) handshakeContext(ctx context.Context) (ret error) { + handshakeCtx, cancel := context.WithCancel(ctx) + // Note: defer this before starting the "interrupter" goroutine + // so that we can tell the difference between the input being canceled and + // this cancellation. In the former case, we need to close the connection. + defer cancel() + + // Start the "interrupter" goroutine, if this context might be canceled. + // (The background context cannot). + // + // The interrupter goroutine waits for the input context to be done and + // closes the connection if this happens before the function returns. + if ctx.Done() != nil { + done := make(chan struct{}) + interruptRes := make(chan error, 1) + defer func() { + close(done) + if ctxErr := <-interruptRes; ctxErr != nil { + // Return context error to user. + ret = ctxErr + } + }() + go func() { + select { + case <-handshakeCtx.Done(): + // Close the connection, discarding the error + _ = c.conn.Close() + interruptRes <- handshakeCtx.Err() + case <-done: + interruptRes <- nil + } + }() + } + c.handshakeMutex.Lock() defer c.handshakeMutex.Unlock() @@ -1390,7 +1444,7 @@ func (c *Conn) Handshake() error { c.in.Lock() defer c.in.Unlock() - c.handshakeErr = c.handshakeFn() + c.handshakeErr = c.handshakeFn(handshakeCtx) if c.handshakeErr == nil { c.handshakes++ } else { diff --git a/src/crypto/tls/handshake_client.go b/src/crypto/tls/handshake_client.go index 46b0a770d53..d09a8c8ccfd 100644 --- a/src/crypto/tls/handshake_client.go +++ b/src/crypto/tls/handshake_client.go @@ -6,6 +6,7 @@ package tls import ( "bytes" + "context" "crypto" "crypto/ecdsa" "crypto/ed25519" @@ -23,6 +24,7 @@ import ( type clientHandshakeState struct { c *Conn + ctx context.Context serverHello *serverHelloMsg hello *clientHelloMsg suite *cipherSuite @@ -133,7 +135,7 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, ecdheParameters, error) { return hello, params, nil } -func (c *Conn) clientHandshake() (err error) { +func (c *Conn) clientHandshake(ctx context.Context) (err error) { if c.config == nil { c.config = defaultConfig() } @@ -197,6 +199,7 @@ func (c *Conn) clientHandshake() (err error) { if c.vers == VersionTLS13 { hs := &clientHandshakeStateTLS13{ c: c, + ctx: ctx, serverHello: serverHello, hello: hello, ecdheParams: ecdheParams, @@ -211,6 +214,7 @@ func (c *Conn) clientHandshake() (err error) { hs := &clientHandshakeState{ c: c, + ctx: ctx, serverHello: serverHello, hello: hello, session: session, @@ -539,7 +543,7 @@ func (hs *clientHandshakeState) doFullHandshake() error { certRequested = true hs.finishedHash.Write(certReq.marshal()) - cri := certificateRequestInfoFromMsg(c.vers, certReq) + cri := certificateRequestInfoFromMsg(hs.ctx, c.vers, certReq) if chainToSend, err = c.getClientCertificate(cri); err != nil { c.sendAlert(alertInternalError) return err @@ -879,10 +883,11 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error { // certificateRequestInfoFromMsg generates a CertificateRequestInfo from a TLS // <= 1.2 CertificateRequest, making an effort to fill in missing information. -func certificateRequestInfoFromMsg(vers uint16, certReq *certificateRequestMsg) *CertificateRequestInfo { +func certificateRequestInfoFromMsg(ctx context.Context, vers uint16, certReq *certificateRequestMsg) *CertificateRequestInfo { cri := &CertificateRequestInfo{ AcceptableCAs: certReq.certificateAuthorities, Version: vers, + ctx: ctx, } var rsaAvail, ecAvail bool diff --git a/src/crypto/tls/handshake_client_test.go b/src/crypto/tls/handshake_client_test.go index 12b0254123e..8889e2c8c33 100644 --- a/src/crypto/tls/handshake_client_test.go +++ b/src/crypto/tls/handshake_client_test.go @@ -6,6 +6,7 @@ package tls import ( "bytes" + "context" "crypto/rsa" "crypto/x509" "encoding/base64" @@ -20,6 +21,7 @@ import ( "os/exec" "path/filepath" "reflect" + "runtime" "strconv" "strings" "testing" @@ -2511,3 +2513,37 @@ func testResumptionKeepsOCSPAndSCT(t *testing.T, ver uint16) { serverConfig.Certificates[0].SignedCertificateTimestamps, ccs.SignedCertificateTimestamps) } } + +func TestClientHandshakeContextCancellation(t *testing.T) { + c, s := localPipe(t) + serverConfig := testConfig.Clone() + serverErr := make(chan error, 1) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + defer close(serverErr) + defer s.Close() + conn := Server(s, serverConfig) + _, err := conn.readClientHello(ctx) + cancel() + serverErr <- err + }() + cli := Client(c, testConfig) + err := cli.HandshakeContext(ctx) + if err == nil { + t.Fatal("Client handshake did not error when the context was canceled") + } + if err != context.Canceled { + t.Errorf("Unexpected client handshake error: %v", err) + } + if err := <-serverErr; err != nil { + t.Errorf("Unexpected server error: %v", err) + } + if runtime.GOARCH == "wasm" { + t.Skip("conn.Close does not error as expected when called multiple times on WASM") + } + err = cli.Close() + if err == nil { + t.Error("Client connection was not closed when the context was canceled") + } +} diff --git a/src/crypto/tls/handshake_client_tls13.go b/src/crypto/tls/handshake_client_tls13.go index 9c61105cf73..0e4b3800352 100644 --- a/src/crypto/tls/handshake_client_tls13.go +++ b/src/crypto/tls/handshake_client_tls13.go @@ -6,6 +6,7 @@ package tls import ( "bytes" + "context" "crypto" "crypto/hmac" "crypto/rsa" @@ -17,6 +18,7 @@ import ( type clientHandshakeStateTLS13 struct { c *Conn + ctx context.Context serverHello *serverHelloMsg hello *clientHelloMsg ecdheParams ecdheParameters @@ -549,6 +551,7 @@ func (hs *clientHandshakeStateTLS13) sendClientCertificate() error { AcceptableCAs: hs.certReq.certificateAuthorities, SignatureSchemes: hs.certReq.supportedSignatureAlgorithms, Version: c.vers, + ctx: hs.ctx, }) if err != nil { return err diff --git a/src/crypto/tls/handshake_server.go b/src/crypto/tls/handshake_server.go index 16d3e643f0b..1fe026ae0e0 100644 --- a/src/crypto/tls/handshake_server.go +++ b/src/crypto/tls/handshake_server.go @@ -5,6 +5,7 @@ package tls import ( + "context" "crypto" "crypto/ecdsa" "crypto/ed25519" @@ -22,6 +23,7 @@ import ( // It's discarded once the handshake has completed. type serverHandshakeState struct { c *Conn + ctx context.Context clientHello *clientHelloMsg hello *serverHelloMsg suite *cipherSuite @@ -36,8 +38,8 @@ type serverHandshakeState struct { } // serverHandshake performs a TLS handshake as a server. -func (c *Conn) serverHandshake() error { - clientHello, err := c.readClientHello() +func (c *Conn) serverHandshake(ctx context.Context) error { + clientHello, err := c.readClientHello(ctx) if err != nil { return err } @@ -45,6 +47,7 @@ func (c *Conn) serverHandshake() error { if c.vers == VersionTLS13 { hs := serverHandshakeStateTLS13{ c: c, + ctx: ctx, clientHello: clientHello, } return hs.handshake() @@ -52,6 +55,7 @@ func (c *Conn) serverHandshake() error { hs := serverHandshakeState{ c: c, + ctx: ctx, clientHello: clientHello, } return hs.handshake() @@ -123,7 +127,7 @@ func (hs *serverHandshakeState) handshake() error { } // readClientHello reads a ClientHello message and selects the protocol version. -func (c *Conn) readClientHello() (*clientHelloMsg, error) { +func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) { msg, err := c.readHandshake() if err != nil { return nil, err @@ -137,7 +141,7 @@ func (c *Conn) readClientHello() (*clientHelloMsg, error) { var configForClient *Config originalConfig := c.config if c.config.GetConfigForClient != nil { - chi := clientHelloInfo(c, clientHello) + chi := clientHelloInfo(ctx, c, clientHello) if configForClient, err = c.config.GetConfigForClient(chi); err != nil { c.sendAlert(alertInternalError) return nil, err @@ -219,7 +223,7 @@ func (hs *serverHandshakeState) processClientHello() error { } } - hs.cert, err = c.config.getCertificate(clientHelloInfo(c, hs.clientHello)) + hs.cert, err = c.config.getCertificate(clientHelloInfo(hs.ctx, c, hs.clientHello)) if err != nil { if err == errNoCertificates { c.sendAlert(alertUnrecognizedName) @@ -813,7 +817,7 @@ func (c *Conn) processCertsFromClient(certificate Certificate) error { return nil } -func clientHelloInfo(c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo { +func clientHelloInfo(ctx context.Context, c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo { supportedVersions := clientHello.supportedVersions if len(clientHello.supportedVersions) == 0 { supportedVersions = supportedVersionsFromMax(clientHello.vers) @@ -829,5 +833,6 @@ func clientHelloInfo(c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo { SupportedVersions: supportedVersions, Conn: c.conn, config: c.config, + ctx: ctx, } } diff --git a/src/crypto/tls/handshake_server_test.go b/src/crypto/tls/handshake_server_test.go index a7a53243129..c4416c379a4 100644 --- a/src/crypto/tls/handshake_server_test.go +++ b/src/crypto/tls/handshake_server_test.go @@ -6,6 +6,7 @@ package tls import ( "bytes" + "context" "crypto" "crypto/elliptic" "crypto/x509" @@ -17,6 +18,7 @@ import ( "os" "os/exec" "path/filepath" + "runtime" "strings" "testing" "time" @@ -36,10 +38,12 @@ func testClientHelloFailure(t *testing.T, serverConfig *Config, m handshakeMessa cli.writeRecord(recordTypeHandshake, m.marshal()) c.Close() }() + ctx := context.Background() conn := Server(s, serverConfig) - ch, err := conn.readClientHello() + ch, err := conn.readClientHello(ctx) hs := serverHandshakeState{ c: conn, + ctx: ctx, clientHello: ch, } if err == nil { @@ -1418,9 +1422,11 @@ func TestSNIGivenOnFailure(t *testing.T) { c.Close() }() conn := Server(s, serverConfig) - ch, err := conn.readClientHello() + ctx := context.Background() + ch, err := conn.readClientHello(ctx) hs := serverHandshakeState{ c: conn, + ctx: ctx, clientHello: ch, } if err == nil { @@ -1673,3 +1679,43 @@ func TestMultipleCertificates(t *testing.T) { t.Errorf("expected RSA certificate, got %v", got) } } + +func TestServerHandshakeContextCancellation(t *testing.T) { + c, s := localPipe(t) + clientConfig := testConfig.Clone() + clientErr := make(chan error, 1) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + defer close(clientErr) + defer c.Close() + clientHello := &clientHelloMsg{ + vers: VersionTLS10, + random: make([]byte, 32), + cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, + compressionMethods: []uint8{compressionNone}, + } + cli := Client(c, clientConfig) + _, err := cli.writeRecord(recordTypeHandshake, clientHello.marshal()) + cancel() + clientErr <- err + }() + conn := Server(s, testConfig) + err := conn.HandshakeContext(ctx) + if err == nil { + t.Fatal("Server handshake did not error when the context was canceled") + } + if err != context.Canceled { + t.Errorf("Unexpected server handshake error: %v", err) + } + if err := <-clientErr; err != nil { + t.Errorf("Unexpected client error: %v", err) + } + if runtime.GOARCH == "wasm" { + t.Skip("conn.Close does not error as expected when called multiple times on WASM") + } + err = conn.Close() + if err == nil { + t.Error("Server connection was not closed when the context was canceled") + } +} diff --git a/src/crypto/tls/handshake_server_tls13.go b/src/crypto/tls/handshake_server_tls13.go index 92d55e0293a..25c37b92c54 100644 --- a/src/crypto/tls/handshake_server_tls13.go +++ b/src/crypto/tls/handshake_server_tls13.go @@ -6,6 +6,7 @@ package tls import ( "bytes" + "context" "crypto" "crypto/hmac" "crypto/rsa" @@ -23,6 +24,7 @@ const maxClientPSKIdentities = 5 type serverHandshakeStateTLS13 struct { c *Conn + ctx context.Context clientHello *clientHelloMsg hello *serverHelloMsg sentDummyCCS bool @@ -361,7 +363,7 @@ func (hs *serverHandshakeStateTLS13) pickCertificate() error { return c.sendAlert(alertMissingExtension) } - certificate, err := c.config.getCertificate(clientHelloInfo(c, hs.clientHello)) + certificate, err := c.config.getCertificate(clientHelloInfo(hs.ctx, c, hs.clientHello)) if err != nil { if err == errNoCertificates { c.sendAlert(alertUnrecognizedName) diff --git a/src/crypto/tls/tls.go b/src/crypto/tls/tls.go index 454aa0bbbc0..bf577cadeaa 100644 --- a/src/crypto/tls/tls.go +++ b/src/crypto/tls/tls.go @@ -25,7 +25,6 @@ import ( "io/ioutil" "net" "strings" - "time" ) // Server returns a new TLS server side connection @@ -116,28 +115,16 @@ func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (* } func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config) (*Conn, error) { - // We want the Timeout and Deadline values from dialer to cover the - // whole process: TCP connection and TLS handshake. This means that we - // also need to start our own timers now. - timeout := netDialer.Timeout + if netDialer.Timeout != 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, netDialer.Timeout) + defer cancel() + } if !netDialer.Deadline.IsZero() { - deadlineTimeout := time.Until(netDialer.Deadline) - if timeout == 0 || deadlineTimeout < timeout { - timeout = deadlineTimeout - } - } - - // hsErrCh is non-nil if we might not wait for Handshake to complete. - var hsErrCh chan error - if timeout != 0 || ctx.Done() != nil { - hsErrCh = make(chan error, 2) - } - if timeout != 0 { - timer := time.AfterFunc(timeout, func() { - hsErrCh <- timeoutError{} - }) - defer timer.Stop() + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(ctx, netDialer.Deadline) + defer cancel() } rawConn, err := netDialer.DialContext(ctx, network, addr) @@ -164,34 +151,10 @@ func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, conf } conn := Client(rawConn, config) - - if hsErrCh == nil { - err = conn.Handshake() - } else { - go func() { - hsErrCh <- conn.Handshake() - }() - - select { - case <-ctx.Done(): - err = ctx.Err() - case err = <-hsErrCh: - if err != nil { - // If the error was due to the context - // closing, prefer the context's error, rather - // than some random network teardown error. - if e := ctx.Err(); e != nil { - err = e - } - } - } - } - - if err != nil { + if err := conn.HandshakeContext(ctx); err != nil { rawConn.Close() return nil, err } - return conn, nil } diff --git a/src/net/http/server.go b/src/net/http/server.go index 4776d960e57..6c7d2817051 100644 --- a/src/net/http/server.go +++ b/src/net/http/server.go @@ -1831,7 +1831,7 @@ func (c *conn) serve(ctx context.Context) { if d := c.server.WriteTimeout; d != 0 { c.rwc.SetWriteDeadline(time.Now().Add(d)) } - if err := tlsConn.Handshake(); err != nil { + if err := tlsConn.HandshakeContext(ctx); err != nil { // If the handshake failed due to the client not speaking // TLS, assume they're speaking plaintext HTTP and write a // 400 response on the TLS conn's underlying net.Conn. diff --git a/src/net/http/transport.go b/src/net/http/transport.go index 29d7434f2a8..65ba6644154 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -1502,7 +1502,7 @@ func (t *Transport) decConnsPerHost(key connectMethodKey) { // Add TLS to a persistent connection, i.e. negotiate a TLS session. If pconn is already a TLS // tunnel, this function establishes a nested TLS session inside the encrypted channel. // The remote endpoint's name may be overridden by TLSClientConfig.ServerName. -func (pconn *persistConn) addTLS(name string, trace *httptrace.ClientTrace) error { +func (pconn *persistConn) addTLS(ctx context.Context, name string, trace *httptrace.ClientTrace) error { // Initiate TLS and check remote host name against certificate. cfg := cloneTLSConfig(pconn.t.TLSClientConfig) if cfg.ServerName == "" { @@ -1524,7 +1524,7 @@ func (pconn *persistConn) addTLS(name string, trace *httptrace.ClientTrace) erro if trace != nil && trace.TLSHandshakeStart != nil { trace.TLSHandshakeStart() } - err := tlsConn.Handshake() + err := tlsConn.HandshakeContext(ctx) if timer != nil { timer.Stop() } @@ -1580,7 +1580,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers if trace != nil && trace.TLSHandshakeStart != nil { trace.TLSHandshakeStart() } - if err := tc.Handshake(); err != nil { + if err := tc.HandshakeContext(ctx); err != nil { go pconn.conn.Close() if trace != nil && trace.TLSHandshakeDone != nil { trace.TLSHandshakeDone(tls.ConnectionState{}, err) @@ -1604,7 +1604,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers if firstTLSHost, _, err = net.SplitHostPort(cm.addr()); err != nil { return nil, wrapErr(err) } - if err = pconn.addTLS(firstTLSHost, trace); err != nil { + if err = pconn.addTLS(ctx, firstTLSHost, trace); err != nil { return nil, wrapErr(err) } } @@ -1718,7 +1718,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers } if cm.proxyURL != nil && cm.targetScheme == "https" { - if err := pconn.addTLS(cm.tlsHost(), trace); err != nil { + if err := pconn.addTLS(ctx, cm.tlsHost(), trace); err != nil { return nil, err } } diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go index e69133e7868..9086507d576 100644 --- a/src/net/http/transport_test.go +++ b/src/net/http/transport_test.go @@ -3735,7 +3735,7 @@ func TestTransportDialTLSContext(t *testing.T) { if err != nil { return nil, err } - return c, c.Handshake() + return c, c.HandshakeContext(ctx) } req, err := NewRequest("GET", ts.URL, nil)