diff --git a/src/net/smtp/smtp.go b/src/net/smtp/smtp.go index 87dea442c4..c9b3c07aa8 100644 --- a/src/net/smtp/smtp.go +++ b/src/net/smtp/smtp.go @@ -157,6 +157,17 @@ func (c *Client) StartTLS(config *tls.Config) error { return c.ehlo() } +// TLSConnectionState returns the client's TLS connection state. +// The return values are their zero values if StartTLS did +// not succeed. +func (c *Client) TLSConnectionState() (state tls.ConnectionState, ok bool) { + tc, ok := c.conn.(*tls.Conn) + if !ok { + return + } + return tc.ConnectionState(), true +} + // Verify checks the validity of an email address on the server. // If Verify returns nil, the address is valid. A non-nil return // does not necessarily indicate an invalid address. Many servers diff --git a/src/net/smtp/smtp_test.go b/src/net/smtp/smtp_test.go index 5c659e8a09..3ae0d5bf1d 100644 --- a/src/net/smtp/smtp_test.go +++ b/src/net/smtp/smtp_test.go @@ -571,6 +571,50 @@ func TestTLSClient(t *testing.T) { } } +func TestTLSConnState(t *testing.T) { + ln := newLocalListener(t) + defer ln.Close() + clientDone := make(chan bool) + serverDone := make(chan bool) + go func() { + defer close(serverDone) + c, err := ln.Accept() + if err != nil { + t.Errorf("Server accept: %v", err) + return + } + defer c.Close() + if err := serverHandle(c, t); err != nil { + t.Errorf("server error: %v", err) + } + }() + go func() { + defer close(clientDone) + c, err := Dial(ln.Addr().String()) + if err != nil { + t.Errorf("Client dial: %v", err) + return + } + defer c.Quit() + cfg := &tls.Config{ServerName: "example.com"} + testHookStartTLS(cfg) // set the RootCAs + if err := c.StartTLS(cfg); err != nil { + t.Errorf("StartTLS: %v", err) + return + } + cs, ok := c.TLSConnectionState() + if !ok { + t.Errorf("TLSConnectionState returned ok == false; want true") + return + } + if cs.Version == 0 || !cs.HandshakeComplete { + t.Errorf("ConnectionState = %#v; expect non-zero Version and HandshakeComplete", cs) + } + }() + <-clientDone + <-serverDone +} + func newLocalListener(t *testing.T) net.Listener { ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil {