diff --git a/cmd/derper/derper.go b/cmd/derper/derper.go index 8a567d63a..a2bb59606 100644 --- a/cmd/derper/derper.go +++ b/cmd/derper/derper.go @@ -16,6 +16,7 @@ import ( "io" "io/ioutil" "log" + "math" "net" "net/http" "os" @@ -24,6 +25,7 @@ import ( "strings" "time" + "golang.org/x/time/rate" "tailscale.com/atomicfile" "tailscale.com/derp" "tailscale.com/derp/derphttp" @@ -49,6 +51,9 @@ var ( meshWith = flag.String("mesh-with", "", "optional comma-separated list of hostnames to mesh with; the server's own hostname can be in the list") bootstrapDNS = flag.String("bootstrap-dns-names", "", "optional comma-separated list of hostnames to make available at /bootstrap-dns") verifyClients = flag.Bool("verify-clients", false, "verify clients to this DERP server through a local tailscaled instance.") + + acceptConnLimit = flag.Float64("accept-connection-limit", math.Inf(+1), "rate limit for accepting new connection") + acceptConnBurst = flag.Int("accept-connection-burst", math.MaxInt, "burst limit for accepting new connection") ) var ( @@ -296,7 +301,7 @@ func main() { } }() } - err = httpsrv.ListenAndServeTLS("", "") + err = rateLimitedListenAndServeTLS(httpsrv) } else { log.Printf("derper: serving on %s", *addr) err = httpsrv.ListenAndServe() @@ -390,3 +395,63 @@ func defaultMeshPSKFile() string { } return "" } + +func rateLimitedListenAndServeTLS(srv *http.Server) error { + addr := srv.Addr + if addr == "" { + addr = ":https" + } + ln, err := net.Listen("tcp", addr) + if err != nil { + return err + } + rln := newRateLimitedListener(ln, rate.Limit(*acceptConnLimit), *acceptConnBurst) + expvar.Publish("tls_listener", rln.ExpVar()) + defer rln.Close() + return srv.ServeTLS(rln, "", "") +} + +type rateLimitedListener struct { + // These are at the start of the struct to ensure 64-bit alignment + // on 32-bit architecture regardless of what other fields may exist + // in this package. + numAccepts expvar.Int // does not include number of rejects + numRejects expvar.Int + + net.Listener + + lim *rate.Limiter +} + +func newRateLimitedListener(ln net.Listener, limit rate.Limit, burst int) *rateLimitedListener { + return &rateLimitedListener{Listener: ln, lim: rate.NewLimiter(limit, burst)} +} + +func (l *rateLimitedListener) ExpVar() expvar.Var { + m := new(metrics.Set) + m.Set("counter_accepted_connections", &l.numAccepts) + m.Set("counter_rejected_connections", &l.numRejects) + return m +} + +var errLimitedConn = errors.New("cannot accept connection; rate limited") + +func (l *rateLimitedListener) Accept() (net.Conn, error) { + // Even under a rate limited situation, we accept the connection immediately + // and close it, rather than being slow at accepting new connections. + // This provides two benefits: 1) it signals to the client that something + // is going on on the server, and 2) it prevents new connections from + // piling up and occupying resources in the OS kernel. + // The client will retry as needing (with backoffs in place). + cn, err := l.Listener.Accept() + if err != nil { + return nil, err + } + if !l.lim.Allow() { + l.numRejects.Add(1) + cn.Close() + return nil, errLimitedConn + } + l.numAccepts.Add(1) + return cn, nil +}