diff --git a/src/net/udpsock.go b/src/net/udpsock.go index 622b1f83fb..6d29a39edf 100644 --- a/src/net/udpsock.go +++ b/src/net/udpsock.go @@ -99,16 +99,12 @@ func ResolveUDPAddr(network, address string) (*UDPAddr, error) { return addrs.forResolve(network, address).(*UDPAddr), nil } -// UDPAddrFromAddrPort returns addr as a UDPAddr. -// -// If addr is not valid, it returns nil. +// UDPAddrFromAddrPort returns addr as a UDPAddr. If addr.IsValid() is false, +// then the returned UDPAddr will contain a nil IP field, indicating an +// address family-agnostic unspecified address. func UDPAddrFromAddrPort(addr netip.AddrPort) *UDPAddr { - if !addr.IsValid() { - return nil - } - ip16 := addr.Addr().As16() return &UDPAddr{ - IP: IP(ip16[:]), + IP: addr.Addr().AsSlice(), Zone: addr.Addr().Zone(), Port: int(addr.Port()), } @@ -189,7 +185,9 @@ func (c *UDPConn) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err func (c *UDPConn) ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *UDPAddr, err error) { var ap netip.AddrPort n, oobn, flags, ap, err = c.ReadMsgUDPAddrPort(b, oob) - addr = UDPAddrFromAddrPort(ap) + if ap.IsValid() { + addr = UDPAddrFromAddrPort(ap) + } return } diff --git a/src/net/udpsock_test.go b/src/net/udpsock_test.go index 9fe74f47a2..01b8d39216 100644 --- a/src/net/udpsock_test.go +++ b/src/net/udpsock_test.go @@ -603,3 +603,35 @@ func BenchmarkWriteToReadFromUDPAddrPort(b *testing.B) { } } } + +func TestUDPIPVersionReadMsg(t *testing.T) { + conn, err := ListenUDP("udp4", &UDPAddr{IP: IPv4(127, 0, 0, 1)}) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + daddr := conn.LocalAddr().(*UDPAddr).AddrPort() + buf := make([]byte, 8) + _, err = conn.WriteToUDPAddrPort(buf, daddr) + if err != nil { + t.Fatal(err) + } + _, _, _, saddr, err := conn.ReadMsgUDPAddrPort(buf, nil) + if err != nil { + t.Fatal(err) + } + if !saddr.Addr().Is4() { + t.Error("returned AddrPort is not IPv4") + } + _, err = conn.WriteToUDPAddrPort(buf, daddr) + if err != nil { + t.Fatal(err) + } + _, _, _, soldaddr, err := conn.ReadMsgUDP(buf, nil) + if err != nil { + t.Fatal(err) + } + if len(soldaddr.IP) != 4 { + t.Error("returned UDPAddr is not IPv4") + } +}