|
|
|
@ -22,11 +22,6 @@ import (
|
|
|
|
|
// headerBytes is the number of bytes in a DNS message header.
|
|
|
|
|
const headerBytes = 12
|
|
|
|
|
|
|
|
|
|
// forwardQueueSize is the maximal number of requests that can be pending delegation.
|
|
|
|
|
// Note that this is distinct from the number of requests that are pending a response,
|
|
|
|
|
// which is not limited (except by txid collisions).
|
|
|
|
|
const forwardQueueSize = 64
|
|
|
|
|
|
|
|
|
|
// connCount is the number of UDP connections to use for forwarding.
|
|
|
|
|
const connCount = 32
|
|
|
|
|
|
|
|
|
@ -138,7 +133,6 @@ func newForwarder(logf logger.Logf, responses chan Packet) *forwarder {
|
|
|
|
|
return &forwarder{
|
|
|
|
|
logf: logger.WithPrefix(logf, "forward: "),
|
|
|
|
|
responses: responses,
|
|
|
|
|
queue: make(chan forwardedPacket, forwardQueueSize),
|
|
|
|
|
closed: make(chan struct{}),
|
|
|
|
|
conns: make([]*net.UDPConn, connCount),
|
|
|
|
|
txMap: make(map[txid]forwardingRecord),
|
|
|
|
@ -155,11 +149,10 @@ func (f *forwarder) Start() error {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
f.wg.Add(connCount + 2)
|
|
|
|
|
f.wg.Add(connCount + 1)
|
|
|
|
|
for idx, conn := range f.conns {
|
|
|
|
|
go f.recv(uint16(idx), conn)
|
|
|
|
|
}
|
|
|
|
|
go f.send()
|
|
|
|
|
go f.cleanMap()
|
|
|
|
|
|
|
|
|
|
return nil
|
|
|
|
@ -191,28 +184,13 @@ func (f *forwarder) setUpstreams(upstreams []net.Addr) {
|
|
|
|
|
f.mu.Unlock()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (f *forwarder) send() {
|
|
|
|
|
defer f.wg.Done()
|
|
|
|
|
|
|
|
|
|
var packet forwardedPacket
|
|
|
|
|
for {
|
|
|
|
|
select {
|
|
|
|
|
case <-f.closed:
|
|
|
|
|
return
|
|
|
|
|
case packet = <-f.queue:
|
|
|
|
|
// continue
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
connIdx := rand.Intn(connCount)
|
|
|
|
|
conn := f.conns[connIdx]
|
|
|
|
|
_, err := conn.WriteTo(packet.payload, packet.dst)
|
|
|
|
|
if err != nil {
|
|
|
|
|
// Do not log errors due to expired deadline.
|
|
|
|
|
if !errors.Is(err, os.ErrDeadlineExceeded) {
|
|
|
|
|
f.logf("send: %v", err)
|
|
|
|
|
}
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
func (f *forwarder) send(packet []byte, dst net.Addr) {
|
|
|
|
|
connIdx := rand.Intn(connCount)
|
|
|
|
|
conn := f.conns[connIdx]
|
|
|
|
|
_, err := conn.WriteTo(packet, dst)
|
|
|
|
|
// Do not log errors due to expired deadline.
|
|
|
|
|
if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) {
|
|
|
|
|
f.logf("send: %v", err)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -308,17 +286,8 @@ func (f *forwarder) forward(query Packet) error {
|
|
|
|
|
|
|
|
|
|
f.mu.Unlock()
|
|
|
|
|
|
|
|
|
|
packet := forwardedPacket{
|
|
|
|
|
payload: query.Payload,
|
|
|
|
|
}
|
|
|
|
|
for _, upstream := range upstreams {
|
|
|
|
|
packet.dst = upstream
|
|
|
|
|
select {
|
|
|
|
|
case <-f.closed:
|
|
|
|
|
return ErrClosed
|
|
|
|
|
case f.queue <- packet:
|
|
|
|
|
// continue
|
|
|
|
|
}
|
|
|
|
|
f.send(query.Payload, upstream)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return nil
|
|
|
|
|