wgengine/netstack: implement UDP relaying to advertised subnets

TCP was done in 662fbd4a09.

This does the same for UDP.

Tested by hand. Integration tests will have to come later. I'd wanted
to do it in this commit, but the SOCKS5 server needed for interop
testing between two userspace nodes doesn't yet support UDP and I
didn't want to invent some whole new userspace packet injection
interface at this point, as SOCKS seems like a better route, but
that's its own bug.

Fixes #2302

RELNOTE=netstack mode can now UDP relay to subnets

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/2490/head
Brad Fitzpatrick 3 years ago committed by Brad Fitzpatrick
parent 3daf27eaad
commit 95a9adbb97

@ -136,6 +136,24 @@ func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magi
return ns, nil
}
// wrapProtoHandler returns protocol handler h wrapped in a version
// that dynamically reconfigures ns's subnet addresses as needed for
// outbound traffic.
func (ns *Impl) wrapProtoHandler(h func(stack.TransportEndpointID, *stack.PacketBuffer) bool) func(stack.TransportEndpointID, *stack.PacketBuffer) bool {
return func(tei stack.TransportEndpointID, pb *stack.PacketBuffer) bool {
addr := tei.LocalAddress
ip, ok := netaddr.FromStdIP(net.IP(addr))
if !ok {
ns.logf("netstack: could not parse local address for incoming connection")
return false
}
if !ns.isLocalIP(ip) {
ns.addSubnetAddress(ip)
}
return h(tei, pb)
}
}
// Start sets up all the handlers so netstack can start working. Implements
// wgengine.FakeImpl.
func (ns *Impl) Start() error {
@ -145,25 +163,8 @@ func (ns *Impl) Start() error {
const maxInFlightConnectionAttempts = 16
tcpFwd := tcp.NewForwarder(ns.ipstack, tcpReceiveBufferSize, maxInFlightConnectionAttempts, ns.acceptTCP)
udpFwd := udp.NewForwarder(ns.ipstack, ns.acceptUDP)
ns.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, func(tei stack.TransportEndpointID, pb *stack.PacketBuffer) bool {
addr := tei.LocalAddress
var pn tcpip.NetworkProtocolNumber
if addr.To4() != "" {
pn = ipv4.ProtocolNumber
} else {
pn = ipv6.ProtocolNumber
}
ip, ok := netaddr.FromStdIP(net.IP(addr))
if !ok {
ns.logf("netstack: could not parse local address %s for incoming TCP connection", ip)
return false
}
if !ns.isLocalIP(ip) {
ns.addSubnetAddress(pn, ip)
}
return tcpFwd.HandlePacket(tei, pb)
})
ns.ipstack.SetTransportProtocolHandler(udp.ProtocolNumber, udpFwd.HandlePacket)
ns.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, ns.wrapProtoHandler(tcpFwd.HandlePacket))
ns.ipstack.SetTransportProtocolHandler(udp.ProtocolNumber, ns.wrapProtoHandler(udpFwd.HandlePacket))
go ns.injectOutbound()
ns.tundev.PostFilterIn = ns.injectInbound
return nil
@ -214,13 +215,19 @@ func (ns *Impl) updateDNS(nm *netmap.NetworkMap) {
ns.dns = DNSMapFromNetworkMap(nm)
}
func (ns *Impl) addSubnetAddress(pn tcpip.NetworkProtocolNumber, ip netaddr.IP) {
func (ns *Impl) addSubnetAddress(ip netaddr.IP) {
ns.mu.Lock()
ns.connsOpenBySubnetIP[ip]++
needAdd := ns.connsOpenBySubnetIP[ip] == 1
ns.mu.Unlock()
// Only register address into netstack for first concurrent connection.
if needAdd {
var pn tcpip.NetworkProtocolNumber
if ip.Is4() {
pn = ipv4.ProtocolNumber
} else if ip.Is6() {
pn = ipv6.ProtocolNumber
}
ns.ipstack.AddAddress(nicID, pn, tcpip.Address(ip.IPAddr().IP))
}
}
@ -543,9 +550,9 @@ func (ns *Impl) forwardTCP(client *gonet.TCPConn, wq *waiter.Queue, dialAddr tcp
}
func (ns *Impl) acceptUDP(r *udp.ForwarderRequest) {
reqDetails := r.ID()
sess := r.ID()
if debugNetstack {
ns.logf("[v2] UDP ForwarderRequest: %v", stringifyTEI(reqDetails))
ns.logf("[v2] UDP ForwarderRequest: %v", stringifyTEI(sess))
}
var wq waiter.Queue
ep, err := r.CreateEndpoint(&wq)
@ -553,30 +560,50 @@ func (ns *Impl) acceptUDP(r *udp.ForwarderRequest) {
ns.logf("acceptUDP: could not create endpoint: %v", err)
return
}
localAddr, err := ep.GetLocalAddress()
if err != nil {
dstAddr, ok := ipPortOfNetstackAddr(sess.LocalAddress, sess.LocalPort)
if !ok {
return
}
remoteAddr, err := ep.GetRemoteAddress()
if err != nil {
srcAddr, ok := ipPortOfNetstackAddr(sess.RemoteAddress, sess.RemotePort)
if !ok {
return
}
c := gonet.NewUDPConn(ns.ipstack, &wq, ep)
go ns.forwardUDP(c, &wq, localAddr, remoteAddr)
go ns.forwardUDP(c, &wq, srcAddr, dstAddr)
}
func (ns *Impl) forwardUDP(client *gonet.UDPConn, wq *waiter.Queue, clientLocalAddr, clientRemoteAddr tcpip.FullAddress) {
port := clientLocalAddr.Port
// forwardUDP proxies between client (with addr clientAddr) and dstAddr.
//
// dstAddr may be either a local Tailscale IP, in which we case we proxy to
// 127.0.0.1, or any other IP (from an advertised subnet), in which case we
// proxy to it directly.
func (ns *Impl) forwardUDP(client *gonet.UDPConn, wq *waiter.Queue, clientAddr, dstAddr netaddr.IPPort) {
port, srcPort := dstAddr.Port(), clientAddr.Port()
ns.logf("[v2] netstack: forwarding incoming UDP connection on port %v", port)
backendListenAddr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: int(clientRemoteAddr.Port)}
backendRemoteAddr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: int(port)}
backendConn, err := net.ListenUDP("udp4", backendListenAddr)
var backendListenAddr *net.UDPAddr
var backendRemoteAddr *net.UDPAddr
isLocal := ns.isLocalIP(dstAddr.IP())
if isLocal {
backendRemoteAddr = &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: int(port)}
backendListenAddr = &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: int(srcPort)}
} else {
backendRemoteAddr = dstAddr.UDPAddr()
if dstAddr.IP().Is4() {
backendListenAddr = &net.UDPAddr{IP: net.ParseIP("0.0.0.0"), Port: int(srcPort)}
} else {
backendListenAddr = &net.UDPAddr{IP: net.ParseIP("::"), Port: int(srcPort)}
}
}
backendConn, err := net.ListenUDP("udp", backendListenAddr)
if err != nil {
ns.logf("netstack: could not bind local port %v: %v, trying again with random port", clientRemoteAddr.Port, err)
ns.logf("netstack: could not bind local port %v: %v, trying again with random port", backendListenAddr.Port, err)
backendListenAddr.Port = 0
backendConn, err = net.ListenUDP("udp4", backendListenAddr)
backendConn, err = net.ListenUDP("udp", backendListenAddr)
if err != nil {
ns.logf("netstack: could not connect to local UDP server on port %v: %v", port, err)
ns.logf("netstack: could not create UDP socket, preventing forwarding to %v: %v", dstAddr, err)
return
}
}
@ -585,28 +612,47 @@ func (ns *Impl) forwardUDP(client *gonet.UDPConn, wq *waiter.Queue, clientLocalA
if !ok {
ns.logf("could not get backend local IP:port from %v:%v", backendLocalAddr.IP, backendLocalAddr.Port)
}
clientRemoteIP, _ := netaddr.FromStdIP(net.ParseIP(clientRemoteAddr.Addr.String()))
ns.e.RegisterIPPortIdentity(backendLocalIPPort, clientRemoteIP)
if isLocal {
ns.e.RegisterIPPortIdentity(backendLocalIPPort, dstAddr.IP())
}
ctx, cancel := context.WithCancel(context.Background())
timer := time.AfterFunc(2*time.Minute, func() {
idleTimeout := 2 * time.Minute
if port == 53 {
// Make DNS packet copies time out much sooner.
//
// TODO(bradfitz): make DNS queries over UDP forwarding even
// cheaper by adding an additional idleTimeout post-DNS-reply.
// For instance, after the DNS response goes back out, then only
// wait a few seconds (or zero, really)
idleTimeout = 30 * time.Second
}
timer := time.AfterFunc(idleTimeout, func() {
if isLocal {
ns.e.UnregisterIPPortIdentity(backendLocalIPPort)
ns.logf("netstack: UDP session between %s and %s timed out", clientRemoteAddr, backendRemoteAddr)
}
ns.logf("netstack: UDP session between %s and %s timed out", backendListenAddr, backendRemoteAddr)
cancel()
client.Close()
backendConn.Close()
})
extend := func() {
timer.Reset(2 * time.Minute)
timer.Reset(idleTimeout)
}
startPacketCopy(ctx, cancel, client, &net.UDPAddr{
IP: net.ParseIP(clientRemoteAddr.Addr.String()),
Port: int(clientRemoteAddr.Port),
}, backendConn, ns.logf, extend)
startPacketCopy(ctx, cancel, client, clientAddr.UDPAddr(), backendConn, ns.logf, extend)
startPacketCopy(ctx, cancel, backendConn, backendRemoteAddr, client, ns.logf, extend)
if isLocal {
// Wait for the copies to be done before decrementing the
// subnet address count to potentially remove the route.
<-ctx.Done()
ns.removeSubnetAddress(dstAddr.IP())
}
}
func startPacketCopy(ctx context.Context, cancel context.CancelFunc, dst net.PacketConn, dstAddr net.Addr, src net.PacketConn, logf logger.Logf, extend func()) {
if debugNetstack {
logf("[v2] netstack: startPacketCopy to %v (%T) from %T", dstAddr, dst, src)
}
go func() {
defer cancel() // tear down the other direction's copy
pkt := make([]byte, mtu)
@ -643,3 +689,7 @@ func stringifyTEI(tei stack.TransportEndpointID) string {
remoteHostPort := net.JoinHostPort(tei.RemoteAddress.String(), strconv.Itoa(int(tei.RemotePort)))
return fmt.Sprintf("%s -> %s", remoteHostPort, localHostPort)
}
func ipPortOfNetstackAddr(a tcpip.Address, port uint16) (ipp netaddr.IPPort, ok bool) {
return netaddr.FromStdAddr(net.IP(a), int(port), "") // TODO(bradfitz): can do without allocs
}

Loading…
Cancel
Save