diff --git a/ipn/local.go b/ipn/local.go index 5f2af42e1..3c80ae188 100644 --- a/ipn/local.go +++ b/ipn/local.go @@ -297,10 +297,13 @@ func (b *LocalBackend) Start(opts Options) error { b.send(Notify{Prefs: prefs}) } if newSt.NetMap != nil { + // Netmap is unchanged only when the diff is empty. + changed := true b.mu.Lock() if b.netMapCache != nil { diff := newSt.NetMap.ConciseDiffFrom(b.netMapCache) if strings.TrimSpace(diff) == "" { + changed = false b.logf("netmap diff: (none)") } else { b.logf("netmap diff:\n%v", diff) @@ -311,8 +314,11 @@ func (b *LocalBackend) Start(opts Options) error { b.mu.Unlock() b.send(Notify{NetMap: newSt.NetMap}) - b.updateFilter(newSt.NetMap) - b.updateDNSMap(newSt.NetMap) + // There is nothing to update if the map hasn't changed. + if changed { + b.updateFilter(newSt.NetMap) + b.updateDNSMap(newSt.NetMap) + } if disableDERP { b.e.SetDERPMap(nil) } else { @@ -435,7 +441,7 @@ func (b *LocalBackend) updateDNSMap(netMap *controlclient.NetworkMap) { if netMap == nil { return } - dnsMap := &tsdns.Map{DomainToIP: make(map[string]netaddr.IP)} + domainToIP := make(map[string]netaddr.IP) for _, peer := range netMap.Peers { if len(peer.Addresses) == 0 { continue @@ -445,9 +451,9 @@ func (b *LocalBackend) updateDNSMap(netMap *controlclient.NetworkMap) { domain = strings.TrimSuffix(domain, ".local") domain = strings.TrimSuffix(domain, ".localdomain") domain = domain + ".ipn.dev" - dnsMap.DomainToIP[domain] = netaddr.IPFrom16(peer.Addresses[0].IP.Addr) + domainToIP[domain] = netaddr.IPFrom16(peer.Addresses[0].IP.Addr) } - b.e.SetDNSMap(dnsMap) + b.e.SetDNSMap(tsdns.NewMap(domainToIP)) } // readPoller is a goroutine that receives service lists from @@ -690,7 +696,15 @@ func (b *LocalBackend) SetPrefs(new *Prefs) { } b.updateFilter(b.netMapCache) - b.updateDNSMap(b.netMapCache) + // TODO(dmytro): when Prefs gain an EnableTailscaleDNS toggle, updateDNSMap here. + + turnDERPOff := new.DisableDERP && !old.DisableDERP + turnDERPOn := !new.DisableDERP && old.DisableDERP + if turnDERPOff { + b.e.SetDERPMap(nil) + } else if turnDERPOn && b.netMapCache != nil { + b.e.SetDERPMap(b.netMapCache.DERPMap) + } if old.WantRunning != new.WantRunning { b.stateMachine() diff --git a/wgengine/packet/packet.go b/wgengine/packet/packet.go index 69f7c99a2..38d6f9e2b 100644 --- a/wgengine/packet/packet.go +++ b/wgengine/packet/packet.go @@ -242,16 +242,19 @@ func (q *ParsedPacket) UDPHeader() UDPHeader { } // Buffer returns the entire packet buffer. +// This is a read-only view; that is, q retains the ownership of the buffer. func (q *ParsedPacket) Buffer() []byte { return q.b } // Sub returns the IP subprotocol section. +// This is a read-only view; that is, q retains the ownership of the buffer. func (q *ParsedPacket) Sub(begin, n int) []byte { return q.b[q.subofs+begin : q.subofs+begin+n] } // Payload returns the payload of the IP subprotocol section. +// This is a read-only view; that is, q retains the ownership of the buffer. func (q *ParsedPacket) Payload() []byte { return q.b[q.dataofs:q.length] } diff --git a/wgengine/tsdns/tsdns.go b/wgengine/tsdns/tsdns.go index 8c0efdf56..71b425971 100644 --- a/wgengine/tsdns/tsdns.go +++ b/wgengine/tsdns/tsdns.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Package tsdns provides a Resolver struct capable of resolving +// Package tsdns provides a Resolver capable of resolving // domains on a Tailscale network. package tsdns @@ -11,6 +11,7 @@ import ( "errors" "strings" "sync" + "time" dns "golang.org/x/net/dns/dnsmessage" "inet.af/netaddr" @@ -18,15 +19,17 @@ import ( "tailscale.com/wgengine/packet" ) -// defaultTTL is the TTL in seconds of all responses from Resolver. -const defaultTTL = 600 +// defaultTTL is the TTL of all responses from Resolver. +const defaultTTL = 600 * time.Second var ( errMapNotSet = errors.New("domain map not set") errNoSuchDomain = errors.New("domain does not exist") errNotImplemented = errors.New("query type not implemented") errNotOurName = errors.New("not an *.ipn.dev domain") + errNotOurQuery = errors.New("query not for this resolver") errNotQuery = errors.New("not a DNS query") + errSmallBuffer = errors.New("response buffer too small") ) var ( @@ -36,13 +39,20 @@ var ( // Map is all the data Resolver needs to resolve DNS queries. type Map struct { - // DomainToIP is a mapping of Tailscale domains to their IP addresses. + // domainToIP is a mapping of Tailscale domains to their IP addresses. // For example, monitoring.ipn.dev -> 100.64.0.1. - DomainToIP map[string]netaddr.IP + domainToIP map[string]netaddr.IP } -// Resolver is a DNS resolver for domain names of the form *.ipn.dev -// It is intended +// NewMap returns a new Map with domain to address mapping given by domainToIP. +// It takes ownership of the provided map. +func NewMap(domainToIP map[string]netaddr.IP) *Map { + return &Map{ + domainToIP: domainToIP, + } +} + +// Resolver is a DNS resolver for domain names of the form *.ipn.dev. type Resolver struct { logf logger.Logf @@ -96,7 +106,7 @@ func (r *Resolver) Resolve(domain string) (netaddr.IP, dns.RCode, error) { r.mu.Unlock() return netaddr.IP{}, dns.RCodeServerFailure, errMapNotSet } - addr, found := r.dnsMap.DomainToIP[domain] + addr, found := r.dnsMap.domainToIP[domain] r.mu.Unlock() if !found { @@ -156,6 +166,7 @@ func (r *Resolver) makeResponse(resp *response) error { } // marshalAnswer serializes the answer record into an active builder. +// The caller may continue using the builder following the call. func marshalAnswer(resp *response, builder *dns.Builder) error { var answer dns.AResource @@ -168,7 +179,7 @@ func marshalAnswer(resp *response, builder *dns.Builder) error { Name: resp.Question.Name, Type: dns.TypeA, Class: dns.ClassINET, - TTL: defaultTTL, + TTL: uint32(defaultTTL / time.Second), } ip := resp.IP.As16() copy(answer.A[:], ip[12:]) @@ -176,44 +187,53 @@ func marshalAnswer(resp *response, builder *dns.Builder) error { } // marshalResponse serializes the DNS response into an active builder. -func marshalResponse(resp *response, builder *dns.Builder) ([]byte, error) { - resp.Header.Response = true - resp.Header.Authoritative = true - if resp.Header.RecursionDesired { - resp.Header.RecursionAvailable = true - } - +// The caller may continue using the builder following the call. +func marshalResponse(resp *response, builder *dns.Builder) error { err := builder.StartQuestions() if err != nil { - return nil, err + return err } err = builder.Question(resp.Question) if err != nil { - return nil, err + return err } if resp.Header.RCode == dns.RCodeSuccess { err = marshalAnswer(resp, builder) if err != nil { - return nil, err + return err } } - return builder.Finish() + return nil } +// marshalReponsePacket marshals a full DNS packet (including headers) +// representing resp, which is a response to query, into buf. +// It returns buf trimmed to the length of the response packet. func marshalResponsePacket(query *packet.ParsedPacket, resp *response, buf []byte) ([]byte, error) { udpHeader := query.UDPHeader() udpHeader.ToResponse() offset := udpHeader.Len() + resp.Header.Response = true + resp.Header.Authoritative = true + if resp.Header.RecursionDesired { + resp.Header.RecursionAvailable = true + } + // dns.Builder appends to the passed buffer (without reallocation when possible), // so we pass in a zero-length slice starting at the point it should start writing. builder := dns.NewBuilder(buf[offset:offset], resp.Header) + err := marshalResponse(resp, &builder) + if err != nil { + return nil, err + } + // rbuf is the response slice with the correct length starting at offset. - rbuf, err := marshalResponse(resp, &builder) + rbuf, err := builder.Finish() if err != nil { return nil, err } @@ -235,15 +255,11 @@ func (r *Resolver) Respond(query *packet.ParsedPacket, buf []byte) ([]byte, erro // 0. Verify that contract is upheld. if !r.AcceptsPacket(query) { - r.logf("[unexpected] tsdns: Respond called on query not for this resolver") - resp.Header.RCode = dns.RCodeServerFailure - return marshalResponsePacket(query, &resp, buf) + return nil, errNotOurQuery } // A DNS response is at least as long as the query if len(buf) < len(query.Buffer()) { - r.logf("[unexpected] tsdns: response buffer is too small") - resp.Header.RCode = dns.RCodeServerFailure - return marshalResponsePacket(query, &resp, buf) + return nil, errSmallBuffer } // 1. Parse query packet. diff --git a/wgengine/tsdns/tsdns_test.go b/wgengine/tsdns/tsdns_test.go new file mode 100644 index 000000000..3ad8a449c --- /dev/null +++ b/wgengine/tsdns/tsdns_test.go @@ -0,0 +1,248 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tsdns + +import ( + "bytes" + "sync" + "testing" + + dns "golang.org/x/net/dns/dnsmessage" + "inet.af/netaddr" + "tailscale.com/wgengine/packet" +) + +var dnsMap = &Map{ + domainToIP: map[string]netaddr.IP{ + "test1.ipn.dev": netaddr.IPv4(1, 2, 3, 4), + "test2.ipn.dev": netaddr.IPv4(5, 6, 7, 8), + }, +} + +func dnspacket(srcip, dstip packet.IP, domain string, tp dns.Type, response bool) *packet.ParsedPacket { + dnsHeader := dns.Header{Response: response} + question := dns.Question{ + Name: dns.MustNewName(domain), + Type: tp, + Class: dns.ClassINET, + } + udpHeader := &packet.UDPHeader{ + IPHeader: packet.IPHeader{ + SrcIP: srcip, + DstIP: dstip, + IPProto: packet.UDP, + }, + SrcPort: 1234, + DstPort: 53, + } + + builder := dns.NewBuilder(nil, dnsHeader) + builder.StartQuestions() + builder.Question(question) + payload, _ := builder.Finish() + + buf := packet.Generate(udpHeader, payload) + + pp := new(packet.ParsedPacket) + pp.Decode(buf) + + return pp +} + +func TestAcceptsPacket(t *testing.T) { + r := NewResolver(t.Logf) + r.SetMap(dnsMap) + + src := packet.IP(0x64656667) // 100.101.102.103 + dst := packet.IP(0x64646464) // 100.100.100.100 + tests := []struct { + name string + request *packet.ParsedPacket + want bool + }{ + {"valid", dnspacket(src, dst, "test1.ipn.dev.", dns.TypeA, false), true}, + {"invalid", dnspacket(dst, src, "test1.ipn.dev.", dns.TypeA, false), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + accepts := r.AcceptsPacket(tt.request) + if accepts != tt.want { + t.Errorf("accepts = %v; want %v", accepts, tt.want) + } + }) + } +} + +func TestResolve(t *testing.T) { + r := NewResolver(t.Logf) + r.SetMap(dnsMap) + + tests := []struct { + name string + domain string + ip netaddr.IP + code dns.RCode + iserr bool + }{ + {"valid", "test1.ipn.dev", netaddr.IPv4(1, 2, 3, 4), dns.RCodeSuccess, false}, + {"nxdomain", "test3.ipn.dev", netaddr.IP{}, dns.RCodeNameError, true}, + {"not our domain", "google.com", netaddr.IP{}, dns.RCodeRefused, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ip, code, err := r.Resolve(tt.domain) + if err != nil && !tt.iserr { + t.Errorf("err = %v; want nil", err) + } else if err == nil && tt.iserr { + t.Errorf("err = nil; want non-nil") + } + if code != tt.code { + t.Errorf("code = %v; want %v", code, tt.code) + } + // Only check ip for non-err + if !tt.iserr && ip != tt.ip { + t.Errorf("ip = %v; want %v", ip, tt.ip) + } + }) + } +} + +func TestConcurrentSet(t *testing.T) { + r := NewResolver(t.Logf) + + // This is purely to ensure that Resolve does not race with SetMap. + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + r.SetMap(dnsMap) + }() + go func() { + defer wg.Done() + r.Resolve("test1.ipn.dev") + }() + wg.Wait() +} + +var validResponse = []byte{ + // IP header + 0x45, 0x00, 0x00, 0x58, 0xff, 0xff, 0x00, 0x00, 0x40, 0x11, 0xe7, 0x00, + // Source IP + 0x64, 0x64, 0x64, 0x64, + // Destination IP + 0x64, 0x65, 0x66, 0x67, + // UDP header + 0x00, 0x35, 0x04, 0xd2, 0x00, 0x44, 0x53, 0xdd, + // DNS payload + 0x00, 0x00, // transaction id: 0 + 0x84, 0x00, // flags: response, authoritative, no error + 0x00, 0x01, // one question + 0x00, 0x01, // one answer + 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs + // Question: + 0x05, 0x74, 0x65, 0x73, 0x74, 0x31, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name + 0x00, 0x01, 0x00, 0x01, // type A, class IN + // Answer: + 0x05, 0x74, 0x65, 0x73, 0x74, 0x31, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name + 0x00, 0x01, 0x00, 0x01, // type A, class IN + 0x00, 0x00, 0x02, 0x58, // TTL: 600 + 0x00, 0x04, // length: 4 bytes + 0x01, 0x02, 0x03, 0x04, // A: 1.2.3.4 +} + +var nxdomainResponse = []byte{ + // IP header + 0x45, 0x00, 0x00, 0x3b, 0xff, 0xff, 0x00, 0x00, 0x40, 0x11, 0xe7, 0x1d, + // Source IP + 0x64, 0x64, 0x64, 0x64, + // Destination IP + 0x64, 0x65, 0x66, 0x67, + // UDP header + 0x00, 0x35, 0x04, 0xd2, 0x00, 0x27, 0x25, 0x33, + // DNS payload + 0x00, 0x00, // transaction id: 0 + 0x84, 0x03, // flags: response, authoritative, error: nxdomain + 0x00, 0x01, // one question + 0x00, 0x00, // no answers + 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs + // Question: + 0x05, 0x74, 0x65, 0x73, 0x74, 0x33, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name + 0x00, 0x01, 0x00, 0x01, // type A, class IN +} + +func TestFull(t *testing.T) { + r := NewResolver(t.Logf) + r.SetMap(dnsMap) + + src := packet.IP(0x64656667) // 100.101.102.103 + dst := packet.IP(0x64646464) // 100.100.100.100 + // One full packet and one error packet + tests := []struct { + name string + request *packet.ParsedPacket + response []byte + }{ + {"valid", dnspacket(src, dst, "test1.ipn.dev.", dns.TypeA, false), validResponse}, + {"error", dnspacket(src, dst, "test3.ipn.dev.", dns.TypeA, false), nxdomainResponse}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buf := make([]byte, 512) + response, err := r.Respond(tt.request, buf) + if err != nil { + t.Errorf("err = %v; want nil", err) + } + if !bytes.Equal(response, tt.response) { + t.Errorf("response = %x; want %x", response, tt.response) + } + }) + } +} + +func TestAllocs(t *testing.T) { + r := NewResolver(t.Logf) + r.SetMap(dnsMap) + + src := packet.IP(0x64656667) // 100.101.102.103 + dst := packet.IP(0x64646464) // 100.100.100.100 + query := dnspacket(src, dst, "test1.ipn.dev.", dns.TypeA, false) + + buf := make([]byte, 512) + allocs := testing.AllocsPerRun(100, func() { + r.Respond(query, buf) + }) + + if allocs > 0 { + t.Errorf("allocs = %v; want 0", allocs) + } +} + +func BenchmarkFull(b *testing.B) { + r := NewResolver(b.Logf) + r.SetMap(dnsMap) + + src := packet.IP(0x64656667) // 100.101.102.103 + dst := packet.IP(0x64646464) // 100.100.100.100 + // One full packet and one error packet + tests := []struct { + name string + request *packet.ParsedPacket + }{ + {"valid", dnspacket(src, dst, "test1.ipn.dev.", dns.TypeA, false)}, + {"nxdomain", dnspacket(src, dst, "test3.ipn.dev.", dns.TypeA, false)}, + } + + buf := make([]byte, 512) + for _, tt := range tests { + b.Run(tt.name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + r.Respond(tt.request, buf) + } + }) + } +}