@ -22,6 +22,7 @@ import (
"log"
"log"
"net"
"net"
"strconv"
"strconv"
"tailscale.com/syncs"
"time"
"time"
"tailscale.com/types/logger"
"tailscale.com/types/logger"
@ -81,6 +82,12 @@ const (
addrTypeNotSupported replyCode = 8
addrTypeNotSupported replyCode = 8
)
)
// UDP conn default buffer size and read timeout.
const (
bufferSize = 8 * 1024
readTimeout = 5 * time . Second
)
// Server is a SOCKS5 proxy server.
// Server is a SOCKS5 proxy server.
type Server struct {
type Server struct {
// Logf optionally specifies the logger to use.
// Logf optionally specifies the logger to use.
@ -143,7 +150,8 @@ type Conn struct {
clientConn net . Conn
clientConn net . Conn
request * request
request * request
udpClientAddr net . Addr
udpClientAddr net . Addr
udpTargetConns syncs . Map [ string , net . Conn ]
}
}
// Run starts the new connection.
// Run starts the new connection.
@ -276,15 +284,6 @@ func (c *Conn) handleUDP() error {
}
}
defer clientUDPConn . Close ( )
defer clientUDPConn . Close ( )
serverUDPConn , err := net . ListenPacket ( "udp" , "[::]:0" )
if err != nil {
res := errorResponse ( generalFailure )
buf , _ := res . marshal ( )
c . clientConn . Write ( buf )
return err
}
defer serverUDPConn . Close ( )
bindAddr , bindPort , err := splitHostPort ( clientUDPConn . LocalAddr ( ) . String ( ) )
bindAddr , bindPort , err := splitHostPort ( clientUDPConn . LocalAddr ( ) . String ( ) )
if err != nil {
if err != nil {
return err
return err
@ -305,14 +304,20 @@ func (c *Conn) handleUDP() error {
}
}
c . clientConn . Write ( buf )
c . clientConn . Write ( buf )
return c . transferUDP ( c . clientConn , clientUDPConn , serverUDPConn )
return c . transferUDP ( c . clientConn , clientUDPConn )
}
}
func ( c * Conn ) transferUDP ( associatedTCP net . Conn , clientConn net . PacketConn , targetConn net . PacketConn ) error {
func ( c * Conn ) transferUDP ( associatedTCP net . Conn , clientConn net . PacketConn ) error {
ctx , cancel := context . WithCancel ( context . Background ( ) )
ctx , cancel := context . WithCancel ( context . Background ( ) )
defer cancel ( )
defer cancel ( )
const bufferSize = 8 * 1024
const readTimeout = 5 * time . Second
// close all target udp connections when the client connection is closed
defer func ( ) {
c . udpTargetConns . Range ( func ( _ string , conn net . Conn ) bool {
_ = conn . Close ( )
return true
} )
} ( )
// client -> target
// client -> target
go func ( ) {
go func ( ) {
@ -323,7 +328,7 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, ta
case <- ctx . Done ( ) :
case <- ctx . Done ( ) :
return
return
default :
default :
err := c . handleUDPRequest ( c lientConn, targetConn , buf , readTimeout )
err := c . handleUDPRequest ( c tx, clientConn , buf )
if err != nil {
if err != nil {
if isTimeout ( err ) {
if isTimeout ( err ) {
continue
continue
@ -337,21 +342,50 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, ta
}
}
} ( )
} ( )
// A UDP association terminates when the TCP connection that the UDP
// ASSOCIATE request arrived on terminates. RFC1928
_ , err := io . Copy ( io . Discard , associatedTCP )
if err != nil {
err = fmt . Errorf ( "udp associated tcp conn: %w" , err )
}
return err
}
func ( c * Conn ) getOrDialTargetConn (
ctx context . Context ,
clientConn net . PacketConn ,
targetAddr string ,
) ( net . Conn , error ) {
host , port , err := splitHostPort ( targetAddr )
if err != nil {
return nil , err
}
conn , loaded := c . udpTargetConns . Load ( targetAddr )
if loaded {
return conn , nil
}
conn , err = c . srv . dial ( ctx , "udp" , targetAddr )
if err != nil {
return nil , err
}
c . udpTargetConns . Store ( targetAddr , conn )
// target -> client
// target -> client
go func ( ) {
go func ( ) {
defer cancel ( )
buf := make ( [ ] byte , bufferSize )
buf := make ( [ ] byte , bufferSize )
addr := socksAddr { addrType : getAddrType ( host ) , addr : host , port : port }
for {
for {
select {
select {
case <- ctx . Done ( ) :
case <- ctx . Done ( ) :
return
return
default :
default :
err := c . handleUDPResponse ( targetConn , clientConn , buf , readTimeout )
err := c . handleUDPResponse ( clientConn, addr , conn , buf )
if err != nil {
if err != nil {
if isTimeout ( err ) {
if isTimeout ( err ) {
continue
continue
}
}
if errors . Is ( err , net . ErrClosed ) {
if errors . Is ( err , net . ErrClosed ) || errors . Is ( err , io . EOF ) {
return
return
}
}
c . logf ( "udp transfer: handle udp response fail: %v" , err )
c . logf ( "udp transfer: handle udp response fail: %v" , err )
@ -360,20 +394,13 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, ta
}
}
} ( )
} ( )
// A UDP association terminates when the TCP connection that the UDP
return conn , nil
// ASSOCIATE request arrived on terminates. RFC1928
_ , err := io . Copy ( io . Discard , associatedTCP )
if err != nil {
err = fmt . Errorf ( "udp associated tcp conn: %w" , err )
}
return err
}
}
func ( c * Conn ) handleUDPRequest (
func ( c * Conn ) handleUDPRequest (
ctx context . Context ,
clientConn net . PacketConn ,
clientConn net . PacketConn ,
targetConn net . PacketConn ,
buf [ ] byte ,
buf [ ] byte ,
readTimeout time . Duration ,
) error {
) error {
// add a deadline for the read to avoid blocking forever
// add a deadline for the read to avoid blocking forever
_ = clientConn . SetReadDeadline ( time . Now ( ) . Add ( readTimeout ) )
_ = clientConn . SetReadDeadline ( time . Now ( ) . Add ( readTimeout ) )
@ -386,12 +413,14 @@ func (c *Conn) handleUDPRequest(
if err != nil {
if err != nil {
return fmt . Errorf ( "parse udp request: %w" , err )
return fmt . Errorf ( "parse udp request: %w" , err )
}
}
targetAddr , err := net . ResolveUDPAddr ( "udp" , req . addr . hostPort ( ) )
targetAddr := req . addr . hostPort ( )
targetConn , err := c . getOrDialTargetConn ( ctx , clientConn , targetAddr )
if err != nil {
if err != nil {
c . logf ( "resolve target addr fail: %v" , err )
return fmt . Errorf ( "dial target %s fail: %w" , targetAddr , err )
}
}
nn , err := targetConn . Write To ( data , targetAddr )
nn , err := targetConn . Write ( data )
if err != nil {
if err != nil {
return fmt . Errorf ( "write to target %s fail: %w" , targetAddr , err )
return fmt . Errorf ( "write to target %s fail: %w" , targetAddr , err )
}
}
@ -402,22 +431,18 @@ func (c *Conn) handleUDPRequest(
}
}
func ( c * Conn ) handleUDPResponse (
func ( c * Conn ) handleUDPResponse (
targetConn net . PacketConn ,
clientConn net . PacketConn ,
clientConn net . PacketConn ,
targetAddr socksAddr ,
targetConn net . Conn ,
buf [ ] byte ,
buf [ ] byte ,
readTimeout time . Duration ,
) error {
) error {
// add a deadline for the read to avoid blocking forever
// add a deadline for the read to avoid blocking forever
_ = targetConn . SetReadDeadline ( time . Now ( ) . Add ( readTimeout ) )
_ = targetConn . SetReadDeadline ( time . Now ( ) . Add ( readTimeout ) )
n , addr, err := targetConn . Read From ( buf )
n , err := targetConn . Read ( buf )
if err != nil {
if err != nil {
return fmt . Errorf ( "read from target: %w" , err )
return fmt . Errorf ( "read from target: %w" , err )
}
}
host , port , err := splitHostPort ( addr . String ( ) )
hdr := udpRequest { addr : targetAddr }
if err != nil {
return fmt . Errorf ( "split host port: %w" , err )
}
hdr := udpRequest { addr : socksAddr { addrType : getAddrType ( host ) , addr : host , port : port } }
pkt , err := hdr . marshal ( )
pkt , err := hdr . marshal ( )
if err != nil {
if err != nil {
return fmt . Errorf ( "marshal udp request: %w" , err )
return fmt . Errorf ( "marshal udp request: %w" , err )