control/controlclient, ipn: store machine key separately from user prefs/persist

Updates #610 (fixes after some win/xcode changes in a separate repo)
reviewable/pr802/r3
Brad Fitzpatrick 4 years ago
parent e1596d655a
commit b5a3850d29

@ -45,8 +45,19 @@ import (
) )
type Persist struct { type Persist struct {
_ structs.Incomparable _ structs.Incomparable
PrivateMachineKey wgcfg.PrivateKey
// LegacyFrontendPrivateMachineKey is here temporarily
// (starting 2020-09-28) during migration of Windows users'
// machine keys from frontend storage to the backend. On the
// first LocalBackend.Start call, the backend will initialize
// the real (backend-owned) machine key from the frontend's
// provided value (if non-zero), picking a new random one if
// needed. This field should be considered read-only from GUI
// frontends. The real value should not be written back in
// this field, lest the frontend persist it to disk.
LegacyFrontendPrivateMachineKey wgcfg.PrivateKey `json:"PrivateMachineKey"`
PrivateNodeKey wgcfg.PrivateKey PrivateNodeKey wgcfg.PrivateKey
OldPrivateNodeKey wgcfg.PrivateKey // needed to request key rotation OldPrivateNodeKey wgcfg.PrivateKey // needed to request key rotation
Provider string Provider string
@ -61,7 +72,7 @@ func (p *Persist) Equals(p2 *Persist) bool {
return false return false
} }
return p.PrivateMachineKey.Equal(p2.PrivateMachineKey) && return p.LegacyFrontendPrivateMachineKey.Equal(p2.LegacyFrontendPrivateMachineKey) &&
p.PrivateNodeKey.Equal(p2.PrivateNodeKey) && p.PrivateNodeKey.Equal(p2.PrivateNodeKey) &&
p.OldPrivateNodeKey.Equal(p2.OldPrivateNodeKey) && p.OldPrivateNodeKey.Equal(p2.OldPrivateNodeKey) &&
p.Provider == p2.Provider && p.Provider == p2.Provider &&
@ -70,8 +81,8 @@ func (p *Persist) Equals(p2 *Persist) bool {
func (p *Persist) Pretty() string { func (p *Persist) Pretty() string {
var mk, ok, nk wgcfg.Key var mk, ok, nk wgcfg.Key
if !p.PrivateMachineKey.IsZero() { if !p.LegacyFrontendPrivateMachineKey.IsZero() {
mk = p.PrivateMachineKey.Public() mk = p.LegacyFrontendPrivateMachineKey.Public()
} }
if !p.OldPrivateNodeKey.IsZero() { if !p.OldPrivateNodeKey.IsZero() {
ok = p.OldPrivateNodeKey.Public() ok = p.OldPrivateNodeKey.Public()
@ -79,7 +90,7 @@ func (p *Persist) Pretty() string {
if !p.PrivateNodeKey.IsZero() { if !p.PrivateNodeKey.IsZero() {
nk = p.PrivateNodeKey.Public() nk = p.PrivateNodeKey.Public()
} }
return fmt.Sprintf("Persist{m=%v, o=%v, n=%v u=%#v}", return fmt.Sprintf("Persist{lm=%v, o=%v, n=%v u=%#v}",
mk.ShortString(), ok.ShortString(), nk.ShortString(), mk.ShortString(), ok.ShortString(), nk.ShortString(),
p.LoginName) p.LoginName)
} }
@ -94,6 +105,7 @@ type Direct struct {
keepAlive bool keepAlive bool
logf logger.Logf logf logger.Logf
discoPubKey tailcfg.DiscoKey discoPubKey tailcfg.DiscoKey
machinePrivKey wgcfg.PrivateKey
mu sync.Mutex // mutex guards the following fields mu sync.Mutex // mutex guards the following fields
serverKey wgcfg.Key serverKey wgcfg.Key
@ -108,16 +120,17 @@ type Direct struct {
} }
type Options struct { type Options struct {
Persist Persist // initial persistent data Persist Persist // initial persistent data
ServerURL string // URL of the tailcontrol server MachinePrivateKey wgcfg.PrivateKey // the machine key to use
AuthKey string // optional node auth key for auto registration ServerURL string // URL of the tailcontrol server
TimeNow func() time.Time // time.Now implementation used by Client AuthKey string // optional node auth key for auto registration
Hostinfo *tailcfg.Hostinfo // non-nil passes ownership, nil means to use default using os.Hostname, etc TimeNow func() time.Time // time.Now implementation used by Client
DiscoPublicKey tailcfg.DiscoKey Hostinfo *tailcfg.Hostinfo // non-nil passes ownership, nil means to use default using os.Hostname, etc
NewDecompressor func() (Decompressor, error) DiscoPublicKey tailcfg.DiscoKey
KeepAlive bool NewDecompressor func() (Decompressor, error)
Logf logger.Logf KeepAlive bool
HTTPTestClient *http.Client // optional HTTP client to use (for tests only) Logf logger.Logf
HTTPTestClient *http.Client // optional HTTP client to use (for tests only)
} }
type Decompressor interface { type Decompressor interface {
@ -130,6 +143,9 @@ func NewDirect(opts Options) (*Direct, error) {
if opts.ServerURL == "" { if opts.ServerURL == "" {
return nil, errors.New("controlclient.New: no server URL specified") return nil, errors.New("controlclient.New: no server URL specified")
} }
if opts.MachinePrivateKey.IsZero() {
return nil, errors.New("controlclient.New: no MachinePrivateKey specified")
}
opts.ServerURL = strings.TrimRight(opts.ServerURL, "/") opts.ServerURL = strings.TrimRight(opts.ServerURL, "/")
serverURL, err := url.Parse(opts.ServerURL) serverURL, err := url.Parse(opts.ServerURL)
if err != nil { if err != nil {
@ -158,6 +174,7 @@ func NewDirect(opts Options) (*Direct, error) {
c := &Direct{ c := &Direct{
httpc: httpc, httpc: httpc,
machinePrivKey: opts.MachinePrivateKey,
serverURL: opts.ServerURL, serverURL: opts.ServerURL,
timeNow: opts.TimeNow, timeNow: opts.TimeNow,
logf: opts.Logf, logf: opts.Logf,
@ -251,14 +268,12 @@ func (c *Direct) TryLogout(ctx context.Context) error {
// immediately invalidated. // immediately invalidated.
//if !c.persist.PrivateNodeKey.IsZero() { //if !c.persist.PrivateNodeKey.IsZero() {
//} //}
c.persist = Persist{ c.persist = Persist{}
PrivateMachineKey: c.persist.PrivateMachineKey,
}
return nil return nil
} }
func (c *Direct) TryLogin(ctx context.Context, t *oauth2.Token, flags LoginFlags) (url string, err error) { func (c *Direct) TryLogin(ctx context.Context, t *oauth2.Token, flags LoginFlags) (url string, err error) {
c.logf("direct.TryLogin(%v, %v)", t != nil, flags) c.logf("direct.TryLogin(token=%v, flags=%v)", t != nil, flags)
return c.doLoginOrRegen(ctx, t, flags, false, "") return c.doLoginOrRegen(ctx, t, flags, false, "")
} }
@ -289,13 +304,8 @@ func (c *Direct) doLogin(ctx context.Context, t *oauth2.Token, flags LoginFlags,
expired := c.expiry != nil && !c.expiry.IsZero() && c.expiry.Before(c.timeNow()) expired := c.expiry != nil && !c.expiry.IsZero() && c.expiry.Before(c.timeNow())
c.mu.Unlock() c.mu.Unlock()
if persist.PrivateMachineKey.IsZero() { if c.machinePrivKey.IsZero() {
c.logf("Generating a new machinekey.") return false, "", errors.New("controlclient.Direct requires a machine private key")
mkey, err := wgcfg.NewPrivateKey()
if err != nil {
log.Fatal(err)
}
persist.PrivateMachineKey = mkey
} }
if expired { if expired {
@ -360,13 +370,13 @@ func (c *Direct) doLogin(ctx context.Context, t *oauth2.Token, flags LoginFlags,
request.Auth.Provider = persist.Provider request.Auth.Provider = persist.Provider
request.Auth.LoginName = persist.LoginName request.Auth.LoginName = persist.LoginName
request.Auth.AuthKey = authKey request.Auth.AuthKey = authKey
bodyData, err := encode(request, &serverKey, &persist.PrivateMachineKey) bodyData, err := encode(request, &serverKey, &c.machinePrivKey)
if err != nil { if err != nil {
return regen, url, err return regen, url, err
} }
body := bytes.NewReader(bodyData) body := bytes.NewReader(bodyData)
u := fmt.Sprintf("%s/machine/%s", c.serverURL, persist.PrivateMachineKey.Public().HexString()) u := fmt.Sprintf("%s/machine/%s", c.serverURL, c.machinePrivKey.Public().HexString())
req, err := http.NewRequest("POST", u, body) req, err := http.NewRequest("POST", u, body)
if err != nil { if err != nil {
return regen, url, err return regen, url, err
@ -377,11 +387,14 @@ func (c *Direct) doLogin(ctx context.Context, t *oauth2.Token, flags LoginFlags,
if err != nil { if err != nil {
return regen, url, fmt.Errorf("register request: %v", err) return regen, url, fmt.Errorf("register request: %v", err)
} }
c.logf("RegisterReq: returned.")
resp := tailcfg.RegisterResponse{} resp := tailcfg.RegisterResponse{}
if err := decode(res, &resp, &serverKey, &persist.PrivateMachineKey); err != nil { if err := decode(res, &resp, &serverKey, &c.machinePrivKey); err != nil {
c.logf("error decoding RegisterReq: %v", err)
return regen, url, fmt.Errorf("register request: %v", err) return regen, url, fmt.Errorf("register request: %v", err)
} }
// Log without PII:
c.logf("RegisterReq: got response; nodeKeyExpired=%v, machineAuthorized=%v; authURL=%v",
resp.NodeKeyExpired, resp.MachineAuthorized, resp.AuthURL != "")
if resp.NodeKeyExpired { if resp.NodeKeyExpired {
if regen { if regen {
@ -507,14 +520,15 @@ func (c *Direct) PollNetMap(ctx context.Context, maxPolls int, cb func(*NetworkM
request.Compress = "zstd" request.Compress = "zstd"
} }
bodyData, err := encode(request, &serverKey, &persist.PrivateMachineKey) bodyData, err := encode(request, &serverKey, &c.machinePrivKey)
if err != nil { if err != nil {
vlogf("netmap: encode: %v", err) vlogf("netmap: encode: %v", err)
return err return err
} }
machinePubKey := tailcfg.MachineKey(c.machinePrivKey.Public())
t0 := time.Now() t0 := time.Now()
u := fmt.Sprintf("%s/machine/%s/map", serverURL, persist.PrivateMachineKey.Public().HexString()) u := fmt.Sprintf("%s/machine/%s/map", serverURL, machinePubKey.HexString())
req, err := http.NewRequest("POST", u, bytes.NewReader(bodyData)) req, err := http.NewRequest("POST", u, bytes.NewReader(bodyData))
if err != nil { if err != nil {
return err return err
@ -648,6 +662,7 @@ func (c *Direct) PollNetMap(ctx context.Context, maxPolls int, cb func(*NetworkM
nm := &NetworkMap{ nm := &NetworkMap{
NodeKey: tailcfg.NodeKey(persist.PrivateNodeKey.Public()), NodeKey: tailcfg.NodeKey(persist.PrivateNodeKey.Public()),
PrivateKey: persist.PrivateNodeKey, PrivateKey: persist.PrivateNodeKey,
MachineKey: machinePubKey,
Expiry: resp.Node.KeyExpiry, Expiry: resp.Node.KeyExpiry,
Name: resp.Node.Name, Name: resp.Node.Name,
Addresses: resp.Node.Addresses, Addresses: resp.Node.Addresses,
@ -719,11 +734,10 @@ var dumpMapResponse, _ = strconv.ParseBool(os.Getenv("TS_DEBUG_MAPRESPONSE"))
func (c *Direct) decodeMsg(msg []byte, v interface{}) error { func (c *Direct) decodeMsg(msg []byte, v interface{}) error {
c.mu.Lock() c.mu.Lock()
mkey := c.persist.PrivateMachineKey
serverKey := c.serverKey serverKey := c.serverKey
c.mu.Unlock() c.mu.Unlock()
decrypted, err := decryptMsg(msg, &serverKey, &mkey) decrypted, err := decryptMsg(msg, &serverKey, &c.machinePrivKey)
if err != nil { if err != nil {
return err return err
} }

@ -30,6 +30,7 @@ type NetworkMap struct {
Addresses []wgcfg.CIDR Addresses []wgcfg.CIDR
LocalPort uint16 // used for debugging LocalPort uint16 // used for debugging
MachineStatus tailcfg.MachineStatus MachineStatus tailcfg.MachineStatus
MachineKey tailcfg.MachineKey
Peers []*tailcfg.Node // sorted by Node.ID Peers []*tailcfg.Node // sorted by Node.ID
DNS tailcfg.DNSConfig DNS tailcfg.DNSConfig
Hostinfo tailcfg.Hostinfo Hostinfo tailcfg.Hostinfo

@ -12,7 +12,7 @@ import (
) )
func TestPersistEqual(t *testing.T) { func TestPersistEqual(t *testing.T) {
persistHandles := []string{"PrivateMachineKey", "PrivateNodeKey", "OldPrivateNodeKey", "Provider", "LoginName"} persistHandles := []string{"LegacyFrontendPrivateMachineKey", "PrivateNodeKey", "OldPrivateNodeKey", "Provider", "LoginName"}
if have := fieldsOf(reflect.TypeOf(Persist{})); !reflect.DeepEqual(have, persistHandles) { if have := fieldsOf(reflect.TypeOf(Persist{})); !reflect.DeepEqual(have, persistHandles) {
t.Errorf("Persist.Equal check might be out of sync\nfields: %q\nhandled: %q\n", t.Errorf("Persist.Equal check might be out of sync\nfields: %q\nhandled: %q\n",
have, persistHandles) have, persistHandles)
@ -36,13 +36,13 @@ func TestPersistEqual(t *testing.T) {
{&Persist{}, &Persist{}, true}, {&Persist{}, &Persist{}, true},
{ {
&Persist{PrivateMachineKey: k1}, &Persist{LegacyFrontendPrivateMachineKey: k1},
&Persist{PrivateMachineKey: newPrivate()}, &Persist{LegacyFrontendPrivateMachineKey: newPrivate()},
false, false,
}, },
{ {
&Persist{PrivateMachineKey: k1}, &Persist{LegacyFrontendPrivateMachineKey: k1},
&Persist{PrivateMachineKey: k1}, &Persist{LegacyFrontendPrivateMachineKey: k1},
true, true,
}, },

@ -95,7 +95,8 @@ type Options struct {
// StateKey and Prefs together define the state the backend should // StateKey and Prefs together define the state the backend should
// use: // use:
// - StateKey=="" && Prefs!=nil: use Prefs for internal state, // - StateKey=="" && Prefs!=nil: use Prefs for internal state,
// don't persist changes in the backend. // don't persist changes in the backend, except for the machine key
// for migration purposes.
// - StateKey!="" && Prefs==nil: load the given backend-side // - StateKey!="" && Prefs==nil: load the given backend-side
// state and use/update that. // state and use/update that.
// - StateKey!="" && Prefs!=nil: like the previous case, but do // - StateKey!="" && Prefs!=nil: like the previous case, but do

@ -5,6 +5,7 @@
package ipn package ipn
import ( import (
"bytes"
"context" "context"
"errors" "errors"
"fmt" "fmt"
@ -62,12 +63,13 @@ type LocalBackend struct {
filterHash string filterHash string
// The mutex protects the following elements. // The mutex protects the following elements.
mu sync.Mutex mu sync.Mutex
notify func(Notify) notify func(Notify)
c *controlclient.Client c *controlclient.Client
stateKey StateKey stateKey StateKey
prefs *Prefs prefs *Prefs
state State machinePrivKey wgcfg.PrivateKey
state State
// hostinfo is mutated in-place while mu is held. // hostinfo is mutated in-place while mu is held.
hostinfo *tailcfg.Hostinfo hostinfo *tailcfg.Hostinfo
// netMap is not mutated in-place once set. // netMap is not mutated in-place once set.
@ -382,6 +384,7 @@ func (b *LocalBackend) Start(opts Options) error {
b.notify = opts.Notify b.notify = opts.Notify
b.netMap = nil b.netMap = nil
persist := b.prefs.Persist persist := b.prefs.Persist
machinePrivKey := b.machinePrivKey
b.mu.Unlock() b.mu.Unlock()
b.updateFilter(nil, nil) b.updateFilter(nil, nil)
@ -397,15 +400,16 @@ func (b *LocalBackend) Start(opts Options) error {
persist = &controlclient.Persist{} persist = &controlclient.Persist{}
} }
cli, err := controlclient.New(controlclient.Options{ cli, err := controlclient.New(controlclient.Options{
Logf: logger.WithPrefix(b.logf, "control: "), MachinePrivateKey: machinePrivKey,
Persist: *persist, Logf: logger.WithPrefix(b.logf, "control: "),
ServerURL: b.serverURL, Persist: *persist,
AuthKey: opts.AuthKey, ServerURL: b.serverURL,
Hostinfo: hostinfo, AuthKey: opts.AuthKey,
KeepAlive: true, Hostinfo: hostinfo,
NewDecompressor: b.newDecompressor, KeepAlive: true,
HTTPTestClient: opts.HTTPTestClient, NewDecompressor: b.newDecompressor,
DiscoPublicKey: discoPublic, HTTPTestClient: opts.HTTPTestClient,
DiscoPublicKey: discoPublic,
}) })
if err != nil { if err != nil {
return err return err
@ -631,6 +635,63 @@ func (b *LocalBackend) popBrowserAuthNow() {
} }
} }
// initMachineKeyLocked is called to initialize b.machinePrivKey.
//
// b.prefs must already be initialized.
// b.mu must be held.
func (b *LocalBackend) initMachineKeyLocked() error {
if !b.machinePrivKey.IsZero() {
// Already set.
return nil
}
var legacyMachineKey wgcfg.PrivateKey
if b.prefs.Persist != nil {
legacyMachineKey = b.prefs.Persist.LegacyFrontendPrivateMachineKey
}
keyText, err := b.store.ReadState(MachineKeyStateKey)
if err == nil {
if err := b.machinePrivKey.UnmarshalText(keyText); err != nil {
return fmt.Errorf("invalid key in %s key of %v: %w", MachineKeyStateKey, b.store, err)
}
if b.machinePrivKey.IsZero() {
return fmt.Errorf("invalid zero key stored in %v key of %v", MachineKeyStateKey, b.store)
}
if !legacyMachineKey.IsZero() && !bytes.Equal(legacyMachineKey[:], b.machinePrivKey[:]) {
b.logf("frontend-provided legacy machine key ignored; used value from server state")
}
return nil
}
if err != ErrStateNotExist {
return fmt.Errorf("error reading %v key of %v: %w", MachineKeyStateKey, b.store, err)
}
// If we didn't find one already on disk and the prefs already
// have a legacy machine key, use that. Otherwise generate a
// new one.
if !legacyMachineKey.IsZero() {
b.logf("using frontend-provided legacy machine key")
b.machinePrivKey = legacyMachineKey
} else {
b.logf("generating new machine key")
var err error
b.machinePrivKey, err = wgcfg.NewPrivateKey()
if err != nil {
return fmt.Errorf("initializing new machine key: %w", err)
}
}
keyText, _ = b.machinePrivKey.MarshalText()
if err := b.store.WriteState(MachineKeyStateKey, keyText); err != nil {
b.logf("error writing machine key to store: %v", err)
return err
}
b.logf("machine key written to store")
return nil
}
// loadStateLocked sets b.prefs and b.stateKey based on a complex // loadStateLocked sets b.prefs and b.stateKey based on a complex
// combination of key, prefs, and legacyPath. b.mu must be held when // combination of key, prefs, and legacyPath. b.mu must be held when
// calling. // calling.
@ -640,9 +701,16 @@ func (b *LocalBackend) loadStateLocked(key StateKey, prefs *Prefs, legacyPath st
} }
if key == "" { if key == "" {
// Frontend fully owns the state, we just need to obey it. // Frontend owns the state, we just need to obey it.
//
// If the frontend (e.g. on Windows) supplied the
// optional/legacy machine key then it's used as the
// value instead of making up a new one.
b.logf("Using frontend prefs") b.logf("Using frontend prefs")
b.prefs = prefs.Clone() b.prefs = prefs.Clone()
if err := b.initMachineKeyLocked(); err != nil {
return fmt.Errorf("initMachineKeyLocked: %w", err)
}
b.stateKey = "" b.stateKey = ""
return nil return nil
} }
@ -674,6 +742,9 @@ func (b *LocalBackend) loadStateLocked(key StateKey, prefs *Prefs, legacyPath st
b.prefs = NewPrefs() b.prefs = NewPrefs()
b.logf("Created empty state for %q", key) b.logf("Created empty state for %q", key)
} }
if err := b.initMachineKeyLocked(); err != nil {
return fmt.Errorf("initMachineKeyLocked: %w", err)
}
b.stateKey = key b.stateKey = key
return nil return nil
} }
@ -684,6 +755,9 @@ func (b *LocalBackend) loadStateLocked(key StateKey, prefs *Prefs, legacyPath st
return fmt.Errorf("PrefsFromBytes: %v", err) return fmt.Errorf("PrefsFromBytes: %v", err)
} }
b.stateKey = key b.stateKey = key
if err := b.initMachineKeyLocked(); err != nil {
return fmt.Errorf("initMachineKeyLocked: %w", err)
}
return nil return nil
} }
@ -1290,13 +1364,14 @@ func (b *LocalBackend) setNetInfo(ni *tailcfg.NetInfo) {
func (b *LocalBackend) TestOnlyPublicKeys() (machineKey tailcfg.MachineKey, nodeKey tailcfg.NodeKey) { func (b *LocalBackend) TestOnlyPublicKeys() (machineKey tailcfg.MachineKey, nodeKey tailcfg.NodeKey) {
b.mu.Lock() b.mu.Lock()
prefs := b.prefs prefs := b.prefs
machinePrivKey := b.machinePrivKey
b.mu.Unlock() b.mu.Unlock()
if prefs == nil { if prefs == nil || machinePrivKey.IsZero() {
return return
} }
mk := prefs.Persist.PrivateMachineKey.Public() mk := machinePrivKey.Public()
nk := prefs.Persist.PrivateNodeKey.Public() nk := prefs.Persist.PrivateNodeKey.Public()
return tailcfg.MachineKey(mk), tailcfg.NodeKey(nk) return tailcfg.MachineKey(mk), tailcfg.NodeKey(nk)
} }

@ -21,6 +21,10 @@ import (
var ErrStateNotExist = errors.New("no state with given ID") var ErrStateNotExist = errors.New("no state with given ID")
const ( const (
// MachineKeyStateKey is the key under which we store the machine key,
// in its wgcfg.PrivateKey.MarshalText representation.
MachineKeyStateKey = StateKey("_machinekey")
// GlobalDaemonStateKey is the ipn.StateKey that tailscaled // GlobalDaemonStateKey is the ipn.StateKey that tailscaled
// loads on startup. // loads on startup.
// //

@ -612,6 +612,7 @@ type Debug struct {
func (k MachineKey) String() string { return fmt.Sprintf("mkey:%x", k[:]) } func (k MachineKey) String() string { return fmt.Sprintf("mkey:%x", k[:]) }
func (k MachineKey) MarshalText() ([]byte, error) { return keyMarshalText("mkey:", k), nil } func (k MachineKey) MarshalText() ([]byte, error) { return keyMarshalText("mkey:", k), nil }
func (k MachineKey) HexString() string { return fmt.Sprintf("%x", k[:]) }
func (k *MachineKey) UnmarshalText(text []byte) error { return keyUnmarshalText(k[:], "mkey:", text) } func (k *MachineKey) UnmarshalText(text []byte) error { return keyUnmarshalText(k[:], "mkey:", text) }
func keyMarshalText(prefix string, k [32]byte) []byte { func keyMarshalText(prefix string, k [32]byte) []byte {

Loading…
Cancel
Save