diff --git a/ipn/ipnlocal/cert.go b/ipn/ipnlocal/cert.go index ef19a6571..33bbac8df 100644 --- a/ipn/ipnlocal/cert.go +++ b/ipn/ipnlocal/cert.go @@ -101,11 +101,13 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertK } if pair, err := getCertPEMCached(cs, domain, now); err == nil { - future := now.AddDate(0, 0, 14) - if b.shouldStartDomainRenewal(cs, domain, future) { + shouldRenew, err := shouldStartDomainRenewal(domain, now, pair) + if err != nil { + logf("error checking for certificate renewal: %v", err) + } else if shouldRenew { logf("starting async renewal") // Start renewal in the background. - go b.getCertPEM(context.Background(), cs, logf, traceACME, domain, future) + go b.getCertPEM(context.Background(), cs, logf, traceACME, domain, now) } return pair, nil } @@ -118,18 +120,41 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertK return pair, nil } -func (b *LocalBackend) shouldStartDomainRenewal(cs certStore, domain string, future time.Time) bool { +func shouldStartDomainRenewal(domain string, now time.Time, pair *TLSCertKeyPair) (bool, error) { renewMu.Lock() defer renewMu.Unlock() - now := time.Now() if last, ok := lastRenewCheck[domain]; ok && now.Sub(last) < time.Minute { // We checked very recently. Don't bother reparsing & // validating the x509 cert. - return false + return false, nil } lastRenewCheck[domain] = now - _, err := getCertPEMCached(cs, domain, future) - return errors.Is(err, errCertExpired) + + block, _ := pem.Decode(pair.CertPEM) + if block == nil { + return false, fmt.Errorf("parsing certificate PEM") + } + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return false, fmt.Errorf("parsing certificate: %w", err) + } + + certLifetime := cert.NotAfter.Sub(cert.NotBefore) + if certLifetime < 0 { + return false, fmt.Errorf("negative certificate lifetime %v", certLifetime) + } + + // Per https://github.com/tailscale/tailscale/issues/8204, check + // whether we're more than 2/3 of the way through the certificate's + // lifetime, which is the officially-recommended best practice by Let's + // Encrypt. + renewalDuration := certLifetime * 2 / 3 + renewAt := cert.NotBefore.Add(renewalDuration) + + if now.After(renewAt) { + return true, nil + } + return false, nil } // certStore provides a way to perist and retrieve TLS certificates. diff --git a/ipn/ipnlocal/cert_test.go b/ipn/ipnlocal/cert_test.go index 6cc2f13c4..d29a68776 100644 --- a/ipn/ipnlocal/cert_test.go +++ b/ipn/ipnlocal/cert_test.go @@ -6,12 +6,19 @@ package ipnlocal import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" "crypto/x509" + "crypto/x509/pkix" "embed" + "encoding/pem" + "math/big" "testing" "time" "github.com/google/go-cmp/cmp" + "golang.org/x/exp/maps" "tailscale.com/ipn/store/mem" ) @@ -100,3 +107,94 @@ func TestCertStoreRoundTrip(t *testing.T) { }) } } + +func TestShouldStartDomainRenewal(t *testing.T) { + reset := func() { + renewMu.Lock() + defer renewMu.Unlock() + maps.Clear(lastRenewCheck) + } + + mustMakePair := func(template *x509.Certificate) *TLSCertKeyPair { + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + panic(err) + } + + b, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv) + if err != nil { + panic(err) + } + certPEM := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: b, + }) + + return &TLSCertKeyPair{ + Cached: false, + CertPEM: certPEM, + KeyPEM: []byte("unused"), + } + } + + now := time.Unix(1685714838, 0) + subject := pkix.Name{ + Organization: []string{"Tailscale, Inc."}, + Country: []string{"CA"}, + Province: []string{"ON"}, + Locality: []string{"Toronto"}, + StreetAddress: []string{"290 Bremner Blvd"}, + PostalCode: []string{"M5V 3L9"}, + } + + testCases := []struct { + name string + notBefore time.Time + lifetime time.Duration + want bool + wantErr string + }{ + { + name: "should renew", + notBefore: now.AddDate(0, 0, -89), + lifetime: 90 * 24 * time.Hour, + want: true, + }, + { + name: "short-lived renewal", + notBefore: now.AddDate(0, 0, -7), + lifetime: 10 * 24 * time.Hour, + want: true, + }, + { + name: "no renew", + notBefore: now.AddDate(0, 0, -59), // 59 days ago == not 2/3rds of the way through 90 days yet + lifetime: 90 * 24 * time.Hour, + want: false, + }, + } + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + reset() + + ret, err := shouldStartDomainRenewal("example.com", now, mustMakePair(&x509.Certificate{ + SerialNumber: big.NewInt(2019), + Subject: subject, + NotBefore: tt.notBefore, + NotAfter: tt.notBefore.Add(tt.lifetime), + })) + + if tt.wantErr != "" { + if err == nil { + t.Errorf("wanted error, got nil") + } else if err.Error() != tt.wantErr { + t.Errorf("got err=%q, want %q", err.Error(), tt.wantErr) + } + } else { + if ret != tt.want { + t.Errorf("got ret=%v, want %v", ret, tt.want) + } + } + }) + } +}