diff --git a/src/pkg/net/fd_unix.go b/src/pkg/net/fd_unix.go index 72317426aa..828e998e3e 100644 --- a/src/pkg/net/fd_unix.go +++ b/src/pkg/net/fd_unix.go @@ -261,6 +261,8 @@ var pollMaxN int var pollservers []*pollServer var startServersOnce []func() +var canCancelIO = true // used for testing current package + func init() { pollMaxN = runtime.NumCPU() if pollMaxN > 8 { diff --git a/src/pkg/net/fd_windows.go b/src/pkg/net/fd_windows.go index eecb4a866a..f94f08295f 100644 --- a/src/pkg/net/fd_windows.go +++ b/src/pkg/net/fd_windows.go @@ -17,19 +17,32 @@ import ( var initErr error +// CancelIo Windows API cancels all outstanding IO for a particular +// socket on current thread. To overcome that limitation, we run +// special goroutine, locked to OS single thread, that both starts +// and cancels IO. It means, there are 2 unavoidable thread switches +// for every IO. +// Some newer versions of Windows has new CancelIoEx API, that does +// not have that limitation and can be used from any thread. This +// package uses CancelIoEx API, if present, otherwise it fallback +// to CancelIo. + +var canCancelIO bool // determines if CancelIoEx API is present + func init() { var d syscall.WSAData e := syscall.WSAStartup(uint32(0x202), &d) if e != nil { initErr = os.NewSyscallError("WSAStartup", e) } + canCancelIO = syscall.LoadCancelIoEx() == nil } func closesocket(s syscall.Handle) error { return syscall.Closesocket(s) } -// Interface for all io operations. +// Interface for all IO operations. type anOpIface interface { Op() *anOp Name() string @@ -42,7 +55,7 @@ type ioResult struct { err error } -// anOp implements functionality common to all io operations. +// anOp implements functionality common to all IO operations. type anOp struct { // Used by IOCP interface, it must be first field // of the struct, as our code rely on it. @@ -75,7 +88,7 @@ func (o *anOp) Op() *anOp { return o } -// bufOp is used by io operations that read / write +// bufOp is used by IO operations that read / write // data from / to client buffer. type bufOp struct { anOp @@ -92,7 +105,7 @@ func (o *bufOp) Init(fd *netFD, buf []byte, mode int) { } } -// resultSrv will retrieve all io completion results from +// resultSrv will retrieve all IO completion results from // iocp and send them to the correspondent waiting client // goroutine via channel supplied in the request. type resultSrv struct { @@ -107,7 +120,7 @@ func (s *resultSrv) Run() { r.err = syscall.GetQueuedCompletionStatus(s.iocp, &(r.qty), &key, &o, syscall.INFINITE) switch { case r.err == nil: - // Dequeued successfully completed io packet. + // Dequeued successfully completed IO packet. case r.err == syscall.Errno(syscall.WAIT_TIMEOUT) && o == nil: // Wait has timed out (should not happen now, but might be used in the future). panic("GetQueuedCompletionStatus timed out") @@ -115,22 +128,23 @@ func (s *resultSrv) Run() { // Failed to dequeue anything -> report the error. panic("GetQueuedCompletionStatus failed " + r.err.Error()) default: - // Dequeued failed io packet. + // Dequeued failed IO packet. } (*anOp)(unsafe.Pointer(o)).resultc <- r } } -// ioSrv executes net io requests. +// ioSrv executes net IO requests. type ioSrv struct { - submchan chan anOpIface // submit io requests - canchan chan anOpIface // cancel io requests + submchan chan anOpIface // submit IO requests + canchan chan anOpIface // cancel IO requests } -// ProcessRemoteIO will execute submit io requests on behalf +// ProcessRemoteIO will execute submit IO requests on behalf // of other goroutines, all on a single os thread, so it can // cancel them later. Results of all operations will be sent // back to their requesters via channel supplied in request. +// It is used only when the CancelIoEx API is unavailable. func (s *ioSrv) ProcessRemoteIO() { runtime.LockOSThread() defer runtime.UnlockOSThread() @@ -144,20 +158,21 @@ func (s *ioSrv) ProcessRemoteIO() { } } -// ExecIO executes a single io operation. It either executes it -// inline, or, if a deadline is employed, passes the request onto +// ExecIO executes a single IO operation oi. It submits and cancels +// IO in the current thread for systems where Windows CancelIoEx API +// is available. Alternatively, it passes the request onto // a special goroutine and waits for completion or cancels request. // deadline is unix nanos. func (s *ioSrv) ExecIO(oi anOpIface, deadline int64) (int, error) { var err error o := oi.Op() - if deadline != 0 { + if canCancelIO { + err = oi.Submit() + } else { // Send request to a special dedicated thread, - // so it can stop the io with CancelIO later. + // so it can stop the IO with CancelIO later. s.submchan <- oi err = <-o.errnoc - } else { - err = oi.Submit() } switch err { case nil: @@ -168,27 +183,45 @@ func (s *ioSrv) ExecIO(oi anOpIface, deadline int64) (int, error) { default: return 0, &OpError{oi.Name(), o.fd.net, o.fd.laddr, err} } - // Wait for our request to complete. - var r ioResult + // Setup timer, if deadline is given. + var timer <-chan time.Time if deadline != 0 { dt := deadline - time.Now().UnixNano() if dt < 1 { dt = 1 } - timer := time.NewTimer(time.Duration(dt) * time.Nanosecond) - defer timer.Stop() - select { - case r = <-o.resultc: - case <-timer.C: + t := time.NewTimer(time.Duration(dt) * time.Nanosecond) + defer t.Stop() + timer = t.C + } + // Wait for our request to complete. + var r ioResult + var cancelled bool + select { + case r = <-o.resultc: + case <-timer: + cancelled = true + case <-o.fd.closec: + cancelled = true + } + if cancelled { + // Cancel it. + if canCancelIO { + err := syscall.CancelIoEx(syscall.Handle(o.Op().fd.sysfd), &o.o) + // Assuming ERROR_NOT_FOUND is returned, if IO is completed. + if err != nil && err != syscall.ERROR_NOT_FOUND { + // TODO(brainman): maybe do something else, but panic. + panic(err) + } + } else { s.canchan <- oi <-o.errnoc - r = <-o.resultc - if r.err == syscall.ERROR_OPERATION_ABORTED { // IO Canceled - r.err = syscall.EWOULDBLOCK - } } - } else { + // Wait for IO to be canceled or complete successfully. r = <-o.resultc + if r.err == syscall.ERROR_OPERATION_ABORTED { // IO Canceled + r.err = syscall.EWOULDBLOCK + } } if r.err != nil { err = &OpError{oi.Name(), o.fd.net, o.fd.laddr, r.err} @@ -211,9 +244,13 @@ func startServer() { go resultsrv.Run() iosrv = new(ioSrv) - iosrv.submchan = make(chan anOpIface) - iosrv.canchan = make(chan anOpIface) - go iosrv.ProcessRemoteIO() + if !canCancelIO { + // Only CancelIo API is available. Lets start special goroutine + // locked to an OS thread, that both starts and cancels IO. + iosrv.submchan = make(chan anOpIface) + iosrv.canchan = make(chan anOpIface) + go iosrv.ProcessRemoteIO() + } } // Network file descriptor. @@ -233,6 +270,7 @@ type netFD struct { raddr Addr resultc [2]chan ioResult // read/write completion results errnoc [2]chan error // read/write submit or cancel operation errors + closec chan bool // used by Close to cancel pending IO // owned by client rdeadline int64 @@ -247,6 +285,7 @@ func allocFD(fd syscall.Handle, family, sotype int, net string) *netFD { family: family, sotype: sotype, net: net, + closec: make(chan bool), } runtime.SetFinalizer(netfd, (*netFD).Close) return netfd @@ -299,24 +338,12 @@ func (fd *netFD) incref(closing bool) error { // Remove a reference to this FD and close if we've been asked to do so (and // there are no references left. func (fd *netFD) decref() { + if fd == nil { + return + } fd.sysmu.Lock() fd.sysref-- - // NOTE(rsc): On Unix we check fd.sysref == 0 here before closing, - // but on Windows we have no way to wake up the blocked I/O other - // than closing the socket (or calling Shutdown, which breaks other - // programs that might have a reference to the socket). So there is - // a small race here that we might close fd.sysfd and then some other - // goroutine might start a read of fd.sysfd (having read it before we - // write InvalidHandle to it), which might refer to some other file - // if the specific handle value gets reused. I think handle values on - // Windows are not reused as aggressively as file descriptors on Unix, - // so this might be tolerable. - if fd.closing && fd.sysfd != syscall.InvalidHandle { - // In case the user has set linger, switch to blocking mode so - // the close blocks. As long as this doesn't happen often, we - // can handle the extra OS processes. Otherwise we'll need to - // use the resultsrv for Close too. Sigh. - syscall.SetNonblock(fd.sysfd, false) + if fd.closing && fd.sysref == 0 && fd.sysfd != syscall.InvalidHandle { closesocket(fd.sysfd) fd.sysfd = syscall.InvalidHandle // no need for a finalizer anymore @@ -329,7 +356,14 @@ func (fd *netFD) Close() error { if err := fd.incref(true); err != nil { return err } - fd.decref() + defer fd.decref() + // unblock pending reader and writer + close(fd.closec) + // wait for both reader and writer to exit + fd.rio.Lock() + defer fd.rio.Unlock() + fd.wio.Lock() + defer fd.wio.Unlock() return nil } @@ -368,18 +402,12 @@ func (o *readOp) Name() string { } func (fd *netFD) Read(buf []byte) (int, error) { - if fd == nil { - return 0, syscall.EINVAL - } - fd.rio.Lock() - defer fd.rio.Unlock() if err := fd.incref(false); err != nil { return 0, err } defer fd.decref() - if fd.sysfd == syscall.InvalidHandle { - return 0, syscall.EINVAL - } + fd.rio.Lock() + defer fd.rio.Unlock() var o readOp o.Init(fd, buf, 'r') n, err := iosrv.ExecIO(&o, fd.rdeadline) @@ -407,18 +435,15 @@ func (o *readFromOp) Name() string { } func (fd *netFD) ReadFrom(buf []byte) (n int, sa syscall.Sockaddr, err error) { - if fd == nil { - return 0, nil, syscall.EINVAL - } if len(buf) == 0 { return 0, nil, nil } - fd.rio.Lock() - defer fd.rio.Unlock() if err := fd.incref(false); err != nil { return 0, nil, err } defer fd.decref() + fd.rio.Lock() + defer fd.rio.Unlock() var o readFromOp o.Init(fd, buf, 'r') o.rsan = int32(unsafe.Sizeof(o.rsa)) @@ -446,15 +471,12 @@ func (o *writeOp) Name() string { } func (fd *netFD) Write(buf []byte) (int, error) { - if fd == nil { - return 0, syscall.EINVAL - } - fd.wio.Lock() - defer fd.wio.Unlock() if err := fd.incref(false); err != nil { return 0, err } defer fd.decref() + fd.wio.Lock() + defer fd.wio.Unlock() var o writeOp o.Init(fd, buf, 'w') return iosrv.ExecIO(&o, fd.wdeadline) @@ -477,21 +499,15 @@ func (o *writeToOp) Name() string { } func (fd *netFD) WriteTo(buf []byte, sa syscall.Sockaddr) (int, error) { - if fd == nil { - return 0, syscall.EINVAL - } if len(buf) == 0 { return 0, nil } - fd.wio.Lock() - defer fd.wio.Unlock() if err := fd.incref(false); err != nil { return 0, err } defer fd.decref() - if fd.sysfd == syscall.InvalidHandle { - return 0, syscall.EINVAL - } + fd.wio.Lock() + defer fd.wio.Unlock() var o writeToOp o.Init(fd, buf, 'w') o.sa = sa diff --git a/src/pkg/net/net_test.go b/src/pkg/net/net_test.go index 623a788f9a..a4e8dcd445 100644 --- a/src/pkg/net/net_test.go +++ b/src/pkg/net/net_test.go @@ -174,3 +174,42 @@ func TestUDPListenClose(t *testing.T) { t.Fatal("timeout waiting for UDP close") } } + +func TestTCPClose(t *testing.T) { + l, err := Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer l.Close() + + read := func(r io.Reader) error { + var m [1]byte + _, err := r.Read(m[:]) + return err + } + + go func() { + c, err := Dial("tcp", l.Addr().String()) + if err != nil { + t.Fatal(err) + } + + go read(c) + + time.Sleep(10 * time.Millisecond) + c.Close() + }() + + c, err := l.Accept() + if err != nil { + t.Fatal(err) + } + defer c.Close() + + for err == nil { + err = read(c) + } + if err != nil && err != io.EOF { + t.Fatal(err) + } +} diff --git a/src/pkg/net/rpc/server_test.go b/src/pkg/net/rpc/server_test.go index 62c7b1e600..a718e8a940 100644 --- a/src/pkg/net/rpc/server_test.go +++ b/src/pkg/net/rpc/server_test.go @@ -499,6 +499,27 @@ func TestClientWriteError(t *testing.T) { w.done <- true } +func TestTCPClose(t *testing.T) { + once.Do(startServer) + + client, err := dialHTTP() + if err != nil { + t.Fatalf("dialing: %v", err) + } + defer client.Close() + + args := Args{17, 8} + var reply Reply + err = client.Call("Arith.Mul", args, &reply) + if err != nil { + t.Fatal("arith error:", err) + } + t.Logf("Arith: %d*%d=%d\n", args.A, args.B, reply) + if reply.C != args.A*args.B { + t.Errorf("Add: expected %d got %d", reply.C, args.A*args.B) + } +} + func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) { b.StopTimer() once.Do(startServer) diff --git a/src/pkg/net/sendfile_windows.go b/src/pkg/net/sendfile_windows.go index f5a6d8804d..2d64f2f5bf 100644 --- a/src/pkg/net/sendfile_windows.go +++ b/src/pkg/net/sendfile_windows.go @@ -48,12 +48,12 @@ func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) { return 0, nil, false } - c.wio.Lock() - defer c.wio.Unlock() if err := c.incref(false); err != nil { return 0, err, true } defer c.decref() + c.wio.Lock() + defer c.wio.Unlock() var o sendfileOp o.Init(c, 'w') diff --git a/src/pkg/net/timeout_test.go b/src/pkg/net/timeout_test.go index eec371cfb2..3343c4a551 100644 --- a/src/pkg/net/timeout_test.go +++ b/src/pkg/net/timeout_test.go @@ -146,3 +146,82 @@ func TestTimeoutAccept(t *testing.T) { // Pass. } } + +func TestReadWriteDeadline(t *testing.T) { + if !canCancelIO { + t.Logf("skipping test on this system") + return + } + const ( + readTimeout = 100 * time.Millisecond + writeTimeout = 200 * time.Millisecond + delta = 40 * time.Millisecond + ) + checkTimeout := func(command string, start time.Time, should time.Duration) { + is := time.Now().Sub(start) + d := should - is + if d < -delta || delta < d { + t.Errorf("%s timeout test failed: is=%v should=%v\n", command, is, should) + } + } + + ln, err := Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("ListenTCP on :0: %v", err) + } + + lnquit := make(chan bool) + + go func() { + c, err := ln.Accept() + if err != nil { + t.Fatalf("Accept: %v", err) + } + defer c.Close() + lnquit <- true + }() + + c, err := Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer c.Close() + + start := time.Now() + err = c.SetReadDeadline(start.Add(readTimeout)) + if err != nil { + t.Fatalf("SetReadDeadline: %v", err) + } + err = c.SetWriteDeadline(start.Add(writeTimeout)) + if err != nil { + t.Fatalf("SetWriteDeadline: %v", err) + } + + quit := make(chan bool) + + go func() { + var buf [10]byte + _, err = c.Read(buf[:]) + if err == nil { + t.Errorf("Read should not succeed") + } + checkTimeout("Read", start, readTimeout) + quit <- true + }() + + go func() { + var buf [10000]byte + for { + _, err = c.Write(buf[:]) + if err != nil { + break + } + } + checkTimeout("Write", start, writeTimeout) + quit <- true + }() + + <-quit + <-quit + <-lnquit +} diff --git a/src/pkg/syscall/syscall_windows.go b/src/pkg/syscall/syscall_windows.go index e997409d23..535bd55466 100644 --- a/src/pkg/syscall/syscall_windows.go +++ b/src/pkg/syscall/syscall_windows.go @@ -142,6 +142,7 @@ func NewCallback(fn interface{}) uintptr //sys GetQueuedCompletionStatus(cphandle Handle, qty *uint32, key *uint32, overlapped **Overlapped, timeout uint32) (err error) //sys PostQueuedCompletionStatus(cphandle Handle, qty uint32, key uint32, overlapped *Overlapped) (err error) //sys CancelIo(s Handle) (err error) +//sys CancelIoEx(s Handle, o *Overlapped) (err error) //sys CreateProcess(appName *uint16, commandLine *uint16, procSecurity *SecurityAttributes, threadSecurity *SecurityAttributes, inheritHandles bool, creationFlags uint32, env *uint16, currentDir *uint16, startupInfo *StartupInfo, outProcInfo *ProcessInformation) (err error) = CreateProcessW //sys OpenProcess(da uint32, inheritHandle bool, pid uint32) (handle Handle, err error) //sys TerminateProcess(handle Handle, exitcode uint32) (err error) @@ -474,6 +475,10 @@ func Chmod(path string, mode uint32) (err error) { return SetFileAttributes(p, attrs) } +func LoadCancelIoEx() error { + return procCancelIoEx.Find() +} + // net api calls const socket_error = uintptr(^uint32(0)) diff --git a/src/pkg/syscall/zsyscall_windows_386.go b/src/pkg/syscall/zsyscall_windows_386.go index f2b359672d..debe3cd596 100644 --- a/src/pkg/syscall/zsyscall_windows_386.go +++ b/src/pkg/syscall/zsyscall_windows_386.go @@ -49,6 +49,7 @@ var ( procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus") procPostQueuedCompletionStatus = modkernel32.NewProc("PostQueuedCompletionStatus") procCancelIo = modkernel32.NewProc("CancelIo") + procCancelIoEx = modkernel32.NewProc("CancelIoEx") procCreateProcessW = modkernel32.NewProc("CreateProcessW") procOpenProcess = modkernel32.NewProc("OpenProcess") procTerminateProcess = modkernel32.NewProc("TerminateProcess") @@ -535,6 +536,18 @@ func CancelIo(s Handle) (err error) { return } +func CancelIoEx(s Handle, o *Overlapped) (err error) { + r1, _, e1 := Syscall(procCancelIoEx.Addr(), 2, uintptr(s), uintptr(unsafe.Pointer(o)), 0) + if r1 == 0 { + if e1 != 0 { + err = error(e1) + } else { + err = EINVAL + } + } + return +} + func CreateProcess(appName *uint16, commandLine *uint16, procSecurity *SecurityAttributes, threadSecurity *SecurityAttributes, inheritHandles bool, creationFlags uint32, env *uint16, currentDir *uint16, startupInfo *StartupInfo, outProcInfo *ProcessInformation) (err error) { var _p0 uint32 if inheritHandles { diff --git a/src/pkg/syscall/zsyscall_windows_amd64.go b/src/pkg/syscall/zsyscall_windows_amd64.go index 37270fa590..5a7e74c645 100644 --- a/src/pkg/syscall/zsyscall_windows_amd64.go +++ b/src/pkg/syscall/zsyscall_windows_amd64.go @@ -49,6 +49,7 @@ var ( procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus") procPostQueuedCompletionStatus = modkernel32.NewProc("PostQueuedCompletionStatus") procCancelIo = modkernel32.NewProc("CancelIo") + procCancelIoEx = modkernel32.NewProc("CancelIoEx") procCreateProcessW = modkernel32.NewProc("CreateProcessW") procOpenProcess = modkernel32.NewProc("OpenProcess") procTerminateProcess = modkernel32.NewProc("TerminateProcess") @@ -535,6 +536,18 @@ func CancelIo(s Handle) (err error) { return } +func CancelIoEx(s Handle, o *Overlapped) (err error) { + r1, _, e1 := Syscall(procCancelIoEx.Addr(), 2, uintptr(s), uintptr(unsafe.Pointer(o)), 0) + if r1 == 0 { + if e1 != 0 { + err = error(e1) + } else { + err = EINVAL + } + } + return +} + func CreateProcess(appName *uint16, commandLine *uint16, procSecurity *SecurityAttributes, threadSecurity *SecurityAttributes, inheritHandles bool, creationFlags uint32, env *uint16, currentDir *uint16, startupInfo *StartupInfo, outProcInfo *ProcessInformation) (err error) { var _p0 uint32 if inheritHandles { diff --git a/src/pkg/syscall/ztypes_windows.go b/src/pkg/syscall/ztypes_windows.go index 485a0cc5c5..9827e129c0 100644 --- a/src/pkg/syscall/ztypes_windows.go +++ b/src/pkg/syscall/ztypes_windows.go @@ -20,6 +20,7 @@ const ( ERROR_ENVVAR_NOT_FOUND Errno = 203 ERROR_OPERATION_ABORTED Errno = 995 ERROR_IO_PENDING Errno = 997 + ERROR_NOT_FOUND Errno = 1168 ) const (