// Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause // Package socks5 is a SOCKS5 server implementation. // // This is used for userspace networking in Tailscale. Specifically, // this is used for dialing out of the machine to other nodes, without // the host kernel's involvement, so it doesn't proper routing tables, // TUN, IPv6, etc. This package is meant to only handle the SOCKS5 protocol // details and not any integration with Tailscale internals itself. // // The glue between this package and Tailscale is in net/socks5/tssocks. package socks5 import ( "bytes" "context" "encoding/binary" "errors" "fmt" "io" "log" "net" "strconv" "time" "tailscale.com/types/logger" ) // Authentication METHODs described in RFC 1928, section 3. const ( noAuthRequired byte = 0 passwordAuth byte = 2 noAcceptableAuth byte = 255 ) // passwordAuthVersion is the auth version byte described in RFC 1929. const passwordAuthVersion = 1 // socks5Version is the byte that represents the SOCKS version // in requests. const socks5Version byte = 5 // commandType are the bytes sent in SOCKS5 packets // that represent the kind of connection the client needs. type commandType byte // The set of valid SOCKS5 commands as described in RFC 1928. const ( connect commandType = 1 bind commandType = 2 udpAssociate commandType = 3 ) // addrType are the bytes sent in SOCKS5 packets // that represent particular address types. type addrType byte // The set of valid SOCKS5 address types as defined in RFC 1928. const ( ipv4 addrType = 1 domainName addrType = 3 ipv6 addrType = 4 ) // replyCode are the bytes sent in SOCKS5 packets // that represent replies from the server to a client // request. type replyCode byte // The set of valid SOCKS5 reply types as per the RFC 1928. const ( success replyCode = 0 generalFailure replyCode = 1 connectionNotAllowed replyCode = 2 networkUnreachable replyCode = 3 hostUnreachable replyCode = 4 connectionRefused replyCode = 5 ttlExpired replyCode = 6 commandNotSupported replyCode = 7 addrTypeNotSupported replyCode = 8 ) // UDP conn default buffer size and read timeout. const ( bufferSize = 8 * 1024 readTimeout = 5 * time.Second ) // Server is a SOCKS5 proxy server. type Server struct { // Logf optionally specifies the logger to use. // If nil, the standard logger is used. Logf logger.Logf // Dialer optionally specifies the dialer to use for outgoing connections. // If nil, the net package's standard dialer is used. Dialer func(ctx context.Context, network, addr string) (net.Conn, error) // Username and Password, if set, are the credential clients must provide. Username string Password string } func (s *Server) dial(ctx context.Context, network, addr string) (net.Conn, error) { dial := s.Dialer if dial == nil { dialer := &net.Dialer{} dial = dialer.DialContext } return dial(ctx, network, addr) } func (s *Server) logf(format string, args ...any) { logf := s.Logf if logf == nil { logf = log.Printf } logf(format, args...) } // Serve accepts and handles incoming connections on the given listener. func (s *Server) Serve(l net.Listener) error { defer l.Close() for { c, err := l.Accept() if err != nil { return err } go func() { defer c.Close() conn := &Conn{logf: s.Logf, clientConn: c, srv: s} err := conn.Run() if err != nil { s.logf("client connection failed: %v", err) } }() } } // Conn is a SOCKS5 connection for client to reach // server. type Conn struct { // The struct is filled by each of the internal // methods in turn as the transaction progresses. logf logger.Logf srv *Server clientConn net.Conn request *request udpClientAddr net.Addr udpTargetConns map[socksAddr]net.Conn } // Run starts the new connection. func (c *Conn) Run() error { needAuth := c.srv.Username != "" || c.srv.Password != "" authMethod := noAuthRequired if needAuth { authMethod = passwordAuth } err := parseClientGreeting(c.clientConn, authMethod) if err != nil { c.clientConn.Write([]byte{socks5Version, noAcceptableAuth}) return err } c.clientConn.Write([]byte{socks5Version, authMethod}) if !needAuth { return c.handleRequest() } user, pwd, err := parseClientAuth(c.clientConn) if err != nil || user != c.srv.Username || pwd != c.srv.Password { c.clientConn.Write([]byte{1, 1}) // auth error return err } c.clientConn.Write([]byte{1, 0}) // auth success return c.handleRequest() } func (c *Conn) handleRequest() error { req, err := parseClientRequest(c.clientConn) if err != nil { res := errorResponse(generalFailure) buf, _ := res.marshal() c.clientConn.Write(buf) return err } c.request = req switch req.command { case connect: return c.handleTCP() case udpAssociate: return c.handleUDP() default: res := errorResponse(commandNotSupported) buf, _ := res.marshal() c.clientConn.Write(buf) return fmt.Errorf("unsupported command %v", req.command) } } func (c *Conn) handleTCP() error { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() srv, err := c.srv.dial( ctx, "tcp", c.request.destination.hostPort(), ) if err != nil { res := errorResponse(generalFailure) buf, _ := res.marshal() c.clientConn.Write(buf) return err } defer srv.Close() localAddr := srv.LocalAddr().String() serverAddr, serverPort, err := splitHostPort(localAddr) if err != nil { return err } res := &response{ reply: success, bindAddr: socksAddr{ addrType: getAddrType(serverAddr), addr: serverAddr, port: serverPort, }, } buf, err := res.marshal() if err != nil { res = errorResponse(generalFailure) buf, _ = res.marshal() } c.clientConn.Write(buf) errc := make(chan error, 2) go func() { _, err := io.Copy(c.clientConn, srv) if err != nil { err = fmt.Errorf("from backend to client: %w", err) } errc <- err }() go func() { _, err := io.Copy(srv, c.clientConn) if err != nil { err = fmt.Errorf("from client to backend: %w", err) } errc <- err }() return <-errc } func (c *Conn) handleUDP() error { // The DST.ADDR and DST.PORT fields contain the address and port that // the client expects to use to send UDP datagrams on for the // association. The server MAY use this information to limit access // to the association. // @see Page 6, https://datatracker.ietf.org/doc/html/rfc1928. // // We do NOT limit the access from the client currently in this implementation. _ = c.request.destination addr := c.clientConn.LocalAddr() host, _, err := net.SplitHostPort(addr.String()) if err != nil { return err } clientUDPConn, err := net.ListenPacket("udp", net.JoinHostPort(host, "0")) if err != nil { res := errorResponse(generalFailure) buf, _ := res.marshal() c.clientConn.Write(buf) return err } defer clientUDPConn.Close() bindAddr, bindPort, err := splitHostPort(clientUDPConn.LocalAddr().String()) if err != nil { return err } res := &response{ reply: success, bindAddr: socksAddr{ addrType: getAddrType(bindAddr), addr: bindAddr, port: bindPort, }, } buf, err := res.marshal() if err != nil { res = errorResponse(generalFailure) buf, _ = res.marshal() } c.clientConn.Write(buf) return c.transferUDP(c.clientConn, clientUDPConn) } func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() // client -> target go func() { defer cancel() c.udpTargetConns = make(map[socksAddr]net.Conn) // close all target udp connections when the client connection is closed defer func() { for _, conn := range c.udpTargetConns { _ = conn.Close() } }() buf := make([]byte, bufferSize) for { select { case <-ctx.Done(): return default: err := c.handleUDPRequest(ctx, clientConn, buf) if err != nil { if isTimeout(err) { continue } if errors.Is(err, net.ErrClosed) { return } c.logf("udp transfer: handle udp request fail: %v", err) } } } }() // A UDP association terminates when the TCP connection that the UDP // ASSOCIATE request arrived on terminates. RFC1928 _, err := io.Copy(io.Discard, associatedTCP) if err != nil { err = fmt.Errorf("udp associated tcp conn: %w", err) } return err } func (c *Conn) getOrDialTargetConn( ctx context.Context, clientConn net.PacketConn, targetAddr socksAddr, ) (net.Conn, error) { conn, exist := c.udpTargetConns[targetAddr] if exist { return conn, nil } conn, err := c.srv.dial(ctx, "udp", targetAddr.hostPort()) if err != nil { return nil, err } c.udpTargetConns[targetAddr] = conn // target -> client go func() { buf := make([]byte, bufferSize) for { select { case <-ctx.Done(): return default: err := c.handleUDPResponse(clientConn, targetAddr, conn, buf) if err != nil { if isTimeout(err) { continue } if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) { return } c.logf("udp transfer: handle udp response fail: %v", err) } } } }() return conn, nil } func (c *Conn) handleUDPRequest( ctx context.Context, clientConn net.PacketConn, buf []byte, ) error { // add a deadline for the read to avoid blocking forever _ = clientConn.SetReadDeadline(time.Now().Add(readTimeout)) n, addr, err := clientConn.ReadFrom(buf) if err != nil { return fmt.Errorf("read from client: %w", err) } c.udpClientAddr = addr req, data, err := parseUDPRequest(buf[:n]) if err != nil { return fmt.Errorf("parse udp request: %w", err) } targetConn, err := c.getOrDialTargetConn(ctx, clientConn, req.addr) if err != nil { return fmt.Errorf("dial target %s fail: %w", req.addr, err) } nn, err := targetConn.Write(data) if err != nil { return fmt.Errorf("write to target %s fail: %w", req.addr, err) } if nn != len(data) { return fmt.Errorf("write to target %s fail: %w", req.addr, io.ErrShortWrite) } return nil } func (c *Conn) handleUDPResponse( clientConn net.PacketConn, targetAddr socksAddr, targetConn net.Conn, buf []byte, ) error { // add a deadline for the read to avoid blocking forever _ = targetConn.SetReadDeadline(time.Now().Add(readTimeout)) n, err := targetConn.Read(buf) if err != nil { return fmt.Errorf("read from target: %w", err) } hdr := udpRequest{addr: targetAddr} pkt, err := hdr.marshal() if err != nil { return fmt.Errorf("marshal udp request: %w", err) } data := append(pkt, buf[:n]...) // use addr from client to send back nn, err := clientConn.WriteTo(data, c.udpClientAddr) if err != nil { return fmt.Errorf("write to client: %w", err) } if nn != len(data) { return fmt.Errorf("write to client: %w", io.ErrShortWrite) } return nil } func isTimeout(err error) bool { terr, ok := errors.Unwrap(err).(interface{ Timeout() bool }) return ok && terr.Timeout() } func splitHostPort(hostport string) (host string, port uint16, err error) { host, portStr, err := net.SplitHostPort(hostport) if err != nil { return "", 0, err } portInt, err := strconv.Atoi(portStr) if err != nil { return "", 0, err } if portInt < 0 || portInt > 65535 { return "", 0, fmt.Errorf("invalid port number %d", portInt) } return host, uint16(portInt), nil } // parseClientGreeting parses a request initiation packet. func parseClientGreeting(r io.Reader, authMethod byte) error { var hdr [2]byte _, err := io.ReadFull(r, hdr[:]) if err != nil { return fmt.Errorf("could not read packet header") } if hdr[0] != socks5Version { return fmt.Errorf("incompatible SOCKS version") } count := int(hdr[1]) methods := make([]byte, count) _, err = io.ReadFull(r, methods) if err != nil { return fmt.Errorf("could not read methods") } for _, m := range methods { if m == authMethod { return nil } } return fmt.Errorf("no acceptable auth methods") } func parseClientAuth(r io.Reader) (usr, pwd string, err error) { var hdr [2]byte if _, err := io.ReadFull(r, hdr[:]); err != nil { return "", "", fmt.Errorf("could not read auth packet header") } if hdr[0] != passwordAuthVersion { return "", "", fmt.Errorf("bad SOCKS auth version") } usrLen := int(hdr[1]) usrBytes := make([]byte, usrLen) if _, err := io.ReadFull(r, usrBytes); err != nil { return "", "", fmt.Errorf("could not read auth packet username") } var hdrPwd [1]byte if _, err := io.ReadFull(r, hdrPwd[:]); err != nil { return "", "", fmt.Errorf("could not read auth packet password length") } pwdLen := int(hdrPwd[0]) pwdBytes := make([]byte, pwdLen) if _, err := io.ReadFull(r, pwdBytes); err != nil { return "", "", fmt.Errorf("could not read auth packet password") } return string(usrBytes), string(pwdBytes), nil } func getAddrType(addr string) addrType { if ip := net.ParseIP(addr); ip != nil { if ip.To4() != nil { return ipv4 } return ipv6 } return domainName } // request represents data contained within a SOCKS5 // connection request packet. type request struct { command commandType destination socksAddr } // parseClientRequest converts raw packet bytes into a // SOCKS5Request struct. func parseClientRequest(r io.Reader) (*request, error) { var hdr [3]byte _, err := io.ReadFull(r, hdr[:]) if err != nil { return nil, fmt.Errorf("could not read packet header") } cmd := hdr[1] destination, err := parseSocksAddr(r) return &request{ command: commandType(cmd), destination: destination, }, err } type socksAddr struct { addrType addrType addr string port uint16 } var zeroSocksAddr = socksAddr{addrType: ipv4, addr: "0.0.0.0", port: 0} func parseSocksAddr(r io.Reader) (addr socksAddr, err error) { var addrTypeData [1]byte _, err = io.ReadFull(r, addrTypeData[:]) if err != nil { return socksAddr{}, fmt.Errorf("could not read address type") } dstAddrType := addrType(addrTypeData[0]) var destination string switch dstAddrType { case ipv4: var ip [4]byte _, err = io.ReadFull(r, ip[:]) if err != nil { return socksAddr{}, fmt.Errorf("could not read IPv4 address") } destination = net.IP(ip[:]).String() case domainName: var dstSizeByte [1]byte _, err = io.ReadFull(r, dstSizeByte[:]) if err != nil { return socksAddr{}, fmt.Errorf("could not read domain name size") } dstSize := int(dstSizeByte[0]) domainName := make([]byte, dstSize) _, err = io.ReadFull(r, domainName) if err != nil { return socksAddr{}, fmt.Errorf("could not read domain name") } destination = string(domainName) case ipv6: var ip [16]byte _, err = io.ReadFull(r, ip[:]) if err != nil { return socksAddr{}, fmt.Errorf("could not read IPv6 address") } destination = net.IP(ip[:]).String() default: return socksAddr{}, fmt.Errorf("unsupported address type") } var portBytes [2]byte _, err = io.ReadFull(r, portBytes[:]) if err != nil { return socksAddr{}, fmt.Errorf("could not read port") } port := binary.BigEndian.Uint16(portBytes[:]) return socksAddr{ addrType: dstAddrType, addr: destination, port: port, }, nil } func (s socksAddr) marshal() ([]byte, error) { var addr []byte switch s.addrType { case ipv4: addr = net.ParseIP(s.addr).To4() if addr == nil { return nil, fmt.Errorf("invalid IPv4 address for binding") } case domainName: if len(s.addr) > 255 { return nil, fmt.Errorf("invalid domain name for binding") } addr = make([]byte, 0, len(s.addr)+1) addr = append(addr, byte(len(s.addr))) addr = append(addr, []byte(s.addr)...) case ipv6: addr = net.ParseIP(s.addr).To16() if addr == nil { return nil, fmt.Errorf("invalid IPv6 address for binding") } default: return nil, fmt.Errorf("unsupported address type") } pkt := []byte{byte(s.addrType)} pkt = append(pkt, addr...) pkt = binary.BigEndian.AppendUint16(pkt, s.port) return pkt, nil } func (s socksAddr) hostPort() string { return net.JoinHostPort(s.addr, strconv.Itoa(int(s.port))) } func (s socksAddr) String() string { return s.hostPort() } // response contains the contents of // a response packet sent from the proxy // to the client. type response struct { reply replyCode bindAddr socksAddr } func errorResponse(code replyCode) *response { return &response{reply: code, bindAddr: zeroSocksAddr} } // marshal converts a SOCKS5Response struct into // a packet. If res.reply == Success, it may throw an error on // receiving an invalid bind address. Otherwise, it will not throw. func (res *response) marshal() ([]byte, error) { pkt := make([]byte, 3) pkt[0] = socks5Version pkt[1] = byte(res.reply) pkt[2] = 0 // null reserved byte addrPkt, err := res.bindAddr.marshal() if err != nil { return nil, err } return append(pkt, addrPkt...), nil } type udpRequest struct { frag byte addr socksAddr } // +----+------+------+----------+----------+----------+ // |RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA | // +----+------+------+----------+----------+----------+ // | 2 | 1 | 1 | Variable | 2 | Variable | // +----+------+------+----------+----------+----------+ func parseUDPRequest(data []byte) (*udpRequest, []byte, error) { if len(data) < 4 { return nil, nil, fmt.Errorf("invalid packet length") } // reserved bytes if !(data[0] == 0 && data[1] == 0) { return nil, nil, fmt.Errorf("invalid udp request header") } frag := data[2] reader := bytes.NewReader(data[3:]) addr, err := parseSocksAddr(reader) bodyLen := reader.Len() // (*bytes.Reader).Len() return unread data length body := data[len(data)-bodyLen:] return &udpRequest{ frag: frag, addr: addr, }, body, err } func (u *udpRequest) marshal() ([]byte, error) { pkt := make([]byte, 3) pkt[0] = 0 pkt[1] = 0 pkt[2] = u.frag addrPkt, err := u.addr.marshal() if err != nil { return nil, err } return append(pkt, addrPkt...), nil }