diff --git a/net/art/stride_table.go b/net/art/stride_table.go new file mode 100644 index 000000000..99a5731ea --- /dev/null +++ b/net/art/stride_table.go @@ -0,0 +1,226 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package art + +import ( + "bytes" + "fmt" + "io" + "math/bits" + "strconv" + "strings" +) + +// strideEntry is a strideTable entry. +type strideEntry[T any] struct { + // prefixIndex is the prefixIndex(...) value that caused this stride entry's + // value to be populated, or 0 if value is nil. + // + // We need to keep track of this because allot() uses it to determine + // whether an entry was propagated from a parent entry, or if it's a + // different independent route. + prefixIndex int + // value is the value associated with the strideEntry, if any. + value *T + // child is the child strideTable associated with the strideEntry, if any. + child *strideTable[T] +} + +// strideTable is a binary tree that implements an 8-bit routing table. +// +// The leaves of the binary tree are host routes (/8s). Each parent is a +// successively larger prefix that encompasses its children (/7 through /0). +type strideTable[T any] struct { + // entries is the nodes of the binary tree, laid out in a flattened array. + // + // The array indices are arranged by the prefixIndex function, such that the + // parent of the node at index i is located at index i>>1, and its children + // at indices i<<1 and (i<<1)+1. + // + // A few consequences of this arrangement: host routes (/8) occupy the last + // 256 entries in the table; the single default route /0 is at index 1, and + // index 0 is unused (in the original paper, it's hijacked through sneaky C + // memory trickery to store the refcount, but this is Go, where we don't + // store random bits in pointers lest we confuse the GC) + entries [lastHostIndex + 1]strideEntry[T] + // refs is the number of route entries and child strideTables referenced by + // this table. It is used in the multi-layered logic to determine when this + // table is empty and can be deleted. + refs int +} + +const ( + // firstHostIndex is the array index of the first host route. This is hostIndex(0/8). + firstHostIndex = 0b1_0000_0000 + // lastHostIndex is the array index of the last host route. This is hostIndex(0xFF/8). + lastHostIndex = 0b1_1111_1111 +) + +// getChild returns the child strideTable pointer for addr (if any), and an +// internal array index that can be used with deleteChild. +func (t *strideTable[T]) getChild(addr uint8) (child *strideTable[T], idx int) { + idx = hostIndex(addr) + return t.entries[idx].child, idx +} + +// deleteChild deletes the child strideTable at idx (if any). idx should be +// obtained via a call to getChild. +func (t *strideTable[T]) deleteChild(idx int) { + t.entries[idx].child = nil + t.refs-- +} + +// getOrCreateChild returns the child strideTable for addr, creating it if +// necessary. +func (t *strideTable[T]) getOrCreateChild(addr uint8) *strideTable[T] { + idx := hostIndex(addr) + if t.entries[idx].child == nil { + t.entries[idx].child = new(strideTable[T]) + t.refs++ + } + return t.entries[idx].child +} + +// allot updates entries whose stored prefixIndex matches oldPrefixIndex, in the +// subtree rooted at idx. Matching entries have their stored prefixIndex set to +// newPrefixIndex, and their value set to val. +// +// allot is the core of the ART algorithm, enabling efficient insertion/deletion +// while preserving very fast lookups. +func (t *strideTable[T]) allot(idx int, oldPrefixIndex, newPrefixIndex int, val *T) { + if t.entries[idx].prefixIndex != oldPrefixIndex { + // current prefixIndex isn't what we expect. This is a recursive call + // that found a child subtree that already has a more specific route + // installed. Don't touch it. + return + } + t.entries[idx].value = val + t.entries[idx].prefixIndex = newPrefixIndex + if idx >= firstHostIndex { + // The entry we just updated was a host route, we're at the bottom of + // the binary tree. + return + } + // Propagate the allotment to this node's children. + left := idx << 1 + t.allot(left, oldPrefixIndex, newPrefixIndex, val) + right := left + 1 + t.allot(right, oldPrefixIndex, newPrefixIndex, val) +} + +// insert adds the route addr/prefixLen to t, with value val. +func (t *strideTable[T]) insert(addr uint8, prefixLen int, val *T) { + idx := prefixIndex(addr, prefixLen) + old := t.entries[idx].value + oldIdx := t.entries[idx].prefixIndex + if oldIdx == idx && old == val { + // This exact prefix+value is already in the table. + return + } + t.allot(idx, oldIdx, idx, val) + if oldIdx != idx { + // This route entry was freshly created (not just updated), that's a new + // reference. + t.refs++ + } + return +} + +// delete removes the route addr/prefixLen from t. +func (t *strideTable[T]) delete(addr uint8, prefixLen int) *T { + idx := prefixIndex(addr, prefixLen) + recordedIdx := t.entries[idx].prefixIndex + if recordedIdx != idx { + // Route entry doesn't exist + return nil + } + val := t.entries[idx].value + + parentIdx := idx >> 1 + t.allot(idx, idx, t.entries[parentIdx].prefixIndex, t.entries[parentIdx].value) + t.refs-- + return val +} + +// get does a route lookup for addr and returns the associated value, or nil if +// no route matched. +func (t *strideTable[T]) get(addr uint8) *T { + return t.entries[hostIndex(addr)].value +} + +// TableDebugString returns the contents of t, formatted as a table with one +// line per entry. +func (t *strideTable[T]) tableDebugString() string { + var ret bytes.Buffer + for i, ent := range t.entries { + if i == 0 { + continue + } + v := "(nil)" + if ent.value != nil { + v = fmt.Sprint(*ent.value) + } + fmt.Fprintf(&ret, "idx=%3d (%s), parent=%3d (%s), val=%v\n", i, formatPrefixTable(inversePrefixIndex(i)), ent.prefixIndex, formatPrefixTable(inversePrefixIndex((ent.prefixIndex))), v) + } + return ret.String() +} + +// treeDebugString returns the contents of t, formatted as a sparse tree. Each +// line is one entry, indented such that it is contained by all its parents, and +// non-overlapping with any of its siblings. +func (t *strideTable[T]) treeDebugString() string { + var ret bytes.Buffer + t.treeDebugStringRec(&ret, 1, 0) // index of 0/0, and 0 indent + return ret.String() +} + +func (t *strideTable[T]) treeDebugStringRec(w io.Writer, idx, indent int) { + addr, len := inversePrefixIndex(idx) + if t.entries[idx].prefixIndex != 0 && t.entries[idx].prefixIndex == idx { + fmt.Fprintf(w, "%s%d/%d (%d/%d) = %v\n", strings.Repeat(" ", indent), addr, len, addr, len, *t.entries[idx].value) + indent += 2 + } + if idx >= firstHostIndex { + return + } + left := idx << 1 + t.treeDebugStringRec(w, left, indent) + right := left + 1 + t.treeDebugStringRec(w, right, indent) +} + +// prefixIndex returns the array index of the tree node for addr/prefixLen. +func prefixIndex(addr uint8, prefixLen int) int { + // the prefixIndex of addr/prefixLen is the prefixLen most significant bits + // of addr, with a 1 tacked onto the left-hand side. For example: + // + // - 0/0 is 1: 0 bits of the addr, with a 1 tacked on + // - 42/8 is 1_00101010 (298): all bits of 42, with a 1 tacked on + // - 48/4 is 1_0011 (19): 4 most-significant bits of 48, with a 1 tacked on + return (int(addr) >> (8 - prefixLen)) + (1 << prefixLen) +} + +// hostIndex returns the array index of the host route for addr. +// It is equivalent to prefixIndex(addr, 8). +func hostIndex(addr uint8) int { + return int(addr) + 1<<8 +} + +// inversePrefixIndex returns the address and prefix length of idx. It is the +// inverse of prefixIndex. Only used for debugging and in tests. +func inversePrefixIndex(idx int) (addr uint8, len int) { + lz := bits.LeadingZeros(uint(idx)) + len = strconv.IntSize - lz - 1 + addr = uint8(idx&(0xFF>>(8-len))) << (8 - len) + return addr, len +} + +// formatPrefixTable formats addr and len as addr/len, with a constant width +// suitable for use in table formatting. +func formatPrefixTable(addr uint8, len int) string { + if len < 0 { // this happens for inversePrefixIndex(0) + return "" + } + return fmt.Sprintf("%3d/%d", addr, len) +} diff --git a/net/art/stride_table_test.go b/net/art/stride_table_test.go new file mode 100644 index 000000000..03fb518ac --- /dev/null +++ b/net/art/stride_table_test.go @@ -0,0 +1,378 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package art + +import ( + "bytes" + "fmt" + "math/rand" + "sort" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "tailscale.com/types/ptr" +) + +func TestInversePrefix(t *testing.T) { + for i := 0; i < 256; i++ { + for len := 0; len < 9; len++ { + addr := i & (0xFF << (8 - len)) + idx := prefixIndex(uint8(addr), len) + addr2, len2 := inversePrefixIndex(idx) + if addr2 != uint8(addr) || len2 != len { + t.Errorf("inverse(index(%d/%d)) != %d/%d", addr, len, addr2, len2) + } + } + } +} + +func TestHostIndex(t *testing.T) { + for i := 0; i < 256; i++ { + got := hostIndex(uint8(i)) + want := prefixIndex(uint8(i), 8) + if got != want { + t.Errorf("hostIndex(%d) = %d, want %d", i, got, want) + } + } +} + +func TestStrideTableInsert(t *testing.T) { + // Verify that strideTable's lookup results after a bunch of inserts exactly + // match those of a naive implementation that just scans all prefixes on + // every lookup. The naive implementation is very slow, but its behavior is + // easy to verify by inspection. + + pfxs := shufflePrefixes(allPrefixes())[:100] + slow := slowTable[int]{pfxs} + fast := strideTable[int]{} + + t.Logf("slow table:\n%s", slow.String()) + + for _, pfx := range pfxs { + fast.insert(pfx.addr, pfx.len, pfx.val) + t.Logf("after insert %d/%d:\n%s", pfx.addr, pfx.len, fast.tableDebugString()) + } + + for i := 0; i < 256; i++ { + addr := uint8(i) + slowVal := slow.get(addr) + fastVal := fast.get(addr) + if slowVal != fastVal { + t.Fatalf("strideTable.get(%d) = %v, want %v", addr, *fastVal, *slowVal) + } + } +} + +func TestStrideTableInsertShuffled(t *testing.T) { + // The order in which routes are inserted into a route table does not + // influence the final shape of the table, as long as the same set of + // prefixes is being inserted. This test verifies that strideTable behaves + // this way. + // + // In addition to the basic shuffle test, we also check that this behavior + // is maintained if all inserted routes have the same value pointer. This + // shouldn't matter (the strideTable still needs to correctly account for + // each inserted route, regardless of associated value), but during initial + // development a subtle bug made the table corrupt itself in that setup, so + // this test includes a regression test for that. + + routes := shufflePrefixes(allPrefixes())[:100] + + zero := 0 + rt := strideTable[int]{} + rtZero := strideTable[int]{} + for _, route := range routes { + rt.insert(route.addr, route.len, route.val) + rtZero.insert(route.addr, route.len, &zero) + } + + // Order of insertion should not affect the final shape of the stride table. + routes2 := append([]slowEntry[int](nil), routes...) // dup so we can print both slices on fail + for i := 0; i < 100; i++ { + rand.Shuffle(len(routes2), func(i, j int) { routes2[i], routes2[j] = routes2[j], routes2[i] }) + rt2 := strideTable[int]{} + for _, route := range routes2 { + rt2.insert(route.addr, route.len, route.val) + } + if diff := cmp.Diff(rt, rt2, cmp.AllowUnexported(strideTable[int]{}, strideEntry[int]{})); diff != "" { + t.Errorf("tables ended up different with different insertion order (-got+want):\n%s\n\nOrder 1: %v\nOrder 2: %v", diff, formatSlowEntriesShort(routes), formatSlowEntriesShort(routes2)) + } + + rtZero2 := strideTable[int]{} + for _, route := range routes2 { + rtZero2.insert(route.addr, route.len, &zero) + } + if diff := cmp.Diff(rtZero, rtZero2, cmp.AllowUnexported(strideTable[int]{}, strideEntry[int]{})); diff != "" { + t.Errorf("tables with identical vals ended up different with different insertion order (-got+want):\n%s\n\nOrder 1: %v\nOrder 2: %v", diff, formatSlowEntriesShort(routes), formatSlowEntriesShort(routes2)) + } + } +} + +func TestStrideTableDelete(t *testing.T) { + // Compare route deletion to our reference slowTable. + pfxs := shufflePrefixes(allPrefixes())[:100] + slow := slowTable[int]{pfxs} + fast := strideTable[int]{} + + t.Logf("slow table:\n%s", slow.String()) + + for _, pfx := range pfxs { + fast.insert(pfx.addr, pfx.len, pfx.val) + t.Logf("after insert %d/%d:\n%s", pfx.addr, pfx.len, fast.tableDebugString()) + } + + toDelete := pfxs[:50] + for _, pfx := range toDelete { + slow.delete(pfx.addr, pfx.len) + fast.delete(pfx.addr, pfx.len) + } + + // Sanity check that slowTable seems to have done the right thing. + if cnt := len(slow.prefixes); cnt != 50 { + t.Fatalf("slowTable has %d entries after deletes, want 50", cnt) + } + + for i := 0; i < 256; i++ { + addr := uint8(i) + slowVal := slow.get(addr) + fastVal := fast.get(addr) + if slowVal != fastVal { + t.Fatalf("strideTable.get(%d) = %v, want %v", addr, *fastVal, *slowVal) + } + } +} + +func TestStrideTableDeleteShuffle(t *testing.T) { + // Same as TestStrideTableInsertShuffle, the order in which prefixes are + // deleted should not impact the final shape of the route table. + + routes := shufflePrefixes(allPrefixes())[:100] + toDelete := routes[:50] + + zero := 0 + rt := strideTable[int]{} + rtZero := strideTable[int]{} + for _, route := range routes { + rt.insert(route.addr, route.len, route.val) + rtZero.insert(route.addr, route.len, &zero) + } + for _, route := range toDelete { + rt.delete(route.addr, route.len) + rtZero.delete(route.addr, route.len) + } + + // Order of deletion should not affect the final shape of the stride table. + toDelete2 := append([]slowEntry[int](nil), toDelete...) // dup so we can print both slices on fail + for i := 0; i < 100; i++ { + rand.Shuffle(len(toDelete2), func(i, j int) { toDelete2[i], toDelete2[j] = toDelete2[j], toDelete2[i] }) + rt2 := strideTable[int]{} + for _, route := range routes { + rt2.insert(route.addr, route.len, route.val) + } + for _, route := range toDelete2 { + rt2.delete(route.addr, route.len) + } + if diff := cmp.Diff(rt, rt2, cmp.AllowUnexported(strideTable[int]{}, strideEntry[int]{})); diff != "" { + t.Errorf("tables ended up different with different deletion order (-got+want):\n%s\n\nOrder 1: %v\nOrder 2: %v", diff, formatSlowEntriesShort(toDelete), formatSlowEntriesShort(toDelete2)) + } + + rtZero2 := strideTable[int]{} + for _, route := range routes { + rtZero2.insert(route.addr, route.len, &zero) + } + for _, route := range toDelete2 { + rtZero2.delete(route.addr, route.len) + } + if diff := cmp.Diff(rtZero, rtZero2, cmp.AllowUnexported(strideTable[int]{}, strideEntry[int]{})); diff != "" { + t.Errorf("tables with identical vals ended up different with different deletion order (-got+want):\n%s\n\nOrder 1: %v\nOrder 2: %v", diff, formatSlowEntriesShort(toDelete), formatSlowEntriesShort(toDelete2)) + } + } +} + +var benchRouteCount = []int{10, 50, 100, 200} + +// forCountAndOrdering runs the benchmark fn with different sets of routes. +// +// fn is called once for each combination of {num_routes, order}, where +// num_routes is the values in benchRouteCount, and order is the order of the +// routes in the list: random, largest prefix first (/0 to /8), and smallest +// prefix first (/8 to /0). +func forCountAndOrdering(b *testing.B, fn func(b *testing.B, routes []slowEntry[int])) { + routes := shufflePrefixes(allPrefixes()) + for _, nroutes := range benchRouteCount { + b.Run(fmt.Sprint(nroutes), func(b *testing.B) { + routes := append([]slowEntry[int](nil), routes[:nroutes]...) + b.Run("random_order", func(b *testing.B) { + b.ReportAllocs() + fn(b, routes) + }) + sort.Slice(routes, func(i, j int) bool { + if routes[i].len < routes[j].len { + return true + } + return routes[i].addr < routes[j].addr + }) + b.Run("largest_first", func(b *testing.B) { + b.ReportAllocs() + fn(b, routes) + }) + sort.Slice(routes, func(i, j int) bool { + if routes[j].len < routes[i].len { + return true + } + return routes[j].addr < routes[i].addr + }) + b.Run("smallest_first", func(b *testing.B) { + b.ReportAllocs() + fn(b, routes) + }) + }) + } +} + +func BenchmarkStrideTableInsertion(b *testing.B) { + forCountAndOrdering(b, func(b *testing.B, routes []slowEntry[int]) { + val := 0 + for i := 0; i < b.N; i++ { + var rt strideTable[int] + for _, route := range routes { + rt.insert(route.addr, route.len, &val) + } + } + inserts := float64(b.N) * float64(len(routes)) + elapsed := float64(b.Elapsed().Nanoseconds()) + elapsedSec := b.Elapsed().Seconds() + b.ReportMetric(elapsed/inserts, "ns/op") + b.ReportMetric(inserts/elapsedSec, "routes/s") + }) +} + +func BenchmarkStrideTableDeletion(b *testing.B) { + forCountAndOrdering(b, func(b *testing.B, routes []slowEntry[int]) { + val := 0 + var rt strideTable[int] + for _, route := range routes { + rt.insert(route.addr, route.len, &val) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + rt2 := rt + for _, route := range routes { + rt2.delete(route.addr, route.len) + } + } + deletes := float64(b.N) * float64(len(routes)) + elapsed := float64(b.Elapsed().Nanoseconds()) + elapsedSec := b.Elapsed().Seconds() + b.ReportMetric(elapsed/deletes, "ns/op") + b.ReportMetric(deletes/elapsedSec, "routes/s") + }) +} + +var writeSink *int + +func BenchmarkStrideTableGet(b *testing.B) { + // No need to forCountAndOrdering here, route lookup time is independent of + // the route count. + routes := shufflePrefixes(allPrefixes())[:100] + var rt strideTable[int] + for _, route := range routes { + rt.insert(route.addr, route.len, route.val) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + writeSink = rt.get(uint8(i)) + } + gets := float64(b.N) + elapsedSec := b.Elapsed().Seconds() + b.ReportMetric(gets/elapsedSec, "routes/s") +} + +// slowTable is an 8-bit routing table implemented as a set of prefixes that are +// explicitly scanned in full for every route lookup. It is very slow, but also +// reasonably easy to verify by inspection, and so a good comparison target for +// strideTable. +type slowTable[T any] struct { + prefixes []slowEntry[T] +} + +type slowEntry[T any] struct { + addr uint8 + len int + val *T +} + +func (t *slowTable[T]) String() string { + pfxs := append([]slowEntry[T](nil), t.prefixes...) + sort.Slice(pfxs, func(i, j int) bool { + if pfxs[i].len != pfxs[j].len { + return pfxs[i].len < pfxs[j].len + } + return pfxs[i].addr < pfxs[j].addr + }) + var ret bytes.Buffer + for _, pfx := range pfxs { + fmt.Fprintf(&ret, "%3d/%d (%08b/%08b) = %v\n", pfx.addr, pfx.len, pfx.addr, pfxMask(pfx.len), *pfx.val) + } + return ret.String() +} + +func (t *slowTable[T]) insert(addr uint8, prefixLen int, val *T) { + t.delete(addr, prefixLen) // no-op if prefix doesn't exist + t.prefixes = append(t.prefixes, slowEntry[T]{addr, prefixLen, val}) +} + +func (t *slowTable[T]) delete(addr uint8, prefixLen int) { + pfx := make([]slowEntry[T], 0, len(t.prefixes)) + for _, e := range t.prefixes { + if e.addr == addr && e.len == prefixLen { + continue + } + pfx = append(pfx, e) + } + t.prefixes = pfx +} + +func (t *slowTable[T]) get(addr uint8) *T { + var ( + ret *T + curLen = -1 + ) + for _, e := range t.prefixes { + if addr&pfxMask(e.len) == e.addr && e.len >= curLen { + ret = e.val + curLen = e.len + } + } + return ret +} + +func pfxMask(pfxLen int) uint8 { + return 0xFF << (8 - pfxLen) +} + +func allPrefixes() []slowEntry[int] { + ret := make([]slowEntry[int], 0, lastHostIndex) + for i := 1; i < lastHostIndex+1; i++ { + a, l := inversePrefixIndex(i) + ret = append(ret, slowEntry[int]{a, l, ptr.To(i)}) + } + return ret +} + +func shufflePrefixes(pfxs []slowEntry[int]) []slowEntry[int] { + rand.Shuffle(len(pfxs), func(i, j int) { pfxs[i], pfxs[j] = pfxs[j], pfxs[i] }) + return pfxs +} + +func formatSlowEntriesShort[T any](ents []slowEntry[T]) string { + var ret []string + for _, ent := range ents { + ret = append(ret, fmt.Sprintf("%d/%d", ent.addr, ent.len)) + } + return "[" + strings.Join(ret, " ") + "]" +} diff --git a/net/art/table.go b/net/art/table.go new file mode 100644 index 000000000..1d49f1566 --- /dev/null +++ b/net/art/table.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package art provides a routing table that implements the Allotment Routing +// Table (ART) algorithm by Donald Knuth, as described in the paper by Yoichi +// Hariguchi. +// +// ART outperforms the traditional radix tree implementations for route lookups, +// insertions, and deletions. +// +// For more information, see Yoichi Hariguchi's paper: +// https://cseweb.ucsd.edu//~varghese/TEACH/cs228/artlookup.pdf +package art