mirror of
https://github.com/golang/go
synced 2024-11-23 06:40:05 -07:00
net: fix testHookDialTCP race
CL 410754 introduces a race accessing the global testHookDialTCP hook. Avoiding this race is difficult, since Dial can return while goroutines it starts are still running. Add a version of this hook to sysDialer, so it can be set on a per-test basis. (Perhaps other uses of this hook should be moved to use the sysDialer-local hook, but this change fixes the immediate data race.) For #52173. Change-Id: I8fb9be13957e91f92919cae7be213c38ad2af75a Reviewed-on: https://go-review.googlesource.com/c/go/+/410957 Run-TryBot: Damien Neil <dneil@google.com> Reviewed-by: Ian Lance Taylor <iant@golang.org> Reviewed-by: Cherry Mui <cherryyz@google.com> TryBot-Result: Gopher Robot <gobot@golang.org>
This commit is contained in:
parent
899f0a29c7
commit
432158b69a
@ -341,6 +341,7 @@ func DialTimeout(network, address string, timeout time.Duration) (Conn, error) {
|
|||||||
type sysDialer struct {
|
type sysDialer struct {
|
||||||
Dialer
|
Dialer
|
||||||
network, address string
|
network, address string
|
||||||
|
testHookDialTCP func(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dial connects to the address on the named network.
|
// Dial connects to the address on the named network.
|
||||||
|
@ -234,9 +234,7 @@ func TestDialParallel(t *testing.T) {
|
|||||||
for i, tt := range testCases {
|
for i, tt := range testCases {
|
||||||
i, tt := i, tt
|
i, tt := i, tt
|
||||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||||
origTestHookDialTCP := testHookDialTCP
|
dialTCP := func(ctx context.Context, network string, laddr, raddr *TCPAddr) (*TCPConn, error) {
|
||||||
defer func() { testHookDialTCP = origTestHookDialTCP }()
|
|
||||||
testHookDialTCP = func(ctx context.Context, network string, laddr, raddr *TCPAddr) (*TCPConn, error) {
|
|
||||||
n := "tcp6"
|
n := "tcp6"
|
||||||
if raddr.IP.To4() != nil {
|
if raddr.IP.To4() != nil {
|
||||||
n = "tcp4"
|
n = "tcp4"
|
||||||
@ -265,6 +263,7 @@ func TestDialParallel(t *testing.T) {
|
|||||||
Dialer: d,
|
Dialer: d,
|
||||||
network: "tcp",
|
network: "tcp",
|
||||||
address: "?",
|
address: "?",
|
||||||
|
testHookDialTCP: dialTCP,
|
||||||
}
|
}
|
||||||
c, err := sd.dialParallel(context.Background(), primaries, fallbacks)
|
c, err := sd.dialParallel(context.Background(), primaries, fallbacks)
|
||||||
elapsed := time.Since(startTime)
|
elapsed := time.Since(startTime)
|
||||||
|
@ -15,8 +15,11 @@ func (c *TCPConn) readFrom(r io.Reader) (int64, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) {
|
func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) {
|
||||||
if testHookDialTCP != nil {
|
if h := sd.testHookDialTCP; h != nil {
|
||||||
return testHookDialTCP(ctx, sd.network, laddr, raddr)
|
return h(ctx, sd.network, laddr, raddr)
|
||||||
|
}
|
||||||
|
if h := testHookDialTCP; h != nil {
|
||||||
|
return h(ctx, sd.network, laddr, raddr)
|
||||||
}
|
}
|
||||||
return sd.doDialTCP(ctx, laddr, raddr)
|
return sd.doDialTCP(ctx, laddr, raddr)
|
||||||
}
|
}
|
||||||
|
@ -55,8 +55,11 @@ func (c *TCPConn) readFrom(r io.Reader) (int64, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) {
|
func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) {
|
||||||
if testHookDialTCP != nil {
|
if h := sd.testHookDialTCP; h != nil {
|
||||||
return testHookDialTCP(ctx, sd.network, laddr, raddr)
|
return h(ctx, sd.network, laddr, raddr)
|
||||||
|
}
|
||||||
|
if h := testHookDialTCP; h != nil {
|
||||||
|
return h(ctx, sd.network, laddr, raddr)
|
||||||
}
|
}
|
||||||
return sd.doDialTCP(ctx, laddr, raddr)
|
return sd.doDialTCP(ctx, laddr, raddr)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user