From f60fcca5f1e7b7a33e219ec45d4bd9dc58dd2552 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Sat, 16 Apr 2016 14:17:40 -0700 Subject: [PATCH] net: fix plan9 after context change, propagate contexts more My previous https://golang.org/cl/22101 to add context throughout the net package broke Plan 9, which isn't currently tested (#15251). It also broke some old unsupported version of Windows (Windows 2000?) which doesn't have the ConnectEx function, but that was only found visually, since our minimum supported Windows version has ConnectEx. This change simplifies the Windows and deletes the non-ConnectEx code path. Windows 2000 will work even less now, if it even worked before. Windows XP remains our minimum supported version. Specifically, the previous CL stopped using the "dial" function, which 0intro noted: https://github.com/golang/go/issues/15333#issuecomment-210842761 This CL removes the dial function instead and makes plan9's net implementation respect contexts, which likely fixes a number of t.Skipped tests. I'm leaving that to 0intro to investigate. In the process of propagating and respecting contexts for plan9, I had to change some signatures to add contexts to more places and ended up pushing contexts down into the Go-based DNS resolution as well, replacing the pure-Go DNS implementation's use of "timeout time.Duration" with a context instead. Updates #11932 Updates #15328 Fixes #15333 Change-Id: I6ad1e62f38271cdd86b3f40921f2d0f23374936a Reviewed-on: https://go-review.googlesource.com/22144 Reviewed-by: David du Colombier <0intro@gmail.com> Reviewed-by: Mikio Hara Reviewed-by: Ian Lance Taylor Run-TryBot: Brad Fitzpatrick TryBot-Result: Gobot Gobot --- src/net/dial.go | 9 ++--- src/net/dial_gen.go | 40 -------------------- src/net/dnsclient_unix.go | 50 +++++++++++++----------- src/net/dnsclient_unix_test.go | 13 ++++--- src/net/fd_plan9.go | 6 --- src/net/fd_windows.go | 25 +++++------- src/net/iprawsock.go | 2 +- src/net/iprawsock_posix.go | 4 +- src/net/ipsock_plan9.go | 48 +++++++++++++++++------ src/net/lookup.go | 14 +++---- src/net/lookup_plan9.go | 51 ++++++++++++------------- src/net/lookup_stub.go | 16 ++++---- src/net/lookup_unix.go | 34 ++++++++++------- src/net/lookup_windows.go | 69 ++++++++++++++++++++++------------ src/net/tcpsock_plan9.go | 8 +--- src/net/udpsock_plan9.go | 7 +--- 16 files changed, 199 insertions(+), 197 deletions(-) delete mode 100644 src/net/dial_gen.go diff --git a/src/net/dial.go b/src/net/dial.go index 1f31e8f2cc7..59e41f536b2 100644 --- a/src/net/dial.go +++ b/src/net/dial.go @@ -124,7 +124,7 @@ func (d *Dialer) fallbackDelay() time.Duration { } } -func parseNetwork(net string) (afnet string, proto int, err error) { +func parseNetwork(ctx context.Context, net string) (afnet string, proto int, err error) { i := last(net, ':') if i < 0 { // no colon switch net { @@ -143,7 +143,7 @@ func parseNetwork(net string) (afnet string, proto int, err error) { protostr := net[i+1:] proto, i, ok := dtoi(protostr, 0) if !ok || i != len(protostr) { - proto, err = lookupProtocol(protostr) + proto, err = lookupProtocol(ctx, protostr) if err != nil { return "", 0, err } @@ -157,7 +157,7 @@ func parseNetwork(net string) (afnet string, proto int, err error) { // addresses. The result contains at least one address when error is // nil. func resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) (addrList, error) { - afnet, _, err := parseNetwork(network) + afnet, _, err := parseNetwork(ctx, network) if err != nil { return nil, err } @@ -472,8 +472,7 @@ func dialSerial(ctx context.Context, dp *dialParam, ras addrList) (Conn, error) } // dialSingle attempts to establish and returns a single connection to -// the destination address. This must be called through the OS-specific -// dial function, because some OSes don't implement the deadline feature. +// the destination address. func dialSingle(ctx context.Context, dp *dialParam, ra Addr) (c Conn, err error) { la := dp.LocalAddr switch ra := ra.(type) { diff --git a/src/net/dial_gen.go b/src/net/dial_gen.go deleted file mode 100644 index a2cd8cb44df..00000000000 --- a/src/net/dial_gen.go +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2012 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build windows plan9 - -package net - -import "time" - -// dialChannel is the simple pure-Go implementation of dial, still -// used on operating systems where the deadline hasn't been pushed -// down into the pollserver. (Plan 9 and some old versions of Windows) -func dialChannel(net string, ra Addr, dialer func(time.Time) (Conn, error), deadline time.Time) (Conn, error) { - if deadline.IsZero() { - return dialer(noDeadline) - } - timeout := deadline.Sub(time.Now()) - if timeout <= 0 { - return nil, &OpError{Op: "dial", Net: net, Source: nil, Addr: ra, Err: errTimeout} - } - t := time.NewTimer(timeout) - defer t.Stop() - type racer struct { - Conn - error - } - ch := make(chan racer, 1) - go func() { - testHookDialChannel() - c, err := dialer(noDeadline) - ch <- racer{c, err} - }() - select { - case <-t.C: - return nil, &OpError{Op: "dial", Net: net, Source: nil, Addr: ra, Err: errTimeout} - case racer := <-ch: - return racer.Conn, racer.error - } -} diff --git a/src/net/dnsclient_unix.go b/src/net/dnsclient_unix.go index 914dd767d33..5ae21012e3c 100644 --- a/src/net/dnsclient_unix.go +++ b/src/net/dnsclient_unix.go @@ -27,10 +27,10 @@ import ( // A dnsDialer provides dialing suitable for DNS queries. type dnsDialer interface { - dialDNS(string, string) (dnsConn, error) + dialDNS(ctx context.Context, network, addr string) (dnsConn, error) } -var testHookDNSDialer = func(d time.Duration) dnsDialer { return &Dialer{Timeout: d} } +var testHookDNSDialer = func() dnsDialer { return &Dialer{} } // A dnsConn represents a DNS transport endpoint. type dnsConn interface { @@ -105,7 +105,7 @@ func (c *TCPConn) writeDNSQuery(msg *dnsMsg) error { return nil } -func (d *Dialer) dialDNS(network, server string) (dnsConn, error) { +func (d *Dialer) dialDNS(ctx context.Context, network, server string) (dnsConn, error) { switch network { case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6": default: @@ -116,9 +116,9 @@ func (d *Dialer) dialDNS(network, server string) (dnsConn, error) { // call back here to translate it. The DNS config parser has // already checked that all the cfg.servers[i] are IP // addresses, which Dial will use without a DNS lookup. - c, err := d.Dial(network, server) + c, err := d.DialContext(ctx, network, server) if err != nil { - return nil, err + return nil, mapErr(err) } switch network { case "tcp", "tcp4", "tcp6": @@ -130,8 +130,8 @@ func (d *Dialer) dialDNS(network, server string) (dnsConn, error) { } // exchange sends a query on the connection and hopes for a response. -func exchange(server, name string, qtype uint16, timeout time.Duration) (*dnsMsg, error) { - d := testHookDNSDialer(timeout) +func exchange(ctx context.Context, server, name string, qtype uint16) (*dnsMsg, error) { + d := testHookDNSDialer() out := dnsMsg{ dnsMsgHdr: dnsMsgHdr{ recursion_desired: true, @@ -141,21 +141,21 @@ func exchange(server, name string, qtype uint16, timeout time.Duration) (*dnsMsg }, } for _, network := range []string{"udp", "tcp"} { - c, err := d.dialDNS(network, server) + c, err := d.dialDNS(ctx, network, server) if err != nil { return nil, err } defer c.Close() - if timeout > 0 { - c.SetDeadline(time.Now().Add(timeout)) + if d, ok := ctx.Deadline(); ok && !d.IsZero() { + c.SetDeadline(d) } out.id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano()) if err := c.writeDNSQuery(&out); err != nil { - return nil, err + return nil, mapErr(err) } in, err := c.readDNSResponse() if err != nil { - return nil, err + return nil, mapErr(err) } if in.id != out.id { return nil, errors.New("DNS message ID mismatch") @@ -170,16 +170,24 @@ func exchange(server, name string, qtype uint16, timeout time.Duration) (*dnsMsg // Do a lookup for a single name, which must be rooted // (otherwise answer will not find the answers). -func tryOneName(cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, error) { +func tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, error) { if len(cfg.servers) == 0 { return "", nil, &DNSError{Err: "no DNS servers", Name: name} } + timeout := time.Duration(cfg.timeout) * time.Second + deadline := time.Now().Add(timeout) + if old, ok := ctx.Deadline(); !ok || deadline.Before(old) { + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(ctx, deadline) + defer cancel() + } + var lastErr error for i := 0; i < cfg.attempts; i++ { for _, server := range cfg.servers { server = JoinHostPort(server, "53") - msg, err := exchange(server, name, qtype, timeout) + msg, err := exchange(ctx, server, name, qtype) if err != nil { lastErr = &DNSError{ Err: err.Error(), @@ -297,7 +305,7 @@ func (conf *resolverConfig) releaseSema() { <-conf.ch } -func lookup(name string, qtype uint16) (cname string, rrs []dnsRR, err error) { +func lookup(ctx context.Context, name string, qtype uint16) (cname string, rrs []dnsRR, err error) { if !isDomainName(name) { return "", nil, &DNSError{Err: "invalid domain name", Name: name} } @@ -306,7 +314,7 @@ func lookup(name string, qtype uint16) (cname string, rrs []dnsRR, err error) { conf := resolvConf.dnsConfig resolvConf.mu.RUnlock() for _, fqdn := range conf.nameList(name) { - cname, rrs, err = tryOneName(conf, fqdn, qtype) + cname, rrs, err = tryOneName(ctx, conf, fqdn, qtype) if err == nil { break } @@ -467,7 +475,7 @@ func goLookupIPOrder(ctx context.Context, name string, order hostLookupOrder) (a for _, fqdn := range conf.nameList(name) { for _, qtype := range qtypes { go func(qtype uint16) { - _, rrs, err := tryOneName(conf, fqdn, qtype) + _, rrs, err := tryOneName(ctx, conf, fqdn, qtype) lane <- racer{fqdn, rrs, err} }(qtype) } @@ -510,8 +518,8 @@ func goLookupIPOrder(ctx context.Context, name string, order hostLookupOrder) (a // Normally we let cgo use the C library resolver instead of // depending on our lookup code, so that Go and C get the same // answers. -func goLookupCNAME(name string) (cname string, err error) { - _, rrs, err := lookup(name, dnsTypeCNAME) +func goLookupCNAME(ctx context.Context, name string) (cname string, err error) { + _, rrs, err := lookup(ctx, name, dnsTypeCNAME) if err != nil { return } @@ -524,7 +532,7 @@ func goLookupCNAME(name string) (cname string, err error) { // only if cgoLookupPTR is the stub in cgo_stub.go). // Normally we let cgo use the C library resolver instead of depending // on our lookup code, so that Go and C get the same answers. -func goLookupPTR(addr string) ([]string, error) { +func goLookupPTR(ctx context.Context, addr string) ([]string, error) { names := lookupStaticAddr(addr) if len(names) > 0 { return names, nil @@ -533,7 +541,7 @@ func goLookupPTR(addr string) ([]string, error) { if err != nil { return nil, err } - _, rrs, err := lookup(arpa, dnsTypePTR) + _, rrs, err := lookup(ctx, arpa, dnsTypePTR) if err != nil { return nil, err } diff --git a/src/net/dnsclient_unix_test.go b/src/net/dnsclient_unix_test.go index 145a3b6a33b..761fb23f142 100644 --- a/src/net/dnsclient_unix_test.go +++ b/src/net/dnsclient_unix_test.go @@ -37,8 +37,9 @@ func TestDNSTransportFallback(t *testing.T) { testenv.MustHaveExternalNetwork(t) for _, tt := range dnsTransportFallbackTests { - timeout := time.Duration(tt.timeout) * time.Second - msg, err := exchange(tt.server, tt.name, tt.qtype, timeout) + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(tt.timeout)*time.Second) + defer cancel() + msg, err := exchange(ctx, tt.server, tt.name, tt.qtype) if err != nil { t.Error(err) continue @@ -78,7 +79,9 @@ func TestSpecialDomainName(t *testing.T) { server := "8.8.8.8:53" for _, tt := range specialDomainNameTests { - msg, err := exchange(server, tt.name, tt.qtype, 3*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msg, err := exchange(ctx, server, tt.name, tt.qtype) if err != nil { t.Error(err) continue @@ -492,7 +495,7 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) { } d := &fakeDNSConn{} - testHookDNSDialer = func(time.Duration) dnsDialer { return d } + testHookDNSDialer = func() dnsDialer { return d } d.rh = func(q *dnsMsg) (*dnsMsg, error) { r := &dnsMsg{ @@ -571,7 +574,7 @@ type fakeDNSConn struct { rh func(*dnsMsg) (*dnsMsg, error) } -func (f *fakeDNSConn) dialDNS(n, s string) (dnsConn, error) { +func (f *fakeDNSConn) dialDNS(_ context.Context, n, s string) (dnsConn, error) { return f, nil } diff --git a/src/net/fd_plan9.go b/src/net/fd_plan9.go index d0e9c53fca6..35d16243178 100644 --- a/src/net/fd_plan9.go +++ b/src/net/fd_plan9.go @@ -32,12 +32,6 @@ func sysInit() { netdir = "/net" } -func dial(net string, ra Addr, dialer func(time.Time) (Conn, error), deadline time.Time) (Conn, error) { - // On plan9, use the relatively inefficient - // goroutine-racing implementation. - return dialChannel(net, ra, dialer, deadline) -} - func newFD(net, name string, ctl, data *os.File, laddr, raddr Addr) (*netFD, error) { return &netFD{net: net, n: name, dir: netdir + "/" + net + "/" + name, ctl: ctl, data: data, laddr: laddr, raddr: raddr}, nil } diff --git a/src/net/fd_windows.go b/src/net/fd_windows.go index d1d91a6a5c5..ca46bf93610 100644 --- a/src/net/fd_windows.go +++ b/src/net/fd_windows.go @@ -11,7 +11,6 @@ import ( "runtime" "sync" "syscall" - "time" "unsafe" ) @@ -70,22 +69,15 @@ func sysInit() { } } +// canUseConnectEx reports whether we can use the ConnectEx Windows API call +// for the given network type. func canUseConnectEx(net string) bool { switch net { - case "udp", "udp4", "udp6", "ip", "ip4", "ip6": - // ConnectEx windows API does not support connectionless sockets. - return false + case "tcp", "tcp4", "tcp6": + return true } - return syscall.LoadConnectEx() == nil -} - -func dial(net string, ra Addr, dialer func(time.Time) (Conn, error), deadline time.Time) (Conn, error) { - if !canUseConnectEx(net) { - // Use the relatively inefficient goroutine-racing - // implementation of DialTimeout. - return dialChannel(net, ra, dialer, deadline) - } - return dialer(deadline) + // ConnectEx windows API does not support connectionless sockets. + return false } // operation contains superset of data necessary to perform all async IO. @@ -328,12 +320,13 @@ func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) error { if err := fd.init(); err != nil { return err } - if deadline, _ := ctx.Deadline(); !deadline.IsZero() { + if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() { fd.setWriteDeadline(deadline) defer fd.setWriteDeadline(noDeadline) } if !canUseConnectEx(fd.net) { - return os.NewSyscallError("connect", connectFunc(fd.sysfd, ra)) + err := connectFunc(fd.sysfd, ra) + return os.NewSyscallError("connect", err) } // ConnectEx windows API requires an unconnected, previously bound socket. if la == nil { diff --git a/src/net/iprawsock.go b/src/net/iprawsock.go index f4a4de82fcd..173b3cb4114 100644 --- a/src/net/iprawsock.go +++ b/src/net/iprawsock.go @@ -50,7 +50,7 @@ func ResolveIPAddr(net, addr string) (*IPAddr, error) { if net == "" { // a hint wildcard for Go 1.0 undocumented behavior net = "ip" } - afnet, _, err := parseNetwork(net) + afnet, _, err := parseNetwork(context.Background(), net) if err != nil { return nil, err } diff --git a/src/net/iprawsock_posix.go b/src/net/iprawsock_posix.go index 68dc307b606..3e0b060a8a4 100644 --- a/src/net/iprawsock_posix.go +++ b/src/net/iprawsock_posix.go @@ -121,7 +121,7 @@ func (c *IPConn) writeMsg(b, oob []byte, addr *IPAddr) (n, oobn int, err error) } func dialIP(ctx context.Context, netProto string, laddr, raddr *IPAddr) (*IPConn, error) { - network, proto, err := parseNetwork(netProto) + network, proto, err := parseNetwork(ctx, netProto) if err != nil { return nil, err } @@ -141,7 +141,7 @@ func dialIP(ctx context.Context, netProto string, laddr, raddr *IPAddr) (*IPConn } func listenIP(ctx context.Context, netProto string, laddr *IPAddr) (*IPConn, error) { - network, proto, err := parseNetwork(netProto) + network, proto, err := parseNetwork(ctx, netProto) if err != nil { return nil, err } diff --git a/src/net/ipsock_plan9.go b/src/net/ipsock_plan9.go index f7c2b446883..2b84683eeb5 100644 --- a/src/net/ipsock_plan9.go +++ b/src/net/ipsock_plan9.go @@ -7,6 +7,7 @@ package net import ( + "context" "os" "syscall" ) @@ -99,7 +100,7 @@ func readPlan9Addr(proto, filename string) (addr Addr, err error) { return addr, nil } -func startPlan9(net string, addr Addr) (ctl *os.File, dest, proto, name string, err error) { +func startPlan9(ctx context.Context, net string, addr Addr) (ctl *os.File, dest, proto, name string, err error) { var ( ip IP port int @@ -118,7 +119,7 @@ func startPlan9(net string, addr Addr) (ctl *os.File, dest, proto, name string, return } - clone, dest, err := queryCS1(proto, ip, port) + clone, dest, err := queryCS1(ctx, proto, ip, port) if err != nil { return } @@ -135,8 +136,8 @@ func startPlan9(net string, addr Addr) (ctl *os.File, dest, proto, name string, return f, dest, proto, string(buf[:n]), nil } -func netErr(e error) { - oe, ok := e.(*OpError) +func fixErr(err error) { + oe, ok := err.(*OpError) if !ok { return } @@ -165,9 +166,34 @@ func netErr(e error) { } } -func dialPlan9(net string, laddr, raddr Addr) (fd *netFD, err error) { - defer func() { netErr(err) }() - f, dest, proto, name, err := startPlan9(net, raddr) +func dialPlan9(ctx context.Context, net string, laddr, raddr Addr) (fd *netFD, err error) { + defer func() { fixErr(err) }() + type res struct { + fd *netFD + err error + } + resc := make(chan res) + go func() { + testHookDialChannel() + fd, err := dialPlan9Blocking(ctx, net, laddr, raddr) + select { + case resc <- res{fd, err}: + case <-ctx.Done(): + if fd != nil { + fd.Close() + } + } + }() + select { + case res := <-resc: + return res.fd, res.err + case <-ctx.Done(): + return nil, mapErr(ctx.Err()) + } +} + +func dialPlan9Blocking(ctx context.Context, net string, laddr, raddr Addr) (fd *netFD, err error) { + f, dest, proto, name, err := startPlan9(ctx, net, raddr) if err != nil { return nil, err } @@ -190,9 +216,9 @@ func dialPlan9(net string, laddr, raddr Addr) (fd *netFD, err error) { return newFD(proto, name, f, data, laddr, raddr) } -func listenPlan9(net string, laddr Addr) (fd *netFD, err error) { - defer func() { netErr(err) }() - f, dest, proto, name, err := startPlan9(net, laddr) +func listenPlan9(ctx context.Context, net string, laddr Addr) (fd *netFD, err error) { + defer func() { fixErr(err) }() + f, dest, proto, name, err := startPlan9(ctx, net, laddr) if err != nil { return nil, err } @@ -214,7 +240,7 @@ func (fd *netFD) netFD() (*netFD, error) { } func (fd *netFD) acceptPlan9() (nfd *netFD, err error) { - defer func() { netErr(err) }() + defer func() { fixErr(err) }() if err := fd.readLock(); err != nil { return nil, err } diff --git a/src/net/lookup.go b/src/net/lookup.go index 8f02787422b..5e60011165d 100644 --- a/src/net/lookup.go +++ b/src/net/lookup.go @@ -114,7 +114,7 @@ func lookupIPContext(ctx context.Context, host string) (addrs []IPAddr, err erro func LookupPort(network, service string) (port int, err error) { port, needsLookup := parsePort(service) if needsLookup { - port, err = lookupPort(network, service) + port, err = lookupPort(context.Background(), network, service) if err != nil { return 0, err } @@ -130,7 +130,7 @@ func LookupPort(network, service string) (port int, err error) { // LookupHost or LookupIP directly; both take care of resolving // the canonical name as part of the lookup. func LookupCNAME(name string) (cname string, err error) { - return lookupCNAME(name) + return lookupCNAME(context.Background(), name) } // LookupSRV tries to resolve an SRV query of the given service, @@ -143,26 +143,26 @@ func LookupCNAME(name string) (cname string, err error) { // publishing SRV records under non-standard names, if both service // and proto are empty strings, LookupSRV looks up name directly. func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err error) { - return lookupSRV(service, proto, name) + return lookupSRV(context.Background(), service, proto, name) } // LookupMX returns the DNS MX records for the given domain name sorted by preference. func LookupMX(name string) (mxs []*MX, err error) { - return lookupMX(name) + return lookupMX(context.Background(), name) } // LookupNS returns the DNS NS records for the given domain name. func LookupNS(name string) (nss []*NS, err error) { - return lookupNS(name) + return lookupNS(context.Background(), name) } // LookupTXT returns the DNS TXT records for the given domain name. func LookupTXT(name string) (txts []string, err error) { - return lookupTXT(name) + return lookupTXT(context.Background(), name) } // LookupAddr performs a reverse lookup for the given address, returning a list // of names mapping to that address. func LookupAddr(addr string) (names []string, err error) { - return lookupAddr(addr) + return lookupAddr(context.Background(), addr) } diff --git a/src/net/lookup_plan9.go b/src/net/lookup_plan9.go index 4224263602b..73147a2d3f7 100644 --- a/src/net/lookup_plan9.go +++ b/src/net/lookup_plan9.go @@ -10,7 +10,7 @@ import ( "os" ) -func query(filename, query string, bufSize int) (res []string, err error) { +func query(ctx context.Context, filename, query string, bufSize int) (res []string, err error) { file, err := os.OpenFile(filename, os.O_RDWR, 0) if err != nil { return @@ -40,7 +40,7 @@ func query(filename, query string, bufSize int) (res []string, err error) { return } -func queryCS(net, host, service string) (res []string, err error) { +func queryCS(ctx context.Context, net, host, service string) (res []string, err error) { switch net { case "tcp4", "tcp6": net = "tcp" @@ -50,15 +50,15 @@ func queryCS(net, host, service string) (res []string, err error) { if host == "" { host = "*" } - return query(netdir+"/cs", net+"!"+host+"!"+service, 128) + return query(ctx, netdir+"/cs", net+"!"+host+"!"+service, 128) } -func queryCS1(net string, ip IP, port int) (clone, dest string, err error) { +func queryCS1(ctx context.Context, net string, ip IP, port int) (clone, dest string, err error) { ips := "*" if len(ip) != 0 && !ip.IsUnspecified() { ips = ip.String() } - lines, err := queryCS(net, ips, itoa(port)) + lines, err := queryCS(ctx, net, ips, itoa(port)) if err != nil { return } @@ -70,8 +70,8 @@ func queryCS1(net string, ip IP, port int) (clone, dest string, err error) { return } -func queryDNS(addr string, typ string) (res []string, err error) { - return query(netdir+"/dns", addr+" "+typ, 1024) +func queryDNS(ctx context.Context, addr string, typ string) (res []string, err error) { + return query(ctx, netdir+"/dns", addr+" "+typ, 1024) } // toLower returns a lower-case version of in. Restricting us to @@ -97,8 +97,8 @@ func toLower(in string) string { // lookupProtocol looks up IP protocol name and returns // the corresponding protocol number. -func lookupProtocol(name string) (proto int, err error) { - lines, err := query(netdir+"/cs", "!protocol="+toLower(name), 128) +func lookupProtocol(ctx context.Context, name string) (proto int, err error) { + lines, err := query(ctx, netdir+"/cs", "!protocol="+toLower(name), 128) if err != nil { return 0, err } @@ -119,7 +119,7 @@ func lookupProtocol(name string) (proto int, err error) { func lookupHost(ctx context.Context, host string) (addrs []string, err error) { // Use netdir/cs instead of netdir/dns because cs knows about // host names in local network (e.g. from /lib/ndb/local) - lines, err := queryCS("net", host, "1") + lines, err := queryCS(ctx, "net", host, "1") if err != nil { return } @@ -148,8 +148,7 @@ loop: } func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) { - // TODO(bradfitz): push down ctx - lits, err := LookupHost(host) + lits, err := lookupHost(ctx, host) if err != nil { return } @@ -163,14 +162,14 @@ func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) { return } -func lookupPort(network, service string) (port int, err error) { +func lookupPort(ctx context.Context, network, service string) (port int, err error) { switch network { case "tcp4", "tcp6": network = "tcp" case "udp4", "udp6": network = "udp" } - lines, err := queryCS(network, "127.0.0.1", service) + lines, err := queryCS(ctx, network, "127.0.0.1", service) if err != nil { return } @@ -192,8 +191,8 @@ func lookupPort(network, service string) (port int, err error) { return 0, unknownPortError } -func lookupCNAME(name string) (cname string, err error) { - lines, err := queryDNS(name, "cname") +func lookupCNAME(ctx context.Context, name string) (cname string, err error) { + lines, err := queryDNS(ctx, name, "cname") if err != nil { return } @@ -205,14 +204,14 @@ func lookupCNAME(name string) (cname string, err error) { return "", errors.New("bad response from ndb/dns") } -func lookupSRV(service, proto, name string) (cname string, addrs []*SRV, err error) { +func lookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*SRV, err error) { var target string if service == "" && proto == "" { target = name } else { target = "_" + service + "._" + proto + "." + name } - lines, err := queryDNS(target, "srv") + lines, err := queryDNS(ctx, target, "srv") if err != nil { return } @@ -234,8 +233,8 @@ func lookupSRV(service, proto, name string) (cname string, addrs []*SRV, err err return } -func lookupMX(name string) (mx []*MX, err error) { - lines, err := queryDNS(name, "mx") +func lookupMX(ctx context.Context, name string) (mx []*MX, err error) { + lines, err := queryDNS(ctx, name, "mx") if err != nil { return } @@ -252,8 +251,8 @@ func lookupMX(name string) (mx []*MX, err error) { return } -func lookupNS(name string) (ns []*NS, err error) { - lines, err := queryDNS(name, "ns") +func lookupNS(ctx context.Context, name string) (ns []*NS, err error) { + lines, err := queryDNS(ctx, name, "ns") if err != nil { return } @@ -267,8 +266,8 @@ func lookupNS(name string) (ns []*NS, err error) { return } -func lookupTXT(name string) (txt []string, err error) { - lines, err := queryDNS(name, "txt") +func lookupTXT(ctx context.Context, name string) (txt []string, err error) { + lines, err := queryDNS(ctx, name, "txt") if err != nil { return } @@ -280,12 +279,12 @@ func lookupTXT(name string) (txt []string, err error) { return } -func lookupAddr(addr string) (name []string, err error) { +func lookupAddr(ctx context.Context, addr string) (name []string, err error) { arpa, err := reverseaddr(addr) if err != nil { return } - lines, err := queryDNS(arpa, "ptr") + lines, err := queryDNS(ctx, arpa, "ptr") if err != nil { return } diff --git a/src/net/lookup_stub.go b/src/net/lookup_stub.go index 38a4f0bae48..bd096b39652 100644 --- a/src/net/lookup_stub.go +++ b/src/net/lookup_stub.go @@ -11,7 +11,7 @@ import ( "syscall" ) -func lookupProtocol(name string) (proto int, err error) { +func lookupProtocol(ctx context.Context, name string) (proto int, err error) { return 0, syscall.ENOPROTOOPT } @@ -23,30 +23,30 @@ func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) { return nil, syscall.ENOPROTOOPT } -func lookupPort(network, service string) (port int, err error) { +func lookupPort(ctx context.Context, network, service string) (port int, err error) { return 0, syscall.ENOPROTOOPT } -func lookupCNAME(name string) (cname string, err error) { +func lookupCNAME(ctx context.Context, name string) (cname string, err error) { return "", syscall.ENOPROTOOPT } -func lookupSRV(service, proto, name string) (cname string, srvs []*SRV, err error) { +func lookupSRV(ctx context.Context, service, proto, name string) (cname string, srvs []*SRV, err error) { return "", nil, syscall.ENOPROTOOPT } -func lookupMX(name string) (mxs []*MX, err error) { +func lookupMX(ctx context.Context, name string) (mxs []*MX, err error) { return nil, syscall.ENOPROTOOPT } -func lookupNS(name string) (nss []*NS, err error) { +func lookupNS(ctx context.Context, name string) (nss []*NS, err error) { return nil, syscall.ENOPROTOOPT } -func lookupTXT(name string) (txts []string, err error) { +func lookupTXT(ctx context.Context, name string) (txts []string, err error) { return nil, syscall.ENOPROTOOPT } -func lookupAddr(addr string) (ptrs []string, err error) { +func lookupAddr(ctx context.Context, addr string) (ptrs []string, err error) { return nil, syscall.ENOPROTOOPT } diff --git a/src/net/lookup_unix.go b/src/net/lookup_unix.go index 8d3fa477828..5461fe8a411 100644 --- a/src/net/lookup_unix.go +++ b/src/net/lookup_unix.go @@ -43,7 +43,7 @@ func readProtocols() { // lookupProtocol looks up IP protocol name in /etc/protocols and // returns correspondent protocol number. -func lookupProtocol(name string) (int, error) { +func lookupProtocol(_ context.Context, name string) (int, error) { onceReadProtocols.Do(readProtocols) proto, found := protocols[name] if !found { @@ -77,7 +77,12 @@ func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) { return goLookupIPOrder(ctx, host, order) } -func lookupPort(network, service string) (int, error) { +func lookupPort(ctx context.Context, network, service string) (int, error) { + // TODO: use the context if there ever becomes a need. Related + // is issue 15321. But port lookup generally just involves + // local files, and the os package has no context support. The + // files might be on a remote filesystem, though. This should + // probably race goroutines if ctx != context.Background(). if systemConf().canUseCgo() { if port, err, ok := cgoLookupPort(network, service); ok { return port, err @@ -86,23 +91,24 @@ func lookupPort(network, service string) (int, error) { return goLookupPort(network, service) } -func lookupCNAME(name string) (string, error) { +func lookupCNAME(ctx context.Context, name string) (string, error) { if systemConf().canUseCgo() { + // TODO: use ctx. issue 15321. Or race goroutines. if cname, err, ok := cgoLookupCNAME(name); ok { return cname, err } } - return goLookupCNAME(name) + return goLookupCNAME(ctx, name) } -func lookupSRV(service, proto, name string) (string, []*SRV, error) { +func lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) { var target string if service == "" && proto == "" { target = name } else { target = "_" + service + "._" + proto + "." + name } - cname, rrs, err := lookup(target, dnsTypeSRV) + cname, rrs, err := lookup(ctx, target, dnsTypeSRV) if err != nil { return "", nil, err } @@ -115,8 +121,8 @@ func lookupSRV(service, proto, name string) (string, []*SRV, error) { return cname, srvs, nil } -func lookupMX(name string) ([]*MX, error) { - _, rrs, err := lookup(name, dnsTypeMX) +func lookupMX(ctx context.Context, name string) ([]*MX, error) { + _, rrs, err := lookup(ctx, name, dnsTypeMX) if err != nil { return nil, err } @@ -129,8 +135,8 @@ func lookupMX(name string) ([]*MX, error) { return mxs, nil } -func lookupNS(name string) ([]*NS, error) { - _, rrs, err := lookup(name, dnsTypeNS) +func lookupNS(ctx context.Context, name string) ([]*NS, error) { + _, rrs, err := lookup(ctx, name, dnsTypeNS) if err != nil { return nil, err } @@ -141,8 +147,8 @@ func lookupNS(name string) ([]*NS, error) { return nss, nil } -func lookupTXT(name string) ([]string, error) { - _, rrs, err := lookup(name, dnsTypeTXT) +func lookupTXT(ctx context.Context, name string) ([]string, error) { + _, rrs, err := lookup(ctx, name, dnsTypeTXT) if err != nil { return nil, err } @@ -153,11 +159,11 @@ func lookupTXT(name string) ([]string, error) { return txts, nil } -func lookupAddr(addr string) ([]string, error) { +func lookupAddr(ctx context.Context, addr string) ([]string, error) { if systemConf().canUseCgo() { if ptrs, err, ok := cgoLookupPTR(addr); ok { return ptrs, err } } - return goLookupPTR(addr) + return goLookupPTR(ctx, addr) } diff --git a/src/net/lookup_windows.go b/src/net/lookup_windows.go index ce012ba873f..7a04cc89984 100644 --- a/src/net/lookup_windows.go +++ b/src/net/lookup_windows.go @@ -26,30 +26,37 @@ func getprotobyname(name string) (proto int, err error) { } // lookupProtocol looks up IP protocol name and returns correspondent protocol number. -func lookupProtocol(name string) (int, error) { +func lookupProtocol(ctx context.Context, name string) (int, error) { // GetProtoByName return value is stored in thread local storage. // Start new os thread before the call to prevent races. type result struct { proto int err error } - ch := make(chan result) + ch := make(chan result) // unbuffered go func() { acquireThread() defer releaseThread() runtime.LockOSThread() defer runtime.UnlockOSThread() proto, err := getprotobyname(name) - ch <- result{proto: proto, err: err} - }() - r := <-ch - if r.err != nil { - if proto, ok := protocols[name]; ok { - return proto, nil + select { + case ch <- result{proto: proto, err: err}: + case <-ctx.Done(): } - r.err = &DNSError{Err: r.err.Error(), Name: name} + }() + select { + case r := <-ch: + if r.err != nil { + if proto, ok := protocols[name]; ok { + return proto, nil + } + r.err = &DNSError{Err: r.err.Error(), Name: name} + } + return r.proto, r.err + case <-ctx.Done(): + return 0, mapErr(ctx.Err()) } - return r.proto, r.err } func lookupHost(ctx context.Context, name string) ([]string, error) { @@ -193,30 +200,38 @@ func getservbyname(network, service string) (int, error) { return int(syscall.Ntohs(s.Port)), nil } -func oldLookupPort(network, service string) (int, error) { +func oldLookupPort(ctx context.Context, network, service string) (int, error) { // GetServByName return value is stored in thread local storage. // Start new os thread before the call to prevent races. type result struct { port int err error } - ch := make(chan result) + ch := make(chan result) // unbuffered go func() { acquireThread() defer releaseThread() runtime.LockOSThread() defer runtime.UnlockOSThread() port, err := getservbyname(network, service) - ch <- result{port: port, err: err} + select { + case ch <- result{port: port, err: err}: + case <-ctx.Done(): + } }() - r := <-ch - if r.err != nil { - r.err = &DNSError{Err: r.err.Error(), Name: network + "/" + service} + select { + case r := <-ch: + if r.err != nil { + r.err = &DNSError{Err: r.err.Error(), Name: network + "/" + service} + } + return r.port, r.err + case <-ctx.Done(): + return 0, mapErr(ctx.Err()) } - return r.port, r.err } -func newLookupPort(network, service string) (int, error) { +func newLookupPort(ctx context.Context, network, service string) (int, error) { + // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. acquireThread() defer releaseThread() var stype int32 @@ -252,7 +267,8 @@ func newLookupPort(network, service string) (int, error) { return 0, &DNSError{Err: syscall.EINVAL.Error(), Name: network + "/" + service} } -func lookupCNAME(name string) (string, error) { +func lookupCNAME(ctx context.Context, name string) (string, error) { + // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. acquireThread() defer releaseThread() var r *syscall.DNSRecord @@ -272,7 +288,8 @@ func lookupCNAME(name string) (string, error) { return absDomainName([]byte(cname)), nil } -func lookupSRV(service, proto, name string) (string, []*SRV, error) { +func lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) { + // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. acquireThread() defer releaseThread() var target string @@ -297,7 +314,8 @@ func lookupSRV(service, proto, name string) (string, []*SRV, error) { return absDomainName([]byte(target)), srvs, nil } -func lookupMX(name string) ([]*MX, error) { +func lookupMX(ctx context.Context, name string) ([]*MX, error) { + // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. acquireThread() defer releaseThread() var r *syscall.DNSRecord @@ -316,7 +334,8 @@ func lookupMX(name string) ([]*MX, error) { return mxs, nil } -func lookupNS(name string) ([]*NS, error) { +func lookupNS(ctx context.Context, name string) ([]*NS, error) { + // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. acquireThread() defer releaseThread() var r *syscall.DNSRecord @@ -334,7 +353,8 @@ func lookupNS(name string) ([]*NS, error) { return nss, nil } -func lookupTXT(name string) ([]string, error) { +func lookupTXT(ctx context.Context, name string) ([]string, error) { + // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. acquireThread() defer releaseThread() var r *syscall.DNSRecord @@ -355,7 +375,8 @@ func lookupTXT(name string) ([]string, error) { return txts, nil } -func lookupAddr(addr string) ([]string, error) { +func lookupAddr(ctx context.Context, addr string) ([]string, error) { + // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. acquireThread() defer releaseThread() arpa, err := reverseaddr(addr) diff --git a/src/net/tcpsock_plan9.go b/src/net/tcpsock_plan9.go index 08ad9be8f41..d2860607f8b 100644 --- a/src/net/tcpsock_plan9.go +++ b/src/net/tcpsock_plan9.go @@ -22,10 +22,6 @@ func dialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, } func doDialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) { - if d, _ := ctx.Deadline(); !d.IsZero() { - // TODO: deadline not implemented on Plan 9 (see golang.og/issue/11932) - } - // TODO(bradfitz,0intro): also use the cancel channel. switch net { case "tcp", "tcp4", "tcp6": default: @@ -34,7 +30,7 @@ func doDialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn if raddr == nil { return nil, errMissingAddress } - fd, err := dialPlan9(net, laddr, raddr) + fd, err := dialPlan9(ctx, net, laddr, raddr) if err != nil { return nil, err } @@ -71,7 +67,7 @@ func (ln *TCPListener) file() (*os.File, error) { } func listenTCP(ctx context.Context, network string, laddr *TCPAddr) (*TCPListener, error) { - fd, err := listenPlan9(network, laddr) + fd, err := listenPlan9(ctx, network, laddr) if err != nil { return nil, err } diff --git a/src/net/udpsock_plan9.go b/src/net/udpsock_plan9.go index 3b3d8d7615d..666f20622f6 100644 --- a/src/net/udpsock_plan9.go +++ b/src/net/udpsock_plan9.go @@ -56,10 +56,7 @@ func (c *UDPConn) writeMsg(b, oob []byte, addr *UDPAddr) (n, oobn int, err error } func dialUDP(ctx context.Context, net string, laddr, raddr *UDPAddr) (*UDPConn, error) { - if deadline, _ := ctx.Deadline(); !deadline.IsZero() { - // TODO: deadline not implemented on Plan 9 (see golang.og/issue/11932) - } - fd, err := dialPlan9(net, laddr, raddr) + fd, err := dialPlan9(ctx, net, laddr, raddr) if err != nil { return nil, err } @@ -95,7 +92,7 @@ func unmarshalUDPHeader(b []byte) (*udpHeader, []byte) { } func listenUDP(ctx context.Context, network string, laddr *UDPAddr) (*UDPConn, error) { - l, err := listenPlan9(network, laddr) + l, err := listenPlan9(ctx, network, laddr) if err != nil { return nil, err }