1
0
mirror of https://github.com/golang/go synced 2024-11-07 15:26:11 -07:00

net: fix handling of Conns created by Resolver.Dial

The DNS client in net is documented to treat Conns returned by
Resolver.Dial which implement PacketConn as UDP and those which don't as
TCP regardless of what was requested. golang.org/cl/37879 changed the
DNS client to assume that the Conn returned by Resolver.Dial was the
requested type which broke compatibility.

Fixes #26573
Updates #16218

Change-Id: Idf4f073a4cc3b1db36a3804898df206907f9c43c
Reviewed-on: https://go-review.googlesource.com/125735
Run-TryBot: Ian Gudger <igudger@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
This commit is contained in:
Ian Gudger 2018-07-24 16:19:19 -07:00 committed by Brad Fitzpatrick
parent c1e1e882d2
commit 5f5402b723
2 changed files with 58 additions and 25 deletions

View File

@ -137,10 +137,10 @@ func (r *Resolver) exchange(ctx context.Context, server string, q dnsmessage.Que
}
var p dnsmessage.Parser
var h dnsmessage.Header
if network == "tcp" {
p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq)
} else {
if _, ok := c.(PacketConn); ok {
p, h, err = dnsPacketRoundTrip(c, id, q, udpReq)
} else {
p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq)
}
c.Close()
if err != nil {

View File

@ -113,7 +113,7 @@ var specialDomainNameTests = []struct {
}
func TestSpecialDomainName(t *testing.T) {
fake := fakeDNSServer{func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
fake := fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
r := dnsmessage.Message{
Header: dnsmessage.Header{
ID: q.ID,
@ -189,7 +189,7 @@ func TestAvoidDNSName(t *testing.T) {
}
}
var fakeDNSServerSuccessful = fakeDNSServer{func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
var fakeDNSServerSuccessful = fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
r := dnsmessage.Message{
Header: dnsmessage.Header{
ID: q.ID,
@ -473,7 +473,7 @@ var goLookupIPWithResolverConfigTests = []struct {
func TestGoLookupIPWithResolverConfig(t *testing.T) {
defer dnsWaitGroup.Wait()
fake := fakeDNSServer{func(n, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
switch s {
case "[2001:4860:4860::8888]:53", "8.8.8.8:53":
break
@ -571,7 +571,7 @@ func TestGoLookupIPWithResolverConfig(t *testing.T) {
func TestGoLookupIPOrderFallbackToFile(t *testing.T) {
defer dnsWaitGroup.Wait()
fake := fakeDNSServer{func(n, s string, q dnsmessage.Message, tm time.Time) (dnsmessage.Message, error) {
fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, tm time.Time) (dnsmessage.Message, error) {
r := dnsmessage.Message{
Header: dnsmessage.Header{
ID: q.ID,
@ -641,7 +641,7 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) {
t.Fatal(err)
}
fake := fakeDNSServer{func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
fake := fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
r := dnsmessage.Message{
Header: dnsmessage.Header{
ID: q.ID,
@ -696,7 +696,7 @@ func TestIgnoreLameReferrals(t *testing.T) {
t.Fatal(err)
}
fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
t.Log(s, q)
r := dnsmessage.Message{
Header: dnsmessage.Header{
@ -788,12 +788,15 @@ func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) {
}
type fakeDNSServer struct {
rh func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error)
rh func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error)
alwaysTCP bool
}
func (server *fakeDNSServer) DialContext(_ context.Context, n, s string) (Conn, error) {
tcp := n == "tcp" || n == "tcp4" || n == "tcp6"
return &fakeDNSConn{tcp: tcp, server: server, n: n, s: s}, nil
if server.alwaysTCP || n == "tcp" || n == "tcp4" || n == "tcp6" {
return &fakeDNSConn{tcp: true, server: server, n: n, s: s}, nil
}
return &fakeDNSPacketConn{fakeDNSConn: fakeDNSConn{tcp: false, server: server, n: n, s: s}}, nil
}
type fakeDNSConn struct {
@ -846,10 +849,6 @@ func (f *fakeDNSConn) Read(b []byte) (int, error) {
return len(bb), nil
}
func (f *fakeDNSConn) ReadFrom(b []byte) (int, Addr, error) {
return 0, nil, nil
}
func (f *fakeDNSConn) Write(b []byte) (int, error) {
if f.tcp && len(b) >= 2 {
b = b[2:]
@ -860,15 +859,24 @@ func (f *fakeDNSConn) Write(b []byte) (int, error) {
return len(b), nil
}
func (f *fakeDNSConn) WriteTo(b []byte, addr Addr) (int, error) {
return 0, nil
}
func (f *fakeDNSConn) SetDeadline(t time.Time) error {
f.t = t
return nil
}
type fakeDNSPacketConn struct {
PacketConn
fakeDNSConn
}
func (f *fakeDNSPacketConn) SetDeadline(t time.Time) error {
return f.fakeDNSConn.SetDeadline(t)
}
func (f *fakeDNSPacketConn) Close() error {
return f.fakeDNSConn.Close()
}
// UDP round-tripper algorithm should ignore invalid DNS responses (issue 13281).
func TestIgnoreDNSForgeries(t *testing.T) {
c, s := Pipe()
@ -973,7 +981,7 @@ func TestRetryTimeout(t *testing.T) {
var deadline0 time.Time
fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
t.Log(s, q, deadline)
if deadline.IsZero() {
@ -1034,7 +1042,7 @@ func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) {
}
var usedServers []string
fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
usedServers = append(usedServers, s)
return mockTXTResponse(q), nil
}}
@ -1218,7 +1226,7 @@ func TestStrictErrorsLookupIP(t *testing.T) {
}
for i, tt := range cases {
fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
t.Log(s, q)
switch tt.resolveWhich(q.Questions[0]) {
@ -1356,7 +1364,7 @@ func TestStrictErrorsLookupTXT(t *testing.T) {
const searchY = "test.y.golang.org."
const txt = "Hello World"
fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
t.Log(s, q)
switch q.Questions[0].Name.String() {
@ -1402,7 +1410,7 @@ func TestStrictErrorsLookupTXT(t *testing.T) {
func TestDNSGoroutineRace(t *testing.T) {
defer dnsWaitGroup.Wait()
fake := fakeDNSServer{func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error) {
fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error) {
time.Sleep(10 * time.Microsecond)
return dnsmessage.Message{}, poll.ErrTimeout
}}
@ -1502,3 +1510,28 @@ func TestIssue12778(t *testing.T) {
t.Fatalf("Err = %#v; wanted %q", de.Err, errNoSuchHost.Error())
}
}
// Issue 26573: verify that Conns that don't implement PacketConn are treated
// as streams even when udp was requested.
func TestDNSDialTCP(t *testing.T) {
fake := fakeDNSServer{
rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
r := dnsmessage.Message{
Header: dnsmessage.Header{
ID: q.Header.ID,
Response: true,
RCode: dnsmessage.RCodeSuccess,
},
Questions: q.Questions,
}
return r, nil
},
alwaysTCP: true,
}
r := Resolver{PreferGo: true, Dial: fake.DialContext}
ctx := context.Background()
_, _, err := r.exchange(ctx, "0.0.0.0", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), time.Second)
if err != nil {
t.Fatal("exhange failed:", err)
}
}