diff --git a/cmd/tailscale/depaware.txt b/cmd/tailscale/depaware.txt index 32131bdb2..0baee818c 100644 --- a/cmd/tailscale/depaware.txt +++ b/cmd/tailscale/depaware.txt @@ -155,7 +155,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep tailscale.com/util/nocasemaps from tailscale.com/types/ipproto tailscale.com/util/quarantine from tailscale.com/cmd/tailscale/cli tailscale.com/util/set from tailscale.com/health+ - tailscale.com/util/singleflight from tailscale.com/net/dnscache + tailscale.com/util/singleflight from tailscale.com/net/dnscache+ tailscale.com/util/slicesx from tailscale.com/net/dnscache+ tailscale.com/util/testenv from tailscale.com/cmd/tailscale/cli tailscale.com/util/truncate from tailscale.com/cmd/tailscale/cli diff --git a/net/dnsfallback/dnsfallback.go b/net/dnsfallback/dnsfallback.go index c7d08858b..59ca5f624 100644 --- a/net/dnsfallback/dnsfallback.go +++ b/net/dnsfallback/dnsfallback.go @@ -36,6 +36,7 @@ import ( "tailscale.com/tailcfg" "tailscale.com/types/logger" "tailscale.com/util/clientmetric" + "tailscale.com/util/singleflight" "tailscale.com/util/slicesx" ) @@ -44,76 +45,165 @@ var ( disableRecursiveResolver = envknob.RegisterBool("TS_DNSFALLBACK_DISABLE_RECURSIVE_RESOLVER") // legacy pre-1.52 env knob name ) +type resolveResult struct { + addrs []netip.Addr + minTTL time.Duration +} + // MakeLookupFunc creates a function that can be used to resolve hostnames // (e.g. as a LookupIPFallback from dnscache.Resolver). // The netMon parameter is optional; if non-nil it's used to do faster interface lookups. func MakeLookupFunc(logf logger.Logf, netMon *netmon.Monitor) func(ctx context.Context, host string) ([]netip.Addr, error) { - return func(ctx context.Context, host string) ([]netip.Addr, error) { - // If they've explicitly disabled the recursive resolver with the legacy - // TS_DNSFALLBACK_DISABLE_RECURSIVE_RESOLVER envknob or not set the - // newer TS_DNSFALLBACK_RECURSIVE_RESOLVER to true, then don't use the - // recursive resolver. (tailscale/corp#15261) In the future, we might - // change the default (the opt.Bool being unset) to mean enabled. - if disableRecursiveResolver() || !optRecursiveResolver().EqualBool(true) { - return lookup(ctx, host, logf, netMon) - } + fr := &fallbackResolver{ + logf: logf, + netMon: netMon, + } + return fr.Lookup +} - addrsCh := make(chan []netip.Addr, 1) +// fallbackResolver contains the state and configuration for a DNS resolution +// function. +type fallbackResolver struct { + logf logger.Logf + netMon *netmon.Monitor // or nil + sf singleflight.Group[string, resolveResult] - // Run the recursive resolver in the background so we can - // compare the results. - go func() { - logf := logger.WithPrefix(logf, "recursive: ") - - // Ensure that we catch panics while we're testing this - // code path; this should never panic, but we don't - // want to take down the process by having the panic - // propagate to the top of the goroutine's stack and - // then terminate. - defer func() { - if r := recover(); r != nil { - logf("bootstrap DNS: recovered panic: %v", r) - metricRecursiveErrors.Add(1) - } - }() - - resolver := recursive.Resolver{ - Dialer: netns.NewDialer(logf, netMon), - Logf: logf, - } - addrs, minTTL, err := resolver.Resolve(ctx, host) - if err != nil { - logf("error using recursive resolver: %v", err) - metricRecursiveErrors.Add(1) - return - } + // for tests + waitForCompare bool +} - compareAddr := func(a, b netip.Addr) int { return a.Compare(b) } - slices.SortFunc(addrs, compareAddr) +func (fr *fallbackResolver) Lookup(ctx context.Context, host string) ([]netip.Addr, error) { + // If they've explicitly disabled the recursive resolver with the legacy + // TS_DNSFALLBACK_DISABLE_RECURSIVE_RESOLVER envknob or not set the + // newer TS_DNSFALLBACK_RECURSIVE_RESOLVER to true, then don't use the + // recursive resolver. (tailscale/corp#15261) In the future, we might + // change the default (the opt.Bool being unset) to mean enabled. + if disableRecursiveResolver() || !optRecursiveResolver().EqualBool(true) { + return lookup(ctx, host, fr.logf, fr.netMon) + } - // Wait for a response from the main function - oldAddrs := <-addrsCh - slices.SortFunc(oldAddrs, compareAddr) + addrsCh := make(chan []netip.Addr, 1) - matches := slices.Equal(addrs, oldAddrs) + // Run the recursive resolver in the background so we can + // compare the results. For tests, we also allow waiting for the + // comparison to complete; normally, we do this entirely asynchronously + // so as not to block the caller. + var done chan struct{} + if fr.waitForCompare { + done = make(chan struct{}) + go func() { + defer close(done) + fr.compareWithRecursive(ctx, addrsCh, host) + }() + } else { + go fr.compareWithRecursive(ctx, addrsCh, host) + } - logf("bootstrap DNS comparison: matches=%v oldAddrs=%v addrs=%v minTTL=%v", matches, oldAddrs, addrs, minTTL) + addrs, err := lookup(ctx, host, fr.logf, fr.netMon) + if err != nil { + addrsCh <- nil + return nil, err + } - if matches { - metricRecursiveMatches.Add(1) - } else { - metricRecursiveMismatches.Add(1) - } - }() + addrsCh <- slices.Clone(addrs) + if fr.waitForCompare { + select { + case <-done: + case <-ctx.Done(): + } + } + return addrs, nil +} - addrs, err := lookup(ctx, host, logf, netMon) +// compareWithRecursive is responsible for comparing the DNS resolution +// performed via the "normal" path (bootstrap DNS requests to the DERP servers) +// with DNS resolution performed with our in-process recursive DNS resolver. +// +// It will select on addrsCh to read exactly one set of addrs (returned by the +// "normal" path) and compare against the results returned by the recursive +// resolver. If ctx is canceled, then it will abort. +func (fr *fallbackResolver) compareWithRecursive( + ctx context.Context, + addrsCh <-chan []netip.Addr, + host string, +) { + logf := logger.WithPrefix(fr.logf, "recursive: ") + + // Ensure that we catch panics while we're testing this + // code path; this should never panic, but we don't + // want to take down the process by having the panic + // propagate to the top of the goroutine's stack and + // then terminate. + defer func() { + if r := recover(); r != nil { + logf("bootstrap DNS: recovered panic: %v", r) + metricRecursiveErrors.Add(1) + } + }() + + // Don't resolve the same host multiple times + // concurrently; if we end up in a tight loop, this can + // take up a lot of CPU. + var didRun bool + result, err, _ := fr.sf.Do(host, func() (resolveResult, error) { + didRun = true + resolver := &recursive.Resolver{ + Dialer: netns.NewDialer(logf, fr.netMon), + Logf: logf, + } + addrs, minTTL, err := resolver.Resolve(ctx, host) if err != nil { - addrsCh <- nil - return nil, err + logf("error using recursive resolver: %v", err) + metricRecursiveErrors.Add(1) + return resolveResult{}, err } + return resolveResult{addrs, minTTL}, nil + }) + + // The singleflight function handled errors; return if + // there was one. Additionally, don't bother doing the + // comparison if we waited on another singleflight + // caller; the results are likely to be the same, so + // rather than spam the logs we can just exit and let + // the singleflight call that did execute do the + // comparison. + // + // Returning here is safe because the addrsCh channel + // is buffered, so the main function won't block even + // if we never read from it. + if err != nil || !didRun { + return + } + + addrs, minTTL := result.addrs, result.minTTL + compareAddr := func(a, b netip.Addr) int { return a.Compare(b) } + slices.SortFunc(addrs, compareAddr) + + // Wait for a response from the main function; try this once before we + // check whether the context is canceled since selects are + // nondeterministic. + var oldAddrs []netip.Addr + select { + case oldAddrs = <-addrsCh: + // All good; continue + default: + // Now block. + select { + case oldAddrs = <-addrsCh: + case <-ctx.Done(): + return + } + } + slices.SortFunc(oldAddrs, compareAddr) + + matches := slices.Equal(addrs, oldAddrs) + + logf("bootstrap DNS comparison: matches=%v oldAddrs=%v addrs=%v minTTL=%v", matches, oldAddrs, addrs, minTTL) - addrsCh <- slices.Clone(addrs) - return addrs, nil + if matches { + metricRecursiveMatches.Add(1) + } else { + metricRecursiveMismatches.Add(1) } } diff --git a/net/dnsfallback/dnsfallback_test.go b/net/dnsfallback/dnsfallback_test.go index a60772b55..4298499b0 100644 --- a/net/dnsfallback/dnsfallback_test.go +++ b/net/dnsfallback/dnsfallback_test.go @@ -4,13 +4,17 @@ package dnsfallback import ( + "context" "encoding/json" + "flag" "os" "path/filepath" "reflect" "testing" + "tailscale.com/net/netmon" "tailscale.com/tailcfg" + "tailscale.com/types/logger" ) func TestGetDERPMap(t *testing.T) { @@ -170,3 +174,30 @@ func TestCacheUnchanged(t *testing.T) { t.Fatalf("didn't find non-empty regular file; mode=%v size=%d", st.Mode(), st.Size()) } } + +var extNetwork = flag.Bool("use-external-network", false, "use the external network in tests") + +func TestLookup(t *testing.T) { + if !*extNetwork { + t.Skip("skipping test without --use-external-network") + } + + logf, closeLogf := logger.LogfCloser(t.Logf) + defer closeLogf() + + netMon, err := netmon.New(logf) + if err != nil { + t.Fatal(err) + } + + resolver := &fallbackResolver{ + logf: logf, + netMon: netMon, + waitForCompare: true, + } + addrs, err := resolver.Lookup(context.Background(), "controlplane.tailscale.com") + if err != nil { + t.Fatal(err) + } + t.Logf("addrs: %+v", addrs) +}