mirror of https://github.com/tailscale/tailscale/
util/syspolicy/source: add package for reading policy settings from external stores
We add package defining interfaces for policy stores, enabling creation of policy sources and reading settings from them. It includes a Windows-specific PlatformPolicyStore for GP and MDM policies stored in the Registry, and an in-memory TestStore for testing purposes. We also include an internal package that tracks and reports policy usage metrics when a policy setting is read from a store. Initially, it will be used only on Windows and Android, as macOS, iOS, and tvOS report their own metrics. However, we plan to use it across all platforms eventually. Updates #12687 Signed-off-by: Nick Khyl <nickk@tailscale.com>pull/13356/head
parent
e865a0e2b0
commit
aeb15dea30
@ -0,0 +1,46 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package loggerx provides logging functions to the rest of the syspolicy packages.
|
||||
package loggerx
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"tailscale.com/types/lazy"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/util/syspolicy/internal"
|
||||
)
|
||||
|
||||
const (
|
||||
errorPrefix = "syspolicy: "
|
||||
verbosePrefix = "syspolicy: [v2] "
|
||||
)
|
||||
|
||||
var (
|
||||
lazyErrorf lazy.SyncValue[logger.Logf]
|
||||
lazyVerbosef lazy.SyncValue[logger.Logf]
|
||||
)
|
||||
|
||||
// Errorf formats and writes an error message to the log.
|
||||
func Errorf(format string, args ...any) {
|
||||
errorf := lazyErrorf.Get(func() logger.Logf {
|
||||
return logger.WithPrefix(log.Printf, errorPrefix)
|
||||
})
|
||||
errorf(format, args...)
|
||||
}
|
||||
|
||||
// Verbosef formats and writes an optional, verbose message to the log.
|
||||
func Verbosef(format string, args ...any) {
|
||||
verbosef := lazyVerbosef.Get(func() logger.Logf {
|
||||
return logger.WithPrefix(log.Printf, verbosePrefix)
|
||||
})
|
||||
verbosef(format, args...)
|
||||
}
|
||||
|
||||
// SetForTest sets the specified errorf and verbosef functions for the duration
|
||||
// of tb and its subtests.
|
||||
func SetForTest(tb internal.TB, errorf, verbosef logger.Logf) {
|
||||
lazyErrorf.SetForTest(tb, errorf, nil)
|
||||
lazyVerbosef.SetForTest(tb, verbosef, nil)
|
||||
}
|
@ -0,0 +1,320 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package metrics provides logging and reporting for policy settings and scopes.
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
xmaps "golang.org/x/exp/maps"
|
||||
|
||||
"tailscale.com/syncs"
|
||||
"tailscale.com/types/lazy"
|
||||
"tailscale.com/util/clientmetric"
|
||||
"tailscale.com/util/mak"
|
||||
"tailscale.com/util/slicesx"
|
||||
"tailscale.com/util/syspolicy/internal"
|
||||
"tailscale.com/util/syspolicy/internal/loggerx"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
"tailscale.com/util/testenv"
|
||||
)
|
||||
|
||||
var lazyReportMetrics lazy.SyncValue[bool] // used as a test hook
|
||||
|
||||
// ShouldReport reports whether metrics should be reported on the current environment.
|
||||
func ShouldReport() bool {
|
||||
return lazyReportMetrics.Get(func() bool {
|
||||
// macOS, iOS and tvOS create their own metrics,
|
||||
// and we don't have syspolicy on any other platforms.
|
||||
return setting.PlatformList{"android", "windows"}.HasCurrent()
|
||||
})
|
||||
}
|
||||
|
||||
// Reset metrics for the specified policy origin.
|
||||
func Reset(origin *setting.Origin) {
|
||||
scopeMetrics(origin).Reset()
|
||||
}
|
||||
|
||||
// ReportConfigured updates metrics and logs that the specified setting is
|
||||
// configured with the given value in the origin.
|
||||
func ReportConfigured(origin *setting.Origin, setting *setting.Definition, value any) {
|
||||
settingMetricsFor(setting).ReportValue(origin, value)
|
||||
}
|
||||
|
||||
// ReportError updates metrics and logs that the specified setting has an error
|
||||
// in the origin.
|
||||
func ReportError(origin *setting.Origin, setting *setting.Definition, err error) {
|
||||
settingMetricsFor(setting).ReportError(origin, err)
|
||||
}
|
||||
|
||||
// ReportNotConfigured updates metrics and logs that the specified setting is
|
||||
// not configured in the origin.
|
||||
func ReportNotConfigured(origin *setting.Origin, setting *setting.Definition) {
|
||||
settingMetricsFor(setting).Reset(origin)
|
||||
}
|
||||
|
||||
// metric is an interface implemented by [clientmetric.Metric] and [funcMetric].
|
||||
type metric interface {
|
||||
Add(v int64)
|
||||
Set(v int64)
|
||||
}
|
||||
|
||||
// policyScopeMetrics are metrics that apply to an entire policy scope rather
|
||||
// than a specific policy setting.
|
||||
type policyScopeMetrics struct {
|
||||
hasAny metric
|
||||
numErrored metric
|
||||
}
|
||||
|
||||
func newScopeMetrics(scope setting.Scope) *policyScopeMetrics {
|
||||
prefix := metricScopeName(scope)
|
||||
// {os}_syspolicy_{scope_unless_device}_any
|
||||
// Example: windows_syspolicy_any or windows_syspolicy_user_any.
|
||||
hasAny := newMetric([]string{prefix, "any"}, clientmetric.TypeGauge)
|
||||
// {os}_syspolicy_{scope_unless_device}_errors
|
||||
// Example: windows_syspolicy_errors or windows_syspolicy_user_errors.
|
||||
//
|
||||
// TODO(nickkhyl): maybe make the `{os}_syspolicy_errors` metric a gauge rather than a counter?
|
||||
// It was a counter prior to https://github.com/tailscale/tailscale/issues/12687, so I kept it as such.
|
||||
// But I think a gauge makes more sense: syspolicy errors indicate a mismatch between the expected
|
||||
// policy value type or format and the actual value read from the underlying store (like the Windows Registry).
|
||||
// We'll encounter the same error every time we re-read the policy setting from the backing store
|
||||
// until the policy value is corrected by the user, or until we fix the bug in the code or ADMX.
|
||||
// There's probably no reason to count and accumulate them over time.
|
||||
//
|
||||
// Brief discussion: https://github.com/tailscale/tailscale/pull/13113#discussion_r1723475136
|
||||
numErrored := newMetric([]string{prefix, "errors"}, clientmetric.TypeCounter)
|
||||
return &policyScopeMetrics{hasAny, numErrored}
|
||||
}
|
||||
|
||||
// ReportHasSettings is called when there's any configured policy setting in the scope.
|
||||
func (m *policyScopeMetrics) ReportHasSettings() {
|
||||
if m != nil {
|
||||
m.hasAny.Set(1)
|
||||
}
|
||||
}
|
||||
|
||||
// ReportError is called when there's any errored policy setting in the scope.
|
||||
func (m *policyScopeMetrics) ReportError() {
|
||||
if m != nil {
|
||||
m.numErrored.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
// Reset is called to reset the policy scope metrics, such as when the policy scope
|
||||
// is about to be reloaded.
|
||||
func (m *policyScopeMetrics) Reset() {
|
||||
if m != nil {
|
||||
m.hasAny.Set(0)
|
||||
// numErrored is a counter and cannot be (re-)set.
|
||||
}
|
||||
}
|
||||
|
||||
// settingMetrics are metrics for a single policy setting in one or more scopes.
|
||||
type settingMetrics struct {
|
||||
definition *setting.Definition
|
||||
isSet []metric // by scope
|
||||
hasErrors []metric // by scope
|
||||
}
|
||||
|
||||
// ReportValue is called when the policy setting is found to be configured in the specified source.
|
||||
func (m *settingMetrics) ReportValue(origin *setting.Origin, v any) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
if scope := origin.Scope().Kind(); scope >= 0 && int(scope) < len(m.isSet) {
|
||||
m.isSet[scope].Set(1)
|
||||
m.hasErrors[scope].Set(0)
|
||||
}
|
||||
scopeMetrics(origin).ReportHasSettings()
|
||||
loggerx.Verbosef("%v(%q) = %v", origin, m.definition.Key(), v)
|
||||
}
|
||||
|
||||
// ReportError is called when there's an error with the policy setting in the specified source.
|
||||
func (m *settingMetrics) ReportError(origin *setting.Origin, err error) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
if scope := origin.Scope().Kind(); int(scope) < len(m.hasErrors) {
|
||||
m.isSet[scope].Set(0)
|
||||
m.hasErrors[scope].Set(1)
|
||||
}
|
||||
scopeMetrics(origin).ReportError()
|
||||
loggerx.Errorf("%v(%q): %v", origin, m.definition.Key(), err)
|
||||
}
|
||||
|
||||
// Reset is called to reset the policy setting's metrics, such as when
|
||||
// the policy setting does not exist or the source containing the policy
|
||||
// is about to be reloaded.
|
||||
func (m *settingMetrics) Reset(origin *setting.Origin) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
if scope := origin.Scope().Kind(); scope >= 0 && int(scope) < len(m.isSet) {
|
||||
m.isSet[scope].Set(0)
|
||||
m.hasErrors[scope].Set(0)
|
||||
}
|
||||
}
|
||||
|
||||
// metricFn is a function that adds or sets a metric value.
|
||||
type metricFn func(name string, typ clientmetric.Type, v int64)
|
||||
|
||||
// funcMetric implements [metric] by calling the specified add and set functions.
|
||||
// Used for testing, and with nil functions on platforms that do not support
|
||||
// syspolicy, and on platforms that report policy metrics from the GUI.
|
||||
type funcMetric struct {
|
||||
name string
|
||||
typ clientmetric.Type
|
||||
add, set metricFn
|
||||
}
|
||||
|
||||
func (m funcMetric) Add(v int64) {
|
||||
if m.add != nil {
|
||||
m.add(m.name, m.typ, v)
|
||||
}
|
||||
}
|
||||
|
||||
func (m funcMetric) Set(v int64) {
|
||||
if m.set != nil {
|
||||
m.set(m.name, m.typ, v)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
lazyDeviceMetrics lazy.SyncValue[*policyScopeMetrics]
|
||||
lazyProfileMetrics lazy.SyncValue[*policyScopeMetrics]
|
||||
lazyUserMetrics lazy.SyncValue[*policyScopeMetrics]
|
||||
)
|
||||
|
||||
func scopeMetrics(origin *setting.Origin) *policyScopeMetrics {
|
||||
switch origin.Scope().Kind() {
|
||||
case setting.DeviceSetting:
|
||||
return lazyDeviceMetrics.Get(func() *policyScopeMetrics {
|
||||
return newScopeMetrics(setting.DeviceSetting)
|
||||
})
|
||||
case setting.ProfileSetting:
|
||||
return lazyProfileMetrics.Get(func() *policyScopeMetrics {
|
||||
return newScopeMetrics(setting.ProfileSetting)
|
||||
})
|
||||
case setting.UserSetting:
|
||||
return lazyUserMetrics.Get(func() *policyScopeMetrics {
|
||||
return newScopeMetrics(setting.UserSetting)
|
||||
})
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
settingMetricsMu sync.RWMutex
|
||||
settingMetricsMap map[setting.Key]*settingMetrics
|
||||
)
|
||||
|
||||
func settingMetricsFor(setting *setting.Definition) *settingMetrics {
|
||||
settingMetricsMu.RLock()
|
||||
metrics, ok := settingMetricsMap[setting.Key()]
|
||||
settingMetricsMu.RUnlock()
|
||||
if ok {
|
||||
return metrics
|
||||
}
|
||||
return settingMetricsForSlow(setting)
|
||||
}
|
||||
|
||||
func settingMetricsForSlow(d *setting.Definition) *settingMetrics {
|
||||
settingMetricsMu.Lock()
|
||||
defer settingMetricsMu.Unlock()
|
||||
if metrics, ok := settingMetricsMap[d.Key()]; ok {
|
||||
return metrics
|
||||
}
|
||||
|
||||
// The loop below initializes metrics for each scope where a policy setting defined in 'd'
|
||||
// can be configured. The [setting.Definition.Scope] returns the narrowest scope at which the policy
|
||||
// setting may be configured, and more specific scopes always have higher numeric values.
|
||||
// In other words, [setting.UserSetting] > [setting.ProfileScope] > [setting.DeviceScope].
|
||||
// It's impossible for a policy setting to be configured in a scope with a higher numeric value than
|
||||
// the [setting.Definition.Scope] returns. Therefore, a policy setting can be configured in at
|
||||
// most d.Scope()+1 different scopes, and having d.Scope()+1 metrics for the corresponding scopes
|
||||
// is always sufficient for [settingMetrics]; it won't access elements past the end of the slice
|
||||
// or need to reallocate with a longer slice if one of those arrives.
|
||||
isSet := make([]metric, d.Scope()+1)
|
||||
hasErrors := make([]metric, d.Scope()+1)
|
||||
for i := range isSet {
|
||||
scope := setting.Scope(i)
|
||||
// {os}_syspolicy_{key}_{scope_unless_device}
|
||||
// Example: windows_syspolicy_AdminConsole or windows_syspolicy_AdminConsole_user.
|
||||
isSet[i] = newSettingMetric(d.Key(), scope, "", clientmetric.TypeGauge)
|
||||
// {os}_syspolicy_{key}_{scope_unless_device}_error
|
||||
// Example: windows_syspolicy_AdminConsole_error or windows_syspolicy_TestSetting01_user_error.
|
||||
hasErrors[i] = newSettingMetric(d.Key(), scope, "error", clientmetric.TypeGauge)
|
||||
}
|
||||
metrics := &settingMetrics{d, isSet, hasErrors}
|
||||
mak.Set(&settingMetricsMap, d.Key(), metrics)
|
||||
return metrics
|
||||
}
|
||||
|
||||
// hooks for testing
|
||||
var addMetricTestHook, setMetricTestHook syncs.AtomicValue[metricFn]
|
||||
|
||||
// SetHooksForTest sets the specified addMetric and setMetric functions
|
||||
// as the metric functions for the duration of tb and all its subtests.
|
||||
func SetHooksForTest(tb internal.TB, addMetric, setMetric metricFn) {
|
||||
oldAddMetric := addMetricTestHook.Swap(addMetric)
|
||||
oldSetMetric := setMetricTestHook.Swap(setMetric)
|
||||
tb.Cleanup(func() {
|
||||
addMetricTestHook.Store(oldAddMetric)
|
||||
setMetricTestHook.Store(oldSetMetric)
|
||||
})
|
||||
|
||||
settingMetricsMu.Lock()
|
||||
oldSettingMetricsMap := xmaps.Clone(settingMetricsMap)
|
||||
clear(settingMetricsMap)
|
||||
settingMetricsMu.Unlock()
|
||||
tb.Cleanup(func() {
|
||||
settingMetricsMu.Lock()
|
||||
settingMetricsMap = oldSettingMetricsMap
|
||||
settingMetricsMu.Unlock()
|
||||
})
|
||||
|
||||
// (re-)set the scope metrics to use the test hooks for the duration of tb.
|
||||
lazyDeviceMetrics.SetForTest(tb, newScopeMetrics(setting.DeviceSetting), nil)
|
||||
lazyProfileMetrics.SetForTest(tb, newScopeMetrics(setting.ProfileSetting), nil)
|
||||
lazyUserMetrics.SetForTest(tb, newScopeMetrics(setting.UserSetting), nil)
|
||||
}
|
||||
|
||||
func newSettingMetric(key setting.Key, scope setting.Scope, suffix string, typ clientmetric.Type) metric {
|
||||
name := strings.ReplaceAll(string(key), setting.KeyPathSeparator, "_")
|
||||
return newMetric([]string{name, metricScopeName(scope), suffix}, typ)
|
||||
}
|
||||
|
||||
func newMetric(nameParts []string, typ clientmetric.Type) metric {
|
||||
name := strings.Join(slicesx.Filter([]string{internal.OS(), "syspolicy"}, nameParts, isNonEmpty), "_")
|
||||
switch {
|
||||
case !ShouldReport():
|
||||
return &funcMetric{name: name, typ: typ}
|
||||
case testenv.InTest():
|
||||
return &funcMetric{name, typ, addMetricTestHook.Load(), setMetricTestHook.Load()}
|
||||
case typ == clientmetric.TypeCounter:
|
||||
return clientmetric.NewCounter(name)
|
||||
case typ == clientmetric.TypeGauge:
|
||||
return clientmetric.NewGauge(name)
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
func isNonEmpty(s string) bool { return s != "" }
|
||||
|
||||
func metricScopeName(scope setting.Scope) string {
|
||||
switch scope {
|
||||
case setting.DeviceSetting:
|
||||
return ""
|
||||
case setting.ProfileSetting:
|
||||
return "profile"
|
||||
case setting.UserSetting:
|
||||
return "user"
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
@ -0,0 +1,423 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"tailscale.com/types/lazy"
|
||||
"tailscale.com/util/clientmetric"
|
||||
"tailscale.com/util/syspolicy/internal"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
)
|
||||
|
||||
func TestSettingMetricNames(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key setting.Key
|
||||
scope setting.Scope
|
||||
suffix string
|
||||
typ clientmetric.Type
|
||||
osOverride string
|
||||
wantMetricName string
|
||||
}{
|
||||
{
|
||||
name: "windows-device-no-suffix",
|
||||
key: "AdminConsole",
|
||||
scope: setting.DeviceSetting,
|
||||
suffix: "",
|
||||
typ: clientmetric.TypeCounter,
|
||||
osOverride: "windows",
|
||||
wantMetricName: "windows_syspolicy_AdminConsole",
|
||||
},
|
||||
{
|
||||
name: "windows-user-no-suffix",
|
||||
key: "AdminConsole",
|
||||
scope: setting.UserSetting,
|
||||
suffix: "",
|
||||
typ: clientmetric.TypeCounter,
|
||||
osOverride: "windows",
|
||||
wantMetricName: "windows_syspolicy_AdminConsole_user",
|
||||
},
|
||||
{
|
||||
name: "windows-profile-no-suffix",
|
||||
key: "AdminConsole",
|
||||
scope: setting.ProfileSetting,
|
||||
suffix: "",
|
||||
typ: clientmetric.TypeCounter,
|
||||
osOverride: "windows",
|
||||
wantMetricName: "windows_syspolicy_AdminConsole_profile",
|
||||
},
|
||||
{
|
||||
name: "windows-profile-err",
|
||||
key: "AdminConsole",
|
||||
scope: setting.ProfileSetting,
|
||||
suffix: "error",
|
||||
typ: clientmetric.TypeCounter,
|
||||
osOverride: "windows",
|
||||
wantMetricName: "windows_syspolicy_AdminConsole_profile_error",
|
||||
},
|
||||
{
|
||||
name: "android-device-no-suffix",
|
||||
key: "AdminConsole",
|
||||
scope: setting.DeviceSetting,
|
||||
suffix: "",
|
||||
typ: clientmetric.TypeCounter,
|
||||
osOverride: "android",
|
||||
wantMetricName: "android_syspolicy_AdminConsole",
|
||||
},
|
||||
{
|
||||
name: "key-path",
|
||||
key: "category/subcategory/setting",
|
||||
scope: setting.DeviceSetting,
|
||||
suffix: "",
|
||||
typ: clientmetric.TypeCounter,
|
||||
osOverride: "fakeos",
|
||||
wantMetricName: "fakeos_syspolicy_category_subcategory_setting",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
internal.OSForTesting.SetForTest(t, tt.osOverride, nil)
|
||||
metric, ok := newSettingMetric(tt.key, tt.scope, tt.suffix, tt.typ).(*funcMetric)
|
||||
if !ok {
|
||||
t.Fatal("metric is not a funcMetric")
|
||||
}
|
||||
if metric.name != tt.wantMetricName {
|
||||
t.Errorf("got %q, want %q", metric.name, tt.wantMetricName)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestScopeMetrics(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
scope setting.Scope
|
||||
osOverride string
|
||||
wantHasAnyName string
|
||||
wantNumErroredName string
|
||||
wantHasAnyType clientmetric.Type
|
||||
wantNumErroredType clientmetric.Type
|
||||
}{
|
||||
{
|
||||
name: "windows-device",
|
||||
scope: setting.DeviceSetting,
|
||||
osOverride: "windows",
|
||||
wantHasAnyName: "windows_syspolicy_any",
|
||||
wantHasAnyType: clientmetric.TypeGauge,
|
||||
wantNumErroredName: "windows_syspolicy_errors",
|
||||
wantNumErroredType: clientmetric.TypeCounter,
|
||||
},
|
||||
{
|
||||
name: "windows-profile",
|
||||
scope: setting.ProfileSetting,
|
||||
osOverride: "windows",
|
||||
wantHasAnyName: "windows_syspolicy_profile_any",
|
||||
wantHasAnyType: clientmetric.TypeGauge,
|
||||
wantNumErroredName: "windows_syspolicy_profile_errors",
|
||||
wantNumErroredType: clientmetric.TypeCounter,
|
||||
},
|
||||
{
|
||||
name: "windows-user",
|
||||
scope: setting.UserSetting,
|
||||
osOverride: "windows",
|
||||
wantHasAnyName: "windows_syspolicy_user_any",
|
||||
wantHasAnyType: clientmetric.TypeGauge,
|
||||
wantNumErroredName: "windows_syspolicy_user_errors",
|
||||
wantNumErroredType: clientmetric.TypeCounter,
|
||||
},
|
||||
{
|
||||
name: "android-device",
|
||||
scope: setting.DeviceSetting,
|
||||
osOverride: "android",
|
||||
wantHasAnyName: "android_syspolicy_any",
|
||||
wantHasAnyType: clientmetric.TypeGauge,
|
||||
wantNumErroredName: "android_syspolicy_errors",
|
||||
wantNumErroredType: clientmetric.TypeCounter,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
internal.OSForTesting.SetForTest(t, tt.osOverride, nil)
|
||||
metrics := newScopeMetrics(tt.scope)
|
||||
hasAny, ok := metrics.hasAny.(*funcMetric)
|
||||
if !ok {
|
||||
t.Fatal("hasAny is not a funcMetric")
|
||||
}
|
||||
numErrored, ok := metrics.numErrored.(*funcMetric)
|
||||
if !ok {
|
||||
t.Fatal("numErrored is not a funcMetric")
|
||||
}
|
||||
if hasAny.name != tt.wantHasAnyName {
|
||||
t.Errorf("hasAny.Name: got %q, want %q", hasAny.name, tt.wantHasAnyName)
|
||||
}
|
||||
if hasAny.typ != tt.wantHasAnyType {
|
||||
t.Errorf("hasAny.Type: got %q, want %q", hasAny.typ, tt.wantHasAnyType)
|
||||
}
|
||||
if numErrored.name != tt.wantNumErroredName {
|
||||
t.Errorf("numErrored.Name: got %q, want %q", numErrored.name, tt.wantNumErroredName)
|
||||
}
|
||||
if numErrored.typ != tt.wantNumErroredType {
|
||||
t.Errorf("hasAny.Type: got %q, want %q", numErrored.typ, tt.wantNumErroredType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type testSettingDetails struct {
|
||||
definition *setting.Definition
|
||||
origin *setting.Origin
|
||||
value any
|
||||
err error
|
||||
}
|
||||
|
||||
func TestReportMetrics(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
osOverride string
|
||||
useMetrics bool
|
||||
settings []testSettingDetails
|
||||
wantMetrics []TestState
|
||||
wantResetMetrics []TestState
|
||||
}{
|
||||
{
|
||||
name: "none",
|
||||
osOverride: "windows",
|
||||
settings: []testSettingDetails{},
|
||||
wantMetrics: []TestState{},
|
||||
},
|
||||
{
|
||||
name: "single-value",
|
||||
osOverride: "windows",
|
||||
settings: []testSettingDetails{
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.DeviceScope),
|
||||
value: 42,
|
||||
},
|
||||
},
|
||||
wantMetrics: []TestState{
|
||||
{"windows_syspolicy_any", 1},
|
||||
{"windows_syspolicy_TestSetting01", 1},
|
||||
},
|
||||
wantResetMetrics: []TestState{
|
||||
{"windows_syspolicy_any", 0},
|
||||
{"windows_syspolicy_TestSetting01", 0},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single-error",
|
||||
osOverride: "windows",
|
||||
settings: []testSettingDetails{
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting02", setting.DeviceSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.DeviceScope),
|
||||
err: errors.New("bang!"),
|
||||
},
|
||||
},
|
||||
wantMetrics: []TestState{
|
||||
{"windows_syspolicy_errors", 1},
|
||||
{"windows_syspolicy_TestSetting02_error", 1},
|
||||
},
|
||||
wantResetMetrics: []TestState{
|
||||
{"windows_syspolicy_errors", 1},
|
||||
{"windows_syspolicy_TestSetting02_error", 0},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "value-and-error",
|
||||
osOverride: "windows",
|
||||
settings: []testSettingDetails{
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.DeviceScope),
|
||||
value: 42,
|
||||
},
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting02", setting.DeviceSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.DeviceScope),
|
||||
err: errors.New("bang!"),
|
||||
},
|
||||
},
|
||||
|
||||
wantMetrics: []TestState{
|
||||
{"windows_syspolicy_any", 1},
|
||||
{"windows_syspolicy_errors", 1},
|
||||
{"windows_syspolicy_TestSetting01", 1},
|
||||
{"windows_syspolicy_TestSetting02_error", 1},
|
||||
},
|
||||
wantResetMetrics: []TestState{
|
||||
{"windows_syspolicy_any", 0},
|
||||
{"windows_syspolicy_errors", 1},
|
||||
{"windows_syspolicy_TestSetting01", 0},
|
||||
{"windows_syspolicy_TestSetting02_error", 0},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "two-values",
|
||||
osOverride: "windows",
|
||||
settings: []testSettingDetails{
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.DeviceScope),
|
||||
value: 42,
|
||||
},
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting02", setting.DeviceSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.DeviceScope),
|
||||
value: 17,
|
||||
},
|
||||
},
|
||||
wantMetrics: []TestState{
|
||||
{"windows_syspolicy_any", 1},
|
||||
{"windows_syspolicy_TestSetting01", 1},
|
||||
{"windows_syspolicy_TestSetting02", 1},
|
||||
},
|
||||
wantResetMetrics: []TestState{
|
||||
{"windows_syspolicy_any", 0},
|
||||
{"windows_syspolicy_TestSetting01", 0},
|
||||
{"windows_syspolicy_TestSetting02", 0},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "two-errors",
|
||||
osOverride: "windows",
|
||||
settings: []testSettingDetails{
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.DeviceScope),
|
||||
err: errors.New("bang!"),
|
||||
},
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting02", setting.DeviceSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.DeviceScope),
|
||||
err: errors.New("bang!"),
|
||||
},
|
||||
},
|
||||
wantMetrics: []TestState{
|
||||
{"windows_syspolicy_errors", 2},
|
||||
{"windows_syspolicy_TestSetting01_error", 1},
|
||||
{"windows_syspolicy_TestSetting02_error", 1},
|
||||
},
|
||||
wantResetMetrics: []TestState{
|
||||
{"windows_syspolicy_errors", 2},
|
||||
{"windows_syspolicy_TestSetting01_error", 0},
|
||||
{"windows_syspolicy_TestSetting02_error", 0},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multi-scope",
|
||||
osOverride: "windows",
|
||||
settings: []testSettingDetails{
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting01", setting.ProfileSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.DeviceScope),
|
||||
value: 42,
|
||||
},
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting02", setting.ProfileSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.CurrentProfileScope),
|
||||
err: errors.New("bang!"),
|
||||
},
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting03", setting.UserSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.CurrentUserScope),
|
||||
value: 17,
|
||||
},
|
||||
},
|
||||
wantMetrics: []TestState{
|
||||
{"windows_syspolicy_any", 1},
|
||||
{"windows_syspolicy_profile_errors", 1},
|
||||
{"windows_syspolicy_user_any", 1},
|
||||
{"windows_syspolicy_TestSetting01", 1},
|
||||
{"windows_syspolicy_TestSetting02_profile_error", 1},
|
||||
{"windows_syspolicy_TestSetting03_user", 1},
|
||||
},
|
||||
wantResetMetrics: []TestState{
|
||||
{"windows_syspolicy_any", 0},
|
||||
{"windows_syspolicy_profile_errors", 1},
|
||||
{"windows_syspolicy_user_any", 0},
|
||||
{"windows_syspolicy_TestSetting01", 0},
|
||||
{"windows_syspolicy_TestSetting02_profile_error", 0},
|
||||
{"windows_syspolicy_TestSetting03_user", 0},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "report-metrics-on-android",
|
||||
osOverride: "android",
|
||||
settings: []testSettingDetails{
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.DeviceScope),
|
||||
value: 42,
|
||||
},
|
||||
},
|
||||
wantMetrics: []TestState{
|
||||
{"android_syspolicy_any", 1},
|
||||
{"android_syspolicy_TestSetting01", 1},
|
||||
},
|
||||
wantResetMetrics: []TestState{
|
||||
{"android_syspolicy_any", 0},
|
||||
{"android_syspolicy_TestSetting01", 0},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "do-not-report-metrics-on-macos",
|
||||
osOverride: "macos",
|
||||
settings: []testSettingDetails{
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.DeviceScope),
|
||||
value: 42,
|
||||
},
|
||||
},
|
||||
|
||||
wantMetrics: []TestState{}, // none reported
|
||||
},
|
||||
{
|
||||
name: "do-not-report-metrics-on-ios",
|
||||
osOverride: "ios",
|
||||
settings: []testSettingDetails{
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.DeviceScope),
|
||||
value: 42,
|
||||
},
|
||||
},
|
||||
|
||||
wantMetrics: []TestState{}, // none reported
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset the lazy value so it'll be re-evaluated with the osOverride.
|
||||
lazyReportMetrics = lazy.SyncValue[bool]{}
|
||||
t.Cleanup(func() {
|
||||
// Also reset it during the cleanup.
|
||||
lazyReportMetrics = lazy.SyncValue[bool]{}
|
||||
})
|
||||
internal.OSForTesting.SetForTest(t, tt.osOverride, nil)
|
||||
|
||||
h := NewTestHandler(t)
|
||||
SetHooksForTest(t, h.AddMetric, h.SetMetric)
|
||||
|
||||
for _, s := range tt.settings {
|
||||
if s.err != nil {
|
||||
ReportError(s.origin, s.definition, s.err)
|
||||
} else {
|
||||
ReportConfigured(s.origin, s.definition, s.value)
|
||||
}
|
||||
}
|
||||
h.MustEqual(tt.wantMetrics...)
|
||||
|
||||
for _, s := range tt.settings {
|
||||
Reset(s.origin)
|
||||
ReportNotConfigured(s.origin, s.definition)
|
||||
}
|
||||
h.MustEqual(tt.wantResetMetrics...)
|
||||
})
|
||||
}
|
||||
}
|
@ -0,0 +1,88 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"tailscale.com/util/clientmetric"
|
||||
"tailscale.com/util/set"
|
||||
"tailscale.com/util/syspolicy/internal"
|
||||
)
|
||||
|
||||
// TestState represents a metric name and its expected value.
|
||||
type TestState struct {
|
||||
Name string // `$os` in the name will be replaced by the actual operating system name.
|
||||
Value int64
|
||||
}
|
||||
|
||||
// TestHandler facilitates testing of the code that uses metrics.
|
||||
type TestHandler struct {
|
||||
t internal.TB
|
||||
|
||||
m map[string]int64
|
||||
}
|
||||
|
||||
// NewTestHandler returns a new TestHandler.
|
||||
func NewTestHandler(t internal.TB) *TestHandler {
|
||||
return &TestHandler{t, make(map[string]int64)}
|
||||
}
|
||||
|
||||
// AddMetric increments the metric with the specified name and type by delta d.
|
||||
func (h *TestHandler) AddMetric(name string, typ clientmetric.Type, d int64) {
|
||||
h.t.Helper()
|
||||
if typ == clientmetric.TypeCounter && d < 0 {
|
||||
h.t.Fatalf("an attempt was made to decrement a counter metric %q", name)
|
||||
}
|
||||
if v, ok := h.m[name]; ok || d != 0 {
|
||||
h.m[name] = v + d
|
||||
}
|
||||
}
|
||||
|
||||
// SetMetric sets the metric with the specified name and type to the value v.
|
||||
func (h *TestHandler) SetMetric(name string, typ clientmetric.Type, v int64) {
|
||||
h.t.Helper()
|
||||
if typ == clientmetric.TypeCounter {
|
||||
h.t.Fatalf("an attempt was made to set a counter metric %q", name)
|
||||
}
|
||||
if _, ok := h.m[name]; ok || v != 0 {
|
||||
h.m[name] = v
|
||||
}
|
||||
}
|
||||
|
||||
// MustEqual fails the test if the actual metric state differs from the specified state.
|
||||
func (h *TestHandler) MustEqual(metrics ...TestState) {
|
||||
h.t.Helper()
|
||||
h.MustContain(metrics...)
|
||||
h.mustNoExtra(metrics...)
|
||||
}
|
||||
|
||||
// MustContain fails the test if the specified metrics are not set or have
|
||||
// different values than specified. It permits other metrics to be set in
|
||||
// addition to the ones being tested.
|
||||
func (h *TestHandler) MustContain(metrics ...TestState) {
|
||||
h.t.Helper()
|
||||
for _, m := range metrics {
|
||||
name := strings.ReplaceAll(m.Name, "$os", internal.OS())
|
||||
v, ok := h.m[name]
|
||||
if !ok {
|
||||
h.t.Errorf("%q: got (none), want %v", name, m.Value)
|
||||
} else if v != m.Value {
|
||||
h.t.Fatalf("%q: got %v, want %v", name, v, m.Value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *TestHandler) mustNoExtra(metrics ...TestState) {
|
||||
h.t.Helper()
|
||||
s := make(set.Set[string])
|
||||
for i := range metrics {
|
||||
s.Add(strings.ReplaceAll(metrics[i].Name, "$os", internal.OS()))
|
||||
}
|
||||
for n, v := range h.m {
|
||||
if !s.Contains(n) {
|
||||
h.t.Errorf("%q: got %v, want (none)", n, v)
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,394 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package source
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"slices"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"tailscale.com/util/mak"
|
||||
"tailscale.com/util/set"
|
||||
"tailscale.com/util/syspolicy/internal/loggerx"
|
||||
"tailscale.com/util/syspolicy/internal/metrics"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
)
|
||||
|
||||
// Reader reads all configured policy settings from a given [Store].
|
||||
// It registers a change callback with the [Store] and maintains the current version
|
||||
// of the [setting.Snapshot] by lazily re-reading policy settings from the [Store]
|
||||
// whenever a new settings snapshot is requested with [Reader.GetSettings].
|
||||
// It is safe for concurrent use.
|
||||
type Reader struct {
|
||||
store Store
|
||||
origin *setting.Origin
|
||||
settings []*setting.Definition
|
||||
unregisterChangeNotifier func()
|
||||
doneCh chan struct{} // closed when [Reader] is closed.
|
||||
|
||||
mu sync.Mutex
|
||||
closing bool
|
||||
upToDate bool
|
||||
lastPolicy *setting.Snapshot
|
||||
sessions set.HandleSet[*ReadingSession]
|
||||
}
|
||||
|
||||
// newReader returns a new [Reader] that reads policy settings from a given [Store].
|
||||
// The returned reader takes ownership of the store. If the store implements [io.Closer],
|
||||
// the returned reader will close the store when it is closed.
|
||||
func newReader(store Store, origin *setting.Origin) (*Reader, error) {
|
||||
settings, err := setting.Definitions()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if expirable, ok := store.(Expirable); ok {
|
||||
select {
|
||||
case <-expirable.Done():
|
||||
return nil, ErrStoreClosed
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
reader := &Reader{store: store, origin: origin, settings: settings, doneCh: make(chan struct{})}
|
||||
if changeable, ok := store.(Changeable); ok {
|
||||
// We should subscribe to policy change notifications first before reading
|
||||
// the policy settings from the store. This way we won't miss any notifications.
|
||||
if reader.unregisterChangeNotifier, err = changeable.RegisterChangeCallback(reader.onPolicyChange); err != nil {
|
||||
// Errors registering policy change callbacks are non-fatal.
|
||||
// TODO(nickkhyl): implement a background policy refresh every X minutes?
|
||||
loggerx.Errorf("failed to register %v policy change callback: %v", origin, err)
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := reader.reload(true); err != nil {
|
||||
if reader.unregisterChangeNotifier != nil {
|
||||
reader.unregisterChangeNotifier()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if expirable, ok := store.(Expirable); ok {
|
||||
if waitCh := expirable.Done(); waitCh != nil {
|
||||
go func() {
|
||||
select {
|
||||
case <-waitCh:
|
||||
reader.Close()
|
||||
case <-reader.doneCh:
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
return reader, nil
|
||||
}
|
||||
|
||||
// GetSettings returns the current [*setting.Snapshot],
|
||||
// re-reading it from from the underlying [Store] only if the policy
|
||||
// has changed since it was read last. It never fails and returns
|
||||
// the previous version of the policy settings if a read attempt fails.
|
||||
func (r *Reader) GetSettings() *setting.Snapshot {
|
||||
r.mu.Lock()
|
||||
upToDate, lastPolicy := r.upToDate, r.lastPolicy
|
||||
r.mu.Unlock()
|
||||
if upToDate {
|
||||
return lastPolicy
|
||||
}
|
||||
|
||||
policy, err := r.reload(false)
|
||||
if err != nil {
|
||||
// If the policy fails to reload completely, log an error and return the last cached version.
|
||||
// However, errors related to individual policy items are always
|
||||
// propagated to callers when they fetch those settings.
|
||||
loggerx.Errorf("failed to reload %v policy: %v", r.origin, err)
|
||||
}
|
||||
return policy
|
||||
}
|
||||
|
||||
// ReadSettings reads policy settings from the underlying [Store] even if no
|
||||
// changes were detected. It returns the new [*setting.Snapshot],nil on
|
||||
// success or an undefined snapshot (possibly `nil`) along with a non-`nil`
|
||||
// error in case of failure.
|
||||
func (r *Reader) ReadSettings() (*setting.Snapshot, error) {
|
||||
return r.reload(true)
|
||||
}
|
||||
|
||||
// reload is like [Reader.ReadSettings], but allows specifying whether to re-read
|
||||
// an unchanged policy, and returns the last [*setting.Snapshot] if the read fails.
|
||||
func (r *Reader) reload(force bool) (*setting.Snapshot, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if r.upToDate && !force {
|
||||
return r.lastPolicy, nil
|
||||
}
|
||||
|
||||
if lockable, ok := r.store.(Lockable); ok {
|
||||
if err := lockable.Lock(); err != nil {
|
||||
return r.lastPolicy, err
|
||||
}
|
||||
defer lockable.Unlock()
|
||||
}
|
||||
|
||||
r.upToDate = true
|
||||
|
||||
metrics.Reset(r.origin)
|
||||
|
||||
var m map[setting.Key]setting.RawItem
|
||||
if lastPolicyCount := r.lastPolicy.Len(); lastPolicyCount > 0 {
|
||||
m = make(map[setting.Key]setting.RawItem, lastPolicyCount)
|
||||
}
|
||||
for _, s := range r.settings {
|
||||
if !r.origin.Scope().IsConfigurableSetting(s) {
|
||||
// Skip settings that cannot be configured in the current scope.
|
||||
continue
|
||||
}
|
||||
|
||||
val, err := readPolicySettingValue(r.store, s)
|
||||
if err != nil && (errors.Is(err, setting.ErrNoSuchKey) || errors.Is(err, setting.ErrNotConfigured)) {
|
||||
metrics.ReportNotConfigured(r.origin, s)
|
||||
continue
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
metrics.ReportConfigured(r.origin, s, val)
|
||||
} else {
|
||||
metrics.ReportError(r.origin, s, err)
|
||||
}
|
||||
|
||||
// If there's an error reading a single policy, such as a value type mismatch,
|
||||
// we'll wrap the error to preserve its text and return it
|
||||
// whenever someone attempts to fetch the value.
|
||||
// Otherwise, the errorText will be nil.
|
||||
errorText := setting.MaybeErrorText(err)
|
||||
item := setting.RawItemWith(val, errorText, r.origin)
|
||||
mak.Set(&m, s.Key(), item)
|
||||
}
|
||||
|
||||
newPolicy := setting.NewSnapshot(m, setting.SummaryWith(r.origin))
|
||||
if r.lastPolicy == nil || !newPolicy.EqualItems(r.lastPolicy) {
|
||||
r.lastPolicy = newPolicy
|
||||
}
|
||||
return r.lastPolicy, nil
|
||||
}
|
||||
|
||||
// ReadingSession is like [Reader], but with a channel that's written
|
||||
// to when there's a policy change, and closed when the session is terminated.
|
||||
type ReadingSession struct {
|
||||
reader *Reader
|
||||
policyChangedCh chan struct{} // 1-buffered channel
|
||||
handle set.Handle // in the reader.sessions
|
||||
closeInternal func()
|
||||
}
|
||||
|
||||
// OpenSession opens and returns a new session to r, allowing the caller
|
||||
// to get notified whenever a policy change is reported by the [source.Store],
|
||||
// or an [ErrStoreClosed] if the reader has already been closed.
|
||||
func (r *Reader) OpenSession() (*ReadingSession, error) {
|
||||
session := &ReadingSession{
|
||||
reader: r,
|
||||
policyChangedCh: make(chan struct{}, 1),
|
||||
}
|
||||
session.closeInternal = sync.OnceFunc(func() { close(session.policyChangedCh) })
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if r.closing {
|
||||
return nil, ErrStoreClosed
|
||||
}
|
||||
session.handle = r.sessions.Add(session)
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// GetSettings is like [Reader.GetSettings].
|
||||
func (s *ReadingSession) GetSettings() *setting.Snapshot {
|
||||
return s.reader.GetSettings()
|
||||
}
|
||||
|
||||
// ReadSettings is like [Reader.ReadSettings].
|
||||
func (s *ReadingSession) ReadSettings() (*setting.Snapshot, error) {
|
||||
return s.reader.ReadSettings()
|
||||
}
|
||||
|
||||
// PolicyChanged returns a channel that's written to when
|
||||
// there's a policy change, closed when the session is terminated.
|
||||
func (s *ReadingSession) PolicyChanged() <-chan struct{} {
|
||||
return s.policyChangedCh
|
||||
}
|
||||
|
||||
// Close unregisters this session with the [Reader].
|
||||
func (s *ReadingSession) Close() {
|
||||
s.reader.mu.Lock()
|
||||
delete(s.reader.sessions, s.handle)
|
||||
s.closeInternal()
|
||||
s.reader.mu.Unlock()
|
||||
}
|
||||
|
||||
// onPolicyChange handles a policy change notification from the [Store],
|
||||
// invalidating the current [setting.Snapshot] in r,
|
||||
// and notifying the active [ReadingSession]s.
|
||||
func (r *Reader) onPolicyChange() {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.upToDate = false
|
||||
for _, s := range r.sessions {
|
||||
select {
|
||||
case s.policyChangedCh <- struct{}{}:
|
||||
// Notified.
|
||||
default:
|
||||
// 1-buffered channel is full, meaning that another policy change
|
||||
// notification is already en route.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the store reader and the underlying store.
|
||||
func (r *Reader) Close() error {
|
||||
r.mu.Lock()
|
||||
if r.closing {
|
||||
r.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
r.closing = true
|
||||
r.mu.Unlock()
|
||||
|
||||
if r.unregisterChangeNotifier != nil {
|
||||
r.unregisterChangeNotifier()
|
||||
r.unregisterChangeNotifier = nil
|
||||
}
|
||||
|
||||
if closer, ok := r.store.(io.Closer); ok {
|
||||
if err := closer.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
r.store = nil
|
||||
|
||||
close(r.doneCh)
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, c := range r.sessions {
|
||||
c.closeInternal()
|
||||
}
|
||||
r.sessions = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Done returns a channel that is closed when the reader is closed.
|
||||
func (r *Reader) Done() <-chan struct{} {
|
||||
return r.doneCh
|
||||
}
|
||||
|
||||
// ReadableSource is a [Source] open for reading.
|
||||
type ReadableSource struct {
|
||||
*Source
|
||||
*ReadingSession
|
||||
}
|
||||
|
||||
// Close closes the underlying [ReadingSession].
|
||||
func (s ReadableSource) Close() {
|
||||
s.ReadingSession.Close()
|
||||
}
|
||||
|
||||
// ReadableSources is a slice of [ReadableSource].
|
||||
type ReadableSources []ReadableSource
|
||||
|
||||
// Contains reports whether s contains the specified source.
|
||||
func (s ReadableSources) Contains(source *Source) bool {
|
||||
return s.IndexOf(source) != -1
|
||||
}
|
||||
|
||||
// IndexOf returns position of the specified source in s, or -1
|
||||
// if the source does not exist.
|
||||
func (s ReadableSources) IndexOf(source *Source) int {
|
||||
return slices.IndexFunc(s, func(rs ReadableSource) bool {
|
||||
return rs.Source == source
|
||||
})
|
||||
}
|
||||
|
||||
// InsertionIndexOf returns the position at which source can be inserted
|
||||
// to maintain the sorted order of the readableSources.
|
||||
// The return value is unspecified if s is not sorted on entry to InsertionIndexOf.
|
||||
func (s ReadableSources) InsertionIndexOf(source *Source) int {
|
||||
// Insert new sources after any existing sources with the same precedence,
|
||||
// and just before the first source with higher precedence.
|
||||
// Just like stable sort, but for insertion.
|
||||
// It's okay to use linear search as insertions are rare
|
||||
// and we never have more than just a few policy sources.
|
||||
higherPrecedence := func(rs ReadableSource) bool { return rs.Compare(source) > 0 }
|
||||
if i := slices.IndexFunc(s, higherPrecedence); i != -1 {
|
||||
return i
|
||||
}
|
||||
return len(s)
|
||||
}
|
||||
|
||||
// StableSort sorts [ReadableSource] in s by precedence, so that policy
|
||||
// settings from sources with higher precedence (e.g., [DeviceScope])
|
||||
// will be read and merged last, overriding any policy settings with
|
||||
// the same keys configured in sources with lower precedence
|
||||
// (e.g., [CurrentUserScope]).
|
||||
func (s *ReadableSources) StableSort() {
|
||||
sort.SliceStable(*s, func(i, j int) bool {
|
||||
return (*s)[i].Source.Compare((*s)[j].Source) < 0
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteAt closes and deletes the i-th source from s.
|
||||
func (s *ReadableSources) DeleteAt(i int) {
|
||||
(*s)[i].Close()
|
||||
*s = slices.Delete(*s, i, i+1)
|
||||
}
|
||||
|
||||
// Close closes and deletes all sources in s.
|
||||
func (s *ReadableSources) Close() {
|
||||
for _, s := range *s {
|
||||
s.Close()
|
||||
}
|
||||
*s = nil
|
||||
}
|
||||
|
||||
func readPolicySettingValue(store Store, s *setting.Definition) (value any, err error) {
|
||||
switch key := s.Key(); s.Type() {
|
||||
case setting.BooleanValue:
|
||||
return store.ReadBoolean(key)
|
||||
case setting.IntegerValue:
|
||||
return store.ReadUInt64(key)
|
||||
case setting.StringValue:
|
||||
return store.ReadString(key)
|
||||
case setting.StringListValue:
|
||||
return store.ReadStringArray(key)
|
||||
case setting.PreferenceOptionValue:
|
||||
s, err := store.ReadString(key)
|
||||
if err == nil {
|
||||
var value setting.PreferenceOption
|
||||
if err = value.UnmarshalText([]byte(s)); err == nil {
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
return setting.ShowChoiceByPolicy, err
|
||||
case setting.VisibilityValue:
|
||||
s, err := store.ReadString(key)
|
||||
if err == nil {
|
||||
var value setting.Visibility
|
||||
if err = value.UnmarshalText([]byte(s)); err == nil {
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
return setting.VisibleByPolicy, err
|
||||
case setting.DurationValue:
|
||||
s, err := store.ReadString(key)
|
||||
if err == nil {
|
||||
var value time.Duration
|
||||
if value, err = time.ParseDuration(s); err == nil {
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
default:
|
||||
return nil, fmt.Errorf("%w: unsupported setting type: %v", setting.ErrTypeMismatch, s.Type())
|
||||
}
|
||||
}
|
@ -0,0 +1,291 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package source
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"tailscale.com/util/must"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
)
|
||||
|
||||
func TestReaderLifecycle(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
origin *setting.Origin
|
||||
definitions []*setting.Definition
|
||||
wantReads []TestExpectedReads
|
||||
initStrings []TestSetting[string]
|
||||
initUInt64s []TestSetting[uint64]
|
||||
initWant *setting.Snapshot
|
||||
addStrings []TestSetting[string]
|
||||
addStringLists []TestSetting[[]string]
|
||||
newWant *setting.Snapshot
|
||||
}{
|
||||
{
|
||||
name: "read-all-settings-once",
|
||||
origin: setting.NewNamedOrigin("Test", setting.DeviceScope),
|
||||
definitions: []*setting.Definition{
|
||||
setting.NewDefinition("StringValue", setting.DeviceSetting, setting.StringValue),
|
||||
setting.NewDefinition("IntegerValue", setting.DeviceSetting, setting.IntegerValue),
|
||||
setting.NewDefinition("BooleanValue", setting.DeviceSetting, setting.BooleanValue),
|
||||
setting.NewDefinition("StringListValue", setting.DeviceSetting, setting.StringListValue),
|
||||
setting.NewDefinition("DurationValue", setting.DeviceSetting, setting.DurationValue),
|
||||
setting.NewDefinition("PreferenceOptionValue", setting.DeviceSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition("VisibilityValue", setting.DeviceSetting, setting.VisibilityValue),
|
||||
},
|
||||
wantReads: []TestExpectedReads{
|
||||
{Key: "StringValue", Type: setting.StringValue, NumTimes: 1},
|
||||
{Key: "IntegerValue", Type: setting.IntegerValue, NumTimes: 1},
|
||||
{Key: "BooleanValue", Type: setting.BooleanValue, NumTimes: 1},
|
||||
{Key: "StringListValue", Type: setting.StringListValue, NumTimes: 1},
|
||||
{Key: "DurationValue", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective
|
||||
{Key: "PreferenceOptionValue", Type: setting.StringValue, NumTimes: 1}, // and so are [setting.PreferenceOption]s
|
||||
{Key: "VisibilityValue", Type: setting.StringValue, NumTimes: 1}, // and [setting.Visibility]
|
||||
},
|
||||
initWant: setting.NewSnapshot(nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
|
||||
},
|
||||
{
|
||||
name: "re-read-all-settings-when-the-policy-changes",
|
||||
origin: setting.NewNamedOrigin("Test", setting.DeviceScope),
|
||||
definitions: []*setting.Definition{
|
||||
setting.NewDefinition("StringValue", setting.DeviceSetting, setting.StringValue),
|
||||
setting.NewDefinition("IntegerValue", setting.DeviceSetting, setting.IntegerValue),
|
||||
setting.NewDefinition("BooleanValue", setting.DeviceSetting, setting.BooleanValue),
|
||||
setting.NewDefinition("StringListValue", setting.DeviceSetting, setting.StringListValue),
|
||||
setting.NewDefinition("DurationValue", setting.DeviceSetting, setting.DurationValue),
|
||||
setting.NewDefinition("PreferenceOptionValue", setting.DeviceSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition("VisibilityValue", setting.DeviceSetting, setting.VisibilityValue),
|
||||
},
|
||||
wantReads: []TestExpectedReads{
|
||||
{Key: "StringValue", Type: setting.StringValue, NumTimes: 1},
|
||||
{Key: "IntegerValue", Type: setting.IntegerValue, NumTimes: 1},
|
||||
{Key: "BooleanValue", Type: setting.BooleanValue, NumTimes: 1},
|
||||
{Key: "StringListValue", Type: setting.StringListValue, NumTimes: 1},
|
||||
{Key: "DurationValue", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective
|
||||
{Key: "PreferenceOptionValue", Type: setting.StringValue, NumTimes: 1}, // and so are [setting.PreferenceOption]s
|
||||
{Key: "VisibilityValue", Type: setting.StringValue, NumTimes: 1}, // and [setting.Visibility]
|
||||
},
|
||||
initWant: setting.NewSnapshot(nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
|
||||
addStrings: []TestSetting[string]{TestSettingOf("StringValue", "S1")},
|
||||
addStringLists: []TestSetting[[]string]{TestSettingOf("StringListValue", []string{"S1", "S2", "S3"})},
|
||||
newWant: setting.NewSnapshot(map[setting.Key]setting.RawItem{
|
||||
"StringValue": setting.RawItemWith("S1", nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
|
||||
"StringListValue": setting.RawItemWith([]string{"S1", "S2", "S3"}, nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
|
||||
}, setting.NewNamedOrigin("Test", setting.DeviceScope)),
|
||||
},
|
||||
{
|
||||
name: "read-settings-if-in-scope/device",
|
||||
origin: setting.NewNamedOrigin("Test", setting.DeviceScope),
|
||||
definitions: []*setting.Definition{
|
||||
setting.NewDefinition("DeviceSetting", setting.DeviceSetting, setting.StringValue),
|
||||
setting.NewDefinition("ProfileSetting", setting.ProfileSetting, setting.IntegerValue),
|
||||
setting.NewDefinition("UserSetting", setting.UserSetting, setting.BooleanValue),
|
||||
},
|
||||
wantReads: []TestExpectedReads{
|
||||
{Key: "DeviceSetting", Type: setting.StringValue, NumTimes: 1},
|
||||
{Key: "ProfileSetting", Type: setting.IntegerValue, NumTimes: 1},
|
||||
{Key: "UserSetting", Type: setting.BooleanValue, NumTimes: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "read-settings-if-in-scope/profile",
|
||||
origin: setting.NewNamedOrigin("Test", setting.CurrentProfileScope),
|
||||
definitions: []*setting.Definition{
|
||||
setting.NewDefinition("DeviceSetting", setting.DeviceSetting, setting.StringValue),
|
||||
setting.NewDefinition("ProfileSetting", setting.ProfileSetting, setting.IntegerValue),
|
||||
setting.NewDefinition("UserSetting", setting.UserSetting, setting.BooleanValue),
|
||||
},
|
||||
wantReads: []TestExpectedReads{
|
||||
// Device settings cannot be configured at the profile scope and should not be read.
|
||||
{Key: "ProfileSetting", Type: setting.IntegerValue, NumTimes: 1},
|
||||
{Key: "UserSetting", Type: setting.BooleanValue, NumTimes: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "read-settings-if-in-scope/user",
|
||||
origin: setting.NewNamedOrigin("Test", setting.CurrentUserScope),
|
||||
definitions: []*setting.Definition{
|
||||
setting.NewDefinition("DeviceSetting", setting.DeviceSetting, setting.StringValue),
|
||||
setting.NewDefinition("ProfileSetting", setting.ProfileSetting, setting.IntegerValue),
|
||||
setting.NewDefinition("UserSetting", setting.UserSetting, setting.BooleanValue),
|
||||
},
|
||||
wantReads: []TestExpectedReads{
|
||||
// Device and profile settings cannot be configured at the profile scope and should not be read.
|
||||
{Key: "UserSetting", Type: setting.BooleanValue, NumTimes: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "read-stringy-settings",
|
||||
origin: setting.NewNamedOrigin("Test", setting.DeviceScope),
|
||||
definitions: []*setting.Definition{
|
||||
setting.NewDefinition("DurationValue", setting.DeviceSetting, setting.DurationValue),
|
||||
setting.NewDefinition("PreferenceOptionValue", setting.DeviceSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition("VisibilityValue", setting.DeviceSetting, setting.VisibilityValue),
|
||||
},
|
||||
wantReads: []TestExpectedReads{
|
||||
{Key: "DurationValue", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective
|
||||
{Key: "PreferenceOptionValue", Type: setting.StringValue, NumTimes: 1}, // and so are [setting.PreferenceOption]s
|
||||
{Key: "VisibilityValue", Type: setting.StringValue, NumTimes: 1}, // and [setting.Visibility]
|
||||
},
|
||||
initStrings: []TestSetting[string]{
|
||||
TestSettingOf("DurationValue", "2h30m"),
|
||||
TestSettingOf("PreferenceOptionValue", "always"),
|
||||
TestSettingOf("VisibilityValue", "show"),
|
||||
},
|
||||
initWant: setting.NewSnapshot(map[setting.Key]setting.RawItem{
|
||||
"DurationValue": setting.RawItemWith(must.Get(time.ParseDuration("2h30m")), nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
|
||||
"PreferenceOptionValue": setting.RawItemWith(setting.AlwaysByPolicy, nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
|
||||
"VisibilityValue": setting.RawItemWith(setting.VisibleByPolicy, nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
|
||||
}, setting.NewNamedOrigin("Test", setting.DeviceScope)),
|
||||
},
|
||||
{
|
||||
name: "read-erroneous-stringy-settings",
|
||||
origin: setting.NewNamedOrigin("Test", setting.CurrentUserScope),
|
||||
definitions: []*setting.Definition{
|
||||
setting.NewDefinition("DurationValue1", setting.UserSetting, setting.DurationValue),
|
||||
setting.NewDefinition("DurationValue2", setting.UserSetting, setting.DurationValue),
|
||||
setting.NewDefinition("PreferenceOptionValue", setting.UserSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition("VisibilityValue", setting.UserSetting, setting.VisibilityValue),
|
||||
},
|
||||
wantReads: []TestExpectedReads{
|
||||
{Key: "DurationValue1", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective
|
||||
{Key: "DurationValue2", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective
|
||||
{Key: "PreferenceOptionValue", Type: setting.StringValue, NumTimes: 1}, // and so are [setting.PreferenceOption]s
|
||||
{Key: "VisibilityValue", Type: setting.StringValue, NumTimes: 1}, // and [setting.Visibility]
|
||||
},
|
||||
initStrings: []TestSetting[string]{
|
||||
TestSettingOf("DurationValue1", "soon"),
|
||||
TestSettingWithError[string]("DurationValue2", setting.NewErrorText("bang!")),
|
||||
TestSettingOf("PreferenceOptionValue", "sometimes"),
|
||||
},
|
||||
initUInt64s: []TestSetting[uint64]{
|
||||
TestSettingOf[uint64]("VisibilityValue", 42), // type mismatch
|
||||
},
|
||||
initWant: setting.NewSnapshot(map[setting.Key]setting.RawItem{
|
||||
"DurationValue1": setting.RawItemWith(nil, setting.NewErrorText("time: invalid duration \"soon\""), setting.NewNamedOrigin("Test", setting.CurrentUserScope)),
|
||||
"DurationValue2": setting.RawItemWith(nil, setting.NewErrorText("bang!"), setting.NewNamedOrigin("Test", setting.CurrentUserScope)),
|
||||
"PreferenceOptionValue": setting.RawItemWith(setting.ShowChoiceByPolicy, nil, setting.NewNamedOrigin("Test", setting.CurrentUserScope)),
|
||||
"VisibilityValue": setting.RawItemWith(setting.VisibleByPolicy, setting.NewErrorText("type mismatch in ReadString: got uint64"), setting.NewNamedOrigin("Test", setting.CurrentUserScope)),
|
||||
}, setting.NewNamedOrigin("Test", setting.CurrentUserScope)),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
setting.SetDefinitionsForTest(t, tt.definitions...)
|
||||
store := NewTestStore(t)
|
||||
store.SetStrings(tt.initStrings...)
|
||||
store.SetUInt64s(tt.initUInt64s...)
|
||||
|
||||
reader, err := newReader(store, tt.origin)
|
||||
if err != nil {
|
||||
t.Fatalf("newReader failed: %v", err)
|
||||
}
|
||||
|
||||
if got := reader.GetSettings(); tt.initWant != nil && !got.Equal(tt.initWant) {
|
||||
t.Errorf("Settings do not match: got %v, want %v", got, tt.initWant)
|
||||
}
|
||||
if tt.wantReads != nil {
|
||||
store.ReadsMustEqual(tt.wantReads...)
|
||||
}
|
||||
|
||||
// Should not result in new reads as there were no changes.
|
||||
N := 100
|
||||
for range N {
|
||||
reader.GetSettings()
|
||||
}
|
||||
if tt.wantReads != nil {
|
||||
store.ReadsMustEqual(tt.wantReads...)
|
||||
}
|
||||
store.ResetCounters()
|
||||
|
||||
got, err := reader.ReadSettings()
|
||||
if err != nil {
|
||||
t.Fatalf("ReadSettings failed: %v", err)
|
||||
}
|
||||
|
||||
if tt.initWant != nil && !got.Equal(tt.initWant) {
|
||||
t.Errorf("Settings do not match: got %v, want %v", got, tt.initWant)
|
||||
}
|
||||
|
||||
if tt.wantReads != nil {
|
||||
store.ReadsMustEqual(tt.wantReads...)
|
||||
}
|
||||
store.ResetCounters()
|
||||
|
||||
if len(tt.addStrings) != 0 || len(tt.addStringLists) != 0 {
|
||||
store.SetStrings(tt.addStrings...)
|
||||
store.SetStringLists(tt.addStringLists...)
|
||||
|
||||
// As the settings have changed, GetSettings needs to re-read them.
|
||||
if got, want := reader.GetSettings(), cmp.Or(tt.newWant, tt.initWant); !got.Equal(want) {
|
||||
t.Errorf("New Settings do not match: got %v, want %v", got, want)
|
||||
}
|
||||
if tt.wantReads != nil {
|
||||
store.ReadsMustEqual(tt.wantReads...)
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-reader.Done():
|
||||
t.Fatalf("the reader is closed")
|
||||
default:
|
||||
}
|
||||
|
||||
store.Close()
|
||||
|
||||
<-reader.Done()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadingSession(t *testing.T) {
|
||||
setting.SetDefinitionsForTest(t, setting.NewDefinition("StringValue", setting.DeviceSetting, setting.StringValue))
|
||||
store := NewTestStore(t)
|
||||
|
||||
origin := setting.NewOrigin(setting.DeviceScope)
|
||||
reader, err := newReader(store, origin)
|
||||
if err != nil {
|
||||
t.Fatalf("newReader failed: %v", err)
|
||||
}
|
||||
session, err := reader.OpenSession()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open a reading session: %v", err)
|
||||
}
|
||||
t.Cleanup(session.Close)
|
||||
|
||||
if got, want := session.GetSettings(), setting.NewSnapshot(nil, origin); !got.Equal(want) {
|
||||
t.Errorf("Settings do not match: got %v, want %v", got, want)
|
||||
}
|
||||
|
||||
select {
|
||||
case _, ok := <-session.PolicyChanged():
|
||||
if ok {
|
||||
t.Fatalf("the policy changed notification was sent prematurely")
|
||||
} else {
|
||||
t.Fatalf("the session was closed prematurely")
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
store.SetStrings(TestSettingOf("StringValue", "S1"))
|
||||
_, ok := <-session.PolicyChanged()
|
||||
if !ok {
|
||||
t.Fatalf("the session was closed prematurely")
|
||||
}
|
||||
|
||||
want := setting.NewSnapshot(map[setting.Key]setting.RawItem{
|
||||
"StringValue": setting.RawItemWith("S1", nil, origin),
|
||||
}, origin)
|
||||
if got := session.GetSettings(); !got.Equal(want) {
|
||||
t.Errorf("Settings do not match: got %v, want %v", got, want)
|
||||
}
|
||||
|
||||
store.Close()
|
||||
if _, ok = <-session.PolicyChanged(); ok {
|
||||
t.Fatalf("the session must be closed")
|
||||
}
|
||||
}
|
@ -0,0 +1,146 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package source defines interfaces for policy stores,
|
||||
// facilitates the creation of policy sources, and provides
|
||||
// functionality for reading policy settings from these sources.
|
||||
package source
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"tailscale.com/types/lazy"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
)
|
||||
|
||||
// ErrStoreClosed is an error returned when attempting to use a [Store] after it
|
||||
// has been closed.
|
||||
var ErrStoreClosed = errors.New("the policy store has been closed")
|
||||
|
||||
// Store provides methods to read system policy settings from OS-specific storage.
|
||||
// Implementations must be concurrency-safe, and may also implement
|
||||
// [Lockable], [Changeable], [Expirable] and [io.Closer].
|
||||
//
|
||||
// If a [Store] implementation also implements [io.Closer],
|
||||
// it will be called by the package to release the resources
|
||||
// when the store is no longer needed.
|
||||
type Store interface {
|
||||
// ReadString returns the value of a [setting.StringValue] with the specified key,
|
||||
// an [setting.ErrNotConfigured] if the policy setting is not configured, or
|
||||
// an error on failure.
|
||||
ReadString(key setting.Key) (string, error)
|
||||
// ReadUInt64 returns the value of a [setting.IntegerValue] with the specified key,
|
||||
// an [setting.ErrNotConfigured] if the policy setting is not configured, or
|
||||
// an error on failure.
|
||||
ReadUInt64(key setting.Key) (uint64, error)
|
||||
// ReadBoolean returns the value of a [setting.BooleanValue] with the specified key,
|
||||
// an [setting.ErrNotConfigured] if the policy setting is not configured, or
|
||||
// an error on failure.
|
||||
ReadBoolean(key setting.Key) (bool, error)
|
||||
// ReadStringArray returns the value of a [setting.StringListValue] with the specified key,
|
||||
// an [setting.ErrNotConfigured] if the policy setting is not configured, or
|
||||
// an error on failure.
|
||||
ReadStringArray(key setting.Key) ([]string, error)
|
||||
}
|
||||
|
||||
// Lockable is an optional interface that [Store] implementations may support.
|
||||
// Locking a [Store] is not mandatory as [Store] must be concurrency-safe,
|
||||
// but is recommended to avoid issues where consecutive read calls for related
|
||||
// policies might return inconsistent results if a policy change occurs between
|
||||
// the calls. Implementations may use locking to pre-read policies or for
|
||||
// similar performance optimizations.
|
||||
type Lockable interface {
|
||||
// Lock acquires a read lock on the policy store,
|
||||
// ensuring the store's state remains unchanged while locked.
|
||||
// Multiple readers can hold the lock simultaneously.
|
||||
// It returns an error if the store cannot be locked.
|
||||
Lock() error
|
||||
// Unlock unlocks the policy store.
|
||||
// It is a run-time error if the store is not locked on entry to Unlock.
|
||||
Unlock()
|
||||
}
|
||||
|
||||
// Changeable is an optional interface that [Store] implementations may support
|
||||
// if the policy settings they contain can be externally changed after being initially read.
|
||||
type Changeable interface {
|
||||
// RegisterChangeCallback adds a function that will be called
|
||||
// whenever there's a policy change in the [Store].
|
||||
// The returned function can be used to unregister the callback.
|
||||
RegisterChangeCallback(callback func()) (unregister func(), err error)
|
||||
}
|
||||
|
||||
// Expirable is an optional interface that [Store] implementations may support
|
||||
// if they can be externally closed or otherwise become invalid while in use.
|
||||
type Expirable interface {
|
||||
// Done returns a channel that is closed when the policy [Store] should no longer be used.
|
||||
// It should return nil if the store never expires.
|
||||
Done() <-chan struct{}
|
||||
}
|
||||
|
||||
// Source represents a named source of policy settings for a given [setting.PolicyScope].
|
||||
type Source struct {
|
||||
name string
|
||||
scope setting.PolicyScope
|
||||
store Store
|
||||
origin *setting.Origin
|
||||
|
||||
lazyReader lazy.SyncValue[*Reader]
|
||||
}
|
||||
|
||||
// NewSource returns a new [Source] with the specified name, scope, and store.
|
||||
func NewSource(name string, scope setting.PolicyScope, store Store) *Source {
|
||||
return &Source{name: name, scope: scope, store: store, origin: setting.NewNamedOrigin(name, scope)}
|
||||
}
|
||||
|
||||
// Name reports the name of the policy source.
|
||||
func (s *Source) Name() string {
|
||||
return s.name
|
||||
}
|
||||
|
||||
// Scope reports the management scope of the policy source.
|
||||
func (s *Source) Scope() setting.PolicyScope {
|
||||
return s.scope
|
||||
}
|
||||
|
||||
// Reader returns a [Reader] that reads from this source's [Store].
|
||||
func (s *Source) Reader() (*Reader, error) {
|
||||
return s.lazyReader.GetErr(func() (*Reader, error) {
|
||||
return newReader(s.store, s.origin)
|
||||
})
|
||||
}
|
||||
|
||||
// Description returns a formatted string with the scope and name of this policy source.
|
||||
// It can be used for logging or display purposes.
|
||||
func (s *Source) Description() string {
|
||||
if s.name != "" {
|
||||
return fmt.Sprintf("%s (%v)", s.name, s.Scope())
|
||||
}
|
||||
return s.Scope().String()
|
||||
}
|
||||
|
||||
// Compare returns an integer comparing s and s2
|
||||
// by their precedence, following the "last-wins" model.
|
||||
// The result will be:
|
||||
//
|
||||
// -1 if policy settings from s should be processed before policy settings from s2;
|
||||
// +1 if policy settings from s should be processed after policy settings from s2, overriding s2;
|
||||
// 0 if the relative processing order of policy settings in s and s2 is unspecified.
|
||||
func (s *Source) Compare(s2 *Source) int {
|
||||
return cmp.Compare(s2.Scope().Kind(), s.Scope().Kind())
|
||||
}
|
||||
|
||||
// Close closes the [Source] and the underlying [Store].
|
||||
func (s *Source) Close() error {
|
||||
// The [Reader], if any, owns the [Store].
|
||||
if reader, _ := s.lazyReader.GetErr(func() (*Reader, error) { return nil, ErrStoreClosed }); reader != nil {
|
||||
return reader.Close()
|
||||
}
|
||||
// Otherwise, it is our responsibility to close it.
|
||||
if closer, ok := s.store.(io.Closer); ok {
|
||||
return closer.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
@ -0,0 +1,450 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package source
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
"tailscale.com/util/set"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
"tailscale.com/util/winutil/gp"
|
||||
)
|
||||
|
||||
const (
|
||||
softwareKeyName = `Software`
|
||||
tsPoliciesSubkey = `Policies\Tailscale`
|
||||
tsIPNSubkey = `Tailscale IPN` // the legacy key we need to fallback to
|
||||
)
|
||||
|
||||
var (
|
||||
_ Store = (*PlatformPolicyStore)(nil)
|
||||
_ Lockable = (*PlatformPolicyStore)(nil)
|
||||
_ Changeable = (*PlatformPolicyStore)(nil)
|
||||
_ Expirable = (*PlatformPolicyStore)(nil)
|
||||
)
|
||||
|
||||
// PlatformPolicyStore implements [Store] by providing read access to
|
||||
// Registry-based Tailscale policies, such as those configured via Group Policy or MDM.
|
||||
// For better performance and consistency, it is recommended to lock it when
|
||||
// reading multiple policy settings sequentially.
|
||||
// It also allows subscribing to policy change notifications.
|
||||
type PlatformPolicyStore struct {
|
||||
scope gp.Scope // [gp.MachinePolicy] or [gp.UserPolicy]
|
||||
|
||||
// The softwareKey can be HKLM\Software, HKCU\Software, or
|
||||
// HKU\{SID}\Software. Anything below the Software subkey, including
|
||||
// Software\Policies, may not yet exist or could be deleted throughout the
|
||||
// [PlatformPolicyStore]'s lifespan, invalidating the handle. We also prefer
|
||||
// to always use a real registry key (rather than a predefined HKLM or HKCU)
|
||||
// to simplify bookkeeping (predefined keys should never be closed).
|
||||
// Finally, this will allow us to watch for any registry changes directly
|
||||
// should we need this in the future in addition to gp.ChangeWatcher.
|
||||
softwareKey registry.Key
|
||||
watcher *gp.ChangeWatcher
|
||||
|
||||
done chan struct{} // done is closed when Close call completes
|
||||
|
||||
// The policyLock can be locked by the caller when reading multiple policy settings
|
||||
// to prevent the Group Policy Client service from modifying policies while
|
||||
// they are being read.
|
||||
//
|
||||
// When both policyLock and mu need to be taken, mu must be taken before policyLock.
|
||||
policyLock *gp.PolicyLock
|
||||
|
||||
mu sync.Mutex
|
||||
tsKeys []registry.Key // or nil if the [PlatformPolicyStore] hasn't been locked.
|
||||
cbs set.HandleSet[func()] // policy change callbacks
|
||||
lockCnt int
|
||||
locked sync.WaitGroup
|
||||
closing bool
|
||||
closed bool
|
||||
}
|
||||
|
||||
type registryValueGetter[T any] func(key registry.Key, name string) (T, error)
|
||||
|
||||
// NewMachinePlatformPolicyStore returns a new [PlatformPolicyStore] for the machine.
|
||||
func NewMachinePlatformPolicyStore() (*PlatformPolicyStore, error) {
|
||||
softwareKey, err := registry.OpenKey(registry.LOCAL_MACHINE, softwareKeyName, windows.KEY_READ)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open the %s key: %w", softwareKeyName, err)
|
||||
}
|
||||
return newPlatformPolicyStore(gp.MachinePolicy, softwareKey, gp.NewMachinePolicyLock()), nil
|
||||
}
|
||||
|
||||
// NewUserPlatformPolicyStore returns a new [PlatformPolicyStore] for the user specified by its token.
|
||||
// User's profile must be loaded, and the token handle must have [windows.TOKEN_QUERY]
|
||||
// and [windows.TOKEN_DUPLICATE] access. The caller retains ownership of the token.
|
||||
func NewUserPlatformPolicyStore(token windows.Token) (*PlatformPolicyStore, error) {
|
||||
var err error
|
||||
var softwareKey registry.Key
|
||||
if token != 0 {
|
||||
var user *windows.Tokenuser
|
||||
if user, err = token.GetTokenUser(); err != nil {
|
||||
return nil, fmt.Errorf("failed to get token user: %w", err)
|
||||
}
|
||||
userSid := user.User.Sid
|
||||
softwareKey, err = registry.OpenKey(registry.USERS, userSid.String()+`\`+softwareKeyName, windows.KEY_READ)
|
||||
} else {
|
||||
softwareKey, err = registry.OpenKey(registry.CURRENT_USER, softwareKeyName, windows.KEY_READ)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open the %s key: %w", softwareKeyName, err)
|
||||
}
|
||||
policyLock, err := gp.NewUserPolicyLock(token)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create a user policy lock: %w", err)
|
||||
}
|
||||
return newPlatformPolicyStore(gp.UserPolicy, softwareKey, policyLock), nil
|
||||
}
|
||||
|
||||
func newPlatformPolicyStore(scope gp.Scope, softwareKey registry.Key, policyLock *gp.PolicyLock) *PlatformPolicyStore {
|
||||
return &PlatformPolicyStore{
|
||||
scope: scope,
|
||||
softwareKey: softwareKey,
|
||||
done: make(chan struct{}),
|
||||
policyLock: policyLock,
|
||||
}
|
||||
}
|
||||
|
||||
// Lock locks the policy store, preventing the system from modifying the policies
|
||||
// while they are being read. It is a read lock that may be acquired by multiple goroutines.
|
||||
// Each Lock call must be balanced by exactly one Unlock call.
|
||||
func (ps *PlatformPolicyStore) Lock() (err error) {
|
||||
ps.mu.Lock()
|
||||
defer ps.mu.Unlock()
|
||||
|
||||
if ps.closing {
|
||||
return ErrStoreClosed
|
||||
}
|
||||
|
||||
ps.lockCnt += 1
|
||||
if ps.lockCnt != 1 {
|
||||
return nil
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
ps.lockCnt -= 1
|
||||
}
|
||||
}()
|
||||
|
||||
// Ensure ps remains open while the lock is held.
|
||||
ps.locked.Add(1)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
ps.locked.Done()
|
||||
}
|
||||
}()
|
||||
|
||||
// Acquire the GP lock to prevent the system from modifying policy settings
|
||||
// while they are being read.
|
||||
if err := ps.policyLock.Lock(); err != nil {
|
||||
if errors.Is(err, gp.ErrInvalidLockState) {
|
||||
// The policy store is being closed and we've lost the race.
|
||||
return ErrStoreClosed
|
||||
}
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
ps.policyLock.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
// Keep the Tailscale's registry keys open for the duration of the lock.
|
||||
keyNames := tailscaleKeyNamesFor(ps.scope)
|
||||
ps.tsKeys = make([]registry.Key, 0, len(keyNames))
|
||||
for _, keyName := range keyNames {
|
||||
var tsKey registry.Key
|
||||
tsKey, err = registry.OpenKey(ps.softwareKey, keyName, windows.KEY_READ)
|
||||
if err != nil {
|
||||
if err == registry.ErrNotExist {
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
ps.tsKeys = append(ps.tsKeys, tsKey)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unlock decrements the lock counter and unlocks the policy store once the counter reaches 0.
|
||||
// It panics if ps is not locked on entry to Unlock.
|
||||
func (ps *PlatformPolicyStore) Unlock() {
|
||||
ps.mu.Lock()
|
||||
defer ps.mu.Unlock()
|
||||
|
||||
ps.lockCnt -= 1
|
||||
if ps.lockCnt < 0 {
|
||||
panic("negative lockCnt")
|
||||
} else if ps.lockCnt != 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, key := range ps.tsKeys {
|
||||
key.Close()
|
||||
}
|
||||
ps.tsKeys = nil
|
||||
ps.policyLock.Unlock()
|
||||
ps.locked.Done()
|
||||
}
|
||||
|
||||
// RegisterChangeCallback adds a function that will be called whenever there's a policy change.
|
||||
// It returns a function that can be used to unregister the specified callback or an error.
|
||||
// The error is [ErrStoreClosed] if ps has already been closed.
|
||||
func (ps *PlatformPolicyStore) RegisterChangeCallback(cb func()) (unregister func(), err error) {
|
||||
ps.mu.Lock()
|
||||
defer ps.mu.Unlock()
|
||||
if ps.closing {
|
||||
return nil, ErrStoreClosed
|
||||
}
|
||||
|
||||
handle := ps.cbs.Add(cb)
|
||||
if len(ps.cbs) == 1 {
|
||||
if ps.watcher, err = gp.NewChangeWatcher(ps.scope, ps.onChange); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return func() {
|
||||
ps.mu.Lock()
|
||||
defer ps.mu.Unlock()
|
||||
delete(ps.cbs, handle)
|
||||
if len(ps.cbs) == 0 {
|
||||
if ps.watcher != nil {
|
||||
ps.watcher.Close()
|
||||
ps.watcher = nil
|
||||
}
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (ps *PlatformPolicyStore) onChange() {
|
||||
ps.mu.Lock()
|
||||
defer ps.mu.Unlock()
|
||||
if ps.closing {
|
||||
return
|
||||
}
|
||||
for _, callback := range ps.cbs {
|
||||
go callback()
|
||||
}
|
||||
}
|
||||
|
||||
// ReadString retrieves a string policy with the specified key.
|
||||
// It returns [setting.ErrNotConfigured] if the policy setting does not exist.
|
||||
func (ps *PlatformPolicyStore) ReadString(key setting.Key) (val string, err error) {
|
||||
return getPolicyValue(ps, key,
|
||||
func(key registry.Key, valueName string) (string, error) {
|
||||
val, _, err := key.GetStringValue(valueName)
|
||||
return val, err
|
||||
})
|
||||
}
|
||||
|
||||
// ReadUInt64 retrieves an integer policy with the specified key.
|
||||
// It returns [setting.ErrNotConfigured] if the policy setting does not exist.
|
||||
func (ps *PlatformPolicyStore) ReadUInt64(key setting.Key) (uint64, error) {
|
||||
return getPolicyValue(ps, key,
|
||||
func(key registry.Key, valueName string) (uint64, error) {
|
||||
val, _, err := key.GetIntegerValue(valueName)
|
||||
return val, err
|
||||
})
|
||||
}
|
||||
|
||||
// ReadBoolean retrieves a boolean policy with the specified key.
|
||||
// It returns [setting.ErrNotConfigured] if the policy setting does not exist.
|
||||
func (ps *PlatformPolicyStore) ReadBoolean(key setting.Key) (bool, error) {
|
||||
return getPolicyValue(ps, key,
|
||||
func(key registry.Key, valueName string) (bool, error) {
|
||||
val, _, err := key.GetIntegerValue(valueName)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return val != 0, nil
|
||||
})
|
||||
}
|
||||
|
||||
// ReadString retrieves a multi-string policy with the specified key.
|
||||
// It returns [setting.ErrNotConfigured] if the policy setting does not exist.
|
||||
func (ps *PlatformPolicyStore) ReadStringArray(key setting.Key) ([]string, error) {
|
||||
return getPolicyValue(ps, key,
|
||||
func(key registry.Key, valueName string) ([]string, error) {
|
||||
val, _, err := key.GetStringsValue(valueName)
|
||||
if err != registry.ErrNotExist {
|
||||
return val, err // the err may be nil or non-nil
|
||||
}
|
||||
|
||||
// The idiomatic way to store multiple string values in Group Policy
|
||||
// and MDM for Windows is to have multiple REG_SZ (or REG_EXPAND_SZ)
|
||||
// values under a subkey rather than in a single REG_MULTI_SZ value.
|
||||
//
|
||||
// See the Group Policy: Registry Extension Encoding specification,
|
||||
// and specifically the ListElement and ListBox types.
|
||||
// https://web.archive.org/web/20240721033657/https://winprotocoldoc.blob.core.windows.net/productionwindowsarchives/MS-GPREG/%5BMS-GPREG%5D.pdf
|
||||
valKey, err := registry.OpenKey(key, valueName, windows.KEY_READ)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
valNames, err := valKey.ReadValueNames(0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
val = make([]string, 0, len(valNames))
|
||||
for _, name := range valNames {
|
||||
switch item, _, err := valKey.GetStringValue(name); {
|
||||
case err == registry.ErrNotExist:
|
||||
continue
|
||||
case err != nil:
|
||||
return nil, err
|
||||
default:
|
||||
val = append(val, item)
|
||||
}
|
||||
}
|
||||
return val, nil
|
||||
})
|
||||
}
|
||||
|
||||
// splitSettingKey extracts the registry key name and value name from a [setting.Key].
|
||||
// The [setting.Key] format allows grouping settings into nested categories using one
|
||||
// or more [setting.KeyPathSeparator]s in the path. How individual policy settings are
|
||||
// stored is an implementation detail of each [Store]. In the [PlatformPolicyStore]
|
||||
// for Windows, we map nested policy categories onto the Registry key hierarchy.
|
||||
// The last component after a [setting.KeyPathSeparator] is treated as the value name,
|
||||
// while everything preceding it is considered a subpath (relative to the {HKLM,HKCU}\Software\Policies\Tailscale key).
|
||||
// If there are no [setting.KeyPathSeparator]s in the key, the policy setting value
|
||||
// is meant to be stored directly under {HKLM,HKCU}\Software\Policies\Tailscale.
|
||||
func splitSettingKey(key setting.Key) (path, valueName string) {
|
||||
if idx := strings.LastIndex(string(key), setting.KeyPathSeparator); idx != -1 {
|
||||
path = strings.ReplaceAll(string(key[:idx]), setting.KeyPathSeparator, `\`)
|
||||
valueName = string(key[idx+len(setting.KeyPathSeparator):])
|
||||
return path, valueName
|
||||
}
|
||||
return "", string(key)
|
||||
}
|
||||
|
||||
func getPolicyValue[T any](ps *PlatformPolicyStore, key setting.Key, getter registryValueGetter[T]) (T, error) {
|
||||
var zero T
|
||||
|
||||
ps.mu.Lock()
|
||||
defer ps.mu.Unlock()
|
||||
if ps.closed {
|
||||
return zero, ErrStoreClosed
|
||||
}
|
||||
|
||||
path, valueName := splitSettingKey(key)
|
||||
getValue := func(key registry.Key) (T, error) {
|
||||
var err error
|
||||
if path != "" {
|
||||
key, err = registry.OpenKey(key, path, windows.KEY_READ)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
defer key.Close()
|
||||
}
|
||||
return getter(key, valueName)
|
||||
}
|
||||
|
||||
if ps.tsKeys != nil {
|
||||
// A non-nil tsKeys indicates that ps has been locked.
|
||||
// The slice may be empty if Tailscale policy keys do not exist.
|
||||
for _, tsKey := range ps.tsKeys {
|
||||
val, err := getValue(tsKey)
|
||||
if err == nil || err != registry.ErrNotExist {
|
||||
return val, err
|
||||
}
|
||||
}
|
||||
return zero, setting.ErrNotConfigured
|
||||
}
|
||||
|
||||
// The ps has not been locked, so we don't have any pre-opened keys.
|
||||
for _, tsKeyName := range tailscaleKeyNamesFor(ps.scope) {
|
||||
var tsKey registry.Key
|
||||
tsKey, err := registry.OpenKey(ps.softwareKey, tsKeyName, windows.KEY_READ)
|
||||
if err != nil {
|
||||
if err == registry.ErrNotExist {
|
||||
continue
|
||||
}
|
||||
return zero, err
|
||||
}
|
||||
val, err := getValue(tsKey)
|
||||
tsKey.Close()
|
||||
if err == nil || err != registry.ErrNotExist {
|
||||
return val, err
|
||||
}
|
||||
}
|
||||
|
||||
return zero, setting.ErrNotConfigured
|
||||
}
|
||||
|
||||
// Close closes the policy store and releases any associated resources.
|
||||
// It cancels pending locks and prevents any new lock attempts,
|
||||
// but waits for existing locks to be released.
|
||||
func (ps *PlatformPolicyStore) Close() error {
|
||||
// Request to close the Group Policy read lock.
|
||||
// Existing held locks will remain valid, but any new or pending locks
|
||||
// will fail. In certain scenarios, the corresponding write lock may be held
|
||||
// by the Group Policy service for extended periods (minutes rather than
|
||||
// seconds or milliseconds). In such cases, we prefer not to wait that long
|
||||
// if the ps is being closed anyway.
|
||||
if ps.policyLock != nil {
|
||||
ps.policyLock.Close()
|
||||
}
|
||||
|
||||
// Mark ps as closing to fast-fail any new lock attempts.
|
||||
// Callers that have already locked it can finish their reading.
|
||||
ps.mu.Lock()
|
||||
if ps.closing {
|
||||
ps.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
ps.closing = true
|
||||
if ps.watcher != nil {
|
||||
ps.watcher.Close()
|
||||
ps.watcher = nil
|
||||
}
|
||||
ps.mu.Unlock()
|
||||
|
||||
// Signal to the external code that ps should no longer be used.
|
||||
close(ps.done)
|
||||
|
||||
// Wait for any outstanding locks to be released.
|
||||
ps.locked.Wait()
|
||||
|
||||
// Deny any further read attempts and release remaining resources.
|
||||
ps.mu.Lock()
|
||||
defer ps.mu.Unlock()
|
||||
ps.cbs = nil
|
||||
ps.policyLock = nil
|
||||
ps.closed = true
|
||||
if ps.softwareKey != 0 {
|
||||
ps.softwareKey.Close()
|
||||
ps.softwareKey = 0
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Done returns a channel that is closed when the Close method is called.
|
||||
func (ps *PlatformPolicyStore) Done() <-chan struct{} {
|
||||
return ps.done
|
||||
}
|
||||
|
||||
func tailscaleKeyNamesFor(scope gp.Scope) []string {
|
||||
switch scope {
|
||||
case gp.MachinePolicy:
|
||||
// If a computer-side policy value does not exist under Software\Policies\Tailscale,
|
||||
// we need to fallback and use the legacy Software\Tailscale IPN key.
|
||||
return []string{tsPoliciesSubkey, tsIPNSubkey}
|
||||
case gp.UserPolicy:
|
||||
// However, we've never used the legacy key with user-side policies,
|
||||
// and we should never do so. Unlike HKLM\Software\Tailscale IPN,
|
||||
// its HKCU counterpart is user-writable.
|
||||
return []string{tsPoliciesSubkey}
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
@ -0,0 +1,398 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package source
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
"tailscale.com/tstest"
|
||||
"tailscale.com/util/cibuild"
|
||||
"tailscale.com/util/mak"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
"tailscale.com/util/winutil"
|
||||
"tailscale.com/util/winutil/gp"
|
||||
)
|
||||
|
||||
// subkeyStrings is a test type indicating that a string slice should be written
|
||||
// to the registry as multiple REG_SZ values under the setting's key,
|
||||
// rather than as a single REG_MULTI_SZ value under the group key.
|
||||
// This is the same format as ADMX use for string lists.
|
||||
type subkeyStrings []string
|
||||
|
||||
type testPolicyValue struct {
|
||||
name setting.Key
|
||||
value any
|
||||
}
|
||||
|
||||
func TestLockUnlockPolicyStore(t *testing.T) {
|
||||
// Make sure we don't leak goroutines
|
||||
tstest.ResourceCheck(t)
|
||||
|
||||
store, err := NewMachinePlatformPolicyStore()
|
||||
if err != nil {
|
||||
t.Fatalf("NewMachinePolicyStore failed: %v", err)
|
||||
}
|
||||
|
||||
t.Run("One-Goroutine", func(t *testing.T) {
|
||||
if err := store.Lock(); err != nil {
|
||||
t.Errorf("store.Lock(): got %v; want nil", err)
|
||||
return
|
||||
}
|
||||
if v, err := store.ReadString("NonExistingPolicySetting"); err == nil || !errors.Is(err, setting.ErrNotConfigured) {
|
||||
t.Errorf(`ReadString: got %v, %v; want "", %v`, v, err, setting.ErrNotConfigured)
|
||||
}
|
||||
store.Unlock()
|
||||
})
|
||||
|
||||
// Lock the store N times from different goroutines.
|
||||
const N = 100
|
||||
var unlocked atomic.Int32
|
||||
t.Run("N-Goroutines", func(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(N)
|
||||
for range N {
|
||||
go func() {
|
||||
if err := store.Lock(); err != nil {
|
||||
t.Errorf("store.Lock(): got %v; want nil", err)
|
||||
return
|
||||
}
|
||||
if v, err := store.ReadString("NonExistingPolicySetting"); err == nil || !errors.Is(err, setting.ErrNotConfigured) {
|
||||
t.Errorf(`ReadString: got %v, %v; want "", %v`, v, err, setting.ErrNotConfigured)
|
||||
}
|
||||
wg.Done()
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
unlocked.Add(1)
|
||||
store.Unlock()
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait until the store is locked N times.
|
||||
wg.Wait()
|
||||
})
|
||||
|
||||
// Close the store. The call should wait for all held locks to be released.
|
||||
if err := store.Close(); err != nil {
|
||||
t.Fatalf("(*PolicyStore).Close failed: %v", err)
|
||||
}
|
||||
if locked := unlocked.Load(); locked != N {
|
||||
t.Errorf("locked.Load(): got %v; want %v", locked, N)
|
||||
}
|
||||
|
||||
// Any further attempts to lock it should fail.
|
||||
if err = store.Lock(); err == nil || !errors.Is(err, ErrStoreClosed) {
|
||||
t.Errorf("store.Lock(): got %v; want %v", err, ErrStoreClosed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadPolicyStore(t *testing.T) {
|
||||
if !winutil.IsCurrentProcessElevated() {
|
||||
t.Skipf("test requires running as elevated user")
|
||||
}
|
||||
tests := []struct {
|
||||
name setting.Key
|
||||
newValue any
|
||||
legacyValue any
|
||||
want any
|
||||
}{
|
||||
{name: "LegacyPolicy", legacyValue: "LegacyValue", want: "LegacyValue"},
|
||||
{name: "StringPolicy", legacyValue: "LegacyValue", newValue: "Value", want: "Value"},
|
||||
{name: "StringPolicy_Empty", legacyValue: "LegacyValue", newValue: "", want: ""},
|
||||
{name: "BoolPolicy_True", newValue: true, want: true},
|
||||
{name: "BoolPolicy_False", newValue: false, want: false},
|
||||
{name: "UIntPolicy_1", newValue: uint32(10), want: uint64(10)}, // uint32 values should be returned as uint64
|
||||
{name: "UIntPolicy_2", newValue: uint64(1 << 37), want: uint64(1 << 37)},
|
||||
{name: "StringListPolicy", newValue: []string{"Value1", "Value2"}, want: []string{"Value1", "Value2"}},
|
||||
{name: "StringListPolicy_Empty", newValue: []string{}, want: []string{}},
|
||||
{name: "StringListPolicy_SubKey", newValue: subkeyStrings{"Value1", "Value2"}, want: []string{"Value1", "Value2"}},
|
||||
{name: "StringListPolicy_SubKey_Empty", newValue: subkeyStrings{}, want: []string{}},
|
||||
}
|
||||
|
||||
runTests := func(t *testing.T, userStore bool, token windows.Token) {
|
||||
var hive registry.Key
|
||||
if userStore {
|
||||
hive = registry.CURRENT_USER
|
||||
} else {
|
||||
hive = registry.LOCAL_MACHINE
|
||||
}
|
||||
|
||||
// Write policy values to the registry.
|
||||
newValues := make([]testPolicyValue, 0, len(tests))
|
||||
for _, tt := range tests {
|
||||
if tt.newValue != nil {
|
||||
newValues = append(newValues, testPolicyValue{name: tt.name, value: tt.newValue})
|
||||
}
|
||||
}
|
||||
policiesKeyName := softwareKeyName + `\` + tsPoliciesSubkey
|
||||
cleanup, err := createTestPolicyValues(hive, policiesKeyName, newValues)
|
||||
if err != nil {
|
||||
t.Fatalf("createTestPolicyValues failed: %v", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
// Write legacy policy values to the registry.
|
||||
legacyValues := make([]testPolicyValue, 0, len(tests))
|
||||
for _, tt := range tests {
|
||||
if tt.legacyValue != nil {
|
||||
legacyValues = append(legacyValues, testPolicyValue{name: tt.name, value: tt.legacyValue})
|
||||
}
|
||||
}
|
||||
legacyKeyName := softwareKeyName + `\` + tsIPNSubkey
|
||||
cleanup, err = createTestPolicyValues(hive, legacyKeyName, legacyValues)
|
||||
if err != nil {
|
||||
t.Fatalf("createTestPolicyValues failed: %v", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
var store *PlatformPolicyStore
|
||||
if userStore {
|
||||
store, err = NewUserPlatformPolicyStore(token)
|
||||
} else {
|
||||
store, err = NewMachinePlatformPolicyStore()
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("NewXPolicyStore failed: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if err := store.Close(); err != nil {
|
||||
t.Errorf("(*PolicyStore).Close failed: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// testReadValues checks that [PolicyStore] returns the same values we wrote directly to the registry.
|
||||
testReadValues := func(t *testing.T, withLocks bool) {
|
||||
for _, tt := range tests {
|
||||
t.Run(string(tt.name), func(t *testing.T) {
|
||||
if userStore && tt.newValue == nil {
|
||||
t.Skip("there is no legacy policies for users")
|
||||
}
|
||||
|
||||
t.Parallel()
|
||||
|
||||
if withLocks {
|
||||
if err := store.Lock(); err != nil {
|
||||
t.Errorf("failed to acquire the lock: %v", err)
|
||||
}
|
||||
defer store.Unlock()
|
||||
}
|
||||
|
||||
var got any
|
||||
var err error
|
||||
switch tt.want.(type) {
|
||||
case string:
|
||||
got, err = store.ReadString(tt.name)
|
||||
case uint64:
|
||||
got, err = store.ReadUInt64(tt.name)
|
||||
case bool:
|
||||
got, err = store.ReadBoolean(tt.name)
|
||||
case []string:
|
||||
got, err = store.ReadStringArray(tt.name)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("got %v; want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
t.Run("NoLock", func(t *testing.T) {
|
||||
testReadValues(t, false)
|
||||
})
|
||||
|
||||
t.Run("WithLock", func(t *testing.T) {
|
||||
testReadValues(t, true)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("MachineStore", func(t *testing.T) {
|
||||
runTests(t, false, 0)
|
||||
})
|
||||
|
||||
t.Run("CurrentUserStore", func(t *testing.T) {
|
||||
runTests(t, true, 0)
|
||||
})
|
||||
|
||||
t.Run("UserStoreWithToken", func(t *testing.T) {
|
||||
var token windows.Token
|
||||
if err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_QUERY, &token); err != nil {
|
||||
t.Fatalf("OpenProcessToken: %v", err)
|
||||
}
|
||||
defer token.Close()
|
||||
runTests(t, true, token)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPolicyStoreChangeNotifications(t *testing.T) {
|
||||
if cibuild.On() {
|
||||
t.Skipf("test requires running on a real Windows environment")
|
||||
}
|
||||
store, err := NewMachinePlatformPolicyStore()
|
||||
if err != nil {
|
||||
t.Fatalf("NewMachinePolicyStore failed: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if err := store.Close(); err != nil {
|
||||
t.Errorf("(*PolicyStore).Close failed: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
done := make(chan struct{})
|
||||
unregister, err := store.RegisterChangeCallback(func() { close(done) })
|
||||
if err != nil {
|
||||
t.Fatalf("RegisterChangeCallback failed: %v", err)
|
||||
}
|
||||
t.Cleanup(unregister)
|
||||
|
||||
// RefreshMachinePolicy is a non-blocking call.
|
||||
if err := gp.RefreshMachinePolicy(true); err != nil {
|
||||
t.Fatalf("RefreshMachinePolicy failed: %v", err)
|
||||
}
|
||||
|
||||
// We should receive a policy change notification when
|
||||
// the Group Policy service completes policy processing.
|
||||
// Otherwise, the test will eventually time out.
|
||||
<-done
|
||||
}
|
||||
|
||||
func TestSplitSettingKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key setting.Key
|
||||
wantPath string
|
||||
wantValue string
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
key: "",
|
||||
wantPath: ``,
|
||||
wantValue: "",
|
||||
},
|
||||
{
|
||||
name: "explicit-empty-path",
|
||||
key: "/ValueName",
|
||||
wantPath: ``,
|
||||
wantValue: "ValueName",
|
||||
},
|
||||
{
|
||||
name: "empty-value",
|
||||
key: "Root/Sub/",
|
||||
wantPath: `Root\Sub`,
|
||||
wantValue: "",
|
||||
},
|
||||
{
|
||||
name: "with-path",
|
||||
key: "Root/Sub/ValueName",
|
||||
wantPath: `Root\Sub`,
|
||||
wantValue: "ValueName",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotPath, gotValue := splitSettingKey(tt.key)
|
||||
if gotPath != tt.wantPath {
|
||||
t.Errorf("Path: got %q, want %q", gotPath, tt.wantPath)
|
||||
}
|
||||
if gotValue != tt.wantValue {
|
||||
t.Errorf("Value: got %q, want %q", gotValue, tt.wantPath)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createTestPolicyValues(hive registry.Key, keyName string, values []testPolicyValue) (cleanup func(), err error) {
|
||||
key, existing, err := registry.CreateKey(hive, keyName, registry.ALL_ACCESS)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var valuesToDelete map[string][]string
|
||||
doCleanup := func() {
|
||||
for path, values := range valuesToDelete {
|
||||
if len(values) == 0 {
|
||||
registry.DeleteKey(key, path)
|
||||
continue
|
||||
}
|
||||
key, err := registry.OpenKey(key, path, windows.KEY_ALL_ACCESS)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
defer key.Close()
|
||||
for _, value := range values {
|
||||
key.DeleteValue(value)
|
||||
}
|
||||
}
|
||||
|
||||
key.Close()
|
||||
if !existing {
|
||||
registry.DeleteKey(hive, keyName)
|
||||
}
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
doCleanup()
|
||||
}
|
||||
}()
|
||||
|
||||
for _, v := range values {
|
||||
key, existing := key, existing
|
||||
path, valueName := splitSettingKey(v.name)
|
||||
if path != "" {
|
||||
if key, existing, err = registry.CreateKey(key, valueName, windows.KEY_ALL_ACCESS); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer key.Close()
|
||||
}
|
||||
if values, ok := valuesToDelete[path]; len(values) > 0 || (!ok && existing) {
|
||||
values = append(values, valueName)
|
||||
mak.Set(&valuesToDelete, path, values)
|
||||
} else if !ok {
|
||||
mak.Set(&valuesToDelete, path, nil)
|
||||
}
|
||||
|
||||
switch value := v.value.(type) {
|
||||
case string:
|
||||
err = key.SetStringValue(valueName, value)
|
||||
case uint32:
|
||||
err = key.SetDWordValue(valueName, value)
|
||||
case uint64:
|
||||
err = key.SetQWordValue(valueName, value)
|
||||
case bool:
|
||||
if value {
|
||||
err = key.SetDWordValue(valueName, 1)
|
||||
} else {
|
||||
err = key.SetDWordValue(valueName, 0)
|
||||
}
|
||||
case []string:
|
||||
err = key.SetStringsValue(valueName, value)
|
||||
case subkeyStrings:
|
||||
key, _, err := registry.CreateKey(key, valueName, windows.KEY_ALL_ACCESS)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer key.Close()
|
||||
mak.Set(&valuesToDelete, strings.Trim(path+`\`+valueName, `\`), nil)
|
||||
for i, value := range value {
|
||||
if err := key.SetStringValue(strconv.Itoa(i), value); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
default:
|
||||
err = fmt.Errorf("unsupported value: %v (%T), name: %q", value, value, v.name)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return doCleanup, nil
|
||||
}
|
@ -0,0 +1,451 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package source
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
xmaps "golang.org/x/exp/maps"
|
||||
"tailscale.com/util/mak"
|
||||
"tailscale.com/util/set"
|
||||
"tailscale.com/util/syspolicy/internal"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
)
|
||||
|
||||
var (
|
||||
_ Store = (*TestStore)(nil)
|
||||
_ Lockable = (*TestStore)(nil)
|
||||
_ Changeable = (*TestStore)(nil)
|
||||
_ Expirable = (*TestStore)(nil)
|
||||
)
|
||||
|
||||
// TestValueType is a constraint that allows types supported by [TestStore].
|
||||
type TestValueType interface {
|
||||
bool | uint64 | string | []string
|
||||
}
|
||||
|
||||
// TestSetting is a policy setting in a [TestStore].
|
||||
type TestSetting[T TestValueType] struct {
|
||||
// Key is the setting's unique identifier.
|
||||
Key setting.Key
|
||||
// Error is the error to be returned by the [TestStore] when reading
|
||||
// a policy setting with the specified key.
|
||||
Error error
|
||||
// Value is the value to be returned by the [TestStore] when reading
|
||||
// a policy setting with the specified key.
|
||||
// It is only used if the Error is nil.
|
||||
Value T
|
||||
}
|
||||
|
||||
// TestSettingOf returns a [TestSetting] representing a policy setting
|
||||
// configured with the specified key and value.
|
||||
func TestSettingOf[T TestValueType](key setting.Key, value T) TestSetting[T] {
|
||||
return TestSetting[T]{Key: key, Value: value}
|
||||
}
|
||||
|
||||
// TestSettingWithError returns a [TestSetting] representing a policy setting
|
||||
// with the specified key and error.
|
||||
func TestSettingWithError[T TestValueType](key setting.Key, err error) TestSetting[T] {
|
||||
return TestSetting[T]{Key: key, Error: err}
|
||||
}
|
||||
|
||||
// testReadOperation describes a single policy setting read operation.
|
||||
type testReadOperation struct {
|
||||
// Key is the setting's unique identifier.
|
||||
Key setting.Key
|
||||
// Type is a value type of a read operation.
|
||||
// [setting.BooleanValue], [setting.IntegerValue], [setting.StringValue] or [setting.StringListValue]
|
||||
Type setting.Type
|
||||
}
|
||||
|
||||
// TestExpectedReads is the number of read operations with the specified details.
|
||||
type TestExpectedReads struct {
|
||||
// Key is the setting's unique identifier.
|
||||
Key setting.Key
|
||||
// Type is a value type of a read operation.
|
||||
// [setting.BooleanValue], [setting.IntegerValue], [setting.StringValue] or [setting.StringListValue]
|
||||
Type setting.Type
|
||||
// NumTimes is how many times a setting with the specified key and type should have been read.
|
||||
NumTimes int
|
||||
}
|
||||
|
||||
func (r TestExpectedReads) operation() testReadOperation {
|
||||
return testReadOperation{r.Key, r.Type}
|
||||
}
|
||||
|
||||
// TestStore is a [Store] that can be used in tests.
|
||||
type TestStore struct {
|
||||
tb internal.TB
|
||||
|
||||
done chan struct{}
|
||||
|
||||
storeLock sync.RWMutex // its RLock is exposed via [Store.Lock]/[Store.Unlock].
|
||||
storeLockCount atomic.Int32
|
||||
|
||||
mu sync.RWMutex
|
||||
suspendCount int // change callback are suspended if > 0
|
||||
mr, mw map[setting.Key]any // maps for reading and writing; they're the same unless the store is suspended.
|
||||
cbs set.HandleSet[func()]
|
||||
|
||||
readsMu sync.Mutex
|
||||
reads map[testReadOperation]int // how many times a policy setting was read
|
||||
}
|
||||
|
||||
// NewTestStore returns a new [TestStore].
|
||||
// The tb will be used to report coding errors detected by the [TestStore].
|
||||
func NewTestStore(tb internal.TB) *TestStore {
|
||||
m := make(map[setting.Key]any)
|
||||
return &TestStore{
|
||||
tb: tb,
|
||||
done: make(chan struct{}),
|
||||
mr: m,
|
||||
mw: m,
|
||||
}
|
||||
}
|
||||
|
||||
// NewTestStoreOf is a shorthand for [NewTestStore] followed by [TestStore.SetBooleans],
|
||||
// [TestStore.SetUInt64s], [TestStore.SetStrings] or [TestStore.SetStringLists].
|
||||
func NewTestStoreOf[T TestValueType](tb internal.TB, settings ...TestSetting[T]) *TestStore {
|
||||
m := make(map[setting.Key]any)
|
||||
store := &TestStore{
|
||||
tb: tb,
|
||||
done: make(chan struct{}),
|
||||
mr: m,
|
||||
mw: m,
|
||||
}
|
||||
switch settings := any(settings).(type) {
|
||||
case []TestSetting[bool]:
|
||||
store.SetBooleans(settings...)
|
||||
case []TestSetting[uint64]:
|
||||
store.SetUInt64s(settings...)
|
||||
case []TestSetting[string]:
|
||||
store.SetStrings(settings...)
|
||||
case []TestSetting[[]string]:
|
||||
store.SetStringLists(settings...)
|
||||
}
|
||||
return store
|
||||
}
|
||||
|
||||
// Lock implements [Lockable].
|
||||
func (s *TestStore) Lock() error {
|
||||
s.storeLock.RLock()
|
||||
s.storeLockCount.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unlock implements [Lockable].
|
||||
func (s *TestStore) Unlock() {
|
||||
if s.storeLockCount.Add(-1) < 0 {
|
||||
s.tb.Fatal("negative storeLockCount")
|
||||
}
|
||||
s.storeLock.RUnlock()
|
||||
}
|
||||
|
||||
// RegisterChangeCallback implements [Changeable].
|
||||
func (s *TestStore) RegisterChangeCallback(callback func()) (unregister func(), err error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
handle := s.cbs.Add(callback)
|
||||
return func() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.cbs, handle)
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ReadString implements [Store].
|
||||
func (s *TestStore) ReadString(key setting.Key) (string, error) {
|
||||
defer s.recordRead(key, setting.StringValue)
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
v, ok := s.mr[key]
|
||||
if !ok {
|
||||
return "", setting.ErrNotConfigured
|
||||
}
|
||||
if err, ok := v.(error); ok {
|
||||
return "", err
|
||||
}
|
||||
str, ok := v.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("%w in ReadString: got %T", setting.ErrTypeMismatch, v)
|
||||
}
|
||||
return str, nil
|
||||
}
|
||||
|
||||
// ReadUInt64 implements [Store].
|
||||
func (s *TestStore) ReadUInt64(key setting.Key) (uint64, error) {
|
||||
defer s.recordRead(key, setting.IntegerValue)
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
v, ok := s.mr[key]
|
||||
if !ok {
|
||||
return 0, setting.ErrNotConfigured
|
||||
}
|
||||
if err, ok := v.(error); ok {
|
||||
return 0, err
|
||||
}
|
||||
u64, ok := v.(uint64)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("%w in ReadUInt64: got %T", setting.ErrTypeMismatch, v)
|
||||
}
|
||||
return u64, nil
|
||||
}
|
||||
|
||||
// ReadBoolean implements [Store].
|
||||
func (s *TestStore) ReadBoolean(key setting.Key) (bool, error) {
|
||||
defer s.recordRead(key, setting.BooleanValue)
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
v, ok := s.mr[key]
|
||||
if !ok {
|
||||
return false, setting.ErrNotConfigured
|
||||
}
|
||||
if err, ok := v.(error); ok {
|
||||
return false, err
|
||||
}
|
||||
b, ok := v.(bool)
|
||||
if !ok {
|
||||
return false, fmt.Errorf("%w in ReadBoolean: got %T", setting.ErrTypeMismatch, v)
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// ReadStringArray implements [Store].
|
||||
func (s *TestStore) ReadStringArray(key setting.Key) ([]string, error) {
|
||||
defer s.recordRead(key, setting.StringListValue)
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
v, ok := s.mr[key]
|
||||
if !ok {
|
||||
return nil, setting.ErrNotConfigured
|
||||
}
|
||||
if err, ok := v.(error); ok {
|
||||
return nil, err
|
||||
}
|
||||
slice, ok := v.([]string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%w in ReadStringArray: got %T", setting.ErrTypeMismatch, v)
|
||||
}
|
||||
return slice, nil
|
||||
}
|
||||
|
||||
func (s *TestStore) recordRead(key setting.Key, typ setting.Type) {
|
||||
s.readsMu.Lock()
|
||||
op := testReadOperation{key, typ}
|
||||
num := s.reads[op]
|
||||
num++
|
||||
mak.Set(&s.reads, op, num)
|
||||
s.readsMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *TestStore) ResetCounters() {
|
||||
s.readsMu.Lock()
|
||||
clear(s.reads)
|
||||
s.readsMu.Unlock()
|
||||
}
|
||||
|
||||
// ReadsMustEqual fails the test if the actual reads differs from the specified reads.
|
||||
func (s *TestStore) ReadsMustEqual(reads ...TestExpectedReads) {
|
||||
s.tb.Helper()
|
||||
s.readsMu.Lock()
|
||||
defer s.readsMu.Unlock()
|
||||
s.readsMustContainLocked(reads...)
|
||||
s.readMustNoExtraLocked(reads...)
|
||||
}
|
||||
|
||||
// ReadsMustContain fails the test if the specified reads have not been made,
|
||||
// or have been made a different number of times. It permits other values to be
|
||||
// read in addition to the ones being tested.
|
||||
func (s *TestStore) ReadsMustContain(reads ...TestExpectedReads) {
|
||||
s.tb.Helper()
|
||||
s.readsMu.Lock()
|
||||
defer s.readsMu.Unlock()
|
||||
s.readsMustContainLocked(reads...)
|
||||
}
|
||||
|
||||
func (s *TestStore) readsMustContainLocked(reads ...TestExpectedReads) {
|
||||
s.tb.Helper()
|
||||
for _, r := range reads {
|
||||
if numTimes := s.reads[r.operation()]; numTimes != r.NumTimes {
|
||||
s.tb.Errorf("%q (%v) reads: got %v, want %v", r.Key, r.Type, numTimes, r.NumTimes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TestStore) readMustNoExtraLocked(reads ...TestExpectedReads) {
|
||||
s.tb.Helper()
|
||||
rs := make(set.Set[testReadOperation])
|
||||
for i := range reads {
|
||||
rs.Add(reads[i].operation())
|
||||
}
|
||||
for ro, num := range s.reads {
|
||||
if !rs.Contains(ro) {
|
||||
s.tb.Errorf("%q (%v) reads: got %v, want 0", ro.Key, ro.Type, num)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Suspend suspends the store, batching changes and notifications
|
||||
// until [TestStore.Resume] is called the same number of times as Suspend.
|
||||
func (s *TestStore) Suspend() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.suspendCount++; s.suspendCount == 1 {
|
||||
s.mw = xmaps.Clone(s.mr)
|
||||
}
|
||||
}
|
||||
|
||||
// Resume resumes the store, applying the changes and invoking
|
||||
// the change callbacks.
|
||||
func (s *TestStore) Resume() {
|
||||
s.storeLock.Lock()
|
||||
s.mu.Lock()
|
||||
switch s.suspendCount--; {
|
||||
case s.suspendCount == 0:
|
||||
s.mr = s.mw
|
||||
s.mu.Unlock()
|
||||
s.storeLock.Unlock()
|
||||
s.notifyPolicyChanged()
|
||||
case s.suspendCount < 0:
|
||||
s.tb.Fatal("negative suspendCount")
|
||||
default:
|
||||
s.mu.Unlock()
|
||||
s.storeLock.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// SetBooleans sets the specified boolean settings in s.
|
||||
func (s *TestStore) SetBooleans(settings ...TestSetting[bool]) {
|
||||
s.storeLock.Lock()
|
||||
for _, setting := range settings {
|
||||
if setting.Key == "" {
|
||||
s.tb.Fatal("empty keys disallowed")
|
||||
}
|
||||
s.mu.Lock()
|
||||
if setting.Error != nil {
|
||||
mak.Set(&s.mw, setting.Key, any(setting.Error))
|
||||
} else {
|
||||
mak.Set(&s.mw, setting.Key, any(setting.Value))
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
s.storeLock.Unlock()
|
||||
s.notifyPolicyChanged()
|
||||
}
|
||||
|
||||
// SetUInt64s sets the specified integer settings in s.
|
||||
func (s *TestStore) SetUInt64s(settings ...TestSetting[uint64]) {
|
||||
s.storeLock.Lock()
|
||||
for _, setting := range settings {
|
||||
if setting.Key == "" {
|
||||
s.tb.Fatal("empty keys disallowed")
|
||||
}
|
||||
s.mu.Lock()
|
||||
if setting.Error != nil {
|
||||
mak.Set(&s.mw, setting.Key, any(setting.Error))
|
||||
} else {
|
||||
mak.Set(&s.mw, setting.Key, any(setting.Value))
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
s.storeLock.Unlock()
|
||||
s.notifyPolicyChanged()
|
||||
}
|
||||
|
||||
// SetStrings sets the specified string settings in s.
|
||||
func (s *TestStore) SetStrings(settings ...TestSetting[string]) {
|
||||
s.storeLock.Lock()
|
||||
for _, setting := range settings {
|
||||
if setting.Key == "" {
|
||||
s.tb.Fatal("empty keys disallowed")
|
||||
}
|
||||
s.mu.Lock()
|
||||
if setting.Error != nil {
|
||||
mak.Set(&s.mw, setting.Key, any(setting.Error))
|
||||
} else {
|
||||
mak.Set(&s.mw, setting.Key, any(setting.Value))
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
s.storeLock.Unlock()
|
||||
s.notifyPolicyChanged()
|
||||
}
|
||||
|
||||
// SetStrings sets the specified string list settings in s.
|
||||
func (s *TestStore) SetStringLists(settings ...TestSetting[[]string]) {
|
||||
s.storeLock.Lock()
|
||||
for _, setting := range settings {
|
||||
if setting.Key == "" {
|
||||
s.tb.Fatal("empty keys disallowed")
|
||||
}
|
||||
s.mu.Lock()
|
||||
if setting.Error != nil {
|
||||
mak.Set(&s.mw, setting.Key, any(setting.Error))
|
||||
} else {
|
||||
mak.Set(&s.mw, setting.Key, any(setting.Value))
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
s.storeLock.Unlock()
|
||||
s.notifyPolicyChanged()
|
||||
}
|
||||
|
||||
// Delete deletes the specified settings from s.
|
||||
func (s *TestStore) Delete(keys ...setting.Key) {
|
||||
s.storeLock.Lock()
|
||||
for _, key := range keys {
|
||||
s.mu.Lock()
|
||||
delete(s.mw, key)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
s.storeLock.Unlock()
|
||||
s.notifyPolicyChanged()
|
||||
}
|
||||
|
||||
// Clear deletes all settings from s.
|
||||
func (s *TestStore) Clear() {
|
||||
s.storeLock.Lock()
|
||||
s.mu.Lock()
|
||||
clear(s.mw)
|
||||
s.mu.Unlock()
|
||||
s.storeLock.Unlock()
|
||||
s.notifyPolicyChanged()
|
||||
}
|
||||
|
||||
func (s *TestStore) notifyPolicyChanged() {
|
||||
s.mu.RLock()
|
||||
if s.suspendCount != 0 {
|
||||
s.mu.RUnlock()
|
||||
return
|
||||
}
|
||||
cbs := xmaps.Values(s.cbs)
|
||||
s.mu.RUnlock()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(cbs))
|
||||
for _, cb := range cbs {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
cb()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Close closes s, notifying its users that it has expired.
|
||||
func (s *TestStore) Close() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.done != nil {
|
||||
close(s.done)
|
||||
s.done = nil
|
||||
}
|
||||
}
|
||||
|
||||
// Done implements [Expirable].
|
||||
func (s *TestStore) Done() <-chan struct{} {
|
||||
return s.done
|
||||
}
|
Loading…
Reference in New Issue