From 1cb7dab8810df2b1be416926bce135e937f52afb Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Wed, 3 Jun 2020 14:42:20 -0700 Subject: [PATCH] cmd/derper: support forwarding packets amongst set of peer DERP servers Updates #388 Signed-off-by: Brad Fitzpatrick --- cmd/derper/derper.go | 7 +- cmd/derper/mesh.go | 147 ++++++++++++++++++ derp/derp.go | 1 + derp/derp_client.go | 38 +++++ derp/derp_server.go | 253 ++++++++++++++++++++++++++++--- derp/derp_test.go | 114 ++++++++++++++ derp/derphttp/derphttp_client.go | 15 ++ 7 files changed, 549 insertions(+), 26 deletions(-) create mode 100644 cmd/derper/mesh.go diff --git a/cmd/derper/derper.go b/cmd/derper/derper.go index edbd68462..ef7a56f52 100644 --- a/cmd/derper/derper.go +++ b/cmd/derper/derper.go @@ -134,10 +134,9 @@ func main() { s.SetMeshKey(key) log.Printf("DERP mesh key configured") } - - // TODO(bradfitz): parse & use the *meshWith - _ = *meshWith - + if err := startMesh(s); err != nil { + log.Fatalf("startMesh: %v", err) + } expvar.Publish("derp", s.ExpVar()) // Create our own mux so we don't expose /debug/ stuff to the world. diff --git a/cmd/derper/mesh.go b/cmd/derper/mesh.go new file mode 100644 index 000000000..b5b5f2190 --- /dev/null +++ b/cmd/derper/mesh.go @@ -0,0 +1,147 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "errors" + "fmt" + "log" + "strings" + "sync" + "time" + + "tailscale.com/derp" + "tailscale.com/derp/derphttp" + "tailscale.com/types/key" + "tailscale.com/types/logger" +) + +func startMesh(s *derp.Server) error { + if *meshWith == "" { + return nil + } + if !s.HasMeshKey() { + return errors.New("--mesh-with requires --mesh-psk-file") + } + for _, host := range strings.Split(*meshWith, ",") { + if err := startMeshWithHost(s, host); err != nil { + return err + } + } + return nil +} + +func startMeshWithHost(s *derp.Server, host string) error { + logf := logger.WithPrefix(log.Printf, fmt.Sprintf("mesh(%q): ", host)) + c, err := derphttp.NewClient(s.PrivateKey(), "https://"+host+"/derp", logf) + if err != nil { + return err + } + c.MeshKey = s.MeshKey() + go runMeshClient(s, host, c, logf) + return nil +} + +func runMeshClient(s *derp.Server, host string, c *derphttp.Client, logf logger.Logf) { + const retryInterval = 5 * time.Second + const statusInterval = 10 * time.Second + var ( + mu sync.Mutex + present = map[key.Public]bool{} + loggedConnected = false + ) + clear := func() { + mu.Lock() + defer mu.Unlock() + if len(present) == 0 { + return + } + logf("reconnected; clearing %d forwarding mappings", len(present)) + for k := range present { + s.RemovePacketForwarder(k, c) + } + present = map[key.Public]bool{} + } + lastConnGen := 0 + lastStatus := time.Now() + logConnectedLocked := func() { + if loggedConnected { + return + } + logf("connected; %d peers", len(present)) + loggedConnected = true + } + + const logConnectedDelay = 200 * time.Millisecond + timer := time.AfterFunc(2*time.Second, func() { + mu.Lock() + defer mu.Unlock() + logConnectedLocked() + }) + defer timer.Stop() + + updatePeer := func(k key.Public, isPresent bool) { + if isPresent { + s.AddPacketForwarder(k, c) + } else { + s.RemovePacketForwarder(k, c) + } + + mu.Lock() + defer mu.Unlock() + if isPresent { + present[k] = true + if !loggedConnected { + timer.Reset(logConnectedDelay) + } + } else { + // If we got a peerGone message, that means the initial connection's + // flood of peerPresent messages is done, so we can log already: + logConnectedLocked() + delete(present, k) + } + } + + for { + err := c.WatchConnectionChanges() + if err != nil { + clear() + logf("WatchConnectionChanges: %v", err) + time.Sleep(retryInterval) + continue + } + + if c.ServerPublicKey() == s.PublicKey() { + logf("detected self-connect; ignoring host") + return + } + for { + var buf [64 << 10]byte + m, connGen, err := c.RecvDetail(buf[:]) + if err != nil { + clear() + logf("Recv: %v", err) + time.Sleep(retryInterval) + break + } + if connGen != lastConnGen { + lastConnGen = connGen + clear() + } + switch m := m.(type) { + case derp.PeerPresentMessage: + updatePeer(key.Public(m), true) + case derp.PeerGoneMessage: + updatePeer(key.Public(m), false) + default: + continue + } + if now := time.Now(); now.Sub(lastStatus) > statusInterval { + lastStatus = now + logf("%d peers", len(present)) + } + } + } +} diff --git a/derp/derp.go b/derp/derp.go index 08e4f20ae..9616a0814 100644 --- a/derp/derp.go +++ b/derp/derp.go @@ -72,6 +72,7 @@ const ( frameClientInfo = frameType(0x02) // 32B pub key + 24B nonce + naclbox(json) frameServerInfo = frameType(0x03) // 24B nonce + naclbox(json) frameSendPacket = frameType(0x04) // 32B dest pub key + packet bytes + frameForwardPacket = frameType(0x0a) // 32B src pub key + 32B dst pub key + packet bytes frameRecvPacket = frameType(0x05) // v0/1: packet bytes, v2: 32B src pub key + packet bytes frameKeepAlive = frameType(0x06) // no payload, no-op (to be replaced with ping/pong) frameNotePreferred = frameType(0x07) // 1 byte payload: 0x01 or 0x00 for whether this is client's home node diff --git a/derp/derp_client.go b/derp/derp_client.go index 0f927a673..f3f38435b 100644 --- a/derp/derp_client.go +++ b/derp/derp_client.go @@ -19,6 +19,7 @@ import ( "tailscale.com/types/logger" ) +// Client is a DERP client. type Client struct { serverKey key.Public // of the DERP server; not a machine or node key privateKey key.Private @@ -170,6 +171,9 @@ func (c *Client) sendClientKey() error { return writeFrame(c.bw, frameClientInfo, buf) } +// ServerPublicKey returns the server's public key. +func (c *Client) ServerPublicKey() key.Public { return c.serverKey } + // Send sends a packet to the Tailscale node identified by dstKey. // // It is an error if the packet is larger than 64KB. @@ -201,6 +205,40 @@ func (c *Client) send(dstKey key.Public, pkt []byte) (ret error) { return c.bw.Flush() } +func (c *Client) ForwardPacket(srcKey, dstKey key.Public, pkt []byte) (err error) { + defer func() { + if err != nil { + err = fmt.Errorf("derp.ForwardPacket: %w", err) + } + }() + + if len(pkt) > MaxPacketSize { + return fmt.Errorf("packet too big: %d", len(pkt)) + } + + c.wmu.Lock() + defer c.wmu.Unlock() + + timer := time.AfterFunc(5*time.Second, c.writeTimeoutFired) + defer timer.Stop() + + if err := writeFrameHeader(c.bw, frameForwardPacket, uint32(keyLen*2+len(pkt))); err != nil { + return err + } + if _, err := c.bw.Write(srcKey[:]); err != nil { + return err + } + if _, err := c.bw.Write(dstKey[:]); err != nil { + return err + } + if _, err := c.bw.Write(pkt); err != nil { + return err + } + return c.bw.Flush() +} + +func (c *Client) writeTimeoutFired() { c.nc.Close() } + // NotePreferred sends a packet that tells the server whether this // client is the user's preferred server. This is only used in the // server for stats. diff --git a/derp/derp_server.go b/derp/derp_server.go index 0bfa2c21b..b0affe3dc 100644 --- a/derp/derp_server.go +++ b/derp/derp_server.go @@ -50,30 +50,46 @@ type Server struct { meshKey string // Counters: - packetsSent, bytesSent expvar.Int - packetsRecv, bytesRecv expvar.Int - packetsDropped expvar.Int - packetsDroppedReason metrics.LabelMap - packetsDroppedUnknown *expvar.Int // unknown dst pubkey - packetsDroppedGone *expvar.Int // dst conn shutting down - packetsDroppedQueueHead *expvar.Int // queue full, drop head packet - packetsDroppedQueueTail *expvar.Int // queue full, drop tail packet - packetsDroppedWrite *expvar.Int // error writing to dst conn - peerGoneFrames expvar.Int // number of peer gone frames sent - accepts expvar.Int - curClients expvar.Int - curHomeClients expvar.Int // ones with preferred - clientsReplaced expvar.Int - unknownFrames expvar.Int - homeMovesIn expvar.Int // established clients announce home server moves in - homeMovesOut expvar.Int // established clients announce home server moves out + packetsSent, bytesSent expvar.Int + packetsRecv, bytesRecv expvar.Int + packetsDropped expvar.Int + packetsDroppedReason metrics.LabelMap + packetsDroppedUnknown *expvar.Int // unknown dst pubkey + packetsDroppedFwdUnknown *expvar.Int // unknown dst pubkey on forward + packetsDroppedGone *expvar.Int // dst conn shutting down + packetsDroppedQueueHead *expvar.Int // queue full, drop head packet + packetsDroppedQueueTail *expvar.Int // queue full, drop tail packet + packetsDroppedWrite *expvar.Int // error writing to dst conn + packetsForwardedOut expvar.Int + packetsForwardedIn expvar.Int + peerGoneFrames expvar.Int // number of peer gone frames sent + accepts expvar.Int + curClients expvar.Int + curHomeClients expvar.Int // ones with preferred + clientsReplaced expvar.Int + unknownFrames expvar.Int + homeMovesIn expvar.Int // established clients announce home server moves in + homeMovesOut expvar.Int // established clients announce home server moves out + multiForwarderCreated expvar.Int + multiForwarderDeleted expvar.Int mu sync.Mutex closed bool netConns map[Conn]chan struct{} // chan is closed when conn closes clients map[key.Public]*sclient - clientsEver map[key.Public]bool // never deleted from, for stats; fine for now - watchers map[*sclient]bool // mesh peer -> true + clientsEver map[key.Public]bool // never deleted from, for stats; fine for now + watchers map[*sclient]bool // mesh peer -> true + clientsMesh map[key.Public]PacketForwarder // clients connected to mesh peers; nil means only in clients, not remote +} + +// PacketForwarder is something that can forward packets. +// +// It's mostly an inteface for circular dependency reasons; the +// typical implementation is derphttp.Client. The other implementation +// is a multiForwarder, which this package creates as needed if a +// public key gets more than one PacketForwarder registered for it. +type PacketForwarder interface { + ForwardPacket(src, dst key.Public, payload []byte) error } // Conn is the subset of the underlying net.Conn the DERP Server needs. @@ -101,11 +117,13 @@ func NewServer(privateKey key.Private, logf logger.Logf) *Server { packetsDroppedReason: metrics.LabelMap{Label: "reason"}, clients: make(map[key.Public]*sclient), clientsEver: make(map[key.Public]bool), + clientsMesh: map[key.Public]PacketForwarder{}, netConns: make(map[Conn]chan struct{}), memSys0: ms.Sys, watchers: map[*sclient]bool{}, } s.packetsDroppedUnknown = s.packetsDroppedReason.Get("unknown_dest") + s.packetsDroppedFwdUnknown = s.packetsDroppedReason.Get("unknown_dest_on_fwd") s.packetsDroppedGone = s.packetsDroppedReason.Get("gone") s.packetsDroppedQueueHead = s.packetsDroppedReason.Get("queue_head") s.packetsDroppedQueueTail = s.packetsDroppedReason.Get("queue_tail") @@ -210,6 +228,9 @@ func (s *Server) registerClient(c *sclient) { } s.clients[c.key] = c s.clientsEver[c.key] = true + if _, ok := s.clientsMesh[c.key]; !ok { + s.clientsMesh[c.key] = nil // just for varz of total users in cluster + } s.curClients.Add(1) s.broadcastPeerStateChangeLocked(c.key, true) } @@ -238,6 +259,9 @@ func (s *Server) unregisterClient(c *sclient) { if c.canMesh { delete(s.watchers, c) } + if v, ok := s.clientsMesh[c.key]; ok && v == nil { + delete(s.clientsMesh, c.key) + } s.broadcastPeerStateChangeLocked(c.key, false) s.curClients.Add(-1) @@ -271,8 +295,6 @@ func (s *Server) addWatcher(c *sclient) { if c.key == s.publicKey { // We're connecting to ourself. Do nothing. - // TODO(bradfitz): have client notice and disconnect - // so an idle TCP connection isn't kept open. return } @@ -378,6 +400,8 @@ func (c *sclient) run(ctx context.Context) error { err = c.handleFrameNotePreferred(ft, fl) case frameSendPacket: err = c.handleFrameSendPacket(ft, fl) + case frameForwardPacket: + err = c.handleFrameForwardPacket(ft, fl) case frameWatchConns: err = c.handleFrameWatchConns(ft, fl) default: @@ -417,6 +441,42 @@ func (c *sclient) handleFrameWatchConns(ft frameType, fl uint32) error { return nil } +// handleFrameForwardPacket reads a "forward packet" frame from the client +// (which must be a trusted client, a peer in our mesh). +func (c *sclient) handleFrameForwardPacket(ft frameType, fl uint32) error { + if !c.canMesh { + return fmt.Errorf("insufficient permissions") + } + s := c.s + + srcKey, dstKey, contents, err := s.recvForwardPacket(c.br, fl) + if err != nil { + return fmt.Errorf("client %x: recvForwardPacket: %v", c.key, err) + } + s.packetsForwardedIn.Add(1) + + s.mu.Lock() + dst := s.clients[dstKey] + // TODO(bradfitz): think about the sentTo/Issue 150 optimization + // in the context of DERP meshes. + s.mu.Unlock() + + if dst == nil { + s.packetsDropped.Add(1) + s.packetsDroppedFwdUnknown.Add(1) + if debug { + c.logf("dropping forwarded packet for unknown %x", dstKey) + } + return nil + } + + return c.sendPkt(dst, pkt{ + bs: contents, + src: srcKey, + }) +} + +// handleFrameSendPacket reads a "send packet" frame from the client. func (c *sclient) handleFrameSendPacket(ft frameType, fl uint32) error { s := c.s @@ -425,9 +485,12 @@ func (c *sclient) handleFrameSendPacket(ft frameType, fl uint32) error { return fmt.Errorf("client %x: recvPacket: %v", c.key, err) } + var fwd PacketForwarder s.mu.Lock() dst := s.clients[dstKey] - if dst != nil { + if dst == nil { + fwd = s.clientsMesh[dstKey] + } else { // Track that we've sent to this peer, so if/when we // disconnect first, the server can inform all our old // recipients that we're gone. (Issue 150 optimization) @@ -436,6 +499,14 @@ func (c *sclient) handleFrameSendPacket(ft frameType, fl uint32) error { s.mu.Unlock() if dst == nil { + if fwd != nil { + s.packetsForwardedOut.Add(1) + if err := fwd.ForwardPacket(c.key, dstKey, contents); err != nil { + // TODO: + return nil + } + return nil + } s.packetsDropped.Add(1) s.packetsDroppedUnknown.Add(1) if debug { @@ -450,6 +521,13 @@ func (c *sclient) handleFrameSendPacket(ft frameType, fl uint32) error { if dst.info.Version >= protocolSrcAddrs { p.src = c.key } + return c.sendPkt(dst, p) +} + +func (c *sclient) sendPkt(dst *sclient, p pkt) error { + s := c.s + dstKey := dst.key + // Attempt to queue for sending up to 3 times. On each attempt, if // the queue is full, try to drop from queue head to prioritize // fresher packets. @@ -615,6 +693,29 @@ func (s *Server) recvPacket(br *bufio.Reader, frameLen uint32) (dstKey key.Publi // zpub is the key.Public zero value. var zpub key.Public +func (s *Server) recvForwardPacket(br *bufio.Reader, frameLen uint32) (srcKey, dstKey key.Public, contents []byte, err error) { + if frameLen < keyLen*2 { + return zpub, zpub, nil, errors.New("short send packet frame") + } + if _, err := io.ReadFull(br, srcKey[:]); err != nil { + return zpub, zpub, nil, err + } + if _, err := io.ReadFull(br, dstKey[:]); err != nil { + return zpub, zpub, nil, err + } + packetLen := frameLen - keyLen*2 + if packetLen > MaxPacketSize { + return zpub, zpub, nil, fmt.Errorf("data packet longer (%d) than max of %v", packetLen, MaxPacketSize) + } + contents = make([]byte, packetLen) + if _, err := io.ReadFull(br, contents); err != nil { + return zpub, zpub, nil, err + } + // TODO: was s.packetsRecv.Add(1) + // TODO: was s.bytesRecv.Add(int64(len(contents))) + return srcKey, dstKey, contents, nil +} + // sclient is a client connection to the server. // // (The "s" prefix is to more explicitly distinguish it from Client in derp_client.go) @@ -889,6 +990,108 @@ func (c *sclient) sendPacket(srcKey key.Public, contents []byte) (err error) { return err } +// AddPacketForwarder registers fwd as a packet forwarder for dst. +// fwd must be comparable. +func (s *Server) AddPacketForwarder(dst key.Public, fwd PacketForwarder) { + s.mu.Lock() + defer s.mu.Unlock() + if prev, ok := s.clientsMesh[dst]; ok { + if prev == fwd { + // Duplicate registration of same forwarder. Ignore. + return + } + if m, ok := prev.(multiForwarder); ok { + if _, ok := m[fwd]; !ok { + // Duplicate registration of same forwarder in set; ignore. + return + } + m[fwd] = m.maxVal() + 1 + return + } + // Otherwise, the existing value is not a set and not a dup, so make it a set. + fwd = multiForwarder{ + prev: 1, // existed 1st, higher priority + fwd: 2, // the passed in fwd is in 2nd place + } + s.multiForwarderCreated.Add(1) + } + s.clientsMesh[dst] = fwd +} + +// RemovePacketForwarder removes fwd as a packet forwarder for dst. +// fwd must be comparable. +func (s *Server) RemovePacketForwarder(dst key.Public, fwd PacketForwarder) { + s.mu.Lock() + defer s.mu.Unlock() + v, ok := s.clientsMesh[dst] + if !ok { + return + } + if m, ok := v.(multiForwarder); ok { + if len(m) < 2 { + panic("unexpected") + } + delete(m, fwd) + // If fwd was in m and we no longer need to be a + // multiForwarder, replace the entry with the + // remaining PacketForwarder. + if len(m) == 1 { + var remain PacketForwarder + for k := range m { + remain = k + } + s.clientsMesh[dst] = remain + s.multiForwarderDeleted.Add(1) + } + return + } + if v != fwd { + // Delete of an entry that wasn't in the + // map. Harmless, so ignore. + // (This might happen if a user is moving around + // between nodes and/or the server sent duplicate + // connection change broadcasts.) + return + } + + if _, isLocal := s.clients[dst]; isLocal { + s.clientsMesh[dst] = nil + } else { + delete(s.clientsMesh, dst) + } +} + +// multiForwarder is a PacketForwarder that represents a set of +// forwarding options. It's used in the rare cases that a client is +// connected to multiple DERP nodes in a region. That shouldn't really +// happen except for perhaps during brief moments while the client is +// reconfiguring, in which case we don't want to forget where the +// client is. The map value is unique connection number; the lowest +// one has been seen the longest. It's used to make sure we forward +// packets consistently to the same node and don't pick randomly. +type multiForwarder map[PacketForwarder]uint8 + +func (m multiForwarder) maxVal() (max uint8) { + for _, v := range m { + if v > max { + max = v + } + } + return +} + +func (m multiForwarder) ForwardPacket(src, dst key.Public, payload []byte) error { + var fwd PacketForwarder + var lowest uint8 + for k, v := range m { + if fwd == nil || v < lowest { + fwd = k + lowest = v + } + } + return fwd.ForwardPacket(src, dst, payload) +} + func (s *Server) expVarFunc(f func() interface{}) expvar.Func { return expvar.Func(func() interface{} { s.mu.Lock() @@ -905,6 +1108,8 @@ func (s *Server) ExpVar() expvar.Var { m.Set("gauge_watchers", s.expVarFunc(func() interface{} { return len(s.watchers) })) m.Set("gauge_current_connnections", &s.curClients) m.Set("gauge_current_home_connnections", &s.curHomeClients) + m.Set("gauge_clients_total", expvar.Func(func() interface{} { return len(s.clientsMesh) })) + m.Set("gauge_clients_remote", expvar.Func(func() interface{} { return len(s.clientsMesh) - len(s.clients) })) m.Set("accepts", &s.accepts) m.Set("clients_replaced", &s.clientsReplaced) m.Set("bytes_received", &s.bytesRecv) @@ -917,5 +1122,9 @@ func (s *Server) ExpVar() expvar.Var { m.Set("home_moves_in", &s.homeMovesIn) m.Set("home_moves_out", &s.homeMovesOut) m.Set("peer_gone_frames", &s.peerGoneFrames) + m.Set("packets_forwarded_out", &s.packetsForwardedOut) + m.Set("packets_forwarded_in", &s.packetsForwardedIn) + m.Set("multiforwarder_created", &s.multiForwarderCreated) + m.Set("multiforwarder_deleted", &s.multiForwarderDeleted) return m } diff --git a/derp/derp_test.go b/derp/derp_test.go index e7deae0fb..e22d53c31 100644 --- a/derp/derp_test.go +++ b/derp/derp_test.go @@ -13,6 +13,7 @@ import ( "fmt" "io" "net" + "reflect" "sync" "testing" "time" @@ -619,3 +620,116 @@ func TestWatch(t *testing.T) { w2.wantGone(t, c1.pub) w3.wantGone(t, c1.pub) } + +type testFwd int + +func (testFwd) ForwardPacket(key.Public, key.Public, []byte) error { panic("not called in tests") } + +func pubAll(b byte) (ret key.Public) { + for i := range ret { + ret[i] = b + } + return +} + +func TestForwarderRegistration(t *testing.T) { + s := &Server{ + clients: make(map[key.Public]*sclient), + clientsMesh: map[key.Public]PacketForwarder{}, + } + want := func(want map[key.Public]PacketForwarder) { + t.Helper() + if got := s.clientsMesh; !reflect.DeepEqual(got, want) { + t.Fatalf("mismatch\n got: %v\nwant: %v\n", got, want) + } + } + wantCounter := func(c *expvar.Int, want int) { + t.Helper() + if got := c.Value(); got != int64(want) { + t.Errorf("counter = %v; want %v", got, want) + } + } + + u1 := pubAll(1) + u2 := pubAll(2) + u3 := pubAll(3) + + s.AddPacketForwarder(u1, testFwd(1)) + s.AddPacketForwarder(u2, testFwd(2)) + want(map[key.Public]PacketForwarder{ + u1: testFwd(1), + u2: testFwd(2), + }) + + // Verify a remove of non-registered forwarder is no-op. + s.RemovePacketForwarder(u2, testFwd(999)) + want(map[key.Public]PacketForwarder{ + u1: testFwd(1), + u2: testFwd(2), + }) + + // Verify a remove of non-registered user is no-op. + s.RemovePacketForwarder(u3, testFwd(1)) + want(map[key.Public]PacketForwarder{ + u1: testFwd(1), + u2: testFwd(2), + }) + + // Actual removal. + s.RemovePacketForwarder(u2, testFwd(2)) + want(map[key.Public]PacketForwarder{ + u1: testFwd(1), + }) + + // Adding a dup for a user. + wantCounter(&s.multiForwarderCreated, 0) + s.AddPacketForwarder(u1, testFwd(100)) + want(map[key.Public]PacketForwarder{ + u1: multiForwarder{ + testFwd(1): 1, + testFwd(100): 2, + }, + }) + wantCounter(&s.multiForwarderCreated, 1) + + // Removing a forwarder in a multi set that doesn't exist; does nothing. + s.RemovePacketForwarder(u1, testFwd(55)) + want(map[key.Public]PacketForwarder{ + u1: multiForwarder{ + testFwd(1): 1, + testFwd(100): 2, + }, + }) + + // Removing a forwarder in a multi set that does exist should collapse it away + // from being a multiForwarder. + wantCounter(&s.multiForwarderDeleted, 0) + s.RemovePacketForwarder(u1, testFwd(1)) + want(map[key.Public]PacketForwarder{ + u1: testFwd(100), + }) + wantCounter(&s.multiForwarderDeleted, 1) + + // Removing an entry for a client that's still connected locally should result + // in a nil forwarder. + u1c := &sclient{ + key: u1, + logf: logger.Discard, + } + s.clients[u1] = u1c + s.RemovePacketForwarder(u1, testFwd(100)) + want(map[key.Public]PacketForwarder{ + u1: nil, + }) + + // But once that client disconnects, it should go away. + s.unregisterClient(u1c) + want(map[key.Public]PacketForwarder{}) + + // But if it already has a forwarder, it's not removed. + s.AddPacketForwarder(u1, testFwd(2)) + s.unregisterClient(u1c) + want(map[key.Public]PacketForwarder{ + u1: testFwd(2), + }) +} diff --git a/derp/derphttp/derphttp_client.go b/derp/derphttp/derphttp_client.go index ba166be03..c7975eb89 100644 --- a/derp/derphttp/derphttp_client.go +++ b/derp/derphttp/derphttp_client.go @@ -114,6 +114,9 @@ func (c *Client) Connect(ctx context.Context) error { } // ServerPublicKey returns the server's public key. +// +// It only returns a non-zero value once a connection has succeeded +// from an earlier call. func (c *Client) ServerPublicKey() key.Public { c.mu.Lock() defer c.mu.Unlock() @@ -293,6 +296,7 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien } } + c.serverPubKey = derpClient.ServerPublicKey() c.client = derpClient c.netConn = tcpConn c.connGen++ @@ -484,6 +488,17 @@ func (c *Client) Send(dstKey key.Public, b []byte) error { return err } +func (c *Client) ForwardPacket(from, to key.Public, b []byte) error { + client, _, err := c.connect(context.TODO(), "derphttp.Client.ForwardPacket") + if err != nil { + return err + } + if err := client.ForwardPacket(from, to, b); err != nil { + c.closeForReconnect(client) + } + return err +} + // NotePreferred notes whether this Client is the caller's preferred // (home) DERP node. It's only used for stats. func (c *Client) NotePreferred(v bool) {