@ -11,6 +11,7 @@ import (
"fmt"
"io"
"net/http"
"net/http/cookiejar"
"net/http/httptest"
"net/netip"
"net/url"
@ -20,6 +21,7 @@ import (
"time"
"github.com/google/go-cmp/cmp"
"github.com/gorilla/csrf"
"tailscale.com/client/local"
"tailscale.com/client/tailscale/apitype"
"tailscale.com/ipn"
@ -1477,3 +1479,83 @@ func mockWaitAuthURL(_ context.Context, id string, src tailcfg.NodeID) (*tailcfg
return nil , errors . New ( "unknown id" )
}
}
func TestCSRFProtect ( t * testing . T ) {
s := & Server { }
mux := http . NewServeMux ( )
mux . HandleFunc ( "GET /test/csrf-token" , func ( w http . ResponseWriter , r * http . Request ) {
token := csrf . Token ( r )
_ , err := io . WriteString ( w , token )
if err != nil {
t . Fatal ( err )
}
} )
mux . HandleFunc ( "POST /test/csrf-protected" , func ( w http . ResponseWriter , r * http . Request ) {
_ , err := io . WriteString ( w , "ok" )
if err != nil {
t . Fatal ( err )
}
} )
h := s . withCSRF ( mux )
ser := httptest . NewServer ( h )
defer ser . Close ( )
jar , err := cookiejar . New ( nil )
if err != nil {
t . Fatalf ( "unable to construct cookie jar: %v" , err )
}
client := ser . Client ( )
client . Jar = jar
// make GET request to populate cookie jar
resp , err := client . Get ( ser . URL + "/test/csrf-token" )
if err != nil {
t . Fatalf ( "unable to make request: %v" , err )
}
defer resp . Body . Close ( )
if resp . StatusCode != http . StatusOK {
t . Fatalf ( "unexpected status: %v" , resp . Status )
}
tokenBytes , err := io . ReadAll ( resp . Body )
if err != nil {
t . Fatalf ( "unable to read body: %v" , err )
}
csrfToken := strings . TrimSpace ( string ( tokenBytes ) )
if csrfToken == "" {
t . Fatal ( "empty csrf token" )
}
// make a POST request without the CSRF header; ensure it fails
resp , err = client . Post ( ser . URL + "/test/csrf-protected" , "text/plain" , nil )
if err != nil {
t . Fatalf ( "unable to make request: %v" , err )
}
if resp . StatusCode != http . StatusForbidden {
t . Fatalf ( "unexpected status: %v" , resp . Status )
}
// make a POST request with the CSRF header; ensure it succeeds
req , err := http . NewRequest ( "POST" , ser . URL + "/test/csrf-protected" , nil )
if err != nil {
t . Fatalf ( "error building request: %v" , err )
}
req . Header . Set ( "X-CSRF-Token" , csrfToken )
resp , err = client . Do ( req )
if err != nil {
t . Fatalf ( "unable to make request: %v" , err )
}
if resp . StatusCode != http . StatusOK {
t . Fatalf ( "unexpected status: %v" , resp . Status )
}
defer resp . Body . Close ( )
out , err := io . ReadAll ( resp . Body )
if err != nil {
t . Fatalf ( "unable to read body: %v" , err )
}
if string ( out ) != "ok" {
t . Fatalf ( "unexpected body: %q" , out )
}
}