From 25525b7754937dd87c04a6390aa35831761f60c3 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Tue, 23 Nov 2021 09:58:34 -0800 Subject: [PATCH] net/dns/resolver, ipn/ipnlocal: wire up peerapi DoH server to DNS forwarder Updates #1713 Change-Id: Ia4ed9d8c9cef0e70aa6d30f2852eaab80f5f695a Signed-off-by: Brad Fitzpatrick --- ipn/ipnlocal/peerapi.go | 122 +++++++++++++++++++++++++++++++++- net/dns/resolver/forwarder.go | 28 ++++++-- net/dns/resolver/tsdns.go | 53 +++++++++++++++ 3 files changed, 196 insertions(+), 7 deletions(-) diff --git a/ipn/ipnlocal/peerapi.go b/ipn/ipnlocal/peerapi.go index d2b24e5c0..0a7615471 100644 --- a/ipn/ipnlocal/peerapi.go +++ b/ipn/ipnlocal/peerapi.go @@ -29,6 +29,7 @@ import ( "unicode" "unicode/utf8" + "golang.org/x/net/dns/dnsmessage" "inet.af/netaddr" "tailscale.com/client/tailscale/apitype" "tailscale.com/hostinfo" @@ -767,6 +768,8 @@ func (h *peerAPIHandler) replyToDNSQueries() bool { return h.isSelf || h.ps.b.OfferingExitNode() } +// handleDNSQuery implements a DoH server (RFC 8484) over the peerapi. +// It's not over HTTPS as the spec dictates, but rather HTTP-over-WireGuard. func (h *peerAPIHandler) handleDNSQuery(w http.ResponseWriter, r *http.Request) { if h.ps.resolver == nil { http.Error(w, "DNS not wired up", http.StatusNotImplemented) @@ -776,13 +779,45 @@ func (h *peerAPIHandler) handleDNSQuery(w http.ResponseWriter, r *http.Request) http.Error(w, "DNS access denied", http.StatusForbidden) return } + pretty := false // non-DoH debug mode for humans q, publicError := dohQuery(r) + if publicError != "" && r.Method == "GET" { + if name := r.FormValue("q"); name != "" { + pretty = true + publicError = "" + q = dnsQueryForName(name, r.FormValue("t")) + } + } if publicError != "" { http.Error(w, publicError, http.StatusBadRequest) return } - // TODO(bradfitz): owl. - fmt.Fprintf(w, "## TODO: got %d bytes of DNS query", len(q)) + + // Some timeout that's short enough to be noticed by humans + // but long enough that it's longer than real DNS timeouts. + const arbitraryTimeout = 5 * time.Second + + ctx, cancel := context.WithTimeout(r.Context(), arbitraryTimeout) + defer cancel() + res, err := h.ps.resolver.HandleExitNodeDNSQuery(ctx, q, h.remoteAddr) + if err != nil { + h.logf("handleDNS fwd error: %v", err) + if err := ctx.Err(); err != nil { + http.Error(w, err.Error(), 500) + } else { + http.Error(w, "DNS forwarding error", 500) + } + return + } + if pretty { + // Non-standard response for interactive debugging. + w.Header().Set("Content-Type", "application/json") + writePrettyDNSReply(w, res) + return + } + w.Header().Set("Content-Type", "application/dns-message") + w.Header().Set("Content-Length", strconv.Itoa(len(q))) + w.Write(res) } func dohQuery(r *http.Request) (dnsQuery []byte, publicErr string) { @@ -817,3 +852,86 @@ func dohQuery(r *http.Request) (dnsQuery []byte, publicErr string) { return q, "" } } + +func dnsQueryForName(name, typStr string) []byte { + typ := dnsmessage.TypeA + switch strings.ToLower(typStr) { + case "aaaa": + typ = dnsmessage.TypeAAAA + case "txt": + typ = dnsmessage.TypeTXT + } + b := dnsmessage.NewBuilder(nil, dnsmessage.Header{ + OpCode: 0, // query + RecursionDesired: true, + ID: 0, + }) + if !strings.HasSuffix(name, ".") { + name += "." + } + b.StartQuestions() + b.Question(dnsmessage.Question{ + Name: dnsmessage.MustNewName(name), + Type: typ, + Class: dnsmessage.ClassINET, + }) + msg, _ := b.Finish() + return msg +} + +func writePrettyDNSReply(w io.Writer, res []byte) (err error) { + defer func() { + if err != nil { + j, _ := json.Marshal(struct { + Error string + }{err.Error()}) + w.Write(j) + return + } + }() + var p dnsmessage.Parser + if _, err := p.Start(res); err != nil { + return err + } + if err := p.SkipAllQuestions(); err != nil { + return err + } + + var gotIPs []string + for { + h, err := p.AnswerHeader() + if err == dnsmessage.ErrSectionDone { + break + } + if err != nil { + return err + } + if h.Class != dnsmessage.ClassINET { + continue + } + switch h.Type { + case dnsmessage.TypeA: + r, err := p.AResource() + if err != nil { + return err + } + gotIPs = append(gotIPs, net.IP(r.A[:]).String()) + case dnsmessage.TypeAAAA: + r, err := p.AAAAResource() + if err != nil { + return err + } + gotIPs = append(gotIPs, net.IP(r.AAAA[:]).String()) + case dnsmessage.TypeTXT: + r, err := p.TXTResource() + if err != nil { + return err + } + gotIPs = append(gotIPs, r.TXT...) + } + } + j, _ := json.Marshal(gotIPs) + j = append(j, '\n') + w.Write(j) + return nil +} diff --git a/net/dns/resolver/forwarder.go b/net/dns/resolver/forwarder.go index 15ca87cdf..98c6e1e36 100644 --- a/net/dns/resolver/forwarder.go +++ b/net/dns/resolver/forwarder.go @@ -546,8 +546,26 @@ type forwardQuery struct { // ... } -// forward forwards the query to all upstream nameservers and returns the first response. +// forward forwards the query to all upstream nameservers and waits for +// the first response. +// +// It either sends to f.responses and returns nil, or returns a +// non-nil error (without sending to the channel). func (f *forwarder) forward(query packet) error { + ctx, cancel := context.WithTimeout(f.ctx, responseTimeout) + defer cancel() + return f.forwardWithDestChan(ctx, query, f.responses) +} + +// forward forwards the query to all upstream nameservers and waits +// for the first response. +// +// It either sends to responseChan and returns nil, or returns a +// non-nil error (without sending to the channel). +// +// If backupResolvers are specified, they're used in the case that no +// upstreams are available. +func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, responseChan chan<- packet, backupResolvers ...resolverAndDelay) error { domain, err := nameFromQuery(query.bs) if err != nil { return err @@ -564,6 +582,9 @@ func (f *forwarder) forward(query packet) error { clampEDNSSize(query.bs, maxResponseBytes) resolvers := f.resolvers(domain) + if len(resolvers) == 0 { + resolvers = backupResolvers + } if len(resolvers) == 0 { return errNoUpstreams } @@ -575,9 +596,6 @@ func (f *forwarder) forward(query packet) error { } defer fq.closeOnCtxDone.Close() - ctx, cancel := context.WithTimeout(f.ctx, responseTimeout) - defer cancel() - resc := make(chan []byte, 1) var ( mu sync.Mutex @@ -616,7 +634,7 @@ func (f *forwarder) forward(query packet) error { select { case <-ctx.Done(): return ctx.Err() - case f.responses <- packet{v, query.addr}: + case responseChan <- packet{v, query.addr}: return nil } case <-ctx.Done(): diff --git a/net/dns/resolver/tsdns.go b/net/dns/resolver/tsdns.go index b4d4b7c1a..f68247af3 100644 --- a/net/dns/resolver/tsdns.go +++ b/net/dns/resolver/tsdns.go @@ -8,6 +8,7 @@ package resolver import ( "bufio" + "context" "encoding/hex" "errors" "fmt" @@ -298,6 +299,58 @@ func (r *Resolver) NextResponse() (packet []byte, to netaddr.IPPort, err error) } } +// HandleExitNodeDNSQuery handles a DNS query that arrived from a peer +// via the peerapi's DoH server. This is only used when the local +// node is being an exit node. +func (r *Resolver) HandleExitNodeDNSQuery(ctx context.Context, q []byte, from netaddr.IPPort) (res []byte, err error) { + ch := make(chan packet, 1) + + err = r.forwarder.forwardWithDestChan(ctx, packet{q, from}, ch) + if err == errNoUpstreams { + // Handle to the system resolver. + switch runtime.GOOS { + case "linux": + // Assume for now that we don't have an upstream because + // they're using systemd-resolved and we're in Split DNS mode + // where we don't know the base config. + // + // TODO(bradfitz): this is a lazy assumption. Do better, and + // maybe move the HandleExitNodeDNSQuery method to the dns.Manager + // instead? But this works for now. + err = r.forwarder.forwardWithDestChan(ctx, packet{q, from}, ch, resolverAndDelay{ + name: dnstype.Resolver{ + Addr: "127.0.0.1:53", + }, + }) + default: + // TODO(bradfitz): if we're on an exit node + // on, say, Windows, we need to parse the DNS + // packet in q and call OS-native APIs for + // each question. But we'll want to strip out + // questions for MagicDNS names probably, so + // they don't loop back into + // 100.100.100.100. We don't want to resolve + // MagicDNS names across Tailnets once we + // permit sharing exit nodes. + // + // For now, just return an error. + return nil, err + } + } + if err != nil { + return nil, err + } + select { + case p, ok := <-ch: + if ok { + return p.bs, nil + } + panic("unexpected close chan") + default: + panic("unexpected unreadable chan") + } +} + // resolveLocal returns an IP for the given domain, if domain is in // the local hosts map and has an IP corresponding to the requested // typ (A, AAAA, ALL).