diff --git a/util/syspolicy/internal/internal.go b/util/syspolicy/internal/internal.go index 4c3e28d39..8f2889625 100644 --- a/util/syspolicy/internal/internal.go +++ b/util/syspolicy/internal/internal.go @@ -13,6 +13,9 @@ import ( "tailscale.com/version" ) +// Init facilitates deferred invocation of initializers. +var Init lazy.DeferredInit + // OSForTesting is the operating system override used for testing. // It follows the same naming convention as [version.OS]. var OSForTesting lazy.SyncValue[string] diff --git a/util/syspolicy/rsop/change_callbacks.go b/util/syspolicy/rsop/change_callbacks.go new file mode 100644 index 000000000..b962f30c0 --- /dev/null +++ b/util/syspolicy/rsop/change_callbacks.go @@ -0,0 +1,107 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package rsop + +import ( + "reflect" + "slices" + "sync" + "time" + + "tailscale.com/util/set" + "tailscale.com/util/syspolicy/internal/loggerx" + "tailscale.com/util/syspolicy/setting" +) + +// Change represents a change from the Old to the New value of type T. +type Change[T any] struct { + New, Old T +} + +// PolicyChangeCallback is a function called whenever a policy changes. +type PolicyChangeCallback func(*PolicyChange) + +// PolicyChange describes a policy change. +type PolicyChange struct { + snapshots Change[*setting.Snapshot] +} + +// New returns the [setting.Snapshot] after the change. +func (c PolicyChange) New() *setting.Snapshot { + return c.snapshots.New +} + +// Old returns the [setting.Snapshot] before the change. +func (c PolicyChange) Old() *setting.Snapshot { + return c.snapshots.Old +} + +// HasChanged reports whether a policy setting with the specified [setting.Key], has changed. +func (c PolicyChange) HasChanged(key setting.Key) bool { + new, newErr := c.snapshots.New.GetErr(key) + old, oldErr := c.snapshots.Old.GetErr(key) + if newErr != nil && oldErr != nil { + return false + } + if newErr != nil || oldErr != nil { + return true + } + switch newVal := new.(type) { + case bool, uint64, string, setting.Visibility, setting.PreferenceOption, time.Duration: + return newVal != old + case []string: + oldVal, ok := old.([]string) + return !ok || !slices.Equal(newVal, oldVal) + default: + loggerx.Errorf("[unexpected] %q has an unsupported value type: %T", key, newVal) + return !reflect.DeepEqual(new, old) + } +} + +// policyChangeCallbacks are the callbacks to invoke when the effective policy changes. +// It is safe for concurrent use. +type policyChangeCallbacks struct { + mu sync.Mutex + cbs set.HandleSet[PolicyChangeCallback] +} + +// Register adds the specified callback to be invoked whenever the policy changes. +func (c *policyChangeCallbacks) Register(callback PolicyChangeCallback) (unregister func()) { + c.mu.Lock() + handle := c.cbs.Add(callback) + c.mu.Unlock() + return func() { + c.mu.Lock() + delete(c.cbs, handle) + c.mu.Unlock() + } +} + +// Invoke calls the registered callback functions with the specified policy change info. +func (c *policyChangeCallbacks) Invoke(snapshots Change[*setting.Snapshot]) { + var wg sync.WaitGroup + defer wg.Wait() + + c.mu.Lock() + defer c.mu.Unlock() + + wg.Add(len(c.cbs)) + change := &PolicyChange{snapshots: snapshots} + for _, cb := range c.cbs { + go func() { + defer wg.Done() + cb(change) + }() + } +} + +// Close awaits the completion of active callbacks and prevents any further invocations. +func (c *policyChangeCallbacks) Close() { + c.mu.Lock() + defer c.mu.Unlock() + if c.cbs != nil { + clear(c.cbs) + c.cbs = nil + } +} diff --git a/util/syspolicy/rsop/resultant_policy.go b/util/syspolicy/rsop/resultant_policy.go new file mode 100644 index 000000000..019b8f602 --- /dev/null +++ b/util/syspolicy/rsop/resultant_policy.go @@ -0,0 +1,449 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package rsop + +import ( + "errors" + "fmt" + "slices" + "sync" + "sync/atomic" + "time" + + "tailscale.com/util/syspolicy/internal/loggerx" + "tailscale.com/util/syspolicy/setting" + + "tailscale.com/util/syspolicy/source" +) + +// ErrPolicyClosed is returned by [Policy.Reload], [Policy.addSource], +// [Policy.removeSource] and [Policy.replaceSource] if the policy has been closed. +var ErrPolicyClosed = errors.New("effective policy closed") + +// The minimum and maximum wait times after detecting a policy change +// before reloading the policy. This only affects policy reloads triggered +// by a change in the underlying [source.Store] and does not impact +// synchronous, caller-initiated reloads, such as when [Policy.Reload] is called. +// +// Policy changes occurring within [policyReloadMinDelay] of each other +// will be batched together, resulting in a single policy reload +// no later than [policyReloadMaxDelay] after the first detected change. +// In other words, the effective policy will be reloaded no more often than once +// every 5 seconds, but at most 15 seconds after an underlying [source.Store] +// has issued a policy change callback. +// +// See [Policy.watchReload]. +var ( + policyReloadMinDelay = 5 * time.Second + policyReloadMaxDelay = 15 * time.Second +) + +// Policy provides access to the current effective [setting.Snapshot] for a given +// scope and allows to reload it from the underlying [source.Store] list. It also allows to +// subscribe and receive a callback whenever the effective [setting.Snapshot] is changed. +// +// It is safe for concurrent use. +type Policy struct { + scope setting.PolicyScope + + reloadCh chan reloadRequest // 1-buffered; written to when a policy reload is required + closeCh chan struct{} // closed to signal that the Policy is being closed + doneCh chan struct{} // closed by [Policy.closeInternal] + + // effective is the most recent version of the [setting.Snapshot] + // containing policy settings merged from all applicable sources. + effective atomic.Pointer[setting.Snapshot] + + changeCallbacks policyChangeCallbacks + + mu sync.Mutex + watcherStarted bool // whether [Policy.watchReload] was started + sources source.ReadableSources + closing bool // whether [Policy.Close] was called (even if we're still closing) +} + +// newPolicy returns a new [Policy] for the specified [setting.PolicyScope] +// that tracks changes and merges policy settings read from the specified sources. +func newPolicy(scope setting.PolicyScope, sources ...*source.Source) (_ *Policy, err error) { + readableSources := make(source.ReadableSources, 0, len(sources)) + defer func() { + if err != nil { + readableSources.Close() + } + }() + for _, s := range sources { + reader, err := s.Reader() + if err != nil { + return nil, fmt.Errorf("failed to get a store reader: %w", err) + } + session, err := reader.OpenSession() + if err != nil { + return nil, fmt.Errorf("failed to open a reading session: %w", err) + } + readableSources = append(readableSources, source.ReadableSource{Source: s, ReadingSession: session}) + } + + // Sort policy sources by their precedence from lower to higher. + // For example, {UserPolicy},{ProfilePolicy},{DevicePolicy}. + readableSources.StableSort() + + p := &Policy{ + scope: scope, + sources: readableSources, + reloadCh: make(chan reloadRequest, 1), + closeCh: make(chan struct{}), + doneCh: make(chan struct{}), + } + if _, err := p.reloadNow(false); err != nil { + p.Close() + return nil, err + } + p.startWatchReloadIfNeeded() + return p, nil +} + +// IsValid reports whether p is in a valid state and has not been closed. +// +// Since p's state can be changed by other goroutines at any time, this should +// only be used as an optimization. +func (p *Policy) IsValid() bool { + select { + case <-p.closeCh: + return false + default: + return true + } +} + +// Scope returns the [setting.PolicyScope] that this policy applies to. +func (p *Policy) Scope() setting.PolicyScope { + return p.scope +} + +// Get returns the effective [setting.Snapshot]. +func (p *Policy) Get() *setting.Snapshot { + return p.effective.Load() +} + +// RegisterChangeCallback adds a function to be called whenever the effective +// policy changes. The returned function can be used to unregister the callback. +func (p *Policy) RegisterChangeCallback(callback PolicyChangeCallback) (unregister func()) { + return p.changeCallbacks.Register(callback) +} + +// Reload synchronously re-reads policy settings from the underlying list of policy sources, +// constructing a new merged [setting.Snapshot] even if the policy remains unchanged. +// In most scenarios, there's no need to re-read the policy manually. +// Instead, it is recommended to register a policy change callback, or to use +// the most recent [setting.Snapshot] returned by the [Policy.Get] method. +// +// It must not be called with p.mu held. +func (p *Policy) Reload() (*setting.Snapshot, error) { + return p.reload(true) +} + +// reload is like Reload, but allows to specify whether to re-read policy settings +// from unchanged policy sources. +// +// It must not be called with p.mu held. +func (p *Policy) reload(force bool) (*setting.Snapshot, error) { + if !p.startWatchReloadIfNeeded() { + return p.Get(), nil + } + + respCh := make(chan reloadResponse, 1) + select { + case p.reloadCh <- reloadRequest{force: force, respCh: respCh}: + // continue + case <-p.closeCh: + return nil, ErrPolicyClosed + } + select { + case resp := <-respCh: + return resp.policy, resp.err + case <-p.closeCh: + return nil, ErrPolicyClosed + } +} + +// reloadAsync requests an asynchronous background policy reload. +// The policy will be reloaded no later than in [policyReloadMaxDelay]. +// +// It must not be called with p.mu held. +func (p *Policy) reloadAsync() { + if !p.startWatchReloadIfNeeded() { + return + } + select { + case p.reloadCh <- reloadRequest{}: + // Sent. + default: + // A reload request is already en route. + } +} + +// reloadNow loads and merges policies from all sources, updating the effective policy. +// If the force parameter is true, it forcibly reloads policies +// from the underlying policy store, even if no policy changes were detected. +// +// Except for the initial policy reload during the [Policy] creation, +// this method should only be called from the [Policy.watchReload] goroutine. +func (p *Policy) reloadNow(force bool) (*setting.Snapshot, error) { + new, err := p.readAndMerge(force) + if err != nil { + return nil, err + } + old := p.effective.Swap(new) + // A nil old value indicates the initial policy load rather than a policy change. + // Additionally, we should not invoke the policy change callbacks unless the + // policy items have actually changed. + if old != nil && !old.EqualItems(new) { + snapshots := Change[*setting.Snapshot]{New: new, Old: old} + p.changeCallbacks.Invoke(snapshots) + } + return new, nil +} + +// Done returns a channel that is closed when the [Policy] is closed. +func (p *Policy) Done() <-chan struct{} { + return p.doneCh +} + +// readAndMerge reads and merges policy settings from all applicable sources, +// returning a [setting.Snapshot] with the merged result. +// If the force parameter is true, it re-reads policy settings from each source +// even if no policy change was observed, and returns an error if the read +// operation fails. +func (p *Policy) readAndMerge(force bool) (*setting.Snapshot, error) { + p.mu.Lock() + defer p.mu.Unlock() + // Start with an empty policy in the target scope. + effective := setting.NewSnapshot(nil, setting.SummaryWith(p.scope)) + // Then merge policy settings from all sources. + // Policy sources with the highest precedence (e.g., the device policy) are merged last, + // overriding any conflicting policy settings with lower precedence. + for _, s := range p.sources { + var policy *setting.Snapshot + if force { + var err error + if policy, err = s.ReadSettings(); err != nil { + return nil, err + } + } else { + policy = s.GetSettings() + } + effective = setting.MergeSnapshots(effective, policy) + } + return effective, nil +} + +// addSource adds the specified source to the list of sources used by p, +// and triggers a synchronous policy refresh. It returns an error +// if the source is not a valid source for this effective policy, +// or if the effective policy is being closed, +// or if policy refresh fails with an error. +func (p *Policy) addSource(source *source.Source) error { + return p.applySourcesChange(source, nil) +} + +// removeSource removes the specified source from the list of sources used by p, +// and triggers a synchronous policy refresh. It returns an error if the +// effective policy is being closed, or if policy refresh fails with an error. +func (p *Policy) removeSource(source *source.Source) error { + return p.applySourcesChange(nil, source) +} + +// replaceSource replaces the old source with the new source atomically, +// and triggers a synchronous policy refresh. It returns an error +// if the source is not a valid source for this effective policy, +// or if the effective policy is being closed, +// or if policy refresh fails with an error. +func (p *Policy) replaceSource(old, new *source.Source) error { + return p.applySourcesChange(new, old) +} + +func (p *Policy) applySourcesChange(toAdd, toRemove *source.Source) error { + if toAdd == toRemove { + return nil + } + if toAdd != nil && !toAdd.Scope().Contains(p.scope) { + return errors.New("scope mismatch") + } + + changed, err := func() (changed bool, err error) { + p.mu.Lock() + defer p.mu.Unlock() + if toAdd != nil && !p.sources.Contains(toAdd) { + reader, err := toAdd.Reader() + if err != nil { + return false, fmt.Errorf("failed to get a store reader: %w", err) + } + session, err := reader.OpenSession() + if err != nil { + return false, fmt.Errorf("failed to open a reading session: %w", err) + } + + addAt := p.sources.InsertionIndexOf(toAdd) + toAdd := source.ReadableSource{ + Source: toAdd, + ReadingSession: session, + } + p.sources = slices.Insert(p.sources, addAt, toAdd) + go p.watchPolicyChanges(toAdd) + changed = true + } + if toRemove != nil { + if deleteAt := p.sources.IndexOf(toRemove); deleteAt != -1 { + p.sources.DeleteAt(deleteAt) + changed = true + } + } + return changed, nil + }() + if changed { + _, err = p.reload(false) + } + return err // may be nil or non-nil +} + +func (p *Policy) watchPolicyChanges(s source.ReadableSource) { + for { + select { + case _, ok := <-s.ReadingSession.PolicyChanged(): + if !ok { + p.mu.Lock() + abruptlyClosed := slices.Contains(p.sources, s) + p.mu.Unlock() + if abruptlyClosed { + // The underlying [source.Source] was closed abruptly without + // being properly removed or replaced by another policy source. + // We can't keep this [Policy] up to date, so we should close it. + p.Close() + } + return + } + // The PolicyChanged channel was signaled. + // Request an asynchronous policy reload. + p.reloadAsync() + case <-p.closeCh: + // The [Policy] is being closed. + return + } + } +} + +// startWatchReloadIfNeeded starts [Policy.watchReload] in a new goroutine +// if the list of policy sources is not empty, it hasn't been started yet, +// and the [Policy] is not being closed. +// It reports whether [Policy.watchReload] has ever been started. +// +// It must not be called with p.mu held. +func (p *Policy) startWatchReloadIfNeeded() bool { + p.mu.Lock() + defer p.mu.Unlock() + if len(p.sources) != 0 && !p.watcherStarted && !p.closing { + go p.watchReload() + for i := range p.sources { + go p.watchPolicyChanges(p.sources[i]) + } + p.watcherStarted = true + } + return p.watcherStarted +} + +// reloadRequest describes a policy reload request. +type reloadRequest struct { + // force policy reload regardless of whether a policy change was detected. + force bool + // respCh is an optional channel. If non-nil, it makes the reload request + // synchronous and receives the result. + respCh chan<- reloadResponse +} + +// reloadResponse is a result of a synchronous policy reload. +type reloadResponse struct { + policy *setting.Snapshot + err error +} + +// watchReload processes incoming synchronous and asynchronous policy reload requests. +// +// Synchronous requests (with a non-nil respCh) are served immediately. +// +// Asynchronous requests are debounced and throttled: they are executed at least +// [policyReloadMinDelay] after the last request, but no later than [policyReloadMaxDelay] +// after the first request in a batch. +func (p *Policy) watchReload() { + defer p.closeInternal() + + force := false // whether a forced refresh was requested + var delayCh, timeoutCh <-chan time.Time + reload := func(respCh chan<- reloadResponse) { + delayCh, timeoutCh = nil, nil + policy, err := p.reloadNow(force) + if err != nil { + loggerx.Errorf("%v policy reload failed: %v\n", p.scope, err) + } + if respCh != nil { + respCh <- reloadResponse{policy: policy, err: err} + } + force = false + } + +loop: + for { + select { + case req := <-p.reloadCh: + if req.force { + force = true + } + if req.respCh != nil { + reload(req.respCh) + continue + } + if delayCh == nil { + timeoutCh = time.After(policyReloadMinDelay) + } + delayCh = time.After(policyReloadMaxDelay) + case <-delayCh: + reload(nil) + case <-timeoutCh: + reload(nil) + case <-p.closeCh: + break loop + } + } +} + +func (p *Policy) closeInternal() { + p.mu.Lock() + defer p.mu.Unlock() + p.sources.Close() + p.changeCallbacks.Close() + close(p.doneCh) + deletePolicy(p) +} + +// Close initiates the closing of the policy. +// The [Policy.Done] channel is closed to signal that the operation has been completed. +func (p *Policy) Close() { + p.mu.Lock() + alreadyClosing := p.closing + watcherStarted := p.watcherStarted + p.closing = true + p.mu.Unlock() + + if alreadyClosing { + return + } + + close(p.closeCh) + if !watcherStarted { + // Normally, closing p.closeCh signals [Policy.watchReload] to exit, + // and [Policy.closeInternal] performs the actual closing when + // [Policy.watchReload] returns. However, if the watcher was never + // started, we need to call [Policy.closeInternal] manually. + go p.closeInternal() + } +} diff --git a/util/syspolicy/rsop/resultant_policy_test.go b/util/syspolicy/rsop/resultant_policy_test.go new file mode 100644 index 000000000..b2408c7f7 --- /dev/null +++ b/util/syspolicy/rsop/resultant_policy_test.go @@ -0,0 +1,986 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package rsop + +import ( + "errors" + "slices" + "sort" + "strconv" + "sync" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "tailscale.com/tstest" + "tailscale.com/util/syspolicy/setting" + + "tailscale.com/util/syspolicy/source" +) + +func TestGetEffectivePolicyNoSource(t *testing.T) { + tests := []struct { + name string + scope setting.PolicyScope + }{ + { + name: "DevicePolicy", + scope: setting.DeviceScope, + }, + { + name: "CurrentProfilePolicy", + scope: setting.CurrentProfileScope, + }, + { + name: "CurrentUserPolicy", + scope: setting.CurrentUserScope, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var policy *Policy + t.Cleanup(func() { + if policy != nil { + policy.Close() + <-policy.Done() + } + }) + + // Make sure we don't create any goroutines. + // We intentionally call ResourceCheck after t.Cleanup, so that when the test exits, + // the resource check runs before the test cleanup closes the policy. + // This helps to report any unexpectedly created goroutines. + // The goal is to ensure that using the syspolicy package, and particularly + // the rsop sub-package, is not wasteful and does not create unnecessary goroutines + // on platforms without registered policy sources. + tstest.ResourceCheck(t) + + policy, err := PolicyFor(tt.scope) + if err != nil { + t.Fatalf("Failed to get effective policy for %v: %v", tt.scope, err) + } + + if got := policy.Get(); got.Len() != 0 { + t.Errorf("Snapshot: got %v; want empty", got) + } + + if got, err := policy.Reload(); err != nil { + t.Errorf("Reload failed: %v", err) + } else if got.Len() != 0 { + t.Errorf("Snapshot: got %v; want empty", got) + } + }) + } +} + +func TestRegisterSourceAndGetEffectivePolicy(t *testing.T) { + type sourceConfig struct { + name string + scope setting.PolicyScope + settingKey setting.Key + settingValue string + wantEffective bool + } + tests := []struct { + name string + scope setting.PolicyScope + initialSources []sourceConfig + additionalSources []sourceConfig + wantSnapshot *setting.Snapshot + }{ + { + name: "DevicePolicy/NoSources", + scope: setting.DeviceScope, + wantSnapshot: setting.NewSnapshot(nil, setting.DeviceScope), + }, + { + name: "UserScope/NoSources", + scope: setting.CurrentUserScope, + wantSnapshot: setting.NewSnapshot(nil, setting.CurrentUserScope), + }, + { + name: "DevicePolicy/OneInitialSource", + scope: setting.DeviceScope, + initialSources: []sourceConfig{ + { + name: "TestSourceA", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "TestValueA", + wantEffective: true, + }, + }, + wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{ + "TestKeyA": setting.RawItemWith("TestValueA", nil, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)), + }, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)), + }, + { + name: "DevicePolicy/OneAdditionalSource", + scope: setting.DeviceScope, + additionalSources: []sourceConfig{ + { + name: "TestSourceA", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "TestValueA", + wantEffective: true, + }, + }, + wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{ + "TestKeyA": setting.RawItemWith("TestValueA", nil, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)), + }, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)), + }, + { + name: "DevicePolicy/ManyInitialSources/NoConflicts", + scope: setting.DeviceScope, + initialSources: []sourceConfig{ + { + name: "TestSourceA", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "TestValueA", + wantEffective: true, + }, + { + name: "TestSourceB", + scope: setting.DeviceScope, + settingKey: "TestKeyB", + settingValue: "TestValueB", + wantEffective: true, + }, + { + name: "TestSourceC", + scope: setting.DeviceScope, + settingKey: "TestKeyC", + settingValue: "TestValueC", + wantEffective: true, + }, + }, + wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{ + "TestKeyA": setting.RawItemWith("TestValueA", nil, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)), + "TestKeyB": setting.RawItemWith("TestValueB", nil, setting.NewNamedOrigin("TestSourceB", setting.DeviceScope)), + "TestKeyC": setting.RawItemWith("TestValueC", nil, setting.NewNamedOrigin("TestSourceC", setting.DeviceScope)), + }, setting.DeviceScope), + }, + { + name: "DevicePolicy/ManyInitialSources/Conflicts", + scope: setting.DeviceScope, + initialSources: []sourceConfig{ + { + name: "TestSourceA", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "TestValueA", + wantEffective: true, + }, + { + name: "TestSourceB", + scope: setting.DeviceScope, + settingKey: "TestKeyB", + settingValue: "TestValueB", + wantEffective: true, + }, + { + name: "TestSourceC", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "TestValueC", + wantEffective: true, + }, + }, + wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{ + "TestKeyA": setting.RawItemWith("TestValueC", nil, setting.NewNamedOrigin("TestSourceC", setting.DeviceScope)), + "TestKeyB": setting.RawItemWith("TestValueB", nil, setting.NewNamedOrigin("TestSourceB", setting.DeviceScope)), + }, setting.DeviceScope), + }, + { + name: "DevicePolicy/MixedSources/Conflicts", + scope: setting.DeviceScope, + initialSources: []sourceConfig{ + { + name: "TestSourceA", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "TestValueA", + wantEffective: true, + }, + { + name: "TestSourceB", + scope: setting.DeviceScope, + settingKey: "TestKeyB", + settingValue: "TestValueB", + wantEffective: true, + }, + { + name: "TestSourceC", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "TestValueC", + wantEffective: true, + }, + }, + additionalSources: []sourceConfig{ + { + name: "TestSourceD", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "TestValueD", + wantEffective: true, + }, + { + name: "TestSourceE", + scope: setting.DeviceScope, + settingKey: "TestKeyC", + settingValue: "TestValueE", + wantEffective: true, + }, + { + name: "TestSourceF", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "TestValueF", + wantEffective: true, + }, + }, + wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{ + "TestKeyA": setting.RawItemWith("TestValueF", nil, setting.NewNamedOrigin("TestSourceF", setting.DeviceScope)), + "TestKeyB": setting.RawItemWith("TestValueB", nil, setting.NewNamedOrigin("TestSourceB", setting.DeviceScope)), + "TestKeyC": setting.RawItemWith("TestValueE", nil, setting.NewNamedOrigin("TestSourceE", setting.DeviceScope)), + }, setting.DeviceScope), + }, + { + name: "UserScope/Init-DeviceSource", + scope: setting.CurrentUserScope, + initialSources: []sourceConfig{ + { + name: "TestSourceDevice", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "DeviceValue", + wantEffective: true, + }, + }, + wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{ + "TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)), + }, setting.CurrentUserScope, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)), + }, + { + name: "UserScope/Init-DeviceSource/Add-UserSource", + scope: setting.CurrentUserScope, + initialSources: []sourceConfig{ + { + name: "TestSourceDevice", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "DeviceValue", + wantEffective: true, + }, + }, + additionalSources: []sourceConfig{ + { + name: "TestSourceUser", + scope: setting.CurrentUserScope, + settingKey: "TestKeyB", + settingValue: "UserValue", + wantEffective: true, + }, + }, + wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{ + "TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)), + "TestKeyB": setting.RawItemWith("UserValue", nil, setting.NewNamedOrigin("TestSourceUser", setting.CurrentUserScope)), + }, setting.CurrentUserScope), + }, + { + name: "UserScope/Init-DeviceSource/Add-UserSource-and-ProfileSource", + scope: setting.CurrentUserScope, + initialSources: []sourceConfig{ + { + name: "TestSourceDevice", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "DeviceValue", + wantEffective: true, + }, + }, + additionalSources: []sourceConfig{ + { + name: "TestSourceProfile", + scope: setting.CurrentProfileScope, + settingKey: "TestKeyB", + settingValue: "ProfileValue", + wantEffective: true, + }, + { + name: "TestSourceUser", + scope: setting.CurrentUserScope, + settingKey: "TestKeyB", + settingValue: "UserValue", + wantEffective: true, + }, + }, + wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{ + "TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)), + "TestKeyB": setting.RawItemWith("ProfileValue", nil, setting.NewNamedOrigin("TestSourceProfile", setting.CurrentProfileScope)), + }, setting.CurrentUserScope), + }, + { + name: "DevicePolicy/User-Source-does-not-apply", + scope: setting.DeviceScope, + initialSources: []sourceConfig{ + { + name: "TestSourceDevice", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "DeviceValue", + wantEffective: true, + }, + }, + additionalSources: []sourceConfig{ + { + name: "TestSourceUser", + scope: setting.CurrentUserScope, + settingKey: "TestKeyA", + settingValue: "UserValue", + wantEffective: false, // Registering a user source should have no impact on the device policy. + }, + }, + wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{ + "TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)), + }, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Register all settings that we use in this test. + var definitions []*setting.Definition + for _, source := range slices.Concat(tt.initialSources, tt.additionalSources) { + definitions = append(definitions, setting.NewDefinition(source.settingKey, tt.scope.Kind(), setting.StringValue)) + } + if err := setting.SetDefinitionsForTest(t, definitions...); err != nil { + t.Fatalf("SetDefinitionsForTest failed: %v", err) + } + + // Add the initial policy sources. + var wantSources []*source.Source + for _, s := range tt.initialSources { + store := source.NewTestStoreOf(t, source.TestSettingOf(s.settingKey, s.settingValue)) + source := source.NewSource(s.name, s.scope, store) + if err := registerSource(source); err != nil { + t.Fatalf("Failed to register policy source: %v", source) + } + if s.wantEffective { + wantSources = append(wantSources, source) + } + t.Cleanup(func() { unregisterSource(source) }) + } + + // Retrieve the effective policy. + policy, err := policyForTest(t, tt.scope) + if err != nil { + t.Fatalf("Failed to get effective policy for %v: %v", tt.scope, err) + } + + checkPolicySources(t, policy, wantSources) + + // Add additional setting sources. + for _, s := range tt.additionalSources { + store := source.NewTestStoreOf(t, source.TestSettingOf(s.settingKey, s.settingValue)) + source := source.NewSource(s.name, s.scope, store) + if err := registerSource(source); err != nil { + t.Fatalf("Failed to register additional policy source: %v", source) + } + if s.wantEffective { + wantSources = append(wantSources, source) + } + t.Cleanup(func() { unregisterSource(source) }) + } + + checkPolicySources(t, policy, wantSources) + + // Verify the final effective settings snapshots. + if got := policy.Get(); !got.Equal(tt.wantSnapshot) { + t.Errorf("Snapshot: got %v; want %v", got, tt.wantSnapshot) + } + }) + } +} + +func TestPolicyFor(t *testing.T) { + tests := []struct { + name string + scopeA, scopeB setting.PolicyScope + closePolicy bool // indicates whether to close policyA before retrieving policyB + wantSame bool // specifies whether policyA and policyB should reference the same [Policy] instance + }{ + { + name: "Device/Device", + scopeA: setting.DeviceScope, + scopeB: setting.DeviceScope, + wantSame: true, + }, + { + name: "Device/CurrentProfile", + scopeA: setting.DeviceScope, + scopeB: setting.CurrentProfileScope, + wantSame: false, + }, + { + name: "Device/CurrentUser", + scopeA: setting.DeviceScope, + scopeB: setting.CurrentUserScope, + wantSame: false, + }, + { + name: "CurrentProfile/CurrentProfile", + scopeA: setting.CurrentProfileScope, + scopeB: setting.CurrentProfileScope, + wantSame: true, + }, + { + name: "CurrentProfile/CurrentUser", + scopeA: setting.CurrentProfileScope, + scopeB: setting.CurrentUserScope, + wantSame: false, + }, + { + name: "CurrentUser/CurrentUser", + scopeA: setting.CurrentUserScope, + scopeB: setting.CurrentUserScope, + wantSame: true, + }, + { + name: "UserA/UserA", + scopeA: setting.UserScopeOf("UserA"), + scopeB: setting.UserScopeOf("UserA"), + wantSame: true, + }, + { + name: "UserA/UserB", + scopeA: setting.UserScopeOf("UserA"), + scopeB: setting.UserScopeOf("UserB"), + wantSame: false, + }, + { + name: "New-after-close", + scopeA: setting.DeviceScope, + scopeB: setting.DeviceScope, + closePolicy: true, + wantSame: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + policyA, err := policyForTest(t, tt.scopeA) + if err != nil { + t.Fatalf("Failed to get effective policy for %v: %v", tt.scopeA, err) + } + + if tt.closePolicy { + policyA.Close() + } + + policyB, err := policyForTest(t, tt.scopeB) + if err != nil { + t.Fatalf("Failed to get effective policy for %v: %v", tt.scopeB, err) + } + + if gotSame := policyA == policyB; gotSame != tt.wantSame { + t.Fatalf("Got same: %v; want same %v", gotSame, tt.wantSame) + } + }) + } +} + +func TestPolicyChangeHasChanged(t *testing.T) { + tests := []struct { + name string + old, new map[setting.Key]setting.RawItem + wantChanged []setting.Key + wantUnchanged []setting.Key + }{ + { + name: "String-Settings", + old: map[setting.Key]setting.RawItem{ + "ChangedSetting": setting.RawItemOf("Old"), + "UnchangedSetting": setting.RawItemOf("Value"), + }, + new: map[setting.Key]setting.RawItem{ + "ChangedSetting": setting.RawItemOf("New"), + "UnchangedSetting": setting.RawItemOf("Value"), + }, + wantChanged: []setting.Key{"ChangedSetting"}, + wantUnchanged: []setting.Key{"UnchangedSetting"}, + }, + { + name: "UInt64-Settings", + old: map[setting.Key]setting.RawItem{ + "ChangedSetting": setting.RawItemOf(uint64(0)), + "UnchangedSetting": setting.RawItemOf(uint64(42)), + }, + new: map[setting.Key]setting.RawItem{ + "ChangedSetting": setting.RawItemOf(uint64(1)), + "UnchangedSetting": setting.RawItemOf(uint64(42)), + }, + wantChanged: []setting.Key{"ChangedSetting"}, + wantUnchanged: []setting.Key{"UnchangedSetting"}, + }, + { + name: "StringSlice-Settings", + old: map[setting.Key]setting.RawItem{ + "ChangedSetting": setting.RawItemOf([]string{"Chicago"}), + "UnchangedSetting": setting.RawItemOf([]string{"String1", "String2"}), + }, + new: map[setting.Key]setting.RawItem{ + "ChangedSetting": setting.RawItemOf([]string{"New York"}), + "UnchangedSetting": setting.RawItemOf([]string{"String1", "String2"}), + }, + wantChanged: []setting.Key{"ChangedSetting"}, + wantUnchanged: []setting.Key{"UnchangedSetting"}, + }, + { + name: "Int8-Settings", // We don't have actual int8 settings, but this should still work. + old: map[setting.Key]setting.RawItem{ + "ChangedSetting": setting.RawItemOf(int8(0)), + "UnchangedSetting": setting.RawItemOf(int8(42)), + }, + new: map[setting.Key]setting.RawItem{ + "ChangedSetting": setting.RawItemOf(int8(1)), + "UnchangedSetting": setting.RawItemOf(int8(42)), + }, + wantChanged: []setting.Key{"ChangedSetting"}, + wantUnchanged: []setting.Key{"UnchangedSetting"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + old := setting.NewSnapshot(tt.old) + new := setting.NewSnapshot(tt.new) + change := PolicyChange{Change[*setting.Snapshot]{old, new}} + for _, wantChanged := range tt.wantChanged { + if !change.HasChanged(wantChanged) { + t.Errorf("%q changed: got false; want true", wantChanged) + } + } + for _, wantUnchanged := range tt.wantUnchanged { + if change.HasChanged(wantUnchanged) { + t.Errorf("%q unchanged: got true; want false", wantUnchanged) + } + } + }) + } +} + +func TestChangePolicySetting(t *testing.T) { + setForTest(t, &policyReloadMinDelay, 100*time.Millisecond) + setForTest(t, &policyReloadMaxDelay, 500*time.Millisecond) + + // Register policy settings used in this test. + settingA := setting.NewDefinition("TestSettingA", setting.DeviceSetting, setting.StringValue) + settingB := setting.NewDefinition("TestSettingB", setting.DeviceSetting, setting.StringValue) + if err := setting.SetDefinitionsForTest(t, settingA, settingB); err != nil { + t.Fatalf("SetDefinitionsForTest failed: %v", err) + } + + // Register a test policy store and create a effective policy that reads the policy settings from it. + store := source.NewTestStoreOf[string](t) + if _, err := RegisterStoreForTest(t, "TestSource", setting.DeviceScope, store); err != nil { + t.Fatalf("Failed to register policy store: %v", err) + } + policy, err := policyForTest(t, setting.DeviceScope) + if err != nil { + t.Fatalf("Failed to get effective policy: %v", err) + } + + // The policy setting is not configured yet. + if _, ok := policy.Get().GetSetting(settingA.Key()); ok { + t.Fatalf("Policy setting %q unexpectedly exists", settingA.Key()) + } + + // Subscribe to the policy change callback... + policyChanged := make(chan *PolicyChange) + unregister := policy.RegisterChangeCallback(func(pc *PolicyChange) { policyChanged <- pc }) + t.Cleanup(unregister) + + // ...make the change, and measure the time between initiating the change + // and receiving the callback. + start := time.Now() + const wantValueA = "TestValueA" + store.SetStrings(source.TestSettingOf(settingA.Key(), wantValueA)) + change := <-policyChanged + gotDelay := time.Since(start) + + // Ensure there is at least a [policyReloadMinDelay] delay between + // a change and the policy reload along with the callback invocation. + // This prevents reloading policy settings too frequently + // when multiple settings change within a short period of time. + if gotDelay < policyReloadMinDelay { + t.Errorf("Delay: got %v; want >= %v", gotDelay, policyReloadMinDelay) + } + + // Verify that the [PolicyChange] passed to the policy change callback + // contains the correct information regarding the policy setting changes. + if !change.HasChanged(settingA.Key()) { + t.Errorf("Policy setting %q has not changed", settingA.Key()) + } + if change.HasChanged(settingB.Key()) { + t.Errorf("Policy setting %q was unexpectedly changed", settingB.Key()) + } + if _, ok := change.Old().GetSetting(settingA.Key()); ok { + t.Fatalf("Policy setting %q unexpectedly exists", settingA.Key()) + } + if gotValue := change.New().Get(settingA.Key()); gotValue != wantValueA { + t.Errorf("Policy setting %q: got %q; want %q", settingA.Key(), gotValue, wantValueA) + } + + // And also verify that the current (most recent) [setting.Snapshot] + // includes the change we just made. + if gotValue := policy.Get().Get(settingA.Key()); gotValue != wantValueA { + t.Errorf("Policy setting %q: got %q; want %q", settingA.Key(), gotValue, wantValueA) + } + + // Now, let's change another policy setting value N times. + const N = 10 + wantValueB := strconv.Itoa(N) + start = time.Now() + for i := range N { + store.SetStrings(source.TestSettingOf(settingB.Key(), strconv.Itoa(i+1))) + } + + // The callback should be invoked only once, even though the policy setting + // has changed N times. + change = <-policyChanged + gotDelay = time.Since(start) + gotCallbacks := 1 +drain: + for { + select { + case <-policyChanged: + gotCallbacks++ + case <-time.After(policyReloadMaxDelay): + break drain + } + } + if wantCallbacks := 1; gotCallbacks > wantCallbacks { + t.Errorf("Callbacks: got %d; want %d", gotCallbacks, wantCallbacks) + } + + // Additionally, the policy change callback should be received no sooner + // than [policyReloadMinDelay] and no later than [policyReloadMaxDelay]. + if gotDelay < policyReloadMinDelay || gotDelay > policyReloadMaxDelay { + t.Errorf("Delay: got %v; want >= %v && <= %v", gotDelay, policyReloadMinDelay, policyReloadMaxDelay) + } + + // Verify that the [PolicyChange] received via the callback + // contains the final policy setting value. + if !change.HasChanged(settingB.Key()) { + t.Errorf("Policy setting %q has not changed", settingB.Key()) + } + if change.HasChanged(settingA.Key()) { + t.Errorf("Policy setting %q was unexpectedly changed", settingA.Key()) + } + if _, ok := change.Old().GetSetting(settingB.Key()); ok { + t.Fatalf("Policy setting %q unexpectedly exists", settingB.Key()) + } + if gotValue := change.New().Get(settingB.Key()); gotValue != wantValueB { + t.Errorf("Policy setting %q: got %q; want %q", settingB.Key(), gotValue, wantValueB) + } + + // Lastly, if a policy store issues a change notification, but the effective policy + // remains unchanged, the [Policy] should ignore it without invoking the change callbacks. + store.NotifyPolicyChanged() + select { + case <-policyChanged: + t.Fatal("Unexpected policy changed notification") + case <-time.After(policyReloadMaxDelay): + } +} + +func TestClosePolicySource(t *testing.T) { + testSetting := setting.NewDefinition("TestSetting", setting.DeviceSetting, setting.StringValue) + if err := setting.SetDefinitionsForTest(t, testSetting); err != nil { + t.Fatalf("SetDefinitionsForTest failed: %v", err) + } + + wantSettingValue := "TestValue" + store := source.NewTestStoreOf(t, source.TestSettingOf(testSetting.Key(), wantSettingValue)) + if _, err := RegisterStoreForTest(t, "TestSource", setting.DeviceScope, store); err != nil { + t.Fatalf("Failed to register policy store: %v", err) + } + policy, err := policyForTest(t, setting.DeviceScope) + if err != nil { + t.Fatalf("Failed to get effective policy: %v", err) + } + + initialSnapshot, err := policy.Reload() + if err != nil { + t.Fatalf("Failed to reload policy: %v", err) + } + if gotSettingValue, err := initialSnapshot.GetErr(testSetting.Key()); err != nil { + t.Fatalf("Failed to get %q setting value: %v", testSetting.Key(), err) + } else if gotSettingValue != wantSettingValue { + t.Fatalf("Setting %q: got %q; want %q", testSetting.Key(), gotSettingValue, wantSettingValue) + } + + store.Close() + + // Closing a policy source abruptly without removing it first should invalidate and close the policy. + <-policy.Done() + if policy.IsValid() { + t.Fatal("The policy was not properly closed") + } + + // The resulting policy snapshot should remain valid and unchanged. + finalSnapshot := policy.Get() + if !finalSnapshot.Equal(initialSnapshot) { + t.Fatal("Policy snapshot has changed") + } + if gotSettingValue, err := finalSnapshot.GetErr(testSetting.Key()); err != nil { + t.Fatalf("Failed to get final %q setting value: %v", testSetting.Key(), err) + } else if gotSettingValue != wantSettingValue { + t.Fatalf("Setting %q: got %q; want %q", testSetting.Key(), gotSettingValue, wantSettingValue) + } + + // However, any further requests to reload the policy should fail. + if _, err := policy.Reload(); err == nil || !errors.Is(err, ErrPolicyClosed) { + t.Fatalf("Reload: gotErr: %v; wantErr: %v", err, ErrPolicyClosed) + } +} + +func TestRemovePolicySource(t *testing.T) { + // Register policy settings used in this test. + settingA := setting.NewDefinition("TestSettingA", setting.DeviceSetting, setting.StringValue) + settingB := setting.NewDefinition("TestSettingB", setting.DeviceSetting, setting.StringValue) + if err := setting.SetDefinitionsForTest(t, settingA, settingB); err != nil { + t.Fatalf("SetDefinitionsForTest failed: %v", err) + } + + // Register two policy stores. + storeA := source.NewTestStoreOf(t, source.TestSettingOf(settingA.Key(), "A")) + storeRegA, err := RegisterStoreForTest(t, "TestSourceA", setting.DeviceScope, storeA) + if err != nil { + t.Fatalf("Failed to register policy store A: %v", err) + } + storeB := source.NewTestStoreOf(t, source.TestSettingOf(settingB.Key(), "B")) + storeRegB, err := RegisterStoreForTest(t, "TestSourceB", setting.DeviceScope, storeB) + if err != nil { + t.Fatalf("Failed to register policy store A: %v", err) + } + + // Create a effective [Policy] that reads policy settings from the two stores. + policy, err := policyForTest(t, setting.DeviceScope) + if err != nil { + t.Fatalf("Failed to get effective policy: %v", err) + } + + // Verify that the [Policy] uses both stores and includes policy settings from each. + if gotSources, wantSources := len(policy.sources), 2; gotSources != wantSources { + t.Fatalf("Policy Sources: got %v; want %v", gotSources, wantSources) + } + if got, want := policy.Get().Get(settingA.Key()), "A"; got != want { + t.Fatalf("Setting %q: got %q; want %q", settingA.Key(), got, want) + } + if got, want := policy.Get().Get(settingB.Key()), "B"; got != want { + t.Fatalf("Setting %q: got %q; want %q", settingB.Key(), got, want) + } + + // Unregister Store A and verify that the effective policy remains valid. + // It should no longer use the removed store or include any policy settings from it. + if err := storeRegA.Unregister(); err != nil { + t.Fatalf("Failed to unregister Store A: %v", err) + } + if !policy.IsValid() { + t.Fatalf("Policy was unexpectedly closed") + } + if gotSources, wantSources := len(policy.sources), 1; gotSources != wantSources { + t.Fatalf("Policy Sources: got %v; want %v", gotSources, wantSources) + } + if got, want := policy.Get().Get(settingA.Key()), any(nil); got != want { + t.Fatalf("Setting %q: got %q; want %q", settingA.Key(), got, want) + } + if got, want := policy.Get().Get(settingB.Key()), "B"; got != want { + t.Fatalf("Setting %q: got %q; want %q", settingB.Key(), got, want) + } + + // Unregister Store B and verify that the effective policy is still valid. + // However, it should be empty since there are no associated policy sources. + if err := storeRegB.Unregister(); err != nil { + t.Fatalf("Failed to unregister Store B: %v", err) + } + if !policy.IsValid() { + t.Fatalf("Policy was unexpectedly closed") + } + if gotSources, wantSources := len(policy.sources), 0; gotSources != wantSources { + t.Fatalf("Policy Sources: got %v; want %v", gotSources, wantSources) + } + if got := policy.Get(); got.Len() != 0 { + t.Fatalf("Settings: got %v; want {Empty}", got) + } +} + +func TestReplacePolicySource(t *testing.T) { + setForTest(t, &policyReloadMinDelay, 100*time.Millisecond) + setForTest(t, &policyReloadMaxDelay, 500*time.Millisecond) + + // Register policy settings used in this test. + testSetting := setting.NewDefinition("TestSettingA", setting.DeviceSetting, setting.StringValue) + if err := setting.SetDefinitionsForTest(t, testSetting); err != nil { + t.Fatalf("SetDefinitionsForTest failed: %v", err) + } + + // Create two policy stores. + initialStore := source.NewTestStoreOf(t, source.TestSettingOf(testSetting.Key(), "InitialValue")) + newStore := source.NewTestStoreOf(t, source.TestSettingOf(testSetting.Key(), "NewValue")) + unchangedStore := source.NewTestStoreOf(t, source.TestSettingOf(testSetting.Key(), "NewValue")) + + // Register the initial store and create a effective [Policy] that reads policy settings from it. + reg, err := RegisterStoreForTest(t, "TestStore", setting.DeviceScope, initialStore) + if err != nil { + t.Fatalf("Failed to register the initial store: %v", err) + } + policy, err := policyForTest(t, setting.DeviceScope) + if err != nil { + t.Fatalf("Failed to get effective policy: %v", err) + } + + // Verify that the test setting has its initial value. + if got, want := policy.Get().Get(testSetting.Key()), "InitialValue"; got != want { + t.Fatalf("Setting %q: got %q; want %q", testSetting.Key(), got, want) + } + + // Subscribe to the policy change callback. + policyChanged := make(chan *PolicyChange, 1) + unregister := policy.RegisterChangeCallback(func(pc *PolicyChange) { policyChanged <- pc }) + t.Cleanup(unregister) + + // Now, let's replace the initial store with the new store. + reg, err = reg.ReplaceStore(newStore) + if err != nil { + t.Fatalf("Failed to replace the policy store: %v", err) + } + t.Cleanup(func() { reg.Unregister() }) + + // We should receive a policy change notification as the setting value has changed. + <-policyChanged + + // Verify that the test setting has the new value. + if got, want := policy.Get().Get(testSetting.Key()), "NewValue"; got != want { + t.Fatalf("Setting %q: got %q; want %q", testSetting.Key(), got, want) + } + + // Replacing a policy store with an identical one containing the same + // values for the same settings should not be considered a policy change. + reg, err = reg.ReplaceStore(unchangedStore) + if err != nil { + t.Fatalf("Failed to replace the policy store: %v", err) + } + t.Cleanup(func() { reg.Unregister() }) + + select { + case <-policyChanged: + t.Fatal("Unexpected policy changed notification") + default: + <-time.After(policyReloadMaxDelay) + } +} + +func TestAddClosedPolicySource(t *testing.T) { + store := source.NewTestStoreOf[string](t) + if _, err := RegisterStoreForTest(t, "TestSource", setting.DeviceScope, store); err != nil { + t.Fatalf("Failed to register policy store: %v", err) + } + store.Close() + + _, err := policyForTest(t, setting.DeviceScope) + if err == nil || !errors.Is(err, source.ErrStoreClosed) { + t.Fatalf("got: %v; want: %v", err, source.ErrStoreClosed) + } +} + +func TestClosePolicyMoreThanOnce(t *testing.T) { + tests := []struct { + name string + numSources int + }{ + { + name: "NoSources", + numSources: 0, + }, + { + name: "OneSource", + numSources: 1, + }, + { + name: "ManySources", + numSources: 10, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for i := range tt.numSources { + store := source.NewTestStoreOf[string](t) + if _, err := RegisterStoreForTest(t, "TestSource #"+strconv.Itoa(i), setting.DeviceScope, store); err != nil { + t.Fatalf("Failed to register policy store: %v", err) + } + } + + policy, err := policyForTest(t, setting.DeviceScope) + if err != nil { + t.Fatalf("failed to get effective policy: %v", err) + } + + const N = 10000 + var wg sync.WaitGroup + for range N { + wg.Add(1) + go func() { + wg.Done() + policy.Close() + <-policy.Done() + }() + } + wg.Wait() + }) + } +} + +func checkPolicySources(tb testing.TB, gotPolicy *Policy, wantSources []*source.Source) { + tb.Helper() + sort.SliceStable(wantSources, func(i, j int) bool { + return wantSources[i].Compare(wantSources[j]) < 0 + }) + gotSources := make([]*source.Source, len(gotPolicy.sources)) + for i := range gotPolicy.sources { + gotSources[i] = gotPolicy.sources[i].Source + } + type sourceSummary struct{ Name, Scope string } + toSourceSummary := cmp.Transformer("source", func(s *source.Source) sourceSummary { return sourceSummary{s.Name(), s.Scope().String()} }) + if diff := cmp.Diff(wantSources, gotSources, toSourceSummary, cmpopts.EquateEmpty()); diff != "" { + tb.Errorf("Policy Sources mismatch: %v", diff) + } +} + +// policyForTest is like [PolicyFor], but it deletes the policy +// when tb and all its subtests complete. +func policyForTest(tb testing.TB, target setting.PolicyScope) (*Policy, error) { + tb.Helper() + + policy, err := PolicyFor(target) + if err != nil { + return nil, err + } + tb.Cleanup(func() { + policy.Close() + <-policy.Done() + deletePolicy(policy) + }) + return policy, nil +} + +func setForTest[T any](tb testing.TB, target *T, newValue T) { + oldValue := *target + tb.Cleanup(func() { *target = oldValue }) + *target = newValue +} diff --git a/util/syspolicy/rsop/rsop.go b/util/syspolicy/rsop/rsop.go new file mode 100644 index 000000000..429b9b101 --- /dev/null +++ b/util/syspolicy/rsop/rsop.go @@ -0,0 +1,174 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package rsop facilitates [source.Store] registration via [RegisterStore] +// and provides access to the effective policy merged from all registered sources +// via [PolicyFor]. +package rsop + +import ( + "errors" + "fmt" + "slices" + "sync" + + "tailscale.com/syncs" + "tailscale.com/util/slicesx" + "tailscale.com/util/syspolicy/internal" + "tailscale.com/util/syspolicy/setting" + "tailscale.com/util/syspolicy/source" +) + +var ( + policyMu sync.Mutex // protects [policySources] and [effectivePolicies] + policySources []*source.Source // all registered policy sources + effectivePolicies []*Policy // all active (non-closed) effective policies returned by [PolicyFor] + + // effectivePolicyLRU is an LRU cache of [Policy] by [setting.Scope]. + // Although there could be multiple [setting.PolicyScope] instances with the same [setting.Scope], + // such as two user scopes for different users, there is only one [setting.DeviceScope], only one + // [setting.CurrentProfileScope], and in most cases, only one active user scope. + // Therefore, cache misses that require falling back to [effectivePolicies] are extremely rare. + // It's a fixed-size array of atomic values and can be accessed without [policyMu] held. + effectivePolicyLRU [setting.NumScopes]syncs.AtomicValue[*Policy] +) + +// PolicyFor returns the [Policy] for the specified scope, +// creating it from the registered [source.Store]s if it doesn't already exist. +func PolicyFor(scope setting.PolicyScope) (*Policy, error) { + if err := internal.Init.Do(); err != nil { + return nil, err + } + policy := effectivePolicyLRU[scope.Kind()].Load() + if policy != nil && policy.Scope() == scope && policy.IsValid() { + return policy, nil + } + return policyForSlow(scope) +} + +func policyForSlow(scope setting.PolicyScope) (policy *Policy, err error) { + defer func() { + // Always update the LRU cache on exit if we found (or created) + // a policy for the specified scope. + if policy != nil { + effectivePolicyLRU[scope.Kind()].Store(policy) + } + }() + + policyMu.Lock() + defer policyMu.Unlock() + if policy, ok := findPolicyByScopeLocked(scope); ok { + return policy, nil + } + + // If there is no existing effective policy for the specified scope, + // we need to create one using the policy sources registered for that scope. + sources := slicesx.Filter(nil, policySources, func(source *source.Source) bool { + return source.Scope().Contains(scope) + }) + policy, err = newPolicy(scope, sources...) + if err != nil { + return nil, err + } + effectivePolicies = append(effectivePolicies, policy) + return policy, nil +} + +// findPolicyByScopeLocked returns a policy with the specified scope and true if +// one exists in the [effectivePolicies] list, otherwise it returns nil, false. +// [policyMu] must be held. +func findPolicyByScopeLocked(target setting.PolicyScope) (policy *Policy, ok bool) { + for _, policy := range effectivePolicies { + if policy.Scope() == target && policy.IsValid() { + return policy, true + } + } + return nil, false +} + +// deletePolicy deletes the specified effective policy from [effectivePolicies] +// and [effectivePolicyLRU]. +func deletePolicy(policy *Policy) { + policyMu.Lock() + defer policyMu.Unlock() + if i := slices.Index(effectivePolicies, policy); i != -1 { + effectivePolicies = slices.Delete(effectivePolicies, i, i+1) + } + effectivePolicyLRU[policy.Scope().Kind()].CompareAndSwap(policy, nil) +} + +// registerSource registers the specified [source.Source] to be used by the package. +// It updates existing [Policy]s returned by [PolicyFor] to use this source if +// they are within the source's [setting.PolicyScope]. +func registerSource(source *source.Source) error { + policyMu.Lock() + defer policyMu.Unlock() + if slices.Contains(policySources, source) { + // already registered + return nil + } + policySources = append(policySources, source) + return forEachEffectivePolicyLocked(func(policy *Policy) error { + if !source.Scope().Contains(policy.Scope()) { + // Policy settings in the specified source do not apply + // to the scope of this effective policy. + // For example, a user policy source is being registered + // while the effective policy is for the device (or another user). + return nil + } + return policy.addSource(source) + }) +} + +// replaceSource is like [unregisterSource](old) followed by [registerSource](new), +// but performed atomically: the effective policy will contain settings +// either from the old source or the new source, never both and never neither. +func replaceSource(old, new *source.Source) error { + policyMu.Lock() + defer policyMu.Unlock() + oldIndex := slices.Index(policySources, old) + if oldIndex == -1 { + return fmt.Errorf("the source is not registered: %v", old) + } + policySources[oldIndex] = new + return forEachEffectivePolicyLocked(func(policy *Policy) error { + if !old.Scope().Contains(policy.Scope()) || !new.Scope().Contains(policy.Scope()) { + return nil + } + return policy.replaceSource(old, new) + }) +} + +// unregisterSource unregisters the specified [source.Source], +// so that it won't be used by any new or existing [Policy]. +func unregisterSource(source *source.Source) error { + policyMu.Lock() + defer policyMu.Unlock() + index := slices.Index(policySources, source) + if index == -1 { + return nil + } + policySources = slices.Delete(policySources, index, index+1) + return forEachEffectivePolicyLocked(func(policy *Policy) error { + if !source.Scope().Contains(policy.Scope()) { + return nil + } + return policy.removeSource(source) + }) +} + +// forEachEffectivePolicyLocked calls fn for every non-closed [Policy] in [effectivePolicies]. +// It accumulates the returned errors and returns an error that wraps all errors returned by fn. +// The [policyMu] mutex must be held while this function is executed. +func forEachEffectivePolicyLocked(fn func(p *Policy) error) error { + var errs []error + for _, policy := range effectivePolicies { + if policy.IsValid() { + err := fn(policy) + if err != nil && !errors.Is(err, ErrPolicyClosed) { + errs = append(errs, err) + } + } + } + return errors.Join(errs...) +} diff --git a/util/syspolicy/rsop/store_registration.go b/util/syspolicy/rsop/store_registration.go new file mode 100644 index 000000000..09c83e988 --- /dev/null +++ b/util/syspolicy/rsop/store_registration.go @@ -0,0 +1,94 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package rsop + +import ( + "errors" + "sync" + "sync/atomic" + + "tailscale.com/util/syspolicy/internal" + "tailscale.com/util/syspolicy/setting" + "tailscale.com/util/syspolicy/source" +) + +// ErrAlreadyConsumed is the error returned when [StoreRegistration.ReplaceStore] +// or [StoreRegistration.Unregister] is called more than once. +var ErrAlreadyConsumed = errors.New("the store registration is no longer valid") + +// StoreRegistration is a [source.Store] registered for use in the specified scope. +// It can be used to unregister the store, or replace it with another one. +type StoreRegistration struct { + source *source.Source + m sync.Mutex // protects the [StoreRegistration.consumeSlow] path + consumed atomic.Bool // can be read without holding m, but must be written with m held +} + +// RegisterStore registers a new policy [source.Store] with the specified name and [setting.PolicyScope]. +func RegisterStore(name string, scope setting.PolicyScope, store source.Store) (*StoreRegistration, error) { + return newStoreRegistration(name, scope, store) +} + +// RegisterStoreForTest is like [RegisterStore], but unregisters the store when +// tb and all its subtests complete. +func RegisterStoreForTest(tb internal.TB, name string, scope setting.PolicyScope, store source.Store) (*StoreRegistration, error) { + reg, err := RegisterStore(name, scope, store) + if err == nil { + tb.Cleanup(func() { + if err := reg.Unregister(); err != nil && !errors.Is(err, ErrAlreadyConsumed) { + tb.Fatalf("Unregister failed: %v", err) + } + }) + } + return reg, err // may be nil or non-nil +} + +func newStoreRegistration(name string, scope setting.PolicyScope, store source.Store) (*StoreRegistration, error) { + source := source.NewSource(name, scope, store) + if err := registerSource(source); err != nil { + return nil, err + } + return &StoreRegistration{source: source}, nil +} + +// ReplaceStore replaces the registered store with the new one, +// returning a new [StoreRegistration] or an error. +func (r *StoreRegistration) ReplaceStore(new source.Store) (*StoreRegistration, error) { + var res *StoreRegistration + err := r.consume(func() error { + newSource := source.NewSource(r.source.Name(), r.source.Scope(), new) + if err := replaceSource(r.source, newSource); err != nil { + return err + } + res = &StoreRegistration{source: newSource} + return nil + }) + return res, err +} + +// Unregister reverts the registration. +func (r *StoreRegistration) Unregister() error { + return r.consume(func() error { return unregisterSource(r.source) }) +} + +// consume invokes fn, consuming r if no error is returned. +// It returns [ErrAlreadyConsumed] on subsequent calls after the first successful call. +func (r *StoreRegistration) consume(fn func() error) (err error) { + if r.consumed.Load() { + return ErrAlreadyConsumed + } + return r.consumeSlow(fn) +} + +func (r *StoreRegistration) consumeSlow(fn func() error) (err error) { + r.m.Lock() + defer r.m.Unlock() + if r.consumed.Load() { + return ErrAlreadyConsumed + } + if err = fn(); err == nil { + r.consumed.Store(true) + } + return err // may be nil or non-nil +} diff --git a/util/syspolicy/setting/policy_scope.go b/util/syspolicy/setting/policy_scope.go index 55fa339e7..c2039fdda 100644 --- a/util/syspolicy/setting/policy_scope.go +++ b/util/syspolicy/setting/policy_scope.go @@ -8,6 +8,7 @@ import ( "strings" "tailscale.com/types/lazy" + "tailscale.com/util/syspolicy/internal" ) var ( @@ -35,6 +36,8 @@ type PolicyScope struct { // when querying policy settings. // It returns [DeviceScope], unless explicitly changed with [SetDefaultScope]. func DefaultScope() PolicyScope { + // Allow deferred package init functions to override the default scope. + internal.Init.Do() return lazyDefaultScope.Get(func() PolicyScope { return DeviceScope }) } diff --git a/util/syspolicy/setting/setting.go b/util/syspolicy/setting/setting.go index 93be287b1..70fb0a931 100644 --- a/util/syspolicy/setting/setting.go +++ b/util/syspolicy/setting/setting.go @@ -243,6 +243,9 @@ func registerLocked(d *Definition) { func settingDefinitions() (DefinitionMap, error) { return definitions.GetErr(func() (DefinitionMap, error) { + if err := internal.Init.Do(); err != nil { + return nil, err + } definitionsMu.Lock() defer definitionsMu.Unlock() definitionsUsed = true diff --git a/util/syspolicy/source/test_store.go b/util/syspolicy/source/test_store.go index bb8e164fb..1f19bbb43 100644 --- a/util/syspolicy/source/test_store.go +++ b/util/syspolicy/source/test_store.go @@ -89,6 +89,7 @@ type TestStore struct { suspendCount int // change callback are suspended if > 0 mr, mw map[setting.Key]any // maps for reading and writing; they're the same unless the store is suspended. cbs set.HandleSet[func()] + closed bool readsMu sync.Mutex reads map[testReadOperation]int // how many times a policy setting was read @@ -98,24 +99,20 @@ type TestStore struct { // The tb will be used to report coding errors detected by the [TestStore]. func NewTestStore(tb internal.TB) *TestStore { m := make(map[setting.Key]any) - return &TestStore{ + store := &TestStore{ tb: tb, done: make(chan struct{}), mr: m, mw: m, } + tb.Cleanup(store.Close) + return store } // NewTestStoreOf is a shorthand for [NewTestStore] followed by [TestStore.SetBooleans], // [TestStore.SetUInt64s], [TestStore.SetStrings] or [TestStore.SetStringLists]. func NewTestStoreOf[T TestValueType](tb internal.TB, settings ...TestSetting[T]) *TestStore { - m := make(map[setting.Key]any) - store := &TestStore{ - tb: tb, - done: make(chan struct{}), - mr: m, - mw: m, - } + store := NewTestStore(tb) switch settings := any(settings).(type) { case []TestSetting[bool]: store.SetBooleans(settings...) @@ -308,7 +305,7 @@ func (s *TestStore) Resume() { s.mr = s.mw s.mu.Unlock() s.storeLock.Unlock() - s.notifyPolicyChanged() + s.NotifyPolicyChanged() case s.suspendCount < 0: s.tb.Fatal("negative suspendCount") default: @@ -333,7 +330,7 @@ func (s *TestStore) SetBooleans(settings ...TestSetting[bool]) { s.mu.Unlock() } s.storeLock.Unlock() - s.notifyPolicyChanged() + s.NotifyPolicyChanged() } // SetUInt64s sets the specified integer settings in s. @@ -352,7 +349,7 @@ func (s *TestStore) SetUInt64s(settings ...TestSetting[uint64]) { s.mu.Unlock() } s.storeLock.Unlock() - s.notifyPolicyChanged() + s.NotifyPolicyChanged() } // SetStrings sets the specified string settings in s. @@ -371,7 +368,7 @@ func (s *TestStore) SetStrings(settings ...TestSetting[string]) { s.mu.Unlock() } s.storeLock.Unlock() - s.notifyPolicyChanged() + s.NotifyPolicyChanged() } // SetStrings sets the specified string list settings in s. @@ -390,7 +387,7 @@ func (s *TestStore) SetStringLists(settings ...TestSetting[[]string]) { s.mu.Unlock() } s.storeLock.Unlock() - s.notifyPolicyChanged() + s.NotifyPolicyChanged() } // Delete deletes the specified settings from s. @@ -402,7 +399,7 @@ func (s *TestStore) Delete(keys ...setting.Key) { s.mu.Unlock() } s.storeLock.Unlock() - s.notifyPolicyChanged() + s.NotifyPolicyChanged() } // Clear deletes all settings from s. @@ -412,10 +409,10 @@ func (s *TestStore) Clear() { clear(s.mw) s.mu.Unlock() s.storeLock.Unlock() - s.notifyPolicyChanged() + s.NotifyPolicyChanged() } -func (s *TestStore) notifyPolicyChanged() { +func (s *TestStore) NotifyPolicyChanged() { s.mu.RLock() if s.suspendCount != 0 { s.mu.RUnlock() @@ -439,9 +436,9 @@ func (s *TestStore) notifyPolicyChanged() { func (s *TestStore) Close() { s.mu.Lock() defer s.mu.Unlock() - if s.done != nil { + if !s.closed { close(s.done) - s.done = nil + s.closed = true } }