ipn/ipnlocal: account for currentUserID when iterating over knownProfiles

We were not checking the currentUserID in all code paths that looped over
knownProfiles. This only impacted multi-user Windows setups.

Updates #713

Signed-off-by: Maisem Ali <maisem@tailscale.com>
pull/6442/head
Maisem Ali 2 years ago committed by Maisem Ali
parent 18c7c3981a
commit aeac4bc8e2

@ -51,6 +51,7 @@ func (pm *profileManager) SetCurrentUser(uid string) error {
if pm.currentUserID == uid { if pm.currentUserID == uid {
return nil return nil
} }
pm.currentUserID = uid
cpk := ipn.CurrentProfileKey(uid) cpk := ipn.CurrentProfileKey(uid)
if b, err := pm.store.ReadState(cpk); err == nil { if b, err := pm.store.ReadState(cpk); err == nil {
pk := ipn.StateKey(string(b)) pk := ipn.StateKey(string(b))
@ -66,34 +67,36 @@ func (pm *profileManager) SetCurrentUser(uid string) error {
} else { } else {
return err return err
} }
pm.currentUserID = uid
return nil return nil
} }
func (pm *profileManager) findProfilesByNodeID(nodeID tailcfg.StableNodeID) []*ipn.LoginProfile { // matchingProfiles returns all profiles that match the given predicate and
if nodeID.IsZero() { // belong to the currentUserID.
return nil func (pm *profileManager) matchingProfiles(f func(*ipn.LoginProfile) bool) (out []*ipn.LoginProfile) {
}
var out []*ipn.LoginProfile
for _, p := range pm.knownProfiles { for _, p := range pm.knownProfiles {
if p.NodeID == nodeID { if p.LocalUserID == pm.currentUserID && f(p) {
out = append(out, p) out = append(out, p)
} }
} }
return out return out
} }
func (pm *profileManager) findProfilesByUserID(userID tailcfg.UserID) []*ipn.LoginProfile { func (pm *profileManager) findProfilesByNodeID(nodeID tailcfg.StableNodeID) []*ipn.LoginProfile {
if userID.IsZero() { if nodeID.IsZero() {
return nil return nil
} }
var out []*ipn.LoginProfile return pm.matchingProfiles(func(p *ipn.LoginProfile) bool {
for _, p := range pm.knownProfiles { return p.NodeID == nodeID
if p.UserProfile.ID == userID { })
out = append(out, p)
} }
func (pm *profileManager) findProfilesByUserID(userID tailcfg.UserID) []*ipn.LoginProfile {
if userID.IsZero() {
return nil
} }
return out return pm.matchingProfiles(func(p *ipn.LoginProfile) bool {
return p.UserProfile.ID == userID
})
} }
// ProfileIDForName returns the profile ID for the profile with the // ProfileIDForName returns the profile ID for the profile with the
@ -107,21 +110,29 @@ func (pm *profileManager) ProfileIDForName(name string) ipn.ProfileID {
} }
func (pm *profileManager) findProfileByName(name string) *ipn.LoginProfile { func (pm *profileManager) findProfileByName(name string) *ipn.LoginProfile {
for _, p := range pm.knownProfiles { out := pm.matchingProfiles(func(p *ipn.LoginProfile) bool {
if p.Name == name { return p.Name == name
return p })
if len(out) == 0 {
return nil
} }
if len(out) > 1 {
pm.logf("[unxpected] multiple profiles with the same name")
} }
return nil return out[0]
} }
func (pm *profileManager) findProfileByKey(key ipn.StateKey) *ipn.LoginProfile { func (pm *profileManager) findProfileByKey(key ipn.StateKey) *ipn.LoginProfile {
for _, p := range pm.knownProfiles { out := pm.matchingProfiles(func(p *ipn.LoginProfile) bool {
if p.Key == key { return p.Key == key
return p })
if len(out) == 0 {
return nil
} }
if len(out) > 1 {
pm.logf("[unxpected] multiple profiles with the same key")
} }
return nil return out[0]
} }
func (pm *profileManager) setUnattendedModeAsConfigured() error { func (pm *profileManager) setUnattendedModeAsConfigured() error {
@ -244,16 +255,15 @@ func (pm *profileManager) writePrefsToStore(key ipn.StateKey, prefs ipn.PrefsVie
// Profiles returns the list of known profiles. // Profiles returns the list of known profiles.
func (pm *profileManager) Profiles() []ipn.LoginProfile { func (pm *profileManager) Profiles() []ipn.LoginProfile {
var profiles []ipn.LoginProfile profiles := pm.matchingProfiles(func(*ipn.LoginProfile) bool { return true })
for _, p := range pm.knownProfiles { slices.SortFunc(profiles, func(a, b *ipn.LoginProfile) bool {
if p.LocalUserID == pm.currentUserID {
profiles = append(profiles, *p)
}
}
slices.SortFunc(profiles, func(a, b ipn.LoginProfile) bool {
return a.Name < b.Name return a.Name < b.Name
}) })
return profiles out := make([]ipn.LoginProfile, 0, len(profiles))
for _, p := range profiles {
out = append(out, *p)
}
return out
} }
// SwitchProfile switches to the profile with the given id. // SwitchProfile switches to the profile with the given id.
@ -431,6 +441,7 @@ func readKnownProfiles(store ipn.StateStore) (map[ipn.ProfileID]*ipn.LoginProfil
} }
func newProfileManagerWithGOOS(store ipn.StateStore, logf logger.Logf, stateKey ipn.StateKey, goos string) (*profileManager, error) { func newProfileManagerWithGOOS(store ipn.StateStore, logf logger.Logf, stateKey ipn.StateKey, goos string) (*profileManager, error) {
logf = logger.WithPrefix(logf, "pm: ")
if stateKey == "" { if stateKey == "" {
var err error var err error
stateKey, err = readAutoStartKey(store, goos) stateKey, err = readAutoStartKey(store, goos)

@ -16,6 +16,76 @@ import (
"tailscale.com/types/persist" "tailscale.com/types/persist"
) )
func TestProfileList(t *testing.T) {
store := new(mem.Store)
pm, err := newProfileManagerWithGOOS(store, logger.Discard, "", "linux")
if err != nil {
t.Fatal(err)
}
id := 0
newProfile := func(t *testing.T, loginName string) ipn.PrefsView {
id++
t.Helper()
pm.NewProfile()
p := pm.CurrentPrefs().AsStruct()
p.Persist = &persist.Persist{
NodeID: tailcfg.StableNodeID(fmt.Sprint(id)),
LoginName: loginName,
PrivateNodeKey: key.NewNode(),
UserProfile: tailcfg.UserProfile{
ID: tailcfg.UserID(id),
LoginName: loginName,
},
}
if err := pm.SetPrefs(p.View()); err != nil {
t.Fatal(err)
}
return p.View()
}
checkProfiles := func(t *testing.T, want ...string) {
t.Helper()
got := pm.Profiles()
if len(got) != len(want) {
t.Fatalf("got %d profiles, want %d", len(got), len(want))
}
for i, w := range want {
if got[i].Name != w {
t.Errorf("got profile %d name %q, want %q", i, got[i].Name, w)
}
}
}
pm.SetCurrentUser("user1")
newProfile(t, "alice")
newProfile(t, "bob")
checkProfiles(t, "alice", "bob")
pm.SetCurrentUser("user2")
checkProfiles(t)
newProfile(t, "carol")
carol := pm.currentProfile
checkProfiles(t, "carol")
pm.SetCurrentUser("user1")
checkProfiles(t, "alice", "bob")
if lp := pm.findProfileByKey(carol.Key); lp != nil {
t.Fatalf("found profile for user2 in user1's profile list")
}
if lp := pm.findProfileByName(carol.Name); lp != nil {
t.Fatalf("found profile for user2 in user1's profile list")
}
if lp := pm.findProfilesByNodeID(carol.NodeID); lp != nil {
t.Fatalf("found profile for user2 in user1's profile list")
}
if lp := pm.findProfilesByUserID(carol.UserProfile.ID); lp != nil {
t.Fatalf("found profile for user2 in user1's profile list")
}
pm.SetCurrentUser("user2")
checkProfiles(t, "carol")
}
// TestProfileManagement tests creating, loading, and switching profiles. // TestProfileManagement tests creating, loading, and switching profiles.
func TestProfileManagement(t *testing.T) { func TestProfileManagement(t *testing.T) {
store := new(mem.Store) store := new(mem.Store)

Loading…
Cancel
Save