From bc34788e653496d7618c42c09ebd908f67cac7d0 Mon Sep 17 00:00:00 2001 From: Dmytro Shynkevych Date: Thu, 27 Aug 2020 00:07:15 -0400 Subject: [PATCH] tsdns: fix accidental rejection of all non-{A, AAAA} questions. This is a bug introduced in a903d6c2ed4d068cf2d212d2541cddeeaa1aca43. Signed-off-by: Dmytro Shynkevych --- wgengine/tsdns/tsdns.go | 26 +++--- wgengine/tsdns/tsdns_server_test.go | 23 +++-- wgengine/tsdns/tsdns_test.go | 125 ++++++++++++++++++---------- 3 files changed, 110 insertions(+), 64 deletions(-) diff --git a/wgengine/tsdns/tsdns.go b/wgengine/tsdns/tsdns.go index 930f0f8ea..7376168bf 100644 --- a/wgengine/tsdns/tsdns.go +++ b/wgengine/tsdns/tsdns.go @@ -184,7 +184,7 @@ func (r *Resolver) NextResponse() (Packet, error) { // Resolve maps a given domain name to the IP address of the host that owns it. // The domain name must be in canonical form (with a trailing period). -func (r *Resolver) Resolve(domain string) (netaddr.IP, dns.RCode, error) { +func (r *Resolver) Resolve(domain string, tp dns.Type) (netaddr.IP, dns.RCode, error) { r.mu.Lock() dnsMap := r.dnsMap r.mu.Unlock() @@ -208,7 +208,13 @@ func (r *Resolver) Resolve(domain string) (netaddr.IP, dns.RCode, error) { if !found { return netaddr.IP{}, dns.RCodeNameError, nil } - return addr, dns.RCodeSuccess, nil + + switch tp { + case dns.TypeA, dns.TypeAAAA, dns.TypeALL: + return addr, dns.RCodeSuccess, nil + default: + return netaddr.IP{}, dns.RCodeNotImplemented, errNotImplemented + } } // ResolveReverse returns the unique domain name that maps to the given address. @@ -501,7 +507,6 @@ func (r *Resolver) respondReverse(query []byte, name string, resp *response) ([] // It is more likely that we failed in parsing the name than that it is actually malformed. // To avoid frustrating users, just log and delegate. if !ok { - // Without this conversion, escape analysis rules that resp escapes. r.logf("parsing rdns: malformed name: %s", name) return nil, errNotOurName } @@ -542,17 +547,12 @@ func (r *Resolver) respond(query []byte) ([]byte, error) { return r.respondReverse(query, name, resp) } - switch resp.Question.Type { - case dns.TypeA, dns.TypeAAAA, dns.TypeALL: - resp.IP, resp.Header.RCode, err = r.Resolve(name) - // This return code is special: it requests forwarding. - if resp.Header.RCode == dns.RCodeRefused { - return nil, errNotOurName - } - default: - resp.Header.RCode = dns.RCodeNotImplemented - err = errNotImplemented + resp.IP, resp.Header.RCode, err = r.Resolve(name, resp.Question.Type) + // This return code is special: it requests forwarding. + if resp.Header.RCode == dns.RCodeRefused { + return nil, errNotOurName } + // We will not return this error: it is the sender's fault. if err != nil { r.logf("resolving: %v", err) diff --git a/wgengine/tsdns/tsdns_server_test.go b/wgengine/tsdns/tsdns_server_test.go index ae2a86f19..bffb8b869 100644 --- a/wgengine/tsdns/tsdns_server_test.go +++ b/wgengine/tsdns/tsdns_server_test.go @@ -16,9 +16,10 @@ import ( var dnsHandleFunc = dns.HandleFunc // resolveToIP returns a handler function which responds -// to queries of type A it receives with an A record containing ipv4 -// and to queries of type AAAA with an AAAA records containing ipv6. -func resolveToIP(ipv4, ipv6 netaddr.IP) dns.HandlerFunc { +// to queries of type A it receives with an A record containing ipv4, +// to queries of type AAAA with an AAAA record containing ipv6, +// to queries of type NS with an NS record containg name. +func resolveToIP(ipv4, ipv6 netaddr.IP, ns string) dns.HandlerFunc { return func(w dns.ResponseWriter, req *dns.Msg) { m := new(dns.Msg) m.SetReply(req) @@ -29,7 +30,8 @@ func resolveToIP(ipv4, ipv6 netaddr.IP) dns.HandlerFunc { question := req.Question[0] var ans dns.RR - if question.Qtype == dns.TypeA { + switch question.Qtype { + case dns.TypeA: ans = &dns.A{ Hdr: dns.RR_Header{ Name: question.Name, @@ -38,7 +40,7 @@ func resolveToIP(ipv4, ipv6 netaddr.IP) dns.HandlerFunc { }, A: ipv4.IPAddr().IP, } - } else { + case dns.TypeAAAA: ans = &dns.AAAA{ Hdr: dns.RR_Header{ Name: question.Name, @@ -47,9 +49,18 @@ func resolveToIP(ipv4, ipv6 netaddr.IP) dns.HandlerFunc { }, AAAA: ipv6.IPAddr().IP, } + case dns.TypeNS: + ans = &dns.NS{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeNS, + Class: dns.ClassINET, + }, + Ns: ns, + } } - m.Answer = append(m.Answer, ans) + m.Answer = append(m.Answer, ans) w.WriteMsg(m) } } diff --git a/wgengine/tsdns/tsdns_test.go b/wgengine/tsdns/tsdns_test.go index fcdb8ef18..427683c91 100644 --- a/wgengine/tsdns/tsdns_test.go +++ b/wgengine/tsdns/tsdns_test.go @@ -48,49 +48,64 @@ func dnspacket(domain string, tp dns.Type) []byte { return payload } -func extractipcode(response []byte) (netaddr.IP, dns.RCode, error) { - var ip netaddr.IP +type dnsResponse struct { + ip netaddr.IP + name string + rcode dns.RCode +} + +func unpackResponse(payload []byte) (dnsResponse, error) { + var response dnsResponse var parser dns.Parser - h, err := parser.Start(response) + h, err := parser.Start(payload) if err != nil { - return ip, 0, err + return response, err } if !h.Response { - return ip, 0, errors.New("not a response") + return response, errors.New("not a response") } - if h.RCode != dns.RCodeSuccess { - return ip, h.RCode, nil + + response.rcode = h.RCode + if response.rcode != dns.RCodeSuccess { + return response, nil } err = parser.SkipAllQuestions() if err != nil { - return ip, 0, err + return response, err } ah, err := parser.AnswerHeader() if err != nil { - return ip, 0, err + return response, err } + switch ah.Type { case dns.TypeA: res, err := parser.AResource() if err != nil { - return ip, 0, err + return response, err } - ip = netaddr.IPv4(res.A[0], res.A[1], res.A[2], res.A[3]) + response.ip = netaddr.IPv4(res.A[0], res.A[1], res.A[2], res.A[3]) case dns.TypeAAAA: res, err := parser.AAAAResource() if err != nil { - return ip, 0, err + return response, err } - ip = netaddr.IPv6Raw(res.AAAA) + response.ip = netaddr.IPv6Raw(res.AAAA) + case dns.TypeNS: + res, err := parser.NSResource() + if err != nil { + return response, err + } + response.name = res.NS.String() default: - return ip, 0, errors.New("type not in {A, AAAA}") + return response, errors.New("type not in {A, AAAA, NS}") } - return ip, h.RCode, nil + return response, nil } func syncRespond(r *Resolver, query []byte) ([]byte, error) { @@ -188,20 +203,21 @@ func TestResolve(t *testing.T) { defer r.Close() tests := []struct { - name string - domain string - ip netaddr.IP - code dns.RCode + name string + qname string + qtype dns.Type + ip netaddr.IP + code dns.RCode }{ - {"ipv4", "test1.ipn.dev.", testipv4, dns.RCodeSuccess}, - {"ipv6", "test2.ipn.dev.", testipv6, dns.RCodeSuccess}, - {"nxdomain", "test3.ipn.dev.", netaddr.IP{}, dns.RCodeNameError}, - {"foreign domain", "google.com.", netaddr.IP{}, dns.RCodeRefused}, + {"ipv4", "test1.ipn.dev.", dns.TypeA, testipv4, dns.RCodeSuccess}, + {"ipv6", "test2.ipn.dev.", dns.TypeAAAA, testipv6, dns.RCodeSuccess}, + {"nxdomain", "test3.ipn.dev.", dns.TypeA, netaddr.IP{}, dns.RCodeNameError}, + {"foreign domain", "google.com.", dns.TypeA, netaddr.IP{}, dns.RCodeRefused}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ip, code, err := r.Resolve(tt.domain) + ip, code, err := r.Resolve(tt.qname, tt.qtype) if err != nil { t.Errorf("err = %v; want nil", err) } @@ -256,7 +272,7 @@ func TestDelegate(t *testing.T) { rc := tstest.NewResourceCheck() defer rc.Assert(t) - dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6)) + dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site.")) dnsHandleFunc("nxdomain.site.", resolveToNXDOMAIN) v4server, v4errch := serveDNS("127.0.0.1:0") @@ -296,40 +312,59 @@ func TestDelegate(t *testing.T) { defer r.Close() tests := []struct { - name string - query []byte - ip netaddr.IP - code dns.RCode + title string + query []byte + response dnsResponse }{ - {"ipv4", dnspacket("test.site.", dns.TypeA), testipv4, dns.RCodeSuccess}, - {"ipv6", dnspacket("test.site.", dns.TypeAAAA), testipv6, dns.RCodeSuccess}, - {"nxdomain", dnspacket("nxdomain.site.", dns.TypeA), netaddr.IP{}, dns.RCodeNameError}, + { + "ipv4", + dnspacket("test.site.", dns.TypeA), + dnsResponse{ip: testipv4, rcode: dns.RCodeSuccess}, + }, + { + "ipv6", + dnspacket("test.site.", dns.TypeAAAA), + dnsResponse{ip: testipv6, rcode: dns.RCodeSuccess}, + }, + { + "ns", + dnspacket("test.site.", dns.TypeNS), + dnsResponse{name: "dns.test.site.", rcode: dns.RCodeSuccess}, + }, + { + "nxdomain", + dnspacket("nxdomain.site.", dns.TypeA), + dnsResponse{rcode: dns.RCodeNameError}, + }, } for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - resp, err := syncRespond(r, tt.query) + t.Run(tt.title, func(t *testing.T) { + payload, err := syncRespond(r, tt.query) if err != nil { t.Errorf("err = %v; want nil", err) return } - ip, code, err := extractipcode(resp) + response, err := unpackResponse(payload) if err != nil { - t.Errorf("extract: err = %v; want nil (in %x)", err, resp) + t.Errorf("extract: err = %v; want nil (in %x)", err, payload) return } - if code != tt.code { - t.Errorf("code = %v; want %v", code, tt.code) + if response.rcode != tt.response.rcode { + t.Errorf("rcode = %v; want %v", response.rcode, tt.response.rcode) } - if ip != tt.ip { - t.Errorf("ip = %v; want %v", ip, tt.ip) + if response.ip != tt.response.ip { + t.Errorf("ip = %v; want %v", response.ip, tt.response.ip) + } + if response.name != tt.response.name { + t.Errorf("name = %v; want %v", response.name, tt.response.name) } }) } } func TestDelegateCollision(t *testing.T) { - dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6)) + dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site.")) server, errch := serveDNS("127.0.0.1:0") defer func() { @@ -425,13 +460,13 @@ func TestConcurrentSetMap(t *testing.T) { }() go func() { defer wg.Done() - r.Resolve("test1.ipn.dev") + r.Resolve("test1.ipn.dev", dns.TypeA) }() wg.Wait() } func TestConcurrentSetUpstreams(t *testing.T) { - dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6)) + dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site.")) server, errch := serveDNS("127.0.0.1:0") defer func() { @@ -570,7 +605,7 @@ func TestFull(t *testing.T) { {"ipv6", dnspacket("test2.ipn.dev.", dns.TypeAAAA), ipv6Response}, {"upper", dnspacket("TEST1.IPN.DEV.", dns.TypeA), ipv4UppercaseResponse}, {"ptr", dnspacket("4.3.2.1.in-addr.arpa.", dns.TypePTR), ptrResponse}, - {"error", dnspacket("test3.ipn.dev.", dns.TypeA), nxdomainResponse}, + {"nxdomain", dnspacket("test3.ipn.dev.", dns.TypeA), nxdomainResponse}, } for _, tt := range tests { @@ -619,7 +654,7 @@ func TestAllocs(t *testing.T) { } func BenchmarkFull(b *testing.B) { - dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6)) + dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site.")) server, errch := serveDNS("127.0.0.1:0") defer func() {