diff --git a/control/controlclient/map.go b/control/controlclient/map.go index 72f23b012..136ac5003 100644 --- a/control/controlclient/map.go +++ b/control/controlclient/map.go @@ -693,6 +693,19 @@ func peerChangeDiff(was tailcfg.NodeView, n *tailcfg.Node) (_ *tailcfg.PeerChang if va == nil || vb == nil || *va != *vb { return nil, false } + case "ExitNodeDNSResolvers": + va, vb := was.ExitNodeDNSResolvers(), views.SliceOfViews(n.ExitNodeDNSResolvers) + + if va.Len() != vb.Len() { + return nil, false + } + + for i := range va.LenIter() { + if !va.At(i).Equal(vb.At(i)) { + return nil, false + } + } + } } if ret != nil { diff --git a/control/controlclient/map_test.go b/control/controlclient/map_test.go index bddec375f..f4daf074e 100644 --- a/control/controlclient/map_test.go +++ b/control/controlclient/map_test.go @@ -20,6 +20,7 @@ import ( "tailscale.com/tailcfg" "tailscale.com/tstest" "tailscale.com/tstime" + "tailscale.com/types/dnstype" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/netmap" @@ -835,6 +836,40 @@ func TestPatchifyPeersChanged(t *testing.T) { }, }, }, + { + name: "change_exitnodednsresolvers", + mr0: &tailcfg.MapResponse{ + Node: &tailcfg.Node{Name: "foo.bar.ts.net."}, + Peers: []*tailcfg.Node{ + {ID: 1, ExitNodeDNSResolvers: []*dnstype.Resolver{{Addr: "dns.exmaple.com"}}, Hostinfo: hi}, + }, + }, + mr1: &tailcfg.MapResponse{ + PeersChanged: []*tailcfg.Node{ + {ID: 1, ExitNodeDNSResolvers: []*dnstype.Resolver{{Addr: "dns2.exmaple.com"}}, Hostinfo: hi}, + }, + }, + want: &tailcfg.MapResponse{ + PeersChanged: []*tailcfg.Node{ + {ID: 1, ExitNodeDNSResolvers: []*dnstype.Resolver{{Addr: "dns2.exmaple.com"}}, Hostinfo: hi}, + }, + }, + }, + { + name: "same_exitnoderesolvers", + mr0: &tailcfg.MapResponse{ + Node: &tailcfg.Node{Name: "foo.bar.ts.net."}, + Peers: []*tailcfg.Node{ + {ID: 1, ExitNodeDNSResolvers: []*dnstype.Resolver{{Addr: "dns.exmaple.com"}}, Hostinfo: hi}, + }, + }, + mr1: &tailcfg.MapResponse{ + PeersChanged: []*tailcfg.Node{ + {ID: 1, ExitNodeDNSResolvers: []*dnstype.Resolver{{Addr: "dns.exmaple.com"}}, Hostinfo: hi}, + }, + }, + want: &tailcfg.MapResponse{}, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/tailcfg/tailcfg.go b/tailcfg/tailcfg.go index 4fff81089..f5c9f3d41 100644 --- a/tailcfg/tailcfg.go +++ b/tailcfg/tailcfg.go @@ -336,6 +336,10 @@ type Node struct { // is not expected to speak Disco or DERP, and it must have Endpoints in // order to be reachable. IsWireGuardOnly bool `json:",omitempty"` + + // ExitNodeDNSResolvers is the list of DNS servers that should be used when this + // node is marked IsWireGuardOnly and being used as an exit node. + ExitNodeDNSResolvers []*dnstype.Resolver `json:",omitempty"` } // DisplayName returns the user-facing name for a node which should diff --git a/tailcfg/tailcfg_clone.go b/tailcfg/tailcfg_clone.go index 6fb2c81b5..76e727444 100644 --- a/tailcfg/tailcfg_clone.go +++ b/tailcfg/tailcfg_clone.go @@ -65,6 +65,12 @@ func (src *Node) Clone() *Node { if dst.SelfNodeV4MasqAddrForThisPeer != nil { dst.SelfNodeV4MasqAddrForThisPeer = ptr.To(*src.SelfNodeV4MasqAddrForThisPeer) } + if src.ExitNodeDNSResolvers != nil { + dst.ExitNodeDNSResolvers = make([]*dnstype.Resolver, len(src.ExitNodeDNSResolvers)) + for i := range dst.ExitNodeDNSResolvers { + dst.ExitNodeDNSResolvers[i] = src.ExitNodeDNSResolvers[i].Clone() + } + } return dst } @@ -101,6 +107,7 @@ var _NodeCloneNeedsRegeneration = Node(struct { Expired bool SelfNodeV4MasqAddrForThisPeer *netip.Addr IsWireGuardOnly bool + ExitNodeDNSResolvers []*dnstype.Resolver }{}) // Clone makes a deep copy of Hostinfo. diff --git a/tailcfg/tailcfg_test.go b/tailcfg/tailcfg_test.go index de0641506..5ed99c9cd 100644 --- a/tailcfg/tailcfg_test.go +++ b/tailcfg/tailcfg_test.go @@ -350,7 +350,7 @@ func TestNodeEqual(t *testing.T) { "UnsignedPeerAPIOnly", "ComputedName", "computedHostIfDifferent", "ComputedNameWithHost", "DataPlaneAuditLogID", "Expired", "SelfNodeV4MasqAddrForThisPeer", - "IsWireGuardOnly", + "IsWireGuardOnly", "ExitNodeDNSResolvers", } if have := fieldsOf(reflect.TypeOf(Node{})); !reflect.DeepEqual(have, nodeHandles) { t.Errorf("Node.Equal check might be out of sync\nfields: %q\nhandled: %q\n", diff --git a/tailcfg/tailcfg_view.go b/tailcfg/tailcfg_view.go index 0b9250412..0ec0142f6 100644 --- a/tailcfg/tailcfg_view.go +++ b/tailcfg/tailcfg_view.go @@ -180,7 +180,10 @@ func (v NodeView) SelfNodeV4MasqAddrForThisPeer() *netip.Addr { return &x } -func (v NodeView) IsWireGuardOnly() bool { return v.ж.IsWireGuardOnly } +func (v NodeView) IsWireGuardOnly() bool { return v.ж.IsWireGuardOnly } +func (v NodeView) ExitNodeDNSResolvers() views.SliceView[*dnstype.Resolver, dnstype.ResolverView] { + return views.SliceOfViews[*dnstype.Resolver, dnstype.ResolverView](v.ж.ExitNodeDNSResolvers) +} func (v NodeView) Equal(v2 NodeView) bool { return v.ж.Equal(v2.ж) } // A compilation failure here means this code must be regenerated, with the command at the top of this file. @@ -216,6 +219,7 @@ var _NodeViewNeedsRegeneration = Node(struct { Expired bool SelfNodeV4MasqAddrForThisPeer *netip.Addr IsWireGuardOnly bool + ExitNodeDNSResolvers []*dnstype.Resolver }{}) // View returns a readonly view of Hostinfo. diff --git a/types/dnstype/dnstype.go b/types/dnstype/dnstype.go index a5137fa79..ae3d1defc 100644 --- a/types/dnstype/dnstype.go +++ b/types/dnstype/dnstype.go @@ -8,6 +8,7 @@ package dnstype import ( "net/netip" + "slices" ) // Resolver is the configuration for one DNS resolver. @@ -51,3 +52,15 @@ func (r *Resolver) IPPort() (ipp netip.AddrPort, ok bool) { } return } + +// Equal reports whether r and other are equal. +func (r *Resolver) Equal(other *Resolver) bool { + if r == nil || other == nil { + return r == other + } + if r == other { + return true + } + + return r.Addr == other.Addr && slices.Equal(r.BootstrapResolution, other.BootstrapResolution) +} diff --git a/types/dnstype/dnstype_test.go b/types/dnstype/dnstype_test.go new file mode 100644 index 000000000..bd8986e7f --- /dev/null +++ b/types/dnstype/dnstype_test.go @@ -0,0 +1,81 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package dnstype + +import ( + "net/netip" + "reflect" + "slices" + "sort" + "testing" +) + +func TestResolverEqual(t *testing.T) { + var fieldNames []string + for _, field := range reflect.VisibleFields(reflect.TypeOf(Resolver{})) { + fieldNames = append(fieldNames, field.Name) + } + sort.Strings(fieldNames) + if !slices.Equal(fieldNames, []string{"Addr", "BootstrapResolution"}) { + t.Errorf("Resolver fields changed; update test") + } + + tests := []struct { + name string + a, b *Resolver + want bool + }{ + { + name: "nil", + a: nil, + b: nil, + want: true, + }, + { + name: "nil vs non-nil", + a: nil, + b: &Resolver{}, + want: false, + }, + { + name: "non-nil vs nil", + a: &Resolver{}, + b: nil, + want: false, + }, + { + name: "equal", + a: &Resolver{Addr: "dns.example.com"}, + b: &Resolver{Addr: "dns.example.com"}, + want: true, + }, + { + name: "not equal addrs", + a: &Resolver{Addr: "dns.example.com"}, + b: &Resolver{Addr: "dns2.example.com"}, + want: false, + }, + { + name: "not equal bootstrap", + a: &Resolver{ + Addr: "dns.example.com", + BootstrapResolution: []netip.Addr{netip.MustParseAddr("8.8.8.8")}, + }, + b: &Resolver{ + Addr: "dns.example.com", + BootstrapResolution: []netip.Addr{netip.MustParseAddr("8.8.4.4")}, + }, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.a.Equal(tt.b) + if got != tt.want { + t.Errorf("got %v; want %v", got, tt.want) + } + }) + } +} diff --git a/types/dnstype/dnstype_view.go b/types/dnstype/dnstype_view.go index b8f1e0312..c0e2b28ff 100644 --- a/types/dnstype/dnstype_view.go +++ b/types/dnstype/dnstype_view.go @@ -64,6 +64,7 @@ func (v ResolverView) Addr() string { return v.ж.Addr } func (v ResolverView) BootstrapResolution() views.Slice[netip.Addr] { return views.SliceOf(v.ж.BootstrapResolution) } +func (v ResolverView) Equal(v2 ResolverView) bool { return v.ж.Equal(v2.ж) } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _ResolverViewNeedsRegeneration = Resolver(struct {