diff --git a/src/pkg/net/fd_unix.go b/src/pkg/net/fd_unix.go index 0540df8255b..d7a83b6393e 100644 --- a/src/pkg/net/fd_unix.go +++ b/src/pkg/net/fd_unix.go @@ -234,6 +234,20 @@ func (s *pollServer) Run() { } } +func (s *pollServer) PrepareRead(fd *netFD) error { + if fd.rdeadline.expired() { + return errTimeout + } + return nil +} + +func (s *pollServer) PrepareWrite(fd *netFD) error { + if fd.wdeadline.expired() { + return errTimeout + } + return nil +} + func (s *pollServer) WaitRead(fd *netFD) error { err := s.AddFD(fd, 'r') if err == nil { @@ -331,6 +345,11 @@ func (fd *netFD) name() string { } func (fd *netFD) connect(ra syscall.Sockaddr) error { + fd.wio.Lock() + defer fd.wio.Unlock() + if err := fd.pollServer.PrepareWrite(fd); err != nil { + return err + } err := syscall.Connect(fd.sysfd, ra) if err == syscall.EINPROGRESS { if err = fd.pollServer.WaitWrite(fd); err != nil { @@ -425,11 +444,10 @@ func (fd *netFD) Read(p []byte) (n int, err error) { return 0, err } defer fd.decref() + if err := fd.pollServer.PrepareRead(fd); err != nil { + return 0, &OpError{"read", fd.net, fd.raddr, err} + } for { - if fd.rdeadline.expired() { - err = errTimeout - break - } n, err = syscall.Read(int(fd.sysfd), p) if err != nil { n = 0 @@ -455,11 +473,10 @@ func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err error) { return 0, nil, err } defer fd.decref() + if err := fd.pollServer.PrepareRead(fd); err != nil { + return 0, nil, &OpError{"read", fd.net, fd.laddr, err} + } for { - if fd.rdeadline.expired() { - err = errTimeout - break - } n, sa, err = syscall.Recvfrom(fd.sysfd, p, 0) if err != nil { n = 0 @@ -485,11 +502,10 @@ func (fd *netFD) ReadMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.S return 0, 0, 0, nil, err } defer fd.decref() + if err := fd.pollServer.PrepareRead(fd); err != nil { + return 0, 0, 0, nil, &OpError{"read", fd.net, fd.laddr, err} + } for { - if fd.rdeadline.expired() { - err = errTimeout - break - } n, oobn, flags, sa, err = syscall.Recvmsg(fd.sysfd, p, oob, 0) if err != nil { // TODO(dfc) should n and oobn be set to 0 @@ -522,11 +538,10 @@ func (fd *netFD) Write(p []byte) (nn int, err error) { return 0, err } defer fd.decref() + if err := fd.pollServer.PrepareWrite(fd); err != nil { + return 0, &OpError{"write", fd.net, fd.raddr, err} + } for { - if fd.wdeadline.expired() { - err = errTimeout - break - } var n int n, err = syscall.Write(int(fd.sysfd), p[nn:]) if n > 0 { @@ -562,11 +577,10 @@ func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err error) { return 0, err } defer fd.decref() + if err := fd.pollServer.PrepareWrite(fd); err != nil { + return 0, &OpError{"write", fd.net, fd.raddr, err} + } for { - if fd.wdeadline.expired() { - err = errTimeout - break - } err = syscall.Sendto(fd.sysfd, p, 0, sa) if err == syscall.EAGAIN { if err = fd.pollServer.WaitWrite(fd); err == nil { @@ -590,11 +604,10 @@ func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oob return 0, 0, err } defer fd.decref() + if err := fd.pollServer.PrepareWrite(fd); err != nil { + return 0, 0, &OpError{"write", fd.net, fd.raddr, err} + } for { - if fd.wdeadline.expired() { - err = errTimeout - break - } err = syscall.Sendmsg(fd.sysfd, p, oob, sa, 0) if err == syscall.EAGAIN { if err = fd.pollServer.WaitWrite(fd); err == nil { @@ -613,6 +626,8 @@ func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oob } func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (netfd *netFD, err error) { + fd.rio.Lock() + defer fd.rio.Unlock() if err := fd.incref(false); err != nil { return nil, err } @@ -620,6 +635,9 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (netfd *netFD, err e var s int var rsa syscall.Sockaddr + if err = fd.pollServer.PrepareRead(fd); err != nil { + return nil, &OpError{"accept", fd.net, fd.laddr, err} + } for { s, rsa, err = accept(fd.sysfd) if err != nil { diff --git a/src/pkg/net/timeout_test.go b/src/pkg/net/timeout_test.go index 0260efcc0b6..2e92147b8e3 100644 --- a/src/pkg/net/timeout_test.go +++ b/src/pkg/net/timeout_test.go @@ -532,7 +532,7 @@ func TestReadDeadlineDataAvailable(t *testing.T) { defer ln.Close() servec := make(chan copyRes) - const msg = "data client shouldn't read, even though it it'll be waiting" + const msg = "data client shouldn't read, even though it'll be waiting" go func() { c, err := ln.Accept() if err != nil { @@ -596,6 +596,64 @@ func TestWriteDeadlineBufferAvailable(t *testing.T) { } } +// TestAcceptDeadlineConnectionAvailable tests that accept deadlines work, even +// if there's incoming connections available. +func TestAcceptDeadlineConnectionAvailable(t *testing.T) { + switch runtime.GOOS { + case "plan9": + t.Skipf("skipping test on %q", runtime.GOOS) + } + + ln := newLocalListener(t).(*TCPListener) + defer ln.Close() + + go func() { + c, err := Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer c.Close() + var buf [1]byte + c.Read(buf[:]) // block until the connection or listener is closed + }() + time.Sleep(10 * time.Millisecond) + ln.SetDeadline(time.Now().Add(-5 * time.Second)) // in the past + c, err := ln.Accept() + if err == nil { + defer c.Close() + } + if !isTimeout(err) { + t.Fatalf("Accept: got %v; want timeout", err) + } +} + +// TestConnectDeadlineInThePast tests that connect deadlines work, even +// if the connection can be established w/o blocking. +func TestConnectDeadlineInThePast(t *testing.T) { + switch runtime.GOOS { + case "plan9": + t.Skipf("skipping test on %q", runtime.GOOS) + } + + ln := newLocalListener(t).(*TCPListener) + defer ln.Close() + + go func() { + c, err := ln.Accept() + if err == nil { + defer c.Close() + } + }() + time.Sleep(10 * time.Millisecond) + c, err := DialTimeout("tcp", ln.Addr().String(), -5*time.Second) // in the past + if err == nil { + defer c.Close() + } + if !isTimeout(err) { + t.Fatalf("DialTimeout: got %v; want timeout", err) + } +} + // TestProlongTimeout tests concurrent deadline modification. // Known to cause data races in the past. func TestProlongTimeout(t *testing.T) {