From 597c19ff4e0f8c40df300f30ae88296bf5d580cc Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Wed, 7 Apr 2021 21:06:31 -0700 Subject: [PATCH] control/controlclient: refactor some internals Signed-off-by: Brad Fitzpatrick --- control/controlclient/direct.go | 52 +++++++++++++++++++-------------- 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go index e9554f7ee..7587b8e11 100644 --- a/control/controlclient/direct.go +++ b/control/controlclient/direct.go @@ -274,7 +274,7 @@ func (c *Direct) TryLogout(ctx context.Context) error { func (c *Direct) TryLogin(ctx context.Context, t *tailcfg.Oauth2Token, flags LoginFlags) (url string, err error) { c.logf("direct.TryLogin(token=%v, flags=%v)", t != nil, flags) - return c.doLoginOrRegen(ctx, t, flags, false, "") + return c.doLoginOrRegen(ctx, loginOpt{Token: t, Flags: flags}) } // WaitLoginURL sits in a long poll waiting for the user to authenticate at url. @@ -282,22 +282,29 @@ func (c *Direct) TryLogin(ctx context.Context, t *tailcfg.Oauth2Token, flags Log // On success, newURL and err will both be nil. func (c *Direct) WaitLoginURL(ctx context.Context, url string) (newURL string, err error) { c.logf("direct.WaitLoginURL") - return c.doLoginOrRegen(ctx, nil, LoginDefault, false, url) + return c.doLoginOrRegen(ctx, loginOpt{URL: url}) } -func (c *Direct) doLoginOrRegen(ctx context.Context, t *tailcfg.Oauth2Token, flags LoginFlags, regen bool, url string) (newURL string, err error) { - mustregen, url, err := c.doLogin(ctx, t, flags, regen, url) +func (c *Direct) doLoginOrRegen(ctx context.Context, opt loginOpt) (newURL string, err error) { + mustRegen, url, err := c.doLogin(ctx, opt) if err != nil { return url, err } - if mustregen { - _, url, err = c.doLogin(ctx, t, flags, true, url) + if mustRegen { + opt.Regen = true + _, url, err = c.doLogin(ctx, opt) } - return url, err } -func (c *Direct) doLogin(ctx context.Context, t *tailcfg.Oauth2Token, flags LoginFlags, regen bool, url string) (mustregen bool, newurl string, err error) { +type loginOpt struct { + Token *tailcfg.Oauth2Token + Flags LoginFlags + Regen bool + URL string +} + +func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, newURL string, err error) { c.mu.Lock() persist := c.persist tryingNewKey := c.tryingNewKey @@ -316,22 +323,23 @@ func (c *Direct) doLogin(ctx context.Context, t *tailcfg.Oauth2Token, flags Logi return false, "", errors.New("getMachinePrivKey returned zero key") } + regen := opt.Regen if expired { c.logf("Old key expired -> regen=true") systemd.Status("key expired; run 'tailscale up' to authenticate") regen = true } - if (flags & LoginInteractive) != 0 { + if (opt.Flags & LoginInteractive) != 0 { c.logf("LoginInteractive -> regen=true") regen = true } - c.logf("doLogin(regen=%v, hasUrl=%v)", regen, url != "") + c.logf("doLogin(regen=%v, hasUrl=%v)", regen, opt.URL != "") if serverKey.IsZero() { var err error serverKey, err = loadServerKey(ctx, c.httpc, c.serverURL) if err != nil { - return regen, url, err + return regen, opt.URL, err } c.mu.Lock() @@ -340,14 +348,14 @@ func (c *Direct) doLogin(ctx context.Context, t *tailcfg.Oauth2Token, flags Logi } var oldNodeKey wgkey.Key - if url != "" { + if opt.URL != "" { } else if regen || persist.PrivateNodeKey.IsZero() { c.logf("Generating a new nodekey.") persist.OldPrivateNodeKey = persist.PrivateNodeKey key, err := wgkey.NewPrivate() if err != nil { c.logf("login keygen: %v", err) - return regen, url, err + return regen, opt.URL, err } tryingNewKey = key } else { @@ -363,7 +371,7 @@ func (c *Direct) doLogin(ctx context.Context, t *tailcfg.Oauth2Token, flags Logi } if backendLogID == "" { err = errors.New("hostinfo: BackendLogID missing") - return regen, url, err + return regen, opt.URL, err } now := time.Now().Round(time.Second) request := tailcfg.RegisterRequest{ @@ -371,13 +379,13 @@ func (c *Direct) doLogin(ctx context.Context, t *tailcfg.Oauth2Token, flags Logi OldNodeKey: tailcfg.NodeKey(oldNodeKey), NodeKey: tailcfg.NodeKey(tryingNewKey.Public()), Hostinfo: hostinfo, - Followup: url, + Followup: opt.URL, Timestamp: &now, } c.logf("RegisterReq: onode=%v node=%v fup=%v", request.OldNodeKey.ShortString(), - request.NodeKey.ShortString(), url != "") - request.Auth.Oauth2Token = t + request.NodeKey.ShortString(), opt.URL != "") + request.Auth.Oauth2Token = opt.Token request.Auth.Provider = persist.Provider request.Auth.LoginName = persist.LoginName request.Auth.AuthKey = authKey @@ -397,31 +405,31 @@ func (c *Direct) doLogin(ctx context.Context, t *tailcfg.Oauth2Token, flags Logi } bodyData, err := encode(request, &serverKey, &machinePrivKey) if err != nil { - return regen, url, err + return regen, opt.URL, err } body := bytes.NewReader(bodyData) u := fmt.Sprintf("%s/machine/%s", c.serverURL, machinePrivKey.Public().HexString()) req, err := http.NewRequest("POST", u, body) if err != nil { - return regen, url, err + return regen, opt.URL, err } req = req.WithContext(ctx) res, err := c.httpc.Do(req) if err != nil { - return regen, url, fmt.Errorf("register request: %v", err) + return regen, opt.URL, fmt.Errorf("register request: %v", err) } if res.StatusCode != 200 { msg, _ := ioutil.ReadAll(res.Body) res.Body.Close() - return regen, url, fmt.Errorf("register request: http %d: %.200s", + return regen, opt.URL, fmt.Errorf("register request: http %d: %.200s", res.StatusCode, strings.TrimSpace(string(msg))) } resp := tailcfg.RegisterResponse{} if err := decode(res, &resp, &serverKey, &machinePrivKey); err != nil { c.logf("error decoding RegisterResponse with server key %s and machine key %s: %v", serverKey, machinePrivKey.Public(), err) - return regen, url, fmt.Errorf("register request: %v", err) + return regen, opt.URL, fmt.Errorf("register request: %v", err) } // Log without PII: c.logf("RegisterReq: got response; nodeKeyExpired=%v, machineAuthorized=%v; authURL=%v",