mirror of https://github.com/tailscale/tailscale/
util/reload: add new package to handle periodic value loading
This can be used to reload a value periodically, whether from disk or another source, while handling jitter and graceful shutdown. Updates tailscale/corp#1297 Signed-off-by: Andrew Dunham <andrew@du.nham.ca> Change-Id: Iee2b4385c9abae59805f642a7308837877cb5b3fpull/11478/head
parent
2f7e7be2ea
commit
512fc0b502
@ -0,0 +1,189 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package reload contains functions that allow periodically reloading a value
|
||||
// (e.g. a config file) from various sources.
|
||||
package reload
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"os"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"tailscale.com/syncs"
|
||||
"tailscale.com/types/logger"
|
||||
)
|
||||
|
||||
// DefaultInterval is the default value for ReloadOpts.Interval if none is
|
||||
// provided.
|
||||
const DefaultInterval = 5 * time.Minute
|
||||
|
||||
// ReloadOpts specifies options for reloading a value. Various helper functions
|
||||
// in this package can be used to create one of these specialized for a given
|
||||
// use-case.
|
||||
type ReloadOpts[T any] struct {
|
||||
// Read is called to obtain the data to be unmarshaled; e.g. by reading
|
||||
// from a file, or making a network request, etc.
|
||||
//
|
||||
// An error from this function is fatal when calling New, but only a
|
||||
// warning during reload.
|
||||
//
|
||||
// This value is required.
|
||||
Read func(context.Context) ([]byte, error)
|
||||
|
||||
// Unmarshal is called with the data that the Read function returns and
|
||||
// should return a parsed form of the given value, or an error.
|
||||
//
|
||||
// An error from this function is fatal when calling New, but only a
|
||||
// warning during reload.
|
||||
//
|
||||
// This value is required.
|
||||
Unmarshal func([]byte) (T, error)
|
||||
|
||||
// Logf is a logger used to print errors that occur on reload. If nil,
|
||||
// no messages are printed.
|
||||
Logf logger.Logf
|
||||
|
||||
// Interval is the interval at which to reload the given data from the
|
||||
// source; if zero, DefaultInterval will be used.
|
||||
Interval time.Duration
|
||||
|
||||
// IntervalJitter is the jitter to be added to the given Interval; if
|
||||
// provided, a duration between 0 and this value will be added to each
|
||||
// Interval when waiting.
|
||||
IntervalJitter time.Duration
|
||||
}
|
||||
|
||||
func (r *ReloadOpts[T]) logf(format string, args ...any) {
|
||||
if r.Logf != nil {
|
||||
r.Logf(format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ReloadOpts[T]) intervalWithJitter() time.Duration {
|
||||
tt := r.Interval
|
||||
if tt == 0 {
|
||||
tt = DefaultInterval
|
||||
}
|
||||
if r.IntervalJitter == 0 {
|
||||
return tt
|
||||
}
|
||||
|
||||
jitter := time.Duration(rand.Intn(int(r.IntervalJitter)))
|
||||
return tt + jitter
|
||||
}
|
||||
|
||||
// New creates and starts reloading the provided value as per opts. It returns
|
||||
// a function that, when called, returns the current stored value, or an error
|
||||
// that indicates something went wrong.
|
||||
//
|
||||
// The value will be present immediately upon return.
|
||||
func New[T any](ctx context.Context, opts ReloadOpts[T]) (func() T, error) {
|
||||
// Create our reloader, which hasn't started.
|
||||
reloader, err := newUnstarted(ctx, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Start it
|
||||
go reloader.run()
|
||||
|
||||
// Return the load function now that we're all set up.
|
||||
return reloader.store.Load, nil
|
||||
}
|
||||
|
||||
type reloader[T any] struct {
|
||||
ctx context.Context
|
||||
store syncs.AtomicValue[T]
|
||||
opts ReloadOpts[T]
|
||||
}
|
||||
|
||||
// newUnstarted creates a reloader that hasn't yet been started.
|
||||
func newUnstarted[T any](ctx context.Context, opts ReloadOpts[T]) (*reloader[T], error) {
|
||||
if opts.Read == nil {
|
||||
return nil, fmt.Errorf("the Read function is required")
|
||||
}
|
||||
if opts.Unmarshal == nil {
|
||||
return nil, fmt.Errorf("the Unmarshal function is required")
|
||||
}
|
||||
|
||||
// Start by reading and unmarshaling the value.
|
||||
data, err := opts.Read(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading initial value: %w", err)
|
||||
}
|
||||
|
||||
initial, err := opts.Unmarshal(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unmarshaling initial value: %v", err)
|
||||
}
|
||||
|
||||
reloader := &reloader[T]{
|
||||
ctx: ctx,
|
||||
opts: opts,
|
||||
}
|
||||
reloader.store.Store(initial)
|
||||
return reloader, nil
|
||||
}
|
||||
|
||||
func (r *reloader[T]) run() {
|
||||
// Create a timer that we re-set each time we fire.
|
||||
timer := time.NewTimer(r.opts.intervalWithJitter())
|
||||
defer timer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.ctx.Done():
|
||||
r.opts.logf("run context is done")
|
||||
return
|
||||
case <-timer.C:
|
||||
}
|
||||
|
||||
if err := r.updateOnce(); err != nil {
|
||||
r.opts.logf("error refreshing data: %v", err)
|
||||
}
|
||||
|
||||
// Re-arm the timer after we're done; this is safe
|
||||
// since the only way this loop woke up was by reading
|
||||
// from timer.C
|
||||
timer.Reset(r.opts.intervalWithJitter())
|
||||
}
|
||||
}
|
||||
|
||||
func (r *reloader[T]) updateOnce() error {
|
||||
data, err := r.opts.Read(r.ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading data: %w", err)
|
||||
}
|
||||
next, err := r.opts.Unmarshal(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unmarshaling data: %w", err)
|
||||
}
|
||||
|
||||
oldValue := r.store.Swap(next)
|
||||
if !reflect.DeepEqual(oldValue, next) {
|
||||
r.opts.logf("stored new value: %+v", next)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// FromJSONFile creates a ReloadOpts describing reloading a value of type T
|
||||
// from the given JSON file on-disk.
|
||||
func FromJSONFile[T any](path string) ReloadOpts[T] {
|
||||
return ReloadOpts[T]{
|
||||
Read: func(_ context.Context) ([]byte, error) {
|
||||
return os.ReadFile(path)
|
||||
},
|
||||
Unmarshal: func(b []byte) (T, error) {
|
||||
var ret, zero T
|
||||
if err := json.Unmarshal(b, &ret); err != nil {
|
||||
return zero, err
|
||||
}
|
||||
return ret, nil
|
||||
},
|
||||
}
|
||||
}
|
@ -0,0 +1,171 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package reload
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"tailscale.com/tstest"
|
||||
)
|
||||
|
||||
func TestReloader(t *testing.T) {
|
||||
buf := []byte("hello world")
|
||||
|
||||
ctx := context.Background()
|
||||
r, err := newUnstarted[string](ctx, ReloadOpts[string]{
|
||||
Logf: t.Logf,
|
||||
Read: func(context.Context) ([]byte, error) {
|
||||
return buf, nil
|
||||
},
|
||||
Unmarshal: func(b []byte) (string, error) {
|
||||
return "The value is: " + string(b), nil
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// We should have an initial value.
|
||||
const wantInitial = "The value is: hello world"
|
||||
if v := r.store.Load(); v != wantInitial {
|
||||
t.Errorf("got initial value %q, want %q", v, wantInitial)
|
||||
}
|
||||
|
||||
// Reloading should result in a new value
|
||||
buf = []byte("new value")
|
||||
if err := r.updateOnce(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
const wantReload = "The value is: new value"
|
||||
if v := r.store.Load(); v != wantReload {
|
||||
t.Errorf("got reloaded value %q, want %q", v, wantReload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReloader_InitialError(t *testing.T) {
|
||||
fakeErr := errors.New("fake error")
|
||||
|
||||
ctx := context.Background()
|
||||
_, err := newUnstarted[string](ctx, ReloadOpts[string]{
|
||||
Logf: t.Logf,
|
||||
Read: func(context.Context) ([]byte, error) { return nil, fakeErr },
|
||||
Unmarshal: func(b []byte) (string, error) { panic("unused because Read fails") },
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected non-nil error")
|
||||
}
|
||||
if !errors.Is(err, fakeErr) {
|
||||
t.Errorf("wanted errors.Is(%v, fakeErr)=true", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReloader_ReloadError(t *testing.T) {
|
||||
fakeErr := errors.New("fake error")
|
||||
shouldError := false
|
||||
|
||||
ctx := context.Background()
|
||||
r, err := newUnstarted[string](ctx, ReloadOpts[string]{
|
||||
Logf: t.Logf,
|
||||
Read: func(context.Context) ([]byte, error) {
|
||||
return []byte("hello"), nil
|
||||
},
|
||||
Unmarshal: func(b []byte) (string, error) {
|
||||
if shouldError {
|
||||
return "", fakeErr
|
||||
}
|
||||
return string(b), nil
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got := r.store.Load(); got != "hello" {
|
||||
t.Fatalf("got value %q, want \"hello\"", got)
|
||||
}
|
||||
|
||||
shouldError = true
|
||||
|
||||
if err := r.updateOnce(); err == nil {
|
||||
t.Errorf("expected error from updateOnce")
|
||||
}
|
||||
if got := r.store.Load(); got != "hello" {
|
||||
t.Fatalf("got value %q, want \"hello\"", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReloader_Run(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
var ncalls atomic.Int64
|
||||
load, err := New[string](ctx, ReloadOpts[string]{
|
||||
Logf: tstest.WhileTestRunningLogger(t),
|
||||
Interval: 10 * time.Millisecond,
|
||||
Read: func(context.Context) ([]byte, error) {
|
||||
return []byte("hello"), nil
|
||||
},
|
||||
Unmarshal: func(b []byte) (string, error) {
|
||||
callNum := ncalls.Add(1)
|
||||
if callNum == 3 {
|
||||
cancel()
|
||||
}
|
||||
return fmt.Sprintf("call %d: %s", callNum, b), nil
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
want := "call 1: hello"
|
||||
if got := load(); got != want {
|
||||
t.Fatalf("got value %q, want %q", got, want)
|
||||
}
|
||||
|
||||
// Wait for the periodic refresh to cancel our context
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Fatal("test timed out")
|
||||
}
|
||||
|
||||
// Depending on how goroutines get scheduled, we can either read call 2
|
||||
// (if we woke up before the run goroutine stores call 3), or call 3
|
||||
// (if we woke up after the run goroutine stores the next value). Check
|
||||
// for both.
|
||||
want1, want2 := "call 2: hello", "call 3: hello"
|
||||
if got := load(); got != want1 && got != want2 {
|
||||
t.Fatalf("got value %q, want %q or %q", got, want1, want2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromJSONFile(t *testing.T) {
|
||||
type testStruct struct {
|
||||
Value string
|
||||
Number int
|
||||
}
|
||||
fpath := filepath.Join(t.TempDir(), "test.json")
|
||||
if err := os.WriteFile(fpath, []byte(`{"Value": "hello", "Number": 1234}`), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
r, err := newUnstarted(ctx, FromJSONFile[*testStruct](fpath))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got := r.store.Load()
|
||||
want := &testStruct{Value: "hello", Number: 1234}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got %+v, want %+v", got, want)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue