1
0
mirror of https://github.com/golang/go synced 2024-11-25 06:57:58 -07:00

net: add ControlContext to Dialer

Fixes #55301

Change-Id: Ie8abcd383eee9af75038bde908ac638f43d33b7e
Reviewed-on: https://go-review.googlesource.com/c/go/+/444955
Reviewed-by: Bryan Mills <bcmills@google.com>
Reviewed-by: Ian Lance Taylor <iant@google.com>
Run-TryBot: Ian Lance Taylor <iant@google.com>
Auto-Submit: Ian Lance Taylor <iant@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: xie cui <523516579@qq.com>
This commit is contained in:
cuiweixie 2022-10-23 08:41:38 +08:00 committed by Gopher Robot
parent bc2dc23846
commit 90b40c0496
10 changed files with 140 additions and 31 deletions

1
api/next/55301.txt Normal file
View File

@ -0,0 +1 @@
pkg net, type Dialer struct, ControlContext func(context.Context, string, string, syscall.RawConn) error #55301

View File

@ -95,7 +95,19 @@ type Dialer struct {
// Network and address parameters passed to Control method are not
// necessarily the ones passed to Dial. For example, passing "tcp" to Dial
// will cause the Control function to be called with "tcp4" or "tcp6".
//
// Control is ignored if ControlContext is not nil.
Control func(network, address string, c syscall.RawConn) error
// If ControlContext is not nil, it is called after creating the network
// connection but before actually dialing.
//
// Network and address parameters passed to Control method are not
// necessarily the ones passed to Dial. For example, passing "tcp" to Dial
// will cause the Control function to be called with "tcp4" or "tcp6".
//
// If ControlContext is not nil, Control is ignored.
ControlContext func(cxt context.Context, network, address string, c syscall.RawConn) error
}
func (d *Dialer) dualStack() bool { return d.FallbackDelay >= 0 }

View File

@ -17,6 +17,7 @@ import (
"runtime"
"strings"
"sync"
"syscall"
"testing"
"time"
)
@ -939,6 +940,36 @@ func TestDialerControl(t *testing.T) {
})
}
func TestDialerControlContext(t *testing.T) {
switch runtime.GOOS {
case "plan9":
t.Skipf("%s does not have full support of socktest", runtime.GOOS)
}
t.Run("StreamDial", func(t *testing.T) {
for i, network := range []string{"tcp", "tcp4", "tcp6", "unix", "unixpacket"} {
if !testableNetwork(network) {
continue
}
ln := newLocalListener(t, network)
defer ln.Close()
var id int
d := Dialer{ControlContext: func(ctx context.Context, network string, address string, c syscall.RawConn) error {
id = ctx.Value("id").(int)
return controlOnConnSetup(network, address, c)
}}
c, err := d.DialContext(context.WithValue(context.Background(), "id", i+1), network, ln.Addr().String())
if err != nil {
t.Error(err)
continue
}
if id != i+1 {
t.Errorf("got id %d, want %d", id, i+1)
}
c.Close()
}
})
}
// mustHaveExternalNetwork is like testenv.MustHaveExternalNetwork
// except that it won't skip testing on non-mobile builders.
func mustHaveExternalNetwork(t *testing.T) {

View File

@ -122,7 +122,13 @@ func (sd *sysDialer) dialIP(ctx context.Context, laddr, raddr *IPAddr) (*IPConn,
default:
return nil, UnknownNetworkError(sd.network)
}
fd, err := internetSocket(ctx, network, laddr, raddr, syscall.SOCK_RAW, proto, "dial", sd.Dialer.Control)
ctrlCtxFn := sd.Dialer.ControlContext
if ctrlCtxFn == nil && sd.Dialer.Control != nil {
ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
return sd.Dialer.Control(network, address, c)
}
}
fd, err := internetSocket(ctx, network, laddr, raddr, syscall.SOCK_RAW, proto, "dial", ctrlCtxFn)
if err != nil {
return nil, err
}
@ -139,7 +145,13 @@ func (sl *sysListener) listenIP(ctx context.Context, laddr *IPAddr) (*IPConn, er
default:
return nil, UnknownNetworkError(sl.network)
}
fd, err := internetSocket(ctx, network, laddr, nil, syscall.SOCK_RAW, proto, "listen", sl.ListenConfig.Control)
var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error
if sl.ListenConfig.Control != nil {
ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
return sl.ListenConfig.Control(network, address, c)
}
}
fd, err := internetSocket(ctx, network, laddr, nil, syscall.SOCK_RAW, proto, "listen", ctrlCtxFn)
if err != nil {
return nil, err
}

View File

@ -134,12 +134,12 @@ func favoriteAddrFamily(network string, laddr, raddr sockaddr, mode string) (fam
return syscall.AF_INET6, false
}
func internetSocket(ctx context.Context, net string, laddr, raddr sockaddr, sotype, proto int, mode string, ctrlFn func(string, string, syscall.RawConn) error) (fd *netFD, err error) {
func internetSocket(ctx context.Context, net string, laddr, raddr sockaddr, sotype, proto int, mode string, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) (fd *netFD, err error) {
if (runtime.GOOS == "aix" || runtime.GOOS == "windows" || runtime.GOOS == "openbsd") && mode == "dial" && raddr.isWildcard() {
raddr = raddr.toLocal(net)
}
family, ipv6only := favoriteAddrFamily(net, laddr, raddr, mode)
return socket(ctx, net, family, sotype, proto, ipv6only, laddr, raddr, ctrlFn)
return socket(ctx, net, family, sotype, proto, ipv6only, laddr, raddr, ctrlCtxFn)
}
func ipToSockaddrInet4(ip IP, port int) (syscall.SockaddrInet4, error) {

View File

@ -57,7 +57,7 @@ type netFD struct {
// socket returns a network file descriptor that is ready for
// asynchronous I/O using the network poller.
func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, ctrlFn func(string, string, syscall.RawConn) error) (*netFD, error) {
func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) (*netFD, error) {
fd := &netFD{family: family, sotype: sotype, net: net}
if laddr != nil && raddr == nil { // listener

View File

@ -15,7 +15,7 @@ import (
// socket returns a network file descriptor that is ready for
// asynchronous I/O using the network poller.
func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, ctrlFn func(string, string, syscall.RawConn) error) (fd *netFD, err error) {
func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) (fd *netFD, err error) {
s, err := sysSocket(family, sotype, proto)
if err != nil {
return nil, err
@ -54,20 +54,20 @@ func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only
if laddr != nil && raddr == nil {
switch sotype {
case syscall.SOCK_STREAM, syscall.SOCK_SEQPACKET:
if err := fd.listenStream(laddr, listenerBacklog(), ctrlFn); err != nil {
if err := fd.listenStream(ctx, laddr, listenerBacklog(), ctrlCtxFn); err != nil {
fd.Close()
return nil, err
}
return fd, nil
case syscall.SOCK_DGRAM:
if err := fd.listenDatagram(laddr, ctrlFn); err != nil {
if err := fd.listenDatagram(ctx, laddr, ctrlCtxFn); err != nil {
fd.Close()
return nil, err
}
return fd, nil
}
}
if err := fd.dial(ctx, laddr, raddr, ctrlFn); err != nil {
if err := fd.dial(ctx, laddr, raddr, ctrlCtxFn); err != nil {
fd.Close()
return nil, err
}
@ -113,9 +113,11 @@ func (fd *netFD) addrFunc() func(syscall.Sockaddr) Addr {
return func(syscall.Sockaddr) Addr { return nil }
}
func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr, ctrlFn func(string, string, syscall.RawConn) error) error {
if ctrlFn != nil {
c, err := newRawConn(fd)
func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) error {
var c *rawConn
var err error
if ctrlCtxFn != nil {
c, err = newRawConn(fd)
if err != nil {
return err
}
@ -125,11 +127,11 @@ func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr, ctrlFn func(st
} else if laddr != nil {
ctrlAddr = laddr.String()
}
if err := ctrlFn(fd.ctrlNetwork(), ctrlAddr, c); err != nil {
if err := ctrlCtxFn(ctx, fd.ctrlNetwork(), ctrlAddr, c); err != nil {
return err
}
}
var err error
var lsa syscall.Sockaddr
if laddr != nil {
if lsa, err = laddr.sockaddr(fd.family); err != nil {
@ -172,7 +174,7 @@ func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr, ctrlFn func(st
return nil
}
func (fd *netFD) listenStream(laddr sockaddr, backlog int, ctrlFn func(string, string, syscall.RawConn) error) error {
func (fd *netFD) listenStream(ctx context.Context, laddr sockaddr, backlog int, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) error {
var err error
if err = setDefaultListenerSockopts(fd.pfd.Sysfd); err != nil {
return err
@ -181,15 +183,17 @@ func (fd *netFD) listenStream(laddr sockaddr, backlog int, ctrlFn func(string, s
if lsa, err = laddr.sockaddr(fd.family); err != nil {
return err
}
if ctrlFn != nil {
if ctrlCtxFn != nil {
c, err := newRawConn(fd)
if err != nil {
return err
}
if err := ctrlFn(fd.ctrlNetwork(), laddr.String(), c); err != nil {
if err := ctrlCtxFn(ctx, fd.ctrlNetwork(), laddr.String(), c); err != nil {
return err
}
}
if err = syscall.Bind(fd.pfd.Sysfd, lsa); err != nil {
return os.NewSyscallError("bind", err)
}
@ -204,7 +208,7 @@ func (fd *netFD) listenStream(laddr sockaddr, backlog int, ctrlFn func(string, s
return nil
}
func (fd *netFD) listenDatagram(laddr sockaddr, ctrlFn func(string, string, syscall.RawConn) error) error {
func (fd *netFD) listenDatagram(ctx context.Context, laddr sockaddr, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) error {
switch addr := laddr.(type) {
case *UDPAddr:
// We provide a socket that listens to a wildcard
@ -233,12 +237,13 @@ func (fd *netFD) listenDatagram(laddr sockaddr, ctrlFn func(string, string, sysc
if lsa, err = laddr.sockaddr(fd.family); err != nil {
return err
}
if ctrlFn != nil {
if ctrlCtxFn != nil {
c, err := newRawConn(fd)
if err != nil {
return err
}
if err := ctrlFn(fd.ctrlNetwork(), laddr.String(), c); err != nil {
if err := ctrlCtxFn(ctx, fd.ctrlNetwork(), laddr.String(), c); err != nil {
return err
}
}

View File

@ -65,7 +65,13 @@ func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPCo
}
func (sd *sysDialer) doDialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) {
fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial", sd.Dialer.Control)
ctrlCtxFn := sd.Dialer.ControlContext
if ctrlCtxFn == nil && sd.Dialer.Control != nil {
ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
return sd.Dialer.Control(network, address, c)
}
}
fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial", ctrlCtxFn)
// TCP has a rarely used mechanism called a 'simultaneous connection' in
// which Dial("tcp", addr1, addr2) run on the machine at addr1 can
@ -95,7 +101,7 @@ func (sd *sysDialer) doDialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCP
if err == nil {
fd.Close()
}
fd, err = internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial", sd.Dialer.Control)
fd, err = internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial", ctrlCtxFn)
}
if err != nil {
@ -168,7 +174,13 @@ func (ln *TCPListener) file() (*os.File, error) {
}
func (sl *sysListener) listenTCP(ctx context.Context, laddr *TCPAddr) (*TCPListener, error) {
fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_STREAM, 0, "listen", sl.ListenConfig.Control)
var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error
if sl.ListenConfig.Control != nil {
ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
return sl.ListenConfig.Control(network, address, c)
}
}
fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_STREAM, 0, "listen", ctrlCtxFn)
if err != nil {
return nil, err
}

View File

@ -203,7 +203,13 @@ func (c *UDPConn) writeMsgAddrPort(b, oob []byte, addr netip.AddrPort) (n, oobn
}
func (sd *sysDialer) dialUDP(ctx context.Context, laddr, raddr *UDPAddr) (*UDPConn, error) {
fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_DGRAM, 0, "dial", sd.Dialer.Control)
ctrlCtxFn := sd.Dialer.ControlContext
if ctrlCtxFn == nil && sd.Dialer.Control != nil {
ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
return sd.Dialer.Control(network, address, c)
}
}
fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_DGRAM, 0, "dial", ctrlCtxFn)
if err != nil {
return nil, err
}
@ -211,7 +217,13 @@ func (sd *sysDialer) dialUDP(ctx context.Context, laddr, raddr *UDPAddr) (*UDPCo
}
func (sl *sysListener) listenUDP(ctx context.Context, laddr *UDPAddr) (*UDPConn, error) {
fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_DGRAM, 0, "listen", sl.ListenConfig.Control)
var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error
if sl.ListenConfig.Control != nil {
ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
return sl.ListenConfig.Control(network, address, c)
}
}
fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_DGRAM, 0, "listen", ctrlCtxFn)
if err != nil {
return nil, err
}
@ -219,7 +231,13 @@ func (sl *sysListener) listenUDP(ctx context.Context, laddr *UDPAddr) (*UDPConn,
}
func (sl *sysListener) listenMulticastUDP(ctx context.Context, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) {
fd, err := internetSocket(ctx, sl.network, gaddr, nil, syscall.SOCK_DGRAM, 0, "listen", sl.ListenConfig.Control)
var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error
if sl.ListenConfig.Control != nil {
ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
return sl.ListenConfig.Control(network, address, c)
}
}
fd, err := internetSocket(ctx, sl.network, gaddr, nil, syscall.SOCK_DGRAM, 0, "listen", ctrlCtxFn)
if err != nil {
return nil, err
}

View File

@ -13,7 +13,7 @@ import (
"syscall"
)
func unixSocket(ctx context.Context, net string, laddr, raddr sockaddr, mode string, ctrlFn func(string, string, syscall.RawConn) error) (*netFD, error) {
func unixSocket(ctx context.Context, net string, laddr, raddr sockaddr, mode string, ctxCtrlFn func(context.Context, string, string, syscall.RawConn) error) (*netFD, error) {
var sotype int
switch net {
case "unix":
@ -42,7 +42,7 @@ func unixSocket(ctx context.Context, net string, laddr, raddr sockaddr, mode str
return nil, errors.New("unknown mode: " + mode)
}
fd, err := socket(ctx, net, syscall.AF_UNIX, sotype, 0, false, laddr, raddr, ctrlFn)
fd, err := socket(ctx, net, syscall.AF_UNIX, sotype, 0, false, laddr, raddr, ctxCtrlFn)
if err != nil {
return nil, err
}
@ -155,7 +155,13 @@ func (c *UnixConn) writeMsg(b, oob []byte, addr *UnixAddr) (n, oobn int, err err
}
func (sd *sysDialer) dialUnix(ctx context.Context, laddr, raddr *UnixAddr) (*UnixConn, error) {
fd, err := unixSocket(ctx, sd.network, laddr, raddr, "dial", sd.Dialer.Control)
ctrlCtxFn := sd.Dialer.ControlContext
if ctrlCtxFn == nil && sd.Dialer.Control != nil {
ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
return sd.Dialer.Control(network, address, c)
}
}
fd, err := unixSocket(ctx, sd.network, laddr, raddr, "dial", ctrlCtxFn)
if err != nil {
return nil, err
}
@ -211,7 +217,13 @@ func (l *UnixListener) SetUnlinkOnClose(unlink bool) {
}
func (sl *sysListener) listenUnix(ctx context.Context, laddr *UnixAddr) (*UnixListener, error) {
fd, err := unixSocket(ctx, sl.network, laddr, nil, "listen", sl.ListenConfig.Control)
var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error
if sl.ListenConfig.Control != nil {
ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
return sl.ListenConfig.Control(network, address, c)
}
}
fd, err := unixSocket(ctx, sl.network, laddr, nil, "listen", ctrlCtxFn)
if err != nil {
return nil, err
}
@ -219,7 +231,13 @@ func (sl *sysListener) listenUnix(ctx context.Context, laddr *UnixAddr) (*UnixLi
}
func (sl *sysListener) listenUnixgram(ctx context.Context, laddr *UnixAddr) (*UnixConn, error) {
fd, err := unixSocket(ctx, sl.network, laddr, nil, "listen", sl.ListenConfig.Control)
var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error
if sl.ListenConfig.Control != nil {
ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
return sl.ListenConfig.Control(network, address, c)
}
}
fd, err := unixSocket(ctx, sl.network, laddr, nil, "listen", ctrlCtxFn)
if err != nil {
return nil, err
}