net/tstun: add tests for captureHook

Signed-off-by: Andrew Dunham <andrew@du.nham.ca>
Change-Id: I630f852d9f16c951c721b34f2bc4128e68fe9475
pull/7941/head
Andrew Dunham 2 years ago
parent c791e64881
commit 04a3118d45

@ -93,6 +93,9 @@ type Wrapper struct {
destMACAtomic syncs.AtomicValue[[6]byte] destMACAtomic syncs.AtomicValue[[6]byte]
discoKey syncs.AtomicValue[key.DiscoPublic] discoKey syncs.AtomicValue[key.DiscoPublic]
// timeNow, if non-nil, will be used to obtain the current time.
timeNow func() time.Time
// natV4Config stores the current NAT configuration. // natV4Config stores the current NAT configuration.
natV4Config atomic.Pointer[natV4Config] natV4Config atomic.Pointer[natV4Config]
@ -258,6 +261,15 @@ func wrap(logf logger.Logf, tdev tun.Device, isTAP bool) *Wrapper {
return w return w
} }
// now returns the current time, either by calling t.timeNow if set or time.Now
// if not.
func (t *Wrapper) now() time.Time {
if t.timeNow != nil {
return t.timeNow()
}
return time.Now()
}
// SetDestIPActivityFuncs sets a map of funcs to run per packet // SetDestIPActivityFuncs sets a map of funcs to run per packet
// destination (the map keys). // destination (the map keys).
// //
@ -724,7 +736,7 @@ func (t *Wrapper) Read(buffs [][]byte, sizes []int, offset int) (int, error) {
} }
} }
if captHook != nil { if captHook != nil {
captHook(capture.FromLocal, time.Now(), p.Buffer(), p.CaptureMeta) captHook(capture.FromLocal, t.now(), p.Buffer(), p.CaptureMeta)
} }
if !t.disableFilter { if !t.disableFilter {
response := t.filterPacketOutboundToWireGuard(p) response := t.filterPacketOutboundToWireGuard(p)
@ -791,7 +803,7 @@ func (t *Wrapper) injectedRead(res tunInjectedRead, buf []byte, offset int) (int
func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook capture.Callback) filter.Response { func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook capture.Callback) filter.Response {
if captHook != nil { if captHook != nil {
captHook(capture.FromPeer, time.Now(), p.Buffer(), p.CaptureMeta) captHook(capture.FromPeer, t.now(), p.Buffer(), p.CaptureMeta)
} }
if p.IPProto == ipproto.TSMP { if p.IPProto == ipproto.TSMP {
@ -959,7 +971,7 @@ func (t *Wrapper) InjectInboundPacketBuffer(pkt stack.PacketBufferPtr) error {
p.Decode(buf[PacketStartOffset:]) p.Decode(buf[PacketStartOffset:])
captHook := t.captureHook.Load() captHook := t.captureHook.Load()
if captHook != nil { if captHook != nil {
captHook(capture.SynthesizedToLocal, time.Now(), p.Buffer(), p.CaptureMeta) captHook(capture.SynthesizedToLocal, t.now(), p.Buffer(), p.CaptureMeta)
} }
t.dnatV4(p) t.dnatV4(p)
@ -1037,14 +1049,14 @@ func (t *Wrapper) injectOutboundPong(pp *packet.Parsed, req packet.TSMPPingReque
// It does not block, but takes ownership of the packet. // It does not block, but takes ownership of the packet.
// The injected packet will not pass through outbound filters. // The injected packet will not pass through outbound filters.
// Injecting an empty packet is a no-op. // Injecting an empty packet is a no-op.
func (t *Wrapper) InjectOutbound(packet []byte) error { func (t *Wrapper) InjectOutbound(pkt []byte) error {
if len(packet) > MaxPacketSize { if len(pkt) > MaxPacketSize {
return errPacketTooBig return errPacketTooBig
} }
if len(packet) == 0 { if len(pkt) == 0 {
return nil return nil
} }
t.injectOutbound(tunInjectedRead{data: packet}) t.injectOutbound(tunInjectedRead{data: pkt})
return nil return nil
} }
@ -1063,7 +1075,7 @@ func (t *Wrapper) InjectOutboundPacketBuffer(pkt stack.PacketBufferPtr) error {
} }
if capt := t.captureHook.Load(); capt != nil { if capt := t.captureHook.Load(); capt != nil {
b := pkt.ToBuffer() b := pkt.ToBuffer()
capt(capture.SynthesizedToPeer, time.Now(), b.Flatten(), packet.CaptureMeta{}) capt(capture.SynthesizedToPeer, t.now(), b.Flatten(), packet.CaptureMeta{})
} }
t.injectOutbound(tunInjectedRead{packet: pkt}) t.injectOutbound(tunInjectedRead{packet: pkt})

@ -10,9 +10,11 @@ import (
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"net/netip" "net/netip"
"reflect"
"strconv" "strconv"
"strings" "strings"
"testing" "testing"
"time"
"unicode" "unicode"
"unsafe" "unsafe"
@ -21,6 +23,8 @@ import (
"github.com/tailscale/wireguard-go/tun/tuntest" "github.com/tailscale/wireguard-go/tun/tuntest"
"go4.org/mem" "go4.org/mem"
"go4.org/netipx" "go4.org/netipx"
"gvisor.dev/gvisor/pkg/bufferv2"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"tailscale.com/disco" "tailscale.com/disco"
"tailscale.com/net/connstats" "tailscale.com/net/connstats"
"tailscale.com/net/netaddr" "tailscale.com/net/netaddr"
@ -33,6 +37,7 @@ import (
"tailscale.com/types/netlogtype" "tailscale.com/types/netlogtype"
"tailscale.com/types/ptr" "tailscale.com/types/ptr"
"tailscale.com/util/must" "tailscale.com/util/must"
"tailscale.com/wgengine/capture"
"tailscale.com/wgengine/filter" "tailscale.com/wgengine/filter"
"tailscale.com/wgengine/wgcfg" "tailscale.com/wgengine/wgcfg"
) )
@ -766,3 +771,93 @@ func TestNATCfg(t *testing.T) {
}) })
} }
} }
// TestCaptureHook verifies that the Wrapper.captureHook callback is called
// with the correct parameters when various packet operations are performed.
func TestCaptureHook(t *testing.T) {
type captureRecord struct {
path capture.Path
now time.Time
pkt []byte
meta packet.CaptureMeta
}
var captured []captureRecord
hook := func(path capture.Path, now time.Time, pkt []byte, meta packet.CaptureMeta) {
captured = append(captured, captureRecord{
path: path,
now: now,
pkt: pkt,
meta: meta,
})
}
now := time.Unix(1682085856, 0)
_, w := newFakeTUN(t.Logf, true)
w.timeNow = func() time.Time {
return now
}
w.InstallCaptureHook(hook)
defer w.Close()
// Loop reading and discarding packets; this ensures that we don't have
// packets stuck in vectorOutbound
go func() {
var (
buf [MaxPacketSize]byte
sizes = make([]int, 1)
)
for {
_, err := w.Read([][]byte{buf[:]}, sizes, 0)
if err != nil {
return
}
}
}()
// Do operations that should result in a packet being captured.
w.Write([][]byte{
[]byte("Write1"),
[]byte("Write2"),
}, 0)
packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: bufferv2.MakeWithData([]byte("InjectInboundPacketBuffer")),
})
w.InjectInboundPacketBuffer(packetBuf)
packetBuf = stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: bufferv2.MakeWithData([]byte("InjectOutboundPacketBuffer")),
})
w.InjectOutboundPacketBuffer(packetBuf)
// TODO: test Read
// TODO: determine if we want InjectOutbound to log
// Assert that the right packets are captured.
want := []captureRecord{
{
path: capture.FromPeer,
pkt: []byte("Write1"),
},
{
path: capture.FromPeer,
pkt: []byte("Write2"),
},
{
path: capture.SynthesizedToLocal,
pkt: []byte("InjectInboundPacketBuffer"),
},
{
path: capture.SynthesizedToPeer,
pkt: []byte("InjectOutboundPacketBuffer"),
},
}
for i := 0; i < len(want); i++ {
want[i].now = now
}
if !reflect.DeepEqual(captured, want) {
t.Errorf("mismatch between captured and expected packets\ngot: %+v\nwant: %+v",
captured, want)
}
}

Loading…
Cancel
Save