mirror of
https://github.com/golang/go
synced 2024-11-23 04:00:03 -07:00
net: keep waiting for valid DNS response until timeout
Prevents denial of service attacks from bogus UDP packets. Fixes #13281. Change-Id: Ifb51b17a1b0807bfd27b144d6037431701184e7b Reviewed-on: https://go-review.googlesource.com/22126 Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org> Run-TryBot: Matthew Dempsky <mdempsky@google.com> TryBot-Result: Gobot Gobot <gobot@golang.org>
This commit is contained in:
parent
9f1ccd647f
commit
3411d63219
@ -38,46 +38,67 @@ type dnsConn interface {
|
||||
|
||||
SetDeadline(time.Time) error
|
||||
|
||||
// readDNSResponse reads a DNS response message from the DNS
|
||||
// transport endpoint and returns the received DNS response
|
||||
// message.
|
||||
readDNSResponse() (*dnsMsg, error)
|
||||
|
||||
// writeDNSQuery writes a DNS query message to the DNS
|
||||
// connection endpoint.
|
||||
writeDNSQuery(*dnsMsg) error
|
||||
// dnsRoundTrip executes a single DNS transaction, returning a
|
||||
// DNS response message for the provided DNS query message.
|
||||
dnsRoundTrip(query *dnsMsg) (*dnsMsg, error)
|
||||
}
|
||||
|
||||
func (c *UDPConn) readDNSResponse() (*dnsMsg, error) {
|
||||
b := make([]byte, 512) // see RFC 1035
|
||||
n, err := c.Read(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msg := &dnsMsg{}
|
||||
if !msg.Unpack(b[:n]) {
|
||||
return nil, errors.New("cannot unmarshal DNS message")
|
||||
}
|
||||
return msg, nil
|
||||
func (c *UDPConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) {
|
||||
return dnsRoundTripUDP(c, query)
|
||||
}
|
||||
|
||||
func (c *UDPConn) writeDNSQuery(msg *dnsMsg) error {
|
||||
b, ok := msg.Pack()
|
||||
// dnsRoundTripUDP implements the dnsRoundTrip interface for RFC 1035's
|
||||
// "UDP usage" transport mechanism. c should be a packet-oriented connection,
|
||||
// such as a *UDPConn.
|
||||
func dnsRoundTripUDP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) {
|
||||
b, ok := query.Pack()
|
||||
if !ok {
|
||||
return errors.New("cannot marshal DNS message")
|
||||
return nil, errors.New("cannot marshal DNS message")
|
||||
}
|
||||
if _, err := c.Write(b); err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b = make([]byte, 512) // see RFC 1035
|
||||
for {
|
||||
n, err := c.Read(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp := &dnsMsg{}
|
||||
if !resp.Unpack(b[:n]) || !resp.IsResponseTo(query) {
|
||||
// Ignore invalid responses as they may be malicious
|
||||
// forgery attempts. Instead continue waiting until
|
||||
// timeout. See golang.org/issue/13281.
|
||||
continue
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *TCPConn) readDNSResponse() (*dnsMsg, error) {
|
||||
b := make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035
|
||||
func (c *TCPConn) dnsRoundTrip(out *dnsMsg) (*dnsMsg, error) {
|
||||
return dnsRoundTripTCP(c, out)
|
||||
}
|
||||
|
||||
// dnsRoundTripTCP implements the dnsRoundTrip interface for RFC 1035's
|
||||
// "TCP usage" transport mechanism. c should be a stream-oriented connection,
|
||||
// such as a *TCPConn.
|
||||
func dnsRoundTripTCP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) {
|
||||
b, ok := query.Pack()
|
||||
if !ok {
|
||||
return nil, errors.New("cannot marshal DNS message")
|
||||
}
|
||||
l := len(b)
|
||||
b = append([]byte{byte(l >> 8), byte(l)}, b...)
|
||||
if _, err := c.Write(b); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b = make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035
|
||||
if _, err := io.ReadFull(c, b[:2]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l := int(b[0])<<8 | int(b[1])
|
||||
l = int(b[0])<<8 | int(b[1])
|
||||
if l > len(b) {
|
||||
b = make([]byte, l)
|
||||
}
|
||||
@ -85,24 +106,14 @@ func (c *TCPConn) readDNSResponse() (*dnsMsg, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msg := &dnsMsg{}
|
||||
if !msg.Unpack(b[:n]) {
|
||||
resp := &dnsMsg{}
|
||||
if !resp.Unpack(b[:n]) {
|
||||
return nil, errors.New("cannot unmarshal DNS message")
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (c *TCPConn) writeDNSQuery(msg *dnsMsg) error {
|
||||
b, ok := msg.Pack()
|
||||
if !ok {
|
||||
return errors.New("cannot marshal DNS message")
|
||||
if !resp.IsResponseTo(query) {
|
||||
return nil, errors.New("invalid DNS response")
|
||||
}
|
||||
l := uint16(len(b))
|
||||
b = append([]byte{byte(l >> 8), byte(l)}, b...)
|
||||
if _, err := c.Write(b); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (d *Dialer) dialDNS(ctx context.Context, network, server string) (dnsConn, error) {
|
||||
@ -150,16 +161,10 @@ func exchange(ctx context.Context, server, name string, qtype uint16) (*dnsMsg,
|
||||
c.SetDeadline(d)
|
||||
}
|
||||
out.id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano())
|
||||
if err := c.writeDNSQuery(&out); err != nil {
|
||||
return nil, mapErr(err)
|
||||
}
|
||||
in, err := c.readDNSResponse()
|
||||
in, err := c.dnsRoundTrip(&out)
|
||||
if err != nil {
|
||||
return nil, mapErr(err)
|
||||
}
|
||||
if in.id != out.id {
|
||||
return nil, errors.New("DNS message ID mismatch")
|
||||
}
|
||||
if in.truncated { // see RFC 5966
|
||||
continue
|
||||
}
|
||||
|
@ -567,9 +567,6 @@ func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) {
|
||||
}
|
||||
|
||||
type fakeDNSConn struct {
|
||||
// last query
|
||||
qmu sync.Mutex // guards q
|
||||
q *dnsMsg
|
||||
// reply handler
|
||||
rh func(*dnsMsg) (*dnsMsg, error)
|
||||
}
|
||||
@ -586,16 +583,76 @@ func (f *fakeDNSConn) SetDeadline(time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeDNSConn) writeDNSQuery(q *dnsMsg) error {
|
||||
f.qmu.Lock()
|
||||
defer f.qmu.Unlock()
|
||||
f.q = q
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeDNSConn) readDNSResponse() (*dnsMsg, error) {
|
||||
f.qmu.Lock()
|
||||
q := f.q
|
||||
f.qmu.Unlock()
|
||||
func (f *fakeDNSConn) dnsRoundTrip(q *dnsMsg) (*dnsMsg, error) {
|
||||
return f.rh(q)
|
||||
}
|
||||
|
||||
// UDP round-tripper algorithm should ignore invalid DNS responses (issue 13281).
|
||||
func TestIgnoreDNSForgeries(t *testing.T) {
|
||||
const TestAddr uint32 = 0x80420001
|
||||
|
||||
c, s := Pipe()
|
||||
go func() {
|
||||
b := make([]byte, 512)
|
||||
n, err := s.Read(b)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
msg := &dnsMsg{}
|
||||
if !msg.Unpack(b[:n]) {
|
||||
t.Fatal("invalid DNS query")
|
||||
}
|
||||
|
||||
s.Write([]byte("garbage DNS response packet"))
|
||||
|
||||
msg.response = true
|
||||
msg.id++ // make invalid ID
|
||||
b, ok := msg.Pack()
|
||||
if !ok {
|
||||
t.Fatal("failed to pack DNS response")
|
||||
}
|
||||
s.Write(b)
|
||||
|
||||
msg.id-- // restore original ID
|
||||
msg.answer = []dnsRR{
|
||||
&dnsRR_A{
|
||||
Hdr: dnsRR_Header{
|
||||
Name: "www.example.com.",
|
||||
Rrtype: dnsTypeA,
|
||||
Class: dnsClassINET,
|
||||
Rdlength: 4,
|
||||
},
|
||||
A: TestAddr,
|
||||
},
|
||||
}
|
||||
|
||||
b, ok = msg.Pack()
|
||||
if !ok {
|
||||
t.Fatal("failed to pack DNS response")
|
||||
}
|
||||
s.Write(b)
|
||||
}()
|
||||
|
||||
msg := &dnsMsg{
|
||||
dnsMsgHdr: dnsMsgHdr{
|
||||
id: 42,
|
||||
},
|
||||
question: []dnsQuestion{
|
||||
{
|
||||
Name: "www.example.com.",
|
||||
Qtype: dnsTypeA,
|
||||
Qclass: dnsClassINET,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := dnsRoundTripUDP(c, msg)
|
||||
if err != nil {
|
||||
t.Fatalf("dnsRoundTripUDP failed: %v", err)
|
||||
}
|
||||
|
||||
if got := resp.answer[0].(*dnsRR_A).A; got != TestAddr {
|
||||
t.Error("got address %v, want %v", got, TestAddr)
|
||||
}
|
||||
}
|
||||
|
@ -934,3 +934,23 @@ func (dns *dnsMsg) String() string {
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// IsResponseTo reports whether m is an acceptable response to query.
|
||||
func (m *dnsMsg) IsResponseTo(query *dnsMsg) bool {
|
||||
if !m.response {
|
||||
return false
|
||||
}
|
||||
if m.id != query.id {
|
||||
return false
|
||||
}
|
||||
if len(m.question) != len(query.question) {
|
||||
return false
|
||||
}
|
||||
for i, q := range m.question {
|
||||
q2 := query.question[i]
|
||||
if !equalASCIILabel(q.Name, q2.Name) || q.Qtype != q2.Qtype || q.Qclass != q2.Qclass {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
@ -280,6 +280,124 @@ func TestDNSParseTXTCorruptTXTLengthReply(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsResponseTo(t *testing.T) {
|
||||
// Sample DNS query.
|
||||
query := dnsMsg{
|
||||
dnsMsgHdr: dnsMsgHdr{
|
||||
id: 42,
|
||||
},
|
||||
question: []dnsQuestion{
|
||||
{
|
||||
Name: "www.example.com.",
|
||||
Qtype: dnsTypeA,
|
||||
Qclass: dnsClassINET,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp := query
|
||||
resp.response = true
|
||||
if !resp.IsResponseTo(&query) {
|
||||
t.Error("got false, want true")
|
||||
}
|
||||
|
||||
badResponses := []dnsMsg{
|
||||
// Different ID.
|
||||
{
|
||||
dnsMsgHdr: dnsMsgHdr{
|
||||
id: 43,
|
||||
response: true,
|
||||
},
|
||||
question: []dnsQuestion{
|
||||
{
|
||||
Name: "www.example.com.",
|
||||
Qtype: dnsTypeA,
|
||||
Qclass: dnsClassINET,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
// Different query name.
|
||||
{
|
||||
dnsMsgHdr: dnsMsgHdr{
|
||||
id: 42,
|
||||
response: true,
|
||||
},
|
||||
question: []dnsQuestion{
|
||||
{
|
||||
Name: "www.google.com.",
|
||||
Qtype: dnsTypeA,
|
||||
Qclass: dnsClassINET,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
// Different query type.
|
||||
{
|
||||
dnsMsgHdr: dnsMsgHdr{
|
||||
id: 42,
|
||||
response: true,
|
||||
},
|
||||
question: []dnsQuestion{
|
||||
{
|
||||
Name: "www.example.com.",
|
||||
Qtype: dnsTypeAAAA,
|
||||
Qclass: dnsClassINET,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
// Different query class.
|
||||
{
|
||||
dnsMsgHdr: dnsMsgHdr{
|
||||
id: 42,
|
||||
response: true,
|
||||
},
|
||||
question: []dnsQuestion{
|
||||
{
|
||||
Name: "www.example.com.",
|
||||
Qtype: dnsTypeA,
|
||||
Qclass: dnsClassCSNET,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
// No questions.
|
||||
{
|
||||
dnsMsgHdr: dnsMsgHdr{
|
||||
id: 42,
|
||||
response: true,
|
||||
},
|
||||
},
|
||||
|
||||
// Extra questions.
|
||||
{
|
||||
dnsMsgHdr: dnsMsgHdr{
|
||||
id: 42,
|
||||
response: true,
|
||||
},
|
||||
question: []dnsQuestion{
|
||||
{
|
||||
Name: "www.example.com.",
|
||||
Qtype: dnsTypeA,
|
||||
Qclass: dnsClassINET,
|
||||
},
|
||||
{
|
||||
Name: "www.golang.org.",
|
||||
Qtype: dnsTypeAAAA,
|
||||
Qclass: dnsClassINET,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i := range badResponses {
|
||||
if badResponses[i].IsResponseTo(&query) {
|
||||
t.Error("%v: got true, want false", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Valid DNS SRV reply
|
||||
const dnsSRVReply = "0901818000010005000000000c5f786d70702d736572766572045f74637006676f6f67" +
|
||||
"6c6503636f6d0000210001c00c002100010000012c00210014000014950c786d70702d" +
|
||||
|
Loading…
Reference in New Issue
Block a user