diff --git a/util/sha256x/sha256.go b/util/sha256x/sha256.go new file mode 100644 index 000000000..212e1a557 --- /dev/null +++ b/util/sha256x/sha256.go @@ -0,0 +1,149 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package sha256x is like crypto/sha256 with extra methods. +// It exports a concrete Hash type +// rather than only returning an interface implementation. +package sha256x + +import ( + "crypto/sha256" + "encoding/binary" + "hash" +) + +var _ hash.Hash = (*Hash)(nil) + +// Hash is a hash.Hash for SHA-256, +// but has efficient methods for hashing fixed-width integers. +type Hash struct { + // The optimization is to maintain our own block and + // only call h.Write with entire blocks. + // This avoids double-copying of buffers within sha256.digest itself. + // However, it does mean that sha256.digest.x goes unused, + // which is a waste of 64B. + + h hash.Hash // always *sha256.digest + x [sha256.BlockSize]byte // equivalent to sha256.digest.x + nx int // equivalent to sha256.digest.nx +} + +func New() *Hash { + return &Hash{h: sha256.New()} +} + +func (h *Hash) Write(b []byte) (int, error) { + h.HashBytes(b) + return len(b), nil +} + +func (h *Hash) Sum(b []byte) []byte { + if h.nx > 0 { + // This causes block mis-alignment. Future operations will be correct, + // but are less efficient until Reset is called. + h.h.Write(h.x[:h.nx]) + h.nx = 0 + } + return h.h.Sum(b) +} + +func (h *Hash) Reset() { + h.h.Reset() + h.nx = 0 +} + +func (h *Hash) Size() int { + return h.h.Size() +} + +func (h *Hash) BlockSize() int { + return h.h.BlockSize() +} + +func (h *Hash) HashUint8(n uint8) { + // NOTE: This method is carefully written to be inlineable. + if h.nx <= len(h.x)-1 { + h.x[h.nx] = n + h.nx += 1 + } else { + h.hashUint8Slow(n) // mark "noinline" to keep this within inline budget + } +} + +//go:noinline +func (h *Hash) hashUint8Slow(n uint8) { h.hashUint(uint64(n), 1) } + +func (h *Hash) HashUint16(n uint16) { + // NOTE: This method is carefully written to be inlineable. + if h.nx <= len(h.x)-2 { + binary.LittleEndian.PutUint16(h.x[h.nx:], n) + h.nx += 2 + } else { + h.hashUint16Slow(n) // mark "noinline" to keep this within inline budget + } +} + +//go:noinline +func (h *Hash) hashUint16Slow(n uint16) { h.hashUint(uint64(n), 2) } + +func (h *Hash) HashUint32(n uint32) { + // NOTE: This method is carefully written to be inlineable. + if h.nx <= len(h.x)-4 { + binary.LittleEndian.PutUint32(h.x[h.nx:], n) + h.nx += 4 + } else { + h.hashUint32Slow(n) // mark "noinline" to keep this within inline budget + } +} + +//go:noinline +func (h *Hash) hashUint32Slow(n uint32) { h.hashUint(uint64(n), 4) } + +func (h *Hash) HashUint64(n uint64) { + // NOTE: This method is carefully written to be inlineable. + if h.nx <= len(h.x)-8 { + binary.LittleEndian.PutUint64(h.x[h.nx:], n) + h.nx += 8 + } else { + h.hashUint64Slow(n) // mark "noinline" to keep this within inline budget + } +} + +//go:noinline +func (h *Hash) hashUint64Slow(n uint64) { h.hashUint(uint64(n), 8) } + +func (h *Hash) hashUint(n uint64, i int) { + for ; i > 0; i-- { + if h.nx == len(h.x) { + h.h.Write(h.x[:]) + h.nx = 0 + } + h.x[h.nx] = byte(n) + h.nx += 1 + n >>= 8 + } +} + +func (h *Hash) HashBytes(b []byte) { + // Nearly identical to sha256.digest.Write. + if h.nx > 0 { + n := copy(h.x[h.nx:], b) + h.nx += n + if h.nx == len(h.x) { + h.h.Write(h.x[:]) + h.nx = 0 + } + b = b[n:] + } + if len(b) >= len(h.x) { + n := len(b) &^ (len(h.x) - 1) // n is a multiple of len(h.x) + h.h.Write(b[:n]) + b = b[n:] + } + if len(b) > 0 { + h.nx = copy(h.x[:], b) + } +} + +// TODO: Add Hash.MarshalBinary and Hash.UnmarshalBinary? diff --git a/util/sha256x/sha256_test.go b/util/sha256x/sha256_test.go new file mode 100644 index 000000000..226ca8281 --- /dev/null +++ b/util/sha256x/sha256_test.go @@ -0,0 +1,149 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sha256x + +import ( + "crypto/sha256" + "encoding/binary" + "hash" + "math/rand" + "testing" + + qt "github.com/frankban/quicktest" +) + +// naiveHash is an obviously correct implementation of Hash. +type naiveHash struct { + hash.Hash + scratch [8]byte +} + +func newNaive() *naiveHash { return &naiveHash{Hash: sha256.New()} } +func (h *naiveHash) HashUint8(n uint8) { h.Write(append(h.scratch[:0], n)) } +func (h *naiveHash) HashUint16(n uint16) { h.Write(binary.LittleEndian.AppendUint16(h.scratch[:0], n)) } +func (h *naiveHash) HashUint32(n uint32) { h.Write(binary.LittleEndian.AppendUint32(h.scratch[:0], n)) } +func (h *naiveHash) HashUint64(n uint64) { h.Write(binary.LittleEndian.AppendUint64(h.scratch[:0], n)) } +func (h *naiveHash) HashBytes(b []byte) { h.Write(b) } + +var bytes = func() (out []byte) { + out = make([]byte, 130) + for i := range out { + out[i] = byte(i) + } + return out +}() + +type hasher interface { + HashUint8(uint8) + HashUint16(uint16) + HashUint32(uint32) + HashUint64(uint64) + HashBytes([]byte) +} + +func hashSuite(h hasher) { + for i := 0; i < 10; i++ { + for j := 0; j < 10; j++ { + h.HashUint8(0x01) + h.HashUint8(0x23) + h.HashUint32(0x456789ab) + h.HashUint8(0xcd) + h.HashUint8(0xef) + h.HashUint16(0x0123) + h.HashUint32(0x456789ab) + h.HashUint16(0xcdef) + h.HashUint8(0x01) + h.HashUint64(0x23456789abcdef01) + h.HashUint16(0x2345) + h.HashUint8(0x67) + h.HashUint16(0x89ab) + h.HashUint8(0xcd) + } + h.HashBytes(bytes[:(i+1)*13]) + } +} +func Test(t *testing.T) { + c := qt.New(t) + h1 := New() + h2 := newNaive() + hashSuite(h1) + hashSuite(h2) + c.Assert(h1.Sum(nil), qt.DeepEquals, h2.Sum(nil)) +} + +func Fuzz(f *testing.F) { + f.Fuzz(func(t *testing.T, seed int64) { + c := qt.New(t) + + execute := func(h hasher, r *rand.Rand) { + for i := 0; i < r.Intn(256); i++ { + switch r.Intn(5) { + case 0: + n := uint8(r.Uint64()) + h.HashUint8(n) + case 1: + n := uint16(r.Uint64()) + h.HashUint16(n) + case 2: + n := uint32(r.Uint64()) + h.HashUint32(n) + case 3: + n := uint64(r.Uint64()) + h.HashUint64(n) + case 4: + b := make([]byte, r.Intn(256)) + r.Read(b) + h.HashBytes(b) + } + } + } + + r1 := rand.New(rand.NewSource(seed)) + r2 := rand.New(rand.NewSource(seed)) + + h1 := New() + h2 := newNaive() + + execute(h1, r1) + execute(h2, r2) + + c.Assert(h1.Sum(nil), qt.DeepEquals, h2.Sum(nil)) + + execute(h1, r1) + execute(h2, r2) + + c.Assert(h1.Sum(nil), qt.DeepEquals, h2.Sum(nil)) + + h1.Reset() + h2.Reset() + + execute(h1, r1) + execute(h2, r2) + + c.Assert(h1.Sum(nil), qt.DeepEquals, h2.Sum(nil)) + }) +} + +func Benchmark(b *testing.B) { + var sum [sha256.Size]byte + b.Run("Hash", func(b *testing.B) { + b.ReportAllocs() + h := New() + for i := 0; i < b.N; i++ { + h.Reset() + hashSuite(h) + h.Sum(sum[:0]) + } + }) + b.Run("Naive", func(b *testing.B) { + b.ReportAllocs() + h := newNaive() + for i := 0; i < b.N; i++ { + h.Reset() + hashSuite(h) + h.Sum(sum[:0]) + } + }) +}