util/deephash: move typeIsRecursive and canMemHash to types.go (#5386)

Also, rename canMemHash to typeIsMemHashable to be consistent.
There are zero changes to the semantics.

Signed-off-by: Joe Tsai <joetsai@digital-static.net>
pull/5351/head
Joe Tsai 2 years ago committed by GitHub
parent d53eb6fa11
commit 44d62b65d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -364,7 +364,7 @@ func genHashStructFields(t reflect.Type) typeHasherFunc {
fields = append(fields, fieldInfo{ fields = append(fields, fieldInfo{
index: i, index: i,
typeInfo: getTypeInfo(sf.Type), typeInfo: getTypeInfo(sf.Type),
canMemHash: canMemHash(sf.Type), canMemHash: typeIsMemHashable(sf.Type),
offset: sf.Offset, offset: sf.Offset,
size: sf.Type.Size(), size: sf.Type.Size(),
}) })
@ -445,7 +445,7 @@ func genTypeHasher(t reflect.Type) typeHasherFunc {
return (*hasher).hashString return (*hasher).hashString
case reflect.Slice: case reflect.Slice:
et := t.Elem() et := t.Elem()
if canMemHash(et) { if typeIsMemHashable(et) {
return (*hasher).hashSliceMem return (*hasher).hashSliceMem
} }
eti := getTypeInfo(et) eti := getTypeInfo(et)
@ -464,7 +464,7 @@ func genTypeHasher(t reflect.Type) typeHasherFunc {
return genHashStructFields(t) return genHashStructFields(t)
case reflect.Pointer: case reflect.Pointer:
et := t.Elem() et := t.Elem()
if canMemHash(et) { if typeIsMemHashable(et) {
return genHashPtrToMemoryRange(et) return genHashPtrToMemoryRange(et)
} }
if t.Implements(appenderToType) { if t.Implements(appenderToType) {
@ -574,7 +574,7 @@ func genHashArray(t reflect.Type, eti *typeInfo) typeHasherFunc {
return noopHasherFunc return noopHasherFunc
} }
et := t.Elem() et := t.Elem()
if canMemHash(et) { if typeIsMemHashable(et) {
return genHashArrayMem(t.Len(), t.Size(), eti) return genHashArrayMem(t.Len(), t.Size(), eti)
} }
n := t.Len() n := t.Len()
@ -625,7 +625,7 @@ func getTypeInfoLocked(t reflect.Type, incomplete map[reflect.Type]*typeInfo) *t
ti := &typeInfo{ ti := &typeInfo{
rtype: t, rtype: t,
isRecursive: typeIsRecursive(t), isRecursive: typeIsRecursive(t),
canMemHash: canMemHash(t), canMemHash: typeIsMemHashable(t),
} }
incomplete[t] = ti incomplete[t] = ti
@ -640,89 +640,6 @@ func getTypeInfoLocked(t reflect.Type, incomplete map[reflect.Type]*typeInfo) *t
return ti return ti
} }
// typeIsRecursive reports whether t has a path back to itself.
//
// For interfaces, it currently always reports true.
func typeIsRecursive(t reflect.Type) bool {
inStack := map[reflect.Type]bool{}
var visitType func(t reflect.Type) (isRecursiveSoFar bool)
visitType = func(t reflect.Type) (isRecursiveSoFar bool) {
// Check whether we have seen this type before.
if inStack[t] {
return true
}
inStack[t] = true
defer func() {
delete(inStack, t)
}()
// Any type that is memory hashable must not be recursive since
// cycles can only occur if pointers are involved.
if canMemHash(t) {
return false
}
// Recursively check types that may contain pointers.
switch t.Kind() {
default:
panic("unhandled kind " + t.Kind().String())
case reflect.String, reflect.UnsafePointer, reflect.Func:
return false
case reflect.Interface:
// Assume the worst for now. TODO(bradfitz): in some cases
// we should be able to prove that it's not recursive. Not worth
// it for now.
return true
case reflect.Array, reflect.Chan, reflect.Pointer, reflect.Slice:
return visitType(t.Elem())
case reflect.Map:
return visitType(t.Key()) || visitType(t.Elem())
case reflect.Struct:
if t.String() == "intern.Value" {
// Otherwise its interface{} makes this return true.
return false
}
for i, numField := 0, t.NumField(); i < numField; i++ {
if visitType(t.Field(i).Type) {
return true
}
}
return false
}
}
return visitType(t)
}
// canMemHash reports whether a slice of t can be hashed by looking at its
// contiguous bytes in memory alone. (e.g. structs with gaps aren't memhashable)
func canMemHash(t reflect.Type) bool {
if t.Size() == 0 {
return true
}
switch t.Kind() {
case reflect.Bool,
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr,
reflect.Float32, reflect.Float64,
reflect.Complex64, reflect.Complex128:
return true
case reflect.Array:
return canMemHash(t.Elem())
case reflect.Struct:
var sumFieldSize uintptr
for i, numField := 0, t.NumField(); i < numField; i++ {
sf := t.Field(i)
if !canMemHash(sf.Type) {
return false
}
sumFieldSize += sf.Type.Size()
}
return sumFieldSize == t.Size() // ensure no gaps
}
return false
}
func (h *hasher) hashValue(v addressableValue, forceCycleChecking bool) { func (h *hasher) hashValue(v addressableValue, forceCycleChecking bool) {
if !v.IsValid() { if !v.IsValid() {
return return

@ -10,7 +10,6 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"hash" "hash"
"io"
"math" "math"
"math/rand" "math/rand"
"net/netip" "net/netip"
@ -19,14 +18,12 @@ import (
"testing" "testing"
"testing/quick" "testing/quick"
"time" "time"
"unsafe"
"go4.org/mem" "go4.org/mem"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/dnstype" "tailscale.com/types/dnstype"
"tailscale.com/types/ipproto" "tailscale.com/types/ipproto"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/structs"
"tailscale.com/util/deephash/testtype" "tailscale.com/util/deephash/testtype"
"tailscale.com/util/dnsname" "tailscale.com/util/dnsname"
"tailscale.com/version" "tailscale.com/version"
@ -289,41 +286,6 @@ func getVal() any {
} }
} }
func TestTypeIsRecursive(t *testing.T) {
type RecursiveStruct struct {
v *RecursiveStruct
}
type RecursiveChan chan *RecursiveChan
tests := []struct {
val any
want bool
}{
{val: 42, want: false},
{val: "string", want: false},
{val: 1 + 2i, want: false},
{val: struct{}{}, want: false},
{val: (*RecursiveStruct)(nil), want: true},
{val: RecursiveStruct{}, want: true},
{val: time.Unix(0, 0), want: false},
{val: structs.Incomparable{}, want: false}, // ignore its [0]func()
{val: tailcfg.NetPortRange{}, want: false}, // uses structs.Incomparable
{val: (*tailcfg.Node)(nil), want: false},
{val: map[string]bool{}, want: false},
{val: func() {}, want: false},
{val: make(chan int), want: false},
{val: unsafe.Pointer(nil), want: false},
{val: make(RecursiveChan), want: true},
{val: make(chan int), want: false},
}
for _, tt := range tests {
got := typeIsRecursive(reflect.TypeOf(tt.val))
if got != tt.want {
t.Errorf("for type %T: got %v, want %v", tt.val, got, tt.want)
}
}
}
type IntThenByte struct { type IntThenByte struct {
i int i int
b byte b byte
@ -337,68 +299,6 @@ type IntIntByteInt struct {
i3 int32 i3 int32
} }
func TestCanMemHash(t *testing.T) {
tests := []struct {
val any
want bool
}{
{true, true},
{uint(1), true},
{uint8(1), true},
{uint16(1), true},
{uint32(1), true},
{uint64(1), true},
{uintptr(1), true},
{int(1), true},
{int8(1), true},
{int16(1), true},
{int32(1), true},
{int64(1), true},
{float32(1), true},
{float64(1), true},
{complex64(1), true},
{complex128(1), true},
{[32]byte{}, true},
{func() {}, false},
{make(chan int), false},
{struct{ io.Writer }{nil}, false},
{unsafe.Pointer(nil), false},
{new(int), false},
{TwoInts{}, true},
{[4]TwoInts{}, true},
{IntThenByte{}, false},
{[4]IntThenByte{}, false},
{tailcfg.PortRange{}, true},
{int16(0), true},
{struct {
_ int
_ int
}{}, true},
{struct {
_ int
_ uint8
_ int
}{}, false}, // gap
{struct {
_ structs.Incomparable // if not last, zero-width
x int
}{}, true},
{struct {
x int
_ structs.Incomparable // zero-width last: has space, can't memhash
}{},
false},
{[0]chan bool{}, true},
{struct{ f [0]func() }{}, true},
}
for _, tt := range tests {
got := canMemHash(reflect.TypeOf(tt.val))
if got != tt.want {
t.Errorf("for type %T: got %v, want %v", tt.val, got, tt.want)
}
}
}
func u8(n uint8) string { return string([]byte{n}) } func u8(n uint8) string { return string([]byte{n}) }
func u16(n uint16) string { return string(binary.LittleEndian.AppendUint16(nil, n)) } func u16(n uint16) string { return string(binary.LittleEndian.AppendUint16(nil, n)) }
func u32(n uint32) string { return string(binary.LittleEndian.AppendUint32(nil, n)) } func u32(n uint32) string { return string(binary.LittleEndian.AppendUint32(nil, n)) }

@ -0,0 +1,88 @@
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package deephash
import "reflect"
// typeIsMemHashable reports whether t can be hashed by directly hashing its
// contiguous bytes in memory (e.g. structs with gaps are not mem-hashable).
func typeIsMemHashable(t reflect.Type) bool {
if t.Size() == 0 {
return true
}
switch t.Kind() {
case reflect.Bool,
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr,
reflect.Float32, reflect.Float64,
reflect.Complex64, reflect.Complex128:
return true
case reflect.Array:
return typeIsMemHashable(t.Elem())
case reflect.Struct:
var sumFieldSize uintptr
for i, numField := 0, t.NumField(); i < numField; i++ {
sf := t.Field(i)
if !typeIsMemHashable(sf.Type) {
return false
}
sumFieldSize += sf.Type.Size()
}
return sumFieldSize == t.Size() // ensure no gaps
}
return false
}
// typeIsRecursive reports whether t has a path back to itself.
// For interfaces, it currently always reports true.
func typeIsRecursive(t reflect.Type) bool {
inStack := map[reflect.Type]bool{}
var visitType func(t reflect.Type) (isRecursiveSoFar bool)
visitType = func(t reflect.Type) (isRecursiveSoFar bool) {
// Check whether we have seen this type before.
if inStack[t] {
return true
}
inStack[t] = true
defer func() {
delete(inStack, t)
}()
// Any type that is memory hashable must not be recursive since
// cycles can only occur if pointers are involved.
if typeIsMemHashable(t) {
return false
}
// Recursively check types that may contain pointers.
switch t.Kind() {
default:
panic("unhandled kind " + t.Kind().String())
case reflect.String, reflect.UnsafePointer, reflect.Func:
return false
case reflect.Interface:
// Assume the worst for now. TODO(bradfitz): in some cases
// we should be able to prove that it's not recursive. Not worth
// it for now.
return true
case reflect.Array, reflect.Chan, reflect.Pointer, reflect.Slice:
return visitType(t.Elem())
case reflect.Map:
return visitType(t.Key()) || visitType(t.Elem())
case reflect.Struct:
if t.String() == "intern.Value" {
// Otherwise its interface{} makes this return true.
return false
}
for i, numField := 0, t.NumField(); i < numField; i++ {
if visitType(t.Field(i).Type) {
return true
}
}
return false
}
}
return visitType(t)
}

@ -0,0 +1,113 @@
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package deephash
import (
"io"
"reflect"
"testing"
"time"
"unsafe"
"tailscale.com/tailcfg"
"tailscale.com/types/structs"
)
func TestTypeIsMemHashable(t *testing.T) {
tests := []struct {
val any
want bool
}{
{true, true},
{uint(1), true},
{uint8(1), true},
{uint16(1), true},
{uint32(1), true},
{uint64(1), true},
{uintptr(1), true},
{int(1), true},
{int8(1), true},
{int16(1), true},
{int32(1), true},
{int64(1), true},
{float32(1), true},
{float64(1), true},
{complex64(1), true},
{complex128(1), true},
{[32]byte{}, true},
{func() {}, false},
{make(chan int), false},
{struct{ io.Writer }{nil}, false},
{unsafe.Pointer(nil), false},
{new(int), false},
{TwoInts{}, true},
{[4]TwoInts{}, true},
{IntThenByte{}, false},
{[4]IntThenByte{}, false},
{tailcfg.PortRange{}, true},
{int16(0), true},
{struct {
_ int
_ int
}{}, true},
{struct {
_ int
_ uint8
_ int
}{}, false}, // gap
{struct {
_ structs.Incomparable // if not last, zero-width
x int
}{}, true},
{struct {
x int
_ structs.Incomparable // zero-width last: has space, can't memhash
}{},
false},
{[0]chan bool{}, true},
{struct{ f [0]func() }{}, true},
}
for _, tt := range tests {
got := typeIsMemHashable(reflect.TypeOf(tt.val))
if got != tt.want {
t.Errorf("for type %T: got %v, want %v", tt.val, got, tt.want)
}
}
}
func TestTypeIsRecursive(t *testing.T) {
type RecursiveStruct struct {
v *RecursiveStruct
}
type RecursiveChan chan *RecursiveChan
tests := []struct {
val any
want bool
}{
{val: 42, want: false},
{val: "string", want: false},
{val: 1 + 2i, want: false},
{val: struct{}{}, want: false},
{val: (*RecursiveStruct)(nil), want: true},
{val: RecursiveStruct{}, want: true},
{val: time.Unix(0, 0), want: false},
{val: structs.Incomparable{}, want: false}, // ignore its [0]func()
{val: tailcfg.NetPortRange{}, want: false}, // uses structs.Incomparable
{val: (*tailcfg.Node)(nil), want: false},
{val: map[string]bool{}, want: false},
{val: func() {}, want: false},
{val: make(chan int), want: false},
{val: unsafe.Pointer(nil), want: false},
{val: make(RecursiveChan), want: true},
{val: make(chan int), want: false},
}
for _, tt := range tests {
got := typeIsRecursive(reflect.TypeOf(tt.val))
if got != tt.want {
t.Errorf("for type %T: got %v, want %v", tt.val, got, tt.want)
}
}
}
Loading…
Cancel
Save