1
0
mirror of https://github.com/golang/go synced 2024-11-23 14:20:05 -07:00

net: fix CNAME resolving on Windows

Fixes #8492

LGTM=alex.brainman
R=golang-codereviews, alex.brainman
CC=golang-codereviews
https://golang.org/cl/122200043
This commit is contained in:
Egon Elbre 2014-08-15 16:37:19 +10:00 committed by Alex Brainman
parent 0235f6854c
commit a18a360379
6 changed files with 341 additions and 12 deletions

View File

@ -210,14 +210,21 @@ func lookupCNAME(name string) (cname string, err error) {
defer releaseThread()
var r *syscall.DNSRecord
e := syscall.DnsQuery(name, syscall.DNS_TYPE_CNAME, 0, nil, &r, nil)
// windows returns DNS_INFO_NO_RECORDS if there are no CNAME-s
if errno, ok := e.(syscall.Errno); ok && errno == syscall.DNS_INFO_NO_RECORDS {
// if there are no aliases, the canonical name is the input name
if name == "" || name[len(name)-1] != '.' {
return name + ".", nil
}
return name, nil
}
if e != nil {
return "", os.NewSyscallError("LookupCNAME", e)
}
defer syscall.DnsRecordListFree(r, 1)
if r != nil && r.Type == syscall.DNS_TYPE_CNAME {
v := (*syscall.DNSPTRData)(unsafe.Pointer(&r.Data[0]))
cname = syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]) + "."
}
resolved := resolveCNAME(syscall.StringToUTF16Ptr(name), r)
cname = syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(resolved))[:]) + "."
return
}
@ -236,8 +243,9 @@ func lookupSRV(service, proto, name string) (cname string, addrs []*SRV, err err
return "", nil, os.NewSyscallError("LookupSRV", e)
}
defer syscall.DnsRecordListFree(r, 1)
addrs = make([]*SRV, 0, 10)
for p := r; p != nil && p.Type == syscall.DNS_TYPE_SRV; p = p.Next {
for _, p := range validRecs(r, syscall.DNS_TYPE_SRV, target) {
v := (*syscall.DNSSRVData)(unsafe.Pointer(&p.Data[0]))
addrs = append(addrs, &SRV{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Target))[:]), v.Port, v.Priority, v.Weight})
}
@ -254,8 +262,9 @@ func lookupMX(name string) (mx []*MX, err error) {
return nil, os.NewSyscallError("LookupMX", e)
}
defer syscall.DnsRecordListFree(r, 1)
mx = make([]*MX, 0, 10)
for p := r; p != nil && p.Type == syscall.DNS_TYPE_MX; p = p.Next {
for _, p := range validRecs(r, syscall.DNS_TYPE_MX, name) {
v := (*syscall.DNSMXData)(unsafe.Pointer(&p.Data[0]))
mx = append(mx, &MX{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.NameExchange))[:]) + ".", v.Preference})
}
@ -272,8 +281,9 @@ func lookupNS(name string) (ns []*NS, err error) {
return nil, os.NewSyscallError("LookupNS", e)
}
defer syscall.DnsRecordListFree(r, 1)
ns = make([]*NS, 0, 10)
for p := r; p != nil && p.Type == syscall.DNS_TYPE_NS; p = p.Next {
for _, p := range validRecs(r, syscall.DNS_TYPE_NS, name) {
v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
ns = append(ns, &NS{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]) + "."})
}
@ -289,9 +299,10 @@ func lookupTXT(name string) (txt []string, err error) {
return nil, os.NewSyscallError("LookupTXT", e)
}
defer syscall.DnsRecordListFree(r, 1)
txt = make([]string, 0, 10)
if r != nil && r.Type == syscall.DNS_TYPE_TEXT {
d := (*syscall.DNSTXTData)(unsafe.Pointer(&r.Data[0]))
for _, p := range validRecs(r, syscall.DNS_TYPE_TEXT, name) {
d := (*syscall.DNSTXTData)(unsafe.Pointer(&p.Data[0]))
for _, v := range (*[1 << 10]*uint16)(unsafe.Pointer(&(d.StringArray[0])))[:d.StringCount] {
s := syscall.UTF16ToString((*[1 << 20]uint16)(unsafe.Pointer(v))[:])
txt = append(txt, s)
@ -313,10 +324,58 @@ func lookupAddr(addr string) (name []string, err error) {
return nil, os.NewSyscallError("LookupAddr", e)
}
defer syscall.DnsRecordListFree(r, 1)
name = make([]string, 0, 10)
for p := r; p != nil && p.Type == syscall.DNS_TYPE_PTR; p = p.Next {
for _, p := range validRecs(r, syscall.DNS_TYPE_PTR, arpa) {
v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
name = append(name, syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]))
}
return name, nil
}
const dnsSectionMask = 0x0003
// returns only results applicable to name and resolves CNAME entries
func validRecs(r *syscall.DNSRecord, dnstype uint16, name string) []*syscall.DNSRecord {
cname := syscall.StringToUTF16Ptr(name)
if dnstype != syscall.DNS_TYPE_CNAME {
cname = resolveCNAME(cname, r)
}
rec := make([]*syscall.DNSRecord, 0, 10)
for p := r; p != nil; p = p.Next {
if p.Dw&dnsSectionMask != syscall.DnsSectionAnswer {
continue
}
if p.Type != dnstype {
continue
}
if !syscall.DnsNameCompare(cname, p.Name) {
continue
}
rec = append(rec, p)
}
return rec
}
// returns the last CNAME in chain
func resolveCNAME(name *uint16, r *syscall.DNSRecord) *uint16 {
// limit cname resolving to 10 in case of a infinite CNAME loop
Cname:
for cnameloop := 0; cnameloop < 10; cnameloop++ {
for p := r; p != nil; p = p.Next {
if p.Dw&dnsSectionMask != syscall.DnsSectionAnswer {
continue
}
if p.Type != syscall.DNS_TYPE_CNAME {
continue
}
if !syscall.DnsNameCompare(name, p.Name) {
continue
}
name = (*syscall.DNSPTRData)(unsafe.Pointer(&r.Data[0])).Host
continue Cname
}
break
}
return name
}

View File

@ -0,0 +1,243 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package net
import (
"bytes"
"encoding/json"
"errors"
"os/exec"
"reflect"
"regexp"
"sort"
"strconv"
"strings"
"testing"
)
var nslookupTestServers = []string{"mail.golang.com", "gmail.com"}
func toJson(v interface{}) string {
data, _ := json.Marshal(v)
return string(data)
}
func TestLookupMX(t *testing.T) {
if testing.Short() || !*testExternal {
t.Skip("skipping test to avoid external network")
}
for _, server := range nslookupTestServers {
mx, err := LookupMX(server)
if err != nil {
t.Errorf("failed %s: %s", server, err)
continue
}
if len(mx) == 0 {
t.Errorf("no results")
continue
}
expected, err := nslookupMX(server)
if err != nil {
t.Logf("skipping failed nslookup %s test: %s", server, err)
}
sort.Sort(byPrefAndHost(expected))
sort.Sort(byPrefAndHost(mx))
if !reflect.DeepEqual(expected, mx) {
t.Errorf("different results %s:\texp:%v\tgot:%v", server, toJson(expected), toJson(mx))
}
}
}
func TestLookupCNAME(t *testing.T) {
if testing.Short() || !*testExternal {
t.Skip("skipping test to avoid external network")
}
for _, server := range nslookupTestServers {
cname, err := LookupCNAME(server)
if err != nil {
t.Errorf("failed %s: %s", server, err)
continue
}
if cname == "" {
t.Errorf("no result %s", server)
}
expected, err := nslookupCNAME(server)
if err != nil {
t.Logf("skipping failed nslookup %s test: %s", server, err)
continue
}
if expected != cname {
t.Errorf("different results %s:\texp:%v\tgot:%v", server, expected, cname)
}
}
}
func TestLookupNS(t *testing.T) {
if testing.Short() || !*testExternal {
t.Skip("skipping test to avoid external network")
}
for _, server := range nslookupTestServers {
ns, err := LookupNS(server)
if err != nil {
t.Errorf("failed %s: %s", server, err)
continue
}
if len(ns) == 0 {
t.Errorf("no results")
continue
}
expected, err := nslookupNS(server)
if err != nil {
t.Logf("skipping failed nslookup %s test: %s", server, err)
continue
}
sort.Sort(byHost(expected))
sort.Sort(byHost(ns))
if !reflect.DeepEqual(expected, ns) {
t.Errorf("different results %s:\texp:%v\tgot:%v", toJson(server), toJson(expected), ns)
}
}
}
func TestLookupTXT(t *testing.T) {
if testing.Short() || !*testExternal {
t.Skip("skipping test to avoid external network")
}
for _, server := range nslookupTestServers {
txt, err := LookupTXT(server)
if err != nil {
t.Errorf("failed %s: %s", server, err)
continue
}
if len(txt) == 0 {
t.Errorf("no results")
continue
}
expected, err := nslookupTXT(server)
if err != nil {
t.Logf("skipping failed nslookup %s test: %s", server, err)
continue
}
sort.Strings(expected)
sort.Strings(txt)
if !reflect.DeepEqual(expected, txt) {
t.Errorf("different results %s:\texp:%v\tgot:%v", server, toJson(expected), toJson(txt))
}
}
}
type byPrefAndHost []*MX
func (s byPrefAndHost) Len() int { return len(s) }
func (s byPrefAndHost) Less(i, j int) bool {
if s[i].Pref != s[j].Pref {
return s[i].Pref < s[j].Pref
}
return s[i].Host < s[j].Host
}
func (s byPrefAndHost) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
type byHost []*NS
func (s byHost) Len() int { return len(s) }
func (s byHost) Less(i, j int) bool { return s[i].Host < s[j].Host }
func (s byHost) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func fqdn(s string) string {
if len(s) == 0 || s[len(s)-1] != '.' {
return s + "."
}
return s
}
func nslookup(qtype, name string) (string, error) {
var out bytes.Buffer
var err bytes.Buffer
cmd := exec.Command("nslookup", "-querytype="+qtype, name)
cmd.Stdout = &out
cmd.Stderr = &err
if err := cmd.Run(); err != nil {
return "", err
}
r := strings.Replace(out.String(), "\r\n", "\n", -1)
// nslookup stderr output contains also debug information such as
// "Non-authoritative answer" and it doesn't return the correct errcode
if strings.Contains(err.String(), "can't find") {
return r, errors.New(err.String())
}
return r, nil
}
func nslookupMX(name string) (mx []*MX, err error) {
var r string
if r, err = nslookup("mx", name); err != nil {
return
}
mx = make([]*MX, 0, 10)
// linux nslookup syntax
// golang.org mail exchanger = 2 alt1.aspmx.l.google.com.
rx := regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+mail exchanger\s*=\s*([0-9]+)\s*([a-z0-9.\-]+)$`)
for _, ans := range rx.FindAllStringSubmatch(r, -1) {
pref, _ := strconv.Atoi(ans[2])
mx = append(mx, &MX{fqdn(ans[3]), uint16(pref)})
}
// windows nslookup syntax
// gmail.com MX preference = 30, mail exchanger = alt3.gmail-smtp-in.l.google.com
rx = regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+MX preference\s*=\s*([0-9]+)\s*,\s*mail exchanger\s*=\s*([a-z0-9.\-]+)$`)
for _, ans := range rx.FindAllStringSubmatch(r, -1) {
pref, _ := strconv.Atoi(ans[2])
mx = append(mx, &MX{fqdn(ans[3]), uint16(pref)})
}
return
}
func nslookupNS(name string) (ns []*NS, err error) {
var r string
if r, err = nslookup("ns", name); err != nil {
return
}
ns = make([]*NS, 0, 10)
// golang.org nameserver = ns1.google.com.
rx := regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+nameserver\s*=\s*([a-z0-9.\-]+)$`)
for _, ans := range rx.FindAllStringSubmatch(r, -1) {
ns = append(ns, &NS{fqdn(ans[2])})
}
return
}
func nslookupCNAME(name string) (cname string, err error) {
var r string
if r, err = nslookup("cname", name); err != nil {
return
}
// mail.golang.com canonical name = golang.org.
rx := regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+canonical name\s*=\s*([a-z0-9.\-]+)$`)
// assumes the last CNAME is the correct one
last := name
for _, ans := range rx.FindAllStringSubmatch(r, -1) {
last = ans[2]
}
return fqdn(last), nil
}
func nslookupTXT(name string) (txt []string, err error) {
var r string
if r, err = nslookup("txt", name); err != nil {
return
}
txt = make([]string, 0, 10)
// linux
// golang.org text = "v=spf1 redirect=_spf.google.com"
// windows
// golang.org text =
//
// "v=spf1 redirect=_spf.google.com"
rx := regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+text\s*=\s*"(.*)"$`)
for _, ans := range rx.FindAllStringSubmatch(r, -1) {
txt = append(txt, ans[2])
}
return
}

View File

@ -549,6 +549,7 @@ const socket_error = uintptr(^uint32(0))
//sys GetProtoByName(name string) (p *Protoent, err error) [failretval==nil] = ws2_32.getprotobyname
//sys DnsQuery(name string, qtype uint16, options uint32, extra *byte, qrs **DNSRecord, pr *byte) (status error) = dnsapi.DnsQuery_W
//sys DnsRecordListFree(rl *DNSRecord, freetype uint32) = dnsapi.DnsRecordListFree
//sys DnsNameCompare(name1 *uint16, name2 *uint16) (same bool) = dnsapi.DnsNameCompare_W
//sys GetAddrInfoW(nodename *uint16, servicename *uint16, hints *AddrinfoW, result **AddrinfoW) (sockerr error) = ws2_32.GetAddrInfoW
//sys FreeAddrInfoW(addrinfo *AddrinfoW) = ws2_32.FreeAddrInfoW
//sys GetIfEntry(pIfRow *MibIfRow) (errcode error) = iphlpapi.GetIfEntry

View File

@ -1,4 +1,4 @@
// go build mksyscall_windows.go && ./mksyscall_windows syscall_windows.go security_windows.go syscall_windows_386.go
// go build mksyscall_windows.go && ./mksyscall_windows syscall_windows.go security_windows.go
// MACHINE GENERATED BY THE COMMAND ABOVE; DO NOT EDIT
package syscall
@ -139,6 +139,7 @@ var (
procgetprotobyname = modws2_32.NewProc("getprotobyname")
procDnsQuery_W = moddnsapi.NewProc("DnsQuery_W")
procDnsRecordListFree = moddnsapi.NewProc("DnsRecordListFree")
procDnsNameCompare_W = moddnsapi.NewProc("DnsNameCompare_W")
procGetAddrInfoW = modws2_32.NewProc("GetAddrInfoW")
procFreeAddrInfoW = modws2_32.NewProc("FreeAddrInfoW")
procGetIfEntry = modiphlpapi.NewProc("GetIfEntry")
@ -1634,6 +1635,12 @@ func DnsRecordListFree(rl *DNSRecord, freetype uint32) {
return
}
func DnsNameCompare(name1 *uint16, name2 *uint16) (same bool) {
r0, _, _ := Syscall(procDnsNameCompare_W.Addr(), 2, uintptr(unsafe.Pointer(name1)), uintptr(unsafe.Pointer(name2)), 0)
same = r0 != 0
return
}
func GetAddrInfoW(nodename *uint16, servicename *uint16, hints *AddrinfoW, result **AddrinfoW) (sockerr error) {
r0, _, _ := Syscall6(procGetAddrInfoW.Addr(), 4, uintptr(unsafe.Pointer(nodename)), uintptr(unsafe.Pointer(servicename)), uintptr(unsafe.Pointer(hints)), uintptr(unsafe.Pointer(result)), 0, 0)
if r0 != 0 {

View File

@ -1,4 +1,4 @@
// go build mksyscall_windows.go && ./mksyscall_windows syscall_windows.go security_windows.go syscall_windows_amd64.go
// go build mksyscall_windows.go && ./mksyscall_windows syscall_windows.go security_windows.go
// MACHINE GENERATED BY THE COMMAND ABOVE; DO NOT EDIT
package syscall
@ -139,6 +139,7 @@ var (
procgetprotobyname = modws2_32.NewProc("getprotobyname")
procDnsQuery_W = moddnsapi.NewProc("DnsQuery_W")
procDnsRecordListFree = moddnsapi.NewProc("DnsRecordListFree")
procDnsNameCompare_W = moddnsapi.NewProc("DnsNameCompare_W")
procGetAddrInfoW = modws2_32.NewProc("GetAddrInfoW")
procFreeAddrInfoW = modws2_32.NewProc("FreeAddrInfoW")
procGetIfEntry = modiphlpapi.NewProc("GetIfEntry")
@ -1634,6 +1635,12 @@ func DnsRecordListFree(rl *DNSRecord, freetype uint32) {
return
}
func DnsNameCompare(name1 *uint16, name2 *uint16) (same bool) {
r0, _, _ := Syscall(procDnsNameCompare_W.Addr(), 2, uintptr(unsafe.Pointer(name1)), uintptr(unsafe.Pointer(name2)), 0)
same = r0 != 0
return
}
func GetAddrInfoW(nodename *uint16, servicename *uint16, hints *AddrinfoW, result **AddrinfoW) (sockerr error) {
r0, _, _ := Syscall6(procGetAddrInfoW.Addr(), 4, uintptr(unsafe.Pointer(nodename)), uintptr(unsafe.Pointer(servicename)), uintptr(unsafe.Pointer(hints)), uintptr(unsafe.Pointer(result)), 0, 0)
if r0 != 0 {

View File

@ -689,6 +689,18 @@ const (
DNS_TYPE_NBSTAT = 0xff01
)
const (
DNS_INFO_NO_RECORDS = 0x251D
)
const (
// flags inside DNSRecord.Dw
DnsSectionQuestion = 0x0000
DnsSectionAnswer = 0x0001
DnsSectionAuthority = 0x0002
DnsSectionAdditional = 0x0003
)
type DNSSRVData struct {
Target *uint16
Priority uint16