From 24b8a57b1e9c61154d45d87402fadcb56ff27843 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Tue, 2 Sep 2025 16:50:10 -0700 Subject: [PATCH] util/syspolicy/policytest: move policy test helper to its own package Updates #16998 Updates #12614 Change-Id: I9fd27d653ebee547951705dc5597481e85b60747 Signed-off-by: Brad Fitzpatrick --- ipn/ipnlocal/local_test.go | 62 +------------ util/syspolicy/policytest/policytest.go | 117 ++++++++++++++++++++++++ 2 files changed, 120 insertions(+), 59 deletions(-) create mode 100644 util/syspolicy/policytest/policytest.go diff --git a/ipn/ipnlocal/local_test.go b/ipn/ipnlocal/local_test.go index a3a26af04..bd81a09c3 100644 --- a/ipn/ipnlocal/local_test.go +++ b/ipn/ipnlocal/local_test.go @@ -64,7 +64,7 @@ import ( "tailscale.com/util/set" "tailscale.com/util/syspolicy" "tailscale.com/util/syspolicy/pkey" - "tailscale.com/util/syspolicy/policyclient" + "tailscale.com/util/syspolicy/policytest" "tailscale.com/util/syspolicy/setting" "tailscale.com/util/syspolicy/source" "tailscale.com/wgengine" @@ -1183,7 +1183,7 @@ func TestConfigureExitNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - var pol testPolicy + var pol policytest.Config // Configure policy settings, if any. if tt.exitNodeIDPolicy != nil { pol.Set(pkey.ExitNodeID, string(*tt.exitNodeIDPolicy)) @@ -5539,62 +5539,6 @@ func TestReadWriteRouteInfo(t *testing.T) { } } -// testPolicy is a [policyclient.Client] with a static mapping of values. -// The map value must be of the correct type (string, []string, bool, etc). -// -// It is used for testing purposes to simulate policy client behavior. -// It panics if the values are the wrong type. -type testPolicy struct { - v map[pkey.Key]any - policyclient.NoPolicyClient -} - -func (sp *testPolicy) Set(key pkey.Key, value any) { - if sp.v == nil { - sp.v = make(map[pkey.Key]any) - } - sp.v[key] = value -} - -func (sp testPolicy) GetStringArray(key pkey.Key, defaultVal []string) ([]string, error) { - if val, ok := sp.v[key]; ok { - if arr, ok := val.([]string); ok { - return arr, nil - } - panic(fmt.Sprintf("key %s is not a []string", key)) - } - return defaultVal, nil -} - -func (sp testPolicy) GetString(key pkey.Key, defaultVal string) (string, error) { - if val, ok := sp.v[key]; ok { - if str, ok := val.(string); ok { - return str, nil - } - panic(fmt.Sprintf("key %s is not a string", key)) - } - return defaultVal, nil -} - -func (sp testPolicy) GetBoolean(key pkey.Key, defaultVal bool) (bool, error) { - if val, ok := sp.v[key]; ok { - if b, ok := val.(bool); ok { - return b, nil - } - panic(fmt.Sprintf("key %s is not a bool", key)) - } - return defaultVal, nil -} - -func (sp testPolicy) HasAnyOf(keys ...pkey.Key) (bool, error) { - for _, key := range keys { - if _, ok := sp.v[key]; ok { - return true, nil - } - } - return false, nil -} - func TestFillAllowedSuggestions(t *testing.T) { tests := []struct { name string @@ -5628,7 +5572,7 @@ func TestFillAllowedSuggestions(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - var pol testPolicy + var pol policytest.Config pol.Set(pkey.AllowedSuggestedExitNodes, tt.allowPolicy) got := fillAllowedSuggestions(pol) diff --git a/util/syspolicy/policytest/policytest.go b/util/syspolicy/policytest/policytest.go new file mode 100644 index 000000000..e05d8938e --- /dev/null +++ b/util/syspolicy/policytest/policytest.go @@ -0,0 +1,117 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package policytest contains test helpers for the syspolicy packages. +package policytest + +import ( + "fmt" + "time" + + "tailscale.com/util/syspolicy/pkey" + "tailscale.com/util/syspolicy/policyclient" + "tailscale.com/util/syspolicy/ptype" +) + +// Config is a [policyclient.Client] implementation with a static mapping of +// values. +// +// It is used for testing purposes to simulate policy client behavior. +// +// It panics if a value is Set with one type and then accessed with a different +// expected type. +type Config map[pkey.Key]any + +var _ policyclient.Client = Config{} + +func (c *Config) Set(key pkey.Key, value any) { + if *c == nil { + *c = make(map[pkey.Key]any) + } + (*c)[key] = value +} + +func (c Config) GetStringArray(key pkey.Key, defaultVal []string) ([]string, error) { + if val, ok := c[key]; ok { + if arr, ok := val.([]string); ok { + return arr, nil + } + panic(fmt.Sprintf("key %s is not a []string", key)) + } + return defaultVal, nil +} + +func (c Config) GetString(key pkey.Key, defaultVal string) (string, error) { + if val, ok := c[key]; ok { + if str, ok := val.(string); ok { + return str, nil + } + panic(fmt.Sprintf("key %s is not a string", key)) + } + return defaultVal, nil +} + +func (c Config) GetBoolean(key pkey.Key, defaultVal bool) (bool, error) { + if val, ok := c[key]; ok { + if b, ok := val.(bool); ok { + return b, nil + } + panic(fmt.Sprintf("key %s is not a bool", key)) + } + return defaultVal, nil +} + +func (c Config) GetUint64(key pkey.Key, defaultVal uint64) (uint64, error) { + if val, ok := c[key]; ok { + if u, ok := val.(uint64); ok { + return u, nil + } + panic(fmt.Sprintf("key %s is not a uint64", key)) + } + return defaultVal, nil +} + +func (c Config) GetDuration(key pkey.Key, defaultVal time.Duration) (time.Duration, error) { + if val, ok := c[key]; ok { + if d, ok := val.(time.Duration); ok { + return d, nil + } + panic(fmt.Sprintf("key %s is not a time.Duration", key)) + } + return defaultVal, nil +} + +func (c Config) GetPreferenceOption(key pkey.Key, defaultVal ptype.PreferenceOption) (ptype.PreferenceOption, error) { + if val, ok := c[key]; ok { + if p, ok := val.(ptype.PreferenceOption); ok { + return p, nil + } + panic(fmt.Sprintf("key %s is not a ptype.PreferenceOption", key)) + } + return defaultVal, nil +} + +func (c Config) GetVisibility(key pkey.Key) (ptype.Visibility, error) { + if val, ok := c[key]; ok { + if p, ok := val.(ptype.Visibility); ok { + return p, nil + } + panic(fmt.Sprintf("key %s is not a ptype.Visibility", key)) + } + return ptype.Visibility(ptype.ShowChoiceByPolicy), nil +} + +func (c Config) HasAnyOf(keys ...pkey.Key) (bool, error) { + for _, key := range keys { + if _, ok := c[key]; ok { + return true, nil + } + } + return false, nil +} + +func (sp Config) RegisterChangeCallback(callback func(policyclient.PolicyChange)) (func(), error) { + return func() {}, nil +} + +func (sp Config) SetDebugLoggingEnabled(enabled bool) {}