winutil: refactor methods to get values from registry to also return (#9536)

errors
Updates tailscale/corp#14879

Signed-off-by: Claire Wang <claire@tailscale.com>
pull/9541/head
Claire Wang 1 year ago committed by GitHub
parent c608660d12
commit e3d6236606
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -131,7 +131,7 @@ func runWindowsService(pol *logpolicy.Policy) error {
osdiag.LogSupportInfo(logger.WithPrefix(log.Printf, "Support Info: "), osdiag.LogSupportInfoReasonStartup) osdiag.LogSupportInfo(logger.WithPrefix(log.Printf, "Support Info: "), osdiag.LogSupportInfoReasonStartup)
}() }()
if winutil.GetPolicyInteger("LogSCMInteractions", 0) != 0 { if logSCMInteractions, _ := winutil.GetPolicyInteger("LogSCMInteractions"); logSCMInteractions != 0 {
syslog, err := eventlog.Open(serviceName) syslog, err := eventlog.Open(serviceName)
if err == nil { if err == nil {
syslogf = func(format string, args ...any) { syslogf = func(format string, args ...any) {
@ -158,7 +158,7 @@ func (service *ipnService) Execute(args []string, r <-chan svc.ChangeRequest, ch
syslogf("Service start pending") syslogf("Service start pending")
svcAccepts := svc.AcceptStop svcAccepts := svc.AcceptStop
if winutil.GetPolicyInteger("FlushDNSOnSessionUnlock", 0) != 0 { if flushDNSOnSessionUnlock, _ := winutil.GetPolicyInteger("FlushDNSOnSessionUnlock"); flushDNSOnSessionUnlock != 0 {
svcAccepts |= svc.AcceptSessionChange svcAccepts |= svc.AcceptSessionChange
} }

@ -40,7 +40,7 @@ var getMachineCertificateSubjectOnce struct {
// Example: "CN=Tailscale Inc Test Root CA,OU=Tailscale Inc Test Certificate Authority,O=Tailscale Inc,ST=ON,C=CA" // Example: "CN=Tailscale Inc Test Root CA,OU=Tailscale Inc Test Certificate Authority,O=Tailscale Inc,ST=ON,C=CA"
func getMachineCertificateSubject() string { func getMachineCertificateSubject() string {
getMachineCertificateSubjectOnce.Do(func() { getMachineCertificateSubjectOnce.Do(func() {
getMachineCertificateSubjectOnce.v = winutil.GetRegString("MachineCertificateSubject", "") getMachineCertificateSubjectOnce.v, _ = winutil.GetRegString("MachineCertificateSubject")
}) })
return getMachineCertificateSubjectOnce.v return getMachineCertificateSubjectOnce.v

@ -62,7 +62,8 @@ func packageTypeWindows() string {
if _, err := os.Stat(`C:\ProgramData\chocolatey\lib\tailscale`); err == nil { if _, err := os.Stat(`C:\ProgramData\chocolatey\lib\tailscale`); err == nil {
return "choco" return "choco"
} }
if msiSentinel := winutil.GetRegInteger("MSI", 0); msiSentinel == 1 { msiSentinel, _ := winutil.GetRegInteger("MSI")
if msiSentinel == 1 {
return "msi" return "msi"
} }
exe, err := os.Executable() exe, err := os.Executable()

@ -452,20 +452,24 @@ var defaultPrefs = func() ipn.PrefsView {
prefs.LoggedOut = true prefs.LoggedOut = true
prefs.WantRunning = false prefs.WantRunning = false
prefs.ControlURL = winutil.GetPolicyString("LoginURL", "") controlURL, _ := winutil.GetPolicyString("LoginURL")
prefs.ControlURL = controlURL
prefs.ExitNodeIP = resolveExitNodeIP(netip.Addr{}) prefs.ExitNodeIP = resolveExitNodeIP(netip.Addr{})
// Allow Incoming (used by the UI) is the negation of ShieldsUp (used by the // Allow Incoming (used by the UI) is the negation of ShieldsUp (used by the
// backend), so this has to convert between the two conventions. // backend), so this has to convert between the two conventions.
prefs.ShieldsUp = winutil.GetPolicyString("AllowIncomingConnections", "") == "never" shieldsUp, _ := winutil.GetPolicyString("AllowIncomingConnections")
prefs.ForceDaemon = winutil.GetPolicyString("UnattendedMode", "") == "always" prefs.ShieldsUp = shieldsUp == "never"
forceDaemon, _ := winutil.GetPolicyString("UnattendedMode")
prefs.ForceDaemon = forceDaemon == "always"
return prefs.View() return prefs.View()
}() }()
func resolveExitNodeIP(defIP netip.Addr) (ret netip.Addr) { func resolveExitNodeIP(defIP netip.Addr) (ret netip.Addr) {
ret = defIP ret = defIP
if exitNode := winutil.GetPolicyString("ExitNodeIP", ""); exitNode != "" { if exitNode, _ := winutil.GetPolicyString("ExitNodeIP"); exitNode != "" {
if ip, err := netip.ParseAddr(exitNode); err == nil { if ip, err := netip.ParseAddr(exitNode); err == nil {
ret = ip ret = ip
} }

@ -65,7 +65,8 @@ func getLogTarget() string {
getLogTargetOnce.v = val getLogTargetOnce.v = val
} else { } else {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
getLogTargetOnce.v = winutil.GetRegString("LogTarget", "") logTarget, _ := winutil.GetRegString("LogTarget")
getLogTargetOnce.v = logTarget
} }
} }
}) })

@ -49,7 +49,10 @@ func (p PreferenceOptionPolicy) ShouldEnable(userChoice bool) bool {
// "always" and "never" remove the user's ability to make a selection. If not // "always" and "never" remove the user's ability to make a selection. If not
// present or set to a different value, "user-decides" is the default. // present or set to a different value, "user-decides" is the default.
func GetPreferenceOptionPolicy(name string) PreferenceOptionPolicy { func GetPreferenceOptionPolicy(name string) PreferenceOptionPolicy {
opt := winutil.GetPolicyString(name, "user-decides") opt, err := winutil.GetPolicyString(name)
if opt == "" || err != nil {
return showChoiceByPolicy
}
switch opt { switch opt {
case "always": case "always":
return alwaysByPolicy return alwaysByPolicy
@ -81,7 +84,10 @@ func (p VisibilityPolicy) Show() bool {
// true) or "hide" (return true). If not present or set to a different value, // true) or "hide" (return true). If not present or set to a different value,
// "show" (return false) is the default. // "show" (return false) is the default.
func GetVisibilityPolicy(name string) VisibilityPolicy { func GetVisibilityPolicy(name string) VisibilityPolicy {
opt := winutil.GetPolicyString(name, "show") opt, err := winutil.GetPolicyString(name)
if opt == "" || err != nil {
return visibleByPolicy
}
switch opt { switch opt {
case "hide": case "hide":
return hiddenByPolicy return hiddenByPolicy
@ -96,8 +102,8 @@ func GetVisibilityPolicy(name string) VisibilityPolicy {
// understands. If the registry value is "" or can not be processed, // understands. If the registry value is "" or can not be processed,
// defaultValue is returned instead. // defaultValue is returned instead.
func GetDurationPolicy(name string, defaultValue time.Duration) time.Duration { func GetDurationPolicy(name string, defaultValue time.Duration) time.Duration {
opt := winutil.GetPolicyString(name, "") opt, err := winutil.GetPolicyString(name)
if opt == "" { if opt == "" || err != nil {
return defaultValue return defaultValue
} }
v, err := time.ParseDuration(opt) v, err := time.ParseDuration(opt)

@ -13,45 +13,49 @@ import (
const RegBase = regBase const RegBase = regBase
// GetPolicyString looks up a registry value in the local machine's path for // GetPolicyString looks up a registry value in the local machine's path for
// system policies, or returns the given default if it can't. // system policies, or returns empty string and the error.
// Use this function to read values that may be set by sysadmins via the MSI // Use this function to read values that may be set by sysadmins via the MSI
// installer or via GPO. For registry settings that you do *not* want to be // installer or via GPO. For registry settings that you do *not* want to be
// visible to sysadmin tools, use GetRegString instead. // visible to sysadmin tools, use GetRegString instead.
// //
// This function will only work on GOOS=windows. Trying to run it on any other // This function will only work on GOOS=windows. Trying to run it on any other
// OS will always return the default value. // OS will always return an empty string and ErrNoValue.
func GetPolicyString(name, defval string) string { // If value does not exist or another error happens, returns empty string and error.
return getPolicyString(name, defval) func GetPolicyString(name string) (string, error) {
return getPolicyString(name)
} }
// GetPolicyInteger looks up a registry value in the local machine's path for // GetPolicyInteger looks up a registry value in the local machine's path for
// system policies, or returns the given default if it can't. // system policies, or returns 0 and the associated error.
// Use this function to read values that may be set by sysadmins via the MSI // Use this function to read values that may be set by sysadmins via the MSI
// installer or via GPO. For registry settings that you do *not* want to be // installer or via GPO. For registry settings that you do *not* want to be
// visible to sysadmin tools, use GetRegInteger instead. // visible to sysadmin tools, use GetRegInteger instead.
// //
// This function will only work on GOOS=windows. Trying to run it on any other // This function will only work on GOOS=windows. Trying to run it on any other
// OS will always return the default value. // OS will always return 0 and ErrNoValue.
func GetPolicyInteger(name string, defval uint64) uint64 { // If value does not exist or another error happens, returns 0 and error.
return getPolicyInteger(name, defval) func GetPolicyInteger(name string) (uint64, error) {
return getPolicyInteger(name)
} }
// GetRegString looks up a registry path in the local machine path, or returns // GetRegString looks up a registry path in the local machine path, or returns
// the given default if it can't. // an empty string and error.
// //
// This function will only work on GOOS=windows. Trying to run it on any other // This function will only work on GOOS=windows. Trying to run it on any other
// OS will always return the default value. // OS will always return an empty string and ErrNoValue.
func GetRegString(name, defval string) string { // If value does not exist or another error happens, returns empty string and error.
return getRegString(name, defval) func GetRegString(name string) (string, error) {
return getRegString(name)
} }
// GetRegInteger looks up a registry path in the local machine path, or returns // GetRegInteger looks up a registry path in the local machine path, or returns
// the given default if it can't. // 0 and the error.
// //
// This function will only work on GOOS=windows. Trying to run it on any other // This function will only work on GOOS=windows. Trying to run it on any other
// OS will always return the default value. // OS will always return 0 and ErrNoValue.
func GetRegInteger(name string, defval uint64) uint64 { // If value does not exist or another error happens, returns 0 and error.
return getRegInteger(name, defval) func GetRegInteger(name string) (uint64, error) {
return getRegInteger(name)
} }
// IsSIDValidPrincipal determines whether the SID contained in uid represents a // IsSIDValidPrincipal determines whether the SID contained in uid represents a

@ -6,6 +6,7 @@
package winutil package winutil
import ( import (
"errors"
"fmt" "fmt"
"os/user" "os/user"
"runtime" "runtime"
@ -13,13 +14,15 @@ import (
const regBase = `` const regBase = ``
func getPolicyString(name, defval string) string { return defval } var ErrNoValue = errors.New("no value because registry is unavailable on this OS")
func getPolicyInteger(name string, defval uint64) uint64 { return defval } func getPolicyString(name string) (string, error) { return "", ErrNoValue }
func getRegString(name, defval string) string { return defval } func getPolicyInteger(name string) (uint64, error) { return 0, ErrNoValue }
func getRegInteger(name string, defval uint64) uint64 { return defval } func getRegString(name string) (string, error) { return "", ErrNoValue }
func getRegInteger(name string) (uint64, error) { return 0, ErrNoValue }
func isSIDValidPrincipal(uid string) bool { return false } func isSIDValidPrincipal(uid string) bool { return false }

@ -29,6 +29,9 @@ const (
// ErrNoShell is returned when the shell process is not found. // ErrNoShell is returned when the shell process is not found.
var ErrNoShell = errors.New("no Shell process is present") var ErrNoShell = errors.New("no Shell process is present")
// ErrNoValue is returned when the value doesn't exist in the registry.
var ErrNoValue = registry.ErrNotExist
// GetDesktopPID searches the PID of the process that's running the // GetDesktopPID searches the PID of the process that's running the
// currently active desktop. Returns ErrNoShell if the shell is not present. // currently active desktop. Returns ErrNoShell if the shell is not present.
// Usually the PID will be for explorer.exe. // Usually the PID will be for explorer.exe.
@ -47,44 +50,44 @@ func GetDesktopPID() (uint32, error) {
return pid, nil return pid, nil
} }
func getPolicyString(name, defval string) string { func getPolicyString(name string) (string, error) {
s, err := getRegStringInternal(regPolicyBase, name) s, err := getRegStringInternal(regPolicyBase, name)
if err != nil { if err != nil {
// Fall back to the legacy path // Fall back to the legacy path
return getRegString(name, defval) return getRegString(name)
} }
return s return s, err
} }
func getPolicyInteger(name string, defval uint64) uint64 { func getRegString(name string) (string, error) {
i, err := getRegIntegerInternal(regPolicyBase, name) s, err := getRegStringInternal(regBase, name)
if err != nil { if err != nil {
// Fall back to the legacy path return "", err
return getRegInteger(name, defval)
} }
return i return s, err
} }
func getRegString(name, defval string) string { func getPolicyInteger(name string) (uint64, error) {
s, err := getRegStringInternal(regBase, name) i, err := getRegIntegerInternal(regPolicyBase, name)
if err != nil { if err != nil {
return defval // Fall back to the legacy path
return getRegInteger(name)
} }
return s return i, err
} }
func getRegInteger(name string, defval uint64) uint64 { func getRegInteger(name string) (uint64, error) {
i, err := getRegIntegerInternal(regBase, name) i, err := getRegIntegerInternal(regBase, name)
if err != nil { if err != nil {
return defval return 0, err
} }
return i return i, err
} }
func getRegStringInternal(subKey, name string) (string, error) { func getRegStringInternal(subKey, name string) (string, error) {
key, err := registry.OpenKey(registry.LOCAL_MACHINE, subKey, registry.READ) key, err := registry.OpenKey(registry.LOCAL_MACHINE, subKey, registry.READ)
if err != nil { if err != nil {
if err != registry.ErrNotExist { if err != ErrNoValue {
log.Printf("registry.OpenKey(%v): %v", subKey, err) log.Printf("registry.OpenKey(%v): %v", subKey, err)
} }
return "", err return "", err
@ -93,7 +96,7 @@ func getRegStringInternal(subKey, name string) (string, error) {
val, _, err := key.GetStringValue(name) val, _, err := key.GetStringValue(name)
if err != nil { if err != nil {
if err != registry.ErrNotExist { if err != ErrNoValue {
log.Printf("registry.GetStringValue(%v): %v", name, err) log.Printf("registry.GetStringValue(%v): %v", name, err)
} }
return "", err return "", err
@ -114,7 +117,7 @@ func GetRegStrings(name string, defval []string) []string {
func getRegStringsInternal(subKey, name string) ([]string, error) { func getRegStringsInternal(subKey, name string) ([]string, error) {
key, err := registry.OpenKey(registry.LOCAL_MACHINE, subKey, registry.READ) key, err := registry.OpenKey(registry.LOCAL_MACHINE, subKey, registry.READ)
if err != nil { if err != nil {
if err != registry.ErrNotExist { if err != ErrNoValue {
log.Printf("registry.OpenKey(%v): %v", subKey, err) log.Printf("registry.OpenKey(%v): %v", subKey, err)
} }
return nil, err return nil, err
@ -123,7 +126,7 @@ func getRegStringsInternal(subKey, name string) ([]string, error) {
val, _, err := key.GetStringsValue(name) val, _, err := key.GetStringsValue(name)
if err != nil { if err != nil {
if err != registry.ErrNotExist { if err != ErrNoValue {
log.Printf("registry.GetStringValue(%v): %v", name, err) log.Printf("registry.GetStringValue(%v): %v", name, err)
} }
return nil, err return nil, err
@ -154,7 +157,7 @@ func DeleteRegValue(name string) error {
func deleteRegValueInternal(subKey, name string) error { func deleteRegValueInternal(subKey, name string) error {
key, err := registry.OpenKey(registry.LOCAL_MACHINE, subKey, registry.SET_VALUE) key, err := registry.OpenKey(registry.LOCAL_MACHINE, subKey, registry.SET_VALUE)
if err == registry.ErrNotExist { if err == ErrNoValue {
return nil return nil
} }
if err != nil { if err != nil {
@ -164,7 +167,7 @@ func deleteRegValueInternal(subKey, name string) error {
defer key.Close() defer key.Close()
err = key.DeleteValue(name) err = key.DeleteValue(name)
if err == registry.ErrNotExist { if err == ErrNoValue {
err = nil err = nil
} }
return err return err
@ -173,7 +176,7 @@ func deleteRegValueInternal(subKey, name string) error {
func getRegIntegerInternal(subKey, name string) (uint64, error) { func getRegIntegerInternal(subKey, name string) (uint64, error) {
key, err := registry.OpenKey(registry.LOCAL_MACHINE, subKey, registry.READ) key, err := registry.OpenKey(registry.LOCAL_MACHINE, subKey, registry.READ)
if err != nil { if err != nil {
if err != registry.ErrNotExist { if err != ErrNoValue {
log.Printf("registry.OpenKey(%v): %v", subKey, err) log.Printf("registry.OpenKey(%v): %v", subKey, err)
} }
return 0, err return 0, err
@ -182,7 +185,7 @@ func getRegIntegerInternal(subKey, name string) (uint64, error) {
val, _, err := key.GetIntegerValue(name) val, _, err := key.GetIntegerValue(name)
if err != nil { if err != nil {
if err != registry.ErrNotExist { if err != ErrNoValue {
log.Printf("registry.GetIntegerValue(%v): %v", name, err) log.Printf("registry.GetIntegerValue(%v): %v", name, err)
} }
return 0, err return 0, err

Loading…
Cancel
Save