cmd/viewer, types/views, util/codegen: add viewer support for custom container types

This adds support for container-like types such as Container[T] that
don't explicitly specify a view type for T. Instead, a package implementing
a container type should also implement and export a ContainerView[T, V] type
and a ContainerViewOf(*Container[T]) ContainerView[T, V] function, which
returns a view for the specified container, inferring the element view type V
from the element type T.

Updates #12736

Signed-off-by: Nick Khyl <nickk@tailscale.com>
pull/12865/head
Nick Khyl 4 months ago committed by Nick Khyl
parent e7bf6e716b
commit 20562a4fb9

@ -9,10 +9,11 @@ import (
"net/netip"
"golang.org/x/exp/constraints"
"tailscale.com/types/ptr"
"tailscale.com/types/views"
)
//go:generate go run tailscale.com/cmd/viewer --type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded,GenericIntStruct,GenericNoPtrsStruct,GenericCloneableStruct --clone-only-type=OnlyGetClone
//go:generate go run tailscale.com/cmd/viewer --type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded,GenericIntStruct,GenericNoPtrsStruct,GenericCloneableStruct,StructWithContainers --clone-only-type=OnlyGetClone
type StructWithoutPtrs struct {
Int int
@ -114,3 +115,50 @@ type GenericCloneableStruct[T views.ViewCloner[T, V], V views.StructView[T]] str
PtrValueMap map[string]*T
SliceMap map[string][]T
}
// Container is a pre-defined container type, such as a collection, an optional
// value or a generic wrapper.
type Container[T any] struct {
Item T
}
func (c *Container[T]) Clone() *Container[T] {
if c == nil {
return nil
}
if cloner, ok := any(c.Item).(views.Cloner[T]); ok {
return &Container[T]{cloner.Clone()}
}
if !views.ContainsPointers[T]() {
return ptr.To(*c)
}
panic(fmt.Errorf("%T contains pointers, but is not cloneable", c.Item))
}
// ContainerView is a pre-defined readonly view of a Container[T].
type ContainerView[T views.ViewCloner[T, V], V views.StructView[T]] struct {
// ж is the underlying mutable value, named with a hard-to-type
// character that looks pointy like a pointer.
// It is named distinctively to make you think of how dangerous it is to escape
// to callers. You must not let callers be able to mutate it.
ж *Container[T]
}
func (cv ContainerView[T, V]) Item() V {
return cv.ж.Item.View()
}
func ContainerViewOf[T views.ViewCloner[T, V], V views.StructView[T]](c *Container[T]) ContainerView[T, V] {
return ContainerView[T, V]{c}
}
type GenericBasicStruct[T BasicType] struct {
Value T
}
type StructWithContainers struct {
IntContainer Container[int]
CloneableContainer Container[*StructWithPtrs]
BasicGenericContainer Container[GenericBasicStruct[int]]
ClonableGenericContainer Container[*GenericNoPtrsStruct[int]]
}

@ -416,3 +416,24 @@ func _GenericCloneableStructCloneNeedsRegeneration[T views.ViewCloner[T, V], V v
SliceMap map[string][]T
}{})
}
// Clone makes a deep copy of StructWithContainers.
// The result aliases no memory with the original.
func (src *StructWithContainers) Clone() *StructWithContainers {
if src == nil {
return nil
}
dst := new(StructWithContainers)
*dst = *src
dst.CloneableContainer = *src.CloneableContainer.Clone()
dst.ClonableGenericContainer = *src.ClonableGenericContainer.Clone()
return dst
}
// A compilation failure here means this code must be regenerated, with the command at the top of this file.
var _StructWithContainersCloneNeedsRegeneration = StructWithContainers(struct {
IntContainer Container[int]
CloneableContainer Container[*StructWithPtrs]
BasicGenericContainer Container[GenericBasicStruct[int]]
ClonableGenericContainer Container[*GenericNoPtrsStruct[int]]
}{})

@ -14,7 +14,7 @@ import (
"tailscale.com/types/views"
)
//go:generate go run tailscale.com/cmd/cloner -clonefunc=false -type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded,GenericIntStruct,GenericNoPtrsStruct,GenericCloneableStruct
//go:generate go run tailscale.com/cmd/cloner -clonefunc=false -type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded,GenericIntStruct,GenericNoPtrsStruct,GenericCloneableStruct,StructWithContainers
// View returns a readonly view of StructWithPtrs.
func (p *StructWithPtrs) View() StructWithPtrsView {
@ -604,3 +604,67 @@ func _GenericCloneableStructViewNeedsRegeneration[T views.ViewCloner[T, V], V vi
SliceMap map[string][]T
}{})
}
// View returns a readonly view of StructWithContainers.
func (p *StructWithContainers) View() StructWithContainersView {
return StructWithContainersView{ж: p}
}
// StructWithContainersView provides a read-only view over StructWithContainers.
//
// Its methods should only be called if `Valid()` returns true.
type StructWithContainersView struct {
// ж is the underlying mutable value, named with a hard-to-type
// character that looks pointy like a pointer.
// It is named distinctively to make you think of how dangerous it is to escape
// to callers. You must not let callers be able to mutate it.
ж *StructWithContainers
}
// Valid reports whether underlying value is non-nil.
func (v StructWithContainersView) Valid() bool { return v.ж != nil }
// AsStruct returns a clone of the underlying value which aliases no memory with
// the original.
func (v StructWithContainersView) AsStruct() *StructWithContainers {
if v.ж == nil {
return nil
}
return v.ж.Clone()
}
func (v StructWithContainersView) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) }
func (v *StructWithContainersView) UnmarshalJSON(b []byte) error {
if v.ж != nil {
return errors.New("already initialized")
}
if len(b) == 0 {
return nil
}
var x StructWithContainers
if err := json.Unmarshal(b, &x); err != nil {
return err
}
v.ж = &x
return nil
}
func (v StructWithContainersView) IntContainer() Container[int] { return v.ж.IntContainer }
func (v StructWithContainersView) CloneableContainer() ContainerView[*StructWithPtrs, StructWithPtrsView] {
return ContainerViewOf(&v.ж.CloneableContainer)
}
func (v StructWithContainersView) BasicGenericContainer() Container[GenericBasicStruct[int]] {
return v.ж.BasicGenericContainer
}
func (v StructWithContainersView) ClonableGenericContainer() ContainerView[*GenericNoPtrsStruct[int], GenericNoPtrsStructView[int]] {
return ContainerViewOf(&v.ж.ClonableGenericContainer)
}
// A compilation failure here means this code must be regenerated, with the command at the top of this file.
var _StructWithContainersViewNeedsRegeneration = StructWithContainers(struct {
IntContainer Container[int]
CloneableContainer Container[*StructWithPtrs]
BasicGenericContainer Container[GenericBasicStruct[int]]
ClonableGenericContainer Container[*GenericNoPtrsStruct[int]]
}{})

@ -13,9 +13,11 @@ import (
"html/template"
"log"
"os"
"slices"
"strings"
"tailscale.com/util/codegen"
"tailscale.com/util/must"
)
const viewTemplateStr = `{{define "common"}}
@ -75,6 +77,8 @@ func (v *{{.ViewName}}{{.TypeParamNames}}) UnmarshalJSON(b []byte) error {
{{end}}
{{define "viewField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldViewName}} { return v.ж.{{.FieldName}}.View() }
{{end}}
{{define "makeViewField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldViewName}} { return {{.MakeViewFnName}}(&v.ж.{{.FieldName}}) }
{{end}}
{{define "valuePointerField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldType}} {
if v.ж.{{.FieldName}} == nil {
return nil
@ -144,6 +148,9 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi
MapValueType string
MapValueView string
MapFn string
// MakeViewFnName is the name of the function that accepts a value and returns a readonly view of it.
MakeViewFnName string
}{
StructName: typ.Obj().Name(),
ViewName: typ.Origin().Obj().Name() + "View",
@ -227,10 +234,20 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi
strucT := underlying
args.FieldType = it.QualifiedName(fieldType)
if codegen.ContainsPointers(strucT) {
args.FieldViewName = appendNameSuffix(args.FieldType, "View")
if viewType := viewTypeForValueType(fieldType); viewType != nil {
args.FieldViewName = it.QualifiedName(viewType)
writeTemplate("viewField")
continue
}
if viewType, makeViewFn := viewTypeForContainerType(fieldType); viewType != nil {
args.FieldViewName = it.QualifiedName(viewType)
args.MakeViewFnName = it.PackagePrefix(makeViewFn.Pkg()) + makeViewFn.Name()
writeTemplate("makeViewField")
continue
}
writeTemplate("unsupportedField")
continue
}
writeTemplate("valueField")
continue
case *types.Map:
@ -388,6 +405,9 @@ func appendNameSuffix(name, suffix string) string {
}
func viewTypeForValueType(typ types.Type) types.Type {
if ptr, ok := typ.(*types.Pointer); ok {
return viewTypeForValueType(ptr.Elem())
}
viewMethod := codegen.LookupMethod(typ, "View")
if viewMethod == nil {
return nil
@ -399,12 +419,116 @@ func viewTypeForValueType(typ types.Type) types.Type {
return sig.Results().At(0).Type()
}
func viewTypeForContainerType(typ types.Type) (*types.Named, *types.Func) {
// The container type should be an instantiated generic type,
// with its first type parameter specifying the element type.
containerType, ok := typ.(*types.Named)
if !ok || containerType.TypeArgs().Len() == 0 {
return nil, nil
}
// Look up the view type for the container type.
// It must include an additional type parameter specifying the element's view type.
// For example, Container[T] => ContainerView[T, V].
containerViewTypeName := containerType.Obj().Name() + "View"
containerViewTypeObj, ok := containerType.Obj().Pkg().Scope().Lookup(containerViewTypeName).(*types.TypeName)
if !ok {
return nil, nil
}
containerViewGenericType, ok := containerViewTypeObj.Type().(*types.Named)
if !ok || containerViewGenericType.TypeParams().Len() != containerType.TypeArgs().Len()+1 {
return nil, nil
}
// Create a list of type arguments for instantiating the container view type.
// Include all type arguments specified for the container type...
containerViewTypeArgs := make([]types.Type, containerViewGenericType.TypeParams().Len())
for i := range containerType.TypeArgs().Len() {
containerViewTypeArgs[i] = containerType.TypeArgs().At(i)
}
// ...and add the element view type.
// For that, we need to first determine the named elem type...
elemType, ok := baseType(containerType.TypeArgs().At(0)).(*types.Named)
if !ok {
return nil, nil
}
// ...then infer the view type from it.
var elemViewType *types.Named
elemTypeName := elemType.Obj().Name()
elemViewTypeBaseName := elemType.Obj().Name() + "View"
if elemViewTypeName, ok := elemType.Obj().Pkg().Scope().Lookup(elemViewTypeBaseName).(*types.TypeName); ok {
// The elem's view type is already defined in the same package as the elem type.
elemViewType = elemViewTypeName.Type().(*types.Named)
} else if slices.Contains(typeNames, elemTypeName) {
// The elem's view type has not been generated yet, but we can define
// and use a blank type with the expected view type name.
elemViewTypeName = types.NewTypeName(0, elemType.Obj().Pkg(), elemViewTypeBaseName, nil)
elemViewType = types.NewNamed(elemViewTypeName, types.NewStruct(nil, nil), nil)
if elemTypeParams := elemType.TypeParams(); elemTypeParams != nil {
elemViewType.SetTypeParams(collectTypeParams(elemTypeParams))
}
} else {
// The elem view type does not exist and won't be generated.
return nil, nil
}
// If elemType is an instantiated generic type, instantiate the elemViewType as well.
if elemTypeArgs := elemType.TypeArgs(); elemTypeArgs != nil {
elemViewType = must.Get(types.Instantiate(nil, elemViewType, collectTypes(elemTypeArgs), false)).(*types.Named)
}
// And finally set the elemViewType as the last type argument.
containerViewTypeArgs[len(containerViewTypeArgs)-1] = elemViewType
// Instantiate the container view type with the specified type arguments.
containerViewType := must.Get(types.Instantiate(nil, containerViewGenericType, containerViewTypeArgs, false))
// Look up a function to create a view of a container.
// It should be in the same package as the container type, named {ViewType}Of,
// and have a signature like {ViewType}Of(c *Container[T]) ContainerView[T, V].
makeContainerView, ok := containerType.Obj().Pkg().Scope().Lookup(containerViewTypeName + "Of").(*types.Func)
if !ok {
return nil, nil
}
return containerViewType.(*types.Named), makeContainerView
}
func baseType(typ types.Type) types.Type {
if ptr, ok := typ.(*types.Pointer); ok {
return ptr.Elem()
}
return typ
}
func collectTypes(list *types.TypeList) []types.Type {
// TODO(nickkhyl): use slices.Collect in Go 1.23?
if list.Len() == 0 {
return nil
}
res := make([]types.Type, list.Len())
for i := range res {
res[i] = list.At(i)
}
return res
}
func collectTypeParams(list *types.TypeParamList) []*types.TypeParam {
if list.Len() == 0 {
return nil
}
res := make([]*types.TypeParam, list.Len())
for i := range res {
p := list.At(i)
res[i] = types.NewTypeParam(p.Obj(), p.Constraint())
}
return res
}
var (
flagTypes = flag.String("type", "", "comma-separated list of types; required")
flagBuildTags = flag.String("tags", "", "compiler build tags to apply")
flagCloneFunc = flag.Bool("clonefunc", false, "add a top-level Clone func")
flagCloneOnlyTypes = flag.String("clone-only-type", "", "comma-separated list of types (a subset of --type) that should only generate a go:generate clone line and not actual views")
typeNames []string
)
func main() {
@ -415,7 +539,7 @@ func main() {
flag.Usage()
os.Exit(2)
}
typeNames := strings.Split(*flagTypes, ",")
typeNames = strings.Split(*flagTypes, ",")
var flagArgs []string
flagArgs = append(flagArgs, fmt.Sprintf("-clonefunc=%v", *flagCloneFunc))

@ -11,6 +11,7 @@ import (
"errors"
"fmt"
"maps"
"reflect"
"slices"
"go4.org/mem"
@ -111,6 +112,13 @@ type StructView[T any] interface {
AsStruct() T
}
// Cloner is any type that has a Clone function returning a deep-clone of the receiver.
type Cloner[T any] interface {
// Clone returns a deep-clone of the receiver.
// It returns nil, when the receiver is nil.
Clone() T
}
// ViewCloner is any type that has had View and Clone funcs generated using
// tailscale.com/cmd/viewer.
type ViewCloner[T any, V StructView[T]] interface {
@ -555,3 +563,46 @@ func (m MapFn[K, T, V]) Range(f MapRangeFn[K, V]) {
}
}
}
// ContainsPointers reports whether T contains any pointers,
// either explicitly or implicitly.
// It has special handling for some types that contain pointers
// that we know are free from memory aliasing/mutation concerns.
func ContainsPointers[T any]() bool {
return containsPointers(reflect.TypeFor[T]())
}
func containsPointers(typ reflect.Type) bool {
switch typ.Kind() {
case reflect.Pointer, reflect.UnsafePointer:
return true
case reflect.Chan, reflect.Map, reflect.Slice:
return true
case reflect.Array:
return containsPointers(typ.Elem())
case reflect.Interface, reflect.Func:
return true // err on the safe side.
case reflect.Struct:
if isWellKnownImmutableStruct(typ) {
return false
}
for i := range typ.NumField() {
if containsPointers(typ.Field(i).Type) {
return true
}
}
}
return false
}
func isWellKnownImmutableStruct(typ reflect.Type) bool {
switch typ.String() {
case "time.Time":
// time.Time contains a pointer that does not need copying
return true
case "netip.Addr", "netip.Prefix", "netip.AddrPort":
return true
default:
return false
}
}

@ -10,6 +10,7 @@ import (
"reflect"
"strings"
"testing"
"unsafe"
qt "github.com/frankban/quicktest"
)
@ -22,6 +23,16 @@ type viewStruct struct {
StringsPtr *Slice[string] `json:",omitempty"`
}
type noPtrStruct struct {
Int int
Str string
}
type withPtrStruct struct {
Int int
StrPtr *string
}
func BenchmarkSliceIteration(b *testing.B) {
var data []viewStruct
for i := range 10000 {
@ -189,3 +200,215 @@ func TestSliceMapKey(t *testing.T) {
}
}
}
func TestContainsPointers(t *testing.T) {
tests := []struct {
name string
typ reflect.Type
wantPtrs bool
}{
{
name: "bool",
typ: reflect.TypeFor[bool](),
wantPtrs: false,
},
{
name: "int",
typ: reflect.TypeFor[int](),
wantPtrs: false,
},
{
name: "int8",
typ: reflect.TypeFor[int8](),
wantPtrs: false,
},
{
name: "int16",
typ: reflect.TypeFor[int16](),
wantPtrs: false,
},
{
name: "int32",
typ: reflect.TypeFor[int32](),
wantPtrs: false,
},
{
name: "int64",
typ: reflect.TypeFor[int64](),
wantPtrs: false,
},
{
name: "uint",
typ: reflect.TypeFor[uint](),
wantPtrs: false,
},
{
name: "uint8",
typ: reflect.TypeFor[uint8](),
wantPtrs: false,
},
{
name: "uint16",
typ: reflect.TypeFor[uint16](),
wantPtrs: false,
},
{
name: "uint32",
typ: reflect.TypeFor[uint32](),
wantPtrs: false,
},
{
name: "uint64",
typ: reflect.TypeFor[uint64](),
wantPtrs: false,
},
{
name: "uintptr",
typ: reflect.TypeFor[uintptr](),
wantPtrs: false,
},
{
name: "string",
typ: reflect.TypeFor[string](),
wantPtrs: false,
},
{
name: "float32",
typ: reflect.TypeFor[float32](),
wantPtrs: false,
},
{
name: "float64",
typ: reflect.TypeFor[float64](),
wantPtrs: false,
},
{
name: "complex64",
typ: reflect.TypeFor[complex64](),
wantPtrs: false,
},
{
name: "complex128",
typ: reflect.TypeFor[complex128](),
wantPtrs: false,
},
{
name: "netip-Addr",
typ: reflect.TypeFor[netip.Addr](),
wantPtrs: false,
},
{
name: "netip-Prefix",
typ: reflect.TypeFor[netip.Prefix](),
wantPtrs: false,
},
{
name: "netip-AddrPort",
typ: reflect.TypeFor[netip.AddrPort](),
wantPtrs: false,
},
{
name: "bool-ptr",
typ: reflect.TypeFor[*bool](),
wantPtrs: true,
},
{
name: "string-ptr",
typ: reflect.TypeFor[*string](),
wantPtrs: true,
},
{
name: "netip-Addr-ptr",
typ: reflect.TypeFor[*netip.Addr](),
wantPtrs: true,
},
{
name: "unsafe-ptr",
typ: reflect.TypeFor[unsafe.Pointer](),
wantPtrs: true,
},
{
name: "no-ptr-struct",
typ: reflect.TypeFor[noPtrStruct](),
wantPtrs: false,
},
{
name: "ptr-struct",
typ: reflect.TypeFor[withPtrStruct](),
wantPtrs: true,
},
{
name: "string-array",
typ: reflect.TypeFor[[5]string](),
wantPtrs: false,
},
{
name: "int-ptr-array",
typ: reflect.TypeFor[[5]*int](),
wantPtrs: true,
},
{
name: "no-ptr-struct-array",
typ: reflect.TypeFor[[5]noPtrStruct](),
wantPtrs: false,
},
{
name: "with-ptr-struct-array",
typ: reflect.TypeFor[[5]withPtrStruct](),
wantPtrs: true,
},
{
name: "string-slice",
typ: reflect.TypeFor[[]string](),
wantPtrs: true,
},
{
name: "int-ptr-slice",
typ: reflect.TypeFor[[]int](),
wantPtrs: true,
},
{
name: "no-ptr-struct-slice",
typ: reflect.TypeFor[[]noPtrStruct](),
wantPtrs: true,
},
{
name: "string-map",
typ: reflect.TypeFor[map[string]string](),
wantPtrs: true,
},
{
name: "int-map",
typ: reflect.TypeFor[map[int]int](),
wantPtrs: true,
},
{
name: "no-ptr-struct-map",
typ: reflect.TypeFor[map[string]noPtrStruct](),
wantPtrs: true,
},
{
name: "chan",
typ: reflect.TypeFor[chan int](),
wantPtrs: true,
},
{
name: "func",
typ: reflect.TypeFor[func()](),
wantPtrs: true,
},
{
name: "interface",
typ: reflect.TypeFor[any](),
wantPtrs: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if gotPtrs := containsPointers(tt.typ); gotPtrs != tt.wantPtrs {
t.Errorf("got %v; want %v", gotPtrs, tt.wantPtrs)
}
})
}
}

@ -111,6 +111,14 @@ func (it *ImportTracker) QualifiedName(t types.Type) string {
return types.TypeString(t, it.qualifier)
}
// PackagePrefix returns the prefix to be used when referencing named objects from pkg.
func (it *ImportTracker) PackagePrefix(pkg *types.Package) string {
if s := it.qualifier(pkg); s != "" {
return s + "."
}
return ""
}
// Write prints all the tracked imports in a single import block to w.
func (it *ImportTracker) Write(w io.Writer) {
fmt.Fprintf(w, "import (\n")

Loading…
Cancel
Save