Fix concurrency issues in controlclient, ipn, types/logger (#456)

Signed-Off-By: Dmytro Shynkevych <dmytro@tailscale.com>
pull/475/head
Dmytro Shynkevych 5 years ago committed by GitHub
parent c8cf3169ba
commit c12d87c54b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -92,9 +92,10 @@ type Direct struct {
authKey string authKey string
tryingNewKey wgcfg.PrivateKey tryingNewKey wgcfg.PrivateKey
expiry *time.Time expiry *time.Time
hostinfo *tailcfg.Hostinfo // always non-nil // hostinfo is mutated in-place while mu is held.
endpoints []string hostinfo *tailcfg.Hostinfo // always non-nil
localPort uint16 // or zero to mean auto endpoints []string
localPort uint16 // or zero to mean auto
} }
type Options struct { type Options struct {
@ -262,6 +263,8 @@ func (c *Direct) doLogin(ctx context.Context, t *oauth2.Token, flags LoginFlags,
tryingNewKey := c.tryingNewKey tryingNewKey := c.tryingNewKey
serverKey := c.serverKey serverKey := c.serverKey
authKey := c.authKey authKey := c.authKey
hostinfo := c.hostinfo
backendLogID := hostinfo.BackendLogID
expired := c.expiry != nil && !c.expiry.IsZero() && c.expiry.Before(c.timeNow()) expired := c.expiry != nil && !c.expiry.IsZero() && c.expiry.Before(c.timeNow())
c.mu.Unlock() c.mu.Unlock()
@ -318,7 +321,7 @@ func (c *Direct) doLogin(ctx context.Context, t *oauth2.Token, flags LoginFlags,
if tryingNewKey == (wgcfg.PrivateKey{}) { if tryingNewKey == (wgcfg.PrivateKey{}) {
log.Fatalf("tryingNewKey is empty, give up") log.Fatalf("tryingNewKey is empty, give up")
} }
if c.hostinfo.BackendLogID == "" { if backendLogID == "" {
err = errors.New("hostinfo: BackendLogID missing") err = errors.New("hostinfo: BackendLogID missing")
return regen, url, err return regen, url, err
} }
@ -326,7 +329,7 @@ func (c *Direct) doLogin(ctx context.Context, t *oauth2.Token, flags LoginFlags,
Version: 1, Version: 1,
OldNodeKey: tailcfg.NodeKey(oldNodeKey), OldNodeKey: tailcfg.NodeKey(oldNodeKey),
NodeKey: tailcfg.NodeKey(tryingNewKey.Public()), NodeKey: tailcfg.NodeKey(tryingNewKey.Public()),
Hostinfo: c.hostinfo, Hostinfo: hostinfo,
Followup: url, Followup: url,
} }
c.logf("RegisterReq: onode=%v node=%v fup=%v", c.logf("RegisterReq: onode=%v node=%v fup=%v",
@ -453,11 +456,12 @@ func (c *Direct) PollNetMap(ctx context.Context, maxPolls int, cb func(*NetworkM
serverURL := c.serverURL serverURL := c.serverURL
serverKey := c.serverKey serverKey := c.serverKey
hostinfo := c.hostinfo hostinfo := c.hostinfo
backendLogID := hostinfo.BackendLogID
localPort := c.localPort localPort := c.localPort
ep := append([]string(nil), c.endpoints...) ep := append([]string(nil), c.endpoints...)
c.mu.Unlock() c.mu.Unlock()
if hostinfo.BackendLogID == "" { if backendLogID == "" {
return errors.New("hostinfo: BackendLogID missing") return errors.New("hostinfo: BackendLogID missing")
} }

@ -57,14 +57,16 @@ type LocalBackend struct {
lastFilterPrint time.Time lastFilterPrint time.Time
// The mutex protects the following elements. // The mutex protects the following elements.
mu sync.Mutex mu sync.Mutex
notify func(Notify) notify func(Notify)
c *controlclient.Client c *controlclient.Client
stateKey StateKey stateKey StateKey
prefs *Prefs prefs *Prefs
state State state State
hiCache *tailcfg.Hostinfo // hostinfo is mutated in-place while mu is held.
netMapCache *controlclient.NetworkMap hostinfo *tailcfg.Hostinfo
// netMap is not mutated in-place once set.
netMap *controlclient.NetworkMap
engineStatus EngineStatus engineStatus EngineStatus
endpoints []string endpoints []string
blocked bool blocked bool
@ -106,11 +108,6 @@ func NewLocalBackend(logf logger.Logf, logid string, store StateStore, e wgengin
} }
b.statusChanged = sync.NewCond(&b.statusLock) b.statusChanged = sync.NewCond(&b.statusLock)
if b.portpoll != nil {
go b.portpoll.Run(ctx)
go b.readPoller()
}
return b, nil return b, nil
} }
@ -146,11 +143,11 @@ func (b *LocalBackend) UpdateStatus(sb *ipnstate.StatusBuilder) {
// TODO: hostinfo, and its networkinfo // TODO: hostinfo, and its networkinfo
// TODO: EngineStatus copy (and deprecate it?) // TODO: EngineStatus copy (and deprecate it?)
if b.netMapCache != nil { if b.netMap != nil {
for id, up := range b.netMapCache.UserProfiles { for id, up := range b.netMap.UserProfiles {
sb.AddUser(id, up) sb.AddUser(id, up)
} }
for _, p := range b.netMapCache.Peers { for _, p := range b.netMap.Peers {
var lastSeen time.Time var lastSeen time.Time
if p.LastSeen != nil { if p.LastSeen != nil {
lastSeen = *p.LastSeen lastSeen = *p.LastSeen
@ -184,6 +181,122 @@ func (b *LocalBackend) SetDecompressor(fn func() (controlclient.Decompressor, er
b.newDecompressor = fn b.newDecompressor = fn
} }
// setClientStatus is the callback invoked by the control client whenever it posts a new status.
// Among other things, this is where we update the netmap, packet filters, DNS and DERP maps.
func (b *LocalBackend) setClientStatus(st controlclient.Status) {
if st.LoginFinished != nil {
// Auth completed, unblock the engine
b.blockEngineUpdates(false)
b.authReconfig()
b.send(Notify{LoginFinished: &empty.Message{}})
}
if st.Persist != nil {
persist := *st.Persist // copy
b.mu.Lock()
b.prefs.Persist = &persist
prefs := b.prefs.Clone()
stateKey := b.stateKey
b.mu.Unlock()
if stateKey != "" {
if err := b.store.WriteState(stateKey, prefs.ToBytes()); err != nil {
b.logf("Failed to save new controlclient state: %v", err)
}
}
b.send(Notify{Prefs: prefs})
}
if st.NetMap != nil {
// Netmap is unchanged only when the diff is empty.
changed := true
b.mu.Lock()
if b.netMap != nil {
diff := st.NetMap.ConciseDiffFrom(b.netMap)
if strings.TrimSpace(diff) == "" {
changed = false
b.logf("netmap diff: (none)")
} else {
b.logf("netmap diff:\n%v", diff)
}
}
disableDERP := b.prefs != nil && b.prefs.DisableDERP
b.netMap = st.NetMap
b.mu.Unlock()
b.send(Notify{NetMap: st.NetMap})
// There is nothing to update if the map hasn't changed.
if changed {
b.updateFilter(st.NetMap)
b.updateDNSMap(st.NetMap)
}
if disableDERP {
b.e.SetDERPMap(nil)
} else {
b.e.SetDERPMap(st.NetMap.DERPMap)
}
}
if st.URL != "" {
b.logf("Received auth URL: %.20v...", st.URL)
b.mu.Lock()
interact := b.interact
b.authURL = st.URL
b.mu.Unlock()
if interact > 0 {
b.popBrowserAuthNow()
}
}
if st.Err != "" {
// TODO(crawshaw): display in the UI.
b.logf("Received error: %v", st.Err)
return
}
if st.NetMap != nil {
b.mu.Lock()
if b.state == NeedsLogin {
b.prefs.WantRunning = true
}
prefs := b.prefs
b.mu.Unlock()
b.SetPrefs(prefs)
}
b.stateMachine()
}
// setWgengineStatus is the callback by the wireguard engine whenever it posts a new status.
// This updates the endpoints both in the backend and in the control client.
func (b *LocalBackend) setWgengineStatus(s *wgengine.Status, err error) {
if err != nil {
b.logf("wgengine status error: %#v", err)
return
}
if s == nil {
b.logf("[unexpected] non-error wgengine update with status=nil: %v", s)
return
}
es := b.parseWgStatus(s)
b.mu.Lock()
c := b.c
b.engineStatus = es
b.endpoints = append([]string{}, s.LocalAddrs...)
b.mu.Unlock()
if c != nil {
c.UpdateEndpoints(0, s.LocalAddrs)
}
b.stateMachine()
b.statusLock.Lock()
b.statusChanged.Broadcast()
b.statusLock.Unlock()
b.send(Notify{Engine: &es})
}
// Start applies the configuration specified in opts, and starts the // Start applies the configuration specified in opts, and starts the
// state machine. // state machine.
// //
@ -205,9 +318,9 @@ func (b *LocalBackend) Start(opts Options) error {
b.logf("Start") b.logf("Start")
} }
hi := controlclient.NewHostinfo() hostinfo := controlclient.NewHostinfo()
hi.BackendLogID = b.backendLogID hostinfo.BackendLogID = b.backendLogID
hi.FrontendLogID = opts.FrontendLogID hostinfo.FrontendLogID = opts.FrontendLogID
b.mu.Lock() b.mu.Lock()
@ -222,11 +335,11 @@ func (b *LocalBackend) Start(opts Options) error {
b.c.Shutdown() b.c.Shutdown()
} }
if b.hiCache != nil { if b.hostinfo != nil {
hi.Services = b.hiCache.Services // keep any previous session and netinfo hostinfo.Services = b.hostinfo.Services // keep any previous session and netinfo
hi.NetInfo = b.hiCache.NetInfo hostinfo.NetInfo = b.hostinfo.NetInfo
} }
b.hiCache = hi b.hostinfo = hostinfo
b.state = NoState b.state = NoState
if err := b.loadStateLocked(opts.StateKey, opts.Prefs, opts.LegacyConfigPath); err != nil { if err := b.loadStateLocked(opts.StateKey, opts.Prefs, opts.LegacyConfigPath); err != nil {
@ -235,11 +348,11 @@ func (b *LocalBackend) Start(opts Options) error {
} }
b.serverURL = b.prefs.ControlURL b.serverURL = b.prefs.ControlURL
hi.RoutableIPs = append(hi.RoutableIPs, b.prefs.AdvertiseRoutes...) hostinfo.RoutableIPs = append(hostinfo.RoutableIPs, b.prefs.AdvertiseRoutes...)
hi.RequestTags = append(hi.RequestTags, b.prefs.AdvertiseTags...) hostinfo.RequestTags = append(hostinfo.RequestTags, b.prefs.AdvertiseTags...)
b.notify = opts.Notify b.notify = opts.Notify
b.netMapCache = nil b.netMap = nil
persist := b.prefs.Persist persist := b.prefs.Persist
b.mu.Unlock() b.mu.Unlock()
@ -255,7 +368,7 @@ func (b *LocalBackend) Start(opts Options) error {
Persist: *persist, Persist: *persist,
ServerURL: b.serverURL, ServerURL: b.serverURL,
AuthKey: opts.AuthKey, AuthKey: opts.AuthKey,
Hostinfo: hi, Hostinfo: hostinfo,
KeepAlive: true, KeepAlive: true,
NewDecompressor: b.newDecompressor, NewDecompressor: b.newDecompressor,
HTTPTestClient: opts.HTTPTestClient, HTTPTestClient: opts.HTTPTestClient,
@ -264,6 +377,13 @@ func (b *LocalBackend) Start(opts Options) error {
return err return err
} }
// At this point, we have finished using hostinfo without synchronization,
// so it is safe to start readPoller which concurrently writes to it.
if b.portpoll != nil {
go b.portpoll.Run(b.ctx)
go b.readPoller()
}
b.mu.Lock() b.mu.Lock()
b.c = cli b.c = cli
endpoints := b.endpoints endpoints := b.endpoints
@ -273,118 +393,8 @@ func (b *LocalBackend) Start(opts Options) error {
cli.UpdateEndpoints(0, endpoints) cli.UpdateEndpoints(0, endpoints)
} }
cli.SetStatusFunc(func(newSt controlclient.Status) { cli.SetStatusFunc(b.setClientStatus)
if newSt.LoginFinished != nil { b.e.SetStatusCallback(b.setWgengineStatus)
// Auth completed, unblock the engine
b.blockEngineUpdates(false)
b.authReconfig()
b.send(Notify{LoginFinished: &empty.Message{}})
}
if newSt.Persist != nil {
persist := *newSt.Persist // copy
b.mu.Lock()
b.prefs.Persist = &persist
prefs := b.prefs.Clone()
stateKey := b.stateKey
b.mu.Unlock()
if stateKey != "" {
if err := b.store.WriteState(stateKey, prefs.ToBytes()); err != nil {
b.logf("Failed to save new controlclient state: %v", err)
}
}
b.send(Notify{Prefs: prefs})
}
if newSt.NetMap != nil {
// Netmap is unchanged only when the diff is empty.
changed := true
b.mu.Lock()
if b.netMapCache != nil {
diff := newSt.NetMap.ConciseDiffFrom(b.netMapCache)
if strings.TrimSpace(diff) == "" {
changed = false
b.logf("netmap diff: (none)")
} else {
b.logf("netmap diff:\n%v", diff)
}
}
disableDERP := b.prefs != nil && b.prefs.DisableDERP
b.netMapCache = newSt.NetMap
b.mu.Unlock()
b.send(Notify{NetMap: newSt.NetMap})
// There is nothing to update if the map hasn't changed.
if changed {
b.updateFilter(newSt.NetMap)
b.updateDNSMap(newSt.NetMap)
}
if disableDERP {
b.e.SetDERPMap(nil)
} else {
b.e.SetDERPMap(newSt.NetMap.DERPMap)
}
}
if newSt.URL != "" {
b.logf("Received auth URL: %.20v...", newSt.URL)
b.mu.Lock()
interact := b.interact
b.authURL = newSt.URL
b.mu.Unlock()
if interact > 0 {
b.popBrowserAuthNow()
}
}
if newSt.Err != "" {
// TODO(crawshaw): display in the UI.
b.logf("Received error: %v", newSt.Err)
return
}
if newSt.NetMap != nil {
b.mu.Lock()
if b.state == NeedsLogin {
b.prefs.WantRunning = true
}
prefs := b.prefs
b.mu.Unlock()
b.SetPrefs(prefs)
}
b.stateMachine()
})
b.e.SetStatusCallback(func(s *wgengine.Status, err error) {
if err != nil {
b.logf("wgengine status error: %#v", err)
return
}
if s == nil {
b.logf("weird: non-error wgengine update with status=nil: %v", s)
return
}
es := b.parseWgStatus(s)
b.mu.Lock()
c := b.c
b.engineStatus = es
b.endpoints = append([]string{}, s.LocalAddrs...)
b.mu.Unlock()
if c != nil {
c.UpdateEndpoints(0, s.LocalAddrs)
}
b.stateMachine()
b.statusLock.Lock()
b.statusChanged.Broadcast()
b.statusLock.Unlock()
b.send(Notify{Engine: &es})
})
b.e.SetNetInfoCallback(b.setNetInfo) b.e.SetNetInfoCallback(b.setNetInfo)
b.mu.Lock() b.mu.Lock()
@ -477,13 +487,11 @@ func (b *LocalBackend) readPoller() {
} }
b.mu.Lock() b.mu.Lock()
if b.hiCache == nil { if b.hostinfo == nil {
// TODO(bradfitz): it's a little weird that this port poller b.hostinfo = new(tailcfg.Hostinfo)
// is started (by NewLocalBackend) before the Start call.
b.hiCache = new(tailcfg.Hostinfo)
} }
b.hiCache.Services = sl b.hostinfo.Services = sl
hi := b.hiCache hi := b.hostinfo
b.mu.Unlock() b.mu.Unlock()
b.doSetHostinfoFilterServices(hi) b.doSetHostinfoFilterServices(hi)
@ -617,13 +625,23 @@ func (b *LocalBackend) StartLoginInteractive() {
// FakeExpireAfter implements Backend. // FakeExpireAfter implements Backend.
func (b *LocalBackend) FakeExpireAfter(x time.Duration) { func (b *LocalBackend) FakeExpireAfter(x time.Duration) {
b.logf("FakeExpireAfter: %v", x) b.logf("FakeExpireAfter: %v", x)
if b.netMapCache != nil {
e := b.netMapCache.Expiry b.mu.Lock()
if e.IsZero() || time.Until(e) > x { defer b.mu.Unlock()
b.netMapCache.Expiry = time.Now().Add(x)
} if b.netMap == nil {
b.send(Notify{NetMap: b.netMapCache}) return
}
// This function is called very rarely,
// so we prefer to fully copy the netmap over introducing in-place modification here.
mapCopy := *b.netMap
e := mapCopy.Expiry
if e.IsZero() || time.Until(e) > x {
mapCopy.Expiry = time.Now().Add(x)
} }
b.netMap = &mapCopy
b.send(Notify{NetMap: b.netMap})
} }
func (b *LocalBackend) parseWgStatus(s *wgengine.Status) (ret EngineStatus) { func (b *LocalBackend) parseWgStatus(s *wgengine.Status) (ret EngineStatus) {
@ -680,13 +698,13 @@ func (b *LocalBackend) SetPrefs(new *Prefs) {
b.logf("Failed to save new controlclient state: %v", err) b.logf("Failed to save new controlclient state: %v", err)
} }
} }
oldHi := b.hiCache oldHi := b.hostinfo
newHi := oldHi.Clone() newHi := oldHi.Clone()
newHi.RoutableIPs = append([]wgcfg.CIDR(nil), b.prefs.AdvertiseRoutes...) newHi.RoutableIPs = append([]wgcfg.CIDR(nil), b.prefs.AdvertiseRoutes...)
if h := new.Hostname; h != "" { if h := new.Hostname; h != "" {
newHi.Hostname = h newHi.Hostname = h
} }
b.hiCache = newHi b.hostinfo = newHi
b.mu.Unlock() b.mu.Unlock()
b.logf("SetPrefs: %v", new.Pretty()) b.logf("SetPrefs: %v", new.Pretty())
@ -695,15 +713,15 @@ func (b *LocalBackend) SetPrefs(new *Prefs) {
b.doSetHostinfoFilterServices(newHi) b.doSetHostinfoFilterServices(newHi)
} }
b.updateFilter(b.netMapCache) b.updateFilter(b.netMap)
// TODO(dmytro): when Prefs gain an EnableTailscaleDNS toggle, updateDNSMap here. // TODO(dmytro): when Prefs gain an EnableTailscaleDNS toggle, updateDNSMap here.
turnDERPOff := new.DisableDERP && !old.DisableDERP turnDERPOff := new.DisableDERP && !old.DisableDERP
turnDERPOn := !new.DisableDERP && old.DisableDERP turnDERPOn := !new.DisableDERP && old.DisableDERP
if turnDERPOff { if turnDERPOff {
b.e.SetDERPMap(nil) b.e.SetDERPMap(nil)
} else if turnDERPOn && b.netMapCache != nil { } else if turnDERPOn && b.netMap != nil {
b.e.SetDERPMap(b.netMapCache.DERPMap) b.e.SetDERPMap(b.netMap.DERPMap)
} }
if old.WantRunning != new.WantRunning { if old.WantRunning != new.WantRunning {
@ -741,7 +759,7 @@ func (b *LocalBackend) doSetHostinfoFilterServices(hi *tailcfg.Hostinfo) {
// NetMap returns the latest cached network map received from // NetMap returns the latest cached network map received from
// controlclient, or nil if no network map was received yet. // controlclient, or nil if no network map was received yet.
func (b *LocalBackend) NetMap() *controlclient.NetworkMap { func (b *LocalBackend) NetMap() *controlclient.NetworkMap {
return b.netMapCache return b.netMap
} }
// blockEngineUpdate sets b.blocked to block, while holding b.mu. Its // blockEngineUpdate sets b.blocked to block, while holding b.mu. Its
@ -762,7 +780,7 @@ func (b *LocalBackend) authReconfig() {
b.mu.Lock() b.mu.Lock()
blocked := b.blocked blocked := b.blocked
uc := b.prefs uc := b.prefs
nm := b.netMapCache nm := b.netMap
b.mu.Unlock() b.mu.Unlock()
if blocked { if blocked {
@ -939,7 +957,7 @@ func (b *LocalBackend) nextState() State {
b.assertClientLocked() b.assertClientLocked()
var ( var (
c = b.c c = b.c
netMap = b.netMapCache netMap = b.netMap
state = b.state state = b.state
wantRunning = b.prefs.WantRunning wantRunning = b.prefs.WantRunning
) )
@ -1037,13 +1055,13 @@ func (b *LocalBackend) Logout() {
b.mu.Lock() b.mu.Lock()
b.assertClientLocked() b.assertClientLocked()
c := b.c c := b.c
b.netMapCache = nil b.netMap = nil
b.mu.Unlock() b.mu.Unlock()
c.Logout() c.Logout()
b.mu.Lock() b.mu.Lock()
b.netMapCache = nil b.netMap = nil
b.mu.Unlock() b.mu.Unlock()
b.stateMachine() b.stateMachine()
@ -1056,13 +1074,13 @@ func (b *LocalBackend) assertClientLocked() {
} }
} }
// setNetInfo sets b.hiCache.NetInfo to ni, and passes ni along to the // setNetInfo sets b.hostinfo.NetInfo to ni, and passes ni along to the
// controlclient, if one exists. // controlclient, if one exists.
func (b *LocalBackend) setNetInfo(ni *tailcfg.NetInfo) { func (b *LocalBackend) setNetInfo(ni *tailcfg.NetInfo) {
b.mu.Lock() b.mu.Lock()
c := b.c c := b.c
if b.hiCache != nil { if b.hostinfo != nil {
b.hiCache.NetInfo = ni.Clone() b.hostinfo.NetInfo = ni.Clone()
} }
b.mu.Unlock() b.mu.Unlock()

@ -127,18 +127,23 @@ func RateLimitedFn(logf Logf, f time.Duration, burst int, maxCache int) Logf {
// since the last time this identical line was logged. // since the last time this identical line was logged.
func LogOnChange(logf Logf, maxInterval time.Duration, timeNow func() time.Time) Logf { func LogOnChange(logf Logf, maxInterval time.Duration, timeNow func() time.Time) Logf {
var ( var (
mu sync.Mutex
sLastLogged string sLastLogged string
tLastLogged = timeNow() tLastLogged = timeNow()
) )
return func(format string, args ...interface{}) { return func(format string, args ...interface{}) {
s := fmt.Sprintf(format, args...) s := fmt.Sprintf(format, args...)
mu.Lock()
if s == sLastLogged && timeNow().Sub(tLastLogged) < maxInterval { if s == sLastLogged && timeNow().Sub(tLastLogged) < maxInterval {
mu.Unlock()
return return
} }
sLastLogged = s sLastLogged = s
tLastLogged = timeNow() tLastLogged = timeNow()
mu.Unlock()
logf(s) logf(s)
} }

@ -9,6 +9,7 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"log" "log"
"sync"
"testing" "testing"
"time" "time"
) )
@ -117,3 +118,31 @@ func TestArgWriter(t *testing.T) {
t.Errorf("got %q; want %q", got, want) t.Errorf("got %q; want %q", got, want)
} }
} }
func TestSynchronization(t *testing.T) {
timeNow := testTimer(1 * time.Second)
tests := []struct {
name string
logf Logf
}{
{"RateLimitedFn", RateLimitedFn(t.Logf, 1*time.Minute, 2, 50)},
{"LogOnChange", LogOnChange(t.Logf, 5*time.Second, timeNow)},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var wg sync.WaitGroup
wg.Add(2)
f := func() {
tt.logf("1 2 3 4 5")
wg.Done()
}
go f()
go f()
wg.Wait()
})
}
}

Loading…
Cancel
Save