diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index 2ffd78462..ba36cde1e 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -62,11 +62,10 @@ type server struct { sessionWaitGroup sync.WaitGroup // mu protects the following - mu sync.Mutex - activeSessionByH map[string]*sshSession // ssh.SessionID (DH H) => session - activeSessionBySharedID map[string]*sshSession // yyymmddThhmmss-XXXXX => session - fetchPublicKeysCache map[string]pubKeyCacheEntry // by https URL - shutdownCalled bool + mu sync.Mutex + activeConns map[*conn]bool // set; value is always true + fetchPublicKeysCache map[string]pubKeyCacheEntry // by https URL + shutdownCalled bool } func (srv *server) now() time.Time { @@ -91,14 +90,28 @@ func init() { }) } +func (srv *server) trackActiveConn(c *conn, add bool) { + srv.mu.Lock() + defer srv.mu.Unlock() + if add { + mak.Set(&srv.activeConns, c, true) + return + } + delete(srv.activeConns, c) +} + // HandleSSHConn handles a Tailscale SSH connection from c. -func (srv *server) HandleSSHConn(c net.Conn) error { +// This is the entry point for all SSH connections. +// When this returns, the connection is closed. +func (srv *server) HandleSSHConn(nc net.Conn) error { metricIncomingConnections.Add(1) - ss, err := srv.newConn() + c, err := srv.newConn() if err != nil { return err } - ss.HandleConn(c) + srv.trackActiveConn(c, true) // add + defer srv.trackActiveConn(c, false) // remove + c.HandleConn(nc) // Return nil to signal to netstack's interception that it doesn't need to // log. If ss.HandleConn had problems, it can log itself (ideally on an @@ -110,11 +123,13 @@ func (srv *server) HandleSSHConn(c net.Conn) error { func (srv *server) Shutdown() { srv.mu.Lock() srv.shutdownCalled = true - for _, s := range srv.activeSessionByH { - s.ctx.CloseWithError(userVisibleError{ - fmt.Sprintf("Tailscale SSH is shutting down.\r\n"), - context.Canceled, - }) + for c := range srv.activeConns { + for _, s := range c.sessions { + s.ctx.CloseWithError(userVisibleError{ + fmt.Sprintf("Tailscale SSH is shutting down.\r\n"), + context.Canceled, + }) + } } srv.mu.Unlock() srv.sessionWaitGroup.Wait() @@ -125,8 +140,8 @@ func (srv *server) Shutdown() { func (srv *server) OnPolicyChange() { srv.mu.Lock() defer srv.mu.Unlock() - for _, s := range srv.activeSessionByH { - go s.checkStillValid() + for c := range srv.activeConns { + go c.checkStillValid() } } @@ -135,25 +150,33 @@ func (srv *server) OnPolicyChange() { type conn struct { *ssh.Server + insecureSkipTailscaleAuth bool // used by tests. + // now is the time to consider the present moment for the // purposes of rule evaluation. now time.Time + connID string // ID that's shared with control action0 *tailcfg.SSHAction // first matching action srv *server info *sshConnInfo // set by setInfo localUser *user.User // set by checkAuth userGroupIDs []string // set by checkAuth - insecureSkipTailscaleAuth bool // used by tests. + mu sync.Mutex // protects the following + // idH is the RFC4253 sec8 hash H. It is used to identify the connection, + // and is shared among all sessions. It should not be shared outside + // process. It is confusingly referred to as SessionID by the gliderlabs/ssh + // library. + idH string + pubKey gossh.PublicKey // set by authorizeSession + finalAction *tailcfg.SSHAction // set by authorizeSession + finalActionErr error // set by authorizeSession + sessions []*sshSession } func (c *conn) logf(format string, args ...any) { - if c.info == nil { - c.srv.logf(format, args...) - return - } - format = fmt.Sprintf("%v: %v", c.info.String(), format) + format = fmt.Sprintf("%v: %v", c.connID, format) c.srv.logf(format, args...) } @@ -247,21 +270,22 @@ func (c *conn) ServerConfig(ctx ssh.Context) *gossh.ServerConfig { func (srv *server) newConn() (*conn, error) { srv.mu.Lock() - shutdownCalled := srv.shutdownCalled - srv.mu.Unlock() - if shutdownCalled { + if srv.shutdownCalled { + srv.mu.Unlock() // Stop accepting new connections. // Connections in the auth phase are handled in handleConnPostSSHAuth. // Existing sessions are terminated by Shutdown. return nil, gossh.ErrDenied } + srv.mu.Unlock() c := &conn{srv: srv, now: srv.now()} + c.connID = fmt.Sprintf("conn-%s-%02x", c.now.UTC().Format("20060102T150405"), randBytes(5)) c.Server = &ssh.Server{ Version: "Tailscale", - Handler: c.handleConnPostSSHAuth, + Handler: c.handleSessionPostSSHAuth, RequestHandlers: map[string]ssh.RequestHandler{}, SubsystemHandlers: map[string]ssh.SubsystemHandler{ - "sftp": c.handleConnPostSSHAuth, + "sftp": c.handleSessionPostSSHAuth, }, // Note: the direct-tcpip channel handler and LocalPortForwardingCallback @@ -270,7 +294,7 @@ func (srv *server) newConn() (*conn, error) { ChannelHandlers: map[string]ssh.ChannelHandler{ "direct-tcpip": ssh.DirectTCPIPHandler, }, - LocalPortForwardingCallback: srv.mayForwardLocalPortTo, + LocalPortForwardingCallback: c.mayForwardLocalPortTo, PublicKeyHandler: c.PublicKeyHandler, ServerConfigCallback: c.ServerConfig, @@ -298,16 +322,12 @@ func (srv *server) newConn() (*conn, error) { // mayForwardLocalPortTo reports whether the ctx should be allowed to port forward // to the specified host and port. // TODO(bradfitz/maisem): should we have more checks on host/port? -func (srv *server) mayForwardLocalPortTo(ctx ssh.Context, destinationHost string, destinationPort uint32) bool { - ss, ok := srv.getSessionForContext(ctx) - if !ok { - return false - } - if !ss.action.AllowLocalPortForwarding { - return false +func (c *conn) mayForwardLocalPortTo(ctx ssh.Context, destinationHost string, destinationPort uint32) bool { + if c.finalAction != nil && c.finalAction.AllowLocalPortForwarding { + metricLocalPortForward.Add(1) + return true } - metricLocalPortForward.Add(1) - return true + return false } // havePubKeyPolicy reports whether any policy rule may provide access by means @@ -401,6 +421,7 @@ func (c *conn) setInfo(cm gossh.ConnMetadata) error { ci.uprof = &uprof c.info = ci + c.logf("handling conn: %v", ci.String()) return nil } @@ -516,32 +537,47 @@ func (srv *server) fetchPublicKeysURL(url string) ([]string, error) { return lines, err } -// handleConnPostSSHAuth runs an SSH session after the SSH-level authentication, -// but not necessarily before all the Tailscale-level extra verification has -// completed. It also handles SFTP requests. -func (c *conn) handleConnPostSSHAuth(s ssh.Session) { - if s.PublicKey() != nil { - metricPublicKeyConnections.Add(1) +func (c *conn) authorizeSession(s ssh.Session) (_ *contextReader, ok bool) { + c.mu.Lock() + defer c.mu.Unlock() + idH := s.Context().(ssh.Context).SessionID() + if c.idH == "" { + c.idH = idH + } else if c.idH != idH { + c.logf("ssh: session ID mismatch: %q != %q", c.idH, idH) + s.Exit(1) + return nil, false } - sshUser := s.User() cr := &contextReader{r: s} - action, err := c.resolveTerminalAction(s, cr) + action, err := c.resolveTerminalActionLocked(s, cr) if err != nil { c.logf("resolveTerminalAction: %v", err) io.WriteString(s.Stderr(), "Access Denied: failed during authorization check.\r\n") s.Exit(1) - return + return nil, false } if action.Reject || !action.Accept { c.logf("access denied for %v", c.info.uprof.LoginName) s.Exit(1) - return - } - if s.PublicKey() != nil { - metricPublicKeyAccepts.Add(1) + return nil, false } + return cr, true +} +// handleSessionPostSSHAuth runs an SSH session after the SSH-level authentication, +// but not necessarily before all the Tailscale-level extra verification has +// completed. It also handles SFTP requests. +func (c *conn) handleSessionPostSSHAuth(s ssh.Session) { + // Now that we have passed the SSH-level authentication, we can start the + // Tailscale-level extra verification. This means that we are going to + // evaluate the policy provided by control against the incoming SSH session. + cr, ok := c.authorizeSession(s) + if !ok { + return + } if cr.HasOutstandingRead() { + // There was some buffered input while we were waiting for the policy + // decision. s = contextReaderSesssion{s, cr} } @@ -555,20 +591,37 @@ func (c *conn) handleConnPostSSHAuth(s ssh.Session) { return } - ss := c.newSSHSession(s, action) - ss.logf("handling new SSH connection from %v (%v) to ssh-user %q", c.info.uprof.LoginName, c.info.src.IP(), sshUser) - ss.logf("access granted to %v as ssh-user %q", c.info.uprof.LoginName, sshUser) + ss := c.newSSHSession(s) + ss.logf("handling new SSH connection from %v (%v) to ssh-user %q", c.info.uprof.LoginName, c.info.src.IP(), c.localUser.Username) + ss.logf("access granted to %v as ssh-user %q", c.info.uprof.LoginName, c.localUser.Name) ss.run() } -// resolveTerminalAction either returns action0 (if it's Accept or Reject) or +// resolveTerminalActionLocked either returns action0 (if it's Accept or Reject) or // else loops, fetching new SSHActions from the control plane. // // Any action with a Message in the chain will be printed to s. // // The returned SSHAction will be either Reject or Accept. -func (c *conn) resolveTerminalAction(s ssh.Session, cr *contextReader) (*tailcfg.SSHAction, error) { - action := c.action0 +// +// c.mu must be held. +func (c *conn) resolveTerminalActionLocked(s ssh.Session, cr *contextReader) (action *tailcfg.SSHAction, err error) { + if c.finalAction != nil || c.finalActionErr != nil { + return c.finalAction, c.finalActionErr + } + + if s.PublicKey() != nil { + metricPublicKeyConnections.Add(1) + } + defer func() { + c.finalAction = action + c.finalActionErr = err + c.pubKey = s.PublicKey() + if c.pubKey != nil && action.Accept { + metricPublicKeyAccepts.Add(1) + } + }() + action = c.action0 var awaitReadOnce sync.Once // to start Reads on cr var sawInterrupt syncs.AtomicBool @@ -672,13 +725,11 @@ func (c *conn) expandPublicKeyURL(pubKeyURL string) string { // sshSession is an accepted Tailscale SSH session. type sshSession struct { ssh.Session - idH string // the RFC4253 sec8 hash H; don't share outside process sharedID string // ID that's shared with control logf logger.Logf ctx *sshContext // implements context.Context conn *conn - action *tailcfg.SSHAction agentListener net.Listener // non-nil if agent-forwarding requested+allowed // initialized by launchProcess: @@ -699,22 +750,21 @@ func (ss *sshSession) vlogf(format string, args ...interface{}) { } } -func (c *conn) newSSHSession(s ssh.Session, action *tailcfg.SSHAction) *sshSession { - sharedID := fmt.Sprintf("%s-%02x", c.now.UTC().Format("20060102T150405"), randBytes(5)) +func (c *conn) newSSHSession(s ssh.Session) *sshSession { + sharedID := fmt.Sprintf("sess-%s-%02x", c.now.UTC().Format("20060102T150405"), randBytes(5)) c.logf("starting session: %v", sharedID) return &sshSession{ Session: s, - idH: s.Context().(ssh.Context).SessionID(), sharedID: sharedID, ctx: newSSHContext(), conn: c, logf: logger.WithPrefix(c.srv.logf, "ssh-session("+sharedID+"): "), - action: action, } } -func (c *conn) isStillValid(pubKey ssh.PublicKey) bool { - a, localUser, err := c.evaluatePolicy(pubKey) +// isStillValid reports whether the conn is still valid. +func (c *conn) isStillValid() bool { + a, localUser, err := c.evaluatePolicy(c.pubKey) if err != nil { return false } @@ -724,18 +774,20 @@ func (c *conn) isStillValid(pubKey ssh.PublicKey) bool { return c.localUser.Username == localUser } -// checkStillValid checks that the session is still valid per the latest SSHPolicy. -// If not, it terminates the session. -func (ss *sshSession) checkStillValid() { - if ss.conn.isStillValid(ss.PublicKey()) { +// checkStillValid checks that the conn is still valid per the latest SSHPolicy. +// If not, it terminates all sessions associated with the conn. +func (c *conn) checkStillValid() { + if c.isStillValid() { return } metricPolicyChangeKick.Add(1) - ss.logf("session no longer valid per new SSH policy; closing") - ss.ctx.CloseWithError(userVisibleError{ - fmt.Sprintf("Access revoked.\r\n"), - context.Canceled, - }) + c.logf("session no longer valid per new SSH policy; closing") + for _, s := range c.sessions { + s.ctx.CloseWithError(userVisibleError{ + fmt.Sprintf("Access revoked.\r\n"), + context.Canceled, + }) + } } func (c *conn) fetchSSHAction(ctx context.Context, url string) (*tailcfg.SSHAction, error) { @@ -798,41 +850,27 @@ func (ss *sshSession) killProcessOnContextDone() { }) } -// sessionAction returns the SSHAction associated with the session. -func (srv *server) getSessionForContext(sctx ssh.Context) (ss *sshSession, ok bool) { - srv.mu.Lock() - defer srv.mu.Unlock() - ss, ok = srv.activeSessionByH[sctx.SessionID()] - return -} - // startSessionLocked registers ss as an active session. // It must be called with srv.mu held. -func (srv *server) startSessionLocked(ss *sshSession) { - srv.sessionWaitGroup.Add(1) - if ss.idH == "" { - panic("empty idH") - } +func (c *conn) startSessionLocked(ss *sshSession) { + c.srv.sessionWaitGroup.Add(1) if ss.sharedID == "" { panic("empty sharedID") } - if _, dup := srv.activeSessionByH[ss.idH]; dup { - panic("dup idH") - } - if _, dup := srv.activeSessionBySharedID[ss.sharedID]; dup { - panic("dup sharedID") - } - mak.Set(&srv.activeSessionByH, ss.idH, ss) - mak.Set(&srv.activeSessionBySharedID, ss.sharedID, ss) + c.sessions = append(c.sessions, ss) } // endSession unregisters s from the list of active sessions. -func (srv *server) endSession(ss *sshSession) { - defer srv.sessionWaitGroup.Done() - srv.mu.Lock() - defer srv.mu.Unlock() - delete(srv.activeSessionByH, ss.idH) - delete(srv.activeSessionBySharedID, ss.sharedID) +func (c *conn) endSession(ss *sshSession) { + defer c.srv.sessionWaitGroup.Done() + c.srv.mu.Lock() + defer c.srv.mu.Unlock() + for i, s := range c.sessions { + if s == ss { + c.sessions = append(c.sessions[:i], c.sessions[i+1:]...) + break + } + } } var errSessionDone = errors.New("session is done") @@ -841,7 +879,7 @@ var errSessionDone = errors.New("session is done") // forwards agent connections between the listener and the ssh.Session. // On success, it assigns ss.agentListener. func (ss *sshSession) handleSSHAgentForwarding(s ssh.Session, lu *user.User) error { - if !ssh.AgentRequested(ss) || !ss.action.AllowAgentForwarding { + if !ssh.AgentRequested(ss) || !ss.conn.finalAction.AllowAgentForwarding { return nil } ss.logf("ssh: agent forwarding requested") @@ -906,15 +944,15 @@ func (ss *sshSession) run() { ss.Exit(1) return } - srv.startSessionLocked(ss) + ss.conn.startSessionLocked(ss) srv.mu.Unlock() - defer srv.endSession(ss) + defer ss.conn.endSession(ss) - if ss.action.SessionDuration != 0 { - t := time.AfterFunc(ss.action.SessionDuration, func() { + if ss.conn.finalAction.SessionDuration != 0 { + t := time.AfterFunc(ss.conn.finalAction.SessionDuration, func() { ss.ctx.CloseWithError(userVisibleError{ - fmt.Sprintf("Session timeout of %v elapsed.", ss.action.SessionDuration), + fmt.Sprintf("Session timeout of %v elapsed.", ss.conn.finalAction.SessionDuration), context.DeadlineExceeded, }) }) diff --git a/ssh/tailssh/tailssh_test.go b/ssh/tailssh/tailssh_test.go index 93ed67f4e..78bda83e2 100644 --- a/ssh/tailssh/tailssh_test.go +++ b/ssh/tailssh/tailssh_test.go @@ -238,9 +238,10 @@ func TestSSH(t *testing.T) { node: &tailcfg.Node{}, uprof: &tailcfg.UserProfile{}, } + sc.finalAction = &tailcfg.SSHAction{Accept: true} sc.Handler = func(s ssh.Session) { - sc.newSSHSession(s, &tailcfg.SSHAction{Accept: true}).run() + sc.newSSHSession(s).run() } ln, err := net.Listen("tcp4", "127.0.0.1:0")