From 241a541864151d08153a5a51227e44a5c4c81e0b Mon Sep 17 00:00:00 2001 From: Joe Tsai Date: Fri, 12 Jan 2024 17:35:48 -0800 Subject: [PATCH] util/ctxkey: add package for type-safe context keys (#10841) The lack of type-safety in context.WithValue leads to the common pattern of defining of package-scoped type to ensure global uniqueness: type fooKey struct{} func withFoo(ctx context, v Foo) context.Context { return context.WithValue(ctx, fooKey{}, v) } func fooValue(ctx context) Foo { v, _ := ctx.Value(fooKey{}).(Foo) return v } where usage becomes: ctx = withFoo(ctx, foo) foo := fooValue(ctx) With many different context keys, this can be quite tedious. Using generics, we can simplify this as: var fooKey = ctxkey.New("mypkg.fooKey", Foo{}) where usage becomes: ctx = fooKey.WithValue(ctx, foo) foo := fooKey.Value(ctx) See https://go.dev/issue/49189 Updates #cleanup Signed-off-by: Joe Tsai --- util/ctxkey/key.go | 139 ++++++++++++++++++++++++++++++++++++++++ util/ctxkey/key_test.go | 82 ++++++++++++++++++++++++ 2 files changed, 221 insertions(+) create mode 100644 util/ctxkey/key.go create mode 100644 util/ctxkey/key_test.go diff --git a/util/ctxkey/key.go b/util/ctxkey/key.go new file mode 100644 index 000000000..87383cf58 --- /dev/null +++ b/util/ctxkey/key.go @@ -0,0 +1,139 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// ctxkey provides type-safe key-value pairs for use with [context.Context]. +// +// Example usage: +// +// // Create a context key. +// var TimeoutKey = ctxkey.New("fsrv.Timeout", 5*time.Second) +// +// // Store a context value. +// ctx = fsrv.TimeoutKey.WithValue(ctx, 10*time.Second) +// +// // Load a context value. +// timeout := fsrv.TimeoutKey.Value(ctx) +// ... // use timeout of type time.Duration +// +// This is inspired by https://go.dev/issue/49189. +package ctxkey + +import ( + "context" + "fmt" + "reflect" +) + +// Key is a generic key type associated with a specific value type. +// +// A zero Key is valid where the Value type itself is used as the context key. +// This pattern should only be used with locally declared Go types. +// The Value type must not be an interface type. +// +// Example usage: +// +// type peerInfo struct { ... } // peerInfo is an unexported type +// var peerInfoKey = ctxkey.Key[peerInfo] +// ctx = peerInfoKey.WithValue(ctx, info) // store a context value +// info = peerInfoKey.Value(ctx) // load a context value +// +// In general, any exported keys should be produced using [New]. +type Key[Value any] struct { + name *stringer[string] + defVal *Value +} + +// New constructs a new context key with an associated value type +// where the default value for an unpopulated value is the provided value. +// +// The provided name is an arbitrary name only used for human debugging. +// As a convention, it is recommended that the name be the dot-delimited +// combination of the package name of the caller with the variable name. +// Every key is unique, even if provided the same name. +// +// Example usage: +// +// package mapreduce +// var NumWorkersKey = ctxkey.New("mapreduce.NumWorkers", runtime.NumCPU()) +func New[Value any](name string, defaultValue Value) Key[Value] { + if name == "" { + var v Value + name = reflect.TypeOf(v).String() // TODO(https://go.dev/issue/60088): Use reflect.TypeFor. + } + var defVal *Value + switch v := reflect.ValueOf(&defaultValue).Elem(); { + case v.Kind() == reflect.Interface: + panic(fmt.Sprintf("value type %v must not be an interface", v.Type())) + case !v.IsZero(): + defVal = &defaultValue + } + // Allocate a *stringer to ensure that every invocation of New + // creates a universally unique context key even for the same name. + return Key[Value]{name: &stringer[string]{name}, defVal: defVal} +} + +// contextKey returns the context key to use. +func (key Key[Value]) contextKey() any { + if key.name == nil { + // Use the reflect.Type of the Value (implies key not created by New). + var v Value + t := reflect.TypeOf(v) + if t == nil { + panic(fmt.Sprintf("value type %v must not be an interface", reflect.TypeOf(&v).Elem())) + } + return t + } else { + // Use the name pointer directly (implies key created by New). + return key.name + } +} + +// WithValue returns a copy of parent in which the value associated with key is val. +// +// It is a type-safe equivalent of [context.WithValue]. +func (key Key[Value]) WithValue(parent context.Context, val Value) context.Context { + return context.WithValue(parent, key.contextKey(), stringer[Value]{val}) +} + +// ValueOk returns the value in the context associated with this key +// and also reports whether it was present. +// If the value is not present, it returns the default value. +func (key Key[Value]) ValueOk(ctx context.Context) (v Value, ok bool) { + vv, ok := ctx.Value(key.contextKey()).(stringer[Value]) + if !ok && key.defVal != nil { + vv.v = *key.defVal + } + return vv.v, ok +} + +// Value returns the value in the context associated with this key. +// If the value is not present, it returns the default value. +func (key Key[Value]) Value(ctx context.Context) (v Value) { + v, _ = key.ValueOk(ctx) + return v +} + +// Has reports whether the context has a value for this key. +func (key Key[Value]) Has(ctx context.Context) (ok bool) { + _, ok = key.ValueOk(ctx) + return ok +} + +// String returns the name of the key. +func (key Key[Value]) String() string { + if key.name == nil { + var v Value + return reflect.TypeOf(v).String() // TODO(https://go.dev/issue/60088): Use reflect.TypeFor. + } + return key.name.String() +} + +// stringer implements [fmt.Stringer] on a generic T. +// +// This assists in debugging such that printing a context prints key and value. +// Note that the [context] package lacks a dependency on [reflect], +// so it cannot print arbitrary values. By implementing [fmt.Stringer], +// we functionally teach a context how to print itself. +type stringer[T any] struct{ v T } + +func (v stringer[T]) String() string { return fmt.Sprint(v.v) } diff --git a/util/ctxkey/key_test.go b/util/ctxkey/key_test.go new file mode 100644 index 000000000..8797576f2 --- /dev/null +++ b/util/ctxkey/key_test.go @@ -0,0 +1,82 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ctxkey + +import ( + "context" + "fmt" + "regexp" + "testing" + "time" + + qt "github.com/frankban/quicktest" +) + +func TestKey(t *testing.T) { + c := qt.New(t) + ctx := context.Background() + + // Test keys with the same name as being distinct. + k1 := New("same.Name", "") + c.Assert(k1.String(), qt.Equals, "same.Name") + k2 := New("same.Name", "") + c.Assert(k2.String(), qt.Equals, "same.Name") + c.Assert(k1 == k2, qt.Equals, false) + ctx = k1.WithValue(ctx, "hello") + c.Assert(k1.Has(ctx), qt.Equals, true) + c.Assert(k1.Value(ctx), qt.Equals, "hello") + c.Assert(k2.Has(ctx), qt.Equals, false) + c.Assert(k2.Value(ctx), qt.Equals, "") + ctx = k2.WithValue(ctx, "goodbye") + c.Assert(k1.Has(ctx), qt.Equals, true) + c.Assert(k1.Value(ctx), qt.Equals, "hello") + c.Assert(k2.Has(ctx), qt.Equals, true) + c.Assert(k2.Value(ctx), qt.Equals, "goodbye") + + // Test default value. + k3 := New("mapreduce.Timeout", time.Hour) + c.Assert(k3.Has(ctx), qt.Equals, false) + c.Assert(k3.Value(ctx), qt.Equals, time.Hour) + ctx = k3.WithValue(ctx, time.Minute) + c.Assert(k3.Has(ctx), qt.Equals, true) + c.Assert(k3.Value(ctx), qt.Equals, time.Minute) + + // Test incomparable value. + k4 := New("slice", []int(nil)) + c.Assert(k4.Has(ctx), qt.Equals, false) + c.Assert(k4.Value(ctx), qt.DeepEquals, []int(nil)) + ctx = k4.WithValue(ctx, []int{1, 2, 3}) + c.Assert(k4.Has(ctx), qt.Equals, true) + c.Assert(k4.Value(ctx), qt.DeepEquals, []int{1, 2, 3}) + + // Accessors should be allocation free. + c.Assert(testing.AllocsPerRun(100, func() { + k1.Value(ctx) + k1.Has(ctx) + k1.ValueOk(ctx) + }), qt.Equals, 0.0) + + // Test keys that are created without New. + var k5 Key[string] + c.Assert(k5.String(), qt.Equals, "string") + c.Assert(k1 == k5, qt.Equals, false) // should be different from key created by New + c.Assert(k5.Has(ctx), qt.Equals, false) + ctx = k5.WithValue(ctx, "fizz") + c.Assert(k5.Value(ctx), qt.Equals, "fizz") + var k6 Key[string] + c.Assert(k6.String(), qt.Equals, "string") + c.Assert(k5 == k6, qt.Equals, true) + c.Assert(k6.Has(ctx), qt.Equals, true) + ctx = k6.WithValue(ctx, "fizz") +} + +func TestStringer(t *testing.T) { + t.SkipNow() // TODO(https://go.dev/cl/555697): Enable this after fix is merged upstream. + c := qt.New(t) + ctx := context.Background() + c.Assert(fmt.Sprint(New("foo.Bar", "").WithValue(ctx, "baz")), qt.Matches, regexp.MustCompile("foo.Bar.*baz")) + c.Assert(fmt.Sprint(New("", []int{}).WithValue(ctx, []int{1, 2, 3})), qt.Matches, regexp.MustCompile(fmt.Sprintf("%[1]T.*%[1]v", []int{1, 2, 3}))) + c.Assert(fmt.Sprint(New("", 0).WithValue(ctx, 5)), qt.Matches, regexp.MustCompile("int.*5")) + c.Assert(fmt.Sprint(Key[time.Duration]{}.WithValue(ctx, time.Hour)), qt.Matches, regexp.MustCompile(fmt.Sprintf("%[1]T.*%[1]v", time.Hour))) +}