diff --git a/util/deephash/deephash.go b/util/deephash/deephash.go index 36e3115dc..892ca2ad4 100644 --- a/util/deephash/deephash.go +++ b/util/deephash/deephash.go @@ -364,7 +364,7 @@ func genHashStructFields(t reflect.Type) typeHasherFunc { fields = append(fields, fieldInfo{ index: i, typeInfo: getTypeInfo(sf.Type), - canMemHash: canMemHash(sf.Type), + canMemHash: typeIsMemHashable(sf.Type), offset: sf.Offset, size: sf.Type.Size(), }) @@ -445,7 +445,7 @@ func genTypeHasher(t reflect.Type) typeHasherFunc { return (*hasher).hashString case reflect.Slice: et := t.Elem() - if canMemHash(et) { + if typeIsMemHashable(et) { return (*hasher).hashSliceMem } eti := getTypeInfo(et) @@ -464,7 +464,7 @@ func genTypeHasher(t reflect.Type) typeHasherFunc { return genHashStructFields(t) case reflect.Pointer: et := t.Elem() - if canMemHash(et) { + if typeIsMemHashable(et) { return genHashPtrToMemoryRange(et) } if t.Implements(appenderToType) { @@ -574,7 +574,7 @@ func genHashArray(t reflect.Type, eti *typeInfo) typeHasherFunc { return noopHasherFunc } et := t.Elem() - if canMemHash(et) { + if typeIsMemHashable(et) { return genHashArrayMem(t.Len(), t.Size(), eti) } n := t.Len() @@ -625,7 +625,7 @@ func getTypeInfoLocked(t reflect.Type, incomplete map[reflect.Type]*typeInfo) *t ti := &typeInfo{ rtype: t, isRecursive: typeIsRecursive(t), - canMemHash: canMemHash(t), + canMemHash: typeIsMemHashable(t), } incomplete[t] = ti @@ -640,89 +640,6 @@ func getTypeInfoLocked(t reflect.Type, incomplete map[reflect.Type]*typeInfo) *t return ti } -// typeIsRecursive reports whether t has a path back to itself. -// -// For interfaces, it currently always reports true. -func typeIsRecursive(t reflect.Type) bool { - inStack := map[reflect.Type]bool{} - - var visitType func(t reflect.Type) (isRecursiveSoFar bool) - visitType = func(t reflect.Type) (isRecursiveSoFar bool) { - // Check whether we have seen this type before. - if inStack[t] { - return true - } - inStack[t] = true - defer func() { - delete(inStack, t) - }() - - // Any type that is memory hashable must not be recursive since - // cycles can only occur if pointers are involved. - if canMemHash(t) { - return false - } - - // Recursively check types that may contain pointers. - switch t.Kind() { - default: - panic("unhandled kind " + t.Kind().String()) - case reflect.String, reflect.UnsafePointer, reflect.Func: - return false - case reflect.Interface: - // Assume the worst for now. TODO(bradfitz): in some cases - // we should be able to prove that it's not recursive. Not worth - // it for now. - return true - case reflect.Array, reflect.Chan, reflect.Pointer, reflect.Slice: - return visitType(t.Elem()) - case reflect.Map: - return visitType(t.Key()) || visitType(t.Elem()) - case reflect.Struct: - if t.String() == "intern.Value" { - // Otherwise its interface{} makes this return true. - return false - } - for i, numField := 0, t.NumField(); i < numField; i++ { - if visitType(t.Field(i).Type) { - return true - } - } - return false - } - } - return visitType(t) -} - -// canMemHash reports whether a slice of t can be hashed by looking at its -// contiguous bytes in memory alone. (e.g. structs with gaps aren't memhashable) -func canMemHash(t reflect.Type) bool { - if t.Size() == 0 { - return true - } - switch t.Kind() { - case reflect.Bool, - reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, - reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, - reflect.Float32, reflect.Float64, - reflect.Complex64, reflect.Complex128: - return true - case reflect.Array: - return canMemHash(t.Elem()) - case reflect.Struct: - var sumFieldSize uintptr - for i, numField := 0, t.NumField(); i < numField; i++ { - sf := t.Field(i) - if !canMemHash(sf.Type) { - return false - } - sumFieldSize += sf.Type.Size() - } - return sumFieldSize == t.Size() // ensure no gaps - } - return false -} - func (h *hasher) hashValue(v addressableValue, forceCycleChecking bool) { if !v.IsValid() { return diff --git a/util/deephash/deephash_test.go b/util/deephash/deephash_test.go index 7f0b5f419..2ed6dac10 100644 --- a/util/deephash/deephash_test.go +++ b/util/deephash/deephash_test.go @@ -10,7 +10,6 @@ import ( "encoding/binary" "fmt" "hash" - "io" "math" "math/rand" "net/netip" @@ -19,14 +18,12 @@ import ( "testing" "testing/quick" "time" - "unsafe" "go4.org/mem" "tailscale.com/tailcfg" "tailscale.com/types/dnstype" "tailscale.com/types/ipproto" "tailscale.com/types/key" - "tailscale.com/types/structs" "tailscale.com/util/deephash/testtype" "tailscale.com/util/dnsname" "tailscale.com/version" @@ -289,41 +286,6 @@ func getVal() any { } } -func TestTypeIsRecursive(t *testing.T) { - type RecursiveStruct struct { - v *RecursiveStruct - } - type RecursiveChan chan *RecursiveChan - - tests := []struct { - val any - want bool - }{ - {val: 42, want: false}, - {val: "string", want: false}, - {val: 1 + 2i, want: false}, - {val: struct{}{}, want: false}, - {val: (*RecursiveStruct)(nil), want: true}, - {val: RecursiveStruct{}, want: true}, - {val: time.Unix(0, 0), want: false}, - {val: structs.Incomparable{}, want: false}, // ignore its [0]func() - {val: tailcfg.NetPortRange{}, want: false}, // uses structs.Incomparable - {val: (*tailcfg.Node)(nil), want: false}, - {val: map[string]bool{}, want: false}, - {val: func() {}, want: false}, - {val: make(chan int), want: false}, - {val: unsafe.Pointer(nil), want: false}, - {val: make(RecursiveChan), want: true}, - {val: make(chan int), want: false}, - } - for _, tt := range tests { - got := typeIsRecursive(reflect.TypeOf(tt.val)) - if got != tt.want { - t.Errorf("for type %T: got %v, want %v", tt.val, got, tt.want) - } - } -} - type IntThenByte struct { i int b byte @@ -337,68 +299,6 @@ type IntIntByteInt struct { i3 int32 } -func TestCanMemHash(t *testing.T) { - tests := []struct { - val any - want bool - }{ - {true, true}, - {uint(1), true}, - {uint8(1), true}, - {uint16(1), true}, - {uint32(1), true}, - {uint64(1), true}, - {uintptr(1), true}, - {int(1), true}, - {int8(1), true}, - {int16(1), true}, - {int32(1), true}, - {int64(1), true}, - {float32(1), true}, - {float64(1), true}, - {complex64(1), true}, - {complex128(1), true}, - {[32]byte{}, true}, - {func() {}, false}, - {make(chan int), false}, - {struct{ io.Writer }{nil}, false}, - {unsafe.Pointer(nil), false}, - {new(int), false}, - {TwoInts{}, true}, - {[4]TwoInts{}, true}, - {IntThenByte{}, false}, - {[4]IntThenByte{}, false}, - {tailcfg.PortRange{}, true}, - {int16(0), true}, - {struct { - _ int - _ int - }{}, true}, - {struct { - _ int - _ uint8 - _ int - }{}, false}, // gap - {struct { - _ structs.Incomparable // if not last, zero-width - x int - }{}, true}, - {struct { - x int - _ structs.Incomparable // zero-width last: has space, can't memhash - }{}, - false}, - {[0]chan bool{}, true}, - {struct{ f [0]func() }{}, true}, - } - for _, tt := range tests { - got := canMemHash(reflect.TypeOf(tt.val)) - if got != tt.want { - t.Errorf("for type %T: got %v, want %v", tt.val, got, tt.want) - } - } -} - func u8(n uint8) string { return string([]byte{n}) } func u16(n uint16) string { return string(binary.LittleEndian.AppendUint16(nil, n)) } func u32(n uint32) string { return string(binary.LittleEndian.AppendUint32(nil, n)) } diff --git a/util/deephash/types.go b/util/deephash/types.go new file mode 100644 index 000000000..897f05a46 --- /dev/null +++ b/util/deephash/types.go @@ -0,0 +1,88 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package deephash + +import "reflect" + +// typeIsMemHashable reports whether t can be hashed by directly hashing its +// contiguous bytes in memory (e.g. structs with gaps are not mem-hashable). +func typeIsMemHashable(t reflect.Type) bool { + if t.Size() == 0 { + return true + } + switch t.Kind() { + case reflect.Bool, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, + reflect.Float32, reflect.Float64, + reflect.Complex64, reflect.Complex128: + return true + case reflect.Array: + return typeIsMemHashable(t.Elem()) + case reflect.Struct: + var sumFieldSize uintptr + for i, numField := 0, t.NumField(); i < numField; i++ { + sf := t.Field(i) + if !typeIsMemHashable(sf.Type) { + return false + } + sumFieldSize += sf.Type.Size() + } + return sumFieldSize == t.Size() // ensure no gaps + } + return false +} + +// typeIsRecursive reports whether t has a path back to itself. +// For interfaces, it currently always reports true. +func typeIsRecursive(t reflect.Type) bool { + inStack := map[reflect.Type]bool{} + var visitType func(t reflect.Type) (isRecursiveSoFar bool) + visitType = func(t reflect.Type) (isRecursiveSoFar bool) { + // Check whether we have seen this type before. + if inStack[t] { + return true + } + inStack[t] = true + defer func() { + delete(inStack, t) + }() + + // Any type that is memory hashable must not be recursive since + // cycles can only occur if pointers are involved. + if typeIsMemHashable(t) { + return false + } + + // Recursively check types that may contain pointers. + switch t.Kind() { + default: + panic("unhandled kind " + t.Kind().String()) + case reflect.String, reflect.UnsafePointer, reflect.Func: + return false + case reflect.Interface: + // Assume the worst for now. TODO(bradfitz): in some cases + // we should be able to prove that it's not recursive. Not worth + // it for now. + return true + case reflect.Array, reflect.Chan, reflect.Pointer, reflect.Slice: + return visitType(t.Elem()) + case reflect.Map: + return visitType(t.Key()) || visitType(t.Elem()) + case reflect.Struct: + if t.String() == "intern.Value" { + // Otherwise its interface{} makes this return true. + return false + } + for i, numField := 0, t.NumField(); i < numField; i++ { + if visitType(t.Field(i).Type) { + return true + } + } + return false + } + } + return visitType(t) +} diff --git a/util/deephash/types_test.go b/util/deephash/types_test.go new file mode 100644 index 000000000..28ed3bddc --- /dev/null +++ b/util/deephash/types_test.go @@ -0,0 +1,113 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package deephash + +import ( + "io" + "reflect" + "testing" + "time" + "unsafe" + + "tailscale.com/tailcfg" + "tailscale.com/types/structs" +) + +func TestTypeIsMemHashable(t *testing.T) { + tests := []struct { + val any + want bool + }{ + {true, true}, + {uint(1), true}, + {uint8(1), true}, + {uint16(1), true}, + {uint32(1), true}, + {uint64(1), true}, + {uintptr(1), true}, + {int(1), true}, + {int8(1), true}, + {int16(1), true}, + {int32(1), true}, + {int64(1), true}, + {float32(1), true}, + {float64(1), true}, + {complex64(1), true}, + {complex128(1), true}, + {[32]byte{}, true}, + {func() {}, false}, + {make(chan int), false}, + {struct{ io.Writer }{nil}, false}, + {unsafe.Pointer(nil), false}, + {new(int), false}, + {TwoInts{}, true}, + {[4]TwoInts{}, true}, + {IntThenByte{}, false}, + {[4]IntThenByte{}, false}, + {tailcfg.PortRange{}, true}, + {int16(0), true}, + {struct { + _ int + _ int + }{}, true}, + {struct { + _ int + _ uint8 + _ int + }{}, false}, // gap + {struct { + _ structs.Incomparable // if not last, zero-width + x int + }{}, true}, + {struct { + x int + _ structs.Incomparable // zero-width last: has space, can't memhash + }{}, + false}, + {[0]chan bool{}, true}, + {struct{ f [0]func() }{}, true}, + } + for _, tt := range tests { + got := typeIsMemHashable(reflect.TypeOf(tt.val)) + if got != tt.want { + t.Errorf("for type %T: got %v, want %v", tt.val, got, tt.want) + } + } +} + +func TestTypeIsRecursive(t *testing.T) { + type RecursiveStruct struct { + v *RecursiveStruct + } + type RecursiveChan chan *RecursiveChan + + tests := []struct { + val any + want bool + }{ + {val: 42, want: false}, + {val: "string", want: false}, + {val: 1 + 2i, want: false}, + {val: struct{}{}, want: false}, + {val: (*RecursiveStruct)(nil), want: true}, + {val: RecursiveStruct{}, want: true}, + {val: time.Unix(0, 0), want: false}, + {val: structs.Incomparable{}, want: false}, // ignore its [0]func() + {val: tailcfg.NetPortRange{}, want: false}, // uses structs.Incomparable + {val: (*tailcfg.Node)(nil), want: false}, + {val: map[string]bool{}, want: false}, + {val: func() {}, want: false}, + {val: make(chan int), want: false}, + {val: unsafe.Pointer(nil), want: false}, + {val: make(RecursiveChan), want: true}, + {val: make(chan int), want: false}, + } + for _, tt := range tests { + got := typeIsRecursive(reflect.TypeOf(tt.val)) + if got != tt.want { + t.Errorf("for type %T: got %v, want %v", tt.val, got, tt.want) + } + } +}