diff --git a/portlist/poller.go b/portlist/poller.go index 5dc99e49c..5a3c5448b 100644 --- a/portlist/poller.go +++ b/portlist/poller.go @@ -14,6 +14,7 @@ import ( "sync" "time" + "golang.org/x/exp/slices" "tailscale.com/envknob" ) @@ -84,14 +85,20 @@ func NewPoller() (*Poller, error) { // Do one initial poll synchronously so we can return an error // early. - var err error - p.prev, err = p.getList() - if err != nil { + if pl, err := p.getList(); err != nil { return nil, err + } else { + p.setPrev(pl) } return p, nil } +func (p *Poller) setPrev(pl List) { + // Make a copy, as the pass in pl slice aliases pl.scratch and we don't want + // that to except to the caller. + p.prev = slices.Clone(pl) +} + func (p *Poller) initOSField() { if newOSImpl != nil { p.os = newOSImpl() @@ -131,11 +138,14 @@ func (p *Poller) send(ctx context.Context, pl List) (sent bool, err error) { // // Run may only be called once. func (p *Poller) Run(ctx context.Context) error { - defer close(p.runDone) - defer close(p.c) - tick := time.NewTicker(pollInterval) defer tick.Stop() + return p.runWithTickChan(ctx, tick.C) +} + +func (p *Poller) runWithTickChan(ctx context.Context, tickChan <-chan time.Time) error { + defer close(p.runDone) + defer close(p.c) // Send out the pre-generated initial value. if sent, err := p.send(ctx, p.prev); !sent { @@ -144,7 +154,7 @@ func (p *Poller) Run(ctx context.Context) error { for { select { - case <-tick.C: + case <-tickChan: pl, err := p.getList() if err != nil { return err @@ -152,9 +162,7 @@ func (p *Poller) Run(ctx context.Context) error { if pl.equal(p.prev) { continue } - // New value. Make a copy, as pl might alias pl.scratch - // and prev must not. - p.prev = append([]Port(nil), pl...) + p.setPrev(pl) if sent, err := p.send(ctx, p.prev); !sent { return err } diff --git a/portlist/portlist_test.go b/portlist/portlist_test.go index 9e824ce4f..a50496aeb 100644 --- a/portlist/portlist_test.go +++ b/portlist/portlist_test.go @@ -5,10 +5,13 @@ package portlist import ( + "context" "flag" "net" "runtime" + "sync" "testing" + "time" "tailscale.com/tstest" ) @@ -182,6 +185,70 @@ func TestEqualLessThan(t *testing.T) { } } +func TestPoller(t *testing.T) { + p, err := NewPoller() + if err != nil { + t.Skipf("not running test: %v", err) + } + defer p.Close() + + var wg sync.WaitGroup + wg.Add(2) + + gotUpdate := make(chan bool, 16) + + go func() { + defer wg.Done() + for pl := range p.Updates() { + // Look at all the pl slice memory to maximize + // chance of race detector seeing violations. + for _, v := range pl { + if v == (Port{}) { + // Force use + panic("empty port") + } + } + select { + case gotUpdate <- true: + default: + } + } + }() + + tick := make(chan time.Time, 16) + go func() { + defer wg.Done() + if err := p.runWithTickChan(context.Background(), tick); err != nil { + t.Error("runWithTickChan:", err) + } + }() + for i := 0; i < 10; i++ { + ln, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + tick <- time.Time{} + + select { + case <-gotUpdate: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for update") + } + } + + // And a bunch of ticks without waiting for updates, + // to make race tests more likely to fail, if any present. + for i := 0; i < 10; i++ { + tick <- time.Time{} + } + + if err := p.Close(); err != nil { + t.Fatal(err) + } + wg.Wait() +} + func BenchmarkGetList(b *testing.B) { benchmarkGetList(b, false) }