diff --git a/tsweb/jsonhandler.go b/tsweb/jsonhandler.go new file mode 100644 index 000000000..f8531bf20 --- /dev/null +++ b/tsweb/jsonhandler.go @@ -0,0 +1,125 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tsweb + +import ( + "encoding/json" + "net/http" + "reflect" +) + +type response struct { + Status string `json:"status"` + Error string `json:"error,omitempty"` + Data interface{} `json:"data,omitempty"` +} + +func responseSuccess(data interface{}) *response { + return &response{ + Status: "success", + Data: data, + } +} + +func responseError(e string) *response { + return &response{ + Status: "error", + Error: e, + } +} + +func writeResponse(w http.ResponseWriter, s int, resp *response) { + b, _ := json.Marshal(resp) + w.WriteHeader(s) + w.Header().Set("Content-Type", "application/json") + w.Write(b) +} + +func checkFn(t reflect.Type) { + h := reflect.TypeOf(http.HandlerFunc(nil)) + switch t.NumIn() { + case 2, 3: + if !t.In(0).AssignableTo(h.In(0)) { + panic("first argument must be http.ResponseWriter") + } + if !t.In(1).AssignableTo(h.In(1)) { + panic("second argument must be *http.Request") + } + default: + panic("JSONHandler: number of input parameter should be 2 or 3") + } + + switch t.NumOut() { + case 1: + if !t.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) { + panic("return value must be error") + } + case 2: + if !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) { + panic("second return value must be error") + } + default: + panic("JSONHandler: number of return values should be 1 or 2") + } +} + +// JSONHandler wraps an HTTP handler function with a version that automatically +// unmarshals and marshals requests and responses respectively into fn's arguments +// and results. +// +// The fn parameter is a function. It must take two or three input arguments. +// The first two arguments must be http.ResponseWriter and *http.Request. +// The optional third argument can be of any type representing the JSON input. +// The function's results can be either (error) or (T, error), where T is the +// JSON-marshalled result type. +// +// For example: +// fn := func(w http.ResponseWriter, r *http.Request, in *Req) (*Res, error) { ... } +func JSONHandler(fn interface{}) http.Handler { + v := reflect.ValueOf(fn) + t := v.Type() + checkFn(t) + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wv := reflect.ValueOf(w) + rv := reflect.ValueOf(r) + var vs []reflect.Value + + switch t.NumIn() { + case 2: + vs = v.Call([]reflect.Value{wv, rv}) + case 3: + dv := reflect.New(t.In(2)) + err := json.NewDecoder(r.Body).Decode(dv.Interface()) + if err != nil { + writeResponse(w, http.StatusBadRequest, responseError("bad json")) + return + } + vs = v.Call([]reflect.Value{wv, rv, dv.Elem()}) + default: + panic("JSONHandler: number of input parameter should be 2 or 3") + } + + switch len(vs) { + case 1: + // todo support other error types + if vs[0].IsNil() { + writeResponse(w, http.StatusOK, responseSuccess(nil)) + } else { + err := vs[0].Interface().(error) + writeResponse(w, http.StatusBadRequest, responseError(err.Error())) + } + case 2: + if vs[1].IsNil() { + writeResponse(w, http.StatusOK, responseSuccess(vs[0].Interface())) + } else { + err := vs[1].Interface().(error) + writeResponse(w, http.StatusBadRequest, responseError(err.Error())) + } + default: + panic("JSONHandler: number of return values should be 1 or 2") + } + }) +} diff --git a/tsweb/jsonhandler_test.go b/tsweb/jsonhandler_test.go new file mode 100644 index 000000000..8379491ed --- /dev/null +++ b/tsweb/jsonhandler_test.go @@ -0,0 +1,175 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tsweb + +import ( + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +type Data struct { + Name string + Price int +} + +type Response struct { + Status string + Error string + Data *Data +} + +func TestNewJSONHandler(t *testing.T) { + checkStatus := func(w *httptest.ResponseRecorder, status string) *Response { + d := &Response{ + Data: &Data{}, + } + + t.Logf("%s", w.Body.Bytes()) + err := json.Unmarshal(w.Body.Bytes(), d) + if err != nil { + t.Logf(err.Error()) + return nil + } + + if d.Status == status { + t.Logf("ok: %s", d.Status) + } else { + t.Fatalf("wrong status: %s %s", d.Status, status) + } + + return d + } + + // 2 1 + h21 := JSONHandler(func(w http.ResponseWriter, r *http.Request) error { + return nil + }) + + t.Run("2 1 simple", func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/", nil) + h21.ServeHTTP(w, r) + checkStatus(w, "success") + }) + + // 2 2 + h22 := JSONHandler(func(w http.ResponseWriter, r *http.Request) (*Data, error) { + return &Data{Name: "tailscale"}, nil + }) + t.Run("2 2 get data", func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/", nil) + h22.ServeHTTP(w, r) + checkStatus(w, "success") + }) + + // 3 1 + h31 := JSONHandler(func(w http.ResponseWriter, r *http.Request, d *Data) error { + if d.Name == "" { + return errors.New("name is empty") + } + + return nil + }) + t.Run("3 1 post data", func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Name": "tailscale"}`)) + h31.ServeHTTP(w, r) + checkStatus(w, "success") + }) + + t.Run("3 1 bad json", func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest("POST", "/", strings.NewReader(`{`)) + h31.ServeHTTP(w, r) + checkStatus(w, "error") + }) + + t.Run("3 1 post data error", func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`)) + h31.ServeHTTP(w, r) + resp := checkStatus(w, "error") + if resp.Error != "name is empty" { + t.Fatalf("wrong error") + } + }) + + // 3 2 + h32 := JSONHandler(func(w http.ResponseWriter, r *http.Request, d *Data) (*Data, error) { + if d.Price == 0 { + return nil, errors.New("price is empty") + } + + return &Data{Price: d.Price * 2}, nil + }) + t.Run("3 2 post data", func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Price": 10}`)) + h32.ServeHTTP(w, r) + resp := checkStatus(w, "success") + t.Log(resp.Data) + if resp.Data.Price != 20 { + t.Fatalf("wrong price: %d %d", resp.Data.Price, 10) + } + }) + + t.Run("3 2 post data error", func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`)) + h32.ServeHTTP(w, r) + resp := checkStatus(w, "error") + if resp.Error != "price is empty" { + t.Fatalf("wrong error") + } + }) + + // fn check + shouldPanic := func() { + r := recover() + if r == nil { + t.Fatalf("should panic") + } + t.Log(r) + } + + t.Run("2 0 panic", func(t *testing.T) { + defer shouldPanic() + JSONHandler(func(w http.ResponseWriter, r *http.Request) {}) + }) + + t.Run("2 1 panic return value", func(t *testing.T) { + defer shouldPanic() + JSONHandler(func(w http.ResponseWriter, r *http.Request) string { + return "" + }) + }) + + t.Run("2 1 panic arguments", func(t *testing.T) { + defer shouldPanic() + JSONHandler(func(r *http.Request, w http.ResponseWriter) error { + return nil + }) + }) + + t.Run("3 1 panic arguments", func(t *testing.T) { + defer shouldPanic() + JSONHandler(func(name string, r *http.Request, w http.ResponseWriter) error { + return nil + }) + }) + + t.Run("3 2 panic return value", func(t *testing.T) { + defer shouldPanic() + //lint:ignore ST1008 intentional + JSONHandler(func(name string, r *http.Request, w http.ResponseWriter) (error, string) { + return nil, "panic" + }) + }) +}