ipn/ipnlocal: handle untagging nodes better

We would end up with duplicate profiles for the node as the UserID
would have chnaged. In order to correctly deduplicate profiles, we
need to look at both the UserID and the NodeID. A single machine can
only ever have 1 profile per NodeID and 1 profile per UserID.

Note: UserID of a Node can change when the node is tagged/untagged,
and the NodeID of a device can change when the node is deleted so we
need to check for both.

Updates #713

Signed-off-by: Maisem Ali <maisem@tailscale.com>
pull/6384/head
Maisem Ali 2 years ago committed by Maisem Ali
parent f18dde6ad1
commit dd50dcd067

@ -884,6 +884,9 @@ func (b *LocalBackend) setClientStatus(st controlclient.Status) {
if !envknob.TKASkipSignatureCheck() { if !envknob.TKASkipSignatureCheck() {
b.tkaFilterNetmapLocked(st.NetMap) b.tkaFilterNetmapLocked(st.NetMap)
} }
if b.updatePersistFromNetMapLocked(st.NetMap, prefs) {
prefsChanged = true
}
b.setNetMapLocked(st.NetMap) b.setNetMapLocked(st.NetMap)
b.updateFilterLocked(st.NetMap, prefs.View()) b.updateFilterLocked(st.NetMap, prefs.View())
} }
@ -3349,23 +3352,36 @@ func hasCapability(nm *netmap.NetworkMap, cap string) bool {
return false return false
} }
func (b *LocalBackend) updatePersistFromNetMapLocked(nm *netmap.NetworkMap, prefs *ipn.Prefs) (changed bool) {
if nm == nil || nm.SelfNode == nil {
return
}
up := nm.UserProfiles[nm.User]
if prefs.Persist.UserProfile.ID != up.ID {
// If the current profile doesn't match the
// network map's user profile, then we need to
// update the persisted UserProfile to match.
prefs.Persist.UserProfile = up
changed = true
}
if prefs.Persist.NodeID == "" {
// If the current profile doesn't have a NodeID,
// then we need to update the persisted NodeID to
// match.
prefs.Persist.NodeID = nm.SelfNode.StableID
changed = true
}
return changed
}
func (b *LocalBackend) setNetMapLocked(nm *netmap.NetworkMap) { func (b *LocalBackend) setNetMapLocked(nm *netmap.NetworkMap) {
b.dialer.SetNetMap(nm) b.dialer.SetNetMap(nm)
var login string var login string
if nm != nil { if nm != nil {
up := nm.UserProfiles[nm.User] login = nm.UserProfiles[nm.User].LoginName
login = up.LoginName
if login == "" { if login == "" {
login = "<missing-profile>" login = "<missing-profile>"
} }
if cp := b.pm.CurrentProfile(); cp.ID != "" && cp.UserProfile.ID != up.ID {
// If the current profile doesn't match the
// network map's user profile, then we need to
// update the persisted UserProfile to match.
prefs := b.pm.CurrentPrefs().AsStruct()
prefs.Persist.UserProfile = up
b.pm.SetPrefs(prefs.View())
}
} }
b.netMap = nm b.netMap = nm
if login != b.activeLogin { if login != b.activeLogin {

@ -70,13 +70,30 @@ func (pm *profileManager) SetCurrentUser(uid string) error {
return nil return nil
} }
func (pm *profileManager) findProfileByUserID(userID tailcfg.UserID) *ipn.LoginProfile { func (pm *profileManager) findProfilesByNodeID(nodeID tailcfg.StableNodeID) []*ipn.LoginProfile {
if nodeID.IsZero() {
return nil
}
var out []*ipn.LoginProfile
for _, p := range pm.knownProfiles {
if p.NodeID == nodeID {
out = append(out, p)
}
}
return out
}
func (pm *profileManager) findProfilesByUserID(userID tailcfg.UserID) []*ipn.LoginProfile {
if userID.IsZero() {
return nil
}
var out []*ipn.LoginProfile
for _, p := range pm.knownProfiles { for _, p := range pm.knownProfiles {
if p.UserProfile.ID == userID { if p.UserProfile.ID == userID {
return p out = append(out, p)
} }
} }
return nil return out
} }
func (pm *profileManager) findProfileByName(name string) *ipn.LoginProfile { func (pm *profileManager) findProfileByName(name string) *ipn.LoginProfile {
@ -124,37 +141,42 @@ func init() {
// provided prefs, which may be accessed via CurrentPrefs. // provided prefs, which may be accessed via CurrentPrefs.
func (pm *profileManager) SetPrefs(prefsIn ipn.PrefsView) error { func (pm *profileManager) SetPrefs(prefsIn ipn.PrefsView) error {
prefs := prefsIn.AsStruct().View() prefs := prefsIn.AsStruct().View()
ps := prefs.Persist() newPersist := prefs.Persist()
if ps == nil || ps.LoginName == "" { if newPersist == nil || newPersist.LoginName == "" {
return pm.setPrefsLocked(prefs) return pm.setPrefsLocked(prefs)
} }
up := ps.UserProfile up := newPersist.UserProfile
if up.LoginName == "" { if up.LoginName == "" {
up.LoginName = ps.LoginName // Backwards compatibility with old prefs files.
up.LoginName = newPersist.LoginName
} else {
newPersist.LoginName = up.LoginName
} }
if up.DisplayName == "" { if up.DisplayName == "" {
up.DisplayName = up.LoginName up.DisplayName = up.LoginName
} }
cp := pm.currentProfile cp := pm.currentProfile
wasNamedWithLoginName := cp.Name == cp.UserProfile.LoginName
if pm.isNewProfile { if pm.isNewProfile {
pm.isNewProfile = false pm.isNewProfile = false
// Check if we already have a profile for this user. // Check if we already have a profile for this user.
existing := pm.findProfileByUserID(ps.UserProfile.ID) existing := pm.findProfilesByUserID(newPersist.UserProfile.ID)
if existing != nil && existing.ID != "" { // Also check if we have a profile with the same NodeID.
cp = existing existing = append(existing, pm.findProfilesByNodeID(newPersist.NodeID)...)
} else { if len(existing) == 0 {
cp.ID, cp.Key = newUnusedID(pm.knownProfiles) cp.ID, cp.Key = newUnusedID(pm.knownProfiles)
cp.Name = ps.LoginName } else {
// Only one profile per user/nodeID should exist.
for _, p := range existing[1:] {
// Best effort cleanup.
pm.DeleteProfile(p.ID)
}
cp = existing[0]
} }
cp.UserProfile = ps.UserProfile
cp.LocalUserID = pm.currentUserID cp.LocalUserID = pm.currentUserID
} else {
cp.UserProfile = ps.UserProfile
}
if wasNamedWithLoginName {
cp.Name = ps.LoginName
} }
cp.UserProfile = newPersist.UserProfile
cp.NodeID = newPersist.NodeID
cp.Name = up.LoginName
pm.knownProfiles[cp.ID] = cp pm.knownProfiles[cp.ID] = cp
pm.currentProfile = cp pm.currentProfile = cp
if err := pm.writeKnownProfiles(); err != nil { if err := pm.writeKnownProfiles(); err != nil {

@ -5,6 +5,7 @@
package ipnlocal package ipnlocal
import ( import (
"fmt"
"testing" "testing"
"tailscale.com/ipn" "tailscale.com/ipn"
@ -61,21 +62,28 @@ func TestProfileManagement(t *testing.T) {
} }
} }
logins := make(map[string]tailcfg.UserID) logins := make(map[string]tailcfg.UserID)
nodeIDs := make(map[string]tailcfg.StableNodeID)
setPrefs := func(t *testing.T, loginName string) ipn.PrefsView { setPrefs := func(t *testing.T, loginName string) ipn.PrefsView {
t.Helper() t.Helper()
p := pm.CurrentPrefs().AsStruct() p := pm.CurrentPrefs().AsStruct()
id := logins[loginName] uid := logins[loginName]
if id.IsZero() { if uid.IsZero() {
id = tailcfg.UserID(len(logins) + 1) uid = tailcfg.UserID(len(logins) + 1)
logins[loginName] = id logins[loginName] = uid
}
nid := nodeIDs[loginName]
if nid.IsZero() {
nid = tailcfg.StableNodeID(fmt.Sprint(len(nodeIDs) + 1))
nodeIDs[loginName] = nid
} }
p.Persist = &persist.Persist{ p.Persist = &persist.Persist{
LoginName: loginName, LoginName: loginName,
PrivateNodeKey: key.NewNode(), PrivateNodeKey: key.NewNode(),
UserProfile: tailcfg.UserProfile{ UserProfile: tailcfg.UserProfile{
ID: id, ID: uid,
LoginName: loginName, LoginName: loginName,
}, },
NodeID: nid,
} }
if err := pm.SetPrefs(p.View()); err != nil { if err := pm.SetPrefs(p.View()); err != nil {
t.Fatal(err) t.Fatal(err)
@ -132,7 +140,7 @@ func TestProfileManagement(t *testing.T) {
} }
checkProfiles(t) checkProfiles(t)
t.Logf("Create new profile again") t.Logf("Create new profile - 2")
pm.NewProfile() pm.NewProfile()
wantCurProfile = "" wantCurProfile = ""
wantProfiles[""] = emptyPrefs wantProfiles[""] = emptyPrefs
@ -143,6 +151,19 @@ func TestProfileManagement(t *testing.T) {
delete(wantProfiles, "") delete(wantProfiles, "")
wantCurProfile = "user@2.example.com" wantCurProfile = "user@2.example.com"
checkProfiles(t) checkProfiles(t)
t.Logf("Tag the current the profile")
nodeIDs["tagged-node.2.ts.net"] = nodeIDs["user@2.example.com"]
wantProfiles["tagged-node.2.ts.net"] = setPrefs(t, "tagged-node.2.ts.net")
delete(wantProfiles, "user@2.example.com")
wantCurProfile = "tagged-node.2.ts.net"
checkProfiles(t)
t.Logf("Relogin")
wantProfiles["user@2.example.com"] = setPrefs(t, "user@2.example.com")
delete(wantProfiles, "tagged-node.2.ts.net")
wantCurProfile = "user@2.example.com"
checkProfiles(t)
} }
// TestProfileManagementWindows tests going into and out of Unattended mode on // TestProfileManagementWindows tests going into and out of Unattended mode on

@ -473,6 +473,7 @@ func TestStateMachine(t *testing.T) {
t.Logf("\n\nLoginFinished") t.Logf("\n\nLoginFinished")
notifies.expect(3) notifies.expect(3)
cc.persist.LoginName = "user1" cc.persist.LoginName = "user1"
cc.persist.UserProfile.LoginName = "user1"
cc.send(nil, "", true, &netmap.NetworkMap{}) cc.send(nil, "", true, &netmap.NetworkMap{})
{ {
nn := notifies.drain(3) nn := notifies.drain(3)
@ -698,6 +699,7 @@ func TestStateMachine(t *testing.T) {
t.Logf("\n\nLoginFinished3") t.Logf("\n\nLoginFinished3")
notifies.expect(3) notifies.expect(3)
cc.persist.LoginName = "user2" cc.persist.LoginName = "user2"
cc.persist.UserProfile.LoginName = "user2"
cc.send(nil, "", true, &netmap.NetworkMap{ cc.send(nil, "", true, &netmap.NetworkMap{
MachineStatus: tailcfg.MachineAuthorized, MachineStatus: tailcfg.MachineAuthorized,
}) })
@ -833,6 +835,7 @@ func TestStateMachine(t *testing.T) {
t.Logf("\n\nLoginDifferent URL visited") t.Logf("\n\nLoginDifferent URL visited")
notifies.expect(3) notifies.expect(3)
cc.persist.LoginName = "user3" cc.persist.LoginName = "user3"
cc.persist.UserProfile.LoginName = "user3"
cc.send(nil, "", true, &netmap.NetworkMap{ cc.send(nil, "", true, &netmap.NetworkMap{
MachineStatus: tailcfg.MachineAuthorized, MachineStatus: tailcfg.MachineAuthorized,
}) })

@ -715,6 +715,13 @@ type LoginProfile struct {
// This is updated whenever the server provides a new UserProfile. // This is updated whenever the server provides a new UserProfile.
UserProfile tailcfg.UserProfile UserProfile tailcfg.UserProfile
// NodeID is the NodeID of the node that this profile is logged into.
// This should be stable across tagging and untagging nodes.
// It may seem redundant to check against both the UserProfile.UserID
// and the NodeID. However the NodeID can change if the node is deleted
// from the admin panel.
NodeID tailcfg.StableNodeID
// LocalUserID is the user ID of the user who created this profile. // LocalUserID is the user ID of the user who created this profile.
// It is only relevant on Windows where we have a multi-user system. // It is only relevant on Windows where we have a multi-user system.
// It is assigned once at profile creation time and never changes. // It is assigned once at profile creation time and never changes.

@ -38,6 +38,7 @@ type Persist struct {
LoginName string LoginName string
UserProfile tailcfg.UserProfile UserProfile tailcfg.UserProfile
NetworkLockKey key.NLPrivate NetworkLockKey key.NLPrivate
NodeID tailcfg.StableNodeID
} }
// PublicNodeKey returns the public key for the node key. // PublicNodeKey returns the public key for the node key.
@ -68,7 +69,8 @@ func (p *Persist) Equals(p2 *Persist) bool {
p.Provider == p2.Provider && p.Provider == p2.Provider &&
p.LoginName == p2.LoginName && p.LoginName == p2.LoginName &&
p.UserProfile == p2.UserProfile && p.UserProfile == p2.UserProfile &&
p.NetworkLockKey.Equal(p2.NetworkLockKey) p.NetworkLockKey.Equal(p2.NetworkLockKey) &&
p.NodeID == p2.NodeID
} }
func (p *Persist) Pretty() string { func (p *Persist) Pretty() string {

@ -33,4 +33,5 @@ var _PersistCloneNeedsRegeneration = Persist(struct {
LoginName string LoginName string
UserProfile tailcfg.UserProfile UserProfile tailcfg.UserProfile
NetworkLockKey key.NLPrivate NetworkLockKey key.NLPrivate
NodeID tailcfg.StableNodeID
}{}) }{})

@ -22,7 +22,7 @@ func fieldsOf(t reflect.Type) (fields []string) {
} }
func TestPersistEqual(t *testing.T) { func TestPersistEqual(t *testing.T) {
persistHandles := []string{"LegacyFrontendPrivateMachineKey", "PrivateNodeKey", "OldPrivateNodeKey", "Provider", "LoginName", "UserProfile", "NetworkLockKey"} persistHandles := []string{"LegacyFrontendPrivateMachineKey", "PrivateNodeKey", "OldPrivateNodeKey", "Provider", "LoginName", "UserProfile", "NetworkLockKey", "NodeID"}
if have := fieldsOf(reflect.TypeOf(Persist{})); !reflect.DeepEqual(have, persistHandles) { if have := fieldsOf(reflect.TypeOf(Persist{})); !reflect.DeepEqual(have, persistHandles) {
t.Errorf("Persist.Equal check might be out of sync\nfields: %q\nhandled: %q\n", t.Errorf("Persist.Equal check might be out of sync\nfields: %q\nhandled: %q\n",
have, persistHandles) have, persistHandles)
@ -123,6 +123,16 @@ func TestPersistEqual(t *testing.T) {
&Persist{NetworkLockKey: key.NewNLPrivate()}, &Persist{NetworkLockKey: key.NewNLPrivate()},
false, false,
}, },
{
&Persist{NodeID: "abc"},
&Persist{NodeID: "abc"},
true,
},
{
&Persist{NodeID: ""},
&Persist{NodeID: "abc"},
false,
},
} }
for i, test := range tests { for i, test := range tests {
if got := test.a.Equals(test.b); got != test.want { if got := test.a.Equals(test.b); got != test.want {

@ -71,6 +71,7 @@ func (v PersistView) Provider() string { return v.ж.Provider
func (v PersistView) LoginName() string { return v.ж.LoginName } func (v PersistView) LoginName() string { return v.ж.LoginName }
func (v PersistView) UserProfile() tailcfg.UserProfile { return v.ж.UserProfile } func (v PersistView) UserProfile() tailcfg.UserProfile { return v.ж.UserProfile }
func (v PersistView) NetworkLockKey() key.NLPrivate { return v.ж.NetworkLockKey } func (v PersistView) NetworkLockKey() key.NLPrivate { return v.ж.NetworkLockKey }
func (v PersistView) NodeID() tailcfg.StableNodeID { return v.ж.NodeID }
// A compilation failure here means this code must be regenerated, with the command at the top of this file. // A compilation failure here means this code must be regenerated, with the command at the top of this file.
var _PersistViewNeedsRegeneration = Persist(struct { var _PersistViewNeedsRegeneration = Persist(struct {
@ -82,4 +83,5 @@ var _PersistViewNeedsRegeneration = Persist(struct {
LoginName string LoginName string
UserProfile tailcfg.UserProfile UserProfile tailcfg.UserProfile
NetworkLockKey key.NLPrivate NetworkLockKey key.NLPrivate
NodeID tailcfg.StableNodeID
}{}) }{})

Loading…
Cancel
Save