From 79b7fa9ac30e4067b203dbd0f476ee7e4d347173 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Tue, 11 May 2021 13:17:12 -0700 Subject: [PATCH] internal/deephash: hash maps without sorting in the acyclic common case MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Hash and xor each entry instead, then write final xor'ed result. name old time/op new time/op delta Hash-4 33.6µs ± 4% 34.6µs ± 3% +3.03% (p=0.013 n=10+9) name old alloc/op new alloc/op delta Hash-4 1.86kB ± 0% 1.77kB ± 0% -5.10% (p=0.000 n=10+9) name old allocs/op new allocs/op delta Hash-4 51.0 ± 0% 49.0 ± 0% -3.92% (p=0.000 n=10+10) Signed-off-by: Brad Fitzpatrick --- internal/deephash/deephash.go | 126 ++++++++++++++++++++++++----- internal/deephash/deephash_test.go | 53 ++++++++++++ 2 files changed, 158 insertions(+), 21 deletions(-) diff --git a/internal/deephash/deephash.go b/internal/deephash/deephash.go index a82d04ba9..c162fbdc4 100644 --- a/internal/deephash/deephash.go +++ b/internal/deephash/deephash.go @@ -10,7 +10,9 @@ import ( "bufio" "crypto/sha256" "fmt" + "hash" "reflect" + "sync" "inet.af/netaddr" "tailscale.com/tailcfg" @@ -48,9 +50,11 @@ var ( tailcfgDiscoKeyType = reflect.TypeOf(tailcfg.DiscoKey{}) ) -func print(w *bufio.Writer, v reflect.Value, visited map[uintptr]bool) { +// print hashes v into w. +// It reports whether it was able to do so without hitting a cycle. +func print(w *bufio.Writer, v reflect.Value, visited map[uintptr]bool) (acyclic bool) { if !v.IsValid() { - return + return true } // Special case some common types. @@ -68,7 +72,7 @@ func print(w *bufio.Writer, v reflect.Value, visited map[uintptr]bool) { } if err == nil { w.Write(b) - return + return true } case netaddrIPPrefix: var b []byte @@ -82,7 +86,7 @@ func print(w *bufio.Writer, v reflect.Value, visited map[uintptr]bool) { } if err == nil { w.Write(b) - return + return true } case wgkeyKeyType: if v.CanAddr() { @@ -92,7 +96,7 @@ func print(w *bufio.Writer, v reflect.Value, visited map[uintptr]bool) { x := v.Interface().(wgkey.Key) w.Write(x[:]) } - return + return true case wgkeyPrivateType: if v.CanAddr() { x := v.Addr().Interface().(*wgkey.Private) @@ -101,7 +105,7 @@ func print(w *bufio.Writer, v reflect.Value, visited map[uintptr]bool) { x := v.Interface().(wgkey.Private) w.Write(x[:]) } - return + return true case tailcfgDiscoKeyType: if v.CanAddr() { x := v.Addr().Interface().(*tailcfg.DiscoKey) @@ -121,43 +125,45 @@ func print(w *bufio.Writer, v reflect.Value, visited map[uintptr]bool) { case reflect.Ptr: ptr := v.Pointer() if visited[ptr] { - return + return false } visited[ptr] = true - print(w, v.Elem(), visited) - return + return print(w, v.Elem(), visited) case reflect.Struct: + acyclic = true w.WriteString("struct{\n") for i, n := 0, v.NumField(); i < n; i++ { fmt.Fprintf(w, " [%d]: ", i) - print(w, v.Field(i), visited) + if !print(w, v.Field(i), visited) { + acyclic = false + } w.WriteString("\n") } w.WriteString("}\n") + return acyclic case reflect.Slice, reflect.Array: if v.Type().Elem().Kind() == reflect.Uint8 && v.CanInterface() { fmt.Fprintf(w, "%q", v.Interface()) - return + return true } fmt.Fprintf(w, "[%d]{\n", v.Len()) + acyclic = true for i, ln := 0, v.Len(); i < ln; i++ { fmt.Fprintf(w, " [%d]: ", i) - print(w, v.Index(i), visited) + if !print(w, v.Index(i), visited) { + acyclic = false + } w.WriteString("\n") } w.WriteString("}\n") + return acyclic case reflect.Interface: - print(w, v.Elem(), visited) + return print(w, v.Elem(), visited) case reflect.Map: - sm := newSortedMap(v) - fmt.Fprintf(w, "map[%d]{\n", len(sm.Key)) - for i, k := range sm.Key { - print(w, k, visited) - w.WriteString(": ") - print(w, sm.Value[i], visited) - w.WriteString("\n") + if hashMapAcyclic(w, v, visited) { + return true } - w.WriteString("}\n") + return hashMapFallback(w, v, visited) case reflect.String: w.WriteString(v.String()) case reflect.Bool: @@ -171,4 +177,82 @@ func print(w *bufio.Writer, v reflect.Value, visited map[uintptr]bool) { case reflect.Complex64, reflect.Complex128: fmt.Fprintf(w, "%v", v.Complex()) } + return true +} + +type mapHasher struct { + xbuf [sha256.Size]byte // XOR'ed accumulated buffer + ebuf [sha256.Size]byte // scratch buffer + s256 hash.Hash // sha256 hash.Hash + bw *bufio.Writer // to hasher into ebuf +} + +func (mh *mapHasher) Reset() { + for i := range mh.xbuf { + mh.xbuf[i] = 0 + } +} + +func (mh *mapHasher) startEntry() { + for i := range mh.ebuf { + mh.ebuf[i] = 0 + } + mh.bw.Flush() + mh.s256.Reset() +} + +func (mh *mapHasher) endEntry() { + mh.bw.Flush() + for i, b := range mh.s256.Sum(mh.ebuf[:0]) { + mh.xbuf[i] ^= b + } +} + +var mapHasherPool = &sync.Pool{ + New: func() interface{} { + mh := new(mapHasher) + mh.s256 = sha256.New() + mh.bw = bufio.NewWriter(mh.s256) + return mh + }, +} + +// hashMapAcyclic is the faster sort-free version of map hashing. If +// it detects a cycle it returns false and guarantees that nothing was +// written to w. +func hashMapAcyclic(w *bufio.Writer, v reflect.Value, visited map[uintptr]bool) (acyclic bool) { + mh := mapHasherPool.Get().(*mapHasher) + defer mapHasherPool.Put(mh) + mh.Reset() + iter := v.MapRange() + for iter.Next() { + mh.startEntry() + if !print(mh.bw, iter.Key(), visited) { + return false + } + if !print(mh.bw, iter.Value(), visited) { + return false + } + mh.endEntry() + } + w.Write(mh.xbuf[:]) + return true +} + +func hashMapFallback(w *bufio.Writer, v reflect.Value, visited map[uintptr]bool) (acyclic bool) { + acyclic = true + sm := newSortedMap(v) + fmt.Fprintf(w, "map[%d]{\n", len(sm.Key)) + for i, k := range sm.Key { + if !print(w, k, visited) { + acyclic = false + } + w.WriteString(": ") + if !print(w, sm.Value[i], visited) { + acyclic = false + } + w.WriteString("\n") + } + w.WriteString("}\n") + return acyclic } diff --git a/internal/deephash/deephash_test.go b/internal/deephash/deephash_test.go index 3d33bcd4a..1442b1ba5 100644 --- a/internal/deephash/deephash_test.go +++ b/internal/deephash/deephash_test.go @@ -5,6 +5,10 @@ package deephash import ( + "bufio" + "bytes" + "fmt" + "reflect" "testing" "inet.af/netaddr" @@ -79,3 +83,52 @@ func BenchmarkHash(b *testing.B) { Hash(v) } } + +func TestHashMapAcyclic(t *testing.T) { + m := map[int]string{} + for i := 0; i < 100; i++ { + m[i] = fmt.Sprint(i) + } + got := map[string]bool{} + + var buf bytes.Buffer + bw := bufio.NewWriter(&buf) + + for i := 0; i < 20; i++ { + visited := map[uintptr]bool{} + v := reflect.ValueOf(m) + buf.Reset() + bw.Reset(&buf) + if !hashMapAcyclic(bw, v, visited) { + t.Fatal("returned false") + } + if got[string(buf.Bytes())] { + continue + } + got[string(buf.Bytes())] = true + } + if len(got) != 1 { + t.Errorf("got %d results; want 1", len(got)) + } +} + +func BenchmarkHashMapAcyclic(b *testing.B) { + b.ReportAllocs() + m := map[int]string{} + for i := 0; i < 100; i++ { + m[i] = fmt.Sprint(i) + } + + var buf bytes.Buffer + bw := bufio.NewWriter(&buf) + visited := map[uintptr]bool{} + v := reflect.ValueOf(m) + + for i := 0; i < b.N; i++ { + buf.Reset() + bw.Reset(&buf) + if !hashMapAcyclic(bw, v, visited) { + b.Fatal("returned false") + } + } +}