From 67325d334e67aebb67fba3c36aa073f2e8cd658f Mon Sep 17 00:00:00 2001 From: Maisem Ali Date: Thu, 2 Jun 2022 01:14:17 -0700 Subject: [PATCH] cmd/tailscale/cli: add lose-ssh risk This makes it so that the user is notified that the action they are about to take may result in them getting disconnected from the machine. It then waits for 5s for the user to maybe Ctrl+C out of it. It also introduces a `--accept-risk=lose-ssh` flag for automation, which allows the caller to pre-acknowledge the risk. The two actions that cause this are: - updating `--ssh` from `true` to `false` - running `tailscale down` Updates #3802 Signed-off-by: Maisem Ali --- cmd/tailscale/cli/down.go | 16 ++++++- cmd/tailscale/cli/risks.go | 78 +++++++++++++++++++++++++++++++++++ cmd/tailscale/cli/ssh.go | 26 ++++++++++++ cmd/tailscale/cli/ssh_unix.go | 51 +++++++++++++++++++++++ cmd/tailscale/cli/up.go | 16 ++++++- 5 files changed, 185 insertions(+), 2 deletions(-) create mode 100644 cmd/tailscale/cli/risks.go create mode 100644 cmd/tailscale/cli/ssh_unix.go diff --git a/cmd/tailscale/cli/down.go b/cmd/tailscale/cli/down.go index 8feac33cc..9a139f315 100644 --- a/cmd/tailscale/cli/down.go +++ b/cmd/tailscale/cli/down.go @@ -6,6 +6,7 @@ package cli import ( "context" + "flag" "fmt" "github.com/peterbourgon/ff/v3/ffcli" @@ -17,7 +18,14 @@ var downCmd = &ffcli.Command{ ShortUsage: "down", ShortHelp: "Disconnect from Tailscale", - Exec: runDown, + Exec: runDown, + FlagSet: newDownFlagSet(), +} + +func newDownFlagSet() *flag.FlagSet { + downf := newFlagSet("down") + registerAcceptRiskFlag(downf) + return downf } func runDown(ctx context.Context, args []string) error { @@ -25,6 +33,12 @@ func runDown(ctx context.Context, args []string) error { return fmt.Errorf("too many non-flag arguments: %q", args) } + if isSSHOverTailscale() { + if err := presentRiskToUser(riskLoseSSH, `You are connected over Tailscale; this action will disable Tailscale and result in your session disconnecting.`); err != nil { + return err + } + } + st, err := localClient.Status(ctx) if err != nil { return fmt.Errorf("error fetching current status: %w", err) diff --git a/cmd/tailscale/cli/risks.go b/cmd/tailscale/cli/risks.go new file mode 100644 index 000000000..84bc188e2 --- /dev/null +++ b/cmd/tailscale/cli/risks.go @@ -0,0 +1,78 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cli + +import ( + "errors" + "flag" + "fmt" + "os" + "os/signal" + "strings" + "syscall" + "time" +) + +var ( + riskTypes []string + acceptedRisks string + riskLoseSSH = registerRiskType("lose-ssh") +) + +func registerRiskType(riskType string) string { + riskTypes = append(riskTypes, riskType) + return riskType +} + +// registerAcceptRiskFlag registers the --accept-risk flag. Accepted risks are accounted for +// in presentRiskToUser. +func registerAcceptRiskFlag(f *flag.FlagSet) { + f.StringVar(&acceptedRisks, "accept-risk", "", "accept risk and skip confirmation for risk types: "+strings.Join(riskTypes, ",")) +} + +// riskAccepted reports whether riskType is in acceptedRisks. +func riskAccepted(riskType string) bool { + for _, r := range strings.Split(acceptedRisks, ",") { + if r == riskType { + return true + } + } + return false +} + +var errAborted = errors.New("aborted, no changes made") + +// riskAbortTimeSeconds is the number of seconds to wait after displaying the +// risk message before continuing with the operation. +// It is used by the presentRiskToUser function below. +const riskAbortTimeSeconds = 5 + +// presentRiskToUser displays the risk message and waits for the user to +// cancel. It returns errorAborted if the user aborts. +func presentRiskToUser(riskType, riskMessage string) error { + if riskAccepted(riskType) { + return nil + } + fmt.Println(riskMessage) + fmt.Printf("To skip this warning, use --accept-risk=%s\n", riskType) + + interrupt := make(chan os.Signal, 1) + signal.Notify(interrupt, syscall.SIGINT) + var msgLen int + for left := riskAbortTimeSeconds; left > 0; left-- { + msg := fmt.Sprintf("\rContinuing in %d seconds...", left) + msgLen = len(msg) + fmt.Print(msg) + select { + case <-interrupt: + fmt.Printf("\r%s\r", strings.Repeat(" ", msgLen+1)) + return errAborted + case <-time.After(time.Second): + continue + } + } + fmt.Printf("\r%s\r", strings.Repeat(" ", msgLen)) + return errAborted +} diff --git a/cmd/tailscale/cli/ssh.go b/cmd/tailscale/cli/ssh.go index 703385ab2..24fb337f2 100644 --- a/cmd/tailscale/cli/ssh.go +++ b/cmd/tailscale/cli/ssh.go @@ -21,6 +21,7 @@ import ( "inet.af/netaddr" "tailscale.com/envknob" "tailscale.com/ipn/ipnstate" + "tailscale.com/net/tsaddr" ) var sshCmd = &ffcli.Command{ @@ -179,3 +180,28 @@ func nodeDNSNameFromArg(st *ipnstate.Status, arg string) (dnsName string, ok boo } return "", false } + +// getSSHClientEnvVar returns the "SSH_CLIENT" environment variable +// for the current process group, if any. +var getSSHClientEnvVar = func() string { + return "" +} + +// isSSHOverTailscale checks if the invocation is in a SSH session over Tailscale. +// It is used to detect if the user is about to take an action that might result in them +// disconnecting from the machine (e.g. disabling SSH) +func isSSHOverTailscale() bool { + sshClient := getSSHClientEnvVar() + if sshClient == "" { + return false + } + ipStr, _, ok := strings.Cut(sshClient, " ") + if !ok { + return false + } + ip, err := netaddr.ParseIP(ipStr) + if err != nil { + return false + } + return tsaddr.IsTailscaleIP(ip) +} diff --git a/cmd/tailscale/cli/ssh_unix.go b/cmd/tailscale/cli/ssh_unix.go new file mode 100644 index 000000000..4d10fe617 --- /dev/null +++ b/cmd/tailscale/cli/ssh_unix.go @@ -0,0 +1,51 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !js && !windows +// +build !js,!windows + +package cli + +import ( + "bytes" + "os" + "path/filepath" + "runtime" + "strconv" + + "golang.org/x/sys/unix" +) + +func init() { + getSSHClientEnvVar = func() string { + if os.Getenv("SUDO_USER") == "" { + // No sudo, just check the env. + return os.Getenv("SSH_CLIENT") + } + if runtime.GOOS != "linux" { + // TODO(maisem): implement this for other platforms. It's not clear + // if there is a way to get the environment for a given process on + // darwin and bsd. + return "" + } + // SID is the session ID of the user's login session. + // It is also the process ID of the original shell that the user logged in with. + // We only need to check the environment of that process. + sid, err := unix.Getsid(os.Getpid()) + if err != nil { + return "" + } + b, err := os.ReadFile(filepath.Join("/proc", strconv.Itoa(sid), "environ")) + if err != nil { + return "" + } + prefix := []byte("SSH_CLIENT=") + for _, env := range bytes.Split(b, []byte{0}) { + if bytes.HasPrefix(env, prefix) { + return string(env[len(prefix):]) + } + } + return "" + } +} diff --git a/cmd/tailscale/cli/up.go b/cmd/tailscale/cli/up.go index 09df80b03..a233745d7 100644 --- a/cmd/tailscale/cli/up.go +++ b/cmd/tailscale/cli/up.go @@ -114,6 +114,8 @@ func newUpFlagSet(goos string, upArgs *upArgsT) *flag.FlagSet { case "windows": upf.BoolVar(&upArgs.forceDaemon, "unattended", false, "run in \"Unattended Mode\" where Tailscale keeps running even after the current GUI user logs out (Windows-only)") } + + registerAcceptRiskFlag(upf) return upf } @@ -465,6 +467,18 @@ func runUp(ctx context.Context, args []string) error { backendState: st.BackendState, curExitNodeIP: exitNodeIP(curPrefs, st), } + + if upArgs.runSSH != curPrefs.RunSSH && isSSHOverTailscale() { + if upArgs.runSSH { + err = presentRiskToUser(riskLoseSSH, `You are connected over Tailscale; this action will reroute SSH traffic to Tailscale SSH and will result in your session disconnecting.`) + } else { + err = presentRiskToUser(riskLoseSSH, `You are connected using Tailscale SSH; this action will result in your session disconnecting.`) + } + if err != nil { + return err + } + } + simpleUp, justEditMP, err := updatePrefs(prefs, curPrefs, env) if err != nil { fatalf("%s", err) @@ -705,7 +719,7 @@ func addPrefFlagMapping(flagName string, prefNames ...string) { // correspond to an ipn.Pref. func preflessFlag(flagName string) bool { switch flagName { - case "auth-key", "force-reauth", "reset", "qr", "json": + case "auth-key", "force-reauth", "reset", "qr", "json", "accept-risk": return true } return false