diff --git a/ipn/e2e_test.go b/ipn/e2e_test.go index 6accd8785..596c6cd02 100644 --- a/ipn/e2e_test.go +++ b/ipn/e2e_test.go @@ -27,6 +27,7 @@ import ( "tailscale.com/wgengine" "tailscale.com/wgengine/magicsock" "tailscale.com/wgengine/router" + "tailscale.com/wgengine/tstun" "tailscale.io/control" // not yet released ) @@ -205,7 +206,8 @@ func newNode(t *testing.T, prefix string, https *httptest.Server, weirdPrefs boo } tun := tuntest.NewChannelTUN() - e1, err := wgengine.NewUserspaceEngineAdvanced(logfe, tun.TUN(), router.NewFake, 0) + tundev := tstun.WrapTUN(logfe, tun.TUN()) + e1, err := wgengine.NewUserspaceEngineAdvanced(logfe, tundev, router.NewFake, 0) if err != nil { t.Fatalf("NewFakeEngine: %v\n", err) } diff --git a/types/logger/logger.go b/types/logger/logger.go index f5a37c714..9e4bc72d9 100644 --- a/types/logger/logger.go +++ b/types/logger/logger.go @@ -8,9 +8,11 @@ package logger import ( + "bufio" "container/list" "fmt" "io" + "io/ioutil" "log" "sync" @@ -111,3 +113,17 @@ func RateLimitedFn(logf Logf, f float64, burst int, maxCache int) Logf { } } } + +// ArgWriter is a fmt.Formatter that can be passed to any Logf func to +// efficiently write to a %v argument without allocations. +type ArgWriter func(*bufio.Writer) + +func (fn ArgWriter) Format(f fmt.State, _ rune) { + bw := argBufioPool.Get().(*bufio.Writer) + bw.Reset(f) + fn(bw) + bw.Flush() + argBufioPool.Put(bw) +} + +var argBufioPool = &sync.Pool{New: func() interface{} { return bufio.NewWriterSize(ioutil.Discard, 1024) }} diff --git a/types/logger/logger_test.go b/types/logger/logger_test.go index 1f7ab4f07..bd4470875 100644 --- a/types/logger/logger_test.go +++ b/types/logger/logger_test.go @@ -5,6 +5,8 @@ package logger import ( + "bufio" + "bytes" "fmt" "log" "testing" @@ -62,3 +64,16 @@ func TestRateLimiter(t *testing.T) { } } + +func TestArgWriter(t *testing.T) { + got := new(bytes.Buffer) + fmt.Fprintf(got, "Greeting: %v", ArgWriter(func(bw *bufio.Writer) { + bw.WriteString("Hello, ") + bw.WriteString("world") + bw.WriteByte('!') + })) + const want = "Greeting: Hello, world!" + if got.String() != want { + t.Errorf("got %q; want %q", got, want) + } +} diff --git a/wgengine/faketun.go b/wgengine/faketun.go deleted file mode 100644 index c063ed7aa..000000000 --- a/wgengine/faketun.go +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package wgengine - -import ( - "io" - "os" - - "github.com/tailscale/wireguard-go/tun" -) - -type fakeTun struct { - datachan chan []byte - evchan chan tun.Event - closechan chan struct{} -} - -// NewFakeTun returns a fake TUN device that does not depend on the -// operating system or any special permissions. -// It primarily exists for testing. -func NewFakeTun() tun.Device { - return &fakeTun{ - datachan: make(chan []byte), - evchan: make(chan tun.Event), - closechan: make(chan struct{}), - } -} - -func (t *fakeTun) File() *os.File { - panic("fakeTun.File() called, which makes no sense") -} - -func (t *fakeTun) Close() error { - close(t.closechan) - close(t.datachan) - close(t.evchan) - return nil -} - -func (t *fakeTun) InsertRead(b []byte) { - t.datachan <- b -} - -func (t *fakeTun) Read(out []byte, offset int) (int, error) { - select { - case <-t.closechan: - return 0, io.EOF - case b := <-t.datachan: - copy(out[offset:offset+len(b)], b) - return len(b), nil - } -} - -func (t *fakeTun) Write(b []byte, n int) (int, error) { return len(b), nil } -func (t *fakeTun) Flush() error { return nil } -func (t *fakeTun) MTU() (int, error) { return 1500, nil } -func (t *fakeTun) Name() (string, error) { return "FakeTun", nil } -func (t *fakeTun) Events() chan tun.Event { return t.evchan } diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 25dff1d24..ba9935343 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -28,6 +28,8 @@ import ( "tailscale.com/stun/stuntest" "tailscale.com/types/key" "tailscale.com/types/logger" + "tailscale.com/wgengine/filter" + "tailscale.com/wgengine/tstun" ) func TestListen(t *testing.T) { @@ -326,7 +328,9 @@ func TestTwoDevicePing(t *testing.T) { //t.Logf("cfg1: %v", uapi2) tun1 := tuntest.NewChannelTUN() - dev1 := device.NewDevice(tun1.TUN(), &device.DeviceOptions{ + tstun1 := tstun.WrapTUN(t.Logf, tun1.TUN()) + tstun1.SetFilter(filter.NewAllowAll()) + dev1 := device.NewDevice(tstun1, &device.DeviceOptions{ Logger: devLogger(t, "dev1"), CreateEndpoint: conn1.CreateEndpoint, CreateBind: conn1.CreateBind, @@ -339,7 +343,9 @@ func TestTwoDevicePing(t *testing.T) { defer dev1.Close() tun2 := tuntest.NewChannelTUN() - dev2 := device.NewDevice(tun2.TUN(), &device.DeviceOptions{ + tstun2 := tstun.WrapTUN(t.Logf, tun2.TUN()) + tstun2.SetFilter(filter.NewAllowAll()) + dev2 := device.NewDevice(tstun2, &device.DeviceOptions{ Logger: devLogger(t, "dev2"), CreateEndpoint: conn2.CreateEndpoint, CreateBind: conn2.CreateBind, @@ -385,7 +391,7 @@ func TestTwoDevicePing(t *testing.T) { t.Run("ping 1.0.0.2", func(t *testing.T) { ping2(t) }) t.Run("ping 1.0.0.2 via SendPacket", func(t *testing.T) { msg1to2 := tuntest.Ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1")) - if err := dev1.SendPacket(msg1to2); err != nil { + if err := tstun1.InjectOutbound(msg1to2); err != nil { t.Fatal(err) } select { diff --git a/wgengine/tstun/faketun.go b/wgengine/tstun/faketun.go new file mode 100644 index 000000000..4451762a8 --- /dev/null +++ b/wgengine/tstun/faketun.go @@ -0,0 +1,64 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tstun + +import ( + "io" + "os" + + "github.com/tailscale/wireguard-go/tun" +) + +type fakeTUN struct { + datachan chan []byte + evchan chan tun.Event + closechan chan struct{} +} + +// NewFakeTUN returns a fake TUN device that does not depend on the +// operating system or any special permissions. +// It primarily exists for testing. +func NewFakeTUN() tun.Device { + return &fakeTUN{ + datachan: make(chan []byte), + evchan: make(chan tun.Event), + closechan: make(chan struct{}), + } +} + +func (t *fakeTUN) File() *os.File { + panic("fakeTUN.File() called, which makes no sense") +} + +func (t *fakeTUN) Close() error { + close(t.closechan) + close(t.datachan) + close(t.evchan) + return nil +} + +func (t *fakeTUN) Read(out []byte, offset int) (int, error) { + select { + case <-t.closechan: + return 0, io.EOF + case b := <-t.datachan: + copy(out[offset:offset+len(b)], b) + return len(b), nil + } +} + +func (t *fakeTUN) Write(b []byte, n int) (int, error) { + select { + case <-t.closechan: + return 0, ErrClosed + case t.datachan <- b[n:]: + return len(b), nil + } +} + +func (t *fakeTUN) Flush() error { return nil } +func (t *fakeTUN) MTU() (int, error) { return 1500, nil } +func (t *fakeTUN) Name() (string, error) { return "FakeTUN", nil } +func (t *fakeTUN) Events() chan tun.Event { return t.evchan } diff --git a/wgengine/tstun/tun.go b/wgengine/tstun/tun.go new file mode 100644 index 000000000..8cf42017e --- /dev/null +++ b/wgengine/tstun/tun.go @@ -0,0 +1,264 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package tstun provides a TUN struct implementing the tun.Device interface +// with additional features as required by wgengine. +package tstun + +import ( + "errors" + "io" + "os" + "sync/atomic" + + "github.com/tailscale/wireguard-go/device" + "github.com/tailscale/wireguard-go/tun" + "tailscale.com/types/logger" + "tailscale.com/wgengine/filter" + "tailscale.com/wgengine/packet" +) + +const ( + readMaxSize = device.MaxMessageSize + readOffset = device.MessageTransportHeaderSize +) + +// MaxPacketSize is the maximum size (in bytes) +// of a packet that can be injected into a tstun.TUN. +const MaxPacketSize = device.MaxContentSize + +var ( + ErrClosed = errors.New("device closed") + ErrFiltered = errors.New("packet dropped by filter") + ErrPacketTooBig = errors.New("packet too big") +) + +// TUN wraps a tun.Device from wireguard-go, +// augmenting it with filtering and packet injection. +// All the added work happens in Read and Write: +// the other methods delegate to the underlying tdev. +type TUN struct { + logf logger.Logf + // tdev is the underlying TUN device. + tdev tun.Device + + // buffer stores the oldest unconsumed packet from tdev. + // It is made a static buffer in order to avoid graticious allocation. + buffer [readMaxSize]byte + // bufferConsumed synchronizes access to buffer (shared by Read and poll). + bufferConsumed chan struct{} + + // closed signals poll (by closing) when the device is closed. + closed chan struct{} + // errors is the error queue populated by poll. + errors chan error + // outbound is the queue by which packets leave the TUN device. + // The directions are relative to the network, not the device: + // inbound packets arrive via UDP and are written into the TUN device; + // outbound packets are read from the TUN device and sent out via UDP. + // This queue is needed because although inbound writes are synchronous, + // the other direction must wait on a Wireguard goroutine to poll it. + outbound chan []byte + + // fitler stores the currently active package filter + filter atomic.Value // of *filter.Filter + // filterFlags control the verbosity of logging packet drops/accepts. + filterFlags filter.RunFlags +} + +func WrapTUN(logf logger.Logf, tdev tun.Device) *TUN { + tun := &TUN{ + logf: logf, + tdev: tdev, + // bufferConsumed is conceptually a condition variable: + // a goroutine should not block when setting it, even with no listeners. + bufferConsumed: make(chan struct{}, 1), + closed: make(chan struct{}), + errors: make(chan error), + outbound: make(chan []byte), + filterFlags: filter.LogAccepts | filter.LogDrops, + } + go tun.poll() + // The buffer starts out consumed. + tun.bufferConsumed <- struct{}{} + + return tun +} + +func (t *TUN) Close() error { + select { + case <-t.closed: + // continue + default: + // Other channels need not be closed: poll will exit gracefully after this. + close(t.closed) + } + + return t.tdev.Close() +} + +func (t *TUN) Events() chan tun.Event { + return t.tdev.Events() +} + +func (t *TUN) File() *os.File { + return t.tdev.File() +} + +func (t *TUN) Flush() error { + return t.tdev.Flush() +} + +func (t *TUN) MTU() (int, error) { + return t.tdev.MTU() +} + +func (t *TUN) Name() (string, error) { + return t.tdev.Name() +} + +// poll polls t.tdev.Read, placing the oldest unconsumed packet into t.buffer. +// This is needed because t.tdev.Read in general may block (it does on Windows), +// so packets may be stuck in t.outbound if t.Read called t.tdev.Read directly. +func (t *TUN) poll() { + for { + select { + case <-t.closed: + return + case <-t.bufferConsumed: + // continue + } + + // Read may use memory in t.buffer before readOffset for mandatory headers. + // This is the rationale behind the tun.TUN.{Read,Write} interfaces + // and the reason t.buffer has size MaxMessageSize and not MaxContentSize. + n, err := t.tdev.Read(t.buffer[:], readOffset) + if err != nil { + select { + case <-t.closed: + return + case t.errors <- err: + // In principle, read errors are not fatal (but wireguard-go disagrees). + t.bufferConsumed <- struct{}{} + } + } else { + select { + case <-t.closed: + return + case t.outbound <- t.buffer[readOffset : readOffset+n]: + // continue + } + } + } +} + +func (t *TUN) filterOut(buf []byte) filter.Response { + filt, _ := t.filter.Load().(*filter.Filter) + + if filt == nil { + t.logf("Warning: you forgot to use SetFilter()! Packet dropped.") + return filter.Drop + } + + var q packet.QDecode + if filt.RunOut(buf, &q, t.filterFlags) == filter.Accept { + return filter.Accept + } + return filter.Drop +} + +func (t *TUN) Read(buf []byte, offset int) (int, error) { + var n int + + select { + case <-t.closed: + return 0, io.EOF + case err := <-t.errors: + return 0, err + case packet := <-t.outbound: + n = copy(buf[offset:], packet) + // t.buffer has a fixed location in memory, + // so this is the easiest way to tell when it has been consumed. + if &packet[0] == &t.buffer[readOffset] { + t.bufferConsumed <- struct{}{} + } + } + + response := t.filterOut(buf[offset : offset+n]) + if response != filter.Accept { + // Wireguard considers read errors fatal; pretend nothing was read + return 0, nil + } + + return n, nil +} + +func (t *TUN) filterIn(buf []byte) filter.Response { + filt, _ := t.filter.Load().(*filter.Filter) + + if filt == nil { + t.logf("Warning: you forgot to use SetFilter()! Packet dropped.") + return filter.Drop + } + + var q packet.QDecode + if filt.RunIn(buf, &q, t.filterFlags) == filter.Accept { + // Only in fake mode, answer any incoming pings. + if q.IsEchoRequest() { + ft, ok := t.tdev.(*fakeTUN) + if ok { + packet := q.EchoRespond() + ft.Write(packet, 0) + // We already handled it, stop. + return filter.Drop + } + } + return filter.Accept + } + return filter.Drop +} + +func (t *TUN) Write(buf []byte, offset int) (int, error) { + response := t.filterIn(buf[offset:]) + if response != filter.Accept { + return 0, ErrFiltered + } + + return t.tdev.Write(buf, offset) +} + +func (t *TUN) GetFilter() *filter.Filter { + filt, _ := t.filter.Load().(*filter.Filter) + return filt +} + +func (t *TUN) SetFilter(filt *filter.Filter) { + t.filter.Store(filt) +} + +// InjectInbound makes the TUN device behave as if a packet +// with the given contents was received from the network. +// It blocks and does not take ownership of the packet. +func (t *TUN) InjectInbound(packet []byte) error { + if len(packet) > MaxPacketSize { + return ErrPacketTooBig + } + _, err := t.Write(packet, 0) + return err +} + +// InjectOutbound makes the TUN device behave as if a packet +// with the given contents was sent to the network. +// It does not block, but takes ownership of the packet. +func (t *TUN) InjectOutbound(packet []byte) error { + if len(packet) > MaxPacketSize { + return ErrPacketTooBig + } + select { + case <-t.closed: + return ErrClosed + case t.outbound <- packet: + return nil + } +} diff --git a/wgengine/userspace.go b/wgengine/userspace.go index d803f5a8b..9f65dbf3b 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -33,6 +33,7 @@ import ( "tailscale.com/wgengine/monitor" "tailscale.com/wgengine/packet" "tailscale.com/wgengine/router" + "tailscale.com/wgengine/tstun" ) // minimalMTU is the MTU we set on tailscale's tuntap @@ -49,7 +50,7 @@ type userspaceEngine struct { logf logger.Logf reqCh chan struct{} waitCh chan struct{} - tundev tun.Device + tundev *tstun.TUN wgdev *device.Device router router.Router magicConn *magicsock.Conn @@ -60,7 +61,6 @@ type userspaceEngine struct { lastCfg wgcfg.Config mu sync.Mutex // guards following; see lock order comment below - filt *filter.Filter statusCallback StatusCallback peerSequence []wgcfg.Key endpoints []string @@ -81,8 +81,8 @@ func (l *Loggify) Write(b []byte) (int, error) { func NewFakeUserspaceEngine(logf logger.Logf, listenPort uint16) (Engine, error) { logf("Starting userspace wireguard engine (FAKE tuntap device).") - tun := NewFakeTun() - return NewUserspaceEngineAdvanced(logf, tun, router.NewFake, listenPort) + tundev := tstun.WrapTUN(logf, tstun.NewFakeTUN()) + return NewUserspaceEngineAdvanced(logf, tundev, router.NewFake, listenPort) } // NewUserspaceEngine creates the named tun device and returns a @@ -94,13 +94,14 @@ func NewUserspaceEngine(logf logger.Logf, tunname string, listenPort uint16) (En logf("Starting userspace wireguard engine with tun device %q", tunname) - tundev, err := tun.CreateTUN(tunname, minimalMTU) + tun, err := tun.CreateTUN(tunname, minimalMTU) if err != nil { diagnoseTUNFailure(logf) logf("CreateTUN: %v", err) return nil, err } logf("CreateTUN ok.") + tundev := tstun.WrapTUN(logf, tun) e, err := NewUserspaceEngineAdvanced(logf, tundev, router.New, listenPort) if err != nil { @@ -115,11 +116,11 @@ type RouterGen func(logf logger.Logf, wgdev *device.Device, tundev tun.Device) ( // NewUserspaceEngineAdvanced is like NewUserspaceEngine but takes a pre-created TUN device and allows specifing // a custom router constructor and listening port. -func NewUserspaceEngineAdvanced(logf logger.Logf, tundev tun.Device, routerGen RouterGen, listenPort uint16) (Engine, error) { +func NewUserspaceEngineAdvanced(logf logger.Logf, tundev *tstun.TUN, routerGen RouterGen, listenPort uint16) (Engine, error) { return newUserspaceEngineAdvanced(logf, tundev, routerGen, listenPort) } -func newUserspaceEngineAdvanced(logf logger.Logf, tundev tun.Device, routerGen RouterGen, listenPort uint16) (_ Engine, reterr error) { +func newUserspaceEngineAdvanced(logf logger.Logf, tundev *tstun.TUN, routerGen RouterGen, listenPort uint16) (_ Engine, reterr error) { e := &userspaceEngine{ logf: logf, reqCh: make(chan struct{}, 1), @@ -161,16 +162,9 @@ func newUserspaceEngineAdvanced(logf logger.Logf, tundev tun.Device, routerGen R Info: dlog, Error: dlog, } - nofilter := func(b []byte) device.FilterResult { - // for safety, default to dropping all packets - logf("Warning: you forgot to use wgengine.SetFilterInOut()! Packet dropped.") - return device.FilterDrop - } opts := &device.DeviceOptions{ - Logger: &logger, - FilterIn: nofilter, - FilterOut: nofilter, + Logger: &logger, HandshakeDone: func(peerKey wgcfg.Key, allowedIPs []net.IPNet) { // Send an unsolicited status event every time a // handshake completes. This makes sure our UI can @@ -320,7 +314,7 @@ func (e *userspaceEngine) pinger(peerKey wgcfg.Key, ips []wgcfg.IP) { } for _, dstIP := range dstIPs { b := packet.GenICMP(srcIP, dstIP, ipid, packet.EchoRequest, 0, payload) - e.wgdev.SendPacket(b) + e.tundev.InjectOutbound(b) } ipid++ } @@ -385,57 +379,11 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config) } func (e *userspaceEngine) GetFilter() *filter.Filter { - e.mu.Lock() - defer e.mu.Unlock() - return e.filt + return e.tundev.GetFilter() } func (e *userspaceEngine) SetFilter(filt *filter.Filter) { - var filtin, filtout func(b []byte) device.FilterResult - if filt == nil { - e.logf("wgengine: nil filter provided; no access restrictions.") - } else { - ft, ft_ok := e.tundev.(*fakeTun) - filtin = func(b []byte) device.FilterResult { - runf := filter.LogDrops - //runf |= filter.HexdumpDrops - runf |= filter.LogAccepts - //runf |= filter.HexdumpAccepts - q := &packet.QDecode{} - if filt.RunIn(b, q, runf) == filter.Accept { - // Only in fake mode, answer any incoming pings - if ft_ok && q.IsEchoRequest() { - pb := q.EchoRespond() - ft.InsertRead(pb) - // We already handled it, stop. - return device.FilterDrop - } - return device.FilterAccept - } - return device.FilterDrop - } - - filtout = func(b []byte) device.FilterResult { - runf := filter.LogDrops - //runf |= filter.HexdumpDrops - runf |= filter.LogAccepts - //runf |= filter.HexdumpAccepts - q := &packet.QDecode{} - if filt.RunOut(b, q, runf) == filter.Accept { - return device.FilterAccept - } - return device.FilterDrop - } - } - - e.wgLock.Lock() - defer e.wgLock.Unlock() - - e.wgdev.SetFilterInOut(filtin, filtout) - - e.mu.Lock() - e.filt = filt - e.mu.Unlock() + e.tundev.SetFilter(filt) } func (e *userspaceEngine) SetStatusCallback(cb StatusCallback) { diff --git a/wgengine/watchdog_test.go b/wgengine/watchdog_test.go index a359b5ee7..0e45dd641 100644 --- a/wgengine/watchdog_test.go +++ b/wgengine/watchdog_test.go @@ -12,6 +12,7 @@ import ( "time" "tailscale.com/wgengine/router" + "tailscale.com/wgengine/tstun" ) func TestWatchdog(t *testing.T) { @@ -19,7 +20,7 @@ func TestWatchdog(t *testing.T) { t.Run("default watchdog does not fire", func(t *testing.T) { t.Parallel() - tun := NewFakeTun() + tun := tstun.WrapTUN(t.Logf, tstun.NewFakeTUN()) e, err := NewUserspaceEngineAdvanced(t.Logf, tun, router.NewFake, 0) if err != nil { t.Fatal(err) @@ -36,7 +37,7 @@ func TestWatchdog(t *testing.T) { t.Run("watchdog fires on blocked getStatus", func(t *testing.T) { t.Parallel() - tun := NewFakeTun() + tun := tstun.WrapTUN(t.Logf, tstun.NewFakeTUN()) e, err := NewUserspaceEngineAdvanced(t.Logf, tun, router.NewFake, 0) if err != nil { t.Fatal(err)