diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go index c8003659c..efcb632bd 100644 --- a/net/tstun/wrap.go +++ b/net/tstun/wrap.go @@ -90,7 +90,7 @@ type Wrapper struct { // to discard an empty packet instead of sending it through t.outbound. outbound chan []byte - // fitler stores the currently active package filter + // filter atomically stores the currently active packet filter filter atomic.Value // of *filter.Filter // filterFlags control the verbosity of logging packet drops/accepts. filterFlags filter.RunFlags diff --git a/wgengine/bench/bench.go b/wgengine/bench/bench.go new file mode 100644 index 000000000..1d6d0eb26 --- /dev/null +++ b/wgengine/bench/bench.go @@ -0,0 +1,398 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Create two wgengine instances and pass data through them, measuring +// throughput, latency, and packet loss. +package main + +import ( + "bufio" + "io" + "log" + "net" + "net/http" + "net/http/pprof" + "os" + "strconv" + "time" + + "inet.af/netaddr" + "tailscale.com/types/logger" +) + +const PayloadSize = 1000 +const ICMPMinSize = 24 + +var Addr1 = netaddr.MustParseIPPrefix("100.64.1.1/32") +var Addr2 = netaddr.MustParseIPPrefix("100.64.1.2/32") + +func main() { + var logf logger.Logf = log.Printf + log.SetFlags(0) + + debugMux := newDebugMux() + go runDebugServer(debugMux, "0.0.0.0:8999") + + mode, err := strconv.Atoi(os.Args[1]) + if err != nil { + log.Fatalf("%q: %v", os.Args[1], err) + } + + traf := NewTrafficGen(nil) + + // Sample test results below are using GOMAXPROCS=2 (for some + // tests, including wireguard-go, higher GOMAXPROCS goes slower) + // on apenwarr's old Linux box: + // Intel(R) Core(TM) i7-4785T CPU @ 2.20GHz + // My 2019 Mac Mini is about 20% faster on most tests. + + switch mode { + // tx=8786325 rx=8786326 (0 = 0.00% loss) (70768.7 Mbits/sec) + case 1: + setupTrivialNoAllocTest(logf, traf) + + // tx=6476293 rx=6476293 (0 = 0.00% loss) (52249.7 Mbits/sec) + case 2: + setupTrivialTest(logf, traf) + + // tx=1957974 rx=1958379 (0 = 0.00% loss) (15939.8 Mbits/sec) + case 11: + setupBlockingChannelTest(logf, traf) + + // tx=728621 rx=701825 (26620 = 3.65% loss) (5525.2 Mbits/sec) + // (much faster on macOS??) + case 12: + setupNonblockingChannelTest(logf, traf) + + // tx=1024260 rx=941098 (83334 = 8.14% loss) (7516.6 Mbits/sec) + // (much faster on macOS??) + case 13: + setupDoubleChannelTest(logf, traf) + + // tx=265468 rx=263189 (2279 = 0.86% loss) (2162.0 Mbits/sec) + case 21: + setupUDPTest(logf, traf) + + // tx=1493580 rx=1493580 (0 = 0.00% loss) (12210.4 Mbits/sec) + case 31: + setupBatchTCPTest(logf, traf) + + // tx=134236 rx=133166 (1070 = 0.80% loss) (1088.9 Mbits/sec) + case 101: + setupWGTest(logf, traf, Addr1, Addr2) + + default: + log.Fatalf("provide a valid test number (0..n)") + } + + logf("initialized ok.") + traf.Start(Addr1.IP, Addr2.IP, PayloadSize+ICMPMinSize, 0) + + var cur, prev Snapshot + var pps float64 + i := 0 + for { + i += 1 + time.Sleep(10 * time.Millisecond) + + if (i % 100) == 0 { + prev = cur + cur = traf.Snap() + d := cur.Sub(prev) + + if prev.WhenNsec == 0 { + logf("tx=%-6d rx=%-6d", d.TxPackets, d.RxPackets) + } else { + logf("%v @%7.0f pkt/sec", d, pps) + } + } + + pps = traf.Adjust() + } +} + +func newDebugMux() *http.ServeMux { + mux := http.NewServeMux() + mux.HandleFunc("/debug/pprof/", pprof.Index) + mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) + mux.HandleFunc("/debug/pprof/profile", pprof.Profile) + mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) + mux.HandleFunc("/debug/pprof/trace", pprof.Trace) + return mux +} + +func runDebugServer(mux *http.ServeMux, addr string) { + srv := &http.Server{ + Addr: addr, + Handler: mux, + } + if err := srv.ListenAndServe(); err != nil { + log.Fatal(err) + } +} + +// The absolute minimal test of the traffic generator: have it fill +// a packet buffer, then absorb it again. Zero packet loss. +func setupTrivialNoAllocTest(logf logger.Logf, traf *TrafficGen) { + go func() { + b := make([]byte, 1600) + for { + n := traf.Generate(b, 16) + if n == 0 { + break + } + traf.GotPacket(b[0:n+16], 16) + } + }() +} + +// Almost the same, but this time allocate a fresh buffer each time +// through the loop. Still zero packet loss. Runs about 2/3 as fast for me. +func setupTrivialTest(logf logger.Logf, traf *TrafficGen) { + go func() { + for { + b := make([]byte, 1600) + n := traf.Generate(b, 16) + if n == 0 { + break + } + traf.GotPacket(b[0:n+16], 16) + } + }() +} + +// Pass packets through a blocking channel between sender and receiver. +// Still zero packet loss since the sender stops when the channel is full. +// Max speed depends on channel length (I'm not sure why). +func setupBlockingChannelTest(logf logger.Logf, traf *TrafficGen) { + ch := make(chan []byte, 1000) + + go func() { + // transmitter + for { + b := make([]byte, 1600) + n := traf.Generate(b, 16) + if n == 0 { + close(ch) + break + } + ch <- b[0 : n+16] + } + }() + + go func() { + // receiver + for b := range ch { + traf.GotPacket(b, 16) + } + }() +} + +// Same as setupBlockingChannelTest, but now we drop packets whenever the +// channel is full. Max speed is about the same as the above test, but +// now with nonzero packet loss. +func setupNonblockingChannelTest(logf logger.Logf, traf *TrafficGen) { + ch := make(chan []byte, 1000) + + go func() { + // transmitter + for { + b := make([]byte, 1600) + n := traf.Generate(b, 16) + if n == 0 { + close(ch) + break + } + select { + case ch <- b[0 : n+16]: + default: + } + } + }() + + go func() { + // receiver + for b := range ch { + traf.GotPacket(b, 16) + } + }() +} + +// Same as above, but at an intermediate blocking channel and goroutine +// to make things a little more like wireguard-go. Roughly 20% slower than +// the single-channel verison. +func setupDoubleChannelTest(logf logger.Logf, traf *TrafficGen) { + ch := make(chan []byte, 1000) + ch2 := make(chan []byte, 1000) + + go func() { + // transmitter + for { + b := make([]byte, 1600) + n := traf.Generate(b, 16) + if n == 0 { + close(ch) + break + } + select { + case ch <- b[0 : n+16]: + default: + } + } + }() + + go func() { + // intermediary + for b := range ch { + ch2 <- b + } + close(ch2) + }() + + go func() { + // receiver + for b := range ch2 { + traf.GotPacket(b, 16) + } + }() +} + +// Instead of a channel, pass packets through a UDP socket. +func setupUDPTest(logf logger.Logf, traf *TrafficGen) { + la, err := net.ResolveUDPAddr("udp", ":0") + if err != nil { + log.Fatalf("resolve: %v", err) + } + + s1, err := net.ListenUDP("udp", la) + if err != nil { + log.Fatalf("listen1: %v", err) + } + s2, err := net.ListenUDP("udp", la) + if err != nil { + log.Fatalf("listen2: %v", err) + } + + a2 := s2.LocalAddr() + + // On macOS (but not Linux), you can't transmit to 0.0.0.0:port, + // which is what returns from .LocalAddr() above. We have to + // force it to localhost instead. + a2.(*net.UDPAddr).IP = net.ParseIP("127.0.0.1") + + s1.SetWriteBuffer(1024 * 1024) + s2.SetReadBuffer(1024 * 1024) + + go func() { + // transmitter + b := make([]byte, 1600) + for { + n := traf.Generate(b, 16) + if n == 0 { + break + } + s1.WriteTo(b[16:n+16], a2) + } + }() + + go func() { + // receiver + b := make([]byte, 1600) + for traf.Running() { + // Use ReadFrom instead of Read, to be more like + // how wireguard-go does it, even though we're not + // going to actually look at the address. + n, _, err := s2.ReadFrom(b) + if err != nil { + log.Fatalf("s2.Read: %v", err) + } + traf.GotPacket(b[:n], 0) + } + }() +} + +// Instead of a channel, pass packets through a TCP socket. +// TCP is a single stream, so we can amortize one syscall across +// multiple packets. 10x amortization seems to make it go ~10x faster, +// as expected, getting us close to the speed of the channel tests above. +// There's also zero packet loss. +func setupBatchTCPTest(logf logger.Logf, traf *TrafficGen) { + sl, err := net.Listen("tcp", ":0") + if err != nil { + log.Fatalf("listen: %v", err) + } + + s1, err := net.Dial("tcp", sl.Addr().String()) + if err != nil { + log.Fatalf("dial: %v", err) + } + + s2, err := sl.Accept() + if err != nil { + log.Fatalf("accept: %v", err) + } + + s1.(*net.TCPConn).SetWriteBuffer(1024 * 1024) + s2.(*net.TCPConn).SetReadBuffer(1024 * 1024) + + ch := make(chan int) + + go func() { + // transmitter + + bs1 := bufio.NewWriterSize(s1, 1024*1024) + + b := make([]byte, 1600) + i := 0 + for { + i += 1 + n := traf.Generate(b, 16) + if n == 0 { + break + } + if i == 1 { + ch <- n + } + bs1.Write(b[16 : n+16]) + + // TODO: this is a pretty half-baked batching + // function, which we'd never want to employ in + // a real-life program. + // + // In real life, we'd probably want to flush + // immediately when there are no more packets to + // generate, and queue up only if we fall behind. + // + // In our case however, we just want to see the + // technical benefits of batching 10 syscalls + // into 1, so a fixed ratio makes more sense. + if (i % 10) == 0 { + bs1.Flush() + } + } + }() + + go func() { + // receiver + + bs2 := bufio.NewReaderSize(s2, 1024*1024) + + // Find out the packet size (we happen to know they're + // all the same size) + packetSize := <-ch + + b := make([]byte, packetSize) + for traf.Running() { + // TODO: can't use ReadFrom() here, which is + // unfair compared to UDP. (ReadFrom for UDP + // apparently allocates memory per packet, which + // this test does not.) + n, err := io.ReadFull(bs2, b) + if err != nil { + log.Fatalf("s2.Read: %v", err) + } + traf.GotPacket(b[:n], 0) + } + }() +} diff --git a/wgengine/bench/bench_test.go b/wgengine/bench/bench_test.go new file mode 100644 index 000000000..55a9f3d7f --- /dev/null +++ b/wgengine/bench/bench_test.go @@ -0,0 +1,108 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Create two wgengine instances and pass data through them, measuring +// throughput, latency, and packet loss. +package main + +import ( + "fmt" + "testing" + "time" + + "tailscale.com/types/logger" +) + +func BenchmarkTrivialNoAlloc(b *testing.B) { + run(b, setupTrivialNoAllocTest) +} +func BenchmarkTrivial(b *testing.B) { + run(b, setupTrivialTest) +} + +func BenchmarkBlockingChannel(b *testing.B) { + run(b, setupBlockingChannelTest) +} + +func BenchmarkNonblockingChannel(b *testing.B) { + run(b, setupNonblockingChannelTest) +} + +func BenchmarkDoubleChannel(b *testing.B) { + run(b, setupDoubleChannelTest) +} + +func BenchmarkUDP(b *testing.B) { + run(b, setupUDPTest) +} + +func BenchmarkBatchTCP(b *testing.B) { + run(b, setupBatchTCPTest) +} + +func BenchmarkWireGuardTest(b *testing.B) { + run(b, func(logf logger.Logf, traf *TrafficGen) { + setupWGTest(logf, traf, Addr1, Addr2) + }) +} + +type SetupFunc func(logger.Logf, *TrafficGen) + +func run(b *testing.B, setup SetupFunc) { + sizes := []int{ + ICMPMinSize + 8, + ICMPMinSize + 100, + ICMPMinSize + 1000, + } + + for _, size := range sizes { + b.Run(fmt.Sprintf("%d", size), func(b *testing.B) { + runOnce(b, setup, size) + }) + } +} + +func runOnce(b *testing.B, setup SetupFunc, payload int) { + b.StopTimer() + b.ReportAllocs() + + var logf logger.Logf = b.Logf + if !testing.Verbose() { + logf = logger.Discard + } + + traf := NewTrafficGen(b.StartTimer) + setup(logf, traf) + + logf("initialized. (n=%v)", b.N) + b.SetBytes(int64(payload)) + + traf.Start(Addr1.IP, Addr2.IP, payload, int64(b.N)) + + var cur, prev Snapshot + var pps float64 + i := 0 + for traf.Running() { + i += 1 + time.Sleep(10 * time.Millisecond) + + if (i % 100) == 0 { + prev = cur + cur = traf.Snap() + d := cur.Sub(prev) + + if prev.WhenNsec != 0 { + logf("%v @%7.0f pkt/sec", d, pps) + } + } + + pps = traf.Adjust() + } + + cur = traf.Snap() + d := cur.Sub(prev) + loss := float64(d.LostPackets) / float64(d.RxPackets) + + b.ReportMetric(loss*100, "%lost") +} diff --git a/wgengine/bench/trafficgen.go b/wgengine/bench/trafficgen.go new file mode 100644 index 000000000..93e63594d --- /dev/null +++ b/wgengine/bench/trafficgen.go @@ -0,0 +1,248 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "encoding/binary" + "fmt" + "log" + "sync" + "time" + + "inet.af/netaddr" + "tailscale.com/net/packet" + "tailscale.com/types/ipproto" +) + +type Snapshot struct { + WhenNsec int64 // current time + timeAcc int64 // accumulated time (+NSecPerTx per transmit) + + LastSeqTx int64 // last sequence number sent + LastSeqRx int64 // last sequence number received + TotalLost int64 // packets out-of-order or lost so far + TotalOOO int64 // packets out-of-order so far + TotalBytesRx int64 // total bytes received so far +} + +type Delta struct { + DurationNsec int64 + TxPackets int64 + RxPackets int64 + LostPackets int64 + OOOPackets int64 + Bytes int64 +} + +func (b Snapshot) Sub(a Snapshot) Delta { + return Delta{ + DurationNsec: b.WhenNsec - a.WhenNsec, + TxPackets: b.LastSeqTx - a.LastSeqTx, + RxPackets: (b.LastSeqRx - a.LastSeqRx) - + (b.TotalLost - a.TotalLost) + + (b.TotalOOO - a.TotalOOO), + LostPackets: b.TotalLost - a.TotalLost, + OOOPackets: b.TotalOOO - a.TotalOOO, + Bytes: b.TotalBytesRx - a.TotalBytesRx, + } +} + +func (d Delta) String() string { + return fmt.Sprintf("tx=%-6d rx=%-4d (%6d = %.1f%% loss) (%d OOO) (%4.1f Mbit/s)", + d.TxPackets, d.RxPackets, d.LostPackets, + float64(d.LostPackets)*100/float64(d.TxPackets), + d.OOOPackets, + float64(d.Bytes)*8*1e9/float64(d.DurationNsec)/1e6) +} + +type TrafficGen struct { + mu sync.Mutex + cur, prev Snapshot // snapshots used for rate control + buf []byte // pre-generated packet buffer + done bool // true if the test has completed + + onFirstPacket func() // function to call on first received packet + + // maxPackets is the max packets to receive (not send) before + // ending the test. If it's zero, the test runs forever. + maxPackets int64 + + // nsPerPacket is the target average nanoseconds between packets. + // It's initially zero, which means transmit as fast as the + // caller wants to go. + nsPerPacket int64 + + // bestPPS is the "best observed packets-per-second" in recent + // memory. + bestPPS float64 +} + +// NewTrafficGen creates a new, initially locked, TrafficGen. +// Until Start() is called, Generate() will block forever. +func NewTrafficGen(onFirstPacket func()) *TrafficGen { + t := TrafficGen{ + onFirstPacket: onFirstPacket, + } + + // initially locked, until first Start() + t.mu.Lock() + + return &t +} + +// Start starts the traffic generator. It assumes mu is already locked, +// and unlocks it. +func (t *TrafficGen) Start(src, dst netaddr.IP, bytesPerPacket int, maxPackets int64) { + h12 := packet.ICMP4Header{ + IP4Header: packet.IP4Header{ + IPProto: ipproto.ICMPv4, + IPID: 0, + Src: src, + Dst: dst, + }, + Type: packet.ICMP4EchoRequest, + Code: packet.ICMP4NoCode, + } + + // ensure there's room for ICMP header plus sequence number + if bytesPerPacket < ICMPMinSize+8 { + log.Fatalf("bytesPerPacket must be > 24+8") + } + + t.maxPackets = maxPackets + + payload := make([]byte, bytesPerPacket-ICMPMinSize) + t.buf = packet.Generate(h12, payload) + + t.mu.Unlock() +} + +func (t *TrafficGen) Snap() Snapshot { + t.mu.Lock() + defer t.mu.Unlock() + + t.cur.WhenNsec = time.Now().UnixNano() + return t.cur +} + +func (t *TrafficGen) Running() bool { + t.mu.Lock() + defer t.mu.Unlock() + + return !t.done +} + +// Generate produces the next packet in the sequence. It sleeps if +// it's too soon for the next packet to be sent. +// +// The generated packet is placed into buf at offset ofs, for compatibility +// with the wireguard-go conventions. +// +// The return value is the number of bytes generated in the packet, or 0 +// if the test has finished running. +func (t *TrafficGen) Generate(b []byte, ofs int) int { + t.mu.Lock() + + now := time.Now().UnixNano() + if t.nsPerPacket == 0 || t.cur.timeAcc == 0 { + t.cur.timeAcc = now - 1 + } + if t.cur.timeAcc >= now { + // too soon + t.mu.Unlock() + time.Sleep(time.Duration(t.cur.timeAcc-now) * time.Nanosecond) + t.mu.Lock() + + now = t.cur.timeAcc + } + if t.done { + t.mu.Unlock() + return 0 + } + + t.cur.timeAcc += t.nsPerPacket + t.cur.LastSeqTx += 1 + t.cur.WhenNsec = now + seq := t.cur.LastSeqTx + + t.mu.Unlock() + + copy(b[ofs:], t.buf) + binary.BigEndian.PutUint64( + b[ofs+ICMPMinSize:ofs+ICMPMinSize+8], + uint64(seq)) + + return len(t.buf) +} + +// GotPacket processes a packet that came back on the receive side. +func (t *TrafficGen) GotPacket(b []byte, ofs int) { + t.mu.Lock() + + s := &t.cur + seq := int64(binary.BigEndian.Uint64( + b[ofs+ICMPMinSize : ofs+ICMPMinSize+8])) + if seq > s.LastSeqRx { + if s.LastSeqRx > 0 { + // only count lost packets after the very first + // successful one. + s.TotalLost += seq - s.LastSeqRx - 1 + } + s.LastSeqRx = seq + } else { + s.TotalOOO += 1 + } + + // +1 packet since we only start counting after the first one + if t.maxPackets > 0 && s.LastSeqRx >= t.maxPackets+1 { + t.done = true + } + s.TotalBytesRx += int64(len(b) - ofs) + + f := t.onFirstPacket + t.onFirstPacket = nil + + t.mu.Unlock() + + if f != nil { + f() + } +} + +// Adjust tunes the transmit rate based on the received packets. +// The goal is to converge on the fastest transmit rate that still has +// minimal packet loss. Returns the new target rate in packets/sec. +// +// We need to play this guessing game in order to balance out tx and rx +// rates when there's a lossy network between them. Otherwise we can end +// up using 99% of the CPU to blast out transmitted packets and leaving only +// 1% to receive them, leading to a misleading throughput calculation. +// +// Call this function multiple times per second. +func (t *TrafficGen) Adjust() (pps float64) { + t.mu.Lock() + defer t.mu.Unlock() + + // don't adjust rate until the first full period *after* receiving + // the first packet. This skips any handshake time in the underlying + // transport. + if t.prev.LastSeqRx == 0 { + t.prev = t.cur + return 0 // no estimate yet, continue at max speed + } + + d := t.cur.Sub(t.prev) + t.bestPPS *= 0.97 + pps = float64(d.RxPackets) * 1e9 / float64(d.DurationNsec) + if pps > 0 && t.prev.WhenNsec > 0 { + if pps > t.bestPPS { + t.bestPPS = pps + } + t.nsPerPacket = int64(1e9 / t.bestPPS) + } + t.prev = t.cur + + return t.bestPPS +} diff --git a/wgengine/bench/wg.go b/wgengine/bench/wg.go new file mode 100644 index 000000000..acf04f6bd --- /dev/null +++ b/wgengine/bench/wg.go @@ -0,0 +1,205 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "io" + "log" + "os" + "strings" + "sync" + + "github.com/tailscale/wireguard-go/tun" + "inet.af/netaddr" + + "tailscale.com/net/dns" + "tailscale.com/tailcfg" + "tailscale.com/types/logger" + "tailscale.com/types/netmap" + "tailscale.com/types/wgkey" + "tailscale.com/wgengine" + "tailscale.com/wgengine/filter" + "tailscale.com/wgengine/router" + "tailscale.com/wgengine/wgcfg" +) + +func setupWGTest(logf logger.Logf, traf *TrafficGen, a1, a2 netaddr.IPPrefix) { + l1 := logger.WithPrefix(logf, "e1: ") + k1, err := wgcfg.NewPrivateKey() + if err != nil { + log.Fatalf("e1 NewPrivateKey: %v", err) + } + c1 := wgcfg.Config{ + Name: "e1", + PrivateKey: k1, + Addresses: []netaddr.IPPrefix{a1}, + } + t1 := &sourceTun{ + logf: logger.WithPrefix(logf, "tun1: "), + traf: traf, + } + e1, err := wgengine.NewUserspaceEngine(l1, wgengine.Config{ + Router: router.NewFake(l1), + LinkMonitor: nil, + ListenPort: 0, + Tun: t1, + }) + if err != nil { + log.Fatalf("e1 init: %v", err) + } + + l2 := logger.WithPrefix(logf, "e2: ") + k2, err := wgcfg.NewPrivateKey() + if err != nil { + log.Fatalf("e2 NewPrivateKey: %v", err) + } + c2 := wgcfg.Config{ + Name: "e2", + PrivateKey: k2, + Addresses: []netaddr.IPPrefix{a2}, + } + t2 := &sinkTun{ + logf: logger.WithPrefix(logf, "tun2: "), + traf: traf, + } + e2, err := wgengine.NewUserspaceEngine(l2, wgengine.Config{ + Router: router.NewFake(l2), + LinkMonitor: nil, + ListenPort: 0, + Tun: t2, + }) + if err != nil { + log.Fatalf("e2 init: %v", err) + } + + e1.SetFilter(filter.NewAllowAllForTest(l1)) + e2.SetFilter(filter.NewAllowAllForTest(l2)) + + var wait sync.WaitGroup + wait.Add(2) + + e1.SetStatusCallback(func(st *wgengine.Status, err error) { + if err != nil { + log.Fatalf("e1 status err: %v", err) + } + logf("e1 status: %v", *st) + + var eps []string + for _, ep := range st.LocalAddrs { + eps = append(eps, ep.Addr.String()) + } + + n := tailcfg.Node{ + ID: tailcfg.NodeID(0), + Name: "n1", + Addresses: []netaddr.IPPrefix{a1}, + AllowedIPs: []netaddr.IPPrefix{a1}, + Endpoints: eps, + } + e2.SetNetworkMap(&netmap.NetworkMap{ + NodeKey: tailcfg.NodeKey(k2), + PrivateKey: wgkey.Private(k2), + Peers: []*tailcfg.Node{&n}, + }) + + p := wgcfg.Peer{ + PublicKey: c1.PrivateKey.Public(), + AllowedIPs: []netaddr.IPPrefix{a1}, + Endpoints: strings.Join(eps, ","), + } + c2.Peers = []wgcfg.Peer{p} + e2.Reconfig(&c2, &router.Config{}, new(dns.Config)) + wait.Done() + }) + + e2.SetStatusCallback(func(st *wgengine.Status, err error) { + if err != nil { + log.Fatalf("e2 status err: %v", err) + } + logf("e2 status: %v", *st) + + var eps []string + for _, ep := range st.LocalAddrs { + eps = append(eps, ep.Addr.String()) + } + + n := tailcfg.Node{ + ID: tailcfg.NodeID(0), + Name: "n2", + Addresses: []netaddr.IPPrefix{a2}, + AllowedIPs: []netaddr.IPPrefix{a2}, + Endpoints: eps, + } + e1.SetNetworkMap(&netmap.NetworkMap{ + NodeKey: tailcfg.NodeKey(k1), + PrivateKey: wgkey.Private(k1), + Peers: []*tailcfg.Node{&n}, + }) + + p := wgcfg.Peer{ + PublicKey: c2.PrivateKey.Public(), + AllowedIPs: []netaddr.IPPrefix{a2}, + Endpoints: strings.Join(eps, ","), + } + c1.Peers = []wgcfg.Peer{p} + e1.Reconfig(&c1, &router.Config{}, new(dns.Config)) + wait.Done() + }) + + // Not using DERP in this test (for now?). + e1.SetDERPMap(&tailcfg.DERPMap{}) + e2.SetDERPMap(&tailcfg.DERPMap{}) + + wait.Wait() +} + +type sourceTun struct { + logf logger.Logf + traf *TrafficGen +} + +func (t *sourceTun) Close() error { return nil } +func (t *sourceTun) Events() chan tun.Event { return nil } +func (t *sourceTun) File() *os.File { return nil } +func (t *sourceTun) Flush() error { return nil } +func (t *sourceTun) MTU() (int, error) { return 1500, nil } +func (t *sourceTun) Name() (string, error) { return "source", nil } + +func (t *sourceTun) Write(b []byte, ofs int) (int, error) { + // Discard all writes + return len(b) - ofs, nil +} + +func (t *sourceTun) Read(b []byte, ofs int) (int, error) { + // Continually generate "input" packets + n := t.traf.Generate(b, ofs) + if n == 0 { + return 0, io.EOF + } + return n, nil +} + +type sinkTun struct { + logf logger.Logf + traf *TrafficGen +} + +func (t *sinkTun) Close() error { return nil } +func (t *sinkTun) Events() chan tun.Event { return nil } +func (t *sinkTun) File() *os.File { return nil } +func (t *sinkTun) Flush() error { return nil } +func (t *sinkTun) MTU() (int, error) { return 1500, nil } +func (t *sinkTun) Name() (string, error) { return "sink", nil } + +func (t *sinkTun) Read(b []byte, ofs int) (int, error) { + // Never returns + select {} +} + +func (t *sinkTun) Write(b []byte, ofs int) (int, error) { + // Count packets, but discard them + t.traf.GotPacket(b, ofs) + return len(b) - ofs, nil +}