tsweb: implement interceptor for error page presentation

Updates https://github.com/tailscale/corp/issues/5605

Signed-off-by: Tom DNetto <tom@tailscale.com>
pull/4825/head
Tom DNetto 2 years ago committed by Tom
parent b788a5ba7e
commit 32c6823cf5

@ -197,8 +197,16 @@ type HandlerOptions struct {
// codes for handled responses.
// The keys are HTTP numeric response codes e.g. 200, 404, ...
StatusCodeCountersFull *expvar.Map
// OnError is called if the handler returned a HTTPError. This
// is intended to be used to present pretty error pages if
// the user agent is determined to be a browser.
OnError ErrorHandlerFunc
}
// ErrorHandlerFunc is called to present a error response.
type ErrorHandlerFunc func(http.ResponseWriter, *http.Request, HTTPError)
// ReturnHandlerFunc is an adapter to allow the use of ordinary
// functions as ReturnHandlers. If f is a function with the
// appropriate signature, ReturnHandlerFunc(f) is a ReturnHandler that
@ -287,7 +295,11 @@ func (h retHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.opts.Logf("[unexpected] HTTPError %v did not contain an HTTP status code, sending internal server error", hErr)
msg.Code = http.StatusInternalServerError
}
http.Error(lw, hErr.Msg, msg.Code)
if h.opts.OnError != nil {
h.opts.OnError(lw, r, hErr)
} else {
http.Error(lw, hErr.Msg, msg.Code)
}
case err != nil:
// Handler returned a generic error. Serve an internal server
// error, if necessary.

@ -76,11 +76,12 @@ func TestStdHandler(t *testing.T) {
// cancel()
tests := []struct {
name string
rh ReturnHandler
r *http.Request
wantCode int
wantLog AccessLogRecord
name string
rh ReturnHandler
r *http.Request
errHandler ErrorHandlerFunc
wantCode int
wantLog AccessLogRecord
}{
{
name: "handler returns 200",
@ -238,6 +239,26 @@ func TestStdHandler(t *testing.T) {
Code: 101,
},
},
{
name: "error handler gets run",
rh: handlerErr(0, Error(404, "not found", nil)), // status code changed in errHandler
r: req(bgCtx, "http://example.com/"),
wantCode: 200,
errHandler: func(w http.ResponseWriter, r *http.Request, e HTTPError) {
http.Error(w, e.Msg, 200)
},
wantLog: AccessLogRecord{
When: clock.Start,
Seconds: 1.0,
Proto: "HTTP/1.1",
TLS: false,
Host: "example.com",
Method: "GET",
Code: 404,
Err: "not found",
RequestURI: "/",
},
},
}
for _, test := range tests {
@ -253,7 +274,7 @@ func TestStdHandler(t *testing.T) {
clock.Reset()
rec := noopHijacker{httptest.NewRecorder(), false}
h := StdHandler(test.rh, HandlerOptions{Logf: logf, Now: clock.Now})
h := StdHandler(test.rh, HandlerOptions{Logf: logf, Now: clock.Now, OnError: test.errHandler})
h.ServeHTTP(&rec, test.r)
res := rec.Result()
if res.StatusCode != test.wantCode {

Loading…
Cancel
Save