1
0
mirror of https://github.com/golang/go synced 2024-11-23 07:30:05 -07:00

net: fix testHookDialTCP race

CL 410754 introduces a race accessing the global testHookDialTCP hook.
Avoiding this race is difficult, since Dial can return while
goroutines it starts are still running. Add a version of this
hook to sysDialer, so it can be set on a per-test basis.

(Perhaps other uses of this hook should be moved to use the
sysDialer-local hook, but this change fixes the immediate data race.)

For #52173.

Change-Id: I8fb9be13957e91f92919cae7be213c38ad2af75a
Reviewed-on: https://go-review.googlesource.com/c/go/+/410957
Run-TryBot: Damien Neil <dneil@google.com>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
Reviewed-by: Cherry Mui <cherryyz@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
This commit is contained in:
Damien Neil 2022-06-07 16:53:53 -07:00
parent 899f0a29c7
commit 432158b69a
4 changed files with 16 additions and 10 deletions

View File

@ -341,6 +341,7 @@ func DialTimeout(network, address string, timeout time.Duration) (Conn, error) {
type sysDialer struct { type sysDialer struct {
Dialer Dialer
network, address string network, address string
testHookDialTCP func(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error)
} }
// Dial connects to the address on the named network. // Dial connects to the address on the named network.

View File

@ -234,9 +234,7 @@ func TestDialParallel(t *testing.T) {
for i, tt := range testCases { for i, tt := range testCases {
i, tt := i, tt i, tt := i, tt
t.Run(fmt.Sprint(i), func(t *testing.T) { t.Run(fmt.Sprint(i), func(t *testing.T) {
origTestHookDialTCP := testHookDialTCP dialTCP := func(ctx context.Context, network string, laddr, raddr *TCPAddr) (*TCPConn, error) {
defer func() { testHookDialTCP = origTestHookDialTCP }()
testHookDialTCP = func(ctx context.Context, network string, laddr, raddr *TCPAddr) (*TCPConn, error) {
n := "tcp6" n := "tcp6"
if raddr.IP.To4() != nil { if raddr.IP.To4() != nil {
n = "tcp4" n = "tcp4"
@ -262,9 +260,10 @@ func TestDialParallel(t *testing.T) {
} }
startTime := time.Now() startTime := time.Now()
sd := &sysDialer{ sd := &sysDialer{
Dialer: d, Dialer: d,
network: "tcp", network: "tcp",
address: "?", address: "?",
testHookDialTCP: dialTCP,
} }
c, err := sd.dialParallel(context.Background(), primaries, fallbacks) c, err := sd.dialParallel(context.Background(), primaries, fallbacks)
elapsed := time.Since(startTime) elapsed := time.Since(startTime)

View File

@ -15,8 +15,11 @@ func (c *TCPConn) readFrom(r io.Reader) (int64, error) {
} }
func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) { func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) {
if testHookDialTCP != nil { if h := sd.testHookDialTCP; h != nil {
return testHookDialTCP(ctx, sd.network, laddr, raddr) return h(ctx, sd.network, laddr, raddr)
}
if h := testHookDialTCP; h != nil {
return h(ctx, sd.network, laddr, raddr)
} }
return sd.doDialTCP(ctx, laddr, raddr) return sd.doDialTCP(ctx, laddr, raddr)
} }

View File

@ -55,8 +55,11 @@ func (c *TCPConn) readFrom(r io.Reader) (int64, error) {
} }
func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) { func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) {
if testHookDialTCP != nil { if h := sd.testHookDialTCP; h != nil {
return testHookDialTCP(ctx, sd.network, laddr, raddr) return h(ctx, sd.network, laddr, raddr)
}
if h := testHookDialTCP; h != nil {
return h(ctx, sd.network, laddr, raddr)
} }
return sd.doDialTCP(ctx, laddr, raddr) return sd.doDialTCP(ctx, laddr, raddr)
} }