mirror of
https://github.com/golang/go
synced 2024-11-25 06:57:58 -07:00
net: add ControlContext to Dialer
Fixes #55301 Change-Id: Ie8abcd383eee9af75038bde908ac638f43d33b7e Reviewed-on: https://go-review.googlesource.com/c/go/+/444955 Reviewed-by: Bryan Mills <bcmills@google.com> Reviewed-by: Ian Lance Taylor <iant@google.com> Run-TryBot: Ian Lance Taylor <iant@google.com> Auto-Submit: Ian Lance Taylor <iant@google.com> TryBot-Result: Gopher Robot <gobot@golang.org> Run-TryBot: xie cui <523516579@qq.com>
This commit is contained in:
parent
bc2dc23846
commit
90b40c0496
1
api/next/55301.txt
Normal file
1
api/next/55301.txt
Normal file
@ -0,0 +1 @@
|
||||
pkg net, type Dialer struct, ControlContext func(context.Context, string, string, syscall.RawConn) error #55301
|
@ -95,7 +95,19 @@ type Dialer struct {
|
||||
// Network and address parameters passed to Control method are not
|
||||
// necessarily the ones passed to Dial. For example, passing "tcp" to Dial
|
||||
// will cause the Control function to be called with "tcp4" or "tcp6".
|
||||
//
|
||||
// Control is ignored if ControlContext is not nil.
|
||||
Control func(network, address string, c syscall.RawConn) error
|
||||
|
||||
// If ControlContext is not nil, it is called after creating the network
|
||||
// connection but before actually dialing.
|
||||
//
|
||||
// Network and address parameters passed to Control method are not
|
||||
// necessarily the ones passed to Dial. For example, passing "tcp" to Dial
|
||||
// will cause the Control function to be called with "tcp4" or "tcp6".
|
||||
//
|
||||
// If ControlContext is not nil, Control is ignored.
|
||||
ControlContext func(cxt context.Context, network, address string, c syscall.RawConn) error
|
||||
}
|
||||
|
||||
func (d *Dialer) dualStack() bool { return d.FallbackDelay >= 0 }
|
||||
|
@ -17,6 +17,7 @@ import (
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@ -939,6 +940,36 @@ func TestDialerControl(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestDialerControlContext(t *testing.T) {
|
||||
switch runtime.GOOS {
|
||||
case "plan9":
|
||||
t.Skipf("%s does not have full support of socktest", runtime.GOOS)
|
||||
}
|
||||
t.Run("StreamDial", func(t *testing.T) {
|
||||
for i, network := range []string{"tcp", "tcp4", "tcp6", "unix", "unixpacket"} {
|
||||
if !testableNetwork(network) {
|
||||
continue
|
||||
}
|
||||
ln := newLocalListener(t, network)
|
||||
defer ln.Close()
|
||||
var id int
|
||||
d := Dialer{ControlContext: func(ctx context.Context, network string, address string, c syscall.RawConn) error {
|
||||
id = ctx.Value("id").(int)
|
||||
return controlOnConnSetup(network, address, c)
|
||||
}}
|
||||
c, err := d.DialContext(context.WithValue(context.Background(), "id", i+1), network, ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
continue
|
||||
}
|
||||
if id != i+1 {
|
||||
t.Errorf("got id %d, want %d", id, i+1)
|
||||
}
|
||||
c.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// mustHaveExternalNetwork is like testenv.MustHaveExternalNetwork
|
||||
// except that it won't skip testing on non-mobile builders.
|
||||
func mustHaveExternalNetwork(t *testing.T) {
|
||||
|
@ -122,7 +122,13 @@ func (sd *sysDialer) dialIP(ctx context.Context, laddr, raddr *IPAddr) (*IPConn,
|
||||
default:
|
||||
return nil, UnknownNetworkError(sd.network)
|
||||
}
|
||||
fd, err := internetSocket(ctx, network, laddr, raddr, syscall.SOCK_RAW, proto, "dial", sd.Dialer.Control)
|
||||
ctrlCtxFn := sd.Dialer.ControlContext
|
||||
if ctrlCtxFn == nil && sd.Dialer.Control != nil {
|
||||
ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
|
||||
return sd.Dialer.Control(network, address, c)
|
||||
}
|
||||
}
|
||||
fd, err := internetSocket(ctx, network, laddr, raddr, syscall.SOCK_RAW, proto, "dial", ctrlCtxFn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -139,7 +145,13 @@ func (sl *sysListener) listenIP(ctx context.Context, laddr *IPAddr) (*IPConn, er
|
||||
default:
|
||||
return nil, UnknownNetworkError(sl.network)
|
||||
}
|
||||
fd, err := internetSocket(ctx, network, laddr, nil, syscall.SOCK_RAW, proto, "listen", sl.ListenConfig.Control)
|
||||
var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error
|
||||
if sl.ListenConfig.Control != nil {
|
||||
ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
|
||||
return sl.ListenConfig.Control(network, address, c)
|
||||
}
|
||||
}
|
||||
fd, err := internetSocket(ctx, network, laddr, nil, syscall.SOCK_RAW, proto, "listen", ctrlCtxFn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -134,12 +134,12 @@ func favoriteAddrFamily(network string, laddr, raddr sockaddr, mode string) (fam
|
||||
return syscall.AF_INET6, false
|
||||
}
|
||||
|
||||
func internetSocket(ctx context.Context, net string, laddr, raddr sockaddr, sotype, proto int, mode string, ctrlFn func(string, string, syscall.RawConn) error) (fd *netFD, err error) {
|
||||
func internetSocket(ctx context.Context, net string, laddr, raddr sockaddr, sotype, proto int, mode string, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) (fd *netFD, err error) {
|
||||
if (runtime.GOOS == "aix" || runtime.GOOS == "windows" || runtime.GOOS == "openbsd") && mode == "dial" && raddr.isWildcard() {
|
||||
raddr = raddr.toLocal(net)
|
||||
}
|
||||
family, ipv6only := favoriteAddrFamily(net, laddr, raddr, mode)
|
||||
return socket(ctx, net, family, sotype, proto, ipv6only, laddr, raddr, ctrlFn)
|
||||
return socket(ctx, net, family, sotype, proto, ipv6only, laddr, raddr, ctrlCtxFn)
|
||||
}
|
||||
|
||||
func ipToSockaddrInet4(ip IP, port int) (syscall.SockaddrInet4, error) {
|
||||
|
@ -57,7 +57,7 @@ type netFD struct {
|
||||
|
||||
// socket returns a network file descriptor that is ready for
|
||||
// asynchronous I/O using the network poller.
|
||||
func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, ctrlFn func(string, string, syscall.RawConn) error) (*netFD, error) {
|
||||
func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) (*netFD, error) {
|
||||
fd := &netFD{family: family, sotype: sotype, net: net}
|
||||
|
||||
if laddr != nil && raddr == nil { // listener
|
||||
|
@ -15,7 +15,7 @@ import (
|
||||
|
||||
// socket returns a network file descriptor that is ready for
|
||||
// asynchronous I/O using the network poller.
|
||||
func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, ctrlFn func(string, string, syscall.RawConn) error) (fd *netFD, err error) {
|
||||
func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) (fd *netFD, err error) {
|
||||
s, err := sysSocket(family, sotype, proto)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -54,20 +54,20 @@ func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only
|
||||
if laddr != nil && raddr == nil {
|
||||
switch sotype {
|
||||
case syscall.SOCK_STREAM, syscall.SOCK_SEQPACKET:
|
||||
if err := fd.listenStream(laddr, listenerBacklog(), ctrlFn); err != nil {
|
||||
if err := fd.listenStream(ctx, laddr, listenerBacklog(), ctrlCtxFn); err != nil {
|
||||
fd.Close()
|
||||
return nil, err
|
||||
}
|
||||
return fd, nil
|
||||
case syscall.SOCK_DGRAM:
|
||||
if err := fd.listenDatagram(laddr, ctrlFn); err != nil {
|
||||
if err := fd.listenDatagram(ctx, laddr, ctrlCtxFn); err != nil {
|
||||
fd.Close()
|
||||
return nil, err
|
||||
}
|
||||
return fd, nil
|
||||
}
|
||||
}
|
||||
if err := fd.dial(ctx, laddr, raddr, ctrlFn); err != nil {
|
||||
if err := fd.dial(ctx, laddr, raddr, ctrlCtxFn); err != nil {
|
||||
fd.Close()
|
||||
return nil, err
|
||||
}
|
||||
@ -113,9 +113,11 @@ func (fd *netFD) addrFunc() func(syscall.Sockaddr) Addr {
|
||||
return func(syscall.Sockaddr) Addr { return nil }
|
||||
}
|
||||
|
||||
func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr, ctrlFn func(string, string, syscall.RawConn) error) error {
|
||||
if ctrlFn != nil {
|
||||
c, err := newRawConn(fd)
|
||||
func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) error {
|
||||
var c *rawConn
|
||||
var err error
|
||||
if ctrlCtxFn != nil {
|
||||
c, err = newRawConn(fd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -125,11 +127,11 @@ func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr, ctrlFn func(st
|
||||
} else if laddr != nil {
|
||||
ctrlAddr = laddr.String()
|
||||
}
|
||||
if err := ctrlFn(fd.ctrlNetwork(), ctrlAddr, c); err != nil {
|
||||
if err := ctrlCtxFn(ctx, fd.ctrlNetwork(), ctrlAddr, c); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
var err error
|
||||
|
||||
var lsa syscall.Sockaddr
|
||||
if laddr != nil {
|
||||
if lsa, err = laddr.sockaddr(fd.family); err != nil {
|
||||
@ -172,7 +174,7 @@ func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr, ctrlFn func(st
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fd *netFD) listenStream(laddr sockaddr, backlog int, ctrlFn func(string, string, syscall.RawConn) error) error {
|
||||
func (fd *netFD) listenStream(ctx context.Context, laddr sockaddr, backlog int, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) error {
|
||||
var err error
|
||||
if err = setDefaultListenerSockopts(fd.pfd.Sysfd); err != nil {
|
||||
return err
|
||||
@ -181,15 +183,17 @@ func (fd *netFD) listenStream(laddr sockaddr, backlog int, ctrlFn func(string, s
|
||||
if lsa, err = laddr.sockaddr(fd.family); err != nil {
|
||||
return err
|
||||
}
|
||||
if ctrlFn != nil {
|
||||
|
||||
if ctrlCtxFn != nil {
|
||||
c, err := newRawConn(fd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ctrlFn(fd.ctrlNetwork(), laddr.String(), c); err != nil {
|
||||
if err := ctrlCtxFn(ctx, fd.ctrlNetwork(), laddr.String(), c); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err = syscall.Bind(fd.pfd.Sysfd, lsa); err != nil {
|
||||
return os.NewSyscallError("bind", err)
|
||||
}
|
||||
@ -204,7 +208,7 @@ func (fd *netFD) listenStream(laddr sockaddr, backlog int, ctrlFn func(string, s
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fd *netFD) listenDatagram(laddr sockaddr, ctrlFn func(string, string, syscall.RawConn) error) error {
|
||||
func (fd *netFD) listenDatagram(ctx context.Context, laddr sockaddr, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) error {
|
||||
switch addr := laddr.(type) {
|
||||
case *UDPAddr:
|
||||
// We provide a socket that listens to a wildcard
|
||||
@ -233,12 +237,13 @@ func (fd *netFD) listenDatagram(laddr sockaddr, ctrlFn func(string, string, sysc
|
||||
if lsa, err = laddr.sockaddr(fd.family); err != nil {
|
||||
return err
|
||||
}
|
||||
if ctrlFn != nil {
|
||||
|
||||
if ctrlCtxFn != nil {
|
||||
c, err := newRawConn(fd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ctrlFn(fd.ctrlNetwork(), laddr.String(), c); err != nil {
|
||||
if err := ctrlCtxFn(ctx, fd.ctrlNetwork(), laddr.String(), c); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -65,7 +65,13 @@ func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPCo
|
||||
}
|
||||
|
||||
func (sd *sysDialer) doDialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) {
|
||||
fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial", sd.Dialer.Control)
|
||||
ctrlCtxFn := sd.Dialer.ControlContext
|
||||
if ctrlCtxFn == nil && sd.Dialer.Control != nil {
|
||||
ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
|
||||
return sd.Dialer.Control(network, address, c)
|
||||
}
|
||||
}
|
||||
fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial", ctrlCtxFn)
|
||||
|
||||
// TCP has a rarely used mechanism called a 'simultaneous connection' in
|
||||
// which Dial("tcp", addr1, addr2) run on the machine at addr1 can
|
||||
@ -95,7 +101,7 @@ func (sd *sysDialer) doDialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCP
|
||||
if err == nil {
|
||||
fd.Close()
|
||||
}
|
||||
fd, err = internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial", sd.Dialer.Control)
|
||||
fd, err = internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial", ctrlCtxFn)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@ -168,7 +174,13 @@ func (ln *TCPListener) file() (*os.File, error) {
|
||||
}
|
||||
|
||||
func (sl *sysListener) listenTCP(ctx context.Context, laddr *TCPAddr) (*TCPListener, error) {
|
||||
fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_STREAM, 0, "listen", sl.ListenConfig.Control)
|
||||
var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error
|
||||
if sl.ListenConfig.Control != nil {
|
||||
ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
|
||||
return sl.ListenConfig.Control(network, address, c)
|
||||
}
|
||||
}
|
||||
fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_STREAM, 0, "listen", ctrlCtxFn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -203,7 +203,13 @@ func (c *UDPConn) writeMsgAddrPort(b, oob []byte, addr netip.AddrPort) (n, oobn
|
||||
}
|
||||
|
||||
func (sd *sysDialer) dialUDP(ctx context.Context, laddr, raddr *UDPAddr) (*UDPConn, error) {
|
||||
fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_DGRAM, 0, "dial", sd.Dialer.Control)
|
||||
ctrlCtxFn := sd.Dialer.ControlContext
|
||||
if ctrlCtxFn == nil && sd.Dialer.Control != nil {
|
||||
ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
|
||||
return sd.Dialer.Control(network, address, c)
|
||||
}
|
||||
}
|
||||
fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_DGRAM, 0, "dial", ctrlCtxFn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -211,7 +217,13 @@ func (sd *sysDialer) dialUDP(ctx context.Context, laddr, raddr *UDPAddr) (*UDPCo
|
||||
}
|
||||
|
||||
func (sl *sysListener) listenUDP(ctx context.Context, laddr *UDPAddr) (*UDPConn, error) {
|
||||
fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_DGRAM, 0, "listen", sl.ListenConfig.Control)
|
||||
var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error
|
||||
if sl.ListenConfig.Control != nil {
|
||||
ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
|
||||
return sl.ListenConfig.Control(network, address, c)
|
||||
}
|
||||
}
|
||||
fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_DGRAM, 0, "listen", ctrlCtxFn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -219,7 +231,13 @@ func (sl *sysListener) listenUDP(ctx context.Context, laddr *UDPAddr) (*UDPConn,
|
||||
}
|
||||
|
||||
func (sl *sysListener) listenMulticastUDP(ctx context.Context, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) {
|
||||
fd, err := internetSocket(ctx, sl.network, gaddr, nil, syscall.SOCK_DGRAM, 0, "listen", sl.ListenConfig.Control)
|
||||
var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error
|
||||
if sl.ListenConfig.Control != nil {
|
||||
ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
|
||||
return sl.ListenConfig.Control(network, address, c)
|
||||
}
|
||||
}
|
||||
fd, err := internetSocket(ctx, sl.network, gaddr, nil, syscall.SOCK_DGRAM, 0, "listen", ctrlCtxFn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -13,7 +13,7 @@ import (
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func unixSocket(ctx context.Context, net string, laddr, raddr sockaddr, mode string, ctrlFn func(string, string, syscall.RawConn) error) (*netFD, error) {
|
||||
func unixSocket(ctx context.Context, net string, laddr, raddr sockaddr, mode string, ctxCtrlFn func(context.Context, string, string, syscall.RawConn) error) (*netFD, error) {
|
||||
var sotype int
|
||||
switch net {
|
||||
case "unix":
|
||||
@ -42,7 +42,7 @@ func unixSocket(ctx context.Context, net string, laddr, raddr sockaddr, mode str
|
||||
return nil, errors.New("unknown mode: " + mode)
|
||||
}
|
||||
|
||||
fd, err := socket(ctx, net, syscall.AF_UNIX, sotype, 0, false, laddr, raddr, ctrlFn)
|
||||
fd, err := socket(ctx, net, syscall.AF_UNIX, sotype, 0, false, laddr, raddr, ctxCtrlFn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -155,7 +155,13 @@ func (c *UnixConn) writeMsg(b, oob []byte, addr *UnixAddr) (n, oobn int, err err
|
||||
}
|
||||
|
||||
func (sd *sysDialer) dialUnix(ctx context.Context, laddr, raddr *UnixAddr) (*UnixConn, error) {
|
||||
fd, err := unixSocket(ctx, sd.network, laddr, raddr, "dial", sd.Dialer.Control)
|
||||
ctrlCtxFn := sd.Dialer.ControlContext
|
||||
if ctrlCtxFn == nil && sd.Dialer.Control != nil {
|
||||
ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
|
||||
return sd.Dialer.Control(network, address, c)
|
||||
}
|
||||
}
|
||||
fd, err := unixSocket(ctx, sd.network, laddr, raddr, "dial", ctrlCtxFn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -211,7 +217,13 @@ func (l *UnixListener) SetUnlinkOnClose(unlink bool) {
|
||||
}
|
||||
|
||||
func (sl *sysListener) listenUnix(ctx context.Context, laddr *UnixAddr) (*UnixListener, error) {
|
||||
fd, err := unixSocket(ctx, sl.network, laddr, nil, "listen", sl.ListenConfig.Control)
|
||||
var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error
|
||||
if sl.ListenConfig.Control != nil {
|
||||
ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
|
||||
return sl.ListenConfig.Control(network, address, c)
|
||||
}
|
||||
}
|
||||
fd, err := unixSocket(ctx, sl.network, laddr, nil, "listen", ctrlCtxFn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -219,7 +231,13 @@ func (sl *sysListener) listenUnix(ctx context.Context, laddr *UnixAddr) (*UnixLi
|
||||
}
|
||||
|
||||
func (sl *sysListener) listenUnixgram(ctx context.Context, laddr *UnixAddr) (*UnixConn, error) {
|
||||
fd, err := unixSocket(ctx, sl.network, laddr, nil, "listen", sl.ListenConfig.Control)
|
||||
var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error
|
||||
if sl.ListenConfig.Control != nil {
|
||||
ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
|
||||
return sl.ListenConfig.Control(network, address, c)
|
||||
}
|
||||
}
|
||||
fd, err := unixSocket(ctx, sl.network, laddr, nil, "listen", ctrlCtxFn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user