diff --git a/util/deephash/deephash.go b/util/deephash/deephash.go index 10667b0e2..4defffca1 100644 --- a/util/deephash/deephash.go +++ b/util/deephash/deephash.go @@ -180,6 +180,7 @@ var uint8Type = reflect.TypeOf(byte(0)) // typeInfo describes properties of a type. type typeInfo struct { rtype reflect.Type + canMemHash bool isRecursive bool // elemTypeInfo is the element type's typeInfo. @@ -218,6 +219,7 @@ func getTypeInfoLocked(t reflect.Type, incomplete map[reflect.Type]*typeInfo) *t ti := &typeInfo{ rtype: t, isRecursive: typeIsRecursive(t), + canMemHash: canMemHash(t), } incomplete[t] = ti @@ -311,6 +313,34 @@ func typeIsRecursive(t reflect.Type) bool { return visitType(t) } +// canMemHash reports whether a slice of t can be hashed by looking at its +// contiguous bytes in memory alone. (e.g. structs with gaps aren't memhashable) +func canMemHash(t reflect.Type) bool { + switch t.Kind() { + case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uintptr, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float64, reflect.Float32, reflect.Complex128, reflect.Complex64: + return true + case reflect.Array: + return canMemHash(t.Elem()) + case reflect.Struct: + var sumFieldSize uintptr + for i, numField := 0, t.NumField(); i < numField; i++ { + sf := t.Field(i) + if !canMemHash(sf.Type) { + // Special case for 0-width fields that aren't at the end. + if sf.Type.Size() == 0 && i < numField-1 { + continue + } + return false + } + sumFieldSize += sf.Type.Size() + } + return sumFieldSize == t.Size() // else there are gaps + } + return false +} + func (h *hasher) hashValue(v reflect.Value, forceCycleChecking bool) { if !v.IsValid() { return diff --git a/util/deephash/deephash_test.go b/util/deephash/deephash_test.go index 29c0b425e..643c42806 100644 --- a/util/deephash/deephash_test.go +++ b/util/deephash/deephash_test.go @@ -10,6 +10,7 @@ import ( "bytes" "crypto/sha256" "fmt" + "io" "math" "math/rand" "reflect" @@ -314,6 +315,83 @@ func TestTypeIsRecursive(t *testing.T) { } } +type IntThenByte struct { + i int + b byte +} + +type TwoInts struct{ a, b int } + +type IntIntByteInt struct { + i1, i2 int32 + b byte // padding after + i3 int32 +} + +func TestCanMemHash(t *testing.T) { + tests := []struct { + val any + want bool + }{ + {true, true}, + {uint(1), true}, + {uint8(1), true}, + {uint16(1), true}, + {uint32(1), true}, + {uint64(1), true}, + {uintptr(1), true}, + {int(1), true}, + {int8(1), true}, + {int16(1), true}, + {int32(1), true}, + {int64(1), true}, + {float32(1), true}, + {float64(1), true}, + {complex64(1), true}, + {complex128(1), true}, + {[32]byte{}, true}, + {func() {}, false}, + {make(chan int), false}, + {struct{ io.Writer }{nil}, false}, + {unsafe.Pointer(nil), false}, + {new(int), false}, + {TwoInts{}, true}, + {[4]TwoInts{}, true}, + {IntThenByte{}, false}, + {[4]IntThenByte{}, false}, + {tailcfg.PortRange{}, true}, + {int16(0), true}, + {struct { + _ int + _ int + }{}, true}, + {struct { + _ int + _ uint8 + _ int + }{}, false}, // gap + { + struct { + _ structs.Incomparable // if not last, zero-width + x int + }{}, + true, + }, + { + struct { + x int + _ structs.Incomparable // zero-width last: has space, can't memhash + }{}, + false, + }} + for _, tt := range tests { + got := canMemHash(reflect.TypeOf(tt.val)) + if got != tt.want { + t.Errorf("for type %T: got %v, want %v", tt.val, got, tt.want) + } + } +} + var sink = Hash("foo") func BenchmarkHash(b *testing.B) {