From 7f68e097ddc314c662a96ca5b3908b9f1a5930cf Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Fri, 29 May 2020 22:33:08 -0700 Subject: [PATCH] net/netcheck: fix HTTPS fallback bug from earlier today My earlier 3fa58303d0dd206609362d68736f1039801ffd8d tried to implement the net/http.Tranhsport.DialTLSContext hook, but I didn't return a *tls.Conn, so we ended up sending a plaintext HTTP request to an HTTPS port. The response ended up being Go telling as such, not the /derp/latency-check handler's response (which is currently still a 404). But we didn't even get the 404. This happened to work well enough because Go's built-in error response was still a valid HTTP response that we can measure for timing purposes, but it's not a great answer. Notably, it means we wouldn't be able to get a future handler to run server-side and count those latency requests. --- derp/derphttp/derphttp_client.go | 52 ++++++++++++++++++++++++++------ net/netcheck/netcheck.go | 22 ++++++++------ 2 files changed, 54 insertions(+), 20 deletions(-) diff --git a/derp/derphttp/derphttp_client.go b/derp/derphttp/derphttp_client.go index 147f10b1a..097960226 100644 --- a/derp/derphttp/derphttp_client.go +++ b/derp/derphttp/derphttp_client.go @@ -142,8 +142,8 @@ func (c *Client) useHTTPS() bool { return true } -// TLSServerName returns which TLS cert name to expect for the given node. -func (c *Client) TLSServerName(node *tailcfg.DERPNode) string { +// tlsServerName returns which TLS cert name to expect for the given node. +func (c *Client) tlsServerName(node *tailcfg.DERPNode) string { if c.url != nil { return c.url.Host } @@ -217,7 +217,7 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien tcpConn, err = c.dialURL(ctx) } else { c.logf("%s: connecting to derp-%d (%v)", caller, reg.RegionID, reg.RegionCode) - tcpConn, node, err = c.DialRegion(ctx, reg) + tcpConn, node, err = c.dialRegion(ctx, reg) } if err != nil { return nil, err @@ -249,11 +249,7 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien var httpConn net.Conn // a TCP conn or a TLS conn; what we speak HTTP to if c.useHTTPS() { - tlsConf := tlsdial.Config(c.TLSServerName(node), c.TLSConfig) - if node != nil && node.DERPTestPort != 0 { - tlsConf.InsecureSkipVerify = true - } - httpConn = tls.Client(tcpConn, tlsConf) + httpConn = c.tlsClient(tcpConn, node) } else { httpConn = tcpConn } @@ -329,10 +325,10 @@ func (c *Client) dialURL(ctx context.Context) (net.Conn, error) { return tcpConn, nil } -// DialRegion returns a TCP connection to the provided region, trying +// dialRegion returns a TCP connection to the provided region, trying // each node in order (with dialNode) until one connects or ctx is // done. -func (c *Client) DialRegion(ctx context.Context, reg *tailcfg.DERPRegion) (net.Conn, *tailcfg.DERPNode, error) { +func (c *Client) dialRegion(ctx context.Context, reg *tailcfg.DERPRegion) (net.Conn, *tailcfg.DERPNode, error) { if len(reg.Nodes) == 0 { return nil, nil, fmt.Errorf("no nodes for %s", c.targetString(reg)) } @@ -352,6 +348,42 @@ func (c *Client) DialRegion(ctx context.Context, reg *tailcfg.DERPRegion) (net.C return nil, nil, firstErr } +func (c *Client) tlsClient(nc net.Conn, node *tailcfg.DERPNode) *tls.Conn { + tlsConf := tlsdial.Config(c.tlsServerName(node), c.TLSConfig) + if node != nil && node.DERPTestPort != 0 { + tlsConf.InsecureSkipVerify = true + } + return tls.Client(nc, tlsConf) +} + +func (c *Client) DialRegionTLS(ctx context.Context, reg *tailcfg.DERPRegion) (tlsConn *tls.Conn, connClose io.Closer, err error) { + tcpConn, node, err := c.dialRegion(ctx, reg) + if err != nil { + return nil, nil, err + } + done := make(chan bool) // unbufferd + defer close(done) + + tlsConn = c.tlsClient(tcpConn, node) + go func() { + select { + case <-done: + case <-ctx.Done(): + tcpConn.Close() + } + }() + err = tlsConn.Handshake() + if err != nil { + return nil, nil, err + } + select { + case done <- true: + return tlsConn, tcpConn, nil + case <-ctx.Done(): + return nil, nil, ctx.Err() + } +} + func (c *Client) dialContext(ctx context.Context, proto, addr string) (net.Conn, error) { var stdDialer dialer = netns.Dialer() var dialer = stdDialer diff --git a/net/netcheck/netcheck.go b/net/netcheck/netcheck.go index 59f426d97..a5f908a1e 100644 --- a/net/netcheck/netcheck.go +++ b/net/netcheck/netcheck.go @@ -8,6 +8,7 @@ package netcheck import ( "bytes" "context" + "crypto/tls" "errors" "fmt" "io" @@ -786,23 +787,26 @@ func (c *Client) measureHTTPSLatency(ctx context.Context, reg *tailcfg.DERPRegio var ip netaddr.IP dc := derphttp.NewNetcheckClient(c.logf) - nc, node, err := dc.DialRegion(ctx, reg) + tlsConn, tcpConn, err := dc.DialRegionTLS(ctx, reg) if err != nil { return 0, ip, err } - defer nc.Close() + defer tcpConn.Close() - if ta, ok := nc.RemoteAddr().(*net.TCPAddr); ok { + if ta, ok := tlsConn.RemoteAddr().(*net.TCPAddr); ok { ip, _ = netaddr.FromStdIP(ta.IP) } if ip == (netaddr.IP{}) { - return 0, ip, fmt.Errorf("no unexpected RemoteAddr %#v", nc.RemoteAddr()) + return 0, ip, fmt.Errorf("no unexpected RemoteAddr %#v", tlsConn.RemoteAddr()) } - connc := make(chan net.Conn, 1) - connc <- nc + connc := make(chan *tls.Conn, 1) + connc <- tlsConn tr := &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return nil, errors.New("unexpected DialContext dial") + }, DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { select { case nc := <-connc: @@ -814,9 +818,7 @@ func (c *Client) measureHTTPSLatency(ctx context.Context, reg *tailcfg.DERPRegio } hc := &http.Client{Transport: tr} - host := dc.TLSServerName(node) - u := fmt.Sprintf("https://%s/derp/latency-check", host) - req, err := http.NewRequestWithContext(ctx, "GET", u, nil) + req, err := http.NewRequestWithContext(ctx, "GET", "https://derp-unused-hostname.tld/derp/latency-check", nil) if err != nil { return 0, ip, err } @@ -827,7 +829,7 @@ func (c *Client) measureHTTPSLatency(ctx context.Context, reg *tailcfg.DERPRegio } defer resp.Body.Close() - _, err = io.Copy(ioutil.Discard, resp.Body) + _, err = io.Copy(ioutil.Discard, io.LimitReader(resp.Body, 8<<10)) if err != nil { return 0, ip, err }