diff --git a/ipn/ipnlocal/cert.go b/ipn/ipnlocal/cert.go index de5dee3e9..1e3ee94ad 100644 --- a/ipn/ipnlocal/cert.go +++ b/ipn/ipnlocal/cert.go @@ -84,11 +84,9 @@ var acmeDebug = envknob.RegisterBool("TS_DEBUG_ACME") // ACME process. ACME process is used for new domain certs, existing expired // certs or existing certs that should get renewed due to upcoming expiry. // -// syncRenewal changes renewal behavior for existing certs that are still valid -// but need renewal. When syncRenewal is set, the method blocks until a new -// cert is issued. When syncRenewal is not set, existing cert is returned right -// away and renewal is kicked off in a background goroutine. -func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string, syncRenewal bool) (*TLSCertKeyPair, error) { +// If a cert is expired, it will be renewed synchronously otherwise it will be +// renewed asynchronously. +func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertKeyPair, error) { if !validLookingCertDomain(domain) { return nil, errors.New("invalid domain") } @@ -108,18 +106,16 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string, syncRenewa } if pair, err := getCertPEMCached(cs, domain, now); err == nil { - shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, pair) - if err != nil { + // If we got here, we have a valid unexpired cert. + // Check whether we should start an async renewal. + if shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, pair); err != nil { logf("error checking for certificate renewal: %v", err) - } else if !shouldRenew { - return pair, nil - } - if !syncRenewal { + } else if shouldRenew { logf("starting async renewal") // Start renewal in the background. go b.getCertPEM(context.Background(), cs, logf, traceACME, domain, now) } - // Synchronous renewal happens below. + return pair, nil } pair, err := b.getCertPEM(ctx, cs, logf, traceACME, domain, now) @@ -130,6 +126,8 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string, syncRenewa return pair, nil } +// shouldStartDomainRenewal reports whether the domain's cert should be renewed +// based on the current time, the cert's expiry, and the ARI check. func (b *LocalBackend) shouldStartDomainRenewal(cs certStore, domain string, now time.Time, pair *TLSCertKeyPair) (bool, error) { renewMu.Lock() defer renewMu.Unlock() @@ -365,8 +363,9 @@ type TLSCertKeyPair struct { func keyFile(dir, domain string) string { return filepath.Join(dir, domain+".key") } func certFile(dir, domain string) string { return filepath.Join(dir, domain+".crt") } -// getCertPEMCached returns a non-nil keyPair and true if a cached keypair for -// domain exists on disk in dir that is valid at the provided now time. +// getCertPEMCached returns a non-nil keyPair if a cached keypair for domain +// exists on disk in dir that is valid at the provided now time. +// // If the keypair is expired, it returns errCertExpired. // If the keypair doesn't exist, it returns ipn.ErrStateNotExist. func getCertPEMCached(cs certStore, domain string, now time.Time) (p *TLSCertKeyPair, err error) { diff --git a/ipn/ipnlocal/cert_js.go b/ipn/ipnlocal/cert_js.go index 24defb47b..a5fdfc4ba 100644 --- a/ipn/ipnlocal/cert_js.go +++ b/ipn/ipnlocal/cert_js.go @@ -12,6 +12,6 @@ type TLSCertKeyPair struct { CertPEM, KeyPEM []byte } -func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string, syncRenewal bool) (*TLSCertKeyPair, error) { +func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertKeyPair, error) { return nil, errors.New("not implemented for js/wasm") } diff --git a/ipn/ipnlocal/serve.go b/ipn/ipnlocal/serve.go index f5c416327..c3c6a7eeb 100644 --- a/ipn/ipnlocal/serve.go +++ b/ipn/ipnlocal/serve.go @@ -451,7 +451,7 @@ func (b *LocalBackend) tcpHandlerForServe(dport uint16, srcAddr netip.AddrPort) GetCertificate: func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - pair, err := b.GetCertPEM(ctx, sni, false) + pair, err := b.GetCertPEM(ctx, sni) if err != nil { return nil, err } @@ -757,7 +757,7 @@ func (b *LocalBackend) getTLSServeCertForPort(port uint16) func(hi *tls.ClientHe ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - pair, err := b.GetCertPEM(ctx, hi.ServerName, false) + pair, err := b.GetCertPEM(ctx, hi.ServerName) if err != nil { return nil, err } diff --git a/ipn/localapi/cert.go b/ipn/localapi/cert.go index e1704cb49..447c3bc3c 100644 --- a/ipn/localapi/cert.go +++ b/ipn/localapi/cert.go @@ -23,7 +23,7 @@ func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) { http.Error(w, "internal handler config wired wrong", 500) return } - pair, err := h.b.GetCertPEM(r.Context(), domain, true) + pair, err := h.b.GetCertPEM(r.Context(), domain) if err != nil { // TODO(bradfitz): 500 is a little lazy here. The errors returned from // GetCertPEM (and everywhere) should carry info info to get whether