// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package dnscache import ( "context" "errors" "flag" "fmt" "net" "reflect" "testing" "time" "inet.af/netaddr" ) var dialTest = flag.String("dial-test", "", "if non-empty, addr:port to test dial") func TestDialer(t *testing.T) { if *dialTest == "" { t.Skip("skipping; --dial-test is blank") } r := new(Resolver) var std net.Dialer dialer := Dialer(std.DialContext, r) t0 := time.Now() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() c, err := dialer(ctx, "tcp", *dialTest) if err != nil { t.Fatal(err) } t.Logf("dialed in %v", time.Since(t0)) c.Close() } func TestDialCall_DNSWasTrustworthy(t *testing.T) { type step struct { ip netaddr.IP // IP we pretended to dial err error // the dial error or nil for success } mustIP := netaddr.MustParseIP errFail := errors.New("some connect failure") tests := []struct { name string steps []step want bool }{ { name: "no-info", want: false, }, { name: "previous-dial", steps: []step{ {mustIP("2003::1"), nil}, {mustIP("2003::1"), errFail}, }, want: true, }, { name: "no-previous-dial", steps: []step{ {mustIP("2003::1"), errFail}, }, want: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { d := &dialer{ pastConnect: map[netaddr.IP]time.Time{}, } dc := &dialCall{ d: d, } for _, st := range tt.steps { dc.noteDialResult(st.ip, st.err) } got := dc.dnsWasTrustworthy() if got != tt.want { t.Errorf("got %v; want %v", got, tt.want) } }) } } func TestDialCall_uniqueIPs(t *testing.T) { dc := &dialCall{} mustIP := netaddr.MustParseIP errFail := errors.New("some connect failure") dc.noteDialResult(mustIP("2003::1"), errFail) dc.noteDialResult(mustIP("2003::2"), errFail) got := dc.uniqueIPs([]netaddr.IP{ mustIP("2003::1"), mustIP("2003::2"), mustIP("2003::2"), mustIP("2003::3"), mustIP("2003::3"), mustIP("2003::4"), mustIP("2003::4"), }) want := []netaddr.IP{ mustIP("2003::3"), mustIP("2003::4"), } if !reflect.DeepEqual(got, want) { t.Errorf("got %v; want %v", got, want) } } func TestResolverAllHostStaticResult(t *testing.T) { r := &Resolver{ SingleHost: "foo.bar", SingleHostStaticResult: []netaddr.IP{ netaddr.MustParseIP("2001:4860:4860::8888"), netaddr.MustParseIP("2001:4860:4860::8844"), netaddr.MustParseIP("8.8.8.8"), netaddr.MustParseIP("8.8.4.4"), }, } ip4, ip6, allIPs, err := r.LookupIP(context.Background(), "foo.bar") if err != nil { t.Fatal(err) } if got, want := ip4.String(), "8.8.8.8"; got != want { t.Errorf("ip4 got %q; want %q", got, want) } if got, want := ip6.String(), "2001:4860:4860::8888"; got != want { t.Errorf("ip4 got %q; want %q", got, want) } if got, want := fmt.Sprintf("%q", allIPs), `[{"2001:4860:4860::8888" ""} {"2001:4860:4860::8844" ""} {"8.8.8.8" ""} {"8.8.4.4" ""}]`; got != want { t.Errorf("allIPs got %q; want %q", got, want) } _, _, _, err = r.LookupIP(context.Background(), "bad") if got, want := fmt.Sprint(err), `dnscache: unexpected hostname "bad" doesn't match expected "foo.bar"`; got != want { t.Errorf("bad dial error got %q; want %q", got, want) } }