|
|
|
@ -9,41 +9,30 @@ import (
|
|
|
|
|
"context"
|
|
|
|
|
"encoding/binary"
|
|
|
|
|
"errors"
|
|
|
|
|
"fmt"
|
|
|
|
|
"hash/crc32"
|
|
|
|
|
"io"
|
|
|
|
|
"math/rand"
|
|
|
|
|
"net"
|
|
|
|
|
"sync"
|
|
|
|
|
"syscall"
|
|
|
|
|
"time"
|
|
|
|
|
|
|
|
|
|
dns "golang.org/x/net/dns/dnsmessage"
|
|
|
|
|
"inet.af/netaddr"
|
|
|
|
|
"tailscale.com/logtail/backoff"
|
|
|
|
|
"tailscale.com/types/logger"
|
|
|
|
|
"tailscale.com/util/dnsname"
|
|
|
|
|
"tailscale.com/wgengine/monitor"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
// headerBytes is the number of bytes in a DNS message header.
|
|
|
|
|
const headerBytes = 12
|
|
|
|
|
|
|
|
|
|
// connCount is the number of UDP connections to use for forwarding.
|
|
|
|
|
const connCount = 32
|
|
|
|
|
|
|
|
|
|
const (
|
|
|
|
|
// cleanupInterval is the interval between purged of timed-out entries from txMap.
|
|
|
|
|
cleanupInterval = 30 * time.Second
|
|
|
|
|
// responseTimeout is the maximal amount of time to wait for a DNS response.
|
|
|
|
|
responseTimeout = 5 * time.Second
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
var errNoUpstreams = errors.New("upstream nameservers not set")
|
|
|
|
|
|
|
|
|
|
type forwardingRecord struct {
|
|
|
|
|
src netaddr.IPPort
|
|
|
|
|
createdAt time.Time
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// txid identifies a DNS transaction.
|
|
|
|
|
//
|
|
|
|
|
// As the standard DNS Request ID is only 16 bits, we extend it:
|
|
|
|
@ -100,178 +89,164 @@ func getTxID(packet []byte) txid {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type route struct {
|
|
|
|
|
suffix dnsname.FQDN
|
|
|
|
|
resolvers []netaddr.IPPort
|
|
|
|
|
Suffix dnsname.FQDN
|
|
|
|
|
Resolvers []netaddr.IPPort
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// forwarder forwards DNS packets to a number of upstream nameservers.
|
|
|
|
|
type forwarder struct {
|
|
|
|
|
logf logger.Logf
|
|
|
|
|
logf logger.Logf
|
|
|
|
|
linkMon *monitor.Mon
|
|
|
|
|
linkSel ForwardLinkSelector
|
|
|
|
|
|
|
|
|
|
ctx context.Context // good until Close
|
|
|
|
|
ctxCancel context.CancelFunc // closes ctx
|
|
|
|
|
|
|
|
|
|
// responses is a channel by which responses are returned.
|
|
|
|
|
responses chan packet
|
|
|
|
|
// closed signals all goroutines to stop.
|
|
|
|
|
closed chan struct{}
|
|
|
|
|
// wg signals when all goroutines have stopped.
|
|
|
|
|
wg sync.WaitGroup
|
|
|
|
|
|
|
|
|
|
// conns are the UDP connections used for forwarding.
|
|
|
|
|
// A random one is selected for each request, regardless of the target upstream.
|
|
|
|
|
conns []*fwdConn
|
|
|
|
|
|
|
|
|
|
mu sync.Mutex
|
|
|
|
|
// routes are per-suffix resolvers to use.
|
|
|
|
|
routes []route // most specific routes first
|
|
|
|
|
txMap map[txid]forwardingRecord // txids to in-flight requests
|
|
|
|
|
|
|
|
|
|
mu sync.Mutex // guards following
|
|
|
|
|
|
|
|
|
|
// routes are per-suffix resolvers to use, with
|
|
|
|
|
// the most specific routes first.
|
|
|
|
|
routes []route
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func init() {
|
|
|
|
|
rand.Seed(time.Now().UnixNano())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func newForwarder(logf logger.Logf, responses chan packet) *forwarder {
|
|
|
|
|
ret := &forwarder{
|
|
|
|
|
func newForwarder(logf logger.Logf, responses chan packet, linkMon *monitor.Mon, linkSel ForwardLinkSelector) *forwarder {
|
|
|
|
|
f := &forwarder{
|
|
|
|
|
logf: logger.WithPrefix(logf, "forward: "),
|
|
|
|
|
linkMon: linkMon,
|
|
|
|
|
linkSel: linkSel,
|
|
|
|
|
responses: responses,
|
|
|
|
|
closed: make(chan struct{}),
|
|
|
|
|
conns: make([]*fwdConn, connCount),
|
|
|
|
|
txMap: make(map[txid]forwardingRecord),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ret.wg.Add(connCount + 1)
|
|
|
|
|
for idx := range ret.conns {
|
|
|
|
|
ret.conns[idx] = newFwdConn(ret.logf, idx)
|
|
|
|
|
go ret.recv(ret.conns[idx])
|
|
|
|
|
}
|
|
|
|
|
go ret.cleanMap()
|
|
|
|
|
|
|
|
|
|
return ret
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (f *forwarder) Close() {
|
|
|
|
|
select {
|
|
|
|
|
case <-f.closed:
|
|
|
|
|
return
|
|
|
|
|
default:
|
|
|
|
|
// continue
|
|
|
|
|
}
|
|
|
|
|
close(f.closed)
|
|
|
|
|
|
|
|
|
|
for _, conn := range f.conns {
|
|
|
|
|
conn.close()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
f.wg.Wait()
|
|
|
|
|
f.ctx, f.ctxCancel = context.WithCancel(context.Background())
|
|
|
|
|
return f
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (f *forwarder) rebindFromNetworkChange() {
|
|
|
|
|
for _, c := range f.conns {
|
|
|
|
|
c.mu.Lock()
|
|
|
|
|
c.reconnectLocked()
|
|
|
|
|
c.mu.Unlock()
|
|
|
|
|
}
|
|
|
|
|
func (f *forwarder) Close() error {
|
|
|
|
|
f.ctxCancel()
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (f *forwarder) setRoutes(routes []route) {
|
|
|
|
|
f.mu.Lock()
|
|
|
|
|
defer f.mu.Unlock()
|
|
|
|
|
f.routes = routes
|
|
|
|
|
f.mu.Unlock()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// send sends packet to dst. It is best effort.
|
|
|
|
|
func (f *forwarder) send(packet []byte, dst netaddr.IPPort) {
|
|
|
|
|
connIdx := rand.Intn(connCount)
|
|
|
|
|
conn := f.conns[connIdx]
|
|
|
|
|
conn.send(packet, dst)
|
|
|
|
|
}
|
|
|
|
|
var stdNetPacketListener packetListener = new(net.ListenConfig)
|
|
|
|
|
|
|
|
|
|
func (f *forwarder) recv(conn *fwdConn) {
|
|
|
|
|
defer f.wg.Done()
|
|
|
|
|
type packetListener interface {
|
|
|
|
|
ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for {
|
|
|
|
|
select {
|
|
|
|
|
case <-f.closed:
|
|
|
|
|
return
|
|
|
|
|
default:
|
|
|
|
|
}
|
|
|
|
|
// The 1 extra byte is to detect packet truncation.
|
|
|
|
|
out := make([]byte, maxResponseBytes+1)
|
|
|
|
|
n := conn.read(out)
|
|
|
|
|
var truncated bool
|
|
|
|
|
if n > maxResponseBytes {
|
|
|
|
|
n = maxResponseBytes
|
|
|
|
|
truncated = true
|
|
|
|
|
}
|
|
|
|
|
if n == 0 {
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
if n < headerBytes {
|
|
|
|
|
f.logf("recv: packet too small (%d bytes)", n)
|
|
|
|
|
}
|
|
|
|
|
func (f *forwarder) packetListener(ip netaddr.IP) (packetListener, error) {
|
|
|
|
|
if f.linkSel == nil || initListenConfig == nil {
|
|
|
|
|
return stdNetPacketListener, nil
|
|
|
|
|
}
|
|
|
|
|
linkName := f.linkSel.PickLink(ip)
|
|
|
|
|
if linkName == "" {
|
|
|
|
|
return stdNetPacketListener, nil
|
|
|
|
|
}
|
|
|
|
|
lc := new(net.ListenConfig)
|
|
|
|
|
if err := initListenConfig(lc, f.linkMon, linkName); err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
return lc, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
out = out[:n]
|
|
|
|
|
txid := getTxID(out)
|
|
|
|
|
// send sends packet to dst. It is best effort.
|
|
|
|
|
//
|
|
|
|
|
// send expects the reply to have the same txid as txidOut.
|
|
|
|
|
//
|
|
|
|
|
// The provided closeOnCtxDone lets send register values to Close if
|
|
|
|
|
// the caller's ctx expires. This avoids send from allocating its own
|
|
|
|
|
// waiting goroutine to interrupt the ReadFrom, as memory is tight on
|
|
|
|
|
// iOS and we want the number of pending DNS lookups to be bursty
|
|
|
|
|
// without too much associated goroutine/memory cost.
|
|
|
|
|
func (f *forwarder) send(ctx context.Context, txidOut txid, closeOnCtxDone *closePool, packet []byte, dst netaddr.IPPort) ([]byte, error) {
|
|
|
|
|
// TODO(bradfitz): if dst.IP is 8.8.8.8 or 8.8.4.4 or 1.1.1.1, etc, or
|
|
|
|
|
// something dynamically probed earlier to support DoH or DoT,
|
|
|
|
|
// do that here instead.
|
|
|
|
|
|
|
|
|
|
ln, err := f.packetListener(dst.IP())
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
conn, err := ln.ListenPacket(ctx, "udp", ":0")
|
|
|
|
|
if err != nil {
|
|
|
|
|
f.logf("ListenPacket failed: %v", err)
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
defer conn.Close()
|
|
|
|
|
|
|
|
|
|
if truncated {
|
|
|
|
|
const dnsFlagTruncated = 0x200
|
|
|
|
|
flags := binary.BigEndian.Uint16(out[2:4])
|
|
|
|
|
flags |= dnsFlagTruncated
|
|
|
|
|
binary.BigEndian.PutUint16(out[2:4], flags)
|
|
|
|
|
closeOnCtxDone.Add(conn)
|
|
|
|
|
defer closeOnCtxDone.Remove(conn)
|
|
|
|
|
|
|
|
|
|
// TODO(#2067): Remove any incomplete records? RFC 1035 section 6.2
|
|
|
|
|
// states that truncation should head drop so that the authority
|
|
|
|
|
// section can be preserved if possible. However, the UDP read with
|
|
|
|
|
// a too-small buffer has already dropped the end, so that's the
|
|
|
|
|
// best we can do.
|
|
|
|
|
if _, err := conn.WriteTo(packet, dst.UDPAddr()); err != nil {
|
|
|
|
|
if err := ctx.Err(); err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
f.mu.Lock()
|
|
|
|
|
|
|
|
|
|
record, found := f.txMap[txid]
|
|
|
|
|
// At most one nameserver will return a response:
|
|
|
|
|
// the first one to do so will delete txid from the map.
|
|
|
|
|
if !found {
|
|
|
|
|
f.mu.Unlock()
|
|
|
|
|
continue
|
|
|
|
|
// The 1 extra byte is to detect packet truncation.
|
|
|
|
|
out := make([]byte, maxResponseBytes+1)
|
|
|
|
|
n, _, err := conn.ReadFrom(out)
|
|
|
|
|
if err != nil {
|
|
|
|
|
if err := ctx.Err(); err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
delete(f.txMap, txid)
|
|
|
|
|
|
|
|
|
|
f.mu.Unlock()
|
|
|
|
|
|
|
|
|
|
pkt := packet{out, record.src}
|
|
|
|
|
select {
|
|
|
|
|
case <-f.closed:
|
|
|
|
|
return
|
|
|
|
|
case f.responses <- pkt:
|
|
|
|
|
// continue
|
|
|
|
|
if packetWasTruncated(err) {
|
|
|
|
|
err = nil
|
|
|
|
|
} else {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// cleanMap periodically deletes timed-out forwarding records from f.txMap to bound growth.
|
|
|
|
|
func (f *forwarder) cleanMap() {
|
|
|
|
|
defer f.wg.Done()
|
|
|
|
|
truncated := n > maxResponseBytes
|
|
|
|
|
if truncated {
|
|
|
|
|
n = maxResponseBytes
|
|
|
|
|
}
|
|
|
|
|
if n < headerBytes {
|
|
|
|
|
f.logf("recv: packet too small (%d bytes)", n)
|
|
|
|
|
}
|
|
|
|
|
out = out[:n]
|
|
|
|
|
txid := getTxID(out)
|
|
|
|
|
if txid != txidOut {
|
|
|
|
|
return nil, errors.New("txid doesn't match")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
t := time.NewTicker(cleanupInterval)
|
|
|
|
|
defer t.Stop()
|
|
|
|
|
if truncated {
|
|
|
|
|
const dnsFlagTruncated = 0x200
|
|
|
|
|
flags := binary.BigEndian.Uint16(out[2:4])
|
|
|
|
|
flags |= dnsFlagTruncated
|
|
|
|
|
binary.BigEndian.PutUint16(out[2:4], flags)
|
|
|
|
|
|
|
|
|
|
// TODO(#2067): Remove any incomplete records? RFC 1035 section 6.2
|
|
|
|
|
// states that truncation should head drop so that the authority
|
|
|
|
|
// section can be preserved if possible. However, the UDP read with
|
|
|
|
|
// a too-small buffer has already dropped the end, so that's the
|
|
|
|
|
// best we can do.
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var now time.Time
|
|
|
|
|
for {
|
|
|
|
|
select {
|
|
|
|
|
case <-f.closed:
|
|
|
|
|
return
|
|
|
|
|
case now = <-t.C:
|
|
|
|
|
// continue
|
|
|
|
|
}
|
|
|
|
|
return out, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
f.mu.Lock()
|
|
|
|
|
for k, v := range f.txMap {
|
|
|
|
|
if now.Sub(v.createdAt) > responseTimeout {
|
|
|
|
|
delete(f.txMap, k)
|
|
|
|
|
}
|
|
|
|
|
// resolvers returns the resolvers to use for domain.
|
|
|
|
|
func (f *forwarder) resolvers(domain dnsname.FQDN) []netaddr.IPPort {
|
|
|
|
|
f.mu.Lock()
|
|
|
|
|
routes := f.routes
|
|
|
|
|
f.mu.Unlock()
|
|
|
|
|
for _, route := range routes {
|
|
|
|
|
if route.Suffix == "." || route.Suffix.Contains(domain) {
|
|
|
|
|
return route.Resolvers
|
|
|
|
|
}
|
|
|
|
|
f.mu.Unlock()
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// forward forwards the query to all upstream nameservers and returns the first response.
|
|
|
|
@ -283,225 +258,60 @@ func (f *forwarder) forward(query packet) error {
|
|
|
|
|
|
|
|
|
|
txid := getTxID(query.bs)
|
|
|
|
|
|
|
|
|
|
f.mu.Lock()
|
|
|
|
|
routes := f.routes
|
|
|
|
|
f.mu.Unlock()
|
|
|
|
|
|
|
|
|
|
var resolvers []netaddr.IPPort
|
|
|
|
|
for _, route := range routes {
|
|
|
|
|
if route.suffix != "." && !route.suffix.Contains(domain) {
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
resolvers = route.resolvers
|
|
|
|
|
break
|
|
|
|
|
}
|
|
|
|
|
resolvers := f.resolvers(domain)
|
|
|
|
|
if len(resolvers) == 0 {
|
|
|
|
|
return errNoUpstreams
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
f.mu.Lock()
|
|
|
|
|
f.txMap[txid] = forwardingRecord{
|
|
|
|
|
src: query.addr,
|
|
|
|
|
createdAt: time.Now(),
|
|
|
|
|
}
|
|
|
|
|
f.mu.Unlock()
|
|
|
|
|
|
|
|
|
|
// TODO(#2066): EDNS size clamping
|
|
|
|
|
|
|
|
|
|
for _, resolver := range resolvers {
|
|
|
|
|
f.send(query.bs, resolver)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// A fwdConn manages a single connection used to forward DNS requests.
|
|
|
|
|
// Net link changes can cause a *net.UDPConn to become permanently unusable, particularly on macOS.
|
|
|
|
|
// fwdConn detects such situations and transparently creates new connections.
|
|
|
|
|
type fwdConn struct {
|
|
|
|
|
// logf allows a fwdConn to log.
|
|
|
|
|
logf logger.Logf
|
|
|
|
|
|
|
|
|
|
// change allows calls to read to block until a the network connection has been replaced.
|
|
|
|
|
change *sync.Cond
|
|
|
|
|
|
|
|
|
|
// mu protects fields that follow it; it is also change's Locker.
|
|
|
|
|
mu sync.Mutex
|
|
|
|
|
// closed tracks whether fwdConn has been permanently closed.
|
|
|
|
|
closed bool
|
|
|
|
|
// conn is the current active connection.
|
|
|
|
|
conn net.PacketConn
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func newFwdConn(logf logger.Logf, idx int) *fwdConn {
|
|
|
|
|
c := new(fwdConn)
|
|
|
|
|
c.logf = logger.WithPrefix(logf, fmt.Sprintf("fwdConn %d: ", idx))
|
|
|
|
|
c.change = sync.NewCond(&c.mu)
|
|
|
|
|
// c.conn is created lazily in send
|
|
|
|
|
return c
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// send sends packet to dst using c's connection.
|
|
|
|
|
// It is best effort. It is UDP, after all. Failures are logged.
|
|
|
|
|
func (c *fwdConn) send(packet []byte, dst netaddr.IPPort) {
|
|
|
|
|
var b *backoff.Backoff // lazily initialized, since it is not needed in the common case
|
|
|
|
|
backOff := func(err error) {
|
|
|
|
|
if b == nil {
|
|
|
|
|
b = backoff.NewBackoff("dns-fwdConn-send", c.logf, 30*time.Second)
|
|
|
|
|
}
|
|
|
|
|
b.BackOff(context.Background(), err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for {
|
|
|
|
|
// Gather the current connection.
|
|
|
|
|
// We can't hold the lock while we call WriteTo.
|
|
|
|
|
c.mu.Lock()
|
|
|
|
|
conn := c.conn
|
|
|
|
|
closed := c.closed
|
|
|
|
|
if closed {
|
|
|
|
|
c.mu.Unlock()
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
if conn == nil {
|
|
|
|
|
c.reconnectLocked()
|
|
|
|
|
c.mu.Unlock()
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
c.mu.Unlock()
|
|
|
|
|
|
|
|
|
|
_, err := conn.WriteTo(packet, dst.UDPAddr())
|
|
|
|
|
if err == nil {
|
|
|
|
|
// Success
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
if errors.Is(err, net.ErrClosed) {
|
|
|
|
|
// We intentionally closed this connection.
|
|
|
|
|
// It has been replaced by a new connection. Try again.
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
// Something else went wrong.
|
|
|
|
|
// We have three choices here: try again, give up, or create a new connection.
|
|
|
|
|
var opErr *net.OpError
|
|
|
|
|
if !errors.As(err, &opErr) {
|
|
|
|
|
// Weird. All errors from the net package should be *net.OpError. Bail.
|
|
|
|
|
c.logf("send: non-*net.OpErr %v (%T)", err, err)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
if opErr.Temporary() || opErr.Timeout() {
|
|
|
|
|
// I doubt that either of these can happen (this is UDP),
|
|
|
|
|
// but go ahead and try again.
|
|
|
|
|
backOff(err)
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
if errors.Is(err, syscall.EHOSTUNREACH) {
|
|
|
|
|
// "No route to host." The network stack is fine, but
|
|
|
|
|
// can't talk to this destination. Not much we can do
|
|
|
|
|
// about that, don't spam logs.
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
if networkIsDown(err) {
|
|
|
|
|
// Fail.
|
|
|
|
|
c.logf("send: network is down")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
if networkIsUnreachable(err) {
|
|
|
|
|
// This can be caused by a link change.
|
|
|
|
|
// Replace the existing connection with a new one.
|
|
|
|
|
c.mu.Lock()
|
|
|
|
|
// It's possible that multiple senders discovered simultaneously
|
|
|
|
|
// that the network is unreachable. Avoid reconnecting multiple times:
|
|
|
|
|
// Only reconnect if the current connection is the one that we
|
|
|
|
|
// discovered to be problematic.
|
|
|
|
|
if c.conn == conn {
|
|
|
|
|
backOff(err)
|
|
|
|
|
c.reconnectLocked()
|
|
|
|
|
closeOnCtxDone := new(closePool)
|
|
|
|
|
defer closeOnCtxDone.Close()
|
|
|
|
|
|
|
|
|
|
ctx, cancel := context.WithTimeout(f.ctx, responseTimeout)
|
|
|
|
|
defer cancel()
|
|
|
|
|
|
|
|
|
|
resc := make(chan []byte, 1)
|
|
|
|
|
var (
|
|
|
|
|
mu sync.Mutex
|
|
|
|
|
firstErr error
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
for _, ipp := range resolvers {
|
|
|
|
|
go func(ipp netaddr.IPPort) {
|
|
|
|
|
resb, err := f.send(ctx, txid, closeOnCtxDone, query.bs, ipp)
|
|
|
|
|
if err != nil {
|
|
|
|
|
mu.Lock()
|
|
|
|
|
defer mu.Unlock()
|
|
|
|
|
if firstErr == nil {
|
|
|
|
|
firstErr = err
|
|
|
|
|
}
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
c.mu.Unlock()
|
|
|
|
|
// Try again with our new network connection.
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
// Unrecognized error. Fail.
|
|
|
|
|
c.logf("send: unrecognized error: %v", err)
|
|
|
|
|
return
|
|
|
|
|
select {
|
|
|
|
|
case resc <- resb:
|
|
|
|
|
default:
|
|
|
|
|
}
|
|
|
|
|
}(ipp)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// read waits for a response from c's connection.
|
|
|
|
|
// It returns the number of bytes read, which may be 0
|
|
|
|
|
// in case of an error or a closed connection.
|
|
|
|
|
func (c *fwdConn) read(out []byte) int {
|
|
|
|
|
for {
|
|
|
|
|
// Gather the current connection.
|
|
|
|
|
// We can't hold the lock while we call ReadFrom.
|
|
|
|
|
c.mu.Lock()
|
|
|
|
|
conn := c.conn
|
|
|
|
|
closed := c.closed
|
|
|
|
|
if closed {
|
|
|
|
|
c.mu.Unlock()
|
|
|
|
|
return 0
|
|
|
|
|
}
|
|
|
|
|
if conn == nil {
|
|
|
|
|
// There is no current connection.
|
|
|
|
|
// Wait for the connection to change, then try again.
|
|
|
|
|
c.change.Wait()
|
|
|
|
|
c.mu.Unlock()
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
c.mu.Unlock()
|
|
|
|
|
|
|
|
|
|
n, _, err := conn.ReadFrom(out)
|
|
|
|
|
if err == nil || packetWasTruncated(err) {
|
|
|
|
|
// Success.
|
|
|
|
|
return n
|
|
|
|
|
select {
|
|
|
|
|
case v := <-resc:
|
|
|
|
|
select {
|
|
|
|
|
case <-ctx.Done():
|
|
|
|
|
return ctx.Err()
|
|
|
|
|
case f.responses <- packet{v, query.addr}:
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
if errors.Is(err, net.ErrClosed) {
|
|
|
|
|
// We intentionally closed this connection.
|
|
|
|
|
// It has been replaced by a new connection. Try again.
|
|
|
|
|
continue
|
|
|
|
|
case <-ctx.Done():
|
|
|
|
|
mu.Lock()
|
|
|
|
|
defer mu.Unlock()
|
|
|
|
|
if firstErr != nil {
|
|
|
|
|
return firstErr
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
c.logf("read: unrecognized error: %v", err)
|
|
|
|
|
return 0
|
|
|
|
|
return ctx.Err()
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// reconnectLocked replaces the current connection with a new one.
|
|
|
|
|
// c.mu must be locked.
|
|
|
|
|
func (c *fwdConn) reconnectLocked() {
|
|
|
|
|
c.closeConnLocked()
|
|
|
|
|
// Make a new connection.
|
|
|
|
|
conn, err := net.ListenPacket("udp", "")
|
|
|
|
|
if err != nil {
|
|
|
|
|
c.logf("ListenPacket failed: %v", err)
|
|
|
|
|
} else {
|
|
|
|
|
c.conn = conn
|
|
|
|
|
}
|
|
|
|
|
// Broadcast that a new connection is available.
|
|
|
|
|
c.change.Broadcast()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// closeCurrentConn closes the current connection.
|
|
|
|
|
// c.mu must be locked.
|
|
|
|
|
func (c *fwdConn) closeConnLocked() {
|
|
|
|
|
if c.conn == nil {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
c.conn.Close() // unblocks all readers/writers, they'll pick up the next connection.
|
|
|
|
|
c.conn = nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// close permanently closes c.
|
|
|
|
|
func (c *fwdConn) close() {
|
|
|
|
|
c.mu.Lock()
|
|
|
|
|
defer c.mu.Unlock()
|
|
|
|
|
if c.closed {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
c.closed = true
|
|
|
|
|
c.closeConnLocked()
|
|
|
|
|
// Unblock any remaining readers.
|
|
|
|
|
c.change.Broadcast()
|
|
|
|
|
}
|
|
|
|
|
var initListenConfig func(_ *net.ListenConfig, _ *monitor.Mon, tunName string) error
|
|
|
|
|
|
|
|
|
|
// nameFromQuery extracts the normalized query name from bs.
|
|
|
|
|
func nameFromQuery(bs []byte) (dnsname.FQDN, error) {
|
|
|
|
@ -523,3 +333,48 @@ func nameFromQuery(bs []byte) (dnsname.FQDN, error) {
|
|
|
|
|
n := q.Name.Data[:q.Name.Length]
|
|
|
|
|
return dnsname.ToFQDN(rawNameToLower(n))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// closePool is a dynamic set of io.Closers to close as a group.
|
|
|
|
|
// It's intended to be Closed at most once.
|
|
|
|
|
//
|
|
|
|
|
// The zero value is ready for use.
|
|
|
|
|
type closePool struct {
|
|
|
|
|
mu sync.Mutex
|
|
|
|
|
m map[io.Closer]bool
|
|
|
|
|
closed bool
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (p *closePool) Add(c io.Closer) {
|
|
|
|
|
p.mu.Lock()
|
|
|
|
|
defer p.mu.Unlock()
|
|
|
|
|
if p.closed {
|
|
|
|
|
c.Close()
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
if p.m == nil {
|
|
|
|
|
p.m = map[io.Closer]bool{}
|
|
|
|
|
}
|
|
|
|
|
p.m[c] = true
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (p *closePool) Remove(c io.Closer) {
|
|
|
|
|
p.mu.Lock()
|
|
|
|
|
defer p.mu.Unlock()
|
|
|
|
|
if p.closed {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
delete(p.m, c)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (p *closePool) Close() error {
|
|
|
|
|
p.mu.Lock()
|
|
|
|
|
defer p.mu.Unlock()
|
|
|
|
|
if p.closed {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
p.closed = true
|
|
|
|
|
for c := range p.m {
|
|
|
|
|
c.Close()
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|