diff --git a/net/dns/osconfig.go b/net/dns/osconfig.go index 012601b74..842c5ac60 100644 --- a/net/dns/osconfig.go +++ b/net/dns/osconfig.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "net/netip" + "slices" "strings" "tailscale.com/types/logger" @@ -103,10 +104,16 @@ func (o *OSConfig) WriteToBufioWriter(w *bufio.Writer) { } func (o OSConfig) IsZero() bool { - return len(o.Nameservers) == 0 && len(o.SearchDomains) == 0 && len(o.MatchDomains) == 0 + return len(o.Hosts) == 0 && + len(o.Nameservers) == 0 && + len(o.SearchDomains) == 0 && + len(o.MatchDomains) == 0 } func (a OSConfig) Equal(b OSConfig) bool { + if len(a.Hosts) != len(b.Hosts) { + return false + } if len(a.Nameservers) != len(b.Nameservers) { return false } @@ -117,6 +124,15 @@ func (a OSConfig) Equal(b OSConfig) bool { return false } + for i := range a.Hosts { + ha, hb := a.Hosts[i], b.Hosts[i] + if ha.Addr != hb.Addr { + return false + } + if !slices.Equal(ha.Hosts, hb.Hosts) { + return false + } + } for i := range a.Nameservers { if a.Nameservers[i] != b.Nameservers[i] { return false diff --git a/net/dns/osconfig_test.go b/net/dns/osconfig_test.go index 02b1cad9e..c19db299f 100644 --- a/net/dns/osconfig_test.go +++ b/net/dns/osconfig_test.go @@ -6,8 +6,10 @@ package dns import ( "fmt" "net/netip" + "reflect" "testing" + "tailscale.com/tstest" "tailscale.com/util/dnsname" ) @@ -41,3 +43,13 @@ func TestOSConfigPrintable(t *testing.T) { t.Errorf("format mismatch:\n got: %s\n want: %s", s, expected) } } + +func TestIsZero(t *testing.T) { + tstest.CheckIsZero[OSConfig](t, map[reflect.Type]any{ + reflect.TypeFor[dnsname.FQDN](): dnsname.FQDN("foo.bar."), + reflect.TypeFor[*HostEntry](): &HostEntry{ + Addr: netip.AddrFrom4([4]byte{100, 1, 2, 3}), + Hosts: []string{"foo", "bar"}, + }, + }) +} diff --git a/tstest/reflect.go b/tstest/reflect.go new file mode 100644 index 000000000..125391349 --- /dev/null +++ b/tstest/reflect.go @@ -0,0 +1,114 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tstest + +import ( + "net/netip" + "reflect" + "testing" + "time" + + "tailscale.com/types/ptr" +) + +// IsZeroable is the interface for things with an IsZero method. +type IsZeroable interface { + IsZero() bool +} + +var ( + netipAddrType = reflect.TypeFor[netip.Addr]() + netipAddrPortType = reflect.TypeFor[netip.AddrPort]() + netipPrefixType = reflect.TypeFor[netip.Prefix]() + timeType = reflect.TypeFor[time.Time]() + timePtrType = reflect.TypeFor[*time.Time]() +) + +// CheckIsZero checks that the IsZero method of a given type functions +// correctly, by instantiating a new value of that type, changing a field, and +// then checking that the IsZero method returns false. +// +// The nonzeroValues map should contain non-zero values for each type that +// exists in the type T or any contained types. Basic types like string, bool, +// and numeric types are handled automatically. +func CheckIsZero[T IsZeroable](t testing.TB, nonzeroValues map[reflect.Type]any) { + t.Helper() + + var zero T + if !zero.IsZero() { + t.Errorf("zero value of %T is not IsZero", zero) + return + } + + var nonEmptyValue func(t reflect.Type) reflect.Value + nonEmptyValue = func(ty reflect.Type) reflect.Value { + if v, ok := nonzeroValues[ty]; ok { + return reflect.ValueOf(v) + } + + switch ty { + // Given that we're a networking company, probably fine to have + // a special case for netip.Addr :) + case netipAddrType: + return reflect.ValueOf(netip.MustParseAddr("1.2.3.4")) + case netipAddrPortType: + return reflect.ValueOf(netip.MustParseAddrPort("1.2.3.4:9999")) + case netipPrefixType: + return reflect.ValueOf(netip.MustParsePrefix("1.2.3.4/24")) + + case timeType: + return reflect.ValueOf(time.Unix(1704067200, 0)) + case timePtrType: + return reflect.ValueOf(ptr.To(time.Unix(1704067200, 0))) + } + + switch ty.Kind() { + case reflect.String: + return reflect.ValueOf("foo").Convert(ty) + case reflect.Bool: + return reflect.ValueOf(true) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return reflect.ValueOf(int64(-42)).Convert(ty) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return reflect.ValueOf(uint64(42)).Convert(ty) + case reflect.Float32, reflect.Float64: + return reflect.ValueOf(float64(3.14)).Convert(ty) + case reflect.Complex64, reflect.Complex128: + return reflect.ValueOf(complex(3.14, 2.71)).Convert(ty) + case reflect.Chan: + return reflect.MakeChan(ty, 1) + + // For slices, ensure that the slice is non-empty. + case reflect.Slice: + v := nonEmptyValue(ty.Elem()) + sl := reflect.MakeSlice(ty, 1, 1) + sl.Index(0).Set(v) + return sl + + case reflect.Map: + // Create a map with a single key-value pair, recursively creating each. + k := nonEmptyValue(ty.Key()) + v := nonEmptyValue(ty.Elem()) + + m := reflect.MakeMap(ty) + m.SetMapIndex(k, v) + return m + + default: + panic("unhandled type " + ty.String()) + } + } + + typ := reflect.TypeFor[T]() + for i, n := 0, typ.NumField(); i < n; i++ { + sf := typ.Field(i) + + var nonzero T + rv := reflect.ValueOf(&nonzero).Elem() + rv.Field(i).Set(nonEmptyValue(sf.Type)) + if nonzero.IsZero() { + t.Errorf("IsZero = true with %v set; want false\nvalue: %#v", sf.Name, nonzero) + } + } +}