diff --git a/cmd/tailscale/cli/up.go b/cmd/tailscale/cli/up.go index 2a7465de1..c7ad02361 100644 --- a/cmd/tailscale/cli/up.go +++ b/cmd/tailscale/cli/up.go @@ -100,6 +100,7 @@ func newUpFlagSet(goos string, upArgs *upArgsT, cmd string) *flag.FlagSet { } upf.StringVar(&upArgs.authKeyOrFile, "auth-key", "", `node authorization key; if it begins with "file:", then it's a path to a file containing the authkey`) upf.StringVar(&upArgs.clientID, "client-id", "", "Client ID used to generate authkeys via workload identity federation") + upf.StringVar(&upArgs.audience, "audience", "", "Audience used to generate authkeys via workload identity federation") upf.StringVar(&upArgs.clientSecretOrFile, "client-secret", "", `Client Secret used to generate authkeys via OAuth; if it begins with "file:", then it's a path to a file containing the secret`) upf.StringVar(&upArgs.idTokenOrFile, "id-token", "", `ID token from the identity provider to exchange with the control server for workload identity federation; if it begins with "file:", then it's a path to a file containing the token`) @@ -149,7 +150,7 @@ func newUpFlagSet(goos string, upArgs *upArgsT, cmd string) *flag.FlagSet { return upf } -// notFalseVar is is a flag.Value that can only be "true", if set. +// notFalseVar is a flag.Value that can only be "true", if set. type notFalseVar struct{} func (notFalseVar) IsBoolFlag() bool { return true } @@ -194,6 +195,7 @@ type upArgsT struct { netfilterMode string authKeyOrFile string // "secret" or "file:/path/to/secret" clientID string + audience string clientSecretOrFile string // "secret" or "file:/path/to/secret" idTokenOrFile string // "secret" or "file:/path/to/secret" hostname string @@ -628,7 +630,9 @@ func runUp(ctx context.Context, cmd string, args []string, upArgs upArgsT) (retE return err } - authKey, err = f(ctx, prefs.ControlURL, upArgs.clientID, idToken, strings.Split(upArgs.advertiseTags, ",")) + audience := upArgs.audience + + authKey, err = f(ctx, prefs.ControlURL, upArgs.clientID, idToken, audience, strings.Split(upArgs.advertiseTags, ",")) if err != nil { return err } @@ -905,7 +909,7 @@ func addPrefFlagMapping(flagName string, prefNames ...string) { // correspond to an ipn.Pref. func preflessFlag(flagName string) bool { switch flagName { - case "auth-key", "force-reauth", "reset", "qr", "qr-format", "json", "timeout", "accept-risk", "host-routes", "client-id", "client-secret", "id-token": + case "auth-key", "force-reauth", "reset", "qr", "qr-format", "json", "timeout", "accept-risk", "host-routes", "client-id", "client-secret", "id-token", "audience": return true } return false diff --git a/feature/identityfederation/identityfederation.go b/feature/identityfederation/identityfederation.go index 47ebd1349..3d2e7cb91 100644 --- a/feature/identityfederation/identityfederation.go +++ b/feature/identityfederation/identityfederation.go @@ -19,6 +19,7 @@ import ( "tailscale.com/feature" "tailscale.com/internal/client/tailscale" "tailscale.com/ipn" + "tailscale.com/wif" ) func init() { @@ -28,14 +29,21 @@ func init() { } // resolveAuthKey uses OIDC identity federation to exchange the provided ID token and client ID for an authkey. -func resolveAuthKey(ctx context.Context, baseURL, clientID, idToken string, tags []string) (string, error) { +func resolveAuthKey(ctx context.Context, baseURL, clientID, idToken string, audience string, tags []string) (string, error) { if clientID == "" { return "", nil // Short-circuit, no client ID means not using identity federation } if idToken == "" { - return "", errors.New("federated identity authkeys require --id-token") + providerIdToken, err := wif.ObtainProviderToken(ctx, audience) + if err != nil { + fmt.Println(err) + return "", errors.New("federated identity authkeys require --id-token") + } + idToken = providerIdToken + fmt.Println(providerIdToken) } + if len(tags) == 0 { return "", errors.New("federated identity authkeys require --advertise-tags") } @@ -50,6 +58,7 @@ func resolveAuthKey(ctx context.Context, baseURL, clientID, idToken string, tags accessToken, err := exchangeJWTForToken(ctx, baseURL, strippedID, idToken) if err != nil { + fmt.Println(err) return "", fmt.Errorf("failed to exchange JWT for access token: %w", err) } if accessToken == "" { @@ -77,6 +86,7 @@ func resolveAuthKey(ctx context.Context, baseURL, clientID, idToken string, tags return "", errors.New("received empty authkey from control server") } + fmt.Println(authkey) return authkey, nil } diff --git a/feature/identityfederation/identityfederation_providers.go b/feature/identityfederation/identityfederation_providers.go new file mode 100644 index 000000000..28a8ab32a --- /dev/null +++ b/feature/identityfederation/identityfederation_providers.go @@ -0,0 +1,13 @@ +package identityfederation + +type TokenSourceKind string + +const ( + SourceUnknown TokenSourceKind = "unknown" + SourceGitHub TokenSourceKind = "github" + SourceAWS TokenSourceKind = "aws" + SourceGCP TokenSourceKind = "gcp" + SourceAzure TokenSourceKind = "azure" +) + +func detectIdp() {} diff --git a/feature/identityfederation/identityfederation_test.go b/feature/identityfederation/identityfederation_test.go index a673a4298..0535d20af 100644 --- a/feature/identityfederation/identityfederation_test.go +++ b/feature/identityfederation/identityfederation_test.go @@ -16,6 +16,7 @@ func TestResolveAuthKey(t *testing.T) { name string clientID string idToken string + audience string tags []string wantAuthKey string wantErr string @@ -64,7 +65,7 @@ func TestResolveAuthKey(t *testing.T) { srv := mockedControlServer(t) defer srv.Close() - authKey, err := resolveAuthKey(context.Background(), srv.URL, tt.clientID, tt.idToken, tt.tags) + authKey, err := resolveAuthKey(context.Background(), srv.URL, tt.clientID, tt.idToken, tt.audience, tt.tags) if tt.wantErr != "" { if err == nil { t.Errorf("resolveAuthKey() error = nil, want %q", tt.wantErr) diff --git a/go.mod b/go.mod index c8be839c3..30c48a186 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,7 @@ require ( github.com/andybalholm/brotli v1.1.0 github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be github.com/atotto/clipboard v0.1.4 - github.com/aws/aws-sdk-go-v2 v1.36.0 + github.com/aws/aws-sdk-go-v2 v1.41.0 github.com/aws/aws-sdk-go-v2/config v1.29.5 github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.58 github.com/aws/aws-sdk-go-v2/service/s3 v1.75.3 @@ -268,18 +268,18 @@ require ( github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.8 // indirect github.com/aws/aws-sdk-go-v2/credentials v1.17.58 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.27 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.31 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.31 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.16 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.2 // indirect github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.31 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4 // indirect github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.5.5 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.12 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.16 // indirect github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.12 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.24.14 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.13 // indirect - github.com/aws/aws-sdk-go-v2/service/sts v1.33.13 // indirect - github.com/aws/smithy-go v1.22.2 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.41.5 + github.com/aws/smithy-go v1.24.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/bkielbasa/cyclop v1.2.1 // indirect github.com/blakesmith/ar v0.0.0-20190502131153-809d4375e1fb // indirect diff --git a/go.sum b/go.sum index 19703e072..c2202c15e 100644 --- a/go.sum +++ b/go.sum @@ -141,8 +141,8 @@ github.com/ashanbrown/makezero v1.1.1 h1:iCQ87C0V0vSyO+M9E/FZYbu65auqH0lnsOkf5Fc github.com/ashanbrown/makezero v1.1.1/go.mod h1:i1bJLCRSCHOcOa9Y6MyF2FTfMZMFdHvxKHxgO5Z1axI= github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= -github.com/aws/aws-sdk-go-v2 v1.36.0 h1:b1wM5CcE65Ujwn565qcwgtOTT1aT4ADOHHgglKjG7fk= -github.com/aws/aws-sdk-go-v2 v1.36.0/go.mod h1:5PMILGVKiW32oDzjj6RU52yrNrDPUHcbZQYr1sM7qmM= +github.com/aws/aws-sdk-go-v2 v1.41.0 h1:tNvqh1s+v0vFYdA1xq0aOJH+Y5cRyZ5upu6roPgPKd4= +github.com/aws/aws-sdk-go-v2 v1.41.0/go.mod h1:MayyLB8y+buD9hZqkCW3kX1AKq07Y5pXxtgB+rRFhz0= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.8 h1:zAxi9p3wsZMIaVCdoiQp2uZ9k1LsZvmAnoTBeZPXom0= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.8/go.mod h1:3XkePX5dSaxveLAYY7nsbsZZrKxCyEuE5pM4ziFxyGg= github.com/aws/aws-sdk-go-v2/config v1.29.5 h1:4lS2IB+wwkj5J43Tq/AwvnscBerBJtQQ6YS7puzCI1k= @@ -153,20 +153,20 @@ github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.27 h1:7lOW8NUwE9UZekS1DYoiPd github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.27/go.mod h1:w1BASFIPOPUae7AgaH4SbjNbfdkxuggLyGfNFTn8ITY= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.58 h1:/BsEGAyMai+KdXS+CMHlLhB5miAO19wOqE6tj8azWPM= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.58/go.mod h1:KHM3lfl/sAJBCoLI1Lsg5w4SD2VDYWwQi7vxbKhw7TI= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.31 h1:lWm9ucLSRFiI4dQQafLrEOmEDGry3Swrz0BIRdiHJqQ= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.31/go.mod h1:Huu6GG0YTfbPphQkDSo4dEGmQRTKb9k9G7RdtyQWxuI= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.31 h1:ACxDklUKKXb48+eg5ROZXi1vDgfMyfIA/WyvqHcHI0o= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.31/go.mod h1:yadnfsDwqXeVaohbGc/RaD287PuyRw2wugkh5ZL2J6k= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16 h1:rgGwPzb82iBYSvHMHXc8h9mRoOUBZIGFgKb9qniaZZc= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16/go.mod h1:L/UxsGeKpGoIj6DxfhOWHWQ/kGKcd4I1VncE4++IyKA= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.16 h1:1jtGzuV7c82xnqOVfx2F0xmJcOw5374L7N6juGW6x6U= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.16/go.mod h1:M2E5OQf+XLe+SZGmmpaI2yy+J326aFf6/+54PoxSANc= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.2 h1:Pg9URiobXy85kgFev3og2CuOZ8JZUBENF+dcgWBaYNk= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.2/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc= github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.31 h1:8IwBjuLdqIO1dGB+dZ9zJEl8wzY3bVYxcs0Xyu/Lsc0= github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.31/go.mod h1:8tMBcuVjL4kP/ECEIWTCWtwV2kj6+ouEKl4cqR4iWLw= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.2 h1:D4oz8/CzT9bAEYtVhSBmFj2dNOtaHOtMKc2vHBwYizA= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.2/go.mod h1:Za3IHqTQ+yNcRHxu1OFucBh0ACZT4j4VQFF0BqpZcLY= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4 h1:0ryTNEdJbzUCEWkVXEXoqlXV72J5keC1GvILMOuD00E= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4/go.mod h1:HQ4qwNZh32C3CBeO6iJLQlgtMzqeG17ziAA/3KDJFow= github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.5.5 h1:siiQ+jummya9OLPDEyHVb2dLW4aOMe22FGDd0sAfuSw= github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.5.5/go.mod h1:iHVx2J9pWzITdP5MJY6qWfG34TfD9EA+Qi3eV6qQCXw= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.12 h1:O+8vD2rGjfihBewr5bT+QUfYUHIxCVgG61LHoT59shM= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.12/go.mod h1:usVdWJaosa66NMvmCrr08NcWDBRv4E6+YFG2pUdw1Lk= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.16 h1:oHjJHeUy0ImIV0bsrX0X91GkV5nJAyv1l1CC9lnO0TI= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.16/go.mod h1:iRSNGgOYmiYwSCXxXaKb9HfOEj40+oTKn8pTxMlYkRM= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.12 h1:tkVNm99nkJnFo1H9IIQb5QkCiPcvCDn3Pos+IeTbGRA= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.12/go.mod h1:dIVlquSPUMqEJtx2/W17SM2SuESRaVEhEV9alcMqxjw= github.com/aws/aws-sdk-go-v2/service/s3 v1.75.3 h1:JBod0SnNqcWQ0+uAyzeRFG1zCHotW8DukumYYyNy0zo= @@ -177,10 +177,10 @@ github.com/aws/aws-sdk-go-v2/service/sso v1.24.14 h1:c5WJ3iHz7rLIgArznb3JCSQT3uU github.com/aws/aws-sdk-go-v2/service/sso v1.24.14/go.mod h1:+JJQTxB6N4niArC14YNtxcQtwEqzS3o9Z32n7q33Rfs= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.13 h1:f1L/JtUkVODD+k1+IiSJUUv8A++2qVr+Xvb3xWXETMU= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.13/go.mod h1:tvqlFoja8/s0o+UruA1Nrezo/df0PzdunMDDurUfg6U= -github.com/aws/aws-sdk-go-v2/service/sts v1.33.13 h1:3LXNnmtH3TURctC23hnC0p/39Q5gre3FI7BNOiDcVWc= -github.com/aws/aws-sdk-go-v2/service/sts v1.33.13/go.mod h1:7Yn+p66q/jt38qMoVfNvjbm3D89mGBnkwDcijgtih8w= -github.com/aws/smithy-go v1.22.2 h1:6D9hW43xKFrRx/tXXfAlIZc4JI+yQe6snnWcQyxSyLQ= -github.com/aws/smithy-go v1.22.2/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.5 h1:SciGFVNZ4mHdm7gpD1dgZYnCuVdX1s+lFTg4+4DOy70= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.5/go.mod h1:iW40X4QBmUxdP+fZNOpfmkdMZqsovezbAeO+Ubiv2pk= +github.com/aws/smithy-go v1.24.0 h1:LpilSUItNPFr1eY85RYgTIg5eIEPtvFbskaFcmmIUnk= +github.com/aws/smithy-go v1.24.0/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= github.com/axiomhq/hyperloglog v0.0.0-20240319100328-84253e514e02 h1:bXAPYSbdYbS5VTy92NIUbeDI1qyggi+JYh5op9IFlcQ= github.com/axiomhq/hyperloglog v0.0.0-20240319100328-84253e514e02/go.mod h1:k08r+Yj1PRAmuayFiRK6MYuR5Ve4IuZtTfxErMIh0+c= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= diff --git a/internal/client/tailscale/identityfederation.go b/internal/client/tailscale/identityfederation.go index b8eb0fc9c..fef02d515 100644 --- a/internal/client/tailscale/identityfederation.go +++ b/internal/client/tailscale/identityfederation.go @@ -16,12 +16,4 @@ import ( // clientID is the federated client ID used for token exchange // idToken is the Identity token from the identity provider // tags is the list of tags to be associated with the auth key -var HookResolveAuthKeyViaWIF feature.Hook[func(ctx context.Context, baseURL, clientID, idToken string, tags []string) (string, error)] - -// HookExchangeJWTForTokenViaWIF resolves to [identityfederation.exchangeJWTForToken] when the -// corresponding feature tag is enabled in the build process. -// -// baseURL is the URL of the control server used for token exchange -// clientID is the federated client ID used for token exchange -// idToken is the Identity token from the identity provider -var HookExchangeJWTForTokenViaWIF feature.Hook[func(ctx context.Context, baseURL, clientID, idToken string) (string, error)] +var HookResolveAuthKeyViaWIF feature.Hook[func(ctx context.Context, baseURL, clientID, idToken, audience string, tags []string) (string, error)] diff --git a/wif/wif.go b/wif/wif.go new file mode 100644 index 000000000..ba7c60635 --- /dev/null +++ b/wif/wif.go @@ -0,0 +1,334 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package wif + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "os" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/sts" + "github.com/aws/smithy-go" +) + +type Environment string + +const ( + EnvGitHub Environment = "github" + EnvAWS Environment = "aws" + EnvGCP Environment = "gcp" + EnvNone Environment = "none" +) + +const ( + metadataDialTimeout = 200 * time.Millisecond + metadataResponseTimeout = 200 * time.Millisecond + metadataClientTimeout = 400 * time.Millisecond + metadataIdleConnTimeout = 10 * time.Second + providerDialTimeout = 500 * time.Millisecond + providerResponseTimeout = 500 * time.Millisecond + providerClientTimeout = 2 * time.Second + githubClientTimeout = 10 * time.Second +) + +// ObtainProviderToken tries to detect what provider the client is running in +// and then tries to obtain an ID token for the audience that is passed as an argument +// To detect the environment, we do it in the following intentional order: +// 1. GitHub Actions (strongest env signals; may run atop any cloud) +// 2. AWS via IMDSv2 token endpoint (does not require env vars) +// 3. GCP via metadata header semantics +func ObtainProviderToken(ctx context.Context, audience string) (token string, err error) { + env := detectEnvironment(ctx) + + switch env { + case EnvGitHub: + return acquireGitHubActionsIDToken(ctx, audience) + case EnvAWS: + return acquireAWSWebIdentityToken(ctx, audience) + case EnvGCP: + return acquireGCPMetadataIDToken(ctx, audience) + default: + return "", errors.New("could not detect environment; provide --id-token explicitly") + } +} + +func detectEnvironment(ctx context.Context) Environment { + if os.Getenv("ACTIONS_ID_TOKEN_REQUEST_URL") != "" && + os.Getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN") != "" { + return EnvGitHub + } + + client := metadataHTTPClient() + + if detectAWSIMDSv2(ctx, client) { + return EnvAWS + } + + if detectGCPMetadata(ctx, client) { + return EnvGCP + } + + return EnvNone +} + +func metadataHTTPClient() *http.Client { + return &http.Client{ + Transport: &http.Transport{ + Proxy: nil, + DialContext: (&net.Dialer{ + Timeout: metadataDialTimeout, + }).DialContext, + ResponseHeaderTimeout: metadataResponseTimeout, + IdleConnTimeout: metadataIdleConnTimeout, + }, + Timeout: metadataClientTimeout, + } +} + +func providerHTTPClient() *http.Client { + return &http.Client{ + Transport: &http.Transport{ + Proxy: nil, + DialContext: (&net.Dialer{ + Timeout: providerDialTimeout, + }).DialContext, + ResponseHeaderTimeout: providerResponseTimeout, + }, + Timeout: providerClientTimeout, + } +} + +func detectAWSIMDSv2(ctx context.Context, client *http.Client) bool { + req, err := http.NewRequestWithContext(ctx, http.MethodPut, "http://169.254.169.254/latest/api/token", nil) + if err != nil { + return false + } + req.Header.Set("X-aws-ec2-metadata-token-ttl-seconds", "60") + + resp, err := client.Do(req) + if err != nil { + return false + } + defer resp.Body.Close() + + return resp.StatusCode == http.StatusOK +} + +func detectGCPMetadata(ctx context.Context, client *http.Client) bool { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://metadata.google.internal", nil) + if err != nil { + return false + } + req.Header.Set("Metadata-Flavor", "Google") + + resp, err := client.Do(req) + if err != nil { + return false + } + defer resp.Body.Close() + + return resp.Header.Get("Metadata-Flavor") == "Google" +} + +type githubOIDCResponse struct { + Value string `json:"value"` +} + +func acquireGitHubActionsIDToken(ctx context.Context, audience string) (jwt string, err error) { + reqURL := os.Getenv("ACTIONS_ID_TOKEN_REQUEST_URL") + reqTok := os.Getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN") + if reqURL == "" || reqTok == "" { + return "", errors.New("missing ACTIONS_ID_TOKEN_REQUEST_URL/TOKEN (ensure workflow has permissions: id-token: write)") + } + + u, err := url.Parse(reqURL) + if err != nil { + return "", fmt.Errorf("parse ACTIONS_ID_TOKEN_REQUEST_URL: %w", err) + } + if strings.TrimSpace(audience) != "" { + q := u.Query() + q.Set("audience", strings.TrimSpace(audience)) + u.RawQuery = q.Encode() + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) + if err != nil { + return "", fmt.Errorf("build request: %w", err) + } + req.Header.Set("Authorization", "Bearer "+reqTok) + req.Header.Set("Accept", "application/json") + + client := &http.Client{Timeout: githubClientTimeout} + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("request github oidc token: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode/100 != 2 { + b, _ := io.ReadAll(io.LimitReader(resp.Body, 2048)) + return "", fmt.Errorf("github oidc token endpoint returned %s: %s", resp.Status, strings.TrimSpace(string(b))) + } + + var tr githubOIDCResponse + if err := json.NewDecoder(resp.Body).Decode(&tr); err != nil { + return "", fmt.Errorf("decode github oidc response: %w", err) + } + if strings.TrimSpace(tr.Value) == "" { + return "", errors.New("github oidc response contained empty token") + } + + // GitHub response doesn't provide exp directly; caller can parse JWT if needed. + return tr.Value, nil +} + +func acquireAWSWebIdentityToken(ctx context.Context, audience string) (jwt string, err error) { + duration := 5 * time.Minute + + region, err := detectAWSRegion(ctx) + if err != nil { + return "", err + } + + // LoadDefaultConfig wires up the default credential chain (incl. IMDS). + cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region)) + if err != nil { + return "", fmt.Errorf("load aws config: %w", err) + } + + // Verify credentials are available before proceeding. + if _, err := cfg.Credentials.Retrieve(ctx); err != nil { + return "", fmt.Errorf("AWS credentials unavailable (instance profile/IMDS?): %w", err) + } + + stsClient := sts.NewFromConfig(cfg) + + in := &sts.GetWebIdentityTokenInput{ + Audience: []string{strings.TrimSpace(audience)}, + SigningAlgorithm: aws.String("RS256"), + DurationSeconds: aws.Int32(int32(duration / time.Second)), + } + + out, err := stsClient.GetWebIdentityToken(ctx, in) + if err != nil { + var apiErr smithy.APIError + if errors.As(err, &apiErr) { + return "", fmt.Errorf("aws sts:GetWebIdentityToken failed (%s): %w", apiErr.ErrorCode(), err) + } + return "", fmt.Errorf("aws sts:GetWebIdentityToken failed: %w", err) + } + + if out.WebIdentityToken == nil || strings.TrimSpace(*out.WebIdentityToken) == "" { + return "", fmt.Errorf("aws sts:GetWebIdentityToken returned empty token") + } + + return *out.WebIdentityToken, nil +} + +func acquireGCPMetadataIDToken(ctx context.Context, audience string) (jwt string, err error) { + u := "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity" + v := url.Values{} + v.Set("audience", strings.TrimSpace(audience)) + v.Set("format", "full") + fullURL := u + "?" + v.Encode() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil) + if err != nil { + return "", fmt.Errorf("build request: %w", err) + } + req.Header.Set("Metadata-Flavor", "Google") + + client := providerHTTPClient() + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("call gcp metadata identity endpoint: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode/100 != 2 { + b, _ := io.ReadAll(io.LimitReader(resp.Body, 2048)) + return "", fmt.Errorf("gcp metadata identity endpoint returned %s: %s", resp.Status, strings.TrimSpace(string(b))) + } + + b, err := io.ReadAll(io.LimitReader(resp.Body, 1024*1024)) + if err != nil { + return "", fmt.Errorf("read gcp id token: %w", err) + } + jwt = strings.TrimSpace(string(b)) + if jwt == "" { + return "", fmt.Errorf("gcp metadata returned empty token") + } + + return jwt, nil +} + +func detectAWSRegion(ctx context.Context) (string, error) { + client := providerHTTPClient() + + req, err := http.NewRequestWithContext(ctx, http.MethodPut, "http://169.254.169.254/latest/api/token", nil) + if err != nil { + return "", fmt.Errorf("build imds token request: %w", err) + } + req.Header.Set("X-aws-ec2-metadata-token-ttl-seconds", "60") + + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("call imds token endpoint: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(io.LimitReader(resp.Body, 512)) + return "", fmt.Errorf("imds token endpoint returned %s: %s", resp.Status, strings.TrimSpace(string(b))) + } + + tokenBytes, err := io.ReadAll(io.LimitReader(resp.Body, 1024)) + if err != nil { + return "", fmt.Errorf("read imds token: %w", err) + } + token := strings.TrimSpace(string(tokenBytes)) + if token == "" { + return "", fmt.Errorf("imds token endpoint returned empty token") + } + + // Get instance identity document + req2, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://169.254.169.254/latest/dynamic/instance-identity/document", nil) + if err != nil { + return "", fmt.Errorf("build identity document request: %w", err) + } + req2.Header.Set("X-aws-ec2-metadata-token", token) + + resp2, err := client.Do(req2) + if err != nil { + return "", fmt.Errorf("call identity document endpoint: %w", err) + } + defer resp2.Body.Close() + + if resp2.StatusCode != http.StatusOK { + b, _ := io.ReadAll(io.LimitReader(resp2.Body, 512)) + return "", fmt.Errorf("identity document endpoint returned %s: %s", resp2.Status, strings.TrimSpace(string(b))) + } + + var doc struct { + Region string `json:"region"` + } + if err := json.NewDecoder(resp2.Body).Decode(&doc); err != nil { + return "", fmt.Errorf("decode identity document: %w", err) + } + if doc.Region == "" { + return "", fmt.Errorf("region not found in instance identity document") + } + + return doc.Region, nil +}