From bd388c0216bcb33d7325b0ad9722a3be8155a289 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Wed, 23 Oct 2024 16:01:08 -0700 Subject: [PATCH] internal/poll: keep copying after successful Sendfile return on BSD The BSD implementation of poll.SendFile incorrectly halted copying after succesfully writing one full chunk of data. Adjust the copy loop to match the Linux and Solaris implementations. In testing, empirically macOS appears to sometimes return EAGAIN from sendfile after successfully copying a full chunk. Add a check to all implementations to return nil after successfully copying all data if the last sendfile call returns EAGAIN. For #70000 Change-Id: I57ba649491fc078c7330310b23e1cfd85135c8ff Reviewed-on: https://go-review.googlesource.com/c/go/+/622235 LUCI-TryBot-Result: Go LUCI Reviewed-by: Ian Lance Taylor --- src/internal/poll/sendfile_bsd.go | 19 ++-- src/internal/poll/sendfile_linux.go | 3 + src/internal/poll/sendfile_solaris.go | 3 + src/os/copy_test.go | 154 ++++++++++++++++++++++++++ src/os/readfrom_linux_test.go | 41 ------- 5 files changed, 171 insertions(+), 49 deletions(-) create mode 100644 src/os/copy_test.go diff --git a/src/internal/poll/sendfile_bsd.go b/src/internal/poll/sendfile_bsd.go index f82af4407d..bc665978c3 100644 --- a/src/internal/poll/sendfile_bsd.go +++ b/src/internal/poll/sendfile_bsd.go @@ -41,22 +41,25 @@ func SendFile(dstFD *FD, src int, pos, remain int64) (written int64, err error, pos += int64(n) written += int64(n) remain -= int64(n) + continue + } else if err != syscall.EAGAIN && err != syscall.EINTR { + // This includes syscall.ENOSYS (no kernel + // support) and syscall.EINVAL (fd types which + // don't implement sendfile), and other errors. + // We should end the loop when there is no error + // returned from sendfile(2) or it is not a retryable error. + break } if err == syscall.EINTR { continue } - // This includes syscall.ENOSYS (no kernel - // support) and syscall.EINVAL (fd types which - // don't implement sendfile), and other errors. - // We should end the loop when there is no error - // returned from sendfile(2) or it is not a retryable error. - if err != syscall.EAGAIN { - break - } if err = dstFD.pd.waitWrite(dstFD.isFile); err != nil { break } } + if err == syscall.EAGAIN { + err = nil + } handled = written != 0 || (err != syscall.ENOSYS && err != syscall.EINVAL) return } diff --git a/src/internal/poll/sendfile_linux.go b/src/internal/poll/sendfile_linux.go index 1b618a72a2..7e800a3b7e 100644 --- a/src/internal/poll/sendfile_linux.go +++ b/src/internal/poll/sendfile_linux.go @@ -53,6 +53,9 @@ func SendFile(dstFD *FD, src int, remain int64) (written int64, err error, handl break } } + if err == syscall.EAGAIN { + err = nil + } handled = written != 0 || (err != syscall.ENOSYS && err != syscall.EINVAL) return } diff --git a/src/internal/poll/sendfile_solaris.go b/src/internal/poll/sendfile_solaris.go index e6f2d908a1..113214ff39 100644 --- a/src/internal/poll/sendfile_solaris.go +++ b/src/internal/poll/sendfile_solaris.go @@ -71,6 +71,9 @@ func SendFile(dstFD *FD, src int, pos, remain int64) (written int64, err error, break } } + if err == syscall.EAGAIN { + err = nil + } handled = written != 0 || (err != syscall.ENOSYS && err != syscall.EINVAL) return } diff --git a/src/os/copy_test.go b/src/os/copy_test.go new file mode 100644 index 0000000000..82346ca4e5 --- /dev/null +++ b/src/os/copy_test.go @@ -0,0 +1,154 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package os_test + +import ( + "bytes" + "errors" + "io" + "math/rand/v2" + "net" + "os" + "runtime" + "sync" + "testing" + + "golang.org/x/net/nettest" +) + +// Exercise sendfile/splice fast paths with a moderately large file. +// +// https://go.dev/issue/70000 + +func TestLargeCopyViaNetwork(t *testing.T) { + const size = 10 * 1024 * 1024 + dir := t.TempDir() + + src, err := os.Create(dir + "/src") + if err != nil { + t.Fatal(err) + } + defer src.Close() + if _, err := io.CopyN(src, newRandReader(), size); err != nil { + t.Fatal(err) + } + if _, err := src.Seek(0, 0); err != nil { + t.Fatal(err) + } + + dst, err := os.Create(dir + "/dst") + if err != nil { + t.Fatal(err) + } + defer dst.Close() + + client, server := createSocketPair(t, "tcp") + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + if n, err := io.Copy(dst, server); n != size || err != nil { + t.Errorf("copy to destination = %v, %v; want %v, nil", n, err, size) + } + }() + go func() { + defer wg.Done() + defer client.Close() + if n, err := io.Copy(client, src); n != size || err != nil { + t.Errorf("copy from source = %v, %v; want %v, nil", n, err, size) + } + }() + wg.Wait() + + if _, err := dst.Seek(0, 0); err != nil { + t.Fatal(err) + } + if err := compareReaders(dst, io.LimitReader(newRandReader(), size)); err != nil { + t.Fatal(err) + } +} + +func compareReaders(a, b io.Reader) error { + bufa := make([]byte, 4096) + bufb := make([]byte, 4096) + for { + na, erra := io.ReadFull(a, bufa) + if erra != nil && erra != io.EOF { + return erra + } + nb, errb := io.ReadFull(b, bufb) + if errb != nil && errb != io.EOF { + return errb + } + if !bytes.Equal(bufa[:na], bufb[:nb]) { + return errors.New("contents mismatch") + } + if erra == io.EOF && errb == io.EOF { + break + } + } + return nil +} + +type randReader struct { + rand *rand.Rand +} + +func newRandReader() *randReader { + return &randReader{rand.New(rand.NewPCG(0, 0))} +} + +func (r *randReader) Read(p []byte) (int, error) { + var v uint64 + var n int + for i := range p { + if n == 0 { + v = r.rand.Uint64() + n = 8 + } + p[i] = byte(v & 0xff) + v >>= 8 + n-- + } + return len(p), nil +} + +func createSocketPair(t *testing.T, proto string) (client, server net.Conn) { + t.Helper() + if !nettest.TestableNetwork(proto) { + t.Skipf("%s does not support %q", runtime.GOOS, proto) + } + + ln, err := nettest.NewLocalListener(proto) + if err != nil { + t.Fatalf("NewLocalListener error: %v", err) + } + t.Cleanup(func() { + if ln != nil { + ln.Close() + } + if client != nil { + client.Close() + } + if server != nil { + server.Close() + } + }) + ch := make(chan struct{}) + go func() { + var err error + server, err = ln.Accept() + if err != nil { + t.Errorf("Accept new connection error: %v", err) + } + ch <- struct{}{} + }() + client, err = net.Dial(proto, ln.Addr().String()) + <-ch + if err != nil { + t.Fatalf("Dial new connection error: %v", err) + } + return client, server +} diff --git a/src/os/readfrom_linux_test.go b/src/os/readfrom_linux_test.go index 2dcc7f0cf8..cc0322882b 100644 --- a/src/os/readfrom_linux_test.go +++ b/src/os/readfrom_linux_test.go @@ -14,14 +14,11 @@ import ( "net" . "os" "path/filepath" - "runtime" "strconv" "sync" "syscall" "testing" "time" - - "golang.org/x/net/nettest" ) func TestSpliceFile(t *testing.T) { @@ -479,41 +476,3 @@ func testGetPollFDAndNetwork(t *testing.T, proto string) { t.Fatalf("server Control error: %v", err) } } - -func createSocketPair(t *testing.T, proto string) (client, server net.Conn) { - t.Helper() - if !nettest.TestableNetwork(proto) { - t.Skipf("%s does not support %q", runtime.GOOS, proto) - } - - ln, err := nettest.NewLocalListener(proto) - if err != nil { - t.Fatalf("NewLocalListener error: %v", err) - } - t.Cleanup(func() { - if ln != nil { - ln.Close() - } - if client != nil { - client.Close() - } - if server != nil { - server.Close() - } - }) - ch := make(chan struct{}) - go func() { - var err error - server, err = ln.Accept() - if err != nil { - t.Errorf("Accept new connection error: %v", err) - } - ch <- struct{}{} - }() - client, err = net.Dial(proto, ln.Addr().String()) - <-ch - if err != nil { - t.Fatalf("Dial new connection error: %v", err) - } - return client, server -}