all: make use of ctxkey everywhere (#10846)

Also perform minor cleanups on the ctxkey package itself.
Provide guidance on when to use ctxkey.Key[T] over ctxkey.New.
Also, allow for interface kinds because the value wrapping trick
also happens to fix edge cases with interfaces in Go.

Updates #cleanup

Signed-off-by: Joe Tsai <joetsai@digital-static.net>
pull/10872/head
Joe Tsai 4 months ago committed by GitHub
parent 7732377cd7
commit c25968e1c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -142,6 +142,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
tailscale.com/util/cloudenv from tailscale.com/hostinfo+ tailscale.com/util/cloudenv from tailscale.com/hostinfo+
W tailscale.com/util/cmpver from tailscale.com/net/tshttpproxy W tailscale.com/util/cmpver from tailscale.com/net/tshttpproxy
tailscale.com/util/cmpx from tailscale.com/cmd/derper+ tailscale.com/util/cmpx from tailscale.com/cmd/derper+
tailscale.com/util/ctxkey from tailscale.com/tsweb+
L 💣 tailscale.com/util/dirwalk from tailscale.com/metrics L 💣 tailscale.com/util/dirwalk from tailscale.com/metrics
tailscale.com/util/dnsname from tailscale.com/hostinfo+ tailscale.com/util/dnsname from tailscale.com/hostinfo+
tailscale.com/util/httpm from tailscale.com/client/tailscale tailscale.com/util/httpm from tailscale.com/client/tailscale

@ -6,7 +6,6 @@
package main package main
import ( import (
"context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"log" "log"
@ -24,22 +23,11 @@ import (
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/tsnet" "tailscale.com/tsnet"
"tailscale.com/util/clientmetric" "tailscale.com/util/clientmetric"
"tailscale.com/util/ctxkey"
"tailscale.com/util/set" "tailscale.com/util/set"
) )
type whoIsKey struct{} var whoIsKey = ctxkey.New("", (*apitype.WhoIsResponse)(nil))
// whoIsFromRequest returns the WhoIsResponse previously stashed by a call to
// addWhoIsToRequest.
func whoIsFromRequest(r *http.Request) *apitype.WhoIsResponse {
return r.Context().Value(whoIsKey{}).(*apitype.WhoIsResponse)
}
// addWhoIsToRequest stashes who in r's context, retrievable by a call to
// whoIsFromRequest.
func addWhoIsToRequest(r *http.Request, who *apitype.WhoIsResponse) *http.Request {
return r.WithContext(context.WithValue(r.Context(), whoIsKey{}, who))
}
var counterNumRequestsProxied = clientmetric.NewCounter("k8s_auth_proxy_requests_proxied") var counterNumRequestsProxied = clientmetric.NewCounter("k8s_auth_proxy_requests_proxied")
@ -127,7 +115,7 @@ func (h *apiserverProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
counterNumRequestsProxied.Add(1) counterNumRequestsProxied.Add(1)
h.rp.ServeHTTP(w, addWhoIsToRequest(r, who)) h.rp.ServeHTTP(w, r.WithContext(whoIsKey.WithValue(r.Context(), who)))
} }
// runAPIServerProxy runs an HTTP server that authenticates requests using the // runAPIServerProxy runs an HTTP server that authenticates requests using the
@ -240,7 +228,7 @@ type impersonateRule struct {
// in the context by the apiserverProxy. // in the context by the apiserverProxy.
func addImpersonationHeaders(r *http.Request, log *zap.SugaredLogger) error { func addImpersonationHeaders(r *http.Request, log *zap.SugaredLogger) error {
log = log.With("remote", r.RemoteAddr) log = log.With("remote", r.RemoteAddr)
who := whoIsFromRequest(r) who := whoIsKey.Value(r.Context())
rules, err := tailcfg.UnmarshalCapJSON[capRule](who.CapMap, capabilityName) rules, err := tailcfg.UnmarshalCapJSON[capRule](who.CapMap, capabilityName)
if len(rules) == 0 && err == nil { if len(rules) == 0 && err == nil {
// Try the old capability name for backwards compatibility. // Try the old capability name for backwards compatibility.

@ -95,7 +95,7 @@ func TestImpersonationHeaders(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
r := must.Get(http.NewRequest("GET", "https://op.ts.net/api/foo", nil)) r := must.Get(http.NewRequest("GET", "https://op.ts.net/api/foo", nil))
r = addWhoIsToRequest(r, &apitype.WhoIsResponse{ r = r.WithContext(whoIsKey.WithValue(r.Context(), &apitype.WhoIsResponse{
Node: &tailcfg.Node{ Node: &tailcfg.Node{
Name: "node.ts.net", Name: "node.ts.net",
Tags: tc.tags, Tags: tc.tags,
@ -104,7 +104,7 @@ func TestImpersonationHeaders(t *testing.T) {
LoginName: tc.emailish, LoginName: tc.emailish,
}, },
CapMap: tc.capMap, CapMap: tc.capMap,
}) }))
addImpersonationHeaders(r, zl.Sugar()) addImpersonationHeaders(r, zl.Sugar())
if d := cmp.Diff(tc.wantHeaders, r.Header); d != "" { if d := cmp.Diff(tc.wantHeaders, r.Header); d != "" {

@ -66,6 +66,7 @@ tailscale.com/cmd/stund dependencies: (generated by github.com/tailscale/depawar
tailscale.com/types/tkatype from tailscale.com/tailcfg+ tailscale.com/types/tkatype from tailscale.com/tailcfg+
tailscale.com/types/views from tailscale.com/net/tsaddr+ tailscale.com/types/views from tailscale.com/net/tsaddr+
tailscale.com/util/cmpx from tailscale.com/tailcfg+ tailscale.com/util/cmpx from tailscale.com/tailcfg+
tailscale.com/util/ctxkey from tailscale.com/tsweb+
L 💣 tailscale.com/util/dirwalk from tailscale.com/metrics L 💣 tailscale.com/util/dirwalk from tailscale.com/metrics
tailscale.com/util/dnsname from tailscale.com/tailcfg tailscale.com/util/dnsname from tailscale.com/tailcfg
tailscale.com/util/lineread from tailscale.com/version/distro tailscale.com/util/lineread from tailscale.com/version/distro

@ -143,6 +143,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
tailscale.com/util/cloudenv from tailscale.com/net/dnscache+ tailscale.com/util/cloudenv from tailscale.com/net/dnscache+
tailscale.com/util/cmpver from tailscale.com/net/tshttpproxy+ tailscale.com/util/cmpver from tailscale.com/net/tshttpproxy+
tailscale.com/util/cmpx from tailscale.com/cmd/tailscale/cli+ tailscale.com/util/cmpx from tailscale.com/cmd/tailscale/cli+
tailscale.com/util/ctxkey from tailscale.com/types/logger
L 💣 tailscale.com/util/dirwalk from tailscale.com/metrics L 💣 tailscale.com/util/dirwalk from tailscale.com/metrics
tailscale.com/util/dnsname from tailscale.com/cmd/tailscale/cli+ tailscale.com/util/dnsname from tailscale.com/cmd/tailscale/cli+
tailscale.com/util/groupmember from tailscale.com/client/web tailscale.com/util/groupmember from tailscale.com/client/web
@ -267,7 +268,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
image/png from github.com/skip2/go-qrcode image/png from github.com/skip2/go-qrcode
io from bufio+ io from bufio+
io/fs from crypto/x509+ io/fs from crypto/x509+
io/ioutil from golang.org/x/sys/cpu+ io/ioutil from github.com/godbus/dbus/v5+
log from expvar+ log from expvar+
log/internal from log log/internal from log
maps from tailscale.com/types/views+ maps from tailscale.com/types/views+

@ -344,6 +344,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
tailscale.com/util/cloudenv from tailscale.com/net/dns/resolver+ tailscale.com/util/cloudenv from tailscale.com/net/dns/resolver+
tailscale.com/util/cmpver from tailscale.com/net/dns+ tailscale.com/util/cmpver from tailscale.com/net/dns+
tailscale.com/util/cmpx from tailscale.com/derp/derphttp+ tailscale.com/util/cmpx from tailscale.com/derp/derphttp+
tailscale.com/util/ctxkey from tailscale.com/ipn/ipnlocal+
💣 tailscale.com/util/deephash from tailscale.com/ipn/ipnlocal+ 💣 tailscale.com/util/deephash from tailscale.com/ipn/ipnlocal+
L 💣 tailscale.com/util/dirwalk from tailscale.com/metrics+ L 💣 tailscale.com/util/dirwalk from tailscale.com/metrics+
tailscale.com/util/dnsname from tailscale.com/hostinfo+ tailscale.com/util/dnsname from tailscale.com/hostinfo+

@ -34,6 +34,7 @@ import (
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/lazy" "tailscale.com/types/lazy"
"tailscale.com/types/logger" "tailscale.com/types/logger"
"tailscale.com/util/ctxkey"
"tailscale.com/util/mak" "tailscale.com/util/mak"
"tailscale.com/version" "tailscale.com/version"
) )
@ -48,8 +49,7 @@ const (
// current etag of a resource. // current etag of a resource.
var ErrETagMismatch = errors.New("etag mismatch") var ErrETagMismatch = errors.New("etag mismatch")
// serveHTTPContextKey is the context.Value key for a *serveHTTPContext. var serveHTTPContextKey ctxkey.Key[*serveHTTPContext]
type serveHTTPContextKey struct{}
type serveHTTPContext struct { type serveHTTPContext struct {
SrcAddr netip.AddrPort SrcAddr netip.AddrPort
@ -433,7 +433,7 @@ func (b *LocalBackend) tcpHandlerForServe(dport uint16, srcAddr netip.AddrPort)
hs := &http.Server{ hs := &http.Server{
Handler: http.HandlerFunc(b.serveWebHandler), Handler: http.HandlerFunc(b.serveWebHandler),
BaseContext: func(_ net.Listener) context.Context { BaseContext: func(_ net.Listener) context.Context {
return context.WithValue(context.Background(), serveHTTPContextKey{}, &serveHTTPContext{ return serveHTTPContextKey.WithValue(context.Background(), &serveHTTPContext{
SrcAddr: srcAddr, SrcAddr: srcAddr,
DestPort: dport, DestPort: dport,
}) })
@ -500,11 +500,6 @@ func (b *LocalBackend) tcpHandlerForServe(dport uint16, srcAddr netip.AddrPort)
return nil return nil
} }
func getServeHTTPContext(r *http.Request) (c *serveHTTPContext, ok bool) {
c, ok = r.Context().Value(serveHTTPContextKey{}).(*serveHTTPContext)
return c, ok
}
func (b *LocalBackend) getServeHandler(r *http.Request) (_ ipn.HTTPHandlerView, at string, ok bool) { func (b *LocalBackend) getServeHandler(r *http.Request) (_ ipn.HTTPHandlerView, at string, ok bool) {
var z ipn.HTTPHandlerView // zero value var z ipn.HTTPHandlerView // zero value
@ -521,7 +516,7 @@ func (b *LocalBackend) getServeHandler(r *http.Request) (_ ipn.HTTPHandlerView,
hostname = r.TLS.ServerName hostname = r.TLS.ServerName
} }
sctx, ok := getServeHTTPContext(r) sctx, ok := serveHTTPContextKey.ValueOk(r.Context())
if !ok { if !ok {
b.logf("[unexpected] localbackend: no serveHTTPContext in request") b.logf("[unexpected] localbackend: no serveHTTPContext in request")
return z, "", false return z, "", false
@ -684,7 +679,7 @@ func addProxyForwardedHeaders(r *httputil.ProxyRequest) {
if r.In.TLS != nil { if r.In.TLS != nil {
r.Out.Header.Set("X-Forwarded-Proto", "https") r.Out.Header.Set("X-Forwarded-Proto", "https")
} }
if c, ok := getServeHTTPContext(r.Out); ok { if c, ok := serveHTTPContextKey.ValueOk(r.Out.Context()); ok {
r.Out.Header.Set("X-Forwarded-For", c.SrcAddr.Addr().String()) r.Out.Header.Set("X-Forwarded-For", c.SrcAddr.Addr().String())
} }
} }
@ -696,7 +691,7 @@ func (b *LocalBackend) addTailscaleIdentityHeaders(r *httputil.ProxyRequest) {
r.Out.Header.Del("Tailscale-User-Profile-Pic") r.Out.Header.Del("Tailscale-User-Profile-Pic")
r.Out.Header.Del("Tailscale-Headers-Info") r.Out.Header.Del("Tailscale-Headers-Info")
c, ok := getServeHTTPContext(r.Out) c, ok := serveHTTPContextKey.ValueOk(r.Out.Context())
if !ok { if !ok {
return return
} }

@ -158,7 +158,7 @@ func TestGetServeHandler(t *testing.T) {
TLS: &tls.ConnectionState{ServerName: serverName}, TLS: &tls.ConnectionState{ServerName: serverName},
} }
port := cmpx.Or(tt.port, 443) port := cmpx.Or(tt.port, 443)
req = req.WithContext(context.WithValue(req.Context(), serveHTTPContextKey{}, &serveHTTPContext{ req = req.WithContext(serveHTTPContextKey.WithValue(req.Context(), &serveHTTPContext{
DestPort: port, DestPort: port,
})) }))
@ -428,7 +428,7 @@ func TestServeHTTPProxy(t *testing.T) {
URL: &url.URL{Path: "/"}, URL: &url.URL{Path: "/"},
TLS: &tls.ConnectionState{ServerName: "example.ts.net"}, TLS: &tls.ConnectionState{ServerName: "example.ts.net"},
} }
req = req.WithContext(context.WithValue(req.Context(), serveHTTPContextKey{}, &serveHTTPContext{ req = req.WithContext(serveHTTPContextKey.WithValue(req.Context(), &serveHTTPContext{
DestPort: 443, DestPort: 443,
SrcAddr: netip.MustParseAddrPort(tt.srcIP + ":1234"), // random src port for tests SrcAddr: netip.MustParseAddrPort(tt.srcIP + ":1234"), // random src port for tests
})) }))

@ -8,6 +8,7 @@ import (
"net/http" "net/http"
"github.com/google/uuid" "github.com/google/uuid"
"tailscale.com/util/ctxkey"
) )
// RequestID is an opaque identifier for a HTTP request, used to correlate // RequestID is an opaque identifier for a HTTP request, used to correlate
@ -24,6 +25,9 @@ import (
// opaque string. The current implementation uses a UUID. // opaque string. The current implementation uses a UUID.
type RequestID string type RequestID string
// RequestIDKey stores and loads [RequestID] values within a [context.Context].
var RequestIDKey ctxkey.Key[RequestID]
// RequestIDHeader is a custom HTTP header that the WithRequestID middleware // RequestIDHeader is a custom HTTP header that the WithRequestID middleware
// uses to determine whether to re-use a given request ID from the client // uses to determine whether to re-use a given request ID from the client
// or generate a new one. // or generate a new one.
@ -42,22 +46,16 @@ func SetRequestID(h http.Handler) http.Handler {
// transitions if needed. // transitions if needed.
id = "REQ-1" + uuid.NewString() id = "REQ-1" + uuid.NewString()
} }
ctx := withRequestID(r.Context(), RequestID(id)) ctx := RequestIDKey.WithValue(r.Context(), RequestID(id))
r = r.WithContext(ctx) r = r.WithContext(ctx)
h.ServeHTTP(w, r) h.ServeHTTP(w, r)
}) })
} }
type requestIDKey struct{}
// RequestIDFromContext retrieves the RequestID from context that can be set by // RequestIDFromContext retrieves the RequestID from context that can be set by
// the SetRequestID function. // the SetRequestID function.
//
// Deprecated: Use [RequestIDKey.Value] instead.
func RequestIDFromContext(ctx context.Context) RequestID { func RequestIDFromContext(ctx context.Context) RequestID {
val, _ := ctx.Value(requestIDKey{}).(RequestID) return RequestIDKey.Value(ctx)
return val
}
// withRequestID sets the given request id value in the given context.
func withRequestID(ctx context.Context, rid RequestID) context.Context {
return context.WithValue(ctx, requestIDKey{}, rid)
} }

@ -166,7 +166,7 @@ func TestStdHandler(t *testing.T) {
{ {
name: "handler returns 404 via HTTPError with request ID", name: "handler returns 404 via HTTPError with request ID",
rh: handlerErr(0, Error(404, "not found", testErr)), rh: handlerErr(0, Error(404, "not found", testErr)),
r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/foo"), r: req(RequestIDKey.WithValue(bgCtx, exampleRequestID), "http://example.com/foo"),
wantCode: 404, wantCode: 404,
wantLog: AccessLogRecord{ wantLog: AccessLogRecord{
When: startTime, When: startTime,
@ -203,7 +203,7 @@ func TestStdHandler(t *testing.T) {
{ {
name: "handler returns 404 with request ID and nil child error", name: "handler returns 404 with request ID and nil child error",
rh: handlerErr(0, Error(404, "not found", nil)), rh: handlerErr(0, Error(404, "not found", nil)),
r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/foo"), r: req(RequestIDKey.WithValue(bgCtx, exampleRequestID), "http://example.com/foo"),
wantCode: 404, wantCode: 404,
wantLog: AccessLogRecord{ wantLog: AccessLogRecord{
When: startTime, When: startTime,
@ -240,7 +240,7 @@ func TestStdHandler(t *testing.T) {
{ {
name: "handler returns user-visible error with request ID", name: "handler returns user-visible error with request ID",
rh: handlerErr(0, vizerror.New("visible error")), rh: handlerErr(0, vizerror.New("visible error")),
r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/foo"), r: req(RequestIDKey.WithValue(bgCtx, exampleRequestID), "http://example.com/foo"),
wantCode: 500, wantCode: 500,
wantLog: AccessLogRecord{ wantLog: AccessLogRecord{
When: startTime, When: startTime,
@ -277,7 +277,7 @@ func TestStdHandler(t *testing.T) {
{ {
name: "handler returns user-visible error wrapped by private error with request ID", name: "handler returns user-visible error wrapped by private error with request ID",
rh: handlerErr(0, fmt.Errorf("private internal error: %w", vizerror.New("visible error"))), rh: handlerErr(0, fmt.Errorf("private internal error: %w", vizerror.New("visible error"))),
r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/foo"), r: req(RequestIDKey.WithValue(bgCtx, exampleRequestID), "http://example.com/foo"),
wantCode: 500, wantCode: 500,
wantLog: AccessLogRecord{ wantLog: AccessLogRecord{
When: startTime, When: startTime,
@ -314,7 +314,7 @@ func TestStdHandler(t *testing.T) {
{ {
name: "handler returns generic error with request ID", name: "handler returns generic error with request ID",
rh: handlerErr(0, testErr), rh: handlerErr(0, testErr),
r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/foo"), r: req(RequestIDKey.WithValue(bgCtx, exampleRequestID), "http://example.com/foo"),
wantCode: 500, wantCode: 500,
wantLog: AccessLogRecord{ wantLog: AccessLogRecord{
When: startTime, When: startTime,
@ -350,7 +350,7 @@ func TestStdHandler(t *testing.T) {
{ {
name: "handler returns error after writing response with request ID", name: "handler returns error after writing response with request ID",
rh: handlerErr(200, testErr), rh: handlerErr(200, testErr),
r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/foo"), r: req(RequestIDKey.WithValue(bgCtx, exampleRequestID), "http://example.com/foo"),
wantCode: 200, wantCode: 200,
wantLog: AccessLogRecord{ wantLog: AccessLogRecord{
When: startTime, When: startTime,
@ -446,7 +446,7 @@ func TestStdHandler(t *testing.T) {
{ {
name: "error handler gets run with request ID", name: "error handler gets run with request ID",
rh: handlerErr(0, Error(404, "not found", nil)), // status code changed in errHandler rh: handlerErr(0, Error(404, "not found", nil)), // status code changed in errHandler
r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/"), r: req(RequestIDKey.WithValue(bgCtx, exampleRequestID), "http://example.com/"),
wantCode: 200, wantCode: 200,
errHandler: func(w http.ResponseWriter, r *http.Request, e HTTPError) { errHandler: func(w http.ResponseWriter, r *http.Request, e HTTPError) {
requestID := RequestIDFromContext(r.Context()) requestID := RequestIDFromContext(r.Context())

@ -21,6 +21,7 @@ import (
"context" "context"
"tailscale.com/envknob" "tailscale.com/envknob"
"tailscale.com/util/ctxkey"
) )
// Logf is the basic Tailscale logger type: a printf-like func. // Logf is the basic Tailscale logger type: a printf-like func.
@ -28,13 +29,16 @@ import (
// Logf functions must be safe for concurrent use. // Logf functions must be safe for concurrent use.
type Logf func(format string, args ...any) type Logf func(format string, args ...any)
// LogfKey stores and loads [Logf] values within a [context.Context].
var LogfKey = ctxkey.New("", Logf(log.Printf))
// A Context is a context.Context that should contain a custom log function, obtainable from FromContext. // A Context is a context.Context that should contain a custom log function, obtainable from FromContext.
// If no log function is present, FromContext will return log.Printf. // If no log function is present, FromContext will return log.Printf.
// To construct a Context, use Add // To construct a Context, use Add
//
// Deprecated: Do not use.
type Context context.Context type Context context.Context
type logfKey struct{}
// jenc is a json.Encode + bytes.Buffer pair wired up to be reused in a pool. // jenc is a json.Encode + bytes.Buffer pair wired up to be reused in a pool.
type jenc struct { type jenc struct {
buf bytes.Buffer buf bytes.Buffer
@ -79,17 +83,17 @@ func (logf Logf) JSON(level int, recType string, v any) {
} }
// FromContext extracts a log function from ctx. // FromContext extracts a log function from ctx.
//
// Deprecated: Use [LogfKey.Value] instead.
func FromContext(ctx Context) Logf { func FromContext(ctx Context) Logf {
v := ctx.Value(logfKey{}) return LogfKey.Value(ctx)
if v == nil {
return log.Printf
}
return v.(Logf)
} }
// Ctx constructs a Context from ctx with fn as its custom log function. // Ctx constructs a Context from ctx with fn as its custom log function.
//
// Deprecated: Use [LogfKey.WithValue] instead.
func Ctx(ctx context.Context, fn Logf) Context { func Ctx(ctx context.Context, fn Logf) Context {
return context.WithValue(ctx, logfKey{}, fn) return LogfKey.WithValue(ctx, fn)
} }
// WithPrefix wraps f, prefixing each format with the provided prefix. // WithPrefix wraps f, prefixing each format with the provided prefix.

@ -6,13 +6,13 @@
// Example usage: // Example usage:
// //
// // Create a context key. // // Create a context key.
// var TimeoutKey = ctxkey.New("fsrv.Timeout", 5*time.Second) // var TimeoutKey = ctxkey.New("mapreduce.Timeout", 5*time.Second)
// //
// // Store a context value. // // Store a context value.
// ctx = fsrv.TimeoutKey.WithValue(ctx, 10*time.Second) // ctx = mapreduce.TimeoutKey.WithValue(ctx, 10*time.Second)
// //
// // Load a context value. // // Load a context value.
// timeout := fsrv.TimeoutKey.Value(ctx) // timeout := mapreduce.TimeoutKey.Value(ctx)
// ... // use timeout of type time.Duration // ... // use timeout of type time.Duration
// //
// This is inspired by https://go.dev/issue/49189. // This is inspired by https://go.dev/issue/49189.
@ -24,20 +24,23 @@ import (
"reflect" "reflect"
) )
// TODO(https://go.dev/issue/60088): Use reflect.TypeFor instead.
func reflectTypeFor[T any]() reflect.Type {
return reflect.TypeOf((*T)(nil)).Elem()
}
// Key is a generic key type associated with a specific value type. // Key is a generic key type associated with a specific value type.
// //
// A zero Key is valid where the Value type itself is used as the context key. // A zero Key is valid where the Value type itself is used as the context key.
// This pattern should only be used with locally declared Go types. // This pattern should only be used with locally declared Go types,
// The Value type must not be an interface type. // otherwise different packages risk producing key conflicts.
// //
// Example usage: // Example usage:
// //
// type peerInfo struct { ... } // peerInfo is an unexported type // type peerInfo struct { ... } // peerInfo is a locally declared type
// var peerInfoKey = ctxkey.Key[peerInfo] // var peerInfoKey ctxkey.Key[peerInfo]
// ctx = peerInfoKey.WithValue(ctx, info) // store a context value // ctx = peerInfoKey.WithValue(ctx, info) // store a context value
// info = peerInfoKey.Value(ctx) // load a context value // info = peerInfoKey.Value(ctx) // load a context value
//
// In general, any exported keys should be produced using [New].
type Key[Value any] struct { type Key[Value any] struct {
name *stringer[string] name *stringer[string]
defVal *Value defVal *Value
@ -49,6 +52,7 @@ type Key[Value any] struct {
// The provided name is an arbitrary name only used for human debugging. // The provided name is an arbitrary name only used for human debugging.
// As a convention, it is recommended that the name be the dot-delimited // As a convention, it is recommended that the name be the dot-delimited
// combination of the package name of the caller with the variable name. // combination of the package name of the caller with the variable name.
// If the name is not provided, then the name of the Value type is used.
// Every key is unique, even if provided the same name. // Every key is unique, even if provided the same name.
// //
// Example usage: // Example usage:
@ -56,32 +60,25 @@ type Key[Value any] struct {
// package mapreduce // package mapreduce
// var NumWorkersKey = ctxkey.New("mapreduce.NumWorkers", runtime.NumCPU()) // var NumWorkersKey = ctxkey.New("mapreduce.NumWorkers", runtime.NumCPU())
func New[Value any](name string, defaultValue Value) Key[Value] { func New[Value any](name string, defaultValue Value) Key[Value] {
// Allocate a new stringer to ensure that every invocation of New
// creates a universally unique context key even for the same name
// since newly allocated pointers are globally unique within a process.
key := Key[Value]{name: new(stringer[string])}
if name == "" { if name == "" {
var v Value name = reflectTypeFor[Value]().String()
name = reflect.TypeOf(v).String() // TODO(https://go.dev/issue/60088): Use reflect.TypeFor.
} }
var defVal *Value key.name.v = name
switch v := reflect.ValueOf(&defaultValue).Elem(); { if v := reflect.ValueOf(defaultValue); v.IsValid() && !v.IsZero() {
case v.Kind() == reflect.Interface: key.defVal = &defaultValue
panic(fmt.Sprintf("value type %v must not be an interface", v.Type()))
case !v.IsZero():
defVal = &defaultValue
} }
// Allocate a *stringer to ensure that every invocation of New return key
// creates a universally unique context key even for the same name.
return Key[Value]{name: &stringer[string]{name}, defVal: defVal}
} }
// contextKey returns the context key to use. // contextKey returns the context key to use.
func (key Key[Value]) contextKey() any { func (key Key[Value]) contextKey() any {
if key.name == nil { if key.name == nil {
// Use the reflect.Type of the Value (implies key not created by New). // Use the reflect.Type of the Value (implies key not created by New).
var v Value return reflectTypeFor[Value]()
t := reflect.TypeOf(v)
if t == nil {
panic(fmt.Sprintf("value type %v must not be an interface", reflect.TypeOf(&v).Elem()))
}
return t
} else { } else {
// Use the name pointer directly (implies key created by New). // Use the name pointer directly (implies key created by New).
return key.name return key.name
@ -122,8 +119,7 @@ func (key Key[Value]) Has(ctx context.Context) (ok bool) {
// String returns the name of the key. // String returns the name of the key.
func (key Key[Value]) String() string { func (key Key[Value]) String() string {
if key.name == nil { if key.name == nil {
var v Value return reflectTypeFor[Value]().String()
return reflect.TypeOf(v).String() // TODO(https://go.dev/issue/60088): Use reflect.TypeFor.
} }
return key.name.String() return key.name.String()
} }
@ -134,6 +130,11 @@ func (key Key[Value]) String() string {
// Note that the [context] package lacks a dependency on [reflect], // Note that the [context] package lacks a dependency on [reflect],
// so it cannot print arbitrary values. By implementing [fmt.Stringer], // so it cannot print arbitrary values. By implementing [fmt.Stringer],
// we functionally teach a context how to print itself. // we functionally teach a context how to print itself.
//
// Wrapping values within a struct has an added bonus that interface kinds
// are properly handled. Without wrapping, we would be unable to distinguish
// between a nil value that was explicitly set or not.
// However, the presence of a stringer indicates an explicit nil value.
type stringer[T any] struct{ v T } type stringer[T any] struct{ v T }
func (v stringer[T]) String() string { return fmt.Sprint(v.v) } func (v stringer[T]) String() string { return fmt.Sprint(v.v) }

@ -6,6 +6,7 @@ package ctxkey
import ( import (
"context" "context"
"fmt" "fmt"
"io"
"regexp" "regexp"
"testing" "testing"
"time" "time"
@ -69,6 +70,27 @@ func TestKey(t *testing.T) {
c.Assert(k5 == k6, qt.Equals, true) c.Assert(k5 == k6, qt.Equals, true)
c.Assert(k6.Has(ctx), qt.Equals, true) c.Assert(k6.Has(ctx), qt.Equals, true)
ctx = k6.WithValue(ctx, "fizz") ctx = k6.WithValue(ctx, "fizz")
// Test interface value types.
var k7 Key[any]
c.Assert(k7.Has(ctx), qt.Equals, false)
ctx = k7.WithValue(ctx, "whatever")
c.Assert(k7.Value(ctx), qt.DeepEquals, "whatever")
ctx = k7.WithValue(ctx, []int{1, 2, 3})
c.Assert(k7.Value(ctx), qt.DeepEquals, []int{1, 2, 3})
ctx = k7.WithValue(ctx, nil)
c.Assert(k7.Has(ctx), qt.Equals, true)
c.Assert(k7.Value(ctx), qt.DeepEquals, nil)
k8 := New[error]("error", io.EOF)
c.Assert(k8.Has(ctx), qt.Equals, false)
c.Assert(k8.Value(ctx), qt.Equals, io.EOF)
ctx = k8.WithValue(ctx, nil)
c.Assert(k8.Value(ctx), qt.Equals, nil)
c.Assert(k8.Has(ctx), qt.Equals, true)
err := fmt.Errorf("read error: %w", io.ErrUnexpectedEOF)
ctx = k8.WithValue(ctx, err)
c.Assert(k8.Value(ctx), qt.Equals, err)
c.Assert(k8.Has(ctx), qt.Equals, true)
} }
func TestStringer(t *testing.T) { func TestStringer(t *testing.T) {

Loading…
Cancel
Save