diff --git a/control/controlclient/auto.go b/control/controlclient/auto.go index 959fa17c1..3172ae352 100644 --- a/control/controlclient/auto.go +++ b/control/controlclient/auto.go @@ -480,11 +480,42 @@ func (c *Auto) unpausedChanLocked() <-chan struct{} { return unpaused } +// mapRoutineState is the state of Auto.mapRoutine while it's running. +type mapRoutineState struct { + c *Auto + bo *backoff.Backoff +} + +func (mrs mapRoutineState) UpdateFullNetmap(nm *netmap.NetworkMap) { + c := mrs.c + health.SetInPollNetMap(true) + + c.mu.Lock() + ctx := c.mapCtx + c.synced = true + if c.loggedIn { + c.state = StateSynchronized + } + c.expiry = ptr.To(nm.Expiry) + stillAuthed := c.loggedIn + c.logf("[v1] mapRoutine: netmap received: %s", c.state) + c.mu.Unlock() + + if stillAuthed { + c.sendStatus("mapRoutine-got-netmap", nil, "", nm) + } + // Reset the backoff timer if we got a netmap. + mrs.bo.BackOff(ctx, nil) +} + // mapRoutine is responsible for keeping a read-only streaming connection to the // control server, and keeping the netmap up to date. func (c *Auto) mapRoutine() { defer close(c.mapDone) - bo := backoff.NewBackoff("mapRoutine", c.logf, 30*time.Second) + mrs := &mapRoutineState{ + c: c, + bo: backoff.NewBackoff("mapRoutine", c.logf, 30*time.Second), + } for { if err := c.waitUnpause("mapRoutine"); err != nil { @@ -531,25 +562,7 @@ func (c *Auto) mapRoutine() { } else { health.SetInPollNetMap(false) - err := c.direct.PollNetMap(ctx, func(nm *netmap.NetworkMap) { - health.SetInPollNetMap(true) - - c.mu.Lock() - c.synced = true - if c.loggedIn { - c.state = StateSynchronized - } - c.expiry = ptr.To(nm.Expiry) - stillAuthed := c.loggedIn - c.logf("[v1] mapRoutine: netmap received: %s", c.state) - c.mu.Unlock() - - if stillAuthed { - c.sendStatus("mapRoutine-got-netmap", nil, "", nm) - } - // Reset the backoff timer if we got a netmap. - bo.BackOff(ctx, nil) - }) + err := c.direct.PollNetMap(ctx, mrs) health.SetInPollNetMap(false) c.mu.Lock() @@ -561,16 +574,14 @@ func (c *Auto) mapRoutine() { c.mu.Unlock() if paused { + mrs.bo.BackOff(ctx, nil) c.logf("mapRoutine: paused") continue } - if err != nil { - report(err, "PollNetMap") - bo.BackOff(ctx, err) - continue - } - bo.BackOff(ctx, nil) + report(err, "PollNetMap") + mrs.bo.BackOff(ctx, err) + continue } } } diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go index d777deb8e..1762c4eec 100644 --- a/control/controlclient/direct.go +++ b/control/controlclient/direct.go @@ -180,6 +180,16 @@ type Decompressor interface { Close() } +// NetmapUpdater is the interface needed by the controlclient to enact change in +// the world as a function of updates received from the network. +type NetmapUpdater interface { + UpdateFullNetmap(*netmap.NetworkMap) + + // TODO(bradfitz): add methods to do fine-grained updates, mutating just + // parts of peers, without implementations of NetmapUpdater needing to do + // the diff themselves between the previous full & next full network maps. +} + // NewDirect returns a new Direct client. func NewDirect(opts Options) (*Direct, error) { if opts.ServerURL == "" { @@ -767,24 +777,31 @@ func (c *Direct) SetEndpoints(endpoints []tailcfg.Endpoint) (changed bool) { return c.newEndpoints(endpoints) } -// PollNetMap makes a /map request to download the network map, calling cb with -// each new netmap. -// It always returns a non-nil error describing the reason for the failure -// or why the request ended. -func (c *Direct) PollNetMap(ctx context.Context, cb func(*netmap.NetworkMap)) error { - return c.sendMapRequest(ctx, true, cb) +// PollNetMap makes a /map request to download the network map, calling +// NetmapUpdater on each update from the control plane. +// +// It always returns a non-nil error describing the reason for the failure or +// why the request ended. +func (c *Direct) PollNetMap(ctx context.Context, nu NetmapUpdater) error { + return c.sendMapRequest(ctx, true, nu) +} + +type rememberLastNetmapUpdater struct { + last *netmap.NetworkMap +} + +func (nu *rememberLastNetmapUpdater) UpdateFullNetmap(nm *netmap.NetworkMap) { + nu.last = nm } // FetchNetMapForTest fetches the netmap once. func (c *Direct) FetchNetMapForTest(ctx context.Context) (*netmap.NetworkMap, error) { - var ret *netmap.NetworkMap - err := c.sendMapRequest(ctx, false, func(nm *netmap.NetworkMap) { - ret = nm - }) - if err == nil && ret == nil { + var nu rememberLastNetmapUpdater + err := c.sendMapRequest(ctx, false, &nu) + if err == nil && nu.last == nil { return nil, errors.New("[unexpected] sendMapRequest success without callback") } - return ret, err + return nu.last, err } // SendUpdate makes a /map request to update the server of our latest state, but @@ -805,8 +822,8 @@ const pollTimeout = 120 * time.Second // and as such always returns a non-nil error. // // If cb is nil, OmitPeers will be set to true. -func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, cb func(*netmap.NetworkMap)) error { - if isStreaming && cb == nil { +func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu NetmapUpdater) error { + if isStreaming && nu == nil { panic("cb must be non-nil if isStreaming is true") } @@ -868,7 +885,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, cb func(* Stream: isStreaming, Hostinfo: hi, DebugFlags: c.debugFlags, - OmitPeers: cb == nil, + OmitPeers: nu == nil, TKAHead: c.tkaHead, } var extraDebugFlags []string @@ -939,7 +956,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, cb func(* health.NoteMapRequestHeard(request) - if cb == nil { + if nu == nil { io.Copy(io.Discard, res.Body) return nil } @@ -1135,7 +1152,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, cb func(* c.expiry = &nm.Expiry c.mu.Unlock() - cb(nm) + nu.UpdateFullNetmap(nm) } if ctx.Err() != nil { return ctx.Err()