diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index 7fb814151..b79a12695 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -15,7 +15,9 @@ import ( "io" "log" "net" + "strconv" "strings" + "sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" @@ -34,6 +36,7 @@ import ( "tailscale.com/net/socks5" "tailscale.com/types/logger" "tailscale.com/types/netmap" + "tailscale.com/util/dnsname" "tailscale.com/wgengine" "tailscale.com/wgengine/filter" "tailscale.com/wgengine/magicsock" @@ -50,6 +53,9 @@ type Impl struct { e wgengine.Engine mc *magicsock.Conn logf logger.Logf + + mu sync.Mutex + dns map[string]netaddr.IP // Magic DNS names (both base + FQDN) => first IP } const nicID = 1 @@ -120,7 +126,33 @@ func (ns *Impl) Start() error { return nil } +func (ns *Impl) updateDNS(nm *netmap.NetworkMap) { + ns.mu.Lock() + defer ns.mu.Unlock() + ns.dns = make(map[string]netaddr.IP) + suffix := nm.MagicDNSSuffix() + + if nm.Name != "" && len(nm.Addresses) > 0 { + ip := nm.Addresses[0].IP + ns.dns[strings.TrimRight(nm.Name, ".")] = ip + if dnsname.HasSuffix(nm.Name, suffix) { + ns.dns[dnsname.TrimSuffix(nm.Name, suffix)] = ip + } + } + for _, p := range nm.Peers { + if p.Name != "" && len(p.Addresses) > 0 { + ip := p.Addresses[0].IP + ns.dns[strings.TrimRight(p.Name, ".")] = ip + if dnsname.HasSuffix(p.Name, suffix) { + ns.dns[dnsname.TrimSuffix(p.Name, suffix)] = ip + } + } + } +} + func (ns *Impl) updateIPs(nm *netmap.NetworkMap) { + ns.updateDNS(nm) + oldIPs := make(map[tcpip.Address]bool) for _, ip := range ns.ipstack.AllAddresses()[nicID] { oldIPs[ip.AddressWithPrefix.Address] = true @@ -166,10 +198,54 @@ func (ns *Impl) updateIPs(nm *netmap.NetworkMap) { } } -func (ns *Impl) dialContextTCP(ctx context.Context, address string) (*gonet.TCPConn, error) { - remoteIPPort, err := netaddr.ParseIPPort(address) +// resolve resolves addr into an IP:port. +func (ns *Impl) resolve(ctx context.Context, addr string) (netaddr.IPPort, error) { + ipp, pippErr := netaddr.ParseIPPort(addr) + if pippErr == nil { + return ipp, nil + } + host, port, err := net.SplitHostPort(addr) + if err != nil { + // addr is malformed. + return netaddr.IPPort{}, err + } + if net.ParseIP(host) != nil { + // The host part of addr was an IP, so the netaddr.ParseIPPort above should've + // passed. Must've been a bad port number. Return the original error. + return netaddr.IPPort{}, pippErr + } + port16, err := strconv.ParseUint(port, 10, 16) + if err != nil { + return netaddr.IPPort{}, fmt.Errorf("invalid port in address %q", addr) + } + + // Host is not an IP, so assume it's a DNS name. + + // Try MagicDNS first, else otherwise a real DNS lookup. + ns.mu.Lock() + ip := ns.dns[host] + ns.mu.Unlock() + if !ip.IsZero() { + return netaddr.IPPort{IP: ip, Port: uint16(port16)}, nil + } + + // No Magic DNS name so try real DNS. + var r net.Resolver + ips, err := r.LookupIP(ctx, "ip", host) + if err != nil { + return netaddr.IPPort{}, err + } + if len(ips) == 0 { + return netaddr.IPPort{}, fmt.Errorf("DNS lookup returned no results for %q", host) + } + ip, _ = netaddr.FromStdIP(ips[0]) + return netaddr.IPPort{IP: ip, Port: uint16(port16)}, nil +} + +func (ns *Impl) dialContextTCP(ctx context.Context, addr string) (*gonet.TCPConn, error) { + remoteIPPort, err := ns.resolve(ctx, addr) if err != nil { - return nil, fmt.Errorf("could not parse IP:port: %w", err) + return nil, err } remoteAddress := tcpip.FullAddress{ NIC: nicID,