diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index ce8e58342..6766fac98 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -125,8 +125,7 @@ type LocalBackend struct { newDecompressor func() (controlclient.Decompressor, error) varRoot string // or empty if SetVarRoot never called sshAtomicBool syncs.AtomicBool - sshServer SSHServer // or nil - shutdownCalled bool // if Shutdown has been called + shutdownCalled bool // if Shutdown has been called filterAtomic atomic.Value // of *filter.Filter containsViaIPFuncAtomic atomic.Value // of func(netaddr.IP) bool @@ -136,6 +135,7 @@ type LocalBackend struct { filterHash deephash.Sum httpTestClient *http.Client // for controlclient. nil by default, used by tests. ccGen clientGen // function for producing controlclient; lazily populated + sshServer SSHServer // or nil, initialized lazily. notify func(ipn.Notify) cc controlclient.Client stateKey ipn.StateKey // computed in part from user-provided value @@ -228,12 +228,6 @@ func NewLocalBackend(logf logger.Logf, logid string, store ipn.StateStore, diale gotPortPollRes: make(chan struct{}), loginFlags: loginFlags, } - if newSSHServer != nil { - b.sshServer, err = newSSHServer(logf, b) - if err != nil { - return nil, fmt.Errorf("newSSHServer: %w", err) - } - } // Default filter blocks everything and logs nothing, until Start() is called. b.setFilter(filter.NewAllowNone(logf, &netaddr.IPSet{})) @@ -351,6 +345,7 @@ func (b *LocalBackend) Shutdown() { cc := b.cc if b.sshServer != nil { b.sshServer.Shutdown() + b.sshServer = nil } b.closePeerAPIListenersLocked() b.mu.Unlock() @@ -1932,6 +1927,12 @@ func (b *LocalBackend) setPrefsLockedOnEntry(caller string, newp *ipn.Prefs) { } b.updateFilterLocked(netMap, newp) + if oldp.ShouldSSHBeRunning() && !newp.ShouldSSHBeRunning() { + if b.sshServer != nil { + go b.sshServer.Shutdown() + b.sshServer = nil + } + } b.mu.Unlock() if stateKey != "" { @@ -1975,10 +1976,6 @@ func (b *LocalBackend) setPrefsLockedOnEntry(caller string, newp *ipn.Prefs) { b.authReconfig() } - if oldp.RunSSH && !newp.RunSSH && b.sshServer != nil { - go b.sshServer.OnPolicyChange() - } - b.send(ipn.Notify{Prefs: newp}) } @@ -3367,11 +3364,28 @@ func (b *LocalBackend) DoNoiseRequest(req *http.Request) (*http.Response, error) return cc.DoNoiseRequest(req) } -func (b *LocalBackend) HandleSSHConn(c net.Conn) error { - if b.sshServer == nil { - return errors.New("no SSH server") +func (b *LocalBackend) sshServerOrInit() (_ SSHServer, err error) { + b.mu.Lock() + defer b.mu.Unlock() + if b.sshServer != nil { + return b.sshServer, nil + } + if newSSHServer == nil { + return nil, errors.New("no SSH server support") + } + b.sshServer, err = newSSHServer(b.logf, b) + if err != nil { + return nil, fmt.Errorf("newSSHServer: %w", err) + } + return b.sshServer, nil +} + +func (b *LocalBackend) HandleSSHConn(c net.Conn) (err error) { + s, err := b.sshServerOrInit() + if err != nil { + return err } - return b.sshServer.HandleSSHConn(c) + return s.HandleSSHConn(c) } // HandleQuad100Port80Conn serves http://100.100.100.100/ on port 80 (and diff --git a/ipn/prefs.go b/ipn/prefs.go index 47235eaca..46df9173b 100644 --- a/ipn/prefs.go +++ b/ipn/prefs.go @@ -584,6 +584,12 @@ func (p *Prefs) SetExitNodeIP(s string, st *ipnstate.Status) error { return err } +// ShouldSSHBeRunning reports whether the SSH server should be running based on +// the prefs. +func (p *Prefs) ShouldSSHBeRunning() bool { + return p.WantRunning && p.RunSSH +} + // PrefsFromBytes deserializes Prefs from a JSON blob. func PrefsFromBytes(b []byte) (*Prefs, error) { p := NewPrefs() diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index 97d36e55d..2a14d84a2 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -110,7 +110,7 @@ func (srv *server) Shutdown() { srv.shutdownCalled = true for _, s := range srv.activeSessionByH { s.ctx.CloseWithError(userVisibleError{ - fmt.Sprintf("Tailscale shutting down.\r\n"), + fmt.Sprintf("Tailscale SSH is shutting down.\r\n"), context.Canceled, }) } @@ -876,7 +876,7 @@ func (ss *sshSession) run() { if srv.shutdownCalled { srv.mu.Unlock() // Do not start any new sessions. - fmt.Fprintf(ss, "Tailscale is shutting down\r\n") + fmt.Fprintf(ss, "Tailscale SSH is shutting down\r\n") ss.Exit(1) return }