diff --git a/derp/derphttp/derphttp_client.go b/derp/derphttp/derphttp_client.go index 097960226..e4d5a1332 100644 --- a/derp/derphttp/derphttp_client.go +++ b/derp/derphttp/derphttp_client.go @@ -142,14 +142,11 @@ func (c *Client) useHTTPS() bool { return true } -// tlsServerName returns which TLS cert name to expect for the given node. +// tlsServerName returns the tls.Config.ServerName value (for the TLS ClientHello). func (c *Client) tlsServerName(node *tailcfg.DERPNode) string { if c.url != nil { return c.url.Host } - if node.CertName != "" { - return node.CertName - } return node.HostName } @@ -350,8 +347,13 @@ func (c *Client) dialRegion(ctx context.Context, reg *tailcfg.DERPRegion) (net.C func (c *Client) tlsClient(nc net.Conn, node *tailcfg.DERPNode) *tls.Conn { tlsConf := tlsdial.Config(c.tlsServerName(node), c.TLSConfig) - if node != nil && node.DERPTestPort != 0 { - tlsConf.InsecureSkipVerify = true + if node != nil { + if node.DERPTestPort != 0 { + tlsConf.InsecureSkipVerify = true + } + if node.CertName != "" { + tlsdial.SetConfigExpectedCert(tlsConf, node.CertName) + } } return tls.Client(nc, tlsConf) } diff --git a/net/tlsdial/tlsdial.go b/net/tlsdial/tlsdial.go index dfea2a4e3..fb7aafb3a 100644 --- a/net/tlsdial/tlsdial.go +++ b/net/tlsdial/tlsdial.go @@ -11,9 +11,14 @@ // control, DERP). package tlsdial -import "crypto/tls" +import ( + "crypto/tls" + "crypto/x509" + "errors" + "time" +) -// Config returns a tls.Config for dialing the given host. +// Config returns a tls.Config for connecting to a server. // If base is non-nil, it's cloned as the base config before // being configured and returned. func Config(host string, base *tls.Config) *tls.Config { @@ -27,3 +32,45 @@ func Config(host string, base *tls.Config) *tls.Config { return conf } + +// SetConfigExpectedCert modifies c to expect and verify that the server returns +// a certificate for the provided certDNSName. +func SetConfigExpectedCert(c *tls.Config, certDNSName string) { + if c.ServerName == certDNSName { + return + } + if c.ServerName == "" { + c.ServerName = certDNSName + return + } + if c.VerifyPeerCertificate != nil { + panic("refusing to override tls.Config.VerifyPeerCertificate") + } + // Set InsecureSkipVerify to prevent crypto/tls from doing its + // own cert verification, but do the same work that it'd do + // (but using certDNSName) in the VerifyPeerCertificate hook. + c.InsecureSkipVerify = true + c.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error { + if len(rawCerts) == 0 { + return errors.New("no certs presented") + } + certs := make([]*x509.Certificate, len(rawCerts)) + for i, asn1Data := range rawCerts { + cert, err := x509.ParseCertificate(asn1Data) + if err != nil { + return err + } + certs[i] = cert + } + opts := x509.VerifyOptions{ + CurrentTime: time.Now(), + DNSName: certDNSName, + Intermediates: x509.NewCertPool(), + } + for _, cert := range certs[1:] { + opts.Intermediates.AddCert(cert) + } + _, err := certs[0].Verify(opts) + return err + } +}