diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go index bc89efdf3..b6cc66939 100644 --- a/control/controlclient/direct.go +++ b/control/controlclient/direct.go @@ -92,9 +92,10 @@ type Direct struct { authKey string tryingNewKey wgcfg.PrivateKey expiry *time.Time - hostinfo *tailcfg.Hostinfo // always non-nil - endpoints []string - localPort uint16 // or zero to mean auto + // hostinfo is mutated in-place while mu is held. + hostinfo *tailcfg.Hostinfo // always non-nil + endpoints []string + localPort uint16 // or zero to mean auto } type Options struct { @@ -262,6 +263,8 @@ func (c *Direct) doLogin(ctx context.Context, t *oauth2.Token, flags LoginFlags, tryingNewKey := c.tryingNewKey serverKey := c.serverKey authKey := c.authKey + hostinfo := c.hostinfo + backendLogID := hostinfo.BackendLogID expired := c.expiry != nil && !c.expiry.IsZero() && c.expiry.Before(c.timeNow()) c.mu.Unlock() @@ -318,7 +321,7 @@ func (c *Direct) doLogin(ctx context.Context, t *oauth2.Token, flags LoginFlags, if tryingNewKey == (wgcfg.PrivateKey{}) { log.Fatalf("tryingNewKey is empty, give up") } - if c.hostinfo.BackendLogID == "" { + if backendLogID == "" { err = errors.New("hostinfo: BackendLogID missing") return regen, url, err } @@ -326,7 +329,7 @@ func (c *Direct) doLogin(ctx context.Context, t *oauth2.Token, flags LoginFlags, Version: 1, OldNodeKey: tailcfg.NodeKey(oldNodeKey), NodeKey: tailcfg.NodeKey(tryingNewKey.Public()), - Hostinfo: c.hostinfo, + Hostinfo: hostinfo, Followup: url, } c.logf("RegisterReq: onode=%v node=%v fup=%v", @@ -453,11 +456,12 @@ func (c *Direct) PollNetMap(ctx context.Context, maxPolls int, cb func(*NetworkM serverURL := c.serverURL serverKey := c.serverKey hostinfo := c.hostinfo + backendLogID := hostinfo.BackendLogID localPort := c.localPort ep := append([]string(nil), c.endpoints...) c.mu.Unlock() - if hostinfo.BackendLogID == "" { + if backendLogID == "" { return errors.New("hostinfo: BackendLogID missing") } diff --git a/ipn/local.go b/ipn/local.go index 3c80ae188..44cb43089 100644 --- a/ipn/local.go +++ b/ipn/local.go @@ -57,14 +57,16 @@ type LocalBackend struct { lastFilterPrint time.Time // The mutex protects the following elements. - mu sync.Mutex - notify func(Notify) - c *controlclient.Client - stateKey StateKey - prefs *Prefs - state State - hiCache *tailcfg.Hostinfo - netMapCache *controlclient.NetworkMap + mu sync.Mutex + notify func(Notify) + c *controlclient.Client + stateKey StateKey + prefs *Prefs + state State + // hostinfo is mutated in-place while mu is held. + hostinfo *tailcfg.Hostinfo + // netMap is not mutated in-place once set. + netMap *controlclient.NetworkMap engineStatus EngineStatus endpoints []string blocked bool @@ -106,11 +108,6 @@ func NewLocalBackend(logf logger.Logf, logid string, store StateStore, e wgengin } b.statusChanged = sync.NewCond(&b.statusLock) - if b.portpoll != nil { - go b.portpoll.Run(ctx) - go b.readPoller() - } - return b, nil } @@ -146,11 +143,11 @@ func (b *LocalBackend) UpdateStatus(sb *ipnstate.StatusBuilder) { // TODO: hostinfo, and its networkinfo // TODO: EngineStatus copy (and deprecate it?) - if b.netMapCache != nil { - for id, up := range b.netMapCache.UserProfiles { + if b.netMap != nil { + for id, up := range b.netMap.UserProfiles { sb.AddUser(id, up) } - for _, p := range b.netMapCache.Peers { + for _, p := range b.netMap.Peers { var lastSeen time.Time if p.LastSeen != nil { lastSeen = *p.LastSeen @@ -184,6 +181,122 @@ func (b *LocalBackend) SetDecompressor(fn func() (controlclient.Decompressor, er b.newDecompressor = fn } +// setClientStatus is the callback invoked by the control client whenever it posts a new status. +// Among other things, this is where we update the netmap, packet filters, DNS and DERP maps. +func (b *LocalBackend) setClientStatus(st controlclient.Status) { + if st.LoginFinished != nil { + // Auth completed, unblock the engine + b.blockEngineUpdates(false) + b.authReconfig() + b.send(Notify{LoginFinished: &empty.Message{}}) + } + if st.Persist != nil { + persist := *st.Persist // copy + + b.mu.Lock() + b.prefs.Persist = &persist + prefs := b.prefs.Clone() + stateKey := b.stateKey + b.mu.Unlock() + + if stateKey != "" { + if err := b.store.WriteState(stateKey, prefs.ToBytes()); err != nil { + b.logf("Failed to save new controlclient state: %v", err) + } + } + b.send(Notify{Prefs: prefs}) + } + if st.NetMap != nil { + // Netmap is unchanged only when the diff is empty. + changed := true + b.mu.Lock() + if b.netMap != nil { + diff := st.NetMap.ConciseDiffFrom(b.netMap) + if strings.TrimSpace(diff) == "" { + changed = false + b.logf("netmap diff: (none)") + } else { + b.logf("netmap diff:\n%v", diff) + } + } + disableDERP := b.prefs != nil && b.prefs.DisableDERP + b.netMap = st.NetMap + b.mu.Unlock() + + b.send(Notify{NetMap: st.NetMap}) + // There is nothing to update if the map hasn't changed. + if changed { + b.updateFilter(st.NetMap) + b.updateDNSMap(st.NetMap) + } + if disableDERP { + b.e.SetDERPMap(nil) + } else { + b.e.SetDERPMap(st.NetMap.DERPMap) + } + } + if st.URL != "" { + b.logf("Received auth URL: %.20v...", st.URL) + + b.mu.Lock() + interact := b.interact + b.authURL = st.URL + b.mu.Unlock() + + if interact > 0 { + b.popBrowserAuthNow() + } + } + if st.Err != "" { + // TODO(crawshaw): display in the UI. + b.logf("Received error: %v", st.Err) + return + } + if st.NetMap != nil { + b.mu.Lock() + if b.state == NeedsLogin { + b.prefs.WantRunning = true + } + prefs := b.prefs + b.mu.Unlock() + + b.SetPrefs(prefs) + } + b.stateMachine() +} + +// setWgengineStatus is the callback by the wireguard engine whenever it posts a new status. +// This updates the endpoints both in the backend and in the control client. +func (b *LocalBackend) setWgengineStatus(s *wgengine.Status, err error) { + if err != nil { + b.logf("wgengine status error: %#v", err) + return + } + if s == nil { + b.logf("[unexpected] non-error wgengine update with status=nil: %v", s) + return + } + + es := b.parseWgStatus(s) + + b.mu.Lock() + c := b.c + b.engineStatus = es + b.endpoints = append([]string{}, s.LocalAddrs...) + b.mu.Unlock() + + if c != nil { + c.UpdateEndpoints(0, s.LocalAddrs) + } + b.stateMachine() + + b.statusLock.Lock() + b.statusChanged.Broadcast() + b.statusLock.Unlock() + + b.send(Notify{Engine: &es}) +} + // Start applies the configuration specified in opts, and starts the // state machine. // @@ -205,9 +318,9 @@ func (b *LocalBackend) Start(opts Options) error { b.logf("Start") } - hi := controlclient.NewHostinfo() - hi.BackendLogID = b.backendLogID - hi.FrontendLogID = opts.FrontendLogID + hostinfo := controlclient.NewHostinfo() + hostinfo.BackendLogID = b.backendLogID + hostinfo.FrontendLogID = opts.FrontendLogID b.mu.Lock() @@ -222,11 +335,11 @@ func (b *LocalBackend) Start(opts Options) error { b.c.Shutdown() } - if b.hiCache != nil { - hi.Services = b.hiCache.Services // keep any previous session and netinfo - hi.NetInfo = b.hiCache.NetInfo + if b.hostinfo != nil { + hostinfo.Services = b.hostinfo.Services // keep any previous session and netinfo + hostinfo.NetInfo = b.hostinfo.NetInfo } - b.hiCache = hi + b.hostinfo = hostinfo b.state = NoState if err := b.loadStateLocked(opts.StateKey, opts.Prefs, opts.LegacyConfigPath); err != nil { @@ -235,11 +348,11 @@ func (b *LocalBackend) Start(opts Options) error { } b.serverURL = b.prefs.ControlURL - hi.RoutableIPs = append(hi.RoutableIPs, b.prefs.AdvertiseRoutes...) - hi.RequestTags = append(hi.RequestTags, b.prefs.AdvertiseTags...) + hostinfo.RoutableIPs = append(hostinfo.RoutableIPs, b.prefs.AdvertiseRoutes...) + hostinfo.RequestTags = append(hostinfo.RequestTags, b.prefs.AdvertiseTags...) b.notify = opts.Notify - b.netMapCache = nil + b.netMap = nil persist := b.prefs.Persist b.mu.Unlock() @@ -255,7 +368,7 @@ func (b *LocalBackend) Start(opts Options) error { Persist: *persist, ServerURL: b.serverURL, AuthKey: opts.AuthKey, - Hostinfo: hi, + Hostinfo: hostinfo, KeepAlive: true, NewDecompressor: b.newDecompressor, HTTPTestClient: opts.HTTPTestClient, @@ -264,6 +377,13 @@ func (b *LocalBackend) Start(opts Options) error { return err } + // At this point, we have finished using hostinfo without synchronization, + // so it is safe to start readPoller which concurrently writes to it. + if b.portpoll != nil { + go b.portpoll.Run(b.ctx) + go b.readPoller() + } + b.mu.Lock() b.c = cli endpoints := b.endpoints @@ -273,118 +393,8 @@ func (b *LocalBackend) Start(opts Options) error { cli.UpdateEndpoints(0, endpoints) } - cli.SetStatusFunc(func(newSt controlclient.Status) { - if newSt.LoginFinished != nil { - // Auth completed, unblock the engine - b.blockEngineUpdates(false) - b.authReconfig() - b.send(Notify{LoginFinished: &empty.Message{}}) - } - if newSt.Persist != nil { - persist := *newSt.Persist // copy - - b.mu.Lock() - b.prefs.Persist = &persist - prefs := b.prefs.Clone() - stateKey := b.stateKey - b.mu.Unlock() - - if stateKey != "" { - if err := b.store.WriteState(stateKey, prefs.ToBytes()); err != nil { - b.logf("Failed to save new controlclient state: %v", err) - } - } - b.send(Notify{Prefs: prefs}) - } - if newSt.NetMap != nil { - // Netmap is unchanged only when the diff is empty. - changed := true - b.mu.Lock() - if b.netMapCache != nil { - diff := newSt.NetMap.ConciseDiffFrom(b.netMapCache) - if strings.TrimSpace(diff) == "" { - changed = false - b.logf("netmap diff: (none)") - } else { - b.logf("netmap diff:\n%v", diff) - } - } - disableDERP := b.prefs != nil && b.prefs.DisableDERP - b.netMapCache = newSt.NetMap - b.mu.Unlock() - - b.send(Notify{NetMap: newSt.NetMap}) - // There is nothing to update if the map hasn't changed. - if changed { - b.updateFilter(newSt.NetMap) - b.updateDNSMap(newSt.NetMap) - } - if disableDERP { - b.e.SetDERPMap(nil) - } else { - b.e.SetDERPMap(newSt.NetMap.DERPMap) - } - } - if newSt.URL != "" { - b.logf("Received auth URL: %.20v...", newSt.URL) - - b.mu.Lock() - interact := b.interact - b.authURL = newSt.URL - b.mu.Unlock() - - if interact > 0 { - b.popBrowserAuthNow() - } - } - if newSt.Err != "" { - // TODO(crawshaw): display in the UI. - b.logf("Received error: %v", newSt.Err) - return - } - if newSt.NetMap != nil { - b.mu.Lock() - if b.state == NeedsLogin { - b.prefs.WantRunning = true - } - prefs := b.prefs - b.mu.Unlock() - - b.SetPrefs(prefs) - } - b.stateMachine() - }) - - b.e.SetStatusCallback(func(s *wgengine.Status, err error) { - if err != nil { - b.logf("wgengine status error: %#v", err) - return - } - if s == nil { - b.logf("weird: non-error wgengine update with status=nil: %v", s) - return - } - - es := b.parseWgStatus(s) - - b.mu.Lock() - c := b.c - b.engineStatus = es - b.endpoints = append([]string{}, s.LocalAddrs...) - b.mu.Unlock() - - if c != nil { - c.UpdateEndpoints(0, s.LocalAddrs) - } - b.stateMachine() - - b.statusLock.Lock() - b.statusChanged.Broadcast() - b.statusLock.Unlock() - - b.send(Notify{Engine: &es}) - }) - + cli.SetStatusFunc(b.setClientStatus) + b.e.SetStatusCallback(b.setWgengineStatus) b.e.SetNetInfoCallback(b.setNetInfo) b.mu.Lock() @@ -477,13 +487,11 @@ func (b *LocalBackend) readPoller() { } b.mu.Lock() - if b.hiCache == nil { - // TODO(bradfitz): it's a little weird that this port poller - // is started (by NewLocalBackend) before the Start call. - b.hiCache = new(tailcfg.Hostinfo) + if b.hostinfo == nil { + b.hostinfo = new(tailcfg.Hostinfo) } - b.hiCache.Services = sl - hi := b.hiCache + b.hostinfo.Services = sl + hi := b.hostinfo b.mu.Unlock() b.doSetHostinfoFilterServices(hi) @@ -617,13 +625,23 @@ func (b *LocalBackend) StartLoginInteractive() { // FakeExpireAfter implements Backend. func (b *LocalBackend) FakeExpireAfter(x time.Duration) { b.logf("FakeExpireAfter: %v", x) - if b.netMapCache != nil { - e := b.netMapCache.Expiry - if e.IsZero() || time.Until(e) > x { - b.netMapCache.Expiry = time.Now().Add(x) - } - b.send(Notify{NetMap: b.netMapCache}) + + b.mu.Lock() + defer b.mu.Unlock() + + if b.netMap == nil { + return + } + + // This function is called very rarely, + // so we prefer to fully copy the netmap over introducing in-place modification here. + mapCopy := *b.netMap + e := mapCopy.Expiry + if e.IsZero() || time.Until(e) > x { + mapCopy.Expiry = time.Now().Add(x) } + b.netMap = &mapCopy + b.send(Notify{NetMap: b.netMap}) } func (b *LocalBackend) parseWgStatus(s *wgengine.Status) (ret EngineStatus) { @@ -680,13 +698,13 @@ func (b *LocalBackend) SetPrefs(new *Prefs) { b.logf("Failed to save new controlclient state: %v", err) } } - oldHi := b.hiCache + oldHi := b.hostinfo newHi := oldHi.Clone() newHi.RoutableIPs = append([]wgcfg.CIDR(nil), b.prefs.AdvertiseRoutes...) if h := new.Hostname; h != "" { newHi.Hostname = h } - b.hiCache = newHi + b.hostinfo = newHi b.mu.Unlock() b.logf("SetPrefs: %v", new.Pretty()) @@ -695,15 +713,15 @@ func (b *LocalBackend) SetPrefs(new *Prefs) { b.doSetHostinfoFilterServices(newHi) } - b.updateFilter(b.netMapCache) + b.updateFilter(b.netMap) // TODO(dmytro): when Prefs gain an EnableTailscaleDNS toggle, updateDNSMap here. turnDERPOff := new.DisableDERP && !old.DisableDERP turnDERPOn := !new.DisableDERP && old.DisableDERP if turnDERPOff { b.e.SetDERPMap(nil) - } else if turnDERPOn && b.netMapCache != nil { - b.e.SetDERPMap(b.netMapCache.DERPMap) + } else if turnDERPOn && b.netMap != nil { + b.e.SetDERPMap(b.netMap.DERPMap) } if old.WantRunning != new.WantRunning { @@ -741,7 +759,7 @@ func (b *LocalBackend) doSetHostinfoFilterServices(hi *tailcfg.Hostinfo) { // NetMap returns the latest cached network map received from // controlclient, or nil if no network map was received yet. func (b *LocalBackend) NetMap() *controlclient.NetworkMap { - return b.netMapCache + return b.netMap } // blockEngineUpdate sets b.blocked to block, while holding b.mu. Its @@ -762,7 +780,7 @@ func (b *LocalBackend) authReconfig() { b.mu.Lock() blocked := b.blocked uc := b.prefs - nm := b.netMapCache + nm := b.netMap b.mu.Unlock() if blocked { @@ -939,7 +957,7 @@ func (b *LocalBackend) nextState() State { b.assertClientLocked() var ( c = b.c - netMap = b.netMapCache + netMap = b.netMap state = b.state wantRunning = b.prefs.WantRunning ) @@ -1037,13 +1055,13 @@ func (b *LocalBackend) Logout() { b.mu.Lock() b.assertClientLocked() c := b.c - b.netMapCache = nil + b.netMap = nil b.mu.Unlock() c.Logout() b.mu.Lock() - b.netMapCache = nil + b.netMap = nil b.mu.Unlock() b.stateMachine() @@ -1056,13 +1074,13 @@ func (b *LocalBackend) assertClientLocked() { } } -// setNetInfo sets b.hiCache.NetInfo to ni, and passes ni along to the +// setNetInfo sets b.hostinfo.NetInfo to ni, and passes ni along to the // controlclient, if one exists. func (b *LocalBackend) setNetInfo(ni *tailcfg.NetInfo) { b.mu.Lock() c := b.c - if b.hiCache != nil { - b.hiCache.NetInfo = ni.Clone() + if b.hostinfo != nil { + b.hostinfo.NetInfo = ni.Clone() } b.mu.Unlock() diff --git a/types/logger/logger.go b/types/logger/logger.go index de7d124b1..42b520460 100644 --- a/types/logger/logger.go +++ b/types/logger/logger.go @@ -127,18 +127,23 @@ func RateLimitedFn(logf Logf, f time.Duration, burst int, maxCache int) Logf { // since the last time this identical line was logged. func LogOnChange(logf Logf, maxInterval time.Duration, timeNow func() time.Time) Logf { var ( + mu sync.Mutex sLastLogged string tLastLogged = timeNow() ) return func(format string, args ...interface{}) { s := fmt.Sprintf(format, args...) + + mu.Lock() if s == sLastLogged && timeNow().Sub(tLastLogged) < maxInterval { + mu.Unlock() return } - sLastLogged = s tLastLogged = timeNow() + mu.Unlock() + logf(s) } diff --git a/types/logger/logger_test.go b/types/logger/logger_test.go index 6d4608734..e5d1f5087 100644 --- a/types/logger/logger_test.go +++ b/types/logger/logger_test.go @@ -9,6 +9,7 @@ import ( "bytes" "fmt" "log" + "sync" "testing" "time" ) @@ -117,3 +118,31 @@ func TestArgWriter(t *testing.T) { t.Errorf("got %q; want %q", got, want) } } + +func TestSynchronization(t *testing.T) { + timeNow := testTimer(1 * time.Second) + tests := []struct { + name string + logf Logf + }{ + {"RateLimitedFn", RateLimitedFn(t.Logf, 1*time.Minute, 2, 50)}, + {"LogOnChange", LogOnChange(t.Logf, 5*time.Second, timeNow)}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var wg sync.WaitGroup + wg.Add(2) + + f := func() { + tt.logf("1 2 3 4 5") + wg.Done() + } + + go f() + go f() + + wg.Wait() + }) + } +}