diff --git a/cmd/tailscale/cli/up.go b/cmd/tailscale/cli/up.go index 2a7465de1..4ec1df798 100644 --- a/cmd/tailscale/cli/up.go +++ b/cmd/tailscale/cli/up.go @@ -99,6 +99,7 @@ func newUpFlagSet(goos string, upArgs *upArgsT, cmd string) *flag.FlagSet { upf.StringVar(&upArgs.qrFormat, "qr-format", string(qrcodes.FormatAuto), fmt.Sprintf("QR code formatting (%s, %s, %s, %s)", qrcodes.FormatAuto, qrcodes.FormatASCII, qrcodes.FormatLarge, qrcodes.FormatSmall)) } upf.StringVar(&upArgs.authKeyOrFile, "auth-key", "", `node authorization key; if it begins with "file:", then it's a path to a file containing the authkey`) + upf.StringVar(&upArgs.audience, "audience", "", "Audience used to generate authkeys via workload identity federation") upf.StringVar(&upArgs.clientID, "client-id", "", "Client ID used to generate authkeys via workload identity federation") upf.StringVar(&upArgs.clientSecretOrFile, "client-secret", "", `Client Secret used to generate authkeys via OAuth; if it begins with "file:", then it's a path to a file containing the secret`) upf.StringVar(&upArgs.idTokenOrFile, "id-token", "", `ID token from the identity provider to exchange with the control server for workload identity federation; if it begins with "file:", then it's a path to a file containing the token`) @@ -149,7 +150,7 @@ func newUpFlagSet(goos string, upArgs *upArgsT, cmd string) *flag.FlagSet { return upf } -// notFalseVar is is a flag.Value that can only be "true", if set. +// notFalseVar is a flag.Value that can only be "true", if set. type notFalseVar struct{} func (notFalseVar) IsBoolFlag() bool { return true } @@ -194,6 +195,7 @@ type upArgsT struct { netfilterMode string authKeyOrFile string // "secret" or "file:/path/to/secret" clientID string + audience string clientSecretOrFile string // "secret" or "file:/path/to/secret" idTokenOrFile string // "secret" or "file:/path/to/secret" hostname string @@ -628,7 +630,7 @@ func runUp(ctx context.Context, cmd string, args []string, upArgs upArgsT) (retE return err } - authKey, err = f(ctx, prefs.ControlURL, upArgs.clientID, idToken, strings.Split(upArgs.advertiseTags, ",")) + authKey, err = f(ctx, prefs.ControlURL, upArgs.clientID, idToken, upArgs.audience, strings.Split(upArgs.advertiseTags, ",")) if err != nil { return err } @@ -905,7 +907,7 @@ func addPrefFlagMapping(flagName string, prefNames ...string) { // correspond to an ipn.Pref. func preflessFlag(flagName string) bool { switch flagName { - case "auth-key", "force-reauth", "reset", "qr", "qr-format", "json", "timeout", "accept-risk", "host-routes", "client-id", "client-secret", "id-token": + case "auth-key", "force-reauth", "reset", "qr", "qr-format", "json", "timeout", "accept-risk", "host-routes", "client-id", "audience", "client-secret", "id-token": return true } return false diff --git a/cmd/tailscale/cli/up_test.go b/cmd/tailscale/cli/up_test.go index fe2f1b555..bb172f906 100644 --- a/cmd/tailscale/cli/up_test.go +++ b/cmd/tailscale/cli/up_test.go @@ -46,6 +46,7 @@ var validUpFlags = set.Of( "client-id", "client-secret", "id-token", + "audience", ) // TestUpFlagSetIsFrozen complains when new flags are added to tailscale up. diff --git a/feature/identityfederation/identityfederation.go b/feature/identityfederation/identityfederation.go index 47ebd1349..6a0ac75a5 100644 --- a/feature/identityfederation/identityfederation.go +++ b/feature/identityfederation/identityfederation.go @@ -19,6 +19,7 @@ import ( "tailscale.com/feature" "tailscale.com/internal/client/tailscale" "tailscale.com/ipn" + "tailscale.com/wif" ) func init() { @@ -28,13 +29,17 @@ func init() { } // resolveAuthKey uses OIDC identity federation to exchange the provided ID token and client ID for an authkey. -func resolveAuthKey(ctx context.Context, baseURL, clientID, idToken string, tags []string) (string, error) { +func resolveAuthKey(ctx context.Context, baseURL, clientID, idToken, audience string, tags []string) (string, error) { if clientID == "" { return "", nil // Short-circuit, no client ID means not using identity federation } if idToken == "" { - return "", errors.New("federated identity authkeys require --id-token") + providerIdToken, err := wif.ObtainProviderToken(ctx, audience) + if err != nil { + return "", errors.New("federated identity authkeys require --id-token") + } + idToken = providerIdToken } if len(tags) == 0 { return "", errors.New("federated identity authkeys require --advertise-tags") diff --git a/feature/identityfederation/identityfederation_test.go b/feature/identityfederation/identityfederation_test.go index a673a4298..3ac53f339 100644 --- a/feature/identityfederation/identityfederation_test.go +++ b/feature/identityfederation/identityfederation_test.go @@ -16,6 +16,7 @@ func TestResolveAuthKey(t *testing.T) { name string clientID string idToken string + audience string tags []string wantAuthKey string wantErr string @@ -24,6 +25,7 @@ func TestResolveAuthKey(t *testing.T) { name: "success", clientID: "client-123", idToken: "token", + audience: "api://tailscale-wif", tags: []string{"tag:test"}, wantAuthKey: "tskey-auth-xyz", wantErr: "", @@ -32,6 +34,7 @@ func TestResolveAuthKey(t *testing.T) { name: "missing client id short-circuits without error", clientID: "", idToken: "token", + audience: "api://tailscale-wif", tags: []string{"tag:test"}, wantAuthKey: "", wantErr: "", @@ -40,6 +43,7 @@ func TestResolveAuthKey(t *testing.T) { name: "missing id token", clientID: "client-123", idToken: "", + audience: "api://tailscale-wif", tags: []string{"tag:test"}, wantErr: "federated identity authkeys require --id-token", }, @@ -47,6 +51,7 @@ func TestResolveAuthKey(t *testing.T) { name: "missing tags", clientID: "client-123", idToken: "token", + audience: "api://tailscale-wif", tags: []string{}, wantErr: "federated identity authkeys require --advertise-tags", }, @@ -54,6 +59,7 @@ func TestResolveAuthKey(t *testing.T) { name: "invalid client id attributes", clientID: "client-123?invalid=value", idToken: "token", + audience: "api://tailscale-wif", tags: []string{"tag:test"}, wantErr: `failed to parse optional config attributes: unknown optional config attribute "invalid"`, }, @@ -64,7 +70,7 @@ func TestResolveAuthKey(t *testing.T) { srv := mockedControlServer(t) defer srv.Close() - authKey, err := resolveAuthKey(context.Background(), srv.URL, tt.clientID, tt.idToken, tt.tags) + authKey, err := resolveAuthKey(context.Background(), srv.URL, tt.clientID, tt.idToken, tt.audience, tt.tags) if tt.wantErr != "" { if err == nil { t.Errorf("resolveAuthKey() error = nil, want %q", tt.wantErr) diff --git a/flake.nix b/flake.nix index dd8016b4e..5d93f4d4d 100644 --- a/flake.nix +++ b/flake.nix @@ -151,4 +151,4 @@ }); }; } -# nix-direnv cache busting line: sha256-knSIes9pFVkVfK5hcBG9BSR1ueH+yPpx4hv/UsyaW2M= +# nix-direnv cache busting line: sha256-HdRMXmKkibF4z8M+oYbnX77URcLZI9zMP3wRy1nvMiY= diff --git a/go.mod b/go.mod index c8be839c3..2627751e9 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,7 @@ require ( github.com/andybalholm/brotli v1.1.0 github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be github.com/atotto/clipboard v0.1.4 - github.com/aws/aws-sdk-go-v2 v1.36.0 + github.com/aws/aws-sdk-go-v2 v1.41.0 github.com/aws/aws-sdk-go-v2/config v1.29.5 github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.58 github.com/aws/aws-sdk-go-v2/service/s3 v1.75.3 @@ -268,18 +268,18 @@ require ( github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.8 // indirect github.com/aws/aws-sdk-go-v2/credentials v1.17.58 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.27 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.31 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.31 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.16 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.2 // indirect github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.31 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4 // indirect github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.5.5 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.12 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.16 // indirect github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.12 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.24.14 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.13 // indirect - github.com/aws/aws-sdk-go-v2/service/sts v1.33.13 // indirect - github.com/aws/smithy-go v1.22.2 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.41.5 + github.com/aws/smithy-go v1.24.0 github.com/beorn7/perks v1.0.1 // indirect github.com/bkielbasa/cyclop v1.2.1 // indirect github.com/blakesmith/ar v0.0.0-20190502131153-809d4375e1fb // indirect diff --git a/go.mod.sri b/go.mod.sri index fd2ab9d7a..3a1b01510 100644 --- a/go.mod.sri +++ b/go.mod.sri @@ -1 +1 @@ -sha256-knSIes9pFVkVfK5hcBG9BSR1ueH+yPpx4hv/UsyaW2M= +sha256-HdRMXmKkibF4z8M+oYbnX77URcLZI9zMP3wRy1nvMiY= diff --git a/go.sum b/go.sum index 19703e072..c2202c15e 100644 --- a/go.sum +++ b/go.sum @@ -141,8 +141,8 @@ github.com/ashanbrown/makezero v1.1.1 h1:iCQ87C0V0vSyO+M9E/FZYbu65auqH0lnsOkf5Fc github.com/ashanbrown/makezero v1.1.1/go.mod h1:i1bJLCRSCHOcOa9Y6MyF2FTfMZMFdHvxKHxgO5Z1axI= github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= -github.com/aws/aws-sdk-go-v2 v1.36.0 h1:b1wM5CcE65Ujwn565qcwgtOTT1aT4ADOHHgglKjG7fk= -github.com/aws/aws-sdk-go-v2 v1.36.0/go.mod h1:5PMILGVKiW32oDzjj6RU52yrNrDPUHcbZQYr1sM7qmM= +github.com/aws/aws-sdk-go-v2 v1.41.0 h1:tNvqh1s+v0vFYdA1xq0aOJH+Y5cRyZ5upu6roPgPKd4= +github.com/aws/aws-sdk-go-v2 v1.41.0/go.mod h1:MayyLB8y+buD9hZqkCW3kX1AKq07Y5pXxtgB+rRFhz0= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.8 h1:zAxi9p3wsZMIaVCdoiQp2uZ9k1LsZvmAnoTBeZPXom0= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.8/go.mod h1:3XkePX5dSaxveLAYY7nsbsZZrKxCyEuE5pM4ziFxyGg= github.com/aws/aws-sdk-go-v2/config v1.29.5 h1:4lS2IB+wwkj5J43Tq/AwvnscBerBJtQQ6YS7puzCI1k= @@ -153,20 +153,20 @@ github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.27 h1:7lOW8NUwE9UZekS1DYoiPd github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.27/go.mod h1:w1BASFIPOPUae7AgaH4SbjNbfdkxuggLyGfNFTn8ITY= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.58 h1:/BsEGAyMai+KdXS+CMHlLhB5miAO19wOqE6tj8azWPM= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.58/go.mod h1:KHM3lfl/sAJBCoLI1Lsg5w4SD2VDYWwQi7vxbKhw7TI= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.31 h1:lWm9ucLSRFiI4dQQafLrEOmEDGry3Swrz0BIRdiHJqQ= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.31/go.mod h1:Huu6GG0YTfbPphQkDSo4dEGmQRTKb9k9G7RdtyQWxuI= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.31 h1:ACxDklUKKXb48+eg5ROZXi1vDgfMyfIA/WyvqHcHI0o= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.31/go.mod h1:yadnfsDwqXeVaohbGc/RaD287PuyRw2wugkh5ZL2J6k= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16 h1:rgGwPzb82iBYSvHMHXc8h9mRoOUBZIGFgKb9qniaZZc= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16/go.mod h1:L/UxsGeKpGoIj6DxfhOWHWQ/kGKcd4I1VncE4++IyKA= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.16 h1:1jtGzuV7c82xnqOVfx2F0xmJcOw5374L7N6juGW6x6U= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.16/go.mod h1:M2E5OQf+XLe+SZGmmpaI2yy+J326aFf6/+54PoxSANc= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.2 h1:Pg9URiobXy85kgFev3og2CuOZ8JZUBENF+dcgWBaYNk= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.2/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc= github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.31 h1:8IwBjuLdqIO1dGB+dZ9zJEl8wzY3bVYxcs0Xyu/Lsc0= github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.31/go.mod h1:8tMBcuVjL4kP/ECEIWTCWtwV2kj6+ouEKl4cqR4iWLw= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.2 h1:D4oz8/CzT9bAEYtVhSBmFj2dNOtaHOtMKc2vHBwYizA= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.2/go.mod h1:Za3IHqTQ+yNcRHxu1OFucBh0ACZT4j4VQFF0BqpZcLY= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4 h1:0ryTNEdJbzUCEWkVXEXoqlXV72J5keC1GvILMOuD00E= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4/go.mod h1:HQ4qwNZh32C3CBeO6iJLQlgtMzqeG17ziAA/3KDJFow= github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.5.5 h1:siiQ+jummya9OLPDEyHVb2dLW4aOMe22FGDd0sAfuSw= github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.5.5/go.mod h1:iHVx2J9pWzITdP5MJY6qWfG34TfD9EA+Qi3eV6qQCXw= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.12 h1:O+8vD2rGjfihBewr5bT+QUfYUHIxCVgG61LHoT59shM= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.12/go.mod h1:usVdWJaosa66NMvmCrr08NcWDBRv4E6+YFG2pUdw1Lk= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.16 h1:oHjJHeUy0ImIV0bsrX0X91GkV5nJAyv1l1CC9lnO0TI= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.16/go.mod h1:iRSNGgOYmiYwSCXxXaKb9HfOEj40+oTKn8pTxMlYkRM= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.12 h1:tkVNm99nkJnFo1H9IIQb5QkCiPcvCDn3Pos+IeTbGRA= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.12/go.mod h1:dIVlquSPUMqEJtx2/W17SM2SuESRaVEhEV9alcMqxjw= github.com/aws/aws-sdk-go-v2/service/s3 v1.75.3 h1:JBod0SnNqcWQ0+uAyzeRFG1zCHotW8DukumYYyNy0zo= @@ -177,10 +177,10 @@ github.com/aws/aws-sdk-go-v2/service/sso v1.24.14 h1:c5WJ3iHz7rLIgArznb3JCSQT3uU github.com/aws/aws-sdk-go-v2/service/sso v1.24.14/go.mod h1:+JJQTxB6N4niArC14YNtxcQtwEqzS3o9Z32n7q33Rfs= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.13 h1:f1L/JtUkVODD+k1+IiSJUUv8A++2qVr+Xvb3xWXETMU= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.13/go.mod h1:tvqlFoja8/s0o+UruA1Nrezo/df0PzdunMDDurUfg6U= -github.com/aws/aws-sdk-go-v2/service/sts v1.33.13 h1:3LXNnmtH3TURctC23hnC0p/39Q5gre3FI7BNOiDcVWc= -github.com/aws/aws-sdk-go-v2/service/sts v1.33.13/go.mod h1:7Yn+p66q/jt38qMoVfNvjbm3D89mGBnkwDcijgtih8w= -github.com/aws/smithy-go v1.22.2 h1:6D9hW43xKFrRx/tXXfAlIZc4JI+yQe6snnWcQyxSyLQ= -github.com/aws/smithy-go v1.22.2/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.5 h1:SciGFVNZ4mHdm7gpD1dgZYnCuVdX1s+lFTg4+4DOy70= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.5/go.mod h1:iW40X4QBmUxdP+fZNOpfmkdMZqsovezbAeO+Ubiv2pk= +github.com/aws/smithy-go v1.24.0 h1:LpilSUItNPFr1eY85RYgTIg5eIEPtvFbskaFcmmIUnk= +github.com/aws/smithy-go v1.24.0/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= github.com/axiomhq/hyperloglog v0.0.0-20240319100328-84253e514e02 h1:bXAPYSbdYbS5VTy92NIUbeDI1qyggi+JYh5op9IFlcQ= github.com/axiomhq/hyperloglog v0.0.0-20240319100328-84253e514e02/go.mod h1:k08r+Yj1PRAmuayFiRK6MYuR5Ve4IuZtTfxErMIh0+c= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= diff --git a/internal/client/tailscale/identityfederation.go b/internal/client/tailscale/identityfederation.go index b8eb0fc9c..7bbd10d93 100644 --- a/internal/client/tailscale/identityfederation.go +++ b/internal/client/tailscale/identityfederation.go @@ -16,7 +16,7 @@ import ( // clientID is the federated client ID used for token exchange // idToken is the Identity token from the identity provider // tags is the list of tags to be associated with the auth key -var HookResolveAuthKeyViaWIF feature.Hook[func(ctx context.Context, baseURL, clientID, idToken string, tags []string) (string, error)] +var HookResolveAuthKeyViaWIF feature.Hook[func(ctx context.Context, baseURL, clientID, idToken, audience string, tags []string) (string, error)] // HookExchangeJWTForTokenViaWIF resolves to [identityfederation.exchangeJWTForToken] when the // corresponding feature tag is enabled in the build process. diff --git a/shell.nix b/shell.nix index c494ce47c..1f8faf235 100644 --- a/shell.nix +++ b/shell.nix @@ -16,4 +16,4 @@ ) { src = ./.; }).shellNix -# nix-direnv cache busting line: sha256-knSIes9pFVkVfK5hcBG9BSR1ueH+yPpx4hv/UsyaW2M= +# nix-direnv cache busting line: sha256-HdRMXmKkibF4z8M+oYbnX77URcLZI9zMP3wRy1nvMiY= diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index ea165e932..ad7b326a7 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -806,7 +806,7 @@ func (s *Server) resolveAuthKey() (string, error) { if clientID == "" && idToken != "" { return "", fmt.Errorf("ID token for workload identity federation found, but client ID is empty") } - authKey, err = resolveViaWIF(s.shutdownCtx, s.ControlURL, clientID, idToken, s.AdvertiseTags) + authKey, err = resolveViaWIF(s.shutdownCtx, s.ControlURL, clientID, idToken, "", s.AdvertiseTags) if err != nil { return "", err } diff --git a/tsnet/tsnet_test.go b/tsnet/tsnet_test.go index 838d5f3f5..14e07ba9b 100644 --- a/tsnet/tsnet_test.go +++ b/tsnet/tsnet_test.go @@ -1405,7 +1405,7 @@ func TestResolveAuthKey(t *testing.T) { oauthAvailable bool wifAvailable bool resolveViaOAuth func(ctx context.Context, clientSecret string, tags []string) (string, error) - resolveViaWIF func(ctx context.Context, baseURL, clientID, idToken string, tags []string) (string, error) + resolveViaWIF func(ctx context.Context, baseURL, clientID, idToken, audience string, tags []string) (string, error) wantAuthKey string wantErr bool wantErrContains string @@ -1437,7 +1437,7 @@ func TestResolveAuthKey(t *testing.T) { clientID: "client-id-123", idToken: "id-token-456", wifAvailable: true, - resolveViaWIF: func(ctx context.Context, baseURL, clientID, idToken string, tags []string) (string, error) { + resolveViaWIF: func(ctx context.Context, baseURL, clientID, idToken, audience string, tags []string) (string, error) { if clientID != "client-id-123" { return "", fmt.Errorf("unexpected client ID: %s", clientID) } @@ -1454,7 +1454,7 @@ func TestResolveAuthKey(t *testing.T) { clientID: "client-id-123", idToken: "id-token-456", wifAvailable: true, - resolveViaWIF: func(ctx context.Context, baseURL, clientID, idToken string, tags []string) (string, error) { + resolveViaWIF: func(ctx context.Context, baseURL, clientID, idToken, audience string, tags []string) (string, error) { return "", fmt.Errorf("resolution failed") }, wantErrContains: "resolution failed", @@ -1464,7 +1464,7 @@ func TestResolveAuthKey(t *testing.T) { clientID: "", idToken: "id-token-456", wifAvailable: true, - resolveViaWIF: func(ctx context.Context, baseURL, clientID, idToken string, tags []string) (string, error) { + resolveViaWIF: func(ctx context.Context, baseURL, clientID, idToken, audience string, tags []string) (string, error) { return "", fmt.Errorf("should not be called") }, wantErrContains: "empty", @@ -1474,7 +1474,7 @@ func TestResolveAuthKey(t *testing.T) { clientID: "client-id-123", idToken: "", wifAvailable: true, - resolveViaWIF: func(ctx context.Context, baseURL, clientID, idToken string, tags []string) (string, error) { + resolveViaWIF: func(ctx context.Context, baseURL, clientID, idToken, audience string, tags []string) (string, error) { return "", fmt.Errorf("should not be called") }, wantErrContains: "empty", @@ -1490,7 +1490,7 @@ func TestResolveAuthKey(t *testing.T) { return "tskey-auth-via-oauth", nil }, wifAvailable: true, - resolveViaWIF: func(ctx context.Context, baseURL, clientID, idToken string, tags []string) (string, error) { + resolveViaWIF: func(ctx context.Context, baseURL, clientID, idToken, audience string, tags []string) (string, error) { return "", fmt.Errorf("should not be called") }, wantAuthKey: "tskey-auth-via-oauth", @@ -1505,7 +1505,7 @@ func TestResolveAuthKey(t *testing.T) { return "", fmt.Errorf("resolution failed") }, wifAvailable: true, - resolveViaWIF: func(ctx context.Context, baseURL, clientID, idToken string, tags []string) (string, error) { + resolveViaWIF: func(ctx context.Context, baseURL, clientID, idToken, audience string, tags []string) (string, error) { return "", fmt.Errorf("should not be called") }, wantErrContains: "failed", diff --git a/wif/wif.go b/wif/wif.go new file mode 100644 index 000000000..d3bdcf6a8 --- /dev/null +++ b/wif/wif.go @@ -0,0 +1,337 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package wif deals with obtaining ID tokens from provider VMs +// to be used as part of Workload Identity Federation +package wif + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "os" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/sts" + "github.com/aws/smithy-go" + "tailscale.com/util/httpm" +) + +type Environment string + +const ( + EnvGitHub Environment = "github" + EnvAWS Environment = "aws" + EnvGCP Environment = "gcp" + EnvNone Environment = "none" +) + +const ( + metadataDialTimeout = 200 * time.Millisecond + metadataResponseTimeout = 200 * time.Millisecond + metadataClientTimeout = 400 * time.Millisecond + metadataIdleConnTimeout = 10 * time.Second + providerDialTimeout = 500 * time.Millisecond + providerResponseTimeout = 500 * time.Millisecond + providerClientTimeout = 2 * time.Second + githubClientTimeout = 10 * time.Second +) + +// ObtainProviderToken tries to detect what provider the client is running in +// and then tries to obtain an ID token for the audience that is passed as an argument +// To detect the environment, we do it in the following intentional order: +// 1. GitHub Actions (strongest env signals; may run atop any cloud) +// 2. AWS via IMDSv2 token endpoint (does not require env vars) +// 3. GCP via metadata header semantics +func ObtainProviderToken(ctx context.Context, audience string) (token string, err error) { + env := detectEnvironment(ctx) + + switch env { + case EnvGitHub: + return acquireGitHubActionsIDToken(ctx, audience) + case EnvAWS: + return acquireAWSWebIdentityToken(ctx, audience) + case EnvGCP: + return acquireGCPMetadataIDToken(ctx, audience) + default: + return "", errors.New("could not detect environment; provide --id-token explicitly") + } +} + +func detectEnvironment(ctx context.Context) Environment { + if os.Getenv("ACTIONS_ID_TOKEN_REQUEST_URL") != "" && + os.Getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN") != "" { + return EnvGitHub + } + + client := metadataHTTPClient() + + if detectAWSIMDSv2(ctx, client) { + return EnvAWS + } + + if detectGCPMetadata(ctx, client) { + return EnvGCP + } + + return EnvNone +} + +func metadataHTTPClient() *http.Client { + return &http.Client{ + Transport: &http.Transport{ + Proxy: nil, + DialContext: (&net.Dialer{ + Timeout: metadataDialTimeout, + }).DialContext, + ResponseHeaderTimeout: metadataResponseTimeout, + IdleConnTimeout: metadataIdleConnTimeout, + }, + Timeout: metadataClientTimeout, + } +} + +func providerHTTPClient() *http.Client { + return &http.Client{ + Transport: &http.Transport{ + Proxy: nil, + DialContext: (&net.Dialer{ + Timeout: providerDialTimeout, + }).DialContext, + ResponseHeaderTimeout: providerResponseTimeout, + }, + Timeout: providerClientTimeout, + } +} + +func detectAWSIMDSv2(ctx context.Context, client *http.Client) bool { + req, err := http.NewRequestWithContext(ctx, httpm.PUT, "http://169.254.169.254/latest/api/token", nil) + if err != nil { + return false + } + req.Header.Set("X-aws-ec2-metadata-token-ttl-seconds", "60") + + resp, err := client.Do(req) + if err != nil { + return false + } + defer resp.Body.Close() + + return resp.StatusCode == http.StatusOK +} + +func detectGCPMetadata(ctx context.Context, client *http.Client) bool { + req, err := http.NewRequestWithContext(ctx, httpm.GET, "http://metadata.google.internal", nil) + if err != nil { + return false + } + req.Header.Set("Metadata-Flavor", "Google") + + resp, err := client.Do(req) + if err != nil { + return false + } + defer resp.Body.Close() + + return resp.Header.Get("Metadata-Flavor") == "Google" +} + +type githubOIDCResponse struct { + Value string `json:"value"` +} + +func acquireGitHubActionsIDToken(ctx context.Context, audience string) (jwt string, err error) { + reqURL := os.Getenv("ACTIONS_ID_TOKEN_REQUEST_URL") + reqTok := os.Getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN") + if reqURL == "" || reqTok == "" { + return "", errors.New("missing ACTIONS_ID_TOKEN_REQUEST_URL/TOKEN (ensure workflow has permissions: id-token: write)") + } + + u, err := url.Parse(reqURL) + if err != nil { + return "", fmt.Errorf("parse ACTIONS_ID_TOKEN_REQUEST_URL: %w", err) + } + if strings.TrimSpace(audience) != "" { + q := u.Query() + q.Set("audience", strings.TrimSpace(audience)) + u.RawQuery = q.Encode() + } + + req, err := http.NewRequestWithContext(ctx, httpm.GET, u.String(), nil) + if err != nil { + return "", fmt.Errorf("build request: %w", err) + } + req.Header.Set("Authorization", "Bearer "+reqTok) + req.Header.Set("Accept", "application/json") + + client := &http.Client{Timeout: githubClientTimeout} + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("request github oidc token: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode/100 != 2 { + b, _ := io.ReadAll(io.LimitReader(resp.Body, 2048)) + return "", fmt.Errorf("github oidc token endpoint returned %s: %s", resp.Status, strings.TrimSpace(string(b))) + } + + var tr githubOIDCResponse + if err := json.NewDecoder(resp.Body).Decode(&tr); err != nil { + return "", fmt.Errorf("decode github oidc response: %w", err) + } + if strings.TrimSpace(tr.Value) == "" { + return "", errors.New("github oidc response contained empty token") + } + + // GitHub response doesn't provide exp directly; caller can parse JWT if needed. + return tr.Value, nil +} + +func acquireAWSWebIdentityToken(ctx context.Context, audience string) (jwt string, err error) { + duration := 5 * time.Minute + + region, err := detectAWSRegion(ctx) + if err != nil { + return "", err + } + + // LoadDefaultConfig wires up the default credential chain (incl. IMDS). + cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region)) + if err != nil { + return "", fmt.Errorf("load aws config: %w", err) + } + + // Verify credentials are available before proceeding. + if _, err := cfg.Credentials.Retrieve(ctx); err != nil { + return "", fmt.Errorf("AWS credentials unavailable (instance profile/IMDS?): %w", err) + } + + stsClient := sts.NewFromConfig(cfg) + + in := &sts.GetWebIdentityTokenInput{ + Audience: []string{strings.TrimSpace(audience)}, + SigningAlgorithm: aws.String("RS256"), + DurationSeconds: aws.Int32(int32(duration / time.Second)), + } + + out, err := stsClient.GetWebIdentityToken(ctx, in) + if err != nil { + var apiErr smithy.APIError + if errors.As(err, &apiErr) { + return "", fmt.Errorf("aws sts:GetWebIdentityToken failed (%s): %w", apiErr.ErrorCode(), err) + } + return "", fmt.Errorf("aws sts:GetWebIdentityToken failed: %w", err) + } + + if out.WebIdentityToken == nil || strings.TrimSpace(*out.WebIdentityToken) == "" { + return "", fmt.Errorf("aws sts:GetWebIdentityToken returned empty token") + } + + return *out.WebIdentityToken, nil +} + +func acquireGCPMetadataIDToken(ctx context.Context, audience string) (jwt string, err error) { + u := "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity" + v := url.Values{} + v.Set("audience", strings.TrimSpace(audience)) + v.Set("format", "full") + fullURL := u + "?" + v.Encode() + + req, err := http.NewRequestWithContext(ctx, httpm.GET, fullURL, nil) + if err != nil { + return "", fmt.Errorf("build request: %w", err) + } + req.Header.Set("Metadata-Flavor", "Google") + + client := providerHTTPClient() + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("call gcp metadata identity endpoint: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode/100 != 2 { + b, _ := io.ReadAll(io.LimitReader(resp.Body, 2048)) + return "", fmt.Errorf("gcp metadata identity endpoint returned %s: %s", resp.Status, strings.TrimSpace(string(b))) + } + + b, err := io.ReadAll(io.LimitReader(resp.Body, 1024*1024)) + if err != nil { + return "", fmt.Errorf("read gcp id token: %w", err) + } + jwt = strings.TrimSpace(string(b)) + if jwt == "" { + return "", fmt.Errorf("gcp metadata returned empty token") + } + + return jwt, nil +} + +func detectAWSRegion(ctx context.Context) (string, error) { + client := providerHTTPClient() + + req, err := http.NewRequestWithContext(ctx, httpm.PUT, "http://169.254.169.254/latest/api/token", nil) + if err != nil { + return "", fmt.Errorf("build imds token request: %w", err) + } + req.Header.Set("X-aws-ec2-metadata-token-ttl-seconds", "60") + + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("call imds token endpoint: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(io.LimitReader(resp.Body, 512)) + return "", fmt.Errorf("imds token endpoint returned %s: %s", resp.Status, strings.TrimSpace(string(b))) + } + + tokenBytes, err := io.ReadAll(io.LimitReader(resp.Body, 1024)) + if err != nil { + return "", fmt.Errorf("read imds token: %w", err) + } + token := strings.TrimSpace(string(tokenBytes)) + if token == "" { + return "", fmt.Errorf("imds token endpoint returned empty token") + } + + // Get instance identity document + req2, err := http.NewRequestWithContext(ctx, httpm.GET, "http://169.254.169.254/latest/dynamic/instance-identity/document", nil) + if err != nil { + return "", fmt.Errorf("build identity document request: %w", err) + } + req2.Header.Set("X-aws-ec2-metadata-token", token) + + resp2, err := client.Do(req2) + if err != nil { + return "", fmt.Errorf("call identity document endpoint: %w", err) + } + defer resp2.Body.Close() + + if resp2.StatusCode != http.StatusOK { + b, _ := io.ReadAll(io.LimitReader(resp2.Body, 512)) + return "", fmt.Errorf("identity document endpoint returned %s: %s", resp2.Status, strings.TrimSpace(string(b))) + } + + var doc struct { + Region string `json:"region"` + } + if err := json.NewDecoder(resp2.Body).Decode(&doc); err != nil { + return "", fmt.Errorf("decode identity document: %w", err) + } + if doc.Region == "" { + return "", fmt.Errorf("region not found in instance identity document") + } + + return doc.Region, nil +}