diff --git a/derp/derp_server.go b/derp/derp_server.go index 70957a138..42855e0b3 100644 --- a/derp/derp_server.go +++ b/derp/derp_server.go @@ -36,6 +36,7 @@ import ( "go4.org/mem" "golang.org/x/crypto/nacl/box" "golang.org/x/sync/errgroup" + "golang.org/x/time/rate" "inet.af/netaddr" "tailscale.com/client/tailscale" "tailscale.com/disco" @@ -118,6 +119,8 @@ type Server struct { curClients expvar.Int curHomeClients expvar.Int // ones with preferred clientsReplaced expvar.Int + clientsReplaceLimited expvar.Int + clientsReplaceSleeping expvar.Int unknownFrames expvar.Int homeMovesIn expvar.Int // established clients announce home server moves in homeMovesOut expvar.Int // established clients announce home server moves out @@ -346,14 +349,28 @@ func (s *Server) initMetacert() { func (s *Server) MetaCert() []byte { return s.metaCert } // registerClient notes that client c is now authenticated and ready for packets. -// If c's public key was already connected with a different connection, the prior one is closed. -func (s *Server) registerClient(c *sclient) { +// +// If c's public key was already connected with a different +// connection, the prior one is closed, unless it's fighting rapidly +// with another client with the same key, in which case the returned +// ok is false, and the caller should wait the provided duration +// before trying again. +func (s *Server) registerClient(c *sclient) (ok bool, d time.Duration) { s.mu.Lock() defer s.mu.Unlock() old := s.clients[c.key] if old == nil { c.logf("adding connection") } else { + // Take over the old rate limiter, discarding the one + // our caller just made. + c.replaceLimiter = old.replaceLimiter + if rr := c.replaceLimiter.ReserveN(timeNow(), 1); rr.OK() { + if d := rr.DelayFrom(timeNow()); d > 0 { + s.clientsReplaceLimited.Add(1) + return false, d + } + } s.clientsReplaced.Add(1) c.logf("adding connection, replacing %s", old.remoteAddr) go old.nc.Close() @@ -365,6 +382,7 @@ func (s *Server) registerClient(c *sclient) { s.keyOfAddr[c.remoteIPPort] = c.key s.curClients.Add(1) s.broadcastPeerStateChangeLocked(c.key, true) + return true, 0 } // broadcastPeerStateChangeLocked enqueues a message to all watchers @@ -490,7 +508,14 @@ func (s *Server) accept(nc Conn, brw *bufio.ReadWriter, remoteAddr string, connN discoSendQueue: make(chan pkt, perClientSendQueueDepth), peerGone: make(chan key.Public), canMesh: clientInfo.MeshKey != "" && clientInfo.MeshKey == s.meshKey, + + // Allow kicking out previous connections once a + // minute, with a very high burst of 100. Once a + // minute is less than the client's 2 minute + // inactivity timeout. + replaceLimiter: rate.NewLimiter(rate.Every(time.Minute), 100), } + if c.canMesh { c.meshUpdate = make(chan struct{}) } @@ -498,7 +523,15 @@ func (s *Server) accept(nc Conn, brw *bufio.ReadWriter, remoteAddr string, connN c.info = *clientInfo } - s.registerClient(c) + for { + ok, d := s.registerClient(c) + if ok { + break + } + s.clientsReplaceSleeping.Add(1) + timeSleep(d) + s.clientsReplaceSleeping.Add(-1) + } defer s.unregisterClient(c) err = s.sendServerInfo(bw, clientKey) @@ -509,6 +542,12 @@ func (s *Server) accept(nc Conn, brw *bufio.ReadWriter, remoteAddr string, connN return c.run(ctx) } +// for testing +var ( + timeSleep = time.Sleep + timeNow = time.Now +) + // run serves the client until there's an error. // If the client hangs up or the server is closed, run returns nil, otherwise run returns an error. func (c *sclient) run(ctx context.Context) error { @@ -952,6 +991,11 @@ type sclient struct { meshUpdate chan struct{} // write request to write peerStateChange canMesh bool // clientInfo had correct mesh token for inter-region routing + // replaceLimiter controls how quickly two connections with + // the same client key can kick each other off the server by + // taking over ownership of a key. + replaceLimiter *rate.Limiter + // Owned by run, not thread-safe. br *bufio.Reader connectedAt time.Time @@ -1351,6 +1395,8 @@ func (s *Server) ExpVar() expvar.Var { m.Set("gauge_clients_remote", expvar.Func(func() interface{} { return len(s.clientsMesh) - len(s.clients) })) m.Set("accepts", &s.accepts) m.Set("clients_replaced", &s.clientsReplaced) + m.Set("clients_replace_limited", &s.clientsReplaceLimited) + m.Set("gauge_clients_replace_sleeping", &s.clientsReplaceSleeping) m.Set("bytes_received", &s.bytesRecv) m.Set("bytes_sent", &s.bytesSent) m.Set("packets_dropped", &s.packetsDropped) diff --git a/derp/derp_test.go b/derp/derp_test.go index a37dd0390..d8dfa714b 100644 --- a/derp/derp_test.go +++ b/derp/derp_test.go @@ -23,6 +23,7 @@ import ( "testing" "time" + "golang.org/x/time/rate" "tailscale.com/net/nettest" "tailscale.com/types/key" "tailscale.com/types/logger" @@ -849,6 +850,136 @@ func TestClientSendPong(t *testing.T) { } +func TestServerReplaceClients(t *testing.T) { + defer func() { + timeSleep = time.Sleep + timeNow = time.Now + }() + + var ( + mu sync.Mutex + now = time.Unix(123, 0) + sleeps int + slept time.Duration + ) + timeSleep = func(d time.Duration) { + mu.Lock() + defer mu.Unlock() + sleeps++ + slept += d + now = now.Add(d) + } + timeNow = func() time.Time { + mu.Lock() + defer mu.Unlock() + return now + } + + serverPrivateKey := newPrivateKey(t) + var logger logger.Logf = logger.Discard + const debug = false + if debug { + logger = t.Logf + } + + s := NewServer(serverPrivateKey, logger) + defer s.Close() + + priv := newPrivateKey(t) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + connNum := 0 + connect := func() *Client { + connNum++ + cout, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + + cin, err := ln.Accept() + if err != nil { + t.Fatal(err) + } + + brwServer := bufio.NewReadWriter(bufio.NewReader(cin), bufio.NewWriter(cin)) + go s.Accept(cin, brwServer, fmt.Sprintf("test-client-%d", connNum)) + + brw := bufio.NewReadWriter(bufio.NewReader(cout), bufio.NewWriter(cout)) + c, err := NewClient(priv, cout, brw, logger) + if err != nil { + t.Fatalf("client %d: %v", connNum, err) + } + return c + } + + wantVar := func(v *expvar.Int, want int64) { + t.Helper() + if got := v.Value(); got != want { + t.Errorf("got %d; want %d", got, want) + } + } + + wantClosed := func(c *Client) { + t.Helper() + for { + m, err := c.Recv() + if err != nil { + t.Logf("got expected error: %v", err) + return + } + switch m.(type) { + case ServerInfoMessage: + continue + default: + t.Fatalf("client got %T; wanted an error", m) + } + } + } + + c1 := connect() + waitConnect(t, c1) + c2 := connect() + waitConnect(t, c2) + wantVar(&s.clientsReplaced, 1) + wantClosed(c1) + + for i := 0; i < 100+5; i++ { + c := connect() + defer c.nc.Close() + if s.clientsReplaceLimited.Value() == 0 && i < 90 { + continue + } + t.Logf("for %d: replaced=%d, limited=%d, sleeping=%d", i, + s.clientsReplaced.Value(), + s.clientsReplaceLimited.Value(), + s.clientsReplaceSleeping.Value(), + ) + } + + mu.Lock() + defer mu.Unlock() + if sleeps == 0 { + t.Errorf("no sleeps") + } + if slept == 0 { + t.Errorf("total sleep duration was 0") + } +} + +func TestLimiter(t *testing.T) { + rl := rate.NewLimiter(rate.Every(time.Minute), 100) + for i := 0; i < 200; i++ { + r := rl.Reserve() + d := r.Delay() + t.Logf("i=%d, allow=%v, d=%v", i, r.OK(), d) + } +} + func BenchmarkSendRecv(b *testing.B) { for _, size := range []int{10, 100, 1000, 10000} { b.Run(fmt.Sprintf("msgsize=%d", size), func(b *testing.B) { benchmarkSendRecvSize(b, size) })