From aeb15dea3032c0560793190731dd6a06b45850b2 Mon Sep 17 00:00:00 2001 From: Nick Khyl Date: Mon, 12 Aug 2024 22:07:45 -0500 Subject: [PATCH] util/syspolicy/source: add package for reading policy settings from external stores We add package defining interfaces for policy stores, enabling creation of policy sources and reading settings from them. It includes a Windows-specific PlatformPolicyStore for GP and MDM policies stored in the Registry, and an in-memory TestStore for testing purposes. We also include an internal package that tracks and reports policy usage metrics when a policy setting is read from a store. Initially, it will be used only on Windows and Android, as macOS, iOS, and tvOS report their own metrics. However, we plan to use it across all platforms eventually. Updates #12687 Signed-off-by: Nick Khyl --- util/syspolicy/internal/loggerx/logger.go | 46 ++ util/syspolicy/internal/metrics/metrics.go | 320 +++++++++++++ .../internal/metrics/metrics_test.go | 423 ++++++++++++++++ .../internal/metrics/test_handler.go | 88 ++++ util/syspolicy/setting/errors.go | 4 +- util/syspolicy/source/policy_reader.go | 394 +++++++++++++++ util/syspolicy/source/policy_reader_test.go | 291 +++++++++++ util/syspolicy/source/policy_source.go | 146 ++++++ util/syspolicy/source/policy_store_windows.go | 450 +++++++++++++++++ .../source/policy_store_windows_test.go | 398 ++++++++++++++++ util/syspolicy/source/test_store.go | 451 ++++++++++++++++++ 11 files changed, 3009 insertions(+), 2 deletions(-) create mode 100644 util/syspolicy/internal/loggerx/logger.go create mode 100644 util/syspolicy/internal/metrics/metrics.go create mode 100644 util/syspolicy/internal/metrics/metrics_test.go create mode 100644 util/syspolicy/internal/metrics/test_handler.go create mode 100644 util/syspolicy/source/policy_reader.go create mode 100644 util/syspolicy/source/policy_reader_test.go create mode 100644 util/syspolicy/source/policy_source.go create mode 100644 util/syspolicy/source/policy_store_windows.go create mode 100644 util/syspolicy/source/policy_store_windows_test.go create mode 100644 util/syspolicy/source/test_store.go diff --git a/util/syspolicy/internal/loggerx/logger.go b/util/syspolicy/internal/loggerx/logger.go new file mode 100644 index 000000000..b28610826 --- /dev/null +++ b/util/syspolicy/internal/loggerx/logger.go @@ -0,0 +1,46 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package loggerx provides logging functions to the rest of the syspolicy packages. +package loggerx + +import ( + "log" + + "tailscale.com/types/lazy" + "tailscale.com/types/logger" + "tailscale.com/util/syspolicy/internal" +) + +const ( + errorPrefix = "syspolicy: " + verbosePrefix = "syspolicy: [v2] " +) + +var ( + lazyErrorf lazy.SyncValue[logger.Logf] + lazyVerbosef lazy.SyncValue[logger.Logf] +) + +// Errorf formats and writes an error message to the log. +func Errorf(format string, args ...any) { + errorf := lazyErrorf.Get(func() logger.Logf { + return logger.WithPrefix(log.Printf, errorPrefix) + }) + errorf(format, args...) +} + +// Verbosef formats and writes an optional, verbose message to the log. +func Verbosef(format string, args ...any) { + verbosef := lazyVerbosef.Get(func() logger.Logf { + return logger.WithPrefix(log.Printf, verbosePrefix) + }) + verbosef(format, args...) +} + +// SetForTest sets the specified errorf and verbosef functions for the duration +// of tb and its subtests. +func SetForTest(tb internal.TB, errorf, verbosef logger.Logf) { + lazyErrorf.SetForTest(tb, errorf, nil) + lazyVerbosef.SetForTest(tb, verbosef, nil) +} diff --git a/util/syspolicy/internal/metrics/metrics.go b/util/syspolicy/internal/metrics/metrics.go new file mode 100644 index 000000000..2ea02278a --- /dev/null +++ b/util/syspolicy/internal/metrics/metrics.go @@ -0,0 +1,320 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package metrics provides logging and reporting for policy settings and scopes. +package metrics + +import ( + "strings" + "sync" + + xmaps "golang.org/x/exp/maps" + + "tailscale.com/syncs" + "tailscale.com/types/lazy" + "tailscale.com/util/clientmetric" + "tailscale.com/util/mak" + "tailscale.com/util/slicesx" + "tailscale.com/util/syspolicy/internal" + "tailscale.com/util/syspolicy/internal/loggerx" + "tailscale.com/util/syspolicy/setting" + "tailscale.com/util/testenv" +) + +var lazyReportMetrics lazy.SyncValue[bool] // used as a test hook + +// ShouldReport reports whether metrics should be reported on the current environment. +func ShouldReport() bool { + return lazyReportMetrics.Get(func() bool { + // macOS, iOS and tvOS create their own metrics, + // and we don't have syspolicy on any other platforms. + return setting.PlatformList{"android", "windows"}.HasCurrent() + }) +} + +// Reset metrics for the specified policy origin. +func Reset(origin *setting.Origin) { + scopeMetrics(origin).Reset() +} + +// ReportConfigured updates metrics and logs that the specified setting is +// configured with the given value in the origin. +func ReportConfigured(origin *setting.Origin, setting *setting.Definition, value any) { + settingMetricsFor(setting).ReportValue(origin, value) +} + +// ReportError updates metrics and logs that the specified setting has an error +// in the origin. +func ReportError(origin *setting.Origin, setting *setting.Definition, err error) { + settingMetricsFor(setting).ReportError(origin, err) +} + +// ReportNotConfigured updates metrics and logs that the specified setting is +// not configured in the origin. +func ReportNotConfigured(origin *setting.Origin, setting *setting.Definition) { + settingMetricsFor(setting).Reset(origin) +} + +// metric is an interface implemented by [clientmetric.Metric] and [funcMetric]. +type metric interface { + Add(v int64) + Set(v int64) +} + +// policyScopeMetrics are metrics that apply to an entire policy scope rather +// than a specific policy setting. +type policyScopeMetrics struct { + hasAny metric + numErrored metric +} + +func newScopeMetrics(scope setting.Scope) *policyScopeMetrics { + prefix := metricScopeName(scope) + // {os}_syspolicy_{scope_unless_device}_any + // Example: windows_syspolicy_any or windows_syspolicy_user_any. + hasAny := newMetric([]string{prefix, "any"}, clientmetric.TypeGauge) + // {os}_syspolicy_{scope_unless_device}_errors + // Example: windows_syspolicy_errors or windows_syspolicy_user_errors. + // + // TODO(nickkhyl): maybe make the `{os}_syspolicy_errors` metric a gauge rather than a counter? + // It was a counter prior to https://github.com/tailscale/tailscale/issues/12687, so I kept it as such. + // But I think a gauge makes more sense: syspolicy errors indicate a mismatch between the expected + // policy value type or format and the actual value read from the underlying store (like the Windows Registry). + // We'll encounter the same error every time we re-read the policy setting from the backing store + // until the policy value is corrected by the user, or until we fix the bug in the code or ADMX. + // There's probably no reason to count and accumulate them over time. + // + // Brief discussion: https://github.com/tailscale/tailscale/pull/13113#discussion_r1723475136 + numErrored := newMetric([]string{prefix, "errors"}, clientmetric.TypeCounter) + return &policyScopeMetrics{hasAny, numErrored} +} + +// ReportHasSettings is called when there's any configured policy setting in the scope. +func (m *policyScopeMetrics) ReportHasSettings() { + if m != nil { + m.hasAny.Set(1) + } +} + +// ReportError is called when there's any errored policy setting in the scope. +func (m *policyScopeMetrics) ReportError() { + if m != nil { + m.numErrored.Add(1) + } +} + +// Reset is called to reset the policy scope metrics, such as when the policy scope +// is about to be reloaded. +func (m *policyScopeMetrics) Reset() { + if m != nil { + m.hasAny.Set(0) + // numErrored is a counter and cannot be (re-)set. + } +} + +// settingMetrics are metrics for a single policy setting in one or more scopes. +type settingMetrics struct { + definition *setting.Definition + isSet []metric // by scope + hasErrors []metric // by scope +} + +// ReportValue is called when the policy setting is found to be configured in the specified source. +func (m *settingMetrics) ReportValue(origin *setting.Origin, v any) { + if m == nil { + return + } + if scope := origin.Scope().Kind(); scope >= 0 && int(scope) < len(m.isSet) { + m.isSet[scope].Set(1) + m.hasErrors[scope].Set(0) + } + scopeMetrics(origin).ReportHasSettings() + loggerx.Verbosef("%v(%q) = %v", origin, m.definition.Key(), v) +} + +// ReportError is called when there's an error with the policy setting in the specified source. +func (m *settingMetrics) ReportError(origin *setting.Origin, err error) { + if m == nil { + return + } + if scope := origin.Scope().Kind(); int(scope) < len(m.hasErrors) { + m.isSet[scope].Set(0) + m.hasErrors[scope].Set(1) + } + scopeMetrics(origin).ReportError() + loggerx.Errorf("%v(%q): %v", origin, m.definition.Key(), err) +} + +// Reset is called to reset the policy setting's metrics, such as when +// the policy setting does not exist or the source containing the policy +// is about to be reloaded. +func (m *settingMetrics) Reset(origin *setting.Origin) { + if m == nil { + return + } + if scope := origin.Scope().Kind(); scope >= 0 && int(scope) < len(m.isSet) { + m.isSet[scope].Set(0) + m.hasErrors[scope].Set(0) + } +} + +// metricFn is a function that adds or sets a metric value. +type metricFn func(name string, typ clientmetric.Type, v int64) + +// funcMetric implements [metric] by calling the specified add and set functions. +// Used for testing, and with nil functions on platforms that do not support +// syspolicy, and on platforms that report policy metrics from the GUI. +type funcMetric struct { + name string + typ clientmetric.Type + add, set metricFn +} + +func (m funcMetric) Add(v int64) { + if m.add != nil { + m.add(m.name, m.typ, v) + } +} + +func (m funcMetric) Set(v int64) { + if m.set != nil { + m.set(m.name, m.typ, v) + } +} + +var ( + lazyDeviceMetrics lazy.SyncValue[*policyScopeMetrics] + lazyProfileMetrics lazy.SyncValue[*policyScopeMetrics] + lazyUserMetrics lazy.SyncValue[*policyScopeMetrics] +) + +func scopeMetrics(origin *setting.Origin) *policyScopeMetrics { + switch origin.Scope().Kind() { + case setting.DeviceSetting: + return lazyDeviceMetrics.Get(func() *policyScopeMetrics { + return newScopeMetrics(setting.DeviceSetting) + }) + case setting.ProfileSetting: + return lazyProfileMetrics.Get(func() *policyScopeMetrics { + return newScopeMetrics(setting.ProfileSetting) + }) + case setting.UserSetting: + return lazyUserMetrics.Get(func() *policyScopeMetrics { + return newScopeMetrics(setting.UserSetting) + }) + default: + panic("unreachable") + } +} + +var ( + settingMetricsMu sync.RWMutex + settingMetricsMap map[setting.Key]*settingMetrics +) + +func settingMetricsFor(setting *setting.Definition) *settingMetrics { + settingMetricsMu.RLock() + metrics, ok := settingMetricsMap[setting.Key()] + settingMetricsMu.RUnlock() + if ok { + return metrics + } + return settingMetricsForSlow(setting) +} + +func settingMetricsForSlow(d *setting.Definition) *settingMetrics { + settingMetricsMu.Lock() + defer settingMetricsMu.Unlock() + if metrics, ok := settingMetricsMap[d.Key()]; ok { + return metrics + } + + // The loop below initializes metrics for each scope where a policy setting defined in 'd' + // can be configured. The [setting.Definition.Scope] returns the narrowest scope at which the policy + // setting may be configured, and more specific scopes always have higher numeric values. + // In other words, [setting.UserSetting] > [setting.ProfileScope] > [setting.DeviceScope]. + // It's impossible for a policy setting to be configured in a scope with a higher numeric value than + // the [setting.Definition.Scope] returns. Therefore, a policy setting can be configured in at + // most d.Scope()+1 different scopes, and having d.Scope()+1 metrics for the corresponding scopes + // is always sufficient for [settingMetrics]; it won't access elements past the end of the slice + // or need to reallocate with a longer slice if one of those arrives. + isSet := make([]metric, d.Scope()+1) + hasErrors := make([]metric, d.Scope()+1) + for i := range isSet { + scope := setting.Scope(i) + // {os}_syspolicy_{key}_{scope_unless_device} + // Example: windows_syspolicy_AdminConsole or windows_syspolicy_AdminConsole_user. + isSet[i] = newSettingMetric(d.Key(), scope, "", clientmetric.TypeGauge) + // {os}_syspolicy_{key}_{scope_unless_device}_error + // Example: windows_syspolicy_AdminConsole_error or windows_syspolicy_TestSetting01_user_error. + hasErrors[i] = newSettingMetric(d.Key(), scope, "error", clientmetric.TypeGauge) + } + metrics := &settingMetrics{d, isSet, hasErrors} + mak.Set(&settingMetricsMap, d.Key(), metrics) + return metrics +} + +// hooks for testing +var addMetricTestHook, setMetricTestHook syncs.AtomicValue[metricFn] + +// SetHooksForTest sets the specified addMetric and setMetric functions +// as the metric functions for the duration of tb and all its subtests. +func SetHooksForTest(tb internal.TB, addMetric, setMetric metricFn) { + oldAddMetric := addMetricTestHook.Swap(addMetric) + oldSetMetric := setMetricTestHook.Swap(setMetric) + tb.Cleanup(func() { + addMetricTestHook.Store(oldAddMetric) + setMetricTestHook.Store(oldSetMetric) + }) + + settingMetricsMu.Lock() + oldSettingMetricsMap := xmaps.Clone(settingMetricsMap) + clear(settingMetricsMap) + settingMetricsMu.Unlock() + tb.Cleanup(func() { + settingMetricsMu.Lock() + settingMetricsMap = oldSettingMetricsMap + settingMetricsMu.Unlock() + }) + + // (re-)set the scope metrics to use the test hooks for the duration of tb. + lazyDeviceMetrics.SetForTest(tb, newScopeMetrics(setting.DeviceSetting), nil) + lazyProfileMetrics.SetForTest(tb, newScopeMetrics(setting.ProfileSetting), nil) + lazyUserMetrics.SetForTest(tb, newScopeMetrics(setting.UserSetting), nil) +} + +func newSettingMetric(key setting.Key, scope setting.Scope, suffix string, typ clientmetric.Type) metric { + name := strings.ReplaceAll(string(key), setting.KeyPathSeparator, "_") + return newMetric([]string{name, metricScopeName(scope), suffix}, typ) +} + +func newMetric(nameParts []string, typ clientmetric.Type) metric { + name := strings.Join(slicesx.Filter([]string{internal.OS(), "syspolicy"}, nameParts, isNonEmpty), "_") + switch { + case !ShouldReport(): + return &funcMetric{name: name, typ: typ} + case testenv.InTest(): + return &funcMetric{name, typ, addMetricTestHook.Load(), setMetricTestHook.Load()} + case typ == clientmetric.TypeCounter: + return clientmetric.NewCounter(name) + case typ == clientmetric.TypeGauge: + return clientmetric.NewGauge(name) + default: + panic("unreachable") + } +} + +func isNonEmpty(s string) bool { return s != "" } + +func metricScopeName(scope setting.Scope) string { + switch scope { + case setting.DeviceSetting: + return "" + case setting.ProfileSetting: + return "profile" + case setting.UserSetting: + return "user" + default: + panic("unreachable") + } +} diff --git a/util/syspolicy/internal/metrics/metrics_test.go b/util/syspolicy/internal/metrics/metrics_test.go new file mode 100644 index 000000000..07be4773c --- /dev/null +++ b/util/syspolicy/internal/metrics/metrics_test.go @@ -0,0 +1,423 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package metrics + +import ( + "errors" + "testing" + + "tailscale.com/types/lazy" + "tailscale.com/util/clientmetric" + "tailscale.com/util/syspolicy/internal" + "tailscale.com/util/syspolicy/setting" +) + +func TestSettingMetricNames(t *testing.T) { + tests := []struct { + name string + key setting.Key + scope setting.Scope + suffix string + typ clientmetric.Type + osOverride string + wantMetricName string + }{ + { + name: "windows-device-no-suffix", + key: "AdminConsole", + scope: setting.DeviceSetting, + suffix: "", + typ: clientmetric.TypeCounter, + osOverride: "windows", + wantMetricName: "windows_syspolicy_AdminConsole", + }, + { + name: "windows-user-no-suffix", + key: "AdminConsole", + scope: setting.UserSetting, + suffix: "", + typ: clientmetric.TypeCounter, + osOverride: "windows", + wantMetricName: "windows_syspolicy_AdminConsole_user", + }, + { + name: "windows-profile-no-suffix", + key: "AdminConsole", + scope: setting.ProfileSetting, + suffix: "", + typ: clientmetric.TypeCounter, + osOverride: "windows", + wantMetricName: "windows_syspolicy_AdminConsole_profile", + }, + { + name: "windows-profile-err", + key: "AdminConsole", + scope: setting.ProfileSetting, + suffix: "error", + typ: clientmetric.TypeCounter, + osOverride: "windows", + wantMetricName: "windows_syspolicy_AdminConsole_profile_error", + }, + { + name: "android-device-no-suffix", + key: "AdminConsole", + scope: setting.DeviceSetting, + suffix: "", + typ: clientmetric.TypeCounter, + osOverride: "android", + wantMetricName: "android_syspolicy_AdminConsole", + }, + { + name: "key-path", + key: "category/subcategory/setting", + scope: setting.DeviceSetting, + suffix: "", + typ: clientmetric.TypeCounter, + osOverride: "fakeos", + wantMetricName: "fakeos_syspolicy_category_subcategory_setting", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + internal.OSForTesting.SetForTest(t, tt.osOverride, nil) + metric, ok := newSettingMetric(tt.key, tt.scope, tt.suffix, tt.typ).(*funcMetric) + if !ok { + t.Fatal("metric is not a funcMetric") + } + if metric.name != tt.wantMetricName { + t.Errorf("got %q, want %q", metric.name, tt.wantMetricName) + } + }) + } +} + +func TestScopeMetrics(t *testing.T) { + tests := []struct { + name string + scope setting.Scope + osOverride string + wantHasAnyName string + wantNumErroredName string + wantHasAnyType clientmetric.Type + wantNumErroredType clientmetric.Type + }{ + { + name: "windows-device", + scope: setting.DeviceSetting, + osOverride: "windows", + wantHasAnyName: "windows_syspolicy_any", + wantHasAnyType: clientmetric.TypeGauge, + wantNumErroredName: "windows_syspolicy_errors", + wantNumErroredType: clientmetric.TypeCounter, + }, + { + name: "windows-profile", + scope: setting.ProfileSetting, + osOverride: "windows", + wantHasAnyName: "windows_syspolicy_profile_any", + wantHasAnyType: clientmetric.TypeGauge, + wantNumErroredName: "windows_syspolicy_profile_errors", + wantNumErroredType: clientmetric.TypeCounter, + }, + { + name: "windows-user", + scope: setting.UserSetting, + osOverride: "windows", + wantHasAnyName: "windows_syspolicy_user_any", + wantHasAnyType: clientmetric.TypeGauge, + wantNumErroredName: "windows_syspolicy_user_errors", + wantNumErroredType: clientmetric.TypeCounter, + }, + { + name: "android-device", + scope: setting.DeviceSetting, + osOverride: "android", + wantHasAnyName: "android_syspolicy_any", + wantHasAnyType: clientmetric.TypeGauge, + wantNumErroredName: "android_syspolicy_errors", + wantNumErroredType: clientmetric.TypeCounter, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + internal.OSForTesting.SetForTest(t, tt.osOverride, nil) + metrics := newScopeMetrics(tt.scope) + hasAny, ok := metrics.hasAny.(*funcMetric) + if !ok { + t.Fatal("hasAny is not a funcMetric") + } + numErrored, ok := metrics.numErrored.(*funcMetric) + if !ok { + t.Fatal("numErrored is not a funcMetric") + } + if hasAny.name != tt.wantHasAnyName { + t.Errorf("hasAny.Name: got %q, want %q", hasAny.name, tt.wantHasAnyName) + } + if hasAny.typ != tt.wantHasAnyType { + t.Errorf("hasAny.Type: got %q, want %q", hasAny.typ, tt.wantHasAnyType) + } + if numErrored.name != tt.wantNumErroredName { + t.Errorf("numErrored.Name: got %q, want %q", numErrored.name, tt.wantNumErroredName) + } + if numErrored.typ != tt.wantNumErroredType { + t.Errorf("hasAny.Type: got %q, want %q", numErrored.typ, tt.wantNumErroredType) + } + }) + } +} + +type testSettingDetails struct { + definition *setting.Definition + origin *setting.Origin + value any + err error +} + +func TestReportMetrics(t *testing.T) { + tests := []struct { + name string + osOverride string + useMetrics bool + settings []testSettingDetails + wantMetrics []TestState + wantResetMetrics []TestState + }{ + { + name: "none", + osOverride: "windows", + settings: []testSettingDetails{}, + wantMetrics: []TestState{}, + }, + { + name: "single-value", + osOverride: "windows", + settings: []testSettingDetails{ + { + definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue), + origin: setting.NewOrigin(setting.DeviceScope), + value: 42, + }, + }, + wantMetrics: []TestState{ + {"windows_syspolicy_any", 1}, + {"windows_syspolicy_TestSetting01", 1}, + }, + wantResetMetrics: []TestState{ + {"windows_syspolicy_any", 0}, + {"windows_syspolicy_TestSetting01", 0}, + }, + }, + { + name: "single-error", + osOverride: "windows", + settings: []testSettingDetails{ + { + definition: setting.NewDefinition("TestSetting02", setting.DeviceSetting, setting.IntegerValue), + origin: setting.NewOrigin(setting.DeviceScope), + err: errors.New("bang!"), + }, + }, + wantMetrics: []TestState{ + {"windows_syspolicy_errors", 1}, + {"windows_syspolicy_TestSetting02_error", 1}, + }, + wantResetMetrics: []TestState{ + {"windows_syspolicy_errors", 1}, + {"windows_syspolicy_TestSetting02_error", 0}, + }, + }, + { + name: "value-and-error", + osOverride: "windows", + settings: []testSettingDetails{ + { + definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue), + origin: setting.NewOrigin(setting.DeviceScope), + value: 42, + }, + { + definition: setting.NewDefinition("TestSetting02", setting.DeviceSetting, setting.IntegerValue), + origin: setting.NewOrigin(setting.DeviceScope), + err: errors.New("bang!"), + }, + }, + + wantMetrics: []TestState{ + {"windows_syspolicy_any", 1}, + {"windows_syspolicy_errors", 1}, + {"windows_syspolicy_TestSetting01", 1}, + {"windows_syspolicy_TestSetting02_error", 1}, + }, + wantResetMetrics: []TestState{ + {"windows_syspolicy_any", 0}, + {"windows_syspolicy_errors", 1}, + {"windows_syspolicy_TestSetting01", 0}, + {"windows_syspolicy_TestSetting02_error", 0}, + }, + }, + { + name: "two-values", + osOverride: "windows", + settings: []testSettingDetails{ + { + definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue), + origin: setting.NewOrigin(setting.DeviceScope), + value: 42, + }, + { + definition: setting.NewDefinition("TestSetting02", setting.DeviceSetting, setting.IntegerValue), + origin: setting.NewOrigin(setting.DeviceScope), + value: 17, + }, + }, + wantMetrics: []TestState{ + {"windows_syspolicy_any", 1}, + {"windows_syspolicy_TestSetting01", 1}, + {"windows_syspolicy_TestSetting02", 1}, + }, + wantResetMetrics: []TestState{ + {"windows_syspolicy_any", 0}, + {"windows_syspolicy_TestSetting01", 0}, + {"windows_syspolicy_TestSetting02", 0}, + }, + }, + { + name: "two-errors", + osOverride: "windows", + settings: []testSettingDetails{ + { + definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue), + origin: setting.NewOrigin(setting.DeviceScope), + err: errors.New("bang!"), + }, + { + definition: setting.NewDefinition("TestSetting02", setting.DeviceSetting, setting.IntegerValue), + origin: setting.NewOrigin(setting.DeviceScope), + err: errors.New("bang!"), + }, + }, + wantMetrics: []TestState{ + {"windows_syspolicy_errors", 2}, + {"windows_syspolicy_TestSetting01_error", 1}, + {"windows_syspolicy_TestSetting02_error", 1}, + }, + wantResetMetrics: []TestState{ + {"windows_syspolicy_errors", 2}, + {"windows_syspolicy_TestSetting01_error", 0}, + {"windows_syspolicy_TestSetting02_error", 0}, + }, + }, + { + name: "multi-scope", + osOverride: "windows", + settings: []testSettingDetails{ + { + definition: setting.NewDefinition("TestSetting01", setting.ProfileSetting, setting.IntegerValue), + origin: setting.NewOrigin(setting.DeviceScope), + value: 42, + }, + { + definition: setting.NewDefinition("TestSetting02", setting.ProfileSetting, setting.IntegerValue), + origin: setting.NewOrigin(setting.CurrentProfileScope), + err: errors.New("bang!"), + }, + { + definition: setting.NewDefinition("TestSetting03", setting.UserSetting, setting.IntegerValue), + origin: setting.NewOrigin(setting.CurrentUserScope), + value: 17, + }, + }, + wantMetrics: []TestState{ + {"windows_syspolicy_any", 1}, + {"windows_syspolicy_profile_errors", 1}, + {"windows_syspolicy_user_any", 1}, + {"windows_syspolicy_TestSetting01", 1}, + {"windows_syspolicy_TestSetting02_profile_error", 1}, + {"windows_syspolicy_TestSetting03_user", 1}, + }, + wantResetMetrics: []TestState{ + {"windows_syspolicy_any", 0}, + {"windows_syspolicy_profile_errors", 1}, + {"windows_syspolicy_user_any", 0}, + {"windows_syspolicy_TestSetting01", 0}, + {"windows_syspolicy_TestSetting02_profile_error", 0}, + {"windows_syspolicy_TestSetting03_user", 0}, + }, + }, + { + name: "report-metrics-on-android", + osOverride: "android", + settings: []testSettingDetails{ + { + definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue), + origin: setting.NewOrigin(setting.DeviceScope), + value: 42, + }, + }, + wantMetrics: []TestState{ + {"android_syspolicy_any", 1}, + {"android_syspolicy_TestSetting01", 1}, + }, + wantResetMetrics: []TestState{ + {"android_syspolicy_any", 0}, + {"android_syspolicy_TestSetting01", 0}, + }, + }, + { + name: "do-not-report-metrics-on-macos", + osOverride: "macos", + settings: []testSettingDetails{ + { + definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue), + origin: setting.NewOrigin(setting.DeviceScope), + value: 42, + }, + }, + + wantMetrics: []TestState{}, // none reported + }, + { + name: "do-not-report-metrics-on-ios", + osOverride: "ios", + settings: []testSettingDetails{ + { + definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue), + origin: setting.NewOrigin(setting.DeviceScope), + value: 42, + }, + }, + + wantMetrics: []TestState{}, // none reported + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset the lazy value so it'll be re-evaluated with the osOverride. + lazyReportMetrics = lazy.SyncValue[bool]{} + t.Cleanup(func() { + // Also reset it during the cleanup. + lazyReportMetrics = lazy.SyncValue[bool]{} + }) + internal.OSForTesting.SetForTest(t, tt.osOverride, nil) + + h := NewTestHandler(t) + SetHooksForTest(t, h.AddMetric, h.SetMetric) + + for _, s := range tt.settings { + if s.err != nil { + ReportError(s.origin, s.definition, s.err) + } else { + ReportConfigured(s.origin, s.definition, s.value) + } + } + h.MustEqual(tt.wantMetrics...) + + for _, s := range tt.settings { + Reset(s.origin) + ReportNotConfigured(s.origin, s.definition) + } + h.MustEqual(tt.wantResetMetrics...) + }) + } +} diff --git a/util/syspolicy/internal/metrics/test_handler.go b/util/syspolicy/internal/metrics/test_handler.go new file mode 100644 index 000000000..f9e484609 --- /dev/null +++ b/util/syspolicy/internal/metrics/test_handler.go @@ -0,0 +1,88 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package metrics + +import ( + "strings" + + "tailscale.com/util/clientmetric" + "tailscale.com/util/set" + "tailscale.com/util/syspolicy/internal" +) + +// TestState represents a metric name and its expected value. +type TestState struct { + Name string // `$os` in the name will be replaced by the actual operating system name. + Value int64 +} + +// TestHandler facilitates testing of the code that uses metrics. +type TestHandler struct { + t internal.TB + + m map[string]int64 +} + +// NewTestHandler returns a new TestHandler. +func NewTestHandler(t internal.TB) *TestHandler { + return &TestHandler{t, make(map[string]int64)} +} + +// AddMetric increments the metric with the specified name and type by delta d. +func (h *TestHandler) AddMetric(name string, typ clientmetric.Type, d int64) { + h.t.Helper() + if typ == clientmetric.TypeCounter && d < 0 { + h.t.Fatalf("an attempt was made to decrement a counter metric %q", name) + } + if v, ok := h.m[name]; ok || d != 0 { + h.m[name] = v + d + } +} + +// SetMetric sets the metric with the specified name and type to the value v. +func (h *TestHandler) SetMetric(name string, typ clientmetric.Type, v int64) { + h.t.Helper() + if typ == clientmetric.TypeCounter { + h.t.Fatalf("an attempt was made to set a counter metric %q", name) + } + if _, ok := h.m[name]; ok || v != 0 { + h.m[name] = v + } +} + +// MustEqual fails the test if the actual metric state differs from the specified state. +func (h *TestHandler) MustEqual(metrics ...TestState) { + h.t.Helper() + h.MustContain(metrics...) + h.mustNoExtra(metrics...) +} + +// MustContain fails the test if the specified metrics are not set or have +// different values than specified. It permits other metrics to be set in +// addition to the ones being tested. +func (h *TestHandler) MustContain(metrics ...TestState) { + h.t.Helper() + for _, m := range metrics { + name := strings.ReplaceAll(m.Name, "$os", internal.OS()) + v, ok := h.m[name] + if !ok { + h.t.Errorf("%q: got (none), want %v", name, m.Value) + } else if v != m.Value { + h.t.Fatalf("%q: got %v, want %v", name, v, m.Value) + } + } +} + +func (h *TestHandler) mustNoExtra(metrics ...TestState) { + h.t.Helper() + s := make(set.Set[string]) + for i := range metrics { + s.Add(strings.ReplaceAll(metrics[i].Name, "$os", internal.OS())) + } + for n, v := range h.m { + if !s.Contains(n) { + h.t.Errorf("%q: got %v, want (none)", n, v) + } + } +} diff --git a/util/syspolicy/setting/errors.go b/util/syspolicy/setting/errors.go index d7e14df83..38dc6a88c 100644 --- a/util/syspolicy/setting/errors.go +++ b/util/syspolicy/setting/errors.go @@ -42,9 +42,9 @@ func NewErrorText(text string) *ErrorText { return ptr.To(ErrorText(text)) } -// NewErrorTextFromError returns an [ErrorText] with the text of the specified error, +// MaybeErrorText returns an [ErrorText] with the text of the specified error, // or nil if err is nil, [ErrNotConfigured], or [ErrNoSuchKey]. -func NewErrorTextFromError(err error) *ErrorText { +func MaybeErrorText(err error) *ErrorText { if err == nil || errors.Is(err, ErrNotConfigured) || errors.Is(err, ErrNoSuchKey) { return nil } diff --git a/util/syspolicy/source/policy_reader.go b/util/syspolicy/source/policy_reader.go new file mode 100644 index 000000000..a1bd3147e --- /dev/null +++ b/util/syspolicy/source/policy_reader.go @@ -0,0 +1,394 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package source + +import ( + "errors" + "fmt" + "io" + "slices" + "sort" + "sync" + "time" + + "tailscale.com/util/mak" + "tailscale.com/util/set" + "tailscale.com/util/syspolicy/internal/loggerx" + "tailscale.com/util/syspolicy/internal/metrics" + "tailscale.com/util/syspolicy/setting" +) + +// Reader reads all configured policy settings from a given [Store]. +// It registers a change callback with the [Store] and maintains the current version +// of the [setting.Snapshot] by lazily re-reading policy settings from the [Store] +// whenever a new settings snapshot is requested with [Reader.GetSettings]. +// It is safe for concurrent use. +type Reader struct { + store Store + origin *setting.Origin + settings []*setting.Definition + unregisterChangeNotifier func() + doneCh chan struct{} // closed when [Reader] is closed. + + mu sync.Mutex + closing bool + upToDate bool + lastPolicy *setting.Snapshot + sessions set.HandleSet[*ReadingSession] +} + +// newReader returns a new [Reader] that reads policy settings from a given [Store]. +// The returned reader takes ownership of the store. If the store implements [io.Closer], +// the returned reader will close the store when it is closed. +func newReader(store Store, origin *setting.Origin) (*Reader, error) { + settings, err := setting.Definitions() + if err != nil { + return nil, err + } + + if expirable, ok := store.(Expirable); ok { + select { + case <-expirable.Done(): + return nil, ErrStoreClosed + default: + } + } + + reader := &Reader{store: store, origin: origin, settings: settings, doneCh: make(chan struct{})} + if changeable, ok := store.(Changeable); ok { + // We should subscribe to policy change notifications first before reading + // the policy settings from the store. This way we won't miss any notifications. + if reader.unregisterChangeNotifier, err = changeable.RegisterChangeCallback(reader.onPolicyChange); err != nil { + // Errors registering policy change callbacks are non-fatal. + // TODO(nickkhyl): implement a background policy refresh every X minutes? + loggerx.Errorf("failed to register %v policy change callback: %v", origin, err) + } + } + + if _, err := reader.reload(true); err != nil { + if reader.unregisterChangeNotifier != nil { + reader.unregisterChangeNotifier() + } + return nil, err + } + + if expirable, ok := store.(Expirable); ok { + if waitCh := expirable.Done(); waitCh != nil { + go func() { + select { + case <-waitCh: + reader.Close() + case <-reader.doneCh: + } + }() + } + } + + return reader, nil +} + +// GetSettings returns the current [*setting.Snapshot], +// re-reading it from from the underlying [Store] only if the policy +// has changed since it was read last. It never fails and returns +// the previous version of the policy settings if a read attempt fails. +func (r *Reader) GetSettings() *setting.Snapshot { + r.mu.Lock() + upToDate, lastPolicy := r.upToDate, r.lastPolicy + r.mu.Unlock() + if upToDate { + return lastPolicy + } + + policy, err := r.reload(false) + if err != nil { + // If the policy fails to reload completely, log an error and return the last cached version. + // However, errors related to individual policy items are always + // propagated to callers when they fetch those settings. + loggerx.Errorf("failed to reload %v policy: %v", r.origin, err) + } + return policy +} + +// ReadSettings reads policy settings from the underlying [Store] even if no +// changes were detected. It returns the new [*setting.Snapshot],nil on +// success or an undefined snapshot (possibly `nil`) along with a non-`nil` +// error in case of failure. +func (r *Reader) ReadSettings() (*setting.Snapshot, error) { + return r.reload(true) +} + +// reload is like [Reader.ReadSettings], but allows specifying whether to re-read +// an unchanged policy, and returns the last [*setting.Snapshot] if the read fails. +func (r *Reader) reload(force bool) (*setting.Snapshot, error) { + r.mu.Lock() + defer r.mu.Unlock() + if r.upToDate && !force { + return r.lastPolicy, nil + } + + if lockable, ok := r.store.(Lockable); ok { + if err := lockable.Lock(); err != nil { + return r.lastPolicy, err + } + defer lockable.Unlock() + } + + r.upToDate = true + + metrics.Reset(r.origin) + + var m map[setting.Key]setting.RawItem + if lastPolicyCount := r.lastPolicy.Len(); lastPolicyCount > 0 { + m = make(map[setting.Key]setting.RawItem, lastPolicyCount) + } + for _, s := range r.settings { + if !r.origin.Scope().IsConfigurableSetting(s) { + // Skip settings that cannot be configured in the current scope. + continue + } + + val, err := readPolicySettingValue(r.store, s) + if err != nil && (errors.Is(err, setting.ErrNoSuchKey) || errors.Is(err, setting.ErrNotConfigured)) { + metrics.ReportNotConfigured(r.origin, s) + continue + } + + if err == nil { + metrics.ReportConfigured(r.origin, s, val) + } else { + metrics.ReportError(r.origin, s, err) + } + + // If there's an error reading a single policy, such as a value type mismatch, + // we'll wrap the error to preserve its text and return it + // whenever someone attempts to fetch the value. + // Otherwise, the errorText will be nil. + errorText := setting.MaybeErrorText(err) + item := setting.RawItemWith(val, errorText, r.origin) + mak.Set(&m, s.Key(), item) + } + + newPolicy := setting.NewSnapshot(m, setting.SummaryWith(r.origin)) + if r.lastPolicy == nil || !newPolicy.EqualItems(r.lastPolicy) { + r.lastPolicy = newPolicy + } + return r.lastPolicy, nil +} + +// ReadingSession is like [Reader], but with a channel that's written +// to when there's a policy change, and closed when the session is terminated. +type ReadingSession struct { + reader *Reader + policyChangedCh chan struct{} // 1-buffered channel + handle set.Handle // in the reader.sessions + closeInternal func() +} + +// OpenSession opens and returns a new session to r, allowing the caller +// to get notified whenever a policy change is reported by the [source.Store], +// or an [ErrStoreClosed] if the reader has already been closed. +func (r *Reader) OpenSession() (*ReadingSession, error) { + session := &ReadingSession{ + reader: r, + policyChangedCh: make(chan struct{}, 1), + } + session.closeInternal = sync.OnceFunc(func() { close(session.policyChangedCh) }) + r.mu.Lock() + defer r.mu.Unlock() + if r.closing { + return nil, ErrStoreClosed + } + session.handle = r.sessions.Add(session) + return session, nil +} + +// GetSettings is like [Reader.GetSettings]. +func (s *ReadingSession) GetSettings() *setting.Snapshot { + return s.reader.GetSettings() +} + +// ReadSettings is like [Reader.ReadSettings]. +func (s *ReadingSession) ReadSettings() (*setting.Snapshot, error) { + return s.reader.ReadSettings() +} + +// PolicyChanged returns a channel that's written to when +// there's a policy change, closed when the session is terminated. +func (s *ReadingSession) PolicyChanged() <-chan struct{} { + return s.policyChangedCh +} + +// Close unregisters this session with the [Reader]. +func (s *ReadingSession) Close() { + s.reader.mu.Lock() + delete(s.reader.sessions, s.handle) + s.closeInternal() + s.reader.mu.Unlock() +} + +// onPolicyChange handles a policy change notification from the [Store], +// invalidating the current [setting.Snapshot] in r, +// and notifying the active [ReadingSession]s. +func (r *Reader) onPolicyChange() { + r.mu.Lock() + defer r.mu.Unlock() + r.upToDate = false + for _, s := range r.sessions { + select { + case s.policyChangedCh <- struct{}{}: + // Notified. + default: + // 1-buffered channel is full, meaning that another policy change + // notification is already en route. + } + } +} + +// Close closes the store reader and the underlying store. +func (r *Reader) Close() error { + r.mu.Lock() + if r.closing { + r.mu.Unlock() + return nil + } + r.closing = true + r.mu.Unlock() + + if r.unregisterChangeNotifier != nil { + r.unregisterChangeNotifier() + r.unregisterChangeNotifier = nil + } + + if closer, ok := r.store.(io.Closer); ok { + if err := closer.Close(); err != nil { + return err + } + } + r.store = nil + + close(r.doneCh) + + r.mu.Lock() + defer r.mu.Unlock() + for _, c := range r.sessions { + c.closeInternal() + } + r.sessions = nil + return nil +} + +// Done returns a channel that is closed when the reader is closed. +func (r *Reader) Done() <-chan struct{} { + return r.doneCh +} + +// ReadableSource is a [Source] open for reading. +type ReadableSource struct { + *Source + *ReadingSession +} + +// Close closes the underlying [ReadingSession]. +func (s ReadableSource) Close() { + s.ReadingSession.Close() +} + +// ReadableSources is a slice of [ReadableSource]. +type ReadableSources []ReadableSource + +// Contains reports whether s contains the specified source. +func (s ReadableSources) Contains(source *Source) bool { + return s.IndexOf(source) != -1 +} + +// IndexOf returns position of the specified source in s, or -1 +// if the source does not exist. +func (s ReadableSources) IndexOf(source *Source) int { + return slices.IndexFunc(s, func(rs ReadableSource) bool { + return rs.Source == source + }) +} + +// InsertionIndexOf returns the position at which source can be inserted +// to maintain the sorted order of the readableSources. +// The return value is unspecified if s is not sorted on entry to InsertionIndexOf. +func (s ReadableSources) InsertionIndexOf(source *Source) int { + // Insert new sources after any existing sources with the same precedence, + // and just before the first source with higher precedence. + // Just like stable sort, but for insertion. + // It's okay to use linear search as insertions are rare + // and we never have more than just a few policy sources. + higherPrecedence := func(rs ReadableSource) bool { return rs.Compare(source) > 0 } + if i := slices.IndexFunc(s, higherPrecedence); i != -1 { + return i + } + return len(s) +} + +// StableSort sorts [ReadableSource] in s by precedence, so that policy +// settings from sources with higher precedence (e.g., [DeviceScope]) +// will be read and merged last, overriding any policy settings with +// the same keys configured in sources with lower precedence +// (e.g., [CurrentUserScope]). +func (s *ReadableSources) StableSort() { + sort.SliceStable(*s, func(i, j int) bool { + return (*s)[i].Source.Compare((*s)[j].Source) < 0 + }) +} + +// DeleteAt closes and deletes the i-th source from s. +func (s *ReadableSources) DeleteAt(i int) { + (*s)[i].Close() + *s = slices.Delete(*s, i, i+1) +} + +// Close closes and deletes all sources in s. +func (s *ReadableSources) Close() { + for _, s := range *s { + s.Close() + } + *s = nil +} + +func readPolicySettingValue(store Store, s *setting.Definition) (value any, err error) { + switch key := s.Key(); s.Type() { + case setting.BooleanValue: + return store.ReadBoolean(key) + case setting.IntegerValue: + return store.ReadUInt64(key) + case setting.StringValue: + return store.ReadString(key) + case setting.StringListValue: + return store.ReadStringArray(key) + case setting.PreferenceOptionValue: + s, err := store.ReadString(key) + if err == nil { + var value setting.PreferenceOption + if err = value.UnmarshalText([]byte(s)); err == nil { + return value, nil + } + } + return setting.ShowChoiceByPolicy, err + case setting.VisibilityValue: + s, err := store.ReadString(key) + if err == nil { + var value setting.Visibility + if err = value.UnmarshalText([]byte(s)); err == nil { + return value, nil + } + } + return setting.VisibleByPolicy, err + case setting.DurationValue: + s, err := store.ReadString(key) + if err == nil { + var value time.Duration + if value, err = time.ParseDuration(s); err == nil { + return value, nil + } + } + return nil, err + default: + return nil, fmt.Errorf("%w: unsupported setting type: %v", setting.ErrTypeMismatch, s.Type()) + } +} diff --git a/util/syspolicy/source/policy_reader_test.go b/util/syspolicy/source/policy_reader_test.go new file mode 100644 index 000000000..57676e67d --- /dev/null +++ b/util/syspolicy/source/policy_reader_test.go @@ -0,0 +1,291 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package source + +import ( + "cmp" + "testing" + "time" + + "tailscale.com/util/must" + "tailscale.com/util/syspolicy/setting" +) + +func TestReaderLifecycle(t *testing.T) { + tests := []struct { + name string + origin *setting.Origin + definitions []*setting.Definition + wantReads []TestExpectedReads + initStrings []TestSetting[string] + initUInt64s []TestSetting[uint64] + initWant *setting.Snapshot + addStrings []TestSetting[string] + addStringLists []TestSetting[[]string] + newWant *setting.Snapshot + }{ + { + name: "read-all-settings-once", + origin: setting.NewNamedOrigin("Test", setting.DeviceScope), + definitions: []*setting.Definition{ + setting.NewDefinition("StringValue", setting.DeviceSetting, setting.StringValue), + setting.NewDefinition("IntegerValue", setting.DeviceSetting, setting.IntegerValue), + setting.NewDefinition("BooleanValue", setting.DeviceSetting, setting.BooleanValue), + setting.NewDefinition("StringListValue", setting.DeviceSetting, setting.StringListValue), + setting.NewDefinition("DurationValue", setting.DeviceSetting, setting.DurationValue), + setting.NewDefinition("PreferenceOptionValue", setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition("VisibilityValue", setting.DeviceSetting, setting.VisibilityValue), + }, + wantReads: []TestExpectedReads{ + {Key: "StringValue", Type: setting.StringValue, NumTimes: 1}, + {Key: "IntegerValue", Type: setting.IntegerValue, NumTimes: 1}, + {Key: "BooleanValue", Type: setting.BooleanValue, NumTimes: 1}, + {Key: "StringListValue", Type: setting.StringListValue, NumTimes: 1}, + {Key: "DurationValue", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective + {Key: "PreferenceOptionValue", Type: setting.StringValue, NumTimes: 1}, // and so are [setting.PreferenceOption]s + {Key: "VisibilityValue", Type: setting.StringValue, NumTimes: 1}, // and [setting.Visibility] + }, + initWant: setting.NewSnapshot(nil, setting.NewNamedOrigin("Test", setting.DeviceScope)), + }, + { + name: "re-read-all-settings-when-the-policy-changes", + origin: setting.NewNamedOrigin("Test", setting.DeviceScope), + definitions: []*setting.Definition{ + setting.NewDefinition("StringValue", setting.DeviceSetting, setting.StringValue), + setting.NewDefinition("IntegerValue", setting.DeviceSetting, setting.IntegerValue), + setting.NewDefinition("BooleanValue", setting.DeviceSetting, setting.BooleanValue), + setting.NewDefinition("StringListValue", setting.DeviceSetting, setting.StringListValue), + setting.NewDefinition("DurationValue", setting.DeviceSetting, setting.DurationValue), + setting.NewDefinition("PreferenceOptionValue", setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition("VisibilityValue", setting.DeviceSetting, setting.VisibilityValue), + }, + wantReads: []TestExpectedReads{ + {Key: "StringValue", Type: setting.StringValue, NumTimes: 1}, + {Key: "IntegerValue", Type: setting.IntegerValue, NumTimes: 1}, + {Key: "BooleanValue", Type: setting.BooleanValue, NumTimes: 1}, + {Key: "StringListValue", Type: setting.StringListValue, NumTimes: 1}, + {Key: "DurationValue", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective + {Key: "PreferenceOptionValue", Type: setting.StringValue, NumTimes: 1}, // and so are [setting.PreferenceOption]s + {Key: "VisibilityValue", Type: setting.StringValue, NumTimes: 1}, // and [setting.Visibility] + }, + initWant: setting.NewSnapshot(nil, setting.NewNamedOrigin("Test", setting.DeviceScope)), + addStrings: []TestSetting[string]{TestSettingOf("StringValue", "S1")}, + addStringLists: []TestSetting[[]string]{TestSettingOf("StringListValue", []string{"S1", "S2", "S3"})}, + newWant: setting.NewSnapshot(map[setting.Key]setting.RawItem{ + "StringValue": setting.RawItemWith("S1", nil, setting.NewNamedOrigin("Test", setting.DeviceScope)), + "StringListValue": setting.RawItemWith([]string{"S1", "S2", "S3"}, nil, setting.NewNamedOrigin("Test", setting.DeviceScope)), + }, setting.NewNamedOrigin("Test", setting.DeviceScope)), + }, + { + name: "read-settings-if-in-scope/device", + origin: setting.NewNamedOrigin("Test", setting.DeviceScope), + definitions: []*setting.Definition{ + setting.NewDefinition("DeviceSetting", setting.DeviceSetting, setting.StringValue), + setting.NewDefinition("ProfileSetting", setting.ProfileSetting, setting.IntegerValue), + setting.NewDefinition("UserSetting", setting.UserSetting, setting.BooleanValue), + }, + wantReads: []TestExpectedReads{ + {Key: "DeviceSetting", Type: setting.StringValue, NumTimes: 1}, + {Key: "ProfileSetting", Type: setting.IntegerValue, NumTimes: 1}, + {Key: "UserSetting", Type: setting.BooleanValue, NumTimes: 1}, + }, + }, + { + name: "read-settings-if-in-scope/profile", + origin: setting.NewNamedOrigin("Test", setting.CurrentProfileScope), + definitions: []*setting.Definition{ + setting.NewDefinition("DeviceSetting", setting.DeviceSetting, setting.StringValue), + setting.NewDefinition("ProfileSetting", setting.ProfileSetting, setting.IntegerValue), + setting.NewDefinition("UserSetting", setting.UserSetting, setting.BooleanValue), + }, + wantReads: []TestExpectedReads{ + // Device settings cannot be configured at the profile scope and should not be read. + {Key: "ProfileSetting", Type: setting.IntegerValue, NumTimes: 1}, + {Key: "UserSetting", Type: setting.BooleanValue, NumTimes: 1}, + }, + }, + { + name: "read-settings-if-in-scope/user", + origin: setting.NewNamedOrigin("Test", setting.CurrentUserScope), + definitions: []*setting.Definition{ + setting.NewDefinition("DeviceSetting", setting.DeviceSetting, setting.StringValue), + setting.NewDefinition("ProfileSetting", setting.ProfileSetting, setting.IntegerValue), + setting.NewDefinition("UserSetting", setting.UserSetting, setting.BooleanValue), + }, + wantReads: []TestExpectedReads{ + // Device and profile settings cannot be configured at the profile scope and should not be read. + {Key: "UserSetting", Type: setting.BooleanValue, NumTimes: 1}, + }, + }, + { + name: "read-stringy-settings", + origin: setting.NewNamedOrigin("Test", setting.DeviceScope), + definitions: []*setting.Definition{ + setting.NewDefinition("DurationValue", setting.DeviceSetting, setting.DurationValue), + setting.NewDefinition("PreferenceOptionValue", setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition("VisibilityValue", setting.DeviceSetting, setting.VisibilityValue), + }, + wantReads: []TestExpectedReads{ + {Key: "DurationValue", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective + {Key: "PreferenceOptionValue", Type: setting.StringValue, NumTimes: 1}, // and so are [setting.PreferenceOption]s + {Key: "VisibilityValue", Type: setting.StringValue, NumTimes: 1}, // and [setting.Visibility] + }, + initStrings: []TestSetting[string]{ + TestSettingOf("DurationValue", "2h30m"), + TestSettingOf("PreferenceOptionValue", "always"), + TestSettingOf("VisibilityValue", "show"), + }, + initWant: setting.NewSnapshot(map[setting.Key]setting.RawItem{ + "DurationValue": setting.RawItemWith(must.Get(time.ParseDuration("2h30m")), nil, setting.NewNamedOrigin("Test", setting.DeviceScope)), + "PreferenceOptionValue": setting.RawItemWith(setting.AlwaysByPolicy, nil, setting.NewNamedOrigin("Test", setting.DeviceScope)), + "VisibilityValue": setting.RawItemWith(setting.VisibleByPolicy, nil, setting.NewNamedOrigin("Test", setting.DeviceScope)), + }, setting.NewNamedOrigin("Test", setting.DeviceScope)), + }, + { + name: "read-erroneous-stringy-settings", + origin: setting.NewNamedOrigin("Test", setting.CurrentUserScope), + definitions: []*setting.Definition{ + setting.NewDefinition("DurationValue1", setting.UserSetting, setting.DurationValue), + setting.NewDefinition("DurationValue2", setting.UserSetting, setting.DurationValue), + setting.NewDefinition("PreferenceOptionValue", setting.UserSetting, setting.PreferenceOptionValue), + setting.NewDefinition("VisibilityValue", setting.UserSetting, setting.VisibilityValue), + }, + wantReads: []TestExpectedReads{ + {Key: "DurationValue1", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective + {Key: "DurationValue2", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective + {Key: "PreferenceOptionValue", Type: setting.StringValue, NumTimes: 1}, // and so are [setting.PreferenceOption]s + {Key: "VisibilityValue", Type: setting.StringValue, NumTimes: 1}, // and [setting.Visibility] + }, + initStrings: []TestSetting[string]{ + TestSettingOf("DurationValue1", "soon"), + TestSettingWithError[string]("DurationValue2", setting.NewErrorText("bang!")), + TestSettingOf("PreferenceOptionValue", "sometimes"), + }, + initUInt64s: []TestSetting[uint64]{ + TestSettingOf[uint64]("VisibilityValue", 42), // type mismatch + }, + initWant: setting.NewSnapshot(map[setting.Key]setting.RawItem{ + "DurationValue1": setting.RawItemWith(nil, setting.NewErrorText("time: invalid duration \"soon\""), setting.NewNamedOrigin("Test", setting.CurrentUserScope)), + "DurationValue2": setting.RawItemWith(nil, setting.NewErrorText("bang!"), setting.NewNamedOrigin("Test", setting.CurrentUserScope)), + "PreferenceOptionValue": setting.RawItemWith(setting.ShowChoiceByPolicy, nil, setting.NewNamedOrigin("Test", setting.CurrentUserScope)), + "VisibilityValue": setting.RawItemWith(setting.VisibleByPolicy, setting.NewErrorText("type mismatch in ReadString: got uint64"), setting.NewNamedOrigin("Test", setting.CurrentUserScope)), + }, setting.NewNamedOrigin("Test", setting.CurrentUserScope)), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + setting.SetDefinitionsForTest(t, tt.definitions...) + store := NewTestStore(t) + store.SetStrings(tt.initStrings...) + store.SetUInt64s(tt.initUInt64s...) + + reader, err := newReader(store, tt.origin) + if err != nil { + t.Fatalf("newReader failed: %v", err) + } + + if got := reader.GetSettings(); tt.initWant != nil && !got.Equal(tt.initWant) { + t.Errorf("Settings do not match: got %v, want %v", got, tt.initWant) + } + if tt.wantReads != nil { + store.ReadsMustEqual(tt.wantReads...) + } + + // Should not result in new reads as there were no changes. + N := 100 + for range N { + reader.GetSettings() + } + if tt.wantReads != nil { + store.ReadsMustEqual(tt.wantReads...) + } + store.ResetCounters() + + got, err := reader.ReadSettings() + if err != nil { + t.Fatalf("ReadSettings failed: %v", err) + } + + if tt.initWant != nil && !got.Equal(tt.initWant) { + t.Errorf("Settings do not match: got %v, want %v", got, tt.initWant) + } + + if tt.wantReads != nil { + store.ReadsMustEqual(tt.wantReads...) + } + store.ResetCounters() + + if len(tt.addStrings) != 0 || len(tt.addStringLists) != 0 { + store.SetStrings(tt.addStrings...) + store.SetStringLists(tt.addStringLists...) + + // As the settings have changed, GetSettings needs to re-read them. + if got, want := reader.GetSettings(), cmp.Or(tt.newWant, tt.initWant); !got.Equal(want) { + t.Errorf("New Settings do not match: got %v, want %v", got, want) + } + if tt.wantReads != nil { + store.ReadsMustEqual(tt.wantReads...) + } + } + + select { + case <-reader.Done(): + t.Fatalf("the reader is closed") + default: + } + + store.Close() + + <-reader.Done() + }) + } +} + +func TestReadingSession(t *testing.T) { + setting.SetDefinitionsForTest(t, setting.NewDefinition("StringValue", setting.DeviceSetting, setting.StringValue)) + store := NewTestStore(t) + + origin := setting.NewOrigin(setting.DeviceScope) + reader, err := newReader(store, origin) + if err != nil { + t.Fatalf("newReader failed: %v", err) + } + session, err := reader.OpenSession() + if err != nil { + t.Fatalf("failed to open a reading session: %v", err) + } + t.Cleanup(session.Close) + + if got, want := session.GetSettings(), setting.NewSnapshot(nil, origin); !got.Equal(want) { + t.Errorf("Settings do not match: got %v, want %v", got, want) + } + + select { + case _, ok := <-session.PolicyChanged(): + if ok { + t.Fatalf("the policy changed notification was sent prematurely") + } else { + t.Fatalf("the session was closed prematurely") + } + default: + } + + store.SetStrings(TestSettingOf("StringValue", "S1")) + _, ok := <-session.PolicyChanged() + if !ok { + t.Fatalf("the session was closed prematurely") + } + + want := setting.NewSnapshot(map[setting.Key]setting.RawItem{ + "StringValue": setting.RawItemWith("S1", nil, origin), + }, origin) + if got := session.GetSettings(); !got.Equal(want) { + t.Errorf("Settings do not match: got %v, want %v", got, want) + } + + store.Close() + if _, ok = <-session.PolicyChanged(); ok { + t.Fatalf("the session must be closed") + } +} diff --git a/util/syspolicy/source/policy_source.go b/util/syspolicy/source/policy_source.go new file mode 100644 index 000000000..7f2821b59 --- /dev/null +++ b/util/syspolicy/source/policy_source.go @@ -0,0 +1,146 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package source defines interfaces for policy stores, +// facilitates the creation of policy sources, and provides +// functionality for reading policy settings from these sources. +package source + +import ( + "cmp" + "errors" + "fmt" + "io" + + "tailscale.com/types/lazy" + "tailscale.com/util/syspolicy/setting" +) + +// ErrStoreClosed is an error returned when attempting to use a [Store] after it +// has been closed. +var ErrStoreClosed = errors.New("the policy store has been closed") + +// Store provides methods to read system policy settings from OS-specific storage. +// Implementations must be concurrency-safe, and may also implement +// [Lockable], [Changeable], [Expirable] and [io.Closer]. +// +// If a [Store] implementation also implements [io.Closer], +// it will be called by the package to release the resources +// when the store is no longer needed. +type Store interface { + // ReadString returns the value of a [setting.StringValue] with the specified key, + // an [setting.ErrNotConfigured] if the policy setting is not configured, or + // an error on failure. + ReadString(key setting.Key) (string, error) + // ReadUInt64 returns the value of a [setting.IntegerValue] with the specified key, + // an [setting.ErrNotConfigured] if the policy setting is not configured, or + // an error on failure. + ReadUInt64(key setting.Key) (uint64, error) + // ReadBoolean returns the value of a [setting.BooleanValue] with the specified key, + // an [setting.ErrNotConfigured] if the policy setting is not configured, or + // an error on failure. + ReadBoolean(key setting.Key) (bool, error) + // ReadStringArray returns the value of a [setting.StringListValue] with the specified key, + // an [setting.ErrNotConfigured] if the policy setting is not configured, or + // an error on failure. + ReadStringArray(key setting.Key) ([]string, error) +} + +// Lockable is an optional interface that [Store] implementations may support. +// Locking a [Store] is not mandatory as [Store] must be concurrency-safe, +// but is recommended to avoid issues where consecutive read calls for related +// policies might return inconsistent results if a policy change occurs between +// the calls. Implementations may use locking to pre-read policies or for +// similar performance optimizations. +type Lockable interface { + // Lock acquires a read lock on the policy store, + // ensuring the store's state remains unchanged while locked. + // Multiple readers can hold the lock simultaneously. + // It returns an error if the store cannot be locked. + Lock() error + // Unlock unlocks the policy store. + // It is a run-time error if the store is not locked on entry to Unlock. + Unlock() +} + +// Changeable is an optional interface that [Store] implementations may support +// if the policy settings they contain can be externally changed after being initially read. +type Changeable interface { + // RegisterChangeCallback adds a function that will be called + // whenever there's a policy change in the [Store]. + // The returned function can be used to unregister the callback. + RegisterChangeCallback(callback func()) (unregister func(), err error) +} + +// Expirable is an optional interface that [Store] implementations may support +// if they can be externally closed or otherwise become invalid while in use. +type Expirable interface { + // Done returns a channel that is closed when the policy [Store] should no longer be used. + // It should return nil if the store never expires. + Done() <-chan struct{} +} + +// Source represents a named source of policy settings for a given [setting.PolicyScope]. +type Source struct { + name string + scope setting.PolicyScope + store Store + origin *setting.Origin + + lazyReader lazy.SyncValue[*Reader] +} + +// NewSource returns a new [Source] with the specified name, scope, and store. +func NewSource(name string, scope setting.PolicyScope, store Store) *Source { + return &Source{name: name, scope: scope, store: store, origin: setting.NewNamedOrigin(name, scope)} +} + +// Name reports the name of the policy source. +func (s *Source) Name() string { + return s.name +} + +// Scope reports the management scope of the policy source. +func (s *Source) Scope() setting.PolicyScope { + return s.scope +} + +// Reader returns a [Reader] that reads from this source's [Store]. +func (s *Source) Reader() (*Reader, error) { + return s.lazyReader.GetErr(func() (*Reader, error) { + return newReader(s.store, s.origin) + }) +} + +// Description returns a formatted string with the scope and name of this policy source. +// It can be used for logging or display purposes. +func (s *Source) Description() string { + if s.name != "" { + return fmt.Sprintf("%s (%v)", s.name, s.Scope()) + } + return s.Scope().String() +} + +// Compare returns an integer comparing s and s2 +// by their precedence, following the "last-wins" model. +// The result will be: +// +// -1 if policy settings from s should be processed before policy settings from s2; +// +1 if policy settings from s should be processed after policy settings from s2, overriding s2; +// 0 if the relative processing order of policy settings in s and s2 is unspecified. +func (s *Source) Compare(s2 *Source) int { + return cmp.Compare(s2.Scope().Kind(), s.Scope().Kind()) +} + +// Close closes the [Source] and the underlying [Store]. +func (s *Source) Close() error { + // The [Reader], if any, owns the [Store]. + if reader, _ := s.lazyReader.GetErr(func() (*Reader, error) { return nil, ErrStoreClosed }); reader != nil { + return reader.Close() + } + // Otherwise, it is our responsibility to close it. + if closer, ok := s.store.(io.Closer); ok { + return closer.Close() + } + return nil +} diff --git a/util/syspolicy/source/policy_store_windows.go b/util/syspolicy/source/policy_store_windows.go new file mode 100644 index 000000000..f526b4ce1 --- /dev/null +++ b/util/syspolicy/source/policy_store_windows.go @@ -0,0 +1,450 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package source + +import ( + "errors" + "fmt" + "strings" + "sync" + + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/registry" + "tailscale.com/util/set" + "tailscale.com/util/syspolicy/setting" + "tailscale.com/util/winutil/gp" +) + +const ( + softwareKeyName = `Software` + tsPoliciesSubkey = `Policies\Tailscale` + tsIPNSubkey = `Tailscale IPN` // the legacy key we need to fallback to +) + +var ( + _ Store = (*PlatformPolicyStore)(nil) + _ Lockable = (*PlatformPolicyStore)(nil) + _ Changeable = (*PlatformPolicyStore)(nil) + _ Expirable = (*PlatformPolicyStore)(nil) +) + +// PlatformPolicyStore implements [Store] by providing read access to +// Registry-based Tailscale policies, such as those configured via Group Policy or MDM. +// For better performance and consistency, it is recommended to lock it when +// reading multiple policy settings sequentially. +// It also allows subscribing to policy change notifications. +type PlatformPolicyStore struct { + scope gp.Scope // [gp.MachinePolicy] or [gp.UserPolicy] + + // The softwareKey can be HKLM\Software, HKCU\Software, or + // HKU\{SID}\Software. Anything below the Software subkey, including + // Software\Policies, may not yet exist or could be deleted throughout the + // [PlatformPolicyStore]'s lifespan, invalidating the handle. We also prefer + // to always use a real registry key (rather than a predefined HKLM or HKCU) + // to simplify bookkeeping (predefined keys should never be closed). + // Finally, this will allow us to watch for any registry changes directly + // should we need this in the future in addition to gp.ChangeWatcher. + softwareKey registry.Key + watcher *gp.ChangeWatcher + + done chan struct{} // done is closed when Close call completes + + // The policyLock can be locked by the caller when reading multiple policy settings + // to prevent the Group Policy Client service from modifying policies while + // they are being read. + // + // When both policyLock and mu need to be taken, mu must be taken before policyLock. + policyLock *gp.PolicyLock + + mu sync.Mutex + tsKeys []registry.Key // or nil if the [PlatformPolicyStore] hasn't been locked. + cbs set.HandleSet[func()] // policy change callbacks + lockCnt int + locked sync.WaitGroup + closing bool + closed bool +} + +type registryValueGetter[T any] func(key registry.Key, name string) (T, error) + +// NewMachinePlatformPolicyStore returns a new [PlatformPolicyStore] for the machine. +func NewMachinePlatformPolicyStore() (*PlatformPolicyStore, error) { + softwareKey, err := registry.OpenKey(registry.LOCAL_MACHINE, softwareKeyName, windows.KEY_READ) + if err != nil { + return nil, fmt.Errorf("failed to open the %s key: %w", softwareKeyName, err) + } + return newPlatformPolicyStore(gp.MachinePolicy, softwareKey, gp.NewMachinePolicyLock()), nil +} + +// NewUserPlatformPolicyStore returns a new [PlatformPolicyStore] for the user specified by its token. +// User's profile must be loaded, and the token handle must have [windows.TOKEN_QUERY] +// and [windows.TOKEN_DUPLICATE] access. The caller retains ownership of the token. +func NewUserPlatformPolicyStore(token windows.Token) (*PlatformPolicyStore, error) { + var err error + var softwareKey registry.Key + if token != 0 { + var user *windows.Tokenuser + if user, err = token.GetTokenUser(); err != nil { + return nil, fmt.Errorf("failed to get token user: %w", err) + } + userSid := user.User.Sid + softwareKey, err = registry.OpenKey(registry.USERS, userSid.String()+`\`+softwareKeyName, windows.KEY_READ) + } else { + softwareKey, err = registry.OpenKey(registry.CURRENT_USER, softwareKeyName, windows.KEY_READ) + } + if err != nil { + return nil, fmt.Errorf("failed to open the %s key: %w", softwareKeyName, err) + } + policyLock, err := gp.NewUserPolicyLock(token) + if err != nil { + return nil, fmt.Errorf("failed to create a user policy lock: %w", err) + } + return newPlatformPolicyStore(gp.UserPolicy, softwareKey, policyLock), nil +} + +func newPlatformPolicyStore(scope gp.Scope, softwareKey registry.Key, policyLock *gp.PolicyLock) *PlatformPolicyStore { + return &PlatformPolicyStore{ + scope: scope, + softwareKey: softwareKey, + done: make(chan struct{}), + policyLock: policyLock, + } +} + +// Lock locks the policy store, preventing the system from modifying the policies +// while they are being read. It is a read lock that may be acquired by multiple goroutines. +// Each Lock call must be balanced by exactly one Unlock call. +func (ps *PlatformPolicyStore) Lock() (err error) { + ps.mu.Lock() + defer ps.mu.Unlock() + + if ps.closing { + return ErrStoreClosed + } + + ps.lockCnt += 1 + if ps.lockCnt != 1 { + return nil + } + defer func() { + if err != nil { + ps.lockCnt -= 1 + } + }() + + // Ensure ps remains open while the lock is held. + ps.locked.Add(1) + defer func() { + if err != nil { + ps.locked.Done() + } + }() + + // Acquire the GP lock to prevent the system from modifying policy settings + // while they are being read. + if err := ps.policyLock.Lock(); err != nil { + if errors.Is(err, gp.ErrInvalidLockState) { + // The policy store is being closed and we've lost the race. + return ErrStoreClosed + } + return err + } + defer func() { + if err != nil { + ps.policyLock.Unlock() + } + }() + + // Keep the Tailscale's registry keys open for the duration of the lock. + keyNames := tailscaleKeyNamesFor(ps.scope) + ps.tsKeys = make([]registry.Key, 0, len(keyNames)) + for _, keyName := range keyNames { + var tsKey registry.Key + tsKey, err = registry.OpenKey(ps.softwareKey, keyName, windows.KEY_READ) + if err != nil { + if err == registry.ErrNotExist { + continue + } + return err + } + ps.tsKeys = append(ps.tsKeys, tsKey) + } + + return nil +} + +// Unlock decrements the lock counter and unlocks the policy store once the counter reaches 0. +// It panics if ps is not locked on entry to Unlock. +func (ps *PlatformPolicyStore) Unlock() { + ps.mu.Lock() + defer ps.mu.Unlock() + + ps.lockCnt -= 1 + if ps.lockCnt < 0 { + panic("negative lockCnt") + } else if ps.lockCnt != 0 { + return + } + + for _, key := range ps.tsKeys { + key.Close() + } + ps.tsKeys = nil + ps.policyLock.Unlock() + ps.locked.Done() +} + +// RegisterChangeCallback adds a function that will be called whenever there's a policy change. +// It returns a function that can be used to unregister the specified callback or an error. +// The error is [ErrStoreClosed] if ps has already been closed. +func (ps *PlatformPolicyStore) RegisterChangeCallback(cb func()) (unregister func(), err error) { + ps.mu.Lock() + defer ps.mu.Unlock() + if ps.closing { + return nil, ErrStoreClosed + } + + handle := ps.cbs.Add(cb) + if len(ps.cbs) == 1 { + if ps.watcher, err = gp.NewChangeWatcher(ps.scope, ps.onChange); err != nil { + return nil, err + } + } + + return func() { + ps.mu.Lock() + defer ps.mu.Unlock() + delete(ps.cbs, handle) + if len(ps.cbs) == 0 { + if ps.watcher != nil { + ps.watcher.Close() + ps.watcher = nil + } + } + }, nil +} + +func (ps *PlatformPolicyStore) onChange() { + ps.mu.Lock() + defer ps.mu.Unlock() + if ps.closing { + return + } + for _, callback := range ps.cbs { + go callback() + } +} + +// ReadString retrieves a string policy with the specified key. +// It returns [setting.ErrNotConfigured] if the policy setting does not exist. +func (ps *PlatformPolicyStore) ReadString(key setting.Key) (val string, err error) { + return getPolicyValue(ps, key, + func(key registry.Key, valueName string) (string, error) { + val, _, err := key.GetStringValue(valueName) + return val, err + }) +} + +// ReadUInt64 retrieves an integer policy with the specified key. +// It returns [setting.ErrNotConfigured] if the policy setting does not exist. +func (ps *PlatformPolicyStore) ReadUInt64(key setting.Key) (uint64, error) { + return getPolicyValue(ps, key, + func(key registry.Key, valueName string) (uint64, error) { + val, _, err := key.GetIntegerValue(valueName) + return val, err + }) +} + +// ReadBoolean retrieves a boolean policy with the specified key. +// It returns [setting.ErrNotConfigured] if the policy setting does not exist. +func (ps *PlatformPolicyStore) ReadBoolean(key setting.Key) (bool, error) { + return getPolicyValue(ps, key, + func(key registry.Key, valueName string) (bool, error) { + val, _, err := key.GetIntegerValue(valueName) + if err != nil { + return false, err + } + return val != 0, nil + }) +} + +// ReadString retrieves a multi-string policy with the specified key. +// It returns [setting.ErrNotConfigured] if the policy setting does not exist. +func (ps *PlatformPolicyStore) ReadStringArray(key setting.Key) ([]string, error) { + return getPolicyValue(ps, key, + func(key registry.Key, valueName string) ([]string, error) { + val, _, err := key.GetStringsValue(valueName) + if err != registry.ErrNotExist { + return val, err // the err may be nil or non-nil + } + + // The idiomatic way to store multiple string values in Group Policy + // and MDM for Windows is to have multiple REG_SZ (or REG_EXPAND_SZ) + // values under a subkey rather than in a single REG_MULTI_SZ value. + // + // See the Group Policy: Registry Extension Encoding specification, + // and specifically the ListElement and ListBox types. + // https://web.archive.org/web/20240721033657/https://winprotocoldoc.blob.core.windows.net/productionwindowsarchives/MS-GPREG/%5BMS-GPREG%5D.pdf + valKey, err := registry.OpenKey(key, valueName, windows.KEY_READ) + if err != nil { + return nil, err + } + valNames, err := valKey.ReadValueNames(0) + if err != nil { + return nil, err + } + val = make([]string, 0, len(valNames)) + for _, name := range valNames { + switch item, _, err := valKey.GetStringValue(name); { + case err == registry.ErrNotExist: + continue + case err != nil: + return nil, err + default: + val = append(val, item) + } + } + return val, nil + }) +} + +// splitSettingKey extracts the registry key name and value name from a [setting.Key]. +// The [setting.Key] format allows grouping settings into nested categories using one +// or more [setting.KeyPathSeparator]s in the path. How individual policy settings are +// stored is an implementation detail of each [Store]. In the [PlatformPolicyStore] +// for Windows, we map nested policy categories onto the Registry key hierarchy. +// The last component after a [setting.KeyPathSeparator] is treated as the value name, +// while everything preceding it is considered a subpath (relative to the {HKLM,HKCU}\Software\Policies\Tailscale key). +// If there are no [setting.KeyPathSeparator]s in the key, the policy setting value +// is meant to be stored directly under {HKLM,HKCU}\Software\Policies\Tailscale. +func splitSettingKey(key setting.Key) (path, valueName string) { + if idx := strings.LastIndex(string(key), setting.KeyPathSeparator); idx != -1 { + path = strings.ReplaceAll(string(key[:idx]), setting.KeyPathSeparator, `\`) + valueName = string(key[idx+len(setting.KeyPathSeparator):]) + return path, valueName + } + return "", string(key) +} + +func getPolicyValue[T any](ps *PlatformPolicyStore, key setting.Key, getter registryValueGetter[T]) (T, error) { + var zero T + + ps.mu.Lock() + defer ps.mu.Unlock() + if ps.closed { + return zero, ErrStoreClosed + } + + path, valueName := splitSettingKey(key) + getValue := func(key registry.Key) (T, error) { + var err error + if path != "" { + key, err = registry.OpenKey(key, path, windows.KEY_READ) + if err != nil { + return zero, err + } + defer key.Close() + } + return getter(key, valueName) + } + + if ps.tsKeys != nil { + // A non-nil tsKeys indicates that ps has been locked. + // The slice may be empty if Tailscale policy keys do not exist. + for _, tsKey := range ps.tsKeys { + val, err := getValue(tsKey) + if err == nil || err != registry.ErrNotExist { + return val, err + } + } + return zero, setting.ErrNotConfigured + } + + // The ps has not been locked, so we don't have any pre-opened keys. + for _, tsKeyName := range tailscaleKeyNamesFor(ps.scope) { + var tsKey registry.Key + tsKey, err := registry.OpenKey(ps.softwareKey, tsKeyName, windows.KEY_READ) + if err != nil { + if err == registry.ErrNotExist { + continue + } + return zero, err + } + val, err := getValue(tsKey) + tsKey.Close() + if err == nil || err != registry.ErrNotExist { + return val, err + } + } + + return zero, setting.ErrNotConfigured +} + +// Close closes the policy store and releases any associated resources. +// It cancels pending locks and prevents any new lock attempts, +// but waits for existing locks to be released. +func (ps *PlatformPolicyStore) Close() error { + // Request to close the Group Policy read lock. + // Existing held locks will remain valid, but any new or pending locks + // will fail. In certain scenarios, the corresponding write lock may be held + // by the Group Policy service for extended periods (minutes rather than + // seconds or milliseconds). In such cases, we prefer not to wait that long + // if the ps is being closed anyway. + if ps.policyLock != nil { + ps.policyLock.Close() + } + + // Mark ps as closing to fast-fail any new lock attempts. + // Callers that have already locked it can finish their reading. + ps.mu.Lock() + if ps.closing { + ps.mu.Unlock() + return nil + } + ps.closing = true + if ps.watcher != nil { + ps.watcher.Close() + ps.watcher = nil + } + ps.mu.Unlock() + + // Signal to the external code that ps should no longer be used. + close(ps.done) + + // Wait for any outstanding locks to be released. + ps.locked.Wait() + + // Deny any further read attempts and release remaining resources. + ps.mu.Lock() + defer ps.mu.Unlock() + ps.cbs = nil + ps.policyLock = nil + ps.closed = true + if ps.softwareKey != 0 { + ps.softwareKey.Close() + ps.softwareKey = 0 + } + return nil +} + +// Done returns a channel that is closed when the Close method is called. +func (ps *PlatformPolicyStore) Done() <-chan struct{} { + return ps.done +} + +func tailscaleKeyNamesFor(scope gp.Scope) []string { + switch scope { + case gp.MachinePolicy: + // If a computer-side policy value does not exist under Software\Policies\Tailscale, + // we need to fallback and use the legacy Software\Tailscale IPN key. + return []string{tsPoliciesSubkey, tsIPNSubkey} + case gp.UserPolicy: + // However, we've never used the legacy key with user-side policies, + // and we should never do so. Unlike HKLM\Software\Tailscale IPN, + // its HKCU counterpart is user-writable. + return []string{tsPoliciesSubkey} + default: + panic("unreachable") + } +} diff --git a/util/syspolicy/source/policy_store_windows_test.go b/util/syspolicy/source/policy_store_windows_test.go new file mode 100644 index 000000000..33f85dc0b --- /dev/null +++ b/util/syspolicy/source/policy_store_windows_test.go @@ -0,0 +1,398 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package source + +import ( + "errors" + "fmt" + "reflect" + "strconv" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/registry" + "tailscale.com/tstest" + "tailscale.com/util/cibuild" + "tailscale.com/util/mak" + "tailscale.com/util/syspolicy/setting" + "tailscale.com/util/winutil" + "tailscale.com/util/winutil/gp" +) + +// subkeyStrings is a test type indicating that a string slice should be written +// to the registry as multiple REG_SZ values under the setting's key, +// rather than as a single REG_MULTI_SZ value under the group key. +// This is the same format as ADMX use for string lists. +type subkeyStrings []string + +type testPolicyValue struct { + name setting.Key + value any +} + +func TestLockUnlockPolicyStore(t *testing.T) { + // Make sure we don't leak goroutines + tstest.ResourceCheck(t) + + store, err := NewMachinePlatformPolicyStore() + if err != nil { + t.Fatalf("NewMachinePolicyStore failed: %v", err) + } + + t.Run("One-Goroutine", func(t *testing.T) { + if err := store.Lock(); err != nil { + t.Errorf("store.Lock(): got %v; want nil", err) + return + } + if v, err := store.ReadString("NonExistingPolicySetting"); err == nil || !errors.Is(err, setting.ErrNotConfigured) { + t.Errorf(`ReadString: got %v, %v; want "", %v`, v, err, setting.ErrNotConfigured) + } + store.Unlock() + }) + + // Lock the store N times from different goroutines. + const N = 100 + var unlocked atomic.Int32 + t.Run("N-Goroutines", func(t *testing.T) { + var wg sync.WaitGroup + wg.Add(N) + for range N { + go func() { + if err := store.Lock(); err != nil { + t.Errorf("store.Lock(): got %v; want nil", err) + return + } + if v, err := store.ReadString("NonExistingPolicySetting"); err == nil || !errors.Is(err, setting.ErrNotConfigured) { + t.Errorf(`ReadString: got %v, %v; want "", %v`, v, err, setting.ErrNotConfigured) + } + wg.Done() + time.Sleep(10 * time.Millisecond) + unlocked.Add(1) + store.Unlock() + }() + } + + // Wait until the store is locked N times. + wg.Wait() + }) + + // Close the store. The call should wait for all held locks to be released. + if err := store.Close(); err != nil { + t.Fatalf("(*PolicyStore).Close failed: %v", err) + } + if locked := unlocked.Load(); locked != N { + t.Errorf("locked.Load(): got %v; want %v", locked, N) + } + + // Any further attempts to lock it should fail. + if err = store.Lock(); err == nil || !errors.Is(err, ErrStoreClosed) { + t.Errorf("store.Lock(): got %v; want %v", err, ErrStoreClosed) + } +} + +func TestReadPolicyStore(t *testing.T) { + if !winutil.IsCurrentProcessElevated() { + t.Skipf("test requires running as elevated user") + } + tests := []struct { + name setting.Key + newValue any + legacyValue any + want any + }{ + {name: "LegacyPolicy", legacyValue: "LegacyValue", want: "LegacyValue"}, + {name: "StringPolicy", legacyValue: "LegacyValue", newValue: "Value", want: "Value"}, + {name: "StringPolicy_Empty", legacyValue: "LegacyValue", newValue: "", want: ""}, + {name: "BoolPolicy_True", newValue: true, want: true}, + {name: "BoolPolicy_False", newValue: false, want: false}, + {name: "UIntPolicy_1", newValue: uint32(10), want: uint64(10)}, // uint32 values should be returned as uint64 + {name: "UIntPolicy_2", newValue: uint64(1 << 37), want: uint64(1 << 37)}, + {name: "StringListPolicy", newValue: []string{"Value1", "Value2"}, want: []string{"Value1", "Value2"}}, + {name: "StringListPolicy_Empty", newValue: []string{}, want: []string{}}, + {name: "StringListPolicy_SubKey", newValue: subkeyStrings{"Value1", "Value2"}, want: []string{"Value1", "Value2"}}, + {name: "StringListPolicy_SubKey_Empty", newValue: subkeyStrings{}, want: []string{}}, + } + + runTests := func(t *testing.T, userStore bool, token windows.Token) { + var hive registry.Key + if userStore { + hive = registry.CURRENT_USER + } else { + hive = registry.LOCAL_MACHINE + } + + // Write policy values to the registry. + newValues := make([]testPolicyValue, 0, len(tests)) + for _, tt := range tests { + if tt.newValue != nil { + newValues = append(newValues, testPolicyValue{name: tt.name, value: tt.newValue}) + } + } + policiesKeyName := softwareKeyName + `\` + tsPoliciesSubkey + cleanup, err := createTestPolicyValues(hive, policiesKeyName, newValues) + if err != nil { + t.Fatalf("createTestPolicyValues failed: %v", err) + } + t.Cleanup(cleanup) + + // Write legacy policy values to the registry. + legacyValues := make([]testPolicyValue, 0, len(tests)) + for _, tt := range tests { + if tt.legacyValue != nil { + legacyValues = append(legacyValues, testPolicyValue{name: tt.name, value: tt.legacyValue}) + } + } + legacyKeyName := softwareKeyName + `\` + tsIPNSubkey + cleanup, err = createTestPolicyValues(hive, legacyKeyName, legacyValues) + if err != nil { + t.Fatalf("createTestPolicyValues failed: %v", err) + } + t.Cleanup(cleanup) + + var store *PlatformPolicyStore + if userStore { + store, err = NewUserPlatformPolicyStore(token) + } else { + store, err = NewMachinePlatformPolicyStore() + } + if err != nil { + t.Fatalf("NewXPolicyStore failed: %v", err) + } + t.Cleanup(func() { + if err := store.Close(); err != nil { + t.Errorf("(*PolicyStore).Close failed: %v", err) + } + }) + + // testReadValues checks that [PolicyStore] returns the same values we wrote directly to the registry. + testReadValues := func(t *testing.T, withLocks bool) { + for _, tt := range tests { + t.Run(string(tt.name), func(t *testing.T) { + if userStore && tt.newValue == nil { + t.Skip("there is no legacy policies for users") + } + + t.Parallel() + + if withLocks { + if err := store.Lock(); err != nil { + t.Errorf("failed to acquire the lock: %v", err) + } + defer store.Unlock() + } + + var got any + var err error + switch tt.want.(type) { + case string: + got, err = store.ReadString(tt.name) + case uint64: + got, err = store.ReadUInt64(tt.name) + case bool: + got, err = store.ReadBoolean(tt.name) + case []string: + got, err = store.ReadStringArray(tt.name) + } + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("got %v; want %v", got, tt.want) + } + }) + } + } + t.Run("NoLock", func(t *testing.T) { + testReadValues(t, false) + }) + + t.Run("WithLock", func(t *testing.T) { + testReadValues(t, true) + }) + } + + t.Run("MachineStore", func(t *testing.T) { + runTests(t, false, 0) + }) + + t.Run("CurrentUserStore", func(t *testing.T) { + runTests(t, true, 0) + }) + + t.Run("UserStoreWithToken", func(t *testing.T) { + var token windows.Token + if err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_QUERY, &token); err != nil { + t.Fatalf("OpenProcessToken: %v", err) + } + defer token.Close() + runTests(t, true, token) + }) +} + +func TestPolicyStoreChangeNotifications(t *testing.T) { + if cibuild.On() { + t.Skipf("test requires running on a real Windows environment") + } + store, err := NewMachinePlatformPolicyStore() + if err != nil { + t.Fatalf("NewMachinePolicyStore failed: %v", err) + } + t.Cleanup(func() { + if err := store.Close(); err != nil { + t.Errorf("(*PolicyStore).Close failed: %v", err) + } + }) + + done := make(chan struct{}) + unregister, err := store.RegisterChangeCallback(func() { close(done) }) + if err != nil { + t.Fatalf("RegisterChangeCallback failed: %v", err) + } + t.Cleanup(unregister) + + // RefreshMachinePolicy is a non-blocking call. + if err := gp.RefreshMachinePolicy(true); err != nil { + t.Fatalf("RefreshMachinePolicy failed: %v", err) + } + + // We should receive a policy change notification when + // the Group Policy service completes policy processing. + // Otherwise, the test will eventually time out. + <-done +} + +func TestSplitSettingKey(t *testing.T) { + tests := []struct { + name string + key setting.Key + wantPath string + wantValue string + }{ + { + name: "empty", + key: "", + wantPath: ``, + wantValue: "", + }, + { + name: "explicit-empty-path", + key: "/ValueName", + wantPath: ``, + wantValue: "ValueName", + }, + { + name: "empty-value", + key: "Root/Sub/", + wantPath: `Root\Sub`, + wantValue: "", + }, + { + name: "with-path", + key: "Root/Sub/ValueName", + wantPath: `Root\Sub`, + wantValue: "ValueName", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotPath, gotValue := splitSettingKey(tt.key) + if gotPath != tt.wantPath { + t.Errorf("Path: got %q, want %q", gotPath, tt.wantPath) + } + if gotValue != tt.wantValue { + t.Errorf("Value: got %q, want %q", gotValue, tt.wantPath) + } + }) + } +} + +func createTestPolicyValues(hive registry.Key, keyName string, values []testPolicyValue) (cleanup func(), err error) { + key, existing, err := registry.CreateKey(hive, keyName, registry.ALL_ACCESS) + if err != nil { + return nil, err + } + var valuesToDelete map[string][]string + doCleanup := func() { + for path, values := range valuesToDelete { + if len(values) == 0 { + registry.DeleteKey(key, path) + continue + } + key, err := registry.OpenKey(key, path, windows.KEY_ALL_ACCESS) + if err != nil { + continue + } + defer key.Close() + for _, value := range values { + key.DeleteValue(value) + } + } + + key.Close() + if !existing { + registry.DeleteKey(hive, keyName) + } + } + defer func() { + if err != nil { + doCleanup() + } + }() + + for _, v := range values { + key, existing := key, existing + path, valueName := splitSettingKey(v.name) + if path != "" { + if key, existing, err = registry.CreateKey(key, valueName, windows.KEY_ALL_ACCESS); err != nil { + return nil, err + } + defer key.Close() + } + if values, ok := valuesToDelete[path]; len(values) > 0 || (!ok && existing) { + values = append(values, valueName) + mak.Set(&valuesToDelete, path, values) + } else if !ok { + mak.Set(&valuesToDelete, path, nil) + } + + switch value := v.value.(type) { + case string: + err = key.SetStringValue(valueName, value) + case uint32: + err = key.SetDWordValue(valueName, value) + case uint64: + err = key.SetQWordValue(valueName, value) + case bool: + if value { + err = key.SetDWordValue(valueName, 1) + } else { + err = key.SetDWordValue(valueName, 0) + } + case []string: + err = key.SetStringsValue(valueName, value) + case subkeyStrings: + key, _, err := registry.CreateKey(key, valueName, windows.KEY_ALL_ACCESS) + if err != nil { + return nil, err + } + defer key.Close() + mak.Set(&valuesToDelete, strings.Trim(path+`\`+valueName, `\`), nil) + for i, value := range value { + if err := key.SetStringValue(strconv.Itoa(i), value); err != nil { + return nil, err + } + } + default: + err = fmt.Errorf("unsupported value: %v (%T), name: %q", value, value, v.name) + } + if err != nil { + return nil, err + } + } + return doCleanup, nil +} diff --git a/util/syspolicy/source/test_store.go b/util/syspolicy/source/test_store.go new file mode 100644 index 000000000..bb8e164fb --- /dev/null +++ b/util/syspolicy/source/test_store.go @@ -0,0 +1,451 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package source + +import ( + "fmt" + "sync" + "sync/atomic" + + xmaps "golang.org/x/exp/maps" + "tailscale.com/util/mak" + "tailscale.com/util/set" + "tailscale.com/util/syspolicy/internal" + "tailscale.com/util/syspolicy/setting" +) + +var ( + _ Store = (*TestStore)(nil) + _ Lockable = (*TestStore)(nil) + _ Changeable = (*TestStore)(nil) + _ Expirable = (*TestStore)(nil) +) + +// TestValueType is a constraint that allows types supported by [TestStore]. +type TestValueType interface { + bool | uint64 | string | []string +} + +// TestSetting is a policy setting in a [TestStore]. +type TestSetting[T TestValueType] struct { + // Key is the setting's unique identifier. + Key setting.Key + // Error is the error to be returned by the [TestStore] when reading + // a policy setting with the specified key. + Error error + // Value is the value to be returned by the [TestStore] when reading + // a policy setting with the specified key. + // It is only used if the Error is nil. + Value T +} + +// TestSettingOf returns a [TestSetting] representing a policy setting +// configured with the specified key and value. +func TestSettingOf[T TestValueType](key setting.Key, value T) TestSetting[T] { + return TestSetting[T]{Key: key, Value: value} +} + +// TestSettingWithError returns a [TestSetting] representing a policy setting +// with the specified key and error. +func TestSettingWithError[T TestValueType](key setting.Key, err error) TestSetting[T] { + return TestSetting[T]{Key: key, Error: err} +} + +// testReadOperation describes a single policy setting read operation. +type testReadOperation struct { + // Key is the setting's unique identifier. + Key setting.Key + // Type is a value type of a read operation. + // [setting.BooleanValue], [setting.IntegerValue], [setting.StringValue] or [setting.StringListValue] + Type setting.Type +} + +// TestExpectedReads is the number of read operations with the specified details. +type TestExpectedReads struct { + // Key is the setting's unique identifier. + Key setting.Key + // Type is a value type of a read operation. + // [setting.BooleanValue], [setting.IntegerValue], [setting.StringValue] or [setting.StringListValue] + Type setting.Type + // NumTimes is how many times a setting with the specified key and type should have been read. + NumTimes int +} + +func (r TestExpectedReads) operation() testReadOperation { + return testReadOperation{r.Key, r.Type} +} + +// TestStore is a [Store] that can be used in tests. +type TestStore struct { + tb internal.TB + + done chan struct{} + + storeLock sync.RWMutex // its RLock is exposed via [Store.Lock]/[Store.Unlock]. + storeLockCount atomic.Int32 + + mu sync.RWMutex + suspendCount int // change callback are suspended if > 0 + mr, mw map[setting.Key]any // maps for reading and writing; they're the same unless the store is suspended. + cbs set.HandleSet[func()] + + readsMu sync.Mutex + reads map[testReadOperation]int // how many times a policy setting was read +} + +// NewTestStore returns a new [TestStore]. +// The tb will be used to report coding errors detected by the [TestStore]. +func NewTestStore(tb internal.TB) *TestStore { + m := make(map[setting.Key]any) + return &TestStore{ + tb: tb, + done: make(chan struct{}), + mr: m, + mw: m, + } +} + +// NewTestStoreOf is a shorthand for [NewTestStore] followed by [TestStore.SetBooleans], +// [TestStore.SetUInt64s], [TestStore.SetStrings] or [TestStore.SetStringLists]. +func NewTestStoreOf[T TestValueType](tb internal.TB, settings ...TestSetting[T]) *TestStore { + m := make(map[setting.Key]any) + store := &TestStore{ + tb: tb, + done: make(chan struct{}), + mr: m, + mw: m, + } + switch settings := any(settings).(type) { + case []TestSetting[bool]: + store.SetBooleans(settings...) + case []TestSetting[uint64]: + store.SetUInt64s(settings...) + case []TestSetting[string]: + store.SetStrings(settings...) + case []TestSetting[[]string]: + store.SetStringLists(settings...) + } + return store +} + +// Lock implements [Lockable]. +func (s *TestStore) Lock() error { + s.storeLock.RLock() + s.storeLockCount.Add(1) + return nil +} + +// Unlock implements [Lockable]. +func (s *TestStore) Unlock() { + if s.storeLockCount.Add(-1) < 0 { + s.tb.Fatal("negative storeLockCount") + } + s.storeLock.RUnlock() +} + +// RegisterChangeCallback implements [Changeable]. +func (s *TestStore) RegisterChangeCallback(callback func()) (unregister func(), err error) { + s.mu.Lock() + defer s.mu.Unlock() + handle := s.cbs.Add(callback) + return func() { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.cbs, handle) + }, nil +} + +// ReadString implements [Store]. +func (s *TestStore) ReadString(key setting.Key) (string, error) { + defer s.recordRead(key, setting.StringValue) + s.mu.RLock() + defer s.mu.RUnlock() + v, ok := s.mr[key] + if !ok { + return "", setting.ErrNotConfigured + } + if err, ok := v.(error); ok { + return "", err + } + str, ok := v.(string) + if !ok { + return "", fmt.Errorf("%w in ReadString: got %T", setting.ErrTypeMismatch, v) + } + return str, nil +} + +// ReadUInt64 implements [Store]. +func (s *TestStore) ReadUInt64(key setting.Key) (uint64, error) { + defer s.recordRead(key, setting.IntegerValue) + s.mu.RLock() + defer s.mu.RUnlock() + v, ok := s.mr[key] + if !ok { + return 0, setting.ErrNotConfigured + } + if err, ok := v.(error); ok { + return 0, err + } + u64, ok := v.(uint64) + if !ok { + return 0, fmt.Errorf("%w in ReadUInt64: got %T", setting.ErrTypeMismatch, v) + } + return u64, nil +} + +// ReadBoolean implements [Store]. +func (s *TestStore) ReadBoolean(key setting.Key) (bool, error) { + defer s.recordRead(key, setting.BooleanValue) + s.mu.RLock() + defer s.mu.RUnlock() + v, ok := s.mr[key] + if !ok { + return false, setting.ErrNotConfigured + } + if err, ok := v.(error); ok { + return false, err + } + b, ok := v.(bool) + if !ok { + return false, fmt.Errorf("%w in ReadBoolean: got %T", setting.ErrTypeMismatch, v) + } + return b, nil +} + +// ReadStringArray implements [Store]. +func (s *TestStore) ReadStringArray(key setting.Key) ([]string, error) { + defer s.recordRead(key, setting.StringListValue) + s.mu.RLock() + defer s.mu.RUnlock() + v, ok := s.mr[key] + if !ok { + return nil, setting.ErrNotConfigured + } + if err, ok := v.(error); ok { + return nil, err + } + slice, ok := v.([]string) + if !ok { + return nil, fmt.Errorf("%w in ReadStringArray: got %T", setting.ErrTypeMismatch, v) + } + return slice, nil +} + +func (s *TestStore) recordRead(key setting.Key, typ setting.Type) { + s.readsMu.Lock() + op := testReadOperation{key, typ} + num := s.reads[op] + num++ + mak.Set(&s.reads, op, num) + s.readsMu.Unlock() +} + +func (s *TestStore) ResetCounters() { + s.readsMu.Lock() + clear(s.reads) + s.readsMu.Unlock() +} + +// ReadsMustEqual fails the test if the actual reads differs from the specified reads. +func (s *TestStore) ReadsMustEqual(reads ...TestExpectedReads) { + s.tb.Helper() + s.readsMu.Lock() + defer s.readsMu.Unlock() + s.readsMustContainLocked(reads...) + s.readMustNoExtraLocked(reads...) +} + +// ReadsMustContain fails the test if the specified reads have not been made, +// or have been made a different number of times. It permits other values to be +// read in addition to the ones being tested. +func (s *TestStore) ReadsMustContain(reads ...TestExpectedReads) { + s.tb.Helper() + s.readsMu.Lock() + defer s.readsMu.Unlock() + s.readsMustContainLocked(reads...) +} + +func (s *TestStore) readsMustContainLocked(reads ...TestExpectedReads) { + s.tb.Helper() + for _, r := range reads { + if numTimes := s.reads[r.operation()]; numTimes != r.NumTimes { + s.tb.Errorf("%q (%v) reads: got %v, want %v", r.Key, r.Type, numTimes, r.NumTimes) + } + } +} + +func (s *TestStore) readMustNoExtraLocked(reads ...TestExpectedReads) { + s.tb.Helper() + rs := make(set.Set[testReadOperation]) + for i := range reads { + rs.Add(reads[i].operation()) + } + for ro, num := range s.reads { + if !rs.Contains(ro) { + s.tb.Errorf("%q (%v) reads: got %v, want 0", ro.Key, ro.Type, num) + } + } +} + +// Suspend suspends the store, batching changes and notifications +// until [TestStore.Resume] is called the same number of times as Suspend. +func (s *TestStore) Suspend() { + s.mu.Lock() + defer s.mu.Unlock() + if s.suspendCount++; s.suspendCount == 1 { + s.mw = xmaps.Clone(s.mr) + } +} + +// Resume resumes the store, applying the changes and invoking +// the change callbacks. +func (s *TestStore) Resume() { + s.storeLock.Lock() + s.mu.Lock() + switch s.suspendCount--; { + case s.suspendCount == 0: + s.mr = s.mw + s.mu.Unlock() + s.storeLock.Unlock() + s.notifyPolicyChanged() + case s.suspendCount < 0: + s.tb.Fatal("negative suspendCount") + default: + s.mu.Unlock() + s.storeLock.Unlock() + } +} + +// SetBooleans sets the specified boolean settings in s. +func (s *TestStore) SetBooleans(settings ...TestSetting[bool]) { + s.storeLock.Lock() + for _, setting := range settings { + if setting.Key == "" { + s.tb.Fatal("empty keys disallowed") + } + s.mu.Lock() + if setting.Error != nil { + mak.Set(&s.mw, setting.Key, any(setting.Error)) + } else { + mak.Set(&s.mw, setting.Key, any(setting.Value)) + } + s.mu.Unlock() + } + s.storeLock.Unlock() + s.notifyPolicyChanged() +} + +// SetUInt64s sets the specified integer settings in s. +func (s *TestStore) SetUInt64s(settings ...TestSetting[uint64]) { + s.storeLock.Lock() + for _, setting := range settings { + if setting.Key == "" { + s.tb.Fatal("empty keys disallowed") + } + s.mu.Lock() + if setting.Error != nil { + mak.Set(&s.mw, setting.Key, any(setting.Error)) + } else { + mak.Set(&s.mw, setting.Key, any(setting.Value)) + } + s.mu.Unlock() + } + s.storeLock.Unlock() + s.notifyPolicyChanged() +} + +// SetStrings sets the specified string settings in s. +func (s *TestStore) SetStrings(settings ...TestSetting[string]) { + s.storeLock.Lock() + for _, setting := range settings { + if setting.Key == "" { + s.tb.Fatal("empty keys disallowed") + } + s.mu.Lock() + if setting.Error != nil { + mak.Set(&s.mw, setting.Key, any(setting.Error)) + } else { + mak.Set(&s.mw, setting.Key, any(setting.Value)) + } + s.mu.Unlock() + } + s.storeLock.Unlock() + s.notifyPolicyChanged() +} + +// SetStrings sets the specified string list settings in s. +func (s *TestStore) SetStringLists(settings ...TestSetting[[]string]) { + s.storeLock.Lock() + for _, setting := range settings { + if setting.Key == "" { + s.tb.Fatal("empty keys disallowed") + } + s.mu.Lock() + if setting.Error != nil { + mak.Set(&s.mw, setting.Key, any(setting.Error)) + } else { + mak.Set(&s.mw, setting.Key, any(setting.Value)) + } + s.mu.Unlock() + } + s.storeLock.Unlock() + s.notifyPolicyChanged() +} + +// Delete deletes the specified settings from s. +func (s *TestStore) Delete(keys ...setting.Key) { + s.storeLock.Lock() + for _, key := range keys { + s.mu.Lock() + delete(s.mw, key) + s.mu.Unlock() + } + s.storeLock.Unlock() + s.notifyPolicyChanged() +} + +// Clear deletes all settings from s. +func (s *TestStore) Clear() { + s.storeLock.Lock() + s.mu.Lock() + clear(s.mw) + s.mu.Unlock() + s.storeLock.Unlock() + s.notifyPolicyChanged() +} + +func (s *TestStore) notifyPolicyChanged() { + s.mu.RLock() + if s.suspendCount != 0 { + s.mu.RUnlock() + return + } + cbs := xmaps.Values(s.cbs) + s.mu.RUnlock() + + var wg sync.WaitGroup + wg.Add(len(cbs)) + for _, cb := range cbs { + go func() { + defer wg.Done() + cb() + }() + } + wg.Wait() +} + +// Close closes s, notifying its users that it has expired. +func (s *TestStore) Close() { + s.mu.Lock() + defer s.mu.Unlock() + if s.done != nil { + close(s.done) + s.done = nil + } +} + +// Done implements [Expirable]. +func (s *TestStore) Done() <-chan struct{} { + return s.done +}