From db39a43f063a73525b6ea93006d888b53e692846 Mon Sep 17 00:00:00 2001 From: Aaron Klotz Date: Fri, 3 Nov 2023 13:56:46 -0600 Subject: [PATCH] util/winutil: add support for restarting Windows processes in specific sessions This PR is all about adding functionality that will enable the installer's upgrade sequence to terminate processes belonging to the previous version, and then subsequently restart instances belonging to the new version within the session(s) corresponding to the processes that were killed. There are multiple parts to this: * We add support for the Restart Manager APIs, which allow us to query the OS for a list of processes locking specific files; * We add the RestartableProcess and RestartableProcesses types that query additional information about the running processes that will allow us to correctly restart them in the future. These types also provide the ability to terminate the processes. * We add the StartProcessInSession family of APIs that permit us to create new processes within specific sessions. This is needed in order to properly attach a new GUI process to the same RDP session and desktop that its previously-terminated counterpart would have been running in. * I tweaked the winutil token APIs again. * A lot of this stuff is pretty hard to test without a very elaborate harness, but I added a unit test for the most complicated part (though it requires LocalSystem to run). Updates https://github.com/tailscale/corp/issues/13998 Signed-off-by: Aaron Klotz --- scripts/check_license_headers.sh | 18 +- util/winutil/mksyscall.go | 6 + util/winutil/restartmgr_windows.go | 836 ++++++++++++++++++ util/winutil/restartmgr_windows_test.go | 147 +++ util/winutil/subprocess_windows_test.go | 433 +++++++++ .../testdata/testrestartableprocesses/main.go | 40 + .../restartableprocess_windows.go | 16 + util/winutil/winutil_windows.go | 86 +- util/winutil/zsyscall_windows.go | 57 +- 9 files changed, 1564 insertions(+), 75 deletions(-) create mode 100644 util/winutil/restartmgr_windows.go create mode 100644 util/winutil/restartmgr_windows_test.go create mode 100644 util/winutil/subprocess_windows_test.go create mode 100644 util/winutil/testdata/testrestartableprocesses/main.go create mode 100644 util/winutil/testdata/testrestartableprocesses/restartableprocess_windows.go diff --git a/scripts/check_license_headers.sh b/scripts/check_license_headers.sh index 9a5ae02f5..bbb128e17 100755 --- a/scripts/check_license_headers.sh +++ b/scripts/check_license_headers.sh @@ -37,15 +37,21 @@ for file in $(find $1 \( -name '*.go' -or -name '*.tsx' -or -name '*.ts' -not -n $1/cmd/tailscale/cli/authenticode_windows.go) # WireGuard copyright. ;; - *_string.go) - # Generated file from go:generate stringer - ;; - $1/control/controlbase/noiseexplorer_test.go) - # Noiseexplorer.com copyright. - ;; + *_string.go) + # Generated file from go:generate stringer + ;; + $1/control/controlbase/noiseexplorer_test.go) + # Noiseexplorer.com copyright. + ;; */zsyscall_windows.go) # Generated syscall wrappers ;; + $1/util/winutil/subprocess_windows_test.go) + # Subprocess test harness code + ;; + $1/util/winutil/testdata/testrestartableprocesses/main.go) + # Subprocess test harness code + ;; *) header="$(head -2 $file)" if ! check_file "$header"; then diff --git a/util/winutil/mksyscall.go b/util/winutil/mksyscall.go index 3c5515ee0..17f41ddcc 100644 --- a/util/winutil/mksyscall.go +++ b/util/winutil/mksyscall.go @@ -6,5 +6,11 @@ package winutil //go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go //go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go +//sys getApplicationRestartSettings(process windows.Handle, commandLine *uint16, commandLineLen *uint32, flags *uint32) (ret wingoes.HRESULT) = kernel32.GetApplicationRestartSettings //sys queryServiceConfig2(hService windows.Handle, infoLevel uint32, buf *byte, bufLen uint32, bytesNeeded *uint32) (err error) [failretval==0] = advapi32.QueryServiceConfig2W //sys registerApplicationRestart(cmdLineExclExeName *uint16, flags uint32) (ret wingoes.HRESULT) = kernel32.RegisterApplicationRestart +//sys rmEndSession(session _RMHANDLE) (ret error) = rstrtmgr.RmEndSession +//sys rmGetList(session _RMHANDLE, nProcInfoNeeded *uint32, nProcInfo *uint32, rgAffectedApps *_RM_PROCESS_INFO, pRebootReasons *uint32) (ret error) = rstrtmgr.RmGetList +//sys rmJoinSession(pSession *_RMHANDLE, sessionKey *uint16) (ret error) = rstrtmgr.RmJoinSession +//sys rmRegisterResources(session _RMHANDLE, nFiles uint32, rgsFileNames **uint16, nApplications uint32, rgApplications *_RM_UNIQUE_PROCESS, nServices uint32, rgsServiceNames **uint16) (ret error) = rstrtmgr.RmRegisterResources +//sys rmStartSession(pSession *_RMHANDLE, flags uint32, sessionKey *uint16) (ret error) = rstrtmgr.RmStartSession diff --git a/util/winutil/restartmgr_windows.go b/util/winutil/restartmgr_windows.go new file mode 100644 index 000000000..254b736b2 --- /dev/null +++ b/util/winutil/restartmgr_windows.go @@ -0,0 +1,836 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package winutil + +import ( + "bytes" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "runtime" + "strings" + "time" + "unicode/utf16" + "unsafe" + + "github.com/dblohm7/wingoes" + "golang.org/x/sys/windows" + "tailscale.com/types/logger" + "tailscale.com/util/multierr" +) + +var ( + ErrDefunctProcess = errors.New("process is defunct") + ErrProcessNotRestartable = errors.New("process is not restartable") +) + +// Implementation note: the code in this file will be invoked from within +// MSI custom actions, so please try to return windows.Errno error codes +// whenever possible; this makes the action return more accurate errors to +// the installer engine. + +const ( + _RESTART_NO_CRASH = 1 + _RESTART_NO_HANG = 2 + _RESTART_NO_PATCH = 4 + _RESTART_NO_REBOOT = 8 +) + +func registerForRestart(opts RegisterForRestartOpts) error { + var flags uint32 + + if !opts.RestartOnCrash { + flags |= _RESTART_NO_CRASH + } + if !opts.RestartOnHang { + flags |= _RESTART_NO_HANG + } + if !opts.RestartOnUpgrade { + flags |= _RESTART_NO_PATCH + } + if !opts.RestartOnReboot { + flags |= _RESTART_NO_REBOOT + } + + var cmdLine *uint16 + if opts.UseCmdLineArgs { + if len(opts.CmdLineArgs) == 0 { + // re-use our current args, excluding the exe name itself + opts.CmdLineArgs = os.Args[1:] + } + + var b strings.Builder + for _, arg := range opts.CmdLineArgs { + if b.Len() > 0 { + b.WriteByte(' ') + } + b.WriteString(windows.EscapeArg(arg)) + } + + if b.Len() > 0 { + var err error + cmdLine, err = windows.UTF16PtrFromString(b.String()) + if err != nil { + return err + } + } + } + + hr := registerApplicationRestart(cmdLine, flags) + if e := wingoes.ErrorFromHRESULT(hr); e.Failed() { + return e + } + + return nil +} + +type _RMHANDLE uint32 + +// See https://web.archive.org/web/20231128212837/https://learn.microsoft.com/en-us/windows/win32/rstmgr/using-restart-manager-with-a-secondary-installer +const _INVALID_RMHANDLE = ^_RMHANDLE(0) + +type _RM_UNIQUE_PROCESS struct { + PID uint32 + ProcessStartTime windows.Filetime +} + +type _RM_APP_TYPE int32 + +const ( + _RmUnknownApp _RM_APP_TYPE = 0 + _RmMainWindow _RM_APP_TYPE = 1 + _RmOtherWindow _RM_APP_TYPE = 2 + _RmService _RM_APP_TYPE = 3 + _RmExplorer _RM_APP_TYPE = 4 + _RmConsole _RM_APP_TYPE = 5 + _RmCritical _RM_APP_TYPE = 1000 +) + +type _RM_APP_STATUS uint32 + +const ( + _RmStatusUnknown _RM_APP_STATUS = 0x0 + _RmStatusRunning _RM_APP_STATUS = 0x1 + _RmStatusStopped _RM_APP_STATUS = 0x2 + _RmStatusStoppedOther _RM_APP_STATUS = 0x4 + _RmStatusRestarted _RM_APP_STATUS = 0x8 + _RmStatusErrorOnStop _RM_APP_STATUS = 0x10 + _RmStatusErrorOnRestart _RM_APP_STATUS = 0x20 + _RmStatusShutdownMasked _RM_APP_STATUS = 0x40 + _RmStatusRestartMasked _RM_APP_STATUS = 0x80 +) + +type _RM_PROCESS_INFO struct { + Process _RM_UNIQUE_PROCESS + AppName [256]uint16 + ServiceShortName [64]uint16 + AppType _RM_APP_TYPE + AppStatus _RM_APP_STATUS + TSSessionID uint32 + Restartable int32 // Win32 BOOL +} + +// RestartManagerSession represents an open Restart Manager session. +type RestartManagerSession interface { + io.Closer + // AddPaths adds the fully-qualified paths in fqPaths to the set of binaries + // that will be monitored by this restart manager session. NOTE: This + // method is expensive to call, so it is better to make a single call with + // a larger slice than to make multiple calls with smaller slices. + AddPaths(fqPaths []string) error + // AffectedProcesses returns the UniqueProcess information for all running + // processes that utilize the binaries previously specified by calls to + // AddPaths. + AffectedProcesses() ([]UniqueProcess, error) + // Key returns the session key associated with this instance. + Key() string +} + +// rmSession encapsulates the necessary information to represent an open +// restart manager session. +// +// Implementation note: rmSession methods that return errors should use +// windows.Errno codes whenever possible, as we call them from the custom +// action DLL. MSI custom actions are expected to return windows.Errno values; +// to ensure our compliance with this expectation, we should also use those +// values. Failure to do so will result in a generic windows.Errno being +// returned to the Windows Installer, which obviously is less than ideal. +type rmSession struct { + session _RMHANDLE + key string + logf logger.Logf +} + +const _CCH_RM_SESSION_KEY = 32 // (excludes NUL terminator) + +// NewRestartManagerSession creates a new RestartManagerSession that utilizes +// logf for logging. +func NewRestartManagerSession(logf logger.Logf) (RestartManagerSession, error) { + var sessionKeyBuf [_CCH_RM_SESSION_KEY + 1]uint16 + result := rmSession{ + logf: logf, + } + if err := rmStartSession(&result.session, 0, &sessionKeyBuf[0]); err != nil { + return nil, err + } + + result.key = windows.UTF16ToString(sessionKeyBuf[:_CCH_RM_SESSION_KEY]) + return &result, nil +} + +// AttachRestartManagerSession opens a connection to an existing session +// specified by sessionKey, using logf for logging. +func AttachRestartManagerSession(logf logger.Logf, sessionKey string) (RestartManagerSession, error) { + sessionKey16, err := windows.UTF16PtrFromString(sessionKey) + if err != nil { + return nil, err + } + + result := rmSession{ + key: sessionKey, + logf: logf, + } + if err := rmJoinSession(&result.session, sessionKey16); err != nil { + return nil, err + } + return &result, nil +} + +func (rms *rmSession) Close() error { + if rms == nil || rms.session == _INVALID_RMHANDLE { + return nil + } + if err := rmEndSession(rms.session); err != nil { + return err + } + rms.session = _INVALID_RMHANDLE + return nil +} + +func (rms *rmSession) Key() string { + return rms.key +} + +func (rms *rmSession) AffectedProcesses() ([]UniqueProcess, error) { + infos, err := rms.processList() + if err != nil { + return nil, err + } + + result := make([]UniqueProcess, 0, len(infos)) + for _, info := range infos { + result = append(result, UniqueProcess{ + _RM_UNIQUE_PROCESS: info.Process, + CanReceiveGUIMsgs: info.AppType == _RmMainWindow || info.AppType == _RmOtherWindow, + }) + } + + return result, nil +} + +func (rms *rmSession) processList() ([]_RM_PROCESS_INFO, error) { + const maxAttempts = 5 + var avail, rebootReasons uint32 + needed := uint32(1) + + var buf []_RM_PROCESS_INFO + err := error(windows.ERROR_MORE_DATA) + numAttempts := 0 + for err == windows.ERROR_MORE_DATA && numAttempts < maxAttempts { + numAttempts++ + buf = make([]_RM_PROCESS_INFO, needed) + avail = needed + err = rmGetList(rms.session, &needed, &avail, unsafe.SliceData(buf), &rebootReasons) + } + + if err != nil { + if err == windows.ERROR_SESSION_CREDENTIAL_CONFLICT { + // Add some more context about the meaning of this error. + err = fmt.Errorf("%w (the Restart Manager does not permit calling RmGetList from a process that did not originally create the session)", err) + } + return nil, err + } + + return buf[:avail], nil +} + +func (rms *rmSession) AddPaths(fqPaths []string) error { + if len(fqPaths) == 0 { + return nil + } + + fqPaths16 := make([]*uint16, 0, len(fqPaths)) + for _, fqPath := range fqPaths { + if !filepath.IsAbs(fqPath) { + return fmt.Errorf("%w: paths must be fully-qualified", windows.ERROR_BAD_PATHNAME) + } + + fqPath16, err := windows.UTF16PtrFromString(fqPath) + if err != nil { + return err + } + + fqPaths16 = append(fqPaths16, fqPath16) + } + + return rmRegisterResources(rms.session, uint32(len(fqPaths16)), unsafe.SliceData(fqPaths16), 0, nil, 0, nil) +} + +// UniqueProcess contains the necessary information to uniquely identify a +// process in the face of potential PID reuse. +type UniqueProcess struct { + _RM_UNIQUE_PROCESS + // CanReceiveGUIMsgs is true when the process has open top-level windows. + CanReceiveGUIMsgs bool +} + +// AsRestartableProcess obtains a RestartableProcess populated using the +// information obtained from up. +func (up *UniqueProcess) AsRestartableProcess() (*RestartableProcess, error) { + // We need PROCESS_QUERY_INFORMATION instead of PROCESS_QUERY_LIMITED_INFORMATION + // in order for ProcessImageName to be able to work from within a privileged + // Windows Installer process. + // We need PROCESS_VM_READ for GetApplicationRestartSettings. + // We need PROCESS_TERMINATE and SYNCHRONIZE to terminate the process and + // to be able to wait for the terminated process's handle to signal. + access := uint32(windows.PROCESS_QUERY_INFORMATION | windows.PROCESS_TERMINATE | windows.PROCESS_VM_READ | windows.SYNCHRONIZE) + h, err := windows.OpenProcess(access, false, up.PID) + if err != nil { + return nil, fmt.Errorf("OpenProcess(%d[%#X]): %w", up.PID, up.PID, err) + } + defer func() { + if h == 0 { + return + } + windows.CloseHandle(h) + }() + + var creationTime, exitTime, kernelTime, userTime windows.Filetime + if err := windows.GetProcessTimes(h, &creationTime, &exitTime, &kernelTime, &userTime); err != nil { + return nil, fmt.Errorf("GetProcessTimes: %w", err) + } + if creationTime != up.ProcessStartTime { + // The PID has been reused and does not actually reference the original process. + return nil, ErrDefunctProcess + } + + var tok windows.Token + if err := windows.OpenProcessToken(h, windows.TOKEN_QUERY, &tok); err != nil { + return nil, fmt.Errorf("OpenProcessToken: %w", err) + } + defer tok.Close() + + tsSessionID, err := TSSessionID(tok) + if err != nil { + return nil, fmt.Errorf("TSSessionID: %w", err) + } + + logonSessionID, err := LogonSessionID(tok) + if err != nil { + return nil, fmt.Errorf("LogonSessionID: %w", err) + } + + img, err := ProcessImageName(h) + if err != nil { + return nil, fmt.Errorf("ProcessImageName: %w", err) + } + + const _RESTART_MAX_CMD_LINE = 1024 + var cmdLine [_RESTART_MAX_CMD_LINE]uint16 + cmdLineLen := uint32(len(cmdLine)) + var rmFlags uint32 + hr := getApplicationRestartSettings(h, &cmdLine[0], &cmdLineLen, &rmFlags) + // Not found is not an error; it just means that the app never set any restart settings. + if e := wingoes.ErrorFromHRESULT(hr); e.Failed() && e != wingoes.ErrorFromErrno(windows.ERROR_NOT_FOUND) { + return nil, fmt.Errorf("GetApplicationRestartSettings: %w", error(e)) + } + if (rmFlags & _RESTART_NO_PATCH) != 0 { + // The application explicitly stated that it cannot be restarted during + // an upgrade. + return nil, ErrProcessNotRestartable + } + + var logonSID string + // Non-fatal, so we'll proceed with best-effort. + if tokenGroups, err := tok.GetTokenGroups(); err == nil { + for _, group := range tokenGroups.AllGroups() { + if (group.Attributes & windows.SE_GROUP_LOGON_ID) != 0 { + logonSID = group.Sid.String() + break + } + } + } + + var userSID string + // Non-fatal, so we'll proceed with best-effort. + if tokenUser, err := tok.GetTokenUser(); err == nil { + // Save the user's SID so that we can later check it against the currently + // logged-in Tailscale profile. + userSID = tokenUser.User.Sid.String() + } + + result := &RestartableProcess{ + Process: *up, + SessionInfo: SessionID{ + LogonSession: logonSessionID, + TSSession: tsSessionID, + }, + CommandLineInfo: CommandLineInfo{ + ExePath: img, + Args: windows.UTF16ToString(cmdLine[:cmdLineLen]), + }, + LogonSID: logonSID, + UserSID: userSID, + handle: h, + } + + runtime.SetFinalizer(result, func(rp *RestartableProcess) { rp.Close() }) + h = 0 + return result, nil +} + +// RestartableProcess contains the necessary information to uniquely identify +// an existing process, as well as the necessary information to be able to +// terminate it and later start a new instance in the identical logon session +// to the previous instance. +type RestartableProcess struct { + // Process uniquely identifies the existing process. + Process UniqueProcess + // SessionInfo uniquely identifies the Terminal Services (RDP) and logon + // sessions the existing process is running under. + SessionInfo SessionID + // CommandLineInfo contains the command line information necessary for restarting. + CommandLineInfo CommandLineInfo + // LogonSID contains the stringified SID of the existing process's token's logon session. + LogonSID string + // UserSID contains the stringified SID of the existing process's token's user. + UserSID string + // handle specifies the Win32 HANDLE associated with the existing process. + // When non-zero, it includes access rights for querying, terminating, and synchronizing. + handle windows.Handle + // hasExitCode is true when the exitCode field is valid. + hasExitCode bool + // exitCode contains exit code returned by this RestartableProcess once + // its termination has been recorded by (RestartableProcesses).Terminate. + // It is only valid when hasExitCode == true. + exitCode uint32 +} + +func (rp *RestartableProcess) Close() error { + if rp.handle == 0 { + return nil + } + windows.CloseHandle(rp.handle) + runtime.SetFinalizer(rp, nil) + rp.handle = 0 + return nil +} + +// RestartableProcesses is a map of PID to *RestartableProcess instance. +type RestartableProcesses map[uint32]*RestartableProcess + +// NewRestartableProcesses instantiates a new RestartableProcesses. +func NewRestartableProcesses() RestartableProcesses { + return make(RestartableProcesses) +} + +// Add inserts rp into rps. +func (rps RestartableProcesses) Add(rp *RestartableProcess) { + if rp != nil { + rps[rp.Process.PID] = rp + } +} + +// Delete removes rp from rps. +func (rps RestartableProcesses) Delete(rp *RestartableProcess) { + if rp != nil { + delete(rps, rp.Process.PID) + } +} + +// Close invokes (*RestartableProcess).Close on every value in rps, and then +// clears rps. +func (rps RestartableProcesses) Close() error { + for _, v := range rps { + v.Close() + } + clear(rps) + return nil +} + +// _MAXIMUM_WAIT_OBJECTS is the Win32 constant for the maximum number of +// handles that a call to WaitForMultipleObjects may receive at once. +const _MAXIMUM_WAIT_OBJECTS = 64 + +// Terminate forcibly terminates all processes in rps using exitCode, and then +// waits for their process handles to signal, up to timeout. +func (rps RestartableProcesses) Terminate(logf logger.Logf, exitCode uint32, timeout time.Duration) error { + if len(rps) == 0 { + return nil + } + + millis, err := wingoes.DurationToTimeoutMilliseconds(timeout) + if err != nil { + return err + } + + errs := make([]error, 0, len(rps)) + procs := make([]*RestartableProcess, 0, len(rps)) + handles := make([]windows.Handle, 0, len(rps)) + for _, v := range rps { + if err := windows.TerminateProcess(v.handle, exitCode); err != nil { + if err == windows.ERROR_ACCESS_DENIED { + // If v terminated before we attempted to terminate, we'll receive + // ERROR_ACCESS_DENIED, which is not really an error worth reporting in + // our use case. Just obtain the exit code and then close the process. + if err := windows.GetExitCodeProcess(v.handle, &v.exitCode); err != nil { + logf("GetExitCodeProcess failed: %v", err) + } else { + v.hasExitCode = true + } + v.Close() + } else { + errs = append(errs, &terminationError{rp: v, err: err}) + } + continue + } + procs = append(procs, v) + handles = append(handles, v.handle) + } + + for len(handles) > 0 { + // WaitForMultipleObjects can only wait on _MAXIMUM_WAIT_OBJECTS handles per + // call, so we batch them as necessary. + count := uint32(min(len(handles), _MAXIMUM_WAIT_OBJECTS)) + waitCode, err := windows.WaitForMultipleObjects(handles[:count], true, millis) + if err != nil { + errs = append(errs, fmt.Errorf("waiting on terminated process handles: %w", err)) + break + } + if e := windows.Errno(waitCode); e == windows.WAIT_TIMEOUT { + errs = append(errs, fmt.Errorf("waiting on terminated process handles: %w", error(e))) + break + } + if waitCode >= windows.WAIT_OBJECT_0 && waitCode < (windows.WAIT_OBJECT_0+count) { + // The first count process handles have all been signaled. Close them out. + for _, proc := range procs[:count] { + if err := windows.GetExitCodeProcess(proc.handle, &proc.exitCode); err != nil { + logf("GetExitCodeProcess failed: %v", err) + } else { + proc.hasExitCode = true + } + proc.Close() + } + procs = procs[count:] + handles = handles[count:] + continue + } + // We really shouldn't be reaching this point + panic(fmt.Sprintf("unexpected state from WaitForMultipleObjects: %d", waitCode)) + } + + if len(errs) != 0 { + return multierr.New(errs...) + } + return nil +} + +type terminationError struct { + rp *RestartableProcess + err error +} + +func (te *terminationError) Error() string { + pid := te.rp.Process.PID + return fmt.Sprintf("terminating process %d (%#X): %v", pid, pid, te.err) +} + +func (te *terminationError) Unwrap() error { + return te.err +} + +// SessionID encapsulates the necessary information for uniquely identifying +// sessions. In particular, SessionID contains enough information to detect +// reuse of Terminal Service session IDs. +type SessionID struct { + // LogonSession is the NT logon session ID. + LogonSession windows.LUID + // TSSession is the terminal services session ID. + TSSession uint32 +} + +// OpenToken obtains the security token associated with sessID. +func (sessID *SessionID) OpenToken() (windows.Token, error) { + var token windows.Token + if err := windows.WTSQueryUserToken(sessID.TSSession, &token); err != nil { + return 0, err + } + + var err error + defer func() { + if err != nil { + token.Close() + } + }() + + tokenLogonSession, err := LogonSessionID(token) + if err != nil { + return 0, err + } + + if tokenLogonSession != sessID.LogonSession { + err = windows.ERROR_NO_SUCH_LOGON_SESSION + return 0, err + } + + return token, nil +} + +// ContainsToken determines whether token is contained within sessID. +func (sessID *SessionID) ContainsToken(token windows.Token) (bool, error) { + tokenTSSessionID, err := TSSessionID(token) + if err != nil { + return false, err + } + + if tokenTSSessionID != sessID.TSSession { + return false, nil + } + + tokenLogonSession, err := LogonSessionID(token) + if err != nil { + return false, err + } + + return tokenLogonSession == sessID.LogonSession, nil +} + +// This is the Window Station and Desktop within a particular session that must +// be specified for interactive processes: "Winsta0\\default\x00" +var defaultDesktop = unsafe.SliceData([]uint16{'W', 'i', 'n', 's', 't', 'a', '0', '\\', 'd', 'e', 'f', 'a', 'u', 'l', 't', 0}) + +// CommandLineInfo manages the necessary information for creating a Win32 +// process using a specific command line. +type CommandLineInfo struct { + // ExePath must be a fully-qualified path to a Windows executable binary. + ExePath string + // Args must be any arguments supplied to the process, excluding the + // path to the binary itself. Args must be properly quoted according to + // Windows path rules. To create a properly quoted Args from scratch, call the + // SetArgs method instead. + Args string `json:",omitempty"` +} + +// SetArgs converts args to a string quoted as necessary to satisfy the rules +// for Win32 command lines, and sets cli.Args to that string. +func (cli *CommandLineInfo) SetArgs(args []string) { + var buf strings.Builder + for _, arg := range args { + if buf.Len() > 0 { + buf.WriteByte(' ') + } + buf.WriteString(windows.EscapeArg(arg)) + } + + cli.Args = buf.String() +} + +// Validate ensures that cli.ExePath contains an absolute path. +func (cli *CommandLineInfo) Validate() error { + if cli == nil { + return windows.ERROR_INVALID_PARAMETER + } + + if !filepath.IsAbs(cli.ExePath) { + return fmt.Errorf("%w: CommandLineInfo requires absolute ExePath", windows.ERROR_BAD_PATHNAME) + } + + return nil +} + +// Resolve converts the information in cli to a format compatible with the Win32 +// CreateProcess* family of APIs, as pointers to C-style UTF-16 strings. It also +// returns the full command line as a Go string for logging purposes. +func (cli *CommandLineInfo) Resolve() (exePath *uint16, cmdLine *uint16, cmdLineStr string, err error) { + // Resolve cmdLine first since that also does a Validate. + cmdLineStr, cmdLine, err = cli.resolveArgsAsUTF16Ptr() + if err != nil { + return nil, nil, "", err + } + + exePath, err = windows.UTF16PtrFromString(cli.ExePath) + if err != nil { + return nil, nil, "", err + } + + return exePath, cmdLine, cmdLineStr, nil +} + +// resolveArgs quotes cli.ExePath as necessary, appends Args, and returns the result. +func (cli *CommandLineInfo) resolveArgs() (string, error) { + if err := cli.Validate(); err != nil { + return "", err + } + + var cmdLineBuf strings.Builder + cmdLineBuf.WriteString(windows.EscapeArg(cli.ExePath)) + if args := cli.Args; args != "" { + cmdLineBuf.WriteByte(' ') + cmdLineBuf.WriteString(args) + } + + return cmdLineBuf.String(), nil +} + +func (cli *CommandLineInfo) resolveArgsAsUTF16Ptr() (string, *uint16, error) { + s, err := cli.resolveArgs() + if err != nil { + return "", nil, err + } + s16, err := windows.UTF16PtrFromString(s) + if err != nil { + return "", nil, err + } + return s, s16, nil +} + +// StartProcessInSession creates a new process using cmdLineInfo that will +// reside inside the session identified by sessID, with the security token whose +// logon is associated with sessID. The child process's environment will be +// inherited from the session token's environment. +func StartProcessInSession(sessID SessionID, cmdLineInfo CommandLineInfo) error { + return StartProcessInSessionWithHandler(sessID, cmdLineInfo, nil) +} + +// PostCreateProcessHandler is a function that is invoked by +// StartProcessInSessionWithHandler when the child process has been successfully +// created. It is the responsibility of the handler to close the pi.Thread and +// pi.Process handles. +type PostCreateProcessHandler func(pi *windows.ProcessInformation) + +// StartProcessInSessionWithHandler creates a new process using cmdLineInfo that +// will reside inside the session identified by sessID, with the security token +// whose logon is associated with sessID. The child process's environment will be +// inherited from the session token's environment. When the child process has +// been successfully created, handler is invoked with the windows.ProcessInformation +// that was returned by the OS. +func StartProcessInSessionWithHandler(sessID SessionID, cmdLineInfo CommandLineInfo, handler PostCreateProcessHandler) error { + pi, err := startProcessInSessionInternal(sessID, cmdLineInfo, 0) + if err != nil { + return err + } + if handler != nil { + handler(pi) + return nil + } + windows.CloseHandle(pi.Process) + windows.CloseHandle(pi.Thread) + return nil +} + +// RunProcessInSession creates a new process and waits up to timeout for that +// child process to complete its execution. The process is created using +// cmdLineInfo and will reside inside the session identified by sessID, with the +// security token whose logon is associated with sessID. The child process's +// environment will be inherited from the session token's environment. +func RunProcessInSession(sessID SessionID, cmdLineInfo CommandLineInfo, timeout time.Duration) (uint32, error) { + timeoutMillis, err := wingoes.DurationToTimeoutMilliseconds(timeout) + if err != nil { + return 1, err + } + + pi, err := startProcessInSessionInternal(sessID, cmdLineInfo, 0) + if err != nil { + return 1, err + } + windows.CloseHandle(pi.Thread) + defer windows.CloseHandle(pi.Process) + + waitCode, err := windows.WaitForSingleObject(pi.Process, timeoutMillis) + if err != nil { + return 1, fmt.Errorf("WaitForSingleObject: %w", err) + } + if e := windows.Errno(waitCode); e == windows.WAIT_TIMEOUT { + return 1, e + } + if waitCode != windows.WAIT_OBJECT_0 { + // This should not be possible; log + return 1, fmt.Errorf("unexpected state from WaitForSingleObject: %d", waitCode) + } + + var exitCode uint32 + if err := windows.GetExitCodeProcess(pi.Process, &exitCode); err != nil { + return 1, err + } + return exitCode, nil +} + +func startProcessInSessionInternal(sessID SessionID, cmdLineInfo CommandLineInfo, extraFlags uint32) (*windows.ProcessInformation, error) { + if err := cmdLineInfo.Validate(); err != nil { + return nil, err + } + + token, err := sessID.OpenToken() + if err != nil { + return nil, fmt.Errorf("(*SessionID).OpenToken: %w", err) + } + defer token.Close() + + exePath16, commandLine16, _, err := cmdLineInfo.Resolve() + if err != nil { + return nil, fmt.Errorf("(*CommandLineInfo).Resolve(): %w", err) + } + + wd16, err := windows.UTF16PtrFromString(filepath.Dir(cmdLineInfo.ExePath)) + if err != nil { + return nil, fmt.Errorf("UTF16PtrFromString(wd): %w", err) + } + + env, err := token.Environ(false) + if err != nil { + return nil, fmt.Errorf("token environment: %w", err) + } + env16 := newEnvBlock(env) + + // The privileges in privNames are required for CreateProcessAsUser to be + // able to start processes as other users in other logon sessions. + privNames := []string{ + "SeAssignPrimaryTokenPrivilege", + "SeIncreaseQuotaPrivilege", + } + dropPrivs, err := EnableCurrentThreadPrivileges(privNames) + if err != nil { + return nil, fmt.Errorf("EnableCurrentThreadPrivileges(%#v): %w", privNames, err) + } + defer dropPrivs() + + createFlags := extraFlags | windows.CREATE_UNICODE_ENVIRONMENT | windows.DETACHED_PROCESS + si := windows.StartupInfo{ + Cb: uint32(unsafe.Sizeof(windows.StartupInfo{})), + Desktop: defaultDesktop, + } + var pi windows.ProcessInformation + if err := windows.CreateProcessAsUser(token, exePath16, commandLine16, nil, nil, + false, createFlags, env16, wd16, &si, &pi); err != nil { + return nil, fmt.Errorf("CreateProcessAsUser: %w", err) + } + return &pi, nil +} + +func newEnvBlock(env []string) *uint16 { + // Intentionally using bytes.Buffer here because we're writing nul bytes (the standard library does this too). + var buf bytes.Buffer + for _, v := range env { + buf.WriteString(v) + buf.WriteByte(0) + } + if buf.Len() == 0 { + // So that we end with a double-null in the empty env case + buf.WriteByte(0) + } + buf.WriteByte(0) + return unsafe.SliceData(utf16.Encode([]rune(string(buf.Bytes())))) +} diff --git a/util/winutil/restartmgr_windows_test.go b/util/winutil/restartmgr_windows_test.go new file mode 100644 index 000000000..0293bc135 --- /dev/null +++ b/util/winutil/restartmgr_windows_test.go @@ -0,0 +1,147 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package winutil + +import ( + "fmt" + "os/user" + "path/filepath" + "strings" + "testing" + "time" + "unsafe" + + "golang.org/x/sys/windows" +) + +const oldFashionedCleanupExitCode = 7778 + +// oldFashionedCleanup cleans up any outstanding binaries using older APIs. +// This would be necessary if the restart manager were to fail during the test. +func oldFashionedCleanup(t *testing.T, binary string) { + snap, err := windows.CreateToolhelp32Snapshot(windows.TH32CS_SNAPPROCESS, 0) + if err != nil { + t.Logf("CreateToolhelp32Snapshot failed: %v", err) + } + defer windows.CloseHandle(snap) + + binary = filepath.Clean(binary) + binbase := filepath.Base(binary) + pe := windows.ProcessEntry32{ + Size: uint32(unsafe.Sizeof(windows.ProcessEntry32{})), + } + for perr := windows.Process32First(snap, &pe); perr == nil; perr = windows.Process32Next(snap, &pe) { + curBin := windows.UTF16ToString(pe.ExeFile[:]) + // Coarse check against the leaf name of the binary + if !strings.EqualFold(binbase, curBin) { + continue + } + + proc, err := windows.OpenProcess(windows.PROCESS_QUERY_LIMITED_INFORMATION|windows.PROCESS_TERMINATE, false, pe.ProcessID) + if err != nil { + t.Logf("OpenProcess failed: %v", err) + continue + } + defer windows.CloseHandle(proc) + + img, err := ProcessImageName(proc) + if err != nil { + t.Logf("ProcessImageName failed: %v", err) + continue + } + + // Now check that their fully-qualified paths match. + if !strings.EqualFold(binary, filepath.Clean(img)) { + continue + } + + t.Logf("Found leftover pid %d, terminating...", pe.ProcessID) + if err := windows.TerminateProcess(proc, oldFashionedCleanupExitCode); err != nil && err != windows.ERROR_ACCESS_DENIED { + t.Logf("TerminateProcess failed: %v", err) + } + } +} + +func testRestartableProcessesImpl(N int, t *testing.T) { + const binary = "testrestartableprocesses" + fq := pathToTestProg(t, binary) + + for i := 0; i < N; i++ { + startTestProg(t, binary, "RestartableProcess") + } + t.Cleanup(func() { + oldFashionedCleanup(t, fq) + }) + + logf := func(format string, args ...any) { + t.Logf(format, args...) + } + rms, err := NewRestartManagerSession(logf) + if err != nil { + t.Fatalf("NewRestartManagerSession: %v", err) + } + defer rms.Close() + + if err := rms.AddPaths([]string{fq}); err != nil { + t.Fatalf("AddPaths: %v", err) + } + + ups, err := rms.AffectedProcesses() + if err != nil { + t.Fatalf("AffectedProcesses: %v", err) + } + + rps := NewRestartableProcesses() + defer rps.Close() + + for _, up := range ups { + rp, err := up.AsRestartableProcess() + if err != nil { + t.Errorf("AsRestartableProcess: %v", err) + continue + } + rps.Add(rp) + } + + const terminateWithExitCode = 7777 + if err := rps.Terminate(logf, terminateWithExitCode, time.Duration(15)*time.Second); err != nil { + t.Errorf("Terminate: %v", err) + } + + for k, v := range rps { + if v.hasExitCode { + if v.exitCode != terminateWithExitCode { + // Not strictly an error, but worth noting. + logf("Subprocess %d terminated with unexpected exit code %d", k, v.exitCode) + } + } else { + t.Errorf("Subprocess %d did not produce an exit code", k) + } + if v.handle != 0 { + t.Errorf("Subprocess %d is unexpectedly still open", k) + } + } +} + +func TestRestartableProcesses(t *testing.T) { + u, err := user.Current() + if err != nil { + t.Fatalf("Could not obtain current user") + } + if u.Uid != localSystemSID { + t.Skipf("This test must be run as SYSTEM") + } + + forN := func(fn func(int, *testing.T)) func([]int) { + return func(ns []int) { + for _, n := range ns { + t.Run(fmt.Sprintf("N=%d", n), func(tt *testing.T) { fn(n, tt) }) + } + } + }(testRestartableProcessesImpl) + + // Testing indicates that the restart manager cannot handle more than 127 processes (on Windows 10, at least), so we use that as our highest value. + ns := []int{0, 1, _MAXIMUM_WAIT_OBJECTS - 1, _MAXIMUM_WAIT_OBJECTS, _MAXIMUM_WAIT_OBJECTS + 1, _MAXIMUM_WAIT_OBJECTS*2 - 1} + forN(ns) +} diff --git a/util/winutil/subprocess_windows_test.go b/util/winutil/subprocess_windows_test.go new file mode 100644 index 000000000..4c6bb5977 --- /dev/null +++ b/util/winutil/subprocess_windows_test.go @@ -0,0 +1,433 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package winutil + +import ( + "bytes" + "errors" + "fmt" + "os" + "os/exec" + "path/filepath" + "runtime" + "strconv" + "strings" + "sync" + "testing" + "time" +) + +// The code in this file is adapted from internal/testenv in the Go source tree +// and is used for writing tests that require spawning subprocesses. + +var toRemove []string + +func TestMain(m *testing.M) { + status := m.Run() + for _, file := range toRemove { + os.RemoveAll(file) + } + os.Exit(status) +} + +var testprog struct { + sync.Mutex + dir string + target map[string]*buildexe +} + +type buildexe struct { + once sync.Once + exe string + err error +} + +func pathToTestProg(t *testing.T, binary string) string { + exe, err := buildTestProg(t, binary, "-buildvcs=false") + if err != nil { + t.Fatal(err) + } + return exe +} + +func runTestProg(t *testing.T, binary, name string, env ...string) string { + exe, err := buildTestProg(t, binary, "-buildvcs=false") + if err != nil { + t.Fatal(err) + } + + return runBuiltTestProg(t, exe, name, env...) +} + +func startTestProg(t *testing.T, binary, name string, env ...string) { + exe, err := buildTestProg(t, binary, "-buildvcs=false") + if err != nil { + t.Fatal(err) + } + + startBuiltTestProg(t, exe, name, env...) +} + +func runBuiltTestProg(t *testing.T, exe, name string, env ...string) string { + cmd := exec.Command(exe, name) + cmd.Env = append(cmd.Env, env...) + if testing.Short() { + cmd.Env = append(cmd.Env, "RUNTIME_TEST_SHORT=1") + } + out, _ := runWithTimeout(t, cmd) + return string(out) +} + +func startBuiltTestProg(t *testing.T, exe, name string, env ...string) { + cmd := exec.Command(exe, name) + cmd.Env = append(cmd.Env, env...) + if testing.Short() { + cmd.Env = append(cmd.Env, "RUNTIME_TEST_SHORT=1") + } + start(t, cmd) +} + +var serializeBuild = make(chan bool, 2) + +func buildTestProg(t *testing.T, binary string, flags ...string) (string, error) { + testprog.Lock() + if testprog.dir == "" { + dir, err := os.MkdirTemp("", "go-build") + if err != nil { + t.Fatalf("failed to create temp directory: %v", err) + } + testprog.dir = dir + toRemove = append(toRemove, dir) + } + + if testprog.target == nil { + testprog.target = make(map[string]*buildexe) + } + name := binary + if len(flags) > 0 { + nameFlags := make([]string, 0, len(flags)) + for _, flag := range flags { + nameFlags = append(nameFlags, strings.ReplaceAll(flag, "=", "_")) + } + name += "_" + strings.Join(nameFlags, "_") + } + target, ok := testprog.target[name] + if !ok { + target = &buildexe{} + testprog.target[name] = target + } + + dir := testprog.dir + + // Unlock testprog while actually building, so that other + // tests can look up executables that were already built. + testprog.Unlock() + + target.once.Do(func() { + // Only do two "go build"'s at a time, + // to keep load from getting too high. + serializeBuild <- true + defer func() { <-serializeBuild }() + + // Don't get confused if goToolPath calls t.Skip. + target.err = errors.New("building test called t.Skip") + + exe := filepath.Join(dir, name+".exe") + + t.Logf("running go build -o %s %s", exe, strings.Join(flags, " ")) + cmd := exec.Command(goToolPath(t), append([]string{"build", "-o", exe}, flags...)...) + cmd.Dir = "testdata/" + binary + out, err := cmd.CombinedOutput() + if err != nil { + target.err = fmt.Errorf("building %s %v: %v\n%s", binary, flags, err, out) + } else { + target.exe = exe + target.err = nil + } + }) + + return target.exe, target.err +} + +// goTool reports the path to the Go tool. +func goTool() (string, error) { + if !hasGoBuild() { + return "", errors.New("platform cannot run go tool") + } + exeSuffix := ".exe" + goroot, err := findGOROOT() + if err != nil { + return "", fmt.Errorf("cannot find go tool: %w", err) + } + path := filepath.Join(goroot, "bin", "go"+exeSuffix) + if _, err := os.Stat(path); err == nil { + return path, nil + } + goBin, err := exec.LookPath("go" + exeSuffix) + if err != nil { + return "", errors.New("cannot find go tool: " + err.Error()) + } + return goBin, nil +} + +// knownEnv is a list of environment variables that affect the operation +// of the Go command. +const knownEnv = ` + AR + CC + CGO_CFLAGS + CGO_CFLAGS_ALLOW + CGO_CFLAGS_DISALLOW + CGO_CPPFLAGS + CGO_CPPFLAGS_ALLOW + CGO_CPPFLAGS_DISALLOW + CGO_CXXFLAGS + CGO_CXXFLAGS_ALLOW + CGO_CXXFLAGS_DISALLOW + CGO_ENABLED + CGO_FFLAGS + CGO_FFLAGS_ALLOW + CGO_FFLAGS_DISALLOW + CGO_LDFLAGS + CGO_LDFLAGS_ALLOW + CGO_LDFLAGS_DISALLOW + CXX + FC + GCCGO + GO111MODULE + GO386 + GOAMD64 + GOARCH + GOARM + GOBIN + GOCACHE + GOENV + GOEXE + GOEXPERIMENT + GOFLAGS + GOGCCFLAGS + GOHOSTARCH + GOHOSTOS + GOINSECURE + GOMIPS + GOMIPS64 + GOMODCACHE + GONOPROXY + GONOSUMDB + GOOS + GOPATH + GOPPC64 + GOPRIVATE + GOPROXY + GOROOT + GOSUMDB + GOTMPDIR + GOTOOLDIR + GOVCS + GOWASM + GOWORK + GO_EXTLINK_ENABLED + PKG_CONFIG +` + +// goToolPath reports the path to the Go tool. +// It is a convenience wrapper around goTool. +// If the tool is unavailable goToolPath calls t.Skip. +// If the tool should be available and isn't, goToolPath calls t.Fatal. +func goToolPath(t testing.TB) string { + mustHaveGoBuild(t) + path, err := goTool() + if err != nil { + t.Fatal(err) + } + // Add all environment variables that affect the Go command to test metadata. + // Cached test results will be invalidate when these variables change. + // See golang.org/issue/32285. + for _, envVar := range strings.Fields(knownEnv) { + os.Getenv(envVar) + } + return path +} + +// hasGoBuild reports whether the current system can build programs with “go build” +// and then run them with os.StartProcess or exec.Command. +func hasGoBuild() bool { + if os.Getenv("GO_GCFLAGS") != "" { + // It's too much work to require every caller of the go command + // to pass along "-gcflags="+os.Getenv("GO_GCFLAGS"). + // For now, if $GO_GCFLAGS is set, report that we simply can't + // run go build. + return false + } + return true +} + +// mustHaveGoBuild checks that the current system can build programs with “go build” +// and then run them with os.StartProcess or exec.Command. +// If not, mustHaveGoBuild calls t.Skip with an explanation. +func mustHaveGoBuild(t testing.TB) { + if os.Getenv("GO_GCFLAGS") != "" { + t.Skipf("skipping test: 'go build' not compatible with setting $GO_GCFLAGS") + } + if !hasGoBuild() { + t.Skipf("skipping test: 'go build' not available on %s/%s", runtime.GOOS, runtime.GOARCH) + } +} + +// hasGoRun reports whether the current system can run programs with “go run.” +func hasGoRun() bool { + // For now, having go run and having go build are the same. + return hasGoBuild() +} + +// mustHaveGoRun checks that the current system can run programs with “go run.” +// If not, mustHaveGoRun calls t.Skip with an explanation. +func mustHaveGoRun(t testing.TB) { + if !hasGoRun() { + t.Skipf("skipping test: 'go run' not available on %s/%s", runtime.GOOS, runtime.GOARCH) + } +} + +var ( + gorootOnce sync.Once + gorootPath string + gorootErr error +) + +func findGOROOT() (string, error) { + gorootOnce.Do(func() { + gorootPath = runtime.GOROOT() + if gorootPath != "" { + // If runtime.GOROOT() is non-empty, assume that it is valid. + // + // (It might not be: for example, the user may have explicitly set GOROOT + // to the wrong directory, or explicitly set GOROOT_FINAL but not GOROOT + // and hasn't moved the tree to GOROOT_FINAL yet. But those cases are + // rare, and if that happens the user can fix what they broke.) + return + } + + // runtime.GOROOT doesn't know where GOROOT is (perhaps because the test + // binary was built with -trimpath, or perhaps because GOROOT_FINAL was set + // without GOROOT and the tree hasn't been moved there yet). + // + // Since this is internal/testenv, we can cheat and assume that the caller + // is a test of some package in a subdirectory of GOROOT/src. ('go test' + // runs the test in the directory containing the packaged under test.) That + // means that if we start walking up the tree, we should eventually find + // GOROOT/src/go.mod, and we can report the parent directory of that. + + cwd, err := os.Getwd() + if err != nil { + gorootErr = fmt.Errorf("finding GOROOT: %w", err) + return + } + + dir := cwd + for { + parent := filepath.Dir(dir) + if parent == dir { + // dir is either "." or only a volume name. + gorootErr = fmt.Errorf("failed to locate GOROOT/src in any parent directory") + return + } + + if base := filepath.Base(dir); base != "src" { + dir = parent + continue // dir cannot be GOROOT/src if it doesn't end in "src". + } + + b, err := os.ReadFile(filepath.Join(dir, "go.mod")) + if err != nil { + if os.IsNotExist(err) { + dir = parent + continue + } + gorootErr = fmt.Errorf("finding GOROOT: %w", err) + return + } + goMod := string(b) + + for goMod != "" { + var line string + line, goMod, _ = strings.Cut(goMod, "\n") + fields := strings.Fields(line) + if len(fields) >= 2 && fields[0] == "module" && fields[1] == "std" { + // Found "module std", which is the module declaration in GOROOT/src! + gorootPath = parent + return + } + } + } + }) + + return gorootPath, gorootErr +} + +// runWithTimeout runs cmd and returns its combined output. If the +// subprocess exits with a non-zero status, it will log that status +// and return a non-nil error, but this is not considered fatal. +func runWithTimeout(t testing.TB, cmd *exec.Cmd) ([]byte, error) { + args := cmd.Args + if args == nil { + args = []string{cmd.Path} + } + + var b bytes.Buffer + cmd.Stdout = &b + cmd.Stderr = &b + if err := cmd.Start(); err != nil { + t.Fatalf("starting %s: %v", args, err) + } + + // If the process doesn't complete within 1 minute, + // assume it is hanging and kill it to get a stack trace. + p := cmd.Process + done := make(chan bool) + go func() { + scale := 2 + if s := os.Getenv("GO_TEST_TIMEOUT_SCALE"); s != "" { + if sc, err := strconv.Atoi(s); err == nil { + scale = sc + } + } + + select { + case <-done: + case <-time.After(time.Duration(scale) * time.Minute): + p.Signal(os.Kill) + // If SIGQUIT doesn't do it after a little + // while, kill the process. + select { + case <-done: + case <-time.After(time.Duration(scale) * 30 * time.Second): + p.Signal(os.Kill) + } + } + }() + + err := cmd.Wait() + if err != nil { + t.Logf("%s exit status: %v", args, err) + } + close(done) + + return b.Bytes(), err +} + +// start runs cmd asynchronously and returns immediately. +func start(t testing.TB, cmd *exec.Cmd) { + args := cmd.Args + if args == nil { + args = []string{cmd.Path} + } + + var b bytes.Buffer + cmd.Stdout = &b + cmd.Stderr = &b + if err := cmd.Start(); err != nil { + t.Fatalf("starting %s: %v", args, err) + } +} diff --git a/util/winutil/testdata/testrestartableprocesses/main.go b/util/winutil/testdata/testrestartableprocesses/main.go new file mode 100644 index 000000000..f5afdf769 --- /dev/null +++ b/util/winutil/testdata/testrestartableprocesses/main.go @@ -0,0 +1,40 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build windows + +package main + +import "os" + +var ( + cmds = map[string]func(){} + err error +) + +func register(name string, f func()) { + if cmds[name] != nil { + panic("duplicate registration: " + name) + } + cmds[name] = f +} + +func registerInit(name string, f func()) { + if len(os.Args) >= 2 && os.Args[1] == name { + f() + } +} + +func main() { + if len(os.Args) < 2 { + println("usage: " + os.Args[0] + " name-of-test") + return + } + f := cmds[os.Args[1]] + if f == nil { + println("unknown function: " + os.Args[1]) + return + } + f() +} diff --git a/util/winutil/testdata/testrestartableprocesses/restartableprocess_windows.go b/util/winutil/testdata/testrestartableprocesses/restartableprocess_windows.go new file mode 100644 index 000000000..d3be45102 --- /dev/null +++ b/util/winutil/testdata/testrestartableprocesses/restartableprocess_windows.go @@ -0,0 +1,16 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "golang.org/x/sys/windows" +) + +func init() { + register("RestartableProcess", RestartableProcess) +} + +func RestartableProcess() { + windows.SleepEx(windows.INFINITE, false) +} diff --git a/util/winutil/winutil_windows.go b/util/winutil/winutil_windows.go index 8fff9f056..53f368343 100644 --- a/util/winutil/winutil_windows.go +++ b/util/winutil/winutil_windows.go @@ -7,7 +7,6 @@ import ( "errors" "fmt" "log" - "os" "os/exec" "os/user" "runtime" @@ -16,7 +15,6 @@ import ( "time" "unsafe" - "github.com/dblohm7/wingoes" "golang.org/x/sys/windows" "golang.org/x/sys/windows/registry" ) @@ -384,7 +382,9 @@ func CreateAppMutex(name string) (windows.Handle, error) { return windows.CreateMutex(nil, false, windows.StringToUTF16Ptr(name)) } -func getTokenInfo[T any](token windows.Token, infoClass uint32) (*T, error) { +// getTokenInfoVariableLen obtains variable-length token information. Use +// this function for information classes that output variable-length data. +func getTokenInfoVariableLen[T any](token windows.Token, infoClass uint32) (*T, error) { var buf []byte var desiredLen uint32 @@ -402,6 +402,15 @@ func getTokenInfo[T any](token windows.Token, infoClass uint32) (*T, error) { return (*T)(unsafe.Pointer(unsafe.SliceData(buf))), nil } +// getTokenInfoFixedLen obtains known fixed-length token information. Use this +// function for information classes that output enumerations, BOOLs, integers etc. +func getTokenInfoFixedLen[T any](token windows.Token, infoClass uint32) (result T, err error) { + var actualLen uint32 + p := (*byte)(unsafe.Pointer(&result)) + err = windows.GetTokenInformation(token, infoClass, p, uint32(unsafe.Sizeof(result)), &actualLen) + return result, err +} + type tokenElevationType int32 const ( @@ -410,16 +419,9 @@ const ( tokenElevationTypeLimited tokenElevationType = 3 ) -func getTokenElevationType(token windows.Token) (result tokenElevationType, err error) { - var actualLen uint32 - p := (*byte)(unsafe.Pointer(&result)) - err = windows.GetTokenInformation(token, windows.TokenElevationType, p, uint32(unsafe.Sizeof(result)), &actualLen) - return result, err -} - // IsTokenLimited returns whether token is a limited UAC token. func IsTokenLimited(token windows.Token) (bool, error) { - elevationType, err := getTokenElevationType(token) + elevationType, err := getTokenInfoFixedLen[tokenElevationType](token, windows.TokenElevationType) if err != nil { return false, err } @@ -616,61 +618,6 @@ func findHomeDirInRegistry(uid string) (dir string, err error) { return dir, nil } -const ( - _RESTART_NO_CRASH = 1 - _RESTART_NO_HANG = 2 - _RESTART_NO_PATCH = 4 - _RESTART_NO_REBOOT = 8 -) - -func registerForRestart(opts RegisterForRestartOpts) error { - var flags uint32 - - if !opts.RestartOnCrash { - flags |= _RESTART_NO_CRASH - } - if !opts.RestartOnHang { - flags |= _RESTART_NO_HANG - } - if !opts.RestartOnUpgrade { - flags |= _RESTART_NO_PATCH - } - if !opts.RestartOnReboot { - flags |= _RESTART_NO_REBOOT - } - - var cmdLine *uint16 - if opts.UseCmdLineArgs { - if len(opts.CmdLineArgs) == 0 { - // re-use our current args, excluding the exe name itself - opts.CmdLineArgs = os.Args[1:] - } - - var b strings.Builder - for _, arg := range opts.CmdLineArgs { - if b.Len() > 0 { - b.WriteByte(' ') - } - b.WriteString(windows.EscapeArg(arg)) - } - - if b.Len() > 0 { - var err error - cmdLine, err = windows.UTF16PtrFromString(b.String()) - if err != nil { - return err - } - } - } - - hr := registerApplicationRestart(cmdLine, flags) - if e := wingoes.ErrorFromHRESULT(hr); e.Failed() { - return e - } - - return nil -} - // ProcessImageName returns the fully-qualified path to the executable image // associated with process. func ProcessImageName(process windows.Handle) (string, error) { @@ -694,13 +641,18 @@ func TSSessionIDToLogonSessionID(tsSessionID uint32) (logonSessionID windows.LUI return LogonSessionID(token) } +// TSSessionID obtains the Terminal Services (RDP) session ID associated with token. +func TSSessionID(token windows.Token) (tsSessionID uint32, err error) { + return getTokenInfoFixedLen[uint32](token, windows.TokenSessionId) +} + type tokenOrigin struct { originatingLogonSession windows.LUID } // LogonSessionID obtains the logon session ID associated with token. func LogonSessionID(token windows.Token) (logonSessionID windows.LUID, err error) { - origin, err := getTokenInfo[tokenOrigin](token, windows.TokenOrigin) + origin, err := getTokenInfoFixedLen[tokenOrigin](token, windows.TokenOrigin) if err != nil { return logonSessionID, err } diff --git a/util/winutil/zsyscall_windows.go b/util/winutil/zsyscall_windows.go index 77e9f36c8..b228ff158 100644 --- a/util/winutil/zsyscall_windows.go +++ b/util/winutil/zsyscall_windows.go @@ -41,9 +41,16 @@ func errnoErr(e syscall.Errno) error { var ( modadvapi32 = windows.NewLazySystemDLL("advapi32.dll") modkernel32 = windows.NewLazySystemDLL("kernel32.dll") + modrstrtmgr = windows.NewLazySystemDLL("rstrtmgr.dll") - procQueryServiceConfig2W = modadvapi32.NewProc("QueryServiceConfig2W") - procRegisterApplicationRestart = modkernel32.NewProc("RegisterApplicationRestart") + procQueryServiceConfig2W = modadvapi32.NewProc("QueryServiceConfig2W") + procGetApplicationRestartSettings = modkernel32.NewProc("GetApplicationRestartSettings") + procRegisterApplicationRestart = modkernel32.NewProc("RegisterApplicationRestart") + procRmEndSession = modrstrtmgr.NewProc("RmEndSession") + procRmGetList = modrstrtmgr.NewProc("RmGetList") + procRmJoinSession = modrstrtmgr.NewProc("RmJoinSession") + procRmRegisterResources = modrstrtmgr.NewProc("RmRegisterResources") + procRmStartSession = modrstrtmgr.NewProc("RmStartSession") ) func queryServiceConfig2(hService windows.Handle, infoLevel uint32, buf *byte, bufLen uint32, bytesNeeded *uint32) (err error) { @@ -54,8 +61,54 @@ func queryServiceConfig2(hService windows.Handle, infoLevel uint32, buf *byte, b return } +func getApplicationRestartSettings(process windows.Handle, commandLine *uint16, commandLineLen *uint32, flags *uint32) (ret wingoes.HRESULT) { + r0, _, _ := syscall.Syscall6(procGetApplicationRestartSettings.Addr(), 4, uintptr(process), uintptr(unsafe.Pointer(commandLine)), uintptr(unsafe.Pointer(commandLineLen)), uintptr(unsafe.Pointer(flags)), 0, 0) + ret = wingoes.HRESULT(r0) + return +} + func registerApplicationRestart(cmdLineExclExeName *uint16, flags uint32) (ret wingoes.HRESULT) { r0, _, _ := syscall.Syscall(procRegisterApplicationRestart.Addr(), 2, uintptr(unsafe.Pointer(cmdLineExclExeName)), uintptr(flags), 0) ret = wingoes.HRESULT(r0) return } + +func rmEndSession(session _RMHANDLE) (ret error) { + r0, _, _ := syscall.Syscall(procRmEndSession.Addr(), 1, uintptr(session), 0, 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func rmGetList(session _RMHANDLE, nProcInfoNeeded *uint32, nProcInfo *uint32, rgAffectedApps *_RM_PROCESS_INFO, pRebootReasons *uint32) (ret error) { + r0, _, _ := syscall.Syscall6(procRmGetList.Addr(), 5, uintptr(session), uintptr(unsafe.Pointer(nProcInfoNeeded)), uintptr(unsafe.Pointer(nProcInfo)), uintptr(unsafe.Pointer(rgAffectedApps)), uintptr(unsafe.Pointer(pRebootReasons)), 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func rmJoinSession(pSession *_RMHANDLE, sessionKey *uint16) (ret error) { + r0, _, _ := syscall.Syscall(procRmJoinSession.Addr(), 2, uintptr(unsafe.Pointer(pSession)), uintptr(unsafe.Pointer(sessionKey)), 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func rmRegisterResources(session _RMHANDLE, nFiles uint32, rgsFileNames **uint16, nApplications uint32, rgApplications *_RM_UNIQUE_PROCESS, nServices uint32, rgsServiceNames **uint16) (ret error) { + r0, _, _ := syscall.Syscall9(procRmRegisterResources.Addr(), 7, uintptr(session), uintptr(nFiles), uintptr(unsafe.Pointer(rgsFileNames)), uintptr(nApplications), uintptr(unsafe.Pointer(rgApplications)), uintptr(nServices), uintptr(unsafe.Pointer(rgsServiceNames)), 0, 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func rmStartSession(pSession *_RMHANDLE, flags uint32, sessionKey *uint16) (ret error) { + r0, _, _ := syscall.Syscall(procRmStartSession.Addr(), 3, uintptr(unsafe.Pointer(pSession)), uintptr(flags), uintptr(unsafe.Pointer(sessionKey))) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +}