mirror of
https://github.com/golang/go
synced 2024-11-23 15:50:07 -07:00
crypto/tls: add DialWithDialer.
While reviewing uses of the lower-level Client API in code, I found that in many cases, code was using Client only because it needed a timeout on the connection. DialWithDialer allows a timeout (and other values) to be specified without resorting to the low-level API. LGTM=r R=golang-codereviews, r, bradfitz CC=golang-codereviews https://golang.org/cl/68920045
This commit is contained in:
parent
b3e0a8df24
commit
1f8b2a69ec
@ -15,6 +15,7 @@ import (
|
|||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Server returns a new TLS server side connection
|
// Server returns a new TLS server side connection
|
||||||
@ -76,24 +77,51 @@ func Listen(network, laddr string, config *Config) (net.Listener, error) {
|
|||||||
return NewListener(l, config), nil
|
return NewListener(l, config), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dial connects to the given network address using net.Dial
|
type timeoutError struct{}
|
||||||
// and then initiates a TLS handshake, returning the resulting
|
|
||||||
// TLS connection.
|
func (timeoutError) Error() string { return "tls: DialWithDialer timed out" }
|
||||||
// Dial interprets a nil configuration as equivalent to
|
func (timeoutError) Timeout() bool { return true }
|
||||||
// the zero configuration; see the documentation of Config
|
func (timeoutError) Temporary() bool { return true }
|
||||||
// for the defaults.
|
|
||||||
func Dial(network, addr string, config *Config) (*Conn, error) {
|
// DialWithDialer connects to the given network address using dialer.Dial and
|
||||||
raddr := addr
|
// then initiates a TLS handshake, returning the resulting TLS connection. Any
|
||||||
c, err := net.Dial(network, raddr)
|
// timeout or deadline given in the dialer apply to connection and TLS
|
||||||
|
// handshake as a whole.
|
||||||
|
//
|
||||||
|
// DialWithDialer interprets a nil configuration as equivalent to the zero
|
||||||
|
// configuration; see the documentation of Config for the defaults.
|
||||||
|
func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
|
||||||
|
// We want the Timeout and Deadline values from dialer to cover the
|
||||||
|
// whole process: TCP connection and TLS handshake. This means that we
|
||||||
|
// also need to start our own timers now.
|
||||||
|
timeout := dialer.Timeout
|
||||||
|
|
||||||
|
if !dialer.Deadline.IsZero() {
|
||||||
|
deadlineTimeout := dialer.Deadline.Sub(time.Now())
|
||||||
|
if timeout == 0 || deadlineTimeout < timeout {
|
||||||
|
timeout = deadlineTimeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var errChannel chan error
|
||||||
|
|
||||||
|
if timeout != 0 {
|
||||||
|
errChannel = make(chan error, 2)
|
||||||
|
time.AfterFunc(timeout, func() {
|
||||||
|
errChannel <- timeoutError{}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
rawConn, err := dialer.Dial(network, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
colonPos := strings.LastIndex(raddr, ":")
|
colonPos := strings.LastIndex(addr, ":")
|
||||||
if colonPos == -1 {
|
if colonPos == -1 {
|
||||||
colonPos = len(raddr)
|
colonPos = len(addr)
|
||||||
}
|
}
|
||||||
hostname := raddr[:colonPos]
|
hostname := addr[:colonPos]
|
||||||
|
|
||||||
if config == nil {
|
if config == nil {
|
||||||
config = defaultConfig()
|
config = defaultConfig()
|
||||||
@ -106,14 +134,37 @@ func Dial(network, addr string, config *Config) (*Conn, error) {
|
|||||||
c.ServerName = hostname
|
c.ServerName = hostname
|
||||||
config = &c
|
config = &c
|
||||||
}
|
}
|
||||||
conn := Client(c, config)
|
|
||||||
if err = conn.Handshake(); err != nil {
|
conn := Client(rawConn, config)
|
||||||
c.Close()
|
|
||||||
|
if timeout == 0 {
|
||||||
|
err = conn.Handshake()
|
||||||
|
} else {
|
||||||
|
go func() {
|
||||||
|
errChannel <- conn.Handshake()
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = <-errChannel
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
rawConn.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Dial connects to the given network address using net.Dial
|
||||||
|
// and then initiates a TLS handshake, returning the resulting
|
||||||
|
// TLS connection.
|
||||||
|
// Dial interprets a nil configuration as equivalent to
|
||||||
|
// the zero configuration; see the documentation of Config
|
||||||
|
// for the defaults.
|
||||||
|
func Dial(network, addr string, config *Config) (*Conn, error) {
|
||||||
|
return DialWithDialer(new(net.Dialer), network, addr, config)
|
||||||
|
}
|
||||||
|
|
||||||
// LoadX509KeyPair reads and parses a public/private key pair from a pair of
|
// LoadX509KeyPair reads and parses a public/private key pair from a pair of
|
||||||
// files. The files must contain PEM encoded data.
|
// files. The files must contain PEM encoded data.
|
||||||
func LoadX509KeyPair(certFile, keyFile string) (cert Certificate, err error) {
|
func LoadX509KeyPair(certFile, keyFile string) (cert Certificate, err error) {
|
||||||
|
@ -5,7 +5,10 @@
|
|||||||
package tls
|
package tls
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var rsaCertPEM = `-----BEGIN CERTIFICATE-----
|
var rsaCertPEM = `-----BEGIN CERTIFICATE-----
|
||||||
@ -105,3 +108,45 @@ func TestX509MixedKeyPair(t *testing.T) {
|
|||||||
t.Error("Load of ECDSA certificate succeeded with RSA private key")
|
t.Error("Load of ECDSA certificate succeeded with RSA private key")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDialTimeout(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
listener, err = net.Listen("tcp6", "[::1]:0")
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
addr := listener.Addr().String()
|
||||||
|
defer listener.Close()
|
||||||
|
|
||||||
|
complete := make(chan bool)
|
||||||
|
defer close(complete)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
conn, err := listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
<-complete
|
||||||
|
conn.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
dialer := &net.Dialer{
|
||||||
|
Timeout: 10 * time.Millisecond,
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err = DialWithDialer(dialer, "tcp", addr, nil); err == nil {
|
||||||
|
t.Fatal("DialWithTimeout completed successfully")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(err.Error(), "timed out") {
|
||||||
|
t.Errorf("resulting error not a timeout: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user