// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package dns import ( "io/ioutil" "os" "path/filepath" "testing" "inet.af/netaddr" "tailscale.com/util/dnsname" ) func TestSetDNS(t *testing.T) { const orig = "nameserver 9.9.9.9 # orig" tmp := t.TempDir() resolvPath := filepath.Join(tmp, "etc", "resolv.conf") backupPath := filepath.Join(tmp, "etc", "resolv.pre-tailscale-backup.conf") if err := os.MkdirAll(filepath.Dir(resolvPath), 0777); err != nil { t.Fatal(err) } if err := ioutil.WriteFile(resolvPath, []byte(orig), 0644); err != nil { t.Fatal(err) } readFile := func(t *testing.T, path string) string { t.Helper() b, err := ioutil.ReadFile(path) if err != nil { t.Fatal(err) } return string(b) } assertBaseState := func(t *testing.T) { if got := readFile(t, resolvPath); got != orig { t.Fatalf("resolv.conf:\n%s, want:\n%s", got, orig) } if _, err := os.Stat(backupPath); !os.IsNotExist(err) { t.Fatalf("resolv.conf backup: want it to be gone but: %v", err) } } m := directManager{fs: directFS{prefix: tmp}} if err := m.SetDNS(OSConfig{ Nameservers: []netaddr.IP{netaddr.MustParseIP("8.8.8.8"), netaddr.MustParseIP("8.8.4.4")}, SearchDomains: []dnsname.FQDN{"ts.net.", "ts-dns.test."}, MatchDomains: []dnsname.FQDN{"ignored."}, }); err != nil { t.Fatal(err) } want := `# resolv.conf(5) file generated by tailscale # DO NOT EDIT THIS FILE BY HAND -- CHANGES WILL BE OVERWRITTEN nameserver 8.8.8.8 nameserver 8.8.4.4 search ts.net ts-dns.test ` if got := readFile(t, resolvPath); got != want { t.Fatalf("resolv.conf:\n%s, want:\n%s", got, want) } if got := readFile(t, backupPath); got != orig { t.Fatalf("resolv.conf backup:\n%s, want:\n%s", got, orig) } // Test that a nil OSConfig cleans up resolv.conf. if err := m.SetDNS(OSConfig{}); err != nil { t.Fatal(err) } assertBaseState(t) // Test that Close cleans up resolv.conf. if err := m.SetDNS(OSConfig{Nameservers: []netaddr.IP{netaddr.MustParseIP("8.8.8.8")}}); err != nil { t.Fatal(err) } if err := m.Close(); err != nil { t.Fatal(err) } assertBaseState(t) }