// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package tsweb import ( "bytes" "compress/gzip" "encoding/json" "fmt" "io" "net/http" "net/http/httptest" "strings" "testing" "github.com/google/go-cmp/cmp" ) type Data struct { Name string Price int } type Response struct { Status string Error string Data *Data } func TestNewJSONHandler(t *testing.T) { checkStatus := func(t *testing.T, w *httptest.ResponseRecorder, status string, code int) *Response { d := &Response{ Data: &Data{}, } bodyBytes := w.Body.Bytes() if w.Result().Header.Get("Content-Encoding") == "gzip" { zr, err := gzip.NewReader(bytes.NewReader(bodyBytes)) if err != nil { t.Fatalf("gzip read error at start: %v", err) } bodyBytes, err = io.ReadAll(zr) if err != nil { t.Fatalf("gzip read error: %v", err) } } t.Logf("%s", bodyBytes) err := json.Unmarshal(bodyBytes, d) if err != nil { t.Logf(err.Error()) return nil } if d.Status == status { t.Logf("ok: %s", d.Status) } else { t.Fatalf("wrong status: got: %s, want: %s", d.Status, status) } if w.Code != code { t.Fatalf("wrong status code: got: %d, want: %d", w.Code, code) } if w.Header().Get("Content-Type") != "application/json" { t.Fatalf("wrong content type: %s", w.Header().Get("Content-Type")) } return d } h21 := JSONHandlerFunc(func(r *http.Request) (int, any, error) { return http.StatusOK, nil, nil }) t.Run("200 simple", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/", nil) h21.ServeHTTPReturn(w, r) checkStatus(t, w, "success", http.StatusOK) }) t.Run("403 HTTPError", func(t *testing.T) { h := JSONHandlerFunc(func(r *http.Request) (int, any, error) { return 0, nil, Error(http.StatusForbidden, "forbidden", nil) }) w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/", nil) h.ServeHTTPReturn(w, r) checkStatus(t, w, "error", http.StatusForbidden) }) h22 := JSONHandlerFunc(func(r *http.Request) (int, any, error) { return http.StatusOK, &Data{Name: "tailscale"}, nil }) t.Run("200 get data", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/", nil) h22.ServeHTTPReturn(w, r) checkStatus(t, w, "success", http.StatusOK) }) h31 := JSONHandlerFunc(func(r *http.Request) (int, any, error) { body := new(Data) if err := json.NewDecoder(r.Body).Decode(body); err != nil { return 0, nil, Error(http.StatusBadRequest, err.Error(), err) } if body.Name == "" { return 0, nil, Error(http.StatusBadRequest, "name is empty", nil) } return http.StatusOK, nil, nil }) t.Run("200 post data", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Name": "tailscale"}`)) h31.ServeHTTPReturn(w, r) checkStatus(t, w, "success", http.StatusOK) }) t.Run("400 bad json", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/", strings.NewReader(`{`)) h31.ServeHTTPReturn(w, r) checkStatus(t, w, "error", http.StatusBadRequest) }) t.Run("400 post data error", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`)) h31.ServeHTTPReturn(w, r) resp := checkStatus(t, w, "error", http.StatusBadRequest) if resp.Error != "name is empty" { t.Fatalf("wrong error") } }) h32 := JSONHandlerFunc(func(r *http.Request) (int, any, error) { body := new(Data) if err := json.NewDecoder(r.Body).Decode(body); err != nil { return 0, nil, Error(http.StatusBadRequest, err.Error(), err) } if body.Name == "root" { return 0, nil, fmt.Errorf("invalid name") } if body.Price == 0 { return 0, nil, Error(http.StatusBadRequest, "price is empty", nil) } return http.StatusOK, &Data{Price: body.Price * 2}, nil }) t.Run("200 post data", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Price": 10}`)) h32.ServeHTTPReturn(w, r) resp := checkStatus(t, w, "success", http.StatusOK) t.Log(resp.Data) if resp.Data.Price != 20 { t.Fatalf("wrong price: %d %d", resp.Data.Price, 10) } }) t.Run("gzipped", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Price": 10}`)) r.Header.Set("Accept-Encoding", "gzip") h32.ServeHTTPReturn(w, r) res := w.Result() if ct := res.Header.Get("Content-Encoding"); ct != "gzip" { t.Fatalf("encoding = %q; want gzip", ct) } resp := checkStatus(t, w, "success", http.StatusOK) t.Log(resp.Data) if resp.Data.Price != 20 { t.Fatalf("wrong price: %d %d", resp.Data.Price, 10) } }) t.Run("gzipped_400", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Price": 10}`)) r.Header.Set("Accept-Encoding", "gzip") value := []string{"foo", "foo", "foo"} JSONHandlerFunc(func(r *http.Request) (int, any, error) { return 400, value, nil }).ServeHTTPReturn(w, r) res := w.Result() if ct := res.Header.Get("Content-Encoding"); ct != "gzip" { t.Fatalf("encoding = %q; want gzip", ct) } if res.StatusCode != 400 { t.Errorf("Status = %v; want 400", res.StatusCode) } }) t.Run("400 post data error", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`)) h32.ServeHTTPReturn(w, r) resp := checkStatus(t, w, "error", http.StatusBadRequest) if resp.Error != "price is empty" { t.Fatalf("wrong error") } }) t.Run("500 internal server error (unspecified error, not of type HTTPError)", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Name": "root"}`)) h32.ServeHTTPReturn(w, r) resp := checkStatus(t, w, "error", http.StatusInternalServerError) if resp.Error != "internal server error" { t.Fatalf("wrong error") } }) t.Run("500 misuse", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/", nil) JSONHandlerFunc(func(r *http.Request) (int, any, error) { return http.StatusOK, make(chan int), nil }).ServeHTTPReturn(w, r) resp := checkStatus(t, w, "error", http.StatusInternalServerError) if resp.Error != "json marshal error" { t.Fatalf("wrong error") } }) t.Run("500 empty status code", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/", nil) JSONHandlerFunc(func(r *http.Request) (status int, data any, err error) { return }).ServeHTTPReturn(w, r) checkStatus(t, w, "error", http.StatusInternalServerError) }) t.Run("403 forbidden, status returned by JSONHandlerFunc and HTTPError agree", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/", nil) JSONHandlerFunc(func(r *http.Request) (int, any, error) { return http.StatusForbidden, nil, Error(http.StatusForbidden, "403 forbidden", nil) }).ServeHTTPReturn(w, r) want := &Response{ Status: "error", Data: &Data{}, Error: "403 forbidden", } got := checkStatus(t, w, "error", http.StatusForbidden) if diff := cmp.Diff(want, got); diff != "" { t.Fatalf(diff) } }) t.Run("403 forbidden, status returned by JSONHandlerFunc and HTTPError do not agree", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/", nil) err := JSONHandlerFunc(func(r *http.Request) (int, any, error) { return http.StatusInternalServerError, nil, Error(http.StatusForbidden, "403 forbidden", nil) }).ServeHTTPReturn(w, r) if !strings.HasPrefix(err.Error(), "[unexpected]") { t.Fatalf("returned error should have `[unexpected]` to note the disagreeing status codes: %v", err) } want := &Response{ Status: "error", Data: &Data{}, Error: "403 forbidden", } got := checkStatus(t, w, "error", http.StatusForbidden) if diff := cmp.Diff(want, got); diff != "" { t.Fatalf("(-want,+got):\n%s", diff) } }) } func TestAcceptsEncoding(t *testing.T) { tests := []struct { in, enc string want bool }{ {"", "gzip", false}, {"gzip", "gzip", true}, {"foo,gzip", "gzip", true}, {"foo, gzip", "gzip", true}, {"foo, gzip ", "gzip", true}, {"gzip, foo ", "gzip", true}, {"gzip, foo ", "br", false}, {"gzip, foo ", "fo", false}, {"gzip;q=1.2, foo ", "gzip", true}, {" gzip;q=1.2, foo ", "gzip", true}, } for i, tt := range tests { h := make(http.Header) if tt.in != "" { h.Set("Accept-Encoding", tt.in) } got := AcceptsEncoding(&http.Request{Header: h}, tt.enc) if got != tt.want { t.Errorf("%d. got %v; want %v", i, got, tt.want) } } }