ssh: Add logic to set accepted environment variables in SSH session (#13559)

Add logic to set environment variables that match the SSH rule's
`acceptEnv` settings in the SSH session's environment.

Updates https://github.com/tailscale/corp/issues/22775

Signed-off-by: Mario Minardi <mario@tailscale.com>
pull/13639/head
Mario Minardi 3 weeks ago committed by GitHub
parent dd6b808acf
commit 8f44ba1cd6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -4,6 +4,7 @@
package tailssh package tailssh
import ( import (
"fmt"
"slices" "slices"
"strings" "strings"
) )
@ -17,27 +18,35 @@ import (
// //
// acceptEnv values may contain * and ? wildcard characters which match against // acceptEnv values may contain * and ? wildcard characters which match against
// zero or one or more characters and a single character respectively. // zero or one or more characters and a single character respectively.
func filterEnv(acceptEnv []string, environ []string) []string { func filterEnv(acceptEnv []string, environ []string) ([]string, error) {
var acceptedPairs []string var acceptedPairs []string
// Quick return if we have an empty list.
if acceptEnv == nil || len(acceptEnv) == 0 {
return acceptedPairs, nil
}
for _, envPair := range environ { for _, envPair := range environ {
envVar := strings.Split(envPair, "=")[0] variableName, _, ok := strings.Cut(envPair, "=")
if !ok {
return nil, fmt.Errorf(`invalid environment variable: %q. Variables must be in "KEY=VALUE" format`, envPair)
}
// Short circuit if we have a direct match between the environment // Short circuit if we have a direct match between the environment
// variable and an AcceptEnv value. // variable and an AcceptEnv value.
if slices.Contains(acceptEnv, envVar) { if slices.Contains(acceptEnv, variableName) {
acceptedPairs = append(acceptedPairs, envPair) acceptedPairs = append(acceptedPairs, envPair)
continue continue
} }
// Otherwise check if we have a wildcard pattern that matches. // Otherwise check if we have a wildcard pattern that matches.
if matchAcceptEnv(acceptEnv, envVar) { if matchAcceptEnv(acceptEnv, variableName) {
acceptedPairs = append(acceptedPairs, envPair) acceptedPairs = append(acceptedPairs, envPair)
continue continue
} }
} }
return acceptedPairs return acceptedPairs, nil
} }
// matchAcceptEnv is a convenience function that wraps calling matchAcceptEnvPattern // matchAcceptEnv is a convenience function that wraps calling matchAcceptEnvPattern

@ -108,6 +108,7 @@ func TestFilterEnv(t *testing.T) {
acceptEnv []string acceptEnv []string
environ []string environ []string
expectedFiltered []string expectedFiltered []string
wantErrMessage string
}{ }{
{ {
name: "simple direct matches", name: "simple direct matches",
@ -127,11 +128,26 @@ func TestFilterEnv(t *testing.T) {
environ: []string{"FOO=BAR", "FOO2=BAZ", "FOO_3=123", "FOOOO4-2=AbCdEfG", "FO1-kmndGamc79567=ABC", "FO57=BAR2"}, environ: []string{"FOO=BAR", "FOO2=BAZ", "FOO_3=123", "FOOOO4-2=AbCdEfG", "FO1-kmndGamc79567=ABC", "FO57=BAR2"},
expectedFiltered: []string{"FOO=BAR", "FOOOO4-2=AbCdEfG", "FO1-kmndGamc79567=ABC"}, expectedFiltered: []string{"FOO=BAR", "FOOOO4-2=AbCdEfG", "FO1-kmndGamc79567=ABC"},
}, },
{
name: "environ format invalid",
acceptEnv: []string{"FO?", "FOOO*", "FO*5?7"},
environ: []string{"FOOBAR"},
expectedFiltered: nil,
wantErrMessage: `invalid environment variable: "FOOBAR". Variables must be in "KEY=VALUE" format`,
},
} }
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
filtered := filterEnv(tc.acceptEnv, tc.environ) filtered, err := filterEnv(tc.acceptEnv, tc.environ)
if err == nil && tc.wantErrMessage != "" {
t.Errorf("wanted error with message %q but error was nil", tc.wantErrMessage)
}
if err != nil && err.Error() != tc.wantErrMessage {
t.Errorf("err = %v; want %v", err, tc.wantErrMessage)
}
if diff := cmp.Diff(tc.expectedFiltered, filtered); diff != "" { if diff := cmp.Diff(tc.expectedFiltered, filtered); diff != "" {
t.Errorf("unexpected filter result (-got,+want): \n%s", diff) t.Errorf("unexpected filter result (-got,+want): \n%s", diff)
} }

@ -12,6 +12,7 @@
package tailssh package tailssh
import ( import (
"encoding/json"
"errors" "errors"
"flag" "flag"
"fmt" "fmt"
@ -154,6 +155,22 @@ func (ss *sshSession) newIncubatorCommand(logf logger.Logf) (cmd *exec.Cmd, err
incubatorArgs = append(incubatorArgs, "--cmd="+ss.RawCommand()) incubatorArgs = append(incubatorArgs, "--cmd="+ss.RawCommand())
} }
allowSendEnv := nm.HasCap(tailcfg.NodeAttrSSHEnvironmentVariables)
if allowSendEnv {
env, err := filterEnv(ss.conn.acceptEnv, ss.Session.Environ())
if err != nil {
return nil, err
}
if len(env) > 0 {
encoded, err := json.Marshal(env)
if err != nil {
return nil, fmt.Errorf("failed to encode environment: %w", err)
}
incubatorArgs = append(incubatorArgs, fmt.Sprintf("--encoded-env=%q", encoded))
}
}
return exec.CommandContext(ss.ctx, ss.conn.srv.tailscaledPath, incubatorArgs...), nil return exec.CommandContext(ss.ctx, ss.conn.srv.tailscaledPath, incubatorArgs...), nil
} }
@ -192,6 +209,9 @@ type incubatorArgs struct {
forceV1Behavior bool forceV1Behavior bool
debugTest bool debugTest bool
isSELinuxEnforcing bool isSELinuxEnforcing bool
encodedEnv string
allowListEnvKeys string
forwardedEnviron []string
} }
func parseIncubatorArgs(args []string) (incubatorArgs, error) { func parseIncubatorArgs(args []string) (incubatorArgs, error) {
@ -215,6 +235,7 @@ func parseIncubatorArgs(args []string) (incubatorArgs, error) {
flags.BoolVar(&ia.forceV1Behavior, "force-v1-behavior", false, "allow falling back to the su command if login is unavailable") flags.BoolVar(&ia.forceV1Behavior, "force-v1-behavior", false, "allow falling back to the su command if login is unavailable")
flags.BoolVar(&ia.debugTest, "debug-test", false, "should debug in test mode") flags.BoolVar(&ia.debugTest, "debug-test", false, "should debug in test mode")
flags.BoolVar(&ia.isSELinuxEnforcing, "is-selinux-enforcing", false, "whether SELinux is in enforcing mode") flags.BoolVar(&ia.isSELinuxEnforcing, "is-selinux-enforcing", false, "whether SELinux is in enforcing mode")
flags.StringVar(&ia.encodedEnv, "encoded-env", "", "JSON encoded array of environment variables in '['key=value']' format")
flags.Parse(args) flags.Parse(args)
for _, g := range strings.Split(groups, ",") { for _, g := range strings.Split(groups, ",") {
@ -225,6 +246,30 @@ func parseIncubatorArgs(args []string) (incubatorArgs, error) {
ia.gids = append(ia.gids, gid) ia.gids = append(ia.gids, gid)
} }
ia.forwardedEnviron = os.Environ()
// pass through SSH_AUTH_SOCK environment variable to support ssh agent forwarding
ia.allowListEnvKeys = "SSH_AUTH_SOCK"
if ia.encodedEnv != "" {
unquoted, err := strconv.Unquote(ia.encodedEnv)
if err != nil {
return ia, fmt.Errorf("unable to parse encodedEnv %q: %w", ia.encodedEnv, err)
}
var extraEnviron []string
err = json.Unmarshal([]byte(unquoted), &extraEnviron)
if err != nil {
return ia, fmt.Errorf("unable to parse encodedEnv %q: %w", ia.encodedEnv, err)
}
ia.forwardedEnviron = append(ia.forwardedEnviron, extraEnviron...)
for _, v := range extraEnviron {
ia.allowListEnvKeys = fmt.Sprintf("%s,%s", ia.allowListEnvKeys, strings.Split(v, "=")[0])
}
}
return ia, nil return ia, nil
} }
@ -406,7 +451,7 @@ func tryExecLogin(dlogf logger.Logf, ia incubatorArgs) error {
dlogf("logging in with %+v", loginArgs) dlogf("logging in with %+v", loginArgs)
// If Exec works, the Go code will not proceed past this: // If Exec works, the Go code will not proceed past this:
err = unix.Exec(loginCmdPath, loginArgs, os.Environ()) err = unix.Exec(loginCmdPath, loginArgs, ia.forwardedEnviron)
// If we made it here, Exec failed. // If we made it here, Exec failed.
return err return err
@ -441,7 +486,7 @@ func trySU(dlogf logger.Logf, ia incubatorArgs) (handled bool, err error) {
loginArgs := []string{ loginArgs := []string{
su, su,
"-w", "SSH_AUTH_SOCK", // pass through SSH_AUTH_SOCK environment variable to support ssh agent forwarding "-w", ia.allowListEnvKeys,
"-l", "-l",
ia.localUser, ia.localUser,
} }
@ -453,7 +498,7 @@ func trySU(dlogf logger.Logf, ia incubatorArgs) (handled bool, err error) {
dlogf("logging in with %+v", loginArgs) dlogf("logging in with %+v", loginArgs)
// If Exec works, the Go code will not proceed past this: // If Exec works, the Go code will not proceed past this:
err = unix.Exec(su, loginArgs, os.Environ()) err = unix.Exec(su, loginArgs, ia.forwardedEnviron)
// If we made it here, Exec failed. // If we made it here, Exec failed.
return true, err return true, err
@ -482,11 +527,11 @@ func findSU(dlogf logger.Logf, ia incubatorArgs) string {
return "" return ""
} }
// First try to execute su -w SSH_AUTH_SOCK -l <user> -c true // First try to execute su -w <allow listed env> -l <user> -c true
// to make sure su supports the necessary arguments. // to make sure su supports the necessary arguments.
err = exec.Command( err = exec.Command(
su, su,
"-w", "SSH_AUTH_SOCK", "-w", ia.allowListEnvKeys,
"-l", "-l",
ia.localUser, ia.localUser,
"-c", "true", "-c", "true",
@ -515,7 +560,7 @@ func handleSSHInProcess(dlogf logger.Logf, ia incubatorArgs) error {
args := shellArgs(ia.isShell, ia.cmd) args := shellArgs(ia.isShell, ia.cmd)
dlogf("running %s %q", ia.loginShell, args) dlogf("running %s %q", ia.loginShell, args)
cmd := newCommand(ia.hasTTY, ia.loginShell, args) cmd := newCommand(ia.hasTTY, ia.loginShell, ia.forwardedEnviron, args)
err := cmd.Run() err := cmd.Run()
if ee, ok := err.(*exec.ExitError); ok { if ee, ok := err.(*exec.ExitError); ok {
ps := ee.ProcessState ps := ee.ProcessState
@ -532,12 +577,12 @@ func handleSSHInProcess(dlogf logger.Logf, ia incubatorArgs) error {
return err return err
} }
func newCommand(hasTTY bool, cmdPath string, cmdArgs []string) *exec.Cmd { func newCommand(hasTTY bool, cmdPath string, cmdEnviron []string, cmdArgs []string) *exec.Cmd {
cmd := exec.Command(cmdPath, cmdArgs...) cmd := exec.Command(cmdPath, cmdArgs...)
cmd.Stdin = os.Stdin cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr cmd.Stderr = os.Stderr
cmd.Env = os.Environ() cmd.Env = cmdEnviron
if hasTTY { if hasTTY {
// If we were launched with a tty then we should mark that as the ctty // If we were launched with a tty then we should mark that as the ctty

@ -238,6 +238,7 @@ type conn struct {
localUser *userMeta // set by doPolicyAuth localUser *userMeta // set by doPolicyAuth
userGroupIDs []string // set by doPolicyAuth userGroupIDs []string // set by doPolicyAuth
pubKey gossh.PublicKey // set by doPolicyAuth pubKey gossh.PublicKey // set by doPolicyAuth
acceptEnv []string
// mu protects the following fields. // mu protects the following fields.
// //
@ -377,7 +378,7 @@ func (c *conn) doPolicyAuth(ctx ssh.Context, pubKey ssh.PublicKey) error {
c.logf("failed to get conninfo: %v", err) c.logf("failed to get conninfo: %v", err)
return errDenied return errDenied
} }
a, localUser, err := c.evaluatePolicy(pubKey) a, localUser, acceptEnv, err := c.evaluatePolicy(pubKey)
if err != nil { if err != nil {
if pubKey == nil && c.havePubKeyPolicy() { if pubKey == nil && c.havePubKeyPolicy() {
return errPubKeyRequired return errPubKeyRequired
@ -387,6 +388,7 @@ func (c *conn) doPolicyAuth(ctx ssh.Context, pubKey ssh.PublicKey) error {
c.action0 = a c.action0 = a
c.currentAction = a c.currentAction = a
c.pubKey = pubKey c.pubKey = pubKey
c.acceptEnv = acceptEnv
if a.Message != "" { if a.Message != "" {
if err := ctx.SendAuthBanner(a.Message); err != nil { if err := ctx.SendAuthBanner(a.Message); err != nil {
return fmt.Errorf("SendBanner: %w", err) return fmt.Errorf("SendBanner: %w", err)
@ -619,16 +621,16 @@ func (c *conn) setInfo(ctx ssh.Context) error {
// evaluatePolicy returns the SSHAction and localUser after evaluating // evaluatePolicy returns the SSHAction and localUser after evaluating
// the SSHPolicy for this conn. The pubKey may be nil for "none" auth. // the SSHPolicy for this conn. The pubKey may be nil for "none" auth.
func (c *conn) evaluatePolicy(pubKey gossh.PublicKey) (_ *tailcfg.SSHAction, localUser string, _ error) { func (c *conn) evaluatePolicy(pubKey gossh.PublicKey) (_ *tailcfg.SSHAction, localUser string, acceptEnv []string, _ error) {
pol, ok := c.sshPolicy() pol, ok := c.sshPolicy()
if !ok { if !ok {
return nil, "", fmt.Errorf("tailssh: rejecting connection; no SSH policy") return nil, "", nil, fmt.Errorf("tailssh: rejecting connection; no SSH policy")
} }
a, localUser, ok := c.evalSSHPolicy(pol, pubKey) a, localUser, acceptEnv, ok := c.evalSSHPolicy(pol, pubKey)
if !ok { if !ok {
return nil, "", fmt.Errorf("tailssh: rejecting connection; no matching policy") return nil, "", nil, fmt.Errorf("tailssh: rejecting connection; no matching policy")
} }
return a, localUser, nil return a, localUser, acceptEnv, nil
} }
// pubKeyCacheEntry is the cache value for an HTTPS URL of public keys (like // pubKeyCacheEntry is the cache value for an HTTPS URL of public keys (like
@ -892,7 +894,7 @@ func (c *conn) newSSHSession(s ssh.Session) *sshSession {
// isStillValid reports whether the conn is still valid. // isStillValid reports whether the conn is still valid.
func (c *conn) isStillValid() bool { func (c *conn) isStillValid() bool {
a, localUser, err := c.evaluatePolicy(c.pubKey) a, localUser, _, err := c.evaluatePolicy(c.pubKey)
c.vlogf("stillValid: %+v %v %v", a, localUser, err) c.vlogf("stillValid: %+v %v %v", a, localUser, err)
if err != nil { if err != nil {
return false return false
@ -1275,13 +1277,13 @@ func (c *conn) ruleExpired(r *tailcfg.SSHRule) bool {
return r.RuleExpires.Before(c.srv.now()) return r.RuleExpires.Before(c.srv.now())
} }
func (c *conn) evalSSHPolicy(pol *tailcfg.SSHPolicy, pubKey gossh.PublicKey) (a *tailcfg.SSHAction, localUser string, ok bool) { func (c *conn) evalSSHPolicy(pol *tailcfg.SSHPolicy, pubKey gossh.PublicKey) (a *tailcfg.SSHAction, localUser string, acceptEnv []string, ok bool) {
for _, r := range pol.Rules { for _, r := range pol.Rules {
if a, localUser, err := c.matchRule(r, pubKey); err == nil { if a, localUser, acceptEnv, err := c.matchRule(r, pubKey); err == nil {
return a, localUser, true return a, localUser, acceptEnv, true
} }
} }
return nil, "", false return nil, "", nil, false
} }
// internal errors for testing; they don't escape to callers or logs. // internal errors for testing; they don't escape to callers or logs.
@ -1294,26 +1296,26 @@ var (
errInvalidConn = errors.New("invalid connection state") errInvalidConn = errors.New("invalid connection state")
) )
func (c *conn) matchRule(r *tailcfg.SSHRule, pubKey gossh.PublicKey) (a *tailcfg.SSHAction, localUser string, err error) { func (c *conn) matchRule(r *tailcfg.SSHRule, pubKey gossh.PublicKey) (a *tailcfg.SSHAction, localUser string, acceptEnv []string, err error) {
defer func() { defer func() {
c.vlogf("matchRule(%+v): %v", r, err) c.vlogf("matchRule(%+v): %v", r, err)
}() }()
if c == nil { if c == nil {
return nil, "", errInvalidConn return nil, "", nil, errInvalidConn
} }
if c.info == nil { if c.info == nil {
c.logf("invalid connection state") c.logf("invalid connection state")
return nil, "", errInvalidConn return nil, "", nil, errInvalidConn
} }
if r == nil { if r == nil {
return nil, "", errNilRule return nil, "", nil, errNilRule
} }
if r.Action == nil { if r.Action == nil {
return nil, "", errNilAction return nil, "", nil, errNilAction
} }
if c.ruleExpired(r) { if c.ruleExpired(r) {
return nil, "", errRuleExpired return nil, "", nil, errRuleExpired
} }
if !r.Action.Reject { if !r.Action.Reject {
// For all but Reject rules, SSHUsers is required. // For all but Reject rules, SSHUsers is required.
@ -1321,15 +1323,15 @@ func (c *conn) matchRule(r *tailcfg.SSHRule, pubKey gossh.PublicKey) (a *tailcfg
// empty string anyway. // empty string anyway.
localUser = mapLocalUser(r.SSHUsers, c.info.sshUser) localUser = mapLocalUser(r.SSHUsers, c.info.sshUser)
if localUser == "" { if localUser == "" {
return nil, "", errUserMatch return nil, "", nil, errUserMatch
} }
} }
if ok, err := c.anyPrincipalMatches(r.Principals, pubKey); err != nil { if ok, err := c.anyPrincipalMatches(r.Principals, pubKey); err != nil {
return nil, "", err return nil, "", nil, err
} else if !ok { } else if !ok {
return nil, "", errPrincipalMatch return nil, "", nil, errPrincipalMatch
} }
return r.Action, localUser, nil return r.Action, localUser, r.AcceptEnv, nil
} }
func mapLocalUser(ruleSSHUsers map[string]string, reqSSHUser string) (localUser string) { func mapLocalUser(ruleSSHUsers map[string]string, reqSSHUser string) (localUser string) {

@ -108,6 +108,7 @@ func TestIntegrationSSH(t *testing.T) {
want []string want []string
forceV1Behavior bool forceV1Behavior bool
skip bool skip bool
allowSendEnv bool
}{ }{
{ {
cmd: "id", cmd: "id",
@ -131,6 +132,18 @@ func TestIntegrationSSH(t *testing.T) {
skip: os.Getenv("SKIP_FILE_OPS") == "1" || !fallbackToSUAvailable(), skip: os.Getenv("SKIP_FILE_OPS") == "1" || !fallbackToSUAvailable(),
forceV1Behavior: false, forceV1Behavior: false,
}, },
{
cmd: `echo "${GIT_ENV_VAR:-unset1} ${EXACT_MATCH:-unset2} ${TESTING:-unset3} ${NOT_ALLOWED:-unset4}"`,
want: []string{"working1 working2 working3 unset4"},
forceV1Behavior: false,
allowSendEnv: true,
},
{
cmd: `echo "${GIT_ENV_VAR:-unset1} ${EXACT_MATCH:-unset2} ${TESTING:-unset3} ${NOT_ALLOWED:-unset4}"`,
want: []string{"unset1 unset2 unset3 unset4"},
forceV1Behavior: false,
allowSendEnv: false,
},
} }
for _, test := range tests { for _, test := range tests {
@ -151,7 +164,13 @@ func TestIntegrationSSH(t *testing.T) {
} }
t.Run(fmt.Sprintf("%s_%s_%s", test.cmd, shellQualifier, versionQualifier), func(t *testing.T) { t.Run(fmt.Sprintf("%s_%s_%s", test.cmd, shellQualifier, versionQualifier), func(t *testing.T) {
s := testSession(t, test.forceV1Behavior) sendEnv := map[string]string{
"GIT_ENV_VAR": "working1",
"EXACT_MATCH": "working2",
"TESTING": "working3",
"NOT_ALLOWED": "working4",
}
s := testSession(t, test.forceV1Behavior, test.allowSendEnv, sendEnv)
if shell { if shell {
err := s.RequestPty("xterm", 40, 80, ssh.TerminalModes{ err := s.RequestPty("xterm", 40, 80, ssh.TerminalModes{
@ -201,7 +220,7 @@ func TestIntegrationSFTP(t *testing.T) {
} }
wantText := "hello world" wantText := "hello world"
cl := testClient(t, forceV1Behavior) cl := testClient(t, forceV1Behavior, false)
scl, err := sftp.NewClient(cl) scl, err := sftp.NewClient(cl)
if err != nil { if err != nil {
t.Fatalf("can't get sftp client: %s", err) t.Fatalf("can't get sftp client: %s", err)
@ -233,7 +252,7 @@ func TestIntegrationSFTP(t *testing.T) {
t.Fatalf("unexpected file contents (-got +want):\n%s", diff) t.Fatalf("unexpected file contents (-got +want):\n%s", diff)
} }
s := testSessionFor(t, cl) s := testSessionFor(t, cl, nil)
got := s.run(t, "ls -l "+filePath, false) got := s.run(t, "ls -l "+filePath, false)
if !strings.Contains(got, "testuser") { if !strings.Contains(got, "testuser") {
t.Fatalf("unexpected file owner user: %s", got) t.Fatalf("unexpected file owner user: %s", got)
@ -262,7 +281,7 @@ func TestIntegrationSCP(t *testing.T) {
} }
wantText := "hello world" wantText := "hello world"
cl := testClient(t, forceV1Behavior) cl := testClient(t, forceV1Behavior, false)
scl, err := scp.NewClientBySSH(cl) scl, err := scp.NewClientBySSH(cl)
if err != nil { if err != nil {
t.Fatalf("can't get sftp client: %s", err) t.Fatalf("can't get sftp client: %s", err)
@ -291,7 +310,7 @@ func TestIntegrationSCP(t *testing.T) {
t.Fatalf("unexpected file contents (-got +want):\n%s", diff) t.Fatalf("unexpected file contents (-got +want):\n%s", diff)
} }
s := testSessionFor(t, cl) s := testSessionFor(t, cl, nil)
got := s.run(t, "ls -l "+filePath, false) got := s.run(t, "ls -l "+filePath, false)
if !strings.Contains(got, "testuser") { if !strings.Contains(got, "testuser") {
t.Fatalf("unexpected file owner user: %s", got) t.Fatalf("unexpected file owner user: %s", got)
@ -349,7 +368,7 @@ func TestSSHAgentForwarding(t *testing.T) {
// Run tailscale SSH server and connect to it // Run tailscale SSH server and connect to it
username := "testuser" username := "testuser"
tailscaleAddr := testServer(t, username, false) tailscaleAddr := testServer(t, username, false, false)
tcl, err := ssh.Dial("tcp", tailscaleAddr, &ssh.ClientConfig{ tcl, err := ssh.Dial("tcp", tailscaleAddr, &ssh.ClientConfig{
HostKeyCallback: ssh.InsecureIgnoreHostKey(), HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}) })
@ -465,11 +484,11 @@ readLoop:
return string(_got) return string(_got)
} }
func testClient(t *testing.T, forceV1Behavior bool, authMethods ...ssh.AuthMethod) *ssh.Client { func testClient(t *testing.T, forceV1Behavior bool, allowSendEnv bool, authMethods ...ssh.AuthMethod) *ssh.Client {
t.Helper() t.Helper()
username := "testuser" username := "testuser"
addr := testServer(t, username, forceV1Behavior) addr := testServer(t, username, forceV1Behavior, allowSendEnv)
cl, err := ssh.Dial("tcp", addr, &ssh.ClientConfig{ cl, err := ssh.Dial("tcp", addr, &ssh.ClientConfig{
HostKeyCallback: ssh.InsecureIgnoreHostKey(), HostKeyCallback: ssh.InsecureIgnoreHostKey(),
@ -483,9 +502,9 @@ func testClient(t *testing.T, forceV1Behavior bool, authMethods ...ssh.AuthMetho
return cl return cl
} }
func testServer(t *testing.T, username string, forceV1Behavior bool) string { func testServer(t *testing.T, username string, forceV1Behavior bool, allowSendEnv bool) string {
srv := &server{ srv := &server{
lb: &testBackend{localUser: username, forceV1Behavior: forceV1Behavior}, lb: &testBackend{localUser: username, forceV1Behavior: forceV1Behavior, allowSendEnv: allowSendEnv},
logf: log.Printf, logf: log.Printf,
tailscaledPath: os.Getenv("TAILSCALED_PATH"), tailscaledPath: os.Getenv("TAILSCALED_PATH"),
timeNow: time.Now, timeNow: time.Now,
@ -509,16 +528,20 @@ func testServer(t *testing.T, username string, forceV1Behavior bool) string {
return l.Addr().String() return l.Addr().String()
} }
func testSession(t *testing.T, forceV1Behavior bool) *session { func testSession(t *testing.T, forceV1Behavior bool, allowSendEnv bool, sendEnv map[string]string) *session {
cl := testClient(t, forceV1Behavior) cl := testClient(t, forceV1Behavior, allowSendEnv)
return testSessionFor(t, cl) return testSessionFor(t, cl, sendEnv)
} }
func testSessionFor(t *testing.T, cl *ssh.Client) *session { func testSessionFor(t *testing.T, cl *ssh.Client, sendEnv map[string]string) *session {
s, err := cl.NewSession() s, err := cl.NewSession()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
for k, v := range sendEnv {
s.Setenv(k, v)
}
t.Cleanup(func() { s.Close() }) t.Cleanup(func() { s.Close() })
stdinReader, stdinWriter := io.Pipe() stdinReader, stdinWriter := io.Pipe()
@ -564,6 +587,7 @@ func generateClientKey(t *testing.T, privateKeyFile string) (ssh.Signer, *rsa.Pr
type testBackend struct { type testBackend struct {
localUser string localUser string
forceV1Behavior bool forceV1Behavior bool
allowSendEnv bool
} }
func (tb *testBackend) GetSSH_HostKeys() ([]gossh.Signer, error) { func (tb *testBackend) GetSSH_HostKeys() ([]gossh.Signer, error) {
@ -597,6 +621,9 @@ func (tb *testBackend) NetMap() *netmap.NetworkMap {
if tb.forceV1Behavior { if tb.forceV1Behavior {
capMap[tailcfg.NodeAttrSSHBehaviorV1] = struct{}{} capMap[tailcfg.NodeAttrSSHBehaviorV1] = struct{}{}
} }
if tb.allowSendEnv {
capMap[tailcfg.NodeAttrSSHEnvironmentVariables] = struct{}{}
}
return &netmap.NetworkMap{ return &netmap.NetworkMap{
SSHPolicy: &tailcfg.SSHPolicy{ SSHPolicy: &tailcfg.SSHPolicy{
Rules: []*tailcfg.SSHRule{ Rules: []*tailcfg.SSHRule{
@ -604,6 +631,7 @@ func (tb *testBackend) NetMap() *netmap.NetworkMap {
Principals: []*tailcfg.SSHPrincipal{{Any: true}}, Principals: []*tailcfg.SSHPrincipal{{Any: true}},
Action: &tailcfg.SSHAction{Accept: true, AllowAgentForwarding: true}, Action: &tailcfg.SSHAction{Accept: true, AllowAgentForwarding: true},
SSHUsers: map[string]string{"*": tb.localUser}, SSHUsers: map[string]string{"*": tb.localUser},
AcceptEnv: []string{"GIT_*", "EXACT_MATCH", "TEST?NG"},
}, },
}, },
}, },

@ -24,6 +24,7 @@ import (
"os/user" "os/user"
"reflect" "reflect"
"runtime" "runtime"
"slices"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -56,11 +57,12 @@ import (
func TestMatchRule(t *testing.T) { func TestMatchRule(t *testing.T) {
someAction := new(tailcfg.SSHAction) someAction := new(tailcfg.SSHAction)
tests := []struct { tests := []struct {
name string name string
rule *tailcfg.SSHRule rule *tailcfg.SSHRule
ci *sshConnInfo ci *sshConnInfo
wantErr error wantErr error
wantUser string wantUser string
wantAcceptEnv []string
}{ }{
{ {
name: "invalid-conn", name: "invalid-conn",
@ -153,6 +155,21 @@ func TestMatchRule(t *testing.T) {
ci: &sshConnInfo{sshUser: "alice"}, ci: &sshConnInfo{sshUser: "alice"},
wantUser: "thealice", wantUser: "thealice",
}, },
{
name: "ok-with-accept-env",
rule: &tailcfg.SSHRule{
Action: someAction,
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
SSHUsers: map[string]string{
"*": "ubuntu",
"alice": "thealice",
},
AcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"},
},
ci: &sshConnInfo{sshUser: "alice"},
wantUser: "thealice",
wantAcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"},
},
{ {
name: "no-users-for-reject", name: "no-users-for-reject",
rule: &tailcfg.SSHRule{ rule: &tailcfg.SSHRule{
@ -210,7 +227,7 @@ func TestMatchRule(t *testing.T) {
info: tt.ci, info: tt.ci,
srv: &server{logf: t.Logf}, srv: &server{logf: t.Logf},
} }
got, gotUser, err := c.matchRule(tt.rule, nil) got, gotUser, gotAcceptEnv, err := c.matchRule(tt.rule, nil)
if err != tt.wantErr { if err != tt.wantErr {
t.Errorf("err = %v; want %v", err, tt.wantErr) t.Errorf("err = %v; want %v", err, tt.wantErr)
} }
@ -220,6 +237,128 @@ func TestMatchRule(t *testing.T) {
if err == nil && got == nil { if err == nil && got == nil {
t.Errorf("expected non-nil action on success") t.Errorf("expected non-nil action on success")
} }
if !slices.Equal(gotAcceptEnv, tt.wantAcceptEnv) {
t.Errorf("acceptEnv = %v; want %v", gotAcceptEnv, tt.wantAcceptEnv)
}
})
}
}
func TestEvalSSHPolicy(t *testing.T) {
someAction := new(tailcfg.SSHAction)
tests := []struct {
name string
policy *tailcfg.SSHPolicy
ci *sshConnInfo
wantMatch bool
wantUser string
wantAcceptEnv []string
}{
{
name: "multiple-matches-picks-first-match",
policy: &tailcfg.SSHPolicy{
Rules: []*tailcfg.SSHRule{
{
Action: someAction,
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
SSHUsers: map[string]string{
"other": "other1",
},
},
{
Action: someAction,
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
SSHUsers: map[string]string{
"*": "ubuntu",
"alice": "thealice",
},
AcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"},
},
{
Action: someAction,
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
SSHUsers: map[string]string{
"other2": "other3",
},
},
{
Action: someAction,
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
SSHUsers: map[string]string{
"*": "ubuntu",
"alice": "thealice",
"mark": "markthe",
},
AcceptEnv: []string{"*"},
},
},
},
ci: &sshConnInfo{sshUser: "alice"},
wantUser: "thealice",
wantAcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"},
wantMatch: true,
},
{
name: "no-matches-returns-failure",
policy: &tailcfg.SSHPolicy{
Rules: []*tailcfg.SSHRule{
{
Action: someAction,
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
SSHUsers: map[string]string{
"other": "other1",
},
},
{
Action: someAction,
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
SSHUsers: map[string]string{
"fedora": "ubuntu",
},
AcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"},
},
{
Action: someAction,
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
SSHUsers: map[string]string{
"other2": "other3",
},
},
{
Action: someAction,
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
SSHUsers: map[string]string{
"mark": "markthe",
},
AcceptEnv: []string{"*"},
},
},
},
ci: &sshConnInfo{sshUser: "alice"},
wantUser: "",
wantAcceptEnv: nil,
wantMatch: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &conn{
info: tt.ci,
srv: &server{logf: t.Logf},
}
got, gotUser, gotAcceptEnv, match := c.evalSSHPolicy(tt.policy, nil)
if match != tt.wantMatch {
t.Errorf("match = %v; want %v", match, tt.wantMatch)
}
if gotUser != tt.wantUser {
t.Errorf("user = %q; want %q", gotUser, tt.wantUser)
}
if tt.wantMatch == true && got == nil {
t.Errorf("expected non-nil action on success")
}
if !slices.Equal(gotAcceptEnv, tt.wantAcceptEnv) {
t.Errorf("acceptEnv = %v; want %v", gotAcceptEnv, tt.wantAcceptEnv)
}
}) })
} }
} }

Loading…
Cancel
Save