diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 1de7418b6..849a57b49 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -884,6 +884,9 @@ func (b *LocalBackend) setClientStatus(st controlclient.Status) { if !envknob.TKASkipSignatureCheck() { b.tkaFilterNetmapLocked(st.NetMap) } + if b.updatePersistFromNetMapLocked(st.NetMap, prefs) { + prefsChanged = true + } b.setNetMapLocked(st.NetMap) b.updateFilterLocked(st.NetMap, prefs.View()) } @@ -3349,23 +3352,36 @@ func hasCapability(nm *netmap.NetworkMap, cap string) bool { return false } +func (b *LocalBackend) updatePersistFromNetMapLocked(nm *netmap.NetworkMap, prefs *ipn.Prefs) (changed bool) { + if nm == nil || nm.SelfNode == nil { + return + } + up := nm.UserProfiles[nm.User] + if prefs.Persist.UserProfile.ID != up.ID { + // If the current profile doesn't match the + // network map's user profile, then we need to + // update the persisted UserProfile to match. + prefs.Persist.UserProfile = up + changed = true + } + if prefs.Persist.NodeID == "" { + // If the current profile doesn't have a NodeID, + // then we need to update the persisted NodeID to + // match. + prefs.Persist.NodeID = nm.SelfNode.StableID + changed = true + } + return changed +} + func (b *LocalBackend) setNetMapLocked(nm *netmap.NetworkMap) { b.dialer.SetNetMap(nm) var login string if nm != nil { - up := nm.UserProfiles[nm.User] - login = up.LoginName + login = nm.UserProfiles[nm.User].LoginName if login == "" { login = "" } - if cp := b.pm.CurrentProfile(); cp.ID != "" && cp.UserProfile.ID != up.ID { - // If the current profile doesn't match the - // network map's user profile, then we need to - // update the persisted UserProfile to match. - prefs := b.pm.CurrentPrefs().AsStruct() - prefs.Persist.UserProfile = up - b.pm.SetPrefs(prefs.View()) - } } b.netMap = nm if login != b.activeLogin { diff --git a/ipn/ipnlocal/profiles.go b/ipn/ipnlocal/profiles.go index 86f427289..2d9c3e42d 100644 --- a/ipn/ipnlocal/profiles.go +++ b/ipn/ipnlocal/profiles.go @@ -70,13 +70,30 @@ func (pm *profileManager) SetCurrentUser(uid string) error { return nil } -func (pm *profileManager) findProfileByUserID(userID tailcfg.UserID) *ipn.LoginProfile { +func (pm *profileManager) findProfilesByNodeID(nodeID tailcfg.StableNodeID) []*ipn.LoginProfile { + if nodeID.IsZero() { + return nil + } + var out []*ipn.LoginProfile + for _, p := range pm.knownProfiles { + if p.NodeID == nodeID { + out = append(out, p) + } + } + return out +} + +func (pm *profileManager) findProfilesByUserID(userID tailcfg.UserID) []*ipn.LoginProfile { + if userID.IsZero() { + return nil + } + var out []*ipn.LoginProfile for _, p := range pm.knownProfiles { if p.UserProfile.ID == userID { - return p + out = append(out, p) } } - return nil + return out } func (pm *profileManager) findProfileByName(name string) *ipn.LoginProfile { @@ -124,37 +141,42 @@ func init() { // provided prefs, which may be accessed via CurrentPrefs. func (pm *profileManager) SetPrefs(prefsIn ipn.PrefsView) error { prefs := prefsIn.AsStruct().View() - ps := prefs.Persist() - if ps == nil || ps.LoginName == "" { + newPersist := prefs.Persist() + if newPersist == nil || newPersist.LoginName == "" { return pm.setPrefsLocked(prefs) } - up := ps.UserProfile + up := newPersist.UserProfile if up.LoginName == "" { - up.LoginName = ps.LoginName + // Backwards compatibility with old prefs files. + up.LoginName = newPersist.LoginName + } else { + newPersist.LoginName = up.LoginName } if up.DisplayName == "" { up.DisplayName = up.LoginName } cp := pm.currentProfile - wasNamedWithLoginName := cp.Name == cp.UserProfile.LoginName if pm.isNewProfile { pm.isNewProfile = false // Check if we already have a profile for this user. - existing := pm.findProfileByUserID(ps.UserProfile.ID) - if existing != nil && existing.ID != "" { - cp = existing - } else { + existing := pm.findProfilesByUserID(newPersist.UserProfile.ID) + // Also check if we have a profile with the same NodeID. + existing = append(existing, pm.findProfilesByNodeID(newPersist.NodeID)...) + if len(existing) == 0 { cp.ID, cp.Key = newUnusedID(pm.knownProfiles) - cp.Name = ps.LoginName + } else { + // Only one profile per user/nodeID should exist. + for _, p := range existing[1:] { + // Best effort cleanup. + pm.DeleteProfile(p.ID) + } + cp = existing[0] } - cp.UserProfile = ps.UserProfile cp.LocalUserID = pm.currentUserID - } else { - cp.UserProfile = ps.UserProfile - } - if wasNamedWithLoginName { - cp.Name = ps.LoginName } + cp.UserProfile = newPersist.UserProfile + cp.NodeID = newPersist.NodeID + cp.Name = up.LoginName pm.knownProfiles[cp.ID] = cp pm.currentProfile = cp if err := pm.writeKnownProfiles(); err != nil { diff --git a/ipn/ipnlocal/profiles_test.go b/ipn/ipnlocal/profiles_test.go index 51b4f1779..a7ffee5c2 100644 --- a/ipn/ipnlocal/profiles_test.go +++ b/ipn/ipnlocal/profiles_test.go @@ -5,6 +5,7 @@ package ipnlocal import ( + "fmt" "testing" "tailscale.com/ipn" @@ -61,21 +62,28 @@ func TestProfileManagement(t *testing.T) { } } logins := make(map[string]tailcfg.UserID) + nodeIDs := make(map[string]tailcfg.StableNodeID) setPrefs := func(t *testing.T, loginName string) ipn.PrefsView { t.Helper() p := pm.CurrentPrefs().AsStruct() - id := logins[loginName] - if id.IsZero() { - id = tailcfg.UserID(len(logins) + 1) - logins[loginName] = id + uid := logins[loginName] + if uid.IsZero() { + uid = tailcfg.UserID(len(logins) + 1) + logins[loginName] = uid + } + nid := nodeIDs[loginName] + if nid.IsZero() { + nid = tailcfg.StableNodeID(fmt.Sprint(len(nodeIDs) + 1)) + nodeIDs[loginName] = nid } p.Persist = &persist.Persist{ LoginName: loginName, PrivateNodeKey: key.NewNode(), UserProfile: tailcfg.UserProfile{ - ID: id, + ID: uid, LoginName: loginName, }, + NodeID: nid, } if err := pm.SetPrefs(p.View()); err != nil { t.Fatal(err) @@ -132,7 +140,7 @@ func TestProfileManagement(t *testing.T) { } checkProfiles(t) - t.Logf("Create new profile again") + t.Logf("Create new profile - 2") pm.NewProfile() wantCurProfile = "" wantProfiles[""] = emptyPrefs @@ -143,6 +151,19 @@ func TestProfileManagement(t *testing.T) { delete(wantProfiles, "") wantCurProfile = "user@2.example.com" checkProfiles(t) + + t.Logf("Tag the current the profile") + nodeIDs["tagged-node.2.ts.net"] = nodeIDs["user@2.example.com"] + wantProfiles["tagged-node.2.ts.net"] = setPrefs(t, "tagged-node.2.ts.net") + delete(wantProfiles, "user@2.example.com") + wantCurProfile = "tagged-node.2.ts.net" + checkProfiles(t) + + t.Logf("Relogin") + wantProfiles["user@2.example.com"] = setPrefs(t, "user@2.example.com") + delete(wantProfiles, "tagged-node.2.ts.net") + wantCurProfile = "user@2.example.com" + checkProfiles(t) } // TestProfileManagementWindows tests going into and out of Unattended mode on diff --git a/ipn/ipnlocal/state_test.go b/ipn/ipnlocal/state_test.go index 5fc2a2cd2..8a1653fe2 100644 --- a/ipn/ipnlocal/state_test.go +++ b/ipn/ipnlocal/state_test.go @@ -473,6 +473,7 @@ func TestStateMachine(t *testing.T) { t.Logf("\n\nLoginFinished") notifies.expect(3) cc.persist.LoginName = "user1" + cc.persist.UserProfile.LoginName = "user1" cc.send(nil, "", true, &netmap.NetworkMap{}) { nn := notifies.drain(3) @@ -698,6 +699,7 @@ func TestStateMachine(t *testing.T) { t.Logf("\n\nLoginFinished3") notifies.expect(3) cc.persist.LoginName = "user2" + cc.persist.UserProfile.LoginName = "user2" cc.send(nil, "", true, &netmap.NetworkMap{ MachineStatus: tailcfg.MachineAuthorized, }) @@ -833,6 +835,7 @@ func TestStateMachine(t *testing.T) { t.Logf("\n\nLoginDifferent URL visited") notifies.expect(3) cc.persist.LoginName = "user3" + cc.persist.UserProfile.LoginName = "user3" cc.send(nil, "", true, &netmap.NetworkMap{ MachineStatus: tailcfg.MachineAuthorized, }) diff --git a/ipn/prefs.go b/ipn/prefs.go index 00fb3f2e0..bde0d0b34 100644 --- a/ipn/prefs.go +++ b/ipn/prefs.go @@ -715,6 +715,13 @@ type LoginProfile struct { // This is updated whenever the server provides a new UserProfile. UserProfile tailcfg.UserProfile + // NodeID is the NodeID of the node that this profile is logged into. + // This should be stable across tagging and untagging nodes. + // It may seem redundant to check against both the UserProfile.UserID + // and the NodeID. However the NodeID can change if the node is deleted + // from the admin panel. + NodeID tailcfg.StableNodeID + // LocalUserID is the user ID of the user who created this profile. // It is only relevant on Windows where we have a multi-user system. // It is assigned once at profile creation time and never changes. diff --git a/types/persist/persist.go b/types/persist/persist.go index 3f4162eac..b128c5f70 100644 --- a/types/persist/persist.go +++ b/types/persist/persist.go @@ -38,6 +38,7 @@ type Persist struct { LoginName string UserProfile tailcfg.UserProfile NetworkLockKey key.NLPrivate + NodeID tailcfg.StableNodeID } // PublicNodeKey returns the public key for the node key. @@ -68,7 +69,8 @@ func (p *Persist) Equals(p2 *Persist) bool { p.Provider == p2.Provider && p.LoginName == p2.LoginName && p.UserProfile == p2.UserProfile && - p.NetworkLockKey.Equal(p2.NetworkLockKey) + p.NetworkLockKey.Equal(p2.NetworkLockKey) && + p.NodeID == p2.NodeID } func (p *Persist) Pretty() string { diff --git a/types/persist/persist_clone.go b/types/persist/persist_clone.go index 4d7924665..aeb40afe5 100644 --- a/types/persist/persist_clone.go +++ b/types/persist/persist_clone.go @@ -33,4 +33,5 @@ var _PersistCloneNeedsRegeneration = Persist(struct { LoginName string UserProfile tailcfg.UserProfile NetworkLockKey key.NLPrivate + NodeID tailcfg.StableNodeID }{}) diff --git a/types/persist/persist_test.go b/types/persist/persist_test.go index ac1401b87..7651fe02a 100644 --- a/types/persist/persist_test.go +++ b/types/persist/persist_test.go @@ -22,7 +22,7 @@ func fieldsOf(t reflect.Type) (fields []string) { } func TestPersistEqual(t *testing.T) { - persistHandles := []string{"LegacyFrontendPrivateMachineKey", "PrivateNodeKey", "OldPrivateNodeKey", "Provider", "LoginName", "UserProfile", "NetworkLockKey"} + persistHandles := []string{"LegacyFrontendPrivateMachineKey", "PrivateNodeKey", "OldPrivateNodeKey", "Provider", "LoginName", "UserProfile", "NetworkLockKey", "NodeID"} if have := fieldsOf(reflect.TypeOf(Persist{})); !reflect.DeepEqual(have, persistHandles) { t.Errorf("Persist.Equal check might be out of sync\nfields: %q\nhandled: %q\n", have, persistHandles) @@ -123,6 +123,16 @@ func TestPersistEqual(t *testing.T) { &Persist{NetworkLockKey: key.NewNLPrivate()}, false, }, + { + &Persist{NodeID: "abc"}, + &Persist{NodeID: "abc"}, + true, + }, + { + &Persist{NodeID: ""}, + &Persist{NodeID: "abc"}, + false, + }, } for i, test := range tests { if got := test.a.Equals(test.b); got != test.want { diff --git a/types/persist/persist_view.go b/types/persist/persist_view.go index ee3976346..b961c07c9 100644 --- a/types/persist/persist_view.go +++ b/types/persist/persist_view.go @@ -71,6 +71,7 @@ func (v PersistView) Provider() string { return v.ж.Provider func (v PersistView) LoginName() string { return v.ж.LoginName } func (v PersistView) UserProfile() tailcfg.UserProfile { return v.ж.UserProfile } func (v PersistView) NetworkLockKey() key.NLPrivate { return v.ж.NetworkLockKey } +func (v PersistView) NodeID() tailcfg.StableNodeID { return v.ж.NodeID } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _PersistViewNeedsRegeneration = Persist(struct { @@ -82,4 +83,5 @@ var _PersistViewNeedsRegeneration = Persist(struct { LoginName string UserProfile tailcfg.UserProfile NetworkLockKey key.NLPrivate + NodeID tailcfg.StableNodeID }{})