diff --git a/src/internal/poll/sendfile_bsd.go b/src/internal/poll/sendfile_bsd.go index f82af4407d..bc665978c3 100644 --- a/src/internal/poll/sendfile_bsd.go +++ b/src/internal/poll/sendfile_bsd.go @@ -41,22 +41,25 @@ func SendFile(dstFD *FD, src int, pos, remain int64) (written int64, err error, pos += int64(n) written += int64(n) remain -= int64(n) + continue + } else if err != syscall.EAGAIN && err != syscall.EINTR { + // This includes syscall.ENOSYS (no kernel + // support) and syscall.EINVAL (fd types which + // don't implement sendfile), and other errors. + // We should end the loop when there is no error + // returned from sendfile(2) or it is not a retryable error. + break } if err == syscall.EINTR { continue } - // This includes syscall.ENOSYS (no kernel - // support) and syscall.EINVAL (fd types which - // don't implement sendfile), and other errors. - // We should end the loop when there is no error - // returned from sendfile(2) or it is not a retryable error. - if err != syscall.EAGAIN { - break - } if err = dstFD.pd.waitWrite(dstFD.isFile); err != nil { break } } + if err == syscall.EAGAIN { + err = nil + } handled = written != 0 || (err != syscall.ENOSYS && err != syscall.EINVAL) return } diff --git a/src/internal/poll/sendfile_linux.go b/src/internal/poll/sendfile_linux.go index 1b618a72a2..7e800a3b7e 100644 --- a/src/internal/poll/sendfile_linux.go +++ b/src/internal/poll/sendfile_linux.go @@ -53,6 +53,9 @@ func SendFile(dstFD *FD, src int, remain int64) (written int64, err error, handl break } } + if err == syscall.EAGAIN { + err = nil + } handled = written != 0 || (err != syscall.ENOSYS && err != syscall.EINVAL) return } diff --git a/src/internal/poll/sendfile_solaris.go b/src/internal/poll/sendfile_solaris.go index e6f2d908a1..113214ff39 100644 --- a/src/internal/poll/sendfile_solaris.go +++ b/src/internal/poll/sendfile_solaris.go @@ -71,6 +71,9 @@ func SendFile(dstFD *FD, src int, pos, remain int64) (written int64, err error, break } } + if err == syscall.EAGAIN { + err = nil + } handled = written != 0 || (err != syscall.ENOSYS && err != syscall.EINVAL) return } diff --git a/src/os/copy_test.go b/src/os/copy_test.go new file mode 100644 index 0000000000..82346ca4e5 --- /dev/null +++ b/src/os/copy_test.go @@ -0,0 +1,154 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package os_test + +import ( + "bytes" + "errors" + "io" + "math/rand/v2" + "net" + "os" + "runtime" + "sync" + "testing" + + "golang.org/x/net/nettest" +) + +// Exercise sendfile/splice fast paths with a moderately large file. +// +// https://go.dev/issue/70000 + +func TestLargeCopyViaNetwork(t *testing.T) { + const size = 10 * 1024 * 1024 + dir := t.TempDir() + + src, err := os.Create(dir + "/src") + if err != nil { + t.Fatal(err) + } + defer src.Close() + if _, err := io.CopyN(src, newRandReader(), size); err != nil { + t.Fatal(err) + } + if _, err := src.Seek(0, 0); err != nil { + t.Fatal(err) + } + + dst, err := os.Create(dir + "/dst") + if err != nil { + t.Fatal(err) + } + defer dst.Close() + + client, server := createSocketPair(t, "tcp") + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + if n, err := io.Copy(dst, server); n != size || err != nil { + t.Errorf("copy to destination = %v, %v; want %v, nil", n, err, size) + } + }() + go func() { + defer wg.Done() + defer client.Close() + if n, err := io.Copy(client, src); n != size || err != nil { + t.Errorf("copy from source = %v, %v; want %v, nil", n, err, size) + } + }() + wg.Wait() + + if _, err := dst.Seek(0, 0); err != nil { + t.Fatal(err) + } + if err := compareReaders(dst, io.LimitReader(newRandReader(), size)); err != nil { + t.Fatal(err) + } +} + +func compareReaders(a, b io.Reader) error { + bufa := make([]byte, 4096) + bufb := make([]byte, 4096) + for { + na, erra := io.ReadFull(a, bufa) + if erra != nil && erra != io.EOF { + return erra + } + nb, errb := io.ReadFull(b, bufb) + if errb != nil && errb != io.EOF { + return errb + } + if !bytes.Equal(bufa[:na], bufb[:nb]) { + return errors.New("contents mismatch") + } + if erra == io.EOF && errb == io.EOF { + break + } + } + return nil +} + +type randReader struct { + rand *rand.Rand +} + +func newRandReader() *randReader { + return &randReader{rand.New(rand.NewPCG(0, 0))} +} + +func (r *randReader) Read(p []byte) (int, error) { + var v uint64 + var n int + for i := range p { + if n == 0 { + v = r.rand.Uint64() + n = 8 + } + p[i] = byte(v & 0xff) + v >>= 8 + n-- + } + return len(p), nil +} + +func createSocketPair(t *testing.T, proto string) (client, server net.Conn) { + t.Helper() + if !nettest.TestableNetwork(proto) { + t.Skipf("%s does not support %q", runtime.GOOS, proto) + } + + ln, err := nettest.NewLocalListener(proto) + if err != nil { + t.Fatalf("NewLocalListener error: %v", err) + } + t.Cleanup(func() { + if ln != nil { + ln.Close() + } + if client != nil { + client.Close() + } + if server != nil { + server.Close() + } + }) + ch := make(chan struct{}) + go func() { + var err error + server, err = ln.Accept() + if err != nil { + t.Errorf("Accept new connection error: %v", err) + } + ch <- struct{}{} + }() + client, err = net.Dial(proto, ln.Addr().String()) + <-ch + if err != nil { + t.Fatalf("Dial new connection error: %v", err) + } + return client, server +} diff --git a/src/os/readfrom_linux_test.go b/src/os/readfrom_linux_test.go index 2dcc7f0cf8..cc0322882b 100644 --- a/src/os/readfrom_linux_test.go +++ b/src/os/readfrom_linux_test.go @@ -14,14 +14,11 @@ import ( "net" . "os" "path/filepath" - "runtime" "strconv" "sync" "syscall" "testing" "time" - - "golang.org/x/net/nettest" ) func TestSpliceFile(t *testing.T) { @@ -479,41 +476,3 @@ func testGetPollFDAndNetwork(t *testing.T, proto string) { t.Fatalf("server Control error: %v", err) } } - -func createSocketPair(t *testing.T, proto string) (client, server net.Conn) { - t.Helper() - if !nettest.TestableNetwork(proto) { - t.Skipf("%s does not support %q", runtime.GOOS, proto) - } - - ln, err := nettest.NewLocalListener(proto) - if err != nil { - t.Fatalf("NewLocalListener error: %v", err) - } - t.Cleanup(func() { - if ln != nil { - ln.Close() - } - if client != nil { - client.Close() - } - if server != nil { - server.Close() - } - }) - ch := make(chan struct{}) - go func() { - var err error - server, err = ln.Accept() - if err != nil { - t.Errorf("Accept new connection error: %v", err) - } - ch <- struct{}{} - }() - client, err = net.Dial(proto, ln.Addr().String()) - <-ch - if err != nil { - t.Fatalf("Dial new connection error: %v", err) - } - return client, server -}