From 32c6823cf5a2eb56823ac07e054430cb157d51fd Mon Sep 17 00:00:00 2001 From: Tom DNetto Date: Wed, 8 Jun 2022 13:05:15 -0700 Subject: [PATCH] tsweb: implement interceptor for error page presentation Updates https://github.com/tailscale/corp/issues/5605 Signed-off-by: Tom DNetto --- tsweb/tsweb.go | 14 +++++++++++++- tsweb/tsweb_test.go | 33 +++++++++++++++++++++++++++------ 2 files changed, 40 insertions(+), 7 deletions(-) diff --git a/tsweb/tsweb.go b/tsweb/tsweb.go index e6fbffbf5..278b9d111 100644 --- a/tsweb/tsweb.go +++ b/tsweb/tsweb.go @@ -197,8 +197,16 @@ type HandlerOptions struct { // codes for handled responses. // The keys are HTTP numeric response codes e.g. 200, 404, ... StatusCodeCountersFull *expvar.Map + + // OnError is called if the handler returned a HTTPError. This + // is intended to be used to present pretty error pages if + // the user agent is determined to be a browser. + OnError ErrorHandlerFunc } +// ErrorHandlerFunc is called to present a error response. +type ErrorHandlerFunc func(http.ResponseWriter, *http.Request, HTTPError) + // ReturnHandlerFunc is an adapter to allow the use of ordinary // functions as ReturnHandlers. If f is a function with the // appropriate signature, ReturnHandlerFunc(f) is a ReturnHandler that @@ -287,7 +295,11 @@ func (h retHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.opts.Logf("[unexpected] HTTPError %v did not contain an HTTP status code, sending internal server error", hErr) msg.Code = http.StatusInternalServerError } - http.Error(lw, hErr.Msg, msg.Code) + if h.opts.OnError != nil { + h.opts.OnError(lw, r, hErr) + } else { + http.Error(lw, hErr.Msg, msg.Code) + } case err != nil: // Handler returned a generic error. Serve an internal server // error, if necessary. diff --git a/tsweb/tsweb_test.go b/tsweb/tsweb_test.go index 8df574254..8aef7073c 100644 --- a/tsweb/tsweb_test.go +++ b/tsweb/tsweb_test.go @@ -76,11 +76,12 @@ func TestStdHandler(t *testing.T) { // cancel() tests := []struct { - name string - rh ReturnHandler - r *http.Request - wantCode int - wantLog AccessLogRecord + name string + rh ReturnHandler + r *http.Request + errHandler ErrorHandlerFunc + wantCode int + wantLog AccessLogRecord }{ { name: "handler returns 200", @@ -238,6 +239,26 @@ func TestStdHandler(t *testing.T) { Code: 101, }, }, + { + name: "error handler gets run", + rh: handlerErr(0, Error(404, "not found", nil)), // status code changed in errHandler + r: req(bgCtx, "http://example.com/"), + wantCode: 200, + errHandler: func(w http.ResponseWriter, r *http.Request, e HTTPError) { + http.Error(w, e.Msg, 200) + }, + wantLog: AccessLogRecord{ + When: clock.Start, + Seconds: 1.0, + Proto: "HTTP/1.1", + TLS: false, + Host: "example.com", + Method: "GET", + Code: 404, + Err: "not found", + RequestURI: "/", + }, + }, } for _, test := range tests { @@ -253,7 +274,7 @@ func TestStdHandler(t *testing.T) { clock.Reset() rec := noopHijacker{httptest.NewRecorder(), false} - h := StdHandler(test.rh, HandlerOptions{Logf: logf, Now: clock.Now}) + h := StdHandler(test.rh, HandlerOptions{Logf: logf, Now: clock.Now, OnError: test.errHandler}) h.ServeHTTP(&rec, test.r) res := rec.Result() if res.StatusCode != test.wantCode {