diff --git a/cmd/tailscale/cli/ssh.go b/cmd/tailscale/cli/ssh.go index 2f96fd95d..4ee56dd45 100644 --- a/cmd/tailscale/cli/ssh.go +++ b/cmd/tailscale/cli/ssh.go @@ -5,6 +5,7 @@ package cli import ( + "bytes" "context" "errors" "fmt" @@ -12,13 +13,16 @@ import ( "os" "os/exec" "os/user" + "path/filepath" "runtime" "strings" "syscall" "github.com/alessio/shellescape" "github.com/peterbourgon/ff/v3/ffcli" + "tailscale.com/client/tailscale" "tailscale.com/envknob" + "tailscale.com/ipn/ipnstate" ) var sshCmd = &ffcli.Command{ @@ -52,9 +56,21 @@ func runSSH(ctx context.Context, args []string) error { if err != nil { return err } + st, err := tailscale.Status(ctx) + if err != nil { + return err + } + knownHostsFile, err := writeKnownHosts(st) + if err != nil { + return err + } + argv := append([]string{ ssh, + "-o", fmt.Sprintf("UserKnownHostsFile %s", + shellescape.Quote(knownHostsFile), + ), "-o", fmt.Sprintf("ProxyCommand %s --socket=%s nc %%h %%p", shellescape.Quote(tailscaleBin), shellescape.Quote(rootArgs.socket), @@ -95,3 +111,52 @@ func runSSH(ctx context.Context, args []string) error { } return errors.New("unreachable") } + +func writeKnownHosts(st *ipnstate.Status) (knownHostsFile string, err error) { + confDir, err := os.UserConfigDir() + if err != nil { + return "", err + } + tsConfDir := filepath.Join(confDir, "tailscale") + if err := os.MkdirAll(tsConfDir, 0700); err != nil { + return "", err + } + knownHostsFile = filepath.Join(tsConfDir, "ssh_known_hosts") + want := genKnownHosts(st) + if cur, err := os.ReadFile(knownHostsFile); err != nil || !bytes.Equal(cur, want) { + if err := os.WriteFile(knownHostsFile, want, 0644); err != nil { + return "", err + } + } + return knownHostsFile, nil +} + +func genKnownHosts(st *ipnstate.Status) []byte { + var buf bytes.Buffer + for _, k := range st.Peers() { + ps := st.Peer[k] + if len(ps.SSH_HostKeys) == 0 { + continue + } + // addEntries adds one line per each of p's host keys. + addEntries := func(host string) { + for _, hk := range ps.SSH_HostKeys { + hostKey := strings.TrimSpace(hk) + if strings.ContainsAny(hostKey, "\n\r") { // invalid + continue + } + fmt.Fprintf(&buf, "%s %s\n", host, hostKey) + } + } + if ps.DNSName != "" { + addEntries(ps.DNSName) + } + if base, _, ok := strings.Cut(ps.DNSName, "."); ok { + addEntries(base) + } + for _, ip := range st.TailscaleIPs { + addEntries(ip.String()) + } + } + return buf.Bytes() +}