diff --git a/net/dns/resolver/forwarder.go b/net/dns/resolver/forwarder.go index a3fdec84a..0eb0e92c2 100644 --- a/net/dns/resolver/forwarder.go +++ b/net/dns/resolver/forwarder.go @@ -17,10 +17,12 @@ import ( "sync" "time" + dns "golang.org/x/net/dns/dnsmessage" "inet.af/netaddr" "tailscale.com/logtail/backoff" "tailscale.com/net/netns" "tailscale.com/types/logger" + "tailscale.com/util/dnsname" ) // headerBytes is the number of bytes in a DNS message header. @@ -100,6 +102,11 @@ func getTxID(packet []byte) txid { return (txid(hash) << 32) | txid(dnsid) } +type route struct { + suffix string + resolvers []netaddr.IPPort +} + // forwarder forwards DNS packets to a number of upstream nameservers. type forwarder struct { logf logger.Logf @@ -116,10 +123,9 @@ type forwarder struct { conns []*fwdConn mu sync.Mutex - // upstreams are the nameserver addresses that should be used for forwarding. - upstreams []net.Addr - // txMap maps DNS txids to active forwarding records. - txMap map[txid]forwardingRecord + // routes are per-suffix resolvers to use. + routes []route // most specific routes first + txMap map[txid]forwardingRecord // txids to in-flight requests } func init() { @@ -127,24 +133,22 @@ func init() { } func newForwarder(logf logger.Logf, responses chan packet) *forwarder { - return &forwarder{ + ret := &forwarder{ logf: logger.WithPrefix(logf, "forward: "), responses: responses, closed: make(chan struct{}), conns: make([]*fwdConn, connCount), txMap: make(map[txid]forwardingRecord), } -} -func (f *forwarder) Start() error { - f.wg.Add(connCount + 1) - for idx := range f.conns { - f.conns[idx] = newFwdConn(f.logf, idx) - go f.recv(f.conns[idx]) + ret.wg.Add(connCount + 1) + for idx := range ret.conns { + ret.conns[idx] = newFwdConn(ret.logf, idx) + go ret.recv(ret.conns[idx]) } - go f.cleanMap() + go ret.cleanMap() - return nil + return ret } func (f *forwarder) Close() { @@ -171,14 +175,15 @@ func (f *forwarder) rebindFromNetworkChange() { } } -func (f *forwarder) setUpstreams(upstreams []net.Addr) { +func (f *forwarder) setRoutes(routes []route) { + fmt.Println(routes) f.mu.Lock() - f.upstreams = upstreams + f.routes = routes f.mu.Unlock() } // send sends packet to dst. It is best effort. -func (f *forwarder) send(packet []byte, dst net.Addr) { +func (f *forwarder) send(packet []byte, dst netaddr.IPPort) { connIdx := rand.Intn(connCount) conn := f.conns[connIdx] conn.send(packet, dst) @@ -256,24 +261,38 @@ func (f *forwarder) cleanMap() { // forward forwards the query to all upstream nameservers and returns the first response. func (f *forwarder) forward(query packet) error { + domain, err := nameFromQuery(query.bs) + if err != nil { + return err + } + txid := getTxID(query.bs) f.mu.Lock() + routes := f.routes + f.mu.Unlock() - upstreams := f.upstreams - if len(upstreams) == 0 { - f.mu.Unlock() + var resolvers []netaddr.IPPort + for _, route := range routes { + if route.suffix != "." && !dnsname.HasSuffix(domain, route.suffix) { + continue + } + resolvers = route.resolvers + break + } + if len(resolvers) == 0 { return errNoUpstreams } + + f.mu.Lock() f.txMap[txid] = forwardingRecord{ src: query.addr, createdAt: time.Now(), } - f.mu.Unlock() - for _, upstream := range upstreams { - f.send(query.bs, upstream) + for _, resolver := range resolvers { + f.send(query.bs, resolver) } return nil @@ -309,7 +328,7 @@ func newFwdConn(logf logger.Logf, idx int) *fwdConn { // send sends packet to dst using c's connection. // It is best effort. It is UDP, after all. Failures are logged. -func (c *fwdConn) send(packet []byte, dst net.Addr) { +func (c *fwdConn) send(packet []byte, dst netaddr.IPPort) { var b *backoff.Backoff // lazily initialized, since it is not needed in the common case backOff := func(err error) { if b == nil { @@ -335,8 +354,9 @@ func (c *fwdConn) send(packet []byte, dst net.Addr) { } c.mu.Unlock() + a := dst.UDPAddr() c.wg.Add(1) - _, err := conn.WriteTo(packet, dst) + _, err := conn.WriteTo(packet, a) c.wg.Done() if err == nil { // Success @@ -469,3 +489,24 @@ func (c *fwdConn) close() { // Unblock any remaining readers. c.change.Broadcast() } + +// nameFromQuery extracts the normalized query name from bs. +func nameFromQuery(bs []byte) (string, error) { + var parser dns.Parser + + hdr, err := parser.Start(bs) + if err != nil { + return "", err + } + if hdr.Response { + return "", errNotQuery + } + + q, err := parser.Question() + if err != nil { + return "", err + } + + n := q.Name.Data[:q.Name.Length] + return rawNameToLower(n), nil +} diff --git a/net/dns/resolver/tsdns.go b/net/dns/resolver/tsdns.go index d0741c614..ddd2278e3 100644 --- a/net/dns/resolver/tsdns.go +++ b/net/dns/resolver/tsdns.go @@ -10,7 +10,6 @@ import ( "encoding/hex" "errors" "fmt" - "net" "sort" "strings" "sync" @@ -68,11 +67,6 @@ type Config struct { LocalDomains []string } -type route struct { - suffix string - resolvers []netaddr.IPPort -} - // Resolver is a DNS resolver for nodes on the Tailscale network, // associating them with domain names of the form ... // If it is asked to resolve a domain that is not of that form, @@ -100,7 +94,6 @@ type Resolver struct { localDomains []string hostToIP map[string][]netaddr.IP ipToHost map[netaddr.IP]string - routes []route // most specific routes first } // New returns a new resolver. @@ -121,10 +114,6 @@ func New(logf logger.Logf, linkMon *monitor.Mon) (*Resolver, error) { r.unregLinkMon = r.linkMon.RegisterChangeCallback(r.onLinkMonitorChange) } - if err := r.forwarder.Start(); err != nil { - return nil, err - } - r.wg.Add(1) go r.poll() @@ -138,7 +127,6 @@ func isFQDN(s string) bool { func (r *Resolver) SetConfig(cfg Config) error { routes := make([]route, 0, len(cfg.Routes)) reverse := make(map[netaddr.IP]string, len(cfg.Hosts)) - var defaultUpstream []net.Addr for host, ips := range cfg.Hosts { if !isFQDN(host) { @@ -162,32 +150,19 @@ func (r *Resolver) SetConfig(cfg Config) error { suffix: suffix, resolvers: ips, }) - if suffix == "." { - // TODO: this is a temporary hack to forward upstream - // resolvers to the forwarder, which doesn't yet - // understand per-domain resolvers. Effectively, SetConfig - // currently ignores all routes except for ".", which it - // sets as the only resolver. - for _, ip := range ips { - up := ip.UDPAddr() - defaultUpstream = append(defaultUpstream, up) - } - } } // Sort from longest prefix to shortest. sort.Slice(routes, func(i, j int) bool { - return strings.Count(routes[i].suffix, ".") > strings.Count(routes[j].suffix, ".") + return dnsname.NumLabels(routes[i].suffix) > dnsname.NumLabels(routes[j].suffix) }) - r.forwarder.setUpstreams(defaultUpstream) + r.forwarder.setRoutes(routes) r.mu.Lock() defer r.mu.Unlock() r.localDomains = cfg.LocalDomains r.hostToIP = cfg.Hosts r.ipToHost = reverse - r.routes = routes - return nil } @@ -386,6 +361,8 @@ type response struct { } // parseQuery parses the query in given packet into a response struct. +// if the parse is successful, resp.Name contains the normalized name being queried. +// TODO: stuffing the query name in resp.Name temporarily is a hack. Clean it up. func parseQuery(query []byte, resp *response) error { var parser dns.Parser var err error diff --git a/net/dns/resolver/tsdns_server_test.go b/net/dns/resolver/tsdns_server_test.go index 5c0ff1325..bad2fedc0 100644 --- a/net/dns/resolver/tsdns_server_test.go +++ b/net/dns/resolver/tsdns_server_test.go @@ -5,7 +5,7 @@ package resolver import ( - "log" + "fmt" "testing" "github.com/miekg/dns" @@ -16,8 +16,6 @@ import ( // that depends on github.com/miekg/dns // from the rest, which only depends on dnsmessage. -var dnsHandleFunc = dns.HandleFunc - // resolveToIP returns a handler function which responds // to queries of type A it receives with an A record containing ipv4, // to queries of type AAAA with an AAAA record containing ipv6, @@ -68,28 +66,38 @@ func resolveToIP(ipv4, ipv6 netaddr.IP, ns string) dns.HandlerFunc { } } -func resolveToNXDOMAIN(w dns.ResponseWriter, req *dns.Msg) { +var resolveToNXDOMAIN = dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) { m := new(dns.Msg) m.SetRcode(req, dns.RcodeNameError) w.WriteMsg(m) -} - -func serveDNS(tb testing.TB, addr string) (*dns.Server, chan error) { - server := &dns.Server{Addr: addr, Net: "udp"} +}) +func serveDNS(tb testing.TB, addr string, records ...interface{}) *dns.Server { + if len(records)%2 != 0 { + panic("must have an even number of record values") + } + mux := dns.NewServeMux() + for i := 0; i < len(records); i += 2 { + name := records[i].(string) + handler := records[i+1].(dns.Handler) + mux.Handle(name, handler) + } waitch := make(chan struct{}) - server.NotifyStartedFunc = func() { close(waitch) } + server := &dns.Server{ + Addr: addr, + Net: "udp", + Handler: mux, + NotifyStartedFunc: func() { close(waitch) }, + ReusePort: true, + } - errch := make(chan error, 1) go func() { err := server.ListenAndServe() if err != nil { - log.Printf("ListenAndServe(%q): %v", addr, err) + panic(fmt.Sprintf("ListenAndServe(%q): %v", addr, err)) } - errch <- err - close(errch) }() <-waitch - return server, errch + return server } diff --git a/net/dns/resolver/tsdns_test.go b/net/dns/resolver/tsdns_test.go index 19a2561af..343aea032 100644 --- a/net/dns/resolver/tsdns_test.go +++ b/net/dns/resolver/tsdns_test.go @@ -15,13 +15,8 @@ import ( "tailscale.com/tstest" ) -var testipv4 = netaddr.IPv4(1, 2, 3, 4) -var testipv6 = netaddr.IPv6Raw([16]byte{ - 0x00, 0x01, 0x02, 0x03, - 0x04, 0x05, 0x06, 0x07, - 0x08, 0x09, 0x0a, 0x0b, - 0x0c, 0x0d, 0x0e, 0x0f, -}) +var testipv4 = netaddr.MustParseIP("1.2.3.4") +var testipv6 = netaddr.MustParseIP("0001:0203:0405:0607:0809:0a0b:0c0d:0e0f") var dnsCfg = Config{ Hosts: map[string][]netaddr.IP{ @@ -283,32 +278,14 @@ func TestDelegate(t *testing.T) { t.Skip("skipping test that requires localhost IPv6") } - dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site.")) - dnsHandleFunc("nxdomain.site.", resolveToNXDOMAIN) - - v4server, v4errch := serveDNS(t, "127.0.0.1:0") - v6server, v6errch := serveDNS(t, "[::1]:0") - - defer func() { - if err := <-v4errch; err != nil { - t.Errorf("v4 server error: %v", err) - } - if err := <-v6errch; err != nil { - t.Errorf("v6 server error: %v", err) - } - }() - if v4server != nil { - defer v4server.Shutdown() - } - if v6server != nil { - defer v6server.Shutdown() - } - - if v4server == nil || v6server == nil { - // There is an error in at least one of the channels - // and we cannot proceed; return to see it. - return - } + v4server := serveDNS(t, "127.0.0.1:0", + "test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."), + "nxdomain.site.", resolveToNXDOMAIN) + defer v4server.Shutdown() + v6server := serveDNS(t, "[::1]:0", + "test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."), + "nxdomain.site.", resolveToNXDOMAIN) + defer v6server.Shutdown() r, err := New(t.Logf, nil) if err != nil { @@ -377,19 +354,75 @@ func TestDelegate(t *testing.T) { } } -func TestDelegateCollision(t *testing.T) { - dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site.")) +func TestDelegateSplitRoute(t *testing.T) { + test4 := netaddr.MustParseIP("2.3.4.5") + test6 := netaddr.MustParseIP("ff::1") - server, errch := serveDNS(t, "127.0.0.1:0") - defer func() { - if err := <-errch; err != nil { - t.Errorf("server error: %v", err) - } - }() + server1 := serveDNS(t, "127.0.0.1:0", + "test.site.", resolveToIP(testipv4, testipv6, "dns.test.site.")) + defer server1.Shutdown() + server2 := serveDNS(t, "127.0.0.1:0", + "test.other.", resolveToIP(test4, test6, "dns.other.")) + defer server2.Shutdown() + + r, err := New(t.Logf, nil) + if err != nil { + t.Fatalf("start: %v", err) + } + defer r.Close() - if server == nil { - return + cfg := dnsCfg + cfg.Routes = map[string][]netaddr.IPPort{ + ".": {netaddr.MustParseIPPort(server1.PacketConn.LocalAddr().String())}, + "other.": {netaddr.MustParseIPPort(server2.PacketConn.LocalAddr().String())}, + } + r.SetConfig(cfg) + + tests := []struct { + title string + query []byte + response dnsResponse + }{ + { + "general", + dnspacket("test.site.", dns.TypeA), + dnsResponse{ip: testipv4, rcode: dns.RCodeSuccess}, + }, + { + "override", + dnspacket("test.other.", dns.TypeA), + dnsResponse{ip: test4, rcode: dns.RCodeSuccess}, + }, } + + for _, tt := range tests { + t.Run(tt.title, func(t *testing.T) { + payload, err := syncRespond(r, tt.query) + if err != nil { + t.Errorf("err = %v; want nil", err) + return + } + response, err := unpackResponse(payload) + if err != nil { + t.Errorf("extract: err = %v; want nil (in %x)", err, payload) + return + } + if response.rcode != tt.response.rcode { + t.Errorf("rcode = %v; want %v", response.rcode, tt.response.rcode) + } + if response.ip != tt.response.ip { + t.Errorf("ip = %v; want %v", response.ip, tt.response.ip) + } + if response.name != tt.response.name { + t.Errorf("name = %v; want %v", response.name, tt.response.name) + } + }) + } +} + +func TestDelegateCollision(t *testing.T) { + server := serveDNS(t, "127.0.0.1:0", + "test.site.", resolveToIP(testipv4, testipv6, "dns.test.site.")) defer server.Shutdown() r, err := New(t.Logf, nil) @@ -628,8 +661,8 @@ func TestFull(t *testing.T) { {"ipv6", dnspacket("test2.ipn.dev.", dns.TypeAAAA), ipv6Response}, {"no-ipv6", dnspacket("test1.ipn.dev.", dns.TypeAAAA), emptyResponse}, {"upper", dnspacket("TEST1.IPN.DEV.", dns.TypeA), ipv4UppercaseResponse}, - {"ptr", dnspacket("4.3.2.1.in-addr.arpa.", dns.TypePTR), ptrResponse}, - {"ptr", dnspacket("f.0.e.0.d.0.c.0.b.0.a.0.9.0.8.0.7.0.6.0.5.0.4.0.3.0.2.0.1.0.0.0.ip6.arpa.", + {"ptr4", dnspacket("4.3.2.1.in-addr.arpa.", dns.TypePTR), ptrResponse}, + {"ptr6", dnspacket("f.0.e.0.d.0.c.0.b.0.a.0.9.0.8.0.7.0.6.0.5.0.4.0.3.0.2.0.1.0.0.0.ip6.arpa.", dns.TypePTR), ptrResponse6}, {"nxdomain", dnspacket("test3.ipn.dev.", dns.TypeA), nxdomainResponse}, } @@ -702,18 +735,8 @@ func TestTrimRDNSBonjourPrefix(t *testing.T) { } func BenchmarkFull(b *testing.B) { - dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site.")) - - server, errch := serveDNS(b, "127.0.0.1:0") - defer func() { - if err := <-errch; err != nil { - b.Errorf("server error: %v", err) - } - }() - - if server == nil { - return - } + server := serveDNS(b, "127.0.0.1:0", + "test.site.", resolveToIP(testipv4, testipv6, "dns.test.site.")) defer server.Shutdown() r, err := New(b.Logf, nil) diff --git a/util/dnsname/dnsname.go b/util/dnsname/dnsname.go index db9f8b0ff..e4d17940f 100644 --- a/util/dnsname/dnsname.go +++ b/util/dnsname/dnsname.go @@ -124,3 +124,12 @@ func SanitizeHostname(hostname string) string { hostname = TrimCommonSuffixes(hostname) return SanitizeLabel(hostname) } + +// NumLabels returns the number of DNS labels in hostname. +// If hostname is empty or the top-level name ".", returns 0. +func NumLabels(hostname string) int { + if hostname == "" || hostname == "." { + return 0 + } + return strings.Count(hostname, ".") +}