From 5b85f848ddcb890017f4b3fe8380df0d8c577a59 Mon Sep 17 00:00:00 2001 From: Tom DNetto Date: Thu, 14 Apr 2022 13:27:59 -0700 Subject: [PATCH] net/dns,net/dns/resolver: refactor channels/magicDNS out of Resolver Moves magicDNS-specific handling out of Resolver & into dns.Manager. This greatly simplifies the Resolver to solely issuing queries and returning responses, without channels. Enforcement of max number of in-flight magicDNS queries, assembly of synthetic UDP datagrams, and integration with wgengine for recieving/responding to magicDNS traffic is now entirely in Manager. This path is being kept around, but ultimately aims to be deleted and replaced with a netstack-based path. This commit is part of a series to implement magicDNS using netstack. Signed-off-by: Tom DNetto --- net/dns/manager.go | 135 ++++++++++++++++++++++++++- net/dns/resolver/forwarder.go | 27 ++---- net/dns/resolver/tsdns.go | 166 +++++---------------------------- net/dns/resolver/tsdns_test.go | 81 +--------------- 4 files changed, 159 insertions(+), 250 deletions(-) diff --git a/net/dns/manager.go b/net/dns/manager.go index 829a3fa8c..7bccb45d2 100644 --- a/net/dns/manager.go +++ b/net/dns/manager.go @@ -6,20 +6,51 @@ package dns import ( "bufio" + "context" + "errors" "runtime" + "sync/atomic" "time" "inet.af/netaddr" "tailscale.com/health" "tailscale.com/net/dns/resolver" + "tailscale.com/net/packet" + "tailscale.com/net/tsaddr" "tailscale.com/net/tsdial" + "tailscale.com/net/tstun" "tailscale.com/types/dnstype" "tailscale.com/types/ipproto" "tailscale.com/types/logger" + "tailscale.com/util/clientmetric" "tailscale.com/util/dnsname" "tailscale.com/wgengine/monitor" ) +var ( + magicDNSIP = tsaddr.TailscaleServiceIP() + magicDNSIPv6 = tsaddr.TailscaleServiceIPv6() +) + +var ( + errFullQueue = errors.New("request queue full") +) + +// maxActiveQueries returns the maximal number of DNS requests that be +// can running. +// If EnqueueRequest is called when this many requests are already pending, +// the request will be dropped to avoid blocking the caller. +func maxActiveQueries() int32 { + if runtime.GOOS == "ios" { + // For memory paranoia reasons on iOS, match the + // historical Tailscale 1.x..1.8 behavior for now + // (just before the 1.10 release). + return 64 + } + // But for other platforms, allow more burstiness: + return 256 +} + // We use file-ignore below instead of ignore because on some platforms, // the lint exception is necessary and on others it is not, // and plain ignore complains if the exception is unnecessary. @@ -31,10 +62,24 @@ import ( // Such operations should be wrapped in a timeout context. const reconfigTimeout = time.Second +type response struct { + pkt []byte + from netaddr.IPPort // where the packet needs to be sent +} + // Manager manages system DNS settings. type Manager struct { logf logger.Logf + // When netstack is not used, Manager implements magic DNS. + // In this case, responses tracks completed DNS requests + // which need a response, and NextPacket() synthesizes a + // fake IP+UDP header to finish assembling the response. + // + // TODO(tom): Rip out once all platforms use netstack. + responses chan response + activeQueriesAtomic int32 + resolver *resolver.Resolver os OSConfigurator } @@ -46,9 +91,10 @@ func NewManager(logf logger.Logf, oscfg OSConfigurator, linkMon *monitor.Mon, di } logf = logger.WithPrefix(logf, "dns: ") m := &Manager{ - logf: logf, - resolver: resolver.New(logf, linkMon, linkSel, dialer), - os: oscfg, + logf: logf, + resolver: resolver.New(logf, linkMon, linkSel, dialer), + os: oscfg, + responses: make(chan response), } m.logf("using %T", m.os) return m @@ -195,12 +241,87 @@ func toIPsOnly(resolvers []dnstype.Resolver) (ret []netaddr.IP) { return ret } +// EnqueuePacket is the legacy path for handling magic DNS traffic, and is +// called with a DNS request payload. +// +// TODO(tom): Rip out once all platforms use netstack. func (m *Manager) EnqueuePacket(bs []byte, proto ipproto.Proto, from, to netaddr.IPPort) error { - return m.resolver.EnqueuePacket(bs, proto, from, to) + if to.Port() != 53 || proto != ipproto.UDP { + return nil + } + + if n := atomic.AddInt32(&m.activeQueriesAtomic, 1); n > maxActiveQueries() { + atomic.AddInt32(&m.activeQueriesAtomic, -1) + metricDNSQueryErrorQueue.Add(1) + return errFullQueue + } + + go func() { + resp, err := m.resolver.Query(context.Background(), bs, from) + if err != nil { + atomic.AddInt32(&m.activeQueriesAtomic, -1) + m.logf("dns query: %v", err) + return + } + + m.responses <- response{resp, from} + }() + return nil } +// NextPacket is the legacy path for obtaining DNS results in response to +// magic DNS queries. It blocks until a response is available. +// +// TODO(tom): Rip out once all platforms use netstack. func (m *Manager) NextPacket() ([]byte, error) { - return m.resolver.NextPacket() + resp := <-m.responses + + // Unused space is needed further down the stack. To avoid extra + // allocations/copying later on, we allocate such space here. + const offset = tstun.PacketStartOffset + + var buf []byte + switch { + case resp.from.IP().Is4(): + h := packet.UDP4Header{ + IP4Header: packet.IP4Header{ + Src: magicDNSIP, + Dst: resp.from.IP(), + }, + SrcPort: 53, + DstPort: resp.from.Port(), + } + hlen := h.Len() + buf = make([]byte, offset+hlen+len(resp.pkt)) + copy(buf[offset+hlen:], resp.pkt) + h.Marshal(buf[offset:]) + case resp.from.IP().Is6(): + h := packet.UDP6Header{ + IP6Header: packet.IP6Header{ + Src: magicDNSIPv6, + Dst: resp.from.IP(), + }, + SrcPort: 53, + DstPort: resp.from.Port(), + } + hlen := h.Len() + buf = make([]byte, offset+hlen+len(resp.pkt)) + copy(buf[offset+hlen:], resp.pkt) + h.Marshal(buf[offset:]) + } + + atomic.AddInt32(&m.activeQueriesAtomic, -1) + return buf, nil +} + +func (m *Manager) Query(ctx context.Context, bs []byte, from netaddr.IPPort) ([]byte, error) { + if n := atomic.AddInt32(&m.activeQueriesAtomic, 1); n > maxActiveQueries() { + atomic.AddInt32(&m.activeQueriesAtomic, -1) + metricDNSQueryErrorQueue.Add(1) + return nil, errFullQueue + } + defer atomic.AddInt32(&m.activeQueriesAtomic, -1) + return m.resolver.Query(ctx, bs, from) } func (m *Manager) Down() error { @@ -229,3 +350,7 @@ func Cleanup(logf logger.Logf, interfaceName string) { logf("dns down: %v", err) } } + +var ( + metricDNSQueryErrorQueue = clientmetric.NewCounter("dns_query_local_error_queue") +) \ No newline at end of file diff --git a/net/dns/resolver/forwarder.go b/net/dns/resolver/forwarder.go index 162fd5f15..f00a887d2 100644 --- a/net/dns/resolver/forwarder.go +++ b/net/dns/resolver/forwarder.go @@ -190,9 +190,6 @@ type forwarder struct { ctx context.Context // good until Close ctxCancel context.CancelFunc // closes ctx - // responses is a channel by which responses are returned. - responses chan packet - mu sync.Mutex // guards following dohClient map[string]*http.Client // urlBase -> client @@ -229,14 +226,13 @@ func maxDoHInFlight(goos string) int { return 1000 } -func newForwarder(logf logger.Logf, responses chan packet, linkMon *monitor.Mon, linkSel ForwardLinkSelector, dialer *tsdial.Dialer) *forwarder { +func newForwarder(logf logger.Logf, linkMon *monitor.Mon, linkSel ForwardLinkSelector, dialer *tsdial.Dialer) *forwarder { f := &forwarder{ - logf: logger.WithPrefix(logf, "forward: "), - linkMon: linkMon, - linkSel: linkSel, - dialer: dialer, - responses: responses, - dohSem: make(chan struct{}, maxDoHInFlight(runtime.GOOS)), + logf: logger.WithPrefix(logf, "forward: "), + linkMon: linkMon, + linkSel: linkSel, + dialer: dialer, + dohSem: make(chan struct{}, maxDoHInFlight(runtime.GOOS)), } f.ctx, f.ctxCancel = context.WithCancel(context.Background()) return f @@ -601,17 +597,6 @@ type forwardQuery struct { // ... } -// forward forwards the query to all upstream nameservers and waits for -// the first response. -// -// It either sends to f.responses and returns nil, or returns a -// non-nil error (without sending to the channel). -func (f *forwarder) forward(query packet) error { - ctx, cancel := context.WithTimeout(f.ctx, responseTimeout) - defer cancel() - return f.forwardWithDestChan(ctx, query, f.responses) -} - // forwardWithDestChan forwards the query to all upstream nameservers // and waits for the first response. // diff --git a/net/dns/resolver/tsdns.go b/net/dns/resolver/tsdns.go index 5d2b71107..64e9bb0d6 100644 --- a/net/dns/resolver/tsdns.go +++ b/net/dns/resolver/tsdns.go @@ -26,12 +26,9 @@ import ( dns "golang.org/x/net/dns/dnsmessage" "inet.af/netaddr" "tailscale.com/net/dns/resolvconffile" - tspacket "tailscale.com/net/packet" "tailscale.com/net/tsaddr" "tailscale.com/net/tsdial" - "tailscale.com/net/tstun" "tailscale.com/types/dnstype" - "tailscale.com/types/ipproto" "tailscale.com/types/logger" "tailscale.com/util/clientmetric" "tailscale.com/util/dnsname" @@ -40,33 +37,11 @@ import ( const dnsSymbolicFQDN = "magicdns.localhost-tailscale-daemon." -var ( - magicDNSIP = tsaddr.TailscaleServiceIP() - magicDNSIPv6 = tsaddr.TailscaleServiceIPv6() -) - -const magicDNSPort = 53 - // maxResponseBytes is the maximum size of a response from a Resolver. The // actual buffer size will be one larger than this so that we can detect // truncation in a platform-agnostic way. const maxResponseBytes = 4095 -// maxActiveQueries returns the maximal number of DNS requests that be -// can running. -// If EnqueueRequest is called when this many requests are already pending, -// the request will be dropped to avoid blocking the caller. -func maxActiveQueries() int32 { - if runtime.GOOS == "ios" { - // For memory paranoia reasons on iOS, match the - // historical Tailscale 1.x..1.8 behavior for now - // (just before the 1.10 release). - return 64 - } - // But for other platforms, allow more burstiness: - return 256 -} - // defaultTTL is the TTL of all responses from Resolver. const defaultTTL = 600 * time.Second @@ -74,7 +49,6 @@ const defaultTTL = 600 * time.Second var ErrClosed = errors.New("closed") var ( - errFullQueue = errors.New("request queue full") errNotQuery = errors.New("not a DNS query") errNotOurName = errors.New("not a Tailscale DNS name") ) @@ -209,12 +183,6 @@ type Resolver struct { // forwarder forwards requests to upstream nameservers. forwarder *forwarder - activeQueriesAtomic int32 // number of DNS queries in flight - - // responses is an unbuffered channel to which responses are returned. - responses chan packet - // errors is an unbuffered channel to which errors are returned. - errors chan error // closed signals all goroutines to stop. closed chan struct{} // wg signals when all goroutines have stopped. @@ -241,16 +209,14 @@ func New(logf logger.Logf, linkMon *monitor.Mon, linkSel ForwardLinkSelector, di panic("nil Dialer") } r := &Resolver{ - logf: logger.WithPrefix(logf, "resolver: "), - linkMon: linkMon, - responses: make(chan packet), - errors: make(chan error), - closed: make(chan struct{}), - hostToIP: map[dnsname.FQDN][]netaddr.IP{}, - ipToHost: map[netaddr.IP]dnsname.FQDN{}, - dialer: dialer, - } - r.forwarder = newForwarder(r.logf, r.responses, linkMon, linkSel, dialer) + logf: logger.WithPrefix(logf, "resolver: "), + linkMon: linkMon, + closed: make(chan struct{}), + hostToIP: map[dnsname.FQDN][]netaddr.IP{}, + ipToHost: map[netaddr.IP]dnsname.FQDN{}, + dialer: dialer, + } + r.forwarder = newForwarder(r.logf, linkMon, linkSel, dialer) return r } @@ -293,94 +259,29 @@ func (r *Resolver) Close() { r.forwarder.Close() } -// EnqueuePacket handles a packet to the magicDNS endpoint. -// It takes ownership of the payload and does not block. -// If the queue is full, the request will be dropped and an error will be returned. -func (r *Resolver) EnqueuePacket(bs []byte, proto ipproto.Proto, from, to netaddr.IPPort) error { - if to.Port() != magicDNSPort || proto != ipproto.UDP { - return nil - } - - return r.enqueueRequest(bs, proto, from, to) -} - -// enqueueRequest places the given DNS request in the resolver's queue. -// If the queue is full, the request will be dropped and an error will be returned. -func (r *Resolver) enqueueRequest(bs []byte, proto ipproto.Proto, from, to netaddr.IPPort) error { +func (r *Resolver) Query(ctx context.Context, bs []byte, from netaddr.IPPort) ([]byte, error) { metricDNSQueryLocal.Add(1) select { case <-r.closed: metricDNSQueryErrorClosed.Add(1) - return ErrClosed + return nil, ErrClosed default: } - if n := atomic.AddInt32(&r.activeQueriesAtomic, 1); n > maxActiveQueries() { - atomic.AddInt32(&r.activeQueriesAtomic, -1) - metricDNSQueryErrorQueue.Add(1) - return errFullQueue - } - go r.handleQuery(packet{bs, from}) - return nil -} - -// NextPacket returns the next packet to service traffic for magicDNS. The returned -// packet is prefixed with unused space consistent with the semantics of injection -// into tstun.Wrapper. -// It blocks until a response is available and gives up ownership of the response payload. -func (r *Resolver) NextPacket() (ipPacket []byte, err error) { - bs, to, err := r.nextResponse() - if err != nil { - return nil, err - } - - // Unused space is needed further down the stack. To avoid extra - // allocations/copying later on, we allocate such space here. - const offset = tstun.PacketStartOffset - var buf []byte - switch { - case to.IP().Is4(): - h := tspacket.UDP4Header{ - IP4Header: tspacket.IP4Header{ - Src: magicDNSIP, - Dst: to.IP(), - }, - SrcPort: magicDNSPort, - DstPort: to.Port(), - } - hlen := h.Len() - buf = make([]byte, offset+hlen+len(bs)) - copy(buf[offset+hlen:], bs) - h.Marshal(buf[offset:]) - case to.IP().Is6(): - h := tspacket.UDP6Header{ - IP6Header: tspacket.IP6Header{ - Src: magicDNSIPv6, - Dst: to.IP(), - }, - SrcPort: magicDNSPort, - DstPort: to.Port(), + out, err := r.respond(bs) + if err == errNotOurName { + responses := make(chan packet, 1) + ctx, cancel := context.WithCancel(ctx) + defer close(responses) + defer cancel() + err = r.forwarder.forwardWithDestChan(ctx, packet{bs, from}, responses) + if err != nil { + return nil, err } - hlen := h.Len() - buf = make([]byte, offset+hlen+len(bs)) - copy(buf[offset+hlen:], bs) - h.Marshal(buf[offset:]) + return (<-responses).bs, nil } - return buf, nil -} - -// nextResponse returns a DNS response to a previously enqueued request. -// It blocks until a response is available and gives up ownership of the response payload. -func (r *Resolver) nextResponse() (packet []byte, to netaddr.IPPort, err error) { - select { - case <-r.closed: - return nil, netaddr.IPPort{}, ErrClosed - case resp := <-r.responses: - return resp.bs, resp.addr, nil - case err := <-r.errors: - return nil, netaddr.IPPort{}, err - } + return out, err } // parseExitNodeQuery parses a DNS request packet. @@ -808,30 +709,6 @@ func (r *Resolver) fqdnForIPLocked(ip netaddr.IP, name dnsname.FQDN) (dnsname.FQ return ret, dns.RCodeSuccess } -func (r *Resolver) handleQuery(pkt packet) { - defer atomic.AddInt32(&r.activeQueriesAtomic, -1) - - out, err := r.respond(pkt.bs) - if err == errNotOurName { - err = r.forwarder.forward(pkt) - if err == nil { - // forward will send response into r.responses, nothing to do. - return - } - } - if err != nil { - select { - case <-r.closed: - case r.errors <- err: - } - } else { - select { - case <-r.closed: - case r.responses <- packet{out, pkt.addr}: - } - } -} - type response struct { Header dns.Header Question dns.Question @@ -1316,7 +1193,6 @@ func unARPA(a string) (ipStr string, ok bool) { var ( metricDNSQueryLocal = clientmetric.NewCounter("dns_query_local") metricDNSQueryErrorClosed = clientmetric.NewCounter("dns_query_local_error_closed") - metricDNSQueryErrorQueue = clientmetric.NewCounter("dns_query_local_error_queue") metricDNSErrorParseNoQ = clientmetric.NewCounter("dns_query_respond_error_no_question") metricDNSErrorParseQuery = clientmetric.NewCounter("dns_query_respond_error_parse") diff --git a/net/dns/resolver/tsdns_test.go b/net/dns/resolver/tsdns_test.go index cc1931da1..f46e8d693 100644 --- a/net/dns/resolver/tsdns_test.go +++ b/net/dns/resolver/tsdns_test.go @@ -27,7 +27,6 @@ import ( "tailscale.com/net/tsdial" "tailscale.com/tstest" "tailscale.com/types/dnstype" - "tailscale.com/types/ipproto" "tailscale.com/util/dnsname" "tailscale.com/wgengine/monitor" ) @@ -234,11 +233,7 @@ func unpackResponse(payload []byte) (dnsResponse, error) { } func syncRespond(r *Resolver, query []byte) ([]byte, error) { - if err := r.enqueueRequest(query, ipproto.UDP, netaddr.IPPort{}, magicDNSv4Port); err != nil { - return nil, fmt.Errorf("enqueueRequest: %w", err) - } - payload, _, err := r.nextResponse() - return payload, err + return r.Query(context.Background(), query, netaddr.IPPort{}) } func mustIP(str string) netaddr.IP { @@ -708,78 +703,6 @@ func TestDelegateSplitRoute(t *testing.T) { } } -func TestDelegateCollision(t *testing.T) { - server := serveDNS(t, "127.0.0.1:0", - "test.site.", resolveToIP(testipv4, testipv6, "dns.test.site.")) - defer server.Shutdown() - - r := newResolver(t) - defer r.Close() - - cfg := dnsCfg - cfg.Routes = map[dnsname.FQDN][]dnstype.Resolver{ - ".": {{Addr: server.PacketConn.LocalAddr().String()}}, - } - r.SetConfig(cfg) - - packets := []struct { - qname dnsname.FQDN - qtype dns.Type - addr netaddr.IPPort - }{ - {"test.site.", dns.TypeA, netaddr.IPPortFrom(netaddr.IPv4(1, 1, 1, 1), 1001)}, - {"test.site.", dns.TypeAAAA, netaddr.IPPortFrom(netaddr.IPv4(1, 1, 1, 1), 1002)}, - } - - // packets will have the same dns txid. - for _, p := range packets { - payload := dnspacket(p.qname, p.qtype, noEdns) - err := r.enqueueRequest(payload, ipproto.UDP, p.addr, magicDNSv4Port) - if err != nil { - t.Error(err) - } - } - - // Despite the txid collision, the answer(s) should still match the query. - resp, addr, err := r.nextResponse() - if err != nil { - t.Error(err) - } - - var p dns.Parser - _, err = p.Start(resp) - if err != nil { - t.Error(err) - } - err = p.SkipAllQuestions() - if err != nil { - t.Error(err) - } - ans, err := p.AllAnswers() - if err != nil { - t.Error(err) - } - if len(ans) == 0 { - t.Fatal("no answers") - } - - var wantType dns.Type - switch ans[0].Body.(type) { - case *dns.AResource: - wantType = dns.TypeA - case *dns.AAAAResource: - wantType = dns.TypeAAAA - default: - t.Errorf("unexpected answer type: %T", ans[0].Body) - } - - for _, p := range packets { - if p.qtype == wantType && p.addr != addr { - t.Errorf("addr = %v; want %v", addr, p.addr) - } - } -} - var allResponse = []byte{ 0x00, 0x00, // transaction id: 0 0x84, 0x00, // flags: response, authoritative, no error @@ -1076,7 +999,7 @@ func TestForwardLinkSelection(t *testing.T) { // routes differently. specialIP := netaddr.IPv4(1, 2, 3, 4) - fwd := newForwarder(t.Logf, nil, nil, linkSelFunc(func(ip netaddr.IP) string { + fwd := newForwarder(t.Logf, nil, linkSelFunc(func(ip netaddr.IP) string { if ip == netaddr.IPv4(1, 2, 3, 4) { return "special" }