mirror of
https://github.com/golang/go
synced 2024-11-23 14:20:05 -07:00
net: fix CNAME resolving on Windows
Fixes #8492 LGTM=alex.brainman R=golang-codereviews, alex.brainman CC=golang-codereviews https://golang.org/cl/122200043
This commit is contained in:
parent
0235f6854c
commit
a18a360379
@ -210,14 +210,21 @@ func lookupCNAME(name string) (cname string, err error) {
|
||||
defer releaseThread()
|
||||
var r *syscall.DNSRecord
|
||||
e := syscall.DnsQuery(name, syscall.DNS_TYPE_CNAME, 0, nil, &r, nil)
|
||||
// windows returns DNS_INFO_NO_RECORDS if there are no CNAME-s
|
||||
if errno, ok := e.(syscall.Errno); ok && errno == syscall.DNS_INFO_NO_RECORDS {
|
||||
// if there are no aliases, the canonical name is the input name
|
||||
if name == "" || name[len(name)-1] != '.' {
|
||||
return name + ".", nil
|
||||
}
|
||||
return name, nil
|
||||
}
|
||||
if e != nil {
|
||||
return "", os.NewSyscallError("LookupCNAME", e)
|
||||
}
|
||||
defer syscall.DnsRecordListFree(r, 1)
|
||||
if r != nil && r.Type == syscall.DNS_TYPE_CNAME {
|
||||
v := (*syscall.DNSPTRData)(unsafe.Pointer(&r.Data[0]))
|
||||
cname = syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]) + "."
|
||||
}
|
||||
|
||||
resolved := resolveCNAME(syscall.StringToUTF16Ptr(name), r)
|
||||
cname = syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(resolved))[:]) + "."
|
||||
return
|
||||
}
|
||||
|
||||
@ -236,8 +243,9 @@ func lookupSRV(service, proto, name string) (cname string, addrs []*SRV, err err
|
||||
return "", nil, os.NewSyscallError("LookupSRV", e)
|
||||
}
|
||||
defer syscall.DnsRecordListFree(r, 1)
|
||||
|
||||
addrs = make([]*SRV, 0, 10)
|
||||
for p := r; p != nil && p.Type == syscall.DNS_TYPE_SRV; p = p.Next {
|
||||
for _, p := range validRecs(r, syscall.DNS_TYPE_SRV, target) {
|
||||
v := (*syscall.DNSSRVData)(unsafe.Pointer(&p.Data[0]))
|
||||
addrs = append(addrs, &SRV{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Target))[:]), v.Port, v.Priority, v.Weight})
|
||||
}
|
||||
@ -254,8 +262,9 @@ func lookupMX(name string) (mx []*MX, err error) {
|
||||
return nil, os.NewSyscallError("LookupMX", e)
|
||||
}
|
||||
defer syscall.DnsRecordListFree(r, 1)
|
||||
|
||||
mx = make([]*MX, 0, 10)
|
||||
for p := r; p != nil && p.Type == syscall.DNS_TYPE_MX; p = p.Next {
|
||||
for _, p := range validRecs(r, syscall.DNS_TYPE_MX, name) {
|
||||
v := (*syscall.DNSMXData)(unsafe.Pointer(&p.Data[0]))
|
||||
mx = append(mx, &MX{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.NameExchange))[:]) + ".", v.Preference})
|
||||
}
|
||||
@ -272,8 +281,9 @@ func lookupNS(name string) (ns []*NS, err error) {
|
||||
return nil, os.NewSyscallError("LookupNS", e)
|
||||
}
|
||||
defer syscall.DnsRecordListFree(r, 1)
|
||||
|
||||
ns = make([]*NS, 0, 10)
|
||||
for p := r; p != nil && p.Type == syscall.DNS_TYPE_NS; p = p.Next {
|
||||
for _, p := range validRecs(r, syscall.DNS_TYPE_NS, name) {
|
||||
v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
|
||||
ns = append(ns, &NS{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]) + "."})
|
||||
}
|
||||
@ -289,9 +299,10 @@ func lookupTXT(name string) (txt []string, err error) {
|
||||
return nil, os.NewSyscallError("LookupTXT", e)
|
||||
}
|
||||
defer syscall.DnsRecordListFree(r, 1)
|
||||
|
||||
txt = make([]string, 0, 10)
|
||||
if r != nil && r.Type == syscall.DNS_TYPE_TEXT {
|
||||
d := (*syscall.DNSTXTData)(unsafe.Pointer(&r.Data[0]))
|
||||
for _, p := range validRecs(r, syscall.DNS_TYPE_TEXT, name) {
|
||||
d := (*syscall.DNSTXTData)(unsafe.Pointer(&p.Data[0]))
|
||||
for _, v := range (*[1 << 10]*uint16)(unsafe.Pointer(&(d.StringArray[0])))[:d.StringCount] {
|
||||
s := syscall.UTF16ToString((*[1 << 20]uint16)(unsafe.Pointer(v))[:])
|
||||
txt = append(txt, s)
|
||||
@ -313,10 +324,58 @@ func lookupAddr(addr string) (name []string, err error) {
|
||||
return nil, os.NewSyscallError("LookupAddr", e)
|
||||
}
|
||||
defer syscall.DnsRecordListFree(r, 1)
|
||||
|
||||
name = make([]string, 0, 10)
|
||||
for p := r; p != nil && p.Type == syscall.DNS_TYPE_PTR; p = p.Next {
|
||||
for _, p := range validRecs(r, syscall.DNS_TYPE_PTR, arpa) {
|
||||
v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
|
||||
name = append(name, syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]))
|
||||
}
|
||||
return name, nil
|
||||
}
|
||||
|
||||
const dnsSectionMask = 0x0003
|
||||
|
||||
// returns only results applicable to name and resolves CNAME entries
|
||||
func validRecs(r *syscall.DNSRecord, dnstype uint16, name string) []*syscall.DNSRecord {
|
||||
cname := syscall.StringToUTF16Ptr(name)
|
||||
if dnstype != syscall.DNS_TYPE_CNAME {
|
||||
cname = resolveCNAME(cname, r)
|
||||
}
|
||||
rec := make([]*syscall.DNSRecord, 0, 10)
|
||||
for p := r; p != nil; p = p.Next {
|
||||
if p.Dw&dnsSectionMask != syscall.DnsSectionAnswer {
|
||||
continue
|
||||
}
|
||||
if p.Type != dnstype {
|
||||
continue
|
||||
}
|
||||
if !syscall.DnsNameCompare(cname, p.Name) {
|
||||
continue
|
||||
}
|
||||
rec = append(rec, p)
|
||||
}
|
||||
return rec
|
||||
}
|
||||
|
||||
// returns the last CNAME in chain
|
||||
func resolveCNAME(name *uint16, r *syscall.DNSRecord) *uint16 {
|
||||
// limit cname resolving to 10 in case of a infinite CNAME loop
|
||||
Cname:
|
||||
for cnameloop := 0; cnameloop < 10; cnameloop++ {
|
||||
for p := r; p != nil; p = p.Next {
|
||||
if p.Dw&dnsSectionMask != syscall.DnsSectionAnswer {
|
||||
continue
|
||||
}
|
||||
if p.Type != syscall.DNS_TYPE_CNAME {
|
||||
continue
|
||||
}
|
||||
if !syscall.DnsNameCompare(name, p.Name) {
|
||||
continue
|
||||
}
|
||||
name = (*syscall.DNSPTRData)(unsafe.Pointer(&r.Data[0])).Host
|
||||
continue Cname
|
||||
}
|
||||
break
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
243
src/pkg/net/lookup_windows_test.go
Normal file
243
src/pkg/net/lookup_windows_test.go
Normal file
@ -0,0 +1,243 @@
|
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os/exec"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var nslookupTestServers = []string{"mail.golang.com", "gmail.com"}
|
||||
|
||||
func toJson(v interface{}) string {
|
||||
data, _ := json.Marshal(v)
|
||||
return string(data)
|
||||
}
|
||||
|
||||
func TestLookupMX(t *testing.T) {
|
||||
if testing.Short() || !*testExternal {
|
||||
t.Skip("skipping test to avoid external network")
|
||||
}
|
||||
for _, server := range nslookupTestServers {
|
||||
mx, err := LookupMX(server)
|
||||
if err != nil {
|
||||
t.Errorf("failed %s: %s", server, err)
|
||||
continue
|
||||
}
|
||||
if len(mx) == 0 {
|
||||
t.Errorf("no results")
|
||||
continue
|
||||
}
|
||||
expected, err := nslookupMX(server)
|
||||
if err != nil {
|
||||
t.Logf("skipping failed nslookup %s test: %s", server, err)
|
||||
}
|
||||
sort.Sort(byPrefAndHost(expected))
|
||||
sort.Sort(byPrefAndHost(mx))
|
||||
if !reflect.DeepEqual(expected, mx) {
|
||||
t.Errorf("different results %s:\texp:%v\tgot:%v", server, toJson(expected), toJson(mx))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookupCNAME(t *testing.T) {
|
||||
if testing.Short() || !*testExternal {
|
||||
t.Skip("skipping test to avoid external network")
|
||||
}
|
||||
for _, server := range nslookupTestServers {
|
||||
cname, err := LookupCNAME(server)
|
||||
if err != nil {
|
||||
t.Errorf("failed %s: %s", server, err)
|
||||
continue
|
||||
}
|
||||
if cname == "" {
|
||||
t.Errorf("no result %s", server)
|
||||
}
|
||||
expected, err := nslookupCNAME(server)
|
||||
if err != nil {
|
||||
t.Logf("skipping failed nslookup %s test: %s", server, err)
|
||||
continue
|
||||
}
|
||||
if expected != cname {
|
||||
t.Errorf("different results %s:\texp:%v\tgot:%v", server, expected, cname)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookupNS(t *testing.T) {
|
||||
if testing.Short() || !*testExternal {
|
||||
t.Skip("skipping test to avoid external network")
|
||||
}
|
||||
for _, server := range nslookupTestServers {
|
||||
ns, err := LookupNS(server)
|
||||
if err != nil {
|
||||
t.Errorf("failed %s: %s", server, err)
|
||||
continue
|
||||
}
|
||||
if len(ns) == 0 {
|
||||
t.Errorf("no results")
|
||||
continue
|
||||
}
|
||||
expected, err := nslookupNS(server)
|
||||
if err != nil {
|
||||
t.Logf("skipping failed nslookup %s test: %s", server, err)
|
||||
continue
|
||||
}
|
||||
sort.Sort(byHost(expected))
|
||||
sort.Sort(byHost(ns))
|
||||
if !reflect.DeepEqual(expected, ns) {
|
||||
t.Errorf("different results %s:\texp:%v\tgot:%v", toJson(server), toJson(expected), ns)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookupTXT(t *testing.T) {
|
||||
if testing.Short() || !*testExternal {
|
||||
t.Skip("skipping test to avoid external network")
|
||||
}
|
||||
for _, server := range nslookupTestServers {
|
||||
txt, err := LookupTXT(server)
|
||||
if err != nil {
|
||||
t.Errorf("failed %s: %s", server, err)
|
||||
continue
|
||||
}
|
||||
if len(txt) == 0 {
|
||||
t.Errorf("no results")
|
||||
continue
|
||||
}
|
||||
expected, err := nslookupTXT(server)
|
||||
if err != nil {
|
||||
t.Logf("skipping failed nslookup %s test: %s", server, err)
|
||||
continue
|
||||
}
|
||||
sort.Strings(expected)
|
||||
sort.Strings(txt)
|
||||
if !reflect.DeepEqual(expected, txt) {
|
||||
t.Errorf("different results %s:\texp:%v\tgot:%v", server, toJson(expected), toJson(txt))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type byPrefAndHost []*MX
|
||||
|
||||
func (s byPrefAndHost) Len() int { return len(s) }
|
||||
func (s byPrefAndHost) Less(i, j int) bool {
|
||||
if s[i].Pref != s[j].Pref {
|
||||
return s[i].Pref < s[j].Pref
|
||||
}
|
||||
return s[i].Host < s[j].Host
|
||||
}
|
||||
func (s byPrefAndHost) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||
|
||||
type byHost []*NS
|
||||
|
||||
func (s byHost) Len() int { return len(s) }
|
||||
func (s byHost) Less(i, j int) bool { return s[i].Host < s[j].Host }
|
||||
func (s byHost) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||
|
||||
func fqdn(s string) string {
|
||||
if len(s) == 0 || s[len(s)-1] != '.' {
|
||||
return s + "."
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func nslookup(qtype, name string) (string, error) {
|
||||
var out bytes.Buffer
|
||||
var err bytes.Buffer
|
||||
cmd := exec.Command("nslookup", "-querytype="+qtype, name)
|
||||
cmd.Stdout = &out
|
||||
cmd.Stderr = &err
|
||||
if err := cmd.Run(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
r := strings.Replace(out.String(), "\r\n", "\n", -1)
|
||||
// nslookup stderr output contains also debug information such as
|
||||
// "Non-authoritative answer" and it doesn't return the correct errcode
|
||||
if strings.Contains(err.String(), "can't find") {
|
||||
return r, errors.New(err.String())
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func nslookupMX(name string) (mx []*MX, err error) {
|
||||
var r string
|
||||
if r, err = nslookup("mx", name); err != nil {
|
||||
return
|
||||
}
|
||||
mx = make([]*MX, 0, 10)
|
||||
// linux nslookup syntax
|
||||
// golang.org mail exchanger = 2 alt1.aspmx.l.google.com.
|
||||
rx := regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+mail exchanger\s*=\s*([0-9]+)\s*([a-z0-9.\-]+)$`)
|
||||
for _, ans := range rx.FindAllStringSubmatch(r, -1) {
|
||||
pref, _ := strconv.Atoi(ans[2])
|
||||
mx = append(mx, &MX{fqdn(ans[3]), uint16(pref)})
|
||||
}
|
||||
// windows nslookup syntax
|
||||
// gmail.com MX preference = 30, mail exchanger = alt3.gmail-smtp-in.l.google.com
|
||||
rx = regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+MX preference\s*=\s*([0-9]+)\s*,\s*mail exchanger\s*=\s*([a-z0-9.\-]+)$`)
|
||||
for _, ans := range rx.FindAllStringSubmatch(r, -1) {
|
||||
pref, _ := strconv.Atoi(ans[2])
|
||||
mx = append(mx, &MX{fqdn(ans[3]), uint16(pref)})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func nslookupNS(name string) (ns []*NS, err error) {
|
||||
var r string
|
||||
if r, err = nslookup("ns", name); err != nil {
|
||||
return
|
||||
}
|
||||
ns = make([]*NS, 0, 10)
|
||||
// golang.org nameserver = ns1.google.com.
|
||||
rx := regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+nameserver\s*=\s*([a-z0-9.\-]+)$`)
|
||||
for _, ans := range rx.FindAllStringSubmatch(r, -1) {
|
||||
ns = append(ns, &NS{fqdn(ans[2])})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func nslookupCNAME(name string) (cname string, err error) {
|
||||
var r string
|
||||
if r, err = nslookup("cname", name); err != nil {
|
||||
return
|
||||
}
|
||||
// mail.golang.com canonical name = golang.org.
|
||||
rx := regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+canonical name\s*=\s*([a-z0-9.\-]+)$`)
|
||||
// assumes the last CNAME is the correct one
|
||||
last := name
|
||||
for _, ans := range rx.FindAllStringSubmatch(r, -1) {
|
||||
last = ans[2]
|
||||
}
|
||||
return fqdn(last), nil
|
||||
}
|
||||
|
||||
func nslookupTXT(name string) (txt []string, err error) {
|
||||
var r string
|
||||
if r, err = nslookup("txt", name); err != nil {
|
||||
return
|
||||
}
|
||||
txt = make([]string, 0, 10)
|
||||
// linux
|
||||
// golang.org text = "v=spf1 redirect=_spf.google.com"
|
||||
|
||||
// windows
|
||||
// golang.org text =
|
||||
//
|
||||
// "v=spf1 redirect=_spf.google.com"
|
||||
rx := regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+text\s*=\s*"(.*)"$`)
|
||||
for _, ans := range rx.FindAllStringSubmatch(r, -1) {
|
||||
txt = append(txt, ans[2])
|
||||
}
|
||||
return
|
||||
}
|
@ -549,6 +549,7 @@ const socket_error = uintptr(^uint32(0))
|
||||
//sys GetProtoByName(name string) (p *Protoent, err error) [failretval==nil] = ws2_32.getprotobyname
|
||||
//sys DnsQuery(name string, qtype uint16, options uint32, extra *byte, qrs **DNSRecord, pr *byte) (status error) = dnsapi.DnsQuery_W
|
||||
//sys DnsRecordListFree(rl *DNSRecord, freetype uint32) = dnsapi.DnsRecordListFree
|
||||
//sys DnsNameCompare(name1 *uint16, name2 *uint16) (same bool) = dnsapi.DnsNameCompare_W
|
||||
//sys GetAddrInfoW(nodename *uint16, servicename *uint16, hints *AddrinfoW, result **AddrinfoW) (sockerr error) = ws2_32.GetAddrInfoW
|
||||
//sys FreeAddrInfoW(addrinfo *AddrinfoW) = ws2_32.FreeAddrInfoW
|
||||
//sys GetIfEntry(pIfRow *MibIfRow) (errcode error) = iphlpapi.GetIfEntry
|
||||
|
@ -1,4 +1,4 @@
|
||||
// go build mksyscall_windows.go && ./mksyscall_windows syscall_windows.go security_windows.go syscall_windows_386.go
|
||||
// go build mksyscall_windows.go && ./mksyscall_windows syscall_windows.go security_windows.go
|
||||
// MACHINE GENERATED BY THE COMMAND ABOVE; DO NOT EDIT
|
||||
|
||||
package syscall
|
||||
@ -139,6 +139,7 @@ var (
|
||||
procgetprotobyname = modws2_32.NewProc("getprotobyname")
|
||||
procDnsQuery_W = moddnsapi.NewProc("DnsQuery_W")
|
||||
procDnsRecordListFree = moddnsapi.NewProc("DnsRecordListFree")
|
||||
procDnsNameCompare_W = moddnsapi.NewProc("DnsNameCompare_W")
|
||||
procGetAddrInfoW = modws2_32.NewProc("GetAddrInfoW")
|
||||
procFreeAddrInfoW = modws2_32.NewProc("FreeAddrInfoW")
|
||||
procGetIfEntry = modiphlpapi.NewProc("GetIfEntry")
|
||||
@ -1634,6 +1635,12 @@ func DnsRecordListFree(rl *DNSRecord, freetype uint32) {
|
||||
return
|
||||
}
|
||||
|
||||
func DnsNameCompare(name1 *uint16, name2 *uint16) (same bool) {
|
||||
r0, _, _ := Syscall(procDnsNameCompare_W.Addr(), 2, uintptr(unsafe.Pointer(name1)), uintptr(unsafe.Pointer(name2)), 0)
|
||||
same = r0 != 0
|
||||
return
|
||||
}
|
||||
|
||||
func GetAddrInfoW(nodename *uint16, servicename *uint16, hints *AddrinfoW, result **AddrinfoW) (sockerr error) {
|
||||
r0, _, _ := Syscall6(procGetAddrInfoW.Addr(), 4, uintptr(unsafe.Pointer(nodename)), uintptr(unsafe.Pointer(servicename)), uintptr(unsafe.Pointer(hints)), uintptr(unsafe.Pointer(result)), 0, 0)
|
||||
if r0 != 0 {
|
||||
|
@ -1,4 +1,4 @@
|
||||
// go build mksyscall_windows.go && ./mksyscall_windows syscall_windows.go security_windows.go syscall_windows_amd64.go
|
||||
// go build mksyscall_windows.go && ./mksyscall_windows syscall_windows.go security_windows.go
|
||||
// MACHINE GENERATED BY THE COMMAND ABOVE; DO NOT EDIT
|
||||
|
||||
package syscall
|
||||
@ -139,6 +139,7 @@ var (
|
||||
procgetprotobyname = modws2_32.NewProc("getprotobyname")
|
||||
procDnsQuery_W = moddnsapi.NewProc("DnsQuery_W")
|
||||
procDnsRecordListFree = moddnsapi.NewProc("DnsRecordListFree")
|
||||
procDnsNameCompare_W = moddnsapi.NewProc("DnsNameCompare_W")
|
||||
procGetAddrInfoW = modws2_32.NewProc("GetAddrInfoW")
|
||||
procFreeAddrInfoW = modws2_32.NewProc("FreeAddrInfoW")
|
||||
procGetIfEntry = modiphlpapi.NewProc("GetIfEntry")
|
||||
@ -1634,6 +1635,12 @@ func DnsRecordListFree(rl *DNSRecord, freetype uint32) {
|
||||
return
|
||||
}
|
||||
|
||||
func DnsNameCompare(name1 *uint16, name2 *uint16) (same bool) {
|
||||
r0, _, _ := Syscall(procDnsNameCompare_W.Addr(), 2, uintptr(unsafe.Pointer(name1)), uintptr(unsafe.Pointer(name2)), 0)
|
||||
same = r0 != 0
|
||||
return
|
||||
}
|
||||
|
||||
func GetAddrInfoW(nodename *uint16, servicename *uint16, hints *AddrinfoW, result **AddrinfoW) (sockerr error) {
|
||||
r0, _, _ := Syscall6(procGetAddrInfoW.Addr(), 4, uintptr(unsafe.Pointer(nodename)), uintptr(unsafe.Pointer(servicename)), uintptr(unsafe.Pointer(hints)), uintptr(unsafe.Pointer(result)), 0, 0)
|
||||
if r0 != 0 {
|
||||
|
@ -689,6 +689,18 @@ const (
|
||||
DNS_TYPE_NBSTAT = 0xff01
|
||||
)
|
||||
|
||||
const (
|
||||
DNS_INFO_NO_RECORDS = 0x251D
|
||||
)
|
||||
|
||||
const (
|
||||
// flags inside DNSRecord.Dw
|
||||
DnsSectionQuestion = 0x0000
|
||||
DnsSectionAnswer = 0x0001
|
||||
DnsSectionAuthority = 0x0002
|
||||
DnsSectionAdditional = 0x0003
|
||||
)
|
||||
|
||||
type DNSSRVData struct {
|
||||
Target *uint16
|
||||
Priority uint16
|
||||
|
Loading…
Reference in New Issue
Block a user