diff --git a/derp/derphttp/derphttp_client.go b/derp/derphttp/derphttp_client.go index 0123c12cc..cba2a2f44 100644 --- a/derp/derphttp/derphttp_client.go +++ b/derp/derphttp/derphttp_client.go @@ -13,6 +13,7 @@ package derphttp import ( "bufio" "bytes" + "context" "crypto/tls" "errors" "fmt" @@ -31,7 +32,8 @@ import ( // // It automatically reconnects on error retry. That is, a failed Send or // Recv will report the error and not retry, but subsequent calls to -// Send/Recv will completely re-establish the connection. +// Send/Recv will completely re-establish the connection (unless Close +// has been called). type Client struct { privateKey key.Private logf logger.Logf @@ -46,6 +48,8 @@ type Client struct { client *derp.Client } +// NewClient returns a new DERP-over-HTTP client. It connects lazily. +// To trigger a connection use Connect. func NewClient(privateKey key.Private, serverURL string, logf logger.Logf) (*Client, error) { u, err := url.Parse(serverURL) if err != nil { @@ -58,13 +62,18 @@ func NewClient(privateKey key.Private, serverURL string, logf logger.Logf) (*Cli url: u, closed: make(chan struct{}), } - if _, err := c.connect("derphttp.NewClient"); err != nil { - c.logf("%v", err) - } return c, nil } -func (c *Client) connect(caller string) (client *derp.Client, err error) { +// Connect connects or reconnects to the server, unless already connected. +// It returns nil if there was already a good connection, or if one was made. +func (c *Client) Connect(ctx context.Context) error { + _, err := c.connect(ctx, "derphttp.Client.Connect") + return err +} + +func (c *Client) connect(ctx context.Context, caller string) (client *derp.Client, err error) { + // TODO: use ctx for TCP+TLS+HTTP below select { case <-c.closed: return nil, ErrClientClosed @@ -84,7 +93,7 @@ func (c *Client) connect(caller string) (client *derp.Client, err error) { defer func() { if err != nil { err = fmt.Errorf("%s connect: %v", caller, err) - if netConn := netConn; netConn != nil { + if netConn != nil { netConn.Close() } } @@ -148,7 +157,7 @@ func (c *Client) connect(caller string) (client *derp.Client, err error) { } func (c *Client) Send(dstKey key.Public, b []byte) error { - client, err := c.connect("derphttp.Client.Send") + client, err := c.connect(context.TODO(), "derphttp.Client.Send") if err != nil { return err } @@ -159,7 +168,7 @@ func (c *Client) Send(dstKey key.Public, b []byte) error { } func (c *Client) Recv(b []byte) (int, error) { - client, err := c.connect("derphttp.Client.Recv") + client, err := c.connect(context.TODO(), "derphttp.Client.Recv") if err != nil { return 0, err } @@ -170,6 +179,8 @@ func (c *Client) Recv(b []byte) (int, error) { return n, err } +// Close closes the client. It will not automatically reconnect after +// being closed. func (c *Client) Close() error { select { case <-c.closed: diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 7b6d15dc0..537ec2098 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -14,6 +14,7 @@ import ( "net" "strings" "sync" + "sync/atomic" "syscall" "time" @@ -51,8 +52,9 @@ type Conn struct { indexedAddrsMu sync.Mutex indexedAddrs map[udpAddr]indexedAddrSet - stunReceiveMu sync.Mutex - stunReceive func(p []byte, fromAddr *net.UDPAddr) + // stunReceiveFunc holds the current STUN packet processing func. + // Its Loaded value is always non-nil. + stunReceiveFunc atomic.Value // of func(p []byte, fromAddr *net.UDPAddr) derpMu sync.Mutex derp *derphttp.Client @@ -140,12 +142,21 @@ func Listen(opts Options) (*Conn, error) { logf: log.Printf, indexedAddrs: make(map[udpAddr]indexedAddrSet), } + c.ignoreSTUNPackets() c.pconn.Reset(packetConn.(*net.UDPConn)) c.startEpUpdate <- struct{}{} // STUN immediately on start go c.epUpdate(epUpdateCtx) return c, nil } +// ignoreSTUNPackets sets a STUN packet processing func that does nothing. +func (c *Conn) ignoreSTUNPackets() { + c.stunReceiveFunc.Store(func([]byte, *net.UDPAddr) {}) +} + +// epUpdate runs in its own goroutine until ctx is shut down. +// Whenever c.startEpUpdate receives a value, it starts an +// STUN endpoint lookup. func (c *Conn) epUpdate(ctx context.Context) { var lastEndpoints []string var lastCancel func() @@ -186,18 +197,22 @@ func (c *Conn) epUpdate(ctx context.Context) { } } +// determineEndpoints returns the machine's endpoint addresses. It +// does a STUN lookup to determine its public address. func (c *Conn) determineEndpoints(ctx context.Context) ([]string, error) { - var alreadyMu sync.Mutex - already := make(map[string]struct{}) - var eps []string + var ( + alreadyMu sync.Mutex + already = make(map[string]bool) // endpoint -> true + ) + var eps []string // unique endpoints addAddr := func(s, reason string) { log.Printf("magicsock: found local %s (%s)\n", s, reason) alreadyMu.Lock() defer alreadyMu.Unlock() - if _, ok := already[s]; !ok { - already[s] = struct{}{} + if !already[s] { + already[s] = true eps = append(eps, s) } } @@ -209,17 +224,13 @@ func (c *Conn) determineEndpoints(ctx context.Context) ([]string, error) { Logf: c.logf, } - c.stunReceiveMu.Lock() - c.stunReceive = s.Receive - c.stunReceiveMu.Unlock() + c.stunReceiveFunc.Store(s.Receive) if err := s.Run(ctx); err != nil { return nil, err } - c.stunReceiveMu.Lock() - c.stunReceive = nil - c.stunReceiveMu.Unlock() + c.ignoreSTUNPackets() if localAddr := c.pconn.LocalAddr(); localAddr.IP.IsUnspecified() { localPort := fmt.Sprintf("%d", localAddr.Port) @@ -421,16 +432,9 @@ func (c *Conn) ReceiveIPv4(b []byte) (n int, ep device.Endpoint, addr *net.UDPAd if !stun.Is(b[:n]) { break } - c.stunReceiveMu.Lock() - fn := c.stunReceive - c.stunReceiveMu.Unlock() - - if fn != nil { - fn(b, addr) - } + c.stunReceiveFunc.Load().(func([]byte, *net.UDPAddr))(b, addr) } - // TODO(crawshaw): remove all the indexed-addr logic addrSet, _ := c.findIndexedAddrSet(addr) if addrSet == nil { // The peer that sent this packet has roamed beyond the @@ -457,14 +461,14 @@ func (c *Conn) SetPrivateKey(privateKey wgcfg.PrivateKey) error { return err } go func() { - var b [1 << 16]byte + var b [64 << 10]byte for { n, err := derp.Recv(b[:]) if err != nil { if err == derphttp.ErrClientClosed { return } - log.Printf("%v", err) + log.Printf("derp.Recv: %v", err) time.Sleep(250 * time.Millisecond) } @@ -696,16 +700,19 @@ func (a *AddrSet) Addrs() []wgcfg.Endpoint { return eps } -func (c *Conn) CreateEndpoint(key [32]byte, s string) (device.Endpoint, error) { +// CreateEndpoint is called by WireGuard to connect to an endpoint. +// The key is the public key of the peer and addrs is a +// comma-separated list of UDP ip:ports. +func (c *Conn) CreateEndpoint(key [32]byte, addrs string) (device.Endpoint, error) { pk := wgcfg.Key(key) - log.Printf("magicsock: CreateEndpoint: key=%s: %s", pk.ShortString(), s) + log.Printf("magicsock: CreateEndpoint: key=%s: %s", pk.ShortString(), addrs) a := &AddrSet{ publicKey: key, curAddr: -1, } - if s != "" { - for _, ep := range strings.Split(s, ",") { + if addrs != "" { + for _, ep := range strings.Split(addrs, ",") { addr, err := net.ResolveUDPAddr("udp", ep) if err != nil { return nil, err diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 299783fa4..d0635a2fe 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -36,7 +36,7 @@ func TestListen(t *testing.T) { defer conn.Close() go func() { - var pkt [1 << 16]byte + var pkt [64 << 10]byte for { _, _, _, err := conn.ReceiveIPv4(pkt[:]) if err != nil { diff --git a/wgengine/userspace.go b/wgengine/userspace.go index 91402d7cd..e60f4a3e8 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -107,18 +107,14 @@ func newUserspaceEngineAdvanced(logf logger.Logf, tundev tun.Device, routerGen R endpointsFn := func(endpoints []string) { e.mu.Lock() - if e.endpoints != nil { - e.endpoints = e.endpoints[:0] - } - e.endpoints = append(e.endpoints, endpoints...) + e.endpoints = append(e.endpoints[:0], endpoints...) e.mu.Unlock() e.RequestStatus() } magicsockOpts := magicsock.Options{ - Port: listenPort, - STUN: magicsock.DefaultSTUN, - // TODO(crawshaw): DERP: magicsock.DefaultDERP, + Port: listenPort, + STUN: magicsock.DefaultSTUN, EndpointsFunc: endpointsFn, } e.magicConn, err = magicsock.Listen(magicsockOpts)