cmd/tailscale: add --min-validity flag to the cert command (#12822)

Some users run "tailscale cert" in a cron job to renew their
certificates on disk. The time until the next cron job run may be long
enough for the old cert to expire with our default heristics.

Add a `--min-validity` flag which ensures that the returned cert is
valid for at least the provided duration (unless it's longer than the
cert lifetime set by Let's Encrypt).

Updates #8725

Signed-off-by: Andrew Lytvynov <awly@tailscale.com>
pull/12865/head
Andrew Lytvynov 4 months ago committed by GitHub
parent 32ce18716b
commit e7bf6e716b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -933,7 +933,20 @@ func CertPair(ctx context.Context, domain string) (certPEM, keyPEM []byte, err e
//
// API maturity: this is considered a stable API.
func (lc *LocalClient) CertPair(ctx context.Context, domain string) (certPEM, keyPEM []byte, err error) {
res, err := lc.send(ctx, "GET", "/localapi/v0/cert/"+domain+"?type=pair", 200, nil)
return lc.CertPairWithValidity(ctx, domain, 0)
}
// CertPairWithValidity returns a cert and private key for the provided DNS
// domain.
//
// It returns a cached certificate from disk if it's still valid.
// When minValidity is non-zero, the returned certificate will be valid for at
// least the given duration, if permitted by the CA. If the certificate is
// valid, but for less than minValidity, it will be synchronously renewed.
//
// API maturity: this is considered a stable API.
func (lc *LocalClient) CertPairWithValidity(ctx context.Context, domain string, minValidity time.Duration) (certPEM, keyPEM []byte, err error) {
res, err := lc.send(ctx, "GET", fmt.Sprintf("/localapi/v0/cert/%s?type=pair&min_validity=%s", domain, minValidity), 200, nil)
if err != nil {
return nil, nil, err
}

@ -16,6 +16,7 @@ import (
"net/http"
"os"
"strings"
"time"
"github.com/peterbourgon/ff/v3/ffcli"
"software.sslmate.com/src/go-pkcs12"
@ -34,14 +35,16 @@ var certCmd = &ffcli.Command{
fs.StringVar(&certArgs.certFile, "cert-file", "", "output cert file or \"-\" for stdout; defaults to DOMAIN.crt if --cert-file and --key-file are both unset")
fs.StringVar(&certArgs.keyFile, "key-file", "", "output key file or \"-\" for stdout; defaults to DOMAIN.key if --cert-file and --key-file are both unset")
fs.BoolVar(&certArgs.serve, "serve-demo", false, "if true, serve on port :443 using the cert as a demo, instead of writing out the files to disk")
fs.DurationVar(&certArgs.minValidity, "min-validity", 0, "ensure the certificate is valid for at least this duration; the output certificate is never expired if this flag is unset or 0, but the lifetime may vary; the maximum allowed min-validity depends on the CA")
return fs
})(),
}
var certArgs struct {
certFile string
keyFile string
serve bool
certFile string
keyFile string
serve bool
minValidity time.Duration
}
func runCert(ctx context.Context, args []string) error {
@ -102,7 +105,7 @@ func runCert(ctx context.Context, args []string) error {
certArgs.certFile = domain + ".crt"
certArgs.keyFile = domain + ".key"
}
certPEM, keyPEM, err := localClient.CertPair(ctx, domain)
certPEM, keyPEM, err := localClient.CertPairWithValidity(ctx, domain, certArgs.minValidity)
if err != nil {
return err
}

@ -88,6 +88,17 @@ var acmeDebug = envknob.RegisterBool("TS_DEBUG_ACME")
// If a cert is expired, it will be renewed synchronously otherwise it will be
// renewed asynchronously.
func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertKeyPair, error) {
return b.GetCertPEMWithValidity(ctx, domain, 0)
}
// GetCertPEMWithValidity gets the TLSCertKeyPair for domain, either from cache
// or via the ACME process. ACME process is used for new domain certs, existing
// expired certs or existing certs that should get renewed sooner than
// minValidity.
//
// If a cert is expired, or expires sooner than minValidity, it will be renewed
// synchronously. Otherwise it will be renewed asynchronously.
func (b *LocalBackend) GetCertPEMWithValidity(ctx context.Context, domain string, minValidity time.Duration) (*TLSCertKeyPair, error) {
if !validLookingCertDomain(domain) {
return nil, errors.New("invalid domain")
}
@ -109,17 +120,28 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertK
if pair, err := getCertPEMCached(cs, domain, now); err == nil {
// If we got here, we have a valid unexpired cert.
// Check whether we should start an async renewal.
if shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, pair); err != nil {
shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, pair, minValidity)
if err != nil {
logf("error checking for certificate renewal: %v", err)
} else if shouldRenew {
// Renewal check failed, but the current cert is valid and not
// expired, so it's safe to return.
return pair, nil
}
if !shouldRenew {
return pair, nil
}
if minValidity == 0 {
logf("starting async renewal")
// Start renewal in the background.
go b.getCertPEM(context.Background(), cs, logf, traceACME, domain, now)
// Start renewal in the background, return current valid cert.
go b.getCertPEM(context.Background(), cs, logf, traceACME, domain, now, minValidity)
return pair, nil
}
return pair, nil
// If the caller requested a specific validity duration, fall through
// to synchronous renewal to fulfill that.
logf("starting sync renewal")
}
pair, err := b.getCertPEM(ctx, cs, logf, traceACME, domain, now)
pair, err := b.getCertPEM(ctx, cs, logf, traceACME, domain, now, minValidity)
if err != nil {
logf("getCertPEM: %v", err)
return nil, err
@ -129,7 +151,14 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertK
// shouldStartDomainRenewal reports whether the domain's cert should be renewed
// based on the current time, the cert's expiry, and the ARI check.
func (b *LocalBackend) shouldStartDomainRenewal(cs certStore, domain string, now time.Time, pair *TLSCertKeyPair) (bool, error) {
func (b *LocalBackend) shouldStartDomainRenewal(cs certStore, domain string, now time.Time, pair *TLSCertKeyPair, minValidity time.Duration) (bool, error) {
if minValidity != 0 {
cert, err := pair.parseCertificate()
if err != nil {
return false, fmt.Errorf("parsing certificate: %w", err)
}
return cert.NotAfter.Sub(now) < minValidity, nil
}
renewMu.Lock()
defer renewMu.Unlock()
if renewAt, ok := renewCertAt[domain]; ok {
@ -157,11 +186,7 @@ func (b *LocalBackend) domainRenewed(domain string) {
}
func (b *LocalBackend) domainRenewalTimeByExpiry(pair *TLSCertKeyPair) (time.Time, error) {
block, _ := pem.Decode(pair.CertPEM)
if block == nil {
return time.Time{}, fmt.Errorf("parsing certificate PEM")
}
cert, err := x509.ParseCertificate(block.Bytes)
cert, err := pair.parseCertificate()
if err != nil {
return time.Time{}, fmt.Errorf("parsing certificate: %w", err)
}
@ -366,6 +391,17 @@ type TLSCertKeyPair struct {
Cached bool // whether result came from cache
}
func (kp TLSCertKeyPair) parseCertificate() (*x509.Certificate, error) {
block, _ := pem.Decode(kp.CertPEM)
if block == nil {
return nil, fmt.Errorf("error parsing certificate PEM")
}
if block.Type != "CERTIFICATE" {
return nil, fmt.Errorf("PEM block is %q, not a CERTIFICATE", block.Type)
}
return x509.ParseCertificate(block.Bytes)
}
func keyFile(dir, domain string) string { return filepath.Join(dir, domain+".key") }
func certFile(dir, domain string) string { return filepath.Join(dir, domain+".crt") }
@ -383,7 +419,7 @@ func getCertPEMCached(cs certStore, domain string, now time.Time) (p *TLSCertKey
return cs.Read(domain, now)
}
func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger.Logf, traceACME func(any), domain string, now time.Time) (*TLSCertKeyPair, error) {
func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger.Logf, traceACME func(any), domain string, now time.Time, minValidity time.Duration) (*TLSCertKeyPair, error) {
acmeMu.Lock()
defer acmeMu.Unlock()
@ -393,7 +429,7 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger
if p, err := getCertPEMCached(cs, domain, now); err == nil {
// shouldStartDomainRenewal caches its result so it's OK to call this
// frequently.
shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, p)
shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, p, minValidity)
if err != nil {
logf("error checking for certificate renewal: %v", err)
} else if !shouldRenew {

@ -9,6 +9,7 @@ import (
"fmt"
"net/http"
"strings"
"time"
"tailscale.com/ipn/ipnlocal"
)
@ -23,7 +24,16 @@ func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) {
http.Error(w, "internal handler config wired wrong", 500)
return
}
pair, err := h.b.GetCertPEM(r.Context(), domain)
var minValidity time.Duration
if minValidityStr := r.URL.Query().Get("min_validity"); minValidityStr != "" {
var err error
minValidity, err = time.ParseDuration(minValidityStr)
if err != nil {
http.Error(w, fmt.Sprintf("invalid validity parameter: %v", err), http.StatusBadRequest)
return
}
}
pair, err := h.b.GetCertPEMWithValidity(r.Context(), domain, minValidity)
if err != nil {
// TODO(bradfitz): 500 is a little lazy here. The errors returned from
// GetCertPEM (and everywhere) should carry info info to get whether

Loading…
Cancel
Save