ssh/tailssh: do the full auth flow during ssh auth

Fixes #5091

Signed-off-by: Maisem Ali <maisem@tailscale.com>
pull/5883/head
Maisem Ali 2 years ago committed by Maisem Ali
parent c8a3d02989
commit f16b77de5d

@ -1,112 +0,0 @@
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tailssh
import (
"context"
"io"
"sync"
"tailscale.com/tempfork/gliderlabs/ssh"
)
// readResult is a result from a io.Reader.Read call,
// as used by contextReader.
type readResult struct {
buf []byte // ownership passed on chan send
err error
}
// contextReader wraps an io.Reader, providing a ReadContext method
// that can be aborted before yielding bytes. If it's aborted, subsequent
// reads can get those byte(s) later.
type contextReader struct {
r io.Reader
// buffered is leftover data from a previous read call that wasn't entirely
// consumed.
buffered []byte
// readErr is a previous read error that was seen while filling buffered. It
// should be returned to the caller after buffered is consumed.
readErr error
mu sync.Mutex // guards ch only
// ch is non-nil if a goroutine had been started and has a result to be
// read. The goroutine may be either still running or done and has
// send to the channel.
ch chan readResult
}
// HasOutstandingRead reports whether there's an outstanding Read call that's
// either currently blocked in a Read or whose result hasn't been consumed.
func (w *contextReader) HasOutstandingRead() bool {
w.mu.Lock()
defer w.mu.Unlock()
return w.ch != nil
}
func (w *contextReader) setChan(c chan readResult) {
w.mu.Lock()
defer w.mu.Unlock()
w.ch = c
}
// ReadContext is like Read, but takes a context permitting the read to be canceled.
//
// If the context becomes done, the underlying Read call continues and its result
// will be given to the next caller to ReadContext.
func (w *contextReader) ReadContext(ctx context.Context, p []byte) (n int, err error) {
if len(p) == 0 {
return 0, nil
}
n = copy(p, w.buffered)
if n > 0 {
w.buffered = w.buffered[n:]
if len(w.buffered) == 0 {
err = w.readErr
}
return n, err
}
if w.ch == nil {
ch := make(chan readResult, 1)
w.setChan(ch)
go func() {
rbuf := make([]byte, len(p))
n, err := w.r.Read(rbuf)
ch <- readResult{rbuf[:n], err}
}()
}
select {
case <-ctx.Done():
return 0, ctx.Err()
case rr := <-w.ch:
w.setChan(nil)
n = copy(p, rr.buf)
w.buffered = rr.buf[n:]
w.readErr = rr.err
if len(w.buffered) == 0 {
err = rr.err
}
return n, err
}
}
// contextReaderSession implements ssh.Session, wrapping another
// ssh.Session but changing its Read method to use contextReader.
type contextReaderSession struct {
ssh.Session
cr *contextReader
}
func (a contextReaderSession) Read(p []byte) (n int, err error) {
if a.cr.HasOutstandingRead() {
return a.cr.ReadContext(context.Background(), p)
}
return a.Session.Read(p)
}

@ -86,11 +86,9 @@ func (ss *sshSession) newIncubatorCommand() *exec.Cmd {
// TODO(maisem): this doesn't work with sftp // TODO(maisem): this doesn't work with sftp
return exec.CommandContext(ss.ctx, name, args...) return exec.CommandContext(ss.ctx, name, args...)
} }
ss.conn.mu.Lock()
lu := ss.conn.localUser lu := ss.conn.localUser
ci := ss.conn.info ci := ss.conn.info
gids := strings.Join(ss.conn.userGroupIDs, ",") gids := strings.Join(ss.conn.userGroupIDs, ",")
ss.conn.mu.Unlock()
remoteUser := ci.uprof.LoginName remoteUser := ci.uprof.LoginName
if len(ci.node.Tags) > 0 { if len(ci.node.Tags) > 0 {
remoteUser = strings.Join(ci.node.Tags, ",") remoteUser = strings.Join(ci.node.Tags, ",")

@ -29,7 +29,6 @@ import (
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
gossh "github.com/tailscale/golang-x-crypto/ssh" gossh "github.com/tailscale/golang-x-crypto/ssh"
@ -87,6 +86,21 @@ func init() {
}) })
} }
// attachSessionToConnIfNotShutdown ensures that srv is not shutdown before
// attaching the session to the conn. This ensures that once Shutdown is called,
// new sessions are not allowed and existing ones are cleaned up.
// It reports whether ss was attached to the conn.
func (srv *server) attachSessionToConnIfNotShutdown(ss *sshSession) bool {
srv.mu.Lock()
defer srv.mu.Unlock()
if srv.shutdownCalled {
// Do not start any new sessions.
return false
}
ss.conn.attachSession(ss)
return true
}
func (srv *server) trackActiveConn(c *conn, add bool) { func (srv *server) trackActiveConn(c *conn, add bool) {
srv.mu.Lock() srv.mu.Lock()
defer srv.mu.Unlock() defer srv.mu.Unlock()
@ -121,12 +135,7 @@ func (srv *server) Shutdown() {
srv.mu.Lock() srv.mu.Lock()
srv.shutdownCalled = true srv.shutdownCalled = true
for c := range srv.activeConns { for c := range srv.activeConns {
for _, s := range c.sessions { c.Close()
s.ctx.CloseWithError(userVisibleError{
fmt.Sprintf("Tailscale SSH is shutting down.\r\n"),
context.Canceled,
})
}
} }
srv.mu.Unlock() srv.mu.Unlock()
srv.sessionWaitGroup.Wait() srv.sessionWaitGroup.Wait()
@ -138,10 +147,7 @@ func (srv *server) OnPolicyChange() {
srv.mu.Lock() srv.mu.Lock()
defer srv.mu.Unlock() defer srv.mu.Unlock()
for c := range srv.activeConns { for c := range srv.activeConns {
c.mu.Lock() if c.info == nil {
ci := c.info
c.mu.Unlock()
if ci == nil {
// c.info is nil when the connection hasn't been authenticated yet. // c.info is nil when the connection hasn't been authenticated yet.
// In that case, the connection will be terminated when it is. // In that case, the connection will be terminated when it is.
continue continue
@ -152,27 +158,52 @@ func (srv *server) OnPolicyChange() {
// conn represents a single SSH connection and its associated // conn represents a single SSH connection and its associated
// ssh.Server. // ssh.Server.
//
// During the lifecycle of a connection, the following are called in order:
// Setup and discover server info
// - ServerConfigCallback
//
// Do the user auth
// - BannerHandler
// - NoClientAuthHandler
// - PublicKeyHandler (only if NoClientAuthHandler returns errPubKeyRequired)
//
// Once auth is done, the conn can be multiplexed with multiple sessions and
// channels concurrently. At which point any of the following can be called
// in any order.
// - c.handleSessionPostSSHAuth
// - c.mayForwardLocalPortTo followed by ssh.DirectTCPIPHandler
type conn struct { type conn struct {
*ssh.Server *ssh.Server
srv *server
insecureSkipTailscaleAuth bool // used by tests. insecureSkipTailscaleAuth bool // used by tests.
connID string // ID that's shared with control
action0 *tailcfg.SSHAction // first matching action
srv *server
mu sync.Mutex // protects the following
localUser *user.User // set by checkAuth
userGroupIDs []string // set by checkAuth
info *sshConnInfo // set by setInfo
// idH is the RFC4253 sec8 hash H. It is used to identify the connection, // idH is the RFC4253 sec8 hash H. It is used to identify the connection,
// and is shared among all sessions. It should not be shared outside // and is shared among all sessions. It should not be shared outside
// process. It is confusingly referred to as SessionID by the gliderlabs/ssh // process. It is confusingly referred to as SessionID by the gliderlabs/ssh
// library. // library.
idH string idH string
pubKey gossh.PublicKey // set by authorizeSession connID string // ID that's shared with control
finalAction *tailcfg.SSHAction // set by authorizeSession
finalActionErr error // set by authorizeSession noPubKeyPolicyAuthError error // set by BannerCallback
action0 *tailcfg.SSHAction // set by doPolicyAuth; first matching action
currentAction *tailcfg.SSHAction // set by doPolicyAuth, updated by resolveNextAction
finalAction *tailcfg.SSHAction // set by doPolicyAuth or resolveNextAction
finalActionErr error // set by doPolicyAuth or resolveNextAction
info *sshConnInfo // set by setInfo
localUser *user.User // set by doPolicyAuth
userGroupIDs []string // set by doPolicyAuth
pubKey gossh.PublicKey // set by doPolicyAuth
// mu protects the following fields.
//
// srv.mu should be acquired prior to mu.
// It is safe to just acquire mu, but unsafe to
// acquire mu and then srv.mu.
mu sync.Mutex // protects the following
sessions []*sshSession sessions []*sshSession
} }
@ -181,49 +212,108 @@ func (c *conn) logf(format string, args ...any) {
c.srv.logf(format, args...) c.srv.logf(format, args...)
} }
// PublicKeyHandler implements ssh.PublicKeyHandler is called by the // isAuthorized returns nil if the connection is authorized to proceed.
// ssh.Server when the client presents a public key. func (c *conn) isAuthorized(ctx ssh.Context) error {
func (c *conn) PublicKeyHandler(ctx ssh.Context, pubKey ssh.PublicKey) error { action := c.currentAction
c.mu.Lock() for {
ci := c.info if action.Accept {
c.mu.Unlock() if c.pubKey != nil {
if ci == nil { metricPublicKeyAccepts.Add(1)
}
return nil
}
if action.Reject || action.HoldAndDelegate == "" {
return gossh.ErrDenied return gossh.ErrDenied
} }
var err error
if err := c.checkAuth(pubKey); err != nil { action, err = c.resolveNextAction(ctx)
// TODO(maisem/bradfitz): surface the error here. if err != nil {
c.logf("rejecting SSH public key %s: %v", bytes.TrimSpace(gossh.MarshalAuthorizedKey(pubKey)), err)
return err return err
} }
c.logf("accepting SSH public key %s", bytes.TrimSpace(gossh.MarshalAuthorizedKey(pubKey))) }
return nil
} }
// errPubKeyRequired is returned by NoClientAuthCallback to make the client // errPubKeyRequired is returned by NoClientAuthCallback to make the client
// resort to public-key auth; not user visible. // resort to public-key auth; not user visible.
var errPubKeyRequired = errors.New("ssh publickey required") var errPubKeyRequired = errors.New("ssh publickey required")
// BannerCallback implements ssh.BannerCallback.
// It is responsible for starting the policy evaluation, and returns
// the first message found in the action chain. It stops the evaluation
// on the first "accept" or "reject" action, and returns the message
// associated with that action (if any).
func (c *conn) BannerCallback(ctx ssh.Context) string {
if err := c.setInfo(ctx); err != nil {
c.logf("failed to get conninfo: %v", err)
return gossh.ErrDenied.Error()
}
if err := c.doPolicyAuth(ctx, nil /* no pub key */); err != nil {
// Stash the error for NoClientAuthCallback to return it.
c.noPubKeyPolicyAuthError = err
return ""
}
action := c.currentAction
for {
if action.Reject || action.Accept || action.Message != "" {
return action.Message
}
if action.HoldAndDelegate == "" {
// Do not send user-visible messages to the user.
// Let the SSH level authentication fail instead.
return ""
}
var err error
action, err = c.resolveNextAction(ctx)
if err != nil {
return ""
}
}
}
// NoClientAuthCallback implements gossh.NoClientAuthCallback and is called by // NoClientAuthCallback implements gossh.NoClientAuthCallback and is called by
// the ssh.Server when the client first connects with the "none" // the ssh.Server when the client first connects with the "none"
// authentication method. // authentication method.
func (c *conn) NoClientAuthCallback(cm gossh.ConnMetadata) (*gossh.Permissions, error) { //
// It is responsible for continuing policy evaluation from BannerCallback (or
// starting it afresh). It returns an error if the policy evaluation fails, or
// if the decision is "reject"
//
// It either returns nil (accept) or errPubKeyRequired or gossh.ErrDenied
// (reject). The errors may be wrapped.
func (c *conn) NoClientAuthCallback(ctx ssh.Context) error {
if c.insecureSkipTailscaleAuth { if c.insecureSkipTailscaleAuth {
return nil, nil return nil
} }
if err := c.setInfo(cm); err != nil { if c.noPubKeyPolicyAuthError != nil {
c.logf("failed to get conninfo: %v", err) return c.noPubKeyPolicyAuthError
return nil, gossh.ErrDenied } else if c.currentAction == nil {
// This should never happen, but if it does, we want to know.
panic("no current action")
} }
return nil, c.checkAuth(nil /* no pub key */) return c.isAuthorized(ctx)
} }
// checkAuth verifies that conn can proceed with the specified (optional) // PublicKeyHandler implements ssh.PublicKeyHandler is called by the
// ssh.Server when the client presents a public key.
func (c *conn) PublicKeyHandler(ctx ssh.Context, pubKey ssh.PublicKey) error {
if err := c.doPolicyAuth(ctx, pubKey); err != nil {
// TODO(maisem/bradfitz): surface the error here.
c.logf("rejecting SSH public key %s: %v", bytes.TrimSpace(gossh.MarshalAuthorizedKey(pubKey)), err)
return err
}
if err := c.isAuthorized(ctx); err != nil {
return err
}
c.logf("accepting SSH public key %s", bytes.TrimSpace(gossh.MarshalAuthorizedKey(pubKey)))
return nil
}
// doPolicyAuth verifies that conn can proceed with the specified (optional)
// pubKey. It returns nil if the matching policy action is Accept or // pubKey. It returns nil if the matching policy action is Accept or
// HoldAndDelegate. If pubKey is nil, there was no policy match but there is a // HoldAndDelegate. If pubKey is nil, there was no policy match but there is a
// policy that might match a public key it returns errPubKeyRequired. Otherwise, // policy that might match a public key it returns errPubKeyRequired. Otherwise,
// it returns gossh.ErrDenied possibly wrapped in gossh.WithBannerError. // it returns gossh.ErrDenied possibly wrapped in gossh.WithBannerError.
func (c *conn) checkAuth(pubKey ssh.PublicKey) error { func (c *conn) doPolicyAuth(ctx ssh.Context, pubKey ssh.PublicKey) error {
a, localUser, err := c.evaluatePolicy(pubKey) a, localUser, err := c.evaluatePolicy(pubKey)
if err != nil { if err != nil {
if pubKey == nil && c.havePubKeyPolicy() { if pubKey == nil && c.havePubKeyPolicy() {
@ -232,7 +322,12 @@ func (c *conn) checkAuth(pubKey ssh.PublicKey) error {
return fmt.Errorf("%w: %v", gossh.ErrDenied, err) return fmt.Errorf("%w: %v", gossh.ErrDenied, err)
} }
c.action0 = a c.action0 = a
c.currentAction = a
c.pubKey = pubKey
if a.Accept || a.HoldAndDelegate != "" { if a.Accept || a.HoldAndDelegate != "" {
if a.Accept {
c.finalAction = a
}
lu, err := user.Lookup(localUser) lu, err := user.Lookup(localUser)
if err != nil { if err != nil {
c.logf("failed to lookup %v: %v", localUser, err) c.logf("failed to lookup %v: %v", localUser, err)
@ -245,13 +340,12 @@ func (c *conn) checkAuth(pubKey ssh.PublicKey) error {
if err != nil { if err != nil {
return err return err
} }
c.mu.Lock()
defer c.mu.Unlock()
c.userGroupIDs = gids c.userGroupIDs = gids
c.localUser = lu c.localUser = lu
return nil return nil
} }
if a.Reject { if a.Reject {
c.finalAction = a
err := gossh.ErrDenied err := gossh.ErrDenied
if a.Message != "" { if a.Message != "" {
err = gossh.WithBannerError{ err = gossh.WithBannerError{
@ -271,7 +365,6 @@ func (c *conn) ServerConfig(ctx ssh.Context) *gossh.ServerConfig {
// OpenSSH presents this on failure as `Permission denied (tailscale).` // OpenSSH presents this on failure as `Permission denied (tailscale).`
ImplicitAuthMethod: "tailscale", ImplicitAuthMethod: "tailscale",
NoClientAuth: true, // required for the NoClientAuthCallback to run NoClientAuth: true, // required for the NoClientAuthCallback to run
NoClientAuthCallback: c.NoClientAuthCallback,
} }
} }
@ -290,22 +383,24 @@ func (srv *server) newConn() (*conn, error) {
c.connID = fmt.Sprintf("ssh-conn-%s-%02x", now.UTC().Format("20060102T150405"), randBytes(5)) c.connID = fmt.Sprintf("ssh-conn-%s-%02x", now.UTC().Format("20060102T150405"), randBytes(5))
c.Server = &ssh.Server{ c.Server = &ssh.Server{
Version: "Tailscale", Version: "Tailscale",
ServerConfigCallback: c.ServerConfig,
BannerHandler: c.BannerCallback,
NoClientAuthHandler: c.NoClientAuthCallback,
PublicKeyHandler: c.PublicKeyHandler,
Handler: c.handleSessionPostSSHAuth, Handler: c.handleSessionPostSSHAuth,
RequestHandlers: map[string]ssh.RequestHandler{}, LocalPortForwardingCallback: c.mayForwardLocalPortTo,
SubsystemHandlers: map[string]ssh.SubsystemHandler{ SubsystemHandlers: map[string]ssh.SubsystemHandler{
"sftp": c.handleSessionPostSSHAuth, "sftp": c.handleSessionPostSSHAuth,
}, },
// Note: the direct-tcpip channel handler and LocalPortForwardingCallback // Note: the direct-tcpip channel handler and LocalPortForwardingCallback
// only adds support for forwarding ports from the local machine. // only adds support for forwarding ports from the local machine.
// TODO(maisem/bradfitz): add remote port forwarding support. // TODO(maisem/bradfitz): add remote port forwarding support.
ChannelHandlers: map[string]ssh.ChannelHandler{ ChannelHandlers: map[string]ssh.ChannelHandler{
"direct-tcpip": ssh.DirectTCPIPHandler, "direct-tcpip": ssh.DirectTCPIPHandler,
}, },
LocalPortForwardingCallback: c.mayForwardLocalPortTo, RequestHandlers: map[string]ssh.RequestHandler{},
PublicKeyHandler: c.PublicKeyHandler,
ServerConfigCallback: c.ServerConfig,
} }
ss := c.Server ss := c.Server
for k, v := range ssh.DefaultRequestHandlers { for k, v := range ssh.DefaultRequestHandlers {
@ -341,10 +436,7 @@ func (c *conn) mayForwardLocalPortTo(ctx ssh.Context, destinationHost string, de
// havePubKeyPolicy reports whether any policy rule may provide access by means // havePubKeyPolicy reports whether any policy rule may provide access by means
// of a ssh.PublicKey. // of a ssh.PublicKey.
func (c *conn) havePubKeyPolicy() bool { func (c *conn) havePubKeyPolicy() bool {
c.mu.Lock() if c.info == nil {
ci := c.info
c.mu.Unlock()
if ci == nil {
panic("havePubKeyPolicy called before setInfo") panic("havePubKeyPolicy called before setInfo")
} }
// Is there any rule that looks like it'd require a public key for this // Is there any rule that looks like it'd require a public key for this
@ -357,7 +449,7 @@ func (c *conn) havePubKeyPolicy() bool {
if c.ruleExpired(r) { if c.ruleExpired(r) {
continue continue
} }
if mapLocalUser(r.SSHUsers, ci.sshUser) == "" { if mapLocalUser(r.SSHUsers, c.info.sshUser) == "" {
continue continue
} }
for _, p := range r.Principals { for _, p := range r.Principals {
@ -416,11 +508,11 @@ func toIPPort(a net.Addr) (ipp netip.AddrPort) {
// connInfo returns a populated sshConnInfo from the provided arguments, // connInfo returns a populated sshConnInfo from the provided arguments,
// validating only that they represent a known Tailscale identity. // validating only that they represent a known Tailscale identity.
func (c *conn) setInfo(cm gossh.ConnMetadata) error { func (c *conn) setInfo(ctx ssh.Context) error {
ci := &sshConnInfo{ ci := &sshConnInfo{
sshUser: cm.User(), sshUser: ctx.User(),
src: toIPPort(cm.RemoteAddr()), src: toIPPort(ctx.RemoteAddr()),
dst: toIPPort(cm.LocalAddr()), dst: toIPPort(ctx.LocalAddr()),
} }
if !tsaddr.IsTailscaleIP(ci.dst.Addr()) { if !tsaddr.IsTailscaleIP(ci.dst.Addr()) {
return fmt.Errorf("tailssh: rejecting non-Tailscale local address %v", ci.dst) return fmt.Errorf("tailssh: rejecting non-Tailscale local address %v", ci.dst)
@ -432,11 +524,10 @@ func (c *conn) setInfo(cm gossh.ConnMetadata) error {
if !ok { if !ok {
return fmt.Errorf("unknown Tailscale identity from src %v", ci.src) return fmt.Errorf("unknown Tailscale identity from src %v", ci.src)
} }
c.mu.Lock()
defer c.mu.Unlock()
ci.node = node ci.node = node
ci.uprof = &uprof ci.uprof = &uprof
c.idH = ctx.SessionID()
c.info = ci c.info = ci
c.logf("handling conn: %v", ci.String()) c.logf("handling conn: %v", ci.String())
return nil return nil
@ -554,50 +645,10 @@ func (srv *server) fetchPublicKeysURL(url string) ([]string, error) {
return lines, err return lines, err
} }
func (c *conn) authorizeSession(s ssh.Session) (_ *contextReader, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
idH := s.Context().(ssh.Context).SessionID()
if c.idH == "" {
c.idH = idH
} else if c.idH != idH {
c.logf("ssh: session ID mismatch: %q != %q", c.idH, idH)
s.Exit(1)
return nil, false
}
cr := &contextReader{r: s}
action, err := c.resolveTerminalActionLocked(s, cr)
if err != nil {
c.logf("resolveTerminalAction: %v", err)
io.WriteString(s.Stderr(), "Access Denied: failed during authorization check.\r\n")
s.Exit(1)
return nil, false
}
if action.Reject || !action.Accept {
c.logf("access denied for %v", c.info.uprof.LoginName)
s.Exit(1)
return nil, false
}
return cr, true
}
// handleSessionPostSSHAuth runs an SSH session after the SSH-level authentication, // handleSessionPostSSHAuth runs an SSH session after the SSH-level authentication,
// but not necessarily before all the Tailscale-level extra verification has // but not necessarily before all the Tailscale-level extra verification has
// completed. It also handles SFTP requests. // completed. It also handles SFTP requests.
func (c *conn) handleSessionPostSSHAuth(s ssh.Session) { func (c *conn) handleSessionPostSSHAuth(s ssh.Session) {
// Now that we have passed the SSH-level authentication, we can start the
// Tailscale-level extra verification. This means that we are going to
// evaluate the policy provided by control against the incoming SSH session.
cr, ok := c.authorizeSession(s)
if !ok {
return
}
if cr.HasOutstandingRead() {
// There was some buffered input while we were waiting for the policy
// decision.
s = contextReaderSession{s, cr}
}
// Do this check after auth, but before starting the session. // Do this check after auth, but before starting the session.
switch s.Subsystem() { switch s.Subsystem() {
case "sftp", "": case "sftp", "":
@ -609,45 +660,35 @@ func (c *conn) handleSessionPostSSHAuth(s ssh.Session) {
} }
ss := c.newSSHSession(s) ss := c.newSSHSession(s)
c.mu.Lock()
ss.logf("handling new SSH connection from %v (%v) to ssh-user %q", c.info.uprof.LoginName, c.info.src.Addr(), c.localUser.Username) ss.logf("handling new SSH connection from %v (%v) to ssh-user %q", c.info.uprof.LoginName, c.info.src.Addr(), c.localUser.Username)
ss.logf("access granted to %v as ssh-user %q", c.info.uprof.LoginName, c.localUser.Username) ss.logf("access granted to %v as ssh-user %q", c.info.uprof.LoginName, c.localUser.Username)
c.mu.Unlock()
ss.run() ss.run()
} }
// resolveTerminalActionLocked either returns action0 (if it's Accept or Reject) or // resolveNextAction starts at c.currentAction and makes it way through the
// else loops, fetching new SSHActions from the control plane. // action chain one step at a time. An action without a HoldAndDelegate is
// // considered the final action. Once a final action is reached, this function
// Any action with a Message in the chain will be printed to s. // will keep returning that action. It updates c.currentAction to the next
// // action in the chain. When the final action is reached, it also sets
// The returned SSHAction will be either Reject or Accept. // c.finalAction to the final action.
// func (c *conn) resolveNextAction(sctx ssh.Context) (action *tailcfg.SSHAction, err error) {
// c.mu must be held.
func (c *conn) resolveTerminalActionLocked(s ssh.Session, cr *contextReader) (action *tailcfg.SSHAction, err error) {
if c.finalAction != nil || c.finalActionErr != nil { if c.finalAction != nil || c.finalActionErr != nil {
return c.finalAction, c.finalActionErr return c.finalAction, c.finalActionErr
} }
if s.PublicKey() != nil {
metricPublicKeyConnections.Add(1)
}
defer func() { defer func() {
if action != nil {
c.currentAction = action
if action.Accept || action.Reject {
c.finalAction = action c.finalAction = action
}
}
if err != nil {
c.finalActionErr = err c.finalActionErr = err
c.pubKey = s.PublicKey()
if c.pubKey != nil && action.Accept {
metricPublicKeyAccepts.Add(1)
} }
}() }()
action = c.action0
var awaitReadOnce sync.Once // to start Reads on cr
var sawInterrupt atomic.Bool
var wg sync.WaitGroup
defer wg.Wait() // wait for awaitIntrOnce's goroutine to exit
ctx, cancel := context.WithCancel(s.Context()) ctx, cancel := context.WithCancel(sctx)
defer cancel() defer cancel()
// Loop processing/fetching Actions until one reaches a // Loop processing/fetching Actions until one reaches a
@ -656,10 +697,7 @@ func (c *conn) resolveTerminalActionLocked(s ssh.Session, cr *contextReader) (ac
// done (client disconnect) or its 30 minute timeout passes. // done (client disconnect) or its 30 minute timeout passes.
// (Which is a long time for somebody to see login // (Which is a long time for somebody to see login
// instructions and go to a URL to do something.) // instructions and go to a URL to do something.)
for { action = c.currentAction
if action.Message != "" {
io.WriteString(s.Stderr(), strings.Replace(action.Message, "\n", "\r\n", -1))
}
if action.Accept || action.Reject { if action.Accept || action.Reject {
if action.Reject { if action.Reject {
metricTerminalReject.Add(1) metricTerminalReject.Add(1)
@ -674,38 +712,13 @@ func (c *conn) resolveTerminalActionLocked(s ssh.Session, cr *contextReader) (ac
return nil, errors.New("reached Action that lacked Accept, Reject, and HoldAndDelegate") return nil, errors.New("reached Action that lacked Accept, Reject, and HoldAndDelegate")
} }
metricHolds.Add(1) metricHolds.Add(1)
awaitReadOnce.Do(func() {
wg.Add(1)
go func() {
defer wg.Done()
buf := make([]byte, 1)
for {
n, err := cr.ReadContext(ctx, buf)
if err != nil {
return
}
if n > 0 && buf[0] == 0x03 { // Ctrl-C
sawInterrupt.Store(true)
s.Stderr().Write([]byte("Canceled.\r\n"))
s.Exit(1)
return
}
}
}()
})
url = c.expandDelegateURLLocked(url) url = c.expandDelegateURLLocked(url)
var err error nextAction, err := c.fetchSSHAction(ctx, url)
action, err = c.fetchSSHAction(ctx, url)
if err != nil { if err != nil {
if sawInterrupt.Load() {
metricTerminalInterrupt.Add(1)
return nil, fmt.Errorf("aborted by user")
} else {
metricTerminalFetchError.Add(1) metricTerminalFetchError.Add(1)
}
return nil, fmt.Errorf("fetching SSHAction from %s: %w", url, err) return nil, fmt.Errorf("fetching SSHAction from %s: %w", url, err)
} }
} return nextAction, nil
} }
func (c *conn) expandDelegateURLLocked(actionURL string) string { func (c *conn) expandDelegateURLLocked(actionURL string) string {
@ -732,12 +745,10 @@ func (c *conn) expandPublicKeyURL(pubKeyURL string) string {
} }
var localPart string var localPart string
var loginName string var loginName string
c.mu.Lock()
if c.info.uprof != nil { if c.info.uprof != nil {
loginName = c.info.uprof.LoginName loginName = c.info.uprof.LoginName
localPart, _, _ = strings.Cut(loginName, "@") localPart, _, _ = strings.Cut(loginName, "@")
} }
c.mu.Unlock()
return strings.NewReplacer( return strings.NewReplacer(
"$LOGINNAME_EMAIL", loginName, "$LOGINNAME_EMAIL", loginName,
"$LOGINNAME_LOCALPART", localPart, "$LOGINNAME_LOCALPART", localPart,
@ -793,8 +804,6 @@ func (c *conn) isStillValid() bool {
if !a.Accept && a.HoldAndDelegate == "" { if !a.Accept && a.HoldAndDelegate == "" {
return false return false
} }
c.mu.Lock()
defer c.mu.Unlock()
return c.localUser.Username == localUser return c.localUser.Username == localUser
} }
@ -806,6 +815,8 @@ func (c *conn) checkStillValid() {
} }
metricPolicyChangeKick.Add(1) metricPolicyChangeKick.Add(1)
c.logf("session no longer valid per new SSH policy; closing") c.logf("session no longer valid per new SSH policy; closing")
c.mu.Lock()
defer c.mu.Unlock()
for _, s := range c.sessions { for _, s := range c.sessions {
s.ctx.CloseWithError(userVisibleError{ s.ctx.CloseWithError(userVisibleError{
fmt.Sprintf("Access revoked.\r\n"), fmt.Sprintf("Access revoked.\r\n"),
@ -876,21 +887,22 @@ func (ss *sshSession) killProcessOnContextDone() {
}) })
} }
// startSessionLocked registers ss as an active session. // attachSession registers ss as an active session.
// It must be called with srv.mu held. func (c *conn) attachSession(ss *sshSession) {
func (c *conn) startSessionLocked(ss *sshSession) {
c.srv.sessionWaitGroup.Add(1) c.srv.sessionWaitGroup.Add(1)
if ss.sharedID == "" { if ss.sharedID == "" {
panic("empty sharedID") panic("empty sharedID")
} }
c.mu.Lock()
defer c.mu.Unlock()
c.sessions = append(c.sessions, ss) c.sessions = append(c.sessions, ss)
} }
// endSession unregisters s from the list of active sessions. // detachSession unregisters s from the list of active sessions.
func (c *conn) endSession(ss *sshSession) { func (c *conn) detachSession(ss *sshSession) {
defer c.srv.sessionWaitGroup.Done() defer c.srv.sessionWaitGroup.Done()
c.srv.mu.Lock() c.mu.Lock()
defer c.srv.mu.Unlock() defer c.mu.Unlock()
for i, s := range c.sessions { for i, s := range c.sessions {
if s == ss { if s == ss {
c.sessions = append(c.sessions[:i], c.sessions[i+1:]...) c.sessions = append(c.sessions[:i], c.sessions[i+1:]...)
@ -960,22 +972,16 @@ func (ss *sshSession) run() {
metricActiveSessions.Add(1) metricActiveSessions.Add(1)
defer metricActiveSessions.Add(-1) defer metricActiveSessions.Add(-1)
defer ss.ctx.CloseWithError(errSessionDone) defer ss.ctx.CloseWithError(errSessionDone)
srv := ss.conn.srv
srv.mu.Lock() if attached := ss.conn.srv.attachSessionToConnIfNotShutdown(ss); !attached {
if srv.shutdownCalled {
srv.mu.Unlock()
// Do not start any new sessions.
fmt.Fprintf(ss, "Tailscale SSH is shutting down\r\n") fmt.Fprintf(ss, "Tailscale SSH is shutting down\r\n")
ss.Exit(1) ss.Exit(1)
return return
} }
ss.conn.startSessionLocked(ss) defer ss.conn.detachSession(ss)
lu := ss.conn.localUser
localUser := lu.Username
srv.mu.Unlock()
defer ss.conn.endSession(ss) lu := ss.conn.localUser
logf := ss.logf
if ss.conn.finalAction.SessionDuration != 0 { if ss.conn.finalAction.SessionDuration != 0 {
t := time.AfterFunc(ss.conn.finalAction.SessionDuration, func() { t := time.AfterFunc(ss.conn.finalAction.SessionDuration, func() {
@ -987,11 +993,9 @@ func (ss *sshSession) run() {
defer t.Stop() defer t.Stop()
} }
logf := ss.logf
if euid := os.Geteuid(); euid != 0 { if euid := os.Geteuid(); euid != 0 {
if lu.Uid != fmt.Sprint(euid) { if lu.Uid != fmt.Sprint(euid) {
ss.logf("can't switch to user %q from process euid %v", localUser, euid) ss.logf("can't switch to user %q from process euid %v", lu.Username, euid)
fmt.Fprintf(ss, "can't switch user\r\n") fmt.Fprintf(ss, "can't switch user\r\n")
ss.Exit(1) ss.Exit(1)
return return
@ -1141,10 +1145,7 @@ func (c *conn) matchRule(r *tailcfg.SSHRule, pubKey gossh.PublicKey) (a *tailcfg
if c == nil { if c == nil {
return nil, "", errInvalidConn return nil, "", errInvalidConn
} }
c.mu.Lock() if c.info == nil {
ci := c.info
c.mu.Unlock()
if ci == nil {
c.logf("invalid connection state") c.logf("invalid connection state")
return nil, "", errInvalidConn return nil, "", errInvalidConn
} }
@ -1161,7 +1162,7 @@ func (c *conn) matchRule(r *tailcfg.SSHRule, pubKey gossh.PublicKey) (a *tailcfg
// For all but Reject rules, SSHUsers is required. // For all but Reject rules, SSHUsers is required.
// If SSHUsers is nil or empty, mapLocalUser will return an // If SSHUsers is nil or empty, mapLocalUser will return an
// empty string anyway. // empty string anyway.
localUser = mapLocalUser(r.SSHUsers, ci.sshUser) localUser = mapLocalUser(r.SSHUsers, c.info.sshUser)
if localUser == "" { if localUser == "" {
return nil, "", errUserMatch return nil, "", errUserMatch
} }
@ -1210,9 +1211,7 @@ func (c *conn) principalMatches(p *tailcfg.SSHPrincipal, pubKey gossh.PublicKey)
// that match the Tailscale identity match (Node, NodeIP, UserLogin, Any). // that match the Tailscale identity match (Node, NodeIP, UserLogin, Any).
// This function does not consider PubKeys. // This function does not consider PubKeys.
func (c *conn) principalMatchesTailscaleIdentity(p *tailcfg.SSHPrincipal) bool { func (c *conn) principalMatchesTailscaleIdentity(p *tailcfg.SSHPrincipal) bool {
c.mu.Lock()
ci := c.info ci := c.info
c.mu.Unlock()
if p.Any { if p.Any {
return true return true
} }

@ -39,8 +39,10 @@ type Server struct {
Version string // server version to be sent before the initial handshake Version string // server version to be sent before the initial handshake
KeyboardInteractiveHandler KeyboardInteractiveHandler // keyboard-interactive authentication handler KeyboardInteractiveHandler KeyboardInteractiveHandler // keyboard-interactive authentication handler
BannerHandler BannerHandler
PasswordHandler PasswordHandler // password authentication handler PasswordHandler PasswordHandler // password authentication handler
PublicKeyHandler PublicKeyHandler // public key authentication handler PublicKeyHandler PublicKeyHandler // public key authentication handler
NoClientAuthHandler NoClientAuthHandler // no client authentication handler
PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil
ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling
LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil
@ -160,6 +162,21 @@ func (srv *Server) config(ctx Context) *gossh.ServerConfig {
return ctx.Permissions().Permissions, nil return ctx.Permissions().Permissions, nil
} }
} }
if srv.NoClientAuthHandler != nil {
config.NoClientAuthCallback = func(conn gossh.ConnMetadata) (*gossh.Permissions, error) {
applyConnMetadata(ctx, conn)
if err := srv.NoClientAuthHandler(ctx); err != nil {
return ctx.Permissions().Permissions, err
}
return ctx.Permissions().Permissions, nil
}
}
if srv.BannerHandler != nil {
config.BannerCallback = func(conn gossh.ConnMetadata) string {
applyConnMetadata(ctx, conn)
return srv.BannerHandler(ctx)
}
}
return config return config
} }

@ -38,6 +38,10 @@ type Handler func(Session)
// PublicKeyHandler is a callback for performing public key authentication. // PublicKeyHandler is a callback for performing public key authentication.
type PublicKeyHandler func(ctx Context, key PublicKey) error type PublicKeyHandler func(ctx Context, key PublicKey) error
type NoClientAuthHandler func(ctx Context) error
type BannerHandler func(ctx Context) string
// PasswordHandler is a callback for performing password authentication. // PasswordHandler is a callback for performing password authentication.
type PasswordHandler func(ctx Context, password string) bool type PasswordHandler func(ctx Context, password string) bool

Loading…
Cancel
Save