diff --git a/cmd/tailscaled/tailscaled.go b/cmd/tailscaled/tailscaled.go index 868145ae0..6edae3030 100644 --- a/cmd/tailscaled/tailscaled.go +++ b/cmd/tailscaled/tailscaled.go @@ -289,10 +289,14 @@ func run() error { return err } - var ns *netstack.Impl - if useNetstack || wrapNetstack { - onlySubnets := wrapNetstack && !useNetstack - ns = mustStartNetstack(logf, e, onlySubnets) + ns, err := newNetstack(logf, e) + if err != nil { + return fmt.Errorf("newNetstack: %w", err) + } + ns.ProcessLocalIPs = useNetstack + ns.ProcessSubnets = useNetstack || wrapNetstack + if err := ns.Start(); err != nil { + log.Fatalf("failed to start netstack: %v", err) } if socksListener != nil || httpProxyListener != nil { @@ -453,19 +457,12 @@ func runDebugServer(mux *http.ServeMux, addr string) { } } -func mustStartNetstack(logf logger.Logf, e wgengine.Engine, onlySubnets bool) *netstack.Impl { +func newNetstack(logf logger.Logf, e wgengine.Engine) (*netstack.Impl, error) { tunDev, magicConn, ok := e.(wgengine.InternalsGetter).GetInternals() if !ok { - log.Fatalf("%T is not a wgengine.InternalsGetter", e) - } - ns, err := netstack.Create(logf, tunDev, e, magicConn, onlySubnets) - if err != nil { - log.Fatalf("netstack.Create: %v", err) - } - if err := ns.Start(); err != nil { - log.Fatalf("failed to start netstack: %v", err) + return nil, fmt.Errorf("%T is not a wgengine.InternalsGetter", e) } - return ns + return netstack.Create(logf, tunDev, e, magicConn) } func mustStartTCPListener(name, addr string) net.Listener { diff --git a/cmd/tailscaled/tailscaled_windows.go b/cmd/tailscaled/tailscaled_windows.go index 303581636..1c6444edb 100644 --- a/cmd/tailscaled/tailscaled_windows.go +++ b/cmd/tailscaled/tailscaled_windows.go @@ -202,9 +202,14 @@ func startIPNServer(ctx context.Context, logid string) error { dev.Close() return nil, fmt.Errorf("engine: %w", err) } - onlySubnets := true - if wrapNetstack { - mustStartNetstack(logf, eng, onlySubnets) + ns, err := newNetstack(logf, eng) + if err != nil { + return nil, fmt.Errorf("newNetstack: %w", err) + } + ns.ProcessLocalIPs = false + ns.ProcessSubnets = wrapNetstack + if err := ns.Start(); err != nil { + return nil, fmt.Errorf("failed to start netstack: %w", err) } return wgengine.NewWatchdog(eng), nil } diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index 3fd36b536..e1febeb2b 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -127,10 +127,11 @@ func (s *Server) start() error { return fmt.Errorf("%T is not a wgengine.InternalsGetter", eng) } - ns, err := netstack.Create(logf, tunDev, eng, magicConn, false) + ns, err := netstack.Create(logf, tunDev, eng, magicConn) if err != nil { return fmt.Errorf("netstack.Create: %w", err) } + ns.ProcessLocalIPs = true ns.ForwardTCPIn = s.forwardTCP if err := ns.Start(); err != nil { return fmt.Errorf("failed to start netstack: %w", err) diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index fc098e7a0..b0c596a29 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -54,13 +54,23 @@ type Impl struct { // port other than accepting it and closing it. ForwardTCPIn func(c net.Conn, port uint16) - ipstack *stack.Stack - linkEP *channel.Endpoint - tundev *tstun.Wrapper - e wgengine.Engine - mc *magicsock.Conn - logf logger.Logf - onlySubnets bool // whether we only want to handle subnet relaying + // ProcessLocalIPs is whether netstack should handle incoming + // traffic directed at the Node.Addresses (local IPs). + // It can only be set before calling Start. + ProcessLocalIPs bool + + // ProcessSubnets is whether netstack should handle incoming + // traffic destined to non-local IPs (i.e. whether it should + // be a subnet router). + // It can only be set before calling Start. + ProcessSubnets bool + + ipstack *stack.Stack + linkEP *channel.Endpoint + tundev *tstun.Wrapper + e wgengine.Engine + mc *magicsock.Conn + logf logger.Logf // atomicIsLocalIPFunc holds a func that reports whether an IP // is a local (non-subnet) Tailscale IP address of this @@ -81,7 +91,7 @@ const nicID = 1 const mtu = 1500 // Create creates and populates a new Impl. -func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magicsock.Conn, onlySubnets bool) (*Impl, error) { +func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magicsock.Conn) (*Impl, error) { if mc == nil { return nil, errors.New("nil magicsock.Conn") } @@ -130,7 +140,6 @@ func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magi e: e, mc: mc, connsOpenBySubnetIP: make(map[netaddr.IP]int), - onlySubnets: onlySubnets, } ns.atomicIsLocalIPFunc.Store(tsaddr.NewContainsIPFunc(nil)) return ns, nil @@ -275,10 +284,10 @@ func (ns *Impl) updateIPs(nm *netmap.NetworkMap) { isAddr[ipp] = true } for _, ipp := range nm.SelfNode.AllowedIPs { - if ns.onlySubnets && isAddr[ipp] { - continue + local := isAddr[ipp] + if local && ns.ProcessLocalIPs || !local && ns.ProcessSubnets { + newIPs[ipPrefixToAddressWithPrefix(ipp)] = true } - newIPs[ipPrefixToAddressWithPrefix(ipp)] = true } ipsToBeAdded := make(map[tcpip.AddressWithPrefix]bool) @@ -446,11 +455,27 @@ func (ns *Impl) isLocalIP(ip netaddr.IP) bool { return ns.atomicIsLocalIPFunc.Load().(func(netaddr.IP) bool)(ip) } +// shouldProcessInbound reports whether an inbound packet should be +// handled by netstack. +func (ns *Impl) shouldProcessInbound(p *packet.Parsed, t *tstun.Wrapper) bool { + if !ns.ProcessLocalIPs && !ns.ProcessSubnets { + // Fast path for common case (e.g. Linux server in TUN mode) where + // netstack isn't used at all; don't even do an isLocalIP lookup. + return false + } + isLocal := ns.isLocalIP(p.Dst.IP()) + if ns.ProcessLocalIPs && isLocal { + return true + } + if ns.ProcessSubnets && !isLocal { + return true + } + return false +} + func (ns *Impl) injectInbound(p *packet.Parsed, t *tstun.Wrapper) filter.Response { - if ns.onlySubnets && ns.isLocalIP(p.Dst.IP()) { - // In hybrid ("only subnets") mode, bail out early if - // the traffic is destined for an actual Tailscale - // address. The real host OS interface will handle it. + if !ns.shouldProcessInbound(p, t) { + // Let the host network stack (if any) deal with it. return filter.Accept } var pn tcpip.NetworkProtocolNumber