diff --git a/ssh/tailssh/incubator.go b/ssh/tailssh/incubator.go index 8ecd279a8..912438b58 100644 --- a/ssh/tailssh/incubator.go +++ b/ssh/tailssh/incubator.go @@ -86,8 +86,11 @@ func (ss *sshSession) newIncubatorCommand() *exec.Cmd { // TODO(maisem): this doesn't work with sftp return exec.CommandContext(ss.ctx, name, args...) } + ss.conn.mu.Lock() lu := ss.conn.localUser ci := ss.conn.info + gids := strings.Join(ss.conn.userGroupIDs, ",") + ss.conn.mu.Unlock() remoteUser := ci.uprof.LoginName if len(ci.node.Tags) > 0 { remoteUser = strings.Join(ci.node.Tags, ",") @@ -98,7 +101,7 @@ func (ss *sshSession) newIncubatorCommand() *exec.Cmd { "ssh", "--uid=" + lu.Uid, "--gid=" + lu.Gid, - "--groups=" + strings.Join(ss.conn.userGroupIDs, ","), + "--groups=" + gids, "--local-user=" + lu.Username, "--remote-user=" + remoteUser, "--remote-ip=" + ci.src.IP().String(), diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index 67ad8acb9..3262ba0c5 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -141,6 +141,14 @@ func (srv *server) OnPolicyChange() { srv.mu.Lock() defer srv.mu.Unlock() for c := range srv.activeConns { + c.mu.Lock() + ci := c.info + c.mu.Unlock() + if ci == nil { + // c.info is nil when the connection hasn't been authenticated yet. + // In that case, the connection will be terminated when it is. + continue + } go c.checkStillValid() } } @@ -152,14 +160,14 @@ type conn struct { insecureSkipTailscaleAuth bool // used by tests. - connID string // ID that's shared with control - action0 *tailcfg.SSHAction // first matching action - srv *server - info *sshConnInfo // set by setInfo + connID string // ID that's shared with control + action0 *tailcfg.SSHAction // first matching action + srv *server + + mu sync.Mutex // protects the following localUser *user.User // set by checkAuth userGroupIDs []string // set by checkAuth - - mu sync.Mutex // protects the following + info *sshConnInfo // set by setInfo // idH is the RFC4253 sec8 hash H. It is used to identify the connection, // and is shared among all sessions. It should not be shared outside // process. It is confusingly referred to as SessionID by the gliderlabs/ssh @@ -179,9 +187,13 @@ func (c *conn) logf(format string, args ...any) { // PublicKeyHandler implements ssh.PublicKeyHandler is called by the the // ssh.Server when the client presents a public key. func (c *conn) PublicKeyHandler(ctx ssh.Context, pubKey ssh.PublicKey) error { - if c.info == nil { + c.mu.Lock() + ci := c.info + c.mu.Unlock() + if ci == nil { return gossh.ErrDenied } + if err := c.checkAuth(pubKey); err != nil { // TODO(maisem/bradfitz): surface the error here. c.logf("rejecting SSH public key %s: %v", bytes.TrimSpace(gossh.MarshalAuthorizedKey(pubKey)), err) @@ -217,7 +229,7 @@ func (c *conn) NoClientAuthCallback(cm gossh.ConnMetadata) (*gossh.Permissions, func (c *conn) checkAuth(pubKey ssh.PublicKey) error { a, localUser, err := c.evaluatePolicy(pubKey) if err != nil { - if pubKey == nil && c.havePubKeyPolicy(c.info) { + if pubKey == nil && c.havePubKeyPolicy() { return errPubKeyRequired } return fmt.Errorf("%w: %v", gossh.ErrDenied, err) @@ -236,6 +248,8 @@ func (c *conn) checkAuth(pubKey ssh.PublicKey) error { if err != nil { return err } + c.mu.Lock() + defer c.mu.Unlock() c.userGroupIDs = gids c.localUser = lu return nil @@ -329,7 +343,13 @@ func (c *conn) mayForwardLocalPortTo(ctx ssh.Context, destinationHost string, de // havePubKeyPolicy reports whether any policy rule may provide access by means // of a ssh.PublicKey. -func (c *conn) havePubKeyPolicy(ci *sshConnInfo) bool { +func (c *conn) havePubKeyPolicy() bool { + c.mu.Lock() + ci := c.info + c.mu.Unlock() + if ci == nil { + panic("havePubKeyPolicy called before setInfo") + } // Is there any rule that looks like it'd require a public key for this // sshUser? pol, ok := c.sshPolicy() @@ -414,6 +434,8 @@ func (c *conn) setInfo(cm gossh.ConnMetadata) error { if !ok { return fmt.Errorf("unknown Tailscale identity from src %v", ci.src) } + c.mu.Lock() + defer c.mu.Unlock() ci.node = node ci.uprof = &uprof @@ -589,8 +611,10 @@ func (c *conn) handleSessionPostSSHAuth(s ssh.Session) { } ss := c.newSSHSession(s) + c.mu.Lock() ss.logf("handling new SSH connection from %v (%v) to ssh-user %q", c.info.uprof.LoginName, c.info.src.IP(), c.localUser.Username) ss.logf("access granted to %v as ssh-user %q", c.info.uprof.LoginName, c.localUser.Username) + c.mu.Unlock() ss.run() } @@ -688,7 +712,10 @@ func (c *conn) resolveTerminalActionLocked(s ssh.Session, cr *contextReader) (ac func (c *conn) expandDelegateURL(actionURL string) string { nm := c.srv.lb.NetMap() + c.mu.Lock() ci := c.info + lu := c.localUser + c.mu.Unlock() var dstNodeID string if nm != nil { dstNodeID = fmt.Sprint(int64(nm.SelfNode.ID)) @@ -699,7 +726,7 @@ func (c *conn) expandDelegateURL(actionURL string) string { "$DST_NODE_IP", url.QueryEscape(ci.dst.IP().String()), "$DST_NODE_ID", dstNodeID, "$SSH_USER", url.QueryEscape(ci.sshUser), - "$LOCAL_USER", url.QueryEscape(c.localUser.Username), + "$LOCAL_USER", url.QueryEscape(lu.Username), ).Replace(actionURL) } @@ -709,10 +736,12 @@ func (c *conn) expandPublicKeyURL(pubKeyURL string) string { } var localPart string var loginName string + c.mu.Lock() if c.info.uprof != nil { loginName = c.info.uprof.LoginName localPart, _, _ = strings.Cut(loginName, "@") } + c.mu.Unlock() return strings.NewReplacer( "$LOGINNAME_EMAIL", loginName, "$LOGINNAME_LOCALPART", localPart, @@ -768,6 +797,8 @@ func (c *conn) isStillValid() bool { if !a.Accept && a.HoldAndDelegate == "" { return false } + c.mu.Lock() + defer c.mu.Unlock() return c.localUser.Username == localUser } @@ -944,6 +975,8 @@ func (ss *sshSession) run() { return } ss.conn.startSessionLocked(ss) + lu := ss.conn.localUser + localUser := lu.Username srv.mu.Unlock() defer ss.conn.endSession(ss) @@ -959,8 +992,6 @@ func (ss *sshSession) run() { } logf := ss.logf - lu := ss.conn.localUser - localUser := lu.Username if euid := os.Geteuid(); euid != 0 { if lu.Uid != fmt.Sprint(euid) { @@ -1110,9 +1141,20 @@ var ( errRuleExpired = errors.New("rule expired") errPrincipalMatch = errors.New("principal didn't match") errUserMatch = errors.New("user didn't match") + errInvalidConn = errors.New("invalid connection state") ) func (c *conn) matchRule(r *tailcfg.SSHRule, pubKey gossh.PublicKey) (a *tailcfg.SSHAction, localUser string, err error) { + if c == nil { + return nil, "", errInvalidConn + } + c.mu.Lock() + ci := c.info + c.mu.Unlock() + if ci == nil { + c.logf("invalid connection state") + return nil, "", errInvalidConn + } if r == nil { return nil, "", errNilRule } @@ -1126,7 +1168,7 @@ func (c *conn) matchRule(r *tailcfg.SSHRule, pubKey gossh.PublicKey) (a *tailcfg // For all but Reject rules, SSHUsers is required. // If SSHUsers is nil or empty, mapLocalUser will return an // empty string anyway. - localUser = mapLocalUser(r.SSHUsers, c.info.sshUser) + localUser = mapLocalUser(r.SSHUsers, ci.sshUser) if localUser == "" { return nil, "", errUserMatch } @@ -1175,7 +1217,9 @@ func (c *conn) principalMatches(p *tailcfg.SSHPrincipal, pubKey gossh.PublicKey) // that match the Tailscale identity match (Node, NodeIP, UserLogin, Any). // This function does not consider PubKeys. func (c *conn) principalMatchesTailscaleIdentity(p *tailcfg.SSHPrincipal) bool { + c.mu.Lock() ci := c.info + c.mu.Unlock() if p.Any { return true } diff --git a/ssh/tailssh/tailssh_test.go b/ssh/tailssh/tailssh_test.go index ce43be36c..cabb775c0 100644 --- a/ssh/tailssh/tailssh_test.go +++ b/ssh/tailssh/tailssh_test.go @@ -47,13 +47,26 @@ func TestMatchRule(t *testing.T) { wantErr error wantUser string }{ + { + name: "invalid-conn", + rule: &tailcfg.SSHRule{ + Action: someAction, + Principals: []*tailcfg.SSHPrincipal{{Any: true}}, + SSHUsers: map[string]string{ + "*": "ubuntu", + }, + }, + wantErr: errInvalidConn, + }, { name: "nil-rule", + ci: &sshConnInfo{}, rule: nil, wantErr: errNilRule, }, { name: "nil-action", + ci: &sshConnInfo{}, rule: &tailcfg.SSHRule{}, wantErr: errNilAction, }, @@ -180,6 +193,7 @@ func TestMatchRule(t *testing.T) { t.Run(tt.name, func(t *testing.T) { c := &conn{ info: tt.ci, + srv: &server{logf: t.Logf}, } got, gotUser, err := c.matchRule(tt.rule, nil) if err != tt.wantErr {