diff --git a/src/crypto/tls/handshake_client_test.go b/src/crypto/tls/handshake_client_test.go index 693f9686a79..f499cf3970d 100644 --- a/src/crypto/tls/handshake_client_test.go +++ b/src/crypto/tls/handshake_client_test.go @@ -6,6 +6,7 @@ package tls import ( "bytes" + "context" "crypto/rsa" "crypto/x509" "encoding/base64" @@ -20,6 +21,7 @@ import ( "os/exec" "path/filepath" "reflect" + "runtime" "strconv" "strings" "testing" @@ -2511,3 +2513,35 @@ func testResumptionKeepsOCSPAndSCT(t *testing.T, ver uint16) { serverConfig.Certificates[0].SignedCertificateTimestamps, ccs.SignedCertificateTimestamps) } } + +// TestClientHandshakeContextCancellation tests that cancelling +// the context given to the client side conn.HandshakeContext +// interrupts the in-progress handshake. +func TestClientHandshakeContextCancellation(t *testing.T) { + c, s := localPipe(t) + ctx, cancel := context.WithCancel(context.Background()) + unblockServer := make(chan struct{}) + defer close(unblockServer) + go func() { + cancel() + <-unblockServer + _ = s.Close() + }() + cli := Client(c, testConfig) + // Initiates client side handshake, which will block until the client hello is read + // by the server, unless the cancellation works. + err := cli.HandshakeContext(ctx) + if err == nil { + t.Fatal("Client handshake did not error when the context was canceled") + } + if err != context.Canceled { + t.Errorf("Unexpected client handshake error: %v", err) + } + if runtime.GOARCH == "wasm" { + t.Skip("conn.Close does not error as expected when called multiple times on WASM") + } + err = cli.Close() + if err == nil { + t.Error("Client connection was not closed when the context was canceled") + } +} diff --git a/src/crypto/tls/handshake_server_test.go b/src/crypto/tls/handshake_server_test.go index 756d288cb35..4c2d319fb1a 100644 --- a/src/crypto/tls/handshake_server_test.go +++ b/src/crypto/tls/handshake_server_test.go @@ -1946,27 +1946,22 @@ func TestAESCipherReordering13(t *testing.T) { } } +// TestServerHandshakeContextCancellation tests that cancelling +// the context given to the server side conn.HandshakeContext +// interrupts the in-progress handshake. func TestServerHandshakeContextCancellation(t *testing.T) { c, s := localPipe(t) - clientConfig := testConfig.Clone() - clientErr := make(chan error, 1) ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + unblockClient := make(chan struct{}) + defer close(unblockClient) go func() { - defer close(clientErr) - defer c.Close() - clientHello := &clientHelloMsg{ - vers: VersionTLS10, - random: make([]byte, 32), - cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, - compressionMethods: []uint8{compressionNone}, - } - cli := Client(c, clientConfig) - _, err := cli.writeRecord(recordTypeHandshake, clientHello.marshal()) cancel() - clientErr <- err + <-unblockClient + _ = c.Close() }() conn := Server(s, testConfig) + // Initiates server side handshake, which will block until a client hello is read + // unless the cancellation works. err := conn.HandshakeContext(ctx) if err == nil { t.Fatal("Server handshake did not error when the context was canceled") @@ -1974,9 +1969,6 @@ func TestServerHandshakeContextCancellation(t *testing.T) { if err != context.Canceled { t.Errorf("Unexpected server handshake error: %v", err) } - if err := <-clientErr; err != nil { - t.Errorf("Unexpected client error: %v", err) - } if runtime.GOARCH == "wasm" { t.Skip("conn.Close does not error as expected when called multiple times on WASM") }