|
|
|
@ -41,6 +41,17 @@ var debug = envknob.RegisterBool("TS_DEBUG_TLS_DIAL")
|
|
|
|
|
// Headscale, etc.
|
|
|
|
|
var tlsdialWarningPrinted sync.Map // map[string]bool
|
|
|
|
|
|
|
|
|
|
var (
|
|
|
|
|
// rootCAOverride creates environment variable config TS_TLS_DIAL_ROOT_CA which
|
|
|
|
|
// will override the certificate authority used to verify the server instead
|
|
|
|
|
// of the system default
|
|
|
|
|
rootCAOverride = envknob.RegisterString("TS_TLS_DIAL_ROOT_CA")
|
|
|
|
|
// serverHostOverride creates environment variable TS_TLS_DIAL_CONNECT_TO which
|
|
|
|
|
// will override the server name the certificate is validated against AND the SNI
|
|
|
|
|
// name presented to the server, which may affect virtual hosts
|
|
|
|
|
serverHostOverride = envknob.RegisterString("TS_TLS_DIAL_CONNECT_TO")
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
// Config returns a tls.Config for connecting to a server.
|
|
|
|
|
// If base is non-nil, it's cloned as the base config before
|
|
|
|
|
// being configured and returned.
|
|
|
|
@ -52,6 +63,9 @@ func Config(host string, base *tls.Config) *tls.Config {
|
|
|
|
|
conf = base.Clone()
|
|
|
|
|
}
|
|
|
|
|
conf.ServerName = host
|
|
|
|
|
if len(serverHostOverride()) != 0 {
|
|
|
|
|
conf.ServerName = serverHostOverride()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if n := sslKeyLogFile; n != "" {
|
|
|
|
|
f, err := os.OpenFile(n, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600)
|
|
|
|
@ -93,6 +107,21 @@ func Config(host string, base *tls.Config) *tls.Config {
|
|
|
|
|
for _, cert := range cs.PeerCertificates[1:] {
|
|
|
|
|
opts.Intermediates.AddCert(cert)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Check against user overriden root CA if provided
|
|
|
|
|
if overrideRoots() != nil {
|
|
|
|
|
opts.Roots = overrideRoots()
|
|
|
|
|
_, err := cs.PeerCertificates[0].Verify(opts)
|
|
|
|
|
if debug() {
|
|
|
|
|
log.Printf("tlsdial(override %q): %v", host, err)
|
|
|
|
|
}
|
|
|
|
|
if err == nil {
|
|
|
|
|
atomic.AddInt32(&counterFallbackOK, 1)
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
_, errSys := cs.PeerCertificates[0].Verify(opts)
|
|
|
|
|
if debug() {
|
|
|
|
|
log.Printf("tlsdial(sys %q): %v", host, errSys)
|
|
|
|
@ -272,3 +301,27 @@ func bakedInRoots() *x509.CertPool {
|
|
|
|
|
})
|
|
|
|
|
return bakedInRootsOnce.p
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var overrideRootsOnce struct {
|
|
|
|
|
sync.Once
|
|
|
|
|
p *x509.CertPool
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func overrideRoots() *x509.CertPool {
|
|
|
|
|
if len(rootCAOverride()) == 0 {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
overrideRootsOnce.Do(func() {
|
|
|
|
|
pem, err := os.ReadFile(rootCAOverride())
|
|
|
|
|
if err != nil {
|
|
|
|
|
panic(fmt.Sprintf("Error loading custom root CA %s: %v", rootCAOverride(), err))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
p := x509.NewCertPool()
|
|
|
|
|
if !p.AppendCertsFromPEM(pem) {
|
|
|
|
|
panic(fmt.Sprintf("Invalid PEM in custom root CA %s", rootCAOverride()))
|
|
|
|
|
}
|
|
|
|
|
overrideRootsOnce.p = p
|
|
|
|
|
})
|
|
|
|
|
return overrideRootsOnce.p
|
|
|
|
|
}
|
|
|
|
|