diff --git a/src/net/dnsclient_unix.go b/src/net/dnsclient_unix.go index 5ae21012e3..6a1fdfccb8 100644 --- a/src/net/dnsclient_unix.go +++ b/src/net/dnsclient_unix.go @@ -38,46 +38,67 @@ type dnsConn interface { SetDeadline(time.Time) error - // readDNSResponse reads a DNS response message from the DNS - // transport endpoint and returns the received DNS response - // message. - readDNSResponse() (*dnsMsg, error) - - // writeDNSQuery writes a DNS query message to the DNS - // connection endpoint. - writeDNSQuery(*dnsMsg) error + // dnsRoundTrip executes a single DNS transaction, returning a + // DNS response message for the provided DNS query message. + dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) } -func (c *UDPConn) readDNSResponse() (*dnsMsg, error) { - b := make([]byte, 512) // see RFC 1035 - n, err := c.Read(b) - if err != nil { - return nil, err - } - msg := &dnsMsg{} - if !msg.Unpack(b[:n]) { - return nil, errors.New("cannot unmarshal DNS message") - } - return msg, nil +func (c *UDPConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) { + return dnsRoundTripUDP(c, query) } -func (c *UDPConn) writeDNSQuery(msg *dnsMsg) error { - b, ok := msg.Pack() +// dnsRoundTripUDP implements the dnsRoundTrip interface for RFC 1035's +// "UDP usage" transport mechanism. c should be a packet-oriented connection, +// such as a *UDPConn. +func dnsRoundTripUDP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) { + b, ok := query.Pack() if !ok { - return errors.New("cannot marshal DNS message") + return nil, errors.New("cannot marshal DNS message") } if _, err := c.Write(b); err != nil { - return err + return nil, err + } + + b = make([]byte, 512) // see RFC 1035 + for { + n, err := c.Read(b) + if err != nil { + return nil, err + } + resp := &dnsMsg{} + if !resp.Unpack(b[:n]) || !resp.IsResponseTo(query) { + // Ignore invalid responses as they may be malicious + // forgery attempts. Instead continue waiting until + // timeout. See golang.org/issue/13281. + continue + } + return resp, nil } - return nil } -func (c *TCPConn) readDNSResponse() (*dnsMsg, error) { - b := make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035 +func (c *TCPConn) dnsRoundTrip(out *dnsMsg) (*dnsMsg, error) { + return dnsRoundTripTCP(c, out) +} + +// dnsRoundTripTCP implements the dnsRoundTrip interface for RFC 1035's +// "TCP usage" transport mechanism. c should be a stream-oriented connection, +// such as a *TCPConn. +func dnsRoundTripTCP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) { + b, ok := query.Pack() + if !ok { + return nil, errors.New("cannot marshal DNS message") + } + l := len(b) + b = append([]byte{byte(l >> 8), byte(l)}, b...) + if _, err := c.Write(b); err != nil { + return nil, err + } + + b = make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035 if _, err := io.ReadFull(c, b[:2]); err != nil { return nil, err } - l := int(b[0])<<8 | int(b[1]) + l = int(b[0])<<8 | int(b[1]) if l > len(b) { b = make([]byte, l) } @@ -85,24 +106,14 @@ func (c *TCPConn) readDNSResponse() (*dnsMsg, error) { if err != nil { return nil, err } - msg := &dnsMsg{} - if !msg.Unpack(b[:n]) { + resp := &dnsMsg{} + if !resp.Unpack(b[:n]) { return nil, errors.New("cannot unmarshal DNS message") } - return msg, nil -} - -func (c *TCPConn) writeDNSQuery(msg *dnsMsg) error { - b, ok := msg.Pack() - if !ok { - return errors.New("cannot marshal DNS message") + if !resp.IsResponseTo(query) { + return nil, errors.New("invalid DNS response") } - l := uint16(len(b)) - b = append([]byte{byte(l >> 8), byte(l)}, b...) - if _, err := c.Write(b); err != nil { - return err - } - return nil + return resp, nil } func (d *Dialer) dialDNS(ctx context.Context, network, server string) (dnsConn, error) { @@ -150,16 +161,10 @@ func exchange(ctx context.Context, server, name string, qtype uint16) (*dnsMsg, c.SetDeadline(d) } out.id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano()) - if err := c.writeDNSQuery(&out); err != nil { - return nil, mapErr(err) - } - in, err := c.readDNSResponse() + in, err := c.dnsRoundTrip(&out) if err != nil { return nil, mapErr(err) } - if in.id != out.id { - return nil, errors.New("DNS message ID mismatch") - } if in.truncated { // see RFC 5966 continue } diff --git a/src/net/dnsclient_unix_test.go b/src/net/dnsclient_unix_test.go index 761fb23f14..0b78adb853 100644 --- a/src/net/dnsclient_unix_test.go +++ b/src/net/dnsclient_unix_test.go @@ -567,9 +567,6 @@ func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) { } type fakeDNSConn struct { - // last query - qmu sync.Mutex // guards q - q *dnsMsg // reply handler rh func(*dnsMsg) (*dnsMsg, error) } @@ -586,16 +583,76 @@ func (f *fakeDNSConn) SetDeadline(time.Time) error { return nil } -func (f *fakeDNSConn) writeDNSQuery(q *dnsMsg) error { - f.qmu.Lock() - defer f.qmu.Unlock() - f.q = q - return nil -} - -func (f *fakeDNSConn) readDNSResponse() (*dnsMsg, error) { - f.qmu.Lock() - q := f.q - f.qmu.Unlock() +func (f *fakeDNSConn) dnsRoundTrip(q *dnsMsg) (*dnsMsg, error) { return f.rh(q) } + +// UDP round-tripper algorithm should ignore invalid DNS responses (issue 13281). +func TestIgnoreDNSForgeries(t *testing.T) { + const TestAddr uint32 = 0x80420001 + + c, s := Pipe() + go func() { + b := make([]byte, 512) + n, err := s.Read(b) + if err != nil { + t.Fatal(err) + } + + msg := &dnsMsg{} + if !msg.Unpack(b[:n]) { + t.Fatal("invalid DNS query") + } + + s.Write([]byte("garbage DNS response packet")) + + msg.response = true + msg.id++ // make invalid ID + b, ok := msg.Pack() + if !ok { + t.Fatal("failed to pack DNS response") + } + s.Write(b) + + msg.id-- // restore original ID + msg.answer = []dnsRR{ + &dnsRR_A{ + Hdr: dnsRR_Header{ + Name: "www.example.com.", + Rrtype: dnsTypeA, + Class: dnsClassINET, + Rdlength: 4, + }, + A: TestAddr, + }, + } + + b, ok = msg.Pack() + if !ok { + t.Fatal("failed to pack DNS response") + } + s.Write(b) + }() + + msg := &dnsMsg{ + dnsMsgHdr: dnsMsgHdr{ + id: 42, + }, + question: []dnsQuestion{ + { + Name: "www.example.com.", + Qtype: dnsTypeA, + Qclass: dnsClassINET, + }, + }, + } + + resp, err := dnsRoundTripUDP(c, msg) + if err != nil { + t.Fatalf("dnsRoundTripUDP failed: %v", err) + } + + if got := resp.answer[0].(*dnsRR_A).A; got != TestAddr { + t.Error("got address %v, want %v", got, TestAddr) + } +} diff --git a/src/net/dnsmsg.go b/src/net/dnsmsg.go index c01381f190..5e339c5fbf 100644 --- a/src/net/dnsmsg.go +++ b/src/net/dnsmsg.go @@ -934,3 +934,23 @@ func (dns *dnsMsg) String() string { } return s } + +// IsResponseTo reports whether m is an acceptable response to query. +func (m *dnsMsg) IsResponseTo(query *dnsMsg) bool { + if !m.response { + return false + } + if m.id != query.id { + return false + } + if len(m.question) != len(query.question) { + return false + } + for i, q := range m.question { + q2 := query.question[i] + if !equalASCIILabel(q.Name, q2.Name) || q.Qtype != q2.Qtype || q.Qclass != q2.Qclass { + return false + } + } + return true +} diff --git a/src/net/dnsmsg_test.go b/src/net/dnsmsg_test.go index 841c32fa84..25bd98cff7 100644 --- a/src/net/dnsmsg_test.go +++ b/src/net/dnsmsg_test.go @@ -280,6 +280,124 @@ func TestDNSParseTXTCorruptTXTLengthReply(t *testing.T) { } } +func TestIsResponseTo(t *testing.T) { + // Sample DNS query. + query := dnsMsg{ + dnsMsgHdr: dnsMsgHdr{ + id: 42, + }, + question: []dnsQuestion{ + { + Name: "www.example.com.", + Qtype: dnsTypeA, + Qclass: dnsClassINET, + }, + }, + } + + resp := query + resp.response = true + if !resp.IsResponseTo(&query) { + t.Error("got false, want true") + } + + badResponses := []dnsMsg{ + // Different ID. + { + dnsMsgHdr: dnsMsgHdr{ + id: 43, + response: true, + }, + question: []dnsQuestion{ + { + Name: "www.example.com.", + Qtype: dnsTypeA, + Qclass: dnsClassINET, + }, + }, + }, + + // Different query name. + { + dnsMsgHdr: dnsMsgHdr{ + id: 42, + response: true, + }, + question: []dnsQuestion{ + { + Name: "www.google.com.", + Qtype: dnsTypeA, + Qclass: dnsClassINET, + }, + }, + }, + + // Different query type. + { + dnsMsgHdr: dnsMsgHdr{ + id: 42, + response: true, + }, + question: []dnsQuestion{ + { + Name: "www.example.com.", + Qtype: dnsTypeAAAA, + Qclass: dnsClassINET, + }, + }, + }, + + // Different query class. + { + dnsMsgHdr: dnsMsgHdr{ + id: 42, + response: true, + }, + question: []dnsQuestion{ + { + Name: "www.example.com.", + Qtype: dnsTypeA, + Qclass: dnsClassCSNET, + }, + }, + }, + + // No questions. + { + dnsMsgHdr: dnsMsgHdr{ + id: 42, + response: true, + }, + }, + + // Extra questions. + { + dnsMsgHdr: dnsMsgHdr{ + id: 42, + response: true, + }, + question: []dnsQuestion{ + { + Name: "www.example.com.", + Qtype: dnsTypeA, + Qclass: dnsClassINET, + }, + { + Name: "www.golang.org.", + Qtype: dnsTypeAAAA, + Qclass: dnsClassINET, + }, + }, + }, + } + + for i := range badResponses { + if badResponses[i].IsResponseTo(&query) { + t.Error("%v: got true, want false", i) + } + } +} + // Valid DNS SRV reply const dnsSRVReply = "0901818000010005000000000c5f786d70702d736572766572045f74637006676f6f67" + "6c6503636f6d0000210001c00c002100010000012c00210014000014950c786d70702d" +