tsweb: Add MiddlewareStack func to apply lists of Middleware (#12907)

Fixes #12909

Signed-off-by: Paul Scott <paul@tailscale.com>
pull/12922/head
Paul Scott 3 months ago committed by GitHub
parent 43375c6efb
commit 855da47777
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -372,6 +372,34 @@ type ReturnHandlerFunc func(http.ResponseWriter, *http.Request) error
// request to the underlying handler, if appropriate. // request to the underlying handler, if appropriate.
type Middleware func(h http.Handler) http.Handler type Middleware func(h http.Handler) http.Handler
// MiddlewareStack combines multiple middleware into a single middleware for
// decorating a [http.Handler]. The first middleware argument will be the first
// to process an incoming request, before passing the request onto subsequent
// middleware and eventually the wrapped handler.
//
// For example:
//
// MiddlewareStack(A, B)(h).ServeHTTP(w, r)
//
// calls in sequence:
//
// a.ServeHTTP(w, r)
// -> b.ServeHTTP(w, r)
// -> h.ServeHTTP(w, r)
//
// (where the lowercase handlers were generated by the uppercase middleware).
func MiddlewareStack(mw ...Middleware) Middleware {
if len(mw) == 1 {
return mw[0]
}
return func(h http.Handler) http.Handler {
for i := len(mw) - 1; i >= 0; i-- {
h = mw[i](h)
}
return h
}
}
// ServeHTTPReturn calls f(w, r). // ServeHTTPReturn calls f(w, r).
func (f ReturnHandlerFunc) ServeHTTPReturn(w http.ResponseWriter, r *http.Request) error { func (f ReturnHandlerFunc) ServeHTTPReturn(w http.ResponseWriter, r *http.Request) error {
return f(w, r) return f(w, r)

@ -14,6 +14,7 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/http/httputil" "net/http/httputil"
"net/textproto"
"net/url" "net/url"
"strings" "strings"
"testing" "testing"
@ -1246,3 +1247,40 @@ func TestBucket(t *testing.T) {
}) })
} }
} }
func ExampleMiddlewareStack() {
// setHeader returns a middleware that sets header k = vs.
setHeader := func(k string, vs ...string) Middleware {
k = textproto.CanonicalMIMEHeaderKey(k)
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header()[k] = vs
h.ServeHTTP(w, r)
})
}
}
// h is a http.Handler which prints the A, B & C response headers, wrapped
// in a few middleware which set those headers.
var h http.Handler = MiddlewareStack(
setHeader("A", "mw1"),
MiddlewareStack(
setHeader("A", "mw2.1"),
setHeader("B", "mw2.2"),
setHeader("C", "mw2.3"),
setHeader("C", "mw2.4"),
),
setHeader("B", "mw3"),
)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Println("A", w.Header().Get("A"))
fmt.Println("B", w.Header().Get("B"))
fmt.Println("C", w.Header().Get("C"))
}))
// Invoke the handler.
h.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest("", "/", nil))
// Output:
// A mw2.1
// B mw3
// C mw2.4
}

Loading…
Cancel
Save