control/controlclient: clean up various things in prep for state overhaul

We want the overall state (used only for tests) to be computed from
the individual states of each component, rather than moving the state
around by hand in dozens of places.

In working towards that, we found a lot of things to clean up.

Updates #cleanup

Change-Id: Ieaaae5355dfae789a8ec7a56ce212f1d7e3a92db
Co-authored-by: Maisem Ali <maisem@tailscale.com>
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/9098/merge
Brad Fitzpatrick 1 year ago committed by Brad Fitzpatrick
parent 0c1e3ff625
commit 003e4aff71

@ -26,43 +26,27 @@ import (
type LoginGoal struct { type LoginGoal struct {
_ structs.Incomparable _ structs.Incomparable
wantLoggedIn bool // true if we *want* to be logged in
token *tailcfg.Oauth2Token // oauth token to use when logging in token *tailcfg.Oauth2Token // oauth token to use when logging in
flags LoginFlags // flags to use when logging in flags LoginFlags // flags to use when logging in
url string // auth url that needs to be visited url string // auth url that needs to be visited
loggedOutResult chan<- error
}
func (g *LoginGoal) sendLogoutError(err error) {
if g.loggedOutResult == nil {
return
}
select {
case g.loggedOutResult <- err:
default:
}
} }
var _ Client = (*Auto)(nil) var _ Client = (*Auto)(nil)
// waitUnpause waits until the client is unpaused then returns. It only // waitUnpause waits until either the client is unpaused or the Auto client is
// returns an error if the client is closed. // shut down. It reports whether the client should keep running (i.e. it's not
func (c *Auto) waitUnpause(routineLogName string) error { // closed).
func (c *Auto) waitUnpause(routineLogName string) (keepRunning bool) {
c.mu.Lock() c.mu.Lock()
if !c.paused { if !c.paused {
c.mu.Unlock() defer c.mu.Unlock()
return nil return !c.closed
} }
unpaused := c.unpausedChanLocked() unpaused := c.unpausedChanLocked()
c.mu.Unlock() c.mu.Unlock()
c.logf("%s: awaiting unpause", routineLogName) c.logf("%s: awaiting unpause", routineLogName)
select { return <-unpaused
case <-unpaused:
c.logf("%s: unpaused", routineLogName)
return nil
case <-c.quit:
return errors.New("quit")
}
} }
// updateRoutine is responsible for informing the server of worthy changes to // updateRoutine is responsible for informing the server of worthy changes to
@ -76,7 +60,7 @@ func (c *Auto) updateRoutine() {
var lastUpdateGenInformed updateGen var lastUpdateGenInformed updateGen
for { for {
if err := c.waitUnpause("updateRoutine"); err != nil { if !c.waitUnpause("updateRoutine") {
c.logf("updateRoutine: exiting") c.logf("updateRoutine: exiting")
return return
} }
@ -86,19 +70,11 @@ func (c *Auto) updateRoutine() {
needUpdate := gen > 0 && gen != lastUpdateGenInformed && c.loggedIn needUpdate := gen > 0 && gen != lastUpdateGenInformed && c.loggedIn
c.mu.Unlock() c.mu.Unlock()
if needUpdate { if !needUpdate {
select {
case <-c.quit:
c.logf("updateRoutine: exiting")
return
default:
}
} else {
// Nothing to do, wait for a signal. // Nothing to do, wait for a signal.
select { select {
case <-c.quit: case <-ctx.Done():
c.logf("updateRoutine: exiting") continue
return
case <-c.updateCh: case <-c.updateCh:
continue continue
} }
@ -141,7 +117,6 @@ type Auto struct {
logf logger.Logf logf logger.Logf
closed bool closed bool
updateCh chan struct{} // readable when we should inform the server of a change updateCh chan struct{} // readable when we should inform the server of a change
newMapCh chan struct{} // readable when we must restart a map request
observer Observer // called to update Client status; always non-nil observer Observer // called to update Client status; always non-nil
observerQueue execQueue observerQueue execQueue
@ -149,6 +124,8 @@ type Auto struct {
mu sync.Mutex // mutex guards the following fields mu sync.Mutex // mutex guards the following fields
wantLoggedIn bool // whether the user wants to be logged in per last method call
urlToVisit string // the last url we were told to visit
expiry time.Time expiry time.Time
// lastUpdateGen is the gen of last update we had an update worth sending to // lastUpdateGen is the gen of last update we had an update worth sending to
@ -156,17 +133,16 @@ type Auto struct {
lastUpdateGen updateGen lastUpdateGen updateGen
paused bool // whether we should stop making HTTP requests paused bool // whether we should stop making HTTP requests
unpauseWaiters []chan struct{} unpauseWaiters []chan bool // chans that gets sent true (once) on wake, or false on Shutdown
loggedIn bool // true if currently logged in loggedIn bool // true if currently logged in
loginGoal *LoginGoal // non-nil if some login activity is desired loginGoal *LoginGoal // non-nil if some login activity is desired
synced bool // true if our netmap is up-to-date inMapPoll bool // true once we get the first MapResponse in a stream; false when HTTP response ends
state State state State // TODO(bradfitz): delete this, make it computed by method from other state
authCtx context.Context // context used for auth requests authCtx context.Context // context used for auth requests
mapCtx context.Context // context used for netmap and update requests mapCtx context.Context // context used for netmap and update requests
authCancel func() // cancel authCtx authCancel func() // cancel authCtx
mapCancel func() // cancel mapCtx mapCancel func() // cancel mapCtx
quit chan struct{} // when closed, goroutines should all exit
authDone chan struct{} // when closed, authRoutine is done authDone chan struct{} // when closed, authRoutine is done
mapDone chan struct{} // when closed, mapRoutine is done mapDone chan struct{} // when closed, mapRoutine is done
updateDone chan struct{} // when closed, updateRoutine is done updateDone chan struct{} // when closed, updateRoutine is done
@ -207,8 +183,6 @@ func NewNoStart(opts Options) (_ *Auto, err error) {
clock: opts.Clock, clock: opts.Clock,
logf: opts.Logf, logf: opts.Logf,
updateCh: make(chan struct{}, 1), updateCh: make(chan struct{}, 1),
newMapCh: make(chan struct{}, 1),
quit: make(chan struct{}),
authDone: make(chan struct{}), authDone: make(chan struct{}),
mapDone: make(chan struct{}), mapDone: make(chan struct{}),
updateDone: make(chan struct{}), updateDone: make(chan struct{}),
@ -237,16 +211,15 @@ func (c *Auto) SetPaused(paused bool) {
c.logf("setPaused(%v)", paused) c.logf("setPaused(%v)", paused)
c.paused = paused c.paused = paused
if paused { if paused {
// Only cancel the map routine. (The auth routine isn't expensive
// so it's fine to keep it running.)
c.cancelMapCtxLocked() c.cancelMapCtxLocked()
} else { c.cancelAuthCtxLocked()
return
}
for _, ch := range c.unpauseWaiters { for _, ch := range c.unpauseWaiters {
close(ch) ch <- true
} }
c.unpauseWaiters = nil c.unpauseWaiters = nil
} }
}
// Start starts the client's goroutines. // Start starts the client's goroutines.
// //
@ -322,20 +295,10 @@ func (c *Auto) cancelMapCtxLocked() {
func (c *Auto) restartMap() { func (c *Auto) restartMap() {
c.mu.Lock() c.mu.Lock()
c.cancelMapCtxLocked() c.cancelMapCtxLocked()
synced := c.synced synced := c.inMapPoll
c.mu.Unlock() c.mu.Unlock()
c.logf("[v1] restartMap: synced=%v", synced) c.logf("[v1] restartMap: synced=%v", synced)
select {
case c.newMapCh <- struct{}{}:
c.logf("[v1] restartMap: wrote to channel")
default:
// if channel write failed, then there was already
// an outstanding newMapCh request. One is enough,
// since it'll always use the latest endpoints.
c.logf("[v1] restartMap: channel was full")
}
c.updateControl() c.updateControl()
} }
@ -344,23 +307,20 @@ func (c *Auto) authRoutine() {
bo := backoff.NewBackoff("authRoutine", c.logf, 30*time.Second) bo := backoff.NewBackoff("authRoutine", c.logf, 30*time.Second)
for { for {
if !c.waitUnpause("authRoutine") {
c.logf("authRoutine: exiting")
return
}
c.mu.Lock() c.mu.Lock()
goal := c.loginGoal goal := c.loginGoal
ctx := c.authCtx ctx := c.authCtx
if goal != nil { if goal != nil {
c.logf("[v1] authRoutine: %s; wantLoggedIn=%v", c.state, goal.wantLoggedIn) c.logf("[v1] authRoutine: %s; wantLoggedIn=%v", c.state, true)
} else { } else {
c.logf("[v1] authRoutine: %s; goal=nil paused=%v", c.state, c.paused) c.logf("[v1] authRoutine: %s; goal=nil paused=%v", c.state, c.paused)
} }
c.mu.Unlock() c.mu.Unlock()
select {
case <-c.quit:
c.logf("[v1] authRoutine: quit")
return
default:
}
report := func(err error, msg string) { report := func(err error, msg string) {
c.logf("[v1] %s: %v", msg, err) c.logf("[v1] %s: %v", msg, err)
// don't send status updates for context errors, // don't send status updates for context errors,
@ -378,28 +338,8 @@ func (c *Auto) authRoutine() {
continue continue
} }
if !goal.wantLoggedIn {
health.SetAuthRoutineInError(nil)
err := c.direct.TryLogout(ctx)
goal.sendLogoutError(err)
if err != nil {
report(err, "TryLogout")
bo.BackOff(ctx, err)
continue
}
// success
c.mu.Lock()
c.loggedIn = false
c.loginGoal = nil
c.state = StateNotAuthenticated
c.synced = false
c.mu.Unlock()
c.sendStatus("authRoutine-wantout", nil, "", nil)
bo.BackOff(ctx, nil)
} else { // ie. goal.wantLoggedIn
c.mu.Lock() c.mu.Lock()
c.urlToVisit = goal.url
if goal.url != "" { if goal.url != "" {
c.state = StateURLVisitRequired c.state = StateURLVisitRequired
} else { } else {
@ -428,13 +368,12 @@ func (c *Auto) authRoutine() {
// However, not all control servers get this right, // However, not all control servers get this right,
// and logging about it here just generates noise. // and logging about it here just generates noise.
c.mu.Lock() c.mu.Lock()
c.urlToVisit = url
c.loginGoal = &LoginGoal{ c.loginGoal = &LoginGoal{
wantLoggedIn: true,
flags: LoginDefault, flags: LoginDefault,
url: url, url: url,
} }
c.state = StateURLVisitRequired c.state = StateURLVisitRequired
c.synced = false
c.mu.Unlock() c.mu.Unlock()
c.sendStatus("authRoutine-url", err, url, nil) c.sendStatus("authRoutine-url", err, url, nil)
@ -451,6 +390,7 @@ func (c *Auto) authRoutine() {
// success // success
health.SetAuthRoutineInError(nil) health.SetAuthRoutineInError(nil)
c.mu.Lock() c.mu.Lock()
c.urlToVisit = ""
c.loggedIn = true c.loggedIn = true
c.loginGoal = nil c.loginGoal = nil
c.state = StateAuthenticated c.state = StateAuthenticated
@ -461,7 +401,6 @@ func (c *Auto) authRoutine() {
bo.BackOff(ctx, nil) bo.BackOff(ctx, nil)
} }
} }
}
// ExpiryForTests returns the credential expiration time, or the zero value if // ExpiryForTests returns the credential expiration time, or the zero value if
// the expiration time isn't known. It's used in tests only. // the expiration time isn't known. It's used in tests only.
@ -477,12 +416,12 @@ func (c *Auto) DirectForTest() *Direct {
return c.direct return c.direct
} }
// unpausedChanLocked returns a new channel that is closed when the // unpausedChanLocked returns a new channel that gets sent
// current Auto pause is unpaused. // either a true when unpaused or false on Auto.Shutdown.
// //
// c.mu must be held // c.mu must be held
func (c *Auto) unpausedChanLocked() <-chan struct{} { func (c *Auto) unpausedChanLocked() <-chan bool {
unpaused := make(chan struct{}) unpaused := make(chan bool, 1)
c.unpauseWaiters = append(c.unpauseWaiters, unpaused) c.unpauseWaiters = append(c.unpauseWaiters, unpaused)
return unpaused return unpaused
} }
@ -498,7 +437,7 @@ func (mrs mapRoutineState) UpdateFullNetmap(nm *netmap.NetworkMap) {
c.mu.Lock() c.mu.Lock()
ctx := c.mapCtx ctx := c.mapCtx
c.synced = true c.inMapPoll = true
if c.loggedIn { if c.loggedIn {
c.state = StateSynchronized c.state = StateSynchronized
} }
@ -524,7 +463,7 @@ func (c *Auto) mapRoutine() {
} }
for { for {
if err := c.waitUnpause("mapRoutine"); err != nil { if !c.waitUnpause("mapRoutine") {
c.logf("mapRoutine: exiting") c.logf("mapRoutine: exiting")
return return
} }
@ -535,13 +474,6 @@ func (c *Auto) mapRoutine() {
ctx := c.mapCtx ctx := c.mapCtx
c.mu.Unlock() c.mu.Unlock()
select {
case <-c.quit:
c.logf("mapRoutine: quit")
return
default:
}
report := func(err error, msg string) { report := func(err error, msg string) {
c.logf("[v1] %s: %v", msg, err) c.logf("[v1] %s: %v", msg, err)
err = fmt.Errorf("%s: %w", msg, err) err = fmt.Errorf("%s: %w", msg, err)
@ -555,24 +487,20 @@ func (c *Auto) mapRoutine() {
if !loggedIn { if !loggedIn {
// Wait for something interesting to happen // Wait for something interesting to happen
c.mu.Lock() c.mu.Lock()
c.synced = false c.inMapPoll = false
// c.state is set by authRoutine()
c.mu.Unlock() c.mu.Unlock()
select { <-ctx.Done()
case <-ctx.Done():
c.logf("[v1] mapRoutine: context done.") c.logf("[v1] mapRoutine: context done.")
case <-c.newMapCh: continue
c.logf("[v1] mapRoutine: new map needed while idle.")
} }
} else {
health.SetOutOfPollNetMap() health.SetOutOfPollNetMap()
err := c.direct.PollNetMap(ctx, mrs) err := c.direct.PollNetMap(ctx, mrs)
health.SetOutOfPollNetMap() health.SetOutOfPollNetMap()
c.mu.Lock() c.mu.Lock()
c.synced = false c.inMapPoll = false
if c.state == StateSynchronized { if c.state == StateSynchronized {
c.state = StateAuthenticated c.state = StateAuthenticated
} }
@ -582,12 +510,9 @@ func (c *Auto) mapRoutine() {
if paused { if paused {
mrs.bo.BackOff(ctx, nil) mrs.bo.BackOff(ctx, nil)
c.logf("mapRoutine: paused") c.logf("mapRoutine: paused")
continue } else {
}
report(err, "PollNetMap")
mrs.bo.BackOff(ctx, err) mrs.bo.BackOff(ctx, err)
continue report(err, "PollNetMap")
} }
} }
} }
@ -637,6 +562,7 @@ func (c *Auto) SetTKAHead(headHash string) {
c.updateControl() c.updateControl()
} }
// sendStatus can not be called with the c.mu held.
func (c *Auto) sendStatus(who string, err error, url string, nm *netmap.NetworkMap) { func (c *Auto) sendStatus(who string, err error, url string, nm *netmap.NetworkMap) {
c.mu.Lock() c.mu.Lock()
if c.closed { if c.closed {
@ -645,13 +571,13 @@ func (c *Auto) sendStatus(who string, err error, url string, nm *netmap.NetworkM
} }
state := c.state state := c.state
loggedIn := c.loggedIn loggedIn := c.loggedIn
synced := c.synced inMapPoll := c.inMapPoll
c.mu.Unlock() c.mu.Unlock()
c.logf("[v1] sendStatus: %s: %v", who, state) c.logf("[v1] sendStatus: %s: %v", who, state)
var p persist.PersistView var p persist.PersistView
if nm != nil && loggedIn && synced { if nm != nil && loggedIn && inMapPoll {
p = c.direct.GetPersist() p = c.direct.GetPersist()
} else { } else {
// don't send netmap status, as it's misleading when we're // don't send netmap status, as it's misleading when we're
@ -677,40 +603,45 @@ func (c *Auto) Login(t *tailcfg.Oauth2Token, flags LoginFlags) {
c.logf("client.Login(%v, %v)", t != nil, flags) c.logf("client.Login(%v, %v)", t != nil, flags)
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return
}
c.wantLoggedIn = true
c.loginGoal = &LoginGoal{ c.loginGoal = &LoginGoal{
wantLoggedIn: true,
token: t, token: t,
flags: flags, flags: flags,
} }
c.mu.Unlock() c.cancelMapCtxLocked()
c.cancelAuthCtxLocked()
c.cancelAuthCtx()
} }
var ErrClientClosed = errors.New("client closed")
func (c *Auto) Logout(ctx context.Context) error { func (c *Auto) Logout(ctx context.Context) error {
c.logf("client.Logout()") c.logf("client.Logout()")
errc := make(chan error, 1)
c.mu.Lock() c.mu.Lock()
c.loginGoal = &LoginGoal{ c.wantLoggedIn = false
wantLoggedIn: false, c.loginGoal = nil
loggedOutResult: errc, closed := c.closed
}
c.mu.Unlock() c.mu.Unlock()
c.cancelAuthCtx()
c.cancelMapCtx()
timer, timerChannel := c.clock.NewTimer(10 * time.Second) if closed {
defer timer.Stop() return ErrClientClosed
select { }
case err := <-errc:
if err := c.direct.TryLogout(ctx); err != nil {
return err return err
case <-ctx.Done():
return ctx.Err()
case <-timerChannel:
return context.DeadlineExceeded
} }
c.mu.Lock()
c.loggedIn = false
c.state = StateNotAuthenticated
c.cancelAuthCtxLocked()
c.cancelMapCtxLocked()
c.mu.Unlock()
c.sendStatus("authRoutine-wantout", nil, "", nil)
return nil
} }
func (c *Auto) SetExpirySooner(ctx context.Context, expiry time.Time) error { func (c *Auto) SetExpirySooner(ctx context.Context, expiry time.Time) error {
@ -738,14 +669,16 @@ func (c *Auto) Shutdown() {
c.closed = true c.closed = true
c.cancelAuthCtxLocked() c.cancelAuthCtxLocked()
c.cancelMapCtxLocked() c.cancelMapCtxLocked()
go c.observerQueue.shutdown() for _, w := range c.unpauseWaiters {
w <- false
}
c.unpauseWaiters = nil
} }
c.mu.Unlock() c.mu.Unlock()
c.logf("client.Shutdown") c.logf("client.Shutdown")
if !closed { if !closed {
c.unregisterHealthWatch() c.unregisterHealthWatch()
close(c.quit)
<-c.authDone <-c.authDone
<-c.mapDone <-c.mapDone
<-c.updateDone <-c.updateDone

@ -50,12 +50,7 @@ func TestStatusEqual(t *testing.T) {
true, true,
}, },
{ {
&Status{state: StateNew}, &Status{},
&Status{state: StateNew},
true,
},
{
&Status{state: StateNew},
&Status{state: StateAuthenticated}, &Status{state: StateAuthenticated},
false, false,
}, },

Loading…
Cancel
Save