diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 8e561b3e0..40ab81447 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -3288,6 +3288,13 @@ func dnsConfigForNetmap(nm *netmap.NetworkMap, prefs ipn.PrefsView, logf logger. return dcfg } + // If we're using an exit node and that exit node is IsWireGuardOnly with + // ExitNodeDNSResolver set, then add that as the default. + if resolvers, ok := wireguardExitNodeDNSResolvers(nm, prefs.ExitNodeID()); ok { + addDefault(resolvers) + return dcfg + } + addDefault(nm.DNS.Resolvers) for suffix, resolvers := range nm.DNS.Routes { fqdn, err := dnsname.ToFQDN(suffix) @@ -4676,6 +4683,30 @@ func exitNodeCanProxyDNS(nm *netmap.NetworkMap, exitNodeID tailcfg.StableNodeID) return "", false } +// wireguardExitNodeDNSResolvers returns the DNS resolvers to use for a +// WireGuard-only exit node, if it has resolver addresses. +func wireguardExitNodeDNSResolvers(nm *netmap.NetworkMap, exitNodeID tailcfg.StableNodeID) ([]*dnstype.Resolver, bool) { + if exitNodeID.IsZero() { + return nil, false + } + + for _, p := range nm.Peers { + if p.StableID() == exitNodeID && p.IsWireGuardOnly() { + resolvers := p.ExitNodeDNSResolvers() + if !resolvers.IsNil() && resolvers.Len() > 0 { + copies := make([]*dnstype.Resolver, resolvers.Len()) + for i := range resolvers.LenIter() { + copies[i] = resolvers.At(i).AsStruct() + } + return copies, true + } + return nil, false + } + } + + return nil, false +} + func peerCanProxyDNS(p tailcfg.NodeView) bool { if p.Cap() >= 26 { // Actually added at 25 diff --git a/ipn/ipnlocal/local_test.go b/ipn/ipnlocal/local_test.go index 6ad3a2b83..ad32a9757 100644 --- a/ipn/ipnlocal/local_test.go +++ b/ipn/ipnlocal/local_test.go @@ -22,6 +22,7 @@ import ( "tailscale.com/tailcfg" "tailscale.com/tsd" "tailscale.com/tstest" + "tailscale.com/types/dnstype" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/logid" @@ -952,3 +953,109 @@ func TestUpdateNetmapDelta(t *testing.T) { } } } + +func TestWireguardExitNodeDNSResolvers(t *testing.T) { + type tc struct { + name string + id tailcfg.StableNodeID + peers []*tailcfg.Node + wantOK bool + wantResolvers []*dnstype.Resolver + } + + tests := []tc{ + { + name: "no peers", + id: "1", + wantOK: false, + wantResolvers: nil, + }, + { + name: "non wireguard peer", + id: "1", + peers: []*tailcfg.Node{ + { + StableID: "1", + IsWireGuardOnly: false, + ExitNodeDNSResolvers: []*dnstype.Resolver{{Addr: "dns.example.com"}}, + }, + }, + wantOK: false, + wantResolvers: nil, + }, + { + name: "no matching IDs", + id: "2", + peers: []*tailcfg.Node{ + { + StableID: "1", + IsWireGuardOnly: true, + ExitNodeDNSResolvers: []*dnstype.Resolver{{Addr: "dns.example.com"}}, + }, + }, + wantOK: false, + wantResolvers: nil, + }, + { + name: "wireguard peer", + id: "1", + peers: []*tailcfg.Node{ + { + StableID: "1", + IsWireGuardOnly: true, + ExitNodeDNSResolvers: []*dnstype.Resolver{{Addr: "dns.example.com"}}, + }, + }, + wantOK: true, + wantResolvers: []*dnstype.Resolver{{Addr: "dns.example.com"}}, + }, + } + + for _, tc := range tests { + peers := nodeViews(tc.peers) + nm := &netmap.NetworkMap{ + Peers: peers, + } + gotResolvers, gotOK := wireguardExitNodeDNSResolvers(nm, tc.id) + + if gotOK != tc.wantOK || !resolversEqual(gotResolvers, tc.wantResolvers) { + t.Errorf("case: %s: got %v, %v, want %v, %v", tc.name, gotOK, gotResolvers, tc.wantOK, tc.wantResolvers) + } + } +} + +func TestDNSConfigForNetmapForWireguardExitNode(t *testing.T) { + resolvers := []*dnstype.Resolver{{Addr: "dns.example.com"}} + nm := &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + { + StableID: "1", + IsWireGuardOnly: true, + ExitNodeDNSResolvers: resolvers, + Hostinfo: (&tailcfg.Hostinfo{}).View(), + }, + }), + } + + prefs := &ipn.Prefs{ + ExitNodeID: "1", + CorpDNS: true, + } + + got := dnsConfigForNetmap(nm, prefs.View(), t.Logf, "") + if !resolversEqual(got.DefaultResolvers, resolvers) { + t.Errorf("got %v, want %v", got.DefaultResolvers, resolvers) + } +} + +func resolversEqual(a, b []*dnstype.Resolver) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if !a[i].Equal(b[i]) { + return false + } + } + return true +}