From a0bae4dac8c195a7c81fe59fde04d6ff4b8ef9f0 Mon Sep 17 00:00:00 2001 From: Andrew Dunham Date: Fri, 2 Sep 2022 14:48:30 -0400 Subject: [PATCH] cmd/derper: add support for unpublished bootstrap DNS entries (#5529) Signed-off-by: Andrew Dunham --- cmd/derper/bootstrap_dns.go | 97 +++++++++++++++++++++---- cmd/derper/bootstrap_dns_test.go | 121 ++++++++++++++++++++++++++++++- cmd/derper/derper.go | 9 ++- 3 files changed, 206 insertions(+), 21 deletions(-) diff --git a/cmd/derper/bootstrap_dns.go b/cmd/derper/bootstrap_dns.go index 6c909bc36..f09e66653 100644 --- a/cmd/derper/bootstrap_dns.go +++ b/cmd/derper/bootstrap_dns.go @@ -17,16 +17,31 @@ import ( "tailscale.com/syncs" ) -var dnsCache syncs.AtomicValue[[]byte] +const refreshTimeout = time.Minute -var bootstrapDNSRequests = expvar.NewInt("counter_bootstrap_dns_requests") +type dnsEntryMap map[string][]net.IP + +var ( + dnsCache syncs.AtomicValue[dnsEntryMap] + dnsCacheBytes syncs.AtomicValue[[]byte] // of JSON + unpublishedDNSCache syncs.AtomicValue[dnsEntryMap] +) + +var ( + bootstrapDNSRequests = expvar.NewInt("counter_bootstrap_dns_requests") + publishedDNSHits = expvar.NewInt("counter_bootstrap_dns_published_hits") + publishedDNSMisses = expvar.NewInt("counter_bootstrap_dns_published_misses") + unpublishedDNSHits = expvar.NewInt("counter_bootstrap_dns_unpublished_hits") + unpublishedDNSMisses = expvar.NewInt("counter_bootstrap_dns_unpublished_misses") +) func refreshBootstrapDNSLoop() { - if *bootstrapDNS == "" { + if *bootstrapDNS == "" && *unpublishedDNS == "" { return } for { refreshBootstrapDNS() + refreshUnpublishedDNS() time.Sleep(10 * time.Minute) } } @@ -35,10 +50,34 @@ func refreshBootstrapDNS() { if *bootstrapDNS == "" { return } - dnsEntries := make(map[string][]net.IP) - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), refreshTimeout) + defer cancel() + dnsEntries := resolveList(ctx, strings.Split(*bootstrapDNS, ",")) + j, err := json.MarshalIndent(dnsEntries, "", "\t") + if err != nil { + // leave the old values in place + return + } + + dnsCache.Store(dnsEntries) + dnsCacheBytes.Store(j) +} + +func refreshUnpublishedDNS() { + if *unpublishedDNS == "" { + return + } + + ctx, cancel := context.WithTimeout(context.Background(), refreshTimeout) defer cancel() - names := strings.Split(*bootstrapDNS, ",") + + dnsEntries := resolveList(ctx, strings.Split(*unpublishedDNS, ",")) + unpublishedDNSCache.Store(dnsEntries) +} + +func resolveList(ctx context.Context, names []string) dnsEntryMap { + dnsEntries := make(dnsEntryMap) + var r net.Resolver for _, name := range names { addrs, err := r.LookupIP(ctx, "ip", name) @@ -48,21 +87,47 @@ func refreshBootstrapDNS() { } dnsEntries[name] = addrs } - j, err := json.MarshalIndent(dnsEntries, "", "\t") - if err != nil { - // leave the old values in place - return - } - dnsCache.Store(j) + return dnsEntries } func handleBootstrapDNS(w http.ResponseWriter, r *http.Request) { bootstrapDNSRequests.Add(1) + w.Header().Set("Content-Type", "application/json") - j := dnsCache.Load() - // Bootstrap DNS requests occur cross-regions, - // and are randomized per request, - // so keeping a connection open is pointlessly expensive. + // Bootstrap DNS requests occur cross-regions, and are randomized per + // request, so keeping a connection open is pointlessly expensive. w.Header().Set("Connection", "close") + + // Try answering a query from our hidden map first + if q := r.URL.Query().Get("q"); q != "" { + if ips, ok := unpublishedDNSCache.Load()[q]; ok && len(ips) > 0 { + unpublishedDNSHits.Add(1) + + // Only return the specific query, not everything. + m := dnsEntryMap{q: ips} + j, err := json.MarshalIndent(m, "", "\t") + if err == nil { + w.Write(j) + return + } + } + + // If we have a "q" query for a name in the published cache + // list, then track whether that's a hit/miss. + if m, ok := dnsCache.Load()[q]; ok { + if len(m) > 0 { + publishedDNSHits.Add(1) + } else { + publishedDNSMisses.Add(1) + } + } else { + // If it wasn't in either cache, treat this as a query + // for the unpublished cache, and thus a cache miss. + unpublishedDNSMisses.Add(1) + } + } + + // Fall back to returning the public set of cached DNS names + j := dnsCacheBytes.Load() w.Write(j) } diff --git a/cmd/derper/bootstrap_dns_test.go b/cmd/derper/bootstrap_dns_test.go index 240711a5e..ecc73d331 100644 --- a/cmd/derper/bootstrap_dns_test.go +++ b/cmd/derper/bootstrap_dns_test.go @@ -5,7 +5,12 @@ package main import ( + "encoding/json" + "net" "net/http" + "net/http/httptest" + "net/url" + "reflect" "testing" ) @@ -17,11 +22,12 @@ func BenchmarkHandleBootstrapDNS(b *testing.B) { }() refreshBootstrapDNS() w := new(bitbucketResponseWriter) + req, _ := http.NewRequest("GET", "https://localhost/bootstrap-dns?q="+url.QueryEscape("log.tailscale.io"), nil) b.ReportAllocs() b.ResetTimer() b.RunParallel(func(b *testing.PB) { for b.Next() { - handleBootstrapDNS(w, nil) + handleBootstrapDNS(w, req) } }) } @@ -33,3 +39,116 @@ func (b *bitbucketResponseWriter) Header() http.Header { return make(http.Header func (b *bitbucketResponseWriter) Write(p []byte) (int, error) { return len(p), nil } func (b *bitbucketResponseWriter) WriteHeader(statusCode int) {} + +func getBootstrapDNS(t *testing.T, q string) dnsEntryMap { + t.Helper() + req, _ := http.NewRequest("GET", "https://localhost/bootstrap-dns?q="+url.QueryEscape(q), nil) + w := httptest.NewRecorder() + handleBootstrapDNS(w, req) + + res := w.Result() + if res.StatusCode != 200 { + t.Fatalf("got status=%d; want %d", res.StatusCode, 200) + } + var ips dnsEntryMap + if err := json.NewDecoder(res.Body).Decode(&ips); err != nil { + t.Fatalf("error decoding response body: %v", err) + } + return ips +} + +func TestUnpublishedDNS(t *testing.T) { + const published = "login.tailscale.com" + const unpublished = "log.tailscale.io" + + prev1, prev2 := *bootstrapDNS, *unpublishedDNS + *bootstrapDNS = published + *unpublishedDNS = unpublished + t.Cleanup(func() { + *bootstrapDNS = prev1 + *unpublishedDNS = prev2 + }) + + refreshBootstrapDNS() + refreshUnpublishedDNS() + + hasResponse := func(q string) bool { + _, found := getBootstrapDNS(t, q)[q] + return found + } + + if !hasResponse(published) { + t.Errorf("expected response for: %s", published) + } + if !hasResponse(unpublished) { + t.Errorf("expected response for: %s", unpublished) + } + + // Verify that querying for a random query or a real query does not + // leak our unpublished domain + m1 := getBootstrapDNS(t, published) + if _, found := m1[unpublished]; found { + t.Errorf("found unpublished domain %s: %+v", unpublished, m1) + } + m2 := getBootstrapDNS(t, "random.example.com") + if _, found := m2[unpublished]; found { + t.Errorf("found unpublished domain %s: %+v", unpublished, m2) + } +} + +func resetMetrics() { + publishedDNSHits.Set(0) + publishedDNSMisses.Set(0) + unpublishedDNSHits.Set(0) + unpublishedDNSMisses.Set(0) +} + +// Verify that we don't count an empty list in the unpublishedDNSCache as a +// cache hit in our metrics. +func TestUnpublishedDNSEmptyList(t *testing.T) { + pub := dnsEntryMap{ + "tailscale.com": {net.IPv4(10, 10, 10, 10)}, + } + dnsCache.Store(pub) + dnsCacheBytes.Store([]byte(`{"tailscale.com":["10.10.10.10"]}`)) + + unpublishedDNSCache.Store(dnsEntryMap{ + "log.tailscale.io": {}, + "controlplane.tailscale.com": {net.IPv4(1, 2, 3, 4)}, + }) + + t.Run("CacheMiss", func(t *testing.T) { + // One domain in map but empty, one not in map at all + for _, q := range []string{"log.tailscale.io", "login.tailscale.com"} { + resetMetrics() + ips := getBootstrapDNS(t, q) + + // Expected our public map to be returned on a cache miss + if !reflect.DeepEqual(ips, pub) { + t.Errorf("got ips=%+v; want %+v", ips, pub) + } + if v := unpublishedDNSHits.Value(); v != 0 { + t.Errorf("got hits=%d; want 0", v) + } + if v := unpublishedDNSMisses.Value(); v != 1 { + t.Errorf("got misses=%d; want 1", v) + } + } + }) + + // Verify that we do get a valid response and metric. + t.Run("CacheHit", func(t *testing.T) { + resetMetrics() + ips := getBootstrapDNS(t, "controlplane.tailscale.com") + want := dnsEntryMap{"controlplane.tailscale.com": {net.IPv4(1, 2, 3, 4)}} + if !reflect.DeepEqual(ips, want) { + t.Errorf("got ips=%+v; want %+v", ips, want) + } + if v := unpublishedDNSHits.Value(); v != 1 { + t.Errorf("got hits=%d; want 1", v) + } + if v := unpublishedDNSMisses.Value(); v != 0 { + t.Errorf("got misses=%d; want 0", v) + } + }) +} diff --git a/cmd/derper/derper.go b/cmd/derper/derper.go index bd587c7b7..d0c019bdc 100644 --- a/cmd/derper/derper.go +++ b/cmd/derper/derper.go @@ -47,10 +47,11 @@ var ( hostname = flag.String("hostname", "derp.tailscale.com", "LetsEncrypt host name, if addr's port is :443") runSTUN = flag.Bool("stun", true, "whether to run a STUN server. It will bind to the same IP (if any) as the --addr flag value.") - meshPSKFile = flag.String("mesh-psk-file", defaultMeshPSKFile(), "if non-empty, path to file containing the mesh pre-shared key file. It should contain some hex string; whitespace is trimmed.") - meshWith = flag.String("mesh-with", "", "optional comma-separated list of hostnames to mesh with; the server's own hostname can be in the list") - bootstrapDNS = flag.String("bootstrap-dns-names", "", "optional comma-separated list of hostnames to make available at /bootstrap-dns") - verifyClients = flag.Bool("verify-clients", false, "verify clients to this DERP server through a local tailscaled instance.") + meshPSKFile = flag.String("mesh-psk-file", defaultMeshPSKFile(), "if non-empty, path to file containing the mesh pre-shared key file. It should contain some hex string; whitespace is trimmed.") + meshWith = flag.String("mesh-with", "", "optional comma-separated list of hostnames to mesh with; the server's own hostname can be in the list") + bootstrapDNS = flag.String("bootstrap-dns-names", "", "optional comma-separated list of hostnames to make available at /bootstrap-dns") + unpublishedDNS = flag.String("unpublished-bootstrap-dns-names", "", "optional comma-separated list of hostnames to make available at /bootstrap-dns and not publish in the list") + verifyClients = flag.Bool("verify-clients", false, "verify clients to this DERP server through a local tailscaled instance.") acceptConnLimit = flag.Float64("accept-connection-limit", math.Inf(+1), "rate limit for accepting new connection") acceptConnBurst = flag.Int("accept-connection-burst", math.MaxInt, "burst limit for accepting new connection")