diff --git a/net/art/stride_table.go b/net/art/stride_table.go index f8bdb20c5..4c19da94b 100644 --- a/net/art/stride_table.go +++ b/net/art/stride_table.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "math/bits" + "net/netip" "strconv" "strings" ) @@ -32,6 +33,9 @@ type strideEntry[T any] struct { // The leaves of the binary tree are host routes (/8s). Each parent is a // successively larger prefix that encompasses its children (/7 through /0). type strideTable[T any] struct { + // prefix is the prefix represented by the 0/0 route of this strideTable. It + // is used in multi-level tables to support path compression. + prefix netip.Prefix // entries is the nodes of the binary tree, laid out in a flattened array. // // The array indices are arranged by the prefixIndex function, such that the @@ -76,7 +80,9 @@ func (t *strideTable[T]) deleteChild(idx int) { func (t *strideTable[T]) getOrCreateChild(addr uint8) *strideTable[T] { idx := hostIndex(addr) if t.entries[idx].child == nil { - t.entries[idx].child = new(strideTable[T]) + t.entries[idx].child = &strideTable[T]{ + prefix: childPrefixOf(t.prefix, addr), + } t.refs++ } return t.entries[idx].child @@ -229,3 +235,29 @@ func formatPrefixTable(addr uint8, len int) string { } return fmt.Sprintf("%3d/%d", addr, len) } + +// childPrefixOf returns the child prefix of parent whose final byte +// is stride. The parent prefix must be byte-aligned +// (i.e. parent.Bits() must be a multiple of 8), and be no more +// specific than /24 for IPv4 or /120 for IPv6. +// +// For example, childPrefixOf("192.168.0.0/16", 8) == "192.168.8.0/24". +func childPrefixOf(parent netip.Prefix, stride uint8) netip.Prefix { + l := parent.Bits() + if l%8 != 0 { + panic("parent prefix is not 8-bit aligned") + } + if l >= parent.Addr().BitLen() { + panic("parent prefix cannot be extended further") + } + off := l / 8 + if parent.Addr().Is4() { + bs := parent.Addr().As4() + bs[off] = stride + return netip.PrefixFrom(netip.AddrFrom4(bs), l+8) + } else { + bs := parent.Addr().As16() + bs[off] = stride + return netip.PrefixFrom(netip.AddrFrom16(bs), l+8) + } +} diff --git a/net/art/stride_table_test.go b/net/art/stride_table_test.go index dec39cb7a..e706ad640 100644 --- a/net/art/stride_table_test.go +++ b/net/art/stride_table_test.go @@ -7,6 +7,7 @@ import ( "bytes" "fmt" "math/rand" + "net/netip" "sort" "strings" "testing" @@ -100,7 +101,7 @@ func TestStrideTableInsertShuffled(t *testing.T) { for _, route := range routes2 { rt2.insert(route.addr, route.len, route.val) } - if diff := cmp.Diff(rt, rt2, cmp.AllowUnexported(strideTable[int]{}, strideEntry[int]{})); diff != "" { + if diff := cmp.Diff(rt, rt2, cmpDiffOpts...); diff != "" { t.Errorf("tables ended up different with different insertion order (-got+want):\n%s\n\nOrder 1: %v\nOrder 2: %v", diff, formatSlowEntriesShort(routes), formatSlowEntriesShort(routes2)) } @@ -108,7 +109,7 @@ func TestStrideTableInsertShuffled(t *testing.T) { for _, route := range routes2 { rtZero2.insert(route.addr, route.len, &zero) } - if diff := cmp.Diff(rtZero, rtZero2, cmp.AllowUnexported(strideTable[int]{}, strideEntry[int]{})); diff != "" { + if diff := cmp.Diff(rtZero, rtZero2, cmpDiffOpts...); diff != "" { t.Errorf("tables with identical vals ended up different with different insertion order (-got+want):\n%s\n\nOrder 1: %v\nOrder 2: %v", diff, formatSlowEntriesShort(routes), formatSlowEntriesShort(routes2)) } } @@ -180,7 +181,7 @@ func TestStrideTableDeleteShuffle(t *testing.T) { for _, route := range toDelete2 { rt2.delete(route.addr, route.len) } - if diff := cmp.Diff(rt, rt2, cmp.AllowUnexported(strideTable[int]{}, strideEntry[int]{})); diff != "" { + if diff := cmp.Diff(rt, rt2, cmpDiffOpts...); diff != "" { t.Errorf("tables ended up different with different deletion order (-got+want):\n%s\n\nOrder 1: %v\nOrder 2: %v", diff, formatSlowEntriesShort(toDelete), formatSlowEntriesShort(toDelete2)) } @@ -191,7 +192,7 @@ func TestStrideTableDeleteShuffle(t *testing.T) { for _, route := range toDelete2 { rtZero2.delete(route.addr, route.len) } - if diff := cmp.Diff(rtZero, rtZero2, cmp.AllowUnexported(strideTable[int]{}, strideEntry[int]{})); diff != "" { + if diff := cmp.Diff(rtZero, rtZero2, cmpDiffOpts...); diff != "" { t.Errorf("tables with identical vals ended up different with different deletion order (-got+want):\n%s\n\nOrder 1: %v\nOrder 2: %v", diff, formatSlowEntriesShort(toDelete), formatSlowEntriesShort(toDelete2)) } } @@ -382,3 +383,8 @@ func formatSlowEntriesShort[T any](ents []slowEntry[T]) string { } return "[" + strings.Join(ret, " ") + "]" } + +var cmpDiffOpts = []cmp.Option{ + cmp.AllowUnexported(strideTable[int]{}, strideEntry[int]{}), + cmp.Comparer(func(a, b netip.Prefix) bool { return a == b }), +} diff --git a/net/art/table.go b/net/art/table.go index 90ae60f82..69f274b3f 100644 --- a/net/art/table.go +++ b/net/art/table.go @@ -18,17 +18,27 @@ import ( "io" "net/netip" "strings" + "sync" ) // Table is an IPv4 and IPv6 routing table. type Table[T any] struct { - v4 strideTable[T] - v6 strideTable[T] + v4 strideTable[T] + v6 strideTable[T] + initOnce sync.Once +} + +func (t *Table[T]) init() { + t.initOnce.Do(func() { + t.v4.prefix = netip.PrefixFrom(netip.IPv4Unspecified(), 0) + t.v6.prefix = netip.PrefixFrom(netip.IPv6Unspecified(), 0) + }) } // Get does a route lookup for addr and returns the associated value, or nil if // no route matched. func (t *Table[T]) Get(addr netip.Addr) *T { + t.init() st := &t.v4 if addr.Is6() { st = &t.v6 @@ -58,6 +68,7 @@ func (t *Table[T]) Get(addr netip.Addr) *T { // Insert adds pfx to the table, with value val. // If pfx is already present in the table, its value is set to val. func (t *Table[T]) Insert(pfx netip.Prefix, val *T) { + t.init() if val == nil { panic("Table.Insert called with nil value") } @@ -85,6 +96,7 @@ func (t *Table[T]) Insert(pfx netip.Prefix, val *T) { // Delete removes pfx from the table, if it is present. func (t *Table[T]) Delete(pfx netip.Prefix) { + t.init() st := &t.v4 if pfx.Addr().Is6() { st = &t.v6 @@ -141,6 +153,7 @@ func (t *Table[T]) Delete(pfx netip.Prefix) { // debugSummary prints the tree of allocated strideTables in t, with each // strideTable's refcount. func (t *Table[T]) debugSummary() string { + t.init() var ret bytes.Buffer fmt.Fprintf(&ret, "v4: ") strideSummary(&ret, &t.v4, 0) @@ -150,7 +163,7 @@ func (t *Table[T]) debugSummary() string { } func strideSummary[T any](w io.Writer, st *strideTable[T], indent int) { - fmt.Fprintf(w, "%d refs\n", st.refs) + fmt.Fprintf(w, "%s: %d refs\n", st.prefix, st.refs) indent += 2 for i := firstHostIndex; i <= lastHostIndex; i++ { if child := st.entries[i].child; child != nil {