diff --git a/src/net/dial.go b/src/net/dial.go index 1f31e8f2cc7..59e41f536b2 100644 --- a/src/net/dial.go +++ b/src/net/dial.go @@ -124,7 +124,7 @@ func (d *Dialer) fallbackDelay() time.Duration { } } -func parseNetwork(net string) (afnet string, proto int, err error) { +func parseNetwork(ctx context.Context, net string) (afnet string, proto int, err error) { i := last(net, ':') if i < 0 { // no colon switch net { @@ -143,7 +143,7 @@ func parseNetwork(net string) (afnet string, proto int, err error) { protostr := net[i+1:] proto, i, ok := dtoi(protostr, 0) if !ok || i != len(protostr) { - proto, err = lookupProtocol(protostr) + proto, err = lookupProtocol(ctx, protostr) if err != nil { return "", 0, err } @@ -157,7 +157,7 @@ func parseNetwork(net string) (afnet string, proto int, err error) { // addresses. The result contains at least one address when error is // nil. func resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) (addrList, error) { - afnet, _, err := parseNetwork(network) + afnet, _, err := parseNetwork(ctx, network) if err != nil { return nil, err } @@ -472,8 +472,7 @@ func dialSerial(ctx context.Context, dp *dialParam, ras addrList) (Conn, error) } // dialSingle attempts to establish and returns a single connection to -// the destination address. This must be called through the OS-specific -// dial function, because some OSes don't implement the deadline feature. +// the destination address. func dialSingle(ctx context.Context, dp *dialParam, ra Addr) (c Conn, err error) { la := dp.LocalAddr switch ra := ra.(type) { diff --git a/src/net/dial_gen.go b/src/net/dial_gen.go deleted file mode 100644 index a2cd8cb44df..00000000000 --- a/src/net/dial_gen.go +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2012 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build windows plan9 - -package net - -import "time" - -// dialChannel is the simple pure-Go implementation of dial, still -// used on operating systems where the deadline hasn't been pushed -// down into the pollserver. (Plan 9 and some old versions of Windows) -func dialChannel(net string, ra Addr, dialer func(time.Time) (Conn, error), deadline time.Time) (Conn, error) { - if deadline.IsZero() { - return dialer(noDeadline) - } - timeout := deadline.Sub(time.Now()) - if timeout <= 0 { - return nil, &OpError{Op: "dial", Net: net, Source: nil, Addr: ra, Err: errTimeout} - } - t := time.NewTimer(timeout) - defer t.Stop() - type racer struct { - Conn - error - } - ch := make(chan racer, 1) - go func() { - testHookDialChannel() - c, err := dialer(noDeadline) - ch <- racer{c, err} - }() - select { - case <-t.C: - return nil, &OpError{Op: "dial", Net: net, Source: nil, Addr: ra, Err: errTimeout} - case racer := <-ch: - return racer.Conn, racer.error - } -} diff --git a/src/net/dnsclient_unix.go b/src/net/dnsclient_unix.go index 914dd767d33..5ae21012e3c 100644 --- a/src/net/dnsclient_unix.go +++ b/src/net/dnsclient_unix.go @@ -27,10 +27,10 @@ import ( // A dnsDialer provides dialing suitable for DNS queries. type dnsDialer interface { - dialDNS(string, string) (dnsConn, error) + dialDNS(ctx context.Context, network, addr string) (dnsConn, error) } -var testHookDNSDialer = func(d time.Duration) dnsDialer { return &Dialer{Timeout: d} } +var testHookDNSDialer = func() dnsDialer { return &Dialer{} } // A dnsConn represents a DNS transport endpoint. type dnsConn interface { @@ -105,7 +105,7 @@ func (c *TCPConn) writeDNSQuery(msg *dnsMsg) error { return nil } -func (d *Dialer) dialDNS(network, server string) (dnsConn, error) { +func (d *Dialer) dialDNS(ctx context.Context, network, server string) (dnsConn, error) { switch network { case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6": default: @@ -116,9 +116,9 @@ func (d *Dialer) dialDNS(network, server string) (dnsConn, error) { // call back here to translate it. The DNS config parser has // already checked that all the cfg.servers[i] are IP // addresses, which Dial will use without a DNS lookup. - c, err := d.Dial(network, server) + c, err := d.DialContext(ctx, network, server) if err != nil { - return nil, err + return nil, mapErr(err) } switch network { case "tcp", "tcp4", "tcp6": @@ -130,8 +130,8 @@ func (d *Dialer) dialDNS(network, server string) (dnsConn, error) { } // exchange sends a query on the connection and hopes for a response. -func exchange(server, name string, qtype uint16, timeout time.Duration) (*dnsMsg, error) { - d := testHookDNSDialer(timeout) +func exchange(ctx context.Context, server, name string, qtype uint16) (*dnsMsg, error) { + d := testHookDNSDialer() out := dnsMsg{ dnsMsgHdr: dnsMsgHdr{ recursion_desired: true, @@ -141,21 +141,21 @@ func exchange(server, name string, qtype uint16, timeout time.Duration) (*dnsMsg }, } for _, network := range []string{"udp", "tcp"} { - c, err := d.dialDNS(network, server) + c, err := d.dialDNS(ctx, network, server) if err != nil { return nil, err } defer c.Close() - if timeout > 0 { - c.SetDeadline(time.Now().Add(timeout)) + if d, ok := ctx.Deadline(); ok && !d.IsZero() { + c.SetDeadline(d) } out.id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano()) if err := c.writeDNSQuery(&out); err != nil { - return nil, err + return nil, mapErr(err) } in, err := c.readDNSResponse() if err != nil { - return nil, err + return nil, mapErr(err) } if in.id != out.id { return nil, errors.New("DNS message ID mismatch") @@ -170,16 +170,24 @@ func exchange(server, name string, qtype uint16, timeout time.Duration) (*dnsMsg // Do a lookup for a single name, which must be rooted // (otherwise answer will not find the answers). -func tryOneName(cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, error) { +func tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, error) { if len(cfg.servers) == 0 { return "", nil, &DNSError{Err: "no DNS servers", Name: name} } + timeout := time.Duration(cfg.timeout) * time.Second + deadline := time.Now().Add(timeout) + if old, ok := ctx.Deadline(); !ok || deadline.Before(old) { + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(ctx, deadline) + defer cancel() + } + var lastErr error for i := 0; i < cfg.attempts; i++ { for _, server := range cfg.servers { server = JoinHostPort(server, "53") - msg, err := exchange(server, name, qtype, timeout) + msg, err := exchange(ctx, server, name, qtype) if err != nil { lastErr = &DNSError{ Err: err.Error(), @@ -297,7 +305,7 @@ func (conf *resolverConfig) releaseSema() { <-conf.ch } -func lookup(name string, qtype uint16) (cname string, rrs []dnsRR, err error) { +func lookup(ctx context.Context, name string, qtype uint16) (cname string, rrs []dnsRR, err error) { if !isDomainName(name) { return "", nil, &DNSError{Err: "invalid domain name", Name: name} } @@ -306,7 +314,7 @@ func lookup(name string, qtype uint16) (cname string, rrs []dnsRR, err error) { conf := resolvConf.dnsConfig resolvConf.mu.RUnlock() for _, fqdn := range conf.nameList(name) { - cname, rrs, err = tryOneName(conf, fqdn, qtype) + cname, rrs, err = tryOneName(ctx, conf, fqdn, qtype) if err == nil { break } @@ -467,7 +475,7 @@ func goLookupIPOrder(ctx context.Context, name string, order hostLookupOrder) (a for _, fqdn := range conf.nameList(name) { for _, qtype := range qtypes { go func(qtype uint16) { - _, rrs, err := tryOneName(conf, fqdn, qtype) + _, rrs, err := tryOneName(ctx, conf, fqdn, qtype) lane <- racer{fqdn, rrs, err} }(qtype) } @@ -510,8 +518,8 @@ func goLookupIPOrder(ctx context.Context, name string, order hostLookupOrder) (a // Normally we let cgo use the C library resolver instead of // depending on our lookup code, so that Go and C get the same // answers. -func goLookupCNAME(name string) (cname string, err error) { - _, rrs, err := lookup(name, dnsTypeCNAME) +func goLookupCNAME(ctx context.Context, name string) (cname string, err error) { + _, rrs, err := lookup(ctx, name, dnsTypeCNAME) if err != nil { return } @@ -524,7 +532,7 @@ func goLookupCNAME(name string) (cname string, err error) { // only if cgoLookupPTR is the stub in cgo_stub.go). // Normally we let cgo use the C library resolver instead of depending // on our lookup code, so that Go and C get the same answers. -func goLookupPTR(addr string) ([]string, error) { +func goLookupPTR(ctx context.Context, addr string) ([]string, error) { names := lookupStaticAddr(addr) if len(names) > 0 { return names, nil @@ -533,7 +541,7 @@ func goLookupPTR(addr string) ([]string, error) { if err != nil { return nil, err } - _, rrs, err := lookup(arpa, dnsTypePTR) + _, rrs, err := lookup(ctx, arpa, dnsTypePTR) if err != nil { return nil, err } diff --git a/src/net/dnsclient_unix_test.go b/src/net/dnsclient_unix_test.go index 145a3b6a33b..761fb23f142 100644 --- a/src/net/dnsclient_unix_test.go +++ b/src/net/dnsclient_unix_test.go @@ -37,8 +37,9 @@ func TestDNSTransportFallback(t *testing.T) { testenv.MustHaveExternalNetwork(t) for _, tt := range dnsTransportFallbackTests { - timeout := time.Duration(tt.timeout) * time.Second - msg, err := exchange(tt.server, tt.name, tt.qtype, timeout) + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(tt.timeout)*time.Second) + defer cancel() + msg, err := exchange(ctx, tt.server, tt.name, tt.qtype) if err != nil { t.Error(err) continue @@ -78,7 +79,9 @@ func TestSpecialDomainName(t *testing.T) { server := "8.8.8.8:53" for _, tt := range specialDomainNameTests { - msg, err := exchange(server, tt.name, tt.qtype, 3*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msg, err := exchange(ctx, server, tt.name, tt.qtype) if err != nil { t.Error(err) continue @@ -492,7 +495,7 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) { } d := &fakeDNSConn{} - testHookDNSDialer = func(time.Duration) dnsDialer { return d } + testHookDNSDialer = func() dnsDialer { return d } d.rh = func(q *dnsMsg) (*dnsMsg, error) { r := &dnsMsg{ @@ -571,7 +574,7 @@ type fakeDNSConn struct { rh func(*dnsMsg) (*dnsMsg, error) } -func (f *fakeDNSConn) dialDNS(n, s string) (dnsConn, error) { +func (f *fakeDNSConn) dialDNS(_ context.Context, n, s string) (dnsConn, error) { return f, nil } diff --git a/src/net/fd_plan9.go b/src/net/fd_plan9.go index d0e9c53fca6..35d16243178 100644 --- a/src/net/fd_plan9.go +++ b/src/net/fd_plan9.go @@ -32,12 +32,6 @@ func sysInit() { netdir = "/net" } -func dial(net string, ra Addr, dialer func(time.Time) (Conn, error), deadline time.Time) (Conn, error) { - // On plan9, use the relatively inefficient - // goroutine-racing implementation. - return dialChannel(net, ra, dialer, deadline) -} - func newFD(net, name string, ctl, data *os.File, laddr, raddr Addr) (*netFD, error) { return &netFD{net: net, n: name, dir: netdir + "/" + net + "/" + name, ctl: ctl, data: data, laddr: laddr, raddr: raddr}, nil } diff --git a/src/net/fd_windows.go b/src/net/fd_windows.go index d1d91a6a5c5..ca46bf93610 100644 --- a/src/net/fd_windows.go +++ b/src/net/fd_windows.go @@ -11,7 +11,6 @@ import ( "runtime" "sync" "syscall" - "time" "unsafe" ) @@ -70,22 +69,15 @@ func sysInit() { } } +// canUseConnectEx reports whether we can use the ConnectEx Windows API call +// for the given network type. func canUseConnectEx(net string) bool { switch net { - case "udp", "udp4", "udp6", "ip", "ip4", "ip6": - // ConnectEx windows API does not support connectionless sockets. - return false + case "tcp", "tcp4", "tcp6": + return true } - return syscall.LoadConnectEx() == nil -} - -func dial(net string, ra Addr, dialer func(time.Time) (Conn, error), deadline time.Time) (Conn, error) { - if !canUseConnectEx(net) { - // Use the relatively inefficient goroutine-racing - // implementation of DialTimeout. - return dialChannel(net, ra, dialer, deadline) - } - return dialer(deadline) + // ConnectEx windows API does not support connectionless sockets. + return false } // operation contains superset of data necessary to perform all async IO. @@ -328,12 +320,13 @@ func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) error { if err := fd.init(); err != nil { return err } - if deadline, _ := ctx.Deadline(); !deadline.IsZero() { + if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() { fd.setWriteDeadline(deadline) defer fd.setWriteDeadline(noDeadline) } if !canUseConnectEx(fd.net) { - return os.NewSyscallError("connect", connectFunc(fd.sysfd, ra)) + err := connectFunc(fd.sysfd, ra) + return os.NewSyscallError("connect", err) } // ConnectEx windows API requires an unconnected, previously bound socket. if la == nil { diff --git a/src/net/iprawsock.go b/src/net/iprawsock.go index f4a4de82fcd..173b3cb4114 100644 --- a/src/net/iprawsock.go +++ b/src/net/iprawsock.go @@ -50,7 +50,7 @@ func ResolveIPAddr(net, addr string) (*IPAddr, error) { if net == "" { // a hint wildcard for Go 1.0 undocumented behavior net = "ip" } - afnet, _, err := parseNetwork(net) + afnet, _, err := parseNetwork(context.Background(), net) if err != nil { return nil, err } diff --git a/src/net/iprawsock_posix.go b/src/net/iprawsock_posix.go index 68dc307b606..3e0b060a8a4 100644 --- a/src/net/iprawsock_posix.go +++ b/src/net/iprawsock_posix.go @@ -121,7 +121,7 @@ func (c *IPConn) writeMsg(b, oob []byte, addr *IPAddr) (n, oobn int, err error) } func dialIP(ctx context.Context, netProto string, laddr, raddr *IPAddr) (*IPConn, error) { - network, proto, err := parseNetwork(netProto) + network, proto, err := parseNetwork(ctx, netProto) if err != nil { return nil, err } @@ -141,7 +141,7 @@ func dialIP(ctx context.Context, netProto string, laddr, raddr *IPAddr) (*IPConn } func listenIP(ctx context.Context, netProto string, laddr *IPAddr) (*IPConn, error) { - network, proto, err := parseNetwork(netProto) + network, proto, err := parseNetwork(ctx, netProto) if err != nil { return nil, err } diff --git a/src/net/ipsock_plan9.go b/src/net/ipsock_plan9.go index f7c2b446883..2b84683eeb5 100644 --- a/src/net/ipsock_plan9.go +++ b/src/net/ipsock_plan9.go @@ -7,6 +7,7 @@ package net import ( + "context" "os" "syscall" ) @@ -99,7 +100,7 @@ func readPlan9Addr(proto, filename string) (addr Addr, err error) { return addr, nil } -func startPlan9(net string, addr Addr) (ctl *os.File, dest, proto, name string, err error) { +func startPlan9(ctx context.Context, net string, addr Addr) (ctl *os.File, dest, proto, name string, err error) { var ( ip IP port int @@ -118,7 +119,7 @@ func startPlan9(net string, addr Addr) (ctl *os.File, dest, proto, name string, return } - clone, dest, err := queryCS1(proto, ip, port) + clone, dest, err := queryCS1(ctx, proto, ip, port) if err != nil { return } @@ -135,8 +136,8 @@ func startPlan9(net string, addr Addr) (ctl *os.File, dest, proto, name string, return f, dest, proto, string(buf[:n]), nil } -func netErr(e error) { - oe, ok := e.(*OpError) +func fixErr(err error) { + oe, ok := err.(*OpError) if !ok { return } @@ -165,9 +166,34 @@ func netErr(e error) { } } -func dialPlan9(net string, laddr, raddr Addr) (fd *netFD, err error) { - defer func() { netErr(err) }() - f, dest, proto, name, err := startPlan9(net, raddr) +func dialPlan9(ctx context.Context, net string, laddr, raddr Addr) (fd *netFD, err error) { + defer func() { fixErr(err) }() + type res struct { + fd *netFD + err error + } + resc := make(chan res) + go func() { + testHookDialChannel() + fd, err := dialPlan9Blocking(ctx, net, laddr, raddr) + select { + case resc <- res{fd, err}: + case <-ctx.Done(): + if fd != nil { + fd.Close() + } + } + }() + select { + case res := <-resc: + return res.fd, res.err + case <-ctx.Done(): + return nil, mapErr(ctx.Err()) + } +} + +func dialPlan9Blocking(ctx context.Context, net string, laddr, raddr Addr) (fd *netFD, err error) { + f, dest, proto, name, err := startPlan9(ctx, net, raddr) if err != nil { return nil, err } @@ -190,9 +216,9 @@ func dialPlan9(net string, laddr, raddr Addr) (fd *netFD, err error) { return newFD(proto, name, f, data, laddr, raddr) } -func listenPlan9(net string, laddr Addr) (fd *netFD, err error) { - defer func() { netErr(err) }() - f, dest, proto, name, err := startPlan9(net, laddr) +func listenPlan9(ctx context.Context, net string, laddr Addr) (fd *netFD, err error) { + defer func() { fixErr(err) }() + f, dest, proto, name, err := startPlan9(ctx, net, laddr) if err != nil { return nil, err } @@ -214,7 +240,7 @@ func (fd *netFD) netFD() (*netFD, error) { } func (fd *netFD) acceptPlan9() (nfd *netFD, err error) { - defer func() { netErr(err) }() + defer func() { fixErr(err) }() if err := fd.readLock(); err != nil { return nil, err } diff --git a/src/net/lookup.go b/src/net/lookup.go index 8f02787422b..5e60011165d 100644 --- a/src/net/lookup.go +++ b/src/net/lookup.go @@ -114,7 +114,7 @@ func lookupIPContext(ctx context.Context, host string) (addrs []IPAddr, err erro func LookupPort(network, service string) (port int, err error) { port, needsLookup := parsePort(service) if needsLookup { - port, err = lookupPort(network, service) + port, err = lookupPort(context.Background(), network, service) if err != nil { return 0, err } @@ -130,7 +130,7 @@ func LookupPort(network, service string) (port int, err error) { // LookupHost or LookupIP directly; both take care of resolving // the canonical name as part of the lookup. func LookupCNAME(name string) (cname string, err error) { - return lookupCNAME(name) + return lookupCNAME(context.Background(), name) } // LookupSRV tries to resolve an SRV query of the given service, @@ -143,26 +143,26 @@ func LookupCNAME(name string) (cname string, err error) { // publishing SRV records under non-standard names, if both service // and proto are empty strings, LookupSRV looks up name directly. func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err error) { - return lookupSRV(service, proto, name) + return lookupSRV(context.Background(), service, proto, name) } // LookupMX returns the DNS MX records for the given domain name sorted by preference. func LookupMX(name string) (mxs []*MX, err error) { - return lookupMX(name) + return lookupMX(context.Background(), name) } // LookupNS returns the DNS NS records for the given domain name. func LookupNS(name string) (nss []*NS, err error) { - return lookupNS(name) + return lookupNS(context.Background(), name) } // LookupTXT returns the DNS TXT records for the given domain name. func LookupTXT(name string) (txts []string, err error) { - return lookupTXT(name) + return lookupTXT(context.Background(), name) } // LookupAddr performs a reverse lookup for the given address, returning a list // of names mapping to that address. func LookupAddr(addr string) (names []string, err error) { - return lookupAddr(addr) + return lookupAddr(context.Background(), addr) } diff --git a/src/net/lookup_plan9.go b/src/net/lookup_plan9.go index 4224263602b..73147a2d3f7 100644 --- a/src/net/lookup_plan9.go +++ b/src/net/lookup_plan9.go @@ -10,7 +10,7 @@ import ( "os" ) -func query(filename, query string, bufSize int) (res []string, err error) { +func query(ctx context.Context, filename, query string, bufSize int) (res []string, err error) { file, err := os.OpenFile(filename, os.O_RDWR, 0) if err != nil { return @@ -40,7 +40,7 @@ func query(filename, query string, bufSize int) (res []string, err error) { return } -func queryCS(net, host, service string) (res []string, err error) { +func queryCS(ctx context.Context, net, host, service string) (res []string, err error) { switch net { case "tcp4", "tcp6": net = "tcp" @@ -50,15 +50,15 @@ func queryCS(net, host, service string) (res []string, err error) { if host == "" { host = "*" } - return query(netdir+"/cs", net+"!"+host+"!"+service, 128) + return query(ctx, netdir+"/cs", net+"!"+host+"!"+service, 128) } -func queryCS1(net string, ip IP, port int) (clone, dest string, err error) { +func queryCS1(ctx context.Context, net string, ip IP, port int) (clone, dest string, err error) { ips := "*" if len(ip) != 0 && !ip.IsUnspecified() { ips = ip.String() } - lines, err := queryCS(net, ips, itoa(port)) + lines, err := queryCS(ctx, net, ips, itoa(port)) if err != nil { return } @@ -70,8 +70,8 @@ func queryCS1(net string, ip IP, port int) (clone, dest string, err error) { return } -func queryDNS(addr string, typ string) (res []string, err error) { - return query(netdir+"/dns", addr+" "+typ, 1024) +func queryDNS(ctx context.Context, addr string, typ string) (res []string, err error) { + return query(ctx, netdir+"/dns", addr+" "+typ, 1024) } // toLower returns a lower-case version of in. Restricting us to @@ -97,8 +97,8 @@ func toLower(in string) string { // lookupProtocol looks up IP protocol name and returns // the corresponding protocol number. -func lookupProtocol(name string) (proto int, err error) { - lines, err := query(netdir+"/cs", "!protocol="+toLower(name), 128) +func lookupProtocol(ctx context.Context, name string) (proto int, err error) { + lines, err := query(ctx, netdir+"/cs", "!protocol="+toLower(name), 128) if err != nil { return 0, err } @@ -119,7 +119,7 @@ func lookupProtocol(name string) (proto int, err error) { func lookupHost(ctx context.Context, host string) (addrs []string, err error) { // Use netdir/cs instead of netdir/dns because cs knows about // host names in local network (e.g. from /lib/ndb/local) - lines, err := queryCS("net", host, "1") + lines, err := queryCS(ctx, "net", host, "1") if err != nil { return } @@ -148,8 +148,7 @@ loop: } func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) { - // TODO(bradfitz): push down ctx - lits, err := LookupHost(host) + lits, err := lookupHost(ctx, host) if err != nil { return } @@ -163,14 +162,14 @@ func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) { return } -func lookupPort(network, service string) (port int, err error) { +func lookupPort(ctx context.Context, network, service string) (port int, err error) { switch network { case "tcp4", "tcp6": network = "tcp" case "udp4", "udp6": network = "udp" } - lines, err := queryCS(network, "127.0.0.1", service) + lines, err := queryCS(ctx, network, "127.0.0.1", service) if err != nil { return } @@ -192,8 +191,8 @@ func lookupPort(network, service string) (port int, err error) { return 0, unknownPortError } -func lookupCNAME(name string) (cname string, err error) { - lines, err := queryDNS(name, "cname") +func lookupCNAME(ctx context.Context, name string) (cname string, err error) { + lines, err := queryDNS(ctx, name, "cname") if err != nil { return } @@ -205,14 +204,14 @@ func lookupCNAME(name string) (cname string, err error) { return "", errors.New("bad response from ndb/dns") } -func lookupSRV(service, proto, name string) (cname string, addrs []*SRV, err error) { +func lookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*SRV, err error) { var target string if service == "" && proto == "" { target = name } else { target = "_" + service + "._" + proto + "." + name } - lines, err := queryDNS(target, "srv") + lines, err := queryDNS(ctx, target, "srv") if err != nil { return } @@ -234,8 +233,8 @@ func lookupSRV(service, proto, name string) (cname string, addrs []*SRV, err err return } -func lookupMX(name string) (mx []*MX, err error) { - lines, err := queryDNS(name, "mx") +func lookupMX(ctx context.Context, name string) (mx []*MX, err error) { + lines, err := queryDNS(ctx, name, "mx") if err != nil { return } @@ -252,8 +251,8 @@ func lookupMX(name string) (mx []*MX, err error) { return } -func lookupNS(name string) (ns []*NS, err error) { - lines, err := queryDNS(name, "ns") +func lookupNS(ctx context.Context, name string) (ns []*NS, err error) { + lines, err := queryDNS(ctx, name, "ns") if err != nil { return } @@ -267,8 +266,8 @@ func lookupNS(name string) (ns []*NS, err error) { return } -func lookupTXT(name string) (txt []string, err error) { - lines, err := queryDNS(name, "txt") +func lookupTXT(ctx context.Context, name string) (txt []string, err error) { + lines, err := queryDNS(ctx, name, "txt") if err != nil { return } @@ -280,12 +279,12 @@ func lookupTXT(name string) (txt []string, err error) { return } -func lookupAddr(addr string) (name []string, err error) { +func lookupAddr(ctx context.Context, addr string) (name []string, err error) { arpa, err := reverseaddr(addr) if err != nil { return } - lines, err := queryDNS(arpa, "ptr") + lines, err := queryDNS(ctx, arpa, "ptr") if err != nil { return } diff --git a/src/net/lookup_stub.go b/src/net/lookup_stub.go index 38a4f0bae48..bd096b39652 100644 --- a/src/net/lookup_stub.go +++ b/src/net/lookup_stub.go @@ -11,7 +11,7 @@ import ( "syscall" ) -func lookupProtocol(name string) (proto int, err error) { +func lookupProtocol(ctx context.Context, name string) (proto int, err error) { return 0, syscall.ENOPROTOOPT } @@ -23,30 +23,30 @@ func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) { return nil, syscall.ENOPROTOOPT } -func lookupPort(network, service string) (port int, err error) { +func lookupPort(ctx context.Context, network, service string) (port int, err error) { return 0, syscall.ENOPROTOOPT } -func lookupCNAME(name string) (cname string, err error) { +func lookupCNAME(ctx context.Context, name string) (cname string, err error) { return "", syscall.ENOPROTOOPT } -func lookupSRV(service, proto, name string) (cname string, srvs []*SRV, err error) { +func lookupSRV(ctx context.Context, service, proto, name string) (cname string, srvs []*SRV, err error) { return "", nil, syscall.ENOPROTOOPT } -func lookupMX(name string) (mxs []*MX, err error) { +func lookupMX(ctx context.Context, name string) (mxs []*MX, err error) { return nil, syscall.ENOPROTOOPT } -func lookupNS(name string) (nss []*NS, err error) { +func lookupNS(ctx context.Context, name string) (nss []*NS, err error) { return nil, syscall.ENOPROTOOPT } -func lookupTXT(name string) (txts []string, err error) { +func lookupTXT(ctx context.Context, name string) (txts []string, err error) { return nil, syscall.ENOPROTOOPT } -func lookupAddr(addr string) (ptrs []string, err error) { +func lookupAddr(ctx context.Context, addr string) (ptrs []string, err error) { return nil, syscall.ENOPROTOOPT } diff --git a/src/net/lookup_unix.go b/src/net/lookup_unix.go index 8d3fa477828..5461fe8a411 100644 --- a/src/net/lookup_unix.go +++ b/src/net/lookup_unix.go @@ -43,7 +43,7 @@ func readProtocols() { // lookupProtocol looks up IP protocol name in /etc/protocols and // returns correspondent protocol number. -func lookupProtocol(name string) (int, error) { +func lookupProtocol(_ context.Context, name string) (int, error) { onceReadProtocols.Do(readProtocols) proto, found := protocols[name] if !found { @@ -77,7 +77,12 @@ func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) { return goLookupIPOrder(ctx, host, order) } -func lookupPort(network, service string) (int, error) { +func lookupPort(ctx context.Context, network, service string) (int, error) { + // TODO: use the context if there ever becomes a need. Related + // is issue 15321. But port lookup generally just involves + // local files, and the os package has no context support. The + // files might be on a remote filesystem, though. This should + // probably race goroutines if ctx != context.Background(). if systemConf().canUseCgo() { if port, err, ok := cgoLookupPort(network, service); ok { return port, err @@ -86,23 +91,24 @@ func lookupPort(network, service string) (int, error) { return goLookupPort(network, service) } -func lookupCNAME(name string) (string, error) { +func lookupCNAME(ctx context.Context, name string) (string, error) { if systemConf().canUseCgo() { + // TODO: use ctx. issue 15321. Or race goroutines. if cname, err, ok := cgoLookupCNAME(name); ok { return cname, err } } - return goLookupCNAME(name) + return goLookupCNAME(ctx, name) } -func lookupSRV(service, proto, name string) (string, []*SRV, error) { +func lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) { var target string if service == "" && proto == "" { target = name } else { target = "_" + service + "._" + proto + "." + name } - cname, rrs, err := lookup(target, dnsTypeSRV) + cname, rrs, err := lookup(ctx, target, dnsTypeSRV) if err != nil { return "", nil, err } @@ -115,8 +121,8 @@ func lookupSRV(service, proto, name string) (string, []*SRV, error) { return cname, srvs, nil } -func lookupMX(name string) ([]*MX, error) { - _, rrs, err := lookup(name, dnsTypeMX) +func lookupMX(ctx context.Context, name string) ([]*MX, error) { + _, rrs, err := lookup(ctx, name, dnsTypeMX) if err != nil { return nil, err } @@ -129,8 +135,8 @@ func lookupMX(name string) ([]*MX, error) { return mxs, nil } -func lookupNS(name string) ([]*NS, error) { - _, rrs, err := lookup(name, dnsTypeNS) +func lookupNS(ctx context.Context, name string) ([]*NS, error) { + _, rrs, err := lookup(ctx, name, dnsTypeNS) if err != nil { return nil, err } @@ -141,8 +147,8 @@ func lookupNS(name string) ([]*NS, error) { return nss, nil } -func lookupTXT(name string) ([]string, error) { - _, rrs, err := lookup(name, dnsTypeTXT) +func lookupTXT(ctx context.Context, name string) ([]string, error) { + _, rrs, err := lookup(ctx, name, dnsTypeTXT) if err != nil { return nil, err } @@ -153,11 +159,11 @@ func lookupTXT(name string) ([]string, error) { return txts, nil } -func lookupAddr(addr string) ([]string, error) { +func lookupAddr(ctx context.Context, addr string) ([]string, error) { if systemConf().canUseCgo() { if ptrs, err, ok := cgoLookupPTR(addr); ok { return ptrs, err } } - return goLookupPTR(addr) + return goLookupPTR(ctx, addr) } diff --git a/src/net/lookup_windows.go b/src/net/lookup_windows.go index ce012ba873f..7a04cc89984 100644 --- a/src/net/lookup_windows.go +++ b/src/net/lookup_windows.go @@ -26,30 +26,37 @@ func getprotobyname(name string) (proto int, err error) { } // lookupProtocol looks up IP protocol name and returns correspondent protocol number. -func lookupProtocol(name string) (int, error) { +func lookupProtocol(ctx context.Context, name string) (int, error) { // GetProtoByName return value is stored in thread local storage. // Start new os thread before the call to prevent races. type result struct { proto int err error } - ch := make(chan result) + ch := make(chan result) // unbuffered go func() { acquireThread() defer releaseThread() runtime.LockOSThread() defer runtime.UnlockOSThread() proto, err := getprotobyname(name) - ch <- result{proto: proto, err: err} - }() - r := <-ch - if r.err != nil { - if proto, ok := protocols[name]; ok { - return proto, nil + select { + case ch <- result{proto: proto, err: err}: + case <-ctx.Done(): } - r.err = &DNSError{Err: r.err.Error(), Name: name} + }() + select { + case r := <-ch: + if r.err != nil { + if proto, ok := protocols[name]; ok { + return proto, nil + } + r.err = &DNSError{Err: r.err.Error(), Name: name} + } + return r.proto, r.err + case <-ctx.Done(): + return 0, mapErr(ctx.Err()) } - return r.proto, r.err } func lookupHost(ctx context.Context, name string) ([]string, error) { @@ -193,30 +200,38 @@ func getservbyname(network, service string) (int, error) { return int(syscall.Ntohs(s.Port)), nil } -func oldLookupPort(network, service string) (int, error) { +func oldLookupPort(ctx context.Context, network, service string) (int, error) { // GetServByName return value is stored in thread local storage. // Start new os thread before the call to prevent races. type result struct { port int err error } - ch := make(chan result) + ch := make(chan result) // unbuffered go func() { acquireThread() defer releaseThread() runtime.LockOSThread() defer runtime.UnlockOSThread() port, err := getservbyname(network, service) - ch <- result{port: port, err: err} + select { + case ch <- result{port: port, err: err}: + case <-ctx.Done(): + } }() - r := <-ch - if r.err != nil { - r.err = &DNSError{Err: r.err.Error(), Name: network + "/" + service} + select { + case r := <-ch: + if r.err != nil { + r.err = &DNSError{Err: r.err.Error(), Name: network + "/" + service} + } + return r.port, r.err + case <-ctx.Done(): + return 0, mapErr(ctx.Err()) } - return r.port, r.err } -func newLookupPort(network, service string) (int, error) { +func newLookupPort(ctx context.Context, network, service string) (int, error) { + // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. acquireThread() defer releaseThread() var stype int32 @@ -252,7 +267,8 @@ func newLookupPort(network, service string) (int, error) { return 0, &DNSError{Err: syscall.EINVAL.Error(), Name: network + "/" + service} } -func lookupCNAME(name string) (string, error) { +func lookupCNAME(ctx context.Context, name string) (string, error) { + // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. acquireThread() defer releaseThread() var r *syscall.DNSRecord @@ -272,7 +288,8 @@ func lookupCNAME(name string) (string, error) { return absDomainName([]byte(cname)), nil } -func lookupSRV(service, proto, name string) (string, []*SRV, error) { +func lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) { + // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. acquireThread() defer releaseThread() var target string @@ -297,7 +314,8 @@ func lookupSRV(service, proto, name string) (string, []*SRV, error) { return absDomainName([]byte(target)), srvs, nil } -func lookupMX(name string) ([]*MX, error) { +func lookupMX(ctx context.Context, name string) ([]*MX, error) { + // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. acquireThread() defer releaseThread() var r *syscall.DNSRecord @@ -316,7 +334,8 @@ func lookupMX(name string) ([]*MX, error) { return mxs, nil } -func lookupNS(name string) ([]*NS, error) { +func lookupNS(ctx context.Context, name string) ([]*NS, error) { + // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. acquireThread() defer releaseThread() var r *syscall.DNSRecord @@ -334,7 +353,8 @@ func lookupNS(name string) ([]*NS, error) { return nss, nil } -func lookupTXT(name string) ([]string, error) { +func lookupTXT(ctx context.Context, name string) ([]string, error) { + // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. acquireThread() defer releaseThread() var r *syscall.DNSRecord @@ -355,7 +375,8 @@ func lookupTXT(name string) ([]string, error) { return txts, nil } -func lookupAddr(addr string) ([]string, error) { +func lookupAddr(ctx context.Context, addr string) ([]string, error) { + // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. acquireThread() defer releaseThread() arpa, err := reverseaddr(addr) diff --git a/src/net/tcpsock_plan9.go b/src/net/tcpsock_plan9.go index 08ad9be8f41..d2860607f8b 100644 --- a/src/net/tcpsock_plan9.go +++ b/src/net/tcpsock_plan9.go @@ -22,10 +22,6 @@ func dialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, } func doDialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) { - if d, _ := ctx.Deadline(); !d.IsZero() { - // TODO: deadline not implemented on Plan 9 (see golang.og/issue/11932) - } - // TODO(bradfitz,0intro): also use the cancel channel. switch net { case "tcp", "tcp4", "tcp6": default: @@ -34,7 +30,7 @@ func doDialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn if raddr == nil { return nil, errMissingAddress } - fd, err := dialPlan9(net, laddr, raddr) + fd, err := dialPlan9(ctx, net, laddr, raddr) if err != nil { return nil, err } @@ -71,7 +67,7 @@ func (ln *TCPListener) file() (*os.File, error) { } func listenTCP(ctx context.Context, network string, laddr *TCPAddr) (*TCPListener, error) { - fd, err := listenPlan9(network, laddr) + fd, err := listenPlan9(ctx, network, laddr) if err != nil { return nil, err } diff --git a/src/net/udpsock_plan9.go b/src/net/udpsock_plan9.go index 3b3d8d7615d..666f20622f6 100644 --- a/src/net/udpsock_plan9.go +++ b/src/net/udpsock_plan9.go @@ -56,10 +56,7 @@ func (c *UDPConn) writeMsg(b, oob []byte, addr *UDPAddr) (n, oobn int, err error } func dialUDP(ctx context.Context, net string, laddr, raddr *UDPAddr) (*UDPConn, error) { - if deadline, _ := ctx.Deadline(); !deadline.IsZero() { - // TODO: deadline not implemented on Plan 9 (see golang.og/issue/11932) - } - fd, err := dialPlan9(net, laddr, raddr) + fd, err := dialPlan9(ctx, net, laddr, raddr) if err != nil { return nil, err } @@ -95,7 +92,7 @@ func unmarshalUDPHeader(b []byte) (*udpHeader, []byte) { } func listenUDP(ctx context.Context, network string, laddr *UDPAddr) (*UDPConn, error) { - l, err := listenPlan9(network, laddr) + l, err := listenPlan9(ctx, network, laddr) if err != nil { return nil, err }