// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package dnscache import ( "bytes" "context" "errors" "fmt" "net" "runtime" "testing" "time" "golang.org/x/net/dns/dnsmessage" "tailscale.com/tstest" ) func TestMessageCache(t *testing.T) { clock := &tstest.Clock{ Start: time.Date(1987, 11, 1, 0, 0, 0, 0, time.UTC), } mc := &MessageCache{Clock: clock.Now} mc.SetMaxCacheSize(2) clock.Advance(time.Second) var out bytes.Buffer if err := mc.ReplyFromCache(&out, makeQ(1, "foo.com.")); err != ErrCacheMiss { t.Fatalf("unexpected error: %v", err) } if err := mc.AddCacheEntry( makeQ(2, "foo.com."), makeRes(2, "FOO.COM.", ttlOpt(10), &dnsmessage.AResource{A: [4]byte{127, 0, 0, 1}}, &dnsmessage.AResource{A: [4]byte{127, 0, 0, 2}})); err != nil { t.Fatal(err) } // Expect cache hit, with 10 seconds remaining. out.Reset() if err := mc.ReplyFromCache(&out, makeQ(3, "foo.com.")); err != nil { t.Fatalf("expected cache hit; got: %v", err) } if p := mustParseResponse(t, out.Bytes()); p.TxID != 3 { t.Errorf("TxID = %v; want %v", p.TxID, 3) } else if p.TTL != 10 { t.Errorf("TTL = %v; want 10", p.TTL) } // One second elapses, expect a cache hit, with 9 seconds // remaining. clock.Advance(time.Second) out.Reset() if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.")); err != nil { t.Fatalf("expected cache hit; got: %v", err) } if p := mustParseResponse(t, out.Bytes()); p.TxID != 4 { t.Errorf("TxID = %v; want %v", p.TxID, 4) } else if p.TTL != 9 { t.Errorf("TTL = %v; want 9", p.TTL) } // Expect cache miss on MX record. if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.", dnsmessage.TypeMX)); err != ErrCacheMiss { t.Fatalf("expected cache miss on MX; got: %v", err) } // Expect cache miss on CHAOS class. if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.", dnsmessage.ClassCHAOS)); err != ErrCacheMiss { t.Fatalf("expected cache miss on CHAOS; got: %v", err) } // Ten seconds elapses; expect a cache miss. clock.Advance(10 * time.Second) if err := mc.ReplyFromCache(&out, makeQ(5, "foo.com.")); err != ErrCacheMiss { t.Fatalf("expected cache miss, got: %v", err) } } type parsedMeta struct { TxID uint16 TTL uint32 } func mustParseResponse(t testing.TB, r []byte) (ret parsedMeta) { t.Helper() var p dnsmessage.Parser h, err := p.Start(r) if err != nil { t.Fatal(err) } ret.TxID = h.ID qq, err := p.AllQuestions() if err != nil { t.Fatalf("AllQuestions: %v", err) } if len(qq) != 1 { t.Fatalf("num questions = %v; want 1", len(qq)) } aa, err := p.AllAnswers() if err != nil { t.Fatalf("AllAnswers: %v", err) } for _, r := range aa { if ret.TTL == 0 { ret.TTL = r.Header.TTL } if ret.TTL != r.Header.TTL { t.Fatal("mixed TTLs") } } return ret } type responseOpt bool type ttlOpt uint32 func makeQ(txID uint16, name string, opt ...interface{}) []byte { opt = append(opt, responseOpt(false)) return makeDNSPkt(txID, name, opt...) } func makeRes(txID uint16, name string, opt ...interface{}) []byte { opt = append(opt, responseOpt(true)) return makeDNSPkt(txID, name, opt...) } func makeDNSPkt(txID uint16, name string, opt ...interface{}) []byte { typ := dnsmessage.TypeA class := dnsmessage.ClassINET var response bool var answers []dnsmessage.ResourceBody var ttl uint32 = 1 // one second by default for _, o := range opt { switch o := o.(type) { case dnsmessage.Type: typ = o case dnsmessage.Class: class = o case responseOpt: response = bool(o) case dnsmessage.ResourceBody: answers = append(answers, o) case ttlOpt: ttl = uint32(o) default: panic(fmt.Sprintf("unknown opt type %T", o)) } } qname := dnsmessage.MustNewName(name) msg := dnsmessage.Message{ Header: dnsmessage.Header{ID: txID, Response: response}, Questions: []dnsmessage.Question{ { Name: qname, Type: typ, Class: class, }, }, } for _, rb := range answers { msg.Answers = append(msg.Answers, dnsmessage.Resource{ Header: dnsmessage.ResourceHeader{ Name: qname, Type: typ, Class: class, TTL: ttl, }, Body: rb, }) } buf, err := msg.Pack() if err != nil { panic(err) } return buf } func TestASCIILowerName(t *testing.T) { n := asciiLowerName(dnsmessage.MustNewName("Foo.COM.")) if got, want := n.String(), "foo.com."; got != want { t.Errorf("got = %q; want %q", got, want) } } func TestGetDNSQueryCacheKey(t *testing.T) { tests := []struct { name string pkt []byte want msgQ txID uint16 anyTX bool }{ { name: "empty", }, { name: "a", pkt: makeQ(123, "foo.com."), want: msgQ{"foo.com.", dnsmessage.TypeA}, txID: 123, }, { name: "aaaa", pkt: makeQ(6, "foo.com.", dnsmessage.TypeAAAA), want: msgQ{"foo.com.", dnsmessage.TypeAAAA}, txID: 6, }, { name: "normalize_case", pkt: makeQ(123, "FoO.CoM."), want: msgQ{"foo.com.", dnsmessage.TypeA}, txID: 123, }, { name: "ignore_response", pkt: makeRes(123, "foo.com."), }, { name: "ignore_question_with_answers", pkt: makeQ(2, "foo.com.", &dnsmessage.AResource{A: [4]byte{127, 0, 0, 1}}), }, { name: "whatever_go_generates", // in case Go's net package grows functionality we don't handle pkt: getGoNetPacketDNSQuery("from-go.foo."), want: msgQ{"from-go.foo.", dnsmessage.TypeA}, anyTX: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, gotTX, ok := getDNSQueryCacheKey(tt.pkt) if !ok { if tt.txID == 0 && got == (msgQ{}) { return } t.Fatal("failed") } if got != tt.want { t.Errorf("got %+v, want %+v", got, tt.want) } if gotTX != tt.txID && !tt.anyTX { t.Errorf("got tx %v, want %v", gotTX, tt.txID) } }) } } func getGoNetPacketDNSQuery(name string) []byte { if runtime.GOOS == "windows" { // On Windows, Go's net.Resolver doesn't use the DNS client. // See https://github.com/golang/go/issues/33097 which // was approved but not yet implemented. // For now just pretend it's implemented to make this test // pass on Windows with complicated the caller. return makeQ(123, name) } res := make(chan []byte, 1) r := &net.Resolver{ PreferGo: true, Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return goResolverConn(res), nil }, } r.LookupIP(context.Background(), "ip4", name) return <-res } type goResolverConn chan<- []byte func (goResolverConn) Close() error { return nil } func (goResolverConn) LocalAddr() net.Addr { return todoAddr{} } func (goResolverConn) RemoteAddr() net.Addr { return todoAddr{} } func (goResolverConn) SetDeadline(t time.Time) error { return nil } func (goResolverConn) SetReadDeadline(t time.Time) error { return nil } func (goResolverConn) SetWriteDeadline(t time.Time) error { return nil } func (goResolverConn) Read([]byte) (int, error) { return 0, errors.New("boom") } func (c goResolverConn) Write(p []byte) (int, error) { select { case c <- p[2:]: // skip 2 byte length for TCP mode DNS query default: } return 0, errors.New("boom") } type todoAddr struct{} func (todoAddr) Network() string { return "unused" } func (todoAddr) String() string { return "unused-todoAddr" }