From 6c877e5da7ab14f0d8a206c09f24cf51fbbc393a Mon Sep 17 00:00:00 2001 From: Ian Lance Taylor Date: Thu, 7 Dec 2017 20:30:28 -0800 Subject: [PATCH] net: avoid race on test hooks with DNS goroutines The DNS code can start goroutines and not wait for them to complete. This does no harm, but in tests this can cause a race condition with the test hooks that are installed and unintalled around the tests. Add a WaitGroup that tests of DNS can use to avoid the race. Fixes #21090 Change-Id: I6c1443a9c2378e8b89d0ab1d6390c0e3e726b0ce Reviewed-on: https://go-review.googlesource.com/82795 Reviewed-by: Brad Fitzpatrick --- src/internal/singleflight/singleflight.go | 10 +++--- src/net/cgo_unix_test.go | 6 ++++ src/net/dnsclient_unix.go | 2 ++ src/net/dnsclient_unix_test.go | 41 +++++++++++++++++++++++ src/net/lookup.go | 12 ++++++- src/net/lookup_test.go | 28 ++++++++++++++++ src/net/netgo_unix_test.go | 1 + 7 files changed, 95 insertions(+), 5 deletions(-) diff --git a/src/internal/singleflight/singleflight.go b/src/internal/singleflight/singleflight.go index de81ac87b9..1e9960d575 100644 --- a/src/internal/singleflight/singleflight.go +++ b/src/internal/singleflight/singleflight.go @@ -65,8 +65,10 @@ func (g *Group) Do(key string, fn func() (interface{}, error)) (v interface{}, e } // DoChan is like Do but returns a channel that will receive the -// results when they are ready. -func (g *Group) DoChan(key string, fn func() (interface{}, error)) <-chan Result { +// results when they are ready. The second result is true if the function +// will eventually be called, false if it will not (because there is +// a pending request with this key). +func (g *Group) DoChan(key string, fn func() (interface{}, error)) (<-chan Result, bool) { ch := make(chan Result, 1) g.mu.Lock() if g.m == nil { @@ -76,7 +78,7 @@ func (g *Group) DoChan(key string, fn func() (interface{}, error)) <-chan Result c.dups++ c.chans = append(c.chans, ch) g.mu.Unlock() - return ch + return ch, false } c := &call{chans: []chan<- Result{ch}} c.wg.Add(1) @@ -85,7 +87,7 @@ func (g *Group) DoChan(key string, fn func() (interface{}, error)) <-chan Result go g.doCall(c, key, fn) - return ch + return ch, true } // doCall handles the single call for a key. diff --git a/src/net/cgo_unix_test.go b/src/net/cgo_unix_test.go index e861c7aa1f..b476a6d626 100644 --- a/src/net/cgo_unix_test.go +++ b/src/net/cgo_unix_test.go @@ -13,6 +13,7 @@ import ( ) func TestCgoLookupIP(t *testing.T) { + defer dnsWaitGroup.Wait() ctx := context.Background() _, err, ok := cgoLookupIP(ctx, "localhost") if !ok { @@ -24,6 +25,7 @@ func TestCgoLookupIP(t *testing.T) { } func TestCgoLookupIPWithCancel(t *testing.T) { + defer dnsWaitGroup.Wait() ctx, cancel := context.WithCancel(context.Background()) defer cancel() _, err, ok := cgoLookupIP(ctx, "localhost") @@ -36,6 +38,7 @@ func TestCgoLookupIPWithCancel(t *testing.T) { } func TestCgoLookupPort(t *testing.T) { + defer dnsWaitGroup.Wait() ctx := context.Background() _, err, ok := cgoLookupPort(ctx, "tcp", "smtp") if !ok { @@ -47,6 +50,7 @@ func TestCgoLookupPort(t *testing.T) { } func TestCgoLookupPortWithCancel(t *testing.T) { + defer dnsWaitGroup.Wait() ctx, cancel := context.WithCancel(context.Background()) defer cancel() _, err, ok := cgoLookupPort(ctx, "tcp", "smtp") @@ -59,6 +63,7 @@ func TestCgoLookupPortWithCancel(t *testing.T) { } func TestCgoLookupPTR(t *testing.T) { + defer dnsWaitGroup.Wait() ctx := context.Background() _, err, ok := cgoLookupPTR(ctx, "127.0.0.1") if !ok { @@ -70,6 +75,7 @@ func TestCgoLookupPTR(t *testing.T) { } func TestCgoLookupPTRWithCancel(t *testing.T) { + defer dnsWaitGroup.Wait() ctx, cancel := context.WithCancel(context.Background()) defer cancel() _, err, ok := cgoLookupPTR(ctx, "127.0.0.1") diff --git a/src/net/dnsclient_unix.go b/src/net/dnsclient_unix.go index acbf6c3b2a..9026fd8c74 100644 --- a/src/net/dnsclient_unix.go +++ b/src/net/dnsclient_unix.go @@ -479,7 +479,9 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, name string, order var lastErr error for _, fqdn := range conf.nameList(name) { for _, qtype := range qtypes { + dnsWaitGroup.Add(1) go func(qtype uint16) { + defer dnsWaitGroup.Done() cname, rrs, err := r.tryOneName(ctx, conf, fqdn, qtype) lane <- racer{cname, rrs, err} }(qtype) diff --git a/src/net/dnsclient_unix_test.go b/src/net/dnsclient_unix_test.go index 73b628c1b5..295ed9770c 100644 --- a/src/net/dnsclient_unix_test.go +++ b/src/net/dnsclient_unix_test.go @@ -203,6 +203,7 @@ var fakeDNSServerSuccessful = fakeDNSServer{func(_, _ string, q *dnsMsg, _ time. // Issue 13705: don't try to resolve onion addresses, etc func TestLookupTorOnion(t *testing.T) { + defer dnsWaitGroup.Wait() r := Resolver{PreferGo: true, Dial: fakeDNSServerSuccessful.DialContext} addrs, err := r.LookupIPAddr(context.Background(), "foo.onion") if err != nil { @@ -300,6 +301,8 @@ var updateResolvConfTests = []struct { } func TestUpdateResolvConf(t *testing.T) { + defer dnsWaitGroup.Wait() + r := Resolver{PreferGo: true, Dial: fakeDNSServerSuccessful.DialContext} conf, err := newResolvConfTest() @@ -455,6 +458,8 @@ var goLookupIPWithResolverConfigTests = []struct { } func TestGoLookupIPWithResolverConfig(t *testing.T) { + defer dnsWaitGroup.Wait() + fake := fakeDNSServer{func(n, s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) { switch s { case "[2001:4860:4860::8888]:53", "8.8.8.8:53": @@ -547,6 +552,8 @@ func TestGoLookupIPWithResolverConfig(t *testing.T) { // Test that goLookupIPOrder falls back to the host file when no DNS servers are available. func TestGoLookupIPOrderFallbackToFile(t *testing.T) { + defer dnsWaitGroup.Wait() + fake := fakeDNSServer{func(n, s string, q *dnsMsg, tm time.Time) (*dnsMsg, error) { r := &dnsMsg{ dnsMsgHdr: dnsMsgHdr{ @@ -603,6 +610,8 @@ func TestGoLookupIPOrderFallbackToFile(t *testing.T) { // querying the original name instead of an error encountered // querying a generated name. func TestErrorForOriginalNameWhenSearching(t *testing.T) { + defer dnsWaitGroup.Wait() + const fqdn = "doesnotexist.domain" conf, err := newResolvConfTest() @@ -657,6 +666,8 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) { // Issue 15434. If a name server gives a lame referral, continue to the next. func TestIgnoreLameReferrals(t *testing.T) { + defer dnsWaitGroup.Wait() + conf, err := newResolvConfTest() if err != nil { t.Fatal(err) @@ -889,6 +900,8 @@ func TestIgnoreDNSForgeries(t *testing.T) { // Issue 16865. If a name server times out, continue to the next. func TestRetryTimeout(t *testing.T) { + defer dnsWaitGroup.Wait() + conf, err := newResolvConfTest() if err != nil { t.Fatal(err) @@ -945,6 +958,8 @@ func TestRotate(t *testing.T) { } func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) { + defer dnsWaitGroup.Wait() + conf, err := newResolvConfTest() if err != nil { t.Fatal(err) @@ -1008,6 +1023,8 @@ func mockTXTResponse(q *dnsMsg) *dnsMsg { // Issue 17448. With StrictErrors enabled, temporary errors should make // LookupIP fail rather than return a partial result. func TestStrictErrorsLookupIP(t *testing.T) { + defer dnsWaitGroup.Wait() + conf, err := newResolvConfTest() if err != nil { t.Fatal(err) @@ -1256,6 +1273,8 @@ func TestStrictErrorsLookupIP(t *testing.T) { // Issue 17448. With StrictErrors enabled, temporary errors should make // LookupTXT stop walking the search list. func TestStrictErrorsLookupTXT(t *testing.T) { + defer dnsWaitGroup.Wait() + conf, err := newResolvConfTest() if err != nil { t.Fatal(err) @@ -1312,3 +1331,25 @@ func TestStrictErrorsLookupTXT(t *testing.T) { } } } + +// Test for a race between uninstalling the test hooks and closing a +// socket connection. This used to fail when testing with -race. +func TestDNSGoroutineRace(t *testing.T) { + defer dnsWaitGroup.Wait() + + fake := fakeDNSServer{func(n, s string, q *dnsMsg, t time.Time) (*dnsMsg, error) { + time.Sleep(10 * time.Microsecond) + return nil, poll.ErrTimeout + }} + r := Resolver{PreferGo: true, Dial: fake.DialContext} + + // The timeout here is less than the timeout used by the server, + // so the goroutine started to query the (fake) server will hang + // around after this test is done if we don't call dnsWaitGroup.Wait. + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Microsecond) + defer cancel() + _, err := r.LookupIPAddr(ctx, "where.are.they.now") + if err == nil { + t.Fatal("fake DNS lookup unexpectedly succeeded") + } +} diff --git a/src/net/lookup.go b/src/net/lookup.go index c9f327050a..85e472932f 100644 --- a/src/net/lookup.go +++ b/src/net/lookup.go @@ -8,6 +8,7 @@ import ( "context" "internal/nettrace" "internal/singleflight" + "sync" ) // protocols contains minimal mappings between internet protocol @@ -53,6 +54,10 @@ var services = map[string]map[string]int{ }, } +// dnsWaitGroup can be used by tests to wait for all DNS goroutines to +// complete. This avoids races on the test hooks. +var dnsWaitGroup sync.WaitGroup + const maxProtoLength = len("RSVP-E2E-IGNORE") + 10 // with room to grow func lookupProtocolMap(name string) (int, error) { @@ -189,9 +194,14 @@ func (r *Resolver) LookupIPAddr(ctx context.Context, host string) ([]IPAddr, err resolverFunc = alt } - ch := lookupGroup.DoChan(host, func() (interface{}, error) { + dnsWaitGroup.Add(1) + ch, called := lookupGroup.DoChan(host, func() (interface{}, error) { + defer dnsWaitGroup.Done() return testHookLookupIP(ctx, resolverFunc, host) }) + if !called { + dnsWaitGroup.Done() + } select { case <-ctx.Done(): diff --git a/src/net/lookup_test.go b/src/net/lookup_test.go index e3bf114a8e..bfb872551c 100644 --- a/src/net/lookup_test.go +++ b/src/net/lookup_test.go @@ -105,6 +105,8 @@ func TestLookupGmailMX(t *testing.T) { t.Skip("IPv4 is required") } + defer dnsWaitGroup.Wait() + for _, tt := range lookupGmailMXTests { mxs, err := LookupMX(tt.name) if err != nil { @@ -137,6 +139,8 @@ func TestLookupGmailNS(t *testing.T) { t.Skip("IPv4 is required") } + defer dnsWaitGroup.Wait() + for _, tt := range lookupGmailNSTests { nss, err := LookupNS(tt.name) if err != nil { @@ -170,6 +174,8 @@ func TestLookupGmailTXT(t *testing.T) { t.Skip("IPv4 is required") } + defer dnsWaitGroup.Wait() + for _, tt := range lookupGmailTXTTests { txts, err := LookupTXT(tt.name) if err != nil { @@ -205,6 +211,8 @@ func TestLookupGooglePublicDNSAddr(t *testing.T) { t.Skip("both IPv4 and IPv6 are required") } + defer dnsWaitGroup.Wait() + for _, tt := range lookupGooglePublicDNSAddrTests { names, err := LookupAddr(tt.addr) if err != nil { @@ -226,6 +234,8 @@ func TestLookupIPv6LinkLocalAddr(t *testing.T) { t.Skip("IPv6 is required") } + defer dnsWaitGroup.Wait() + addrs, err := LookupHost("localhost") if err != nil { t.Fatal(err) @@ -262,6 +272,8 @@ func TestLookupCNAME(t *testing.T) { t.Skip("IPv4 is required") } + defer dnsWaitGroup.Wait() + for _, tt := range lookupCNAMETests { cname, err := LookupCNAME(tt.name) if err != nil { @@ -289,6 +301,8 @@ func TestLookupGoogleHost(t *testing.T) { t.Skip("IPv4 is required") } + defer dnsWaitGroup.Wait() + for _, tt := range lookupGoogleHostTests { addrs, err := LookupHost(tt.name) if err != nil { @@ -313,6 +327,8 @@ func TestLookupLongTXT(t *testing.T) { testenv.MustHaveExternalNetwork(t) } + defer dnsWaitGroup.Wait() + txts, err := LookupTXT("golang.rsc.io") if err != nil { t.Fatal(err) @@ -343,6 +359,8 @@ func TestLookupGoogleIP(t *testing.T) { t.Skip("IPv4 is required") } + defer dnsWaitGroup.Wait() + for _, tt := range lookupGoogleIPTests { ips, err := LookupIP(tt.name) if err != nil { @@ -378,6 +396,7 @@ var revAddrTests = []struct { } func TestReverseAddress(t *testing.T) { + defer dnsWaitGroup.Wait() for i, tt := range revAddrTests { a, err := reverseaddr(tt.Addr) if len(tt.ErrPrefix) > 0 && err == nil { @@ -401,6 +420,8 @@ func TestDNSFlood(t *testing.T) { t.Skip("test disabled; use -dnsflood to enable") } + defer dnsWaitGroup.Wait() + var N = 5000 if runtime.GOOS == "darwin" { // On Darwin this test consumes kernel threads much @@ -482,6 +503,8 @@ func TestLookupDotsWithLocalSource(t *testing.T) { testenv.MustHaveExternalNetwork(t) } + defer dnsWaitGroup.Wait() + for i, fn := range []func() func(){forceGoDNS, forceCgoDNS} { fixup := fn() if fixup == nil { @@ -527,6 +550,8 @@ func TestLookupDotsWithRemoteSource(t *testing.T) { t.Skip("IPv4 is required") } + defer dnsWaitGroup.Wait() + if fixup := forceGoDNS(); fixup != nil { testDots(t, "go") fixup() @@ -747,6 +772,9 @@ func TestLookupNonLDH(t *testing.T) { if runtime.GOOS == "nacl" { t.Skip("skip on nacl") } + + defer dnsWaitGroup.Wait() + if fixup := forceGoDNS(); fixup != nil { defer fixup() } diff --git a/src/net/netgo_unix_test.go b/src/net/netgo_unix_test.go index 47901b03cf..f2244ea69c 100644 --- a/src/net/netgo_unix_test.go +++ b/src/net/netgo_unix_test.go @@ -13,6 +13,7 @@ import ( ) func TestGoLookupIP(t *testing.T) { + defer dnsWaitGroup.Wait() host := "localhost" ctx := context.Background() _, err, ok := cgoLookupIP(ctx, host)