diff --git a/cmd/k8s-operator/deploy/chart/templates/deployment.yaml b/cmd/k8s-operator/deploy/chart/templates/deployment.yaml index 51d0a88c3..0f2dc42fc 100644 --- a/cmd/k8s-operator/deploy/chart/templates/deployment.yaml +++ b/cmd/k8s-operator/deploy/chart/templates/deployment.yaml @@ -34,7 +34,9 @@ spec: securityContext: {{- toYaml . | nindent 8 }} {{- end }} + {{- if or .Values.oauth.clientSecret .Values.oauth.audience }} volumes: + {{- if .Values.oauth.clientSecret }} - name: oauth {{- with .Values.oauthSecretVolume }} {{- toYaml . | nindent 10 }} @@ -42,6 +44,17 @@ spec: secret: secretName: operator-oauth {{- end }} + {{- else }} + - name: oidc-jwt + projected: + defaultMode: 420 + sources: + - serviceAccountToken: + audience: {{ .Values.oauth.audience }} + expirationSeconds: 3600 + path: token + {{- end }} + {{- end }} containers: - name: operator {{- with .Values.operatorConfig.securityContext }} @@ -72,10 +85,15 @@ spec: value: {{ .Values.loginServer }} - name: OPERATOR_INGRESS_CLASS_NAME value: {{ .Values.ingressClass.name }} + {{- if .Values.oauth.clientSecret }} - name: CLIENT_ID_FILE value: /oauth/client_id - name: CLIENT_SECRET_FILE value: /oauth/client_secret + {{- else if .Values.oauth.audience }} + - name: CLIENT_ID + value: {{ .Values.oauth.clientId }} + {{- end }} {{- $proxyTag := printf ":%s" ( .Values.proxyConfig.image.tag | default .Chart.AppVersion )}} - name: PROXY_IMAGE value: {{ coalesce .Values.proxyConfig.image.repo .Values.proxyConfig.image.repository }}{{- if .Values.proxyConfig.image.digest -}}{{ printf "@%s" .Values.proxyConfig.image.digest}}{{- else -}}{{ printf "%s" $proxyTag }}{{- end }} @@ -100,10 +118,18 @@ spec: {{- with .Values.operatorConfig.extraEnv }} {{- toYaml . | nindent 12 }} {{- end }} + {{- if or .Values.oauth.clientSecret .Values.oauth.audience }} volumeMounts: + {{- if .Values.oauth.clientSecret }} - name: oauth mountPath: /oauth readOnly: true + {{- else }} + - name: oidc-jwt + mountPath: /var/run/secrets/tailscale/serviceaccount + readOnly: true + {{- end }} + {{- end }} {{- with .Values.operatorConfig.nodeSelector }} nodeSelector: {{- toYaml . | nindent 8 }} diff --git a/cmd/k8s-operator/deploy/chart/templates/oauth-secret.yaml b/cmd/k8s-operator/deploy/chart/templates/oauth-secret.yaml index b44fde0a1..b85c78915 100644 --- a/cmd/k8s-operator/deploy/chart/templates/oauth-secret.yaml +++ b/cmd/k8s-operator/deploy/chart/templates/oauth-secret.yaml @@ -1,7 +1,7 @@ # Copyright (c) Tailscale Inc & AUTHORS # SPDX-License-Identifier: BSD-3-Clause -{{ if and .Values.oauth .Values.oauth.clientId -}} +{{ if and .Values.oauth .Values.oauth.clientId .Values.oauth.clientSecret -}} apiVersion: v1 kind: Secret metadata: diff --git a/cmd/k8s-operator/deploy/chart/values.yaml b/cmd/k8s-operator/deploy/chart/values.yaml index cdedb92e8..eb11fc7f2 100644 --- a/cmd/k8s-operator/deploy/chart/values.yaml +++ b/cmd/k8s-operator/deploy/chart/values.yaml @@ -1,13 +1,20 @@ # Copyright (c) Tailscale Inc & AUTHORS # SPDX-License-Identifier: BSD-3-Clause -# Operator oauth credentials. If set a Kubernetes Secret with the provided -# values will be created in the operator namespace. If unset a Secret named -# operator-oauth must be precreated or oauthSecretVolume needs to be adjusted. -# This block will be overridden by oauthSecretVolume, if set. -oauth: {} - # clientId: "" - # clientSecret: "" +# Operator oauth credentials. If unset a Secret named operator-oauth must be +# precreated or oauthSecretVolume needs to be adjusted. This block will be +# overridden by oauthSecretVolume, if set. +oauth: + # The Client ID the operator will authenticate with. + clientId: "" + # If set a Kubernetes Secret with the provided value will be created in + # the operator namespace, and mounted into the operator Pod. Takes precedence + # over oauth.audience. + clientSecret: "" + # The audience for oauth.clientId if using a workload identity federation + # OAuth client. Mutually exclusive with oauth.clientSecret. + # See https://tailscale.com/kb/1581/workload-identity-federation. + audience: "" # URL of the control plane to be used by all resources managed by the operator. loginServer: "" diff --git a/cmd/k8s-operator/generate/main.go b/cmd/k8s-operator/generate/main.go index 5fd5d551b..08bdc350d 100644 --- a/cmd/k8s-operator/generate/main.go +++ b/cmd/k8s-operator/generate/main.go @@ -69,7 +69,7 @@ func main() { }() log.Print("Templating Helm chart contents") helmTmplCmd := exec.Command("./tool/helm", "template", "operator", "./cmd/k8s-operator/deploy/chart", - "--namespace=tailscale") + "--namespace=tailscale", "--set=oauth.clientSecret=''") helmTmplCmd.Dir = repoRoot var out bytes.Buffer helmTmplCmd.Stdout = &out diff --git a/cmd/k8s-operator/operator.go b/cmd/k8s-operator/operator.go index cc97b1be2..d5ff07780 100644 --- a/cmd/k8s-operator/operator.go +++ b/cmd/k8s-operator/operator.go @@ -164,22 +164,24 @@ func main() { runReconcilers(rOpts) } -// initTSNet initializes the tsnet.Server and logs in to Tailscale. It uses the -// CLIENT_ID_FILE and CLIENT_SECRET_FILE environment variables to authenticate -// with Tailscale. +// initTSNet initializes the tsnet.Server and logs in to Tailscale. If CLIENT_ID +// is set, it authenticates to the Tailscale API using the federated OIDC workload +// identity flow. Otherwise, it uses the CLIENT_ID_FILE and CLIENT_SECRET_FILE +// environment variables to authenticate with static credentials. func initTSNet(zlog *zap.SugaredLogger, loginServer string) (*tsnet.Server, tsClient) { var ( - clientIDPath = defaultEnv("CLIENT_ID_FILE", "") - clientSecretPath = defaultEnv("CLIENT_SECRET_FILE", "") + clientID = defaultEnv("CLIENT_ID", "") // Used for workload identity federation. + clientIDPath = defaultEnv("CLIENT_ID_FILE", "") // Used for static client credentials. + clientSecretPath = defaultEnv("CLIENT_SECRET_FILE", "") // Used for static client credentials. hostname = defaultEnv("OPERATOR_HOSTNAME", "tailscale-operator") kubeSecret = defaultEnv("OPERATOR_SECRET", "") operatorTags = defaultEnv("OPERATOR_INITIAL_TAGS", "tag:k8s-operator") ) startlog := zlog.Named("startup") - if clientIDPath == "" || clientSecretPath == "" { - startlog.Fatalf("CLIENT_ID_FILE and CLIENT_SECRET_FILE must be set") + if clientID == "" && (clientIDPath == "" || clientSecretPath == "") { + startlog.Fatalf("CLIENT_ID_FILE and CLIENT_SECRET_FILE must be set") // TODO(tomhjp): error message can mention WIF once it's publicly available. } - tsc, err := newTSClient(context.Background(), clientIDPath, clientSecretPath, loginServer) + tsc, err := newTSClient(zlog.Named("ts-api-client"), clientID, clientIDPath, clientSecretPath, loginServer) if err != nil { startlog.Fatalf("error creating Tailscale client: %v", err) } diff --git a/cmd/k8s-operator/tsclient.go b/cmd/k8s-operator/tsclient.go index 50620c26d..d22fa1797 100644 --- a/cmd/k8s-operator/tsclient.go +++ b/cmd/k8s-operator/tsclient.go @@ -8,8 +8,13 @@ package main import ( "context" "fmt" + "net/http" "os" + "sync" + "time" + "go.uber.org/zap" + "golang.org/x/oauth2" "golang.org/x/oauth2/clientcredentials" "tailscale.com/internal/client/tailscale" "tailscale.com/ipn" @@ -20,30 +25,53 @@ import ( // call should be performed on the default tailnet for the provided credentials. const ( defaultTailnet = "-" + oidcJWTPath = "/var/run/secrets/tailscale/serviceaccount/token" ) -func newTSClient(ctx context.Context, clientIDPath, clientSecretPath, loginServer string) (tsClient, error) { - clientID, err := os.ReadFile(clientIDPath) - if err != nil { - return nil, fmt.Errorf("error reading client ID %q: %w", clientIDPath, err) - } - clientSecret, err := os.ReadFile(clientSecretPath) - if err != nil { - return nil, fmt.Errorf("reading client secret %q: %w", clientSecretPath, err) - } - const tokenURLPath = "/api/v2/oauth/token" - tokenURL := fmt.Sprintf("%s%s", ipn.DefaultControlURL, tokenURLPath) +func newTSClient(logger *zap.SugaredLogger, clientID, clientIDPath, clientSecretPath, loginServer string) (*tailscale.Client, error) { + baseURL := ipn.DefaultControlURL if loginServer != "" { - tokenURL = fmt.Sprintf("%s%s", loginServer, tokenURLPath) + baseURL = loginServer } - credentials := clientcredentials.Config{ - ClientID: string(clientID), - ClientSecret: string(clientSecret), - TokenURL: tokenURL, + + var httpClient *http.Client + if clientID == "" { + // Use static client credentials mounted to disk. + id, err := os.ReadFile(clientIDPath) + if err != nil { + return nil, fmt.Errorf("error reading client ID %q: %w", clientIDPath, err) + } + secret, err := os.ReadFile(clientSecretPath) + if err != nil { + return nil, fmt.Errorf("reading client secret %q: %w", clientSecretPath, err) + } + credentials := clientcredentials.Config{ + ClientID: string(id), + ClientSecret: string(secret), + TokenURL: fmt.Sprintf("%s%s", baseURL, "/api/v2/oauth/token"), + } + tokenSrc := credentials.TokenSource(context.Background()) + httpClient = oauth2.NewClient(context.Background(), tokenSrc) + } else { + // Use workload identity federation. + tokenSrc := &jwtTokenSource{ + logger: logger, + jwtPath: oidcJWTPath, + baseCfg: clientcredentials.Config{ + ClientID: clientID, + TokenURL: fmt.Sprintf("%s%s", baseURL, "/api/v2/oauth/token-exchange"), + }, + } + httpClient = &http.Client{ + Transport: &oauth2.Transport{ + Source: tokenSrc, + }, + } } + c := tailscale.NewClient(defaultTailnet, nil) c.UserAgent = "tailscale-k8s-operator" - c.HTTPClient = credentials.Client(ctx) + c.HTTPClient = httpClient if loginServer != "" { c.BaseURL = loginServer } @@ -63,3 +91,43 @@ type tsClient interface { // DeleteVIPService is a method for deleting a Tailscale Service. DeleteVIPService(ctx context.Context, name tailcfg.ServiceName) error } + +// jwtTokenSource implements the [oauth2.TokenSource] interface, but with the +// ability to regenerate a fresh underlying token source each time a new value +// of the JWT parameter is needed due to expiration. +type jwtTokenSource struct { + logger *zap.SugaredLogger + jwtPath string // Path to the file containing an automatically refreshed JWT. + baseCfg clientcredentials.Config // Holds config that doesn't change for the lifetime of the process. + + mu sync.Mutex // Guards underlying. + underlying oauth2.TokenSource // The oauth2 client implementation. Does its own separate caching of the access token. +} + +func (s *jwtTokenSource) Token() (*oauth2.Token, error) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.underlying != nil { + t, err := s.underlying.Token() + if err == nil && t != nil && t.Valid() { + return t, nil + } + } + + s.logger.Debugf("Refreshing JWT from %s", s.jwtPath) + tk, err := os.ReadFile(s.jwtPath) + if err != nil { + return nil, fmt.Errorf("error reading JWT from %q: %w", s.jwtPath, err) + } + + // Shallow copy of the base config. + credentials := s.baseCfg + credentials.EndpointParams = map[string][]string{ + "jwt": {string(tk)}, + } + + src := credentials.TokenSource(context.Background()) + s.underlying = oauth2.ReuseTokenSourceWithExpiry(nil, src, time.Minute) + return s.underlying.Token() +} diff --git a/cmd/k8s-operator/tsclient_test.go b/cmd/k8s-operator/tsclient_test.go new file mode 100644 index 000000000..16de512d5 --- /dev/null +++ b/cmd/k8s-operator/tsclient_test.go @@ -0,0 +1,135 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "go.uber.org/zap" + "golang.org/x/oauth2" +) + +func TestNewStaticClient(t *testing.T) { + const ( + clientIDFile = "client-id" + clientSecretFile = "client-secret" + ) + + tmp := t.TempDir() + clientIDPath := filepath.Join(tmp, clientIDFile) + if err := os.WriteFile(clientIDPath, []byte("test-client-id"), 0600); err != nil { + t.Fatalf("error writing test file %q: %v", clientIDPath, err) + } + clientSecretPath := filepath.Join(tmp, clientSecretFile) + if err := os.WriteFile(clientSecretPath, []byte("test-client-secret"), 0600); err != nil { + t.Fatalf("error writing test file %q: %v", clientSecretPath, err) + } + + srv := testAPI(t, 3600) + cl, err := newTSClient(zap.NewNop().Sugar(), "", clientIDPath, clientSecretPath, srv.URL) + if err != nil { + t.Fatalf("error creating Tailscale client: %v", err) + } + + resp, err := cl.HTTPClient.Get(srv.URL) + if err != nil { + t.Fatalf("error making test API call: %v", err) + } + defer resp.Body.Close() + + got, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("error reading response body: %v", err) + } + want := "Bearer " + testToken("/api/v2/oauth/token", "test-client-id", "test-client-secret", "") + if string(got) != want { + t.Errorf("got %q; want %q", got, want) + } +} + +func TestNewWorkloadIdentityClient(t *testing.T) { + // 5 seconds is within expiryDelta leeway, so the access token will + // immediately be considered expired and get refreshed on each access. + srv := testAPI(t, 5) + cl, err := newTSClient(zap.NewNop().Sugar(), "test-client-id", "", "", srv.URL) + if err != nil { + t.Fatalf("error creating Tailscale client: %v", err) + } + + // Modify the path where the JWT will be read from. + oauth2Transport, ok := cl.HTTPClient.Transport.(*oauth2.Transport) + if !ok { + t.Fatalf("expected oauth2.Transport, got %T", cl.HTTPClient.Transport) + } + jwtTokenSource, ok := oauth2Transport.Source.(*jwtTokenSource) + if !ok { + t.Fatalf("expected jwtTokenSource, got %T", oauth2Transport.Source) + } + tmp := t.TempDir() + jwtPath := filepath.Join(tmp, "token") + jwtTokenSource.jwtPath = jwtPath + + for _, jwt := range []string{"test-jwt", "updated-test-jwt"} { + if err := os.WriteFile(jwtPath, []byte(jwt), 0600); err != nil { + t.Fatalf("error writing test file %q: %v", jwtPath, err) + } + resp, err := cl.HTTPClient.Get(srv.URL) + if err != nil { + t.Fatalf("error making test API call: %v", err) + } + defer resp.Body.Close() + + got, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("error reading response body: %v", err) + } + if want := "Bearer " + testToken("/api/v2/oauth/token-exchange", "test-client-id", "", jwt); string(got) != want { + t.Errorf("got %q; want %q", got, want) + } + } +} + +func testAPI(t *testing.T, expirationSeconds int) *httptest.Server { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Logf("test server got request: %s %s", r.Method, r.URL.Path) + switch r.URL.Path { + case "/api/v2/oauth/token", "/api/v2/oauth/token-exchange": + id, secret, ok := r.BasicAuth() + if !ok { + t.Fatal("missing or invalid basic auth") + } + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": testToken(r.URL.Path, id, secret, r.FormValue("jwt")), + "token_type": "Bearer", + "expires_in": expirationSeconds, + }); err != nil { + t.Fatalf("error writing response: %v", err) + } + case "/": + // Echo back the authz header for test assertions. + _, err := w.Write([]byte(r.Header.Get("Authorization"))) + if err != nil { + t.Fatalf("error writing response: %v", err) + } + default: + w.WriteHeader(http.StatusNotFound) + } + })) + t.Cleanup(srv.Close) + return srv +} + +func testToken(path, id, secret, jwt string) string { + return fmt.Sprintf("%s|%s|%s|%s", path, id, secret, jwt) +}