diff --git a/net/socks5/socks5.go b/net/socks5/socks5.go index 0d651537f..db315d949 100644 --- a/net/socks5/socks5.go +++ b/net/socks5/socks5.go @@ -22,6 +22,7 @@ import ( "log" "net" "strconv" + "tailscale.com/syncs" "time" "tailscale.com/types/logger" @@ -81,6 +82,12 @@ const ( addrTypeNotSupported replyCode = 8 ) +// UDP conn default buffer size and read timeout. +const ( + bufferSize = 8 * 1024 + readTimeout = 5 * time.Second +) + // Server is a SOCKS5 proxy server. type Server struct { // Logf optionally specifies the logger to use. @@ -143,7 +150,8 @@ type Conn struct { clientConn net.Conn request *request - udpClientAddr net.Addr + udpClientAddr net.Addr + udpTargetConns syncs.Map[string, net.Conn] } // Run starts the new connection. @@ -276,15 +284,6 @@ func (c *Conn) handleUDP() error { } defer clientUDPConn.Close() - serverUDPConn, err := net.ListenPacket("udp", "[::]:0") - if err != nil { - res := errorResponse(generalFailure) - buf, _ := res.marshal() - c.clientConn.Write(buf) - return err - } - defer serverUDPConn.Close() - bindAddr, bindPort, err := splitHostPort(clientUDPConn.LocalAddr().String()) if err != nil { return err @@ -305,14 +304,20 @@ func (c *Conn) handleUDP() error { } c.clientConn.Write(buf) - return c.transferUDP(c.clientConn, clientUDPConn, serverUDPConn) + return c.transferUDP(c.clientConn, clientUDPConn) } -func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, targetConn net.PacketConn) error { +func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - const bufferSize = 8 * 1024 - const readTimeout = 5 * time.Second + + // close all target udp connections when the client connection is closed + defer func() { + c.udpTargetConns.Range(func(_ string, conn net.Conn) bool { + _ = conn.Close() + return true + }) + }() // client -> target go func() { @@ -323,7 +328,7 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, ta case <-ctx.Done(): return default: - err := c.handleUDPRequest(clientConn, targetConn, buf, readTimeout) + err := c.handleUDPRequest(ctx, clientConn, buf) if err != nil { if isTimeout(err) { continue @@ -337,21 +342,50 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, ta } }() + // A UDP association terminates when the TCP connection that the UDP + // ASSOCIATE request arrived on terminates. RFC1928 + _, err := io.Copy(io.Discard, associatedTCP) + if err != nil { + err = fmt.Errorf("udp associated tcp conn: %w", err) + } + return err +} + +func (c *Conn) getOrDialTargetConn( + ctx context.Context, + clientConn net.PacketConn, + targetAddr string, +) (net.Conn, error) { + host, port, err := splitHostPort(targetAddr) + if err != nil { + return nil, err + } + + conn, loaded := c.udpTargetConns.Load(targetAddr) + if loaded { + return conn, nil + } + conn, err = c.srv.dial(ctx, "udp", targetAddr) + if err != nil { + return nil, err + } + c.udpTargetConns.Store(targetAddr, conn) + // target -> client go func() { - defer cancel() buf := make([]byte, bufferSize) + addr := socksAddr{addrType: getAddrType(host), addr: host, port: port} for { select { case <-ctx.Done(): return default: - err := c.handleUDPResponse(targetConn, clientConn, buf, readTimeout) + err := c.handleUDPResponse(clientConn, addr, conn, buf) if err != nil { if isTimeout(err) { continue } - if errors.Is(err, net.ErrClosed) { + if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) { return } c.logf("udp transfer: handle udp response fail: %v", err) @@ -360,20 +394,13 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, ta } }() - // A UDP association terminates when the TCP connection that the UDP - // ASSOCIATE request arrived on terminates. RFC1928 - _, err := io.Copy(io.Discard, associatedTCP) - if err != nil { - err = fmt.Errorf("udp associated tcp conn: %w", err) - } - return err + return conn, nil } func (c *Conn) handleUDPRequest( + ctx context.Context, clientConn net.PacketConn, - targetConn net.PacketConn, buf []byte, - readTimeout time.Duration, ) error { // add a deadline for the read to avoid blocking forever _ = clientConn.SetReadDeadline(time.Now().Add(readTimeout)) @@ -386,12 +413,14 @@ func (c *Conn) handleUDPRequest( if err != nil { return fmt.Errorf("parse udp request: %w", err) } - targetAddr, err := net.ResolveUDPAddr("udp", req.addr.hostPort()) + + targetAddr := req.addr.hostPort() + targetConn, err := c.getOrDialTargetConn(ctx, clientConn, targetAddr) if err != nil { - c.logf("resolve target addr fail: %v", err) + return fmt.Errorf("dial target %s fail: %w", targetAddr, err) } - nn, err := targetConn.WriteTo(data, targetAddr) + nn, err := targetConn.Write(data) if err != nil { return fmt.Errorf("write to target %s fail: %w", targetAddr, err) } @@ -402,22 +431,18 @@ func (c *Conn) handleUDPRequest( } func (c *Conn) handleUDPResponse( - targetConn net.PacketConn, clientConn net.PacketConn, + targetAddr socksAddr, + targetConn net.Conn, buf []byte, - readTimeout time.Duration, ) error { // add a deadline for the read to avoid blocking forever _ = targetConn.SetReadDeadline(time.Now().Add(readTimeout)) - n, addr, err := targetConn.ReadFrom(buf) + n, err := targetConn.Read(buf) if err != nil { return fmt.Errorf("read from target: %w", err) } - host, port, err := splitHostPort(addr.String()) - if err != nil { - return fmt.Errorf("split host port: %w", err) - } - hdr := udpRequest{addr: socksAddr{addrType: getAddrType(host), addr: host, port: port}} + hdr := udpRequest{addr: targetAddr} pkt, err := hdr.marshal() if err != nil { return fmt.Errorf("marshal udp request: %w", err)