diff --git a/cmd/tsidp/tsidp.go b/cmd/tsidp/tsidp.go index 2fc6d27e4..c02b09745 100644 --- a/cmd/tsidp/tsidp.go +++ b/cmd/tsidp/tsidp.go @@ -47,6 +47,7 @@ import ( "tailscale.com/tsnet" "tailscale.com/types/key" "tailscale.com/types/lazy" + "tailscale.com/types/opt" "tailscale.com/types/views" "tailscale.com/util/mak" "tailscale.com/util/must" @@ -61,20 +62,40 @@ type ctxConn struct{} // accessing the IDP over Funnel are persisted. const funnelClientsFile = "oidc-funnel-clients.json" +// oauthClientsFile is the new file name for OAuth clients when running in secure mode. +const oauthClientsFile = "oauth-clients.json" + +// deprecatedFunnelClientsFile is the name used when renaming the old file. +const deprecatedFunnelClientsFile = "deprecated-oidc-funnel-clients.json" + // oidcKeyFile is where the OIDC private key is persisted. const oidcKeyFile = "oidc-key.json" var ( - flagVerbose = flag.Bool("verbose", false, "be verbose") - flagPort = flag.Int("port", 443, "port to listen on") - flagLocalPort = flag.Int("local-port", -1, "allow requests from localhost") - flagUseLocalTailscaled = flag.Bool("use-local-tailscaled", false, "use local tailscaled instead of tsnet") - flagFunnel = flag.Bool("funnel", false, "use Tailscale Funnel to make tsidp available on the public internet") - flagHostname = flag.String("hostname", "idp", "tsnet hostname to use instead of idp") - flagDir = flag.String("dir", "", "tsnet state directory; a default one will be created if not provided") + flagVerbose = flag.Bool("verbose", false, "be verbose") + flagPort = flag.Int("port", 443, "port to listen on") + flagLocalPort = flag.Int("local-port", -1, "allow requests from localhost") + flagUseLocalTailscaled = flag.Bool("use-local-tailscaled", false, "use local tailscaled instead of tsnet") + flagFunnel = flag.Bool("funnel", false, "use Tailscale Funnel to make tsidp available on the public internet") + flagHostname = flag.String("hostname", "idp", "tsnet hostname to use instead of idp") + flagDir = flag.String("dir", "", "tsnet state directory; a default one will be created if not provided") + flagAllowInsecureRegistrationBool opt.Bool + flagAllowInsecureRegistration = opt.BoolFlag{Bool: &flagAllowInsecureRegistrationBool} ) +// getAllowInsecureRegistration returns whether to allow OAuth flows without pre-registered clients. +// Default is true for backward compatibility; explicitly set to false for strict OAuth compliance. +func getAllowInsecureRegistration() bool { + v, ok := flagAllowInsecureRegistration.Get() + if !ok { + // Flag not set, default to true (allow insecure for backward compatibility) + return true + } + return v +} + func main() { + flag.Var(&flagAllowInsecureRegistration, "allow-insecure-registration", "allow OAuth flows without pre-registered client credentials (default: true for backward compatibility; set to false for strict OAuth compliance)") flag.Parse() ctx := context.Background() if !envknob.UseWIPCode() { @@ -172,10 +193,11 @@ func main() { } srv := &idpServer{ - lc: lc, - funnel: *flagFunnel, - localTSMode: *flagUseLocalTailscaled, - rootPath: rootPath, + lc: lc, + funnel: *flagFunnel, + localTSMode: *flagUseLocalTailscaled, + rootPath: rootPath, + allowInsecureRegistration: getAllowInsecureRegistration(), } if *flagPort != 443 { @@ -184,20 +206,29 @@ func main() { srv.serverURL = fmt.Sprintf("https://%s", strings.TrimSuffix(st.Self.DNSName, ".")) } - // Load funnel clients from disk if they exist, regardless of whether funnel is enabled - // This ensures OIDC clients persist across restarts - funnelClientsFilePath, err := getConfigFilePath(rootPath, funnelClientsFile) - if err != nil { - log.Fatalf("could not get funnel clients file path: %v", err) + // If allowInsecureRegistration is enabled, the old oidc-funnel-clients.json path is used. + // If allowInsecureRegistration is disabled, attempt to migrate the old path to oidc-clients.json and use this new path. + var clientsFilePath string + if !srv.allowInsecureRegistration { + clientsFilePath, err = migrateOAuthClients(rootPath) + if err != nil { + log.Fatalf("could not migrate OAuth clients: %v", err) + } + } else { + clientsFilePath, err = getConfigFilePath(rootPath, funnelClientsFile) + if err != nil { + log.Fatalf("could not get funnel clients file path: %v", err) + } } - f, err := os.Open(funnelClientsFilePath) + + f, err := os.Open(clientsFilePath) if err == nil { if err := json.NewDecoder(f).Decode(&srv.funnelClients); err != nil { - log.Fatalf("could not parse %s: %v", funnelClientsFilePath, err) + log.Fatalf("could not parse %s: %v", clientsFilePath, err) } f.Close() } else if !errors.Is(err, os.ErrNotExist) { - log.Fatalf("could not open %s: %v", funnelClientsFilePath, err) + log.Fatalf("could not open %s: %v", clientsFilePath, err) } log.Printf("Running tsidp at %s ...", srv.serverURL) @@ -304,12 +335,13 @@ func serveOnLocalTailscaled(ctx context.Context, lc *local.Client, st *ipnstate. } type idpServer struct { - lc *local.Client - loopbackURL string - serverURL string // "https://foo.bar.ts.net" - funnel bool - localTSMode bool - rootPath string // root path, used for storing state files + lc *local.Client + loopbackURL string + serverURL string // "https://foo.bar.ts.net" + funnel bool + localTSMode bool + rootPath string // root path, used for storing state files + allowInsecureRegistration bool // If true, allow OAuth without pre-registered clients lazyMux lazy.SyncValue[*http.ServeMux] lazySigningKey lazy.SyncValue[*signingKey] @@ -393,14 +425,15 @@ func (ar *authRequest) allowRelyingParty(r *http.Request, lc *local.Client) erro } func (s *idpServer) authorize(w http.ResponseWriter, r *http.Request) { + // This URL is visited by the user who is being authenticated. If they are // visiting the URL over Funnel, that means they are not part of the // tailnet that they are trying to be authenticated for. + // NOTE: Funnel request behavior is the same regardless of secure or insecure mode. if isFunnelRequest(r) { http.Error(w, "tsidp: unauthorized", http.StatusUnauthorized) return } - uq := r.URL.Query() redirectURI := uq.Get("redirect_uri") @@ -409,6 +442,86 @@ func (s *idpServer) authorize(w http.ResponseWriter, r *http.Request) { return } + clientID := uq.Get("client_id") + if clientID == "" { + http.Error(w, "tsidp: must specify client_id", http.StatusBadRequest) + return + } + + if !s.allowInsecureRegistration { + // When insecure registration is NOT allowed, validate client_id exists but defer client_secret validation to token endpoint + // This follows RFC 6749 which specifies client authentication should occur at token endpoint, not authorization endpoint + + s.mu.Lock() + c, ok := s.funnelClients[clientID] + s.mu.Unlock() + if !ok { + http.Error(w, "tsidp: invalid client ID", http.StatusBadRequest) + return + } + + // Validate client_id matches (public identifier validation) + clientIDcmp := subtle.ConstantTimeCompare([]byte(clientID), []byte(c.ID)) + if clientIDcmp != 1 { + http.Error(w, "tsidp: invalid client ID", http.StatusBadRequest) + return + } + + // Validate redirect URI + if redirectURI != c.RedirectURI { + http.Error(w, "tsidp: redirect_uri mismatch", http.StatusBadRequest) + return + } + + // Get user information + var remoteAddr string + if s.localTSMode { + remoteAddr = r.Header.Get("X-Forwarded-For") + } else { + remoteAddr = r.RemoteAddr + } + + // Check who is visiting the authorize endpoint. + var who *apitype.WhoIsResponse + var err error + who, err = s.lc.WhoIs(r.Context(), remoteAddr) + if err != nil { + log.Printf("Error getting WhoIs: %v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + code := rands.HexString(32) + ar := &authRequest{ + nonce: uq.Get("nonce"), + remoteUser: who, + redirectURI: redirectURI, + clientID: clientID, + funnelRP: c, // Store the validated client + } + + s.mu.Lock() + mak.Set(&s.code, code, ar) + s.mu.Unlock() + + q := make(url.Values) + q.Set("code", code) + if state := uq.Get("state"); state != "" { + q.Set("state", state) + } + parsedURL, err := url.Parse(redirectURI) + if err != nil { + http.Error(w, "invalid redirect URI", http.StatusInternalServerError) + return + } + parsedURL.RawQuery = q.Encode() + u := parsedURL.String() + log.Printf("Redirecting to %q", u) + + http.Redirect(w, r, u, http.StatusFound) + return + } + var remoteAddr string if s.localTSMode { // in local tailscaled mode, the local tailscaled is forwarding us @@ -430,7 +543,7 @@ func (s *idpServer) authorize(w http.ResponseWriter, r *http.Request) { nonce: uq.Get("nonce"), remoteUser: who, redirectURI: redirectURI, - clientID: uq.Get("client_id"), + clientID: clientID, } if r.URL.Path == "/authorize/funnel" { @@ -466,7 +579,13 @@ func (s *idpServer) authorize(w http.ResponseWriter, r *http.Request) { if state := uq.Get("state"); state != "" { q.Set("state", state) } - u := redirectURI + "?" + q.Encode() + parsedURL, err := url.Parse(redirectURI) + if err != nil { + http.Error(w, "invalid redirect URI", http.StatusInternalServerError) + return + } + parsedURL.RawQuery = q.Encode() + u := parsedURL.String() log.Printf("Redirecting to %q", u) http.Redirect(w, r, u, http.StatusFound) @@ -476,7 +595,13 @@ func (s *idpServer) newMux() *http.ServeMux { mux := http.NewServeMux() mux.HandleFunc(oidcJWKSPath, s.serveJWKS) mux.HandleFunc(oidcConfigPath, s.serveOpenIDConfig) - mux.HandleFunc("/authorize/", s.authorize) + if !s.allowInsecureRegistration { + // When insecure registration is NOT allowed, use a single /authorize endpoint + mux.HandleFunc("/authorize", s.authorize) + } else { + // When insecure registration is allowed, preserve original behavior with path-based routing + mux.HandleFunc("/authorize/", s.authorize) + } mux.HandleFunc("/userinfo", s.serveUserInfo) mux.HandleFunc("/token", s.serveToken) mux.HandleFunc("/clients/", s.serveClients) @@ -513,6 +638,24 @@ func (s *idpServer) serveUserInfo(w http.ResponseWriter, r *http.Request) { s.mu.Lock() delete(s.accessToken, tk) s.mu.Unlock() + return + } + + if !s.allowInsecureRegistration { + // When insecure registration is NOT allowed, validate that the token was issued to a valid client. + if ar.clientID == "" { + http.Error(w, "tsidp: no client associated with token", http.StatusBadRequest) + return + } + + // Validate client still exists + s.mu.Lock() + _, clientExists := s.funnelClients[ar.clientID] + s.mu.Unlock() + if !clientExists { + http.Error(w, "tsidp: client no longer exists", http.StatusUnauthorized) + return + } } ui := userInfo{} @@ -722,11 +865,58 @@ func (s *idpServer) serveToken(w http.ResponseWriter, r *http.Request) { http.Error(w, "tsidp: code not found", http.StatusBadRequest) return } - if err := ar.allowRelyingParty(r, s.lc); err != nil { - log.Printf("Error allowing relying party: %v", err) - http.Error(w, err.Error(), http.StatusForbidden) - return + + if !s.allowInsecureRegistration { + // When insecure registration is NOT allowed, always validate client credentials regardless of request source + clientID := r.FormValue("client_id") + clientSecret := r.FormValue("client_secret") + + // Try basic auth if form values are empty + if clientID == "" || clientSecret == "" { + if basicClientID, basicClientSecret, ok := r.BasicAuth(); ok { + if clientID == "" { + clientID = basicClientID + } + if clientSecret == "" { + clientSecret = basicClientSecret + } + } + } + + if clientID == "" || clientSecret == "" { + http.Error(w, "tsidp: client credentials required in when insecure registration is not allowed", http.StatusUnauthorized) + return + } + + // Validate against the stored auth request + if ar.clientID != clientID { + http.Error(w, "tsidp: client_id mismatch", http.StatusBadRequest) + return + } + + // Validate client credentials against stored clients + if ar.funnelRP == nil { + http.Error(w, "tsidp: no client information found", http.StatusBadRequest) + return + } + + clientIDcmp := subtle.ConstantTimeCompare([]byte(clientID), []byte(ar.funnelRP.ID)) + clientSecretcmp := subtle.ConstantTimeCompare([]byte(clientSecret), []byte(ar.funnelRP.Secret)) + if clientIDcmp != 1 || clientSecretcmp != 1 { + http.Error(w, "tsidp: invalid client credentials", http.StatusUnauthorized) + return + } + } else { + // Original behavior when insecure registration is allowed + // Only checks ClientID and Client Secret when over funnel. + // Local connections are allowed and tailnet connections only check matching nodeIDs. + if err := ar.allowRelyingParty(r, s.lc); err != nil { + log.Printf("Error allowing relying party: %v", err) + http.Error(w, err.Error(), http.StatusForbidden) + return + } } + if ar.redirectURI != r.FormValue("redirect_uri") { http.Error(w, "tsidp: redirect_uri mismatch", http.StatusBadRequest) return @@ -977,24 +1167,38 @@ func (s *idpServer) serveOpenIDConfig(w http.ResponseWriter, r *http.Request) { http.Error(w, "tsidp: not found", http.StatusNotFound) return } - ap, err := netip.ParseAddrPort(r.RemoteAddr) - if err != nil { - log.Printf("Error parsing remote addr: %v", err) - return - } + var authorizeEndpoint string rpEndpoint := s.serverURL - if isFunnelRequest(r) { - authorizeEndpoint = fmt.Sprintf("%s/authorize/funnel", s.serverURL) - } else if who, err := s.lc.WhoIs(r.Context(), r.RemoteAddr); err == nil { - authorizeEndpoint = fmt.Sprintf("%s/authorize/%d", s.serverURL, who.Node.ID) - } else if ap.Addr().IsLoopback() { - rpEndpoint = s.loopbackURL - authorizeEndpoint = fmt.Sprintf("%s/authorize/localhost", s.serverURL) + + if !s.allowInsecureRegistration { + // When insecure registration is NOT allowed, use a single authorization endpoint for all request types + // This will be the same regardless of if the user is on localhost, tailscale, or funnel. + authorizeEndpoint = fmt.Sprintf("%s/authorize", s.serverURL) + rpEndpoint = s.serverURL } else { - log.Printf("Error getting WhoIs: %v", err) - http.Error(w, err.Error(), http.StatusInternalServerError) - return + // When insecure registration is allowed TSIDP uses the requestors nodeID + // (typically that of the resource server during auto discovery) when on the tailnet + // and adds it to the authorize URL as a replacement clientID for when the user authorizes. + // The behavior over funnel drops the nodeID & clientID replacement behvaior and does require a + // previously created clientID and client secret. + ap, err := netip.ParseAddrPort(r.RemoteAddr) + if err != nil { + log.Printf("Error parsing remote addr: %v", err) + return + } + if isFunnelRequest(r) { + authorizeEndpoint = fmt.Sprintf("%s/authorize/funnel", s.serverURL) + } else if who, err := s.lc.WhoIs(r.Context(), r.RemoteAddr); err == nil { + authorizeEndpoint = fmt.Sprintf("%s/authorize/%d", s.serverURL, who.Node.ID) + } else if ap.Addr().IsLoopback() { + rpEndpoint = s.loopbackURL + authorizeEndpoint = fmt.Sprintf("%s/authorize/localhost", s.serverURL) + } else { + log.Printf("Error getting WhoIs: %v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } } w.Header().Set("Content-Type", "application/json") @@ -1148,20 +1352,27 @@ func (s *idpServer) serveDeleteClient(w http.ResponseWriter, r *http.Request, cl } // storeFunnelClientsLocked writes the current mapping of OIDC client ID/secret -// pairs for RPs that access the IDP over funnel. s.mu must be held while -// calling this. +// pairs for RPs that access the IDP. When insecure registration is NOT allowed, uses oauth-clients.json; +// otherwise uses oidc-funnel-clients.json. s.mu must be held while calling this. func (s *idpServer) storeFunnelClientsLocked() error { var buf bytes.Buffer if err := json.NewEncoder(&buf).Encode(s.funnelClients); err != nil { return err } - funnelClientsFilePath, err := getConfigFilePath(s.rootPath, funnelClientsFile) + var clientsFilePath string + var err error + if !s.allowInsecureRegistration { + clientsFilePath, err = getConfigFilePath(s.rootPath, oauthClientsFile) + } else { + clientsFilePath, err = getConfigFilePath(s.rootPath, funnelClientsFile) + } + if err != nil { return fmt.Errorf("storeFunnelClientsLocked: %v", err) } - return os.WriteFile(funnelClientsFilePath, buf.Bytes(), 0600) + return os.WriteFile(clientsFilePath, buf.Bytes(), 0600) } const ( @@ -1275,9 +1486,67 @@ func isFunnelRequest(r *http.Request) bool { return false } +// migrateOAuthClients migrates from oidc-funnel-clients.json to oauth-clients.json. +// If oauth-clients.json already exists, no migration is performed. +// If both files are missing a new configuration is created. +// The path to the new configuration file is returned. +func migrateOAuthClients(rootPath string) (string, error) { + // First, check for oauth-clients.json (new file) + oauthPath, err := getConfigFilePath(rootPath, oauthClientsFile) + if err != nil { + return "", fmt.Errorf("could not get oauth clients file path: %w", err) + } + if _, err := os.Stat(oauthPath); err == nil { + // oauth-clients.json already exists, use it + return oauthPath, nil + } + + // Check for old oidc-funnel-clients.json + oldPath, err := getConfigFilePath(rootPath, funnelClientsFile) + if err != nil { + return "", fmt.Errorf("could not get funnel clients file path: %w", err) + } + if _, err := os.Stat(oldPath); err == nil { + // Old file exists, migrate it + log.Printf("Migrating OAuth clients from %s to %s", oldPath, oauthPath) + + // Read the old file + data, err := os.ReadFile(oldPath) + if err != nil { + return "", fmt.Errorf("could not read old funnel clients file: %w", err) + } + + // Write to new location + if err := os.WriteFile(oauthPath, data, 0600); err != nil { + return "", fmt.Errorf("could not write new oauth clients file: %w", err) + } + + // Rename old file to deprecated name + deprecatedPath, err := getConfigFilePath(rootPath, deprecatedFunnelClientsFile) + if err != nil { + return "", fmt.Errorf("could not get deprecated file path: %w", err) + } + if err := os.Rename(oldPath, deprecatedPath); err != nil { + log.Printf("Warning: could not rename old file to deprecated name: %v", err) + } else { + log.Printf("Renamed old file to %s", deprecatedPath) + } + + return oauthPath, nil + } + + // Neither file exists, create empty oauth-clients.json + log.Printf("Creating empty OAuth clients file at %s", oauthPath) + if err := os.WriteFile(oauthPath, []byte("{}"), 0600); err != nil { + return "", fmt.Errorf("could not create empty oauth clients file: %w", err) + } + + return oauthPath, nil +} + // getConfigFilePath returns the path to the config file for the given file name. // The oidc-key.json and funnel-clients.json files were originally opened and written -// to without paths, and ended up in /root dir or home directory of the user running +// to without paths, and ended up in /root or home directory of the user running // the process. To maintain backward compatibility, we return the naked file name if that // file exists already, otherwise we return the full path in the rootPath. func getConfigFilePath(rootPath string, fileName string) (string, error) { diff --git a/cmd/tsidp/tsidp_test.go b/cmd/tsidp/tsidp_test.go index e5465d3cf..4f5af9e59 100644 --- a/cmd/tsidp/tsidp_test.go +++ b/cmd/tsidp/tsidp_test.go @@ -1,6 +1,19 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +// Package main tests for tsidp focus on OAuth security boundaries and +// correct implementation of the OpenID Connect identity provider. +// +// Test Strategy: +// - Tests are intentionally granular to provide clear failure signals when +// security-critical logic breaks +// - OAuth flow tests cover both strict mode (registered clients only) and +// legacy mode (local funnel clients) to ensure proper access controls +// - Helper functions like normalizeMap ensure deterministic comparisons +// despite JSON marshaling order variations +// - The privateKey global is reused across tests for performance (RSA key +// generation is expensive) + package main import ( @@ -16,21 +29,28 @@ import ( "net/netip" "net/url" "os" + "path/filepath" "reflect" "sort" "strings" + "sync" "testing" "time" "gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2/jwt" + "tailscale.com/client/local" "tailscale.com/client/tailscale/apitype" "tailscale.com/tailcfg" "tailscale.com/types/key" + "tailscale.com/types/opt" "tailscale.com/types/views" ) -// normalizeMap recursively sorts []any values in a map[string]any +// normalizeMap recursively sorts []any values in a map[string]any to ensure +// deterministic test comparisons. This is necessary because JSON marshaling +// doesn't guarantee array order, and we need stable comparisons when testing +// claim merging and flattening logic. func normalizeMap(t *testing.T, m map[string]any) map[string]any { t.Helper() normalized := make(map[string]any, len(m)) @@ -66,7 +86,13 @@ func mustMarshalJSON(t *testing.T, v any) tailcfg.RawMessage { return tailcfg.RawMessage(b) } -var privateKey *rsa.PrivateKey = nil +// privateKey is a shared RSA private key used across tests. It's lazily +// initialized on first use to avoid the expensive key generation cost +// for every test. Protected by privateKeyMu for thread safety. +var ( + privateKey *rsa.PrivateKey + privateKeyMu sync.Mutex +) func oidcTestingSigner(t *testing.T) jose.Signer { t.Helper() @@ -86,6 +112,9 @@ func oidcTestingPublicKey(t *testing.T) *rsa.PublicKey { func mustGeneratePrivateKey(t *testing.T) *rsa.PrivateKey { t.Helper() + privateKeyMu.Lock() + defer privateKeyMu.Unlock() + if privateKey != nil { return privateKey } @@ -181,7 +210,7 @@ func TestFlattenExtraClaims(t *testing.T) { {ExtraClaims: map[string]any{"foo": []any{"baz"}}}, }, expected: map[string]any{ - "foo": []any{"bar", "baz"}, // since first was scalar, second being a slice forces slice output + "foo": []any{"bar", "baz"}, // converts to slice when any rule provides a slice }, }, { @@ -462,6 +491,7 @@ func TestServeToken(t *testing.T) { omitCode bool redirectURI string remoteAddr string + strictMode bool expectError bool expected map[string]any }{ @@ -469,12 +499,14 @@ func TestServeToken(t *testing.T) { name: "GET not allowed", method: "GET", grantType: "authorization_code", + strictMode: false, expectError: true, }, { name: "unsupported grant type", method: "POST", grantType: "pkcs", + strictMode: false, expectError: true, }, { @@ -482,6 +514,7 @@ func TestServeToken(t *testing.T) { method: "POST", grantType: "authorization_code", code: "invalid-code", + strictMode: false, expectError: true, }, { @@ -489,6 +522,7 @@ func TestServeToken(t *testing.T) { method: "POST", grantType: "authorization_code", omitCode: true, + strictMode: false, expectError: true, }, { @@ -498,6 +532,7 @@ func TestServeToken(t *testing.T) { code: "valid-code", redirectURI: "https://invalid.example.com/callback", remoteAddr: "127.0.0.1:12345", + strictMode: false, expectError: true, }, { @@ -507,15 +542,17 @@ func TestServeToken(t *testing.T) { redirectURI: "https://rp.example.com/callback", code: "valid-code", remoteAddr: "192.168.0.1:12345", + strictMode: false, expectError: true, }, { - name: "extra claim included", + name: "extra claim included (non-strict)", method: "POST", grantType: "authorization_code", redirectURI: "https://rp.example.com/callback", code: "valid-code", remoteAddr: "127.0.0.1:12345", + strictMode: false, caps: tailcfg.PeerCapMap{ tailcfg.PeerCapabilityTsIDP: { mustMarshalJSON(t, capRule{ @@ -531,11 +568,12 @@ func TestServeToken(t *testing.T) { }, }, { - name: "attempt to overwrite protected claim", + name: "attempt to overwrite protected claim (non-strict)", method: "POST", grantType: "authorization_code", redirectURI: "https://rp.example.com/callback", code: "valid-code", + strictMode: false, caps: tailcfg.PeerCapMap{ tailcfg.PeerCapabilityTsIDP: { mustMarshalJSON(t, capRule{ @@ -554,6 +592,9 @@ func TestServeToken(t *testing.T) { t.Run(tt.name, func(t *testing.T) { now := time.Now() + // Use setupTestServer helper + s := setupTestServer(t, tt.strictMode) + // Fake user/node profile := &tailcfg.UserProfile{ LoginName: "alice@example.com", @@ -575,20 +616,27 @@ func TestServeToken(t *testing.T) { CapMap: tt.caps, } - s := &idpServer{ - code: map[string]*authRequest{ - "valid-code": { - clientID: "client-id", - nonce: "nonce123", - redirectURI: "https://rp.example.com/callback", - validTill: now.Add(5 * time.Minute), - remoteUser: remoteUser, - localRP: true, - }, - }, + // Setup auth request with appropriate configuration for strict mode + var funnelClientPtr *funnelClient + if tt.strictMode { + funnelClientPtr = &funnelClient{ + ID: "client-id", + Secret: "test-secret", + Name: "Test Client", + RedirectURI: "https://rp.example.com/callback", + } + s.funnelClients["client-id"] = funnelClientPtr + } + + s.code["valid-code"] = &authRequest{ + clientID: "client-id", + nonce: "nonce123", + redirectURI: "https://rp.example.com/callback", + validTill: now.Add(5 * time.Minute), + remoteUser: remoteUser, + localRP: !tt.strictMode, + funnelRP: funnelClientPtr, } - // Inject a working signer - s.lazySigner.Set(oidcTestingSigner(t)) form := url.Values{} form.Set("grant_type", tt.grantType) @@ -596,6 +644,11 @@ func TestServeToken(t *testing.T) { if !tt.omitCode { form.Set("code", tt.code) } + // Add client credentials for strict mode + if tt.strictMode { + form.Set("client_id", "client-id") + form.Set("client_secret", "test-secret") + } req := httptest.NewRequest(tt.method, "/token", strings.NewReader(form.Encode())) req.RemoteAddr = tt.remoteAddr @@ -779,6 +832,7 @@ func TestExtraUserInfo(t *testing.T) { // Insert a valid token into the idpServer s := &idpServer{ + allowInsecureRegistration: true, // Default to allowing insecure registration for backward compatibility accessToken: map[string]*authRequest{ token: { validTill: tt.tokenValidTill, @@ -854,7 +908,7 @@ func TestFunnelClientsPersistence(t *testing.T) { t.Fatalf("failed to write test file: %v", err) } - t.Run("step1_load_from_existing_file", func(t *testing.T) { + t.Run("load_from_existing_file", func(t *testing.T) { srv := &idpServer{} // Simulate the funnel clients loading logic from main() @@ -887,7 +941,7 @@ func TestFunnelClientsPersistence(t *testing.T) { } }) - t.Run("step2_initialize_empty_when_no_file", func(t *testing.T) { + t.Run("initialize_empty_when_no_file", func(t *testing.T) { nonExistentFile := t.TempDir() + "/non-existent.json" srv := &idpServer{} @@ -913,7 +967,7 @@ func TestFunnelClientsPersistence(t *testing.T) { } }) - t.Run("step3_persist_and_reload_clients", func(t *testing.T) { + t.Run("persist_and_reload_clients", func(t *testing.T) { tmpFile2 := t.TempDir() + "/test-persistence.json" // Create initial server with one client @@ -962,4 +1016,1048 @@ func TestFunnelClientsPersistence(t *testing.T) { } } }) + + t.Run("strict_mode_file_handling", func(t *testing.T) { + tmpDir := t.TempDir() + + // Test strict mode uses oauth-clients.json + srv1 := setupTestServer(t, true) + srv1.rootPath = tmpDir + srv1.funnelClients["oauth-client"] = &funnelClient{ + ID: "oauth-client", + Secret: "oauth-secret", + Name: "OAuth Client", + RedirectURI: "https://oauth.example.com/callback", + } + + // Test storeFunnelClientsLocked in strict mode + srv1.mu.Lock() + err := srv1.storeFunnelClientsLocked() + srv1.mu.Unlock() + + if err != nil { + t.Fatalf("failed to store clients in strict mode: %v", err) + } + + // Verify oauth-clients.json was created + oauthPath := tmpDir + "/" + oauthClientsFile + if _, err := os.Stat(oauthPath); err != nil { + t.Errorf("expected oauth-clients.json to be created: %v", err) + } + + // Verify oidc-funnel-clients.json was NOT created + funnelPath := tmpDir + "/" + funnelClientsFile + if _, err := os.Stat(funnelPath); !os.IsNotExist(err) { + t.Error("expected oidc-funnel-clients.json NOT to be created in strict mode") + } + }) + + t.Run("non_strict_mode_file_handling", func(t *testing.T) { + tmpDir := t.TempDir() + + // Test non-strict mode uses oidc-funnel-clients.json + srv1 := setupTestServer(t, false) + srv1.rootPath = tmpDir + srv1.funnelClients["funnel-client"] = &funnelClient{ + ID: "funnel-client", + Secret: "funnel-secret", + Name: "Funnel Client", + RedirectURI: "https://funnel.example.com/callback", + } + + // Test storeFunnelClientsLocked in non-strict mode + srv1.mu.Lock() + err := srv1.storeFunnelClientsLocked() + srv1.mu.Unlock() + + if err != nil { + t.Fatalf("failed to store clients in non-strict mode: %v", err) + } + + // Verify oidc-funnel-clients.json was created + funnelPath := tmpDir + "/" + funnelClientsFile + if _, err := os.Stat(funnelPath); err != nil { + t.Errorf("expected oidc-funnel-clients.json to be created: %v", err) + } + + // Verify oauth-clients.json was NOT created + oauthPath := tmpDir + "/" + oauthClientsFile + if _, err := os.Stat(oauthPath); !os.IsNotExist(err) { + t.Error("expected oauth-clients.json NOT to be created in non-strict mode") + } + }) +} + +// Test helper functions for strict OAuth mode testing +func setupTestServer(t *testing.T, strictMode bool) *idpServer { + return setupTestServerWithClient(t, strictMode, nil) +} + +// setupTestServerWithClient creates a test server with an optional LocalClient. +// If lc is nil, the server will have no LocalClient (original behavior). +// If lc is provided, it will be used for WhoIs calls during testing. +func setupTestServerWithClient(t *testing.T, strictMode bool, lc *local.Client) *idpServer { + t.Helper() + + srv := &idpServer{ + allowInsecureRegistration: !strictMode, + code: make(map[string]*authRequest), + accessToken: make(map[string]*authRequest), + funnelClients: make(map[string]*funnelClient), + serverURL: "https://test.ts.net", + rootPath: t.TempDir(), + lc: lc, + } + + // Add a test client for funnel/strict mode testing + srv.funnelClients["test-client"] = &funnelClient{ + ID: "test-client", + Secret: "test-secret", + Name: "Test Client", + RedirectURI: "https://rp.example.com/callback", + } + + // Inject a working signer for token tests + srv.lazySigner.Set(oidcTestingSigner(t)) + + return srv +} + +func TestGetAllowInsecureRegistration(t *testing.T) { + tests := []struct { + name string + flagSet bool + flagValue bool + expectAllowInsecureRegistration bool + }{ + { + name: "flag explicitly set to false - insecure registration disabled (strict mode)", + flagSet: true, + flagValue: false, + expectAllowInsecureRegistration: false, + }, + { + name: "flag explicitly set to true - insecure registration enabled", + flagSet: true, + flagValue: true, + expectAllowInsecureRegistration: true, + }, + { + name: "flag unset - insecure registration enabled (default for backward compatibility)", + flagSet: false, + flagValue: false, // not used when unset + expectAllowInsecureRegistration: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Save original state + originalFlag := flagAllowInsecureRegistration + defer func() { + flagAllowInsecureRegistration = originalFlag + }() + + // Set up test state by creating a new BoolFlag and setting values + var b opt.Bool + flagAllowInsecureRegistration = opt.BoolFlag{Bool: &b} + if tt.flagSet { + flagAllowInsecureRegistration.Bool.Set(tt.flagValue) + } + // Note: when tt.flagSet is false, the Bool remains unset (which is what we want) + + got := getAllowInsecureRegistration() + if got != tt.expectAllowInsecureRegistration { + t.Errorf("getAllowInsecureRegistration() = %v, want %v", got, tt.expectAllowInsecureRegistration) + } + }) + } +} + +// TestMigrateOAuthClients verifies the migration from legacy funnel clients +// to OAuth clients. This migration is necessary when transitioning from +// non-strict to strict OAuth mode. The migration logic should: +// - Copy clients from oidc-funnel-clients.json to oauth-clients.json +// - Rename the old file to mark it as deprecated +// - Handle cases where files already exist or are missing +func TestMigrateOAuthClients(t *testing.T) { + tests := []struct { + name string + setupOldFile bool + setupNewFile bool + oldFileContent map[string]*funnelClient + newFileContent map[string]*funnelClient + expectError bool + expectNewFileExists bool + expectOldRenamed bool + }{ + { + name: "migrate from old file to new file", + setupOldFile: true, + oldFileContent: map[string]*funnelClient{ + "old-client": { + ID: "old-client", + Secret: "old-secret", + Name: "Old Client", + RedirectURI: "https://old.example.com/callback", + }, + }, + expectNewFileExists: true, + expectOldRenamed: true, + }, + { + name: "new file already exists - no migration", + setupNewFile: true, + newFileContent: map[string]*funnelClient{ + "existing-client": { + ID: "existing-client", + Secret: "existing-secret", + Name: "Existing Client", + RedirectURI: "https://existing.example.com/callback", + }, + }, + expectNewFileExists: true, + expectOldRenamed: false, + }, + { + name: "neither file exists - create empty new file", + expectNewFileExists: true, + expectOldRenamed: false, + }, + { + name: "both files exist - prefer new file", + setupOldFile: true, + setupNewFile: true, + oldFileContent: map[string]*funnelClient{ + "old-client": { + ID: "old-client", + Secret: "old-secret", + Name: "Old Client", + RedirectURI: "https://old.example.com/callback", + }, + }, + newFileContent: map[string]*funnelClient{ + "new-client": { + ID: "new-client", + Secret: "new-secret", + Name: "New Client", + RedirectURI: "https://new.example.com/callback", + }, + }, + expectNewFileExists: true, + expectOldRenamed: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rootPath := t.TempDir() + + // Setup old file if needed + if tt.setupOldFile { + oldData, err := json.Marshal(tt.oldFileContent) + if err != nil { + t.Fatalf("failed to marshal old file content: %v", err) + } + oldPath := rootPath + "/" + funnelClientsFile + if err := os.WriteFile(oldPath, oldData, 0600); err != nil { + t.Fatalf("failed to create old file: %v", err) + } + } + + // Setup new file if needed + if tt.setupNewFile { + newData, err := json.Marshal(tt.newFileContent) + if err != nil { + t.Fatalf("failed to marshal new file content: %v", err) + } + newPath := rootPath + "/" + oauthClientsFile + if err := os.WriteFile(newPath, newData, 0600); err != nil { + t.Fatalf("failed to create new file: %v", err) + } + } + + // Call migrateOAuthClients + resultPath, err := migrateOAuthClients(rootPath) + + if tt.expectError && err == nil { + t.Fatalf("expected error but got none") + } + if !tt.expectError && err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tt.expectError { + return + } + + // Verify result path points to oauth-clients.json + expectedPath := filepath.Join(rootPath, oauthClientsFile) + if resultPath != expectedPath { + t.Errorf("expected result path %s, got %s", expectedPath, resultPath) + } + + // Verify new file exists if expected + if tt.expectNewFileExists { + if _, err := os.Stat(resultPath); err != nil { + t.Errorf("expected new file to exist at %s: %v", resultPath, err) + } + + // Verify content + data, err := os.ReadFile(resultPath) + if err != nil { + t.Fatalf("failed to read new file: %v", err) + } + + var clients map[string]*funnelClient + if err := json.Unmarshal(data, &clients); err != nil { + t.Fatalf("failed to unmarshal new file: %v", err) + } + + // Determine expected content + var expectedContent map[string]*funnelClient + if tt.setupNewFile { + expectedContent = tt.newFileContent + } else if tt.setupOldFile { + expectedContent = tt.oldFileContent + } else { + expectedContent = make(map[string]*funnelClient) + } + + if len(clients) != len(expectedContent) { + t.Errorf("expected %d clients, got %d", len(expectedContent), len(clients)) + } + + for id, expectedClient := range expectedContent { + actualClient, ok := clients[id] + if !ok { + t.Errorf("expected client %s not found", id) + continue + } + if actualClient.ID != expectedClient.ID || + actualClient.Secret != expectedClient.Secret || + actualClient.Name != expectedClient.Name || + actualClient.RedirectURI != expectedClient.RedirectURI { + t.Errorf("client %s mismatch: got %+v, want %+v", id, actualClient, expectedClient) + } + } + } + + // Verify old file renamed if expected + if tt.expectOldRenamed { + deprecatedPath := rootPath + "/" + deprecatedFunnelClientsFile + if _, err := os.Stat(deprecatedPath); err != nil { + t.Errorf("expected old file to be renamed to %s: %v", deprecatedPath, err) + } + + // Verify original old file is gone + oldPath := rootPath + "/" + funnelClientsFile + if _, err := os.Stat(oldPath); !os.IsNotExist(err) { + t.Errorf("expected old file %s to be removed", oldPath) + } + } + }) + } +} + +// TestGetConfigFilePath verifies backward compatibility for config file location. +// The function must check current directory first (legacy deployments) before +// falling back to rootPath (new installations) to prevent breaking existing +// tsidp deployments that have config files in unexpected locations. +func TestGetConfigFilePath(t *testing.T) { + tests := []struct { + name string + fileName string + createInCwd bool + createInRoot bool + expectInCwd bool + expectError bool + }{ + { + name: "file exists in current directory - use current directory", + fileName: "test-config.json", + createInCwd: true, + expectInCwd: true, + }, + { + name: "file does not exist - use root path", + fileName: "test-config.json", + createInCwd: false, + expectInCwd: false, + }, + { + name: "file exists in both - prefer current directory", + fileName: "test-config.json", + createInCwd: true, + createInRoot: true, + expectInCwd: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create temporary directories + rootPath := t.TempDir() + originalWd, err := os.Getwd() + if err != nil { + t.Fatalf("failed to get working directory: %v", err) + } + + // Create a temporary working directory + tmpWd := t.TempDir() + if err := os.Chdir(tmpWd); err != nil { + t.Fatalf("failed to change to temp directory: %v", err) + } + defer func() { + os.Chdir(originalWd) + }() + + // Setup files as needed + if tt.createInCwd { + if err := os.WriteFile(tt.fileName, []byte("{}"), 0600); err != nil { + t.Fatalf("failed to create file in cwd: %v", err) + } + } + if tt.createInRoot { + rootFilePath := filepath.Join(rootPath, tt.fileName) + if err := os.WriteFile(rootFilePath, []byte("{}"), 0600); err != nil { + t.Fatalf("failed to create file in root: %v", err) + } + } + + // Call getConfigFilePath + resultPath, err := getConfigFilePath(rootPath, tt.fileName) + + if tt.expectError && err == nil { + t.Fatalf("expected error but got none") + } + if !tt.expectError && err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tt.expectError { + return + } + + // Verify result + if tt.expectInCwd { + if resultPath != tt.fileName { + t.Errorf("expected path %s, got %s", tt.fileName, resultPath) + } + } else { + expectedPath := filepath.Join(rootPath, tt.fileName) + if resultPath != expectedPath { + t.Errorf("expected path %s, got %s", expectedPath, resultPath) + } + } + }) + } +} + +// TestAuthorizeStrictMode verifies OAuth authorization endpoint security and validation logic. +// Tests both the security boundary (funnel rejection) and the business logic (strict mode validation). +func TestAuthorizeStrictMode(t *testing.T) { + tests := []struct { + name string + strictMode bool + clientID string + redirectURI string + state string + nonce string + setupClient bool + clientRedirect string + useFunnel bool // whether to simulate funnel request + mockWhoIsError bool // whether to make WhoIs return an error + expectError bool + expectCode int + expectRedirect bool + }{ + // Security boundary test: funnel rejection + { + name: "funnel requests are always rejected for security", + strictMode: true, + clientID: "test-client", + redirectURI: "https://rp.example.com/callback", + state: "random-state", + nonce: "random-nonce", + setupClient: true, + clientRedirect: "https://rp.example.com/callback", + useFunnel: true, + expectError: true, + expectCode: http.StatusUnauthorized, + }, + + // Strict mode parameter validation tests (non-funnel) + { + name: "strict mode - missing client_id", + strictMode: true, + clientID: "", + redirectURI: "https://rp.example.com/callback", + useFunnel: false, + expectError: true, + expectCode: http.StatusBadRequest, + }, + { + name: "strict mode - missing redirect_uri", + strictMode: true, + clientID: "test-client", + redirectURI: "", + useFunnel: false, + expectError: true, + expectCode: http.StatusBadRequest, + }, + + // Strict mode client validation tests (non-funnel) + { + name: "strict mode - invalid client_id", + strictMode: true, + clientID: "invalid-client", + redirectURI: "https://rp.example.com/callback", + setupClient: false, + useFunnel: false, + expectError: true, + expectCode: http.StatusBadRequest, + }, + { + name: "strict mode - redirect_uri mismatch", + strictMode: true, + clientID: "test-client", + redirectURI: "https://wrong.example.com/callback", + setupClient: true, + clientRedirect: "https://rp.example.com/callback", + useFunnel: false, + expectError: true, + expectCode: http.StatusBadRequest, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + srv := setupTestServer(t, tt.strictMode) + + // For non-funnel tests, we'll test the parameter validation logic + // without needing to mock WhoIs, since the validation happens before WhoIs calls + + // Setup client if needed + if tt.setupClient { + srv.funnelClients["test-client"] = &funnelClient{ + ID: "test-client", + Secret: "test-secret", + Name: "Test Client", + RedirectURI: tt.clientRedirect, + } + } else if !tt.strictMode { + // For non-strict mode tests that don't need a specific client setup + // but might reference one, clear the default client + delete(srv.funnelClients, "test-client") + } + + // Create request + reqURL := "/authorize" + if !tt.strictMode { + // In non-strict mode, use the node-specific endpoint + reqURL = "/authorize/123" + } + + query := url.Values{} + if tt.clientID != "" { + query.Set("client_id", tt.clientID) + } + if tt.redirectURI != "" { + query.Set("redirect_uri", tt.redirectURI) + } + if tt.state != "" { + query.Set("state", tt.state) + } + if tt.nonce != "" { + query.Set("nonce", tt.nonce) + } + + reqURL += "?" + query.Encode() + req := httptest.NewRequest("GET", reqURL, nil) + req.RemoteAddr = "127.0.0.1:12345" + + // Set funnel header only when explicitly testing funnel behavior + if tt.useFunnel { + req.Header.Set("Tailscale-Funnel-Request", "true") + } + + rr := httptest.NewRecorder() + srv.authorize(rr, req) + + if tt.expectError { + if rr.Code != tt.expectCode { + t.Errorf("expected status code %d, got %d: %s", tt.expectCode, rr.Code, rr.Body.String()) + } + } else if tt.expectRedirect { + if rr.Code != http.StatusFound { + t.Errorf("expected redirect (302), got %d: %s", rr.Code, rr.Body.String()) + } + + location := rr.Header().Get("Location") + if location == "" { + t.Error("expected Location header in redirect response") + } else { + // Parse the redirect URL to verify it contains a code + redirectURL, err := url.Parse(location) + if err != nil { + t.Errorf("failed to parse redirect URL: %v", err) + } else { + code := redirectURL.Query().Get("code") + if code == "" { + t.Error("expected 'code' parameter in redirect URL") + } + + // Verify state is preserved if provided + if tt.state != "" { + returnedState := redirectURL.Query().Get("state") + if returnedState != tt.state { + t.Errorf("expected state '%s', got '%s'", tt.state, returnedState) + } + } + + // Verify the auth request was stored + srv.mu.Lock() + ar, ok := srv.code[code] + srv.mu.Unlock() + + if !ok { + t.Error("expected authorization request to be stored") + } else { + if ar.clientID != tt.clientID { + t.Errorf("expected clientID '%s', got '%s'", tt.clientID, ar.clientID) + } + if ar.redirectURI != tt.redirectURI { + t.Errorf("expected redirectURI '%s', got '%s'", tt.redirectURI, ar.redirectURI) + } + if ar.nonce != tt.nonce { + t.Errorf("expected nonce '%s', got '%s'", tt.nonce, ar.nonce) + } + } + } + } + } else { + t.Errorf("unexpected test case: not expecting error or redirect") + } + }) + } +} + +// TestServeTokenWithClientValidation verifies OAuth token endpoint security in both strict and non-strict modes. +// In strict mode, the token endpoint must: +// - Require and validate client credentials (client_id + client_secret) +// - Only accept tokens from registered funnel clients +// - Validate that redirect_uri matches the registered client +// - Support both form-based and HTTP Basic authentication for client credentials +func TestServeTokenWithClientValidation(t *testing.T) { + tests := []struct { + name string + strictMode bool + method string + grantType string + code string + clientID string + clientSecret string + redirectURI string + useBasicAuth bool + setupAuthRequest bool + authRequestClient string + authRequestRedirect string + expectError bool + expectCode int + expectIDToken bool + }{ + { + name: "strict mode - valid token exchange with form credentials", + strictMode: true, + method: "POST", + grantType: "authorization_code", + code: "valid-code", + clientID: "test-client", + clientSecret: "test-secret", + redirectURI: "https://rp.example.com/callback", + setupAuthRequest: true, + authRequestClient: "test-client", + authRequestRedirect: "https://rp.example.com/callback", + expectIDToken: true, + }, + { + name: "strict mode - valid token exchange with basic auth", + strictMode: true, + method: "POST", + grantType: "authorization_code", + code: "valid-code", + redirectURI: "https://rp.example.com/callback", + useBasicAuth: true, + clientID: "test-client", + clientSecret: "test-secret", + setupAuthRequest: true, + authRequestClient: "test-client", + authRequestRedirect: "https://rp.example.com/callback", + expectIDToken: true, + }, + { + name: "strict mode - missing client credentials", + strictMode: true, + method: "POST", + grantType: "authorization_code", + code: "valid-code", + redirectURI: "https://rp.example.com/callback", + setupAuthRequest: true, + authRequestClient: "test-client", + authRequestRedirect: "https://rp.example.com/callback", + expectError: true, + expectCode: http.StatusUnauthorized, + }, + { + name: "strict mode - client_id mismatch", + strictMode: true, + method: "POST", + grantType: "authorization_code", + code: "valid-code", + clientID: "wrong-client", + clientSecret: "test-secret", + redirectURI: "https://rp.example.com/callback", + setupAuthRequest: true, + authRequestClient: "test-client", + expectError: true, + expectCode: http.StatusBadRequest, + }, + { + name: "strict mode - invalid client secret", + strictMode: true, + method: "POST", + grantType: "authorization_code", + code: "valid-code", + clientID: "test-client", + clientSecret: "wrong-secret", + redirectURI: "https://rp.example.com/callback", + setupAuthRequest: true, + authRequestClient: "test-client", + authRequestRedirect: "https://rp.example.com/callback", + expectError: true, + expectCode: http.StatusUnauthorized, + }, + { + name: "strict mode - redirect_uri mismatch", + strictMode: true, + method: "POST", + grantType: "authorization_code", + code: "valid-code", + clientID: "test-client", + clientSecret: "test-secret", + redirectURI: "https://wrong.example.com/callback", + setupAuthRequest: true, + authRequestClient: "test-client", + authRequestRedirect: "https://rp.example.com/callback", + expectError: true, + expectCode: http.StatusBadRequest, + }, + { + name: "non-strict mode - no client validation required", + strictMode: false, + method: "POST", + grantType: "authorization_code", + code: "valid-code", + redirectURI: "https://rp.example.com/callback", + setupAuthRequest: true, + authRequestRedirect: "https://rp.example.com/callback", + expectIDToken: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + srv := setupTestServer(t, tt.strictMode) + + // Setup authorization request if needed + if tt.setupAuthRequest { + now := time.Now() + profile := &tailcfg.UserProfile{ + LoginName: "alice@example.com", + DisplayName: "Alice Example", + ProfilePicURL: "https://example.com/alice.jpg", + } + node := &tailcfg.Node{ + ID: 123, + Name: "test-node.test.ts.net.", + User: 456, + Key: key.NodePublic{}, + Cap: 1, + DiscoKey: key.DiscoPublic{}, + } + remoteUser := &apitype.WhoIsResponse{ + Node: node, + UserProfile: profile, + CapMap: tailcfg.PeerCapMap{}, + } + + var funnelClientPtr *funnelClient + if tt.strictMode && tt.authRequestClient != "" { + funnelClientPtr = &funnelClient{ + ID: tt.authRequestClient, + Secret: "test-secret", + Name: "Test Client", + RedirectURI: tt.authRequestRedirect, + } + srv.funnelClients[tt.authRequestClient] = funnelClientPtr + } + + srv.code["valid-code"] = &authRequest{ + clientID: tt.authRequestClient, + nonce: "nonce123", + redirectURI: tt.authRequestRedirect, + validTill: now.Add(5 * time.Minute), + remoteUser: remoteUser, + localRP: !tt.strictMode, + funnelRP: funnelClientPtr, + } + } + + // Create form data + form := url.Values{} + form.Set("grant_type", tt.grantType) + form.Set("code", tt.code) + form.Set("redirect_uri", tt.redirectURI) + + if !tt.useBasicAuth { + if tt.clientID != "" { + form.Set("client_id", tt.clientID) + } + if tt.clientSecret != "" { + form.Set("client_secret", tt.clientSecret) + } + } + + req := httptest.NewRequest(tt.method, "/token", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.RemoteAddr = "127.0.0.1:12345" + + if tt.useBasicAuth && tt.clientID != "" && tt.clientSecret != "" { + req.SetBasicAuth(tt.clientID, tt.clientSecret) + } + + rr := httptest.NewRecorder() + srv.serveToken(rr, req) + + if tt.expectError { + if rr.Code != tt.expectCode { + t.Errorf("expected status code %d, got %d: %s", tt.expectCode, rr.Code, rr.Body.String()) + } + } else if tt.expectIDToken { + if rr.Code != http.StatusOK { + t.Errorf("expected 200 OK, got %d: %s", rr.Code, rr.Body.String()) + } + + var resp struct { + IDToken string `json:"id_token"` + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + } + + if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + if resp.IDToken == "" { + t.Error("expected id_token in response") + } + if resp.AccessToken == "" { + t.Error("expected access_token in response") + } + if resp.TokenType != "Bearer" { + t.Errorf("expected token_type 'Bearer', got '%s'", resp.TokenType) + } + if resp.ExpiresIn != 300 { + t.Errorf("expected expires_in 300, got %d", resp.ExpiresIn) + } + + // Verify access token was stored + srv.mu.Lock() + _, ok := srv.accessToken[resp.AccessToken] + srv.mu.Unlock() + + if !ok { + t.Error("expected access token to be stored") + } + + // Verify authorization code was consumed + srv.mu.Lock() + _, ok = srv.code[tt.code] + srv.mu.Unlock() + + if ok { + t.Error("expected authorization code to be consumed") + } + } + }) + } +} + +// TestServeUserInfoWithClientValidation verifies UserInfo endpoint security in both strict and non-strict modes. +// In strict mode, the UserInfo endpoint must: +// - Validate that access tokens are associated with registered clients +// - Reject tokens for clients that have been deleted/unregistered +// - Enforce token expiration properly +// - Return appropriate user claims based on client capabilities +func TestServeUserInfoWithClientValidation(t *testing.T) { + tests := []struct { + name string + strictMode bool + setupToken bool + setupClient bool + clientID string + token string + tokenValidTill time.Time + expectError bool + expectCode int + expectUserInfo bool + }{ + { + name: "strict mode - valid token with existing client", + strictMode: true, + setupToken: true, + setupClient: true, + clientID: "test-client", + token: "valid-token", + tokenValidTill: time.Now().Add(5 * time.Minute), + expectUserInfo: true, + }, + { + name: "strict mode - valid token but client no longer exists", + strictMode: true, + setupToken: true, + setupClient: false, + clientID: "deleted-client", + token: "valid-token", + tokenValidTill: time.Now().Add(5 * time.Minute), + expectError: true, + expectCode: http.StatusUnauthorized, + }, + { + name: "strict mode - expired token", + strictMode: true, + setupToken: true, + setupClient: true, + clientID: "test-client", + token: "expired-token", + tokenValidTill: time.Now().Add(-5 * time.Minute), + expectError: true, + expectCode: http.StatusBadRequest, + }, + { + name: "strict mode - invalid token", + strictMode: true, + setupToken: false, + token: "invalid-token", + expectError: true, + expectCode: http.StatusBadRequest, + }, + { + name: "strict mode - token without client association", + strictMode: true, + setupToken: true, + setupClient: false, + clientID: "", + token: "valid-token", + tokenValidTill: time.Now().Add(5 * time.Minute), + expectError: true, + expectCode: http.StatusBadRequest, + }, + { + name: "non-strict mode - no client validation required", + strictMode: false, + setupToken: true, + setupClient: false, + clientID: "", + token: "valid-token", + tokenValidTill: time.Now().Add(5 * time.Minute), + expectUserInfo: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + srv := setupTestServer(t, tt.strictMode) + + // Setup client if needed + if tt.setupClient { + srv.funnelClients[tt.clientID] = &funnelClient{ + ID: tt.clientID, + Secret: "test-secret", + Name: "Test Client", + RedirectURI: "https://rp.example.com/callback", + } + } + + // Setup token if needed + if tt.setupToken { + profile := &tailcfg.UserProfile{ + LoginName: "alice@example.com", + DisplayName: "Alice Example", + ProfilePicURL: "https://example.com/alice.jpg", + } + node := &tailcfg.Node{ + ID: 123, + Name: "test-node.test.ts.net.", + User: 456, + Key: key.NodePublic{}, + Cap: 1, + DiscoKey: key.DiscoPublic{}, + } + remoteUser := &apitype.WhoIsResponse{ + Node: node, + UserProfile: profile, + CapMap: tailcfg.PeerCapMap{}, + } + + srv.accessToken[tt.token] = &authRequest{ + clientID: tt.clientID, + validTill: tt.tokenValidTill, + remoteUser: remoteUser, + } + } + + // Create request + req := httptest.NewRequest("GET", "/userinfo", nil) + req.Header.Set("Authorization", "Bearer "+tt.token) + req.RemoteAddr = "127.0.0.1:12345" + + rr := httptest.NewRecorder() + srv.serveUserInfo(rr, req) + + if tt.expectError { + if rr.Code != tt.expectCode { + t.Errorf("expected status code %d, got %d: %s", tt.expectCode, rr.Code, rr.Body.String()) + } + } else if tt.expectUserInfo { + if rr.Code != http.StatusOK { + t.Errorf("expected 200 OK, got %d: %s", rr.Code, rr.Body.String()) + } + + var resp map[string]any + if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to parse JSON response: %v", err) + } + + // Check required fields + expectedFields := []string{"sub", "name", "email", "picture", "username"} + for _, field := range expectedFields { + if _, ok := resp[field]; !ok { + t.Errorf("expected field '%s' in user info response", field) + } + } + + // Verify specific values + if resp["name"] != "Alice Example" { + t.Errorf("expected name 'Alice Example', got '%v'", resp["name"]) + } + if resp["email"] != "alice@example.com" { + t.Errorf("expected email 'alice@example.com', got '%v'", resp["email"]) + } + if resp["username"] != "alice" { + t.Errorf("expected username 'alice', got '%v'", resp["username"]) + } + } + }) + } }