tsweb: rewrite JSONHandler without using reflect (#684)

Closes #656 #657

Signed-off-by: Zijie Lu <zijie@tailscale.com>
pull/692/head
halulu 4 years ago committed by GitHub
parent 93ffc565e5
commit 1835bb6f85
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -7,7 +7,6 @@ package tsweb
import ( import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"reflect"
) )
type response struct { type response struct {
@ -16,119 +15,59 @@ type response struct {
Data interface{} `json:"data,omitempty"` Data interface{} `json:"data,omitempty"`
} }
func responseSuccess(data interface{}) *response { // TODO: Header
return &response{
Status: "success",
Data: data,
}
}
func responseError(e string) *response { // JSONHandlerFunc only take *http.Request as argument to avoid any misuse of http.ResponseWriter.
return &response{ // The function's results must be (status int, data interface{}, err error).
Status: "error", // Return a HTTPError to show an error message, otherwise JSONHandler will only show "internal server error".
Error: e, type JSONHandlerFunc func(r *http.Request) (status int, data interface{}, err error)
}
}
func writeResponse(w http.ResponseWriter, s int, resp *response) { // ServeHTTP calls the JSONHandlerFunc and automatically marshals http responses.
b, _ := json.Marshal(resp) //
// Use the following code to unmarshal the request body
// body := new(DataType)
// if err := json.NewDecoder(r.Body).Decode(body); err != nil {
// return http.StatusBadRequest, nil, err
// }
//
// Check jsonhandler_text.go for examples
func (fn JSONHandlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(s) var resp *response
w.Write(b) status, data, err := fn(r)
} if status == 0 {
status = http.StatusInternalServerError
func checkFn(t reflect.Type) { resp = &response{
h := reflect.TypeOf(http.HandlerFunc(nil)) Status: "error",
switch t.NumIn() { Error: "internal server error",
case 2, 3:
if !t.In(0).AssignableTo(h.In(0)) {
panic("first argument must be http.ResponseWriter")
}
if !t.In(1).AssignableTo(h.In(1)) {
panic("second argument must be *http.Request")
} }
default: } else if err == nil {
panic("JSONHandler: number of input parameter should be 2 or 3") resp = &response{
Status: "success",
Data: data,
} }
} else {
switch t.NumOut() { if werr, ok := err.(HTTPError); ok {
case 1: resp = &response{
if !t.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) { Status: "error",
panic("return value must be error") Error: werr.Msg,
Data: data,
} }
case 2: } else {
if !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) { resp = &response{
panic("second return value must be error") Status: "error",
Error: "internal server error",
} }
default:
panic("JSONHandler: number of return values should be 1 or 2")
} }
} }
// JSONHandler wraps an HTTP handler function with a version that automatically b, err := json.Marshal(resp)
// unmarshals and marshals requests and responses respectively into fn's arguments
// and results.
//
// The fn parameter is a function. It must take two or three input arguments.
// The first two arguments must be http.ResponseWriter and *http.Request.
// The optional third argument can be of any type representing the JSON input.
// The function's results can be either (error) or (T, error), where T is the
// JSON-marshalled result type.
//
// For example:
// fn := func(w http.ResponseWriter, r *http.Request, in *Req) (*Res, error) { ... }
func JSONHandler(fn interface{}) http.Handler {
v := reflect.ValueOf(fn)
t := v.Type()
checkFn(t)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
wv := reflect.ValueOf(w)
rv := reflect.ValueOf(r)
var vs []reflect.Value
switch t.NumIn() {
case 2:
vs = v.Call([]reflect.Value{wv, rv})
case 3:
dv := reflect.New(t.In(2))
err := json.NewDecoder(r.Body).Decode(dv.Interface())
if err != nil { if err != nil {
writeResponse(w, http.StatusBadRequest, responseError("bad json")) w.WriteHeader(http.StatusInternalServerError)
return w.Write([]byte(`{"status":"error","error":"json marshal error"}`))
}
vs = v.Call([]reflect.Value{wv, rv, dv.Elem()})
default:
panic("JSONHandler: number of input parameter should be 2 or 3")
}
var e reflect.Value
switch len(vs) {
case 1:
// todo support other error types
if vs[0].IsZero() {
writeResponse(w, http.StatusOK, responseSuccess(nil))
return
}
e = vs[0]
case 2:
if vs[1].IsZero() {
if !vs[0].IsZero() {
writeResponse(w, http.StatusOK, responseSuccess(vs[0].Interface()))
}
return return
} }
e = vs[1]
default:
panic("JSONHandler: number of return values should be 1 or 2")
}
if e.Type().AssignableTo(reflect.TypeOf(HTTPError{})) { w.WriteHeader(status)
err := e.Interface().(HTTPError) w.Write(b)
writeResponse(w, err.Code, responseError(err.Error()))
} else {
err := e.Interface().(error)
writeResponse(w, http.StatusBadRequest, responseError(err.Error()))
}
})
} }

@ -5,9 +5,8 @@
package tsweb package tsweb
import ( import (
"bytes"
"encoding/json" "encoding/json"
"errors" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
@ -26,7 +25,7 @@ type Response struct {
} }
func TestNewJSONHandler(t *testing.T) { func TestNewJSONHandler(t *testing.T) {
checkStatus := func(w *httptest.ResponseRecorder, status string) *Response { checkStatus := func(w *httptest.ResponseRecorder, status string, code int) *Response {
d := &Response{ d := &Response{
Data: &Data{}, Data: &Data{},
} }
@ -44,6 +43,10 @@ func TestNewJSONHandler(t *testing.T) {
t.Fatalf("wrong status: %s %s", d.Status, status) t.Fatalf("wrong status: %s %s", d.Status, status)
} }
if w.Code != code {
t.Fatalf("wrong status code: %d %d", w.Code, code)
}
if w.Header().Get("Content-Type") != "application/json" { if w.Header().Get("Content-Type") != "application/json" {
t.Fatalf("wrong content type: %s", w.Header().Get("Content-Type")) t.Fatalf("wrong content type: %s", w.Header().Get("Content-Type"))
} }
@ -51,163 +54,139 @@ func TestNewJSONHandler(t *testing.T) {
return d return d
} }
// 2 1 h21 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
h21 := JSONHandler(func(w http.ResponseWriter, r *http.Request) error { return http.StatusOK, nil, nil
return nil
}) })
t.Run("2 1 simple", func(t *testing.T) { t.Run("200 simple", func(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil) r := httptest.NewRequest("GET", "/", nil)
h21.ServeHTTP(w, r) h21.ServeHTTP(w, r)
checkStatus(w, "success") checkStatus(w, "success", http.StatusOK)
}) })
t.Run("2 1 HTTPError", func(t *testing.T) { t.Run("403 HTTPError", func(t *testing.T) {
h := JSONHandler(func(w http.ResponseWriter, r *http.Request) HTTPError { h := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
return Error(http.StatusForbidden, "forbidden", nil) return http.StatusForbidden, nil, fmt.Errorf("forbidden")
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil) r := httptest.NewRequest("GET", "/", nil)
h.ServeHTTP(w, r) h.ServeHTTP(w, r)
if w.Code != http.StatusForbidden { checkStatus(w, "error", http.StatusForbidden)
t.Fatalf("wrong code: %d %d", w.Code, http.StatusForbidden)
}
}) })
// 2 2 h22 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
h22 := JSONHandler(func(w http.ResponseWriter, r *http.Request) (*Data, error) { return http.StatusOK, &Data{Name: "tailscale"}, nil
return &Data{Name: "tailscale"}, nil
}) })
t.Run("2 2 get data", func(t *testing.T) {
t.Run("200 get data", func(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil) r := httptest.NewRequest("GET", "/", nil)
h22.ServeHTTP(w, r) h22.ServeHTTP(w, r)
checkStatus(w, "success") checkStatus(w, "success", http.StatusOK)
}) })
// 3 1 h31 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
h31 := JSONHandler(func(w http.ResponseWriter, r *http.Request, d *Data) error { body := new(Data)
if d.Name == "" { if err := json.NewDecoder(r.Body).Decode(body); err != nil {
return errors.New("name is empty") return http.StatusBadRequest, nil, err
} }
return nil if body.Name == "" {
return http.StatusBadRequest, nil, Error(http.StatusBadGateway, "name is empty", nil)
}
return http.StatusOK, nil, nil
}) })
t.Run("3 1 post data", func(t *testing.T) { t.Run("200 post data", func(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Name": "tailscale"}`)) r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Name": "tailscale"}`))
h31.ServeHTTP(w, r) h31.ServeHTTP(w, r)
checkStatus(w, "success") checkStatus(w, "success", http.StatusOK)
}) })
t.Run("3 1 bad json", func(t *testing.T) { t.Run("400 bad json", func(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
r := httptest.NewRequest("POST", "/", strings.NewReader(`{`)) r := httptest.NewRequest("POST", "/", strings.NewReader(`{`))
h31.ServeHTTP(w, r) h31.ServeHTTP(w, r)
checkStatus(w, "error") checkStatus(w, "error", http.StatusBadRequest)
}) })
t.Run("3 1 post data error", func(t *testing.T) { t.Run("400 post data error", func(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`)) r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`))
h31.ServeHTTP(w, r) h31.ServeHTTP(w, r)
resp := checkStatus(w, "error") resp := checkStatus(w, "error", http.StatusBadRequest)
if resp.Error != "name is empty" { if resp.Error != "name is empty" {
t.Fatalf("wrong error") t.Fatalf("wrong error")
} }
}) })
// 3 2 h32 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
h32 := JSONHandler(func(w http.ResponseWriter, r *http.Request, d *Data) (*Data, error) { body := new(Data)
if d.Price == 0 { if err := json.NewDecoder(r.Body).Decode(body); err != nil {
return nil, errors.New("price is empty") return http.StatusBadRequest, nil, err
}
if body.Name == "root" {
return http.StatusInternalServerError, nil, fmt.Errorf("invalid name")
}
if body.Price == 0 {
return http.StatusBadRequest, nil, Error(http.StatusBadGateway, "price is empty", nil)
} }
return &Data{Price: d.Price * 2}, nil return http.StatusOK, &Data{Price: body.Price * 2}, nil
}) })
t.Run("3 2 post data", func(t *testing.T) {
t.Run("200 post data", func(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Price": 10}`)) r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Price": 10}`))
h32.ServeHTTP(w, r) h32.ServeHTTP(w, r)
resp := checkStatus(w, "success") resp := checkStatus(w, "success", http.StatusOK)
t.Log(resp.Data) t.Log(resp.Data)
if resp.Data.Price != 20 { if resp.Data.Price != 20 {
t.Fatalf("wrong price: %d %d", resp.Data.Price, 10) t.Fatalf("wrong price: %d %d", resp.Data.Price, 10)
} }
}) })
t.Run("3 2 post data error", func(t *testing.T) { t.Run("400 post data error", func(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`)) r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`))
h32.ServeHTTP(w, r) h32.ServeHTTP(w, r)
resp := checkStatus(w, "error") resp := checkStatus(w, "error", http.StatusBadRequest)
if resp.Error != "price is empty" { if resp.Error != "price is empty" {
t.Fatalf("wrong error") t.Fatalf("wrong error")
} }
}) })
// fn check t.Run("500 internal server error", func(t *testing.T) {
shouldPanic := func() { w := httptest.NewRecorder()
r := recover() r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Name": "root"}`))
if r == nil { h32.ServeHTTP(w, r)
t.Fatalf("should panic") resp := checkStatus(w, "error", http.StatusInternalServerError)
} if resp.Error != "internal server error" {
t.Log(r) t.Fatalf("wrong error")
} }
t.Run("2 0 panic", func(t *testing.T) {
defer shouldPanic()
JSONHandler(func(w http.ResponseWriter, r *http.Request) {})
})
t.Run("2 1 panic return value", func(t *testing.T) {
defer shouldPanic()
JSONHandler(func(w http.ResponseWriter, r *http.Request) string {
return ""
})
}) })
t.Run("2 1 panic arguments", func(t *testing.T) { t.Run("500 misuse", func(t *testing.T) {
defer shouldPanic() w := httptest.NewRecorder()
JSONHandler(func(r *http.Request, w http.ResponseWriter) error { r := httptest.NewRequest("POST", "/", nil)
return nil JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
}) return http.StatusOK, make(chan int), nil
}) }).ServeHTTP(w, r)
resp := checkStatus(w, "error", http.StatusInternalServerError)
t.Run("3 1 panic arguments", func(t *testing.T) { if resp.Error != "json marshal error" {
defer shouldPanic() t.Fatalf("wrong error")
JSONHandler(func(name string, r *http.Request, w http.ResponseWriter) error { }
return nil
})
})
t.Run("3 2 panic return value", func(t *testing.T) {
defer shouldPanic()
//lint:ignore ST1008 intentional
JSONHandler(func(name string, r *http.Request, w http.ResponseWriter) (error, string) {
return nil, "panic"
})
})
t.Run("2 2 forbidden", func(t *testing.T) {
code := http.StatusForbidden
body := []byte("forbidden")
h := JSONHandler(func(w http.ResponseWriter, r *http.Request) (*Data, error) {
w.WriteHeader(code)
w.Write(body)
return nil, nil
}) })
t.Run("500 empty status code", func(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil) r := httptest.NewRequest("POST", "/", nil)
h.ServeHTTP(w, r) JSONHandlerFunc(func(r *http.Request) (status int, data interface{}, err error) {
if w.Code != http.StatusForbidden { return
t.Fatalf("wrong code: %d %d", w.Code, code) }).ServeHTTP(w, r)
} checkStatus(w, "error", http.StatusInternalServerError)
if !bytes.Equal(w.Body.Bytes(), []byte("forbidden")) {
t.Fatalf("wrong body: %s %s", w.Body.Bytes(), body)
}
}) })
} }

Loading…
Cancel
Save