util/deephash: add IncludeFields, ExcludeFields HasherForType Options

Updates tailscale/corp#6198

Change-Id: Iafc18c5b947522cf07a42a56f35c0319cc7b1c94
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/9083/head
Brad Fitzpatrick 1 year ago committed by Brad Fitzpatrick
parent e7d1538a2d
commit 4af22f3785

@ -23,11 +23,13 @@ import (
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"fmt"
"reflect"
"sync"
"time"
"tailscale.com/util/hashx"
"tailscale.com/util/set"
)
// There is much overlap between the theory of serialization and hashing.
@ -152,12 +154,90 @@ func Hash[T any](v *T) Sum {
return h.sum()
}
// Option is an optional argument to HasherForType.
type Option interface {
isOption()
}
type fieldFilterOpt struct {
t reflect.Type
fields set.Set[string]
includeOnMatch bool // true to include fields, false to exclude them
}
func (fieldFilterOpt) isOption() {}
func (f fieldFilterOpt) filterStructField(sf reflect.StructField) (include bool) {
if f.fields.Contains(sf.Name) {
return f.includeOnMatch
}
return !f.includeOnMatch
}
// IncludeFields returns an option that modifies the hashing for T to only
// include the named struct fields.
//
// T must be a struct type, and must match the type of the value passed to
// HasherForType.
func IncludeFields[T any](fields ...string) Option {
return newFieldFilter[T](true, fields)
}
// ExcludeFields returns an option that modifies the hashing for T to include
// all struct fields of T except those provided in fields.
//
// T must be a struct type, and must match the type of the value passed to
// HasherForType.
func ExcludeFields[T any](fields ...string) Option {
return newFieldFilter[T](false, fields)
}
func newFieldFilter[T any](include bool, fields []string) Option {
var zero T
t := reflect.TypeOf(&zero).Elem()
fieldSet := set.Set[string]{}
for _, f := range fields {
if _, ok := t.FieldByName(f); !ok {
panic(fmt.Sprintf("unknown field %q for type %v", f, t))
}
fieldSet.Add(f)
}
return fieldFilterOpt{t, fieldSet, include}
}
// HasherForType returns a hash that is specialized for the provided type.
func HasherForType[T any]() func(*T) Sum {
//
// HasherForType panics if the opts are invalid for the provided type.
//
// Currently, at most one option can be provided (IncludeFields or
// ExcludeFields) and its type must match the type of T. Those restrictions may
// be removed in the future, along with documentation about their precedence
// when combined.
func HasherForType[T any](opts ...Option) func(*T) Sum {
var v *T
seedOnce.Do(initSeed)
if len(opts) > 1 {
panic("HasherForType only accepts one optional argument") // for now
}
t := reflect.TypeOf(v).Elem()
hash := lookupTypeHasher(t)
var hash typeHasherFunc
for _, o := range opts {
switch o := o.(type) {
default:
panic(fmt.Sprintf("unknown HasherOpt %T", o))
case fieldFilterOpt:
if t.Kind() != reflect.Struct {
panic("HasherForStructTypeWithFieldFilter requires T of kind struct")
}
if t != o.t {
panic(fmt.Sprintf("field filter for type %v does not match HasherForType type %v", o.t, t))
}
hash = makeStructHasher(t, o.filterStructField)
}
}
if hash == nil {
hash = lookupTypeHasher(t)
}
return func(v *T) (s Sum) {
// This logic is identical to Hash, but pull out a few statements.
h := hasherPool.Get().(*hasher)
@ -225,7 +305,7 @@ func makeTypeHasher(t reflect.Type) typeHasherFunc {
case reflect.Slice:
return makeSliceHasher(t)
case reflect.Struct:
return makeStructHasher(t)
return makeStructHasher(t, keepAllStructFields)
case reflect.Map:
return makeMapHasher(t)
case reflect.Pointer:
@ -353,9 +433,12 @@ func makeSliceHasher(t reflect.Type) typeHasherFunc {
}
}
func makeStructHasher(t reflect.Type) typeHasherFunc {
func keepAllStructFields(keepField reflect.StructField) bool { return true }
func makeStructHasher(t reflect.Type, keepField func(reflect.StructField) bool) typeHasherFunc {
type fieldHasher struct {
idx int // index of field for reflect.Type.Field(n); negative if memory is directly hashable
keep bool
hash typeHasherFunc // only valid if idx is not negative
offset uintptr
size uintptr
@ -365,8 +448,8 @@ func makeStructHasher(t reflect.Type) typeHasherFunc {
init := func() {
for i, numField := 0, t.NumField(); i < numField; i++ {
sf := t.Field(i)
f := fieldHasher{i, nil, sf.Offset, sf.Type.Size()}
if typeIsMemHashable(sf.Type) {
f := fieldHasher{i, keepField(sf), nil, sf.Offset, sf.Type.Size()}
if f.keep && typeIsMemHashable(sf.Type) {
f.idx = -1
}
@ -390,6 +473,9 @@ func makeStructHasher(t reflect.Type) typeHasherFunc {
return func(h *hasher, p pointer) {
once.Do(init)
for _, field := range fields {
if !field.keep {
continue
}
pf := p.structField(field.idx, field.offset, field.size)
if field.idx < 0 {
h.HashBytes(pf.asMemory(field.size))

@ -1066,6 +1066,51 @@ func TestAppendTo(t *testing.T) {
}
}
func TestFilterFields(t *testing.T) {
type T struct {
A int
B int
C int
}
hashers := map[string]func(*T) Sum{
"all": HasherForType[T](),
"ac": HasherForType[T](IncludeFields[T]("A", "C")),
"b": HasherForType[T](ExcludeFields[T]("A", "C")),
}
tests := []struct {
hasher string
a, b T
wantEq bool
}{
{"all", T{1, 2, 3}, T{1, 2, 3}, true},
{"all", T{1, 2, 3}, T{0, 2, 3}, false},
{"all", T{1, 2, 3}, T{1, 0, 3}, false},
{"all", T{1, 2, 3}, T{1, 2, 0}, false},
{"ac", T{0, 0, 0}, T{0, 0, 0}, true},
{"ac", T{1, 0, 1}, T{1, 1, 1}, true},
{"ac", T{1, 1, 1}, T{1, 1, 0}, false},
{"b", T{0, 0, 0}, T{0, 0, 0}, true},
{"b", T{1, 0, 1}, T{1, 1, 1}, false},
{"b", T{1, 1, 1}, T{0, 1, 0}, true},
}
for _, tt := range tests {
f, ok := hashers[tt.hasher]
if !ok {
t.Fatalf("bad test: unknown hasher %q", tt.hasher)
}
sum1 := f(&tt.a)
sum2 := f(&tt.b)
got := sum1 == sum2
if got != tt.wantEq {
t.Errorf("hasher %q, for %+v and %v, got equal = %v; want %v", tt.hasher, tt.a, tt.b, got, tt.wantEq)
}
}
}
func BenchmarkAppendTo(b *testing.B) {
b.ReportAllocs()
v := getVal()

Loading…
Cancel
Save