From 1835bb6f85f0e25458576de14d15a26955b43f9c Mon Sep 17 00:00:00 2001 From: halulu Date: Tue, 18 Aug 2020 17:37:01 -0400 Subject: [PATCH] tsweb: rewrite JSONHandler without using reflect (#684) Closes #656 #657 Signed-off-by: Zijie Lu --- tsweb/jsonhandler.go | 153 +++++++++++----------------------- tsweb/jsonhandler_test.go | 167 +++++++++++++++++--------------------- 2 files changed, 119 insertions(+), 201 deletions(-) diff --git a/tsweb/jsonhandler.go b/tsweb/jsonhandler.go index d249b9b14..8602a9ad9 100644 --- a/tsweb/jsonhandler.go +++ b/tsweb/jsonhandler.go @@ -7,7 +7,6 @@ package tsweb import ( "encoding/json" "net/http" - "reflect" ) type response struct { @@ -16,119 +15,59 @@ type response struct { Data interface{} `json:"data,omitempty"` } -func responseSuccess(data interface{}) *response { - return &response{ - Status: "success", - Data: data, - } -} - -func responseError(e string) *response { - return &response{ - Status: "error", - Error: e, - } -} +// TODO: Header -func writeResponse(w http.ResponseWriter, s int, resp *response) { - b, _ := json.Marshal(resp) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(s) - w.Write(b) -} - -func checkFn(t reflect.Type) { - h := reflect.TypeOf(http.HandlerFunc(nil)) - switch t.NumIn() { - case 2, 3: - if !t.In(0).AssignableTo(h.In(0)) { - panic("first argument must be http.ResponseWriter") - } - if !t.In(1).AssignableTo(h.In(1)) { - panic("second argument must be *http.Request") - } - default: - panic("JSONHandler: number of input parameter should be 2 or 3") - } - - switch t.NumOut() { - case 1: - if !t.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) { - panic("return value must be error") - } - case 2: - if !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) { - panic("second return value must be error") - } - default: - panic("JSONHandler: number of return values should be 1 or 2") - } -} +// JSONHandlerFunc only take *http.Request as argument to avoid any misuse of http.ResponseWriter. +// The function's results must be (status int, data interface{}, err error). +// Return a HTTPError to show an error message, otherwise JSONHandler will only show "internal server error". +type JSONHandlerFunc func(r *http.Request) (status int, data interface{}, err error) -// JSONHandler wraps an HTTP handler function with a version that automatically -// unmarshals and marshals requests and responses respectively into fn's arguments -// and results. +// ServeHTTP calls the JSONHandlerFunc and automatically marshals http responses. // -// The fn parameter is a function. It must take two or three input arguments. -// The first two arguments must be http.ResponseWriter and *http.Request. -// The optional third argument can be of any type representing the JSON input. -// The function's results can be either (error) or (T, error), where T is the -// JSON-marshalled result type. +// Use the following code to unmarshal the request body +// body := new(DataType) +// if err := json.NewDecoder(r.Body).Decode(body); err != nil { +// return http.StatusBadRequest, nil, err +// } // -// For example: -// fn := func(w http.ResponseWriter, r *http.Request, in *Req) (*Res, error) { ... } -func JSONHandler(fn interface{}) http.Handler { - v := reflect.ValueOf(fn) - t := v.Type() - checkFn(t) - - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - wv := reflect.ValueOf(w) - rv := reflect.ValueOf(r) - var vs []reflect.Value - - switch t.NumIn() { - case 2: - vs = v.Call([]reflect.Value{wv, rv}) - case 3: - dv := reflect.New(t.In(2)) - err := json.NewDecoder(r.Body).Decode(dv.Interface()) - if err != nil { - writeResponse(w, http.StatusBadRequest, responseError("bad json")) - return - } - vs = v.Call([]reflect.Value{wv, rv, dv.Elem()}) - default: - panic("JSONHandler: number of input parameter should be 2 or 3") +// Check jsonhandler_text.go for examples +func (fn JSONHandlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + var resp *response + status, data, err := fn(r) + if status == 0 { + status = http.StatusInternalServerError + resp = &response{ + Status: "error", + Error: "internal server error", } - - var e reflect.Value - switch len(vs) { - case 1: - // todo support other error types - if vs[0].IsZero() { - writeResponse(w, http.StatusOK, responseSuccess(nil)) - return + } else if err == nil { + resp = &response{ + Status: "success", + Data: data, + } + } else { + if werr, ok := err.(HTTPError); ok { + resp = &response{ + Status: "error", + Error: werr.Msg, + Data: data, } - e = vs[0] - case 2: - if vs[1].IsZero() { - if !vs[0].IsZero() { - writeResponse(w, http.StatusOK, responseSuccess(vs[0].Interface())) - } - return + } else { + resp = &response{ + Status: "error", + Error: "internal server error", } - e = vs[1] - default: - panic("JSONHandler: number of return values should be 1 or 2") } + } - if e.Type().AssignableTo(reflect.TypeOf(HTTPError{})) { - err := e.Interface().(HTTPError) - writeResponse(w, err.Code, responseError(err.Error())) - } else { - err := e.Interface().(error) - writeResponse(w, http.StatusBadRequest, responseError(err.Error())) - } - }) + b, err := json.Marshal(resp) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`{"status":"error","error":"json marshal error"}`)) + return + } + + w.WriteHeader(status) + w.Write(b) } diff --git a/tsweb/jsonhandler_test.go b/tsweb/jsonhandler_test.go index 91d7d1227..be7d9d215 100644 --- a/tsweb/jsonhandler_test.go +++ b/tsweb/jsonhandler_test.go @@ -5,9 +5,8 @@ package tsweb import ( - "bytes" "encoding/json" - "errors" + "fmt" "net/http" "net/http/httptest" "strings" @@ -26,7 +25,7 @@ type Response struct { } func TestNewJSONHandler(t *testing.T) { - checkStatus := func(w *httptest.ResponseRecorder, status string) *Response { + checkStatus := func(w *httptest.ResponseRecorder, status string, code int) *Response { d := &Response{ Data: &Data{}, } @@ -44,6 +43,10 @@ func TestNewJSONHandler(t *testing.T) { t.Fatalf("wrong status: %s %s", d.Status, status) } + if w.Code != code { + t.Fatalf("wrong status code: %d %d", w.Code, code) + } + if w.Header().Get("Content-Type") != "application/json" { t.Fatalf("wrong content type: %s", w.Header().Get("Content-Type")) } @@ -51,163 +54,139 @@ func TestNewJSONHandler(t *testing.T) { return d } - // 2 1 - h21 := JSONHandler(func(w http.ResponseWriter, r *http.Request) error { - return nil + h21 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) { + return http.StatusOK, nil, nil }) - t.Run("2 1 simple", func(t *testing.T) { + t.Run("200 simple", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/", nil) h21.ServeHTTP(w, r) - checkStatus(w, "success") + checkStatus(w, "success", http.StatusOK) }) - t.Run("2 1 HTTPError", func(t *testing.T) { - h := JSONHandler(func(w http.ResponseWriter, r *http.Request) HTTPError { - return Error(http.StatusForbidden, "forbidden", nil) + t.Run("403 HTTPError", func(t *testing.T) { + h := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) { + return http.StatusForbidden, nil, fmt.Errorf("forbidden") }) w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/", nil) h.ServeHTTP(w, r) - if w.Code != http.StatusForbidden { - t.Fatalf("wrong code: %d %d", w.Code, http.StatusForbidden) - } + checkStatus(w, "error", http.StatusForbidden) }) - // 2 2 - h22 := JSONHandler(func(w http.ResponseWriter, r *http.Request) (*Data, error) { - return &Data{Name: "tailscale"}, nil + h22 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) { + return http.StatusOK, &Data{Name: "tailscale"}, nil }) - t.Run("2 2 get data", func(t *testing.T) { + + t.Run("200 get data", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/", nil) h22.ServeHTTP(w, r) - checkStatus(w, "success") + checkStatus(w, "success", http.StatusOK) }) - // 3 1 - h31 := JSONHandler(func(w http.ResponseWriter, r *http.Request, d *Data) error { - if d.Name == "" { - return errors.New("name is empty") + h31 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) { + body := new(Data) + if err := json.NewDecoder(r.Body).Decode(body); err != nil { + return http.StatusBadRequest, nil, err } - return nil + if body.Name == "" { + return http.StatusBadRequest, nil, Error(http.StatusBadGateway, "name is empty", nil) + } + + return http.StatusOK, nil, nil }) - t.Run("3 1 post data", func(t *testing.T) { + t.Run("200 post data", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Name": "tailscale"}`)) h31.ServeHTTP(w, r) - checkStatus(w, "success") + checkStatus(w, "success", http.StatusOK) }) - t.Run("3 1 bad json", func(t *testing.T) { + t.Run("400 bad json", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/", strings.NewReader(`{`)) h31.ServeHTTP(w, r) - checkStatus(w, "error") + checkStatus(w, "error", http.StatusBadRequest) }) - t.Run("3 1 post data error", func(t *testing.T) { + t.Run("400 post data error", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`)) h31.ServeHTTP(w, r) - resp := checkStatus(w, "error") + resp := checkStatus(w, "error", http.StatusBadRequest) if resp.Error != "name is empty" { t.Fatalf("wrong error") } }) - // 3 2 - h32 := JSONHandler(func(w http.ResponseWriter, r *http.Request, d *Data) (*Data, error) { - if d.Price == 0 { - return nil, errors.New("price is empty") + h32 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) { + body := new(Data) + if err := json.NewDecoder(r.Body).Decode(body); err != nil { + return http.StatusBadRequest, nil, err + } + if body.Name == "root" { + return http.StatusInternalServerError, nil, fmt.Errorf("invalid name") + } + if body.Price == 0 { + return http.StatusBadRequest, nil, Error(http.StatusBadGateway, "price is empty", nil) } - return &Data{Price: d.Price * 2}, nil + return http.StatusOK, &Data{Price: body.Price * 2}, nil }) - t.Run("3 2 post data", func(t *testing.T) { + + t.Run("200 post data", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Price": 10}`)) h32.ServeHTTP(w, r) - resp := checkStatus(w, "success") + resp := checkStatus(w, "success", http.StatusOK) t.Log(resp.Data) if resp.Data.Price != 20 { t.Fatalf("wrong price: %d %d", resp.Data.Price, 10) } }) - t.Run("3 2 post data error", func(t *testing.T) { + t.Run("400 post data error", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`)) h32.ServeHTTP(w, r) - resp := checkStatus(w, "error") + resp := checkStatus(w, "error", http.StatusBadRequest) if resp.Error != "price is empty" { t.Fatalf("wrong error") } }) - // fn check - shouldPanic := func() { - r := recover() - if r == nil { - t.Fatalf("should panic") + t.Run("500 internal server error", func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Name": "root"}`)) + h32.ServeHTTP(w, r) + resp := checkStatus(w, "error", http.StatusInternalServerError) + if resp.Error != "internal server error" { + t.Fatalf("wrong error") } - t.Log(r) - } - - t.Run("2 0 panic", func(t *testing.T) { - defer shouldPanic() - JSONHandler(func(w http.ResponseWriter, r *http.Request) {}) }) - t.Run("2 1 panic return value", func(t *testing.T) { - defer shouldPanic() - JSONHandler(func(w http.ResponseWriter, r *http.Request) string { - return "" - }) - }) - - t.Run("2 1 panic arguments", func(t *testing.T) { - defer shouldPanic() - JSONHandler(func(r *http.Request, w http.ResponseWriter) error { - return nil - }) - }) - - t.Run("3 1 panic arguments", func(t *testing.T) { - defer shouldPanic() - JSONHandler(func(name string, r *http.Request, w http.ResponseWriter) error { - return nil - }) - }) - - t.Run("3 2 panic return value", func(t *testing.T) { - defer shouldPanic() - //lint:ignore ST1008 intentional - JSONHandler(func(name string, r *http.Request, w http.ResponseWriter) (error, string) { - return nil, "panic" - }) + t.Run("500 misuse", func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest("POST", "/", nil) + JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) { + return http.StatusOK, make(chan int), nil + }).ServeHTTP(w, r) + resp := checkStatus(w, "error", http.StatusInternalServerError) + if resp.Error != "json marshal error" { + t.Fatalf("wrong error") + } }) - t.Run("2 2 forbidden", func(t *testing.T) { - code := http.StatusForbidden - body := []byte("forbidden") - h := JSONHandler(func(w http.ResponseWriter, r *http.Request) (*Data, error) { - w.WriteHeader(code) - w.Write(body) - return nil, nil - }) - + t.Run("500 empty status code", func(t *testing.T) { w := httptest.NewRecorder() - r := httptest.NewRequest("GET", "/", nil) - h.ServeHTTP(w, r) - if w.Code != http.StatusForbidden { - t.Fatalf("wrong code: %d %d", w.Code, code) - } - if !bytes.Equal(w.Body.Bytes(), []byte("forbidden")) { - t.Fatalf("wrong body: %s %s", w.Body.Bytes(), body) - } + r := httptest.NewRequest("POST", "/", nil) + JSONHandlerFunc(func(r *http.Request) (status int, data interface{}, err error) { + return + }).ServeHTTP(w, r) + checkStatus(w, "error", http.StatusInternalServerError) }) }