mirror of
https://github.com/golang/go
synced 2024-11-20 01:04:40 -07:00
net: clean up builtin DNS stub resolver, fix tests
This change does clean up as preparation for fixing #11081. - renames cfg to resolvConf for clarification - adds a new type resolverConfig and its methods: init, update, tryAcquireSema, releaseSema for mutual exclusion of resolv.conf data - deflakes, simplifies tests for resolv.conf data; previously the tests sometimes left some garbage in the data Change-Id: I277ced853fddc3791dde40ab54dbd5c78114b78c Reviewed-on: https://go-review.googlesource.com/10931 Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org> Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
This commit is contained in:
parent
d7a8d3eb08
commit
7ef6a9f38b
@ -195,91 +195,106 @@ func tryOneName(cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, err
|
||||
return "", nil, lastErr
|
||||
}
|
||||
|
||||
func convertRR_A(records []dnsRR) []IP {
|
||||
addrs := make([]IP, len(records))
|
||||
for i, rr := range records {
|
||||
a := rr.(*dnsRR_A).A
|
||||
addrs[i] = IPv4(byte(a>>24), byte(a>>16), byte(a>>8), byte(a))
|
||||
// addrRecordList converts and returns a list of IP addresses from DNS
|
||||
// address records (both A and AAAA). Other record types are ignored.
|
||||
func addrRecordList(rrs []dnsRR) []IPAddr {
|
||||
addrs := make([]IPAddr, 0, 4)
|
||||
for _, rr := range rrs {
|
||||
switch rr := rr.(type) {
|
||||
case *dnsRR_A:
|
||||
addrs = append(addrs, IPAddr{IP: IPv4(byte(rr.A>>24), byte(rr.A>>16), byte(rr.A>>8), byte(rr.A))})
|
||||
case *dnsRR_AAAA:
|
||||
ip := make(IP, IPv6len)
|
||||
copy(ip, rr.AAAA[:])
|
||||
addrs = append(addrs, IPAddr{IP: ip})
|
||||
}
|
||||
}
|
||||
return addrs
|
||||
}
|
||||
|
||||
func convertRR_AAAA(records []dnsRR) []IP {
|
||||
addrs := make([]IP, len(records))
|
||||
for i, rr := range records {
|
||||
a := make(IP, IPv6len)
|
||||
copy(a, rr.(*dnsRR_AAAA).AAAA[:])
|
||||
addrs[i] = a
|
||||
}
|
||||
return addrs
|
||||
}
|
||||
// A resolverConfig represents a DNS stub resolver configuration.
|
||||
type resolverConfig struct {
|
||||
initOnce sync.Once // guards init of resolverConfig
|
||||
|
||||
// cfg is used for the storage and reparsing of /etc/resolv.conf
|
||||
var cfg struct {
|
||||
// ch is used as a semaphore that only allows one lookup at a time to
|
||||
// recheck resolv.conf. It acts as guard for lastChecked and modTime.
|
||||
ch chan struct{}
|
||||
lastChecked time.Time // last time resolv.conf was checked
|
||||
modTime time.Time // time of resolv.conf modification
|
||||
// ch is used as a semaphore that only allows one lookup at a
|
||||
// time to recheck resolv.conf.
|
||||
ch chan struct{} // guards lastChecked and modTime
|
||||
lastChecked time.Time // last time resolv.conf was checked
|
||||
modTime time.Time // time of resolv.conf modification
|
||||
|
||||
mu sync.RWMutex // protects dnsConfig
|
||||
dnsConfig *dnsConfig // parsed resolv.conf structure used in lookups
|
||||
}
|
||||
|
||||
var onceLoadConfig sync.Once
|
||||
var resolvConf resolverConfig
|
||||
|
||||
func initCfg() {
|
||||
// init initializes conf and is only called via conf.initOnce.
|
||||
func (conf *resolverConfig) init() {
|
||||
// Set dnsConfig, modTime, and lastChecked so we don't parse
|
||||
// resolv.conf twice the first time.
|
||||
cfg.dnsConfig = systemConf().resolv
|
||||
if cfg.dnsConfig == nil {
|
||||
cfg.dnsConfig = dnsReadConfig("/etc/resolv.conf")
|
||||
conf.dnsConfig = systemConf().resolv
|
||||
if conf.dnsConfig == nil {
|
||||
conf.dnsConfig = dnsReadConfig("/etc/resolv.conf")
|
||||
}
|
||||
|
||||
if fi, err := os.Stat("/etc/resolv.conf"); err == nil {
|
||||
cfg.modTime = fi.ModTime()
|
||||
conf.modTime = fi.ModTime()
|
||||
}
|
||||
cfg.lastChecked = time.Now()
|
||||
conf.lastChecked = time.Now()
|
||||
|
||||
// Prepare ch so that only one loadConfig may run at once
|
||||
cfg.ch = make(chan struct{}, 1)
|
||||
cfg.ch <- struct{}{}
|
||||
// Prepare ch so that only one update of resolverConfig may
|
||||
// run at once.
|
||||
conf.ch = make(chan struct{}, 1)
|
||||
}
|
||||
|
||||
func loadConfig(resolvConfPath string) {
|
||||
onceLoadConfig.Do(initCfg)
|
||||
// tryUpdate tries to update conf with the named resolv.conf file.
|
||||
// The name variable only exists for testing. It is otherwise always
|
||||
// "/etc/resolv.conf".
|
||||
func (conf *resolverConfig) tryUpdate(name string) {
|
||||
conf.initOnce.Do(conf.init)
|
||||
|
||||
// ensure only one loadConfig at a time checks /etc/resolv.conf
|
||||
select {
|
||||
case <-cfg.ch:
|
||||
defer func() { cfg.ch <- struct{}{} }()
|
||||
default:
|
||||
// Ensure only one update at a time checks resolv.conf.
|
||||
if !conf.tryAcquireSema() {
|
||||
return
|
||||
}
|
||||
defer conf.releaseSema()
|
||||
|
||||
now := time.Now()
|
||||
if cfg.lastChecked.After(now.Add(-5 * time.Second)) {
|
||||
if conf.lastChecked.After(now.Add(-5 * time.Second)) {
|
||||
return
|
||||
}
|
||||
cfg.lastChecked = now
|
||||
conf.lastChecked = now
|
||||
|
||||
if fi, err := os.Stat(resolvConfPath); err == nil {
|
||||
if fi.ModTime().Equal(cfg.modTime) {
|
||||
if fi, err := os.Stat(name); err == nil {
|
||||
if fi.ModTime().Equal(conf.modTime) {
|
||||
return
|
||||
}
|
||||
cfg.modTime = fi.ModTime()
|
||||
conf.modTime = fi.ModTime()
|
||||
} else {
|
||||
// If modTime wasn't set prior, assume nothing has changed.
|
||||
if cfg.modTime.IsZero() {
|
||||
if conf.modTime.IsZero() {
|
||||
return
|
||||
}
|
||||
cfg.modTime = time.Time{}
|
||||
conf.modTime = time.Time{}
|
||||
}
|
||||
|
||||
ncfg := dnsReadConfig(resolvConfPath)
|
||||
cfg.mu.Lock()
|
||||
cfg.dnsConfig = ncfg
|
||||
cfg.mu.Unlock()
|
||||
dnsConf := dnsReadConfig(name)
|
||||
conf.mu.Lock()
|
||||
conf.dnsConfig = dnsConf
|
||||
conf.mu.Unlock()
|
||||
}
|
||||
|
||||
func (conf *resolverConfig) tryAcquireSema() bool {
|
||||
select {
|
||||
case conf.ch <- struct{}{}:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (conf *resolverConfig) releaseSema() {
|
||||
<-conf.ch
|
||||
}
|
||||
|
||||
func lookup(name string, qtype uint16) (cname string, rrs []dnsRR, err error) {
|
||||
@ -287,10 +302,10 @@ func lookup(name string, qtype uint16) (cname string, rrs []dnsRR, err error) {
|
||||
return name, nil, &DNSError{Err: "invalid domain name", Name: name}
|
||||
}
|
||||
|
||||
loadConfig("/etc/resolv.conf")
|
||||
cfg.mu.RLock()
|
||||
resolv := cfg.dnsConfig
|
||||
cfg.mu.RUnlock()
|
||||
resolvConf.tryUpdate("/etc/resolv.conf")
|
||||
resolvConf.mu.RLock()
|
||||
resolv := resolvConf.dnsConfig
|
||||
resolvConf.mu.RUnlock()
|
||||
|
||||
// If name is rooted (trailing dot) or has enough dots,
|
||||
// try it by itself first.
|
||||
@ -441,18 +456,7 @@ func goLookupIPOrder(name string, order hostLookupOrder) (addrs []IPAddr, err er
|
||||
lastErr = racer.error
|
||||
continue
|
||||
}
|
||||
switch racer.qtype {
|
||||
case dnsTypeA:
|
||||
for _, ip := range convertRR_A(racer.rrs) {
|
||||
addr := IPAddr{IP: ip}
|
||||
addrs = append(addrs, addr)
|
||||
}
|
||||
case dnsTypeAAAA:
|
||||
for _, ip := range convertRR_AAAA(racer.rrs) {
|
||||
addr := IPAddr{IP: ip}
|
||||
addrs = append(addrs, addr)
|
||||
}
|
||||
}
|
||||
addrs = append(addrs, addrRecordList(racer.rrs)...)
|
||||
}
|
||||
if len(addrs) == 0 {
|
||||
if lastErr != nil {
|
||||
@ -472,11 +476,11 @@ func goLookupIPOrder(name string, order hostLookupOrder) (addrs []IPAddr, err er
|
||||
// depending on our lookup code, so that Go and C get the same
|
||||
// answers.
|
||||
func goLookupCNAME(name string) (cname string, err error) {
|
||||
_, rr, err := lookup(name, dnsTypeCNAME)
|
||||
_, rrs, err := lookup(name, dnsTypeCNAME)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
cname = rr[0].(*dnsRR_CNAME).Cname
|
||||
cname = rrs[0].(*dnsRR_CNAME).Cname
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -7,11 +7,13 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"io"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@ -93,119 +95,135 @@ func TestSpecialDomainName(t *testing.T) {
|
||||
}
|
||||
|
||||
type resolvConfTest struct {
|
||||
*testing.T
|
||||
dir string
|
||||
path string
|
||||
*resolverConfig
|
||||
}
|
||||
|
||||
func newResolvConfTest(t *testing.T) *resolvConfTest {
|
||||
func newResolvConfTest() (*resolvConfTest, error) {
|
||||
dir, err := ioutil.TempDir("", "go-resolvconftest")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r := &resolvConfTest{
|
||||
T: t,
|
||||
dir: dir,
|
||||
path: path.Join(dir, "resolv.conf"),
|
||||
conf := &resolvConfTest{
|
||||
dir: dir,
|
||||
path: path.Join(dir, "resolv.conf"),
|
||||
resolverConfig: &resolvConf,
|
||||
}
|
||||
|
||||
return r
|
||||
conf.initOnce.Do(conf.init)
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
func (r *resolvConfTest) SetConf(s string) {
|
||||
// Make sure the file mtime will be different once we're done here,
|
||||
// even on systems with coarse (1s) mtime resolution.
|
||||
time.Sleep(time.Second)
|
||||
|
||||
f, err := os.OpenFile(r.path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600)
|
||||
func (conf *resolvConfTest) writeAndUpdate(lines []string) error {
|
||||
f, err := os.OpenFile(conf.path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600)
|
||||
if err != nil {
|
||||
r.Fatalf("failed to create temp file %s: %v", r.path, err)
|
||||
return err
|
||||
}
|
||||
if _, err := io.WriteString(f, s); err != nil {
|
||||
if _, err := f.WriteString(strings.Join(lines, "\n")); err != nil {
|
||||
f.Close()
|
||||
r.Fatalf("failed to write temp file: %v", err)
|
||||
return err
|
||||
}
|
||||
f.Close()
|
||||
cfg.lastChecked = time.Time{}
|
||||
loadConfig(r.path)
|
||||
}
|
||||
|
||||
func (r *resolvConfTest) WantServers(want []string) {
|
||||
cfg.mu.RLock()
|
||||
defer cfg.mu.RUnlock()
|
||||
if got := cfg.dnsConfig.servers; !reflect.DeepEqual(got, want) {
|
||||
r.Fatalf("unexpected dns server loaded, got %v want %v", got, want)
|
||||
if err := conf.forceUpdate(conf.path); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *resolvConfTest) Close() {
|
||||
if err := os.RemoveAll(r.dir); err != nil {
|
||||
r.Logf("failed to remove temp dir %s: %v", r.dir, err)
|
||||
func (conf *resolvConfTest) forceUpdate(name string) error {
|
||||
dnsConf := dnsReadConfig(name)
|
||||
conf.mu.Lock()
|
||||
conf.dnsConfig = dnsConf
|
||||
conf.mu.Unlock()
|
||||
for i := 0; i < 5; i++ {
|
||||
if conf.tryAcquireSema() {
|
||||
conf.lastChecked = time.Time{}
|
||||
conf.releaseSema()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("tryAcquireSema for %s failed", name)
|
||||
}
|
||||
|
||||
func TestReloadResolvConfFail(t *testing.T) {
|
||||
func (conf *resolvConfTest) servers() []string {
|
||||
conf.mu.RLock()
|
||||
servers := conf.dnsConfig.servers
|
||||
conf.mu.RUnlock()
|
||||
return servers
|
||||
}
|
||||
|
||||
func (conf *resolvConfTest) teardown() error {
|
||||
err := conf.forceUpdate("/etc/resolv.conf")
|
||||
os.RemoveAll(conf.dir)
|
||||
return err
|
||||
}
|
||||
|
||||
var updateResolvConfTests = []struct {
|
||||
name string // query name
|
||||
lines []string // resolver configuration lines
|
||||
servers []string // expected name servers
|
||||
}{
|
||||
{
|
||||
name: "golang.org",
|
||||
lines: []string{"nameserver 8.8.8.8"},
|
||||
servers: []string{"8.8.8.8"},
|
||||
},
|
||||
{
|
||||
name: "",
|
||||
lines: nil, // an empty resolv.conf should use defaultNS as name servers
|
||||
servers: defaultNS,
|
||||
},
|
||||
{
|
||||
name: "www.example.com",
|
||||
lines: []string{"nameserver 8.8.4.4"},
|
||||
servers: []string{"8.8.4.4"},
|
||||
},
|
||||
}
|
||||
|
||||
func TestUpdateResolvConf(t *testing.T) {
|
||||
if testing.Short() || !*testExternal {
|
||||
t.Skip("avoid external network")
|
||||
}
|
||||
|
||||
r := newResolvConfTest(t)
|
||||
defer r.Close()
|
||||
|
||||
r.SetConf("nameserver 8.8.8.8")
|
||||
|
||||
if _, err := goLookupIP("golang.org"); err != nil {
|
||||
conf, err := newResolvConfTest()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conf.teardown()
|
||||
|
||||
// Using an empty resolv.conf should use localhost as servers
|
||||
r.SetConf("")
|
||||
|
||||
if len(cfg.dnsConfig.servers) != len(defaultNS) {
|
||||
t.Fatalf("goLookupIP(missing; good; bad) failed: servers=%v, want: %v", cfg.dnsConfig.servers, defaultNS)
|
||||
}
|
||||
|
||||
for i := range cfg.dnsConfig.servers {
|
||||
if cfg.dnsConfig.servers[i] != defaultNS[i] {
|
||||
t.Fatalf("goLookupIP(missing; good; bad) failed: servers=%v, want: %v", cfg.dnsConfig.servers, defaultNS)
|
||||
for i, tt := range updateResolvConfTests {
|
||||
if err := conf.writeAndUpdate(tt.lines); err != nil {
|
||||
t.Error(err)
|
||||
continue
|
||||
}
|
||||
if tt.name != "" {
|
||||
var wg sync.WaitGroup
|
||||
const N = 10
|
||||
wg.Add(N)
|
||||
for j := 0; j < N; j++ {
|
||||
go func(name string) {
|
||||
defer wg.Done()
|
||||
ips, err := goLookupIP(name)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
if len(ips) == 0 {
|
||||
t.Errorf("no records for %s", name)
|
||||
return
|
||||
}
|
||||
}(tt.name)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
servers := conf.servers()
|
||||
if !reflect.DeepEqual(servers, tt.servers) {
|
||||
t.Errorf("#%d: got %v; want %v", i, servers, tt.servers)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReloadResolvConfChange(t *testing.T) {
|
||||
if testing.Short() || !*testExternal {
|
||||
t.Skip("avoid external network")
|
||||
}
|
||||
|
||||
r := newResolvConfTest(t)
|
||||
defer r.Close()
|
||||
|
||||
r.SetConf("nameserver 8.8.8.8")
|
||||
|
||||
if _, err := goLookupIP("golang.org"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
r.WantServers([]string{"8.8.8.8"})
|
||||
|
||||
// Using an empty resolv.conf should use localhost as servers
|
||||
r.SetConf("")
|
||||
|
||||
if len(cfg.dnsConfig.servers) != len(defaultNS) {
|
||||
t.Fatalf("goLookupIP(missing; good; bad) failed: servers=%v, want: %v", cfg.dnsConfig.servers, defaultNS)
|
||||
}
|
||||
|
||||
for i := range cfg.dnsConfig.servers {
|
||||
if cfg.dnsConfig.servers[i] != defaultNS[i] {
|
||||
t.Fatalf("goLookupIP(missing; good; bad) failed: servers=%v, want: %v", cfg.dnsConfig.servers, defaultNS)
|
||||
}
|
||||
}
|
||||
|
||||
// A new good config should get picked up
|
||||
r.SetConf("nameserver 8.8.4.4")
|
||||
r.WantServers([]string{"8.8.4.4"})
|
||||
}
|
||||
|
||||
func BenchmarkGoLookupIP(b *testing.B) {
|
||||
testHookUninstaller.Do(uninstallTestHooks)
|
||||
|
||||
@ -225,14 +243,21 @@ func BenchmarkGoLookupIPNoSuchHost(b *testing.B) {
|
||||
func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) {
|
||||
testHookUninstaller.Do(uninstallTestHooks)
|
||||
|
||||
// This looks ugly but it's safe as long as benchmarks are run
|
||||
// sequentially in package testing.
|
||||
<-cfg.ch // keep config from being reloaded upon lookup
|
||||
orig := cfg.dnsConfig
|
||||
cfg.dnsConfig.servers = append([]string{"203.0.113.254"}, cfg.dnsConfig.servers...) // use TEST-NET-3 block, see RFC 5737
|
||||
conf, err := newResolvConfTest()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer conf.teardown()
|
||||
|
||||
lines := []string{
|
||||
"nameserver 203.0.113.254", // use TEST-NET-3 block, see RFC 5737
|
||||
"nameserver 8.8.8.8",
|
||||
}
|
||||
if err := conf.writeAndUpdate(lines); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
goLookupIP("www.example.com")
|
||||
}
|
||||
cfg.dnsConfig = orig
|
||||
cfg.ch <- struct{}{}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user