net/dns/resolver: teach the forwarder to do per-domain routing.

Given a DNS route map, the forwarder selects the right set of
upstreams for a given name.

Signed-off-by: David Anderson <danderson@tailscale.com>
pull/1644/head
David Anderson 4 years ago
parent 4ed111281b
commit 9f105d3968

@ -17,10 +17,12 @@ import (
"sync" "sync"
"time" "time"
dns "golang.org/x/net/dns/dnsmessage"
"inet.af/netaddr" "inet.af/netaddr"
"tailscale.com/logtail/backoff" "tailscale.com/logtail/backoff"
"tailscale.com/net/netns" "tailscale.com/net/netns"
"tailscale.com/types/logger" "tailscale.com/types/logger"
"tailscale.com/util/dnsname"
) )
// headerBytes is the number of bytes in a DNS message header. // headerBytes is the number of bytes in a DNS message header.
@ -100,6 +102,11 @@ func getTxID(packet []byte) txid {
return (txid(hash) << 32) | txid(dnsid) return (txid(hash) << 32) | txid(dnsid)
} }
type route struct {
suffix string
resolvers []netaddr.IPPort
}
// forwarder forwards DNS packets to a number of upstream nameservers. // forwarder forwards DNS packets to a number of upstream nameservers.
type forwarder struct { type forwarder struct {
logf logger.Logf logf logger.Logf
@ -116,10 +123,9 @@ type forwarder struct {
conns []*fwdConn conns []*fwdConn
mu sync.Mutex mu sync.Mutex
// upstreams are the nameserver addresses that should be used for forwarding. // routes are per-suffix resolvers to use.
upstreams []net.Addr routes []route // most specific routes first
// txMap maps DNS txids to active forwarding records. txMap map[txid]forwardingRecord // txids to in-flight requests
txMap map[txid]forwardingRecord
} }
func init() { func init() {
@ -127,24 +133,22 @@ func init() {
} }
func newForwarder(logf logger.Logf, responses chan packet) *forwarder { func newForwarder(logf logger.Logf, responses chan packet) *forwarder {
return &forwarder{ ret := &forwarder{
logf: logger.WithPrefix(logf, "forward: "), logf: logger.WithPrefix(logf, "forward: "),
responses: responses, responses: responses,
closed: make(chan struct{}), closed: make(chan struct{}),
conns: make([]*fwdConn, connCount), conns: make([]*fwdConn, connCount),
txMap: make(map[txid]forwardingRecord), txMap: make(map[txid]forwardingRecord),
} }
}
func (f *forwarder) Start() error { ret.wg.Add(connCount + 1)
f.wg.Add(connCount + 1) for idx := range ret.conns {
for idx := range f.conns { ret.conns[idx] = newFwdConn(ret.logf, idx)
f.conns[idx] = newFwdConn(f.logf, idx) go ret.recv(ret.conns[idx])
go f.recv(f.conns[idx])
} }
go f.cleanMap() go ret.cleanMap()
return nil return ret
} }
func (f *forwarder) Close() { func (f *forwarder) Close() {
@ -171,14 +175,15 @@ func (f *forwarder) rebindFromNetworkChange() {
} }
} }
func (f *forwarder) setUpstreams(upstreams []net.Addr) { func (f *forwarder) setRoutes(routes []route) {
fmt.Println(routes)
f.mu.Lock() f.mu.Lock()
f.upstreams = upstreams f.routes = routes
f.mu.Unlock() f.mu.Unlock()
} }
// send sends packet to dst. It is best effort. // send sends packet to dst. It is best effort.
func (f *forwarder) send(packet []byte, dst net.Addr) { func (f *forwarder) send(packet []byte, dst netaddr.IPPort) {
connIdx := rand.Intn(connCount) connIdx := rand.Intn(connCount)
conn := f.conns[connIdx] conn := f.conns[connIdx]
conn.send(packet, dst) conn.send(packet, dst)
@ -256,24 +261,38 @@ func (f *forwarder) cleanMap() {
// forward forwards the query to all upstream nameservers and returns the first response. // forward forwards the query to all upstream nameservers and returns the first response.
func (f *forwarder) forward(query packet) error { func (f *forwarder) forward(query packet) error {
domain, err := nameFromQuery(query.bs)
if err != nil {
return err
}
txid := getTxID(query.bs) txid := getTxID(query.bs)
f.mu.Lock() f.mu.Lock()
routes := f.routes
f.mu.Unlock()
upstreams := f.upstreams var resolvers []netaddr.IPPort
if len(upstreams) == 0 { for _, route := range routes {
f.mu.Unlock() if route.suffix != "." && !dnsname.HasSuffix(domain, route.suffix) {
continue
}
resolvers = route.resolvers
break
}
if len(resolvers) == 0 {
return errNoUpstreams return errNoUpstreams
} }
f.mu.Lock()
f.txMap[txid] = forwardingRecord{ f.txMap[txid] = forwardingRecord{
src: query.addr, src: query.addr,
createdAt: time.Now(), createdAt: time.Now(),
} }
f.mu.Unlock() f.mu.Unlock()
for _, upstream := range upstreams { for _, resolver := range resolvers {
f.send(query.bs, upstream) f.send(query.bs, resolver)
} }
return nil return nil
@ -309,7 +328,7 @@ func newFwdConn(logf logger.Logf, idx int) *fwdConn {
// send sends packet to dst using c's connection. // send sends packet to dst using c's connection.
// It is best effort. It is UDP, after all. Failures are logged. // It is best effort. It is UDP, after all. Failures are logged.
func (c *fwdConn) send(packet []byte, dst net.Addr) { func (c *fwdConn) send(packet []byte, dst netaddr.IPPort) {
var b *backoff.Backoff // lazily initialized, since it is not needed in the common case var b *backoff.Backoff // lazily initialized, since it is not needed in the common case
backOff := func(err error) { backOff := func(err error) {
if b == nil { if b == nil {
@ -335,8 +354,9 @@ func (c *fwdConn) send(packet []byte, dst net.Addr) {
} }
c.mu.Unlock() c.mu.Unlock()
a := dst.UDPAddr()
c.wg.Add(1) c.wg.Add(1)
_, err := conn.WriteTo(packet, dst) _, err := conn.WriteTo(packet, a)
c.wg.Done() c.wg.Done()
if err == nil { if err == nil {
// Success // Success
@ -469,3 +489,24 @@ func (c *fwdConn) close() {
// Unblock any remaining readers. // Unblock any remaining readers.
c.change.Broadcast() c.change.Broadcast()
} }
// nameFromQuery extracts the normalized query name from bs.
func nameFromQuery(bs []byte) (string, error) {
var parser dns.Parser
hdr, err := parser.Start(bs)
if err != nil {
return "", err
}
if hdr.Response {
return "", errNotQuery
}
q, err := parser.Question()
if err != nil {
return "", err
}
n := q.Name.Data[:q.Name.Length]
return rawNameToLower(n), nil
}

@ -10,7 +10,6 @@ import (
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
"net"
"sort" "sort"
"strings" "strings"
"sync" "sync"
@ -68,11 +67,6 @@ type Config struct {
LocalDomains []string LocalDomains []string
} }
type route struct {
suffix string
resolvers []netaddr.IPPort
}
// Resolver is a DNS resolver for nodes on the Tailscale network, // Resolver is a DNS resolver for nodes on the Tailscale network,
// associating them with domain names of the form <mynode>.<mydomain>.<root>. // associating them with domain names of the form <mynode>.<mydomain>.<root>.
// If it is asked to resolve a domain that is not of that form, // If it is asked to resolve a domain that is not of that form,
@ -100,7 +94,6 @@ type Resolver struct {
localDomains []string localDomains []string
hostToIP map[string][]netaddr.IP hostToIP map[string][]netaddr.IP
ipToHost map[netaddr.IP]string ipToHost map[netaddr.IP]string
routes []route // most specific routes first
} }
// New returns a new resolver. // New returns a new resolver.
@ -121,10 +114,6 @@ func New(logf logger.Logf, linkMon *monitor.Mon) (*Resolver, error) {
r.unregLinkMon = r.linkMon.RegisterChangeCallback(r.onLinkMonitorChange) r.unregLinkMon = r.linkMon.RegisterChangeCallback(r.onLinkMonitorChange)
} }
if err := r.forwarder.Start(); err != nil {
return nil, err
}
r.wg.Add(1) r.wg.Add(1)
go r.poll() go r.poll()
@ -138,7 +127,6 @@ func isFQDN(s string) bool {
func (r *Resolver) SetConfig(cfg Config) error { func (r *Resolver) SetConfig(cfg Config) error {
routes := make([]route, 0, len(cfg.Routes)) routes := make([]route, 0, len(cfg.Routes))
reverse := make(map[netaddr.IP]string, len(cfg.Hosts)) reverse := make(map[netaddr.IP]string, len(cfg.Hosts))
var defaultUpstream []net.Addr
for host, ips := range cfg.Hosts { for host, ips := range cfg.Hosts {
if !isFQDN(host) { if !isFQDN(host) {
@ -162,32 +150,19 @@ func (r *Resolver) SetConfig(cfg Config) error {
suffix: suffix, suffix: suffix,
resolvers: ips, resolvers: ips,
}) })
if suffix == "." {
// TODO: this is a temporary hack to forward upstream
// resolvers to the forwarder, which doesn't yet
// understand per-domain resolvers. Effectively, SetConfig
// currently ignores all routes except for ".", which it
// sets as the only resolver.
for _, ip := range ips {
up := ip.UDPAddr()
defaultUpstream = append(defaultUpstream, up)
}
}
} }
// Sort from longest prefix to shortest. // Sort from longest prefix to shortest.
sort.Slice(routes, func(i, j int) bool { sort.Slice(routes, func(i, j int) bool {
return strings.Count(routes[i].suffix, ".") > strings.Count(routes[j].suffix, ".") return dnsname.NumLabels(routes[i].suffix) > dnsname.NumLabels(routes[j].suffix)
}) })
r.forwarder.setUpstreams(defaultUpstream) r.forwarder.setRoutes(routes)
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
r.localDomains = cfg.LocalDomains r.localDomains = cfg.LocalDomains
r.hostToIP = cfg.Hosts r.hostToIP = cfg.Hosts
r.ipToHost = reverse r.ipToHost = reverse
r.routes = routes
return nil return nil
} }
@ -386,6 +361,8 @@ type response struct {
} }
// parseQuery parses the query in given packet into a response struct. // parseQuery parses the query in given packet into a response struct.
// if the parse is successful, resp.Name contains the normalized name being queried.
// TODO: stuffing the query name in resp.Name temporarily is a hack. Clean it up.
func parseQuery(query []byte, resp *response) error { func parseQuery(query []byte, resp *response) error {
var parser dns.Parser var parser dns.Parser
var err error var err error

@ -5,7 +5,7 @@
package resolver package resolver
import ( import (
"log" "fmt"
"testing" "testing"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -16,8 +16,6 @@ import (
// that depends on github.com/miekg/dns // that depends on github.com/miekg/dns
// from the rest, which only depends on dnsmessage. // from the rest, which only depends on dnsmessage.
var dnsHandleFunc = dns.HandleFunc
// resolveToIP returns a handler function which responds // resolveToIP returns a handler function which responds
// to queries of type A it receives with an A record containing ipv4, // to queries of type A it receives with an A record containing ipv4,
// to queries of type AAAA with an AAAA record containing ipv6, // to queries of type AAAA with an AAAA record containing ipv6,
@ -68,28 +66,38 @@ func resolveToIP(ipv4, ipv6 netaddr.IP, ns string) dns.HandlerFunc {
} }
} }
func resolveToNXDOMAIN(w dns.ResponseWriter, req *dns.Msg) { var resolveToNXDOMAIN = dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
m := new(dns.Msg) m := new(dns.Msg)
m.SetRcode(req, dns.RcodeNameError) m.SetRcode(req, dns.RcodeNameError)
w.WriteMsg(m) w.WriteMsg(m)
} })
func serveDNS(tb testing.TB, addr string) (*dns.Server, chan error) {
server := &dns.Server{Addr: addr, Net: "udp"}
func serveDNS(tb testing.TB, addr string, records ...interface{}) *dns.Server {
if len(records)%2 != 0 {
panic("must have an even number of record values")
}
mux := dns.NewServeMux()
for i := 0; i < len(records); i += 2 {
name := records[i].(string)
handler := records[i+1].(dns.Handler)
mux.Handle(name, handler)
}
waitch := make(chan struct{}) waitch := make(chan struct{})
server.NotifyStartedFunc = func() { close(waitch) } server := &dns.Server{
Addr: addr,
Net: "udp",
Handler: mux,
NotifyStartedFunc: func() { close(waitch) },
ReusePort: true,
}
errch := make(chan error, 1)
go func() { go func() {
err := server.ListenAndServe() err := server.ListenAndServe()
if err != nil { if err != nil {
log.Printf("ListenAndServe(%q): %v", addr, err) panic(fmt.Sprintf("ListenAndServe(%q): %v", addr, err))
} }
errch <- err
close(errch)
}() }()
<-waitch <-waitch
return server, errch return server
} }

@ -15,13 +15,8 @@ import (
"tailscale.com/tstest" "tailscale.com/tstest"
) )
var testipv4 = netaddr.IPv4(1, 2, 3, 4) var testipv4 = netaddr.MustParseIP("1.2.3.4")
var testipv6 = netaddr.IPv6Raw([16]byte{ var testipv6 = netaddr.MustParseIP("0001:0203:0405:0607:0809:0a0b:0c0d:0e0f")
0x00, 0x01, 0x02, 0x03,
0x04, 0x05, 0x06, 0x07,
0x08, 0x09, 0x0a, 0x0b,
0x0c, 0x0d, 0x0e, 0x0f,
})
var dnsCfg = Config{ var dnsCfg = Config{
Hosts: map[string][]netaddr.IP{ Hosts: map[string][]netaddr.IP{
@ -283,32 +278,14 @@ func TestDelegate(t *testing.T) {
t.Skip("skipping test that requires localhost IPv6") t.Skip("skipping test that requires localhost IPv6")
} }
dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site.")) v4server := serveDNS(t, "127.0.0.1:0",
dnsHandleFunc("nxdomain.site.", resolveToNXDOMAIN) "test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."),
"nxdomain.site.", resolveToNXDOMAIN)
v4server, v4errch := serveDNS(t, "127.0.0.1:0") defer v4server.Shutdown()
v6server, v6errch := serveDNS(t, "[::1]:0") v6server := serveDNS(t, "[::1]:0",
"test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."),
defer func() { "nxdomain.site.", resolveToNXDOMAIN)
if err := <-v4errch; err != nil { defer v6server.Shutdown()
t.Errorf("v4 server error: %v", err)
}
if err := <-v6errch; err != nil {
t.Errorf("v6 server error: %v", err)
}
}()
if v4server != nil {
defer v4server.Shutdown()
}
if v6server != nil {
defer v6server.Shutdown()
}
if v4server == nil || v6server == nil {
// There is an error in at least one of the channels
// and we cannot proceed; return to see it.
return
}
r, err := New(t.Logf, nil) r, err := New(t.Logf, nil)
if err != nil { if err != nil {
@ -377,19 +354,75 @@ func TestDelegate(t *testing.T) {
} }
} }
func TestDelegateCollision(t *testing.T) { func TestDelegateSplitRoute(t *testing.T) {
dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site.")) test4 := netaddr.MustParseIP("2.3.4.5")
test6 := netaddr.MustParseIP("ff::1")
server, errch := serveDNS(t, "127.0.0.1:0") server1 := serveDNS(t, "127.0.0.1:0",
defer func() { "test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
if err := <-errch; err != nil { defer server1.Shutdown()
t.Errorf("server error: %v", err) server2 := serveDNS(t, "127.0.0.1:0",
} "test.other.", resolveToIP(test4, test6, "dns.other."))
}() defer server2.Shutdown()
r, err := New(t.Logf, nil)
if err != nil {
t.Fatalf("start: %v", err)
}
defer r.Close()
if server == nil { cfg := dnsCfg
return cfg.Routes = map[string][]netaddr.IPPort{
".": {netaddr.MustParseIPPort(server1.PacketConn.LocalAddr().String())},
"other.": {netaddr.MustParseIPPort(server2.PacketConn.LocalAddr().String())},
}
r.SetConfig(cfg)
tests := []struct {
title string
query []byte
response dnsResponse
}{
{
"general",
dnspacket("test.site.", dns.TypeA),
dnsResponse{ip: testipv4, rcode: dns.RCodeSuccess},
},
{
"override",
dnspacket("test.other.", dns.TypeA),
dnsResponse{ip: test4, rcode: dns.RCodeSuccess},
},
} }
for _, tt := range tests {
t.Run(tt.title, func(t *testing.T) {
payload, err := syncRespond(r, tt.query)
if err != nil {
t.Errorf("err = %v; want nil", err)
return
}
response, err := unpackResponse(payload)
if err != nil {
t.Errorf("extract: err = %v; want nil (in %x)", err, payload)
return
}
if response.rcode != tt.response.rcode {
t.Errorf("rcode = %v; want %v", response.rcode, tt.response.rcode)
}
if response.ip != tt.response.ip {
t.Errorf("ip = %v; want %v", response.ip, tt.response.ip)
}
if response.name != tt.response.name {
t.Errorf("name = %v; want %v", response.name, tt.response.name)
}
})
}
}
func TestDelegateCollision(t *testing.T) {
server := serveDNS(t, "127.0.0.1:0",
"test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
defer server.Shutdown() defer server.Shutdown()
r, err := New(t.Logf, nil) r, err := New(t.Logf, nil)
@ -628,8 +661,8 @@ func TestFull(t *testing.T) {
{"ipv6", dnspacket("test2.ipn.dev.", dns.TypeAAAA), ipv6Response}, {"ipv6", dnspacket("test2.ipn.dev.", dns.TypeAAAA), ipv6Response},
{"no-ipv6", dnspacket("test1.ipn.dev.", dns.TypeAAAA), emptyResponse}, {"no-ipv6", dnspacket("test1.ipn.dev.", dns.TypeAAAA), emptyResponse},
{"upper", dnspacket("TEST1.IPN.DEV.", dns.TypeA), ipv4UppercaseResponse}, {"upper", dnspacket("TEST1.IPN.DEV.", dns.TypeA), ipv4UppercaseResponse},
{"ptr", dnspacket("4.3.2.1.in-addr.arpa.", dns.TypePTR), ptrResponse}, {"ptr4", dnspacket("4.3.2.1.in-addr.arpa.", dns.TypePTR), ptrResponse},
{"ptr", dnspacket("f.0.e.0.d.0.c.0.b.0.a.0.9.0.8.0.7.0.6.0.5.0.4.0.3.0.2.0.1.0.0.0.ip6.arpa.", {"ptr6", dnspacket("f.0.e.0.d.0.c.0.b.0.a.0.9.0.8.0.7.0.6.0.5.0.4.0.3.0.2.0.1.0.0.0.ip6.arpa.",
dns.TypePTR), ptrResponse6}, dns.TypePTR), ptrResponse6},
{"nxdomain", dnspacket("test3.ipn.dev.", dns.TypeA), nxdomainResponse}, {"nxdomain", dnspacket("test3.ipn.dev.", dns.TypeA), nxdomainResponse},
} }
@ -702,18 +735,8 @@ func TestTrimRDNSBonjourPrefix(t *testing.T) {
} }
func BenchmarkFull(b *testing.B) { func BenchmarkFull(b *testing.B) {
dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site.")) server := serveDNS(b, "127.0.0.1:0",
"test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
server, errch := serveDNS(b, "127.0.0.1:0")
defer func() {
if err := <-errch; err != nil {
b.Errorf("server error: %v", err)
}
}()
if server == nil {
return
}
defer server.Shutdown() defer server.Shutdown()
r, err := New(b.Logf, nil) r, err := New(b.Logf, nil)

@ -124,3 +124,12 @@ func SanitizeHostname(hostname string) string {
hostname = TrimCommonSuffixes(hostname) hostname = TrimCommonSuffixes(hostname)
return SanitizeLabel(hostname) return SanitizeLabel(hostname)
} }
// NumLabels returns the number of DNS labels in hostname.
// If hostname is empty or the top-level name ".", returns 0.
func NumLabels(hostname string) int {
if hostname == "" || hostname == "." {
return 0
}
return strings.Count(hostname, ".")
}

Loading…
Cancel
Save