util/syspolicy: add caching handler (#10288)

Fixes tailscale/corp#15850
Co-authored-by: Adrian Dewhurst <adrian@tailscale.com>
Signed-off-by: Claire Wang <claire@tailscale.com>
pull/10302/head
Claire Wang 1 year ago committed by GitHub
parent 719ee4415e
commit b8a2aedccd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,98 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package syspolicy
import (
"errors"
"sync"
)
// CachingHandler is a handler that reads policies from an underlying handler the first time each key is requested
// and permanently caches the result unless there is an error. If there is an ErrNoSuchKey error, that result is cached,
// otherwise the actual error is returned and the next read for that key will retry using the handler.
type CachingHandler struct {
mu sync.Mutex
strings map[string]string
uint64s map[string]uint64
bools map[string]bool
notFound map[string]bool
handler Handler
}
// NewCachingHandler creates a CachingHandler given a handler.
func NewCachingHandler(handler Handler) *CachingHandler {
return &CachingHandler{
handler: handler,
strings: make(map[string]string),
uint64s: make(map[string]uint64),
bools: make(map[string]bool),
notFound: make(map[string]bool),
}
}
// ReadString reads the policy settings value string given the key.
// ReadString first reads from the handler's cache before resorting to using the handler.
func (ch *CachingHandler) ReadString(key string) (string, error) {
ch.mu.Lock()
defer ch.mu.Unlock()
if val, ok := ch.strings[key]; ok {
return val, nil
}
if notFound := ch.notFound[key]; notFound {
return "", ErrNoSuchKey
}
val, err := ch.handler.ReadString(key)
if errors.Is(err, ErrNoSuchKey) {
ch.notFound[key] = true
return "", err
} else if err != nil {
return "", err
}
ch.strings[key] = val
return val, nil
}
// ReadUInt64 reads the policy settings uint64 value given the key.
// ReadUInt64 first reads from the handler's cache before resorting to using the handler.
func (ch *CachingHandler) ReadUInt64(key string) (uint64, error) {
ch.mu.Lock()
defer ch.mu.Unlock()
if val, ok := ch.uint64s[key]; ok {
return val, nil
}
if notFound := ch.notFound[key]; notFound {
return 0, ErrNoSuchKey
}
val, err := ch.handler.ReadUInt64(key)
if errors.Is(err, ErrNoSuchKey) {
ch.notFound[key] = true
return 0, err
} else if err != nil {
return 0, err
}
ch.uint64s[key] = val
return val, nil
}
// ReadBoolean reads the policy settings boolean value given the key.
// ReadBoolean first reads from the handler's cache before resorting to using the handler.
func (ch *CachingHandler) ReadBoolean(key string) (bool, error) {
ch.mu.Lock()
defer ch.mu.Unlock()
if val, ok := ch.bools[key]; ok {
return val, nil
}
if notFound := ch.notFound[key]; notFound {
return false, ErrNoSuchKey
}
val, err := ch.handler.ReadBoolean(key)
if errors.Is(err, ErrNoSuchKey) {
ch.notFound[key] = true
return false, err
} else if err != nil {
return false, err
}
ch.bools[key] = val
return val, nil
}

@ -0,0 +1,262 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package syspolicy
import (
"testing"
)
func TestHandlerReadString(t *testing.T) {
tests := []struct {
name string
key string
handlerKey Key
handlerValue string
handlerError error
preserveHandler bool
wantValue string
wantErr error
strings map[string]string
expectedCalls int
}{
{
name: "read existing cached values",
key: "test",
handlerKey: "do not read",
strings: map[string]string{"test": "foo"},
wantValue: "foo",
expectedCalls: 0,
},
{
name: "read existing values not cached",
key: "test",
handlerKey: "test",
handlerValue: "foo",
wantValue: "foo",
expectedCalls: 1,
},
{
name: "error no such key",
key: "test",
handlerKey: "test",
handlerError: ErrNoSuchKey,
wantErr: ErrNoSuchKey,
expectedCalls: 1,
},
{
name: "other error",
key: "test",
handlerKey: "test",
handlerError: someOtherError,
wantErr: someOtherError,
preserveHandler: true,
expectedCalls: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testHandler := &testHandler{
t: t,
key: tt.handlerKey,
s: tt.handlerValue,
err: tt.handlerError,
}
cache := NewCachingHandler(testHandler)
if tt.strings != nil {
cache.strings = tt.strings
}
got, err := cache.ReadString(tt.key)
if err != tt.wantErr {
t.Errorf("err=%v want %v", err, tt.wantErr)
}
if got != tt.wantValue {
t.Errorf("got %v want %v", got, cache.strings[tt.key])
}
if !tt.preserveHandler {
testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil
}
got, err = cache.ReadString(tt.key)
if err != tt.wantErr {
t.Errorf("repeat err=%v want %v", err, tt.wantErr)
}
if got != tt.wantValue {
t.Errorf("repeat got %v want %v", got, cache.strings[tt.key])
}
if testHandler.calls != tt.expectedCalls {
t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls)
}
})
}
}
func TestHandlerReadUint64(t *testing.T) {
tests := []struct {
name string
key string
handlerKey Key
handlerValue uint64
handlerError error
preserveHandler bool
wantValue uint64
wantErr error
uint64s map[string]uint64
expectedCalls int
}{
{
name: "read existing cached values",
key: "test",
handlerKey: "do not read",
uint64s: map[string]uint64{"test": 1},
wantValue: 1,
expectedCalls: 0,
},
{
name: "read existing values not cached",
key: "test",
handlerKey: "test",
handlerValue: 1,
wantValue: 1,
expectedCalls: 1,
},
{
name: "error no such key",
key: "test",
handlerKey: "test",
handlerError: ErrNoSuchKey,
wantErr: ErrNoSuchKey,
expectedCalls: 1,
},
{
name: "other error",
key: "test",
handlerKey: "test",
handlerError: someOtherError,
wantErr: someOtherError,
preserveHandler: true,
expectedCalls: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testHandler := &testHandler{
t: t,
key: tt.handlerKey,
u64: tt.handlerValue,
err: tt.handlerError,
}
cache := NewCachingHandler(testHandler)
if tt.uint64s != nil {
cache.uint64s = tt.uint64s
}
got, err := cache.ReadUInt64(tt.key)
if err != tt.wantErr {
t.Errorf("err=%v want %v", err, tt.wantErr)
}
if got != tt.wantValue {
t.Errorf("got %v want %v", got, cache.strings[tt.key])
}
if !tt.preserveHandler {
testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil
}
got, err = cache.ReadUInt64(tt.key)
if err != tt.wantErr {
t.Errorf("repeat err=%v want %v", err, tt.wantErr)
}
if got != tt.wantValue {
t.Errorf("repeat got %v want %v", got, cache.strings[tt.key])
}
if testHandler.calls != tt.expectedCalls {
t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls)
}
})
}
}
func TestHandlerReadBool(t *testing.T) {
tests := []struct {
name string
key string
handlerKey Key
handlerValue bool
handlerError error
preserveHandler bool
wantValue bool
wantErr error
bools map[string]bool
expectedCalls int
}{
{
name: "read existing cached values",
key: "test",
handlerKey: "do not read",
bools: map[string]bool{"test": true},
wantValue: true,
expectedCalls: 0,
},
{
name: "read existing values not cached",
key: "test",
handlerKey: "test",
handlerValue: true,
wantValue: true,
expectedCalls: 1,
},
{
name: "error no such key",
key: "test",
handlerKey: "test",
handlerError: ErrNoSuchKey,
wantErr: ErrNoSuchKey,
expectedCalls: 1,
},
{
name: "other error",
key: "test",
handlerKey: "test",
handlerError: someOtherError,
wantErr: someOtherError,
preserveHandler: true,
expectedCalls: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testHandler := &testHandler{
t: t,
key: tt.handlerKey,
b: tt.handlerValue,
err: tt.handlerError,
}
cache := NewCachingHandler(testHandler)
if tt.bools != nil {
cache.bools = tt.bools
}
got, err := cache.ReadBoolean(tt.key)
if err != tt.wantErr {
t.Errorf("err=%v want %v", err, tt.wantErr)
}
if got != tt.wantValue {
t.Errorf("got %v want %v", got, cache.strings[tt.key])
}
if !tt.preserveHandler {
testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil
}
got, err = cache.ReadBoolean(tt.key)
if err != tt.wantErr {
t.Errorf("repeat err=%v want %v", err, tt.wantErr)
}
if got != tt.wantValue {
t.Errorf("repeat got %v want %v", got, cache.strings[tt.key])
}
if testHandler.calls != tt.expectedCalls {
t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls)
}
})
}
}

@ -12,7 +12,7 @@ import (
type windowsHandler struct{} type windowsHandler struct{}
func init() { func init() {
RegisterHandler(windowsHandler{}) RegisterHandler(NewCachingHandler(windowsHandler{}))
} }
func (windowsHandler) ReadString(key string) (string, error) { func (windowsHandler) ReadString(key string) (string, error) {

@ -13,12 +13,13 @@ import (
// methods that involve getting a policy value. // methods that involve getting a policy value.
// For keys and the corresponding values, check policy_keys.go. // For keys and the corresponding values, check policy_keys.go.
type testHandler struct { type testHandler struct {
t *testing.T t *testing.T
key Key key Key
s string s string
u64 uint64 u64 uint64
b bool b bool
err error err error
calls int // used for testing reads from cache vs. handler
} }
var someOtherError = errors.New("error other than not found") var someOtherError = errors.New("error other than not found")
@ -34,6 +35,7 @@ func (th *testHandler) ReadString(key string) (string, error) {
if key != string(th.key) { if key != string(th.key) {
th.t.Errorf("ReadString(%q) want %q", key, th.key) th.t.Errorf("ReadString(%q) want %q", key, th.key)
} }
th.calls++
return th.s, th.err return th.s, th.err
} }
@ -41,6 +43,7 @@ func (th *testHandler) ReadUInt64(key string) (uint64, error) {
if key != string(th.key) { if key != string(th.key) {
th.t.Errorf("ReadUint64(%q) want %q", key, th.key) th.t.Errorf("ReadUint64(%q) want %q", key, th.key)
} }
th.calls++
return th.u64, th.err return th.u64, th.err
} }
@ -48,6 +51,7 @@ func (th *testHandler) ReadBoolean(key string) (bool, error) {
if key != string(th.key) { if key != string(th.key) {
th.t.Errorf("ReadBool(%q) want %q", key, th.key) th.t.Errorf("ReadBool(%q) want %q", key, th.key)
} }
th.calls++
return th.b, th.err return th.b, th.err
} }

Loading…
Cancel
Save