ssh/tailssh: close sessions on policy change if no longer allowed

Updates #3802

Change-Id: I98503c2505b77ac9d0cc792614fcdb691761a70c
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/4422/head
Brad Fitzpatrick 3 years ago committed by Brad Fitzpatrick
parent 4ec83fbad6
commit ade7bd8745

@ -76,6 +76,11 @@ func getControlDebugFlags() []string {
// SSHServer is the interface of the conditionally linked ssh/tailssh.server. // SSHServer is the interface of the conditionally linked ssh/tailssh.server.
type SSHServer interface { type SSHServer interface {
HandleSSHConn(net.Conn) error HandleSSHConn(net.Conn) error
// OnPolicyChange is called when the SSH access policy changes,
// so that existing sessions can be re-evaluated for validity
// and closed if they'd no longer be accepted.
OnPolicyChange()
} }
type newSSHServerFunc func(logger.Logf, *LocalBackend) (SSHServer, error) type newSSHServerFunc func(logger.Logf, *LocalBackend) (SSHServer, error)
@ -1148,6 +1153,10 @@ func (b *LocalBackend) updateFilterLocked(netMap *netmap.NetworkMap, prefs *ipn.
b.logf("[v1] netmap packet filter: %v filters", len(packetFilter)) b.logf("[v1] netmap packet filter: %v filters", len(packetFilter))
b.setFilter(filter.New(packetFilter, localNets, logNets, oldFilter, b.logf)) b.setFilter(filter.New(packetFilter, localNets, logNets, oldFilter, b.logf))
} }
if b.sshServer != nil {
go b.sshServer.OnPolicyChange()
}
} }
func (b *LocalBackend) setFilter(f *filter.Filter) { func (b *LocalBackend) setFilter(f *filter.Filter) {

@ -75,8 +75,8 @@ func init() {
} }
// HandleSSHConn handles a Tailscale SSH connection from c. // HandleSSHConn handles a Tailscale SSH connection from c.
func (s *server) HandleSSHConn(c net.Conn) error { func (srv *server) HandleSSHConn(c net.Conn) error {
ss, err := s.newSSHServer() ss, err := srv.newSSHServer()
if err != nil { if err != nil {
return err return err
} }
@ -88,6 +88,16 @@ func (s *server) HandleSSHConn(c net.Conn) error {
return nil return nil
} }
// OnPolicyChange terminates any active sessions that no longer match
// the SSH access policy.
func (srv *server) OnPolicyChange() {
srv.mu.Lock()
defer srv.mu.Unlock()
for _, s := range srv.activeSessionByH {
go s.checkStillValid()
}
}
func (srv *server) newSSHServer() (*ssh.Server, error) { func (srv *server) newSSHServer() (*ssh.Server, error) {
ss := &ssh.Server{ ss := &ssh.Server{
Handler: srv.handleSSH, Handler: srv.handleSSH,
@ -102,13 +112,13 @@ func (srv *server) newSSHServer() (*ssh.Server, error) {
Version: "SSH-2.0-Tailscale", Version: "SSH-2.0-Tailscale",
LocalPortForwardingCallback: srv.mayForwardLocalPortTo, LocalPortForwardingCallback: srv.mayForwardLocalPortTo,
NoClientAuthCallback: func(m gossh.ConnMetadata) (*gossh.Permissions, error) { NoClientAuthCallback: func(m gossh.ConnMetadata) (*gossh.Permissions, error) {
if srv.requiresPubKey(m.User(), m.LocalAddr(), m.RemoteAddr()) { if srv.requiresPubKey(m.User(), toIPPort(m.LocalAddr()), toIPPort(m.RemoteAddr())) {
return nil, errors.New("public key required") // any non-nil error will do return nil, errors.New("public key required") // any non-nil error will do
} }
return nil, nil return nil, nil
}, },
PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool { PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool {
if srv.acceptPubKey(ctx.User(), ctx.LocalAddr(), ctx.RemoteAddr(), key) { if srv.acceptPubKey(ctx.User(), toIPPort(ctx.LocalAddr()), toIPPort(ctx.RemoteAddr()), key) {
srv.logf("accepting SSH public key %s", bytes.TrimSpace(gossh.MarshalAuthorizedKey(key))) srv.logf("accepting SSH public key %s", bytes.TrimSpace(gossh.MarshalAuthorizedKey(key)))
return true return true
} }
@ -149,7 +159,7 @@ func (srv *server) mayForwardLocalPortTo(ctx ssh.Context, destinationHost string
// requiresPubKey reports whether the SSH server, during the auth negotiation // requiresPubKey reports whether the SSH server, during the auth negotiation
// phase, should requires that the client send an SSH public key. (or, more // phase, should requires that the client send an SSH public key. (or, more
// specifically, that "none" auth isn't acceptable) // specifically, that "none" auth isn't acceptable)
func (srv *server) requiresPubKey(sshUser string, localAddr, remoteAddr net.Addr) bool { func (srv *server) requiresPubKey(sshUser string, localAddr, remoteAddr netaddr.IPPort) bool {
pol, ok := srv.sshPolicy() pol, ok := srv.sshPolicy()
if !ok { if !ok {
return false return false
@ -184,7 +194,7 @@ func (srv *server) requiresPubKey(sshUser string, localAddr, remoteAddr net.Addr
return false return false
} }
func (srv *server) acceptPubKey(sshUser string, localAddr, remoteAddr net.Addr, pubKey ssh.PublicKey) bool { func (srv *server) acceptPubKey(sshUser string, localAddr, remoteAddr netaddr.IPPort, pubKey ssh.PublicKey) bool {
a, _, _, err := srv.evaluatePolicy(sshUser, localAddr, remoteAddr, pubKey) a, _, _, err := srv.evaluatePolicy(sshUser, localAddr, remoteAddr, pubKey)
if err != nil { if err != nil {
return false return false
@ -220,19 +230,16 @@ func (srv *server) sshPolicy() (_ *tailcfg.SSHPolicy, ok bool) {
return nil, false return nil, false
} }
func asTailscaleIPPort(a net.Addr) (netaddr.IPPort, error) { func toIPPort(a net.Addr) (ipp netaddr.IPPort) {
ta, ok := a.(*net.TCPAddr) ta, ok := a.(*net.TCPAddr)
if !ok { if !ok {
return netaddr.IPPort{}, fmt.Errorf("non-TCP addr %T %v", a, a) return
} }
tanetaddr, ok := netaddr.FromStdIP(ta.IP) tanetaddr, ok := netaddr.FromStdIP(ta.IP)
if !ok { if !ok {
return netaddr.IPPort{}, fmt.Errorf("unparseable addr %v", ta.IP) return
}
if !tsaddr.IsTailscaleIP(tanetaddr) {
return netaddr.IPPort{}, fmt.Errorf("non-Tailscale addr %v", ta.IP)
} }
return netaddr.IPPortFrom(tanetaddr, uint16(ta.Port)), nil return netaddr.IPPortFrom(tanetaddr, uint16(ta.Port))
} }
// evaluatePolicy returns the SSHAction, sshConnInfo and localUser after // evaluatePolicy returns the SSHAction, sshConnInfo and localUser after
@ -241,29 +248,27 @@ func asTailscaleIPPort(a net.Addr) (netaddr.IPPort, error) {
// //
// The return sshConnInfo will be non-nil, even on some errors, if the // The return sshConnInfo will be non-nil, even on some errors, if the
// evaluation made it far enough to resolve the remoteAddr to a Tailscale IP. // evaluation made it far enough to resolve the remoteAddr to a Tailscale IP.
func (srv *server) evaluatePolicy(sshUser string, localAddr, remoteAddr net.Addr, pubKey ssh.PublicKey) (_ *tailcfg.SSHAction, _ *sshConnInfo, localUser string, _ error) { func (srv *server) evaluatePolicy(sshUser string, localAddr, remoteAddr netaddr.IPPort, pubKey ssh.PublicKey) (_ *tailcfg.SSHAction, _ *sshConnInfo, localUser string, _ error) {
pol, ok := srv.sshPolicy() pol, ok := srv.sshPolicy()
if !ok { if !ok {
return nil, nil, "", fmt.Errorf("tsshd: rejecting connection; no SSH policy") return nil, nil, "", fmt.Errorf("tailssh: rejecting connection; no SSH policy")
} }
srcIPP, err := asTailscaleIPPort(remoteAddr) if !tsaddr.IsTailscaleIP(remoteAddr.IP()) {
if err != nil { return nil, nil, "", fmt.Errorf("tailssh: rejecting non-Tailscale remote address %v", remoteAddr)
return nil, nil, "", fmt.Errorf("tsshd: rejecting: %w", err)
} }
dstIPP, err := asTailscaleIPPort(localAddr) if !tsaddr.IsTailscaleIP(localAddr.IP()) {
if err != nil { return nil, nil, "", fmt.Errorf("tailssh: rejecting non-Tailscale remote address %v", localAddr)
return nil, nil, "", err
} }
node, uprof, ok := srv.lb.WhoIs(srcIPP) node, uprof, ok := srv.lb.WhoIs(remoteAddr)
if !ok { if !ok {
return nil, nil, "", fmt.Errorf("unknown Tailscale identity from src %v", srcIPP) return nil, nil, "", fmt.Errorf("unknown Tailscale identity from src %v", remoteAddr)
} }
ci := &sshConnInfo{ ci := &sshConnInfo{
now: time.Now(), now: time.Now(),
fetchPublicKeysURL: srv.fetchPublicKeysURL, fetchPublicKeysURL: srv.fetchPublicKeysURL,
sshUser: sshUser, sshUser: sshUser,
src: srcIPP, src: remoteAddr,
dst: dstIPP, dst: localAddr,
node: node, node: node,
uprof: &uprof, uprof: &uprof,
pubKey: pubKey, pubKey: pubKey,
@ -304,7 +309,7 @@ func (srv *server) handleSSH(s ssh.Session) {
logf := srv.logf logf := srv.logf
sshUser := s.User() sshUser := s.User()
action, ci, localUser, err := srv.evaluatePolicy(sshUser, s.LocalAddr(), s.RemoteAddr(), s.PublicKey()) action, ci, localUser, err := srv.evaluatePolicy(sshUser, toIPPort(s.LocalAddr()), toIPPort(s.RemoteAddr()), s.PublicKey())
if err != nil { if err != nil {
logf(err.Error()) logf(err.Error())
s.Exit(1) s.Exit(1)
@ -433,6 +438,21 @@ func (srv *server) newSSHSession(s ssh.Session, ci *sshConnInfo, lu *user.User)
} }
} }
// checkStillValid checks that the session is still valid per the latest SSHPolicy.
// If not, it terminates the session.
func (ss *sshSession) checkStillValid() {
ci := ss.connInfo
a, _, _, err := ss.srv.evaluatePolicy(ci.sshUser, ci.src, ci.dst, ci.pubKey)
if err == nil && (a.Accept || a.HoldAndDelegate != "") {
return
}
ss.logf("session no longer valid per new SSH policy; closing")
ss.ctx.CloseWithError(userVisibleError{
fmt.Sprintf("Access revoked.\n"),
context.Canceled,
})
}
func (srv *server) fetchSSHAction(ctx context.Context, url string) (*tailcfg.SSHAction, error) { func (srv *server) fetchSSHAction(ctx context.Context, url string) (*tailcfg.SSHAction, error) {
ctx, cancel := context.WithTimeout(ctx, 30*time.Minute) ctx, cancel := context.WithTimeout(ctx, 30*time.Minute)
defer cancel() defer cancel()

Loading…
Cancel
Save