From 67ebba90e1da096c52d335441bbbab35eff0a9d1 Mon Sep 17 00:00:00 2001 From: Dmytro Shynkevych Date: Tue, 7 Jul 2020 15:25:32 -0400 Subject: [PATCH] tsdns: dual resolution mode, IPv6 support (#526) This change adds to tsdns the ability to delegate lookups to upstream nameservers. This is crucial for setting Magic DNS as the system resolver. Signed-off-by: Dmytro Shynkevych --- go.mod | 2 +- go.sum | 5 +- ipn/local.go | 2 +- wgengine/packet/ip.go | 13 + wgengine/tsdns/tsdns.go | 464 ++++++++++++++++++++++++----------- wgengine/tsdns/tsdns_test.go | 258 ++++++++++++------- wgengine/userspace.go | 62 ++++- 7 files changed, 552 insertions(+), 254 deletions(-) diff --git a/go.mod b/go.mod index f14884bd4..f979b503d 100644 --- a/go.mod +++ b/go.mod @@ -30,6 +30,6 @@ require ( golang.org/x/sys v0.0.0-20200501052902-10377860bb8e golang.org/x/time v0.0.0-20191024005414-555d28b269f0 honnef.co/go/tools v0.0.1-2020.1.4 - inet.af/netaddr v0.0.0-20200702150737-4591d218f82c + inet.af/netaddr v0.0.0-20200706235120-1ac1a40fae99 rsc.io/goversion v1.2.0 ) diff --git a/go.sum b/go.sum index 9c26b5cdb..87a7c480b 100644 --- a/go.sum +++ b/go.sum @@ -160,7 +160,8 @@ gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo= gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= honnef.co/go/tools v0.0.1-2020.1.4 h1:UoveltGrhghAA7ePc+e+QYDHXrBps2PqFZiHkGR/xK8= honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= -inet.af/netaddr v0.0.0-20200702150737-4591d218f82c h1:j3Z4HL4KcLBDU1kmRpXTD5fikKBqIkE+7vFKS5mCz3Y= -inet.af/netaddr v0.0.0-20200702150737-4591d218f82c/go.mod h1:qqYzz/2whtrbWJvt+DNWQyvekNN4ePQZcg2xc2/Yjww= +inet.af v0.0.0-20181218191229-53da77bc832c h1:U3RoiyEF5b3Y1SVL6NNvpkgqUz2qS3a0OJh9kpSCN04= +inet.af/netaddr v0.0.0-20200706235120-1ac1a40fae99 h1:+43CBpWlrXThaOxixPS5JXEJZC8zaMCpDu3aKffe0bs= +inet.af/netaddr v0.0.0-20200706235120-1ac1a40fae99/go.mod h1:qqYzz/2whtrbWJvt+DNWQyvekNN4ePQZcg2xc2/Yjww= rsc.io/goversion v1.2.0 h1:SPn+NLTiAG7w30IRK/DKp1BjvpWabYgxlLp/+kx5J8w= rsc.io/goversion v1.2.0/go.mod h1:Eih9y/uIBS3ulggl7KNJ09xGSLcuNaLgmvvqa07sgfo= diff --git a/ipn/local.go b/ipn/local.go index 6571d84cc..9e3daa6bd 100644 --- a/ipn/local.go +++ b/ipn/local.go @@ -467,7 +467,7 @@ func (b *LocalBackend) updateDNSMap(netMap *controlclient.NetworkMap) { // Like PeerStatus.SimpleHostName() domain = strings.TrimSuffix(domain, ".local") domain = strings.TrimSuffix(domain, ".localdomain") - domain = domain + ".ipn.dev" + domain = domain + ".tailscale.us" domainToIP[domain] = netaddr.IPFrom16(peer.Addresses[0].IP.Addr) } b.e.SetDNSMap(tsdns.NewMap(domainToIP)) diff --git a/wgengine/packet/ip.go b/wgengine/packet/ip.go index 71bbb3cb6..487f0fb2b 100644 --- a/wgengine/packet/ip.go +++ b/wgengine/packet/ip.go @@ -7,6 +7,8 @@ package packet import ( "fmt" "net" + + "inet.af/netaddr" ) // IP is an IPv4 address. @@ -22,6 +24,17 @@ func NewIP(b net.IP) IP { return IP(get32(b4)) } +// IPFromNetaddr converts a netaddr.IP to an IP. +func IPFromNetaddr(ip netaddr.IP) IP { + ipbytes := ip.As4() + return IP(get32(ipbytes[:])) +} + +// Netaddr converts an IP to a netaddr.IP. +func (ip IP) Netaddr() netaddr.IP { + return netaddr.IPv4(byte(ip>>24), byte(ip>>16), byte(ip>>8), byte(ip)) +} + func (ip IP) String() string { return fmt.Sprintf("%d.%d.%d.%d", byte(ip>>24), byte(ip>>16), byte(ip>>8), byte(ip)) } diff --git a/wgengine/tsdns/tsdns.go b/wgengine/tsdns/tsdns.go index 71b425971..7f4b57c74 100644 --- a/wgengine/tsdns/tsdns.go +++ b/wgengine/tsdns/tsdns.go @@ -7,128 +7,319 @@ package tsdns import ( - "encoding/binary" + "bytes" + "context" "errors" - "strings" "sync" "time" dns "golang.org/x/net/dns/dnsmessage" "inet.af/netaddr" + "tailscale.com/net/netns" "tailscale.com/types/logger" - "tailscale.com/wgengine/packet" ) +// maxResponseSize is the maximum size of a response from a Resolver. +const maxResponseSize = 512 + +// queueSize is the maximal number of DNS requests that can be pending at a time. +// If EnqueueRequest is called when this many requests are already pending, +// the request will be dropped to avoid blocking the caller. +const queueSize = 8 + +// delegateTimeout is the maximal amount of time Resolver will wait +// for upstream nameservers to process a query. +const delegateTimeout = 5 * time.Second + // defaultTTL is the TTL of all responses from Resolver. const defaultTTL = 600 * time.Second +// ErrClosed indicates that the resolver has been closed and readers should exit. +var ErrClosed = errors.New("closed") + var ( + errAllFailed = errors.New("all upstream nameservers failed") + errFullQueue = errors.New("request queue full") errMapNotSet = errors.New("domain map not set") - errNoSuchDomain = errors.New("domain does not exist") errNotImplemented = errors.New("query type not implemented") - errNotOurName = errors.New("not an *.ipn.dev domain") - errNotOurQuery = errors.New("query not for this resolver") errNotQuery = errors.New("not a DNS query") - errSmallBuffer = errors.New("response buffer too small") ) -var ( - defaultIP = packet.IP(binary.BigEndian.Uint32([]byte{100, 100, 100, 100})) - defaultPort = uint16(53) -) - -// Map is all the data Resolver needs to resolve DNS queries. +// Map is all the data Resolver needs to resolve DNS queries within the Tailscale network. type Map struct { // domainToIP is a mapping of Tailscale domains to their IP addresses. - // For example, monitoring.ipn.dev -> 100.64.0.1. + // For example, monitoring.tailscale.us -> 100.64.0.1. domainToIP map[string]netaddr.IP } // NewMap returns a new Map with domain to address mapping given by domainToIP. -// It takes ownership of the provided map. func NewMap(domainToIP map[string]netaddr.IP) *Map { - return &Map{ - domainToIP: domainToIP, - } + return &Map{domainToIP: domainToIP} } -// Resolver is a DNS resolver for domain names of the form *.ipn.dev. +// Packet represents a DNS payload together with the address of its origin. +type Packet struct { + // Payload is the application layer DNS payload. + // Resolver assumes ownership of the request payload when it is enqueued + // and cedes ownership of the response payload when it is returned from NextResponse. + Payload []byte + // Addr is the source address for a request and the destination address for a response. + Addr netaddr.IPPort +} + +// Resolver is a DNS resolver for nodes on the Tailscale network, +// associating them with domain names of the form ... +// If it is asked to resolve a domain that is not of that form, +// it delegates to upstream nameservers if any are set. type Resolver struct { logf logger.Logf - // ip is the IP on which the resolver is listening. - ip packet.IP - // port is the port on which the resolver is listening. - port uint16 + // The asynchronous interface is due to the fact that resolution may potentially + // block for a long time (if the upstream nameserver is slow to reach). + + // queue is a buffered channel holding DNS requests queued for resolution. + queue chan Packet + // responses is an unbuffered channel to which responses are sent. + responses chan Packet + // errors is an unbuffered channel to which errors are sent. + errors chan error + // closed notifies the poll goroutines to stop. + closed chan struct{} + // pollGroup signals when all poll goroutines have stopped. + pollGroup sync.WaitGroup + + // rootDomain is in ... + rootDomain []byte + + // dialer is the netns.Dialer used for delegation. + dialer netns.Dialer // mu guards the following fields from being updated while used. - mu sync.Mutex + mu sync.RWMutex // dnsMap is the map most recently received from the control server. dnsMap *Map + // nameservers is the list of nameserver addresses that should be used + // if the received query is not for a Tailscale node. + // The addresses are strings of the form ip:port, as expected by Dial. + nameservers []string } -// NewResolver constructs a resolver with default parameters. -func NewResolver(logf logger.Logf) *Resolver { +// NewResolver constructs a resolver associated with the given root domain. +func NewResolver(logf logger.Logf, rootDomain string) *Resolver { r := &Resolver{ - logf: logf, - ip: defaultIP, - port: defaultPort, + logf: logger.WithPrefix(logf, "tsdns: "), + queue: make(chan Packet, queueSize), + responses: make(chan Packet), + errors: make(chan error), + closed: make(chan struct{}), + // Conform to the name format dnsmessage uses (trailing period, bytes). + rootDomain: []byte(rootDomain + "."), + dialer: netns.NewDialer(), } return r } -// AcceptsPacket determines if the given packet is -// directed to this resolver (by ip and port). -// We also require that UDP be used to simplify things for now. -func (r *Resolver) AcceptsPacket(in *packet.ParsedPacket) bool { - return in.DstIP == r.ip && in.DstPort == r.port && in.IPProto == packet.UDP +func (r *Resolver) Start() { + // TODO(dmytro): spawn more than one goroutine? They block on delegation. + r.pollGroup.Add(1) + go r.poll() +} + +// Close shuts down the resolver and ensures poll goroutines have exited. +// The Resolver cannot be used again after Close is called. +func (r *Resolver) Close() { + select { + case <-r.closed: + return + default: + // continue + } + close(r.closed) + r.pollGroup.Wait() } -// SetMap sets the resolver's DNS map. +// SetMap sets the resolver's DNS map, taking ownership of it. func (r *Resolver) SetMap(m *Map) { r.mu.Lock() r.dnsMap = m r.mu.Unlock() } -// Resolve maps a given domain name to the IP address of the host that owns it. -func (r *Resolver) Resolve(domain string) (netaddr.IP, dns.RCode, error) { - // If not a subdomain of ipn.dev, then we must refuse this query. - // We do this before checking the map to distinguish beween nonexistent domains - // and misdirected queries. - if !strings.HasSuffix(domain, ".ipn.dev") { - return netaddr.IP{}, dns.RCodeRefused, errNotOurName +// SetUpstreamNameservers sets the addresses of the resolver's +// upstream nameservers, taking ownership of the argument. +// The addresses should be strings of the form ip:port, +// matching what Dial("udp", addr) expects as addr. +func (r *Resolver) SetNameservers(nameservers []string) { + r.mu.Lock() + r.nameservers = nameservers + r.mu.Unlock() +} + +// EnqueueRequest places the given DNS request in the resolver's queue. +// It takes ownership of the payload and does not block. +// If the queue is full, the request will be dropped and an error will be returned. +func (r *Resolver) EnqueueRequest(request Packet) error { + select { + case r.queue <- request: + return nil + default: + return errFullQueue } +} - r.mu.Lock() +// NextResponse returns a DNS response to a previously enqueued request. +// It blocks until a response is available and gives up ownership of the response payload. +func (r *Resolver) NextResponse() (Packet, error) { + select { + case resp := <-r.responses: + return resp, nil + case err := <-r.errors: + return Packet{}, err + case <-r.closed: + return Packet{}, ErrClosed + } +} + +// Resolve maps a given domain name to the IP address of the host that owns it. +// The domain name must not have a trailing period. +func (r *Resolver) Resolve(domain string) (netaddr.IP, dns.RCode, error) { + r.mu.RLock() if r.dnsMap == nil { - r.mu.Unlock() + r.mu.RUnlock() return netaddr.IP{}, dns.RCodeServerFailure, errMapNotSet } addr, found := r.dnsMap.domainToIP[domain] - r.mu.Unlock() + r.mu.RUnlock() if !found { - return netaddr.IP{}, dns.RCodeNameError, errNoSuchDomain + return netaddr.IP{}, dns.RCodeNameError, nil } return addr, dns.RCodeSuccess, nil } +func (r *Resolver) poll() { + defer r.pollGroup.Done() + + var ( + packet Packet + err error + ) + for { + select { + case packet = <-r.queue: + // continue + case <-r.closed: + return + } + + packet.Payload, err = r.respond(packet.Payload) + if err != nil { + select { + case r.errors <- err: + // continue + case <-r.closed: + return + } + } else { + select { + case r.responses <- packet: + // continue + case <-r.closed: + return + } + } + } +} + +// queryServer obtains a DNS response by querying the given server. +func (r *Resolver) queryServer(ctx context.Context, server string, query []byte) ([]byte, error) { + conn, err := r.dialer.DialContext(ctx, "udp", server) + if err != nil { + return nil, err + } + defer conn.Close() + + // Interrupt the current operation when the context is cancelled. + go func() { + <-ctx.Done() + conn.SetDeadline(time.Unix(1, 0)) + }() + + _, err = conn.Write(query) + if err != nil { + return nil, err + } + + out := make([]byte, maxResponseSize) + n, err := conn.Read(out) + if err != nil { + return nil, err + } + + return out[:n], nil +} + +// delegate forwards the query to all upstream nameservers and returns the first response. +func (r *Resolver) delegate(query []byte) ([]byte, error) { + r.mu.RLock() + nameservers := r.nameservers + r.mu.RUnlock() + + if len(r.nameservers) == 0 { + return nil, errAllFailed + } + + ctx, cancel := context.WithTimeout(context.Background(), delegateTimeout) + defer cancel() + + // Common case, don't spawn goroutines. + if len(nameservers) == 1 { + return r.queryServer(ctx, nameservers[0], query) + } + + datach := make(chan []byte) + for _, server := range nameservers { + go func(s string) { + resp, err := r.queryServer(ctx, s, query) + // Only print errors not due to cancelation after first response. + if err != nil && ctx.Err() != context.Canceled { + r.logf("querying %s: %v", s, err) + } + + datach <- resp + }(server) + } + + var response []byte + for range nameservers { + cur := <-datach + if cur != nil && response == nil { + // Received first successful response + response = cur + cancel() + } + } + + if response == nil { + return nil, errAllFailed + } + return response, nil +} + type response struct { - Header dns.Header - ResourceHeader dns.ResourceHeader - Question dns.Question - // TODO(dmytro): support IPv6. - IP netaddr.IP + Header dns.Header + Question dns.Question + Name string + IP netaddr.IP } // parseQuery parses the query in given packet into a response struct. -func (r *Resolver) parseQuery(query *packet.ParsedPacket, resp *response) error { +func (r *Resolver) parseQuery(query []byte, resp *response) error { var parser dns.Parser var err error - resp.Header, err = parser.Start(query.Payload()) + resp.Header, err = parser.Start(query) if err != nil { return err } @@ -145,146 +336,123 @@ func (r *Resolver) parseQuery(query *packet.ParsedPacket, resp *response) error return nil } -// makeResponse resolves the question stored in resp and sets the answer fields. -func (r *Resolver) makeResponse(resp *response) error { - var err error - - name := resp.Question.Name.String() - if len(name) > 0 { - name = name[:len(name)-1] - } - - if resp.Question.Type == dns.TypeA { - // Remove final dot from name: *.ipn.dev. -> *.ipn.dev - resp.IP, resp.Header.RCode, err = r.Resolve(name) - } else { - resp.Header.RCode = dns.RCodeNotImplemented - err = errNotImplemented - } - - return err -} - -// marshalAnswer serializes the answer record into an active builder. +// marshalARecord serializes an A record into an active builder. // The caller may continue using the builder following the call. -func marshalAnswer(resp *response, builder *dns.Builder) error { +func marshalARecord(name dns.Name, ip netaddr.IP, builder *dns.Builder) error { var answer dns.AResource - err := builder.StartAnswers() - if err != nil { - return err - } - answerHeader := dns.ResourceHeader{ - Name: resp.Question.Name, + Name: name, Type: dns.TypeA, Class: dns.ClassINET, TTL: uint32(defaultTTL / time.Second), } - ip := resp.IP.As16() - copy(answer.A[:], ip[12:]) + ipbytes := ip.As4() + copy(answer.A[:], ipbytes[:]) return builder.AResource(answerHeader, answer) } -// marshalResponse serializes the DNS response into an active builder. +// marshalAAAARecord serializes an AAAA record into an active builder. // The caller may continue using the builder following the call. -func marshalResponse(resp *response, builder *dns.Builder) error { - err := builder.StartQuestions() - if err != nil { - return err - } - - err = builder.Question(resp.Question) - if err != nil { - return err - } +func marshalAAAARecord(name dns.Name, ip netaddr.IP, builder *dns.Builder) error { + var answer dns.AAAAResource - if resp.Header.RCode == dns.RCodeSuccess { - err = marshalAnswer(resp, builder) - if err != nil { - return err - } + answerHeader := dns.ResourceHeader{ + Name: name, + Type: dns.TypeAAAA, + Class: dns.ClassINET, + TTL: uint32(defaultTTL / time.Second), } - - return nil + ipbytes := ip.As16() + copy(answer.AAAA[:], ipbytes[:]) + return builder.AAAAResource(answerHeader, answer) } -// marshalReponsePacket marshals a full DNS packet (including headers) -// representing resp, which is a response to query, into buf. -// It returns buf trimmed to the length of the response packet. -func marshalResponsePacket(query *packet.ParsedPacket, resp *response, buf []byte) ([]byte, error) { - udpHeader := query.UDPHeader() - udpHeader.ToResponse() - offset := udpHeader.Len() - +// marshalResponse serializes the DNS response into a new buffer. +func marshalResponse(resp *response) ([]byte, error) { resp.Header.Response = true resp.Header.Authoritative = true if resp.Header.RecursionDesired { resp.Header.RecursionAvailable = true } - // dns.Builder appends to the passed buffer (without reallocation when possible), - // so we pass in a zero-length slice starting at the point it should start writing. - builder := dns.NewBuilder(buf[offset:offset], resp.Header) + builder := dns.NewBuilder(nil, resp.Header) - err := marshalResponse(resp, &builder) + err := builder.StartQuestions() if err != nil { return nil, err } - // rbuf is the response slice with the correct length starting at offset. - rbuf, err := builder.Finish() + err = builder.Question(resp.Question) if err != nil { return nil, err } - end := offset + len(rbuf) - err = udpHeader.Marshal(buf[:end]) + // Only successful responses contain answers. + if resp.Header.RCode != dns.RCodeSuccess { + return builder.Finish() + } + + err = builder.StartAnswers() if err != nil { return nil, err } - return buf[:end], nil -} - -// Respond writes a response to query into buf and returns buf trimmed to the response length. -// It is assumed that r.AcceptsPacket(query) is true. -func (r *Resolver) Respond(query *packet.ParsedPacket, buf []byte) ([]byte, error) { - var resp response - var err error - - // 0. Verify that contract is upheld. - if !r.AcceptsPacket(query) { - return nil, errNotOurQuery + if resp.IP.Is4() { + err = marshalARecord(resp.Question.Name, resp.IP, &builder) + } else { + err = marshalAAAARecord(resp.Question.Name, resp.IP, &builder) } - // A DNS response is at least as long as the query - if len(buf) < len(query.Buffer()) { - return nil, errSmallBuffer + if err != nil { + return nil, err } - // 1. Parse query packet. - err = r.parseQuery(query, &resp) + return builder.Finish() +} + +// respond returns a DNS response to query. +func (r *Resolver) respond(query []byte) ([]byte, error) { + resp := new(response) + + // ParseQuery is sufficiently fast to run on every DNS packet. + // This is considerably simpler than extracting the name by hand + // to shave off microseconds in case of delegation. + err := r.parseQuery(query, resp) // We will not return this error: it is the sender's fault. if err != nil { - r.logf("tsdns: error during query parsing: %v", err) + r.logf("parsing query: %v", err) resp.Header.RCode = dns.RCodeFormatError - return marshalResponsePacket(query, &resp, buf) + return marshalResponse(resp) } - // 2. Service the query. - err = r.makeResponse(&resp) + // Delegate only when not a subdomain of rootDomain. + // We do this on bytes because Name.String() allocates. + rawName := resp.Question.Name.Data[:resp.Question.Name.Length] + if !bytes.HasSuffix(rawName, r.rootDomain) { + out, err := r.delegate(query) + if err != nil { + r.logf("delegating: %v", err) + resp.Header.RCode = dns.RCodeServerFailure + return marshalResponse(resp) + } + return out, nil + } + + switch resp.Question.Type { + case dns.TypeA, dns.TypeAAAA: + domain := resp.Question.Name.String() + // Strip off the trailing period. + // This is safe: Name is guaranteed to have a trailing period by construction. + domain = domain[:len(domain)-1] + resp.IP, resp.Header.RCode, err = r.Resolve(domain) + default: + resp.Header.RCode = dns.RCodeNotImplemented + err = errNotImplemented + } // We will not return this error: it is the sender's fault. if err != nil { - r.logf("tsdns: error during name resolution: %v", err) - return marshalResponsePacket(query, &resp, buf) - } - // For now, we require IPv4 in all cases. - // If we somehow came up with a non-IPv4 address, it's our fault. - if !resp.IP.Is4() { - resp.Header.RCode = dns.RCodeServerFailure - r.logf("tsdns: error during name resolution: IPv6 address: %v", resp.IP) + r.logf("resolving: %v", err) } - // 3. Serialize the response. - return marshalResponsePacket(query, &resp, buf) + return marshalResponse(resp) } diff --git a/wgengine/tsdns/tsdns_test.go b/wgengine/tsdns/tsdns_test.go index 3ad8a449c..b487a582c 100644 --- a/wgengine/tsdns/tsdns_test.go +++ b/wgengine/tsdns/tsdns_test.go @@ -6,113 +6,173 @@ package tsdns import ( "bytes" + "errors" "sync" "testing" dns "golang.org/x/net/dns/dnsmessage" "inet.af/netaddr" - "tailscale.com/wgengine/packet" ) +var test2bytes = [16]byte{ + 0x00, 0x01, 0x02, 0x03, + 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0a, 0x0b, + 0x0c, 0x0d, 0x0e, 0x0f, +} + var dnsMap = &Map{ domainToIP: map[string]netaddr.IP{ "test1.ipn.dev": netaddr.IPv4(1, 2, 3, 4), - "test2.ipn.dev": netaddr.IPv4(5, 6, 7, 8), + "test2.ipn.dev": netaddr.IPv6Raw(test2bytes), }, } -func dnspacket(srcip, dstip packet.IP, domain string, tp dns.Type, response bool) *packet.ParsedPacket { - dnsHeader := dns.Header{Response: response} +func dnspacket(domain string, tp dns.Type) []byte { + var dnsHeader dns.Header question := dns.Question{ Name: dns.MustNewName(domain), Type: tp, Class: dns.ClassINET, } - udpHeader := &packet.UDPHeader{ - IPHeader: packet.IPHeader{ - SrcIP: srcip, - DstIP: dstip, - IPProto: packet.UDP, - }, - SrcPort: 1234, - DstPort: 53, - } builder := dns.NewBuilder(nil, dnsHeader) builder.StartQuestions() builder.Question(question) payload, _ := builder.Finish() - buf := packet.Generate(udpHeader, payload) + return payload +} - pp := new(packet.ParsedPacket) - pp.Decode(buf) +func extractipcode(response []byte) (netaddr.IP, dns.RCode, error) { + var ip netaddr.IP + var parser dns.Parser - return pp -} + h, err := parser.Start(response) + if err != nil { + return ip, 0, err + } -func TestAcceptsPacket(t *testing.T) { - r := NewResolver(t.Logf) - r.SetMap(dnsMap) + if !h.Response { + return ip, 0, errors.New("not a response") + } + if h.RCode != dns.RCodeSuccess { + return ip, h.RCode, nil + } - src := packet.IP(0x64656667) // 100.101.102.103 - dst := packet.IP(0x64646464) // 100.100.100.100 - tests := []struct { - name string - request *packet.ParsedPacket - want bool - }{ - {"valid", dnspacket(src, dst, "test1.ipn.dev.", dns.TypeA, false), true}, - {"invalid", dnspacket(dst, src, "test1.ipn.dev.", dns.TypeA, false), false}, + err = parser.SkipAllQuestions() + if err != nil { + return ip, 0, err } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - accepts := r.AcceptsPacket(tt.request) - if accepts != tt.want { - t.Errorf("accepts = %v; want %v", accepts, tt.want) - } - }) + ah, err := parser.AnswerHeader() + if err != nil { + return ip, 0, err } + switch ah.Type { + case dns.TypeA: + res, err := parser.AResource() + if err != nil { + return ip, 0, err + } + ip = netaddr.IPv4(res.A[0], res.A[1], res.A[2], res.A[3]) + case dns.TypeAAAA: + res, err := parser.AAAAResource() + if err != nil { + return ip, 0, err + } + ip = netaddr.IPv6Raw(res.AAAA) + default: + return ip, 0, errors.New("type not in {A, AAAA}") + } + + return ip, h.RCode, nil +} + +func syncRespond(r *Resolver, query []byte) ([]byte, error) { + request := Packet{Payload: query} + r.EnqueueRequest(request) + resp, err := r.NextResponse() + return resp.Payload, err } func TestResolve(t *testing.T) { - r := NewResolver(t.Logf) + r := NewResolver(t.Logf, "ipn.dev") r.SetMap(dnsMap) + r.Start() tests := []struct { name string domain string ip netaddr.IP code dns.RCode - iserr bool }{ - {"valid", "test1.ipn.dev", netaddr.IPv4(1, 2, 3, 4), dns.RCodeSuccess, false}, - {"nxdomain", "test3.ipn.dev", netaddr.IP{}, dns.RCodeNameError, true}, - {"not our domain", "google.com", netaddr.IP{}, dns.RCodeRefused, true}, + {"ipv4", "test1.ipn.dev", netaddr.IPv4(1, 2, 3, 4), dns.RCodeSuccess}, + {"ipv6", "test2.ipn.dev", netaddr.IPv6Raw(test2bytes), dns.RCodeSuccess}, + {"nxdomain", "test3.ipn.dev", netaddr.IP{}, dns.RCodeNameError}, + {"foreign domain", "google.com", netaddr.IP{}, dns.RCodeNameError}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ip, code, err := r.Resolve(tt.domain) - if err != nil && !tt.iserr { + if err != nil { t.Errorf("err = %v; want nil", err) - } else if err == nil && tt.iserr { - t.Errorf("err = nil; want non-nil") } if code != tt.code { t.Errorf("code = %v; want %v", code, tt.code) } // Only check ip for non-err - if !tt.iserr && ip != tt.ip { + if ip != tt.ip { + t.Errorf("ip = %v; want %v", ip, tt.ip) + } + }) + } +} + +func TestDelegate(t *testing.T) { + r := NewResolver(t.Logf, "ipn.dev") + r.SetNameservers([]string{"9.9.9.9:53", "[2620:fe::fe]:53"}) + r.Start() + + localhostv4, _ := netaddr.ParseIP("127.0.0.1") + localhostv6, _ := netaddr.ParseIP("::1") + tests := []struct { + name string + query []byte + ip netaddr.IP + code dns.RCode + }{ + {"ipv4", dnspacket("localhost.", dns.TypeA), localhostv4, dns.RCodeSuccess}, + {"ipv6", dnspacket("localhost.", dns.TypeAAAA), localhostv6, dns.RCodeSuccess}, + {"nxdomain", dnspacket("invalid.invalid.", dns.TypeA), netaddr.IP{}, dns.RCodeNameError}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resp, err := syncRespond(r, tt.query) + if err != nil { + t.Errorf("err = %v; want nil", err) + return + } + ip, code, err := extractipcode(resp) + if err != nil { + t.Errorf("extract: err = %v; want nil (in %x)", err, resp) + return + } + if code != tt.code { + t.Errorf("code = %v; want %v", code, tt.code) + } + if ip != tt.ip { t.Errorf("ip = %v; want %v", ip, tt.ip) } }) } } -func TestConcurrentSet(t *testing.T) { - r := NewResolver(t.Logf) +func TestConcurrentSetMap(t *testing.T) { + r := NewResolver(t.Logf, "ipn.dev") + r.Start() // This is purely to ensure that Resolve does not race with SetMap. var wg sync.WaitGroup @@ -128,16 +188,26 @@ func TestConcurrentSet(t *testing.T) { wg.Wait() } -var validResponse = []byte{ - // IP header - 0x45, 0x00, 0x00, 0x58, 0xff, 0xff, 0x00, 0x00, 0x40, 0x11, 0xe7, 0x00, - // Source IP - 0x64, 0x64, 0x64, 0x64, - // Destination IP - 0x64, 0x65, 0x66, 0x67, - // UDP header - 0x00, 0x35, 0x04, 0xd2, 0x00, 0x44, 0x53, 0xdd, - // DNS payload +func TestConcurrentSetNameservers(t *testing.T) { + r := NewResolver(t.Logf, "ipn.dev") + r.Start() + packet := dnspacket("google.com.", dns.TypeA) + + // This is purely to ensure that delegation does not race with SetNameservers. + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + r.SetNameservers([]string{"9.9.9.9:53"}) + }() + go func() { + defer wg.Done() + syncRespond(r, packet) + }() + wg.Wait() +} + +var validIPv4Response = []byte{ 0x00, 0x00, // transaction id: 0 0x84, 0x00, // flags: response, authoritative, no error 0x00, 0x01, // one question @@ -154,16 +224,25 @@ var validResponse = []byte{ 0x01, 0x02, 0x03, 0x04, // A: 1.2.3.4 } +var validIPv6Response = []byte{ + 0x00, 0x00, // transaction id: 0 + 0x84, 0x00, // flags: response, authoritative, no error + 0x00, 0x01, // one question + 0x00, 0x01, // one answer + 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs + // Question: + 0x05, 0x74, 0x65, 0x73, 0x74, 0x32, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name + 0x00, 0x1c, 0x00, 0x01, // type AAAA, class IN + // Answer: + 0x05, 0x74, 0x65, 0x73, 0x74, 0x32, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name + 0x00, 0x1c, 0x00, 0x01, // type AAAA, class IN + 0x00, 0x00, 0x02, 0x58, // TTL: 600 + 0x00, 0x10, // length: 16 bytes + // AAAA: 0001:0203:0405:0607:0809:0A0B:0C0D:0E0F + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0xb, 0xc, 0xd, 0xe, 0xf, +} + var nxdomainResponse = []byte{ - // IP header - 0x45, 0x00, 0x00, 0x3b, 0xff, 0xff, 0x00, 0x00, 0x40, 0x11, 0xe7, 0x1d, - // Source IP - 0x64, 0x64, 0x64, 0x64, - // Destination IP - 0x64, 0x65, 0x66, 0x67, - // UDP header - 0x00, 0x35, 0x04, 0xd2, 0x00, 0x27, 0x25, 0x33, - // DNS payload 0x00, 0x00, // transaction id: 0 0x84, 0x03, // flags: response, authoritative, error: nxdomain 0x00, 0x01, // one question @@ -175,25 +254,24 @@ var nxdomainResponse = []byte{ } func TestFull(t *testing.T) { - r := NewResolver(t.Logf) + r := NewResolver(t.Logf, "ipn.dev") r.SetMap(dnsMap) + r.Start() - src := packet.IP(0x64656667) // 100.101.102.103 - dst := packet.IP(0x64646464) // 100.100.100.100 // One full packet and one error packet tests := []struct { name string - request *packet.ParsedPacket + request []byte response []byte }{ - {"valid", dnspacket(src, dst, "test1.ipn.dev.", dns.TypeA, false), validResponse}, - {"error", dnspacket(src, dst, "test3.ipn.dev.", dns.TypeA, false), nxdomainResponse}, + {"ipv4", dnspacket("test1.ipn.dev.", dns.TypeA), validIPv4Response}, + {"ipv6", dnspacket("test2.ipn.dev.", dns.TypeAAAA), validIPv6Response}, + {"error", dnspacket("test3.ipn.dev.", dns.TypeA), nxdomainResponse}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - buf := make([]byte, 512) - response, err := r.Respond(tt.request, buf) + response, err := syncRespond(r, tt.request) if err != nil { t.Errorf("err = %v; want nil", err) } @@ -205,43 +283,41 @@ func TestFull(t *testing.T) { } func TestAllocs(t *testing.T) { - r := NewResolver(t.Logf) + r := NewResolver(t.Logf, "ipn.dev") r.SetMap(dnsMap) + r.Start() - src := packet.IP(0x64656667) // 100.101.102.103 - dst := packet.IP(0x64646464) // 100.100.100.100 - query := dnspacket(src, dst, "test1.ipn.dev.", dns.TypeA, false) + // It is seemingly pointless to test allocs in the delegate path, + // as dialer.Dial -> Read -> Write alone comprise 12 allocs. + query := dnspacket("test1.ipn.dev.", dns.TypeA) - buf := make([]byte, 512) allocs := testing.AllocsPerRun(100, func() { - r.Respond(query, buf) + syncRespond(r, query) }) - if allocs > 0 { - t.Errorf("allocs = %v; want 0", allocs) + if allocs > 1 { + t.Errorf("allocs = %v; want 1", allocs) } } func BenchmarkFull(b *testing.B) { - r := NewResolver(b.Logf) + r := NewResolver(b.Logf, "ipn.dev") r.SetMap(dnsMap) + r.Start() - src := packet.IP(0x64656667) // 100.101.102.103 - dst := packet.IP(0x64646464) // 100.100.100.100 // One full packet and one error packet tests := []struct { name string - request *packet.ParsedPacket + request []byte }{ - {"valid", dnspacket(src, dst, "test1.ipn.dev.", dns.TypeA, false)}, - {"nxdomain", dnspacket(src, dst, "test3.ipn.dev.", dns.TypeA, false)}, + {"valid", dnspacket("test1.ipn.dev.", dns.TypeA)}, + {"nxdomain", dnspacket("test3.ipn.dev.", dns.TypeA)}, } - buf := make([]byte, 512) for _, tt := range tests { b.Run(tt.name, func(b *testing.B) { for i := 0; i < b.N; i++ { - r.Respond(tt.request, buf) + syncRespond(r, tt.request) } }) } diff --git a/wgengine/userspace.go b/wgengine/userspace.go index 46b5cac65..e6292240c 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -25,6 +25,7 @@ import ( "github.com/tailscale/wireguard-go/tun" "github.com/tailscale/wireguard-go/wgcfg" "go4.org/mem" + "inet.af/netaddr" "tailscale.com/control/controlclient" "tailscale.com/internal/deepprint" "tailscale.com/ipn/ipnstate" @@ -51,6 +52,11 @@ import ( // discovery. const minimalMTU = 1280 +const ( + magicDNSIP = 0x64646464 // 100.100.100.100 + magicDNSPort = 53 +) + type userspaceEngine struct { logf logger.Logf reqCh chan struct{} @@ -100,7 +106,7 @@ type EngineConfig struct { // EchoRespondToAll determines whether ICMP Echo requests incoming from Tailscale peers // will be intercepted and responded to, regardless of the source host. EchoRespondToAll bool - // UseTailscaleDNS determines whether DNS requests for names of the form *.ipn.dev + // UseTailscaleDNS determines whether DNS requests for names of the form .. // directed to the designated Taislcale DNS address (see wgengine/tsdns) // will be intercepted and resolved by a tsdns.Resolver. UseTailscaleDNS bool @@ -174,7 +180,7 @@ func newUserspaceEngineAdvanced(conf EngineConfig) (_ Engine, reterr error) { reqCh: make(chan struct{}, 1), waitCh: make(chan struct{}), tundev: tstun.WrapTUN(logf, conf.TUN), - resolver: tsdns.NewResolver(logf), + resolver: tsdns.NewResolver(logf, "tailscale.us"), useTailscaleDNS: conf.UseTailscaleDNS, pingers: make(map[wgcfg.Key]*pinger), } @@ -308,6 +314,9 @@ func newUserspaceEngineAdvanced(conf EngineConfig) (_ Engine, reterr error) { e.linkMon.Start() e.magicConn.Start() + e.resolver.Start() + go e.pollResolver() + return e, nil } @@ -360,22 +369,52 @@ func (e *userspaceEngine) isLocalAddr(ip packet.IP) bool { // handleDNS is an outbound pre-filter resolving Tailscale domains. func (e *userspaceEngine) handleDNS(p *packet.ParsedPacket, t *tstun.TUN) filter.Response { - if e.resolver.AcceptsPacket(p) { - // TODO(dmytro): avoid this allocation without having tsdns know tstun quirks. - buf := make([]byte, tstun.MaxPacketSize) - offset := tstun.PacketStartOffset - response, err := e.resolver.Respond(p, buf[offset:]) + if p.DstIP == magicDNSIP && p.DstPort == magicDNSPort && p.IPProto == packet.UDP { + request := tsdns.Packet{ + Payload: p.Payload(), + Addr: netaddr.IPPort{IP: p.SrcIP.Netaddr(), Port: p.SrcPort}, + } + err := e.resolver.EnqueueRequest(request) if err != nil { - e.logf("DNS resolver error: %v", err) - } else { - t.InjectInboundDirect(buf[:offset+len(response)], offset) + e.logf("tsdns: enqueue: %v", err) } - // We already handled it, stop. return filter.Drop } return filter.Accept } +// pollResolver reads responses from the DNS resolver and injects them inbound. +func (e *userspaceEngine) pollResolver() { + for { + resp, err := e.resolver.NextResponse() + if err == tsdns.ErrClosed { + return + } + if err != nil { + e.logf("tsdns: error: %v", err) + continue + } + + h := packet.UDPHeader{ + IPHeader: packet.IPHeader{ + SrcIP: packet.IP(magicDNSIP), + DstIP: packet.IPFromNetaddr(resp.Addr.IP), + }, + SrcPort: magicDNSPort, + DstPort: resp.Addr.Port, + } + hlen := h.Len() + + // TODO(dmytro): avoid this allocation without importing tstun quirks into tsdns. + const offset = tstun.PacketStartOffset + buf := make([]byte, offset+hlen+len(resp.Payload)) + copy(buf[offset+hlen:], resp.Payload) + h.Marshal(buf[offset:]) + + e.tundev.InjectInboundDirect(buf, offset) + } +} + // pinger sends ping packets for a few seconds. // // These generated packets are used to ensure we trigger the spray logic in @@ -759,6 +798,7 @@ func (e *userspaceEngine) Close() { r := bufio.NewReader(strings.NewReader("")) e.wgdev.IpcSetOperation(r) + e.resolver.Close() e.magicConn.Close() e.linkMon.Close() e.router.Close()