tsweb: accept a function to call before request handling

To complement the existing `onCompletion` callback, which is called
after request handler.

Updates tailscale/corp#17075

Signed-off-by: Anton Tolchanov <anton@tailscale.com>
pull/12619/head
Anton Tolchanov 3 months ago committed by Anton Tolchanov
parent 6e55d8f6a1
commit 787ead835f

@ -250,19 +250,25 @@ type HandlerOptions struct {
// for each bucket based on the contained parameters.
BucketedStats *BucketedStatsOptions
// OnStart is called inline before ServeHTTP is called. Optional.
OnStart OnStartFunc
// OnError is called if the handler returned a HTTPError. This
// is intended to be used to present pretty error pages if
// the user agent is determined to be a browser.
OnError ErrorHandlerFunc
// OnCompletion is called when ServeHTTP is finished and gets
// useful data that the implementor can use for metrics.
// OnCompletion is called inline when ServeHTTP is finished and gets
// useful data that the implementor can use for metrics. Optional.
OnCompletion OnCompletionFunc
}
// ErrorHandlerFunc is called to present a error response.
type ErrorHandlerFunc func(http.ResponseWriter, *http.Request, HTTPError)
// OnStartFunc is called before ServeHTTP is called.
type OnStartFunc func(*http.Request, AccessLogRecord)
// OnCompletionFunc is called when ServeHTTP is finished and gets
// useful data that the implementor can use for metrics.
type OnCompletionFunc func(*http.Request, AccessLogRecord)
@ -336,6 +342,10 @@ func (h retHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
if fn := h.opts.OnStart; fn != nil {
fn(r, msg)
}
lw := &loggingResponseWriter{ResponseWriter: w, logf: h.opts.Logf}
// In case the handler panics, we want to recover and continue logging the

@ -17,6 +17,7 @@ import (
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"tailscale.com/tstest"
"tailscale.com/util/must"
"tailscale.com/util/vizerror"
@ -485,8 +486,15 @@ func TestStdHandler(t *testing.T) {
Step: time.Second,
})
var onStartRecord, onCompletionRecord AccessLogRecord
rec := noopHijacker{httptest.NewRecorder(), false}
h := StdHandler(test.rh, HandlerOptions{Logf: logf, Now: clock.Now, OnError: test.errHandler})
h := StdHandler(test.rh, HandlerOptions{
Logf: logf,
Now: clock.Now,
OnError: test.errHandler,
OnStart: func(r *http.Request, alr AccessLogRecord) { onStartRecord = alr },
OnCompletion: func(r *http.Request, alr AccessLogRecord) { onCompletionRecord = alr },
})
h.ServeHTTP(&rec, test.r)
res := rec.Result()
if res.StatusCode != test.wantCode {
@ -502,6 +510,13 @@ func TestStdHandler(t *testing.T) {
}
return e.Error()
})
if diff := cmp.Diff(onStartRecord, test.wantLog, errTransform, cmpopts.IgnoreFields(
AccessLogRecord{}, "Time", "Seconds", "Code", "Err")); diff != "" {
t.Errorf("onStart callback returned unexpected request log (-got+want):\n%s", diff)
}
if diff := cmp.Diff(onCompletionRecord, test.wantLog, errTransform); diff != "" {
t.Errorf("onCompletion callback returned incorrect request log (-got+want):\n%s", diff)
}
if diff := cmp.Diff(logs[0], test.wantLog, errTransform); diff != "" {
t.Errorf("handler wrote incorrect request log (-got+want):\n%s", diff)
}

Loading…
Cancel
Save