// Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause // Package limiter provides a keyed token bucket rate limiter. package limiter import ( "fmt" "html" "io" "sync" "time" "tailscale.com/util/lru" ) // Limiter is a keyed token bucket rate limiter. // // Each key gets its own separate token bucket to pull from, enabling // enforcement on things like "requests per IP address". To avoid // unbounded memory growth, Limiter actually only tracks limits // precisely for the N most recently seen keys, and assumes that // untracked keys are well-behaved. This trades off absolute precision // for bounded memory use, while still enforcing well for outlier // keys. // // As such, Limiter should only be used in situations where "rough" // enforcement of outliers only is sufficient, such as throttling // egregious outlier keys (e.g. something sending 100 queries per // second, where everyone else is sending at most 5). // // Each key's token bucket behaves like a regular token bucket, with // the added feature that a bucket's token count can optionally go // negative. This implements a form of "cooldown" for keys that exceed // the rate limit: once a key starts getting denied, it must stop // requesting tokens long enough for the bucket to return to a // positive balance. If the key keeps hammering the limiter in excess // of the rate limit, the token count will remain negative, and the // key will not be allowed to proceed at all. This is in contrast to // the classic token bucket, where a key trying to use more than the // rate limit will get capped at the limit, but can still occasionally // consume a token as one becomes available. // // The zero value is a valid limiter that rejects all requests. A // useful limiter must specify a Size, Max and RefillInterval. type Limiter[K comparable] struct { // Size is the number of keys to track. Only the Size most // recently seen keys have their limits enforced precisely, older // keys are assumed to not be querying frequently enough to bother // tracking. Size int // Max is the number of tokens available for a key to consume // before time-based rate limiting kicks in. An unused limiter // regains available tokens over time, up to Max tokens. A newly // tracked key initially receives Max tokens. Max int64 // RefillInterval is the interval at which a key regains tokens for // use, up to Max tokens. RefillInterval time.Duration // Overdraft is the amount of additional tokens a key can be // charged for when it exceeds its rate limit. Each additional // request issued for the key charges one unit of overdraft, up to // this limit. Overdraft tokens are refilled at the normal rate, // and must be fully repaid before any tokens become available for // requests. // // A non-zero Overdraft results in "cooldown" behavior: with a // normal token bucket that bottoms out at zero tokens, an abusive // key can still consume one token every RefillInterval. With a // non-zero overdraft, a throttled key must stop requesting tokens // entirely for a cooldown period, otherwise they remain // perpetually in debt and cannot proceed at all. Overdraft int64 mu sync.Mutex cache *lru.Cache[K, *bucket] } // QPSInterval returns the interval between events corresponding to // the given queries/second rate. // // This is a helper to be used when populating Limiter.RefillInterval. func QPSInterval(qps float64) time.Duration { return time.Duration(float64(time.Second) / qps) } type bucket struct { cur int64 // current available tokens lastUpdate time.Time // last timestamp at which cur was updated } // Allow charges the key one token (up to the overdraft limit), and // reports whether the key can perform an action. func (l *Limiter[K]) Allow(key K) bool { return l.allow(key, time.Now()) } func (l *Limiter[K]) allow(key K, now time.Time) bool { l.mu.Lock() defer l.mu.Unlock() return l.allowBucketLocked(l.getBucketLocked(key, now), now) } func (l *Limiter[K]) getBucketLocked(key K, now time.Time) *bucket { if l.cache == nil { l.cache = &lru.Cache[K, *bucket]{MaxEntries: l.Size} } else if b := l.cache.Get(key); b != nil { return b } b := &bucket{ cur: l.Max, lastUpdate: now.Truncate(l.RefillInterval), } l.cache.Set(key, b) return b } func (l *Limiter[K]) allowBucketLocked(b *bucket, now time.Time) bool { // Only update the bucket quota if needed to process request. if b.cur <= 0 { l.updateBucketLocked(b, now) } ret := b.cur > 0 if b.cur > -l.Overdraft { b.cur-- } return ret } func (l *Limiter[K]) updateBucketLocked(b *bucket, now time.Time) { now = now.Truncate(l.RefillInterval) if now.Before(b.lastUpdate) { return } timeDelta := max(now.Sub(b.lastUpdate), 0) tokenDelta := int64(timeDelta / l.RefillInterval) b.cur = min(b.cur+tokenDelta, l.Max) b.lastUpdate = now } // peekForTest returns the number of tokens for key, also reporting // whether key was present. func (l *Limiter[K]) tokensForTest(key K) (int64, bool) { l.mu.Lock() defer l.mu.Unlock() if b, ok := l.cache.PeekOk(key); ok { return b.cur, true } return 0, false } // DumpHTML writes the state of the limiter to the given writer, // formatted as an HTML table. If onlyLimited is true, the output only // lists keys that are currently being limited. // // DumpHTML blocks other callers of the limiter while it collects the // state for dumping. It should not be called on large limiters // involved in hot codepaths. func (l *Limiter[K]) DumpHTML(w io.Writer, onlyLimited bool) { l.dumpHTML(w, onlyLimited, time.Now()) } func (l *Limiter[K]) dumpHTML(w io.Writer, onlyLimited bool, now time.Time) { dump := l.collectDump(now) io.WriteString(w, "") for _, line := range dump { if onlyLimited && line.Tokens > 0 { continue } kStr := html.EscapeString(fmt.Sprint(line.Key)) format := "" if !onlyLimited && line.Tokens <= 0 { // Make limited entries stand out when showing // limited+non-limited together format = "" } fmt.Fprintf(w, format, kStr, line.Tokens) } io.WriteString(w, "
KeyTokens
%s%d
%s%d
") } // collectDump grabs a copy of the limiter state needed by DumpHTML. func (l *Limiter[K]) collectDump(now time.Time) []dumpEntry[K] { l.mu.Lock() defer l.mu.Unlock() ret := make([]dumpEntry[K], 0, l.cache.Len()) l.cache.ForEach(func(k K, v *bucket) { l.updateBucketLocked(v, now) // so stats are accurate ret = append(ret, dumpEntry[K]{k, v.cur}) }) return ret } // dumpEntry is the per-key information that DumpHTML needs to print // limiter state. type dumpEntry[K comparable] struct { Key K Tokens int64 }