diff --git a/src/net/fd_wasip1.go b/src/net/fd_wasip1.go index 3f64ff4683c..a3584e82bd0 100644 --- a/src/net/fd_wasip1.go +++ b/src/net/fd_wasip1.go @@ -38,22 +38,21 @@ type netFD struct { *fakeNetFD } -func newFD(sysfd int) (*netFD, error) { - return newPollFD(poll.FD{ +func newFD(net string, sysfd int) *netFD { + return newPollFD(net, poll.FD{ Sysfd: sysfd, IsStream: true, ZeroReadIsEOF: true, }) } -func newPollFD(pfd poll.FD) (*netFD, error) { - ret := &netFD{ +func newPollFD(net string, pfd poll.FD) *netFD { + return &netFD{ pfd: pfd, - net: "tcp", + net: net, laddr: unknownAddr{}, raddr: unknownAddr{}, } - return ret, nil } func (fd *netFD) init() error { @@ -75,10 +74,7 @@ func (fd *netFD) accept() (netfd *netFD, err error) { } return nil, err } - if netfd, err = newFD(d); err != nil { - poll.CloseFunc(d) - return nil, err - } + netfd = newFD("tcp", d) if err = netfd.init(); err != nil { netfd.Close() return nil, err diff --git a/src/net/file_wasip1.go b/src/net/file_wasip1.go index 95fd5403a6b..a3624efb55e 100644 --- a/src/net/file_wasip1.go +++ b/src/net/file_wasip1.go @@ -13,45 +13,85 @@ import ( ) func fileListener(f *os.File) (Listener, error) { - fd, err := newFileFD(f) + filetype, err := fd_fdstat_get_type(f.PollFD().Sysfd) if err != nil { return nil, err } - return &TCPListener{fd: fd}, nil + net, err := fileListenNet(filetype) + if err != nil { + return nil, err + } + pfd := f.PollFD().Copy() + fd := newPollFD(net, pfd) + if err := fd.init(); err != nil { + pfd.Close() + return nil, err + } + return newFileListener(fd), nil } func fileConn(f *os.File) (Conn, error) { - fd, err := newFileFD(f) + filetype, err := fd_fdstat_get_type(f.PollFD().Sysfd) if err != nil { return nil, err } - return &TCPConn{conn{fd: fd}}, nil + net, err := fileConnNet(filetype) + if err != nil { + return nil, err + } + pfd := f.PollFD().Copy() + fd := newPollFD(net, pfd) + if err := fd.init(); err != nil { + pfd.Close() + return nil, err + } + return newFileConn(fd), nil } -func filePacketConn(f *os.File) (PacketConn, error) { return nil, syscall.ENOPROTOOPT } +func filePacketConn(f *os.File) (PacketConn, error) { + return nil, syscall.ENOPROTOOPT +} -func newFileFD(f *os.File) (fd *netFD, err error) { - pfd := f.PollFD().Copy() - defer func() { - if err != nil { - pfd.Close() - } - }() - filetype, err := fd_fdstat_get_type(pfd.Sysfd) - if err != nil { - return nil, err +func fileListenNet(filetype syscall.Filetype) (string, error) { + switch filetype { + case syscall.FILETYPE_SOCKET_STREAM: + return "tcp", nil + case syscall.FILETYPE_SOCKET_DGRAM: + return "", syscall.EOPNOTSUPP + default: + return "", syscall.ENOTSOCK } - if filetype != syscall.FILETYPE_SOCKET_STREAM { - return nil, syscall.ENOTSOCK +} + +func fileConnNet(filetype syscall.Filetype) (string, error) { + switch filetype { + case syscall.FILETYPE_SOCKET_STREAM: + return "tcp", nil + case syscall.FILETYPE_SOCKET_DGRAM: + return "udp", nil + default: + return "", syscall.ENOTSOCK } - fd, err = newPollFD(pfd) - if err != nil { - return nil, err +} + +func newFileListener(fd *netFD) Listener { + switch fd.net { + case "tcp": + return &TCPListener{fd: fd} + default: + panic("unsupported network for file listener: " + fd.net) } - if err := fd.init(); err != nil { - return nil, err +} + +func newFileConn(fd *netFD) Conn { + switch fd.net { + case "tcp": + return &TCPConn{conn{fd: fd}} + case "udp": + return &UDPConn{conn{fd: fd}} + default: + panic("unsupported network for file connection: " + fd.net) } - return fd, nil } // This helper is implemented in the syscall package. It means we don't have diff --git a/src/net/file_wasip1_test.go b/src/net/file_wasip1_test.go new file mode 100644 index 00000000000..137574090fd --- /dev/null +++ b/src/net/file_wasip1_test.go @@ -0,0 +1,92 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build wasip1 + +package net + +import ( + "syscall" + "testing" +) + +// The tests in this file intend to validate the ability for net.FileConn and +// net.FileListener to handle both TCP and UDP sockets. Ideally we would test +// the public interface by constructing an *os.File from a file descriptor +// opened on a socket, but the WASI preview 1 specification is too limited to +// support this approach for UDP sockets. Instead, we test the internals that +// make it possible for WASI host runtimes and guest programs to integrate +// socket extensions with the net package using net.FileConn/net.FileListener. +// +// Note that the creation of net.Conn and net.Listener values for TCP sockets +// has an end-to-end test in src/runtime/internal/wasitest, here we are only +// verifying the code paths specific to UDP, and error handling for invalid use +// of the functions. + +func TestWasip1FileConnNet(t *testing.T) { + tests := []struct { + filetype syscall.Filetype + network string + error error + }{ + {syscall.FILETYPE_SOCKET_STREAM, "tcp", nil}, + {syscall.FILETYPE_SOCKET_DGRAM, "udp", nil}, + {syscall.FILETYPE_BLOCK_DEVICE, "", syscall.ENOTSOCK}, + {syscall.FILETYPE_CHARACTER_DEVICE, "", syscall.ENOTSOCK}, + {syscall.FILETYPE_DIRECTORY, "", syscall.ENOTSOCK}, + {syscall.FILETYPE_REGULAR_FILE, "", syscall.ENOTSOCK}, + {syscall.FILETYPE_SYMBOLIC_LINK, "", syscall.ENOTSOCK}, + {syscall.FILETYPE_UNKNOWN, "", syscall.ENOTSOCK}, + } + for _, test := range tests { + net, err := fileConnNet(test.filetype) + if net != test.network { + t.Errorf("fileConnNet: network mismatch: want=%q got=%q", test.network, net) + } + if err != test.error { + t.Errorf("fileConnNet: error mismatch: want=%v got=%v", test.error, err) + } + } +} + +func TestWasip1FileListenNet(t *testing.T) { + tests := []struct { + filetype syscall.Filetype + network string + error error + }{ + {syscall.FILETYPE_SOCKET_STREAM, "tcp", nil}, + {syscall.FILETYPE_SOCKET_DGRAM, "", syscall.EOPNOTSUPP}, + {syscall.FILETYPE_BLOCK_DEVICE, "", syscall.ENOTSOCK}, + {syscall.FILETYPE_CHARACTER_DEVICE, "", syscall.ENOTSOCK}, + {syscall.FILETYPE_DIRECTORY, "", syscall.ENOTSOCK}, + {syscall.FILETYPE_REGULAR_FILE, "", syscall.ENOTSOCK}, + {syscall.FILETYPE_SYMBOLIC_LINK, "", syscall.ENOTSOCK}, + {syscall.FILETYPE_UNKNOWN, "", syscall.ENOTSOCK}, + } + for _, test := range tests { + net, err := fileListenNet(test.filetype) + if net != test.network { + t.Errorf("fileListenNet: network mismatch: want=%q got=%q", test.network, net) + } + if err != test.error { + t.Errorf("fileListenNet: error mismatch: want=%v got=%v", test.error, err) + } + } +} + +func TestWasip1NewFileListener(t *testing.T) { + if l, ok := newFileListener(newFD("tcp", -1)).(*TCPListener); !ok { + t.Errorf("newFileListener: tcp listener type mismatch: %T", l) + } +} + +func TestWasip1NewFileConn(t *testing.T) { + if c, ok := newFileConn(newFD("tcp", -1)).(*TCPConn); !ok { + t.Errorf("newFileConn: tcp conn type mismatch: %T", c) + } + if c, ok := newFileConn(newFD("udp", -1)).(*UDPConn); !ok { + t.Errorf("newFileConn: udp conn type mismatch: %T", c) + } +}