diff --git a/cmd/tailscale/cli/down.go b/cmd/tailscale/cli/down.go index 8feac33cc..9a139f315 100644 --- a/cmd/tailscale/cli/down.go +++ b/cmd/tailscale/cli/down.go @@ -6,6 +6,7 @@ package cli import ( "context" + "flag" "fmt" "github.com/peterbourgon/ff/v3/ffcli" @@ -17,7 +18,14 @@ var downCmd = &ffcli.Command{ ShortUsage: "down", ShortHelp: "Disconnect from Tailscale", - Exec: runDown, + Exec: runDown, + FlagSet: newDownFlagSet(), +} + +func newDownFlagSet() *flag.FlagSet { + downf := newFlagSet("down") + registerAcceptRiskFlag(downf) + return downf } func runDown(ctx context.Context, args []string) error { @@ -25,6 +33,12 @@ func runDown(ctx context.Context, args []string) error { return fmt.Errorf("too many non-flag arguments: %q", args) } + if isSSHOverTailscale() { + if err := presentRiskToUser(riskLoseSSH, `You are connected over Tailscale; this action will disable Tailscale and result in your session disconnecting.`); err != nil { + return err + } + } + st, err := localClient.Status(ctx) if err != nil { return fmt.Errorf("error fetching current status: %w", err) diff --git a/cmd/tailscale/cli/risks.go b/cmd/tailscale/cli/risks.go new file mode 100644 index 000000000..84bc188e2 --- /dev/null +++ b/cmd/tailscale/cli/risks.go @@ -0,0 +1,78 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cli + +import ( + "errors" + "flag" + "fmt" + "os" + "os/signal" + "strings" + "syscall" + "time" +) + +var ( + riskTypes []string + acceptedRisks string + riskLoseSSH = registerRiskType("lose-ssh") +) + +func registerRiskType(riskType string) string { + riskTypes = append(riskTypes, riskType) + return riskType +} + +// registerAcceptRiskFlag registers the --accept-risk flag. Accepted risks are accounted for +// in presentRiskToUser. +func registerAcceptRiskFlag(f *flag.FlagSet) { + f.StringVar(&acceptedRisks, "accept-risk", "", "accept risk and skip confirmation for risk types: "+strings.Join(riskTypes, ",")) +} + +// riskAccepted reports whether riskType is in acceptedRisks. +func riskAccepted(riskType string) bool { + for _, r := range strings.Split(acceptedRisks, ",") { + if r == riskType { + return true + } + } + return false +} + +var errAborted = errors.New("aborted, no changes made") + +// riskAbortTimeSeconds is the number of seconds to wait after displaying the +// risk message before continuing with the operation. +// It is used by the presentRiskToUser function below. +const riskAbortTimeSeconds = 5 + +// presentRiskToUser displays the risk message and waits for the user to +// cancel. It returns errorAborted if the user aborts. +func presentRiskToUser(riskType, riskMessage string) error { + if riskAccepted(riskType) { + return nil + } + fmt.Println(riskMessage) + fmt.Printf("To skip this warning, use --accept-risk=%s\n", riskType) + + interrupt := make(chan os.Signal, 1) + signal.Notify(interrupt, syscall.SIGINT) + var msgLen int + for left := riskAbortTimeSeconds; left > 0; left-- { + msg := fmt.Sprintf("\rContinuing in %d seconds...", left) + msgLen = len(msg) + fmt.Print(msg) + select { + case <-interrupt: + fmt.Printf("\r%s\r", strings.Repeat(" ", msgLen+1)) + return errAborted + case <-time.After(time.Second): + continue + } + } + fmt.Printf("\r%s\r", strings.Repeat(" ", msgLen)) + return errAborted +} diff --git a/cmd/tailscale/cli/ssh.go b/cmd/tailscale/cli/ssh.go index 703385ab2..24fb337f2 100644 --- a/cmd/tailscale/cli/ssh.go +++ b/cmd/tailscale/cli/ssh.go @@ -21,6 +21,7 @@ import ( "inet.af/netaddr" "tailscale.com/envknob" "tailscale.com/ipn/ipnstate" + "tailscale.com/net/tsaddr" ) var sshCmd = &ffcli.Command{ @@ -179,3 +180,28 @@ func nodeDNSNameFromArg(st *ipnstate.Status, arg string) (dnsName string, ok boo } return "", false } + +// getSSHClientEnvVar returns the "SSH_CLIENT" environment variable +// for the current process group, if any. +var getSSHClientEnvVar = func() string { + return "" +} + +// isSSHOverTailscale checks if the invocation is in a SSH session over Tailscale. +// It is used to detect if the user is about to take an action that might result in them +// disconnecting from the machine (e.g. disabling SSH) +func isSSHOverTailscale() bool { + sshClient := getSSHClientEnvVar() + if sshClient == "" { + return false + } + ipStr, _, ok := strings.Cut(sshClient, " ") + if !ok { + return false + } + ip, err := netaddr.ParseIP(ipStr) + if err != nil { + return false + } + return tsaddr.IsTailscaleIP(ip) +} diff --git a/cmd/tailscale/cli/ssh_unix.go b/cmd/tailscale/cli/ssh_unix.go new file mode 100644 index 000000000..4d10fe617 --- /dev/null +++ b/cmd/tailscale/cli/ssh_unix.go @@ -0,0 +1,51 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !js && !windows +// +build !js,!windows + +package cli + +import ( + "bytes" + "os" + "path/filepath" + "runtime" + "strconv" + + "golang.org/x/sys/unix" +) + +func init() { + getSSHClientEnvVar = func() string { + if os.Getenv("SUDO_USER") == "" { + // No sudo, just check the env. + return os.Getenv("SSH_CLIENT") + } + if runtime.GOOS != "linux" { + // TODO(maisem): implement this for other platforms. It's not clear + // if there is a way to get the environment for a given process on + // darwin and bsd. + return "" + } + // SID is the session ID of the user's login session. + // It is also the process ID of the original shell that the user logged in with. + // We only need to check the environment of that process. + sid, err := unix.Getsid(os.Getpid()) + if err != nil { + return "" + } + b, err := os.ReadFile(filepath.Join("/proc", strconv.Itoa(sid), "environ")) + if err != nil { + return "" + } + prefix := []byte("SSH_CLIENT=") + for _, env := range bytes.Split(b, []byte{0}) { + if bytes.HasPrefix(env, prefix) { + return string(env[len(prefix):]) + } + } + return "" + } +} diff --git a/cmd/tailscale/cli/up.go b/cmd/tailscale/cli/up.go index 09df80b03..a233745d7 100644 --- a/cmd/tailscale/cli/up.go +++ b/cmd/tailscale/cli/up.go @@ -114,6 +114,8 @@ func newUpFlagSet(goos string, upArgs *upArgsT) *flag.FlagSet { case "windows": upf.BoolVar(&upArgs.forceDaemon, "unattended", false, "run in \"Unattended Mode\" where Tailscale keeps running even after the current GUI user logs out (Windows-only)") } + + registerAcceptRiskFlag(upf) return upf } @@ -465,6 +467,18 @@ func runUp(ctx context.Context, args []string) error { backendState: st.BackendState, curExitNodeIP: exitNodeIP(curPrefs, st), } + + if upArgs.runSSH != curPrefs.RunSSH && isSSHOverTailscale() { + if upArgs.runSSH { + err = presentRiskToUser(riskLoseSSH, `You are connected over Tailscale; this action will reroute SSH traffic to Tailscale SSH and will result in your session disconnecting.`) + } else { + err = presentRiskToUser(riskLoseSSH, `You are connected using Tailscale SSH; this action will result in your session disconnecting.`) + } + if err != nil { + return err + } + } + simpleUp, justEditMP, err := updatePrefs(prefs, curPrefs, env) if err != nil { fatalf("%s", err) @@ -705,7 +719,7 @@ func addPrefFlagMapping(flagName string, prefNames ...string) { // correspond to an ipn.Pref. func preflessFlag(flagName string) bool { switch flagName { - case "auth-key", "force-reauth", "reset", "qr", "json": + case "auth-key", "force-reauth", "reset", "qr", "json", "accept-risk": return true } return false