From 4c80344e27d0d44121bff0980f2df98a2bf1244c Mon Sep 17 00:00:00 2001 From: Naman Sood Date: Mon, 8 Mar 2021 13:43:01 -0500 Subject: [PATCH] wgengine/netstack: stop UDP forwarding when one side dies Updates #504 Updates #707 Signed-off-by: Naman Sood --- wgengine/netstack/netstack.go | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index 01fe8b094..61b1cdb9b 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -419,7 +419,7 @@ func (ns *Impl) forwardUDP(client *gonet.UDPConn, wq *waiter.Queue, clientLocalA } ctx, cancel := context.WithCancel(context.Background()) timer := time.AfterFunc(2*time.Minute, func() { - ns.logf("netstack: forwarder UDP connection on port %v closed", port) + ns.logf("netstack: UDP session between %s and %s timed out", clientRemoteAddr, backendRemoteAddr) cancel() client.Close() backendConn.Close() @@ -427,16 +427,17 @@ func (ns *Impl) forwardUDP(client *gonet.UDPConn, wq *waiter.Queue, clientLocalA extend := func() { timer.Reset(2 * time.Minute) } - startPacketCopy(ctx, client, &net.UDPAddr{ + startPacketCopy(ctx, cancel, client, &net.UDPAddr{ IP: net.ParseIP(clientRemoteAddr.Addr.String()), Port: int(clientRemoteAddr.Port), }, backendConn, ns.logf, extend) - startPacketCopy(ctx, backendConn, backendRemoteAddr, client, ns.logf, extend) + startPacketCopy(ctx, cancel, backendConn, backendRemoteAddr, client, ns.logf, extend) } -func startPacketCopy(ctx context.Context, dst net.PacketConn, dstAddr net.Addr, src net.PacketConn, logf logger.Logf, extend func()) { +func startPacketCopy(ctx context.Context, cancel context.CancelFunc, dst net.PacketConn, dstAddr net.Addr, src net.PacketConn, logf logger.Logf, extend func()) { go func() { + defer cancel() // tear down the other direction's copy pkt := make([]byte, mtu) for { select { @@ -457,7 +458,9 @@ func startPacketCopy(ctx context.Context, dst net.PacketConn, dstAddr net.Addr, } return } - logf("[v2] wrote UDP packet %s -> %s", srcAddr, dstAddr) + if debugNetstack { + logf("[v2] wrote UDP packet %s -> %s", srcAddr, dstAddr) + } extend() } }