From c3db64c0f45e8f2d75c5b59401e0fc925701b6f4 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Sat, 10 Jun 2023 11:31:43 -0700 Subject: [PATCH] net: fix panic when calling net.Listen or net.Dial on wasip1 Address a panic that was caused by net.Dial/net.Listen entering the fake network stack and assuming that the addresses would be of type *TCPAddr, where in fact they could have been *UDPAddr or *UnixAddr as well. The fix consist in implementing the fake network facility for udp and unix addresses, preventing the assumed type assertion to TCPAddr from triggering a panic. New tests are added to verify that using the fake network from the exported functions of the net package satisfies the minimal requirement of being able to create a listener and establish a connection for all the supported network types. Fixes #60012 Fixes #60739 Change-Id: I2688f1a0a7c6c9894ad3d137a5d311192c77a9b2 Reviewed-on: https://go-review.googlesource.com/c/go/+/502315 Reviewed-by: Dmitri Shuralyov TryBot-Result: Gopher Robot Run-TryBot: Dmitri Shuralyov Reviewed-by: Johan Brandhorst-Satzkorn Reviewed-by: Ian Lance Taylor Auto-Submit: Dmitri Shuralyov --- src/net/net_fake.go | 134 +++++++++++++++++++++----- src/net/net_fake_test.go | 203 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 312 insertions(+), 25 deletions(-) create mode 100644 src/net/net_fake_test.go diff --git a/src/net/net_fake.go b/src/net/net_fake.go index a816213f8de..68d36966ca2 100644 --- a/src/net/net_fake.go +++ b/src/net/net_fake.go @@ -20,7 +20,7 @@ import ( ) var listenersMu sync.Mutex -var listeners = make(map[string]*netFD) +var listeners = make(map[fakeNetAddr]*netFD) var portCounterMu sync.Mutex var portCounter = 0 @@ -32,13 +32,16 @@ func nextPort() int { return portCounter } +type fakeNetAddr struct { + network string + address string +} + type fakeNetFD struct { - listener bool - laddr Addr + listener fakeNetAddr r *bufferedPipe w *bufferedPipe incoming chan *netFD - closedMu sync.Mutex closed bool } @@ -51,32 +54,110 @@ func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only return fakelistener(fd, laddr) } fd2 := &netFD{family: family, sotype: sotype, net: net} - return fakeconn(fd, fd2, raddr) + return fakeconn(fd, fd2, laddr, raddr) +} + +func fakeIPAndPort(ip IP, port int) (IP, int) { + if ip == nil { + ip = IPv4(127, 0, 0, 1) + } + if port == 0 { + port = nextPort() + } + return ip, port +} + +func fakeTCPAddr(addr *TCPAddr) *TCPAddr { + var ip IP + var port int + var zone string + if addr != nil { + ip, port, zone = addr.IP, addr.Port, addr.Zone + } + ip, port = fakeIPAndPort(ip, port) + return &TCPAddr{IP: ip, Port: port, Zone: zone} +} + +func fakeUDPAddr(addr *UDPAddr) *UDPAddr { + var ip IP + var port int + var zone string + if addr != nil { + ip, port, zone = addr.IP, addr.Port, addr.Zone + } + ip, port = fakeIPAndPort(ip, port) + return &UDPAddr{IP: ip, Port: port, Zone: zone} +} + +func fakeUnixAddr(sotype int, addr *UnixAddr) *UnixAddr { + var net, name string + if addr != nil { + name = addr.Name + } + switch sotype { + case syscall.SOCK_DGRAM: + net = "unixgram" + case syscall.SOCK_SEQPACKET: + net = "unixpacket" + default: + net = "unix" + } + return &UnixAddr{Net: net, Name: name} } func fakelistener(fd *netFD, laddr sockaddr) (*netFD, error) { - l := laddr.(*TCPAddr) - fd.laddr = &TCPAddr{ - IP: l.IP, - Port: nextPort(), - Zone: l.Zone, + switch l := laddr.(type) { + case *TCPAddr: + laddr = fakeTCPAddr(l) + case *UDPAddr: + laddr = fakeUDPAddr(l) + case *UnixAddr: + if l.Name == "" { + return nil, syscall.ENOENT + } + laddr = fakeUnixAddr(fd.sotype, l) + default: + return nil, syscall.EOPNOTSUPP } + + listener := fakeNetAddr{ + network: laddr.Network(), + address: laddr.String(), + } + fd.fakeNetFD = &fakeNetFD{ - listener: true, - laddr: fd.laddr, + listener: listener, incoming: make(chan *netFD, 1024), } + + fd.laddr = laddr listenersMu.Lock() - listeners[fd.laddr.(*TCPAddr).String()] = fd - listenersMu.Unlock() + defer listenersMu.Unlock() + if _, exists := listeners[listener]; exists { + return nil, syscall.EADDRINUSE + } + listeners[listener] = fd return fd, nil } -func fakeconn(fd *netFD, fd2 *netFD, raddr sockaddr) (*netFD, error) { - fd.laddr = &TCPAddr{ - IP: IPv4(127, 0, 0, 1), - Port: nextPort(), +func fakeconn(fd *netFD, fd2 *netFD, laddr, raddr sockaddr) (*netFD, error) { + switch r := raddr.(type) { + case *TCPAddr: + r = fakeTCPAddr(r) + raddr = r + laddr = fakeTCPAddr(laddr.(*TCPAddr)) + case *UDPAddr: + r = fakeUDPAddr(r) + raddr = r + laddr = fakeUDPAddr(laddr.(*UDPAddr)) + case *UnixAddr: + r = fakeUnixAddr(fd.sotype, r) + raddr = r + laddr = &UnixAddr{Net: r.Net, Name: r.Name} + default: + return nil, syscall.EAFNOSUPPORT } + fd.laddr = laddr fd.raddr = raddr fd.fakeNetFD = &fakeNetFD{ @@ -90,15 +171,18 @@ func fakeconn(fd *netFD, fd2 *netFD, raddr sockaddr) (*netFD, error) { fd2.laddr = fd.raddr fd2.raddr = fd.laddr + + listener := fakeNetAddr{ + network: fd.raddr.Network(), + address: fd.raddr.String(), + } listenersMu.Lock() - l, ok := listeners[fd.raddr.(*TCPAddr).String()] + defer listenersMu.Unlock() + l, ok := listeners[listener] if !ok { - listenersMu.Unlock() return nil, syscall.ECONNREFUSED } l.incoming <- fd2 - listenersMu.Unlock() - return fd, nil } @@ -119,11 +203,11 @@ func (fd *fakeNetFD) Close() error { fd.closed = true fd.closedMu.Unlock() - if fd.listener { + if fd.listener != (fakeNetAddr{}) { listenersMu.Lock() - delete(listeners, fd.laddr.String()) + delete(listeners, fd.listener) close(fd.incoming) - fd.listener = false + fd.listener = fakeNetAddr{} listenersMu.Unlock() return nil } diff --git a/src/net/net_fake_test.go b/src/net/net_fake_test.go new file mode 100644 index 00000000000..783304d5314 --- /dev/null +++ b/src/net/net_fake_test.go @@ -0,0 +1,203 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build js || wasip1 + +package net + +// GOOS=js and GOOS=wasip1 do not have typical socket networking capabilities +// found on other platforms. To help run test suites of the stdlib packages, +// an in-memory "fake network" facility is implemented. +// +// The tests in this files are intended to validate the behavior of the fake +// network stack on these platforms. + +import "testing" + +func TestFakeConn(t *testing.T) { + tests := []struct { + name string + listen func() (Listener, error) + dial func(Addr) (Conn, error) + addr func(*testing.T, Addr) + }{ + { + name: "Listener:tcp", + listen: func() (Listener, error) { + return Listen("tcp", ":0") + }, + dial: func(addr Addr) (Conn, error) { + return Dial(addr.Network(), addr.String()) + }, + addr: testFakeTCPAddr, + }, + + { + name: "ListenTCP:tcp", + listen: func() (Listener, error) { + // Creating a listening TCP connection with a nil address must + // select an IP address on localhost with a random port. + // This test verifies that the fake network facility does that. + return ListenTCP("tcp", nil) + }, + dial: func(addr Addr) (Conn, error) { + // Connecting a listening TCP connection will select a local + // address on the local network and connects to the destination + // address. + return DialTCP("tcp", nil, addr.(*TCPAddr)) + }, + addr: testFakeTCPAddr, + }, + + { + name: "ListenUnix:unix", + listen: func() (Listener, error) { + return ListenUnix("unix", &UnixAddr{Name: "test"}) + }, + dial: func(addr Addr) (Conn, error) { + return DialUnix("unix", nil, addr.(*UnixAddr)) + }, + addr: testFakeUnixAddr("unix", "test"), + }, + + { + name: "ListenUnix:unixpacket", + listen: func() (Listener, error) { + return ListenUnix("unixpacket", &UnixAddr{Name: "test"}) + }, + dial: func(addr Addr) (Conn, error) { + return DialUnix("unixpacket", nil, addr.(*UnixAddr)) + }, + addr: testFakeUnixAddr("unixpacket", "test"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + l, err := test.listen() + if err != nil { + t.Fatal(err) + } + defer l.Close() + test.addr(t, l.Addr()) + + c, err := test.dial(l.Addr()) + if err != nil { + t.Fatal(err) + } + defer c.Close() + test.addr(t, c.LocalAddr()) + test.addr(t, c.RemoteAddr()) + }) + } +} + +func TestFakePacketConn(t *testing.T) { + tests := []struct { + name string + listen func() (PacketConn, error) + dial func(Addr) (Conn, error) + addr func(*testing.T, Addr) + }{ + { + name: "ListenPacket:udp", + listen: func() (PacketConn, error) { + return ListenPacket("udp", ":0") + }, + dial: func(addr Addr) (Conn, error) { + return Dial(addr.Network(), addr.String()) + }, + addr: testFakeUDPAddr, + }, + + { + name: "ListenUDP:udp", + listen: func() (PacketConn, error) { + // Creating a listening UDP connection with a nil address must + // select an IP address on localhost with a random port. + // This test verifies that the fake network facility does that. + return ListenUDP("udp", nil) + }, + dial: func(addr Addr) (Conn, error) { + // Connecting a listening UDP connection will select a local + // address on the local network and connects to the destination + // address. + return DialUDP("udp", nil, addr.(*UDPAddr)) + }, + addr: testFakeUDPAddr, + }, + + { + name: "ListenUnixgram:unixgram", + listen: func() (PacketConn, error) { + return ListenUnixgram("unixgram", &UnixAddr{Name: "test"}) + }, + dial: func(addr Addr) (Conn, error) { + return DialUnix("unixgram", nil, addr.(*UnixAddr)) + }, + addr: testFakeUnixAddr("unixgram", "test"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + l, err := test.listen() + if err != nil { + t.Fatal(err) + } + defer l.Close() + test.addr(t, l.LocalAddr()) + + c, err := test.dial(l.LocalAddr()) + if err != nil { + t.Fatal(err) + } + defer c.Close() + test.addr(t, c.LocalAddr()) + test.addr(t, c.RemoteAddr()) + }) + } +} + +func testFakeTCPAddr(t *testing.T, addr Addr) { + t.Helper() + if a, ok := addr.(*TCPAddr); !ok { + t.Errorf("Addr is not *TCPAddr: %T", addr) + } else { + testFakeNetAddr(t, a.IP, a.Port) + } +} + +func testFakeUDPAddr(t *testing.T, addr Addr) { + t.Helper() + if a, ok := addr.(*UDPAddr); !ok { + t.Errorf("Addr is not *UDPAddr: %T", addr) + } else { + testFakeNetAddr(t, a.IP, a.Port) + } +} + +func testFakeNetAddr(t *testing.T, ip IP, port int) { + t.Helper() + if port == 0 { + t.Error("network address is missing port") + } else if len(ip) == 0 { + t.Error("network address is missing IP") + } else if !ip.Equal(IPv4(127, 0, 0, 1)) { + t.Errorf("network address has wrong IP: %s", ip) + } +} + +func testFakeUnixAddr(net, name string) func(*testing.T, Addr) { + return func(t *testing.T, addr Addr) { + t.Helper() + if a, ok := addr.(*UnixAddr); !ok { + t.Errorf("Addr is not *UnixAddr: %T", addr) + } else if a.Net != net { + t.Errorf("unix address has wrong net: want=%q got=%q", net, a.Net) + } else if a.Name != name { + t.Errorf("unix address has wrong name: want=%q got=%q", name, a.Name) + } + } +}