From 24f322bc43cd0aa6f9492c2d03b3c0d330b0cc3b Mon Sep 17 00:00:00 2001 From: Maisem Ali Date: Thu, 12 Oct 2023 15:52:41 -0700 Subject: [PATCH] ipn/ipnlocal: do unexpired cert renewals in the background We were eagerly doing a synchronous renewal of the cert while trying to serve traffic. Instead of that, just do the cert renewal in the background and continue serving traffic as long as the cert is still valid. This regressed in c1ecae13ab708cef90905085f87729974f6c339d when we introduced ARI support and were trying to make the experience of `tailscale cert` better. However, that ended up regressing the experience for tsnet as it would not always doing the renewal synchronously. Fixes #9783 Signed-off-by: Maisem Ali --- ipn/ipnlocal/cert.go | 27 +++++++++++++-------------- ipn/ipnlocal/cert_js.go | 2 +- ipn/ipnlocal/serve.go | 4 ++-- ipn/localapi/cert.go | 2 +- 4 files changed, 17 insertions(+), 18 deletions(-) diff --git a/ipn/ipnlocal/cert.go b/ipn/ipnlocal/cert.go index de5dee3e9..1e3ee94ad 100644 --- a/ipn/ipnlocal/cert.go +++ b/ipn/ipnlocal/cert.go @@ -84,11 +84,9 @@ var acmeDebug = envknob.RegisterBool("TS_DEBUG_ACME") // ACME process. ACME process is used for new domain certs, existing expired // certs or existing certs that should get renewed due to upcoming expiry. // -// syncRenewal changes renewal behavior for existing certs that are still valid -// but need renewal. When syncRenewal is set, the method blocks until a new -// cert is issued. When syncRenewal is not set, existing cert is returned right -// away and renewal is kicked off in a background goroutine. -func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string, syncRenewal bool) (*TLSCertKeyPair, error) { +// If a cert is expired, it will be renewed synchronously otherwise it will be +// renewed asynchronously. +func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertKeyPair, error) { if !validLookingCertDomain(domain) { return nil, errors.New("invalid domain") } @@ -108,18 +106,16 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string, syncRenewa } if pair, err := getCertPEMCached(cs, domain, now); err == nil { - shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, pair) - if err != nil { + // If we got here, we have a valid unexpired cert. + // Check whether we should start an async renewal. + if shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, pair); err != nil { logf("error checking for certificate renewal: %v", err) - } else if !shouldRenew { - return pair, nil - } - if !syncRenewal { + } else if shouldRenew { logf("starting async renewal") // Start renewal in the background. go b.getCertPEM(context.Background(), cs, logf, traceACME, domain, now) } - // Synchronous renewal happens below. + return pair, nil } pair, err := b.getCertPEM(ctx, cs, logf, traceACME, domain, now) @@ -130,6 +126,8 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string, syncRenewa return pair, nil } +// shouldStartDomainRenewal reports whether the domain's cert should be renewed +// based on the current time, the cert's expiry, and the ARI check. func (b *LocalBackend) shouldStartDomainRenewal(cs certStore, domain string, now time.Time, pair *TLSCertKeyPair) (bool, error) { renewMu.Lock() defer renewMu.Unlock() @@ -365,8 +363,9 @@ type TLSCertKeyPair struct { func keyFile(dir, domain string) string { return filepath.Join(dir, domain+".key") } func certFile(dir, domain string) string { return filepath.Join(dir, domain+".crt") } -// getCertPEMCached returns a non-nil keyPair and true if a cached keypair for -// domain exists on disk in dir that is valid at the provided now time. +// getCertPEMCached returns a non-nil keyPair if a cached keypair for domain +// exists on disk in dir that is valid at the provided now time. +// // If the keypair is expired, it returns errCertExpired. // If the keypair doesn't exist, it returns ipn.ErrStateNotExist. func getCertPEMCached(cs certStore, domain string, now time.Time) (p *TLSCertKeyPair, err error) { diff --git a/ipn/ipnlocal/cert_js.go b/ipn/ipnlocal/cert_js.go index 24defb47b..a5fdfc4ba 100644 --- a/ipn/ipnlocal/cert_js.go +++ b/ipn/ipnlocal/cert_js.go @@ -12,6 +12,6 @@ type TLSCertKeyPair struct { CertPEM, KeyPEM []byte } -func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string, syncRenewal bool) (*TLSCertKeyPair, error) { +func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertKeyPair, error) { return nil, errors.New("not implemented for js/wasm") } diff --git a/ipn/ipnlocal/serve.go b/ipn/ipnlocal/serve.go index f5c416327..c3c6a7eeb 100644 --- a/ipn/ipnlocal/serve.go +++ b/ipn/ipnlocal/serve.go @@ -451,7 +451,7 @@ func (b *LocalBackend) tcpHandlerForServe(dport uint16, srcAddr netip.AddrPort) GetCertificate: func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - pair, err := b.GetCertPEM(ctx, sni, false) + pair, err := b.GetCertPEM(ctx, sni) if err != nil { return nil, err } @@ -757,7 +757,7 @@ func (b *LocalBackend) getTLSServeCertForPort(port uint16) func(hi *tls.ClientHe ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - pair, err := b.GetCertPEM(ctx, hi.ServerName, false) + pair, err := b.GetCertPEM(ctx, hi.ServerName) if err != nil { return nil, err } diff --git a/ipn/localapi/cert.go b/ipn/localapi/cert.go index e1704cb49..447c3bc3c 100644 --- a/ipn/localapi/cert.go +++ b/ipn/localapi/cert.go @@ -23,7 +23,7 @@ func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) { http.Error(w, "internal handler config wired wrong", 500) return } - pair, err := h.b.GetCertPEM(r.Context(), domain, true) + pair, err := h.b.GetCertPEM(r.Context(), domain) if err != nil { // TODO(bradfitz): 500 is a little lazy here. The errors returned from // GetCertPEM (and everywhere) should carry info info to get whether