// Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause package art import ( "bytes" "fmt" "math/rand" "net/netip" "sort" "strings" "testing" "github.com/google/go-cmp/cmp" "tailscale.com/types/ptr" ) func TestInversePrefix(t *testing.T) { t.Parallel() for i := 0; i < 256; i++ { for len := 0; len < 9; len++ { addr := i & (0xFF << (8 - len)) idx := prefixIndex(uint8(addr), len) addr2, len2 := inversePrefixIndex(idx) if addr2 != uint8(addr) || len2 != len { t.Errorf("inverse(index(%d/%d)) != %d/%d", addr, len, addr2, len2) } } } } func TestHostIndex(t *testing.T) { t.Parallel() for i := 0; i < 256; i++ { got := hostIndex(uint8(i)) want := prefixIndex(uint8(i), 8) if got != want { t.Errorf("hostIndex(%d) = %d, want %d", i, got, want) } } } func TestStrideTableInsert(t *testing.T) { t.Parallel() // Verify that strideTable's lookup results after a bunch of inserts exactly // match those of a naive implementation that just scans all prefixes on // every lookup. The naive implementation is very slow, but its behavior is // easy to verify by inspection. pfxs := shufflePrefixes(allPrefixes())[:100] slow := slowTable[int]{pfxs} fast := strideTable[int]{} t.Logf("slow table:\n%s", slow.String()) for _, pfx := range pfxs { fast.insert(pfx.addr, pfx.len, pfx.val) t.Logf("after insert %d/%d:\n%s", pfx.addr, pfx.len, fast.tableDebugString()) } for i := 0; i < 256; i++ { addr := uint8(i) slowVal := slow.get(addr) fastVal := fast.get(addr) if slowVal != fastVal { t.Fatalf("strideTable.get(%d) = %v, want %v", addr, *fastVal, *slowVal) } } } func TestStrideTableInsertShuffled(t *testing.T) { t.Parallel() // The order in which routes are inserted into a route table does not // influence the final shape of the table, as long as the same set of // prefixes is being inserted. This test verifies that strideTable behaves // this way. // // In addition to the basic shuffle test, we also check that this behavior // is maintained if all inserted routes have the same value pointer. This // shouldn't matter (the strideTable still needs to correctly account for // each inserted route, regardless of associated value), but during initial // development a subtle bug made the table corrupt itself in that setup, so // this test includes a regression test for that. routes := shufflePrefixes(allPrefixes())[:100] zero := 0 rt := strideTable[int]{} rtZero := strideTable[int]{} for _, route := range routes { rt.insert(route.addr, route.len, route.val) rtZero.insert(route.addr, route.len, &zero) } // Order of insertion should not affect the final shape of the stride table. routes2 := append([]slowEntry[int](nil), routes...) // dup so we can print both slices on fail for i := 0; i < 100; i++ { rand.Shuffle(len(routes2), func(i, j int) { routes2[i], routes2[j] = routes2[j], routes2[i] }) rt2 := strideTable[int]{} for _, route := range routes2 { rt2.insert(route.addr, route.len, route.val) } if diff := cmp.Diff(rt, rt2, cmpDiffOpts...); diff != "" { t.Errorf("tables ended up different with different insertion order (-got+want):\n%s\n\nOrder 1: %v\nOrder 2: %v", diff, formatSlowEntriesShort(routes), formatSlowEntriesShort(routes2)) } rtZero2 := strideTable[int]{} for _, route := range routes2 { rtZero2.insert(route.addr, route.len, &zero) } if diff := cmp.Diff(rtZero, rtZero2, cmpDiffOpts...); diff != "" { t.Errorf("tables with identical vals ended up different with different insertion order (-got+want):\n%s\n\nOrder 1: %v\nOrder 2: %v", diff, formatSlowEntriesShort(routes), formatSlowEntriesShort(routes2)) } } } func TestStrideTableDelete(t *testing.T) { t.Parallel() // Compare route deletion to our reference slowTable. pfxs := shufflePrefixes(allPrefixes())[:100] slow := slowTable[int]{pfxs} fast := strideTable[int]{} t.Logf("slow table:\n%s", slow.String()) for _, pfx := range pfxs { fast.insert(pfx.addr, pfx.len, pfx.val) t.Logf("after insert %d/%d:\n%s", pfx.addr, pfx.len, fast.tableDebugString()) } toDelete := pfxs[:50] for _, pfx := range toDelete { slow.delete(pfx.addr, pfx.len) fast.delete(pfx.addr, pfx.len) } // Sanity check that slowTable seems to have done the right thing. if cnt := len(slow.prefixes); cnt != 50 { t.Fatalf("slowTable has %d entries after deletes, want 50", cnt) } for i := 0; i < 256; i++ { addr := uint8(i) slowVal := slow.get(addr) fastVal := fast.get(addr) if slowVal != fastVal { t.Fatalf("strideTable.get(%d) = %v, want %v", addr, *fastVal, *slowVal) } } } func TestStrideTableDeleteShuffle(t *testing.T) { t.Parallel() // Same as TestStrideTableInsertShuffle, the order in which prefixes are // deleted should not impact the final shape of the route table. routes := shufflePrefixes(allPrefixes())[:100] toDelete := routes[:50] zero := 0 rt := strideTable[int]{} rtZero := strideTable[int]{} for _, route := range routes { rt.insert(route.addr, route.len, route.val) rtZero.insert(route.addr, route.len, &zero) } for _, route := range toDelete { rt.delete(route.addr, route.len) rtZero.delete(route.addr, route.len) } // Order of deletion should not affect the final shape of the stride table. toDelete2 := append([]slowEntry[int](nil), toDelete...) // dup so we can print both slices on fail for i := 0; i < 100; i++ { rand.Shuffle(len(toDelete2), func(i, j int) { toDelete2[i], toDelete2[j] = toDelete2[j], toDelete2[i] }) rt2 := strideTable[int]{} for _, route := range routes { rt2.insert(route.addr, route.len, route.val) } for _, route := range toDelete2 { rt2.delete(route.addr, route.len) } if diff := cmp.Diff(rt, rt2, cmpDiffOpts...); diff != "" { t.Errorf("tables ended up different with different deletion order (-got+want):\n%s\n\nOrder 1: %v\nOrder 2: %v", diff, formatSlowEntriesShort(toDelete), formatSlowEntriesShort(toDelete2)) } rtZero2 := strideTable[int]{} for _, route := range routes { rtZero2.insert(route.addr, route.len, &zero) } for _, route := range toDelete2 { rtZero2.delete(route.addr, route.len) } if diff := cmp.Diff(rtZero, rtZero2, cmpDiffOpts...); diff != "" { t.Errorf("tables with identical vals ended up different with different deletion order (-got+want):\n%s\n\nOrder 1: %v\nOrder 2: %v", diff, formatSlowEntriesShort(toDelete), formatSlowEntriesShort(toDelete2)) } } } var strideRouteCount = []int{10, 50, 100, 200} // forCountAndOrdering runs the benchmark fn with different sets of routes. // // fn is called once for each combination of {num_routes, order}, where // num_routes is the values in strideRouteCount, and order is the order of the // routes in the list: random, largest prefix first (/0 to /8), and smallest // prefix first (/8 to /0). func forStrideCountAndOrdering(b *testing.B, fn func(b *testing.B, routes []slowEntry[int])) { routes := shufflePrefixes(allPrefixes()) for _, nroutes := range strideRouteCount { b.Run(fmt.Sprint(nroutes), func(b *testing.B) { routes := append([]slowEntry[int](nil), routes[:nroutes]...) b.Run("random_order", func(b *testing.B) { b.ReportAllocs() fn(b, routes) }) sort.Slice(routes, func(i, j int) bool { if routes[i].len < routes[j].len { return true } return routes[i].addr < routes[j].addr }) b.Run("largest_first", func(b *testing.B) { b.ReportAllocs() fn(b, routes) }) sort.Slice(routes, func(i, j int) bool { if routes[j].len < routes[i].len { return true } return routes[j].addr < routes[i].addr }) b.Run("smallest_first", func(b *testing.B) { b.ReportAllocs() fn(b, routes) }) }) } } func BenchmarkStrideTableInsertion(b *testing.B) { forStrideCountAndOrdering(b, func(b *testing.B, routes []slowEntry[int]) { val := 0 for i := 0; i < b.N; i++ { var rt strideTable[int] for _, route := range routes { rt.insert(route.addr, route.len, &val) } } inserts := float64(b.N) * float64(len(routes)) elapsed := float64(b.Elapsed().Nanoseconds()) elapsedSec := b.Elapsed().Seconds() b.ReportMetric(elapsed/inserts, "ns/op") b.ReportMetric(inserts/elapsedSec, "routes/s") }) } func BenchmarkStrideTableDeletion(b *testing.B) { forStrideCountAndOrdering(b, func(b *testing.B, routes []slowEntry[int]) { val := 0 var rt strideTable[int] for _, route := range routes { rt.insert(route.addr, route.len, &val) } b.ResetTimer() for i := 0; i < b.N; i++ { rt2 := rt for _, route := range routes { rt2.delete(route.addr, route.len) } } deletes := float64(b.N) * float64(len(routes)) elapsed := float64(b.Elapsed().Nanoseconds()) elapsedSec := b.Elapsed().Seconds() b.ReportMetric(elapsed/deletes, "ns/op") b.ReportMetric(deletes/elapsedSec, "routes/s") }) } var writeSink *int func BenchmarkStrideTableGet(b *testing.B) { // No need to forCountAndOrdering here, route lookup time is independent of // the route count. routes := shufflePrefixes(allPrefixes())[:100] var rt strideTable[int] for _, route := range routes { rt.insert(route.addr, route.len, route.val) } b.ResetTimer() for i := 0; i < b.N; i++ { writeSink = rt.get(uint8(i)) } gets := float64(b.N) elapsedSec := b.Elapsed().Seconds() b.ReportMetric(gets/elapsedSec, "routes/s") } // slowTable is an 8-bit routing table implemented as a set of prefixes that are // explicitly scanned in full for every route lookup. It is very slow, but also // reasonably easy to verify by inspection, and so a good comparison target for // strideTable. type slowTable[T any] struct { prefixes []slowEntry[T] } type slowEntry[T any] struct { addr uint8 len int val *T } func (t *slowTable[T]) String() string { pfxs := append([]slowEntry[T](nil), t.prefixes...) sort.Slice(pfxs, func(i, j int) bool { if pfxs[i].len != pfxs[j].len { return pfxs[i].len < pfxs[j].len } return pfxs[i].addr < pfxs[j].addr }) var ret bytes.Buffer for _, pfx := range pfxs { fmt.Fprintf(&ret, "%3d/%d (%08b/%08b) = %v\n", pfx.addr, pfx.len, pfx.addr, pfxMask(pfx.len), *pfx.val) } return ret.String() } func (t *slowTable[T]) insert(addr uint8, prefixLen int, val *T) { t.delete(addr, prefixLen) // no-op if prefix doesn't exist t.prefixes = append(t.prefixes, slowEntry[T]{addr, prefixLen, val}) } func (t *slowTable[T]) delete(addr uint8, prefixLen int) { pfx := make([]slowEntry[T], 0, len(t.prefixes)) for _, e := range t.prefixes { if e.addr == addr && e.len == prefixLen { continue } pfx = append(pfx, e) } t.prefixes = pfx } func (t *slowTable[T]) get(addr uint8) *T { var ( ret *T curLen = -1 ) for _, e := range t.prefixes { if addr&pfxMask(e.len) == e.addr && e.len >= curLen { ret = e.val curLen = e.len } } return ret } func pfxMask(pfxLen int) uint8 { return 0xFF << (8 - pfxLen) } func allPrefixes() []slowEntry[int] { ret := make([]slowEntry[int], 0, lastHostIndex) for i := 1; i < lastHostIndex+1; i++ { a, l := inversePrefixIndex(i) ret = append(ret, slowEntry[int]{a, l, ptr.To(i)}) } return ret } func shufflePrefixes(pfxs []slowEntry[int]) []slowEntry[int] { rand.Shuffle(len(pfxs), func(i, j int) { pfxs[i], pfxs[j] = pfxs[j], pfxs[i] }) return pfxs } func formatSlowEntriesShort[T any](ents []slowEntry[T]) string { var ret []string for _, ent := range ents { ret = append(ret, fmt.Sprintf("%d/%d", ent.addr, ent.len)) } return "[" + strings.Join(ret, " ") + "]" } var cmpDiffOpts = []cmp.Option{ cmp.AllowUnexported(strideTable[int]{}, strideEntry[int]{}), cmp.Comparer(func(a, b netip.Prefix) bool { return a == b }), }