diff --git a/portlist/poller.go b/portlist/poller.go index 90c8e7838..d1f5b2ab0 100644 --- a/portlist/poller.go +++ b/portlist/poller.go @@ -7,7 +7,6 @@ package portlist import ( - "context" "errors" "fmt" "runtime" @@ -19,7 +18,8 @@ import ( ) var ( - pollInterval = 5 * time.Second // default; changed by some OS-specific init funcs + newOSImpl func(includeLocalhost bool) osImpl // if non-nil, constructs a new osImpl. + pollInterval = 5 * time.Second // default; changed by some OS-specific init funcs debugDisablePortlist = envknob.RegisterBool("TS_DEBUG_DISABLE_PORTLIST") ) @@ -37,8 +37,6 @@ type Poller struct { // This field should only be changed before calling Run. IncludeLocalhost bool - c chan List // unbuffered - // os, if non-nil, is an OS-specific implementation of the portlist getting // code. When non-nil, it's responsible for getting the complete list of // cached ports complete with the process name. That is, when set, @@ -49,12 +47,6 @@ type Poller struct { initOnce sync.Once // guards init of os initErr error - // closeCtx is the context that's canceled on Close. - closeCtx context.Context - closeCtxCancel context.CancelFunc - - runDone chan struct{} // closed when Run completes - // scatch is memory for Poller.getList to reuse between calls. scratch []Port @@ -75,36 +67,6 @@ type osImpl interface { AppendListeningPorts(base []Port) ([]Port, error) } -// newOSImpl, if non-nil, constructs a new osImpl. -var newOSImpl func(includeLocalhost bool) osImpl - -var ( - errUnimplemented = errors.New("portlist poller not implemented on " + runtime.GOOS) - errDisabled = errors.New("portlist disabled by envknob") -) - -// NewPoller returns a new portlist Poller. It returns an error -// if the portlist couldn't be obtained. -func NewPoller() (*Poller, error) { - p := &Poller{ - c: make(chan List), - runDone: make(chan struct{}), - } - p.initOnce.Do(p.init) - if p.initErr != nil { - return nil, p.initErr - } - p.closeCtx, p.closeCtxCancel = context.WithCancel(context.Background()) - // Do one initial poll synchronously so we can return an error - // early. - if pl, err := p.getList(); err != nil { - return nil, err - } else { - p.setPrev(pl) - } - return p, nil -} - func (p *Poller) setPrev(pl List) { // Make a copy, as the pass in pl slice aliases pl.scratch and we don't want // that to except to the caller. @@ -114,22 +76,16 @@ func (p *Poller) setPrev(pl List) { // init initializes the Poller by ensuring it has an underlying // OS implementation and is not turned off by envknob. func (p *Poller) init() { - if debugDisablePortlist() { - p.initErr = errDisabled - return - } - if newOSImpl == nil { - p.initErr = errUnimplemented - return + switch { + case debugDisablePortlist(): + p.initErr = errors.New("portlist disabled by envknob") + case newOSImpl == nil: + p.initErr = errors.New("portlist poller not implemented on " + runtime.GOOS) + default: + p.os = newOSImpl(p.IncludeLocalhost) } - p.os = newOSImpl(p.IncludeLocalhost) } -// Updates return the channel that receives port list updates. -// -// The channel is closed when the Poller is closed. -func (p *Poller) Updates() <-chan List { return p.c } - // Close closes the Poller. func (p *Poller) Close() error { if p.initErr != nil { @@ -138,25 +94,9 @@ func (p *Poller) Close() error { if p.os == nil { return nil } - if p.closeCtxCancel != nil { - p.closeCtxCancel() - <-p.runDone - } return p.os.Close() } -// send sends pl to p.c and returns whether it was successfully sent. -func (p *Poller) send(ctx context.Context, pl List) (sent bool, err error) { - select { - case p.c <- pl: - return true, nil - case <-ctx.Done(): - return false, ctx.Err() - case <-p.closeCtx.Done(): - return false, nil - } -} - // Poll returns the list of listening ports, if changed from // a previous call as indicated by the changed result. func (p *Poller) Poll() (ports []Port, changed bool, err error) { @@ -175,55 +115,7 @@ func (p *Poller) Poll() (ports []Port, changed bool, err error) { return p.prev, true, nil } -// Run runs the Poller periodically until either the context -// is done, or the Close is called. -// -// Run may only be called once. -func (p *Poller) Run(ctx context.Context) error { - tick := time.NewTicker(pollInterval) - defer tick.Stop() - return p.runWithTickChan(ctx, tick.C) -} - -func (p *Poller) runWithTickChan(ctx context.Context, tickChan <-chan time.Time) error { - defer close(p.runDone) - defer close(p.c) - - // Send out the pre-generated initial value. - if sent, err := p.send(ctx, p.prev); !sent { - return err - } - - for { - select { - case <-tickChan: - pl, err := p.getList() - if err != nil { - return err - } - if pl.equal(p.prev) { - continue - } - p.setPrev(pl) - if sent, err := p.send(ctx, p.prev); !sent { - return err - } - case <-ctx.Done(): - return ctx.Err() - case <-p.closeCtx.Done(): - return nil - } - } -} - func (p *Poller) getList() (List, error) { - // TODO(marwan): this method does not - // need to do any init logic. Update tests - // once async API is removed. - p.initOnce.Do(p.init) - if p.initErr == errDisabled { - return nil, nil - } var err error p.scratch, err = p.os.AppendListeningPorts(p.scratch[:0]) return p.scratch, err diff --git a/portlist/portlist_test.go b/portlist/portlist_test.go index 86e8bd335..14cc490f7 100644 --- a/portlist/portlist_test.go +++ b/portlist/portlist_test.go @@ -4,11 +4,8 @@ package portlist import ( - "context" "net" - "sync" "testing" - "time" "tailscale.com/tstest" ) @@ -17,14 +14,14 @@ func TestGetList(t *testing.T) { tstest.ResourceCheck(t) var p Poller - pl, err := p.getList() + pl, _, err := p.Poll() if err != nil { t.Fatal(err) } for i, p := range pl { t.Logf("[%d] %+v", i, p) } - t.Logf("As String: %v", pl.String()) + t.Logf("As String: %s", List(pl)) } func TestIgnoreLocallyBoundPorts(t *testing.T) { @@ -38,7 +35,7 @@ func TestIgnoreLocallyBoundPorts(t *testing.T) { ta := ln.Addr().(*net.TCPAddr) port := ta.Port var p Poller - pl, err := p.getList() + pl, _, err := p.Poll() if err != nil { t.Fatal(err) } @@ -49,16 +46,16 @@ func TestIgnoreLocallyBoundPorts(t *testing.T) { } } -func TestChangesOverTime(t *testing.T) { +func TestPoller(t *testing.T) { var p Poller p.IncludeLocalhost = true get := func(t *testing.T) []Port { t.Helper() - s, err := p.getList() + s, _, err := p.Poll() if err != nil { t.Fatal(err) } - return append([]Port(nil), s...) + return s } p1 := get(t) @@ -192,74 +189,6 @@ func TestClose(t *testing.T) { } } -func TestPoller(t *testing.T) { - p, err := NewPoller() - if err != nil { - t.Skipf("not running test: %v", err) - } - t.Cleanup(func() { - if err := p.Close(); err != nil { - t.Errorf("error closing poller in test: %v", err) - } - }) - - var wg sync.WaitGroup - wg.Add(2) - - gotUpdate := make(chan bool, 16) - - go func() { - defer wg.Done() - for pl := range p.Updates() { - // Look at all the pl slice memory to maximize - // chance of race detector seeing violations. - for _, v := range pl { - if v == (Port{}) { - // Force use - panic("empty port") - } - } - select { - case gotUpdate <- true: - default: - } - } - }() - - tick := make(chan time.Time, 16) - go func() { - defer wg.Done() - if err := p.runWithTickChan(context.Background(), tick); err != nil { - t.Error("runWithTickChan:", err) - } - }() - for i := 0; i < 10; i++ { - ln, err := net.Listen("tcp", ":0") - if err != nil { - t.Fatal(err) - } - defer ln.Close() - tick <- time.Time{} - - select { - case <-gotUpdate: - case <-time.After(5 * time.Second): - t.Fatal("timed out waiting for update") - } - } - - // And a bunch of ticks without waiting for updates, - // to make race tests more likely to fail, if any present. - for i := 0; i < 10; i++ { - tick <- time.Time{} - } - - if err := p.Close(); err != nil { - t.Fatal(err) - } - wg.Wait() -} - func BenchmarkGetList(b *testing.B) { benchmarkGetList(b, false) } @@ -271,6 +200,11 @@ func BenchmarkGetListIncremental(b *testing.B) { func benchmarkGetList(b *testing.B, incremental bool) { b.ReportAllocs() var p Poller + p.init() + if p.initErr != nil { + b.Skip(p.initErr) + } + b.Cleanup(func() { p.Close() }) for i := 0; i < b.N; i++ { pl, err := p.getList() if err != nil {