From d1a30be2757792419e783d5c245811f547d73d92 Mon Sep 17 00:00:00 2001 From: Zijie Lu Date: Tue, 9 Jun 2020 17:40:45 -0400 Subject: [PATCH] tsweb: JSONHandler: supports HTTPError Signed-off-by: Zijie Lu --- tsweb/jsonhandler.go | 25 ++++++++++++++++--------- tsweb/jsonhandler_test.go | 13 +++++++++++++ 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/tsweb/jsonhandler.go b/tsweb/jsonhandler.go index 677d49eda..1093e222a 100644 --- a/tsweb/jsonhandler.go +++ b/tsweb/jsonhandler.go @@ -102,26 +102,33 @@ func JSONHandler(fn interface{}) http.Handler { panic("JSONHandler: number of input parameter should be 2 or 3") } + var e reflect.Value switch len(vs) { case 1: // todo support other error types - if vs[0].IsNil() { + if vs[0].IsZero() { writeResponse(w, http.StatusOK, responseSuccess(nil)) - } else { - err := vs[0].Interface().(error) - writeResponse(w, http.StatusBadRequest, responseError(err.Error())) + return } + e = vs[0] case 2: - if vs[1].IsNil() { - if !vs[0].IsNil() { + if vs[1].IsZero() { + if !vs[0].IsZero() { writeResponse(w, http.StatusOK, responseSuccess(vs[0].Interface())) } - } else { - err := vs[1].Interface().(error) - writeResponse(w, http.StatusBadRequest, responseError(err.Error())) + return } + e = vs[1] default: panic("JSONHandler: number of return values should be 1 or 2") } + + if e.Type().AssignableTo(reflect.TypeOf(HTTPError{})) { + err := e.Interface().(HTTPError) + writeResponse(w, err.Code, responseError(err.Error())) + } else { + err := e.Interface().(error) + writeResponse(w, http.StatusBadRequest, responseError(err.Error())) + } }) } diff --git a/tsweb/jsonhandler_test.go b/tsweb/jsonhandler_test.go index 9d78d93ce..990d1f2f2 100644 --- a/tsweb/jsonhandler_test.go +++ b/tsweb/jsonhandler_test.go @@ -59,6 +59,19 @@ func TestNewJSONHandler(t *testing.T) { checkStatus(w, "success") }) + t.Run("2 1 HTTPError", func(t *testing.T) { + h := JSONHandler(func(w http.ResponseWriter, r *http.Request) HTTPError { + return Error(http.StatusForbidden, "forbidden", nil) + }) + + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/", nil) + h.ServeHTTP(w, r) + if w.Code != http.StatusForbidden { + t.Fatalf("wrong code: %d %d", w.Code, http.StatusForbidden) + } + }) + // 2 2 h22 := JSONHandler(func(w http.ResponseWriter, r *http.Request) (*Data, error) { return &Data{Name: "tailscale"}, nil