diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index 205f675f2..a2e9dae13 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -279,6 +279,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/util/osshare from tailscale.com/ipn/ipnlocal+ tailscale.com/util/pidowner from tailscale.com/ipn/ipnserver tailscale.com/util/racebuild from tailscale.com/logpolicy + 💣 tailscale.com/util/sha256x from tailscale.com/util/deephash tailscale.com/util/singleflight from tailscale.com/control/controlclient+ L tailscale.com/util/strs from tailscale.com/hostinfo tailscale.com/util/systemd from tailscale.com/control/controlclient+ diff --git a/util/deephash/deephash.go b/util/deephash/deephash.go index a1b70906f..c1ad7ee3b 100644 --- a/util/deephash/deephash.go +++ b/util/deephash/deephash.go @@ -20,18 +20,18 @@ package deephash import ( - "bufio" "crypto/sha256" "encoding/binary" "encoding/hex" "fmt" - "hash" "log" "math" "reflect" "sync" "time" "unsafe" + + "tailscale.com/util/sha256x" ) // There is much overlap between the theory of serialization and hashing. @@ -79,23 +79,11 @@ const scratchSize = 128 // hasher is reusable state for hashing a value. // Get one via hasherPool. type hasher struct { - h hash.Hash - bw *bufio.Writer + sha256x.Hash scratch [scratchSize]byte visitStack visitStack } -func (h *hasher) reset() { - if h.h == nil { - h.h = sha256.New() - } - if h.bw == nil { - h.bw = bufio.NewWriterSize(h.h, h.h.BlockSize()) - } - h.bw.Flush() - h.h.Reset() -} - // Sum is an opaque checksum type that is comparable. type Sum struct { sum [sha256.Size]byte @@ -121,12 +109,7 @@ func initSeed() { } func (h *hasher) sum() (s Sum) { - h.bw.Flush() - // Sum into scratch & copy out, as hash.Hash is an interface - // so the slice necessarily escapes, and there's no sha256 - // concrete type exported and we don't want the 'hash' result - // parameter to escape to the heap: - copy(s.sum[:], h.h.Sum(h.scratch[:0])) + h.Sum(s.sum[:0]) return s } @@ -139,9 +122,9 @@ var hasherPool = &sync.Pool{ func Hash(v any) (s Sum) { h := hasherPool.Get().(*hasher) defer hasherPool.Put(h) - h.reset() + h.Reset() seedOnce.Do(initSeed) - h.hashUint64(seed) + h.HashUint64(seed) rv := reflect.ValueOf(v) if rv.IsValid() { @@ -177,11 +160,11 @@ func HasherForType[T any]() func(T) Sum { } seedOnce.Do(initSeed) - return func(v T) Sum { + return func(v T) (s Sum) { h := hasherPool.Get().(*hasher) defer hasherPool.Put(h) - h.reset() - h.hashUint64(seed) + h.Reset() + h.HashUint64(seed) rv := reflect.ValueOf(v) @@ -218,26 +201,6 @@ type appenderTo interface { AppendTo([]byte) []byte } -func (h *hasher) hashUint8(i uint8) { - h.bw.WriteByte(i) -} -func (h *hasher) hashUint16(i uint16) { - binary.LittleEndian.PutUint16(h.scratch[:2], i) - h.bw.Write(h.scratch[:2]) -} -func (h *hasher) hashUint32(i uint32) { - binary.LittleEndian.PutUint32(h.scratch[:4], i) - h.bw.Write(h.scratch[:4]) -} -func (h *hasher) hashLen(n int) { - binary.LittleEndian.PutUint64(h.scratch[:8], uint64(n)) - h.bw.Write(h.scratch[:8]) -} -func (h *hasher) hashUint64(i uint64) { - binary.LittleEndian.PutUint64(h.scratch[:8], i) - h.bw.Write(h.scratch[:8]) -} - var ( uint8Type = reflect.TypeOf(byte(0)) timeTimeType = reflect.TypeOf(time.Time{}) @@ -286,47 +249,47 @@ func (h *hasher) hashBoolv(v addressableValue) bool { if v.Bool() { b = 1 } - h.hashUint8(b) + h.HashUint8(b) return true } func (h *hasher) hashUint8v(v addressableValue) bool { - h.hashUint8(uint8(v.Uint())) + h.HashUint8(uint8(v.Uint())) return true } func (h *hasher) hashInt8v(v addressableValue) bool { - h.hashUint8(uint8(v.Int())) + h.HashUint8(uint8(v.Int())) return true } func (h *hasher) hashUint16v(v addressableValue) bool { - h.hashUint16(uint16(v.Uint())) + h.HashUint16(uint16(v.Uint())) return true } func (h *hasher) hashInt16v(v addressableValue) bool { - h.hashUint16(uint16(v.Int())) + h.HashUint16(uint16(v.Int())) return true } func (h *hasher) hashUint32v(v addressableValue) bool { - h.hashUint32(uint32(v.Uint())) + h.HashUint32(uint32(v.Uint())) return true } func (h *hasher) hashInt32v(v addressableValue) bool { - h.hashUint32(uint32(v.Int())) + h.HashUint32(uint32(v.Int())) return true } func (h *hasher) hashUint64v(v addressableValue) bool { - h.hashUint64(v.Uint()) + h.HashUint64(v.Uint()) return true } func (h *hasher) hashInt64v(v addressableValue) bool { - h.hashUint64(uint64(v.Int())) + h.HashUint64(uint64(v.Int())) return true } @@ -338,7 +301,7 @@ func hashStructAppenderTo(h *hasher, v addressableValue) bool { size := h.scratch[:8] record := a.AppendTo(size) binary.LittleEndian.PutUint64(record, uint64(len(record)-len(size))) - h.bw.Write(record) + h.HashBytes(record) return true } @@ -348,15 +311,15 @@ func hashPointerAppenderTo(h *hasher, v addressableValue) bool { return false // slow path } if v.IsNil() { - h.hashUint8(0) // indicates nil + h.HashUint8(0) // indicates nil return true } - h.hashUint8(1) // indicates visiting a pointer + h.HashUint8(1) // indicates visiting a pointer a := v.Interface().(appenderTo) size := h.scratch[:8] record := a.AppendTo(size) binary.LittleEndian.PutUint64(record, uint64(len(record)-len(size))) - h.bw.Write(record) + h.HashBytes(record) return true } @@ -416,7 +379,7 @@ func (sh structHasher) hash(h *hasher, v addressableValue) bool { base := v.Addr().UnsafePointer() for _, f := range sh.fields { if f.canMemHash { - h.bw.Write(unsafe.Slice((*byte)(unsafe.Pointer(uintptr(base)+f.offset)), f.size)) + h.HashBytes(unsafe.Slice((*byte)(unsafe.Pointer(uintptr(base)+f.offset)), f.size)) continue } va := addressableValue{v.Field(f.index)} // field is addressable if parent struct is addressable @@ -433,10 +396,10 @@ func genHashPtrToMemoryRange(eleType reflect.Type) typeHasherFunc { size := eleType.Size() return func(h *hasher, v addressableValue) bool { if v.IsNil() { - h.hashUint8(0) // indicates nil + h.HashUint8(0) // indicates nil } else { - h.hashUint8(1) // indicates visiting a pointer - h.bw.Write(unsafe.Slice((*byte)(v.UnsafePointer()), size)) + h.HashUint8(1) // indicates visiting a pointer + h.HashBytes(unsafe.Slice((*byte)(v.UnsafePointer()), size)) } return true } @@ -509,10 +472,10 @@ func genTypeHasher(t reflect.Type) typeHasherFunc { eti := getTypeInfo(et) return func(h *hasher, v addressableValue) bool { if v.IsNil() { - h.hashUint8(0) // indicates nil + h.HashUint8(0) // indicates nil return true } - h.hashUint8(1) // indicates visiting a pointer + h.HashUint8(1) // indicates visiting a pointer va := addressableValue{v.Elem()} // dereferenced pointer is always addressable return eti.hasher()(h, va) } @@ -530,32 +493,32 @@ func genTypeHasher(t reflect.Type) typeHasherFunc { // hashString hashes v, of kind String. func (h *hasher) hashString(v addressableValue) bool { s := v.String() - h.hashLen(len(s)) - h.bw.WriteString(s) + h.HashUint64(uint64(len(s))) + h.HashString(s) return true } func (h *hasher) hashFloat32v(v addressableValue) bool { - h.hashUint32(math.Float32bits(float32(v.Float()))) + h.HashUint32(math.Float32bits(float32(v.Float()))) return true } func (h *hasher) hashFloat64v(v addressableValue) bool { - h.hashUint64(math.Float64bits(v.Float())) + h.HashUint64(math.Float64bits(v.Float())) return true } func (h *hasher) hashComplex64v(v addressableValue) bool { c := complex64(v.Complex()) - h.hashUint32(math.Float32bits(real(c))) - h.hashUint32(math.Float32bits(imag(c))) + h.HashUint32(math.Float32bits(real(c))) + h.HashUint32(math.Float32bits(imag(c))) return true } func (h *hasher) hashComplex128v(v addressableValue) bool { c := v.Complex() - h.hashUint64(math.Float64bits(real(c))) - h.hashUint64(math.Float64bits(imag(c))) + h.HashUint64(math.Float64bits(real(c))) + h.HashUint64(math.Float64bits(imag(c))) return true } @@ -564,24 +527,24 @@ func (h *hasher) hashTimev(v addressableValue) bool { t := *(*time.Time)(v.Addr().UnsafePointer()) b := t.AppendFormat(h.scratch[:1], time.RFC3339Nano) b[0] = byte(len(b) - 1) // more than sufficient width; if not, good enough. - h.bw.Write(b) + h.HashBytes(b) return true } // hashSliceMem hashes v, of kind Slice, with a memhash-able element type. func (h *hasher) hashSliceMem(v addressableValue) bool { vLen := v.Len() - h.hashUint64(uint64(vLen)) + h.HashUint64(uint64(vLen)) if vLen == 0 { return true } - h.bw.Write(unsafe.Slice((*byte)(v.UnsafePointer()), v.Type().Elem().Size()*uintptr(vLen))) + h.HashBytes(unsafe.Slice((*byte)(v.UnsafePointer()), v.Type().Elem().Size()*uintptr(vLen))) return true } func genHashArrayMem(n int, arraySize uintptr, efu *typeInfo) typeHasherFunc { return func(h *hasher, v addressableValue) bool { - h.bw.Write(unsafe.Slice((*byte)(v.Addr().UnsafePointer()), arraySize)) + h.HashBytes(unsafe.Slice((*byte)(v.Addr().UnsafePointer()), arraySize)) return true } } @@ -622,7 +585,7 @@ type sliceElementHasher struct { func (seh sliceElementHasher) hash(h *hasher, v addressableValue) bool { vLen := v.Len() - h.hashUint64(uint64(vLen)) + h.HashUint64(uint64(vLen)) for i := 0; i < vLen; i++ { va := addressableValue{v.Index(i)} // slice elements are always addressable if !seh.eti.hasher()(h, va) { @@ -787,7 +750,6 @@ func (h *hasher) hashValue(v addressableValue, forceCycleChecking bool) { } func (h *hasher) hashValueWithType(v addressableValue, ti *typeInfo, forceCycleChecking bool) { - w := h.bw doCheckCycles := forceCycleChecking || ti.isRecursive if !doCheckCycles { @@ -803,22 +765,22 @@ func (h *hasher) hashValueWithType(v addressableValue, ti *typeInfo, forceCycleC panic(fmt.Sprintf("unhandled kind %v for type %v", v.Kind(), v.Type())) case reflect.Ptr: if v.IsNil() { - h.hashUint8(0) // indicates nil + h.HashUint8(0) // indicates nil return } if doCheckCycles { ptr := pointerOf(v) if idx, ok := h.visitStack.seen(ptr); ok { - h.hashUint8(2) // indicates cycle - h.hashUint64(uint64(idx)) + h.HashUint8(2) // indicates cycle + h.HashUint64(uint64(idx)) return } h.visitStack.push(ptr) defer h.visitStack.pop(ptr) } - h.hashUint8(1) // indicates visiting a pointer + h.HashUint8(1) // indicates visiting a pointer va := addressableValue{v.Elem()} // dereferenced pointer is always addressable h.hashValueWithType(va, ti.elemTypeInfo, doCheckCycles) case reflect.Struct: @@ -829,7 +791,7 @@ func (h *hasher) hashValueWithType(v addressableValue, ti *typeInfo, forceCycleC case reflect.Slice, reflect.Array: vLen := v.Len() if v.Kind() == reflect.Slice { - h.hashUint64(uint64(vLen)) + h.HashUint64(uint64(vLen)) } if v.Type().Elem() == uint8Type && v.CanInterface() { if vLen > 0 && vLen <= scratchSize { @@ -838,10 +800,10 @@ func (h *hasher) hashValueWithType(v addressableValue, ti *typeInfo, forceCycleC // scratchSize bytes at a time, but reflect.Slice seems // to allocate, so it's not a win. n := reflect.Copy(reflect.ValueOf(&h.scratch).Elem(), v.Value) - w.Write(h.scratch[:n]) + h.HashBytes(h.scratch[:n]) return } - fmt.Fprintf(w, "%s", v.Interface()) + fmt.Fprintf(h, "%s", v.Interface()) return } for i := 0; i < vLen; i++ { @@ -853,14 +815,14 @@ func (h *hasher) hashValueWithType(v addressableValue, ti *typeInfo, forceCycleC } case reflect.Interface: if v.IsNil() { - h.hashUint8(0) // indicates nil + h.HashUint8(0) // indicates nil return } // TODO: Use a valueCache here? va := newAddressableValue(v.Elem().Type()) va.Set(v.Elem()) - h.hashUint8(1) // indicates visiting interface value + h.HashUint8(1) // indicates visiting interface value h.hashType(va.Type()) h.hashValue(va, doCheckCycles) case reflect.Map: @@ -868,51 +830,51 @@ func (h *hasher) hashValueWithType(v addressableValue, ti *typeInfo, forceCycleC if doCheckCycles { ptr := pointerOf(v) if idx, ok := h.visitStack.seen(ptr); ok { - h.hashUint8(2) // indicates cycle - h.hashUint64(uint64(idx)) + h.HashUint8(2) // indicates cycle + h.HashUint64(uint64(idx)) return } h.visitStack.push(ptr) defer h.visitStack.pop(ptr) } - h.hashUint8(1) // indicates visiting a map + h.HashUint8(1) // indicates visiting a map h.hashMap(v, ti, doCheckCycles) case reflect.String: s := v.String() - h.hashUint64(uint64(len(s))) - w.WriteString(s) + h.HashUint64(uint64(len(s))) + h.HashString(s) case reflect.Bool: if v.Bool() { - h.hashUint8(1) + h.HashUint8(1) } else { - h.hashUint8(0) + h.HashUint8(0) } case reflect.Int8: - h.hashUint8(uint8(v.Int())) + h.HashUint8(uint8(v.Int())) case reflect.Int16: - h.hashUint16(uint16(v.Int())) + h.HashUint16(uint16(v.Int())) case reflect.Int32: - h.hashUint32(uint32(v.Int())) + h.HashUint32(uint32(v.Int())) case reflect.Int64, reflect.Int: - h.hashUint64(uint64(v.Int())) + h.HashUint64(uint64(v.Int())) case reflect.Uint8: - h.hashUint8(uint8(v.Uint())) + h.HashUint8(uint8(v.Uint())) case reflect.Uint16: - h.hashUint16(uint16(v.Uint())) + h.HashUint16(uint16(v.Uint())) case reflect.Uint32: - h.hashUint32(uint32(v.Uint())) + h.HashUint32(uint32(v.Uint())) case reflect.Uint64, reflect.Uint, reflect.Uintptr: - h.hashUint64(uint64(v.Uint())) + h.HashUint64(uint64(v.Uint())) case reflect.Float32: - h.hashUint32(math.Float32bits(float32(v.Float()))) + h.HashUint32(math.Float32bits(float32(v.Float()))) case reflect.Float64: - h.hashUint64(math.Float64bits(float64(v.Float()))) + h.HashUint64(math.Float64bits(float64(v.Float()))) case reflect.Complex64: - h.hashUint32(math.Float32bits(real(complex64(v.Complex())))) - h.hashUint32(math.Float32bits(imag(complex64(v.Complex())))) + h.HashUint32(math.Float32bits(real(complex64(v.Complex())))) + h.HashUint32(math.Float32bits(imag(complex64(v.Complex())))) case reflect.Complex128: - h.hashUint64(math.Float64bits(real(complex128(v.Complex())))) - h.hashUint64(math.Float64bits(imag(complex128(v.Complex())))) + h.HashUint64(math.Float64bits(real(complex128(v.Complex())))) + h.HashUint64(math.Float64bits(imag(complex128(v.Complex())))) } } @@ -958,12 +920,12 @@ func (h *hasher) hashMap(v addressableValue, ti *typeInfo, checkCycles bool) { for iter := v.MapRange(); iter.Next(); { k.SetIterKey(iter) e.SetIterValue(iter) - mh.h.reset() + mh.h.Reset() mh.h.hashValueWithType(k, ti.keyTypeInfo, checkCycles) mh.h.hashValueWithType(e, ti.elemTypeInfo, checkCycles) sum.xor(mh.h.sum()) } - h.bw.Write(append(h.scratch[:0], sum.sum[:]...)) // append into scratch to avoid heap allocation + h.HashBytes(append(h.scratch[:0], sum.sum[:]...)) // append into scratch to avoid heap allocation } // visitStack is a stack of pointers visited. @@ -1005,5 +967,5 @@ func (h *hasher) hashType(t reflect.Type) { // that maps reflect.Type to some arbitrary and unique index. // While safer, it requires global state with memory that can never be GC'd. rtypeAddr := reflect.ValueOf(t).Pointer() // address of *reflect.rtype - h.hashUint64(uint64(rtypeAddr)) + h.HashUint64(uint64(rtypeAddr)) } diff --git a/util/deephash/deephash_test.go b/util/deephash/deephash_test.go index b4abae568..27d077866 100644 --- a/util/deephash/deephash_test.go +++ b/util/deephash/deephash_test.go @@ -6,10 +6,9 @@ package deephash import ( "archive/tar" - "bufio" - "bytes" "crypto/sha256" "fmt" + "hash" "io" "math" "math/rand" @@ -626,10 +625,9 @@ func TestGetTypeHasher(t *testing.T) { va := newAddressableValue(rv.Type()) va.Set(rv) fn := getTypeInfo(va.Type()).hasher() - var buf bytes.Buffer - h := &hasher{ - bw: bufio.NewWriter(&buf), - } + hb := &hashBuffer{Hash: sha256.New()} + h := new(hasher) + h.Hash.H = hb got := fn(h, va) const ptrSize = 32 << uintptr(^uintptr(0)>>63) if tt.out32 != "" && ptrSize == 32 { @@ -641,10 +639,8 @@ func TestGetTypeHasher(t *testing.T) { if got != tt.want { t.Fatalf("func returned %v; want %v", got, tt.want) } - if err := h.bw.Flush(); err != nil { - t.Fatal(err) - } - if got := buf.String(); got != tt.out { + h.sum() + if got := string(hb.B); got != tt.out { t.Fatalf("got %q; want %q", got, tt.out) } }) @@ -720,21 +716,21 @@ func TestHashMapAcyclic(t *testing.T) { } got := map[string]bool{} - var buf bytes.Buffer - bw := bufio.NewWriter(&buf) + hb := &hashBuffer{Hash: sha256.New()} ti := getTypeInfo(reflect.TypeOf(m)) for i := 0; i < 20; i++ { v := addressableValue{reflect.ValueOf(&m).Elem()} - buf.Reset() - bw.Reset(&buf) - h := &hasher{bw: bw} + hb.Reset() + h := new(hasher) + h.Hash.H = hb h.hashMap(v, ti, false) - if got[string(buf.Bytes())] { + h.sum() + if got[string(hb.B)] { continue } - got[string(buf.Bytes())] = true + got[string(hb.B)] = true } if len(got) != 1 { t.Errorf("got %d results; want 1", len(got)) @@ -746,13 +742,13 @@ func TestPrintArray(t *testing.T) { X [32]byte } x := T{X: [32]byte{1: 1, 31: 31}} - var got bytes.Buffer - bw := bufio.NewWriter(&got) - h := &hasher{bw: bw} + hb := &hashBuffer{Hash: sha256.New()} + h := new(hasher) + h.Hash.H = hb h.hashValue(addressableValue{reflect.ValueOf(&x).Elem()}, false) - bw.Flush() + h.sum() const want = "\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f" - if got := got.Bytes(); string(got) != want { + if got := hb.B; string(got) != want { t.Errorf("wrong:\n got: %q\nwant: %q\n", got, want) } } @@ -764,16 +760,15 @@ func BenchmarkHashMapAcyclic(b *testing.B) { m[i] = fmt.Sprint(i) } - var buf bytes.Buffer - bw := bufio.NewWriter(&buf) + hb := &hashBuffer{Hash: sha256.New()} v := addressableValue{reflect.ValueOf(&m).Elem()} ti := getTypeInfo(v.Type()) - h := &hasher{bw: bw} + h := new(hasher) + h.Hash.H = hb for i := 0; i < b.N; i++ { - buf.Reset() - bw.Reset(&buf) + h.Reset() h.hashMap(v, ti, false) } } @@ -874,3 +869,19 @@ func BenchmarkHashArray(b *testing.B) { sink = Hash(x) } } + +// hashBuffer is a hash.Hash that buffers all written data. +type hashBuffer struct { + hash.Hash + B []byte +} + +func (h *hashBuffer) Write(b []byte) (int, error) { + n, err := h.Hash.Write(b) + h.B = append(h.B, b[:n]...) + return n, err +} +func (h *hashBuffer) Reset() { + h.Hash.Reset() + h.B = h.B[:0] +} diff --git a/util/sha256x/sha256.go b/util/sha256x/sha256.go index 50cffb574..0a20437de 100644 --- a/util/sha256x/sha256.go +++ b/util/sha256x/sha256.go @@ -11,6 +11,7 @@ import ( "crypto/sha256" "encoding/binary" "hash" + "unsafe" ) var _ hash.Hash = (*Hash)(nil) @@ -24,13 +25,16 @@ type Hash struct { // However, it does mean that sha256.digest.x goes unused, // which is a waste of 64B. - h hash.Hash // always *sha256.digest + // H is the underlying hash.Hash. + // The hash.Hash.BlockSize must be equal to sha256.BlockSize. + // It is exported only for testing purposes. + H hash.Hash // usually a *sha256.digest x [sha256.BlockSize]byte // equivalent to sha256.digest.x nx int // equivalent to sha256.digest.nx } func New() *Hash { - return &Hash{h: sha256.New()} + return &Hash{H: sha256.New()} } func (h *Hash) Write(b []byte) (int, error) { @@ -42,32 +46,32 @@ func (h *Hash) Sum(b []byte) []byte { if h.nx > 0 { // This causes block mis-alignment. Future operations will be correct, // but are less efficient until Reset is called. - h.h.Write(h.x[:h.nx]) + h.H.Write(h.x[:h.nx]) h.nx = 0 } // Unfortunately hash.Hash.Sum always causes the input to escape since // escape analysis cannot prove anything past an interface method call. // Assuming h already escapes, we call Sum with h.x first, - // and then the copy the result to b. - sum := h.h.Sum(h.x[:0]) + // and then copy the result to b. + sum := h.H.Sum(h.x[:0]) return append(b, sum...) } func (h *Hash) Reset() { - if h.h == nil { - h.h = sha256.New() + if h.H == nil { + h.H = sha256.New() } - h.h.Reset() + h.H.Reset() h.nx = 0 } func (h *Hash) Size() int { - return h.h.Size() + return h.H.Size() } func (h *Hash) BlockSize() int { - return h.h.BlockSize() + return h.H.BlockSize() } func (h *Hash) HashUint8(n uint8) { @@ -125,7 +129,7 @@ func (h *Hash) hashUint64Slow(n uint64) { h.hashUint(uint64(n), 8) } func (h *Hash) hashUint(n uint64, i int) { for ; i > 0; i-- { if h.nx == len(h.x) { - h.h.Write(h.x[:]) + h.H.Write(h.x[:]) h.nx = 0 } h.x[h.nx] = byte(n) @@ -140,14 +144,14 @@ func (h *Hash) HashBytes(b []byte) { n := copy(h.x[h.nx:], b) h.nx += n if h.nx == len(h.x) { - h.h.Write(h.x[:]) + h.H.Write(h.x[:]) h.nx = 0 } b = b[n:] } if len(b) >= len(h.x) { n := len(b) &^ (len(h.x) - 1) // n is a multiple of len(h.x) - h.h.Write(b[:n]) + h.H.Write(b[:n]) b = b[n:] } if len(b) > 0 { @@ -155,4 +159,14 @@ func (h *Hash) HashBytes(b []byte) { } } +func (h *Hash) HashString(s string) { + type stringHeader struct { + p unsafe.Pointer + n int + } + p := (*stringHeader)(unsafe.Pointer(&s)) + b := unsafe.Slice((*byte)(p.p), p.n) + h.HashBytes(b) +} + // TODO: Add Hash.MarshalBinary and Hash.UnmarshalBinary? diff --git a/util/sha256x/sha256_test.go b/util/sha256x/sha256_test.go index f9f631fbc..91fcc1acf 100644 --- a/util/sha256x/sha256_test.go +++ b/util/sha256x/sha256_test.go @@ -17,7 +17,7 @@ import ( // naiveHash is an obviously correct implementation of Hash. type naiveHash struct { hash.Hash - scratch [8]byte + scratch [256]byte } func newNaive() *naiveHash { return &naiveHash{Hash: sha256.New()} } @@ -26,6 +26,7 @@ func (h *naiveHash) HashUint16(n uint16) { h.Write(binary.LittleEndian.AppendUin func (h *naiveHash) HashUint32(n uint32) { h.Write(binary.LittleEndian.AppendUint32(h.scratch[:0], n)) } func (h *naiveHash) HashUint64(n uint64) { h.Write(binary.LittleEndian.AppendUint64(h.scratch[:0], n)) } func (h *naiveHash) HashBytes(b []byte) { h.Write(b) } +func (h *naiveHash) HashString(s string) { h.Write(append(h.scratch[:0], s...)) } var bytes = func() (out []byte) { out = make([]byte, 130) @@ -41,6 +42,7 @@ type hasher interface { HashUint32(uint32) HashUint64(uint64) HashBytes([]byte) + HashString(string) } func hashSuite(h hasher) { @@ -61,7 +63,12 @@ func hashSuite(h hasher) { h.HashUint16(0x89ab) h.HashUint8(0xcd) } - h.HashBytes(bytes[:(i+1)*13]) + b := bytes[:(i+1)*13] + if i%2 == 0 { + h.HashBytes(b) + } else { + h.HashString(string(b)) + } } } @@ -74,14 +81,51 @@ func Test(t *testing.T) { c.Assert(h1.Sum(nil), qt.DeepEquals, h2.Sum(nil)) } -func TestSumAllocations(t *testing.T) { +func TestAllocations(t *testing.T) { c := qt.New(t) - h := New() - n := testing.AllocsPerRun(100, func() { - var a [sha256.Size]byte - h.Sum(a[:0]) + c.Run("Sum", func(c *qt.C) { + h := New() + c.Assert(testing.AllocsPerRun(100, func() { + var a [sha256.Size]byte + h.Sum(a[:0]) + }), qt.Equals, 0.0) + }) + c.Run("HashUint8", func(c *qt.C) { + h := New() + c.Assert(testing.AllocsPerRun(100, func() { + h.HashUint8(0x01) + }), qt.Equals, 0.0) + }) + c.Run("HashUint16", func(c *qt.C) { + h := New() + c.Assert(testing.AllocsPerRun(100, func() { + h.HashUint16(0x0123) + }), qt.Equals, 0.0) + }) + c.Run("HashUint32", func(c *qt.C) { + h := New() + c.Assert(testing.AllocsPerRun(100, func() { + h.HashUint32(0x01234567) + }), qt.Equals, 0.0) + }) + c.Run("HashUint64", func(c *qt.C) { + h := New() + c.Assert(testing.AllocsPerRun(100, func() { + h.HashUint64(0x0123456789abcdef) + }), qt.Equals, 0.0) + }) + c.Run("HashBytes", func(c *qt.C) { + h := New() + c.Assert(testing.AllocsPerRun(100, func() { + h.HashBytes(bytes) + }), qt.Equals, 0.0) + }) + c.Run("HashString", func(c *qt.C) { + h := New() + c.Assert(testing.AllocsPerRun(100, func() { + h.HashString("abcdefghijklmnopqrstuvwxyz") + }), qt.Equals, 0.0) }) - c.Assert(n, qt.Equals, 0.0) } func Fuzz(f *testing.F) {