net/socks5: support UDP

Updates #7581

Signed-off-by: VimT <me@vimt.me>
pull/13042/head
VimT 5 months ago committed by Brad Fitzpatrick
parent 91d2e1772d
commit e3f047618b

@ -13,8 +13,10 @@
package socks5 package socks5
import ( import (
"bytes"
"context" "context"
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"io" "io"
"log" "log"
@ -121,7 +123,7 @@ func (s *Server) Serve(l net.Listener) error {
} }
go func() { go func() {
defer c.Close() defer c.Close()
conn := &Conn{clientConn: c, srv: s} conn := &Conn{logf: s.Logf, clientConn: c, srv: s}
err := conn.Run() err := conn.Run()
if err != nil { if err != nil {
s.logf("client connection failed: %v", err) s.logf("client connection failed: %v", err)
@ -136,9 +138,12 @@ type Conn struct {
// The struct is filled by each of the internal // The struct is filled by each of the internal
// methods in turn as the transaction progresses. // methods in turn as the transaction progresses.
logf logger.Logf
srv *Server srv *Server
clientConn net.Conn clientConn net.Conn
request *request request *request
udpClientAddr net.Addr
} }
// Run starts the new connection. // Run starts the new connection.
@ -172,58 +177,59 @@ func (c *Conn) Run() error {
func (c *Conn) handleRequest() error { func (c *Conn) handleRequest() error {
req, err := parseClientRequest(c.clientConn) req, err := parseClientRequest(c.clientConn)
if err != nil { if err != nil {
res := &response{reply: generalFailure} res := errorResponse(generalFailure)
buf, _ := res.marshal() buf, _ := res.marshal()
c.clientConn.Write(buf) c.clientConn.Write(buf)
return err return err
} }
if req.command != connect {
res := &response{reply: commandNotSupported} c.request = req
switch req.command {
case connect:
return c.handleTCP()
case udpAssociate:
return c.handleUDP()
default:
res := errorResponse(commandNotSupported)
buf, _ := res.marshal() buf, _ := res.marshal()
c.clientConn.Write(buf) c.clientConn.Write(buf)
return fmt.Errorf("unsupported command %v", req.command) return fmt.Errorf("unsupported command %v", req.command)
} }
c.request = req }
func (c *Conn) handleTCP() error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
srv, err := c.srv.dial( srv, err := c.srv.dial(
ctx, ctx,
"tcp", "tcp",
net.JoinHostPort(c.request.destination, strconv.Itoa(int(c.request.port))), c.request.destination.hostPort(),
) )
if err != nil { if err != nil {
res := &response{reply: generalFailure} res := errorResponse(generalFailure)
buf, _ := res.marshal() buf, _ := res.marshal()
c.clientConn.Write(buf) c.clientConn.Write(buf)
return err return err
} }
defer srv.Close() defer srv.Close()
serverAddr, serverPortStr, err := net.SplitHostPort(srv.LocalAddr().String())
localAddr := srv.LocalAddr().String()
serverAddr, serverPort, err := splitHostPort(localAddr)
if err != nil { if err != nil {
return err return err
} }
serverPort, _ := strconv.Atoi(serverPortStr)
var bindAddrType addrType
if ip := net.ParseIP(serverAddr); ip != nil {
if ip.To4() != nil {
bindAddrType = ipv4
} else {
bindAddrType = ipv6
}
} else {
bindAddrType = domainName
}
res := &response{ res := &response{
reply: success, reply: success,
bindAddrType: bindAddrType, bindAddr: socksAddr{
bindAddr: serverAddr, addrType: getAddrType(serverAddr),
bindPort: uint16(serverPort), addr: serverAddr,
port: serverPort,
},
} }
buf, err := res.marshal() buf, err := res.marshal()
if err != nil { if err != nil {
res = &response{reply: generalFailure} res = errorResponse(generalFailure)
buf, _ = res.marshal() buf, _ = res.marshal()
} }
c.clientConn.Write(buf) c.clientConn.Write(buf)
@ -246,6 +252,208 @@ func (c *Conn) handleRequest() error {
return <-errc return <-errc
} }
func (c *Conn) handleUDP() error {
// The DST.ADDR and DST.PORT fields contain the address and port that
// the client expects to use to send UDP datagrams on for the
// association. The server MAY use this information to limit access
// to the association.
// @see Page 6, https://datatracker.ietf.org/doc/html/rfc1928.
//
// We do NOT limit the access from the client currently in this implementation.
_ = c.request.destination
addr := c.clientConn.LocalAddr()
host, _, err := net.SplitHostPort(addr.String())
if err != nil {
return err
}
clientUDPConn, err := net.ListenPacket("udp", net.JoinHostPort(host, "0"))
if err != nil {
res := errorResponse(generalFailure)
buf, _ := res.marshal()
c.clientConn.Write(buf)
return err
}
defer clientUDPConn.Close()
serverUDPConn, err := net.ListenPacket("udp", "[::]:0")
if err != nil {
res := errorResponse(generalFailure)
buf, _ := res.marshal()
c.clientConn.Write(buf)
return err
}
defer serverUDPConn.Close()
bindAddr, bindPort, err := splitHostPort(clientUDPConn.LocalAddr().String())
if err != nil {
return err
}
res := &response{
reply: success,
bindAddr: socksAddr{
addrType: getAddrType(bindAddr),
addr: bindAddr,
port: bindPort,
},
}
buf, err := res.marshal()
if err != nil {
res = errorResponse(generalFailure)
buf, _ = res.marshal()
}
c.clientConn.Write(buf)
return c.transferUDP(c.clientConn, clientUDPConn, serverUDPConn)
}
func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, targetConn net.PacketConn) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
const bufferSize = 8 * 1024
const readTimeout = 5 * time.Second
// client -> target
go func() {
defer cancel()
buf := make([]byte, bufferSize)
for {
select {
case <-ctx.Done():
return
default:
err := c.handleUDPRequest(clientConn, targetConn, buf, readTimeout)
if err != nil {
if isTimeout(err) {
continue
}
if errors.Is(err, net.ErrClosed) {
return
}
c.logf("udp transfer: handle udp request fail: %v", err)
}
}
}
}()
// target -> client
go func() {
defer cancel()
buf := make([]byte, bufferSize)
for {
select {
case <-ctx.Done():
return
default:
err := c.handleUDPResponse(targetConn, clientConn, buf, readTimeout)
if err != nil {
if isTimeout(err) {
continue
}
if errors.Is(err, net.ErrClosed) {
return
}
c.logf("udp transfer: handle udp response fail: %v", err)
}
}
}
}()
// A UDP association terminates when the TCP connection that the UDP
// ASSOCIATE request arrived on terminates. RFC1928
_, err := io.Copy(io.Discard, associatedTCP)
if err != nil {
err = fmt.Errorf("udp associated tcp conn: %w", err)
}
return err
}
func (c *Conn) handleUDPRequest(
clientConn net.PacketConn,
targetConn net.PacketConn,
buf []byte,
readTimeout time.Duration,
) error {
// add a deadline for the read to avoid blocking forever
_ = clientConn.SetReadDeadline(time.Now().Add(readTimeout))
n, addr, err := clientConn.ReadFrom(buf)
if err != nil {
return fmt.Errorf("read from client: %w", err)
}
c.udpClientAddr = addr
req, data, err := parseUDPRequest(buf[:n])
if err != nil {
return fmt.Errorf("parse udp request: %w", err)
}
targetAddr, err := net.ResolveUDPAddr("udp", req.addr.hostPort())
if err != nil {
c.logf("resolve target addr fail: %v", err)
}
nn, err := targetConn.WriteTo(data, targetAddr)
if err != nil {
return fmt.Errorf("write to target %s fail: %w", targetAddr, err)
}
if nn != len(data) {
return fmt.Errorf("write to target %s fail: %w", targetAddr, io.ErrShortWrite)
}
return nil
}
func (c *Conn) handleUDPResponse(
targetConn net.PacketConn,
clientConn net.PacketConn,
buf []byte,
readTimeout time.Duration,
) error {
// add a deadline for the read to avoid blocking forever
_ = targetConn.SetReadDeadline(time.Now().Add(readTimeout))
n, addr, err := targetConn.ReadFrom(buf)
if err != nil {
return fmt.Errorf("read from target: %w", err)
}
host, port, err := splitHostPort(addr.String())
if err != nil {
return fmt.Errorf("split host port: %w", err)
}
hdr := udpRequest{addr: socksAddr{addrType: getAddrType(host), addr: host, port: port}}
pkt, err := hdr.marshal()
if err != nil {
return fmt.Errorf("marshal udp request: %w", err)
}
data := append(pkt, buf[:n]...)
// use addr from client to send back
nn, err := clientConn.WriteTo(data, c.udpClientAddr)
if err != nil {
return fmt.Errorf("write to client: %w", err)
}
if nn != len(data) {
return fmt.Errorf("write to client: %w", io.ErrShortWrite)
}
return nil
}
func isTimeout(err error) bool {
terr, ok := errors.Unwrap(err).(interface{ Timeout() bool })
return ok && terr.Timeout()
}
func splitHostPort(hostport string) (host string, port uint16, err error) {
host, portStr, err := net.SplitHostPort(hostport)
if err != nil {
return "", 0, err
}
portInt, err := strconv.Atoi(portStr)
if err != nil {
return "", 0, err
}
if portInt < 0 || portInt > 65535 {
return "", 0, fmt.Errorf("invalid port number %d", portInt)
}
return host, uint16(portInt), nil
}
// parseClientGreeting parses a request initiation packet. // parseClientGreeting parses a request initiation packet.
func parseClientGreeting(r io.Reader, authMethod byte) error { func parseClientGreeting(r io.Reader, authMethod byte) error {
var hdr [2]byte var hdr [2]byte
@ -295,123 +503,205 @@ func parseClientAuth(r io.Reader) (usr, pwd string, err error) {
return string(usrBytes), string(pwdBytes), nil return string(usrBytes), string(pwdBytes), nil
} }
func getAddrType(addr string) addrType {
if ip := net.ParseIP(addr); ip != nil {
if ip.To4() != nil {
return ipv4
}
return ipv6
}
return domainName
}
// request represents data contained within a SOCKS5 // request represents data contained within a SOCKS5
// connection request packet. // connection request packet.
type request struct { type request struct {
command commandType command commandType
destination string destination socksAddr
port uint16
destAddrType addrType
} }
// parseClientRequest converts raw packet bytes into a // parseClientRequest converts raw packet bytes into a
// SOCKS5Request struct. // SOCKS5Request struct.
func parseClientRequest(r io.Reader) (*request, error) { func parseClientRequest(r io.Reader) (*request, error) {
var hdr [4]byte var hdr [3]byte
_, err := io.ReadFull(r, hdr[:]) _, err := io.ReadFull(r, hdr[:])
if err != nil { if err != nil {
return nil, fmt.Errorf("could not read packet header") return nil, fmt.Errorf("could not read packet header")
} }
cmd := hdr[1] cmd := hdr[1]
destAddrType := addrType(hdr[3])
var destination string destination, err := parseSocksAddr(r)
var port uint16 return &request{
command: commandType(cmd),
destination: destination,
}, err
}
type socksAddr struct {
addrType addrType
addr string
port uint16
}
var zeroSocksAddr = socksAddr{addrType: ipv4, addr: "0.0.0.0", port: 0}
func parseSocksAddr(r io.Reader) (addr socksAddr, err error) {
var addrTypeData [1]byte
_, err = io.ReadFull(r, addrTypeData[:])
if err != nil {
return socksAddr{}, fmt.Errorf("could not read address type")
}
if destAddrType == ipv4 { dstAddrType := addrType(addrTypeData[0])
var destination string
switch dstAddrType {
case ipv4:
var ip [4]byte var ip [4]byte
_, err = io.ReadFull(r, ip[:]) _, err = io.ReadFull(r, ip[:])
if err != nil { if err != nil {
return nil, fmt.Errorf("could not read IPv4 address") return socksAddr{}, fmt.Errorf("could not read IPv4 address")
} }
destination = net.IP(ip[:]).String() destination = net.IP(ip[:]).String()
} else if destAddrType == domainName { case domainName:
var dstSizeByte [1]byte var dstSizeByte [1]byte
_, err = io.ReadFull(r, dstSizeByte[:]) _, err = io.ReadFull(r, dstSizeByte[:])
if err != nil { if err != nil {
return nil, fmt.Errorf("could not read domain name size") return socksAddr{}, fmt.Errorf("could not read domain name size")
} }
dstSize := int(dstSizeByte[0]) dstSize := int(dstSizeByte[0])
domainName := make([]byte, dstSize) domainName := make([]byte, dstSize)
_, err = io.ReadFull(r, domainName) _, err = io.ReadFull(r, domainName)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not read domain name") return socksAddr{}, fmt.Errorf("could not read domain name")
} }
destination = string(domainName) destination = string(domainName)
} else if destAddrType == ipv6 { case ipv6:
var ip [16]byte var ip [16]byte
_, err = io.ReadFull(r, ip[:]) _, err = io.ReadFull(r, ip[:])
if err != nil { if err != nil {
return nil, fmt.Errorf("could not read IPv6 address") return socksAddr{}, fmt.Errorf("could not read IPv6 address")
} }
destination = net.IP(ip[:]).String() destination = net.IP(ip[:]).String()
} else { default:
return nil, fmt.Errorf("unsupported address type") return socksAddr{}, fmt.Errorf("unsupported address type")
} }
var portBytes [2]byte var portBytes [2]byte
_, err = io.ReadFull(r, portBytes[:]) _, err = io.ReadFull(r, portBytes[:])
if err != nil { if err != nil {
return nil, fmt.Errorf("could not read port") return socksAddr{}, fmt.Errorf("could not read port")
} }
port = binary.BigEndian.Uint16(portBytes[:]) port := binary.BigEndian.Uint16(portBytes[:])
return socksAddr{
return &request{ addrType: dstAddrType,
command: commandType(cmd), addr: destination,
destination: destination,
port: port, port: port,
destAddrType: destAddrType,
}, nil }, nil
} }
func (s socksAddr) marshal() ([]byte, error) {
var addr []byte
switch s.addrType {
case ipv4:
addr = net.ParseIP(s.addr).To4()
if addr == nil {
return nil, fmt.Errorf("invalid IPv4 address for binding")
}
case domainName:
if len(s.addr) > 255 {
return nil, fmt.Errorf("invalid domain name for binding")
}
addr = make([]byte, 0, len(s.addr)+1)
addr = append(addr, byte(len(s.addr)))
addr = append(addr, []byte(s.addr)...)
case ipv6:
addr = net.ParseIP(s.addr).To16()
if addr == nil {
return nil, fmt.Errorf("invalid IPv6 address for binding")
}
default:
return nil, fmt.Errorf("unsupported address type")
}
pkt := []byte{byte(s.addrType)}
pkt = append(pkt, addr...)
pkt = binary.BigEndian.AppendUint16(pkt, s.port)
return pkt, nil
}
func (s socksAddr) hostPort() string {
return net.JoinHostPort(s.addr, strconv.Itoa(int(s.port)))
}
// response contains the contents of // response contains the contents of
// a response packet sent from the proxy // a response packet sent from the proxy
// to the client. // to the client.
type response struct { type response struct {
reply replyCode reply replyCode
bindAddrType addrType bindAddr socksAddr
bindAddr string }
bindPort uint16
func errorResponse(code replyCode) *response {
return &response{reply: code, bindAddr: zeroSocksAddr}
} }
// marshal converts a SOCKS5Response struct into // marshal converts a SOCKS5Response struct into
// a packet. If res.reply == Success, it may throw an error on // a packet. If res.reply == Success, it may throw an error on
// receiving an invalid bind address. Otherwise, it will not throw. // receiving an invalid bind address. Otherwise, it will not throw.
func (res *response) marshal() ([]byte, error) { func (res *response) marshal() ([]byte, error) {
pkt := make([]byte, 4) pkt := make([]byte, 3)
pkt[0] = socks5Version pkt[0] = socks5Version
pkt[1] = byte(res.reply) pkt[1] = byte(res.reply)
pkt[2] = 0 // null reserved byte pkt[2] = 0 // null reserved byte
pkt[3] = byte(res.bindAddrType)
if res.reply != success { addrPkt, err := res.bindAddr.marshal()
return pkt, nil if err != nil {
return nil, err
} }
var addr []byte return append(pkt, addrPkt...), nil
switch res.bindAddrType { }
case ipv4:
addr = net.ParseIP(res.bindAddr).To4() type udpRequest struct {
if addr == nil { frag byte
return nil, fmt.Errorf("invalid IPv4 address for binding") addr socksAddr
} }
case domainName:
if len(res.bindAddr) > 255 { // +----+------+------+----------+----------+----------+
return nil, fmt.Errorf("invalid domain name for binding") // |RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA |
} // +----+------+------+----------+----------+----------+
addr = make([]byte, 0, len(res.bindAddr)+1) // | 2 | 1 | 1 | Variable | 2 | Variable |
addr = append(addr, byte(len(res.bindAddr))) // +----+------+------+----------+----------+----------+
addr = append(addr, []byte(res.bindAddr)...) func parseUDPRequest(data []byte) (*udpRequest, []byte, error) {
case ipv6: if len(data) < 4 {
addr = net.ParseIP(res.bindAddr).To16() return nil, nil, fmt.Errorf("invalid packet length")
if addr == nil {
return nil, fmt.Errorf("invalid IPv6 address for binding")
} }
default:
return nil, fmt.Errorf("unsupported address type") // reserved bytes
if !(data[0] == 0 && data[1] == 0) {
return nil, nil, fmt.Errorf("invalid udp request header")
} }
pkt = append(pkt, addr...) frag := data[2]
pkt = binary.BigEndian.AppendUint16(pkt, uint16(res.bindPort))
return pkt, nil reader := bytes.NewReader(data[3:])
addr, err := parseSocksAddr(reader)
bodyLen := reader.Len() // (*bytes.Reader).Len() return unread data length
body := data[len(data)-bodyLen:]
return &udpRequest{
frag: frag,
addr: addr,
}, body, err
}
func (u *udpRequest) marshal() ([]byte, error) {
pkt := make([]byte, 3)
pkt[0] = 0
pkt[1] = 0
pkt[2] = u.frag
addrPkt, err := u.addr.marshal()
if err != nil {
return nil, err
}
return append(pkt, addrPkt...), nil
} }

@ -4,6 +4,7 @@
package socks5 package socks5
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -32,6 +33,19 @@ func backendServer(listener net.Listener) {
listener.Close() listener.Close()
} }
func udpEchoServer(conn net.PacketConn) {
var buf [1024]byte
n, addr, err := conn.ReadFrom(buf[:])
if err != nil {
panic(err)
}
_, err = conn.WriteTo(buf[:n], addr)
if err != nil {
panic(err)
}
conn.Close()
}
func TestRead(t *testing.T) { func TestRead(t *testing.T) {
// backend server which we'll use SOCKS5 to connect to // backend server which we'll use SOCKS5 to connect to
listener, err := net.Listen("tcp", ":0") listener, err := net.Listen("tcp", ":0")
@ -152,3 +166,102 @@ func TestReadPassword(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
} }
func TestUDP(t *testing.T) {
// backend UDP server which we'll use SOCKS5 to connect to
listener, err := net.ListenPacket("udp", ":0")
if err != nil {
t.Fatal(err)
}
backendServerPort := listener.LocalAddr().(*net.UDPAddr).Port
go udpEchoServer(listener)
// SOCKS5 server
socks5, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
socks5Port := socks5.Addr().(*net.TCPAddr).Port
go socks5Server(socks5)
// net/proxy don't support UDP, so we need to manually send the SOCKS5 UDP request
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", socks5Port))
if err != nil {
t.Fatal(err)
}
_, err = conn.Write([]byte{0x05, 0x01, 0x00}) // client hello with no auth
if err != nil {
t.Fatal(err)
}
buf := make([]byte, 1024)
n, err := conn.Read(buf) // server hello
if err != nil {
t.Fatal(err)
}
if n != 2 || buf[0] != 0x05 || buf[1] != 0x00 {
t.Fatalf("got: %q want: 0x05 0x00", buf[:n])
}
targetAddr := socksAddr{
addrType: domainName,
addr: "localhost",
port: uint16(backendServerPort),
}
targetAddrPkt, err := targetAddr.marshal()
if err != nil {
t.Fatal(err)
}
_, err = conn.Write(append([]byte{0x05, 0x03, 0x00}, targetAddrPkt...)) // client reqeust
if err != nil {
t.Fatal(err)
}
n, err = conn.Read(buf) // server response
if err != nil {
t.Fatal(err)
}
if n < 3 || !bytes.Equal(buf[:3], []byte{0x05, 0x00, 0x00}) {
t.Fatalf("got: %q want: 0x05 0x00 0x00", buf[:n])
}
udpProxySocksAddr, err := parseSocksAddr(bytes.NewReader(buf[3:n]))
if err != nil {
t.Fatal(err)
}
udpProxyAddr, err := net.ResolveUDPAddr("udp", udpProxySocksAddr.hostPort())
if err != nil {
t.Fatal(err)
}
udpConn, err := net.DialUDP("udp", nil, udpProxyAddr)
if err != nil {
t.Fatal(err)
}
udpPayload, err := (&udpRequest{addr: targetAddr}).marshal()
if err != nil {
t.Fatal(err)
}
udpPayload = append(udpPayload, []byte("Test")...)
_, err = udpConn.Write(udpPayload) // send udp package
if err != nil {
t.Fatal(err)
}
n, _, err = udpConn.ReadFrom(buf)
if err != nil {
t.Fatal(err)
}
_, responseBody, err := parseUDPRequest(buf[:n]) // read udp response
if err != nil {
t.Fatal(err)
}
if string(responseBody) != "Test" {
t.Fatalf("got: %q want: Test", responseBody)
}
err = udpConn.Close()
if err != nil {
t.Fatal(err)
}
err = conn.Close()
if err != nil {
t.Fatal(err)
}
}

Loading…
Cancel
Save