From d298d5b1f85df98036c4352f9c1093620c132c66 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Tue, 18 Feb 2020 13:32:04 -0800 Subject: [PATCH] wgengine/magicsock: support multiple derp servers, and not just for handshakes Signed-off-by: Brad Fitzpatrick --- wgengine/magicsock/derpmap.go | 49 ++++ wgengine/magicsock/magicsock.go | 405 ++++++++++++++++++++------- wgengine/magicsock/magicsock_test.go | 6 + 3 files changed, 365 insertions(+), 95 deletions(-) create mode 100644 wgengine/magicsock/derpmap.go diff --git a/wgengine/magicsock/derpmap.go b/wgengine/magicsock/derpmap.go new file mode 100644 index 000000000..a28635f40 --- /dev/null +++ b/wgengine/magicsock/derpmap.go @@ -0,0 +1,49 @@ +// Copyright 2019 Tailscale & AUTHORS. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package magicsock + +import ( + "fmt" + "net" +) + +// derpFakeIPStr is a fake WireGuard endpoint IP address that means +// to use DERP. When used, the port number of the WireGuard endpoint +// is the DERP server number to use. +const derpMagicIPStr = "127.3.3.40" // 3340 are above the keys DERP on the keyboard +var derpMagicIP = net.IPv4(127, 3, 3, 40) // net.IP version of above + +var ( + derpHostOfIndex = map[int]string{} // index (fake port number) -> hostname + derpIndexOfHost = map[string]int{} // derpHostOfIndex reversed +) + +func init() { + // Just one zone for now: + addDerper(1, "derp.tailscale.com") +} + +func addDerper(i int, host string) { + if other, dup := derpHostOfIndex[i]; dup { + panic(fmt.Sprintf("duplicate DERP index %v (host %q and %q)", i, other, host)) + } + if other, dup := derpIndexOfHost[host]; dup { + panic(fmt.Sprintf("duplicate DERP host %q (index %v and %v)", host, other, i)) + } + derpHostOfIndex[i] = host + derpIndexOfHost[host] = i +} + +// derpHost returns the hostname of a DERP server index (a fake port +// number used with derpMagicIP). It always returns a non-empty string. +func derpHost(i int) string { + if h, ok := derpHostOfIndex[i]; ok { + return h + } + if 1 <= i && i <= 64<<10 { + return fmt.Sprintf("derp%v.tailscale.com", i) + } + return "derp.tailscale.com" +} diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index cd2846934..c6d955766 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -9,6 +9,7 @@ package magicsock import ( "context" "encoding/binary" + "errors" "fmt" "log" "net" @@ -31,11 +32,12 @@ import ( type Conn struct { pconn *RebindingUDPConn pconnPort uint16 + privateKey key.Private stunServers []string - derpServer string startEpUpdate chan struct{} // send to trigger endpoint update epFunc func(endpoints []string) logf func(format string, args ...interface{}) + donec chan struct{} // closed on Conn.Close epUpdateCtx context.Context // endpoint updater context epUpdateCancel func() // the func to cancel epUpdateCtx @@ -58,8 +60,12 @@ type Conn struct { // Its Loaded value is always non-nil. stunReceiveFunc atomic.Value // of func(p []byte, fromAddr *net.UDPAddr) - derpMu sync.Mutex - derp *derphttp.Client + udpRecvCh chan udpReadResult + derpRecvCh chan derpReadResult + + derpMu sync.Mutex + derpConn map[int]*derphttp.Client // magic derp port (see derpmap.go) to its client + derpWriteCh map[int]chan<- derpWriteRequest } // udpAddr is the key in the indexedAddrs map. @@ -81,8 +87,6 @@ type indexedAddrSet struct { // The current default (zero) means to auto-select a random free port. const DefaultPort = 0 -const DefaultDERP = "https://derp.tailscale.com/derp" - var DefaultSTUN = []string{ "stun.l.google.com:19302", "stun3.l.google.com:19302", @@ -95,7 +99,6 @@ type Options struct { Port uint16 STUN []string - DERP string // EndpointsFunc optionally provides a func to be called when // endpoints change. The called func does not own the slice. @@ -136,14 +139,16 @@ func Listen(opts Options) (*Conn, error) { epUpdateCtx, epUpdateCancel := context.WithCancel(context.Background()) c := &Conn{ pconn: new(RebindingUDPConn), + donec: make(chan struct{}), stunServers: append([]string{}, opts.STUN...), - derpServer: opts.DERP, startEpUpdate: make(chan struct{}, 1), epUpdateCtx: epUpdateCtx, epUpdateCancel: epUpdateCancel, epFunc: opts.endpointsFunc(), logf: log.Printf, indexedAddrs: make(map[udpAddr]indexedAddrSet), + derpRecvCh: make(chan derpReadResult), + udpRecvCh: make(chan udpReadResult), } c.ignoreSTUNPackets() c.pconn.Reset(packetConn.(*net.UDPConn)) @@ -355,55 +360,225 @@ func (c *Conn) LocalPort() uint16 { return uint16(laddr.Port) } -func (c *Conn) Send(b []byte, ep device.Endpoint) error { - a := ep.(*AddrSet) - +func shouldSprayPacket(b []byte) bool { + if len(b) < 4 { + return false + } msgType := binary.LittleEndian.Uint32(b[:4]) switch msgType { - case device.MessageInitiationType, device.MessageResponseType, device.MessageCookieReplyType: - // Part of the wireguard handshake. - // Send to every potential endpoint we have for a peer. - a.mu.Lock() - roamAddr := a.roamAddr - a.mu.Unlock() - - var err error - var success bool - if roamAddr != nil { - _, err = c.pconn.WriteTo(b, roamAddr) - if err == nil { - success = true - } + case device.MessageInitiationType, + device.MessageResponseType, + device.MessageCookieReplyType: // TODO: necessary? + return true + } + return false +} + +// appendDests appends to dsts the destinations that b should be +// written to in order to reach as. Some of the returned UDPAddrs may +// be fake addrs representing DERP servers. +// +// It also returns as's current roamAddr, if any. +func appendDests(dsts []*net.UDPAddr, as *AddrSet, b []byte) (_ []*net.UDPAddr, roamAddr *net.UDPAddr) { + spray := shouldSprayPacket(b) + + as.mu.Lock() + defer as.mu.Unlock() + + roamAddr = as.roamAddr + if roamAddr != nil { + dsts = append(dsts, roamAddr) + if !spray { + return dsts, roamAddr } - for i := len(a.addrs) - 1; i >= 0; i-- { - addr := &a.addrs[i] - _, err = c.pconn.WriteTo(b, addr) - if err == nil { - success = true - } + } + for i := len(as.addrs) - 1; i >= 0; i-- { + addr := &as.addrs[i] + if spray || as.curAddr == -1 || as.curAddr == i { + dsts = append(dsts, addr) + } + if !spray && len(dsts) != 0 { + break + } + } + return dsts, roamAddr +} + +var errNoDestinations = errors.New("magicsock: no destinations") + +func (c *Conn) Send(b []byte, ep device.Endpoint) error { + as := ep.(*AddrSet) + + var addrBuf [8]*net.UDPAddr + dsts, roamAddr := appendDests(addrBuf[:0], as, b) + + if len(dsts) == 0 { + return errNoDestinations + } + + var success bool + var ret error + for _, addr := range dsts { + err := c.sendAddr(addr, as.publicKey, b) + if err == nil { + success = true + } else if ret == nil { + ret = err + } + if err != nil && addr != roamAddr { + log.Printf("magicsock: Conn.Send(%v): %v", addr, err) } + } + if success { + return nil + } + return ret +} - if msgType == device.MessageInitiationType { - // Send initial handshake messages via DERP. - c.derpMu.Lock() - derp := c.derp - c.derpMu.Unlock() +var errConnClosed = errors.New("Conn closed") - if derp != nil { - if err := derp.Send(a.publicKey, b); err != nil { - log.Printf("derp send failed: %v", err) - } +var errDropDerpPacket = errors.New("too many DERP packets queued; dropping") + +// sendAddr sends packet b to addr, which is either a real UDP address +// or a fake UDP address representing a DERP server (see derpmap.go). +// The provided public key identifies the recipient. +func (c *Conn) sendAddr(addr *net.UDPAddr, pubKey key.Public, b []byte) error { + if ch := c.derpWriteChanOfAddr(addr); ch != nil { + errc := make(chan error, 1) + select { + case <-c.donec: + return errConnClosed + case ch <- derpWriteRequest{addr, pubKey, b, errc}: + select { + case <-c.donec: + return errConnClosed + case err := <-errc: + return err // usually nil } + default: + // Too many writes queued. Drop packet. + return errDropDerpPacket } + } + _, err := c.pconn.WriteTo(b, addr) + return err +} - if success { +// bufferedDerpWritesBeforeDrop is how many packets writes can be +// queued up the DERP client to write on the wire before we start +// dropping. +// +// TODO: this is currently arbitrary. Figure out something better? +const bufferedDerpWritesBeforeDrop = 4 + +// derpWriteChanOfAddr returns a DERP client for fake UDP addresses that +// represent DERP servers, creating them as necessary. For real UDP +// addresses, it returns nil. +func (c *Conn) derpWriteChanOfAddr(addr *net.UDPAddr) chan<- derpWriteRequest { + if !addr.IP.Equal(derpMagicIP) { + return nil + } + c.derpMu.Lock() + defer c.derpMu.Unlock() + ch, ok := c.derpWriteCh[addr.Port] + if !ok { + if c.derpWriteCh == nil { + c.derpWriteCh = make(map[int]chan<- derpWriteRequest) + c.derpConn = make(map[int]*derphttp.Client) + } + host := derpHost(addr.Port) + dc, err := derphttp.NewClient(c.privateKey, "https://"+host+"/derp", log.Printf) + if err != nil { + log.Printf("derphttp.NewClient: port %d, host %q invalid? err: %v", addr.Port, host, err) return nil } + + bidiCh := make(chan derpWriteRequest, bufferedDerpWritesBeforeDrop) + ch = bidiCh + c.derpConn[addr.Port] = dc + c.derpWriteCh[addr.Port] = ch + go c.runDerpReader(addr, dc) + go c.runDerpWriter(addr, dc, bidiCh) } + return ch +} - // Write to the highest-priority address we have seen so far. - _, err := c.pconn.WriteTo(b, a.dst()) - return err +// derpReadResult is the type sent by runDerpClient to ReceiveIPv4 +// when a DERP packet is available. +type derpReadResult struct { + derpAddr *net.UDPAddr + n int // length of data received + + // copyBuf is called to copy the data to dst. It returns how + // much data was copied, which will be n if dst is large + // enough. + copyBuf func(dst []byte) int +} + +// runDerpReader runs in a goroutine for the life of a DERP +// connection, handling received packets. +func (c *Conn) runDerpReader(derpFakeAddr *net.UDPAddr, dc *derphttp.Client) { + didCopy := make(chan struct{}, 1) + var buf [64 << 10]byte + var bufValid int // bytes in buf that are valid + copyFn := func(dst []byte) int { + n := copy(dst, buf[:bufValid]) + didCopy <- struct{}{} + return n + } + + for { + var err error // no := on next line to not shadow bufValid + bufValid, err = dc.Recv(buf[:]) + if err != nil { + if err == derphttp.ErrClientClosed { + return + } + select { + case <-c.donec: + return + default: + } + log.Printf("derp.Recv: %v", err) + time.Sleep(250 * time.Millisecond) + continue + } + log.Printf("got derp %v packet: %q", derpFakeAddr, buf[:bufValid]) + select { + case <-c.donec: + return + case c.derpRecvCh <- derpReadResult{derpFakeAddr, bufValid, copyFn}: + <-didCopy + } + } +} + +type derpWriteRequest struct { + addr *net.UDPAddr + pubKey key.Public + b []byte + errc chan<- error +} + +// runDerpWriter runs in a goroutine for the life of a DERP +// connection, handling received packets. +func (c *Conn) runDerpWriter(derpFakeAddr *net.UDPAddr, dc *derphttp.Client, ch <-chan derpWriteRequest) { + for { + select { + case <-c.donec: + return + case wr := <-ch: + err := dc.Send(wr.pubKey, wr.b) + if err != nil { + log.Printf("magicsock: derp.Send(%v): %v", wr.addr, err) + } + select { + case wr.errc <- err: + case <-c.donec: + return + } + } + } } func (c *Conn) findIndexedAddrSet(addr *net.UDPAddr) (addrSet *AddrSet, index int) { @@ -421,21 +596,75 @@ func (c *Conn) findIndexedAddrSet(addr *net.UDPAddr) (addrSet *AddrSet, index in return indAddr.addr, indAddr.index } +type udpReadResult struct { + n int + err error + addr *net.UDPAddr +} + +// aLongTimeAgo is a non-zero time, far in the past, used for +// immediate cancellation of network operations. +var aLongTimeAgo = time.Unix(233431200, 0) + func (c *Conn) ReceiveIPv4(b []byte) (n int, ep device.Endpoint, addr *net.UDPAddr, err error) { - // Read a packet, and process any STUN packets before returning. - for { - var pAddr net.Addr - n, pAddr, err = c.pconn.ReadFrom(b) - if err != nil { - return n, nil, nil, err + go func() { + // Read a packet, and process any STUN packets before returning. + for { + var pAddr net.Addr + n, pAddr, err = c.pconn.ReadFrom(b) + if err != nil { + select { + case c.udpRecvCh <- udpReadResult{err: err}: + case <-c.donec: + } + return + } + if stun.Is(b[:n]) { + c.stunReceiveFunc.Load().(func([]byte, *net.UDPAddr))(b, addr) + continue + } + + addr := pAddr.(*net.UDPAddr) + addr.IP = addr.IP.To4() + select { + case c.udpRecvCh <- udpReadResult{n: n, addr: addr}: + case <-c.donec: + } + return } - addr = pAddr.(*net.UDPAddr) - addr.IP = addr.IP.To4() + }() - if !stun.Is(b[:n]) { - break + select { + case dm := <-c.derpRecvCh: + // Cancel the pconn read goroutine + c.pconn.SetReadDeadline(aLongTimeAgo) + select { + case <-c.udpRecvCh: + // It's likely an error, since we just canceled the read. + // But there's a small window where the pconn.ReadFrom could've + // succeeded but not yet sent, and we got into the derp recv path + // first. In that case this udpReadResult is a real non-err packet + // and we need to choose which to use. Currently, arbitrarily, we currently + // select DERP and discard this result entirely. + // The main point of this receive, though, is to make sure that the goroutine + // is done with our b []byte buf. + c.pconn.SetReadDeadline(time.Time{}) + case <-c.donec: + return 0, nil, nil, errors.New("Conn closed") } - c.stunReceiveFunc.Load().(func([]byte, *net.UDPAddr))(b, addr) + n, addr = dm.n, dm.derpAddr + ncopy := dm.copyBuf(b) + if ncopy != n { + err = fmt.Errorf("received DERP packet of length %d that's too big for WireGuard ReceiveIPv4 buf size %d", n, ncopy) + log.Printf("magicsock: %v", err) + return 0, nil, nil, err + } + + case um := <-c.udpRecvCh: + if um.err != nil { + return 0, nil, nil, err + } + n, addr = um.n, um.addr } addrSet, _ := c.findIndexedAddrSet(addr) @@ -455,51 +684,23 @@ func (c *Conn) ReceiveIPv6(buff []byte) (int, device.Endpoint, *net.UDPAddr, err } func (c *Conn) SetPrivateKey(privateKey wgcfg.PrivateKey) error { - if c.derpServer == "" { - return nil - } - - derp, err := derphttp.NewClient(key.Private(privateKey), c.derpServer, log.Printf) - if err != nil { - return err - } - go func() { - var b [64 << 10]byte - for { - n, err := derp.Recv(b[:]) - if err != nil { - if err == derphttp.ErrClientClosed { - return - } - log.Printf("derp.Recv: %v", err) - time.Sleep(250 * time.Millisecond) - } - - c.reSTUN() - - addr := c.pconn.LocalAddr() - if _, err := c.pconn.WriteToUDP(b[:n], addr); err != nil { - log.Printf("%v", err) - } - } - }() - - c.derpMu.Lock() - if c.derp != nil { - if err := c.derp.Close(); err != nil { - log.Printf("derp.Close: %v", err) - } - } - c.derp = derp - c.derpMu.Unlock() - + c.privateKey = key.Private(privateKey) return nil } func (c *Conn) SetMark(value uint32) error { return nil } func (c *Conn) Close() error { + select { + case <-c.donec: + return nil + default: + } + close(c.donec) c.epUpdateCancel() + for _, dc := range c.derpConn { + dc.Close() + } return c.pconn.Close() } @@ -543,8 +744,16 @@ type AddrSet struct { publicKey key.Public // peer public key used for DERP communication addrs []net.UDPAddr // ordered priority list provided by wgengine - mu sync.Mutex // guards roamAddr and curAddr - roamAddr *net.UDPAddr // peer addr determined from incoming packets + mu sync.Mutex // guards roamAddr and curAddr + + // roamAddr is non-nil if/when we receive a correctly signed + // WireGuard packet from an unexpected address. If so, we + // remember it and send responses there in the future, but + // this should hopefully never be used (or at least used + // rarely) in the case that all the components of Tailscale + // are correctly learning/sharing the network map details. + roamAddr *net.UDPAddr + // curAddr is an index into addrs of the highest-priority // address a valid packet has been received from so far. // If no valid packet from addrs has been received, curAddr is -1. @@ -641,7 +850,7 @@ func (a *AddrSet) UpdateDst(new *net.UDPAddr) error { a.roamAddr = new case a.roamAddr != nil: - log.Printf("magicsock: rx %s from known %s (%d), replacs roaming address %s", pk, new, index, a.roamAddr) + log.Printf("magicsock: rx %s from known %s (%d), replaces roaming address %s", pk, new, index, a.roamAddr) a.roamAddr = nil a.curAddr = index @@ -814,6 +1023,12 @@ func (c *RebindingUDPConn) Close() error { return c.pconn.Close() } +func (c *RebindingUDPConn) SetReadDeadline(t time.Time) { + c.mu.Lock() + defer c.mu.Unlock() + c.pconn.SetReadDeadline(t) +} + func (c *RebindingUDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) { for { c.mu.Lock() diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index d0635a2fe..364041d98 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -71,3 +71,9 @@ func pickPort(t *testing.T) uint16 { defer conn.Close() return uint16(conn.LocalAddr().(*net.UDPAddr).Port) } + +func TestDerpIPConstant(t *testing.T) { + if derpMagicIPStr != derpMagicIP.String() { + t.Errorf("str %q != IP %v", derpMagicIPStr, derpMagicIP) + } +}