From e3d6236606c627e7915d66beafda3c200e162835 Mon Sep 17 00:00:00 2001 From: Claire Wang Date: Tue, 26 Sep 2023 13:15:11 -0400 Subject: [PATCH] winutil: refactor methods to get values from registry to also return (#9536) errors Updates tailscale/corp#14879 Signed-off-by: Claire Wang --- cmd/tailscaled/tailscaled_windows.go | 4 +- control/controlclient/sign_supported.go | 2 +- hostinfo/hostinfo_windows.go | 3 +- ipn/ipnlocal/profiles.go | 12 ++++-- logpolicy/logpolicy.go | 3 +- util/winutil/policy/policy_windows.go | 14 +++++-- util/winutil/winutil.go | 36 ++++++++++-------- util/winutil/winutil_notwindows.go | 11 ++++-- util/winutil/winutil_windows.go | 49 +++++++++++++------------ 9 files changed, 78 insertions(+), 56 deletions(-) diff --git a/cmd/tailscaled/tailscaled_windows.go b/cmd/tailscaled/tailscaled_windows.go index 0d056250e..d383bf8c3 100644 --- a/cmd/tailscaled/tailscaled_windows.go +++ b/cmd/tailscaled/tailscaled_windows.go @@ -131,7 +131,7 @@ func runWindowsService(pol *logpolicy.Policy) error { osdiag.LogSupportInfo(logger.WithPrefix(log.Printf, "Support Info: "), osdiag.LogSupportInfoReasonStartup) }() - if winutil.GetPolicyInteger("LogSCMInteractions", 0) != 0 { + if logSCMInteractions, _ := winutil.GetPolicyInteger("LogSCMInteractions"); logSCMInteractions != 0 { syslog, err := eventlog.Open(serviceName) if err == nil { syslogf = func(format string, args ...any) { @@ -158,7 +158,7 @@ func (service *ipnService) Execute(args []string, r <-chan svc.ChangeRequest, ch syslogf("Service start pending") svcAccepts := svc.AcceptStop - if winutil.GetPolicyInteger("FlushDNSOnSessionUnlock", 0) != 0 { + if flushDNSOnSessionUnlock, _ := winutil.GetPolicyInteger("FlushDNSOnSessionUnlock"); flushDNSOnSessionUnlock != 0 { svcAccepts |= svc.AcceptSessionChange } diff --git a/control/controlclient/sign_supported.go b/control/controlclient/sign_supported.go index 2dc8efa1e..0c7925452 100644 --- a/control/controlclient/sign_supported.go +++ b/control/controlclient/sign_supported.go @@ -40,7 +40,7 @@ var getMachineCertificateSubjectOnce struct { // Example: "CN=Tailscale Inc Test Root CA,OU=Tailscale Inc Test Certificate Authority,O=Tailscale Inc,ST=ON,C=CA" func getMachineCertificateSubject() string { getMachineCertificateSubjectOnce.Do(func() { - getMachineCertificateSubjectOnce.v = winutil.GetRegString("MachineCertificateSubject", "") + getMachineCertificateSubjectOnce.v, _ = winutil.GetRegString("MachineCertificateSubject") }) return getMachineCertificateSubjectOnce.v diff --git a/hostinfo/hostinfo_windows.go b/hostinfo/hostinfo_windows.go index 3401655f4..d74d1db42 100644 --- a/hostinfo/hostinfo_windows.go +++ b/hostinfo/hostinfo_windows.go @@ -62,7 +62,8 @@ func packageTypeWindows() string { if _, err := os.Stat(`C:\ProgramData\chocolatey\lib\tailscale`); err == nil { return "choco" } - if msiSentinel := winutil.GetRegInteger("MSI", 0); msiSentinel == 1 { + msiSentinel, _ := winutil.GetRegInteger("MSI") + if msiSentinel == 1 { return "msi" } exe, err := os.Executable() diff --git a/ipn/ipnlocal/profiles.go b/ipn/ipnlocal/profiles.go index 30f4c59f8..74e5c52bd 100644 --- a/ipn/ipnlocal/profiles.go +++ b/ipn/ipnlocal/profiles.go @@ -452,20 +452,24 @@ var defaultPrefs = func() ipn.PrefsView { prefs.LoggedOut = true prefs.WantRunning = false - prefs.ControlURL = winutil.GetPolicyString("LoginURL", "") + controlURL, _ := winutil.GetPolicyString("LoginURL") + prefs.ControlURL = controlURL + prefs.ExitNodeIP = resolveExitNodeIP(netip.Addr{}) // Allow Incoming (used by the UI) is the negation of ShieldsUp (used by the // backend), so this has to convert between the two conventions. - prefs.ShieldsUp = winutil.GetPolicyString("AllowIncomingConnections", "") == "never" - prefs.ForceDaemon = winutil.GetPolicyString("UnattendedMode", "") == "always" + shieldsUp, _ := winutil.GetPolicyString("AllowIncomingConnections") + prefs.ShieldsUp = shieldsUp == "never" + forceDaemon, _ := winutil.GetPolicyString("UnattendedMode") + prefs.ForceDaemon = forceDaemon == "always" return prefs.View() }() func resolveExitNodeIP(defIP netip.Addr) (ret netip.Addr) { ret = defIP - if exitNode := winutil.GetPolicyString("ExitNodeIP", ""); exitNode != "" { + if exitNode, _ := winutil.GetPolicyString("ExitNodeIP"); exitNode != "" { if ip, err := netip.ParseAddr(exitNode); err == nil { ret = ip } diff --git a/logpolicy/logpolicy.go b/logpolicy/logpolicy.go index c11aaf3bc..7d22362ae 100644 --- a/logpolicy/logpolicy.go +++ b/logpolicy/logpolicy.go @@ -65,7 +65,8 @@ func getLogTarget() string { getLogTargetOnce.v = val } else { if runtime.GOOS == "windows" { - getLogTargetOnce.v = winutil.GetRegString("LogTarget", "") + logTarget, _ := winutil.GetRegString("LogTarget") + getLogTargetOnce.v = logTarget } } }) diff --git a/util/winutil/policy/policy_windows.go b/util/winutil/policy/policy_windows.go index 139cf5876..89142951f 100644 --- a/util/winutil/policy/policy_windows.go +++ b/util/winutil/policy/policy_windows.go @@ -49,7 +49,10 @@ func (p PreferenceOptionPolicy) ShouldEnable(userChoice bool) bool { // "always" and "never" remove the user's ability to make a selection. If not // present or set to a different value, "user-decides" is the default. func GetPreferenceOptionPolicy(name string) PreferenceOptionPolicy { - opt := winutil.GetPolicyString(name, "user-decides") + opt, err := winutil.GetPolicyString(name) + if opt == "" || err != nil { + return showChoiceByPolicy + } switch opt { case "always": return alwaysByPolicy @@ -81,7 +84,10 @@ func (p VisibilityPolicy) Show() bool { // true) or "hide" (return true). If not present or set to a different value, // "show" (return false) is the default. func GetVisibilityPolicy(name string) VisibilityPolicy { - opt := winutil.GetPolicyString(name, "show") + opt, err := winutil.GetPolicyString(name) + if opt == "" || err != nil { + return visibleByPolicy + } switch opt { case "hide": return hiddenByPolicy @@ -96,8 +102,8 @@ func GetVisibilityPolicy(name string) VisibilityPolicy { // understands. If the registry value is "" or can not be processed, // defaultValue is returned instead. func GetDurationPolicy(name string, defaultValue time.Duration) time.Duration { - opt := winutil.GetPolicyString(name, "") - if opt == "" { + opt, err := winutil.GetPolicyString(name) + if opt == "" || err != nil { return defaultValue } v, err := time.ParseDuration(opt) diff --git a/util/winutil/winutil.go b/util/winutil/winutil.go index 3ec3f7c99..4b771491a 100644 --- a/util/winutil/winutil.go +++ b/util/winutil/winutil.go @@ -13,45 +13,49 @@ import ( const RegBase = regBase // GetPolicyString looks up a registry value in the local machine's path for -// system policies, or returns the given default if it can't. +// system policies, or returns empty string and the error. // Use this function to read values that may be set by sysadmins via the MSI // installer or via GPO. For registry settings that you do *not* want to be // visible to sysadmin tools, use GetRegString instead. // // This function will only work on GOOS=windows. Trying to run it on any other -// OS will always return the default value. -func GetPolicyString(name, defval string) string { - return getPolicyString(name, defval) +// OS will always return an empty string and ErrNoValue. +// If value does not exist or another error happens, returns empty string and error. +func GetPolicyString(name string) (string, error) { + return getPolicyString(name) } // GetPolicyInteger looks up a registry value in the local machine's path for -// system policies, or returns the given default if it can't. +// system policies, or returns 0 and the associated error. // Use this function to read values that may be set by sysadmins via the MSI // installer or via GPO. For registry settings that you do *not* want to be // visible to sysadmin tools, use GetRegInteger instead. // // This function will only work on GOOS=windows. Trying to run it on any other -// OS will always return the default value. -func GetPolicyInteger(name string, defval uint64) uint64 { - return getPolicyInteger(name, defval) +// OS will always return 0 and ErrNoValue. +// If value does not exist or another error happens, returns 0 and error. +func GetPolicyInteger(name string) (uint64, error) { + return getPolicyInteger(name) } // GetRegString looks up a registry path in the local machine path, or returns -// the given default if it can't. +// an empty string and error. // // This function will only work on GOOS=windows. Trying to run it on any other -// OS will always return the default value. -func GetRegString(name, defval string) string { - return getRegString(name, defval) +// OS will always return an empty string and ErrNoValue. +// If value does not exist or another error happens, returns empty string and error. +func GetRegString(name string) (string, error) { + return getRegString(name) } // GetRegInteger looks up a registry path in the local machine path, or returns -// the given default if it can't. +// 0 and the error. // // This function will only work on GOOS=windows. Trying to run it on any other -// OS will always return the default value. -func GetRegInteger(name string, defval uint64) uint64 { - return getRegInteger(name, defval) +// OS will always return 0 and ErrNoValue. +// If value does not exist or another error happens, returns 0 and error. +func GetRegInteger(name string) (uint64, error) { + return getRegInteger(name) } // IsSIDValidPrincipal determines whether the SID contained in uid represents a diff --git a/util/winutil/winutil_notwindows.go b/util/winutil/winutil_notwindows.go index c9a292aae..a40712c3f 100644 --- a/util/winutil/winutil_notwindows.go +++ b/util/winutil/winutil_notwindows.go @@ -6,6 +6,7 @@ package winutil import ( + "errors" "fmt" "os/user" "runtime" @@ -13,13 +14,15 @@ import ( const regBase = `` -func getPolicyString(name, defval string) string { return defval } +var ErrNoValue = errors.New("no value because registry is unavailable on this OS") -func getPolicyInteger(name string, defval uint64) uint64 { return defval } +func getPolicyString(name string) (string, error) { return "", ErrNoValue } -func getRegString(name, defval string) string { return defval } +func getPolicyInteger(name string) (uint64, error) { return 0, ErrNoValue } -func getRegInteger(name string, defval uint64) uint64 { return defval } +func getRegString(name string) (string, error) { return "", ErrNoValue } + +func getRegInteger(name string) (uint64, error) { return 0, ErrNoValue } func isSIDValidPrincipal(uid string) bool { return false } diff --git a/util/winutil/winutil_windows.go b/util/winutil/winutil_windows.go index ed516ce6b..a686e6335 100644 --- a/util/winutil/winutil_windows.go +++ b/util/winutil/winutil_windows.go @@ -29,6 +29,9 @@ const ( // ErrNoShell is returned when the shell process is not found. var ErrNoShell = errors.New("no Shell process is present") +// ErrNoValue is returned when the value doesn't exist in the registry. +var ErrNoValue = registry.ErrNotExist + // GetDesktopPID searches the PID of the process that's running the // currently active desktop. Returns ErrNoShell if the shell is not present. // Usually the PID will be for explorer.exe. @@ -47,44 +50,44 @@ func GetDesktopPID() (uint32, error) { return pid, nil } -func getPolicyString(name, defval string) string { +func getPolicyString(name string) (string, error) { s, err := getRegStringInternal(regPolicyBase, name) if err != nil { // Fall back to the legacy path - return getRegString(name, defval) + return getRegString(name) } - return s + return s, err } -func getPolicyInteger(name string, defval uint64) uint64 { - i, err := getRegIntegerInternal(regPolicyBase, name) +func getRegString(name string) (string, error) { + s, err := getRegStringInternal(regBase, name) if err != nil { - // Fall back to the legacy path - return getRegInteger(name, defval) + return "", err } - return i + return s, err } -func getRegString(name, defval string) string { - s, err := getRegStringInternal(regBase, name) +func getPolicyInteger(name string) (uint64, error) { + i, err := getRegIntegerInternal(regPolicyBase, name) if err != nil { - return defval + // Fall back to the legacy path + return getRegInteger(name) } - return s + return i, err } -func getRegInteger(name string, defval uint64) uint64 { +func getRegInteger(name string) (uint64, error) { i, err := getRegIntegerInternal(regBase, name) if err != nil { - return defval + return 0, err } - return i + return i, err } func getRegStringInternal(subKey, name string) (string, error) { key, err := registry.OpenKey(registry.LOCAL_MACHINE, subKey, registry.READ) if err != nil { - if err != registry.ErrNotExist { + if err != ErrNoValue { log.Printf("registry.OpenKey(%v): %v", subKey, err) } return "", err @@ -93,7 +96,7 @@ func getRegStringInternal(subKey, name string) (string, error) { val, _, err := key.GetStringValue(name) if err != nil { - if err != registry.ErrNotExist { + if err != ErrNoValue { log.Printf("registry.GetStringValue(%v): %v", name, err) } return "", err @@ -114,7 +117,7 @@ func GetRegStrings(name string, defval []string) []string { func getRegStringsInternal(subKey, name string) ([]string, error) { key, err := registry.OpenKey(registry.LOCAL_MACHINE, subKey, registry.READ) if err != nil { - if err != registry.ErrNotExist { + if err != ErrNoValue { log.Printf("registry.OpenKey(%v): %v", subKey, err) } return nil, err @@ -123,7 +126,7 @@ func getRegStringsInternal(subKey, name string) ([]string, error) { val, _, err := key.GetStringsValue(name) if err != nil { - if err != registry.ErrNotExist { + if err != ErrNoValue { log.Printf("registry.GetStringValue(%v): %v", name, err) } return nil, err @@ -154,7 +157,7 @@ func DeleteRegValue(name string) error { func deleteRegValueInternal(subKey, name string) error { key, err := registry.OpenKey(registry.LOCAL_MACHINE, subKey, registry.SET_VALUE) - if err == registry.ErrNotExist { + if err == ErrNoValue { return nil } if err != nil { @@ -164,7 +167,7 @@ func deleteRegValueInternal(subKey, name string) error { defer key.Close() err = key.DeleteValue(name) - if err == registry.ErrNotExist { + if err == ErrNoValue { err = nil } return err @@ -173,7 +176,7 @@ func deleteRegValueInternal(subKey, name string) error { func getRegIntegerInternal(subKey, name string) (uint64, error) { key, err := registry.OpenKey(registry.LOCAL_MACHINE, subKey, registry.READ) if err != nil { - if err != registry.ErrNotExist { + if err != ErrNoValue { log.Printf("registry.OpenKey(%v): %v", subKey, err) } return 0, err @@ -182,7 +185,7 @@ func getRegIntegerInternal(subKey, name string) (uint64, error) { val, _, err := key.GetIntegerValue(name) if err != nil { - if err != registry.ErrNotExist { + if err != ErrNoValue { log.Printf("registry.GetIntegerValue(%v): %v", name, err) } return 0, err