mirror of
https://github.com/golang/go
synced 2024-11-18 12:54:44 -07:00
net: avoid race on test hooks with DNS goroutines
The DNS code can start goroutines and not wait for them to complete. This does no harm, but in tests this can cause a race condition with the test hooks that are installed and unintalled around the tests. Add a WaitGroup that tests of DNS can use to avoid the race. Fixes #21090 Change-Id: I6c1443a9c2378e8b89d0ab1d6390c0e3e726b0ce Reviewed-on: https://go-review.googlesource.com/82795 Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
This commit is contained in:
parent
6af8c0d812
commit
6c877e5da7
@ -65,8 +65,10 @@ func (g *Group) Do(key string, fn func() (interface{}, error)) (v interface{}, e
|
||||
}
|
||||
|
||||
// DoChan is like Do but returns a channel that will receive the
|
||||
// results when they are ready.
|
||||
func (g *Group) DoChan(key string, fn func() (interface{}, error)) <-chan Result {
|
||||
// results when they are ready. The second result is true if the function
|
||||
// will eventually be called, false if it will not (because there is
|
||||
// a pending request with this key).
|
||||
func (g *Group) DoChan(key string, fn func() (interface{}, error)) (<-chan Result, bool) {
|
||||
ch := make(chan Result, 1)
|
||||
g.mu.Lock()
|
||||
if g.m == nil {
|
||||
@ -76,7 +78,7 @@ func (g *Group) DoChan(key string, fn func() (interface{}, error)) <-chan Result
|
||||
c.dups++
|
||||
c.chans = append(c.chans, ch)
|
||||
g.mu.Unlock()
|
||||
return ch
|
||||
return ch, false
|
||||
}
|
||||
c := &call{chans: []chan<- Result{ch}}
|
||||
c.wg.Add(1)
|
||||
@ -85,7 +87,7 @@ func (g *Group) DoChan(key string, fn func() (interface{}, error)) <-chan Result
|
||||
|
||||
go g.doCall(c, key, fn)
|
||||
|
||||
return ch
|
||||
return ch, true
|
||||
}
|
||||
|
||||
// doCall handles the single call for a key.
|
||||
|
@ -13,6 +13,7 @@ import (
|
||||
)
|
||||
|
||||
func TestCgoLookupIP(t *testing.T) {
|
||||
defer dnsWaitGroup.Wait()
|
||||
ctx := context.Background()
|
||||
_, err, ok := cgoLookupIP(ctx, "localhost")
|
||||
if !ok {
|
||||
@ -24,6 +25,7 @@ func TestCgoLookupIP(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCgoLookupIPWithCancel(t *testing.T) {
|
||||
defer dnsWaitGroup.Wait()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
_, err, ok := cgoLookupIP(ctx, "localhost")
|
||||
@ -36,6 +38,7 @@ func TestCgoLookupIPWithCancel(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCgoLookupPort(t *testing.T) {
|
||||
defer dnsWaitGroup.Wait()
|
||||
ctx := context.Background()
|
||||
_, err, ok := cgoLookupPort(ctx, "tcp", "smtp")
|
||||
if !ok {
|
||||
@ -47,6 +50,7 @@ func TestCgoLookupPort(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCgoLookupPortWithCancel(t *testing.T) {
|
||||
defer dnsWaitGroup.Wait()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
_, err, ok := cgoLookupPort(ctx, "tcp", "smtp")
|
||||
@ -59,6 +63,7 @@ func TestCgoLookupPortWithCancel(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCgoLookupPTR(t *testing.T) {
|
||||
defer dnsWaitGroup.Wait()
|
||||
ctx := context.Background()
|
||||
_, err, ok := cgoLookupPTR(ctx, "127.0.0.1")
|
||||
if !ok {
|
||||
@ -70,6 +75,7 @@ func TestCgoLookupPTR(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCgoLookupPTRWithCancel(t *testing.T) {
|
||||
defer dnsWaitGroup.Wait()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
_, err, ok := cgoLookupPTR(ctx, "127.0.0.1")
|
||||
|
@ -479,7 +479,9 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, name string, order
|
||||
var lastErr error
|
||||
for _, fqdn := range conf.nameList(name) {
|
||||
for _, qtype := range qtypes {
|
||||
dnsWaitGroup.Add(1)
|
||||
go func(qtype uint16) {
|
||||
defer dnsWaitGroup.Done()
|
||||
cname, rrs, err := r.tryOneName(ctx, conf, fqdn, qtype)
|
||||
lane <- racer{cname, rrs, err}
|
||||
}(qtype)
|
||||
|
@ -203,6 +203,7 @@ var fakeDNSServerSuccessful = fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.
|
||||
|
||||
// Issue 13705: don't try to resolve onion addresses, etc
|
||||
func TestLookupTorOnion(t *testing.T) {
|
||||
defer dnsWaitGroup.Wait()
|
||||
r := Resolver{PreferGo: true, Dial: fakeDNSServerSuccessful.DialContext}
|
||||
addrs, err := r.LookupIPAddr(context.Background(), "foo.onion")
|
||||
if err != nil {
|
||||
@ -300,6 +301,8 @@ var updateResolvConfTests = []struct {
|
||||
}
|
||||
|
||||
func TestUpdateResolvConf(t *testing.T) {
|
||||
defer dnsWaitGroup.Wait()
|
||||
|
||||
r := Resolver{PreferGo: true, Dial: fakeDNSServerSuccessful.DialContext}
|
||||
|
||||
conf, err := newResolvConfTest()
|
||||
@ -455,6 +458,8 @@ var goLookupIPWithResolverConfigTests = []struct {
|
||||
}
|
||||
|
||||
func TestGoLookupIPWithResolverConfig(t *testing.T) {
|
||||
defer dnsWaitGroup.Wait()
|
||||
|
||||
fake := fakeDNSServer{func(n, s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
|
||||
switch s {
|
||||
case "[2001:4860:4860::8888]:53", "8.8.8.8:53":
|
||||
@ -547,6 +552,8 @@ func TestGoLookupIPWithResolverConfig(t *testing.T) {
|
||||
|
||||
// Test that goLookupIPOrder falls back to the host file when no DNS servers are available.
|
||||
func TestGoLookupIPOrderFallbackToFile(t *testing.T) {
|
||||
defer dnsWaitGroup.Wait()
|
||||
|
||||
fake := fakeDNSServer{func(n, s string, q *dnsMsg, tm time.Time) (*dnsMsg, error) {
|
||||
r := &dnsMsg{
|
||||
dnsMsgHdr: dnsMsgHdr{
|
||||
@ -603,6 +610,8 @@ func TestGoLookupIPOrderFallbackToFile(t *testing.T) {
|
||||
// querying the original name instead of an error encountered
|
||||
// querying a generated name.
|
||||
func TestErrorForOriginalNameWhenSearching(t *testing.T) {
|
||||
defer dnsWaitGroup.Wait()
|
||||
|
||||
const fqdn = "doesnotexist.domain"
|
||||
|
||||
conf, err := newResolvConfTest()
|
||||
@ -657,6 +666,8 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) {
|
||||
|
||||
// Issue 15434. If a name server gives a lame referral, continue to the next.
|
||||
func TestIgnoreLameReferrals(t *testing.T) {
|
||||
defer dnsWaitGroup.Wait()
|
||||
|
||||
conf, err := newResolvConfTest()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@ -889,6 +900,8 @@ func TestIgnoreDNSForgeries(t *testing.T) {
|
||||
|
||||
// Issue 16865. If a name server times out, continue to the next.
|
||||
func TestRetryTimeout(t *testing.T) {
|
||||
defer dnsWaitGroup.Wait()
|
||||
|
||||
conf, err := newResolvConfTest()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@ -945,6 +958,8 @@ func TestRotate(t *testing.T) {
|
||||
}
|
||||
|
||||
func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) {
|
||||
defer dnsWaitGroup.Wait()
|
||||
|
||||
conf, err := newResolvConfTest()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@ -1008,6 +1023,8 @@ func mockTXTResponse(q *dnsMsg) *dnsMsg {
|
||||
// Issue 17448. With StrictErrors enabled, temporary errors should make
|
||||
// LookupIP fail rather than return a partial result.
|
||||
func TestStrictErrorsLookupIP(t *testing.T) {
|
||||
defer dnsWaitGroup.Wait()
|
||||
|
||||
conf, err := newResolvConfTest()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@ -1256,6 +1273,8 @@ func TestStrictErrorsLookupIP(t *testing.T) {
|
||||
// Issue 17448. With StrictErrors enabled, temporary errors should make
|
||||
// LookupTXT stop walking the search list.
|
||||
func TestStrictErrorsLookupTXT(t *testing.T) {
|
||||
defer dnsWaitGroup.Wait()
|
||||
|
||||
conf, err := newResolvConfTest()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@ -1312,3 +1331,25 @@ func TestStrictErrorsLookupTXT(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test for a race between uninstalling the test hooks and closing a
|
||||
// socket connection. This used to fail when testing with -race.
|
||||
func TestDNSGoroutineRace(t *testing.T) {
|
||||
defer dnsWaitGroup.Wait()
|
||||
|
||||
fake := fakeDNSServer{func(n, s string, q *dnsMsg, t time.Time) (*dnsMsg, error) {
|
||||
time.Sleep(10 * time.Microsecond)
|
||||
return nil, poll.ErrTimeout
|
||||
}}
|
||||
r := Resolver{PreferGo: true, Dial: fake.DialContext}
|
||||
|
||||
// The timeout here is less than the timeout used by the server,
|
||||
// so the goroutine started to query the (fake) server will hang
|
||||
// around after this test is done if we don't call dnsWaitGroup.Wait.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Microsecond)
|
||||
defer cancel()
|
||||
_, err := r.LookupIPAddr(ctx, "where.are.they.now")
|
||||
if err == nil {
|
||||
t.Fatal("fake DNS lookup unexpectedly succeeded")
|
||||
}
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"context"
|
||||
"internal/nettrace"
|
||||
"internal/singleflight"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// protocols contains minimal mappings between internet protocol
|
||||
@ -53,6 +54,10 @@ var services = map[string]map[string]int{
|
||||
},
|
||||
}
|
||||
|
||||
// dnsWaitGroup can be used by tests to wait for all DNS goroutines to
|
||||
// complete. This avoids races on the test hooks.
|
||||
var dnsWaitGroup sync.WaitGroup
|
||||
|
||||
const maxProtoLength = len("RSVP-E2E-IGNORE") + 10 // with room to grow
|
||||
|
||||
func lookupProtocolMap(name string) (int, error) {
|
||||
@ -189,9 +194,14 @@ func (r *Resolver) LookupIPAddr(ctx context.Context, host string) ([]IPAddr, err
|
||||
resolverFunc = alt
|
||||
}
|
||||
|
||||
ch := lookupGroup.DoChan(host, func() (interface{}, error) {
|
||||
dnsWaitGroup.Add(1)
|
||||
ch, called := lookupGroup.DoChan(host, func() (interface{}, error) {
|
||||
defer dnsWaitGroup.Done()
|
||||
return testHookLookupIP(ctx, resolverFunc, host)
|
||||
})
|
||||
if !called {
|
||||
dnsWaitGroup.Done()
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
|
@ -105,6 +105,8 @@ func TestLookupGmailMX(t *testing.T) {
|
||||
t.Skip("IPv4 is required")
|
||||
}
|
||||
|
||||
defer dnsWaitGroup.Wait()
|
||||
|
||||
for _, tt := range lookupGmailMXTests {
|
||||
mxs, err := LookupMX(tt.name)
|
||||
if err != nil {
|
||||
@ -137,6 +139,8 @@ func TestLookupGmailNS(t *testing.T) {
|
||||
t.Skip("IPv4 is required")
|
||||
}
|
||||
|
||||
defer dnsWaitGroup.Wait()
|
||||
|
||||
for _, tt := range lookupGmailNSTests {
|
||||
nss, err := LookupNS(tt.name)
|
||||
if err != nil {
|
||||
@ -170,6 +174,8 @@ func TestLookupGmailTXT(t *testing.T) {
|
||||
t.Skip("IPv4 is required")
|
||||
}
|
||||
|
||||
defer dnsWaitGroup.Wait()
|
||||
|
||||
for _, tt := range lookupGmailTXTTests {
|
||||
txts, err := LookupTXT(tt.name)
|
||||
if err != nil {
|
||||
@ -205,6 +211,8 @@ func TestLookupGooglePublicDNSAddr(t *testing.T) {
|
||||
t.Skip("both IPv4 and IPv6 are required")
|
||||
}
|
||||
|
||||
defer dnsWaitGroup.Wait()
|
||||
|
||||
for _, tt := range lookupGooglePublicDNSAddrTests {
|
||||
names, err := LookupAddr(tt.addr)
|
||||
if err != nil {
|
||||
@ -226,6 +234,8 @@ func TestLookupIPv6LinkLocalAddr(t *testing.T) {
|
||||
t.Skip("IPv6 is required")
|
||||
}
|
||||
|
||||
defer dnsWaitGroup.Wait()
|
||||
|
||||
addrs, err := LookupHost("localhost")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@ -262,6 +272,8 @@ func TestLookupCNAME(t *testing.T) {
|
||||
t.Skip("IPv4 is required")
|
||||
}
|
||||
|
||||
defer dnsWaitGroup.Wait()
|
||||
|
||||
for _, tt := range lookupCNAMETests {
|
||||
cname, err := LookupCNAME(tt.name)
|
||||
if err != nil {
|
||||
@ -289,6 +301,8 @@ func TestLookupGoogleHost(t *testing.T) {
|
||||
t.Skip("IPv4 is required")
|
||||
}
|
||||
|
||||
defer dnsWaitGroup.Wait()
|
||||
|
||||
for _, tt := range lookupGoogleHostTests {
|
||||
addrs, err := LookupHost(tt.name)
|
||||
if err != nil {
|
||||
@ -313,6 +327,8 @@ func TestLookupLongTXT(t *testing.T) {
|
||||
testenv.MustHaveExternalNetwork(t)
|
||||
}
|
||||
|
||||
defer dnsWaitGroup.Wait()
|
||||
|
||||
txts, err := LookupTXT("golang.rsc.io")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@ -343,6 +359,8 @@ func TestLookupGoogleIP(t *testing.T) {
|
||||
t.Skip("IPv4 is required")
|
||||
}
|
||||
|
||||
defer dnsWaitGroup.Wait()
|
||||
|
||||
for _, tt := range lookupGoogleIPTests {
|
||||
ips, err := LookupIP(tt.name)
|
||||
if err != nil {
|
||||
@ -378,6 +396,7 @@ var revAddrTests = []struct {
|
||||
}
|
||||
|
||||
func TestReverseAddress(t *testing.T) {
|
||||
defer dnsWaitGroup.Wait()
|
||||
for i, tt := range revAddrTests {
|
||||
a, err := reverseaddr(tt.Addr)
|
||||
if len(tt.ErrPrefix) > 0 && err == nil {
|
||||
@ -401,6 +420,8 @@ func TestDNSFlood(t *testing.T) {
|
||||
t.Skip("test disabled; use -dnsflood to enable")
|
||||
}
|
||||
|
||||
defer dnsWaitGroup.Wait()
|
||||
|
||||
var N = 5000
|
||||
if runtime.GOOS == "darwin" {
|
||||
// On Darwin this test consumes kernel threads much
|
||||
@ -482,6 +503,8 @@ func TestLookupDotsWithLocalSource(t *testing.T) {
|
||||
testenv.MustHaveExternalNetwork(t)
|
||||
}
|
||||
|
||||
defer dnsWaitGroup.Wait()
|
||||
|
||||
for i, fn := range []func() func(){forceGoDNS, forceCgoDNS} {
|
||||
fixup := fn()
|
||||
if fixup == nil {
|
||||
@ -527,6 +550,8 @@ func TestLookupDotsWithRemoteSource(t *testing.T) {
|
||||
t.Skip("IPv4 is required")
|
||||
}
|
||||
|
||||
defer dnsWaitGroup.Wait()
|
||||
|
||||
if fixup := forceGoDNS(); fixup != nil {
|
||||
testDots(t, "go")
|
||||
fixup()
|
||||
@ -747,6 +772,9 @@ func TestLookupNonLDH(t *testing.T) {
|
||||
if runtime.GOOS == "nacl" {
|
||||
t.Skip("skip on nacl")
|
||||
}
|
||||
|
||||
defer dnsWaitGroup.Wait()
|
||||
|
||||
if fixup := forceGoDNS(); fixup != nil {
|
||||
defer fixup()
|
||||
}
|
||||
|
@ -13,6 +13,7 @@ import (
|
||||
)
|
||||
|
||||
func TestGoLookupIP(t *testing.T) {
|
||||
defer dnsWaitGroup.Wait()
|
||||
host := "localhost"
|
||||
ctx := context.Background()
|
||||
_, err, ok := cgoLookupIP(ctx, host)
|
||||
|
Loading…
Reference in New Issue
Block a user