util/slicesx: add EqualSameNil, like slices.Equal but same nilness

Then use it in tailcfg which had it duplicated a couple times.

I think we have it a few other places too.

And use slices.Equal in wgengine/router too. (found while looking for callers)

Updates #cleanup

Change-Id: If5350eee9b3ef071882a3db29a305081e4cd9d23
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/9610/head
Brad Fitzpatrick 1 year ago committed by Brad Fitzpatrick
parent 72e53749c1
commit 5f5c9142cc

@ -24,6 +24,7 @@ import (
"tailscale.com/types/tkatype"
"tailscale.com/util/cmpx"
"tailscale.com/util/dnsname"
"tailscale.com/util/slicesx"
)
// CapabilityVersion represents the client's capability level. That
@ -1939,10 +1940,10 @@ func (n *Node) Equal(n2 *Node) bool {
n.Machine == n2.Machine &&
n.DiscoKey == n2.DiscoKey &&
eqPtr(n.Online, n2.Online) &&
eqCIDRs(n.Addresses, n2.Addresses) &&
eqCIDRs(n.AllowedIPs, n2.AllowedIPs) &&
eqCIDRs(n.PrimaryRoutes, n2.PrimaryRoutes) &&
eqStrings(n.Endpoints, n2.Endpoints) &&
slicesx.EqualSameNil(n.Addresses, n2.Addresses) &&
slicesx.EqualSameNil(n.AllowedIPs, n2.AllowedIPs) &&
slicesx.EqualSameNil(n.PrimaryRoutes, n2.PrimaryRoutes) &&
slicesx.EqualSameNil(n.Endpoints, n2.Endpoints) &&
n.DERP == n2.DERP &&
n.Cap == n2.Cap &&
n.Hostinfo.Equal(n2.Hostinfo) &&
@ -1954,7 +1955,7 @@ func (n *Node) Equal(n2 *Node) bool {
n.ComputedName == n2.ComputedName &&
n.computedHostIfDifferent == n2.computedHostIfDifferent &&
n.ComputedNameWithHost == n2.ComputedNameWithHost &&
eqStrings(n.Tags, n2.Tags) &&
slicesx.EqualSameNil(n.Tags, n2.Tags) &&
n.Expired == n2.Expired &&
eqPtr(n.SelfNodeV4MasqAddrForThisPeer, n2.SelfNodeV4MasqAddrForThisPeer) &&
eqPtr(n.SelfNodeV6MasqAddrForThisPeer, n2.SelfNodeV6MasqAddrForThisPeer) &&
@ -1971,30 +1972,6 @@ func eqPtr[T comparable](a, b *T) bool {
return *a == *b
}
func eqStrings(a, b []string) bool {
if len(a) != len(b) || ((a == nil) != (b == nil)) {
return false
}
for i, v := range a {
if v != b[i] {
return false
}
}
return true
}
func eqCIDRs(a, b []netip.Prefix) bool {
if len(a) != len(b) || ((a == nil) != (b == nil)) {
return false
}
for i, v := range a {
if v != b[i] {
return false
}
}
return true
}
func eqTimePtr(a, b *time.Time) bool {
return ((a == nil) == (b == nil)) && (a == nil || a.Equal(*b))
}

@ -57,3 +57,23 @@ func Partition[S ~[]T, T any](s S, cb func(T) bool) (trues, falses S) {
}
return
}
// EqualSameNil reports whether two slices are equal: the same length, same
// nilness (notably when length zero), and all elements equal. If the lengths
// are different or their nilness differs, Equal returns false. Otherwise, the
// elements are compared in increasing index order, and the comparison stops at
// the first unequal pair. Floating point NaNs are not considered equal.
//
// It is identical to the standard library's slices.Equal but adds the matching
// nilness check.
func EqualSameNil[S ~[]E, E comparable](s1, s2 S) bool {
if len(s1) != len(s2) || (s1 == nil) != (s2 == nil) {
return false
}
for i := range s1 {
if s1[i] != s2[i] {
return false
}
}
return true
}

@ -7,6 +7,8 @@ import (
"reflect"
"slices"
"testing"
qt "github.com/frankban/quicktest"
)
func TestInterleave(t *testing.T) {
@ -84,3 +86,14 @@ func TestPartition(t *testing.T) {
t.Errorf("odds: got %v, want %v", odds, wantOdds)
}
}
func TestEqualSameNil(t *testing.T) {
c := qt.New(t)
c.Check(EqualSameNil([]string{"a"}, []string{"a"}), qt.Equals, true)
c.Check(EqualSameNil([]string{"a"}, []string{"b"}), qt.Equals, false)
c.Check(EqualSameNil([]string{"a"}, []string{}), qt.Equals, false)
c.Check(EqualSameNil([]string{}, []string{}), qt.Equals, true)
c.Check(EqualSameNil(nil, []string{}), qt.Equals, false)
c.Check(EqualSameNil([]string{}, nil), qt.Equals, false)
c.Check(EqualSameNil[[]string](nil, nil), qt.Equals, true)
}

@ -12,6 +12,7 @@ import (
"net/netip"
"os"
"os/exec"
"slices"
"strings"
"sync"
"syscall"
@ -196,7 +197,7 @@ func (ft *firewallTweaker) doAsyncSet() {
ft.mu.Lock()
for { // invariant: ft.mu must be locked when beginning this block
val := ft.wantLocal
if ft.known && strsEqual(ft.lastLocal, val) && ft.wantKillswitch == ft.lastKillswitch && routesEqual(ft.localRoutes, ft.lastLocalRoutes) {
if ft.known && slices.Equal(ft.lastLocal, val) && ft.wantKillswitch == ft.lastKillswitch && slices.Equal(ft.localRoutes, ft.lastLocalRoutes) {
ft.running = false
ft.logf("ending netsh goroutine")
ft.mu.Unlock()
@ -341,28 +342,3 @@ func (ft *firewallTweaker) doSet(local []string, killswitch bool, clear bool, pr
// in via stdin encoded in json.
return ft.fwProcEncoder.Encode(allowedRoutes)
}
func routesEqual(a, b []netip.Prefix) bool {
if len(a) != len(b) {
return false
}
// Routes are pre-sorted.
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
func strsEqual(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}

Loading…
Cancel
Save