diff --git a/net/dns/manager.go b/net/dns/manager.go index 2fc54492e..36040a7e8 100644 --- a/net/dns/manager.go +++ b/net/dns/manager.go @@ -8,6 +8,7 @@ import ( "bufio" "context" "errors" + "net" "runtime" "sync/atomic" "time" @@ -63,8 +64,8 @@ func maxActiveQueries() int32 { const reconfigTimeout = time.Second type response struct { - pkt []byte - to netaddr.IPPort // response destination (request source) + pkt []byte + to netaddr.IPPort // response destination (request source) } // Manager manages system DNS settings. @@ -80,6 +81,9 @@ type Manager struct { responses chan response activeQueriesAtomic int32 + ctx context.Context // good until Down + ctxCancel context.CancelFunc // closes ctx + resolver *resolver.Resolver os OSConfigurator } @@ -96,6 +100,7 @@ func NewManager(logf logger.Logf, oscfg OSConfigurator, linkMon *monitor.Mon, di os: oscfg, responses: make(chan response), } + m.ctx, m.ctxCancel = context.WithCancel(context.Background()) m.logf("using %T", m.os) return m } @@ -257,14 +262,18 @@ func (m *Manager) EnqueuePacket(bs []byte, proto ipproto.Proto, from, to netaddr } go func() { - resp, err := m.resolver.Query(context.Background(), bs, from) + resp, err := m.resolver.Query(m.ctx, bs, from) if err != nil { atomic.AddInt32(&m.activeQueriesAtomic, -1) m.logf("dns query: %v", err) return } - m.responses <- response{resp, from} + select { + case <-m.ctx.Done(): + return + case m.responses <- response{resp, from}: + } }() return nil } @@ -274,7 +283,13 @@ func (m *Manager) EnqueuePacket(bs []byte, proto ipproto.Proto, from, to netaddr // // TODO(tom): Rip out once all platforms use netstack. func (m *Manager) NextPacket() ([]byte, error) { - resp := <-m.responses + var resp response + select { + case <-m.ctx.Done(): + return nil, net.ErrClosed + case resp = <-m.responses: + // continue + } // Unused space is needed further down the stack. To avoid extra // allocations/copying later on, we allocate such space here. @@ -315,6 +330,13 @@ func (m *Manager) NextPacket() ([]byte, error) { } func (m *Manager) Query(ctx context.Context, bs []byte, from netaddr.IPPort) ([]byte, error) { + select { + case <-m.ctx.Done(): + return nil, net.ErrClosed + default: + // continue + } + if n := atomic.AddInt32(&m.activeQueriesAtomic, 1); n > maxActiveQueries() { atomic.AddInt32(&m.activeQueriesAtomic, -1) metricDNSQueryErrorQueue.Add(1) @@ -325,6 +347,7 @@ func (m *Manager) Query(ctx context.Context, bs []byte, from netaddr.IPPort) ([] } func (m *Manager) Down() error { + m.ctxCancel() if err := m.os.Close(); err != nil { return err } @@ -353,4 +376,4 @@ func Cleanup(logf logger.Logf, interfaceName string) { var ( metricDNSQueryErrorQueue = clientmetric.NewCounter("dns_query_local_error_queue") -) \ No newline at end of file +) diff --git a/net/dns/resolver/tsdns.go b/net/dns/resolver/tsdns.go index 64e9bb0d6..dbb665bc3 100644 --- a/net/dns/resolver/tsdns.go +++ b/net/dns/resolver/tsdns.go @@ -45,9 +45,6 @@ const maxResponseBytes = 4095 // defaultTTL is the TTL of all responses from Resolver. const defaultTTL = 600 * time.Second -// ErrClosed indicates that the resolver has been closed and readers should exit. -var ErrClosed = errors.New("closed") - var ( errNotQuery = errors.New("not a DNS query") errNotOurName = errors.New("not a Tailscale DNS name") @@ -264,7 +261,7 @@ func (r *Resolver) Query(ctx context.Context, bs []byte, from netaddr.IPPort) ([ select { case <-r.closed: metricDNSQueryErrorClosed.Add(1) - return nil, ErrClosed + return nil, net.ErrClosed default: } diff --git a/wgengine/userspace.go b/wgengine/userspace.go index f732f766b..300abe4bb 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "io" + "net" "reflect" "runtime" "strings" @@ -499,7 +500,7 @@ func (e *userspaceEngine) handleLocalPackets(p *packet.Parsed, t *tstun.Wrapper) func (e *userspaceEngine) pollResolver() { for { bs, err := e.dns.NextPacket() - if err == resolver.ErrClosed { + if errors.Is(err, net.ErrClosed) { return } if err != nil {