diff --git a/src/pkg/net/dial.go b/src/pkg/net/dial.go index c1eb983cc0f..354028a157a 100644 --- a/src/pkg/net/dial.go +++ b/src/pkg/net/dial.go @@ -5,7 +5,6 @@ package net import ( - "runtime" "time" ) @@ -113,30 +112,16 @@ func dialAddr(net, addr string, addri Addr, deadline time.Time) (c Conn, err err return } -const useDialTimeoutRace = runtime.GOOS == "windows" || runtime.GOOS == "plan9" - // DialTimeout acts like Dial but takes a timeout. // The timeout includes name resolution, if required. func DialTimeout(net, addr string, timeout time.Duration) (Conn, error) { - if useDialTimeoutRace { - // On windows and plan9, use the relatively inefficient - // goroutine-racing implementation of DialTimeout that - // doesn't push down deadlines to the pollster. - // TODO: remove this once those are implemented. - return dialTimeoutRace(net, addr, timeout) - } - deadline := time.Now().Add(timeout) - _, addri, err := resolveNetAddr("dial", net, addr, deadline) - if err != nil { - return nil, err - } - return dialAddr(net, addr, addri, deadline) + return dialTimeout(net, addr, timeout) } // dialTimeoutRace is the old implementation of DialTimeout, still used // on operating systems where the deadline hasn't been pushed down // into the pollserver. -// TODO: fix this on Windows and plan9. +// TODO: fix this on plan9. func dialTimeoutRace(net, addr string, timeout time.Duration) (Conn, error) { t := time.NewTimer(timeout) defer t.Stop() diff --git a/src/pkg/net/dial_windows_test.go b/src/pkg/net/dial_windows_test.go new file mode 100644 index 00000000000..8fc9b2fd462 --- /dev/null +++ b/src/pkg/net/dial_windows_test.go @@ -0,0 +1,74 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "sync" + "syscall" + "testing" + "time" + "unsafe" +) + +var handleCounter struct { + once sync.Once + proc *syscall.Proc +} + +func numHandles(t *testing.T) int { + + handleCounter.once.Do(func() { + d, err := syscall.LoadDLL("kernel32.dll") + if err != nil { + t.Fatalf("LoadDLL: %v\n", err) + } + handleCounter.proc, err = d.FindProc("GetProcessHandleCount") + if err != nil { + t.Fatalf("FindProc: %v\n", err) + } + }) + + cp, err := syscall.GetCurrentProcess() + if err != nil { + t.Fatalf("GetCurrentProcess: %v\n", err) + } + var n uint32 + r, _, err := handleCounter.proc.Call(uintptr(cp), uintptr(unsafe.Pointer(&n))) + if r == 0 { + t.Fatalf("GetProcessHandleCount: %v\n", error(err)) + } + return int(n) +} + +func testDialTimeoutHandleLeak(t *testing.T) (before, after int) { + before = numHandles(t) + // See comment in TestDialTimeout about why we use this address. + c, err := DialTimeout("tcp", "127.0.71.111:49151", 200*time.Millisecond) + after = numHandles(t) + if err == nil { + c.Close() + t.Fatalf("unexpected: connected to %s", c.RemoteAddr()) + } + terr, ok := err.(timeout) + if !ok { + t.Fatalf("got error %q; want error with timeout interface", err) + } + if !terr.Timeout() { + t.Fatalf("got error %q; not a timeout", err) + } + return +} + +func TestDialTimeoutHandleLeak(t *testing.T) { + if !canUseConnectEx("tcp") { + t.Logf("skipping test; no ConnectEx found.") + return + } + testDialTimeoutHandleLeak(t) // ignore first call results + before, after := testDialTimeoutHandleLeak(t) + if before != after { + t.Fatalf("handle count is different before=%d and after=%d", before, after) + } +} diff --git a/src/pkg/net/fd_plan9.go b/src/pkg/net/fd_plan9.go index 6d7ab388ae7..3462792816e 100644 --- a/src/pkg/net/fd_plan9.go +++ b/src/pkg/net/fd_plan9.go @@ -23,6 +23,12 @@ var canCancelIO = true // used for testing current package func sysInit() { } +func dialTimeout(net, addr string, timeout time.Duration) (Conn, error) { + // On plan9, use the relatively inefficient + // goroutine-racing implementation. + return dialTimeoutRace(net, addr, timeout) +} + func newFD(proto, name string, ctl *os.File, laddr, raddr Addr) *netFD { return &netFD{proto, name, "/net/" + proto + "/" + name, ctl, nil, laddr, raddr} } diff --git a/src/pkg/net/fd_unix.go b/src/pkg/net/fd_unix.go index 6d8af0ab2e2..cfe6df21308 100644 --- a/src/pkg/net/fd_unix.go +++ b/src/pkg/net/fd_unix.go @@ -288,6 +288,15 @@ func server(fd int) *pollServer { return pollservers[k] } +func dialTimeout(net, addr string, timeout time.Duration) (Conn, error) { + deadline := time.Now().Add(timeout) + _, addri, err := resolveNetAddr("dial", net, addr, deadline) + if err != nil { + return nil, err + } + return dialAddr(net, addr, addri, deadline) +} + func newFD(fd, family, sotype int, net string) (*netFD, error) { if err := syscall.SetNonblock(fd, true); err != nil { return nil, err diff --git a/src/pkg/net/fd_windows.go b/src/pkg/net/fd_windows.go index 18712191fee..ea6ef10ec1a 100644 --- a/src/pkg/net/fd_windows.go +++ b/src/pkg/net/fd_windows.go @@ -45,6 +45,28 @@ func closesocket(s syscall.Handle) error { return syscall.Closesocket(s) } +func canUseConnectEx(net string) bool { + if net == "udp" || net == "udp4" || net == "udp6" { + // ConnectEx windows API does not support connectionless sockets. + return false + } + return syscall.LoadConnectEx() == nil +} + +func dialTimeout(net, addr string, timeout time.Duration) (Conn, error) { + if !canUseConnectEx(net) { + // Use the relatively inefficient goroutine-racing + // implementation of DialTimeout. + return dialTimeoutRace(net, addr, timeout) + } + deadline := time.Now().Add(timeout) + _, addri, err := resolveNetAddr("dial", net, addr, deadline) + if err != nil { + return nil, err + } + return dialAddr(net, addr, addri, deadline) +} + // Interface for all IO operations. type anOpIface interface { Op() *anOp @@ -321,8 +343,48 @@ func (fd *netFD) setAddr(laddr, raddr Addr) { runtime.SetFinalizer(fd, (*netFD).closesocket) } +// Make new connection. + +type connectOp struct { + anOp + ra syscall.Sockaddr +} + +func (o *connectOp) Submit() error { + return syscall.ConnectEx(o.fd.sysfd, o.ra, nil, 0, nil, &o.o) +} + +func (o *connectOp) Name() string { + return "ConnectEx" +} + func (fd *netFD) connect(ra syscall.Sockaddr) error { - return syscall.Connect(fd.sysfd, ra) + if !canUseConnectEx(fd.net) { + return syscall.Connect(fd.sysfd, ra) + } + // ConnectEx windows API requires an unconnected, previously bound socket. + var la syscall.Sockaddr + switch ra.(type) { + case *syscall.SockaddrInet4: + la = &syscall.SockaddrInet4{} + case *syscall.SockaddrInet6: + la = &syscall.SockaddrInet6{} + default: + panic("unexpected type in connect") + } + if err := syscall.Bind(fd.sysfd, la); err != nil { + return err + } + // Call ConnectEx API. + var o connectOp + o.Init(fd, 'w') + o.ra = ra + _, err := iosrv.ExecIO(&o, fd.wdeadline.value()) + if err != nil { + return err + } + // Refresh socket properties. + return syscall.Setsockopt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_UPDATE_CONNECT_CONTEXT, (*byte)(unsafe.Pointer(&fd.sysfd)), int32(unsafe.Sizeof(fd.sysfd))) } // Add a reference to this fd. diff --git a/src/pkg/syscall/syscall_windows.go b/src/pkg/syscall/syscall_windows.go index 5acb65dee14..e745fbe510f 100644 --- a/src/pkg/syscall/syscall_windows.go +++ b/src/pkg/syscall/syscall_windows.go @@ -7,6 +7,8 @@ package syscall import ( + errorspkg "errors" + "sync" "unicode/utf16" "unsafe" ) @@ -712,6 +714,56 @@ func LoadGetAddrInfo() error { return procGetAddrInfoW.Find() } +var connectExFunc struct { + once sync.Once + addr uintptr + err error +} + +func LoadConnectEx() error { + connectExFunc.once.Do(func() { + var s Handle + s, connectExFunc.err = Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP) + if connectExFunc.err != nil { + return + } + defer CloseHandle(s) + var n uint32 + connectExFunc.err = WSAIoctl(s, + SIO_GET_EXTENSION_FUNCTION_POINTER, + (*byte)(unsafe.Pointer(&WSAID_CONNECTEX)), + uint32(unsafe.Sizeof(WSAID_CONNECTEX)), + (*byte)(unsafe.Pointer(&connectExFunc.addr)), + uint32(unsafe.Sizeof(connectExFunc.addr)), + &n, nil, 0) + }) + return connectExFunc.err +} + +func connectEx(s Handle, name uintptr, namelen int32, sendBuf *byte, sendDataLen uint32, bytesSent *uint32, overlapped *Overlapped) (err error) { + r1, _, e1 := Syscall9(connectExFunc.addr, 7, uintptr(s), uintptr(name), uintptr(namelen), uintptr(unsafe.Pointer(sendBuf)), uintptr(sendDataLen), uintptr(unsafe.Pointer(bytesSent)), uintptr(unsafe.Pointer(overlapped)), 0, 0) + if r1 == 0 { + if e1 != 0 { + err = error(e1) + } else { + err = EINVAL + } + } + return +} + +func ConnectEx(fd Handle, sa Sockaddr, sendBuf *byte, sendDataLen uint32, bytesSent *uint32, overlapped *Overlapped) error { + err := LoadConnectEx() + if err != nil { + return errorspkg.New("failed to find ConnectEx: " + err.Error()) + } + ptr, n, err := sa.sockaddr() + if err != nil { + return err + } + return connectEx(fd, ptr, n, sendBuf, sendDataLen, bytesSent, overlapped) +} + // Invented structures to support what package os expects. type Rusage struct { CreationTime Filetime diff --git a/src/pkg/syscall/ztypes_windows.go b/src/pkg/syscall/ztypes_windows.go index 1f7308796fb..a2006f803d6 100644 --- a/src/pkg/syscall/ztypes_windows.go +++ b/src/pkg/syscall/ztypes_windows.go @@ -496,15 +496,22 @@ const ( IPPROTO_TCP = 6 IPPROTO_UDP = 17 - SOL_SOCKET = 0xffff - SO_REUSEADDR = 4 - SO_KEEPALIVE = 8 - SO_DONTROUTE = 16 - SO_BROADCAST = 32 - SO_LINGER = 128 - SO_RCVBUF = 0x1002 - SO_SNDBUF = 0x1001 - SO_UPDATE_ACCEPT_CONTEXT = 0x700b + SOL_SOCKET = 0xffff + SO_REUSEADDR = 4 + SO_KEEPALIVE = 8 + SO_DONTROUTE = 16 + SO_BROADCAST = 32 + SO_LINGER = 128 + SO_RCVBUF = 0x1002 + SO_SNDBUF = 0x1001 + SO_UPDATE_ACCEPT_CONTEXT = 0x700b + SO_UPDATE_CONNECT_CONTEXT = 0x7010 + + IOC_OUT = 0x40000000 + IOC_IN = 0x80000000 + IOC_INOUT = IOC_IN | IOC_OUT + IOC_WS2 = 0x08000000 + SIO_GET_EXTENSION_FUNCTION_POINTER = IOC_INOUT | IOC_WS2 | 6 // cf. http://support.microsoft.com/default.aspx?scid=kb;en-us;257460 @@ -941,3 +948,17 @@ const ( AI_CANONNAME = 2 AI_NUMERICHOST = 4 ) + +type GUID struct { + Data1 uint32 + Data2 uint16 + Data3 uint16 + Data4 [8]byte +} + +var WSAID_CONNECTEX = GUID{ + 0x25a207b9, + 0xddf3, + 0x4660, + [8]byte{0x8e, 0xe9, 0x76, 0xe5, 0x8c, 0x74, 0x06, 0x3e}, +}