From aef3c0350c3b10b41418c5532e3a33a8cc11dc7b Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Sun, 18 Apr 2021 19:49:25 -0700 Subject: [PATCH] control/controlclient: break direct.go into map.go (+tests), add mapSession So the NetworkMap-from-incremental-MapResponses can be tested easily. And because direct.go was getting too big. No change in behavior at this point. Just movement. Signed-off-by: Brad Fitzpatrick --- control/controlclient/direct.go | 238 +++--------------------- control/controlclient/direct_test.go | 156 ---------------- control/controlclient/filter.go | 20 -- control/controlclient/map.go | 268 +++++++++++++++++++++++++++ control/controlclient/map_test.go | 167 +++++++++++++++++ 5 files changed, 461 insertions(+), 388 deletions(-) delete mode 100644 control/controlclient/filter.go create mode 100644 control/controlclient/map.go create mode 100644 control/controlclient/map_test.go diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go index eb4816ff6..6f1c91a60 100644 --- a/control/controlclient/direct.go +++ b/control/controlclient/direct.go @@ -23,7 +23,6 @@ import ( "path/filepath" "reflect" "runtime" - "sort" "strconv" "strings" "sync" @@ -49,7 +48,6 @@ import ( "tailscale.com/util/dnsname" "tailscale.com/util/systemd" "tailscale.com/version" - "tailscale.com/wgengine/filter" "tailscale.com/wgengine/monitor" ) @@ -729,11 +727,12 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, cb func(*netm } }() - var lastDNSConfig = new(tailcfg.DNSConfig) - var lastDERPMap *tailcfg.DERPMap - var lastUserProfile = map[tailcfg.UserID]tailcfg.UserProfile{} - var lastParsedPacketFilter []filter.Match - var collectServices bool + sess := newMapSession() + sess.logf = c.logf + sess.vlogf = vlogf + sess.persist = persist + sess.machinePubKey = machinePubKey + sess.keepSharerAndUserSplit = c.keepSharerAndUserSplit // If allowStream, then the server will use an HTTP long poll to // return incremental results. There is always one response right @@ -742,7 +741,6 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, cb func(*netm // the same format before just closing the connection. // We can use this same read loop either way. var msg []byte - var previousPeers []*tailcfg.Node // for delta-purposes for i := 0; i < maxPolls || maxPolls < 0; i++ { vlogf("netmap: starting size read after %v (poll %v)", time.Since(t0).Round(time.Millisecond), i) var siz [4]byte @@ -789,16 +787,6 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, cb func(*netm continue } - undeltaPeers(&resp, previousPeers) - previousPeers = cloneNodes(resp.Peers) // defensive/lazy clone, since this escapes to who knows where - for _, up := range resp.UserProfiles { - lastUserProfile[up.ID] = up - } - - if resp.DERPMap != nil { - vlogf("netmap: new map contains DERP map") - lastDERPMap = resp.DERPMap - } if resp.Debug != nil { if resp.Debug.LogHeapPprof { go logheap.LogHeap(resp.Debug.LogHeapURL) @@ -809,17 +797,30 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, cb func(*netm setControlAtomic(&controlUseDERPRoute, resp.Debug.DERPRoute) setControlAtomic(&controlTrimWGConfig, resp.Debug.TrimWGConfig) } + + nm := sess.netmapForResponse(&resp) + // Temporarily (2020-06-29) support removing all but // discovery-supporting nodes during development, for // less noise. if Debug.OnlyDisco { - filtered := resp.Peers[:0] - for _, p := range resp.Peers { - if !p.DiscoKey.IsZero() { - filtered = append(filtered, p) + anyOld, numDisco := false, 0 + for _, p := range nm.Peers { + if p.DiscoKey.IsZero() { + anyOld = true + } else { + numDisco++ + } + } + if anyOld { + filtered := make([]*tailcfg.Node, 0, numDisco) + for _, p := range nm.Peers { + if !p.DiscoKey.IsZero() { + filtered = append(filtered, p) + } } + nm.Peers = filtered } - resp.Peers = filtered } if Debug.StripEndpoints { for _, p := range resp.Peers { @@ -830,18 +831,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, cb func(*netm } } if Debug.StripCaps { - resp.Node.Capabilities = nil - } - - if pf := resp.PacketFilter; pf != nil { - lastParsedPacketFilter = c.parsePacketFilter(pf) - } - if c := resp.DNSConfig; c != nil { - lastDNSConfig = c - } - - if v, ok := resp.CollectServices.Get(); ok { - collectServices = v + nm.SelfNode.Capabilities = nil } // Get latest localPort. This might've changed if @@ -849,67 +839,9 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, cb func(*netm // the end-to-end test. // TODO(bradfitz): remove the NetworkMap.LocalPort field entirely. c.mu.Lock() - localPort = c.localPort + nm.LocalPort = c.localPort c.mu.Unlock() - nm := &netmap.NetworkMap{ - SelfNode: resp.Node, - NodeKey: tailcfg.NodeKey(persist.PrivateNodeKey.Public()), - PrivateKey: persist.PrivateNodeKey, - MachineKey: machinePubKey, - Expiry: resp.Node.KeyExpiry, - Name: resp.Node.Name, - Addresses: resp.Node.Addresses, - Peers: resp.Peers, - LocalPort: localPort, - User: resp.Node.User, - UserProfiles: make(map[tailcfg.UserID]tailcfg.UserProfile), - Domain: resp.Domain, - DNS: *lastDNSConfig, - Hostinfo: resp.Node.Hostinfo, - PacketFilter: lastParsedPacketFilter, - CollectServices: collectServices, - DERPMap: lastDERPMap, - Debug: resp.Debug, - } - addUserProfile := func(userID tailcfg.UserID) { - if _, dup := nm.UserProfiles[userID]; dup { - // Already populated it from a previous peer. - return - } - if up, ok := lastUserProfile[userID]; ok { - nm.UserProfiles[userID] = up - } - } - addUserProfile(nm.User) - magicDNSSuffix := nm.MagicDNSSuffix() - nm.SelfNode.InitDisplayNames(magicDNSSuffix) - for _, peer := range resp.Peers { - peer.InitDisplayNames(magicDNSSuffix) - if !peer.Sharer.IsZero() { - if c.keepSharerAndUserSplit { - addUserProfile(peer.Sharer) - } else { - peer.User = peer.Sharer - } - } - addUserProfile(peer.User) - } - if resp.Node.MachineAuthorized { - nm.MachineStatus = tailcfg.MachineAuthorized - } else { - nm.MachineStatus = tailcfg.MachineUnauthorized - } - if len(resp.DNS) > 0 { - nm.DNS.Nameservers = resp.DNS - } - if len(resp.SearchPaths) > 0 { - nm.DNS.Domains = resp.SearchPaths - } - if Debug.ProxyDNS { - nm.DNS.Proxied = true - } - // Printing the netmap can be extremely verbose, but is very // handy for debugging. Let's limit how often we do it. // Code elsewhere prints netmap diffs every time, so this @@ -1104,124 +1036,6 @@ func envBool(k string) bool { var clockNow = time.Now -// undeltaPeers updates mapRes.Peers to be complete based on the -// provided previous peer list and the PeersRemoved and PeersChanged -// fields in mapRes, as well as the PeerSeenChange and OnlineChange -// maps. -// -// It then also nils out the delta fields. -func undeltaPeers(mapRes *tailcfg.MapResponse, prev []*tailcfg.Node) { - if len(mapRes.Peers) > 0 { - // Not delta encoded. - if !nodesSorted(mapRes.Peers) { - log.Printf("netmap: undeltaPeers: MapResponse.Peers not sorted; sorting") - sortNodes(mapRes.Peers) - } - return - } - - var removed map[tailcfg.NodeID]bool - if pr := mapRes.PeersRemoved; len(pr) > 0 { - removed = make(map[tailcfg.NodeID]bool, len(pr)) - for _, id := range pr { - removed[id] = true - } - } - changed := mapRes.PeersChanged - - if !nodesSorted(changed) { - log.Printf("netmap: undeltaPeers: MapResponse.PeersChanged not sorted; sorting") - sortNodes(changed) - } - if !nodesSorted(prev) { - // Internal error (unrelated to the network) if we get here. - log.Printf("netmap: undeltaPeers: [unexpected] prev not sorted; sorting") - sortNodes(prev) - } - - newFull := prev - if len(removed) > 0 || len(changed) > 0 { - newFull = make([]*tailcfg.Node, 0, len(prev)-len(removed)) - for len(prev) > 0 && len(changed) > 0 { - pID := prev[0].ID - cID := changed[0].ID - if removed[pID] { - prev = prev[1:] - continue - } - switch { - case pID < cID: - newFull = append(newFull, prev[0]) - prev = prev[1:] - case pID == cID: - newFull = append(newFull, changed[0]) - prev, changed = prev[1:], changed[1:] - case cID < pID: - newFull = append(newFull, changed[0]) - changed = changed[1:] - } - } - newFull = append(newFull, changed...) - for _, n := range prev { - if !removed[n.ID] { - newFull = append(newFull, n) - } - } - sortNodes(newFull) - } - - if len(mapRes.PeerSeenChange) != 0 || len(mapRes.OnlineChange) != 0 { - peerByID := make(map[tailcfg.NodeID]*tailcfg.Node, len(newFull)) - for _, n := range newFull { - peerByID[n.ID] = n - } - now := clockNow() - for nodeID, seen := range mapRes.PeerSeenChange { - if n, ok := peerByID[nodeID]; ok { - if seen { - n.LastSeen = &now - } else { - n.LastSeen = nil - } - } - } - for nodeID, online := range mapRes.OnlineChange { - if n, ok := peerByID[nodeID]; ok { - online := online - n.Online = &online - } - } - } - - mapRes.Peers = newFull - mapRes.PeersChanged = nil - mapRes.PeersRemoved = nil -} - -func nodesSorted(v []*tailcfg.Node) bool { - for i, n := range v { - if i > 0 && n.ID <= v[i-1].ID { - return false - } - } - return true -} - -func sortNodes(v []*tailcfg.Node) { - sort.Slice(v, func(i, j int) bool { return v[i].ID < v[j].ID }) -} - -func cloneNodes(v1 []*tailcfg.Node) []*tailcfg.Node { - if v1 == nil { - return nil - } - v2 := make([]*tailcfg.Node, len(v1)) - for i, n := range v1 { - v2[i] = n.Clone() - } - return v2 -} - // opt.Bool configs from control. var ( controlUseDERPRoute atomic.Value diff --git a/control/controlclient/direct_test.go b/control/controlclient/direct_test.go index 69c6ff31e..4d2607616 100644 --- a/control/controlclient/direct_test.go +++ b/control/controlclient/direct_test.go @@ -6,169 +6,13 @@ package controlclient import ( "encoding/json" - "fmt" - "reflect" - "strings" "testing" - "time" "inet.af/netaddr" "tailscale.com/tailcfg" "tailscale.com/types/wgkey" ) -func TestUndeltaPeers(t *testing.T) { - defer func(old func() time.Time) { clockNow = old }(clockNow) - - var curTime time.Time - clockNow = func() time.Time { - return curTime - } - online := func(v bool) func(*tailcfg.Node) { - return func(n *tailcfg.Node) { - n.Online = &v - } - } - seenAt := func(t time.Time) func(*tailcfg.Node) { - return func(n *tailcfg.Node) { - n.LastSeen = &t - } - } - n := func(id tailcfg.NodeID, name string, mod ...func(*tailcfg.Node)) *tailcfg.Node { - n := &tailcfg.Node{ID: id, Name: name} - for _, f := range mod { - f(n) - } - return n - } - peers := func(nv ...*tailcfg.Node) []*tailcfg.Node { return nv } - tests := []struct { - name string - mapRes *tailcfg.MapResponse - curTime time.Time - prev []*tailcfg.Node - want []*tailcfg.Node - }{ - { - name: "full_peers", - mapRes: &tailcfg.MapResponse{ - Peers: peers(n(1, "foo"), n(2, "bar")), - }, - want: peers(n(1, "foo"), n(2, "bar")), - }, - { - name: "full_peers_ignores_deltas", - mapRes: &tailcfg.MapResponse{ - Peers: peers(n(1, "foo"), n(2, "bar")), - PeersRemoved: []tailcfg.NodeID{2}, - }, - want: peers(n(1, "foo"), n(2, "bar")), - }, - { - name: "add_and_update", - prev: peers(n(1, "foo"), n(2, "bar")), - mapRes: &tailcfg.MapResponse{ - PeersChanged: peers(n(0, "zero"), n(2, "bar2"), n(3, "three")), - }, - want: peers(n(0, "zero"), n(1, "foo"), n(2, "bar2"), n(3, "three")), - }, - { - name: "remove", - prev: peers(n(1, "foo"), n(2, "bar")), - mapRes: &tailcfg.MapResponse{ - PeersRemoved: []tailcfg.NodeID{1}, - }, - want: peers(n(2, "bar")), - }, - { - name: "add_and_remove", - prev: peers(n(1, "foo"), n(2, "bar")), - mapRes: &tailcfg.MapResponse{ - PeersChanged: peers(n(1, "foo2")), - PeersRemoved: []tailcfg.NodeID{2}, - }, - want: peers(n(1, "foo2")), - }, - { - name: "unchanged", - prev: peers(n(1, "foo"), n(2, "bar")), - mapRes: &tailcfg.MapResponse{}, - want: peers(n(1, "foo"), n(2, "bar")), - }, - { - name: "online_change", - prev: peers(n(1, "foo"), n(2, "bar")), - mapRes: &tailcfg.MapResponse{ - OnlineChange: map[tailcfg.NodeID]bool{ - 1: true, - }, - }, - want: peers( - n(1, "foo", online(true)), - n(2, "bar"), - ), - }, - { - name: "online_change_offline", - prev: peers(n(1, "foo"), n(2, "bar")), - mapRes: &tailcfg.MapResponse{ - OnlineChange: map[tailcfg.NodeID]bool{ - 1: false, - 2: true, - }, - }, - want: peers( - n(1, "foo", online(false)), - n(2, "bar", online(true)), - ), - }, - { - name: "peer_seen_at", - prev: peers(n(1, "foo", seenAt(time.Unix(111, 0))), n(2, "bar")), - curTime: time.Unix(123, 0), - mapRes: &tailcfg.MapResponse{ - PeerSeenChange: map[tailcfg.NodeID]bool{ - 1: false, - 2: true, - }, - }, - want: peers( - n(1, "foo"), - n(2, "bar", seenAt(time.Unix(123, 0))), - ), - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if !tt.curTime.IsZero() { - curTime = tt.curTime - } - undeltaPeers(tt.mapRes, tt.prev) - if !reflect.DeepEqual(tt.mapRes.Peers, tt.want) { - t.Errorf("wrong results\n got: %s\nwant: %s", formatNodes(tt.mapRes.Peers), formatNodes(tt.want)) - } - }) - } -} - -func formatNodes(nodes []*tailcfg.Node) string { - var sb strings.Builder - for i, n := range nodes { - if i > 0 { - sb.WriteString(", ") - } - var extra string - if n.Online != nil { - extra += fmt.Sprintf(", online=%v", *n.Online) - } - if n.LastSeen != nil { - extra += fmt.Sprintf(", lastSeen=%v", n.LastSeen.Unix()) - } - fmt.Fprintf(&sb, "(%d, %q%s)", n.ID, n.Name, extra) - } - return sb.String() -} - func TestNewDirect(t *testing.T) { hi := NewHostinfo() ni := tailcfg.NetInfo{LinkType: "wired"} diff --git a/control/controlclient/filter.go b/control/controlclient/filter.go deleted file mode 100644 index 708d39ba0..000000000 --- a/control/controlclient/filter.go +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package controlclient - -import ( - "tailscale.com/tailcfg" - "tailscale.com/wgengine/filter" -) - -// Parse a backward-compatible FilterRule used by control's wire -// format, producing the most current filter format. -func (c *Direct) parsePacketFilter(pf []tailcfg.FilterRule) []filter.Match { - mm, err := filter.MatchesFromFilterRules(pf) - if err != nil { - c.logf("parsePacketFilter: %s\n", err) - } - return mm -} diff --git a/control/controlclient/map.go b/control/controlclient/map.go new file mode 100644 index 000000000..5f3d7a463 --- /dev/null +++ b/control/controlclient/map.go @@ -0,0 +1,268 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package controlclient + +import ( + "log" + "sort" + + "tailscale.com/tailcfg" + "tailscale.com/types/logger" + "tailscale.com/types/netmap" + "tailscale.com/types/persist" + "tailscale.com/wgengine/filter" +) + +// mapSession holds the state over a long-polled "map" request to the +// control plane. +// +// It accepts incremental tailcfg.MapResponse values to +// netMapForResponse and returns fully inflated NetworkMaps, filling +// in the omitted data implicit from prior MapResponse values from +// within the same session (the same long-poll HTTP response to the +// one MapRequest). +type mapSession struct { + // Immutable fields. + logf logger.Logf + vlogf logger.Logf + persist persist.Persist + machinePubKey tailcfg.MachineKey + keepSharerAndUserSplit bool // see Options.KeepSharerAndUserSplit + + // Fields storing state over the the coards of multiple MapResponses. + lastDNSConfig *tailcfg.DNSConfig + lastDERPMap *tailcfg.DERPMap + lastUserProfile map[tailcfg.UserID]tailcfg.UserProfile + lastParsedPacketFilter []filter.Match + collectServices bool + previousPeers []*tailcfg.Node // for delta-purposes + + // netMapBuilding is non-nil during a netmapForResponse call, + // containing the value to be returned, once fully populated. + netMapBuilding *netmap.NetworkMap +} + +func newMapSession() *mapSession { + ms := &mapSession{ + logf: logger.Discard, + vlogf: logger.Discard, + lastDNSConfig: new(tailcfg.DNSConfig), + lastUserProfile: map[tailcfg.UserID]tailcfg.UserProfile{}, + } + return ms +} + +func (ms *mapSession) addUserProfile(userID tailcfg.UserID) { + nm := ms.netMapBuilding + if _, dup := nm.UserProfiles[userID]; dup { + // Already populated it from a previous peer. + return + } + if up, ok := ms.lastUserProfile[userID]; ok { + nm.UserProfiles[userID] = up + } +} + +// netmapForResponse returns a fully populated NetworkMap from a full +// or incremental MapResponse within the session, filling in omitted +// information from prior MapResponse values. +func (ms *mapSession) netmapForResponse(resp *tailcfg.MapResponse) *netmap.NetworkMap { + undeltaPeers(resp, ms.previousPeers) + + ms.previousPeers = cloneNodes(resp.Peers) // defensive/lazy clone, since this escapes to who knows where + for _, up := range resp.UserProfiles { + ms.lastUserProfile[up.ID] = up + } + + if resp.DERPMap != nil { + ms.vlogf("netmap: new map contains DERP map") + ms.lastDERPMap = resp.DERPMap + } + + if pf := resp.PacketFilter; pf != nil { + var err error + ms.lastParsedPacketFilter, err = filter.MatchesFromFilterRules(pf) + if err != nil { + ms.logf("parsePacketFilter: %v", err) + } + } + if c := resp.DNSConfig; c != nil { + ms.lastDNSConfig = c + } + + if v, ok := resp.CollectServices.Get(); ok { + ms.collectServices = v + } + + nm := &netmap.NetworkMap{ + SelfNode: resp.Node, + NodeKey: tailcfg.NodeKey(ms.persist.PrivateNodeKey.Public()), + PrivateKey: ms.persist.PrivateNodeKey, + MachineKey: ms.machinePubKey, + Expiry: resp.Node.KeyExpiry, + Name: resp.Node.Name, + Addresses: resp.Node.Addresses, + Peers: resp.Peers, + User: resp.Node.User, + UserProfiles: make(map[tailcfg.UserID]tailcfg.UserProfile), + Domain: resp.Domain, + DNS: *ms.lastDNSConfig, + Hostinfo: resp.Node.Hostinfo, + PacketFilter: ms.lastParsedPacketFilter, + CollectServices: ms.collectServices, + DERPMap: ms.lastDERPMap, + Debug: resp.Debug, + } + ms.netMapBuilding = nm + + ms.addUserProfile(nm.User) + magicDNSSuffix := nm.MagicDNSSuffix() + nm.SelfNode.InitDisplayNames(magicDNSSuffix) + for _, peer := range resp.Peers { + peer.InitDisplayNames(magicDNSSuffix) + if !peer.Sharer.IsZero() { + if ms.keepSharerAndUserSplit { + ms.addUserProfile(peer.Sharer) + } else { + peer.User = peer.Sharer + } + } + ms.addUserProfile(peer.User) + } + if resp.Node.MachineAuthorized { + nm.MachineStatus = tailcfg.MachineAuthorized + } else { + nm.MachineStatus = tailcfg.MachineUnauthorized + } + if len(resp.DNS) > 0 { + nm.DNS.Nameservers = resp.DNS + } + if len(resp.SearchPaths) > 0 { + nm.DNS.Domains = resp.SearchPaths + } + if Debug.ProxyDNS { + nm.DNS.Proxied = true + } + ms.netMapBuilding = nil + return nm +} + +// undeltaPeers updates mapRes.Peers to be complete based on the +// provided previous peer list and the PeersRemoved and PeersChanged +// fields in mapRes, as well as the PeerSeenChange and OnlineChange +// maps. +// +// It then also nils out the delta fields. +func undeltaPeers(mapRes *tailcfg.MapResponse, prev []*tailcfg.Node) { + if len(mapRes.Peers) > 0 { + // Not delta encoded. + if !nodesSorted(mapRes.Peers) { + log.Printf("netmap: undeltaPeers: MapResponse.Peers not sorted; sorting") + sortNodes(mapRes.Peers) + } + return + } + + var removed map[tailcfg.NodeID]bool + if pr := mapRes.PeersRemoved; len(pr) > 0 { + removed = make(map[tailcfg.NodeID]bool, len(pr)) + for _, id := range pr { + removed[id] = true + } + } + changed := mapRes.PeersChanged + + if !nodesSorted(changed) { + log.Printf("netmap: undeltaPeers: MapResponse.PeersChanged not sorted; sorting") + sortNodes(changed) + } + if !nodesSorted(prev) { + // Internal error (unrelated to the network) if we get here. + log.Printf("netmap: undeltaPeers: [unexpected] prev not sorted; sorting") + sortNodes(prev) + } + + newFull := prev + if len(removed) > 0 || len(changed) > 0 { + newFull = make([]*tailcfg.Node, 0, len(prev)-len(removed)) + for len(prev) > 0 && len(changed) > 0 { + pID := prev[0].ID + cID := changed[0].ID + if removed[pID] { + prev = prev[1:] + continue + } + switch { + case pID < cID: + newFull = append(newFull, prev[0]) + prev = prev[1:] + case pID == cID: + newFull = append(newFull, changed[0]) + prev, changed = prev[1:], changed[1:] + case cID < pID: + newFull = append(newFull, changed[0]) + changed = changed[1:] + } + } + newFull = append(newFull, changed...) + for _, n := range prev { + if !removed[n.ID] { + newFull = append(newFull, n) + } + } + sortNodes(newFull) + } + + if len(mapRes.PeerSeenChange) != 0 || len(mapRes.OnlineChange) != 0 { + peerByID := make(map[tailcfg.NodeID]*tailcfg.Node, len(newFull)) + for _, n := range newFull { + peerByID[n.ID] = n + } + now := clockNow() + for nodeID, seen := range mapRes.PeerSeenChange { + if n, ok := peerByID[nodeID]; ok { + if seen { + n.LastSeen = &now + } else { + n.LastSeen = nil + } + } + } + for nodeID, online := range mapRes.OnlineChange { + if n, ok := peerByID[nodeID]; ok { + online := online + n.Online = &online + } + } + } + + mapRes.Peers = newFull + mapRes.PeersChanged = nil + mapRes.PeersRemoved = nil +} + +func nodesSorted(v []*tailcfg.Node) bool { + for i, n := range v { + if i > 0 && n.ID <= v[i-1].ID { + return false + } + } + return true +} + +func sortNodes(v []*tailcfg.Node) { + sort.Slice(v, func(i, j int) bool { return v[i].ID < v[j].ID }) +} + +func cloneNodes(v1 []*tailcfg.Node) []*tailcfg.Node { + if v1 == nil { + return nil + } + v2 := make([]*tailcfg.Node, len(v1)) + for i, n := range v1 { + v2[i] = n.Clone() + } + return v2 +} diff --git a/control/controlclient/map_test.go b/control/controlclient/map_test.go new file mode 100644 index 000000000..137dd2863 --- /dev/null +++ b/control/controlclient/map_test.go @@ -0,0 +1,167 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package controlclient + +import ( + "fmt" + "reflect" + "strings" + "testing" + "time" + + "tailscale.com/tailcfg" +) + +func TestUndeltaPeers(t *testing.T) { + defer func(old func() time.Time) { clockNow = old }(clockNow) + + var curTime time.Time + clockNow = func() time.Time { + return curTime + } + online := func(v bool) func(*tailcfg.Node) { + return func(n *tailcfg.Node) { + n.Online = &v + } + } + seenAt := func(t time.Time) func(*tailcfg.Node) { + return func(n *tailcfg.Node) { + n.LastSeen = &t + } + } + n := func(id tailcfg.NodeID, name string, mod ...func(*tailcfg.Node)) *tailcfg.Node { + n := &tailcfg.Node{ID: id, Name: name} + for _, f := range mod { + f(n) + } + return n + } + peers := func(nv ...*tailcfg.Node) []*tailcfg.Node { return nv } + tests := []struct { + name string + mapRes *tailcfg.MapResponse + curTime time.Time + prev []*tailcfg.Node + want []*tailcfg.Node + }{ + { + name: "full_peers", + mapRes: &tailcfg.MapResponse{ + Peers: peers(n(1, "foo"), n(2, "bar")), + }, + want: peers(n(1, "foo"), n(2, "bar")), + }, + { + name: "full_peers_ignores_deltas", + mapRes: &tailcfg.MapResponse{ + Peers: peers(n(1, "foo"), n(2, "bar")), + PeersRemoved: []tailcfg.NodeID{2}, + }, + want: peers(n(1, "foo"), n(2, "bar")), + }, + { + name: "add_and_update", + prev: peers(n(1, "foo"), n(2, "bar")), + mapRes: &tailcfg.MapResponse{ + PeersChanged: peers(n(0, "zero"), n(2, "bar2"), n(3, "three")), + }, + want: peers(n(0, "zero"), n(1, "foo"), n(2, "bar2"), n(3, "three")), + }, + { + name: "remove", + prev: peers(n(1, "foo"), n(2, "bar")), + mapRes: &tailcfg.MapResponse{ + PeersRemoved: []tailcfg.NodeID{1}, + }, + want: peers(n(2, "bar")), + }, + { + name: "add_and_remove", + prev: peers(n(1, "foo"), n(2, "bar")), + mapRes: &tailcfg.MapResponse{ + PeersChanged: peers(n(1, "foo2")), + PeersRemoved: []tailcfg.NodeID{2}, + }, + want: peers(n(1, "foo2")), + }, + { + name: "unchanged", + prev: peers(n(1, "foo"), n(2, "bar")), + mapRes: &tailcfg.MapResponse{}, + want: peers(n(1, "foo"), n(2, "bar")), + }, + { + name: "online_change", + prev: peers(n(1, "foo"), n(2, "bar")), + mapRes: &tailcfg.MapResponse{ + OnlineChange: map[tailcfg.NodeID]bool{ + 1: true, + }, + }, + want: peers( + n(1, "foo", online(true)), + n(2, "bar"), + ), + }, + { + name: "online_change_offline", + prev: peers(n(1, "foo"), n(2, "bar")), + mapRes: &tailcfg.MapResponse{ + OnlineChange: map[tailcfg.NodeID]bool{ + 1: false, + 2: true, + }, + }, + want: peers( + n(1, "foo", online(false)), + n(2, "bar", online(true)), + ), + }, + { + name: "peer_seen_at", + prev: peers(n(1, "foo", seenAt(time.Unix(111, 0))), n(2, "bar")), + curTime: time.Unix(123, 0), + mapRes: &tailcfg.MapResponse{ + PeerSeenChange: map[tailcfg.NodeID]bool{ + 1: false, + 2: true, + }, + }, + want: peers( + n(1, "foo"), + n(2, "bar", seenAt(time.Unix(123, 0))), + ), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if !tt.curTime.IsZero() { + curTime = tt.curTime + } + undeltaPeers(tt.mapRes, tt.prev) + if !reflect.DeepEqual(tt.mapRes.Peers, tt.want) { + t.Errorf("wrong results\n got: %s\nwant: %s", formatNodes(tt.mapRes.Peers), formatNodes(tt.want)) + } + }) + } +} + +func formatNodes(nodes []*tailcfg.Node) string { + var sb strings.Builder + for i, n := range nodes { + if i > 0 { + sb.WriteString(", ") + } + var extra string + if n.Online != nil { + extra += fmt.Sprintf(", online=%v", *n.Online) + } + if n.LastSeen != nil { + extra += fmt.Sprintf(", lastSeen=%v", n.LastSeen.Unix()) + } + fmt.Fprintf(&sb, "(%d, %q%s)", n.ID, n.Name, extra) + } + return sb.String() +}