@ -34,6 +34,7 @@ import (
"golang.org/x/crypto/acme"
"golang.org/x/crypto/acme"
"tailscale.com/envknob"
"tailscale.com/envknob"
"tailscale.com/ipn/ipnlocal"
"tailscale.com/ipn/ipnstate"
"tailscale.com/ipn/ipnstate"
"tailscale.com/types/logger"
"tailscale.com/types/logger"
"tailscale.com/util/strs"
"tailscale.com/util/strs"
@ -79,13 +80,6 @@ func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) {
http . Error ( w , "cert access denied" , http . StatusForbidden )
http . Error ( w , "cert access denied" , http . StatusForbidden )
return
return
}
}
dir , err := h . certDir ( )
if err != nil {
h . logf ( "certDir: %v" , err )
http . Error ( w , "failed to get cert dir" , 500 )
return
}
domain , ok := strs . CutPrefix ( r . URL . Path , "/localapi/v0/cert/" )
domain , ok := strs . CutPrefix ( r . URL . Path , "/localapi/v0/cert/" )
if ! ok {
if ! ok {
http . Error ( w , "internal handler config wired wrong" , 500 )
http . Error ( w , "internal handler config wired wrong" , 500 )
@ -95,8 +89,24 @@ func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) {
http . Error ( w , "invalid domain" , 400 )
http . Error ( w , "invalid domain" , 400 )
return
return
}
}
now := time . Now ( )
pair , err := h . getCertPEM ( r . Context ( ) , domain )
if err != nil {
http . Error ( w , fmt . Sprint ( err ) , 500 )
return
}
serveKeyPair ( w , r , pair )
}
// getCertPEM gets the KeyPair for domain, either from cache, via the ACME
// process, or from cache and kicking off an async ACME renewal.
func ( h * Handler ) getCertPEM ( ctx context . Context , domain string ) ( * keyPair , error ) {
logf := logger . WithPrefix ( h . logf , fmt . Sprintf ( "cert(%q): " , domain ) )
logf := logger . WithPrefix ( h . logf , fmt . Sprintf ( "cert(%q): " , domain ) )
dir , err := h . certDir ( )
if err != nil {
logf ( "failed to get certDir: %v" , err )
return nil , err
}
now := time . Now ( )
traceACME := func ( v any ) {
traceACME := func ( v any ) {
if ! acmeDebug ( ) {
if ! acmeDebug ( ) {
return
return
@ -105,24 +115,22 @@ func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) {
log . Printf ( "acme %T: %s" , v , j )
log . Printf ( "acme %T: %s" , v , j )
}
}
if pair , ok := h . getCertPEMCached ( dir , domain , now ) ; ok {
if pair , ok := getCertPEMCached ( dir , domain , now ) ; ok {
future := now . AddDate ( 0 , 0 , 14 )
future := now . AddDate ( 0 , 0 , 14 )
if h . shouldStartDomainRenewal ( dir , domain , future ) {
if h . shouldStartDomainRenewal ( dir , domain , future ) {
logf ( "starting async renewal" )
logf ( "starting async renewal" )
// Start renewal in the background.
// Start renewal in the background.
go h . getCertPEM ( context . Background ( ) , logf , traceACME , dir , domain , future )
go getCertPEM ( context . Background ( ) , h . b , logf , traceACME , dir , domain , future )
}
}
serveKeyPair ( w , r , pair )
return pair , nil
return
}
}
pair , err := h . getCertPEM ( r . Context ( ) , logf , traceACME , dir , domain , now )
pair , err := getCertPEM ( ctx , h . b , logf , traceACME , dir , domain , now )
if err != nil {
if err != nil {
logf ( "getCertPEM: %v" , err )
logf ( "getCertPEM: %v" , err )
http . Error ( w , fmt . Sprint ( err ) , 500 )
return nil , err
return
}
}
serveKeyPair ( w , r , pair )
return pair , nil
}
}
func ( h * Handler ) shouldStartDomainRenewal ( dir , domain string , future time . Time ) bool {
func ( h * Handler ) shouldStartDomainRenewal ( dir , domain string , future time . Time ) bool {
@ -135,7 +143,7 @@ func (h *Handler) shouldStartDomainRenewal(dir, domain string, future time.Time)
return false
return false
}
}
lastRenewCheck [ domain ] = now
lastRenewCheck [ domain ] = now
_ , ok := h . getCertPEMCached ( dir , domain , future )
_ , ok := getCertPEMCached ( dir , domain , future )
return ! ok
return ! ok
}
}
@ -154,10 +162,12 @@ func serveKeyPair(w http.ResponseWriter, r *http.Request, p *keyPair) {
}
}
}
}
// keyPair is a TLS public and private key, and whether they were obtained
// from cache or freshly obtained.
type keyPair struct {
type keyPair struct {
certPEM [ ] byte
certPEM [ ] byte // public key, in PEM form
keyPEM [ ] byte
keyPEM [ ] byte // private key, in PEM form
cached bool
cached bool // whether result came from cache
}
}
func keyFile ( dir , domain string ) string { return filepath . Join ( dir , domain + ".key" ) }
func keyFile ( dir , domain string ) string { return filepath . Join ( dir , domain + ".key" ) }
@ -166,7 +176,7 @@ func certFile(dir, domain string) string { return filepath.Join(dir, domain+".cr
// getCertPEMCached returns a non-nil keyPair and true if a cached
// getCertPEMCached returns a non-nil keyPair and true if a cached
// keypair for domain exists on disk in dir that is valid at the
// keypair for domain exists on disk in dir that is valid at the
// provided now time.
// provided now time.
func ( h * Handler ) getCertPEMCached ( dir , domain string , now time . Time ) ( p * keyPair , ok bool ) {
func getCertPEMCached ( dir , domain string , now time . Time ) ( p * keyPair , ok bool ) {
if ! validLookingCertDomain ( domain ) {
if ! validLookingCertDomain ( domain ) {
// Before we read files from disk using it, validate it's halfway
// Before we read files from disk using it, validate it's halfway
// reasonable looking.
// reasonable looking.
@ -181,11 +191,11 @@ func (h *Handler) getCertPEMCached(dir, domain string, now time.Time) (p *keyPai
return nil , false
return nil , false
}
}
func ( h * Handler ) getCertPEM ( ctx context . Context , logf logger . Logf , traceACME func ( any ) , dir , domain string , now time . Time ) ( * keyPair , error ) {
func getCertPEM ( ctx context . Context , lb * ipnlocal . LocalBackend , logf logger . Logf , traceACME func ( any ) , dir , domain string , now time . Time ) ( * keyPair , error ) {
acmeMu . Lock ( )
acmeMu . Lock ( )
defer acmeMu . Unlock ( )
defer acmeMu . Unlock ( )
if p , ok := h . getCertPEMCached ( dir , domain , now ) ; ok {
if p , ok := getCertPEMCached ( dir , domain , now ) ; ok {
return p , nil
return p , nil
}
}
@ -223,7 +233,7 @@ func (h *Handler) getCertPEM(ctx context.Context, logf logger.Logf, traceACME fu
}
}
// Before hitting LetsEncrypt, see if this is a domain that Tailscale will do DNS challenges for.
// Before hitting LetsEncrypt, see if this is a domain that Tailscale will do DNS challenges for.
st := h. b. StatusWithoutPeers ( )
st := l b. StatusWithoutPeers ( )
if err := checkCertDomain ( st , domain ) ; err != nil {
if err := checkCertDomain ( st , domain ) ; err != nil {
return nil , err
return nil , err
}
}
@ -260,7 +270,7 @@ func (h *Handler) getCertPEM(ctx context.Context, logf logger.Logf, traceACME fu
}
}
if ! ok {
if ! ok {
logf ( "starting SetDNS call..." )
logf ( "starting SetDNS call..." )
err = h. b. SetDNS ( ctx , key , rec )
err = l b. SetDNS ( ctx , key , rec )
if err != nil {
if err != nil {
return nil , fmt . Errorf ( "SetDNS %q => %q: %w" , key , rec , err )
return nil , fmt . Errorf ( "SetDNS %q => %q: %w" , key , rec , err )
}
}