From 31bf3874d6f793943d63ffa2cf0440c57108b338 Mon Sep 17 00:00:00 2001 From: Joe Tsai Date: Sat, 27 Aug 2022 12:30:35 -0700 Subject: [PATCH] util/deephash: use unsafe.Pointer instead of reflect.Value (#5459) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Use of reflect.Value.SetXXX panics if the provided argument was obtained from an unexported struct field. Instead, pass an unsafe.Pointer around and convert to a reflect.Value when necessary (i.e., for maps and interfaces). Converting from unsafe.Pointer to reflect.Value guarantees that none of the read-only bits will be populated. When running in race mode, we attach type information to the pointer so that we can type check every pointer operation. This also type-checks that direct memory hashing is within the valid range of a struct value. We add test cases that previously caused deephash to panic, but now pass. Performance: name old time/op new time/op delta Hash 14.1µs ± 1% 14.1µs ± 1% ~ (p=0.590 n=10+9) HashPacketFilter 2.53µs ± 2% 2.44µs ± 1% -3.79% (p=0.000 n=9+10) TailcfgNode 1.45µs ± 1% 1.43µs ± 0% -1.36% (p=0.000 n=9+9) HashArray 318ns ± 2% 318ns ± 2% ~ (p=0.541 n=10+10) HashMapAcyclic 32.9µs ± 1% 31.6µs ± 1% -4.16% (p=0.000 n=10+9) There is a slight performance gain due to the use of unsafe.Pointer over reflect.Value methods. Also, passing an unsafe.Pointer (1 word) on the stack is cheaper than passing a reflect.Value (3 words). Performance gains are diminishing since SHA-256 hashing now dominates the runtime. Signed-off-by: Joe Tsai --- util/deephash/deephash.go | 236 ++++++++++++-------------------- util/deephash/deephash_test.go | 105 ++++++++++++-- util/deephash/pointer.go | 115 ++++++++++++++++ util/deephash/pointer_norace.go | 14 ++ util/deephash/pointer_race.go | 100 ++++++++++++++ 5 files changed, 408 insertions(+), 162 deletions(-) create mode 100644 util/deephash/pointer.go create mode 100644 util/deephash/pointer_norace.go create mode 100644 util/deephash/pointer_race.go diff --git a/util/deephash/deephash.go b/util/deephash/deephash.go index 228473028..1e3cb79b2 100644 --- a/util/deephash/deephash.go +++ b/util/deephash/deephash.go @@ -24,11 +24,9 @@ import ( "crypto/sha256" "encoding/binary" "encoding/hex" - "net/netip" "reflect" "sync" "time" - "unsafe" "tailscale.com/util/hashx" ) @@ -60,19 +58,6 @@ import ( // theoretically "parsable" by looking up the hash in a magical map that // returns the set of entries for that given hash. -// addressableValue is a reflect.Value that is guaranteed to be addressable -// such that calling the Addr and Set methods do not panic. -// -// There is no compile magic that enforces this property, -// but rather the need to construct this type makes it easier to examine each -// construction site to ensure that this property is upheld. -type addressableValue struct{ reflect.Value } - -// newAddressableValue constructs a new addressable value of type t. -func newAddressableValue(t reflect.Type) addressableValue { - return addressableValue{reflect.New(t).Elem()} // dereferenced pointer is always addressable -} - const scratchSize = 128 // hasher is reusable state for hashing a value. @@ -134,12 +119,16 @@ func Hash(v any) (s Sum) { rv := reflect.ValueOf(v) if rv.IsValid() { - var va addressableValue + var t reflect.Type + var p pointer if rv.Kind() == reflect.Pointer && !rv.IsNil() { - va = addressableValue{rv.Elem()} // dereferenced pointer is always addressable + t = rv.Type().Elem() + p = pointerOf(rv) } else { - va = newAddressableValue(rv.Type()) + t = rv.Type() + va := reflect.New(t).Elem() va.Set(rv) + p = pointerOf(va.Addr()) } // Always treat the Hash input as an interface (it is), including hashing @@ -148,9 +137,9 @@ func Hash(v any) (s Sum) { // the same thing that we do for reflect.Kind Interface in hashValue, but // the initial reflect.ValueOf from an interface value effectively strips // the interface box off so we have to do it at the top level by hand. - h.hashType(va.Type()) - ti := getTypeInfo(va.Type()) - ti.hasher()(h, va) + h.hashType(t) + ti := getTypeInfo(t) + ti.hasher()(h, p) } return h.sum() } @@ -177,14 +166,15 @@ func HasherForType[T any]() func(T) Sum { if rv.IsValid() { if rv.Kind() == reflect.Pointer && !rv.IsNil() { - va := addressableValue{rv.Elem()} // dereferenced pointer is always addressable - h.hashType(va.Type()) - tiElem.hasher()(h, va) + p := pointerOf(rv) + h.hashType(t.Elem()) + tiElem.hasher()(h, p) } else { - va := newAddressableValue(rv.Type()) + va := reflect.New(t).Elem() va.Set(rv) - h.hashType(va.Type()) - ti.hasher()(h, va) + p := pointerOf(va.Addr()) + h.hashType(t) + ti.hasher()(h, p) } } return h.sum() @@ -223,7 +213,10 @@ type typeInfo struct { hashFuncLazy typeHasherFunc // nil until created } -type typeHasherFunc func(h *hasher, v addressableValue) +// typeHasherFunc hashes the value pointed at by p for a given type. +// For example, if t is a bool, then p is a *bool. +// The provided pointer must always be non-nil. +type typeHasherFunc func(h *hasher, p pointer) var typeInfoMap sync.Map // map[reflect.Type]*typeInfo var typeInfoMapPopulate sync.Mutex // just for adding to typeInfoMap @@ -289,28 +282,13 @@ type structHasher struct { fields []fieldInfo } -func (sh structHasher) hash(h *hasher, v addressableValue) { - base := v.Addr().UnsafePointer() +func (sh structHasher) hash(h *hasher, p pointer) { for _, f := range sh.fields { + pf := p.structField(f.index, f.offset, f.size) if f.canMemHash { - h.HashBytes(unsafe.Slice((*byte)(unsafe.Pointer(uintptr(base)+f.offset)), f.size)) - continue - } - va := addressableValue{v.Field(f.index)} // field is addressable if parent struct is addressable - f.typeInfo.hasher()(h, va) - } -} - -// genHashPtrToMemoryRange returns a hasher where the reflect.Value is a Ptr to -// the provided eleType. -func genHashPtrToMemoryRange(eleType reflect.Type) typeHasherFunc { - size := eleType.Size() - return func(h *hasher, v addressableValue) { - if v.IsNil() { - h.HashUint8(0) // indicates nil + h.HashBytes(pf.asMemory(f.size)) } else { - h.HashUint8(1) // indicates visiting a pointer - h.HashBytes(unsafe.Slice((*byte)(v.UnsafePointer()), size)) + f.typeInfo.hasher()(h, pf) } } } @@ -337,7 +315,15 @@ func genTypeHasher(ti *typeInfo) typeHasherFunc { case reflect.Slice: et := t.Elem() if typeIsMemHashable(et) { - return (*hasher).hashSliceMem + return func(h *hasher, p pointer) { + pa := p.sliceArray() + vLen := p.sliceLen() + h.HashUint64(uint64(vLen)) + if vLen == 0 { + return + } + h.HashBytes(pa.asMemory(et.Size() * uintptr(vLen))) + } } eti := getTypeInfo(et) return genHashSliceElements(eti) @@ -348,80 +334,79 @@ func genTypeHasher(ti *typeInfo) typeHasherFunc { case reflect.Struct: return genHashStructFields(t) case reflect.Map: - return func(h *hasher, v addressableValue) { + return func(h *hasher, p pointer) { + v := p.asValue(t).Elem() // reflect.Map kind if v.IsNil() { h.HashUint8(0) // indicates nil return } if ti.isRecursive { - ptr := pointerOf(v) - if idx, ok := h.visitStack.seen(ptr); ok { + pm := v.UnsafePointer() // underlying pointer of map + if idx, ok := h.visitStack.seen(pm); ok { h.HashUint8(2) // indicates cycle h.HashUint64(uint64(idx)) return } - h.visitStack.push(ptr) - defer h.visitStack.pop(ptr) + h.visitStack.push(pm) + defer h.visitStack.pop(pm) } h.HashUint8(1) // indicates visiting a map - h.hashMap(v, ti, ti.isRecursive) + h.hashMap(v, ti) } case reflect.Pointer: et := t.Elem() - if typeIsMemHashable(et) { - return genHashPtrToMemoryRange(et) - } eti := getTypeInfo(et) - return func(h *hasher, v addressableValue) { - if v.IsNil() { + return func(h *hasher, p pointer) { + pe := p.pointerElem() + if pe.isNil() { h.HashUint8(0) // indicates nil return } if ti.isRecursive { - ptr := pointerOf(v) - if idx, ok := h.visitStack.seen(ptr); ok { + if idx, ok := h.visitStack.seen(pe.p); ok { h.HashUint8(2) // indicates cycle h.HashUint64(uint64(idx)) return } - h.visitStack.push(ptr) - defer h.visitStack.pop(ptr) + h.visitStack.push(pe.p) + defer h.visitStack.pop(pe.p) } - h.HashUint8(1) // indicates visiting a pointer - va := addressableValue{v.Elem()} // dereferenced pointer is always addressable - eti.hasher()(h, va) + h.HashUint8(1) // indicates visiting a pointer + eti.hasher()(h, pe) } case reflect.Interface: - return func(h *hasher, v addressableValue) { + return func(h *hasher, p pointer) { + v := p.asValue(t).Elem() // reflect.Interface kind if v.IsNil() { h.HashUint8(0) // indicates nil return } - va := newAddressableValue(v.Elem().Type()) - va.Set(v.Elem()) - - h.HashUint8(1) // indicates visiting interface value - h.hashType(va.Type()) - ti := getTypeInfo(va.Type()) - ti.hasher()(h, va) + h.HashUint8(1) // visiting interface + v = v.Elem() + t := v.Type() + h.hashType(t) + va := reflect.New(t).Elem() + va.Set(v) + ti := getTypeInfo(t) + ti.hasher()(h, pointerOf(va.Addr())) } default: // Func, Chan, UnsafePointer - return noopHasherFunc + return func(*hasher, pointer) {} } } -func (h *hasher) hashString(v addressableValue) { - s := v.String() +func (h *hasher) hashString(p pointer) { + s := *p.asString() h.HashUint64(uint64(len(s))) h.HashString(s) } // hashTimev hashes v, of kind time.Time. -func (h *hasher) hashTimev(v addressableValue) { +func (h *hasher) hashTimev(p pointer) { // Include the zone offset (but not the name) to keep // Hash(t1) == Hash(t2) being semantically equivalent to // t1.Format(time.RFC3339Nano) == t2.Format(time.RFC3339Nano). - t := *(*time.Time)(v.Addr().UnsafePointer()) + t := *p.asTime() _, offset := t.Zone() h.HashUint64(uint64(t.Unix())) h.HashUint32(uint32(t.Nanosecond())) @@ -429,11 +414,11 @@ func (h *hasher) hashTimev(v addressableValue) { } // hashAddrv hashes v, of type netip.Addr. -func (h *hasher) hashAddrv(v addressableValue) { +func (h *hasher) hashAddrv(p pointer) { // The formatting of netip.Addr covers the // IP version, the address, and the optional zone name (for v6). // This is equivalent to a1.MarshalBinary() == a2.MarshalBinary(). - ip := *(*netip.Addr)(v.Addr().UnsafePointer()) + ip := *p.asAddr() switch { case !ip.IsValid(): h.HashUint64(0) @@ -452,46 +437,22 @@ func (h *hasher) hashAddrv(v addressableValue) { } func makeMemHasher(n uintptr) typeHasherFunc { - return func(h *hasher, v addressableValue) { - h.HashBytes(unsafe.Slice((*byte)(v.Addr().UnsafePointer()), n)) - } -} - -// hashSliceMem hashes v, of kind Slice, with a memhash-able element type. -func (h *hasher) hashSliceMem(v addressableValue) { - vLen := v.Len() - h.HashUint64(uint64(vLen)) - if vLen == 0 { - return - } - h.HashBytes(unsafe.Slice((*byte)(v.UnsafePointer()), v.Type().Elem().Size()*uintptr(vLen))) -} - -func genHashArrayMem(n int, arraySize uintptr, efu *typeInfo) typeHasherFunc { - return func(h *hasher, v addressableValue) { - h.HashBytes(unsafe.Slice((*byte)(v.Addr().UnsafePointer()), arraySize)) + return func(h *hasher, p pointer) { + h.HashBytes(p.asMemory(n)) } } func genHashArrayElements(n int, eti *typeInfo) typeHasherFunc { - return func(h *hasher, v addressableValue) { + nb := eti.rtype.Size() // byte size of each array element + return func(h *hasher, p pointer) { for i := 0; i < n; i++ { - va := addressableValue{v.Index(i)} // element is addressable if parent array is addressable - eti.hasher()(h, va) + pe := p.arrayIndex(i, nb) + eti.hasher()(h, pe) } } } -func noopHasherFunc(h *hasher, v addressableValue) {} - func genHashArray(t reflect.Type, eti *typeInfo) typeHasherFunc { - if t.Size() == 0 { - return noopHasherFunc - } - et := t.Elem() - if typeIsMemHashable(et) { - return genHashArrayMem(t.Len(), t.Size(), eti) - } n := t.Len() return genHashArrayElements(n, eti) } @@ -504,12 +465,14 @@ type sliceElementHasher struct { eti *typeInfo } -func (seh sliceElementHasher) hash(h *hasher, v addressableValue) { - vLen := v.Len() +func (seh sliceElementHasher) hash(h *hasher, p pointer) { + pa := p.sliceArray() + vLen := p.sliceLen() h.HashUint64(uint64(vLen)) + nb := seh.eti.rtype.Size() for i := 0; i < vLen; i++ { - va := addressableValue{v.Index(i)} // slice elements are always addressable - seh.eti.hasher()(h, va) + pe := pa.arrayIndex(i, nb) + seh.eti.hasher()(h, pe) } } @@ -560,12 +523,12 @@ var mapHasherPool = &sync.Pool{ New: func() any { return new(mapHasher) }, } -type valueCache map[reflect.Type]addressableValue +type valueCache map[reflect.Type]reflect.Value -func (c *valueCache) get(t reflect.Type) addressableValue { +func (c *valueCache) get(t reflect.Type) reflect.Value { v, ok := (*c)[t] if !ok { - v = newAddressableValue(t) + v = reflect.New(t).Elem() if *c == nil { *c = make(valueCache) } @@ -578,7 +541,7 @@ func (c *valueCache) get(t reflect.Type) addressableValue { // It relies on a map being a functionally an unordered set of KV entries. // So long as we hash each KV entry together, we can XOR all // of the individual hashes to produce a unique hash for the entire map. -func (h *hasher) hashMap(v addressableValue, ti *typeInfo, checkCycles bool) { +func (h *hasher) hashMap(v reflect.Value, ti *typeInfo) { mh := mapHasherPool.Get().(*mapHasher) defer mapHasherPool.Put(mh) @@ -594,44 +557,13 @@ func (h *hasher) hashMap(v addressableValue, ti *typeInfo, checkCycles bool) { k.SetIterKey(iter) e.SetIterValue(iter) mh.h.Reset() - ti.keyTypeInfo.hasher()(&mh.h, k) - ti.elemTypeInfo.hasher()(&mh.h, e) + ti.keyTypeInfo.hasher()(&mh.h, pointerOf(k.Addr())) + ti.elemTypeInfo.hasher()(&mh.h, pointerOf(e.Addr())) sum.xor(mh.h.sum()) } h.HashBytes(append(h.scratch[:0], sum.sum[:]...)) // append into scratch to avoid heap allocation } -// visitStack is a stack of pointers visited. -// Pointers are pushed onto the stack when visited, and popped when leaving. -// The integer value is the depth at which the pointer was visited. -// The length of this stack should be zero after every hashing operation. -type visitStack map[pointer]int - -func (v visitStack) seen(p pointer) (int, bool) { - idx, ok := v[p] - return idx, ok -} - -func (v *visitStack) push(p pointer) { - if *v == nil { - *v = make(map[pointer]int) - } - (*v)[p] = len(*v) -} - -func (v visitStack) pop(p pointer) { - delete(v, p) -} - -// pointer is a thin wrapper over unsafe.Pointer. -// We only rely on comparability of pointers; we cannot rely on uintptr since -// that would break if Go ever switched to a moving GC. -type pointer struct{ p unsafe.Pointer } - -func pointerOf(v addressableValue) pointer { - return pointer{unsafe.Pointer(v.Value.Pointer())} -} - // hashType hashes a reflect.Type. // The hash is only consistent within the lifetime of a program. func (h *hasher) hashType(t reflect.Type) { diff --git a/util/deephash/deephash_test.go b/util/deephash/deephash_test.go index a2b2e0492..235470914 100644 --- a/util/deephash/deephash_test.go +++ b/util/deephash/deephash_test.go @@ -20,6 +20,7 @@ import ( "testing/quick" "time" + qt "github.com/frankban/quicktest" "go4.org/mem" "go4.org/netipx" "tailscale.com/tailcfg" @@ -572,13 +573,13 @@ func TestGetTypeHasher(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { rv := reflect.ValueOf(tt.val) - va := newAddressableValue(rv.Type()) + va := reflect.New(rv.Type()).Elem() va.Set(rv) fn := getTypeInfo(va.Type()).hasher() hb := &hashBuffer{Hash: sha256.New()} h := new(hasher) h.Block512.Hash = hb - fn(h, va) + fn(h, pointerOf(va.Addr())) const ptrSize = 32 << uintptr(^uintptr(0)>>63) if tt.out32 != "" && ptrSize == 32 { tt.out = tt.out32 @@ -591,6 +592,90 @@ func TestGetTypeHasher(t *testing.T) { } } +func TestMapCycle(t *testing.T) { + type M map[string]M + c := qt.New(t) + + a := make(M) // cylic graph of 1 node + a["self"] = a + b := make(M) // cylic graph of 1 node + b["self"] = b + ha := Hash(a) + hb := Hash(b) + c.Assert(ha, qt.Equals, hb) + + c1 := make(M) // cyclic graph of 2 nodes + c2 := make(M) // cyclic graph of 2 nodes + c1["peer"] = c2 + c2["peer"] = c1 + hc1 := Hash(c1) + hc2 := Hash(c2) + c.Assert(hc1, qt.Equals, hc2) + c.Assert(ha, qt.Not(qt.Equals), hc1) + c.Assert(hb, qt.Not(qt.Equals), hc2) + + c3 := make(M) // graph of 1 node pointing to cyclic graph of 2 nodes + c3["child"] = c1 + hc3 := Hash(c3) + c.Assert(hc1, qt.Not(qt.Equals), hc3) +} + +func TestPointerCycle(t *testing.T) { + type P *P + c := qt.New(t) + + a := new(P) // cyclic graph of 1 node + *a = a + b := new(P) // cyclic graph of 1 node + *b = b + ha := Hash(&a) + hb := Hash(&b) + c.Assert(ha, qt.Equals, hb) + + c1 := new(P) // cyclic graph of 2 nodes + c2 := new(P) // cyclic graph of 2 nodes + *c1 = c2 + *c2 = c1 + hc1 := Hash(&c1) + hc2 := Hash(&c2) + c.Assert(hc1, qt.Equals, hc2) + c.Assert(ha, qt.Not(qt.Equals), hc1) + c.Assert(hb, qt.Not(qt.Equals), hc2) + + c3 := new(P) // graph of 1 node pointing to cyclic graph of 2 nodes + *c3 = c1 + hc3 := Hash(&c3) + c.Assert(hc1, qt.Not(qt.Equals), hc3) +} + +func TestInterfaceCycle(t *testing.T) { + type I struct{ v any } + c := qt.New(t) + + a := new(I) // cyclic graph of 1 node + a.v = a + b := new(I) // cyclic graph of 1 node + b.v = b + ha := Hash(&a) + hb := Hash(&b) + c.Assert(ha, qt.Equals, hb) + + c1 := new(I) // cyclic graph of 2 nodes + c2 := new(I) // cyclic graph of 2 nodes + c1.v = c2 + c2.v = c1 + hc1 := Hash(&c1) + hc2 := Hash(&c2) + c.Assert(hc1, qt.Equals, hc2) + c.Assert(ha, qt.Not(qt.Equals), hc1) + c.Assert(hb, qt.Not(qt.Equals), hc2) + + c3 := new(I) // graph of 1 node pointing to cyclic graph of 2 nodes + c3.v = c1 + hc3 := Hash(&c3) + c.Assert(hc1, qt.Not(qt.Equals), hc3) +} + var sink Sum func BenchmarkHash(b *testing.B) { @@ -665,11 +750,11 @@ func TestHashMapAcyclic(t *testing.T) { ti := getTypeInfo(reflect.TypeOf(m)) for i := 0; i < 20; i++ { - v := addressableValue{reflect.ValueOf(&m).Elem()} + v := reflect.ValueOf(&m).Elem() hb.Reset() h := new(hasher) h.Block512.Hash = hb - h.hashMap(v, ti, false) + h.hashMap(v, ti) h.sum() if got[string(hb.B)] { continue @@ -689,9 +774,9 @@ func TestPrintArray(t *testing.T) { hb := &hashBuffer{Hash: sha256.New()} h := new(hasher) h.Block512.Hash = hb - v := addressableValue{reflect.ValueOf(&x).Elem()} - ti := getTypeInfo(v.Type()) - ti.hasher()(h, v) + va := reflect.ValueOf(&x).Elem() + ti := getTypeInfo(va.Type()) + ti.hasher()(h, pointerOf(va.Addr())) h.sum() const want = "\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f" if got := hb.B; string(got) != want { @@ -707,15 +792,15 @@ func BenchmarkHashMapAcyclic(b *testing.B) { } hb := &hashBuffer{Hash: sha256.New()} - v := addressableValue{reflect.ValueOf(&m).Elem()} - ti := getTypeInfo(v.Type()) + va := reflect.ValueOf(&m).Elem() + ti := getTypeInfo(va.Type()) h := new(hasher) h.Block512.Hash = hb for i := 0; i < b.N; i++ { h.Reset() - h.hashMap(v, ti, false) + h.hashMap(va, ti) } } diff --git a/util/deephash/pointer.go b/util/deephash/pointer.go new file mode 100644 index 000000000..2fc5ab54a --- /dev/null +++ b/util/deephash/pointer.go @@ -0,0 +1,115 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package deephash + +import ( + "net/netip" + "reflect" + "time" + "unsafe" +) + +// unsafePointer is an untyped pointer. +// It is the caller's responsibility to call operations on the correct type. +// +// This pointer only ever points to a small set of kinds or types: +// time.Time, netip.Addr, string, array, slice, struct, map, pointer, interface, +// or a pointer to memory that is directly hashable. +// +// Arrays are represented as pointers to the first element. +// Structs are represented as pointers to the first field. +// Slices are represented as pointers to a slice header. +// Pointers are represented as pointers to a pointer. +// +// We do not support direct operations on maps and interfaces, and instead +// rely on pointer.asValue to convert the pointer back to a reflect.Value. +// Conversion of an unsafe.Pointer to reflect.Value guarantees that the +// read-only flag in the reflect.Value is unpopulated, avoiding panics that may +// othewise have occurred since the value was obtained from an unexported field. +type unsafePointer struct{ p unsafe.Pointer } + +func unsafePointerOf(v reflect.Value) unsafePointer { + return unsafePointer{v.UnsafePointer()} +} +func (p unsafePointer) isNil() bool { + return p.p == nil +} + +// pointerElem dereferences a pointer. +// p must point to a pointer. +func (p unsafePointer) pointerElem() unsafePointer { + return unsafePointer{*(*unsafe.Pointer)(p.p)} +} + +// sliceLen returns the slice length. +// p must point to a slice. +func (p unsafePointer) sliceLen() int { + return (*reflect.SliceHeader)(p.p).Len +} + +// sliceArray returns a pointer to the underlying slice array. +// p must point to a slice. +func (p unsafePointer) sliceArray() unsafePointer { + return unsafePointer{unsafe.Pointer((*reflect.SliceHeader)(p.p).Data)} +} + +// arrayIndex returns a pointer to an element in the array. +// p must point to an array. +func (p unsafePointer) arrayIndex(index int, size uintptr) unsafePointer { + return unsafePointer{unsafe.Add(p.p, uintptr(index)*size)} +} + +// structField returns a pointer to a field in a struct. +// p must pointer to a struct. +func (p unsafePointer) structField(index int, offset, size uintptr) unsafePointer { + return unsafePointer{unsafe.Add(p.p, offset)} +} + +// asString casts p as a *string. +func (p unsafePointer) asString() *string { + return (*string)(p.p) +} + +// asTime casts p as a *time.Time. +func (p unsafePointer) asTime() *time.Time { + return (*time.Time)(p.p) +} + +// asAddr casts p as a *netip.Addr. +func (p unsafePointer) asAddr() *netip.Addr { + return (*netip.Addr)(p.p) +} + +// asValue casts p as a reflect.Value containing a pointer to value of t. +func (p unsafePointer) asValue(typ reflect.Type) reflect.Value { + return reflect.NewAt(typ, p.p) +} + +// asMemory returns the memory pointer at by p for a specified size. +func (p unsafePointer) asMemory(size uintptr) []byte { + return unsafe.Slice((*byte)(p.p), size) +} + +// visitStack is a stack of pointers visited. +// Pointers are pushed onto the stack when visited, and popped when leaving. +// The integer value is the depth at which the pointer was visited. +// The length of this stack should be zero after every hashing operation. +type visitStack map[unsafe.Pointer]int + +func (v visitStack) seen(p unsafe.Pointer) (int, bool) { + idx, ok := v[p] + return idx, ok +} + +func (v *visitStack) push(p unsafe.Pointer) { + if *v == nil { + *v = make(map[unsafe.Pointer]int) + } + (*v)[p] = len(*v) +} + +func (v visitStack) pop(p unsafe.Pointer) { + delete(v, p) +} diff --git a/util/deephash/pointer_norace.go b/util/deephash/pointer_norace.go new file mode 100644 index 000000000..19d7f543a --- /dev/null +++ b/util/deephash/pointer_norace.go @@ -0,0 +1,14 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !race + +package deephash + +import "reflect" + +type pointer = unsafePointer + +// pointerOf returns a pointer from v, which must be a reflect.Pointer. +func pointerOf(v reflect.Value) pointer { return unsafePointerOf(v) } diff --git a/util/deephash/pointer_race.go b/util/deephash/pointer_race.go new file mode 100644 index 000000000..477fad7be --- /dev/null +++ b/util/deephash/pointer_race.go @@ -0,0 +1,100 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build race + +package deephash + +import ( + "fmt" + "net/netip" + "reflect" + "time" +) + +// pointer is a typed pointer that performs safety checks for every operation. +type pointer struct { + unsafePointer + t reflect.Type // type of pointed-at value; may be nil + n uintptr // size of valid memory after p +} + +// pointerOf returns a pointer from v, which must be a reflect.Pointer. +func pointerOf(v reflect.Value) pointer { + assert(v.Kind() == reflect.Pointer, "got %v, want pointer", v.Kind()) + te := v.Type().Elem() + return pointer{unsafePointerOf(v), te, te.Size()} +} + +func (p pointer) pointerElem() pointer { + assert(p.t.Kind() == reflect.Pointer, "got %v, want pointer", p.t.Kind()) + te := p.t.Elem() + return pointer{p.unsafePointer.pointerElem(), te, te.Size()} +} + +func (p pointer) sliceLen() int { + assert(p.t.Kind() == reflect.Slice, "got %v, want slice", p.t.Kind()) + return p.unsafePointer.sliceLen() +} + +func (p pointer) sliceArray() pointer { + assert(p.t.Kind() == reflect.Slice, "got %v, want slice", p.t.Kind()) + n := p.sliceLen() + assert(n >= 0, "got negative slice length %d", n) + ta := reflect.ArrayOf(n, p.t.Elem()) + return pointer{p.unsafePointer.sliceArray(), ta, ta.Size()} +} + +func (p pointer) arrayIndex(index int, size uintptr) pointer { + assert(p.t.Kind() == reflect.Array, "got %v, want array", p.t.Kind()) + assert(0 <= index && index < p.t.Len(), "got array of size %d, want to access element %d", p.t.Len(), index) + assert(p.t.Elem().Size() == size, "got element size of %d, want %d", p.t.Elem().Size(), size) + te := p.t.Elem() + return pointer{p.unsafePointer.arrayIndex(index, size), te, te.Size()} +} + +func (p pointer) structField(index int, offset, size uintptr) pointer { + assert(p.t.Kind() == reflect.Struct, "got %v, want struct", p.t.Kind()) + assert(p.n >= offset, "got size of %d, want excessive start offset of %d", p.n, offset) + assert(p.n >= offset+size, "got size of %d, want excessive end offset of %d", p.n, offset+size) + if index < 0 { + return pointer{p.unsafePointer.structField(index, offset, size), nil, size} + } + sf := p.t.Field(index) + t := sf.Type + assert(sf.Offset == offset, "got offset of %d, want offset %d", sf.Offset, offset) + assert(t.Size() == size, "got size of %d, want size %d", t.Size(), size) + return pointer{p.unsafePointer.structField(index, offset, size), t, t.Size()} +} + +func (p pointer) asString() *string { + assert(p.t.Kind() == reflect.String, "got %v, want string", p.t) + return p.unsafePointer.asString() +} + +func (p pointer) asTime() *time.Time { + assert(p.t == timeTimeType, "got %v, want %v", p.t, timeTimeType) + return p.unsafePointer.asTime() +} + +func (p pointer) asAddr() *netip.Addr { + assert(p.t == netipAddrType, "got %v, want %v", p.t, netipAddrType) + return p.unsafePointer.asAddr() +} + +func (p pointer) asValue(typ reflect.Type) reflect.Value { + assert(p.t == typ, "got %v, want %v", p.t, typ) + return p.unsafePointer.asValue(typ) +} + +func (p pointer) asMemory(size uintptr) []byte { + assert(p.n >= size, "got size of %d, want excessive size of %d", p.n, size) + return p.unsafePointer.asMemory(size) +} + +func assert(b bool, f string, a ...any) { + if !b { + panic(fmt.Sprintf(f, a...)) + } +}