// Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause // Package vnet simulates a virtual Internet containing a set of networks with various // NAT behaviors. You can then plug VMs into the virtual internet at different points // to test Tailscale working end-to-end in various network conditions. // // See https://github.com/tailscale/tailscale/issues/13038 package vnet // TODO: // - [ ] port mapping actually working // - [ ] conf to let you firewall things // - [ ] tests for NAT tables import ( "bufio" "bytes" "context" "crypto/tls" "encoding/binary" "encoding/json" "errors" "fmt" "io" "log" "math/rand/v2" "net" "net/http" "net/http/httptest" "net/netip" "os/exec" "strconv" "sync" "sync/atomic" "time" "github.com/google/gopacket" "github.com/google/gopacket/layers" "go4.org/mem" "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/waiter" "tailscale.com/client/tailscale" "tailscale.com/derp" "tailscale.com/derp/derphttp" "tailscale.com/net/netutil" "tailscale.com/net/stun" "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/tstest/integration/testcontrol" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/util/mak" "tailscale.com/util/must" "tailscale.com/util/set" "tailscale.com/util/zstdframe" ) const nicID = 1 const ( stunPort = 3478 pcpPort = 5351 ssdpPort = 1900 ) func (s *Server) PopulateDERPMapIPs() error { out, err := exec.Command("tailscale", "debug", "derp-map").Output() if err != nil { return fmt.Errorf("tailscale debug derp-map: %v", err) } var dm tailcfg.DERPMap if err := json.Unmarshal(out, &dm); err != nil { return fmt.Errorf("unmarshal DERPMap: %v", err) } for _, r := range dm.Regions { for _, n := range r.Nodes { if n.IPv4 != "" { s.derpIPs.Add(netip.MustParseAddr(n.IPv4)) } } } return nil } func (n *network) InitNAT(natType NAT) error { ctor, ok := natTypes[natType] if !ok { return fmt.Errorf("unknown NAT type %q", natType) } t, err := ctor(n) if err != nil { return fmt.Errorf("error creating NAT type %q for network %v: %w", natType, n.wanIP, err) } n.setNATTable(t) n.natStyle.Store(natType) return nil } func (n *network) setNATTable(nt NATTable) { n.natMu.Lock() defer n.natMu.Unlock() n.natTable = nt } // SoleLANIP implements [IPPool]. func (n *network) SoleLANIP() (netip.Addr, bool) { if len(n.nodesByIP) != 1 { return netip.Addr{}, false } for ip := range n.nodesByIP { return ip, true } return netip.Addr{}, false } // WANIP implements [IPPool]. func (n *network) WANIP() netip.Addr { return n.wanIP } func (n *network) initStack() error { n.ns = stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ ipv4.NewProtocol, arp.NewProtocol, }, TransportProtocols: []stack.TransportProtocolFactory{ tcp.NewProtocol, icmp.NewProtocol4, }, }) sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default tcpipErr := n.ns.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt) if tcpipErr != nil { return fmt.Errorf("SetTransportProtocolOption SACK: %v", tcpipErr) } n.linkEP = channel.New(512, 1500, tcpip.LinkAddress(n.mac.HWAddr())) if tcpipProblem := n.ns.CreateNIC(nicID, n.linkEP); tcpipProblem != nil { return fmt.Errorf("CreateNIC: %v", tcpipProblem) } n.ns.SetPromiscuousMode(nicID, true) n.ns.SetSpoofing(nicID, true) prefix := tcpip.AddrFrom4Slice(n.lanIP.Addr().AsSlice()).WithPrefix() prefix.PrefixLen = n.lanIP.Bits() if tcpProb := n.ns.AddProtocolAddress(nicID, tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: prefix, }, stack.AddressProperties{}); tcpProb != nil { return errors.New(tcpProb.String()) } ipv4Subnet, err := tcpip.NewSubnet(tcpip.AddrFromSlice(make([]byte, 4)), tcpip.MaskFromBytes(make([]byte, 4))) if err != nil { return fmt.Errorf("could not create IPv4 subnet: %v", err) } n.ns.SetRouteTable([]tcpip.Route{ { Destination: ipv4Subnet, NIC: nicID, }, }) const tcpReceiveBufferSize = 0 // default const maxInFlightConnectionAttempts = 8192 tcpFwd := tcp.NewForwarder(n.ns, tcpReceiveBufferSize, maxInFlightConnectionAttempts, n.acceptTCP) n.ns.SetTransportProtocolHandler(tcp.ProtocolNumber, func(tei stack.TransportEndpointID, pb *stack.PacketBuffer) (handled bool) { return tcpFwd.HandlePacket(tei, pb) }) go func() { for { pkt := n.linkEP.ReadContext(n.s.shutdownCtx) if pkt == nil { if n.s.shutdownCtx.Err() != nil { // Return without logging. return } continue } ipRaw := pkt.ToView().AsSlice() goPkt := gopacket.NewPacket( ipRaw, layers.LayerTypeIPv4, gopacket.Lazy) layerV4 := goPkt.Layer(layers.LayerTypeIPv4).(*layers.IPv4) dstIP, _ := netip.AddrFromSlice(layerV4.DstIP) node, ok := n.nodesByIP[dstIP] if !ok { log.Printf("no MAC for dest IP %v", dstIP) continue } eth := &layers.Ethernet{ SrcMAC: n.mac.HWAddr(), DstMAC: node.mac.HWAddr(), EthernetType: layers.EthernetTypeIPv4, } buffer := gopacket.NewSerializeBuffer() options := gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true} sls := []gopacket.SerializableLayer{ eth, } for _, layer := range goPkt.Layers() { sl, ok := layer.(gopacket.SerializableLayer) if !ok { log.Fatalf("layer %s is not serializable", layer.LayerType().String()) } switch gl := layer.(type) { case *layers.TCP: gl.SetNetworkLayerForChecksum(layerV4) case *layers.UDP: gl.SetNetworkLayerForChecksum(layerV4) } sls = append(sls, sl) } if err := gopacket.SerializeLayers(buffer, options, sls...); err != nil { log.Printf("Serialize error: %v", err) continue } if nw, ok := n.writers.Load(node.mac); ok { nw.write(buffer.Bytes()) } else { log.Printf("No writeFunc for %v", node.mac) } } }() return nil } func netaddrIPFromNetstackIP(s tcpip.Address) netip.Addr { switch s.Len() { case 4: return netip.AddrFrom4(s.As4()) case 16: return netip.AddrFrom16(s.As16()).Unmap() } return netip.Addr{} } func stringifyTEI(tei stack.TransportEndpointID) string { localHostPort := net.JoinHostPort(tei.LocalAddress.String(), strconv.Itoa(int(tei.LocalPort))) remoteHostPort := net.JoinHostPort(tei.RemoteAddress.String(), strconv.Itoa(int(tei.RemotePort))) return fmt.Sprintf("%s -> %s", remoteHostPort, localHostPort) } func (n *network) acceptTCP(r *tcp.ForwarderRequest) { reqDetails := r.ID() clientRemoteIP := netaddrIPFromNetstackIP(reqDetails.RemoteAddress) destIP := netaddrIPFromNetstackIP(reqDetails.LocalAddress) destPort := reqDetails.LocalPort if !clientRemoteIP.IsValid() { r.Complete(true) // sends a RST return } var wq waiter.Queue ep, err := r.CreateEndpoint(&wq) if err != nil { log.Printf("CreateEndpoint error for %s: %v", stringifyTEI(reqDetails), err) r.Complete(true) // sends a RST return } ep.SocketOptions().SetKeepAlive(true) if destPort == 123 { r.Complete(false) tc := gonet.NewTCPConn(&wq, ep) io.WriteString(tc, "Hello from Go\nGoodbye.\n") tc.Close() return } if destPort == 8008 && fakeTestAgent.Match(destIP) { r.Complete(false) tc := gonet.NewTCPConn(&wq, ep) node := n.nodesByIP[clientRemoteIP] ac := &agentConn{node, tc} n.s.addIdleAgentConn(ac) return } if destPort == 80 && fakeControl.Match(destIP) { r.Complete(false) tc := gonet.NewTCPConn(&wq, ep) hs := &http.Server{Handler: n.s.control} go hs.Serve(netutil.NewOneConnListener(tc, nil)) return } if fakeDERP1.Match(destIP) || fakeDERP2.Match(destIP) { if destPort == 443 { ds := n.s.derps[0] if fakeDERP2.Match(destIP) { ds = n.s.derps[1] } r.Complete(false) tc := gonet.NewTCPConn(&wq, ep) tlsConn := tls.Server(tc, ds.tlsConfig) hs := &http.Server{Handler: ds.handler} go hs.Serve(netutil.NewOneConnListener(tlsConn, nil)) return } if destPort == 80 { r.Complete(false) tc := gonet.NewTCPConn(&wq, ep) hs := &http.Server{Handler: n.s.derps[0].handler} go hs.Serve(netutil.NewOneConnListener(tc, nil)) return } } if destPort == 443 && fakeLogCatcher.Match(destIP) { r.Complete(false) tc := gonet.NewTCPConn(&wq, ep) go n.serveLogCatcherConn(clientRemoteIP, tc) return } log.Printf("vnet-AcceptTCP: %v", stringifyTEI(reqDetails)) var targetDial string if n.s.derpIPs.Contains(destIP) { targetDial = destIP.String() + ":" + strconv.Itoa(int(destPort)) } else if fakeProxyControlplane.Match(destIP) { targetDial = "controlplane.tailscale.com:" + strconv.Itoa(int(destPort)) } if targetDial != "" { c, err := net.Dial("tcp", targetDial) if err != nil { r.Complete(true) log.Printf("Dial controlplane: %v", err) return } defer c.Close() tc := gonet.NewTCPConn(&wq, ep) defer tc.Close() r.Complete(false) errc := make(chan error, 2) go func() { _, err := io.Copy(tc, c); errc <- err }() go func() { _, err := io.Copy(c, tc); errc <- err }() <-errc } else { r.Complete(true) // sends a RST } } // serveLogCatchConn serves a TCP connection to "log.tailscale.io", speaking the // logtail/logcatcher protocol. // // We terminate TLS with an arbitrary cert; the client is configured to not // validate TLS certs for this hostname when running under these integration // tests. func (n *network) serveLogCatcherConn(clientRemoteIP netip.Addr, c net.Conn) { tlsConfig := n.s.derps[0].tlsConfig // self-signed (stealing DERP's); test client configure to not check tlsConn := tls.Server(c, tlsConfig) var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { all, _ := io.ReadAll(r.Body) if r.Header.Get("Content-Encoding") == "zstd" { var err error all, err = zstdframe.AppendDecode(nil, all) if err != nil { log.Printf("LOGS DECODE ERROR zstd decode: %v", err) http.Error(w, "zstd decode error", http.StatusBadRequest) return } } var logs []struct { Logtail struct { Client_Time time.Time } Text string } if err := json.Unmarshal(all, &logs); err != nil { log.Printf("Logs decode error: %v", err) return } node := n.nodesByIP[clientRemoteIP] if node != nil { node.logMu.Lock() defer node.logMu.Unlock() node.logCatcherWrites++ for _, lg := range logs { tStr := lg.Logtail.Client_Time.Round(time.Millisecond).Format(time.RFC3339Nano) fmt.Fprintf(&node.logBuf, "[%v] %s\n", tStr, lg.Text) } } }) hs := &http.Server{Handler: handler} hs.Serve(netutil.NewOneConnListener(tlsConn, nil)) } type EthernetPacket struct { le *layers.Ethernet gp gopacket.Packet } func (ep EthernetPacket) SrcMAC() MAC { return MAC(ep.le.SrcMAC) } func (ep EthernetPacket) DstMAC() MAC { return MAC(ep.le.DstMAC) } type MAC [6]byte func (m MAC) IsBroadcast() bool { return m == MAC{0xff, 0xff, 0xff, 0xff, 0xff, 0xff} } func macOf(hwa net.HardwareAddr) (_ MAC, ok bool) { if len(hwa) != 6 { return MAC{}, false } return MAC(hwa), true } func (m MAC) HWAddr() net.HardwareAddr { return net.HardwareAddr(m[:]) } func (m MAC) String() string { return fmt.Sprintf("%02x:%02x:%02x:%02x:%02x:%02x", m[0], m[1], m[2], m[3], m[4], m[5]) } type portMapping struct { dst netip.AddrPort // LAN IP:port expiry time.Time } type writerFunc func([]byte, *net.UnixAddr, int) // Encapsulates both a write function, an optional outbound socket address // for dgram mode and an interfaceID for packet captures. type networkWriter struct { writer writerFunc // Function to write packets to the network addr *net.UnixAddr // Outbound socket address for dgram mode interfaceID int // The interface ID of the src node (for writing pcaps) } func (nw *networkWriter) write(b []byte) { nw.writer(b, nw.addr, nw.interfaceID) } type network struct { s *Server mac MAC portmap bool lanInterfaceID int wanInterfaceID int wanIP netip.Addr lanIP netip.Prefix // with host bits set (e.g. 192.168.2.1/24) nodesByIP map[netip.Addr]*node logf func(format string, args ...any) ns *stack.Stack linkEP *channel.Endpoint natStyle syncs.AtomicValue[NAT] natMu sync.Mutex // held while using + changing natTable natTable NATTable portMap map[netip.AddrPort]portMapping // WAN ip:port -> LAN ip:port portMapFlow map[portmapFlowKey]netip.AddrPort // (lanAP, peerWANAP) -> portmapped wanAP // writers is a map of MAC -> networkWriters to write packets to that MAC. // It contains entries for connected nodes only. writers syncs.Map[MAC, networkWriter] // MAC -> to networkWriter for that MAC } // Regsiters a writerFunc for a MAC address. // raddr is and optional outbound socket address of the client interface for dgram mode. // Pass nil for the writerFunc to deregister the writer. func (n *network) registerWriter(mac MAC, raddr *net.UnixAddr, interfaceID int, wf writerFunc) { if wf != nil { n.writers.Store(mac, networkWriter{ writer: wf, addr: raddr, interfaceID: interfaceID, }) } else { n.writers.Delete(mac) } } func (n *network) MACOfIP(ip netip.Addr) (_ MAC, ok bool) { if n.lanIP.Addr() == ip { return n.mac, true } if n, ok := n.nodesByIP[ip]; ok { return n.mac, true } return MAC{}, false } type node struct { mac MAC num int // 1-based node number interfaceID int net *network lanIP netip.Addr // must be in net.lanIP prefix + unique in net verboseSyslog bool // logMu guards logBuf. // TODO(bradfitz): conditionally write these out to separate files at the end? // Currently they only hold logcatcher logs. logMu sync.Mutex logBuf bytes.Buffer logCatcherWrites int } // String returns the string "nodeN" where N is the 1-based node number. func (n *node) String() string { return fmt.Sprintf("node%d", n.num) } type derpServer struct { srv *derp.Server handler http.Handler tlsConfig *tls.Config } func newDERPServer() *derpServer { // Just to get a self-signed TLS cert: ts := httptest.NewTLSServer(nil) ts.Close() ds := &derpServer{ srv: derp.NewServer(key.NewNode(), logger.Discard), tlsConfig: ts.TLS, // self-signed; test client configure to not check } var mux http.ServeMux mux.Handle("/derp", derphttp.Handler(ds.srv)) mux.HandleFunc("/generate_204", derphttp.ServeNoContent) ds.handler = &mux return ds } type Server struct { shutdownCtx context.Context shutdownCancel context.CancelFunc shuttingDown atomic.Bool wg sync.WaitGroup blendReality bool derpIPs set.Set[netip.Addr] nodes []*node nodeByMAC map[MAC]*node networks set.Set[*network] networkByWAN map[netip.Addr]*network control *testcontrol.Server derps []*derpServer pcapWriter *pcapWriter mu sync.Mutex agentConnWaiter map[*node]chan<- struct{} // signaled after added to set agentConns set.Set[*agentConn] // not keyed by node; should be small/cheap enough to scan all agentDialer map[*node]DialFunc } type DialFunc func(ctx context.Context, network, address string) (net.Conn, error) var derpMap = &tailcfg.DERPMap{ Regions: map[int]*tailcfg.DERPRegion{ 1: { RegionID: 1, RegionCode: "atlantis", RegionName: "Atlantis", Nodes: []*tailcfg.DERPNode{ { Name: "1a", RegionID: 1, HostName: "derp1.tailscale", IPv4: fakeDERP1.v4.String(), IPv6: fakeDERP1.v6.String(), InsecureForTests: true, CanPort80: true, }, }, }, 2: { RegionID: 2, RegionCode: "northpole", RegionName: "North Pole", Nodes: []*tailcfg.DERPNode{ { Name: "2a", RegionID: 2, HostName: "derp2.tailscale", IPv4: fakeDERP2.v4.String(), IPv6: fakeDERP2.v6.String(), InsecureForTests: true, CanPort80: true, }, }, }, }, } func New(c *Config) (*Server, error) { ctx, cancel := context.WithCancel(context.Background()) s := &Server{ shutdownCtx: ctx, shutdownCancel: cancel, control: &testcontrol.Server{ DERPMap: derpMap, ExplicitBaseURL: "http://control.tailscale", }, derpIPs: set.Of[netip.Addr](), nodeByMAC: map[MAC]*node{}, networkByWAN: map[netip.Addr]*network{}, networks: set.Of[*network](), } for range 2 { s.derps = append(s.derps, newDERPServer()) } if err := s.initFromConfig(c); err != nil { return nil, err } for n := range s.networks { if err := n.initStack(); err != nil { return nil, fmt.Errorf("newServer: initStack: %v", err) } } return s, nil } func (s *Server) Close() { if shutdown := s.shuttingDown.Swap(true); !shutdown { s.shutdownCancel() s.pcapWriter.Close() } s.wg.Wait() } func (s *Server) HWAddr(mac MAC) net.HardwareAddr { // TODO: cache return net.HardwareAddr(mac[:]) } // IPv4ForDNS returns the IP address for the given DNS query name (for IPv4 A // queries only). func (s *Server) IPv4ForDNS(qname string) (netip.Addr, bool) { if v, ok := vips[qname]; ok { return v.v4, v.v4.IsValid() } return netip.Addr{}, false } type Protocol int const ( ProtocolQEMU = Protocol(iota + 1) ProtocolUnixDGRAM // for macOS Virtualization.Framework and VZFileHandleNetworkDeviceAttachment ) // Handles a single connection from a QEMU-style client or muxd connections for dgram mode func (s *Server) ServeUnixConn(uc *net.UnixConn, proto Protocol) { if s.shuttingDown.Load() { return } s.wg.Add(1) defer s.wg.Done() context.AfterFunc(s.shutdownCtx, func() { uc.SetDeadline(time.Now()) }) log.Printf("Got conn %T %p", uc, uc) defer uc.Close() bw := bufio.NewWriterSize(uc, 2<<10) var writeMu sync.Mutex writePkt := func(pkt []byte, raddr *net.UnixAddr, interfaceID int) { if pkt == nil { return } writeMu.Lock() defer writeMu.Unlock() switch proto { case ProtocolQEMU: hdr := binary.BigEndian.AppendUint32(bw.AvailableBuffer()[:0], uint32(len(pkt))) if _, err := bw.Write(hdr); err != nil { log.Printf("Write hdr: %v", err) return } if _, err := bw.Write(pkt); err != nil { log.Printf("Write pkt: %v", err) return } case ProtocolUnixDGRAM: if raddr == nil { log.Printf("Write pkt: dgram mode write failure, no outbound socket address") return } if _, err := uc.WriteToUnix(pkt, raddr); err != nil { log.Printf("Write pkt : %v", err) return } } if err := bw.Flush(); err != nil { log.Printf("Flush: %v", err) } must.Do(s.pcapWriter.WritePacket(gopacket.CaptureInfo{ Timestamp: time.Now(), CaptureLength: len(pkt), Length: len(pkt), InterfaceIndex: interfaceID, }, pkt)) } buf := make([]byte, 16<<10) for { var packetRaw []byte var raddr *net.UnixAddr switch proto { case ProtocolUnixDGRAM: n, addr, err := uc.ReadFromUnix(buf) raddr = addr if err != nil { log.Printf("ReadFromUnix: %v", err) continue } packetRaw = buf[:n] case ProtocolQEMU: if _, err := io.ReadFull(uc, buf[:4]); err != nil { if s.shutdownCtx.Err() != nil { // Return without logging. return } log.Printf("ReadFull header: %v", err) return } n := binary.BigEndian.Uint32(buf[:4]) if _, err := io.ReadFull(uc, buf[4:4+n]); err != nil { if s.shutdownCtx.Err() != nil { // Return without logging. return } log.Printf("ReadFull pkt: %v", err) return } packetRaw = buf[4 : 4+n] // raw ethernet frame } packet := gopacket.NewPacket(packetRaw, layers.LayerTypeEthernet, gopacket.Lazy) le, ok := packet.LinkLayer().(*layers.Ethernet) if !ok || len(le.SrcMAC) != 6 || len(le.DstMAC) != 6 { continue } ep := EthernetPacket{le, packet} srcMAC := ep.SrcMAC() srcNode, ok := s.nodeByMAC[srcMAC] if !ok { log.Printf("[conn %p] got frame from unknown MAC %v", uc, srcMAC) continue } // Register a writer for the source MAC address if one doesn't exist. if _, ok := srcNode.net.writers.Load(srcMAC); !ok { log.Printf("[conn %p] Registering writer for MAC %v is node %v", uc, srcMAC, srcNode.lanIP) srcNode.net.registerWriter(srcMAC, raddr, srcNode.interfaceID, writePkt) defer func() { srcNode.net.registerWriter(srcMAC, nil, 0, nil) }() continue } must.Do(s.pcapWriter.WritePacket(gopacket.CaptureInfo{ Timestamp: time.Now(), CaptureLength: len(packetRaw), Length: len(packetRaw), InterfaceIndex: srcNode.interfaceID, }, packetRaw)) srcNode.net.HandleEthernetPacket(ep) } } func (s *Server) routeUDPPacket(up UDPPacket) { // Find which network owns this based on the destination IP // and all the known networks' wan IPs. // But certain things (like STUN) we do in-process. if up.Dst.Port() == stunPort { // TODO(bradfitz): fake latency; time.AfterFunc the response if res, ok := makeSTUNReply(up); ok { //log.Printf("STUN reply: %+v", res) s.routeUDPPacket(res) } else { log.Printf("weird: STUN packet not handled") } return } dstIP := up.Dst.Addr() netw, ok := s.networkByWAN[dstIP] if !ok { if dstIP.IsPrivate() { // Not worth spamming logs. RFC 1918 space doesn't route. return } log.Printf("no network to route UDP packet for %v", up.Dst) return } netw.HandleUDPPacket(up) } // writeEth writes a raw Ethernet frame to all (0, 1, or multiple) connected // clients on the network. // // This only delivers to client devices and not the virtual router/gateway // device. func (n *network) writeEth(res []byte) { if len(res) < 12 { return } dstMAC := MAC(res[0:6]) srcMAC := MAC(res[6:12]) if dstMAC.IsBroadcast() { n.writers.Range(func(mac MAC, nw networkWriter) bool { nw.write(res) return true }) return } if srcMAC == dstMAC { n.logf("dropping write of packet from %v to itself", srcMAC) return } if nw, ok := n.writers.Load(dstMAC); ok { nw.write(res) return } } func (n *network) HandleEthernetPacket(ep EthernetPacket) { packet := ep.gp dstMAC := ep.DstMAC() isBroadcast := dstMAC.IsBroadcast() forRouter := dstMAC == n.mac || isBroadcast switch ep.le.EthernetType { default: n.logf("Dropping non-IP packet: %v", ep.le.EthernetType) return case layers.EthernetTypeARP: res, err := n.createARPResponse(packet) if err != nil { n.logf("createARPResponse: %v", err) } else { n.writeEth(res) } return case layers.EthernetTypeIPv6: // One day. Low value for now. IPv4 NAT modes is the main thing // this project wants to test. return case layers.EthernetTypeIPv4: // Below } // Send ethernet broadcasts and unicast ethernet frames to peers // on the same network. This is all LAN traffic that isn't meant // for the router/gw itself: n.writeEth(ep.gp.Data()) if forRouter { n.HandleEthernetIPv4PacketForRouter(ep) } } // HandleUDPPacket handles a UDP packet arriving from the internet, // addressed to the router's WAN IP. It is then NATed back to a // LAN IP here and wrapped in an ethernet layer and delivered // to the network. func (n *network) HandleUDPPacket(p UDPPacket) { buf, err := n.serializedUDPPacket(p.Src, p.Dst, p.Payload, nil) if err != nil { n.logf("serializing UDP packet: %v", err) return } n.s.pcapWriter.WritePacket(gopacket.CaptureInfo{ Timestamp: time.Now(), CaptureLength: len(buf), Length: len(buf), InterfaceIndex: n.wanInterfaceID, }, buf) dst := n.doNATIn(p.Src, p.Dst) if !dst.IsValid() { n.logf("Warning: NAT dropped packet; no mapping for %v=>%v", p.Src, p.Dst) return } p.Dst = dst buf, err = n.serializedUDPPacket(p.Src, p.Dst, p.Payload, nil) if err != nil { n.logf("serializing UDP packet: %v", err) return } n.s.pcapWriter.WritePacket(gopacket.CaptureInfo{ Timestamp: time.Now(), CaptureLength: len(buf), Length: len(buf), InterfaceIndex: n.lanInterfaceID, }, buf) n.WriteUDPPacketNoNAT(p) } // WriteUDPPacketNoNAT writes a UDP packet to the network, without // doing any NAT translation. // // The packet will always have the ethernet src MAC of the router // so this should not be used for packets between clients on the // same ethernet segment. func (n *network) WriteUDPPacketNoNAT(p UDPPacket) { src, dst := p.Src, p.Dst node, ok := n.nodesByIP[dst.Addr()] if !ok { n.logf("no node for dest IP %v in UDP packet %v=>%v", dst.Addr(), p.Src, p.Dst) return } eth := &layers.Ethernet{ SrcMAC: n.mac.HWAddr(), // of gateway DstMAC: node.mac.HWAddr(), EthernetType: layers.EthernetTypeIPv4, } ethRaw, err := n.serializedUDPPacket(src, dst, p.Payload, eth) if err != nil { n.logf("serializing UDP packet: %v", err) return } n.writeEth(ethRaw) } // serializedUDPPacket serializes a UDP packet with the given source and // destination IP:port pairs, and payload. // // If eth is non-nil, it will be used as the Ethernet layer, otherwise the // Ethernet layer will be omitted from the serialization. func (n *network) serializedUDPPacket(src, dst netip.AddrPort, payload []byte, eth *layers.Ethernet) ([]byte, error) { ip := &layers.IPv4{ Version: 4, TTL: 64, Protocol: layers.IPProtocolUDP, SrcIP: src.Addr().AsSlice(), DstIP: dst.Addr().AsSlice(), } udp := &layers.UDP{ SrcPort: layers.UDPPort(src.Port()), DstPort: layers.UDPPort(dst.Port()), } udp.SetNetworkLayerForChecksum(ip) buffer := gopacket.NewSerializeBuffer() options := gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true} layers := []gopacket.SerializableLayer{eth, ip, udp, gopacket.Payload(payload)} if eth == nil { layers = layers[1:] } if err := gopacket.SerializeLayers(buffer, options, layers...); err != nil { return nil, fmt.Errorf("serializing UDP: %v", err) } return buffer.Bytes(), nil } // HandleEthernetIPv4PacketForRouter handles an IPv4 packet that is // directed to the router/gateway itself. The packet may be to the // broadcast MAC address, or to the router's MAC address. The target // IP may be the router's IP, or an internet (routed) IP. func (n *network) HandleEthernetIPv4PacketForRouter(ep EthernetPacket) { packet := ep.gp v4, ok := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) if !ok { return } srcIP, _ := netip.AddrFromSlice(v4.SrcIP) dstIP, _ := netip.AddrFromSlice(v4.DstIP) toForward := dstIP != n.lanIP.Addr() && dstIP != netip.IPv4Unspecified() udp, isUDP := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) if isDHCPRequest(packet) { res, err := n.s.createDHCPResponse(packet) if err != nil { n.logf("createDHCPResponse: %v", err) return } n.writeEth(res) return } if isMDNSQuery(packet) || isIGMP(packet) { // Don't log. Spammy for now. return } if isDNSRequest(packet) { // TODO(bradfitz): restrict this to 4.11.4.11? add DNS // on gateway instead? res, err := n.s.createDNSResponse(packet) if err != nil { n.logf("createDNSResponse: %v", err) return } n.writeEth(res) return } if isUDP && fakeSyslog.Match(dstIP) { node, ok := n.nodesByIP[srcIP] if !ok { return } if node.verboseSyslog { // TODO(bradfitz): parse this and capture it, structured, into // node's log buffer. log.Printf("syslog from %v: %s", node, udp.Payload) } return } if !toForward && isNATPMP(packet) { n.handleNATPMPRequest(UDPPacket{ Src: netip.AddrPortFrom(srcIP, uint16(udp.SrcPort)), Dst: netip.AddrPortFrom(dstIP, uint16(udp.DstPort)), Payload: udp.Payload, }) return } if toForward && isUDP { src := netip.AddrPortFrom(srcIP, uint16(udp.SrcPort)) dst := netip.AddrPortFrom(dstIP, uint16(udp.DstPort)) buf, err := n.serializedUDPPacket(src, dst, udp.Payload, nil) if err != nil { n.logf("serializing UDP packet: %v", err) return } n.s.pcapWriter.WritePacket(gopacket.CaptureInfo{ Timestamp: time.Now(), CaptureLength: len(buf), Length: len(buf), InterfaceIndex: n.lanInterfaceID, }, buf) lanSrc := src // the original src, before NAT (for logging only) src = n.doNATOut(src, dst) if !src.IsValid() { n.logf("warning: NAT dropped packet; no NAT out mapping for %v=>%v", lanSrc, dst) return } buf, err = n.serializedUDPPacket(src, dst, udp.Payload, nil) if err != nil { n.logf("serializing UDP packet: %v", err) return } n.s.pcapWriter.WritePacket(gopacket.CaptureInfo{ Timestamp: time.Now(), CaptureLength: len(buf), Length: len(buf), InterfaceIndex: n.wanInterfaceID, }, buf) n.s.routeUDPPacket(UDPPacket{ Src: src, Dst: dst, Payload: udp.Payload, }) return } if toForward && n.s.shouldInterceptTCP(packet) { ipp := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) pktCopy := make([]byte, 0, len(ipp.Contents)+len(ipp.Payload)) pktCopy = append(pktCopy, ipp.Contents...) pktCopy = append(pktCopy, ipp.Payload...) packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ Payload: buffer.MakeWithData(pktCopy), }) n.linkEP.InjectInbound(header.IPv4ProtocolNumber, packetBuf) packetBuf.DecRef() return } if isUDP && (udp.DstPort == pcpPort || udp.DstPort == ssdpPort) { // We handle NAT-PMP, but not these yet. // TODO(bradfitz): handle? marginal utility so far. // Don't log about them being unknown. return } //log.Printf("Got packet: %v", packet) } func (s *Server) createDHCPResponse(request gopacket.Packet) ([]byte, error) { ethLayer := request.Layer(layers.LayerTypeEthernet).(*layers.Ethernet) srcMAC, ok := macOf(ethLayer.SrcMAC) if !ok { return nil, nil } node, ok := s.nodeByMAC[srcMAC] if !ok { log.Printf("DHCP request from unknown node %v; ignoring", srcMAC) return nil, nil } gwIP := node.net.lanIP.Addr() ipLayer := request.Layer(layers.LayerTypeIPv4).(*layers.IPv4) udpLayer := request.Layer(layers.LayerTypeUDP).(*layers.UDP) dhcpLayer := request.Layer(layers.LayerTypeDHCPv4).(*layers.DHCPv4) response := &layers.DHCPv4{ Operation: layers.DHCPOpReply, HardwareType: layers.LinkTypeEthernet, HardwareLen: 6, Xid: dhcpLayer.Xid, ClientHWAddr: dhcpLayer.ClientHWAddr, Flags: dhcpLayer.Flags, YourClientIP: node.lanIP.AsSlice(), Options: []layers.DHCPOption{ { Type: layers.DHCPOptServerID, Data: gwIP.AsSlice(), // DHCP server's IP Length: 4, }, }, } var msgType layers.DHCPMsgType for _, opt := range dhcpLayer.Options { if opt.Type == layers.DHCPOptMessageType && opt.Length > 0 { msgType = layers.DHCPMsgType(opt.Data[0]) } } switch msgType { case layers.DHCPMsgTypeDiscover: response.Options = append(response.Options, layers.DHCPOption{ Type: layers.DHCPOptMessageType, Data: []byte{byte(layers.DHCPMsgTypeOffer)}, Length: 1, }) case layers.DHCPMsgTypeRequest: response.Options = append(response.Options, layers.DHCPOption{ Type: layers.DHCPOptMessageType, Data: []byte{byte(layers.DHCPMsgTypeAck)}, Length: 1, }, layers.DHCPOption{ Type: layers.DHCPOptLeaseTime, Data: binary.BigEndian.AppendUint32(nil, 3600), // hour? sure. Length: 4, }, layers.DHCPOption{ Type: layers.DHCPOptRouter, Data: gwIP.AsSlice(), Length: 4, }, layers.DHCPOption{ Type: layers.DHCPOptDNS, Data: fakeDNS.v4.AsSlice(), Length: 4, }, layers.DHCPOption{ Type: layers.DHCPOptSubnetMask, Data: net.CIDRMask(node.net.lanIP.Bits(), 32), Length: 4, }, ) } eth := &layers.Ethernet{ SrcMAC: node.net.mac.HWAddr(), DstMAC: ethLayer.SrcMAC, EthernetType: layers.EthernetTypeIPv4, } ip := &layers.IPv4{ Version: 4, TTL: 64, Protocol: layers.IPProtocolUDP, SrcIP: ipLayer.DstIP, DstIP: ipLayer.SrcIP, } udp := &layers.UDP{ SrcPort: udpLayer.DstPort, DstPort: udpLayer.SrcPort, } udp.SetNetworkLayerForChecksum(ip) buffer := gopacket.NewSerializeBuffer() options := gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true} if err := gopacket.SerializeLayers(buffer, options, eth, ip, udp, response, ); err != nil { return nil, err } return buffer.Bytes(), nil } func isDHCPRequest(pkt gopacket.Packet) bool { v4, ok := pkt.Layer(layers.LayerTypeIPv4).(*layers.IPv4) if !ok || v4.Protocol != layers.IPProtocolUDP { return false } udp, ok := pkt.Layer(layers.LayerTypeUDP).(*layers.UDP) return ok && udp.DstPort == 67 && udp.SrcPort == 68 } func isIGMP(pkt gopacket.Packet) bool { return pkt.Layer(layers.LayerTypeIGMP) != nil } func isMDNSQuery(pkt gopacket.Packet) bool { udp, ok := pkt.Layer(layers.LayerTypeUDP).(*layers.UDP) // TODO(bradfitz): also check IPv4 DstIP=224.0.0.251 (or whatever) return ok && udp.SrcPort == 5353 && udp.DstPort == 5353 } func (s *Server) shouldInterceptTCP(pkt gopacket.Packet) bool { tcp, ok := pkt.Layer(layers.LayerTypeTCP).(*layers.TCP) if !ok { return false } ipv4, ok := pkt.Layer(layers.LayerTypeIPv4).(*layers.IPv4) if !ok { return false } if tcp.DstPort == 123 { // Test port for TCP interception. Not really useful, but cute for // demos. return true } dstIP, _ := netip.AddrFromSlice(ipv4.DstIP.To4()) if tcp.DstPort == 80 || tcp.DstPort == 443 { for _, v := range []virtualIP{fakeControl, fakeDERP1, fakeDERP2, fakeLogCatcher} { if v.Match(dstIP) { return true } } if fakeProxyControlplane.Match(dstIP) { return s.blendReality } if s.derpIPs.Contains(dstIP) { return true } } if tcp.DstPort == 8008 && fakeTestAgent.Match(dstIP) { // Connection from cmd/tta. return true } return false } // isDNSRequest reports whether pkt is a DNS request to the fake DNS server. func isDNSRequest(pkt gopacket.Packet) bool { udp, ok := pkt.Layer(layers.LayerTypeUDP).(*layers.UDP) if !ok || udp.DstPort != 53 { return false } ip, ok := pkt.Layer(layers.LayerTypeIPv4).(*layers.IPv4) if !ok { return false } dstIP, ok := netip.AddrFromSlice(ip.DstIP) if !ok || !fakeDNS.Match(dstIP) { return false } dns, ok := pkt.Layer(layers.LayerTypeDNS).(*layers.DNS) return ok && dns.QR == false && len(dns.Questions) > 0 } func isNATPMP(pkt gopacket.Packet) bool { udp, ok := pkt.Layer(layers.LayerTypeUDP).(*layers.UDP) return ok && udp.DstPort == 5351 && len(udp.Payload) > 0 && udp.Payload[0] == 0 // version 0, not 2 for PCP } func makeSTUNReply(req UDPPacket) (res UDPPacket, ok bool) { txid, err := stun.ParseBindingRequest(req.Payload) if err != nil { log.Printf("invalid STUN request: %v", err) return res, false } return UDPPacket{ Src: req.Dst, Dst: req.Src, Payload: stun.Response(txid, req.Src), }, true } func (s *Server) createDNSResponse(pkt gopacket.Packet) ([]byte, error) { ethLayer := pkt.Layer(layers.LayerTypeEthernet).(*layers.Ethernet) ipLayer := pkt.Layer(layers.LayerTypeIPv4).(*layers.IPv4) udpLayer := pkt.Layer(layers.LayerTypeUDP).(*layers.UDP) dnsLayer := pkt.Layer(layers.LayerTypeDNS).(*layers.DNS) if dnsLayer.OpCode != layers.DNSOpCodeQuery || dnsLayer.QR || len(dnsLayer.Questions) == 0 { return nil, nil } response := &layers.DNS{ ID: dnsLayer.ID, QR: true, AA: true, TC: false, RD: dnsLayer.RD, RA: true, OpCode: layers.DNSOpCodeQuery, ResponseCode: layers.DNSResponseCodeNoErr, } var names []string for _, q := range dnsLayer.Questions { response.QDCount++ response.Questions = append(response.Questions, q) if mem.HasSuffix(mem.B(q.Name), mem.S(".pool.ntp.org")) { // Just drop DNS queries for NTP servers. For Debian/etc guests used // during development. Not needed. Assume VM guests get correct time // via their hypervisor. return nil, nil } names = append(names, q.Type.String()+"/"+string(q.Name)) if q.Class != layers.DNSClassIN || q.Type != layers.DNSTypeA { continue } if ip, ok := s.IPv4ForDNS(string(q.Name)); ok { response.ANCount++ response.Answers = append(response.Answers, layers.DNSResourceRecord{ Name: q.Name, Type: q.Type, Class: q.Class, IP: ip.AsSlice(), TTL: 60, }) } } eth2 := &layers.Ethernet{ SrcMAC: ethLayer.DstMAC, DstMAC: ethLayer.SrcMAC, EthernetType: layers.EthernetTypeIPv4, } ip2 := &layers.IPv4{ Version: 4, TTL: 64, Protocol: layers.IPProtocolUDP, SrcIP: ipLayer.DstIP, DstIP: ipLayer.SrcIP, } udp2 := &layers.UDP{ SrcPort: udpLayer.DstPort, DstPort: udpLayer.SrcPort, } udp2.SetNetworkLayerForChecksum(ip2) buffer := gopacket.NewSerializeBuffer() options := gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true} if err := gopacket.SerializeLayers(buffer, options, eth2, ip2, udp2, response); err != nil { return nil, err } const debugDNS = false if debugDNS { if len(response.Answers) > 0 { back := gopacket.NewPacket(buffer.Bytes(), layers.LayerTypeEthernet, gopacket.Lazy) log.Printf("Generated: %v", back) } else { log.Printf("made empty response for %q", names) } } return buffer.Bytes(), nil } // doNATOut performs NAT on an outgoing packet from src to dst, where // src is a LAN IP and dst is a WAN IP. // // It returns the source WAN ip:port to use. // // If newSrc is invalid, the packet should be dropped. func (n *network) doNATOut(src, dst netip.AddrPort) (newSrc netip.AddrPort) { n.natMu.Lock() defer n.natMu.Unlock() // First see if there's a port mapping, before doing NAT. if wanAP, ok := n.portMapFlow[portmapFlowKey{ peerWAN: dst, lanAP: src, }]; ok { return wanAP } return n.natTable.PickOutgoingSrc(src, dst, time.Now()) } type portmapFlowKey struct { peerWAN netip.AddrPort // the peer's WAN ip:port lanAP netip.AddrPort } // doNATIn performs NAT on an incoming packet from WAN src to WAN dst, returning // a new destination LAN ip:port to use. // // If newDst is invalid, the packet should be dropped. func (n *network) doNATIn(src, dst netip.AddrPort) (newDst netip.AddrPort) { n.natMu.Lock() defer n.natMu.Unlock() now := time.Now() // First see if there's a port mapping, before doing NAT. if lanAP, ok := n.portMap[dst]; ok { if now.Before(lanAP.expiry) { mak.Set(&n.portMapFlow, portmapFlowKey{ peerWAN: src, lanAP: lanAP.dst, }, dst) //n.logf("NAT: doNatIn: port mapping %v=>%v", dst, lanAP.dst) return lanAP.dst } n.logf("NAT: doNatIn: port mapping EXPIRED for %v=>%v", dst, lanAP.dst) delete(n.portMap, dst) return netip.AddrPort{} } return n.natTable.PickIncomingDst(src, dst, now) } // IsPublicPortUsed reports whether the given public port is currently in use. // // n.natMu must be held by the caller. (It's only called by nat implementations // which are always called with natMu held)) func (n *network) IsPublicPortUsed(ap netip.AddrPort) bool { _, ok := n.portMap[ap] return ok } func (n *network) doPortMap(src netip.Addr, dstLANPort, wantExtPort uint16, sec int) (gotPort uint16, ok bool) { n.natMu.Lock() defer n.natMu.Unlock() if !n.portmap { return 0, false } wanAP := netip.AddrPortFrom(n.wanIP, wantExtPort) dst := netip.AddrPortFrom(src, dstLANPort) if sec == 0 { lanAP, ok := n.portMap[wanAP] if ok && lanAP.dst.Addr() == src { delete(n.portMap, wanAP) } return 0, false } // See if they already have a mapping and extend expiry if so. for k, v := range n.portMap { if v.dst == dst { n.portMap[k] = portMapping{ dst: dst, expiry: time.Now().Add(time.Duration(sec) * time.Second), } return k.Port(), true } } for try := 0; try < 20_000; try++ { if wanAP.Port() > 0 && !n.natTable.IsPublicPortUsed(wanAP) { mak.Set(&n.portMap, wanAP, portMapping{ dst: dst, expiry: time.Now().Add(time.Duration(sec) * time.Second), }) n.logf("vnet: allocated NAT mapping from %v to %v", wanAP, dst) return wanAP.Port(), true } wantExtPort = rand.N(uint16(32<<10)) + 32<<10 wanAP = netip.AddrPortFrom(n.wanIP, wantExtPort) } return 0, false } func (n *network) createARPResponse(pkt gopacket.Packet) ([]byte, error) { ethLayer, ok := pkt.Layer(layers.LayerTypeEthernet).(*layers.Ethernet) if !ok { return nil, nil } arpLayer, ok := pkt.Layer(layers.LayerTypeARP).(*layers.ARP) if !ok || arpLayer.Operation != layers.ARPRequest || arpLayer.AddrType != layers.LinkTypeEthernet || arpLayer.Protocol != layers.EthernetTypeIPv4 || arpLayer.HwAddressSize != 6 || arpLayer.ProtAddressSize != 4 || len(arpLayer.DstProtAddress) != 4 { return nil, nil } wantIP := netip.AddrFrom4([4]byte(arpLayer.DstProtAddress)) foundMAC, ok := n.MACOfIP(wantIP) if !ok { return nil, nil } eth := &layers.Ethernet{ SrcMAC: foundMAC.HWAddr(), DstMAC: ethLayer.SrcMAC, EthernetType: layers.EthernetTypeARP, } a2 := &layers.ARP{ AddrType: layers.LinkTypeEthernet, Protocol: layers.EthernetTypeIPv4, HwAddressSize: 6, ProtAddressSize: 4, Operation: layers.ARPReply, SourceHwAddress: foundMAC.HWAddr(), SourceProtAddress: arpLayer.DstProtAddress, DstHwAddress: ethLayer.SrcMAC, DstProtAddress: arpLayer.SourceProtAddress, } buffer := gopacket.NewSerializeBuffer() options := gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true} if err := gopacket.SerializeLayers(buffer, options, eth, a2); err != nil { return nil, err } return buffer.Bytes(), nil } func (n *network) handleNATPMPRequest(req UDPPacket) { if !n.portmap { return } if string(req.Payload) == "\x00\x00" { // https://www.rfc-editor.org/rfc/rfc6886#section-3.2 res := make([]byte, 0, 12) res = append(res, 0, // version 0 (NAT-PMP) 128, // response to op 0 (128+0) 0, 0, // result code success ) res = binary.BigEndian.AppendUint32(res, uint32(time.Now().Unix())) wan4 := n.wanIP.As4() res = append(res, wan4[:]...) n.WriteUDPPacketNoNAT(UDPPacket{ Src: req.Dst, Dst: req.Src, Payload: res, }) return } // Map UDP request if len(req.Payload) == 12 && req.Payload[0] == 0 && req.Payload[1] == 1 { // https://www.rfc-editor.org/rfc/rfc6886#section-3.3 // "00 01 00 00 ed 40 00 00 00 00 1c 20" => // 00 ver // 01 op=map UDP // 00 00 reserved (0 in request; in response, this is the result code) // ed 40 internal port 60736 // 00 00 suggested external port // 00 00 1c 20 suggested lifetime in seconds (7200 sec = 2 hours) internalPort := binary.BigEndian.Uint16(req.Payload[4:6]) wantExtPort := binary.BigEndian.Uint16(req.Payload[6:8]) lifetimeSec := binary.BigEndian.Uint32(req.Payload[8:12]) gotPort, ok := n.doPortMap(req.Src.Addr(), internalPort, wantExtPort, int(lifetimeSec)) if !ok { n.logf("NAT-PMP map request for %v:%d failed", req.Src.Addr(), internalPort) return } res := make([]byte, 0, 16) res = append(res, 0, // version 0 (NAT-PMP) 1+128, // response to op 1 0, 0, // result code success ) res = binary.BigEndian.AppendUint32(res, uint32(time.Now().Unix())) res = binary.BigEndian.AppendUint16(res, internalPort) res = binary.BigEndian.AppendUint16(res, gotPort) res = binary.BigEndian.AppendUint32(res, lifetimeSec) n.WriteUDPPacketNoNAT(UDPPacket{ Src: req.Dst, Dst: req.Src, Payload: res, }) return } n.logf("TODO: handle NAT-PMP packet % 02x", req.Payload) } // UDPPacket is a UDP packet. // // For the purposes of this project, a UDP packet // (not a general IP packet) is the unit to be NAT'ed, // as that's all that Tailscale uses. type UDPPacket struct { Src netip.AddrPort Dst netip.AddrPort Payload []byte // everything after UDP header } func (s *Server) WriteStartingBanner(w io.Writer) { fmt.Fprintf(w, "vnet serving clients:\n") for _, n := range s.nodes { fmt.Fprintf(w, " %v %15v (%v, %v)\n", n.mac, n.lanIP, n.net.wanIP, n.net.natStyle.Load()) } } type agentConn struct { node *node tc *gonet.TCPConn } func (s *Server) addIdleAgentConn(ac *agentConn) { //log.Printf("got agent conn from %v", ac.node.mac) s.mu.Lock() defer s.mu.Unlock() s.agentConns.Make() s.agentConns.Add(ac) if waiter, ok := s.agentConnWaiter[ac.node]; ok { select { case waiter <- struct{}{}: default: } } } func (s *Server) takeAgentConn(ctx context.Context, n *node) (_ *agentConn, ok bool) { for { ac, ok := s.takeAgentConnOne(n) if ok { //log.Printf("got agent conn for %v", n.mac) return ac, true } s.mu.Lock() ready := make(chan struct{}) mak.Set(&s.agentConnWaiter, n, ready) s.mu.Unlock() //log.Printf("waiting for agent conn for %v", n.mac) select { case <-ctx.Done(): return nil, false case <-ready: case <-time.After(time.Second): // Try again regularly anyway, in case we have multiple clients // trying to hit the same node, or if a race means we weren't in the // select by the time addIdleAgentConn tried to signal us. } } } func (s *Server) takeAgentConnOne(n *node) (_ *agentConn, ok bool) { s.mu.Lock() defer s.mu.Unlock() for ac := range s.agentConns { if ac.node == n { s.agentConns.Delete(ac) return ac, true } } return nil, false } type NodeAgentClient struct { *tailscale.LocalClient HTTPClient *http.Client } func (s *Server) NodeAgentDialer(n *Node) DialFunc { s.mu.Lock() defer s.mu.Unlock() if d, ok := s.agentDialer[n.n]; ok { return d } d := func(ctx context.Context, network, addr string) (net.Conn, error) { ac, ok := s.takeAgentConn(ctx, n.n) if !ok { return nil, ctx.Err() } return ac.tc, nil } mak.Set(&s.agentDialer, n.n, d) return d } func (s *Server) NodeAgentClient(n *Node) *NodeAgentClient { d := s.NodeAgentDialer(n) return &NodeAgentClient{ LocalClient: &tailscale.LocalClient{ UseSocketOnly: true, OmitAuth: true, Dial: d, }, HTTPClient: &http.Client{ Transport: &http.Transport{ DialContext: d, }, }, } } // EnableHostFirewall enables the host's stateful firewall. func (c *NodeAgentClient) EnableHostFirewall(ctx context.Context) error { req, err := http.NewRequestWithContext(ctx, "GET", "http://unused/fw", nil) if err != nil { return err } res, err := c.HTTPClient.Do(req) if err != nil { return err } defer res.Body.Close() all, _ := io.ReadAll(res.Body) if res.StatusCode != 200 { return fmt.Errorf("unexpected status code %v: %s", res.Status, all) } return nil }