diff --git a/ipn/ipnlocal/cert.go b/ipn/ipnlocal/cert.go index 1b3264253..ff29ffc39 100644 --- a/ipn/ipnlocal/cert.go +++ b/ipn/ipnlocal/cert.go @@ -82,11 +82,6 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertK return nil, errors.New("invalid domain") } logf := logger.WithPrefix(b.logf, fmt.Sprintf("cert(%q): ", domain)) - dir, err := b.certDir() - if err != nil { - logf("failed to get certDir: %v", err) - return nil, err - } now := time.Now() traceACME := func(v any) { if !acmeDebug() { @@ -96,17 +91,22 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertK log.Printf("acme %T: %s", v, j) } - if pair, err := b.getCertPEMCached(dir, domain, now); err == nil { + cs, err := b.getCertStore() + if err != nil { + return nil, err + } + + if pair, err := getCertPEMCached(cs, domain, now); err == nil { future := now.AddDate(0, 0, 14) - if b.shouldStartDomainRenewal(dir, domain, future) { + if b.shouldStartDomainRenewal(cs, domain, future) { logf("starting async renewal") // Start renewal in the background. - go b.getCertPEM(context.Background(), logf, traceACME, dir, domain, future) + go b.getCertPEM(context.Background(), cs, logf, traceACME, domain, future) } return pair, nil } - pair, err := b.getCertPEM(ctx, logf, traceACME, dir, domain, now) + pair, err := b.getCertPEM(ctx, cs, logf, traceACME, domain, now) if err != nil { logf("getCertPEM: %v", err) return nil, err @@ -114,7 +114,7 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertK return pair, nil } -func (b *LocalBackend) shouldStartDomainRenewal(dir, domain string, future time.Time) bool { +func (b *LocalBackend) shouldStartDomainRenewal(cs certStore, domain string, future time.Time) bool { renewMu.Lock() defer renewMu.Unlock() now := time.Now() @@ -124,7 +124,7 @@ func (b *LocalBackend) shouldStartDomainRenewal(dir, domain string, future time. return false } lastRenewCheck[domain] = now - _, err := b.getCertPEMCached(dir, domain, future) + _, err := getCertPEMCached(cs, domain, future) return errors.Is(err, errCertExpired) } @@ -140,15 +140,24 @@ type certStore interface { WriteCert(domain string, cert []byte) error // WriteKey writes the key for domain. WriteKey(domain string, key []byte) error + // ACMEKey returns the value previously stored via WriteACMEKey. + // It is a PEM encoded ECDSA key. + ACMEKey() ([]byte, error) + // WriteACMEKey stores the provided PEM encoded ECDSA key. + WriteACMEKey([]byte) error } var errCertExpired = errors.New("cert expired") -func (b *LocalBackend) getCertStore(dir string) certStore { +func (b *LocalBackend) getCertStore() (certStore, error) { + dir, err := b.certDir() + if err != nil { + return nil, err + } if hostinfo.GetEnvType() == hostinfo.Kubernetes && dir == "/tmp" { - return certStateStore{StateStore: b.store} + return certStateStore{StateStore: b.store}, nil } - return certFileStore{dir: dir} + return certFileStore{dir: dir}, nil } // certFileStore implements certStore by storing the cert & key files in the named directory. @@ -160,6 +169,25 @@ type certFileStore struct { testRoots *x509.CertPool } +const acmePEMName = "acme-account.key.pem" + +func (f certFileStore) ACMEKey() ([]byte, error) { + pemName := filepath.Join(f.dir, acmePEMName) + v, err := os.ReadFile(pemName) + if err != nil { + if os.IsNotExist(err) { + return nil, ipn.ErrStateNotExist + } + return nil, err + } + return v, nil +} + +func (f certFileStore) WriteACMEKey(b []byte) error { + pemName := filepath.Join(f.dir, acmePEMName) + return os.WriteFile(pemName, b, 0600) +} + func (f certFileStore) Read(domain string, now time.Time) (*TLSCertKeyPair, error) { certPEM, err := os.ReadFile(certFile(f.dir, domain)) if err != nil { @@ -221,6 +249,14 @@ func (s certStateStore) WriteKey(domain string, key []byte) error { return s.WriteState(ipn.StateKey(domain+".key"), key) } +func (s certStateStore) ACMEKey() ([]byte, error) { + return s.ReadState(ipn.StateKey(acmePEMName)) +} + +func (s certStateStore) WriteACMEKey(key []byte) error { + return s.WriteState(ipn.StateKey(acmePEMName), key) +} + // TLSCertKeyPair is a TLS public and private key, and whether they were obtained // from cache or freshly obtained. type TLSCertKeyPair struct { @@ -236,26 +272,26 @@ func certFile(dir, domain string) string { return filepath.Join(dir, domain+".cr // domain exists on disk in dir that is valid at the provided now time. // If the keypair is expired, it returns errCertExpired. // If the keypair doesn't exist, it returns ipn.ErrStateNotExist. -func (b *LocalBackend) getCertPEMCached(dir, domain string, now time.Time) (p *TLSCertKeyPair, err error) { +func getCertPEMCached(cs certStore, domain string, now time.Time) (p *TLSCertKeyPair, err error) { if !validLookingCertDomain(domain) { // Before we read files from disk using it, validate it's halfway // reasonable looking. return nil, fmt.Errorf("invalid domain %q", domain) } - return b.getCertStore(dir).Read(domain, now) + return cs.Read(domain, now) } -func (b *LocalBackend) getCertPEM(ctx context.Context, logf logger.Logf, traceACME func(any), dir, domain string, now time.Time) (*TLSCertKeyPair, error) { +func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger.Logf, traceACME func(any), domain string, now time.Time) (*TLSCertKeyPair, error) { acmeMu.Lock() defer acmeMu.Unlock() - if p, err := b.getCertPEMCached(dir, domain, now); err == nil { + if p, err := getCertPEMCached(cs, domain, now); err == nil { return p, nil } else if !errors.Is(err, ipn.ErrStateNotExist) && !errors.Is(err, errCertExpired) { return nil, err } - key, err := acmeKey(dir) + key, err := acmeKey(cs) if err != nil { return nil, fmt.Errorf("acmeKey: %w", err) } @@ -366,8 +402,7 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, logf logger.Logf, traceAC if err := encodeECDSAKey(&privPEM, certPrivKey); err != nil { return nil, err } - certStore := b.getCertStore(dir) - if err := certStore.WriteKey(domain, privPEM.Bytes()); err != nil { + if err := cs.WriteKey(domain, privPEM.Bytes()); err != nil { return nil, err } @@ -390,7 +425,7 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, logf logger.Logf, traceAC return nil, err } } - if err := certStore.WriteCert(domain, certPEM.Bytes()); err != nil { + if err := cs.WriteCert(domain, certPEM.Bytes()); err != nil { return nil, err } @@ -444,14 +479,15 @@ func parsePrivateKey(der []byte) (crypto.Signer, error) { return nil, errors.New("acme/autocert: failed to parse private key") } -func acmeKey(dir string) (crypto.Signer, error) { - pemName := filepath.Join(dir, "acme-account.key.pem") - if v, err := os.ReadFile(pemName); err == nil { +func acmeKey(cs certStore) (crypto.Signer, error) { + if v, err := cs.ACMEKey(); err == nil { priv, _ := pem.Decode(v) if priv == nil || !strings.Contains(priv.Type, "PRIVATE") { return nil, errors.New("acme/autocert: invalid account key found in cache") } return parsePrivateKey(priv.Bytes) + } else if err != nil && !errors.Is(err, ipn.ErrStateNotExist) { + return nil, err } privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) @@ -462,7 +498,7 @@ func acmeKey(dir string) (crypto.Signer, error) { if err := encodeECDSAKey(&pemBuf, privKey); err != nil { return nil, err } - if err := os.WriteFile(pemName, pemBuf.Bytes(), 0600); err != nil { + if err := cs.WriteACMEKey(pemBuf.Bytes()); err != nil { return nil, err } return privKey, nil