diff --git a/net/art/stride_table.go b/net/art/stride_table.go index ea261efac..f18f76515 100644 --- a/net/art/stride_table.go +++ b/net/art/stride_table.go @@ -18,19 +18,6 @@ const ( debugStrideDelete = false ) -// strideEntry is a strideTable entry. -type strideEntry[T any] struct { - // prefixIndex is the prefixIndex(...) value that caused this stride entry's - // value to be populated, or 0 if value is nil. - // - // We need to keep track of this because allot() uses it to determine - // whether an entry was propagated from a parent entry, or if it's a - // different independent route. - prefixIndex int - // value is the value associated with the strideEntry, if any. - value *T -} - // strideTable is a binary tree that implements an 8-bit routing table. // // The leaves of the binary tree are host routes (/8s). Each parent is a @@ -54,7 +41,9 @@ type strideTable[T any] struct { // paper, it's hijacked through sneaky C memory trickery to store // the refcount, but this is Go, where we don't store random bits // in pointers lest we confuse the GC) - entries [lastHostIndex + 1]strideEntry[T] + // + // A nil value means no route matches the queried route. + entries [lastHostIndex + 1]*T // children are the child tables of this table. Each child // represents the address space within one of this table's host // routes (/8). @@ -112,13 +101,6 @@ func (t *strideTable[T]) getOrCreateChild(addr uint8) (child *strideTable[T], cr return ret, false } -// getValAndChild returns both the prefix and child strideTable for -// addr. Both returned values can be nil if no entry of that type -// exists for addr. -func (t *strideTable[T]) getValAndChild(addr uint8) (*T, *strideTable[T]) { - return t.entries[hostIndex(addr)].value, t.children[addr] -} - // findFirstChild returns the first child strideTable in t, or nil if // t has no children. func (t *strideTable[T]) findFirstChild() *strideTable[T] { @@ -130,21 +112,41 @@ func (t *strideTable[T]) findFirstChild() *strideTable[T] { return nil } +// hasPrefixRootedAt reports whether t.entries[idx] is the root node of +// a prefix. +func (t *strideTable[T]) hasPrefixRootedAt(idx int) bool { + val := t.entries[idx] + if val == nil { + return false + } + + parentIdx := parentIndex(idx) + if parentIdx == 0 { + // idx is non-nil, and is at the 0/0 route position. + return true + } + if parent := t.entries[parentIdx]; val != parent { + // parent node in the tree isn't the same prefix, so idx must + // be a root. + return true + } + return false +} + // allot updates entries whose stored prefixIndex matches oldPrefixIndex, in the // subtree rooted at idx. Matching entries have their stored prefixIndex set to // newPrefixIndex, and their value set to val. // // allot is the core of the ART algorithm, enabling efficient insertion/deletion // while preserving very fast lookups. -func (t *strideTable[T]) allot(idx int, oldPrefixIndex, newPrefixIndex int, val *T) { - if t.entries[idx].prefixIndex != oldPrefixIndex { - // current prefixIndex isn't what we expect. This is a recursive call - // that found a child subtree that already has a more specific route - // installed. Don't touch it. +func (t *strideTable[T]) allot(idx int, old, new *T) { + if t.entries[idx] != old { + // current idx isn't what we expect. This is a recursive call + // that found a child subtree that already has a more specific + // route installed. Don't touch it. return } - t.entries[idx].value = val - t.entries[idx].prefixIndex = newPrefixIndex + t.entries[idx] = new if idx >= firstHostIndex { // The entry we just updated was a host route, we're at the bottom of // the binary tree. @@ -152,51 +154,73 @@ func (t *strideTable[T]) allot(idx int, oldPrefixIndex, newPrefixIndex int, val } // Propagate the allotment to this node's children. left := idx << 1 - t.allot(left, oldPrefixIndex, newPrefixIndex, val) + t.allot(left, old, new) right := left + 1 - t.allot(right, oldPrefixIndex, newPrefixIndex, val) + t.allot(right, old, new) } // insert adds the route addr/prefixLen to t, with value val. -func (t *strideTable[T]) insert(addr uint8, prefixLen int, val *T) { +func (t *strideTable[T]) insert(addr uint8, prefixLen int, val T) { idx := prefixIndex(addr, prefixLen) - old := t.entries[idx].value - oldIdx := t.entries[idx].prefixIndex - if oldIdx == idx && old == val { - // This exact prefix+value is already in the table. - return - } - t.allot(idx, oldIdx, idx, val) - if oldIdx != idx { - // This route entry was freshly created (not just updated), that's a new - // reference. + if !t.hasPrefixRootedAt(idx) { + // This route entry is being freshly created (not just + // updated), that's a new reference. t.routeRefs++ } + + old := t.entries[idx] + + // For allot to work correctly, each distinct prefix in the + // strideTable must have a different value pointer, even if val is + // identical. This new()+assignment guarantees that each inserted + // prefix gets a unique address. + p := new(T) + *p = val + + t.allot(idx, old, p) return } -// delete removes the route addr/prefixLen from t. Returns the value -// that was associated with the deleted prefix, or nil if the prefix -// wasn't in the strideTable. -func (t *strideTable[T]) delete(addr uint8, prefixLen int) *T { +// delete removes the route addr/prefixLen from t. Reports whether the +// prefix existed in the table prior to deletion. +func (t *strideTable[T]) delete(addr uint8, prefixLen int) (wasPresent bool) { idx := prefixIndex(addr, prefixLen) - recordedIdx := t.entries[idx].prefixIndex - if recordedIdx != idx { + if !t.hasPrefixRootedAt(idx) { // Route entry doesn't exist - return nil + return false } - val := t.entries[idx].value - parentIdx := idx >> 1 - t.allot(idx, idx, t.entries[parentIdx].prefixIndex, t.entries[parentIdx].value) + val := t.entries[idx] + var parentVal *T + if parentIdx := parentIndex(idx); parentIdx != 0 { + parentVal = t.entries[parentIdx] + } + + t.allot(idx, val, parentVal) t.routeRefs-- - return val + return true } -// get does a route lookup for addr and returns the associated value, or nil if -// no route matched. -func (t *strideTable[T]) get(addr uint8) *T { - return t.entries[hostIndex(addr)].value +// get does a route lookup for addr and (value, true) if a matching +// route exists, or (zero, false) otherwise. +func (t *strideTable[T]) get(addr uint8) (ret T, ok bool) { + if val := t.entries[hostIndex(addr)]; val != nil { + return *val, true + } + return ret, false +} + +// getValAndChild returns both the prefix value and child strideTable +// for addr. valOK reports whether a prefix value exists for addr, and +// child is non-nil if a child exists for addr. +func (t *strideTable[T]) getValAndChild(addr uint8) (val T, valOK bool, child *strideTable[T]) { + vp := t.entries[hostIndex(addr)] + if vp != nil { + val = *vp + valOK = true + } + child = t.children[addr] + return } // TableDebugString returns the contents of t, formatted as a table with one @@ -208,10 +232,10 @@ func (t *strideTable[T]) tableDebugString() string { continue } v := "(nil)" - if ent.value != nil { - v = fmt.Sprint(*ent.value) + if ent != nil { + v = fmt.Sprint(*ent) } - fmt.Fprintf(&ret, "idx=%3d (%s), parent=%3d (%s), val=%v\n", i, formatPrefixTable(inversePrefixIndex(i)), ent.prefixIndex, formatPrefixTable(inversePrefixIndex((ent.prefixIndex))), v) + fmt.Fprintf(&ret, "idx=%3d (%s), val=%v\n", i, formatPrefixTable(inversePrefixIndex(i)), v) } return ret.String() } @@ -227,8 +251,8 @@ func (t *strideTable[T]) treeDebugString() string { func (t *strideTable[T]) treeDebugStringRec(w io.Writer, idx, indent int) { addr, len := inversePrefixIndex(idx) - if t.entries[idx].prefixIndex != 0 && t.entries[idx].prefixIndex == idx { - fmt.Fprintf(w, "%s%d/%d (%02x/%d) = %v\n", strings.Repeat(" ", indent), addr, len, addr, len, *t.entries[idx].value) + if t.hasPrefixRootedAt(idx) { + fmt.Fprintf(w, "%s%d/%d (%02x/%d) = %v\n", strings.Repeat(" ", indent), addr, len, addr, len, *t.entries[idx]) indent += 2 } if idx >= firstHostIndex { @@ -251,6 +275,12 @@ func prefixIndex(addr uint8, prefixLen int) int { return (int(addr) >> (8 - prefixLen)) + (1 << prefixLen) } +// parentIndex returns the index of idx's parent prefix, or 0 if idx +// is the index of 0/0. +func parentIndex(idx int) int { + return idx >> 1 +} + // hostIndex returns the array index of the host route for addr. // It is equivalent to prefixIndex(addr, 8). func hostIndex(addr uint8) int { diff --git a/net/art/stride_table_test.go b/net/art/stride_table_test.go index a974479a5..82a7c915d 100644 --- a/net/art/stride_table_test.go +++ b/net/art/stride_table_test.go @@ -8,12 +8,12 @@ import ( "fmt" "math/rand" "net/netip" + "runtime" "sort" "strings" "testing" "github.com/google/go-cmp/cmp" - "tailscale.com/types/ptr" ) func TestInversePrefix(t *testing.T) { @@ -65,10 +65,10 @@ func TestStrideTableInsert(t *testing.T) { for i := 0; i < 256; i++ { addr := uint8(i) - slowVal := slow.get(addr) - fastVal := fast.get(addr) - if slowVal != fastVal { - t.Fatalf("strideTable.get(%d) = %v, want %v", addr, *fastVal, *slowVal) + slowVal, slowOK := slow.get(addr) + fastVal, fastOK := fast.get(addr) + if !getsEqual(fastVal, fastOK, slowVal, slowOK) { + t.Fatalf("strideTable.get(%d) = (%v, %v), want (%v, %v)", addr, fastVal, fastOK, slowVal, slowOK) } } } @@ -91,10 +91,14 @@ func TestStrideTableInsertShuffled(t *testing.T) { zero := 0 rt := strideTable[int]{} + // strideTable has a value interface, but internally has to keep + // track of distinct routes even if they all have the same + // value. rtZero uses the same value for all routes, and expects + // correct behavior. rtZero := strideTable[int]{} for _, route := range routes { rt.insert(route.addr, route.len, route.val) - rtZero.insert(route.addr, route.len, &zero) + rtZero.insert(route.addr, route.len, zero) } // Order of insertion should not affect the final shape of the stride table. @@ -105,15 +109,15 @@ func TestStrideTableInsertShuffled(t *testing.T) { for _, route := range routes2 { rt2.insert(route.addr, route.len, route.val) } - if diff := cmp.Diff(rt, rt2, cmpDiffOpts...); diff != "" { + if diff := cmp.Diff(rt.tableDebugString(), rt2.tableDebugString()); diff != "" { t.Errorf("tables ended up different with different insertion order (-got+want):\n%s\n\nOrder 1: %v\nOrder 2: %v", diff, formatSlowEntriesShort(routes), formatSlowEntriesShort(routes2)) } rtZero2 := strideTable[int]{} for _, route := range routes2 { - rtZero2.insert(route.addr, route.len, &zero) + rtZero2.insert(route.addr, route.len, zero) } - if diff := cmp.Diff(rtZero, rtZero2, cmpDiffOpts...); diff != "" { + if diff := cmp.Diff(rtZero.tableDebugString(), rtZero2.tableDebugString(), cmpDiffOpts...); diff != "" { t.Errorf("tables with identical vals ended up different with different insertion order (-got+want):\n%s\n\nOrder 1: %v\nOrder 2: %v", diff, formatSlowEntriesShort(routes), formatSlowEntriesShort(routes2)) } } @@ -150,10 +154,10 @@ func TestStrideTableDelete(t *testing.T) { for i := 0; i < 256; i++ { addr := uint8(i) - slowVal := slow.get(addr) - fastVal := fast.get(addr) - if slowVal != fastVal { - t.Fatalf("strideTable.get(%d) = %v, want %v", addr, *fastVal, *slowVal) + slowVal, slowOK := slow.get(addr) + fastVal, fastOK := fast.get(addr) + if !getsEqual(fastVal, fastOK, slowVal, slowOK) { + t.Fatalf("strideTable.get(%d) = (%v, %v), want (%v, %v)", addr, fastVal, fastOK, slowVal, slowOK) } } } @@ -168,10 +172,14 @@ func TestStrideTableDeleteShuffle(t *testing.T) { zero := 0 rt := strideTable[int]{} + // strideTable has a value interface, but internally has to keep + // track of distinct routes even if they all have the same + // value. rtZero uses the same value for all routes, and expects + // correct behavior. rtZero := strideTable[int]{} for _, route := range routes { rt.insert(route.addr, route.len, route.val) - rtZero.insert(route.addr, route.len, &zero) + rtZero.insert(route.addr, route.len, zero) } for _, route := range toDelete { rt.delete(route.addr, route.len) @@ -189,18 +197,18 @@ func TestStrideTableDeleteShuffle(t *testing.T) { for _, route := range toDelete2 { rt2.delete(route.addr, route.len) } - if diff := cmp.Diff(rt, rt2, cmpDiffOpts...); diff != "" { + if diff := cmp.Diff(rt.tableDebugString(), rt2.tableDebugString(), cmpDiffOpts...); diff != "" { t.Errorf("tables ended up different with different deletion order (-got+want):\n%s\n\nOrder 1: %v\nOrder 2: %v", diff, formatSlowEntriesShort(toDelete), formatSlowEntriesShort(toDelete2)) } rtZero2 := strideTable[int]{} for _, route := range routes { - rtZero2.insert(route.addr, route.len, &zero) + rtZero2.insert(route.addr, route.len, zero) } for _, route := range toDelete2 { rtZero2.delete(route.addr, route.len) } - if diff := cmp.Diff(rtZero, rtZero2, cmpDiffOpts...); diff != "" { + if diff := cmp.Diff(rtZero.tableDebugString(), rtZero2.tableDebugString(), cmpDiffOpts...); diff != "" { t.Errorf("tables with identical vals ended up different with different deletion order (-got+want):\n%s\n\nOrder 1: %v\nOrder 2: %v", diff, formatSlowEntriesShort(toDelete), formatSlowEntriesShort(toDelete2)) } } @@ -218,31 +226,35 @@ func forStrideCountAndOrdering(b *testing.B, fn func(b *testing.B, routes []slow routes := shufflePrefixes(allPrefixes()) for _, nroutes := range strideRouteCount { b.Run(fmt.Sprint(nroutes), func(b *testing.B) { - routes := append([]slowEntry[int](nil), routes[:nroutes]...) - b.Run("random_order", func(b *testing.B) { + runAndRecord := func(b *testing.B) { b.ReportAllocs() + var startMem, endMem runtime.MemStats + runtime.ReadMemStats(&startMem) fn(b, routes) - }) + runtime.ReadMemStats(&endMem) + ops := float64(b.N) * float64(len(routes)) + allocs := float64(endMem.Mallocs - startMem.Mallocs) + bytes := float64(endMem.TotalAlloc - startMem.TotalAlloc) + b.ReportMetric(roundFloat64(allocs/ops), "allocs/op") + b.ReportMetric(roundFloat64(bytes/ops), "B/op") + } + + routes := append([]slowEntry[int](nil), routes[:nroutes]...) + b.Run("random_order", runAndRecord) sort.Slice(routes, func(i, j int) bool { if routes[i].len < routes[j].len { return true } return routes[i].addr < routes[j].addr }) - b.Run("largest_first", func(b *testing.B) { - b.ReportAllocs() - fn(b, routes) - }) + b.Run("largest_first", runAndRecord) sort.Slice(routes, func(i, j int) bool { if routes[j].len < routes[i].len { return true } return routes[j].addr < routes[i].addr }) - b.Run("smallest_first", func(b *testing.B) { - b.ReportAllocs() - fn(b, routes) - }) + b.Run("smallest_first", runAndRecord) }) } } @@ -253,7 +265,7 @@ func BenchmarkStrideTableInsertion(b *testing.B) { for i := 0; i < b.N; i++ { var rt strideTable[int] for _, route := range routes { - rt.insert(route.addr, route.len, &val) + rt.insert(route.addr, route.len, val) } } inserts := float64(b.N) * float64(len(routes)) @@ -269,7 +281,7 @@ func BenchmarkStrideTableDeletion(b *testing.B) { val := 0 var rt strideTable[int] for _, route := range routes { - rt.insert(route.addr, route.len, &val) + rt.insert(route.addr, route.len, val) } b.ResetTimer() @@ -287,7 +299,7 @@ func BenchmarkStrideTableDeletion(b *testing.B) { }) } -var writeSink *int +var writeSink int func BenchmarkStrideTableGet(b *testing.B) { // No need to forCountAndOrdering here, route lookup time is independent of @@ -300,7 +312,7 @@ func BenchmarkStrideTableGet(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - writeSink = rt.get(uint8(i)) + writeSink, _ = rt.get(uint8(i)) } gets := float64(b.N) elapsedSec := b.Elapsed().Seconds() @@ -318,7 +330,7 @@ type slowTable[T any] struct { type slowEntry[T any] struct { addr uint8 len int - val *T + val T } func (t *slowTable[T]) String() string { @@ -331,13 +343,14 @@ func (t *slowTable[T]) String() string { }) var ret bytes.Buffer for _, pfx := range pfxs { - fmt.Fprintf(&ret, "%3d/%d (%08b/%08b) = %v\n", pfx.addr, pfx.len, pfx.addr, pfxMask(pfx.len), *pfx.val) + fmt.Fprintf(&ret, "%3d/%d (%08b/%08b) = %v\n", pfx.addr, pfx.len, pfx.addr, pfxMask(pfx.len), pfx.val) } return ret.String() } -func (t *slowTable[T]) insert(addr uint8, prefixLen int, val *T) { +func (t *slowTable[T]) insert(addr uint8, prefixLen int, val T) { t.delete(addr, prefixLen) // no-op if prefix doesn't exist + t.prefixes = append(t.prefixes, slowEntry[T]{addr, prefixLen, val}) } @@ -352,18 +365,15 @@ func (t *slowTable[T]) delete(addr uint8, prefixLen int) { t.prefixes = pfx } -func (t *slowTable[T]) get(addr uint8) *T { - var ( - ret *T - curLen = -1 - ) +func (t *slowTable[T]) get(addr uint8) (ret T, ok bool) { + var curLen = -1 for _, e := range t.prefixes { if addr&pfxMask(e.len) == e.addr && e.len >= curLen { ret = e.val curLen = e.len } } - return ret + return ret, curLen != -1 } func pfxMask(pfxLen int) uint8 { @@ -374,7 +384,7 @@ func allPrefixes() []slowEntry[int] { ret := make([]slowEntry[int], 0, lastHostIndex) for i := 1; i < lastHostIndex+1; i++ { a, l := inversePrefixIndex(i) - ret = append(ret, slowEntry[int]{a, l, ptr.To(i)}) + ret = append(ret, slowEntry[int]{a, l, i}) } return ret } @@ -393,6 +403,15 @@ func formatSlowEntriesShort[T any](ents []slowEntry[T]) string { } var cmpDiffOpts = []cmp.Option{ - cmp.AllowUnexported(strideTable[int]{}, strideEntry[int]{}), cmp.Comparer(func(a, b netip.Prefix) bool { return a == b }), } + +func getsEqual[T comparable](a T, aOK bool, b T, bOK bool) bool { + if !aOK && !bOK { + return true + } + if aOK != bOK { + return false + } + return a == b +} diff --git a/net/art/table.go b/net/art/table.go index 4b12d4bd7..fa3975778 100644 --- a/net/art/table.go +++ b/net/art/table.go @@ -51,7 +51,7 @@ func (t *Table[T]) tableForAddr(addr netip.Addr) *strideTable[T] { // Get does a route lookup for addr and returns the associated value, or nil if // no route matched. -func (t *Table[T]) Get(addr netip.Addr) *T { +func (t *Table[T]) Get(addr netip.Addr) (ret T, ok bool) { t.init() // Ideally we would use addr.AsSlice here, but AsSlice is just @@ -84,13 +84,13 @@ func (t *Table[T]) Get(addr netip.Addr) *T { const maxDepth = 16 type prefixAndRoute struct { prefix netip.Prefix - route *T + route T } strideMatch := make([]prefixAndRoute, 0, maxDepth) findLeaf: for { - rt, child := st.getValAndChild(bs[i]) - if rt != nil { + rt, rtOK, child := st.getValAndChild(bs[i]) + if rtOK { // This strideTable contains a route that may be relevant to our // search, remember it. strideMatch = append(strideMatch, prefixAndRoute{st.prefix, rt}) @@ -115,7 +115,7 @@ findLeaf: // the correct most-specific route. for i := len(strideMatch) - 1; i >= 0; i-- { if m := strideMatch[i]; m.prefix.Contains(addr) { - return m.route + return m.route, true } } @@ -123,16 +123,13 @@ findLeaf: // immediately), or we went on a wild goose chase down a compressed path for // the wrong prefix, and also found no usable routes on the way back up to // the root. This is a miss. - return nil + return ret, false } // Insert adds pfx to the table, with value val. // If pfx is already present in the table, its value is set to val. -func (t *Table[T]) Insert(pfx netip.Prefix, val *T) { +func (t *Table[T]) Insert(pfx netip.Prefix, val T) { t.init() - if val == nil { - panic("Table.Insert called with nil value") - } // The standard library doesn't enforce normalized prefixes (where // the non-prefix bits are all zero). These algorithms require @@ -423,7 +420,7 @@ func (t *Table[T]) Delete(pfx netip.Prefix) { if debugDelete { fmt.Printf("delete: delete from st.prefix=%s addr=%d/%d\n", st.prefix, bs[byteIdx], numBits) } - if st.delete(bs[byteIdx], numBits) == nil { + if routeExisted := st.delete(bs[byteIdx], numBits); !routeExisted { // We're in the right strideTable, but pfx wasn't in // it. Refcounts haven't changed, so we can skip cleanup. if debugDelete { diff --git a/net/art/table_test.go b/net/art/table_test.go index 835828340..9166c00e5 100644 --- a/net/art/table_test.go +++ b/net/art/table_test.go @@ -12,8 +12,6 @@ import ( "strconv" "testing" "time" - - "tailscale.com/types/ptr" ) func TestRegression(t *testing.T) { @@ -30,17 +28,16 @@ func TestRegression(t *testing.T) { slow := slowPrefixTable[int]{} p := netip.MustParsePrefix - v := ptr.To(1) - tbl.Insert(p("226.205.197.0/24"), v) - slow.insert(p("226.205.197.0/24"), v) - v = ptr.To(2) - tbl.Insert(p("226.205.0.0/16"), v) - slow.insert(p("226.205.0.0/16"), v) + tbl.Insert(p("226.205.197.0/24"), 1) + slow.insert(p("226.205.197.0/24"), 1) + tbl.Insert(p("226.205.0.0/16"), 2) + slow.insert(p("226.205.0.0/16"), 2) probe := netip.MustParseAddr("226.205.121.152") - got, want := tbl.Get(probe), slow.get(probe) - if got != want { - t.Fatalf("got %v, want %v", got, want) + got, gotOK := tbl.Get(probe) + want, wantOK := slow.get(probe) + if !getsEqual(got, gotOK, want, wantOK) { + t.Fatalf("got (%v, %v), want (%v, %v)", got, gotOK, want, wantOK) } }) @@ -49,18 +46,18 @@ func TestRegression(t *testing.T) { // within computePrefixSplit. t1, t2 := &Table[int]{}, &Table[int]{} p := netip.MustParsePrefix - v1, v2 := ptr.To(1), ptr.To(2) - t1.Insert(p("136.20.0.0/16"), v1) - t1.Insert(p("136.20.201.62/32"), v2) + t1.Insert(p("136.20.0.0/16"), 1) + t1.Insert(p("136.20.201.62/32"), 2) - t2.Insert(p("136.20.201.62/32"), v2) - t2.Insert(p("136.20.0.0/16"), v1) + t2.Insert(p("136.20.201.62/32"), 2) + t2.Insert(p("136.20.0.0/16"), 1) a := netip.MustParseAddr("136.20.54.139") - got, want := t2.Get(a), t1.Get(a) - if got != want { - t.Errorf("Get(%q) is insertion order dependent (t1=%v, t2=%v)", a, want, got) + got1, ok1 := t1.Get(a) + got2, ok2 := t2.Get(a) + if !getsEqual(got1, ok1, got2, ok2) { + t.Errorf("Get(%q) is insertion order dependent: t1=(%v, %v), t2=(%v, %v)", a, got1, ok1, got2, ok2) } }) } @@ -99,7 +96,7 @@ func TestInsert(t *testing.T) { p := netip.MustParsePrefix // Create a new leaf strideTable, with compressed path - tbl.Insert(p("192.168.0.1/32"), ptr.To(1)) + tbl.Insert(p("192.168.0.1/32"), 1) checkRoutes(t, tbl, []tableTest{ {"192.168.0.1", 1}, {"192.168.0.2", -1}, @@ -114,7 +111,7 @@ func TestInsert(t *testing.T) { }) // Insert into previous leaf, no tree changes - tbl.Insert(p("192.168.0.2/32"), ptr.To(2)) + tbl.Insert(p("192.168.0.2/32"), 2) checkRoutes(t, tbl, []tableTest{ {"192.168.0.1", 1}, {"192.168.0.2", 2}, @@ -129,7 +126,7 @@ func TestInsert(t *testing.T) { }) // Insert into previous leaf, unaligned prefix covering the /32s - tbl.Insert(p("192.168.0.0/26"), ptr.To(7)) + tbl.Insert(p("192.168.0.0/26"), 7) checkRoutes(t, tbl, []tableTest{ {"192.168.0.1", 1}, {"192.168.0.2", 2}, @@ -144,7 +141,7 @@ func TestInsert(t *testing.T) { }) // Create a different leaf elsewhere - tbl.Insert(p("10.0.0.0/27"), ptr.To(3)) + tbl.Insert(p("10.0.0.0/27"), 3) checkRoutes(t, tbl, []tableTest{ {"192.168.0.1", 1}, {"192.168.0.2", 2}, @@ -159,7 +156,7 @@ func TestInsert(t *testing.T) { }) // Insert that creates a new intermediate table and a new child - tbl.Insert(p("192.168.1.1/32"), ptr.To(4)) + tbl.Insert(p("192.168.1.1/32"), 4) checkRoutes(t, tbl, []tableTest{ {"192.168.0.1", 1}, {"192.168.0.2", 2}, @@ -174,7 +171,7 @@ func TestInsert(t *testing.T) { }) // Insert that creates a new intermediate table but no new child - tbl.Insert(p("192.170.0.0/16"), ptr.To(5)) + tbl.Insert(p("192.170.0.0/16"), 5) checkRoutes(t, tbl, []tableTest{ {"192.168.0.1", 1}, {"192.168.0.2", 2}, @@ -190,7 +187,7 @@ func TestInsert(t *testing.T) { // New leaf in a different subtree, so the next insert can test a // variant of decompression. - tbl.Insert(p("192.180.0.1/32"), ptr.To(8)) + tbl.Insert(p("192.180.0.1/32"), 8) checkRoutes(t, tbl, []tableTest{ {"192.168.0.1", 1}, {"192.168.0.2", 2}, @@ -206,7 +203,7 @@ func TestInsert(t *testing.T) { // Insert that creates a new intermediate table but no new child, // with an unaligned intermediate - tbl.Insert(p("192.180.0.0/21"), ptr.To(9)) + tbl.Insert(p("192.180.0.0/21"), 9) checkRoutes(t, tbl, []tableTest{ {"192.168.0.1", 1}, {"192.168.0.2", 2}, @@ -221,7 +218,7 @@ func TestInsert(t *testing.T) { }) // Insert a default route, those have their own codepath. - tbl.Insert(p("0.0.0.0/0"), ptr.To(6)) + tbl.Insert(p("0.0.0.0/0"), 6) checkRoutes(t, tbl, []tableTest{ {"192.168.0.1", 1}, {"192.168.0.2", 2}, @@ -238,7 +235,7 @@ func TestInsert(t *testing.T) { // Now all of the above again, but for IPv6. // Create a new leaf strideTable, with compressed path - tbl.Insert(p("ff:aaaa::1/128"), ptr.To(1)) + tbl.Insert(p("ff:aaaa::1/128"), 1) checkRoutes(t, tbl, []tableTest{ {"ff:aaaa::1", 1}, {"ff:aaaa::2", -1}, @@ -253,7 +250,7 @@ func TestInsert(t *testing.T) { }) // Insert into previous leaf, no tree changes - tbl.Insert(p("ff:aaaa::2/128"), ptr.To(2)) + tbl.Insert(p("ff:aaaa::2/128"), 2) checkRoutes(t, tbl, []tableTest{ {"ff:aaaa::1", 1}, {"ff:aaaa::2", 2}, @@ -268,7 +265,7 @@ func TestInsert(t *testing.T) { }) // Insert into previous leaf, unaligned prefix covering the /128s - tbl.Insert(p("ff:aaaa::/125"), ptr.To(7)) + tbl.Insert(p("ff:aaaa::/125"), 7) checkRoutes(t, tbl, []tableTest{ {"ff:aaaa::1", 1}, {"ff:aaaa::2", 2}, @@ -283,7 +280,7 @@ func TestInsert(t *testing.T) { }) // Create a different leaf elsewhere - tbl.Insert(p("ffff:bbbb::/120"), ptr.To(3)) + tbl.Insert(p("ffff:bbbb::/120"), 3) checkRoutes(t, tbl, []tableTest{ {"ff:aaaa::1", 1}, {"ff:aaaa::2", 2}, @@ -298,7 +295,7 @@ func TestInsert(t *testing.T) { }) // Insert that creates a new intermediate table and a new child - tbl.Insert(p("ff:aaaa:aaaa::1/128"), ptr.To(4)) + tbl.Insert(p("ff:aaaa:aaaa::1/128"), 4) checkRoutes(t, tbl, []tableTest{ {"ff:aaaa::1", 1}, {"ff:aaaa::2", 2}, @@ -313,7 +310,7 @@ func TestInsert(t *testing.T) { }) // Insert that creates a new intermediate table but no new child - tbl.Insert(p("ff:aaaa:aaaa:bb00::/56"), ptr.To(5)) + tbl.Insert(p("ff:aaaa:aaaa:bb00::/56"), 5) checkRoutes(t, tbl, []tableTest{ {"ff:aaaa::1", 1}, {"ff:aaaa::2", 2}, @@ -329,7 +326,7 @@ func TestInsert(t *testing.T) { // New leaf in a different subtree, so the next insert can test a // variant of decompression. - tbl.Insert(p("ff:cccc::1/128"), ptr.To(8)) + tbl.Insert(p("ff:cccc::1/128"), 8) checkRoutes(t, tbl, []tableTest{ {"ff:aaaa::1", 1}, {"ff:aaaa::2", 2}, @@ -345,7 +342,7 @@ func TestInsert(t *testing.T) { // Insert that creates a new intermediate table but no new child, // with an unaligned intermediate - tbl.Insert(p("ff:cccc::/37"), ptr.To(9)) + tbl.Insert(p("ff:cccc::/37"), 9) checkRoutes(t, tbl, []tableTest{ {"ff:aaaa::1", 1}, {"ff:aaaa::2", 2}, @@ -360,7 +357,7 @@ func TestInsert(t *testing.T) { }) // Insert a default route, those have their own codepath. - tbl.Insert(p("::/0"), ptr.To(6)) + tbl.Insert(p("::/0"), 6) checkRoutes(t, tbl, []tableTest{ {"ff:aaaa::1", 1}, {"ff:aaaa::2", 2}, @@ -384,7 +381,7 @@ func TestDelete(t *testing.T) { tbl := &Table[int]{} checkSize(t, tbl, 2) - tbl.Insert(p("10.0.0.0/8"), ptr.To(1)) + tbl.Insert(p("10.0.0.0/8"), 1) checkRoutes(t, tbl, []tableTest{ {"10.0.0.1", 1}, {"255.255.255.255", -1}, @@ -403,7 +400,7 @@ func TestDelete(t *testing.T) { tbl := &Table[int]{} checkSize(t, tbl, 2) - tbl.Insert(p("192.168.0.1/32"), ptr.To(1)) + tbl.Insert(p("192.168.0.1/32"), 1) checkRoutes(t, tbl, []tableTest{ {"192.168.0.1", 1}, {"255.255.255.255", -1}, @@ -421,8 +418,8 @@ func TestDelete(t *testing.T) { // Create an intermediate with 2 children, then delete one leaf. tbl := &Table[int]{} checkSize(t, tbl, 2) - tbl.Insert(p("192.168.0.1/32"), ptr.To(1)) - tbl.Insert(p("192.180.0.1/32"), ptr.To(2)) + tbl.Insert(p("192.168.0.1/32"), 1) + tbl.Insert(p("192.180.0.1/32"), 2) checkRoutes(t, tbl, []tableTest{ {"192.168.0.1", 1}, {"192.180.0.1", 2}, @@ -442,9 +439,9 @@ func TestDelete(t *testing.T) { // Same, but the intermediate carries a route as well. tbl := &Table[int]{} checkSize(t, tbl, 2) - tbl.Insert(p("192.168.0.1/32"), ptr.To(1)) - tbl.Insert(p("192.180.0.1/32"), ptr.To(2)) - tbl.Insert(p("192.0.0.0/10"), ptr.To(3)) + tbl.Insert(p("192.168.0.1/32"), 1) + tbl.Insert(p("192.180.0.1/32"), 2) + tbl.Insert(p("192.0.0.0/10"), 3) checkRoutes(t, tbl, []tableTest{ {"192.168.0.1", 1}, {"192.180.0.1", 2}, @@ -466,9 +463,9 @@ func TestDelete(t *testing.T) { // Intermediate with 3 leaves, then delete one leaf. tbl := &Table[int]{} checkSize(t, tbl, 2) - tbl.Insert(p("192.168.0.1/32"), ptr.To(1)) - tbl.Insert(p("192.180.0.1/32"), ptr.To(2)) - tbl.Insert(p("192.200.0.1/32"), ptr.To(3)) + tbl.Insert(p("192.168.0.1/32"), 1) + tbl.Insert(p("192.180.0.1/32"), 2) + tbl.Insert(p("192.200.0.1/32"), 3) checkRoutes(t, tbl, []tableTest{ {"192.168.0.1", 1}, {"192.180.0.1", 2}, @@ -490,7 +487,7 @@ func TestDelete(t *testing.T) { // Delete non-existent prefix, missing strideTable path. tbl := &Table[int]{} checkSize(t, tbl, 2) - tbl.Insert(p("192.168.0.1/32"), ptr.To(1)) + tbl.Insert(p("192.168.0.1/32"), 1) checkRoutes(t, tbl, []tableTest{ {"192.168.0.1", 1}, {"192.255.0.1", -1}, @@ -509,7 +506,7 @@ func TestDelete(t *testing.T) { // with a wrong turn. tbl := &Table[int]{} checkSize(t, tbl, 2) - tbl.Insert(p("192.168.0.1/32"), ptr.To(1)) + tbl.Insert(p("192.168.0.1/32"), 1) checkRoutes(t, tbl, []tableTest{ {"192.168.0.1", 1}, {"192.255.0.1", -1}, @@ -528,7 +525,7 @@ func TestDelete(t *testing.T) { // leaf doesn't contain route. tbl := &Table[int]{} checkSize(t, tbl, 2) - tbl.Insert(p("192.168.0.1/32"), ptr.To(1)) + tbl.Insert(p("192.168.0.1/32"), 1) checkRoutes(t, tbl, []tableTest{ {"192.168.0.1", 1}, {"192.255.0.1", -1}, @@ -547,8 +544,8 @@ func TestDelete(t *testing.T) { // compactable. tbl := &Table[int]{} checkSize(t, tbl, 2) - tbl.Insert(p("192.168.0.1/32"), ptr.To(1)) - tbl.Insert(p("192.168.0.0/22"), ptr.To(2)) + tbl.Insert(p("192.168.0.1/32"), 1) + tbl.Insert(p("192.168.0.0/22"), 2) checkRoutes(t, tbl, []tableTest{ {"192.168.0.1", 1}, {"192.168.0.2", 2}, @@ -568,7 +565,7 @@ func TestDelete(t *testing.T) { // Default routes have a special case in the code. tbl := &Table[int]{} - tbl.Insert(p("0.0.0.0/0"), ptr.To(1)) + tbl.Insert(p("0.0.0.0/0"), 1) tbl.Delete(p("0.0.0.0/0")) checkRoutes(t, tbl, []tableTest{ @@ -595,20 +592,20 @@ func TestInsertCompare(t *testing.T) { t.Logf(fast.debugSummary()) } - seenVals4 := map[*int]bool{} - seenVals6 := map[*int]bool{} + seenVals4 := map[int]bool{} + seenVals6 := map[int]bool{} for i := 0; i < 10_000; i++ { a := randomAddr() - slowVal := slow.get(a) - fastVal := fast.Get(a) + slowVal, slowOK := slow.get(a) + fastVal, fastOK := fast.Get(a) + if !getsEqual(slowVal, slowOK, fastVal, fastOK) { + t.Fatalf("get(%q) = (%v, %v), want (%v, %v)", a, fastVal, fastOK, slowVal, slowOK) + } if a.Is6() { seenVals6[fastVal] = true } else { seenVals4[fastVal] = true } - if slowVal != fastVal { - t.Fatalf("get(%q) = %p, want %p", a, fastVal, slowVal) - } } // Empirically, 10k probes into 5k v4 prefixes and 5k v6 prefixes results in @@ -667,13 +664,10 @@ func TestInsertShuffled(t *testing.T) { } for _, a := range addrs { - val1 := rt.Get(a) - val2 := rt2.Get(a) - if val1 == nil && val2 == nil { - continue - } - if (val1 == nil && val2 != nil) || (val1 != nil && val2 == nil) || (*val1 != *val2) { - t.Fatalf("get(%q) = %s, want %s", a, printIntPtr(val2), printIntPtr(val1)) + val1, ok1 := rt.Get(a) + val2, ok2 := rt2.Get(a) + if !getsEqual(val1, ok1, val2, ok2) { + t.Fatalf("get(%q) = (%v, %v), want (%v, %v)", a, val2, ok2, val1, ok1) } } } @@ -727,20 +721,20 @@ func TestDeleteCompare(t *testing.T) { fast.Delete(pfx.pfx) } - seenVals4 := map[*int]bool{} - seenVals6 := map[*int]bool{} + seenVals4 := map[int]bool{} + seenVals6 := map[int]bool{} for i := 0; i < numProbes; i++ { a := randomAddr() - slowVal := slow.get(a) - fastVal := fast.Get(a) + slowVal, slowOK := slow.get(a) + fastVal, fastOK := fast.Get(a) + if !getsEqual(slowVal, slowOK, fastVal, fastOK) { + t.Fatalf("get(%q) = (%v, %v), want (%v, %v)", a, fastVal, fastOK, slowVal, slowOK) + } if a.Is6() { seenVals6[fastVal] = true } else { seenVals4[fastVal] = true } - if slowVal != fastVal { - t.Fatalf("get(%q) = %p, want %p", a, fastVal, slowVal) - } } // Empirically, 10k probes into 5k v4 prefixes and 5k v6 prefixes results in // ~1k distinct values for v4 and ~300 for v6. distinct routes. This sanity @@ -814,13 +808,10 @@ func TestDeleteShuffled(t *testing.T) { // test for equivalence statistically with random probes instead. for i := 0; i < numProbes; i++ { a := randomAddr() - val1 := rt.Get(a) - val2 := rt2.Get(a) - if val1 == nil && val2 == nil { - continue - } - if (val1 == nil && val2 != nil) || (val1 != nil && val2 == nil) || (*val1 != *val2) { - t.Errorf("get(%q) = %s, want %s", a, printIntPtr(val2), printIntPtr(val1)) + val1, ok1 := rt.Get(a) + val2, ok2 := rt2.Get(a) + if !getsEqual(val1, ok1, val2, ok2) { + t.Errorf("get(%q) = (%v, %v), want (%v, %v)", a, val2, ok2, val1, ok1) } } } @@ -868,12 +859,12 @@ type tableTest struct { func checkRoutes(t *testing.T, tbl *Table[int], tt []tableTest) { t.Helper() for _, tc := range tt { - v := tbl.Get(netip.MustParseAddr(tc.addr)) - if v == nil && tc.want != -1 { - t.Errorf("lookup %q got nil, want %d", tc.addr, tc.want) + v, ok := tbl.Get(netip.MustParseAddr(tc.addr)) + if !ok && tc.want != -1 { + t.Errorf("lookup %q got (%v, %v), want (_, false)", tc.addr, v, ok) } - if v != nil && *v != tc.want { - t.Errorf("lookup %q got %d, want %d", tc.addr, *v, tc.want) + if ok && v != tc.want { + t.Errorf("lookup %q got (%v, %v), want (%v, true)", tc.addr, v, ok, tc.want) } } } @@ -1005,7 +996,7 @@ func BenchmarkTableGet(b *testing.B) { for i := 0; i < b.N; i++ { addr := genAddr() t.Start() - writeSink = rt.Get(addr) + writeSink, _ = rt.Get(addr) t.Stop() } }) @@ -1112,7 +1103,7 @@ type slowPrefixTable[T any] struct { type slowPrefixEntry[T any] struct { pfx netip.Prefix - val *T + val T } func (t *slowPrefixTable[T]) delete(pfx netip.Prefix) { @@ -1127,7 +1118,7 @@ func (t *slowPrefixTable[T]) delete(pfx netip.Prefix) { t.prefixes = ret } -func (t *slowPrefixTable[T]) insert(pfx netip.Prefix, val *T) { +func (t *slowPrefixTable[T]) insert(pfx netip.Prefix, val T) { pfx = pfx.Masked() for i, ent := range t.prefixes { if ent.pfx == pfx { @@ -1138,11 +1129,8 @@ func (t *slowPrefixTable[T]) insert(pfx netip.Prefix, val *T) { t.prefixes = append(t.prefixes, slowPrefixEntry[T]{pfx, val}) } -func (t *slowPrefixTable[T]) get(addr netip.Addr) *T { - var ( - ret *T - bestLen = -1 - ) +func (t *slowPrefixTable[T]) get(addr netip.Addr) (ret T, ok bool) { + bestLen := -1 for _, pfx := range t.prefixes { if pfx.pfx.Contains(addr) && pfx.pfx.Bits() > bestLen { @@ -1150,7 +1138,7 @@ func (t *slowPrefixTable[T]) get(addr netip.Addr) *T { bestLen = pfx.pfx.Bits() } } - return ret + return ret, bestLen != -1 } // randomPrefixes returns n randomly generated prefixes and associated values, @@ -1176,7 +1164,7 @@ func randomPrefixes4(n int) []slowPrefixEntry[int] { ret := make([]slowPrefixEntry[int], 0, len(pfxs)) for pfx := range pfxs { - ret = append(ret, slowPrefixEntry[int]{pfx, ptr.To(rand.Int())}) + ret = append(ret, slowPrefixEntry[int]{pfx, rand.Int()}) } return ret @@ -1197,7 +1185,7 @@ func randomPrefixes6(n int) []slowPrefixEntry[int] { ret := make([]slowPrefixEntry[int], 0, len(pfxs)) for pfx := range pfxs { - ret = append(ret, slowPrefixEntry[int]{pfx, ptr.To(rand.Int())}) + ret = append(ret, slowPrefixEntry[int]{pfx, rand.Int()}) } return ret @@ -1230,14 +1218,6 @@ func randomAddr6() netip.Addr { return netip.AddrFrom16(b) } -// printIntPtr returns *v as a string, or the literal "" if v is nil. -func printIntPtr(v *int) string { - if v == nil { - return "" - } - return fmt.Sprint(*v) -} - // roundFloat64 rounds f to 2 decimal places, for display. // // It round-trips through a float->string->float conversion, so should not be