diff --git a/prober/dns.go b/prober/dns.go index 58048766e..4302994fe 100644 --- a/prober/dns.go +++ b/prober/dns.go @@ -31,15 +31,15 @@ type ForEachAddrOpts struct { } // ForEachAddr returns a Probe that resolves a given hostname into all -// available IP addresses, and then calls a function to create a new Probe -// every time a new IP is discovered. The Probe returned will be closed if an +// available IP addresses, and then calls a function to create new Probes +// every time a new IP is discovered. The Probes returned will be closed if an // IP address is no longer in the DNS record for the given hostname. This can // be used to healthcheck every IP address that a hostname resolves to. -func ForEachAddr(host string, newProbe func(netip.Addr) *Probe, opts ForEachAddrOpts) ProbeFunc { - return makeForEachAddr(host, newProbe, opts).run +func ForEachAddr(host string, makeProbes func(netip.Addr) []*Probe, opts ForEachAddrOpts) ProbeFunc { + return makeForEachAddr(host, makeProbes, opts).run } -func makeForEachAddr(host string, newProbe func(netip.Addr) *Probe, opts ForEachAddrOpts) *forEachAddrProbe { +func makeForEachAddr(host string, makeProbes func(netip.Addr) []*Probe, opts ForEachAddrOpts) *forEachAddrProbe { if opts.Logf == nil { opts.Logf = logger.Discard } @@ -54,9 +54,9 @@ func makeForEachAddr(host string, newProbe func(netip.Addr) *Probe, opts ForEach logf: opts.Logf, host: host, networks: opts.Networks, - newProbe: newProbe, + makeProbes: makeProbes, lookupNetIP: opts.LookupNetIP, - probes: make(map[netip.Addr]*Probe), + probes: make(map[netip.Addr][]*Probe), } } @@ -65,12 +65,12 @@ type forEachAddrProbe struct { logf logger.Logf host string networks []string - newProbe func(netip.Addr) *Probe + makeProbes func(netip.Addr) []*Probe lookupNetIP func(context.Context, string, string) ([]netip.Addr, error) // state mu sync.Mutex // protects following - probes map[netip.Addr]*Probe + probes map[netip.Addr][]*Probe } // run matches the ProbeFunc signature @@ -102,23 +102,25 @@ func (f *forEachAddrProbe) run(ctx context.Context) error { } // Make a new probe, and add it to 'probes'; if the - // function returns nil, we skip it. - probe := f.newProbe(addr) - if probe == nil { + // function returns an empty list, we skip it. + probes := f.makeProbes(addr) + if len(probes) == 0 { continue } - f.logf("adding new probe for %v", addr) - f.probes[addr] = probe + f.logf("adding %d new probes for %v", len(probes), addr) + f.probes[addr] = probes } // Remove probes that we didn't see during this address resolution. - for addr, probe := range f.probes { + for addr, probes := range f.probes { if !sawIPs[addr] { - f.logf("removing probe for %v", addr) + f.logf("removing %d probes for %v", len(probes), addr) - // This IP is no longer in the DNS record. Close and remove the probe. - probe.Close() + // This IP is no longer in the DNS record. Close and remove all probes + for _, probe := range probes { + probe.Close() + } delete(f.probes, addr) } } diff --git a/prober/dns_example_test.go b/prober/dns_example_test.go index 7cd73bcf3..4bb7471a2 100644 --- a/prober/dns_example_test.go +++ b/prober/dns_example_test.go @@ -39,14 +39,15 @@ func ExampleForEachAddr() { } // This function is called every time we discover a new IP address to check. - makeTLSProbe := func(addr netip.Addr) *prober.Probe { + makeTLSProbe := func(addr netip.Addr) []*prober.Probe { pf := prober.TLSWithIP(*hostname, netip.AddrPortFrom(addr, 443)) if *verbose { logger := logger.WithPrefix(log.Printf, fmt.Sprintf("[tls %s]: ", addr)) pf = probeLogWrapper(logger, pf) } - return p.Run(fmt.Sprintf("website/%s/tls", addr), every30s, nil, pf) + probe := p.Run(fmt.Sprintf("website/%s/tls", addr), every30s, nil, pf) + return []*prober.Probe{probe} } // Determine whether to use IPv4 or IPv6 based on whether we can create diff --git a/prober/dns_test.go b/prober/dns_test.go index d0148537d..b7c432d11 100644 --- a/prober/dns_test.go +++ b/prober/dns_test.go @@ -48,7 +48,7 @@ func TestForEachAddr(t *testing.T) { mu sync.Mutex // protects following registered []netip.Addr ) - newProbe := func(addr netip.Addr) *Probe { + newProbe := func(addr netip.Addr) []*Probe { // Called to register a new prober t.Logf("called to register new probe for %v", addr) @@ -57,9 +57,10 @@ func TestForEachAddr(t *testing.T) { registered = append(registered, addr) // Return a probe that does nothing; we don't care about what this does. - return p.Run(fmt.Sprintf("website/%s", addr), probeInterval, nil, func(_ context.Context) error { + probe := p.Run(fmt.Sprintf("website/%s", addr), probeInterval, nil, func(_ context.Context) error { return nil }) + return []*Probe{probe} } fep := makeForEachAddr("tailscale.com", newProbe, opts)