feature/identityfederation: strip query params on clientID (#17666)

Updates #9192

Signed-off-by: mcoulombe <max@tailscale.com>
pull/17672/head
Max Coulombe 1 month ago committed by GitHub
parent a760cbe33f
commit 34e992f59d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -42,12 +42,12 @@ func resolveAuthKey(ctx context.Context, baseURL, clientID, idToken string, tags
baseURL = ipn.DefaultControlURL baseURL = ipn.DefaultControlURL
} }
ephemeral, preauth, err := parseOptionalAttributes(clientID) strippedID, ephemeral, preauth, err := parseOptionalAttributes(clientID)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to parse optional config attributes: %w", err) return "", fmt.Errorf("failed to parse optional config attributes: %w", err)
} }
accessToken, err := exchangeJWTForToken(ctx, baseURL, clientID, idToken) accessToken, err := exchangeJWTForToken(ctx, baseURL, strippedID, idToken)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to exchange JWT for access token: %w", err) return "", fmt.Errorf("failed to exchange JWT for access token: %w", err)
} }
@ -79,15 +79,15 @@ func resolveAuthKey(ctx context.Context, baseURL, clientID, idToken string, tags
return authkey, nil return authkey, nil
} }
func parseOptionalAttributes(clientID string) (ephemeral bool, preauthorized bool, err error) { func parseOptionalAttributes(clientID string) (strippedID string, ephemeral bool, preauthorized bool, err error) {
_, attrs, found := strings.Cut(clientID, "?") strippedID, attrs, found := strings.Cut(clientID, "?")
if !found { if !found {
return true, false, nil return clientID, true, false, nil
} }
parsed, err := url.ParseQuery(attrs) parsed, err := url.ParseQuery(attrs)
if err != nil { if err != nil {
return false, false, fmt.Errorf("failed to parse optional config attributes: %w", err) return "", false, false, fmt.Errorf("failed to parse optional config attributes: %w", err)
} }
for k := range parsed { for k := range parsed {
@ -97,11 +97,14 @@ func parseOptionalAttributes(clientID string) (ephemeral bool, preauthorized boo
case "preauthorized": case "preauthorized":
preauthorized, err = strconv.ParseBool(parsed.Get(k)) preauthorized, err = strconv.ParseBool(parsed.Get(k))
default: default:
return false, false, fmt.Errorf("unknown optional config attribute %q", k) return "", false, false, fmt.Errorf("unknown optional config attribute %q", k)
} }
} }
if err != nil {
return "", false, false, err
}
return ephemeral, preauthorized, err return strippedID, ephemeral, preauthorized, nil
} }
// exchangeJWTForToken exchanges a JWT for a Tailscale access token. // exchangeJWTForToken exchanges a JWT for a Tailscale access token.

@ -87,6 +87,7 @@ func TestParseOptionalAttributes(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
clientID string clientID string
wantClientID string
wantEphemeral bool wantEphemeral bool
wantPreauth bool wantPreauth bool
wantErr string wantErr string
@ -94,6 +95,7 @@ func TestParseOptionalAttributes(t *testing.T) {
{ {
name: "default values", name: "default values",
clientID: "client-123", clientID: "client-123",
wantClientID: "client-123",
wantEphemeral: true, wantEphemeral: true,
wantPreauth: false, wantPreauth: false,
wantErr: "", wantErr: "",
@ -101,6 +103,7 @@ func TestParseOptionalAttributes(t *testing.T) {
{ {
name: "custom values", name: "custom values",
clientID: "client-123?ephemeral=false&preauthorized=true", clientID: "client-123?ephemeral=false&preauthorized=true",
wantClientID: "client-123",
wantEphemeral: false, wantEphemeral: false,
wantPreauth: true, wantPreauth: true,
wantErr: "", wantErr: "",
@ -108,6 +111,7 @@ func TestParseOptionalAttributes(t *testing.T) {
{ {
name: "unknown attribute", name: "unknown attribute",
clientID: "client-123?unknown=value", clientID: "client-123?unknown=value",
wantClientID: "",
wantEphemeral: false, wantEphemeral: false,
wantPreauth: false, wantPreauth: false,
wantErr: `unknown optional config attribute "unknown"`, wantErr: `unknown optional config attribute "unknown"`,
@ -115,6 +119,7 @@ func TestParseOptionalAttributes(t *testing.T) {
{ {
name: "invalid value", name: "invalid value",
clientID: "client-123?ephemeral=invalid", clientID: "client-123?ephemeral=invalid",
wantClientID: "",
wantEphemeral: false, wantEphemeral: false,
wantPreauth: false, wantPreauth: false,
wantErr: `strconv.ParseBool: parsing "invalid": invalid syntax`, wantErr: `strconv.ParseBool: parsing "invalid": invalid syntax`,
@ -123,7 +128,7 @@ func TestParseOptionalAttributes(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ephemeral, preauth, err := parseOptionalAttributes(tt.clientID) strippedID, ephemeral, preauth, err := parseOptionalAttributes(tt.clientID)
if tt.wantErr != "" { if tt.wantErr != "" {
if err == nil { if err == nil {
t.Errorf("parseOptionalAttributes() error = nil, want %q", tt.wantErr) t.Errorf("parseOptionalAttributes() error = nil, want %q", tt.wantErr)
@ -138,6 +143,9 @@ func TestParseOptionalAttributes(t *testing.T) {
return return
} }
} }
if strippedID != tt.wantClientID {
t.Errorf("parseOptionalAttributes() strippedID = %v, want %v", strippedID, tt.wantClientID)
}
if ephemeral != tt.wantEphemeral { if ephemeral != tt.wantEphemeral {
t.Errorf("parseOptionalAttributes() ephemeral = %v, want %v", ephemeral, tt.wantEphemeral) t.Errorf("parseOptionalAttributes() ephemeral = %v, want %v", ephemeral, tt.wantEphemeral)
} }

Loading…
Cancel
Save