// Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause package tstest import ( "net/netip" "reflect" "testing" "time" "tailscale.com/types/ptr" ) // IsZeroable is the interface for things with an IsZero method. type IsZeroable interface { IsZero() bool } var ( netipAddrType = reflect.TypeFor[netip.Addr]() netipAddrPortType = reflect.TypeFor[netip.AddrPort]() netipPrefixType = reflect.TypeFor[netip.Prefix]() timeType = reflect.TypeFor[time.Time]() timePtrType = reflect.TypeFor[*time.Time]() ) // CheckIsZero checks that the IsZero method of a given type functions // correctly, by instantiating a new value of that type, changing a field, and // then checking that the IsZero method returns false. // // The nonzeroValues map should contain non-zero values for each type that // exists in the type T or any contained types. Basic types like string, bool, // and numeric types are handled automatically. func CheckIsZero[T IsZeroable](t testing.TB, nonzeroValues map[reflect.Type]any) { t.Helper() var zero T if !zero.IsZero() { t.Errorf("zero value of %T is not IsZero", zero) return } var nonEmptyValue func(t reflect.Type) reflect.Value nonEmptyValue = func(ty reflect.Type) reflect.Value { if v, ok := nonzeroValues[ty]; ok { return reflect.ValueOf(v) } switch ty { // Given that we're a networking company, probably fine to have // a special case for netip.Addr :) case netipAddrType: return reflect.ValueOf(netip.MustParseAddr("1.2.3.4")) case netipAddrPortType: return reflect.ValueOf(netip.MustParseAddrPort("1.2.3.4:9999")) case netipPrefixType: return reflect.ValueOf(netip.MustParsePrefix("1.2.3.4/24")) case timeType: return reflect.ValueOf(time.Unix(1704067200, 0)) case timePtrType: return reflect.ValueOf(ptr.To(time.Unix(1704067200, 0))) } switch ty.Kind() { case reflect.String: return reflect.ValueOf("foo").Convert(ty) case reflect.Bool: return reflect.ValueOf(true) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: return reflect.ValueOf(int64(-42)).Convert(ty) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: return reflect.ValueOf(uint64(42)).Convert(ty) case reflect.Float32, reflect.Float64: return reflect.ValueOf(float64(3.14)).Convert(ty) case reflect.Complex64, reflect.Complex128: return reflect.ValueOf(complex(3.14, 2.71)).Convert(ty) case reflect.Chan: return reflect.MakeChan(ty, 1) // For slices, ensure that the slice is non-empty. case reflect.Slice: v := nonEmptyValue(ty.Elem()) sl := reflect.MakeSlice(ty, 1, 1) sl.Index(0).Set(v) return sl case reflect.Map: // Create a map with a single key-value pair, recursively creating each. k := nonEmptyValue(ty.Key()) v := nonEmptyValue(ty.Elem()) m := reflect.MakeMap(ty) m.SetMapIndex(k, v) return m default: panic("unhandled type " + ty.String()) } } typ := reflect.TypeFor[T]() for i, n := 0, typ.NumField(); i < n; i++ { sf := typ.Field(i) var nonzero T rv := reflect.ValueOf(&nonzero).Elem() rv.Field(i).Set(nonEmptyValue(sf.Type)) if nonzero.IsZero() { t.Errorf("IsZero = true with %v set; want false\nvalue: %#v", sf.Name, nonzero) } } }