diff --git a/src/pkg/crypto/tls/ca_set.go b/src/pkg/crypto/tls/ca_set.go index fe2a540f4db..ae00ac55868 100644 --- a/src/pkg/crypto/tls/ca_set.go +++ b/src/pkg/crypto/tls/ca_set.go @@ -16,6 +16,7 @@ type CASet struct { byName map[string][]*x509.Certificate } +// NewCASet returns a new, empty CASet. func NewCASet() *CASet { return &CASet{ make(map[string][]*x509.Certificate), diff --git a/src/pkg/crypto/tls/common.go b/src/pkg/crypto/tls/common.go index a4f2b804f10..4fb17ad3a89 100644 --- a/src/pkg/crypto/tls/common.go +++ b/src/pkg/crypto/tls/common.go @@ -78,6 +78,7 @@ const ( // Rest of these are reserved by the TLS spec ) +// ConnectionState records basic TLS details about the connection. type ConnectionState struct { HandshakeComplete bool CipherSuite uint16 @@ -88,28 +89,65 @@ type ConnectionState struct { // has been passed to a TLS function it must not be modified. type Config struct { // Rand provides the source of entropy for nonces and RSA blinding. + // If Rand is nil, TLS uses the cryptographic random reader in package + // crypto/rand. Rand io.Reader + // Time returns the current time as the number of seconds since the epoch. + // If Time is nil, TLS uses the system time.Seconds. Time func() int64 - // Certificates contains one or more certificate chains. + + // Certificates contains one or more certificate chains + // to present to the other side of the connection. + // Server configurations must include at least one certificate. Certificates []Certificate - RootCAs *CASet + + // RootCAs defines the set of root certificate authorities + // that clients use when verifying server certificates. + // If RootCAs is nil, TLS uses the host's root CA set. + RootCAs *CASet + // NextProtos is a list of supported, application level protocols. // Currently only server-side handling is supported. NextProtos []string + // ServerName is included in the client's handshake to support virtual // hosting. ServerName string - // AuthenticateClient determines if a server will request a certificate + + // AuthenticateClient controls whether a server will request a certificate // from the client. It does not require that the client send a - // certificate nor, if it does, that the certificate is anything more - // than self-signed. + // certificate nor does it require that the certificate sent be + // anything more than self-signed. AuthenticateClient bool } +func (c *Config) rand() io.Reader { + r := c.Rand + if r == nil { + return rand.Reader + } + return r +} + +func (c *Config) time() int64 { + t := c.Time + if t == nil { + t = time.Seconds + } + return t() +} + +func (c *Config) rootCAs() *CASet { + s := c.RootCAs + if s == nil { + s = defaultRoots() + } + return s +} + +// A Certificate is a chain of one or more certificates, leaf first. type Certificate struct { - // Certificate contains a chain of one or more certificates. Leaf - // certificate first. Certificate [][]byte PrivateKey *rsa.PrivateKey } @@ -143,14 +181,10 @@ func mutualVersion(vers uint16) (uint16, bool) { return vers, true } -// The defaultConfig is used in place of a nil *Config in the TLS server and client. -var varDefaultConfig *Config - -var once sync.Once +var emptyConfig Config func defaultConfig() *Config { - once.Do(initDefaultConfig) - return varDefaultConfig + return &emptyConfig } // Possible certificate files; stop after finding one. @@ -162,7 +196,16 @@ var certFiles = []string{ "/usr/share/curl/curl-ca-bundle.crt", // OS X } -func initDefaultConfig() { +var once sync.Once + +func defaultRoots() *CASet { + once.Do(initDefaultRoots) + return varDefaultRoots +} + +var varDefaultRoots *CASet + +func initDefaultRoots() { roots := NewCASet() for _, file := range certFiles { data, err := ioutil.ReadFile(file) @@ -171,10 +214,5 @@ func initDefaultConfig() { break } } - - varDefaultConfig = &Config{ - Rand: rand.Reader, - Time: time.Seconds, - RootCAs: roots, - } + varDefaultRoots = roots } diff --git a/src/pkg/crypto/tls/handshake_client.go b/src/pkg/crypto/tls/handshake_client.go index b6b0e0fad37..4cddba33030 100644 --- a/src/pkg/crypto/tls/handshake_client.go +++ b/src/pkg/crypto/tls/handshake_client.go @@ -30,12 +30,12 @@ func (c *Conn) clientHandshake() os.Error { serverName: c.config.ServerName, } - t := uint32(c.config.Time()) + t := uint32(c.config.time()) hello.random[0] = byte(t >> 24) hello.random[1] = byte(t >> 16) hello.random[2] = byte(t >> 8) hello.random[3] = byte(t) - _, err := io.ReadFull(c.config.Rand, hello.random[4:]) + _, err := io.ReadFull(c.config.rand(), hello.random[4:]) if err != nil { c.sendAlert(alertInternalError) return os.ErrorString("short read from Rand") @@ -217,12 +217,12 @@ func (c *Conn) clientHandshake() os.Error { preMasterSecret := make([]byte, 48) preMasterSecret[0] = byte(hello.vers >> 8) preMasterSecret[1] = byte(hello.vers) - _, err = io.ReadFull(c.config.Rand, preMasterSecret[2:]) + _, err = io.ReadFull(c.config.rand(), preMasterSecret[2:]) if err != nil { return c.sendAlert(alertInternalError) } - ckx.ciphertext, err = rsa.EncryptPKCS1v15(c.config.Rand, pub, preMasterSecret) + ckx.ciphertext, err = rsa.EncryptPKCS1v15(c.config.rand(), pub, preMasterSecret) if err != nil { return c.sendAlert(alertInternalError) } @@ -235,7 +235,7 @@ func (c *Conn) clientHandshake() os.Error { var digest [36]byte copy(digest[0:16], finishedHash.serverMD5.Sum()) copy(digest[16:36], finishedHash.serverSHA1.Sum()) - signed, err := rsa.SignPKCS1v15(c.config.Rand, c.config.Certificates[0].PrivateKey, rsa.HashMD5SHA1, digest[0:]) + signed, err := rsa.SignPKCS1v15(c.config.rand(), c.config.Certificates[0].PrivateKey, rsa.HashMD5SHA1, digest[0:]) if err != nil { return c.sendAlert(alertInternalError) } diff --git a/src/pkg/crypto/tls/handshake_server.go b/src/pkg/crypto/tls/handshake_server.go index 22550384610..6db2a6a1bf6 100644 --- a/src/pkg/crypto/tls/handshake_server.go +++ b/src/pkg/crypto/tls/handshake_server.go @@ -84,13 +84,13 @@ func (c *Conn) serverHandshake() os.Error { hello.vers = vers hello.cipherSuite = suite.id - t := uint32(config.Time()) + t := uint32(config.time()) hello.random = make([]byte, 32) hello.random[0] = byte(t >> 24) hello.random[1] = byte(t >> 16) hello.random[2] = byte(t >> 8) hello.random[3] = byte(t) - _, err = io.ReadFull(config.Rand, hello.random[4:]) + _, err = io.ReadFull(config.rand(), hello.random[4:]) if err != nil { return c.sendAlert(alertInternalError) } @@ -209,12 +209,12 @@ func (c *Conn) serverHandshake() os.Error { } preMasterSecret := make([]byte, 48) - _, err = io.ReadFull(config.Rand, preMasterSecret[2:]) + _, err = io.ReadFull(config.rand(), preMasterSecret[2:]) if err != nil { return c.sendAlert(alertInternalError) } - err = rsa.DecryptPKCS1v15SessionKey(config.Rand, config.Certificates[0].PrivateKey, ckx.ciphertext, preMasterSecret) + err = rsa.DecryptPKCS1v15SessionKey(config.rand(), config.Certificates[0].PrivateKey, ckx.ciphertext, preMasterSecret) if err != nil { return c.sendAlert(alertHandshakeFailure) } diff --git a/src/pkg/crypto/tls/tls.go b/src/pkg/crypto/tls/tls.go index 61f0a9702dc..b11d3225daa 100644 --- a/src/pkg/crypto/tls/tls.go +++ b/src/pkg/crypto/tls/tls.go @@ -15,19 +15,31 @@ import ( "strings" ) +// Server returns a new TLS server side connection +// using conn as the underlying transport. +// The configuration config must be non-nil and must have +// at least one certificate. func Server(conn net.Conn, config *Config) *Conn { return &Conn{conn: conn, config: config} } +// Client returns a new TLS client side connection +// using conn as the underlying transport. +// Client interprets a nil configuration as equivalent to +// the zero configuration; see the documentation of Config +// for the defaults. func Client(conn net.Conn, config *Config) *Conn { return &Conn{conn: conn, config: config, isClient: true} } +// A Listener implements a network listener (net.Listener) for TLS connections. type Listener struct { listener net.Listener config *Config } +// Accept waits for and returns the next incoming TLS connection. +// The returned connection c is a *tls.Conn. func (l *Listener) Accept() (c net.Conn, err os.Error) { c, err = l.listener.Accept() if err != nil { @@ -37,8 +49,10 @@ func (l *Listener) Accept() (c net.Conn, err os.Error) { return } +// Close closes the listener. func (l *Listener) Close() os.Error { return l.listener.Close() } +// Addr returns the listener's network address. func (l *Listener) Addr() net.Addr { return l.listener.Addr() } // NewListener creates a Listener which accepts connections from an inner @@ -52,7 +66,11 @@ func NewListener(listener net.Listener, config *Config) (l *Listener) { return } -func Listen(network, laddr string, config *Config) (net.Listener, os.Error) { +// Listen creates a TLS listener accepting connections on the +// given network address using net.Listen. +// The configuration config must be non-nil and must have +// at least one certificate. +func Listen(network, laddr string, config *Config) (*Listener, os.Error) { if config == nil || len(config.Certificates) == 0 { return nil, os.NewError("tls.Listen: no certificates in configuration") } @@ -63,7 +81,13 @@ func Listen(network, laddr string, config *Config) (net.Listener, os.Error) { return NewListener(l, config), nil } -func Dial(network, laddr, raddr string) (net.Conn, os.Error) { +// Dial connects to the given network address using net.Dial +// and then initiates a TLS handshake, returning the resulting +// TLS connection. +// Dial interprets a nil configuration as equivalent to +// the zero configuration; see the documentation of Config +// for the defaults. +func Dial(network, laddr, raddr string, config *Config) (*Conn, os.Error) { c, err := net.Dial(network, laddr, raddr) if err != nil { return nil, err @@ -75,15 +99,21 @@ func Dial(network, laddr, raddr string) (net.Conn, os.Error) { } hostname := raddr[:colonPos] - config := defaultConfig() - config.ServerName = hostname - conn := Client(c, config) - err = conn.Handshake() - if err == nil { - return conn, nil + if config == nil { + config = defaultConfig() } - c.Close() - return nil, err + if config.ServerName != "" { + // Make a copy to avoid polluting argument or default. + c := *config + c.ServerName = hostname + config = &c + } + conn := Client(c, config) + if err = conn.Handshake(); err != nil { + c.Close() + return nil, err + } + return conn, nil } // LoadX509KeyPair reads and parses a public/private key pair from a pair of diff --git a/src/pkg/http/client.go b/src/pkg/http/client.go index e902369e7c2..29678ee32ae 100644 --- a/src/pkg/http/client.go +++ b/src/pkg/http/client.go @@ -63,7 +63,7 @@ func send(req *Request) (resp *Response, err os.Error) { return nil, err } } else { // https - conn, err = tls.Dial("tcp", "", addr) + conn, err = tls.Dial("tcp", "", addr, nil) if err != nil { return nil, err } diff --git a/src/pkg/websocket/client.go b/src/pkg/websocket/client.go index caf63f16f65..09134594405 100644 --- a/src/pkg/websocket/client.go +++ b/src/pkg/websocket/client.go @@ -111,7 +111,7 @@ func Dial(url, protocol, origin string) (ws *Conn, err os.Error) { client, err = net.Dial("tcp", "", parsedUrl.Host) case "wss": - client, err = tls.Dial("tcp", "", parsedUrl.Host) + client, err = tls.Dial("tcp", "", parsedUrl.Host, nil) default: err = ErrBadScheme