From 76389d8baf942b10a8f0f4201b7c4b0737a0172c Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Thu, 8 Dec 2022 17:58:14 -0800 Subject: [PATCH] net/tstun, wgengine/magicsock: enable vectorized I/O on Linux (#6663) This commit updates the wireguard-go dependency and implements the necessary changes to the tun.Device and conn.Bind implementations to support passing vectors of packets in tailscaled. This significantly improves throughput performance on Linux. Updates #414 Signed-off-by: Jordan Whited Signed-off-by: James Tucker Co-authored-by: James Tucker --- cmd/tailscaled/depaware.txt | 1 + go.mod | 10 +- go.sum | 20 +- net/tstun/fake.go | 15 +- net/tstun/tap_linux.go | 135 ++++++++-- net/tstun/tap_unsupported.go | 3 +- net/tstun/tun.go | 6 + net/tstun/wrap.go | 352 ++++++++++++++----------- net/tstun/wrap_test.go | 32 ++- tstest/natlab/natlab.go | 12 + wgengine/bench/wg.go | 61 +++-- wgengine/magicsock/magicsock.go | 368 +++++++++++++++++++++++---- wgengine/magicsock/magicsock_test.go | 31 ++- wgengine/wgcfg/device_test.go | 23 +- 14 files changed, 774 insertions(+), 295 deletions(-) diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index 3a79598f3..af4b373f8 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -130,6 +130,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de 💣 golang.zx2c4.com/wireguard/conn from golang.zx2c4.com/wireguard/device+ W 💣 golang.zx2c4.com/wireguard/conn/winrio from golang.zx2c4.com/wireguard/conn 💣 golang.zx2c4.com/wireguard/device from tailscale.com/net/tstun+ + L golang.zx2c4.com/wireguard/endian from golang.zx2c4.com/wireguard/tun 💣 golang.zx2c4.com/wireguard/ipc from golang.zx2c4.com/wireguard/device W 💣 golang.zx2c4.com/wireguard/ipc/namedpipe from golang.zx2c4.com/wireguard/ipc golang.zx2c4.com/wireguard/ratelimiter from golang.zx2c4.com/wireguard/device diff --git a/go.mod b/go.mod index e3db9d053..53aef54b8 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,8 @@ module tailscale.com go 1.19 +replace golang.zx2c4.com/wireguard => github.com/tailscale/wireguard-go v0.0.0-20221207223341-6be4ed075788 + require ( filippo.io/mkcert v1.4.3 github.com/Microsoft/go-winio v0.6.0 @@ -64,12 +66,12 @@ require ( github.com/vishvananda/netlink v1.1.1-0.20211118161826-650dca95af54 go4.org/mem v0.0.0-20210711025021-927187094b94 go4.org/netipx v0.0.0-20220725152314-7e7bdc8411bf - golang.org/x/crypto v0.1.0 + golang.org/x/crypto v0.3.0 golang.org/x/exp v0.0.0-20221205204356-47842c84f3db - golang.org/x/net v0.1.0 + golang.org/x/net v0.2.0 golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 - golang.org/x/sys v0.1.0 - golang.org/x/term v0.1.0 + golang.org/x/sys v0.2.0 + golang.org/x/term v0.2.0 golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11 golang.org/x/tools v0.2.0 golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 diff --git a/go.sum b/go.sum index ad13ab7c5..07809f541 100644 --- a/go.sum +++ b/go.sum @@ -1108,6 +1108,8 @@ github.com/tailscale/mkctr v0.0.0-20220601142259-c0b937af2e89 h1:7xU7AFQE83h0wz/ github.com/tailscale/mkctr v0.0.0-20220601142259-c0b937af2e89/go.mod h1:OGMqrTzDqmJkGumUTtOv44Rp3/4xS+QFbE8Rn0AGlaU= github.com/tailscale/netlink v1.1.1-0.20211101221916-cabfb018fe85 h1:zrsUcqrG2uQSPhaUPjUQwozcRdDdSxxqhNgNZ3drZFk= github.com/tailscale/netlink v1.1.1-0.20211101221916-cabfb018fe85/go.mod h1:NzVQi3Mleb+qzq8VmcWpSkcSYxXIg0DkI6XDzpVkhJ0= +github.com/tailscale/wireguard-go v0.0.0-20221207223341-6be4ed075788 h1:HRBKNhAqG+3NGtudGB8QzpaKlvf4MoBCMEnjdF+D+nA= +github.com/tailscale/wireguard-go v0.0.0-20221207223341-6be4ed075788/go.mod h1:wzWjYPfptTrgXwkAZmjd7sXHf7RYnz5PrPr6GN1eb2Y= github.com/tc-hib/winres v0.1.6 h1:qgsYHze+BxQPEYilxIz/KCQGaClvI2+yLBAZs+3+0B8= github.com/tc-hib/winres v0.1.6/go.mod h1:pe6dOR40VOrGz8PkzreVKNvEKnlE8t4yR8A8naL+t7A= github.com/tcnksm/go-httpstat v0.2.0 h1:rP7T5e5U2HfmOBmZzGgGZjBQ5/GluWUylujl0tJ04I0= @@ -1267,8 +1269,8 @@ golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20211117183948-ae814b36b871/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.1.0 h1:MDRAIl0xIo9Io2xV565hzXHw3zVseKrJKodhohM5CjU= -golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw= +golang.org/x/crypto v0.3.0 h1:a06MkbcxBrEFc0w0QIZWXrH/9cCX6KJyWbBOIwAn+7A= +golang.org/x/crypto v0.3.0/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -1375,8 +1377,8 @@ golang.org/x/net v0.0.0-20210928044308-7d9f5e0b762b/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= -golang.org/x/net v0.1.0 h1:hZ/3BUoy5aId7sCpA/Tc5lt8DkFgdVS2onTpJsZ/fl0= -golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= +golang.org/x/net v0.2.0 h1:sZfSu1wtKLGlWI4ZZayP0ck9Y73K1ynO6gqzTdBVdPU= +golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -1512,13 +1514,13 @@ golang.org/x/sys v0.0.0-20211105183446-c75c47738b0c/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220128215802-99c3d69c2c27/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U= -golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.2.0 h1:ljd4t30dBnAvMZaQCevtY0xLLD0A+bRZXbgLMLU1F/A= +golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.1.0 h1:g6Z6vPFA9dYBAF7DWcH6sCcOntplXsDKcliusYijMlw= -golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.2.0 h1:z85xZCsEl7bi/KwbNADeBYoOP0++7W1ipu+aGnpwzRM= +golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -1656,8 +1658,6 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 h1:Ug9qvr1myri/zFN6xL17LSCBGFDnphBBhzmILHsM5TY= golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= -golang.zx2c4.com/wireguard v0.0.0-20220920152132-bb719d3a6e2c h1:Okh6a1xpnJslG9Mn84pId1Mn+Q8cvpo4HCeeFWHo0cA= -golang.zx2c4.com/wireguard v0.0.0-20220920152132-bb719d3a6e2c/go.mod h1:enML0deDxY1ux+B6ANGiwtg0yAJi1rctkTpcHNAVPyg= golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE= golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI= google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= diff --git a/net/tstun/fake.go b/net/tstun/fake.go index 09d68b6ba..75ff7a384 100644 --- a/net/tstun/fake.go +++ b/net/tstun/fake.go @@ -34,21 +34,22 @@ func (t *fakeTUN) Close() error { return nil } -func (t *fakeTUN) Read(out []byte, offset int) (int, error) { +func (t *fakeTUN) Read(out [][]byte, sizes []int, offset int) (int, error) { <-t.closechan return 0, io.EOF } -func (t *fakeTUN) Write(b []byte, n int) (int, error) { +func (t *fakeTUN) Write(b [][]byte, n int) (int, error) { select { case <-t.closechan: return 0, ErrClosed default: } - return len(b), nil + return 1, nil } -func (t *fakeTUN) Flush() error { return nil } -func (t *fakeTUN) MTU() (int, error) { return 1500, nil } -func (t *fakeTUN) Name() (string, error) { return "FakeTUN", nil } -func (t *fakeTUN) Events() chan tun.Event { return t.evchan } +func (t *fakeTUN) Flush() error { return nil } +func (t *fakeTUN) MTU() (int, error) { return 1500, nil } +func (t *fakeTUN) Name() (string, error) { return "FakeTUN", nil } +func (t *fakeTUN) Events() <-chan tun.Event { return t.evchan } +func (t *fakeTUN) BatchSize() int { return 1 } diff --git a/net/tstun/tap_linux.go b/net/tstun/tap_linux.go index 15995b1a6..caf7b6573 100644 --- a/net/tstun/tap_linux.go +++ b/net/tstun/tap_linux.go @@ -12,6 +12,7 @@ import ( "net/netip" "os" "os/exec" + "sync" "github.com/insomniacslk/dhcp/dhcpv4" "golang.org/x/sys/unix" @@ -23,6 +24,7 @@ import ( "tailscale.com/net/netaddr" "tailscale.com/net/packet" "tailscale.com/types/ipproto" + "tailscale.com/util/multierr" ) // TODO: this was randomly generated once. Maybe do it per process start? But @@ -69,13 +71,7 @@ func openDevice(fd int, tapName, bridgeName string) (tun.Device, error) { } } - // Also sets non-blocking I/O on fd when creating tun.Device. - dev, _, err := tun.CreateUnmonitoredTUNFromFD(fd) // TODO: MTU - if err != nil { - return nil, err - } - - return dev, nil + return newTAPDevice(fd, tapName) } type etherType [2]byte @@ -168,7 +164,8 @@ func (t *Wrapper) handleTAPFrame(ethBuf []byte) bool { copy(res.HardwareAddressTarget(), req.HardwareAddressSender()) copy(res.ProtocolAddressTarget(), req.ProtocolAddressSender()) - n, err := t.tdev.Write(buf, 0) + // TODO(raggi): reduce allocs! + n, err := t.tdev.Write([][]byte{buf}, 0) if tapDebug { t.logf("tap: wrote ARP reply %v, %v", n, err) } @@ -252,7 +249,9 @@ func (t *Wrapper) handleDHCPRequest(ethBuf []byte) bool { netip.AddrPortFrom(netaddr.IPv4(100, 100, 100, 100), 67), // src netip.AddrPortFrom(netaddr.IPv4(255, 255, 255, 255), 68), // dst ) - n, err := t.tdev.Write(pkt, 0) + + // TODO(raggi): reduce allocs! + n, err := t.tdev.Write([][]byte{pkt}, 0) if tapDebug { t.logf("tap: wrote DHCP OFFER %v, %v", n, err) } @@ -279,7 +278,8 @@ func (t *Wrapper) handleDHCPRequest(ethBuf []byte) bool { netip.AddrPortFrom(netaddr.IPv4(100, 100, 100, 100), 67), // src netip.AddrPortFrom(netaddr.IPv4(255, 255, 255, 255), 68), // dst ) - n, err := t.tdev.Write(pkt, 0) + // TODO(raggi): reduce allocs! + n, err := t.tdev.Write([][]byte{pkt}, 0) if tapDebug { t.logf("tap: wrote DHCP ACK %v, %v", n, err) } @@ -346,21 +346,108 @@ func (t *Wrapper) destMAC() [6]byte { return t.destMACAtomic.Load() } -func (t *Wrapper) tapWrite(buf []byte, offset int) (int, error) { - if offset < ethernetFrameSize { - return 0, fmt.Errorf("[unexpected] weird offset %d for TAP write", offset) +func newTAPDevice(fd int, tapName string) (tun.Device, error) { + err := unix.SetNonblock(fd, true) + if err != nil { + return nil, err } - eth := buf[offset-ethernetFrameSize:] - dst := t.destMAC() - copy(eth[:6], dst[:]) - copy(eth[6:12], ourMAC[:]) - et := etherTypeIPv4 - if buf[offset]>>4 == 6 { - et = etherTypeIPv6 + file := os.NewFile(uintptr(fd), "/dev/tap") + d := &tapDevice{ + file: file, + events: make(chan tun.Event), + name: tapName, } - eth[12], eth[13] = et[0], et[1] - if tapDebug { - t.logf("tap: tapWrite off=%v % x", offset, buf) + return d, nil +} + +var ( + _ setWrapperer = &tapDevice{} +) + +type tapDevice struct { + file *os.File + events chan tun.Event + name string + wrapper *Wrapper + closeOnce sync.Once +} + +func (t *tapDevice) setWrapper(wrapper *Wrapper) { + t.wrapper = wrapper +} + +func (t *tapDevice) File() *os.File { + return t.file +} + +func (t *tapDevice) Name() (string, error) { + return t.name, nil +} + +func (t *tapDevice) Read(buffs [][]byte, sizes []int, offset int) (int, error) { + n, err := t.file.Read(buffs[0][offset:]) + if err != nil { + return 0, err } - return t.tdev.Write(buf, offset-ethernetFrameSize) + sizes[0] = n + return 1, nil +} + +func (t *tapDevice) Write(buffs [][]byte, offset int) (int, error) { + errs := make([]error, 0) + wrote := 0 + for _, buff := range buffs { + if offset < ethernetFrameSize { + errs = append(errs, fmt.Errorf("[unexpected] weird offset %d for TAP write", offset)) + return 0, multierr.New(errs...) + } + eth := buff[offset-ethernetFrameSize:] + dst := t.wrapper.destMAC() + copy(eth[:6], dst[:]) + copy(eth[6:12], ourMAC[:]) + et := etherTypeIPv4 + if buff[offset]>>4 == 6 { + et = etherTypeIPv6 + } + eth[12], eth[13] = et[0], et[1] + if tapDebug { + t.wrapper.logf("tap: tapWrite off=%v % x", offset, buff) + } + _, err := t.file.Write(buff[offset-ethernetFrameSize:]) + if err != nil { + errs = append(errs, err) + } else { + wrote++ + } + } + return wrote, multierr.New(errs...) +} + +func (t *tapDevice) MTU() (int, error) { + ifr, err := unix.NewIfreq(t.name) + if err != nil { + return 0, err + } + err = unix.IoctlIfreq(int(t.file.Fd()), unix.SIOCGIFMTU, ifr) + if err != nil { + return 0, err + } + return int(ifr.Uint32()), nil +} + +func (t *tapDevice) Events() <-chan tun.Event { + return t.events +} + +func (t *tapDevice) Close() error { + var err error + t.closeOnce.Do(func() { + close(t.events) + err = t.file.Close() + }) + return err +} + +func (t *tapDevice) BatchSize() int { + return 1 } diff --git a/net/tstun/tap_unsupported.go b/net/tstun/tap_unsupported.go index 2659e051e..186813349 100644 --- a/net/tstun/tap_unsupported.go +++ b/net/tstun/tap_unsupported.go @@ -6,5 +6,4 @@ package tstun -func (*Wrapper) handleTAPFrame([]byte) bool { panic("unreachable") } -func (*Wrapper) tapWrite([]byte, int) (int, error) { panic("unreachable") } +func (*Wrapper) handleTAPFrame([]byte) bool { panic("unreachable") } diff --git a/net/tstun/tun.go b/net/tstun/tun.go index 543f3ef67..ce766acb6 100644 --- a/net/tstun/tun.go +++ b/net/tstun/tun.go @@ -25,6 +25,7 @@ var createTAP func(tapName, bridgeName string) (tun.Device, error) // New returns a tun.Device for the requested device name, along with // the OS-dependent name that was allocated to the device. func New(logf logger.Logf, tunName string) (tun.Device, string, error) { + var disableTUNOffload = envknob.Bool("TS_DISABLE_TUN_OFFLOAD") var dev tun.Device var err error if strings.HasPrefix(tunName, "tap:") { @@ -51,6 +52,11 @@ func New(logf logger.Logf, tunName string) (tun.Device, string, error) { tunMTU = mtu } dev, err = tun.CreateTUN(tunName, tunMTU) + if err == nil && disableTUNOffload { + if do, ok := dev.(tun.DisableOffloader); ok { + do.DisableOffload() + } + } } if err != nil { return nil, "", err diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go index 64ee83065..f57b5525e 100644 --- a/net/tstun/wrap.go +++ b/net/tstun/wrap.go @@ -88,25 +88,31 @@ type Wrapper struct { destMACAtomic syncs.AtomicValue[[6]byte] discoKey syncs.AtomicValue[key.DiscoPublic] - // buffer stores the oldest unconsumed packet from tdev. - // It is made a static buffer in order to avoid allocations. - buffer [maxBufferSize]byte - // bufferConsumedMu protects bufferConsumed from concurrent sends and closes. - // It does not prevent send-after-close, only data races. + // vectorBuffer stores the oldest unconsumed packet vector from tdev. It is + // allocated in wrap() and the underlying arrays should never grow. + vectorBuffer [][]byte + // bufferConsumedMu protects bufferConsumed from concurrent sends, closes, + // and send-after-close (by way of bufferConsumedClosed). bufferConsumedMu sync.Mutex - // bufferConsumed synchronizes access to buffer (shared by Read and poll). + // bufferConsumedClosed is true when bufferConsumed has been closed. This is + // read by bufferConsumed writers to prevent send-after-close. + bufferConsumedClosed bool + // bufferConsumed synchronizes access to vectorBuffer (shared by Read() and + // pollVector()). // - // Close closes bufferConsumed. There may be outstanding sends to bufferConsumed - // when that happens; we catch any resulting panics. - // This lets us avoid expensive multi-case selects. + // Close closes bufferConsumed and sets bufferConsumedClosed to true. bufferConsumed chan struct{} // closed signals poll (by closing) when the device is closed. closed chan struct{} - // outboundMu protects outbound from concurrent sends and closes. - // It does not prevent send-after-close, only data races. + // outboundMu protects outbound and vectorOutbound from concurrent sends, + // closes, and send-after-close (by way of outboundClosed). outboundMu sync.Mutex - // outbound is the queue by which packets leave the TUN device. + // outboundClosed is true when outbound or vectorOutbound have been closed. + // This is read by outbound and vectorOutbound writers to prevent + // send-after-close. + outboundClosed bool + // vectorOutbound is the queue by which packets leave the TUN device. // // The directions are relative to the network, not the device: // inbound packets arrive via UDP and are written into the TUN device; @@ -115,12 +121,10 @@ type Wrapper struct { // the other direction must wait on a WireGuard goroutine to poll it. // // Empty reads are skipped by WireGuard, so it is always legal - // to discard an empty packet instead of sending it through t.outbound. + // to discard an empty packet instead of sending it through vectorOutbound. // - // Close closes outbound. There may be outstanding sends to outbound - // when that happens; we catch any resulting panics. - // This lets us avoid expensive multi-case selects. - outbound chan tunReadResult + // Close closes vectorOutbound and sets outboundClosed to true. + vectorOutbound chan tunVectorReadResult // eventsUpDown yields up and down tun.Events that arrive on a Wrapper's events channel. eventsUpDown chan tun.Event @@ -172,19 +176,30 @@ type Wrapper struct { stats atomic.Pointer[connstats.Statistics] } -// tunReadResult is the result of a TUN read, or an injected result pretending to be a TUN read. -// The data is not interpreted in the usual way for a Read method. -// See the comment in the middle of Wrap.Read. -type tunReadResult struct { - // Only one of err, packet or data should be set, and are read in that order - // of precedence. - err error +// tunInjectedRead is an injected packet pretending to be a tun.Read(). +type tunInjectedRead struct { + // Only one of packet or data should be set, and are read in that order of + // precedence. packet *stack.PacketBuffer data []byte +} + +// tunVectorReadResult is the result of a tun.Read(), or an injected packet +// pretending to be a tun.Read(). +type tunVectorReadResult struct { + // Only one of err, data, or injected should be set, and are read in that + // order of precedence. + err error + data [][]byte + injected tunInjectedRead - // injected is set if the read result was generated internally, and contained packets should not - // pass through filters. - injected bool + dataOffset int +} + +type setWrapperer interface { + // setWrapper enables the underlying TUN/TAP to have access to the Wrapper. + // It MUST be called only once during initialization, other usage is unsafe. + setWrapper(*Wrapper) } func WrapTAP(logf logger.Logf, tdev tun.Device) *Wrapper { @@ -197,7 +212,7 @@ func Wrap(logf logger.Logf, tdev tun.Device) *Wrapper { func wrap(logf logger.Logf, tdev tun.Device, isTAP bool) *Wrapper { logf = logger.WithPrefix(logf, "tstun: ") - tun := &Wrapper{ + w := &Wrapper{ logf: logf, limitedLogf: logger.RateLimitedFn(logf, 1*time.Minute, 2, 10), isTAP: isTAP, @@ -206,21 +221,30 @@ func wrap(logf logger.Logf, tdev tun.Device, isTAP bool) *Wrapper { // a goroutine should not block when setting it, even with no listeners. bufferConsumed: make(chan struct{}, 1), closed: make(chan struct{}), - // outbound can be unbuffered; the buffer is an optimization. - outbound: make(chan tunReadResult, 1), - eventsUpDown: make(chan tun.Event), - eventsOther: make(chan tun.Event), + // vectorOutbound can be unbuffered; the buffer is an optimization. + vectorOutbound: make(chan tunVectorReadResult, 1), + eventsUpDown: make(chan tun.Event), + eventsOther: make(chan tun.Event), // TODO(dmytro): (highly rate-limited) hexdumps should happen on unknown packets. filterFlags: filter.LogAccepts | filter.LogDrops, } - go tun.poll() - go tun.pumpEvents() + w.vectorBuffer = make([][]byte, tdev.BatchSize()) + for i := range w.vectorBuffer { + w.vectorBuffer[i] = make([]byte, maxBufferSize) + } + go w.pollVector() + + go w.pumpEvents() // The buffer starts out consumed. - tun.bufferConsumed <- struct{}{} - tun.noteActivity() + w.bufferConsumed <- struct{}{} + w.noteActivity() + + if sw, ok := w.tdev.(setWrapperer); ok { + sw.setWrapper(w) + } - return tun + return w } // SetDestIPActivityFuncs sets a map of funcs to run per packet @@ -261,10 +285,12 @@ func (t *Wrapper) Close() error { t.closeOnce.Do(func() { close(t.closed) t.bufferConsumedMu.Lock() + t.bufferConsumedClosed = true close(t.bufferConsumed) t.bufferConsumedMu.Unlock() t.outboundMu.Lock() - close(t.outbound) + t.outboundClosed = true + close(t.vectorOutbound) t.outboundMu.Unlock() err = t.tdev.Close() }) @@ -323,7 +349,7 @@ func (t *Wrapper) EventsUpDown() chan tun.Event { // Events returns a TUN event channel that contains all non-Up, non-Down events. // It is named Events because it is the set of events that we want to expose to wireguard-go, // and Events is the name specified by the wireguard-go tun.Device interface. -func (t *Wrapper) Events() chan tun.Event { +func (t *Wrapper) Events() <-chan tun.Event { return t.eventsOther } @@ -331,10 +357,6 @@ func (t *Wrapper) File() *os.File { return t.tdev.File() } -func (t *Wrapper) Flush() error { - return t.tdev.Flush() -} - func (t *Wrapper) MTU() (int, error) { return t.tdev.MTU() } @@ -343,94 +365,95 @@ func (t *Wrapper) Name() (string, error) { return t.tdev.Name() } -// allowSendOnClosedChannel suppresses panics due to sending on a closed channel. -// This allows us to avoid synchronization between poll and Close. -// Such synchronization (particularly multi-case selects) is too expensive -// for code like poll or Read that is on the hot path of every packet. -// If this makes you sad or angry, you may want to join our -// weekly Go Performance Delinquents Anonymous meetings on Monday nights. -func allowSendOnClosedChannel() { - r := recover() - if r == nil { - return - } - e, _ := r.(error) - if e != nil && e.Error() == "send on closed channel" { - return - } - panic(r) -} - const ethernetFrameSize = 14 // 2 six byte MACs, 2 bytes ethertype -// poll polls t.tdev.Read, placing the oldest unconsumed packet into t.buffer. -// This is needed because t.tdev.Read in general may block (it does on Windows), -// so packets may be stuck in t.outbound if t.Read called t.tdev.Read directly. -func (t *Wrapper) poll() { +// pollVector polls t.tdev.Read(), placing the oldest unconsumed packet vector +// into t.vectorBuffer. This is needed because t.tdev.Read() in general may +// block (it does on Windows), so packets may be stuck in t.vectorOutbound if +// t.Read() called t.tdev.Read() directly. +func (t *Wrapper) pollVector() { + sizes := make([]int, len(t.vectorBuffer)) + readOffset := PacketStartOffset + if t.isTAP { + readOffset = PacketStartOffset - ethernetFrameSize + } + for range t.bufferConsumed { DoRead: + for i := range t.vectorBuffer { + t.vectorBuffer[i] = t.vectorBuffer[i][:cap(t.vectorBuffer[i])] + } var n int var err error - // Read may use memory in t.buffer before PacketStartOffset for mandatory headers. - // This is the rationale behind the tun.Wrapper.{Read,Write} interfaces - // and the reason t.buffer has size MaxMessageSize and not MaxContentSize. - // In principle, read errors are not fatal (but wireguard-go disagrees). - // We loop here until we get a non-empty (or failed) read. - // We don't need this loop for correctness, - // but wireguard-go will skip an empty read, - // so we might as well avoid the send through t.outbound. for n == 0 && err == nil { if t.isClosed() { return } - if t.isTAP { - n, err = t.tdev.Read(t.buffer[:], PacketStartOffset-ethernetFrameSize) - if tapDebug { - s := fmt.Sprintf("% x", t.buffer[:]) - for strings.HasSuffix(s, " 00") { - s = strings.TrimSuffix(s, " 00") - } - t.logf("TAP read %v, %v: %s", n, err, s) + n, err = t.tdev.Read(t.vectorBuffer[:], sizes, readOffset) + if t.isTAP && tapDebug { + s := fmt.Sprintf("% x", t.vectorBuffer[0][:]) + for strings.HasSuffix(s, " 00") { + s = strings.TrimSuffix(s, " 00") } - } else { - n, err = t.tdev.Read(t.buffer[:], PacketStartOffset) + t.logf("TAP read %v, %v: %s", n, err, s) } } + for i := range sizes[:n] { + t.vectorBuffer[i] = t.vectorBuffer[i][:readOffset+sizes[i]] + } if t.isTAP { if err == nil { - ethernetFrame := t.buffer[PacketStartOffset-ethernetFrameSize:][:n] + ethernetFrame := t.vectorBuffer[0][readOffset:] if t.handleTAPFrame(ethernetFrame) { goto DoRead } } // Fall through. We got an IP packet. - if n >= ethernetFrameSize { - n -= ethernetFrameSize + if sizes[0] >= ethernetFrameSize { + t.vectorBuffer[0] = t.vectorBuffer[0][:readOffset+sizes[0]-ethernetFrameSize] } if tapDebug { - t.logf("tap regular frame: %x", t.buffer[PacketStartOffset:PacketStartOffset+n]) + t.logf("tap regular frame: %x", t.vectorBuffer[0][PacketStartOffset:PacketStartOffset+sizes[0]]) } } - t.sendOutbound(tunReadResult{data: t.buffer[PacketStartOffset : PacketStartOffset+n], err: err}) + t.sendVectorOutbound(tunVectorReadResult{ + data: t.vectorBuffer[:n], + dataOffset: PacketStartOffset, + err: err, + }) } } // sendBufferConsumed does t.bufferConsumed <- struct{}{}. -// It protects against any panics or data races that that send could cause. func (t *Wrapper) sendBufferConsumed() { - defer allowSendOnClosedChannel() t.bufferConsumedMu.Lock() defer t.bufferConsumedMu.Unlock() + if t.bufferConsumedClosed { + return + } t.bufferConsumed <- struct{}{} } -// sendOutbound does t.outboundMu <- r. -// It protects against any panics or data races that that send could cause. -func (t *Wrapper) sendOutbound(r tunReadResult) { - defer allowSendOnClosedChannel() +// injectOutbound does t.vectorOutbound <- r +func (t *Wrapper) injectOutbound(r tunInjectedRead) { t.outboundMu.Lock() defer t.outboundMu.Unlock() - t.outbound <- r + if t.outboundClosed { + return + } + t.vectorOutbound <- tunVectorReadResult{ + injected: r, + } +} + +// sendVectorOutbound does t.vectorOutbound <- r. +func (t *Wrapper) sendVectorOutbound(r tunVectorReadResult) { + t.outboundMu.Lock() + defer t.outboundMu.Unlock() + if t.outboundClosed { + return + } + t.vectorOutbound <- r } var ( @@ -514,34 +537,79 @@ func (t *Wrapper) IdleDuration() time.Duration { return mono.Since(t.lastActivityAtomic.LoadAtomic()) } -func (t *Wrapper) Read(buf []byte, offset int) (int, error) { - res, ok := <-t.outbound +func (t *Wrapper) Read(buffs [][]byte, sizes []int, offset int) (int, error) { + res, ok := <-t.vectorOutbound + if !ok { - // Wrapper is closed. return 0, io.EOF } if res.err != nil { return 0, res.err } + if res.data == nil { + n, err := t.injectedRead(res.injected, buffs[0], offset) + sizes[0] = n + if err != nil && n == 0 { + return 0, err + } + + return 1, err + } + + metricPacketOut.Add(int64(len(res.data))) + + var buffsPos int + for _, data := range res.data { + p := parsedPacketPool.Get().(*packet.Parsed) + defer parsedPacketPool.Put(p) + p.Decode(data[res.dataOffset:]) + if m := t.destIPActivity.Load(); m != nil { + if fn := m[p.Dst.Addr()]; fn != nil { + fn() + } + } + if !t.disableFilter { + response := t.filterOut(p) + if response != filter.Accept { + metricPacketOutDrop.Add(1) + continue + } + } + n := copy(buffs[buffsPos][offset:], data[res.dataOffset:]) + if n != len(data)-res.dataOffset { + panic(fmt.Sprintf("short copy: %d != %d", n, len(data)-res.dataOffset)) + } + sizes[buffsPos] = n + if stats := t.stats.Load(); stats != nil { + stats.UpdateTxVirtual(data[res.dataOffset:]) + } + buffsPos++ + } + + // t.vectorBuffer has a fixed location in memory. + // TODO(raggi): add an explicit field and possibly method to the tunVectorReadResult + // to signal when sendBufferConsumed should be called. + if &res.data[0] == &t.vectorBuffer[0] { + // We are done with t.buffer. Let poll() re-use it. + t.sendBufferConsumed() + } + + t.noteActivity() + return buffsPos, nil +} +// injectedRead handles injected reads, which bypass filters. +func (t *Wrapper) injectedRead(res tunInjectedRead, buf []byte, offset int) (int, error) { metricPacketOut.Add(1) var n int if res.packet != nil { - n = copy(buf[offset:], res.packet.NetworkHeader().Slice()) n += copy(buf[offset+n:], res.packet.TransportHeader().Slice()) n += copy(buf[offset+n:], res.packet.Data().AsRange().ToSlice()) - res.packet.DecRef() } else { n = copy(buf[offset:], res.data) - - // t.buffer has a fixed location in memory. - if &res.data[0] == &t.buffer[PacketStartOffset] { - // We are done with t.buffer. Let poll re-use it. - t.sendBufferConsumed() - } } p := parsedPacketPool.Get().(*packet.Parsed) @@ -554,16 +622,6 @@ func (t *Wrapper) Read(buf []byte, offset int) (int, error) { } } - // Do not filter injected packets. - if !res.injected && !t.disableFilter { - response := t.filterOut(p) - if response != filter.Accept { - metricPacketOutDrop.Add(1) - // WireGuard considers read errors fatal; pretend nothing was read - return 0, nil - } - } - if stats := t.stats.Load(); stats != nil { stats.UpdateTxVirtual(buf[offset:][:n]) } @@ -668,42 +726,40 @@ func (t *Wrapper) filterIn(buf []byte) filter.Response { return filter.Accept } -// Write accepts an incoming packet. The packet begins at buf[offset:], +// Write accepts incoming packets. The packets begins at buffs[:][offset:], // like wireguard-go/tun.Device.Write. -func (t *Wrapper) Write(buf []byte, offset int) (int, error) { - metricPacketIn.Add(1) +func (t *Wrapper) Write(buffs [][]byte, offset int) (int, error) { + metricPacketIn.Add(int64(len(buffs))) + i := 0 if !t.disableFilter { - if t.filterIn(buf[offset:]) != filter.Accept { - metricPacketInDrop.Add(1) - // If we're not accepting the packet, lie to wireguard-go and pretend - // that everything is okay with a nil error, so wireguard-go - // doesn't log about this Write "failure". - // - // We return len(buf), but the ill-defined wireguard-go/tun.Device.Write - // method doesn't specify how the offset affects the return value. - // In fact, the Linux implementation does one of two different things depending - // on how the /dev/net/tun was created. But fortunately the wireguard-go - // code ignores the int return and only looks at the error: - // - // device/receive.go: _, err = device.tun.device.Write(....) - // - // TODO(bradfitz): fix upstream interface docs, implementation. - return len(buf), nil + for _, buff := range buffs { + if t.filterIn(buff[offset:]) != filter.Accept { + metricPacketInDrop.Add(1) + } else { + buffs[i] = buff + i++ + } } + } else { + i = len(buffs) } + buffs = buffs[:i] - t.noteActivity() - return t.tdevWrite(buf, offset) + if len(buffs) > 0 { + t.noteActivity() + _, err := t.tdevWrite(buffs, offset) + return len(buffs), err + } + return 0, nil } -func (t *Wrapper) tdevWrite(buf []byte, offset int) (int, error) { +func (t *Wrapper) tdevWrite(buffs [][]byte, offset int) (int, error) { if stats := t.stats.Load(); stats != nil { - stats.UpdateRxVirtual(buf[offset:]) - } - if t.isTAP { - return t.tapWrite(buf, offset) + for i := range buffs { + stats.UpdateRxVirtual((buffs)[i][offset:]) + } } - return t.tdev.Write(buf, offset) + return t.tdev.Write(buffs, offset) } func (t *Wrapper) GetFilter() *filter.Filter { @@ -755,7 +811,7 @@ func (t *Wrapper) InjectInboundDirect(buf []byte, offset int) error { } // Write to the underlying device to skip filters. - _, err := t.tdevWrite(buf, offset) + _, err := t.tdevWrite([][]byte{buf}, offset) // TODO(jwhited): alloc? return err } @@ -813,7 +869,7 @@ func (t *Wrapper) InjectOutbound(packet []byte) error { if len(packet) == 0 { return nil } - t.sendOutbound(tunReadResult{data: packet, injected: true}) + t.injectOutbound(tunInjectedRead{data: packet}) return nil } @@ -830,10 +886,14 @@ func (t *Wrapper) InjectOutboundPacketBuffer(packet *stack.PacketBuffer) error { packet.DecRef() return nil } - t.sendOutbound(tunReadResult{packet: packet, injected: true}) + t.injectOutbound(tunInjectedRead{packet: packet}) return nil } +func (t *Wrapper) BatchSize() int { + return t.tdev.BatchSize() +} + // Unwrap returns the underlying tun.Device. func (t *Wrapper) Unwrap() tun.Device { return t.tdev diff --git a/net/tstun/wrap_test.go b/net/tstun/wrap_test.go index 1a6b33cff..a3cb80adf 100644 --- a/net/tstun/wrap_test.go +++ b/net/tstun/wrap_test.go @@ -208,16 +208,24 @@ func TestReadAndInject(t *testing.T) { var buf [MaxPacketSize]byte var seen = make(map[string]bool) + sizes := make([]int, 1) // We expect the same packets back, in no particular order. for i := 0; i < len(written)+len(injected); i++ { - n, err := tun.Read(buf[:], 0) + packet := buf[:] + buffs := [][]byte{packet} + numPackets, err := tun.Read(buffs, sizes, 0) if err != nil { t.Errorf("read %d: error: %v", i, err) } - if n != size { - t.Errorf("read %d: got size %d; want %d", i, n, size) + if numPackets != 1 { + t.Fatalf("read %d packets, expected %d", numPackets, 1) } - got := string(buf[:n]) + packet = packet[:sizes[0]] + packetLen := len(packet) + if packetLen != size { + t.Errorf("read %d: got size %d; want %d", i, packetLen, size) + } + got := string(packet) t.Logf("read %d: got %s", i, got) seen[got] = true } @@ -245,13 +253,10 @@ func TestWriteAndInject(t *testing.T) { go func() { for _, packet := range written { payload := []byte(packet) - n, err := tun.Write(payload, 0) + _, err := tun.Write([][]byte{payload}, 0) if err != nil { t.Errorf("%s: error: %v", packet, err) } - if n != size { - t.Errorf("%s: got size %d; want %d", packet, n, size) - } } }() @@ -339,6 +344,7 @@ func TestFilter(t *testing.T) { var n int var err error var filtered bool + sizes := make([]int, 1) tunStats, _ := stats.Extract() if len(tunStats) > 0 { @@ -352,11 +358,11 @@ func TestFilter(t *testing.T) { // If it stays zero, nothing made it through // to the wrapped TUN. tun.lastActivityAtomic.StoreAtomic(0) - _, err = tun.Write(tt.data, 0) + _, err = tun.Write([][]byte{tt.data}, 0) filtered = tun.lastActivityAtomic.LoadAtomic() == 0 } else { chtun.Outbound <- tt.data - n, err = tun.Read(buf[:], 0) + n, err = tun.Read([][]byte{buf[:]}, sizes, 0) // In the read direction, errors are fatal, so we return n = 0 instead. filtered = (n == 0) } @@ -400,7 +406,7 @@ func TestAllocs(t *testing.T) { ftun, tun := newFakeTUN(t.Logf, false) defer tun.Close() - buf := []byte{0x00} + buf := [][]byte{[]byte{0x00}} err := tstest.MinAllocsPerRun(t, 0, func() { _, err := ftun.Write(buf, 0) if err != nil { @@ -417,7 +423,7 @@ func TestAllocs(t *testing.T) { func TestClose(t *testing.T) { ftun, tun := newFakeTUN(t.Logf, false) - data := udp4("1.2.3.4", "5.6.7.8", 98, 98) + data := [][]byte{udp4("1.2.3.4", "5.6.7.8", 98, 98)} _, err := ftun.Write(data, 0) if err != nil { t.Error(err) @@ -435,7 +441,7 @@ func BenchmarkWrite(b *testing.B) { ftun, tun := newFakeTUN(b.Logf, true) defer tun.Close() - packet := udp4("5.6.7.8", "1.2.3.4", 89, 89) + packet := [][]byte{udp4("5.6.7.8", "1.2.3.4", 89, 89)} for i := 0; i < b.N; i++ { _, err := ftun.Write(packet, 0) if err != nil { diff --git a/tstest/natlab/natlab.go b/tstest/natlab/natlab.go index 91a6462cf..5095cfcd6 100644 --- a/tstest/natlab/natlab.go +++ b/tstest/natlab/natlab.go @@ -812,6 +812,18 @@ func (c *conn) LocalAddr() net.Addr { } } +func (c *conn) Read(buf []byte) (int, error) { + panic("unimplemented stub") +} + +func (c *conn) RemoteAddr() net.Addr { + panic("unimplemented stub") +} + +func (c *conn) Write(buf []byte) (int, error) { + panic("unimplemented stub") +} + func (c *conn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/wgengine/bench/wg.go b/wgengine/bench/wg.go index d5e8df1d1..f6a215896 100644 --- a/wgengine/bench/wg.go +++ b/wgengine/bench/wg.go @@ -168,25 +168,31 @@ type sourceTun struct { traf *TrafficGen } -func (t *sourceTun) Close() error { return nil } -func (t *sourceTun) Events() chan tun.Event { return nil } -func (t *sourceTun) File() *os.File { return nil } -func (t *sourceTun) Flush() error { return nil } -func (t *sourceTun) MTU() (int, error) { return 1500, nil } -func (t *sourceTun) Name() (string, error) { return "source", nil } - -func (t *sourceTun) Write(b []byte, ofs int) (int, error) { +func (t *sourceTun) Close() error { return nil } +func (t *sourceTun) Events() <-chan tun.Event { return nil } +func (t *sourceTun) File() *os.File { return nil } +func (t *sourceTun) Flush() error { return nil } +func (t *sourceTun) MTU() (int, error) { return 1500, nil } +func (t *sourceTun) Name() (string, error) { return "source", nil } + +// TODO(raggi): could be optimized for linux style batch sizes +func (t *sourceTun) BatchSize() int { return 1 } + +func (t *sourceTun) Write(b [][]byte, ofs int) (int, error) { // Discard all writes - return len(b) - ofs, nil + return len(b), nil } -func (t *sourceTun) Read(b []byte, ofs int) (int, error) { - // Continually generate "input" packets - n := t.traf.Generate(b, ofs) - if n == 0 { - return 0, io.EOF +func (t *sourceTun) Read(b [][]byte, sizes []int, ofs int) (int, error) { + for i, b := range b { + // Continually generate "input" packets + n := t.traf.Generate(b, ofs) + sizes[i] = n + if n == 0 { + return 0, io.EOF + } } - return n, nil + return len(b), nil } type sinkTun struct { @@ -194,20 +200,25 @@ type sinkTun struct { traf *TrafficGen } -func (t *sinkTun) Close() error { return nil } -func (t *sinkTun) Events() chan tun.Event { return nil } -func (t *sinkTun) File() *os.File { return nil } -func (t *sinkTun) Flush() error { return nil } -func (t *sinkTun) MTU() (int, error) { return 1500, nil } -func (t *sinkTun) Name() (string, error) { return "sink", nil } +func (t *sinkTun) Close() error { return nil } +func (t *sinkTun) Events() <-chan tun.Event { return nil } +func (t *sinkTun) File() *os.File { return nil } +func (t *sinkTun) Flush() error { return nil } +func (t *sinkTun) MTU() (int, error) { return 1500, nil } +func (t *sinkTun) Name() (string, error) { return "sink", nil } -func (t *sinkTun) Read(b []byte, ofs int) (int, error) { +func (t *sinkTun) Read(b [][]byte, sizes []int, ofs int) (int, error) { // Never returns select {} } -func (t *sinkTun) Write(b []byte, ofs int) (int, error) { +func (t *sinkTun) Write(b [][]byte, ofs int) (int, error) { // Count packets, but discard them - t.traf.GotPacket(b, ofs) - return len(b) - ofs, nil + for _, b := range b { + t.traf.GotPacket(b, ofs) + } + return len(b), nil } + +// TODO(raggi): could be optimized for linux style batch sizes +func (t *sinkTun) BatchSize() int { return 1 } diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 34ad7c8ae..e5b782c9d 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -29,6 +29,8 @@ import ( "time" "go4.org/mem" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" "golang.zx2c4.com/wireguard/conn" "tailscale.com/control/controlclient" "tailscale.com/derp" @@ -269,6 +271,9 @@ type Conn struct { pconn4 RebindingUDPConn pconn6 RebindingUDPConn + receiveBatchPool sync.Pool + sendBatchPool sync.Pool + // closeDisco4 and closeDisco6 are io.Closers to shut down the raw // disco packet receivers. If nil, no raw disco receiver is // running for the given family. @@ -575,6 +580,30 @@ func newConn() *Conn { discoInfo: make(map[key.DiscoPublic]*discoInfo), } c.bind = &connBind{Conn: c, closed: true} + c.receiveBatchPool = sync.Pool{New: func() any { + msgs := make([]ipv6.Message, c.bind.BatchSize()) + for i := range msgs { + msgs[i].Buffers = make([][]byte, 1) + } + batch := &receiveBatch{ + msgs: msgs, + } + return batch + }} + c.sendBatchPool = sync.Pool{New: func() any { + ua := &net.UDPAddr{ + IP: make([]byte, 16), + } + msgs := make([]ipv6.Message, c.bind.BatchSize()) + for i := range msgs { + msgs[i].Buffers = make([][]byte, 1) + msgs[i].Addr = ua + } + return &sendBatch{ + ua: ua, + msgs: msgs, + } + }} c.muCond = sync.NewCond(&c.mu) c.networkUp.Store(true) // assume up until told otherwise return c @@ -1214,13 +1243,14 @@ var errNetworkDown = errors.New("magicsock: network down") func (c *Conn) networkDown() bool { return !c.networkUp.Load() } -func (c *Conn) Send(b []byte, ep conn.Endpoint) error { - metricSendData.Add(1) +func (c *Conn) Send(buffs [][]byte, ep conn.Endpoint) error { + n := int64(len(buffs)) + metricSendData.Add(n) if c.networkDown() { - metricSendDataNetworkDown.Add(1) + metricSendDataNetworkDown.Add(n) return errNetworkDown } - return ep.(*endpoint).send(b) + return ep.(*endpoint).send(buffs) } var errConnClosed = errors.New("Conn closed") @@ -1229,6 +1259,46 @@ var errDropDerpPacket = errors.New("too many DERP packets queued; dropping") var errNoUDP = errors.New("no UDP available on platform") +var ( + // This acts as a compile-time check for our usage of ipv6.Message in + // udpConnWithBatchOps for both IPv6 and IPv4 operations. + _ ipv6.Message = ipv4.Message{} +) + +type sendBatch struct { + ua *net.UDPAddr + msgs []ipv6.Message // ipv4.Message and ipv6.Message are the same underlying type +} + +func (c *Conn) sendUDPBatch(addr netip.AddrPort, buffs [][]byte) (sent bool, err error) { + batch := c.sendBatchPool.Get().(*sendBatch) + defer c.sendBatchPool.Put(batch) + + isIPv6 := false + switch { + case addr.Addr().Is4(): + case addr.Addr().Is6(): + isIPv6 = true + default: + panic("bogus sendUDPBatch addr type") + } + + as16 := addr.Addr().As16() + copy(batch.ua.IP, as16[:]) + batch.ua.Port = int(addr.Port()) + for i, buff := range buffs { + batch.msgs[i].Buffers[0] = buff + batch.msgs[i].Addr = batch.ua + } + + if isIPv6 { + _, err = c.pconn6.WriteBatch(batch.msgs[:len(buffs)], 0) + } else { + _, err = c.pconn4.WriteBatch(batch.msgs[:len(buffs)], 0) + } + return err == nil, err +} + // sendUDP sends UDP packet b to ipp. // See sendAddr's docs on the return value meanings. func (c *Conn) sendUDP(ipp netip.AddrPort, b []byte) (sent bool, err error) { @@ -1671,34 +1741,93 @@ func (c *Conn) runDerpWriter(ctx context.Context, dc *derphttp.Client, ch <-chan } } -// receiveIPv6 receives a UDP IPv6 packet. It is called by wireguard-go. -func (c *Conn) receiveIPv6(b []byte) (int, conn.Endpoint, error) { +type receiveBatch struct { + msgs []ipv6.Message +} + +func (c *Conn) getReceiveBatch() *receiveBatch { + batch := c.receiveBatchPool.Get().(*receiveBatch) + return batch +} + +func (c *Conn) putReceiveBatch(batch *receiveBatch) { + for i := range batch.msgs { + batch.msgs[i] = ipv6.Message{Buffers: batch.msgs[i].Buffers} + } + c.receiveBatchPool.Put(batch) +} + +func (c *Conn) receiveIPv6(buffs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) { health.ReceiveIPv6.Enter() defer health.ReceiveIPv6.Exit() + + batch := c.getReceiveBatch() + defer c.putReceiveBatch(batch) for { - n, ipp, err := c.pconn6.ReadFromNetaddr(b) + for i := range buffs { + batch.msgs[i].Buffers[0] = buffs[i] + } + numMsgs, err := c.pconn6.ReadBatch(batch.msgs, 0) if err != nil { - return 0, nil, err + if neterror.PacketWasTruncated(err) { + // TODO(raggi): discuss whether to log? + continue + } + return 0, err } - if ep, ok := c.receiveIP(b[:n], ipp, &c.ippEndpoint6, c.closeDisco6 == nil); ok { - metricRecvDataIPv6.Add(1) - return n, ep, nil + + reportToCaller := false + for i, msg := range batch.msgs[:numMsgs] { + ipp := msg.Addr.(*net.UDPAddr).AddrPort() + if ep, ok := c.receiveIP(msg.Buffers[0][:msg.N], ipp, &c.ippEndpoint6, c.closeDisco6 == nil); ok { + metricRecvDataIPv6.Add(1) + eps[i] = ep + sizes[i] = msg.N + reportToCaller = true + } else { + sizes[i] = 0 + } + } + + if reportToCaller { + return numMsgs, nil } } } -// receiveIPv4 receives a UDP IPv4 packet. It is called by wireguard-go. -func (c *Conn) receiveIPv4(b []byte) (n int, ep conn.Endpoint, err error) { +func (c *Conn) receiveIPv4(buffs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) { health.ReceiveIPv4.Enter() defer health.ReceiveIPv4.Exit() + + batch := c.getReceiveBatch() + defer c.putReceiveBatch(batch) for { - n, ipp, err := c.pconn4.ReadFromNetaddr(b) + for i := range buffs { + batch.msgs[i].Buffers[0] = buffs[i] + } + numMsgs, err := c.pconn4.ReadBatch(batch.msgs, 0) if err != nil { - return 0, nil, err + if neterror.PacketWasTruncated(err) { + // TODO(raggi): discuss whether to log? + continue + } + return 0, err } - if ep, ok := c.receiveIP(b[:n], ipp, &c.ippEndpoint4, c.closeDisco4 == nil); ok { - metricRecvDataIPv4.Add(1) - return n, ep, nil + + reportToCaller := false + for i, msg := range batch.msgs[:numMsgs] { + ipp := msg.Addr.(*net.UDPAddr).AddrPort() + if ep, ok := c.receiveIP(msg.Buffers[0][:msg.N], ipp, &c.ippEndpoint4, c.closeDisco4 == nil); ok { + metricRecvDataIPv4.Add(1) + eps[i] = ep + sizes[i] = msg.N + reportToCaller = true + } else { + sizes[i] = 0 + } + } + if reportToCaller { + return numMsgs, nil } } } @@ -1748,27 +1877,25 @@ func (c *Conn) receiveIP(b []byte, ipp netip.AddrPort, cache *ippEndpointCache, return ep, true } -// receiveDERP reads a packet from c.derpRecvCh into b and returns the associated endpoint. -// It is called by wireguard-go. -// -// If the packet was a disco message or the peer endpoint wasn't -// found, the returned error is errLoopAgain. -func (c *connBind) receiveDERP(b []byte) (n int, ep conn.Endpoint, err error) { +func (c *connBind) receiveDERP(buffs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) { health.ReceiveDERP.Enter() defer health.ReceiveDERP.Exit() + for dm := range c.derpRecvCh { if c.Closed() { break } - n, ep := c.processDERPReadResult(dm, b) + n, ep := c.processDERPReadResult(dm, buffs[0]) if n == 0 { // No data read occurred. Wait for another packet. continue } metricRecvDataDERP.Add(1) - return n, ep, nil + sizes[0] = n + eps[0] = ep + return 1, nil } - return 0, nil, net.ErrClosed + return 0, net.ErrClosed } func (c *Conn) processDERPReadResult(dm derpReadResult, b []byte) (n int, ep *endpoint) { @@ -2645,6 +2772,16 @@ type connBind struct { closed bool } +func (c *connBind) BatchSize() int { + // TODO(raggi): determine by properties rather than hardcoding platform behavior + switch runtime.GOOS { + case "linux": + return conn.DefaultBatchSize + default: + return 1 + } +} + // Open is called by WireGuard to create a UDP binding. // The ignoredPort comes from wireguard-go, via the wgcfg config. // We ignore that port value here, since we have the local port available easily. @@ -2856,13 +2993,13 @@ func (c *Conn) bindSocket(ruc *RebindingUDPConn, network string, curPortFate cur defer ruc.mu.Unlock() if runtime.GOOS == "js" { - ruc.setConnLocked(newBlockForeverConn()) + ruc.setConnLocked(newBlockForeverConn(), "") return nil } if debugAlwaysDERP() { c.logf("disabled %v per TS_DEBUG_ALWAYS_USE_DERP", network) - ruc.setConnLocked(newBlockForeverConn()) + ruc.setConnLocked(newBlockForeverConn(), "") return nil } @@ -2897,7 +3034,7 @@ func (c *Conn) bindSocket(ruc *RebindingUDPConn, network string, curPortFate cur } trySetSocketBuffer(pconn, c.logf) // Success. - ruc.setConnLocked(pconn) + ruc.setConnLocked(pconn, network) if network == "udp4" { health.SetUDP4Unbound(false) } @@ -2908,7 +3045,7 @@ func (c *Conn) bindSocket(ruc *RebindingUDPConn, network string, curPortFate cur // Set pconn to a dummy conn whose reads block until closed. // This keeps the receive funcs alive for a future in which // we get a link change and we can try binding again. - ruc.setConnLocked(newBlockForeverConn()) + ruc.setConnLocked(newBlockForeverConn(), "") if network == "udp4" { health.SetUDP4Unbound(true) } @@ -3005,6 +3142,51 @@ func (c *Conn) ParseEndpoint(nodeKeyStr string) (conn.Endpoint, error) { return ep, nil } +type batchReaderWriter interface { + batchReader + batchWriter +} + +type batchWriter interface { + WriteBatch([]ipv6.Message, int) (int, error) +} + +type batchReader interface { + ReadBatch([]ipv6.Message, int) (int, error) +} + +// udpConnWithBatchOps wraps a *net.UDPConn in order to extend it to support +// batch operations. +// +// TODO(jwhited): This wrapping is temporary. https://github.com/golang/go/issues/45886 +type udpConnWithBatchOps struct { + *net.UDPConn + xpc batchReaderWriter +} + +func newUDPConnWithBatchOps(conn *net.UDPConn, network string) udpConnWithBatchOps { + ucbo := udpConnWithBatchOps{ + UDPConn: conn, + } + switch network { + case "udp4": + ucbo.xpc = ipv4.NewPacketConn(conn) + case "udp6": + ucbo.xpc = ipv6.NewPacketConn(conn) + default: + panic("bogus network") + } + return ucbo +} + +func (u udpConnWithBatchOps) WriteBatch(ms []ipv6.Message, flags int) (int, error) { + return u.xpc.WriteBatch(ms, flags) +} + +func (u udpConnWithBatchOps) ReadBatch(ms []ipv6.Message, flags int) (int, error) { + return u.xpc.ReadBatch(ms, flags) +} + // RebindingUDPConn is a UDP socket that can be re-bound. // Unix has no notion of re-binding a socket, so we swap it out for a new one. type RebindingUDPConn struct { @@ -3022,9 +3204,28 @@ type RebindingUDPConn struct { port uint16 } -func (c *RebindingUDPConn) setConnLocked(p nettype.PacketConn) { - c.pconn = p - c.pconnAtomic.Store(&p) +// upgradePacketConn may upgrade a nettype.PacketConn to a udpConnWithBatchOps. +func upgradePacketConn(p nettype.PacketConn, network string) nettype.PacketConn { + uc, ok := p.(*net.UDPConn) + if ok && runtime.GOOS == "linux" && (network == "udp4" || network == "udp6") { + // Non-Linux does not support batch operations. x/net will fall back to + // recv/sendmsg, but not all platforms have recv/sendmsg support. Keep + // this simple for now. + return newUDPConnWithBatchOps(uc, network) + } + return p +} + +// setConnLocked sets the provided nettype.PacketConn. It should be called only +// after acquiring RebindingUDPConn.mu. It upgrades the provided +// nettype.PacketConn to a udpConnWithBatchOps when appropriate. This upgrade +// is intentionally pushed closest to where read/write ops occur in order to +// avoid disrupting surrounding code that assumes nettype.PacketConn is a +// *net.UDPConn. +func (c *RebindingUDPConn) setConnLocked(p nettype.PacketConn, network string) { + upc := upgradePacketConn(p, network) + c.pconn = upc + c.pconnAtomic.Store(&upc) c.port = uint16(c.localAddrLocked().Port) } @@ -3087,6 +3288,60 @@ func (c *RebindingUDPConn) ReadFromNetaddr(b []byte) (n int, ipp netip.AddrPort, } } +func (c *RebindingUDPConn) WriteBatch(msgs []ipv6.Message, flags int) (int, error) { + var ( + n int + err error + start int + ) + for { + pconn := *c.pconnAtomic.Load() + bw, ok := pconn.(batchWriter) + if !ok { + for _, msg := range msgs { + _, err = pconn.WriteTo(msg.Buffers[0], msg.Addr) + if err != nil { + return n, err + } + n++ + } + return n, nil + } + + n, err = bw.WriteBatch(msgs[start:], flags) + if err != nil { + if pconn != c.currentConn() { + continue + } + return n, err + } else if n == len(msgs[start:]) { + return len(msgs), nil + } else { + start += n + } + } +} + +func (c *RebindingUDPConn) ReadBatch(msgs []ipv6.Message, flags int) (int, error) { + for { + pconn := *c.pconnAtomic.Load() + br, ok := pconn.(batchReader) + if !ok { + var err error + msgs[0].N, msgs[0].Addr, err = c.ReadFrom(msgs[0].Buffers[0]) + if err == nil { + return 1, nil + } + return 0, err + } + n, err := br.ReadBatch(msgs, flags) + if err != nil && pconn != c.currentConn() { + continue + } + return n, err + } +} + func (c *RebindingUDPConn) Port() uint16 { c.mu.Lock() defer c.mu.Unlock() @@ -3175,6 +3430,20 @@ func (c *blockForeverConn) WriteToUDPAddrPort(p []byte, addr netip.AddrPort) (in return len(p), nil } +func (c *blockForeverConn) ReadBatch(p []ipv6.Message, flags int) (int, error) { + c.mu.Lock() + for !c.closed { + c.cond.Wait() + } + c.mu.Unlock() + return 0, net.ErrClosed +} + +func (c *blockForeverConn) WriteBatch(p []ipv6.Message, flags int) (int, error) { + // Silently drop writes. + return len(p), nil +} + func (c *blockForeverConn) LocalAddr() net.Addr { // Return a *net.UDPAddr because lots of code assumes that it will. return new(net.UDPAddr) @@ -3302,7 +3571,7 @@ func ippDebugString(ua netip.AddrPort) string { return ua.String() } -// endpointSendFunc is a func that writes an encrypted Wireguard payload from +// endpointSendFunc is a func that writes encrypted Wireguard payloads from // WireGuard to a peer. It might write via UDP, DERP, both, or neither. // // What these funcs should NOT do is too much work. Minimize use of mutexes, map @@ -3313,7 +3582,7 @@ func ippDebugString(ua netip.AddrPort) string { // // A nil value means the current fast path has expired and needs to be // recalculated. -type endpointSendFunc func([]byte) error +type endpointSendFunc func([][]byte) error // discoEndpoint is a wireguard/conn.Endpoint that picks the best // available path to communicate with a peer, based on network @@ -3629,9 +3898,9 @@ func (de *endpoint) cliPing(res *ipnstate.PingResult, cb func(*ipnstate.PingResu de.noteActiveLocked() } -func (de *endpoint) send(b []byte) error { +func (de *endpoint) send(buffs [][]byte) error { if fn := de.sendFunc.Load(); fn != nil { - return fn(b) + return fn(buffs) } de.mu.Lock() @@ -3656,21 +3925,30 @@ func (de *endpoint) send(b []byte) error { } var err error if udpAddr.IsValid() { - _, err = de.c.sendAddr(udpAddr, de.publicKey, b) + _, err = de.c.sendUDPBatch(udpAddr, buffs) + // TODO(raggi): needs updating for accuracy, as in error conditions we may have partial sends. if stats := de.c.stats.Load(); err == nil && stats != nil { - stats.UpdateTxPhysical(de.nodeAddr, udpAddr, len(b)) + var txBytes int + for _, b := range buffs { + txBytes += len(b) + } + stats.UpdateTxPhysical(de.nodeAddr, udpAddr, txBytes) } } if derpAddr.IsValid() { - if ok, _ := de.c.sendAddr(derpAddr, de.publicKey, b); ok { + allOk := true + for _, buff := range buffs { + ok, _ := de.c.sendAddr(derpAddr, de.publicKey, buff) if stats := de.c.stats.Load(); stats != nil { - stats.UpdateTxPhysical(de.nodeAddr, derpAddr, len(b)) + stats.UpdateTxPhysical(de.nodeAddr, derpAddr, len(buff)) } - if err != nil { - // UDP failed but DERP worked, so good enough: - return nil + if !ok { + allOk = false } } + if allOk { + return nil + } } return err } diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 0f1196a54..7e8f623c8 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -29,6 +29,7 @@ import ( "go4.org/mem" "golang.org/x/exp/maps" + wgconn "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun/tuntest" "tailscale.com/derp" @@ -364,9 +365,12 @@ func TestNewConn(t *testing.T) { conn.SetPrivateKey(key.NewNode()) go func() { - var pkt [64 << 10]byte + pkts := make([][]byte, 1) + sizes := make([]int, 1) + eps := make([]wgconn.Endpoint, 1) + pkts[0] = make([]byte, 64<<10) for { - _, _, err := conn.receiveIPv4(pkt[:]) + _, err := conn.receiveIPv4(pkts, sizes, eps) if err != nil { return } @@ -1262,17 +1266,20 @@ func setUpReceiveFrom(tb testing.TB) (roundTrip func()) { for i := range sendBuf { sendBuf[i] = 'x' } - buf := make([]byte, 2<<10) + buffs := make([][]byte, 1) + buffs[0] = make([]byte, 2<<10) + sizes := make([]int, 1) + eps := make([]wgconn.Endpoint, 1) return func() { if _, err := sendConn.WriteTo(sendBuf, dstAddr); err != nil { tb.Fatalf("WriteTo: %v", err) } - n, ep, err := conn.receiveIPv4(buf) + n, err := conn.receiveIPv4(buffs, sizes, eps) if err != nil { tb.Fatal(err) } _ = n - _ = ep + _ = eps } } @@ -1330,6 +1337,9 @@ func TestGoMajorVersion(t *testing.T) { } func TestReceiveFromAllocs(t *testing.T) { + // TODO(jwhited): we are back to nonzero alloc due to our use of x/net until + // https://github.com/golang/go/issues/45886 is implemented. + t.Skip("alloc tests are skipped until https://github.com/golang/go/issues/45886 is implemented and plumbed.") if racebuild.On { t.Skip("alloc tests are unreliable with -race") } @@ -1481,9 +1491,12 @@ func TestRebindStress(t *testing.T) { errc := make(chan error, 1) go func() { - buf := make([]byte, 1500) + buffs := make([][]byte, 1) + sizes := make([]int, 1) + eps := make([]wgconn.Endpoint, 1) + buffs[0] = make([]byte, 1500) for { - _, _, err := conn.receiveIPv4(buf) + _, err := conn.receiveIPv4(buffs, sizes, eps) if ctx.Err() != nil { errc <- nil return @@ -1813,6 +1826,6 @@ func TestRebindingUDPConn(t *testing.T) { t.Fatal(err) } defer realConn.Close() - c.setConnLocked(realConn.(nettype.PacketConn)) - c.setConnLocked(newBlockForeverConn()) + c.setConnLocked(realConn.(nettype.PacketConn), "udp4") + c.setConnLocked(newBlockForeverConn(), "") } diff --git a/wgengine/wgcfg/device_test.go b/wgengine/wgcfg/device_test.go index 8868ad785..c68c0d01d 100644 --- a/wgengine/wgcfg/device_test.go +++ b/wgengine/wgcfg/device_test.go @@ -213,18 +213,18 @@ func newNilTun() tun.Device { } } -func (t *nilTun) File() *os.File { return nil } -func (t *nilTun) Flush() error { return nil } -func (t *nilTun) MTU() (int, error) { return 1420, nil } -func (t *nilTun) Name() (string, error) { return "niltun", nil } -func (t *nilTun) Events() chan tun.Event { return t.events } +func (t *nilTun) File() *os.File { return nil } +func (t *nilTun) Flush() error { return nil } +func (t *nilTun) MTU() (int, error) { return 1420, nil } +func (t *nilTun) Name() (string, error) { return "niltun", nil } +func (t *nilTun) Events() <-chan tun.Event { return t.events } -func (t *nilTun) Read(data []byte, offset int) (int, error) { +func (t *nilTun) Read(data [][]byte, sizes []int, offset int) (int, error) { <-t.closed return 0, io.EOF } -func (t *nilTun) Write(data []byte, offset int) (int, error) { +func (t *nilTun) Write(data [][]byte, offset int) (int, error) { <-t.closed return 0, io.EOF } @@ -235,18 +235,21 @@ func (t *nilTun) Close() error { return nil } +func (t *nilTun) BatchSize() int { return 1 } + // A noopBind is a conn.Bind that does no actual binding work. type noopBind struct{} func (noopBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { return nil, 1, nil } -func (noopBind) Close() error { return nil } -func (noopBind) SetMark(mark uint32) error { return nil } -func (noopBind) Send(b []byte, ep conn.Endpoint) error { return nil } +func (noopBind) Close() error { return nil } +func (noopBind) SetMark(mark uint32) error { return nil } +func (noopBind) Send(b [][]byte, ep conn.Endpoint) error { return nil } func (noopBind) ParseEndpoint(s string) (conn.Endpoint, error) { return dummyEndpoint(s), nil } +func (noopBind) BatchSize() int { return 1 } // A dummyEndpoint is a string holding the endpoint destination. type dummyEndpoint string