From 228d0c6aeabcc01ae973442ae1eac5e82b07da11 Mon Sep 17 00:00:00 2001 From: Andrew Dunham Date: Fri, 14 Apr 2023 12:13:03 -0400 Subject: [PATCH] net/netcheck: use dnscache.Resolver when resolving DERP IPs This also adds a bunch of tests for this function to ensure that we're returning the proper IP(s) in all cases. Signed-off-by: Andrew Dunham Change-Id: I0d9d57170dbab5f2bf07abdf78ecd17e0e635399 --- cmd/tailscale/cli/netcheck.go | 1 + net/netcheck/netcheck.go | 56 +++++++++++++++++++++++++++++--- net/netcheck/netcheck_test.go | 57 +++++++++++++++++++++++++++++++++ wgengine/magicsock/magicsock.go | 1 + 4 files changed, 110 insertions(+), 5 deletions(-) diff --git a/cmd/tailscale/cli/netcheck.go b/cmd/tailscale/cli/netcheck.go index fef2e22a6..a1f79b1fe 100644 --- a/cmd/tailscale/cli/netcheck.go +++ b/cmd/tailscale/cli/netcheck.go @@ -48,6 +48,7 @@ func runNetcheck(ctx context.Context, args []string) error { c := &netcheck.Client{ UDPBindAddr: envknob.String("TS_DEBUG_NETCHECK_UDP_BIND"), PortMapper: portmapper.NewClient(logger.WithPrefix(log.Printf, "portmap: "), nil, nil), + UseDNSCache: false, // always resolve, don't cache } if netcheckArgs.verbose { c.Logf = logger.WithPrefix(log.Printf, "netcheck: ") diff --git a/net/netcheck/netcheck.go b/net/netcheck/netcheck.go index b941340f1..54f31f065 100644 --- a/net/netcheck/netcheck.go +++ b/net/netcheck/netcheck.go @@ -25,6 +25,7 @@ import ( "github.com/tcnksm/go-httpstat" "tailscale.com/derp/derphttp" "tailscale.com/envknob" + "tailscale.com/net/dnscache" "tailscale.com/net/interfaces" "tailscale.com/net/netaddr" "tailscale.com/net/neterror" @@ -181,6 +182,15 @@ type Client struct { // If nil, portmap discovery is not done. PortMapper *portmapper.Client // lazily initialized on first use + // UseDNSCache controls whether this client should use a + // *dnscache.Resolver to resolve DERP hostnames, when no IP address is + // provided in the DERP map. Note that Tailscale-provided DERP servers + // all specify explicit IPv4 and IPv6 addresses, so this is mostly + // helpful for users with custom DERP servers. + // + // If false, the default net.Resolver will be used, with no caching. + UseDNSCache bool + // For tests testEnoughRegions int testCaptivePortalDelay time.Duration @@ -191,6 +201,7 @@ type Client struct { last *Report // most recent report lastFull time.Time // time of last full (non-incremental) report curState *reportState // non-nil if we're in a call to GetReportn + resolver *dnscache.Resolver // only set if UseDNSCache is true } // STUNConn is the interface required by the netcheck Client when @@ -1514,6 +1525,7 @@ func (rs *reportState) runProbe(ctx context.Context, dm *tailcfg.DERPMap, probe addr := c.nodeAddr(ctx, node, probe.proto) if !addr.IsValid() { + c.logf("netcheck.runProbe: named node %q has no address", probe.node) return } @@ -1597,12 +1609,46 @@ func (c *Client) nodeAddr(ctx context.Context, n *tailcfg.DERPNode, proto probeP return } - // TODO(bradfitz): add singleflight+dnscache here. - addrs, _ := net.DefaultResolver.LookupIPAddr(ctx, n.HostName) + // The default lookup function if we don't set UseDNSCache is to use net.DefaultResolver. + lookupIPAddr := func(ctx context.Context, host string) ([]netip.Addr, error) { + addrs, err := net.DefaultResolver.LookupIPAddr(ctx, host) + if err != nil { + return nil, err + } + + var naddrs []netip.Addr + for _, addr := range addrs { + na, ok := netip.AddrFromSlice(addr.IP) + if !ok { + continue + } + naddrs = append(naddrs, na.Unmap()) + } + return naddrs, nil + } + + c.mu.Lock() + if c.UseDNSCache { + if c.resolver == nil { + c.resolver = &dnscache.Resolver{ + Forward: net.DefaultResolver, + UseLastGood: true, + Logf: c.logf, + } + } + resolver := c.resolver + lookupIPAddr = func(ctx context.Context, host string) ([]netip.Addr, error) { + _, _, allIPs, err := resolver.LookupIP(ctx, host) + return allIPs, err + } + } + c.mu.Unlock() + + probeIsV4 := proto == probeIPv4 + addrs, _ := lookupIPAddr(ctx, n.HostName) for _, a := range addrs { - if (a.IP.To4() != nil) == (proto == probeIPv4) { - na, _ := netip.AddrFromSlice(a.IP.To4()) - return netip.AddrPortFrom(na.Unmap(), uint16(port)) + if (a.Is4() && probeIsV4) || (a.Is6() && !probeIsV4) { + return netip.AddrPortFrom(a, uint16(port)) } } return diff --git a/net/netcheck/netcheck_test.go b/net/netcheck/netcheck_test.go index 4d4bc4a2f..ee9ef308d 100644 --- a/net/netcheck/netcheck_test.go +++ b/net/netcheck/netcheck_test.go @@ -824,3 +824,60 @@ type RoundTripFunc func(req *http.Request) *http.Response func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { return f(req), nil } + +func TestNodeAddrResolve(t *testing.T) { + c := &Client{ + Logf: t.Logf, + UDPBindAddr: "127.0.0.1:0", + UseDNSCache: true, + } + + dn := &tailcfg.DERPNode{ + Name: "derptest1a", + RegionID: 901, + HostName: "tailscale.com", + // No IPv4 or IPv6 addrs + } + dnV4Only := &tailcfg.DERPNode{ + Name: "derptest1b", + RegionID: 901, + HostName: "ipv4.google.com", + // No IPv4 or IPv6 addrs + } + + ctx := context.Background() + for _, tt := range []bool{true, false} { + t.Run(fmt.Sprintf("UseDNSCache=%v", tt), func(t *testing.T) { + c.resolver = nil + c.UseDNSCache = tt + + t.Run("IPv4", func(t *testing.T) { + ap := c.nodeAddr(ctx, dn, probeIPv4) + if !ap.IsValid() { + t.Fatal("expected valid AddrPort") + } + if !ap.Addr().Is4() { + t.Fatalf("expected IPv4 addr, got: %v", ap.Addr()) + } + t.Logf("got IPv4 addr: %v", ap) + }) + t.Run("IPv6", func(t *testing.T) { + ap := c.nodeAddr(ctx, dn, probeIPv6) + if !ap.IsValid() { + t.Fatal("expected valid AddrPort") + } + if !ap.Addr().Is6() { + t.Fatalf("expected IPv6 addr, got: %v", ap.Addr()) + } + t.Logf("got IPv6 addr: %v", ap) + }) + t.Run("IPv6 Failure", func(t *testing.T) { + ap := c.nodeAddr(ctx, dnV4Only, probeIPv6) + if ap.IsValid() { + t.Fatalf("expected no addr but got: %v", ap) + } + t.Logf("correctly got invalid addr") + }) + }) + } +} diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index c4511b33f..6fd7f4b1b 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -660,6 +660,7 @@ func NewConn(opts Options) (*Conn, error) { GetSTUNConn6: func() netcheck.STUNConn { return &c.pconn6 }, SkipExternalNetwork: inTest(), PortMapper: c.portMapper, + UseDNSCache: true, } c.ignoreSTUNPackets()