From 517c90d7e583ff446db3891a521c859060e891ae Mon Sep 17 00:00:00 2001 From: Naman Sood Date: Thu, 25 Feb 2021 14:18:16 -0500 Subject: [PATCH] wgengine, cmd/tailscaled: refactor netstack, forward TCP to hello as demo (#1301) Updates #707 Updates #504 Signed-off-by: Naman Sood --- cmd/tailscaled/depaware.txt | 9 +- cmd/tailscaled/tailscaled.go | 4 +- wgengine/netstack/netstack.go | 338 ++++++++++++++++++---------- wgengine/netstack/netstack_32bit.go | 4 +- wgengine/userspace.go | 45 ++-- 5 files changed, 261 insertions(+), 139 deletions(-) diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index bf0ec8359..1ba1e1c4a 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -49,10 +49,11 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de gvisor.dev/gvisor/pkg/tcpip/header from gvisor.dev/gvisor/pkg/tcpip/link/channel+ gvisor.dev/gvisor/pkg/tcpip/header/parse from gvisor.dev/gvisor/pkg/tcpip/network/ipv4+ gvisor.dev/gvisor/pkg/tcpip/link/channel from tailscale.com/wgengine/netstack - gvisor.dev/gvisor/pkg/tcpip/network/fragmentation from gvisor.dev/gvisor/pkg/tcpip/network/ipv4 - gvisor.dev/gvisor/pkg/tcpip/network/hash from gvisor.dev/gvisor/pkg/tcpip/network/ipv4 - gvisor.dev/gvisor/pkg/tcpip/network/ip from gvisor.dev/gvisor/pkg/tcpip/network/ipv4 + gvisor.dev/gvisor/pkg/tcpip/network/fragmentation from gvisor.dev/gvisor/pkg/tcpip/network/ipv4+ + gvisor.dev/gvisor/pkg/tcpip/network/hash from gvisor.dev/gvisor/pkg/tcpip/network/ipv4+ + gvisor.dev/gvisor/pkg/tcpip/network/ip from gvisor.dev/gvisor/pkg/tcpip/network/ipv4+ gvisor.dev/gvisor/pkg/tcpip/network/ipv4 from tailscale.com/wgengine/netstack + gvisor.dev/gvisor/pkg/tcpip/network/ipv6 from tailscale.com/wgengine/netstack gvisor.dev/gvisor/pkg/tcpip/ports from gvisor.dev/gvisor/pkg/tcpip/stack+ gvisor.dev/gvisor/pkg/tcpip/seqnum from gvisor.dev/gvisor/pkg/tcpip/header+ gvisor.dev/gvisor/pkg/tcpip/stack from gvisor.dev/gvisor/pkg/tcpip/adapters/gonet+ @@ -225,7 +226,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de hash from compress/zlib+ hash/adler32 from compress/zlib hash/crc32 from compress/gzip+ - hash/fnv from tailscale.com/wgengine/magicsock + hash/fnv from tailscale.com/wgengine/magicsock+ hash/maphash from go4.org/mem html from net/http/pprof+ io from bufio+ diff --git a/cmd/tailscaled/tailscaled.go b/cmd/tailscaled/tailscaled.go index 0110c8a75..4ff26fbb8 100644 --- a/cmd/tailscaled/tailscaled.go +++ b/cmd/tailscaled/tailscaled.go @@ -192,9 +192,9 @@ func run() error { var e wgengine.Engine if args.fake { - var impl wgengine.FakeImplFunc + var impl wgengine.FakeImplFactory if args.tunname == "userspace-networking" { - impl = netstack.Impl + impl = netstack.Create } e, err = wgengine.NewFakeUserspaceEngine(logf, 0, impl) } else { diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index b2b21fcba..5906f00bf 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -12,8 +12,8 @@ import ( "context" "errors" "fmt" + "io" "log" - "net" "strings" "gvisor.dev/gvisor/pkg/tcpip" @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" @@ -37,161 +38,270 @@ import ( "tailscale.com/wgengine/tstun" ) -func Impl(logf logger.Logf, tundev *tstun.TUN, e wgengine.Engine, mc *magicsock.Conn) error { +// Impl contains the state for the netstack implementation, +// and implements wgengine.FakeImpl to act as a userspace network +// stack when Tailscale is running in fake mode. +type Impl struct { + ipstack *stack.Stack + linkEP *channel.Endpoint + tundev *tstun.TUN + e wgengine.Engine + mc *magicsock.Conn + logf logger.Logf +} + +const nicID = 1 + +// Create creates and populates a new Impl. +func Create(logf logger.Logf, tundev *tstun.TUN, e wgengine.Engine, mc *magicsock.Conn) (wgengine.FakeImpl, error) { if mc == nil { - return errors.New("nil magicsock.Conn") + return nil, errors.New("nil magicsock.Conn") } if tundev == nil { - return errors.New("nil tundev") + return nil, errors.New("nil tundev") } if logf == nil { - return errors.New("nil logger") + return nil, errors.New("nil logger") } if e == nil { - return errors.New("nil Engine") + return nil, errors.New("nil Engine") } ipstack := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6}, }) - const mtu = 1500 linkEP := channel.New(512, mtu, "") + if tcpipProblem := ipstack.CreateNIC(nicID, linkEP); tcpipProblem != nil { + return nil, fmt.Errorf("could not create netstack NIC: %v", tcpipProblem) + } + // Add IPv4 and IPv6 default routes, so all incoming packets from the Tailscale side + // are handled by the one fake NIC we use. + ipv4Subnet, _ := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", 4)), tcpip.AddressMask(strings.Repeat("\x00", 4))) + ipv6Subnet, _ := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", 16)), tcpip.AddressMask(strings.Repeat("\x00", 16))) + ipstack.SetRouteTable([]tcpip.Route{ + { + Destination: ipv4Subnet, + NIC: nicID, + }, + { + Destination: ipv6Subnet, + NIC: nicID, + }, + }) + ns := &Impl{ + logf: logf, + ipstack: ipstack, + linkEP: linkEP, + tundev: tundev, + e: e, + mc: mc, + } + return ns, nil +} + +// Start sets up all the handlers so netstack can start working. Implements +// wgengine.FakeImpl. +func (ns *Impl) Start() error { + ns.e.AddNetworkMapCallback(ns.updateIPs) + // size = 0 means use default buffer size + const tcpReceiveBufferSize = 0 + const maxInFlightConnectionAttempts = 16 + tcpFwd := tcp.NewForwarder(ns.ipstack, tcpReceiveBufferSize, maxInFlightConnectionAttempts, ns.acceptTCP) + udpFwd := udp.NewForwarder(ns.ipstack, ns.acceptUDP) + ns.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket) + ns.ipstack.SetTransportProtocolHandler(udp.ProtocolNumber, udpFwd.HandlePacket) + go ns.injectOutbound() + ns.tundev.PostFilterIn = ns.injectInbound - const nicID = 1 - if err := ipstack.CreateNIC(nicID, linkEP); err != nil { - log.Fatal(err) + return nil +} + +func (ns *Impl) updateIPs(nm *netmap.NetworkMap) { + oldIPs := make(map[tcpip.Address]bool) + for _, ip := range ns.ipstack.AllAddresses()[nicID] { + oldIPs[ip.AddressWithPrefix.Address] = true + } + newIPs := make(map[tcpip.Address]bool) + for _, ip := range nm.Addresses { + newIPs[tcpip.Address(ip.IP.IPAddr().IP)] = true } - e.AddNetworkMapCallback(func(nm *netmap.NetworkMap) { - oldIPs := make(map[tcpip.Address]bool) - for _, ip := range ipstack.AllAddresses()[nicID] { - oldIPs[ip.AddressWithPrefix.Address] = true + ipsToBeAdded := make(map[tcpip.Address]bool) + for ip := range newIPs { + if !oldIPs[ip] { + ipsToBeAdded[ip] = true } - newIPs := make(map[tcpip.Address]bool) - for _, ip := range nm.Addresses { - newIPs[tcpip.Address(ip.IPNet().IP)] = true + } + ipsToBeRemoved := make(map[tcpip.Address]bool) + for ip := range oldIPs { + if !newIPs[ip] { + ipsToBeRemoved[ip] = true } + } - ipsToBeAdded := make(map[tcpip.Address]bool) - for ip := range newIPs { - if !oldIPs[ip] { - ipsToBeAdded[ip] = true - } + for ip := range ipsToBeRemoved { + err := ns.ipstack.RemoveAddress(nicID, ip) + if err != nil { + ns.logf("netstack: could not deregister IP %s: %v", ip, err) + } else { + ns.logf("[v2] netstack: deregistered IP %s", ip) } - ipsToBeRemoved := make(map[tcpip.Address]bool) - for ip := range oldIPs { - if !newIPs[ip] { - ipsToBeRemoved[ip] = true - } + } + for ip := range ipsToBeAdded { + var err *tcpip.Error + if ip.To4() == "" { + err = ns.ipstack.AddAddress(nicID, ipv6.ProtocolNumber, ip) + } else { + err = ns.ipstack.AddAddress(nicID, ipv4.ProtocolNumber, ip) } - - for ip := range ipsToBeRemoved { - err := ipstack.RemoveAddress(nicID, ip) - if err != nil { - logf("netstack: could not deregister IP %s: %v", ip, err) - } else { - logf("netstack: deregistered IP %s", ip) - } + if err != nil { + ns.logf("netstack: could not register IP %s: %v", ip, err) + } else { + ns.logf("[v2] netstack: registered IP %s", ip) } - for ip := range ipsToBeAdded { - err := ipstack.AddAddress(nicID, ipv4.ProtocolNumber, ip) - if err != nil { - logf("netstack: could not register IP %s: %v", ip, err) - } else { - logf("netstack: registered IP %s", ip) - } + } +} + +func (ns *Impl) dialContextTCP(ctx context.Context, address string) (*gonet.TCPConn, error) { + remoteIPPort, err := netaddr.ParseIPPort(address) + if err != nil { + return nil, fmt.Errorf("could not parse IP:port: %w", err) + } + remoteAddress := tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.Address(remoteIPPort.IP.IPAddr().IP), + Port: remoteIPPort.Port, + } + var ipType tcpip.NetworkProtocolNumber + if remoteIPPort.IP.Is4() { + ipType = ipv4.ProtocolNumber + } else { + ipType = ipv6.ProtocolNumber + } + + return gonet.DialContextTCP(ctx, ns.ipstack, remoteAddress, ipType) +} + +func (ns *Impl) injectOutbound() { + for { + packetInfo, ok := ns.linkEP.ReadContext(context.Background()) + if !ok { + ns.logf("[v2] ReadContext-for-write = ok=false") + continue } - }) + pkt := packetInfo.Pkt + hdrNetwork := pkt.NetworkHeader() + hdrTransport := pkt.TransportHeader() - // Add 0.0.0.0/0 default route. - subnet, _ := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", 4)), tcpip.AddressMask(strings.Repeat("\x00", 4))) - ipstack.SetRouteTable([]tcpip.Route{ - { - Destination: subnet, - NIC: nicID, - }, - }) + full := make([]byte, 0, pkt.Size()) + full = append(full, hdrNetwork.View()...) + full = append(full, hdrTransport.View()...) + full = append(full, pkt.Data.ToView()...) - // use Forwarder to accept any connection from stack - fwd := tcp.NewForwarder(ipstack, 0, 16, func(r *tcp.ForwarderRequest) { - logf("XXX ForwarderRequest: %v", r) - var wq waiter.Queue - ep, err := r.CreateEndpoint(&wq) - if err != nil { - r.Complete(true) + ns.logf("[v2] packet Write out: % x", full) + if err := ns.tundev.InjectOutbound(full); err != nil { + log.Printf("netstack inject outbound: %v", err) return } - r.Complete(false) - c := gonet.NewTCPConn(&wq, ep) - // TCP echo - go echo(c, e, mc) + } +} + +func (ns *Impl) injectInbound(p *packet.Parsed, t *tstun.TUN) filter.Response { + var pn tcpip.NetworkProtocolNumber + switch p.IPVersion { + case 4: + pn = header.IPv4ProtocolNumber + case 6: + pn = header.IPv6ProtocolNumber + } + ns.logf("[v2] packet in (from %v): % x", p.Src, p.Buffer()) + vv := buffer.View(append([]byte(nil), p.Buffer()...)).ToVectorisedView() + packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, }) - ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) + ns.linkEP.InjectInbound(pn, packetBuf) + return filter.Accept +} - go func() { - for { - packetInfo, ok := linkEP.ReadContext(context.Background()) - if !ok { - logf("XXX ReadContext-for-write = ok=false") - continue - } - pkt := packetInfo.Pkt - hdrNetwork := pkt.NetworkHeader() - hdrTransport := pkt.TransportHeader() - - full := make([]byte, 0, pkt.Size()) - full = append(full, hdrNetwork.View()...) - full = append(full, hdrTransport.View()...) - full = append(full, pkt.Data.ToView()...) - - logf("XXX packet Write out: % x", full) - if err := tundev.InjectOutbound(full); err != nil { - log.Printf("netstack inject outbound: %v", err) - return - } +func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) { + ns.logf("[v2] ForwarderRequest: %v", r) + var wq waiter.Queue + ep, err := r.CreateEndpoint(&wq) + if err != nil { + r.Complete(true) + return + } + localAddr, err := ep.GetLocalAddress() + ns.logf("[v2] forwarding port %v to 100.101.102.103:80", localAddr.Port) + if err != nil { + r.Complete(true) + return + } + r.Complete(false) + c := gonet.NewTCPConn(&wq, ep) + go ns.forwardTCP(c, &wq, "100.101.102.103:80") +} +func (ns *Impl) forwardTCP(client *gonet.TCPConn, wq *waiter.Queue, address string) { + defer client.Close() + ns.logf("[v2] netstack: forwarding to address %s", address) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + wq.EventRegister(&waitEntry, waiter.EventHUp) + defer wq.EventUnregister(&waitEntry) + done := make(chan bool) + // netstack doesn't close the notification channel automatically if there was no + // hup signal, so we close done after we're done to not leak the goroutine below. + defer close(done) + go func() { + select { + case <-notifyCh: + case <-done: } + cancel() }() - - tundev.PostFilterIn = func(p *packet.Parsed, t *tstun.TUN) filter.Response { - var pn tcpip.NetworkProtocolNumber - switch p.IPVersion { - case 4: - pn = header.IPv4ProtocolNumber - case 6: - pn = header.IPv6ProtocolNumber - } - logf("XXX packet in (from %v): % x", p.Src, p.Buffer()) - vv := buffer.View(append([]byte(nil), p.Buffer()...)).ToVectorisedView() - packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - }) - linkEP.InjectInbound(pn, packetBuf) - return filter.Accept + server, err := ns.dialContextTCP(ctx, address) + if err != nil { + ns.logf("netstack: could not connect to server %s: %s", address, err) + return } - return nil + defer server.Close() + connClosed := make(chan bool, 2) + go func() { + io.Copy(server, client) + connClosed <- true + }() + go func() { + io.Copy(client, server) + connClosed <- true + }() + <-connClosed + ns.logf("[v2] netstack: forwarder connection to %s closed", address) } -func echo(c *gonet.TCPConn, e wgengine.Engine, mc *magicsock.Conn) { - defer c.Close() - src, _ := netaddr.FromStdIP(c.RemoteAddr().(*net.TCPAddr).IP) - who := "" - if n, u, ok := mc.WhoIs(src); ok { - who = fmt.Sprintf("%v from %v", u.DisplayName, n.Name) +func (ns *Impl) acceptUDP(r *udp.ForwarderRequest) { + ns.logf("[v2] UDP ForwarderRequest: %v", r) + var wq waiter.Queue + ep, err := r.CreateEndpoint(&wq) + if err != nil { + ns.logf("Could not create endpoint, exiting") + return } - fmt.Fprintf(c, "Hello, %s! Thanks for connecting to me on port %v (Try other ports too!)\nEchoing...\n", - who, - c.LocalAddr().(*net.TCPAddr).Port) + c := gonet.NewUDPConn(ns.ipstack, &wq, ep) + go echoUDP(c) +} + +func echoUDP(c *gonet.UDPConn) { buf := make([]byte, 1500) for { n, err := c.Read(buf) if err != nil { - log.Printf("Err: %v", err) break } c.Write(buf[:n]) } - log.Print("Connection closed") + c.Close() } diff --git a/wgengine/netstack/netstack_32bit.go b/wgengine/netstack/netstack_32bit.go index 22b34e51e..dfb433d3d 100644 --- a/wgengine/netstack/netstack_32bit.go +++ b/wgengine/netstack/netstack_32bit.go @@ -16,6 +16,6 @@ import ( "tailscale.com/wgengine/tstun" ) -func Impl(logf logger.Logf, tundev *tstun.TUN, e wgengine.Engine, mc *magicsock.Conn) error { - return errors.New("netstack is not supported on 32-bit platforms for now; see https://github.com/google/gvisor/issues/5241") +func Create(logf logger.Logf, tundev *tstun.TUN, e wgengine.Engine, mc *magicsock.Conn) (wgengine.FakeImpl, error) { + return nil, errors.New("netstack is not supported on 32-bit platforms for now; see https://github.com/google/gvisor/issues/5241") } diff --git a/wgengine/userspace.go b/wgengine/userspace.go index 952abfc09..fcbdd1eeb 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -146,26 +146,33 @@ type EngineConfig struct { // which disables such features as DNS configuration and unrestricted ICMP Echo responses. Fake bool - // FakeImpl, if non-nil, specifies which type of fake implementation to - // use. Two values are typical: nil, for a basic ping-only fake - // implementation, and netstack.Impl, which brings in gvisor's netstack - // to the binary. The desire to keep that out of some binaries is why - // this func exists, so wgengine need not depend on gvisor. - FakeImpl FakeImplFunc + // FakeImplFactory, if non-nil, creates a FakeImpl to use as a fake engine + // implementation. Two values are typical: nil, for a basic ping-only fake + // implementation, and netstack.Create, which creates a userspace network + // stack using gvisor's netstack. The desire to keep netstack out of some + // binaries is why the FakeImpl interface exists, so wgengine need not + // depend on gvisor. + FakeImplFactory FakeImplFactory } -// FakeImplFunc is the type used by EngineConfig.FakeImpl. See docs there. -type FakeImplFunc func(logger.Logf, *tstun.TUN, Engine, *magicsock.Conn) error +// FakeImpl is a fake or alternate version of Engine that can be started. See +// EngineConfig.FakeImplFactory for details. +type FakeImpl interface { + Start() error +} + +// FakeImplFactory is the type of a function used to create FakeImpls. +type FakeImplFactory func(logger.Logf, *tstun.TUN, Engine, *magicsock.Conn) (FakeImpl, error) -func NewFakeUserspaceEngine(logf logger.Logf, listenPort uint16, impl FakeImplFunc) (Engine, error) { +func NewFakeUserspaceEngine(logf logger.Logf, listenPort uint16, impl FakeImplFactory) (Engine, error) { logf("Starting userspace wireguard engine (with fake TUN device)") conf := EngineConfig{ - Logf: logf, - TUN: tstun.NewFakeTUN(), - RouterGen: router.NewFake, - ListenPort: listenPort, - Fake: true, - FakeImpl: impl, + Logf: logf, + TUN: tstun.NewFakeTUN(), + RouterGen: router.NewFake, + ListenPort: listenPort, + Fake: true, + FakeImplFactory: impl, } return NewUserspaceEngineAdvanced(conf) } @@ -282,8 +289,12 @@ func newUserspaceEngineAdvanced(conf EngineConfig) (_ Engine, reterr error) { // Respond to all pings only in fake mode. if conf.Fake { - if impl := conf.FakeImpl; impl != nil { - if err := impl(logf, e.tundev, e, e.magicConn); err != nil { + if f := conf.FakeImplFactory; f != nil { + impl, err := f(logf, e.tundev, e, e.magicConn) + if err != nil { + return nil, err + } + if err := impl.Start(); err != nil { return nil, err } } else {