tsweb: make JSONHandlerFunc implement ReturnHandler, not http.Handler

This way something is capable of logging errors on the server.

Fixes #766

Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
crawshaw/jsonhandler
David Crawshaw 4 years ago committed by David Crawshaw
parent 3aeb2e204c
commit dea3ef0597

@ -6,6 +6,7 @@ package tsweb
import ( import (
"encoding/json" "encoding/json"
"fmt"
"net/http" "net/http"
) )
@ -15,23 +16,23 @@ type response struct {
Data interface{} `json:"data,omitempty"` Data interface{} `json:"data,omitempty"`
} }
// TODO: Header // JSONHandlerFunc is an HTTP ReturnHandler that writes JSON responses to the client.
//
// JSONHandlerFunc only take *http.Request as argument to avoid any misuse of http.ResponseWriter. // Return a HTTPError to show an error message, otherwise JSONHandlerFunc will
// The function's results must be (status int, data interface{}, err error). // only report "internal server error" to the user.
// Return a HTTPError to show an error message, otherwise JSONHandler will only show "internal server error".
type JSONHandlerFunc func(r *http.Request) (status int, data interface{}, err error) type JSONHandlerFunc func(r *http.Request) (status int, data interface{}, err error)
// ServeHTTP calls the JSONHandlerFunc and automatically marshals http responses. // ServeHTTPReturn implements the ReturnHandler interface.
// //
// Use the following code to unmarshal the request body // Use the following code to unmarshal the request body
//
// body := new(DataType) // body := new(DataType)
// if err := json.NewDecoder(r.Body).Decode(body); err != nil { // if err := json.NewDecoder(r.Body).Decode(body); err != nil {
// return http.StatusBadRequest, nil, err // return http.StatusBadRequest, nil, err
// } // }
// //
// Check jsonhandler_text.go for examples // See jsonhandler_text.go for examples.
func (fn JSONHandlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (fn JSONHandlerFunc) ServeHTTPReturn(w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
var resp *response var resp *response
status, data, err := fn(r) status, data, err := fn(r)
@ -53,6 +54,10 @@ func (fn JSONHandlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request) {
Error: werr.Msg, Error: werr.Msg,
Data: data, Data: data,
} }
// Unwrap the HTTPError here because we are communicating with
// the client in this handler. We don't want the wrapping
// ReturnHandler to do it too.
err = werr.Err
} else { } else {
resp = &response{ resp = &response{
Status: "error", Status: "error",
@ -61,13 +66,17 @@ func (fn JSONHandlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
} }
b, err := json.Marshal(resp) b, jerr := json.Marshal(resp)
if err != nil { if jerr != nil {
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"status":"error","error":"json marshal error"}`)) w.Write([]byte(`{"status":"error","error":"json marshal error"}`))
return if err != nil {
return fmt.Errorf("%w, and then we could not respond: %v", err, jerr)
}
return jerr
} }
w.WriteHeader(status) w.WriteHeader(status)
w.Write(b) w.Write(b)
return err
} }

@ -61,7 +61,7 @@ func TestNewJSONHandler(t *testing.T) {
t.Run("200 simple", func(t *testing.T) { t.Run("200 simple", func(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil) r := httptest.NewRequest("GET", "/", nil)
h21.ServeHTTP(w, r) h21.ServeHTTPReturn(w, r)
checkStatus(w, "success", http.StatusOK) checkStatus(w, "success", http.StatusOK)
}) })
@ -72,7 +72,7 @@ func TestNewJSONHandler(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil) r := httptest.NewRequest("GET", "/", nil)
h.ServeHTTP(w, r) h.ServeHTTPReturn(w, r)
checkStatus(w, "error", http.StatusForbidden) checkStatus(w, "error", http.StatusForbidden)
}) })
@ -83,7 +83,7 @@ func TestNewJSONHandler(t *testing.T) {
t.Run("200 get data", func(t *testing.T) { t.Run("200 get data", func(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil) r := httptest.NewRequest("GET", "/", nil)
h22.ServeHTTP(w, r) h22.ServeHTTPReturn(w, r)
checkStatus(w, "success", http.StatusOK) checkStatus(w, "success", http.StatusOK)
}) })
@ -102,21 +102,21 @@ func TestNewJSONHandler(t *testing.T) {
t.Run("200 post data", func(t *testing.T) { t.Run("200 post data", func(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Name": "tailscale"}`)) r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Name": "tailscale"}`))
h31.ServeHTTP(w, r) h31.ServeHTTPReturn(w, r)
checkStatus(w, "success", http.StatusOK) checkStatus(w, "success", http.StatusOK)
}) })
t.Run("400 bad json", func(t *testing.T) { t.Run("400 bad json", func(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
r := httptest.NewRequest("POST", "/", strings.NewReader(`{`)) r := httptest.NewRequest("POST", "/", strings.NewReader(`{`))
h31.ServeHTTP(w, r) h31.ServeHTTPReturn(w, r)
checkStatus(w, "error", http.StatusBadRequest) checkStatus(w, "error", http.StatusBadRequest)
}) })
t.Run("400 post data error", func(t *testing.T) { t.Run("400 post data error", func(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`)) r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`))
h31.ServeHTTP(w, r) h31.ServeHTTPReturn(w, r)
resp := checkStatus(w, "error", http.StatusBadRequest) resp := checkStatus(w, "error", http.StatusBadRequest)
if resp.Error != "name is empty" { if resp.Error != "name is empty" {
t.Fatalf("wrong error") t.Fatalf("wrong error")
@ -141,7 +141,7 @@ func TestNewJSONHandler(t *testing.T) {
t.Run("200 post data", func(t *testing.T) { t.Run("200 post data", func(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Price": 10}`)) r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Price": 10}`))
h32.ServeHTTP(w, r) h32.ServeHTTPReturn(w, r)
resp := checkStatus(w, "success", http.StatusOK) resp := checkStatus(w, "success", http.StatusOK)
t.Log(resp.Data) t.Log(resp.Data)
if resp.Data.Price != 20 { if resp.Data.Price != 20 {
@ -152,7 +152,7 @@ func TestNewJSONHandler(t *testing.T) {
t.Run("400 post data error", func(t *testing.T) { t.Run("400 post data error", func(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`)) r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`))
h32.ServeHTTP(w, r) h32.ServeHTTPReturn(w, r)
resp := checkStatus(w, "error", http.StatusBadRequest) resp := checkStatus(w, "error", http.StatusBadRequest)
if resp.Error != "price is empty" { if resp.Error != "price is empty" {
t.Fatalf("wrong error") t.Fatalf("wrong error")
@ -162,7 +162,7 @@ func TestNewJSONHandler(t *testing.T) {
t.Run("500 internal server error", func(t *testing.T) { t.Run("500 internal server error", func(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Name": "root"}`)) r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Name": "root"}`))
h32.ServeHTTP(w, r) h32.ServeHTTPReturn(w, r)
resp := checkStatus(w, "error", http.StatusInternalServerError) resp := checkStatus(w, "error", http.StatusInternalServerError)
if resp.Error != "internal server error" { if resp.Error != "internal server error" {
t.Fatalf("wrong error") t.Fatalf("wrong error")
@ -174,7 +174,7 @@ func TestNewJSONHandler(t *testing.T) {
r := httptest.NewRequest("POST", "/", nil) r := httptest.NewRequest("POST", "/", nil)
JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) { JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
return http.StatusOK, make(chan int), nil return http.StatusOK, make(chan int), nil
}).ServeHTTP(w, r) }).ServeHTTPReturn(w, r)
resp := checkStatus(w, "error", http.StatusInternalServerError) resp := checkStatus(w, "error", http.StatusInternalServerError)
if resp.Error != "json marshal error" { if resp.Error != "json marshal error" {
t.Fatalf("wrong error") t.Fatalf("wrong error")
@ -186,7 +186,7 @@ func TestNewJSONHandler(t *testing.T) {
r := httptest.NewRequest("POST", "/", nil) r := httptest.NewRequest("POST", "/", nil)
JSONHandlerFunc(func(r *http.Request) (status int, data interface{}, err error) { JSONHandlerFunc(func(r *http.Request) (status int, data interface{}, err error) {
return return
}).ServeHTTP(w, r) }).ServeHTTPReturn(w, r)
checkStatus(w, "error", http.StatusInternalServerError) checkStatus(w, "error", http.StatusInternalServerError)
}) })
} }

Loading…
Cancel
Save