cmd/derper: add support for unpublished bootstrap DNS entries (#5529)

Signed-off-by: Andrew Dunham <andrew@du.nham.ca>
pull/5542/head
Andrew Dunham 2 years ago committed by GitHub
parent 9132b31e43
commit a0bae4dac8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -17,16 +17,31 @@ import (
"tailscale.com/syncs" "tailscale.com/syncs"
) )
var dnsCache syncs.AtomicValue[[]byte] const refreshTimeout = time.Minute
var bootstrapDNSRequests = expvar.NewInt("counter_bootstrap_dns_requests") type dnsEntryMap map[string][]net.IP
var (
dnsCache syncs.AtomicValue[dnsEntryMap]
dnsCacheBytes syncs.AtomicValue[[]byte] // of JSON
unpublishedDNSCache syncs.AtomicValue[dnsEntryMap]
)
var (
bootstrapDNSRequests = expvar.NewInt("counter_bootstrap_dns_requests")
publishedDNSHits = expvar.NewInt("counter_bootstrap_dns_published_hits")
publishedDNSMisses = expvar.NewInt("counter_bootstrap_dns_published_misses")
unpublishedDNSHits = expvar.NewInt("counter_bootstrap_dns_unpublished_hits")
unpublishedDNSMisses = expvar.NewInt("counter_bootstrap_dns_unpublished_misses")
)
func refreshBootstrapDNSLoop() { func refreshBootstrapDNSLoop() {
if *bootstrapDNS == "" { if *bootstrapDNS == "" && *unpublishedDNS == "" {
return return
} }
for { for {
refreshBootstrapDNS() refreshBootstrapDNS()
refreshUnpublishedDNS()
time.Sleep(10 * time.Minute) time.Sleep(10 * time.Minute)
} }
} }
@ -35,10 +50,34 @@ func refreshBootstrapDNS() {
if *bootstrapDNS == "" { if *bootstrapDNS == "" {
return return
} }
dnsEntries := make(map[string][]net.IP) ctx, cancel := context.WithTimeout(context.Background(), refreshTimeout)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel() defer cancel()
names := strings.Split(*bootstrapDNS, ",") dnsEntries := resolveList(ctx, strings.Split(*bootstrapDNS, ","))
j, err := json.MarshalIndent(dnsEntries, "", "\t")
if err != nil {
// leave the old values in place
return
}
dnsCache.Store(dnsEntries)
dnsCacheBytes.Store(j)
}
func refreshUnpublishedDNS() {
if *unpublishedDNS == "" {
return
}
ctx, cancel := context.WithTimeout(context.Background(), refreshTimeout)
defer cancel()
dnsEntries := resolveList(ctx, strings.Split(*unpublishedDNS, ","))
unpublishedDNSCache.Store(dnsEntries)
}
func resolveList(ctx context.Context, names []string) dnsEntryMap {
dnsEntries := make(dnsEntryMap)
var r net.Resolver var r net.Resolver
for _, name := range names { for _, name := range names {
addrs, err := r.LookupIP(ctx, "ip", name) addrs, err := r.LookupIP(ctx, "ip", name)
@ -48,21 +87,47 @@ func refreshBootstrapDNS() {
} }
dnsEntries[name] = addrs dnsEntries[name] = addrs
} }
j, err := json.MarshalIndent(dnsEntries, "", "\t") return dnsEntries
if err != nil {
// leave the old values in place
return
}
dnsCache.Store(j)
} }
func handleBootstrapDNS(w http.ResponseWriter, r *http.Request) { func handleBootstrapDNS(w http.ResponseWriter, r *http.Request) {
bootstrapDNSRequests.Add(1) bootstrapDNSRequests.Add(1)
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
j := dnsCache.Load() // Bootstrap DNS requests occur cross-regions, and are randomized per
// Bootstrap DNS requests occur cross-regions, // request, so keeping a connection open is pointlessly expensive.
// and are randomized per request,
// so keeping a connection open is pointlessly expensive.
w.Header().Set("Connection", "close") w.Header().Set("Connection", "close")
// Try answering a query from our hidden map first
if q := r.URL.Query().Get("q"); q != "" {
if ips, ok := unpublishedDNSCache.Load()[q]; ok && len(ips) > 0 {
unpublishedDNSHits.Add(1)
// Only return the specific query, not everything.
m := dnsEntryMap{q: ips}
j, err := json.MarshalIndent(m, "", "\t")
if err == nil {
w.Write(j)
return
}
}
// If we have a "q" query for a name in the published cache
// list, then track whether that's a hit/miss.
if m, ok := dnsCache.Load()[q]; ok {
if len(m) > 0 {
publishedDNSHits.Add(1)
} else {
publishedDNSMisses.Add(1)
}
} else {
// If it wasn't in either cache, treat this as a query
// for the unpublished cache, and thus a cache miss.
unpublishedDNSMisses.Add(1)
}
}
// Fall back to returning the public set of cached DNS names
j := dnsCacheBytes.Load()
w.Write(j) w.Write(j)
} }

@ -5,7 +5,12 @@
package main package main
import ( import (
"encoding/json"
"net"
"net/http" "net/http"
"net/http/httptest"
"net/url"
"reflect"
"testing" "testing"
) )
@ -17,11 +22,12 @@ func BenchmarkHandleBootstrapDNS(b *testing.B) {
}() }()
refreshBootstrapDNS() refreshBootstrapDNS()
w := new(bitbucketResponseWriter) w := new(bitbucketResponseWriter)
req, _ := http.NewRequest("GET", "https://localhost/bootstrap-dns?q="+url.QueryEscape("log.tailscale.io"), nil)
b.ReportAllocs() b.ReportAllocs()
b.ResetTimer() b.ResetTimer()
b.RunParallel(func(b *testing.PB) { b.RunParallel(func(b *testing.PB) {
for b.Next() { for b.Next() {
handleBootstrapDNS(w, nil) handleBootstrapDNS(w, req)
} }
}) })
} }
@ -33,3 +39,116 @@ func (b *bitbucketResponseWriter) Header() http.Header { return make(http.Header
func (b *bitbucketResponseWriter) Write(p []byte) (int, error) { return len(p), nil } func (b *bitbucketResponseWriter) Write(p []byte) (int, error) { return len(p), nil }
func (b *bitbucketResponseWriter) WriteHeader(statusCode int) {} func (b *bitbucketResponseWriter) WriteHeader(statusCode int) {}
func getBootstrapDNS(t *testing.T, q string) dnsEntryMap {
t.Helper()
req, _ := http.NewRequest("GET", "https://localhost/bootstrap-dns?q="+url.QueryEscape(q), nil)
w := httptest.NewRecorder()
handleBootstrapDNS(w, req)
res := w.Result()
if res.StatusCode != 200 {
t.Fatalf("got status=%d; want %d", res.StatusCode, 200)
}
var ips dnsEntryMap
if err := json.NewDecoder(res.Body).Decode(&ips); err != nil {
t.Fatalf("error decoding response body: %v", err)
}
return ips
}
func TestUnpublishedDNS(t *testing.T) {
const published = "login.tailscale.com"
const unpublished = "log.tailscale.io"
prev1, prev2 := *bootstrapDNS, *unpublishedDNS
*bootstrapDNS = published
*unpublishedDNS = unpublished
t.Cleanup(func() {
*bootstrapDNS = prev1
*unpublishedDNS = prev2
})
refreshBootstrapDNS()
refreshUnpublishedDNS()
hasResponse := func(q string) bool {
_, found := getBootstrapDNS(t, q)[q]
return found
}
if !hasResponse(published) {
t.Errorf("expected response for: %s", published)
}
if !hasResponse(unpublished) {
t.Errorf("expected response for: %s", unpublished)
}
// Verify that querying for a random query or a real query does not
// leak our unpublished domain
m1 := getBootstrapDNS(t, published)
if _, found := m1[unpublished]; found {
t.Errorf("found unpublished domain %s: %+v", unpublished, m1)
}
m2 := getBootstrapDNS(t, "random.example.com")
if _, found := m2[unpublished]; found {
t.Errorf("found unpublished domain %s: %+v", unpublished, m2)
}
}
func resetMetrics() {
publishedDNSHits.Set(0)
publishedDNSMisses.Set(0)
unpublishedDNSHits.Set(0)
unpublishedDNSMisses.Set(0)
}
// Verify that we don't count an empty list in the unpublishedDNSCache as a
// cache hit in our metrics.
func TestUnpublishedDNSEmptyList(t *testing.T) {
pub := dnsEntryMap{
"tailscale.com": {net.IPv4(10, 10, 10, 10)},
}
dnsCache.Store(pub)
dnsCacheBytes.Store([]byte(`{"tailscale.com":["10.10.10.10"]}`))
unpublishedDNSCache.Store(dnsEntryMap{
"log.tailscale.io": {},
"controlplane.tailscale.com": {net.IPv4(1, 2, 3, 4)},
})
t.Run("CacheMiss", func(t *testing.T) {
// One domain in map but empty, one not in map at all
for _, q := range []string{"log.tailscale.io", "login.tailscale.com"} {
resetMetrics()
ips := getBootstrapDNS(t, q)
// Expected our public map to be returned on a cache miss
if !reflect.DeepEqual(ips, pub) {
t.Errorf("got ips=%+v; want %+v", ips, pub)
}
if v := unpublishedDNSHits.Value(); v != 0 {
t.Errorf("got hits=%d; want 0", v)
}
if v := unpublishedDNSMisses.Value(); v != 1 {
t.Errorf("got misses=%d; want 1", v)
}
}
})
// Verify that we do get a valid response and metric.
t.Run("CacheHit", func(t *testing.T) {
resetMetrics()
ips := getBootstrapDNS(t, "controlplane.tailscale.com")
want := dnsEntryMap{"controlplane.tailscale.com": {net.IPv4(1, 2, 3, 4)}}
if !reflect.DeepEqual(ips, want) {
t.Errorf("got ips=%+v; want %+v", ips, want)
}
if v := unpublishedDNSHits.Value(); v != 1 {
t.Errorf("got hits=%d; want 1", v)
}
if v := unpublishedDNSMisses.Value(); v != 0 {
t.Errorf("got misses=%d; want 0", v)
}
})
}

@ -50,6 +50,7 @@ var (
meshPSKFile = flag.String("mesh-psk-file", defaultMeshPSKFile(), "if non-empty, path to file containing the mesh pre-shared key file. It should contain some hex string; whitespace is trimmed.") meshPSKFile = flag.String("mesh-psk-file", defaultMeshPSKFile(), "if non-empty, path to file containing the mesh pre-shared key file. It should contain some hex string; whitespace is trimmed.")
meshWith = flag.String("mesh-with", "", "optional comma-separated list of hostnames to mesh with; the server's own hostname can be in the list") meshWith = flag.String("mesh-with", "", "optional comma-separated list of hostnames to mesh with; the server's own hostname can be in the list")
bootstrapDNS = flag.String("bootstrap-dns-names", "", "optional comma-separated list of hostnames to make available at /bootstrap-dns") bootstrapDNS = flag.String("bootstrap-dns-names", "", "optional comma-separated list of hostnames to make available at /bootstrap-dns")
unpublishedDNS = flag.String("unpublished-bootstrap-dns-names", "", "optional comma-separated list of hostnames to make available at /bootstrap-dns and not publish in the list")
verifyClients = flag.Bool("verify-clients", false, "verify clients to this DERP server through a local tailscaled instance.") verifyClients = flag.Bool("verify-clients", false, "verify clients to this DERP server through a local tailscaled instance.")
acceptConnLimit = flag.Float64("accept-connection-limit", math.Inf(+1), "rate limit for accepting new connection") acceptConnLimit = flag.Float64("accept-connection-limit", math.Inf(+1), "rate limit for accepting new connection")

Loading…
Cancel
Save