tsweb: log all cancellations as 499s (#12894)

Updates #12141

Signed-off-by: Paul Scott <paul@tailscale.com>
pull/12907/head
Paul Scott 4 months ago committed by GitHub
parent 57856fc0d5
commit ba7f2d129e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -338,7 +338,7 @@ func (opts ErrorOptions) withDefaults() ErrorOptions {
opts.Logf = logger.Discard opts.Logf = logger.Discard
} }
if opts.OnError == nil { if opts.OnError == nil {
opts.OnError = writeHTTPError opts.OnError = WriteHTTPError
} }
return opts return opts
} }
@ -405,7 +405,7 @@ func ErrorHandler(h ReturnHandler, opts ErrorOptions) http.Handler {
// errCallback is added to logHandler's request context so that errorHandler can // errCallback is added to logHandler's request context so that errorHandler can
// pass errors back up the stack to logHandler. // pass errors back up the stack to logHandler.
var errCallback = ctxkey.New[func(string)]("tailscale.com/tsweb.errCallback", nil) var errCallback = ctxkey.New[func(HTTPError)]("tailscale.com/tsweb.errCallback", nil)
// logHandler is a http.Handler which logs the HTTP request. // logHandler is a http.Handler which logs the HTTP request.
// It injects an errCallback for errorHandler to augment the log message with // It injects an errCallback for errorHandler to augment the log message with
@ -471,9 +471,25 @@ func (h logHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
// Let errorHandler tell us what error it wrote to the client. // Let errorHandler tell us what error it wrote to the client.
r = r.WithContext(errCallback.WithValue(ctx, func(e string) { r = r.WithContext(errCallback.WithValue(ctx, func(e HTTPError) {
if msg.Err == "" { // Keep the deepest error.
msg.Err = e // Keep the first error. if msg.Err != "" {
return
}
// Log the error.
if e.Msg != "" && e.Err != nil {
msg.Err = e.Msg + ": " + e.Err.Error()
} else if e.Err != nil {
msg.Err = e.Err.Error()
} else if e.Msg != "" {
msg.Err = e.Msg
}
// We log the code from the loggingResponseWriter, except for
// cancellation where we override with 499.
if reqCancelled(r, e.Err) {
msg.Code = 499
} }
})) }))
@ -502,23 +518,29 @@ func (h logHandler) logRequest(r *http.Request, lw *loggingResponseWriter, msg A
msg.Bytes = lw.bytes msg.Bytes = lw.bytes
msg.Seconds = h.opts.Now().Sub(msg.Time).Seconds() msg.Seconds = h.opts.Now().Sub(msg.Time).Seconds()
switch { switch {
case msg.Code != 0:
// Keep explicit codes from a few particular errors.
case lw.hijacked: case lw.hijacked:
// Connection no longer belongs to us, just log that we // Connection no longer belongs to us, just log that we
// switched protocols away from HTTP. // switched protocols away from HTTP.
msg.Code = http.StatusSwitchingProtocols msg.Code = http.StatusSwitchingProtocols
case lw.code == 0: case lw.code == 0:
if r.Context().Err() != nil { // If the handler didn't write and didn't send a header, that still means 200.
// We didn't write a response before the client disconnected. // (See https://play.golang.org/p/4P7nx_Tap7p)
msg.Code = 499 msg.Code = 200
} else {
// If the handler didn't write and didn't send a header, that still means 200.
// (See https://play.golang.org/p/4P7nx_Tap7p)
msg.Code = 200
}
default: default:
msg.Code = lw.code msg.Code = lw.code
} }
// Keep track of the original response code when we've overridden it.
if lw.code != 0 && msg.Code != lw.code {
if msg.Err == "" {
msg.Err = fmt.Sprintf("(original code %d)", lw.code)
} else {
msg.Err = fmt.Sprintf("%s (original code %d)", msg.Err, lw.code)
}
}
if !h.opts.QuietLoggingIfSuccessful || (msg.Code != http.StatusOK && msg.Code != http.StatusNotModified) { if !h.opts.QuietLoggingIfSuccessful || (msg.Code != http.StatusOK && msg.Code != http.StatusNotModified) {
h.opts.Logf("%s", msg) h.opts.Logf("%s", msg)
} }
@ -564,6 +586,7 @@ var responseCodeCache sync.Map
// response code that gets sent, if any. // response code that gets sent, if any.
type loggingResponseWriter struct { type loggingResponseWriter struct {
http.ResponseWriter http.ResponseWriter
ctx context.Context
code int code int
bytes int bytes int
hijacked bool hijacked bool
@ -582,6 +605,7 @@ func newLogResponseWriter(logf logger.Logf, w http.ResponseWriter, r *http.Reque
} }
return &loggingResponseWriter{ return &loggingResponseWriter{
ResponseWriter: w, ResponseWriter: w,
ctx: r.Context(),
logf: logf, logf: logf,
} }
} }
@ -592,7 +616,9 @@ func (l *loggingResponseWriter) WriteHeader(statusCode int) {
l.logf("[unexpected] HTTP handler set statusCode twice (%d and %d)", l.code, statusCode) l.logf("[unexpected] HTTP handler set statusCode twice (%d and %d)", l.code, statusCode)
return return
} }
l.code = statusCode if l.ctx.Err() == nil {
l.code = statusCode
}
l.ResponseWriter.WriteHeader(statusCode) l.ResponseWriter.WriteHeader(statusCode)
} }
@ -682,8 +708,13 @@ func (h errorHandler) handleError(w http.ResponseWriter, r *http.Request, lw *lo
} }
} else if v, ok := vizerror.As(err); ok { } else if v, ok := vizerror.As(err); ok {
hErr = Error(http.StatusInternalServerError, v.Error(), nil) hErr = Error(http.StatusInternalServerError, v.Error(), nil)
} else if errors.Is(err, context.Canceled) || errors.Is(err, http.ErrAbortHandler) { } else if reqCancelled(r, err) {
hErr = Error(499, "", err) // Nginx convention // 499 is the Nginx convention meaning "Client Closed Connection".
if errors.Is(err, context.Canceled) || errors.Is(err, http.ErrAbortHandler) {
hErr = Error(499, "", err)
} else {
hErr = Error(499, "", fmt.Errorf("%w: %w", context.Canceled, err))
}
} else { } else {
// Omit the friendly message so HTTP logs show the bare error that was // Omit the friendly message so HTTP logs show the bare error that was
// returned and we know it's not a HTTPError. // returned and we know it's not a HTTPError.
@ -692,13 +723,7 @@ func (h errorHandler) handleError(w http.ResponseWriter, r *http.Request, lw *lo
// Tell the logger what error we wrote back to the client. // Tell the logger what error we wrote back to the client.
if pb := errCallback.Value(r.Context()); pb != nil { if pb := errCallback.Value(r.Context()); pb != nil {
if hErr.Msg != "" && hErr.Err != nil { pb(hErr)
pb(hErr.Msg + ": " + hErr.Err.Error())
} else if hErr.Err != nil {
pb(hErr.Err.Error())
} else if hErr.Msg != "" {
pb(hErr.Msg)
}
logged = true logged = true
} }
@ -775,21 +800,32 @@ func (e *panicError) Unwrap() error {
return err return err
} }
// writeHTTPError is the default error response formatter. // reqCancelled returns true if err is http.ErrAbortHandler or r.Context.Err()
func writeHTTPError(w http.ResponseWriter, r *http.Request, hErr HTTPError) { // is context.Canceled.
func reqCancelled(r *http.Request, err error) bool {
return errors.Is(err, http.ErrAbortHandler) || r.Context().Err() == context.Canceled
}
// WriteHTTPError is the default error response formatter.
func WriteHTTPError(w http.ResponseWriter, r *http.Request, e HTTPError) {
// Don't write a response if we've hit a cancellation/abort.
if r.Context().Err() != nil || errors.Is(e.Err, http.ErrAbortHandler) {
return
}
// Default headers set by http.Error. // Default headers set by http.Error.
h := w.Header() h := w.Header()
h.Set("Content-Type", "text/plain; charset=utf-8") h.Set("Content-Type", "text/plain; charset=utf-8")
h.Set("X-Content-Type-Options", "nosniff") h.Set("X-Content-Type-Options", "nosniff")
// Custom headers from the error. // Custom headers from the error.
for k, vs := range hErr.Header { for k, vs := range e.Header {
h[k] = vs h[k] = vs
} }
// Write the msg back to the user. // Write the msg back to the user.
w.WriteHeader(hErr.Code) w.WriteHeader(e.Code)
fmt.Fprint(w, hErr.Msg) fmt.Fprint(w, e.Msg)
// If it's a plaintext message, add line breaks and RequestID. // If it's a plaintext message, add line breaks and RequestID.
if strings.HasPrefix(h.Get("Content-Type"), "text/plain") { if strings.HasPrefix(h.Get("Content-Type"), "text/plain") {

@ -13,6 +13,7 @@ import (
"net" "net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/http/httputil"
"net/url" "net/url"
"strings" "strings"
"testing" "testing"
@ -493,6 +494,25 @@ func TestStdHandler(t *testing.T) {
wantBody: "not found with request ID " + exampleRequestID + "\n", wantBody: "not found with request ID " + exampleRequestID + "\n",
}, },
{
name: "inner_cancelled",
rh: handlerErr(0, context.Canceled), // return canceled error, but the request was not cancelled
r: req(bgCtx, "http://example.com/"),
wantCode: 500,
wantLog: AccessLogRecord{
Time: startTime,
Seconds: 1.0,
Proto: "HTTP/1.1",
TLS: false,
Host: "example.com",
Method: "GET",
Code: 500,
Err: "context canceled",
RequestURI: "/",
},
wantBody: "Internal Server Error\n",
},
{ {
name: "nested", name: "nested",
rh: ReturnHandlerFunc(func(w http.ResponseWriter, r *http.Request) error { rh: ReturnHandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
@ -705,6 +725,7 @@ func TestStdHandler_Canceled(t *testing.T) {
close(handlerOpen) close(handlerOpen)
ctx := r.Context() ctx := r.Context()
<-ctx.Done() <-ctx.Done()
w.WriteHeader(200) // Ignored.
return ctx.Err() return ctx.Err()
}), }),
HandlerOptions{ HandlerOptions{
@ -718,6 +739,8 @@ func TestStdHandler_Canceled(t *testing.T) {
}, },
}, },
) )
s := httptest.NewServer(h)
t.Cleanup(s.Close)
// Create a context which gets canceled after the handler starts processing // Create a context which gets canceled after the handler starts processing
// the request. // the request.
@ -727,9 +750,80 @@ func TestStdHandler_Canceled(t *testing.T) {
cancelReq() cancelReq()
}() }()
// Send a request to our server.
req, err := http.NewRequestWithContext(ctx, httpm.GET, s.URL, nil)
if err != nil {
t.Fatalf("making request: %s", err)
}
res, err := http.DefaultClient.Do(req)
if !errors.Is(err, context.Canceled) {
t.Errorf("got error %v, want context.Canceled", err)
}
if res != nil {
t.Errorf("got response %#v, want nil", res)
}
// Check that we got the expected log record.
got := <-r
got.Seconds = 0
got.RemoteAddr = ""
got.Host = ""
got.UserAgent = ""
want := AccessLogRecord{
Time: now,
Code: 499,
Method: "GET",
Err: "context canceled",
Proto: "HTTP/1.1",
RequestURI: "/",
}
if d := cmp.Diff(want, got); d != "" {
t.Errorf("AccessLogRecord wrong (-want +got)\n%s", d)
}
// Check that we rendered no response to the client after
// logHandler.OnCompletion has been called.
if e != nil {
t.Errorf("got OnError callback with %#v, want no callback", e)
}
}
func TestStdHandler_CanceledAfterHeader(t *testing.T) {
now := time.Now()
r := make(chan AccessLogRecord)
var e *HTTPError
handlerOpen := make(chan struct{})
h := StdHandler(
ReturnHandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
w.WriteHeader(http.StatusNoContent)
close(handlerOpen)
ctx := r.Context()
<-ctx.Done()
return ctx.Err()
}),
HandlerOptions{
Logf: t.Logf,
Now: func() time.Time { return now },
OnError: func(w http.ResponseWriter, r *http.Request, h HTTPError) {
e = &h
},
OnCompletion: func(_ *http.Request, alr AccessLogRecord) {
r <- alr
},
},
)
s := httptest.NewServer(h) s := httptest.NewServer(h)
t.Cleanup(s.Close) t.Cleanup(s.Close)
// Create a context which gets canceled after the handler starts processing
// the request.
ctx, cancelReq := context.WithCancel(context.Background())
go func() {
<-handlerOpen
cancelReq()
}()
// Send a request to our server. // Send a request to our server.
req, err := http.NewRequestWithContext(ctx, httpm.GET, s.URL, nil) req, err := http.NewRequestWithContext(ctx, httpm.GET, s.URL, nil)
if err != nil { if err != nil {
@ -753,7 +847,7 @@ func TestStdHandler_Canceled(t *testing.T) {
Time: now, Time: now,
Code: 499, Code: 499,
Method: "GET", Method: "GET",
Err: "context canceled", Err: "context canceled (original code 204)",
Proto: "HTTP/1.1", Proto: "HTTP/1.1",
RequestURI: "/", RequestURI: "/",
} }
@ -766,7 +860,98 @@ func TestStdHandler_Canceled(t *testing.T) {
if e != nil { if e != nil {
t.Errorf("got OnError callback with %#v, want no callback", e) t.Errorf("got OnError callback with %#v, want no callback", e)
} }
}
func TestStdHandler_ConnectionClosedDuringBody(t *testing.T) {
now := time.Now()
// Start a HTTP server that returns 1MB of data.
// We next put a reverse-proxy in front of this server.
rs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
for range 1024 {
w.Write(make([]byte, 1024))
}
}))
defer rs.Close()
r := make(chan AccessLogRecord)
var e *HTTPError
responseStarted := make(chan struct{})
// Create another server which proxies our 1MB server.
// The [httputil.ReverseProxy] will panic with [http.ErrAbortHandler] when
// it fails to copy the response to the client.
h := StdHandler(
ReturnHandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
(&httputil.ReverseProxy{
Director: func(r *http.Request) {
r.URL = must.Get(url.Parse(rs.URL))
},
ModifyResponse: func(r *http.Response) error {
close(responseStarted)
return nil
},
}).ServeHTTP(w, r)
return nil
}),
HandlerOptions{
Logf: t.Logf,
Now: func() time.Time { return now },
OnError: func(w http.ResponseWriter, r *http.Request, h HTTPError) {
e = &h
},
OnCompletion: func(_ *http.Request, alr AccessLogRecord) {
r <- alr
},
},
)
s := httptest.NewServer(h)
t.Cleanup(s.Close)
// Create a context which gets canceled after the handler starts processing
// the request.
ctx, cancelReq := context.WithCancel(context.Background())
go func() {
<-responseStarted
cancelReq()
}()
// Send a request to our server.
req, err := http.NewRequestWithContext(ctx, httpm.GET, s.URL, nil)
if err != nil {
t.Fatalf("making request: %s", err)
}
res, err := http.DefaultClient.Do(req)
if !errors.Is(err, context.Canceled) {
t.Errorf("got error %v, want context.Canceled", err)
}
if res != nil {
t.Errorf("got response %#v, want nil", res)
}
// Check that we got the expected log record.
got := <-r
got.Seconds = 0
got.RemoteAddr = ""
got.Host = ""
got.UserAgent = ""
want := AccessLogRecord{
Time: now,
Code: 499,
Method: "GET",
Err: "net/http: abort Handler (original code 200)",
Proto: "HTTP/1.1",
RequestURI: "/",
}
if d := cmp.Diff(want, got, cmpopts.IgnoreFields(AccessLogRecord{}, "Bytes")); d != "" {
t.Errorf("AccessLogRecord wrong (-want +got)\n%s", d)
}
// Check that we rendered no response to the client after
// logHandler.OnCompletion has been called.
if e != nil {
t.Errorf("got OnError callback with %#v, want no callback", e)
}
} }
func TestStdHandler_OnErrorPanic(t *testing.T) { func TestStdHandler_OnErrorPanic(t *testing.T) {

Loading…
Cancel
Save