diff --git a/tsweb/jsonhandler.go b/tsweb/jsonhandler.go index 677d49eda..1093e222a 100644 --- a/tsweb/jsonhandler.go +++ b/tsweb/jsonhandler.go @@ -102,26 +102,33 @@ func JSONHandler(fn interface{}) http.Handler { panic("JSONHandler: number of input parameter should be 2 or 3") } + var e reflect.Value switch len(vs) { case 1: // todo support other error types - if vs[0].IsNil() { + if vs[0].IsZero() { writeResponse(w, http.StatusOK, responseSuccess(nil)) - } else { - err := vs[0].Interface().(error) - writeResponse(w, http.StatusBadRequest, responseError(err.Error())) + return } + e = vs[0] case 2: - if vs[1].IsNil() { - if !vs[0].IsNil() { + if vs[1].IsZero() { + if !vs[0].IsZero() { writeResponse(w, http.StatusOK, responseSuccess(vs[0].Interface())) } - } else { - err := vs[1].Interface().(error) - writeResponse(w, http.StatusBadRequest, responseError(err.Error())) + return } + e = vs[1] default: panic("JSONHandler: number of return values should be 1 or 2") } + + if e.Type().AssignableTo(reflect.TypeOf(HTTPError{})) { + err := e.Interface().(HTTPError) + writeResponse(w, err.Code, responseError(err.Error())) + } else { + err := e.Interface().(error) + writeResponse(w, http.StatusBadRequest, responseError(err.Error())) + } }) } diff --git a/tsweb/jsonhandler_test.go b/tsweb/jsonhandler_test.go index 9d78d93ce..990d1f2f2 100644 --- a/tsweb/jsonhandler_test.go +++ b/tsweb/jsonhandler_test.go @@ -59,6 +59,19 @@ func TestNewJSONHandler(t *testing.T) { checkStatus(w, "success") }) + t.Run("2 1 HTTPError", func(t *testing.T) { + h := JSONHandler(func(w http.ResponseWriter, r *http.Request) HTTPError { + return Error(http.StatusForbidden, "forbidden", nil) + }) + + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/", nil) + h.ServeHTTP(w, r) + if w.Code != http.StatusForbidden { + t.Fatalf("wrong code: %d %d", w.Code, http.StatusForbidden) + } + }) + // 2 2 h22 := JSONHandler(func(w http.ResponseWriter, r *http.Request) (*Data, error) { return &Data{Name: "tailscale"}, nil