diff --git a/doc/go1.16.html b/doc/go1.16.html
index 6c4d076d502..bb920a0cb8a 100644
--- a/doc/go1.16.html
+++ b/doc/go1.16.html
@@ -271,6 +271,21 @@ Do not send CLs removing the interior tags from such phrases.
indefinitely.
+
+ (*Conn).HandshakeContext was added to
+ allow the user to control cancellation of an in-progress TLS Handshake.
+ The context provided is propagated into the
+ ClientHelloInfo
+ and CertificateRequestInfo
+ structs and accessible through the new
+ (*ClientHelloInfo).Context
+ and
+
+ (*CertificateRequestInfo).Context
+ methods respectively. Canceling the context after the handshake has finished
+ has no effect.
+
+
@@ -405,6 +420,13 @@ Do not send CLs removing the interior tags from such phrases.
Cookies set with SameSiteDefaultMode
now behave according to the current
spec (no attribute is set) instead of generating a SameSite key without a value.
+
+
+ The net/http
package now uses the new
+ (*tls.Conn).HandshakeContext
+ with the Request
context
+ when performing TLS handshakes in the client or server.
+
diff --git a/src/crypto/tls/common.go b/src/crypto/tls/common.go
index 86dc0dd3b2e..1370d26fe2c 100644
--- a/src/crypto/tls/common.go
+++ b/src/crypto/tls/common.go
@@ -7,6 +7,7 @@ package tls
import (
"bytes"
"container/list"
+ "context"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
@@ -444,6 +445,16 @@ type ClientHelloInfo struct {
// config is embedded by the GetCertificate or GetConfigForClient caller,
// for use with SupportsCertificate.
config *Config
+
+ // ctx is the context of the handshake that is in progress.
+ ctx context.Context
+}
+
+// Context returns the context of the handshake that is in progress.
+// This context is a child of the context passed to HandshakeContext,
+// if any, and is canceled when the handshake concludes.
+func (c *ClientHelloInfo) Context() context.Context {
+ return c.ctx
}
// CertificateRequestInfo contains information from a server's
@@ -462,6 +473,16 @@ type CertificateRequestInfo struct {
// Version is the TLS version that was negotiated for this connection.
Version uint16
+
+ // ctx is the context of the handshake that is in progress.
+ ctx context.Context
+}
+
+// Context returns the context of the handshake that is in progress.
+// This context is a child of the context passed to HandshakeContext,
+// if any, and is canceled when the handshake concludes.
+func (c *CertificateRequestInfo) Context() context.Context {
+ return c.ctx
}
// RenegotiationSupport enumerates the different levels of support for TLS
diff --git a/src/crypto/tls/conn.go b/src/crypto/tls/conn.go
index b9a1095862a..2f5d4303c25 100644
--- a/src/crypto/tls/conn.go
+++ b/src/crypto/tls/conn.go
@@ -8,6 +8,7 @@ package tls
import (
"bytes"
+ "context"
"crypto/cipher"
"crypto/subtle"
"crypto/x509"
@@ -26,7 +27,7 @@ type Conn struct {
// constant
conn net.Conn
isClient bool
- handshakeFn func() error // (*Conn).clientHandshake or serverHandshake
+ handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake
// handshakeStatus is 1 if the connection is currently transferring
// application data (i.e. is not currently processing a handshake).
@@ -1192,7 +1193,7 @@ func (c *Conn) handleRenegotiation() error {
defer c.handshakeMutex.Unlock()
atomic.StoreUint32(&c.handshakeStatus, 0)
- if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil {
+ if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil {
c.handshakes++
}
return c.handshakeErr
@@ -1375,8 +1376,61 @@ func (c *Conn) closeNotify() error {
// first Read or Write will call it automatically.
//
// For control over canceling or setting a timeout on a handshake, use
-// the Dialer's DialContext method.
+// HandshakeContext or the Dialer's DialContext method instead.
func (c *Conn) Handshake() error {
+ return c.HandshakeContext(context.Background())
+}
+
+// HandshakeContext runs the client or server handshake
+// protocol if it has not yet been run.
+//
+// The provided Context must be non-nil. If the context is canceled before
+// the handshake is complete, the handshake is interrupted and an error is returned.
+// Once the handshake has completed, cancellation of the context will not affect the
+// connection.
+//
+// Most uses of this package need not call HandshakeContext explicitly: the
+// first Read or Write will call it automatically.
+func (c *Conn) HandshakeContext(ctx context.Context) error {
+ // Delegate to unexported method for named return
+ // without confusing documented signature.
+ return c.handshakeContext(ctx)
+}
+
+func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
+ handshakeCtx, cancel := context.WithCancel(ctx)
+ // Note: defer this before starting the "interrupter" goroutine
+ // so that we can tell the difference between the input being canceled and
+ // this cancellation. In the former case, we need to close the connection.
+ defer cancel()
+
+ // Start the "interrupter" goroutine, if this context might be canceled.
+ // (The background context cannot).
+ //
+ // The interrupter goroutine waits for the input context to be done and
+ // closes the connection if this happens before the function returns.
+ if ctx.Done() != nil {
+ done := make(chan struct{})
+ interruptRes := make(chan error, 1)
+ defer func() {
+ close(done)
+ if ctxErr := <-interruptRes; ctxErr != nil {
+ // Return context error to user.
+ ret = ctxErr
+ }
+ }()
+ go func() {
+ select {
+ case <-handshakeCtx.Done():
+ // Close the connection, discarding the error
+ _ = c.conn.Close()
+ interruptRes <- handshakeCtx.Err()
+ case <-done:
+ interruptRes <- nil
+ }
+ }()
+ }
+
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
@@ -1390,7 +1444,7 @@ func (c *Conn) Handshake() error {
c.in.Lock()
defer c.in.Unlock()
- c.handshakeErr = c.handshakeFn()
+ c.handshakeErr = c.handshakeFn(handshakeCtx)
if c.handshakeErr == nil {
c.handshakes++
} else {
diff --git a/src/crypto/tls/handshake_client.go b/src/crypto/tls/handshake_client.go
index 46b0a770d53..d09a8c8ccfd 100644
--- a/src/crypto/tls/handshake_client.go
+++ b/src/crypto/tls/handshake_client.go
@@ -6,6 +6,7 @@ package tls
import (
"bytes"
+ "context"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
@@ -23,6 +24,7 @@ import (
type clientHandshakeState struct {
c *Conn
+ ctx context.Context
serverHello *serverHelloMsg
hello *clientHelloMsg
suite *cipherSuite
@@ -133,7 +135,7 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, ecdheParameters, error) {
return hello, params, nil
}
-func (c *Conn) clientHandshake() (err error) {
+func (c *Conn) clientHandshake(ctx context.Context) (err error) {
if c.config == nil {
c.config = defaultConfig()
}
@@ -197,6 +199,7 @@ func (c *Conn) clientHandshake() (err error) {
if c.vers == VersionTLS13 {
hs := &clientHandshakeStateTLS13{
c: c,
+ ctx: ctx,
serverHello: serverHello,
hello: hello,
ecdheParams: ecdheParams,
@@ -211,6 +214,7 @@ func (c *Conn) clientHandshake() (err error) {
hs := &clientHandshakeState{
c: c,
+ ctx: ctx,
serverHello: serverHello,
hello: hello,
session: session,
@@ -539,7 +543,7 @@ func (hs *clientHandshakeState) doFullHandshake() error {
certRequested = true
hs.finishedHash.Write(certReq.marshal())
- cri := certificateRequestInfoFromMsg(c.vers, certReq)
+ cri := certificateRequestInfoFromMsg(hs.ctx, c.vers, certReq)
if chainToSend, err = c.getClientCertificate(cri); err != nil {
c.sendAlert(alertInternalError)
return err
@@ -879,10 +883,11 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error {
// certificateRequestInfoFromMsg generates a CertificateRequestInfo from a TLS
// <= 1.2 CertificateRequest, making an effort to fill in missing information.
-func certificateRequestInfoFromMsg(vers uint16, certReq *certificateRequestMsg) *CertificateRequestInfo {
+func certificateRequestInfoFromMsg(ctx context.Context, vers uint16, certReq *certificateRequestMsg) *CertificateRequestInfo {
cri := &CertificateRequestInfo{
AcceptableCAs: certReq.certificateAuthorities,
Version: vers,
+ ctx: ctx,
}
var rsaAvail, ecAvail bool
diff --git a/src/crypto/tls/handshake_client_test.go b/src/crypto/tls/handshake_client_test.go
index 12b0254123e..8889e2c8c33 100644
--- a/src/crypto/tls/handshake_client_test.go
+++ b/src/crypto/tls/handshake_client_test.go
@@ -6,6 +6,7 @@ package tls
import (
"bytes"
+ "context"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
@@ -20,6 +21,7 @@ import (
"os/exec"
"path/filepath"
"reflect"
+ "runtime"
"strconv"
"strings"
"testing"
@@ -2511,3 +2513,37 @@ func testResumptionKeepsOCSPAndSCT(t *testing.T, ver uint16) {
serverConfig.Certificates[0].SignedCertificateTimestamps, ccs.SignedCertificateTimestamps)
}
}
+
+func TestClientHandshakeContextCancellation(t *testing.T) {
+ c, s := localPipe(t)
+ serverConfig := testConfig.Clone()
+ serverErr := make(chan error, 1)
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ go func() {
+ defer close(serverErr)
+ defer s.Close()
+ conn := Server(s, serverConfig)
+ _, err := conn.readClientHello(ctx)
+ cancel()
+ serverErr <- err
+ }()
+ cli := Client(c, testConfig)
+ err := cli.HandshakeContext(ctx)
+ if err == nil {
+ t.Fatal("Client handshake did not error when the context was canceled")
+ }
+ if err != context.Canceled {
+ t.Errorf("Unexpected client handshake error: %v", err)
+ }
+ if err := <-serverErr; err != nil {
+ t.Errorf("Unexpected server error: %v", err)
+ }
+ if runtime.GOARCH == "wasm" {
+ t.Skip("conn.Close does not error as expected when called multiple times on WASM")
+ }
+ err = cli.Close()
+ if err == nil {
+ t.Error("Client connection was not closed when the context was canceled")
+ }
+}
diff --git a/src/crypto/tls/handshake_client_tls13.go b/src/crypto/tls/handshake_client_tls13.go
index 9c61105cf73..0e4b3800352 100644
--- a/src/crypto/tls/handshake_client_tls13.go
+++ b/src/crypto/tls/handshake_client_tls13.go
@@ -6,6 +6,7 @@ package tls
import (
"bytes"
+ "context"
"crypto"
"crypto/hmac"
"crypto/rsa"
@@ -17,6 +18,7 @@ import (
type clientHandshakeStateTLS13 struct {
c *Conn
+ ctx context.Context
serverHello *serverHelloMsg
hello *clientHelloMsg
ecdheParams ecdheParameters
@@ -549,6 +551,7 @@ func (hs *clientHandshakeStateTLS13) sendClientCertificate() error {
AcceptableCAs: hs.certReq.certificateAuthorities,
SignatureSchemes: hs.certReq.supportedSignatureAlgorithms,
Version: c.vers,
+ ctx: hs.ctx,
})
if err != nil {
return err
diff --git a/src/crypto/tls/handshake_server.go b/src/crypto/tls/handshake_server.go
index 16d3e643f0b..1fe026ae0e0 100644
--- a/src/crypto/tls/handshake_server.go
+++ b/src/crypto/tls/handshake_server.go
@@ -5,6 +5,7 @@
package tls
import (
+ "context"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
@@ -22,6 +23,7 @@ import (
// It's discarded once the handshake has completed.
type serverHandshakeState struct {
c *Conn
+ ctx context.Context
clientHello *clientHelloMsg
hello *serverHelloMsg
suite *cipherSuite
@@ -36,8 +38,8 @@ type serverHandshakeState struct {
}
// serverHandshake performs a TLS handshake as a server.
-func (c *Conn) serverHandshake() error {
- clientHello, err := c.readClientHello()
+func (c *Conn) serverHandshake(ctx context.Context) error {
+ clientHello, err := c.readClientHello(ctx)
if err != nil {
return err
}
@@ -45,6 +47,7 @@ func (c *Conn) serverHandshake() error {
if c.vers == VersionTLS13 {
hs := serverHandshakeStateTLS13{
c: c,
+ ctx: ctx,
clientHello: clientHello,
}
return hs.handshake()
@@ -52,6 +55,7 @@ func (c *Conn) serverHandshake() error {
hs := serverHandshakeState{
c: c,
+ ctx: ctx,
clientHello: clientHello,
}
return hs.handshake()
@@ -123,7 +127,7 @@ func (hs *serverHandshakeState) handshake() error {
}
// readClientHello reads a ClientHello message and selects the protocol version.
-func (c *Conn) readClientHello() (*clientHelloMsg, error) {
+func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) {
msg, err := c.readHandshake()
if err != nil {
return nil, err
@@ -137,7 +141,7 @@ func (c *Conn) readClientHello() (*clientHelloMsg, error) {
var configForClient *Config
originalConfig := c.config
if c.config.GetConfigForClient != nil {
- chi := clientHelloInfo(c, clientHello)
+ chi := clientHelloInfo(ctx, c, clientHello)
if configForClient, err = c.config.GetConfigForClient(chi); err != nil {
c.sendAlert(alertInternalError)
return nil, err
@@ -219,7 +223,7 @@ func (hs *serverHandshakeState) processClientHello() error {
}
}
- hs.cert, err = c.config.getCertificate(clientHelloInfo(c, hs.clientHello))
+ hs.cert, err = c.config.getCertificate(clientHelloInfo(hs.ctx, c, hs.clientHello))
if err != nil {
if err == errNoCertificates {
c.sendAlert(alertUnrecognizedName)
@@ -813,7 +817,7 @@ func (c *Conn) processCertsFromClient(certificate Certificate) error {
return nil
}
-func clientHelloInfo(c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo {
+func clientHelloInfo(ctx context.Context, c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo {
supportedVersions := clientHello.supportedVersions
if len(clientHello.supportedVersions) == 0 {
supportedVersions = supportedVersionsFromMax(clientHello.vers)
@@ -829,5 +833,6 @@ func clientHelloInfo(c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo {
SupportedVersions: supportedVersions,
Conn: c.conn,
config: c.config,
+ ctx: ctx,
}
}
diff --git a/src/crypto/tls/handshake_server_test.go b/src/crypto/tls/handshake_server_test.go
index a7a53243129..c4416c379a4 100644
--- a/src/crypto/tls/handshake_server_test.go
+++ b/src/crypto/tls/handshake_server_test.go
@@ -6,6 +6,7 @@ package tls
import (
"bytes"
+ "context"
"crypto"
"crypto/elliptic"
"crypto/x509"
@@ -17,6 +18,7 @@ import (
"os"
"os/exec"
"path/filepath"
+ "runtime"
"strings"
"testing"
"time"
@@ -36,10 +38,12 @@ func testClientHelloFailure(t *testing.T, serverConfig *Config, m handshakeMessa
cli.writeRecord(recordTypeHandshake, m.marshal())
c.Close()
}()
+ ctx := context.Background()
conn := Server(s, serverConfig)
- ch, err := conn.readClientHello()
+ ch, err := conn.readClientHello(ctx)
hs := serverHandshakeState{
c: conn,
+ ctx: ctx,
clientHello: ch,
}
if err == nil {
@@ -1418,9 +1422,11 @@ func TestSNIGivenOnFailure(t *testing.T) {
c.Close()
}()
conn := Server(s, serverConfig)
- ch, err := conn.readClientHello()
+ ctx := context.Background()
+ ch, err := conn.readClientHello(ctx)
hs := serverHandshakeState{
c: conn,
+ ctx: ctx,
clientHello: ch,
}
if err == nil {
@@ -1673,3 +1679,43 @@ func TestMultipleCertificates(t *testing.T) {
t.Errorf("expected RSA certificate, got %v", got)
}
}
+
+func TestServerHandshakeContextCancellation(t *testing.T) {
+ c, s := localPipe(t)
+ clientConfig := testConfig.Clone()
+ clientErr := make(chan error, 1)
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ go func() {
+ defer close(clientErr)
+ defer c.Close()
+ clientHello := &clientHelloMsg{
+ vers: VersionTLS10,
+ random: make([]byte, 32),
+ cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA},
+ compressionMethods: []uint8{compressionNone},
+ }
+ cli := Client(c, clientConfig)
+ _, err := cli.writeRecord(recordTypeHandshake, clientHello.marshal())
+ cancel()
+ clientErr <- err
+ }()
+ conn := Server(s, testConfig)
+ err := conn.HandshakeContext(ctx)
+ if err == nil {
+ t.Fatal("Server handshake did not error when the context was canceled")
+ }
+ if err != context.Canceled {
+ t.Errorf("Unexpected server handshake error: %v", err)
+ }
+ if err := <-clientErr; err != nil {
+ t.Errorf("Unexpected client error: %v", err)
+ }
+ if runtime.GOARCH == "wasm" {
+ t.Skip("conn.Close does not error as expected when called multiple times on WASM")
+ }
+ err = conn.Close()
+ if err == nil {
+ t.Error("Server connection was not closed when the context was canceled")
+ }
+}
diff --git a/src/crypto/tls/handshake_server_tls13.go b/src/crypto/tls/handshake_server_tls13.go
index 92d55e0293a..25c37b92c54 100644
--- a/src/crypto/tls/handshake_server_tls13.go
+++ b/src/crypto/tls/handshake_server_tls13.go
@@ -6,6 +6,7 @@ package tls
import (
"bytes"
+ "context"
"crypto"
"crypto/hmac"
"crypto/rsa"
@@ -23,6 +24,7 @@ const maxClientPSKIdentities = 5
type serverHandshakeStateTLS13 struct {
c *Conn
+ ctx context.Context
clientHello *clientHelloMsg
hello *serverHelloMsg
sentDummyCCS bool
@@ -361,7 +363,7 @@ func (hs *serverHandshakeStateTLS13) pickCertificate() error {
return c.sendAlert(alertMissingExtension)
}
- certificate, err := c.config.getCertificate(clientHelloInfo(c, hs.clientHello))
+ certificate, err := c.config.getCertificate(clientHelloInfo(hs.ctx, c, hs.clientHello))
if err != nil {
if err == errNoCertificates {
c.sendAlert(alertUnrecognizedName)
diff --git a/src/crypto/tls/tls.go b/src/crypto/tls/tls.go
index 454aa0bbbc0..bf577cadeaa 100644
--- a/src/crypto/tls/tls.go
+++ b/src/crypto/tls/tls.go
@@ -25,7 +25,6 @@ import (
"io/ioutil"
"net"
"strings"
- "time"
)
// Server returns a new TLS server side connection
@@ -116,28 +115,16 @@ func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*
}
func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
- // We want the Timeout and Deadline values from dialer to cover the
- // whole process: TCP connection and TLS handshake. This means that we
- // also need to start our own timers now.
- timeout := netDialer.Timeout
+ if netDialer.Timeout != 0 {
+ var cancel context.CancelFunc
+ ctx, cancel = context.WithTimeout(ctx, netDialer.Timeout)
+ defer cancel()
+ }
if !netDialer.Deadline.IsZero() {
- deadlineTimeout := time.Until(netDialer.Deadline)
- if timeout == 0 || deadlineTimeout < timeout {
- timeout = deadlineTimeout
- }
- }
-
- // hsErrCh is non-nil if we might not wait for Handshake to complete.
- var hsErrCh chan error
- if timeout != 0 || ctx.Done() != nil {
- hsErrCh = make(chan error, 2)
- }
- if timeout != 0 {
- timer := time.AfterFunc(timeout, func() {
- hsErrCh <- timeoutError{}
- })
- defer timer.Stop()
+ var cancel context.CancelFunc
+ ctx, cancel = context.WithDeadline(ctx, netDialer.Deadline)
+ defer cancel()
}
rawConn, err := netDialer.DialContext(ctx, network, addr)
@@ -164,34 +151,10 @@ func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, conf
}
conn := Client(rawConn, config)
-
- if hsErrCh == nil {
- err = conn.Handshake()
- } else {
- go func() {
- hsErrCh <- conn.Handshake()
- }()
-
- select {
- case <-ctx.Done():
- err = ctx.Err()
- case err = <-hsErrCh:
- if err != nil {
- // If the error was due to the context
- // closing, prefer the context's error, rather
- // than some random network teardown error.
- if e := ctx.Err(); e != nil {
- err = e
- }
- }
- }
- }
-
- if err != nil {
+ if err := conn.HandshakeContext(ctx); err != nil {
rawConn.Close()
return nil, err
}
-
return conn, nil
}
diff --git a/src/net/http/server.go b/src/net/http/server.go
index 4776d960e57..6c7d2817051 100644
--- a/src/net/http/server.go
+++ b/src/net/http/server.go
@@ -1831,7 +1831,7 @@ func (c *conn) serve(ctx context.Context) {
if d := c.server.WriteTimeout; d != 0 {
c.rwc.SetWriteDeadline(time.Now().Add(d))
}
- if err := tlsConn.Handshake(); err != nil {
+ if err := tlsConn.HandshakeContext(ctx); err != nil {
// If the handshake failed due to the client not speaking
// TLS, assume they're speaking plaintext HTTP and write a
// 400 response on the TLS conn's underlying net.Conn.
diff --git a/src/net/http/transport.go b/src/net/http/transport.go
index 29d7434f2a8..65ba6644154 100644
--- a/src/net/http/transport.go
+++ b/src/net/http/transport.go
@@ -1502,7 +1502,7 @@ func (t *Transport) decConnsPerHost(key connectMethodKey) {
// Add TLS to a persistent connection, i.e. negotiate a TLS session. If pconn is already a TLS
// tunnel, this function establishes a nested TLS session inside the encrypted channel.
// The remote endpoint's name may be overridden by TLSClientConfig.ServerName.
-func (pconn *persistConn) addTLS(name string, trace *httptrace.ClientTrace) error {
+func (pconn *persistConn) addTLS(ctx context.Context, name string, trace *httptrace.ClientTrace) error {
// Initiate TLS and check remote host name against certificate.
cfg := cloneTLSConfig(pconn.t.TLSClientConfig)
if cfg.ServerName == "" {
@@ -1524,7 +1524,7 @@ func (pconn *persistConn) addTLS(name string, trace *httptrace.ClientTrace) erro
if trace != nil && trace.TLSHandshakeStart != nil {
trace.TLSHandshakeStart()
}
- err := tlsConn.Handshake()
+ err := tlsConn.HandshakeContext(ctx)
if timer != nil {
timer.Stop()
}
@@ -1580,7 +1580,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers
if trace != nil && trace.TLSHandshakeStart != nil {
trace.TLSHandshakeStart()
}
- if err := tc.Handshake(); err != nil {
+ if err := tc.HandshakeContext(ctx); err != nil {
go pconn.conn.Close()
if trace != nil && trace.TLSHandshakeDone != nil {
trace.TLSHandshakeDone(tls.ConnectionState{}, err)
@@ -1604,7 +1604,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers
if firstTLSHost, _, err = net.SplitHostPort(cm.addr()); err != nil {
return nil, wrapErr(err)
}
- if err = pconn.addTLS(firstTLSHost, trace); err != nil {
+ if err = pconn.addTLS(ctx, firstTLSHost, trace); err != nil {
return nil, wrapErr(err)
}
}
@@ -1718,7 +1718,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers
}
if cm.proxyURL != nil && cm.targetScheme == "https" {
- if err := pconn.addTLS(cm.tlsHost(), trace); err != nil {
+ if err := pconn.addTLS(ctx, cm.tlsHost(), trace); err != nil {
return nil, err
}
}
diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go
index e69133e7868..9086507d576 100644
--- a/src/net/http/transport_test.go
+++ b/src/net/http/transport_test.go
@@ -3735,7 +3735,7 @@ func TestTransportDialTLSContext(t *testing.T) {
if err != nil {
return nil, err
}
- return c, c.Handshake()
+ return c, c.HandshakeContext(ctx)
}
req, err := NewRequest("GET", ts.URL, nil)