From cd589c8a73415afbf94a8976f20cbed9d4061ba6 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Tue, 8 Aug 2023 15:31:43 +0200 Subject: [PATCH] os: make MkdirAll support volume names MkdirAll fails to create directories under root paths using volume names (e.g. //?/Volume{GUID}/foo). This is because fixRootDirectory only handle extended length paths using drive letters (e.g. //?/C:/foo). This CL fixes that issue by also detecting volume names without path separator. Updates #22230 Fixes #39785 Change-Id: I813fdc0b968ce71a4297f69245b935558e6cd789 Reviewed-on: https://go-review.googlesource.com/c/go/+/517015 Run-TryBot: Quim Muntal TryBot-Result: Gopher Robot Reviewed-by: Bryan Mills Reviewed-by: Michael Knyszek --- .../syscall/windows/syscall_windows.go | 1 + .../syscall/windows/zsyscall_windows.go | 89 ++++++++++--------- src/os/path.go | 24 +++-- src/os/path_plan9.go | 4 +- src/os/path_unix.go | 4 +- src/os/path_windows.go | 11 --- src/os/path_windows_test.go | 49 ++++++++++ 7 files changed, 118 insertions(+), 64 deletions(-) diff --git a/src/internal/syscall/windows/syscall_windows.go b/src/internal/syscall/windows/syscall_windows.go index 924dd1e1212..ad36bd48a68 100644 --- a/src/internal/syscall/windows/syscall_windows.go +++ b/src/internal/syscall/windows/syscall_windows.go @@ -398,6 +398,7 @@ type FILE_ID_BOTH_DIR_INFO struct { } //sys GetVolumeInformationByHandle(file syscall.Handle, volumeNameBuffer *uint16, volumeNameSize uint32, volumeNameSerialNumber *uint32, maximumComponentLength *uint32, fileSystemFlags *uint32, fileSystemNameBuffer *uint16, fileSystemNameSize uint32) (err error) = GetVolumeInformationByHandleW +//sys GetVolumeNameForVolumeMountPoint(volumeMountPoint *uint16, volumeName *uint16, bufferlength uint32) (err error) = GetVolumeNameForVolumeMountPointW //sys RtlLookupFunctionEntry(pc uintptr, baseAddress *uintptr, table *byte) (ret uintptr) = kernel32.RtlLookupFunctionEntry //sys RtlVirtualUnwind(handlerType uint32, baseAddress uintptr, pc uintptr, entry uintptr, ctxt uintptr, data *uintptr, frame *uintptr, ctxptrs *byte) (ret uintptr) = kernel32.RtlVirtualUnwind diff --git a/src/internal/syscall/windows/zsyscall_windows.go b/src/internal/syscall/windows/zsyscall_windows.go index fb87bd03a27..e3f6d8d2a22 100644 --- a/src/internal/syscall/windows/zsyscall_windows.go +++ b/src/internal/syscall/windows/zsyscall_windows.go @@ -45,46 +45,47 @@ var ( moduserenv = syscall.NewLazyDLL(sysdll.Add("userenv.dll")) modws2_32 = syscall.NewLazyDLL(sysdll.Add("ws2_32.dll")) - procAdjustTokenPrivileges = modadvapi32.NewProc("AdjustTokenPrivileges") - procDuplicateTokenEx = modadvapi32.NewProc("DuplicateTokenEx") - procImpersonateSelf = modadvapi32.NewProc("ImpersonateSelf") - procLookupPrivilegeValueW = modadvapi32.NewProc("LookupPrivilegeValueW") - procOpenSCManagerW = modadvapi32.NewProc("OpenSCManagerW") - procOpenServiceW = modadvapi32.NewProc("OpenServiceW") - procOpenThreadToken = modadvapi32.NewProc("OpenThreadToken") - procQueryServiceStatus = modadvapi32.NewProc("QueryServiceStatus") - procRevertToSelf = modadvapi32.NewProc("RevertToSelf") - procSetTokenInformation = modadvapi32.NewProc("SetTokenInformation") - procSystemFunction036 = modadvapi32.NewProc("SystemFunction036") - procGetAdaptersAddresses = modiphlpapi.NewProc("GetAdaptersAddresses") - procCreateEventW = modkernel32.NewProc("CreateEventW") - procGetACP = modkernel32.NewProc("GetACP") - procGetComputerNameExW = modkernel32.NewProc("GetComputerNameExW") - procGetConsoleCP = modkernel32.NewProc("GetConsoleCP") - procGetCurrentThread = modkernel32.NewProc("GetCurrentThread") - procGetFileInformationByHandleEx = modkernel32.NewProc("GetFileInformationByHandleEx") - procGetFinalPathNameByHandleW = modkernel32.NewProc("GetFinalPathNameByHandleW") - procGetModuleFileNameW = modkernel32.NewProc("GetModuleFileNameW") - procGetTempPath2W = modkernel32.NewProc("GetTempPath2W") - procGetVolumeInformationByHandleW = modkernel32.NewProc("GetVolumeInformationByHandleW") - procLockFileEx = modkernel32.NewProc("LockFileEx") - procModule32FirstW = modkernel32.NewProc("Module32FirstW") - procModule32NextW = modkernel32.NewProc("Module32NextW") - procMoveFileExW = modkernel32.NewProc("MoveFileExW") - procMultiByteToWideChar = modkernel32.NewProc("MultiByteToWideChar") - procRtlLookupFunctionEntry = modkernel32.NewProc("RtlLookupFunctionEntry") - procRtlVirtualUnwind = modkernel32.NewProc("RtlVirtualUnwind") - procSetFileInformationByHandle = modkernel32.NewProc("SetFileInformationByHandle") - procUnlockFileEx = modkernel32.NewProc("UnlockFileEx") - procVirtualQuery = modkernel32.NewProc("VirtualQuery") - procNetShareAdd = modnetapi32.NewProc("NetShareAdd") - procNetShareDel = modnetapi32.NewProc("NetShareDel") - procNetUserGetLocalGroups = modnetapi32.NewProc("NetUserGetLocalGroups") - procGetProcessMemoryInfo = modpsapi.NewProc("GetProcessMemoryInfo") - procCreateEnvironmentBlock = moduserenv.NewProc("CreateEnvironmentBlock") - procDestroyEnvironmentBlock = moduserenv.NewProc("DestroyEnvironmentBlock") - procGetProfilesDirectoryW = moduserenv.NewProc("GetProfilesDirectoryW") - procWSASocketW = modws2_32.NewProc("WSASocketW") + procAdjustTokenPrivileges = modadvapi32.NewProc("AdjustTokenPrivileges") + procDuplicateTokenEx = modadvapi32.NewProc("DuplicateTokenEx") + procImpersonateSelf = modadvapi32.NewProc("ImpersonateSelf") + procLookupPrivilegeValueW = modadvapi32.NewProc("LookupPrivilegeValueW") + procOpenSCManagerW = modadvapi32.NewProc("OpenSCManagerW") + procOpenServiceW = modadvapi32.NewProc("OpenServiceW") + procOpenThreadToken = modadvapi32.NewProc("OpenThreadToken") + procQueryServiceStatus = modadvapi32.NewProc("QueryServiceStatus") + procRevertToSelf = modadvapi32.NewProc("RevertToSelf") + procSetTokenInformation = modadvapi32.NewProc("SetTokenInformation") + procSystemFunction036 = modadvapi32.NewProc("SystemFunction036") + procGetAdaptersAddresses = modiphlpapi.NewProc("GetAdaptersAddresses") + procCreateEventW = modkernel32.NewProc("CreateEventW") + procGetACP = modkernel32.NewProc("GetACP") + procGetComputerNameExW = modkernel32.NewProc("GetComputerNameExW") + procGetConsoleCP = modkernel32.NewProc("GetConsoleCP") + procGetCurrentThread = modkernel32.NewProc("GetCurrentThread") + procGetFileInformationByHandleEx = modkernel32.NewProc("GetFileInformationByHandleEx") + procGetFinalPathNameByHandleW = modkernel32.NewProc("GetFinalPathNameByHandleW") + procGetModuleFileNameW = modkernel32.NewProc("GetModuleFileNameW") + procGetTempPath2W = modkernel32.NewProc("GetTempPath2W") + procGetVolumeInformationByHandleW = modkernel32.NewProc("GetVolumeInformationByHandleW") + procGetVolumeNameForVolumeMountPointW = modkernel32.NewProc("GetVolumeNameForVolumeMountPointW") + procLockFileEx = modkernel32.NewProc("LockFileEx") + procModule32FirstW = modkernel32.NewProc("Module32FirstW") + procModule32NextW = modkernel32.NewProc("Module32NextW") + procMoveFileExW = modkernel32.NewProc("MoveFileExW") + procMultiByteToWideChar = modkernel32.NewProc("MultiByteToWideChar") + procRtlLookupFunctionEntry = modkernel32.NewProc("RtlLookupFunctionEntry") + procRtlVirtualUnwind = modkernel32.NewProc("RtlVirtualUnwind") + procSetFileInformationByHandle = modkernel32.NewProc("SetFileInformationByHandle") + procUnlockFileEx = modkernel32.NewProc("UnlockFileEx") + procVirtualQuery = modkernel32.NewProc("VirtualQuery") + procNetShareAdd = modnetapi32.NewProc("NetShareAdd") + procNetShareDel = modnetapi32.NewProc("NetShareDel") + procNetUserGetLocalGroups = modnetapi32.NewProc("NetUserGetLocalGroups") + procGetProcessMemoryInfo = modpsapi.NewProc("GetProcessMemoryInfo") + procCreateEnvironmentBlock = moduserenv.NewProc("CreateEnvironmentBlock") + procDestroyEnvironmentBlock = moduserenv.NewProc("DestroyEnvironmentBlock") + procGetProfilesDirectoryW = moduserenv.NewProc("GetProfilesDirectoryW") + procWSASocketW = modws2_32.NewProc("WSASocketW") ) func adjustTokenPrivileges(token syscall.Token, disableAllPrivileges bool, newstate *TOKEN_PRIVILEGES, buflen uint32, prevstate *TOKEN_PRIVILEGES, returnlen *uint32) (ret uint32, err error) { @@ -279,6 +280,14 @@ func GetVolumeInformationByHandle(file syscall.Handle, volumeNameBuffer *uint16, return } +func GetVolumeNameForVolumeMountPoint(volumeMountPoint *uint16, volumeName *uint16, bufferlength uint32) (err error) { + r1, _, e1 := syscall.Syscall(procGetVolumeNameForVolumeMountPointW.Addr(), 3, uintptr(unsafe.Pointer(volumeMountPoint)), uintptr(unsafe.Pointer(volumeName)), uintptr(bufferlength)) + if r1 == 0 { + err = errnoErr(e1) + } + return +} + func LockFileEx(file syscall.Handle, flags uint32, reserved uint32, bytesLow uint32, bytesHigh uint32, overlapped *syscall.Overlapped) (err error) { r1, _, e1 := syscall.Syscall6(procLockFileEx.Addr(), 6, uintptr(file), uintptr(flags), uintptr(reserved), uintptr(bytesLow), uintptr(bytesHigh), uintptr(unsafe.Pointer(overlapped))) if r1 == 0 { diff --git a/src/os/path.go b/src/os/path.go index df87887b9be..6ac4cbe20f7 100644 --- a/src/os/path.go +++ b/src/os/path.go @@ -26,19 +26,25 @@ func MkdirAll(path string, perm FileMode) error { } // Slow path: make sure parent exists and then call Mkdir for path. - i := len(path) - for i > 0 && IsPathSeparator(path[i-1]) { // Skip trailing path separator. + + // Extract the parent folder from path by first removing any trailing + // path separator and then scanning backward until finding a path + // separator or reaching the beginning of the string. + i := len(path) - 1 + for i >= 0 && IsPathSeparator(path[i]) { i-- } - - j := i - for j > 0 && !IsPathSeparator(path[j-1]) { // Scan backward over element. - j-- + for i >= 0 && !IsPathSeparator(path[i]) { + i-- + } + if i < 0 { + i = 0 } - if j > 1 { - // Create parent. - err = MkdirAll(fixRootDirectory(path[:j-1]), perm) + // If there is a parent directory, and it is not the volume name, + // recurse to ensure parent directory exists. + if parent := path[:i]; len(parent) > len(volumeName(path)) { + err = MkdirAll(parent, perm) if err != nil { return err } diff --git a/src/os/path_plan9.go b/src/os/path_plan9.go index a54b4b98f12..f1c9dbc048c 100644 --- a/src/os/path_plan9.go +++ b/src/os/path_plan9.go @@ -14,6 +14,6 @@ func IsPathSeparator(c uint8) bool { return PathSeparator == c } -func fixRootDirectory(p string) string { - return p +func volumeName(p string) string { + return "" } diff --git a/src/os/path_unix.go b/src/os/path_unix.go index c975cdb11e0..1c80fa91f8b 100644 --- a/src/os/path_unix.go +++ b/src/os/path_unix.go @@ -70,6 +70,6 @@ func splitPath(path string) (string, string) { return dirname, basename } -func fixRootDirectory(p string) string { - return p +func volumeName(p string) string { + return "" } diff --git a/src/os/path_windows.go b/src/os/path_windows.go index 3356908a360..ec9a87274de 100644 --- a/src/os/path_windows.go +++ b/src/os/path_windows.go @@ -214,14 +214,3 @@ func fixLongPath(path string) string { } return string(pathbuf[:w]) } - -// fixRootDirectory fixes a reference to a drive's root directory to -// have the required trailing slash. -func fixRootDirectory(p string) string { - if len(p) == len(`\\?\c:`) { - if IsPathSeparator(p[0]) && IsPathSeparator(p[1]) && p[2] == '?' && IsPathSeparator(p[3]) && p[5] == ':' { - return p + `\` - } - } - return p -} diff --git a/src/os/path_windows_test.go b/src/os/path_windows_test.go index 2506b4f0d8d..4e5e501d1f1 100644 --- a/src/os/path_windows_test.go +++ b/src/os/path_windows_test.go @@ -5,7 +5,11 @@ package os_test import ( + "fmt" + "internal/syscall/windows" + "internal/testenv" "os" + "path/filepath" "strings" "syscall" "testing" @@ -106,3 +110,48 @@ func TestOpenRootSlash(t *testing.T) { dir.Close() } } + +func testMkdirAllAtRoot(t *testing.T, root string) { + // Create a unique-enough directory name in root. + base := fmt.Sprintf("%s-%d", t.Name(), os.Getpid()) + path := filepath.Join(root, base) + if err := os.MkdirAll(path, 0777); err != nil { + t.Fatalf("MkdirAll(%q) failed: %v", path, err) + } + // Clean up + if err := os.RemoveAll(path); err != nil { + t.Fatal(err) + } +} + +func TestMkdirAllExtendedLengthAtRoot(t *testing.T) { + if testenv.Builder() == "" { + t.Skipf("skipping non-hermetic test outside of Go builders") + } + + const prefix = `\\?\` + vol := filepath.VolumeName(t.TempDir()) + `\` + if len(vol) < 4 || vol[:4] != prefix { + vol = prefix + vol + } + testMkdirAllAtRoot(t, vol) +} + +func TestMkdirAllVolumeNameAtRoot(t *testing.T) { + if testenv.Builder() == "" { + t.Skipf("skipping non-hermetic test outside of Go builders") + } + + vol, err := syscall.UTF16PtrFromString(filepath.VolumeName(t.TempDir()) + `\`) + if err != nil { + t.Fatal(err) + } + const maxVolNameLen = 50 + var buf [maxVolNameLen]uint16 + err = windows.GetVolumeNameForVolumeMountPoint(vol, &buf[0], maxVolNameLen) + if err != nil { + t.Fatal(err) + } + volName := syscall.UTF16ToString(buf[:]) + testMkdirAllAtRoot(t, volName) +}