mirror of https://github.com/tailscale/tailscale/
net/dnscache, net/tsdial: add DNS caching to tsdial UserDial
This is enough to handle the DNS queries as generated by Go's net package (which our HTTP/SOCKS client uses), and the responses generated by the ExitDNS DoH server. This isn't yet suitable for putting on 100.100.100.100 where a number of different DNS clients would hit it, as this doesn't yet do EDNS0. It might work, but it's untested and likely incomplete. Likewise, this doesn't handle anything about truncation, as the exchanges are entirely in memory between Go or DoH. That would also need to be handled later, if/when it's hooked up to 100.100.100.100. Updates #3507 Change-Id: I1736b0ad31eea85ea853b310c52c5e6bf65c6e2a Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>pull/3546/head
parent
b59e7669c1
commit
39ffa16853
@ -0,0 +1,314 @@
|
|||||||
|
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package dnscache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang/groupcache/lru"
|
||||||
|
"golang.org/x/net/dns/dnsmessage"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MessageCache is a cache that works at the DNS message layer,
|
||||||
|
// with its cache keyed on a DNS wire-level question, and capable
|
||||||
|
// of replying to DNS messages.
|
||||||
|
//
|
||||||
|
// Its zero value is ready for use with a default cache size.
|
||||||
|
// Use SetMaxCacheSize to specify the cache size.
|
||||||
|
//
|
||||||
|
// It's safe for concurrent use.
|
||||||
|
type MessageCache struct {
|
||||||
|
// Clock is a clock, for testing.
|
||||||
|
// If nil, time.Now is used.
|
||||||
|
Clock func() time.Time
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
cacheSizeSet int // 0 means default
|
||||||
|
cache lru.Cache // msgQ => *msgCacheValue
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *MessageCache) now() time.Time {
|
||||||
|
if c.Clock != nil {
|
||||||
|
return c.Clock()
|
||||||
|
}
|
||||||
|
return time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMaxCacheSize sets the maximum number of DNS cache entries that
|
||||||
|
// can be stored.
|
||||||
|
func (c *MessageCache) SetMaxCacheSize(n int) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
c.cacheSizeSet = n
|
||||||
|
c.pruneLocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush clears the cache.
|
||||||
|
func (c *MessageCache) Flush() {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
c.cache.Clear()
|
||||||
|
}
|
||||||
|
|
||||||
|
// pruneLocked prunes down the cache size to the configured (or
|
||||||
|
// default) max size.
|
||||||
|
func (c *MessageCache) pruneLocked() {
|
||||||
|
max := c.cacheSizeSet
|
||||||
|
if max == 0 {
|
||||||
|
max = 500
|
||||||
|
}
|
||||||
|
for c.cache.Len() > max {
|
||||||
|
c.cache.RemoveOldest()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// msgQ is the MessageCache cache key.
|
||||||
|
//
|
||||||
|
// It's basically a golang.org/x/net/dns/dnsmessage#Question but the
|
||||||
|
// Class is omitted (we only cache ClassINET) and we store a Go string
|
||||||
|
// instead of a 256 byte dnsmessage.Name array.
|
||||||
|
type msgQ struct {
|
||||||
|
Name string
|
||||||
|
Type dnsmessage.Type // A, AAAA, MX, etc
|
||||||
|
}
|
||||||
|
|
||||||
|
// A *msgCacheValue is the cached value for a msgQ (question) key.
|
||||||
|
//
|
||||||
|
// Despite using pointers for storage and methods, the value is
|
||||||
|
// immutable once placed in the cache.
|
||||||
|
type msgCacheValue struct {
|
||||||
|
Expires time.Time
|
||||||
|
|
||||||
|
// Answers are the minimum data to reconstruct a DNS response
|
||||||
|
// message. TTLs are added later when converting to a
|
||||||
|
// dnsmessage.Resource.
|
||||||
|
Answers []msgResource
|
||||||
|
}
|
||||||
|
|
||||||
|
type msgResource struct {
|
||||||
|
Name string
|
||||||
|
Type dnsmessage.Type // dnsmessage.UnknownResource.Type
|
||||||
|
Data []byte // dnsmessage.UnknownResource.Data
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrCacheMiss is a sentinel error returned by MessageCache.ReplyFromCache
|
||||||
|
// when the request can not be satisified from cache.
|
||||||
|
var ErrCacheMiss = errors.New("cache miss")
|
||||||
|
|
||||||
|
var parserPool = &sync.Pool{
|
||||||
|
New: func() interface{} { return new(dnsmessage.Parser) },
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReplyFromCache writes a DNS reply to w for the provided DNS query message,
|
||||||
|
// which must begin with the two ID bytes of a DNS message.
|
||||||
|
//
|
||||||
|
// If there's a cache miss, the message is invalid or unexpected,
|
||||||
|
// ErrCacheMiss is returned. On cache hit, either nil or an error from
|
||||||
|
// a w.Write call is returned.
|
||||||
|
func (c *MessageCache) ReplyFromCache(w io.Writer, dnsQueryMessage []byte) error {
|
||||||
|
cacheKey, txID, ok := getDNSQueryCacheKey(dnsQueryMessage)
|
||||||
|
if !ok {
|
||||||
|
return ErrCacheMiss
|
||||||
|
}
|
||||||
|
now := c.now()
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
cacheEntI, _ := c.cache.Get(cacheKey)
|
||||||
|
v, ok := cacheEntI.(*msgCacheValue)
|
||||||
|
if ok && now.After(v.Expires) {
|
||||||
|
c.cache.Remove(cacheKey)
|
||||||
|
ok = false
|
||||||
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return ErrCacheMiss
|
||||||
|
}
|
||||||
|
|
||||||
|
ttl := uint32(v.Expires.Sub(now).Seconds())
|
||||||
|
|
||||||
|
packedRes, err := packDNSResponse(cacheKey, txID, ttl, v.Answers)
|
||||||
|
if err != nil {
|
||||||
|
return ErrCacheMiss
|
||||||
|
}
|
||||||
|
_, err = w.Write(packedRes)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
errNotCacheable = errors.New("question not cacheable")
|
||||||
|
)
|
||||||
|
|
||||||
|
// AddCacheEntry adds a cache entry to the cache.
|
||||||
|
// It returns an error if the entry could not be cached.
|
||||||
|
func (c *MessageCache) AddCacheEntry(qPacket, res []byte) error {
|
||||||
|
cacheKey, qID, ok := getDNSQueryCacheKey(qPacket)
|
||||||
|
if !ok {
|
||||||
|
return errNotCacheable
|
||||||
|
}
|
||||||
|
now := c.now()
|
||||||
|
v := &msgCacheValue{}
|
||||||
|
|
||||||
|
p := parserPool.Get().(*dnsmessage.Parser)
|
||||||
|
defer parserPool.Put(p)
|
||||||
|
|
||||||
|
resh, err := p.Start(res)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("reading header in response: %w", err)
|
||||||
|
}
|
||||||
|
if resh.ID != qID {
|
||||||
|
return fmt.Errorf("response ID doesn't match query ID")
|
||||||
|
}
|
||||||
|
q, err := p.Question()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("reading 1st question in response: %w", err)
|
||||||
|
}
|
||||||
|
if _, err := p.Question(); err != dnsmessage.ErrSectionDone {
|
||||||
|
if err == nil {
|
||||||
|
return errors.New("unexpected 2nd question in response")
|
||||||
|
}
|
||||||
|
return fmt.Errorf("after reading 1st question in response: %w", err)
|
||||||
|
}
|
||||||
|
if resName := asciiLowerName(q.Name).String(); resName != cacheKey.Name {
|
||||||
|
return fmt.Errorf("response question name %q != question name %q", resName, cacheKey.Name)
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
rh, err := p.AnswerHeader()
|
||||||
|
if err == dnsmessage.ErrSectionDone {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("reading answer: %w", err)
|
||||||
|
}
|
||||||
|
res, err := p.UnknownResource()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("reading resource: %w", err)
|
||||||
|
}
|
||||||
|
if rh.Class != dnsmessage.ClassINET {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the cache entry's expiration to the soonest
|
||||||
|
// we've seen. (They should all be the same, though)
|
||||||
|
expires := now.Add(time.Duration(rh.TTL) * time.Second)
|
||||||
|
if v.Expires.IsZero() || expires.Before(v.Expires) {
|
||||||
|
v.Expires = expires
|
||||||
|
}
|
||||||
|
v.Answers = append(v.Answers, msgResource{
|
||||||
|
Name: rh.Name.String(),
|
||||||
|
Type: rh.Type,
|
||||||
|
Data: res.Data, // doesn't alias; a copy from dnsmessage.unpackUnknownResource
|
||||||
|
})
|
||||||
|
}
|
||||||
|
c.addCacheValue(cacheKey, v)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *MessageCache) addCacheValue(cacheKey msgQ, v *msgCacheValue) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
c.cache.Add(cacheKey, v)
|
||||||
|
c.pruneLocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
func getDNSQueryCacheKey(msg []byte) (cacheKey msgQ, txID uint16, ok bool) {
|
||||||
|
p := parserPool.Get().(*dnsmessage.Parser)
|
||||||
|
defer parserPool.Put(p)
|
||||||
|
h, err := p.Start(msg)
|
||||||
|
const dnsHeaderSize = 12
|
||||||
|
if err != nil || h.OpCode != 0 || h.Response || h.Truncated ||
|
||||||
|
len(msg) < dnsHeaderSize { // p.Start checks this anyway, but to be explicit for slicing below
|
||||||
|
return cacheKey, 0, false
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
numQ = binary.BigEndian.Uint16(msg[4:6])
|
||||||
|
numAns = binary.BigEndian.Uint16(msg[6:8])
|
||||||
|
numAuth = binary.BigEndian.Uint16(msg[8:10])
|
||||||
|
numAddn = binary.BigEndian.Uint16(msg[10:12])
|
||||||
|
)
|
||||||
|
_ = numAddn // ignore this for now; do client OSes send EDNS additional? assume so, ignore.
|
||||||
|
if !(numQ == 1 && numAns == 0 && numAuth == 0) {
|
||||||
|
// Something weird. We don't want to deal with it.
|
||||||
|
return cacheKey, 0, false
|
||||||
|
}
|
||||||
|
q, err := p.Question()
|
||||||
|
if err != nil {
|
||||||
|
// Already verified numQ == 1 so shouldn't happen, but:
|
||||||
|
return cacheKey, 0, false
|
||||||
|
}
|
||||||
|
if q.Class != dnsmessage.ClassINET {
|
||||||
|
// We only cache the Internet class.
|
||||||
|
return cacheKey, 0, false
|
||||||
|
}
|
||||||
|
return msgQ{Name: asciiLowerName(q.Name).String(), Type: q.Type}, h.ID, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func asciiLowerName(n dnsmessage.Name) dnsmessage.Name {
|
||||||
|
nb := n.Data[:]
|
||||||
|
if int(n.Length) < len(n.Data) {
|
||||||
|
nb = nb[:n.Length]
|
||||||
|
}
|
||||||
|
for i, b := range nb {
|
||||||
|
if 'A' <= b && b <= 'Z' {
|
||||||
|
n.Data[i] += 0x20
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
// packDNSResponse builds a DNS response for the given question and
|
||||||
|
// transaction ID. The response resource records will have have the
|
||||||
|
// same provided TTL.
|
||||||
|
func packDNSResponse(q msgQ, txID uint16, ttl uint32, answers []msgResource) ([]byte, error) {
|
||||||
|
var baseMem []byte // TODO: guess a max size based on looping over answers?
|
||||||
|
b := dnsmessage.NewBuilder(baseMem, dnsmessage.Header{
|
||||||
|
ID: txID,
|
||||||
|
Response: true,
|
||||||
|
OpCode: 0,
|
||||||
|
Authoritative: false,
|
||||||
|
Truncated: false,
|
||||||
|
RCode: dnsmessage.RCodeSuccess,
|
||||||
|
})
|
||||||
|
name, err := dnsmessage.NewName(q.Name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := b.StartQuestions(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := b.Question(dnsmessage.Question{
|
||||||
|
Name: name,
|
||||||
|
Type: q.Type,
|
||||||
|
Class: dnsmessage.ClassINET,
|
||||||
|
}); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := b.StartAnswers(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, r := range answers {
|
||||||
|
name, err := dnsmessage.NewName(r.Name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := b.UnknownResource(dnsmessage.ResourceHeader{
|
||||||
|
Name: name,
|
||||||
|
Type: r.Type,
|
||||||
|
Class: dnsmessage.ClassINET,
|
||||||
|
TTL: ttl,
|
||||||
|
}, dnsmessage.UnknownResource{
|
||||||
|
Type: r.Type,
|
||||||
|
Data: r.Data,
|
||||||
|
}); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return b.Finish()
|
||||||
|
}
|
@ -0,0 +1,292 @@
|
|||||||
|
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package dnscache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"runtime"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/net/dns/dnsmessage"
|
||||||
|
"tailscale.com/tstest"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMessageCache(t *testing.T) {
|
||||||
|
clock := &tstest.Clock{
|
||||||
|
Start: time.Date(1987, 11, 1, 0, 0, 0, 0, time.UTC),
|
||||||
|
}
|
||||||
|
mc := &MessageCache{Clock: clock.Now}
|
||||||
|
mc.SetMaxCacheSize(2)
|
||||||
|
clock.Advance(time.Second)
|
||||||
|
|
||||||
|
var out bytes.Buffer
|
||||||
|
if err := mc.ReplyFromCache(&out, makeQ(1, "foo.com.")); err != ErrCacheMiss {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mc.AddCacheEntry(
|
||||||
|
makeQ(2, "foo.com."),
|
||||||
|
makeRes(2, "FOO.COM.", ttlOpt(10),
|
||||||
|
&dnsmessage.AResource{A: [4]byte{127, 0, 0, 1}},
|
||||||
|
&dnsmessage.AResource{A: [4]byte{127, 0, 0, 2}})); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expect cache hit, with 10 seconds remaining.
|
||||||
|
out.Reset()
|
||||||
|
if err := mc.ReplyFromCache(&out, makeQ(3, "foo.com.")); err != nil {
|
||||||
|
t.Fatalf("expected cache hit; got: %v", err)
|
||||||
|
}
|
||||||
|
if p := mustParseResponse(t, out.Bytes()); p.TxID != 3 {
|
||||||
|
t.Errorf("TxID = %v; want %v", p.TxID, 3)
|
||||||
|
} else if p.TTL != 10 {
|
||||||
|
t.Errorf("TTL = %v; want 10", p.TTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// One second elapses, expect a cache hit, with 9 seconds
|
||||||
|
// remaining.
|
||||||
|
clock.Advance(time.Second)
|
||||||
|
out.Reset()
|
||||||
|
if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.")); err != nil {
|
||||||
|
t.Fatalf("expected cache hit; got: %v", err)
|
||||||
|
}
|
||||||
|
if p := mustParseResponse(t, out.Bytes()); p.TxID != 4 {
|
||||||
|
t.Errorf("TxID = %v; want %v", p.TxID, 4)
|
||||||
|
} else if p.TTL != 9 {
|
||||||
|
t.Errorf("TTL = %v; want 9", p.TTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expect cache miss on MX record.
|
||||||
|
if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.", dnsmessage.TypeMX)); err != ErrCacheMiss {
|
||||||
|
t.Fatalf("expected cache miss on MX; got: %v", err)
|
||||||
|
}
|
||||||
|
// Expect cache miss on CHAOS class.
|
||||||
|
if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.", dnsmessage.ClassCHAOS)); err != ErrCacheMiss {
|
||||||
|
t.Fatalf("expected cache miss on CHAOS; got: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ten seconds elapses; expect a cache miss.
|
||||||
|
clock.Advance(10 * time.Second)
|
||||||
|
if err := mc.ReplyFromCache(&out, makeQ(5, "foo.com.")); err != ErrCacheMiss {
|
||||||
|
t.Fatalf("expected cache miss, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type parsedMeta struct {
|
||||||
|
TxID uint16
|
||||||
|
TTL uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustParseResponse(t testing.TB, r []byte) (ret parsedMeta) {
|
||||||
|
t.Helper()
|
||||||
|
var p dnsmessage.Parser
|
||||||
|
h, err := p.Start(r)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
ret.TxID = h.ID
|
||||||
|
qq, err := p.AllQuestions()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("AllQuestions: %v", err)
|
||||||
|
}
|
||||||
|
if len(qq) != 1 {
|
||||||
|
t.Fatalf("num questions = %v; want 1", len(qq))
|
||||||
|
}
|
||||||
|
aa, err := p.AllAnswers()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("AllAnswers: %v", err)
|
||||||
|
}
|
||||||
|
for _, r := range aa {
|
||||||
|
if ret.TTL == 0 {
|
||||||
|
ret.TTL = r.Header.TTL
|
||||||
|
}
|
||||||
|
if ret.TTL != r.Header.TTL {
|
||||||
|
t.Fatal("mixed TTLs")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
type responseOpt bool
|
||||||
|
|
||||||
|
type ttlOpt uint32
|
||||||
|
|
||||||
|
func makeQ(txID uint16, name string, opt ...interface{}) []byte {
|
||||||
|
opt = append(opt, responseOpt(false))
|
||||||
|
return makeDNSPkt(txID, name, opt...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeRes(txID uint16, name string, opt ...interface{}) []byte {
|
||||||
|
opt = append(opt, responseOpt(true))
|
||||||
|
return makeDNSPkt(txID, name, opt...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeDNSPkt(txID uint16, name string, opt ...interface{}) []byte {
|
||||||
|
typ := dnsmessage.TypeA
|
||||||
|
class := dnsmessage.ClassINET
|
||||||
|
var response bool
|
||||||
|
var answers []dnsmessage.ResourceBody
|
||||||
|
var ttl uint32 = 1 // one second by default
|
||||||
|
for _, o := range opt {
|
||||||
|
switch o := o.(type) {
|
||||||
|
case dnsmessage.Type:
|
||||||
|
typ = o
|
||||||
|
case dnsmessage.Class:
|
||||||
|
class = o
|
||||||
|
case responseOpt:
|
||||||
|
response = bool(o)
|
||||||
|
case dnsmessage.ResourceBody:
|
||||||
|
answers = append(answers, o)
|
||||||
|
case ttlOpt:
|
||||||
|
ttl = uint32(o)
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("unknown opt type %T", o))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
qname := dnsmessage.MustNewName(name)
|
||||||
|
msg := dnsmessage.Message{
|
||||||
|
Header: dnsmessage.Header{ID: txID, Response: response},
|
||||||
|
Questions: []dnsmessage.Question{
|
||||||
|
{
|
||||||
|
Name: qname,
|
||||||
|
Type: typ,
|
||||||
|
Class: class,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, rb := range answers {
|
||||||
|
msg.Answers = append(msg.Answers, dnsmessage.Resource{
|
||||||
|
Header: dnsmessage.ResourceHeader{
|
||||||
|
Name: qname,
|
||||||
|
Type: typ,
|
||||||
|
Class: class,
|
||||||
|
TTL: ttl,
|
||||||
|
},
|
||||||
|
Body: rb,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
buf, err := msg.Pack()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestASCIILowerName(t *testing.T) {
|
||||||
|
n := asciiLowerName(dnsmessage.MustNewName("Foo.COM."))
|
||||||
|
if got, want := n.String(), "foo.com."; got != want {
|
||||||
|
t.Errorf("got = %q; want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetDNSQueryCacheKey(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
pkt []byte
|
||||||
|
want msgQ
|
||||||
|
txID uint16
|
||||||
|
anyTX bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "a",
|
||||||
|
pkt: makeQ(123, "foo.com."),
|
||||||
|
want: msgQ{"foo.com.", dnsmessage.TypeA},
|
||||||
|
txID: 123,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "aaaa",
|
||||||
|
pkt: makeQ(6, "foo.com.", dnsmessage.TypeAAAA),
|
||||||
|
want: msgQ{"foo.com.", dnsmessage.TypeAAAA},
|
||||||
|
txID: 6,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "normalize_case",
|
||||||
|
pkt: makeQ(123, "FoO.CoM."),
|
||||||
|
want: msgQ{"foo.com.", dnsmessage.TypeA},
|
||||||
|
txID: 123,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ignore_response",
|
||||||
|
pkt: makeRes(123, "foo.com."),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ignore_question_with_answers",
|
||||||
|
pkt: makeQ(2, "foo.com.", &dnsmessage.AResource{A: [4]byte{127, 0, 0, 1}}),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "whatever_go_generates", // in case Go's net package grows functionality we don't handle
|
||||||
|
pkt: getGoNetPacketDNSQuery("from-go.foo."),
|
||||||
|
want: msgQ{"from-go.foo.", dnsmessage.TypeA},
|
||||||
|
anyTX: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, gotTX, ok := getDNSQueryCacheKey(tt.pkt)
|
||||||
|
if !ok {
|
||||||
|
if tt.txID == 0 && got == (msgQ{}) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Fatal("failed")
|
||||||
|
}
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("got %+v, want %+v", got, tt.want)
|
||||||
|
}
|
||||||
|
if gotTX != tt.txID && !tt.anyTX {
|
||||||
|
t.Errorf("got tx %v, want %v", gotTX, tt.txID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getGoNetPacketDNSQuery(name string) []byte {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
// On Windows, Go's net.Resolver doesn't use the DNS client.
|
||||||
|
// See https://github.com/golang/go/issues/33097 which
|
||||||
|
// was approved but not yet implemented.
|
||||||
|
// For now just pretend it's implemented to make this test
|
||||||
|
// pass on Windows with complicated the caller.
|
||||||
|
return makeQ(123, name)
|
||||||
|
}
|
||||||
|
res := make(chan []byte, 1)
|
||||||
|
r := &net.Resolver{
|
||||||
|
PreferGo: true,
|
||||||
|
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
return goResolverConn(res), nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
r.LookupIP(context.Background(), "ip4", name)
|
||||||
|
return <-res
|
||||||
|
}
|
||||||
|
|
||||||
|
type goResolverConn chan<- []byte
|
||||||
|
|
||||||
|
func (goResolverConn) Close() error { return nil }
|
||||||
|
func (goResolverConn) LocalAddr() net.Addr { return todoAddr{} }
|
||||||
|
func (goResolverConn) RemoteAddr() net.Addr { return todoAddr{} }
|
||||||
|
func (goResolverConn) SetDeadline(t time.Time) error { return nil }
|
||||||
|
func (goResolverConn) SetReadDeadline(t time.Time) error { return nil }
|
||||||
|
func (goResolverConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||||
|
func (goResolverConn) Read([]byte) (int, error) { return 0, errors.New("boom") }
|
||||||
|
func (c goResolverConn) Write(p []byte) (int, error) {
|
||||||
|
select {
|
||||||
|
case c <- p[2:]: // skip 2 byte length for TCP mode DNS query
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
return 0, errors.New("boom")
|
||||||
|
}
|
||||||
|
|
||||||
|
type todoAddr struct{}
|
||||||
|
|
||||||
|
func (todoAddr) Network() string { return "unused" }
|
||||||
|
func (todoAddr) String() string { return "unused-todoAddr" }
|
Loading…
Reference in New Issue