mirror of
https://github.com/golang/go
synced 2024-11-19 06:14:39 -07:00
net: separate DNS transport from DNS query-response interaction
Before fixing issue 6579 this CL separates DNS transport from DNS message interaction to make it easier to add builtin DNS resolver control logic. Update #6579 LGTM=alex, kevlar R=golang-codereviews, alex, gobot, iant, minux, kevlar CC=golang-codereviews https://golang.org/cl/101220044
This commit is contained in:
parent
f2f17c0ff2
commit
48e7533783
@ -16,6 +16,7 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"math/rand"
|
||||
"os"
|
||||
@ -23,118 +24,187 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Send a request on the connection and hope for a reply.
|
||||
// Up to cfg.attempts attempts.
|
||||
func exchange(cfg *dnsConfig, c Conn, name string, qtype uint16) (*dnsMsg, error) {
|
||||
_, useTCP := c.(*TCPConn)
|
||||
if len(name) >= 256 {
|
||||
return nil, &DNSError{Err: "name too long", Name: name}
|
||||
}
|
||||
out := new(dnsMsg)
|
||||
out.id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano())
|
||||
out.question = []dnsQuestion{
|
||||
{name, qtype, dnsClassINET},
|
||||
}
|
||||
out.recursion_desired = true
|
||||
msg, ok := out.Pack()
|
||||
if !ok {
|
||||
return nil, &DNSError{Err: "internal error - cannot pack message", Name: name}
|
||||
}
|
||||
if useTCP {
|
||||
mlen := uint16(len(msg))
|
||||
msg = append([]byte{byte(mlen >> 8), byte(mlen)}, msg...)
|
||||
}
|
||||
for attempt := 0; attempt < cfg.attempts; attempt++ {
|
||||
n, err := c.Write(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// A dnsConn represents a DNS transport endpoint.
|
||||
type dnsConn interface {
|
||||
Conn
|
||||
|
||||
// readDNSResponse reads a DNS response message from the DNS
|
||||
// transport endpoint and returns the received DNS response
|
||||
// message.
|
||||
readDNSResponse() (*dnsMsg, error)
|
||||
|
||||
// writeDNSQuery writes a DNS query message to the DNS
|
||||
// connection endpoint.
|
||||
writeDNSQuery(*dnsMsg) error
|
||||
}
|
||||
|
||||
if cfg.timeout == 0 {
|
||||
c.SetReadDeadline(noDeadline)
|
||||
} else {
|
||||
c.SetReadDeadline(time.Now().Add(time.Duration(cfg.timeout) * time.Second))
|
||||
}
|
||||
buf := make([]byte, 2000)
|
||||
if useTCP {
|
||||
n, err = io.ReadFull(c, buf[:2])
|
||||
func (c *UDPConn) readDNSResponse() (*dnsMsg, error) {
|
||||
b := make([]byte, 512) // see RFC 1035
|
||||
n, err := c.Read(b)
|
||||
if err != nil {
|
||||
if e, ok := err.(Error); ok && e.Timeout() {
|
||||
continue
|
||||
}
|
||||
}
|
||||
mlen := int(buf[0])<<8 | int(buf[1])
|
||||
if mlen > len(buf) {
|
||||
buf = make([]byte, mlen)
|
||||
}
|
||||
n, err = io.ReadFull(c, buf[:mlen])
|
||||
} else {
|
||||
n, err = c.Read(buf)
|
||||
}
|
||||
if err != nil {
|
||||
if e, ok := err.(Error); ok && e.Timeout() {
|
||||
continue
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
buf = buf[:n]
|
||||
in := new(dnsMsg)
|
||||
if !in.Unpack(buf) || in.id != out.id {
|
||||
msg := &dnsMsg{}
|
||||
if !msg.Unpack(b[:n]) {
|
||||
return nil, errors.New("cannot unmarshal DNS message")
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (c *UDPConn) writeDNSQuery(msg *dnsMsg) error {
|
||||
b, ok := msg.Pack()
|
||||
if !ok {
|
||||
return errors.New("cannot marshal DNS message")
|
||||
}
|
||||
if _, err := c.Write(b); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *TCPConn) readDNSResponse() (*dnsMsg, error) {
|
||||
b := make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035
|
||||
if _, err := io.ReadFull(c, b[:2]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l := int(b[0])<<8 | int(b[1])
|
||||
if l > len(b) {
|
||||
b = make([]byte, l)
|
||||
}
|
||||
n, err := io.ReadFull(c, b[:l])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msg := &dnsMsg{}
|
||||
if !msg.Unpack(b[:n]) {
|
||||
return nil, errors.New("cannot unmarshal DNS message")
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (c *TCPConn) writeDNSQuery(msg *dnsMsg) error {
|
||||
b, ok := msg.Pack()
|
||||
if !ok {
|
||||
return errors.New("cannot marshal DNS message")
|
||||
}
|
||||
l := uint16(len(b))
|
||||
b = append([]byte{byte(l >> 8), byte(l)}, b...)
|
||||
if _, err := c.Write(b); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Dialer) dialDNS(network, server string) (dnsConn, error) {
|
||||
switch network {
|
||||
case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6":
|
||||
default:
|
||||
return nil, UnknownNetworkError(network)
|
||||
}
|
||||
// Calling Dial here is scary -- we have to be sure not to
|
||||
// dial a name that will require a DNS lookup, or Dial will
|
||||
// call back here to translate it. The DNS config parser has
|
||||
// already checked that all the cfg.servers[i] are IP
|
||||
// addresses, which Dial will use without a DNS lookup.
|
||||
c, err := d.Dial(network, server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch network {
|
||||
case "tcp", "tcp4", "tcp6":
|
||||
return c.(*TCPConn), nil
|
||||
case "udp", "udp4", "udp6":
|
||||
return c.(*UDPConn), nil
|
||||
}
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
// exchange sends a query on the connection and hopes for a response.
|
||||
func exchange(server, name string, qtype uint16, timeout time.Duration) (*dnsMsg, error) {
|
||||
d := Dialer{Timeout: timeout}
|
||||
out := dnsMsg{
|
||||
dnsMsgHdr: dnsMsgHdr{
|
||||
recursion_desired: true,
|
||||
},
|
||||
question: []dnsQuestion{
|
||||
{name, qtype, dnsClassINET},
|
||||
},
|
||||
}
|
||||
for _, network := range []string{"udp", "tcp"} {
|
||||
c, err := d.dialDNS(network, server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer c.Close()
|
||||
if timeout > 0 {
|
||||
c.SetDeadline(time.Now().Add(timeout))
|
||||
}
|
||||
out.id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano())
|
||||
if err := c.writeDNSQuery(&out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
in, err := c.readDNSResponse()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if in.id != out.id {
|
||||
return nil, errors.New("DNS message ID mismatch")
|
||||
}
|
||||
if in.truncated { // see RFC 5966
|
||||
continue
|
||||
}
|
||||
return in, nil
|
||||
}
|
||||
var server string
|
||||
if a := c.RemoteAddr(); a != nil {
|
||||
server = a.String()
|
||||
}
|
||||
return nil, &DNSError{Err: "no answer from server", Name: name, Server: server, IsTimeout: true}
|
||||
return nil, errors.New("no answer from DNS server")
|
||||
}
|
||||
|
||||
// Do a lookup for a single name, which must be rooted
|
||||
// (otherwise answer will not find the answers).
|
||||
func tryOneName(cfg *dnsConfig, name string, qtype uint16) (cname string, addrs []dnsRR, err error) {
|
||||
func tryOneName(cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, error) {
|
||||
if len(cfg.servers) == 0 {
|
||||
return "", nil, &DNSError{Err: "no DNS servers", Name: name}
|
||||
}
|
||||
for i := 0; i < len(cfg.servers); i++ {
|
||||
// Calling Dial here is scary -- we have to be sure
|
||||
// not to dial a name that will require a DNS lookup,
|
||||
// or Dial will call back here to translate it.
|
||||
// The DNS config parser has already checked that
|
||||
// all the cfg.servers[i] are IP addresses, which
|
||||
// Dial will use without a DNS lookup.
|
||||
server := cfg.servers[i] + ":53"
|
||||
c, cerr := Dial("udp", server)
|
||||
if cerr != nil {
|
||||
err = cerr
|
||||
if len(name) >= 256 {
|
||||
return "", nil, &DNSError{Err: "DNS name too long", Name: name}
|
||||
}
|
||||
timeout := time.Duration(cfg.timeout) * time.Second
|
||||
var lastErr error
|
||||
for _, server := range cfg.servers {
|
||||
server += ":53"
|
||||
lastErr = &DNSError{
|
||||
Err: "no answer from DNS server",
|
||||
Name: name,
|
||||
Server: server,
|
||||
IsTimeout: true,
|
||||
}
|
||||
for i := 0; i < cfg.attempts; i++ {
|
||||
msg, err := exchange(server, name, qtype, timeout)
|
||||
if err != nil {
|
||||
if nerr, ok := err.(Error); ok && nerr.Timeout() {
|
||||
lastErr = &DNSError{
|
||||
Err: nerr.Error(),
|
||||
Name: name,
|
||||
Server: server,
|
||||
IsTimeout: true,
|
||||
}
|
||||
continue
|
||||
|
||||
}
|
||||
msg, merr := exchange(cfg, c, name, qtype)
|
||||
c.Close()
|
||||
if merr != nil {
|
||||
err = merr
|
||||
continue
|
||||
lastErr = &DNSError{
|
||||
Err: err.Error(),
|
||||
Name: name,
|
||||
Server: server,
|
||||
}
|
||||
if msg.truncated { // see RFC 5966
|
||||
c, cerr = Dial("tcp", server)
|
||||
if cerr != nil {
|
||||
err = cerr
|
||||
continue
|
||||
}
|
||||
msg, merr = exchange(cfg, c, name, qtype)
|
||||
c.Close()
|
||||
if merr != nil {
|
||||
err = merr
|
||||
continue
|
||||
}
|
||||
}
|
||||
cname, addrs, err = answer(name, server, msg, qtype)
|
||||
if err == nil || err.(*DNSError).Err == noSuchHost {
|
||||
break
|
||||
}
|
||||
cname, addrs, err := answer(name, server, msg, qtype)
|
||||
if err == nil || err.(*DNSError).Err == noSuchHost {
|
||||
return cname, addrs, err
|
||||
}
|
||||
return
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
return "", nil, lastErr
|
||||
}
|
||||
|
||||
func convertRR_A(records []dnsRR) []IP {
|
||||
|
@ -16,19 +16,79 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestTCPLookup(t *testing.T) {
|
||||
var dnsTransportFallbackTests = []struct {
|
||||
server string
|
||||
name string
|
||||
qtype uint16
|
||||
timeout int
|
||||
rcode int
|
||||
}{
|
||||
// Querying "com." with qtype=255 usually makes an answer
|
||||
// which requires more than 512 bytes.
|
||||
{"8.8.8.8:53", "com.", dnsTypeALL, 2, dnsRcodeSuccess},
|
||||
{"8.8.4.4:53", "com.", dnsTypeALL, 4, dnsRcodeSuccess},
|
||||
}
|
||||
|
||||
func TestDNSTransportFallback(t *testing.T) {
|
||||
if testing.Short() || !*testExternal {
|
||||
t.Skip("skipping test to avoid external network")
|
||||
}
|
||||
c, err := Dial("tcp", "8.8.8.8:53")
|
||||
|
||||
for _, tt := range dnsTransportFallbackTests {
|
||||
timeout := time.Duration(tt.timeout) * time.Second
|
||||
msg, err := exchange(tt.server, tt.name, tt.qtype, timeout)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial failed: %v", err)
|
||||
t.Error(err)
|
||||
continue
|
||||
}
|
||||
defer c.Close()
|
||||
cfg := &dnsConfig{timeout: 10, attempts: 3}
|
||||
_, err = exchange(cfg, c, "com.", dnsTypeALL)
|
||||
switch msg.rcode {
|
||||
case tt.rcode, dnsRcodeServerFailure:
|
||||
default:
|
||||
t.Errorf("got %v from %v; want %v", msg.rcode, tt.server, tt.rcode)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// See RFC 6761 for further information about the reserved, pseudo
|
||||
// domain names.
|
||||
var specialDomainNameTests = []struct {
|
||||
name string
|
||||
qtype uint16
|
||||
rcode int
|
||||
}{
|
||||
// Name resoltion APIs and libraries should not recongnize the
|
||||
// followings as special.
|
||||
{"1.0.168.192.in-addr.arpa.", dnsTypePTR, dnsRcodeNameError},
|
||||
{"test.", dnsTypeALL, dnsRcodeNameError},
|
||||
{"example.com.", dnsTypeALL, dnsRcodeSuccess},
|
||||
|
||||
// Name resoltion APIs and libraries should recongnize the
|
||||
// followings as special and should not send any queries.
|
||||
// Though, we test those names here for verifying nagative
|
||||
// answers at DNS query-response interaction level.
|
||||
{"localhost.", dnsTypeALL, dnsRcodeNameError},
|
||||
{"invalid.", dnsTypeALL, dnsRcodeNameError},
|
||||
}
|
||||
|
||||
func TestSpecialDomainName(t *testing.T) {
|
||||
if testing.Short() || !*testExternal {
|
||||
t.Skip("skipping test to avoid external network")
|
||||
}
|
||||
|
||||
server := "8.8.8.8:53"
|
||||
for _, tt := range specialDomainNameTests {
|
||||
msg, err := exchange(server, tt.name, tt.qtype, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("exchange failed: %v", err)
|
||||
t.Error(err)
|
||||
continue
|
||||
}
|
||||
switch msg.rcode {
|
||||
case tt.rcode, dnsRcodeServerFailure:
|
||||
default:
|
||||
t.Errorf("got %v from %v; want %v", msg.rcode, server, tt.rcode)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user