diff --git a/cmd/derper/bootstrap_dns.go b/cmd/derper/bootstrap_dns.go index 83961a5ef..e7d96f466 100644 --- a/cmd/derper/bootstrap_dns.go +++ b/cmd/derper/bootstrap_dns.go @@ -8,13 +8,13 @@ import ( "encoding/json" "expvar" "log" - "math/rand" "net" "net/http" "strings" "time" "tailscale.com/syncs" + "tailscale.com/util/slicesx" ) const refreshTimeout = time.Minute @@ -57,7 +57,7 @@ func refreshBootstrapDNS() { // to IPv6 for k := range dnsEntries { ips := dnsEntries[k] - rand.Shuffle(len(ips), func(i, j int) { ips[i], ips[j] = ips[j], ips[i] }) + slicesx.Shuffle(ips) dnsEntries[k] = ips } j, err := json.MarshalIndent(dnsEntries, "", "\t") diff --git a/cmd/derper/depaware.txt b/cmd/derper/depaware.txt index c2e7dcaf4..e94fe8a49 100644 --- a/cmd/derper/depaware.txt +++ b/cmd/derper/depaware.txt @@ -87,6 +87,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa tailscale.com/util/multierr from tailscale.com/health tailscale.com/util/set from tailscale.com/health tailscale.com/util/singleflight from tailscale.com/net/dnscache + tailscale.com/util/slicesx from tailscale.com/cmd/derper+ tailscale.com/util/vizerror from tailscale.com/tsweb W 💣 tailscale.com/util/winutil from tailscale.com/hostinfo+ tailscale.com/version from tailscale.com/derp+ diff --git a/cmd/tailscale/depaware.txt b/cmd/tailscale/depaware.txt index e0db11c38..066cacd6d 100644 --- a/cmd/tailscale/depaware.txt +++ b/cmd/tailscale/depaware.txt @@ -122,6 +122,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep tailscale.com/util/quarantine from tailscale.com/cmd/tailscale/cli tailscale.com/util/set from tailscale.com/health+ tailscale.com/util/singleflight from tailscale.com/net/dnscache + tailscale.com/util/slicesx from tailscale.com/net/dnscache+ 💣 tailscale.com/util/winutil from tailscale.com/hostinfo+ tailscale.com/version from tailscale.com/cmd/tailscale/cli+ tailscale.com/version/distro from tailscale.com/cmd/tailscale/cli+ diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index 9e6e0c45d..4615e0e77 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -305,6 +305,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/util/racebuild from tailscale.com/logpolicy tailscale.com/util/set from tailscale.com/health+ tailscale.com/util/singleflight from tailscale.com/control/controlclient+ + tailscale.com/util/slicesx from tailscale.com/net/dnscache+ tailscale.com/util/systemd from tailscale.com/control/controlclient+ tailscale.com/util/uniq from tailscale.com/wgengine/magicsock+ tailscale.com/util/vizerror from tailscale.com/tsweb diff --git a/net/dnscache/dnscache.go b/net/dnscache/dnscache.go index 52de322af..5a5978f5f 100644 --- a/net/dnscache/dnscache.go +++ b/net/dnscache/dnscache.go @@ -24,6 +24,7 @@ import ( "tailscale.com/types/logger" "tailscale.com/util/cloudenv" "tailscale.com/util/singleflight" + "tailscale.com/util/slicesx" ) var zaddr netip.Addr @@ -577,7 +578,7 @@ func (dc *dialCall) raceDial(ctx context.Context, ips []netip.Addr) (net.Conn, e iv4 = append(iv4, ip) } } - ips = interleaveSlices(iv6, iv4) + ips = slicesx.Interleave(iv6, iv4) go func() { for i, ip := range ips { @@ -636,21 +637,6 @@ func (dc *dialCall) raceDial(ctx context.Context, ips []netip.Addr) (net.Conn, e } } -// interleaveSlices combines two slices of the form [a, b, c] and [x, y, z] -// into a slice with elements interleaved; i.e. [a, x, b, y, c, z]. -func interleaveSlices[T any](a, b []T) []T { - var ( - i int - ret = make([]T, 0, len(a)+len(b)) - ) - for i = 0; i < len(a) && i < len(b); i++ { - ret = append(ret, a[i], b[i]) - } - ret = append(ret, a[i:]...) - ret = append(ret, b[i:]...) - return ret -} - func v4addrs(aa []netip.Addr) (ret []netip.Addr) { for _, a := range aa { a = a.Unmap() diff --git a/net/dnscache/dnscache_test.go b/net/dnscache/dnscache_test.go index f143995ef..3f6360a10 100644 --- a/net/dnscache/dnscache_test.go +++ b/net/dnscache/dnscache_test.go @@ -141,30 +141,6 @@ func TestResolverAllHostStaticResult(t *testing.T) { } } -func TestInterleaveSlices(t *testing.T) { - testCases := []struct { - name string - a, b []int - want []int - }{ - {name: "equal", a: []int{1, 3, 5}, b: []int{2, 4, 6}, want: []int{1, 2, 3, 4, 5, 6}}, - {name: "short_b", a: []int{1, 3, 5}, b: []int{2, 4}, want: []int{1, 2, 3, 4, 5}}, - {name: "short_a", a: []int{1, 3}, b: []int{2, 4, 6}, want: []int{1, 2, 3, 4, 6}}, - {name: "len_1", a: []int{1}, b: []int{2, 4, 6}, want: []int{1, 2, 4, 6}}, - {name: "nil_a", a: nil, b: []int{2, 4, 6}, want: []int{2, 4, 6}}, - {name: "nil_all", a: nil, b: nil, want: []int{}}, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - merged := interleaveSlices(tc.a, tc.b) - if !reflect.DeepEqual(merged, tc.want) { - t.Errorf("got %v; want %v", merged, tc.want) - } - }) - } -} - func TestShouldTryBootstrap(t *testing.T) { oldDebug := debug t.Cleanup(func() { debug = oldDebug }) diff --git a/net/dnsfallback/dnsfallback.go b/net/dnsfallback/dnsfallback.go index 175a74b53..584a10807 100644 --- a/net/dnsfallback/dnsfallback.go +++ b/net/dnsfallback/dnsfallback.go @@ -14,7 +14,6 @@ import ( "errors" "fmt" "log" - "math/rand" "net" "net/http" "net/netip" @@ -31,6 +30,7 @@ import ( "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/types/logger" + "tailscale.com/util/slicesx" ) func Lookup(ctx context.Context, host string) ([]netip.Addr, error) { @@ -56,8 +56,8 @@ func Lookup(ctx context.Context, host string) ([]netip.Addr, error) { } } } - rand.Shuffle(len(cands4), func(i, j int) { cands4[i], cands4[j] = cands4[j], cands4[i] }) - rand.Shuffle(len(cands6), func(i, j int) { cands6[i], cands6[j] = cands6[j], cands6[i] }) + slicesx.Shuffle(cands4) + slicesx.Shuffle(cands6) const maxCands = 6 var cands []nameIP // up to maxCands alternating v4/v6 as long as we have both @@ -87,7 +87,7 @@ func Lookup(ctx context.Context, host string) ([]netip.Addr, error) { continue } if ips := dm[host]; len(ips) > 0 { - rand.Shuffle(len(ips), func(i, j int) { ips[i], ips[j] = ips[j], ips[i] }) + slicesx.Shuffle(ips) logf("bootstrapDNS(%q, %q) for %q = %v", cand.dnsName, cand.ip, host, ips) return ips, nil } diff --git a/util/slicesx/slicesx.go b/util/slicesx/slicesx.go new file mode 100644 index 000000000..ce55594db --- /dev/null +++ b/util/slicesx/slicesx.go @@ -0,0 +1,44 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package slicesx contains some helpful generic slice functions. +package slicesx + +import "math/rand" + +// Interleave combines two slices of the form [a, b, c] and [x, y, z] into a +// slice with elements interleaved; i.e. [a, x, b, y, c, z]. +func Interleave[S ~[]T, T any](a, b S) S { + // Avoid allocating an empty slice. + if a == nil && b == nil { + return nil + } + + var ( + i int + ret = make([]T, 0, len(a)+len(b)) + ) + for i = 0; i < len(a) && i < len(b); i++ { + ret = append(ret, a[i], b[i]) + } + ret = append(ret, a[i:]...) + ret = append(ret, b[i:]...) + return ret +} + +// Shuffle randomly shuffles a slice in-place, similar to rand.Shuffle. +func Shuffle[S ~[]T, T any](s S) { + // TODO(andrew): use a pooled Rand? + + // This is the same Fisher-Yates shuffle implementation as rand.Shuffle + n := len(s) + i := n - 1 + for ; i > 1<<31-1-1; i-- { + j := int(rand.Int63n(int64(i + 1))) + s[i], s[j] = s[j], s[i] + } + for ; i > 0; i-- { + j := int(rand.Int31n(int32(i + 1))) + s[i], s[j] = s[j], s[i] + } +} diff --git a/util/slicesx/slicesx_test.go b/util/slicesx/slicesx_test.go new file mode 100644 index 000000000..1d6062d6a --- /dev/null +++ b/util/slicesx/slicesx_test.go @@ -0,0 +1,66 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package slicesx + +import ( + "reflect" + "testing" + + "golang.org/x/exp/slices" +) + +func TestInterleave(t *testing.T) { + testCases := []struct { + name string + a, b []int + want []int + }{ + {name: "equal", a: []int{1, 3, 5}, b: []int{2, 4, 6}, want: []int{1, 2, 3, 4, 5, 6}}, + {name: "short_b", a: []int{1, 3, 5}, b: []int{2, 4}, want: []int{1, 2, 3, 4, 5}}, + {name: "short_a", a: []int{1, 3}, b: []int{2, 4, 6}, want: []int{1, 2, 3, 4, 6}}, + {name: "len_1", a: []int{1}, b: []int{2, 4, 6}, want: []int{1, 2, 4, 6}}, + {name: "nil_a", a: nil, b: []int{2, 4, 6}, want: []int{2, 4, 6}}, + {name: "nil_all", a: nil, b: nil, want: nil}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + merged := Interleave(tc.a, tc.b) + if !reflect.DeepEqual(merged, tc.want) { + t.Errorf("got %v; want %v", merged, tc.want) + } + }) + } +} + +func BenchmarkInterleave(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + Interleave( + []int{1, 2, 3}, + []int{9, 8, 7}, + ) + } +} +func TestShuffle(t *testing.T) { + var sl []int + for i := 0; i < 100; i++ { + sl = append(sl, i) + } + + var wasShuffled bool + for try := 0; try < 10; try++ { + shuffled := slices.Clone(sl) + Shuffle(shuffled) + if !reflect.DeepEqual(shuffled, sl) { + wasShuffled = true + break + } + } + + if !wasShuffled { + t.Errorf("expected shuffle after 10 tries") + } +}