// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package dnscache import ( "context" "errors" "flag" "fmt" "net" "net/netip" "reflect" "testing" "time" ) var dialTest = flag.String("dial-test", "", "if non-empty, addr:port to test dial") func TestDialer(t *testing.T) { if *dialTest == "" { t.Skip("skipping; --dial-test is blank") } r := new(Resolver) var std net.Dialer dialer := Dialer(std.DialContext, r) t0 := time.Now() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() c, err := dialer(ctx, "tcp", *dialTest) if err != nil { t.Fatal(err) } t.Logf("dialed in %v", time.Since(t0)) c.Close() } func TestDialCall_DNSWasTrustworthy(t *testing.T) { type step struct { ip netip.Addr // IP we pretended to dial err error // the dial error or nil for success } mustIP := netip.MustParseAddr errFail := errors.New("some connect failure") tests := []struct { name string steps []step want bool }{ { name: "no-info", want: false, }, { name: "previous-dial", steps: []step{ {mustIP("2003::1"), nil}, {mustIP("2003::1"), errFail}, }, want: true, }, { name: "no-previous-dial", steps: []step{ {mustIP("2003::1"), errFail}, }, want: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { d := &dialer{ pastConnect: map[netip.Addr]time.Time{}, } dc := &dialCall{ d: d, } for _, st := range tt.steps { dc.noteDialResult(st.ip, st.err) } got := dc.dnsWasTrustworthy() if got != tt.want { t.Errorf("got %v; want %v", got, tt.want) } }) } } func TestDialCall_uniqueIPs(t *testing.T) { dc := &dialCall{} mustIP := netip.MustParseAddr errFail := errors.New("some connect failure") dc.noteDialResult(mustIP("2003::1"), errFail) dc.noteDialResult(mustIP("2003::2"), errFail) got := dc.uniqueIPs([]netip.Addr{ mustIP("2003::1"), mustIP("2003::2"), mustIP("2003::2"), mustIP("2003::3"), mustIP("2003::3"), mustIP("2003::4"), mustIP("2003::4"), }) want := []netip.Addr{ mustIP("2003::3"), mustIP("2003::4"), } if !reflect.DeepEqual(got, want) { t.Errorf("got %v; want %v", got, want) } } func TestResolverAllHostStaticResult(t *testing.T) { r := &Resolver{ SingleHost: "foo.bar", SingleHostStaticResult: []netip.Addr{ netip.MustParseAddr("2001:4860:4860::8888"), netip.MustParseAddr("2001:4860:4860::8844"), netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("8.8.4.4"), }, } ip4, ip6, allIPs, err := r.LookupIP(context.Background(), "foo.bar") if err != nil { t.Fatal(err) } if got, want := ip4.String(), "8.8.8.8"; got != want { t.Errorf("ip4 got %q; want %q", got, want) } if got, want := ip6.String(), "2001:4860:4860::8888"; got != want { t.Errorf("ip4 got %q; want %q", got, want) } if got, want := fmt.Sprintf("%q", allIPs), `["2001:4860:4860::8888" "2001:4860:4860::8844" "8.8.8.8" "8.8.4.4"]`; got != want { t.Errorf("allIPs got %q; want %q", got, want) } _, _, _, err = r.LookupIP(context.Background(), "bad") if got, want := fmt.Sprint(err), `dnscache: unexpected hostname "bad" doesn't match expected "foo.bar"`; got != want { t.Errorf("bad dial error got %q; want %q", got, want) } } func TestInterleaveSlices(t *testing.T) { testCases := []struct { name string a, b []int want []int }{ {name: "equal", a: []int{1, 3, 5}, b: []int{2, 4, 6}, want: []int{1, 2, 3, 4, 5, 6}}, {name: "short_b", a: []int{1, 3, 5}, b: []int{2, 4}, want: []int{1, 2, 3, 4, 5}}, {name: "short_a", a: []int{1, 3}, b: []int{2, 4, 6}, want: []int{1, 2, 3, 4, 6}}, {name: "len_1", a: []int{1}, b: []int{2, 4, 6}, want: []int{1, 2, 4, 6}}, {name: "nil_a", a: nil, b: []int{2, 4, 6}, want: []int{2, 4, 6}}, {name: "nil_all", a: nil, b: nil, want: []int{}}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { merged := interleaveSlices(tc.a, tc.b) if !reflect.DeepEqual(merged, tc.want) { t.Errorf("got %v; want %v", merged, tc.want) } }) } } func TestShouldTryBootstrap(t *testing.T) { oldDebug := debug t.Cleanup(func() { debug = oldDebug }) debug = func() bool { return true } type step struct { ip netip.Addr // IP we pretended to dial err error // the dial error or nil for success } canceled, cancel := context.WithCancel(context.Background()) cancel() deadlineExceeded, cancel := context.WithTimeout(context.Background(), 0) defer cancel() ctx := context.Background() errFailed := errors.New("some failure") cacheWithFallback := &Resolver{ LookupIPFallback: func(_ context.Context, _ string) ([]netip.Addr, error) { panic("unimplemented") }, } cacheNoFallback := &Resolver{} testCases := []struct { name string steps []step ctx context.Context err error noFallback bool want bool }{ { name: "no-error", ctx: ctx, err: nil, want: false, }, { name: "canceled", ctx: canceled, err: errFailed, want: false, }, { name: "deadline-exceeded", ctx: deadlineExceeded, err: errFailed, want: false, }, { name: "no-fallback", ctx: ctx, err: errFailed, noFallback: true, want: false, }, { name: "dns-was-trustworthy", ctx: ctx, err: errFailed, steps: []step{ {netip.MustParseAddr("2003::1"), nil}, {netip.MustParseAddr("2003::1"), errFailed}, }, want: false, }, { name: "should-bootstrap", ctx: ctx, err: errFailed, want: true, }, } for _, tt := range testCases { t.Run(tt.name, func(t *testing.T) { d := &dialer{ pastConnect: map[netip.Addr]time.Time{}, } if tt.noFallback { d.dnsCache = cacheNoFallback } else { d.dnsCache = cacheWithFallback } dc := &dialCall{d: d} for _, st := range tt.steps { dc.noteDialResult(st.ip, st.err) } got := d.shouldTryBootstrap(tt.ctx, tt.err, dc) if got != tt.want { t.Errorf("got %v; want %v", got, tt.want) } }) } }