From d9df023e6f4063f3ed3d48ce82245dc237b3a444 Mon Sep 17 00:00:00 2001 From: Joe Tsai Date: Fri, 16 Dec 2022 10:14:00 -0800 Subject: [PATCH] net/connstats: enforce maximum number of connections (#6760) The Tailscale logging service has a hard limit on the maximum log message size that can be accepted. We want to ensure that netlog messages never exceed this limit otherwise a client cannot transmit logs. Move the goroutine for periodically dumping netlog messages from wgengine/netlog to net/connstats. This allows net/connstats to manage when it dumps messages, either based on time or by size. Updates tailscale/corp#8427 Signed-off-by: Joe Tsai --- net/connstats/stats.go | 129 +++++++++++++++++++++++---- net/connstats/stats_test.go | 33 ++++--- net/tstun/wrap_test.go | 15 ++-- types/netlogtype/netlogtype.go | 9 ++ wgengine/magicsock/magicsock_test.go | 12 ++- wgengine/netlog/logger.go | 114 ++++++++++------------- 6 files changed, 206 insertions(+), 106 deletions(-) diff --git a/net/connstats/stats.go b/net/connstats/stats.go index edd4707c4..a17afd4f8 100644 --- a/net/connstats/stats.go +++ b/net/connstats/stats.go @@ -7,9 +7,12 @@ package connstats import ( + "context" "net/netip" "sync" + "time" + "golang.org/x/sync/errgroup" "tailscale.com/net/packet" "tailscale.com/types/netlogtype" ) @@ -18,11 +21,64 @@ import ( // All methods are safe for concurrent use. // The zero value is ready for use. type Statistics struct { - mu sync.Mutex + maxConns int // immutable once set + + mu sync.Mutex + connCnts + + connCntsCh chan connCnts + shutdownCtx context.Context + shutdown context.CancelFunc + group errgroup.Group +} + +type connCnts struct { + start time.Time + end time.Time virtual map[netlogtype.Connection]netlogtype.Counts physical map[netlogtype.Connection]netlogtype.Counts } +// NewStatistics creates a data structure for tracking connection statistics +// that periodically dumps the virtual and physical connection counts +// depending on whether the maxPeriod or maxConns is exceeded. +// The dump function is called from a single goroutine. +// Shutdown must be called to cleanup resources. +func NewStatistics(maxPeriod time.Duration, maxConns int, dump func(start, end time.Time, virtual, physical map[netlogtype.Connection]netlogtype.Counts)) *Statistics { + s := &Statistics{maxConns: maxConns} + s.connCntsCh = make(chan connCnts, 256) + s.shutdownCtx, s.shutdown = context.WithCancel(context.Background()) + s.group.Go(func() error { + // TODO(joetsai): Using a ticker is problematic on mobile platforms + // where waking up a process every maxPeriod when there is no activity + // is a drain on battery life. Switch this instead to instead use + // a time.Timer that is triggered upon network activity. + ticker := new(time.Ticker) + if maxPeriod > 0 { + ticker := time.NewTicker(maxPeriod) + defer ticker.Stop() + } + + for { + var cc connCnts + select { + case cc = <-s.connCntsCh: + case <-ticker.C: + cc = s.extract() + case <-s.shutdownCtx.Done(): + cc = s.extract() + } + if len(cc.virtual)+len(cc.physical) > 0 && dump != nil { + dump(cc.start, cc.end, cc.virtual, cc.physical) + } + if s.shutdownCtx.Err() != nil { + return nil + } + } + }) + return s +} + // UpdateTxVirtual updates the counters for a transmitted IP packet // The source and destination of the packet directly correspond with // the source and destination in netlogtype.Connection. @@ -47,10 +103,10 @@ func (s *Statistics) updateVirtual(b []byte, receive bool) { s.mu.Lock() defer s.mu.Unlock() - if s.virtual == nil { - s.virtual = make(map[netlogtype.Connection]netlogtype.Counts) + cnts, found := s.virtual[conn] + if !found && !s.preInsertConn() { + return } - cnts := s.virtual[conn] if receive { cnts.RxPackets++ cnts.RxBytes += uint64(len(b)) @@ -82,10 +138,10 @@ func (s *Statistics) updatePhysical(src netip.Addr, dst netip.AddrPort, n int, r s.mu.Lock() defer s.mu.Unlock() - if s.physical == nil { - s.physical = make(map[netlogtype.Connection]netlogtype.Counts) + cnts, found := s.physical[conn] + if !found && !s.preInsertConn() { + return } - cnts := s.physical[conn] if receive { cnts.RxPackets++ cnts.RxBytes += uint64(n) @@ -96,14 +152,57 @@ func (s *Statistics) updatePhysical(src netip.Addr, dst netip.AddrPort, n int, r s.physical[conn] = cnts } -// Extract extracts and resets the counters for all active connections. -// It must be called periodically otherwise the memory used is unbounded. -func (s *Statistics) Extract() (virtual, physical map[netlogtype.Connection]netlogtype.Counts) { +// preInsertConn updates the maps to handle insertion of a new connection. +// It reports false if insertion is not allowed (i.e., after shutdown). +func (s *Statistics) preInsertConn() bool { + // Check whether insertion of a new connection will exceed maxConns. + if len(s.virtual)+len(s.physical) == s.maxConns && s.maxConns > 0 { + // Extract the current statistics and send it to the serializer. + // Avoid blocking the network packet handling path. + select { + case s.connCntsCh <- s.extractLocked(): + default: + // TODO(joetsai): Log that we are dropping an entire connCounts. + } + } + + // Initialize the maps if nil. + if s.virtual == nil && s.physical == nil { + s.start = time.Now().UTC() + s.virtual = make(map[netlogtype.Connection]netlogtype.Counts) + s.physical = make(map[netlogtype.Connection]netlogtype.Counts) + } + + return s.shutdownCtx.Err() == nil +} + +func (s *Statistics) extract() connCnts { s.mu.Lock() defer s.mu.Unlock() - virtual = s.virtual - s.virtual = make(map[netlogtype.Connection]netlogtype.Counts) - physical = s.physical - s.physical = make(map[netlogtype.Connection]netlogtype.Counts) - return virtual, physical + return s.extractLocked() +} + +func (s *Statistics) extractLocked() connCnts { + if len(s.virtual)+len(s.physical) == 0 { + return connCnts{} + } + s.end = time.Now().UTC() + cc := s.connCnts + s.connCnts = connCnts{} + return cc +} + +// TestExtract synchronously extracts the current network statistics map +// and resets the counters. This should only be used for testing purposes. +func (s *Statistics) TestExtract() (virtual, physical map[netlogtype.Connection]netlogtype.Counts) { + cc := s.extract() + return cc.virtual, cc.physical +} + +// Shutdown performs a final flush of statistics. +// Statistics for any subsequent calls to Update will be dropped. +// It is safe to call Shutdown concurrently and repeatedly. +func (s *Statistics) Shutdown(context.Context) error { + s.shutdown() + return s.group.Wait() } diff --git a/net/connstats/stats_test.go b/net/connstats/stats_test.go index 7c212e8aa..7e02bc0f7 100644 --- a/net/connstats/stats_test.go +++ b/net/connstats/stats_test.go @@ -5,6 +5,7 @@ package connstats import ( + "context" "encoding/binary" "fmt" "math/rand" @@ -47,7 +48,20 @@ func testPacketV4(proto ipproto.Proto, srcAddr, dstAddr [4]byte, srcPort, dstPor func TestConcurrent(t *testing.T) { c := qt.New(t) - var stats Statistics + const maxPeriod = 10 * time.Millisecond + const maxConns = 10 + virtualAggregate := make(map[netlogtype.Connection]netlogtype.Counts) + stats := NewStatistics(maxPeriod, maxConns, func(start, end time.Time, virtual, physical map[netlogtype.Connection]netlogtype.Counts) { + c.Assert(start.IsZero(), qt.IsFalse) + c.Assert(end.IsZero(), qt.IsFalse) + c.Assert(end.Before(start), qt.IsFalse) + c.Assert(len(virtual) > 0 && len(virtual) <= maxConns, qt.IsTrue) + c.Assert(len(physical) == 0, qt.IsTrue) + for conn, cnts := range virtual { + virtualAggregate[conn] = virtualAggregate[conn].Add(cnts) + } + }) + defer stats.Shutdown(context.Background()) var wants []map[netlogtype.Connection]netlogtype.Counts gots := make([]map[netlogtype.Connection]netlogtype.Counts, runtime.NumCPU()) var group sync.WaitGroup @@ -95,14 +109,9 @@ func TestConcurrent(t *testing.T) { } }(i) } - for range gots { - virtual, _ := stats.Extract() - wants = append(wants, virtual) - time.Sleep(time.Millisecond) - } group.Wait() - virtual, _ := stats.Extract() - wants = append(wants, virtual) + c.Assert(stats.Shutdown(context.Background()), qt.IsNil) + wants = append(wants, virtualAggregate) got := make(map[netlogtype.Connection]netlogtype.Counts) want := make(map[netlogtype.Connection]netlogtype.Counts) @@ -126,7 +135,7 @@ func Benchmark(b *testing.B) { b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { - var s Statistics + s := NewStatistics(0, 0, nil) for j := 0; j < 1e3; j++ { s.UpdateTxVirtual(p) } @@ -137,7 +146,7 @@ func Benchmark(b *testing.B) { b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { - var s Statistics + s := NewStatistics(0, 0, nil) for j := 0; j < 1e3; j++ { binary.BigEndian.PutUint32(p[20:], uint32(j)) // unique port combination s.UpdateTxVirtual(p) @@ -149,7 +158,7 @@ func Benchmark(b *testing.B) { b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { - var s Statistics + s := NewStatistics(0, 0, nil) var group sync.WaitGroup for j := 0; j < runtime.NumCPU(); j++ { group.Add(1) @@ -171,7 +180,7 @@ func Benchmark(b *testing.B) { b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { - var s Statistics + s := NewStatistics(0, 0, nil) var group sync.WaitGroup for j := 0; j < runtime.NumCPU(); j++ { group.Add(1) diff --git a/net/tstun/wrap_test.go b/net/tstun/wrap_test.go index 54b6f3bba..1c5b005ce 100644 --- a/net/tstun/wrap_test.go +++ b/net/tstun/wrap_test.go @@ -6,15 +6,17 @@ package tstun import ( "bytes" + "context" "encoding/binary" "fmt" "net/netip" - "reflect" "strconv" "strings" "testing" "unsafe" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/tailscale/wireguard-go/tun/tuntest" "go4.org/mem" "go4.org/netipx" @@ -337,7 +339,8 @@ func TestFilter(t *testing.T) { }() var buf [MaxPacketSize]byte - stats := new(connstats.Statistics) + stats := connstats.NewStatistics(0, 0, nil) + defer stats.Shutdown(context.Background()) tun.SetStatistics(stats) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -346,7 +349,7 @@ func TestFilter(t *testing.T) { var filtered bool sizes := make([]int, 1) - tunStats, _ := stats.Extract() + tunStats, _ := stats.TestExtract() if len(tunStats) > 0 { t.Errorf("connstats.Statistics.Extract = %v, want {}", stats) } @@ -381,7 +384,7 @@ func TestFilter(t *testing.T) { } } - got, _ := stats.Extract() + got, _ := stats.TestExtract() want := map[netlogtype.Connection]netlogtype.Counts{} if !tt.drop { var p packet.Parsed @@ -395,8 +398,8 @@ func TestFilter(t *testing.T) { want[conn] = netlogtype.Counts{TxPackets: 1, TxBytes: uint64(len(tt.data))} } } - if !reflect.DeepEqual(got, want) { - t.Errorf("tun.ExtractStatistics = %v, want %v", got, want) + if diff := cmp.Diff(got, want, cmpopts.EquateEmpty()); diff != "" { + t.Errorf("stats.TestExtract (-got +want):\n%s", diff) } }) } diff --git a/types/netlogtype/netlogtype.go b/types/netlogtype/netlogtype.go index 1553806fb..d4b65d459 100644 --- a/types/netlogtype/netlogtype.go +++ b/types/netlogtype/netlogtype.go @@ -30,6 +30,15 @@ type Message struct { } const ( + messageJSON = `{"nodeId":"n0123456789abcdefCNTRL",` + maxJSONTimeRange + `,` + minJSONTraffic + `}` + maxJSONTimeRange = `"start":` + maxJSONRFC3339 + `,"end":` + maxJSONRFC3339 + maxJSONRFC3339 = `"0001-01-01T00:00:00.000000000Z"` + minJSONTraffic = `"virtualTraffic":{},"subnetTraffic":{},"exitTraffic":{},"physicalTraffic":{}` + + // MaxMessageJSONSize is the overhead size of Message when it is + // serialized as JSON assuming that each traffic map is populated. + MaxMessageJSONSize = len(messageJSON) + maxJSONConnCounts = `{` + maxJSONConn + `,` + maxJSONCounts + `}` maxJSONConn = `"proto":` + maxJSONProto + `,"src":` + maxJSONAddrPort + `,"dst":` + maxJSONAddrPort maxJSONProto = `255` diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 2dc9c1513..c970a68c8 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -135,7 +135,7 @@ func runDERPAndStun(t *testing.T, logf logger.Logf, l nettype.PacketListener, st type magicStack struct { privateKey key.NodePrivate epCh chan []tailcfg.Endpoint // endpoint updates produced by this peer - stats connstats.Statistics // per-connection statistics + stats *connstats.Statistics // per-connection statistics conn *Conn // the magicsock itself tun *tuntest.ChannelTUN // TUN device to send/receive packets tsTun *tstun.Wrapper // wrapped tun that implements filtering and wgengine hooks @@ -1053,11 +1053,15 @@ func testTwoDevicePing(t *testing.T, d *devices) { } } - m1.conn.SetStatistics(&m1.stats) - m2.conn.SetStatistics(&m2.stats) + m1.stats = connstats.NewStatistics(0, 0, nil) + defer m1.stats.Shutdown(context.Background()) + m1.conn.SetStatistics(m1.stats) + m2.stats = connstats.NewStatistics(0, 0, nil) + defer m2.stats.Shutdown(context.Background()) + m2.conn.SetStatistics(m2.stats) checkStats := func(t *testing.T, m *magicStack, wantConns []netlogtype.Connection) { - _, stats := m.stats.Extract() + _, stats := m.stats.TestExtract() for _, conn := range wantConns { if _, ok := stats[conn]; ok { return diff --git a/wgengine/netlog/logger.go b/wgengine/netlog/logger.go index 17d482b84..19970d626 100644 --- a/wgengine/netlog/logger.go +++ b/wgengine/netlog/logger.go @@ -17,7 +17,6 @@ import ( "sync" "time" - "golang.org/x/sync/errgroup" "tailscale.com/logpolicy" "tailscale.com/logtail" "tailscale.com/net/connstats" @@ -25,6 +24,7 @@ import ( "tailscale.com/smallzstd" "tailscale.com/tailcfg" "tailscale.com/types/netlogtype" + "tailscale.com/util/multierr" "tailscale.com/wgengine/router" ) @@ -32,8 +32,7 @@ import ( const pollPeriod = 5 * time.Second // Device is an abstraction over a tunnel device or a magic socket. -// *tstun.Wrapper implements this interface. -// *magicsock.Conn implements this interface. +// Both *tstun.Wrapper and *magicsock.Conn implement this interface. type Device interface { SetStatistics(*connstats.Statistics) } @@ -47,15 +46,15 @@ func (noopDevice) SetStatistics(*connstats.Statistics) {} // Exit node traffic is not logged for privacy reasons. // The zero value is ready for use. type Logger struct { - mu sync.Mutex + mu sync.Mutex // protects all fields below logger *logtail.Logger + stats *connstats.Statistics + tun Device + sock Device addrs map[netip.Addr]bool prefixes map[netip.Prefix]bool - - group errgroup.Group - cancel context.CancelFunc } // Running reports whether the logger is running. @@ -97,18 +96,13 @@ func (nl *Logger) Startup(nodeID tailcfg.StableNodeID, nodeLogID, domainLogID lo if nl.logger != nil { return fmt.Errorf("network logger already running for %v", nl.logger.PrivateID().Public()) } - if tun == nil { - tun = noopDevice{} - } - if sock == nil { - sock = noopDevice{} - } + // Startup a log stream to Tailscale's logging service. httpc := &http.Client{Transport: logpolicy.NewLogtailTransport(logtail.DefaultHost)} if testClient != nil { httpc = testClient } - logger := logtail.NewLogger(logtail.Config{ + nl.logger = logtail.NewLogger(logtail.Config{ Collection: "tailtraffic.log.tailscale.io", PrivateID: nodeLogID, CopyPrivateID: domainLogID, @@ -127,47 +121,34 @@ func (nl *Logger) Startup(nodeID tailcfg.StableNodeID, nodeLogID, domainLogID lo IncludeProcID: true, IncludeProcSequence: true, }, log.Printf) - nl.logger = logger - - stats := new(connstats.Statistics) - ctx, cancel := context.WithCancel(context.Background()) - nl.cancel = cancel - nl.group.Go(func() error { - tun.SetStatistics(stats) - defer tun.SetStatistics(nil) - sock.SetStatistics(stats) - defer sock.SetStatistics(nil) + // Startup a data structure to track per-connection statistics. + // There is a maximum size for individual log messages that logtail + // can upload to the Tailscale log service, so stay below this limit. + const maxLogSize = 256 << 10 + const maxConns = (maxLogSize - netlogtype.MaxMessageJSONSize) / netlogtype.MaxConnectionCountsJSONSize + nl.stats = connstats.NewStatistics(pollPeriod, maxConns, func(start, end time.Time, virtual, physical map[netlogtype.Connection]netlogtype.Counts) { + nl.mu.Lock() + addrs := nl.addrs + prefixes := nl.prefixes + nl.mu.Unlock() + recordStatistics(nl.logger, nodeID, start, end, virtual, physical, addrs, prefixes) + }) - start := time.Now() - ticker := time.NewTicker(pollPeriod) - for { - var end time.Time - select { - case <-ctx.Done(): - end = time.Now() - case end = <-ticker.C: - } + // Register the connection tracker into the TUN device. + if tun == nil { + tun = noopDevice{} + } + nl.tun = tun + nl.tun.SetStatistics(nl.stats) - // NOTE: connstats and sockStats will always be slightly out-of-sync. - // It is impossible to have an atomic snapshot of statistics - // at both layers without a global mutex that spans all layers. - connstats, sockStats := stats.Extract() - if len(connstats)+len(sockStats) > 0 { - nl.mu.Lock() - addrs := nl.addrs - prefixes := nl.prefixes - nl.mu.Unlock() - recordStatistics(logger, nodeID, start, end, connstats, sockStats, addrs, prefixes) - } + // Register the connection tracker into magicsock. + if sock == nil { + sock = noopDevice{} + } + nl.sock = sock + nl.sock.SetStatistics(nl.stats) - if ctx.Err() != nil { - break - } - start = end.Add(time.Nanosecond) - } - return nil - }) return nil } @@ -222,21 +203,8 @@ func recordStatistics(logger *logtail.Logger, nodeID tailcfg.StableNodeID, start } if len(m.VirtualTraffic)+len(m.SubnetTraffic)+len(m.ExitTraffic)+len(m.PhysicalTraffic) > 0 { - // TODO(joetsai): Place a hard limit on the size of a network log message. - // The log server rejects any payloads above a certain size, so logging - // a message that large would cause logtail to be stuck forever trying - // and failing to upload the same excessively large payload. - // - // We should figure out the behavior for handling this. We could split - // the message apart so that there are multiple chunks with the same window, - // We could also consider reducing the granularity of the data - // by dropping port numbers. - const maxSize = 256 << 10 if b, err := json.Marshal(m); err != nil { logger.Logf("json.Marshal error: %v", err) - } else if len(b) > maxSize { - logger.Logf("JSON body too large: %dB (virtual:%d subnet:%d exit:%d physical:%d)", - len(b), len(m.VirtualTraffic), len(m.SubnetTraffic), len(m.ExitTraffic), len(m.PhysicalTraffic)) } else { logger.Logf("%s", b) } @@ -285,15 +253,23 @@ func (nl *Logger) Shutdown(ctx context.Context) error { if nl.logger == nil { return nil } - nl.cancel() + + // Shutdown in reverse order of Startup. + // Do not hold lock while shutting down since this may flush one last time. nl.mu.Unlock() - nl.group.Wait() // do not hold lock while waiting + nl.sock.SetStatistics(nil) + nl.tun.SetStatistics(nil) + err1 := nl.stats.Shutdown(ctx) + err2 := nl.logger.Shutdown(ctx) nl.mu.Lock() - err := nl.logger.Shutdown(ctx) + // Purge state. nl.logger = nil + nl.stats = nil + nl.tun = nil + nl.sock = nil nl.addrs = nil nl.prefixes = nil - nl.cancel = nil - return err + + return multierr.New(err1, err2) }