tsdns: fix accidental rejection of all non-{A, AAAA} questions.

This is a bug introduced in a903d6c2ed.

Signed-off-by: Dmytro Shynkevych <dmytro@tailscale.com>
reviewable/pr720/r1
Dmytro Shynkevych 4 years ago
parent 28f9cd06f5
commit bc34788e65
No known key found for this signature in database
GPG Key ID: FF5E2F3DAD97EA23

@ -184,7 +184,7 @@ func (r *Resolver) NextResponse() (Packet, error) {
// Resolve maps a given domain name to the IP address of the host that owns it. // Resolve maps a given domain name to the IP address of the host that owns it.
// The domain name must be in canonical form (with a trailing period). // The domain name must be in canonical form (with a trailing period).
func (r *Resolver) Resolve(domain string) (netaddr.IP, dns.RCode, error) { func (r *Resolver) Resolve(domain string, tp dns.Type) (netaddr.IP, dns.RCode, error) {
r.mu.Lock() r.mu.Lock()
dnsMap := r.dnsMap dnsMap := r.dnsMap
r.mu.Unlock() r.mu.Unlock()
@ -208,7 +208,13 @@ func (r *Resolver) Resolve(domain string) (netaddr.IP, dns.RCode, error) {
if !found { if !found {
return netaddr.IP{}, dns.RCodeNameError, nil return netaddr.IP{}, dns.RCodeNameError, nil
} }
return addr, dns.RCodeSuccess, nil
switch tp {
case dns.TypeA, dns.TypeAAAA, dns.TypeALL:
return addr, dns.RCodeSuccess, nil
default:
return netaddr.IP{}, dns.RCodeNotImplemented, errNotImplemented
}
} }
// ResolveReverse returns the unique domain name that maps to the given address. // ResolveReverse returns the unique domain name that maps to the given address.
@ -501,7 +507,6 @@ func (r *Resolver) respondReverse(query []byte, name string, resp *response) ([]
// It is more likely that we failed in parsing the name than that it is actually malformed. // It is more likely that we failed in parsing the name than that it is actually malformed.
// To avoid frustrating users, just log and delegate. // To avoid frustrating users, just log and delegate.
if !ok { if !ok {
// Without this conversion, escape analysis rules that resp escapes.
r.logf("parsing rdns: malformed name: %s", name) r.logf("parsing rdns: malformed name: %s", name)
return nil, errNotOurName return nil, errNotOurName
} }
@ -542,17 +547,12 @@ func (r *Resolver) respond(query []byte) ([]byte, error) {
return r.respondReverse(query, name, resp) return r.respondReverse(query, name, resp)
} }
switch resp.Question.Type { resp.IP, resp.Header.RCode, err = r.Resolve(name, resp.Question.Type)
case dns.TypeA, dns.TypeAAAA, dns.TypeALL: // This return code is special: it requests forwarding.
resp.IP, resp.Header.RCode, err = r.Resolve(name) if resp.Header.RCode == dns.RCodeRefused {
// This return code is special: it requests forwarding. return nil, errNotOurName
if resp.Header.RCode == dns.RCodeRefused {
return nil, errNotOurName
}
default:
resp.Header.RCode = dns.RCodeNotImplemented
err = errNotImplemented
} }
// We will not return this error: it is the sender's fault. // We will not return this error: it is the sender's fault.
if err != nil { if err != nil {
r.logf("resolving: %v", err) r.logf("resolving: %v", err)

@ -16,9 +16,10 @@ import (
var dnsHandleFunc = dns.HandleFunc var dnsHandleFunc = dns.HandleFunc
// resolveToIP returns a handler function which responds // resolveToIP returns a handler function which responds
// to queries of type A it receives with an A record containing ipv4 // to queries of type A it receives with an A record containing ipv4,
// and to queries of type AAAA with an AAAA records containing ipv6. // to queries of type AAAA with an AAAA record containing ipv6,
func resolveToIP(ipv4, ipv6 netaddr.IP) dns.HandlerFunc { // to queries of type NS with an NS record containg name.
func resolveToIP(ipv4, ipv6 netaddr.IP, ns string) dns.HandlerFunc {
return func(w dns.ResponseWriter, req *dns.Msg) { return func(w dns.ResponseWriter, req *dns.Msg) {
m := new(dns.Msg) m := new(dns.Msg)
m.SetReply(req) m.SetReply(req)
@ -29,7 +30,8 @@ func resolveToIP(ipv4, ipv6 netaddr.IP) dns.HandlerFunc {
question := req.Question[0] question := req.Question[0]
var ans dns.RR var ans dns.RR
if question.Qtype == dns.TypeA { switch question.Qtype {
case dns.TypeA:
ans = &dns.A{ ans = &dns.A{
Hdr: dns.RR_Header{ Hdr: dns.RR_Header{
Name: question.Name, Name: question.Name,
@ -38,7 +40,7 @@ func resolveToIP(ipv4, ipv6 netaddr.IP) dns.HandlerFunc {
}, },
A: ipv4.IPAddr().IP, A: ipv4.IPAddr().IP,
} }
} else { case dns.TypeAAAA:
ans = &dns.AAAA{ ans = &dns.AAAA{
Hdr: dns.RR_Header{ Hdr: dns.RR_Header{
Name: question.Name, Name: question.Name,
@ -47,9 +49,18 @@ func resolveToIP(ipv4, ipv6 netaddr.IP) dns.HandlerFunc {
}, },
AAAA: ipv6.IPAddr().IP, AAAA: ipv6.IPAddr().IP,
} }
case dns.TypeNS:
ans = &dns.NS{
Hdr: dns.RR_Header{
Name: question.Name,
Rrtype: dns.TypeNS,
Class: dns.ClassINET,
},
Ns: ns,
}
} }
m.Answer = append(m.Answer, ans)
m.Answer = append(m.Answer, ans)
w.WriteMsg(m) w.WriteMsg(m)
} }
} }

@ -48,49 +48,64 @@ func dnspacket(domain string, tp dns.Type) []byte {
return payload return payload
} }
func extractipcode(response []byte) (netaddr.IP, dns.RCode, error) { type dnsResponse struct {
var ip netaddr.IP ip netaddr.IP
name string
rcode dns.RCode
}
func unpackResponse(payload []byte) (dnsResponse, error) {
var response dnsResponse
var parser dns.Parser var parser dns.Parser
h, err := parser.Start(response) h, err := parser.Start(payload)
if err != nil { if err != nil {
return ip, 0, err return response, err
} }
if !h.Response { if !h.Response {
return ip, 0, errors.New("not a response") return response, errors.New("not a response")
} }
if h.RCode != dns.RCodeSuccess {
return ip, h.RCode, nil response.rcode = h.RCode
if response.rcode != dns.RCodeSuccess {
return response, nil
} }
err = parser.SkipAllQuestions() err = parser.SkipAllQuestions()
if err != nil { if err != nil {
return ip, 0, err return response, err
} }
ah, err := parser.AnswerHeader() ah, err := parser.AnswerHeader()
if err != nil { if err != nil {
return ip, 0, err return response, err
} }
switch ah.Type { switch ah.Type {
case dns.TypeA: case dns.TypeA:
res, err := parser.AResource() res, err := parser.AResource()
if err != nil { if err != nil {
return ip, 0, err return response, err
} }
ip = netaddr.IPv4(res.A[0], res.A[1], res.A[2], res.A[3]) response.ip = netaddr.IPv4(res.A[0], res.A[1], res.A[2], res.A[3])
case dns.TypeAAAA: case dns.TypeAAAA:
res, err := parser.AAAAResource() res, err := parser.AAAAResource()
if err != nil { if err != nil {
return ip, 0, err return response, err
} }
ip = netaddr.IPv6Raw(res.AAAA) response.ip = netaddr.IPv6Raw(res.AAAA)
case dns.TypeNS:
res, err := parser.NSResource()
if err != nil {
return response, err
}
response.name = res.NS.String()
default: default:
return ip, 0, errors.New("type not in {A, AAAA}") return response, errors.New("type not in {A, AAAA, NS}")
} }
return ip, h.RCode, nil return response, nil
} }
func syncRespond(r *Resolver, query []byte) ([]byte, error) { func syncRespond(r *Resolver, query []byte) ([]byte, error) {
@ -188,20 +203,21 @@ func TestResolve(t *testing.T) {
defer r.Close() defer r.Close()
tests := []struct { tests := []struct {
name string name string
domain string qname string
ip netaddr.IP qtype dns.Type
code dns.RCode ip netaddr.IP
code dns.RCode
}{ }{
{"ipv4", "test1.ipn.dev.", testipv4, dns.RCodeSuccess}, {"ipv4", "test1.ipn.dev.", dns.TypeA, testipv4, dns.RCodeSuccess},
{"ipv6", "test2.ipn.dev.", testipv6, dns.RCodeSuccess}, {"ipv6", "test2.ipn.dev.", dns.TypeAAAA, testipv6, dns.RCodeSuccess},
{"nxdomain", "test3.ipn.dev.", netaddr.IP{}, dns.RCodeNameError}, {"nxdomain", "test3.ipn.dev.", dns.TypeA, netaddr.IP{}, dns.RCodeNameError},
{"foreign domain", "google.com.", netaddr.IP{}, dns.RCodeRefused}, {"foreign domain", "google.com.", dns.TypeA, netaddr.IP{}, dns.RCodeRefused},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ip, code, err := r.Resolve(tt.domain) ip, code, err := r.Resolve(tt.qname, tt.qtype)
if err != nil { if err != nil {
t.Errorf("err = %v; want nil", err) t.Errorf("err = %v; want nil", err)
} }
@ -256,7 +272,7 @@ func TestDelegate(t *testing.T) {
rc := tstest.NewResourceCheck() rc := tstest.NewResourceCheck()
defer rc.Assert(t) defer rc.Assert(t)
dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6)) dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
dnsHandleFunc("nxdomain.site.", resolveToNXDOMAIN) dnsHandleFunc("nxdomain.site.", resolveToNXDOMAIN)
v4server, v4errch := serveDNS("127.0.0.1:0") v4server, v4errch := serveDNS("127.0.0.1:0")
@ -296,40 +312,59 @@ func TestDelegate(t *testing.T) {
defer r.Close() defer r.Close()
tests := []struct { tests := []struct {
name string title string
query []byte query []byte
ip netaddr.IP response dnsResponse
code dns.RCode
}{ }{
{"ipv4", dnspacket("test.site.", dns.TypeA), testipv4, dns.RCodeSuccess}, {
{"ipv6", dnspacket("test.site.", dns.TypeAAAA), testipv6, dns.RCodeSuccess}, "ipv4",
{"nxdomain", dnspacket("nxdomain.site.", dns.TypeA), netaddr.IP{}, dns.RCodeNameError}, dnspacket("test.site.", dns.TypeA),
dnsResponse{ip: testipv4, rcode: dns.RCodeSuccess},
},
{
"ipv6",
dnspacket("test.site.", dns.TypeAAAA),
dnsResponse{ip: testipv6, rcode: dns.RCodeSuccess},
},
{
"ns",
dnspacket("test.site.", dns.TypeNS),
dnsResponse{name: "dns.test.site.", rcode: dns.RCodeSuccess},
},
{
"nxdomain",
dnspacket("nxdomain.site.", dns.TypeA),
dnsResponse{rcode: dns.RCodeNameError},
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.title, func(t *testing.T) {
resp, err := syncRespond(r, tt.query) payload, err := syncRespond(r, tt.query)
if err != nil { if err != nil {
t.Errorf("err = %v; want nil", err) t.Errorf("err = %v; want nil", err)
return return
} }
ip, code, err := extractipcode(resp) response, err := unpackResponse(payload)
if err != nil { if err != nil {
t.Errorf("extract: err = %v; want nil (in %x)", err, resp) t.Errorf("extract: err = %v; want nil (in %x)", err, payload)
return return
} }
if code != tt.code { if response.rcode != tt.response.rcode {
t.Errorf("code = %v; want %v", code, tt.code) t.Errorf("rcode = %v; want %v", response.rcode, tt.response.rcode)
} }
if ip != tt.ip { if response.ip != tt.response.ip {
t.Errorf("ip = %v; want %v", ip, tt.ip) t.Errorf("ip = %v; want %v", response.ip, tt.response.ip)
}
if response.name != tt.response.name {
t.Errorf("name = %v; want %v", response.name, tt.response.name)
} }
}) })
} }
} }
func TestDelegateCollision(t *testing.T) { func TestDelegateCollision(t *testing.T) {
dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6)) dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
server, errch := serveDNS("127.0.0.1:0") server, errch := serveDNS("127.0.0.1:0")
defer func() { defer func() {
@ -425,13 +460,13 @@ func TestConcurrentSetMap(t *testing.T) {
}() }()
go func() { go func() {
defer wg.Done() defer wg.Done()
r.Resolve("test1.ipn.dev") r.Resolve("test1.ipn.dev", dns.TypeA)
}() }()
wg.Wait() wg.Wait()
} }
func TestConcurrentSetUpstreams(t *testing.T) { func TestConcurrentSetUpstreams(t *testing.T) {
dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6)) dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
server, errch := serveDNS("127.0.0.1:0") server, errch := serveDNS("127.0.0.1:0")
defer func() { defer func() {
@ -570,7 +605,7 @@ func TestFull(t *testing.T) {
{"ipv6", dnspacket("test2.ipn.dev.", dns.TypeAAAA), ipv6Response}, {"ipv6", dnspacket("test2.ipn.dev.", dns.TypeAAAA), ipv6Response},
{"upper", dnspacket("TEST1.IPN.DEV.", dns.TypeA), ipv4UppercaseResponse}, {"upper", dnspacket("TEST1.IPN.DEV.", dns.TypeA), ipv4UppercaseResponse},
{"ptr", dnspacket("4.3.2.1.in-addr.arpa.", dns.TypePTR), ptrResponse}, {"ptr", dnspacket("4.3.2.1.in-addr.arpa.", dns.TypePTR), ptrResponse},
{"error", dnspacket("test3.ipn.dev.", dns.TypeA), nxdomainResponse}, {"nxdomain", dnspacket("test3.ipn.dev.", dns.TypeA), nxdomainResponse},
} }
for _, tt := range tests { for _, tt := range tests {
@ -619,7 +654,7 @@ func TestAllocs(t *testing.T) {
} }
func BenchmarkFull(b *testing.B) { func BenchmarkFull(b *testing.B) {
dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6)) dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
server, errch := serveDNS("127.0.0.1:0") server, errch := serveDNS("127.0.0.1:0")
defer func() { defer func() {

Loading…
Cancel
Save