diff --git a/tsweb/tsweb.go b/tsweb/tsweb.go index 2f3b1eae3..56680cc92 100644 --- a/tsweb/tsweb.go +++ b/tsweb/tsweb.go @@ -372,6 +372,34 @@ type ReturnHandlerFunc func(http.ResponseWriter, *http.Request) error // request to the underlying handler, if appropriate. type Middleware func(h http.Handler) http.Handler +// MiddlewareStack combines multiple middleware into a single middleware for +// decorating a [http.Handler]. The first middleware argument will be the first +// to process an incoming request, before passing the request onto subsequent +// middleware and eventually the wrapped handler. +// +// For example: +// +// MiddlewareStack(A, B)(h).ServeHTTP(w, r) +// +// calls in sequence: +// +// a.ServeHTTP(w, r) +// -> b.ServeHTTP(w, r) +// -> h.ServeHTTP(w, r) +// +// (where the lowercase handlers were generated by the uppercase middleware). +func MiddlewareStack(mw ...Middleware) Middleware { + if len(mw) == 1 { + return mw[0] + } + return func(h http.Handler) http.Handler { + for i := len(mw) - 1; i >= 0; i-- { + h = mw[i](h) + } + return h + } +} + // ServeHTTPReturn calls f(w, r). func (f ReturnHandlerFunc) ServeHTTPReturn(w http.ResponseWriter, r *http.Request) error { return f(w, r) diff --git a/tsweb/tsweb_test.go b/tsweb/tsweb_test.go index 2bf2b7341..7214ff3e6 100644 --- a/tsweb/tsweb_test.go +++ b/tsweb/tsweb_test.go @@ -14,6 +14,7 @@ import ( "net/http" "net/http/httptest" "net/http/httputil" + "net/textproto" "net/url" "strings" "testing" @@ -1246,3 +1247,40 @@ func TestBucket(t *testing.T) { }) } } + +func ExampleMiddlewareStack() { + // setHeader returns a middleware that sets header k = vs. + setHeader := func(k string, vs ...string) Middleware { + k = textproto.CanonicalMIMEHeaderKey(k) + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header()[k] = vs + h.ServeHTTP(w, r) + }) + } + } + + // h is a http.Handler which prints the A, B & C response headers, wrapped + // in a few middleware which set those headers. + var h http.Handler = MiddlewareStack( + setHeader("A", "mw1"), + MiddlewareStack( + setHeader("A", "mw2.1"), + setHeader("B", "mw2.2"), + setHeader("C", "mw2.3"), + setHeader("C", "mw2.4"), + ), + setHeader("B", "mw3"), + )(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Println("A", w.Header().Get("A")) + fmt.Println("B", w.Header().Get("B")) + fmt.Println("C", w.Header().Get("C")) + })) + + // Invoke the handler. + h.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest("", "/", nil)) + // Output: + // A mw2.1 + // B mw3 + // C mw2.4 +}