net/dnsfallback: add singleflight to recursive resolver

This prevents running more than one recursive resolution for the same
hostname in parallel, which can use excessive amounts of CPU when called
in a tight loop. Additionally, add tests that hit the network (when
run with a flag) to test the lookup behaviour.

Updates tailscale/corp#15261

Signed-off-by: Andrew Dunham <andrew@du.nham.ca>
Change-Id: I39351e1d2a8782dd4c52cb04b3bd982eb651c81e
pull/10260/head
Andrew Dunham 1 year ago
parent a40e918d63
commit e33bc64cff

@ -155,7 +155,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
tailscale.com/util/nocasemaps from tailscale.com/types/ipproto tailscale.com/util/nocasemaps from tailscale.com/types/ipproto
tailscale.com/util/quarantine from tailscale.com/cmd/tailscale/cli tailscale.com/util/quarantine from tailscale.com/cmd/tailscale/cli
tailscale.com/util/set from tailscale.com/health+ tailscale.com/util/set from tailscale.com/health+
tailscale.com/util/singleflight from tailscale.com/net/dnscache tailscale.com/util/singleflight from tailscale.com/net/dnscache+
tailscale.com/util/slicesx from tailscale.com/net/dnscache+ tailscale.com/util/slicesx from tailscale.com/net/dnscache+
tailscale.com/util/testenv from tailscale.com/cmd/tailscale/cli tailscale.com/util/testenv from tailscale.com/cmd/tailscale/cli
tailscale.com/util/truncate from tailscale.com/cmd/tailscale/cli tailscale.com/util/truncate from tailscale.com/cmd/tailscale/cli

@ -36,6 +36,7 @@ import (
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/logger" "tailscale.com/types/logger"
"tailscale.com/util/clientmetric" "tailscale.com/util/clientmetric"
"tailscale.com/util/singleflight"
"tailscale.com/util/slicesx" "tailscale.com/util/slicesx"
) )
@ -44,76 +45,165 @@ var (
disableRecursiveResolver = envknob.RegisterBool("TS_DNSFALLBACK_DISABLE_RECURSIVE_RESOLVER") // legacy pre-1.52 env knob name disableRecursiveResolver = envknob.RegisterBool("TS_DNSFALLBACK_DISABLE_RECURSIVE_RESOLVER") // legacy pre-1.52 env knob name
) )
type resolveResult struct {
addrs []netip.Addr
minTTL time.Duration
}
// MakeLookupFunc creates a function that can be used to resolve hostnames // MakeLookupFunc creates a function that can be used to resolve hostnames
// (e.g. as a LookupIPFallback from dnscache.Resolver). // (e.g. as a LookupIPFallback from dnscache.Resolver).
// The netMon parameter is optional; if non-nil it's used to do faster interface lookups. // The netMon parameter is optional; if non-nil it's used to do faster interface lookups.
func MakeLookupFunc(logf logger.Logf, netMon *netmon.Monitor) func(ctx context.Context, host string) ([]netip.Addr, error) { func MakeLookupFunc(logf logger.Logf, netMon *netmon.Monitor) func(ctx context.Context, host string) ([]netip.Addr, error) {
return func(ctx context.Context, host string) ([]netip.Addr, error) { fr := &fallbackResolver{
// If they've explicitly disabled the recursive resolver with the legacy logf: logf,
// TS_DNSFALLBACK_DISABLE_RECURSIVE_RESOLVER envknob or not set the netMon: netMon,
// newer TS_DNSFALLBACK_RECURSIVE_RESOLVER to true, then don't use the }
// recursive resolver. (tailscale/corp#15261) In the future, we might return fr.Lookup
// change the default (the opt.Bool being unset) to mean enabled. }
if disableRecursiveResolver() || !optRecursiveResolver().EqualBool(true) {
return lookup(ctx, host, logf, netMon)
}
addrsCh := make(chan []netip.Addr, 1) // fallbackResolver contains the state and configuration for a DNS resolution
// function.
type fallbackResolver struct {
logf logger.Logf
netMon *netmon.Monitor // or nil
sf singleflight.Group[string, resolveResult]
// Run the recursive resolver in the background so we can // for tests
// compare the results. waitForCompare bool
go func() { }
logf := logger.WithPrefix(logf, "recursive: ")
// Ensure that we catch panics while we're testing this
// code path; this should never panic, but we don't
// want to take down the process by having the panic
// propagate to the top of the goroutine's stack and
// then terminate.
defer func() {
if r := recover(); r != nil {
logf("bootstrap DNS: recovered panic: %v", r)
metricRecursiveErrors.Add(1)
}
}()
resolver := recursive.Resolver{
Dialer: netns.NewDialer(logf, netMon),
Logf: logf,
}
addrs, minTTL, err := resolver.Resolve(ctx, host)
if err != nil {
logf("error using recursive resolver: %v", err)
metricRecursiveErrors.Add(1)
return
}
compareAddr := func(a, b netip.Addr) int { return a.Compare(b) } func (fr *fallbackResolver) Lookup(ctx context.Context, host string) ([]netip.Addr, error) {
slices.SortFunc(addrs, compareAddr) // If they've explicitly disabled the recursive resolver with the legacy
// TS_DNSFALLBACK_DISABLE_RECURSIVE_RESOLVER envknob or not set the
// newer TS_DNSFALLBACK_RECURSIVE_RESOLVER to true, then don't use the
// recursive resolver. (tailscale/corp#15261) In the future, we might
// change the default (the opt.Bool being unset) to mean enabled.
if disableRecursiveResolver() || !optRecursiveResolver().EqualBool(true) {
return lookup(ctx, host, fr.logf, fr.netMon)
}
// Wait for a response from the main function addrsCh := make(chan []netip.Addr, 1)
oldAddrs := <-addrsCh
slices.SortFunc(oldAddrs, compareAddr)
matches := slices.Equal(addrs, oldAddrs) // Run the recursive resolver in the background so we can
// compare the results. For tests, we also allow waiting for the
// comparison to complete; normally, we do this entirely asynchronously
// so as not to block the caller.
var done chan struct{}
if fr.waitForCompare {
done = make(chan struct{})
go func() {
defer close(done)
fr.compareWithRecursive(ctx, addrsCh, host)
}()
} else {
go fr.compareWithRecursive(ctx, addrsCh, host)
}
logf("bootstrap DNS comparison: matches=%v oldAddrs=%v addrs=%v minTTL=%v", matches, oldAddrs, addrs, minTTL) addrs, err := lookup(ctx, host, fr.logf, fr.netMon)
if err != nil {
addrsCh <- nil
return nil, err
}
if matches { addrsCh <- slices.Clone(addrs)
metricRecursiveMatches.Add(1) if fr.waitForCompare {
} else { select {
metricRecursiveMismatches.Add(1) case <-done:
} case <-ctx.Done():
}() }
}
return addrs, nil
}
addrs, err := lookup(ctx, host, logf, netMon) // compareWithRecursive is responsible for comparing the DNS resolution
// performed via the "normal" path (bootstrap DNS requests to the DERP servers)
// with DNS resolution performed with our in-process recursive DNS resolver.
//
// It will select on addrsCh to read exactly one set of addrs (returned by the
// "normal" path) and compare against the results returned by the recursive
// resolver. If ctx is canceled, then it will abort.
func (fr *fallbackResolver) compareWithRecursive(
ctx context.Context,
addrsCh <-chan []netip.Addr,
host string,
) {
logf := logger.WithPrefix(fr.logf, "recursive: ")
// Ensure that we catch panics while we're testing this
// code path; this should never panic, but we don't
// want to take down the process by having the panic
// propagate to the top of the goroutine's stack and
// then terminate.
defer func() {
if r := recover(); r != nil {
logf("bootstrap DNS: recovered panic: %v", r)
metricRecursiveErrors.Add(1)
}
}()
// Don't resolve the same host multiple times
// concurrently; if we end up in a tight loop, this can
// take up a lot of CPU.
var didRun bool
result, err, _ := fr.sf.Do(host, func() (resolveResult, error) {
didRun = true
resolver := &recursive.Resolver{
Dialer: netns.NewDialer(logf, fr.netMon),
Logf: logf,
}
addrs, minTTL, err := resolver.Resolve(ctx, host)
if err != nil { if err != nil {
addrsCh <- nil logf("error using recursive resolver: %v", err)
return nil, err metricRecursiveErrors.Add(1)
return resolveResult{}, err
} }
return resolveResult{addrs, minTTL}, nil
})
// The singleflight function handled errors; return if
// there was one. Additionally, don't bother doing the
// comparison if we waited on another singleflight
// caller; the results are likely to be the same, so
// rather than spam the logs we can just exit and let
// the singleflight call that did execute do the
// comparison.
//
// Returning here is safe because the addrsCh channel
// is buffered, so the main function won't block even
// if we never read from it.
if err != nil || !didRun {
return
}
addrs, minTTL := result.addrs, result.minTTL
compareAddr := func(a, b netip.Addr) int { return a.Compare(b) }
slices.SortFunc(addrs, compareAddr)
// Wait for a response from the main function; try this once before we
// check whether the context is canceled since selects are
// nondeterministic.
var oldAddrs []netip.Addr
select {
case oldAddrs = <-addrsCh:
// All good; continue
default:
// Now block.
select {
case oldAddrs = <-addrsCh:
case <-ctx.Done():
return
}
}
slices.SortFunc(oldAddrs, compareAddr)
matches := slices.Equal(addrs, oldAddrs)
logf("bootstrap DNS comparison: matches=%v oldAddrs=%v addrs=%v minTTL=%v", matches, oldAddrs, addrs, minTTL)
addrsCh <- slices.Clone(addrs) if matches {
return addrs, nil metricRecursiveMatches.Add(1)
} else {
metricRecursiveMismatches.Add(1)
} }
} }

@ -4,13 +4,17 @@
package dnsfallback package dnsfallback
import ( import (
"context"
"encoding/json" "encoding/json"
"flag"
"os" "os"
"path/filepath" "path/filepath"
"reflect" "reflect"
"testing" "testing"
"tailscale.com/net/netmon"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/logger"
) )
func TestGetDERPMap(t *testing.T) { func TestGetDERPMap(t *testing.T) {
@ -170,3 +174,30 @@ func TestCacheUnchanged(t *testing.T) {
t.Fatalf("didn't find non-empty regular file; mode=%v size=%d", st.Mode(), st.Size()) t.Fatalf("didn't find non-empty regular file; mode=%v size=%d", st.Mode(), st.Size())
} }
} }
var extNetwork = flag.Bool("use-external-network", false, "use the external network in tests")
func TestLookup(t *testing.T) {
if !*extNetwork {
t.Skip("skipping test without --use-external-network")
}
logf, closeLogf := logger.LogfCloser(t.Logf)
defer closeLogf()
netMon, err := netmon.New(logf)
if err != nil {
t.Fatal(err)
}
resolver := &fallbackResolver{
logf: logf,
netMon: netMon,
waitForCompare: true,
}
addrs, err := resolver.Lookup(context.Background(), "controlplane.tailscale.com")
if err != nil {
t.Fatal(err)
}
t.Logf("addrs: %+v", addrs)
}

Loading…
Cancel
Save