diff --git a/client/web/web.go b/client/web/web.go index 6203b4c18..e9810ccd0 100644 --- a/client/web/web.go +++ b/client/web/web.go @@ -203,15 +203,25 @@ func NewServer(opts ServerOpts) (s *Server, err error) { } s.assetsHandler, s.assetsCleanup = assetsHandler(s.devMode) - var metric string // clientmetric to report on startup + var metric string + s.apiHandler, metric = s.modeAPIHandler(s.mode) + s.apiHandler = s.withCSRF(s.apiHandler) - // Create handler for "/api" requests with CSRF protection. - // We don't require secure cookies, since the web client is regularly used - // on network appliances that are served on local non-https URLs. - // The client is secured by limiting the interface it listens on, - // or by authenticating requests before they reach the web client. + // Don't block startup on reporting metric. + // Report in separate go routine with 5 second timeout. + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + s.lc.IncrementCounter(ctx, metric, 1) + }() + + return s, nil +} + +func (s *Server) withCSRF(h http.Handler) http.Handler { csrfProtect := csrf.Protect(s.csrfKey(), csrf.Secure(false)) + // ref https://github.com/tailscale/tailscale/pull/14822 // signal to the CSRF middleware that the request is being served over // plaintext HTTP to skip TLS-only header checks. withSetPlaintext := func(h http.Handler) http.Handler { @@ -221,27 +231,24 @@ func NewServer(opts ServerOpts) (s *Server, err error) { }) } - switch s.mode { + // NB: the order of the withSetPlaintext and csrfProtect calls is important + // to ensure that we signal to the CSRF middleware that the request is being + // served over plaintext HTTP and not over TLS as it presumes by default. + return withSetPlaintext(csrfProtect(h)) +} + +func (s *Server) modeAPIHandler(mode ServerMode) (http.Handler, string) { + switch mode { case LoginServerMode: - s.apiHandler = csrfProtect(withSetPlaintext(http.HandlerFunc(s.serveLoginAPI))) - metric = "web_login_client_initialization" + return http.HandlerFunc(s.serveLoginAPI), "web_login_client_initialization" case ReadOnlyServerMode: - s.apiHandler = csrfProtect(withSetPlaintext(http.HandlerFunc(s.serveLoginAPI))) - metric = "web_readonly_client_initialization" + return http.HandlerFunc(s.serveLoginAPI), "web_readonly_client_initialization" case ManageServerMode: - s.apiHandler = csrfProtect(withSetPlaintext(http.HandlerFunc(s.serveAPI))) - metric = "web_client_initialization" + return http.HandlerFunc(s.serveAPI), "web_client_initialization" + default: // invalid mode + log.Fatalf("invalid mode: %v", mode) } - - // Don't block startup on reporting metric. - // Report in separate go routine with 5 second timeout. - go func() { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - s.lc.IncrementCounter(ctx, metric, 1) - }() - - return s, nil + return nil, "" } func (s *Server) Shutdown() { diff --git a/client/web/web_test.go b/client/web/web_test.go index b9242f6ac..291356260 100644 --- a/client/web/web_test.go +++ b/client/web/web_test.go @@ -11,6 +11,7 @@ import ( "fmt" "io" "net/http" + "net/http/cookiejar" "net/http/httptest" "net/netip" "net/url" @@ -20,6 +21,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "github.com/gorilla/csrf" "tailscale.com/client/local" "tailscale.com/client/tailscale/apitype" "tailscale.com/ipn" @@ -1477,3 +1479,83 @@ func mockWaitAuthURL(_ context.Context, id string, src tailcfg.NodeID) (*tailcfg return nil, errors.New("unknown id") } } + +func TestCSRFProtect(t *testing.T) { + s := &Server{} + + mux := http.NewServeMux() + mux.HandleFunc("GET /test/csrf-token", func(w http.ResponseWriter, r *http.Request) { + token := csrf.Token(r) + _, err := io.WriteString(w, token) + if err != nil { + t.Fatal(err) + } + }) + mux.HandleFunc("POST /test/csrf-protected", func(w http.ResponseWriter, r *http.Request) { + _, err := io.WriteString(w, "ok") + if err != nil { + t.Fatal(err) + } + }) + h := s.withCSRF(mux) + ser := httptest.NewServer(h) + defer ser.Close() + + jar, err := cookiejar.New(nil) + if err != nil { + t.Fatalf("unable to construct cookie jar: %v", err) + } + + client := ser.Client() + client.Jar = jar + + // make GET request to populate cookie jar + resp, err := client.Get(ser.URL + "/test/csrf-token") + if err != nil { + t.Fatalf("unable to make request: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %v", resp.Status) + } + tokenBytes, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("unable to read body: %v", err) + } + + csrfToken := strings.TrimSpace(string(tokenBytes)) + if csrfToken == "" { + t.Fatal("empty csrf token") + } + + // make a POST request without the CSRF header; ensure it fails + resp, err = client.Post(ser.URL+"/test/csrf-protected", "text/plain", nil) + if err != nil { + t.Fatalf("unable to make request: %v", err) + } + if resp.StatusCode != http.StatusForbidden { + t.Fatalf("unexpected status: %v", resp.Status) + } + + // make a POST request with the CSRF header; ensure it succeeds + req, err := http.NewRequest("POST", ser.URL+"/test/csrf-protected", nil) + if err != nil { + t.Fatalf("error building request: %v", err) + } + req.Header.Set("X-CSRF-Token", csrfToken) + resp, err = client.Do(req) + if err != nil { + t.Fatalf("unable to make request: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %v", resp.Status) + } + defer resp.Body.Close() + out, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("unable to read body: %v", err) + } + if string(out) != "ok" { + t.Fatalf("unexpected body: %q", out) + } +}