diff --git a/util/set/slice.go b/util/set/slice.go new file mode 100644 index 000000000..589b903df --- /dev/null +++ b/util/set/slice.go @@ -0,0 +1,69 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package set + +import ( + "golang.org/x/exp/slices" + "tailscale.com/types/views" +) + +// Slice is a set of elements tracked in a slice of unique elements. +type Slice[T comparable] struct { + slice []T + set map[T]bool // nil until/unless slice is large enough +} + +// Slice returns the a view of the underlying slice. +// The elements are in order of insertion. +// The returned value is only valid until ss is modified again. +func (ss *Slice[T]) Slice() views.Slice[T] { return views.SliceOf(ss.slice) } + +// Contains reports whether v is in the set. +// The amortized cost is O(1). +func (ss *Slice[T]) Contains(v T) bool { + if ss.set != nil { + return ss.set[v] + } + return slices.Index(ss.slice, v) != -1 +} + +// Remove removes v from the set. +// The cost is O(n). +func (ss *Slice[T]) Remove(v T) { + if ss.set != nil { + if !ss.set[v] { + return + } + delete(ss.set, v) + } + if ix := slices.Index(ss.slice, v); ix != -1 { + ss.slice = append(ss.slice[:ix], ss.slice[ix+1:]...) + } +} + +// Add adds each element in vs to the set. +// The amortized cost is O(1) per element. +func (ss *Slice[T]) Add(vs ...T) { + for _, v := range vs { + if ss.Contains(v) { + continue + } + ss.slice = append(ss.slice, v) + if ss.set != nil { + ss.set[v] = true + } else if len(ss.slice) > 8 { + ss.set = make(map[T]bool, len(ss.slice)) + for _, v := range ss.slice { + ss.set[v] = true + } + } + } +} + +// AddSlice adds all elements in vs to the set. +func (ss *Slice[T]) AddSlice(vs views.Slice[T]) { + for i := 0; i < vs.Len(); i++ { + ss.Add(vs.At(i)) + } +} diff --git a/util/set/slice_test.go b/util/set/slice_test.go new file mode 100644 index 000000000..9134c2962 --- /dev/null +++ b/util/set/slice_test.go @@ -0,0 +1,56 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package set + +import ( + "testing" + + qt "github.com/frankban/quicktest" +) + +func TestSliceSet(t *testing.T) { + c := qt.New(t) + + var ss Slice[int] + c.Check(len(ss.slice), qt.Equals, 0) + ss.Add(1) + c.Check(len(ss.slice), qt.Equals, 1) + c.Check(len(ss.set), qt.Equals, 0) + c.Check(ss.Contains(1), qt.Equals, true) + c.Check(ss.Contains(2), qt.Equals, false) + + ss.Add(1) + c.Check(len(ss.slice), qt.Equals, 1) + c.Check(len(ss.set), qt.Equals, 0) + + ss.Add(2) + ss.Add(3) + ss.Add(4) + ss.Add(5) + ss.Add(6) + ss.Add(7) + ss.Add(8) + c.Check(len(ss.slice), qt.Equals, 8) + c.Check(len(ss.set), qt.Equals, 0) + + ss.Add(9) + c.Check(len(ss.slice), qt.Equals, 9) + c.Check(len(ss.set), qt.Equals, 9) + + ss.Remove(4) + c.Check(len(ss.slice), qt.Equals, 8) + c.Check(len(ss.set), qt.Equals, 8) + c.Assert(ss.Contains(4), qt.IsFalse) + + // Ensure that the order of insertion is maintained + c.Assert(ss.Slice().AsSlice(), qt.DeepEquals, []int{1, 2, 3, 5, 6, 7, 8, 9}) + ss.Add(4) + c.Check(len(ss.slice), qt.Equals, 9) + c.Check(len(ss.set), qt.Equals, 9) + c.Assert(ss.Contains(4), qt.IsTrue) + c.Assert(ss.Slice().AsSlice(), qt.DeepEquals, []int{1, 2, 3, 5, 6, 7, 8, 9, 4}) + + ss.Add(1, 234, 556) + c.Assert(ss.Slice().AsSlice(), qt.DeepEquals, []int{1, 2, 3, 5, 6, 7, 8, 9, 4, 234, 556}) +}