diff --git a/appc/appconnector.go b/appc/appconnector.go index a08fbb0b8..ef59343d3 100644 --- a/appc/appconnector.go +++ b/appc/appconnector.go @@ -22,6 +22,7 @@ import ( "tailscale.com/types/views" "tailscale.com/util/dnsname" "tailscale.com/util/execqueue" + "tailscale.com/util/mak" ) // RouteAdvertiser is an interface that allows the AppConnector to advertise @@ -206,7 +207,16 @@ func (e *AppConnector) ObserveDNSResponse(res []byte) { return } -nextAnswer: + // cnameChain tracks a chain of CNAMEs for a given query in order to reverse + // a CNAME chain back to the original query for flattening. The keys are + // CNAME record targets, and the value is the name the record answers, so + // for www.example.com CNAME example.com, the map would contain + // ["example.com"] = "www.example.com". + var cnameChain map[string]string + + // addressRecords is a list of address records found in the response. + var addressRecords map[string][]netip.Addr + for { h, err := p.AnswerHeader() if err == dnsmessage.ErrSectionDone { @@ -222,83 +232,147 @@ nextAnswer: } continue } - if h.Type != dnsmessage.TypeA && h.Type != dnsmessage.TypeAAAA { + + switch h.Type { + case dnsmessage.TypeCNAME, dnsmessage.TypeA, dnsmessage.TypeAAAA: + default: if err := p.SkipAnswer(); err != nil { return } continue - } - domain := h.Name.String() - if len(domain) == 0 { - return } - domain = strings.TrimSuffix(domain, ".") - domain = strings.ToLower(domain) - e.logf("[v2] observed DNS response for %s", domain) - e.mu.Lock() - addrs, ok := e.domains[domain] - // match wildcard domains - if !ok { - for _, wc := range e.wildcards { - if dnsname.HasSuffix(domain, wc) { - e.domains[domain] = nil - ok = true - break - } - } + domain := strings.TrimSuffix(strings.ToLower(h.Name.String()), ".") + if len(domain) == 0 { + continue } - e.mu.Unlock() - if !ok { - if err := p.SkipAnswer(); err != nil { + if h.Type == dnsmessage.TypeCNAME { + res, err := p.CNAMEResource() + if err != nil { return } + cname := strings.TrimSuffix(strings.ToLower(res.CNAME.String()), ".") + if len(cname) == 0 { + continue + } + mak.Set(&cnameChain, cname, domain) continue } - var addr netip.Addr switch h.Type { case dnsmessage.TypeA: r, err := p.AResource() if err != nil { return } - addr = netip.AddrFrom4(r.A) + addr := netip.AddrFrom4(r.A) + mak.Set(&addressRecords, domain, append(addressRecords[domain], addr)) case dnsmessage.TypeAAAA: r, err := p.AAAAResource() if err != nil { return } - addr = netip.AddrFrom16(r.AAAA) + addr := netip.AddrFrom16(r.AAAA) + mak.Set(&addressRecords, domain, append(addressRecords[domain], addr)) default: if err := p.SkipAnswer(); err != nil { return } continue } - if slices.Contains(addrs, addr) { + } + + e.mu.Lock() + defer e.mu.Unlock() + + for domain, addrs := range addressRecords { + domain, isRouted := e.findRoutedDomainLocked(domain, cnameChain) + + // domain and none of the CNAMEs in the chain are routed + if !isRouted { continue } - for _, route := range e.controlRoutes { - if route.Contains(addr) { - // record the new address associated with the domain for faster matching in subsequent - // requests and for diagnostic records. - e.mu.Lock() - e.domains[domain] = append(addrs, addr) - e.mu.Unlock() - continue nextAnswer + + // advertise each address we have learned for the routed domain, that + // was not already known. + for _, addr := range addrs { + e.logf("[v2] observed routed DNS response for %s: %s", domain, addr) + if e.isAddrKnownLocked(domain, addr) { + continue + } + + e.scheduleAdvertisement(domain, addr) + } + } +} + +// starting from the given domain that resolved to an address, find it, or any +// of the domains in the CNAME chain toward resolving it, that are routed +// domains, returning the routed domain name and a bool indicating whether a +// routed domain was found. +// e.mu must be held. +func (e *AppConnector) findRoutedDomainLocked(domain string, cnameChain map[string]string) (string, bool) { + var isRouted bool + for { + _, isRouted = e.domains[domain] + if isRouted { + break + } + + // match wildcard domains + for _, wc := range e.wildcards { + if dnsname.HasSuffix(domain, wc) { + e.domains[domain] = nil + isRouted = true + break } } + + next, ok := cnameChain[domain] + if !ok { + break + } + domain = next + } + return domain, isRouted +} + +// isAddrKnownLocked returns true if the address is known to be associated with +// the given domain. Known domain tables are updated for covered routes to speed +// up future matches. +// e.mu must be held. +func (e *AppConnector) isAddrKnownLocked(domain string, addr netip.Addr) bool { + if slices.Contains(e.domains[domain], addr) { + return true + } + for _, route := range e.controlRoutes { + if route.Contains(addr) { + // record the new address associated with the domain for faster matching in subsequent + // requests and for diagnostic records. + e.domains[domain] = append(e.domains[domain], addr) + return true + } + } + return false + +} + +// scheduleAdvertisement schedules an advertisement of the given address +// associated with the given domain. +func (e *AppConnector) scheduleAdvertisement(domain string, addr netip.Addr) { + e.queue.Add(func() { if err := e.routeAdvertiser.AdvertiseRoute(netip.PrefixFrom(addr, addr.BitLen())); err != nil { e.logf("failed to advertise route for %s: %v: %v", domain, addr, err) - continue + return } - e.logf("[v2] advertised route for %v: %v", domain, addr) - e.mu.Lock() - e.domains[domain] = append(addrs, addr) - e.mu.Unlock() - } + defer e.mu.Unlock() + + if !slices.Contains(e.domains[domain], addr) { + e.logf("[v2] advertised route for %v: %v", domain, addr) + e.domains[domain] = append(e.domains[domain], addr) + } + }) } diff --git a/appc/appconnector_test.go b/appc/appconnector_test.go index 510ee8361..5b739b097 100644 --- a/appc/appconnector_test.go +++ b/appc/appconnector_test.go @@ -99,6 +99,7 @@ func TestDomainRoutes(t *testing.T) { a := NewAppConnector(t.Logf, rc) a.updateDomains([]string{"example.com"}) a.ObserveDNSResponse(dnsResponse("example.com.", "192.0.0.8")) + a.Wait(context.Background()) want := map[string][]netip.Addr{ "example.com": {netip.MustParseAddr("192.0.0.8")}, @@ -110,6 +111,7 @@ func TestDomainRoutes(t *testing.T) { } func TestObserveDNSResponse(t *testing.T) { + ctx := context.Background() rc := &appctest.RouteCollector{} a := NewAppConnector(t.Logf, rc) @@ -123,6 +125,26 @@ func TestObserveDNSResponse(t *testing.T) { a.updateDomains([]string{"example.com"}) a.ObserveDNSResponse(dnsResponse("example.com.", "192.0.0.8")) + a.Wait(ctx) + if got, want := rc.Routes(), wantRoutes; !slices.Equal(got, want) { + t.Errorf("got %v; want %v", got, want) + } + + // a CNAME record chain should result in a route being added if the chain + // matches a routed domain. + a.updateDomains([]string{"www.example.com", "example.com"}) + a.ObserveDNSResponse(dnsCNAMEResponse("192.0.0.9", "www.example.com.", "chain.example.com.", "example.com.")) + a.Wait(ctx) + wantRoutes = append(wantRoutes, netip.MustParsePrefix("192.0.0.9/32")) + if got, want := rc.Routes(), wantRoutes; !slices.Equal(got, want) { + t.Errorf("got %v; want %v", got, want) + } + + // a CNAME record chain should result in a route being added if the chain + // even if only found in the middle of the chain + a.ObserveDNSResponse(dnsCNAMEResponse("192.0.0.10", "outside.example.org.", "www.example.com.", "example.org.")) + a.Wait(ctx) + wantRoutes = append(wantRoutes, netip.MustParsePrefix("192.0.0.10/32")) if got, want := rc.Routes(), wantRoutes; !slices.Equal(got, want) { t.Errorf("got %v; want %v", got, want) } @@ -130,12 +152,14 @@ func TestObserveDNSResponse(t *testing.T) { wantRoutes = append(wantRoutes, netip.MustParsePrefix("2001:db8::1/128")) a.ObserveDNSResponse(dnsResponse("example.com.", "2001:db8::1")) + a.Wait(ctx) if got, want := rc.Routes(), wantRoutes; !slices.Equal(got, want) { t.Errorf("got %v; want %v", got, want) } // don't re-advertise routes that have already been advertised a.ObserveDNSResponse(dnsResponse("example.com.", "2001:db8::1")) + a.Wait(ctx) if !slices.Equal(rc.Routes(), wantRoutes) { t.Errorf("rc.Routes(): got %v; want %v", rc.Routes(), wantRoutes) } @@ -145,6 +169,7 @@ func TestObserveDNSResponse(t *testing.T) { a.updateRoutes([]netip.Prefix{pfx}) wantRoutes = append(wantRoutes, pfx) a.ObserveDNSResponse(dnsResponse("example.com.", "192.0.2.1")) + a.Wait(ctx) if !slices.Equal(rc.Routes(), wantRoutes) { t.Errorf("rc.Routes(): got %v; want %v", rc.Routes(), wantRoutes) } @@ -154,11 +179,13 @@ func TestObserveDNSResponse(t *testing.T) { } func TestWildcardDomains(t *testing.T) { + ctx := context.Background() rc := &appctest.RouteCollector{} a := NewAppConnector(t.Logf, rc) a.updateDomains([]string{"*.example.com"}) a.ObserveDNSResponse(dnsResponse("foo.example.com.", "192.0.0.8")) + a.Wait(ctx) if got, want := rc.Routes(), []netip.Prefix{netip.MustParsePrefix("192.0.0.8/32")}; !slices.Equal(got, want) { t.Errorf("routes: got %v; want %v", got, want) } @@ -218,6 +245,61 @@ func dnsResponse(domain, address string) []byte { return must.Get(b.Finish()) } +func dnsCNAMEResponse(address string, domains ...string) []byte { + addr := netip.MustParseAddr(address) + b := dnsmessage.NewBuilder(nil, dnsmessage.Header{}) + b.EnableCompression() + b.StartAnswers() + + if len(domains) >= 2 { + for i, domain := range domains[:len(domains)-1] { + b.CNAMEResource( + dnsmessage.ResourceHeader{ + Name: dnsmessage.MustNewName(domain), + Type: dnsmessage.TypeCNAME, + Class: dnsmessage.ClassINET, + TTL: 0, + }, + dnsmessage.CNAMEResource{ + CNAME: dnsmessage.MustNewName(domains[i+1]), + }, + ) + } + } + + domain := domains[len(domains)-1] + + switch addr.BitLen() { + case 32: + b.AResource( + dnsmessage.ResourceHeader{ + Name: dnsmessage.MustNewName(domain), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + TTL: 0, + }, + dnsmessage.AResource{ + A: addr.As4(), + }, + ) + case 128: + b.AAAAResource( + dnsmessage.ResourceHeader{ + Name: dnsmessage.MustNewName(domain), + Type: dnsmessage.TypeAAAA, + Class: dnsmessage.ClassINET, + TTL: 0, + }, + dnsmessage.AAAAResource{ + AAAA: addr.As16(), + }, + ) + default: + panic("invalid address length") + } + return must.Get(b.Finish()) +} + func prefixEqual(a, b netip.Prefix) bool { return a == b } diff --git a/ipn/ipnlocal/peerapi_test.go b/ipn/ipnlocal/peerapi_test.go index 2f495b7e7..85f5423e3 100644 --- a/ipn/ipnlocal/peerapi_test.go +++ b/ipn/ipnlocal/peerapi_test.go @@ -803,6 +803,72 @@ func TestPeerAPIReplyToDNSQueriesAreObserved(t *testing.T) { } } +func TestPeerAPIReplyToDNSQueriesAreObservedWithCNAMEFlattening(t *testing.T) { + ctx := context.Background() + var h peerAPIHandler + h.remoteAddr = netip.MustParseAddrPort("100.150.151.152:12345") + + rc := &appctest.RouteCollector{} + eng, _ := wgengine.NewFakeUserspaceEngine(logger.Discard, 0) + pm := must.Get(newProfileManager(new(mem.Store), t.Logf)) + h.ps = &peerAPIServer{ + b: &LocalBackend{ + e: eng, + pm: pm, + store: pm.Store(), + appConnector: appc.NewAppConnector(t.Logf, rc), + }, + } + h.ps.b.appConnector.UpdateDomains([]string{"www.example.com"}) + h.ps.b.appConnector.Wait(ctx) + + h.ps.resolver = &fakeResolver{build: func(b *dnsmessage.Builder) { + b.CNAMEResource( + dnsmessage.ResourceHeader{ + Name: dnsmessage.MustNewName("www.example.com."), + Type: dnsmessage.TypeCNAME, + Class: dnsmessage.ClassINET, + TTL: 0, + }, + dnsmessage.CNAMEResource{ + CNAME: dnsmessage.MustNewName("example.com."), + }, + ) + b.AResource( + dnsmessage.ResourceHeader{ + Name: dnsmessage.MustNewName("example.com."), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + TTL: 0, + }, + dnsmessage.AResource{ + A: [4]byte{192, 0, 0, 8}, + }, + ) + }} + f := filter.NewAllowAllForTest(logger.Discard) + h.ps.b.setFilter(f) + + if !h.ps.b.OfferingAppConnector() { + t.Fatal("expecting to be offering app connector") + } + if !h.replyToDNSQueries() { + t.Errorf("unexpectedly deny; wanted to be a DNS server") + } + + w := httptest.NewRecorder() + h.handleDNSQuery(w, httptest.NewRequest("GET", "/dns-query?q=www.example.com.", nil)) + if w.Code != http.StatusOK { + t.Errorf("unexpected status code: %v", w.Code) + } + h.ps.b.appConnector.Wait(ctx) + + wantRoutes := []netip.Prefix{netip.MustParsePrefix("192.0.0.8/32")} + if !slices.Equal(rc.Routes(), wantRoutes) { + t.Errorf("got %v; want %v", rc.Routes(), wantRoutes) + } +} + type fakeResolver struct { build func(*dnsmessage.Builder) }