derp: throttle client sends if server advertises rate limits

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/2865/head
Brad Fitzpatrick 3 years ago committed by Brad Fitzpatrick
parent d43fcd2f02
commit 73f177e4d5

@ -16,6 +16,7 @@ import (
"time"
"golang.org/x/crypto/nacl/box"
"golang.org/x/time/rate"
"tailscale.com/types/key"
"tailscale.com/types/logger"
)
@ -32,8 +33,9 @@ type Client struct {
canAckPings bool
isProber bool
wmu sync.Mutex // hold while writing to bw
bw *bufio.Writer
wmu sync.Mutex // hold while writing to bw
bw *bufio.Writer
rate *rate.Limiter // if non-nil, rate limiter to use
// Owned by Recv:
peeked int // bytes to discard on next Recv
@ -217,7 +219,12 @@ func (c *Client) send(dstKey key.Public, pkt []byte) (ret error) {
c.wmu.Lock()
defer c.wmu.Unlock()
if c.rate != nil {
pktLen := frameHeaderLen + len(dstKey) + len(pkt)
if !c.rate.AllowN(time.Now(), pktLen) {
return nil // drop
}
}
if err := writeFrameHeader(c.bw, frameSendPacket, uint32(len(dstKey)+len(pkt))); err != nil {
return err
}
@ -353,7 +360,22 @@ type PeerPresentMessage key.Public
func (PeerPresentMessage) msg() {}
// ServerInfoMessage is sent by the server upon first connect.
type ServerInfoMessage struct{}
type ServerInfoMessage struct {
// TokenBucketBytesPerSecond is how many bytes per second the
// server says it will accept, including all framing bytes.
//
// Zero means unspecified. There might be a limit, but the
// client need not try to respect it.
TokenBucketBytesPerSecond int
// TokenBucketBytesBurst is how many bytes the server will
// allow to burst, temporarily violating
// TokenBucketBytesPerSecond.
//
// Zero means unspecified. There might be a limit, but the
// client need not try to respect it.
TokenBucketBytesBurst int
}
func (ServerInfoMessage) msg() {}
@ -475,12 +497,16 @@ func (c *Client) recvTimeout(timeout time.Duration) (m ReceivedMessage, err erro
// needing to wait an RTT to discover the version at startup.
// We'd prefer to give the connection to the client (magicsock)
// to start writing as soon as possible.
_, err := c.parseServerInfo(b)
si, err := c.parseServerInfo(b)
if err != nil {
return nil, fmt.Errorf("invalid server info frame: %v", err)
}
// TODO: add the results of parseServerInfo to ServerInfoMessage if we ever need it.
return ServerInfoMessage{}, nil
sm := ServerInfoMessage{
TokenBucketBytesPerSecond: si.TokenBucketBytesPerSecond,
TokenBucketBytesBurst: si.TokenBucketBytesBurst,
}
c.setSendRateLimiter(sm)
return sm, nil
case frameKeepAlive:
// A one-way keep-alive message that doesn't require an acknowledgement.
// This predated framePing/framePong.
@ -537,3 +563,16 @@ func (c *Client) recvTimeout(timeout time.Duration) (m ReceivedMessage, err erro
}
}
}
func (c *Client) setSendRateLimiter(sm ServerInfoMessage) {
c.wmu.Lock()
defer c.wmu.Unlock()
if sm.TokenBucketBytesPerSecond == 0 {
c.rate = nil
} else {
c.rate = rate.NewLimiter(
rate.Limit(sm.TokenBucketBytesPerSecond),
sm.TokenBucketBytesBurst)
}
}

@ -1079,6 +1079,9 @@ func (s *Server) noteClientActivity(c *sclient) {
type serverInfo struct {
Version int `json:"version,omitempty"`
TokenBucketBytesPerSecond int `json:",omitempty"`
TokenBucketBytesBurst int `json:",omitempty"`
}
func (s *Server) sendServerInfo(bw *lazyBufioWriter, clientKey key.Public) error {

@ -1244,3 +1244,80 @@ func TestParseSSOutput(t *testing.T) {
t.Errorf("parseSSOutput expected non-empty map")
}
}
type countWriter struct {
mu sync.Mutex
writes int
bytes int64
}
func (w *countWriter) Write(p []byte) (n int, err error) {
w.mu.Lock()
defer w.mu.Unlock()
w.writes++
w.bytes += int64(len(p))
return len(p), nil
}
func (w *countWriter) Stats() (writes int, bytes int64) {
w.mu.Lock()
defer w.mu.Unlock()
return w.writes, w.bytes
}
func (w *countWriter) ResetStats() {
w.mu.Lock()
defer w.mu.Unlock()
w.writes, w.bytes = 0, 0
}
func TestClientSendRateLimiting(t *testing.T) {
cw := new(countWriter)
c := &Client{
bw: bufio.NewWriter(cw),
}
c.setSendRateLimiter(ServerInfoMessage{})
pkt := make([]byte, 1000)
if err := c.send(key.Public{}, pkt); err != nil {
t.Fatal(err)
}
writes1, bytes1 := cw.Stats()
if writes1 != 1 {
t.Errorf("writes = %v, want 1", writes1)
}
// Flood should all succeed.
cw.ResetStats()
for i := 0; i < 1000; i++ {
if err := c.send(key.Public{}, pkt); err != nil {
t.Fatal(err)
}
}
writes1K, bytes1K := cw.Stats()
if writes1K != 1000 {
t.Logf("writes = %v; want 1000", writes1K)
}
if got, want := bytes1K, bytes1*1000; got != want {
t.Logf("bytes = %v; want %v", got, want)
}
// Set a rate limiter
cw.ResetStats()
c.setSendRateLimiter(ServerInfoMessage{
TokenBucketBytesPerSecond: 1,
TokenBucketBytesBurst: int(bytes1 * 2),
})
for i := 0; i < 1000; i++ {
if err := c.send(key.Public{}, pkt); err != nil {
t.Fatal(err)
}
}
writesLimited, bytesLimited := cw.Stats()
if writesLimited == 0 || writesLimited == writes1K {
t.Errorf("limited conn's write count = %v; want non-zero, less than 1k", writesLimited)
}
if bytesLimited < bytes1*2 || bytesLimited >= bytes1K {
t.Errorf("limited conn's bytes count = %v; want >=%v, <%v", bytesLimited, bytes1K*2, bytes1K)
}
}

Loading…
Cancel
Save