control/controlclient: move auto_test back to corp repo.

It can't run without corp stuff anyway, and makes it harder to
refactor the control server.
reviewable/pr417/r1
David Anderson 5 years ago
parent 737124ef70
commit 557b310e67

@ -17,6 +17,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/tailscale/wireguard-go/wgcfg"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"tailscale.com/logtail/backoff" "tailscale.com/logtail/backoff"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
@ -25,36 +26,37 @@ import (
"tailscale.com/types/structs" "tailscale.com/types/structs"
) )
// TODO(apenwarr): eliminate the 'state' variable, as it's now obsolete. // State is the high-level state of the client. It is used only in
// It's used only by the unit tests. // unit tests for proper sequencing, don't depend on it anywhere else.
type state int // TODO(apenwarr): eliminate 'state', as it's now obsolete.
type State int
const ( const (
stateNew = state(iota) StateNew = State(iota)
stateNotAuthenticated StateNotAuthenticated
stateAuthenticating StateAuthenticating
stateURLVisitRequired StateURLVisitRequired
stateAuthenticated StateAuthenticated
stateSynchronized // connected and received map update StateSynchronized // connected and received map update
) )
func (s state) MarshalText() ([]byte, error) { func (s State) MarshalText() ([]byte, error) {
return []byte(s.String()), nil return []byte(s.String()), nil
} }
func (s state) String() string { func (s State) String() string {
switch s { switch s {
case stateNew: case StateNew:
return "state:new" return "state:new"
case stateNotAuthenticated: case StateNotAuthenticated:
return "state:not-authenticated" return "state:not-authenticated"
case stateAuthenticating: case StateAuthenticating:
return "state:authenticating" return "state:authenticating"
case stateURLVisitRequired: case StateURLVisitRequired:
return "state:url-visit-required" return "state:url-visit-required"
case stateAuthenticated: case StateAuthenticated:
return "state:authenticated" return "state:authenticated"
case stateSynchronized: case StateSynchronized:
return "state:synchronized" return "state:synchronized"
default: default:
return fmt.Sprintf("state:unknown:%d", int(s)) return fmt.Sprintf("state:unknown:%d", int(s))
@ -69,7 +71,7 @@ type Status struct {
Persist *Persist // locally persisted configuration Persist *Persist // locally persisted configuration
NetMap *NetworkMap // server-pushed configuration NetMap *NetworkMap // server-pushed configuration
Hostinfo *tailcfg.Hostinfo // current Hostinfo data Hostinfo *tailcfg.Hostinfo // current Hostinfo data
state state State State
} }
// Equal reports whether s and s2 are equal. // Equal reports whether s and s2 are equal.
@ -84,7 +86,7 @@ func (s *Status) Equal(s2 *Status) bool {
reflect.DeepEqual(s.Persist, s2.Persist) && reflect.DeepEqual(s.Persist, s2.Persist) &&
reflect.DeepEqual(s.NetMap, s2.NetMap) && reflect.DeepEqual(s.NetMap, s2.NetMap) &&
reflect.DeepEqual(s.Hostinfo, s2.Hostinfo) && reflect.DeepEqual(s.Hostinfo, s2.Hostinfo) &&
s.state == s2.state s.State == s2.State
} }
func (s Status) String() string { func (s Status) String() string {
@ -92,7 +94,7 @@ func (s Status) String() string {
if err != nil { if err != nil {
panic(err) panic(err)
} }
return s.state.String() + " " + string(b) return s.State.String() + " " + string(b)
} }
type LoginGoal struct { type LoginGoal struct {
@ -121,7 +123,7 @@ type Client struct {
hostinfo *tailcfg.Hostinfo hostinfo *tailcfg.Hostinfo
inPollNetMap bool // true if currently running a PollNetMap inPollNetMap bool // true if currently running a PollNetMap
inSendStatus int // number of sendStatus calls currently in progress inSendStatus int // number of sendStatus calls currently in progress
state state state State
authCtx context.Context // context used for auth requests authCtx context.Context // context used for auth requests
mapCtx context.Context // context used for netmap requests mapCtx context.Context // context used for netmap requests
@ -319,7 +321,7 @@ func (c *Client) authRoutine() {
c.mu.Lock() c.mu.Lock()
c.loggedIn = false c.loggedIn = false
c.loginGoal = nil c.loginGoal = nil
c.state = stateNotAuthenticated c.state = StateNotAuthenticated
c.synced = false c.synced = false
c.mu.Unlock() c.mu.Unlock()
@ -328,9 +330,9 @@ func (c *Client) authRoutine() {
} else { // ie. goal.wantLoggedIn } else { // ie. goal.wantLoggedIn
c.mu.Lock() c.mu.Lock()
if goal.url != "" { if goal.url != "" {
c.state = stateURLVisitRequired c.state = StateURLVisitRequired
} else { } else {
c.state = stateAuthenticating c.state = StateAuthenticating
} }
c.mu.Unlock() c.mu.Unlock()
@ -359,7 +361,7 @@ func (c *Client) authRoutine() {
c.mu.Lock() c.mu.Lock()
c.loginGoal = goal c.loginGoal = goal
c.state = stateURLVisitRequired c.state = StateURLVisitRequired
c.synced = false c.synced = false
c.mu.Unlock() c.mu.Unlock()
@ -372,7 +374,7 @@ func (c *Client) authRoutine() {
c.mu.Lock() c.mu.Lock()
c.loggedIn = true c.loggedIn = true
c.loginGoal = nil c.loginGoal = nil
c.state = stateAuthenticated c.state = StateAuthenticated
c.mu.Unlock() c.mu.Unlock()
c.sendStatus("authRoutine4", nil, "", nil) c.sendStatus("authRoutine4", nil, "", nil)
@ -382,6 +384,20 @@ func (c *Client) authRoutine() {
} }
} }
// Expiry returns the credential expiration time, or the zero time if
// the expiration time isn't known. Used in tests only.
func (c *Client) Expiry() *time.Time {
c.mu.Lock()
defer c.mu.Unlock()
return c.expiry
}
// Direct returns the underlying direct client object. Used in tests
// only.
func (c *Client) Direct() *Direct {
return c.direct
}
func (c *Client) mapRoutine() { func (c *Client) mapRoutine() {
defer close(c.mapDone) defer close(c.mapDone)
bo := backoff.NewBackoff("mapRoutine", c.logf) bo := backoff.NewBackoff("mapRoutine", c.logf)
@ -449,7 +465,7 @@ func (c *Client) mapRoutine() {
c.synced = true c.synced = true
c.inPollNetMap = true c.inPollNetMap = true
if c.loggedIn { if c.loggedIn {
c.state = stateSynchronized c.state = StateSynchronized
} }
exp := nm.Expiry exp := nm.Expiry
c.expiry = &exp c.expiry = &exp
@ -467,8 +483,8 @@ func (c *Client) mapRoutine() {
c.mu.Lock() c.mu.Lock()
c.synced = false c.synced = false
c.inPollNetMap = false c.inPollNetMap = false
if c.state == stateSynchronized { if c.state == StateSynchronized {
c.state = stateAuthenticated c.state = StateAuthenticated
} }
c.mu.Unlock() c.mu.Unlock()
@ -537,7 +553,7 @@ func (c *Client) sendStatus(who string, err error, url string, nm *NetworkMap) {
var p *Persist var p *Persist
var fin *empty.Message var fin *empty.Message
if state == stateAuthenticated { if state == StateAuthenticated {
fin = new(empty.Message) fin = new(empty.Message)
} }
if nm != nil && loggedIn && synced { if nm != nil && loggedIn && synced {
@ -554,7 +570,7 @@ func (c *Client) sendStatus(who string, err error, url string, nm *NetworkMap) {
Persist: p, Persist: p,
NetMap: nm, NetMap: nm,
Hostinfo: hi, Hostinfo: hi,
state: state, State: state,
} }
if err != nil { if err != nil {
new.Err = err.Error() new.Err = err.Error()
@ -623,3 +639,20 @@ func (c *Client) Shutdown() {
c.logf("Client.Shutdown done.") c.logf("Client.Shutdown done.")
} }
} }
// NodePublicKey returns the node public key currently in use. This is
// used exclusively in tests.
func (c *Client) TestOnlyNodePublicKey() wgcfg.Key {
priv := c.direct.GetPersist()
return priv.PrivateNodeKey.Public()
}
func (c *Client) TestOnlySetAuthKey(authkey string) {
c.direct.mu.Lock()
defer c.direct.mu.Unlock()
c.direct.authKey = authkey
}
func (c *Client) TestOnlyTimeNow() time.Time {
return c.timeNow()
}

File diff suppressed because it is too large Load Diff

@ -22,7 +22,7 @@ func fieldsOf(t reflect.Type) (fields []string) {
func TestStatusEqual(t *testing.T) { func TestStatusEqual(t *testing.T) {
// Verify that the Equal method stays in sync with reality // Verify that the Equal method stays in sync with reality
equalHandles := []string{"LoginFinished", "Err", "URL", "Persist", "NetMap", "Hostinfo", "state"} equalHandles := []string{"LoginFinished", "Err", "URL", "Persist", "NetMap", "Hostinfo", "State"}
if have := fieldsOf(reflect.TypeOf(Status{})); !reflect.DeepEqual(have, equalHandles) { if have := fieldsOf(reflect.TypeOf(Status{})); !reflect.DeepEqual(have, equalHandles) {
t.Errorf("Status.Equal check might be out of sync\nfields: %q\nhandled: %q\n", t.Errorf("Status.Equal check might be out of sync\nfields: %q\nhandled: %q\n",
have, equalHandles) have, equalHandles)
@ -48,13 +48,13 @@ func TestStatusEqual(t *testing.T) {
true, true,
}, },
{ {
&Status{state: stateNew}, &Status{State: StateNew},
&Status{state: stateNew}, &Status{State: StateNew},
true, true,
}, },
{ {
&Status{state: stateNew}, &Status{State: StateNew},
&Status{state: stateAuthenticated}, &Status{State: StateAuthenticated},
false, false,
}, },
{ {

Loading…
Cancel
Save