From 02231e968ebb9a33df0301ebcd7e093ea83bc1cb Mon Sep 17 00:00:00 2001 From: Dmytro Shynkevych Date: Fri, 5 Jun 2020 11:19:03 -0400 Subject: [PATCH] wgengine/tstun: add tests and benchmarks (#436) Signed-off-by: Dmytro Shynkevych --- wgengine/tstun/faketun.go | 1 - wgengine/tstun/tun.go | 21 ++- wgengine/tstun/tun_test.go | 355 +++++++++++++++++++++++++++++++++++++ 3 files changed, 369 insertions(+), 8 deletions(-) create mode 100644 wgengine/tstun/tun_test.go diff --git a/wgengine/tstun/faketun.go b/wgengine/tstun/faketun.go index 4451762a8..6afc861e3 100644 --- a/wgengine/tstun/faketun.go +++ b/wgengine/tstun/faketun.go @@ -34,7 +34,6 @@ func (t *fakeTUN) File() *os.File { func (t *fakeTUN) Close() error { close(t.closechan) - close(t.datachan) close(t.evchan) return nil } diff --git a/wgengine/tstun/tun.go b/wgengine/tstun/tun.go index a4ed8dd08..90457d96c 100644 --- a/wgengine/tstun/tun.go +++ b/wgengine/tstun/tun.go @@ -72,6 +72,9 @@ type TUN struct { filter atomic.Value // of *filter.Filter // filterFlags control the verbosity of logging packet drops/accepts. filterFlags filter.RunFlags + + // insecure disables all filtering when set. This is useful in tests. + insecure bool } func WrapTUN(logf logger.Logf, tdev tun.Device) *TUN { @@ -202,10 +205,12 @@ func (t *TUN) Read(buf []byte, offset int) (int, error) { } } - response := t.filterOut(buf[offset : offset+n]) - if response != filter.Accept { - // Wireguard considers read errors fatal; pretend nothing was read - return 0, nil + if !t.insecure { + response := t.filterOut(buf[offset : offset+n]) + if response != filter.Accept { + // Wireguard considers read errors fatal; pretend nothing was read + return 0, nil + } } return n, nil @@ -240,9 +245,11 @@ func (t *TUN) filterIn(buf []byte) filter.Response { } func (t *TUN) Write(buf []byte, offset int) (int, error) { - response := t.filterIn(buf[offset:]) - if response != filter.Accept { - return 0, ErrFiltered + if !t.insecure { + response := t.filterIn(buf[offset:]) + if response != filter.Accept { + return 0, ErrFiltered + } } return t.tdev.Write(buf, offset) diff --git a/wgengine/tstun/tun_test.go b/wgengine/tstun/tun_test.go new file mode 100644 index 000000000..258ecc35d --- /dev/null +++ b/wgengine/tstun/tun_test.go @@ -0,0 +1,355 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tstun + +import ( + "bytes" + "testing" + + "github.com/tailscale/wireguard-go/tun/tuntest" + "tailscale.com/types/logger" + "tailscale.com/wgengine/filter" + "tailscale.com/wgengine/packet" +) + +func udp(src, dst packet.IP, sport, dport uint16) []byte { + header := &packet.UDPHeader{ + IPHeader: packet.IPHeader{ + SrcIP: src, + DstIP: dst, + IPID: 0, + }, + SrcPort: sport, + DstPort: dport, + } + return packet.Generate(header, []byte("udp_payload")) +} + +func nets(ips []packet.IP) []filter.Net { + out := make([]filter.Net, 0, len(ips)) + for _, ip := range ips { + out = append(out, filter.Net{ip, filter.Netmask(32)}) + } + return out +} + +func ippr(ip packet.IP, start, end uint16) []filter.NetPortRange { + return []filter.NetPortRange{ + filter.NetPortRange{filter.Net{ip, filter.Netmask(32)}, filter.PortRange{start, end}}, + } +} + +func setfilter(logf logger.Logf, tun *TUN) { + matches := filter.Matches{ + {Srcs: nets([]packet.IP{0x05060708}), Dsts: ippr(0x01020304, 89, 90)}, + {Srcs: nets([]packet.IP{0x01020304}), Dsts: ippr(0x05060708, 98, 98)}, + } + localNets := []filter.Net{ + {packet.IP(0x01020304), filter.Netmask(16)}, + } + tun.SetFilter(filter.New(matches, localNets, nil, logf)) +} + +func newChannelTUN(logf logger.Logf, secure bool) (*tuntest.ChannelTUN, *TUN) { + chtun := tuntest.NewChannelTUN() + tun := WrapTUN(logf, chtun.TUN()) + if secure { + setfilter(logf, tun) + } else { + tun.insecure = true + } + return chtun, tun +} + +func newFakeTUN(logf logger.Logf, secure bool) (*fakeTUN, *TUN) { + ftun := NewFakeTUN() + tun := WrapTUN(logf, ftun) + if secure { + setfilter(logf, tun) + } else { + tun.insecure = true + } + return ftun.(*fakeTUN), tun +} + +func TestReadAndInject(t *testing.T) { + chtun, tun := newChannelTUN(t.Logf, false) + defer tun.Close() + + const size = 2 // all payloads have this size + written := []string{"w0", "w1"} + injected := []string{"i0", "i1"} + + go func() { + for _, packet := range written { + payload := []byte(packet) + chtun.Outbound <- payload + } + }() + + for _, packet := range injected { + go func(packet string) { + payload := []byte(packet) + err := tun.InjectOutbound(payload) + if err != nil { + t.Errorf("%s: error: %v", packet, err) + } + }(packet) + } + + var buf [MaxPacketSize]byte + var seen = make(map[string]bool) + // We expect the same packets back, in no particular order. + for i := 0; i < len(written)+len(injected); i++ { + n, err := tun.Read(buf[:], 0) + if err != nil { + t.Errorf("read %d: error: %v", i, err) + } + if n != size { + t.Errorf("read %d: got size %d; want %d", i, n, size) + } + got := string(buf[:n]) + t.Logf("read %d: got %s", i, got) + seen[got] = true + } + + for _, packet := range written { + if !seen[packet] { + t.Errorf("%s not received", packet) + } + } + for _, packet := range injected { + if !seen[packet] { + t.Errorf("%s not received", packet) + } + } +} + +func TestWriteAndInject(t *testing.T) { + chtun, tun := newChannelTUN(t.Logf, false) + defer tun.Close() + + const size = 2 // all payloads have this size + written := []string{"w0", "w1"} + injected := []string{"i0", "i1"} + + go func() { + for _, packet := range written { + payload := []byte(packet) + n, err := tun.Write(payload, 0) + if err != nil { + t.Errorf("%s: error: %v", packet, err) + } + if n != size { + t.Errorf("%s: got size %d; want %d", packet, n, size) + } + } + }() + + for _, packet := range injected { + go func(packet string) { + payload := []byte(packet) + err := tun.InjectInbound(payload) + if err != nil { + t.Errorf("%s: error: %v", packet, err) + } + }(packet) + } + + seen := make(map[string]bool) + // We expect the same packets back, in no particular order. + for i := 0; i < len(written)+len(injected); i++ { + packet := <-chtun.Inbound + got := string(packet) + t.Logf("read %d: got %s", i, got) + seen[got] = true + } + + for _, packet := range written { + if !seen[packet] { + t.Errorf("%s not received", packet) + } + } + for _, packet := range injected { + if !seen[packet] { + t.Errorf("%s not received", packet) + } + } +} + +func TestFilter(t *testing.T) { + chtun, tun := newChannelTUN(t.Logf, true) + defer tun.Close() + + type direction int + + const ( + in direction = iota + out + ) + + tests := []struct { + name string + dir direction + drop bool + data []byte + }{ + {"junk_in", in, true, []byte("\x45not a valid IPv4 packet")}, + {"junk_out", out, true, []byte("\x45not a valid IPv4 packet")}, + {"bad_port_in", in, true, udp(0x05060708, 0x01020304, 22, 22)}, + {"bad_port_out", out, false, udp(0x01020304, 0x05060708, 22, 22)}, + {"bad_ip_in", in, true, udp(0x08010101, 0x01020304, 89, 89)}, + {"bad_ip_out", out, false, udp(0x01020304, 0x08010101, 98, 98)}, + {"good_packet_in", in, false, udp(0x05060708, 0x01020304, 89, 89)}, + {"good_packet_out", out, false, udp(0x01020304, 0x05060708, 98, 98)}, + } + + // A reader on the other end of the TUN. + go func() { + var recvbuf []byte + for { + select { + case <-tun.closed: + return + case recvbuf = <-chtun.Inbound: + // continue + } + for _, tt := range tests { + if tt.drop && bytes.Equal(recvbuf, tt.data) { + t.Errorf("did not drop %s", tt.name) + } + } + } + }() + + var buf [MaxPacketSize]byte + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var n int + var err error + var filtered bool + + if tt.dir == in { + _, err = tun.Write(tt.data, 0) + if err == ErrFiltered { + filtered = true + err = nil + } + } else { + chtun.Outbound <- tt.data + n, err = tun.Read(buf[:], 0) + // In the read direction, errors are fatal, so we return n = 0 instead. + filtered = (n == 0) + } + + if err != nil { + t.Errorf("got err %v; want nil", err) + } + + if filtered { + if !tt.drop { + t.Errorf("got drop; want accept") + } + } else { + if tt.drop { + t.Errorf("got accept; want drop") + } + } + }) + } +} + +func TestAllocs(t *testing.T) { + ftun, tun := newFakeTUN(t.Logf, false) + defer tun.Close() + + go func() { + var buf []byte + for { + select { + case <-tun.closed: + return + case buf = <-ftun.datachan: + // continue + } + + select { + case <-tun.closed: + return + case ftun.datachan <- buf: + // continue + } + } + }() + + buf := []byte{0x00} + allocs := testing.AllocsPerRun(100, func() { + _, err := tun.Write(buf, 0) + if err != nil { + t.Errorf("write: error: %v", err) + return + } + + _, err = tun.Read(buf, 0) + if err != nil { + t.Errorf("read: error: %v", err) + return + } + + }) + + if allocs > 0 { + t.Errorf("read allocs = %v; want 0", allocs) + } +} + +func BenchmarkWrite(b *testing.B) { + ftun, tun := newFakeTUN(b.Logf, true) + defer tun.Close() + + go func() { + for { + select { + case <-tun.closed: + return + case <-ftun.datachan: + // continue + } + } + }() + + packet := udp(0x05060708, 0x01020304, 89, 89) + for i := 0; i < b.N; i++ { + _, err := tun.Write(packet, 0) + if err != nil { + b.Errorf("err = %v; want nil", err) + } + } +} + +func BenchmarkRead(b *testing.B) { + ftun, tun := newFakeTUN(b.Logf, true) + defer tun.Close() + + packet := udp(0x05060708, 0x01020304, 89, 89) + go func() { + for { + select { + case <-tun.closed: + return + case ftun.datachan <- packet: + // continue + } + } + }() + + var buf [128]byte + for i := 0; i < b.N; i++ { + _, err := tun.Read(buf[:], 0) + if err != nil { + b.Errorf("err = %v; want nil", err) + } + } +}