tsnet: enable use-cases with non-native IPs by setting ns.ProcessSubnets

Terminating traffic to IPs which are not the native IPs of the node requires
the netstack subsystem to intercept trafic to an IP it does not consider local.
This PR switches on such interception. In addition to supporting such termination,
this change will also enable exit nodes and subnet routers when running in
userspace mode.

DO NOT MERGE until 1.52 is cut.

Signed-off-by: Tom DNetto <tom@tailscale.com>
Updates: https://github.com/tailscale/corp/issues/15038
pull/9892/head
Tom DNetto 8 months ago committed by Tom
parent 452f900589
commit 3df305b764

@ -533,6 +533,7 @@ func (s *Server) start() (reterr error) {
sys.Tun.Get().Start() sys.Tun.Get().Start()
sys.Set(ns) sys.Set(ns)
ns.ProcessLocalIPs = true ns.ProcessLocalIPs = true
ns.ProcessSubnets = true
ns.GetTCPHandlerForFlow = s.getTCPHandlerForFlow ns.GetTCPHandlerForFlow = s.getTCPHandlerForFlow
ns.GetUDPHandlerForFlow = s.getUDPHandlerForFlow ns.GetUDPHandlerForFlow = s.getUDPHandlerForFlow
s.netstack = ns s.netstack = ns
@ -731,19 +732,39 @@ func networkForFamily(netBase string, is6 bool) string {
// - ("tcp", "", port) // - ("tcp", "", port)
// //
// The netBase is "tcp" or "udp" (without any '4' or '6' suffix). // The netBase is "tcp" or "udp" (without any '4' or '6' suffix).
//
// Listeners which do not specify an IP address will match for traffic
// for the local node (that is, a destination address of the IPv4 or
// IPv6 address of this node) only. To listen for traffic on other addresses
// such as those routed inbound via subnet routes, explicitly specify
// the listening address or use RegisterFallbackTCPHandler.
func (s *Server) listenerForDstAddr(netBase string, dst netip.AddrPort, funnel bool) (_ *listener, ok bool) { func (s *Server) listenerForDstAddr(netBase string, dst netip.AddrPort, funnel bool) (_ *listener, ok bool) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
for _, a := range [2]netip.Addr{0: dst.Addr()} {
// Search for a listener with the specified IP
for _, net := range [2]string{
networkForFamily(netBase, dst.Addr().Is6()),
netBase,
} {
if ln, ok := s.listeners[listenKey{net, dst.Addr(), dst.Port(), funnel}]; ok {
return ln, true
}
}
// Search for a listener without an IP if the destination was
// one of the native IPs of the node.
if ip4, ip6 := s.TailscaleIPs(); dst.Addr() == ip4 || dst.Addr() == ip6 {
for _, net := range [2]string{ for _, net := range [2]string{
networkForFamily(netBase, dst.Addr().Is6()), networkForFamily(netBase, dst.Addr().Is6()),
netBase, netBase,
} { } {
if ln, ok := s.listeners[listenKey{net, a, dst.Port(), funnel}]; ok { if ln, ok := s.listeners[listenKey{net, netip.Addr{}, dst.Port(), funnel}]; ok {
return ln, true return ln, true
} }
} }
} }
return nil, false return nil, false
} }
@ -853,6 +874,12 @@ func (s *Server) APIClient() (*tailscale.Client, error) {
// Listen announces only on the Tailscale network. // Listen announces only on the Tailscale network.
// It will start the server if it has not been started yet. // It will start the server if it has not been started yet.
//
// Listeners which do not specify an IP address will match for traffic
// for the local node (that is, a destination address of the IPv4 or
// IPv6 address of this node) only. To listen for traffic on other addresses
// such as those routed inbound via subnet routes, explicitly specify
// the listening address or use RegisterFallbackTCPHandler.
func (s *Server) Listen(network, addr string) (net.Listener, error) { func (s *Server) Listen(network, addr string) (net.Listener, error) {
return s.listen(network, addr, listenOnTailnet) return s.listen(network, addr, listenOnTailnet)
} }

@ -39,6 +39,7 @@ import (
"tailscale.com/tstest" "tailscale.com/tstest"
"tailscale.com/tstest/integration" "tailscale.com/tstest/integration"
"tailscale.com/tstest/integration/testcontrol" "tailscale.com/tstest/integration/testcontrol"
"tailscale.com/types/key"
"tailscale.com/types/logger" "tailscale.com/types/logger"
"tailscale.com/util/must" "tailscale.com/util/must"
) )
@ -95,7 +96,7 @@ func TestListenerPort(t *testing.T) {
var verboseDERP = flag.Bool("verbose-derp", false, "if set, print DERP and STUN logs") var verboseDERP = flag.Bool("verbose-derp", false, "if set, print DERP and STUN logs")
var verboseNodes = flag.Bool("verbose-nodes", false, "if set, print tsnet.Server logs") var verboseNodes = flag.Bool("verbose-nodes", false, "if set, print tsnet.Server logs")
func startControl(t *testing.T) (controlURL string) { func startControl(t *testing.T) (controlURL string, control *testcontrol.Server) {
// Corp#4520: don't use netns for tests. // Corp#4520: don't use netns for tests.
netns.SetEnabled(false) netns.SetEnabled(false)
t.Cleanup(func() { t.Cleanup(func() {
@ -107,7 +108,7 @@ func startControl(t *testing.T) (controlURL string) {
derpLogf = t.Logf derpLogf = t.Logf
} }
derpMap := integration.RunDERPAndSTUN(t, derpLogf, "127.0.0.1") derpMap := integration.RunDERPAndSTUN(t, derpLogf, "127.0.0.1")
control := &testcontrol.Server{ control = &testcontrol.Server{
DERPMap: derpMap, DERPMap: derpMap,
DNSConfig: &tailcfg.DNSConfig{ DNSConfig: &tailcfg.DNSConfig{
Proxied: true, Proxied: true,
@ -119,7 +120,7 @@ func startControl(t *testing.T) (controlURL string) {
t.Cleanup(control.HTTPTestServer.Close) t.Cleanup(control.HTTPTestServer.Close)
controlURL = control.HTTPTestServer.URL controlURL = control.HTTPTestServer.URL
t.Logf("testcontrol listening on %s", controlURL) t.Logf("testcontrol listening on %s", controlURL)
return controlURL return controlURL, control
} }
type testCertIssuer struct { type testCertIssuer struct {
@ -200,7 +201,7 @@ func (tci *testCertIssuer) Pool() *x509.CertPool {
var testCertRoot = newCertIssuer() var testCertRoot = newCertIssuer()
func startServer(t *testing.T, ctx context.Context, controlURL, hostname string) (*Server, netip.Addr) { func startServer(t *testing.T, ctx context.Context, controlURL, hostname string) (*Server, netip.Addr, key.NodePublic) {
t.Helper() t.Helper()
tmp := filepath.Join(t.TempDir(), hostname) tmp := filepath.Join(t.TempDir(), hostname)
@ -222,7 +223,7 @@ func startServer(t *testing.T, ctx context.Context, controlURL, hostname string)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
return s, status.TailscaleIPs[0] return s, status.TailscaleIPs[0], status.Self.PublicKey
} }
func TestConn(t *testing.T) { func TestConn(t *testing.T) {
@ -230,9 +231,17 @@ func TestConn(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()
controlURL := startControl(t) controlURL, c := startControl(t)
s1, s1ip := startServer(t, ctx, controlURL, "s1") s1, s1ip, s1PubKey := startServer(t, ctx, controlURL, "s1")
s2, _ := startServer(t, ctx, controlURL, "s2") s2, _, _ := startServer(t, ctx, controlURL, "s2")
s1.lb.EditPrefs(&ipn.MaskedPrefs{
Prefs: ipn.Prefs{
AdvertiseRoutes: []netip.Prefix{netip.MustParsePrefix("192.0.2.0/24")},
},
AdvertiseRoutesSet: true,
})
c.SetSubnetRoutes(s1PubKey, []netip.Prefix{netip.MustParsePrefix("192.0.2.0/24")})
lc2, err := s2.LocalClient() lc2, err := s2.LocalClient()
if err != nil { if err != nil {
@ -281,6 +290,15 @@ func TestConn(t *testing.T) {
if err == nil { if err == nil {
t.Fatalf("unexpected success; should have seen a connection refused error") t.Fatalf("unexpected success; should have seen a connection refused error")
} }
// s1 is a subnet router for TEST-NET-1 (192.0.2.0/24). Lets dial to that
// subnet from s2 to ensure a listener without an IP address (i.e. ":8081")
// only matches destination IPs corresponding to the node's IP, and not
// to any random IP a subnet is routing.
_, err = s2.Dial(ctx, "tcp", fmt.Sprintf("%s:8081", "192.0.2.1"))
if err == nil {
t.Fatalf("unexpected success; should have seen a connection refused error")
}
} }
func TestLoopbackLocalAPI(t *testing.T) { func TestLoopbackLocalAPI(t *testing.T) {
@ -289,8 +307,8 @@ func TestLoopbackLocalAPI(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()
controlURL := startControl(t) controlURL, _ := startControl(t)
s1, _ := startServer(t, ctx, controlURL, "s1") s1, _, _ := startServer(t, ctx, controlURL, "s1")
addr, proxyCred, localAPICred, err := s1.Loopback() addr, proxyCred, localAPICred, err := s1.Loopback()
if err != nil { if err != nil {
@ -363,9 +381,9 @@ func TestLoopbackSOCKS5(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()
controlURL := startControl(t) controlURL, _ := startControl(t)
s1, s1ip := startServer(t, ctx, controlURL, "s1") s1, s1ip, _ := startServer(t, ctx, controlURL, "s1")
s2, _ := startServer(t, ctx, controlURL, "s2") s2, _, _ := startServer(t, ctx, controlURL, "s2")
addr, proxyCred, _, err := s2.Loopback() addr, proxyCred, _, err := s2.Loopback()
if err != nil { if err != nil {
@ -410,7 +428,7 @@ func TestLoopbackSOCKS5(t *testing.T) {
} }
func TestTailscaleIPs(t *testing.T) { func TestTailscaleIPs(t *testing.T) {
controlURL := startControl(t) controlURL, _ := startControl(t)
tmp := t.TempDir() tmp := t.TempDir()
tmps1 := filepath.Join(tmp, "s1") tmps1 := filepath.Join(tmp, "s1")
@ -455,8 +473,8 @@ func TestListenerCleanup(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()
controlURL := startControl(t) controlURL, _ := startControl(t)
s1, _ := startServer(t, ctx, controlURL, "s1") s1, _, _ := startServer(t, ctx, controlURL, "s1")
ln, err := s1.Listen("tcp", ":8081") ln, err := s1.Listen("tcp", ":8081")
if err != nil { if err != nil {
@ -475,7 +493,7 @@ func TestListenerCleanup(t *testing.T) {
// tests https://github.com/tailscale/tailscale/issues/6973 -- that we can start a tsnet server, // tests https://github.com/tailscale/tailscale/issues/6973 -- that we can start a tsnet server,
// stop it, and restart it, even on Windows. // stop it, and restart it, even on Windows.
func TestStartStopStartGetsSameIP(t *testing.T) { func TestStartStopStartGetsSameIP(t *testing.T) {
controlURL := startControl(t) controlURL, _ := startControl(t)
tmp := t.TempDir() tmp := t.TempDir()
tmps1 := filepath.Join(tmp, "s1") tmps1 := filepath.Join(tmp, "s1")
@ -527,9 +545,9 @@ func TestFunnel(t *testing.T) {
ctx, dialCancel := context.WithTimeout(context.Background(), 30*time.Second) ctx, dialCancel := context.WithTimeout(context.Background(), 30*time.Second)
defer dialCancel() defer dialCancel()
controlURL := startControl(t) controlURL, _ := startControl(t)
s1, _ := startServer(t, ctx, controlURL, "s1") s1, _, _ := startServer(t, ctx, controlURL, "s1")
s2, _ := startServer(t, ctx, controlURL, "s2") s2, _, _ := startServer(t, ctx, controlURL, "s2")
ln := must.Get(s1.ListenFunnel("tcp", ":443")) ln := must.Get(s1.ListenFunnel("tcp", ":443"))
defer ln.Close() defer ln.Close()
@ -637,9 +655,9 @@ func TestFallbackTCPHandler(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()
controlURL := startControl(t) controlURL, _ := startControl(t)
s1, s1ip := startServer(t, ctx, controlURL, "s1") s1, s1ip, _ := startServer(t, ctx, controlURL, "s1")
s2, _ := startServer(t, ctx, controlURL, "s2") s2, _, _ := startServer(t, ctx, controlURL, "s2")
lc2, err := s2.LocalClient() lc2, err := s2.LocalClient()
if err != nil { if err != nil {

@ -33,6 +33,7 @@ import (
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/logger" "tailscale.com/types/logger"
"tailscale.com/types/ptr" "tailscale.com/types/ptr"
"tailscale.com/util/mak"
"tailscale.com/util/must" "tailscale.com/util/must"
"tailscale.com/util/rands" "tailscale.com/util/rands"
"tailscale.com/util/set" "tailscale.com/util/set"
@ -64,6 +65,10 @@ type Server struct {
pubKey key.MachinePublic pubKey key.MachinePublic
privKey key.ControlPrivate // not strictly needed vs. MachinePrivate, but handy to test type interactions. privKey key.ControlPrivate // not strictly needed vs. MachinePrivate, but handy to test type interactions.
// nodeSubnetRoutes is a list of subnet routes that are served
// by the specified node.
nodeSubnetRoutes map[key.NodePublic][]netip.Prefix
// masquerades is the set of masquerades that should be applied to // masquerades is the set of masquerades that should be applied to
// MapResponses sent to clients. It is keyed by the requesting nodes // MapResponses sent to clients. It is keyed by the requesting nodes
// public key, and then the peer node's public key. The value is the // public key, and then the peer node's public key. The value is the
@ -328,6 +333,13 @@ func (s *Server) serveMachine(w http.ResponseWriter, r *http.Request) {
} }
} }
// SetSubnetRoutes sets the list of subnet routes which a node is routing.
func (s *Server) SetSubnetRoutes(nodeKey key.NodePublic, routes []netip.Prefix) {
s.mu.Lock()
defer s.mu.Unlock()
mak.Set(&s.nodeSubnetRoutes, nodeKey, routes)
}
// MasqueradePair is a pair of nodes and the IP address that the // MasqueradePair is a pair of nodes and the IP address that the
// Node masquerades as for the Peer. // Node masquerades as for the Peer.
// //
@ -908,6 +920,7 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse,
s.mu.Lock() s.mu.Lock()
peerAddress := s.masquerades[p.Key][node.Key] peerAddress := s.masquerades[p.Key][node.Key]
routes := s.nodeSubnetRoutes[p.Key]
s.mu.Unlock() s.mu.Unlock()
if peerAddress.IsValid() { if peerAddress.IsValid() {
if peerAddress.Is6() { if peerAddress.Is6() {
@ -918,6 +931,10 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse,
p.AllowedIPs[0] = netip.PrefixFrom(peerAddress, peerAddress.BitLen()) p.AllowedIPs[0] = netip.PrefixFrom(peerAddress, peerAddress.BitLen())
} }
} }
if len(routes) > 0 {
p.PrimaryRoutes = routes
p.AllowedIPs = append(p.AllowedIPs, routes...)
}
res.Peers = append(res.Peers, p) res.Peers = append(res.Peers, p)
} }
@ -939,11 +956,12 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse,
v4Prefix, v4Prefix,
v6Prefix, v6Prefix,
} }
res.Node.AllowedIPs = res.Node.Addresses
// Consume a PingRequest while protected by mutex if it exists
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
res.Node.AllowedIPs = append(res.Node.Addresses, s.nodeSubnetRoutes[nk]...)
// Consume a PingRequest while protected by mutex if it exists
switch m := s.msgToSend[nk].(type) { switch m := s.msgToSend[nk].(type) {
case *tailcfg.PingRequest: case *tailcfg.PingRequest:
res.PingRequest = m res.PingRequest = m

Loading…
Cancel
Save