diff --git a/ipn/ipnlocal/state_test.go b/ipn/ipnlocal/state_test.go index f2f2821e4..56f8d360a 100644 --- a/ipn/ipnlocal/state_test.go +++ b/ipn/ipnlocal/state_test.go @@ -14,6 +14,7 @@ import ( "tailscale.com/control/controlclient" "tailscale.com/ipn" + "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/types/empty" "tailscale.com/types/logger" @@ -277,7 +278,7 @@ func TestStateMachine(t *testing.T) { c := qt.New(t) logf := t.Logf - store := new(ipn.MemoryStore) + store := new(testStateStorage) e, err := wgengine.NewFakeUserspaceEngine(logf, 0) if err != nil { t.Fatalf("NewFakeUserspaceEngine: %v", err) @@ -523,6 +524,7 @@ func TestStateMachine(t *testing.T) { // The user changes their preference to WantRunning after all. t.Logf("\n\nWantRunning -> true") + store.awaitWrite() notifies.expect(2) b.EditPrefs(&ipn.MaskedPrefs{ WantRunningSet: true, @@ -537,6 +539,7 @@ func TestStateMachine(t *testing.T) { c.Assert(nn[0].State, qt.Not(qt.IsNil)) c.Assert(nn[1].Prefs, qt.Not(qt.IsNil)) c.Assert(ipn.Starting, qt.Equals, *nn[0].State) + c.Assert(store.sawWrite(), qt.IsTrue) } // Test the fast-path frontend reconnection. @@ -561,6 +564,7 @@ func TestStateMachine(t *testing.T) { b.state = ipn.Starting // User wants to logout. + store.awaitWrite() t.Logf("\n\nLogout (async)") notifies.expect(2) b.Logout() @@ -573,6 +577,7 @@ func TestStateMachine(t *testing.T) { c.Assert(nn[1].Prefs.LoggedOut, qt.IsTrue) c.Assert(nn[1].Prefs.WantRunning, qt.IsFalse) c.Assert(ipn.Stopped, qt.Equals, b.State()) + c.Assert(store.sawWrite(), qt.IsTrue) } // Let's make the logout succeed. @@ -861,3 +866,29 @@ func TestStateMachine(t *testing.T) { c.Assert(ipn.Starting, qt.Equals, b.State()) } } + +type testStateStorage struct { + mem ipn.MemoryStore + written syncs.AtomicBool +} + +func (s *testStateStorage) ReadState(id ipn.StateKey) ([]byte, error) { + return s.mem.ReadState(id) +} + +func (s *testStateStorage) WriteState(id ipn.StateKey, bs []byte) error { + s.written.Set(true) + return s.mem.WriteState(id, bs) +} + +// awaitWrite clears the "I've seen writes" bit, in prep for a future +// call to sawWrite to see if a write arrived. +func (s *testStateStorage) awaitWrite() { s.written.Set(false) } + +// sawWrite reports whether there's been a WriteState call since the most +// recent awaitWrite call. +func (s *testStateStorage) sawWrite() bool { + v := s.written.Get() + s.awaitWrite() + return v +}