control/controlclient: support lazy machine key generation

It's not done in the caller yet, but the controlclient does it now.

Updates #1573

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/1622/head
Brad Fitzpatrick 3 years ago
parent 8d57bce5ef
commit a998fe7c3d

@ -63,7 +63,7 @@ type Direct struct {
logf logger.Logf logf logger.Logf
linkMon *monitor.Mon // or nil linkMon *monitor.Mon // or nil
discoPubKey tailcfg.DiscoKey discoPubKey tailcfg.DiscoKey
machinePrivKey wgkey.Private getMachinePrivKey func() (wgkey.Private, error)
debugFlags []string debugFlags []string
keepSharerAndUserSplit bool keepSharerAndUserSplit bool
@ -81,19 +81,19 @@ type Direct struct {
} }
type Options struct { type Options struct {
Persist persist.Persist // initial persistent data Persist persist.Persist // initial persistent data
MachinePrivateKey wgkey.Private // the machine key to use GetMachinePrivateKey func() (wgkey.Private, error) // returns the machine key to use
ServerURL string // URL of the tailcontrol server ServerURL string // URL of the tailcontrol server
AuthKey string // optional node auth key for auto registration AuthKey string // optional node auth key for auto registration
TimeNow func() time.Time // time.Now implementation used by Client TimeNow func() time.Time // time.Now implementation used by Client
Hostinfo *tailcfg.Hostinfo // non-nil passes ownership, nil means to use default using os.Hostname, etc Hostinfo *tailcfg.Hostinfo // non-nil passes ownership, nil means to use default using os.Hostname, etc
DiscoPublicKey tailcfg.DiscoKey DiscoPublicKey tailcfg.DiscoKey
NewDecompressor func() (Decompressor, error) NewDecompressor func() (Decompressor, error)
KeepAlive bool KeepAlive bool
Logf logger.Logf Logf logger.Logf
HTTPTestClient *http.Client // optional HTTP client to use (for tests only) HTTPTestClient *http.Client // optional HTTP client to use (for tests only)
DebugFlags []string // debug settings to send to control DebugFlags []string // debug settings to send to control
LinkMonitor *monitor.Mon // optional link monitor LinkMonitor *monitor.Mon // optional link monitor
// KeepSharerAndUserSplit controls whether the client // KeepSharerAndUserSplit controls whether the client
// understands Node.Sharer. If false, the Sharer is mapped to the User. // understands Node.Sharer. If false, the Sharer is mapped to the User.
@ -110,8 +110,8 @@ func NewDirect(opts Options) (*Direct, error) {
if opts.ServerURL == "" { if opts.ServerURL == "" {
return nil, errors.New("controlclient.New: no server URL specified") return nil, errors.New("controlclient.New: no server URL specified")
} }
if opts.MachinePrivateKey.IsZero() { if opts.GetMachinePrivateKey == nil {
return nil, errors.New("controlclient.New: no MachinePrivateKey specified") return nil, errors.New("controlclient.New: no GetMachinePrivateKey specified")
} }
opts.ServerURL = strings.TrimRight(opts.ServerURL, "/") opts.ServerURL = strings.TrimRight(opts.ServerURL, "/")
serverURL, err := url.Parse(opts.ServerURL) serverURL, err := url.Parse(opts.ServerURL)
@ -147,7 +147,7 @@ func NewDirect(opts Options) (*Direct, error) {
c := &Direct{ c := &Direct{
httpc: httpc, httpc: httpc,
machinePrivKey: opts.MachinePrivateKey, getMachinePrivKey: opts.GetMachinePrivateKey,
serverURL: opts.ServerURL, serverURL: opts.ServerURL,
timeNow: opts.TimeNow, timeNow: opts.TimeNow,
logf: opts.Logf, logf: opts.Logf,
@ -301,8 +301,12 @@ func (c *Direct) doLogin(ctx context.Context, t *tailcfg.Oauth2Token, flags Logi
expired := c.expiry != nil && !c.expiry.IsZero() && c.expiry.Before(c.timeNow()) expired := c.expiry != nil && !c.expiry.IsZero() && c.expiry.Before(c.timeNow())
c.mu.Unlock() c.mu.Unlock()
if c.machinePrivKey.IsZero() { machinePrivKey, err := c.getMachinePrivKey()
return false, "", errors.New("controlclient.Direct requires a machine private key") if err != nil {
return false, "", fmt.Errorf("getMachinePrivKey: %w", err)
}
if machinePrivKey.IsZero() {
return false, "", errors.New("getMachinePrivKey returned zero key")
} }
if expired { if expired {
@ -370,7 +374,7 @@ func (c *Direct) doLogin(ctx context.Context, t *tailcfg.Oauth2Token, flags Logi
request.Auth.Provider = persist.Provider request.Auth.Provider = persist.Provider
request.Auth.LoginName = persist.LoginName request.Auth.LoginName = persist.LoginName
request.Auth.AuthKey = authKey request.Auth.AuthKey = authKey
err = signRegisterRequest(&request, c.serverURL, c.serverKey, c.machinePrivKey.Public()) err = signRegisterRequest(&request, c.serverURL, c.serverKey, machinePrivKey.Public())
if err != nil { if err != nil {
// If signing failed, clear all related fields // If signing failed, clear all related fields
request.SignatureType = tailcfg.SignatureNone request.SignatureType = tailcfg.SignatureNone
@ -384,13 +388,13 @@ func (c *Direct) doLogin(ctx context.Context, t *tailcfg.Oauth2Token, flags Logi
c.logf("RegisterReq sign error: %v", err) c.logf("RegisterReq sign error: %v", err)
} }
} }
bodyData, err := encode(request, &serverKey, &c.machinePrivKey) bodyData, err := encode(request, &serverKey, &machinePrivKey)
if err != nil { if err != nil {
return regen, url, err return regen, url, err
} }
body := bytes.NewReader(bodyData) body := bytes.NewReader(bodyData)
u := fmt.Sprintf("%s/machine/%s", c.serverURL, c.machinePrivKey.Public().HexString()) u := fmt.Sprintf("%s/machine/%s", c.serverURL, machinePrivKey.Public().HexString())
req, err := http.NewRequest("POST", u, body) req, err := http.NewRequest("POST", u, body)
if err != nil { if err != nil {
return regen, url, err return regen, url, err
@ -408,8 +412,8 @@ func (c *Direct) doLogin(ctx context.Context, t *tailcfg.Oauth2Token, flags Logi
res.StatusCode, strings.TrimSpace(string(msg))) res.StatusCode, strings.TrimSpace(string(msg)))
} }
resp := tailcfg.RegisterResponse{} resp := tailcfg.RegisterResponse{}
if err := decode(res, &resp, &serverKey, &c.machinePrivKey); err != nil { if err := decode(res, &resp, &serverKey, &machinePrivKey); err != nil {
c.logf("error decoding RegisterResponse with server key %s and machine key %s: %v", serverKey, c.machinePrivKey.Public(), err) c.logf("error decoding RegisterResponse with server key %s and machine key %s: %v", serverKey, machinePrivKey.Public(), err)
return regen, url, fmt.Errorf("register request: %v", err) return regen, url, fmt.Errorf("register request: %v", err)
} }
// Log without PII: // Log without PII:
@ -536,6 +540,14 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, cb func(*netm
everEndpoints := c.everEndpoints everEndpoints := c.everEndpoints
c.mu.Unlock() c.mu.Unlock()
machinePrivKey, err := c.getMachinePrivKey()
if err != nil {
return fmt.Errorf("getMachinePrivKey: %w", err)
}
if machinePrivKey.IsZero() {
return errors.New("getMachinePrivKey returned zero key")
}
if persist.PrivateNodeKey.IsZero() { if persist.PrivateNodeKey.IsZero() {
return errors.New("privateNodeKey is zero") return errors.New("privateNodeKey is zero")
} }
@ -593,7 +605,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, cb func(*netm
request.ReadOnly = true request.ReadOnly = true
} }
bodyData, err := encode(request, &serverKey, &c.machinePrivKey) bodyData, err := encode(request, &serverKey, &machinePrivKey)
if err != nil { if err != nil {
vlogf("netmap: encode: %v", err) vlogf("netmap: encode: %v", err)
return err return err
@ -602,7 +614,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, cb func(*netm
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
machinePubKey := tailcfg.MachineKey(c.machinePrivKey.Public()) machinePubKey := tailcfg.MachineKey(machinePrivKey.Public())
t0 := time.Now() t0 := time.Now()
u := fmt.Sprintf("%s/machine/%s/map", serverURL, machinePubKey.HexString()) u := fmt.Sprintf("%s/machine/%s/map", serverURL, machinePubKey.HexString())
@ -695,7 +707,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, cb func(*netm
vlogf("netmap: read body after %v", time.Since(t0).Round(time.Millisecond)) vlogf("netmap: read body after %v", time.Since(t0).Round(time.Millisecond))
var resp tailcfg.MapResponse var resp tailcfg.MapResponse
if err := c.decodeMsg(msg, &resp); err != nil { if err := c.decodeMsg(msg, &resp, &machinePrivKey); err != nil {
vlogf("netmap: decode error: %v") vlogf("netmap: decode error: %v")
return err return err
} }
@ -878,12 +890,12 @@ var debugMap, _ = strconv.ParseBool(os.Getenv("TS_DEBUG_MAP"))
var jsonEscapedZero = []byte(`\u0000`) var jsonEscapedZero = []byte(`\u0000`)
func (c *Direct) decodeMsg(msg []byte, v interface{}) error { func (c *Direct) decodeMsg(msg []byte, v interface{}, machinePrivKey *wgkey.Private) error {
c.mu.Lock() c.mu.Lock()
serverKey := c.serverKey serverKey := c.serverKey
c.mu.Unlock() c.mu.Unlock()
decrypted, err := decryptMsg(msg, &serverKey, &c.machinePrivKey) decrypted, err := decryptMsg(msg, &serverKey, machinePrivKey)
if err != nil { if err != nil {
return err return err
} }
@ -917,8 +929,8 @@ func (c *Direct) decodeMsg(msg []byte, v interface{}) error {
} }
func decodeMsg(msg []byte, v interface{}, serverKey *wgkey.Key, mkey *wgkey.Private) error { func decodeMsg(msg []byte, v interface{}, serverKey *wgkey.Key, machinePrivKey *wgkey.Private) error {
decrypted, err := decryptMsg(msg, serverKey, mkey) decrypted, err := decryptMsg(msg, serverKey, machinePrivKey)
if err != nil { if err != nil {
return err return err
} }

@ -103,7 +103,13 @@ func TestNewDirect(t *testing.T) {
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
opts := Options{ServerURL: "https://example.com", MachinePrivateKey: key, Hostinfo: hi} opts := Options{
ServerURL: "https://example.com",
Hostinfo: hi,
GetMachinePrivateKey: func() (wgkey.Private, error) {
return key, nil
},
}
c, err := NewDirect(opts) c, err := NewDirect(opts)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

@ -623,18 +623,23 @@ func (b *LocalBackend) Start(opts ipn.Options) error {
persistv = &persist.Persist{} persistv = &persist.Persist{}
} }
cli, err := controlclient.New(controlclient.Options{ cli, err := controlclient.New(controlclient.Options{
MachinePrivateKey: machinePrivKey, GetMachinePrivateKey: func() (wgkey.Private, error) {
Logf: logger.WithPrefix(b.logf, "control: "), // TODO(bradfitz): finish pushing this laziness further; see
Persist: *persistv, // https://github.com/tailscale/tailscale/issues/1573
ServerURL: b.serverURL, // For now this is only lazy-ified in controlclient.
AuthKey: opts.AuthKey, return machinePrivKey, nil
Hostinfo: hostinfo, },
KeepAlive: true, Logf: logger.WithPrefix(b.logf, "control: "),
NewDecompressor: b.newDecompressor, Persist: *persistv,
HTTPTestClient: opts.HTTPTestClient, ServerURL: b.serverURL,
DiscoPublicKey: discoPublic, AuthKey: opts.AuthKey,
DebugFlags: controlDebugFlags, Hostinfo: hostinfo,
LinkMonitor: b.e.GetLinkMonitor(), KeepAlive: true,
NewDecompressor: b.newDecompressor,
HTTPTestClient: opts.HTTPTestClient,
DiscoPublicKey: discoPublic,
DebugFlags: controlDebugFlags,
LinkMonitor: b.e.GetLinkMonitor(),
}) })
if err != nil { if err != nil {
return err return err

Loading…
Cancel
Save