diff --git a/log/sockstatlog/logger.go b/log/sockstatlog/logger.go index 8e4209709..4d44d7d47 100644 --- a/log/sockstatlog/logger.go +++ b/log/sockstatlog/logger.go @@ -5,6 +5,7 @@ package sockstatlog import ( + "context" "encoding/json" "io" "os" @@ -22,6 +23,9 @@ const pollPeriod = time.Second / 10 // Logger logs statistics about network sockets. type Logger struct { + ctx context.Context + cancelFn context.CancelFunc + ticker time.Ticker logf logger.Logf logbuffer *filch.Filch @@ -63,7 +67,10 @@ func NewLogger(logdir string, logf logger.Logf) (*Logger, error) { return nil, err } + ctx, cancel := context.WithCancel(context.Background()) return &Logger{ + ctx: ctx, + cancelFn: cancel, ticker: *time.NewTicker(pollPeriod), logf: logf, logbuffer: filch, @@ -83,32 +90,38 @@ func (l *Logger) poll() { var lastTime time.Time enc := json.NewEncoder(l.logbuffer) - for t := range l.ticker.C { - stats := sockstats.Get() - if lastStats != nil { - diffstats := delta(lastStats, stats) - if len(diffstats) > 0 { - e := event{ - Time: lastTime.UnixMilli(), - Duration: t.Sub(lastTime).Milliseconds(), - Stats: diffstats, - } - if stats.CurrentInterfaceCellular { - e.IsCellularInterface = 1 - } - if err := enc.Encode(e); err != nil { - l.logf("sockstatlog: error encoding log: %v", err) + for { + select { + case <-l.ctx.Done(): + return + case t := <-l.ticker.C: + stats := sockstats.Get() + if lastStats != nil { + diffstats := delta(lastStats, stats) + if len(diffstats) > 0 { + e := event{ + Time: lastTime.UnixMilli(), + Duration: t.Sub(lastTime).Milliseconds(), + Stats: diffstats, + } + if stats.CurrentInterfaceCellular { + e.IsCellularInterface = 1 + } + if err := enc.Encode(e); err != nil { + l.logf("sockstatlog: error encoding log: %v", err) + } } } + lastTime = t + lastStats = stats } - lastTime = t - lastStats = stats } } func (l *Logger) Shutdown() { l.ticker.Stop() l.logbuffer.Close() + l.cancelFn() } // WriteLogs reads local logs, combining logs into events, and writes them to w.