util/syspolicy: add rsop package that provides access to the resultant policy

In this PR we add syspolicy/rsop package that facilitates policy source registration
and provides access to the resultant policy merged from all registered sources for a
given scope.

Updates #12687

Signed-off-by: Nick Khyl <nickk@tailscale.com>
pull/13727/head
Nick Khyl 1 month ago committed by Nick Khyl
parent 2aa9125ac4
commit ff5f233c3a

@ -13,6 +13,9 @@ import (
"tailscale.com/version" "tailscale.com/version"
) )
// Init facilitates deferred invocation of initializers.
var Init lazy.DeferredInit
// OSForTesting is the operating system override used for testing. // OSForTesting is the operating system override used for testing.
// It follows the same naming convention as [version.OS]. // It follows the same naming convention as [version.OS].
var OSForTesting lazy.SyncValue[string] var OSForTesting lazy.SyncValue[string]

@ -0,0 +1,107 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package rsop
import (
"reflect"
"slices"
"sync"
"time"
"tailscale.com/util/set"
"tailscale.com/util/syspolicy/internal/loggerx"
"tailscale.com/util/syspolicy/setting"
)
// Change represents a change from the Old to the New value of type T.
type Change[T any] struct {
New, Old T
}
// PolicyChangeCallback is a function called whenever a policy changes.
type PolicyChangeCallback func(*PolicyChange)
// PolicyChange describes a policy change.
type PolicyChange struct {
snapshots Change[*setting.Snapshot]
}
// New returns the [setting.Snapshot] after the change.
func (c PolicyChange) New() *setting.Snapshot {
return c.snapshots.New
}
// Old returns the [setting.Snapshot] before the change.
func (c PolicyChange) Old() *setting.Snapshot {
return c.snapshots.Old
}
// HasChanged reports whether a policy setting with the specified [setting.Key], has changed.
func (c PolicyChange) HasChanged(key setting.Key) bool {
new, newErr := c.snapshots.New.GetErr(key)
old, oldErr := c.snapshots.Old.GetErr(key)
if newErr != nil && oldErr != nil {
return false
}
if newErr != nil || oldErr != nil {
return true
}
switch newVal := new.(type) {
case bool, uint64, string, setting.Visibility, setting.PreferenceOption, time.Duration:
return newVal != old
case []string:
oldVal, ok := old.([]string)
return !ok || !slices.Equal(newVal, oldVal)
default:
loggerx.Errorf("[unexpected] %q has an unsupported value type: %T", key, newVal)
return !reflect.DeepEqual(new, old)
}
}
// policyChangeCallbacks are the callbacks to invoke when the effective policy changes.
// It is safe for concurrent use.
type policyChangeCallbacks struct {
mu sync.Mutex
cbs set.HandleSet[PolicyChangeCallback]
}
// Register adds the specified callback to be invoked whenever the policy changes.
func (c *policyChangeCallbacks) Register(callback PolicyChangeCallback) (unregister func()) {
c.mu.Lock()
handle := c.cbs.Add(callback)
c.mu.Unlock()
return func() {
c.mu.Lock()
delete(c.cbs, handle)
c.mu.Unlock()
}
}
// Invoke calls the registered callback functions with the specified policy change info.
func (c *policyChangeCallbacks) Invoke(snapshots Change[*setting.Snapshot]) {
var wg sync.WaitGroup
defer wg.Wait()
c.mu.Lock()
defer c.mu.Unlock()
wg.Add(len(c.cbs))
change := &PolicyChange{snapshots: snapshots}
for _, cb := range c.cbs {
go func() {
defer wg.Done()
cb(change)
}()
}
}
// Close awaits the completion of active callbacks and prevents any further invocations.
func (c *policyChangeCallbacks) Close() {
c.mu.Lock()
defer c.mu.Unlock()
if c.cbs != nil {
clear(c.cbs)
c.cbs = nil
}
}

@ -0,0 +1,449 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package rsop
import (
"errors"
"fmt"
"slices"
"sync"
"sync/atomic"
"time"
"tailscale.com/util/syspolicy/internal/loggerx"
"tailscale.com/util/syspolicy/setting"
"tailscale.com/util/syspolicy/source"
)
// ErrPolicyClosed is returned by [Policy.Reload], [Policy.addSource],
// [Policy.removeSource] and [Policy.replaceSource] if the policy has been closed.
var ErrPolicyClosed = errors.New("effective policy closed")
// The minimum and maximum wait times after detecting a policy change
// before reloading the policy. This only affects policy reloads triggered
// by a change in the underlying [source.Store] and does not impact
// synchronous, caller-initiated reloads, such as when [Policy.Reload] is called.
//
// Policy changes occurring within [policyReloadMinDelay] of each other
// will be batched together, resulting in a single policy reload
// no later than [policyReloadMaxDelay] after the first detected change.
// In other words, the effective policy will be reloaded no more often than once
// every 5 seconds, but at most 15 seconds after an underlying [source.Store]
// has issued a policy change callback.
//
// See [Policy.watchReload].
var (
policyReloadMinDelay = 5 * time.Second
policyReloadMaxDelay = 15 * time.Second
)
// Policy provides access to the current effective [setting.Snapshot] for a given
// scope and allows to reload it from the underlying [source.Store] list. It also allows to
// subscribe and receive a callback whenever the effective [setting.Snapshot] is changed.
//
// It is safe for concurrent use.
type Policy struct {
scope setting.PolicyScope
reloadCh chan reloadRequest // 1-buffered; written to when a policy reload is required
closeCh chan struct{} // closed to signal that the Policy is being closed
doneCh chan struct{} // closed by [Policy.closeInternal]
// effective is the most recent version of the [setting.Snapshot]
// containing policy settings merged from all applicable sources.
effective atomic.Pointer[setting.Snapshot]
changeCallbacks policyChangeCallbacks
mu sync.Mutex
watcherStarted bool // whether [Policy.watchReload] was started
sources source.ReadableSources
closing bool // whether [Policy.Close] was called (even if we're still closing)
}
// newPolicy returns a new [Policy] for the specified [setting.PolicyScope]
// that tracks changes and merges policy settings read from the specified sources.
func newPolicy(scope setting.PolicyScope, sources ...*source.Source) (_ *Policy, err error) {
readableSources := make(source.ReadableSources, 0, len(sources))
defer func() {
if err != nil {
readableSources.Close()
}
}()
for _, s := range sources {
reader, err := s.Reader()
if err != nil {
return nil, fmt.Errorf("failed to get a store reader: %w", err)
}
session, err := reader.OpenSession()
if err != nil {
return nil, fmt.Errorf("failed to open a reading session: %w", err)
}
readableSources = append(readableSources, source.ReadableSource{Source: s, ReadingSession: session})
}
// Sort policy sources by their precedence from lower to higher.
// For example, {UserPolicy},{ProfilePolicy},{DevicePolicy}.
readableSources.StableSort()
p := &Policy{
scope: scope,
sources: readableSources,
reloadCh: make(chan reloadRequest, 1),
closeCh: make(chan struct{}),
doneCh: make(chan struct{}),
}
if _, err := p.reloadNow(false); err != nil {
p.Close()
return nil, err
}
p.startWatchReloadIfNeeded()
return p, nil
}
// IsValid reports whether p is in a valid state and has not been closed.
//
// Since p's state can be changed by other goroutines at any time, this should
// only be used as an optimization.
func (p *Policy) IsValid() bool {
select {
case <-p.closeCh:
return false
default:
return true
}
}
// Scope returns the [setting.PolicyScope] that this policy applies to.
func (p *Policy) Scope() setting.PolicyScope {
return p.scope
}
// Get returns the effective [setting.Snapshot].
func (p *Policy) Get() *setting.Snapshot {
return p.effective.Load()
}
// RegisterChangeCallback adds a function to be called whenever the effective
// policy changes. The returned function can be used to unregister the callback.
func (p *Policy) RegisterChangeCallback(callback PolicyChangeCallback) (unregister func()) {
return p.changeCallbacks.Register(callback)
}
// Reload synchronously re-reads policy settings from the underlying list of policy sources,
// constructing a new merged [setting.Snapshot] even if the policy remains unchanged.
// In most scenarios, there's no need to re-read the policy manually.
// Instead, it is recommended to register a policy change callback, or to use
// the most recent [setting.Snapshot] returned by the [Policy.Get] method.
//
// It must not be called with p.mu held.
func (p *Policy) Reload() (*setting.Snapshot, error) {
return p.reload(true)
}
// reload is like Reload, but allows to specify whether to re-read policy settings
// from unchanged policy sources.
//
// It must not be called with p.mu held.
func (p *Policy) reload(force bool) (*setting.Snapshot, error) {
if !p.startWatchReloadIfNeeded() {
return p.Get(), nil
}
respCh := make(chan reloadResponse, 1)
select {
case p.reloadCh <- reloadRequest{force: force, respCh: respCh}:
// continue
case <-p.closeCh:
return nil, ErrPolicyClosed
}
select {
case resp := <-respCh:
return resp.policy, resp.err
case <-p.closeCh:
return nil, ErrPolicyClosed
}
}
// reloadAsync requests an asynchronous background policy reload.
// The policy will be reloaded no later than in [policyReloadMaxDelay].
//
// It must not be called with p.mu held.
func (p *Policy) reloadAsync() {
if !p.startWatchReloadIfNeeded() {
return
}
select {
case p.reloadCh <- reloadRequest{}:
// Sent.
default:
// A reload request is already en route.
}
}
// reloadNow loads and merges policies from all sources, updating the effective policy.
// If the force parameter is true, it forcibly reloads policies
// from the underlying policy store, even if no policy changes were detected.
//
// Except for the initial policy reload during the [Policy] creation,
// this method should only be called from the [Policy.watchReload] goroutine.
func (p *Policy) reloadNow(force bool) (*setting.Snapshot, error) {
new, err := p.readAndMerge(force)
if err != nil {
return nil, err
}
old := p.effective.Swap(new)
// A nil old value indicates the initial policy load rather than a policy change.
// Additionally, we should not invoke the policy change callbacks unless the
// policy items have actually changed.
if old != nil && !old.EqualItems(new) {
snapshots := Change[*setting.Snapshot]{New: new, Old: old}
p.changeCallbacks.Invoke(snapshots)
}
return new, nil
}
// Done returns a channel that is closed when the [Policy] is closed.
func (p *Policy) Done() <-chan struct{} {
return p.doneCh
}
// readAndMerge reads and merges policy settings from all applicable sources,
// returning a [setting.Snapshot] with the merged result.
// If the force parameter is true, it re-reads policy settings from each source
// even if no policy change was observed, and returns an error if the read
// operation fails.
func (p *Policy) readAndMerge(force bool) (*setting.Snapshot, error) {
p.mu.Lock()
defer p.mu.Unlock()
// Start with an empty policy in the target scope.
effective := setting.NewSnapshot(nil, setting.SummaryWith(p.scope))
// Then merge policy settings from all sources.
// Policy sources with the highest precedence (e.g., the device policy) are merged last,
// overriding any conflicting policy settings with lower precedence.
for _, s := range p.sources {
var policy *setting.Snapshot
if force {
var err error
if policy, err = s.ReadSettings(); err != nil {
return nil, err
}
} else {
policy = s.GetSettings()
}
effective = setting.MergeSnapshots(effective, policy)
}
return effective, nil
}
// addSource adds the specified source to the list of sources used by p,
// and triggers a synchronous policy refresh. It returns an error
// if the source is not a valid source for this effective policy,
// or if the effective policy is being closed,
// or if policy refresh fails with an error.
func (p *Policy) addSource(source *source.Source) error {
return p.applySourcesChange(source, nil)
}
// removeSource removes the specified source from the list of sources used by p,
// and triggers a synchronous policy refresh. It returns an error if the
// effective policy is being closed, or if policy refresh fails with an error.
func (p *Policy) removeSource(source *source.Source) error {
return p.applySourcesChange(nil, source)
}
// replaceSource replaces the old source with the new source atomically,
// and triggers a synchronous policy refresh. It returns an error
// if the source is not a valid source for this effective policy,
// or if the effective policy is being closed,
// or if policy refresh fails with an error.
func (p *Policy) replaceSource(old, new *source.Source) error {
return p.applySourcesChange(new, old)
}
func (p *Policy) applySourcesChange(toAdd, toRemove *source.Source) error {
if toAdd == toRemove {
return nil
}
if toAdd != nil && !toAdd.Scope().Contains(p.scope) {
return errors.New("scope mismatch")
}
changed, err := func() (changed bool, err error) {
p.mu.Lock()
defer p.mu.Unlock()
if toAdd != nil && !p.sources.Contains(toAdd) {
reader, err := toAdd.Reader()
if err != nil {
return false, fmt.Errorf("failed to get a store reader: %w", err)
}
session, err := reader.OpenSession()
if err != nil {
return false, fmt.Errorf("failed to open a reading session: %w", err)
}
addAt := p.sources.InsertionIndexOf(toAdd)
toAdd := source.ReadableSource{
Source: toAdd,
ReadingSession: session,
}
p.sources = slices.Insert(p.sources, addAt, toAdd)
go p.watchPolicyChanges(toAdd)
changed = true
}
if toRemove != nil {
if deleteAt := p.sources.IndexOf(toRemove); deleteAt != -1 {
p.sources.DeleteAt(deleteAt)
changed = true
}
}
return changed, nil
}()
if changed {
_, err = p.reload(false)
}
return err // may be nil or non-nil
}
func (p *Policy) watchPolicyChanges(s source.ReadableSource) {
for {
select {
case _, ok := <-s.ReadingSession.PolicyChanged():
if !ok {
p.mu.Lock()
abruptlyClosed := slices.Contains(p.sources, s)
p.mu.Unlock()
if abruptlyClosed {
// The underlying [source.Source] was closed abruptly without
// being properly removed or replaced by another policy source.
// We can't keep this [Policy] up to date, so we should close it.
p.Close()
}
return
}
// The PolicyChanged channel was signaled.
// Request an asynchronous policy reload.
p.reloadAsync()
case <-p.closeCh:
// The [Policy] is being closed.
return
}
}
}
// startWatchReloadIfNeeded starts [Policy.watchReload] in a new goroutine
// if the list of policy sources is not empty, it hasn't been started yet,
// and the [Policy] is not being closed.
// It reports whether [Policy.watchReload] has ever been started.
//
// It must not be called with p.mu held.
func (p *Policy) startWatchReloadIfNeeded() bool {
p.mu.Lock()
defer p.mu.Unlock()
if len(p.sources) != 0 && !p.watcherStarted && !p.closing {
go p.watchReload()
for i := range p.sources {
go p.watchPolicyChanges(p.sources[i])
}
p.watcherStarted = true
}
return p.watcherStarted
}
// reloadRequest describes a policy reload request.
type reloadRequest struct {
// force policy reload regardless of whether a policy change was detected.
force bool
// respCh is an optional channel. If non-nil, it makes the reload request
// synchronous and receives the result.
respCh chan<- reloadResponse
}
// reloadResponse is a result of a synchronous policy reload.
type reloadResponse struct {
policy *setting.Snapshot
err error
}
// watchReload processes incoming synchronous and asynchronous policy reload requests.
//
// Synchronous requests (with a non-nil respCh) are served immediately.
//
// Asynchronous requests are debounced and throttled: they are executed at least
// [policyReloadMinDelay] after the last request, but no later than [policyReloadMaxDelay]
// after the first request in a batch.
func (p *Policy) watchReload() {
defer p.closeInternal()
force := false // whether a forced refresh was requested
var delayCh, timeoutCh <-chan time.Time
reload := func(respCh chan<- reloadResponse) {
delayCh, timeoutCh = nil, nil
policy, err := p.reloadNow(force)
if err != nil {
loggerx.Errorf("%v policy reload failed: %v\n", p.scope, err)
}
if respCh != nil {
respCh <- reloadResponse{policy: policy, err: err}
}
force = false
}
loop:
for {
select {
case req := <-p.reloadCh:
if req.force {
force = true
}
if req.respCh != nil {
reload(req.respCh)
continue
}
if delayCh == nil {
timeoutCh = time.After(policyReloadMinDelay)
}
delayCh = time.After(policyReloadMaxDelay)
case <-delayCh:
reload(nil)
case <-timeoutCh:
reload(nil)
case <-p.closeCh:
break loop
}
}
}
func (p *Policy) closeInternal() {
p.mu.Lock()
defer p.mu.Unlock()
p.sources.Close()
p.changeCallbacks.Close()
close(p.doneCh)
deletePolicy(p)
}
// Close initiates the closing of the policy.
// The [Policy.Done] channel is closed to signal that the operation has been completed.
func (p *Policy) Close() {
p.mu.Lock()
alreadyClosing := p.closing
watcherStarted := p.watcherStarted
p.closing = true
p.mu.Unlock()
if alreadyClosing {
return
}
close(p.closeCh)
if !watcherStarted {
// Normally, closing p.closeCh signals [Policy.watchReload] to exit,
// and [Policy.closeInternal] performs the actual closing when
// [Policy.watchReload] returns. However, if the watcher was never
// started, we need to call [Policy.closeInternal] manually.
go p.closeInternal()
}
}

@ -0,0 +1,986 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package rsop
import (
"errors"
"slices"
"sort"
"strconv"
"sync"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"tailscale.com/tstest"
"tailscale.com/util/syspolicy/setting"
"tailscale.com/util/syspolicy/source"
)
func TestGetEffectivePolicyNoSource(t *testing.T) {
tests := []struct {
name string
scope setting.PolicyScope
}{
{
name: "DevicePolicy",
scope: setting.DeviceScope,
},
{
name: "CurrentProfilePolicy",
scope: setting.CurrentProfileScope,
},
{
name: "CurrentUserPolicy",
scope: setting.CurrentUserScope,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var policy *Policy
t.Cleanup(func() {
if policy != nil {
policy.Close()
<-policy.Done()
}
})
// Make sure we don't create any goroutines.
// We intentionally call ResourceCheck after t.Cleanup, so that when the test exits,
// the resource check runs before the test cleanup closes the policy.
// This helps to report any unexpectedly created goroutines.
// The goal is to ensure that using the syspolicy package, and particularly
// the rsop sub-package, is not wasteful and does not create unnecessary goroutines
// on platforms without registered policy sources.
tstest.ResourceCheck(t)
policy, err := PolicyFor(tt.scope)
if err != nil {
t.Fatalf("Failed to get effective policy for %v: %v", tt.scope, err)
}
if got := policy.Get(); got.Len() != 0 {
t.Errorf("Snapshot: got %v; want empty", got)
}
if got, err := policy.Reload(); err != nil {
t.Errorf("Reload failed: %v", err)
} else if got.Len() != 0 {
t.Errorf("Snapshot: got %v; want empty", got)
}
})
}
}
func TestRegisterSourceAndGetEffectivePolicy(t *testing.T) {
type sourceConfig struct {
name string
scope setting.PolicyScope
settingKey setting.Key
settingValue string
wantEffective bool
}
tests := []struct {
name string
scope setting.PolicyScope
initialSources []sourceConfig
additionalSources []sourceConfig
wantSnapshot *setting.Snapshot
}{
{
name: "DevicePolicy/NoSources",
scope: setting.DeviceScope,
wantSnapshot: setting.NewSnapshot(nil, setting.DeviceScope),
},
{
name: "UserScope/NoSources",
scope: setting.CurrentUserScope,
wantSnapshot: setting.NewSnapshot(nil, setting.CurrentUserScope),
},
{
name: "DevicePolicy/OneInitialSource",
scope: setting.DeviceScope,
initialSources: []sourceConfig{
{
name: "TestSourceA",
scope: setting.DeviceScope,
settingKey: "TestKeyA",
settingValue: "TestValueA",
wantEffective: true,
},
},
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
"TestKeyA": setting.RawItemWith("TestValueA", nil, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)),
}, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)),
},
{
name: "DevicePolicy/OneAdditionalSource",
scope: setting.DeviceScope,
additionalSources: []sourceConfig{
{
name: "TestSourceA",
scope: setting.DeviceScope,
settingKey: "TestKeyA",
settingValue: "TestValueA",
wantEffective: true,
},
},
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
"TestKeyA": setting.RawItemWith("TestValueA", nil, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)),
}, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)),
},
{
name: "DevicePolicy/ManyInitialSources/NoConflicts",
scope: setting.DeviceScope,
initialSources: []sourceConfig{
{
name: "TestSourceA",
scope: setting.DeviceScope,
settingKey: "TestKeyA",
settingValue: "TestValueA",
wantEffective: true,
},
{
name: "TestSourceB",
scope: setting.DeviceScope,
settingKey: "TestKeyB",
settingValue: "TestValueB",
wantEffective: true,
},
{
name: "TestSourceC",
scope: setting.DeviceScope,
settingKey: "TestKeyC",
settingValue: "TestValueC",
wantEffective: true,
},
},
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
"TestKeyA": setting.RawItemWith("TestValueA", nil, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)),
"TestKeyB": setting.RawItemWith("TestValueB", nil, setting.NewNamedOrigin("TestSourceB", setting.DeviceScope)),
"TestKeyC": setting.RawItemWith("TestValueC", nil, setting.NewNamedOrigin("TestSourceC", setting.DeviceScope)),
}, setting.DeviceScope),
},
{
name: "DevicePolicy/ManyInitialSources/Conflicts",
scope: setting.DeviceScope,
initialSources: []sourceConfig{
{
name: "TestSourceA",
scope: setting.DeviceScope,
settingKey: "TestKeyA",
settingValue: "TestValueA",
wantEffective: true,
},
{
name: "TestSourceB",
scope: setting.DeviceScope,
settingKey: "TestKeyB",
settingValue: "TestValueB",
wantEffective: true,
},
{
name: "TestSourceC",
scope: setting.DeviceScope,
settingKey: "TestKeyA",
settingValue: "TestValueC",
wantEffective: true,
},
},
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
"TestKeyA": setting.RawItemWith("TestValueC", nil, setting.NewNamedOrigin("TestSourceC", setting.DeviceScope)),
"TestKeyB": setting.RawItemWith("TestValueB", nil, setting.NewNamedOrigin("TestSourceB", setting.DeviceScope)),
}, setting.DeviceScope),
},
{
name: "DevicePolicy/MixedSources/Conflicts",
scope: setting.DeviceScope,
initialSources: []sourceConfig{
{
name: "TestSourceA",
scope: setting.DeviceScope,
settingKey: "TestKeyA",
settingValue: "TestValueA",
wantEffective: true,
},
{
name: "TestSourceB",
scope: setting.DeviceScope,
settingKey: "TestKeyB",
settingValue: "TestValueB",
wantEffective: true,
},
{
name: "TestSourceC",
scope: setting.DeviceScope,
settingKey: "TestKeyA",
settingValue: "TestValueC",
wantEffective: true,
},
},
additionalSources: []sourceConfig{
{
name: "TestSourceD",
scope: setting.DeviceScope,
settingKey: "TestKeyA",
settingValue: "TestValueD",
wantEffective: true,
},
{
name: "TestSourceE",
scope: setting.DeviceScope,
settingKey: "TestKeyC",
settingValue: "TestValueE",
wantEffective: true,
},
{
name: "TestSourceF",
scope: setting.DeviceScope,
settingKey: "TestKeyA",
settingValue: "TestValueF",
wantEffective: true,
},
},
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
"TestKeyA": setting.RawItemWith("TestValueF", nil, setting.NewNamedOrigin("TestSourceF", setting.DeviceScope)),
"TestKeyB": setting.RawItemWith("TestValueB", nil, setting.NewNamedOrigin("TestSourceB", setting.DeviceScope)),
"TestKeyC": setting.RawItemWith("TestValueE", nil, setting.NewNamedOrigin("TestSourceE", setting.DeviceScope)),
}, setting.DeviceScope),
},
{
name: "UserScope/Init-DeviceSource",
scope: setting.CurrentUserScope,
initialSources: []sourceConfig{
{
name: "TestSourceDevice",
scope: setting.DeviceScope,
settingKey: "TestKeyA",
settingValue: "DeviceValue",
wantEffective: true,
},
},
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
"TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)),
}, setting.CurrentUserScope, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)),
},
{
name: "UserScope/Init-DeviceSource/Add-UserSource",
scope: setting.CurrentUserScope,
initialSources: []sourceConfig{
{
name: "TestSourceDevice",
scope: setting.DeviceScope,
settingKey: "TestKeyA",
settingValue: "DeviceValue",
wantEffective: true,
},
},
additionalSources: []sourceConfig{
{
name: "TestSourceUser",
scope: setting.CurrentUserScope,
settingKey: "TestKeyB",
settingValue: "UserValue",
wantEffective: true,
},
},
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
"TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)),
"TestKeyB": setting.RawItemWith("UserValue", nil, setting.NewNamedOrigin("TestSourceUser", setting.CurrentUserScope)),
}, setting.CurrentUserScope),
},
{
name: "UserScope/Init-DeviceSource/Add-UserSource-and-ProfileSource",
scope: setting.CurrentUserScope,
initialSources: []sourceConfig{
{
name: "TestSourceDevice",
scope: setting.DeviceScope,
settingKey: "TestKeyA",
settingValue: "DeviceValue",
wantEffective: true,
},
},
additionalSources: []sourceConfig{
{
name: "TestSourceProfile",
scope: setting.CurrentProfileScope,
settingKey: "TestKeyB",
settingValue: "ProfileValue",
wantEffective: true,
},
{
name: "TestSourceUser",
scope: setting.CurrentUserScope,
settingKey: "TestKeyB",
settingValue: "UserValue",
wantEffective: true,
},
},
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
"TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)),
"TestKeyB": setting.RawItemWith("ProfileValue", nil, setting.NewNamedOrigin("TestSourceProfile", setting.CurrentProfileScope)),
}, setting.CurrentUserScope),
},
{
name: "DevicePolicy/User-Source-does-not-apply",
scope: setting.DeviceScope,
initialSources: []sourceConfig{
{
name: "TestSourceDevice",
scope: setting.DeviceScope,
settingKey: "TestKeyA",
settingValue: "DeviceValue",
wantEffective: true,
},
},
additionalSources: []sourceConfig{
{
name: "TestSourceUser",
scope: setting.CurrentUserScope,
settingKey: "TestKeyA",
settingValue: "UserValue",
wantEffective: false, // Registering a user source should have no impact on the device policy.
},
},
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
"TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)),
}, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Register all settings that we use in this test.
var definitions []*setting.Definition
for _, source := range slices.Concat(tt.initialSources, tt.additionalSources) {
definitions = append(definitions, setting.NewDefinition(source.settingKey, tt.scope.Kind(), setting.StringValue))
}
if err := setting.SetDefinitionsForTest(t, definitions...); err != nil {
t.Fatalf("SetDefinitionsForTest failed: %v", err)
}
// Add the initial policy sources.
var wantSources []*source.Source
for _, s := range tt.initialSources {
store := source.NewTestStoreOf(t, source.TestSettingOf(s.settingKey, s.settingValue))
source := source.NewSource(s.name, s.scope, store)
if err := registerSource(source); err != nil {
t.Fatalf("Failed to register policy source: %v", source)
}
if s.wantEffective {
wantSources = append(wantSources, source)
}
t.Cleanup(func() { unregisterSource(source) })
}
// Retrieve the effective policy.
policy, err := policyForTest(t, tt.scope)
if err != nil {
t.Fatalf("Failed to get effective policy for %v: %v", tt.scope, err)
}
checkPolicySources(t, policy, wantSources)
// Add additional setting sources.
for _, s := range tt.additionalSources {
store := source.NewTestStoreOf(t, source.TestSettingOf(s.settingKey, s.settingValue))
source := source.NewSource(s.name, s.scope, store)
if err := registerSource(source); err != nil {
t.Fatalf("Failed to register additional policy source: %v", source)
}
if s.wantEffective {
wantSources = append(wantSources, source)
}
t.Cleanup(func() { unregisterSource(source) })
}
checkPolicySources(t, policy, wantSources)
// Verify the final effective settings snapshots.
if got := policy.Get(); !got.Equal(tt.wantSnapshot) {
t.Errorf("Snapshot: got %v; want %v", got, tt.wantSnapshot)
}
})
}
}
func TestPolicyFor(t *testing.T) {
tests := []struct {
name string
scopeA, scopeB setting.PolicyScope
closePolicy bool // indicates whether to close policyA before retrieving policyB
wantSame bool // specifies whether policyA and policyB should reference the same [Policy] instance
}{
{
name: "Device/Device",
scopeA: setting.DeviceScope,
scopeB: setting.DeviceScope,
wantSame: true,
},
{
name: "Device/CurrentProfile",
scopeA: setting.DeviceScope,
scopeB: setting.CurrentProfileScope,
wantSame: false,
},
{
name: "Device/CurrentUser",
scopeA: setting.DeviceScope,
scopeB: setting.CurrentUserScope,
wantSame: false,
},
{
name: "CurrentProfile/CurrentProfile",
scopeA: setting.CurrentProfileScope,
scopeB: setting.CurrentProfileScope,
wantSame: true,
},
{
name: "CurrentProfile/CurrentUser",
scopeA: setting.CurrentProfileScope,
scopeB: setting.CurrentUserScope,
wantSame: false,
},
{
name: "CurrentUser/CurrentUser",
scopeA: setting.CurrentUserScope,
scopeB: setting.CurrentUserScope,
wantSame: true,
},
{
name: "UserA/UserA",
scopeA: setting.UserScopeOf("UserA"),
scopeB: setting.UserScopeOf("UserA"),
wantSame: true,
},
{
name: "UserA/UserB",
scopeA: setting.UserScopeOf("UserA"),
scopeB: setting.UserScopeOf("UserB"),
wantSame: false,
},
{
name: "New-after-close",
scopeA: setting.DeviceScope,
scopeB: setting.DeviceScope,
closePolicy: true,
wantSame: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
policyA, err := policyForTest(t, tt.scopeA)
if err != nil {
t.Fatalf("Failed to get effective policy for %v: %v", tt.scopeA, err)
}
if tt.closePolicy {
policyA.Close()
}
policyB, err := policyForTest(t, tt.scopeB)
if err != nil {
t.Fatalf("Failed to get effective policy for %v: %v", tt.scopeB, err)
}
if gotSame := policyA == policyB; gotSame != tt.wantSame {
t.Fatalf("Got same: %v; want same %v", gotSame, tt.wantSame)
}
})
}
}
func TestPolicyChangeHasChanged(t *testing.T) {
tests := []struct {
name string
old, new map[setting.Key]setting.RawItem
wantChanged []setting.Key
wantUnchanged []setting.Key
}{
{
name: "String-Settings",
old: map[setting.Key]setting.RawItem{
"ChangedSetting": setting.RawItemOf("Old"),
"UnchangedSetting": setting.RawItemOf("Value"),
},
new: map[setting.Key]setting.RawItem{
"ChangedSetting": setting.RawItemOf("New"),
"UnchangedSetting": setting.RawItemOf("Value"),
},
wantChanged: []setting.Key{"ChangedSetting"},
wantUnchanged: []setting.Key{"UnchangedSetting"},
},
{
name: "UInt64-Settings",
old: map[setting.Key]setting.RawItem{
"ChangedSetting": setting.RawItemOf(uint64(0)),
"UnchangedSetting": setting.RawItemOf(uint64(42)),
},
new: map[setting.Key]setting.RawItem{
"ChangedSetting": setting.RawItemOf(uint64(1)),
"UnchangedSetting": setting.RawItemOf(uint64(42)),
},
wantChanged: []setting.Key{"ChangedSetting"},
wantUnchanged: []setting.Key{"UnchangedSetting"},
},
{
name: "StringSlice-Settings",
old: map[setting.Key]setting.RawItem{
"ChangedSetting": setting.RawItemOf([]string{"Chicago"}),
"UnchangedSetting": setting.RawItemOf([]string{"String1", "String2"}),
},
new: map[setting.Key]setting.RawItem{
"ChangedSetting": setting.RawItemOf([]string{"New York"}),
"UnchangedSetting": setting.RawItemOf([]string{"String1", "String2"}),
},
wantChanged: []setting.Key{"ChangedSetting"},
wantUnchanged: []setting.Key{"UnchangedSetting"},
},
{
name: "Int8-Settings", // We don't have actual int8 settings, but this should still work.
old: map[setting.Key]setting.RawItem{
"ChangedSetting": setting.RawItemOf(int8(0)),
"UnchangedSetting": setting.RawItemOf(int8(42)),
},
new: map[setting.Key]setting.RawItem{
"ChangedSetting": setting.RawItemOf(int8(1)),
"UnchangedSetting": setting.RawItemOf(int8(42)),
},
wantChanged: []setting.Key{"ChangedSetting"},
wantUnchanged: []setting.Key{"UnchangedSetting"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
old := setting.NewSnapshot(tt.old)
new := setting.NewSnapshot(tt.new)
change := PolicyChange{Change[*setting.Snapshot]{old, new}}
for _, wantChanged := range tt.wantChanged {
if !change.HasChanged(wantChanged) {
t.Errorf("%q changed: got false; want true", wantChanged)
}
}
for _, wantUnchanged := range tt.wantUnchanged {
if change.HasChanged(wantUnchanged) {
t.Errorf("%q unchanged: got true; want false", wantUnchanged)
}
}
})
}
}
func TestChangePolicySetting(t *testing.T) {
setForTest(t, &policyReloadMinDelay, 100*time.Millisecond)
setForTest(t, &policyReloadMaxDelay, 500*time.Millisecond)
// Register policy settings used in this test.
settingA := setting.NewDefinition("TestSettingA", setting.DeviceSetting, setting.StringValue)
settingB := setting.NewDefinition("TestSettingB", setting.DeviceSetting, setting.StringValue)
if err := setting.SetDefinitionsForTest(t, settingA, settingB); err != nil {
t.Fatalf("SetDefinitionsForTest failed: %v", err)
}
// Register a test policy store and create a effective policy that reads the policy settings from it.
store := source.NewTestStoreOf[string](t)
if _, err := RegisterStoreForTest(t, "TestSource", setting.DeviceScope, store); err != nil {
t.Fatalf("Failed to register policy store: %v", err)
}
policy, err := policyForTest(t, setting.DeviceScope)
if err != nil {
t.Fatalf("Failed to get effective policy: %v", err)
}
// The policy setting is not configured yet.
if _, ok := policy.Get().GetSetting(settingA.Key()); ok {
t.Fatalf("Policy setting %q unexpectedly exists", settingA.Key())
}
// Subscribe to the policy change callback...
policyChanged := make(chan *PolicyChange)
unregister := policy.RegisterChangeCallback(func(pc *PolicyChange) { policyChanged <- pc })
t.Cleanup(unregister)
// ...make the change, and measure the time between initiating the change
// and receiving the callback.
start := time.Now()
const wantValueA = "TestValueA"
store.SetStrings(source.TestSettingOf(settingA.Key(), wantValueA))
change := <-policyChanged
gotDelay := time.Since(start)
// Ensure there is at least a [policyReloadMinDelay] delay between
// a change and the policy reload along with the callback invocation.
// This prevents reloading policy settings too frequently
// when multiple settings change within a short period of time.
if gotDelay < policyReloadMinDelay {
t.Errorf("Delay: got %v; want >= %v", gotDelay, policyReloadMinDelay)
}
// Verify that the [PolicyChange] passed to the policy change callback
// contains the correct information regarding the policy setting changes.
if !change.HasChanged(settingA.Key()) {
t.Errorf("Policy setting %q has not changed", settingA.Key())
}
if change.HasChanged(settingB.Key()) {
t.Errorf("Policy setting %q was unexpectedly changed", settingB.Key())
}
if _, ok := change.Old().GetSetting(settingA.Key()); ok {
t.Fatalf("Policy setting %q unexpectedly exists", settingA.Key())
}
if gotValue := change.New().Get(settingA.Key()); gotValue != wantValueA {
t.Errorf("Policy setting %q: got %q; want %q", settingA.Key(), gotValue, wantValueA)
}
// And also verify that the current (most recent) [setting.Snapshot]
// includes the change we just made.
if gotValue := policy.Get().Get(settingA.Key()); gotValue != wantValueA {
t.Errorf("Policy setting %q: got %q; want %q", settingA.Key(), gotValue, wantValueA)
}
// Now, let's change another policy setting value N times.
const N = 10
wantValueB := strconv.Itoa(N)
start = time.Now()
for i := range N {
store.SetStrings(source.TestSettingOf(settingB.Key(), strconv.Itoa(i+1)))
}
// The callback should be invoked only once, even though the policy setting
// has changed N times.
change = <-policyChanged
gotDelay = time.Since(start)
gotCallbacks := 1
drain:
for {
select {
case <-policyChanged:
gotCallbacks++
case <-time.After(policyReloadMaxDelay):
break drain
}
}
if wantCallbacks := 1; gotCallbacks > wantCallbacks {
t.Errorf("Callbacks: got %d; want %d", gotCallbacks, wantCallbacks)
}
// Additionally, the policy change callback should be received no sooner
// than [policyReloadMinDelay] and no later than [policyReloadMaxDelay].
if gotDelay < policyReloadMinDelay || gotDelay > policyReloadMaxDelay {
t.Errorf("Delay: got %v; want >= %v && <= %v", gotDelay, policyReloadMinDelay, policyReloadMaxDelay)
}
// Verify that the [PolicyChange] received via the callback
// contains the final policy setting value.
if !change.HasChanged(settingB.Key()) {
t.Errorf("Policy setting %q has not changed", settingB.Key())
}
if change.HasChanged(settingA.Key()) {
t.Errorf("Policy setting %q was unexpectedly changed", settingA.Key())
}
if _, ok := change.Old().GetSetting(settingB.Key()); ok {
t.Fatalf("Policy setting %q unexpectedly exists", settingB.Key())
}
if gotValue := change.New().Get(settingB.Key()); gotValue != wantValueB {
t.Errorf("Policy setting %q: got %q; want %q", settingB.Key(), gotValue, wantValueB)
}
// Lastly, if a policy store issues a change notification, but the effective policy
// remains unchanged, the [Policy] should ignore it without invoking the change callbacks.
store.NotifyPolicyChanged()
select {
case <-policyChanged:
t.Fatal("Unexpected policy changed notification")
case <-time.After(policyReloadMaxDelay):
}
}
func TestClosePolicySource(t *testing.T) {
testSetting := setting.NewDefinition("TestSetting", setting.DeviceSetting, setting.StringValue)
if err := setting.SetDefinitionsForTest(t, testSetting); err != nil {
t.Fatalf("SetDefinitionsForTest failed: %v", err)
}
wantSettingValue := "TestValue"
store := source.NewTestStoreOf(t, source.TestSettingOf(testSetting.Key(), wantSettingValue))
if _, err := RegisterStoreForTest(t, "TestSource", setting.DeviceScope, store); err != nil {
t.Fatalf("Failed to register policy store: %v", err)
}
policy, err := policyForTest(t, setting.DeviceScope)
if err != nil {
t.Fatalf("Failed to get effective policy: %v", err)
}
initialSnapshot, err := policy.Reload()
if err != nil {
t.Fatalf("Failed to reload policy: %v", err)
}
if gotSettingValue, err := initialSnapshot.GetErr(testSetting.Key()); err != nil {
t.Fatalf("Failed to get %q setting value: %v", testSetting.Key(), err)
} else if gotSettingValue != wantSettingValue {
t.Fatalf("Setting %q: got %q; want %q", testSetting.Key(), gotSettingValue, wantSettingValue)
}
store.Close()
// Closing a policy source abruptly without removing it first should invalidate and close the policy.
<-policy.Done()
if policy.IsValid() {
t.Fatal("The policy was not properly closed")
}
// The resulting policy snapshot should remain valid and unchanged.
finalSnapshot := policy.Get()
if !finalSnapshot.Equal(initialSnapshot) {
t.Fatal("Policy snapshot has changed")
}
if gotSettingValue, err := finalSnapshot.GetErr(testSetting.Key()); err != nil {
t.Fatalf("Failed to get final %q setting value: %v", testSetting.Key(), err)
} else if gotSettingValue != wantSettingValue {
t.Fatalf("Setting %q: got %q; want %q", testSetting.Key(), gotSettingValue, wantSettingValue)
}
// However, any further requests to reload the policy should fail.
if _, err := policy.Reload(); err == nil || !errors.Is(err, ErrPolicyClosed) {
t.Fatalf("Reload: gotErr: %v; wantErr: %v", err, ErrPolicyClosed)
}
}
func TestRemovePolicySource(t *testing.T) {
// Register policy settings used in this test.
settingA := setting.NewDefinition("TestSettingA", setting.DeviceSetting, setting.StringValue)
settingB := setting.NewDefinition("TestSettingB", setting.DeviceSetting, setting.StringValue)
if err := setting.SetDefinitionsForTest(t, settingA, settingB); err != nil {
t.Fatalf("SetDefinitionsForTest failed: %v", err)
}
// Register two policy stores.
storeA := source.NewTestStoreOf(t, source.TestSettingOf(settingA.Key(), "A"))
storeRegA, err := RegisterStoreForTest(t, "TestSourceA", setting.DeviceScope, storeA)
if err != nil {
t.Fatalf("Failed to register policy store A: %v", err)
}
storeB := source.NewTestStoreOf(t, source.TestSettingOf(settingB.Key(), "B"))
storeRegB, err := RegisterStoreForTest(t, "TestSourceB", setting.DeviceScope, storeB)
if err != nil {
t.Fatalf("Failed to register policy store A: %v", err)
}
// Create a effective [Policy] that reads policy settings from the two stores.
policy, err := policyForTest(t, setting.DeviceScope)
if err != nil {
t.Fatalf("Failed to get effective policy: %v", err)
}
// Verify that the [Policy] uses both stores and includes policy settings from each.
if gotSources, wantSources := len(policy.sources), 2; gotSources != wantSources {
t.Fatalf("Policy Sources: got %v; want %v", gotSources, wantSources)
}
if got, want := policy.Get().Get(settingA.Key()), "A"; got != want {
t.Fatalf("Setting %q: got %q; want %q", settingA.Key(), got, want)
}
if got, want := policy.Get().Get(settingB.Key()), "B"; got != want {
t.Fatalf("Setting %q: got %q; want %q", settingB.Key(), got, want)
}
// Unregister Store A and verify that the effective policy remains valid.
// It should no longer use the removed store or include any policy settings from it.
if err := storeRegA.Unregister(); err != nil {
t.Fatalf("Failed to unregister Store A: %v", err)
}
if !policy.IsValid() {
t.Fatalf("Policy was unexpectedly closed")
}
if gotSources, wantSources := len(policy.sources), 1; gotSources != wantSources {
t.Fatalf("Policy Sources: got %v; want %v", gotSources, wantSources)
}
if got, want := policy.Get().Get(settingA.Key()), any(nil); got != want {
t.Fatalf("Setting %q: got %q; want %q", settingA.Key(), got, want)
}
if got, want := policy.Get().Get(settingB.Key()), "B"; got != want {
t.Fatalf("Setting %q: got %q; want %q", settingB.Key(), got, want)
}
// Unregister Store B and verify that the effective policy is still valid.
// However, it should be empty since there are no associated policy sources.
if err := storeRegB.Unregister(); err != nil {
t.Fatalf("Failed to unregister Store B: %v", err)
}
if !policy.IsValid() {
t.Fatalf("Policy was unexpectedly closed")
}
if gotSources, wantSources := len(policy.sources), 0; gotSources != wantSources {
t.Fatalf("Policy Sources: got %v; want %v", gotSources, wantSources)
}
if got := policy.Get(); got.Len() != 0 {
t.Fatalf("Settings: got %v; want {Empty}", got)
}
}
func TestReplacePolicySource(t *testing.T) {
setForTest(t, &policyReloadMinDelay, 100*time.Millisecond)
setForTest(t, &policyReloadMaxDelay, 500*time.Millisecond)
// Register policy settings used in this test.
testSetting := setting.NewDefinition("TestSettingA", setting.DeviceSetting, setting.StringValue)
if err := setting.SetDefinitionsForTest(t, testSetting); err != nil {
t.Fatalf("SetDefinitionsForTest failed: %v", err)
}
// Create two policy stores.
initialStore := source.NewTestStoreOf(t, source.TestSettingOf(testSetting.Key(), "InitialValue"))
newStore := source.NewTestStoreOf(t, source.TestSettingOf(testSetting.Key(), "NewValue"))
unchangedStore := source.NewTestStoreOf(t, source.TestSettingOf(testSetting.Key(), "NewValue"))
// Register the initial store and create a effective [Policy] that reads policy settings from it.
reg, err := RegisterStoreForTest(t, "TestStore", setting.DeviceScope, initialStore)
if err != nil {
t.Fatalf("Failed to register the initial store: %v", err)
}
policy, err := policyForTest(t, setting.DeviceScope)
if err != nil {
t.Fatalf("Failed to get effective policy: %v", err)
}
// Verify that the test setting has its initial value.
if got, want := policy.Get().Get(testSetting.Key()), "InitialValue"; got != want {
t.Fatalf("Setting %q: got %q; want %q", testSetting.Key(), got, want)
}
// Subscribe to the policy change callback.
policyChanged := make(chan *PolicyChange, 1)
unregister := policy.RegisterChangeCallback(func(pc *PolicyChange) { policyChanged <- pc })
t.Cleanup(unregister)
// Now, let's replace the initial store with the new store.
reg, err = reg.ReplaceStore(newStore)
if err != nil {
t.Fatalf("Failed to replace the policy store: %v", err)
}
t.Cleanup(func() { reg.Unregister() })
// We should receive a policy change notification as the setting value has changed.
<-policyChanged
// Verify that the test setting has the new value.
if got, want := policy.Get().Get(testSetting.Key()), "NewValue"; got != want {
t.Fatalf("Setting %q: got %q; want %q", testSetting.Key(), got, want)
}
// Replacing a policy store with an identical one containing the same
// values for the same settings should not be considered a policy change.
reg, err = reg.ReplaceStore(unchangedStore)
if err != nil {
t.Fatalf("Failed to replace the policy store: %v", err)
}
t.Cleanup(func() { reg.Unregister() })
select {
case <-policyChanged:
t.Fatal("Unexpected policy changed notification")
default:
<-time.After(policyReloadMaxDelay)
}
}
func TestAddClosedPolicySource(t *testing.T) {
store := source.NewTestStoreOf[string](t)
if _, err := RegisterStoreForTest(t, "TestSource", setting.DeviceScope, store); err != nil {
t.Fatalf("Failed to register policy store: %v", err)
}
store.Close()
_, err := policyForTest(t, setting.DeviceScope)
if err == nil || !errors.Is(err, source.ErrStoreClosed) {
t.Fatalf("got: %v; want: %v", err, source.ErrStoreClosed)
}
}
func TestClosePolicyMoreThanOnce(t *testing.T) {
tests := []struct {
name string
numSources int
}{
{
name: "NoSources",
numSources: 0,
},
{
name: "OneSource",
numSources: 1,
},
{
name: "ManySources",
numSources: 10,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
for i := range tt.numSources {
store := source.NewTestStoreOf[string](t)
if _, err := RegisterStoreForTest(t, "TestSource #"+strconv.Itoa(i), setting.DeviceScope, store); err != nil {
t.Fatalf("Failed to register policy store: %v", err)
}
}
policy, err := policyForTest(t, setting.DeviceScope)
if err != nil {
t.Fatalf("failed to get effective policy: %v", err)
}
const N = 10000
var wg sync.WaitGroup
for range N {
wg.Add(1)
go func() {
wg.Done()
policy.Close()
<-policy.Done()
}()
}
wg.Wait()
})
}
}
func checkPolicySources(tb testing.TB, gotPolicy *Policy, wantSources []*source.Source) {
tb.Helper()
sort.SliceStable(wantSources, func(i, j int) bool {
return wantSources[i].Compare(wantSources[j]) < 0
})
gotSources := make([]*source.Source, len(gotPolicy.sources))
for i := range gotPolicy.sources {
gotSources[i] = gotPolicy.sources[i].Source
}
type sourceSummary struct{ Name, Scope string }
toSourceSummary := cmp.Transformer("source", func(s *source.Source) sourceSummary { return sourceSummary{s.Name(), s.Scope().String()} })
if diff := cmp.Diff(wantSources, gotSources, toSourceSummary, cmpopts.EquateEmpty()); diff != "" {
tb.Errorf("Policy Sources mismatch: %v", diff)
}
}
// policyForTest is like [PolicyFor], but it deletes the policy
// when tb and all its subtests complete.
func policyForTest(tb testing.TB, target setting.PolicyScope) (*Policy, error) {
tb.Helper()
policy, err := PolicyFor(target)
if err != nil {
return nil, err
}
tb.Cleanup(func() {
policy.Close()
<-policy.Done()
deletePolicy(policy)
})
return policy, nil
}
func setForTest[T any](tb testing.TB, target *T, newValue T) {
oldValue := *target
tb.Cleanup(func() { *target = oldValue })
*target = newValue
}

@ -0,0 +1,174 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package rsop facilitates [source.Store] registration via [RegisterStore]
// and provides access to the effective policy merged from all registered sources
// via [PolicyFor].
package rsop
import (
"errors"
"fmt"
"slices"
"sync"
"tailscale.com/syncs"
"tailscale.com/util/slicesx"
"tailscale.com/util/syspolicy/internal"
"tailscale.com/util/syspolicy/setting"
"tailscale.com/util/syspolicy/source"
)
var (
policyMu sync.Mutex // protects [policySources] and [effectivePolicies]
policySources []*source.Source // all registered policy sources
effectivePolicies []*Policy // all active (non-closed) effective policies returned by [PolicyFor]
// effectivePolicyLRU is an LRU cache of [Policy] by [setting.Scope].
// Although there could be multiple [setting.PolicyScope] instances with the same [setting.Scope],
// such as two user scopes for different users, there is only one [setting.DeviceScope], only one
// [setting.CurrentProfileScope], and in most cases, only one active user scope.
// Therefore, cache misses that require falling back to [effectivePolicies] are extremely rare.
// It's a fixed-size array of atomic values and can be accessed without [policyMu] held.
effectivePolicyLRU [setting.NumScopes]syncs.AtomicValue[*Policy]
)
// PolicyFor returns the [Policy] for the specified scope,
// creating it from the registered [source.Store]s if it doesn't already exist.
func PolicyFor(scope setting.PolicyScope) (*Policy, error) {
if err := internal.Init.Do(); err != nil {
return nil, err
}
policy := effectivePolicyLRU[scope.Kind()].Load()
if policy != nil && policy.Scope() == scope && policy.IsValid() {
return policy, nil
}
return policyForSlow(scope)
}
func policyForSlow(scope setting.PolicyScope) (policy *Policy, err error) {
defer func() {
// Always update the LRU cache on exit if we found (or created)
// a policy for the specified scope.
if policy != nil {
effectivePolicyLRU[scope.Kind()].Store(policy)
}
}()
policyMu.Lock()
defer policyMu.Unlock()
if policy, ok := findPolicyByScopeLocked(scope); ok {
return policy, nil
}
// If there is no existing effective policy for the specified scope,
// we need to create one using the policy sources registered for that scope.
sources := slicesx.Filter(nil, policySources, func(source *source.Source) bool {
return source.Scope().Contains(scope)
})
policy, err = newPolicy(scope, sources...)
if err != nil {
return nil, err
}
effectivePolicies = append(effectivePolicies, policy)
return policy, nil
}
// findPolicyByScopeLocked returns a policy with the specified scope and true if
// one exists in the [effectivePolicies] list, otherwise it returns nil, false.
// [policyMu] must be held.
func findPolicyByScopeLocked(target setting.PolicyScope) (policy *Policy, ok bool) {
for _, policy := range effectivePolicies {
if policy.Scope() == target && policy.IsValid() {
return policy, true
}
}
return nil, false
}
// deletePolicy deletes the specified effective policy from [effectivePolicies]
// and [effectivePolicyLRU].
func deletePolicy(policy *Policy) {
policyMu.Lock()
defer policyMu.Unlock()
if i := slices.Index(effectivePolicies, policy); i != -1 {
effectivePolicies = slices.Delete(effectivePolicies, i, i+1)
}
effectivePolicyLRU[policy.Scope().Kind()].CompareAndSwap(policy, nil)
}
// registerSource registers the specified [source.Source] to be used by the package.
// It updates existing [Policy]s returned by [PolicyFor] to use this source if
// they are within the source's [setting.PolicyScope].
func registerSource(source *source.Source) error {
policyMu.Lock()
defer policyMu.Unlock()
if slices.Contains(policySources, source) {
// already registered
return nil
}
policySources = append(policySources, source)
return forEachEffectivePolicyLocked(func(policy *Policy) error {
if !source.Scope().Contains(policy.Scope()) {
// Policy settings in the specified source do not apply
// to the scope of this effective policy.
// For example, a user policy source is being registered
// while the effective policy is for the device (or another user).
return nil
}
return policy.addSource(source)
})
}
// replaceSource is like [unregisterSource](old) followed by [registerSource](new),
// but performed atomically: the effective policy will contain settings
// either from the old source or the new source, never both and never neither.
func replaceSource(old, new *source.Source) error {
policyMu.Lock()
defer policyMu.Unlock()
oldIndex := slices.Index(policySources, old)
if oldIndex == -1 {
return fmt.Errorf("the source is not registered: %v", old)
}
policySources[oldIndex] = new
return forEachEffectivePolicyLocked(func(policy *Policy) error {
if !old.Scope().Contains(policy.Scope()) || !new.Scope().Contains(policy.Scope()) {
return nil
}
return policy.replaceSource(old, new)
})
}
// unregisterSource unregisters the specified [source.Source],
// so that it won't be used by any new or existing [Policy].
func unregisterSource(source *source.Source) error {
policyMu.Lock()
defer policyMu.Unlock()
index := slices.Index(policySources, source)
if index == -1 {
return nil
}
policySources = slices.Delete(policySources, index, index+1)
return forEachEffectivePolicyLocked(func(policy *Policy) error {
if !source.Scope().Contains(policy.Scope()) {
return nil
}
return policy.removeSource(source)
})
}
// forEachEffectivePolicyLocked calls fn for every non-closed [Policy] in [effectivePolicies].
// It accumulates the returned errors and returns an error that wraps all errors returned by fn.
// The [policyMu] mutex must be held while this function is executed.
func forEachEffectivePolicyLocked(fn func(p *Policy) error) error {
var errs []error
for _, policy := range effectivePolicies {
if policy.IsValid() {
err := fn(policy)
if err != nil && !errors.Is(err, ErrPolicyClosed) {
errs = append(errs, err)
}
}
}
return errors.Join(errs...)
}

@ -0,0 +1,94 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package rsop
import (
"errors"
"sync"
"sync/atomic"
"tailscale.com/util/syspolicy/internal"
"tailscale.com/util/syspolicy/setting"
"tailscale.com/util/syspolicy/source"
)
// ErrAlreadyConsumed is the error returned when [StoreRegistration.ReplaceStore]
// or [StoreRegistration.Unregister] is called more than once.
var ErrAlreadyConsumed = errors.New("the store registration is no longer valid")
// StoreRegistration is a [source.Store] registered for use in the specified scope.
// It can be used to unregister the store, or replace it with another one.
type StoreRegistration struct {
source *source.Source
m sync.Mutex // protects the [StoreRegistration.consumeSlow] path
consumed atomic.Bool // can be read without holding m, but must be written with m held
}
// RegisterStore registers a new policy [source.Store] with the specified name and [setting.PolicyScope].
func RegisterStore(name string, scope setting.PolicyScope, store source.Store) (*StoreRegistration, error) {
return newStoreRegistration(name, scope, store)
}
// RegisterStoreForTest is like [RegisterStore], but unregisters the store when
// tb and all its subtests complete.
func RegisterStoreForTest(tb internal.TB, name string, scope setting.PolicyScope, store source.Store) (*StoreRegistration, error) {
reg, err := RegisterStore(name, scope, store)
if err == nil {
tb.Cleanup(func() {
if err := reg.Unregister(); err != nil && !errors.Is(err, ErrAlreadyConsumed) {
tb.Fatalf("Unregister failed: %v", err)
}
})
}
return reg, err // may be nil or non-nil
}
func newStoreRegistration(name string, scope setting.PolicyScope, store source.Store) (*StoreRegistration, error) {
source := source.NewSource(name, scope, store)
if err := registerSource(source); err != nil {
return nil, err
}
return &StoreRegistration{source: source}, nil
}
// ReplaceStore replaces the registered store with the new one,
// returning a new [StoreRegistration] or an error.
func (r *StoreRegistration) ReplaceStore(new source.Store) (*StoreRegistration, error) {
var res *StoreRegistration
err := r.consume(func() error {
newSource := source.NewSource(r.source.Name(), r.source.Scope(), new)
if err := replaceSource(r.source, newSource); err != nil {
return err
}
res = &StoreRegistration{source: newSource}
return nil
})
return res, err
}
// Unregister reverts the registration.
func (r *StoreRegistration) Unregister() error {
return r.consume(func() error { return unregisterSource(r.source) })
}
// consume invokes fn, consuming r if no error is returned.
// It returns [ErrAlreadyConsumed] on subsequent calls after the first successful call.
func (r *StoreRegistration) consume(fn func() error) (err error) {
if r.consumed.Load() {
return ErrAlreadyConsumed
}
return r.consumeSlow(fn)
}
func (r *StoreRegistration) consumeSlow(fn func() error) (err error) {
r.m.Lock()
defer r.m.Unlock()
if r.consumed.Load() {
return ErrAlreadyConsumed
}
if err = fn(); err == nil {
r.consumed.Store(true)
}
return err // may be nil or non-nil
}

@ -8,6 +8,7 @@ import (
"strings" "strings"
"tailscale.com/types/lazy" "tailscale.com/types/lazy"
"tailscale.com/util/syspolicy/internal"
) )
var ( var (
@ -35,6 +36,8 @@ type PolicyScope struct {
// when querying policy settings. // when querying policy settings.
// It returns [DeviceScope], unless explicitly changed with [SetDefaultScope]. // It returns [DeviceScope], unless explicitly changed with [SetDefaultScope].
func DefaultScope() PolicyScope { func DefaultScope() PolicyScope {
// Allow deferred package init functions to override the default scope.
internal.Init.Do()
return lazyDefaultScope.Get(func() PolicyScope { return DeviceScope }) return lazyDefaultScope.Get(func() PolicyScope { return DeviceScope })
} }

@ -243,6 +243,9 @@ func registerLocked(d *Definition) {
func settingDefinitions() (DefinitionMap, error) { func settingDefinitions() (DefinitionMap, error) {
return definitions.GetErr(func() (DefinitionMap, error) { return definitions.GetErr(func() (DefinitionMap, error) {
if err := internal.Init.Do(); err != nil {
return nil, err
}
definitionsMu.Lock() definitionsMu.Lock()
defer definitionsMu.Unlock() defer definitionsMu.Unlock()
definitionsUsed = true definitionsUsed = true

@ -89,6 +89,7 @@ type TestStore struct {
suspendCount int // change callback are suspended if > 0 suspendCount int // change callback are suspended if > 0
mr, mw map[setting.Key]any // maps for reading and writing; they're the same unless the store is suspended. mr, mw map[setting.Key]any // maps for reading and writing; they're the same unless the store is suspended.
cbs set.HandleSet[func()] cbs set.HandleSet[func()]
closed bool
readsMu sync.Mutex readsMu sync.Mutex
reads map[testReadOperation]int // how many times a policy setting was read reads map[testReadOperation]int // how many times a policy setting was read
@ -98,24 +99,20 @@ type TestStore struct {
// The tb will be used to report coding errors detected by the [TestStore]. // The tb will be used to report coding errors detected by the [TestStore].
func NewTestStore(tb internal.TB) *TestStore { func NewTestStore(tb internal.TB) *TestStore {
m := make(map[setting.Key]any) m := make(map[setting.Key]any)
return &TestStore{ store := &TestStore{
tb: tb, tb: tb,
done: make(chan struct{}), done: make(chan struct{}),
mr: m, mr: m,
mw: m, mw: m,
} }
tb.Cleanup(store.Close)
return store
} }
// NewTestStoreOf is a shorthand for [NewTestStore] followed by [TestStore.SetBooleans], // NewTestStoreOf is a shorthand for [NewTestStore] followed by [TestStore.SetBooleans],
// [TestStore.SetUInt64s], [TestStore.SetStrings] or [TestStore.SetStringLists]. // [TestStore.SetUInt64s], [TestStore.SetStrings] or [TestStore.SetStringLists].
func NewTestStoreOf[T TestValueType](tb internal.TB, settings ...TestSetting[T]) *TestStore { func NewTestStoreOf[T TestValueType](tb internal.TB, settings ...TestSetting[T]) *TestStore {
m := make(map[setting.Key]any) store := NewTestStore(tb)
store := &TestStore{
tb: tb,
done: make(chan struct{}),
mr: m,
mw: m,
}
switch settings := any(settings).(type) { switch settings := any(settings).(type) {
case []TestSetting[bool]: case []TestSetting[bool]:
store.SetBooleans(settings...) store.SetBooleans(settings...)
@ -308,7 +305,7 @@ func (s *TestStore) Resume() {
s.mr = s.mw s.mr = s.mw
s.mu.Unlock() s.mu.Unlock()
s.storeLock.Unlock() s.storeLock.Unlock()
s.notifyPolicyChanged() s.NotifyPolicyChanged()
case s.suspendCount < 0: case s.suspendCount < 0:
s.tb.Fatal("negative suspendCount") s.tb.Fatal("negative suspendCount")
default: default:
@ -333,7 +330,7 @@ func (s *TestStore) SetBooleans(settings ...TestSetting[bool]) {
s.mu.Unlock() s.mu.Unlock()
} }
s.storeLock.Unlock() s.storeLock.Unlock()
s.notifyPolicyChanged() s.NotifyPolicyChanged()
} }
// SetUInt64s sets the specified integer settings in s. // SetUInt64s sets the specified integer settings in s.
@ -352,7 +349,7 @@ func (s *TestStore) SetUInt64s(settings ...TestSetting[uint64]) {
s.mu.Unlock() s.mu.Unlock()
} }
s.storeLock.Unlock() s.storeLock.Unlock()
s.notifyPolicyChanged() s.NotifyPolicyChanged()
} }
// SetStrings sets the specified string settings in s. // SetStrings sets the specified string settings in s.
@ -371,7 +368,7 @@ func (s *TestStore) SetStrings(settings ...TestSetting[string]) {
s.mu.Unlock() s.mu.Unlock()
} }
s.storeLock.Unlock() s.storeLock.Unlock()
s.notifyPolicyChanged() s.NotifyPolicyChanged()
} }
// SetStrings sets the specified string list settings in s. // SetStrings sets the specified string list settings in s.
@ -390,7 +387,7 @@ func (s *TestStore) SetStringLists(settings ...TestSetting[[]string]) {
s.mu.Unlock() s.mu.Unlock()
} }
s.storeLock.Unlock() s.storeLock.Unlock()
s.notifyPolicyChanged() s.NotifyPolicyChanged()
} }
// Delete deletes the specified settings from s. // Delete deletes the specified settings from s.
@ -402,7 +399,7 @@ func (s *TestStore) Delete(keys ...setting.Key) {
s.mu.Unlock() s.mu.Unlock()
} }
s.storeLock.Unlock() s.storeLock.Unlock()
s.notifyPolicyChanged() s.NotifyPolicyChanged()
} }
// Clear deletes all settings from s. // Clear deletes all settings from s.
@ -412,10 +409,10 @@ func (s *TestStore) Clear() {
clear(s.mw) clear(s.mw)
s.mu.Unlock() s.mu.Unlock()
s.storeLock.Unlock() s.storeLock.Unlock()
s.notifyPolicyChanged() s.NotifyPolicyChanged()
} }
func (s *TestStore) notifyPolicyChanged() { func (s *TestStore) NotifyPolicyChanged() {
s.mu.RLock() s.mu.RLock()
if s.suspendCount != 0 { if s.suspendCount != 0 {
s.mu.RUnlock() s.mu.RUnlock()
@ -439,9 +436,9 @@ func (s *TestStore) notifyPolicyChanged() {
func (s *TestStore) Close() { func (s *TestStore) Close() {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if s.done != nil { if !s.closed {
close(s.done) close(s.done)
s.done = nil s.closed = true
} }
} }

Loading…
Cancel
Save