diff --git a/src/pkg/net/fd.go b/src/pkg/net/fd.go index 3e8780083d..cd1a21dc36 100644 --- a/src/pkg/net/fd.go +++ b/src/pkg/net/fd.go @@ -303,26 +303,18 @@ func (fd *netFD) setAddr(laddr, raddr Addr) { fd.sysfile = os.NewFile(fd.sysfd, fd.net+":"+ls+"->"+rs) } -func (fd *netFD) connect(la, ra syscall.Sockaddr) (err os.Error) { - if la != nil { - e := syscall.Bind(fd.sysfd, la) - if e != 0 { - return os.Errno(e) +func (fd *netFD) connect(ra syscall.Sockaddr) (err os.Error) { + e := syscall.Connect(fd.sysfd, ra) + if e == syscall.EINPROGRESS { + var errno int + pollserver.WaitWrite(fd) + e, errno = syscall.GetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_ERROR) + if errno != 0 { + return os.NewSyscallError("getsockopt", errno) } } - if ra != nil { - e := syscall.Connect(fd.sysfd, ra) - if e == syscall.EINPROGRESS { - var errno int - pollserver.WaitWrite(fd) - e, errno = syscall.GetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_ERROR) - if errno != 0 { - return os.NewSyscallError("getsockopt", errno) - } - } - if e != 0 { - return os.Errno(e) - } + if e != 0 { + return os.Errno(e) } return nil } diff --git a/src/pkg/net/fd_windows.go b/src/pkg/net/fd_windows.go index 0abf230ce1..c2f736cc12 100644 --- a/src/pkg/net/fd_windows.go +++ b/src/pkg/net/fd_windows.go @@ -253,18 +253,10 @@ func (fd *netFD) setAddr(laddr, raddr Addr) { fd.raddr = raddr } -func (fd *netFD) connect(la, ra syscall.Sockaddr) (err os.Error) { - if la != nil { - e := syscall.Bind(fd.sysfd, la) - if e != 0 { - return os.Errno(e) - } - } - if ra != nil { - e := syscall.Connect(fd.sysfd, ra) - if e != 0 { - return os.Errno(e) - } +func (fd *netFD) connect(ra syscall.Sockaddr) (err os.Error) { + e := syscall.Connect(fd.sysfd, ra) + if e != 0 { + return os.Errno(e) } return nil } diff --git a/src/pkg/net/sock.go b/src/pkg/net/sock.go index 9b99ad58f8..933700af16 100644 --- a/src/pkg/net/sock.go +++ b/src/pkg/net/sock.go @@ -44,19 +44,30 @@ func socket(net string, f, p, t int, la, ra syscall.Sockaddr, toAddr func(syscal syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0) } + if la != nil { + e = syscall.Bind(s, la) + if e != 0 { + closesocket(s) + return nil, os.Errno(e) + } + } + if fd, err = newFD(s, f, p, net); err != nil { closesocket(s) return nil, err } - if err = fd.connect(la, ra); err != nil { - closesocket(s) - return nil, err + if ra != nil { + if err = fd.connect(ra); err != nil { + fd.sysfd = -1 + closesocket(s) + return nil, err + } } - sa, _ := syscall.Getsockname(fd.sysfd) + sa, _ := syscall.Getsockname(s) laddr := toAddr(sa) - sa, _ = syscall.Getpeername(fd.sysfd) + sa, _ = syscall.Getpeername(s) raddr := toAddr(sa) fd.setAddr(laddr, raddr)