From e7bf6e716bdd4938bf51bf6c17ec442637ecb95b Mon Sep 17 00:00:00 2001 From: Andrew Lytvynov Date: Fri, 19 Jul 2024 11:35:22 -0500 Subject: [PATCH] cmd/tailscale: add --min-validity flag to the cert command (#12822) Some users run "tailscale cert" in a cron job to renew their certificates on disk. The time until the next cron job run may be long enough for the old cert to expire with our default heristics. Add a `--min-validity` flag which ensures that the returned cert is valid for at least the provided duration (unless it's longer than the cert lifetime set by Let's Encrypt). Updates #8725 Signed-off-by: Andrew Lytvynov --- client/tailscale/localclient.go | 15 +++++++- cmd/tailscale/cli/cert.go | 11 +++--- ipn/ipnlocal/cert.go | 64 +++++++++++++++++++++++++-------- ipn/localapi/cert.go | 12 ++++++- 4 files changed, 82 insertions(+), 20 deletions(-) diff --git a/client/tailscale/localclient.go b/client/tailscale/localclient.go index 67bd0c5cf..53310e3d1 100644 --- a/client/tailscale/localclient.go +++ b/client/tailscale/localclient.go @@ -933,7 +933,20 @@ func CertPair(ctx context.Context, domain string) (certPEM, keyPEM []byte, err e // // API maturity: this is considered a stable API. func (lc *LocalClient) CertPair(ctx context.Context, domain string) (certPEM, keyPEM []byte, err error) { - res, err := lc.send(ctx, "GET", "/localapi/v0/cert/"+domain+"?type=pair", 200, nil) + return lc.CertPairWithValidity(ctx, domain, 0) +} + +// CertPairWithValidity returns a cert and private key for the provided DNS +// domain. +// +// It returns a cached certificate from disk if it's still valid. +// When minValidity is non-zero, the returned certificate will be valid for at +// least the given duration, if permitted by the CA. If the certificate is +// valid, but for less than minValidity, it will be synchronously renewed. +// +// API maturity: this is considered a stable API. +func (lc *LocalClient) CertPairWithValidity(ctx context.Context, domain string, minValidity time.Duration) (certPEM, keyPEM []byte, err error) { + res, err := lc.send(ctx, "GET", fmt.Sprintf("/localapi/v0/cert/%s?type=pair&min_validity=%s", domain, minValidity), 200, nil) if err != nil { return nil, nil, err } diff --git a/cmd/tailscale/cli/cert.go b/cmd/tailscale/cli/cert.go index db0f057ce..9c8eca5b7 100644 --- a/cmd/tailscale/cli/cert.go +++ b/cmd/tailscale/cli/cert.go @@ -16,6 +16,7 @@ import ( "net/http" "os" "strings" + "time" "github.com/peterbourgon/ff/v3/ffcli" "software.sslmate.com/src/go-pkcs12" @@ -34,14 +35,16 @@ var certCmd = &ffcli.Command{ fs.StringVar(&certArgs.certFile, "cert-file", "", "output cert file or \"-\" for stdout; defaults to DOMAIN.crt if --cert-file and --key-file are both unset") fs.StringVar(&certArgs.keyFile, "key-file", "", "output key file or \"-\" for stdout; defaults to DOMAIN.key if --cert-file and --key-file are both unset") fs.BoolVar(&certArgs.serve, "serve-demo", false, "if true, serve on port :443 using the cert as a demo, instead of writing out the files to disk") + fs.DurationVar(&certArgs.minValidity, "min-validity", 0, "ensure the certificate is valid for at least this duration; the output certificate is never expired if this flag is unset or 0, but the lifetime may vary; the maximum allowed min-validity depends on the CA") return fs })(), } var certArgs struct { - certFile string - keyFile string - serve bool + certFile string + keyFile string + serve bool + minValidity time.Duration } func runCert(ctx context.Context, args []string) error { @@ -102,7 +105,7 @@ func runCert(ctx context.Context, args []string) error { certArgs.certFile = domain + ".crt" certArgs.keyFile = domain + ".key" } - certPEM, keyPEM, err := localClient.CertPair(ctx, domain) + certPEM, keyPEM, err := localClient.CertPairWithValidity(ctx, domain, certArgs.minValidity) if err != nil { return err } diff --git a/ipn/ipnlocal/cert.go b/ipn/ipnlocal/cert.go index 11ea05df3..d87374bbb 100644 --- a/ipn/ipnlocal/cert.go +++ b/ipn/ipnlocal/cert.go @@ -88,6 +88,17 @@ var acmeDebug = envknob.RegisterBool("TS_DEBUG_ACME") // If a cert is expired, it will be renewed synchronously otherwise it will be // renewed asynchronously. func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertKeyPair, error) { + return b.GetCertPEMWithValidity(ctx, domain, 0) +} + +// GetCertPEMWithValidity gets the TLSCertKeyPair for domain, either from cache +// or via the ACME process. ACME process is used for new domain certs, existing +// expired certs or existing certs that should get renewed sooner than +// minValidity. +// +// If a cert is expired, or expires sooner than minValidity, it will be renewed +// synchronously. Otherwise it will be renewed asynchronously. +func (b *LocalBackend) GetCertPEMWithValidity(ctx context.Context, domain string, minValidity time.Duration) (*TLSCertKeyPair, error) { if !validLookingCertDomain(domain) { return nil, errors.New("invalid domain") } @@ -109,17 +120,28 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertK if pair, err := getCertPEMCached(cs, domain, now); err == nil { // If we got here, we have a valid unexpired cert. // Check whether we should start an async renewal. - if shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, pair); err != nil { + shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, pair, minValidity) + if err != nil { logf("error checking for certificate renewal: %v", err) - } else if shouldRenew { + // Renewal check failed, but the current cert is valid and not + // expired, so it's safe to return. + return pair, nil + } + if !shouldRenew { + return pair, nil + } + if minValidity == 0 { logf("starting async renewal") - // Start renewal in the background. - go b.getCertPEM(context.Background(), cs, logf, traceACME, domain, now) + // Start renewal in the background, return current valid cert. + go b.getCertPEM(context.Background(), cs, logf, traceACME, domain, now, minValidity) + return pair, nil } - return pair, nil + // If the caller requested a specific validity duration, fall through + // to synchronous renewal to fulfill that. + logf("starting sync renewal") } - pair, err := b.getCertPEM(ctx, cs, logf, traceACME, domain, now) + pair, err := b.getCertPEM(ctx, cs, logf, traceACME, domain, now, minValidity) if err != nil { logf("getCertPEM: %v", err) return nil, err @@ -129,7 +151,14 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertK // shouldStartDomainRenewal reports whether the domain's cert should be renewed // based on the current time, the cert's expiry, and the ARI check. -func (b *LocalBackend) shouldStartDomainRenewal(cs certStore, domain string, now time.Time, pair *TLSCertKeyPair) (bool, error) { +func (b *LocalBackend) shouldStartDomainRenewal(cs certStore, domain string, now time.Time, pair *TLSCertKeyPair, minValidity time.Duration) (bool, error) { + if minValidity != 0 { + cert, err := pair.parseCertificate() + if err != nil { + return false, fmt.Errorf("parsing certificate: %w", err) + } + return cert.NotAfter.Sub(now) < minValidity, nil + } renewMu.Lock() defer renewMu.Unlock() if renewAt, ok := renewCertAt[domain]; ok { @@ -157,11 +186,7 @@ func (b *LocalBackend) domainRenewed(domain string) { } func (b *LocalBackend) domainRenewalTimeByExpiry(pair *TLSCertKeyPair) (time.Time, error) { - block, _ := pem.Decode(pair.CertPEM) - if block == nil { - return time.Time{}, fmt.Errorf("parsing certificate PEM") - } - cert, err := x509.ParseCertificate(block.Bytes) + cert, err := pair.parseCertificate() if err != nil { return time.Time{}, fmt.Errorf("parsing certificate: %w", err) } @@ -366,6 +391,17 @@ type TLSCertKeyPair struct { Cached bool // whether result came from cache } +func (kp TLSCertKeyPair) parseCertificate() (*x509.Certificate, error) { + block, _ := pem.Decode(kp.CertPEM) + if block == nil { + return nil, fmt.Errorf("error parsing certificate PEM") + } + if block.Type != "CERTIFICATE" { + return nil, fmt.Errorf("PEM block is %q, not a CERTIFICATE", block.Type) + } + return x509.ParseCertificate(block.Bytes) +} + func keyFile(dir, domain string) string { return filepath.Join(dir, domain+".key") } func certFile(dir, domain string) string { return filepath.Join(dir, domain+".crt") } @@ -383,7 +419,7 @@ func getCertPEMCached(cs certStore, domain string, now time.Time) (p *TLSCertKey return cs.Read(domain, now) } -func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger.Logf, traceACME func(any), domain string, now time.Time) (*TLSCertKeyPair, error) { +func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger.Logf, traceACME func(any), domain string, now time.Time, minValidity time.Duration) (*TLSCertKeyPair, error) { acmeMu.Lock() defer acmeMu.Unlock() @@ -393,7 +429,7 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger if p, err := getCertPEMCached(cs, domain, now); err == nil { // shouldStartDomainRenewal caches its result so it's OK to call this // frequently. - shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, p) + shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, p, minValidity) if err != nil { logf("error checking for certificate renewal: %v", err) } else if !shouldRenew { diff --git a/ipn/localapi/cert.go b/ipn/localapi/cert.go index 447c3bc3c..323406f7b 100644 --- a/ipn/localapi/cert.go +++ b/ipn/localapi/cert.go @@ -9,6 +9,7 @@ import ( "fmt" "net/http" "strings" + "time" "tailscale.com/ipn/ipnlocal" ) @@ -23,7 +24,16 @@ func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) { http.Error(w, "internal handler config wired wrong", 500) return } - pair, err := h.b.GetCertPEM(r.Context(), domain) + var minValidity time.Duration + if minValidityStr := r.URL.Query().Get("min_validity"); minValidityStr != "" { + var err error + minValidity, err = time.ParseDuration(minValidityStr) + if err != nil { + http.Error(w, fmt.Sprintf("invalid validity parameter: %v", err), http.StatusBadRequest) + return + } + } + pair, err := h.b.GetCertPEMWithValidity(r.Context(), domain, minValidity) if err != nil { // TODO(bradfitz): 500 is a little lazy here. The errors returned from // GetCertPEM (and everywhere) should carry info info to get whether