prober: allow custom tls.Config for TLS probes (#17186)

Updates https://github.com/tailscale/corp/issues/28569

Signed-off-by: Andrew Lytvynov <awly@tailscale.com>
pull/17082/merge
Andrew Lytvynov 3 months ago committed by GitHub
parent 73bbd7caca
commit 70dfdac609
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -8,6 +8,7 @@ import (
"cmp" "cmp"
"context" "context"
crand "crypto/rand" crand "crypto/rand"
"crypto/tls"
"encoding/binary" "encoding/binary"
"encoding/json" "encoding/json"
"errors" "errors"
@ -68,7 +69,7 @@ type derpProber struct {
ProbeMap ProbeClass ProbeMap ProbeClass
// Probe classes for probing individual derpers. // Probe classes for probing individual derpers.
tlsProbeFn func(string) ProbeClass tlsProbeFn func(string, *tls.Config) ProbeClass
udpProbeFn func(string, int) ProbeClass udpProbeFn func(string, int) ProbeClass
meshProbeFn func(string, string) ProbeClass meshProbeFn func(string, string) ProbeClass
bwProbeFn func(string, string, int64) ProbeClass bwProbeFn func(string, string, int64) ProbeClass
@ -206,7 +207,7 @@ func (d *derpProber) probeMapFn(ctx context.Context) error {
if d.probes[n] == nil { if d.probes[n] == nil {
log.Printf("adding DERP TLS probe for %s (%s) every %v", server.Name, region.RegionName, d.tlsInterval) log.Printf("adding DERP TLS probe for %s (%s) every %v", server.Name, region.RegionName, d.tlsInterval)
derpPort := cmp.Or(server.DERPPort, 443) derpPort := cmp.Or(server.DERPPort, 443)
d.probes[n] = d.p.Run(n, d.tlsInterval, labels, d.tlsProbeFn(fmt.Sprintf("%s:%d", server.HostName, derpPort))) d.probes[n] = d.p.Run(n, d.tlsInterval, labels, d.tlsProbeFn(fmt.Sprintf("%s:%d", server.HostName, derpPort), nil))
} }
} }

@ -74,7 +74,7 @@ func TestDerpProber(t *testing.T) {
p: p, p: p,
derpMapURL: srv.URL, derpMapURL: srv.URL,
tlsInterval: time.Second, tlsInterval: time.Second,
tlsProbeFn: func(_ string) ProbeClass { return FuncProbe(func(context.Context) error { return nil }) }, tlsProbeFn: func(_ string, _ *tls.Config) ProbeClass { return FuncProbe(func(context.Context) error { return nil }) },
udpInterval: time.Second, udpInterval: time.Second,
udpProbeFn: func(_ string, _ int) ProbeClass { return FuncProbe(func(context.Context) error { return nil }) }, udpProbeFn: func(_ string, _ int) ProbeClass { return FuncProbe(func(context.Context) error { return nil }) },
meshInterval: time.Second, meshInterval: time.Second,

@ -5,6 +5,7 @@ package prober_test
import ( import (
"context" "context"
"crypto/tls"
"flag" "flag"
"fmt" "fmt"
"log" "log"
@ -40,7 +41,7 @@ func ExampleForEachAddr() {
// This function is called every time we discover a new IP address to check. // This function is called every time we discover a new IP address to check.
makeTLSProbe := func(addr netip.Addr) []*prober.Probe { makeTLSProbe := func(addr netip.Addr) []*prober.Probe {
pf := prober.TLSWithIP(*hostname, netip.AddrPortFrom(addr, 443)) pf := prober.TLSWithIP(netip.AddrPortFrom(addr, 443), &tls.Config{ServerName: *hostname})
if *verbose { if *verbose {
logger := logger.WithPrefix(log.Printf, fmt.Sprintf("[tls %s]: ", addr)) logger := logger.WithPrefix(log.Printf, fmt.Sprintf("[tls %s]: ", addr))
pf = probeLogWrapper(logger, pf) pf = probeLogWrapper(logger, pf)

@ -9,9 +9,9 @@ import (
"crypto/x509" "crypto/x509"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"net/netip" "net/netip"
"slices"
"time" "time"
"tailscale.com/util/multierr" "tailscale.com/util/multierr"
@ -28,33 +28,31 @@ const letsEncryptStartedStaplingCRL int64 = 1746576000 // 2025-05-07 00:00:00 UT
// The ProbeFunc connects to a hostPort (host:port string), does a TLS // The ProbeFunc connects to a hostPort (host:port string), does a TLS
// handshake, verifies that the hostname matches the presented certificate, // handshake, verifies that the hostname matches the presented certificate,
// checks certificate validity time and OCSP revocation status. // checks certificate validity time and OCSP revocation status.
func TLS(hostPort string) ProbeClass { //
// The TLS config is optional and may be nil.
func TLS(hostPort string, config *tls.Config) ProbeClass {
return ProbeClass{ return ProbeClass{
Probe: func(ctx context.Context) error { Probe: func(ctx context.Context) error {
certDomain, _, err := net.SplitHostPort(hostPort) return probeTLS(ctx, config, hostPort)
if err != nil {
return err
}
return probeTLS(ctx, certDomain, hostPort)
}, },
Class: "tls", Class: "tls",
} }
} }
// TLSWithIP is like TLS, but dials the provided dialAddr instead // TLSWithIP is like TLS, but dials the provided dialAddr instead of using DNS
// of using DNS resolution. The certDomain is the expected name in // resolution. Use config.ServerName to send SNI and validate the name in the
// the cert (and the SNI name to send). // cert.
func TLSWithIP(certDomain string, dialAddr netip.AddrPort) ProbeClass { func TLSWithIP(dialAddr netip.AddrPort, config *tls.Config) ProbeClass {
return ProbeClass{ return ProbeClass{
Probe: func(ctx context.Context) error { Probe: func(ctx context.Context) error {
return probeTLS(ctx, certDomain, dialAddr.String()) return probeTLS(ctx, config, dialAddr.String())
}, },
Class: "tls", Class: "tls",
} }
} }
func probeTLS(ctx context.Context, certDomain string, dialHostPort string) error { func probeTLS(ctx context.Context, config *tls.Config, dialHostPort string) error {
dialer := &tls.Dialer{Config: &tls.Config{ServerName: certDomain}} dialer := &tls.Dialer{Config: config}
conn, err := dialer.DialContext(ctx, "tcp", dialHostPort) conn, err := dialer.DialContext(ctx, "tcp", dialHostPort)
if err != nil { if err != nil {
return fmt.Errorf("connecting to %q: %w", dialHostPort, err) return fmt.Errorf("connecting to %q: %w", dialHostPort, err)
@ -108,6 +106,10 @@ func validateConnState(ctx context.Context, cs *tls.ConnectionState) (returnerr
} }
if len(leafCert.CRLDistributionPoints) == 0 { if len(leafCert.CRLDistributionPoints) == 0 {
if !slices.Contains(leafCert.Issuer.Organization, "Let's Encrypt") {
// LE certs contain a CRL, but certs from other CAs might not.
return
}
if leafCert.NotBefore.Before(time.Unix(letsEncryptStartedStaplingCRL, 0)) { if leafCert.NotBefore.Before(time.Unix(letsEncryptStartedStaplingCRL, 0)) {
// Certificate might not have a CRL. // Certificate might not have a CRL.
return return

@ -83,7 +83,7 @@ func TestTLSConnection(t *testing.T) {
srv.StartTLS() srv.StartTLS()
defer srv.Close() defer srv.Close()
err = probeTLS(context.Background(), "fail.example.com", srv.Listener.Addr().String()) err = probeTLS(context.Background(), &tls.Config{ServerName: "fail.example.com"}, srv.Listener.Addr().String())
// The specific error message here is platform-specific ("certificate is not trusted" // The specific error message here is platform-specific ("certificate is not trusted"
// on macOS and "certificate signed by unknown authority" on Linux), so only check // on macOS and "certificate signed by unknown authority" on Linux), so only check
// that it contains the word 'certificate'. // that it contains the word 'certificate'.
@ -269,40 +269,54 @@ func TestCRL(t *testing.T) {
name string name string
cert *x509.Certificate cert *x509.Certificate
crlBytes []byte crlBytes []byte
issuer pkix.Name
wantErr string wantErr string
}{ }{
{ {
"ValidCert", "ValidCert",
leafCertParsed, leafCertParsed,
emptyRlBytes, emptyRlBytes,
caCert.Issuer,
"", "",
}, },
{ {
"RevokedCert", "RevokedCert",
leafCertParsed, leafCertParsed,
rlBytes, rlBytes,
caCert.Issuer,
"has been revoked on", "has been revoked on",
}, },
{ {
"EmptyCRL", "EmptyCRL",
leafCertParsed, leafCertParsed,
emptyRlBytes, emptyRlBytes,
caCert.Issuer,
"", "",
}, },
{ {
"NoCRL", "NoCRLLetsEncrypt",
leafCertParsed, leafCertParsed,
nil, nil,
pkix.Name{CommonName: "tlsprobe.test", Organization: []string{"Let's Encrypt"}},
"no CRL server presented in leaf cert for", "no CRL server presented in leaf cert for",
}, },
{
"NoCRLOtherCA",
leafCertParsed,
nil,
caCert.Issuer,
"",
},
{ {
"NotBeforeCRLStaplingDate", "NotBeforeCRLStaplingDate",
noCRLStapledParsed, noCRLStapledParsed,
nil, nil,
caCert.Issuer,
"", "",
}, },
} { } {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
tt.cert.Issuer = tt.issuer
cs := &tls.ConnectionState{PeerCertificates: []*x509.Certificate{tt.cert, caCert}} cs := &tls.ConnectionState{PeerCertificates: []*x509.Certificate{tt.cert, caCert}}
if tt.crlBytes != nil { if tt.crlBytes != nil {
crlServer.crlBytes = tt.crlBytes crlServer.crlBytes = tt.crlBytes

Loading…
Cancel
Save