diff --git a/src/net/dial.go b/src/net/dial.go index 0a7da408fe5..bed0b1e3e0c 100644 --- a/src/net/dial.go +++ b/src/net/dial.go @@ -135,23 +135,26 @@ func (d *Dialer) fallbackDelay() time.Duration { } } -func parseNetwork(ctx context.Context, net string) (afnet string, proto int, err error) { - i := last(net, ':') +func parseNetwork(ctx context.Context, network string, needsProto bool) (afnet string, proto int, err error) { + i := last(network, ':') if i < 0 { // no colon - switch net { + switch network { case "tcp", "tcp4", "tcp6": case "udp", "udp4", "udp6": case "ip", "ip4", "ip6": + if needsProto { + return "", 0, UnknownNetworkError(network) + } case "unix", "unixgram", "unixpacket": default: - return "", 0, UnknownNetworkError(net) + return "", 0, UnknownNetworkError(network) } - return net, 0, nil + return network, 0, nil } - afnet = net[:i] + afnet = network[:i] switch afnet { case "ip", "ip4", "ip6": - protostr := net[i+1:] + protostr := network[i+1:] proto, i, ok := dtoi(protostr) if !ok || i != len(protostr) { proto, err = lookupProtocol(ctx, protostr) @@ -161,14 +164,14 @@ func parseNetwork(ctx context.Context, net string) (afnet string, proto int, err } return afnet, proto, nil } - return "", 0, UnknownNetworkError(net) + return "", 0, UnknownNetworkError(network) } // resolveAddrList resolves addr using hint and returns a list of // addresses. The result contains at least one address when error is // nil. func (r *Resolver) resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) (addrList, error) { - afnet, _, err := parseNetwork(ctx, network) + afnet, _, err := parseNetwork(ctx, network, true) if err != nil { return nil, err } diff --git a/src/net/iprawsock.go b/src/net/iprawsock.go index d994fc67c65..d69a303d786 100644 --- a/src/net/iprawsock.go +++ b/src/net/iprawsock.go @@ -71,7 +71,7 @@ func ResolveIPAddr(net, addr string) (*IPAddr, error) { if net == "" { // a hint wildcard for Go 1.0 undocumented behavior net = "ip" } - afnet, _, err := parseNetwork(context.Background(), net) + afnet, _, err := parseNetwork(context.Background(), net, false) if err != nil { return nil, err } diff --git a/src/net/iprawsock_posix.go b/src/net/iprawsock_posix.go index 8f4b702e480..5d76818af9f 100644 --- a/src/net/iprawsock_posix.go +++ b/src/net/iprawsock_posix.go @@ -113,7 +113,7 @@ func (c *IPConn) writeMsg(b, oob []byte, addr *IPAddr) (n, oobn int, err error) } func dialIP(ctx context.Context, netProto string, laddr, raddr *IPAddr) (*IPConn, error) { - network, proto, err := parseNetwork(ctx, netProto) + network, proto, err := parseNetwork(ctx, netProto, true) if err != nil { return nil, err } @@ -133,7 +133,7 @@ func dialIP(ctx context.Context, netProto string, laddr, raddr *IPAddr) (*IPConn } func listenIP(ctx context.Context, netProto string, laddr *IPAddr) (*IPConn, error) { - network, proto, err := parseNetwork(ctx, netProto) + network, proto, err := parseNetwork(ctx, netProto, true) if err != nil { return nil, err } diff --git a/src/net/iprawsock_test.go b/src/net/iprawsock_test.go index 5d33b26a91e..8972051f5d5 100644 --- a/src/net/iprawsock_test.go +++ b/src/net/iprawsock_test.go @@ -117,3 +117,75 @@ func TestIPConnRemoteName(t *testing.T) { t.Fatalf("got %#v; want %#v", c.RemoteAddr(), raddr) } } + +func TestDialListenIPArgs(t *testing.T) { + type test struct { + argLists [][2]string + shouldFail bool + } + tests := []test{ + { + argLists: [][2]string{ + {"ip", "127.0.0.1"}, + {"ip:", "127.0.0.1"}, + {"ip::", "127.0.0.1"}, + {"ip", "::1"}, + {"ip:", "::1"}, + {"ip::", "::1"}, + {"ip4", "127.0.0.1"}, + {"ip4:", "127.0.0.1"}, + {"ip4::", "127.0.0.1"}, + {"ip6", "::1"}, + {"ip6:", "::1"}, + {"ip6::", "::1"}, + }, + shouldFail: true, + }, + } + if testableNetwork("ip") { + priv := test{shouldFail: false} + for _, tt := range []struct { + network, address string + args [2]string + }{ + {"ip4:47", "127.0.0.1", [2]string{"ip4:47", "127.0.0.1"}}, + {"ip6:47", "::1", [2]string{"ip6:47", "::1"}}, + } { + c, err := ListenPacket(tt.network, tt.address) + if err != nil { + continue + } + c.Close() + priv.argLists = append(priv.argLists, tt.args) + } + if len(priv.argLists) > 0 { + tests = append(tests, priv) + } + } + + for _, tt := range tests { + for _, args := range tt.argLists { + _, err := Dial(args[0], args[1]) + if tt.shouldFail != (err != nil) { + t.Errorf("Dial(%q, %q) = %v; want (err != nil) is %t", args[0], args[1], err, tt.shouldFail) + } + _, err = ListenPacket(args[0], args[1]) + if tt.shouldFail != (err != nil) { + t.Errorf("ListenPacket(%q, %q) = %v; want (err != nil) is %t", args[0], args[1], err, tt.shouldFail) + } + a, err := ResolveIPAddr("ip", args[1]) + if err != nil { + t.Errorf("ResolveIPAddr(\"ip\", %q) = %v", args[1], err) + continue + } + _, err = DialIP(args[0], nil, a) + if tt.shouldFail != (err != nil) { + t.Errorf("DialIP(%q, %v) = %v; want (err != nil) is %t", args[0], a, err, tt.shouldFail) + } + _, err = ListenIP(args[0], a) + if tt.shouldFail != (err != nil) { + t.Errorf("ListenIP(%q, %v) = %v; want (err != nil) is %t", args[0], a, err, tt.shouldFail) + } + } + } +}