client/web: fix CSRF handler order in web UI (#15143)

Fix the order of the CSRF handlers (HTTP plaintext context setting,
_then_ enforcement) in the construction of the web UI server. This
resolves false-positive "invalid Origin" 403 exceptions when attempting
to update settings in the web UI.

Add unit test to exercise the CSRF protection failure and success cases
for our web UI configuration.

Updates #14822
Updates #14872

Signed-off-by: Patrick O'Doherty <patrick@tailscale.com>
pull/15144/head
Patrick O'Doherty 10 months ago committed by GitHub
parent ae303d41dd
commit f5522e62d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -203,15 +203,25 @@ func NewServer(opts ServerOpts) (s *Server, err error) {
} }
s.assetsHandler, s.assetsCleanup = assetsHandler(s.devMode) s.assetsHandler, s.assetsCleanup = assetsHandler(s.devMode)
var metric string // clientmetric to report on startup var metric string
s.apiHandler, metric = s.modeAPIHandler(s.mode)
s.apiHandler = s.withCSRF(s.apiHandler)
// Create handler for "/api" requests with CSRF protection. // Don't block startup on reporting metric.
// We don't require secure cookies, since the web client is regularly used // Report in separate go routine with 5 second timeout.
// on network appliances that are served on local non-https URLs. go func() {
// The client is secured by limiting the interface it listens on, ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
// or by authenticating requests before they reach the web client. defer cancel()
s.lc.IncrementCounter(ctx, metric, 1)
}()
return s, nil
}
func (s *Server) withCSRF(h http.Handler) http.Handler {
csrfProtect := csrf.Protect(s.csrfKey(), csrf.Secure(false)) csrfProtect := csrf.Protect(s.csrfKey(), csrf.Secure(false))
// ref https://github.com/tailscale/tailscale/pull/14822
// signal to the CSRF middleware that the request is being served over // signal to the CSRF middleware that the request is being served over
// plaintext HTTP to skip TLS-only header checks. // plaintext HTTP to skip TLS-only header checks.
withSetPlaintext := func(h http.Handler) http.Handler { withSetPlaintext := func(h http.Handler) http.Handler {
@ -221,27 +231,24 @@ func NewServer(opts ServerOpts) (s *Server, err error) {
}) })
} }
switch s.mode { // NB: the order of the withSetPlaintext and csrfProtect calls is important
// to ensure that we signal to the CSRF middleware that the request is being
// served over plaintext HTTP and not over TLS as it presumes by default.
return withSetPlaintext(csrfProtect(h))
}
func (s *Server) modeAPIHandler(mode ServerMode) (http.Handler, string) {
switch mode {
case LoginServerMode: case LoginServerMode:
s.apiHandler = csrfProtect(withSetPlaintext(http.HandlerFunc(s.serveLoginAPI))) return http.HandlerFunc(s.serveLoginAPI), "web_login_client_initialization"
metric = "web_login_client_initialization"
case ReadOnlyServerMode: case ReadOnlyServerMode:
s.apiHandler = csrfProtect(withSetPlaintext(http.HandlerFunc(s.serveLoginAPI))) return http.HandlerFunc(s.serveLoginAPI), "web_readonly_client_initialization"
metric = "web_readonly_client_initialization"
case ManageServerMode: case ManageServerMode:
s.apiHandler = csrfProtect(withSetPlaintext(http.HandlerFunc(s.serveAPI))) return http.HandlerFunc(s.serveAPI), "web_client_initialization"
metric = "web_client_initialization" default: // invalid mode
log.Fatalf("invalid mode: %v", mode)
} }
return nil, ""
// Don't block startup on reporting metric.
// Report in separate go routine with 5 second timeout.
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s.lc.IncrementCounter(ctx, metric, 1)
}()
return s, nil
} }
func (s *Server) Shutdown() { func (s *Server) Shutdown() {

@ -11,6 +11,7 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/http/cookiejar"
"net/http/httptest" "net/http/httptest"
"net/netip" "net/netip"
"net/url" "net/url"
@ -20,6 +21,7 @@ import (
"time" "time"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/gorilla/csrf"
"tailscale.com/client/local" "tailscale.com/client/local"
"tailscale.com/client/tailscale/apitype" "tailscale.com/client/tailscale/apitype"
"tailscale.com/ipn" "tailscale.com/ipn"
@ -1477,3 +1479,83 @@ func mockWaitAuthURL(_ context.Context, id string, src tailcfg.NodeID) (*tailcfg
return nil, errors.New("unknown id") return nil, errors.New("unknown id")
} }
} }
func TestCSRFProtect(t *testing.T) {
s := &Server{}
mux := http.NewServeMux()
mux.HandleFunc("GET /test/csrf-token", func(w http.ResponseWriter, r *http.Request) {
token := csrf.Token(r)
_, err := io.WriteString(w, token)
if err != nil {
t.Fatal(err)
}
})
mux.HandleFunc("POST /test/csrf-protected", func(w http.ResponseWriter, r *http.Request) {
_, err := io.WriteString(w, "ok")
if err != nil {
t.Fatal(err)
}
})
h := s.withCSRF(mux)
ser := httptest.NewServer(h)
defer ser.Close()
jar, err := cookiejar.New(nil)
if err != nil {
t.Fatalf("unable to construct cookie jar: %v", err)
}
client := ser.Client()
client.Jar = jar
// make GET request to populate cookie jar
resp, err := client.Get(ser.URL + "/test/csrf-token")
if err != nil {
t.Fatalf("unable to make request: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status: %v", resp.Status)
}
tokenBytes, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("unable to read body: %v", err)
}
csrfToken := strings.TrimSpace(string(tokenBytes))
if csrfToken == "" {
t.Fatal("empty csrf token")
}
// make a POST request without the CSRF header; ensure it fails
resp, err = client.Post(ser.URL+"/test/csrf-protected", "text/plain", nil)
if err != nil {
t.Fatalf("unable to make request: %v", err)
}
if resp.StatusCode != http.StatusForbidden {
t.Fatalf("unexpected status: %v", resp.Status)
}
// make a POST request with the CSRF header; ensure it succeeds
req, err := http.NewRequest("POST", ser.URL+"/test/csrf-protected", nil)
if err != nil {
t.Fatalf("error building request: %v", err)
}
req.Header.Set("X-CSRF-Token", csrfToken)
resp, err = client.Do(req)
if err != nil {
t.Fatalf("unable to make request: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status: %v", resp.Status)
}
defer resp.Body.Close()
out, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("unable to read body: %v", err)
}
if string(out) != "ok" {
t.Fatalf("unexpected body: %q", out)
}
}

Loading…
Cancel
Save