From 5fa3aeb14d56f6af4d6ad3cc9c81a20775770911 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Fri, 23 Nov 2012 22:15:26 -0800 Subject: [PATCH] net: check read and write deadlines before doing syscalls Otherwise a fast sender or receiver can make sockets always readable or writable, preventing deadline checks from ever occuring. Update #4191 (fixes it with other CL, coming separately) Fixes #4403 R=golang-dev, alex.brainman, dave, mikioh.mikioh CC=golang-dev https://golang.org/cl/6851096 --- src/pkg/net/fd_unix.go | 36 +++++++ src/pkg/net/timeout_test.go | 201 ++++++++++++++++++++++++++++++++++++ 2 files changed, 237 insertions(+) diff --git a/src/pkg/net/fd_unix.go b/src/pkg/net/fd_unix.go index d87c51ec663..16da53f0f55 100644 --- a/src/pkg/net/fd_unix.go +++ b/src/pkg/net/fd_unix.go @@ -423,6 +423,12 @@ func (fd *netFD) Read(p []byte) (n int, err error) { } defer fd.decref() for { + if fd.rdeadline > 0 { + if time.Now().UnixNano() >= fd.rdeadline { + err = errTimeout + break + } + } n, err = syscall.Read(int(fd.sysfd), p) if err == syscall.EAGAIN { err = errTimeout @@ -453,6 +459,12 @@ func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err error) { } defer fd.decref() for { + if fd.rdeadline > 0 { + if time.Now().UnixNano() >= fd.rdeadline { + err = errTimeout + break + } + } n, sa, err = syscall.Recvfrom(fd.sysfd, p, 0) if err == syscall.EAGAIN { err = errTimeout @@ -481,6 +493,12 @@ func (fd *netFD) ReadMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.S } defer fd.decref() for { + if fd.rdeadline > 0 { + if time.Now().UnixNano() >= fd.rdeadline { + err = errTimeout + break + } + } n, oobn, flags, sa, err = syscall.Recvmsg(fd.sysfd, p, oob, 0) if err == syscall.EAGAIN { err = errTimeout @@ -512,6 +530,12 @@ func (fd *netFD) Write(p []byte) (int, error) { var err error nn := 0 for { + if fd.wdeadline > 0 { + if time.Now().UnixNano() >= fd.wdeadline { + err = errTimeout + break + } + } var n int n, err = syscall.Write(int(fd.sysfd), p[nn:]) if n > 0 { @@ -551,6 +575,12 @@ func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err error) { } defer fd.decref() for { + if fd.wdeadline > 0 { + if time.Now().UnixNano() >= fd.wdeadline { + err = errTimeout + break + } + } err = syscall.Sendto(fd.sysfd, p, 0, sa) if err == syscall.EAGAIN { err = errTimeout @@ -578,6 +608,12 @@ func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oob } defer fd.decref() for { + if fd.wdeadline > 0 { + if time.Now().UnixNano() >= fd.wdeadline { + err = errTimeout + break + } + } err = syscall.Sendmsg(fd.sysfd, p, oob, sa, 0) if err == syscall.EAGAIN { err = errTimeout diff --git a/src/pkg/net/timeout_test.go b/src/pkg/net/timeout_test.go index 68d8ced011a..b5b2fa28962 100644 --- a/src/pkg/net/timeout_test.go +++ b/src/pkg/net/timeout_test.go @@ -6,11 +6,24 @@ package net import ( "fmt" + "io" + "io/ioutil" "runtime" "testing" "time" ) +func isTimeout(err error) bool { + e, ok := err.(Error) + return ok && e.Timeout() +} + +type copyRes struct { + n int64 + err error + d time.Duration +} + func testTimeout(t *testing.T, net, addr string, readFrom bool) { c, err := Dial(net, addr) if err != nil { @@ -230,3 +243,191 @@ func TestReadWriteDeadline(t *testing.T) { <-quit <-lnquit } + +type neverEnding byte + +func (b neverEnding) Read(p []byte) (n int, err error) { + for i := range p { + p[i] = byte(b) + } + return len(p), nil +} + +func TestVariousDeadlines1Proc(t *testing.T) { + testVariousDeadlines(t, 1) +} + +func TestVariousDeadlines4Proc(t *testing.T) { + testVariousDeadlines(t, 4) +} + +func testVariousDeadlines(t *testing.T, maxProcs int) { + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs)) + ln := newLocalListener(t) + defer ln.Close() + donec := make(chan struct{}) + defer close(donec) + + testsDone := func() bool { + select { + case <-donec: + return true + } + return false + } + + // The server, with no timeouts of its own, sending bytes to clients + // as fast as it can. + servec := make(chan copyRes) + go func() { + for { + c, err := ln.Accept() + if err != nil { + if !testsDone() { + t.Fatalf("Accept: %v", err) + } + return + } + go func() { + t0 := time.Now() + n, err := io.Copy(c, neverEnding('a')) + d := time.Since(t0) + c.Close() + servec <- copyRes{n, err, d} + }() + } + }() + + for _, timeout := range []time.Duration{ + 1 * time.Nanosecond, + 2 * time.Nanosecond, + 5 * time.Nanosecond, + 50 * time.Nanosecond, + 100 * time.Nanosecond, + 200 * time.Nanosecond, + 500 * time.Nanosecond, + 750 * time.Nanosecond, + 1 * time.Microsecond, + 5 * time.Microsecond, + 25 * time.Microsecond, + 250 * time.Microsecond, + 500 * time.Microsecond, + 1 * time.Millisecond, + 5 * time.Millisecond, + 100 * time.Millisecond, + 250 * time.Millisecond, + 500 * time.Millisecond, + 1 * time.Second, + } { + numRuns := 3 + if testing.Short() { + numRuns = 1 + if timeout > 500*time.Microsecond { + continue + } + } + for run := 0; run < numRuns; run++ { + name := fmt.Sprintf("%v run %d/%d", timeout, run+1, numRuns) + t.Log(name) + + c, err := Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + clientc := make(chan copyRes) + go func() { + t0 := time.Now() + c.SetDeadline(t0.Add(timeout)) + n, err := io.Copy(ioutil.Discard, c) + d := time.Since(t0) + c.Close() + clientc <- copyRes{n, err, d} + }() + + const tooLong = 2000 * time.Millisecond + select { + case res := <-clientc: + if isTimeout(res.err) { + t.Logf("for %v, good client timeout after %v, reading %d bytes", name, res.d, res.n) + } else { + t.Fatalf("for %v: client Copy = %d, %v (want timeout)", name, res.n, res.err) + } + case <-time.After(tooLong): + t.Fatalf("for %v: timeout (%v) waiting for client to timeout (%v) reading", name, tooLong, timeout) + } + + select { + case res := <-servec: + t.Logf("for %v: server in %v wrote %d, %v", name, res.d, res.n, res.err) + case <-time.After(tooLong): + t.Fatalf("for %v, timeout waiting for server to finish writing", name) + } + } + } +} + +// TestReadDeadlineDataAvailable tests that read deadlines work, even +// if there's data ready to be read. +func TestReadDeadlineDataAvailable(t *testing.T) { + ln := newLocalListener(t) + defer ln.Close() + + servec := make(chan copyRes) + const msg = "data client shouldn't read, even though it it'll be waiting" + go func() { + c, err := ln.Accept() + if err != nil { + t.Fatalf("Accept: %v", err) + } + defer c.Close() + n, err := c.Write([]byte(msg)) + servec <- copyRes{n: int64(n), err: err} + }() + + c, err := Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer c.Close() + if res := <-servec; res.err != nil || res.n != int64(len(msg)) { + t.Fatalf("unexpected server Write: n=%d, err=%d; want n=%d, err=nil", res.n, res.err, len(msg)) + } + c.SetReadDeadline(time.Now().Add(-5 * time.Second)) // in the psat. + buf := make([]byte, len(msg)/2) + n, err := c.Read(buf) + if n > 0 || !isTimeout(err) { + t.Fatalf("client read = %d (%q) err=%v; want 0, timeout", n, buf[:n], err) + } +} + +// TestWriteDeadlineBufferAvailable tests that write deadlines work, even +// if there's buffer space available to write. +func TestWriteDeadlineBufferAvailable(t *testing.T) { + ln := newLocalListener(t) + defer ln.Close() + + servec := make(chan copyRes) + go func() { + c, err := ln.Accept() + if err != nil { + t.Fatalf("Accept: %v", err) + } + defer c.Close() + c.SetWriteDeadline(time.Now().Add(-5 * time.Second)) // in the past + n, err := c.Write([]byte{'x'}) + servec <- copyRes{n: int64(n), err: err} + }() + + c, err := Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer c.Close() + res := <-servec + if res.n != 0 { + t.Errorf("Write = %d; want 0", res.n) + } + if !isTimeout(res.err) { + t.Errorf("Write error = %v; want timeout", res.err) + } +}