diff --git a/cmd/tailscaled/tailscaled_windows.go b/cmd/tailscaled/tailscaled_windows.go index e8f901a66..112dee653 100644 --- a/cmd/tailscaled/tailscaled_windows.go +++ b/cmd/tailscaled/tailscaled_windows.go @@ -529,12 +529,11 @@ func uninstallWinTun(logf logger.Logf) { func fullyQualifiedWintunPath(logf logger.Logf) string { var dir string - var buf [windows.MAX_PATH]uint16 - length := uint32(len(buf)) - if err := windows.QueryFullProcessImageName(windows.CurrentProcess(), 0, &buf[0], &length); err != nil { - logf("QueryFullProcessImageName failed: %v", err) + imgName, err := winutil.ProcessImageName(windows.CurrentProcess()) + if err != nil { + logf("ProcessImageName failed: %v", err) } else { - dir = filepath.Dir(windows.UTF16ToString(buf[:length])) + dir = filepath.Dir(imgName) } return filepath.Join(dir, "wintun.dll") diff --git a/util/winutil/winutil_windows.go b/util/winutil/winutil_windows.go index 039c7547d..8fff9f056 100644 --- a/util/winutil/winutil_windows.go +++ b/util/winutil/winutil_windows.go @@ -225,21 +225,40 @@ func isSIDValidPrincipal(uid string) bool { } } -// EnableCurrentThreadPrivilege enables the named privilege in the current -// thread access token. The current goroutine is also locked to the OS thread -// (runtime.LockOSThread). Callers must call the returned disable function when -// done with the privileged task. -func EnableCurrentThreadPrivilege(name string) (disable func() error, err error) { +// EnableCurrentThreadPrivilege enables the named privilege +// in the current thread's access token. The current goroutine is also locked to +// the OS thread (runtime.LockOSThread). Callers must call the returned disable +// function when done with the privileged task. +func EnableCurrentThreadPrivilege(name string) (disable func(), err error) { + return EnableCurrentThreadPrivileges([]string{name}) +} + +// EnableCurrentThreadPrivileges enables the named privileges +// in the current thread's access token. The current goroutine is also locked to +// the OS thread (runtime.LockOSThread). Callers must call the returned disable +// function when done with the privileged task. +func EnableCurrentThreadPrivileges(names []string) (disable func(), err error) { runtime.LockOSThread() + if len(names) == 0 { + // Nothing to enable; no-op isn't really an error... + return runtime.UnlockOSThread, nil + } if err := windows.ImpersonateSelf(windows.SecurityImpersonation); err != nil { runtime.UnlockOSThread() return nil, err } - disable = func() error { + + disable = func() { defer runtime.UnlockOSThread() - return windows.RevertToSelf() + // If RevertToSelf fails, it's not really recoverable and we should panic. + // Failure to do so would leak the privileges we're enabling, which is a + // security issue. + if err := windows.RevertToSelf(); err != nil { + panic(fmt.Sprintf("RevertToSelf failed: %v", err)) + } } + defer func() { if err != nil { disable() @@ -254,19 +273,38 @@ func EnableCurrentThreadPrivilege(name string) (disable func() error, err error) } defer t.Close() - var tp windows.Tokenprivileges + tp := newTokenPrivileges(len(names)) + privs := tp.AllPrivileges() + for i := range privs { + var privStr *uint16 + privStr, err = windows.UTF16PtrFromString(names[i]) + if err != nil { + return nil, err + } + err = windows.LookupPrivilegeValue(nil, privStr, &privs[i].Luid) + if err != nil { + return nil, err + } + privs[i].Attributes = windows.SE_PRIVILEGE_ENABLED + } - privStr, err := syscall.UTF16PtrFromString(name) + err = windows.AdjustTokenPrivileges(t, false, tp, 0, nil, nil) if err != nil { return nil, err } - err = windows.LookupPrivilegeValue(nil, privStr, &tp.Privileges[0].Luid) - if err != nil { - return nil, err + + return disable, nil +} + +func newTokenPrivileges(numPrivs int) *windows.Tokenprivileges { + if numPrivs <= 0 { + panic("numPrivs must be > 0") } - tp.PrivilegeCount = 1 - tp.Privileges[0].Attributes = windows.SE_PRIVILEGE_ENABLED - return disable, windows.AdjustTokenPrivileges(t, false, &tp, 0, nil, nil) + numBytes := unsafe.Sizeof(windows.Tokenprivileges{}) + (uintptr(numPrivs-1) * unsafe.Sizeof(windows.LUIDAndAttributes{})) + buf := make([]byte, numBytes) + result := (*windows.Tokenprivileges)(unsafe.Pointer(unsafe.SliceData(buf))) + result.PrivilegeCount = uint32(numPrivs) + return result } // StartProcessAsChild starts exePath process as a child of parentPID. @@ -346,35 +384,22 @@ func CreateAppMutex(name string) (windows.Handle, error) { return windows.CreateMutex(nil, false, windows.StringToUTF16Ptr(name)) } -func getTokenInfo(token windows.Token, infoClass uint32) ([]byte, error) { +func getTokenInfo[T any](token windows.Token, infoClass uint32) (*T, error) { + var buf []byte var desiredLen uint32 - err := windows.GetTokenInformation(token, infoClass, nil, 0, &desiredLen) - if err != nil && err != windows.ERROR_INSUFFICIENT_BUFFER { - return nil, err - } - buf := make([]byte, desiredLen) - actualLen := desiredLen - err = windows.GetTokenInformation(token, infoClass, &buf[0], desiredLen, &actualLen) - return buf, err -} + err := windows.GetTokenInformation(token, infoClass, nil, 0, &desiredLen) -func getTokenUserInfo(token windows.Token) (*windows.Tokenuser, error) { - buf, err := getTokenInfo(token, windows.TokenUser) - if err != nil { - return nil, err + for err == windows.ERROR_INSUFFICIENT_BUFFER { + buf = make([]byte, desiredLen) + err = windows.GetTokenInformation(token, infoClass, unsafe.SliceData(buf), desiredLen, &desiredLen) } - return (*windows.Tokenuser)(unsafe.Pointer(&buf[0])), nil -} - -func getTokenPrimaryGroupInfo(token windows.Token) (*windows.Tokenprimarygroup, error) { - buf, err := getTokenInfo(token, windows.TokenPrimaryGroup) if err != nil { return nil, err } - return (*windows.Tokenprimarygroup)(unsafe.Pointer(&buf[0])), nil + return (*T)(unsafe.Pointer(unsafe.SliceData(buf))), nil } type tokenElevationType int32 @@ -417,12 +442,12 @@ func GetCurrentUserSIDs() (*UserSIDs, error) { } defer token.Close() - userInfo, err := getTokenUserInfo(token) + userInfo, err := token.GetTokenUser() if err != nil { return nil, err } - primaryGroup, err := getTokenPrimaryGroupInfo(token) + primaryGroup, err := token.GetTokenPrimaryGroup() if err != nil { return nil, err } @@ -645,3 +670,40 @@ func registerForRestart(opts RegisterForRestartOpts) error { return nil } + +// ProcessImageName returns the fully-qualified path to the executable image +// associated with process. +func ProcessImageName(process windows.Handle) (string, error) { + var pathBuf [windows.MAX_PATH]uint16 + pathBufLen := uint32(len(pathBuf)) + if err := windows.QueryFullProcessImageName(process, 0, &pathBuf[0], &pathBufLen); err != nil { + return "", err + } + return windows.UTF16ToString(pathBuf[:pathBufLen]), nil +} + +// TSSessionIDToLogonSessionID retrieves the logon session ID associated with +// tsSessionId, which is a Terminal Services / RDP session ID. The calling +// process must be running as LocalSystem. +func TSSessionIDToLogonSessionID(tsSessionID uint32) (logonSessionID windows.LUID, err error) { + var token windows.Token + if err := windows.WTSQueryUserToken(tsSessionID, &token); err != nil { + return logonSessionID, fmt.Errorf("WTSQueryUserToken: %w", err) + } + defer token.Close() + return LogonSessionID(token) +} + +type tokenOrigin struct { + originatingLogonSession windows.LUID +} + +// LogonSessionID obtains the logon session ID associated with token. +func LogonSessionID(token windows.Token) (logonSessionID windows.LUID, err error) { + origin, err := getTokenInfo[tokenOrigin](token, windows.TokenOrigin) + if err != nil { + return logonSessionID, err + } + + return origin.originatingLogonSession, nil +}