cmd/derper: support forwarding packets amongst set of peer DERP servers

Updates #388

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
reviewable/pr396/r8
Brad Fitzpatrick 5 years ago committed by Brad Fitzpatrick
parent e441d3218e
commit 1cb7dab881

@ -134,10 +134,9 @@ func main() {
s.SetMeshKey(key) s.SetMeshKey(key)
log.Printf("DERP mesh key configured") log.Printf("DERP mesh key configured")
} }
if err := startMesh(s); err != nil {
// TODO(bradfitz): parse & use the *meshWith log.Fatalf("startMesh: %v", err)
_ = *meshWith }
expvar.Publish("derp", s.ExpVar()) expvar.Publish("derp", s.ExpVar())
// Create our own mux so we don't expose /debug/ stuff to the world. // Create our own mux so we don't expose /debug/ stuff to the world.

@ -0,0 +1,147 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"errors"
"fmt"
"log"
"strings"
"sync"
"time"
"tailscale.com/derp"
"tailscale.com/derp/derphttp"
"tailscale.com/types/key"
"tailscale.com/types/logger"
)
func startMesh(s *derp.Server) error {
if *meshWith == "" {
return nil
}
if !s.HasMeshKey() {
return errors.New("--mesh-with requires --mesh-psk-file")
}
for _, host := range strings.Split(*meshWith, ",") {
if err := startMeshWithHost(s, host); err != nil {
return err
}
}
return nil
}
func startMeshWithHost(s *derp.Server, host string) error {
logf := logger.WithPrefix(log.Printf, fmt.Sprintf("mesh(%q): ", host))
c, err := derphttp.NewClient(s.PrivateKey(), "https://"+host+"/derp", logf)
if err != nil {
return err
}
c.MeshKey = s.MeshKey()
go runMeshClient(s, host, c, logf)
return nil
}
func runMeshClient(s *derp.Server, host string, c *derphttp.Client, logf logger.Logf) {
const retryInterval = 5 * time.Second
const statusInterval = 10 * time.Second
var (
mu sync.Mutex
present = map[key.Public]bool{}
loggedConnected = false
)
clear := func() {
mu.Lock()
defer mu.Unlock()
if len(present) == 0 {
return
}
logf("reconnected; clearing %d forwarding mappings", len(present))
for k := range present {
s.RemovePacketForwarder(k, c)
}
present = map[key.Public]bool{}
}
lastConnGen := 0
lastStatus := time.Now()
logConnectedLocked := func() {
if loggedConnected {
return
}
logf("connected; %d peers", len(present))
loggedConnected = true
}
const logConnectedDelay = 200 * time.Millisecond
timer := time.AfterFunc(2*time.Second, func() {
mu.Lock()
defer mu.Unlock()
logConnectedLocked()
})
defer timer.Stop()
updatePeer := func(k key.Public, isPresent bool) {
if isPresent {
s.AddPacketForwarder(k, c)
} else {
s.RemovePacketForwarder(k, c)
}
mu.Lock()
defer mu.Unlock()
if isPresent {
present[k] = true
if !loggedConnected {
timer.Reset(logConnectedDelay)
}
} else {
// If we got a peerGone message, that means the initial connection's
// flood of peerPresent messages is done, so we can log already:
logConnectedLocked()
delete(present, k)
}
}
for {
err := c.WatchConnectionChanges()
if err != nil {
clear()
logf("WatchConnectionChanges: %v", err)
time.Sleep(retryInterval)
continue
}
if c.ServerPublicKey() == s.PublicKey() {
logf("detected self-connect; ignoring host")
return
}
for {
var buf [64 << 10]byte
m, connGen, err := c.RecvDetail(buf[:])
if err != nil {
clear()
logf("Recv: %v", err)
time.Sleep(retryInterval)
break
}
if connGen != lastConnGen {
lastConnGen = connGen
clear()
}
switch m := m.(type) {
case derp.PeerPresentMessage:
updatePeer(key.Public(m), true)
case derp.PeerGoneMessage:
updatePeer(key.Public(m), false)
default:
continue
}
if now := time.Now(); now.Sub(lastStatus) > statusInterval {
lastStatus = now
logf("%d peers", len(present))
}
}
}
}

@ -72,6 +72,7 @@ const (
frameClientInfo = frameType(0x02) // 32B pub key + 24B nonce + naclbox(json) frameClientInfo = frameType(0x02) // 32B pub key + 24B nonce + naclbox(json)
frameServerInfo = frameType(0x03) // 24B nonce + naclbox(json) frameServerInfo = frameType(0x03) // 24B nonce + naclbox(json)
frameSendPacket = frameType(0x04) // 32B dest pub key + packet bytes frameSendPacket = frameType(0x04) // 32B dest pub key + packet bytes
frameForwardPacket = frameType(0x0a) // 32B src pub key + 32B dst pub key + packet bytes
frameRecvPacket = frameType(0x05) // v0/1: packet bytes, v2: 32B src pub key + packet bytes frameRecvPacket = frameType(0x05) // v0/1: packet bytes, v2: 32B src pub key + packet bytes
frameKeepAlive = frameType(0x06) // no payload, no-op (to be replaced with ping/pong) frameKeepAlive = frameType(0x06) // no payload, no-op (to be replaced with ping/pong)
frameNotePreferred = frameType(0x07) // 1 byte payload: 0x01 or 0x00 for whether this is client's home node frameNotePreferred = frameType(0x07) // 1 byte payload: 0x01 or 0x00 for whether this is client's home node

@ -19,6 +19,7 @@ import (
"tailscale.com/types/logger" "tailscale.com/types/logger"
) )
// Client is a DERP client.
type Client struct { type Client struct {
serverKey key.Public // of the DERP server; not a machine or node key serverKey key.Public // of the DERP server; not a machine or node key
privateKey key.Private privateKey key.Private
@ -170,6 +171,9 @@ func (c *Client) sendClientKey() error {
return writeFrame(c.bw, frameClientInfo, buf) return writeFrame(c.bw, frameClientInfo, buf)
} }
// ServerPublicKey returns the server's public key.
func (c *Client) ServerPublicKey() key.Public { return c.serverKey }
// Send sends a packet to the Tailscale node identified by dstKey. // Send sends a packet to the Tailscale node identified by dstKey.
// //
// It is an error if the packet is larger than 64KB. // It is an error if the packet is larger than 64KB.
@ -201,6 +205,40 @@ func (c *Client) send(dstKey key.Public, pkt []byte) (ret error) {
return c.bw.Flush() return c.bw.Flush()
} }
func (c *Client) ForwardPacket(srcKey, dstKey key.Public, pkt []byte) (err error) {
defer func() {
if err != nil {
err = fmt.Errorf("derp.ForwardPacket: %w", err)
}
}()
if len(pkt) > MaxPacketSize {
return fmt.Errorf("packet too big: %d", len(pkt))
}
c.wmu.Lock()
defer c.wmu.Unlock()
timer := time.AfterFunc(5*time.Second, c.writeTimeoutFired)
defer timer.Stop()
if err := writeFrameHeader(c.bw, frameForwardPacket, uint32(keyLen*2+len(pkt))); err != nil {
return err
}
if _, err := c.bw.Write(srcKey[:]); err != nil {
return err
}
if _, err := c.bw.Write(dstKey[:]); err != nil {
return err
}
if _, err := c.bw.Write(pkt); err != nil {
return err
}
return c.bw.Flush()
}
func (c *Client) writeTimeoutFired() { c.nc.Close() }
// NotePreferred sends a packet that tells the server whether this // NotePreferred sends a packet that tells the server whether this
// client is the user's preferred server. This is only used in the // client is the user's preferred server. This is only used in the
// server for stats. // server for stats.

@ -50,30 +50,46 @@ type Server struct {
meshKey string meshKey string
// Counters: // Counters:
packetsSent, bytesSent expvar.Int packetsSent, bytesSent expvar.Int
packetsRecv, bytesRecv expvar.Int packetsRecv, bytesRecv expvar.Int
packetsDropped expvar.Int packetsDropped expvar.Int
packetsDroppedReason metrics.LabelMap packetsDroppedReason metrics.LabelMap
packetsDroppedUnknown *expvar.Int // unknown dst pubkey packetsDroppedUnknown *expvar.Int // unknown dst pubkey
packetsDroppedGone *expvar.Int // dst conn shutting down packetsDroppedFwdUnknown *expvar.Int // unknown dst pubkey on forward
packetsDroppedQueueHead *expvar.Int // queue full, drop head packet packetsDroppedGone *expvar.Int // dst conn shutting down
packetsDroppedQueueTail *expvar.Int // queue full, drop tail packet packetsDroppedQueueHead *expvar.Int // queue full, drop head packet
packetsDroppedWrite *expvar.Int // error writing to dst conn packetsDroppedQueueTail *expvar.Int // queue full, drop tail packet
peerGoneFrames expvar.Int // number of peer gone frames sent packetsDroppedWrite *expvar.Int // error writing to dst conn
accepts expvar.Int packetsForwardedOut expvar.Int
curClients expvar.Int packetsForwardedIn expvar.Int
curHomeClients expvar.Int // ones with preferred peerGoneFrames expvar.Int // number of peer gone frames sent
clientsReplaced expvar.Int accepts expvar.Int
unknownFrames expvar.Int curClients expvar.Int
homeMovesIn expvar.Int // established clients announce home server moves in curHomeClients expvar.Int // ones with preferred
homeMovesOut expvar.Int // established clients announce home server moves out clientsReplaced expvar.Int
unknownFrames expvar.Int
homeMovesIn expvar.Int // established clients announce home server moves in
homeMovesOut expvar.Int // established clients announce home server moves out
multiForwarderCreated expvar.Int
multiForwarderDeleted expvar.Int
mu sync.Mutex mu sync.Mutex
closed bool closed bool
netConns map[Conn]chan struct{} // chan is closed when conn closes netConns map[Conn]chan struct{} // chan is closed when conn closes
clients map[key.Public]*sclient clients map[key.Public]*sclient
clientsEver map[key.Public]bool // never deleted from, for stats; fine for now clientsEver map[key.Public]bool // never deleted from, for stats; fine for now
watchers map[*sclient]bool // mesh peer -> true watchers map[*sclient]bool // mesh peer -> true
clientsMesh map[key.Public]PacketForwarder // clients connected to mesh peers; nil means only in clients, not remote
}
// PacketForwarder is something that can forward packets.
//
// It's mostly an inteface for circular dependency reasons; the
// typical implementation is derphttp.Client. The other implementation
// is a multiForwarder, which this package creates as needed if a
// public key gets more than one PacketForwarder registered for it.
type PacketForwarder interface {
ForwardPacket(src, dst key.Public, payload []byte) error
} }
// Conn is the subset of the underlying net.Conn the DERP Server needs. // Conn is the subset of the underlying net.Conn the DERP Server needs.
@ -101,11 +117,13 @@ func NewServer(privateKey key.Private, logf logger.Logf) *Server {
packetsDroppedReason: metrics.LabelMap{Label: "reason"}, packetsDroppedReason: metrics.LabelMap{Label: "reason"},
clients: make(map[key.Public]*sclient), clients: make(map[key.Public]*sclient),
clientsEver: make(map[key.Public]bool), clientsEver: make(map[key.Public]bool),
clientsMesh: map[key.Public]PacketForwarder{},
netConns: make(map[Conn]chan struct{}), netConns: make(map[Conn]chan struct{}),
memSys0: ms.Sys, memSys0: ms.Sys,
watchers: map[*sclient]bool{}, watchers: map[*sclient]bool{},
} }
s.packetsDroppedUnknown = s.packetsDroppedReason.Get("unknown_dest") s.packetsDroppedUnknown = s.packetsDroppedReason.Get("unknown_dest")
s.packetsDroppedFwdUnknown = s.packetsDroppedReason.Get("unknown_dest_on_fwd")
s.packetsDroppedGone = s.packetsDroppedReason.Get("gone") s.packetsDroppedGone = s.packetsDroppedReason.Get("gone")
s.packetsDroppedQueueHead = s.packetsDroppedReason.Get("queue_head") s.packetsDroppedQueueHead = s.packetsDroppedReason.Get("queue_head")
s.packetsDroppedQueueTail = s.packetsDroppedReason.Get("queue_tail") s.packetsDroppedQueueTail = s.packetsDroppedReason.Get("queue_tail")
@ -210,6 +228,9 @@ func (s *Server) registerClient(c *sclient) {
} }
s.clients[c.key] = c s.clients[c.key] = c
s.clientsEver[c.key] = true s.clientsEver[c.key] = true
if _, ok := s.clientsMesh[c.key]; !ok {
s.clientsMesh[c.key] = nil // just for varz of total users in cluster
}
s.curClients.Add(1) s.curClients.Add(1)
s.broadcastPeerStateChangeLocked(c.key, true) s.broadcastPeerStateChangeLocked(c.key, true)
} }
@ -238,6 +259,9 @@ func (s *Server) unregisterClient(c *sclient) {
if c.canMesh { if c.canMesh {
delete(s.watchers, c) delete(s.watchers, c)
} }
if v, ok := s.clientsMesh[c.key]; ok && v == nil {
delete(s.clientsMesh, c.key)
}
s.broadcastPeerStateChangeLocked(c.key, false) s.broadcastPeerStateChangeLocked(c.key, false)
s.curClients.Add(-1) s.curClients.Add(-1)
@ -271,8 +295,6 @@ func (s *Server) addWatcher(c *sclient) {
if c.key == s.publicKey { if c.key == s.publicKey {
// We're connecting to ourself. Do nothing. // We're connecting to ourself. Do nothing.
// TODO(bradfitz): have client notice and disconnect
// so an idle TCP connection isn't kept open.
return return
} }
@ -378,6 +400,8 @@ func (c *sclient) run(ctx context.Context) error {
err = c.handleFrameNotePreferred(ft, fl) err = c.handleFrameNotePreferred(ft, fl)
case frameSendPacket: case frameSendPacket:
err = c.handleFrameSendPacket(ft, fl) err = c.handleFrameSendPacket(ft, fl)
case frameForwardPacket:
err = c.handleFrameForwardPacket(ft, fl)
case frameWatchConns: case frameWatchConns:
err = c.handleFrameWatchConns(ft, fl) err = c.handleFrameWatchConns(ft, fl)
default: default:
@ -417,6 +441,42 @@ func (c *sclient) handleFrameWatchConns(ft frameType, fl uint32) error {
return nil return nil
} }
// handleFrameForwardPacket reads a "forward packet" frame from the client
// (which must be a trusted client, a peer in our mesh).
func (c *sclient) handleFrameForwardPacket(ft frameType, fl uint32) error {
if !c.canMesh {
return fmt.Errorf("insufficient permissions")
}
s := c.s
srcKey, dstKey, contents, err := s.recvForwardPacket(c.br, fl)
if err != nil {
return fmt.Errorf("client %x: recvForwardPacket: %v", c.key, err)
}
s.packetsForwardedIn.Add(1)
s.mu.Lock()
dst := s.clients[dstKey]
// TODO(bradfitz): think about the sentTo/Issue 150 optimization
// in the context of DERP meshes.
s.mu.Unlock()
if dst == nil {
s.packetsDropped.Add(1)
s.packetsDroppedFwdUnknown.Add(1)
if debug {
c.logf("dropping forwarded packet for unknown %x", dstKey)
}
return nil
}
return c.sendPkt(dst, pkt{
bs: contents,
src: srcKey,
})
}
// handleFrameSendPacket reads a "send packet" frame from the client.
func (c *sclient) handleFrameSendPacket(ft frameType, fl uint32) error { func (c *sclient) handleFrameSendPacket(ft frameType, fl uint32) error {
s := c.s s := c.s
@ -425,9 +485,12 @@ func (c *sclient) handleFrameSendPacket(ft frameType, fl uint32) error {
return fmt.Errorf("client %x: recvPacket: %v", c.key, err) return fmt.Errorf("client %x: recvPacket: %v", c.key, err)
} }
var fwd PacketForwarder
s.mu.Lock() s.mu.Lock()
dst := s.clients[dstKey] dst := s.clients[dstKey]
if dst != nil { if dst == nil {
fwd = s.clientsMesh[dstKey]
} else {
// Track that we've sent to this peer, so if/when we // Track that we've sent to this peer, so if/when we
// disconnect first, the server can inform all our old // disconnect first, the server can inform all our old
// recipients that we're gone. (Issue 150 optimization) // recipients that we're gone. (Issue 150 optimization)
@ -436,6 +499,14 @@ func (c *sclient) handleFrameSendPacket(ft frameType, fl uint32) error {
s.mu.Unlock() s.mu.Unlock()
if dst == nil { if dst == nil {
if fwd != nil {
s.packetsForwardedOut.Add(1)
if err := fwd.ForwardPacket(c.key, dstKey, contents); err != nil {
// TODO:
return nil
}
return nil
}
s.packetsDropped.Add(1) s.packetsDropped.Add(1)
s.packetsDroppedUnknown.Add(1) s.packetsDroppedUnknown.Add(1)
if debug { if debug {
@ -450,6 +521,13 @@ func (c *sclient) handleFrameSendPacket(ft frameType, fl uint32) error {
if dst.info.Version >= protocolSrcAddrs { if dst.info.Version >= protocolSrcAddrs {
p.src = c.key p.src = c.key
} }
return c.sendPkt(dst, p)
}
func (c *sclient) sendPkt(dst *sclient, p pkt) error {
s := c.s
dstKey := dst.key
// Attempt to queue for sending up to 3 times. On each attempt, if // Attempt to queue for sending up to 3 times. On each attempt, if
// the queue is full, try to drop from queue head to prioritize // the queue is full, try to drop from queue head to prioritize
// fresher packets. // fresher packets.
@ -615,6 +693,29 @@ func (s *Server) recvPacket(br *bufio.Reader, frameLen uint32) (dstKey key.Publi
// zpub is the key.Public zero value. // zpub is the key.Public zero value.
var zpub key.Public var zpub key.Public
func (s *Server) recvForwardPacket(br *bufio.Reader, frameLen uint32) (srcKey, dstKey key.Public, contents []byte, err error) {
if frameLen < keyLen*2 {
return zpub, zpub, nil, errors.New("short send packet frame")
}
if _, err := io.ReadFull(br, srcKey[:]); err != nil {
return zpub, zpub, nil, err
}
if _, err := io.ReadFull(br, dstKey[:]); err != nil {
return zpub, zpub, nil, err
}
packetLen := frameLen - keyLen*2
if packetLen > MaxPacketSize {
return zpub, zpub, nil, fmt.Errorf("data packet longer (%d) than max of %v", packetLen, MaxPacketSize)
}
contents = make([]byte, packetLen)
if _, err := io.ReadFull(br, contents); err != nil {
return zpub, zpub, nil, err
}
// TODO: was s.packetsRecv.Add(1)
// TODO: was s.bytesRecv.Add(int64(len(contents)))
return srcKey, dstKey, contents, nil
}
// sclient is a client connection to the server. // sclient is a client connection to the server.
// //
// (The "s" prefix is to more explicitly distinguish it from Client in derp_client.go) // (The "s" prefix is to more explicitly distinguish it from Client in derp_client.go)
@ -889,6 +990,108 @@ func (c *sclient) sendPacket(srcKey key.Public, contents []byte) (err error) {
return err return err
} }
// AddPacketForwarder registers fwd as a packet forwarder for dst.
// fwd must be comparable.
func (s *Server) AddPacketForwarder(dst key.Public, fwd PacketForwarder) {
s.mu.Lock()
defer s.mu.Unlock()
if prev, ok := s.clientsMesh[dst]; ok {
if prev == fwd {
// Duplicate registration of same forwarder. Ignore.
return
}
if m, ok := prev.(multiForwarder); ok {
if _, ok := m[fwd]; !ok {
// Duplicate registration of same forwarder in set; ignore.
return
}
m[fwd] = m.maxVal() + 1
return
}
// Otherwise, the existing value is not a set and not a dup, so make it a set.
fwd = multiForwarder{
prev: 1, // existed 1st, higher priority
fwd: 2, // the passed in fwd is in 2nd place
}
s.multiForwarderCreated.Add(1)
}
s.clientsMesh[dst] = fwd
}
// RemovePacketForwarder removes fwd as a packet forwarder for dst.
// fwd must be comparable.
func (s *Server) RemovePacketForwarder(dst key.Public, fwd PacketForwarder) {
s.mu.Lock()
defer s.mu.Unlock()
v, ok := s.clientsMesh[dst]
if !ok {
return
}
if m, ok := v.(multiForwarder); ok {
if len(m) < 2 {
panic("unexpected")
}
delete(m, fwd)
// If fwd was in m and we no longer need to be a
// multiForwarder, replace the entry with the
// remaining PacketForwarder.
if len(m) == 1 {
var remain PacketForwarder
for k := range m {
remain = k
}
s.clientsMesh[dst] = remain
s.multiForwarderDeleted.Add(1)
}
return
}
if v != fwd {
// Delete of an entry that wasn't in the
// map. Harmless, so ignore.
// (This might happen if a user is moving around
// between nodes and/or the server sent duplicate
// connection change broadcasts.)
return
}
if _, isLocal := s.clients[dst]; isLocal {
s.clientsMesh[dst] = nil
} else {
delete(s.clientsMesh, dst)
}
}
// multiForwarder is a PacketForwarder that represents a set of
// forwarding options. It's used in the rare cases that a client is
// connected to multiple DERP nodes in a region. That shouldn't really
// happen except for perhaps during brief moments while the client is
// reconfiguring, in which case we don't want to forget where the
// client is. The map value is unique connection number; the lowest
// one has been seen the longest. It's used to make sure we forward
// packets consistently to the same node and don't pick randomly.
type multiForwarder map[PacketForwarder]uint8
func (m multiForwarder) maxVal() (max uint8) {
for _, v := range m {
if v > max {
max = v
}
}
return
}
func (m multiForwarder) ForwardPacket(src, dst key.Public, payload []byte) error {
var fwd PacketForwarder
var lowest uint8
for k, v := range m {
if fwd == nil || v < lowest {
fwd = k
lowest = v
}
}
return fwd.ForwardPacket(src, dst, payload)
}
func (s *Server) expVarFunc(f func() interface{}) expvar.Func { func (s *Server) expVarFunc(f func() interface{}) expvar.Func {
return expvar.Func(func() interface{} { return expvar.Func(func() interface{} {
s.mu.Lock() s.mu.Lock()
@ -905,6 +1108,8 @@ func (s *Server) ExpVar() expvar.Var {
m.Set("gauge_watchers", s.expVarFunc(func() interface{} { return len(s.watchers) })) m.Set("gauge_watchers", s.expVarFunc(func() interface{} { return len(s.watchers) }))
m.Set("gauge_current_connnections", &s.curClients) m.Set("gauge_current_connnections", &s.curClients)
m.Set("gauge_current_home_connnections", &s.curHomeClients) m.Set("gauge_current_home_connnections", &s.curHomeClients)
m.Set("gauge_clients_total", expvar.Func(func() interface{} { return len(s.clientsMesh) }))
m.Set("gauge_clients_remote", expvar.Func(func() interface{} { return len(s.clientsMesh) - len(s.clients) }))
m.Set("accepts", &s.accepts) m.Set("accepts", &s.accepts)
m.Set("clients_replaced", &s.clientsReplaced) m.Set("clients_replaced", &s.clientsReplaced)
m.Set("bytes_received", &s.bytesRecv) m.Set("bytes_received", &s.bytesRecv)
@ -917,5 +1122,9 @@ func (s *Server) ExpVar() expvar.Var {
m.Set("home_moves_in", &s.homeMovesIn) m.Set("home_moves_in", &s.homeMovesIn)
m.Set("home_moves_out", &s.homeMovesOut) m.Set("home_moves_out", &s.homeMovesOut)
m.Set("peer_gone_frames", &s.peerGoneFrames) m.Set("peer_gone_frames", &s.peerGoneFrames)
m.Set("packets_forwarded_out", &s.packetsForwardedOut)
m.Set("packets_forwarded_in", &s.packetsForwardedIn)
m.Set("multiforwarder_created", &s.multiForwarderCreated)
m.Set("multiforwarder_deleted", &s.multiForwarderDeleted)
return m return m
} }

@ -13,6 +13,7 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"reflect"
"sync" "sync"
"testing" "testing"
"time" "time"
@ -619,3 +620,116 @@ func TestWatch(t *testing.T) {
w2.wantGone(t, c1.pub) w2.wantGone(t, c1.pub)
w3.wantGone(t, c1.pub) w3.wantGone(t, c1.pub)
} }
type testFwd int
func (testFwd) ForwardPacket(key.Public, key.Public, []byte) error { panic("not called in tests") }
func pubAll(b byte) (ret key.Public) {
for i := range ret {
ret[i] = b
}
return
}
func TestForwarderRegistration(t *testing.T) {
s := &Server{
clients: make(map[key.Public]*sclient),
clientsMesh: map[key.Public]PacketForwarder{},
}
want := func(want map[key.Public]PacketForwarder) {
t.Helper()
if got := s.clientsMesh; !reflect.DeepEqual(got, want) {
t.Fatalf("mismatch\n got: %v\nwant: %v\n", got, want)
}
}
wantCounter := func(c *expvar.Int, want int) {
t.Helper()
if got := c.Value(); got != int64(want) {
t.Errorf("counter = %v; want %v", got, want)
}
}
u1 := pubAll(1)
u2 := pubAll(2)
u3 := pubAll(3)
s.AddPacketForwarder(u1, testFwd(1))
s.AddPacketForwarder(u2, testFwd(2))
want(map[key.Public]PacketForwarder{
u1: testFwd(1),
u2: testFwd(2),
})
// Verify a remove of non-registered forwarder is no-op.
s.RemovePacketForwarder(u2, testFwd(999))
want(map[key.Public]PacketForwarder{
u1: testFwd(1),
u2: testFwd(2),
})
// Verify a remove of non-registered user is no-op.
s.RemovePacketForwarder(u3, testFwd(1))
want(map[key.Public]PacketForwarder{
u1: testFwd(1),
u2: testFwd(2),
})
// Actual removal.
s.RemovePacketForwarder(u2, testFwd(2))
want(map[key.Public]PacketForwarder{
u1: testFwd(1),
})
// Adding a dup for a user.
wantCounter(&s.multiForwarderCreated, 0)
s.AddPacketForwarder(u1, testFwd(100))
want(map[key.Public]PacketForwarder{
u1: multiForwarder{
testFwd(1): 1,
testFwd(100): 2,
},
})
wantCounter(&s.multiForwarderCreated, 1)
// Removing a forwarder in a multi set that doesn't exist; does nothing.
s.RemovePacketForwarder(u1, testFwd(55))
want(map[key.Public]PacketForwarder{
u1: multiForwarder{
testFwd(1): 1,
testFwd(100): 2,
},
})
// Removing a forwarder in a multi set that does exist should collapse it away
// from being a multiForwarder.
wantCounter(&s.multiForwarderDeleted, 0)
s.RemovePacketForwarder(u1, testFwd(1))
want(map[key.Public]PacketForwarder{
u1: testFwd(100),
})
wantCounter(&s.multiForwarderDeleted, 1)
// Removing an entry for a client that's still connected locally should result
// in a nil forwarder.
u1c := &sclient{
key: u1,
logf: logger.Discard,
}
s.clients[u1] = u1c
s.RemovePacketForwarder(u1, testFwd(100))
want(map[key.Public]PacketForwarder{
u1: nil,
})
// But once that client disconnects, it should go away.
s.unregisterClient(u1c)
want(map[key.Public]PacketForwarder{})
// But if it already has a forwarder, it's not removed.
s.AddPacketForwarder(u1, testFwd(2))
s.unregisterClient(u1c)
want(map[key.Public]PacketForwarder{
u1: testFwd(2),
})
}

@ -114,6 +114,9 @@ func (c *Client) Connect(ctx context.Context) error {
} }
// ServerPublicKey returns the server's public key. // ServerPublicKey returns the server's public key.
//
// It only returns a non-zero value once a connection has succeeded
// from an earlier call.
func (c *Client) ServerPublicKey() key.Public { func (c *Client) ServerPublicKey() key.Public {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@ -293,6 +296,7 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien
} }
} }
c.serverPubKey = derpClient.ServerPublicKey()
c.client = derpClient c.client = derpClient
c.netConn = tcpConn c.netConn = tcpConn
c.connGen++ c.connGen++
@ -484,6 +488,17 @@ func (c *Client) Send(dstKey key.Public, b []byte) error {
return err return err
} }
func (c *Client) ForwardPacket(from, to key.Public, b []byte) error {
client, _, err := c.connect(context.TODO(), "derphttp.Client.ForwardPacket")
if err != nil {
return err
}
if err := client.ForwardPacket(from, to, b); err != nil {
c.closeForReconnect(client)
}
return err
}
// NotePreferred notes whether this Client is the caller's preferred // NotePreferred notes whether this Client is the caller's preferred
// (home) DERP node. It's only used for stats. // (home) DERP node. It's only used for stats.
func (c *Client) NotePreferred(v bool) { func (c *Client) NotePreferred(v bool) {

Loading…
Cancel
Save