mirror of https://github.com/tailscale/tailscale/
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
2064 lines
57 KiB
Go
2064 lines
57 KiB
Go
// Copyright (c) Tailscale Inc & AUTHORS
|
|
// SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
// Package main tests for tsidp focus on OAuth security boundaries and
|
|
// correct implementation of the OpenID Connect identity provider.
|
|
//
|
|
// Test Strategy:
|
|
// - Tests are intentionally granular to provide clear failure signals when
|
|
// security-critical logic breaks
|
|
// - OAuth flow tests cover both strict mode (registered clients only) and
|
|
// legacy mode (local funnel clients) to ensure proper access controls
|
|
// - Helper functions like normalizeMap ensure deterministic comparisons
|
|
// despite JSON marshaling order variations
|
|
// - The privateKey global is reused across tests for performance (RSA key
|
|
// generation is expensive)
|
|
|
|
package main
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/netip"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"reflect"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"gopkg.in/square/go-jose.v2"
|
|
"gopkg.in/square/go-jose.v2/jwt"
|
|
"tailscale.com/client/local"
|
|
"tailscale.com/client/tailscale/apitype"
|
|
"tailscale.com/tailcfg"
|
|
"tailscale.com/types/key"
|
|
"tailscale.com/types/opt"
|
|
"tailscale.com/types/views"
|
|
)
|
|
|
|
// normalizeMap recursively sorts []any values in a map[string]any to ensure
|
|
// deterministic test comparisons. This is necessary because JSON marshaling
|
|
// doesn't guarantee array order, and we need stable comparisons when testing
|
|
// claim merging and flattening logic.
|
|
func normalizeMap(t *testing.T, m map[string]any) map[string]any {
|
|
t.Helper()
|
|
normalized := make(map[string]any, len(m))
|
|
for k, v := range m {
|
|
switch val := v.(type) {
|
|
case []any:
|
|
sorted := make([]string, len(val))
|
|
for i, item := range val {
|
|
sorted[i] = fmt.Sprintf("%v", item) // convert everything to string for sorting
|
|
}
|
|
sort.Strings(sorted)
|
|
|
|
// convert back to []any
|
|
sortedIface := make([]any, len(sorted))
|
|
for i, s := range sorted {
|
|
sortedIface[i] = s
|
|
}
|
|
normalized[k] = sortedIface
|
|
|
|
default:
|
|
normalized[k] = v
|
|
}
|
|
}
|
|
return normalized
|
|
}
|
|
|
|
func mustMarshalJSON(t *testing.T, v any) tailcfg.RawMessage {
|
|
t.Helper()
|
|
b, err := json.Marshal(v)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return tailcfg.RawMessage(b)
|
|
}
|
|
|
|
// privateKey is a shared RSA private key used across tests. It's lazily
|
|
// initialized on first use to avoid the expensive key generation cost
|
|
// for every test. Protected by privateKeyMu for thread safety.
|
|
var (
|
|
privateKey *rsa.PrivateKey
|
|
privateKeyMu sync.Mutex
|
|
)
|
|
|
|
func oidcTestingSigner(t *testing.T) jose.Signer {
|
|
t.Helper()
|
|
privKey := mustGeneratePrivateKey(t)
|
|
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: privKey}, nil)
|
|
if err != nil {
|
|
t.Fatalf("failed to create signer: %v", err)
|
|
}
|
|
return sig
|
|
}
|
|
|
|
func oidcTestingPublicKey(t *testing.T) *rsa.PublicKey {
|
|
t.Helper()
|
|
privKey := mustGeneratePrivateKey(t)
|
|
return &privKey.PublicKey
|
|
}
|
|
|
|
func mustGeneratePrivateKey(t *testing.T) *rsa.PrivateKey {
|
|
t.Helper()
|
|
privateKeyMu.Lock()
|
|
defer privateKeyMu.Unlock()
|
|
|
|
if privateKey != nil {
|
|
return privateKey
|
|
}
|
|
|
|
var err error
|
|
privateKey, err = rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
t.Fatalf("failed to generate key: %v", err)
|
|
}
|
|
|
|
return privateKey
|
|
}
|
|
|
|
func TestFlattenExtraClaims(t *testing.T) {
|
|
log.SetOutput(io.Discard) // suppress log output during tests
|
|
|
|
tests := []struct {
|
|
name string
|
|
input []capRule
|
|
expected map[string]any
|
|
}{
|
|
{
|
|
name: "empty extra claims",
|
|
input: []capRule{
|
|
{ExtraClaims: map[string]any{}},
|
|
},
|
|
expected: map[string]any{},
|
|
},
|
|
{
|
|
name: "string and number values",
|
|
input: []capRule{
|
|
{
|
|
ExtraClaims: map[string]any{
|
|
"featureA": "read",
|
|
"featureB": 42,
|
|
},
|
|
},
|
|
},
|
|
expected: map[string]any{
|
|
"featureA": "read",
|
|
"featureB": "42",
|
|
},
|
|
},
|
|
{
|
|
name: "slice of strings and ints",
|
|
input: []capRule{
|
|
{
|
|
ExtraClaims: map[string]any{
|
|
"roles": []any{"admin", "user", 1},
|
|
},
|
|
},
|
|
},
|
|
expected: map[string]any{
|
|
"roles": []any{"admin", "user", "1"},
|
|
},
|
|
},
|
|
{
|
|
name: "duplicate values deduplicated (slice input)",
|
|
input: []capRule{
|
|
{
|
|
ExtraClaims: map[string]any{
|
|
"foo": []string{"bar", "baz"},
|
|
},
|
|
},
|
|
{
|
|
ExtraClaims: map[string]any{
|
|
"foo": []any{"bar", "qux"},
|
|
},
|
|
},
|
|
},
|
|
expected: map[string]any{
|
|
"foo": []any{"bar", "baz", "qux"},
|
|
},
|
|
},
|
|
{
|
|
name: "ignore unsupported map type, keep valid scalar",
|
|
input: []capRule{
|
|
{
|
|
ExtraClaims: map[string]any{
|
|
"invalid": map[string]any{"bad": "yes"},
|
|
"valid": "ok",
|
|
},
|
|
},
|
|
},
|
|
expected: map[string]any{
|
|
"valid": "ok",
|
|
},
|
|
},
|
|
{
|
|
name: "scalar first, slice second",
|
|
input: []capRule{
|
|
{ExtraClaims: map[string]any{"foo": "bar"}},
|
|
{ExtraClaims: map[string]any{"foo": []any{"baz"}}},
|
|
},
|
|
expected: map[string]any{
|
|
"foo": []any{"bar", "baz"}, // converts to slice when any rule provides a slice
|
|
},
|
|
},
|
|
{
|
|
name: "conflicting scalar and unsupported map",
|
|
input: []capRule{
|
|
{ExtraClaims: map[string]any{"foo": "bar"}},
|
|
{ExtraClaims: map[string]any{"foo": map[string]any{"bad": "entry"}}},
|
|
},
|
|
expected: map[string]any{
|
|
"foo": "bar", // map should be ignored
|
|
},
|
|
},
|
|
{
|
|
name: "multiple slices with overlap",
|
|
input: []capRule{
|
|
{ExtraClaims: map[string]any{"roles": []any{"admin", "user"}}},
|
|
{ExtraClaims: map[string]any{"roles": []any{"admin", "guest"}}},
|
|
},
|
|
expected: map[string]any{
|
|
"roles": []any{"admin", "user", "guest"},
|
|
},
|
|
},
|
|
{
|
|
name: "slice with unsupported values",
|
|
input: []capRule{
|
|
{ExtraClaims: map[string]any{
|
|
"mixed": []any{"ok", 42, map[string]string{"oops": "fail"}},
|
|
}},
|
|
},
|
|
expected: map[string]any{
|
|
"mixed": []any{"ok", "42"}, // map is ignored
|
|
},
|
|
},
|
|
{
|
|
name: "duplicate scalar value",
|
|
input: []capRule{
|
|
{ExtraClaims: map[string]any{"env": "prod"}},
|
|
{ExtraClaims: map[string]any{"env": "prod"}},
|
|
},
|
|
expected: map[string]any{
|
|
"env": "prod", // not converted to slice
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := flattenExtraClaims(tt.input)
|
|
|
|
gotNormalized := normalizeMap(t, got)
|
|
expectedNormalized := normalizeMap(t, tt.expected)
|
|
|
|
if !reflect.DeepEqual(gotNormalized, expectedNormalized) {
|
|
t.Errorf("mismatch\nGot:\n%s\nWant:\n%s", gotNormalized, expectedNormalized)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestExtraClaims(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
claim tailscaleClaims
|
|
extraClaims []capRule
|
|
expected map[string]any
|
|
expectError bool
|
|
}{
|
|
{
|
|
name: "extra claim",
|
|
claim: tailscaleClaims{
|
|
Claims: jwt.Claims{},
|
|
Nonce: "foobar",
|
|
Key: key.NodePublic{},
|
|
Addresses: views.Slice[netip.Prefix]{},
|
|
NodeID: 0,
|
|
NodeName: "test-node",
|
|
Tailnet: "test.ts.net",
|
|
Email: "test@example.com",
|
|
UserID: 0,
|
|
UserName: "test",
|
|
},
|
|
extraClaims: []capRule{
|
|
{
|
|
ExtraClaims: map[string]any{
|
|
"foo": []string{"bar"},
|
|
},
|
|
},
|
|
},
|
|
expected: map[string]any{
|
|
"nonce": "foobar",
|
|
"key": "nodekey:0000000000000000000000000000000000000000000000000000000000000000",
|
|
"addresses": nil,
|
|
"nid": float64(0),
|
|
"node": "test-node",
|
|
"tailnet": "test.ts.net",
|
|
"email": "test@example.com",
|
|
"username": "test",
|
|
"foo": []any{"bar"},
|
|
},
|
|
},
|
|
{
|
|
name: "duplicate claim distinct values",
|
|
claim: tailscaleClaims{
|
|
Claims: jwt.Claims{},
|
|
Nonce: "foobar",
|
|
Key: key.NodePublic{},
|
|
Addresses: views.Slice[netip.Prefix]{},
|
|
NodeID: 0,
|
|
NodeName: "test-node",
|
|
Tailnet: "test.ts.net",
|
|
Email: "test@example.com",
|
|
UserID: 0,
|
|
UserName: "test",
|
|
},
|
|
extraClaims: []capRule{
|
|
{
|
|
ExtraClaims: map[string]any{
|
|
"foo": []string{"bar"},
|
|
},
|
|
},
|
|
{
|
|
ExtraClaims: map[string]any{
|
|
"foo": []string{"foobar"},
|
|
},
|
|
},
|
|
},
|
|
expected: map[string]any{
|
|
"nonce": "foobar",
|
|
"key": "nodekey:0000000000000000000000000000000000000000000000000000000000000000",
|
|
"addresses": nil,
|
|
"nid": float64(0),
|
|
"node": "test-node",
|
|
"tailnet": "test.ts.net",
|
|
"email": "test@example.com",
|
|
"username": "test",
|
|
"foo": []any{"foobar", "bar"},
|
|
},
|
|
},
|
|
{
|
|
name: "multiple extra claims",
|
|
claim: tailscaleClaims{
|
|
Claims: jwt.Claims{},
|
|
Nonce: "foobar",
|
|
Key: key.NodePublic{},
|
|
Addresses: views.Slice[netip.Prefix]{},
|
|
NodeID: 0,
|
|
NodeName: "test-node",
|
|
Tailnet: "test.ts.net",
|
|
Email: "test@example.com",
|
|
UserID: 0,
|
|
UserName: "test",
|
|
},
|
|
extraClaims: []capRule{
|
|
{
|
|
ExtraClaims: map[string]any{
|
|
"foo": []string{"bar"},
|
|
},
|
|
},
|
|
{
|
|
ExtraClaims: map[string]any{
|
|
"bar": []string{"foo"},
|
|
},
|
|
},
|
|
},
|
|
expected: map[string]any{
|
|
"nonce": "foobar",
|
|
"key": "nodekey:0000000000000000000000000000000000000000000000000000000000000000",
|
|
"addresses": nil,
|
|
"nid": float64(0),
|
|
"node": "test-node",
|
|
"tailnet": "test.ts.net",
|
|
"email": "test@example.com",
|
|
"username": "test",
|
|
"foo": []any{"bar"},
|
|
"bar": []any{"foo"},
|
|
},
|
|
},
|
|
{
|
|
name: "overwrite claim",
|
|
claim: tailscaleClaims{
|
|
Claims: jwt.Claims{},
|
|
Nonce: "foobar",
|
|
Key: key.NodePublic{},
|
|
Addresses: views.Slice[netip.Prefix]{},
|
|
NodeID: 0,
|
|
NodeName: "test-node",
|
|
Tailnet: "test.ts.net",
|
|
Email: "test@example.com",
|
|
UserID: 0,
|
|
UserName: "test",
|
|
},
|
|
extraClaims: []capRule{
|
|
{
|
|
ExtraClaims: map[string]any{
|
|
"username": "foobar",
|
|
},
|
|
},
|
|
},
|
|
expected: map[string]any{
|
|
"nonce": "foobar",
|
|
"key": "nodekey:0000000000000000000000000000000000000000000000000000000000000000",
|
|
"addresses": nil,
|
|
"nid": float64(0),
|
|
"node": "test-node",
|
|
"tailnet": "test.ts.net",
|
|
"email": "test@example.com",
|
|
"username": "foobar",
|
|
},
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "empty extra claims",
|
|
claim: tailscaleClaims{
|
|
Claims: jwt.Claims{},
|
|
Nonce: "foobar",
|
|
Key: key.NodePublic{},
|
|
Addresses: views.Slice[netip.Prefix]{},
|
|
NodeID: 0,
|
|
NodeName: "test-node",
|
|
Tailnet: "test.ts.net",
|
|
Email: "test@example.com",
|
|
UserID: 0,
|
|
UserName: "test",
|
|
},
|
|
extraClaims: []capRule{{ExtraClaims: map[string]any{}}},
|
|
expected: map[string]any{
|
|
"nonce": "foobar",
|
|
"key": "nodekey:0000000000000000000000000000000000000000000000000000000000000000",
|
|
"addresses": nil,
|
|
"nid": float64(0),
|
|
"node": "test-node",
|
|
"tailnet": "test.ts.net",
|
|
"email": "test@example.com",
|
|
"username": "test",
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
claims, err := withExtraClaims(tt.claim, tt.extraClaims)
|
|
if err != nil && !tt.expectError {
|
|
t.Fatalf("claim.withExtraClaims() unexpected error = %v", err)
|
|
} else if err == nil && tt.expectError {
|
|
t.Fatalf("expected error, got nil")
|
|
} else if err != nil && tt.expectError {
|
|
return // just as expected
|
|
}
|
|
|
|
// Marshal to JSON then unmarshal back to map[string]any
|
|
gotClaims, err := json.Marshal(claims)
|
|
if err != nil {
|
|
t.Errorf("json.Marshal(claims) error = %v", err)
|
|
}
|
|
|
|
var gotClaimsMap map[string]any
|
|
if err := json.Unmarshal(gotClaims, &gotClaimsMap); err != nil {
|
|
t.Fatalf("json.Unmarshal(gotClaims) error = %v", err)
|
|
}
|
|
|
|
gotNormalized := normalizeMap(t, gotClaimsMap)
|
|
expectedNormalized := normalizeMap(t, tt.expected)
|
|
|
|
if !reflect.DeepEqual(gotNormalized, expectedNormalized) {
|
|
t.Errorf("claims mismatch:\n got: %#v\nwant: %#v", gotNormalized, expectedNormalized)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestServeToken(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
caps tailcfg.PeerCapMap
|
|
method string
|
|
grantType string
|
|
code string
|
|
omitCode bool
|
|
redirectURI string
|
|
remoteAddr string
|
|
strictMode bool
|
|
expectError bool
|
|
expected map[string]any
|
|
}{
|
|
{
|
|
name: "GET not allowed",
|
|
method: "GET",
|
|
grantType: "authorization_code",
|
|
strictMode: false,
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "unsupported grant type",
|
|
method: "POST",
|
|
grantType: "pkcs",
|
|
strictMode: false,
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "invalid code",
|
|
method: "POST",
|
|
grantType: "authorization_code",
|
|
code: "invalid-code",
|
|
strictMode: false,
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "omit code from form",
|
|
method: "POST",
|
|
grantType: "authorization_code",
|
|
omitCode: true,
|
|
strictMode: false,
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "invalid redirect uri",
|
|
method: "POST",
|
|
grantType: "authorization_code",
|
|
code: "valid-code",
|
|
redirectURI: "https://invalid.example.com/callback",
|
|
remoteAddr: "127.0.0.1:12345",
|
|
strictMode: false,
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "invalid remoteAddr",
|
|
method: "POST",
|
|
grantType: "authorization_code",
|
|
redirectURI: "https://rp.example.com/callback",
|
|
code: "valid-code",
|
|
remoteAddr: "192.168.0.1:12345",
|
|
strictMode: false,
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "extra claim included (non-strict)",
|
|
method: "POST",
|
|
grantType: "authorization_code",
|
|
redirectURI: "https://rp.example.com/callback",
|
|
code: "valid-code",
|
|
remoteAddr: "127.0.0.1:12345",
|
|
strictMode: false,
|
|
caps: tailcfg.PeerCapMap{
|
|
tailcfg.PeerCapabilityTsIDP: {
|
|
mustMarshalJSON(t, capRule{
|
|
IncludeInUserInfo: true,
|
|
ExtraClaims: map[string]any{
|
|
"foo": "bar",
|
|
},
|
|
}),
|
|
},
|
|
},
|
|
expected: map[string]any{
|
|
"foo": "bar",
|
|
},
|
|
},
|
|
{
|
|
name: "attempt to overwrite protected claim (non-strict)",
|
|
method: "POST",
|
|
grantType: "authorization_code",
|
|
redirectURI: "https://rp.example.com/callback",
|
|
code: "valid-code",
|
|
strictMode: false,
|
|
caps: tailcfg.PeerCapMap{
|
|
tailcfg.PeerCapabilityTsIDP: {
|
|
mustMarshalJSON(t, capRule{
|
|
IncludeInUserInfo: true,
|
|
ExtraClaims: map[string]any{
|
|
"sub": "should-not-overwrite",
|
|
},
|
|
}),
|
|
},
|
|
},
|
|
expectError: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
now := time.Now()
|
|
|
|
// Use setupTestServer helper
|
|
s := setupTestServer(t, tt.strictMode)
|
|
|
|
// Fake user/node
|
|
profile := &tailcfg.UserProfile{
|
|
LoginName: "alice@example.com",
|
|
DisplayName: "Alice Example",
|
|
ProfilePicURL: "https://example.com/alice.jpg",
|
|
}
|
|
node := &tailcfg.Node{
|
|
ID: 123,
|
|
Name: "test-node.test.ts.net.",
|
|
User: 456,
|
|
Key: key.NodePublic{},
|
|
Cap: 1,
|
|
DiscoKey: key.DiscoPublic{},
|
|
}
|
|
|
|
remoteUser := &apitype.WhoIsResponse{
|
|
Node: node,
|
|
UserProfile: profile,
|
|
CapMap: tt.caps,
|
|
}
|
|
|
|
// Setup auth request with appropriate configuration for strict mode
|
|
var funnelClientPtr *funnelClient
|
|
if tt.strictMode {
|
|
funnelClientPtr = &funnelClient{
|
|
ID: "client-id",
|
|
Secret: "test-secret",
|
|
Name: "Test Client",
|
|
RedirectURI: "https://rp.example.com/callback",
|
|
}
|
|
s.funnelClients["client-id"] = funnelClientPtr
|
|
}
|
|
|
|
s.code["valid-code"] = &authRequest{
|
|
clientID: "client-id",
|
|
nonce: "nonce123",
|
|
redirectURI: "https://rp.example.com/callback",
|
|
validTill: now.Add(5 * time.Minute),
|
|
remoteUser: remoteUser,
|
|
localRP: !tt.strictMode,
|
|
funnelRP: funnelClientPtr,
|
|
}
|
|
|
|
form := url.Values{}
|
|
form.Set("grant_type", tt.grantType)
|
|
form.Set("redirect_uri", tt.redirectURI)
|
|
if !tt.omitCode {
|
|
form.Set("code", tt.code)
|
|
}
|
|
// Add client credentials for strict mode
|
|
if tt.strictMode {
|
|
form.Set("client_id", "client-id")
|
|
form.Set("client_secret", "test-secret")
|
|
}
|
|
|
|
req := httptest.NewRequest(tt.method, "/token", strings.NewReader(form.Encode()))
|
|
req.RemoteAddr = tt.remoteAddr
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
rr := httptest.NewRecorder()
|
|
|
|
s.serveToken(rr, req)
|
|
|
|
if tt.expectError {
|
|
if rr.Code == http.StatusOK {
|
|
t.Fatalf("expected error, got 200 OK: %s", rr.Body.String())
|
|
}
|
|
return
|
|
}
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Fatalf("expected 200 OK, got %d: %s", rr.Code, rr.Body.String())
|
|
}
|
|
|
|
var resp struct {
|
|
IDToken string `json:"id_token"`
|
|
}
|
|
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
|
|
t.Fatalf("failed to unmarshal response: %v", err)
|
|
}
|
|
|
|
tok, err := jwt.ParseSigned(resp.IDToken)
|
|
if err != nil {
|
|
t.Fatalf("failed to parse ID token: %v", err)
|
|
}
|
|
|
|
out := make(map[string]any)
|
|
if err := tok.Claims(oidcTestingPublicKey(t), &out); err != nil {
|
|
t.Fatalf("failed to extract claims: %v", err)
|
|
}
|
|
|
|
for k, want := range tt.expected {
|
|
got, ok := out[k]
|
|
if !ok {
|
|
t.Errorf("missing expected claim %q", k)
|
|
continue
|
|
}
|
|
if !reflect.DeepEqual(got, want) {
|
|
t.Errorf("claim %q: got %v, want %v", k, got, want)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestExtraUserInfo(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
caps tailcfg.PeerCapMap
|
|
tokenValidTill time.Time
|
|
expected map[string]any
|
|
expectError bool
|
|
}{
|
|
{
|
|
name: "extra claim",
|
|
tokenValidTill: time.Now().Add(1 * time.Minute),
|
|
caps: tailcfg.PeerCapMap{
|
|
tailcfg.PeerCapabilityTsIDP: {
|
|
mustMarshalJSON(t, capRule{
|
|
IncludeInUserInfo: true,
|
|
ExtraClaims: map[string]any{
|
|
"foo": []string{"bar"},
|
|
},
|
|
}),
|
|
},
|
|
},
|
|
expected: map[string]any{
|
|
"foo": []any{"bar"},
|
|
},
|
|
},
|
|
{
|
|
name: "duplicate claim distinct values",
|
|
tokenValidTill: time.Now().Add(1 * time.Minute),
|
|
caps: tailcfg.PeerCapMap{
|
|
tailcfg.PeerCapabilityTsIDP: {
|
|
mustMarshalJSON(t, capRule{
|
|
IncludeInUserInfo: true,
|
|
ExtraClaims: map[string]any{
|
|
"foo": []string{"bar", "foobar"},
|
|
},
|
|
}),
|
|
},
|
|
},
|
|
expected: map[string]any{
|
|
"foo": []any{"bar", "foobar"},
|
|
},
|
|
},
|
|
{
|
|
name: "multiple extra claims",
|
|
tokenValidTill: time.Now().Add(1 * time.Minute),
|
|
caps: tailcfg.PeerCapMap{
|
|
tailcfg.PeerCapabilityTsIDP: {
|
|
mustMarshalJSON(t, capRule{
|
|
IncludeInUserInfo: true,
|
|
ExtraClaims: map[string]any{
|
|
"foo": "bar",
|
|
"bar": "foo",
|
|
},
|
|
}),
|
|
},
|
|
},
|
|
expected: map[string]any{
|
|
"foo": "bar",
|
|
"bar": "foo",
|
|
},
|
|
},
|
|
{
|
|
name: "empty extra claims",
|
|
caps: tailcfg.PeerCapMap{},
|
|
tokenValidTill: time.Now().Add(1 * time.Minute),
|
|
expected: map[string]any{},
|
|
},
|
|
{
|
|
name: "attempt to overwrite protected claim",
|
|
tokenValidTill: time.Now().Add(1 * time.Minute),
|
|
caps: tailcfg.PeerCapMap{
|
|
tailcfg.PeerCapabilityTsIDP: {
|
|
mustMarshalJSON(t, capRule{
|
|
IncludeInUserInfo: true,
|
|
ExtraClaims: map[string]any{
|
|
"sub": "should-not-overwrite",
|
|
"foo": "ok",
|
|
},
|
|
}),
|
|
},
|
|
},
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "extra claim omitted",
|
|
tokenValidTill: time.Now().Add(1 * time.Minute),
|
|
caps: tailcfg.PeerCapMap{
|
|
tailcfg.PeerCapabilityTsIDP: {
|
|
mustMarshalJSON(t, capRule{
|
|
IncludeInUserInfo: false,
|
|
ExtraClaims: map[string]any{
|
|
"foo": "ok",
|
|
},
|
|
}),
|
|
},
|
|
},
|
|
expected: map[string]any{},
|
|
},
|
|
{
|
|
name: "expired token",
|
|
caps: tailcfg.PeerCapMap{},
|
|
tokenValidTill: time.Now().Add(-1 * time.Minute),
|
|
expected: map[string]any{},
|
|
expectError: true,
|
|
},
|
|
}
|
|
token := "valid-token"
|
|
|
|
// Create a fake tailscale Node
|
|
node := &tailcfg.Node{
|
|
ID: 123,
|
|
Name: "test-node.test.ts.net.",
|
|
User: 456,
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
// Construct the remote user
|
|
profile := tailcfg.UserProfile{
|
|
LoginName: "alice@example.com",
|
|
DisplayName: "Alice Example",
|
|
ProfilePicURL: "https://example.com/alice.jpg",
|
|
}
|
|
|
|
remoteUser := &apitype.WhoIsResponse{
|
|
Node: node,
|
|
UserProfile: &profile,
|
|
CapMap: tt.caps,
|
|
}
|
|
|
|
// Insert a valid token into the idpServer
|
|
s := &idpServer{
|
|
allowInsecureRegistration: true, // Default to allowing insecure registration for backward compatibility
|
|
accessToken: map[string]*authRequest{
|
|
token: {
|
|
validTill: tt.tokenValidTill,
|
|
remoteUser: remoteUser,
|
|
},
|
|
},
|
|
}
|
|
|
|
// Construct request
|
|
req := httptest.NewRequest("GET", "/userinfo", nil)
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
rr := httptest.NewRecorder()
|
|
|
|
// Call the method under test
|
|
s.serveUserInfo(rr, req)
|
|
|
|
if tt.expectError {
|
|
if rr.Code == http.StatusOK {
|
|
t.Fatalf("expected error, got %d: %s", rr.Code, rr.Body.String())
|
|
}
|
|
return
|
|
}
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Fatalf("expected 200 OK, got %d: %s", rr.Code, rr.Body.String())
|
|
}
|
|
|
|
var resp map[string]any
|
|
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
|
|
t.Fatalf("failed to parse JSON response: %v", err)
|
|
}
|
|
|
|
// Construct expected
|
|
tt.expected["sub"] = remoteUser.Node.User.String()
|
|
tt.expected["name"] = profile.DisplayName
|
|
tt.expected["email"] = profile.LoginName
|
|
tt.expected["picture"] = profile.ProfilePicURL
|
|
tt.expected["username"], _, _ = strings.Cut(profile.LoginName, "@")
|
|
|
|
gotNormalized := normalizeMap(t, resp)
|
|
expectedNormalized := normalizeMap(t, tt.expected)
|
|
|
|
if !reflect.DeepEqual(gotNormalized, expectedNormalized) {
|
|
t.Errorf("UserInfo mismatch:\n got: %#v\nwant: %#v", gotNormalized, expectedNormalized)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestFunnelClientsPersistence(t *testing.T) {
|
|
testClients := map[string]*funnelClient{
|
|
"test-client-1": {
|
|
ID: "test-client-1",
|
|
Secret: "test-secret-1",
|
|
Name: "Test Client 1",
|
|
RedirectURI: "https://example.com/callback",
|
|
},
|
|
"test-client-2": {
|
|
ID: "test-client-2",
|
|
Secret: "test-secret-2",
|
|
Name: "Test Client 2",
|
|
RedirectURI: "https://example2.com/callback",
|
|
},
|
|
}
|
|
|
|
testData, err := json.Marshal(testClients)
|
|
if err != nil {
|
|
t.Fatalf("failed to marshal test data: %v", err)
|
|
}
|
|
|
|
tmpFile := t.TempDir() + "/oidc-funnel-clients.json"
|
|
if err := os.WriteFile(tmpFile, testData, 0600); err != nil {
|
|
t.Fatalf("failed to write test file: %v", err)
|
|
}
|
|
|
|
t.Run("load_from_existing_file", func(t *testing.T) {
|
|
srv := &idpServer{}
|
|
|
|
// Simulate the funnel clients loading logic from main()
|
|
srv.funnelClients = make(map[string]*funnelClient)
|
|
f, err := os.Open(tmpFile)
|
|
if err == nil {
|
|
if err := json.NewDecoder(f).Decode(&srv.funnelClients); err != nil {
|
|
t.Fatalf("could not parse %s: %v", tmpFile, err)
|
|
}
|
|
f.Close()
|
|
} else if !errors.Is(err, os.ErrNotExist) {
|
|
t.Fatalf("could not open %s: %v", tmpFile, err)
|
|
}
|
|
|
|
// Verify clients were loaded correctly
|
|
if len(srv.funnelClients) != 2 {
|
|
t.Errorf("expected 2 clients, got %d", len(srv.funnelClients))
|
|
}
|
|
|
|
client1, ok := srv.funnelClients["test-client-1"]
|
|
if !ok {
|
|
t.Error("expected test-client-1 to be loaded")
|
|
} else {
|
|
if client1.Name != "Test Client 1" {
|
|
t.Errorf("expected client name 'Test Client 1', got '%s'", client1.Name)
|
|
}
|
|
if client1.Secret != "test-secret-1" {
|
|
t.Errorf("expected client secret 'test-secret-1', got '%s'", client1.Secret)
|
|
}
|
|
}
|
|
})
|
|
|
|
t.Run("initialize_empty_when_no_file", func(t *testing.T) {
|
|
nonExistentFile := t.TempDir() + "/non-existent.json"
|
|
|
|
srv := &idpServer{}
|
|
|
|
// Simulate the funnel clients loading logic from main()
|
|
srv.funnelClients = make(map[string]*funnelClient)
|
|
f, err := os.Open(nonExistentFile)
|
|
if err == nil {
|
|
if err := json.NewDecoder(f).Decode(&srv.funnelClients); err != nil {
|
|
t.Fatalf("could not parse %s: %v", nonExistentFile, err)
|
|
}
|
|
f.Close()
|
|
} else if !errors.Is(err, os.ErrNotExist) {
|
|
t.Fatalf("could not open %s: %v", nonExistentFile, err)
|
|
}
|
|
|
|
// Verify map is initialized but empty
|
|
if srv.funnelClients == nil {
|
|
t.Error("expected funnelClients map to be initialized")
|
|
}
|
|
if len(srv.funnelClients) != 0 {
|
|
t.Errorf("expected empty map, got %d clients", len(srv.funnelClients))
|
|
}
|
|
})
|
|
|
|
t.Run("persist_and_reload_clients", func(t *testing.T) {
|
|
tmpFile2 := t.TempDir() + "/test-persistence.json"
|
|
|
|
// Create initial server with one client
|
|
srv1 := &idpServer{
|
|
funnelClients: make(map[string]*funnelClient),
|
|
}
|
|
srv1.funnelClients["new-client"] = &funnelClient{
|
|
ID: "new-client",
|
|
Secret: "new-secret",
|
|
Name: "New Client",
|
|
RedirectURI: "https://new.example.com/callback",
|
|
}
|
|
|
|
// Save clients to file (simulating saveFunnelClients)
|
|
data, err := json.Marshal(srv1.funnelClients)
|
|
if err != nil {
|
|
t.Fatalf("failed to marshal clients: %v", err)
|
|
}
|
|
if err := os.WriteFile(tmpFile2, data, 0600); err != nil {
|
|
t.Fatalf("failed to write clients file: %v", err)
|
|
}
|
|
|
|
// Create new server instance and load clients
|
|
srv2 := &idpServer{}
|
|
srv2.funnelClients = make(map[string]*funnelClient)
|
|
f, err := os.Open(tmpFile2)
|
|
if err == nil {
|
|
if err := json.NewDecoder(f).Decode(&srv2.funnelClients); err != nil {
|
|
t.Fatalf("could not parse %s: %v", tmpFile2, err)
|
|
}
|
|
f.Close()
|
|
} else if !errors.Is(err, os.ErrNotExist) {
|
|
t.Fatalf("could not open %s: %v", tmpFile2, err)
|
|
}
|
|
|
|
// Verify the client was persisted correctly
|
|
loadedClient, ok := srv2.funnelClients["new-client"]
|
|
if !ok {
|
|
t.Error("expected new-client to be loaded after persistence")
|
|
} else {
|
|
if loadedClient.Name != "New Client" {
|
|
t.Errorf("expected client name 'New Client', got '%s'", loadedClient.Name)
|
|
}
|
|
if loadedClient.Secret != "new-secret" {
|
|
t.Errorf("expected client secret 'new-secret', got '%s'", loadedClient.Secret)
|
|
}
|
|
}
|
|
})
|
|
|
|
t.Run("strict_mode_file_handling", func(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
|
|
// Test strict mode uses oauth-clients.json
|
|
srv1 := setupTestServer(t, true)
|
|
srv1.rootPath = tmpDir
|
|
srv1.funnelClients["oauth-client"] = &funnelClient{
|
|
ID: "oauth-client",
|
|
Secret: "oauth-secret",
|
|
Name: "OAuth Client",
|
|
RedirectURI: "https://oauth.example.com/callback",
|
|
}
|
|
|
|
// Test storeFunnelClientsLocked in strict mode
|
|
srv1.mu.Lock()
|
|
err := srv1.storeFunnelClientsLocked()
|
|
srv1.mu.Unlock()
|
|
|
|
if err != nil {
|
|
t.Fatalf("failed to store clients in strict mode: %v", err)
|
|
}
|
|
|
|
// Verify oauth-clients.json was created
|
|
oauthPath := tmpDir + "/" + oauthClientsFile
|
|
if _, err := os.Stat(oauthPath); err != nil {
|
|
t.Errorf("expected oauth-clients.json to be created: %v", err)
|
|
}
|
|
|
|
// Verify oidc-funnel-clients.json was NOT created
|
|
funnelPath := tmpDir + "/" + funnelClientsFile
|
|
if _, err := os.Stat(funnelPath); !os.IsNotExist(err) {
|
|
t.Error("expected oidc-funnel-clients.json NOT to be created in strict mode")
|
|
}
|
|
})
|
|
|
|
t.Run("non_strict_mode_file_handling", func(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
|
|
// Test non-strict mode uses oidc-funnel-clients.json
|
|
srv1 := setupTestServer(t, false)
|
|
srv1.rootPath = tmpDir
|
|
srv1.funnelClients["funnel-client"] = &funnelClient{
|
|
ID: "funnel-client",
|
|
Secret: "funnel-secret",
|
|
Name: "Funnel Client",
|
|
RedirectURI: "https://funnel.example.com/callback",
|
|
}
|
|
|
|
// Test storeFunnelClientsLocked in non-strict mode
|
|
srv1.mu.Lock()
|
|
err := srv1.storeFunnelClientsLocked()
|
|
srv1.mu.Unlock()
|
|
|
|
if err != nil {
|
|
t.Fatalf("failed to store clients in non-strict mode: %v", err)
|
|
}
|
|
|
|
// Verify oidc-funnel-clients.json was created
|
|
funnelPath := tmpDir + "/" + funnelClientsFile
|
|
if _, err := os.Stat(funnelPath); err != nil {
|
|
t.Errorf("expected oidc-funnel-clients.json to be created: %v", err)
|
|
}
|
|
|
|
// Verify oauth-clients.json was NOT created
|
|
oauthPath := tmpDir + "/" + oauthClientsFile
|
|
if _, err := os.Stat(oauthPath); !os.IsNotExist(err) {
|
|
t.Error("expected oauth-clients.json NOT to be created in non-strict mode")
|
|
}
|
|
})
|
|
}
|
|
|
|
// Test helper functions for strict OAuth mode testing
|
|
func setupTestServer(t *testing.T, strictMode bool) *idpServer {
|
|
return setupTestServerWithClient(t, strictMode, nil)
|
|
}
|
|
|
|
// setupTestServerWithClient creates a test server with an optional LocalClient.
|
|
// If lc is nil, the server will have no LocalClient (original behavior).
|
|
// If lc is provided, it will be used for WhoIs calls during testing.
|
|
func setupTestServerWithClient(t *testing.T, strictMode bool, lc *local.Client) *idpServer {
|
|
t.Helper()
|
|
|
|
srv := &idpServer{
|
|
allowInsecureRegistration: !strictMode,
|
|
code: make(map[string]*authRequest),
|
|
accessToken: make(map[string]*authRequest),
|
|
funnelClients: make(map[string]*funnelClient),
|
|
serverURL: "https://test.ts.net",
|
|
rootPath: t.TempDir(),
|
|
lc: lc,
|
|
}
|
|
|
|
// Add a test client for funnel/strict mode testing
|
|
srv.funnelClients["test-client"] = &funnelClient{
|
|
ID: "test-client",
|
|
Secret: "test-secret",
|
|
Name: "Test Client",
|
|
RedirectURI: "https://rp.example.com/callback",
|
|
}
|
|
|
|
// Inject a working signer for token tests
|
|
srv.lazySigner.Set(oidcTestingSigner(t))
|
|
|
|
return srv
|
|
}
|
|
|
|
func TestGetAllowInsecureRegistration(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
flagSet bool
|
|
flagValue bool
|
|
expectAllowInsecureRegistration bool
|
|
}{
|
|
{
|
|
name: "flag explicitly set to false - insecure registration disabled (strict mode)",
|
|
flagSet: true,
|
|
flagValue: false,
|
|
expectAllowInsecureRegistration: false,
|
|
},
|
|
{
|
|
name: "flag explicitly set to true - insecure registration enabled",
|
|
flagSet: true,
|
|
flagValue: true,
|
|
expectAllowInsecureRegistration: true,
|
|
},
|
|
{
|
|
name: "flag unset - insecure registration enabled (default for backward compatibility)",
|
|
flagSet: false,
|
|
flagValue: false, // not used when unset
|
|
expectAllowInsecureRegistration: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Save original state
|
|
originalFlag := flagAllowInsecureRegistration
|
|
defer func() {
|
|
flagAllowInsecureRegistration = originalFlag
|
|
}()
|
|
|
|
// Set up test state by creating a new BoolFlag and setting values
|
|
var b opt.Bool
|
|
flagAllowInsecureRegistration = opt.BoolFlag{Bool: &b}
|
|
if tt.flagSet {
|
|
flagAllowInsecureRegistration.Bool.Set(tt.flagValue)
|
|
}
|
|
// Note: when tt.flagSet is false, the Bool remains unset (which is what we want)
|
|
|
|
got := getAllowInsecureRegistration()
|
|
if got != tt.expectAllowInsecureRegistration {
|
|
t.Errorf("getAllowInsecureRegistration() = %v, want %v", got, tt.expectAllowInsecureRegistration)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestMigrateOAuthClients verifies the migration from legacy funnel clients
|
|
// to OAuth clients. This migration is necessary when transitioning from
|
|
// non-strict to strict OAuth mode. The migration logic should:
|
|
// - Copy clients from oidc-funnel-clients.json to oauth-clients.json
|
|
// - Rename the old file to mark it as deprecated
|
|
// - Handle cases where files already exist or are missing
|
|
func TestMigrateOAuthClients(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
setupOldFile bool
|
|
setupNewFile bool
|
|
oldFileContent map[string]*funnelClient
|
|
newFileContent map[string]*funnelClient
|
|
expectError bool
|
|
expectNewFileExists bool
|
|
expectOldRenamed bool
|
|
}{
|
|
{
|
|
name: "migrate from old file to new file",
|
|
setupOldFile: true,
|
|
oldFileContent: map[string]*funnelClient{
|
|
"old-client": {
|
|
ID: "old-client",
|
|
Secret: "old-secret",
|
|
Name: "Old Client",
|
|
RedirectURI: "https://old.example.com/callback",
|
|
},
|
|
},
|
|
expectNewFileExists: true,
|
|
expectOldRenamed: true,
|
|
},
|
|
{
|
|
name: "new file already exists - no migration",
|
|
setupNewFile: true,
|
|
newFileContent: map[string]*funnelClient{
|
|
"existing-client": {
|
|
ID: "existing-client",
|
|
Secret: "existing-secret",
|
|
Name: "Existing Client",
|
|
RedirectURI: "https://existing.example.com/callback",
|
|
},
|
|
},
|
|
expectNewFileExists: true,
|
|
expectOldRenamed: false,
|
|
},
|
|
{
|
|
name: "neither file exists - create empty new file",
|
|
expectNewFileExists: true,
|
|
expectOldRenamed: false,
|
|
},
|
|
{
|
|
name: "both files exist - prefer new file",
|
|
setupOldFile: true,
|
|
setupNewFile: true,
|
|
oldFileContent: map[string]*funnelClient{
|
|
"old-client": {
|
|
ID: "old-client",
|
|
Secret: "old-secret",
|
|
Name: "Old Client",
|
|
RedirectURI: "https://old.example.com/callback",
|
|
},
|
|
},
|
|
newFileContent: map[string]*funnelClient{
|
|
"new-client": {
|
|
ID: "new-client",
|
|
Secret: "new-secret",
|
|
Name: "New Client",
|
|
RedirectURI: "https://new.example.com/callback",
|
|
},
|
|
},
|
|
expectNewFileExists: true,
|
|
expectOldRenamed: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
rootPath := t.TempDir()
|
|
|
|
// Setup old file if needed
|
|
if tt.setupOldFile {
|
|
oldData, err := json.Marshal(tt.oldFileContent)
|
|
if err != nil {
|
|
t.Fatalf("failed to marshal old file content: %v", err)
|
|
}
|
|
oldPath := rootPath + "/" + funnelClientsFile
|
|
if err := os.WriteFile(oldPath, oldData, 0600); err != nil {
|
|
t.Fatalf("failed to create old file: %v", err)
|
|
}
|
|
}
|
|
|
|
// Setup new file if needed
|
|
if tt.setupNewFile {
|
|
newData, err := json.Marshal(tt.newFileContent)
|
|
if err != nil {
|
|
t.Fatalf("failed to marshal new file content: %v", err)
|
|
}
|
|
newPath := rootPath + "/" + oauthClientsFile
|
|
if err := os.WriteFile(newPath, newData, 0600); err != nil {
|
|
t.Fatalf("failed to create new file: %v", err)
|
|
}
|
|
}
|
|
|
|
// Call migrateOAuthClients
|
|
resultPath, err := migrateOAuthClients(rootPath)
|
|
|
|
if tt.expectError && err == nil {
|
|
t.Fatalf("expected error but got none")
|
|
}
|
|
if !tt.expectError && err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if tt.expectError {
|
|
return
|
|
}
|
|
|
|
// Verify result path points to oauth-clients.json
|
|
expectedPath := filepath.Join(rootPath, oauthClientsFile)
|
|
if resultPath != expectedPath {
|
|
t.Errorf("expected result path %s, got %s", expectedPath, resultPath)
|
|
}
|
|
|
|
// Verify new file exists if expected
|
|
if tt.expectNewFileExists {
|
|
if _, err := os.Stat(resultPath); err != nil {
|
|
t.Errorf("expected new file to exist at %s: %v", resultPath, err)
|
|
}
|
|
|
|
// Verify content
|
|
data, err := os.ReadFile(resultPath)
|
|
if err != nil {
|
|
t.Fatalf("failed to read new file: %v", err)
|
|
}
|
|
|
|
var clients map[string]*funnelClient
|
|
if err := json.Unmarshal(data, &clients); err != nil {
|
|
t.Fatalf("failed to unmarshal new file: %v", err)
|
|
}
|
|
|
|
// Determine expected content
|
|
var expectedContent map[string]*funnelClient
|
|
if tt.setupNewFile {
|
|
expectedContent = tt.newFileContent
|
|
} else if tt.setupOldFile {
|
|
expectedContent = tt.oldFileContent
|
|
} else {
|
|
expectedContent = make(map[string]*funnelClient)
|
|
}
|
|
|
|
if len(clients) != len(expectedContent) {
|
|
t.Errorf("expected %d clients, got %d", len(expectedContent), len(clients))
|
|
}
|
|
|
|
for id, expectedClient := range expectedContent {
|
|
actualClient, ok := clients[id]
|
|
if !ok {
|
|
t.Errorf("expected client %s not found", id)
|
|
continue
|
|
}
|
|
if actualClient.ID != expectedClient.ID ||
|
|
actualClient.Secret != expectedClient.Secret ||
|
|
actualClient.Name != expectedClient.Name ||
|
|
actualClient.RedirectURI != expectedClient.RedirectURI {
|
|
t.Errorf("client %s mismatch: got %+v, want %+v", id, actualClient, expectedClient)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Verify old file renamed if expected
|
|
if tt.expectOldRenamed {
|
|
deprecatedPath := rootPath + "/" + deprecatedFunnelClientsFile
|
|
if _, err := os.Stat(deprecatedPath); err != nil {
|
|
t.Errorf("expected old file to be renamed to %s: %v", deprecatedPath, err)
|
|
}
|
|
|
|
// Verify original old file is gone
|
|
oldPath := rootPath + "/" + funnelClientsFile
|
|
if _, err := os.Stat(oldPath); !os.IsNotExist(err) {
|
|
t.Errorf("expected old file %s to be removed", oldPath)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestGetConfigFilePath verifies backward compatibility for config file location.
|
|
// The function must check current directory first (legacy deployments) before
|
|
// falling back to rootPath (new installations) to prevent breaking existing
|
|
// tsidp deployments that have config files in unexpected locations.
|
|
func TestGetConfigFilePath(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
fileName string
|
|
createInCwd bool
|
|
createInRoot bool
|
|
expectInCwd bool
|
|
expectError bool
|
|
}{
|
|
{
|
|
name: "file exists in current directory - use current directory",
|
|
fileName: "test-config.json",
|
|
createInCwd: true,
|
|
expectInCwd: true,
|
|
},
|
|
{
|
|
name: "file does not exist - use root path",
|
|
fileName: "test-config.json",
|
|
createInCwd: false,
|
|
expectInCwd: false,
|
|
},
|
|
{
|
|
name: "file exists in both - prefer current directory",
|
|
fileName: "test-config.json",
|
|
createInCwd: true,
|
|
createInRoot: true,
|
|
expectInCwd: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Create temporary directories
|
|
rootPath := t.TempDir()
|
|
originalWd, err := os.Getwd()
|
|
if err != nil {
|
|
t.Fatalf("failed to get working directory: %v", err)
|
|
}
|
|
|
|
// Create a temporary working directory
|
|
tmpWd := t.TempDir()
|
|
if err := os.Chdir(tmpWd); err != nil {
|
|
t.Fatalf("failed to change to temp directory: %v", err)
|
|
}
|
|
defer func() {
|
|
os.Chdir(originalWd)
|
|
}()
|
|
|
|
// Setup files as needed
|
|
if tt.createInCwd {
|
|
if err := os.WriteFile(tt.fileName, []byte("{}"), 0600); err != nil {
|
|
t.Fatalf("failed to create file in cwd: %v", err)
|
|
}
|
|
}
|
|
if tt.createInRoot {
|
|
rootFilePath := filepath.Join(rootPath, tt.fileName)
|
|
if err := os.WriteFile(rootFilePath, []byte("{}"), 0600); err != nil {
|
|
t.Fatalf("failed to create file in root: %v", err)
|
|
}
|
|
}
|
|
|
|
// Call getConfigFilePath
|
|
resultPath, err := getConfigFilePath(rootPath, tt.fileName)
|
|
|
|
if tt.expectError && err == nil {
|
|
t.Fatalf("expected error but got none")
|
|
}
|
|
if !tt.expectError && err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if tt.expectError {
|
|
return
|
|
}
|
|
|
|
// Verify result
|
|
if tt.expectInCwd {
|
|
if resultPath != tt.fileName {
|
|
t.Errorf("expected path %s, got %s", tt.fileName, resultPath)
|
|
}
|
|
} else {
|
|
expectedPath := filepath.Join(rootPath, tt.fileName)
|
|
if resultPath != expectedPath {
|
|
t.Errorf("expected path %s, got %s", expectedPath, resultPath)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestAuthorizeStrictMode verifies OAuth authorization endpoint security and validation logic.
|
|
// Tests both the security boundary (funnel rejection) and the business logic (strict mode validation).
|
|
func TestAuthorizeStrictMode(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
strictMode bool
|
|
clientID string
|
|
redirectURI string
|
|
state string
|
|
nonce string
|
|
setupClient bool
|
|
clientRedirect string
|
|
useFunnel bool // whether to simulate funnel request
|
|
mockWhoIsError bool // whether to make WhoIs return an error
|
|
expectError bool
|
|
expectCode int
|
|
expectRedirect bool
|
|
}{
|
|
// Security boundary test: funnel rejection
|
|
{
|
|
name: "funnel requests are always rejected for security",
|
|
strictMode: true,
|
|
clientID: "test-client",
|
|
redirectURI: "https://rp.example.com/callback",
|
|
state: "random-state",
|
|
nonce: "random-nonce",
|
|
setupClient: true,
|
|
clientRedirect: "https://rp.example.com/callback",
|
|
useFunnel: true,
|
|
expectError: true,
|
|
expectCode: http.StatusUnauthorized,
|
|
},
|
|
|
|
// Strict mode parameter validation tests (non-funnel)
|
|
{
|
|
name: "strict mode - missing client_id",
|
|
strictMode: true,
|
|
clientID: "",
|
|
redirectURI: "https://rp.example.com/callback",
|
|
useFunnel: false,
|
|
expectError: true,
|
|
expectCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "strict mode - missing redirect_uri",
|
|
strictMode: true,
|
|
clientID: "test-client",
|
|
redirectURI: "",
|
|
useFunnel: false,
|
|
expectError: true,
|
|
expectCode: http.StatusBadRequest,
|
|
},
|
|
|
|
// Strict mode client validation tests (non-funnel)
|
|
{
|
|
name: "strict mode - invalid client_id",
|
|
strictMode: true,
|
|
clientID: "invalid-client",
|
|
redirectURI: "https://rp.example.com/callback",
|
|
setupClient: false,
|
|
useFunnel: false,
|
|
expectError: true,
|
|
expectCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "strict mode - redirect_uri mismatch",
|
|
strictMode: true,
|
|
clientID: "test-client",
|
|
redirectURI: "https://wrong.example.com/callback",
|
|
setupClient: true,
|
|
clientRedirect: "https://rp.example.com/callback",
|
|
useFunnel: false,
|
|
expectError: true,
|
|
expectCode: http.StatusBadRequest,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
srv := setupTestServer(t, tt.strictMode)
|
|
|
|
// For non-funnel tests, we'll test the parameter validation logic
|
|
// without needing to mock WhoIs, since the validation happens before WhoIs calls
|
|
|
|
// Setup client if needed
|
|
if tt.setupClient {
|
|
srv.funnelClients["test-client"] = &funnelClient{
|
|
ID: "test-client",
|
|
Secret: "test-secret",
|
|
Name: "Test Client",
|
|
RedirectURI: tt.clientRedirect,
|
|
}
|
|
} else if !tt.strictMode {
|
|
// For non-strict mode tests that don't need a specific client setup
|
|
// but might reference one, clear the default client
|
|
delete(srv.funnelClients, "test-client")
|
|
}
|
|
|
|
// Create request
|
|
reqURL := "/authorize"
|
|
if !tt.strictMode {
|
|
// In non-strict mode, use the node-specific endpoint
|
|
reqURL = "/authorize/123"
|
|
}
|
|
|
|
query := url.Values{}
|
|
if tt.clientID != "" {
|
|
query.Set("client_id", tt.clientID)
|
|
}
|
|
if tt.redirectURI != "" {
|
|
query.Set("redirect_uri", tt.redirectURI)
|
|
}
|
|
if tt.state != "" {
|
|
query.Set("state", tt.state)
|
|
}
|
|
if tt.nonce != "" {
|
|
query.Set("nonce", tt.nonce)
|
|
}
|
|
|
|
reqURL += "?" + query.Encode()
|
|
req := httptest.NewRequest("GET", reqURL, nil)
|
|
req.RemoteAddr = "127.0.0.1:12345"
|
|
|
|
// Set funnel header only when explicitly testing funnel behavior
|
|
if tt.useFunnel {
|
|
req.Header.Set("Tailscale-Funnel-Request", "true")
|
|
}
|
|
|
|
rr := httptest.NewRecorder()
|
|
srv.authorize(rr, req)
|
|
|
|
if tt.expectError {
|
|
if rr.Code != tt.expectCode {
|
|
t.Errorf("expected status code %d, got %d: %s", tt.expectCode, rr.Code, rr.Body.String())
|
|
}
|
|
} else if tt.expectRedirect {
|
|
if rr.Code != http.StatusFound {
|
|
t.Errorf("expected redirect (302), got %d: %s", rr.Code, rr.Body.String())
|
|
}
|
|
|
|
location := rr.Header().Get("Location")
|
|
if location == "" {
|
|
t.Error("expected Location header in redirect response")
|
|
} else {
|
|
// Parse the redirect URL to verify it contains a code
|
|
redirectURL, err := url.Parse(location)
|
|
if err != nil {
|
|
t.Errorf("failed to parse redirect URL: %v", err)
|
|
} else {
|
|
code := redirectURL.Query().Get("code")
|
|
if code == "" {
|
|
t.Error("expected 'code' parameter in redirect URL")
|
|
}
|
|
|
|
// Verify state is preserved if provided
|
|
if tt.state != "" {
|
|
returnedState := redirectURL.Query().Get("state")
|
|
if returnedState != tt.state {
|
|
t.Errorf("expected state '%s', got '%s'", tt.state, returnedState)
|
|
}
|
|
}
|
|
|
|
// Verify the auth request was stored
|
|
srv.mu.Lock()
|
|
ar, ok := srv.code[code]
|
|
srv.mu.Unlock()
|
|
|
|
if !ok {
|
|
t.Error("expected authorization request to be stored")
|
|
} else {
|
|
if ar.clientID != tt.clientID {
|
|
t.Errorf("expected clientID '%s', got '%s'", tt.clientID, ar.clientID)
|
|
}
|
|
if ar.redirectURI != tt.redirectURI {
|
|
t.Errorf("expected redirectURI '%s', got '%s'", tt.redirectURI, ar.redirectURI)
|
|
}
|
|
if ar.nonce != tt.nonce {
|
|
t.Errorf("expected nonce '%s', got '%s'", tt.nonce, ar.nonce)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
t.Errorf("unexpected test case: not expecting error or redirect")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestServeTokenWithClientValidation verifies OAuth token endpoint security in both strict and non-strict modes.
|
|
// In strict mode, the token endpoint must:
|
|
// - Require and validate client credentials (client_id + client_secret)
|
|
// - Only accept tokens from registered funnel clients
|
|
// - Validate that redirect_uri matches the registered client
|
|
// - Support both form-based and HTTP Basic authentication for client credentials
|
|
func TestServeTokenWithClientValidation(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
strictMode bool
|
|
method string
|
|
grantType string
|
|
code string
|
|
clientID string
|
|
clientSecret string
|
|
redirectURI string
|
|
useBasicAuth bool
|
|
setupAuthRequest bool
|
|
authRequestClient string
|
|
authRequestRedirect string
|
|
expectError bool
|
|
expectCode int
|
|
expectIDToken bool
|
|
}{
|
|
{
|
|
name: "strict mode - valid token exchange with form credentials",
|
|
strictMode: true,
|
|
method: "POST",
|
|
grantType: "authorization_code",
|
|
code: "valid-code",
|
|
clientID: "test-client",
|
|
clientSecret: "test-secret",
|
|
redirectURI: "https://rp.example.com/callback",
|
|
setupAuthRequest: true,
|
|
authRequestClient: "test-client",
|
|
authRequestRedirect: "https://rp.example.com/callback",
|
|
expectIDToken: true,
|
|
},
|
|
{
|
|
name: "strict mode - valid token exchange with basic auth",
|
|
strictMode: true,
|
|
method: "POST",
|
|
grantType: "authorization_code",
|
|
code: "valid-code",
|
|
redirectURI: "https://rp.example.com/callback",
|
|
useBasicAuth: true,
|
|
clientID: "test-client",
|
|
clientSecret: "test-secret",
|
|
setupAuthRequest: true,
|
|
authRequestClient: "test-client",
|
|
authRequestRedirect: "https://rp.example.com/callback",
|
|
expectIDToken: true,
|
|
},
|
|
{
|
|
name: "strict mode - missing client credentials",
|
|
strictMode: true,
|
|
method: "POST",
|
|
grantType: "authorization_code",
|
|
code: "valid-code",
|
|
redirectURI: "https://rp.example.com/callback",
|
|
setupAuthRequest: true,
|
|
authRequestClient: "test-client",
|
|
authRequestRedirect: "https://rp.example.com/callback",
|
|
expectError: true,
|
|
expectCode: http.StatusUnauthorized,
|
|
},
|
|
{
|
|
name: "strict mode - client_id mismatch",
|
|
strictMode: true,
|
|
method: "POST",
|
|
grantType: "authorization_code",
|
|
code: "valid-code",
|
|
clientID: "wrong-client",
|
|
clientSecret: "test-secret",
|
|
redirectURI: "https://rp.example.com/callback",
|
|
setupAuthRequest: true,
|
|
authRequestClient: "test-client",
|
|
expectError: true,
|
|
expectCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "strict mode - invalid client secret",
|
|
strictMode: true,
|
|
method: "POST",
|
|
grantType: "authorization_code",
|
|
code: "valid-code",
|
|
clientID: "test-client",
|
|
clientSecret: "wrong-secret",
|
|
redirectURI: "https://rp.example.com/callback",
|
|
setupAuthRequest: true,
|
|
authRequestClient: "test-client",
|
|
authRequestRedirect: "https://rp.example.com/callback",
|
|
expectError: true,
|
|
expectCode: http.StatusUnauthorized,
|
|
},
|
|
{
|
|
name: "strict mode - redirect_uri mismatch",
|
|
strictMode: true,
|
|
method: "POST",
|
|
grantType: "authorization_code",
|
|
code: "valid-code",
|
|
clientID: "test-client",
|
|
clientSecret: "test-secret",
|
|
redirectURI: "https://wrong.example.com/callback",
|
|
setupAuthRequest: true,
|
|
authRequestClient: "test-client",
|
|
authRequestRedirect: "https://rp.example.com/callback",
|
|
expectError: true,
|
|
expectCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "non-strict mode - no client validation required",
|
|
strictMode: false,
|
|
method: "POST",
|
|
grantType: "authorization_code",
|
|
code: "valid-code",
|
|
redirectURI: "https://rp.example.com/callback",
|
|
setupAuthRequest: true,
|
|
authRequestRedirect: "https://rp.example.com/callback",
|
|
expectIDToken: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
srv := setupTestServer(t, tt.strictMode)
|
|
|
|
// Setup authorization request if needed
|
|
if tt.setupAuthRequest {
|
|
now := time.Now()
|
|
profile := &tailcfg.UserProfile{
|
|
LoginName: "alice@example.com",
|
|
DisplayName: "Alice Example",
|
|
ProfilePicURL: "https://example.com/alice.jpg",
|
|
}
|
|
node := &tailcfg.Node{
|
|
ID: 123,
|
|
Name: "test-node.test.ts.net.",
|
|
User: 456,
|
|
Key: key.NodePublic{},
|
|
Cap: 1,
|
|
DiscoKey: key.DiscoPublic{},
|
|
}
|
|
remoteUser := &apitype.WhoIsResponse{
|
|
Node: node,
|
|
UserProfile: profile,
|
|
CapMap: tailcfg.PeerCapMap{},
|
|
}
|
|
|
|
var funnelClientPtr *funnelClient
|
|
if tt.strictMode && tt.authRequestClient != "" {
|
|
funnelClientPtr = &funnelClient{
|
|
ID: tt.authRequestClient,
|
|
Secret: "test-secret",
|
|
Name: "Test Client",
|
|
RedirectURI: tt.authRequestRedirect,
|
|
}
|
|
srv.funnelClients[tt.authRequestClient] = funnelClientPtr
|
|
}
|
|
|
|
srv.code["valid-code"] = &authRequest{
|
|
clientID: tt.authRequestClient,
|
|
nonce: "nonce123",
|
|
redirectURI: tt.authRequestRedirect,
|
|
validTill: now.Add(5 * time.Minute),
|
|
remoteUser: remoteUser,
|
|
localRP: !tt.strictMode,
|
|
funnelRP: funnelClientPtr,
|
|
}
|
|
}
|
|
|
|
// Create form data
|
|
form := url.Values{}
|
|
form.Set("grant_type", tt.grantType)
|
|
form.Set("code", tt.code)
|
|
form.Set("redirect_uri", tt.redirectURI)
|
|
|
|
if !tt.useBasicAuth {
|
|
if tt.clientID != "" {
|
|
form.Set("client_id", tt.clientID)
|
|
}
|
|
if tt.clientSecret != "" {
|
|
form.Set("client_secret", tt.clientSecret)
|
|
}
|
|
}
|
|
|
|
req := httptest.NewRequest(tt.method, "/token", strings.NewReader(form.Encode()))
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
req.RemoteAddr = "127.0.0.1:12345"
|
|
|
|
if tt.useBasicAuth && tt.clientID != "" && tt.clientSecret != "" {
|
|
req.SetBasicAuth(tt.clientID, tt.clientSecret)
|
|
}
|
|
|
|
rr := httptest.NewRecorder()
|
|
srv.serveToken(rr, req)
|
|
|
|
if tt.expectError {
|
|
if rr.Code != tt.expectCode {
|
|
t.Errorf("expected status code %d, got %d: %s", tt.expectCode, rr.Code, rr.Body.String())
|
|
}
|
|
} else if tt.expectIDToken {
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("expected 200 OK, got %d: %s", rr.Code, rr.Body.String())
|
|
}
|
|
|
|
var resp struct {
|
|
IDToken string `json:"id_token"`
|
|
AccessToken string `json:"access_token"`
|
|
TokenType string `json:"token_type"`
|
|
ExpiresIn int `json:"expires_in"`
|
|
}
|
|
|
|
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
|
|
t.Fatalf("failed to unmarshal response: %v", err)
|
|
}
|
|
|
|
if resp.IDToken == "" {
|
|
t.Error("expected id_token in response")
|
|
}
|
|
if resp.AccessToken == "" {
|
|
t.Error("expected access_token in response")
|
|
}
|
|
if resp.TokenType != "Bearer" {
|
|
t.Errorf("expected token_type 'Bearer', got '%s'", resp.TokenType)
|
|
}
|
|
if resp.ExpiresIn != 300 {
|
|
t.Errorf("expected expires_in 300, got %d", resp.ExpiresIn)
|
|
}
|
|
|
|
// Verify access token was stored
|
|
srv.mu.Lock()
|
|
_, ok := srv.accessToken[resp.AccessToken]
|
|
srv.mu.Unlock()
|
|
|
|
if !ok {
|
|
t.Error("expected access token to be stored")
|
|
}
|
|
|
|
// Verify authorization code was consumed
|
|
srv.mu.Lock()
|
|
_, ok = srv.code[tt.code]
|
|
srv.mu.Unlock()
|
|
|
|
if ok {
|
|
t.Error("expected authorization code to be consumed")
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestServeUserInfoWithClientValidation verifies UserInfo endpoint security in both strict and non-strict modes.
|
|
// In strict mode, the UserInfo endpoint must:
|
|
// - Validate that access tokens are associated with registered clients
|
|
// - Reject tokens for clients that have been deleted/unregistered
|
|
// - Enforce token expiration properly
|
|
// - Return appropriate user claims based on client capabilities
|
|
func TestServeUserInfoWithClientValidation(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
strictMode bool
|
|
setupToken bool
|
|
setupClient bool
|
|
clientID string
|
|
token string
|
|
tokenValidTill time.Time
|
|
expectError bool
|
|
expectCode int
|
|
expectUserInfo bool
|
|
}{
|
|
{
|
|
name: "strict mode - valid token with existing client",
|
|
strictMode: true,
|
|
setupToken: true,
|
|
setupClient: true,
|
|
clientID: "test-client",
|
|
token: "valid-token",
|
|
tokenValidTill: time.Now().Add(5 * time.Minute),
|
|
expectUserInfo: true,
|
|
},
|
|
{
|
|
name: "strict mode - valid token but client no longer exists",
|
|
strictMode: true,
|
|
setupToken: true,
|
|
setupClient: false,
|
|
clientID: "deleted-client",
|
|
token: "valid-token",
|
|
tokenValidTill: time.Now().Add(5 * time.Minute),
|
|
expectError: true,
|
|
expectCode: http.StatusUnauthorized,
|
|
},
|
|
{
|
|
name: "strict mode - expired token",
|
|
strictMode: true,
|
|
setupToken: true,
|
|
setupClient: true,
|
|
clientID: "test-client",
|
|
token: "expired-token",
|
|
tokenValidTill: time.Now().Add(-5 * time.Minute),
|
|
expectError: true,
|
|
expectCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "strict mode - invalid token",
|
|
strictMode: true,
|
|
setupToken: false,
|
|
token: "invalid-token",
|
|
expectError: true,
|
|
expectCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "strict mode - token without client association",
|
|
strictMode: true,
|
|
setupToken: true,
|
|
setupClient: false,
|
|
clientID: "",
|
|
token: "valid-token",
|
|
tokenValidTill: time.Now().Add(5 * time.Minute),
|
|
expectError: true,
|
|
expectCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "non-strict mode - no client validation required",
|
|
strictMode: false,
|
|
setupToken: true,
|
|
setupClient: false,
|
|
clientID: "",
|
|
token: "valid-token",
|
|
tokenValidTill: time.Now().Add(5 * time.Minute),
|
|
expectUserInfo: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
srv := setupTestServer(t, tt.strictMode)
|
|
|
|
// Setup client if needed
|
|
if tt.setupClient {
|
|
srv.funnelClients[tt.clientID] = &funnelClient{
|
|
ID: tt.clientID,
|
|
Secret: "test-secret",
|
|
Name: "Test Client",
|
|
RedirectURI: "https://rp.example.com/callback",
|
|
}
|
|
}
|
|
|
|
// Setup token if needed
|
|
if tt.setupToken {
|
|
profile := &tailcfg.UserProfile{
|
|
LoginName: "alice@example.com",
|
|
DisplayName: "Alice Example",
|
|
ProfilePicURL: "https://example.com/alice.jpg",
|
|
}
|
|
node := &tailcfg.Node{
|
|
ID: 123,
|
|
Name: "test-node.test.ts.net.",
|
|
User: 456,
|
|
Key: key.NodePublic{},
|
|
Cap: 1,
|
|
DiscoKey: key.DiscoPublic{},
|
|
}
|
|
remoteUser := &apitype.WhoIsResponse{
|
|
Node: node,
|
|
UserProfile: profile,
|
|
CapMap: tailcfg.PeerCapMap{},
|
|
}
|
|
|
|
srv.accessToken[tt.token] = &authRequest{
|
|
clientID: tt.clientID,
|
|
validTill: tt.tokenValidTill,
|
|
remoteUser: remoteUser,
|
|
}
|
|
}
|
|
|
|
// Create request
|
|
req := httptest.NewRequest("GET", "/userinfo", nil)
|
|
req.Header.Set("Authorization", "Bearer "+tt.token)
|
|
req.RemoteAddr = "127.0.0.1:12345"
|
|
|
|
rr := httptest.NewRecorder()
|
|
srv.serveUserInfo(rr, req)
|
|
|
|
if tt.expectError {
|
|
if rr.Code != tt.expectCode {
|
|
t.Errorf("expected status code %d, got %d: %s", tt.expectCode, rr.Code, rr.Body.String())
|
|
}
|
|
} else if tt.expectUserInfo {
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("expected 200 OK, got %d: %s", rr.Code, rr.Body.String())
|
|
}
|
|
|
|
var resp map[string]any
|
|
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
|
|
t.Fatalf("failed to parse JSON response: %v", err)
|
|
}
|
|
|
|
// Check required fields
|
|
expectedFields := []string{"sub", "name", "email", "picture", "username"}
|
|
for _, field := range expectedFields {
|
|
if _, ok := resp[field]; !ok {
|
|
t.Errorf("expected field '%s' in user info response", field)
|
|
}
|
|
}
|
|
|
|
// Verify specific values
|
|
if resp["name"] != "Alice Example" {
|
|
t.Errorf("expected name 'Alice Example', got '%v'", resp["name"])
|
|
}
|
|
if resp["email"] != "alice@example.com" {
|
|
t.Errorf("expected email 'alice@example.com', got '%v'", resp["email"])
|
|
}
|
|
if resp["username"] != "alice" {
|
|
t.Errorf("expected username 'alice', got '%v'", resp["username"])
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|