@ -24,6 +24,7 @@ import (
dns "golang.org/x/net/dns/dnsmessage"
"inet.af/netaddr"
"tailscale.com/net/netns"
"tailscale.com/types/dnstype"
"tailscale.com/types/logger"
"tailscale.com/util/dnsname"
"tailscale.com/wgengine/monitor"
@ -133,8 +134,8 @@ type route struct {
// resolverAndDelay is an upstream DNS resolver and a delay for how
// long to wait before querying it.
type resolverAndDelay struct {
// ipp is the upstream resolver.
ipp netaddr . IPPort
// name is the upstream resolver.
name dnstype . Resolver
// startDelay is an amount to delay this resolver at
// start. It's used when, say, there are four Google or
@ -158,7 +159,7 @@ type forwarder struct {
mu sync . Mutex // guards following
dohClient map [ netaddr . IP ] * http . Client
dohClient map [ string ] * http . Client // urlBase -> client
// routes are per-suffix resolvers to use, with
// the most specific routes first.
@ -192,11 +193,11 @@ func (f *forwarder) Close() error {
return nil
}
// resolversWithDelays maps from a set of DNS server ip:ports (currently
// the port is always 53) to a slice of a type that included a
// startDelay. So if ipps contains e.g. four Google DNS IPs (two IPv4
// + twoIPv6), this function partition adds delays to some.
func resolversWithDelays ( ipps [ ] netaddr . IPPort ) [ ] resolverAndDelay {
// resolversWithDelays maps from a set of DNS server names to a slice of
// a type that included a startDelay. So if resolvers contains e.g. four
// Google DNS IPs (two IPv4 + twoIPv6), this function partition adds
// delays to some.
func resolversWithDelays ( resolvers [ ] dnstype . Resolver ) [ ] resolverAndDelay {
type hostAndFam struct {
host string // some arbitrary string representing DNS host (currently the DoH base)
bits uint8 // either 32 or 128 for IPv4 vs IPv6s address family
@ -206,18 +207,19 @@ func resolversWithDelays(ipps []netaddr.IPPort) []resolverAndDelay {
// per address family.
total := map [ hostAndFam ] int { }
rr := make ( [ ] resolverAndDelay , len ( ipp s) )
for _ , ipp := range ipp s {
ip := ipp . IP ( )
rr := make ( [ ] resolverAndDelay , len ( resolver s) )
for _ , r := range resolver s {
if ip , err := netaddr . ParseIP ( r . Addr ) ; err == nil {
if host , ok := knownDoH [ ip ] ; ok {
total [ hostAndFam { host , ip . BitLen ( ) } ] ++
}
}
}
done := map [ hostAndFam ] int { }
for i , ipp := range ipps {
ip := ipp . IP ( )
for i , r := range resolvers {
var startDelay time . Duration
if ip , err := netaddr . ParseIP ( r . Addr ) ; err == nil {
if host , ok := knownDoH [ ip ] ; ok {
key4 := hostAndFam { host , 32 }
key6 := hostAndFam { host , 128 }
@ -245,8 +247,9 @@ func resolversWithDelays(ipps []netaddr.IPPort) []resolverAndDelay {
}
done [ hostAndFam { host , ip . BitLen ( ) } ] ++
}
}
rr [ i ] = resolverAndDelay {
ipp: ipp ,
name: r ,
startDelay : startDelay ,
}
}
@ -257,12 +260,12 @@ func resolversWithDelays(ipps []netaddr.IPPort) []resolverAndDelay {
// Resolver.SetConfig on reconfig.
//
// The memory referenced by routesBySuffix should not be modified.
func ( f * forwarder ) setRoutes ( routesBySuffix map [ dnsname . FQDN ] [ ] netaddr. IPPort ) {
func ( f * forwarder ) setRoutes ( routesBySuffix map [ dnsname . FQDN ] [ ] dnstype. Resolver ) {
routes := make ( [ ] route , 0 , len ( routesBySuffix ) )
for suffix , ipp s := range routesBySuffix {
for suffix , r s := range routesBySuffix {
routes = append ( routes , route {
Suffix : suffix ,
Resolvers : resolversWithDelays ( ipp s) ,
Resolvers : resolversWithDelays ( r s) ,
} )
}
// Sort from longest prefix to shortest.
@ -296,18 +299,19 @@ func (f *forwarder) packetListener(ip netaddr.IP) (packetListener, error) {
return lc , nil
}
func ( f * forwarder ) get DoHClient( ip netaddr . IP ) ( urlBase string , c * http . Client , ok bool ) {
func ( f * forwarder ) get Known DoHClient( ip netaddr . IP ) ( urlBase string , c * http . Client , ok bool ) {
urlBase , ok = knownDoH [ ip ]
if ! ok {
return
}
f . mu . Lock ( )
defer f . mu . Unlock ( )
if c , ok := f . dohClient [ ip ] ; ok {
if c , ok := f . dohClient [ urlBase ] ; ok {
return urlBase , c , true
}
if f . dohClient == nil {
f . dohClient = map [ netaddr . IP ] * http . Client { }
f . dohClient = map [ string ] * http . Client { }
}
nsDialer := netns . NewDialer ( )
c = & http . Client {
@ -330,7 +334,7 @@ func (f *forwarder) getDoHClient(ip netaddr.IP) (urlBase string, c *http.Client,
} ,
} ,
}
f . dohClient [ ip ] = c
f . dohClient [ urlBase ] = c
return urlBase , c , true
}
@ -380,20 +384,32 @@ func (f *forwarder) sendDoH(ctx context.Context, urlBase string, c *http.Client,
// send sends packet to dst. It is best effort.
//
// send expects the reply to have the same txid as txidOut.
//
func ( f * forwarder ) send ( ctx context . Context , fq * forwardQuery , dst netaddr . IPPort ) ( [ ] byte , error ) {
ip := dst . IP ( )
func ( f * forwarder ) send ( ctx context . Context , fq * forwardQuery , rr resolverAndDelay ) ( [ ] byte , error ) {
if strings . HasPrefix ( rr . name . Addr , "http://" ) {
return nil , fmt . Errorf ( "http:// resolvers not supported yet" )
}
if strings . HasPrefix ( rr . name . Addr , "https://" ) {
return nil , fmt . Errorf ( "https:// resolvers not supported yet" )
}
if strings . HasPrefix ( rr . name . Addr , "tls://" ) {
return nil , fmt . Errorf ( "tls:// resolvers not supported yet" )
}
ipp , err := netaddr . ParseIPPort ( rr . name . Addr )
if err != nil {
return nil , err
}
// Upgrade known DNS IPs to DoH (DNS-over-HTTPs).
if urlBase , dc , ok := f . getDoHClient ( ip ) ; ok {
// All known DoH is over port 53.
if urlBase , dc , ok := f . getKnownDoHClient ( ipp . IP ( ) ) ; ok {
res , err := f . sendDoH ( ctx , urlBase , dc , fq . packet )
if err == nil || ctx . Err ( ) != nil {
return res , err
}
f . logf ( "DoH error from %v: %v" , ip , err )
f . logf ( "DoH error from %v: %v" , ip p. IP ( ) , err )
}
ln , err := f . packetListener ( ip )
ln , err := f . packetListener ( ip p. IP ( ) )
if err != nil {
return nil , err
}
@ -407,7 +423,7 @@ func (f *forwarder) send(ctx context.Context, fq *forwardQuery, dst netaddr.IPPo
fq . closeOnCtxDone . Add ( conn )
defer fq . closeOnCtxDone . Remove ( conn )
if _ , err := conn . WriteTo ( fq . packet , dst . UDPAddr ( ) ) ; err != nil {
if _ , err := conn . WriteTo ( fq . packet , ipp . UDPAddr ( ) ) ; err != nil {
if err := ctx . Err ( ) ; err != nil {
return nil , err
}
@ -525,8 +541,8 @@ func (f *forwarder) forward(query packet) error {
firstErr error
)
for _, rr := range resolvers {
go func ( rr resolverAndDelay ) {
for i := range resolvers {
go func ( rr * resolverAndDelay ) {
if rr . startDelay > 0 {
timer := time . NewTimer ( rr . startDelay )
select {
@ -536,7 +552,7 @@ func (f *forwarder) forward(query packet) error {
return
}
}
resb , err := f . send ( ctx , fq , rr . ipp )
resb , err := f . send ( ctx , fq , * rr )
if err != nil {
mu . Lock ( )
defer mu . Unlock ( )
@ -549,7 +565,7 @@ func (f *forwarder) forward(query packet) error {
case resc <- resb :
default :
}
} ( r r)
} ( & r esolve rs[ i ] )
}
select {
@ -638,7 +654,7 @@ func (p *closePool) Close() error {
return nil
}
var knownDoH = map [ netaddr . IP ] string { }
var knownDoH = map [ netaddr . IP ] string { } // 8.8.8.8 => "https://..."
var dohIPsOfBase = map [ string ] [ ] netaddr . IP { }