Add custom TLS options

Signed-off-by: Kevin Allen kallen@bostondynamics.com
pull/8070/head
Kevin Allen 1 year ago
parent 8864112a0c
commit d31a4d92e6

@ -41,6 +41,17 @@ var debug = envknob.RegisterBool("TS_DEBUG_TLS_DIAL")
// Headscale, etc.
var tlsdialWarningPrinted sync.Map // map[string]bool
var (
// rootCAOverride creates environment variable config TS_TLS_DIAL_ROOT_CA which
// will override the certificate authority used to verify the server instead
// of the system default
rootCAOverride = envknob.RegisterString("TS_TLS_DIAL_ROOT_CA")
// serverHostOverride creates environment variable TS_TLS_DIAL_CONNECT_TO which
// will override the server name the certificate is validated against AND the SNI
// name presented to the server, which may affect virtual hosts
serverHostOverride = envknob.RegisterString("TS_TLS_DIAL_CONNECT_TO")
)
// Config returns a tls.Config for connecting to a server.
// If base is non-nil, it's cloned as the base config before
// being configured and returned.
@ -52,6 +63,9 @@ func Config(host string, base *tls.Config) *tls.Config {
conf = base.Clone()
}
conf.ServerName = host
if len(serverHostOverride()) != 0 {
conf.ServerName = serverHostOverride()
}
if n := sslKeyLogFile; n != "" {
f, err := os.OpenFile(n, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600)
@ -93,6 +107,21 @@ func Config(host string, base *tls.Config) *tls.Config {
for _, cert := range cs.PeerCertificates[1:] {
opts.Intermediates.AddCert(cert)
}
// Check against user overriden root CA if provided
if overrideRoots() != nil {
opts.Roots = overrideRoots()
_, err := cs.PeerCertificates[0].Verify(opts)
if debug() {
log.Printf("tlsdial(override %q): %v", host, err)
}
if err == nil {
atomic.AddInt32(&counterFallbackOK, 1)
return nil
}
return err
}
_, errSys := cs.PeerCertificates[0].Verify(opts)
if debug() {
log.Printf("tlsdial(sys %q): %v", host, errSys)
@ -272,3 +301,27 @@ func bakedInRoots() *x509.CertPool {
})
return bakedInRootsOnce.p
}
var overrideRootsOnce struct {
sync.Once
p *x509.CertPool
}
func overrideRoots() *x509.CertPool {
if len(rootCAOverride()) == 0 {
return nil
}
overrideRootsOnce.Do(func() {
pem, err := os.ReadFile(rootCAOverride())
if err != nil {
panic(fmt.Sprintf("Error loading custom root CA %s: %v", rootCAOverride(), err))
}
p := x509.NewCertPool()
if !p.AppendCertsFromPEM(pem) {
panic(fmt.Sprintf("Invalid PEM in custom root CA %s", rootCAOverride()))
}
overrideRootsOnce.p = p
})
return overrideRootsOnce.p
}

Loading…
Cancel
Save