diff --git a/ipn/ipnlocal/profiles.go b/ipn/ipnlocal/profiles.go index 5d14c1d94..1e996c357 100644 --- a/ipn/ipnlocal/profiles.go +++ b/ipn/ipnlocal/profiles.go @@ -51,6 +51,7 @@ func (pm *profileManager) SetCurrentUser(uid string) error { if pm.currentUserID == uid { return nil } + pm.currentUserID = uid cpk := ipn.CurrentProfileKey(uid) if b, err := pm.store.ReadState(cpk); err == nil { pk := ipn.StateKey(string(b)) @@ -66,34 +67,36 @@ func (pm *profileManager) SetCurrentUser(uid string) error { } else { return err } - pm.currentUserID = uid return nil } -func (pm *profileManager) findProfilesByNodeID(nodeID tailcfg.StableNodeID) []*ipn.LoginProfile { - if nodeID.IsZero() { - return nil - } - var out []*ipn.LoginProfile +// matchingProfiles returns all profiles that match the given predicate and +// belong to the currentUserID. +func (pm *profileManager) matchingProfiles(f func(*ipn.LoginProfile) bool) (out []*ipn.LoginProfile) { for _, p := range pm.knownProfiles { - if p.NodeID == nodeID { + if p.LocalUserID == pm.currentUserID && f(p) { out = append(out, p) } } return out } +func (pm *profileManager) findProfilesByNodeID(nodeID tailcfg.StableNodeID) []*ipn.LoginProfile { + if nodeID.IsZero() { + return nil + } + return pm.matchingProfiles(func(p *ipn.LoginProfile) bool { + return p.NodeID == nodeID + }) +} + func (pm *profileManager) findProfilesByUserID(userID tailcfg.UserID) []*ipn.LoginProfile { if userID.IsZero() { return nil } - var out []*ipn.LoginProfile - for _, p := range pm.knownProfiles { - if p.UserProfile.ID == userID { - out = append(out, p) - } - } - return out + return pm.matchingProfiles(func(p *ipn.LoginProfile) bool { + return p.UserProfile.ID == userID + }) } // ProfileIDForName returns the profile ID for the profile with the @@ -107,21 +110,29 @@ func (pm *profileManager) ProfileIDForName(name string) ipn.ProfileID { } func (pm *profileManager) findProfileByName(name string) *ipn.LoginProfile { - for _, p := range pm.knownProfiles { - if p.Name == name { - return p - } + out := pm.matchingProfiles(func(p *ipn.LoginProfile) bool { + return p.Name == name + }) + if len(out) == 0 { + return nil } - return nil + if len(out) > 1 { + pm.logf("[unxpected] multiple profiles with the same name") + } + return out[0] } func (pm *profileManager) findProfileByKey(key ipn.StateKey) *ipn.LoginProfile { - for _, p := range pm.knownProfiles { - if p.Key == key { - return p - } + out := pm.matchingProfiles(func(p *ipn.LoginProfile) bool { + return p.Key == key + }) + if len(out) == 0 { + return nil } - return nil + if len(out) > 1 { + pm.logf("[unxpected] multiple profiles with the same key") + } + return out[0] } func (pm *profileManager) setUnattendedModeAsConfigured() error { @@ -244,16 +255,15 @@ func (pm *profileManager) writePrefsToStore(key ipn.StateKey, prefs ipn.PrefsVie // Profiles returns the list of known profiles. func (pm *profileManager) Profiles() []ipn.LoginProfile { - var profiles []ipn.LoginProfile - for _, p := range pm.knownProfiles { - if p.LocalUserID == pm.currentUserID { - profiles = append(profiles, *p) - } - } - slices.SortFunc(profiles, func(a, b ipn.LoginProfile) bool { + profiles := pm.matchingProfiles(func(*ipn.LoginProfile) bool { return true }) + slices.SortFunc(profiles, func(a, b *ipn.LoginProfile) bool { return a.Name < b.Name }) - return profiles + out := make([]ipn.LoginProfile, 0, len(profiles)) + for _, p := range profiles { + out = append(out, *p) + } + return out } // SwitchProfile switches to the profile with the given id. @@ -431,6 +441,7 @@ func readKnownProfiles(store ipn.StateStore) (map[ipn.ProfileID]*ipn.LoginProfil } func newProfileManagerWithGOOS(store ipn.StateStore, logf logger.Logf, stateKey ipn.StateKey, goos string) (*profileManager, error) { + logf = logger.WithPrefix(logf, "pm: ") if stateKey == "" { var err error stateKey, err = readAutoStartKey(store, goos) diff --git a/ipn/ipnlocal/profiles_test.go b/ipn/ipnlocal/profiles_test.go index a7ffee5c2..75e98c7d2 100644 --- a/ipn/ipnlocal/profiles_test.go +++ b/ipn/ipnlocal/profiles_test.go @@ -16,6 +16,76 @@ import ( "tailscale.com/types/persist" ) +func TestProfileList(t *testing.T) { + store := new(mem.Store) + + pm, err := newProfileManagerWithGOOS(store, logger.Discard, "", "linux") + if err != nil { + t.Fatal(err) + } + id := 0 + newProfile := func(t *testing.T, loginName string) ipn.PrefsView { + id++ + t.Helper() + pm.NewProfile() + p := pm.CurrentPrefs().AsStruct() + p.Persist = &persist.Persist{ + NodeID: tailcfg.StableNodeID(fmt.Sprint(id)), + LoginName: loginName, + PrivateNodeKey: key.NewNode(), + UserProfile: tailcfg.UserProfile{ + ID: tailcfg.UserID(id), + LoginName: loginName, + }, + } + if err := pm.SetPrefs(p.View()); err != nil { + t.Fatal(err) + } + return p.View() + } + checkProfiles := func(t *testing.T, want ...string) { + t.Helper() + got := pm.Profiles() + if len(got) != len(want) { + t.Fatalf("got %d profiles, want %d", len(got), len(want)) + } + for i, w := range want { + if got[i].Name != w { + t.Errorf("got profile %d name %q, want %q", i, got[i].Name, w) + } + } + } + + pm.SetCurrentUser("user1") + newProfile(t, "alice") + newProfile(t, "bob") + checkProfiles(t, "alice", "bob") + + pm.SetCurrentUser("user2") + checkProfiles(t) + newProfile(t, "carol") + carol := pm.currentProfile + checkProfiles(t, "carol") + + pm.SetCurrentUser("user1") + checkProfiles(t, "alice", "bob") + if lp := pm.findProfileByKey(carol.Key); lp != nil { + t.Fatalf("found profile for user2 in user1's profile list") + } + if lp := pm.findProfileByName(carol.Name); lp != nil { + t.Fatalf("found profile for user2 in user1's profile list") + } + if lp := pm.findProfilesByNodeID(carol.NodeID); lp != nil { + t.Fatalf("found profile for user2 in user1's profile list") + } + if lp := pm.findProfilesByUserID(carol.UserProfile.ID); lp != nil { + t.Fatalf("found profile for user2 in user1's profile list") + } + + pm.SetCurrentUser("user2") + checkProfiles(t, "carol") +} + // TestProfileManagement tests creating, loading, and switching profiles. func TestProfileManagement(t *testing.T) { store := new(mem.Store)