diff --git a/client/tailscale/localclient.go b/client/tailscale/localclient.go index 67bd0c5cf..53310e3d1 100644 --- a/client/tailscale/localclient.go +++ b/client/tailscale/localclient.go @@ -933,7 +933,20 @@ func CertPair(ctx context.Context, domain string) (certPEM, keyPEM []byte, err e // // API maturity: this is considered a stable API. func (lc *LocalClient) CertPair(ctx context.Context, domain string) (certPEM, keyPEM []byte, err error) { - res, err := lc.send(ctx, "GET", "/localapi/v0/cert/"+domain+"?type=pair", 200, nil) + return lc.CertPairWithValidity(ctx, domain, 0) +} + +// CertPairWithValidity returns a cert and private key for the provided DNS +// domain. +// +// It returns a cached certificate from disk if it's still valid. +// When minValidity is non-zero, the returned certificate will be valid for at +// least the given duration, if permitted by the CA. If the certificate is +// valid, but for less than minValidity, it will be synchronously renewed. +// +// API maturity: this is considered a stable API. +func (lc *LocalClient) CertPairWithValidity(ctx context.Context, domain string, minValidity time.Duration) (certPEM, keyPEM []byte, err error) { + res, err := lc.send(ctx, "GET", fmt.Sprintf("/localapi/v0/cert/%s?type=pair&min_validity=%s", domain, minValidity), 200, nil) if err != nil { return nil, nil, err } diff --git a/cmd/tailscale/cli/cert.go b/cmd/tailscale/cli/cert.go index db0f057ce..9c8eca5b7 100644 --- a/cmd/tailscale/cli/cert.go +++ b/cmd/tailscale/cli/cert.go @@ -16,6 +16,7 @@ import ( "net/http" "os" "strings" + "time" "github.com/peterbourgon/ff/v3/ffcli" "software.sslmate.com/src/go-pkcs12" @@ -34,14 +35,16 @@ var certCmd = &ffcli.Command{ fs.StringVar(&certArgs.certFile, "cert-file", "", "output cert file or \"-\" for stdout; defaults to DOMAIN.crt if --cert-file and --key-file are both unset") fs.StringVar(&certArgs.keyFile, "key-file", "", "output key file or \"-\" for stdout; defaults to DOMAIN.key if --cert-file and --key-file are both unset") fs.BoolVar(&certArgs.serve, "serve-demo", false, "if true, serve on port :443 using the cert as a demo, instead of writing out the files to disk") + fs.DurationVar(&certArgs.minValidity, "min-validity", 0, "ensure the certificate is valid for at least this duration; the output certificate is never expired if this flag is unset or 0, but the lifetime may vary; the maximum allowed min-validity depends on the CA") return fs })(), } var certArgs struct { - certFile string - keyFile string - serve bool + certFile string + keyFile string + serve bool + minValidity time.Duration } func runCert(ctx context.Context, args []string) error { @@ -102,7 +105,7 @@ func runCert(ctx context.Context, args []string) error { certArgs.certFile = domain + ".crt" certArgs.keyFile = domain + ".key" } - certPEM, keyPEM, err := localClient.CertPair(ctx, domain) + certPEM, keyPEM, err := localClient.CertPairWithValidity(ctx, domain, certArgs.minValidity) if err != nil { return err } diff --git a/ipn/ipnlocal/cert.go b/ipn/ipnlocal/cert.go index 11ea05df3..d87374bbb 100644 --- a/ipn/ipnlocal/cert.go +++ b/ipn/ipnlocal/cert.go @@ -88,6 +88,17 @@ var acmeDebug = envknob.RegisterBool("TS_DEBUG_ACME") // If a cert is expired, it will be renewed synchronously otherwise it will be // renewed asynchronously. func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertKeyPair, error) { + return b.GetCertPEMWithValidity(ctx, domain, 0) +} + +// GetCertPEMWithValidity gets the TLSCertKeyPair for domain, either from cache +// or via the ACME process. ACME process is used for new domain certs, existing +// expired certs or existing certs that should get renewed sooner than +// minValidity. +// +// If a cert is expired, or expires sooner than minValidity, it will be renewed +// synchronously. Otherwise it will be renewed asynchronously. +func (b *LocalBackend) GetCertPEMWithValidity(ctx context.Context, domain string, minValidity time.Duration) (*TLSCertKeyPair, error) { if !validLookingCertDomain(domain) { return nil, errors.New("invalid domain") } @@ -109,17 +120,28 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertK if pair, err := getCertPEMCached(cs, domain, now); err == nil { // If we got here, we have a valid unexpired cert. // Check whether we should start an async renewal. - if shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, pair); err != nil { + shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, pair, minValidity) + if err != nil { logf("error checking for certificate renewal: %v", err) - } else if shouldRenew { + // Renewal check failed, but the current cert is valid and not + // expired, so it's safe to return. + return pair, nil + } + if !shouldRenew { + return pair, nil + } + if minValidity == 0 { logf("starting async renewal") - // Start renewal in the background. - go b.getCertPEM(context.Background(), cs, logf, traceACME, domain, now) + // Start renewal in the background, return current valid cert. + go b.getCertPEM(context.Background(), cs, logf, traceACME, domain, now, minValidity) + return pair, nil } - return pair, nil + // If the caller requested a specific validity duration, fall through + // to synchronous renewal to fulfill that. + logf("starting sync renewal") } - pair, err := b.getCertPEM(ctx, cs, logf, traceACME, domain, now) + pair, err := b.getCertPEM(ctx, cs, logf, traceACME, domain, now, minValidity) if err != nil { logf("getCertPEM: %v", err) return nil, err @@ -129,7 +151,14 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertK // shouldStartDomainRenewal reports whether the domain's cert should be renewed // based on the current time, the cert's expiry, and the ARI check. -func (b *LocalBackend) shouldStartDomainRenewal(cs certStore, domain string, now time.Time, pair *TLSCertKeyPair) (bool, error) { +func (b *LocalBackend) shouldStartDomainRenewal(cs certStore, domain string, now time.Time, pair *TLSCertKeyPair, minValidity time.Duration) (bool, error) { + if minValidity != 0 { + cert, err := pair.parseCertificate() + if err != nil { + return false, fmt.Errorf("parsing certificate: %w", err) + } + return cert.NotAfter.Sub(now) < minValidity, nil + } renewMu.Lock() defer renewMu.Unlock() if renewAt, ok := renewCertAt[domain]; ok { @@ -157,11 +186,7 @@ func (b *LocalBackend) domainRenewed(domain string) { } func (b *LocalBackend) domainRenewalTimeByExpiry(pair *TLSCertKeyPair) (time.Time, error) { - block, _ := pem.Decode(pair.CertPEM) - if block == nil { - return time.Time{}, fmt.Errorf("parsing certificate PEM") - } - cert, err := x509.ParseCertificate(block.Bytes) + cert, err := pair.parseCertificate() if err != nil { return time.Time{}, fmt.Errorf("parsing certificate: %w", err) } @@ -366,6 +391,17 @@ type TLSCertKeyPair struct { Cached bool // whether result came from cache } +func (kp TLSCertKeyPair) parseCertificate() (*x509.Certificate, error) { + block, _ := pem.Decode(kp.CertPEM) + if block == nil { + return nil, fmt.Errorf("error parsing certificate PEM") + } + if block.Type != "CERTIFICATE" { + return nil, fmt.Errorf("PEM block is %q, not a CERTIFICATE", block.Type) + } + return x509.ParseCertificate(block.Bytes) +} + func keyFile(dir, domain string) string { return filepath.Join(dir, domain+".key") } func certFile(dir, domain string) string { return filepath.Join(dir, domain+".crt") } @@ -383,7 +419,7 @@ func getCertPEMCached(cs certStore, domain string, now time.Time) (p *TLSCertKey return cs.Read(domain, now) } -func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger.Logf, traceACME func(any), domain string, now time.Time) (*TLSCertKeyPair, error) { +func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger.Logf, traceACME func(any), domain string, now time.Time, minValidity time.Duration) (*TLSCertKeyPair, error) { acmeMu.Lock() defer acmeMu.Unlock() @@ -393,7 +429,7 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger if p, err := getCertPEMCached(cs, domain, now); err == nil { // shouldStartDomainRenewal caches its result so it's OK to call this // frequently. - shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, p) + shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, p, minValidity) if err != nil { logf("error checking for certificate renewal: %v", err) } else if !shouldRenew { diff --git a/ipn/localapi/cert.go b/ipn/localapi/cert.go index 447c3bc3c..323406f7b 100644 --- a/ipn/localapi/cert.go +++ b/ipn/localapi/cert.go @@ -9,6 +9,7 @@ import ( "fmt" "net/http" "strings" + "time" "tailscale.com/ipn/ipnlocal" ) @@ -23,7 +24,16 @@ func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) { http.Error(w, "internal handler config wired wrong", 500) return } - pair, err := h.b.GetCertPEM(r.Context(), domain) + var minValidity time.Duration + if minValidityStr := r.URL.Query().Get("min_validity"); minValidityStr != "" { + var err error + minValidity, err = time.ParseDuration(minValidityStr) + if err != nil { + http.Error(w, fmt.Sprintf("invalid validity parameter: %v", err), http.StatusBadRequest) + return + } + } + pair, err := h.b.GetCertPEMWithValidity(r.Context(), domain, minValidity) if err != nil { // TODO(bradfitz): 500 is a little lazy here. The errors returned from // GetCertPEM (and everywhere) should carry info info to get whether