mirror of
https://github.com/golang/go
synced 2024-11-26 14:56:47 -07:00
net: fix accept/connect deadline handling
Ensure that accept/connect respect deadline, even if the operation can be executed w/o blocking. Note this changes external behavior, but it makes it consistent with read/write. Factor out deadline check into pollServer.PrepareRead/Write, in preparation for edge triggered pollServer. Ensure that pollServer.WaitRead/Write are not called concurrently by adding rio/wio locks around connect/accept. R=golang-dev, mikioh.mikioh, bradfitz, iant CC=golang-dev https://golang.org/cl/7436048
This commit is contained in:
parent
48fa4a10fb
commit
0f136f2c05
@ -234,6 +234,20 @@ func (s *pollServer) Run() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *pollServer) PrepareRead(fd *netFD) error {
|
||||||
|
if fd.rdeadline.expired() {
|
||||||
|
return errTimeout
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *pollServer) PrepareWrite(fd *netFD) error {
|
||||||
|
if fd.wdeadline.expired() {
|
||||||
|
return errTimeout
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *pollServer) WaitRead(fd *netFD) error {
|
func (s *pollServer) WaitRead(fd *netFD) error {
|
||||||
err := s.AddFD(fd, 'r')
|
err := s.AddFD(fd, 'r')
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@ -331,6 +345,11 @@ func (fd *netFD) name() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (fd *netFD) connect(ra syscall.Sockaddr) error {
|
func (fd *netFD) connect(ra syscall.Sockaddr) error {
|
||||||
|
fd.wio.Lock()
|
||||||
|
defer fd.wio.Unlock()
|
||||||
|
if err := fd.pollServer.PrepareWrite(fd); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
err := syscall.Connect(fd.sysfd, ra)
|
err := syscall.Connect(fd.sysfd, ra)
|
||||||
if err == syscall.EINPROGRESS {
|
if err == syscall.EINPROGRESS {
|
||||||
if err = fd.pollServer.WaitWrite(fd); err != nil {
|
if err = fd.pollServer.WaitWrite(fd); err != nil {
|
||||||
@ -425,11 +444,10 @@ func (fd *netFD) Read(p []byte) (n int, err error) {
|
|||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
defer fd.decref()
|
defer fd.decref()
|
||||||
|
if err := fd.pollServer.PrepareRead(fd); err != nil {
|
||||||
|
return 0, &OpError{"read", fd.net, fd.raddr, err}
|
||||||
|
}
|
||||||
for {
|
for {
|
||||||
if fd.rdeadline.expired() {
|
|
||||||
err = errTimeout
|
|
||||||
break
|
|
||||||
}
|
|
||||||
n, err = syscall.Read(int(fd.sysfd), p)
|
n, err = syscall.Read(int(fd.sysfd), p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
n = 0
|
n = 0
|
||||||
@ -455,11 +473,10 @@ func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err error) {
|
|||||||
return 0, nil, err
|
return 0, nil, err
|
||||||
}
|
}
|
||||||
defer fd.decref()
|
defer fd.decref()
|
||||||
|
if err := fd.pollServer.PrepareRead(fd); err != nil {
|
||||||
|
return 0, nil, &OpError{"read", fd.net, fd.laddr, err}
|
||||||
|
}
|
||||||
for {
|
for {
|
||||||
if fd.rdeadline.expired() {
|
|
||||||
err = errTimeout
|
|
||||||
break
|
|
||||||
}
|
|
||||||
n, sa, err = syscall.Recvfrom(fd.sysfd, p, 0)
|
n, sa, err = syscall.Recvfrom(fd.sysfd, p, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
n = 0
|
n = 0
|
||||||
@ -485,11 +502,10 @@ func (fd *netFD) ReadMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.S
|
|||||||
return 0, 0, 0, nil, err
|
return 0, 0, 0, nil, err
|
||||||
}
|
}
|
||||||
defer fd.decref()
|
defer fd.decref()
|
||||||
|
if err := fd.pollServer.PrepareRead(fd); err != nil {
|
||||||
|
return 0, 0, 0, nil, &OpError{"read", fd.net, fd.laddr, err}
|
||||||
|
}
|
||||||
for {
|
for {
|
||||||
if fd.rdeadline.expired() {
|
|
||||||
err = errTimeout
|
|
||||||
break
|
|
||||||
}
|
|
||||||
n, oobn, flags, sa, err = syscall.Recvmsg(fd.sysfd, p, oob, 0)
|
n, oobn, flags, sa, err = syscall.Recvmsg(fd.sysfd, p, oob, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO(dfc) should n and oobn be set to 0
|
// TODO(dfc) should n and oobn be set to 0
|
||||||
@ -522,11 +538,10 @@ func (fd *netFD) Write(p []byte) (nn int, err error) {
|
|||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
defer fd.decref()
|
defer fd.decref()
|
||||||
|
if err := fd.pollServer.PrepareWrite(fd); err != nil {
|
||||||
|
return 0, &OpError{"write", fd.net, fd.raddr, err}
|
||||||
|
}
|
||||||
for {
|
for {
|
||||||
if fd.wdeadline.expired() {
|
|
||||||
err = errTimeout
|
|
||||||
break
|
|
||||||
}
|
|
||||||
var n int
|
var n int
|
||||||
n, err = syscall.Write(int(fd.sysfd), p[nn:])
|
n, err = syscall.Write(int(fd.sysfd), p[nn:])
|
||||||
if n > 0 {
|
if n > 0 {
|
||||||
@ -562,11 +577,10 @@ func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err error) {
|
|||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
defer fd.decref()
|
defer fd.decref()
|
||||||
|
if err := fd.pollServer.PrepareWrite(fd); err != nil {
|
||||||
|
return 0, &OpError{"write", fd.net, fd.raddr, err}
|
||||||
|
}
|
||||||
for {
|
for {
|
||||||
if fd.wdeadline.expired() {
|
|
||||||
err = errTimeout
|
|
||||||
break
|
|
||||||
}
|
|
||||||
err = syscall.Sendto(fd.sysfd, p, 0, sa)
|
err = syscall.Sendto(fd.sysfd, p, 0, sa)
|
||||||
if err == syscall.EAGAIN {
|
if err == syscall.EAGAIN {
|
||||||
if err = fd.pollServer.WaitWrite(fd); err == nil {
|
if err = fd.pollServer.WaitWrite(fd); err == nil {
|
||||||
@ -590,11 +604,10 @@ func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oob
|
|||||||
return 0, 0, err
|
return 0, 0, err
|
||||||
}
|
}
|
||||||
defer fd.decref()
|
defer fd.decref()
|
||||||
|
if err := fd.pollServer.PrepareWrite(fd); err != nil {
|
||||||
|
return 0, 0, &OpError{"write", fd.net, fd.raddr, err}
|
||||||
|
}
|
||||||
for {
|
for {
|
||||||
if fd.wdeadline.expired() {
|
|
||||||
err = errTimeout
|
|
||||||
break
|
|
||||||
}
|
|
||||||
err = syscall.Sendmsg(fd.sysfd, p, oob, sa, 0)
|
err = syscall.Sendmsg(fd.sysfd, p, oob, sa, 0)
|
||||||
if err == syscall.EAGAIN {
|
if err == syscall.EAGAIN {
|
||||||
if err = fd.pollServer.WaitWrite(fd); err == nil {
|
if err = fd.pollServer.WaitWrite(fd); err == nil {
|
||||||
@ -613,6 +626,8 @@ func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oob
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (netfd *netFD, err error) {
|
func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (netfd *netFD, err error) {
|
||||||
|
fd.rio.Lock()
|
||||||
|
defer fd.rio.Unlock()
|
||||||
if err := fd.incref(false); err != nil {
|
if err := fd.incref(false); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -620,6 +635,9 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (netfd *netFD, err e
|
|||||||
|
|
||||||
var s int
|
var s int
|
||||||
var rsa syscall.Sockaddr
|
var rsa syscall.Sockaddr
|
||||||
|
if err = fd.pollServer.PrepareRead(fd); err != nil {
|
||||||
|
return nil, &OpError{"accept", fd.net, fd.laddr, err}
|
||||||
|
}
|
||||||
for {
|
for {
|
||||||
s, rsa, err = accept(fd.sysfd)
|
s, rsa, err = accept(fd.sysfd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -532,7 +532,7 @@ func TestReadDeadlineDataAvailable(t *testing.T) {
|
|||||||
defer ln.Close()
|
defer ln.Close()
|
||||||
|
|
||||||
servec := make(chan copyRes)
|
servec := make(chan copyRes)
|
||||||
const msg = "data client shouldn't read, even though it it'll be waiting"
|
const msg = "data client shouldn't read, even though it'll be waiting"
|
||||||
go func() {
|
go func() {
|
||||||
c, err := ln.Accept()
|
c, err := ln.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -596,6 +596,64 @@ func TestWriteDeadlineBufferAvailable(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestAcceptDeadlineConnectionAvailable tests that accept deadlines work, even
|
||||||
|
// if there's incoming connections available.
|
||||||
|
func TestAcceptDeadlineConnectionAvailable(t *testing.T) {
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "plan9":
|
||||||
|
t.Skipf("skipping test on %q", runtime.GOOS)
|
||||||
|
}
|
||||||
|
|
||||||
|
ln := newLocalListener(t).(*TCPListener)
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
c, err := Dial("tcp", ln.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Dial: %v", err)
|
||||||
|
}
|
||||||
|
defer c.Close()
|
||||||
|
var buf [1]byte
|
||||||
|
c.Read(buf[:]) // block until the connection or listener is closed
|
||||||
|
}()
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
ln.SetDeadline(time.Now().Add(-5 * time.Second)) // in the past
|
||||||
|
c, err := ln.Accept()
|
||||||
|
if err == nil {
|
||||||
|
defer c.Close()
|
||||||
|
}
|
||||||
|
if !isTimeout(err) {
|
||||||
|
t.Fatalf("Accept: got %v; want timeout", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConnectDeadlineInThePast tests that connect deadlines work, even
|
||||||
|
// if the connection can be established w/o blocking.
|
||||||
|
func TestConnectDeadlineInThePast(t *testing.T) {
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "plan9":
|
||||||
|
t.Skipf("skipping test on %q", runtime.GOOS)
|
||||||
|
}
|
||||||
|
|
||||||
|
ln := newLocalListener(t).(*TCPListener)
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
c, err := ln.Accept()
|
||||||
|
if err == nil {
|
||||||
|
defer c.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
c, err := DialTimeout("tcp", ln.Addr().String(), -5*time.Second) // in the past
|
||||||
|
if err == nil {
|
||||||
|
defer c.Close()
|
||||||
|
}
|
||||||
|
if !isTimeout(err) {
|
||||||
|
t.Fatalf("DialTimeout: got %v; want timeout", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestProlongTimeout tests concurrent deadline modification.
|
// TestProlongTimeout tests concurrent deadline modification.
|
||||||
// Known to cause data races in the past.
|
// Known to cause data races in the past.
|
||||||
func TestProlongTimeout(t *testing.T) {
|
func TestProlongTimeout(t *testing.T) {
|
||||||
|
Loading…
Reference in New Issue
Block a user