From 636670b8db9b72c80ec89da0a666be0d686269fd Mon Sep 17 00:00:00 2001 From: Scott Bell Date: Sun, 8 May 2016 18:17:59 -0700 Subject: [PATCH] net: use contexts for cgo-based DNS resolution Although calls to getaddrinfo can't be portably interrupted, we still benefit from more granular resource management by pushing the context downwards. Fixes #15321 Change-Id: I5506195fc6493080410e3d46aaa3fe02018a24fe Reviewed-on: https://go-review.googlesource.com/22961 Reviewed-by: Brad Fitzpatrick Run-TryBot: Brad Fitzpatrick TryBot-Result: Gobot Gobot --- src/net/cgo_stub.go | 12 ++-- src/net/cgo_unix.go | 136 +++++++++++++++++++++++++++++-------- src/net/cgo_unix_test.go | 61 ++++++++++++++++- src/net/lookup_unix.go | 12 ++-- src/net/netgo_unix_test.go | 5 +- 5 files changed, 180 insertions(+), 46 deletions(-) diff --git a/src/net/cgo_stub.go b/src/net/cgo_stub.go index 52d1dfd346..51259722ae 100644 --- a/src/net/cgo_stub.go +++ b/src/net/cgo_stub.go @@ -6,6 +6,8 @@ package net +import "context" + func init() { netGo = true } type addrinfoErrno int @@ -14,22 +16,22 @@ func (eai addrinfoErrno) Error() string { return "" } func (eai addrinfoErrno) Temporary() bool { return false } func (eai addrinfoErrno) Timeout() bool { return false } -func cgoLookupHost(name string) (addrs []string, err error, completed bool) { +func cgoLookupHost(ctx context.Context, name string) (addrs []string, err error, completed bool) { return nil, nil, false } -func cgoLookupPort(network, service string) (port int, err error, completed bool) { +func cgoLookupPort(ctx context.Context, network, service string) (port int, err error, completed bool) { return 0, nil, false } -func cgoLookupIP(name string) (addrs []IPAddr, err error, completed bool) { +func cgoLookupIP(ctx context.Context, name string) (addrs []IPAddr, err error, completed bool) { return nil, nil, false } -func cgoLookupCNAME(name string) (cname string, err error, completed bool) { +func cgoLookupCNAME(ctx context.Context, name string) (cname string, err error, completed bool) { return "", nil, false } -func cgoLookupPTR(addr string) (ptrs []string, err error, completed bool) { +func cgoLookupPTR(ctx context.Context, addr string) (ptrs []string, err error, completed bool) { return nil, nil, false } diff --git a/src/net/cgo_unix.go b/src/net/cgo_unix.go index 59c40c8d8a..5a1eed8437 100644 --- a/src/net/cgo_unix.go +++ b/src/net/cgo_unix.go @@ -19,6 +19,7 @@ package net import "C" import ( + "context" "syscall" "unsafe" ) @@ -32,18 +33,31 @@ func (eai addrinfoErrno) Error() string { return C.GoString(C.gai_strerror(C.i func (eai addrinfoErrno) Temporary() bool { return eai == C.EAI_AGAIN } func (eai addrinfoErrno) Timeout() bool { return false } -func cgoLookupHost(name string) (hosts []string, err error, completed bool) { - addrs, err, completed := cgoLookupIP(name) +type portLookupResult struct { + port int + err error +} + +type ipLookupResult struct { + addrs []IPAddr + cname string + err error +} + +type reverseLookupResult struct { + names []string + err error +} + +func cgoLookupHost(ctx context.Context, name string) (hosts []string, err error, completed bool) { + addrs, err, completed := cgoLookupIP(ctx, name) for _, addr := range addrs { hosts = append(hosts, addr.String()) } return } -func cgoLookupPort(network, service string) (port int, err error, completed bool) { - acquireThread() - defer releaseThread() - +func cgoLookupPort(ctx context.Context, network, service string) (port int, err error, completed bool) { var hints C.struct_addrinfo switch network { case "": // no hints @@ -64,11 +78,27 @@ func cgoLookupPort(network, service string) (port int, err error, completed bool hints.ai_family = C.AF_INET6 } } + if ctx.Done() == nil { + port, err := cgoLookupServicePort(&hints, network, service) + return port, err, true + } + result := make(chan portLookupResult, 1) + go cgoPortLookup(result, &hints, network, service) + select { + case r := <-result: + return r.port, r.err, true + case <-ctx.Done(): + // Since there isn't a portable way to cancel the lookup, + // we just let it finish and write to the buffered channel. + return 0, mapErr(ctx.Err()), false + } +} +func cgoLookupServicePort(hints *C.struct_addrinfo, network, service string) (port int, err error) { s := C.CString(service) var res *C.struct_addrinfo defer C.free(unsafe.Pointer(s)) - gerrno, err := C.getaddrinfo(nil, s, &hints, &res) + gerrno, err := C.getaddrinfo(nil, s, hints, &res) if gerrno != 0 { switch gerrno { case C.EAI_SYSTEM: @@ -78,7 +108,7 @@ func cgoLookupPort(network, service string) (port int, err error, completed bool default: err = addrinfoErrno(gerrno) } - return 0, &DNSError{Err: err.Error(), Name: network + "/" + service}, true + return 0, &DNSError{Err: err.Error(), Name: network + "/" + service} } defer C.freeaddrinfo(res) @@ -87,17 +117,22 @@ func cgoLookupPort(network, service string) (port int, err error, completed bool case C.AF_INET: sa := (*syscall.RawSockaddrInet4)(unsafe.Pointer(r.ai_addr)) p := (*[2]byte)(unsafe.Pointer(&sa.Port)) - return int(p[0])<<8 | int(p[1]), nil, true + return int(p[0])<<8 | int(p[1]), nil case C.AF_INET6: sa := (*syscall.RawSockaddrInet6)(unsafe.Pointer(r.ai_addr)) p := (*[2]byte)(unsafe.Pointer(&sa.Port)) - return int(p[0])<<8 | int(p[1]), nil, true + return int(p[0])<<8 | int(p[1]), nil } } - return 0, &DNSError{Err: "unknown port", Name: network + "/" + service}, true + return 0, &DNSError{Err: "unknown port", Name: network + "/" + service} } -func cgoLookupIPCNAME(name string) (addrs []IPAddr, cname string, err error, completed bool) { +func cgoPortLookup(result chan<- portLookupResult, hints *C.struct_addrinfo, network, service string) { + port, err := cgoLookupServicePort(hints, network, service) + result <- portLookupResult{port, err} +} + +func cgoLookupIPCNAME(name string) (addrs []IPAddr, cname string, err error) { acquireThread() defer releaseThread() @@ -127,7 +162,7 @@ func cgoLookupIPCNAME(name string) (addrs []IPAddr, cname string, err error, com default: err = addrinfoErrno(gerrno) } - return nil, "", &DNSError{Err: err.Error(), Name: name}, true + return nil, "", &DNSError{Err: err.Error(), Name: name} } defer C.freeaddrinfo(res) @@ -156,17 +191,42 @@ func cgoLookupIPCNAME(name string) (addrs []IPAddr, cname string, err error, com addrs = append(addrs, addr) } } - return addrs, cname, nil, true + return addrs, cname, nil } -func cgoLookupIP(name string) (addrs []IPAddr, err error, completed bool) { - addrs, _, err, completed = cgoLookupIPCNAME(name) - return +func cgoIPLookup(result chan<- ipLookupResult, name string) { + addrs, cname, err := cgoLookupIPCNAME(name) + result <- ipLookupResult{addrs, cname, err} } -func cgoLookupCNAME(name string) (cname string, err error, completed bool) { - _, cname, err, completed = cgoLookupIPCNAME(name) - return +func cgoLookupIP(ctx context.Context, name string) (addrs []IPAddr, err error, completed bool) { + if ctx.Done() == nil { + addrs, _, err = cgoLookupIPCNAME(name) + return addrs, err, true + } + result := make(chan ipLookupResult, 1) + go cgoIPLookup(result, name) + select { + case r := <-result: + return r.addrs, r.err, true + case <-ctx.Done(): + return nil, mapErr(ctx.Err()), false + } +} + +func cgoLookupCNAME(ctx context.Context, name string) (cname string, err error, completed bool) { + if ctx.Done() == nil { + _, cname, err = cgoLookupIPCNAME(name) + return cname, err, true + } + result := make(chan ipLookupResult, 1) + go cgoIPLookup(result, name) + select { + case r := <-result: + return r.cname, r.err, true + case <-ctx.Done(): + return "", mapErr(ctx.Err()), false + } } // These are roughly enough for the following: @@ -182,10 +242,7 @@ const ( maxNameinfoLen = 4096 ) -func cgoLookupPTR(addr string) ([]string, error, bool) { - acquireThread() - defer releaseThread() - +func cgoLookupPTR(ctx context.Context, addr string) (names []string, err error, completed bool) { var zone string ip := parseIPv4(addr) if ip == nil { @@ -198,9 +255,26 @@ func cgoLookupPTR(addr string) ([]string, error, bool) { if sa == nil { return nil, &DNSError{Err: "invalid address " + ip.String(), Name: addr}, true } - var err error - var b []byte + if ctx.Done() == nil { + names, err := cgoLookupAddrPTR(addr, sa, salen) + return names, err, true + } + result := make(chan reverseLookupResult, 1) + go cgoReverseLookup(result, addr, sa, salen) + select { + case r := <-result: + return r.names, r.err, true + case <-ctx.Done(): + return nil, mapErr(ctx.Err()), false + } +} + +func cgoLookupAddrPTR(addr string, sa *C.struct_sockaddr, salen C.socklen_t) (names []string, err error) { + acquireThread() + defer releaseThread() + var gerrno int + var b []byte for l := nameinfoLen; l <= maxNameinfoLen; l *= 2 { b = make([]byte, l) gerrno, err = cgoNameinfoPTR(b, sa, salen) @@ -217,16 +291,20 @@ func cgoLookupPTR(addr string) ([]string, error, bool) { default: err = addrinfoErrno(gerrno) } - return nil, &DNSError{Err: err.Error(), Name: addr}, true + return nil, &DNSError{Err: err.Error(), Name: addr} } - for i := 0; i < len(b); i++ { if b[i] == 0 { b = b[:i] break } } - return []string{absDomainName(b)}, nil, true + return []string{absDomainName(b)}, nil +} + +func cgoReverseLookup(result chan<- reverseLookupResult, addr string, sa *C.struct_sockaddr, salen C.socklen_t) { + names, err := cgoLookupAddrPTR(addr, sa, salen) + result <- reverseLookupResult{names, err} } func cgoSockaddr(ip IP, zone string) (*C.struct_sockaddr, C.socklen_t) { diff --git a/src/net/cgo_unix_test.go b/src/net/cgo_unix_test.go index 5dc7b1a62d..e861c7aa1f 100644 --- a/src/net/cgo_unix_test.go +++ b/src/net/cgo_unix_test.go @@ -13,15 +13,70 @@ import ( ) func TestCgoLookupIP(t *testing.T) { - host := "localhost" - _, err, ok := cgoLookupIP(host) + ctx := context.Background() + _, err, ok := cgoLookupIP(ctx, "localhost") if !ok { t.Errorf("cgoLookupIP must not be a placeholder") } if err != nil { t.Error(err) } - if _, err := goLookupIP(context.Background(), host); err != nil { +} + +func TestCgoLookupIPWithCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, err, ok := cgoLookupIP(ctx, "localhost") + if !ok { + t.Errorf("cgoLookupIP must not be a placeholder") + } + if err != nil { + t.Error(err) + } +} + +func TestCgoLookupPort(t *testing.T) { + ctx := context.Background() + _, err, ok := cgoLookupPort(ctx, "tcp", "smtp") + if !ok { + t.Errorf("cgoLookupPort must not be a placeholder") + } + if err != nil { + t.Error(err) + } +} + +func TestCgoLookupPortWithCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, err, ok := cgoLookupPort(ctx, "tcp", "smtp") + if !ok { + t.Errorf("cgoLookupPort must not be a placeholder") + } + if err != nil { + t.Error(err) + } +} + +func TestCgoLookupPTR(t *testing.T) { + ctx := context.Background() + _, err, ok := cgoLookupPTR(ctx, "127.0.0.1") + if !ok { + t.Errorf("cgoLookupPTR must not be a placeholder") + } + if err != nil { + t.Error(err) + } +} + +func TestCgoLookupPTRWithCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, err, ok := cgoLookupPTR(ctx, "127.0.0.1") + if !ok { + t.Errorf("cgoLookupPTR must not be a placeholder") + } + if err != nil { t.Error(err) } } diff --git a/src/net/lookup_unix.go b/src/net/lookup_unix.go index 5461fe8a41..15397e8105 100644 --- a/src/net/lookup_unix.go +++ b/src/net/lookup_unix.go @@ -55,7 +55,7 @@ func lookupProtocol(_ context.Context, name string) (int, error) { func lookupHost(ctx context.Context, host string) (addrs []string, err error) { order := systemConf().hostLookupOrder(host) if order == hostLookupCgo { - if addrs, err, ok := cgoLookupHost(host); ok { + if addrs, err, ok := cgoLookupHost(ctx, host); ok { return addrs, err } // cgo not available (or netgo); fall back to Go's DNS resolver @@ -67,8 +67,7 @@ func lookupHost(ctx context.Context, host string) (addrs []string, err error) { func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) { order := systemConf().hostLookupOrder(host) if order == hostLookupCgo { - // TODO(bradfitz): push down ctx, or at least its deadline to start - if addrs, err, ok := cgoLookupIP(host); ok { + if addrs, err, ok := cgoLookupIP(ctx, host); ok { return addrs, err } // cgo not available (or netgo); fall back to Go's DNS resolver @@ -84,7 +83,7 @@ func lookupPort(ctx context.Context, network, service string) (int, error) { // files might be on a remote filesystem, though. This should // probably race goroutines if ctx != context.Background(). if systemConf().canUseCgo() { - if port, err, ok := cgoLookupPort(network, service); ok { + if port, err, ok := cgoLookupPort(ctx, network, service); ok { return port, err } } @@ -93,8 +92,7 @@ func lookupPort(ctx context.Context, network, service string) (int, error) { func lookupCNAME(ctx context.Context, name string) (string, error) { if systemConf().canUseCgo() { - // TODO: use ctx. issue 15321. Or race goroutines. - if cname, err, ok := cgoLookupCNAME(name); ok { + if cname, err, ok := cgoLookupCNAME(ctx, name); ok { return cname, err } } @@ -161,7 +159,7 @@ func lookupTXT(ctx context.Context, name string) ([]string, error) { func lookupAddr(ctx context.Context, addr string) ([]string, error) { if systemConf().canUseCgo() { - if ptrs, err, ok := cgoLookupPTR(addr); ok { + if ptrs, err, ok := cgoLookupPTR(ctx, addr); ok { return ptrs, err } } diff --git a/src/net/netgo_unix_test.go b/src/net/netgo_unix_test.go index 0a118874c2..5f1eb19e12 100644 --- a/src/net/netgo_unix_test.go +++ b/src/net/netgo_unix_test.go @@ -14,14 +14,15 @@ import ( func TestGoLookupIP(t *testing.T) { host := "localhost" - _, err, ok := cgoLookupIP(host) + ctx := context.Background() + _, err, ok := cgoLookupIP(ctx, host) if ok { t.Errorf("cgoLookupIP must be a placeholder") } if err != nil { t.Error(err) } - if _, err := goLookupIP(context.Background(), host); err != nil { + if _, err := goLookupIP(ctx, host); err != nil { t.Error(err) } }