diff --git a/cmd/viewer/tests/tests.go b/cmd/viewer/tests/tests.go index ed4d6914a..17bc1edc3 100644 --- a/cmd/viewer/tests/tests.go +++ b/cmd/viewer/tests/tests.go @@ -9,10 +9,11 @@ import ( "net/netip" "golang.org/x/exp/constraints" + "tailscale.com/types/ptr" "tailscale.com/types/views" ) -//go:generate go run tailscale.com/cmd/viewer --type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded,GenericIntStruct,GenericNoPtrsStruct,GenericCloneableStruct --clone-only-type=OnlyGetClone +//go:generate go run tailscale.com/cmd/viewer --type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded,GenericIntStruct,GenericNoPtrsStruct,GenericCloneableStruct,StructWithContainers --clone-only-type=OnlyGetClone type StructWithoutPtrs struct { Int int @@ -114,3 +115,50 @@ type GenericCloneableStruct[T views.ViewCloner[T, V], V views.StructView[T]] str PtrValueMap map[string]*T SliceMap map[string][]T } + +// Container is a pre-defined container type, such as a collection, an optional +// value or a generic wrapper. +type Container[T any] struct { + Item T +} + +func (c *Container[T]) Clone() *Container[T] { + if c == nil { + return nil + } + if cloner, ok := any(c.Item).(views.Cloner[T]); ok { + return &Container[T]{cloner.Clone()} + } + if !views.ContainsPointers[T]() { + return ptr.To(*c) + } + panic(fmt.Errorf("%T contains pointers, but is not cloneable", c.Item)) +} + +// ContainerView is a pre-defined readonly view of a Container[T]. +type ContainerView[T views.ViewCloner[T, V], V views.StructView[T]] struct { + // ж is the underlying mutable value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + ж *Container[T] +} + +func (cv ContainerView[T, V]) Item() V { + return cv.ж.Item.View() +} + +func ContainerViewOf[T views.ViewCloner[T, V], V views.StructView[T]](c *Container[T]) ContainerView[T, V] { + return ContainerView[T, V]{c} +} + +type GenericBasicStruct[T BasicType] struct { + Value T +} + +type StructWithContainers struct { + IntContainer Container[int] + CloneableContainer Container[*StructWithPtrs] + BasicGenericContainer Container[GenericBasicStruct[int]] + ClonableGenericContainer Container[*GenericNoPtrsStruct[int]] +} diff --git a/cmd/viewer/tests/tests_clone.go b/cmd/viewer/tests/tests_clone.go index ec5631da9..b4d92d3ec 100644 --- a/cmd/viewer/tests/tests_clone.go +++ b/cmd/viewer/tests/tests_clone.go @@ -416,3 +416,24 @@ func _GenericCloneableStructCloneNeedsRegeneration[T views.ViewCloner[T, V], V v SliceMap map[string][]T }{}) } + +// Clone makes a deep copy of StructWithContainers. +// The result aliases no memory with the original. +func (src *StructWithContainers) Clone() *StructWithContainers { + if src == nil { + return nil + } + dst := new(StructWithContainers) + *dst = *src + dst.CloneableContainer = *src.CloneableContainer.Clone() + dst.ClonableGenericContainer = *src.ClonableGenericContainer.Clone() + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _StructWithContainersCloneNeedsRegeneration = StructWithContainers(struct { + IntContainer Container[int] + CloneableContainer Container[*StructWithPtrs] + BasicGenericContainer Container[GenericBasicStruct[int]] + ClonableGenericContainer Container[*GenericNoPtrsStruct[int]] +}{}) diff --git a/cmd/viewer/tests/tests_view.go b/cmd/viewer/tests/tests_view.go index 9a337f5aa..44618e79e 100644 --- a/cmd/viewer/tests/tests_view.go +++ b/cmd/viewer/tests/tests_view.go @@ -14,7 +14,7 @@ import ( "tailscale.com/types/views" ) -//go:generate go run tailscale.com/cmd/cloner -clonefunc=false -type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded,GenericIntStruct,GenericNoPtrsStruct,GenericCloneableStruct +//go:generate go run tailscale.com/cmd/cloner -clonefunc=false -type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded,GenericIntStruct,GenericNoPtrsStruct,GenericCloneableStruct,StructWithContainers // View returns a readonly view of StructWithPtrs. func (p *StructWithPtrs) View() StructWithPtrsView { @@ -604,3 +604,67 @@ func _GenericCloneableStructViewNeedsRegeneration[T views.ViewCloner[T, V], V vi SliceMap map[string][]T }{}) } + +// View returns a readonly view of StructWithContainers. +func (p *StructWithContainers) View() StructWithContainersView { + return StructWithContainersView{ж: p} +} + +// StructWithContainersView provides a read-only view over StructWithContainers. +// +// Its methods should only be called if `Valid()` returns true. +type StructWithContainersView struct { + // ж is the underlying mutable value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + ж *StructWithContainers +} + +// Valid reports whether underlying value is non-nil. +func (v StructWithContainersView) Valid() bool { return v.ж != nil } + +// AsStruct returns a clone of the underlying value which aliases no memory with +// the original. +func (v StructWithContainersView) AsStruct() *StructWithContainers { + if v.ж == nil { + return nil + } + return v.ж.Clone() +} + +func (v StructWithContainersView) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) } + +func (v *StructWithContainersView) UnmarshalJSON(b []byte) error { + if v.ж != nil { + return errors.New("already initialized") + } + if len(b) == 0 { + return nil + } + var x StructWithContainers + if err := json.Unmarshal(b, &x); err != nil { + return err + } + v.ж = &x + return nil +} + +func (v StructWithContainersView) IntContainer() Container[int] { return v.ж.IntContainer } +func (v StructWithContainersView) CloneableContainer() ContainerView[*StructWithPtrs, StructWithPtrsView] { + return ContainerViewOf(&v.ж.CloneableContainer) +} +func (v StructWithContainersView) BasicGenericContainer() Container[GenericBasicStruct[int]] { + return v.ж.BasicGenericContainer +} +func (v StructWithContainersView) ClonableGenericContainer() ContainerView[*GenericNoPtrsStruct[int], GenericNoPtrsStructView[int]] { + return ContainerViewOf(&v.ж.ClonableGenericContainer) +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _StructWithContainersViewNeedsRegeneration = StructWithContainers(struct { + IntContainer Container[int] + CloneableContainer Container[*StructWithPtrs] + BasicGenericContainer Container[GenericBasicStruct[int]] + ClonableGenericContainer Container[*GenericNoPtrsStruct[int]] +}{}) diff --git a/cmd/viewer/viewer.go b/cmd/viewer/viewer.go index d2be6af66..d77875ba8 100644 --- a/cmd/viewer/viewer.go +++ b/cmd/viewer/viewer.go @@ -13,9 +13,11 @@ import ( "html/template" "log" "os" + "slices" "strings" "tailscale.com/util/codegen" + "tailscale.com/util/must" ) const viewTemplateStr = `{{define "common"}} @@ -75,6 +77,8 @@ func (v *{{.ViewName}}{{.TypeParamNames}}) UnmarshalJSON(b []byte) error { {{end}} {{define "viewField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldViewName}} { return v.ж.{{.FieldName}}.View() } {{end}} +{{define "makeViewField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldViewName}} { return {{.MakeViewFnName}}(&v.ж.{{.FieldName}}) } +{{end}} {{define "valuePointerField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldType}} { if v.ж.{{.FieldName}} == nil { return nil @@ -144,6 +148,9 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi MapValueType string MapValueView string MapFn string + + // MakeViewFnName is the name of the function that accepts a value and returns a readonly view of it. + MakeViewFnName string }{ StructName: typ.Obj().Name(), ViewName: typ.Origin().Obj().Name() + "View", @@ -227,8 +234,18 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi strucT := underlying args.FieldType = it.QualifiedName(fieldType) if codegen.ContainsPointers(strucT) { - args.FieldViewName = appendNameSuffix(args.FieldType, "View") - writeTemplate("viewField") + if viewType := viewTypeForValueType(fieldType); viewType != nil { + args.FieldViewName = it.QualifiedName(viewType) + writeTemplate("viewField") + continue + } + if viewType, makeViewFn := viewTypeForContainerType(fieldType); viewType != nil { + args.FieldViewName = it.QualifiedName(viewType) + args.MakeViewFnName = it.PackagePrefix(makeViewFn.Pkg()) + makeViewFn.Name() + writeTemplate("makeViewField") + continue + } + writeTemplate("unsupportedField") continue } writeTemplate("valueField") @@ -388,6 +405,9 @@ func appendNameSuffix(name, suffix string) string { } func viewTypeForValueType(typ types.Type) types.Type { + if ptr, ok := typ.(*types.Pointer); ok { + return viewTypeForValueType(ptr.Elem()) + } viewMethod := codegen.LookupMethod(typ, "View") if viewMethod == nil { return nil @@ -399,12 +419,116 @@ func viewTypeForValueType(typ types.Type) types.Type { return sig.Results().At(0).Type() } +func viewTypeForContainerType(typ types.Type) (*types.Named, *types.Func) { + // The container type should be an instantiated generic type, + // with its first type parameter specifying the element type. + containerType, ok := typ.(*types.Named) + if !ok || containerType.TypeArgs().Len() == 0 { + return nil, nil + } + + // Look up the view type for the container type. + // It must include an additional type parameter specifying the element's view type. + // For example, Container[T] => ContainerView[T, V]. + containerViewTypeName := containerType.Obj().Name() + "View" + containerViewTypeObj, ok := containerType.Obj().Pkg().Scope().Lookup(containerViewTypeName).(*types.TypeName) + if !ok { + return nil, nil + } + containerViewGenericType, ok := containerViewTypeObj.Type().(*types.Named) + if !ok || containerViewGenericType.TypeParams().Len() != containerType.TypeArgs().Len()+1 { + return nil, nil + } + + // Create a list of type arguments for instantiating the container view type. + // Include all type arguments specified for the container type... + containerViewTypeArgs := make([]types.Type, containerViewGenericType.TypeParams().Len()) + for i := range containerType.TypeArgs().Len() { + containerViewTypeArgs[i] = containerType.TypeArgs().At(i) + } + // ...and add the element view type. + // For that, we need to first determine the named elem type... + elemType, ok := baseType(containerType.TypeArgs().At(0)).(*types.Named) + if !ok { + return nil, nil + } + // ...then infer the view type from it. + var elemViewType *types.Named + elemTypeName := elemType.Obj().Name() + elemViewTypeBaseName := elemType.Obj().Name() + "View" + if elemViewTypeName, ok := elemType.Obj().Pkg().Scope().Lookup(elemViewTypeBaseName).(*types.TypeName); ok { + // The elem's view type is already defined in the same package as the elem type. + elemViewType = elemViewTypeName.Type().(*types.Named) + } else if slices.Contains(typeNames, elemTypeName) { + // The elem's view type has not been generated yet, but we can define + // and use a blank type with the expected view type name. + elemViewTypeName = types.NewTypeName(0, elemType.Obj().Pkg(), elemViewTypeBaseName, nil) + elemViewType = types.NewNamed(elemViewTypeName, types.NewStruct(nil, nil), nil) + if elemTypeParams := elemType.TypeParams(); elemTypeParams != nil { + elemViewType.SetTypeParams(collectTypeParams(elemTypeParams)) + } + } else { + // The elem view type does not exist and won't be generated. + return nil, nil + } + // If elemType is an instantiated generic type, instantiate the elemViewType as well. + if elemTypeArgs := elemType.TypeArgs(); elemTypeArgs != nil { + elemViewType = must.Get(types.Instantiate(nil, elemViewType, collectTypes(elemTypeArgs), false)).(*types.Named) + } + // And finally set the elemViewType as the last type argument. + containerViewTypeArgs[len(containerViewTypeArgs)-1] = elemViewType + + // Instantiate the container view type with the specified type arguments. + containerViewType := must.Get(types.Instantiate(nil, containerViewGenericType, containerViewTypeArgs, false)) + // Look up a function to create a view of a container. + // It should be in the same package as the container type, named {ViewType}Of, + // and have a signature like {ViewType}Of(c *Container[T]) ContainerView[T, V]. + makeContainerView, ok := containerType.Obj().Pkg().Scope().Lookup(containerViewTypeName + "Of").(*types.Func) + if !ok { + return nil, nil + } + return containerViewType.(*types.Named), makeContainerView +} + +func baseType(typ types.Type) types.Type { + if ptr, ok := typ.(*types.Pointer); ok { + return ptr.Elem() + } + return typ +} + +func collectTypes(list *types.TypeList) []types.Type { + // TODO(nickkhyl): use slices.Collect in Go 1.23? + if list.Len() == 0 { + return nil + } + res := make([]types.Type, list.Len()) + for i := range res { + res[i] = list.At(i) + } + return res +} + +func collectTypeParams(list *types.TypeParamList) []*types.TypeParam { + if list.Len() == 0 { + return nil + } + res := make([]*types.TypeParam, list.Len()) + for i := range res { + p := list.At(i) + res[i] = types.NewTypeParam(p.Obj(), p.Constraint()) + } + return res +} + var ( flagTypes = flag.String("type", "", "comma-separated list of types; required") flagBuildTags = flag.String("tags", "", "compiler build tags to apply") flagCloneFunc = flag.Bool("clonefunc", false, "add a top-level Clone func") flagCloneOnlyTypes = flag.String("clone-only-type", "", "comma-separated list of types (a subset of --type) that should only generate a go:generate clone line and not actual views") + + typeNames []string ) func main() { @@ -415,7 +539,7 @@ func main() { flag.Usage() os.Exit(2) } - typeNames := strings.Split(*flagTypes, ",") + typeNames = strings.Split(*flagTypes, ",") var flagArgs []string flagArgs = append(flagArgs, fmt.Sprintf("-clonefunc=%v", *flagCloneFunc)) diff --git a/types/views/views.go b/types/views/views.go index 42758966f..4edd72688 100644 --- a/types/views/views.go +++ b/types/views/views.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "maps" + "reflect" "slices" "go4.org/mem" @@ -111,6 +112,13 @@ type StructView[T any] interface { AsStruct() T } +// Cloner is any type that has a Clone function returning a deep-clone of the receiver. +type Cloner[T any] interface { + // Clone returns a deep-clone of the receiver. + // It returns nil, when the receiver is nil. + Clone() T +} + // ViewCloner is any type that has had View and Clone funcs generated using // tailscale.com/cmd/viewer. type ViewCloner[T any, V StructView[T]] interface { @@ -555,3 +563,46 @@ func (m MapFn[K, T, V]) Range(f MapRangeFn[K, V]) { } } } + +// ContainsPointers reports whether T contains any pointers, +// either explicitly or implicitly. +// It has special handling for some types that contain pointers +// that we know are free from memory aliasing/mutation concerns. +func ContainsPointers[T any]() bool { + return containsPointers(reflect.TypeFor[T]()) +} + +func containsPointers(typ reflect.Type) bool { + switch typ.Kind() { + case reflect.Pointer, reflect.UnsafePointer: + return true + case reflect.Chan, reflect.Map, reflect.Slice: + return true + case reflect.Array: + return containsPointers(typ.Elem()) + case reflect.Interface, reflect.Func: + return true // err on the safe side. + case reflect.Struct: + if isWellKnownImmutableStruct(typ) { + return false + } + for i := range typ.NumField() { + if containsPointers(typ.Field(i).Type) { + return true + } + } + } + return false +} + +func isWellKnownImmutableStruct(typ reflect.Type) bool { + switch typ.String() { + case "time.Time": + // time.Time contains a pointer that does not need copying + return true + case "netip.Addr", "netip.Prefix", "netip.AddrPort": + return true + default: + return false + } +} diff --git a/types/views/views_test.go b/types/views/views_test.go index 0173d3207..1a4f1f2d4 100644 --- a/types/views/views_test.go +++ b/types/views/views_test.go @@ -10,6 +10,7 @@ import ( "reflect" "strings" "testing" + "unsafe" qt "github.com/frankban/quicktest" ) @@ -22,6 +23,16 @@ type viewStruct struct { StringsPtr *Slice[string] `json:",omitempty"` } +type noPtrStruct struct { + Int int + Str string +} + +type withPtrStruct struct { + Int int + StrPtr *string +} + func BenchmarkSliceIteration(b *testing.B) { var data []viewStruct for i := range 10000 { @@ -189,3 +200,215 @@ func TestSliceMapKey(t *testing.T) { } } } + +func TestContainsPointers(t *testing.T) { + tests := []struct { + name string + typ reflect.Type + wantPtrs bool + }{ + { + name: "bool", + typ: reflect.TypeFor[bool](), + wantPtrs: false, + }, + { + name: "int", + typ: reflect.TypeFor[int](), + wantPtrs: false, + }, + { + name: "int8", + typ: reflect.TypeFor[int8](), + wantPtrs: false, + }, + { + name: "int16", + typ: reflect.TypeFor[int16](), + wantPtrs: false, + }, + { + name: "int32", + typ: reflect.TypeFor[int32](), + wantPtrs: false, + }, + { + name: "int64", + typ: reflect.TypeFor[int64](), + wantPtrs: false, + }, + { + name: "uint", + typ: reflect.TypeFor[uint](), + wantPtrs: false, + }, + { + name: "uint8", + typ: reflect.TypeFor[uint8](), + wantPtrs: false, + }, + { + name: "uint16", + typ: reflect.TypeFor[uint16](), + wantPtrs: false, + }, + { + name: "uint32", + typ: reflect.TypeFor[uint32](), + wantPtrs: false, + }, + { + name: "uint64", + typ: reflect.TypeFor[uint64](), + wantPtrs: false, + }, + { + name: "uintptr", + typ: reflect.TypeFor[uintptr](), + wantPtrs: false, + }, + { + name: "string", + typ: reflect.TypeFor[string](), + wantPtrs: false, + }, + { + name: "float32", + typ: reflect.TypeFor[float32](), + wantPtrs: false, + }, + { + name: "float64", + typ: reflect.TypeFor[float64](), + wantPtrs: false, + }, + { + name: "complex64", + typ: reflect.TypeFor[complex64](), + wantPtrs: false, + }, + { + name: "complex128", + typ: reflect.TypeFor[complex128](), + wantPtrs: false, + }, + { + name: "netip-Addr", + typ: reflect.TypeFor[netip.Addr](), + wantPtrs: false, + }, + { + name: "netip-Prefix", + typ: reflect.TypeFor[netip.Prefix](), + wantPtrs: false, + }, + { + name: "netip-AddrPort", + typ: reflect.TypeFor[netip.AddrPort](), + wantPtrs: false, + }, + { + name: "bool-ptr", + typ: reflect.TypeFor[*bool](), + wantPtrs: true, + }, + { + name: "string-ptr", + typ: reflect.TypeFor[*string](), + wantPtrs: true, + }, + { + name: "netip-Addr-ptr", + typ: reflect.TypeFor[*netip.Addr](), + wantPtrs: true, + }, + { + name: "unsafe-ptr", + typ: reflect.TypeFor[unsafe.Pointer](), + wantPtrs: true, + }, + { + name: "no-ptr-struct", + typ: reflect.TypeFor[noPtrStruct](), + wantPtrs: false, + }, + { + name: "ptr-struct", + typ: reflect.TypeFor[withPtrStruct](), + wantPtrs: true, + }, + { + name: "string-array", + typ: reflect.TypeFor[[5]string](), + wantPtrs: false, + }, + { + name: "int-ptr-array", + typ: reflect.TypeFor[[5]*int](), + wantPtrs: true, + }, + { + name: "no-ptr-struct-array", + typ: reflect.TypeFor[[5]noPtrStruct](), + wantPtrs: false, + }, + { + name: "with-ptr-struct-array", + typ: reflect.TypeFor[[5]withPtrStruct](), + wantPtrs: true, + }, + { + name: "string-slice", + typ: reflect.TypeFor[[]string](), + wantPtrs: true, + }, + { + name: "int-ptr-slice", + typ: reflect.TypeFor[[]int](), + wantPtrs: true, + }, + { + name: "no-ptr-struct-slice", + typ: reflect.TypeFor[[]noPtrStruct](), + wantPtrs: true, + }, + { + name: "string-map", + typ: reflect.TypeFor[map[string]string](), + wantPtrs: true, + }, + { + name: "int-map", + typ: reflect.TypeFor[map[int]int](), + wantPtrs: true, + }, + { + name: "no-ptr-struct-map", + typ: reflect.TypeFor[map[string]noPtrStruct](), + wantPtrs: true, + }, + { + name: "chan", + typ: reflect.TypeFor[chan int](), + wantPtrs: true, + }, + { + name: "func", + typ: reflect.TypeFor[func()](), + wantPtrs: true, + }, + { + name: "interface", + typ: reflect.TypeFor[any](), + wantPtrs: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if gotPtrs := containsPointers(tt.typ); gotPtrs != tt.wantPtrs { + t.Errorf("got %v; want %v", gotPtrs, tt.wantPtrs) + } + }) + } +} diff --git a/util/codegen/codegen.go b/util/codegen/codegen.go index 13dbc94a4..dea8faef6 100644 --- a/util/codegen/codegen.go +++ b/util/codegen/codegen.go @@ -111,6 +111,14 @@ func (it *ImportTracker) QualifiedName(t types.Type) string { return types.TypeString(t, it.qualifier) } +// PackagePrefix returns the prefix to be used when referencing named objects from pkg. +func (it *ImportTracker) PackagePrefix(pkg *types.Package) string { + if s := it.qualifier(pkg); s != "" { + return s + "." + } + return "" +} + // Write prints all the tracked imports in a single import block to w. func (it *ImportTracker) Write(w io.Writer) { fmt.Fprintf(w, "import (\n")