mirror of
https://github.com/golang/go
synced 2024-11-23 04:10:04 -07:00
crypto/tls: add Dialer
Fixes #18482 Change-Id: I99d65dc5d824c00093ea61e7445fc121314af87f Reviewed-on: https://go-review.googlesource.com/c/go/+/214977 Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org> Reviewed-by: Ian Lance Taylor <iant@golang.org>
This commit is contained in:
parent
7004be998b
commit
40a144b94f
@ -133,6 +133,18 @@ TODO
|
|||||||
TODO
|
TODO
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
<dl id="crypto/tls"><dt><a href="/crypto/tls/">crypto/tls</a></dt>
|
||||||
|
<dd>
|
||||||
|
<p><!-- CL 214977 -->
|
||||||
|
The new
|
||||||
|
<a href="/pkg/crypto/tls/#Dialer"><code>Dialer</code></a>
|
||||||
|
type and its
|
||||||
|
<a href="/pkg/crypto/tls/#Dialer.DialContext"><code>DialContext</code></a>
|
||||||
|
method permits using a context to both connect and handshake with a TLS server.
|
||||||
|
</p>
|
||||||
|
</dd>
|
||||||
|
</dl>
|
||||||
|
|
||||||
<dl id="flag"><dt><a href="/pkg/flag/">flag</a></dt>
|
<dl id="flag"><dt><a href="/pkg/flag/">flag</a></dt>
|
||||||
<dd>
|
<dd>
|
||||||
<p><!-- CL 221427 -->
|
<p><!-- CL 221427 -->
|
||||||
|
@ -1334,8 +1334,12 @@ func (c *Conn) closeNotify() error {
|
|||||||
|
|
||||||
// Handshake runs the client or server handshake
|
// Handshake runs the client or server handshake
|
||||||
// protocol if it has not yet been run.
|
// protocol if it has not yet been run.
|
||||||
// Most uses of this package need not call Handshake
|
//
|
||||||
// explicitly: the first Read or Write will call it automatically.
|
// Most uses of this package need not call Handshake explicitly: the
|
||||||
|
// first Read or Write will call it automatically.
|
||||||
|
//
|
||||||
|
// For control over canceling or setting a timeout on a handshake, use
|
||||||
|
// the Dialer's DialContext method.
|
||||||
func (c *Conn) Handshake() error {
|
func (c *Conn) Handshake() error {
|
||||||
c.handshakeMutex.Lock()
|
c.handshakeMutex.Lock()
|
||||||
defer c.handshakeMutex.Unlock()
|
defer c.handshakeMutex.Unlock()
|
||||||
|
@ -13,6 +13,7 @@ package tls
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"crypto"
|
"crypto"
|
||||||
"crypto/ecdsa"
|
"crypto/ecdsa"
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
@ -111,29 +112,35 @@ func (timeoutError) Temporary() bool { return true }
|
|||||||
// DialWithDialer interprets a nil configuration as equivalent to the zero
|
// DialWithDialer interprets a nil configuration as equivalent to the zero
|
||||||
// configuration; see the documentation of Config for the defaults.
|
// configuration; see the documentation of Config for the defaults.
|
||||||
func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
|
func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
|
||||||
|
return dial(context.Background(), dialer, network, addr, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
|
||||||
// We want the Timeout and Deadline values from dialer to cover the
|
// We want the Timeout and Deadline values from dialer to cover the
|
||||||
// whole process: TCP connection and TLS handshake. This means that we
|
// whole process: TCP connection and TLS handshake. This means that we
|
||||||
// also need to start our own timers now.
|
// also need to start our own timers now.
|
||||||
timeout := dialer.Timeout
|
timeout := netDialer.Timeout
|
||||||
|
|
||||||
if !dialer.Deadline.IsZero() {
|
if !netDialer.Deadline.IsZero() {
|
||||||
deadlineTimeout := time.Until(dialer.Deadline)
|
deadlineTimeout := time.Until(netDialer.Deadline)
|
||||||
if timeout == 0 || deadlineTimeout < timeout {
|
if timeout == 0 || deadlineTimeout < timeout {
|
||||||
timeout = deadlineTimeout
|
timeout = deadlineTimeout
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var errChannel chan error
|
// hsErrCh is non-nil if we might not wait for Handshake to complete.
|
||||||
|
var hsErrCh chan error
|
||||||
|
if timeout != 0 || ctx.Done() != nil {
|
||||||
|
hsErrCh = make(chan error, 2)
|
||||||
|
}
|
||||||
if timeout != 0 {
|
if timeout != 0 {
|
||||||
errChannel = make(chan error, 2)
|
|
||||||
timer := time.AfterFunc(timeout, func() {
|
timer := time.AfterFunc(timeout, func() {
|
||||||
errChannel <- timeoutError{}
|
hsErrCh <- timeoutError{}
|
||||||
})
|
})
|
||||||
defer timer.Stop()
|
defer timer.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
rawConn, err := dialer.Dial(network, addr)
|
rawConn, err := netDialer.DialContext(ctx, network, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -158,14 +165,26 @@ func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*
|
|||||||
|
|
||||||
conn := Client(rawConn, config)
|
conn := Client(rawConn, config)
|
||||||
|
|
||||||
if timeout == 0 {
|
if hsErrCh == nil {
|
||||||
err = conn.Handshake()
|
err = conn.Handshake()
|
||||||
} else {
|
} else {
|
||||||
go func() {
|
go func() {
|
||||||
errChannel <- conn.Handshake()
|
hsErrCh <- conn.Handshake()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err = <-errChannel
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
err = ctx.Err()
|
||||||
|
case err = <-hsErrCh:
|
||||||
|
if err != nil {
|
||||||
|
// If the error was due to the context
|
||||||
|
// closing, prefer the context's error, rather
|
||||||
|
// than some random network teardown error.
|
||||||
|
if e := ctx.Err(); e != nil {
|
||||||
|
err = e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -186,6 +205,54 @@ func Dial(network, addr string, config *Config) (*Conn, error) {
|
|||||||
return DialWithDialer(new(net.Dialer), network, addr, config)
|
return DialWithDialer(new(net.Dialer), network, addr, config)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Dialer dials TLS connections given a configuration and a Dialer for the
|
||||||
|
// underlying connection.
|
||||||
|
type Dialer struct {
|
||||||
|
// NetDialer is the optional dialer to use for the TLS connections'
|
||||||
|
// underlying TCP connections.
|
||||||
|
// A nil NetDialer is equivalent to the net.Dialer zero value.
|
||||||
|
NetDialer *net.Dialer
|
||||||
|
|
||||||
|
// Config is the TLS configuration to use for new connections.
|
||||||
|
// A nil configuration is equivalent to the zero
|
||||||
|
// configuration; see the documentation of Config for the
|
||||||
|
// defaults.
|
||||||
|
Config *Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dial connects to the given network address and initiates a TLS
|
||||||
|
// handshake, returning the resulting TLS connection.
|
||||||
|
//
|
||||||
|
// The returned Conn, if any, will always be of type *Conn.
|
||||||
|
func (d *Dialer) Dial(network, addr string) (net.Conn, error) {
|
||||||
|
return d.DialContext(context.Background(), network, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Dialer) netDialer() *net.Dialer {
|
||||||
|
if d.NetDialer != nil {
|
||||||
|
return d.NetDialer
|
||||||
|
}
|
||||||
|
return new(net.Dialer)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dial connects to the given network address and initiates a TLS
|
||||||
|
// handshake, returning the resulting TLS connection.
|
||||||
|
//
|
||||||
|
// The provided Context must be non-nil. If the context expires before
|
||||||
|
// the connection is complete, an error is returned. Once successfully
|
||||||
|
// connected, any expiration of the context will not affect the
|
||||||
|
// connection.
|
||||||
|
//
|
||||||
|
// The returned Conn, if any, will always be of type *Conn.
|
||||||
|
func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
c, err := dial(ctx, d.netDialer(), network, addr, d.Config)
|
||||||
|
if err != nil {
|
||||||
|
// Don't return c (a typed nil) in an interface.
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
// LoadX509KeyPair reads and parses a public/private key pair from a pair
|
// LoadX509KeyPair reads and parses a public/private key pair from a pair
|
||||||
// of files. The files must contain PEM encoded data. The certificate file
|
// of files. The files must contain PEM encoded data. The certificate file
|
||||||
// may contain intermediate certificates following the leaf certificate to
|
// may contain intermediate certificates following the leaf certificate to
|
||||||
|
@ -6,6 +6,7 @@ package tls
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"crypto"
|
"crypto"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@ -272,6 +273,47 @@ func TestDeadlineOnWrite(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type readerFunc func([]byte) (int, error)
|
||||||
|
|
||||||
|
func (f readerFunc) Read(b []byte) (int, error) { return f(b) }
|
||||||
|
|
||||||
|
// TestDialer tests that tls.Dialer.DialContext can abort in the middle of a handshake.
|
||||||
|
// (The other cases are all handled by the existing dial tests in this package, which
|
||||||
|
// all also flow through the same code shared code paths)
|
||||||
|
func TestDialer(t *testing.T) {
|
||||||
|
ln := newLocalListener(t)
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
unblockServer := make(chan struct{}) // close-only
|
||||||
|
defer close(unblockServer)
|
||||||
|
go func() {
|
||||||
|
conn, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
<-unblockServer
|
||||||
|
}()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
d := Dialer{Config: &Config{
|
||||||
|
Rand: readerFunc(func(b []byte) (n int, err error) {
|
||||||
|
// By the time crypto/tls wants randomness, that means it has a TCP
|
||||||
|
// connection, so we're past the Dialer's dial and now blocked
|
||||||
|
// in a handshake. Cancel our context and see if we get unstuck.
|
||||||
|
// (Our TCP listener above never reads or writes, so the Handshake
|
||||||
|
// would otherwise be stuck forever)
|
||||||
|
cancel()
|
||||||
|
return len(b), nil
|
||||||
|
}),
|
||||||
|
ServerName: "foo",
|
||||||
|
}}
|
||||||
|
_, err := d.DialContext(ctx, "tcp", ln.Addr().String())
|
||||||
|
if err != context.Canceled {
|
||||||
|
t.Errorf("err = %v; want context.Canceled", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func isTimeoutError(err error) bool {
|
func isTimeoutError(err error) bool {
|
||||||
if ne, ok := err.(net.Error); ok {
|
if ne, ok := err.(net.Error); ok {
|
||||||
return ne.Timeout()
|
return ne.Timeout()
|
||||||
|
@ -408,7 +408,7 @@ var pkgDeps = map[string][]string{
|
|||||||
// SSL/TLS.
|
// SSL/TLS.
|
||||||
"crypto/tls": {
|
"crypto/tls": {
|
||||||
"L4", "CRYPTO-MATH", "OS", "golang.org/x/crypto/cryptobyte", "golang.org/x/crypto/hkdf",
|
"L4", "CRYPTO-MATH", "OS", "golang.org/x/crypto/cryptobyte", "golang.org/x/crypto/hkdf",
|
||||||
"container/list", "crypto/x509", "encoding/pem", "net", "syscall", "crypto/ed25519",
|
"container/list", "context", "crypto/x509", "encoding/pem", "net", "syscall", "crypto/ed25519",
|
||||||
},
|
},
|
||||||
"crypto/x509": {
|
"crypto/x509": {
|
||||||
"L4", "CRYPTO-MATH", "OS", "CGO", "crypto/ed25519",
|
"L4", "CRYPTO-MATH", "OS", "CGO", "crypto/ed25519",
|
||||||
|
Loading…
Reference in New Issue
Block a user