ssh/tailssh: handle not-authenticated-yet connections in matchRule

Also make more fields in conn.info thread safe, there was previously a
data race here.

Fixes #5110

Signed-off-by: Maisem Ali <maisem@tailscale.com>
pull/5117/head
Maisem Ali 2 years ago committed by Maisem Ali
parent 41e60dae80
commit 480fd6c797

@ -86,8 +86,11 @@ func (ss *sshSession) newIncubatorCommand() *exec.Cmd {
// TODO(maisem): this doesn't work with sftp // TODO(maisem): this doesn't work with sftp
return exec.CommandContext(ss.ctx, name, args...) return exec.CommandContext(ss.ctx, name, args...)
} }
ss.conn.mu.Lock()
lu := ss.conn.localUser lu := ss.conn.localUser
ci := ss.conn.info ci := ss.conn.info
gids := strings.Join(ss.conn.userGroupIDs, ",")
ss.conn.mu.Unlock()
remoteUser := ci.uprof.LoginName remoteUser := ci.uprof.LoginName
if len(ci.node.Tags) > 0 { if len(ci.node.Tags) > 0 {
remoteUser = strings.Join(ci.node.Tags, ",") remoteUser = strings.Join(ci.node.Tags, ",")
@ -98,7 +101,7 @@ func (ss *sshSession) newIncubatorCommand() *exec.Cmd {
"ssh", "ssh",
"--uid=" + lu.Uid, "--uid=" + lu.Uid,
"--gid=" + lu.Gid, "--gid=" + lu.Gid,
"--groups=" + strings.Join(ss.conn.userGroupIDs, ","), "--groups=" + gids,
"--local-user=" + lu.Username, "--local-user=" + lu.Username,
"--remote-user=" + remoteUser, "--remote-user=" + remoteUser,
"--remote-ip=" + ci.src.IP().String(), "--remote-ip=" + ci.src.IP().String(),

@ -141,6 +141,14 @@ func (srv *server) OnPolicyChange() {
srv.mu.Lock() srv.mu.Lock()
defer srv.mu.Unlock() defer srv.mu.Unlock()
for c := range srv.activeConns { for c := range srv.activeConns {
c.mu.Lock()
ci := c.info
c.mu.Unlock()
if ci == nil {
// c.info is nil when the connection hasn't been authenticated yet.
// In that case, the connection will be terminated when it is.
continue
}
go c.checkStillValid() go c.checkStillValid()
} }
} }
@ -155,11 +163,11 @@ type conn struct {
connID string // ID that's shared with control connID string // ID that's shared with control
action0 *tailcfg.SSHAction // first matching action action0 *tailcfg.SSHAction // first matching action
srv *server srv *server
info *sshConnInfo // set by setInfo
localUser *user.User // set by checkAuth
userGroupIDs []string // set by checkAuth
mu sync.Mutex // protects the following mu sync.Mutex // protects the following
localUser *user.User // set by checkAuth
userGroupIDs []string // set by checkAuth
info *sshConnInfo // set by setInfo
// idH is the RFC4253 sec8 hash H. It is used to identify the connection, // idH is the RFC4253 sec8 hash H. It is used to identify the connection,
// and is shared among all sessions. It should not be shared outside // and is shared among all sessions. It should not be shared outside
// process. It is confusingly referred to as SessionID by the gliderlabs/ssh // process. It is confusingly referred to as SessionID by the gliderlabs/ssh
@ -179,9 +187,13 @@ func (c *conn) logf(format string, args ...any) {
// PublicKeyHandler implements ssh.PublicKeyHandler is called by the the // PublicKeyHandler implements ssh.PublicKeyHandler is called by the the
// ssh.Server when the client presents a public key. // ssh.Server when the client presents a public key.
func (c *conn) PublicKeyHandler(ctx ssh.Context, pubKey ssh.PublicKey) error { func (c *conn) PublicKeyHandler(ctx ssh.Context, pubKey ssh.PublicKey) error {
if c.info == nil { c.mu.Lock()
ci := c.info
c.mu.Unlock()
if ci == nil {
return gossh.ErrDenied return gossh.ErrDenied
} }
if err := c.checkAuth(pubKey); err != nil { if err := c.checkAuth(pubKey); err != nil {
// TODO(maisem/bradfitz): surface the error here. // TODO(maisem/bradfitz): surface the error here.
c.logf("rejecting SSH public key %s: %v", bytes.TrimSpace(gossh.MarshalAuthorizedKey(pubKey)), err) c.logf("rejecting SSH public key %s: %v", bytes.TrimSpace(gossh.MarshalAuthorizedKey(pubKey)), err)
@ -217,7 +229,7 @@ func (c *conn) NoClientAuthCallback(cm gossh.ConnMetadata) (*gossh.Permissions,
func (c *conn) checkAuth(pubKey ssh.PublicKey) error { func (c *conn) checkAuth(pubKey ssh.PublicKey) error {
a, localUser, err := c.evaluatePolicy(pubKey) a, localUser, err := c.evaluatePolicy(pubKey)
if err != nil { if err != nil {
if pubKey == nil && c.havePubKeyPolicy(c.info) { if pubKey == nil && c.havePubKeyPolicy() {
return errPubKeyRequired return errPubKeyRequired
} }
return fmt.Errorf("%w: %v", gossh.ErrDenied, err) return fmt.Errorf("%w: %v", gossh.ErrDenied, err)
@ -236,6 +248,8 @@ func (c *conn) checkAuth(pubKey ssh.PublicKey) error {
if err != nil { if err != nil {
return err return err
} }
c.mu.Lock()
defer c.mu.Unlock()
c.userGroupIDs = gids c.userGroupIDs = gids
c.localUser = lu c.localUser = lu
return nil return nil
@ -329,7 +343,13 @@ func (c *conn) mayForwardLocalPortTo(ctx ssh.Context, destinationHost string, de
// havePubKeyPolicy reports whether any policy rule may provide access by means // havePubKeyPolicy reports whether any policy rule may provide access by means
// of a ssh.PublicKey. // of a ssh.PublicKey.
func (c *conn) havePubKeyPolicy(ci *sshConnInfo) bool { func (c *conn) havePubKeyPolicy() bool {
c.mu.Lock()
ci := c.info
c.mu.Unlock()
if ci == nil {
panic("havePubKeyPolicy called before setInfo")
}
// Is there any rule that looks like it'd require a public key for this // Is there any rule that looks like it'd require a public key for this
// sshUser? // sshUser?
pol, ok := c.sshPolicy() pol, ok := c.sshPolicy()
@ -414,6 +434,8 @@ func (c *conn) setInfo(cm gossh.ConnMetadata) error {
if !ok { if !ok {
return fmt.Errorf("unknown Tailscale identity from src %v", ci.src) return fmt.Errorf("unknown Tailscale identity from src %v", ci.src)
} }
c.mu.Lock()
defer c.mu.Unlock()
ci.node = node ci.node = node
ci.uprof = &uprof ci.uprof = &uprof
@ -589,8 +611,10 @@ func (c *conn) handleSessionPostSSHAuth(s ssh.Session) {
} }
ss := c.newSSHSession(s) ss := c.newSSHSession(s)
c.mu.Lock()
ss.logf("handling new SSH connection from %v (%v) to ssh-user %q", c.info.uprof.LoginName, c.info.src.IP(), c.localUser.Username) ss.logf("handling new SSH connection from %v (%v) to ssh-user %q", c.info.uprof.LoginName, c.info.src.IP(), c.localUser.Username)
ss.logf("access granted to %v as ssh-user %q", c.info.uprof.LoginName, c.localUser.Username) ss.logf("access granted to %v as ssh-user %q", c.info.uprof.LoginName, c.localUser.Username)
c.mu.Unlock()
ss.run() ss.run()
} }
@ -688,7 +712,10 @@ func (c *conn) resolveTerminalActionLocked(s ssh.Session, cr *contextReader) (ac
func (c *conn) expandDelegateURL(actionURL string) string { func (c *conn) expandDelegateURL(actionURL string) string {
nm := c.srv.lb.NetMap() nm := c.srv.lb.NetMap()
c.mu.Lock()
ci := c.info ci := c.info
lu := c.localUser
c.mu.Unlock()
var dstNodeID string var dstNodeID string
if nm != nil { if nm != nil {
dstNodeID = fmt.Sprint(int64(nm.SelfNode.ID)) dstNodeID = fmt.Sprint(int64(nm.SelfNode.ID))
@ -699,7 +726,7 @@ func (c *conn) expandDelegateURL(actionURL string) string {
"$DST_NODE_IP", url.QueryEscape(ci.dst.IP().String()), "$DST_NODE_IP", url.QueryEscape(ci.dst.IP().String()),
"$DST_NODE_ID", dstNodeID, "$DST_NODE_ID", dstNodeID,
"$SSH_USER", url.QueryEscape(ci.sshUser), "$SSH_USER", url.QueryEscape(ci.sshUser),
"$LOCAL_USER", url.QueryEscape(c.localUser.Username), "$LOCAL_USER", url.QueryEscape(lu.Username),
).Replace(actionURL) ).Replace(actionURL)
} }
@ -709,10 +736,12 @@ func (c *conn) expandPublicKeyURL(pubKeyURL string) string {
} }
var localPart string var localPart string
var loginName string var loginName string
c.mu.Lock()
if c.info.uprof != nil { if c.info.uprof != nil {
loginName = c.info.uprof.LoginName loginName = c.info.uprof.LoginName
localPart, _, _ = strings.Cut(loginName, "@") localPart, _, _ = strings.Cut(loginName, "@")
} }
c.mu.Unlock()
return strings.NewReplacer( return strings.NewReplacer(
"$LOGINNAME_EMAIL", loginName, "$LOGINNAME_EMAIL", loginName,
"$LOGINNAME_LOCALPART", localPart, "$LOGINNAME_LOCALPART", localPart,
@ -768,6 +797,8 @@ func (c *conn) isStillValid() bool {
if !a.Accept && a.HoldAndDelegate == "" { if !a.Accept && a.HoldAndDelegate == "" {
return false return false
} }
c.mu.Lock()
defer c.mu.Unlock()
return c.localUser.Username == localUser return c.localUser.Username == localUser
} }
@ -944,6 +975,8 @@ func (ss *sshSession) run() {
return return
} }
ss.conn.startSessionLocked(ss) ss.conn.startSessionLocked(ss)
lu := ss.conn.localUser
localUser := lu.Username
srv.mu.Unlock() srv.mu.Unlock()
defer ss.conn.endSession(ss) defer ss.conn.endSession(ss)
@ -959,8 +992,6 @@ func (ss *sshSession) run() {
} }
logf := ss.logf logf := ss.logf
lu := ss.conn.localUser
localUser := lu.Username
if euid := os.Geteuid(); euid != 0 { if euid := os.Geteuid(); euid != 0 {
if lu.Uid != fmt.Sprint(euid) { if lu.Uid != fmt.Sprint(euid) {
@ -1110,9 +1141,20 @@ var (
errRuleExpired = errors.New("rule expired") errRuleExpired = errors.New("rule expired")
errPrincipalMatch = errors.New("principal didn't match") errPrincipalMatch = errors.New("principal didn't match")
errUserMatch = errors.New("user didn't match") errUserMatch = errors.New("user didn't match")
errInvalidConn = errors.New("invalid connection state")
) )
func (c *conn) matchRule(r *tailcfg.SSHRule, pubKey gossh.PublicKey) (a *tailcfg.SSHAction, localUser string, err error) { func (c *conn) matchRule(r *tailcfg.SSHRule, pubKey gossh.PublicKey) (a *tailcfg.SSHAction, localUser string, err error) {
if c == nil {
return nil, "", errInvalidConn
}
c.mu.Lock()
ci := c.info
c.mu.Unlock()
if ci == nil {
c.logf("invalid connection state")
return nil, "", errInvalidConn
}
if r == nil { if r == nil {
return nil, "", errNilRule return nil, "", errNilRule
} }
@ -1126,7 +1168,7 @@ func (c *conn) matchRule(r *tailcfg.SSHRule, pubKey gossh.PublicKey) (a *tailcfg
// For all but Reject rules, SSHUsers is required. // For all but Reject rules, SSHUsers is required.
// If SSHUsers is nil or empty, mapLocalUser will return an // If SSHUsers is nil or empty, mapLocalUser will return an
// empty string anyway. // empty string anyway.
localUser = mapLocalUser(r.SSHUsers, c.info.sshUser) localUser = mapLocalUser(r.SSHUsers, ci.sshUser)
if localUser == "" { if localUser == "" {
return nil, "", errUserMatch return nil, "", errUserMatch
} }
@ -1175,7 +1217,9 @@ func (c *conn) principalMatches(p *tailcfg.SSHPrincipal, pubKey gossh.PublicKey)
// that match the Tailscale identity match (Node, NodeIP, UserLogin, Any). // that match the Tailscale identity match (Node, NodeIP, UserLogin, Any).
// This function does not consider PubKeys. // This function does not consider PubKeys.
func (c *conn) principalMatchesTailscaleIdentity(p *tailcfg.SSHPrincipal) bool { func (c *conn) principalMatchesTailscaleIdentity(p *tailcfg.SSHPrincipal) bool {
c.mu.Lock()
ci := c.info ci := c.info
c.mu.Unlock()
if p.Any { if p.Any {
return true return true
} }

@ -47,13 +47,26 @@ func TestMatchRule(t *testing.T) {
wantErr error wantErr error
wantUser string wantUser string
}{ }{
{
name: "invalid-conn",
rule: &tailcfg.SSHRule{
Action: someAction,
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
SSHUsers: map[string]string{
"*": "ubuntu",
},
},
wantErr: errInvalidConn,
},
{ {
name: "nil-rule", name: "nil-rule",
ci: &sshConnInfo{},
rule: nil, rule: nil,
wantErr: errNilRule, wantErr: errNilRule,
}, },
{ {
name: "nil-action", name: "nil-action",
ci: &sshConnInfo{},
rule: &tailcfg.SSHRule{}, rule: &tailcfg.SSHRule{},
wantErr: errNilAction, wantErr: errNilAction,
}, },
@ -180,6 +193,7 @@ func TestMatchRule(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c := &conn{ c := &conn{
info: tt.ci, info: tt.ci,
srv: &server{logf: t.Logf},
} }
got, gotUser, err := c.matchRule(tt.rule, nil) got, gotUser, err := c.matchRule(tt.rule, nil)
if err != tt.wantErr { if err != tt.wantErr {

Loading…
Cancel
Save