mirror of
https://github.com/golang/go
synced 2024-11-25 08:07:57 -07:00
net: avoid nil pointer dereference when RemoteAddr.String method chain is called
Fixes #3721. R=dave, rsc CC=golang-dev https://golang.org/cl/6395055
This commit is contained in:
parent
e80f6a4de1
commit
6cf77f2af4
@ -612,11 +612,10 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (netfd *netFD, err e
|
|||||||
syscall.ForkLock.RUnlock()
|
syscall.ForkLock.RUnlock()
|
||||||
|
|
||||||
if netfd, err = newFD(s, fd.family, fd.sotype, fd.net); err != nil {
|
if netfd, err = newFD(s, fd.family, fd.sotype, fd.net); err != nil {
|
||||||
syscall.Close(s)
|
closesocket(s)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
lsa, _ := syscall.Getsockname(netfd.sysfd)
|
netfd.setAddr(localSockname(fd, toAddr), toAddr(rsa))
|
||||||
netfd.setAddr(toAddr(lsa), toAddr(rsa))
|
|
||||||
return netfd, nil
|
return netfd, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -25,8 +25,8 @@ func newFileFD(f *os.File) (*netFD, error) {
|
|||||||
|
|
||||||
family := syscall.AF_UNSPEC
|
family := syscall.AF_UNSPEC
|
||||||
toAddr := sockaddrToTCP
|
toAddr := sockaddrToTCP
|
||||||
sa, _ := syscall.Getsockname(fd)
|
lsa, _ := syscall.Getsockname(fd)
|
||||||
switch sa.(type) {
|
switch lsa.(type) {
|
||||||
default:
|
default:
|
||||||
closesocket(fd)
|
closesocket(fd)
|
||||||
return nil, syscall.EINVAL
|
return nil, syscall.EINVAL
|
||||||
@ -53,16 +53,14 @@ func newFileFD(f *os.File) (*netFD, error) {
|
|||||||
toAddr = sockaddrToUnixpacket
|
toAddr = sockaddrToUnixpacket
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
laddr := toAddr(sa)
|
laddr := toAddr(lsa)
|
||||||
sa, _ = syscall.Getpeername(fd)
|
|
||||||
raddr := toAddr(sa)
|
|
||||||
|
|
||||||
netfd, err := newFD(fd, family, sotype, laddr.Network())
|
netfd, err := newFD(fd, family, sotype, laddr.Network())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
closesocket(fd)
|
closesocket(fd)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
netfd.setAddr(laddr, raddr)
|
netfd.setAddr(laddr, remoteSockname(netfd, toAddr))
|
||||||
return netfd, nil
|
return netfd, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -80,10 +78,10 @@ func FileConn(f *os.File) (c Conn, err error) {
|
|||||||
return newTCPConn(fd), nil
|
return newTCPConn(fd), nil
|
||||||
case *UDPAddr:
|
case *UDPAddr:
|
||||||
return newUDPConn(fd), nil
|
return newUDPConn(fd), nil
|
||||||
case *UnixAddr:
|
|
||||||
return newUnixConn(fd), nil
|
|
||||||
case *IPAddr:
|
case *IPAddr:
|
||||||
return newIPConn(fd), nil
|
return newIPConn(fd), nil
|
||||||
|
case *UnixAddr:
|
||||||
|
return newUnixConn(fd), nil
|
||||||
}
|
}
|
||||||
fd.Close()
|
fd.Close()
|
||||||
return nil, syscall.EINVAL
|
return nil, syscall.EINVAL
|
||||||
|
@ -14,6 +14,55 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var ipConnAddrStringTests = []struct {
|
||||||
|
net string
|
||||||
|
laddr string
|
||||||
|
raddr string
|
||||||
|
ipv6 bool
|
||||||
|
}{
|
||||||
|
{"ip:icmp", "127.0.0.1", "", false},
|
||||||
|
{"ip:icmp", "::1", "", true},
|
||||||
|
{"ip:icmp", "", "127.0.0.1", false},
|
||||||
|
{"ip:icmp", "", "::1", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIPConnAddrString(t *testing.T) {
|
||||||
|
if os.Getuid() != 0 {
|
||||||
|
t.Logf("skipping test; must be root")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range ipConnAddrStringTests {
|
||||||
|
if tt.ipv6 && !supportsIPv6 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
err error
|
||||||
|
c *IPConn
|
||||||
|
mode string
|
||||||
|
)
|
||||||
|
if tt.raddr == "" {
|
||||||
|
mode = "listen"
|
||||||
|
la, _ := ResolveIPAddr(tt.net, tt.laddr)
|
||||||
|
c, err = ListenIP(tt.net, la)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListenIP(%q, %q) failed: %v", tt.net, la.String(), err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
mode = "dial"
|
||||||
|
la, _ := ResolveIPAddr(tt.net, tt.laddr)
|
||||||
|
ra, _ := ResolveIPAddr(tt.net, tt.raddr)
|
||||||
|
c, err = DialIP(tt.net, la, ra)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("DialIP(%q, %q) failed: %v", tt.net, ra.String(), err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.Logf("%s-%v: LocalAddr: %q, %q", mode, i, c.LocalAddr(), c.LocalAddr().String())
|
||||||
|
t.Logf("%s-%v: RemoteAddr: %q, %q", mode, i, c.RemoteAddr(), c.RemoteAddr().String())
|
||||||
|
c.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var icmpTests = []struct {
|
var icmpTests = []struct {
|
||||||
net string
|
net string
|
||||||
laddr string
|
laddr string
|
||||||
@ -26,7 +75,7 @@ var icmpTests = []struct {
|
|||||||
|
|
||||||
func TestICMP(t *testing.T) {
|
func TestICMP(t *testing.T) {
|
||||||
if os.Getuid() != 0 {
|
if os.Getuid() != 0 {
|
||||||
t.Logf("test disabled; must be root")
|
t.Logf("skipping test; must be root")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ import (
|
|||||||
var listenerBacklog = maxListenerBacklog()
|
var listenerBacklog = maxListenerBacklog()
|
||||||
|
|
||||||
// Generic socket creation.
|
// Generic socket creation.
|
||||||
func socket(net string, f, t, p int, ipv6only bool, la, ra syscall.Sockaddr, toAddr func(syscall.Sockaddr) Addr) (fd *netFD, err error) {
|
func socket(net string, f, t, p int, ipv6only bool, ulsa, ursa syscall.Sockaddr, toAddr func(syscall.Sockaddr) Addr) (fd *netFD, err error) {
|
||||||
// See ../syscall/exec.go for description of ForkLock.
|
// See ../syscall/exec.go for description of ForkLock.
|
||||||
syscall.ForkLock.RLock()
|
syscall.ForkLock.RLock()
|
||||||
s, err := syscall.Socket(f, t, p)
|
s, err := syscall.Socket(f, t, p)
|
||||||
@ -27,21 +27,18 @@ func socket(net string, f, t, p int, ipv6only bool, la, ra syscall.Sockaddr, toA
|
|||||||
syscall.CloseOnExec(s)
|
syscall.CloseOnExec(s)
|
||||||
syscall.ForkLock.RUnlock()
|
syscall.ForkLock.RUnlock()
|
||||||
|
|
||||||
err = setDefaultSockopts(s, f, t, ipv6only)
|
if err = setDefaultSockopts(s, f, t, ipv6only); err != nil {
|
||||||
if err != nil {
|
|
||||||
closesocket(s)
|
closesocket(s)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var bla syscall.Sockaddr
|
var blsa syscall.Sockaddr
|
||||||
if la != nil {
|
if ulsa != nil {
|
||||||
bla, err = listenerSockaddr(s, f, la, toAddr)
|
if blsa, err = listenerSockaddr(s, f, ulsa, toAddr); err != nil {
|
||||||
if err != nil {
|
|
||||||
closesocket(s)
|
closesocket(s)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
err = syscall.Bind(s, bla)
|
if err = syscall.Bind(s, blsa); err != nil {
|
||||||
if err != nil {
|
|
||||||
closesocket(s)
|
closesocket(s)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -52,8 +49,8 @@ func socket(net string, f, t, p int, ipv6only bool, la, ra syscall.Sockaddr, toA
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if ra != nil {
|
if ursa != nil {
|
||||||
if err = fd.connect(ra); err != nil {
|
if err = fd.connect(ursa); err != nil {
|
||||||
closesocket(s)
|
closesocket(s)
|
||||||
fd.Close()
|
fd.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -61,17 +58,13 @@ func socket(net string, f, t, p int, ipv6only bool, la, ra syscall.Sockaddr, toA
|
|||||||
fd.isConnected = true
|
fd.isConnected = true
|
||||||
}
|
}
|
||||||
|
|
||||||
sa, _ := syscall.Getsockname(s)
|
|
||||||
var laddr Addr
|
var laddr Addr
|
||||||
if la != nil && bla != la {
|
if ulsa != nil && blsa != ulsa {
|
||||||
laddr = toAddr(la)
|
laddr = toAddr(ulsa)
|
||||||
} else {
|
} else {
|
||||||
laddr = toAddr(sa)
|
laddr = localSockname(fd, toAddr)
|
||||||
}
|
}
|
||||||
sa, _ = syscall.Getpeername(s)
|
fd.setAddr(laddr, remoteSockname(fd, toAddr))
|
||||||
raddr := toAddr(sa)
|
|
||||||
|
|
||||||
fd.setAddr(laddr, raddr)
|
|
||||||
return fd, nil
|
return fd, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -85,3 +78,39 @@ func genericReadFrom(w io.Writer, r io.Reader) (n int64, err error) {
|
|||||||
// Use wrapper to hide existing r.ReadFrom from io.Copy.
|
// Use wrapper to hide existing r.ReadFrom from io.Copy.
|
||||||
return io.Copy(writerOnly{w}, r)
|
return io.Copy(writerOnly{w}, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func localSockname(fd *netFD, toAddr func(syscall.Sockaddr) Addr) Addr {
|
||||||
|
sa, _ := syscall.Getsockname(fd.sysfd)
|
||||||
|
if sa == nil {
|
||||||
|
return nullProtocolAddr(fd.family, fd.sotype)
|
||||||
|
}
|
||||||
|
return toAddr(sa)
|
||||||
|
}
|
||||||
|
|
||||||
|
func remoteSockname(fd *netFD, toAddr func(syscall.Sockaddr) Addr) Addr {
|
||||||
|
sa, _ := syscall.Getpeername(fd.sysfd)
|
||||||
|
if sa == nil {
|
||||||
|
return nullProtocolAddr(fd.family, fd.sotype)
|
||||||
|
}
|
||||||
|
return toAddr(sa)
|
||||||
|
}
|
||||||
|
|
||||||
|
func nullProtocolAddr(f, t int) Addr {
|
||||||
|
switch f {
|
||||||
|
case syscall.AF_INET, syscall.AF_INET6:
|
||||||
|
switch t {
|
||||||
|
case syscall.SOCK_STREAM:
|
||||||
|
return (*TCPAddr)(nil)
|
||||||
|
case syscall.SOCK_DGRAM:
|
||||||
|
return (*UDPAddr)(nil)
|
||||||
|
case syscall.SOCK_RAW:
|
||||||
|
return (*IPAddr)(nil)
|
||||||
|
}
|
||||||
|
case syscall.AF_UNIX:
|
||||||
|
switch t {
|
||||||
|
case syscall.SOCK_STREAM, syscall.SOCK_DGRAM, syscall.SOCK_SEQPACKET:
|
||||||
|
return (*UnixAddr)(nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
panic("unreachable")
|
||||||
|
}
|
||||||
|
@ -9,6 +9,33 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var udpConnAddrStringTests = []struct {
|
||||||
|
net string
|
||||||
|
laddr string
|
||||||
|
raddr string
|
||||||
|
ipv6 bool
|
||||||
|
}{
|
||||||
|
{"udp", "127.0.0.1:0", "", false},
|
||||||
|
{"udp", "[::1]:0", "", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUDPConnAddrString(t *testing.T) {
|
||||||
|
for i, tt := range udpConnAddrStringTests {
|
||||||
|
if tt.ipv6 && !supportsIPv6 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
mode := "listen"
|
||||||
|
la, _ := ResolveUDPAddr(tt.net, tt.laddr)
|
||||||
|
c, err := ListenUDP(tt.net, la)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListenUDP(%q, %q) failed: %v", tt.net, la.String(), err)
|
||||||
|
}
|
||||||
|
t.Logf("%s-%v: LocalAddr: %q, %q", mode, i, c.LocalAddr(), c.LocalAddr().String())
|
||||||
|
t.Logf("%s-%v: RemoteAddr: %q, %q", mode, i, c.RemoteAddr(), c.RemoteAddr().String())
|
||||||
|
c.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestWriteToUDP(t *testing.T) {
|
func TestWriteToUDP(t *testing.T) {
|
||||||
switch runtime.GOOS {
|
switch runtime.GOOS {
|
||||||
case "plan9":
|
case "plan9":
|
||||||
|
Loading…
Reference in New Issue
Block a user