diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index 8e24980f3..fdac1a037 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -136,6 +136,24 @@ func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magi return ns, nil } +// wrapProtoHandler returns protocol handler h wrapped in a version +// that dynamically reconfigures ns's subnet addresses as needed for +// outbound traffic. +func (ns *Impl) wrapProtoHandler(h func(stack.TransportEndpointID, *stack.PacketBuffer) bool) func(stack.TransportEndpointID, *stack.PacketBuffer) bool { + return func(tei stack.TransportEndpointID, pb *stack.PacketBuffer) bool { + addr := tei.LocalAddress + ip, ok := netaddr.FromStdIP(net.IP(addr)) + if !ok { + ns.logf("netstack: could not parse local address for incoming connection") + return false + } + if !ns.isLocalIP(ip) { + ns.addSubnetAddress(ip) + } + return h(tei, pb) + } +} + // Start sets up all the handlers so netstack can start working. Implements // wgengine.FakeImpl. func (ns *Impl) Start() error { @@ -145,25 +163,8 @@ func (ns *Impl) Start() error { const maxInFlightConnectionAttempts = 16 tcpFwd := tcp.NewForwarder(ns.ipstack, tcpReceiveBufferSize, maxInFlightConnectionAttempts, ns.acceptTCP) udpFwd := udp.NewForwarder(ns.ipstack, ns.acceptUDP) - ns.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, func(tei stack.TransportEndpointID, pb *stack.PacketBuffer) bool { - addr := tei.LocalAddress - var pn tcpip.NetworkProtocolNumber - if addr.To4() != "" { - pn = ipv4.ProtocolNumber - } else { - pn = ipv6.ProtocolNumber - } - ip, ok := netaddr.FromStdIP(net.IP(addr)) - if !ok { - ns.logf("netstack: could not parse local address %s for incoming TCP connection", ip) - return false - } - if !ns.isLocalIP(ip) { - ns.addSubnetAddress(pn, ip) - } - return tcpFwd.HandlePacket(tei, pb) - }) - ns.ipstack.SetTransportProtocolHandler(udp.ProtocolNumber, udpFwd.HandlePacket) + ns.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, ns.wrapProtoHandler(tcpFwd.HandlePacket)) + ns.ipstack.SetTransportProtocolHandler(udp.ProtocolNumber, ns.wrapProtoHandler(udpFwd.HandlePacket)) go ns.injectOutbound() ns.tundev.PostFilterIn = ns.injectInbound return nil @@ -214,13 +215,19 @@ func (ns *Impl) updateDNS(nm *netmap.NetworkMap) { ns.dns = DNSMapFromNetworkMap(nm) } -func (ns *Impl) addSubnetAddress(pn tcpip.NetworkProtocolNumber, ip netaddr.IP) { +func (ns *Impl) addSubnetAddress(ip netaddr.IP) { ns.mu.Lock() ns.connsOpenBySubnetIP[ip]++ needAdd := ns.connsOpenBySubnetIP[ip] == 1 ns.mu.Unlock() // Only register address into netstack for first concurrent connection. if needAdd { + var pn tcpip.NetworkProtocolNumber + if ip.Is4() { + pn = ipv4.ProtocolNumber + } else if ip.Is6() { + pn = ipv6.ProtocolNumber + } ns.ipstack.AddAddress(nicID, pn, tcpip.Address(ip.IPAddr().IP)) } } @@ -543,9 +550,9 @@ func (ns *Impl) forwardTCP(client *gonet.TCPConn, wq *waiter.Queue, dialAddr tcp } func (ns *Impl) acceptUDP(r *udp.ForwarderRequest) { - reqDetails := r.ID() + sess := r.ID() if debugNetstack { - ns.logf("[v2] UDP ForwarderRequest: %v", stringifyTEI(reqDetails)) + ns.logf("[v2] UDP ForwarderRequest: %v", stringifyTEI(sess)) } var wq waiter.Queue ep, err := r.CreateEndpoint(&wq) @@ -553,30 +560,50 @@ func (ns *Impl) acceptUDP(r *udp.ForwarderRequest) { ns.logf("acceptUDP: could not create endpoint: %v", err) return } - localAddr, err := ep.GetLocalAddress() - if err != nil { + dstAddr, ok := ipPortOfNetstackAddr(sess.LocalAddress, sess.LocalPort) + if !ok { return } - remoteAddr, err := ep.GetRemoteAddress() - if err != nil { + srcAddr, ok := ipPortOfNetstackAddr(sess.RemoteAddress, sess.RemotePort) + if !ok { return } + c := gonet.NewUDPConn(ns.ipstack, &wq, ep) - go ns.forwardUDP(c, &wq, localAddr, remoteAddr) + go ns.forwardUDP(c, &wq, srcAddr, dstAddr) } -func (ns *Impl) forwardUDP(client *gonet.UDPConn, wq *waiter.Queue, clientLocalAddr, clientRemoteAddr tcpip.FullAddress) { - port := clientLocalAddr.Port +// forwardUDP proxies between client (with addr clientAddr) and dstAddr. +// +// dstAddr may be either a local Tailscale IP, in which we case we proxy to +// 127.0.0.1, or any other IP (from an advertised subnet), in which case we +// proxy to it directly. +func (ns *Impl) forwardUDP(client *gonet.UDPConn, wq *waiter.Queue, clientAddr, dstAddr netaddr.IPPort) { + port, srcPort := dstAddr.Port(), clientAddr.Port() ns.logf("[v2] netstack: forwarding incoming UDP connection on port %v", port) - backendListenAddr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: int(clientRemoteAddr.Port)} - backendRemoteAddr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: int(port)} - backendConn, err := net.ListenUDP("udp4", backendListenAddr) + + var backendListenAddr *net.UDPAddr + var backendRemoteAddr *net.UDPAddr + isLocal := ns.isLocalIP(dstAddr.IP()) + if isLocal { + backendRemoteAddr = &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: int(port)} + backendListenAddr = &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: int(srcPort)} + } else { + backendRemoteAddr = dstAddr.UDPAddr() + if dstAddr.IP().Is4() { + backendListenAddr = &net.UDPAddr{IP: net.ParseIP("0.0.0.0"), Port: int(srcPort)} + } else { + backendListenAddr = &net.UDPAddr{IP: net.ParseIP("::"), Port: int(srcPort)} + } + } + + backendConn, err := net.ListenUDP("udp", backendListenAddr) if err != nil { - ns.logf("netstack: could not bind local port %v: %v, trying again with random port", clientRemoteAddr.Port, err) + ns.logf("netstack: could not bind local port %v: %v, trying again with random port", backendListenAddr.Port, err) backendListenAddr.Port = 0 - backendConn, err = net.ListenUDP("udp4", backendListenAddr) + backendConn, err = net.ListenUDP("udp", backendListenAddr) if err != nil { - ns.logf("netstack: could not connect to local UDP server on port %v: %v", port, err) + ns.logf("netstack: could not create UDP socket, preventing forwarding to %v: %v", dstAddr, err) return } } @@ -585,28 +612,47 @@ func (ns *Impl) forwardUDP(client *gonet.UDPConn, wq *waiter.Queue, clientLocalA if !ok { ns.logf("could not get backend local IP:port from %v:%v", backendLocalAddr.IP, backendLocalAddr.Port) } - clientRemoteIP, _ := netaddr.FromStdIP(net.ParseIP(clientRemoteAddr.Addr.String())) - ns.e.RegisterIPPortIdentity(backendLocalIPPort, clientRemoteIP) + if isLocal { + ns.e.RegisterIPPortIdentity(backendLocalIPPort, dstAddr.IP()) + } ctx, cancel := context.WithCancel(context.Background()) - timer := time.AfterFunc(2*time.Minute, func() { - ns.e.UnregisterIPPortIdentity(backendLocalIPPort) - ns.logf("netstack: UDP session between %s and %s timed out", clientRemoteAddr, backendRemoteAddr) + + idleTimeout := 2 * time.Minute + if port == 53 { + // Make DNS packet copies time out much sooner. + // + // TODO(bradfitz): make DNS queries over UDP forwarding even + // cheaper by adding an additional idleTimeout post-DNS-reply. + // For instance, after the DNS response goes back out, then only + // wait a few seconds (or zero, really) + idleTimeout = 30 * time.Second + } + timer := time.AfterFunc(idleTimeout, func() { + if isLocal { + ns.e.UnregisterIPPortIdentity(backendLocalIPPort) + } + ns.logf("netstack: UDP session between %s and %s timed out", backendListenAddr, backendRemoteAddr) cancel() client.Close() backendConn.Close() }) extend := func() { - timer.Reset(2 * time.Minute) + timer.Reset(idleTimeout) } - startPacketCopy(ctx, cancel, client, &net.UDPAddr{ - IP: net.ParseIP(clientRemoteAddr.Addr.String()), - Port: int(clientRemoteAddr.Port), - }, backendConn, ns.logf, extend) + startPacketCopy(ctx, cancel, client, clientAddr.UDPAddr(), backendConn, ns.logf, extend) startPacketCopy(ctx, cancel, backendConn, backendRemoteAddr, client, ns.logf, extend) - + if isLocal { + // Wait for the copies to be done before decrementing the + // subnet address count to potentially remove the route. + <-ctx.Done() + ns.removeSubnetAddress(dstAddr.IP()) + } } func startPacketCopy(ctx context.Context, cancel context.CancelFunc, dst net.PacketConn, dstAddr net.Addr, src net.PacketConn, logf logger.Logf, extend func()) { + if debugNetstack { + logf("[v2] netstack: startPacketCopy to %v (%T) from %T", dstAddr, dst, src) + } go func() { defer cancel() // tear down the other direction's copy pkt := make([]byte, mtu) @@ -643,3 +689,7 @@ func stringifyTEI(tei stack.TransportEndpointID) string { remoteHostPort := net.JoinHostPort(tei.RemoteAddress.String(), strconv.Itoa(int(tei.RemotePort))) return fmt.Sprintf("%s -> %s", remoteHostPort, localHostPort) } + +func ipPortOfNetstackAddr(a tcpip.Address, port uint16) (ipp netaddr.IPPort, ok bool) { + return netaddr.FromStdAddr(net.IP(a), int(port), "") // TODO(bradfitz): can do without allocs +}