net/dns/resolver: add live reconfig, plumb through to ipnlocal.

The resolver still only supports a single upstream config, and
ipn/wgengine still have to split up the DNS config, but this moves
closer to unifying the DNS configs.

As a handy side-effect of the refactor, IPv6 MagicDNS records exist
now.

Signed-off-by: David Anderson <danderson@tailscale.com>
pull/1635/head
David Anderson 4 years ago
parent caeafc4a32
commit 90f82b6946

@ -91,7 +91,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
tailscale.com/logtail/filch from tailscale.com/logpolicy tailscale.com/logtail/filch from tailscale.com/logpolicy
tailscale.com/metrics from tailscale.com/derp tailscale.com/metrics from tailscale.com/derp
tailscale.com/net/dns from tailscale.com/ipn/ipnlocal+ tailscale.com/net/dns from tailscale.com/ipn/ipnlocal+
tailscale.com/net/dns/resolver from tailscale.com/ipn/ipnlocal+ tailscale.com/net/dns/resolver from tailscale.com/wgengine
tailscale.com/net/dnscache from tailscale.com/control/controlclient+ tailscale.com/net/dnscache from tailscale.com/control/controlclient+
tailscale.com/net/dnsfallback from tailscale.com/control/controlclient tailscale.com/net/dnsfallback from tailscale.com/control/controlclient
tailscale.com/net/flowtrack from tailscale.com/wgengine/filter+ tailscale.com/net/flowtrack from tailscale.com/wgengine/filter+

@ -28,7 +28,6 @@ import (
"tailscale.com/ipn/ipnstate" "tailscale.com/ipn/ipnstate"
"tailscale.com/ipn/policy" "tailscale.com/ipn/policy"
"tailscale.com/net/dns" "tailscale.com/net/dns"
"tailscale.com/net/dns/resolver"
"tailscale.com/net/interfaces" "tailscale.com/net/interfaces"
"tailscale.com/net/tsaddr" "tailscale.com/net/tsaddr"
"tailscale.com/paths" "tailscale.com/paths"
@ -440,9 +439,6 @@ func (b *LocalBackend) setClientStatus(st controlclient.Status) {
b.updateFilter(st.NetMap, prefs) b.updateFilter(st.NetMap, prefs)
b.e.SetNetworkMap(st.NetMap) b.e.SetNetworkMap(st.NetMap)
if !dnsMapsEqual(st.NetMap, netMap) {
b.updateDNSMap(st.NetMap)
}
b.e.SetDERPMap(st.NetMap.DERPMap) b.e.SetDERPMap(st.NetMap.DERPMap)
b.send(ipn.Notify{NetMap: st.NetMap}) b.send(ipn.Notify{NetMap: st.NetMap})
@ -851,32 +847,6 @@ func dnsMapsEqual(new, old *netmap.NetworkMap) bool {
return true return true
} }
// updateDNSMap updates the domain map in the DNS resolver in wgengine
// based on the given netMap and user preferences.
func (b *LocalBackend) updateDNSMap(netMap *netmap.NetworkMap) {
if netMap == nil {
b.logf("dns map: (not ready)")
return
}
nameToIP := make(map[string]netaddr.IP)
set := func(name string, addrs []netaddr.IPPrefix) {
if len(addrs) == 0 || name == "" {
return
}
nameToIP[name] = addrs[0].IP
}
for _, peer := range netMap.Peers {
set(peer.Name, peer.Addresses)
}
set(netMap.Name, netMap.Addresses)
dnsMap := resolver.NewMap(nameToIP, magicDNSRootDomains(netMap))
// map diff will be logged in dns.Resolver.SetMap.
b.e.SetDNSMap(dnsMap)
}
// readPoller is a goroutine that receives service lists from // readPoller is a goroutine that receives service lists from
// b.portpoll and propagates them into the controlclient's HostInfo. // b.portpoll and propagates them into the controlclient's HostInfo.
func (b *LocalBackend) readPoller() { func (b *LocalBackend) readPoller() {
@ -1487,7 +1457,21 @@ func (b *LocalBackend) authReconfig() {
} }
} }
err = b.e.Reconfig(cfg, rcfg) nameToIP := make(map[string][]netaddr.IP)
set := func(name string, addrs []netaddr.IPPrefix) {
if len(addrs) == 0 || name == "" {
return
}
for _, addr := range addrs {
nameToIP[name] = append(nameToIP[name], addr.IP)
}
}
for _, peer := range nm.Peers {
set(peer.Name, peer.Addresses)
}
set(nm.Name, nm.Addresses)
err = b.e.Reconfig(cfg, rcfg, nameToIP, magicDNSRootDomains(nm))
if err == wgengine.ErrNoChanges { if err == wgengine.ErrNoChanges {
return return
} }
@ -1743,7 +1727,7 @@ func (b *LocalBackend) enterState(newState ipn.State) {
b.blockEngineUpdates(true) b.blockEngineUpdates(true)
fallthrough fallthrough
case ipn.Stopped: case ipn.Stopped:
err := b.e.Reconfig(&wgcfg.Config{}, &router.Config{}) err := b.e.Reconfig(&wgcfg.Config{}, &router.Config{}, nil, nil)
if err != nil { if err != nil {
b.logf("Reconfig(down): %v", err) b.logf("Reconfig(down): %v", err)
} }
@ -1835,7 +1819,7 @@ func (b *LocalBackend) stateMachine() {
// a status update that predates the "I've shut down" update. // a status update that predates the "I've shut down" update.
func (b *LocalBackend) stopEngineAndWait() { func (b *LocalBackend) stopEngineAndWait() {
b.logf("stopEngineAndWait...") b.logf("stopEngineAndWait...")
b.e.Reconfig(&wgcfg.Config{}, &router.Config{}) b.e.Reconfig(&wgcfg.Config{}, &router.Config{}, nil, nil)
b.requestEngineStatusAndWait() b.requestEngineStatusAndWait()
b.logf("stopEngineAndWait: done.") b.logf("stopEngineAndWait: done.")
} }

@ -1,160 +0,0 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package resolver
import (
"sort"
"strings"
"inet.af/netaddr"
)
// Map is all the data Resolver needs to resolve DNS queries within the Tailscale network.
type Map struct {
// nameToIP is a mapping of Tailscale domain names to their IP addresses.
// For example, monitoring.tailscale.us -> 100.64.0.1.
nameToIP map[string]netaddr.IP
// ipToName is the inverse of nameToIP.
ipToName map[netaddr.IP]string
// names are the keys of nameToIP in sorted order.
names []string
// rootDomains are the domains whose subdomains should always
// be resolved locally to prevent leakage of sensitive names.
rootDomains []string // e.g. "user.provider.beta.tailscale.net."
}
// NewMap returns a new Map with name to address mapping given by nameToIP.
//
// rootDomains are the domains whose subdomains should always be
// resolved locally to prevent leakage of sensitive names. They should
// end in a period ("user-foo.tailscale.net.").
func NewMap(initNameToIP map[string]netaddr.IP, rootDomains []string) *Map {
// TODO(dmytro): we have to allocate names and ipToName, but nameToIP can be avoided.
// It is here because control sends us names not in canonical form. Change this.
names := make([]string, 0, len(initNameToIP))
nameToIP := make(map[string]netaddr.IP, len(initNameToIP))
ipToName := make(map[netaddr.IP]string, len(initNameToIP))
for name, ip := range initNameToIP {
if len(name) == 0 {
// Nothing useful can be done with empty names.
continue
}
if name[len(name)-1] != '.' {
name += "."
}
names = append(names, name)
nameToIP[name] = ip
ipToName[ip] = name
}
sort.Strings(names)
return &Map{
nameToIP: nameToIP,
ipToName: ipToName,
names: names,
rootDomains: rootDomains,
}
}
func printSingleNameIP(buf *strings.Builder, name string, ip netaddr.IP) {
buf.WriteString(name)
buf.WriteByte('\t')
buf.WriteString(ip.String())
buf.WriteByte('\n')
}
func (m *Map) Pretty() string {
buf := new(strings.Builder)
for _, name := range m.names {
printSingleNameIP(buf, name, m.nameToIP[name])
}
return buf.String()
}
func (m *Map) PrettyDiffFrom(old *Map) string {
var (
oldNameToIP map[string]netaddr.IP
newNameToIP map[string]netaddr.IP
oldNames []string
newNames []string
)
if old != nil {
oldNameToIP = old.nameToIP
oldNames = old.names
}
if m != nil {
newNameToIP = m.nameToIP
newNames = m.names
}
buf := new(strings.Builder)
space := func() bool {
return buf.Len() < (1 << 10)
}
for len(oldNames) > 0 && len(newNames) > 0 {
var name string
newName, oldName := newNames[0], oldNames[0]
switch {
case oldName < newName:
name = oldName
oldNames = oldNames[1:]
case oldName > newName:
name = newName
newNames = newNames[1:]
case oldNames[0] == newNames[0]:
name = oldNames[0]
oldNames = oldNames[1:]
newNames = newNames[1:]
}
if !space() {
continue
}
ipOld, inOld := oldNameToIP[name]
ipNew, inNew := newNameToIP[name]
switch {
case !inOld:
buf.WriteByte('+')
printSingleNameIP(buf, name, ipNew)
case !inNew:
buf.WriteByte('-')
printSingleNameIP(buf, name, ipOld)
case ipOld != ipNew:
buf.WriteByte('-')
printSingleNameIP(buf, name, ipOld)
buf.WriteByte('+')
printSingleNameIP(buf, name, ipNew)
}
}
for _, name := range oldNames {
if !space() {
break
}
if _, ok := newNameToIP[name]; !ok {
buf.WriteByte('-')
printSingleNameIP(buf, name, oldNameToIP[name])
}
}
for _, name := range newNames {
if !space() {
break
}
if _, ok := oldNameToIP[name]; !ok {
buf.WriteByte('+')
printSingleNameIP(buf, name, newNameToIP[name])
}
}
if !space() {
buf.WriteString("... [truncated]\n")
}
return buf.String()
}

@ -1,156 +0,0 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package resolver
import (
"fmt"
"strings"
"testing"
"inet.af/netaddr"
)
func TestPretty(t *testing.T) {
tests := []struct {
name string
dmap *Map
want string
}{
{"empty", NewMap(nil, nil), ""},
{
"single",
NewMap(map[string]netaddr.IP{
"hello.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
}, nil),
"hello.ipn.dev.\t100.101.102.103\n",
},
{
"multiple",
NewMap(map[string]netaddr.IP{
"test1.domain.": netaddr.IPv4(100, 101, 102, 103),
"test2.sub.domain.": netaddr.IPv4(100, 99, 9, 1),
}, nil),
"test1.domain.\t100.101.102.103\ntest2.sub.domain.\t100.99.9.1\n",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.dmap.Pretty()
if tt.want != got {
t.Errorf("want %v; got %v", tt.want, got)
}
})
}
}
func TestPrettyDiffFrom(t *testing.T) {
tests := []struct {
name string
map1 *Map
map2 *Map
want string
}{
{
"from_empty",
nil,
NewMap(map[string]netaddr.IP{
"test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
"test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101),
}, nil),
"+test1.ipn.dev.\t100.101.102.103\n+test2.ipn.dev.\t100.103.102.101\n",
},
{
"equal",
NewMap(map[string]netaddr.IP{
"test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
"test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101),
}, nil),
NewMap(map[string]netaddr.IP{
"test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101),
"test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
}, nil),
"",
},
{
"changed_ip",
NewMap(map[string]netaddr.IP{
"test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
"test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101),
}, nil),
NewMap(map[string]netaddr.IP{
"test2.ipn.dev.": netaddr.IPv4(100, 104, 102, 101),
"test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
}, nil),
"-test2.ipn.dev.\t100.103.102.101\n+test2.ipn.dev.\t100.104.102.101\n",
},
{
"new_domain",
NewMap(map[string]netaddr.IP{
"test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
"test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101),
}, nil),
NewMap(map[string]netaddr.IP{
"test3.ipn.dev.": netaddr.IPv4(100, 105, 106, 107),
"test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101),
"test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
}, nil),
"+test3.ipn.dev.\t100.105.106.107\n",
},
{
"gone_domain",
NewMap(map[string]netaddr.IP{
"test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
"test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101),
}, nil),
NewMap(map[string]netaddr.IP{
"test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
}, nil),
"-test2.ipn.dev.\t100.103.102.101\n",
},
{
"mixed",
NewMap(map[string]netaddr.IP{
"test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
"test4.ipn.dev.": netaddr.IPv4(100, 107, 106, 105),
"test5.ipn.dev.": netaddr.IPv4(100, 64, 1, 1),
"test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101),
}, nil),
NewMap(map[string]netaddr.IP{
"test2.ipn.dev.": netaddr.IPv4(100, 104, 102, 101),
"test1.ipn.dev.": netaddr.IPv4(100, 100, 101, 102),
"test3.ipn.dev.": netaddr.IPv4(100, 64, 1, 1),
}, nil),
"-test1.ipn.dev.\t100.101.102.103\n+test1.ipn.dev.\t100.100.101.102\n" +
"-test2.ipn.dev.\t100.103.102.101\n+test2.ipn.dev.\t100.104.102.101\n" +
"+test3.ipn.dev.\t100.64.1.1\n-test4.ipn.dev.\t100.107.106.105\n-test5.ipn.dev.\t100.64.1.1\n",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.map2.PrettyDiffFrom(tt.map1)
if tt.want != got {
t.Errorf("want %v; got %v", tt.want, got)
}
})
}
t.Run("truncated", func(t *testing.T) {
small := NewMap(nil, nil)
m := map[string]netaddr.IP{}
for i := 0; i < 5000; i++ {
m[fmt.Sprintf("host%d.ipn.dev.", i)] = netaddr.IPv4(100, 64, 1, 1)
}
veryBig := NewMap(m, nil)
diff := veryBig.PrettyDiffFrom(small)
if len(diff) > 3<<10 {
t.Errorf("pretty diff too large: %d bytes", len(diff))
}
if !strings.Contains(diff, "truncated") {
t.Errorf("big diff not truncated")
}
})
}

@ -9,7 +9,9 @@ package resolver
import ( import (
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt"
"net" "net"
"sort"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -38,8 +40,6 @@ var ErrClosed = errors.New("closed")
var ( var (
errFullQueue = errors.New("request queue full") errFullQueue = errors.New("request queue full")
errMapNotSet = errors.New("domain map not set")
errNotImplemented = errors.New("query type not implemented")
errNotQuery = errors.New("not a DNS query") errNotQuery = errors.New("not a DNS query")
errNotOurName = errors.New("not a Tailscale DNS name") errNotOurName = errors.New("not a Tailscale DNS name")
) )
@ -49,6 +49,30 @@ type packet struct {
addr netaddr.IPPort // src for a request, dst for a response addr netaddr.IPPort // src for a request, dst for a response
} }
// Config is a resolver configuration.
// Given a Config, queries are resolved in the following order:
// If the query is an exact match for an entry in LocalHosts, return that.
// Else if the query suffix matches an entry in LocalDomains, return NXDOMAIN.
// Else forward the query to the most specific matching entry in Routes.
// Else return SERVFAIL.
type Config struct {
// Routes is a map of DNS name suffix to the resolvers to use for
// queries within that suffix.
// Queries only match the most specific suffix.
// To register a "default route", add an entry for ".".
Routes map[string][]netaddr.IPPort
// LocalHosts is a map of FQDNs to corresponding IPs.
Hosts map[string][]netaddr.IP
// LocalDomains is a list of DNS name suffixes that should not be
// routed to upstream resolvers.
LocalDomains []string
}
type route struct {
suffix string
resolvers []netaddr.IPPort
}
// Resolver is a DNS resolver for nodes on the Tailscale network, // Resolver is a DNS resolver for nodes on the Tailscale network,
// associating them with domain names of the form <mynode>.<mydomain>.<root>. // associating them with domain names of the form <mynode>.<mydomain>.<root>.
// If it is asked to resolve a domain that is not of that form, // If it is asked to resolve a domain that is not of that form,
@ -73,8 +97,10 @@ type Resolver struct {
// mu guards the following fields from being updated while used. // mu guards the following fields from being updated while used.
mu sync.Mutex mu sync.Mutex
// dnsMap is the map most recently received from the control server. localDomains []string
dnsMap *Map hostToIP map[string][]netaddr.IP
ipToHost map[netaddr.IP]string
routes []route // most specific routes first
} }
// New returns a new resolver. // New returns a new resolver.
@ -87,6 +113,8 @@ func New(logf logger.Logf, linkMon *monitor.Mon) (*Resolver, error) {
responses: make(chan packet), responses: make(chan packet),
errors: make(chan error), errors: make(chan error),
closed: make(chan struct{}), closed: make(chan struct{}),
hostToIP: map[string][]netaddr.IP{},
ipToHost: map[netaddr.IP]string{},
} }
r.forwarder = newForwarder(r.logf, r.responses) r.forwarder = newForwarder(r.logf, r.responses)
if r.linkMon != nil { if r.linkMon != nil {
@ -103,6 +131,66 @@ func New(logf logger.Logf, linkMon *monitor.Mon) (*Resolver, error) {
return r, nil return r, nil
} }
func isFQDN(s string) bool {
return strings.HasSuffix(s, ".")
}
func (r *Resolver) SetConfig(cfg Config) error {
routes := make([]route, 0, len(cfg.Routes))
reverse := make(map[netaddr.IP]string, len(cfg.Hosts))
var defaultUpstream []net.Addr
for host, ips := range cfg.Hosts {
if !isFQDN(host) {
return fmt.Errorf("host entry %q is not a FQDN", host)
}
for _, ip := range ips {
reverse[ip] = host
}
}
for _, domain := range cfg.LocalDomains {
if !isFQDN(domain) {
return fmt.Errorf("local domain %q is not a FQDN", domain)
}
}
for suffix, ips := range cfg.Routes {
if !strings.HasSuffix(suffix, ".") {
return fmt.Errorf("route suffix %q is not a FQDN", suffix)
}
routes = append(routes, route{
suffix: suffix,
resolvers: ips,
})
if suffix == "." {
// TODO: this is a temporary hack to forward upstream
// resolvers to the forwarder, which doesn't yet
// understand per-domain resolvers. Effectively, SetConfig
// currently ignores all routes except for ".", which it
// sets as the only resolver.
for _, ip := range ips {
up := ip.UDPAddr()
defaultUpstream = append(defaultUpstream, up)
}
}
}
// Sort from longest prefix to shortest.
sort.Slice(routes, func(i, j int) bool {
return strings.Count(routes[i].suffix, ".") > strings.Count(routes[j].suffix, ".")
})
r.forwarder.setUpstreams(defaultUpstream)
r.mu.Lock()
defer r.mu.Unlock()
r.localDomains = cfg.LocalDomains
r.hostToIP = cfg.Hosts
r.ipToHost = reverse
r.routes = routes
return nil
}
// Close shuts down the resolver and ensures poll goroutines have exited. // Close shuts down the resolver and ensures poll goroutines have exited.
// The Resolver cannot be used again after Close is called. // The Resolver cannot be used again after Close is called.
func (r *Resolver) Close() { func (r *Resolver) Close() {
@ -129,22 +217,6 @@ func (r *Resolver) onLinkMonitorChange(changed bool, state *interfaces.State) {
r.forwarder.rebindFromNetworkChange() r.forwarder.rebindFromNetworkChange()
} }
// SetMap sets the resolver's DNS map, taking ownership of it.
func (r *Resolver) SetMap(m *Map) {
r.mu.Lock()
oldMap := r.dnsMap
r.dnsMap = m
r.mu.Unlock()
r.logf("map diff:\n%s", m.PrettyDiffFrom(oldMap))
}
// SetUpstreams sets the addresses of the resolver's
// upstream nameservers, taking ownership of the argument.
func (r *Resolver) SetUpstreams(upstreams []net.Addr) {
r.forwarder.setUpstreams(upstreams)
r.logf("set upstreams: %v", upstreams)
}
// EnqueueRequest places the given DNS request in the resolver's queue. // EnqueueRequest places the given DNS request in the resolver's queue.
// It takes ownership of the payload and does not block. // It takes ownership of the payload and does not block.
// If the queue is full, the request will be dropped and an error will be returned. // If the queue is full, the request will be dropped and an error will be returned.
@ -172,62 +244,70 @@ func (r *Resolver) NextResponse() (packet []byte, to netaddr.IPPort, err error)
} }
} }
// resolve maps a given domain name to the IP address of the host that owns it, // resolveLocal returns an IP for the given domain, if domain is in
// if the IP address conforms to the DNS resource type given by tp (one of A, AAAA, ALL). // the local hosts map and has an IP corresponding to the requested
// typ (A, AAAA, ALL).
// The domain name must be in canonical form (with a trailing period). // The domain name must be in canonical form (with a trailing period).
func (r *Resolver) resolve(domain string, tp dns.Type) (netaddr.IP, dns.RCode, error) { // Returns dns.RCodeRefused to indicate that the local map is not
r.mu.Lock() // authoritative for domain.
dnsMap := r.dnsMap func (r *Resolver) resolveLocal(domain string, typ dns.Type) (netaddr.IP, dns.RCode) {
r.mu.Unlock()
if dnsMap == nil {
return netaddr.IP{}, dns.RCodeServerFailure, errMapNotSet
}
// Reject .onion domains per RFC 7686. // Reject .onion domains per RFC 7686.
if dnsname.HasSuffix(domain, ".onion") { if dnsname.HasSuffix(domain, ".onion") {
return netaddr.IP{}, dns.RCodeNameError, nil return netaddr.IP{}, dns.RCodeNameError
} }
anyHasSuffix := false r.mu.Lock()
for _, suffix := range dnsMap.rootDomains { hosts := r.hostToIP
localDomains := r.localDomains
r.mu.Unlock()
addrs, found := hosts[domain]
if !found {
for _, suffix := range localDomains {
if dnsname.HasSuffix(domain, suffix) { if dnsname.HasSuffix(domain, suffix) {
anyHasSuffix = true // We are authoritative for the queried domain.
break return netaddr.IP{}, dns.RCodeNameError
} }
} }
addr, found := dnsMap.nameToIP[domain] // Not authoritative, signal that forwarding is advisable.
if !found { return netaddr.IP{}, dns.RCodeRefused
if !anyHasSuffix {
return netaddr.IP{}, dns.RCodeRefused, nil
}
return netaddr.IP{}, dns.RCodeNameError, nil
} }
// Refactoring note: this must happen after we check suffixes, // Refactoring note: this must happen after we check suffixes,
// otherwise we will respond with NOTIMP to requests that should be forwarded. // otherwise we will respond with NOTIMP to requests that should be forwarded.
switch tp { //
// DNS semantics subtlety: when a DNS name exists, but no records
// are available for the requested record type, we must return
// RCodeSuccess with no data, not NXDOMAIN.
switch typ {
case dns.TypeA: case dns.TypeA:
if !addr.Is4() { for _, ip := range addrs {
return netaddr.IP{}, dns.RCodeSuccess, nil if ip.Is4() {
return ip, dns.RCodeSuccess
} }
return addr, dns.RCodeSuccess, nil }
return netaddr.IP{}, dns.RCodeSuccess
case dns.TypeAAAA: case dns.TypeAAAA:
if !addr.Is6() { for _, ip := range addrs {
return netaddr.IP{}, dns.RCodeSuccess, nil if ip.Is6() {
return ip, dns.RCodeSuccess
}
} }
return addr, dns.RCodeSuccess, nil return netaddr.IP{}, dns.RCodeSuccess
case dns.TypeALL: case dns.TypeALL:
// Answer with whatever we've got. // Answer with whatever we've got.
// It could be IPv4, IPv6, or a zero addr. // It could be IPv4, IPv6, or a zero addr.
// TODO: Return all available resolutions (A and AAAA, if we have them). // TODO: Return all available resolutions (A and AAAA, if we have them).
return addr, dns.RCodeSuccess, nil if len(addrs) == 0 {
return netaddr.IP{}, dns.RCodeSuccess
}
return addrs[0], dns.RCodeSuccess
// Leave some some record types explicitly unimplemented. // Leave some some record types explicitly unimplemented.
// These types relate to recursive resolution or special // These types relate to recursive resolution or special
// DNS sematics and might be implemented in the future. // DNS semantics and might be implemented in the future.
case dns.TypeNS, dns.TypeSOA, dns.TypeAXFR, dns.TypeHINFO: case dns.TypeNS, dns.TypeSOA, dns.TypeAXFR, dns.TypeHINFO:
return netaddr.IP{}, dns.RCodeNotImplemented, errNotImplemented return netaddr.IP{}, dns.RCodeNotImplemented
// For everything except for the few types above that are explictly not implemented, return no records. // For everything except for the few types above that are explictly not implemented, return no records.
// This is what other DNS systems do: always return NOERROR // This is what other DNS systems do: always return NOERROR
@ -236,26 +316,23 @@ func (r *Resolver) resolve(domain string, tp dns.Type) (netaddr.IP, dns.RCode, e
// dig -t TYPE9824 example.com // dig -t TYPE9824 example.com
// and note that NOERROR is returned, despite that record type being made up. // and note that NOERROR is returned, despite that record type being made up.
default: default:
// no records exist of this type // The name exists, but no records exist of the requested type.
return netaddr.IP{}, dns.RCodeSuccess, nil return netaddr.IP{}, dns.RCodeSuccess
} }
} }
// resolveReverse returns the unique domain name that maps to the given address. // resolveReverse returns the unique domain name that maps to the given address.
// The returned domain name is in canonical form (with a trailing period). // The returned domain name is in canonical form (with a trailing period).
func (r *Resolver) resolveReverse(ip netaddr.IP) (string, dns.RCode, error) { func (r *Resolver) resolveLocalReverse(ip netaddr.IP) (string, dns.RCode) {
r.mu.Lock() r.mu.Lock()
dnsMap := r.dnsMap ips := r.ipToHost
r.mu.Unlock() r.mu.Unlock()
if dnsMap == nil { name, found := ips[ip]
return "", dns.RCodeServerFailure, errMapNotSet
}
name, found := dnsMap.ipToName[ip]
if !found { if !found {
return "", dns.RCodeNameError, nil return "", dns.RCodeNameError
} }
return name, dns.RCodeSuccess, nil return name, dns.RCodeSuccess
} }
func (r *Resolver) poll() { func (r *Resolver) poll() {
@ -567,11 +644,7 @@ func (r *Resolver) respondReverse(query []byte, name string, resp *response) ([]
return nil, errNotOurName return nil, errNotOurName
} }
var err error resp.Name, resp.Header.RCode = r.resolveLocalReverse(ip)
resp.Name, resp.Header.RCode, err = r.resolveReverse(ip)
if err != nil {
r.logf("resolving rdns: %v", ip, err)
}
if resp.Header.RCode == dns.RCodeNameError { if resp.Header.RCode == dns.RCodeNameError {
return nil, errNotOurName return nil, errNotOurName
} }
@ -608,16 +681,11 @@ func (r *Resolver) respond(query []byte) ([]byte, error) {
return r.respondReverse(query, name, resp) return r.respondReverse(query, name, resp)
} }
resp.IP, resp.Header.RCode, err = r.resolve(name, resp.Question.Type) resp.IP, resp.Header.RCode = r.resolveLocal(name, resp.Question.Type)
// This return code is special: it requests forwarding. // This return code is special: it requests forwarding.
if resp.Header.RCode == dns.RCodeRefused { if resp.Header.RCode == dns.RCodeRefused {
return nil, errNotOurName return nil, errNotOurName
} }
// We will not return this error: it is the sender's fault.
if err != nil {
r.logf("resolving: %v", err)
}
return marshalResponse(resp) return marshalResponse(resp)
} }

@ -8,7 +8,6 @@ import (
"bytes" "bytes"
"errors" "errors"
"net" "net"
"sync"
"testing" "testing"
dns "golang.org/x/net/dns/dnsmessage" dns "golang.org/x/net/dns/dnsmessage"
@ -24,13 +23,13 @@ var testipv6 = netaddr.IPv6Raw([16]byte{
0x0c, 0x0d, 0x0e, 0x0f, 0x0c, 0x0d, 0x0e, 0x0f,
}) })
var dnsMap = NewMap( var dnsCfg = Config{
map[string]netaddr.IP{ Hosts: map[string][]netaddr.IP{
"test1.ipn.dev.": testipv4, "test1.ipn.dev.": []netaddr.IP{testipv4},
"test2.ipn.dev.": testipv6, "test2.ipn.dev.": []netaddr.IP{testipv6},
}, },
[]string{"ipn.dev."}, LocalDomains: []string{"ipn.dev."},
) }
func dnspacket(domain string, tp dns.Type) []byte { func dnspacket(domain string, tp dns.Type) []byte {
var dnsHeader dns.Header var dnsHeader dns.Header
@ -192,14 +191,14 @@ func TestRDNSNameToIPv6(t *testing.T) {
} }
} }
func TestResolve(t *testing.T) { func TestResolveLocal(t *testing.T) {
r, err := New(t.Logf, nil) r, err := New(t.Logf, nil)
if err != nil { if err != nil {
t.Fatalf("start: %v", err) t.Fatalf("start: %v", err)
} }
defer r.Close() defer r.Close()
r.SetMap(dnsMap) r.SetConfig(dnsCfg)
tests := []struct { tests := []struct {
name string name string
@ -223,10 +222,7 @@ func TestResolve(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ip, code, err := r.resolve(tt.qname, tt.qtype) ip, code := r.resolveLocal(tt.qname, tt.qtype)
if err != nil {
t.Errorf("err = %v; want nil", err)
}
if code != tt.code { if code != tt.code {
t.Errorf("code = %v; want %v", code, tt.code) t.Errorf("code = %v; want %v", code, tt.code)
} }
@ -238,14 +234,14 @@ func TestResolve(t *testing.T) {
} }
} }
func TestResolveReverse(t *testing.T) { func TestResolveLocalReverse(t *testing.T) {
r, err := New(t.Logf, nil) r, err := New(t.Logf, nil)
if err != nil { if err != nil {
t.Fatalf("start: %v", err) t.Fatalf("start: %v", err)
} }
defer r.Close() defer r.Close()
r.SetMap(dnsMap) r.SetConfig(dnsCfg)
tests := []struct { tests := []struct {
name string name string
@ -260,10 +256,7 @@ func TestResolveReverse(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
name, code, err := r.resolveReverse(tt.ip) name, code := r.resolveLocalReverse(tt.ip)
if err != nil {
t.Errorf("err = %v; want nil", err)
}
if code != tt.code { if code != tt.code {
t.Errorf("code = %v; want %v", code, tt.code) t.Errorf("code = %v; want %v", code, tt.code)
} }
@ -323,11 +316,14 @@ func TestDelegate(t *testing.T) {
} }
defer r.Close() defer r.Close()
r.SetMap(dnsMap) cfg := dnsCfg
r.SetUpstreams([]net.Addr{ cfg.Routes = map[string][]netaddr.IPPort{
v4server.PacketConn.LocalAddr(), ".": {
v6server.PacketConn.LocalAddr(), netaddr.MustParseIPPort(v4server.PacketConn.LocalAddr().String()),
}) netaddr.MustParseIPPort(v6server.PacketConn.LocalAddr().String()),
},
}
r.SetConfig(cfg)
tests := []struct { tests := []struct {
title string title string
@ -402,8 +398,13 @@ func TestDelegateCollision(t *testing.T) {
} }
defer r.Close() defer r.Close()
r.SetMap(dnsMap) cfg := dnsCfg
r.SetUpstreams([]net.Addr{server.PacketConn.LocalAddr()}) cfg.Routes = map[string][]netaddr.IPPort{
".": {
netaddr.MustParseIPPort(server.PacketConn.LocalAddr().String()),
},
}
r.SetConfig(cfg)
packets := []struct { packets := []struct {
qname string qname string
@ -460,65 +461,6 @@ func TestDelegateCollision(t *testing.T) {
} }
} }
func TestConcurrentSetMap(t *testing.T) {
r, err := New(t.Logf, nil)
if err != nil {
t.Fatalf("start: %v", err)
}
defer r.Close()
// This is purely to ensure that Resolve does not race with SetMap.
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
r.SetMap(dnsMap)
}()
go func() {
defer wg.Done()
r.resolve("test1.ipn.dev", dns.TypeA)
}()
wg.Wait()
}
func TestConcurrentSetUpstreams(t *testing.T) {
dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
server, errch := serveDNS(t, "127.0.0.1:0")
defer func() {
if err := <-errch; err != nil {
t.Errorf("server error: %v", err)
}
}()
if server == nil {
return
}
defer server.Shutdown()
r, err := New(t.Logf, nil)
if err != nil {
t.Fatalf("start: %v", err)
}
defer r.Close()
r.SetMap(dnsMap)
packet := dnspacket("test.site.", dns.TypeA)
// This is purely to ensure that delegation does not race with SetUpstreams.
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
r.SetUpstreams([]net.Addr{server.PacketConn.LocalAddr()})
}()
go func() {
defer wg.Done()
syncRespond(r, packet)
}()
wg.Wait()
}
var allResponse = []byte{ var allResponse = []byte{
0x00, 0x00, // transaction id: 0 0x00, 0x00, // transaction id: 0
0x84, 0x00, // flags: response, authoritative, no error 0x84, 0x00, // flags: response, authoritative, no error
@ -673,7 +615,7 @@ func TestFull(t *testing.T) {
} }
defer r.Close() defer r.Close()
r.SetMap(dnsMap) r.SetConfig(dnsCfg)
// One full packet and one error packet // One full packet and one error packet
tests := []struct { tests := []struct {
@ -711,7 +653,7 @@ func TestAllocs(t *testing.T) {
t.Fatalf("start: %v", err) t.Fatalf("start: %v", err)
} }
defer r.Close() defer r.Close()
r.SetMap(dnsMap) r.SetConfig(dnsCfg)
// It is seemingly pointless to test allocs in the delegate path, // It is seemingly pointless to test allocs in the delegate path,
// as dialer.Dial -> Read -> Write alone comprise 12 allocs. // as dialer.Dial -> Read -> Write alone comprise 12 allocs.
@ -780,8 +722,12 @@ func BenchmarkFull(b *testing.B) {
} }
defer r.Close() defer r.Close()
r.SetMap(dnsMap) cfg := dnsCfg
r.SetUpstreams([]net.Addr{server.PacketConn.LocalAddr()}) cfg.Routes = map[string][]netaddr.IPPort{
".": {
netaddr.MustParseIPPort(server.PacketConn.LocalAddr().String()),
},
}
tests := []struct { tests := []struct {
name string name string

@ -913,7 +913,7 @@ func genLocalAddrFunc(addrs []netaddr.IPPrefix) func(netaddr.IP) bool {
return func(t netaddr.IP) bool { return m[t] } return func(t netaddr.IP) bool { return m[t] }
} }
func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config) error { func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, hosts map[string][]netaddr.IP, localDomains []string) error {
if routerCfg == nil { if routerCfg == nil {
panic("routerCfg must not be nil") panic("routerCfg must not be nil")
} }
@ -933,7 +933,7 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config)
e.mu.Unlock() e.mu.Unlock()
engineChanged := deepprint.UpdateHash(&e.lastEngineSigFull, cfg) engineChanged := deepprint.UpdateHash(&e.lastEngineSigFull, cfg)
routerChanged := deepprint.UpdateHash(&e.lastRouterSig, routerCfg) routerChanged := deepprint.UpdateHash(&e.lastRouterSig, routerCfg, hosts, localDomains)
if !engineChanged && !routerChanged { if !engineChanged && !routerChanged {
return ErrNoChanges return ErrNoChanges
} }
@ -979,20 +979,24 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config)
} }
if routerChanged { if routerChanged {
resolverCfg := resolver.Config{
Hosts: hosts,
LocalDomains: localDomains,
Routes: map[string][]netaddr.IPPort{},
}
if routerCfg.DNS.Proxied { if routerCfg.DNS.Proxied {
ips := routerCfg.DNS.Nameservers ips := routerCfg.DNS.Nameservers
upstreams := make([]net.Addr, len(ips)) upstreams := make([]netaddr.IPPort, len(ips))
for i, ip := range ips { for i, ip := range ips {
stdIP := ip.IPAddr() upstreams[i] = netaddr.IPPort{
upstreams[i] = &net.UDPAddr{ IP: ip,
IP: stdIP.IP,
Port: 53, Port: 53,
Zone: stdIP.Zone,
} }
} }
e.resolver.SetUpstreams(upstreams) resolverCfg.Routes["."] = upstreams
routerCfg.DNS.Nameservers = []netaddr.IP{tsaddr.TailscaleServiceIP()} routerCfg.DNS.Nameservers = []netaddr.IP{tsaddr.TailscaleServiceIP()}
} }
e.resolver.SetConfig(resolverCfg) // TODO: check error and propagate to health pkg
e.logf("wgengine: Reconfig: configuring router") e.logf("wgengine: Reconfig: configuring router")
err := e.router.Set(routerCfg) err := e.router.Set(routerCfg)
health.SetRouterHealth(err) health.SetRouterHealth(err)
@ -1018,10 +1022,6 @@ func (e *userspaceEngine) SetFilter(filt *filter.Filter) {
e.tundev.SetFilter(filt) e.tundev.SetFilter(filt)
} }
func (e *userspaceEngine) SetDNSMap(dm *resolver.Map) {
e.resolver.SetMap(dm)
}
func (e *userspaceEngine) SetStatusCallback(cb StatusCallback) { func (e *userspaceEngine) SetStatusCallback(cb StatusCallback) {
e.mu.Lock() e.mu.Lock()
defer e.mu.Unlock() defer e.mu.Unlock()

@ -108,7 +108,7 @@ func TestUserspaceEngineReconfig(t *testing.T) {
}, },
} }
err = e.Reconfig(cfg, routerCfg) err = e.Reconfig(cfg, routerCfg, nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

@ -14,7 +14,6 @@ import (
"inet.af/netaddr" "inet.af/netaddr"
"tailscale.com/ipn/ipnstate" "tailscale.com/ipn/ipnstate"
"tailscale.com/net/dns/resolver"
"tailscale.com/net/tstun" "tailscale.com/net/tstun"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/netmap" "tailscale.com/types/netmap"
@ -74,8 +73,8 @@ func (e *watchdogEngine) watchdog(name string, fn func()) {
}) })
} }
func (e *watchdogEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config) error { func (e *watchdogEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, hosts map[string][]netaddr.IP, localDomains []string) error {
return e.watchdogErr("Reconfig", func() error { return e.wrap.Reconfig(cfg, routerCfg) }) return e.watchdogErr("Reconfig", func() error { return e.wrap.Reconfig(cfg, routerCfg, hosts, localDomains) })
} }
func (e *watchdogEngine) GetLinkMonitor() *monitor.Mon { func (e *watchdogEngine) GetLinkMonitor() *monitor.Mon {
return e.wrap.GetLinkMonitor() return e.wrap.GetLinkMonitor()
@ -86,9 +85,6 @@ func (e *watchdogEngine) GetFilter() *filter.Filter {
func (e *watchdogEngine) SetFilter(filt *filter.Filter) { func (e *watchdogEngine) SetFilter(filt *filter.Filter) {
e.watchdog("SetFilter", func() { e.wrap.SetFilter(filt) }) e.watchdog("SetFilter", func() { e.wrap.SetFilter(filt) })
} }
func (e *watchdogEngine) SetDNSMap(dm *resolver.Map) {
e.watchdog("SetDNSMap", func() { e.wrap.SetDNSMap(dm) })
}
func (e *watchdogEngine) SetStatusCallback(cb StatusCallback) { func (e *watchdogEngine) SetStatusCallback(cb StatusCallback) {
e.watchdog("SetStatusCallback", func() { e.wrap.SetStatusCallback(cb) }) e.watchdog("SetStatusCallback", func() { e.wrap.SetStatusCallback(cb) })
} }

@ -9,7 +9,6 @@ import (
"inet.af/netaddr" "inet.af/netaddr"
"tailscale.com/ipn/ipnstate" "tailscale.com/ipn/ipnstate"
"tailscale.com/net/dns/resolver"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/netmap" "tailscale.com/types/netmap"
"tailscale.com/wgengine/filter" "tailscale.com/wgengine/filter"
@ -57,7 +56,7 @@ type Engine interface {
// sends an updated network map. // sends an updated network map.
// //
// The returned error is ErrNoChanges if no changes were made. // The returned error is ErrNoChanges if no changes were made.
Reconfig(*wgcfg.Config, *router.Config) error Reconfig(*wgcfg.Config, *router.Config, map[string][]netaddr.IP, []string) error
// GetFilter returns the current packet filter, if any. // GetFilter returns the current packet filter, if any.
GetFilter() *filter.Filter GetFilter() *filter.Filter
@ -65,9 +64,6 @@ type Engine interface {
// SetFilter updates the packet filter. // SetFilter updates the packet filter.
SetFilter(*filter.Filter) SetFilter(*filter.Filter)
// SetDNSMap updates the DNS map.
SetDNSMap(*resolver.Map)
// SetStatusCallback sets the function to call when the // SetStatusCallback sets the function to call when the
// WireGuard status changes. // WireGuard status changes.
SetStatusCallback(StatusCallback) SetStatusCallback(StatusCallback)

Loading…
Cancel
Save