net/dns/resolver: teach the forwarder to do per-domain routing.

Given a DNS route map, the forwarder selects the right set of
upstreams for a given name.

Signed-off-by: David Anderson <danderson@tailscale.com>
pull/1644/head
David Anderson 3 years ago
parent 4ed111281b
commit 9f105d3968

@ -17,10 +17,12 @@ import (
"sync"
"time"
dns "golang.org/x/net/dns/dnsmessage"
"inet.af/netaddr"
"tailscale.com/logtail/backoff"
"tailscale.com/net/netns"
"tailscale.com/types/logger"
"tailscale.com/util/dnsname"
)
// headerBytes is the number of bytes in a DNS message header.
@ -100,6 +102,11 @@ func getTxID(packet []byte) txid {
return (txid(hash) << 32) | txid(dnsid)
}
type route struct {
suffix string
resolvers []netaddr.IPPort
}
// forwarder forwards DNS packets to a number of upstream nameservers.
type forwarder struct {
logf logger.Logf
@ -116,10 +123,9 @@ type forwarder struct {
conns []*fwdConn
mu sync.Mutex
// upstreams are the nameserver addresses that should be used for forwarding.
upstreams []net.Addr
// txMap maps DNS txids to active forwarding records.
txMap map[txid]forwardingRecord
// routes are per-suffix resolvers to use.
routes []route // most specific routes first
txMap map[txid]forwardingRecord // txids to in-flight requests
}
func init() {
@ -127,24 +133,22 @@ func init() {
}
func newForwarder(logf logger.Logf, responses chan packet) *forwarder {
return &forwarder{
ret := &forwarder{
logf: logger.WithPrefix(logf, "forward: "),
responses: responses,
closed: make(chan struct{}),
conns: make([]*fwdConn, connCount),
txMap: make(map[txid]forwardingRecord),
}
}
func (f *forwarder) Start() error {
f.wg.Add(connCount + 1)
for idx := range f.conns {
f.conns[idx] = newFwdConn(f.logf, idx)
go f.recv(f.conns[idx])
ret.wg.Add(connCount + 1)
for idx := range ret.conns {
ret.conns[idx] = newFwdConn(ret.logf, idx)
go ret.recv(ret.conns[idx])
}
go f.cleanMap()
go ret.cleanMap()
return nil
return ret
}
func (f *forwarder) Close() {
@ -171,14 +175,15 @@ func (f *forwarder) rebindFromNetworkChange() {
}
}
func (f *forwarder) setUpstreams(upstreams []net.Addr) {
func (f *forwarder) setRoutes(routes []route) {
fmt.Println(routes)
f.mu.Lock()
f.upstreams = upstreams
f.routes = routes
f.mu.Unlock()
}
// send sends packet to dst. It is best effort.
func (f *forwarder) send(packet []byte, dst net.Addr) {
func (f *forwarder) send(packet []byte, dst netaddr.IPPort) {
connIdx := rand.Intn(connCount)
conn := f.conns[connIdx]
conn.send(packet, dst)
@ -256,24 +261,38 @@ func (f *forwarder) cleanMap() {
// forward forwards the query to all upstream nameservers and returns the first response.
func (f *forwarder) forward(query packet) error {
domain, err := nameFromQuery(query.bs)
if err != nil {
return err
}
txid := getTxID(query.bs)
f.mu.Lock()
routes := f.routes
f.mu.Unlock()
upstreams := f.upstreams
if len(upstreams) == 0 {
f.mu.Unlock()
var resolvers []netaddr.IPPort
for _, route := range routes {
if route.suffix != "." && !dnsname.HasSuffix(domain, route.suffix) {
continue
}
resolvers = route.resolvers
break
}
if len(resolvers) == 0 {
return errNoUpstreams
}
f.mu.Lock()
f.txMap[txid] = forwardingRecord{
src: query.addr,
createdAt: time.Now(),
}
f.mu.Unlock()
for _, upstream := range upstreams {
f.send(query.bs, upstream)
for _, resolver := range resolvers {
f.send(query.bs, resolver)
}
return nil
@ -309,7 +328,7 @@ func newFwdConn(logf logger.Logf, idx int) *fwdConn {
// send sends packet to dst using c's connection.
// It is best effort. It is UDP, after all. Failures are logged.
func (c *fwdConn) send(packet []byte, dst net.Addr) {
func (c *fwdConn) send(packet []byte, dst netaddr.IPPort) {
var b *backoff.Backoff // lazily initialized, since it is not needed in the common case
backOff := func(err error) {
if b == nil {
@ -335,8 +354,9 @@ func (c *fwdConn) send(packet []byte, dst net.Addr) {
}
c.mu.Unlock()
a := dst.UDPAddr()
c.wg.Add(1)
_, err := conn.WriteTo(packet, dst)
_, err := conn.WriteTo(packet, a)
c.wg.Done()
if err == nil {
// Success
@ -469,3 +489,24 @@ func (c *fwdConn) close() {
// Unblock any remaining readers.
c.change.Broadcast()
}
// nameFromQuery extracts the normalized query name from bs.
func nameFromQuery(bs []byte) (string, error) {
var parser dns.Parser
hdr, err := parser.Start(bs)
if err != nil {
return "", err
}
if hdr.Response {
return "", errNotQuery
}
q, err := parser.Question()
if err != nil {
return "", err
}
n := q.Name.Data[:q.Name.Length]
return rawNameToLower(n), nil
}

@ -10,7 +10,6 @@ import (
"encoding/hex"
"errors"
"fmt"
"net"
"sort"
"strings"
"sync"
@ -68,11 +67,6 @@ type Config struct {
LocalDomains []string
}
type route struct {
suffix string
resolvers []netaddr.IPPort
}
// Resolver is a DNS resolver for nodes on the Tailscale network,
// associating them with domain names of the form <mynode>.<mydomain>.<root>.
// If it is asked to resolve a domain that is not of that form,
@ -100,7 +94,6 @@ type Resolver struct {
localDomains []string
hostToIP map[string][]netaddr.IP
ipToHost map[netaddr.IP]string
routes []route // most specific routes first
}
// New returns a new resolver.
@ -121,10 +114,6 @@ func New(logf logger.Logf, linkMon *monitor.Mon) (*Resolver, error) {
r.unregLinkMon = r.linkMon.RegisterChangeCallback(r.onLinkMonitorChange)
}
if err := r.forwarder.Start(); err != nil {
return nil, err
}
r.wg.Add(1)
go r.poll()
@ -138,7 +127,6 @@ func isFQDN(s string) bool {
func (r *Resolver) SetConfig(cfg Config) error {
routes := make([]route, 0, len(cfg.Routes))
reverse := make(map[netaddr.IP]string, len(cfg.Hosts))
var defaultUpstream []net.Addr
for host, ips := range cfg.Hosts {
if !isFQDN(host) {
@ -162,32 +150,19 @@ func (r *Resolver) SetConfig(cfg Config) error {
suffix: suffix,
resolvers: ips,
})
if suffix == "." {
// TODO: this is a temporary hack to forward upstream
// resolvers to the forwarder, which doesn't yet
// understand per-domain resolvers. Effectively, SetConfig
// currently ignores all routes except for ".", which it
// sets as the only resolver.
for _, ip := range ips {
up := ip.UDPAddr()
defaultUpstream = append(defaultUpstream, up)
}
}
}
// Sort from longest prefix to shortest.
sort.Slice(routes, func(i, j int) bool {
return strings.Count(routes[i].suffix, ".") > strings.Count(routes[j].suffix, ".")
return dnsname.NumLabels(routes[i].suffix) > dnsname.NumLabels(routes[j].suffix)
})
r.forwarder.setUpstreams(defaultUpstream)
r.forwarder.setRoutes(routes)
r.mu.Lock()
defer r.mu.Unlock()
r.localDomains = cfg.LocalDomains
r.hostToIP = cfg.Hosts
r.ipToHost = reverse
r.routes = routes
return nil
}
@ -386,6 +361,8 @@ type response struct {
}
// parseQuery parses the query in given packet into a response struct.
// if the parse is successful, resp.Name contains the normalized name being queried.
// TODO: stuffing the query name in resp.Name temporarily is a hack. Clean it up.
func parseQuery(query []byte, resp *response) error {
var parser dns.Parser
var err error

@ -5,7 +5,7 @@
package resolver
import (
"log"
"fmt"
"testing"
"github.com/miekg/dns"
@ -16,8 +16,6 @@ import (
// that depends on github.com/miekg/dns
// from the rest, which only depends on dnsmessage.
var dnsHandleFunc = dns.HandleFunc
// resolveToIP returns a handler function which responds
// to queries of type A it receives with an A record containing ipv4,
// to queries of type AAAA with an AAAA record containing ipv6,
@ -68,28 +66,38 @@ func resolveToIP(ipv4, ipv6 netaddr.IP, ns string) dns.HandlerFunc {
}
}
func resolveToNXDOMAIN(w dns.ResponseWriter, req *dns.Msg) {
var resolveToNXDOMAIN = dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
m := new(dns.Msg)
m.SetRcode(req, dns.RcodeNameError)
w.WriteMsg(m)
}
func serveDNS(tb testing.TB, addr string) (*dns.Server, chan error) {
server := &dns.Server{Addr: addr, Net: "udp"}
})
func serveDNS(tb testing.TB, addr string, records ...interface{}) *dns.Server {
if len(records)%2 != 0 {
panic("must have an even number of record values")
}
mux := dns.NewServeMux()
for i := 0; i < len(records); i += 2 {
name := records[i].(string)
handler := records[i+1].(dns.Handler)
mux.Handle(name, handler)
}
waitch := make(chan struct{})
server.NotifyStartedFunc = func() { close(waitch) }
server := &dns.Server{
Addr: addr,
Net: "udp",
Handler: mux,
NotifyStartedFunc: func() { close(waitch) },
ReusePort: true,
}
errch := make(chan error, 1)
go func() {
err := server.ListenAndServe()
if err != nil {
log.Printf("ListenAndServe(%q): %v", addr, err)
panic(fmt.Sprintf("ListenAndServe(%q): %v", addr, err))
}
errch <- err
close(errch)
}()
<-waitch
return server, errch
return server
}

@ -15,13 +15,8 @@ import (
"tailscale.com/tstest"
)
var testipv4 = netaddr.IPv4(1, 2, 3, 4)
var testipv6 = netaddr.IPv6Raw([16]byte{
0x00, 0x01, 0x02, 0x03,
0x04, 0x05, 0x06, 0x07,
0x08, 0x09, 0x0a, 0x0b,
0x0c, 0x0d, 0x0e, 0x0f,
})
var testipv4 = netaddr.MustParseIP("1.2.3.4")
var testipv6 = netaddr.MustParseIP("0001:0203:0405:0607:0809:0a0b:0c0d:0e0f")
var dnsCfg = Config{
Hosts: map[string][]netaddr.IP{
@ -283,32 +278,14 @@ func TestDelegate(t *testing.T) {
t.Skip("skipping test that requires localhost IPv6")
}
dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
dnsHandleFunc("nxdomain.site.", resolveToNXDOMAIN)
v4server, v4errch := serveDNS(t, "127.0.0.1:0")
v6server, v6errch := serveDNS(t, "[::1]:0")
defer func() {
if err := <-v4errch; err != nil {
t.Errorf("v4 server error: %v", err)
}
if err := <-v6errch; err != nil {
t.Errorf("v6 server error: %v", err)
}
}()
if v4server != nil {
defer v4server.Shutdown()
}
if v6server != nil {
defer v6server.Shutdown()
}
if v4server == nil || v6server == nil {
// There is an error in at least one of the channels
// and we cannot proceed; return to see it.
return
}
v4server := serveDNS(t, "127.0.0.1:0",
"test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."),
"nxdomain.site.", resolveToNXDOMAIN)
defer v4server.Shutdown()
v6server := serveDNS(t, "[::1]:0",
"test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."),
"nxdomain.site.", resolveToNXDOMAIN)
defer v6server.Shutdown()
r, err := New(t.Logf, nil)
if err != nil {
@ -377,19 +354,75 @@ func TestDelegate(t *testing.T) {
}
}
func TestDelegateCollision(t *testing.T) {
dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
func TestDelegateSplitRoute(t *testing.T) {
test4 := netaddr.MustParseIP("2.3.4.5")
test6 := netaddr.MustParseIP("ff::1")
server, errch := serveDNS(t, "127.0.0.1:0")
defer func() {
if err := <-errch; err != nil {
t.Errorf("server error: %v", err)
}
}()
server1 := serveDNS(t, "127.0.0.1:0",
"test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
defer server1.Shutdown()
server2 := serveDNS(t, "127.0.0.1:0",
"test.other.", resolveToIP(test4, test6, "dns.other."))
defer server2.Shutdown()
r, err := New(t.Logf, nil)
if err != nil {
t.Fatalf("start: %v", err)
}
defer r.Close()
if server == nil {
return
cfg := dnsCfg
cfg.Routes = map[string][]netaddr.IPPort{
".": {netaddr.MustParseIPPort(server1.PacketConn.LocalAddr().String())},
"other.": {netaddr.MustParseIPPort(server2.PacketConn.LocalAddr().String())},
}
r.SetConfig(cfg)
tests := []struct {
title string
query []byte
response dnsResponse
}{
{
"general",
dnspacket("test.site.", dns.TypeA),
dnsResponse{ip: testipv4, rcode: dns.RCodeSuccess},
},
{
"override",
dnspacket("test.other.", dns.TypeA),
dnsResponse{ip: test4, rcode: dns.RCodeSuccess},
},
}
for _, tt := range tests {
t.Run(tt.title, func(t *testing.T) {
payload, err := syncRespond(r, tt.query)
if err != nil {
t.Errorf("err = %v; want nil", err)
return
}
response, err := unpackResponse(payload)
if err != nil {
t.Errorf("extract: err = %v; want nil (in %x)", err, payload)
return
}
if response.rcode != tt.response.rcode {
t.Errorf("rcode = %v; want %v", response.rcode, tt.response.rcode)
}
if response.ip != tt.response.ip {
t.Errorf("ip = %v; want %v", response.ip, tt.response.ip)
}
if response.name != tt.response.name {
t.Errorf("name = %v; want %v", response.name, tt.response.name)
}
})
}
}
func TestDelegateCollision(t *testing.T) {
server := serveDNS(t, "127.0.0.1:0",
"test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
defer server.Shutdown()
r, err := New(t.Logf, nil)
@ -628,8 +661,8 @@ func TestFull(t *testing.T) {
{"ipv6", dnspacket("test2.ipn.dev.", dns.TypeAAAA), ipv6Response},
{"no-ipv6", dnspacket("test1.ipn.dev.", dns.TypeAAAA), emptyResponse},
{"upper", dnspacket("TEST1.IPN.DEV.", dns.TypeA), ipv4UppercaseResponse},
{"ptr", dnspacket("4.3.2.1.in-addr.arpa.", dns.TypePTR), ptrResponse},
{"ptr", dnspacket("f.0.e.0.d.0.c.0.b.0.a.0.9.0.8.0.7.0.6.0.5.0.4.0.3.0.2.0.1.0.0.0.ip6.arpa.",
{"ptr4", dnspacket("4.3.2.1.in-addr.arpa.", dns.TypePTR), ptrResponse},
{"ptr6", dnspacket("f.0.e.0.d.0.c.0.b.0.a.0.9.0.8.0.7.0.6.0.5.0.4.0.3.0.2.0.1.0.0.0.ip6.arpa.",
dns.TypePTR), ptrResponse6},
{"nxdomain", dnspacket("test3.ipn.dev.", dns.TypeA), nxdomainResponse},
}
@ -702,18 +735,8 @@ func TestTrimRDNSBonjourPrefix(t *testing.T) {
}
func BenchmarkFull(b *testing.B) {
dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
server, errch := serveDNS(b, "127.0.0.1:0")
defer func() {
if err := <-errch; err != nil {
b.Errorf("server error: %v", err)
}
}()
if server == nil {
return
}
server := serveDNS(b, "127.0.0.1:0",
"test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
defer server.Shutdown()
r, err := New(b.Logf, nil)

@ -124,3 +124,12 @@ func SanitizeHostname(hostname string) string {
hostname = TrimCommonSuffixes(hostname)
return SanitizeLabel(hostname)
}
// NumLabels returns the number of DNS labels in hostname.
// If hostname is empty or the top-level name ".", returns 0.
func NumLabels(hostname string) int {
if hostname == "" || hostname == "." {
return 0
}
return strings.Count(hostname, ".")
}

Loading…
Cancel
Save