diff --git a/control/controlclient/auto.go b/control/controlclient/auto.go index febacb215..0d7e247c8 100644 --- a/control/controlclient/auto.go +++ b/control/controlclient/auto.go @@ -17,6 +17,7 @@ import ( "sync" "time" + "github.com/tailscale/wireguard-go/wgcfg" "golang.org/x/oauth2" "tailscale.com/logtail/backoff" "tailscale.com/tailcfg" @@ -25,36 +26,37 @@ import ( "tailscale.com/types/structs" ) -// TODO(apenwarr): eliminate the 'state' variable, as it's now obsolete. -// It's used only by the unit tests. -type state int +// State is the high-level state of the client. It is used only in +// unit tests for proper sequencing, don't depend on it anywhere else. +// TODO(apenwarr): eliminate 'state', as it's now obsolete. +type State int const ( - stateNew = state(iota) - stateNotAuthenticated - stateAuthenticating - stateURLVisitRequired - stateAuthenticated - stateSynchronized // connected and received map update + StateNew = State(iota) + StateNotAuthenticated + StateAuthenticating + StateURLVisitRequired + StateAuthenticated + StateSynchronized // connected and received map update ) -func (s state) MarshalText() ([]byte, error) { +func (s State) MarshalText() ([]byte, error) { return []byte(s.String()), nil } -func (s state) String() string { +func (s State) String() string { switch s { - case stateNew: + case StateNew: return "state:new" - case stateNotAuthenticated: + case StateNotAuthenticated: return "state:not-authenticated" - case stateAuthenticating: + case StateAuthenticating: return "state:authenticating" - case stateURLVisitRequired: + case StateURLVisitRequired: return "state:url-visit-required" - case stateAuthenticated: + case StateAuthenticated: return "state:authenticated" - case stateSynchronized: + case StateSynchronized: return "state:synchronized" default: return fmt.Sprintf("state:unknown:%d", int(s)) @@ -69,7 +71,7 @@ type Status struct { Persist *Persist // locally persisted configuration NetMap *NetworkMap // server-pushed configuration Hostinfo *tailcfg.Hostinfo // current Hostinfo data - state state + State State } // Equal reports whether s and s2 are equal. @@ -84,7 +86,7 @@ func (s *Status) Equal(s2 *Status) bool { reflect.DeepEqual(s.Persist, s2.Persist) && reflect.DeepEqual(s.NetMap, s2.NetMap) && reflect.DeepEqual(s.Hostinfo, s2.Hostinfo) && - s.state == s2.state + s.State == s2.State } func (s Status) String() string { @@ -92,7 +94,7 @@ func (s Status) String() string { if err != nil { panic(err) } - return s.state.String() + " " + string(b) + return s.State.String() + " " + string(b) } type LoginGoal struct { @@ -121,7 +123,7 @@ type Client struct { hostinfo *tailcfg.Hostinfo inPollNetMap bool // true if currently running a PollNetMap inSendStatus int // number of sendStatus calls currently in progress - state state + state State authCtx context.Context // context used for auth requests mapCtx context.Context // context used for netmap requests @@ -319,7 +321,7 @@ func (c *Client) authRoutine() { c.mu.Lock() c.loggedIn = false c.loginGoal = nil - c.state = stateNotAuthenticated + c.state = StateNotAuthenticated c.synced = false c.mu.Unlock() @@ -328,9 +330,9 @@ func (c *Client) authRoutine() { } else { // ie. goal.wantLoggedIn c.mu.Lock() if goal.url != "" { - c.state = stateURLVisitRequired + c.state = StateURLVisitRequired } else { - c.state = stateAuthenticating + c.state = StateAuthenticating } c.mu.Unlock() @@ -359,7 +361,7 @@ func (c *Client) authRoutine() { c.mu.Lock() c.loginGoal = goal - c.state = stateURLVisitRequired + c.state = StateURLVisitRequired c.synced = false c.mu.Unlock() @@ -372,7 +374,7 @@ func (c *Client) authRoutine() { c.mu.Lock() c.loggedIn = true c.loginGoal = nil - c.state = stateAuthenticated + c.state = StateAuthenticated c.mu.Unlock() c.sendStatus("authRoutine4", nil, "", nil) @@ -382,6 +384,20 @@ func (c *Client) authRoutine() { } } +// Expiry returns the credential expiration time, or the zero time if +// the expiration time isn't known. Used in tests only. +func (c *Client) Expiry() *time.Time { + c.mu.Lock() + defer c.mu.Unlock() + return c.expiry +} + +// Direct returns the underlying direct client object. Used in tests +// only. +func (c *Client) Direct() *Direct { + return c.direct +} + func (c *Client) mapRoutine() { defer close(c.mapDone) bo := backoff.NewBackoff("mapRoutine", c.logf) @@ -449,7 +465,7 @@ func (c *Client) mapRoutine() { c.synced = true c.inPollNetMap = true if c.loggedIn { - c.state = stateSynchronized + c.state = StateSynchronized } exp := nm.Expiry c.expiry = &exp @@ -467,8 +483,8 @@ func (c *Client) mapRoutine() { c.mu.Lock() c.synced = false c.inPollNetMap = false - if c.state == stateSynchronized { - c.state = stateAuthenticated + if c.state == StateSynchronized { + c.state = StateAuthenticated } c.mu.Unlock() @@ -537,7 +553,7 @@ func (c *Client) sendStatus(who string, err error, url string, nm *NetworkMap) { var p *Persist var fin *empty.Message - if state == stateAuthenticated { + if state == StateAuthenticated { fin = new(empty.Message) } if nm != nil && loggedIn && synced { @@ -554,7 +570,7 @@ func (c *Client) sendStatus(who string, err error, url string, nm *NetworkMap) { Persist: p, NetMap: nm, Hostinfo: hi, - state: state, + State: state, } if err != nil { new.Err = err.Error() @@ -623,3 +639,20 @@ func (c *Client) Shutdown() { c.logf("Client.Shutdown done.") } } + +// NodePublicKey returns the node public key currently in use. This is +// used exclusively in tests. +func (c *Client) TestOnlyNodePublicKey() wgcfg.Key { + priv := c.direct.GetPersist() + return priv.PrivateNodeKey.Public() +} + +func (c *Client) TestOnlySetAuthKey(authkey string) { + c.direct.mu.Lock() + defer c.direct.mu.Unlock() + c.direct.authKey = authkey +} + +func (c *Client) TestOnlyTimeNow() time.Time { + return c.timeNow() +} diff --git a/control/controlclient/auto_test.go b/control/controlclient/auto_test.go deleted file mode 100644 index 0c6a0145e..000000000 --- a/control/controlclient/auto_test.go +++ /dev/null @@ -1,1337 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build depends_on_currently_unreleased - -package controlclient - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io/ioutil" - "net/http" - "net/http/cookiejar" - "net/http/httptest" - "net/url" - "os" - "reflect" - "runtime/pprof" - "strconv" - "strings" - "sync" - "testing" - "time" - - "github.com/klauspost/compress/zstd" - "github.com/tailscale/wireguard-go/wgcfg" - "tailscale.com/tailcfg" - "tailscale.com/tstest" - "tailscale.com/types/logger" - "tailscale.io/control" // not yet released - "tailscale.io/control/cfgdb" -) - -func TestTest(t *testing.T) { - check := tstest.NewResourceCheck() - defer check.Assert(t) -} - -func TestServerStartStop(t *testing.T) { - s := newServer(t) - defer s.close() -} - -func TestControlBasics(t *testing.T) { - s := newServer(t) - defer s.close() - - c := s.newClient(t, "c") - c.Login(nil, 0) - status := c.waitStatus(t, stateURLVisitRequired) - c.postAuthURL(t, "foo@tailscale.com", status.New) -} - -// A function with the same semantics as t.Run(), but which doesn't rearrange -// the logs by creating a new sub-t.Logf, and doesn't support parallelism. -// This makes it possible to actually figure out what happened by looking -// at the logs. -func runSub(t *testing.T, name string, fn func(t *testing.T)) { - t.Helper() - t.Logf("\n") - t.Logf("\n\n--- Starting: %v\n\n", name) - defer func() { - if t.Failed() { - t.Logf("\n\n--- FAILED: %v\n\n", name) - } else { - t.Logf("\n\n--- PASS: %v\n\n", name) - } - }() - - fn(t) -} - -func fatal(t *testing.T, args ...interface{}) { - t.Helper() - t.Fatal("FAILED: ", fmt.Sprint(args...)) -} - -func fatalf(t *testing.T, s string, args ...interface{}) { - t.Helper() - t.Fatalf("FAILED: "+s, args...) -} - -func TestControl(t *testing.T) { - s := newServer(t) - defer s.close() - - c1 := s.newClient(t, "c1") - - runSub(t, "authorize first tailscale.com client", func(t *testing.T) { - const loginName = "testuser1@tailscale.com" - c1.checkNoStatus(t) - c1.loginAs(t, loginName) - c1.waitStatus(t, stateAuthenticated) - status := c1.waitStatus(t, stateSynchronized) - if got, want := status.New.NetMap.MachineStatus, tailcfg.MachineUnauthorized; got != want { - fatalf(t, "MachineStatus=%v, want %v", got, want) - } - c1.checkNoStatus(t) - affectedPeers, err := s.control.AuthorizeMachine(c1.mkey, c1.nkey) - if err != nil { - fatal(t, err) - } - status = c1.status(t) - if got := status.New.Persist.LoginName; got != loginName { - fatalf(t, "LoginName=%q, want %q", got, loginName) - } - if got := status.New.Persist.Provider; got != "google" { - fatalf(t, "Provider=%q, want google", got) - } - if len(affectedPeers) != 1 || affectedPeers[0] != c1.id { - fatalf(t, "authorization should notify the node being authorized (%v), got: %v", c1.id, affectedPeers) - } - if peers := status.New.NetMap.Peers; len(peers) != 0 { - fatalf(t, "peers=%v, want none", peers) - } - if userID := status.New.NetMap.User; userID == 0 { - fatalf(t, "NetMap.User is missing") - } else { - profile := status.New.NetMap.UserProfiles[userID] - if profile.LoginName != loginName { - fatalf(t, "NetMap user LoginName=%q, want %q", profile.LoginName, loginName) - } - } - c1.checkNoStatus(t) - }) - - c2 := s.newClient(t, "c2") - - runSub(t, "authorize second tailscale.io client", func(t *testing.T) { - c2.loginAs(t, "testuser2@tailscale.com") - c2.waitStatus(t, stateAuthenticated) - c2.waitStatus(t, stateSynchronized) - c2.checkNoStatus(t) - - // Make sure not to call operations like this on a client in a - // test until the initial map read is done. Otherwise the - // initial map read will trigger a map update to peers, and - // there will sometimes be a spurious map update. - affectedPeers, err := s.control.AuthorizeMachine(c2.mkey, c2.nkey) - if err != nil { - fatal(t, err) - } - status := c2.waitStatus(t, stateSynchronized) - c1Status := c1.waitStatus(t, stateSynchronized) - - if len(affectedPeers) != 2 { - fatalf(t, "affectedPeers=%v, want two entries", affectedPeers) - } - if want := []tailcfg.NodeID{c1.id, c2.id}; !nodeIDsEqual(affectedPeers, want) { - fatalf(t, "affectedPeers=%v, want %v", affectedPeers, want) - } - - c1NetMap := c1Status.New.NetMap - c2NetMap := status.New.NetMap - if len(c1NetMap.Peers) != 1 || len(c2NetMap.Peers) != 1 { - t.Error("wrong number of peers") - } else { - if c2NetMap.Peers[0].Key != c1.nkey { - fatalf(t, "c2 has wrong peer key %v, want %v", c2NetMap.Peers[0].Key, c1.nkey) - } - if c1NetMap.Peers[0].Key != c2.nkey { - fatalf(t, "c1 has wrong peer key %v, want %v", c1NetMap.Peers[0].Key, c2.nkey) - } - } - if t.Failed() { - fatalf(t, "client1 network map:\n%s", c1Status.New.NetMap) - fatalf(t, "client2 network map:\n%s", status.New.NetMap) - } - - c1.checkNoStatus(t) - c2.checkNoStatus(t) - }) - - // c3/c4 are on a different domain to c1/c2. - // The two domains should never affect one another. - c3 := s.newClient(t, "c3") - - runSub(t, "authorize first onmicrosoft client", func(t *testing.T) { - c3.loginAs(t, "testuser1@tailscale.onmicrosoft.com") - c3.waitStatus(t, stateAuthenticated) - c3Status := c3.waitStatus(t, stateSynchronized) - // no machine authorization for tailscale.onmicrosoft.com - c1.checkNoStatus(t) - c2.checkNoStatus(t) - - netMap := c3Status.New.NetMap - if netMap.NodeKey != c3.nkey { - fatalf(t, "netMap.NodeKey=%v, want %v", netMap.NodeKey, c3.nkey) - } - if len(netMap.Peers) != 0 { - fatalf(t, "netMap.Peers=%v, want none", netMap.Peers) - } - - c1.checkNoStatus(t) - c2.checkNoStatus(t) - c3.checkNoStatus(t) - }) - - c4 := s.newClient(t, "c4") - - runSub(t, "authorize second onmicrosoft client", func(t *testing.T) { - c4.loginAs(t, "testuser2@tailscale.onmicrosoft.com") - c4.waitStatus(t, stateAuthenticated) - c3Status := c3.waitStatus(t, stateSynchronized) - c4Status := c4.waitStatus(t, stateSynchronized) - c3NetMap := c3Status.New.NetMap - c4NetMap := c4Status.New.NetMap - - c1.checkNoStatus(t) - c2.checkNoStatus(t) - - if len(c3NetMap.Peers) != 1 { - fatalf(t, "wrong number of c3 peers: %d", len(c3NetMap.Peers)) - } else if len(c4NetMap.Peers) != 1 { - fatalf(t, "wrong number of c4 peers: %d", len(c4NetMap.Peers)) - } else { - if c3NetMap.Peers[0].Key != c4.nkey || c4NetMap.Peers[0].Key != c3.nkey { - t.Error("wrong peer key") - } - } - if t.Failed() { - fatalf(t, "client3 network map:\n%s", c3NetMap) - fatalf(t, "client4 network map:\n%s", c4NetMap) - } - }) - - var c1NetMap *NetworkMap - runSub(t, "update c1 and c2 endpoints", func(t *testing.T) { - c1Endpoints := []string{"172.16.1.5:12345", "4.4.4.4:4444"} - c1.checkNoStatus(t) - c1.UpdateEndpoints(1234, c1Endpoints) - c1NetMap = c1.status(t).New.NetMap - c2NetMap := c2.status(t).New.NetMap - c1.checkNoStatus(t) - c2.checkNoStatus(t) - - if c1NetMap.LocalPort != 1234 { - fatalf(t, "c1 netmap localport=%d, want 1234", c1NetMap.LocalPort) - } - if len(c2NetMap.Peers) != 1 { - fatalf(t, "wrong peer count: %d", len(c2NetMap.Peers)) - } - if got := c2NetMap.Peers[0].Endpoints; !hasStringsSuffix(got, c1Endpoints) { - fatalf(t, "c2 peer endpoints=%v, want %v", got, c1Endpoints) - } - c3.checkNoStatus(t) - c4.checkNoStatus(t) - - c2Endpoints := []string{"172.16.1.7:6543", "5.5.5.5.3333"} - c2.UpdateEndpoints(9876, c2Endpoints) - c1NetMap = c1.status(t).New.NetMap - c2NetMap = c2.status(t).New.NetMap - - if c1NetMap.LocalPort != 1234 { - fatalf(t, "c1 netmap localport=%d, want 1234", c1NetMap.LocalPort) - } - if c2NetMap.LocalPort != 9876 { - fatalf(t, "c2 netmap localport=%d, want 9876", c2NetMap.LocalPort) - } - if got := c2NetMap.Peers[0].Endpoints; !hasStringsSuffix(got, c1Endpoints) { - fatalf(t, "c2 peer endpoints=%v, want suffix %v", got, c1Endpoints) - } - if got := c1NetMap.Peers[0].Endpoints; !hasStringsSuffix(got, c2Endpoints) { - fatalf(t, "c1 peer endpoints=%v, want suffix %v", got, c2Endpoints) - } - - c1.checkNoStatus(t) - c2.checkNoStatus(t) - c3.checkNoStatus(t) - c4.checkNoStatus(t) - }) - - allZeros, err := wgcfg.ParseCIDR("0.0.0.0/0") - if err != nil { - fatal(t, err) - } - - runSub(t, "route all traffic via client 1", func(t *testing.T) { - aips := []wgcfg.CIDR{} - aips = append(aips, c1NetMap.Addresses...) - aips = append(aips, allZeros) - - affectedPeers, err := s.control.SetAllowedIPs(c1.nkey, aips) - if err != nil { - fatal(t, err) - } - c2Status := c2.status(t) - c2NetMap := c2Status.New.NetMap - - if want := []tailcfg.NodeID{c2.id}; !nodeIDsEqual(affectedPeers, want) { - fatalf(t, "affectedPeers=%v, want %v", affectedPeers, want) - } - - _ = c2NetMap - foundAllZeros := false - for _, cidr := range c2NetMap.Peers[0].AllowedIPs { - if cidr == allZeros { - foundAllZeros = true - } - } - if !foundAllZeros { - fatalf(t, "client2 peer does not contain %s: %v", allZeros, c2NetMap.Peers[0].AllowedIPs) - } - - c1.checkNoStatus(t) - c3.checkNoStatus(t) - c4.checkNoStatus(t) - }) - - runSub(t, "remove route all traffic", func(t *testing.T) { - affectedPeers, err := s.control.SetAllowedIPs(c1.nkey, c1NetMap.Addresses) - if err != nil { - fatal(t, err) - } - c2NetMap := c2.status(t).New.NetMap - - if want := []tailcfg.NodeID{c2.id}; !nodeIDsEqual(affectedPeers, want) { - fatalf(t, "affectedPeers=%v, want %v", affectedPeers, want) - } - - foundAllZeros := false - for _, cidr := range c2NetMap.Peers[0].AllowedIPs { - if cidr == allZeros { - foundAllZeros = true - } - } - if foundAllZeros { - fatalf(t, "client2 peer still contains %s: %v", allZeros, c2NetMap.Peers[0].AllowedIPs) - } - - c1.checkNoStatus(t) - c3.checkNoStatus(t) - c4.checkNoStatus(t) - }) - - runSub(t, "refresh client key", func(t *testing.T) { - oldKey := c1.nkey - - c1.Login(nil, LoginInteractive) - status := c1.waitStatus(t, stateURLVisitRequired) - c1.postAuthURL(t, "testuser1@tailscale.com", status.New) - c1.waitStatus(t, stateAuthenticated) - status = c1.waitStatus(t, stateSynchronized) - if status.New.Err != "" { - fatal(t, status.New.Err) - } - - c1NetMap := status.New.NetMap - c1.nkey = c1NetMap.NodeKey - if c1.nkey == oldKey { - fatalf(t, "new key is the same as the old key: %s", oldKey) - } - c2NetMap := c2.status(t).New.NetMap - if len(c2NetMap.Peers) != 1 || c2NetMap.Peers[0].Key != c1.nkey { - fatalf(t, "c2 peer: %v, want new node key %v", c1.nkey, c2NetMap.Peers[0].Key) - } - - c3.checkNoStatus(t) - c4.checkNoStatus(t) - }) - - runSub(t, "set hostinfo", func(t *testing.T) { - c3.Login(nil, LoginDefault) - c4.Login(nil, LoginDefault) - c3.waitStatus(t, stateAuthenticated) - c4.waitStatus(t, stateAuthenticated) - c3.waitStatus(t, stateSynchronized) - c4.waitStatus(t, stateSynchronized) - - c3.UpdateEndpoints(9876, []string{"1.2.3.4:3333"}) - c3.waitStatus(t, stateSynchronized) - c4.waitStatus(t, stateSynchronized) - - c4.UpdateEndpoints(9876, []string{"5.6.7.8:1111"}) - c3.waitStatus(t, stateSynchronized) - c4.waitStatus(t, stateSynchronized) - - c3.SetHostinfo(&tailcfg.Hostinfo{ - BackendLogID: "set-hostinfo-test", - OS: "linux", - }) - c3.waitStatus(t, stateSynchronized) - c4NetMap := c4.status(t).New.NetMap - if len(c4NetMap.Peers) != 1 { - fatalf(t, "wrong number of peers: %v", c4NetMap.Peers) - } - peer := c4NetMap.Peers[0] - if !peer.KeepAlive { - fatalf(t, "peer KeepAlive=false, want true") - } - if peer.Hostinfo.OS != "linux" { - fatalf(t, "peer OS is not linux: %v", peer.Hostinfo) - } - - c4.SetHostinfo(&tailcfg.Hostinfo{ - BackendLogID: "set-hostinfo-test", - OS: "iOS", - }) - c3NetMap := c3.status(t).New.NetMap - c4NetMap = c4.status(t).New.NetMap - if len(c3NetMap.Peers) != 1 { - fatalf(t, "wrong number of peers: %v", c3NetMap.Peers) - } - if len(c4NetMap.Peers) != 1 { - fatalf(t, "wrong number of peers: %v", c4NetMap.Peers) - } - peer = c3NetMap.Peers[0] - if peer.KeepAlive { - fatalf(t, "peer KeepAlive=true, want false") - } - if peer.Hostinfo.OS != "iOS" { - fatalf(t, "peer OS is not iOS: %v", peer.Hostinfo) - } - peer = c4NetMap.Peers[0] - if peer.KeepAlive { - fatalf(t, "peer KeepAlive=true, want false") - } - if peer.Hostinfo.OS != "linux" { - fatalf(t, "peer OS is not linux: %v", peer.Hostinfo) - } - - }) -} - -func hasStringsSuffix(list, suffix []string) bool { - if len(list) < len(suffix) { - return false - } - return reflect.DeepEqual(list[len(list)-len(suffix):], suffix) -} - -func TestLoginInterrupt(t *testing.T) { - s := newServer(t) - defer s.close() - - c := s.newClient(t, "c") - - const loginName = "testuser1@tailscale.com" - c.checkNoStatus(t) - c.loginAs(t, loginName) - c.waitStatus(t, stateAuthenticated) - c.waitStatus(t, stateSynchronized) - t.Logf("authorizing: %v %v %v\n", s, c.mkey, c.nkey) - if _, err := s.control.AuthorizeMachine(c.mkey, c.nkey); err != nil { - fatal(t, err) - } - status := c.waitStatus(t, stateSynchronized) - if got, want := status.New.NetMap.MachineStatus, tailcfg.MachineAuthorized; got != want { - fatalf(t, "MachineStatus=%v, want %v", got, want) - } - origAddrs := status.New.NetMap.Addresses - if len(origAddrs) == 0 { - fatalf(t, "Addresses empty, want something") - } - - c.Logout() - c.waitStatus(t, stateNotAuthenticated) - c.Login(nil, 0) - status = c.waitStatus(t, stateURLVisitRequired) - authURL := status.New.URL - - // Interrupt, and do login again. - c.Login(nil, 0) - status = c.waitStatus(t, stateURLVisitRequired) - authURL2 := status.New.URL - - if authURL == authURL2 { - fatalf(t, "auth URLs match for subsequent logins: %s", authURL) - } - - // Direct auth URL visit is not enough because our cookie is no longer fresh. - req, err := http.NewRequest("GET", authURL2, nil) - if err != nil { - fatal(t, err) - } - req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - resp, err := c.httpc.Do(req.WithContext(c.ctx)) - if err != nil { - fatal(t, err) - } - b, err := ioutil.ReadAll(resp.Body) - if err != nil { - fatal(t, err) - } - resp.Body.Close() - if i := bytes.Index(b, []byte("
header - b = b[i:] - } - if !bytes.Contains(b, []byte("