diff --git a/ipn/localapi/cert.go b/ipn/localapi/cert.go index 6e281d072..afd7ba37b 100644 --- a/ipn/localapi/cert.go +++ b/ipn/localapi/cert.go @@ -34,6 +34,7 @@ import ( "golang.org/x/crypto/acme" "tailscale.com/envknob" + "tailscale.com/ipn/ipnlocal" "tailscale.com/ipn/ipnstate" "tailscale.com/types/logger" "tailscale.com/util/strs" @@ -79,13 +80,6 @@ func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) { http.Error(w, "cert access denied", http.StatusForbidden) return } - dir, err := h.certDir() - if err != nil { - h.logf("certDir: %v", err) - http.Error(w, "failed to get cert dir", 500) - return - } - domain, ok := strs.CutPrefix(r.URL.Path, "/localapi/v0/cert/") if !ok { http.Error(w, "internal handler config wired wrong", 500) @@ -95,8 +89,24 @@ func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) { http.Error(w, "invalid domain", 400) return } - now := time.Now() + pair, err := h.getCertPEM(r.Context(), domain) + if err != nil { + http.Error(w, fmt.Sprint(err), 500) + return + } + serveKeyPair(w, r, pair) +} + +// getCertPEM gets the KeyPair for domain, either from cache, via the ACME +// process, or from cache and kicking off an async ACME renewal. +func (h *Handler) getCertPEM(ctx context.Context, domain string) (*keyPair, error) { logf := logger.WithPrefix(h.logf, fmt.Sprintf("cert(%q): ", domain)) + dir, err := h.certDir() + if err != nil { + logf("failed to get certDir: %v", err) + return nil, err + } + now := time.Now() traceACME := func(v any) { if !acmeDebug() { return @@ -105,24 +115,22 @@ func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) { log.Printf("acme %T: %s", v, j) } - if pair, ok := h.getCertPEMCached(dir, domain, now); ok { + if pair, ok := getCertPEMCached(dir, domain, now); ok { future := now.AddDate(0, 0, 14) if h.shouldStartDomainRenewal(dir, domain, future) { logf("starting async renewal") // Start renewal in the background. - go h.getCertPEM(context.Background(), logf, traceACME, dir, domain, future) + go getCertPEM(context.Background(), h.b, logf, traceACME, dir, domain, future) } - serveKeyPair(w, r, pair) - return + return pair, nil } - pair, err := h.getCertPEM(r.Context(), logf, traceACME, dir, domain, now) + pair, err := getCertPEM(ctx, h.b, logf, traceACME, dir, domain, now) if err != nil { logf("getCertPEM: %v", err) - http.Error(w, fmt.Sprint(err), 500) - return + return nil, err } - serveKeyPair(w, r, pair) + return pair, nil } func (h *Handler) shouldStartDomainRenewal(dir, domain string, future time.Time) bool { @@ -135,7 +143,7 @@ func (h *Handler) shouldStartDomainRenewal(dir, domain string, future time.Time) return false } lastRenewCheck[domain] = now - _, ok := h.getCertPEMCached(dir, domain, future) + _, ok := getCertPEMCached(dir, domain, future) return !ok } @@ -154,10 +162,12 @@ func serveKeyPair(w http.ResponseWriter, r *http.Request, p *keyPair) { } } +// keyPair is a TLS public and private key, and whether they were obtained +// from cache or freshly obtained. type keyPair struct { - certPEM []byte - keyPEM []byte - cached bool + certPEM []byte // public key, in PEM form + keyPEM []byte // private key, in PEM form + cached bool // whether result came from cache } func keyFile(dir, domain string) string { return filepath.Join(dir, domain+".key") } @@ -166,7 +176,7 @@ func certFile(dir, domain string) string { return filepath.Join(dir, domain+".cr // getCertPEMCached returns a non-nil keyPair and true if a cached // keypair for domain exists on disk in dir that is valid at the // provided now time. -func (h *Handler) getCertPEMCached(dir, domain string, now time.Time) (p *keyPair, ok bool) { +func getCertPEMCached(dir, domain string, now time.Time) (p *keyPair, ok bool) { if !validLookingCertDomain(domain) { // Before we read files from disk using it, validate it's halfway // reasonable looking. @@ -181,11 +191,11 @@ func (h *Handler) getCertPEMCached(dir, domain string, now time.Time) (p *keyPai return nil, false } -func (h *Handler) getCertPEM(ctx context.Context, logf logger.Logf, traceACME func(any), dir, domain string, now time.Time) (*keyPair, error) { +func getCertPEM(ctx context.Context, lb *ipnlocal.LocalBackend, logf logger.Logf, traceACME func(any), dir, domain string, now time.Time) (*keyPair, error) { acmeMu.Lock() defer acmeMu.Unlock() - if p, ok := h.getCertPEMCached(dir, domain, now); ok { + if p, ok := getCertPEMCached(dir, domain, now); ok { return p, nil } @@ -223,7 +233,7 @@ func (h *Handler) getCertPEM(ctx context.Context, logf logger.Logf, traceACME fu } // Before hitting LetsEncrypt, see if this is a domain that Tailscale will do DNS challenges for. - st := h.b.StatusWithoutPeers() + st := lb.StatusWithoutPeers() if err := checkCertDomain(st, domain); err != nil { return nil, err } @@ -260,7 +270,7 @@ func (h *Handler) getCertPEM(ctx context.Context, logf logger.Logf, traceACME fu } if !ok { logf("starting SetDNS call...") - err = h.b.SetDNS(ctx, key, rec) + err = lb.SetDNS(ctx, key, rec) if err != nil { return nil, fmt.Errorf("SetDNS %q => %q: %w", key, rec, err) }