diff --git a/net/socks5/socks5.go b/net/socks5/socks5.go index 7e8749456..b774ebe24 100644 --- a/net/socks5/socks5.go +++ b/net/socks5/socks5.go @@ -25,15 +25,20 @@ import ( "tailscale.com/types/logger" ) +// Authentication METHODs described in RFC 1928, section 3. const ( noAuthRequired byte = 0 + passwordAuth byte = 2 noAcceptableAuth byte = 255 - - // socks5Version is the byte that represents the SOCKS version - // in requests. - socks5Version byte = 5 ) +// passwordAuthVersion is the auth version byte described in RFC 1929. +const passwordAuthVersion = 1 + +// socks5Version is the byte that represents the SOCKS version +// in requests. +const socks5Version byte = 5 + // commandType are the bytes sent in SOCKS5 packets // that represent the kind of connection the client needs. type commandType byte @@ -83,6 +88,10 @@ type Server struct { // Dialer optionally specifies the dialer to use for outgoing connections. // If nil, the net package's standard dialer is used. Dialer func(ctx context.Context, network, addr string) (net.Conn, error) + + // Username and Password, if set, are the credential clients must provide. + Username string + Password string } func (s *Server) dial(ctx context.Context, network, addr string) (net.Conn, error) { @@ -134,12 +143,29 @@ type Conn struct { // Run starts the new connection. func (c *Conn) Run() error { - err := parseClientGreeting(c.clientConn) + needAuth := c.srv.Username != "" || c.srv.Password != "" + authMethod := noAuthRequired + if needAuth { + authMethod = passwordAuth + } + + err := parseClientGreeting(c.clientConn, authMethod) if err != nil { c.clientConn.Write([]byte{socks5Version, noAcceptableAuth}) return err } - c.clientConn.Write([]byte{socks5Version, noAuthRequired}) + c.clientConn.Write([]byte{socks5Version, authMethod}) + if !needAuth { + return c.handleRequest() + } + + user, pwd, err := parseClientAuth(c.clientConn) + if err != nil || user != c.srv.Username || pwd != c.srv.Password { + c.clientConn.Write([]byte{1, 1}) // auth error + return err + } + c.clientConn.Write([]byte{1, 0}) // auth success + return c.handleRequest() } @@ -220,10 +246,8 @@ func (c *Conn) handleRequest() error { return <-errc } -// parseClientGreeting parses a request initiation packet -// and returns a slice that contains the acceptable auth methods -// for the client. -func parseClientGreeting(r io.Reader) error { +// parseClientGreeting parses a request initiation packet. +func parseClientGreeting(r io.Reader, authMethod byte) error { var hdr [2]byte _, err := io.ReadFull(r, hdr[:]) if err != nil { @@ -239,13 +263,38 @@ func parseClientGreeting(r io.Reader) error { return fmt.Errorf("could not read methods") } for _, m := range methods { - if m == noAuthRequired { + if m == authMethod { return nil } } return fmt.Errorf("no acceptable auth methods") } +func parseClientAuth(r io.Reader) (usr, pwd string, err error) { + var hdr [2]byte + if _, err := io.ReadFull(r, hdr[:]); err != nil { + return "", "", fmt.Errorf("could not read auth packet header") + } + if hdr[0] != passwordAuthVersion { + return "", "", fmt.Errorf("bad SOCKS auth version") + } + usrLen := int(hdr[1]) + usrBytes := make([]byte, usrLen) + if _, err := io.ReadFull(r, usrBytes); err != nil { + return "", "", fmt.Errorf("could not read auth packet username") + } + var hdrPwd [1]byte + if _, err := io.ReadFull(r, hdrPwd[:]); err != nil { + return "", "", fmt.Errorf("could not read auth packet password length") + } + pwdLen := int(hdrPwd[0]) + pwdBytes := make([]byte, pwdLen) + if _, err := io.ReadFull(r, pwdBytes); err != nil { + return "", "", fmt.Errorf("could not read auth packet password") + } + return string(usrBytes), string(pwdBytes), nil +} + // request represents data contained within a SOCKS5 // connection request packet. type request struct { diff --git a/net/socks5/socks5_test.go b/net/socks5/socks5_test.go index cbeffbff5..201a66575 100644 --- a/net/socks5/socks5_test.go +++ b/net/socks5/socks5_test.go @@ -4,6 +4,7 @@ package socks5 import ( + "errors" "fmt" "io" "net" @@ -74,3 +75,80 @@ func TestRead(t *testing.T) { t.Fatal(err) } } + +func TestReadPassword(t *testing.T) { + // backend server which we'll use SOCKS5 to connect to + ln, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + backendServerPort := ln.Addr().(*net.TCPAddr).Port + go backendServer(ln) + + socks5ln, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + socks5ln.Close() + }) + auth := &proxy.Auth{User: "foo", Password: "bar"} + go func() { + s := Server{Username: auth.User, Password: auth.Password} + err := s.Serve(socks5ln) + if err != nil && !errors.Is(err, net.ErrClosed) { + panic(err) + } + }() + + addr := fmt.Sprintf("localhost:%d", socks5ln.Addr().(*net.TCPAddr).Port) + + if d, err := proxy.SOCKS5("tcp", addr, nil, proxy.Direct); err != nil { + t.Fatal(err) + } else { + if _, err := d.Dial("tcp", addr); err == nil { + t.Fatal("expected no-auth dial error") + } + } + + badPwd := &proxy.Auth{User: "foo", Password: "not right"} + if d, err := proxy.SOCKS5("tcp", addr, badPwd, proxy.Direct); err != nil { + t.Fatal(err) + } else { + if _, err := d.Dial("tcp", addr); err == nil { + t.Fatal("expected bad password dial error") + } + } + + badUsr := &proxy.Auth{User: "not right", Password: "bar"} + if d, err := proxy.SOCKS5("tcp", addr, badUsr, proxy.Direct); err != nil { + t.Fatal(err) + } else { + if _, err := d.Dial("tcp", addr); err == nil { + t.Fatal("expected bad username dial error") + } + } + + socksDialer, err := proxy.SOCKS5("tcp", addr, auth, proxy.Direct) + if err != nil { + t.Fatal(err) + } + + addr = fmt.Sprintf("localhost:%d", backendServerPort) + conn, err := socksDialer.Dial("tcp", addr) + if err != nil { + t.Fatal(err) + } + + buf := make([]byte, 4) + if _, err := io.ReadFull(conn, buf); err != nil { + t.Fatal(err) + } + if string(buf) != "Test" { + t.Fatalf("got: %q want: Test", buf) + } + + if err := conn.Close(); err != nil { + t.Fatal(err) + } +}