mirror of https://github.com/tailscale/tailscale/
cmd/tailscale/cli: add lose-ssh risk
This makes it so that the user is notified that the action they are about to take may result in them getting disconnected from the machine. It then waits for 5s for the user to maybe Ctrl+C out of it. It also introduces a `--accept-risk=lose-ssh` flag for automation, which allows the caller to pre-acknowledge the risk. The two actions that cause this are: - updating `--ssh` from `true` to `false` - running `tailscale down` Updates #3802 Signed-off-by: Maisem Ali <maisem@tailscale.com>pull/4773/head
parent
1336fb740b
commit
67325d334e
@ -0,0 +1,78 @@
|
|||||||
|
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package cli
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
riskTypes []string
|
||||||
|
acceptedRisks string
|
||||||
|
riskLoseSSH = registerRiskType("lose-ssh")
|
||||||
|
)
|
||||||
|
|
||||||
|
func registerRiskType(riskType string) string {
|
||||||
|
riskTypes = append(riskTypes, riskType)
|
||||||
|
return riskType
|
||||||
|
}
|
||||||
|
|
||||||
|
// registerAcceptRiskFlag registers the --accept-risk flag. Accepted risks are accounted for
|
||||||
|
// in presentRiskToUser.
|
||||||
|
func registerAcceptRiskFlag(f *flag.FlagSet) {
|
||||||
|
f.StringVar(&acceptedRisks, "accept-risk", "", "accept risk and skip confirmation for risk types: "+strings.Join(riskTypes, ","))
|
||||||
|
}
|
||||||
|
|
||||||
|
// riskAccepted reports whether riskType is in acceptedRisks.
|
||||||
|
func riskAccepted(riskType string) bool {
|
||||||
|
for _, r := range strings.Split(acceptedRisks, ",") {
|
||||||
|
if r == riskType {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
var errAborted = errors.New("aborted, no changes made")
|
||||||
|
|
||||||
|
// riskAbortTimeSeconds is the number of seconds to wait after displaying the
|
||||||
|
// risk message before continuing with the operation.
|
||||||
|
// It is used by the presentRiskToUser function below.
|
||||||
|
const riskAbortTimeSeconds = 5
|
||||||
|
|
||||||
|
// presentRiskToUser displays the risk message and waits for the user to
|
||||||
|
// cancel. It returns errorAborted if the user aborts.
|
||||||
|
func presentRiskToUser(riskType, riskMessage string) error {
|
||||||
|
if riskAccepted(riskType) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
fmt.Println(riskMessage)
|
||||||
|
fmt.Printf("To skip this warning, use --accept-risk=%s\n", riskType)
|
||||||
|
|
||||||
|
interrupt := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(interrupt, syscall.SIGINT)
|
||||||
|
var msgLen int
|
||||||
|
for left := riskAbortTimeSeconds; left > 0; left-- {
|
||||||
|
msg := fmt.Sprintf("\rContinuing in %d seconds...", left)
|
||||||
|
msgLen = len(msg)
|
||||||
|
fmt.Print(msg)
|
||||||
|
select {
|
||||||
|
case <-interrupt:
|
||||||
|
fmt.Printf("\r%s\r", strings.Repeat(" ", msgLen+1))
|
||||||
|
return errAborted
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Printf("\r%s\r", strings.Repeat(" ", msgLen))
|
||||||
|
return errAborted
|
||||||
|
}
|
@ -0,0 +1,51 @@
|
|||||||
|
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
//go:build !js && !windows
|
||||||
|
// +build !js,!windows
|
||||||
|
|
||||||
|
package cli
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
getSSHClientEnvVar = func() string {
|
||||||
|
if os.Getenv("SUDO_USER") == "" {
|
||||||
|
// No sudo, just check the env.
|
||||||
|
return os.Getenv("SSH_CLIENT")
|
||||||
|
}
|
||||||
|
if runtime.GOOS != "linux" {
|
||||||
|
// TODO(maisem): implement this for other platforms. It's not clear
|
||||||
|
// if there is a way to get the environment for a given process on
|
||||||
|
// darwin and bsd.
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
// SID is the session ID of the user's login session.
|
||||||
|
// It is also the process ID of the original shell that the user logged in with.
|
||||||
|
// We only need to check the environment of that process.
|
||||||
|
sid, err := unix.Getsid(os.Getpid())
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
b, err := os.ReadFile(filepath.Join("/proc", strconv.Itoa(sid), "environ"))
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
prefix := []byte("SSH_CLIENT=")
|
||||||
|
for _, env := range bytes.Split(b, []byte{0}) {
|
||||||
|
if bytes.HasPrefix(env, prefix) {
|
||||||
|
return string(env[len(prefix):])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue