diff --git a/util/nocasemaps/nocase.go b/util/nocasemaps/nocase.go new file mode 100644 index 000000000..eaaaef559 --- /dev/null +++ b/util/nocasemaps/nocase.go @@ -0,0 +1,100 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// nocasemaps provides efficient functions to set and get entries in Go maps +// keyed by a string, where the string is always lower-case. +package nocasemaps + +import ( + "unicode" + "unicode/utf8" +) + +// TODO(https://github.com/golang/go/discussions/54245): +// Define a generic Map type instead. The main reason to avoid that is because +// there is currently no convenient API for iteration. +// An opaque Map type would force callers to interact with the map through +// the methods, preventing accidental interactions with the underlying map +// without using functions in this package. + +const stackArraySize = 32 + +// Get is equivalent to: +// +// v := m[strings.ToLower(k)] +func Get[K ~string, V any](m map[K]V, k K) V { + if isLowerASCII(string(k)) { + return m[k] + } + var a [stackArraySize]byte + return m[K(appendToLower(a[:0], string(k)))] +} + +// GetOk is equivalent to: +// +// v, ok := m[strings.ToLower(k)] +func GetOk[K ~string, V any](m map[K]V, k K) (V, bool) { + if isLowerASCII(string(k)) { + v, ok := m[k] + return v, ok + } + var a [stackArraySize]byte + v, ok := m[K(appendToLower(a[:0], string(k)))] + return v, ok +} + +// Set is equivalent to: +// +// m[strings.ToLower(k)] = v +func Set[K ~string, V any](m map[K]V, k K, v V) { + if isLowerASCII(string(k)) { + m[k] = v + return + } + // TODO(https://go.dev/issues/55930): This currently always allocates. + // An optimization to the compiler and runtime could make this allocate-free + // in the event that we are overwriting a map entry. + // + // Alternatively, we could use string interning. + // See an example intern data structure, see: + // https://github.com/go-json-experiment/json/blob/master/intern.go + var a [stackArraySize]byte + m[K(appendToLower(a[:0], string(k)))] = v +} + +// Delete is equivalent to: +// +// delete(m, strings.ToLower(k)) +func Delete[K ~string, V any](m map[K]V, k K) { + if isLowerASCII(string(k)) { + delete(m, k) + return + } + var a [stackArraySize]byte + delete(m, K(appendToLower(a[:0], string(k)))) +} + +func isLowerASCII(s string) bool { + for i := 0; i < len(s); i++ { + if c := s[i]; c >= utf8.RuneSelf || ('A' <= c && c <= 'Z') { + return false + } + } + return true +} + +func appendToLower(b []byte, s string) []byte { + for i := 0; i < len(s); i++ { + switch c := s[i]; { + case 'A' <= c && c <= 'Z': + b = append(b, c+('a'-'A')) + case c < utf8.RuneSelf: + b = append(b, c) + default: + r, n := utf8.DecodeRuneInString(s[i:]) + b = utf8.AppendRune(b, unicode.ToLower(r)) + i += n - 1 // -1 to compensate for i++ in loop advancement + } + } + return b +} diff --git a/util/nocasemaps/nocase_test.go b/util/nocasemaps/nocase_test.go new file mode 100644 index 000000000..310f82889 --- /dev/null +++ b/util/nocasemaps/nocase_test.go @@ -0,0 +1,143 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package nocasemaps + +import ( + "strings" + "testing" + + qt "github.com/frankban/quicktest" + xmaps "golang.org/x/exp/maps" +) + +func pair[A, B any](a A, b B) (out struct { + A A + B B +}) { + out.A = a + out.B = b + return out +} + +func Test(t *testing.T) { + c := qt.New(t) + m := make(map[string]int) + Set(m, "hello", 1) + c.Assert(m, qt.DeepEquals, map[string]int{"hello": 1}) + Set(m, "HeLlO", 2) + c.Assert(m, qt.DeepEquals, map[string]int{"hello": 2}) + c.Assert(Get(m, "hello"), qt.Equals, 2) + c.Assert(pair(GetOk(m, "hello")), qt.Equals, pair(2, true)) + c.Assert(Get(m, "HeLlO"), qt.Equals, 2) + c.Assert(pair(GetOk(m, "HeLlO")), qt.Equals, pair(2, true)) + c.Assert(Get(m, "HELLO"), qt.Equals, 2) + c.Assert(pair(GetOk(m, "HELLO")), qt.Equals, pair(2, true)) + c.Assert(Get(m, "missing"), qt.Equals, 0) + c.Assert(pair(GetOk(m, "missing")), qt.Equals, pair(0, false)) + Set(m, "foo", 3) + Set(m, "BAR", 4) + Set(m, "bAz", 5) + c.Assert(m, qt.DeepEquals, map[string]int{"hello": 2, "foo": 3, "bar": 4, "baz": 5}) + Delete(m, "foo") + c.Assert(m, qt.DeepEquals, map[string]int{"hello": 2, "bar": 4, "baz": 5}) + Delete(m, "bar") + c.Assert(m, qt.DeepEquals, map[string]int{"hello": 2, "baz": 5}) + Delete(m, "BAZ") + c.Assert(m, qt.DeepEquals, map[string]int{"hello": 2}) +} + +var lowerTests = []struct{ in, want string }{ + {"", ""}, + {"abc", "abc"}, + {"AbC123", "abc123"}, + {"azAZ09_", "azaz09_"}, + {"longStrinGwitHmixofsmaLLandcAps", "longstringwithmixofsmallandcaps"}, + {"renan bastos 93 AOSDAJDJAIDJAIDAJIaidsjjaidijadsjiadjiOOKKO", "renan bastos 93 aosdajdjaidjaidajiaidsjjaidijadsjiadjiookko"}, + {"LONG\u2C6FSTRING\u2C6FWITH\u2C6FNONASCII\u2C6FCHARS", "long\u0250string\u0250with\u0250nonascii\u0250chars"}, + {"\u2C6D\u2C6D\u2C6D\u2C6D\u2C6D", "\u0251\u0251\u0251\u0251\u0251"}, // shrinks one byte per char + {"A\u0080\U0010FFFF", "a\u0080\U0010FFFF"}, // test utf8.RuneSelf and utf8.MaxRune +} + +func TestAppendToLower(t *testing.T) { + for _, tt := range lowerTests { + got := string(appendToLower(nil, tt.in)) + if got != tt.want { + t.Errorf("appendToLower(%q) = %q, want %q", tt.in, got, tt.want) + } + } +} + +func FuzzAppendToLower(f *testing.F) { + for _, tt := range lowerTests { + f.Add(tt.in) + } + f.Fuzz(func(t *testing.T, in string) { + got := string(appendToLower(nil, in)) + want := strings.ToLower(in) + if got != want { + t.Errorf("appendToLower(%q) = %q, want %q", in, got, want) + } + }) +} + +var ( + testLower = "production-server" + testUpper = "PRODUCTION-SERVER" + testMap = make(map[string]int) + testValue = 5 + testSink int +) + +func Benchmark(b *testing.B) { + for i, key := range []string{testLower, testUpper} { + b.Run([]string{"Lower", "Upper"}[i], func(b *testing.B) { + b.Run("Get", func(b *testing.B) { + b.Run("Naive", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + testSink = testMap[strings.ToLower(key)] + } + }) + b.Run("NoCase", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + testSink = Get(testMap, key) + } + }) + }) + b.Run("Set", func(b *testing.B) { + b.Run("Naive", func(b *testing.B) { + b.ReportAllocs() + testMap[strings.ToLower(key)] = testValue + for i := 0; i < b.N; i++ { + testMap[strings.ToLower(key)] = testValue + } + xmaps.Clear(testMap) + }) + b.Run("NoCase", func(b *testing.B) { + b.ReportAllocs() + Set(testMap, key, testValue) + for i := 0; i < b.N; i++ { + Set(testMap, key, testValue) + } + xmaps.Clear(testMap) + }) + }) + b.Run("Delete", func(b *testing.B) { + b.Run("Naive", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + delete(testMap, strings.ToLower(key)) + } + }) + b.Run("NoCase", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + Delete(testMap, key) + } + }) + }) + }) + } +}