From 8f44ba1cd626a389c451c85651d3cc33c0df4c6f Mon Sep 17 00:00:00 2001 From: Mario Minardi Date: Mon, 30 Sep 2024 21:47:45 -0600 Subject: [PATCH] ssh: Add logic to set accepted environment variables in SSH session (#13559) Add logic to set environment variables that match the SSH rule's `acceptEnv` settings in the SSH session's environment. Updates https://github.com/tailscale/corp/issues/22775 Signed-off-by: Mario Minardi --- ssh/tailssh/accept_env.go | 19 ++- ssh/tailssh/accept_env_test.go | 18 ++- ssh/tailssh/incubator.go | 61 ++++++++-- ssh/tailssh/tailssh.go | 44 +++---- ssh/tailssh/tailssh_integration_test.go | 56 ++++++--- ssh/tailssh/tailssh_test.go | 151 +++++++++++++++++++++++- 6 files changed, 294 insertions(+), 55 deletions(-) diff --git a/ssh/tailssh/accept_env.go b/ssh/tailssh/accept_env.go index df4f1d010..6461a79a3 100644 --- a/ssh/tailssh/accept_env.go +++ b/ssh/tailssh/accept_env.go @@ -4,6 +4,7 @@ package tailssh import ( + "fmt" "slices" "strings" ) @@ -17,27 +18,35 @@ import ( // // acceptEnv values may contain * and ? wildcard characters which match against // zero or one or more characters and a single character respectively. -func filterEnv(acceptEnv []string, environ []string) []string { +func filterEnv(acceptEnv []string, environ []string) ([]string, error) { var acceptedPairs []string + // Quick return if we have an empty list. + if acceptEnv == nil || len(acceptEnv) == 0 { + return acceptedPairs, nil + } + for _, envPair := range environ { - envVar := strings.Split(envPair, "=")[0] + variableName, _, ok := strings.Cut(envPair, "=") + if !ok { + return nil, fmt.Errorf(`invalid environment variable: %q. Variables must be in "KEY=VALUE" format`, envPair) + } // Short circuit if we have a direct match between the environment // variable and an AcceptEnv value. - if slices.Contains(acceptEnv, envVar) { + if slices.Contains(acceptEnv, variableName) { acceptedPairs = append(acceptedPairs, envPair) continue } // Otherwise check if we have a wildcard pattern that matches. - if matchAcceptEnv(acceptEnv, envVar) { + if matchAcceptEnv(acceptEnv, variableName) { acceptedPairs = append(acceptedPairs, envPair) continue } } - return acceptedPairs + return acceptedPairs, nil } // matchAcceptEnv is a convenience function that wraps calling matchAcceptEnvPattern diff --git a/ssh/tailssh/accept_env_test.go b/ssh/tailssh/accept_env_test.go index c67774447..b54c98097 100644 --- a/ssh/tailssh/accept_env_test.go +++ b/ssh/tailssh/accept_env_test.go @@ -108,6 +108,7 @@ func TestFilterEnv(t *testing.T) { acceptEnv []string environ []string expectedFiltered []string + wantErrMessage string }{ { name: "simple direct matches", @@ -127,11 +128,26 @@ func TestFilterEnv(t *testing.T) { environ: []string{"FOO=BAR", "FOO2=BAZ", "FOO_3=123", "FOOOO4-2=AbCdEfG", "FO1-kmndGamc79567=ABC", "FO57=BAR2"}, expectedFiltered: []string{"FOO=BAR", "FOOOO4-2=AbCdEfG", "FO1-kmndGamc79567=ABC"}, }, + { + name: "environ format invalid", + acceptEnv: []string{"FO?", "FOOO*", "FO*5?7"}, + environ: []string{"FOOBAR"}, + expectedFiltered: nil, + wantErrMessage: `invalid environment variable: "FOOBAR". Variables must be in "KEY=VALUE" format`, + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - filtered := filterEnv(tc.acceptEnv, tc.environ) + filtered, err := filterEnv(tc.acceptEnv, tc.environ) + if err == nil && tc.wantErrMessage != "" { + t.Errorf("wanted error with message %q but error was nil", tc.wantErrMessage) + } + + if err != nil && err.Error() != tc.wantErrMessage { + t.Errorf("err = %v; want %v", err, tc.wantErrMessage) + } + if diff := cmp.Diff(tc.expectedFiltered, filtered); diff != "" { t.Errorf("unexpected filter result (-got,+want): \n%s", diff) } diff --git a/ssh/tailssh/incubator.go b/ssh/tailssh/incubator.go index 37f2a5434..f47492082 100644 --- a/ssh/tailssh/incubator.go +++ b/ssh/tailssh/incubator.go @@ -12,6 +12,7 @@ package tailssh import ( + "encoding/json" "errors" "flag" "fmt" @@ -154,6 +155,22 @@ func (ss *sshSession) newIncubatorCommand(logf logger.Logf) (cmd *exec.Cmd, err incubatorArgs = append(incubatorArgs, "--cmd="+ss.RawCommand()) } + allowSendEnv := nm.HasCap(tailcfg.NodeAttrSSHEnvironmentVariables) + if allowSendEnv { + env, err := filterEnv(ss.conn.acceptEnv, ss.Session.Environ()) + if err != nil { + return nil, err + } + + if len(env) > 0 { + encoded, err := json.Marshal(env) + if err != nil { + return nil, fmt.Errorf("failed to encode environment: %w", err) + } + incubatorArgs = append(incubatorArgs, fmt.Sprintf("--encoded-env=%q", encoded)) + } + } + return exec.CommandContext(ss.ctx, ss.conn.srv.tailscaledPath, incubatorArgs...), nil } @@ -192,6 +209,9 @@ type incubatorArgs struct { forceV1Behavior bool debugTest bool isSELinuxEnforcing bool + encodedEnv string + allowListEnvKeys string + forwardedEnviron []string } func parseIncubatorArgs(args []string) (incubatorArgs, error) { @@ -215,6 +235,7 @@ func parseIncubatorArgs(args []string) (incubatorArgs, error) { flags.BoolVar(&ia.forceV1Behavior, "force-v1-behavior", false, "allow falling back to the su command if login is unavailable") flags.BoolVar(&ia.debugTest, "debug-test", false, "should debug in test mode") flags.BoolVar(&ia.isSELinuxEnforcing, "is-selinux-enforcing", false, "whether SELinux is in enforcing mode") + flags.StringVar(&ia.encodedEnv, "encoded-env", "", "JSON encoded array of environment variables in '['key=value']' format") flags.Parse(args) for _, g := range strings.Split(groups, ",") { @@ -225,6 +246,30 @@ func parseIncubatorArgs(args []string) (incubatorArgs, error) { ia.gids = append(ia.gids, gid) } + ia.forwardedEnviron = os.Environ() + // pass through SSH_AUTH_SOCK environment variable to support ssh agent forwarding + ia.allowListEnvKeys = "SSH_AUTH_SOCK" + + if ia.encodedEnv != "" { + unquoted, err := strconv.Unquote(ia.encodedEnv) + if err != nil { + return ia, fmt.Errorf("unable to parse encodedEnv %q: %w", ia.encodedEnv, err) + } + + var extraEnviron []string + + err = json.Unmarshal([]byte(unquoted), &extraEnviron) + if err != nil { + return ia, fmt.Errorf("unable to parse encodedEnv %q: %w", ia.encodedEnv, err) + } + + ia.forwardedEnviron = append(ia.forwardedEnviron, extraEnviron...) + + for _, v := range extraEnviron { + ia.allowListEnvKeys = fmt.Sprintf("%s,%s", ia.allowListEnvKeys, strings.Split(v, "=")[0]) + } + } + return ia, nil } @@ -406,7 +451,7 @@ func tryExecLogin(dlogf logger.Logf, ia incubatorArgs) error { dlogf("logging in with %+v", loginArgs) // If Exec works, the Go code will not proceed past this: - err = unix.Exec(loginCmdPath, loginArgs, os.Environ()) + err = unix.Exec(loginCmdPath, loginArgs, ia.forwardedEnviron) // If we made it here, Exec failed. return err @@ -441,7 +486,7 @@ func trySU(dlogf logger.Logf, ia incubatorArgs) (handled bool, err error) { loginArgs := []string{ su, - "-w", "SSH_AUTH_SOCK", // pass through SSH_AUTH_SOCK environment variable to support ssh agent forwarding + "-w", ia.allowListEnvKeys, "-l", ia.localUser, } @@ -453,7 +498,7 @@ func trySU(dlogf logger.Logf, ia incubatorArgs) (handled bool, err error) { dlogf("logging in with %+v", loginArgs) // If Exec works, the Go code will not proceed past this: - err = unix.Exec(su, loginArgs, os.Environ()) + err = unix.Exec(su, loginArgs, ia.forwardedEnviron) // If we made it here, Exec failed. return true, err @@ -482,11 +527,11 @@ func findSU(dlogf logger.Logf, ia incubatorArgs) string { return "" } - // First try to execute su -w SSH_AUTH_SOCK -l -c true + // First try to execute su -w -l -c true // to make sure su supports the necessary arguments. err = exec.Command( su, - "-w", "SSH_AUTH_SOCK", + "-w", ia.allowListEnvKeys, "-l", ia.localUser, "-c", "true", @@ -515,7 +560,7 @@ func handleSSHInProcess(dlogf logger.Logf, ia incubatorArgs) error { args := shellArgs(ia.isShell, ia.cmd) dlogf("running %s %q", ia.loginShell, args) - cmd := newCommand(ia.hasTTY, ia.loginShell, args) + cmd := newCommand(ia.hasTTY, ia.loginShell, ia.forwardedEnviron, args) err := cmd.Run() if ee, ok := err.(*exec.ExitError); ok { ps := ee.ProcessState @@ -532,12 +577,12 @@ func handleSSHInProcess(dlogf logger.Logf, ia incubatorArgs) error { return err } -func newCommand(hasTTY bool, cmdPath string, cmdArgs []string) *exec.Cmd { +func newCommand(hasTTY bool, cmdPath string, cmdEnviron []string, cmdArgs []string) *exec.Cmd { cmd := exec.Command(cmdPath, cmdArgs...) cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr - cmd.Env = os.Environ() + cmd.Env = cmdEnviron if hasTTY { // If we were launched with a tty then we should mark that as the ctty diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index 7187b5b59..9ade1847e 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -238,6 +238,7 @@ type conn struct { localUser *userMeta // set by doPolicyAuth userGroupIDs []string // set by doPolicyAuth pubKey gossh.PublicKey // set by doPolicyAuth + acceptEnv []string // mu protects the following fields. // @@ -377,7 +378,7 @@ func (c *conn) doPolicyAuth(ctx ssh.Context, pubKey ssh.PublicKey) error { c.logf("failed to get conninfo: %v", err) return errDenied } - a, localUser, err := c.evaluatePolicy(pubKey) + a, localUser, acceptEnv, err := c.evaluatePolicy(pubKey) if err != nil { if pubKey == nil && c.havePubKeyPolicy() { return errPubKeyRequired @@ -387,6 +388,7 @@ func (c *conn) doPolicyAuth(ctx ssh.Context, pubKey ssh.PublicKey) error { c.action0 = a c.currentAction = a c.pubKey = pubKey + c.acceptEnv = acceptEnv if a.Message != "" { if err := ctx.SendAuthBanner(a.Message); err != nil { return fmt.Errorf("SendBanner: %w", err) @@ -619,16 +621,16 @@ func (c *conn) setInfo(ctx ssh.Context) error { // evaluatePolicy returns the SSHAction and localUser after evaluating // the SSHPolicy for this conn. The pubKey may be nil for "none" auth. -func (c *conn) evaluatePolicy(pubKey gossh.PublicKey) (_ *tailcfg.SSHAction, localUser string, _ error) { +func (c *conn) evaluatePolicy(pubKey gossh.PublicKey) (_ *tailcfg.SSHAction, localUser string, acceptEnv []string, _ error) { pol, ok := c.sshPolicy() if !ok { - return nil, "", fmt.Errorf("tailssh: rejecting connection; no SSH policy") + return nil, "", nil, fmt.Errorf("tailssh: rejecting connection; no SSH policy") } - a, localUser, ok := c.evalSSHPolicy(pol, pubKey) + a, localUser, acceptEnv, ok := c.evalSSHPolicy(pol, pubKey) if !ok { - return nil, "", fmt.Errorf("tailssh: rejecting connection; no matching policy") + return nil, "", nil, fmt.Errorf("tailssh: rejecting connection; no matching policy") } - return a, localUser, nil + return a, localUser, acceptEnv, nil } // pubKeyCacheEntry is the cache value for an HTTPS URL of public keys (like @@ -892,7 +894,7 @@ func (c *conn) newSSHSession(s ssh.Session) *sshSession { // isStillValid reports whether the conn is still valid. func (c *conn) isStillValid() bool { - a, localUser, err := c.evaluatePolicy(c.pubKey) + a, localUser, _, err := c.evaluatePolicy(c.pubKey) c.vlogf("stillValid: %+v %v %v", a, localUser, err) if err != nil { return false @@ -1275,13 +1277,13 @@ func (c *conn) ruleExpired(r *tailcfg.SSHRule) bool { return r.RuleExpires.Before(c.srv.now()) } -func (c *conn) evalSSHPolicy(pol *tailcfg.SSHPolicy, pubKey gossh.PublicKey) (a *tailcfg.SSHAction, localUser string, ok bool) { +func (c *conn) evalSSHPolicy(pol *tailcfg.SSHPolicy, pubKey gossh.PublicKey) (a *tailcfg.SSHAction, localUser string, acceptEnv []string, ok bool) { for _, r := range pol.Rules { - if a, localUser, err := c.matchRule(r, pubKey); err == nil { - return a, localUser, true + if a, localUser, acceptEnv, err := c.matchRule(r, pubKey); err == nil { + return a, localUser, acceptEnv, true } } - return nil, "", false + return nil, "", nil, false } // internal errors for testing; they don't escape to callers or logs. @@ -1294,26 +1296,26 @@ var ( errInvalidConn = errors.New("invalid connection state") ) -func (c *conn) matchRule(r *tailcfg.SSHRule, pubKey gossh.PublicKey) (a *tailcfg.SSHAction, localUser string, err error) { +func (c *conn) matchRule(r *tailcfg.SSHRule, pubKey gossh.PublicKey) (a *tailcfg.SSHAction, localUser string, acceptEnv []string, err error) { defer func() { c.vlogf("matchRule(%+v): %v", r, err) }() if c == nil { - return nil, "", errInvalidConn + return nil, "", nil, errInvalidConn } if c.info == nil { c.logf("invalid connection state") - return nil, "", errInvalidConn + return nil, "", nil, errInvalidConn } if r == nil { - return nil, "", errNilRule + return nil, "", nil, errNilRule } if r.Action == nil { - return nil, "", errNilAction + return nil, "", nil, errNilAction } if c.ruleExpired(r) { - return nil, "", errRuleExpired + return nil, "", nil, errRuleExpired } if !r.Action.Reject { // For all but Reject rules, SSHUsers is required. @@ -1321,15 +1323,15 @@ func (c *conn) matchRule(r *tailcfg.SSHRule, pubKey gossh.PublicKey) (a *tailcfg // empty string anyway. localUser = mapLocalUser(r.SSHUsers, c.info.sshUser) if localUser == "" { - return nil, "", errUserMatch + return nil, "", nil, errUserMatch } } if ok, err := c.anyPrincipalMatches(r.Principals, pubKey); err != nil { - return nil, "", err + return nil, "", nil, err } else if !ok { - return nil, "", errPrincipalMatch + return nil, "", nil, errPrincipalMatch } - return r.Action, localUser, nil + return r.Action, localUser, r.AcceptEnv, nil } func mapLocalUser(ruleSSHUsers map[string]string, reqSSHUser string) (localUser string) { diff --git a/ssh/tailssh/tailssh_integration_test.go b/ssh/tailssh/tailssh_integration_test.go index 485c13fdb..1799d3400 100644 --- a/ssh/tailssh/tailssh_integration_test.go +++ b/ssh/tailssh/tailssh_integration_test.go @@ -108,6 +108,7 @@ func TestIntegrationSSH(t *testing.T) { want []string forceV1Behavior bool skip bool + allowSendEnv bool }{ { cmd: "id", @@ -131,6 +132,18 @@ func TestIntegrationSSH(t *testing.T) { skip: os.Getenv("SKIP_FILE_OPS") == "1" || !fallbackToSUAvailable(), forceV1Behavior: false, }, + { + cmd: `echo "${GIT_ENV_VAR:-unset1} ${EXACT_MATCH:-unset2} ${TESTING:-unset3} ${NOT_ALLOWED:-unset4}"`, + want: []string{"working1 working2 working3 unset4"}, + forceV1Behavior: false, + allowSendEnv: true, + }, + { + cmd: `echo "${GIT_ENV_VAR:-unset1} ${EXACT_MATCH:-unset2} ${TESTING:-unset3} ${NOT_ALLOWED:-unset4}"`, + want: []string{"unset1 unset2 unset3 unset4"}, + forceV1Behavior: false, + allowSendEnv: false, + }, } for _, test := range tests { @@ -151,7 +164,13 @@ func TestIntegrationSSH(t *testing.T) { } t.Run(fmt.Sprintf("%s_%s_%s", test.cmd, shellQualifier, versionQualifier), func(t *testing.T) { - s := testSession(t, test.forceV1Behavior) + sendEnv := map[string]string{ + "GIT_ENV_VAR": "working1", + "EXACT_MATCH": "working2", + "TESTING": "working3", + "NOT_ALLOWED": "working4", + } + s := testSession(t, test.forceV1Behavior, test.allowSendEnv, sendEnv) if shell { err := s.RequestPty("xterm", 40, 80, ssh.TerminalModes{ @@ -201,7 +220,7 @@ func TestIntegrationSFTP(t *testing.T) { } wantText := "hello world" - cl := testClient(t, forceV1Behavior) + cl := testClient(t, forceV1Behavior, false) scl, err := sftp.NewClient(cl) if err != nil { t.Fatalf("can't get sftp client: %s", err) @@ -233,7 +252,7 @@ func TestIntegrationSFTP(t *testing.T) { t.Fatalf("unexpected file contents (-got +want):\n%s", diff) } - s := testSessionFor(t, cl) + s := testSessionFor(t, cl, nil) got := s.run(t, "ls -l "+filePath, false) if !strings.Contains(got, "testuser") { t.Fatalf("unexpected file owner user: %s", got) @@ -262,7 +281,7 @@ func TestIntegrationSCP(t *testing.T) { } wantText := "hello world" - cl := testClient(t, forceV1Behavior) + cl := testClient(t, forceV1Behavior, false) scl, err := scp.NewClientBySSH(cl) if err != nil { t.Fatalf("can't get sftp client: %s", err) @@ -291,7 +310,7 @@ func TestIntegrationSCP(t *testing.T) { t.Fatalf("unexpected file contents (-got +want):\n%s", diff) } - s := testSessionFor(t, cl) + s := testSessionFor(t, cl, nil) got := s.run(t, "ls -l "+filePath, false) if !strings.Contains(got, "testuser") { t.Fatalf("unexpected file owner user: %s", got) @@ -349,7 +368,7 @@ func TestSSHAgentForwarding(t *testing.T) { // Run tailscale SSH server and connect to it username := "testuser" - tailscaleAddr := testServer(t, username, false) + tailscaleAddr := testServer(t, username, false, false) tcl, err := ssh.Dial("tcp", tailscaleAddr, &ssh.ClientConfig{ HostKeyCallback: ssh.InsecureIgnoreHostKey(), }) @@ -465,11 +484,11 @@ readLoop: return string(_got) } -func testClient(t *testing.T, forceV1Behavior bool, authMethods ...ssh.AuthMethod) *ssh.Client { +func testClient(t *testing.T, forceV1Behavior bool, allowSendEnv bool, authMethods ...ssh.AuthMethod) *ssh.Client { t.Helper() username := "testuser" - addr := testServer(t, username, forceV1Behavior) + addr := testServer(t, username, forceV1Behavior, allowSendEnv) cl, err := ssh.Dial("tcp", addr, &ssh.ClientConfig{ HostKeyCallback: ssh.InsecureIgnoreHostKey(), @@ -483,9 +502,9 @@ func testClient(t *testing.T, forceV1Behavior bool, authMethods ...ssh.AuthMetho return cl } -func testServer(t *testing.T, username string, forceV1Behavior bool) string { +func testServer(t *testing.T, username string, forceV1Behavior bool, allowSendEnv bool) string { srv := &server{ - lb: &testBackend{localUser: username, forceV1Behavior: forceV1Behavior}, + lb: &testBackend{localUser: username, forceV1Behavior: forceV1Behavior, allowSendEnv: allowSendEnv}, logf: log.Printf, tailscaledPath: os.Getenv("TAILSCALED_PATH"), timeNow: time.Now, @@ -509,16 +528,20 @@ func testServer(t *testing.T, username string, forceV1Behavior bool) string { return l.Addr().String() } -func testSession(t *testing.T, forceV1Behavior bool) *session { - cl := testClient(t, forceV1Behavior) - return testSessionFor(t, cl) +func testSession(t *testing.T, forceV1Behavior bool, allowSendEnv bool, sendEnv map[string]string) *session { + cl := testClient(t, forceV1Behavior, allowSendEnv) + return testSessionFor(t, cl, sendEnv) } -func testSessionFor(t *testing.T, cl *ssh.Client) *session { +func testSessionFor(t *testing.T, cl *ssh.Client, sendEnv map[string]string) *session { s, err := cl.NewSession() if err != nil { t.Fatal(err) } + for k, v := range sendEnv { + s.Setenv(k, v) + } + t.Cleanup(func() { s.Close() }) stdinReader, stdinWriter := io.Pipe() @@ -564,6 +587,7 @@ func generateClientKey(t *testing.T, privateKeyFile string) (ssh.Signer, *rsa.Pr type testBackend struct { localUser string forceV1Behavior bool + allowSendEnv bool } func (tb *testBackend) GetSSH_HostKeys() ([]gossh.Signer, error) { @@ -597,6 +621,9 @@ func (tb *testBackend) NetMap() *netmap.NetworkMap { if tb.forceV1Behavior { capMap[tailcfg.NodeAttrSSHBehaviorV1] = struct{}{} } + if tb.allowSendEnv { + capMap[tailcfg.NodeAttrSSHEnvironmentVariables] = struct{}{} + } return &netmap.NetworkMap{ SSHPolicy: &tailcfg.SSHPolicy{ Rules: []*tailcfg.SSHRule{ @@ -604,6 +631,7 @@ func (tb *testBackend) NetMap() *netmap.NetworkMap { Principals: []*tailcfg.SSHPrincipal{{Any: true}}, Action: &tailcfg.SSHAction{Accept: true, AllowAgentForwarding: true}, SSHUsers: map[string]string{"*": tb.localUser}, + AcceptEnv: []string{"GIT_*", "EXACT_MATCH", "TEST?NG"}, }, }, }, diff --git a/ssh/tailssh/tailssh_test.go b/ssh/tailssh/tailssh_test.go index cdeaa4a05..9e4f5ffd3 100644 --- a/ssh/tailssh/tailssh_test.go +++ b/ssh/tailssh/tailssh_test.go @@ -24,6 +24,7 @@ import ( "os/user" "reflect" "runtime" + "slices" "strconv" "strings" "sync" @@ -56,11 +57,12 @@ import ( func TestMatchRule(t *testing.T) { someAction := new(tailcfg.SSHAction) tests := []struct { - name string - rule *tailcfg.SSHRule - ci *sshConnInfo - wantErr error - wantUser string + name string + rule *tailcfg.SSHRule + ci *sshConnInfo + wantErr error + wantUser string + wantAcceptEnv []string }{ { name: "invalid-conn", @@ -153,6 +155,21 @@ func TestMatchRule(t *testing.T) { ci: &sshConnInfo{sshUser: "alice"}, wantUser: "thealice", }, + { + name: "ok-with-accept-env", + rule: &tailcfg.SSHRule{ + Action: someAction, + Principals: []*tailcfg.SSHPrincipal{{Any: true}}, + SSHUsers: map[string]string{ + "*": "ubuntu", + "alice": "thealice", + }, + AcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"}, + }, + ci: &sshConnInfo{sshUser: "alice"}, + wantUser: "thealice", + wantAcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"}, + }, { name: "no-users-for-reject", rule: &tailcfg.SSHRule{ @@ -210,7 +227,7 @@ func TestMatchRule(t *testing.T) { info: tt.ci, srv: &server{logf: t.Logf}, } - got, gotUser, err := c.matchRule(tt.rule, nil) + got, gotUser, gotAcceptEnv, err := c.matchRule(tt.rule, nil) if err != tt.wantErr { t.Errorf("err = %v; want %v", err, tt.wantErr) } @@ -220,6 +237,128 @@ func TestMatchRule(t *testing.T) { if err == nil && got == nil { t.Errorf("expected non-nil action on success") } + if !slices.Equal(gotAcceptEnv, tt.wantAcceptEnv) { + t.Errorf("acceptEnv = %v; want %v", gotAcceptEnv, tt.wantAcceptEnv) + } + }) + } +} + +func TestEvalSSHPolicy(t *testing.T) { + someAction := new(tailcfg.SSHAction) + tests := []struct { + name string + policy *tailcfg.SSHPolicy + ci *sshConnInfo + wantMatch bool + wantUser string + wantAcceptEnv []string + }{ + { + name: "multiple-matches-picks-first-match", + policy: &tailcfg.SSHPolicy{ + Rules: []*tailcfg.SSHRule{ + { + Action: someAction, + Principals: []*tailcfg.SSHPrincipal{{Any: true}}, + SSHUsers: map[string]string{ + "other": "other1", + }, + }, + { + Action: someAction, + Principals: []*tailcfg.SSHPrincipal{{Any: true}}, + SSHUsers: map[string]string{ + "*": "ubuntu", + "alice": "thealice", + }, + AcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"}, + }, + { + Action: someAction, + Principals: []*tailcfg.SSHPrincipal{{Any: true}}, + SSHUsers: map[string]string{ + "other2": "other3", + }, + }, + { + Action: someAction, + Principals: []*tailcfg.SSHPrincipal{{Any: true}}, + SSHUsers: map[string]string{ + "*": "ubuntu", + "alice": "thealice", + "mark": "markthe", + }, + AcceptEnv: []string{"*"}, + }, + }, + }, + ci: &sshConnInfo{sshUser: "alice"}, + wantUser: "thealice", + wantAcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"}, + wantMatch: true, + }, + { + name: "no-matches-returns-failure", + policy: &tailcfg.SSHPolicy{ + Rules: []*tailcfg.SSHRule{ + { + Action: someAction, + Principals: []*tailcfg.SSHPrincipal{{Any: true}}, + SSHUsers: map[string]string{ + "other": "other1", + }, + }, + { + Action: someAction, + Principals: []*tailcfg.SSHPrincipal{{Any: true}}, + SSHUsers: map[string]string{ + "fedora": "ubuntu", + }, + AcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"}, + }, + { + Action: someAction, + Principals: []*tailcfg.SSHPrincipal{{Any: true}}, + SSHUsers: map[string]string{ + "other2": "other3", + }, + }, + { + Action: someAction, + Principals: []*tailcfg.SSHPrincipal{{Any: true}}, + SSHUsers: map[string]string{ + "mark": "markthe", + }, + AcceptEnv: []string{"*"}, + }, + }, + }, + ci: &sshConnInfo{sshUser: "alice"}, + wantUser: "", + wantAcceptEnv: nil, + wantMatch: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &conn{ + info: tt.ci, + srv: &server{logf: t.Logf}, + } + got, gotUser, gotAcceptEnv, match := c.evalSSHPolicy(tt.policy, nil) + if match != tt.wantMatch { + t.Errorf("match = %v; want %v", match, tt.wantMatch) + } + if gotUser != tt.wantUser { + t.Errorf("user = %q; want %q", gotUser, tt.wantUser) + } + if tt.wantMatch == true && got == nil { + t.Errorf("expected non-nil action on success") + } + if !slices.Equal(gotAcceptEnv, tt.wantAcceptEnv) { + t.Errorf("acceptEnv = %v; want %v", gotAcceptEnv, tt.wantAcceptEnv) + } }) } }