diff --git a/util/limiter/limiter.go b/util/limiter/limiter.go new file mode 100644 index 000000000..8896f8604 --- /dev/null +++ b/util/limiter/limiter.go @@ -0,0 +1,149 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package limiter + +import ( + "sync" + "time" + + "tailscale.com/util/lru" +) + +// Limiter is a keyed token bucket rate limiter. +// +// Each key gets its own separate token bucket to pull from, enabling +// enforcement on things like "requests per IP address". To avoid +// unbounded memory growth, Limiter actually only tracks limits +// precisely for the N most recently seen keys, and assumes that +// untracked keys are well-behaved. This trades off absolute precision +// for bounded memory use, while still enforcing well for outlier +// keys. +// +// As such, Limiter should only be used in situations where "rough" +// enforcement of outliers only is sufficient, such as throttling +// egregious outlier keys (e.g. something sending 100 queries per +// second, where everyone else is sending at most 5). +// +// Each key's token bucket behaves like a regular token bucket, with +// the added feature that a bucket's token count can optionally go +// negative. This implements a form of "cooldown" for keys that exceed +// the rate limit: once a key starts getting denied, it must stop +// requesting tokens long enough for the bucket to return to a +// positive balance. If the key keeps hammering the limiter in excess +// of the rate limit, the token count will remain negative, and the +// key will not be allowed to proceed at all. This is in contrast to +// the classic token bucket, where a key trying to use more than the +// rate limit will get capped at the limit, but can still occasionally +// consume a token as one becomes available. +// +// The zero value is a valid limiter that rejects all requests. A +// useful limiter must specify a Size, Max and RefillInterval. +type Limiter[K comparable] struct { + // Size is the number of keys to track. Only the Size most + // recently seen keys have their limits enforced precisely, older + // keys are assumed to not be querying frequently enough to bother + // tracking. + Size int + + // Max is the number of tokens available for a key to consume + // before time-based rate limiting kicks in. An unused limiter + // regains available tokens over time, up to Max tokens. A newly + // tracked key initially receives Max tokens. + Max int64 + + // RefillInterval is the interval at which a key regains tokens for + // use, up to Max tokens. + RefillInterval time.Duration + + // Overdraft is the amount of additional tokens a key can be + // charged for when it exceeds its rate limit. Each additional + // request issued for the key charges one unit of overdraft, up to + // this limit. Overdraft tokens are refilled at the normal rate, + // and must be fully repaid before any tokens become available for + // requests. + // + // A non-zero Overdraft results in "cooldown" behavior: with a + // normal token bucket that bottoms out at zero tokens, an abusive + // key can still consume one token every RefillInterval. With a + // non-zero overdraft, a throttled key must stop requesting tokens + // entirely for a cooldown period, otherwise they remain + // perpetually in debt and cannot proceed at all. + Overdraft int64 + + mu sync.Mutex + cache *lru.Cache[K, *bucket] +} + +// QPSInterval returns the interval between events corresponding to +// the given queries/second rate. +// +// This is a helper to be used when populating Limiter.RefillInterval. +func QPSInterval(qps float64) time.Duration { + return time.Duration(float64(time.Second) / qps) +} + +type bucket struct { + cur int64 // current available tokens + lastUpdate time.Time // last timestamp at which cur was updated +} + +// Allow charges the key one token (up to the overdraft limit), and +// reports whether the key can perform an action. +func (l *Limiter[K]) Allow(key K) bool { + return l.allow(key, time.Now()) +} + +func (l *Limiter[K]) allow(key K, now time.Time) bool { + l.mu.Lock() + defer l.mu.Unlock() + return l.allowBucketLocked(l.getBucketLocked(key, now), now) +} + +func (l *Limiter[K]) getBucketLocked(key K, now time.Time) *bucket { + if l.cache == nil { + l.cache = &lru.Cache[K, *bucket]{MaxEntries: l.Size} + } else if b := l.cache.Get(key); b != nil { + return b + } + b := &bucket{ + cur: l.Max, + lastUpdate: now.Truncate(l.RefillInterval), + } + l.cache.Set(key, b) + return b +} + +func (l *Limiter[K]) allowBucketLocked(b *bucket, now time.Time) bool { + // Only update the bucket quota if needed to process request. + if b.cur <= 0 { + l.updateBucketLocked(b, now) + } + ret := b.cur > 0 + if b.cur > -l.Overdraft { + b.cur-- + } + return ret +} + +func (l *Limiter[K]) updateBucketLocked(b *bucket, now time.Time) { + now = now.Truncate(l.RefillInterval) + if now.Before(b.lastUpdate) { + return + } + timeDelta := max(now.Sub(b.lastUpdate), 0) + tokenDelta := int64(timeDelta / l.RefillInterval) + b.cur = min(b.cur+tokenDelta, l.Max) + b.lastUpdate = now +} + +// peekForTest returns the number of tokens for key, also reporting +// whether key was present. +func (l *Limiter[K]) tokensForTest(key K) (int64, bool) { + l.mu.Lock() + defer l.mu.Unlock() + if b, ok := l.cache.PeekOk(key); ok { + return b.cur, true + } + return 0, false +} diff --git a/util/limiter/limiter_test.go b/util/limiter/limiter_test.go new file mode 100644 index 000000000..fdf0d6b7d --- /dev/null +++ b/util/limiter/limiter_test.go @@ -0,0 +1,151 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package limiter + +import ( + "testing" + "time" +) + +const testRefillInterval = time.Second + +func TestLimiter(t *testing.T) { + // 1qps, burst of 10, 2 keys tracked + l := &Limiter[string]{ + Size: 2, + Max: 10, + RefillInterval: testRefillInterval, + } + + // Consume entire burst + now := time.Now().Truncate(testRefillInterval) + allowed(t, l, "foo", 10, now) + denied(t, l, "foo", 1, now) + hasTokens(t, l, "foo", 0) + + allowed(t, l, "bar", 10, now) + denied(t, l, "bar", 1, now) + hasTokens(t, l, "bar", 0) + + // Refill 1 token for both foo and bar + now = now.Add(time.Second + time.Millisecond) + allowed(t, l, "foo", 1, now) + denied(t, l, "foo", 1, now) + hasTokens(t, l, "foo", 0) + + allowed(t, l, "bar", 1, now) + denied(t, l, "bar", 1, now) + hasTokens(t, l, "bar", 0) + + // Refill 2 tokens for foo and bar + now = now.Add(2*time.Second + time.Millisecond) + allowed(t, l, "foo", 2, now) + denied(t, l, "foo", 1, now) + hasTokens(t, l, "foo", 0) + + allowed(t, l, "bar", 2, now) + denied(t, l, "bar", 1, now) + hasTokens(t, l, "bar", 0) + + // qux can burst 10, evicts foo so it can immediately burst 10 again too + allowed(t, l, "qux", 10, now) + denied(t, l, "qux", 1, now) + notInLimiter(t, l, "foo") + denied(t, l, "bar", 1, now) // refresh bar so foo lookup doesn't evict it - still throttled + + allowed(t, l, "foo", 10, now) + denied(t, l, "foo", 1, now) + hasTokens(t, l, "foo", 0) +} + +func TestLimiterOverdraft(t *testing.T) { + // 1qps, burst of 10, overdraft of 2, 2 keys tracked + l := &Limiter[string]{ + Size: 2, + Max: 10, + Overdraft: 2, + RefillInterval: testRefillInterval, + } + + // Consume entire burst, go 1 into debt + now := time.Now().Truncate(testRefillInterval).Add(time.Millisecond) + allowed(t, l, "foo", 10, now) + denied(t, l, "foo", 1, now) + hasTokens(t, l, "foo", -1) + + allowed(t, l, "bar", 10, now) + denied(t, l, "bar", 1, now) + hasTokens(t, l, "bar", -1) + + // Refill 1 token for both foo and bar. + // Still denied, still in debt. + now = now.Add(time.Second) + denied(t, l, "foo", 1, now) + hasTokens(t, l, "foo", -1) + denied(t, l, "bar", 1, now) + hasTokens(t, l, "bar", -1) + + // Refill 2 tokens for foo and bar (1 available after debt), try + // to consume 4. Overdraft is capped to 2. + now = now.Add(2 * time.Second) + allowed(t, l, "foo", 1, now) + denied(t, l, "foo", 3, now) + hasTokens(t, l, "foo", -2) + + allowed(t, l, "bar", 1, now) + denied(t, l, "bar", 3, now) + hasTokens(t, l, "bar", -2) + + // Refill 1, not enough to allow. + now = now.Add(time.Second) + denied(t, l, "foo", 1, now) + hasTokens(t, l, "foo", -2) + denied(t, l, "bar", 1, now) + hasTokens(t, l, "bar", -2) + + // qux evicts foo, foo can immediately burst 10 again. + allowed(t, l, "qux", 1, now) + hasTokens(t, l, "qux", 9) + notInLimiter(t, l, "foo") + allowed(t, l, "foo", 10, now) + denied(t, l, "foo", 1, now) + hasTokens(t, l, "foo", -1) +} + +func allowed(t *testing.T, l *Limiter[string], key string, count int, now time.Time) { + t.Helper() + for i := 0; i < count; i++ { + if !l.allow(key, now) { + toks, ok := l.tokensForTest(key) + t.Errorf("after %d times: allow(%q, %q) = false, want true (%d tokens available, in cache = %v)", i, key, now, toks, ok) + } + } +} + +func denied(t *testing.T, l *Limiter[string], key string, count int, now time.Time) { + t.Helper() + for i := 0; i < count; i++ { + if l.allow(key, now) { + toks, ok := l.tokensForTest(key) + t.Errorf("after %d times: allow(%q, %q) = true, want false (%d tokens available, in cache = %v)", i, key, now, toks, ok) + } + } +} + +func hasTokens(t *testing.T, l *Limiter[string], key string, want int64) { + t.Helper() + got, ok := l.tokensForTest(key) + if !ok { + t.Errorf("key %q missing from limiter", key) + } else if got != want { + t.Errorf("key %q has %d tokens, want %d", key, got, want) + } +} + +func notInLimiter(t *testing.T, l *Limiter[string], key string) { + t.Helper() + if tokens, ok := l.tokensForTest(key); ok { + t.Errorf("key %q unexpectedly tracked by limiter, with %d tokens", key, tokens) + } +} diff --git a/util/lru/lru.go b/util/lru/lru.go index 639e7c91b..a6790cd46 100644 --- a/util/lru/lru.go +++ b/util/lru/lru.go @@ -84,6 +84,20 @@ func (c *Cache[K, V]) GetOk(key K) (value V, ok bool) { return zero, false } +// PeekOk looks up the key's value from the cache, also reporting +// whether it was present. +// +// Unlike GetOk, PeekOk does not move key to the front of the +// LRU. This should mostly be used for non-intrusive debug inspection +// of the cache. +func (c *Cache[K, V]) PeekOk(key K) (value V, ok bool) { + if ele, hit := c.m[key]; hit { + return ele.Value.(*entry[K, V]).value, true + } + var zero V + return zero, false +} + // Delete removes the provided key from the cache if it was present. func (c *Cache[K, V]) Delete(key K) { if e, ok := c.m[key]; ok {