tsdns: dual resolution mode, IPv6 support (#526)

This change adds to tsdns the ability to delegate lookups to upstream nameservers.
This is crucial for setting Magic DNS as the system resolver.

Signed-off-by: Dmytro Shynkevych <dmytro@tailscale.com>
reviewable/pr534/r1
Dmytro Shynkevych 4 years ago committed by GitHub
parent ce1b52bb71
commit 67ebba90e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -30,6 +30,6 @@ require (
golang.org/x/sys v0.0.0-20200501052902-10377860bb8e golang.org/x/sys v0.0.0-20200501052902-10377860bb8e
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 golang.org/x/time v0.0.0-20191024005414-555d28b269f0
honnef.co/go/tools v0.0.1-2020.1.4 honnef.co/go/tools v0.0.1-2020.1.4
inet.af/netaddr v0.0.0-20200702150737-4591d218f82c inet.af/netaddr v0.0.0-20200706235120-1ac1a40fae99
rsc.io/goversion v1.2.0 rsc.io/goversion v1.2.0
) )

@ -160,7 +160,8 @@ gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo=
gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
honnef.co/go/tools v0.0.1-2020.1.4 h1:UoveltGrhghAA7ePc+e+QYDHXrBps2PqFZiHkGR/xK8= honnef.co/go/tools v0.0.1-2020.1.4 h1:UoveltGrhghAA7ePc+e+QYDHXrBps2PqFZiHkGR/xK8=
honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
inet.af/netaddr v0.0.0-20200702150737-4591d218f82c h1:j3Z4HL4KcLBDU1kmRpXTD5fikKBqIkE+7vFKS5mCz3Y= inet.af v0.0.0-20181218191229-53da77bc832c h1:U3RoiyEF5b3Y1SVL6NNvpkgqUz2qS3a0OJh9kpSCN04=
inet.af/netaddr v0.0.0-20200702150737-4591d218f82c/go.mod h1:qqYzz/2whtrbWJvt+DNWQyvekNN4ePQZcg2xc2/Yjww= inet.af/netaddr v0.0.0-20200706235120-1ac1a40fae99 h1:+43CBpWlrXThaOxixPS5JXEJZC8zaMCpDu3aKffe0bs=
inet.af/netaddr v0.0.0-20200706235120-1ac1a40fae99/go.mod h1:qqYzz/2whtrbWJvt+DNWQyvekNN4ePQZcg2xc2/Yjww=
rsc.io/goversion v1.2.0 h1:SPn+NLTiAG7w30IRK/DKp1BjvpWabYgxlLp/+kx5J8w= rsc.io/goversion v1.2.0 h1:SPn+NLTiAG7w30IRK/DKp1BjvpWabYgxlLp/+kx5J8w=
rsc.io/goversion v1.2.0/go.mod h1:Eih9y/uIBS3ulggl7KNJ09xGSLcuNaLgmvvqa07sgfo= rsc.io/goversion v1.2.0/go.mod h1:Eih9y/uIBS3ulggl7KNJ09xGSLcuNaLgmvvqa07sgfo=

@ -467,7 +467,7 @@ func (b *LocalBackend) updateDNSMap(netMap *controlclient.NetworkMap) {
// Like PeerStatus.SimpleHostName() // Like PeerStatus.SimpleHostName()
domain = strings.TrimSuffix(domain, ".local") domain = strings.TrimSuffix(domain, ".local")
domain = strings.TrimSuffix(domain, ".localdomain") domain = strings.TrimSuffix(domain, ".localdomain")
domain = domain + ".ipn.dev" domain = domain + ".tailscale.us"
domainToIP[domain] = netaddr.IPFrom16(peer.Addresses[0].IP.Addr) domainToIP[domain] = netaddr.IPFrom16(peer.Addresses[0].IP.Addr)
} }
b.e.SetDNSMap(tsdns.NewMap(domainToIP)) b.e.SetDNSMap(tsdns.NewMap(domainToIP))

@ -7,6 +7,8 @@ package packet
import ( import (
"fmt" "fmt"
"net" "net"
"inet.af/netaddr"
) )
// IP is an IPv4 address. // IP is an IPv4 address.
@ -22,6 +24,17 @@ func NewIP(b net.IP) IP {
return IP(get32(b4)) return IP(get32(b4))
} }
// IPFromNetaddr converts a netaddr.IP to an IP.
func IPFromNetaddr(ip netaddr.IP) IP {
ipbytes := ip.As4()
return IP(get32(ipbytes[:]))
}
// Netaddr converts an IP to a netaddr.IP.
func (ip IP) Netaddr() netaddr.IP {
return netaddr.IPv4(byte(ip>>24), byte(ip>>16), byte(ip>>8), byte(ip))
}
func (ip IP) String() string { func (ip IP) String() string {
return fmt.Sprintf("%d.%d.%d.%d", byte(ip>>24), byte(ip>>16), byte(ip>>8), byte(ip)) return fmt.Sprintf("%d.%d.%d.%d", byte(ip>>24), byte(ip>>16), byte(ip>>8), byte(ip))
} }

@ -7,128 +7,319 @@
package tsdns package tsdns
import ( import (
"encoding/binary" "bytes"
"context"
"errors" "errors"
"strings"
"sync" "sync"
"time" "time"
dns "golang.org/x/net/dns/dnsmessage" dns "golang.org/x/net/dns/dnsmessage"
"inet.af/netaddr" "inet.af/netaddr"
"tailscale.com/net/netns"
"tailscale.com/types/logger" "tailscale.com/types/logger"
"tailscale.com/wgengine/packet"
) )
// maxResponseSize is the maximum size of a response from a Resolver.
const maxResponseSize = 512
// queueSize is the maximal number of DNS requests that can be pending at a time.
// If EnqueueRequest is called when this many requests are already pending,
// the request will be dropped to avoid blocking the caller.
const queueSize = 8
// delegateTimeout is the maximal amount of time Resolver will wait
// for upstream nameservers to process a query.
const delegateTimeout = 5 * time.Second
// defaultTTL is the TTL of all responses from Resolver. // defaultTTL is the TTL of all responses from Resolver.
const defaultTTL = 600 * time.Second const defaultTTL = 600 * time.Second
// ErrClosed indicates that the resolver has been closed and readers should exit.
var ErrClosed = errors.New("closed")
var ( var (
errAllFailed = errors.New("all upstream nameservers failed")
errFullQueue = errors.New("request queue full")
errMapNotSet = errors.New("domain map not set") errMapNotSet = errors.New("domain map not set")
errNoSuchDomain = errors.New("domain does not exist")
errNotImplemented = errors.New("query type not implemented") errNotImplemented = errors.New("query type not implemented")
errNotOurName = errors.New("not an *.ipn.dev domain")
errNotOurQuery = errors.New("query not for this resolver")
errNotQuery = errors.New("not a DNS query") errNotQuery = errors.New("not a DNS query")
errSmallBuffer = errors.New("response buffer too small")
) )
var ( // Map is all the data Resolver needs to resolve DNS queries within the Tailscale network.
defaultIP = packet.IP(binary.BigEndian.Uint32([]byte{100, 100, 100, 100}))
defaultPort = uint16(53)
)
// Map is all the data Resolver needs to resolve DNS queries.
type Map struct { type Map struct {
// domainToIP is a mapping of Tailscale domains to their IP addresses. // domainToIP is a mapping of Tailscale domains to their IP addresses.
// For example, monitoring.ipn.dev -> 100.64.0.1. // For example, monitoring.tailscale.us -> 100.64.0.1.
domainToIP map[string]netaddr.IP domainToIP map[string]netaddr.IP
} }
// NewMap returns a new Map with domain to address mapping given by domainToIP. // NewMap returns a new Map with domain to address mapping given by domainToIP.
// It takes ownership of the provided map.
func NewMap(domainToIP map[string]netaddr.IP) *Map { func NewMap(domainToIP map[string]netaddr.IP) *Map {
return &Map{ return &Map{domainToIP: domainToIP}
domainToIP: domainToIP,
}
} }
// Resolver is a DNS resolver for domain names of the form *.ipn.dev. // Packet represents a DNS payload together with the address of its origin.
type Packet struct {
// Payload is the application layer DNS payload.
// Resolver assumes ownership of the request payload when it is enqueued
// and cedes ownership of the response payload when it is returned from NextResponse.
Payload []byte
// Addr is the source address for a request and the destination address for a response.
Addr netaddr.IPPort
}
// Resolver is a DNS resolver for nodes on the Tailscale network,
// associating them with domain names of the form <mynode>.<mydomain>.<root>.
// If it is asked to resolve a domain that is not of that form,
// it delegates to upstream nameservers if any are set.
type Resolver struct { type Resolver struct {
logf logger.Logf logf logger.Logf
// ip is the IP on which the resolver is listening. // The asynchronous interface is due to the fact that resolution may potentially
ip packet.IP // block for a long time (if the upstream nameserver is slow to reach).
// port is the port on which the resolver is listening.
port uint16 // queue is a buffered channel holding DNS requests queued for resolution.
queue chan Packet
// responses is an unbuffered channel to which responses are sent.
responses chan Packet
// errors is an unbuffered channel to which errors are sent.
errors chan error
// closed notifies the poll goroutines to stop.
closed chan struct{}
// pollGroup signals when all poll goroutines have stopped.
pollGroup sync.WaitGroup
// rootDomain is <root> in <mynode>.<mydomain>.<root>.
rootDomain []byte
// dialer is the netns.Dialer used for delegation.
dialer netns.Dialer
// mu guards the following fields from being updated while used. // mu guards the following fields from being updated while used.
mu sync.Mutex mu sync.RWMutex
// dnsMap is the map most recently received from the control server. // dnsMap is the map most recently received from the control server.
dnsMap *Map dnsMap *Map
// nameservers is the list of nameserver addresses that should be used
// if the received query is not for a Tailscale node.
// The addresses are strings of the form ip:port, as expected by Dial.
nameservers []string
} }
// NewResolver constructs a resolver with default parameters. // NewResolver constructs a resolver associated with the given root domain.
func NewResolver(logf logger.Logf) *Resolver { func NewResolver(logf logger.Logf, rootDomain string) *Resolver {
r := &Resolver{ r := &Resolver{
logf: logf, logf: logger.WithPrefix(logf, "tsdns: "),
ip: defaultIP, queue: make(chan Packet, queueSize),
port: defaultPort, responses: make(chan Packet),
errors: make(chan error),
closed: make(chan struct{}),
// Conform to the name format dnsmessage uses (trailing period, bytes).
rootDomain: []byte(rootDomain + "."),
dialer: netns.NewDialer(),
} }
return r return r
} }
// AcceptsPacket determines if the given packet is func (r *Resolver) Start() {
// directed to this resolver (by ip and port). // TODO(dmytro): spawn more than one goroutine? They block on delegation.
// We also require that UDP be used to simplify things for now. r.pollGroup.Add(1)
func (r *Resolver) AcceptsPacket(in *packet.ParsedPacket) bool { go r.poll()
return in.DstIP == r.ip && in.DstPort == r.port && in.IPProto == packet.UDP }
// Close shuts down the resolver and ensures poll goroutines have exited.
// The Resolver cannot be used again after Close is called.
func (r *Resolver) Close() {
select {
case <-r.closed:
return
default:
// continue
}
close(r.closed)
r.pollGroup.Wait()
} }
// SetMap sets the resolver's DNS map. // SetMap sets the resolver's DNS map, taking ownership of it.
func (r *Resolver) SetMap(m *Map) { func (r *Resolver) SetMap(m *Map) {
r.mu.Lock() r.mu.Lock()
r.dnsMap = m r.dnsMap = m
r.mu.Unlock() r.mu.Unlock()
} }
// Resolve maps a given domain name to the IP address of the host that owns it. // SetUpstreamNameservers sets the addresses of the resolver's
func (r *Resolver) Resolve(domain string) (netaddr.IP, dns.RCode, error) { // upstream nameservers, taking ownership of the argument.
// If not a subdomain of ipn.dev, then we must refuse this query. // The addresses should be strings of the form ip:port,
// We do this before checking the map to distinguish beween nonexistent domains // matching what Dial("udp", addr) expects as addr.
// and misdirected queries. func (r *Resolver) SetNameservers(nameservers []string) {
if !strings.HasSuffix(domain, ".ipn.dev") { r.mu.Lock()
return netaddr.IP{}, dns.RCodeRefused, errNotOurName r.nameservers = nameservers
r.mu.Unlock()
}
// EnqueueRequest places the given DNS request in the resolver's queue.
// It takes ownership of the payload and does not block.
// If the queue is full, the request will be dropped and an error will be returned.
func (r *Resolver) EnqueueRequest(request Packet) error {
select {
case r.queue <- request:
return nil
default:
return errFullQueue
} }
}
r.mu.Lock() // NextResponse returns a DNS response to a previously enqueued request.
// It blocks until a response is available and gives up ownership of the response payload.
func (r *Resolver) NextResponse() (Packet, error) {
select {
case resp := <-r.responses:
return resp, nil
case err := <-r.errors:
return Packet{}, err
case <-r.closed:
return Packet{}, ErrClosed
}
}
// Resolve maps a given domain name to the IP address of the host that owns it.
// The domain name must not have a trailing period.
func (r *Resolver) Resolve(domain string) (netaddr.IP, dns.RCode, error) {
r.mu.RLock()
if r.dnsMap == nil { if r.dnsMap == nil {
r.mu.Unlock() r.mu.RUnlock()
return netaddr.IP{}, dns.RCodeServerFailure, errMapNotSet return netaddr.IP{}, dns.RCodeServerFailure, errMapNotSet
} }
addr, found := r.dnsMap.domainToIP[domain] addr, found := r.dnsMap.domainToIP[domain]
r.mu.Unlock() r.mu.RUnlock()
if !found { if !found {
return netaddr.IP{}, dns.RCodeNameError, errNoSuchDomain return netaddr.IP{}, dns.RCodeNameError, nil
} }
return addr, dns.RCodeSuccess, nil return addr, dns.RCodeSuccess, nil
} }
func (r *Resolver) poll() {
defer r.pollGroup.Done()
var (
packet Packet
err error
)
for {
select {
case packet = <-r.queue:
// continue
case <-r.closed:
return
}
packet.Payload, err = r.respond(packet.Payload)
if err != nil {
select {
case r.errors <- err:
// continue
case <-r.closed:
return
}
} else {
select {
case r.responses <- packet:
// continue
case <-r.closed:
return
}
}
}
}
// queryServer obtains a DNS response by querying the given server.
func (r *Resolver) queryServer(ctx context.Context, server string, query []byte) ([]byte, error) {
conn, err := r.dialer.DialContext(ctx, "udp", server)
if err != nil {
return nil, err
}
defer conn.Close()
// Interrupt the current operation when the context is cancelled.
go func() {
<-ctx.Done()
conn.SetDeadline(time.Unix(1, 0))
}()
_, err = conn.Write(query)
if err != nil {
return nil, err
}
out := make([]byte, maxResponseSize)
n, err := conn.Read(out)
if err != nil {
return nil, err
}
return out[:n], nil
}
// delegate forwards the query to all upstream nameservers and returns the first response.
func (r *Resolver) delegate(query []byte) ([]byte, error) {
r.mu.RLock()
nameservers := r.nameservers
r.mu.RUnlock()
if len(r.nameservers) == 0 {
return nil, errAllFailed
}
ctx, cancel := context.WithTimeout(context.Background(), delegateTimeout)
defer cancel()
// Common case, don't spawn goroutines.
if len(nameservers) == 1 {
return r.queryServer(ctx, nameservers[0], query)
}
datach := make(chan []byte)
for _, server := range nameservers {
go func(s string) {
resp, err := r.queryServer(ctx, s, query)
// Only print errors not due to cancelation after first response.
if err != nil && ctx.Err() != context.Canceled {
r.logf("querying %s: %v", s, err)
}
datach <- resp
}(server)
}
var response []byte
for range nameservers {
cur := <-datach
if cur != nil && response == nil {
// Received first successful response
response = cur
cancel()
}
}
if response == nil {
return nil, errAllFailed
}
return response, nil
}
type response struct { type response struct {
Header dns.Header Header dns.Header
ResourceHeader dns.ResourceHeader Question dns.Question
Question dns.Question Name string
// TODO(dmytro): support IPv6. IP netaddr.IP
IP netaddr.IP
} }
// parseQuery parses the query in given packet into a response struct. // parseQuery parses the query in given packet into a response struct.
func (r *Resolver) parseQuery(query *packet.ParsedPacket, resp *response) error { func (r *Resolver) parseQuery(query []byte, resp *response) error {
var parser dns.Parser var parser dns.Parser
var err error var err error
resp.Header, err = parser.Start(query.Payload()) resp.Header, err = parser.Start(query)
if err != nil { if err != nil {
return err return err
} }
@ -145,146 +336,123 @@ func (r *Resolver) parseQuery(query *packet.ParsedPacket, resp *response) error
return nil return nil
} }
// makeResponse resolves the question stored in resp and sets the answer fields. // marshalARecord serializes an A record into an active builder.
func (r *Resolver) makeResponse(resp *response) error {
var err error
name := resp.Question.Name.String()
if len(name) > 0 {
name = name[:len(name)-1]
}
if resp.Question.Type == dns.TypeA {
// Remove final dot from name: *.ipn.dev. -> *.ipn.dev
resp.IP, resp.Header.RCode, err = r.Resolve(name)
} else {
resp.Header.RCode = dns.RCodeNotImplemented
err = errNotImplemented
}
return err
}
// marshalAnswer serializes the answer record into an active builder.
// The caller may continue using the builder following the call. // The caller may continue using the builder following the call.
func marshalAnswer(resp *response, builder *dns.Builder) error { func marshalARecord(name dns.Name, ip netaddr.IP, builder *dns.Builder) error {
var answer dns.AResource var answer dns.AResource
err := builder.StartAnswers()
if err != nil {
return err
}
answerHeader := dns.ResourceHeader{ answerHeader := dns.ResourceHeader{
Name: resp.Question.Name, Name: name,
Type: dns.TypeA, Type: dns.TypeA,
Class: dns.ClassINET, Class: dns.ClassINET,
TTL: uint32(defaultTTL / time.Second), TTL: uint32(defaultTTL / time.Second),
} }
ip := resp.IP.As16() ipbytes := ip.As4()
copy(answer.A[:], ip[12:]) copy(answer.A[:], ipbytes[:])
return builder.AResource(answerHeader, answer) return builder.AResource(answerHeader, answer)
} }
// marshalResponse serializes the DNS response into an active builder. // marshalAAAARecord serializes an AAAA record into an active builder.
// The caller may continue using the builder following the call. // The caller may continue using the builder following the call.
func marshalResponse(resp *response, builder *dns.Builder) error { func marshalAAAARecord(name dns.Name, ip netaddr.IP, builder *dns.Builder) error {
err := builder.StartQuestions() var answer dns.AAAAResource
if err != nil {
return err
}
err = builder.Question(resp.Question)
if err != nil {
return err
}
if resp.Header.RCode == dns.RCodeSuccess { answerHeader := dns.ResourceHeader{
err = marshalAnswer(resp, builder) Name: name,
if err != nil { Type: dns.TypeAAAA,
return err Class: dns.ClassINET,
} TTL: uint32(defaultTTL / time.Second),
} }
ipbytes := ip.As16()
return nil copy(answer.AAAA[:], ipbytes[:])
return builder.AAAAResource(answerHeader, answer)
} }
// marshalReponsePacket marshals a full DNS packet (including headers) // marshalResponse serializes the DNS response into a new buffer.
// representing resp, which is a response to query, into buf. func marshalResponse(resp *response) ([]byte, error) {
// It returns buf trimmed to the length of the response packet.
func marshalResponsePacket(query *packet.ParsedPacket, resp *response, buf []byte) ([]byte, error) {
udpHeader := query.UDPHeader()
udpHeader.ToResponse()
offset := udpHeader.Len()
resp.Header.Response = true resp.Header.Response = true
resp.Header.Authoritative = true resp.Header.Authoritative = true
if resp.Header.RecursionDesired { if resp.Header.RecursionDesired {
resp.Header.RecursionAvailable = true resp.Header.RecursionAvailable = true
} }
// dns.Builder appends to the passed buffer (without reallocation when possible), builder := dns.NewBuilder(nil, resp.Header)
// so we pass in a zero-length slice starting at the point it should start writing.
builder := dns.NewBuilder(buf[offset:offset], resp.Header)
err := marshalResponse(resp, &builder) err := builder.StartQuestions()
if err != nil { if err != nil {
return nil, err return nil, err
} }
// rbuf is the response slice with the correct length starting at offset. err = builder.Question(resp.Question)
rbuf, err := builder.Finish()
if err != nil { if err != nil {
return nil, err return nil, err
} }
end := offset + len(rbuf) // Only successful responses contain answers.
err = udpHeader.Marshal(buf[:end]) if resp.Header.RCode != dns.RCodeSuccess {
return builder.Finish()
}
err = builder.StartAnswers()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return buf[:end], nil if resp.IP.Is4() {
} err = marshalARecord(resp.Question.Name, resp.IP, &builder)
} else {
// Respond writes a response to query into buf and returns buf trimmed to the response length. err = marshalAAAARecord(resp.Question.Name, resp.IP, &builder)
// It is assumed that r.AcceptsPacket(query) is true.
func (r *Resolver) Respond(query *packet.ParsedPacket, buf []byte) ([]byte, error) {
var resp response
var err error
// 0. Verify that contract is upheld.
if !r.AcceptsPacket(query) {
return nil, errNotOurQuery
} }
// A DNS response is at least as long as the query if err != nil {
if len(buf) < len(query.Buffer()) { return nil, err
return nil, errSmallBuffer
} }
// 1. Parse query packet. return builder.Finish()
err = r.parseQuery(query, &resp) }
// respond returns a DNS response to query.
func (r *Resolver) respond(query []byte) ([]byte, error) {
resp := new(response)
// ParseQuery is sufficiently fast to run on every DNS packet.
// This is considerably simpler than extracting the name by hand
// to shave off microseconds in case of delegation.
err := r.parseQuery(query, resp)
// We will not return this error: it is the sender's fault. // We will not return this error: it is the sender's fault.
if err != nil { if err != nil {
r.logf("tsdns: error during query parsing: %v", err) r.logf("parsing query: %v", err)
resp.Header.RCode = dns.RCodeFormatError resp.Header.RCode = dns.RCodeFormatError
return marshalResponsePacket(query, &resp, buf) return marshalResponse(resp)
} }
// 2. Service the query. // Delegate only when not a subdomain of rootDomain.
err = r.makeResponse(&resp) // We do this on bytes because Name.String() allocates.
rawName := resp.Question.Name.Data[:resp.Question.Name.Length]
if !bytes.HasSuffix(rawName, r.rootDomain) {
out, err := r.delegate(query)
if err != nil {
r.logf("delegating: %v", err)
resp.Header.RCode = dns.RCodeServerFailure
return marshalResponse(resp)
}
return out, nil
}
switch resp.Question.Type {
case dns.TypeA, dns.TypeAAAA:
domain := resp.Question.Name.String()
// Strip off the trailing period.
// This is safe: Name is guaranteed to have a trailing period by construction.
domain = domain[:len(domain)-1]
resp.IP, resp.Header.RCode, err = r.Resolve(domain)
default:
resp.Header.RCode = dns.RCodeNotImplemented
err = errNotImplemented
}
// We will not return this error: it is the sender's fault. // We will not return this error: it is the sender's fault.
if err != nil { if err != nil {
r.logf("tsdns: error during name resolution: %v", err) r.logf("resolving: %v", err)
return marshalResponsePacket(query, &resp, buf)
}
// For now, we require IPv4 in all cases.
// If we somehow came up with a non-IPv4 address, it's our fault.
if !resp.IP.Is4() {
resp.Header.RCode = dns.RCodeServerFailure
r.logf("tsdns: error during name resolution: IPv6 address: %v", resp.IP)
} }
// 3. Serialize the response. return marshalResponse(resp)
return marshalResponsePacket(query, &resp, buf)
} }

@ -6,113 +6,173 @@ package tsdns
import ( import (
"bytes" "bytes"
"errors"
"sync" "sync"
"testing" "testing"
dns "golang.org/x/net/dns/dnsmessage" dns "golang.org/x/net/dns/dnsmessage"
"inet.af/netaddr" "inet.af/netaddr"
"tailscale.com/wgengine/packet"
) )
var test2bytes = [16]byte{
0x00, 0x01, 0x02, 0x03,
0x04, 0x05, 0x06, 0x07,
0x08, 0x09, 0x0a, 0x0b,
0x0c, 0x0d, 0x0e, 0x0f,
}
var dnsMap = &Map{ var dnsMap = &Map{
domainToIP: map[string]netaddr.IP{ domainToIP: map[string]netaddr.IP{
"test1.ipn.dev": netaddr.IPv4(1, 2, 3, 4), "test1.ipn.dev": netaddr.IPv4(1, 2, 3, 4),
"test2.ipn.dev": netaddr.IPv4(5, 6, 7, 8), "test2.ipn.dev": netaddr.IPv6Raw(test2bytes),
}, },
} }
func dnspacket(srcip, dstip packet.IP, domain string, tp dns.Type, response bool) *packet.ParsedPacket { func dnspacket(domain string, tp dns.Type) []byte {
dnsHeader := dns.Header{Response: response} var dnsHeader dns.Header
question := dns.Question{ question := dns.Question{
Name: dns.MustNewName(domain), Name: dns.MustNewName(domain),
Type: tp, Type: tp,
Class: dns.ClassINET, Class: dns.ClassINET,
} }
udpHeader := &packet.UDPHeader{
IPHeader: packet.IPHeader{
SrcIP: srcip,
DstIP: dstip,
IPProto: packet.UDP,
},
SrcPort: 1234,
DstPort: 53,
}
builder := dns.NewBuilder(nil, dnsHeader) builder := dns.NewBuilder(nil, dnsHeader)
builder.StartQuestions() builder.StartQuestions()
builder.Question(question) builder.Question(question)
payload, _ := builder.Finish() payload, _ := builder.Finish()
buf := packet.Generate(udpHeader, payload) return payload
}
pp := new(packet.ParsedPacket) func extractipcode(response []byte) (netaddr.IP, dns.RCode, error) {
pp.Decode(buf) var ip netaddr.IP
var parser dns.Parser
return pp h, err := parser.Start(response)
} if err != nil {
return ip, 0, err
}
func TestAcceptsPacket(t *testing.T) { if !h.Response {
r := NewResolver(t.Logf) return ip, 0, errors.New("not a response")
r.SetMap(dnsMap) }
if h.RCode != dns.RCodeSuccess {
return ip, h.RCode, nil
}
src := packet.IP(0x64656667) // 100.101.102.103 err = parser.SkipAllQuestions()
dst := packet.IP(0x64646464) // 100.100.100.100 if err != nil {
tests := []struct { return ip, 0, err
name string
request *packet.ParsedPacket
want bool
}{
{"valid", dnspacket(src, dst, "test1.ipn.dev.", dns.TypeA, false), true},
{"invalid", dnspacket(dst, src, "test1.ipn.dev.", dns.TypeA, false), false},
} }
for _, tt := range tests { ah, err := parser.AnswerHeader()
t.Run(tt.name, func(t *testing.T) { if err != nil {
accepts := r.AcceptsPacket(tt.request) return ip, 0, err
if accepts != tt.want {
t.Errorf("accepts = %v; want %v", accepts, tt.want)
}
})
} }
switch ah.Type {
case dns.TypeA:
res, err := parser.AResource()
if err != nil {
return ip, 0, err
}
ip = netaddr.IPv4(res.A[0], res.A[1], res.A[2], res.A[3])
case dns.TypeAAAA:
res, err := parser.AAAAResource()
if err != nil {
return ip, 0, err
}
ip = netaddr.IPv6Raw(res.AAAA)
default:
return ip, 0, errors.New("type not in {A, AAAA}")
}
return ip, h.RCode, nil
}
func syncRespond(r *Resolver, query []byte) ([]byte, error) {
request := Packet{Payload: query}
r.EnqueueRequest(request)
resp, err := r.NextResponse()
return resp.Payload, err
} }
func TestResolve(t *testing.T) { func TestResolve(t *testing.T) {
r := NewResolver(t.Logf) r := NewResolver(t.Logf, "ipn.dev")
r.SetMap(dnsMap) r.SetMap(dnsMap)
r.Start()
tests := []struct { tests := []struct {
name string name string
domain string domain string
ip netaddr.IP ip netaddr.IP
code dns.RCode code dns.RCode
iserr bool
}{ }{
{"valid", "test1.ipn.dev", netaddr.IPv4(1, 2, 3, 4), dns.RCodeSuccess, false}, {"ipv4", "test1.ipn.dev", netaddr.IPv4(1, 2, 3, 4), dns.RCodeSuccess},
{"nxdomain", "test3.ipn.dev", netaddr.IP{}, dns.RCodeNameError, true}, {"ipv6", "test2.ipn.dev", netaddr.IPv6Raw(test2bytes), dns.RCodeSuccess},
{"not our domain", "google.com", netaddr.IP{}, dns.RCodeRefused, true}, {"nxdomain", "test3.ipn.dev", netaddr.IP{}, dns.RCodeNameError},
{"foreign domain", "google.com", netaddr.IP{}, dns.RCodeNameError},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ip, code, err := r.Resolve(tt.domain) ip, code, err := r.Resolve(tt.domain)
if err != nil && !tt.iserr { if err != nil {
t.Errorf("err = %v; want nil", err) t.Errorf("err = %v; want nil", err)
} else if err == nil && tt.iserr {
t.Errorf("err = nil; want non-nil")
} }
if code != tt.code { if code != tt.code {
t.Errorf("code = %v; want %v", code, tt.code) t.Errorf("code = %v; want %v", code, tt.code)
} }
// Only check ip for non-err // Only check ip for non-err
if !tt.iserr && ip != tt.ip { if ip != tt.ip {
t.Errorf("ip = %v; want %v", ip, tt.ip)
}
})
}
}
func TestDelegate(t *testing.T) {
r := NewResolver(t.Logf, "ipn.dev")
r.SetNameservers([]string{"9.9.9.9:53", "[2620:fe::fe]:53"})
r.Start()
localhostv4, _ := netaddr.ParseIP("127.0.0.1")
localhostv6, _ := netaddr.ParseIP("::1")
tests := []struct {
name string
query []byte
ip netaddr.IP
code dns.RCode
}{
{"ipv4", dnspacket("localhost.", dns.TypeA), localhostv4, dns.RCodeSuccess},
{"ipv6", dnspacket("localhost.", dns.TypeAAAA), localhostv6, dns.RCodeSuccess},
{"nxdomain", dnspacket("invalid.invalid.", dns.TypeA), netaddr.IP{}, dns.RCodeNameError},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, err := syncRespond(r, tt.query)
if err != nil {
t.Errorf("err = %v; want nil", err)
return
}
ip, code, err := extractipcode(resp)
if err != nil {
t.Errorf("extract: err = %v; want nil (in %x)", err, resp)
return
}
if code != tt.code {
t.Errorf("code = %v; want %v", code, tt.code)
}
if ip != tt.ip {
t.Errorf("ip = %v; want %v", ip, tt.ip) t.Errorf("ip = %v; want %v", ip, tt.ip)
} }
}) })
} }
} }
func TestConcurrentSet(t *testing.T) { func TestConcurrentSetMap(t *testing.T) {
r := NewResolver(t.Logf) r := NewResolver(t.Logf, "ipn.dev")
r.Start()
// This is purely to ensure that Resolve does not race with SetMap. // This is purely to ensure that Resolve does not race with SetMap.
var wg sync.WaitGroup var wg sync.WaitGroup
@ -128,16 +188,26 @@ func TestConcurrentSet(t *testing.T) {
wg.Wait() wg.Wait()
} }
var validResponse = []byte{ func TestConcurrentSetNameservers(t *testing.T) {
// IP header r := NewResolver(t.Logf, "ipn.dev")
0x45, 0x00, 0x00, 0x58, 0xff, 0xff, 0x00, 0x00, 0x40, 0x11, 0xe7, 0x00, r.Start()
// Source IP packet := dnspacket("google.com.", dns.TypeA)
0x64, 0x64, 0x64, 0x64,
// Destination IP // This is purely to ensure that delegation does not race with SetNameservers.
0x64, 0x65, 0x66, 0x67, var wg sync.WaitGroup
// UDP header wg.Add(2)
0x00, 0x35, 0x04, 0xd2, 0x00, 0x44, 0x53, 0xdd, go func() {
// DNS payload defer wg.Done()
r.SetNameservers([]string{"9.9.9.9:53"})
}()
go func() {
defer wg.Done()
syncRespond(r, packet)
}()
wg.Wait()
}
var validIPv4Response = []byte{
0x00, 0x00, // transaction id: 0 0x00, 0x00, // transaction id: 0
0x84, 0x00, // flags: response, authoritative, no error 0x84, 0x00, // flags: response, authoritative, no error
0x00, 0x01, // one question 0x00, 0x01, // one question
@ -154,16 +224,25 @@ var validResponse = []byte{
0x01, 0x02, 0x03, 0x04, // A: 1.2.3.4 0x01, 0x02, 0x03, 0x04, // A: 1.2.3.4
} }
var validIPv6Response = []byte{
0x00, 0x00, // transaction id: 0
0x84, 0x00, // flags: response, authoritative, no error
0x00, 0x01, // one question
0x00, 0x01, // one answer
0x00, 0x00, 0x00, 0x00, // no authority or additional RRs
// Question:
0x05, 0x74, 0x65, 0x73, 0x74, 0x32, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name
0x00, 0x1c, 0x00, 0x01, // type AAAA, class IN
// Answer:
0x05, 0x74, 0x65, 0x73, 0x74, 0x32, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name
0x00, 0x1c, 0x00, 0x01, // type AAAA, class IN
0x00, 0x00, 0x02, 0x58, // TTL: 600
0x00, 0x10, // length: 16 bytes
// AAAA: 0001:0203:0405:0607:0809:0A0B:0C0D:0E0F
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0xb, 0xc, 0xd, 0xe, 0xf,
}
var nxdomainResponse = []byte{ var nxdomainResponse = []byte{
// IP header
0x45, 0x00, 0x00, 0x3b, 0xff, 0xff, 0x00, 0x00, 0x40, 0x11, 0xe7, 0x1d,
// Source IP
0x64, 0x64, 0x64, 0x64,
// Destination IP
0x64, 0x65, 0x66, 0x67,
// UDP header
0x00, 0x35, 0x04, 0xd2, 0x00, 0x27, 0x25, 0x33,
// DNS payload
0x00, 0x00, // transaction id: 0 0x00, 0x00, // transaction id: 0
0x84, 0x03, // flags: response, authoritative, error: nxdomain 0x84, 0x03, // flags: response, authoritative, error: nxdomain
0x00, 0x01, // one question 0x00, 0x01, // one question
@ -175,25 +254,24 @@ var nxdomainResponse = []byte{
} }
func TestFull(t *testing.T) { func TestFull(t *testing.T) {
r := NewResolver(t.Logf) r := NewResolver(t.Logf, "ipn.dev")
r.SetMap(dnsMap) r.SetMap(dnsMap)
r.Start()
src := packet.IP(0x64656667) // 100.101.102.103
dst := packet.IP(0x64646464) // 100.100.100.100
// One full packet and one error packet // One full packet and one error packet
tests := []struct { tests := []struct {
name string name string
request *packet.ParsedPacket request []byte
response []byte response []byte
}{ }{
{"valid", dnspacket(src, dst, "test1.ipn.dev.", dns.TypeA, false), validResponse}, {"ipv4", dnspacket("test1.ipn.dev.", dns.TypeA), validIPv4Response},
{"error", dnspacket(src, dst, "test3.ipn.dev.", dns.TypeA, false), nxdomainResponse}, {"ipv6", dnspacket("test2.ipn.dev.", dns.TypeAAAA), validIPv6Response},
{"error", dnspacket("test3.ipn.dev.", dns.TypeA), nxdomainResponse},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
buf := make([]byte, 512) response, err := syncRespond(r, tt.request)
response, err := r.Respond(tt.request, buf)
if err != nil { if err != nil {
t.Errorf("err = %v; want nil", err) t.Errorf("err = %v; want nil", err)
} }
@ -205,43 +283,41 @@ func TestFull(t *testing.T) {
} }
func TestAllocs(t *testing.T) { func TestAllocs(t *testing.T) {
r := NewResolver(t.Logf) r := NewResolver(t.Logf, "ipn.dev")
r.SetMap(dnsMap) r.SetMap(dnsMap)
r.Start()
src := packet.IP(0x64656667) // 100.101.102.103 // It is seemingly pointless to test allocs in the delegate path,
dst := packet.IP(0x64646464) // 100.100.100.100 // as dialer.Dial -> Read -> Write alone comprise 12 allocs.
query := dnspacket(src, dst, "test1.ipn.dev.", dns.TypeA, false) query := dnspacket("test1.ipn.dev.", dns.TypeA)
buf := make([]byte, 512)
allocs := testing.AllocsPerRun(100, func() { allocs := testing.AllocsPerRun(100, func() {
r.Respond(query, buf) syncRespond(r, query)
}) })
if allocs > 0 { if allocs > 1 {
t.Errorf("allocs = %v; want 0", allocs) t.Errorf("allocs = %v; want 1", allocs)
} }
} }
func BenchmarkFull(b *testing.B) { func BenchmarkFull(b *testing.B) {
r := NewResolver(b.Logf) r := NewResolver(b.Logf, "ipn.dev")
r.SetMap(dnsMap) r.SetMap(dnsMap)
r.Start()
src := packet.IP(0x64656667) // 100.101.102.103
dst := packet.IP(0x64646464) // 100.100.100.100
// One full packet and one error packet // One full packet and one error packet
tests := []struct { tests := []struct {
name string name string
request *packet.ParsedPacket request []byte
}{ }{
{"valid", dnspacket(src, dst, "test1.ipn.dev.", dns.TypeA, false)}, {"valid", dnspacket("test1.ipn.dev.", dns.TypeA)},
{"nxdomain", dnspacket(src, dst, "test3.ipn.dev.", dns.TypeA, false)}, {"nxdomain", dnspacket("test3.ipn.dev.", dns.TypeA)},
} }
buf := make([]byte, 512)
for _, tt := range tests { for _, tt := range tests {
b.Run(tt.name, func(b *testing.B) { b.Run(tt.name, func(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
r.Respond(tt.request, buf) syncRespond(r, tt.request)
} }
}) })
} }

@ -25,6 +25,7 @@ import (
"github.com/tailscale/wireguard-go/tun" "github.com/tailscale/wireguard-go/tun"
"github.com/tailscale/wireguard-go/wgcfg" "github.com/tailscale/wireguard-go/wgcfg"
"go4.org/mem" "go4.org/mem"
"inet.af/netaddr"
"tailscale.com/control/controlclient" "tailscale.com/control/controlclient"
"tailscale.com/internal/deepprint" "tailscale.com/internal/deepprint"
"tailscale.com/ipn/ipnstate" "tailscale.com/ipn/ipnstate"
@ -51,6 +52,11 @@ import (
// discovery. // discovery.
const minimalMTU = 1280 const minimalMTU = 1280
const (
magicDNSIP = 0x64646464 // 100.100.100.100
magicDNSPort = 53
)
type userspaceEngine struct { type userspaceEngine struct {
logf logger.Logf logf logger.Logf
reqCh chan struct{} reqCh chan struct{}
@ -100,7 +106,7 @@ type EngineConfig struct {
// EchoRespondToAll determines whether ICMP Echo requests incoming from Tailscale peers // EchoRespondToAll determines whether ICMP Echo requests incoming from Tailscale peers
// will be intercepted and responded to, regardless of the source host. // will be intercepted and responded to, regardless of the source host.
EchoRespondToAll bool EchoRespondToAll bool
// UseTailscaleDNS determines whether DNS requests for names of the form *.ipn.dev // UseTailscaleDNS determines whether DNS requests for names of the form <mynode>.<mydomain>.<root>
// directed to the designated Taislcale DNS address (see wgengine/tsdns) // directed to the designated Taislcale DNS address (see wgengine/tsdns)
// will be intercepted and resolved by a tsdns.Resolver. // will be intercepted and resolved by a tsdns.Resolver.
UseTailscaleDNS bool UseTailscaleDNS bool
@ -174,7 +180,7 @@ func newUserspaceEngineAdvanced(conf EngineConfig) (_ Engine, reterr error) {
reqCh: make(chan struct{}, 1), reqCh: make(chan struct{}, 1),
waitCh: make(chan struct{}), waitCh: make(chan struct{}),
tundev: tstun.WrapTUN(logf, conf.TUN), tundev: tstun.WrapTUN(logf, conf.TUN),
resolver: tsdns.NewResolver(logf), resolver: tsdns.NewResolver(logf, "tailscale.us"),
useTailscaleDNS: conf.UseTailscaleDNS, useTailscaleDNS: conf.UseTailscaleDNS,
pingers: make(map[wgcfg.Key]*pinger), pingers: make(map[wgcfg.Key]*pinger),
} }
@ -308,6 +314,9 @@ func newUserspaceEngineAdvanced(conf EngineConfig) (_ Engine, reterr error) {
e.linkMon.Start() e.linkMon.Start()
e.magicConn.Start() e.magicConn.Start()
e.resolver.Start()
go e.pollResolver()
return e, nil return e, nil
} }
@ -360,22 +369,52 @@ func (e *userspaceEngine) isLocalAddr(ip packet.IP) bool {
// handleDNS is an outbound pre-filter resolving Tailscale domains. // handleDNS is an outbound pre-filter resolving Tailscale domains.
func (e *userspaceEngine) handleDNS(p *packet.ParsedPacket, t *tstun.TUN) filter.Response { func (e *userspaceEngine) handleDNS(p *packet.ParsedPacket, t *tstun.TUN) filter.Response {
if e.resolver.AcceptsPacket(p) { if p.DstIP == magicDNSIP && p.DstPort == magicDNSPort && p.IPProto == packet.UDP {
// TODO(dmytro): avoid this allocation without having tsdns know tstun quirks. request := tsdns.Packet{
buf := make([]byte, tstun.MaxPacketSize) Payload: p.Payload(),
offset := tstun.PacketStartOffset Addr: netaddr.IPPort{IP: p.SrcIP.Netaddr(), Port: p.SrcPort},
response, err := e.resolver.Respond(p, buf[offset:]) }
err := e.resolver.EnqueueRequest(request)
if err != nil { if err != nil {
e.logf("DNS resolver error: %v", err) e.logf("tsdns: enqueue: %v", err)
} else {
t.InjectInboundDirect(buf[:offset+len(response)], offset)
} }
// We already handled it, stop.
return filter.Drop return filter.Drop
} }
return filter.Accept return filter.Accept
} }
// pollResolver reads responses from the DNS resolver and injects them inbound.
func (e *userspaceEngine) pollResolver() {
for {
resp, err := e.resolver.NextResponse()
if err == tsdns.ErrClosed {
return
}
if err != nil {
e.logf("tsdns: error: %v", err)
continue
}
h := packet.UDPHeader{
IPHeader: packet.IPHeader{
SrcIP: packet.IP(magicDNSIP),
DstIP: packet.IPFromNetaddr(resp.Addr.IP),
},
SrcPort: magicDNSPort,
DstPort: resp.Addr.Port,
}
hlen := h.Len()
// TODO(dmytro): avoid this allocation without importing tstun quirks into tsdns.
const offset = tstun.PacketStartOffset
buf := make([]byte, offset+hlen+len(resp.Payload))
copy(buf[offset+hlen:], resp.Payload)
h.Marshal(buf[offset:])
e.tundev.InjectInboundDirect(buf, offset)
}
}
// pinger sends ping packets for a few seconds. // pinger sends ping packets for a few seconds.
// //
// These generated packets are used to ensure we trigger the spray logic in // These generated packets are used to ensure we trigger the spray logic in
@ -759,6 +798,7 @@ func (e *userspaceEngine) Close() {
r := bufio.NewReader(strings.NewReader("")) r := bufio.NewReader(strings.NewReader(""))
e.wgdev.IpcSetOperation(r) e.wgdev.IpcSetOperation(r)
e.resolver.Close()
e.magicConn.Close() e.magicConn.Close()
e.linkMon.Close() e.linkMon.Close()
e.router.Close() e.router.Close()

Loading…
Cancel
Save