diff --git a/control/controlclient/auto_test.go b/control/controlclient/auto_test.go index 974b3e55f..affed3b0a 100644 --- a/control/controlclient/auto_test.go +++ b/control/controlclient/auto_test.go @@ -1167,6 +1167,12 @@ func authURLForPOST(authURL string) string { return authURL[:i] + "/login?refresh=true&next_url=" + url.PathEscape(authURL[i:]) } +// postAuthURL manually executes the OAuth login flow, starting at +// authURL and claiming to be user. This flow will only work correctly +// if the control server is configured with the "None" auth provider, +// which blindly accepts the provided user and produces a cookie for +// them. postAuthURL returns the auth cookie produced by the control +// server. func postAuthURL(t *testing.T, ctx context.Context, httpc *http.Client, user string, authURL string) *http.Cookie { t.Helper() diff --git a/control/controlclient/direct_test.go b/control/controlclient/direct_test.go index a5ea9c431..04d39c80a 100644 --- a/control/controlclient/direct_test.go +++ b/control/controlclient/direct_test.go @@ -22,6 +22,8 @@ import ( "tailscale.io/control" // not yet released ) +// Test that when there are two controlclient connections using the +// same credentials, the later one disconnects the earlier one. func TestClientsReusingKeys(t *testing.T) { tmpdir, err := ioutil.TempDir("", "control-test-") if err != nil { @@ -31,21 +33,23 @@ func TestClientsReusingKeys(t *testing.T) { httpsrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server.ServeHTTP(w, r) })) + defer func() { + httpsrv.CloseClientConnections() + httpsrv.Close() + os.RemoveAll(tmpdir) + }() + httpc := httpsrv.Client() httpc.Jar, err = cookiejar.New(nil) if err != nil { t.Fatal(err) } + server, err = control.New(tmpdir, tmpdir, httpsrv.URL, true) if err != nil { t.Fatal(err) } server.QuietLogging = true - defer func() { - httpsrv.CloseClientConnections() - httpsrv.Close() - os.RemoveAll(tmpdir) - }() hi := NewHostinfo() hi.FrontendLogID = "go-test-only" @@ -63,13 +67,20 @@ func TestClientsReusingKeys(t *testing.T) { if err != nil { t.Fatal(err) } + + // Use a cancelable context so that goroutines blocking in + // PollNetMap shut down when the test exits. ctx, cancel := context.WithCancel(context.Background()) defer cancel() + + // Execute c1's login flow: TryLogin to get an auth URL, + // postAuthURL to execute the (faked) OAuth segment of the flow, + // and WaitLoginURL to complete the login on the client end. + const user = "testuser1@tailscale.onmicrosoft.com" authURL, err := c1.TryLogin(ctx, nil, 0) if err != nil { t.Fatal(err) } - const user = "testuser1@tailscale.onmicrosoft.com" postAuthURL(t, ctx, httpc, user, authURL) newURL, err := c1.WaitLoginURL(ctx, authURL) if err != nil { @@ -79,10 +90,13 @@ func TestClientsReusingKeys(t *testing.T) { t.Fatalf("unexpected newURL: %s", newURL) } + // Start c1's netmap poll in parallel with the rest of the + // test. We're expecting it to block happily, invoking the no-op + // update function periodically, then exit once c2 starts its own + // poll below. pollErrCh := make(chan error) go func() { - err := c1.PollNetMap(ctx, -1, func(netMap *NetworkMap) {}) - pollErrCh <- err + pollErrCh <- c1.PollNetMap(ctx, -1, func(netMap *NetworkMap) {}) }() select { @@ -91,6 +105,8 @@ func TestClientsReusingKeys(t *testing.T) { default: } + // Connect c2, reusing c1's credentials. In other words, c2 *is* + // c1 from the server's perspective. c2, err := NewDirect(Options{ ServerURL: httpsrv.URL, HTTPTestClient: httpsrv.Client(), @@ -112,18 +128,24 @@ func TestClientsReusingKeys(t *testing.T) { if err != nil { t.Fatal(err) } + // We don't expect to be given an authURL, our credentials from c1 + // should still be good. if authURL != "" { t.Errorf("unexpected authURL %s", authURL) } + // Request a single netmap, so this function returns promptly + // instead of blocking like c1's PollNetMap. err = c2.PollNetMap(ctx, 1, func(netMap *NetworkMap) {}) if err != nil { t.Fatal(err) } + // Now that c2 connected and got a netmap, we expect c1's poll to + // have exited. select { case err := <-pollErrCh: - t.Logf("expected poll error: %v", err) + t.Logf("c1: netmap poll aborted as expected (%v)", err) case <-time.After(5 * time.Second): t.Fatal("first client poll failed to close") }