// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package tsweb import ( "encoding/json" "fmt" "net/http" "net/http/httptest" "strings" "testing" "github.com/google/go-cmp/cmp" ) type Data struct { Name string Price int } type Response struct { Status string Error string Data *Data } func TestNewJSONHandler(t *testing.T) { checkStatus := func(w *httptest.ResponseRecorder, status string, code int) *Response { d := &Response{ Data: &Data{}, } t.Logf("%s", w.Body.Bytes()) err := json.Unmarshal(w.Body.Bytes(), d) if err != nil { t.Logf(err.Error()) return nil } if d.Status == status { t.Logf("ok: %s", d.Status) } else { t.Fatalf("wrong status: got: %s, want: %s", d.Status, status) } if w.Code != code { t.Fatalf("wrong status code: got: %d, want: %d", w.Code, code) } if w.Header().Get("Content-Type") != "application/json" { t.Fatalf("wrong content type: %s", w.Header().Get("Content-Type")) } return d } h21 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) { return http.StatusOK, nil, nil }) t.Run("200 simple", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/", nil) h21.ServeHTTPReturn(w, r) checkStatus(w, "success", http.StatusOK) }) t.Run("403 HTTPError", func(t *testing.T) { h := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) { return 0, nil, Error(http.StatusForbidden, "forbidden", nil) }) w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/", nil) h.ServeHTTPReturn(w, r) checkStatus(w, "error", http.StatusForbidden) }) h22 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) { return http.StatusOK, &Data{Name: "tailscale"}, nil }) t.Run("200 get data", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/", nil) h22.ServeHTTPReturn(w, r) checkStatus(w, "success", http.StatusOK) }) h31 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) { body := new(Data) if err := json.NewDecoder(r.Body).Decode(body); err != nil { return 0, nil, Error(http.StatusBadRequest, err.Error(), err) } if body.Name == "" { return 0, nil, Error(http.StatusBadRequest, "name is empty", nil) } return http.StatusOK, nil, nil }) t.Run("200 post data", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Name": "tailscale"}`)) h31.ServeHTTPReturn(w, r) checkStatus(w, "success", http.StatusOK) }) t.Run("400 bad json", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/", strings.NewReader(`{`)) h31.ServeHTTPReturn(w, r) checkStatus(w, "error", http.StatusBadRequest) }) t.Run("400 post data error", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`)) h31.ServeHTTPReturn(w, r) resp := checkStatus(w, "error", http.StatusBadRequest) if resp.Error != "name is empty" { t.Fatalf("wrong error") } }) h32 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) { body := new(Data) if err := json.NewDecoder(r.Body).Decode(body); err != nil { return 0, nil, Error(http.StatusBadRequest, err.Error(), err) } if body.Name == "root" { return 0, nil, fmt.Errorf("invalid name") } if body.Price == 0 { return 0, nil, Error(http.StatusBadRequest, "price is empty", nil) } return http.StatusOK, &Data{Price: body.Price * 2}, nil }) t.Run("200 post data", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Price": 10}`)) h32.ServeHTTPReturn(w, r) resp := checkStatus(w, "success", http.StatusOK) t.Log(resp.Data) if resp.Data.Price != 20 { t.Fatalf("wrong price: %d %d", resp.Data.Price, 10) } }) t.Run("400 post data error", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`)) h32.ServeHTTPReturn(w, r) resp := checkStatus(w, "error", http.StatusBadRequest) if resp.Error != "price is empty" { t.Fatalf("wrong error") } }) t.Run("500 internal server error (unspecified error, not of type HTTPError)", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Name": "root"}`)) h32.ServeHTTPReturn(w, r) resp := checkStatus(w, "error", http.StatusInternalServerError) if resp.Error != "internal server error" { t.Fatalf("wrong error") } }) t.Run("500 misuse", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/", nil) JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) { return http.StatusOK, make(chan int), nil }).ServeHTTPReturn(w, r) resp := checkStatus(w, "error", http.StatusInternalServerError) if resp.Error != "json marshal error" { t.Fatalf("wrong error") } }) t.Run("500 empty status code", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/", nil) JSONHandlerFunc(func(r *http.Request) (status int, data interface{}, err error) { return }).ServeHTTPReturn(w, r) checkStatus(w, "error", http.StatusInternalServerError) }) t.Run("403 forbidden, status returned by JSONHandlerFunc and HTTPError agree", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/", nil) JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) { return http.StatusForbidden, nil, Error(http.StatusForbidden, "403 forbidden", nil) }).ServeHTTPReturn(w, r) want := &Response{ Status: "error", Data: &Data{}, Error: "403 forbidden", } got := checkStatus(w, "error", http.StatusForbidden) if diff := cmp.Diff(want, got); diff != "" { t.Fatalf(diff) } }) t.Run("403 forbidden, status returned by JSONHandlerFunc and HTTPError do not agree", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/", nil) err := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) { return http.StatusInternalServerError, nil, Error(http.StatusForbidden, "403 forbidden", nil) }).ServeHTTPReturn(w, r) if !strings.HasPrefix(err.Error(), "[unexpected]") { t.Fatalf("returned error should have `[unexpected]` to note the disagreeing status codes: %v", err) } want := &Response{ Status: "error", Data: &Data{}, Error: "403 forbidden", } got := checkStatus(w, "error", http.StatusForbidden) if diff := cmp.Diff(want, got); diff != "" { t.Fatalf("(-want,+got):\n%s", diff) } }) }