diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 22d249dd4..7b3d11bb2 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -1797,10 +1797,9 @@ func (b *LocalBackend) peerAPIServicesLocked() (ret []tailcfg.Service) { }) } switch runtime.GOOS { - case "linux", "freebsd", "openbsd", "illumos", "darwin": + case "linux", "freebsd", "openbsd", "illumos", "darwin", "windows": // These are the platforms currently supported by // net/dns/resolver/tsdns.go:Resolver.HandleExitNodeDNSQuery. - // TODO(bradfitz): add windows once it's done there. ret = append(ret, tailcfg.Service{ Proto: tailcfg.PeerAPIDNS, Port: 1, // version diff --git a/net/dns/resolver/tsdns.go b/net/dns/resolver/tsdns.go index 8ca24bbaf..7c5258e0d 100644 --- a/net/dns/resolver/tsdns.go +++ b/net/dns/resolver/tsdns.go @@ -360,7 +360,8 @@ func (r *Resolver) HandleExitNodeDNSQuery(ctx context.Context, q []byte, from ne case "windows": // TODO: use DnsQueryEx and write to ch. // See https://docs.microsoft.com/en-us/windows/win32/api/windns/nf-windns-dnsqueryex. - return nil, errors.New("TODO: windows exit node suport") + // For now just use the net package: + return handleExitNodeDNSQueryWithNetPkg(ctx, nil, resp) case "darwin": // /etc/resolv.conf is a lie and only says one upstream DNS // but for now that's probably good enough. Later we'll @@ -404,6 +405,106 @@ func (r *Resolver) HandleExitNodeDNSQuery(ctx context.Context, q []byte, from ne } } +// handleExitNodeDNSQueryWithNetPkg takes a DNS query message in q and +// return a reply (for the ExitDNS DoH service) using the net package's +// native APIs. This is only used on Windows for now. +// +// If resolver is nil, the net.Resolver zero value is used. +// +// response contains the pre-serialized response, which notably +// includes the original question and its header. +func handleExitNodeDNSQueryWithNetPkg(ctx context.Context, resolver *net.Resolver, resp *response) (res []byte, err error) { + if resp.Question.Class != dns.ClassINET { + return nil, errors.New("unsupported class") + } + + r := resolver + if r == nil { + r = new(net.Resolver) + } + name := resp.Question.Name.String() + + handleError := func(err error) (res []byte, _ error) { + if isGoNoSuchHostError(err) { + resp.Header.RCode = dns.RCodeNameError + return marshalResponse(resp) + } + // TODO: map other errors to RCodeServerFailure? + // Or I guess our caller should do that? + return nil, err + } + + resp.Header.RCode = dns.RCodeSuccess // unless changed below + + switch resp.Question.Type { + case dns.TypeA, dns.TypeAAAA: + network := "ip4" + if resp.Question.Type == dns.TypeAAAA { + network = "ip6" + } + ips, err := r.LookupIP(ctx, network, name) + if err != nil { + return handleError(err) + } + for _, stdIP := range ips { + if ip, ok := netaddr.FromStdIP(stdIP); ok { + resp.IPs = append(resp.IPs, ip) + } + } + case dns.TypeTXT: + strs, err := r.LookupTXT(ctx, name) + if err != nil { + return handleError(err) + } + resp.TXT = strs + case dns.TypePTR: + ipStr, ok := unARPA(name) + if !ok { + // TODO: is this RCodeFormatError? + return nil, errors.New("bogus PTR name") + } + addrs, err := r.LookupAddr(ctx, ipStr) + if err != nil { + return handleError(err) + } + if len(addrs) > 0 { + resp.Name, _ = dnsname.ToFQDN(addrs[0]) + } + case dns.TypeCNAME: + cname, err := r.LookupCNAME(ctx, name) + if err != nil { + return handleError(err) + } + resp.CNAME = cname + case dns.TypeSRV: + // Thanks, Go: "To accommodate services publishing SRV + // records under non-standard names, if both service + // and proto are empty strings, LookupSRV looks up + // name directly." + _, srvs, err := r.LookupSRV(ctx, "", "", name) + if err != nil { + return handleError(err) + } + resp.SRVs = srvs + case dns.TypeNS: + nss, err := r.LookupNS(ctx, name) + if err != nil { + return handleError(err) + } + resp.NSs = nss + default: + return nil, fmt.Errorf("unsupported record type %v", resp.Question.Type) + } + return marshalResponse(resp) +} + +func isGoNoSuchHostError(err error) bool { + if de, ok := err.(*net.DNSError); ok { + return de.IsNotFound + } + return false +} + type resolvConfCache struct { mod time.Time size int64 @@ -604,10 +705,27 @@ func (r *Resolver) handleQuery(pkt packet) { type response struct { Header dns.Header Question dns.Question + // Name is the response to a PTR query. Name dnsname.FQDN - // IP is the response to an A, AAAA, or ALL query. - IP netaddr.IP + + // IP and IPs are the responses to an A, AAAA, or ALL query. + // Either/both/neither can be populated. + IP netaddr.IP + IPs []netaddr.IP + + // TXT is the response to a TXT query. + // Each one is its own RR with one string. + TXT []string + + // CNAME is the response to a CNAME query. + CNAME string + + // SRVs are the responses to a SRV query. + SRVs []*net.SRV + + // NSs are the responses to an NS query. + NSs []*net.NS } var dnsParserPool = &sync.Pool{ @@ -683,6 +801,16 @@ func marshalAAAARecord(name dns.Name, ip netaddr.IP, builder *dns.Builder) error return builder.AAAAResource(answerHeader, answer) } +func marshalIP(name dns.Name, ip netaddr.IP, builder *dns.Builder) error { + if ip.Is4() { + return marshalARecord(name, ip, builder) + } + if ip.Is6() { + return marshalAAAARecord(name, ip, builder) + } + return nil +} + // marshalPTRRecord serializes a PTR record into an active builder. // The caller may continue using the builder following the call. func marshalPTRRecord(queryName dns.Name, name dnsname.FQDN, builder *dns.Builder) error { @@ -702,6 +830,83 @@ func marshalPTRRecord(queryName dns.Name, name dnsname.FQDN, builder *dns.Builde return builder.PTRResource(answerHeader, answer) } +func marshalTXT(queryName dns.Name, txts []string, builder *dns.Builder) error { + for _, txt := range txts { + if err := builder.TXTResource(dns.ResourceHeader{ + Name: queryName, + Type: dns.TypeTXT, + Class: dns.ClassINET, + TTL: uint32(defaultTTL / time.Second), + }, dns.TXTResource{ + TXT: []string{txt}, + }); err != nil { + return err + } + } + return nil +} + +func marshalCNAME(queryName dns.Name, cname string, builder *dns.Builder) error { + if cname == "" { + return nil + } + name, err := dns.NewName(cname) + if err != nil { + return err + } + return builder.CNAMEResource(dns.ResourceHeader{ + Name: queryName, + Type: dns.TypeCNAME, + Class: dns.ClassINET, + TTL: uint32(defaultTTL / time.Second), + }, dns.CNAMEResource{ + CNAME: name, + }) +} + +func marshalNS(queryName dns.Name, nss []*net.NS, builder *dns.Builder) error { + for _, ns := range nss { + name, err := dns.NewName(ns.Host) + if err != nil { + return err + } + err = builder.NSResource(dns.ResourceHeader{ + Name: queryName, + Type: dns.TypeNS, + Class: dns.ClassINET, + TTL: uint32(defaultTTL / time.Second), + }, dns.NSResource{NS: name}) + if err != nil { + return err + } + } + return nil +} + +func marshalSRV(queryName dns.Name, srvs []*net.SRV, builder *dns.Builder) error { + for _, s := range srvs { + srvName, err := dns.NewName(s.Target) + if err != nil { + return err + } + err = builder.SRVResource(dns.ResourceHeader{ + Name: queryName, + Type: dns.TypeSRV, + Class: dns.ClassINET, + TTL: uint32(defaultTTL / time.Second), + }, dns.SRVResource{ + Target: srvName, + Priority: s.Priority, + Port: s.Port, + Weight: s.Weight, + }) + if err != nil { + return err + } + } + return nil +} + // marshalResponse serializes the DNS response into a new buffer. func marshalResponse(resp *response) ([]byte, error) { resp.Header.Response = true @@ -712,6 +917,14 @@ func marshalResponse(resp *response) ([]byte, error) { builder := dns.NewBuilder(nil, resp.Header) + // TODO(bradfitz): I'm not sure why this wasn't enabled + // before, but for now (2021-12-09) enable it at least when + // there's more than 1 record (which was never the case + // before), where it really helps. + if len(resp.IPs) > 1 { + builder.EnableCompression() + } + isSuccess := resp.Header.RCode == dns.RCodeSuccess if resp.Question.Type != 0 || isSuccess { @@ -738,13 +951,24 @@ func marshalResponse(resp *response) ([]byte, error) { switch resp.Question.Type { case dns.TypeA, dns.TypeAAAA, dns.TypeALL: - if resp.IP.Is4() { - err = marshalARecord(resp.Question.Name, resp.IP, &builder) - } else if resp.IP.Is6() { - err = marshalAAAARecord(resp.Question.Name, resp.IP, &builder) + if err := marshalIP(resp.Question.Name, resp.IP, &builder); err != nil { + return nil, err + } + for _, ip := range resp.IPs { + if err := marshalIP(resp.Question.Name, ip, &builder); err != nil { + return nil, err + } } case dns.TypePTR: err = marshalPTRRecord(resp.Question.Name, resp.Name, &builder) + case dns.TypeTXT: + err = marshalTXT(resp.Question.Name, resp.TXT, &builder) + case dns.TypeCNAME: + err = marshalCNAME(resp.Question.Name, resp.CNAME, &builder) + case dns.TypeSRV: + err = marshalSRV(resp.Question.Name, resp.SRVs, &builder) + case dns.TypeNS: + err = marshalNS(resp.Question.Name, resp.NSs, &builder) } if err != nil { return nil, err @@ -926,6 +1150,37 @@ func (r *Resolver) respond(query []byte) ([]byte, error) { return marshalResponse(resp) } +// unARPA maps from "4.4.8.8.in-addr.arpa." to "8.8.4.4", etc. +func unARPA(a string) (ipStr string, ok bool) { + const suf4 = ".in-addr.arpa." + if strings.HasSuffix(a, suf4) { + s := strings.TrimSuffix(a, suf4) + // Parse and reverse octets. + ip, err := netaddr.ParseIP(s) + if err != nil || !ip.Is4() { + return "", false + } + a4 := ip.As4() + return netaddr.IPv4(a4[3], a4[2], a4[1], a4[0]).String(), true + } + const suf6 = ".ip6.arpa." + if len(a) == len("e.0.0.2.0.0.0.0.0.0.0.0.0.0.0.0.b.0.8.0.a.0.0.4.0.b.8.f.7.0.6.2.ip6.arpa.") && + strings.HasSuffix(a, suf6) { + var hx [32]byte + var a16 [16]byte + for i := range hx { + hx[31-i] = a[i*2] + if a[i*2+1] != '.' { + return "", false + } + } + hex.Decode(a16[:], hx[:]) + return netaddr.IPFrom16(a16).String(), true + } + return "", false + +} + var ( metricDNSQueryLocal = clientmetric.NewCounter("dns_query_local") metricDNSQueryErrorClosed = clientmetric.NewCounter("dns_query_local_error_closed") diff --git a/net/dns/resolver/tsdns_server_test.go b/net/dns/resolver/tsdns_server_test.go index bf35a03ef..1e5f0294b 100644 --- a/net/dns/resolver/tsdns_server_test.go +++ b/net/dns/resolver/tsdns_server_test.go @@ -6,6 +6,7 @@ package resolver import ( "fmt" + "net" "strings" "testing" @@ -179,6 +180,129 @@ var resolveToNXDOMAIN = dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) w.WriteMsg(m) }) +// weirdoGoCNAMEHandler returns a DNS handler that satisfies +// Go's weird Resolver.LookupCNAME (read its godoc carefully!). +// +// This doesn't even return a CNAME record, because that's not +// what Go looks for. +func weirdoGoCNAMEHandler(target string) dns.HandlerFunc { + return func(w dns.ResponseWriter, req *dns.Msg) { + m := new(dns.Msg) + m.SetReply(req) + question := req.Question[0] + + switch question.Qtype { + case dns.TypeA: + m.Answer = append(m.Answer, &dns.CNAME{ + Hdr: dns.RR_Header{ + Name: target, + Rrtype: dns.TypeCNAME, + Class: dns.ClassINET, + Ttl: 600, + }, + Target: target, + }) + case dns.TypeAAAA: + m.Answer = append(m.Answer, &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: target, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: 600, + }, + AAAA: net.ParseIP("1::2"), + }) + } + w.WriteMsg(m) + } +} + +// dnsHandler returns a handler that replies with the answers/options +// provided. +// +// Types supported: netaddr.IP. +func dnsHandler(answers ...interface{}) dns.HandlerFunc { + return func(w dns.ResponseWriter, req *dns.Msg) { + m := new(dns.Msg) + m.SetReply(req) + if len(req.Question) != 1 { + panic("not a single-question request") + } + m.RecursionAvailable = true // to stop net package's errLameReferral on empty replies + + question := req.Question[0] + for _, a := range answers { + switch a := a.(type) { + default: + panic(fmt.Sprintf("unsupported dnsHandler arg %T", a)) + case netaddr.IP: + ip := a + if ip.Is4() { + m.Answer = append(m.Answer, &dns.A{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + }, + A: ip.IPAddr().IP, + }) + } else if ip.Is6() { + m.Answer = append(m.Answer, &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + }, + AAAA: ip.IPAddr().IP, + }) + } + case dns.PTR: + ptr := a + ptr.Hdr = dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypePTR, + Class: dns.ClassINET, + } + m.Answer = append(m.Answer, &ptr) + case dns.CNAME: + c := a + c.Hdr = dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeCNAME, + Class: dns.ClassINET, + Ttl: 600, + } + m.Answer = append(m.Answer, &c) + case dns.TXT: + txt := a + txt.Hdr = dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeTXT, + Class: dns.ClassINET, + } + m.Answer = append(m.Answer, &txt) + case dns.SRV: + srv := a + srv.Hdr = dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeSRV, + Class: dns.ClassINET, + } + m.Answer = append(m.Answer, &srv) + case dns.NS: + rr := a + rr.Hdr = dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeNS, + Class: dns.ClassINET, + } + m.Answer = append(m.Answer, &rr) + } + } + w.WriteMsg(m) + } +} + func serveDNS(tb testing.TB, addr string, records ...interface{}) *dns.Server { if len(records)%2 != 0 { panic("must have an even number of record values") diff --git a/net/dns/resolver/tsdns_test.go b/net/dns/resolver/tsdns_test.go index 6081c5ea6..1d648dd15 100644 --- a/net/dns/resolver/tsdns_test.go +++ b/net/dns/resolver/tsdns_test.go @@ -6,16 +6,22 @@ package resolver import ( "bytes" + "context" "encoding/hex" + "encoding/json" "errors" "fmt" "math/rand" "net" + "reflect" "runtime" "strconv" "strings" "testing" + "time" + miekdns "github.com/miekg/dns" + "golang.org/x/net/dns/dnsmessage" dns "golang.org/x/net/dns/dnsmessage" "inet.af/netaddr" "tailscale.com/net/tsdial" @@ -43,6 +49,8 @@ var dnsCfg = Config{ const noEdns = 0 +const dnsHeaderLen = 12 + func dnspacket(domain dnsname.FQDN, tp dns.Type, ednsSize uint16) []byte { var dnsHeader dns.Header question := dns.Question{ @@ -1093,3 +1101,383 @@ func TestForwardLinkSelection(t *testing.T) { type linkSelFunc func(ip netaddr.IP) string func (f linkSelFunc) PickLink(ip netaddr.IP) string { return f(ip) } + +func TestHandleExitNodeDNSQueryWithNetPkg(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("skipping test on Windows; waiting for golang.org/issue/33097") + } + + records := []interface{}{ + "no-records.test.", + dnsHandler(), + + "one-a.test.", + dnsHandler(netaddr.MustParseIP("1.2.3.4")), + + "two-a.test.", + dnsHandler(netaddr.MustParseIP("1.2.3.4"), netaddr.MustParseIP("5.6.7.8")), + + "one-aaaa.test.", + dnsHandler(netaddr.MustParseIP("1::2")), + + "two-aaaa.test.", + dnsHandler(netaddr.MustParseIP("1::2"), netaddr.MustParseIP("3::4")), + + "nx-domain.test.", + resolveToNXDOMAIN, + + "4.3.2.1.in-addr.arpa.", + dnsHandler(miekdns.PTR{Ptr: "foo.com."}), + + "cname.test.", + weirdoGoCNAMEHandler("the-target.foo."), + + "txt.test.", + dnsHandler( + miekdns.TXT{Txt: []string{"txt1=one"}}, + miekdns.TXT{Txt: []string{"txt2=two"}}, + miekdns.TXT{Txt: []string{"txt3=three"}}, + ), + + "srv.test.", + dnsHandler( + miekdns.SRV{ + Priority: 1, + Weight: 2, + Port: 3, + Target: "foo.com.", + }, + miekdns.SRV{ + Priority: 4, + Weight: 5, + Port: 6, + Target: "bar.com.", + }, + ), + + "ns.test.", + dnsHandler(miekdns.NS{Ns: "ns1.foo."}, miekdns.NS{Ns: "ns2.bar."}), + } + v4server := serveDNS(t, "127.0.0.1:0", records...) + defer v4server.Shutdown() + + // backendResolver is the resolver between + // handleExitNodeDNSQueryWithNetPkg and its upstream resolver, + // which in this test's case is the miekg/dns test DNS server + // (v4server). + backResolver := &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, addr string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, "udp", v4server.PacketConn.LocalAddr().String()) + }, + } + + t.Run("no_such_host", func(t *testing.T) { + res, err := handleExitNodeDNSQueryWithNetPkg(context.Background(), backResolver, &response{ + Header: dnsmessage.Header{ + ID: 123, + Response: true, + OpCode: 0, // query + }, + Question: dnsmessage.Question{ + Name: dnsmessage.MustNewName("nx-domain.test."), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }, + }) + if err != nil { + t.Fatal(err) + } + if len(res) < dnsHeaderLen { + t.Fatal("short reply") + } + rcode := dns.RCode(res[3] & 0x0f) + if rcode != dns.RCodeNameError { + t.Errorf("RCode = %v; want dns.RCodeNameError", rcode) + t.Logf("Response was: %q", res) + } + }) + + matchPacked := func(want string) func(t testing.TB, got []byte) { + return func(t testing.TB, got []byte) { + if string(got) == want { + return + } + t.Errorf("unexpected reply.\n got: %q\nwant: %q\n", got, want) + t.Errorf("\nin hex:\n got: % 2x\nwant: % 2x\n", got, want) + } + } + + tests := []struct { + Type dnsmessage.Type + Name string + Check func(t testing.TB, got []byte) + }{ + { + Type: dnsmessage.TypeA, + Name: "one-a.test.", + Check: matchPacked("\x00{\x84\x00\x00\x01\x00\x01\x00\x00\x00\x00\x05one-a\x04test\x00\x00\x01\x00\x01\x05one-a\x04test\x00\x00\x01\x00\x01\x00\x00\x02X\x00\x04\x01\x02\x03\x04"), + }, + { + Type: dnsmessage.TypeA, + Name: "two-a.test.", + Check: matchPacked("\x00{\x84\x00\x00\x01\x00\x02\x00\x00\x00\x00\x05two-a\x04test\x00\x00\x01\x00\x01\xc0\f\x00\x01\x00\x01\x00\x00\x02X\x00\x04\x01\x02\x03\x04\xc0\f\x00\x01\x00\x01\x00\x00\x02X\x00\x04\x05\x06\a\b"), + }, + { + Type: dnsmessage.TypeAAAA, + Name: "one-aaaa.test.", + Check: matchPacked("\x00{\x84\x00\x00\x01\x00\x01\x00\x00\x00\x00\bone-aaaa\x04test\x00\x00\x1c\x00\x01\bone-aaaa\x04test\x00\x00\x1c\x00\x01\x00\x00\x02X\x00\x10\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"), + }, + { + Type: dnsmessage.TypeAAAA, + Name: "two-aaaa.test.", + Check: matchPacked("\x00{\x84\x00\x00\x01\x00\x02\x00\x00\x00\x00\btwo-aaaa\x04test\x00\x00\x1c\x00\x01\xc0\f\x00\x1c\x00\x01\x00\x00\x02X\x00\x10\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\xc0\f\x00\x1c\x00\x01\x00\x00\x02X\x00\x10\x00\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04"), + }, + { + Type: dnsmessage.TypePTR, + Name: "4.3.2.1.in-addr.arpa.", + Check: matchPacked("\x00{\x84\x00\x00\x01\x00\x01\x00\x00\x00\x00\x014\x013\x012\x011\ain-addr\x04arpa\x00\x00\f\x00\x01\x014\x013\x012\x011\ain-addr\x04arpa\x00\x00\f\x00\x01\x00\x00\x02X\x00\t\x03foo\x03com\x00"), + }, + { + Type: dnsmessage.TypeCNAME, + Name: "cname.test.", + Check: matchPacked("\x00{\x84\x00\x00\x01\x00\x01\x00\x00\x00\x00\x05cname\x04test\x00\x00\x05\x00\x01\x05cname\x04test\x00\x00\x05\x00\x01\x00\x00\x02X\x00\x10\nthe-target\x03foo\x00"), + }, + + // No records of various types + { + Type: dnsmessage.TypeA, + Name: "no-records.test.", + Check: matchPacked("\x00{\x84\x03\x00\x01\x00\x00\x00\x00\x00\x00\nno-records\x04test\x00\x00\x01\x00\x01"), + }, + { + Type: dnsmessage.TypeAAAA, + Name: "no-records.test.", + Check: matchPacked("\x00{\x84\x03\x00\x01\x00\x00\x00\x00\x00\x00\nno-records\x04test\x00\x00\x1c\x00\x01"), + }, + { + Type: dnsmessage.TypeCNAME, + Name: "no-records.test.", + Check: matchPacked("\x00{\x84\x03\x00\x01\x00\x00\x00\x00\x00\x00\nno-records\x04test\x00\x00\x05\x00\x01"), + }, + { + Type: dnsmessage.TypeSRV, + Name: "no-records.test.", + Check: matchPacked("\x00{\x84\x03\x00\x01\x00\x00\x00\x00\x00\x00\nno-records\x04test\x00\x00!\x00\x01"), + }, + { + Type: dnsmessage.TypeTXT, + Name: "txt.test.", + Check: matchPacked("\x00{\x84\x00\x00\x01\x00\x03\x00\x00\x00\x00\x03txt\x04test\x00\x00\x10\x00\x01\x03txt\x04test\x00\x00\x10\x00\x01\x00\x00\x02X\x00\t\btxt1=one\x03txt\x04test\x00\x00\x10\x00\x01\x00\x00\x02X\x00\t\btxt2=two\x03txt\x04test\x00\x00\x10\x00\x01\x00\x00\x02X\x00\v\ntxt3=three"), + }, + { + Type: dnsmessage.TypeSRV, + Name: "srv.test.", + Check: matchPacked("\x00{\x84\x00\x00\x01\x00\x02\x00\x00\x00\x00\x03srv\x04test\x00\x00!\x00\x01\x03srv\x04test\x00\x00!\x00\x01\x00\x00\x02X\x00\x0f\x00\x01\x00\x02\x00\x03\x03foo\x03com\x00\x03srv\x04test\x00\x00!\x00\x01\x00\x00\x02X\x00\x0f\x00\x04\x00\x05\x00\x06\x03bar\x03com\x00"), + }, + { + Type: dnsmessage.TypeNS, + Name: "ns.test.", + Check: matchPacked("\x00{\x84\x00\x00\x01\x00\x02\x00\x00\x00\x00\x02ns\x04test\x00\x00\x02\x00\x01\x02ns\x04test\x00\x00\x02\x00\x01\x00\x00\x02X\x00\t\x03ns1\x03foo\x00\x02ns\x04test\x00\x00\x02\x00\x01\x00\x00\x02X\x00\t\x03ns2\x03bar\x00"), + }, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("%v_%v", tt.Type, strings.Trim(tt.Name, ".")), func(t *testing.T) { + got, err := handleExitNodeDNSQueryWithNetPkg(context.Background(), backResolver, &response{ + Header: dnsmessage.Header{ + ID: 123, + Response: true, + OpCode: 0, // query + }, + Question: dnsmessage.Question{ + Name: dnsmessage.MustNewName(tt.Name), + Type: tt.Type, + Class: dnsmessage.ClassINET, + }, + }) + if err != nil { + t.Fatal(err) + } + if len(got) < dnsHeaderLen { + t.Errorf("short record") + } + if tt.Check != nil { + tt.Check(t, got) + if t.Failed() { + t.Errorf("Got: %q\nIn hex: % 02x", got, got) + } + } + }) + } + + wrapRes := newWrapResolver(backResolver) + ctx := context.Background() + + t.Run("wrap_ip_a", func(t *testing.T) { + ips, err := wrapRes.LookupIP(ctx, "ip", "two-a.test.") + if err != nil { + t.Fatal(err) + } + if got, want := ips, []net.IP{ + net.ParseIP("1.2.3.4").To4(), + net.ParseIP("5.6.7.8").To4(), + }; !reflect.DeepEqual(got, want) { + t.Errorf("LookupIP = %v; want %v", got, want) + } + }) + + t.Run("wrap_ip_aaaa", func(t *testing.T) { + ips, err := wrapRes.LookupIP(ctx, "ip", "two-aaaa.test.") + if err != nil { + t.Fatal(err) + } + if got, want := ips, []net.IP{ + net.ParseIP("1::2"), + net.ParseIP("3::4"), + }; !reflect.DeepEqual(got, want) { + t.Errorf("LookupIP(v6) = %v; want %v", got, want) + } + }) + + t.Run("wrap_ip_nx", func(t *testing.T) { + ips, err := wrapRes.LookupIP(ctx, "ip", "nx-domain.test.") + if !isGoNoSuchHostError(err) { + t.Errorf("no NX domain = (%v, %v); want no host error", ips, err) + } + }) + + t.Run("wrap_srv", func(t *testing.T) { + _, srvs, err := wrapRes.LookupSRV(ctx, "", "", "srv.test.") + if err != nil { + t.Fatal(err) + } + if got, want := srvs, []*net.SRV{ + { + Target: "foo.com.", + Priority: 1, + Weight: 2, + Port: 3, + }, + { + Target: "bar.com.", + Priority: 4, + Weight: 5, + Port: 6, + }, + }; !reflect.DeepEqual(got, want) { + jgot, _ := json.Marshal(got) + jwant, _ := json.Marshal(want) + t.Errorf("SRV = %s; want %s", jgot, jwant) + } + }) + + t.Run("wrap_txt", func(t *testing.T) { + txts, err := wrapRes.LookupTXT(ctx, "txt.test.") + if err != nil { + t.Fatal(err) + } + if got, want := txts, []string{"txt1=one", "txt2=two", "txt3=three"}; !reflect.DeepEqual(got, want) { + t.Errorf("TXT = %q; want %q", got, want) + } + }) + + t.Run("wrap_ns", func(t *testing.T) { + nss, err := wrapRes.LookupNS(ctx, "ns.test.") + if err != nil { + t.Fatal(err) + } + if got, want := nss, []*net.NS{ + {Host: "ns1.foo."}, + {Host: "ns2.bar."}, + }; !reflect.DeepEqual(got, want) { + jgot, _ := json.Marshal(got) + jwant, _ := json.Marshal(want) + t.Errorf("NS = %s; want %s", jgot, jwant) + } + }) +} + +// newWrapResolver returns a resolver that uses r (via handleExitNodeDNSQueryWithNetPkg) +// to make DNS requests. +func newWrapResolver(r *net.Resolver) *net.Resolver { + if runtime.GOOS == "windows" { + panic("doesn't work on Windows") // golang.org/issue/33097 + } + return &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, addr string) (net.Conn, error) { + return &wrapResolverConn{ctx: ctx, r: r}, nil + }, + } +} + +type wrapResolverConn struct { + ctx context.Context + r *net.Resolver + buf bytes.Buffer +} + +var _ net.PacketConn = (*wrapResolverConn)(nil) + +func (*wrapResolverConn) Close() error { return nil } +func (*wrapResolverConn) LocalAddr() net.Addr { return fakeAddr{} } +func (*wrapResolverConn) RemoteAddr() net.Addr { return fakeAddr{} } +func (*wrapResolverConn) SetDeadline(t time.Time) error { return nil } +func (*wrapResolverConn) SetReadDeadline(t time.Time) error { return nil } +func (*wrapResolverConn) SetWriteDeadline(t time.Time) error { return nil } + +func (a *wrapResolverConn) Read(p []byte) (n int, err error) { + n, _, err = a.ReadFrom(p) + return +} + +func (a *wrapResolverConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + n, err = a.buf.Read(p) + return n, fakeAddr{}, err +} + +func (a *wrapResolverConn) Write(packet []byte) (n int, err error) { + return a.WriteTo(packet, fakeAddr{}) +} + +func (a *wrapResolverConn) WriteTo(q []byte, _ net.Addr) (n int, err error) { + resp := parseExitNodeQuery(q) + if resp == nil { + return 0, errors.New("bad query") + } + res, err := handleExitNodeDNSQueryWithNetPkg(context.Background(), a.r, resp) + if err != nil { + return 0, err + } + a.buf.Write(res) + return len(q), nil +} + +type fakeAddr struct{} + +func (fakeAddr) Network() string { return "unused" } +func (fakeAddr) String() string { return "unused-todoAddr" } + +func TestUnARPA(t *testing.T) { + tests := []struct { + in, want string + }{ + {"", ""}, + {"bad", ""}, + {"4.4.8.8.in-addr.arpa.", "8.8.4.4"}, + {".in-addr.arpa.", ""}, + {"e.0.0.2.0.0.0.0.0.0.0.0.0.0.0.0.b.0.8.0.a.0.0.4.0.b.8.f.7.0.6.2.ip6.arpa.", "2607:f8b0:400a:80b::200e"}, + {".ip6.arpa.", ""}, + } + for _, tt := range tests { + got, ok := unARPA(tt.in) + if ok != (got != "") { + t.Errorf("inconsistent results for %q: (%q, %v)", tt.in, got, ok) + } + if got != tt.want { + t.Errorf("unARPA(%q) = %q; want %q", tt.in, got, tt.want) + } + } +}