util/deephash: don't track cycles on non-recursive types

name              old time/op    new time/op    delta
Hash-8              67.3µs ±20%    76.5µs ±16%     ~     (p=0.143 n=10+10)
HashMapAcyclic-8    63.0µs ± 2%    56.3µs ± 1%  -10.65%  (p=0.000 n=10+8)
TailcfgNode-8       9.18µs ± 2%    6.52µs ± 3%  -28.96%  (p=0.000 n=9+10)
HashArray-8          732ns ± 3%     709ns ± 1%   -3.21%  (p=0.000 n=10+10)

name              old alloc/op   new alloc/op   delta
Hash-8               24.0B ± 0%     24.0B ± 0%     ~     (all equal)
HashMapAcyclic-8     0.00B          0.00B          ~     (all equal)
TailcfgNode-8        0.00B          0.00B          ~     (all equal)
HashArray-8          0.00B          0.00B          ~     (all equal)

name              old allocs/op  new allocs/op  delta
Hash-8                1.00 ± 0%      1.00 ± 0%     ~     (all equal)
HashMapAcyclic-8      0.00           0.00          ~     (all equal)
TailcfgNode-8         0.00           0.00          ~     (all equal)
HashArray-8           0.00           0.00          ~     (all equal)

Change-Id: I28642050d837dff66b2db54b2b0e6d272a930be8
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/4873/head
Brad Fitzpatrick 2 years ago committed by Brad Fitzpatrick
parent 36ea837736
commit f31588786f

@ -125,7 +125,7 @@ func Hash(v any) (s Sum) {
seed = uint64(time.Now().UnixNano()) seed = uint64(time.Now().UnixNano())
}) })
h.hashUint64(seed) h.hashUint64(seed)
h.hashValue(reflect.ValueOf(v)) h.hashValue(reflect.ValueOf(v), false)
return h.sum() return h.sum()
} }
@ -164,26 +164,151 @@ func (h *hasher) hashUint64(i uint64) {
var uint8Type = reflect.TypeOf(byte(0)) var uint8Type = reflect.TypeOf(byte(0))
func (h *hasher) hashValue(v reflect.Value) { // typeInfo describes properties of a type.
if !v.IsValid() { type typeInfo struct {
return rtype reflect.Type
isRecursive bool
// elemTypeInfo is the element type's typeInfo.
// It's set when rtype is of Kind Ptr, Slice, Array, Map.
elemTypeInfo *typeInfo
// keyTypeInfo is the map key type's typeInfo.
// It's set when rtype is of Kind Map.
keyTypeInfo *typeInfo
} }
w := h.bw var typeInfoMap sync.Map // map[reflect.Type]*typeInfo
var typeInfoMapPopulate sync.Mutex // just for adding to typeInfoMap
func getTypeInfo(t reflect.Type) *typeInfo {
if f, ok := typeInfoMap.Load(t); ok {
return f.(*typeInfo)
}
typeInfoMapPopulate.Lock()
defer typeInfoMapPopulate.Unlock()
newTypes := map[reflect.Type]*typeInfo{}
ti := getTypeInfoLocked(t, newTypes)
for t, ti := range newTypes {
typeInfoMap.Store(t, ti)
}
return ti
}
func getTypeInfoLocked(t reflect.Type, incomplete map[reflect.Type]*typeInfo) *typeInfo {
if v, ok := typeInfoMap.Load(t); ok {
return v.(*typeInfo)
}
if ti, ok := incomplete[t]; ok {
return ti
}
ti := &typeInfo{
rtype: t,
isRecursive: typeIsRecursive(t),
}
incomplete[t] = ti
switch t.Kind() {
case reflect.Map:
ti.keyTypeInfo = getTypeInfoLocked(t.Key(), incomplete)
fallthrough
case reflect.Ptr, reflect.Slice, reflect.Array:
ti.elemTypeInfo = getTypeInfoLocked(t.Elem(), incomplete)
}
return ti
}
// typeIsRecursive reports whether t has a path back to itself.
//
// For interfaces, it currently always reports true.
func typeIsRecursive(t reflect.Type) bool {
inStack := map[reflect.Type]bool{}
var stack []reflect.Type
var visitType func(t reflect.Type) (isRecursiveSoFar bool)
visitType = func(t reflect.Type) (isRecursiveSoFar bool) {
switch t.Kind() {
case reflect.Bool,
reflect.Int,
reflect.Int8,
reflect.Int16,
reflect.Int32,
reflect.Int64,
reflect.Uint,
reflect.Uint8,
reflect.Uint16,
reflect.Uint32,
reflect.Uint64,
reflect.Uintptr,
reflect.Float32,
reflect.Float64,
reflect.Complex64,
reflect.Complex128,
reflect.String,
reflect.UnsafePointer,
reflect.Func:
return false
}
if t.Size() == 0 {
return false
}
if inStack[t] {
return true
}
stack = append(stack, t)
inStack[t] = true
defer func() {
delete(inStack, t)
stack = stack[:len(stack)-1]
}()
if v.CanInterface() { switch t.Kind() {
// Use AppendTo methods, if available and cheap. default:
if v.CanAddr() && v.Type().Implements(appenderToType) { panic("unhandled kind " + t.Kind().String())
a := v.Addr().Interface().(appenderTo) case reflect.Interface:
size := h.scratch[:8] // Assume the worst for now. TODO(bradfitz): in some cases
record := a.AppendTo(size) // we should be able to prove that it's not recursive. Not worth
binary.LittleEndian.PutUint64(record, uint64(len(record)-len(size))) // it for now.
w.Write(record) return true
case reflect.Array, reflect.Chan, reflect.Pointer, reflect.Slice:
return visitType(t.Elem())
case reflect.Map:
if visitType(t.Key()) {
return true
}
if visitType(t.Elem()) {
return true
}
case reflect.Struct:
if t.String() == "intern.Value" {
// Otherwise its interface{} makes this return true.
return false
}
for i, numField := 0, t.NumField(); i < numField; i++ {
if visitType(t.Field(i).Type) {
return true
}
}
return false
}
return false
}
return visitType(t)
}
func (h *hasher) hashValue(v reflect.Value, forceCycleChecking bool) {
if !v.IsValid() {
return return
} }
ti := getTypeInfo(v.Type())
h.hashValueWithType(v, ti, forceCycleChecking)
} }
// TODO(dsnet): Avoid cycle detection for types that cannot have cycles. func (h *hasher) hashValueWithType(v reflect.Value, ti *typeInfo, forceCycleChecking bool) {
w := h.bw
doCheckCycles := forceCycleChecking || ti.isRecursive
// Generic handling. // Generic handling.
switch v.Kind() { switch v.Kind() {
@ -195,7 +320,7 @@ func (h *hasher) hashValue(v reflect.Value) {
return return
} }
// Check for cycle. if doCheckCycles {
ptr := pointerOf(v) ptr := pointerOf(v)
if idx, ok := h.visitStack.seen(ptr); ok { if idx, ok := h.visitStack.seen(ptr); ok {
h.hashUint8(2) // indicates cycle h.hashUint8(2) // indicates cycle
@ -204,12 +329,13 @@ func (h *hasher) hashValue(v reflect.Value) {
} }
h.visitStack.push(ptr) h.visitStack.push(ptr)
defer h.visitStack.pop(ptr) defer h.visitStack.pop(ptr)
}
h.hashUint8(1) // indicates visiting a pointer h.hashUint8(1) // indicates visiting a pointer
h.hashValue(v.Elem()) h.hashValueWithType(v.Elem(), ti.elemTypeInfo, doCheckCycles)
case reflect.Struct: case reflect.Struct:
for i, n := 0, v.NumField(); i < n; i++ { for i, n := 0, v.NumField(); i < n; i++ {
h.hashValue(v.Field(i)) h.hashValue(v.Field(i), doCheckCycles)
} }
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
vLen := v.Len() vLen := v.Len()
@ -233,7 +359,7 @@ func (h *hasher) hashValue(v reflect.Value) {
// TODO(dsnet): Perform cycle detection for slices, // TODO(dsnet): Perform cycle detection for slices,
// which is functionally a list of pointers. // which is functionally a list of pointers.
// See https://github.com/google/go-cmp/blob/402949e8139bb890c71a707b6faf6dd05c92f4e5/cmp/compare.go#L438-L450 // See https://github.com/google/go-cmp/blob/402949e8139bb890c71a707b6faf6dd05c92f4e5/cmp/compare.go#L438-L450
h.hashValue(v.Index(i)) h.hashValueWithType(v.Index(i), ti.elemTypeInfo, doCheckCycles)
} }
case reflect.Interface: case reflect.Interface:
if v.IsNil() { if v.IsNil() {
@ -244,9 +370,10 @@ func (h *hasher) hashValue(v reflect.Value) {
h.hashUint8(1) // indicates visiting interface value h.hashUint8(1) // indicates visiting interface value
h.hashType(v.Type()) h.hashType(v.Type())
h.hashValue(v) h.hashValue(v, doCheckCycles)
case reflect.Map: case reflect.Map:
// Check for cycle. // Check for cycle.
if doCheckCycles {
ptr := pointerOf(v) ptr := pointerOf(v)
if idx, ok := h.visitStack.seen(ptr); ok { if idx, ok := h.visitStack.seen(ptr); ok {
h.hashUint8(2) // indicates cycle h.hashUint8(2) // indicates cycle
@ -255,9 +382,9 @@ func (h *hasher) hashValue(v reflect.Value) {
} }
h.visitStack.push(ptr) h.visitStack.push(ptr)
defer h.visitStack.pop(ptr) defer h.visitStack.pop(ptr)
}
h.hashUint8(1) // indicates visiting a map h.hashUint8(1) // indicates visiting a map
h.hashMap(v) h.hashMap(v, ti, doCheckCycles)
case reflect.String: case reflect.String:
s := v.String() s := v.String()
h.hashUint64(uint64(len(s))) h.hashUint64(uint64(len(s)))
@ -325,7 +452,7 @@ func (c *valueCache) get(t reflect.Type) reflect.Value {
// It relies on a map being a functionally an unordered set of KV entries. // It relies on a map being a functionally an unordered set of KV entries.
// So long as we hash each KV entry together, we can XOR all // So long as we hash each KV entry together, we can XOR all
// of the individual hashes to produce a unique hash for the entire map. // of the individual hashes to produce a unique hash for the entire map.
func (h *hasher) hashMap(v reflect.Value) { func (h *hasher) hashMap(v reflect.Value, ti *typeInfo, checkCycles bool) {
mh := mapHasherPool.Get().(*mapHasher) mh := mapHasherPool.Get().(*mapHasher)
defer mapHasherPool.Put(mh) defer mapHasherPool.Put(mh)
@ -341,8 +468,8 @@ func (h *hasher) hashMap(v reflect.Value) {
k.SetIterKey(iter) k.SetIterKey(iter)
e.SetIterValue(iter) e.SetIterValue(iter)
mh.h.reset() mh.h.reset()
mh.h.hashValue(k) mh.h.hashValueWithType(k, ti.keyTypeInfo, checkCycles)
mh.h.hashValue(e) mh.h.hashValueWithType(e, ti.elemTypeInfo, checkCycles)
sum.xor(mh.h.sum()) sum.xor(mh.h.sum())
} }
h.bw.Write(append(h.scratch[:0], sum.sum[:]...)) // append into scratch to avoid heap allocation h.bw.Write(append(h.scratch[:0], sum.sum[:]...)) // append into scratch to avoid heap allocation

@ -14,6 +14,8 @@ import (
"reflect" "reflect"
"runtime" "runtime"
"testing" "testing"
"time"
"unsafe"
"go4.org/mem" "go4.org/mem"
"inet.af/netaddr" "inet.af/netaddr"
@ -21,6 +23,7 @@ import (
"tailscale.com/types/dnstype" "tailscale.com/types/dnstype"
"tailscale.com/types/ipproto" "tailscale.com/types/ipproto"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/structs"
"tailscale.com/util/dnsname" "tailscale.com/util/dnsname"
"tailscale.com/version" "tailscale.com/version"
"tailscale.com/wgengine/filter" "tailscale.com/wgengine/filter"
@ -235,6 +238,41 @@ func getVal() []any {
} }
} }
func TestTypeIsRecursive(t *testing.T) {
type RecursiveStruct struct {
v *RecursiveStruct
}
type RecursiveChan chan *RecursiveChan
tests := []struct {
val any
want bool
}{
{val: 42, want: false},
{val: "string", want: false},
{val: 1 + 2i, want: false},
{val: struct{}{}, want: false},
{val: (*RecursiveStruct)(nil), want: true},
{val: RecursiveStruct{}, want: true},
{val: time.Unix(0, 0), want: false},
{val: structs.Incomparable{}, want: false}, // ignore its [0]func()
{val: tailcfg.NetPortRange{}, want: false}, // uses structs.Incomparable
{val: (*tailcfg.Node)(nil), want: false},
{val: map[string]bool{}, want: false},
{val: func() {}, want: false},
{val: make(chan int), want: false},
{val: unsafe.Pointer(nil), want: false},
{val: make(RecursiveChan), want: true},
{val: make(chan int), want: false},
}
for _, tt := range tests {
got := typeIsRecursive(reflect.TypeOf(tt.val))
if got != tt.want {
t.Errorf("for type %T: got %v, want %v", tt.val, got, tt.want)
}
}
}
var sink = Hash("foo") var sink = Hash("foo")
func BenchmarkHash(b *testing.B) { func BenchmarkHash(b *testing.B) {
@ -255,12 +293,14 @@ func TestHashMapAcyclic(t *testing.T) {
var buf bytes.Buffer var buf bytes.Buffer
bw := bufio.NewWriter(&buf) bw := bufio.NewWriter(&buf)
ti := getTypeInfo(reflect.TypeOf(m))
for i := 0; i < 20; i++ { for i := 0; i < 20; i++ {
v := reflect.ValueOf(m) v := reflect.ValueOf(m)
buf.Reset() buf.Reset()
bw.Reset(&buf) bw.Reset(&buf)
h := &hasher{bw: bw} h := &hasher{bw: bw}
h.hashMap(v) h.hashMap(v, ti, false)
if got[string(buf.Bytes())] { if got[string(buf.Bytes())] {
continue continue
} }
@ -279,7 +319,7 @@ func TestPrintArray(t *testing.T) {
var got bytes.Buffer var got bytes.Buffer
bw := bufio.NewWriter(&got) bw := bufio.NewWriter(&got)
h := &hasher{bw: bw} h := &hasher{bw: bw}
h.hashValue(reflect.ValueOf(x)) h.hashValue(reflect.ValueOf(x), false)
bw.Flush() bw.Flush()
const want = "\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f" const want = "\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f"
if got := got.Bytes(); string(got) != want { if got := got.Bytes(); string(got) != want {
@ -297,13 +337,14 @@ func BenchmarkHashMapAcyclic(b *testing.B) {
var buf bytes.Buffer var buf bytes.Buffer
bw := bufio.NewWriter(&buf) bw := bufio.NewWriter(&buf)
v := reflect.ValueOf(m) v := reflect.ValueOf(m)
ti := getTypeInfo(v.Type())
h := &hasher{bw: bw} h := &hasher{bw: bw}
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
buf.Reset() buf.Reset()
bw.Reset(&buf) bw.Reset(&buf)
h.hashMap(v) h.hashMap(v, ti, false)
} }
} }

Loading…
Cancel
Save