diff --git a/util/uniq/slice.go b/util/uniq/slice.go index 5438a2556..49c240ca9 100644 --- a/util/uniq/slice.go +++ b/util/uniq/slice.go @@ -33,3 +33,31 @@ func ModifySlice[E comparable](slice *[]E) { *slice = (*slice)[:end] } } + +// ModifySliceFunc is the same as ModifySlice except that it allows using a +// custom comparison function. +// +// eq should report whether the two provided elements are equal. +func ModifySliceFunc[E any](slice *[]E, eq func(i, j E) bool) { + // Remove duplicates + dst := 0 + for i := 1; i < len(*slice); i++ { + if eq((*slice)[dst], (*slice)[i]) { + continue + } + dst++ + (*slice)[dst] = (*slice)[i] + } + + // Zero out the elements we removed at the end of the slice + end := dst + 1 + var zero E + for i := end; i < len(*slice); i++ { + (*slice)[i] = zero + } + + // Truncate the slice + if end < len(*slice) { + *slice = (*slice)[:end] + } +} diff --git a/util/uniq/slice_test.go b/util/uniq/slice_test.go index 313e2b435..d4f8018d7 100644 --- a/util/uniq/slice_test.go +++ b/util/uniq/slice_test.go @@ -12,22 +12,23 @@ import ( "tailscale.com/util/uniq" ) -func runTests(t *testing.T, cb func(*[]int)) { +func runTests(t *testing.T, cb func(*[]uint32)) { tests := []struct { - in []int - want []int + // Use uint32 to be different from an int-typed slice index + in []uint32 + want []uint32 }{ - {in: []int{0, 1, 2}, want: []int{0, 1, 2}}, - {in: []int{0, 1, 2, 2}, want: []int{0, 1, 2}}, - {in: []int{0, 0, 1, 2}, want: []int{0, 1, 2}}, - {in: []int{0, 1, 0, 2}, want: []int{0, 1, 0, 2}}, - {in: []int{0}, want: []int{0}}, - {in: []int{0, 0}, want: []int{0}}, - {in: []int{}, want: []int{}}, + {in: []uint32{0, 1, 2}, want: []uint32{0, 1, 2}}, + {in: []uint32{0, 1, 2, 2}, want: []uint32{0, 1, 2}}, + {in: []uint32{0, 0, 1, 2}, want: []uint32{0, 1, 2}}, + {in: []uint32{0, 1, 0, 2}, want: []uint32{0, 1, 0, 2}}, + {in: []uint32{0}, want: []uint32{0}}, + {in: []uint32{0, 0}, want: []uint32{0}}, + {in: []uint32{}, want: []uint32{}}, } for _, test := range tests { - in := make([]int, len(test.in)) + in := make([]uint32, len(test.in)) copy(in, test.in) cb(&test.in) if !reflect.DeepEqual(test.in, test.want) { @@ -44,11 +45,19 @@ func runTests(t *testing.T, cb func(*[]int)) { } func TestModifySlice(t *testing.T) { - runTests(t, func(slice *[]int) { + runTests(t, func(slice *[]uint32) { uniq.ModifySlice(slice) }) } +func TestModifySliceFunc(t *testing.T) { + runTests(t, func(slice *[]uint32) { + uniq.ModifySliceFunc(slice, func(i, j uint32) bool { + return i == j + }) + }) +} + func Benchmark(b *testing.B) { benches := []struct { name string