net/socks5: optimize UDP relay

Key changes:
- No mutex for every udp package: replace syncs.Map with regular map for udpTargetConns
- Use socksAddr as map key for better type safety
- Add test for multi udp target

Updates #7581

Change-Id: Ic3d384a9eab62dcbf267d7d6d268bf242cc8ed3c
Signed-off-by: VimT <me@vimt.me>
pull/14009/head
VimT 2 months ago committed by Brad Fitzpatrick
parent b0626ff84c
commit 43138c7a5c

@ -22,7 +22,6 @@ import (
"log" "log"
"net" "net"
"strconv" "strconv"
"tailscale.com/syncs"
"time" "time"
"tailscale.com/types/logger" "tailscale.com/types/logger"
@ -151,7 +150,7 @@ type Conn struct {
request *request request *request
udpClientAddr net.Addr udpClientAddr net.Addr
udpTargetConns syncs.Map[string, net.Conn] udpTargetConns map[socksAddr]net.Conn
} }
// Run starts the new connection. // Run starts the new connection.
@ -311,17 +310,18 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn) er
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
// client -> target
go func() {
defer cancel()
c.udpTargetConns = make(map[socksAddr]net.Conn)
// close all target udp connections when the client connection is closed // close all target udp connections when the client connection is closed
defer func() { defer func() {
c.udpTargetConns.Range(func(_ string, conn net.Conn) bool { for _, conn := range c.udpTargetConns {
_ = conn.Close() _ = conn.Close()
return true }
})
}() }()
// client -> target
go func() {
defer cancel()
buf := make([]byte, bufferSize) buf := make([]byte, bufferSize)
for { for {
select { select {
@ -354,33 +354,27 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn) er
func (c *Conn) getOrDialTargetConn( func (c *Conn) getOrDialTargetConn(
ctx context.Context, ctx context.Context,
clientConn net.PacketConn, clientConn net.PacketConn,
targetAddr string, targetAddr socksAddr,
) (net.Conn, error) { ) (net.Conn, error) {
host, port, err := splitHostPort(targetAddr) conn, exist := c.udpTargetConns[targetAddr]
if err != nil { if exist {
return nil, err
}
conn, loaded := c.udpTargetConns.Load(targetAddr)
if loaded {
return conn, nil return conn, nil
} }
conn, err = c.srv.dial(ctx, "udp", targetAddr) conn, err := c.srv.dial(ctx, "udp", targetAddr.hostPort())
if err != nil { if err != nil {
return nil, err return nil, err
} }
c.udpTargetConns.Store(targetAddr, conn) c.udpTargetConns[targetAddr] = conn
// target -> client // target -> client
go func() { go func() {
buf := make([]byte, bufferSize) buf := make([]byte, bufferSize)
addr := socksAddr{addrType: getAddrType(host), addr: host, port: port}
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return return
default: default:
err := c.handleUDPResponse(clientConn, addr, conn, buf) err := c.handleUDPResponse(clientConn, targetAddr, conn, buf)
if err != nil { if err != nil {
if isTimeout(err) { if isTimeout(err) {
continue continue
@ -414,18 +408,17 @@ func (c *Conn) handleUDPRequest(
return fmt.Errorf("parse udp request: %w", err) return fmt.Errorf("parse udp request: %w", err)
} }
targetAddr := req.addr.hostPort() targetConn, err := c.getOrDialTargetConn(ctx, clientConn, req.addr)
targetConn, err := c.getOrDialTargetConn(ctx, clientConn, targetAddr)
if err != nil { if err != nil {
return fmt.Errorf("dial target %s fail: %w", targetAddr, err) return fmt.Errorf("dial target %s fail: %w", req.addr, err)
} }
nn, err := targetConn.Write(data) nn, err := targetConn.Write(data)
if err != nil { if err != nil {
return fmt.Errorf("write to target %s fail: %w", targetAddr, err) return fmt.Errorf("write to target %s fail: %w", req.addr, err)
} }
if nn != len(data) { if nn != len(data) {
return fmt.Errorf("write to target %s fail: %w", targetAddr, io.ErrShortWrite) return fmt.Errorf("write to target %s fail: %w", req.addr, io.ErrShortWrite)
} }
return nil return nil
} }
@ -652,10 +645,15 @@ func (s socksAddr) marshal() ([]byte, error) {
pkt = binary.BigEndian.AppendUint16(pkt, s.port) pkt = binary.BigEndian.AppendUint16(pkt, s.port)
return pkt, nil return pkt, nil
} }
func (s socksAddr) hostPort() string { func (s socksAddr) hostPort() string {
return net.JoinHostPort(s.addr, strconv.Itoa(int(s.port))) return net.JoinHostPort(s.addr, strconv.Itoa(int(s.port)))
} }
func (s socksAddr) String() string {
return s.hostPort()
}
// response contains the contents of // response contains the contents of
// a response packet sent from the proxy // a response packet sent from the proxy
// to the client. // to the client.

@ -169,12 +169,25 @@ func TestReadPassword(t *testing.T) {
func TestUDP(t *testing.T) { func TestUDP(t *testing.T) {
// backend UDP server which we'll use SOCKS5 to connect to // backend UDP server which we'll use SOCKS5 to connect to
newUDPEchoServer := func() net.PacketConn {
listener, err := net.ListenPacket("udp", ":0") listener, err := net.ListenPacket("udp", ":0")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
backendServerPort := listener.LocalAddr().(*net.UDPAddr).Port
go udpEchoServer(listener) go udpEchoServer(listener)
return listener
}
const echoServerNumber = 3
echoServerListener := make([]net.PacketConn, echoServerNumber)
for i := 0; i < echoServerNumber; i++ {
echoServerListener[i] = newUDPEchoServer()
}
defer func() {
for i := 0; i < echoServerNumber; i++ {
_ = echoServerListener[i].Close()
}
}()
// SOCKS5 server // SOCKS5 server
socks5, err := net.Listen("tcp", ":0") socks5, err := net.Listen("tcp", ":0")
@ -184,12 +197,14 @@ func TestUDP(t *testing.T) {
socks5Port := socks5.Addr().(*net.TCPAddr).Port socks5Port := socks5.Addr().(*net.TCPAddr).Port
go socks5Server(socks5) go socks5Server(socks5)
// make a socks5 udpAssociate conn
newUdpAssociateConn := func() (socks5Conn net.Conn, socks5UDPAddr socksAddr) {
// net/proxy don't support UDP, so we need to manually send the SOCKS5 UDP request // net/proxy don't support UDP, so we need to manually send the SOCKS5 UDP request
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", socks5Port)) conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", socks5Port))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
_, err = conn.Write([]byte{0x05, 0x01, 0x00}) // client hello with no auth _, err = conn.Write([]byte{socks5Version, 0x01, noAuthRequired}) // client hello with no auth
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -198,20 +213,16 @@ func TestUDP(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if n != 2 || buf[0] != 0x05 || buf[1] != 0x00 { if n != 2 || buf[0] != socks5Version || buf[1] != noAuthRequired {
t.Fatalf("got: %q want: 0x05 0x00", buf[:n]) t.Fatalf("got: %q want: 0x05 0x00", buf[:n])
} }
targetAddr := socksAddr{ targetAddr := socksAddr{addrType: ipv4, addr: "0.0.0.0", port: 0}
addrType: domainName,
addr: "localhost",
port: uint16(backendServerPort),
}
targetAddrPkt, err := targetAddr.marshal() targetAddrPkt, err := targetAddr.marshal()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
_, err = conn.Write(append([]byte{0x05, 0x03, 0x00}, targetAddrPkt...)) // client reqeust _, err = conn.Write(append([]byte{socks5Version, byte(udpAssociate), 0x00}, targetAddrPkt...)) // client reqeust
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -220,7 +231,7 @@ func TestUDP(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if n < 3 || !bytes.Equal(buf[:3], []byte{0x05, 0x00, 0x00}) { if n < 3 || !bytes.Equal(buf[:3], []byte{socks5Version, 0x00, 0x00}) {
t.Fatalf("got: %q want: 0x05 0x00 0x00", buf[:n]) t.Fatalf("got: %q want: 0x05 0x00 0x00", buf[:n])
} }
udpProxySocksAddr, err := parseSocksAddr(bytes.NewReader(buf[3:n])) udpProxySocksAddr, err := parseSocksAddr(bytes.NewReader(buf[3:n]))
@ -228,40 +239,51 @@ func TestUDP(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
udpProxyAddr, err := net.ResolveUDPAddr("udp", udpProxySocksAddr.hostPort()) return conn, udpProxySocksAddr
if err != nil {
t.Fatal(err)
} }
udpConn, err := net.DialUDP("udp", nil, udpProxyAddr)
if err != nil { conn, udpProxySocksAddr := newUdpAssociateConn()
t.Fatal(err) defer conn.Close()
}
udpPayload, err := (&udpRequest{addr: targetAddr}).marshal() sendUDPAndWaitResponse := func(socks5UDPConn net.Conn, addr socksAddr, body []byte) (responseBody []byte) {
udpPayload, err := (&udpRequest{addr: addr}).marshal()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
udpPayload = append(udpPayload, []byte("Test")...) udpPayload = append(udpPayload, body...)
_, err = udpConn.Write(udpPayload) // send udp package _, err = socks5UDPConn.Write(udpPayload)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
n, _, err = udpConn.ReadFrom(buf) buf := make([]byte, 1024)
n, err := socks5UDPConn.Read(buf)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
_, responseBody, err := parseUDPRequest(buf[:n]) // read udp response _, responseBody, err = parseUDPRequest(buf[:n])
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if string(responseBody) != "Test" { return responseBody
t.Fatalf("got: %q want: Test", responseBody)
} }
err = udpConn.Close()
udpProxyAddr, err := net.ResolveUDPAddr("udp", udpProxySocksAddr.hostPort())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = conn.Close() socks5UDPConn, err := net.DialUDP("udp", nil, udpProxyAddr)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer socks5UDPConn.Close()
for i := 0; i < echoServerNumber; i++ {
port := echoServerListener[i].LocalAddr().(*net.UDPAddr).Port
addr := socksAddr{addrType: ipv4, addr: "127.0.0.1", port: uint16(port)}
requestBody := []byte(fmt.Sprintf("Test %d", i))
responseBody := sendUDPAndWaitResponse(socks5UDPConn, addr, requestBody)
if !bytes.Equal(requestBody, responseBody) {
t.Fatalf("got: %q want: %q", responseBody, requestBody)
}
}
} }

Loading…
Cancel
Save