diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go index 228345806..30daefbf0 100644 --- a/control/controlclient/direct.go +++ b/control/controlclient/direct.go @@ -63,7 +63,7 @@ type Direct struct { logf logger.Logf linkMon *monitor.Mon // or nil discoPubKey tailcfg.DiscoKey - machinePrivKey wgkey.Private + getMachinePrivKey func() (wgkey.Private, error) debugFlags []string keepSharerAndUserSplit bool @@ -81,19 +81,19 @@ type Direct struct { } type Options struct { - Persist persist.Persist // initial persistent data - MachinePrivateKey wgkey.Private // the machine key to use - ServerURL string // URL of the tailcontrol server - AuthKey string // optional node auth key for auto registration - TimeNow func() time.Time // time.Now implementation used by Client - Hostinfo *tailcfg.Hostinfo // non-nil passes ownership, nil means to use default using os.Hostname, etc - DiscoPublicKey tailcfg.DiscoKey - NewDecompressor func() (Decompressor, error) - KeepAlive bool - Logf logger.Logf - HTTPTestClient *http.Client // optional HTTP client to use (for tests only) - DebugFlags []string // debug settings to send to control - LinkMonitor *monitor.Mon // optional link monitor + Persist persist.Persist // initial persistent data + GetMachinePrivateKey func() (wgkey.Private, error) // returns the machine key to use + ServerURL string // URL of the tailcontrol server + AuthKey string // optional node auth key for auto registration + TimeNow func() time.Time // time.Now implementation used by Client + Hostinfo *tailcfg.Hostinfo // non-nil passes ownership, nil means to use default using os.Hostname, etc + DiscoPublicKey tailcfg.DiscoKey + NewDecompressor func() (Decompressor, error) + KeepAlive bool + Logf logger.Logf + HTTPTestClient *http.Client // optional HTTP client to use (for tests only) + DebugFlags []string // debug settings to send to control + LinkMonitor *monitor.Mon // optional link monitor // KeepSharerAndUserSplit controls whether the client // understands Node.Sharer. If false, the Sharer is mapped to the User. @@ -110,8 +110,8 @@ func NewDirect(opts Options) (*Direct, error) { if opts.ServerURL == "" { return nil, errors.New("controlclient.New: no server URL specified") } - if opts.MachinePrivateKey.IsZero() { - return nil, errors.New("controlclient.New: no MachinePrivateKey specified") + if opts.GetMachinePrivateKey == nil { + return nil, errors.New("controlclient.New: no GetMachinePrivateKey specified") } opts.ServerURL = strings.TrimRight(opts.ServerURL, "/") serverURL, err := url.Parse(opts.ServerURL) @@ -147,7 +147,7 @@ func NewDirect(opts Options) (*Direct, error) { c := &Direct{ httpc: httpc, - machinePrivKey: opts.MachinePrivateKey, + getMachinePrivKey: opts.GetMachinePrivateKey, serverURL: opts.ServerURL, timeNow: opts.TimeNow, logf: opts.Logf, @@ -301,8 +301,12 @@ func (c *Direct) doLogin(ctx context.Context, t *tailcfg.Oauth2Token, flags Logi expired := c.expiry != nil && !c.expiry.IsZero() && c.expiry.Before(c.timeNow()) c.mu.Unlock() - if c.machinePrivKey.IsZero() { - return false, "", errors.New("controlclient.Direct requires a machine private key") + machinePrivKey, err := c.getMachinePrivKey() + if err != nil { + return false, "", fmt.Errorf("getMachinePrivKey: %w", err) + } + if machinePrivKey.IsZero() { + return false, "", errors.New("getMachinePrivKey returned zero key") } if expired { @@ -370,7 +374,7 @@ func (c *Direct) doLogin(ctx context.Context, t *tailcfg.Oauth2Token, flags Logi request.Auth.Provider = persist.Provider request.Auth.LoginName = persist.LoginName request.Auth.AuthKey = authKey - err = signRegisterRequest(&request, c.serverURL, c.serverKey, c.machinePrivKey.Public()) + err = signRegisterRequest(&request, c.serverURL, c.serverKey, machinePrivKey.Public()) if err != nil { // If signing failed, clear all related fields request.SignatureType = tailcfg.SignatureNone @@ -384,13 +388,13 @@ func (c *Direct) doLogin(ctx context.Context, t *tailcfg.Oauth2Token, flags Logi c.logf("RegisterReq sign error: %v", err) } } - bodyData, err := encode(request, &serverKey, &c.machinePrivKey) + bodyData, err := encode(request, &serverKey, &machinePrivKey) if err != nil { return regen, url, err } body := bytes.NewReader(bodyData) - u := fmt.Sprintf("%s/machine/%s", c.serverURL, c.machinePrivKey.Public().HexString()) + u := fmt.Sprintf("%s/machine/%s", c.serverURL, machinePrivKey.Public().HexString()) req, err := http.NewRequest("POST", u, body) if err != nil { return regen, url, err @@ -408,8 +412,8 @@ func (c *Direct) doLogin(ctx context.Context, t *tailcfg.Oauth2Token, flags Logi res.StatusCode, strings.TrimSpace(string(msg))) } resp := tailcfg.RegisterResponse{} - if err := decode(res, &resp, &serverKey, &c.machinePrivKey); err != nil { - c.logf("error decoding RegisterResponse with server key %s and machine key %s: %v", serverKey, c.machinePrivKey.Public(), err) + if err := decode(res, &resp, &serverKey, &machinePrivKey); err != nil { + c.logf("error decoding RegisterResponse with server key %s and machine key %s: %v", serverKey, machinePrivKey.Public(), err) return regen, url, fmt.Errorf("register request: %v", err) } // Log without PII: @@ -536,6 +540,14 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, cb func(*netm everEndpoints := c.everEndpoints c.mu.Unlock() + machinePrivKey, err := c.getMachinePrivKey() + if err != nil { + return fmt.Errorf("getMachinePrivKey: %w", err) + } + if machinePrivKey.IsZero() { + return errors.New("getMachinePrivKey returned zero key") + } + if persist.PrivateNodeKey.IsZero() { return errors.New("privateNodeKey is zero") } @@ -593,7 +605,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, cb func(*netm request.ReadOnly = true } - bodyData, err := encode(request, &serverKey, &c.machinePrivKey) + bodyData, err := encode(request, &serverKey, &machinePrivKey) if err != nil { vlogf("netmap: encode: %v", err) return err @@ -602,7 +614,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, cb func(*netm ctx, cancel := context.WithCancel(ctx) defer cancel() - machinePubKey := tailcfg.MachineKey(c.machinePrivKey.Public()) + machinePubKey := tailcfg.MachineKey(machinePrivKey.Public()) t0 := time.Now() u := fmt.Sprintf("%s/machine/%s/map", serverURL, machinePubKey.HexString()) @@ -695,7 +707,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, cb func(*netm vlogf("netmap: read body after %v", time.Since(t0).Round(time.Millisecond)) var resp tailcfg.MapResponse - if err := c.decodeMsg(msg, &resp); err != nil { + if err := c.decodeMsg(msg, &resp, &machinePrivKey); err != nil { vlogf("netmap: decode error: %v") return err } @@ -878,12 +890,12 @@ var debugMap, _ = strconv.ParseBool(os.Getenv("TS_DEBUG_MAP")) var jsonEscapedZero = []byte(`\u0000`) -func (c *Direct) decodeMsg(msg []byte, v interface{}) error { +func (c *Direct) decodeMsg(msg []byte, v interface{}, machinePrivKey *wgkey.Private) error { c.mu.Lock() serverKey := c.serverKey c.mu.Unlock() - decrypted, err := decryptMsg(msg, &serverKey, &c.machinePrivKey) + decrypted, err := decryptMsg(msg, &serverKey, machinePrivKey) if err != nil { return err } @@ -917,8 +929,8 @@ func (c *Direct) decodeMsg(msg []byte, v interface{}) error { } -func decodeMsg(msg []byte, v interface{}, serverKey *wgkey.Key, mkey *wgkey.Private) error { - decrypted, err := decryptMsg(msg, serverKey, mkey) +func decodeMsg(msg []byte, v interface{}, serverKey *wgkey.Key, machinePrivKey *wgkey.Private) error { + decrypted, err := decryptMsg(msg, serverKey, machinePrivKey) if err != nil { return err } diff --git a/control/controlclient/direct_test.go b/control/controlclient/direct_test.go index 3dab4d9ec..603614aba 100644 --- a/control/controlclient/direct_test.go +++ b/control/controlclient/direct_test.go @@ -103,7 +103,13 @@ func TestNewDirect(t *testing.T) { if err != nil { t.Error(err) } - opts := Options{ServerURL: "https://example.com", MachinePrivateKey: key, Hostinfo: hi} + opts := Options{ + ServerURL: "https://example.com", + Hostinfo: hi, + GetMachinePrivateKey: func() (wgkey.Private, error) { + return key, nil + }, + } c, err := NewDirect(opts) if err != nil { t.Fatal(err) diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index b3ddc4856..202a53afa 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -623,18 +623,23 @@ func (b *LocalBackend) Start(opts ipn.Options) error { persistv = &persist.Persist{} } cli, err := controlclient.New(controlclient.Options{ - MachinePrivateKey: machinePrivKey, - Logf: logger.WithPrefix(b.logf, "control: "), - Persist: *persistv, - ServerURL: b.serverURL, - AuthKey: opts.AuthKey, - Hostinfo: hostinfo, - KeepAlive: true, - NewDecompressor: b.newDecompressor, - HTTPTestClient: opts.HTTPTestClient, - DiscoPublicKey: discoPublic, - DebugFlags: controlDebugFlags, - LinkMonitor: b.e.GetLinkMonitor(), + GetMachinePrivateKey: func() (wgkey.Private, error) { + // TODO(bradfitz): finish pushing this laziness further; see + // https://github.com/tailscale/tailscale/issues/1573 + // For now this is only lazy-ified in controlclient. + return machinePrivKey, nil + }, + Logf: logger.WithPrefix(b.logf, "control: "), + Persist: *persistv, + ServerURL: b.serverURL, + AuthKey: opts.AuthKey, + Hostinfo: hostinfo, + KeepAlive: true, + NewDecompressor: b.newDecompressor, + HTTPTestClient: opts.HTTPTestClient, + DiscoPublicKey: discoPublic, + DebugFlags: controlDebugFlags, + LinkMonitor: b.e.GetLinkMonitor(), }) if err != nil { return err