diff --git a/util/deephash/deephash.go b/util/deephash/deephash.go index e274ed53f..5add5a82e 100644 --- a/util/deephash/deephash.go +++ b/util/deephash/deephash.go @@ -61,6 +61,19 @@ import ( // theoretically "parsable" by looking up the hash in a magical map that // returns the set of entries for that given hash. +// addressableValue is a reflect.Value that is guaranteed to be addressable +// such that calling the Addr and Set methods do not panic. +// +// There is no compile magic that enforces this property, +// but rather the need to construct this type makes it easier to examine each +// construction site to ensure that this property is upheld. +type addressableValue struct{ reflect.Value } + +// newAddressableValue constructs a new addressable value of type t. +func newAddressableValue(t reflect.Type) addressableValue { + return addressableValue{reflect.New(t).Elem()} // dereferenced pointer is always addressable +} + const scratchSize = 128 // hasher is reusable state for hashing a value. @@ -122,6 +135,7 @@ var hasherPool = &sync.Pool{ } // Hash returns the hash of v. +// For performance, this should be a non-nil pointer. func Hash(v any) (s Sum) { h := hasherPool.Get().(*hasher) defer hasherPool.Put(h) @@ -131,14 +145,22 @@ func Hash(v any) (s Sum) { rv := reflect.ValueOf(v) if rv.IsValid() { + var va addressableValue + if rv.Kind() == reflect.Pointer && !rv.IsNil() { + va = addressableValue{rv.Elem()} // dereferenced pointer is always addressable + } else { + va = newAddressableValue(rv.Type()) + va.Set(rv) + } + // Always treat the Hash input as an interface (it is), including hashing // its type, otherwise two Hash calls of different types could hash to the // same bytes off the different types and get equivalent Sum values. This is // the same thing that we do for reflect.Kind Interface in hashValue, but // the initial reflect.ValueOf from an interface value effectively strips // the interface box off so we have to do it at the top level by hand. - h.hashType(rv.Type()) - h.hashValue(rv, false) + h.hashType(va.Type()) + h.hashValue(va, false) } return h.sum() } @@ -147,7 +169,12 @@ func Hash(v any) (s Sum) { // the provided reflect type, avoiding a map lookup per value. func HasherForType[T any]() func(T) Sum { var zeroT T - ti := getTypeInfo(reflect.TypeOf(zeroT)) + t := reflect.TypeOf(zeroT) + ti := getTypeInfo(t) + var tiElem *typeInfo + if t.Kind() == reflect.Pointer { + tiElem = getTypeInfo(t.Elem()) + } seedOnce.Do(initSeed) return func(v T) Sum { @@ -159,14 +186,16 @@ func HasherForType[T any]() func(T) Sum { rv := reflect.ValueOf(v) if rv.IsValid() { - // Always treat the Hash input as an interface (it is), including hashing - // its type, otherwise two Hash calls of different types could hash to the - // same bytes off the different types and get equivalent Sum values. This is - // the same thing that we do for reflect.Kind Interface in hashValue, but - // the initial reflect.ValueOf from an interface value effectively strips - // the interface box off so we have to do it at the top level by hand. - h.hashType(rv.Type()) - h.hashValueWithType(rv, ti, false) + if rv.Kind() == reflect.Pointer && !rv.IsNil() { + va := addressableValue{rv.Elem()} // dereferenced pointer is always addressable + h.hashType(va.Type()) + h.hashValueWithType(va, tiElem, false) + } else { + va := newAddressableValue(rv.Type()) + va.Set(rv) + h.hashType(va.Type()) + h.hashValueWithType(va, ti, false) + } } return h.sum() } @@ -238,7 +267,7 @@ type typeInfo struct { } // returns ok if it was handled; else slow path runs -type typeHasherFunc func(h *hasher, v reflect.Value) (ok bool) +type typeHasherFunc func(h *hasher, v addressableValue) (ok bool) var typeInfoMap sync.Map // map[reflect.Type]*typeInfo var typeInfoMapPopulate sync.Mutex // just for adding to typeInfoMap @@ -252,7 +281,7 @@ func (ti *typeInfo) buildHashFuncOnce() { ti.hashFuncLazy = genTypeHasher(ti.rtype) } -func (h *hasher) hashBoolv(v reflect.Value) bool { +func (h *hasher) hashBoolv(v addressableValue) bool { var b byte if v.Bool() { b = 1 @@ -261,56 +290,51 @@ func (h *hasher) hashBoolv(v reflect.Value) bool { return true } -func (h *hasher) hashUint8v(v reflect.Value) bool { +func (h *hasher) hashUint8v(v addressableValue) bool { h.hashUint8(uint8(v.Uint())) return true } -func (h *hasher) hashInt8v(v reflect.Value) bool { +func (h *hasher) hashInt8v(v addressableValue) bool { h.hashUint8(uint8(v.Int())) return true } -func (h *hasher) hashUint16v(v reflect.Value) bool { +func (h *hasher) hashUint16v(v addressableValue) bool { h.hashUint16(uint16(v.Uint())) return true } -func (h *hasher) hashInt16v(v reflect.Value) bool { +func (h *hasher) hashInt16v(v addressableValue) bool { h.hashUint16(uint16(v.Int())) return true } -func (h *hasher) hashUint32v(v reflect.Value) bool { +func (h *hasher) hashUint32v(v addressableValue) bool { h.hashUint32(uint32(v.Uint())) return true } -func (h *hasher) hashInt32v(v reflect.Value) bool { +func (h *hasher) hashInt32v(v addressableValue) bool { h.hashUint32(uint32(v.Int())) return true } -func (h *hasher) hashUint64v(v reflect.Value) bool { +func (h *hasher) hashUint64v(v addressableValue) bool { h.hashUint64(v.Uint()) return true } -func (h *hasher) hashInt64v(v reflect.Value) bool { +func (h *hasher) hashInt64v(v addressableValue) bool { h.hashUint64(uint64(v.Int())) return true } -func hashStructAppenderTo(h *hasher, v reflect.Value) bool { +func hashStructAppenderTo(h *hasher, v addressableValue) bool { if !v.CanInterface() { return false // slow path } - var a appenderTo - if v.CanAddr() { - a = v.Addr().Interface().(appenderTo) - } else { - a = v.Interface().(appenderTo) - } + a := v.Addr().Interface().(appenderTo) size := h.scratch[:8] record := a.AppendTo(size) binary.LittleEndian.PutUint64(record, uint64(len(record)-len(size))) @@ -319,7 +343,7 @@ func hashStructAppenderTo(h *hasher, v reflect.Value) bool { } // hashPointerAppenderTo hashes v, a reflect.Ptr, that implements appenderTo. -func hashPointerAppenderTo(h *hasher, v reflect.Value) bool { +func hashPointerAppenderTo(h *hasher, v addressableValue) bool { if !v.CanInterface() { return false // slow path } @@ -338,7 +362,7 @@ func hashPointerAppenderTo(h *hasher, v reflect.Value) bool { // fieldInfo describes a struct field. type fieldInfo struct { - index int // index of field for reflect.Value.Field(n) + index int // index of field for reflect.Value.Field(n); -1 if invalid typeInfo *typeInfo canMemHash bool offset uintptr // when we can memhash the field @@ -380,30 +404,24 @@ func genHashStructFields(t reflect.Type) typeHasherFunc { size: sf.Type.Size(), }) } - fieldsIfCanAddr := mergeContiguousFieldsCopy(fields) - return structHasher{fields, fieldsIfCanAddr}.hash + fields = mergeContiguousFieldsCopy(fields) + return structHasher{fields}.hash } type structHasher struct { - fields, fieldsIfCanAddr []fieldInfo + fields []fieldInfo } -func (sh structHasher) hash(h *hasher, v reflect.Value) bool { - var base unsafe.Pointer - if v.CanAddr() { - base = v.Addr().UnsafePointer() - for _, f := range sh.fieldsIfCanAddr { - if f.canMemHash { - h.bw.Write(unsafe.Slice((*byte)(unsafe.Pointer(uintptr(base)+f.offset)), f.size)) - } else if !f.typeInfo.hasher()(h, v.Field(f.index)) { - return false - } +func (sh structHasher) hash(h *hasher, v addressableValue) bool { + base := v.Addr().UnsafePointer() + for _, f := range sh.fields { + if f.canMemHash { + h.bw.Write(unsafe.Slice((*byte)(unsafe.Pointer(uintptr(base)+f.offset)), f.size)) + continue } - } else { - for _, f := range sh.fields { - if !f.typeInfo.hasher()(h, v.Field(f.index)) { - return false - } + va := addressableValue{v.Field(f.index)} // field is addressable if parent struct is addressable + if !f.typeInfo.hasher()(h, va) { + return false } } return true @@ -413,7 +431,7 @@ func (sh structHasher) hash(h *hasher, v reflect.Value) bool { // the provided eleType. func genHashPtrToMemoryRange(eleType reflect.Type) typeHasherFunc { size := eleType.Size() - return func(h *hasher, v reflect.Value) bool { + return func(h *hasher, v addressableValue) bool { if v.IsNil() { h.hashUint8(0) // indicates nil } else { @@ -489,18 +507,19 @@ func genTypeHasher(t reflect.Type) typeHasherFunc { } if !typeIsRecursive(t) { eti := getTypeInfo(et) - return func(h *hasher, v reflect.Value) bool { + return func(h *hasher, v addressableValue) bool { if v.IsNil() { h.hashUint8(0) // indicates nil return true } - h.hashUint8(1) // indicates visiting a pointer - return eti.hasher()(h, v.Elem()) + h.hashUint8(1) // indicates visiting a pointer + va := addressableValue{v.Elem()} // dereferenced pointer is always addressable + return eti.hasher()(h, va) } } } - return func(h *hasher, v reflect.Value) bool { + return func(h *hasher, v addressableValue) bool { if debug { log.Printf("unhandled type %v", v.Type()) } @@ -509,31 +528,31 @@ func genTypeHasher(t reflect.Type) typeHasherFunc { } // hashString hashes v, of kind String. -func (h *hasher) hashString(v reflect.Value) bool { +func (h *hasher) hashString(v addressableValue) bool { s := v.String() h.hashLen(len(s)) h.bw.WriteString(s) return true } -func (h *hasher) hashFloat32v(v reflect.Value) bool { +func (h *hasher) hashFloat32v(v addressableValue) bool { h.hashUint32(math.Float32bits(float32(v.Float()))) return true } -func (h *hasher) hashFloat64v(v reflect.Value) bool { +func (h *hasher) hashFloat64v(v addressableValue) bool { h.hashUint64(math.Float64bits(v.Float())) return true } -func (h *hasher) hashComplex64v(v reflect.Value) bool { +func (h *hasher) hashComplex64v(v addressableValue) bool { c := complex64(v.Complex()) h.hashUint32(math.Float32bits(real(c))) h.hashUint32(math.Float32bits(imag(c))) return true } -func (h *hasher) hashComplex128v(v reflect.Value) bool { +func (h *hasher) hashComplex128v(v addressableValue) bool { c := v.Complex() h.hashUint64(math.Float64bits(real(c))) h.hashUint64(math.Float64bits(imag(c))) @@ -541,15 +560,8 @@ func (h *hasher) hashComplex128v(v reflect.Value) bool { } // hashString hashes v, of kind time.Time. -func (h *hasher) hashTimev(v reflect.Value) bool { - var t time.Time - if v.CanAddr() { - t = *(*time.Time)(v.Addr().UnsafePointer()) - } else if v.CanInterface() { - t = v.Interface().(time.Time) - } else { - return false - } +func (h *hasher) hashTimev(v addressableValue) bool { + t := *(*time.Time)(v.Addr().UnsafePointer()) b := t.AppendFormat(h.scratch[:1], time.RFC3339Nano) b[0] = byte(len(b) - 1) // more than sufficient width; if not, good enough. h.bw.Write(b) @@ -557,7 +569,7 @@ func (h *hasher) hashTimev(v reflect.Value) bool { } // hashSliceMem hashes v, of kind Slice, with a memhash-able element type. -func (h *hasher) hashSliceMem(v reflect.Value) bool { +func (h *hasher) hashSliceMem(v addressableValue) bool { vLen := v.Len() h.hashUint64(uint64(vLen)) if vLen == 0 { @@ -568,20 +580,17 @@ func (h *hasher) hashSliceMem(v reflect.Value) bool { } func genHashArrayMem(n int, arraySize uintptr, efu *typeInfo) typeHasherFunc { - byElement := genHashArrayElements(n, efu) - return func(h *hasher, v reflect.Value) bool { - if v.CanAddr() { - h.bw.Write(unsafe.Slice((*byte)(v.Addr().UnsafePointer()), arraySize)) - return true - } - return byElement(h, v) + return func(h *hasher, v addressableValue) bool { + h.bw.Write(unsafe.Slice((*byte)(v.Addr().UnsafePointer()), arraySize)) + return true } } func genHashArrayElements(n int, eti *typeInfo) typeHasherFunc { - return func(h *hasher, v reflect.Value) bool { + return func(h *hasher, v addressableValue) bool { for i := 0; i < n; i++ { - if !eti.hasher()(h, v.Index(i)) { + va := addressableValue{v.Index(i)} // element is addressable if parent array is addressable + if !eti.hasher()(h, va) { return false } } @@ -589,7 +598,7 @@ func genHashArrayElements(n int, eti *typeInfo) typeHasherFunc { } } -func noopHasherFunc(h *hasher, v reflect.Value) bool { return true } +func noopHasherFunc(h *hasher, v addressableValue) bool { return true } func genHashArray(t reflect.Type, eti *typeInfo) typeHasherFunc { if t.Size() == 0 { @@ -611,11 +620,12 @@ type sliceElementHasher struct { eti *typeInfo } -func (seh sliceElementHasher) hash(h *hasher, v reflect.Value) bool { +func (seh sliceElementHasher) hash(h *hasher, v addressableValue) bool { vLen := v.Len() h.hashUint64(uint64(vLen)) for i := 0; i < vLen; i++ { - if !seh.eti.hasher()(h, v.Index(i)) { + va := addressableValue{v.Index(i)} // slice elements are always addressable + if !seh.eti.hasher()(h, va) { return false } } @@ -768,7 +778,7 @@ func canMemHash(t reflect.Type) bool { return false } -func (h *hasher) hashValue(v reflect.Value, forceCycleChecking bool) { +func (h *hasher) hashValue(v addressableValue, forceCycleChecking bool) { if !v.IsValid() { return } @@ -776,7 +786,7 @@ func (h *hasher) hashValue(v reflect.Value, forceCycleChecking bool) { h.hashValueWithType(v, ti, forceCycleChecking) } -func (h *hasher) hashValueWithType(v reflect.Value, ti *typeInfo, forceCycleChecking bool) { +func (h *hasher) hashValueWithType(v addressableValue, ti *typeInfo, forceCycleChecking bool) { w := h.bw doCheckCycles := forceCycleChecking || ti.isRecursive @@ -808,11 +818,13 @@ func (h *hasher) hashValueWithType(v reflect.Value, ti *typeInfo, forceCycleChec defer h.visitStack.pop(ptr) } - h.hashUint8(1) // indicates visiting a pointer - h.hashValueWithType(v.Elem(), ti.elemTypeInfo, doCheckCycles) + h.hashUint8(1) // indicates visiting a pointer + va := addressableValue{v.Elem()} // dereferenced pointer is always addressable + h.hashValueWithType(va, ti.elemTypeInfo, doCheckCycles) case reflect.Struct: for i, n := 0, v.NumField(); i < n; i++ { - h.hashValue(v.Field(i), doCheckCycles) + va := addressableValue{v.Field(i)} // field is addressable if parent struct is addressable + h.hashValue(va, doCheckCycles) } case reflect.Slice, reflect.Array: vLen := v.Len() @@ -825,7 +837,7 @@ func (h *hasher) hashValueWithType(v reflect.Value, ti *typeInfo, forceCycleChec // It seems tempting to do this for all sizes, doing // scratchSize bytes at a time, but reflect.Slice seems // to allocate, so it's not a win. - n := reflect.Copy(reflect.ValueOf(&h.scratch).Elem(), v) + n := reflect.Copy(reflect.ValueOf(&h.scratch).Elem(), v.Value) w.Write(h.scratch[:n]) return } @@ -836,18 +848,21 @@ func (h *hasher) hashValueWithType(v reflect.Value, ti *typeInfo, forceCycleChec // TODO(dsnet): Perform cycle detection for slices, // which is functionally a list of pointers. // See https://github.com/google/go-cmp/blob/402949e8139bb890c71a707b6faf6dd05c92f4e5/cmp/compare.go#L438-L450 - h.hashValueWithType(v.Index(i), ti.elemTypeInfo, doCheckCycles) + va := addressableValue{v.Index(i)} // slice elements are always addressable + h.hashValueWithType(va, ti.elemTypeInfo, doCheckCycles) } case reflect.Interface: if v.IsNil() { h.hashUint8(0) // indicates nil return } - v = v.Elem() + // TODO: Use a valueCache here? + va := newAddressableValue(v.Elem().Type()) + va.Set(v.Elem()) h.hashUint8(1) // indicates visiting interface value - h.hashType(v.Type()) - h.hashValue(v, doCheckCycles) + h.hashType(va.Type()) + h.hashValue(va, doCheckCycles) case reflect.Map: // Check for cycle. if doCheckCycles { @@ -911,12 +926,12 @@ var mapHasherPool = &sync.Pool{ New: func() any { return new(mapHasher) }, } -type valueCache map[reflect.Type]reflect.Value +type valueCache map[reflect.Type]addressableValue -func (c *valueCache) get(t reflect.Type) reflect.Value { +func (c *valueCache) get(t reflect.Type) addressableValue { v, ok := (*c)[t] if !ok { - v = reflect.New(t).Elem() + v = newAddressableValue(t) if *c == nil { *c = make(valueCache) } @@ -929,12 +944,12 @@ func (c *valueCache) get(t reflect.Type) reflect.Value { // It relies on a map being a functionally an unordered set of KV entries. // So long as we hash each KV entry together, we can XOR all // of the individual hashes to produce a unique hash for the entire map. -func (h *hasher) hashMap(v reflect.Value, ti *typeInfo, checkCycles bool) { +func (h *hasher) hashMap(v addressableValue, ti *typeInfo, checkCycles bool) { mh := mapHasherPool.Get().(*mapHasher) defer mapHasherPool.Put(mh) iter := &mh.iter - iter.Reset(v) + iter.Reset(v.Value) defer iter.Reset(reflect.Value{}) // avoid pinning v from mh.iter when we return var sum Sum @@ -983,8 +998,8 @@ func (v visitStack) pop(p pointer) { // that would break if Go ever switched to a moving GC. type pointer struct{ p unsafe.Pointer } -func pointerOf(v reflect.Value) pointer { - return pointer{unsafe.Pointer(v.Pointer())} +func pointerOf(v addressableValue) pointer { + return pointer{unsafe.Pointer(v.Value.Pointer())} } // hashType hashes a reflect.Type. diff --git a/util/deephash/deephash_test.go b/util/deephash/deephash_test.go index cc06e3945..b4abae568 100644 --- a/util/deephash/deephash_test.go +++ b/util/deephash/deephash_test.go @@ -187,8 +187,16 @@ func TestQuick(t *testing.T) { } } -func getVal() []any { - return []any{ +func getVal() any { + return &struct { + WGConfig *wgcfg.Config + RouterConfig *router.Config + MapFQDNAddrs map[dnsname.FQDN][]netip.Addr + MapFQDNAddrPorts map[dnsname.FQDN][]netip.AddrPort + MapDiscoPublics map[key.DiscoPublic]bool + MapResponse *tailcfg.MapResponse + FilterMatch filter.Match + }{ &wgcfg.Config{ Name: "foo", Addresses: []netip.Prefix{netip.PrefixFrom(netip.AddrFrom16([16]byte{3: 3}).Unmap(), 5)}, @@ -467,7 +475,8 @@ func TestGetTypeHasher(t *testing.T) { a, b int c uint16 }{1, -1, 2}, - out: "\x01\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\x02\x00", + out: "\x01\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\x02\x00", + out32: "\x01\x00\x00\x00\xff\xff\xff\xff\x02\x00", }, { name: "nil_int_ptr", @@ -529,7 +538,7 @@ func TestGetTypeHasher(t *testing.T) { { name: "time_ptr_via_unexported_value", val: *testtype.NewUnexportedAddressableTime(time.Unix(0, 0).In(time.UTC)), - want: false, // neither addressable nor interface-able + out: "\x141970-01-01T00:00:00Z", }, { name: "time_custom_zone", @@ -614,12 +623,14 @@ func TestGetTypeHasher(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { rv := reflect.ValueOf(tt.val) - fn := getTypeInfo(rv.Type()).hasher() + va := newAddressableValue(rv.Type()) + va.Set(rv) + fn := getTypeInfo(va.Type()).hasher() var buf bytes.Buffer h := &hasher{ bw: bufio.NewWriter(&buf), } - got := fn(h, rv) + got := fn(h, va) const ptrSize = 32 << uintptr(^uintptr(0)>>63) if tt.out32 != "" && ptrSize == 32 { tt.out = tt.out32 @@ -640,7 +651,7 @@ func TestGetTypeHasher(t *testing.T) { } } -var sink = Hash("foo") +var sink Sum func BenchmarkHash(b *testing.B) { b.ReportAllocs() @@ -696,9 +707,9 @@ var filterRules = []tailcfg.FilterRule{ func BenchmarkHashPacketFilter(b *testing.B) { b.ReportAllocs() - hash := HasherForType[[]tailcfg.FilterRule]() + hash := HasherForType[*[]tailcfg.FilterRule]() for i := 0; i < b.N; i++ { - sink = hash(filterRules) + sink = hash(&filterRules) } } @@ -715,7 +726,7 @@ func TestHashMapAcyclic(t *testing.T) { ti := getTypeInfo(reflect.TypeOf(m)) for i := 0; i < 20; i++ { - v := reflect.ValueOf(m) + v := addressableValue{reflect.ValueOf(&m).Elem()} buf.Reset() bw.Reset(&buf) h := &hasher{bw: bw} @@ -738,7 +749,7 @@ func TestPrintArray(t *testing.T) { var got bytes.Buffer bw := bufio.NewWriter(&got) h := &hasher{bw: bw} - h.hashValue(reflect.ValueOf(x), false) + h.hashValue(addressableValue{reflect.ValueOf(&x).Elem()}, false) bw.Flush() const want = "\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f" if got := got.Bytes(); string(got) != want { @@ -755,7 +766,7 @@ func BenchmarkHashMapAcyclic(b *testing.B) { var buf bytes.Buffer bw := bufio.NewWriter(&buf) - v := reflect.ValueOf(m) + v := addressableValue{reflect.ValueOf(&m).Elem()} ti := getTypeInfo(v.Type()) h := &hasher{bw: bw}