diff --git a/src/crypto/tls/conn.go b/src/crypto/tls/conn.go index fac4b91473..eeab030eca 100644 --- a/src/crypto/tls/conn.go +++ b/src/crypto/tls/conn.go @@ -162,9 +162,22 @@ type halfConn struct { trafficSecret []byte // current TLS 1.3 traffic secret } +type permamentError struct { + err net.Error +} + +func (e *permamentError) Error() string { return e.err.Error() } +func (e *permamentError) Unwrap() error { return e.err } +func (e *permamentError) Timeout() bool { return e.err.Timeout() } +func (e *permamentError) Temporary() bool { return false } + func (hc *halfConn) setErrorLocked(err error) error { - hc.err = err - return err + if e, ok := err.(net.Error); ok { + hc.err = &permamentError{err: e} + } else { + hc.err = err + } + return hc.err } // prepareCipherSpec sets the encryption and MAC states diff --git a/src/crypto/tls/handshake_server_test.go b/src/crypto/tls/handshake_server_test.go index 953ca0026e..61f0ca2bf7 100644 --- a/src/crypto/tls/handshake_server_test.go +++ b/src/crypto/tls/handshake_server_test.go @@ -355,7 +355,8 @@ func TestAlertForwarding(t *testing.T) { err := Server(s, testConfig).Handshake() s.Close() - if e, ok := err.(*net.OpError); !ok || e.Err != error(alertUnknownCA) { + var opErr *net.OpError + if !errors.As(err, &opErr) || opErr.Err != error(alertUnknownCA) { t.Errorf("Got error: %s; expected: %s", err, error(alertUnknownCA)) } } diff --git a/src/crypto/tls/tls_test.go b/src/crypto/tls/tls_test.go index 89fac607e1..42fd5e1b8c 100644 --- a/src/crypto/tls/tls_test.go +++ b/src/crypto/tls/tls_test.go @@ -201,6 +201,77 @@ func TestDialTimeout(t *testing.T) { } } +func TestDeadlineOnWrite(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + + ln := newLocalListener(t) + defer ln.Close() + + srvCh := make(chan *Conn, 1) + + go func() { + sconn, err := ln.Accept() + if err != nil { + srvCh <- nil + return + } + srv := Server(sconn, testConfig.Clone()) + if err := srv.Handshake(); err != nil { + srvCh <- nil + return + } + srvCh <- srv + }() + + clientConfig := testConfig.Clone() + clientConfig.MaxVersion = VersionTLS12 + conn, err := Dial("tcp", ln.Addr().String(), clientConfig) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + srv := <-srvCh + if srv == nil { + t.Error(err) + } + + // Make sure the client/server is setup correctly and is able to do a typical Write/Read + buf := make([]byte, 6) + if _, err := srv.Write([]byte("foobar")); err != nil { + t.Errorf("Write err: %v", err) + } + if n, err := conn.Read(buf); n != 6 || err != nil || string(buf) != "foobar" { + t.Errorf("Read = %d, %v, data %q; want 6, nil, foobar", n, err, buf) + } + + // Set a deadline which should cause Write to timeout + if err = srv.SetDeadline(time.Now()); err != nil { + t.Fatalf("SetDeadline(time.Now()) err: %v", err) + } + if _, err = srv.Write([]byte("should fail")); err == nil { + t.Fatal("Write should have timed out") + } + + // Clear deadline and make sure it still times out + if err = srv.SetDeadline(time.Time{}); err != nil { + t.Fatalf("SetDeadline(time.Time{}) err: %v", err) + } + if _, err = srv.Write([]byte("This connection is permanently broken")); err == nil { + t.Fatal("Write which previously failed should still time out") + } + + // Verify the error + if ne := err.(net.Error); ne.Temporary() != false { + t.Error("Write timed out but incorrectly classified the error as Temporary") + } + if !isTimeoutError(err) { + t.Error("Write timed out but did not classify the error as a Timeout") + } +} + func isTimeoutError(err error) bool { if ne, ok := err.(net.Error); ok { return ne.Timeout()