mirror of
https://github.com/golang/go
synced 2024-11-07 15:26:11 -07:00
net: fix handling of Conns created by Resolver.Dial
The DNS client in net is documented to treat Conns returned by Resolver.Dial which implement PacketConn as UDP and those which don't as TCP regardless of what was requested. golang.org/cl/37879 changed the DNS client to assume that the Conn returned by Resolver.Dial was the requested type which broke compatibility. Fixes #26573 Updates #16218 Change-Id: Idf4f073a4cc3b1db36a3804898df206907f9c43c Reviewed-on: https://go-review.googlesource.com/125735 Run-TryBot: Ian Gudger <igudger@google.com> TryBot-Result: Gobot Gobot <gobot@golang.org> Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
This commit is contained in:
parent
c1e1e882d2
commit
5f5402b723
@ -137,10 +137,10 @@ func (r *Resolver) exchange(ctx context.Context, server string, q dnsmessage.Que
|
||||
}
|
||||
var p dnsmessage.Parser
|
||||
var h dnsmessage.Header
|
||||
if network == "tcp" {
|
||||
p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq)
|
||||
} else {
|
||||
if _, ok := c.(PacketConn); ok {
|
||||
p, h, err = dnsPacketRoundTrip(c, id, q, udpReq)
|
||||
} else {
|
||||
p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq)
|
||||
}
|
||||
c.Close()
|
||||
if err != nil {
|
||||
|
@ -113,7 +113,7 @@ var specialDomainNameTests = []struct {
|
||||
}
|
||||
|
||||
func TestSpecialDomainName(t *testing.T) {
|
||||
fake := fakeDNSServer{func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
|
||||
fake := fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
|
||||
r := dnsmessage.Message{
|
||||
Header: dnsmessage.Header{
|
||||
ID: q.ID,
|
||||
@ -189,7 +189,7 @@ func TestAvoidDNSName(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
var fakeDNSServerSuccessful = fakeDNSServer{func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
|
||||
var fakeDNSServerSuccessful = fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
|
||||
r := dnsmessage.Message{
|
||||
Header: dnsmessage.Header{
|
||||
ID: q.ID,
|
||||
@ -473,7 +473,7 @@ var goLookupIPWithResolverConfigTests = []struct {
|
||||
|
||||
func TestGoLookupIPWithResolverConfig(t *testing.T) {
|
||||
defer dnsWaitGroup.Wait()
|
||||
fake := fakeDNSServer{func(n, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
|
||||
fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
|
||||
switch s {
|
||||
case "[2001:4860:4860::8888]:53", "8.8.8.8:53":
|
||||
break
|
||||
@ -571,7 +571,7 @@ func TestGoLookupIPWithResolverConfig(t *testing.T) {
|
||||
func TestGoLookupIPOrderFallbackToFile(t *testing.T) {
|
||||
defer dnsWaitGroup.Wait()
|
||||
|
||||
fake := fakeDNSServer{func(n, s string, q dnsmessage.Message, tm time.Time) (dnsmessage.Message, error) {
|
||||
fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, tm time.Time) (dnsmessage.Message, error) {
|
||||
r := dnsmessage.Message{
|
||||
Header: dnsmessage.Header{
|
||||
ID: q.ID,
|
||||
@ -641,7 +641,7 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
fake := fakeDNSServer{func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
|
||||
fake := fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
|
||||
r := dnsmessage.Message{
|
||||
Header: dnsmessage.Header{
|
||||
ID: q.ID,
|
||||
@ -696,7 +696,7 @@ func TestIgnoreLameReferrals(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
|
||||
fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
|
||||
t.Log(s, q)
|
||||
r := dnsmessage.Message{
|
||||
Header: dnsmessage.Header{
|
||||
@ -788,12 +788,15 @@ func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) {
|
||||
}
|
||||
|
||||
type fakeDNSServer struct {
|
||||
rh func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error)
|
||||
rh func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error)
|
||||
alwaysTCP bool
|
||||
}
|
||||
|
||||
func (server *fakeDNSServer) DialContext(_ context.Context, n, s string) (Conn, error) {
|
||||
tcp := n == "tcp" || n == "tcp4" || n == "tcp6"
|
||||
return &fakeDNSConn{tcp: tcp, server: server, n: n, s: s}, nil
|
||||
if server.alwaysTCP || n == "tcp" || n == "tcp4" || n == "tcp6" {
|
||||
return &fakeDNSConn{tcp: true, server: server, n: n, s: s}, nil
|
||||
}
|
||||
return &fakeDNSPacketConn{fakeDNSConn: fakeDNSConn{tcp: false, server: server, n: n, s: s}}, nil
|
||||
}
|
||||
|
||||
type fakeDNSConn struct {
|
||||
@ -846,10 +849,6 @@ func (f *fakeDNSConn) Read(b []byte) (int, error) {
|
||||
return len(bb), nil
|
||||
}
|
||||
|
||||
func (f *fakeDNSConn) ReadFrom(b []byte) (int, Addr, error) {
|
||||
return 0, nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeDNSConn) Write(b []byte) (int, error) {
|
||||
if f.tcp && len(b) >= 2 {
|
||||
b = b[2:]
|
||||
@ -860,15 +859,24 @@ func (f *fakeDNSConn) Write(b []byte) (int, error) {
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (f *fakeDNSConn) WriteTo(b []byte, addr Addr) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (f *fakeDNSConn) SetDeadline(t time.Time) error {
|
||||
f.t = t
|
||||
return nil
|
||||
}
|
||||
|
||||
type fakeDNSPacketConn struct {
|
||||
PacketConn
|
||||
fakeDNSConn
|
||||
}
|
||||
|
||||
func (f *fakeDNSPacketConn) SetDeadline(t time.Time) error {
|
||||
return f.fakeDNSConn.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (f *fakeDNSPacketConn) Close() error {
|
||||
return f.fakeDNSConn.Close()
|
||||
}
|
||||
|
||||
// UDP round-tripper algorithm should ignore invalid DNS responses (issue 13281).
|
||||
func TestIgnoreDNSForgeries(t *testing.T) {
|
||||
c, s := Pipe()
|
||||
@ -973,7 +981,7 @@ func TestRetryTimeout(t *testing.T) {
|
||||
|
||||
var deadline0 time.Time
|
||||
|
||||
fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
|
||||
fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
|
||||
t.Log(s, q, deadline)
|
||||
|
||||
if deadline.IsZero() {
|
||||
@ -1034,7 +1042,7 @@ func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) {
|
||||
}
|
||||
|
||||
var usedServers []string
|
||||
fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
|
||||
fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
|
||||
usedServers = append(usedServers, s)
|
||||
return mockTXTResponse(q), nil
|
||||
}}
|
||||
@ -1218,7 +1226,7 @@ func TestStrictErrorsLookupIP(t *testing.T) {
|
||||
}
|
||||
|
||||
for i, tt := range cases {
|
||||
fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
|
||||
fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
|
||||
t.Log(s, q)
|
||||
|
||||
switch tt.resolveWhich(q.Questions[0]) {
|
||||
@ -1356,7 +1364,7 @@ func TestStrictErrorsLookupTXT(t *testing.T) {
|
||||
const searchY = "test.y.golang.org."
|
||||
const txt = "Hello World"
|
||||
|
||||
fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
|
||||
fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
|
||||
t.Log(s, q)
|
||||
|
||||
switch q.Questions[0].Name.String() {
|
||||
@ -1402,7 +1410,7 @@ func TestStrictErrorsLookupTXT(t *testing.T) {
|
||||
func TestDNSGoroutineRace(t *testing.T) {
|
||||
defer dnsWaitGroup.Wait()
|
||||
|
||||
fake := fakeDNSServer{func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error) {
|
||||
fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error) {
|
||||
time.Sleep(10 * time.Microsecond)
|
||||
return dnsmessage.Message{}, poll.ErrTimeout
|
||||
}}
|
||||
@ -1502,3 +1510,28 @@ func TestIssue12778(t *testing.T) {
|
||||
t.Fatalf("Err = %#v; wanted %q", de.Err, errNoSuchHost.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Issue 26573: verify that Conns that don't implement PacketConn are treated
|
||||
// as streams even when udp was requested.
|
||||
func TestDNSDialTCP(t *testing.T) {
|
||||
fake := fakeDNSServer{
|
||||
rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
|
||||
r := dnsmessage.Message{
|
||||
Header: dnsmessage.Header{
|
||||
ID: q.Header.ID,
|
||||
Response: true,
|
||||
RCode: dnsmessage.RCodeSuccess,
|
||||
},
|
||||
Questions: q.Questions,
|
||||
}
|
||||
return r, nil
|
||||
},
|
||||
alwaysTCP: true,
|
||||
}
|
||||
r := Resolver{PreferGo: true, Dial: fake.DialContext}
|
||||
ctx := context.Background()
|
||||
_, _, err := r.exchange(ctx, "0.0.0.0", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), time.Second)
|
||||
if err != nil {
|
||||
t.Fatal("exhange failed:", err)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user