mirror of
https://github.com/golang/go
synced 2024-11-27 02:31:18 -07:00
net: check read and write deadlines before doing syscalls
Otherwise a fast sender or receiver can make sockets always readable or writable, preventing deadline checks from ever occuring. Update #4191 (fixes it with other CL, coming separately) Fixes #4403 R=golang-dev, alex.brainman, dave, mikioh.mikioh CC=golang-dev https://golang.org/cl/6851096
This commit is contained in:
parent
314fd62434
commit
5fa3aeb14d
@ -423,6 +423,12 @@ func (fd *netFD) Read(p []byte) (n int, err error) {
|
||||
}
|
||||
defer fd.decref()
|
||||
for {
|
||||
if fd.rdeadline > 0 {
|
||||
if time.Now().UnixNano() >= fd.rdeadline {
|
||||
err = errTimeout
|
||||
break
|
||||
}
|
||||
}
|
||||
n, err = syscall.Read(int(fd.sysfd), p)
|
||||
if err == syscall.EAGAIN {
|
||||
err = errTimeout
|
||||
@ -453,6 +459,12 @@ func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err error) {
|
||||
}
|
||||
defer fd.decref()
|
||||
for {
|
||||
if fd.rdeadline > 0 {
|
||||
if time.Now().UnixNano() >= fd.rdeadline {
|
||||
err = errTimeout
|
||||
break
|
||||
}
|
||||
}
|
||||
n, sa, err = syscall.Recvfrom(fd.sysfd, p, 0)
|
||||
if err == syscall.EAGAIN {
|
||||
err = errTimeout
|
||||
@ -481,6 +493,12 @@ func (fd *netFD) ReadMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.S
|
||||
}
|
||||
defer fd.decref()
|
||||
for {
|
||||
if fd.rdeadline > 0 {
|
||||
if time.Now().UnixNano() >= fd.rdeadline {
|
||||
err = errTimeout
|
||||
break
|
||||
}
|
||||
}
|
||||
n, oobn, flags, sa, err = syscall.Recvmsg(fd.sysfd, p, oob, 0)
|
||||
if err == syscall.EAGAIN {
|
||||
err = errTimeout
|
||||
@ -512,6 +530,12 @@ func (fd *netFD) Write(p []byte) (int, error) {
|
||||
var err error
|
||||
nn := 0
|
||||
for {
|
||||
if fd.wdeadline > 0 {
|
||||
if time.Now().UnixNano() >= fd.wdeadline {
|
||||
err = errTimeout
|
||||
break
|
||||
}
|
||||
}
|
||||
var n int
|
||||
n, err = syscall.Write(int(fd.sysfd), p[nn:])
|
||||
if n > 0 {
|
||||
@ -551,6 +575,12 @@ func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err error) {
|
||||
}
|
||||
defer fd.decref()
|
||||
for {
|
||||
if fd.wdeadline > 0 {
|
||||
if time.Now().UnixNano() >= fd.wdeadline {
|
||||
err = errTimeout
|
||||
break
|
||||
}
|
||||
}
|
||||
err = syscall.Sendto(fd.sysfd, p, 0, sa)
|
||||
if err == syscall.EAGAIN {
|
||||
err = errTimeout
|
||||
@ -578,6 +608,12 @@ func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oob
|
||||
}
|
||||
defer fd.decref()
|
||||
for {
|
||||
if fd.wdeadline > 0 {
|
||||
if time.Now().UnixNano() >= fd.wdeadline {
|
||||
err = errTimeout
|
||||
break
|
||||
}
|
||||
}
|
||||
err = syscall.Sendmsg(fd.sysfd, p, oob, sa, 0)
|
||||
if err == syscall.EAGAIN {
|
||||
err = errTimeout
|
||||
|
@ -6,11 +6,24 @@ package net
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func isTimeout(err error) bool {
|
||||
e, ok := err.(Error)
|
||||
return ok && e.Timeout()
|
||||
}
|
||||
|
||||
type copyRes struct {
|
||||
n int64
|
||||
err error
|
||||
d time.Duration
|
||||
}
|
||||
|
||||
func testTimeout(t *testing.T, net, addr string, readFrom bool) {
|
||||
c, err := Dial(net, addr)
|
||||
if err != nil {
|
||||
@ -230,3 +243,191 @@ func TestReadWriteDeadline(t *testing.T) {
|
||||
<-quit
|
||||
<-lnquit
|
||||
}
|
||||
|
||||
type neverEnding byte
|
||||
|
||||
func (b neverEnding) Read(p []byte) (n int, err error) {
|
||||
for i := range p {
|
||||
p[i] = byte(b)
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func TestVariousDeadlines1Proc(t *testing.T) {
|
||||
testVariousDeadlines(t, 1)
|
||||
}
|
||||
|
||||
func TestVariousDeadlines4Proc(t *testing.T) {
|
||||
testVariousDeadlines(t, 4)
|
||||
}
|
||||
|
||||
func testVariousDeadlines(t *testing.T, maxProcs int) {
|
||||
defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs))
|
||||
ln := newLocalListener(t)
|
||||
defer ln.Close()
|
||||
donec := make(chan struct{})
|
||||
defer close(donec)
|
||||
|
||||
testsDone := func() bool {
|
||||
select {
|
||||
case <-donec:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// The server, with no timeouts of its own, sending bytes to clients
|
||||
// as fast as it can.
|
||||
servec := make(chan copyRes)
|
||||
go func() {
|
||||
for {
|
||||
c, err := ln.Accept()
|
||||
if err != nil {
|
||||
if !testsDone() {
|
||||
t.Fatalf("Accept: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
t0 := time.Now()
|
||||
n, err := io.Copy(c, neverEnding('a'))
|
||||
d := time.Since(t0)
|
||||
c.Close()
|
||||
servec <- copyRes{n, err, d}
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
for _, timeout := range []time.Duration{
|
||||
1 * time.Nanosecond,
|
||||
2 * time.Nanosecond,
|
||||
5 * time.Nanosecond,
|
||||
50 * time.Nanosecond,
|
||||
100 * time.Nanosecond,
|
||||
200 * time.Nanosecond,
|
||||
500 * time.Nanosecond,
|
||||
750 * time.Nanosecond,
|
||||
1 * time.Microsecond,
|
||||
5 * time.Microsecond,
|
||||
25 * time.Microsecond,
|
||||
250 * time.Microsecond,
|
||||
500 * time.Microsecond,
|
||||
1 * time.Millisecond,
|
||||
5 * time.Millisecond,
|
||||
100 * time.Millisecond,
|
||||
250 * time.Millisecond,
|
||||
500 * time.Millisecond,
|
||||
1 * time.Second,
|
||||
} {
|
||||
numRuns := 3
|
||||
if testing.Short() {
|
||||
numRuns = 1
|
||||
if timeout > 500*time.Microsecond {
|
||||
continue
|
||||
}
|
||||
}
|
||||
for run := 0; run < numRuns; run++ {
|
||||
name := fmt.Sprintf("%v run %d/%d", timeout, run+1, numRuns)
|
||||
t.Log(name)
|
||||
|
||||
c, err := Dial("tcp", ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
}
|
||||
clientc := make(chan copyRes)
|
||||
go func() {
|
||||
t0 := time.Now()
|
||||
c.SetDeadline(t0.Add(timeout))
|
||||
n, err := io.Copy(ioutil.Discard, c)
|
||||
d := time.Since(t0)
|
||||
c.Close()
|
||||
clientc <- copyRes{n, err, d}
|
||||
}()
|
||||
|
||||
const tooLong = 2000 * time.Millisecond
|
||||
select {
|
||||
case res := <-clientc:
|
||||
if isTimeout(res.err) {
|
||||
t.Logf("for %v, good client timeout after %v, reading %d bytes", name, res.d, res.n)
|
||||
} else {
|
||||
t.Fatalf("for %v: client Copy = %d, %v (want timeout)", name, res.n, res.err)
|
||||
}
|
||||
case <-time.After(tooLong):
|
||||
t.Fatalf("for %v: timeout (%v) waiting for client to timeout (%v) reading", name, tooLong, timeout)
|
||||
}
|
||||
|
||||
select {
|
||||
case res := <-servec:
|
||||
t.Logf("for %v: server in %v wrote %d, %v", name, res.d, res.n, res.err)
|
||||
case <-time.After(tooLong):
|
||||
t.Fatalf("for %v, timeout waiting for server to finish writing", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadDeadlineDataAvailable tests that read deadlines work, even
|
||||
// if there's data ready to be read.
|
||||
func TestReadDeadlineDataAvailable(t *testing.T) {
|
||||
ln := newLocalListener(t)
|
||||
defer ln.Close()
|
||||
|
||||
servec := make(chan copyRes)
|
||||
const msg = "data client shouldn't read, even though it it'll be waiting"
|
||||
go func() {
|
||||
c, err := ln.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("Accept: %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
n, err := c.Write([]byte(msg))
|
||||
servec <- copyRes{n: int64(n), err: err}
|
||||
}()
|
||||
|
||||
c, err := Dial("tcp", ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
if res := <-servec; res.err != nil || res.n != int64(len(msg)) {
|
||||
t.Fatalf("unexpected server Write: n=%d, err=%d; want n=%d, err=nil", res.n, res.err, len(msg))
|
||||
}
|
||||
c.SetReadDeadline(time.Now().Add(-5 * time.Second)) // in the psat.
|
||||
buf := make([]byte, len(msg)/2)
|
||||
n, err := c.Read(buf)
|
||||
if n > 0 || !isTimeout(err) {
|
||||
t.Fatalf("client read = %d (%q) err=%v; want 0, timeout", n, buf[:n], err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteDeadlineBufferAvailable tests that write deadlines work, even
|
||||
// if there's buffer space available to write.
|
||||
func TestWriteDeadlineBufferAvailable(t *testing.T) {
|
||||
ln := newLocalListener(t)
|
||||
defer ln.Close()
|
||||
|
||||
servec := make(chan copyRes)
|
||||
go func() {
|
||||
c, err := ln.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("Accept: %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
c.SetWriteDeadline(time.Now().Add(-5 * time.Second)) // in the past
|
||||
n, err := c.Write([]byte{'x'})
|
||||
servec <- copyRes{n: int64(n), err: err}
|
||||
}()
|
||||
|
||||
c, err := Dial("tcp", ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
res := <-servec
|
||||
if res.n != 0 {
|
||||
t.Errorf("Write = %d; want 0", res.n)
|
||||
}
|
||||
if !isTimeout(res.err) {
|
||||
t.Errorf("Write error = %v; want timeout", res.err)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user