tsweb: propagate RequestID via context and entire request

The recent addition of RequestID was only populated if the
HTTP Request had returned an error. This meant that the underlying
handler has no access to this request id and any logs it may have
emitted were impossible to correlate to that request id. Therefore,
this PR adds a middleware to generate request ids and pass them
through the request context. The tsweb.StdHandler automatically
populates this request id if the middleware is being used. Finally,
inner handlers can use the context to retrieve that same request id
and use it so that all logs and events can be correlated.

Updates #2549

Signed-off-by: Marwan Sulaiman <marwan@tailscale.com>
marwan/displayname
Marwan Sulaiman 12 months ago committed by Marwan Sulaiman
parent c27aa9e7ff
commit b819f66eb1

@ -18,6 +18,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
L github.com/google/nftables/expr from github.com/google/nftables+ L github.com/google/nftables/expr from github.com/google/nftables+
L github.com/google/nftables/internal/parseexprfunc from github.com/google/nftables+ L github.com/google/nftables/internal/parseexprfunc from github.com/google/nftables+
L github.com/google/nftables/xt from github.com/google/nftables/expr+ L github.com/google/nftables/xt from github.com/google/nftables/expr+
github.com/google/uuid from tailscale.com/tsweb
github.com/hdevalence/ed25519consensus from tailscale.com/tka github.com/hdevalence/ed25519consensus from tailscale.com/tka
L github.com/josharian/native from github.com/mdlayher/netlink+ L github.com/josharian/native from github.com/mdlayher/netlink+
L 💣 github.com/jsimonetti/rtnetlink from tailscale.com/net/interfaces+ L 💣 github.com/jsimonetti/rtnetlink from tailscale.com/net/interfaces+
@ -218,6 +219,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
crypto/tls from golang.org/x/crypto/acme+ crypto/tls from golang.org/x/crypto/acme+
crypto/x509 from crypto/tls+ crypto/x509 from crypto/tls+
crypto/x509/pkix from crypto/x509+ crypto/x509/pkix from crypto/x509+
database/sql/driver from github.com/google/uuid
embed from crypto/internal/nistec+ embed from crypto/internal/nistec+
encoding from encoding/json+ encoding from encoding/json+
encoding/asn1 from crypto/x509+ encoding/asn1 from crypto/x509+

@ -45,11 +45,11 @@ type AccessLogRecord struct {
Bytes int `json:"bytes,omitempty"` Bytes int `json:"bytes,omitempty"`
// Error encountered during request processing. // Error encountered during request processing.
Err string `json:"err,omitempty"` Err string `json:"err,omitempty"`
// RequestID is a unique ID for this request. When a request fails due to an // RequestID is a unique ID for this request. If the *http.Request context
// error, the ID is generated and displayed to the client immediately after // carries this value via SetRequestID, then it will be displayed to the
// the error text, as well as logged here. This makes it easier to correlate // client immediately after the error text, as well as logged here. This
// support requests with server logs. If a RequestID generator is not // makes it easier to correlate support requests with server logs. If a
// configured, RequestID will be empty. // RequestID generator is not configured, RequestID will be empty.
RequestID RequestID `json:"request_id,omitempty"` RequestID RequestID `json:"request_id,omitempty"`
} }

@ -0,0 +1,63 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package tsweb
import (
"context"
"net/http"
"github.com/google/uuid"
)
// RequestID is an opaque identifier for a HTTP request, used to correlate
// user-visible errors with backend server logs. The RequestID is typically
// threaded through an HTTP Middleware (WithRequestID) and then can be extracted
// by HTTP Handlers to include in their logs.
//
// RequestID is an opaque identifier for a HTTP request, used to correlate
// user-visible errors with backend server logs. If present in the context, the
// RequestID will be printed alongside the message text and logged in the
// AccessLogRecord.
//
// A RequestID has the format "REQ-1{ID}", and the ID should be treated as an
// opaque string. The current implementation uses a UUID.
type RequestID string
// RequestIDHeader is a custom HTTP header that the WithRequestID middleware
// uses to determine whether to re-use a given request ID from the client
// or generate a new one.
const RequestIDHeader = "X-Tailscale-Request-Id"
// SetRequestID is an HTTP middleware that injects a RequestID in the
// *http.Request Context. The value of that request id is either retrieved from
// the RequestIDHeader or a randomly generated one if not exists. Inner
// handlers can retrieve this ID from the RequestIDFromContext function.
func SetRequestID(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
id := r.Header.Get(RequestIDHeader)
if id == "" {
// REQ-1 indicates the version of the RequestID pattern. It is
// currently arbitrary but allows for forward compatible
// transitions if needed.
id = "REQ-1" + uuid.NewString()
}
ctx := withRequestID(r.Context(), RequestID(id))
r = r.WithContext(ctx)
h.ServeHTTP(w, r)
})
}
type requestIDKey struct{}
// RequestIDFromContext retrieves the RequestID from context that can be set by
// the SetRequestID function.
func RequestIDFromContext(ctx context.Context) RequestID {
val, _ := ctx.Value(requestIDKey{}).(RequestID)
return val
}
// withRequestID sets the given request id value in the given context.
func withRequestID(ctx context.Context, rid RequestID) context.Context {
return context.WithValue(ctx, requestIDKey{}, rid)
}

@ -177,8 +177,7 @@ type ReturnHandler interface {
type HandlerOptions struct { type HandlerOptions struct {
QuietLoggingIfSuccessful bool // if set, do not log successfully handled HTTP requests (200 and 304 status codes) QuietLoggingIfSuccessful bool // if set, do not log successfully handled HTTP requests (200 and 304 status codes)
Logf logger.Logf Logf logger.Logf
Now func() time.Time // if nil, defaults to time.Now Now func() time.Time // if nil, defaults to time.Now
GenerateRequestID func(*http.Request) RequestID // if nil, no request IDs are generated
// If non-nil, StatusCodeCounters maintains counters // If non-nil, StatusCodeCounters maintains counters
// of status codes for handled responses. // of status codes for handled responses.
@ -204,6 +203,13 @@ type ErrorHandlerFunc func(http.ResponseWriter, *http.Request, HTTPError)
// calls f. // calls f.
type ReturnHandlerFunc func(http.ResponseWriter, *http.Request) error type ReturnHandlerFunc func(http.ResponseWriter, *http.Request) error
// A Middleware is a function that wraps an http.Handler to extend or modify
// its behaviour.
//
// The implementation of the wrapper is responsible for delegating its input
// request to the underlying handler, if appropriate.
type Middleware func(h http.Handler) http.Handler
// ServeHTTPReturn calls f(w, r). // ServeHTTPReturn calls f(w, r).
func (f ReturnHandlerFunc) ServeHTTPReturn(w http.ResponseWriter, r *http.Request) error { func (f ReturnHandlerFunc) ServeHTTPReturn(w http.ResponseWriter, r *http.Request) error {
return f(w, r) return f(w, r)
@ -240,6 +246,7 @@ func (h retHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
RequestURI: r.URL.RequestURI(), RequestURI: r.URL.RequestURI(),
UserAgent: r.UserAgent(), UserAgent: r.UserAgent(),
Referer: r.Referer(), Referer: r.Referer(),
RequestID: RequestIDFromContext(r.Context()),
} }
lw := &loggingResponseWriter{ResponseWriter: w, logf: h.opts.Logf} lw := &loggingResponseWriter{ResponseWriter: w, logf: h.opts.Logf}
@ -275,11 +282,6 @@ func (h retHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
msg.Code = 499 // nginx convention: Client Closed Request msg.Code = 499 // nginx convention: Client Closed Request
msg.Err = context.Canceled.Error() msg.Err = context.Canceled.Error()
case hErrOK: case hErrOK:
if hErr.RequestID == "" && h.opts.GenerateRequestID != nil {
hErr.RequestID = h.opts.GenerateRequestID(r)
}
msg.RequestID = hErr.RequestID
// Handler asked us to send an error. Do so, if we haven't // Handler asked us to send an error. Do so, if we haven't
// already sent a response. // already sent a response.
msg.Err = hErr.Msg msg.Err = hErr.Msg
@ -310,17 +312,15 @@ func (h retHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
lw.WriteHeader(msg.Code) lw.WriteHeader(msg.Code)
fmt.Fprintln(lw, hErr.Msg) fmt.Fprintln(lw, hErr.Msg)
if hErr.RequestID != "" { if msg.RequestID != "" {
fmt.Fprintln(lw, hErr.RequestID) fmt.Fprintln(lw, msg.RequestID)
} }
} }
case err != nil: case err != nil:
const internalServerError = "internal server error" const internalServerError = "internal server error"
errorMessage := internalServerError errorMessage := internalServerError
if h.opts.GenerateRequestID != nil { if msg.RequestID != "" {
msg.RequestID = h.opts.GenerateRequestID(r) errorMessage += "\n" + string(msg.RequestID)
errorMessage = errorMessage + "\n" + string(msg.RequestID)
} }
// Handler returned a generic error. Serve an internal server // Handler returned a generic error. Serve an internal server
// error, if necessary. // error, if necessary.
@ -422,45 +422,18 @@ func (l loggingResponseWriter) Flush() {
f.Flush() f.Flush()
} }
// RequestID is an opaque identifier for a HTTP request, used to correlate
// user-visible errors with backend server logs. If present in a HTTPError, the
// RequestID will be printed alongside the message text and logged in the
// AccessLogRecord. If an HTTPError has no RequestID (or a non-HTTPError error
// is returned), but the StdHandler has a RequestID generator function, then a
// RequestID will be generated before responding to the client and logging the
// error.
//
// In the event that there is no ErrorHandlerFunc and a non-HTTPError is
// returned to a StdHandler, the response body will be formatted like
// "internal server error\n{RequestID}\n".
//
// There is no particular format required for a RequestID, but ideally it should
// be obvious to an end-user that it is something to record for support
// purposes. One possible example for a RequestID format is:
// REQ-{server identifier}-{timestamp}-{random hex string}.
type RequestID string
// HTTPError is an error with embedded HTTP response information. // HTTPError is an error with embedded HTTP response information.
// //
// It is the error type to be (optionally) used by Handler.ServeHTTPReturn. // It is the error type to be (optionally) used by Handler.ServeHTTPReturn.
type HTTPError struct { type HTTPError struct {
Code int // HTTP response code to send to client; 0 means 500 Code int // HTTP response code to send to client; 0 means 500
Msg string // Response body to send to client Msg string // Response body to send to client
Err error // Detailed error to log on the server Err error // Detailed error to log on the server
RequestID RequestID // Optional identifier to connect client-visible errors with server logs Header http.Header // Optional set of HTTP headers to set in the response
Header http.Header // Optional set of HTTP headers to set in the response
} }
// Error implements the error interface. // Error implements the error interface.
func (e HTTPError) Error() string { func (e HTTPError) Error() string { return fmt.Sprintf("httperror{%d, %q, %v}", e.Code, e.Msg, e.Err) }
if e.RequestID != "" {
return fmt.Sprintf("httperror{%d, %q, %v, RequestID=%q}", e.Code, e.Msg, e.Err, e.RequestID)
} else {
// Backwards compatibility
return fmt.Sprintf("httperror{%d, %q, %v}", e.Code, e.Msg, e.Err)
}
}
func (e HTTPError) Unwrap() error { return e.Err } func (e HTTPError) Unwrap() error { return e.Err }
// Error returns an HTTPError containing the given information. // Error returns an HTTPError containing the given information.
@ -502,8 +475,8 @@ func BrowserHeaderHandler(h http.Handler) http.Handler {
// BrowserHeaderHandlerFunc wraps the provided http.HandlerFunc with a call to // BrowserHeaderHandlerFunc wraps the provided http.HandlerFunc with a call to
// AddBrowserHeaders. // AddBrowserHeaders.
func BrowserHeaderHandlerFunc(h http.HandlerFunc) http.HandlerFunc { func BrowserHeaderHandlerFunc(h http.HandlerFunc) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
AddBrowserHeaders(w) AddBrowserHeaders(w)
h.ServeHTTP(w, r) h.ServeHTTP(w, r)
}) }
} }

@ -67,20 +67,17 @@ func TestStdHandler(t *testing.T) {
bgCtx = context.Background() bgCtx = context.Background()
// canceledCtx, cancel = context.WithCancel(bgCtx) // canceledCtx, cancel = context.WithCancel(bgCtx)
startTime = time.Unix(1687870000, 1234) startTime = time.Unix(1687870000, 1234)
setExampleRequestID = func(_ *http.Request) RequestID { return exampleRequestID }
) )
// cancel() // cancel()
tests := []struct { tests := []struct {
name string name string
rh ReturnHandler rh ReturnHandler
r *http.Request r *http.Request
errHandler ErrorHandlerFunc errHandler ErrorHandlerFunc
generateRequestID func(*http.Request) RequestID wantCode int
wantCode int wantLog AccessLogRecord
wantLog AccessLogRecord wantBody string
wantBody string
}{ }{
{ {
name: "handler returns 200", name: "handler returns 200",
@ -100,11 +97,10 @@ func TestStdHandler(t *testing.T) {
}, },
{ {
name: "handler returns 200 with request ID", name: "handler returns 200 with request ID",
rh: handlerCode(200), rh: handlerCode(200),
r: req(bgCtx, "http://example.com/"), r: req(bgCtx, "http://example.com/"),
generateRequestID: setExampleRequestID, wantCode: 200,
wantCode: 200,
wantLog: AccessLogRecord{ wantLog: AccessLogRecord{
When: startTime, When: startTime,
Seconds: 1.0, Seconds: 1.0,
@ -134,11 +130,10 @@ func TestStdHandler(t *testing.T) {
}, },
{ {
name: "handler returns 404 with request ID", name: "handler returns 404 with request ID",
rh: handlerCode(404), rh: handlerCode(404),
r: req(bgCtx, "http://example.com/foo"), r: req(bgCtx, "http://example.com/foo"),
generateRequestID: setExampleRequestID, wantCode: 404,
wantCode: 404,
wantLog: AccessLogRecord{ wantLog: AccessLogRecord{
When: startTime, When: startTime,
Seconds: 1.0, Seconds: 1.0,
@ -169,11 +164,10 @@ func TestStdHandler(t *testing.T) {
}, },
{ {
name: "handler returns 404 via HTTPError with request ID", name: "handler returns 404 via HTTPError with request ID",
rh: handlerErr(0, Error(404, "not found", testErr)), rh: handlerErr(0, Error(404, "not found", testErr)),
r: req(bgCtx, "http://example.com/foo"), r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/foo"),
generateRequestID: setExampleRequestID, wantCode: 404,
wantCode: 404,
wantLog: AccessLogRecord{ wantLog: AccessLogRecord{
When: startTime, When: startTime,
Seconds: 1.0, Seconds: 1.0,
@ -207,11 +201,10 @@ func TestStdHandler(t *testing.T) {
}, },
{ {
name: "handler returns 404 with request ID and nil child error", name: "handler returns 404 with request ID and nil child error",
rh: handlerErr(0, Error(404, "not found", nil)), rh: handlerErr(0, Error(404, "not found", nil)),
r: req(bgCtx, "http://example.com/foo"), r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/foo"),
generateRequestID: setExampleRequestID, wantCode: 404,
wantCode: 404,
wantLog: AccessLogRecord{ wantLog: AccessLogRecord{
When: startTime, When: startTime,
Seconds: 1.0, Seconds: 1.0,
@ -245,11 +238,10 @@ func TestStdHandler(t *testing.T) {
}, },
{ {
name: "handler returns user-visible error with request ID", name: "handler returns user-visible error with request ID",
rh: handlerErr(0, vizerror.New("visible error")), rh: handlerErr(0, vizerror.New("visible error")),
r: req(bgCtx, "http://example.com/foo"), r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/foo"),
generateRequestID: setExampleRequestID, wantCode: 500,
wantCode: 500,
wantLog: AccessLogRecord{ wantLog: AccessLogRecord{
When: startTime, When: startTime,
Seconds: 1.0, Seconds: 1.0,
@ -283,11 +275,10 @@ func TestStdHandler(t *testing.T) {
}, },
{ {
name: "handler returns user-visible error wrapped by private error with request ID", name: "handler returns user-visible error wrapped by private error with request ID",
rh: handlerErr(0, fmt.Errorf("private internal error: %w", vizerror.New("visible error"))), rh: handlerErr(0, fmt.Errorf("private internal error: %w", vizerror.New("visible error"))),
r: req(bgCtx, "http://example.com/foo"), r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/foo"),
generateRequestID: setExampleRequestID, wantCode: 500,
wantCode: 500,
wantLog: AccessLogRecord{ wantLog: AccessLogRecord{
When: startTime, When: startTime,
Seconds: 1.0, Seconds: 1.0,
@ -321,11 +312,10 @@ func TestStdHandler(t *testing.T) {
}, },
{ {
name: "handler returns generic error with request ID", name: "handler returns generic error with request ID",
rh: handlerErr(0, testErr), rh: handlerErr(0, testErr),
r: req(bgCtx, "http://example.com/foo"), r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/foo"),
generateRequestID: setExampleRequestID, wantCode: 500,
wantCode: 500,
wantLog: AccessLogRecord{ wantLog: AccessLogRecord{
When: startTime, When: startTime,
Seconds: 1.0, Seconds: 1.0,
@ -358,11 +348,10 @@ func TestStdHandler(t *testing.T) {
}, },
{ {
name: "handler returns error after writing response with request ID", name: "handler returns error after writing response with request ID",
rh: handlerErr(200, testErr), rh: handlerErr(200, testErr),
r: req(bgCtx, "http://example.com/foo"), r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/foo"),
generateRequestID: setExampleRequestID, wantCode: 200,
wantCode: 200,
wantLog: AccessLogRecord{ wantLog: AccessLogRecord{
When: startTime, When: startTime,
Seconds: 1.0, Seconds: 1.0,
@ -455,13 +444,13 @@ func TestStdHandler(t *testing.T) {
}, },
{ {
name: "error handler gets run with request ID", name: "error handler gets run with request ID",
rh: handlerErr(0, Error(404, "not found", nil)), // status code changed in errHandler rh: handlerErr(0, Error(404, "not found", nil)), // status code changed in errHandler
r: req(bgCtx, "http://example.com/"), r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/"),
generateRequestID: setExampleRequestID, wantCode: 200,
wantCode: 200,
errHandler: func(w http.ResponseWriter, r *http.Request, e HTTPError) { errHandler: func(w http.ResponseWriter, r *http.Request, e HTTPError) {
http.Error(w, fmt.Sprintf("%s with request ID %s", e.Msg, e.RequestID), 200) requestID := RequestIDFromContext(r.Context())
http.Error(w, fmt.Sprintf("%s with request ID %s", e.Msg, requestID), 200)
}, },
wantLog: AccessLogRecord{ wantLog: AccessLogRecord{
When: startTime, When: startTime,
@ -477,37 +466,6 @@ func TestStdHandler(t *testing.T) {
}, },
wantBody: "not found with request ID " + exampleRequestID + "\n", wantBody: "not found with request ID " + exampleRequestID + "\n",
}, },
{
name: "request ID can use information from request",
rh: handlerErr(0, Error(400, "bad request", nil)),
r: func() *http.Request {
r := req(bgCtx, "http://example.com/")
r.AddCookie(&http.Cookie{Name: "want_request_id", Value: "asdf1234"})
return r
}(),
generateRequestID: func(r *http.Request) RequestID {
c, _ := r.Cookie("want_request_id")
if c == nil {
return ""
}
return RequestID(c.Value)
},
wantCode: 400,
wantLog: AccessLogRecord{
When: startTime,
Seconds: 1.0,
Proto: "HTTP/1.1",
TLS: false,
Host: "example.com",
RequestURI: "/",
Method: "GET",
Code: 400,
Err: "bad request",
RequestID: "asdf1234",
},
wantBody: "bad request\nasdf1234\n",
},
} }
for _, test := range tests { for _, test := range tests {
@ -526,7 +484,7 @@ func TestStdHandler(t *testing.T) {
}) })
rec := noopHijacker{httptest.NewRecorder(), false} rec := noopHijacker{httptest.NewRecorder(), false}
h := StdHandler(test.rh, HandlerOptions{Logf: logf, Now: clock.Now, GenerateRequestID: test.generateRequestID, OnError: test.errHandler}) h := StdHandler(test.rh, HandlerOptions{Logf: logf, Now: clock.Now, OnError: test.errHandler})
h.ServeHTTP(&rec, test.r) h.ServeHTTP(&rec, test.r)
res := rec.Result() res := rec.Result()
if res.StatusCode != test.wantCode { if res.StatusCode != test.wantCode {

Loading…
Cancel
Save