From ff25900bb660722d3496e7afafbdb9e3454271b6 Mon Sep 17 00:00:00 2001 From: Wei Guangjing Date: Wed, 19 Jan 2011 14:49:25 -0500 Subject: [PATCH] net: implement windows timeout R=brainman, rsc CC=golang-dev https://golang.org/cl/1731047 --- src/pkg/net/fd_windows.go | 138 +++++++++++++++++++++--- src/pkg/net/timeout_test.go | 5 - src/pkg/syscall/syscall_windows.go | 1 + src/pkg/syscall/zsyscall_windows_386.go | 16 +++ 4 files changed, 141 insertions(+), 19 deletions(-) diff --git a/src/pkg/net/fd_windows.go b/src/pkg/net/fd_windows.go index 72685d612a..f3e5761c87 100644 --- a/src/pkg/net/fd_windows.go +++ b/src/pkg/net/fd_windows.go @@ -6,14 +6,14 @@ package net import ( "os" + "runtime" "sync" "syscall" + "time" "unsafe" "runtime" ) -// BUG(brainman): The Windows implementation does not implement SetTimeout. - // IO completion result parameters. type ioResult struct { key uint32 @@ -79,6 +79,8 @@ type ioPacket struct { // Link to the io owner. c chan *ioResult + + w *syscall.WSABuf } func (s *pollServer) getCompletedIO() (ov *syscall.Overlapped, result *ioResult, err os.Error) { @@ -126,6 +128,8 @@ func startServer() { panic("Start pollServer: " + err.String() + "\n") } pollserver = p + + go timeoutIO() } var initErr os.Error @@ -143,8 +147,8 @@ func newFD(fd, family, proto int, net string, laddr, raddr Addr) (f *netFD, err sysfd: fd, family: family, proto: proto, - cr: make(chan *ioResult), - cw: make(chan *ioResult), + cr: make(chan *ioResult, 1), + cw: make(chan *ioResult, 1), net: net, laddr: laddr, raddr: raddr, @@ -199,6 +203,80 @@ func newWSABuf(p []byte) *syscall.WSABuf { return &syscall.WSABuf{uint32(len(p)), p0} } +func waitPacket(fd *netFD, pckt *ioPacket, mode int) (r *ioResult) { + var delta int64 + if mode == 'r' { + delta = fd.rdeadline_delta + } + if mode == 'w' { + delta = fd.wdeadline_delta + } + if delta <= 0 { + return <-pckt.c + } + + select { + case r = <-pckt.c: + case <-time.After(delta): + a := &arg{f: cancel, fd: fd, pckt: pckt, c: make(chan int)} + ioChan <- a + <-a.c + r = <-pckt.c + if r.errno == 995 { // IO Canceled + r.errno = syscall.EWOULDBLOCK + } + } + return r +} + +const ( + read = iota + readfrom + write + writeto + cancel +) + +type arg struct { + f int + fd *netFD + pckt *ioPacket + done *uint32 + flags *uint32 + rsa *syscall.RawSockaddrAny + size *int32 + sa *syscall.Sockaddr + c chan int +} + +var ioChan chan *arg = make(chan *arg) + +func timeoutIO() { + // CancelIO only cancels all pending input and output (I/O) operations that are + // issued by the calling thread for the specified file, does not cancel I/O + // operations that other threads issue for a file handle. So we need do all timeout + // I/O in single OS thread. + runtime.LockOSThread() + defer runtime.UnlockOSThread() + for { + o := <-ioChan + var e int + switch o.f { + case read: + e = syscall.WSARecv(uint32(o.fd.sysfd), o.pckt.w, 1, o.done, o.flags, &o.pckt.o, nil) + case readfrom: + e = syscall.WSARecvFrom(uint32(o.fd.sysfd), o.pckt.w, 1, o.done, o.flags, o.rsa, o.size, &o.pckt.o, nil) + case write: + e = syscall.WSASend(uint32(o.fd.sysfd), o.pckt.w, 1, o.done, uint32(0), &o.pckt.o, nil) + case writeto: + e = syscall.WSASendto(uint32(o.fd.sysfd), o.pckt.w, 1, o.done, 0, *o.sa, &o.pckt.o, nil) + case cancel: + _, e = syscall.CancelIo(uint32(o.fd.sysfd)) + } + o.c <- e + } +} + func (fd *netFD) Read(p []byte) (n int, err os.Error) { if fd == nil { return 0, os.EINVAL @@ -213,9 +291,17 @@ func (fd *netFD) Read(p []byte) (n int, err os.Error) { // Submit receive request. var pckt ioPacket pckt.c = fd.cr + pckt.w = newWSABuf(p) var done uint32 flags := uint32(0) - e := syscall.WSARecv(uint32(fd.sysfd), newWSABuf(p), 1, &done, &flags, &pckt.o, nil) + var e int + if fd.rdeadline_delta > 0 { + a := &arg{f: read, fd: fd, pckt: &pckt, done: &done, flags: &flags, c: make(chan int)} + ioChan <- a + e = <-a.c + } else { + e = syscall.WSARecv(uint32(fd.sysfd), pckt.w, 1, &done, &flags, &pckt.o, nil) + } switch e { case 0: // IO completed immediately, but we need to get our completion message anyway. @@ -225,7 +311,7 @@ func (fd *netFD) Read(p []byte) (n int, err os.Error) { return 0, &OpError{"WSARecv", fd.net, fd.laddr, os.Errno(e)} } // Wait for our request to complete. - r := <-pckt.c + r := waitPacket(fd, &pckt, 'r') if r.errno != 0 { err = &OpError{"WSARecv", fd.net, fd.laddr, os.Errno(r.errno)} } @@ -253,11 +339,19 @@ func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err os.Error) { // Submit receive request. var pckt ioPacket pckt.c = fd.cr + pckt.w = newWSABuf(p) var done uint32 flags := uint32(0) var rsa syscall.RawSockaddrAny l := int32(unsafe.Sizeof(rsa)) - e := syscall.WSARecvFrom(uint32(fd.sysfd), newWSABuf(p), 1, &done, &flags, &rsa, &l, &pckt.o, nil) + var e int + if fd.rdeadline_delta > 0 { + a := &arg{f: readfrom, fd: fd, pckt: &pckt, done: &done, flags: &flags, rsa: &rsa, size: &l, c: make(chan int)} + ioChan <- a + e = <-a.c + } else { + e = syscall.WSARecvFrom(uint32(fd.sysfd), pckt.w, 1, &done, &flags, &rsa, &l, &pckt.o, nil) + } switch e { case 0: // IO completed immediately, but we need to get our completion message anyway. @@ -267,7 +361,7 @@ func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err os.Error) { return 0, nil, &OpError{"WSARecvFrom", fd.net, fd.laddr, os.Errno(e)} } // Wait for our request to complete. - r := <-pckt.c + r := waitPacket(fd, &pckt, 'r') if r.errno != 0 { err = &OpError{"WSARecvFrom", fd.net, fd.laddr, os.Errno(r.errno)} } @@ -290,8 +384,16 @@ func (fd *netFD) Write(p []byte) (n int, err os.Error) { // Submit send request. var pckt ioPacket pckt.c = fd.cw + pckt.w = newWSABuf(p) var done uint32 - e := syscall.WSASend(uint32(fd.sysfd), newWSABuf(p), 1, &done, uint32(0), &pckt.o, nil) + var e int + if fd.wdeadline_delta > 0 { + a := &arg{f: write, fd: fd, pckt: &pckt, done: &done, c: make(chan int)} + ioChan <- a + e = <-a.c + } else { + e = syscall.WSASend(uint32(fd.sysfd), pckt.w, 1, &done, uint32(0), &pckt.o, nil) + } switch e { case 0: // IO completed immediately, but we need to get our completion message anyway. @@ -301,7 +403,7 @@ func (fd *netFD) Write(p []byte) (n int, err os.Error) { return 0, &OpError{"WSASend", fd.net, fd.laddr, os.Errno(e)} } // Wait for our request to complete. - r := <-pckt.c + r := waitPacket(fd, &pckt, 'w') if r.errno != 0 { err = &OpError{"WSASend", fd.net, fd.laddr, os.Errno(r.errno)} } @@ -326,8 +428,16 @@ func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err os.Error) { // Submit send request. var pckt ioPacket pckt.c = fd.cw + pckt.w = newWSABuf(p) var done uint32 - e := syscall.WSASendto(uint32(fd.sysfd), newWSABuf(p), 1, &done, 0, sa, &pckt.o, nil) + var e int + if fd.wdeadline_delta > 0 { + a := &arg{f: writeto, fd: fd, pckt: &pckt, done: &done, sa: &sa, c: make(chan int)} + ioChan <- a + e = <-a.c + } else { + e = syscall.WSASendto(uint32(fd.sysfd), pckt.w, 1, &done, 0, sa, &pckt.o, nil) + } switch e { case 0: // IO completed immediately, but we need to get our completion message anyway. @@ -337,7 +447,7 @@ func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err os.Error) { return 0, &OpError{"WSASendTo", fd.net, fd.laddr, os.Errno(e)} } // Wait for our request to complete. - r := <-pckt.c + r := waitPacket(fd, &pckt, 'w') if r.errno != 0 { err = &OpError{"WSASendTo", fd.net, fd.laddr, os.Errno(r.errno)} } @@ -410,8 +520,8 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err os. sysfd: s, family: fd.family, proto: fd.proto, - cr: make(chan *ioResult), - cw: make(chan *ioResult), + cr: make(chan *ioResult, 1), + cw: make(chan *ioResult, 1), net: fd.net, laddr: laddr, raddr: raddr, diff --git a/src/pkg/net/timeout_test.go b/src/pkg/net/timeout_test.go index 092781685e..3594c0a350 100644 --- a/src/pkg/net/timeout_test.go +++ b/src/pkg/net/timeout_test.go @@ -8,14 +8,9 @@ import ( "os" "testing" "time" - "runtime" ) func testTimeout(t *testing.T, network, addr string, readFrom bool) { - // Timeouts are not implemented on windows. - if runtime.GOOS == "windows" { - return - } fd, err := Dial(network, "", addr) if err != nil { t.Errorf("dial %s %s failed: %v", network, addr, err) diff --git a/src/pkg/syscall/syscall_windows.go b/src/pkg/syscall/syscall_windows.go index 06dde518fd..d3d22dba80 100644 --- a/src/pkg/syscall/syscall_windows.go +++ b/src/pkg/syscall/syscall_windows.go @@ -126,6 +126,7 @@ func getSysProcAddr(m uint32, pname string) uintptr { //sys GetTimeZoneInformation(tzi *Timezoneinformation) (rc uint32, errno int) [failretval==0xffffffff] //sys CreateIoCompletionPort(filehandle int32, cphandle int32, key uint32, threadcnt uint32) (handle int32, errno int) //sys GetQueuedCompletionStatus(cphandle int32, qty *uint32, key *uint32, overlapped **Overlapped, timeout uint32) (ok bool, errno int) +//sys CancelIo(s uint32) (ok bool, errno int) //sys CreateProcess(appName *int16, commandLine *uint16, procSecurity *int16, threadSecurity *int16, inheritHandles bool, creationFlags uint32, env *uint16, currentDir *uint16, startupInfo *StartupInfo, outProcInfo *ProcessInformation) (ok bool, errno int) = CreateProcessW //sys GetStartupInfo(startupInfo *StartupInfo) (ok bool, errno int) = GetStartupInfoW //sys GetCurrentProcess() (pseudoHandle int32, errno int) diff --git a/src/pkg/syscall/zsyscall_windows_386.go b/src/pkg/syscall/zsyscall_windows_386.go index 29880f2b28..18e36a0226 100644 --- a/src/pkg/syscall/zsyscall_windows_386.go +++ b/src/pkg/syscall/zsyscall_windows_386.go @@ -43,6 +43,7 @@ var ( procGetTimeZoneInformation = getSysProcAddr(modkernel32, "GetTimeZoneInformation") procCreateIoCompletionPort = getSysProcAddr(modkernel32, "CreateIoCompletionPort") procGetQueuedCompletionStatus = getSysProcAddr(modkernel32, "GetQueuedCompletionStatus") + procCancelIo = getSysProcAddr(modkernel32, "CancelIo") procCreateProcessW = getSysProcAddr(modkernel32, "CreateProcessW") procGetStartupInfoW = getSysProcAddr(modkernel32, "GetStartupInfoW") procGetCurrentProcess = getSysProcAddr(modkernel32, "GetCurrentProcess") @@ -512,6 +513,21 @@ func GetQueuedCompletionStatus(cphandle int32, qty *uint32, key *uint32, overlap return } +func CancelIo(s uint32) (ok bool, errno int) { + r0, _, e1 := Syscall(procCancelIo, uintptr(s), 0, 0) + ok = bool(r0 != 0) + if !ok { + if e1 != 0 { + errno = int(e1) + } else { + errno = EINVAL + } + } else { + errno = 0 + } + return +} + func CreateProcess(appName *int16, commandLine *uint16, procSecurity *int16, threadSecurity *int16, inheritHandles bool, creationFlags uint32, env *uint16, currentDir *uint16, startupInfo *StartupInfo, outProcInfo *ProcessInformation) (ok bool, errno int) { var _p0 uint32 if inheritHandles {