diff --git a/util/deephash/deephash.go b/util/deephash/deephash.go index 87fde2819..c24735192 100644 --- a/util/deephash/deephash.go +++ b/util/deephash/deephash.go @@ -14,7 +14,7 @@ // - time.Time are compared based on whether they are the same instant in time // and also in the same zone offset. Monotonic measurements and zone names // are ignored as part of the hash. -// - netip.Addr are compared based on a shallow comparison of the struct. +// - netip.Addr are compared based on a shallow comparison of the struct. // // WARNING: This package, like most of the tailscale.com Go module, // should be considered Tailscale-internal; we make no API promises. @@ -25,7 +25,6 @@ import ( "encoding/binary" "encoding/hex" "fmt" - "log" "math" "net/netip" "reflect" @@ -246,7 +245,7 @@ func (ti *typeInfo) hasher() typeHasherFunc { } func (ti *typeInfo) buildHashFuncOnce() { - ti.hashFuncLazy = genTypeHasher(ti.rtype) + ti.hashFuncLazy = genTypeHasher(ti) } func (h *hasher) hashBoolv(v addressableValue) bool { @@ -380,13 +379,8 @@ func genHashPtrToMemoryRange(eleType reflect.Type) typeHasherFunc { } } -const debug = false - -func genTypeHasher(t reflect.Type) typeHasherFunc { - if debug { - log.Printf("generating func for %v", t) - } - +func genTypeHasher(ti *typeInfo) typeHasherFunc { + t := ti.rtype switch t.Kind() { case reflect.Bool: return (*hasher).hashBoolv @@ -436,30 +430,67 @@ func genTypeHasher(t reflect.Type) typeHasherFunc { default: return genHashStructFields(t) } + case reflect.Map: + return func(h *hasher, v addressableValue) bool { + if v.IsNil() { + h.HashUint8(0) // indicates nil + return true + } + if ti.isRecursive { + ptr := pointerOf(v) + if idx, ok := h.visitStack.seen(ptr); ok { + h.HashUint8(2) // indicates cycle + h.HashUint64(uint64(idx)) + return true + } + h.visitStack.push(ptr) + defer h.visitStack.pop(ptr) + } + h.HashUint8(1) // indicates visiting a map + h.hashMap(v, ti, ti.isRecursive) + return true + } case reflect.Pointer: et := t.Elem() if typeIsMemHashable(et) { return genHashPtrToMemoryRange(et) } - if !typeIsRecursive(t) { - eti := getTypeInfo(et) - return func(h *hasher, v addressableValue) bool { - if v.IsNil() { - h.HashUint8(0) // indicates nil + eti := getTypeInfo(et) + return func(h *hasher, v addressableValue) bool { + if v.IsNil() { + h.HashUint8(0) // indicates nil + return true + } + if ti.isRecursive { + ptr := pointerOf(v) + if idx, ok := h.visitStack.seen(ptr); ok { + h.HashUint8(2) // indicates cycle + h.HashUint64(uint64(idx)) return true } - h.HashUint8(1) // indicates visiting a pointer - va := addressableValue{v.Elem()} // dereferenced pointer is always addressable - return eti.hasher()(h, va) + h.visitStack.push(ptr) + defer h.visitStack.pop(ptr) } + h.HashUint8(1) // indicates visiting a pointer + va := addressableValue{v.Elem()} // dereferenced pointer is always addressable + return eti.hasher()(h, va) } - } + case reflect.Interface: + return func(h *hasher, v addressableValue) bool { + if v.IsNil() { + h.HashUint8(0) // indicates nil + return true + } + va := newAddressableValue(v.Elem().Type()) + va.Set(v.Elem()) - return func(h *hasher, v addressableValue) bool { - if debug { - log.Printf("unhandled type %v", v.Type()) + h.HashUint8(1) // indicates visiting interface value + h.hashType(va.Type()) + h.hashValue(va, true) + return true } - return false + default: // Func, Chan, UnsafePointer + return noopHasherFunc } } @@ -646,11 +677,8 @@ func (h *hasher) hashValue(v addressableValue, forceCycleChecking bool) { func (h *hasher) hashValueWithType(v addressableValue, ti *typeInfo, forceCycleChecking bool) { doCheckCycles := forceCycleChecking || ti.isRecursive - if !doCheckCycles { - hf := ti.hasher() - if hf(h, v) { - return - } + if ti.hasher()(h, v) { + return } // Generic handling.