From ac574d875c7bf6ce16e744b47ce94b74622d550b Mon Sep 17 00:00:00 2001 From: Andrew Dunham Date: Wed, 3 Apr 2024 17:02:36 -0400 Subject: [PATCH] prober: add helper function to check all IPs for a DNS hostname This allows us to check all IP addresses (and address families) for a given DNS hostname while dynamically discovering new IPs and removing old ones as they're no longer valid. Also add a testable example that demonstrates how to use it. Alternative to #11610 Updates tailscale/corp#16367 Signed-off-by: Andrew Dunham Change-Id: I6d6f39bafc30e6dfcf6708185d09faee2a374599 --- prober/dns.go | 126 +++++++++++++++++++++++++++++++++++++ prober/dns_example_test.go | 98 +++++++++++++++++++++++++++++ prober/dns_test.go | 115 +++++++++++++++++++++++++++++++++ 3 files changed, 339 insertions(+) create mode 100644 prober/dns.go create mode 100644 prober/dns_example_test.go create mode 100644 prober/dns_test.go diff --git a/prober/dns.go b/prober/dns.go new file mode 100644 index 000000000..58048766e --- /dev/null +++ b/prober/dns.go @@ -0,0 +1,126 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package prober + +import ( + "context" + "fmt" + "net" + "net/netip" + "sync" + + "tailscale.com/types/logger" +) + +// ForEachAddrOpts contains options for ForEachAddr. The zero value for all +// fields is valid unless stated otherwise. +type ForEachAddrOpts struct { + // Logf is the logger to use for logging. If nil, no logging is done. + Logf logger.Logf + // Networks is the list of networks to resolve; if non-empty, it should + // contain at least one of "ip", "ip4", or "ip6". + // + // If empty, "ip" is assumed. + Networks []string + // LookupNetIP is the function to use to resolve the hostname to one or + // more IP addresses. + // + // If nil, net.DefaultResolver.LookupNetIP is used. + LookupNetIP func(context.Context, string, string) ([]netip.Addr, error) +} + +// ForEachAddr returns a Probe that resolves a given hostname into all +// available IP addresses, and then calls a function to create a new Probe +// every time a new IP is discovered. The Probe returned will be closed if an +// IP address is no longer in the DNS record for the given hostname. This can +// be used to healthcheck every IP address that a hostname resolves to. +func ForEachAddr(host string, newProbe func(netip.Addr) *Probe, opts ForEachAddrOpts) ProbeFunc { + return makeForEachAddr(host, newProbe, opts).run +} + +func makeForEachAddr(host string, newProbe func(netip.Addr) *Probe, opts ForEachAddrOpts) *forEachAddrProbe { + if opts.Logf == nil { + opts.Logf = logger.Discard + } + if len(opts.Networks) == 0 { + opts.Networks = []string{"ip"} + } + if opts.LookupNetIP == nil { + opts.LookupNetIP = net.DefaultResolver.LookupNetIP + } + + return &forEachAddrProbe{ + logf: opts.Logf, + host: host, + networks: opts.Networks, + newProbe: newProbe, + lookupNetIP: opts.LookupNetIP, + probes: make(map[netip.Addr]*Probe), + } +} + +type forEachAddrProbe struct { + // inputs; immutable + logf logger.Logf + host string + networks []string + newProbe func(netip.Addr) *Probe + lookupNetIP func(context.Context, string, string) ([]netip.Addr, error) + + // state + mu sync.Mutex // protects following + probes map[netip.Addr]*Probe +} + +// run matches the ProbeFunc signature +func (f *forEachAddrProbe) run(ctx context.Context) error { + var addrs []netip.Addr + for _, network := range f.networks { + naddrs, err := f.lookupNetIP(ctx, network, f.host) + if err != nil { + return fmt.Errorf("resolving %s addr for %q: %w", network, f.host, err) + } + addrs = append(addrs, naddrs...) + } + if len(addrs) == 0 { + return fmt.Errorf("no addrs for %q", f.host) + } + + // For each address, create a new probe if it doesn't already + // exist in our probe map. + f.mu.Lock() + defer f.mu.Unlock() + + sawIPs := make(map[netip.Addr]bool) + for _, addr := range addrs { + sawIPs[addr] = true + + if _, ok := f.probes[addr]; ok { + // Nothing to create + continue + } + + // Make a new probe, and add it to 'probes'; if the + // function returns nil, we skip it. + probe := f.newProbe(addr) + if probe == nil { + continue + } + + f.logf("adding new probe for %v", addr) + f.probes[addr] = probe + } + + // Remove probes that we didn't see during this address resolution. + for addr, probe := range f.probes { + if !sawIPs[addr] { + f.logf("removing probe for %v", addr) + + // This IP is no longer in the DNS record. Close and remove the probe. + probe.Close() + delete(f.probes, addr) + } + } + return nil +} diff --git a/prober/dns_example_test.go b/prober/dns_example_test.go new file mode 100644 index 000000000..7cd73bcf3 --- /dev/null +++ b/prober/dns_example_test.go @@ -0,0 +1,98 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package prober_test + +import ( + "context" + "flag" + "fmt" + "log" + "net" + "net/netip" + "os" + "os/signal" + "time" + + "tailscale.com/prober" + "tailscale.com/types/logger" +) + +const ( + every30s = 30 * time.Second +) + +var ( + hostname = flag.String("hostname", "tailscale.com", "hostname to probe") + oneshot = flag.Bool("oneshot", true, "run probes once and exit") + verbose = flag.Bool("verbose", false, "enable verbose logging") +) + +// This example demonstrates how to use ForEachAddr to create a TLS probe for +// each IP address in the DNS record of a given hostname. +func ExampleForEachAddr() { + flag.Parse() + + p := prober.New().WithSpread(true) + if *oneshot { + p = p.WithOnce(true) + } + + // This function is called every time we discover a new IP address to check. + makeTLSProbe := func(addr netip.Addr) *prober.Probe { + pf := prober.TLSWithIP(*hostname, netip.AddrPortFrom(addr, 443)) + if *verbose { + logger := logger.WithPrefix(log.Printf, fmt.Sprintf("[tls %s]: ", addr)) + pf = probeLogWrapper(logger, pf) + } + + return p.Run(fmt.Sprintf("website/%s/tls", addr), every30s, nil, pf) + } + + // Determine whether to use IPv4 or IPv6 based on whether we can create + // an IPv6 listening socket on localhost. + sock, err := net.Listen("tcp", "[::1]:0") + supportsIPv6 := err == nil + if sock != nil { + sock.Close() + } + + networks := []string{"ip4"} + if supportsIPv6 { + networks = append(networks, "ip6") + } + + var vlogf logger.Logf = logger.Discard + if *verbose { + vlogf = log.Printf + } + + // This is the outer probe that resolves the hostname and creates a new + // TLS probe for each IP. + p.Run("website/dns", every30s, nil, prober.ForEachAddr(*hostname, makeTLSProbe, prober.ForEachAddrOpts{ + Logf: vlogf, + Networks: networks, + })) + + defer log.Printf("done") + + // Wait until all probes have run if we're running in oneshot mode. + if *oneshot { + p.Wait() + return + } + + // Otherwise, wait until we get a signal. + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, os.Interrupt) + <-sigCh +} + +func probeLogWrapper(logf logger.Logf, pf prober.ProbeFunc) prober.ProbeFunc { + return func(ctx context.Context) error { + logf("starting probe") + err := pf(ctx) + logf("probe finished with %v", err) + return err + } +} diff --git a/prober/dns_test.go b/prober/dns_test.go new file mode 100644 index 000000000..d0148537d --- /dev/null +++ b/prober/dns_test.go @@ -0,0 +1,115 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package prober + +import ( + "context" + "fmt" + "net/netip" + "slices" + "sync" + "testing" + + "tailscale.com/syncs" +) + +func TestForEachAddr(t *testing.T) { + clk := newFakeTime() + p := newForTest(clk.Now, clk.NewTicker) + + opts := ForEachAddrOpts{ + Logf: t.Logf, + Networks: []string{"ip4", "ip6"}, + } + + var ( + addr4_1 = netip.MustParseAddr("76.76.21.21") + addr4_2 = netip.MustParseAddr("127.0.0.1") + + addr6_1 = netip.MustParseAddr("2600:9000:a602:b1e6:5b89:50a1:7cf7:67b8") + addr6_2 = netip.MustParseAddr("2600:9000:a51d:27c1:6748:d035:a989:fb3c") + ) + + var resolverAddrs4, resolverAddrs6 syncs.AtomicValue[[]netip.Addr] + resolverAddrs4.Store([]netip.Addr{addr4_1}) + resolverAddrs6.Store([]netip.Addr{addr6_1, addr6_2}) + + opts.LookupNetIP = func(_ context.Context, network string, _ string) ([]netip.Addr, error) { + if network == "ip4" { + return resolverAddrs4.Load(), nil + } else if network == "ip6" { + return resolverAddrs6.Load(), nil + } + return nil, fmt.Errorf("unknown network %q", network) + } + + var ( + mu sync.Mutex // protects following + registered []netip.Addr + ) + newProbe := func(addr netip.Addr) *Probe { + // Called to register a new prober + t.Logf("called to register new probe for %v", addr) + + mu.Lock() + defer mu.Unlock() + registered = append(registered, addr) + + // Return a probe that does nothing; we don't care about what this does. + return p.Run(fmt.Sprintf("website/%s", addr), probeInterval, nil, func(_ context.Context) error { + return nil + }) + } + + fep := makeForEachAddr("tailscale.com", newProbe, opts) + + // Mimic a call from the prober; we do this ourselves instead of + // calling it via p.Run so we know that the probe has actually run. + ctx := context.Background() + if err := fep.run(ctx); err != nil { + t.Fatalf("run: %v", err) + } + + mu.Lock() + wantAddrs := []netip.Addr{addr4_1, addr6_1, addr6_2} + if !slices.Equal(registered, wantAddrs) { + t.Errorf("got registered addrs %v; want %v", registered, wantAddrs) + } + mu.Unlock() + + // Now, update our IP addresses to force the prober to close and + // re-create our probes. + resolverAddrs4.Store([]netip.Addr{addr4_2}) + resolverAddrs6.Store([]netip.Addr{addr6_2}) + + // Clear out our test data. + mu.Lock() + registered = nil + mu.Unlock() + + // Run our individual prober again manually (so we don't have to wait + // or coordinate with the created probers). + if err := fep.run(ctx); err != nil { + t.Fatalf("run: %v", err) + } + + // Ensure that we only registered our net-new address (addr4_2). + mu.Lock() + wantAddrs = []netip.Addr{addr4_2} + if !slices.Equal(registered, wantAddrs) { + t.Errorf("got registered addrs %v; want %v", registered, wantAddrs) + } + mu.Unlock() + + // Check that we don't have a probe for the addresses that we expect to + // have been removed (addr4_1 and addr6_1). + p.mu.Lock() + for _, addr := range []netip.Addr{addr4_1, addr6_1} { + _, ok := fep.probes[addr] + if ok { + t.Errorf("probe for %v still exists", addr) + } + } + p.mu.Unlock() +}