diff --git a/src/crypto/tls/conn.go b/src/crypto/tls/conn.go index ea299930a9..a44d56dcb1 100644 --- a/src/crypto/tls/conn.go +++ b/src/crypto/tls/conn.go @@ -1235,6 +1235,10 @@ func (c *Conn) Handshake() error { } if c.handshakeErr == nil { c.handshakes++ + } else { + // If an error occurred during the hadshake try to flush the + // alert that might be left in the buffer. + c.flush() } return c.handshakeErr } diff --git a/src/crypto/tls/handshake_client_test.go b/src/crypto/tls/handshake_client_test.go index a5491bcdf3..c87ad5babd 100644 --- a/src/crypto/tls/handshake_client_test.go +++ b/src/crypto/tls/handshake_client_test.go @@ -15,6 +15,7 @@ import ( "errors" "fmt" "io" + "math/big" "net" "os" "os/exec" @@ -1123,3 +1124,47 @@ func TestBuffering(t *testing.T) { t.Errorf("expected server handshake to complete with only two writes, but saw %d", n) } } + +func TestAlertFlushing(t *testing.T) { + c, s := net.Pipe() + done := make(chan bool) + + clientWCC := &writeCountingConn{Conn: c} + serverWCC := &writeCountingConn{Conn: s} + + serverConfig := testConfig.Clone() + + // Cause a signature-time error + brokenKey := rsa.PrivateKey{PublicKey: testRSAPrivateKey.PublicKey} + brokenKey.D = big.NewInt(42) + serverConfig.Certificates = []Certificate{{ + Certificate: [][]byte{testRSACertificate}, + PrivateKey: &brokenKey, + }} + + go func() { + Server(serverWCC, serverConfig).Handshake() + serverWCC.Close() + done <- true + }() + + err := Client(clientWCC, testConfig).Handshake() + if err == nil { + t.Fatal("client unexpectedly returned no error") + } + + const expectedError = "remote error: tls: handshake failure" + if e := err.Error(); !strings.Contains(e, expectedError) { + t.Fatalf("expected to find %q in error but error was %q", expectedError, e) + } + clientWCC.Close() + <-done + + if n := clientWCC.numWrites; n != 1 { + t.Errorf("expected client handshake to complete with one write, but saw %d", n) + } + + if n := serverWCC.numWrites; n != 1 { + t.Errorf("expected server handshake to complete with one write, but saw %d", n) + } +}