net/dnscache: add overly simplistic DNS cache package for selective use

I started to write a full DNS caching resolver and I realized it was
overkill and wouldn't work on Windows even in Go 1.14 yet, so I'm
doing this tiny one instead for now, just for all our netcheck STUN
derp lookups, and connections to DERP servers. (This will be caching a
exactly 8 DNS entries, all ours.)

Fixes #145 (can be better later, of course)
pull/148/head
Brad Fitzpatrick 4 years ago
parent a36ccb8525
commit 2cff9016e4

@ -26,6 +26,7 @@ import (
"time" "time"
"tailscale.com/derp" "tailscale.com/derp"
"tailscale.com/net/dnscache"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/logger" "tailscale.com/types/logger"
) )
@ -37,7 +38,8 @@ import (
// Send/Recv will completely re-establish the connection (unless Close // Send/Recv will completely re-establish the connection (unless Close
// has been called). // has been called).
type Client struct { type Client struct {
TLSConfig *tls.Config // for sever connection, optional, nil means default TLSConfig *tls.Config // for sever connection, optional, nil means default
DNSCache *dnscache.Resolver // optional; if nil, no caching
privateKey key.Private privateKey key.Private
logf logger.Logf logf logger.Logf
@ -137,11 +139,23 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien
} }
}() }()
host := c.url.Hostname()
hostOrIP := host
var d net.Dialer var d net.Dialer
log.Printf("Dialing: %q", net.JoinHostPort(c.url.Hostname(), urlPort(c.url))) log.Printf("Dialing: %q", net.JoinHostPort(host, urlPort(c.url)))
tcpConn, err = d.DialContext(ctx, "tcp", net.JoinHostPort(c.url.Hostname(), urlPort(c.url)))
if c.DNSCache != nil {
ip, err := c.DNSCache.LookupIP(ctx, host)
if err != nil {
return nil, err
}
hostOrIP = ip.String()
}
tcpConn, err = d.DialContext(ctx, "tcp", net.JoinHostPort(hostOrIP, urlPort(c.url)))
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("Dial of %q: %v", host, err)
} }
// Now that we have a TCP connection, force close it. // Now that we have a TCP connection, force close it.

@ -0,0 +1,151 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package dnscache contains a minimal DNS cache that makes a bunch of
// assumptions that are only valid for us. Not recommended for general use.
package dnscache
import (
"context"
"fmt"
"net"
"sync"
"time"
"golang.org/x/sync/singleflight"
)
var single = &Resolver{
Forward: &net.Resolver{PreferGo: true},
}
// Get returns a caching Resolver singleton.
func Get() *Resolver { return single }
const fixedTTL = 10 * time.Minute
// Resolver is a minimal DNS caching resolver.
//
// The TTL is always fixed for now. It's not intended for general use.
// Cache entries are never cleaned up so it's intended that this is
// only used with a fixed set of hostnames.
type Resolver struct {
// Forward is the resolver to use to populate the cache.
// If nil, net.DefaultResolver is used.
Forward *net.Resolver
sf singleflight.Group
mu sync.Mutex
ipCache map[string]ipCacheEntry
}
type ipCacheEntry struct {
ip net.IP
expires time.Time
}
func (r *Resolver) fwd() *net.Resolver {
if r.Forward != nil {
return r.Forward
}
return net.DefaultResolver
}
// LookupIP returns the first IPv4 address found, otherwise the first IPv6 address.
func (r *Resolver) LookupIP(ctx context.Context, host string) (net.IP, error) {
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
return ip4, nil
}
return ip, nil
}
if ip, ok := r.lookupIPCache(host); ok {
return ip, nil
}
ch := r.sf.DoChan(host, func() (interface{}, error) {
ip, err := r.lookupIP(host)
if err != nil {
return nil, err
}
return ip, nil
})
select {
case res := <-ch:
if res.Err != nil {
return nil, res.Err
}
return res.Val.(net.IP), nil
case <-ctx.Done():
return nil, ctx.Err()
}
}
func (r *Resolver) lookupIPCache(host string) (ip net.IP, ok bool) {
r.mu.Lock()
defer r.mu.Unlock()
if ent, ok := r.ipCache[host]; ok && ent.expires.After(time.Now()) {
return ent.ip, true
}
return nil, false
}
func (r *Resolver) lookupIP(host string) (net.IP, error) {
if ip, ok := r.lookupIPCache(host); ok {
return ip, nil
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
ips, err := r.fwd().LookupIPAddr(ctx, host)
if err != nil {
return nil, err
}
if len(ips) == 0 {
return nil, fmt.Errorf("no IPs for %q found", host)
}
for _, ipa := range ips {
if ip4 := ipa.IP.To4(); ip4 != nil {
return r.addIPCache(host, ip4, fixedTTL), nil
}
}
return r.addIPCache(host, ips[0].IP, fixedTTL), nil
}
func (r *Resolver) addIPCache(host string, ip net.IP, d time.Duration) net.IP {
if isPrivateIP(ip) {
// Don't cache obviously wrong entries from captive portals.
// TODO: use DoH or DoT for the forwarding resolver?
return ip
}
r.mu.Lock()
defer r.mu.Unlock()
if r.ipCache == nil {
r.ipCache = make(map[string]ipCacheEntry)
}
r.ipCache[host] = ipCacheEntry{ip: ip, expires: time.Now().Add(d)}
return ip
}
func mustCIDR(s string) *net.IPNet {
_, ipNet, err := net.ParseCIDR("100.64.0.0/10")
if err != nil {
panic(err)
}
return ipNet
}
func isPrivateIP(ip net.IP) bool {
return private1.Contains(ip) || private2.Contains(ip) || private3.Contains(ip)
}
var (
private1 = mustCIDR("10.0.0.0/8")
private2 = mustCIDR("172.16.0.0/12")
private3 = mustCIDR("192.168.0.0/16")
)

@ -17,6 +17,7 @@ import (
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"tailscale.com/interfaces" "tailscale.com/interfaces"
"tailscale.com/net/dnscache"
"tailscale.com/stun" "tailscale.com/stun"
"tailscale.com/stunner" "tailscale.com/stunner"
"tailscale.com/types/logger" "tailscale.com/types/logger"
@ -181,6 +182,7 @@ func GetReport(ctx context.Context, logf logger.Logf) (*Report, error) {
Endpoint: add, Endpoint: add,
Servers: stunServers, Servers: stunServers,
Logf: logf, Logf: logf,
DNSCache: dnscache.Get(),
} }
grp.Go(func() error { return s4.Run(ctx) }) grp.Go(func() error { return s4.Run(ctx) })
go reader(s4, pc4, unlimited) go reader(s4, pc4, unlimited)
@ -190,6 +192,7 @@ func GetReport(ctx context.Context, logf logger.Logf) (*Report, error) {
Endpoint: addHair, Endpoint: addHair,
Servers: stunServers, Servers: stunServers,
Logf: logf, Logf: logf,
DNSCache: dnscache.Get(),
} }
grp.Go(func() error { return s4Hair.Run(ctx) }) grp.Go(func() error { return s4Hair.Run(ctx) })
go reader(s4Hair, pc4Hair, 2) go reader(s4Hair, pc4Hair, 2)
@ -201,6 +204,7 @@ func GetReport(ctx context.Context, logf logger.Logf) (*Report, error) {
Servers: stunServers6, Servers: stunServers6,
Logf: logf, Logf: logf,
OnlyIPv6: true, OnlyIPv6: true,
DNSCache: dnscache.Get(),
} }
grp.Go(func() error { return s6.Run(ctx) }) grp.Go(func() error { return s6.Run(ctx) })
go reader(s6, pc6, unlimited) go reader(s6, pc6, unlimited)

@ -12,6 +12,7 @@ import (
"sync" "sync"
"time" "time"
"tailscale.com/net/dnscache"
"tailscale.com/stun" "tailscale.com/stun"
) )
@ -38,9 +39,9 @@ type Stunner struct {
Servers []string // STUN servers to contact Servers []string // STUN servers to contact
// Resolver optionally specifies a resolver to use for DNS lookups. // DNSCache optionally specifies a DNSCache to use.
// If nil, net.DefaultResolver is used. // If nil, a DNS cache is not used.
Resolver *net.Resolver DNSCache *dnscache.Resolver
// Logf optionally specifies a log function. If nil, logging is disabled. // Logf optionally specifies a log function. If nil, logging is disabled.
Logf func(format string, args ...interface{}) Logf func(format string, args ...interface{})
@ -118,9 +119,6 @@ func (s *Stunner) Receive(p []byte, fromAddr *net.UDPAddr) {
} }
func (s *Stunner) resolver() *net.Resolver { func (s *Stunner) resolver() *net.Resolver {
if s.Resolver != nil {
return s.Resolver
}
return net.DefaultResolver return net.DefaultResolver
} }
@ -192,9 +190,18 @@ func (s *Stunner) sendSTUN(ctx context.Context, server string) error {
} }
addr := &net.UDPAddr{Port: addrPort} addr := &net.UDPAddr{Port: addrPort}
ipAddrs, err := s.resolver().LookupIPAddr(ctx, host) var ipAddrs []net.IPAddr
if err != nil { if s.DNSCache != nil {
return fmt.Errorf("lookup ip addr: %v", err) ip, err := s.DNSCache.LookupIP(ctx, host)
if err != nil {
return fmt.Errorf("lookup ip addr: %v", err)
}
ipAddrs = []net.IPAddr{{IP: ip}}
} else {
ipAddrs, err = s.resolver().LookupIPAddr(ctx, host)
if err != nil {
return fmt.Errorf("lookup ip addr: %v", err)
}
} }
for _, ipAddr := range ipAddrs { for _, ipAddr := range ipAddrs {
ip4 := ipAddr.IP.To4() ip4 := ipAddr.IP.To4()

@ -31,6 +31,7 @@ import (
"tailscale.com/derp" "tailscale.com/derp"
"tailscale.com/derp/derphttp" "tailscale.com/derp/derphttp"
"tailscale.com/interfaces" "tailscale.com/interfaces"
"tailscale.com/net/dnscache"
"tailscale.com/netcheck" "tailscale.com/netcheck"
"tailscale.com/stun" "tailscale.com/stun"
"tailscale.com/stunner" "tailscale.com/stunner"
@ -638,6 +639,7 @@ func (c *Conn) derpWriteChanOfAddr(addr *net.UDPAddr) chan<- derpWriteRequest {
c.logf("derphttp.NewClient: port %d, host %q invalid? err: %v", addr.Port, host, err) c.logf("derphttp.NewClient: port %d, host %q invalid? err: %v", addr.Port, host, err)
return nil return nil
} }
dc.DNSCache = dnscache.Get()
dc.TLSConfig = c.derpTLSConfig dc.TLSConfig = c.derpTLSConfig
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())

Loading…
Cancel
Save