net/connstats: invert network logging data flow (#6272)

Previously, tstun.Wrapper and magicsock.Conn managed their
own statistics data structure and relied on an external call to
Extract to extract (and reset) the statistics.
This makes it difficult to ensure a maximum size on the statistics
as the caller has no introspection into whether the number
of unique connections is getting too large.

Invert the control flow such that a *connstats.Statistics
is registered with tstun.Wrapper and magicsock.Conn.
Methods on non-nil *connstats.Statistics are called for every packet.
This allows the implementation of connstats.Statistics (in the future)
to better control when it needs to flush to ensure
bounds on maximum sizes.

The value registered into tstun.Wrapper and magicsock.Conn could
be an interface, but that has two performance detriments:

1. Method calls on interface values are more expensive since
they must go through a virtual method dispatch.

2. The implementation would need a sync.Mutex to protect the
statistics value instead of using an atomic.Pointer.

Given that methods on constats.Statistics are called for every packet,
we want reduce the CPU cost on this hot path.

Signed-off-by: Joe Tsai <joetsai@digital-static.net>
pull/6548/head
Joe Tsai 2 years ago committed by GitHub
parent 35c10373b5
commit 2e5d08ec4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -218,6 +218,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
tailscale.com/logtail/backoff from tailscale.com/control/controlclient+ tailscale.com/logtail/backoff from tailscale.com/control/controlclient+
tailscale.com/logtail/filch from tailscale.com/logpolicy tailscale.com/logtail/filch from tailscale.com/logpolicy
tailscale.com/metrics from tailscale.com/derp+ tailscale.com/metrics from tailscale.com/derp+
tailscale.com/net/connstats from tailscale.com/net/tstun+
tailscale.com/net/dns from tailscale.com/ipn/ipnlocal+ tailscale.com/net/dns from tailscale.com/ipn/ipnlocal+
tailscale.com/net/dns/publicdns from tailscale.com/net/dns/resolver+ tailscale.com/net/dns/publicdns from tailscale.com/net/dns/resolver+
tailscale.com/net/dns/resolvconffile from tailscale.com/net/dns+ tailscale.com/net/dns/resolvconffile from tailscale.com/net/dns+
@ -245,7 +246,6 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
tailscale.com/net/tsdial from tailscale.com/control/controlclient+ tailscale.com/net/tsdial from tailscale.com/control/controlclient+
💣 tailscale.com/net/tshttpproxy from tailscale.com/control/controlclient+ 💣 tailscale.com/net/tshttpproxy from tailscale.com/control/controlclient+
tailscale.com/net/tstun from tailscale.com/net/dns+ tailscale.com/net/tstun from tailscale.com/net/dns+
tailscale.com/net/tunstats from tailscale.com/net/tstun
tailscale.com/net/wsconn from tailscale.com/control/controlhttp+ tailscale.com/net/wsconn from tailscale.com/control/controlhttp+
tailscale.com/paths from tailscale.com/ipn/ipnlocal+ tailscale.com/paths from tailscale.com/ipn/ipnlocal+
💣 tailscale.com/portlist from tailscale.com/ipn/ipnlocal 💣 tailscale.com/portlist from tailscale.com/ipn/ipnlocal
@ -268,7 +268,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
tailscale.com/types/key from tailscale.com/control/controlbase+ tailscale.com/types/key from tailscale.com/control/controlbase+
tailscale.com/types/logger from tailscale.com/control/controlclient+ tailscale.com/types/logger from tailscale.com/control/controlclient+
tailscale.com/types/logid from tailscale.com/logtail+ tailscale.com/types/logid from tailscale.com/logtail+
tailscale.com/types/netlogtype from tailscale.com/net/tstun+ tailscale.com/types/netlogtype from tailscale.com/net/connstats+
tailscale.com/types/netmap from tailscale.com/control/controlclient+ tailscale.com/types/netmap from tailscale.com/control/controlclient+
tailscale.com/types/nettype from tailscale.com/wgengine/magicsock+ tailscale.com/types/nettype from tailscale.com/wgengine/magicsock+
tailscale.com/types/opt from tailscale.com/control/controlclient+ tailscale.com/types/opt from tailscale.com/control/controlclient+

@ -0,0 +1,109 @@
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package connstats maintains statistics about connections
// flowing through a TUN device (which operate at the IP layer).
package connstats
import (
"net/netip"
"sync"
"tailscale.com/net/packet"
"tailscale.com/types/netlogtype"
)
// Statistics maintains counters for every connection.
// All methods are safe for concurrent use.
// The zero value is ready for use.
type Statistics struct {
mu sync.Mutex
virtual map[netlogtype.Connection]netlogtype.Counts
physical map[netlogtype.Connection]netlogtype.Counts
}
// UpdateTxVirtual updates the counters for a transmitted IP packet
// The source and destination of the packet directly correspond with
// the source and destination in netlogtype.Connection.
func (s *Statistics) UpdateTxVirtual(b []byte) {
s.updateVirtual(b, false)
}
// UpdateRxVirtual updates the counters for a received IP packet.
// The source and destination of the packet are inverted with respect to
// the source and destination in netlogtype.Connection.
func (s *Statistics) UpdateRxVirtual(b []byte) {
s.updateVirtual(b, true)
}
func (s *Statistics) updateVirtual(b []byte, receive bool) {
var p packet.Parsed
p.Decode(b)
conn := netlogtype.Connection{Proto: p.IPProto, Src: p.Src, Dst: p.Dst}
if receive {
conn.Src, conn.Dst = conn.Dst, conn.Src
}
s.mu.Lock()
defer s.mu.Unlock()
if s.virtual == nil {
s.virtual = make(map[netlogtype.Connection]netlogtype.Counts)
}
cnts := s.virtual[conn]
if receive {
cnts.RxPackets++
cnts.RxBytes += uint64(len(b))
} else {
cnts.TxPackets++
cnts.TxBytes += uint64(len(b))
}
s.virtual[conn] = cnts
}
// UpdateTxPhysical updates the counters for a transmitted wireguard packet
// The src is always a Tailscale IP address, representing some remote peer.
// The dst is a remote IP address and port that corresponds
// with some physical peer backing the Tailscale IP address.
func (s *Statistics) UpdateTxPhysical(src netip.Addr, dst netip.AddrPort, n int) {
s.updatePhysical(src, dst, n, false)
}
// UpdateRxPhysical updates the counters for a received wireguard packet.
// The src is always a Tailscale IP address, representing some remote peer.
// The dst is a remote IP address and port that corresponds
// with some physical peer backing the Tailscale IP address.
func (s *Statistics) UpdateRxPhysical(src netip.Addr, dst netip.AddrPort, n int) {
s.updatePhysical(src, dst, n, true)
}
func (s *Statistics) updatePhysical(src netip.Addr, dst netip.AddrPort, n int, receive bool) {
conn := netlogtype.Connection{Src: netip.AddrPortFrom(src, 0), Dst: dst}
s.mu.Lock()
defer s.mu.Unlock()
if s.physical == nil {
s.physical = make(map[netlogtype.Connection]netlogtype.Counts)
}
cnts := s.physical[conn]
if receive {
cnts.RxPackets++
cnts.RxBytes += uint64(n)
} else {
cnts.TxPackets++
cnts.TxBytes += uint64(n)
}
s.physical[conn] = cnts
}
// Extract extracts and resets the counters for all active connections.
// It must be called periodically otherwise the memory used is unbounded.
func (s *Statistics) Extract() (virtual, physical map[netlogtype.Connection]netlogtype.Counts) {
s.mu.Lock()
defer s.mu.Unlock()
virtual = s.virtual
s.virtual = make(map[netlogtype.Connection]netlogtype.Counts)
physical = s.physical
s.physical = make(map[netlogtype.Connection]netlogtype.Counts)
return virtual, physical
}

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package tunstats package connstats
import ( import (
"encoding/binary" "encoding/binary"
@ -82,13 +82,13 @@ func TestConcurrent(t *testing.T) {
cnts := gots[i][t2] cnts := gots[i][t2]
if receive { if receive {
stats.UpdateRx(p) stats.UpdateRxVirtual(p)
cnts.RxPackets++ cnts.RxPackets++
cnts.RxBytes += uint64(len(p)) cnts.RxBytes += uint64(len(p))
} else { } else {
cnts.TxPackets++ cnts.TxPackets++
cnts.TxBytes += uint64(len(p)) cnts.TxBytes += uint64(len(p))
stats.UpdateTx(p) stats.UpdateTxVirtual(p)
} }
gots[i][t2] = cnts gots[i][t2] = cnts
time.Sleep(time.Duration(rn.Intn(1 + delay))) time.Sleep(time.Duration(rn.Intn(1 + delay)))
@ -96,11 +96,13 @@ func TestConcurrent(t *testing.T) {
}(i) }(i)
} }
for range gots { for range gots {
wants = append(wants, stats.Extract()) virtual, _ := stats.Extract()
wants = append(wants, virtual)
time.Sleep(time.Millisecond) time.Sleep(time.Millisecond)
} }
group.Wait() group.Wait()
wants = append(wants, stats.Extract()) virtual, _ := stats.Extract()
wants = append(wants, virtual)
got := make(map[netlogtype.Connection]netlogtype.Counts) got := make(map[netlogtype.Connection]netlogtype.Counts)
want := make(map[netlogtype.Connection]netlogtype.Counts) want := make(map[netlogtype.Connection]netlogtype.Counts)
@ -126,7 +128,7 @@ func Benchmark(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
var s Statistics var s Statistics
for j := 0; j < 1e3; j++ { for j := 0; j < 1e3; j++ {
s.UpdateTx(p) s.UpdateTxVirtual(p)
} }
} }
}) })
@ -138,7 +140,7 @@ func Benchmark(b *testing.B) {
var s Statistics var s Statistics
for j := 0; j < 1e3; j++ { for j := 0; j < 1e3; j++ {
binary.BigEndian.PutUint32(p[20:], uint32(j)) // unique port combination binary.BigEndian.PutUint32(p[20:], uint32(j)) // unique port combination
s.UpdateTx(p) s.UpdateTxVirtual(p)
} }
} }
}) })
@ -154,7 +156,7 @@ func Benchmark(b *testing.B) {
go func() { go func() {
defer group.Done() defer group.Done()
for k := 0; k < 1e3; k++ { for k := 0; k < 1e3; k++ {
s.UpdateTx(p) s.UpdateTxVirtual(p)
} }
}() }()
} }
@ -179,7 +181,7 @@ func Benchmark(b *testing.B) {
j *= 1e3 j *= 1e3
for k := 0; k < 1e3; k++ { for k := 0; k < 1e3; k++ {
binary.BigEndian.PutUint32(p[20:], uint32(j+k)) // unique port combination binary.BigEndian.PutUint32(p[20:], uint32(j+k)) // unique port combination
s.UpdateTx(p) s.UpdateTxVirtual(p)
} }
}(j) }(j)
} }

@ -22,15 +22,14 @@ import (
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
"gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/stack"
"tailscale.com/disco" "tailscale.com/disco"
"tailscale.com/net/connstats"
"tailscale.com/net/packet" "tailscale.com/net/packet"
"tailscale.com/net/tsaddr" "tailscale.com/net/tsaddr"
"tailscale.com/net/tunstats"
"tailscale.com/syncs" "tailscale.com/syncs"
"tailscale.com/tstime/mono" "tailscale.com/tstime/mono"
"tailscale.com/types/ipproto" "tailscale.com/types/ipproto"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/logger" "tailscale.com/types/logger"
"tailscale.com/types/netlogtype"
"tailscale.com/util/clientmetric" "tailscale.com/util/clientmetric"
"tailscale.com/wgengine/filter" "tailscale.com/wgengine/filter"
) )
@ -170,10 +169,7 @@ type Wrapper struct {
disableTSMPRejected bool disableTSMPRejected bool
// stats maintains per-connection counters. // stats maintains per-connection counters.
stats struct { stats atomic.Pointer[connstats.Statistics]
enabled atomic.Bool
tunstats.Statistics
}
} }
// tunReadResult is the result of a TUN read, or an injected result pretending to be a TUN read. // tunReadResult is the result of a TUN read, or an injected result pretending to be a TUN read.
@ -568,8 +564,8 @@ func (t *Wrapper) Read(buf []byte, offset int) (int, error) {
} }
} }
if t.stats.enabled.Load() { if stats := t.stats.Load(); stats != nil {
t.stats.UpdateTx(buf[offset:][:n]) stats.UpdateTxVirtual(buf[offset:][:n])
} }
t.noteActivity() t.noteActivity()
return n, nil return n, nil
@ -701,8 +697,8 @@ func (t *Wrapper) Write(buf []byte, offset int) (int, error) {
} }
func (t *Wrapper) tdevWrite(buf []byte, offset int) (int, error) { func (t *Wrapper) tdevWrite(buf []byte, offset int) (int, error) {
if t.stats.enabled.Load() { if stats := t.stats.Load(); stats != nil {
t.stats.UpdateRx(buf[offset:]) stats.UpdateRxVirtual(buf[offset:])
} }
if t.isTAP { if t.isTAP {
return t.tapWrite(buf, offset) return t.tapWrite(buf, offset)
@ -843,18 +839,10 @@ func (t *Wrapper) Unwrap() tun.Device {
return t.tdev return t.tdev
} }
// SetStatisticsEnabled enables per-connections packet counters. // SetStatistics specifies a per-connection statistics aggregator.
// Disabling statistics gathering does not reset the counters. // Nil may be specified to disable statistics gathering.
// ExtractStatistics must be called to reset the counters and func (t *Wrapper) SetStatistics(stats *connstats.Statistics) {
// be periodically called while enabled to avoid unbounded memory use. t.stats.Store(stats)
func (t *Wrapper) SetStatisticsEnabled(enable bool) {
t.stats.enabled.Store(enable)
}
// ExtractStatistics extracts and resets the counters for all active connections.
// It must be called periodically otherwise the memory used is unbounded.
func (t *Wrapper) ExtractStatistics() map[netlogtype.Connection]netlogtype.Counts {
return t.stats.Extract()
} }
var ( var (

@ -19,6 +19,7 @@ import (
"go4.org/netipx" "go4.org/netipx"
"golang.zx2c4.com/wireguard/tun/tuntest" "golang.zx2c4.com/wireguard/tun/tuntest"
"tailscale.com/disco" "tailscale.com/disco"
"tailscale.com/net/connstats"
"tailscale.com/net/netaddr" "tailscale.com/net/netaddr"
"tailscale.com/net/packet" "tailscale.com/net/packet"
"tailscale.com/tstest" "tailscale.com/tstest"
@ -283,11 +284,6 @@ func TestWriteAndInject(t *testing.T) {
t.Errorf("%s not received", packet) t.Errorf("%s not received", packet)
} }
} }
// Statistics gathering is disabled by default.
if stats := tun.ExtractStatistics(); len(stats) > 0 {
t.Errorf("tun.ExtractStatistics = %v, want {}", stats)
}
} }
func TestFilter(t *testing.T) { func TestFilter(t *testing.T) {
@ -336,15 +332,17 @@ func TestFilter(t *testing.T) {
}() }()
var buf [MaxPacketSize]byte var buf [MaxPacketSize]byte
tun.SetStatisticsEnabled(true) stats := new(connstats.Statistics)
tun.SetStatistics(stats)
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
var n int var n int
var err error var err error
var filtered bool var filtered bool
if stats := tun.ExtractStatistics(); len(stats) > 0 { tunStats, _ := stats.Extract()
t.Errorf("tun.ExtractStatistics = %v, want {}", stats) if len(tunStats) > 0 {
t.Errorf("connstats.Statistics.Extract = %v, want {}", stats)
} }
if tt.dir == in { if tt.dir == in {
@ -377,7 +375,7 @@ func TestFilter(t *testing.T) {
} }
} }
got := tun.ExtractStatistics() got, _ := stats.Extract()
want := map[netlogtype.Connection]netlogtype.Counts{} want := map[netlogtype.Connection]netlogtype.Counts{}
if !tt.drop { if !tt.drop {
var p packet.Parsed var p packet.Parsed

@ -1,70 +0,0 @@
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package tunstats maintains statistics about connections
// flowing through a TUN device (which operate at the IP layer).
package tunstats
import (
"sync"
"tailscale.com/net/packet"
"tailscale.com/types/netlogtype"
)
// Statistics maintains counters for every connection.
// All methods are safe for concurrent use.
// The zero value is ready for use.
type Statistics struct {
mu sync.Mutex
m map[netlogtype.Connection]netlogtype.Counts
}
// UpdateTx updates the counters for a transmitted IP packet
// The source and destination of the packet directly correspond with
// the source and destination in netlogtype.Connection.
func (s *Statistics) UpdateTx(b []byte) {
s.update(b, false)
}
// UpdateRx updates the counters for a received IP packet.
// The source and destination of the packet are inverted with respect to
// the source and destination in netlogtype.Connection.
func (s *Statistics) UpdateRx(b []byte) {
s.update(b, true)
}
func (s *Statistics) update(b []byte, receive bool) {
var p packet.Parsed
p.Decode(b)
conn := netlogtype.Connection{Proto: p.IPProto, Src: p.Src, Dst: p.Dst}
if receive {
conn.Src, conn.Dst = conn.Dst, conn.Src
}
s.mu.Lock()
defer s.mu.Unlock()
if s.m == nil {
s.m = make(map[netlogtype.Connection]netlogtype.Counts)
}
cnts := s.m[conn]
if receive {
cnts.RxPackets++
cnts.RxBytes += uint64(len(b))
} else {
cnts.TxPackets++
cnts.TxBytes += uint64(len(b))
}
s.m[conn] = cnts
}
// Extract extracts and resets the counters for all active connections.
// It must be called periodically otherwise the memory used is unbounded.
func (s *Statistics) Extract() map[netlogtype.Connection]netlogtype.Counts {
s.mu.Lock()
defer s.mu.Unlock()
m := s.m
s.m = make(map[netlogtype.Connection]netlogtype.Counts)
return m
}

@ -37,6 +37,7 @@ import (
"tailscale.com/health" "tailscale.com/health"
"tailscale.com/ipn/ipnstate" "tailscale.com/ipn/ipnstate"
"tailscale.com/logtail/backoff" "tailscale.com/logtail/backoff"
"tailscale.com/net/connstats"
"tailscale.com/net/dnscache" "tailscale.com/net/dnscache"
"tailscale.com/net/interfaces" "tailscale.com/net/interfaces"
"tailscale.com/net/netaddr" "tailscale.com/net/netaddr"
@ -52,7 +53,6 @@ import (
"tailscale.com/tstime/mono" "tailscale.com/tstime/mono"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/logger" "tailscale.com/types/logger"
"tailscale.com/types/netlogtype"
"tailscale.com/types/netmap" "tailscale.com/types/netmap"
"tailscale.com/types/nettype" "tailscale.com/types/nettype"
"tailscale.com/util/clientmetric" "tailscale.com/util/clientmetric"
@ -337,19 +337,7 @@ type Conn struct {
port atomic.Uint32 port atomic.Uint32
// stats maintains per-connection counters. // stats maintains per-connection counters.
// See SetStatisticsEnabled and ExtractStatistics for details. stats atomic.Pointer[connstats.Statistics]
stats struct {
enabled atomic.Bool
// TODO(joetsai): A per-Conn map of connections is easiest to implement.
// Since every packet occurs within the context of an endpoint,
// we could track the counts within the endpoint itself,
// and then merge the results when ExtractStatistics is called.
// That would avoid a map lookup for every packet.
mu sync.Mutex
m map[netlogtype.Connection]netlogtype.Counts
}
// ============================================================ // ============================================================
// mu guards all following fields; see userspaceEngine lock // mu guards all following fields; see userspaceEngine lock
@ -1754,8 +1742,8 @@ func (c *Conn) receiveIP(b []byte, ipp netip.AddrPort, cache *ippEndpointCache,
ep = de ep = de
} }
ep.noteRecvActivity() ep.noteRecvActivity()
if c.stats.enabled.Load() { if stats := c.stats.Load(); stats != nil {
c.updateStats(ep.nodeAddr, ipp, netlogtype.Counts{RxPackets: 1, RxBytes: uint64(len(b))}) stats.UpdateRxPhysical(ep.nodeAddr, ipp, len(b))
} }
return ep, true return ep, true
} }
@ -1812,8 +1800,8 @@ func (c *Conn) processDERPReadResult(dm derpReadResult, b []byte) (n int, ep *en
} }
ep.noteRecvActivity() ep.noteRecvActivity()
if c.stats.enabled.Load() { if stats := c.stats.Load(); stats != nil {
c.updateStats(ep.nodeAddr, ipp, netlogtype.Counts{RxPackets: 1, RxBytes: uint64(dm.n)}) stats.UpdateRxPhysical(ep.nodeAddr, ipp, dm.n)
} }
return n, ep return n, ep
} }
@ -3306,37 +3294,10 @@ func (c *Conn) UpdateStatus(sb *ipnstate.StatusBuilder) {
}) })
} }
// updateStats updates the statistics counters with the src, dst, and cnts. // SetStatistics specifies a per-connection statistics aggregator.
// It is the caller's responsibility to check whether logging is enabled. // Nil may be specified to disable statistics gathering.
func (c *Conn) updateStats(src netip.Addr, dst netip.AddrPort, cnts netlogtype.Counts) { func (c *Conn) SetStatistics(stats *connstats.Statistics) {
conn := netlogtype.Connection{Src: netip.AddrPortFrom(src, 0), Dst: dst} c.stats.Store(stats)
c.stats.mu.Lock()
defer c.stats.mu.Unlock()
mak.Set(&c.stats.m, conn, c.stats.m[conn].Add(cnts))
}
// SetStatisticsEnabled enables per-connection packet counters.
// Disabling statistics gathering does not reset the counters.
// ExtractStatistics must be called to reset the counters and
// be periodically called while enabled to avoid unbounded memory use.
func (c *Conn) SetStatisticsEnabled(enable bool) {
c.stats.enabled.Store(enable)
}
// ExtractStatistics extracts and resets the counters for all active connections.
// It must be called periodically otherwise the memory used is unbounded.
//
// The source is always a peer's tailscale IP address,
// while the destination is the peer's physical IP address and port.
// As a special case, packets routed through DERP use a destination address
// of 127.3.3.40 with the port being the DERP region.
// This node's tailscale IP address never appears in the returned map.
func (c *Conn) ExtractStatistics() map[netlogtype.Connection]netlogtype.Counts {
c.stats.mu.Lock()
defer c.stats.mu.Unlock()
m := c.stats.m
c.stats.m = nil
return m
} }
func ippDebugString(ua netip.AddrPort) string { func ippDebugString(ua netip.AddrPort) string {
@ -3701,14 +3662,14 @@ func (de *endpoint) send(b []byte) error {
var err error var err error
if udpAddr.IsValid() { if udpAddr.IsValid() {
_, err = de.c.sendAddr(udpAddr, de.publicKey, b) _, err = de.c.sendAddr(udpAddr, de.publicKey, b)
if err == nil && de.c.stats.enabled.Load() { if stats := de.c.stats.Load(); err == nil && stats != nil {
de.c.updateStats(de.nodeAddr, udpAddr, netlogtype.Counts{TxPackets: 1, TxBytes: uint64(len(b))}) stats.UpdateTxPhysical(de.nodeAddr, udpAddr, len(b))
} }
} }
if derpAddr.IsValid() { if derpAddr.IsValid() {
if ok, _ := de.c.sendAddr(derpAddr, de.publicKey, b); ok { if ok, _ := de.c.sendAddr(derpAddr, de.publicKey, b); ok {
if de.c.stats.enabled.Load() { if stats := de.c.stats.Load(); stats != nil {
de.c.updateStats(de.nodeAddr, derpAddr, netlogtype.Counts{TxPackets: 1, TxBytes: uint64(len(b))}) stats.UpdateTxPhysical(de.nodeAddr, derpAddr, len(b))
} }
if err != nil { if err != nil {
// UDP failed but DERP worked, so good enough: // UDP failed but DERP worked, so good enough:

@ -35,6 +35,7 @@ import (
"tailscale.com/derp/derphttp" "tailscale.com/derp/derphttp"
"tailscale.com/disco" "tailscale.com/disco"
"tailscale.com/ipn/ipnstate" "tailscale.com/ipn/ipnstate"
"tailscale.com/net/connstats"
"tailscale.com/net/netaddr" "tailscale.com/net/netaddr"
"tailscale.com/net/stun/stuntest" "tailscale.com/net/stun/stuntest"
"tailscale.com/net/tstun" "tailscale.com/net/tstun"
@ -133,6 +134,7 @@ func runDERPAndStun(t *testing.T, logf logger.Logf, l nettype.PacketListener, st
type magicStack struct { type magicStack struct {
privateKey key.NodePrivate privateKey key.NodePrivate
epCh chan []tailcfg.Endpoint // endpoint updates produced by this peer epCh chan []tailcfg.Endpoint // endpoint updates produced by this peer
stats connstats.Statistics // per-connection statistics
conn *Conn // the magicsock itself conn *Conn // the magicsock itself
tun *tuntest.ChannelTUN // TUN device to send/receive packets tun *tuntest.ChannelTUN // TUN device to send/receive packets
tsTun *tstun.Wrapper // wrapped tun that implements filtering and wgengine hooks tsTun *tstun.Wrapper // wrapped tun that implements filtering and wgengine hooks
@ -1047,11 +1049,11 @@ func testTwoDevicePing(t *testing.T, d *devices) {
} }
} }
m1.conn.SetStatisticsEnabled(true) m1.conn.SetStatistics(&m1.stats)
m2.conn.SetStatisticsEnabled(true) m2.conn.SetStatistics(&m2.stats)
checkStats := func(t *testing.T, m *magicStack, wantConns []netlogtype.Connection) { checkStats := func(t *testing.T, m *magicStack, wantConns []netlogtype.Connection) {
stats := m.conn.ExtractStatistics() _, stats := m.stats.Extract()
for _, conn := range wantConns { for _, conn := range wantConns {
if _, ok := stats[conn]; ok { if _, ok := stats[conn]; ok {
return return

@ -20,6 +20,7 @@ import (
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"tailscale.com/logpolicy" "tailscale.com/logpolicy"
"tailscale.com/logtail" "tailscale.com/logtail"
"tailscale.com/net/connstats"
"tailscale.com/net/tsaddr" "tailscale.com/net/tsaddr"
"tailscale.com/smallzstd" "tailscale.com/smallzstd"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
@ -34,14 +35,12 @@ const pollPeriod = 5 * time.Second
// *tstun.Wrapper implements this interface. // *tstun.Wrapper implements this interface.
// *magicsock.Conn implements this interface. // *magicsock.Conn implements this interface.
type Device interface { type Device interface {
SetStatisticsEnabled(bool) SetStatistics(*connstats.Statistics)
ExtractStatistics() map[netlogtype.Connection]netlogtype.Counts
} }
type noopDevice struct{} type noopDevice struct{}
func (noopDevice) SetStatisticsEnabled(bool) {} func (noopDevice) SetStatistics(*connstats.Statistics) {}
func (noopDevice) ExtractStatistics() map[netlogtype.Connection]netlogtype.Counts { return nil }
// Logger logs statistics about every connection. // Logger logs statistics about every connection.
// At present, it only logs connections within a tailscale network. // At present, it only logs connections within a tailscale network.
@ -130,16 +129,15 @@ func (nl *Logger) Startup(nodeID tailcfg.StableNodeID, nodeLogID, domainLogID lo
}, log.Printf) }, log.Printf)
nl.logger = logger nl.logger = logger
stats := new(connstats.Statistics)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
nl.cancel = cancel nl.cancel = cancel
nl.group.Go(func() error { nl.group.Go(func() error {
tun.SetStatisticsEnabled(true) tun.SetStatistics(stats)
defer tun.SetStatisticsEnabled(false) defer tun.SetStatistics(nil)
tun.ExtractStatistics() // clear out any stale statistics
sock.SetStatisticsEnabled(true) sock.SetStatistics(stats)
defer sock.SetStatisticsEnabled(false) defer sock.SetStatistics(nil)
sock.ExtractStatistics() // clear out any stale statistics
start := time.Now() start := time.Now()
ticker := time.NewTicker(pollPeriod) ticker := time.NewTicker(pollPeriod)
@ -147,22 +145,20 @@ func (nl *Logger) Startup(nodeID tailcfg.StableNodeID, nodeLogID, domainLogID lo
var end time.Time var end time.Time
select { select {
case <-ctx.Done(): case <-ctx.Done():
tun.SetStatisticsEnabled(false)
end = time.Now() end = time.Now()
case end = <-ticker.C: case end = <-ticker.C:
} }
// NOTE: tunStats and sockStats will always be slightly out-of-sync. // NOTE: connstats and sockStats will always be slightly out-of-sync.
// It is impossible to have an atomic snapshot of statistics // It is impossible to have an atomic snapshot of statistics
// at both layers without a global mutex that spans all layers. // at both layers without a global mutex that spans all layers.
tunStats := tun.ExtractStatistics() connstats, sockStats := stats.Extract()
sockStats := sock.ExtractStatistics() if len(connstats)+len(sockStats) > 0 {
if len(tunStats)+len(sockStats) > 0 {
nl.mu.Lock() nl.mu.Lock()
addrs := nl.addrs addrs := nl.addrs
prefixes := nl.prefixes prefixes := nl.prefixes
nl.mu.Unlock() nl.mu.Unlock()
recordStatistics(logger, nodeID, start, end, tunStats, sockStats, addrs, prefixes) recordStatistics(logger, nodeID, start, end, connstats, sockStats, addrs, prefixes)
} }
if ctx.Err() != nil { if ctx.Err() != nil {
@ -175,7 +171,7 @@ func (nl *Logger) Startup(nodeID tailcfg.StableNodeID, nodeLogID, domainLogID lo
return nil return nil
} }
func recordStatistics(logger *logtail.Logger, nodeID tailcfg.StableNodeID, start, end time.Time, tunStats, sockStats map[netlogtype.Connection]netlogtype.Counts, addrs map[netip.Addr]bool, prefixes map[netip.Prefix]bool) { func recordStatistics(logger *logtail.Logger, nodeID tailcfg.StableNodeID, start, end time.Time, connstats, sockStats map[netlogtype.Connection]netlogtype.Counts, addrs map[netip.Addr]bool, prefixes map[netip.Prefix]bool) {
m := netlogtype.Message{NodeID: nodeID, Start: start.UTC(), End: end.UTC()} m := netlogtype.Message{NodeID: nodeID, Start: start.UTC(), End: end.UTC()}
classifyAddr := func(a netip.Addr) (isTailscale, withinRoute bool) { classifyAddr := func(a netip.Addr) (isTailscale, withinRoute bool) {
@ -194,7 +190,7 @@ func recordStatistics(logger *logtail.Logger, nodeID tailcfg.StableNodeID, start
} }
exitTraffic := make(map[netlogtype.Connection]netlogtype.Counts) exitTraffic := make(map[netlogtype.Connection]netlogtype.Counts)
for conn, cnts := range tunStats { for conn, cnts := range connstats {
srcIsTailscaleIP, srcWithinSubnet := classifyAddr(conn.Src.Addr()) srcIsTailscaleIP, srcWithinSubnet := classifyAddr(conn.Src.Addr())
dstIsTailscaleIP, dstWithinSubnet := classifyAddr(conn.Dst.Addr()) dstIsTailscaleIP, dstWithinSubnet := classifyAddr(conn.Dst.Addr())
switch { switch {

@ -1,66 +0,0 @@
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package netlog
import (
"context"
"net/http"
"testing"
qt "github.com/frankban/quicktest"
"tailscale.com/logtail"
"tailscale.com/tstest"
"tailscale.com/types/netlogtype"
"tailscale.com/util/must"
"tailscale.com/wgengine/router"
)
func init() {
testClient = &http.Client{Transport: &roundTripper}
}
var roundTripper roundTripperFunc
type roundTripperFunc struct {
F func(*http.Request) (*http.Response, error)
}
func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return f.F(r)
}
type fakeDevice struct {
toggled int // even => disabled, odd => enabled
}
func (d *fakeDevice) SetStatisticsEnabled(enable bool) {
if enabled := d.toggled%2 == 1; enabled != enable {
d.toggled++
}
}
func (fakeDevice) ExtractStatistics() map[netlogtype.Connection]netlogtype.Counts {
// TODO(dsnet): Add a test that verifies that statistics are correctly
// extracted from the device and uploaded. Unfortunately,
// we can't reliably run this test until we fix http://go/oss/5856.
return nil
}
func TestResourceCheck(t *testing.T) {
roundTripper.F = func(r *http.Request) (*http.Response, error) {
return &http.Response{StatusCode: 200}, nil
}
c := qt.New(t)
tstest.ResourceCheck(t)
var l Logger
var d fakeDevice
for i := 0; i < 10; i++ {
must.Do(l.Startup("", logtail.PrivateID{}, logtail.PrivateID{}, &d, nil))
l.ReconfigRoutes(&router.Config{})
must.Do(l.Shutdown(context.Background()))
c.Assert(d.toggled, qt.Equals, 2*(i+1))
}
}
Loading…
Cancel
Save