diff --git a/src/pkg/net/lookup_windows.go b/src/pkg/net/lookup_windows.go index 130364231d4..6a925b0a7ad 100644 --- a/src/pkg/net/lookup_windows.go +++ b/src/pkg/net/lookup_windows.go @@ -210,14 +210,21 @@ func lookupCNAME(name string) (cname string, err error) { defer releaseThread() var r *syscall.DNSRecord e := syscall.DnsQuery(name, syscall.DNS_TYPE_CNAME, 0, nil, &r, nil) + // windows returns DNS_INFO_NO_RECORDS if there are no CNAME-s + if errno, ok := e.(syscall.Errno); ok && errno == syscall.DNS_INFO_NO_RECORDS { + // if there are no aliases, the canonical name is the input name + if name == "" || name[len(name)-1] != '.' { + return name + ".", nil + } + return name, nil + } if e != nil { return "", os.NewSyscallError("LookupCNAME", e) } defer syscall.DnsRecordListFree(r, 1) - if r != nil && r.Type == syscall.DNS_TYPE_CNAME { - v := (*syscall.DNSPTRData)(unsafe.Pointer(&r.Data[0])) - cname = syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]) + "." - } + + resolved := resolveCNAME(syscall.StringToUTF16Ptr(name), r) + cname = syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(resolved))[:]) + "." return } @@ -236,8 +243,9 @@ func lookupSRV(service, proto, name string) (cname string, addrs []*SRV, err err return "", nil, os.NewSyscallError("LookupSRV", e) } defer syscall.DnsRecordListFree(r, 1) + addrs = make([]*SRV, 0, 10) - for p := r; p != nil && p.Type == syscall.DNS_TYPE_SRV; p = p.Next { + for _, p := range validRecs(r, syscall.DNS_TYPE_SRV, target) { v := (*syscall.DNSSRVData)(unsafe.Pointer(&p.Data[0])) addrs = append(addrs, &SRV{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Target))[:]), v.Port, v.Priority, v.Weight}) } @@ -254,8 +262,9 @@ func lookupMX(name string) (mx []*MX, err error) { return nil, os.NewSyscallError("LookupMX", e) } defer syscall.DnsRecordListFree(r, 1) + mx = make([]*MX, 0, 10) - for p := r; p != nil && p.Type == syscall.DNS_TYPE_MX; p = p.Next { + for _, p := range validRecs(r, syscall.DNS_TYPE_MX, name) { v := (*syscall.DNSMXData)(unsafe.Pointer(&p.Data[0])) mx = append(mx, &MX{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.NameExchange))[:]) + ".", v.Preference}) } @@ -272,8 +281,9 @@ func lookupNS(name string) (ns []*NS, err error) { return nil, os.NewSyscallError("LookupNS", e) } defer syscall.DnsRecordListFree(r, 1) + ns = make([]*NS, 0, 10) - for p := r; p != nil && p.Type == syscall.DNS_TYPE_NS; p = p.Next { + for _, p := range validRecs(r, syscall.DNS_TYPE_NS, name) { v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0])) ns = append(ns, &NS{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]) + "."}) } @@ -289,9 +299,10 @@ func lookupTXT(name string) (txt []string, err error) { return nil, os.NewSyscallError("LookupTXT", e) } defer syscall.DnsRecordListFree(r, 1) + txt = make([]string, 0, 10) - if r != nil && r.Type == syscall.DNS_TYPE_TEXT { - d := (*syscall.DNSTXTData)(unsafe.Pointer(&r.Data[0])) + for _, p := range validRecs(r, syscall.DNS_TYPE_TEXT, name) { + d := (*syscall.DNSTXTData)(unsafe.Pointer(&p.Data[0])) for _, v := range (*[1 << 10]*uint16)(unsafe.Pointer(&(d.StringArray[0])))[:d.StringCount] { s := syscall.UTF16ToString((*[1 << 20]uint16)(unsafe.Pointer(v))[:]) txt = append(txt, s) @@ -313,10 +324,58 @@ func lookupAddr(addr string) (name []string, err error) { return nil, os.NewSyscallError("LookupAddr", e) } defer syscall.DnsRecordListFree(r, 1) + name = make([]string, 0, 10) - for p := r; p != nil && p.Type == syscall.DNS_TYPE_PTR; p = p.Next { + for _, p := range validRecs(r, syscall.DNS_TYPE_PTR, arpa) { v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0])) name = append(name, syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:])) } return name, nil } + +const dnsSectionMask = 0x0003 + +// returns only results applicable to name and resolves CNAME entries +func validRecs(r *syscall.DNSRecord, dnstype uint16, name string) []*syscall.DNSRecord { + cname := syscall.StringToUTF16Ptr(name) + if dnstype != syscall.DNS_TYPE_CNAME { + cname = resolveCNAME(cname, r) + } + rec := make([]*syscall.DNSRecord, 0, 10) + for p := r; p != nil; p = p.Next { + if p.Dw&dnsSectionMask != syscall.DnsSectionAnswer { + continue + } + if p.Type != dnstype { + continue + } + if !syscall.DnsNameCompare(cname, p.Name) { + continue + } + rec = append(rec, p) + } + return rec +} + +// returns the last CNAME in chain +func resolveCNAME(name *uint16, r *syscall.DNSRecord) *uint16 { + // limit cname resolving to 10 in case of a infinite CNAME loop +Cname: + for cnameloop := 0; cnameloop < 10; cnameloop++ { + for p := r; p != nil; p = p.Next { + if p.Dw&dnsSectionMask != syscall.DnsSectionAnswer { + continue + } + if p.Type != syscall.DNS_TYPE_CNAME { + continue + } + if !syscall.DnsNameCompare(name, p.Name) { + continue + } + name = (*syscall.DNSPTRData)(unsafe.Pointer(&r.Data[0])).Host + continue Cname + } + break + } + return name +} diff --git a/src/pkg/net/lookup_windows_test.go b/src/pkg/net/lookup_windows_test.go new file mode 100644 index 00000000000..7495b5b5781 --- /dev/null +++ b/src/pkg/net/lookup_windows_test.go @@ -0,0 +1,243 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "bytes" + "encoding/json" + "errors" + "os/exec" + "reflect" + "regexp" + "sort" + "strconv" + "strings" + "testing" +) + +var nslookupTestServers = []string{"mail.golang.com", "gmail.com"} + +func toJson(v interface{}) string { + data, _ := json.Marshal(v) + return string(data) +} + +func TestLookupMX(t *testing.T) { + if testing.Short() || !*testExternal { + t.Skip("skipping test to avoid external network") + } + for _, server := range nslookupTestServers { + mx, err := LookupMX(server) + if err != nil { + t.Errorf("failed %s: %s", server, err) + continue + } + if len(mx) == 0 { + t.Errorf("no results") + continue + } + expected, err := nslookupMX(server) + if err != nil { + t.Logf("skipping failed nslookup %s test: %s", server, err) + } + sort.Sort(byPrefAndHost(expected)) + sort.Sort(byPrefAndHost(mx)) + if !reflect.DeepEqual(expected, mx) { + t.Errorf("different results %s:\texp:%v\tgot:%v", server, toJson(expected), toJson(mx)) + } + } +} + +func TestLookupCNAME(t *testing.T) { + if testing.Short() || !*testExternal { + t.Skip("skipping test to avoid external network") + } + for _, server := range nslookupTestServers { + cname, err := LookupCNAME(server) + if err != nil { + t.Errorf("failed %s: %s", server, err) + continue + } + if cname == "" { + t.Errorf("no result %s", server) + } + expected, err := nslookupCNAME(server) + if err != nil { + t.Logf("skipping failed nslookup %s test: %s", server, err) + continue + } + if expected != cname { + t.Errorf("different results %s:\texp:%v\tgot:%v", server, expected, cname) + } + } +} + +func TestLookupNS(t *testing.T) { + if testing.Short() || !*testExternal { + t.Skip("skipping test to avoid external network") + } + for _, server := range nslookupTestServers { + ns, err := LookupNS(server) + if err != nil { + t.Errorf("failed %s: %s", server, err) + continue + } + if len(ns) == 0 { + t.Errorf("no results") + continue + } + expected, err := nslookupNS(server) + if err != nil { + t.Logf("skipping failed nslookup %s test: %s", server, err) + continue + } + sort.Sort(byHost(expected)) + sort.Sort(byHost(ns)) + if !reflect.DeepEqual(expected, ns) { + t.Errorf("different results %s:\texp:%v\tgot:%v", toJson(server), toJson(expected), ns) + } + } +} + +func TestLookupTXT(t *testing.T) { + if testing.Short() || !*testExternal { + t.Skip("skipping test to avoid external network") + } + for _, server := range nslookupTestServers { + txt, err := LookupTXT(server) + if err != nil { + t.Errorf("failed %s: %s", server, err) + continue + } + if len(txt) == 0 { + t.Errorf("no results") + continue + } + expected, err := nslookupTXT(server) + if err != nil { + t.Logf("skipping failed nslookup %s test: %s", server, err) + continue + } + sort.Strings(expected) + sort.Strings(txt) + if !reflect.DeepEqual(expected, txt) { + t.Errorf("different results %s:\texp:%v\tgot:%v", server, toJson(expected), toJson(txt)) + } + } +} + +type byPrefAndHost []*MX + +func (s byPrefAndHost) Len() int { return len(s) } +func (s byPrefAndHost) Less(i, j int) bool { + if s[i].Pref != s[j].Pref { + return s[i].Pref < s[j].Pref + } + return s[i].Host < s[j].Host +} +func (s byPrefAndHost) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +type byHost []*NS + +func (s byHost) Len() int { return len(s) } +func (s byHost) Less(i, j int) bool { return s[i].Host < s[j].Host } +func (s byHost) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +func fqdn(s string) string { + if len(s) == 0 || s[len(s)-1] != '.' { + return s + "." + } + return s +} + +func nslookup(qtype, name string) (string, error) { + var out bytes.Buffer + var err bytes.Buffer + cmd := exec.Command("nslookup", "-querytype="+qtype, name) + cmd.Stdout = &out + cmd.Stderr = &err + if err := cmd.Run(); err != nil { + return "", err + } + r := strings.Replace(out.String(), "\r\n", "\n", -1) + // nslookup stderr output contains also debug information such as + // "Non-authoritative answer" and it doesn't return the correct errcode + if strings.Contains(err.String(), "can't find") { + return r, errors.New(err.String()) + } + return r, nil +} + +func nslookupMX(name string) (mx []*MX, err error) { + var r string + if r, err = nslookup("mx", name); err != nil { + return + } + mx = make([]*MX, 0, 10) + // linux nslookup syntax + // golang.org mail exchanger = 2 alt1.aspmx.l.google.com. + rx := regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+mail exchanger\s*=\s*([0-9]+)\s*([a-z0-9.\-]+)$`) + for _, ans := range rx.FindAllStringSubmatch(r, -1) { + pref, _ := strconv.Atoi(ans[2]) + mx = append(mx, &MX{fqdn(ans[3]), uint16(pref)}) + } + // windows nslookup syntax + // gmail.com MX preference = 30, mail exchanger = alt3.gmail-smtp-in.l.google.com + rx = regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+MX preference\s*=\s*([0-9]+)\s*,\s*mail exchanger\s*=\s*([a-z0-9.\-]+)$`) + for _, ans := range rx.FindAllStringSubmatch(r, -1) { + pref, _ := strconv.Atoi(ans[2]) + mx = append(mx, &MX{fqdn(ans[3]), uint16(pref)}) + } + return +} + +func nslookupNS(name string) (ns []*NS, err error) { + var r string + if r, err = nslookup("ns", name); err != nil { + return + } + ns = make([]*NS, 0, 10) + // golang.org nameserver = ns1.google.com. + rx := regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+nameserver\s*=\s*([a-z0-9.\-]+)$`) + for _, ans := range rx.FindAllStringSubmatch(r, -1) { + ns = append(ns, &NS{fqdn(ans[2])}) + } + return +} + +func nslookupCNAME(name string) (cname string, err error) { + var r string + if r, err = nslookup("cname", name); err != nil { + return + } + // mail.golang.com canonical name = golang.org. + rx := regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+canonical name\s*=\s*([a-z0-9.\-]+)$`) + // assumes the last CNAME is the correct one + last := name + for _, ans := range rx.FindAllStringSubmatch(r, -1) { + last = ans[2] + } + return fqdn(last), nil +} + +func nslookupTXT(name string) (txt []string, err error) { + var r string + if r, err = nslookup("txt", name); err != nil { + return + } + txt = make([]string, 0, 10) + // linux + // golang.org text = "v=spf1 redirect=_spf.google.com" + + // windows + // golang.org text = + // + // "v=spf1 redirect=_spf.google.com" + rx := regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+text\s*=\s*"(.*)"$`) + for _, ans := range rx.FindAllStringSubmatch(r, -1) { + txt = append(txt, ans[2]) + } + return +} diff --git a/src/pkg/syscall/syscall_windows.go b/src/pkg/syscall/syscall_windows.go index 1fe1ae0fab4..32a7aed0018 100644 --- a/src/pkg/syscall/syscall_windows.go +++ b/src/pkg/syscall/syscall_windows.go @@ -549,6 +549,7 @@ const socket_error = uintptr(^uint32(0)) //sys GetProtoByName(name string) (p *Protoent, err error) [failretval==nil] = ws2_32.getprotobyname //sys DnsQuery(name string, qtype uint16, options uint32, extra *byte, qrs **DNSRecord, pr *byte) (status error) = dnsapi.DnsQuery_W //sys DnsRecordListFree(rl *DNSRecord, freetype uint32) = dnsapi.DnsRecordListFree +//sys DnsNameCompare(name1 *uint16, name2 *uint16) (same bool) = dnsapi.DnsNameCompare_W //sys GetAddrInfoW(nodename *uint16, servicename *uint16, hints *AddrinfoW, result **AddrinfoW) (sockerr error) = ws2_32.GetAddrInfoW //sys FreeAddrInfoW(addrinfo *AddrinfoW) = ws2_32.FreeAddrInfoW //sys GetIfEntry(pIfRow *MibIfRow) (errcode error) = iphlpapi.GetIfEntry diff --git a/src/pkg/syscall/zsyscall_windows_386.go b/src/pkg/syscall/zsyscall_windows_386.go index d55211ee754..1f44750b7f3 100644 --- a/src/pkg/syscall/zsyscall_windows_386.go +++ b/src/pkg/syscall/zsyscall_windows_386.go @@ -1,4 +1,4 @@ -// go build mksyscall_windows.go && ./mksyscall_windows syscall_windows.go security_windows.go syscall_windows_386.go +// go build mksyscall_windows.go && ./mksyscall_windows syscall_windows.go security_windows.go // MACHINE GENERATED BY THE COMMAND ABOVE; DO NOT EDIT package syscall @@ -139,6 +139,7 @@ var ( procgetprotobyname = modws2_32.NewProc("getprotobyname") procDnsQuery_W = moddnsapi.NewProc("DnsQuery_W") procDnsRecordListFree = moddnsapi.NewProc("DnsRecordListFree") + procDnsNameCompare_W = moddnsapi.NewProc("DnsNameCompare_W") procGetAddrInfoW = modws2_32.NewProc("GetAddrInfoW") procFreeAddrInfoW = modws2_32.NewProc("FreeAddrInfoW") procGetIfEntry = modiphlpapi.NewProc("GetIfEntry") @@ -1634,6 +1635,12 @@ func DnsRecordListFree(rl *DNSRecord, freetype uint32) { return } +func DnsNameCompare(name1 *uint16, name2 *uint16) (same bool) { + r0, _, _ := Syscall(procDnsNameCompare_W.Addr(), 2, uintptr(unsafe.Pointer(name1)), uintptr(unsafe.Pointer(name2)), 0) + same = r0 != 0 + return +} + func GetAddrInfoW(nodename *uint16, servicename *uint16, hints *AddrinfoW, result **AddrinfoW) (sockerr error) { r0, _, _ := Syscall6(procGetAddrInfoW.Addr(), 4, uintptr(unsafe.Pointer(nodename)), uintptr(unsafe.Pointer(servicename)), uintptr(unsafe.Pointer(hints)), uintptr(unsafe.Pointer(result)), 0, 0) if r0 != 0 { diff --git a/src/pkg/syscall/zsyscall_windows_amd64.go b/src/pkg/syscall/zsyscall_windows_amd64.go index 47affab73da..1f44750b7f3 100644 --- a/src/pkg/syscall/zsyscall_windows_amd64.go +++ b/src/pkg/syscall/zsyscall_windows_amd64.go @@ -1,4 +1,4 @@ -// go build mksyscall_windows.go && ./mksyscall_windows syscall_windows.go security_windows.go syscall_windows_amd64.go +// go build mksyscall_windows.go && ./mksyscall_windows syscall_windows.go security_windows.go // MACHINE GENERATED BY THE COMMAND ABOVE; DO NOT EDIT package syscall @@ -139,6 +139,7 @@ var ( procgetprotobyname = modws2_32.NewProc("getprotobyname") procDnsQuery_W = moddnsapi.NewProc("DnsQuery_W") procDnsRecordListFree = moddnsapi.NewProc("DnsRecordListFree") + procDnsNameCompare_W = moddnsapi.NewProc("DnsNameCompare_W") procGetAddrInfoW = modws2_32.NewProc("GetAddrInfoW") procFreeAddrInfoW = modws2_32.NewProc("FreeAddrInfoW") procGetIfEntry = modiphlpapi.NewProc("GetIfEntry") @@ -1634,6 +1635,12 @@ func DnsRecordListFree(rl *DNSRecord, freetype uint32) { return } +func DnsNameCompare(name1 *uint16, name2 *uint16) (same bool) { + r0, _, _ := Syscall(procDnsNameCompare_W.Addr(), 2, uintptr(unsafe.Pointer(name1)), uintptr(unsafe.Pointer(name2)), 0) + same = r0 != 0 + return +} + func GetAddrInfoW(nodename *uint16, servicename *uint16, hints *AddrinfoW, result **AddrinfoW) (sockerr error) { r0, _, _ := Syscall6(procGetAddrInfoW.Addr(), 4, uintptr(unsafe.Pointer(nodename)), uintptr(unsafe.Pointer(servicename)), uintptr(unsafe.Pointer(hints)), uintptr(unsafe.Pointer(result)), 0, 0) if r0 != 0 { diff --git a/src/pkg/syscall/ztypes_windows.go b/src/pkg/syscall/ztypes_windows.go index 8b3625f1467..1363da01a88 100644 --- a/src/pkg/syscall/ztypes_windows.go +++ b/src/pkg/syscall/ztypes_windows.go @@ -689,6 +689,18 @@ const ( DNS_TYPE_NBSTAT = 0xff01 ) +const ( + DNS_INFO_NO_RECORDS = 0x251D +) + +const ( + // flags inside DNSRecord.Dw + DnsSectionQuestion = 0x0000 + DnsSectionAnswer = 0x0001 + DnsSectionAuthority = 0x0002 + DnsSectionAdditional = 0x0003 +) + type DNSSRVData struct { Target *uint16 Priority uint16