cmd/derper: support TXT-mediated unpublished bootstrap DNS rollouts

Updates tailscale/coral#127

Change-Id: I2712c50630d0d1272c30305fa5a1899a19ffacef
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/12223/head
Brad Fitzpatrick 6 months ago committed by Brad Fitzpatrick
parent 72f0f53ed0
commit 3c9be07214

@ -5,26 +5,35 @@ package main
import ( import (
"context" "context"
"encoding/binary"
"encoding/json" "encoding/json"
"expvar" "expvar"
"log" "log"
"math/rand/v2"
"net" "net"
"net/http" "net/http"
"net/netip"
"strconv"
"strings" "strings"
"sync/atomic"
"time" "time"
"tailscale.com/syncs" "tailscale.com/syncs"
"tailscale.com/util/mak"
"tailscale.com/util/slicesx" "tailscale.com/util/slicesx"
) )
const refreshTimeout = time.Minute const refreshTimeout = time.Minute
type dnsEntryMap map[string][]net.IP type dnsEntryMap struct {
IPs map[string][]net.IP
Percent map[string]float64 // "foo.com" => 0.5 for 50%
}
var ( var (
dnsCache syncs.AtomicValue[dnsEntryMap] dnsCache atomic.Pointer[dnsEntryMap]
dnsCacheBytes syncs.AtomicValue[[]byte] // of JSON dnsCacheBytes syncs.AtomicValue[[]byte] // of JSON
unpublishedDNSCache syncs.AtomicValue[dnsEntryMap] unpublishedDNSCache atomic.Pointer[dnsEntryMap]
bootstrapLookupMap syncs.Map[string, bool] bootstrapLookupMap syncs.Map[string, bool]
) )
@ -34,6 +43,7 @@ var (
publishedDNSMisses = expvar.NewInt("counter_bootstrap_dns_published_misses") publishedDNSMisses = expvar.NewInt("counter_bootstrap_dns_published_misses")
unpublishedDNSHits = expvar.NewInt("counter_bootstrap_dns_unpublished_hits") unpublishedDNSHits = expvar.NewInt("counter_bootstrap_dns_unpublished_hits")
unpublishedDNSMisses = expvar.NewInt("counter_bootstrap_dns_unpublished_misses") unpublishedDNSMisses = expvar.NewInt("counter_bootstrap_dns_unpublished_misses")
unpublishedDNSPercentMisses = expvar.NewInt("counter_bootstrap_dns_unpublished_percent_misses")
) )
func init() { func init() {
@ -59,15 +69,13 @@ func refreshBootstrapDNS() {
} }
ctx, cancel := context.WithTimeout(context.Background(), refreshTimeout) ctx, cancel := context.WithTimeout(context.Background(), refreshTimeout)
defer cancel() defer cancel()
dnsEntries := resolveList(ctx, strings.Split(*bootstrapDNS, ",")) dnsEntries := resolveList(ctx, *bootstrapDNS)
// Randomize the order of the IPs for each name to avoid the client biasing // Randomize the order of the IPs for each name to avoid the client biasing
// to IPv6 // to IPv6
for k := range dnsEntries { for _, vv := range dnsEntries.IPs {
ips := dnsEntries[k] slicesx.Shuffle(vv)
slicesx.Shuffle(ips)
dnsEntries[k] = ips
} }
j, err := json.MarshalIndent(dnsEntries, "", "\t") j, err := json.MarshalIndent(dnsEntries.IPs, "", "\t")
if err != nil { if err != nil {
// leave the old values in place // leave the old values in place
return return
@ -81,27 +89,50 @@ func refreshUnpublishedDNS() {
if *unpublishedDNS == "" { if *unpublishedDNS == "" {
return return
} }
ctx, cancel := context.WithTimeout(context.Background(), refreshTimeout) ctx, cancel := context.WithTimeout(context.Background(), refreshTimeout)
defer cancel() defer cancel()
dnsEntries := resolveList(ctx, *unpublishedDNS)
dnsEntries := resolveList(ctx, strings.Split(*unpublishedDNS, ","))
unpublishedDNSCache.Store(dnsEntries) unpublishedDNSCache.Store(dnsEntries)
} }
func resolveList(ctx context.Context, names []string) dnsEntryMap { // resolveList takes a comma-separated list of DNS names to resolve.
dnsEntries := make(dnsEntryMap) //
// If an entry contains a slash, it's two DNS names: the first is the one to
// resolve and the second is that of a TXT recording containing the rollout
// percentage in range "0".."100". If the TXT record doesn't exist or is
// malformed, the percentage is 0. If the TXT record is not provided (there's no
// slash), then the percentage is 100.
func resolveList(ctx context.Context, list string) *dnsEntryMap {
ents := strings.Split(list, ",")
ret := &dnsEntryMap{}
var r net.Resolver var r net.Resolver
for _, name := range names { for _, ent := range ents {
name, txtName, _ := strings.Cut(ent, "/")
addrs, err := r.LookupIP(ctx, "ip", name) addrs, err := r.LookupIP(ctx, "ip", name)
if err != nil { if err != nil {
log.Printf("bootstrap DNS lookup %q: %v", name, err) log.Printf("bootstrap DNS lookup %q: %v", name, err)
continue continue
} }
dnsEntries[name] = addrs mak.Set(&ret.IPs, name, addrs)
if txtName == "" {
mak.Set(&ret.Percent, name, 1.0)
continue
}
vals, err := r.LookupTXT(ctx, txtName)
if err != nil {
log.Printf("bootstrap DNS lookup %q: %v", txtName, err)
continue
}
for _, v := range vals {
if v, err := strconv.Atoi(v); err == nil && v >= 0 && v <= 100 {
mak.Set(&ret.Percent, name, float64(v)/100)
}
}
} }
return dnsEntries return ret
} }
func handleBootstrapDNS(w http.ResponseWriter, r *http.Request) { func handleBootstrapDNS(w http.ResponseWriter, r *http.Request) {
@ -115,22 +146,36 @@ func handleBootstrapDNS(w http.ResponseWriter, r *http.Request) {
// Try answering a query from our hidden map first // Try answering a query from our hidden map first
if q := r.URL.Query().Get("q"); q != "" { if q := r.URL.Query().Get("q"); q != "" {
bootstrapLookupMap.Store(q, true) bootstrapLookupMap.Store(q, true)
if ips, ok := unpublishedDNSCache.Load()[q]; ok && len(ips) > 0 { if bootstrapLookupMap.Len() > 500 { // defensive
bootstrapLookupMap.Clear()
}
if m := unpublishedDNSCache.Load(); m != nil && len(m.IPs[q]) > 0 {
unpublishedDNSHits.Add(1) unpublishedDNSHits.Add(1)
percent := m.Percent[q]
if remoteAddrMatchesPercent(r.RemoteAddr, percent) {
// Only return the specific query, not everything. // Only return the specific query, not everything.
m := dnsEntryMap{q: ips} m := map[string][]net.IP{q: m.IPs[q]}
j, err := json.MarshalIndent(m, "", "\t") j, err := json.MarshalIndent(m, "", "\t")
if err == nil { if err == nil {
w.Write(j) w.Write(j)
return return
} }
} else {
unpublishedDNSPercentMisses.Add(1)
}
} }
// If we have a "q" query for a name in the published cache // If we have a "q" query for a name in the published cache
// list, then track whether that's a hit/miss. // list, then track whether that's a hit/miss.
if m, ok := dnsCache.Load()[q]; ok { m := dnsCache.Load()
if len(m) > 0 { var inPub bool
var ips []net.IP
if m != nil {
ips, inPub = m.IPs[q]
}
if inPub {
if len(ips) > 0 {
publishedDNSHits.Add(1) publishedDNSHits.Add(1)
} else { } else {
publishedDNSMisses.Add(1) publishedDNSMisses.Add(1)
@ -146,3 +191,29 @@ func handleBootstrapDNS(w http.ResponseWriter, r *http.Request) {
j := dnsCacheBytes.Load() j := dnsCacheBytes.Load()
w.Write(j) w.Write(j)
} }
// percent is [0.0, 1.0].
func remoteAddrMatchesPercent(remoteAddr string, percent float64) bool {
if percent == 0 {
return false
}
if percent == 1 {
return true
}
reqIPStr, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
return false
}
reqIP, err := netip.ParseAddr(reqIPStr)
if err != nil {
return false
}
if reqIP.IsLoopback() {
// For local testing.
return rand.Float64() < 0.5
}
reqIP16 := reqIP.As16()
rndSrc := rand.NewPCG(binary.LittleEndian.Uint64(reqIP16[:8]), binary.LittleEndian.Uint64(reqIP16[8:]))
rnd := rand.New(rndSrc)
return percent > rnd.Float64()
}

@ -4,10 +4,13 @@
package main package main
import ( import (
"bytes"
"encoding/json" "encoding/json"
"io"
"net" "net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/netip"
"net/url" "net/url"
"reflect" "reflect"
"testing" "testing"
@ -38,7 +41,7 @@ func (b *bitbucketResponseWriter) Write(p []byte) (int, error) { return len(p),
func (b *bitbucketResponseWriter) WriteHeader(statusCode int) {} func (b *bitbucketResponseWriter) WriteHeader(statusCode int) {}
func getBootstrapDNS(t *testing.T, q string) dnsEntryMap { func getBootstrapDNS(t *testing.T, q string) map[string][]net.IP {
t.Helper() t.Helper()
req, _ := http.NewRequest("GET", "https://localhost/bootstrap-dns?q="+url.QueryEscape(q), nil) req, _ := http.NewRequest("GET", "https://localhost/bootstrap-dns?q="+url.QueryEscape(q), nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -48,11 +51,12 @@ func getBootstrapDNS(t *testing.T, q string) dnsEntryMap {
if res.StatusCode != 200 { if res.StatusCode != 200 {
t.Fatalf("got status=%d; want %d", res.StatusCode, 200) t.Fatalf("got status=%d; want %d", res.StatusCode, 200)
} }
var ips dnsEntryMap var m map[string][]net.IP
if err := json.NewDecoder(res.Body).Decode(&ips); err != nil { var buf bytes.Buffer
t.Fatalf("error decoding response body: %v", err) if err := json.NewDecoder(io.TeeReader(res.Body, &buf)).Decode(&m); err != nil {
t.Fatalf("error decoding response body %q: %v", buf.Bytes(), err)
} }
return ips return m
} }
func TestUnpublishedDNS(t *testing.T) { func TestUnpublishedDNS(t *testing.T) {
@ -107,15 +111,21 @@ func resetMetrics() {
// Verify that we don't count an empty list in the unpublishedDNSCache as a // Verify that we don't count an empty list in the unpublishedDNSCache as a
// cache hit in our metrics. // cache hit in our metrics.
func TestUnpublishedDNSEmptyList(t *testing.T) { func TestUnpublishedDNSEmptyList(t *testing.T) {
pub := dnsEntryMap{ pub := &dnsEntryMap{
"tailscale.com": {net.IPv4(10, 10, 10, 10)}, IPs: map[string][]net.IP{"tailscale.com": {net.IPv4(10, 10, 10, 10)}},
} }
dnsCache.Store(pub) dnsCache.Store(pub)
dnsCacheBytes.Store([]byte(`{"tailscale.com":["10.10.10.10"]}`)) dnsCacheBytes.Store([]byte(`{"tailscale.com":["10.10.10.10"]}`))
unpublishedDNSCache.Store(dnsEntryMap{ unpublishedDNSCache.Store(&dnsEntryMap{
IPs: map[string][]net.IP{
"log.tailscale.io": {}, "log.tailscale.io": {},
"controlplane.tailscale.com": {net.IPv4(1, 2, 3, 4)}, "controlplane.tailscale.com": {net.IPv4(1, 2, 3, 4)},
},
Percent: map[string]float64{
"log.tailscale.io": 1.0,
"controlplane.tailscale.com": 1.0,
},
}) })
t.Run("CacheMiss", func(t *testing.T) { t.Run("CacheMiss", func(t *testing.T) {
@ -125,8 +135,8 @@ func TestUnpublishedDNSEmptyList(t *testing.T) {
ips := getBootstrapDNS(t, q) ips := getBootstrapDNS(t, q)
// Expected our public map to be returned on a cache miss // Expected our public map to be returned on a cache miss
if !reflect.DeepEqual(ips, pub) { if !reflect.DeepEqual(ips, pub.IPs) {
t.Errorf("got ips=%+v; want %+v", ips, pub) t.Errorf("got ips=%+v; want %+v", ips, pub.IPs)
} }
if v := unpublishedDNSHits.Value(); v != 0 { if v := unpublishedDNSHits.Value(); v != 0 {
t.Errorf("got hits=%d; want 0", v) t.Errorf("got hits=%d; want 0", v)
@ -141,7 +151,7 @@ func TestUnpublishedDNSEmptyList(t *testing.T) {
t.Run("CacheHit", func(t *testing.T) { t.Run("CacheHit", func(t *testing.T) {
resetMetrics() resetMetrics()
ips := getBootstrapDNS(t, "controlplane.tailscale.com") ips := getBootstrapDNS(t, "controlplane.tailscale.com")
want := dnsEntryMap{"controlplane.tailscale.com": {net.IPv4(1, 2, 3, 4)}} want := map[string][]net.IP{"controlplane.tailscale.com": {net.IPv4(1, 2, 3, 4)}}
if !reflect.DeepEqual(ips, want) { if !reflect.DeepEqual(ips, want) {
t.Errorf("got ips=%+v; want %+v", ips, want) t.Errorf("got ips=%+v; want %+v", ips, want)
} }
@ -166,3 +176,54 @@ func TestLookupMetric(t *testing.T) {
t.Errorf("bootstrapLookupMap.Len() want=5, got %v", bootstrapLookupMap.Len()) t.Errorf("bootstrapLookupMap.Len() want=5, got %v", bootstrapLookupMap.Len())
} }
} }
func TestRemoteAddrMatchesPercent(t *testing.T) {
tests := []struct {
remoteAddr string
percent float64
want bool
}{
// 0% and 100%.
{"10.0.0.1:1234", 0.0, false},
{"10.0.0.1:1234", 1.0, true},
// Invalid IP.
{"", 1.0, true},
{"", 0.0, false},
{"", 0.5, false},
// Small manual sample at 50%. The func uses a deterministic PRNG seed.
{"1.2.3.4:567", 0.5, true},
{"1.2.3.5:567", 0.5, true},
{"1.2.3.6:567", 0.5, false},
{"1.2.3.7:567", 0.5, true},
{"1.2.3.8:567", 0.5, false},
{"1.2.3.9:567", 0.5, true},
{"1.2.3.10:567", 0.5, true},
}
for _, tt := range tests {
got := remoteAddrMatchesPercent(tt.remoteAddr, tt.percent)
if got != tt.want {
t.Errorf("remoteAddrMatchesPercent(%q, %v) = %v; want %v", tt.remoteAddr, tt.percent, got, tt.want)
}
}
var match, all int
const wantPercent = 0.5
for a := range 256 {
for b := range 256 {
all++
if remoteAddrMatchesPercent(
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, byte(a), byte(b)}), 12345).String(),
wantPercent) {
match++
}
}
}
gotPercent := float64(match) / float64(all)
const tolerance = 0.005
t.Logf("got percent %v (goal %v)", gotPercent, wantPercent)
if gotPercent < wantPercent-tolerance || gotPercent > wantPercent+tolerance {
t.Errorf("got %v; want %v ± %v", gotPercent, wantPercent, tolerance)
}
}

@ -253,7 +253,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
math/big from crypto/dsa+ math/big from crypto/dsa+
math/bits from compress/flate+ math/bits from compress/flate+
math/rand from github.com/mdlayher/netlink+ math/rand from github.com/mdlayher/netlink+
math/rand/v2 from tailscale.com/util/fastuuid math/rand/v2 from tailscale.com/util/fastuuid+
mime from github.com/prometheus/common/expfmt+ mime from github.com/prometheus/common/expfmt+
mime/multipart from net/http mime/multipart from net/http
mime/quotedprintable from mime/multipart mime/quotedprintable from mime/multipart

@ -55,7 +55,7 @@ var (
meshPSKFile = flag.String("mesh-psk-file", defaultMeshPSKFile(), "if non-empty, path to file containing the mesh pre-shared key file. It should contain some hex string; whitespace is trimmed.") meshPSKFile = flag.String("mesh-psk-file", defaultMeshPSKFile(), "if non-empty, path to file containing the mesh pre-shared key file. It should contain some hex string; whitespace is trimmed.")
meshWith = flag.String("mesh-with", "", "optional comma-separated list of hostnames to mesh with; the server's own hostname can be in the list") meshWith = flag.String("mesh-with", "", "optional comma-separated list of hostnames to mesh with; the server's own hostname can be in the list")
bootstrapDNS = flag.String("bootstrap-dns-names", "", "optional comma-separated list of hostnames to make available at /bootstrap-dns") bootstrapDNS = flag.String("bootstrap-dns-names", "", "optional comma-separated list of hostnames to make available at /bootstrap-dns")
unpublishedDNS = flag.String("unpublished-bootstrap-dns-names", "", "optional comma-separated list of hostnames to make available at /bootstrap-dns and not publish in the list") unpublishedDNS = flag.String("unpublished-bootstrap-dns-names", "", "optional comma-separated list of hostnames to make available at /bootstrap-dns and not publish in the list. If an entry contains a slash, the second part names a DNS record to poll for its TXT record with a `0` to `100` value for rollout percentage.")
verifyClients = flag.Bool("verify-clients", false, "verify clients to this DERP server through a local tailscaled instance.") verifyClients = flag.Bool("verify-clients", false, "verify clients to this DERP server through a local tailscaled instance.")
verifyClientURL = flag.String("verify-client-url", "", "if non-empty, an admission controller URL for permitting client connections; see tailcfg.DERPAdmitClientRequest") verifyClientURL = flag.String("verify-client-url", "", "if non-empty, an admission controller URL for permitting client connections; see tailcfg.DERPAdmitClientRequest")
verifyFailOpen = flag.Bool("verify-client-url-fail-open", true, "whether we fail open if --verify-client-url is unreachable") verifyFailOpen = flag.Bool("verify-client-url-fail-open", true, "whether we fail open if --verify-client-url is unreachable")

Loading…
Cancel
Save