diff --git a/net/dns/resolver/forwarder.go b/net/dns/resolver/forwarder.go index b747fb153..0ee0751ba 100644 --- a/net/dns/resolver/forwarder.go +++ b/net/dns/resolver/forwarder.go @@ -518,6 +518,8 @@ func (f *forwarder) send(ctx context.Context, fq *forwardQuery, rr resolverAndDe return f.sendUDP(ctx, fq, rr) } +var errServerFailure = errors.New("response code indicates server issue") + func (f *forwarder) sendUDP(ctx context.Context, fq *forwardQuery, rr resolverAndDelay) (ret []byte, err error) { ipp, ok := rr.name.IPPort() if !ok { @@ -581,7 +583,7 @@ func (f *forwarder) sendUDP(ctx context.Context, fq *forwardQuery, rr resolverAn if rcode == dns.RCodeServerFailure { f.logf("recv: response code indicating server failure: %d", rcode) metricDNSFwdUDPErrorServer.Add(1) - return nil, errors.New("response code indicates server issue") + return nil, errServerFailure } if truncated { @@ -751,6 +753,20 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo } numErr++ if numErr == len(resolvers) { + if firstErr == errServerFailure { + res, err := servfailResponse(query) + if err != nil { + f.logf("building servfail response: %v", err) + return firstErr + } + + select { + case <-ctx.Done(): + metricDNSFwdErrorContext.Add(1) + metricDNSFwdErrorContextGotError.Add(1) + case responseChan <- res: + } + } return firstErr } case <-ctx.Done(): @@ -809,6 +825,27 @@ func nxDomainResponse(req packet) (res packet, err error) { return res, err } +// servfailResponse returns a SERVFAIL error reply for the provided request. +func servfailResponse(req packet) (res packet, err error) { + p := dnsParserPool.Get().(*dnsParser) + defer dnsParserPool.Put(p) + + if err := p.parseQuery(req.bs); err != nil { + return packet{}, err + } + + h := p.Header + h.Response = true + h.Authoritative = true + h.RCode = dns.RCodeServerFailure + b := dns.NewBuilder(nil, h) + b.StartQuestions() + b.Question(p.Question) + res.bs, err = b.Finish() + res.addr = req.addr + return res, err +} + // closePool is a dynamic set of io.Closers to close as a group. // It's intended to be Closed at most once. // diff --git a/net/dns/resolver/tsdns.go b/net/dns/resolver/tsdns.go index 05df02484..625493da6 100644 --- a/net/dns/resolver/tsdns.go +++ b/net/dns/resolver/tsdns.go @@ -282,7 +282,15 @@ func (r *Resolver) Query(ctx context.Context, bs []byte, from netaddr.IPPort) ([ defer cancel() err = r.forwarder.forwardWithDestChan(ctx, packet{bs, from}, responses) if err != nil { - return nil, err + select { + // Best effort: use any error response sent by forwardWithDestChan. + // This is present in some errors paths, such as when all upstream + // DNS servers replied with an error. + case resp := <-responses: + return resp.bs, err + default: + return nil, err + } } return (<-responses).bs, nil } diff --git a/net/dns/resolver/tsdns_test.go b/net/dns/resolver/tsdns_test.go index cb8a000ed..4ca528654 100644 --- a/net/dns/resolver/tsdns_test.go +++ b/net/dns/resolver/tsdns_test.go @@ -1417,3 +1417,45 @@ func TestUnARPA(t *testing.T) { } } } + +// TestServfail validates that a SERVFAIL error response is returned if +// all upstream resolvers respond with SERVFAIL. +// +// See: https://github.com/tailscale/tailscale/issues/4722 +func TestServfail(t *testing.T) { + server := serveDNS(t, "127.0.0.1:0", "test.site.", miekdns.HandlerFunc(func(w miekdns.ResponseWriter, req *miekdns.Msg) { + m := new(miekdns.Msg) + m.Rcode = miekdns.RcodeServerFailure + w.WriteMsg(m) + })) + defer server.Shutdown() + + r := newResolver(t) + defer r.Close() + + cfg := dnsCfg + cfg.Routes = map[dnsname.FQDN][]*dnstype.Resolver{ + ".": {{Addr: server.PacketConn.LocalAddr().String()}}, + } + r.SetConfig(cfg) + + pkt, err := syncRespond(r, dnspacket("test.site.", dns.TypeA, noEdns)) + if err != errServerFailure { + t.Errorf("err = %v, want %v", err, errServerFailure) + } + + wantPkt := []byte{ + 0x00, 0x00, // transaction id: 0 + 0x84, 0x02, // flags: response, authoritative, error: servfail + 0x00, 0x01, // one question + 0x00, 0x00, // no answers + 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs + // Question: + 0x04, 0x74, 0x65, 0x73, 0x74, 0x04, 0x73, 0x69, 0x74, 0x65, 0x00, // name + 0x00, 0x01, 0x00, 0x01, // type A, class IN + } + + if !bytes.Equal(pkt, wantPkt) { + t.Errorf("response was %X, want %X", pkt, wantPkt) + } +}