From 7ef6a9f38b39bd442ad07f5faf895486a5f58750 Mon Sep 17 00:00:00 2001 From: Mikio Hara Date: Thu, 11 Jun 2015 12:46:01 +0900 Subject: [PATCH] net: clean up builtin DNS stub resolver, fix tests This change does clean up as preparation for fixing #11081. - renames cfg to resolvConf for clarification - adds a new type resolverConfig and its methods: init, update, tryAcquireSema, releaseSema for mutual exclusion of resolv.conf data - deflakes, simplifies tests for resolv.conf data; previously the tests sometimes left some garbage in the data Change-Id: I277ced853fddc3791dde40ab54dbd5c78114b78c Reviewed-on: https://go-review.googlesource.com/10931 Reviewed-by: Brad Fitzpatrick Run-TryBot: Brad Fitzpatrick --- src/net/dnsclient_unix.go | 138 +++++++++++----------- src/net/dnsclient_unix_test.go | 203 ++++++++++++++++++--------------- 2 files changed, 185 insertions(+), 156 deletions(-) diff --git a/src/net/dnsclient_unix.go b/src/net/dnsclient_unix.go index 8a1745f3cb..8f636055ab 100644 --- a/src/net/dnsclient_unix.go +++ b/src/net/dnsclient_unix.go @@ -195,91 +195,106 @@ func tryOneName(cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, err return "", nil, lastErr } -func convertRR_A(records []dnsRR) []IP { - addrs := make([]IP, len(records)) - for i, rr := range records { - a := rr.(*dnsRR_A).A - addrs[i] = IPv4(byte(a>>24), byte(a>>16), byte(a>>8), byte(a)) +// addrRecordList converts and returns a list of IP addresses from DNS +// address records (both A and AAAA). Other record types are ignored. +func addrRecordList(rrs []dnsRR) []IPAddr { + addrs := make([]IPAddr, 0, 4) + for _, rr := range rrs { + switch rr := rr.(type) { + case *dnsRR_A: + addrs = append(addrs, IPAddr{IP: IPv4(byte(rr.A>>24), byte(rr.A>>16), byte(rr.A>>8), byte(rr.A))}) + case *dnsRR_AAAA: + ip := make(IP, IPv6len) + copy(ip, rr.AAAA[:]) + addrs = append(addrs, IPAddr{IP: ip}) + } } return addrs } -func convertRR_AAAA(records []dnsRR) []IP { - addrs := make([]IP, len(records)) - for i, rr := range records { - a := make(IP, IPv6len) - copy(a, rr.(*dnsRR_AAAA).AAAA[:]) - addrs[i] = a - } - return addrs -} +// A resolverConfig represents a DNS stub resolver configuration. +type resolverConfig struct { + initOnce sync.Once // guards init of resolverConfig -// cfg is used for the storage and reparsing of /etc/resolv.conf -var cfg struct { - // ch is used as a semaphore that only allows one lookup at a time to - // recheck resolv.conf. It acts as guard for lastChecked and modTime. - ch chan struct{} - lastChecked time.Time // last time resolv.conf was checked - modTime time.Time // time of resolv.conf modification + // ch is used as a semaphore that only allows one lookup at a + // time to recheck resolv.conf. + ch chan struct{} // guards lastChecked and modTime + lastChecked time.Time // last time resolv.conf was checked + modTime time.Time // time of resolv.conf modification mu sync.RWMutex // protects dnsConfig dnsConfig *dnsConfig // parsed resolv.conf structure used in lookups } -var onceLoadConfig sync.Once +var resolvConf resolverConfig -func initCfg() { +// init initializes conf and is only called via conf.initOnce. +func (conf *resolverConfig) init() { // Set dnsConfig, modTime, and lastChecked so we don't parse // resolv.conf twice the first time. - cfg.dnsConfig = systemConf().resolv - if cfg.dnsConfig == nil { - cfg.dnsConfig = dnsReadConfig("/etc/resolv.conf") + conf.dnsConfig = systemConf().resolv + if conf.dnsConfig == nil { + conf.dnsConfig = dnsReadConfig("/etc/resolv.conf") } if fi, err := os.Stat("/etc/resolv.conf"); err == nil { - cfg.modTime = fi.ModTime() + conf.modTime = fi.ModTime() } - cfg.lastChecked = time.Now() + conf.lastChecked = time.Now() - // Prepare ch so that only one loadConfig may run at once - cfg.ch = make(chan struct{}, 1) - cfg.ch <- struct{}{} + // Prepare ch so that only one update of resolverConfig may + // run at once. + conf.ch = make(chan struct{}, 1) } -func loadConfig(resolvConfPath string) { - onceLoadConfig.Do(initCfg) +// tryUpdate tries to update conf with the named resolv.conf file. +// The name variable only exists for testing. It is otherwise always +// "/etc/resolv.conf". +func (conf *resolverConfig) tryUpdate(name string) { + conf.initOnce.Do(conf.init) - // ensure only one loadConfig at a time checks /etc/resolv.conf - select { - case <-cfg.ch: - defer func() { cfg.ch <- struct{}{} }() - default: + // Ensure only one update at a time checks resolv.conf. + if !conf.tryAcquireSema() { return } + defer conf.releaseSema() now := time.Now() - if cfg.lastChecked.After(now.Add(-5 * time.Second)) { + if conf.lastChecked.After(now.Add(-5 * time.Second)) { return } - cfg.lastChecked = now + conf.lastChecked = now - if fi, err := os.Stat(resolvConfPath); err == nil { - if fi.ModTime().Equal(cfg.modTime) { + if fi, err := os.Stat(name); err == nil { + if fi.ModTime().Equal(conf.modTime) { return } - cfg.modTime = fi.ModTime() + conf.modTime = fi.ModTime() } else { // If modTime wasn't set prior, assume nothing has changed. - if cfg.modTime.IsZero() { + if conf.modTime.IsZero() { return } - cfg.modTime = time.Time{} + conf.modTime = time.Time{} } - ncfg := dnsReadConfig(resolvConfPath) - cfg.mu.Lock() - cfg.dnsConfig = ncfg - cfg.mu.Unlock() + dnsConf := dnsReadConfig(name) + conf.mu.Lock() + conf.dnsConfig = dnsConf + conf.mu.Unlock() +} + +func (conf *resolverConfig) tryAcquireSema() bool { + select { + case conf.ch <- struct{}{}: + return true + default: + return false + } +} + +func (conf *resolverConfig) releaseSema() { + <-conf.ch } func lookup(name string, qtype uint16) (cname string, rrs []dnsRR, err error) { @@ -287,10 +302,10 @@ func lookup(name string, qtype uint16) (cname string, rrs []dnsRR, err error) { return name, nil, &DNSError{Err: "invalid domain name", Name: name} } - loadConfig("/etc/resolv.conf") - cfg.mu.RLock() - resolv := cfg.dnsConfig - cfg.mu.RUnlock() + resolvConf.tryUpdate("/etc/resolv.conf") + resolvConf.mu.RLock() + resolv := resolvConf.dnsConfig + resolvConf.mu.RUnlock() // If name is rooted (trailing dot) or has enough dots, // try it by itself first. @@ -441,18 +456,7 @@ func goLookupIPOrder(name string, order hostLookupOrder) (addrs []IPAddr, err er lastErr = racer.error continue } - switch racer.qtype { - case dnsTypeA: - for _, ip := range convertRR_A(racer.rrs) { - addr := IPAddr{IP: ip} - addrs = append(addrs, addr) - } - case dnsTypeAAAA: - for _, ip := range convertRR_AAAA(racer.rrs) { - addr := IPAddr{IP: ip} - addrs = append(addrs, addr) - } - } + addrs = append(addrs, addrRecordList(racer.rrs)...) } if len(addrs) == 0 { if lastErr != nil { @@ -472,11 +476,11 @@ func goLookupIPOrder(name string, order hostLookupOrder) (addrs []IPAddr, err er // depending on our lookup code, so that Go and C get the same // answers. func goLookupCNAME(name string) (cname string, err error) { - _, rr, err := lookup(name, dnsTypeCNAME) + _, rrs, err := lookup(name, dnsTypeCNAME) if err != nil { return } - cname = rr[0].(*dnsRR_CNAME).Cname + cname = rrs[0].(*dnsRR_CNAME).Cname return } diff --git a/src/net/dnsclient_unix_test.go b/src/net/dnsclient_unix_test.go index 06c9ad3134..c6bfc67abc 100644 --- a/src/net/dnsclient_unix_test.go +++ b/src/net/dnsclient_unix_test.go @@ -7,11 +7,13 @@ package net import ( - "io" + "fmt" "io/ioutil" "os" "path" "reflect" + "strings" + "sync" "testing" "time" ) @@ -93,119 +95,135 @@ func TestSpecialDomainName(t *testing.T) { } type resolvConfTest struct { - *testing.T dir string path string + *resolverConfig } -func newResolvConfTest(t *testing.T) *resolvConfTest { +func newResolvConfTest() (*resolvConfTest, error) { dir, err := ioutil.TempDir("", "go-resolvconftest") if err != nil { - t.Fatal(err) + return nil, err } - - r := &resolvConfTest{ - T: t, - dir: dir, - path: path.Join(dir, "resolv.conf"), + conf := &resolvConfTest{ + dir: dir, + path: path.Join(dir, "resolv.conf"), + resolverConfig: &resolvConf, } - - return r + conf.initOnce.Do(conf.init) + return conf, nil } -func (r *resolvConfTest) SetConf(s string) { - // Make sure the file mtime will be different once we're done here, - // even on systems with coarse (1s) mtime resolution. - time.Sleep(time.Second) - - f, err := os.OpenFile(r.path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600) +func (conf *resolvConfTest) writeAndUpdate(lines []string) error { + f, err := os.OpenFile(conf.path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600) if err != nil { - r.Fatalf("failed to create temp file %s: %v", r.path, err) + return err } - if _, err := io.WriteString(f, s); err != nil { + if _, err := f.WriteString(strings.Join(lines, "\n")); err != nil { f.Close() - r.Fatalf("failed to write temp file: %v", err) + return err } f.Close() - cfg.lastChecked = time.Time{} - loadConfig(r.path) -} - -func (r *resolvConfTest) WantServers(want []string) { - cfg.mu.RLock() - defer cfg.mu.RUnlock() - if got := cfg.dnsConfig.servers; !reflect.DeepEqual(got, want) { - r.Fatalf("unexpected dns server loaded, got %v want %v", got, want) + if err := conf.forceUpdate(conf.path); err != nil { + return err } + return nil } -func (r *resolvConfTest) Close() { - if err := os.RemoveAll(r.dir); err != nil { - r.Logf("failed to remove temp dir %s: %v", r.dir, err) +func (conf *resolvConfTest) forceUpdate(name string) error { + dnsConf := dnsReadConfig(name) + conf.mu.Lock() + conf.dnsConfig = dnsConf + conf.mu.Unlock() + for i := 0; i < 5; i++ { + if conf.tryAcquireSema() { + conf.lastChecked = time.Time{} + conf.releaseSema() + return nil + } } + return fmt.Errorf("tryAcquireSema for %s failed", name) } -func TestReloadResolvConfFail(t *testing.T) { +func (conf *resolvConfTest) servers() []string { + conf.mu.RLock() + servers := conf.dnsConfig.servers + conf.mu.RUnlock() + return servers +} + +func (conf *resolvConfTest) teardown() error { + err := conf.forceUpdate("/etc/resolv.conf") + os.RemoveAll(conf.dir) + return err +} + +var updateResolvConfTests = []struct { + name string // query name + lines []string // resolver configuration lines + servers []string // expected name servers +}{ + { + name: "golang.org", + lines: []string{"nameserver 8.8.8.8"}, + servers: []string{"8.8.8.8"}, + }, + { + name: "", + lines: nil, // an empty resolv.conf should use defaultNS as name servers + servers: defaultNS, + }, + { + name: "www.example.com", + lines: []string{"nameserver 8.8.4.4"}, + servers: []string{"8.8.4.4"}, + }, +} + +func TestUpdateResolvConf(t *testing.T) { if testing.Short() || !*testExternal { t.Skip("avoid external network") } - r := newResolvConfTest(t) - defer r.Close() - - r.SetConf("nameserver 8.8.8.8") - - if _, err := goLookupIP("golang.org"); err != nil { + conf, err := newResolvConfTest() + if err != nil { t.Fatal(err) } + defer conf.teardown() - // Using an empty resolv.conf should use localhost as servers - r.SetConf("") - - if len(cfg.dnsConfig.servers) != len(defaultNS) { - t.Fatalf("goLookupIP(missing; good; bad) failed: servers=%v, want: %v", cfg.dnsConfig.servers, defaultNS) - } - - for i := range cfg.dnsConfig.servers { - if cfg.dnsConfig.servers[i] != defaultNS[i] { - t.Fatalf("goLookupIP(missing; good; bad) failed: servers=%v, want: %v", cfg.dnsConfig.servers, defaultNS) + for i, tt := range updateResolvConfTests { + if err := conf.writeAndUpdate(tt.lines); err != nil { + t.Error(err) + continue + } + if tt.name != "" { + var wg sync.WaitGroup + const N = 10 + wg.Add(N) + for j := 0; j < N; j++ { + go func(name string) { + defer wg.Done() + ips, err := goLookupIP(name) + if err != nil { + t.Error(err) + return + } + if len(ips) == 0 { + t.Errorf("no records for %s", name) + return + } + }(tt.name) + } + wg.Wait() + } + servers := conf.servers() + if !reflect.DeepEqual(servers, tt.servers) { + t.Errorf("#%d: got %v; want %v", i, servers, tt.servers) + continue } } } -func TestReloadResolvConfChange(t *testing.T) { - if testing.Short() || !*testExternal { - t.Skip("avoid external network") - } - - r := newResolvConfTest(t) - defer r.Close() - - r.SetConf("nameserver 8.8.8.8") - - if _, err := goLookupIP("golang.org"); err != nil { - t.Fatal(err) - } - r.WantServers([]string{"8.8.8.8"}) - - // Using an empty resolv.conf should use localhost as servers - r.SetConf("") - - if len(cfg.dnsConfig.servers) != len(defaultNS) { - t.Fatalf("goLookupIP(missing; good; bad) failed: servers=%v, want: %v", cfg.dnsConfig.servers, defaultNS) - } - - for i := range cfg.dnsConfig.servers { - if cfg.dnsConfig.servers[i] != defaultNS[i] { - t.Fatalf("goLookupIP(missing; good; bad) failed: servers=%v, want: %v", cfg.dnsConfig.servers, defaultNS) - } - } - - // A new good config should get picked up - r.SetConf("nameserver 8.8.4.4") - r.WantServers([]string{"8.8.4.4"}) -} - func BenchmarkGoLookupIP(b *testing.B) { testHookUninstaller.Do(uninstallTestHooks) @@ -225,14 +243,21 @@ func BenchmarkGoLookupIPNoSuchHost(b *testing.B) { func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) { testHookUninstaller.Do(uninstallTestHooks) - // This looks ugly but it's safe as long as benchmarks are run - // sequentially in package testing. - <-cfg.ch // keep config from being reloaded upon lookup - orig := cfg.dnsConfig - cfg.dnsConfig.servers = append([]string{"203.0.113.254"}, cfg.dnsConfig.servers...) // use TEST-NET-3 block, see RFC 5737 + conf, err := newResolvConfTest() + if err != nil { + b.Fatal(err) + } + defer conf.teardown() + + lines := []string{ + "nameserver 203.0.113.254", // use TEST-NET-3 block, see RFC 5737 + "nameserver 8.8.8.8", + } + if err := conf.writeAndUpdate(lines); err != nil { + b.Fatal(err) + } + for i := 0; i < b.N; i++ { goLookupIP("www.example.com") } - cfg.dnsConfig = orig - cfg.ch <- struct{}{} }