diff --git a/feature/identityfederation/identityfederation.go b/feature/identityfederation/identityfederation.go index a4470fc27..ab1b65f12 100644 --- a/feature/identityfederation/identityfederation.go +++ b/feature/identityfederation/identityfederation.go @@ -42,12 +42,12 @@ func resolveAuthKey(ctx context.Context, baseURL, clientID, idToken string, tags baseURL = ipn.DefaultControlURL } - ephemeral, preauth, err := parseOptionalAttributes(clientID) + strippedID, ephemeral, preauth, err := parseOptionalAttributes(clientID) if err != nil { return "", fmt.Errorf("failed to parse optional config attributes: %w", err) } - accessToken, err := exchangeJWTForToken(ctx, baseURL, clientID, idToken) + accessToken, err := exchangeJWTForToken(ctx, baseURL, strippedID, idToken) if err != nil { return "", fmt.Errorf("failed to exchange JWT for access token: %w", err) } @@ -79,15 +79,15 @@ func resolveAuthKey(ctx context.Context, baseURL, clientID, idToken string, tags return authkey, nil } -func parseOptionalAttributes(clientID string) (ephemeral bool, preauthorized bool, err error) { - _, attrs, found := strings.Cut(clientID, "?") +func parseOptionalAttributes(clientID string) (strippedID string, ephemeral bool, preauthorized bool, err error) { + strippedID, attrs, found := strings.Cut(clientID, "?") if !found { - return true, false, nil + return clientID, true, false, nil } parsed, err := url.ParseQuery(attrs) if err != nil { - return false, false, fmt.Errorf("failed to parse optional config attributes: %w", err) + return "", false, false, fmt.Errorf("failed to parse optional config attributes: %w", err) } for k := range parsed { @@ -97,11 +97,14 @@ func parseOptionalAttributes(clientID string) (ephemeral bool, preauthorized boo case "preauthorized": preauthorized, err = strconv.ParseBool(parsed.Get(k)) default: - return false, false, fmt.Errorf("unknown optional config attribute %q", k) + return "", false, false, fmt.Errorf("unknown optional config attribute %q", k) } } + if err != nil { + return "", false, false, err + } - return ephemeral, preauthorized, err + return strippedID, ephemeral, preauthorized, nil } // exchangeJWTForToken exchanges a JWT for a Tailscale access token. diff --git a/feature/identityfederation/identityfederation_test.go b/feature/identityfederation/identityfederation_test.go index 7b75852a8..a673a4298 100644 --- a/feature/identityfederation/identityfederation_test.go +++ b/feature/identityfederation/identityfederation_test.go @@ -87,6 +87,7 @@ func TestParseOptionalAttributes(t *testing.T) { tests := []struct { name string clientID string + wantClientID string wantEphemeral bool wantPreauth bool wantErr string @@ -94,6 +95,7 @@ func TestParseOptionalAttributes(t *testing.T) { { name: "default values", clientID: "client-123", + wantClientID: "client-123", wantEphemeral: true, wantPreauth: false, wantErr: "", @@ -101,6 +103,7 @@ func TestParseOptionalAttributes(t *testing.T) { { name: "custom values", clientID: "client-123?ephemeral=false&preauthorized=true", + wantClientID: "client-123", wantEphemeral: false, wantPreauth: true, wantErr: "", @@ -108,6 +111,7 @@ func TestParseOptionalAttributes(t *testing.T) { { name: "unknown attribute", clientID: "client-123?unknown=value", + wantClientID: "", wantEphemeral: false, wantPreauth: false, wantErr: `unknown optional config attribute "unknown"`, @@ -115,6 +119,7 @@ func TestParseOptionalAttributes(t *testing.T) { { name: "invalid value", clientID: "client-123?ephemeral=invalid", + wantClientID: "", wantEphemeral: false, wantPreauth: false, wantErr: `strconv.ParseBool: parsing "invalid": invalid syntax`, @@ -123,7 +128,7 @@ func TestParseOptionalAttributes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ephemeral, preauth, err := parseOptionalAttributes(tt.clientID) + strippedID, ephemeral, preauth, err := parseOptionalAttributes(tt.clientID) if tt.wantErr != "" { if err == nil { t.Errorf("parseOptionalAttributes() error = nil, want %q", tt.wantErr) @@ -138,6 +143,9 @@ func TestParseOptionalAttributes(t *testing.T) { return } } + if strippedID != tt.wantClientID { + t.Errorf("parseOptionalAttributes() strippedID = %v, want %v", strippedID, tt.wantClientID) + } if ephemeral != tt.wantEphemeral { t.Errorf("parseOptionalAttributes() ephemeral = %v, want %v", ephemeral, tt.wantEphemeral) }