From fae362e97e852cf04c6c089e61e92c1ad559b29b Mon Sep 17 00:00:00 2001 From: Alexey Borzenkov Date: Wed, 15 May 2013 13:24:54 +1000 Subject: [PATCH] os/user: faster user lookup on Windows Trying to lookup user's display name with directory services can take several seconds when user's computer is not in a domain. As a workaround, check if computer is joined in a domain first, and don't use directory services if it is not. Additionally, don't leak tokens in user.Current(). Fixes #5298. R=golang-dev, bradfitz, alex.brainman, lucio.dere CC=golang-dev https://golang.org/cl/8541047 --- src/pkg/os/user/lookup_windows.go | 81 +++++++++++++++-------- src/pkg/syscall/security_windows.go | 9 +++ src/pkg/syscall/zsyscall_windows_386.go | 9 +++ src/pkg/syscall/zsyscall_windows_amd64.go | 9 +++ 4 files changed, 81 insertions(+), 27 deletions(-) diff --git a/src/pkg/os/user/lookup_windows.go b/src/pkg/os/user/lookup_windows.go index a0a8a4ec10..99c325ff01 100644 --- a/src/pkg/os/user/lookup_windows.go +++ b/src/pkg/os/user/lookup_windows.go @@ -10,37 +10,63 @@ import ( "unsafe" ) -func lookupFullName(domain, username, domainAndUser string) (string, error) { - // try domain controller first - name, e := syscall.TranslateAccountName(domainAndUser, - syscall.NameSamCompatible, syscall.NameDisplay, 50) - if e != nil { - // domain lookup failed, perhaps this pc is not part of domain - d, e := syscall.UTF16PtrFromString(domain) - if e != nil { - return "", e - } - u, e := syscall.UTF16PtrFromString(username) - if e != nil { - return "", e - } - var p *byte - e = syscall.NetUserGetInfo(d, u, 10, &p) - if e != nil { - // path executed when a domain user is disconnected from the domain - // pretend username is fullname - return username, nil - } - defer syscall.NetApiBufferFree(p) - i := (*syscall.UserInfo10)(unsafe.Pointer(p)) - if i.FullName == nil { - return "", nil - } - name = syscall.UTF16ToString((*[1024]uint16)(unsafe.Pointer(i.FullName))[:]) +func isDomainJoined() (bool, error) { + var domain *uint16 + var status uint32 + err := syscall.NetGetJoinInformation(nil, &domain, &status) + if err != nil { + return false, err } + syscall.NetApiBufferFree((*byte)(unsafe.Pointer(domain))) + return status == syscall.NetSetupDomainName, nil +} + +func lookupFullNameDomain(domainAndUser string) (string, error) { + return syscall.TranslateAccountName(domainAndUser, + syscall.NameSamCompatible, syscall.NameDisplay, 50) +} + +func lookupFullNameServer(servername, username string) (string, error) { + s, e := syscall.UTF16PtrFromString(servername) + if e != nil { + return "", e + } + u, e := syscall.UTF16PtrFromString(username) + if e != nil { + return "", e + } + var p *byte + e = syscall.NetUserGetInfo(s, u, 10, &p) + if e != nil { + return "", e + } + defer syscall.NetApiBufferFree(p) + i := (*syscall.UserInfo10)(unsafe.Pointer(p)) + if i.FullName == nil { + return "", nil + } + name := syscall.UTF16ToString((*[1024]uint16)(unsafe.Pointer(i.FullName))[:]) return name, nil } +func lookupFullName(domain, username, domainAndUser string) (string, error) { + joined, err := isDomainJoined() + if err == nil && joined { + name, err := lookupFullNameDomain(domainAndUser) + if err == nil { + return name, nil + } + } + name, err := lookupFullNameServer(domain, username) + if err == nil { + return name, nil + } + // domain worked neigher as a domain nor as a server + // could be domain server unavailable + // pretend username is fullname + return username, nil +} + func newUser(usid *syscall.SID, gid, dir string) (*User, error) { username, domain, t, e := usid.LookupAccount("") if e != nil { @@ -73,6 +99,7 @@ func current() (*User, error) { if e != nil { return nil, e } + defer t.Close() u, e := t.GetTokenUser() if e != nil { return nil, e diff --git a/src/pkg/syscall/security_windows.go b/src/pkg/syscall/security_windows.go index 017b270146..b22ecf578e 100644 --- a/src/pkg/syscall/security_windows.go +++ b/src/pkg/syscall/security_windows.go @@ -58,6 +58,14 @@ func TranslateAccountName(username string, from, to uint32, initSize int) (strin return UTF16ToString(b), nil } +const ( + // do not reorder + NetSetupUnknownStatus = iota + NetSetupUnjoined + NetSetupWorkgroupName + NetSetupDomainName +) + type UserInfo10 struct { Name *uint16 Comment *uint16 @@ -66,6 +74,7 @@ type UserInfo10 struct { } //sys NetUserGetInfo(serverName *uint16, userName *uint16, level uint32, buf **byte) (neterr error) = netapi32.NetUserGetInfo +//sys NetGetJoinInformation(server *uint16, name **uint16, bufType *uint32) (neterr error) = netapi32.NetGetJoinInformation //sys NetApiBufferFree(buf *byte) (neterr error) = netapi32.NetApiBufferFree const ( diff --git a/src/pkg/syscall/zsyscall_windows_386.go b/src/pkg/syscall/zsyscall_windows_386.go index e5c48488ba..838812a620 100644 --- a/src/pkg/syscall/zsyscall_windows_386.go +++ b/src/pkg/syscall/zsyscall_windows_386.go @@ -140,6 +140,7 @@ var ( procTranslateNameW = modsecur32.NewProc("TranslateNameW") procGetUserNameExW = modsecur32.NewProc("GetUserNameExW") procNetUserGetInfo = modnetapi32.NewProc("NetUserGetInfo") + procNetGetJoinInformation = modnetapi32.NewProc("NetGetJoinInformation") procNetApiBufferFree = modnetapi32.NewProc("NetApiBufferFree") procLookupAccountSidW = modadvapi32.NewProc("LookupAccountSidW") procLookupAccountNameW = modadvapi32.NewProc("LookupAccountNameW") @@ -1613,6 +1614,14 @@ func NetUserGetInfo(serverName *uint16, userName *uint16, level uint32, buf **by return } +func NetGetJoinInformation(server *uint16, name **uint16, bufType *uint32) (neterr error) { + r0, _, _ := Syscall(procNetGetJoinInformation.Addr(), 3, uintptr(unsafe.Pointer(server)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(bufType))) + if r0 != 0 { + neterr = Errno(r0) + } + return +} + func NetApiBufferFree(buf *byte) (neterr error) { r0, _, _ := Syscall(procNetApiBufferFree.Addr(), 1, uintptr(unsafe.Pointer(buf)), 0, 0) if r0 != 0 { diff --git a/src/pkg/syscall/zsyscall_windows_amd64.go b/src/pkg/syscall/zsyscall_windows_amd64.go index 465b509ae7..1b403be495 100644 --- a/src/pkg/syscall/zsyscall_windows_amd64.go +++ b/src/pkg/syscall/zsyscall_windows_amd64.go @@ -140,6 +140,7 @@ var ( procTranslateNameW = modsecur32.NewProc("TranslateNameW") procGetUserNameExW = modsecur32.NewProc("GetUserNameExW") procNetUserGetInfo = modnetapi32.NewProc("NetUserGetInfo") + procNetGetJoinInformation = modnetapi32.NewProc("NetGetJoinInformation") procNetApiBufferFree = modnetapi32.NewProc("NetApiBufferFree") procLookupAccountSidW = modadvapi32.NewProc("LookupAccountSidW") procLookupAccountNameW = modadvapi32.NewProc("LookupAccountNameW") @@ -1613,6 +1614,14 @@ func NetUserGetInfo(serverName *uint16, userName *uint16, level uint32, buf **by return } +func NetGetJoinInformation(server *uint16, name **uint16, bufType *uint32) (neterr error) { + r0, _, _ := Syscall(procNetGetJoinInformation.Addr(), 3, uintptr(unsafe.Pointer(server)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(bufType))) + if r0 != 0 { + neterr = Errno(r0) + } + return +} + func NetApiBufferFree(buf *byte) (neterr error) { r0, _, _ := Syscall(procNetApiBufferFree.Addr(), 1, uintptr(unsafe.Pointer(buf)), 0, 0) if r0 != 0 {