From 662fbd4a09664e849f0b898d1e8df13325d36efa Mon Sep 17 00:00:00 2001 From: Naman Sood Date: Mon, 29 Mar 2021 14:33:05 -0400 Subject: [PATCH] wgengine/netstack: Allow userspace networking mode to expose subnets (#1588) wgengine/netstack: Allow userspace networking mode to expose subnets Updates #504 Updates #707 Signed-off-by: Naman Sood --- wgengine/netstack/netstack.go | 152 +++++++++++++++++++++++++--------- 1 file changed, 114 insertions(+), 38 deletions(-) diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index ef3c24b1d..92881decd 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -31,6 +31,7 @@ import ( "inet.af/netstack/tcpip/transport/udp" "inet.af/netstack/waiter" "tailscale.com/net/packet" + "tailscale.com/net/tsaddr" "tailscale.com/net/tstun" "tailscale.com/types/logger" "tailscale.com/types/netmap" @@ -55,6 +56,11 @@ type Impl struct { mu sync.Mutex dns DNSMap + // connsOpenBySubnetIP keeps track of number of connections open + // for each subnet IP temporarily registered on netstack for active + // TCP connections, so they can be unregistered when connections are + // closed. + connsOpenBySubnetIP map[netaddr.IP]int } const nicID = 1 @@ -82,6 +88,12 @@ func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magi if tcpipProblem := ipstack.CreateNIC(nicID, linkEP); tcpipProblem != nil { return nil, fmt.Errorf("could not create netstack NIC: %v", tcpipProblem) } + // By default the netstack NIC will only accept packets for the IPs + // registered to it. Since in some cases we dynamically register IPs + // based on the packets that arrive, the NIC needs to accept all + // incoming packets. The NIC won't receive anything it isn't meant to + // since Wireguard will only send us packets that are meant for us. + ipstack.SetPromiscuousMode(nicID, true) // Add IPv4 and IPv6 default routes, so all incoming packets from the Tailscale side // are handled by the one fake NIC we use. ipv4Subnet, _ := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", 4)), tcpip.AddressMask(strings.Repeat("\x00", 4))) @@ -97,12 +109,13 @@ func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magi }, }) ns := &Impl{ - logf: logf, - ipstack: ipstack, - linkEP: linkEP, - tundev: tundev, - e: e, - mc: mc, + logf: logf, + ipstack: ipstack, + linkEP: linkEP, + tundev: tundev, + e: e, + mc: mc, + connsOpenBySubnetIP: make(map[netaddr.IP]int), } return ns, nil } @@ -116,7 +129,24 @@ func (ns *Impl) Start() error { const maxInFlightConnectionAttempts = 16 tcpFwd := tcp.NewForwarder(ns.ipstack, tcpReceiveBufferSize, maxInFlightConnectionAttempts, ns.acceptTCP) udpFwd := udp.NewForwarder(ns.ipstack, ns.acceptUDP) - ns.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket) + ns.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, func(tei stack.TransportEndpointID, pb *stack.PacketBuffer) bool { + addr := tei.LocalAddress + var pn tcpip.NetworkProtocolNumber + if addr.To4() != "" { + pn = ipv4.ProtocolNumber + } else { + pn = ipv6.ProtocolNumber + } + ip, ok := netaddr.FromStdIP(net.IP(addr)) + if !ok { + ns.logf("netstack: could not parse local address %s for incoming TCP connection", ip) + return false + } + if !tsaddr.IsTailscaleIP(ip) { + ns.addSubnetAddress(pn, ip) + } + return tcpFwd.HandlePacket(tei, pb) + }) ns.ipstack.SetTransportProtocolHandler(udp.ProtocolNumber, udpFwd.HandlePacket) go ns.injectOutbound() ns.tundev.PostFilterIn = ns.injectInbound @@ -156,50 +186,86 @@ func (ns *Impl) updateDNS(nm *netmap.NetworkMap) { ns.dns = DNSMapFromNetworkMap(nm) } +func (ns *Impl) addSubnetAddress(pn tcpip.NetworkProtocolNumber, ip netaddr.IP) { + ns.mu.Lock() + ns.connsOpenBySubnetIP[ip]++ + needAdd := ns.connsOpenBySubnetIP[ip] == 1 + ns.mu.Unlock() + // Only register address into netstack for first concurrent connection. + if needAdd { + ns.ipstack.AddAddress(nicID, pn, tcpip.Address(ip.IPAddr().IP)) + } +} + +func (ns *Impl) removeSubnetAddress(ip netaddr.IP) { + ns.mu.Lock() + defer ns.mu.Unlock() + ns.connsOpenBySubnetIP[ip]-- + // Only unregister address from netstack after last concurrent connection. + if ns.connsOpenBySubnetIP[ip] == 0 { + ns.ipstack.RemoveAddress(nicID, tcpip.Address(ip.IPAddr().IP)) + delete(ns.connsOpenBySubnetIP, ip) + } +} + +func ipPrefixToAddressWithPrefix(ipp netaddr.IPPrefix) tcpip.AddressWithPrefix { + return tcpip.AddressWithPrefix{ + Address: tcpip.Address(ipp.IP.IPAddr().IP), + PrefixLen: int(ipp.Bits), + } +} + func (ns *Impl) updateIPs(nm *netmap.NetworkMap) { ns.updateDNS(nm) - oldIPs := make(map[tcpip.Address]bool) - for _, ip := range ns.ipstack.AllAddresses()[nicID] { - oldIPs[ip.AddressWithPrefix.Address] = true + oldIPs := make(map[tcpip.AddressWithPrefix]bool) + for _, protocolAddr := range ns.ipstack.AllAddresses()[nicID] { + oldIPs[protocolAddr.AddressWithPrefix] = true } - newIPs := make(map[tcpip.Address]bool) - for _, ip := range nm.Addresses { - newIPs[tcpip.Address(ip.IP.IPAddr().IP)] = true + newIPs := make(map[tcpip.AddressWithPrefix]bool) + for _, ipp := range nm.SelfNode.AllowedIPs { + newIPs[ipPrefixToAddressWithPrefix(ipp)] = true } - ipsToBeAdded := make(map[tcpip.Address]bool) - for ip := range newIPs { - if !oldIPs[ip] { - ipsToBeAdded[ip] = true + ipsToBeAdded := make(map[tcpip.AddressWithPrefix]bool) + for ipp := range newIPs { + if !oldIPs[ipp] { + ipsToBeAdded[ipp] = true } } - ipsToBeRemoved := make(map[tcpip.Address]bool) + ipsToBeRemoved := make(map[tcpip.AddressWithPrefix]bool) for ip := range oldIPs { if !newIPs[ip] { ipsToBeRemoved[ip] = true } } + ns.mu.Lock() + for ip := range ns.connsOpenBySubnetIP { + ipp := tcpip.Address(ip.IPAddr().IP).WithPrefix() + ipsToBeAdded[ipp] = true + delete(ipsToBeRemoved, ipp) + } + ns.mu.Unlock() - for ip := range ipsToBeRemoved { - err := ns.ipstack.RemoveAddress(nicID, ip) + for ipp := range ipsToBeRemoved { + err := ns.ipstack.RemoveAddress(nicID, ipp.Address) if err != nil { - ns.logf("netstack: could not deregister IP %s: %v", ip, err) + ns.logf("netstack: could not deregister IP %s: %v", ipp, err) } else { - ns.logf("[v2] netstack: deregistered IP %s", ip) + ns.logf("[v2] netstack: deregistered IP %s", ipp) } } - for ip := range ipsToBeAdded { + for ipp := range ipsToBeAdded { var err tcpip.Error - if ip.To4() == "" { - err = ns.ipstack.AddAddress(nicID, ipv6.ProtocolNumber, ip) + if ipp.Address.To4() == "" { + err = ns.ipstack.AddAddressWithPrefix(nicID, ipv6.ProtocolNumber, ipp) } else { - err = ns.ipstack.AddAddress(nicID, ipv4.ProtocolNumber, ip) + err = ns.ipstack.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, ipp) } if err != nil { - ns.logf("netstack: could not register IP %s: %v", ip, err) + ns.logf("netstack: could not register IP %s: %v", ipp, err) } else { - ns.logf("[v2] netstack: registered IP %s", ip) + ns.logf("[v2] netstack: registered IP %s", ipp) } } } @@ -322,25 +388,35 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) { // ForwarderRequest: &{{{{0 0}}} 0xc0001c30b0 0xc0004c3d40 {1240 6 true 826109390 0 true} ns.logf("[v2] ForwarderRequest: %v", r) } + reqDetails := r.ID() + dialAddr := reqDetails.LocalAddress + dialNetAddr, _ := netaddr.FromStdIP(net.IP(dialAddr)) + isTailscaleIP := tsaddr.IsTailscaleIP(dialNetAddr) + defer func() { + if !isTailscaleIP { + // if this is a subnet IP, we added this in before the TCP handshake + // so netstack is happy TCP-handshaking as a subnet IP + ns.removeSubnetAddress(dialNetAddr) + } + }() var wq waiter.Queue ep, err := r.CreateEndpoint(&wq) if err != nil { r.Complete(true) return } - localAddr, err := ep.GetLocalAddress() - if err != nil { - r.Complete(true) - return + if isTailscaleIP { + dialAddr = tcpip.Address(net.ParseIP("127.0.0.1")).To4() } r.Complete(false) c := gonet.NewTCPConn(&wq, ep) - go ns.forwardTCP(c, &wq, localAddr.Port) + ns.forwardTCP(c, &wq, dialAddr, reqDetails.LocalPort) } -func (ns *Impl) forwardTCP(client *gonet.TCPConn, wq *waiter.Queue, port uint16) { +func (ns *Impl) forwardTCP(client *gonet.TCPConn, wq *waiter.Queue, dialAddr tcpip.Address, dialPort uint16) { defer client.Close() - ns.logf("[v2] netstack: forwarding incoming connection on port %v", port) + dialAddrStr := net.JoinHostPort(dialAddr.String(), strconv.Itoa(int(dialPort))) + ns.logf("[v2] netstack: forwarding incoming connection to %s", dialAddrStr) ctx, cancel := context.WithCancel(context.Background()) defer cancel() waitEntry, notifyCh := waiter.NewChannelEntry(nil) @@ -358,9 +434,9 @@ func (ns *Impl) forwardTCP(client *gonet.TCPConn, wq *waiter.Queue, port uint16) cancel() }() var stdDialer net.Dialer - server, err := stdDialer.DialContext(ctx, "tcp", net.JoinHostPort("localhost", strconv.Itoa(int(port)))) + server, err := stdDialer.DialContext(ctx, "tcp", dialAddrStr) if err != nil { - ns.logf("netstack: could not connect to local server on port %v: %v", port, err) + ns.logf("netstack: could not connect to local server at %s: %v", dialAddrStr, err) return } defer server.Close() @@ -382,7 +458,7 @@ func (ns *Impl) forwardTCP(client *gonet.TCPConn, wq *waiter.Queue, port uint16) if err != nil { ns.logf("proxy connection closed with error: %v", err) } - ns.logf("[v2] netstack: forwarder connection on port %v closed", port) + ns.logf("[v2] netstack: forwarder connection to %s closed", dialAddrStr) } func (ns *Impl) acceptUDP(r *udp.ForwarderRequest) {