util/winutil: update UserProfile to ensure any environment variables in the roaming profile path are expanded

Updates #12383

Signed-off-by: Aaron Klotz <aaron@tailscale.com>
pull/12480/head
Aaron Klotz 2 weeks ago
parent a8ee83e2c5
commit 7354547bd8

@ -6,6 +6,7 @@ package winutil
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go //go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go
//go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go //go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go
//sys expandEnvironmentStringsForUser(token windows.Token, src *uint16, dst *uint16, dstLen uint32) (err error) [int32(failretval)==0] = userenv.ExpandEnvironmentStringsForUserW
//sys getApplicationRestartSettings(process windows.Handle, commandLine *uint16, commandLineLen *uint32, flags *uint32) (ret wingoes.HRESULT) = kernel32.GetApplicationRestartSettings //sys getApplicationRestartSettings(process windows.Handle, commandLine *uint16, commandLineLen *uint32, flags *uint32) (ret wingoes.HRESULT) = kernel32.GetApplicationRestartSettings
//sys loadUserProfile(token windows.Token, profileInfo *_PROFILEINFO) (err error) [int32(failretval)==0] = userenv.LoadUserProfileW //sys loadUserProfile(token windows.Token, profileInfo *_PROFILEINFO) (err error) [int32(failretval)==0] = userenv.LoadUserProfileW
//sys queryServiceConfig2(hService windows.Handle, infoLevel uint32, buf *byte, bufLen uint32, bytesNeeded *uint32) (err error) [failretval==0] = advapi32.QueryServiceConfig2W //sys queryServiceConfig2(hService windows.Handle, infoLevel uint32, buf *byte, bufLen uint32, bytesNeeded *uint32) (err error) [failretval==0] = advapi32.QueryServiceConfig2W

@ -80,7 +80,7 @@ func LoadUserProfile(token windows.Token, u *user.User) (up *UserProfile, err er
var roamingProfilePath *uint16 var roamingProfilePath *uint16
if winenv.IsDomainJoined() { if winenv.IsDomainJoined() {
roamingProfilePath, err = getRoamingProfilePath(nil, computerName, userName) roamingProfilePath, err = getRoamingProfilePath(nil, token, computerName, userName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -134,7 +134,7 @@ func (up *UserProfile) Close() error {
return nil return nil
} }
func getRoamingProfilePath(logf logger.Logf, computerName, userName *uint16) (path *uint16, err error) { func getRoamingProfilePath(logf logger.Logf, token windows.Token, computerName, userName *uint16) (path *uint16, err error) {
// logf is for debugging/testing. // logf is for debugging/testing.
if logf == nil { if logf == nil {
logf = logger.Discard logf = logger.Discard
@ -152,19 +152,18 @@ func getRoamingProfilePath(logf logger.Logf, computerName, userName *uint16) (pa
if profilePath == nil { if profilePath == nil {
return nil, nil return nil, nil
} }
if *profilePath == 0 {
var sz int // Empty string
for ptr := unsafe.Pointer(profilePath); *(*uint16)(ptr) != 0; sz++ { return nil, nil
ptr = unsafe.Pointer(uintptr(ptr) + unsafe.Sizeof(*profilePath))
} }
if sz == 0 { var expanded [windows.MAX_PATH + 1]uint16
return nil, nil if err := expandEnvironmentStringsForUser(token, profilePath, &expanded[0], uint32(len(expanded))); err != nil {
return nil, err
} }
buf := unsafe.Slice(profilePath, sz+1) // This buffer is only used briefly, so we don't bother copying it into a shorter slice.
cp := append([]uint16{}, buf...) return &expanded[0], nil
return unsafe.SliceData(cp), nil
} }
func getComputerAndUserName(token windows.Token, u *user.User) (computerName *uint16, userName *uint16, err error) { func getComputerAndUserName(token windows.Token, u *user.User) (computerName *uint16, userName *uint16, err error) {

@ -45,16 +45,17 @@ var (
modrstrtmgr = windows.NewLazySystemDLL("rstrtmgr.dll") modrstrtmgr = windows.NewLazySystemDLL("rstrtmgr.dll")
moduserenv = windows.NewLazySystemDLL("userenv.dll") moduserenv = windows.NewLazySystemDLL("userenv.dll")
procQueryServiceConfig2W = modadvapi32.NewProc("QueryServiceConfig2W") procQueryServiceConfig2W = modadvapi32.NewProc("QueryServiceConfig2W")
procGetApplicationRestartSettings = modkernel32.NewProc("GetApplicationRestartSettings") procGetApplicationRestartSettings = modkernel32.NewProc("GetApplicationRestartSettings")
procRegisterApplicationRestart = modkernel32.NewProc("RegisterApplicationRestart") procRegisterApplicationRestart = modkernel32.NewProc("RegisterApplicationRestart")
procRmEndSession = modrstrtmgr.NewProc("RmEndSession") procRmEndSession = modrstrtmgr.NewProc("RmEndSession")
procRmGetList = modrstrtmgr.NewProc("RmGetList") procRmGetList = modrstrtmgr.NewProc("RmGetList")
procRmJoinSession = modrstrtmgr.NewProc("RmJoinSession") procRmJoinSession = modrstrtmgr.NewProc("RmJoinSession")
procRmRegisterResources = modrstrtmgr.NewProc("RmRegisterResources") procRmRegisterResources = modrstrtmgr.NewProc("RmRegisterResources")
procRmStartSession = modrstrtmgr.NewProc("RmStartSession") procRmStartSession = modrstrtmgr.NewProc("RmStartSession")
procLoadUserProfileW = moduserenv.NewProc("LoadUserProfileW") procExpandEnvironmentStringsForUserW = moduserenv.NewProc("ExpandEnvironmentStringsForUserW")
procUnloadUserProfile = moduserenv.NewProc("UnloadUserProfile") procLoadUserProfileW = moduserenv.NewProc("LoadUserProfileW")
procUnloadUserProfile = moduserenv.NewProc("UnloadUserProfile")
) )
func queryServiceConfig2(hService windows.Handle, infoLevel uint32, buf *byte, bufLen uint32, bytesNeeded *uint32) (err error) { func queryServiceConfig2(hService windows.Handle, infoLevel uint32, buf *byte, bufLen uint32, bytesNeeded *uint32) (err error) {
@ -117,6 +118,14 @@ func rmStartSession(pSession *_RMHANDLE, flags uint32, sessionKey *uint16) (ret
return return
} }
func expandEnvironmentStringsForUser(token windows.Token, src *uint16, dst *uint16, dstLen uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procExpandEnvironmentStringsForUserW.Addr(), 4, uintptr(token), uintptr(unsafe.Pointer(src)), uintptr(unsafe.Pointer(dst)), uintptr(dstLen), 0, 0)
if int32(r1) == 0 {
err = errnoErr(e1)
}
return
}
func loadUserProfile(token windows.Token, profileInfo *_PROFILEINFO) (err error) { func loadUserProfile(token windows.Token, profileInfo *_PROFILEINFO) (err error) {
r1, _, e1 := syscall.Syscall(procLoadUserProfileW.Addr(), 2, uintptr(token), uintptr(unsafe.Pointer(profileInfo)), 0) r1, _, e1 := syscall.Syscall(procLoadUserProfileW.Addr(), 2, uintptr(token), uintptr(unsafe.Pointer(profileInfo)), 0)
if int32(r1) == 0 { if int32(r1) == 0 {

Loading…
Cancel
Save