diff --git a/src/os/readfrom_linux_test.go b/src/os/readfrom_linux_test.go index cb6a59abdb..982a2b6330 100644 --- a/src/os/readfrom_linux_test.go +++ b/src/os/readfrom_linux_test.go @@ -13,6 +13,7 @@ import ( . "os" "path/filepath" "strconv" + "strings" "syscall" "testing" "time" @@ -74,6 +75,56 @@ func TestCopyFileRange(t *testing.T) { mustSeekStart(t, dst2) mustContainData(t, dst2, data) // through traditional means }) + t.Run("CopyFileItself", func(t *testing.T) { + hook := hookCopyFileRange(t) + + f, err := os.CreateTemp("", "file-readfrom-itself-test") + if err != nil { + t.Fatalf("failed to create tmp file: %v", err) + } + t.Cleanup(func() { + f.Close() + os.Remove(f.Name()) + }) + + data := []byte("hello world!") + if _, err := f.Write(data); err != nil { + t.Fatalf("failed to create and feed the file: %v", err) + } + + if err := f.Sync(); err != nil { + t.Fatalf("failed to save the file: %v", err) + } + + // Rewind it. + if _, err := f.Seek(0, io.SeekStart); err != nil { + t.Fatalf("failed to rewind the file: %v", err) + } + + // Read data from the file itself. + if _, err := io.Copy(f, f); err != nil { + t.Fatalf("failed to read from the file: %v", err) + } + + if !hook.called || hook.written != 0 || hook.handled || hook.err != nil { + t.Fatalf("poll.CopyFileRange should be called and return the EINVAL error, but got hook.called=%t, hook.err=%v", hook.called, hook.err) + } + + // Rewind it. + if _, err := f.Seek(0, io.SeekStart); err != nil { + t.Fatalf("failed to rewind the file: %v", err) + } + + data2, err := io.ReadAll(f) + if err != nil { + t.Fatalf("failed to read from the file: %v", err) + } + + // It should wind up a double of the original data. + if strings.Repeat(string(data), 2) != string(data2) { + t.Fatalf("data mismatch: %s != %s", string(data), string(data2)) + } + }) t.Run("NotRegular", func(t *testing.T) { t.Run("BothPipes", func(t *testing.T) { hook := hookCopyFileRange(t) @@ -344,6 +395,10 @@ type copyFileRangeHook struct { srcfd int remain int64 + written int64 + handled bool + err error + original func(dst, src *poll.FD, remain int64) (int64, bool, error) } @@ -354,7 +409,8 @@ func (h *copyFileRangeHook) install() { h.dstfd = dst.Sysfd h.srcfd = src.Sysfd h.remain = remain - return h.original(dst, src, remain) + h.written, h.handled, h.err = h.original(dst, src, remain) + return h.written, h.handled, h.err } }