diff --git a/scripts/check_license_headers.sh b/scripts/check_license_headers.sh index 9a5ae02f5..bbb128e17 100755 --- a/scripts/check_license_headers.sh +++ b/scripts/check_license_headers.sh @@ -37,15 +37,21 @@ for file in $(find $1 \( -name '*.go' -or -name '*.tsx' -or -name '*.ts' -not -n $1/cmd/tailscale/cli/authenticode_windows.go) # WireGuard copyright. ;; - *_string.go) - # Generated file from go:generate stringer - ;; - $1/control/controlbase/noiseexplorer_test.go) - # Noiseexplorer.com copyright. - ;; + *_string.go) + # Generated file from go:generate stringer + ;; + $1/control/controlbase/noiseexplorer_test.go) + # Noiseexplorer.com copyright. + ;; */zsyscall_windows.go) # Generated syscall wrappers ;; + $1/util/winutil/subprocess_windows_test.go) + # Subprocess test harness code + ;; + $1/util/winutil/testdata/testrestartableprocesses/main.go) + # Subprocess test harness code + ;; *) header="$(head -2 $file)" if ! check_file "$header"; then diff --git a/util/winutil/mksyscall.go b/util/winutil/mksyscall.go index 3c5515ee0..17f41ddcc 100644 --- a/util/winutil/mksyscall.go +++ b/util/winutil/mksyscall.go @@ -6,5 +6,11 @@ package winutil //go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go //go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go +//sys getApplicationRestartSettings(process windows.Handle, commandLine *uint16, commandLineLen *uint32, flags *uint32) (ret wingoes.HRESULT) = kernel32.GetApplicationRestartSettings //sys queryServiceConfig2(hService windows.Handle, infoLevel uint32, buf *byte, bufLen uint32, bytesNeeded *uint32) (err error) [failretval==0] = advapi32.QueryServiceConfig2W //sys registerApplicationRestart(cmdLineExclExeName *uint16, flags uint32) (ret wingoes.HRESULT) = kernel32.RegisterApplicationRestart +//sys rmEndSession(session _RMHANDLE) (ret error) = rstrtmgr.RmEndSession +//sys rmGetList(session _RMHANDLE, nProcInfoNeeded *uint32, nProcInfo *uint32, rgAffectedApps *_RM_PROCESS_INFO, pRebootReasons *uint32) (ret error) = rstrtmgr.RmGetList +//sys rmJoinSession(pSession *_RMHANDLE, sessionKey *uint16) (ret error) = rstrtmgr.RmJoinSession +//sys rmRegisterResources(session _RMHANDLE, nFiles uint32, rgsFileNames **uint16, nApplications uint32, rgApplications *_RM_UNIQUE_PROCESS, nServices uint32, rgsServiceNames **uint16) (ret error) = rstrtmgr.RmRegisterResources +//sys rmStartSession(pSession *_RMHANDLE, flags uint32, sessionKey *uint16) (ret error) = rstrtmgr.RmStartSession diff --git a/util/winutil/restartmgr_windows.go b/util/winutil/restartmgr_windows.go new file mode 100644 index 000000000..254b736b2 --- /dev/null +++ b/util/winutil/restartmgr_windows.go @@ -0,0 +1,836 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package winutil + +import ( + "bytes" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "runtime" + "strings" + "time" + "unicode/utf16" + "unsafe" + + "github.com/dblohm7/wingoes" + "golang.org/x/sys/windows" + "tailscale.com/types/logger" + "tailscale.com/util/multierr" +) + +var ( + ErrDefunctProcess = errors.New("process is defunct") + ErrProcessNotRestartable = errors.New("process is not restartable") +) + +// Implementation note: the code in this file will be invoked from within +// MSI custom actions, so please try to return windows.Errno error codes +// whenever possible; this makes the action return more accurate errors to +// the installer engine. + +const ( + _RESTART_NO_CRASH = 1 + _RESTART_NO_HANG = 2 + _RESTART_NO_PATCH = 4 + _RESTART_NO_REBOOT = 8 +) + +func registerForRestart(opts RegisterForRestartOpts) error { + var flags uint32 + + if !opts.RestartOnCrash { + flags |= _RESTART_NO_CRASH + } + if !opts.RestartOnHang { + flags |= _RESTART_NO_HANG + } + if !opts.RestartOnUpgrade { + flags |= _RESTART_NO_PATCH + } + if !opts.RestartOnReboot { + flags |= _RESTART_NO_REBOOT + } + + var cmdLine *uint16 + if opts.UseCmdLineArgs { + if len(opts.CmdLineArgs) == 0 { + // re-use our current args, excluding the exe name itself + opts.CmdLineArgs = os.Args[1:] + } + + var b strings.Builder + for _, arg := range opts.CmdLineArgs { + if b.Len() > 0 { + b.WriteByte(' ') + } + b.WriteString(windows.EscapeArg(arg)) + } + + if b.Len() > 0 { + var err error + cmdLine, err = windows.UTF16PtrFromString(b.String()) + if err != nil { + return err + } + } + } + + hr := registerApplicationRestart(cmdLine, flags) + if e := wingoes.ErrorFromHRESULT(hr); e.Failed() { + return e + } + + return nil +} + +type _RMHANDLE uint32 + +// See https://web.archive.org/web/20231128212837/https://learn.microsoft.com/en-us/windows/win32/rstmgr/using-restart-manager-with-a-secondary-installer +const _INVALID_RMHANDLE = ^_RMHANDLE(0) + +type _RM_UNIQUE_PROCESS struct { + PID uint32 + ProcessStartTime windows.Filetime +} + +type _RM_APP_TYPE int32 + +const ( + _RmUnknownApp _RM_APP_TYPE = 0 + _RmMainWindow _RM_APP_TYPE = 1 + _RmOtherWindow _RM_APP_TYPE = 2 + _RmService _RM_APP_TYPE = 3 + _RmExplorer _RM_APP_TYPE = 4 + _RmConsole _RM_APP_TYPE = 5 + _RmCritical _RM_APP_TYPE = 1000 +) + +type _RM_APP_STATUS uint32 + +const ( + _RmStatusUnknown _RM_APP_STATUS = 0x0 + _RmStatusRunning _RM_APP_STATUS = 0x1 + _RmStatusStopped _RM_APP_STATUS = 0x2 + _RmStatusStoppedOther _RM_APP_STATUS = 0x4 + _RmStatusRestarted _RM_APP_STATUS = 0x8 + _RmStatusErrorOnStop _RM_APP_STATUS = 0x10 + _RmStatusErrorOnRestart _RM_APP_STATUS = 0x20 + _RmStatusShutdownMasked _RM_APP_STATUS = 0x40 + _RmStatusRestartMasked _RM_APP_STATUS = 0x80 +) + +type _RM_PROCESS_INFO struct { + Process _RM_UNIQUE_PROCESS + AppName [256]uint16 + ServiceShortName [64]uint16 + AppType _RM_APP_TYPE + AppStatus _RM_APP_STATUS + TSSessionID uint32 + Restartable int32 // Win32 BOOL +} + +// RestartManagerSession represents an open Restart Manager session. +type RestartManagerSession interface { + io.Closer + // AddPaths adds the fully-qualified paths in fqPaths to the set of binaries + // that will be monitored by this restart manager session. NOTE: This + // method is expensive to call, so it is better to make a single call with + // a larger slice than to make multiple calls with smaller slices. + AddPaths(fqPaths []string) error + // AffectedProcesses returns the UniqueProcess information for all running + // processes that utilize the binaries previously specified by calls to + // AddPaths. + AffectedProcesses() ([]UniqueProcess, error) + // Key returns the session key associated with this instance. + Key() string +} + +// rmSession encapsulates the necessary information to represent an open +// restart manager session. +// +// Implementation note: rmSession methods that return errors should use +// windows.Errno codes whenever possible, as we call them from the custom +// action DLL. MSI custom actions are expected to return windows.Errno values; +// to ensure our compliance with this expectation, we should also use those +// values. Failure to do so will result in a generic windows.Errno being +// returned to the Windows Installer, which obviously is less than ideal. +type rmSession struct { + session _RMHANDLE + key string + logf logger.Logf +} + +const _CCH_RM_SESSION_KEY = 32 // (excludes NUL terminator) + +// NewRestartManagerSession creates a new RestartManagerSession that utilizes +// logf for logging. +func NewRestartManagerSession(logf logger.Logf) (RestartManagerSession, error) { + var sessionKeyBuf [_CCH_RM_SESSION_KEY + 1]uint16 + result := rmSession{ + logf: logf, + } + if err := rmStartSession(&result.session, 0, &sessionKeyBuf[0]); err != nil { + return nil, err + } + + result.key = windows.UTF16ToString(sessionKeyBuf[:_CCH_RM_SESSION_KEY]) + return &result, nil +} + +// AttachRestartManagerSession opens a connection to an existing session +// specified by sessionKey, using logf for logging. +func AttachRestartManagerSession(logf logger.Logf, sessionKey string) (RestartManagerSession, error) { + sessionKey16, err := windows.UTF16PtrFromString(sessionKey) + if err != nil { + return nil, err + } + + result := rmSession{ + key: sessionKey, + logf: logf, + } + if err := rmJoinSession(&result.session, sessionKey16); err != nil { + return nil, err + } + return &result, nil +} + +func (rms *rmSession) Close() error { + if rms == nil || rms.session == _INVALID_RMHANDLE { + return nil + } + if err := rmEndSession(rms.session); err != nil { + return err + } + rms.session = _INVALID_RMHANDLE + return nil +} + +func (rms *rmSession) Key() string { + return rms.key +} + +func (rms *rmSession) AffectedProcesses() ([]UniqueProcess, error) { + infos, err := rms.processList() + if err != nil { + return nil, err + } + + result := make([]UniqueProcess, 0, len(infos)) + for _, info := range infos { + result = append(result, UniqueProcess{ + _RM_UNIQUE_PROCESS: info.Process, + CanReceiveGUIMsgs: info.AppType == _RmMainWindow || info.AppType == _RmOtherWindow, + }) + } + + return result, nil +} + +func (rms *rmSession) processList() ([]_RM_PROCESS_INFO, error) { + const maxAttempts = 5 + var avail, rebootReasons uint32 + needed := uint32(1) + + var buf []_RM_PROCESS_INFO + err := error(windows.ERROR_MORE_DATA) + numAttempts := 0 + for err == windows.ERROR_MORE_DATA && numAttempts < maxAttempts { + numAttempts++ + buf = make([]_RM_PROCESS_INFO, needed) + avail = needed + err = rmGetList(rms.session, &needed, &avail, unsafe.SliceData(buf), &rebootReasons) + } + + if err != nil { + if err == windows.ERROR_SESSION_CREDENTIAL_CONFLICT { + // Add some more context about the meaning of this error. + err = fmt.Errorf("%w (the Restart Manager does not permit calling RmGetList from a process that did not originally create the session)", err) + } + return nil, err + } + + return buf[:avail], nil +} + +func (rms *rmSession) AddPaths(fqPaths []string) error { + if len(fqPaths) == 0 { + return nil + } + + fqPaths16 := make([]*uint16, 0, len(fqPaths)) + for _, fqPath := range fqPaths { + if !filepath.IsAbs(fqPath) { + return fmt.Errorf("%w: paths must be fully-qualified", windows.ERROR_BAD_PATHNAME) + } + + fqPath16, err := windows.UTF16PtrFromString(fqPath) + if err != nil { + return err + } + + fqPaths16 = append(fqPaths16, fqPath16) + } + + return rmRegisterResources(rms.session, uint32(len(fqPaths16)), unsafe.SliceData(fqPaths16), 0, nil, 0, nil) +} + +// UniqueProcess contains the necessary information to uniquely identify a +// process in the face of potential PID reuse. +type UniqueProcess struct { + _RM_UNIQUE_PROCESS + // CanReceiveGUIMsgs is true when the process has open top-level windows. + CanReceiveGUIMsgs bool +} + +// AsRestartableProcess obtains a RestartableProcess populated using the +// information obtained from up. +func (up *UniqueProcess) AsRestartableProcess() (*RestartableProcess, error) { + // We need PROCESS_QUERY_INFORMATION instead of PROCESS_QUERY_LIMITED_INFORMATION + // in order for ProcessImageName to be able to work from within a privileged + // Windows Installer process. + // We need PROCESS_VM_READ for GetApplicationRestartSettings. + // We need PROCESS_TERMINATE and SYNCHRONIZE to terminate the process and + // to be able to wait for the terminated process's handle to signal. + access := uint32(windows.PROCESS_QUERY_INFORMATION | windows.PROCESS_TERMINATE | windows.PROCESS_VM_READ | windows.SYNCHRONIZE) + h, err := windows.OpenProcess(access, false, up.PID) + if err != nil { + return nil, fmt.Errorf("OpenProcess(%d[%#X]): %w", up.PID, up.PID, err) + } + defer func() { + if h == 0 { + return + } + windows.CloseHandle(h) + }() + + var creationTime, exitTime, kernelTime, userTime windows.Filetime + if err := windows.GetProcessTimes(h, &creationTime, &exitTime, &kernelTime, &userTime); err != nil { + return nil, fmt.Errorf("GetProcessTimes: %w", err) + } + if creationTime != up.ProcessStartTime { + // The PID has been reused and does not actually reference the original process. + return nil, ErrDefunctProcess + } + + var tok windows.Token + if err := windows.OpenProcessToken(h, windows.TOKEN_QUERY, &tok); err != nil { + return nil, fmt.Errorf("OpenProcessToken: %w", err) + } + defer tok.Close() + + tsSessionID, err := TSSessionID(tok) + if err != nil { + return nil, fmt.Errorf("TSSessionID: %w", err) + } + + logonSessionID, err := LogonSessionID(tok) + if err != nil { + return nil, fmt.Errorf("LogonSessionID: %w", err) + } + + img, err := ProcessImageName(h) + if err != nil { + return nil, fmt.Errorf("ProcessImageName: %w", err) + } + + const _RESTART_MAX_CMD_LINE = 1024 + var cmdLine [_RESTART_MAX_CMD_LINE]uint16 + cmdLineLen := uint32(len(cmdLine)) + var rmFlags uint32 + hr := getApplicationRestartSettings(h, &cmdLine[0], &cmdLineLen, &rmFlags) + // Not found is not an error; it just means that the app never set any restart settings. + if e := wingoes.ErrorFromHRESULT(hr); e.Failed() && e != wingoes.ErrorFromErrno(windows.ERROR_NOT_FOUND) { + return nil, fmt.Errorf("GetApplicationRestartSettings: %w", error(e)) + } + if (rmFlags & _RESTART_NO_PATCH) != 0 { + // The application explicitly stated that it cannot be restarted during + // an upgrade. + return nil, ErrProcessNotRestartable + } + + var logonSID string + // Non-fatal, so we'll proceed with best-effort. + if tokenGroups, err := tok.GetTokenGroups(); err == nil { + for _, group := range tokenGroups.AllGroups() { + if (group.Attributes & windows.SE_GROUP_LOGON_ID) != 0 { + logonSID = group.Sid.String() + break + } + } + } + + var userSID string + // Non-fatal, so we'll proceed with best-effort. + if tokenUser, err := tok.GetTokenUser(); err == nil { + // Save the user's SID so that we can later check it against the currently + // logged-in Tailscale profile. + userSID = tokenUser.User.Sid.String() + } + + result := &RestartableProcess{ + Process: *up, + SessionInfo: SessionID{ + LogonSession: logonSessionID, + TSSession: tsSessionID, + }, + CommandLineInfo: CommandLineInfo{ + ExePath: img, + Args: windows.UTF16ToString(cmdLine[:cmdLineLen]), + }, + LogonSID: logonSID, + UserSID: userSID, + handle: h, + } + + runtime.SetFinalizer(result, func(rp *RestartableProcess) { rp.Close() }) + h = 0 + return result, nil +} + +// RestartableProcess contains the necessary information to uniquely identify +// an existing process, as well as the necessary information to be able to +// terminate it and later start a new instance in the identical logon session +// to the previous instance. +type RestartableProcess struct { + // Process uniquely identifies the existing process. + Process UniqueProcess + // SessionInfo uniquely identifies the Terminal Services (RDP) and logon + // sessions the existing process is running under. + SessionInfo SessionID + // CommandLineInfo contains the command line information necessary for restarting. + CommandLineInfo CommandLineInfo + // LogonSID contains the stringified SID of the existing process's token's logon session. + LogonSID string + // UserSID contains the stringified SID of the existing process's token's user. + UserSID string + // handle specifies the Win32 HANDLE associated with the existing process. + // When non-zero, it includes access rights for querying, terminating, and synchronizing. + handle windows.Handle + // hasExitCode is true when the exitCode field is valid. + hasExitCode bool + // exitCode contains exit code returned by this RestartableProcess once + // its termination has been recorded by (RestartableProcesses).Terminate. + // It is only valid when hasExitCode == true. + exitCode uint32 +} + +func (rp *RestartableProcess) Close() error { + if rp.handle == 0 { + return nil + } + windows.CloseHandle(rp.handle) + runtime.SetFinalizer(rp, nil) + rp.handle = 0 + return nil +} + +// RestartableProcesses is a map of PID to *RestartableProcess instance. +type RestartableProcesses map[uint32]*RestartableProcess + +// NewRestartableProcesses instantiates a new RestartableProcesses. +func NewRestartableProcesses() RestartableProcesses { + return make(RestartableProcesses) +} + +// Add inserts rp into rps. +func (rps RestartableProcesses) Add(rp *RestartableProcess) { + if rp != nil { + rps[rp.Process.PID] = rp + } +} + +// Delete removes rp from rps. +func (rps RestartableProcesses) Delete(rp *RestartableProcess) { + if rp != nil { + delete(rps, rp.Process.PID) + } +} + +// Close invokes (*RestartableProcess).Close on every value in rps, and then +// clears rps. +func (rps RestartableProcesses) Close() error { + for _, v := range rps { + v.Close() + } + clear(rps) + return nil +} + +// _MAXIMUM_WAIT_OBJECTS is the Win32 constant for the maximum number of +// handles that a call to WaitForMultipleObjects may receive at once. +const _MAXIMUM_WAIT_OBJECTS = 64 + +// Terminate forcibly terminates all processes in rps using exitCode, and then +// waits for their process handles to signal, up to timeout. +func (rps RestartableProcesses) Terminate(logf logger.Logf, exitCode uint32, timeout time.Duration) error { + if len(rps) == 0 { + return nil + } + + millis, err := wingoes.DurationToTimeoutMilliseconds(timeout) + if err != nil { + return err + } + + errs := make([]error, 0, len(rps)) + procs := make([]*RestartableProcess, 0, len(rps)) + handles := make([]windows.Handle, 0, len(rps)) + for _, v := range rps { + if err := windows.TerminateProcess(v.handle, exitCode); err != nil { + if err == windows.ERROR_ACCESS_DENIED { + // If v terminated before we attempted to terminate, we'll receive + // ERROR_ACCESS_DENIED, which is not really an error worth reporting in + // our use case. Just obtain the exit code and then close the process. + if err := windows.GetExitCodeProcess(v.handle, &v.exitCode); err != nil { + logf("GetExitCodeProcess failed: %v", err) + } else { + v.hasExitCode = true + } + v.Close() + } else { + errs = append(errs, &terminationError{rp: v, err: err}) + } + continue + } + procs = append(procs, v) + handles = append(handles, v.handle) + } + + for len(handles) > 0 { + // WaitForMultipleObjects can only wait on _MAXIMUM_WAIT_OBJECTS handles per + // call, so we batch them as necessary. + count := uint32(min(len(handles), _MAXIMUM_WAIT_OBJECTS)) + waitCode, err := windows.WaitForMultipleObjects(handles[:count], true, millis) + if err != nil { + errs = append(errs, fmt.Errorf("waiting on terminated process handles: %w", err)) + break + } + if e := windows.Errno(waitCode); e == windows.WAIT_TIMEOUT { + errs = append(errs, fmt.Errorf("waiting on terminated process handles: %w", error(e))) + break + } + if waitCode >= windows.WAIT_OBJECT_0 && waitCode < (windows.WAIT_OBJECT_0+count) { + // The first count process handles have all been signaled. Close them out. + for _, proc := range procs[:count] { + if err := windows.GetExitCodeProcess(proc.handle, &proc.exitCode); err != nil { + logf("GetExitCodeProcess failed: %v", err) + } else { + proc.hasExitCode = true + } + proc.Close() + } + procs = procs[count:] + handles = handles[count:] + continue + } + // We really shouldn't be reaching this point + panic(fmt.Sprintf("unexpected state from WaitForMultipleObjects: %d", waitCode)) + } + + if len(errs) != 0 { + return multierr.New(errs...) + } + return nil +} + +type terminationError struct { + rp *RestartableProcess + err error +} + +func (te *terminationError) Error() string { + pid := te.rp.Process.PID + return fmt.Sprintf("terminating process %d (%#X): %v", pid, pid, te.err) +} + +func (te *terminationError) Unwrap() error { + return te.err +} + +// SessionID encapsulates the necessary information for uniquely identifying +// sessions. In particular, SessionID contains enough information to detect +// reuse of Terminal Service session IDs. +type SessionID struct { + // LogonSession is the NT logon session ID. + LogonSession windows.LUID + // TSSession is the terminal services session ID. + TSSession uint32 +} + +// OpenToken obtains the security token associated with sessID. +func (sessID *SessionID) OpenToken() (windows.Token, error) { + var token windows.Token + if err := windows.WTSQueryUserToken(sessID.TSSession, &token); err != nil { + return 0, err + } + + var err error + defer func() { + if err != nil { + token.Close() + } + }() + + tokenLogonSession, err := LogonSessionID(token) + if err != nil { + return 0, err + } + + if tokenLogonSession != sessID.LogonSession { + err = windows.ERROR_NO_SUCH_LOGON_SESSION + return 0, err + } + + return token, nil +} + +// ContainsToken determines whether token is contained within sessID. +func (sessID *SessionID) ContainsToken(token windows.Token) (bool, error) { + tokenTSSessionID, err := TSSessionID(token) + if err != nil { + return false, err + } + + if tokenTSSessionID != sessID.TSSession { + return false, nil + } + + tokenLogonSession, err := LogonSessionID(token) + if err != nil { + return false, err + } + + return tokenLogonSession == sessID.LogonSession, nil +} + +// This is the Window Station and Desktop within a particular session that must +// be specified for interactive processes: "Winsta0\\default\x00" +var defaultDesktop = unsafe.SliceData([]uint16{'W', 'i', 'n', 's', 't', 'a', '0', '\\', 'd', 'e', 'f', 'a', 'u', 'l', 't', 0}) + +// CommandLineInfo manages the necessary information for creating a Win32 +// process using a specific command line. +type CommandLineInfo struct { + // ExePath must be a fully-qualified path to a Windows executable binary. + ExePath string + // Args must be any arguments supplied to the process, excluding the + // path to the binary itself. Args must be properly quoted according to + // Windows path rules. To create a properly quoted Args from scratch, call the + // SetArgs method instead. + Args string `json:",omitempty"` +} + +// SetArgs converts args to a string quoted as necessary to satisfy the rules +// for Win32 command lines, and sets cli.Args to that string. +func (cli *CommandLineInfo) SetArgs(args []string) { + var buf strings.Builder + for _, arg := range args { + if buf.Len() > 0 { + buf.WriteByte(' ') + } + buf.WriteString(windows.EscapeArg(arg)) + } + + cli.Args = buf.String() +} + +// Validate ensures that cli.ExePath contains an absolute path. +func (cli *CommandLineInfo) Validate() error { + if cli == nil { + return windows.ERROR_INVALID_PARAMETER + } + + if !filepath.IsAbs(cli.ExePath) { + return fmt.Errorf("%w: CommandLineInfo requires absolute ExePath", windows.ERROR_BAD_PATHNAME) + } + + return nil +} + +// Resolve converts the information in cli to a format compatible with the Win32 +// CreateProcess* family of APIs, as pointers to C-style UTF-16 strings. It also +// returns the full command line as a Go string for logging purposes. +func (cli *CommandLineInfo) Resolve() (exePath *uint16, cmdLine *uint16, cmdLineStr string, err error) { + // Resolve cmdLine first since that also does a Validate. + cmdLineStr, cmdLine, err = cli.resolveArgsAsUTF16Ptr() + if err != nil { + return nil, nil, "", err + } + + exePath, err = windows.UTF16PtrFromString(cli.ExePath) + if err != nil { + return nil, nil, "", err + } + + return exePath, cmdLine, cmdLineStr, nil +} + +// resolveArgs quotes cli.ExePath as necessary, appends Args, and returns the result. +func (cli *CommandLineInfo) resolveArgs() (string, error) { + if err := cli.Validate(); err != nil { + return "", err + } + + var cmdLineBuf strings.Builder + cmdLineBuf.WriteString(windows.EscapeArg(cli.ExePath)) + if args := cli.Args; args != "" { + cmdLineBuf.WriteByte(' ') + cmdLineBuf.WriteString(args) + } + + return cmdLineBuf.String(), nil +} + +func (cli *CommandLineInfo) resolveArgsAsUTF16Ptr() (string, *uint16, error) { + s, err := cli.resolveArgs() + if err != nil { + return "", nil, err + } + s16, err := windows.UTF16PtrFromString(s) + if err != nil { + return "", nil, err + } + return s, s16, nil +} + +// StartProcessInSession creates a new process using cmdLineInfo that will +// reside inside the session identified by sessID, with the security token whose +// logon is associated with sessID. The child process's environment will be +// inherited from the session token's environment. +func StartProcessInSession(sessID SessionID, cmdLineInfo CommandLineInfo) error { + return StartProcessInSessionWithHandler(sessID, cmdLineInfo, nil) +} + +// PostCreateProcessHandler is a function that is invoked by +// StartProcessInSessionWithHandler when the child process has been successfully +// created. It is the responsibility of the handler to close the pi.Thread and +// pi.Process handles. +type PostCreateProcessHandler func(pi *windows.ProcessInformation) + +// StartProcessInSessionWithHandler creates a new process using cmdLineInfo that +// will reside inside the session identified by sessID, with the security token +// whose logon is associated with sessID. The child process's environment will be +// inherited from the session token's environment. When the child process has +// been successfully created, handler is invoked with the windows.ProcessInformation +// that was returned by the OS. +func StartProcessInSessionWithHandler(sessID SessionID, cmdLineInfo CommandLineInfo, handler PostCreateProcessHandler) error { + pi, err := startProcessInSessionInternal(sessID, cmdLineInfo, 0) + if err != nil { + return err + } + if handler != nil { + handler(pi) + return nil + } + windows.CloseHandle(pi.Process) + windows.CloseHandle(pi.Thread) + return nil +} + +// RunProcessInSession creates a new process and waits up to timeout for that +// child process to complete its execution. The process is created using +// cmdLineInfo and will reside inside the session identified by sessID, with the +// security token whose logon is associated with sessID. The child process's +// environment will be inherited from the session token's environment. +func RunProcessInSession(sessID SessionID, cmdLineInfo CommandLineInfo, timeout time.Duration) (uint32, error) { + timeoutMillis, err := wingoes.DurationToTimeoutMilliseconds(timeout) + if err != nil { + return 1, err + } + + pi, err := startProcessInSessionInternal(sessID, cmdLineInfo, 0) + if err != nil { + return 1, err + } + windows.CloseHandle(pi.Thread) + defer windows.CloseHandle(pi.Process) + + waitCode, err := windows.WaitForSingleObject(pi.Process, timeoutMillis) + if err != nil { + return 1, fmt.Errorf("WaitForSingleObject: %w", err) + } + if e := windows.Errno(waitCode); e == windows.WAIT_TIMEOUT { + return 1, e + } + if waitCode != windows.WAIT_OBJECT_0 { + // This should not be possible; log + return 1, fmt.Errorf("unexpected state from WaitForSingleObject: %d", waitCode) + } + + var exitCode uint32 + if err := windows.GetExitCodeProcess(pi.Process, &exitCode); err != nil { + return 1, err + } + return exitCode, nil +} + +func startProcessInSessionInternal(sessID SessionID, cmdLineInfo CommandLineInfo, extraFlags uint32) (*windows.ProcessInformation, error) { + if err := cmdLineInfo.Validate(); err != nil { + return nil, err + } + + token, err := sessID.OpenToken() + if err != nil { + return nil, fmt.Errorf("(*SessionID).OpenToken: %w", err) + } + defer token.Close() + + exePath16, commandLine16, _, err := cmdLineInfo.Resolve() + if err != nil { + return nil, fmt.Errorf("(*CommandLineInfo).Resolve(): %w", err) + } + + wd16, err := windows.UTF16PtrFromString(filepath.Dir(cmdLineInfo.ExePath)) + if err != nil { + return nil, fmt.Errorf("UTF16PtrFromString(wd): %w", err) + } + + env, err := token.Environ(false) + if err != nil { + return nil, fmt.Errorf("token environment: %w", err) + } + env16 := newEnvBlock(env) + + // The privileges in privNames are required for CreateProcessAsUser to be + // able to start processes as other users in other logon sessions. + privNames := []string{ + "SeAssignPrimaryTokenPrivilege", + "SeIncreaseQuotaPrivilege", + } + dropPrivs, err := EnableCurrentThreadPrivileges(privNames) + if err != nil { + return nil, fmt.Errorf("EnableCurrentThreadPrivileges(%#v): %w", privNames, err) + } + defer dropPrivs() + + createFlags := extraFlags | windows.CREATE_UNICODE_ENVIRONMENT | windows.DETACHED_PROCESS + si := windows.StartupInfo{ + Cb: uint32(unsafe.Sizeof(windows.StartupInfo{})), + Desktop: defaultDesktop, + } + var pi windows.ProcessInformation + if err := windows.CreateProcessAsUser(token, exePath16, commandLine16, nil, nil, + false, createFlags, env16, wd16, &si, &pi); err != nil { + return nil, fmt.Errorf("CreateProcessAsUser: %w", err) + } + return &pi, nil +} + +func newEnvBlock(env []string) *uint16 { + // Intentionally using bytes.Buffer here because we're writing nul bytes (the standard library does this too). + var buf bytes.Buffer + for _, v := range env { + buf.WriteString(v) + buf.WriteByte(0) + } + if buf.Len() == 0 { + // So that we end with a double-null in the empty env case + buf.WriteByte(0) + } + buf.WriteByte(0) + return unsafe.SliceData(utf16.Encode([]rune(string(buf.Bytes())))) +} diff --git a/util/winutil/restartmgr_windows_test.go b/util/winutil/restartmgr_windows_test.go new file mode 100644 index 000000000..0293bc135 --- /dev/null +++ b/util/winutil/restartmgr_windows_test.go @@ -0,0 +1,147 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package winutil + +import ( + "fmt" + "os/user" + "path/filepath" + "strings" + "testing" + "time" + "unsafe" + + "golang.org/x/sys/windows" +) + +const oldFashionedCleanupExitCode = 7778 + +// oldFashionedCleanup cleans up any outstanding binaries using older APIs. +// This would be necessary if the restart manager were to fail during the test. +func oldFashionedCleanup(t *testing.T, binary string) { + snap, err := windows.CreateToolhelp32Snapshot(windows.TH32CS_SNAPPROCESS, 0) + if err != nil { + t.Logf("CreateToolhelp32Snapshot failed: %v", err) + } + defer windows.CloseHandle(snap) + + binary = filepath.Clean(binary) + binbase := filepath.Base(binary) + pe := windows.ProcessEntry32{ + Size: uint32(unsafe.Sizeof(windows.ProcessEntry32{})), + } + for perr := windows.Process32First(snap, &pe); perr == nil; perr = windows.Process32Next(snap, &pe) { + curBin := windows.UTF16ToString(pe.ExeFile[:]) + // Coarse check against the leaf name of the binary + if !strings.EqualFold(binbase, curBin) { + continue + } + + proc, err := windows.OpenProcess(windows.PROCESS_QUERY_LIMITED_INFORMATION|windows.PROCESS_TERMINATE, false, pe.ProcessID) + if err != nil { + t.Logf("OpenProcess failed: %v", err) + continue + } + defer windows.CloseHandle(proc) + + img, err := ProcessImageName(proc) + if err != nil { + t.Logf("ProcessImageName failed: %v", err) + continue + } + + // Now check that their fully-qualified paths match. + if !strings.EqualFold(binary, filepath.Clean(img)) { + continue + } + + t.Logf("Found leftover pid %d, terminating...", pe.ProcessID) + if err := windows.TerminateProcess(proc, oldFashionedCleanupExitCode); err != nil && err != windows.ERROR_ACCESS_DENIED { + t.Logf("TerminateProcess failed: %v", err) + } + } +} + +func testRestartableProcessesImpl(N int, t *testing.T) { + const binary = "testrestartableprocesses" + fq := pathToTestProg(t, binary) + + for i := 0; i < N; i++ { + startTestProg(t, binary, "RestartableProcess") + } + t.Cleanup(func() { + oldFashionedCleanup(t, fq) + }) + + logf := func(format string, args ...any) { + t.Logf(format, args...) + } + rms, err := NewRestartManagerSession(logf) + if err != nil { + t.Fatalf("NewRestartManagerSession: %v", err) + } + defer rms.Close() + + if err := rms.AddPaths([]string{fq}); err != nil { + t.Fatalf("AddPaths: %v", err) + } + + ups, err := rms.AffectedProcesses() + if err != nil { + t.Fatalf("AffectedProcesses: %v", err) + } + + rps := NewRestartableProcesses() + defer rps.Close() + + for _, up := range ups { + rp, err := up.AsRestartableProcess() + if err != nil { + t.Errorf("AsRestartableProcess: %v", err) + continue + } + rps.Add(rp) + } + + const terminateWithExitCode = 7777 + if err := rps.Terminate(logf, terminateWithExitCode, time.Duration(15)*time.Second); err != nil { + t.Errorf("Terminate: %v", err) + } + + for k, v := range rps { + if v.hasExitCode { + if v.exitCode != terminateWithExitCode { + // Not strictly an error, but worth noting. + logf("Subprocess %d terminated with unexpected exit code %d", k, v.exitCode) + } + } else { + t.Errorf("Subprocess %d did not produce an exit code", k) + } + if v.handle != 0 { + t.Errorf("Subprocess %d is unexpectedly still open", k) + } + } +} + +func TestRestartableProcesses(t *testing.T) { + u, err := user.Current() + if err != nil { + t.Fatalf("Could not obtain current user") + } + if u.Uid != localSystemSID { + t.Skipf("This test must be run as SYSTEM") + } + + forN := func(fn func(int, *testing.T)) func([]int) { + return func(ns []int) { + for _, n := range ns { + t.Run(fmt.Sprintf("N=%d", n), func(tt *testing.T) { fn(n, tt) }) + } + } + }(testRestartableProcessesImpl) + + // Testing indicates that the restart manager cannot handle more than 127 processes (on Windows 10, at least), so we use that as our highest value. + ns := []int{0, 1, _MAXIMUM_WAIT_OBJECTS - 1, _MAXIMUM_WAIT_OBJECTS, _MAXIMUM_WAIT_OBJECTS + 1, _MAXIMUM_WAIT_OBJECTS*2 - 1} + forN(ns) +} diff --git a/util/winutil/subprocess_windows_test.go b/util/winutil/subprocess_windows_test.go new file mode 100644 index 000000000..4c6bb5977 --- /dev/null +++ b/util/winutil/subprocess_windows_test.go @@ -0,0 +1,433 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package winutil + +import ( + "bytes" + "errors" + "fmt" + "os" + "os/exec" + "path/filepath" + "runtime" + "strconv" + "strings" + "sync" + "testing" + "time" +) + +// The code in this file is adapted from internal/testenv in the Go source tree +// and is used for writing tests that require spawning subprocesses. + +var toRemove []string + +func TestMain(m *testing.M) { + status := m.Run() + for _, file := range toRemove { + os.RemoveAll(file) + } + os.Exit(status) +} + +var testprog struct { + sync.Mutex + dir string + target map[string]*buildexe +} + +type buildexe struct { + once sync.Once + exe string + err error +} + +func pathToTestProg(t *testing.T, binary string) string { + exe, err := buildTestProg(t, binary, "-buildvcs=false") + if err != nil { + t.Fatal(err) + } + return exe +} + +func runTestProg(t *testing.T, binary, name string, env ...string) string { + exe, err := buildTestProg(t, binary, "-buildvcs=false") + if err != nil { + t.Fatal(err) + } + + return runBuiltTestProg(t, exe, name, env...) +} + +func startTestProg(t *testing.T, binary, name string, env ...string) { + exe, err := buildTestProg(t, binary, "-buildvcs=false") + if err != nil { + t.Fatal(err) + } + + startBuiltTestProg(t, exe, name, env...) +} + +func runBuiltTestProg(t *testing.T, exe, name string, env ...string) string { + cmd := exec.Command(exe, name) + cmd.Env = append(cmd.Env, env...) + if testing.Short() { + cmd.Env = append(cmd.Env, "RUNTIME_TEST_SHORT=1") + } + out, _ := runWithTimeout(t, cmd) + return string(out) +} + +func startBuiltTestProg(t *testing.T, exe, name string, env ...string) { + cmd := exec.Command(exe, name) + cmd.Env = append(cmd.Env, env...) + if testing.Short() { + cmd.Env = append(cmd.Env, "RUNTIME_TEST_SHORT=1") + } + start(t, cmd) +} + +var serializeBuild = make(chan bool, 2) + +func buildTestProg(t *testing.T, binary string, flags ...string) (string, error) { + testprog.Lock() + if testprog.dir == "" { + dir, err := os.MkdirTemp("", "go-build") + if err != nil { + t.Fatalf("failed to create temp directory: %v", err) + } + testprog.dir = dir + toRemove = append(toRemove, dir) + } + + if testprog.target == nil { + testprog.target = make(map[string]*buildexe) + } + name := binary + if len(flags) > 0 { + nameFlags := make([]string, 0, len(flags)) + for _, flag := range flags { + nameFlags = append(nameFlags, strings.ReplaceAll(flag, "=", "_")) + } + name += "_" + strings.Join(nameFlags, "_") + } + target, ok := testprog.target[name] + if !ok { + target = &buildexe{} + testprog.target[name] = target + } + + dir := testprog.dir + + // Unlock testprog while actually building, so that other + // tests can look up executables that were already built. + testprog.Unlock() + + target.once.Do(func() { + // Only do two "go build"'s at a time, + // to keep load from getting too high. + serializeBuild <- true + defer func() { <-serializeBuild }() + + // Don't get confused if goToolPath calls t.Skip. + target.err = errors.New("building test called t.Skip") + + exe := filepath.Join(dir, name+".exe") + + t.Logf("running go build -o %s %s", exe, strings.Join(flags, " ")) + cmd := exec.Command(goToolPath(t), append([]string{"build", "-o", exe}, flags...)...) + cmd.Dir = "testdata/" + binary + out, err := cmd.CombinedOutput() + if err != nil { + target.err = fmt.Errorf("building %s %v: %v\n%s", binary, flags, err, out) + } else { + target.exe = exe + target.err = nil + } + }) + + return target.exe, target.err +} + +// goTool reports the path to the Go tool. +func goTool() (string, error) { + if !hasGoBuild() { + return "", errors.New("platform cannot run go tool") + } + exeSuffix := ".exe" + goroot, err := findGOROOT() + if err != nil { + return "", fmt.Errorf("cannot find go tool: %w", err) + } + path := filepath.Join(goroot, "bin", "go"+exeSuffix) + if _, err := os.Stat(path); err == nil { + return path, nil + } + goBin, err := exec.LookPath("go" + exeSuffix) + if err != nil { + return "", errors.New("cannot find go tool: " + err.Error()) + } + return goBin, nil +} + +// knownEnv is a list of environment variables that affect the operation +// of the Go command. +const knownEnv = ` + AR + CC + CGO_CFLAGS + CGO_CFLAGS_ALLOW + CGO_CFLAGS_DISALLOW + CGO_CPPFLAGS + CGO_CPPFLAGS_ALLOW + CGO_CPPFLAGS_DISALLOW + CGO_CXXFLAGS + CGO_CXXFLAGS_ALLOW + CGO_CXXFLAGS_DISALLOW + CGO_ENABLED + CGO_FFLAGS + CGO_FFLAGS_ALLOW + CGO_FFLAGS_DISALLOW + CGO_LDFLAGS + CGO_LDFLAGS_ALLOW + CGO_LDFLAGS_DISALLOW + CXX + FC + GCCGO + GO111MODULE + GO386 + GOAMD64 + GOARCH + GOARM + GOBIN + GOCACHE + GOENV + GOEXE + GOEXPERIMENT + GOFLAGS + GOGCCFLAGS + GOHOSTARCH + GOHOSTOS + GOINSECURE + GOMIPS + GOMIPS64 + GOMODCACHE + GONOPROXY + GONOSUMDB + GOOS + GOPATH + GOPPC64 + GOPRIVATE + GOPROXY + GOROOT + GOSUMDB + GOTMPDIR + GOTOOLDIR + GOVCS + GOWASM + GOWORK + GO_EXTLINK_ENABLED + PKG_CONFIG +` + +// goToolPath reports the path to the Go tool. +// It is a convenience wrapper around goTool. +// If the tool is unavailable goToolPath calls t.Skip. +// If the tool should be available and isn't, goToolPath calls t.Fatal. +func goToolPath(t testing.TB) string { + mustHaveGoBuild(t) + path, err := goTool() + if err != nil { + t.Fatal(err) + } + // Add all environment variables that affect the Go command to test metadata. + // Cached test results will be invalidate when these variables change. + // See golang.org/issue/32285. + for _, envVar := range strings.Fields(knownEnv) { + os.Getenv(envVar) + } + return path +} + +// hasGoBuild reports whether the current system can build programs with “go build” +// and then run them with os.StartProcess or exec.Command. +func hasGoBuild() bool { + if os.Getenv("GO_GCFLAGS") != "" { + // It's too much work to require every caller of the go command + // to pass along "-gcflags="+os.Getenv("GO_GCFLAGS"). + // For now, if $GO_GCFLAGS is set, report that we simply can't + // run go build. + return false + } + return true +} + +// mustHaveGoBuild checks that the current system can build programs with “go build” +// and then run them with os.StartProcess or exec.Command. +// If not, mustHaveGoBuild calls t.Skip with an explanation. +func mustHaveGoBuild(t testing.TB) { + if os.Getenv("GO_GCFLAGS") != "" { + t.Skipf("skipping test: 'go build' not compatible with setting $GO_GCFLAGS") + } + if !hasGoBuild() { + t.Skipf("skipping test: 'go build' not available on %s/%s", runtime.GOOS, runtime.GOARCH) + } +} + +// hasGoRun reports whether the current system can run programs with “go run.” +func hasGoRun() bool { + // For now, having go run and having go build are the same. + return hasGoBuild() +} + +// mustHaveGoRun checks that the current system can run programs with “go run.” +// If not, mustHaveGoRun calls t.Skip with an explanation. +func mustHaveGoRun(t testing.TB) { + if !hasGoRun() { + t.Skipf("skipping test: 'go run' not available on %s/%s", runtime.GOOS, runtime.GOARCH) + } +} + +var ( + gorootOnce sync.Once + gorootPath string + gorootErr error +) + +func findGOROOT() (string, error) { + gorootOnce.Do(func() { + gorootPath = runtime.GOROOT() + if gorootPath != "" { + // If runtime.GOROOT() is non-empty, assume that it is valid. + // + // (It might not be: for example, the user may have explicitly set GOROOT + // to the wrong directory, or explicitly set GOROOT_FINAL but not GOROOT + // and hasn't moved the tree to GOROOT_FINAL yet. But those cases are + // rare, and if that happens the user can fix what they broke.) + return + } + + // runtime.GOROOT doesn't know where GOROOT is (perhaps because the test + // binary was built with -trimpath, or perhaps because GOROOT_FINAL was set + // without GOROOT and the tree hasn't been moved there yet). + // + // Since this is internal/testenv, we can cheat and assume that the caller + // is a test of some package in a subdirectory of GOROOT/src. ('go test' + // runs the test in the directory containing the packaged under test.) That + // means that if we start walking up the tree, we should eventually find + // GOROOT/src/go.mod, and we can report the parent directory of that. + + cwd, err := os.Getwd() + if err != nil { + gorootErr = fmt.Errorf("finding GOROOT: %w", err) + return + } + + dir := cwd + for { + parent := filepath.Dir(dir) + if parent == dir { + // dir is either "." or only a volume name. + gorootErr = fmt.Errorf("failed to locate GOROOT/src in any parent directory") + return + } + + if base := filepath.Base(dir); base != "src" { + dir = parent + continue // dir cannot be GOROOT/src if it doesn't end in "src". + } + + b, err := os.ReadFile(filepath.Join(dir, "go.mod")) + if err != nil { + if os.IsNotExist(err) { + dir = parent + continue + } + gorootErr = fmt.Errorf("finding GOROOT: %w", err) + return + } + goMod := string(b) + + for goMod != "" { + var line string + line, goMod, _ = strings.Cut(goMod, "\n") + fields := strings.Fields(line) + if len(fields) >= 2 && fields[0] == "module" && fields[1] == "std" { + // Found "module std", which is the module declaration in GOROOT/src! + gorootPath = parent + return + } + } + } + }) + + return gorootPath, gorootErr +} + +// runWithTimeout runs cmd and returns its combined output. If the +// subprocess exits with a non-zero status, it will log that status +// and return a non-nil error, but this is not considered fatal. +func runWithTimeout(t testing.TB, cmd *exec.Cmd) ([]byte, error) { + args := cmd.Args + if args == nil { + args = []string{cmd.Path} + } + + var b bytes.Buffer + cmd.Stdout = &b + cmd.Stderr = &b + if err := cmd.Start(); err != nil { + t.Fatalf("starting %s: %v", args, err) + } + + // If the process doesn't complete within 1 minute, + // assume it is hanging and kill it to get a stack trace. + p := cmd.Process + done := make(chan bool) + go func() { + scale := 2 + if s := os.Getenv("GO_TEST_TIMEOUT_SCALE"); s != "" { + if sc, err := strconv.Atoi(s); err == nil { + scale = sc + } + } + + select { + case <-done: + case <-time.After(time.Duration(scale) * time.Minute): + p.Signal(os.Kill) + // If SIGQUIT doesn't do it after a little + // while, kill the process. + select { + case <-done: + case <-time.After(time.Duration(scale) * 30 * time.Second): + p.Signal(os.Kill) + } + } + }() + + err := cmd.Wait() + if err != nil { + t.Logf("%s exit status: %v", args, err) + } + close(done) + + return b.Bytes(), err +} + +// start runs cmd asynchronously and returns immediately. +func start(t testing.TB, cmd *exec.Cmd) { + args := cmd.Args + if args == nil { + args = []string{cmd.Path} + } + + var b bytes.Buffer + cmd.Stdout = &b + cmd.Stderr = &b + if err := cmd.Start(); err != nil { + t.Fatalf("starting %s: %v", args, err) + } +} diff --git a/util/winutil/testdata/testrestartableprocesses/main.go b/util/winutil/testdata/testrestartableprocesses/main.go new file mode 100644 index 000000000..f5afdf769 --- /dev/null +++ b/util/winutil/testdata/testrestartableprocesses/main.go @@ -0,0 +1,40 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build windows + +package main + +import "os" + +var ( + cmds = map[string]func(){} + err error +) + +func register(name string, f func()) { + if cmds[name] != nil { + panic("duplicate registration: " + name) + } + cmds[name] = f +} + +func registerInit(name string, f func()) { + if len(os.Args) >= 2 && os.Args[1] == name { + f() + } +} + +func main() { + if len(os.Args) < 2 { + println("usage: " + os.Args[0] + " name-of-test") + return + } + f := cmds[os.Args[1]] + if f == nil { + println("unknown function: " + os.Args[1]) + return + } + f() +} diff --git a/util/winutil/testdata/testrestartableprocesses/restartableprocess_windows.go b/util/winutil/testdata/testrestartableprocesses/restartableprocess_windows.go new file mode 100644 index 000000000..d3be45102 --- /dev/null +++ b/util/winutil/testdata/testrestartableprocesses/restartableprocess_windows.go @@ -0,0 +1,16 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "golang.org/x/sys/windows" +) + +func init() { + register("RestartableProcess", RestartableProcess) +} + +func RestartableProcess() { + windows.SleepEx(windows.INFINITE, false) +} diff --git a/util/winutil/winutil_windows.go b/util/winutil/winutil_windows.go index 8fff9f056..53f368343 100644 --- a/util/winutil/winutil_windows.go +++ b/util/winutil/winutil_windows.go @@ -7,7 +7,6 @@ import ( "errors" "fmt" "log" - "os" "os/exec" "os/user" "runtime" @@ -16,7 +15,6 @@ import ( "time" "unsafe" - "github.com/dblohm7/wingoes" "golang.org/x/sys/windows" "golang.org/x/sys/windows/registry" ) @@ -384,7 +382,9 @@ func CreateAppMutex(name string) (windows.Handle, error) { return windows.CreateMutex(nil, false, windows.StringToUTF16Ptr(name)) } -func getTokenInfo[T any](token windows.Token, infoClass uint32) (*T, error) { +// getTokenInfoVariableLen obtains variable-length token information. Use +// this function for information classes that output variable-length data. +func getTokenInfoVariableLen[T any](token windows.Token, infoClass uint32) (*T, error) { var buf []byte var desiredLen uint32 @@ -402,6 +402,15 @@ func getTokenInfo[T any](token windows.Token, infoClass uint32) (*T, error) { return (*T)(unsafe.Pointer(unsafe.SliceData(buf))), nil } +// getTokenInfoFixedLen obtains known fixed-length token information. Use this +// function for information classes that output enumerations, BOOLs, integers etc. +func getTokenInfoFixedLen[T any](token windows.Token, infoClass uint32) (result T, err error) { + var actualLen uint32 + p := (*byte)(unsafe.Pointer(&result)) + err = windows.GetTokenInformation(token, infoClass, p, uint32(unsafe.Sizeof(result)), &actualLen) + return result, err +} + type tokenElevationType int32 const ( @@ -410,16 +419,9 @@ const ( tokenElevationTypeLimited tokenElevationType = 3 ) -func getTokenElevationType(token windows.Token) (result tokenElevationType, err error) { - var actualLen uint32 - p := (*byte)(unsafe.Pointer(&result)) - err = windows.GetTokenInformation(token, windows.TokenElevationType, p, uint32(unsafe.Sizeof(result)), &actualLen) - return result, err -} - // IsTokenLimited returns whether token is a limited UAC token. func IsTokenLimited(token windows.Token) (bool, error) { - elevationType, err := getTokenElevationType(token) + elevationType, err := getTokenInfoFixedLen[tokenElevationType](token, windows.TokenElevationType) if err != nil { return false, err } @@ -616,61 +618,6 @@ func findHomeDirInRegistry(uid string) (dir string, err error) { return dir, nil } -const ( - _RESTART_NO_CRASH = 1 - _RESTART_NO_HANG = 2 - _RESTART_NO_PATCH = 4 - _RESTART_NO_REBOOT = 8 -) - -func registerForRestart(opts RegisterForRestartOpts) error { - var flags uint32 - - if !opts.RestartOnCrash { - flags |= _RESTART_NO_CRASH - } - if !opts.RestartOnHang { - flags |= _RESTART_NO_HANG - } - if !opts.RestartOnUpgrade { - flags |= _RESTART_NO_PATCH - } - if !opts.RestartOnReboot { - flags |= _RESTART_NO_REBOOT - } - - var cmdLine *uint16 - if opts.UseCmdLineArgs { - if len(opts.CmdLineArgs) == 0 { - // re-use our current args, excluding the exe name itself - opts.CmdLineArgs = os.Args[1:] - } - - var b strings.Builder - for _, arg := range opts.CmdLineArgs { - if b.Len() > 0 { - b.WriteByte(' ') - } - b.WriteString(windows.EscapeArg(arg)) - } - - if b.Len() > 0 { - var err error - cmdLine, err = windows.UTF16PtrFromString(b.String()) - if err != nil { - return err - } - } - } - - hr := registerApplicationRestart(cmdLine, flags) - if e := wingoes.ErrorFromHRESULT(hr); e.Failed() { - return e - } - - return nil -} - // ProcessImageName returns the fully-qualified path to the executable image // associated with process. func ProcessImageName(process windows.Handle) (string, error) { @@ -694,13 +641,18 @@ func TSSessionIDToLogonSessionID(tsSessionID uint32) (logonSessionID windows.LUI return LogonSessionID(token) } +// TSSessionID obtains the Terminal Services (RDP) session ID associated with token. +func TSSessionID(token windows.Token) (tsSessionID uint32, err error) { + return getTokenInfoFixedLen[uint32](token, windows.TokenSessionId) +} + type tokenOrigin struct { originatingLogonSession windows.LUID } // LogonSessionID obtains the logon session ID associated with token. func LogonSessionID(token windows.Token) (logonSessionID windows.LUID, err error) { - origin, err := getTokenInfo[tokenOrigin](token, windows.TokenOrigin) + origin, err := getTokenInfoFixedLen[tokenOrigin](token, windows.TokenOrigin) if err != nil { return logonSessionID, err } diff --git a/util/winutil/zsyscall_windows.go b/util/winutil/zsyscall_windows.go index 77e9f36c8..b228ff158 100644 --- a/util/winutil/zsyscall_windows.go +++ b/util/winutil/zsyscall_windows.go @@ -41,9 +41,16 @@ func errnoErr(e syscall.Errno) error { var ( modadvapi32 = windows.NewLazySystemDLL("advapi32.dll") modkernel32 = windows.NewLazySystemDLL("kernel32.dll") + modrstrtmgr = windows.NewLazySystemDLL("rstrtmgr.dll") - procQueryServiceConfig2W = modadvapi32.NewProc("QueryServiceConfig2W") - procRegisterApplicationRestart = modkernel32.NewProc("RegisterApplicationRestart") + procQueryServiceConfig2W = modadvapi32.NewProc("QueryServiceConfig2W") + procGetApplicationRestartSettings = modkernel32.NewProc("GetApplicationRestartSettings") + procRegisterApplicationRestart = modkernel32.NewProc("RegisterApplicationRestart") + procRmEndSession = modrstrtmgr.NewProc("RmEndSession") + procRmGetList = modrstrtmgr.NewProc("RmGetList") + procRmJoinSession = modrstrtmgr.NewProc("RmJoinSession") + procRmRegisterResources = modrstrtmgr.NewProc("RmRegisterResources") + procRmStartSession = modrstrtmgr.NewProc("RmStartSession") ) func queryServiceConfig2(hService windows.Handle, infoLevel uint32, buf *byte, bufLen uint32, bytesNeeded *uint32) (err error) { @@ -54,8 +61,54 @@ func queryServiceConfig2(hService windows.Handle, infoLevel uint32, buf *byte, b return } +func getApplicationRestartSettings(process windows.Handle, commandLine *uint16, commandLineLen *uint32, flags *uint32) (ret wingoes.HRESULT) { + r0, _, _ := syscall.Syscall6(procGetApplicationRestartSettings.Addr(), 4, uintptr(process), uintptr(unsafe.Pointer(commandLine)), uintptr(unsafe.Pointer(commandLineLen)), uintptr(unsafe.Pointer(flags)), 0, 0) + ret = wingoes.HRESULT(r0) + return +} + func registerApplicationRestart(cmdLineExclExeName *uint16, flags uint32) (ret wingoes.HRESULT) { r0, _, _ := syscall.Syscall(procRegisterApplicationRestart.Addr(), 2, uintptr(unsafe.Pointer(cmdLineExclExeName)), uintptr(flags), 0) ret = wingoes.HRESULT(r0) return } + +func rmEndSession(session _RMHANDLE) (ret error) { + r0, _, _ := syscall.Syscall(procRmEndSession.Addr(), 1, uintptr(session), 0, 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func rmGetList(session _RMHANDLE, nProcInfoNeeded *uint32, nProcInfo *uint32, rgAffectedApps *_RM_PROCESS_INFO, pRebootReasons *uint32) (ret error) { + r0, _, _ := syscall.Syscall6(procRmGetList.Addr(), 5, uintptr(session), uintptr(unsafe.Pointer(nProcInfoNeeded)), uintptr(unsafe.Pointer(nProcInfo)), uintptr(unsafe.Pointer(rgAffectedApps)), uintptr(unsafe.Pointer(pRebootReasons)), 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func rmJoinSession(pSession *_RMHANDLE, sessionKey *uint16) (ret error) { + r0, _, _ := syscall.Syscall(procRmJoinSession.Addr(), 2, uintptr(unsafe.Pointer(pSession)), uintptr(unsafe.Pointer(sessionKey)), 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func rmRegisterResources(session _RMHANDLE, nFiles uint32, rgsFileNames **uint16, nApplications uint32, rgApplications *_RM_UNIQUE_PROCESS, nServices uint32, rgsServiceNames **uint16) (ret error) { + r0, _, _ := syscall.Syscall9(procRmRegisterResources.Addr(), 7, uintptr(session), uintptr(nFiles), uintptr(unsafe.Pointer(rgsFileNames)), uintptr(nApplications), uintptr(unsafe.Pointer(rgApplications)), uintptr(nServices), uintptr(unsafe.Pointer(rgsServiceNames)), 0, 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func rmStartSession(pSession *_RMHANDLE, flags uint32, sessionKey *uint16) (ret error) { + r0, _, _ := syscall.Syscall(procRmStartSession.Addr(), 3, uintptr(unsafe.Pointer(pSession)), uintptr(flags), uintptr(unsafe.Pointer(sessionKey))) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +}