diff --git a/ipn/ipnlocal/ssh.go b/ipn/ipnlocal/ssh.go index 221a80a7c..010f8f259 100644 --- a/ipn/ipnlocal/ssh.go +++ b/ipn/ipnlocal/ssh.go @@ -26,6 +26,7 @@ import ( "github.com/tailscale/golang-x-crypto/ssh" "tailscale.com/envknob" + "tailscale.com/util/mak" ) var useHostKeys = envknob.Bool("TS_USE_SYSTEM_SSH_HOST_KEYS") @@ -36,34 +37,39 @@ var useHostKeys = envknob.Bool("TS_USE_SYSTEM_SSH_HOST_KEYS") var keyTypes = []string{"rsa", "ecdsa", "ed25519"} func (b *LocalBackend) GetSSH_HostKeys() (keys []ssh.Signer, err error) { + var existing map[string]ssh.Signer if os.Geteuid() == 0 { - keys, err = b.getSystemSSH_HostKeys() - if err != nil || len(keys) > 0 { - return keys, err - } - // Otherwise, perhaps they don't have OpenSSH etc installed. - // Generate our own keys... + existing = b.getSystemSSH_HostKeys() } - return b.getTailscaleSSH_HostKeys() + return b.getTailscaleSSH_HostKeys(existing) } -func (b *LocalBackend) getTailscaleSSH_HostKeys() (keys []ssh.Signer, err error) { - root := b.TailscaleVarRoot() - if root == "" { - return nil, errors.New("no var root for ssh keys") - } - keyDir := filepath.Join(root, "ssh") - if err := os.MkdirAll(keyDir, 0700); err != nil { - return nil, err - } +// getTailscaleSSH_HostKeys returns the three (rsa, ecdsa, ed25519) SSH host +// keys, reusing the provided ones in existing if present in the map. +func (b *LocalBackend) getTailscaleSSH_HostKeys(existing map[string]ssh.Signer) (keys []ssh.Signer, err error) { + var keyDir string // lazily initialized $TAILSCALE_VAR/ssh dir. for _, typ := range keyTypes { + if s, ok := existing[typ]; ok { + keys = append(keys, s) + continue + } + if keyDir == "" { + root := b.TailscaleVarRoot() + if root == "" { + return nil, errors.New("no var root for ssh keys") + } + keyDir = filepath.Join(root, "ssh") + if err := os.MkdirAll(keyDir, 0700); err != nil { + return nil, err + } + } hostKey, err := b.hostKeyFileOrCreate(keyDir, typ) if err != nil { - return nil, err + return nil, fmt.Errorf("error creating SSH host key type %q in %q: %w", typ, keyDir, err) } signer, err := ssh.ParsePrivateKey(hostKey) if err != nil { - return nil, err + return nil, fmt.Errorf("error parsing SSH host key type %q from %q: %w", typ, keyDir, err) } keys = append(keys, signer) } @@ -115,24 +121,21 @@ func (b *LocalBackend) hostKeyFileOrCreate(keyDir, typ string) ([]byte, error) { return pemGen, err } -func (b *LocalBackend) getSystemSSH_HostKeys() (ret []ssh.Signer, err error) { - // TODO(bradfitz): cache this? +func (b *LocalBackend) getSystemSSH_HostKeys() (ret map[string]ssh.Signer) { for _, typ := range keyTypes { filename := "/etc/ssh/ssh_host_" + typ + "_key" hostKey, err := ioutil.ReadFile(filename) - if os.IsNotExist(err) || len(bytes.TrimSpace(hostKey)) == 0 { + if err != nil || len(bytes.TrimSpace(hostKey)) == 0 { continue } - if err != nil { - return nil, err - } signer, err := ssh.ParsePrivateKey(hostKey) if err != nil { - return nil, fmt.Errorf("error reading private key %s: %w", filename, err) + b.logf("warning: error reading host key %s: %v (generating one instead)", filename, err) + continue } - ret = append(ret, signer) + mak.Set(&ret, typ, signer) } - return ret, nil + return ret } func (b *LocalBackend) getSSHHostKeyPublicStrings() (ret []string) { diff --git a/ipn/ipnlocal/ssh_test.go b/ipn/ipnlocal/ssh_test.go index db0ad86b6..42c99755a 100644 --- a/ipn/ipnlocal/ssh_test.go +++ b/ipn/ipnlocal/ssh_test.go @@ -15,7 +15,7 @@ import ( func TestSSHKeyGen(t *testing.T) { dir := t.TempDir() lb := &LocalBackend{varRoot: dir} - keys, err := lb.getTailscaleSSH_HostKeys() + keys, err := lb.getTailscaleSSH_HostKeys(nil) if err != nil { t.Fatal(err) } @@ -32,7 +32,7 @@ func TestSSHKeyGen(t *testing.T) { t.Fatalf("keys = %v; want %v", got, want) } - keys2, err := lb.getTailscaleSSH_HostKeys() + keys2, err := lb.getTailscaleSSH_HostKeys(nil) if err != nil { t.Fatal(err) }