diff --git a/go.sum b/go.sum index f9f1a9ac3..ef2c51a91 100644 --- a/go.sum +++ b/go.sum @@ -41,6 +41,7 @@ github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5a github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.2 h1:X2ev0eStA3AbceY54o37/0PQ/UWqKEiiO2dKL5OPaFM= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/rpmpack v0.0.0-20191226140753-aa36bfddb3a0 h1:BW6OvS3kpT5UEPbCZ+KyX/OB4Ks9/MNMhWjqPPkZxsE= github.com/google/rpmpack v0.0.0-20191226140753-aa36bfddb3a0/go.mod h1:RaTPr0KUf2K7fnZYLNDrr8rxAamWs3iNywJLtQ2AzBg= diff --git a/tsweb/jsonhandler.go b/tsweb/jsonhandler.go index 3199138e6..9b3a0378b 100644 --- a/tsweb/jsonhandler.go +++ b/tsweb/jsonhandler.go @@ -19,7 +19,7 @@ type response struct { // JSONHandlerFunc is an HTTP ReturnHandler that writes JSON responses to the client. // // Return a HTTPError to show an error message, otherwise JSONHandlerFunc will -// only report "internal server error" to the user. +// only report "internal server error" to the user with status code 500. type JSONHandlerFunc func(r *http.Request) (status int, data interface{}, err error) // ServeHTTPReturn implements the ReturnHandler interface. @@ -31,23 +31,12 @@ type JSONHandlerFunc func(r *http.Request) (status int, data interface{}, err er // return http.StatusBadRequest, nil, err // } // -// See jsonhandler_text.go for examples. +// See jsonhandler_test.go for examples. func (fn JSONHandlerFunc) ServeHTTPReturn(w http.ResponseWriter, r *http.Request) error { w.Header().Set("Content-Type", "application/json") var resp *response status, data, err := fn(r) - if status == 0 { - status = http.StatusInternalServerError - resp = &response{ - Status: "error", - Error: "internal server error", - } - } else if err == nil { - resp = &response{ - Status: "success", - Data: data, - } - } else { + if err != nil { if werr, ok := err.(HTTPError); ok { resp = &response{ Status: "error", @@ -61,12 +50,29 @@ func (fn JSONHandlerFunc) ServeHTTPReturn(w http.ResponseWriter, r *http.Request if werr.Msg != "" { err = fmt.Errorf("%s: %w", werr.Msg, err) } + // take status from the HTTPError to encourage error handling in one location + if status != 0 && status != werr.Code { + err = fmt.Errorf("[unexpected] non-zero status that does not match HTTPError status, status: %d, HTTPError.code: %d: %w", status, werr.Code, err) + } + status = werr.Code } else { + status = http.StatusInternalServerError resp = &response{ Status: "error", Error: "internal server error", } } + } else if status == 0 { + status = http.StatusInternalServerError + resp = &response{ + Status: "error", + Error: "internal server error", + } + } else if err == nil { + resp = &response{ + Status: "success", + Data: data, + } } b, jerr := json.Marshal(resp) diff --git a/tsweb/jsonhandler_test.go b/tsweb/jsonhandler_test.go index d7032f1c2..b36d55d89 100644 --- a/tsweb/jsonhandler_test.go +++ b/tsweb/jsonhandler_test.go @@ -11,6 +11,8 @@ import ( "net/http/httptest" "strings" "testing" + + "github.com/google/go-cmp/cmp" ) type Data struct { @@ -40,11 +42,11 @@ func TestNewJSONHandler(t *testing.T) { if d.Status == status { t.Logf("ok: %s", d.Status) } else { - t.Fatalf("wrong status: %s %s", d.Status, status) + t.Fatalf("wrong status: got: %s, want: %s", d.Status, status) } if w.Code != code { - t.Fatalf("wrong status code: %d %d", w.Code, code) + t.Fatalf("wrong status code: got: %d, want: %d", w.Code, code) } if w.Header().Get("Content-Type") != "application/json" { @@ -67,7 +69,7 @@ func TestNewJSONHandler(t *testing.T) { t.Run("403 HTTPError", func(t *testing.T) { h := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) { - return http.StatusForbidden, nil, fmt.Errorf("forbidden") + return 0, nil, Error(http.StatusForbidden, "forbidden", nil) }) w := httptest.NewRecorder() @@ -90,11 +92,11 @@ func TestNewJSONHandler(t *testing.T) { h31 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) { body := new(Data) if err := json.NewDecoder(r.Body).Decode(body); err != nil { - return http.StatusBadRequest, nil, err + return 0, nil, Error(http.StatusBadRequest, err.Error(), err) } if body.Name == "" { - return http.StatusBadRequest, nil, Error(http.StatusBadGateway, "name is empty", nil) + return 0, nil, Error(http.StatusBadRequest, "name is empty", nil) } return http.StatusOK, nil, nil @@ -126,13 +128,13 @@ func TestNewJSONHandler(t *testing.T) { h32 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) { body := new(Data) if err := json.NewDecoder(r.Body).Decode(body); err != nil { - return http.StatusBadRequest, nil, err + return 0, nil, Error(http.StatusBadRequest, err.Error(), err) } if body.Name == "root" { - return http.StatusInternalServerError, nil, fmt.Errorf("invalid name") + return 0, nil, fmt.Errorf("invalid name") } if body.Price == 0 { - return http.StatusBadRequest, nil, Error(http.StatusBadGateway, "price is empty", nil) + return 0, nil, Error(http.StatusBadRequest, "price is empty", nil) } return http.StatusOK, &Data{Price: body.Price * 2}, nil @@ -159,7 +161,7 @@ func TestNewJSONHandler(t *testing.T) { } }) - t.Run("500 internal server error", func(t *testing.T) { + t.Run("500 internal server error (unspecified error, not of type HTTPError)", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Name": "root"}`)) h32.ServeHTTPReturn(w, r) @@ -189,4 +191,41 @@ func TestNewJSONHandler(t *testing.T) { }).ServeHTTPReturn(w, r) checkStatus(w, "error", http.StatusInternalServerError) }) + + t.Run("403 forbidden, status returned by JSONHandlerFunc and HTTPError agree", func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest("POST", "/", nil) + JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) { + return http.StatusForbidden, nil, Error(http.StatusForbidden, "403 forbidden", nil) + }).ServeHTTPReturn(w, r) + want := &Response{ + Status: "error", + Data: &Data{}, + Error: "403 forbidden", + } + got := checkStatus(w, "error", http.StatusForbidden) + if diff := cmp.Diff(want, got); diff != "" { + t.Fatalf(diff) + } + }) + + t.Run("403 forbidden, status returned by JSONHandlerFunc and HTTPError do not agree", func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest("POST", "/", nil) + err := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) { + return http.StatusInternalServerError, nil, Error(http.StatusForbidden, "403 forbidden", nil) + }).ServeHTTPReturn(w, r) + if !strings.HasPrefix(err.Error(), "[unexpected]") { + t.Fatalf("returned error should have `[unexpected]` to note the disagreeing status codes: %v", err) + } + want := &Response{ + Status: "error", + Data: &Data{}, + Error: "403 forbidden", + } + got := checkStatus(w, "error", http.StatusForbidden) + if diff := cmp.Diff(want, got); diff != "" { + t.Fatalf("(-want,+got):\n%s", diff) + } + }) }