mirror of
https://github.com/golang/go
synced 2024-11-23 06:20:07 -07:00
net: use windows ConnectEx to dial (when possible)
Update #2631. Update #3097. R=golang-dev, bradfitz CC=golang-dev https://golang.org/cl/7061060
This commit is contained in:
parent
e32d1154ec
commit
810e439859
@ -5,7 +5,6 @@
|
|||||||
package net
|
package net
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"runtime"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -113,30 +112,16 @@ func dialAddr(net, addr string, addri Addr, deadline time.Time) (c Conn, err err
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
const useDialTimeoutRace = runtime.GOOS == "windows" || runtime.GOOS == "plan9"
|
|
||||||
|
|
||||||
// DialTimeout acts like Dial but takes a timeout.
|
// DialTimeout acts like Dial but takes a timeout.
|
||||||
// The timeout includes name resolution, if required.
|
// The timeout includes name resolution, if required.
|
||||||
func DialTimeout(net, addr string, timeout time.Duration) (Conn, error) {
|
func DialTimeout(net, addr string, timeout time.Duration) (Conn, error) {
|
||||||
if useDialTimeoutRace {
|
return dialTimeout(net, addr, timeout)
|
||||||
// On windows and plan9, use the relatively inefficient
|
|
||||||
// goroutine-racing implementation of DialTimeout that
|
|
||||||
// doesn't push down deadlines to the pollster.
|
|
||||||
// TODO: remove this once those are implemented.
|
|
||||||
return dialTimeoutRace(net, addr, timeout)
|
|
||||||
}
|
|
||||||
deadline := time.Now().Add(timeout)
|
|
||||||
_, addri, err := resolveNetAddr("dial", net, addr, deadline)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return dialAddr(net, addr, addri, deadline)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// dialTimeoutRace is the old implementation of DialTimeout, still used
|
// dialTimeoutRace is the old implementation of DialTimeout, still used
|
||||||
// on operating systems where the deadline hasn't been pushed down
|
// on operating systems where the deadline hasn't been pushed down
|
||||||
// into the pollserver.
|
// into the pollserver.
|
||||||
// TODO: fix this on Windows and plan9.
|
// TODO: fix this on plan9.
|
||||||
func dialTimeoutRace(net, addr string, timeout time.Duration) (Conn, error) {
|
func dialTimeoutRace(net, addr string, timeout time.Duration) (Conn, error) {
|
||||||
t := time.NewTimer(timeout)
|
t := time.NewTimer(timeout)
|
||||||
defer t.Stop()
|
defer t.Stop()
|
||||||
|
74
src/pkg/net/dial_windows_test.go
Normal file
74
src/pkg/net/dial_windows_test.go
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
// Copyright 2013 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package net
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
var handleCounter struct {
|
||||||
|
once sync.Once
|
||||||
|
proc *syscall.Proc
|
||||||
|
}
|
||||||
|
|
||||||
|
func numHandles(t *testing.T) int {
|
||||||
|
|
||||||
|
handleCounter.once.Do(func() {
|
||||||
|
d, err := syscall.LoadDLL("kernel32.dll")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("LoadDLL: %v\n", err)
|
||||||
|
}
|
||||||
|
handleCounter.proc, err = d.FindProc("GetProcessHandleCount")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("FindProc: %v\n", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
cp, err := syscall.GetCurrentProcess()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetCurrentProcess: %v\n", err)
|
||||||
|
}
|
||||||
|
var n uint32
|
||||||
|
r, _, err := handleCounter.proc.Call(uintptr(cp), uintptr(unsafe.Pointer(&n)))
|
||||||
|
if r == 0 {
|
||||||
|
t.Fatalf("GetProcessHandleCount: %v\n", error(err))
|
||||||
|
}
|
||||||
|
return int(n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testDialTimeoutHandleLeak(t *testing.T) (before, after int) {
|
||||||
|
before = numHandles(t)
|
||||||
|
// See comment in TestDialTimeout about why we use this address.
|
||||||
|
c, err := DialTimeout("tcp", "127.0.71.111:49151", 200*time.Millisecond)
|
||||||
|
after = numHandles(t)
|
||||||
|
if err == nil {
|
||||||
|
c.Close()
|
||||||
|
t.Fatalf("unexpected: connected to %s", c.RemoteAddr())
|
||||||
|
}
|
||||||
|
terr, ok := err.(timeout)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("got error %q; want error with timeout interface", err)
|
||||||
|
}
|
||||||
|
if !terr.Timeout() {
|
||||||
|
t.Fatalf("got error %q; not a timeout", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDialTimeoutHandleLeak(t *testing.T) {
|
||||||
|
if !canUseConnectEx("tcp") {
|
||||||
|
t.Logf("skipping test; no ConnectEx found.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
testDialTimeoutHandleLeak(t) // ignore first call results
|
||||||
|
before, after := testDialTimeoutHandleLeak(t)
|
||||||
|
if before != after {
|
||||||
|
t.Fatalf("handle count is different before=%d and after=%d", before, after)
|
||||||
|
}
|
||||||
|
}
|
@ -23,6 +23,12 @@ var canCancelIO = true // used for testing current package
|
|||||||
func sysInit() {
|
func sysInit() {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func dialTimeout(net, addr string, timeout time.Duration) (Conn, error) {
|
||||||
|
// On plan9, use the relatively inefficient
|
||||||
|
// goroutine-racing implementation.
|
||||||
|
return dialTimeoutRace(net, addr, timeout)
|
||||||
|
}
|
||||||
|
|
||||||
func newFD(proto, name string, ctl *os.File, laddr, raddr Addr) *netFD {
|
func newFD(proto, name string, ctl *os.File, laddr, raddr Addr) *netFD {
|
||||||
return &netFD{proto, name, "/net/" + proto + "/" + name, ctl, nil, laddr, raddr}
|
return &netFD{proto, name, "/net/" + proto + "/" + name, ctl, nil, laddr, raddr}
|
||||||
}
|
}
|
||||||
|
@ -288,6 +288,15 @@ func server(fd int) *pollServer {
|
|||||||
return pollservers[k]
|
return pollservers[k]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func dialTimeout(net, addr string, timeout time.Duration) (Conn, error) {
|
||||||
|
deadline := time.Now().Add(timeout)
|
||||||
|
_, addri, err := resolveNetAddr("dial", net, addr, deadline)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return dialAddr(net, addr, addri, deadline)
|
||||||
|
}
|
||||||
|
|
||||||
func newFD(fd, family, sotype int, net string) (*netFD, error) {
|
func newFD(fd, family, sotype int, net string) (*netFD, error) {
|
||||||
if err := syscall.SetNonblock(fd, true); err != nil {
|
if err := syscall.SetNonblock(fd, true); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -45,6 +45,28 @@ func closesocket(s syscall.Handle) error {
|
|||||||
return syscall.Closesocket(s)
|
return syscall.Closesocket(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func canUseConnectEx(net string) bool {
|
||||||
|
if net == "udp" || net == "udp4" || net == "udp6" {
|
||||||
|
// ConnectEx windows API does not support connectionless sockets.
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return syscall.LoadConnectEx() == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func dialTimeout(net, addr string, timeout time.Duration) (Conn, error) {
|
||||||
|
if !canUseConnectEx(net) {
|
||||||
|
// Use the relatively inefficient goroutine-racing
|
||||||
|
// implementation of DialTimeout.
|
||||||
|
return dialTimeoutRace(net, addr, timeout)
|
||||||
|
}
|
||||||
|
deadline := time.Now().Add(timeout)
|
||||||
|
_, addri, err := resolveNetAddr("dial", net, addr, deadline)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return dialAddr(net, addr, addri, deadline)
|
||||||
|
}
|
||||||
|
|
||||||
// Interface for all IO operations.
|
// Interface for all IO operations.
|
||||||
type anOpIface interface {
|
type anOpIface interface {
|
||||||
Op() *anOp
|
Op() *anOp
|
||||||
@ -321,8 +343,48 @@ func (fd *netFD) setAddr(laddr, raddr Addr) {
|
|||||||
runtime.SetFinalizer(fd, (*netFD).closesocket)
|
runtime.SetFinalizer(fd, (*netFD).closesocket)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Make new connection.
|
||||||
|
|
||||||
|
type connectOp struct {
|
||||||
|
anOp
|
||||||
|
ra syscall.Sockaddr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *connectOp) Submit() error {
|
||||||
|
return syscall.ConnectEx(o.fd.sysfd, o.ra, nil, 0, nil, &o.o)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *connectOp) Name() string {
|
||||||
|
return "ConnectEx"
|
||||||
|
}
|
||||||
|
|
||||||
func (fd *netFD) connect(ra syscall.Sockaddr) error {
|
func (fd *netFD) connect(ra syscall.Sockaddr) error {
|
||||||
|
if !canUseConnectEx(fd.net) {
|
||||||
return syscall.Connect(fd.sysfd, ra)
|
return syscall.Connect(fd.sysfd, ra)
|
||||||
|
}
|
||||||
|
// ConnectEx windows API requires an unconnected, previously bound socket.
|
||||||
|
var la syscall.Sockaddr
|
||||||
|
switch ra.(type) {
|
||||||
|
case *syscall.SockaddrInet4:
|
||||||
|
la = &syscall.SockaddrInet4{}
|
||||||
|
case *syscall.SockaddrInet6:
|
||||||
|
la = &syscall.SockaddrInet6{}
|
||||||
|
default:
|
||||||
|
panic("unexpected type in connect")
|
||||||
|
}
|
||||||
|
if err := syscall.Bind(fd.sysfd, la); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Call ConnectEx API.
|
||||||
|
var o connectOp
|
||||||
|
o.Init(fd, 'w')
|
||||||
|
o.ra = ra
|
||||||
|
_, err := iosrv.ExecIO(&o, fd.wdeadline.value())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Refresh socket properties.
|
||||||
|
return syscall.Setsockopt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_UPDATE_CONNECT_CONTEXT, (*byte)(unsafe.Pointer(&fd.sysfd)), int32(unsafe.Sizeof(fd.sysfd)))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add a reference to this fd.
|
// Add a reference to this fd.
|
||||||
|
@ -7,6 +7,8 @@
|
|||||||
package syscall
|
package syscall
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
errorspkg "errors"
|
||||||
|
"sync"
|
||||||
"unicode/utf16"
|
"unicode/utf16"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
@ -712,6 +714,56 @@ func LoadGetAddrInfo() error {
|
|||||||
return procGetAddrInfoW.Find()
|
return procGetAddrInfoW.Find()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var connectExFunc struct {
|
||||||
|
once sync.Once
|
||||||
|
addr uintptr
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadConnectEx() error {
|
||||||
|
connectExFunc.once.Do(func() {
|
||||||
|
var s Handle
|
||||||
|
s, connectExFunc.err = Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)
|
||||||
|
if connectExFunc.err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer CloseHandle(s)
|
||||||
|
var n uint32
|
||||||
|
connectExFunc.err = WSAIoctl(s,
|
||||||
|
SIO_GET_EXTENSION_FUNCTION_POINTER,
|
||||||
|
(*byte)(unsafe.Pointer(&WSAID_CONNECTEX)),
|
||||||
|
uint32(unsafe.Sizeof(WSAID_CONNECTEX)),
|
||||||
|
(*byte)(unsafe.Pointer(&connectExFunc.addr)),
|
||||||
|
uint32(unsafe.Sizeof(connectExFunc.addr)),
|
||||||
|
&n, nil, 0)
|
||||||
|
})
|
||||||
|
return connectExFunc.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func connectEx(s Handle, name uintptr, namelen int32, sendBuf *byte, sendDataLen uint32, bytesSent *uint32, overlapped *Overlapped) (err error) {
|
||||||
|
r1, _, e1 := Syscall9(connectExFunc.addr, 7, uintptr(s), uintptr(name), uintptr(namelen), uintptr(unsafe.Pointer(sendBuf)), uintptr(sendDataLen), uintptr(unsafe.Pointer(bytesSent)), uintptr(unsafe.Pointer(overlapped)), 0, 0)
|
||||||
|
if r1 == 0 {
|
||||||
|
if e1 != 0 {
|
||||||
|
err = error(e1)
|
||||||
|
} else {
|
||||||
|
err = EINVAL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func ConnectEx(fd Handle, sa Sockaddr, sendBuf *byte, sendDataLen uint32, bytesSent *uint32, overlapped *Overlapped) error {
|
||||||
|
err := LoadConnectEx()
|
||||||
|
if err != nil {
|
||||||
|
return errorspkg.New("failed to find ConnectEx: " + err.Error())
|
||||||
|
}
|
||||||
|
ptr, n, err := sa.sockaddr()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return connectEx(fd, ptr, n, sendBuf, sendDataLen, bytesSent, overlapped)
|
||||||
|
}
|
||||||
|
|
||||||
// Invented structures to support what package os expects.
|
// Invented structures to support what package os expects.
|
||||||
type Rusage struct {
|
type Rusage struct {
|
||||||
CreationTime Filetime
|
CreationTime Filetime
|
||||||
|
@ -505,6 +505,13 @@ const (
|
|||||||
SO_RCVBUF = 0x1002
|
SO_RCVBUF = 0x1002
|
||||||
SO_SNDBUF = 0x1001
|
SO_SNDBUF = 0x1001
|
||||||
SO_UPDATE_ACCEPT_CONTEXT = 0x700b
|
SO_UPDATE_ACCEPT_CONTEXT = 0x700b
|
||||||
|
SO_UPDATE_CONNECT_CONTEXT = 0x7010
|
||||||
|
|
||||||
|
IOC_OUT = 0x40000000
|
||||||
|
IOC_IN = 0x80000000
|
||||||
|
IOC_INOUT = IOC_IN | IOC_OUT
|
||||||
|
IOC_WS2 = 0x08000000
|
||||||
|
SIO_GET_EXTENSION_FUNCTION_POINTER = IOC_INOUT | IOC_WS2 | 6
|
||||||
|
|
||||||
// cf. http://support.microsoft.com/default.aspx?scid=kb;en-us;257460
|
// cf. http://support.microsoft.com/default.aspx?scid=kb;en-us;257460
|
||||||
|
|
||||||
@ -941,3 +948,17 @@ const (
|
|||||||
AI_CANONNAME = 2
|
AI_CANONNAME = 2
|
||||||
AI_NUMERICHOST = 4
|
AI_NUMERICHOST = 4
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type GUID struct {
|
||||||
|
Data1 uint32
|
||||||
|
Data2 uint16
|
||||||
|
Data3 uint16
|
||||||
|
Data4 [8]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
var WSAID_CONNECTEX = GUID{
|
||||||
|
0x25a207b9,
|
||||||
|
0xddf3,
|
||||||
|
0x4660,
|
||||||
|
[8]byte{0x8e, 0xe9, 0x76, 0xe5, 0x8c, 0x74, 0x06, 0x3e},
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user