mirror of https://github.com/tailscale/tailscale/
Replace our ratelimiter with standard rate package (#359)
* Replace our ratelimiter with standard rate package Signed-off-by: Wendi Yu <wendi.yu@yahoo.ca>pull/364/head
parent
b01db109f5
commit
499c8fcbb3
@ -1,81 +0,0 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"tailscale.com/types/structs"
|
||||
)
|
||||
|
||||
type Bucket struct {
|
||||
_ structs.Incomparable
|
||||
mu sync.Mutex
|
||||
FillInterval time.Duration
|
||||
Burst int
|
||||
v int
|
||||
quitCh chan struct{}
|
||||
started bool
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (b *Bucket) startLocked() {
|
||||
b.v = b.Burst
|
||||
b.quitCh = make(chan struct{})
|
||||
b.started = true
|
||||
|
||||
t := time.NewTicker(b.FillInterval)
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-b.quitCh:
|
||||
return
|
||||
case <-t.C:
|
||||
b.tick()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (b *Bucket) tick() {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if b.v < b.Burst {
|
||||
b.v++
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Bucket) Close() {
|
||||
b.mu.Lock()
|
||||
if !b.started {
|
||||
b.closed = true
|
||||
b.mu.Unlock()
|
||||
return
|
||||
}
|
||||
if b.closed {
|
||||
b.mu.Unlock()
|
||||
return
|
||||
}
|
||||
b.closed = true
|
||||
b.mu.Unlock()
|
||||
|
||||
b.quitCh <- struct{}{}
|
||||
}
|
||||
|
||||
func (b *Bucket) TryGet() int {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if !b.started {
|
||||
b.startLocked()
|
||||
}
|
||||
if b.v > 0 {
|
||||
b.v--
|
||||
return b.v + 1
|
||||
}
|
||||
return 0
|
||||
}
|
@ -1,28 +0,0 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestBucket(t *testing.T) {
|
||||
b := Bucket{
|
||||
FillInterval: time.Second,
|
||||
Burst: 3,
|
||||
}
|
||||
expect := []int{3, 2, 1, 0, 0}
|
||||
for i, want := range expect {
|
||||
got := b.TryGet()
|
||||
if want != got {
|
||||
t.Errorf("#%d want=%d got=%d\n", i, want, got)
|
||||
}
|
||||
}
|
||||
b.tick()
|
||||
if want, got := 1, b.TryGet(); want != got {
|
||||
t.Errorf("after tick: want=%d got=%d\n", want, got)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue