diff --git a/src/net/dnsclient_unix.go b/src/net/dnsclient_unix.go index fe00fe19fe..2fee3346e9 100644 --- a/src/net/dnsclient_unix.go +++ b/src/net/dnsclient_unix.go @@ -137,10 +137,10 @@ func (r *Resolver) exchange(ctx context.Context, server string, q dnsmessage.Que } var p dnsmessage.Parser var h dnsmessage.Header - if network == "tcp" { - p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq) - } else { + if _, ok := c.(PacketConn); ok { p, h, err = dnsPacketRoundTrip(c, id, q, udpReq) + } else { + p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq) } c.Close() if err != nil { diff --git a/src/net/dnsclient_unix_test.go b/src/net/dnsclient_unix_test.go index a95b2fe645..bb014b903a 100644 --- a/src/net/dnsclient_unix_test.go +++ b/src/net/dnsclient_unix_test.go @@ -113,7 +113,7 @@ var specialDomainNameTests = []struct { } func TestSpecialDomainName(t *testing.T) { - fake := fakeDNSServer{func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { + fake := fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { r := dnsmessage.Message{ Header: dnsmessage.Header{ ID: q.ID, @@ -189,7 +189,7 @@ func TestAvoidDNSName(t *testing.T) { } } -var fakeDNSServerSuccessful = fakeDNSServer{func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { +var fakeDNSServerSuccessful = fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { r := dnsmessage.Message{ Header: dnsmessage.Header{ ID: q.ID, @@ -473,7 +473,7 @@ var goLookupIPWithResolverConfigTests = []struct { func TestGoLookupIPWithResolverConfig(t *testing.T) { defer dnsWaitGroup.Wait() - fake := fakeDNSServer{func(n, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { + fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { switch s { case "[2001:4860:4860::8888]:53", "8.8.8.8:53": break @@ -571,7 +571,7 @@ func TestGoLookupIPWithResolverConfig(t *testing.T) { func TestGoLookupIPOrderFallbackToFile(t *testing.T) { defer dnsWaitGroup.Wait() - fake := fakeDNSServer{func(n, s string, q dnsmessage.Message, tm time.Time) (dnsmessage.Message, error) { + fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, tm time.Time) (dnsmessage.Message, error) { r := dnsmessage.Message{ Header: dnsmessage.Header{ ID: q.ID, @@ -641,7 +641,7 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) { t.Fatal(err) } - fake := fakeDNSServer{func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { + fake := fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { r := dnsmessage.Message{ Header: dnsmessage.Header{ ID: q.ID, @@ -696,7 +696,7 @@ func TestIgnoreLameReferrals(t *testing.T) { t.Fatal(err) } - fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { + fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { t.Log(s, q) r := dnsmessage.Message{ Header: dnsmessage.Header{ @@ -788,12 +788,15 @@ func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) { } type fakeDNSServer struct { - rh func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error) + rh func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error) + alwaysTCP bool } func (server *fakeDNSServer) DialContext(_ context.Context, n, s string) (Conn, error) { - tcp := n == "tcp" || n == "tcp4" || n == "tcp6" - return &fakeDNSConn{tcp: tcp, server: server, n: n, s: s}, nil + if server.alwaysTCP || n == "tcp" || n == "tcp4" || n == "tcp6" { + return &fakeDNSConn{tcp: true, server: server, n: n, s: s}, nil + } + return &fakeDNSPacketConn{fakeDNSConn: fakeDNSConn{tcp: false, server: server, n: n, s: s}}, nil } type fakeDNSConn struct { @@ -846,10 +849,6 @@ func (f *fakeDNSConn) Read(b []byte) (int, error) { return len(bb), nil } -func (f *fakeDNSConn) ReadFrom(b []byte) (int, Addr, error) { - return 0, nil, nil -} - func (f *fakeDNSConn) Write(b []byte) (int, error) { if f.tcp && len(b) >= 2 { b = b[2:] @@ -860,15 +859,24 @@ func (f *fakeDNSConn) Write(b []byte) (int, error) { return len(b), nil } -func (f *fakeDNSConn) WriteTo(b []byte, addr Addr) (int, error) { - return 0, nil -} - func (f *fakeDNSConn) SetDeadline(t time.Time) error { f.t = t return nil } +type fakeDNSPacketConn struct { + PacketConn + fakeDNSConn +} + +func (f *fakeDNSPacketConn) SetDeadline(t time.Time) error { + return f.fakeDNSConn.SetDeadline(t) +} + +func (f *fakeDNSPacketConn) Close() error { + return f.fakeDNSConn.Close() +} + // UDP round-tripper algorithm should ignore invalid DNS responses (issue 13281). func TestIgnoreDNSForgeries(t *testing.T) { c, s := Pipe() @@ -973,7 +981,7 @@ func TestRetryTimeout(t *testing.T) { var deadline0 time.Time - fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) { + fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) { t.Log(s, q, deadline) if deadline.IsZero() { @@ -1034,7 +1042,7 @@ func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) { } var usedServers []string - fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) { + fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) { usedServers = append(usedServers, s) return mockTXTResponse(q), nil }} @@ -1218,7 +1226,7 @@ func TestStrictErrorsLookupIP(t *testing.T) { } for i, tt := range cases { - fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) { + fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) { t.Log(s, q) switch tt.resolveWhich(q.Questions[0]) { @@ -1356,7 +1364,7 @@ func TestStrictErrorsLookupTXT(t *testing.T) { const searchY = "test.y.golang.org." const txt = "Hello World" - fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) { + fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) { t.Log(s, q) switch q.Questions[0].Name.String() { @@ -1402,7 +1410,7 @@ func TestStrictErrorsLookupTXT(t *testing.T) { func TestDNSGoroutineRace(t *testing.T) { defer dnsWaitGroup.Wait() - fake := fakeDNSServer{func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error) { + fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error) { time.Sleep(10 * time.Microsecond) return dnsmessage.Message{}, poll.ErrTimeout }} @@ -1502,3 +1510,28 @@ func TestIssue12778(t *testing.T) { t.Fatalf("Err = %#v; wanted %q", de.Err, errNoSuchHost.Error()) } } + +// Issue 26573: verify that Conns that don't implement PacketConn are treated +// as streams even when udp was requested. +func TestDNSDialTCP(t *testing.T) { + fake := fakeDNSServer{ + rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { + r := dnsmessage.Message{ + Header: dnsmessage.Header{ + ID: q.Header.ID, + Response: true, + RCode: dnsmessage.RCodeSuccess, + }, + Questions: q.Questions, + } + return r, nil + }, + alwaysTCP: true, + } + r := Resolver{PreferGo: true, Dial: fake.DialContext} + ctx := context.Background() + _, _, err := r.exchange(ctx, "0.0.0.0", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), time.Second) + if err != nil { + t.Fatal("exhange failed:", err) + } +}