diff --git a/src/pkg/net/fd.go b/src/pkg/net/fd.go index 52527ec8f2f..e9927d9534a 100644 --- a/src/pkg/net/fd.go +++ b/src/pkg/net/fd.go @@ -612,11 +612,10 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (netfd *netFD, err e syscall.ForkLock.RUnlock() if netfd, err = newFD(s, fd.family, fd.sotype, fd.net); err != nil { - syscall.Close(s) + closesocket(s) return nil, err } - lsa, _ := syscall.Getsockname(netfd.sysfd) - netfd.setAddr(toAddr(lsa), toAddr(rsa)) + netfd.setAddr(localSockname(fd, toAddr), toAddr(rsa)) return netfd, nil } diff --git a/src/pkg/net/file.go b/src/pkg/net/file.go index 1abf24f2d65..11c8f77a825 100644 --- a/src/pkg/net/file.go +++ b/src/pkg/net/file.go @@ -25,8 +25,8 @@ func newFileFD(f *os.File) (*netFD, error) { family := syscall.AF_UNSPEC toAddr := sockaddrToTCP - sa, _ := syscall.Getsockname(fd) - switch sa.(type) { + lsa, _ := syscall.Getsockname(fd) + switch lsa.(type) { default: closesocket(fd) return nil, syscall.EINVAL @@ -53,16 +53,14 @@ func newFileFD(f *os.File) (*netFD, error) { toAddr = sockaddrToUnixpacket } } - laddr := toAddr(sa) - sa, _ = syscall.Getpeername(fd) - raddr := toAddr(sa) + laddr := toAddr(lsa) netfd, err := newFD(fd, family, sotype, laddr.Network()) if err != nil { closesocket(fd) return nil, err } - netfd.setAddr(laddr, raddr) + netfd.setAddr(laddr, remoteSockname(netfd, toAddr)) return netfd, nil } @@ -80,10 +78,10 @@ func FileConn(f *os.File) (c Conn, err error) { return newTCPConn(fd), nil case *UDPAddr: return newUDPConn(fd), nil - case *UnixAddr: - return newUnixConn(fd), nil case *IPAddr: return newIPConn(fd), nil + case *UnixAddr: + return newUnixConn(fd), nil } fd.Close() return nil, syscall.EINVAL diff --git a/src/pkg/net/ipraw_test.go b/src/pkg/net/ipraw_test.go index 0a28827e330..5b5b68377f7 100644 --- a/src/pkg/net/ipraw_test.go +++ b/src/pkg/net/ipraw_test.go @@ -14,6 +14,55 @@ import ( "time" ) +var ipConnAddrStringTests = []struct { + net string + laddr string + raddr string + ipv6 bool +}{ + {"ip:icmp", "127.0.0.1", "", false}, + {"ip:icmp", "::1", "", true}, + {"ip:icmp", "", "127.0.0.1", false}, + {"ip:icmp", "", "::1", true}, +} + +func TestIPConnAddrString(t *testing.T) { + if os.Getuid() != 0 { + t.Logf("skipping test; must be root") + return + } + + for i, tt := range ipConnAddrStringTests { + if tt.ipv6 && !supportsIPv6 { + continue + } + var ( + err error + c *IPConn + mode string + ) + if tt.raddr == "" { + mode = "listen" + la, _ := ResolveIPAddr(tt.net, tt.laddr) + c, err = ListenIP(tt.net, la) + if err != nil { + t.Fatalf("ListenIP(%q, %q) failed: %v", tt.net, la.String(), err) + } + } else { + mode = "dial" + la, _ := ResolveIPAddr(tt.net, tt.laddr) + ra, _ := ResolveIPAddr(tt.net, tt.raddr) + c, err = DialIP(tt.net, la, ra) + if err != nil { + t.Fatalf("DialIP(%q, %q) failed: %v", tt.net, ra.String(), err) + } + } + t.Logf("%s-%v: LocalAddr: %q, %q", mode, i, c.LocalAddr(), c.LocalAddr().String()) + t.Logf("%s-%v: RemoteAddr: %q, %q", mode, i, c.RemoteAddr(), c.RemoteAddr().String()) + c.Close() + } +} + var icmpTests = []struct { net string laddr string @@ -26,7 +75,7 @@ var icmpTests = []struct { func TestICMP(t *testing.T) { if os.Getuid() != 0 { - t.Logf("test disabled; must be root") + t.Logf("skipping test; must be root") return } diff --git a/src/pkg/net/sock.go b/src/pkg/net/sock.go index 3ae16054e47..bc9606048a8 100644 --- a/src/pkg/net/sock.go +++ b/src/pkg/net/sock.go @@ -16,7 +16,7 @@ import ( var listenerBacklog = maxListenerBacklog() // Generic socket creation. -func socket(net string, f, t, p int, ipv6only bool, la, ra syscall.Sockaddr, toAddr func(syscall.Sockaddr) Addr) (fd *netFD, err error) { +func socket(net string, f, t, p int, ipv6only bool, ulsa, ursa syscall.Sockaddr, toAddr func(syscall.Sockaddr) Addr) (fd *netFD, err error) { // See ../syscall/exec.go for description of ForkLock. syscall.ForkLock.RLock() s, err := syscall.Socket(f, t, p) @@ -27,21 +27,18 @@ func socket(net string, f, t, p int, ipv6only bool, la, ra syscall.Sockaddr, toA syscall.CloseOnExec(s) syscall.ForkLock.RUnlock() - err = setDefaultSockopts(s, f, t, ipv6only) - if err != nil { + if err = setDefaultSockopts(s, f, t, ipv6only); err != nil { closesocket(s) return nil, err } - var bla syscall.Sockaddr - if la != nil { - bla, err = listenerSockaddr(s, f, la, toAddr) - if err != nil { + var blsa syscall.Sockaddr + if ulsa != nil { + if blsa, err = listenerSockaddr(s, f, ulsa, toAddr); err != nil { closesocket(s) return nil, err } - err = syscall.Bind(s, bla) - if err != nil { + if err = syscall.Bind(s, blsa); err != nil { closesocket(s) return nil, err } @@ -52,8 +49,8 @@ func socket(net string, f, t, p int, ipv6only bool, la, ra syscall.Sockaddr, toA return nil, err } - if ra != nil { - if err = fd.connect(ra); err != nil { + if ursa != nil { + if err = fd.connect(ursa); err != nil { closesocket(s) fd.Close() return nil, err @@ -61,17 +58,13 @@ func socket(net string, f, t, p int, ipv6only bool, la, ra syscall.Sockaddr, toA fd.isConnected = true } - sa, _ := syscall.Getsockname(s) var laddr Addr - if la != nil && bla != la { - laddr = toAddr(la) + if ulsa != nil && blsa != ulsa { + laddr = toAddr(ulsa) } else { - laddr = toAddr(sa) + laddr = localSockname(fd, toAddr) } - sa, _ = syscall.Getpeername(s) - raddr := toAddr(sa) - - fd.setAddr(laddr, raddr) + fd.setAddr(laddr, remoteSockname(fd, toAddr)) return fd, nil } @@ -85,3 +78,39 @@ func genericReadFrom(w io.Writer, r io.Reader) (n int64, err error) { // Use wrapper to hide existing r.ReadFrom from io.Copy. return io.Copy(writerOnly{w}, r) } + +func localSockname(fd *netFD, toAddr func(syscall.Sockaddr) Addr) Addr { + sa, _ := syscall.Getsockname(fd.sysfd) + if sa == nil { + return nullProtocolAddr(fd.family, fd.sotype) + } + return toAddr(sa) +} + +func remoteSockname(fd *netFD, toAddr func(syscall.Sockaddr) Addr) Addr { + sa, _ := syscall.Getpeername(fd.sysfd) + if sa == nil { + return nullProtocolAddr(fd.family, fd.sotype) + } + return toAddr(sa) +} + +func nullProtocolAddr(f, t int) Addr { + switch f { + case syscall.AF_INET, syscall.AF_INET6: + switch t { + case syscall.SOCK_STREAM: + return (*TCPAddr)(nil) + case syscall.SOCK_DGRAM: + return (*UDPAddr)(nil) + case syscall.SOCK_RAW: + return (*IPAddr)(nil) + } + case syscall.AF_UNIX: + switch t { + case syscall.SOCK_STREAM, syscall.SOCK_DGRAM, syscall.SOCK_SEQPACKET: + return (*UnixAddr)(nil) + } + } + panic("unreachable") +} diff --git a/src/pkg/net/udp_test.go b/src/pkg/net/udp_test.go index f80d3b5a9cf..90365c05f0a 100644 --- a/src/pkg/net/udp_test.go +++ b/src/pkg/net/udp_test.go @@ -9,6 +9,33 @@ import ( "testing" ) +var udpConnAddrStringTests = []struct { + net string + laddr string + raddr string + ipv6 bool +}{ + {"udp", "127.0.0.1:0", "", false}, + {"udp", "[::1]:0", "", true}, +} + +func TestUDPConnAddrString(t *testing.T) { + for i, tt := range udpConnAddrStringTests { + if tt.ipv6 && !supportsIPv6 { + continue + } + mode := "listen" + la, _ := ResolveUDPAddr(tt.net, tt.laddr) + c, err := ListenUDP(tt.net, la) + if err != nil { + t.Fatalf("ListenUDP(%q, %q) failed: %v", tt.net, la.String(), err) + } + t.Logf("%s-%v: LocalAddr: %q, %q", mode, i, c.LocalAddr(), c.LocalAddr().String()) + t.Logf("%s-%v: RemoteAddr: %q, %q", mode, i, c.RemoteAddr(), c.RemoteAddr().String()) + c.Close() + } +} + func TestWriteToUDP(t *testing.T) { switch runtime.GOOS { case "plan9":