From b5ff68a9684149d800840e31fe9b3ea45d016b7e Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Sun, 20 Aug 2023 19:52:52 -0700 Subject: [PATCH] control/controlclient: flesh out mapSession to break up gigantic method Now mapSession has a bunch more fields and methods, rather than being just one massive func with a ton of local variables. So far there are no major new optimizations, though. It should behave the same as before. This has been done with an eye towards testability (so tests can set all the callback funcs as needed, or not, without a huge Direct client or long-running HTTP requests), but this change doesn't add new tests yet. That will follow in the changes which flesh out the NetmapUpdater interface. Updates #1909 Change-Id: Iad4e7442d5bbbe2614bd4b1dc4b02e27504898df Signed-off-by: Brad Fitzpatrick --- control/controlclient/direct.go | 179 +++++++++++++----------------- control/controlclient/map.go | 138 ++++++++++++++++++++++- control/controlclient/map_test.go | 81 ++++++++++++-- 3 files changed, 284 insertions(+), 114 deletions(-) diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go index 76db38c13..5e731ce79 100644 --- a/control/controlclient/direct.go +++ b/control/controlclient/direct.go @@ -54,7 +54,6 @@ import ( "tailscale.com/types/ptr" "tailscale.com/types/tkatype" "tailscale.com/util/clientmetric" - "tailscale.com/util/cmpx" "tailscale.com/util/multierr" "tailscale.com/util/singleflight" "tailscale.com/util/systemd" @@ -806,10 +805,10 @@ func (c *Direct) SendUpdate(ctx context.Context) error { return c.sendMapRequest(ctx, false, nil) } -// If we go more than pollTimeout without hearing from the server, +// If we go more than watchdogTimeout without hearing from the server, // end the long poll. We should be receiving a keep alive ping // every minute. -const pollTimeout = 120 * time.Second +const watchdogTimeout = 120 * time.Second // sendMapRequest makes a /map request to download the network map, calling cb // with each new netmap. If isStreaming, it will poll forever and only returns @@ -956,39 +955,48 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap return nil } - timeout, timeoutChannel := c.clock.NewTimer(pollTimeout) - timeoutReset := make(chan struct{}) - pollDone := make(chan struct{}) - defer close(pollDone) - go func() { - for { - select { - case <-pollDone: - vlogf("netmap: ending timeout goroutine") - return - case <-timeoutChannel: - c.logf("map response long-poll timed out!") - cancel() - return - case <-timeoutReset: - if !timeout.Stop() { - select { - case <-timeoutChannel: - case <-pollDone: - vlogf("netmap: ending timeout goroutine") - return - } - } - vlogf("netmap: reset timeout timer") - timeout.Reset(pollTimeout) - } - } - }() + var mapResIdx int // 0 for first message, then 1+ for deltas - sess := newMapSession(persist.PrivateNodeKey()) + sess := newMapSession(persist.PrivateNodeKey(), nu) + defer sess.Close() + sess.cancel = cancel sess.logf = c.logf sess.vlogf = vlogf + sess.altClock = c.clock sess.machinePubKey = machinePubKey + sess.onDebug = c.handleDebugMessage + sess.onConciseNetMapSummary = func(summary string) { + // Occasionally print the netmap header. + // This is handy for debugging, and our logs processing + // pipeline depends on it. (TODO: Remove this dependency.) + now := c.clock.Now() + if now.Sub(c.lastPrintMap) < 5*time.Minute { + return + } + c.lastPrintMap = now + c.logf("[v1] new network map[%d]:\n%s", mapResIdx, summary) + } + sess.onSelfNodeChanged = func(nm *netmap.NetworkMap) { + c.mu.Lock() + defer c.mu.Unlock() + // If we are the ones who last updated persist, then we can update it + // again. Otherwise, we should not touch it. Also, it's only worth + // change it if the Node info changed. + if persist == c.persist { + newPersist := persist.AsStruct() + newPersist.NodeID = nm.SelfNode.StableID + newPersist.UserProfile = nm.UserProfiles[nm.User()] + + c.persist = newPersist.View() + persist = c.persist + } + c.expiry = nm.Expiry + } + sess.StartWatchdog() + + // gotNonKeepAliveMessage is whether we've yet received a MapResponse message without + // KeepAlive set. + var gotNonKeepAliveMessage bool // If allowStream, then the server will use an HTTP long poll to // return incremental results. There is always one response right @@ -997,8 +1005,8 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap // the same format before just closing the connection. // We can use this same read loop either way. var msg []byte - for i := 0; i == 0 || isStreaming; i++ { - vlogf("netmap: starting size read after %v (poll %v)", time.Since(t0).Round(time.Millisecond), i) + for ; mapResIdx == 0 || isStreaming; mapResIdx++ { + vlogf("netmap: starting size read after %v (poll %v)", time.Since(t0).Round(time.Millisecond), mapResIdx) var siz [4]byte if _, err := io.ReadFull(res.Body, siz[:]); err != nil { vlogf("netmap: size read error after %v: %v", time.Since(t0).Round(time.Millisecond), err) @@ -1062,7 +1070,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap } select { - case timeoutReset <- struct{}{}: + case sess.watchdogReset <- struct{}{}: vlogf("netmap: sent timer reset") case <-ctx.Done(): c.logf("[v1] netmap: not resetting timer; context done: %v", ctx.Err()) @@ -1074,75 +1082,19 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap } metricMapResponseMap.Add(1) - if i > 0 { + if gotNonKeepAliveMessage { + // If we've already seen a non-keep-alive message, this is a delta update. metricMapResponseMapDelta.Add(1) + } else if resp.Node == nil { + // The very first non-keep-alive message should have Node populated. + c.logf("initial MapResponse lacked Node") + return errors.New("initial MapResponse lacked node") } + gotNonKeepAliveMessage = true - if debug := resp.Debug; debug != nil { - if code := debug.Exit; code != nil { - c.logf("exiting process with status %v per controlplane", *code) - os.Exit(*code) - } - if debug.DisableLogTail { - logtail.Disable() - envknob.SetNoLogsNoSupport() - } - if sleep := time.Duration(debug.SleepSeconds * float64(time.Second)); sleep > 0 { - if err := sleepAsRequested(ctx, c.logf, timeoutReset, sleep, c.clock); err != nil { - return err - } - } - } - - // For responses that mutate the self node, check for updated nodeAttrs. - if resp.Node != nil { - setControlKnobsFromNodeAttrs(resp.Node.Capabilities) - } - - // Call Node.InitDisplayNames on any changed nodes. - initDisplayNames(cmpx.Or(resp.Node, sess.lastNode).View(), &resp) - - nm := sess.netmapForResponse(&resp) - if nm.SelfNode == nil { - c.logf("MapResponse lacked node") - return errors.New("MapResponse lacked node") - } - - if DevKnob.StripEndpoints() { - for _, p := range resp.Peers { - p.Endpoints = nil - } - } - if DevKnob.StripCaps() { - nm.SelfNode.Capabilities = nil - } - - // Occasionally print the netmap header. - // This is handy for debugging, and our logs processing - // pipeline depends on it. (TODO: Remove this dependency.) - // Code elsewhere prints netmap diffs every time they are received. - now := c.clock.Now() - if now.Sub(c.lastPrintMap) >= 5*time.Minute { - c.lastPrintMap = now - c.logf("[v1] new network map[%d]:\n%s", i, nm.VeryConcise()) - } - - c.mu.Lock() - // If we are the ones who last updated persist, then we can update it - // again. Otherwise, we should not touch it. Also, it's only worth - // change it if the Node info changed. - if persist == c.persist && resp.Node != nil { - newPersist := persist.AsStruct() - newPersist.NodeID = nm.SelfNode.StableID - newPersist.UserProfile = nm.UserProfiles[nm.User()] - - c.persist = newPersist.View() - persist = c.persist + if err := sess.HandleNonKeepAliveMapResponse(ctx, &resp); err != nil { + return err } - c.expiry = nm.Expiry - c.mu.Unlock() - - nu.UpdateFullNetmap(nm) } if ctx.Err() != nil { return ctx.Err() @@ -1150,6 +1102,23 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap return nil } +func (c *Direct) handleDebugMessage(ctx context.Context, debug *tailcfg.Debug, watchdogReset chan<- struct{}) error { + if code := debug.Exit; code != nil { + c.logf("exiting process with status %v per controlplane", *code) + os.Exit(*code) + } + if debug.DisableLogTail { + logtail.Disable() + envknob.SetNoLogsNoSupport() + } + if sleep := time.Duration(debug.SleepSeconds * float64(time.Second)); sleep > 0 { + if err := sleepAsRequested(ctx, c.logf, watchdogReset, sleep, c.clock); err != nil { + return err + } + } + return nil +} + // initDisplayNames mutates any tailcfg.Nodes in resp to populate their display names, // calling InitDisplayNames on each. // @@ -1538,7 +1507,11 @@ func answerC2NPing(logf logger.Logf, c2nHandler http.Handler, c *http.Client, pr } } -func sleepAsRequested(ctx context.Context, logf logger.Logf, timeoutReset chan<- struct{}, d time.Duration, clock tstime.Clock) error { +// sleepAsRequest implements the sleep for a tailcfg.Debug message requesting +// that the client sleep. The complication is that while we're sleeping (if for +// a long time), we need to periodically reset the watchdog timer before it +// expires. +func sleepAsRequested(ctx context.Context, logf logger.Logf, watchdogReset chan<- struct{}, d time.Duration, clock tstime.Clock) error { const maxSleep = 5 * time.Minute if d > maxSleep { logf("sleeping for %v, capped from server-requested %v ...", maxSleep, d) @@ -1547,7 +1520,7 @@ func sleepAsRequested(ctx context.Context, logf logger.Logf, timeoutReset chan<- logf("sleeping for server-requested %v ...", d) } - ticker, tickerChannel := clock.NewTicker(pollTimeout / 2) + ticker, tickerChannel := clock.NewTicker(watchdogTimeout / 2) defer ticker.Stop() timer, timerChannel := clock.NewTimer(d) defer timer.Stop() @@ -1559,7 +1532,7 @@ func sleepAsRequested(ctx context.Context, logf logger.Logf, timeoutReset chan<- return nil case <-tickerChannel: select { - case timeoutReset <- struct{}{}: + case watchdogReset <- struct{}{}: case <-timerChannel: return nil case <-ctx.Done(): diff --git a/control/controlclient/map.go b/control/controlclient/map.go index 9ee923e8e..6841dfa46 100644 --- a/control/controlclient/map.go +++ b/control/controlclient/map.go @@ -4,6 +4,7 @@ package controlclient import ( + "context" "fmt" "log" "net/netip" @@ -11,10 +12,12 @@ import ( "tailscale.com/envknob" "tailscale.com/tailcfg" + "tailscale.com/tstime" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/netmap" "tailscale.com/types/views" + "tailscale.com/util/cmpx" "tailscale.com/wgengine/filter" ) @@ -28,10 +31,35 @@ import ( // one MapRequest). type mapSession struct { // Immutable fields. + nu NetmapUpdater // called on changes (in addition to the optional hooks below) privateNodeKey key.NodePrivate logf logger.Logf vlogf logger.Logf machinePubKey key.MachinePublic + altClock tstime.Clock // if nil, regular time is used + cancel context.CancelFunc // always non-nil, shuts down caller's base long poll context + watchdogReset chan struct{} // send to request that the long poll activity watchdog timeout be reset + + // sessionAliveCtx is a Background-based context that's alive for the + // duration of the mapSession that we own the lifetime of. It's closed by + // sessionAliveCtxClose. + sessionAliveCtx context.Context + sessionAliveCtxClose context.CancelFunc // closes sessionAliveCtx + + // Optional hooks, set once before use. + + // onDebug specifies what to do with a *tailcfg.Debug message. + // If the watchdogReset chan is nil, it's not used. Otherwise it can be sent to + // to request that the long poll activity watchdog timeout be reset. + onDebug func(_ context.Context, _ *tailcfg.Debug, watchdogReset chan<- struct{}) error + + // onConciseNetMapSummary, if non-nil, is called with the Netmap.VeryConcise summary + // whenever a map response is received. + onConciseNetMapSummary func(string) + + // onSelfNodeChanged is called before the NetmapUpdater if the self node was + // changed. + onSelfNodeChanged func(*netmap.NetworkMap) // Fields storing state over the course of multiple MapResponses. lastNode *tailcfg.Node @@ -49,23 +77,127 @@ type mapSession struct { lastPopBrowserURL string stickyDebug tailcfg.Debug // accumulated opt.Bool values lastTKAInfo *tailcfg.TKAInfo + lastNetmapSummary string // from NetworkMap.VeryConcise // netMapBuilding is non-nil during a netmapForResponse call, // containing the value to be returned, once fully populated. netMapBuilding *netmap.NetworkMap } -func newMapSession(privateNodeKey key.NodePrivate) *mapSession { +// newMapSession returns a mostly unconfigured new mapSession. +// +// Modify its optional fields on the returned value before use. +// +// It must have its Close method called to release resources. +func newMapSession(privateNodeKey key.NodePrivate, nu NetmapUpdater) *mapSession { ms := &mapSession{ + nu: nu, privateNodeKey: privateNodeKey, - logf: logger.Discard, - vlogf: logger.Discard, lastDNSConfig: new(tailcfg.DNSConfig), lastUserProfile: map[tailcfg.UserID]tailcfg.UserProfile{}, + watchdogReset: make(chan struct{}), + + // Non-nil no-op defaults, to be optionally overridden by the caller. + logf: logger.Discard, + vlogf: logger.Discard, + cancel: func() {}, + onDebug: func(context.Context, *tailcfg.Debug, chan<- struct{}) error { return nil }, + onConciseNetMapSummary: func(string) {}, + onSelfNodeChanged: func(*netmap.NetworkMap) {}, } + ms.sessionAliveCtx, ms.sessionAliveCtxClose = context.WithCancel(context.Background()) return ms } +func (ms *mapSession) clock() tstime.Clock { + return cmpx.Or[tstime.Clock](ms.altClock, tstime.StdClock{}) +} + +// StartWatchdog starts the session's watchdog timer. +// If there's no activity in too long, it tears down the connection. +// Call Close to release these resources. +func (ms *mapSession) StartWatchdog() { + timer, timedOutChan := ms.clock().NewTimer(watchdogTimeout) + go func() { + defer timer.Stop() + for { + select { + case <-ms.sessionAliveCtx.Done(): + ms.vlogf("netmap: ending timeout goroutine") + return + case <-timedOutChan: + ms.logf("map response long-poll timed out!") + ms.cancel() + return + case <-ms.watchdogReset: + if !timer.Stop() { + select { + case <-timedOutChan: + case <-ms.sessionAliveCtx.Done(): + ms.vlogf("netmap: ending timeout goroutine") + return + } + } + ms.vlogf("netmap: reset timeout timer") + timer.Reset(watchdogTimeout) + } + } + }() +} + +func (ms *mapSession) Close() { + ms.sessionAliveCtxClose() +} + +// HandleNonKeepAliveMapResponse handles a non-KeepAlive MapResponse (full or +// incremental). +// +// All fields that are valid on a KeepAlive MapResponse have already been +// handled. +// +// TODO(bradfitz): make this handle all fields later. For now (2023-08-20) this +// is [re]factoring progress enough. +func (ms *mapSession) HandleNonKeepAliveMapResponse(ctx context.Context, resp *tailcfg.MapResponse) error { + if debug := resp.Debug; debug != nil { + if err := ms.onDebug(ctx, debug, ms.watchdogReset); err != nil { + return err + } + } + + if DevKnob.StripEndpoints() { + for _, p := range resp.Peers { + p.Endpoints = nil + } + for _, p := range resp.PeersChanged { + p.Endpoints = nil + } + } + + // For responses that mutate the self node, check for updated nodeAttrs. + if resp.Node != nil { + if DevKnob.StripCaps() { + resp.Node.Capabilities = nil + } + setControlKnobsFromNodeAttrs(resp.Node.Capabilities) + } + + // Call Node.InitDisplayNames on any changed nodes. + initDisplayNames(cmpx.Or(resp.Node, ms.lastNode).View(), resp) + + nm := ms.netmapForResponse(resp) + + ms.lastNetmapSummary = nm.VeryConcise() + ms.onConciseNetMapSummary(ms.lastNetmapSummary) + + // If the self node changed, we might need to update persist. + if resp.Node != nil { + ms.onSelfNodeChanged(nm) + } + + ms.nu.UpdateFullNetmap(nm) + return nil +} + func (ms *mapSession) addUserProfile(userID tailcfg.UserID) { if userID == 0 { return diff --git a/control/controlclient/map_test.go b/control/controlclient/map_test.go index 81b75f96c..473a5fe93 100644 --- a/control/controlclient/map_test.go +++ b/control/controlclient/map_test.go @@ -4,10 +4,13 @@ package controlclient import ( + "context" "encoding/json" "fmt" + "net/netip" "reflect" "strings" + "sync/atomic" "testing" "time" @@ -330,8 +333,9 @@ func formatNodes(nodes []*tailcfg.Node) string { return sb.String() } -func newTestMapSession(t *testing.T) *mapSession { - ms := newMapSession(key.NewNode()) +func newTestMapSession(t testing.TB, nu NetmapUpdater) *mapSession { + ms := newMapSession(key.NewNode(), nu) + t.Cleanup(ms.Close) ms.logf = t.Logf return ms } @@ -346,7 +350,7 @@ func TestNetmapForResponse(t *testing.T) { }, }, } - ms := newTestMapSession(t) + ms := newTestMapSession(t, nil) nm1 := ms.netmapForResponse(&tailcfg.MapResponse{ Node: new(tailcfg.Node), PacketFilter: somePacketFilter, @@ -367,7 +371,7 @@ func TestNetmapForResponse(t *testing.T) { }) t.Run("implicit_dnsconfig", func(t *testing.T) { someDNSConfig := &tailcfg.DNSConfig{Domains: []string{"foo", "bar"}} - ms := newTestMapSession(t) + ms := newTestMapSession(t, nil) nm1 := ms.netmapForResponse(&tailcfg.MapResponse{ Node: new(tailcfg.Node), DNSConfig: someDNSConfig, @@ -384,7 +388,7 @@ func TestNetmapForResponse(t *testing.T) { } }) t.Run("collect_services", func(t *testing.T) { - ms := newTestMapSession(t) + ms := newTestMapSession(t, nil) var nm *netmap.NetworkMap wantCollect := func(v bool) { t.Helper() @@ -417,7 +421,7 @@ func TestNetmapForResponse(t *testing.T) { wantCollect(true) }) t.Run("implicit_domain", func(t *testing.T) { - ms := newTestMapSession(t) + ms := newTestMapSession(t, nil) var nm *netmap.NetworkMap want := func(v string) { t.Helper() @@ -445,7 +449,7 @@ func TestNetmapForResponse(t *testing.T) { ComputedName: "foo", ComputedNameWithHost: "foo", } - ms := newTestMapSession(t) + ms := newTestMapSession(t, nil) mapRes := &tailcfg.MapResponse{ Node: someNode, } @@ -564,7 +568,7 @@ func TestDeltaDERPMap(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ms := newTestMapSession(t) + ms := newTestMapSession(t, nil) for stepi, s := range tt.steps { nm := ms.netmapForResponse(&tailcfg.MapResponse{DERPMap: s.got}) if !reflect.DeepEqual(nm.DERPMap, s.want) { @@ -574,3 +578,64 @@ func TestDeltaDERPMap(t *testing.T) { }) } } + +type countingNetmapUpdater struct { + full atomic.Int64 +} + +func (nu *countingNetmapUpdater) UpdateFullNetmap(nm *netmap.NetworkMap) { + nu.full.Add(1) +} + +func BenchmarkMapSessionDelta(b *testing.B) { + for _, size := range []int{10, 100, 1_000, 10_000} { + b.Run(fmt.Sprintf("size_%d", size), func(b *testing.B) { + ctx := context.Background() + nu := &countingNetmapUpdater{} + ms := newTestMapSession(b, nu) + res := &tailcfg.MapResponse{ + Node: &tailcfg.Node{ + ID: 1, + Name: "foo.bar.ts.net.", + }, + } + for i := 0; i < size; i++ { + res.Peers = append(res.Peers, &tailcfg.Node{ + ID: tailcfg.NodeID(i + 2), + Name: fmt.Sprintf("peer%d.bar.ts.net.", i), + DERP: "127.3.3.40:10", + Addresses: []netip.Prefix{netip.MustParsePrefix("100.100.2.3/32"), netip.MustParsePrefix("fd7a:115c:a1e0::123/128")}, + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.100.2.3/32"), netip.MustParsePrefix("fd7a:115c:a1e0::123/128")}, + Endpoints: []string{"192.168.1.2:345", "192.168.1.3:678"}, + Hostinfo: (&tailcfg.Hostinfo{ + OS: "fooOS", + Hostname: "MyHostname", + Services: []tailcfg.Service{ + {Proto: "peerapi4", Port: 1234}, + {Proto: "peerapi6", Port: 1234}, + {Proto: "peerapi-dns-proxy", Port: 1}, + }, + }).View(), + LastSeen: ptr.To(time.Unix(int64(i), 0)), + }) + } + ms.HandleNonKeepAliveMapResponse(ctx, res) + + b.ResetTimer() + b.ReportAllocs() + + // Now for the core of the benchmark loop, just toggle + // a single node's online status. + for i := 0; i < b.N; i++ { + if err := ms.HandleNonKeepAliveMapResponse(ctx, &tailcfg.MapResponse{ + OnlineChange: map[tailcfg.NodeID]bool{ + 2: i%2 == 0, + }, + }); err != nil { + b.Fatal(err) + } + } + }) + } + +}