net/dns,net/dns/resolver: refactor channels/magicDNS out of Resolver

Moves magicDNS-specific handling out of Resolver & into dns.Manager. This
greatly simplifies the Resolver to solely issuing queries and returning
responses, without channels.

Enforcement of max number of in-flight magicDNS queries, assembly of
synthetic UDP datagrams, and integration with wgengine for
recieving/responding to magicDNS traffic is now entirely in Manager.
This path is being kept around, but ultimately aims to be deleted and
replaced with a netstack-based path.

This commit is part of a series to implement magicDNS using netstack.

Signed-off-by: Tom DNetto <tom@tailscale.com>
pull/4590/head
Tom DNetto 2 years ago committed by Tom
parent a54671529b
commit 5b85f848dd

@ -6,20 +6,51 @@ package dns
import (
"bufio"
"context"
"errors"
"runtime"
"sync/atomic"
"time"
"inet.af/netaddr"
"tailscale.com/health"
"tailscale.com/net/dns/resolver"
"tailscale.com/net/packet"
"tailscale.com/net/tsaddr"
"tailscale.com/net/tsdial"
"tailscale.com/net/tstun"
"tailscale.com/types/dnstype"
"tailscale.com/types/ipproto"
"tailscale.com/types/logger"
"tailscale.com/util/clientmetric"
"tailscale.com/util/dnsname"
"tailscale.com/wgengine/monitor"
)
var (
magicDNSIP = tsaddr.TailscaleServiceIP()
magicDNSIPv6 = tsaddr.TailscaleServiceIPv6()
)
var (
errFullQueue = errors.New("request queue full")
)
// maxActiveQueries returns the maximal number of DNS requests that be
// can running.
// If EnqueueRequest is called when this many requests are already pending,
// the request will be dropped to avoid blocking the caller.
func maxActiveQueries() int32 {
if runtime.GOOS == "ios" {
// For memory paranoia reasons on iOS, match the
// historical Tailscale 1.x..1.8 behavior for now
// (just before the 1.10 release).
return 64
}
// But for other platforms, allow more burstiness:
return 256
}
// We use file-ignore below instead of ignore because on some platforms,
// the lint exception is necessary and on others it is not,
// and plain ignore complains if the exception is unnecessary.
@ -31,10 +62,24 @@ import (
// Such operations should be wrapped in a timeout context.
const reconfigTimeout = time.Second
type response struct {
pkt []byte
from netaddr.IPPort // where the packet needs to be sent
}
// Manager manages system DNS settings.
type Manager struct {
logf logger.Logf
// When netstack is not used, Manager implements magic DNS.
// In this case, responses tracks completed DNS requests
// which need a response, and NextPacket() synthesizes a
// fake IP+UDP header to finish assembling the response.
//
// TODO(tom): Rip out once all platforms use netstack.
responses chan response
activeQueriesAtomic int32
resolver *resolver.Resolver
os OSConfigurator
}
@ -46,9 +91,10 @@ func NewManager(logf logger.Logf, oscfg OSConfigurator, linkMon *monitor.Mon, di
}
logf = logger.WithPrefix(logf, "dns: ")
m := &Manager{
logf: logf,
resolver: resolver.New(logf, linkMon, linkSel, dialer),
os: oscfg,
logf: logf,
resolver: resolver.New(logf, linkMon, linkSel, dialer),
os: oscfg,
responses: make(chan response),
}
m.logf("using %T", m.os)
return m
@ -195,12 +241,87 @@ func toIPsOnly(resolvers []dnstype.Resolver) (ret []netaddr.IP) {
return ret
}
// EnqueuePacket is the legacy path for handling magic DNS traffic, and is
// called with a DNS request payload.
//
// TODO(tom): Rip out once all platforms use netstack.
func (m *Manager) EnqueuePacket(bs []byte, proto ipproto.Proto, from, to netaddr.IPPort) error {
return m.resolver.EnqueuePacket(bs, proto, from, to)
if to.Port() != 53 || proto != ipproto.UDP {
return nil
}
if n := atomic.AddInt32(&m.activeQueriesAtomic, 1); n > maxActiveQueries() {
atomic.AddInt32(&m.activeQueriesAtomic, -1)
metricDNSQueryErrorQueue.Add(1)
return errFullQueue
}
go func() {
resp, err := m.resolver.Query(context.Background(), bs, from)
if err != nil {
atomic.AddInt32(&m.activeQueriesAtomic, -1)
m.logf("dns query: %v", err)
return
}
m.responses <- response{resp, from}
}()
return nil
}
// NextPacket is the legacy path for obtaining DNS results in response to
// magic DNS queries. It blocks until a response is available.
//
// TODO(tom): Rip out once all platforms use netstack.
func (m *Manager) NextPacket() ([]byte, error) {
return m.resolver.NextPacket()
resp := <-m.responses
// Unused space is needed further down the stack. To avoid extra
// allocations/copying later on, we allocate such space here.
const offset = tstun.PacketStartOffset
var buf []byte
switch {
case resp.from.IP().Is4():
h := packet.UDP4Header{
IP4Header: packet.IP4Header{
Src: magicDNSIP,
Dst: resp.from.IP(),
},
SrcPort: 53,
DstPort: resp.from.Port(),
}
hlen := h.Len()
buf = make([]byte, offset+hlen+len(resp.pkt))
copy(buf[offset+hlen:], resp.pkt)
h.Marshal(buf[offset:])
case resp.from.IP().Is6():
h := packet.UDP6Header{
IP6Header: packet.IP6Header{
Src: magicDNSIPv6,
Dst: resp.from.IP(),
},
SrcPort: 53,
DstPort: resp.from.Port(),
}
hlen := h.Len()
buf = make([]byte, offset+hlen+len(resp.pkt))
copy(buf[offset+hlen:], resp.pkt)
h.Marshal(buf[offset:])
}
atomic.AddInt32(&m.activeQueriesAtomic, -1)
return buf, nil
}
func (m *Manager) Query(ctx context.Context, bs []byte, from netaddr.IPPort) ([]byte, error) {
if n := atomic.AddInt32(&m.activeQueriesAtomic, 1); n > maxActiveQueries() {
atomic.AddInt32(&m.activeQueriesAtomic, -1)
metricDNSQueryErrorQueue.Add(1)
return nil, errFullQueue
}
defer atomic.AddInt32(&m.activeQueriesAtomic, -1)
return m.resolver.Query(ctx, bs, from)
}
func (m *Manager) Down() error {
@ -229,3 +350,7 @@ func Cleanup(logf logger.Logf, interfaceName string) {
logf("dns down: %v", err)
}
}
var (
metricDNSQueryErrorQueue = clientmetric.NewCounter("dns_query_local_error_queue")
)

@ -190,9 +190,6 @@ type forwarder struct {
ctx context.Context // good until Close
ctxCancel context.CancelFunc // closes ctx
// responses is a channel by which responses are returned.
responses chan packet
mu sync.Mutex // guards following
dohClient map[string]*http.Client // urlBase -> client
@ -229,14 +226,13 @@ func maxDoHInFlight(goos string) int {
return 1000
}
func newForwarder(logf logger.Logf, responses chan packet, linkMon *monitor.Mon, linkSel ForwardLinkSelector, dialer *tsdial.Dialer) *forwarder {
func newForwarder(logf logger.Logf, linkMon *monitor.Mon, linkSel ForwardLinkSelector, dialer *tsdial.Dialer) *forwarder {
f := &forwarder{
logf: logger.WithPrefix(logf, "forward: "),
linkMon: linkMon,
linkSel: linkSel,
dialer: dialer,
responses: responses,
dohSem: make(chan struct{}, maxDoHInFlight(runtime.GOOS)),
logf: logger.WithPrefix(logf, "forward: "),
linkMon: linkMon,
linkSel: linkSel,
dialer: dialer,
dohSem: make(chan struct{}, maxDoHInFlight(runtime.GOOS)),
}
f.ctx, f.ctxCancel = context.WithCancel(context.Background())
return f
@ -601,17 +597,6 @@ type forwardQuery struct {
// ...
}
// forward forwards the query to all upstream nameservers and waits for
// the first response.
//
// It either sends to f.responses and returns nil, or returns a
// non-nil error (without sending to the channel).
func (f *forwarder) forward(query packet) error {
ctx, cancel := context.WithTimeout(f.ctx, responseTimeout)
defer cancel()
return f.forwardWithDestChan(ctx, query, f.responses)
}
// forwardWithDestChan forwards the query to all upstream nameservers
// and waits for the first response.
//

@ -26,12 +26,9 @@ import (
dns "golang.org/x/net/dns/dnsmessage"
"inet.af/netaddr"
"tailscale.com/net/dns/resolvconffile"
tspacket "tailscale.com/net/packet"
"tailscale.com/net/tsaddr"
"tailscale.com/net/tsdial"
"tailscale.com/net/tstun"
"tailscale.com/types/dnstype"
"tailscale.com/types/ipproto"
"tailscale.com/types/logger"
"tailscale.com/util/clientmetric"
"tailscale.com/util/dnsname"
@ -40,33 +37,11 @@ import (
const dnsSymbolicFQDN = "magicdns.localhost-tailscale-daemon."
var (
magicDNSIP = tsaddr.TailscaleServiceIP()
magicDNSIPv6 = tsaddr.TailscaleServiceIPv6()
)
const magicDNSPort = 53
// maxResponseBytes is the maximum size of a response from a Resolver. The
// actual buffer size will be one larger than this so that we can detect
// truncation in a platform-agnostic way.
const maxResponseBytes = 4095
// maxActiveQueries returns the maximal number of DNS requests that be
// can running.
// If EnqueueRequest is called when this many requests are already pending,
// the request will be dropped to avoid blocking the caller.
func maxActiveQueries() int32 {
if runtime.GOOS == "ios" {
// For memory paranoia reasons on iOS, match the
// historical Tailscale 1.x..1.8 behavior for now
// (just before the 1.10 release).
return 64
}
// But for other platforms, allow more burstiness:
return 256
}
// defaultTTL is the TTL of all responses from Resolver.
const defaultTTL = 600 * time.Second
@ -74,7 +49,6 @@ const defaultTTL = 600 * time.Second
var ErrClosed = errors.New("closed")
var (
errFullQueue = errors.New("request queue full")
errNotQuery = errors.New("not a DNS query")
errNotOurName = errors.New("not a Tailscale DNS name")
)
@ -209,12 +183,6 @@ type Resolver struct {
// forwarder forwards requests to upstream nameservers.
forwarder *forwarder
activeQueriesAtomic int32 // number of DNS queries in flight
// responses is an unbuffered channel to which responses are returned.
responses chan packet
// errors is an unbuffered channel to which errors are returned.
errors chan error
// closed signals all goroutines to stop.
closed chan struct{}
// wg signals when all goroutines have stopped.
@ -241,16 +209,14 @@ func New(logf logger.Logf, linkMon *monitor.Mon, linkSel ForwardLinkSelector, di
panic("nil Dialer")
}
r := &Resolver{
logf: logger.WithPrefix(logf, "resolver: "),
linkMon: linkMon,
responses: make(chan packet),
errors: make(chan error),
closed: make(chan struct{}),
hostToIP: map[dnsname.FQDN][]netaddr.IP{},
ipToHost: map[netaddr.IP]dnsname.FQDN{},
dialer: dialer,
}
r.forwarder = newForwarder(r.logf, r.responses, linkMon, linkSel, dialer)
logf: logger.WithPrefix(logf, "resolver: "),
linkMon: linkMon,
closed: make(chan struct{}),
hostToIP: map[dnsname.FQDN][]netaddr.IP{},
ipToHost: map[netaddr.IP]dnsname.FQDN{},
dialer: dialer,
}
r.forwarder = newForwarder(r.logf, linkMon, linkSel, dialer)
return r
}
@ -293,94 +259,29 @@ func (r *Resolver) Close() {
r.forwarder.Close()
}
// EnqueuePacket handles a packet to the magicDNS endpoint.
// It takes ownership of the payload and does not block.
// If the queue is full, the request will be dropped and an error will be returned.
func (r *Resolver) EnqueuePacket(bs []byte, proto ipproto.Proto, from, to netaddr.IPPort) error {
if to.Port() != magicDNSPort || proto != ipproto.UDP {
return nil
}
return r.enqueueRequest(bs, proto, from, to)
}
// enqueueRequest places the given DNS request in the resolver's queue.
// If the queue is full, the request will be dropped and an error will be returned.
func (r *Resolver) enqueueRequest(bs []byte, proto ipproto.Proto, from, to netaddr.IPPort) error {
func (r *Resolver) Query(ctx context.Context, bs []byte, from netaddr.IPPort) ([]byte, error) {
metricDNSQueryLocal.Add(1)
select {
case <-r.closed:
metricDNSQueryErrorClosed.Add(1)
return ErrClosed
return nil, ErrClosed
default:
}
if n := atomic.AddInt32(&r.activeQueriesAtomic, 1); n > maxActiveQueries() {
atomic.AddInt32(&r.activeQueriesAtomic, -1)
metricDNSQueryErrorQueue.Add(1)
return errFullQueue
}
go r.handleQuery(packet{bs, from})
return nil
}
// NextPacket returns the next packet to service traffic for magicDNS. The returned
// packet is prefixed with unused space consistent with the semantics of injection
// into tstun.Wrapper.
// It blocks until a response is available and gives up ownership of the response payload.
func (r *Resolver) NextPacket() (ipPacket []byte, err error) {
bs, to, err := r.nextResponse()
if err != nil {
return nil, err
}
// Unused space is needed further down the stack. To avoid extra
// allocations/copying later on, we allocate such space here.
const offset = tstun.PacketStartOffset
var buf []byte
switch {
case to.IP().Is4():
h := tspacket.UDP4Header{
IP4Header: tspacket.IP4Header{
Src: magicDNSIP,
Dst: to.IP(),
},
SrcPort: magicDNSPort,
DstPort: to.Port(),
}
hlen := h.Len()
buf = make([]byte, offset+hlen+len(bs))
copy(buf[offset+hlen:], bs)
h.Marshal(buf[offset:])
case to.IP().Is6():
h := tspacket.UDP6Header{
IP6Header: tspacket.IP6Header{
Src: magicDNSIPv6,
Dst: to.IP(),
},
SrcPort: magicDNSPort,
DstPort: to.Port(),
out, err := r.respond(bs)
if err == errNotOurName {
responses := make(chan packet, 1)
ctx, cancel := context.WithCancel(ctx)
defer close(responses)
defer cancel()
err = r.forwarder.forwardWithDestChan(ctx, packet{bs, from}, responses)
if err != nil {
return nil, err
}
hlen := h.Len()
buf = make([]byte, offset+hlen+len(bs))
copy(buf[offset+hlen:], bs)
h.Marshal(buf[offset:])
return (<-responses).bs, nil
}
return buf, nil
}
// nextResponse returns a DNS response to a previously enqueued request.
// It blocks until a response is available and gives up ownership of the response payload.
func (r *Resolver) nextResponse() (packet []byte, to netaddr.IPPort, err error) {
select {
case <-r.closed:
return nil, netaddr.IPPort{}, ErrClosed
case resp := <-r.responses:
return resp.bs, resp.addr, nil
case err := <-r.errors:
return nil, netaddr.IPPort{}, err
}
return out, err
}
// parseExitNodeQuery parses a DNS request packet.
@ -808,30 +709,6 @@ func (r *Resolver) fqdnForIPLocked(ip netaddr.IP, name dnsname.FQDN) (dnsname.FQ
return ret, dns.RCodeSuccess
}
func (r *Resolver) handleQuery(pkt packet) {
defer atomic.AddInt32(&r.activeQueriesAtomic, -1)
out, err := r.respond(pkt.bs)
if err == errNotOurName {
err = r.forwarder.forward(pkt)
if err == nil {
// forward will send response into r.responses, nothing to do.
return
}
}
if err != nil {
select {
case <-r.closed:
case r.errors <- err:
}
} else {
select {
case <-r.closed:
case r.responses <- packet{out, pkt.addr}:
}
}
}
type response struct {
Header dns.Header
Question dns.Question
@ -1316,7 +1193,6 @@ func unARPA(a string) (ipStr string, ok bool) {
var (
metricDNSQueryLocal = clientmetric.NewCounter("dns_query_local")
metricDNSQueryErrorClosed = clientmetric.NewCounter("dns_query_local_error_closed")
metricDNSQueryErrorQueue = clientmetric.NewCounter("dns_query_local_error_queue")
metricDNSErrorParseNoQ = clientmetric.NewCounter("dns_query_respond_error_no_question")
metricDNSErrorParseQuery = clientmetric.NewCounter("dns_query_respond_error_parse")

@ -27,7 +27,6 @@ import (
"tailscale.com/net/tsdial"
"tailscale.com/tstest"
"tailscale.com/types/dnstype"
"tailscale.com/types/ipproto"
"tailscale.com/util/dnsname"
"tailscale.com/wgengine/monitor"
)
@ -234,11 +233,7 @@ func unpackResponse(payload []byte) (dnsResponse, error) {
}
func syncRespond(r *Resolver, query []byte) ([]byte, error) {
if err := r.enqueueRequest(query, ipproto.UDP, netaddr.IPPort{}, magicDNSv4Port); err != nil {
return nil, fmt.Errorf("enqueueRequest: %w", err)
}
payload, _, err := r.nextResponse()
return payload, err
return r.Query(context.Background(), query, netaddr.IPPort{})
}
func mustIP(str string) netaddr.IP {
@ -708,78 +703,6 @@ func TestDelegateSplitRoute(t *testing.T) {
}
}
func TestDelegateCollision(t *testing.T) {
server := serveDNS(t, "127.0.0.1:0",
"test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
defer server.Shutdown()
r := newResolver(t)
defer r.Close()
cfg := dnsCfg
cfg.Routes = map[dnsname.FQDN][]dnstype.Resolver{
".": {{Addr: server.PacketConn.LocalAddr().String()}},
}
r.SetConfig(cfg)
packets := []struct {
qname dnsname.FQDN
qtype dns.Type
addr netaddr.IPPort
}{
{"test.site.", dns.TypeA, netaddr.IPPortFrom(netaddr.IPv4(1, 1, 1, 1), 1001)},
{"test.site.", dns.TypeAAAA, netaddr.IPPortFrom(netaddr.IPv4(1, 1, 1, 1), 1002)},
}
// packets will have the same dns txid.
for _, p := range packets {
payload := dnspacket(p.qname, p.qtype, noEdns)
err := r.enqueueRequest(payload, ipproto.UDP, p.addr, magicDNSv4Port)
if err != nil {
t.Error(err)
}
}
// Despite the txid collision, the answer(s) should still match the query.
resp, addr, err := r.nextResponse()
if err != nil {
t.Error(err)
}
var p dns.Parser
_, err = p.Start(resp)
if err != nil {
t.Error(err)
}
err = p.SkipAllQuestions()
if err != nil {
t.Error(err)
}
ans, err := p.AllAnswers()
if err != nil {
t.Error(err)
}
if len(ans) == 0 {
t.Fatal("no answers")
}
var wantType dns.Type
switch ans[0].Body.(type) {
case *dns.AResource:
wantType = dns.TypeA
case *dns.AAAAResource:
wantType = dns.TypeAAAA
default:
t.Errorf("unexpected answer type: %T", ans[0].Body)
}
for _, p := range packets {
if p.qtype == wantType && p.addr != addr {
t.Errorf("addr = %v; want %v", addr, p.addr)
}
}
}
var allResponse = []byte{
0x00, 0x00, // transaction id: 0
0x84, 0x00, // flags: response, authoritative, no error
@ -1076,7 +999,7 @@ func TestForwardLinkSelection(t *testing.T) {
// routes differently.
specialIP := netaddr.IPv4(1, 2, 3, 4)
fwd := newForwarder(t.Logf, nil, nil, linkSelFunc(func(ip netaddr.IP) string {
fwd := newForwarder(t.Logf, nil, linkSelFunc(func(ip netaddr.IP) string {
if ip == netaddr.IPv4(1, 2, 3, 4) {
return "special"
}

Loading…
Cancel
Save