From 5f5402b72313463a25c1bcacd559254834690d7e Mon Sep 17 00:00:00 2001 From: Ian Gudger Date: Tue, 24 Jul 2018 16:19:19 -0700 Subject: [PATCH] net: fix handling of Conns created by Resolver.Dial The DNS client in net is documented to treat Conns returned by Resolver.Dial which implement PacketConn as UDP and those which don't as TCP regardless of what was requested. golang.org/cl/37879 changed the DNS client to assume that the Conn returned by Resolver.Dial was the requested type which broke compatibility. Fixes #26573 Updates #16218 Change-Id: Idf4f073a4cc3b1db36a3804898df206907f9c43c Reviewed-on: https://go-review.googlesource.com/125735 Run-TryBot: Ian Gudger TryBot-Result: Gobot Gobot Reviewed-by: Brad Fitzpatrick --- src/net/dnsclient_unix.go | 6 +-- src/net/dnsclient_unix_test.go | 77 ++++++++++++++++++++++++---------- 2 files changed, 58 insertions(+), 25 deletions(-) diff --git a/src/net/dnsclient_unix.go b/src/net/dnsclient_unix.go index fe00fe19fe..2fee3346e9 100644 --- a/src/net/dnsclient_unix.go +++ b/src/net/dnsclient_unix.go @@ -137,10 +137,10 @@ func (r *Resolver) exchange(ctx context.Context, server string, q dnsmessage.Que } var p dnsmessage.Parser var h dnsmessage.Header - if network == "tcp" { - p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq) - } else { + if _, ok := c.(PacketConn); ok { p, h, err = dnsPacketRoundTrip(c, id, q, udpReq) + } else { + p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq) } c.Close() if err != nil { diff --git a/src/net/dnsclient_unix_test.go b/src/net/dnsclient_unix_test.go index a95b2fe645..bb014b903a 100644 --- a/src/net/dnsclient_unix_test.go +++ b/src/net/dnsclient_unix_test.go @@ -113,7 +113,7 @@ var specialDomainNameTests = []struct { } func TestSpecialDomainName(t *testing.T) { - fake := fakeDNSServer{func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { + fake := fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { r := dnsmessage.Message{ Header: dnsmessage.Header{ ID: q.ID, @@ -189,7 +189,7 @@ func TestAvoidDNSName(t *testing.T) { } } -var fakeDNSServerSuccessful = fakeDNSServer{func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { +var fakeDNSServerSuccessful = fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { r := dnsmessage.Message{ Header: dnsmessage.Header{ ID: q.ID, @@ -473,7 +473,7 @@ var goLookupIPWithResolverConfigTests = []struct { func TestGoLookupIPWithResolverConfig(t *testing.T) { defer dnsWaitGroup.Wait() - fake := fakeDNSServer{func(n, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { + fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { switch s { case "[2001:4860:4860::8888]:53", "8.8.8.8:53": break @@ -571,7 +571,7 @@ func TestGoLookupIPWithResolverConfig(t *testing.T) { func TestGoLookupIPOrderFallbackToFile(t *testing.T) { defer dnsWaitGroup.Wait() - fake := fakeDNSServer{func(n, s string, q dnsmessage.Message, tm time.Time) (dnsmessage.Message, error) { + fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, tm time.Time) (dnsmessage.Message, error) { r := dnsmessage.Message{ Header: dnsmessage.Header{ ID: q.ID, @@ -641,7 +641,7 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) { t.Fatal(err) } - fake := fakeDNSServer{func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { + fake := fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { r := dnsmessage.Message{ Header: dnsmessage.Header{ ID: q.ID, @@ -696,7 +696,7 @@ func TestIgnoreLameReferrals(t *testing.T) { t.Fatal(err) } - fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { + fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { t.Log(s, q) r := dnsmessage.Message{ Header: dnsmessage.Header{ @@ -788,12 +788,15 @@ func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) { } type fakeDNSServer struct { - rh func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error) + rh func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error) + alwaysTCP bool } func (server *fakeDNSServer) DialContext(_ context.Context, n, s string) (Conn, error) { - tcp := n == "tcp" || n == "tcp4" || n == "tcp6" - return &fakeDNSConn{tcp: tcp, server: server, n: n, s: s}, nil + if server.alwaysTCP || n == "tcp" || n == "tcp4" || n == "tcp6" { + return &fakeDNSConn{tcp: true, server: server, n: n, s: s}, nil + } + return &fakeDNSPacketConn{fakeDNSConn: fakeDNSConn{tcp: false, server: server, n: n, s: s}}, nil } type fakeDNSConn struct { @@ -846,10 +849,6 @@ func (f *fakeDNSConn) Read(b []byte) (int, error) { return len(bb), nil } -func (f *fakeDNSConn) ReadFrom(b []byte) (int, Addr, error) { - return 0, nil, nil -} - func (f *fakeDNSConn) Write(b []byte) (int, error) { if f.tcp && len(b) >= 2 { b = b[2:] @@ -860,15 +859,24 @@ func (f *fakeDNSConn) Write(b []byte) (int, error) { return len(b), nil } -func (f *fakeDNSConn) WriteTo(b []byte, addr Addr) (int, error) { - return 0, nil -} - func (f *fakeDNSConn) SetDeadline(t time.Time) error { f.t = t return nil } +type fakeDNSPacketConn struct { + PacketConn + fakeDNSConn +} + +func (f *fakeDNSPacketConn) SetDeadline(t time.Time) error { + return f.fakeDNSConn.SetDeadline(t) +} + +func (f *fakeDNSPacketConn) Close() error { + return f.fakeDNSConn.Close() +} + // UDP round-tripper algorithm should ignore invalid DNS responses (issue 13281). func TestIgnoreDNSForgeries(t *testing.T) { c, s := Pipe() @@ -973,7 +981,7 @@ func TestRetryTimeout(t *testing.T) { var deadline0 time.Time - fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) { + fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) { t.Log(s, q, deadline) if deadline.IsZero() { @@ -1034,7 +1042,7 @@ func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) { } var usedServers []string - fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) { + fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) { usedServers = append(usedServers, s) return mockTXTResponse(q), nil }} @@ -1218,7 +1226,7 @@ func TestStrictErrorsLookupIP(t *testing.T) { } for i, tt := range cases { - fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) { + fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) { t.Log(s, q) switch tt.resolveWhich(q.Questions[0]) { @@ -1356,7 +1364,7 @@ func TestStrictErrorsLookupTXT(t *testing.T) { const searchY = "test.y.golang.org." const txt = "Hello World" - fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) { + fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) { t.Log(s, q) switch q.Questions[0].Name.String() { @@ -1402,7 +1410,7 @@ func TestStrictErrorsLookupTXT(t *testing.T) { func TestDNSGoroutineRace(t *testing.T) { defer dnsWaitGroup.Wait() - fake := fakeDNSServer{func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error) { + fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error) { time.Sleep(10 * time.Microsecond) return dnsmessage.Message{}, poll.ErrTimeout }} @@ -1502,3 +1510,28 @@ func TestIssue12778(t *testing.T) { t.Fatalf("Err = %#v; wanted %q", de.Err, errNoSuchHost.Error()) } } + +// Issue 26573: verify that Conns that don't implement PacketConn are treated +// as streams even when udp was requested. +func TestDNSDialTCP(t *testing.T) { + fake := fakeDNSServer{ + rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { + r := dnsmessage.Message{ + Header: dnsmessage.Header{ + ID: q.Header.ID, + Response: true, + RCode: dnsmessage.RCodeSuccess, + }, + Questions: q.Questions, + } + return r, nil + }, + alwaysTCP: true, + } + r := Resolver{PreferGo: true, Dial: fake.DialContext} + ctx := context.Background() + _, _, err := r.exchange(ctx, "0.0.0.0", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), time.Second) + if err != nil { + t.Fatal("exhange failed:", err) + } +}