From c93f74d34bf45bcbfa1cfda5ccd198ed5682ddf4 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Mon, 8 Sep 2014 17:47:12 -0700 Subject: [PATCH] syscall: keep Windows syscall pointers live too Like https://golang.org/cl/139360044 LGTM=rsc, alex.brainman R=alex.brainman, rsc CC=golang-codereviews https://golang.org/cl/138250043 --- src/syscall/mksyscall_windows.go | 24 ++++++++++++++++++++++-- src/syscall/zsyscall_windows.go | 15 +++++++++++---- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/src/syscall/mksyscall_windows.go b/src/syscall/mksyscall_windows.go index 1cdd6b4d22..ea9ee45511 100644 --- a/src/syscall/mksyscall_windows.go +++ b/src/syscall/mksyscall_windows.go @@ -158,6 +158,7 @@ func (p *Param) SyscallArgList() []string { case p.Type[0] == '*': s = fmt.Sprintf("unsafe.Pointer(%s)", p.Name) case p.Type == "string": + p.fn.use(p.tmpVar()) s = fmt.Sprintf("unsafe.Pointer(%s)", p.tmpVar()) case p.Type == "bool": s = p.tmpVar() @@ -303,6 +304,7 @@ type Fn struct { Params []*Param Rets *Rets PrintTrace bool + Used []string dllname string dllfuncname string src string @@ -310,6 +312,15 @@ type Fn struct { curTmpVarIdx int // insure tmp variables have uniq names } +func (f *Fn) use(v string) { + for _, e := range f.Used { + if e == v { + return + } + } + f.Used = append(f.Used, v) +} + // extractParams parses s to extract function parameters. func extractParams(s string, f *Fn) ([]*Param, error) { s = trim(s) @@ -328,7 +339,7 @@ func extractParams(s string, f *Fn) ([]*Param, error) { } } ps[i] = &Param{ - Name: trim(b[0]), + Name: sanitizeName(trim(b[0])), Type: trim(b[1]), fn: f, tmpVarIdx: -1, @@ -337,6 +348,13 @@ func extractParams(s string, f *Fn) ([]*Param, error) { return ps, nil } +func sanitizeName(n string) string { + if n == "use" { + return "use_" + } + return n +} + // extractSection extracts text out of string s starting after start // and ending just before end. found return value will indicate success, // and prefix, body and suffix will contain correspondent parts of string s. @@ -680,7 +698,7 @@ var ( {{define "funcbody"}} func {{.Name}}({{.ParamList}}) {{if .Rets.List}}{{.Rets.List}} {{end}}{ {{template "tmpvars" .}} {{template "syscall" .}} -{{template "seterror" .}}{{template "printtrace" .}} return +{{template "used" .}}{{template "seterror" .}}{{template "printtrace" .}} return } {{end}} @@ -689,6 +707,8 @@ func {{.Name}}({{.ParamList}}) {{if .Rets.List}}{{.Rets.List}} {{end}}{ {{define "syscall"}}{{.Rets.SetReturnValuesCode}}{{.Syscall}}(proc{{.DLLFuncName}}.Addr(), {{.ParamCount}}, {{.SyscallParamList}}){{end}} +{{define "used"}}{{range .Used}}use(unsafe.Pointer({{.}}));{{end}}{{end}} + {{define "seterror"}}{{if .Rets.SetErrorCode}} {{.Rets.SetErrorCode}} {{end}}{{end}} diff --git a/src/syscall/zsyscall_windows.go b/src/syscall/zsyscall_windows.go index 1f44750b7f..9f2c84fb1f 100644 --- a/src/syscall/zsyscall_windows.go +++ b/src/syscall/zsyscall_windows.go @@ -177,6 +177,7 @@ func LoadLibrary(libname string) (handle Handle, err error) { return } r0, _, e1 := Syscall(procLoadLibraryW.Addr(), 1, uintptr(unsafe.Pointer(_p0)), 0, 0) + use(unsafe.Pointer(_p0)) handle = Handle(r0) if handle == 0 { if e1 != 0 { @@ -207,6 +208,7 @@ func GetProcAddress(module Handle, procname string) (proc uintptr, err error) { return } r0, _, e1 := Syscall(procGetProcAddress.Addr(), 2, uintptr(module), uintptr(unsafe.Pointer(_p0)), 0) + use(unsafe.Pointer(_p0)) proc = uintptr(r0) if proc == 0 { if e1 != 0 { @@ -1559,6 +1561,7 @@ func GetHostByName(name string) (h *Hostent, err error) { return } r0, _, e1 := Syscall(procgethostbyname.Addr(), 1, uintptr(unsafe.Pointer(_p0)), 0, 0) + use(unsafe.Pointer(_p0)) h = (*Hostent)(unsafe.Pointer(r0)) if h == nil { if e1 != 0 { @@ -1582,6 +1585,8 @@ func GetServByName(name string, proto string) (s *Servent, err error) { return } r0, _, e1 := Syscall(procgetservbyname.Addr(), 2, uintptr(unsafe.Pointer(_p0)), uintptr(unsafe.Pointer(_p1)), 0) + use(unsafe.Pointer(_p0)) + use(unsafe.Pointer(_p1)) s = (*Servent)(unsafe.Pointer(r0)) if s == nil { if e1 != 0 { @@ -1606,6 +1611,7 @@ func GetProtoByName(name string) (p *Protoent, err error) { return } r0, _, e1 := Syscall(procgetprotobyname.Addr(), 1, uintptr(unsafe.Pointer(_p0)), 0, 0) + use(unsafe.Pointer(_p0)) p = (*Protoent)(unsafe.Pointer(r0)) if p == nil { if e1 != 0 { @@ -1624,6 +1630,7 @@ func DnsQuery(name string, qtype uint16, options uint32, extra *byte, qrs **DNSR return } r0, _, _ := Syscall6(procDnsQuery_W.Addr(), 6, uintptr(unsafe.Pointer(_p0)), uintptr(qtype), uintptr(options), uintptr(unsafe.Pointer(extra)), uintptr(unsafe.Pointer(qrs)), uintptr(unsafe.Pointer(pr))) + use(unsafe.Pointer(_p0)) if r0 != 0 { status = Errno(r0) } @@ -1743,8 +1750,8 @@ func NetApiBufferFree(buf *byte) (neterr error) { return } -func LookupAccountSid(systemName *uint16, sid *SID, name *uint16, nameLen *uint32, refdDomainName *uint16, refdDomainNameLen *uint32, use *uint32) (err error) { - r1, _, e1 := Syscall9(procLookupAccountSidW.Addr(), 7, uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(nameLen)), uintptr(unsafe.Pointer(refdDomainName)), uintptr(unsafe.Pointer(refdDomainNameLen)), uintptr(unsafe.Pointer(use)), 0, 0) +func LookupAccountSid(systemName *uint16, sid *SID, name *uint16, nameLen *uint32, refdDomainName *uint16, refdDomainNameLen *uint32, use_ *uint32) (err error) { + r1, _, e1 := Syscall9(procLookupAccountSidW.Addr(), 7, uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(nameLen)), uintptr(unsafe.Pointer(refdDomainName)), uintptr(unsafe.Pointer(refdDomainNameLen)), uintptr(unsafe.Pointer(use_)), 0, 0) if r1 == 0 { if e1 != 0 { err = error(e1) @@ -1755,8 +1762,8 @@ func LookupAccountSid(systemName *uint16, sid *SID, name *uint16, nameLen *uint3 return } -func LookupAccountName(systemName *uint16, accountName *uint16, sid *SID, sidLen *uint32, refdDomainName *uint16, refdDomainNameLen *uint32, use *uint32) (err error) { - r1, _, e1 := Syscall9(procLookupAccountNameW.Addr(), 7, uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(accountName)), uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(sidLen)), uintptr(unsafe.Pointer(refdDomainName)), uintptr(unsafe.Pointer(refdDomainNameLen)), uintptr(unsafe.Pointer(use)), 0, 0) +func LookupAccountName(systemName *uint16, accountName *uint16, sid *SID, sidLen *uint32, refdDomainName *uint16, refdDomainNameLen *uint32, use_ *uint32) (err error) { + r1, _, e1 := Syscall9(procLookupAccountNameW.Addr(), 7, uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(accountName)), uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(sidLen)), uintptr(unsafe.Pointer(refdDomainName)), uintptr(unsafe.Pointer(refdDomainNameLen)), uintptr(unsafe.Pointer(use_)), 0, 0) if r1 == 0 { if e1 != 0 { err = error(e1)