@ -5,35 +5,45 @@ package main
import (
"context"
"encoding/binary"
"encoding/json"
"expvar"
"log"
"math/rand/v2"
"net"
"net/http"
"net/netip"
"strconv"
"strings"
"sync/atomic"
"time"
"tailscale.com/syncs"
"tailscale.com/util/mak"
"tailscale.com/util/slicesx"
)
const refreshTimeout = time . Minute
type dnsEntryMap map [ string ] [ ] net . IP
type dnsEntryMap struct {
IPs map [ string ] [ ] net . IP
Percent map [ string ] float64 // "foo.com" => 0.5 for 50%
}
var (
dnsCache syncs . AtomicValue [ dnsEntryMap ]
dnsCache atomic. Pointer [ dnsEntryMap ]
dnsCacheBytes syncs . AtomicValue [ [ ] byte ] // of JSON
unpublishedDNSCache syncs. AtomicValue [ dnsEntryMap ]
unpublishedDNSCache atomic. Pointer [ dnsEntryMap ]
bootstrapLookupMap syncs . Map [ string , bool ]
)
var (
bootstrapDNSRequests = expvar . NewInt ( "counter_bootstrap_dns_requests" )
publishedDNSHits = expvar . NewInt ( "counter_bootstrap_dns_published_hits" )
publishedDNSMisses = expvar . NewInt ( "counter_bootstrap_dns_published_misses" )
unpublishedDNSHits = expvar . NewInt ( "counter_bootstrap_dns_unpublished_hits" )
unpublishedDNSMisses = expvar . NewInt ( "counter_bootstrap_dns_unpublished_misses" )
bootstrapDNSRequests = expvar . NewInt ( "counter_bootstrap_dns_requests" )
publishedDNSHits = expvar . NewInt ( "counter_bootstrap_dns_published_hits" )
publishedDNSMisses = expvar . NewInt ( "counter_bootstrap_dns_published_misses" )
unpublishedDNSHits = expvar . NewInt ( "counter_bootstrap_dns_unpublished_hits" )
unpublishedDNSMisses = expvar . NewInt ( "counter_bootstrap_dns_unpublished_misses" )
unpublishedDNSPercentMisses = expvar . NewInt ( "counter_bootstrap_dns_unpublished_percent_misses" )
)
func init ( ) {
@ -59,15 +69,13 @@ func refreshBootstrapDNS() {
}
ctx , cancel := context . WithTimeout ( context . Background ( ) , refreshTimeout )
defer cancel ( )
dnsEntries := resolveList ( ctx , strings . Split ( * bootstrapDNS , "," ) )
dnsEntries := resolveList ( ctx , * bootstrapDNS )
// Randomize the order of the IPs for each name to avoid the client biasing
// to IPv6
for k := range dnsEntries {
ips := dnsEntries [ k ]
slicesx . Shuffle ( ips )
dnsEntries [ k ] = ips
for _ , vv := range dnsEntries . IPs {
slicesx . Shuffle ( vv )
}
j , err := json . MarshalIndent ( dnsEntries , "" , "\t" )
j , err := json . MarshalIndent ( dnsEntries . IPs , "" , "\t" )
if err != nil {
// leave the old values in place
return
@ -81,27 +89,50 @@ func refreshUnpublishedDNS() {
if * unpublishedDNS == "" {
return
}
ctx , cancel := context . WithTimeout ( context . Background ( ) , refreshTimeout )
defer cancel ( )
dnsEntries := resolveList ( ctx , strings . Split ( * unpublishedDNS , "," ) )
dnsEntries := resolveList ( ctx , * unpublishedDNS )
unpublishedDNSCache . Store ( dnsEntries )
}
func resolveList ( ctx context . Context , names [ ] string ) dnsEntryMap {
dnsEntries := make ( dnsEntryMap )
// resolveList takes a comma-separated list of DNS names to resolve.
//
// If an entry contains a slash, it's two DNS names: the first is the one to
// resolve and the second is that of a TXT recording containing the rollout
// percentage in range "0".."100". If the TXT record doesn't exist or is
// malformed, the percentage is 0. If the TXT record is not provided (there's no
// slash), then the percentage is 100.
func resolveList ( ctx context . Context , list string ) * dnsEntryMap {
ents := strings . Split ( list , "," )
ret := & dnsEntryMap { }
var r net . Resolver
for _ , name := range names {
for _ , ent := range ents {
name , txtName , _ := strings . Cut ( ent , "/" )
addrs , err := r . LookupIP ( ctx , "ip" , name )
if err != nil {
log . Printf ( "bootstrap DNS lookup %q: %v" , name , err )
continue
}
dnsEntries [ name ] = addrs
mak . Set ( & ret . IPs , name , addrs )
if txtName == "" {
mak . Set ( & ret . Percent , name , 1.0 )
continue
}
vals , err := r . LookupTXT ( ctx , txtName )
if err != nil {
log . Printf ( "bootstrap DNS lookup %q: %v" , txtName , err )
continue
}
for _ , v := range vals {
if v , err := strconv . Atoi ( v ) ; err == nil && v >= 0 && v <= 100 {
mak . Set ( & ret . Percent , name , float64 ( v ) / 100 )
}
}
}
return dnsEntries
return ret
}
func handleBootstrapDNS ( w http . ResponseWriter , r * http . Request ) {
@ -115,22 +146,36 @@ func handleBootstrapDNS(w http.ResponseWriter, r *http.Request) {
// Try answering a query from our hidden map first
if q := r . URL . Query ( ) . Get ( "q" ) ; q != "" {
bootstrapLookupMap . Store ( q , true )
if ips , ok := unpublishedDNSCache . Load ( ) [ q ] ; ok && len ( ips ) > 0 {
if bootstrapLookupMap . Len ( ) > 500 { // defensive
bootstrapLookupMap . Clear ( )
}
if m := unpublishedDNSCache . Load ( ) ; m != nil && len ( m . IPs [ q ] ) > 0 {
unpublishedDNSHits . Add ( 1 )
// Only return the specific query, not everything.
m := dnsEntryMap { q : ips }
j , err := json . MarshalIndent ( m , "" , "\t" )
if err == nil {
w . Write ( j )
return
percent := m . Percent [ q ]
if remoteAddrMatchesPercent ( r . RemoteAddr , percent ) {
// Only return the specific query, not everything.
m := map [ string ] [ ] net . IP { q : m . IPs [ q ] }
j , err := json . MarshalIndent ( m , "" , "\t" )
if err == nil {
w . Write ( j )
return
}
} else {
unpublishedDNSPercentMisses . Add ( 1 )
}
}
// If we have a "q" query for a name in the published cache
// list, then track whether that's a hit/miss.
if m , ok := dnsCache . Load ( ) [ q ] ; ok {
if len ( m ) > 0 {
m := dnsCache . Load ( )
var inPub bool
var ips [ ] net . IP
if m != nil {
ips , inPub = m . IPs [ q ]
}
if inPub {
if len ( ips ) > 0 {
publishedDNSHits . Add ( 1 )
} else {
publishedDNSMisses . Add ( 1 )
@ -146,3 +191,29 @@ func handleBootstrapDNS(w http.ResponseWriter, r *http.Request) {
j := dnsCacheBytes . Load ( )
w . Write ( j )
}
// percent is [0.0, 1.0].
func remoteAddrMatchesPercent ( remoteAddr string , percent float64 ) bool {
if percent == 0 {
return false
}
if percent == 1 {
return true
}
reqIPStr , _ , err := net . SplitHostPort ( remoteAddr )
if err != nil {
return false
}
reqIP , err := netip . ParseAddr ( reqIPStr )
if err != nil {
return false
}
if reqIP . IsLoopback ( ) {
// For local testing.
return rand . Float64 ( ) < 0.5
}
reqIP16 := reqIP . As16 ( )
rndSrc := rand . NewPCG ( binary . LittleEndian . Uint64 ( reqIP16 [ : 8 ] ) , binary . LittleEndian . Uint64 ( reqIP16 [ 8 : ] ) )
rnd := rand . New ( rndSrc )
return percent > rnd . Float64 ( )
}