diff --git a/cmd/containerboot/main.go b/cmd/containerboot/main.go index b074d9fb3..dc382fa25 100644 --- a/cmd/containerboot/main.go +++ b/cmd/containerboot/main.go @@ -356,7 +356,11 @@ func tailscaledArgs(cfg *settings) []string { args := []string{"--socket=" + cfg.Socket} switch { case cfg.InKubernetes && cfg.KubeSecret != "": - args = append(args, "--state=kube:"+cfg.KubeSecret, "--statedir=/tmp") + args = append(args, "--state=kube:"+cfg.KubeSecret) + if cfg.StateDir == "" { + cfg.StateDir = "/tmp" + } + fallthrough case cfg.StateDir != "": args = append(args, "--statedir="+cfg.StateDir) default: diff --git a/ipn/ipnlocal/cert.go b/ipn/ipnlocal/cert.go index fb2ffc12c..b30878563 100644 --- a/ipn/ipnlocal/cert.go +++ b/ipn/ipnlocal/cert.go @@ -32,6 +32,8 @@ import ( "golang.org/x/crypto/acme" "tailscale.com/envknob" + "tailscale.com/hostinfo" + "tailscale.com/ipn" "tailscale.com/ipn/ipnstate" "tailscale.com/types/logger" "tailscale.com/version" @@ -94,9 +96,9 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertK log.Printf("acme %T: %s", v, j) } - if pair, ok := getCertPEMCached(dir, domain, now); ok { + if pair, err := b.getCertPEMCached(dir, domain, now); err == nil { future := now.AddDate(0, 0, 14) - if shouldStartDomainRenewal(dir, domain, future) { + if b.shouldStartDomainRenewal(dir, domain, future) { logf("starting async renewal") // Start renewal in the background. go b.getCertPEM(context.Background(), logf, traceACME, dir, domain, future) @@ -112,7 +114,7 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertK return pair, nil } -func shouldStartDomainRenewal(dir, domain string, future time.Time) bool { +func (b *LocalBackend) shouldStartDomainRenewal(dir, domain string, future time.Time) bool { renewMu.Lock() defer renewMu.Unlock() now := time.Now() @@ -122,8 +124,91 @@ func shouldStartDomainRenewal(dir, domain string, future time.Time) bool { return false } lastRenewCheck[domain] = now - _, ok := getCertPEMCached(dir, domain, future) - return !ok + _, err := b.getCertPEMCached(dir, domain, future) + return errors.Is(err, errCertExpired) +} + +// certStore provides a way to perist and retrieve TLS certificates. +// As of 2023-02-01, we use store certs in directories on disk everywhere +// except on Kubernetes, where we use the state store. +type certStore interface { + // Read returns the cert and key for domain, if they exist and are valid + // for now. If they're expired, it returns errCertExpired. + // If they don't exist, it returns ipn.ErrStateNotExist. + Read(domain string, now time.Time) (*TLSCertKeyPair, error) + // WriteCert writes the cert for domain. + WriteCert(domain string, cert []byte) error + // WriteKey writes the key for domain. + WriteKey(domain string, key []byte) error +} + +var errCertExpired = errors.New("cert expired") + +func (b *LocalBackend) getCertStore(dir string) certStore { + if hostinfo.GetEnvType() == hostinfo.Kubernetes && dir == "/tmp" { + return certStateStore{b.store} + } + return certFileStore(dir) +} + +// certFileStore implements certStore by storing the cert & key files in the named directory. +type certFileStore string // dir + +func (f certFileStore) Read(domain string, now time.Time) (*TLSCertKeyPair, error) { + certPEM, err := os.ReadFile(keyFile(string(f), domain)) + if err != nil { + if os.IsNotExist(err) { + return nil, ipn.ErrStateNotExist + } + return nil, err + } + keyPEM, err := os.ReadFile(certFile(string(f), domain)) + if err != nil { + if os.IsNotExist(err) { + return nil, ipn.ErrStateNotExist + } + return nil, err + } + if !validCertPEM(domain, keyPEM, certPEM, now) { + return nil, errCertExpired + } + return &TLSCertKeyPair{CertPEM: certPEM, KeyPEM: keyPEM, Cached: true}, nil +} + +func (f certFileStore) WriteCert(domain string, cert []byte) error { + return os.WriteFile(keyFile(string(f), domain), cert, 0644) +} + +func (f certFileStore) WriteKey(domain string, key []byte) error { + return os.WriteFile(keyFile(string(f), domain), key, 0600) +} + +// certStateStore implements certStore by storing the cert & key files in an ipn.StateStore. +type certStateStore struct { + ipn.StateStore +} + +func (s certStateStore) Read(domain string, now time.Time) (*TLSCertKeyPair, error) { + certPEM, err := s.ReadState(ipn.StateKey(domain + ".crt")) + if err != nil { + return nil, err + } + keyPEM, err := s.ReadState(ipn.StateKey(domain + ".key")) + if err != nil { + return nil, err + } + if !validCertPEM(domain, keyPEM, certPEM, now) { + return nil, errCertExpired + } + return &TLSCertKeyPair{CertPEM: certPEM, KeyPEM: keyPEM, Cached: true}, nil +} + +func (s certStateStore) WriteCert(domain string, cert []byte) error { + return s.WriteState(ipn.StateKey(domain+".crt"), cert) +} + +func (s certStateStore) WriteKey(domain string, key []byte) error { + return s.WriteState(ipn.StateKey(domain+".key"), key) } // TLSCertKeyPair is a TLS public and private key, and whether they were obtained @@ -137,30 +222,27 @@ type TLSCertKeyPair struct { func keyFile(dir, domain string) string { return filepath.Join(dir, domain+".key") } func certFile(dir, domain string) string { return filepath.Join(dir, domain+".crt") } -// getCertPEMCached returns a non-nil keyPair and true if a cached -// keypair for domain exists on disk in dir that is valid at the -// provided now time. -func getCertPEMCached(dir, domain string, now time.Time) (p *TLSCertKeyPair, ok bool) { +// getCertPEMCached returns a non-nil keyPair and true if a cached keypair for +// domain exists on disk in dir that is valid at the provided now time. +// If the keypair is expired, it returns errCertExpired. +// If the keypair doesn't exist, it returns ipn.ErrStateNotExist. +func (b *LocalBackend) getCertPEMCached(dir, domain string, now time.Time) (p *TLSCertKeyPair, err error) { if !validLookingCertDomain(domain) { // Before we read files from disk using it, validate it's halfway // reasonable looking. - return nil, false + return nil, fmt.Errorf("invalid domain %q", domain) } - if keyPEM, err := os.ReadFile(keyFile(dir, domain)); err == nil { - certPEM, _ := os.ReadFile(certFile(dir, domain)) - if validCertPEM(domain, keyPEM, certPEM, now) { - return &TLSCertKeyPair{CertPEM: certPEM, KeyPEM: keyPEM, Cached: true}, true - } - } - return nil, false + return b.getCertStore(dir).Read(domain, now) } func (b *LocalBackend) getCertPEM(ctx context.Context, logf logger.Logf, traceACME func(any), dir, domain string, now time.Time) (*TLSCertKeyPair, error) { acmeMu.Lock() defer acmeMu.Unlock() - if p, ok := getCertPEMCached(dir, domain, now); ok { + if p, err := b.getCertPEMCached(dir, domain, now); err == nil { return p, nil + } else if !errors.Is(err, ipn.ErrStateNotExist) && !errors.Is(err, errCertExpired) { + return nil, err } key, err := acmeKey(dir) @@ -274,7 +356,8 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, logf logger.Logf, traceAC if err := encodeECDSAKey(&privPEM, certPrivKey); err != nil { return nil, err } - if err := os.WriteFile(keyFile(dir, domain), privPEM.Bytes(), 0600); err != nil { + certStore := b.getCertStore(dir) + if err := certStore.WriteKey(domain, privPEM.Bytes()); err != nil { return nil, err } @@ -297,7 +380,7 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, logf logger.Logf, traceAC return nil, err } } - if err := os.WriteFile(certFile(dir, domain), certPEM.Bytes(), 0644); err != nil { + if err := certStore.WriteCert(domain, certPEM.Bytes()); err != nil { return nil, err }