diff --git a/net/socks5/socks5.go b/net/socks5/socks5.go index db315d949..4a5befa1d 100644 --- a/net/socks5/socks5.go +++ b/net/socks5/socks5.go @@ -22,7 +22,6 @@ import ( "log" "net" "strconv" - "tailscale.com/syncs" "time" "tailscale.com/types/logger" @@ -151,7 +150,7 @@ type Conn struct { request *request udpClientAddr net.Addr - udpTargetConns syncs.Map[string, net.Conn] + udpTargetConns map[socksAddr]net.Conn } // Run starts the new connection. @@ -311,17 +310,18 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn) er ctx, cancel := context.WithCancel(context.Background()) defer cancel() - // close all target udp connections when the client connection is closed - defer func() { - c.udpTargetConns.Range(func(_ string, conn net.Conn) bool { - _ = conn.Close() - return true - }) - }() - // client -> target go func() { defer cancel() + + c.udpTargetConns = make(map[socksAddr]net.Conn) + // close all target udp connections when the client connection is closed + defer func() { + for _, conn := range c.udpTargetConns { + _ = conn.Close() + } + }() + buf := make([]byte, bufferSize) for { select { @@ -354,33 +354,27 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn) er func (c *Conn) getOrDialTargetConn( ctx context.Context, clientConn net.PacketConn, - targetAddr string, + targetAddr socksAddr, ) (net.Conn, error) { - host, port, err := splitHostPort(targetAddr) - if err != nil { - return nil, err - } - - conn, loaded := c.udpTargetConns.Load(targetAddr) - if loaded { + conn, exist := c.udpTargetConns[targetAddr] + if exist { return conn, nil } - conn, err = c.srv.dial(ctx, "udp", targetAddr) + conn, err := c.srv.dial(ctx, "udp", targetAddr.hostPort()) if err != nil { return nil, err } - c.udpTargetConns.Store(targetAddr, conn) + c.udpTargetConns[targetAddr] = conn // target -> client go func() { buf := make([]byte, bufferSize) - addr := socksAddr{addrType: getAddrType(host), addr: host, port: port} for { select { case <-ctx.Done(): return default: - err := c.handleUDPResponse(clientConn, addr, conn, buf) + err := c.handleUDPResponse(clientConn, targetAddr, conn, buf) if err != nil { if isTimeout(err) { continue @@ -414,18 +408,17 @@ func (c *Conn) handleUDPRequest( return fmt.Errorf("parse udp request: %w", err) } - targetAddr := req.addr.hostPort() - targetConn, err := c.getOrDialTargetConn(ctx, clientConn, targetAddr) + targetConn, err := c.getOrDialTargetConn(ctx, clientConn, req.addr) if err != nil { - return fmt.Errorf("dial target %s fail: %w", targetAddr, err) + return fmt.Errorf("dial target %s fail: %w", req.addr, err) } nn, err := targetConn.Write(data) if err != nil { - return fmt.Errorf("write to target %s fail: %w", targetAddr, err) + return fmt.Errorf("write to target %s fail: %w", req.addr, err) } if nn != len(data) { - return fmt.Errorf("write to target %s fail: %w", targetAddr, io.ErrShortWrite) + return fmt.Errorf("write to target %s fail: %w", req.addr, io.ErrShortWrite) } return nil } @@ -652,10 +645,15 @@ func (s socksAddr) marshal() ([]byte, error) { pkt = binary.BigEndian.AppendUint16(pkt, s.port) return pkt, nil } + func (s socksAddr) hostPort() string { return net.JoinHostPort(s.addr, strconv.Itoa(int(s.port))) } +func (s socksAddr) String() string { + return s.hostPort() +} + // response contains the contents of // a response packet sent from the proxy // to the client. diff --git a/net/socks5/socks5_test.go b/net/socks5/socks5_test.go index 11ea59d4b..bc6fac79f 100644 --- a/net/socks5/socks5_test.go +++ b/net/socks5/socks5_test.go @@ -169,12 +169,25 @@ func TestReadPassword(t *testing.T) { func TestUDP(t *testing.T) { // backend UDP server which we'll use SOCKS5 to connect to - listener, err := net.ListenPacket("udp", ":0") - if err != nil { - t.Fatal(err) + newUDPEchoServer := func() net.PacketConn { + listener, err := net.ListenPacket("udp", ":0") + if err != nil { + t.Fatal(err) + } + go udpEchoServer(listener) + return listener } - backendServerPort := listener.LocalAddr().(*net.UDPAddr).Port - go udpEchoServer(listener) + + const echoServerNumber = 3 + echoServerListener := make([]net.PacketConn, echoServerNumber) + for i := 0; i < echoServerNumber; i++ { + echoServerListener[i] = newUDPEchoServer() + } + defer func() { + for i := 0; i < echoServerNumber; i++ { + _ = echoServerListener[i].Close() + } + }() // SOCKS5 server socks5, err := net.Listen("tcp", ":0") @@ -184,84 +197,93 @@ func TestUDP(t *testing.T) { socks5Port := socks5.Addr().(*net.TCPAddr).Port go socks5Server(socks5) - // net/proxy don't support UDP, so we need to manually send the SOCKS5 UDP request - conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", socks5Port)) - if err != nil { - t.Fatal(err) - } - _, err = conn.Write([]byte{0x05, 0x01, 0x00}) // client hello with no auth - if err != nil { - t.Fatal(err) - } - buf := make([]byte, 1024) - n, err := conn.Read(buf) // server hello - if err != nil { - t.Fatal(err) - } - if n != 2 || buf[0] != 0x05 || buf[1] != 0x00 { - t.Fatalf("got: %q want: 0x05 0x00", buf[:n]) - } + // make a socks5 udpAssociate conn + newUdpAssociateConn := func() (socks5Conn net.Conn, socks5UDPAddr socksAddr) { + // net/proxy don't support UDP, so we need to manually send the SOCKS5 UDP request + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", socks5Port)) + if err != nil { + t.Fatal(err) + } + _, err = conn.Write([]byte{socks5Version, 0x01, noAuthRequired}) // client hello with no auth + if err != nil { + t.Fatal(err) + } + buf := make([]byte, 1024) + n, err := conn.Read(buf) // server hello + if err != nil { + t.Fatal(err) + } + if n != 2 || buf[0] != socks5Version || buf[1] != noAuthRequired { + t.Fatalf("got: %q want: 0x05 0x00", buf[:n]) + } - targetAddr := socksAddr{ - addrType: domainName, - addr: "localhost", - port: uint16(backendServerPort), - } - targetAddrPkt, err := targetAddr.marshal() - if err != nil { - t.Fatal(err) - } - _, err = conn.Write(append([]byte{0x05, 0x03, 0x00}, targetAddrPkt...)) // client reqeust - if err != nil { - t.Fatal(err) - } + targetAddr := socksAddr{addrType: ipv4, addr: "0.0.0.0", port: 0} + targetAddrPkt, err := targetAddr.marshal() + if err != nil { + t.Fatal(err) + } + _, err = conn.Write(append([]byte{socks5Version, byte(udpAssociate), 0x00}, targetAddrPkt...)) // client reqeust + if err != nil { + t.Fatal(err) + } - n, err = conn.Read(buf) // server response - if err != nil { - t.Fatal(err) - } - if n < 3 || !bytes.Equal(buf[:3], []byte{0x05, 0x00, 0x00}) { - t.Fatalf("got: %q want: 0x05 0x00 0x00", buf[:n]) + n, err = conn.Read(buf) // server response + if err != nil { + t.Fatal(err) + } + if n < 3 || !bytes.Equal(buf[:3], []byte{socks5Version, 0x00, 0x00}) { + t.Fatalf("got: %q want: 0x05 0x00 0x00", buf[:n]) + } + udpProxySocksAddr, err := parseSocksAddr(bytes.NewReader(buf[3:n])) + if err != nil { + t.Fatal(err) + } + + return conn, udpProxySocksAddr } - udpProxySocksAddr, err := parseSocksAddr(bytes.NewReader(buf[3:n])) - if err != nil { - t.Fatal(err) + + conn, udpProxySocksAddr := newUdpAssociateConn() + defer conn.Close() + + sendUDPAndWaitResponse := func(socks5UDPConn net.Conn, addr socksAddr, body []byte) (responseBody []byte) { + udpPayload, err := (&udpRequest{addr: addr}).marshal() + if err != nil { + t.Fatal(err) + } + udpPayload = append(udpPayload, body...) + _, err = socks5UDPConn.Write(udpPayload) + if err != nil { + t.Fatal(err) + } + buf := make([]byte, 1024) + n, err := socks5UDPConn.Read(buf) + if err != nil { + t.Fatal(err) + } + _, responseBody, err = parseUDPRequest(buf[:n]) + if err != nil { + t.Fatal(err) + } + return responseBody } udpProxyAddr, err := net.ResolveUDPAddr("udp", udpProxySocksAddr.hostPort()) if err != nil { t.Fatal(err) } - udpConn, err := net.DialUDP("udp", nil, udpProxyAddr) - if err != nil { - t.Fatal(err) - } - udpPayload, err := (&udpRequest{addr: targetAddr}).marshal() - if err != nil { - t.Fatal(err) - } - udpPayload = append(udpPayload, []byte("Test")...) - _, err = udpConn.Write(udpPayload) // send udp package - if err != nil { - t.Fatal(err) - } - n, _, err = udpConn.ReadFrom(buf) - if err != nil { - t.Fatal(err) - } - _, responseBody, err := parseUDPRequest(buf[:n]) // read udp response - if err != nil { - t.Fatal(err) - } - if string(responseBody) != "Test" { - t.Fatalf("got: %q want: Test", responseBody) - } - err = udpConn.Close() + socks5UDPConn, err := net.DialUDP("udp", nil, udpProxyAddr) if err != nil { t.Fatal(err) } - err = conn.Close() - if err != nil { - t.Fatal(err) + defer socks5UDPConn.Close() + + for i := 0; i < echoServerNumber; i++ { + port := echoServerListener[i].LocalAddr().(*net.UDPAddr).Port + addr := socksAddr{addrType: ipv4, addr: "127.0.0.1", port: uint16(port)} + requestBody := []byte(fmt.Sprintf("Test %d", i)) + responseBody := sendUDPAndWaitResponse(socks5UDPConn, addr, requestBody) + if !bytes.Equal(requestBody, responseBody) { + t.Fatalf("got: %q want: %q", responseBody, requestBody) + } } }