derp: clean up derphttp client code, use contexts

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/115/head
Brad Fitzpatrick 5 years ago
parent cdc10b74f1
commit 752146a70f

@ -97,10 +97,13 @@ func (s *Server) isClosed() bool {
return s.closed return s.closed
} }
// Accept adds a new connection to the server. // Accept adds a new connection to the server and serves it.
//
// The provided bufio ReadWriter must be already connected to nc. // The provided bufio ReadWriter must be already connected to nc.
// Accept blocks until the Server is closed or the connection closes // Accept blocks until the Server is closed or the connection closes
// on its own. // on its own.
//
// Accept closes nc.
func (s *Server) Accept(nc net.Conn, brw *bufio.ReadWriter) { func (s *Server) Accept(nc net.Conn, brw *bufio.ReadWriter) {
closed := make(chan struct{}) closed := make(chan struct{})

@ -12,16 +12,18 @@ package derphttp
import ( import (
"bufio" "bufio"
"bytes"
"context" "context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
"io"
"io/ioutil" "io/ioutil"
"log"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"sync" "sync"
"time"
"tailscale.com/derp" "tailscale.com/derp"
"tailscale.com/types/key" "tailscale.com/types/key"
@ -37,14 +39,14 @@ import (
type Client struct { type Client struct {
privateKey key.Private privateKey key.Private
logf logger.Logf logf logger.Logf
closed chan struct{}
url *url.URL url *url.URL
resp *http.Response
netConnMu sync.Mutex ctx context.Context // closed via cancelCtx in Client.Close
netConn net.Conn cancelCtx context.CancelFunc
clientMu sync.Mutex mu sync.Mutex
closed bool
netConn io.Closer
client *derp.Client client *derp.Client
} }
@ -55,12 +57,16 @@ func NewClient(privateKey key.Private, serverURL string, logf logger.Logf) (*Cli
if err != nil { if err != nil {
return nil, fmt.Errorf("derphttp.NewClient: %v", err) return nil, fmt.Errorf("derphttp.NewClient: %v", err)
} }
if urlPort(u) == "" {
return nil, fmt.Errorf("derphttp.NewClient: invalid URL scheme %q", u.Scheme)
}
ctx, cancel := context.WithCancel(context.Background())
c := &Client{ c := &Client{
privateKey: privateKey, privateKey: privateKey,
logf: logf, logf: logf,
url: u, url: u,
closed: make(chan struct{}), ctx: ctx,
cancelCtx: cancel,
} }
return c, nil return c, nil
} }
@ -72,71 +78,119 @@ func (c *Client) Connect(ctx context.Context) error {
return err return err
} }
func urlPort(u *url.URL) string {
if p := u.Port(); p != "" {
return p
}
switch u.Scheme {
case "https":
return "443"
case "http":
return "80"
}
return ""
}
func (c *Client) connect(ctx context.Context, caller string) (client *derp.Client, err error) { func (c *Client) connect(ctx context.Context, caller string) (client *derp.Client, err error) {
// TODO: use ctx for TCP+TLS+HTTP below c.mu.Lock()
select { defer c.mu.Unlock()
case <-c.closed: if c.closed {
return nil, ErrClientClosed return nil, ErrClientClosed
default:
} }
c.clientMu.Lock()
defer c.clientMu.Unlock()
if c.client != nil { if c.client != nil {
return c.client, nil return c.client, nil
} }
c.logf("%s: connecting", caller) c.logf("%s: connecting to %v", caller, c.url)
var netConn net.Conn // timeout is the fallback maximum time (if ctx doesn't limit
defer func() { // it further) to do all of: DNS + TCP + TLS + HTTP Upgrade +
if err != nil { // DERP upgrade.
err = fmt.Errorf("%s connect: %v", caller, err) const timeout = 10 * time.Second
if netConn != nil { ctx, cancel := context.WithTimeout(ctx, timeout)
netConn.Close() go func() {
} select {
case <-ctx.Done():
log.Printf("XXXX normal")
// Either timeout fired (handled below), or
// we're returning via the defer cancel()
// below.
case <-c.ctx.Done():
log.Printf("XXXX dead2")
// Propagate a Client.Close call into
// cancelling this context.
cancel()
} }
}() }()
defer cancel()
if c.url.Scheme == "https" { var tcpConn net.Conn
port := c.url.Port() defer func() {
if port == "" { if err != nil {
port = "443" if ctx.Err() != nil {
err = fmt.Errorf("%v: %v", ctx.Err(), err)
} }
config := &tls.Config{} err = fmt.Errorf("%s connect to %v: %v", caller, c.url, err)
var tlsConn *tls.Conn if tcpConn != nil {
tlsConn, err = tls.Dial("tcp", net.JoinHostPort(c.url.Host, port), config) go tcpConn.Close()
if tlsConn != nil {
netConn = tlsConn
} }
} else {
netConn, err = net.Dial("tcp", c.url.Host)
} }
}()
var d net.Dialer
log.Printf("Dialing: %q", net.JoinHostPort(c.url.Hostname(), urlPort(c.url)))
tcpConn, err = d.DialContext(ctx, "tcp", net.JoinHostPort(c.url.Hostname(), urlPort(c.url)))
if err != nil { if err != nil {
return nil, err return nil, err
} }
c.netConnMu.Lock() // Now that we have a TCP connection, force close it.
c.netConn = netConn done := make(chan struct{})
c.netConnMu.Unlock() defer close(done)
go func() {
select {
case <-done:
// Normal path. Upgrade occurred in time.
case <-ctx.Done():
select {
case <-done:
// Normal path. Upgrade occurred in time.
// But the ctx.Done() is also done because
// the "defer cancel()" above scheduled
// before this goroutine.
default:
// The TLS or HTTP or DERP exchanges didn't complete
// in time. Force close the TCP connection to force
// them to fail quickly.
tcpConn.Close()
}
}
}()
var httpConn net.Conn // a TCP conn or a TLS conn; what we speak HTTP to
if c.url.Scheme == "https" {
httpConn = tls.Client(tcpConn, &tls.Config{ServerName: c.url.Host})
} else {
httpConn = tcpConn
}
conn := bufio.NewReadWriter(bufio.NewReader(netConn), bufio.NewWriter(netConn)) brw := bufio.NewReadWriter(bufio.NewReader(httpConn), bufio.NewWriter(httpConn))
req, err := http.NewRequest("GET", c.url.String(), nil) req, err := http.NewRequest("GET", c.url.String(), nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
req.Header.Set("Upgrade", "WebSocket") req.Header.Set("Upgrade", "DERP")
req.Header.Set("Connection", "Upgrade") req.Header.Set("Connection", "Upgrade")
if err := req.Write(conn); err != nil {
if err := req.Write(brw); err != nil {
return nil, err return nil, err
} }
if err := conn.Flush(); err != nil { if err := brw.Flush(); err != nil {
return nil, err return nil, err
} }
resp, err := http.ReadResponse(conn.Reader, req) resp, err := http.ReadResponse(brw.Reader, req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -145,14 +199,14 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien
resp.Body.Close() resp.Body.Close()
return nil, fmt.Errorf("GET failed: %v: %s", err, b) return nil, fmt.Errorf("GET failed: %v: %s", err, b)
} }
resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
derpClient, err := derp.NewClient(c.privateKey, netConn, conn, c.logf) derpClient, err := derp.NewClient(c.privateKey, httpConn, brw, c.logf)
if err != nil { if err != nil {
return nil, err return nil, err
} }
c.resp = resp
c.client = derpClient c.client = derpClient
c.netConn = tcpConn
return c.client, nil return c.client, nil
} }
@ -162,7 +216,7 @@ func (c *Client) Send(dstKey key.Public, b []byte) error {
return err return err
} }
if err := client.Send(dstKey, b); err != nil { if err := client.Send(dstKey, b); err != nil {
c.close() c.Close()
} }
return err return err
} }
@ -174,7 +228,7 @@ func (c *Client) Recv(b []byte) (derp.ReceivedMessage, error) {
} }
m, err := client.Recv(b) m, err := client.Recv(b)
if err != nil { if err != nil {
c.close() c.Close()
} }
return m, err return m, err
} }
@ -182,35 +236,20 @@ func (c *Client) Recv(b []byte) (derp.ReceivedMessage, error) {
// Close closes the client. It will not automatically reconnect after // Close closes the client. It will not automatically reconnect after
// being closed. // being closed.
func (c *Client) Close() error { func (c *Client) Close() error {
select { c.cancelCtx() // not in lock, so it can cancel Connect, which holds mu
case <-c.closed:
return ErrClientClosed
default:
}
close(c.closed)
c.close()
return nil
}
func (c *Client) close() {
c.netConnMu.Lock()
netConn := c.netConn
c.netConnMu.Unlock()
if netConn != nil { c.mu.Lock()
netConn.Close() defer c.mu.Unlock()
if c.closed {
return ErrClientClosed
} }
c.closed = true
c.clientMu.Lock() if c.netConn != nil {
defer c.clientMu.Unlock() c.netConn.Close()
if c.client == nil { c.netConn = nil
return
} }
c.resp = nil
c.client = nil c.client = nil
c.netConnMu.Lock() return nil
c.netConn = nil
c.netConnMu.Unlock()
} }
var ErrClientClosed = errors.New("derphttp.Client closed") var ErrClientClosed = errors.New("derphttp.Client closed")

@ -5,6 +5,7 @@
package derphttp package derphttp
import ( import (
"log"
"net/http" "net/http"
"tailscale.com/derp" "tailscale.com/derp"
@ -12,11 +13,11 @@ import (
func Handler(s *derp.Server) http.Handler { func Handler(s *derp.Server) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Upgrade") != "WebSocket" { if p := r.Header.Get("Upgrade"); p != "WebSocket" && p != "DERP" {
http.Error(w, "DERP requires connection upgrade", http.StatusUpgradeRequired) http.Error(w, "DERP requires connection upgrade", http.StatusUpgradeRequired)
return return
} }
w.Header().Set("Upgrade", "WebSocket") w.Header().Set("Upgrade", "DERP")
w.Header().Set("Connection", "Upgrade") w.Header().Set("Connection", "Upgrade")
w.WriteHeader(http.StatusSwitchingProtocols) w.WriteHeader(http.StatusSwitchingProtocols)
@ -27,6 +28,7 @@ func Handler(s *derp.Server) http.Handler {
} }
netConn, conn, err := h.Hijack() netConn, conn, err := h.Hijack()
if err != nil { if err != nil {
log.Printf("Hijack failed: %v", err)
http.Error(w, "HTTP does not support general TCP support", 500) http.Error(w, "HTTP does not support general TCP support", 500)
return return
} }

Loading…
Cancel
Save