control: use tstime instead of time (#8595)

Updates #8587
Signed-off-by: Claire Wang <claire@tailscale.com>
pull/8797/head
Claire Wang 10 months ago committed by GitHub
parent a8e32f1a4b
commit a17c45fd6e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -110,7 +110,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
tailscale.com/tailcfg from tailscale.com/cmd/tailscale/cli+ tailscale.com/tailcfg from tailscale.com/cmd/tailscale/cli+
tailscale.com/tka from tailscale.com/client/tailscale+ tailscale.com/tka from tailscale.com/client/tailscale+
W tailscale.com/tsconst from tailscale.com/net/interfaces W tailscale.com/tsconst from tailscale.com/net/interfaces
tailscale.com/tstime from tailscale.com/derp+ tailscale.com/tstime from tailscale.com/control/controlhttp+
💣 tailscale.com/tstime/mono from tailscale.com/tstime/rate 💣 tailscale.com/tstime/mono from tailscale.com/tstime/rate
tailscale.com/tstime/rate from tailscale.com/wgengine/filter+ tailscale.com/tstime/rate from tailscale.com/wgengine/filter+
tailscale.com/types/dnstype from tailscale.com/tailcfg tailscale.com/types/dnstype from tailscale.com/tailcfg

@ -15,6 +15,7 @@ import (
"tailscale.com/logtail/backoff" "tailscale.com/logtail/backoff"
"tailscale.com/net/sockstats" "tailscale.com/net/sockstats"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/tstime"
"tailscale.com/types/empty" "tailscale.com/types/empty"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/logger" "tailscale.com/types/logger"
@ -48,7 +49,7 @@ var _ Client = (*Auto)(nil)
// It's a concrete implementation of the Client interface. // It's a concrete implementation of the Client interface.
type Auto struct { type Auto struct {
direct *Direct // our interface to the server APIs direct *Direct // our interface to the server APIs
timeNow func() time.Time clock tstime.Clock
logf logger.Logf logf logger.Logf
expiry *time.Time expiry *time.Time
closed bool closed bool
@ -107,12 +108,12 @@ func NewNoStart(opts Options) (_ *Auto, err error) {
if opts.Logf == nil { if opts.Logf == nil {
opts.Logf = func(fmt string, args ...any) {} opts.Logf = func(fmt string, args ...any) {}
} }
if opts.TimeNow == nil { if opts.Clock == nil {
opts.TimeNow = time.Now opts.Clock = tstime.StdClock{}
} }
c := &Auto{ c := &Auto{
direct: direct, direct: direct,
timeNow: opts.TimeNow, clock: opts.Clock,
logf: opts.Logf, logf: opts.Logf,
newMapCh: make(chan struct{}, 1), newMapCh: make(chan struct{}, 1),
quit: make(chan struct{}), quit: make(chan struct{}),
@ -208,7 +209,7 @@ func (c *Auto) sendNewMapRequest() {
c.liteMapUpdateCancel = cancel c.liteMapUpdateCancel = cancel
go func() { go func() {
defer cancel() defer cancel()
t0 := time.Now() t0 := c.clock.Now()
err := c.direct.SendLiteMapUpdate(ctx) err := c.direct.SendLiteMapUpdate(ctx)
d := time.Since(t0).Round(time.Millisecond) d := time.Since(t0).Round(time.Millisecond)
@ -704,14 +705,14 @@ func (c *Auto) Logout(ctx context.Context) error {
c.mu.Unlock() c.mu.Unlock()
c.cancelAuth() c.cancelAuth()
timer := time.NewTimer(10 * time.Second) timer, timerChannel := c.clock.NewTimer(10 * time.Second)
defer timer.Stop() defer timer.Stop()
select { select {
case err := <-errc: case err := <-errc:
return err return err
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
case <-timer.C: case <-timerChannel:
return context.DeadlineExceeded return context.DeadlineExceeded
} }
} }
@ -772,7 +773,7 @@ func (c *Auto) TestOnlySetAuthKey(authkey string) {
} }
func (c *Auto) TestOnlyTimeNow() time.Time { func (c *Auto) TestOnlyTimeNow() time.Time {
return c.timeNow() return c.clock.Now()
} }
// SetDNS sends the SetDNSRequest request to the control plane server, // SetDNS sends the SetDNSRequest request to the control plane server,

@ -45,6 +45,7 @@ import (
"tailscale.com/syncs" "tailscale.com/syncs"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/tka" "tailscale.com/tka"
"tailscale.com/tstime"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/logger" "tailscale.com/types/logger"
"tailscale.com/types/netmap" "tailscale.com/types/netmap"
@ -63,7 +64,7 @@ type Direct struct {
dialer *tsdial.Dialer dialer *tsdial.Dialer
dnsCache *dnscache.Resolver dnsCache *dnscache.Resolver
serverURL string // URL of the tailcontrol server serverURL string // URL of the tailcontrol server
timeNow func() time.Time clock tstime.Clock
lastPrintMap time.Time lastPrintMap time.Time
newDecompressor func() (Decompressor, error) newDecompressor func() (Decompressor, error)
keepAlive bool keepAlive bool
@ -105,8 +106,8 @@ type Options struct {
GetMachinePrivateKey func() (key.MachinePrivate, error) // returns the machine key to use GetMachinePrivateKey func() (key.MachinePrivate, error) // returns the machine key to use
ServerURL string // URL of the tailcontrol server ServerURL string // URL of the tailcontrol server
AuthKey string // optional node auth key for auto registration AuthKey string // optional node auth key for auto registration
TimeNow func() time.Time // time.Now implementation used by Client Clock tstime.Clock
Hostinfo *tailcfg.Hostinfo // non-nil passes ownership, nil means to use default using os.Hostname, etc Hostinfo *tailcfg.Hostinfo // non-nil passes ownership, nil means to use default using os.Hostname, etc
DiscoPublicKey key.DiscoPublic DiscoPublicKey key.DiscoPublic
NewDecompressor func() (Decompressor, error) NewDecompressor func() (Decompressor, error)
KeepAlive bool KeepAlive bool
@ -191,8 +192,8 @@ func NewDirect(opts Options) (*Direct, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if opts.TimeNow == nil { if opts.Clock == nil {
opts.TimeNow = time.Now opts.Clock = tstime.StdClock{}
} }
if opts.Logf == nil { if opts.Logf == nil {
// TODO(apenwarr): remove this default and fail instead. // TODO(apenwarr): remove this default and fail instead.
@ -235,7 +236,7 @@ func NewDirect(opts Options) (*Direct, error) {
httpc: httpc, httpc: httpc,
getMachinePrivKey: opts.GetMachinePrivateKey, getMachinePrivKey: opts.GetMachinePrivateKey,
serverURL: opts.ServerURL, serverURL: opts.ServerURL,
timeNow: opts.TimeNow, clock: opts.Clock,
logf: opts.Logf, logf: opts.Logf,
newDecompressor: opts.NewDecompressor, newDecompressor: opts.NewDecompressor,
keepAlive: opts.KeepAlive, keepAlive: opts.KeepAlive,
@ -432,7 +433,7 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new
authKey, isWrapped, wrappedSig, wrappedKey := decodeWrappedAuthkey(c.authKey, c.logf) authKey, isWrapped, wrappedSig, wrappedKey := decodeWrappedAuthkey(c.authKey, c.logf)
hi := c.hostInfoLocked() hi := c.hostInfoLocked()
backendLogID := hi.BackendLogID backendLogID := hi.BackendLogID
expired := c.expiry != nil && !c.expiry.IsZero() && c.expiry.Before(c.timeNow()) expired := c.expiry != nil && !c.expiry.IsZero() && c.expiry.Before(c.clock.Now())
c.mu.Unlock() c.mu.Unlock()
machinePrivKey, err := c.getMachinePrivKey() machinePrivKey, err := c.getMachinePrivKey()
@ -537,7 +538,7 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new
err = errors.New("hostinfo: BackendLogID missing") err = errors.New("hostinfo: BackendLogID missing")
return regen, opt.URL, nil, err return regen, opt.URL, nil, err
} }
now := time.Now().Round(time.Second) now := c.clock.Now().Round(time.Second)
request := tailcfg.RegisterRequest{ request := tailcfg.RegisterRequest{
Version: 1, Version: 1,
OldNodeKey: oldNodeKey, OldNodeKey: oldNodeKey,
@ -911,7 +912,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, readOnly bool
defer cancel() defer cancel()
machinePubKey := machinePrivKey.Public() machinePubKey := machinePrivKey.Public()
t0 := time.Now() t0 := c.clock.Now()
// Url and httpc are protocol specific. // Url and httpc are protocol specific.
var url string var url string
@ -954,7 +955,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, readOnly bool
return nil return nil
} }
timeout := time.NewTimer(pollTimeout) timeout, timeoutChannel := c.clock.NewTimer(pollTimeout)
timeoutReset := make(chan struct{}) timeoutReset := make(chan struct{})
pollDone := make(chan struct{}) pollDone := make(chan struct{})
defer close(pollDone) defer close(pollDone)
@ -964,14 +965,14 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, readOnly bool
case <-pollDone: case <-pollDone:
vlogf("netmap: ending timeout goroutine") vlogf("netmap: ending timeout goroutine")
return return
case <-timeout.C: case <-timeoutChannel:
c.logf("map response long-poll timed out!") c.logf("map response long-poll timed out!")
cancel() cancel()
return return
case <-timeoutReset: case <-timeoutReset:
if !timeout.Stop() { if !timeout.Stop() {
select { select {
case <-timeout.C: case <-timeoutChannel:
case <-pollDone: case <-pollDone:
vlogf("netmap: ending timeout goroutine") vlogf("netmap: ending timeout goroutine")
return return
@ -1096,7 +1097,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, readOnly bool
go dumpGoroutinesToURL(c.httpc, resp.Debug.GoroutineDumpURL) go dumpGoroutinesToURL(c.httpc, resp.Debug.GoroutineDumpURL)
} }
if sleep := time.Duration(resp.Debug.SleepSeconds * float64(time.Second)); sleep > 0 { if sleep := time.Duration(resp.Debug.SleepSeconds * float64(time.Second)); sleep > 0 {
if err := sleepAsRequested(ctx, c.logf, timeoutReset, sleep); err != nil { if err := sleepAsRequested(ctx, c.logf, timeoutReset, sleep, c.clock); err != nil {
return err return err
} }
} }
@ -1126,7 +1127,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, readOnly bool
// This is handy for debugging, and our logs processing // This is handy for debugging, and our logs processing
// pipeline depends on it. (TODO: Remove this dependency.) // pipeline depends on it. (TODO: Remove this dependency.)
// Code elsewhere prints netmap diffs every time they are received. // Code elsewhere prints netmap diffs every time they are received.
now := c.timeNow() now := c.clock.Now()
if now.Sub(c.lastPrintMap) >= 5*time.Minute { if now.Sub(c.lastPrintMap) >= 5*time.Minute {
c.lastPrintMap = now c.lastPrintMap = now
c.logf("[v1] new network map[%d]:\n%s", i, nm.VeryConcise()) c.logf("[v1] new network map[%d]:\n%s", i, nm.VeryConcise())
@ -1304,7 +1305,7 @@ func initDevKnob() devKnobs {
} }
} }
var clockNow = time.Now var clock tstime.Clock = tstime.StdClock{}
// opt.Bool configs from control. // opt.Bool configs from control.
var ( var (
@ -1408,9 +1409,9 @@ func answerHeadPing(logf logger.Logf, c *http.Client, pr *tailcfg.PingRequest) {
if pr.Log { if pr.Log {
logf("answerHeadPing: sending HEAD ping to %v ...", pr.URL) logf("answerHeadPing: sending HEAD ping to %v ...", pr.URL)
} }
t0 := time.Now() t0 := clock.Now()
_, err = c.Do(req) _, err = c.Do(req)
d := time.Since(t0).Round(time.Millisecond) d := clock.Since(t0).Round(time.Millisecond)
if err != nil { if err != nil {
logf("answerHeadPing error: %v to %v (after %v)", err, pr.URL, d) logf("answerHeadPing error: %v to %v (after %v)", err, pr.URL, d)
} else if pr.Log { } else if pr.Log {
@ -1456,7 +1457,7 @@ func answerC2NPing(logf logger.Logf, c2nHandler http.Handler, c *http.Client, pr
if pr.Log { if pr.Log {
logf("answerC2NPing: sending POST ping to %v ...", pr.URL) logf("answerC2NPing: sending POST ping to %v ...", pr.URL)
} }
t0 := time.Now() t0 := clock.Now()
_, err = c.Do(req) _, err = c.Do(req)
d := time.Since(t0).Round(time.Millisecond) d := time.Since(t0).Round(time.Millisecond)
if err != nil { if err != nil {
@ -1466,7 +1467,7 @@ func answerC2NPing(logf logger.Logf, c2nHandler http.Handler, c *http.Client, pr
} }
} }
func sleepAsRequested(ctx context.Context, logf logger.Logf, timeoutReset chan<- struct{}, d time.Duration) error { func sleepAsRequested(ctx context.Context, logf logger.Logf, timeoutReset chan<- struct{}, d time.Duration, clock tstime.Clock) error {
const maxSleep = 5 * time.Minute const maxSleep = 5 * time.Minute
if d > maxSleep { if d > maxSleep {
logf("sleeping for %v, capped from server-requested %v ...", maxSleep, d) logf("sleeping for %v, capped from server-requested %v ...", maxSleep, d)
@ -1475,20 +1476,20 @@ func sleepAsRequested(ctx context.Context, logf logger.Logf, timeoutReset chan<-
logf("sleeping for server-requested %v ...", d) logf("sleeping for server-requested %v ...", d)
} }
ticker := time.NewTicker(pollTimeout / 2) ticker, tickerChannel := clock.NewTicker(pollTimeout / 2)
defer ticker.Stop() defer ticker.Stop()
timer := time.NewTimer(d) timer, timerChannel := clock.NewTimer(d)
defer timer.Stop() defer timer.Stop()
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
case <-timer.C: case <-timerChannel:
return nil return nil
case <-ticker.C: case <-tickerChannel:
select { select {
case timeoutReset <- struct{}{}: case timeoutReset <- struct{}{}:
case <-timer.C: case <-timerChannel:
return nil return nil
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
@ -1665,7 +1666,7 @@ func doPingerPing(logf logger.Logf, c *http.Client, pr *tailcfg.PingRequest, pin
logf("invalid ping request: missing url, ip or pinger") logf("invalid ping request: missing url, ip or pinger")
return return
} }
start := time.Now() start := clock.Now()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
@ -1703,7 +1704,7 @@ func postPingResult(start time.Time, logf logger.Logf, c *http.Client, pr *tailc
if pr.Log { if pr.Log {
logf("postPingResult: sending ping results to %v ...", pr.URL) logf("postPingResult: sending ping results to %v ...", pr.URL)
} }
t0 := time.Now() t0 := clock.Now()
_, err = c.Do(req) _, err = c.Do(req)
d := time.Since(t0).Round(time.Millisecond) d := time.Since(t0).Round(time.Millisecond)
if err != nil { if err != nil {

@ -307,7 +307,7 @@ func undeltaPeers(mapRes *tailcfg.MapResponse, prev []*tailcfg.Node) {
for _, n := range newFull { for _, n := range newFull {
peerByID[n.ID] = n peerByID[n.ID] = n
} }
now := clockNow() now := clock.Now()
for nodeID, seen := range mapRes.PeerSeenChange { for nodeID, seen := range mapRes.PeerSeenChange {
if n, ok := peerByID[nodeID]; ok { if n, ok := peerByID[nodeID]; ok {
if seen { if seen {

@ -14,6 +14,7 @@ import (
"go4.org/mem" "go4.org/mem"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/tstest" "tailscale.com/tstest"
"tailscale.com/tstime"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/netmap" "tailscale.com/types/netmap"
"tailscale.com/types/opt" "tailscale.com/types/opt"
@ -23,9 +24,6 @@ import (
func TestUndeltaPeers(t *testing.T) { func TestUndeltaPeers(t *testing.T) {
var curTime time.Time var curTime time.Time
tstest.Replace(t, &clockNow, func() time.Time {
return curTime
})
online := func(v bool) func(*tailcfg.Node) { online := func(v bool) func(*tailcfg.Node) {
return func(n *tailcfg.Node) { return func(n *tailcfg.Node) {
@ -298,6 +296,7 @@ func TestUndeltaPeers(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if !tt.curTime.IsZero() { if !tt.curTime.IsZero() {
curTime = tt.curTime curTime = tt.curTime
tstest.Replace(t, &clock, tstime.Clock(tstest.NewClock(tstest.ClockOpts{Start: curTime})))
} }
undeltaPeers(tt.mapRes, tt.prev) undeltaPeers(tt.mapRes, tt.prev)
if !reflect.DeepEqual(tt.mapRes.Peers, tt.want) { if !reflect.DeepEqual(tt.mapRes.Peers, tt.want) {

@ -23,6 +23,7 @@ import (
"tailscale.com/net/netmon" "tailscale.com/net/netmon"
"tailscale.com/net/tsdial" "tailscale.com/net/tsdial"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/tstime"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/logger" "tailscale.com/types/logger"
"tailscale.com/util/mak" "tailscale.com/util/mak"
@ -450,6 +451,7 @@ func (nc *NoiseClient) dial(ctx context.Context) (*noiseConn, error) {
DialPlan: dialPlan, DialPlan: dialPlan,
Logf: nc.logf, Logf: nc.logf,
NetMon: nc.netMon, NetMon: nc.netMon,
Clock: tstime.StdClock{},
}).Dial(ctx) }).Dial(ctx)
if err != nil { if err != nil {
return nil, err return nil, err

@ -127,7 +127,7 @@ func findIdentity(subject string, st certstore.Store) (certstore.Identity, []*x5
return nil, nil, err return nil, nil, err
} }
selected, chain := selectIdentityFromSlice(subject, ids, time.Now()) selected, chain := selectIdentityFromSlice(subject, ids, clock.Now())
for _, id := range ids { for _, id := range ids {
if id != selected { if id != selected {

@ -45,6 +45,7 @@ import (
"tailscale.com/net/tlsdial" "tailscale.com/net/tlsdial"
"tailscale.com/net/tshttpproxy" "tailscale.com/net/tshttpproxy"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/tstime"
"tailscale.com/util/multierr" "tailscale.com/util/multierr"
) )
@ -147,13 +148,16 @@ func (a *Dialer) dial(ctx context.Context) (*ClientConn, error) {
// before we do anything. // before we do anything.
if c.DialStartDelaySec > 0 { if c.DialStartDelaySec > 0 {
a.logf("[v2] controlhttp: waiting %.2f seconds before dialing %q @ %v", c.DialStartDelaySec, a.Hostname, c.IP) a.logf("[v2] controlhttp: waiting %.2f seconds before dialing %q @ %v", c.DialStartDelaySec, a.Hostname, c.IP)
tmr := time.NewTimer(time.Duration(c.DialStartDelaySec * float64(time.Second))) if a.Clock == nil {
a.Clock = tstime.StdClock{}
}
tmr, tmrChannel := a.Clock.NewTimer(time.Duration(c.DialStartDelaySec * float64(time.Second)))
defer tmr.Stop() defer tmr.Stop()
select { select {
case <-ctx.Done(): case <-ctx.Done():
err = ctx.Err() err = ctx.Err()
return return
case <-tmr.C: case <-tmrChannel:
} }
} }
@ -319,7 +323,10 @@ func (a *Dialer) dialHost(ctx context.Context, addr netip.Addr) (*ClientConn, er
// In case outbound port 80 blocked or MITM'ed poorly, start a backup timer // In case outbound port 80 blocked or MITM'ed poorly, start a backup timer
// to dial port 443 if port 80 doesn't either succeed or fail quickly. // to dial port 443 if port 80 doesn't either succeed or fail quickly.
try443Timer := time.AfterFunc(a.httpsFallbackDelay(), func() { try(u443) }) if a.Clock == nil {
a.Clock = tstime.StdClock{}
}
try443Timer := a.Clock.AfterFunc(a.httpsFallbackDelay(), func() { try(u443) })
defer try443Timer.Stop() defer try443Timer.Stop()
var err80, err443 error var err80, err443 error

@ -11,6 +11,7 @@ import (
"tailscale.com/net/dnscache" "tailscale.com/net/dnscache"
"tailscale.com/net/netmon" "tailscale.com/net/netmon"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/tstime"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/logger" "tailscale.com/types/logger"
) )
@ -89,6 +90,10 @@ type Dialer struct {
drainFinished chan struct{} drainFinished chan struct{}
omitCertErrorLogging bool omitCertErrorLogging bool
testFallbackDelay time.Duration testFallbackDelay time.Duration
// tstime.Clock is used instead of time package for methods such as time.Now.
// If not specified, will default to tstime.StdClock{}.
Clock tstime.Clock
} }
func strDef(v1, v2 string) string { func strDef(v1, v2 string) string {

@ -25,6 +25,7 @@ import (
"tailscale.com/net/socks5" "tailscale.com/net/socks5"
"tailscale.com/net/tsdial" "tailscale.com/net/tsdial"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/tstest"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/logger" "tailscale.com/types/logger"
) )
@ -204,6 +205,7 @@ func testControlHTTP(t *testing.T, param httpTestParam) {
Logf: t.Logf, Logf: t.Logf,
omitCertErrorLogging: true, omitCertErrorLogging: true,
testFallbackDelay: 50 * time.Millisecond, testFallbackDelay: 50 * time.Millisecond,
Clock: &tstest.Clock{},
} }
if proxy != nil { if proxy != nil {
@ -660,6 +662,7 @@ func TestDialPlan(t *testing.T) {
drainFinished: drained, drainFinished: drained,
omitCertErrorLogging: true, omitCertErrorLogging: true,
testFallbackDelay: 50 * time.Millisecond, testFallbackDelay: 50 * time.Millisecond,
Clock: &tstest.Clock{},
} }
conn, err := a.dial(ctx) conn, err := a.dial(ctx)

Loading…
Cancel
Save