diff --git a/ipn/ipnlocal/local_test.go b/ipn/ipnlocal/local_test.go index bd81a09c3..4debcdd8d 100644 --- a/ipn/ipnlocal/local_test.go +++ b/ipn/ipnlocal/local_test.go @@ -2881,20 +2881,16 @@ func TestSetExitNodeIDPolicy(t *testing.T) { }, } - syspolicy.RegisterWellKnownSettingsForTest(t) - for _, test := range tests { t.Run(test.name, func(t *testing.T) { - b := newTestBackend(t) - - policyStore := source.NewTestStore(t) + var polc policytest.Config if test.exitNodeIDKey { - policyStore.SetStrings(source.TestSettingOf(pkey.ExitNodeID, test.exitNodeID)) + polc.Set(pkey.ExitNodeID, test.exitNodeID) } if test.exitNodeIPKey { - policyStore.SetStrings(source.TestSettingOf(pkey.ExitNodeIP, test.exitNodeIP)) + polc.Set(pkey.ExitNodeIP, test.exitNodeIP) } - syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore) + b := newTestBackend(t, polc) if test.nm == nil { test.nm = new(netmap.NetworkMap) @@ -3026,15 +3022,13 @@ func TestUpdateNetmapDeltaAutoExitNode(t *testing.T) { }, } - syspolicy.RegisterWellKnownSettingsForTest(t) - policyStore := source.NewTestStoreOf(t, source.TestSettingOf( - pkey.ExitNodeID, "auto:any", - )) - syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore) - for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - b := newTestLocalBackend(t) + sys := tsd.NewSystem() + sys.PolicyClient.Set(policytest.Config{ + pkey.ExitNodeID: "auto:any", + }) + b := newTestLocalBackendWithSys(t, sys) b.currentNode().SetNetMap(tt.netmap) b.lastSuggestedExitNode = tt.lastSuggestedExitNode b.sys.MagicSock.Get().SetLastNetcheckReportForTest(b.ctx, tt.report) @@ -3094,7 +3088,13 @@ func TestUpdateNetmapDeltaAutoExitNode(t *testing.T) { } func TestAutoExitNodeSetNetInfoCallback(t *testing.T) { - b := newTestLocalBackend(t) + polc := policytest.Config{ + pkey.ExitNodeID: "auto:any", + } + sys := tsd.NewSystem() + sys.PolicyClient.Set(polc) + + b := newTestLocalBackendWithSys(t, sys) hi := hostinfo.New() ni := tailcfg.NetInfo{LinkType: "wired"} hi.NetInfo = &ni @@ -3106,16 +3106,12 @@ func TestAutoExitNodeSetNetInfoCallback(t *testing.T) { GetMachinePrivateKey: func() (key.MachinePrivate, error) { return k, nil }, - Dialer: tsdial.NewDialer(netmon.NewStatic()), - Logf: b.logf, + Dialer: tsdial.NewDialer(netmon.NewStatic()), + Logf: b.logf, + PolicyClient: polc, } cc = newClient(t, opts) b.cc = cc - syspolicy.RegisterWellKnownSettingsForTest(t) - policyStore := source.NewTestStoreOf(t, source.TestSettingOf( - pkey.ExitNodeID, "auto:any", - )) - syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore) peer1 := makePeer(1, withCap(26), withDERP(3), withSuggest(), withExitRoutes()) peer2 := makePeer(2, withCap(26), withDERP(2), withSuggest(), withExitRoutes()) selfNode := tailcfg.Node{ @@ -3219,12 +3215,14 @@ func TestSetControlClientStatusAutoExitNode(t *testing.T) { }, DERPMap: derpMap, } - b := newTestLocalBackend(t) - syspolicy.RegisterWellKnownSettingsForTest(t) - policyStore := source.NewTestStoreOf(t, source.TestSettingOf( - pkey.ExitNodeID, "auto:any", - )) - syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore) + + polc := policytest.Config{ + pkey.ExitNodeID: "auto:any", + } + sys := tsd.NewSystem() + sys.PolicyClient.Set(polc) + + b := newTestLocalBackendWithSys(t, sys) b.currentNode().SetNetMap(nm) // Peer 2 should be the initial exit node, as it's better than peer 1 // in terms of latency and DERP region. @@ -3461,21 +3459,20 @@ func TestApplySysPolicy(t *testing.T) { }, } - syspolicy.RegisterWellKnownSettingsForTest(t) - for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - settings := make([]source.TestSetting[string], 0, len(tt.stringPolicies)) - for p, v := range tt.stringPolicies { - settings = append(settings, source.TestSettingOf(p, v)) + var polc policytest.Config + for k, v := range tt.stringPolicies { + polc.Set(k, v) } - policyStore := source.NewTestStoreOf(t, settings...) - syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore) t.Run("unit", func(t *testing.T) { prefs := tt.prefs.Clone() - lb := newTestLocalBackend(t) + sys := tsd.NewSystem() + sys.PolicyClient.Set(polc) + + lb := newTestLocalBackendWithSys(t, sys) gotAnyChange := lb.applySysPolicyLocked(prefs) if gotAnyChange && prefs.Equals(&tt.prefs) { @@ -3508,7 +3505,7 @@ func TestApplySysPolicy(t *testing.T) { pm := must.Get(newProfileManager(new(mem.Store), t.Logf, new(health.Tracker))) pm.prefs = usePrefs.View() - b := newTestBackend(t) + b := newTestBackend(t, polc) b.mu.Lock() b.pm = pm b.mu.Unlock() @@ -3607,24 +3604,26 @@ func TestPreferencePolicyInfo(t *testing.T) { }, } - syspolicy.RegisterWellKnownSettingsForTest(t) - for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { for _, pp := range preferencePolicies { t.Run(string(pp.key), func(t *testing.T) { - s := source.TestSetting[string]{ - Key: pp.key, - Error: tt.policyError, - Value: tt.policyValue, + t.Parallel() + + var polc policytest.Config + if tt.policyError != nil { + polc.Set(pp.key, tt.policyError) + } else { + polc.Set(pp.key, tt.policyValue) } - policyStore := source.NewTestStoreOf(t, s) - syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore) prefs := defaultPrefs.AsStruct() pp.set(prefs, tt.initialValue) - lb := newTestLocalBackend(t) + sys := tsd.NewSystem() + sys.PolicyClient.Set(polc) + + lb := newTestLocalBackendWithSys(t, sys) gotAnyChange := lb.applySysPolicyLocked(prefs) if gotAnyChange != tt.wantChange { @@ -6534,7 +6533,8 @@ func TestUpdatePrefsOnSysPolicyChange(t *testing.T) { store := source.NewTestStoreOf[string](t) syspolicy.MustRegisterStoreForTest(t, "TestSource", setting.DeviceScope, store) - lb := newLocalBackendWithTestControl(t, enableLogging, func(tb testing.TB, opts controlclient.Options) controlclient.Client { + sys := tsd.NewSystem() + lb := newLocalBackendWithSysAndTestControl(t, enableLogging, sys, func(tb testing.TB, opts controlclient.Options) controlclient.Client { return newClient(tb, opts) }) if tt.initialPrefs != nil { diff --git a/ipn/ipnlocal/serve_test.go b/ipn/ipnlocal/serve_test.go index 57d1a4745..e2561cba9 100644 --- a/ipn/ipnlocal/serve_test.go +++ b/ipn/ipnlocal/serve_test.go @@ -35,6 +35,7 @@ import ( "tailscale.com/types/netmap" "tailscale.com/util/mak" "tailscale.com/util/must" + "tailscale.com/util/syspolicy/policyclient" "tailscale.com/wgengine" ) @@ -870,7 +871,7 @@ func mustCreateURL(t *testing.T, u string) url.URL { return *uParsed } -func newTestBackend(t *testing.T) *LocalBackend { +func newTestBackend(t *testing.T, opts ...any) *LocalBackend { var logf logger.Logf = logger.Discard const debug = true if debug { @@ -878,6 +879,16 @@ func newTestBackend(t *testing.T) *LocalBackend { } sys := tsd.NewSystem() + + for _, o := range opts { + switch v := o.(type) { + case policyclient.Client: + sys.PolicyClient.Set(v) + default: + panic(fmt.Sprintf("unsupported option type %T", v)) + } + } + e, err := wgengine.NewUserspaceEngine(logf, wgengine.Config{ SetSubsystem: sys.Set, HealthTracker: sys.HealthTracker(), diff --git a/util/syspolicy/policytest/policytest.go b/util/syspolicy/policytest/policytest.go index e05d8938e..7ea0ad91f 100644 --- a/util/syspolicy/policytest/policytest.go +++ b/util/syspolicy/policytest/policytest.go @@ -19,7 +19,12 @@ import ( // It is used for testing purposes to simulate policy client behavior. // // It panics if a value is Set with one type and then accessed with a different -// expected type. +// expected type and/or value. Some accessors such as GetPreferenceOption and +// GetVisibility support either a ptype.PreferenceOption/ptype.Visibility in the +// map, or the string representation as supported by their UnmarshalText +// methods. +// +// The map value may be an error to return that error value from the accessor. type Config map[pkey.Key]any var _ policyclient.Client = Config{} @@ -33,70 +38,108 @@ func (c *Config) Set(key pkey.Key, value any) { func (c Config) GetStringArray(key pkey.Key, defaultVal []string) ([]string, error) { if val, ok := c[key]; ok { - if arr, ok := val.([]string); ok { - return arr, nil + switch val := val.(type) { + case []string: + return val, nil + case error: + return nil, val + default: + panic(fmt.Sprintf("key %s is not a []string; got %T", key, val)) } - panic(fmt.Sprintf("key %s is not a []string", key)) } return defaultVal, nil } func (c Config) GetString(key pkey.Key, defaultVal string) (string, error) { if val, ok := c[key]; ok { - if str, ok := val.(string); ok { - return str, nil + switch val := val.(type) { + case string: + return val, nil + case error: + return "", val + default: + panic(fmt.Sprintf("key %s is not a string; got %T", key, val)) } - panic(fmt.Sprintf("key %s is not a string", key)) } return defaultVal, nil } func (c Config) GetBoolean(key pkey.Key, defaultVal bool) (bool, error) { if val, ok := c[key]; ok { - if b, ok := val.(bool); ok { - return b, nil + switch val := val.(type) { + case bool: + return val, nil + case error: + return false, val + default: + panic(fmt.Sprintf("key %s is not a bool; got %T", key, val)) } - panic(fmt.Sprintf("key %s is not a bool", key)) } return defaultVal, nil } func (c Config) GetUint64(key pkey.Key, defaultVal uint64) (uint64, error) { if val, ok := c[key]; ok { - if u, ok := val.(uint64); ok { - return u, nil + switch val := val.(type) { + case uint64: + return val, nil + case error: + return 0, val + default: + panic(fmt.Sprintf("key %s is not a uint64; got %T", key, val)) } - panic(fmt.Sprintf("key %s is not a uint64", key)) } return defaultVal, nil } func (c Config) GetDuration(key pkey.Key, defaultVal time.Duration) (time.Duration, error) { if val, ok := c[key]; ok { - if d, ok := val.(time.Duration); ok { - return d, nil + switch val := val.(type) { + case time.Duration: + return val, nil + case error: + return 0, val + default: + panic(fmt.Sprintf("key %s is not a time.Duration; got %T", key, val)) } - panic(fmt.Sprintf("key %s is not a time.Duration", key)) } return defaultVal, nil } func (c Config) GetPreferenceOption(key pkey.Key, defaultVal ptype.PreferenceOption) (ptype.PreferenceOption, error) { if val, ok := c[key]; ok { - if p, ok := val.(ptype.PreferenceOption); ok { - return p, nil + switch val := val.(type) { + case ptype.PreferenceOption: + return val, nil + case error: + var zero ptype.PreferenceOption + return zero, val + case string: + var p ptype.PreferenceOption + err := p.UnmarshalText(([]byte)(val)) + return p, err + default: + panic(fmt.Sprintf("key %s is not a ptype.PreferenceOption", key)) } - panic(fmt.Sprintf("key %s is not a ptype.PreferenceOption", key)) } return defaultVal, nil } func (c Config) GetVisibility(key pkey.Key) (ptype.Visibility, error) { if val, ok := c[key]; ok { - if p, ok := val.(ptype.Visibility); ok { - return p, nil + switch val := val.(type) { + case ptype.Visibility: + return val, nil + case error: + var zero ptype.Visibility + return zero, val + case string: + var p ptype.Visibility + err := p.UnmarshalText(([]byte)(val)) + return p, err + default: + panic(fmt.Sprintf("key %s is not a ptype.Visibility", key)) } - panic(fmt.Sprintf("key %s is not a ptype.Visibility", key)) } return ptype.Visibility(ptype.ShowChoiceByPolicy), nil }