diff --git a/tstest/integration/nat/nat_test.go b/tstest/integration/nat/nat_test.go index e9cfc7a46..8dbaad111 100644 --- a/tstest/integration/nat/nat_test.go +++ b/tstest/integration/nat/nat_test.go @@ -94,6 +94,13 @@ func easy(c *vnet.Config) *vnet.Node { fmt.Sprintf("192.168.%d.1/24", n), vnet.EasyNAT)) } +func easyPMP(c *vnet.Config) *vnet.Node { + n := c.NumNodes() + 1 + return c.AddNode(c.AddNetwork( + fmt.Sprintf("2.%d.%d.%d", n, n, n), // public IP + fmt.Sprintf("192.168.%d.1/24", n), vnet.EasyNAT, vnet.NATPMP)) +} + func hard(c *vnet.Config) *vnet.Node { n := c.NumNodes() + 1 return c.AddNode(c.AddNetwork( @@ -161,7 +168,7 @@ func (nt *natTest) runTest(node1, node2 addNodeFunc) { cmd := exec.Command("qemu-system-x86_64", "-M", "microvm,isa-serial=off", - "-m", "1G", + "-m", "384M", "-nodefaults", "-no-user-config", "-nographic", "-kernel", nt.kernel, "-append", "console=hvc0 root=PARTUUID=60c24cc1-f3f9-427a-8199-dd02023b0001/PARTNROFF=1 ro init=/gokrazy/init panic=10 oops=panic pci=off nousb tsc=unstable clocksource=hpet tailscale-tta=1", @@ -252,7 +259,7 @@ func streamDaemonLogs(ctx context.Context, t testing.TB, c *vnet.NodeAgentClient Text string `json:"text"` } if err := dec.Decode(&logEntry); err != nil { - if err == io.EOF { + if err == io.EOF || errors.Is(err, context.Canceled) { return } t.Errorf("log entry: %v", err) @@ -324,3 +331,8 @@ func TestEasyHardPMP(t *testing.T) { nt := newNatTest(t) nt.runTest(easy, hardPMP) } + +func TestEasyPMPHard(t *testing.T) { + nt := newNatTest(t) + nt.runTest(easyPMP, hard) +} diff --git a/tstest/natlab/vnet/vnet.go b/tstest/natlab/vnet/vnet.go index caed96770..f920f20a9 100644 --- a/tstest/natlab/vnet/vnet.go +++ b/tstest/natlab/vnet/vnet.go @@ -412,10 +412,11 @@ type network struct { ns *stack.Stack linkEP *channel.Endpoint - natStyle syncs.AtomicValue[NAT] - natMu sync.Mutex // held while using + changing natTable - natTable NATTable - portMap map[netip.AddrPort]portMapping // WAN ip:port -> LAN ip:port + natStyle syncs.AtomicValue[NAT] + natMu sync.Mutex // held while using + changing natTable + natTable NATTable + portMap map[netip.AddrPort]portMapping // WAN ip:port -> LAN ip:port + portMapFlow map[portmapFlowKey]netip.AddrPort // (lanAP, peerWANAP) -> portmapped wanAP // writeFunc is a map of MAC -> func to write to that MAC. // It contains entries for connected nodes only. @@ -1197,13 +1198,27 @@ func (s *Server) createDNSResponse(pkt gopacket.Packet) ([]byte, error) { // doNATOut performs NAT on an outgoing packet from src to dst, where // src is a LAN IP and dst is a WAN IP. // -// It returns the souce WAN ip:port to use. +// It returns the source WAN ip:port to use. func (n *network) doNATOut(src, dst netip.AddrPort) (newSrc netip.AddrPort) { n.natMu.Lock() defer n.natMu.Unlock() + + // First see if there's a port mapping, before doing NAT. + if wanAP, ok := n.portMapFlow[portmapFlowKey{ + peerWAN: dst, + lanAP: src, + }]; ok { + return wanAP + } + return n.natTable.PickOutgoingSrc(src, dst, time.Now()) } +type portmapFlowKey struct { + peerWAN netip.AddrPort // the peer's WAN ip:port + lanAP netip.AddrPort +} + // doNATIn performs NAT on an incoming packet from WAN src to WAN dst, returning // a new destination LAN ip:port to use. func (n *network) doNATIn(src, dst netip.AddrPort) (newDst netip.AddrPort) { @@ -1215,6 +1230,10 @@ func (n *network) doNATIn(src, dst netip.AddrPort) (newDst netip.AddrPort) { // First see if there's a port mapping, before doing NAT. if lanAP, ok := n.portMap[dst]; ok { if now.Before(lanAP.expiry) { + mak.Set(&n.portMapFlow, portmapFlowKey{ + peerWAN: src, + lanAP: lanAP.dst, + }, dst) n.logf("XXX NAT: doNatIn: port mapping %v=>%v", dst, lanAP.dst) return lanAP.dst }