cmd/cloner, cmd/viewer, util/codegen: add support for aliases of cloneable types

We have several checked type assertions to *types.Named in both cmd/cloner and cmd/viewer.
As Go 1.23 updates the go/types package to produce Alias type nodes for type aliases,
these type assertions no longer work as expected unless the new behavior is disabled
with gotypesalias=0.

In this PR, we add codegen.NamedTypeOf(t types.Type), which functions like t.(*types.Named)
but also unrolls type aliases. We then use it in place of type assertions in the cmd/cloner and
cmd/viewer packages where appropriate.

We also update type switches to include *types.Alias alongside *types.Named in relevant cases,
remove *types.Struct cases when switching on types.Type.Underlying and update the tests
with more cases where type aliases can be used.

Updates #13224
Updates #12912

Signed-off-by: Nick Khyl <nickk@tailscale.com>
pull/13252/head
Nick Khyl 3 months ago committed by Nick Khyl
parent a9dc6e07ad
commit 03acab2639

@ -115,7 +115,7 @@ func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) {
if !codegen.ContainsPointers(ft) || codegen.HasNoClone(t.Tag(i)) { if !codegen.ContainsPointers(ft) || codegen.HasNoClone(t.Tag(i)) {
continue continue
} }
if named, _ := ft.(*types.Named); named != nil { if named, _ := codegen.NamedTypeOf(ft); named != nil {
if codegen.IsViewType(ft) { if codegen.IsViewType(ft) {
writef("dst.%s = src.%s", fname, fname) writef("dst.%s = src.%s", fname, fname)
continue continue
@ -161,7 +161,7 @@ func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) {
case *types.Pointer: case *types.Pointer:
base := ft.Elem() base := ft.Elem()
hasPtrs := codegen.ContainsPointers(base) hasPtrs := codegen.ContainsPointers(base)
if named, _ := base.(*types.Named); named != nil && hasPtrs { if named, _ := codegen.NamedTypeOf(base); named != nil && hasPtrs {
writef("dst.%s = src.%s.Clone()", fname, fname) writef("dst.%s = src.%s.Clone()", fname, fname)
continue continue
} }

@ -206,11 +206,25 @@ type StructWithContainers struct {
type ( type (
StructWithPtrsAlias = StructWithPtrs StructWithPtrsAlias = StructWithPtrs
StructWithoutPtrsAlias = StructWithoutPtrs StructWithoutPtrsAlias = StructWithoutPtrs
StructWithPtrsAliasView = StructWithPtrsView
StructWithoutPtrsAliasView = StructWithoutPtrsView
) )
type StructWithTypeAliasFields struct { type StructWithTypeAliasFields struct {
WithPtr StructWithPtrsAlias WithPtr StructWithPtrsAlias
WithoutPtr StructWithoutPtrsAlias WithoutPtr StructWithoutPtrsAlias
WithPtrByPtr *StructWithPtrsAlias
WithoutPtrByPtr *StructWithoutPtrsAlias
SliceWithPtrs []*StructWithPtrsAlias
SliceWithoutPtrs []*StructWithoutPtrsAlias
MapWithPtrs map[string]*StructWithPtrsAlias
MapWithoutPtrs map[string]*StructWithoutPtrsAlias
MapOfSlicesWithPtrs map[string][]*StructWithPtrsAlias
MapOfSlicesWithoutPtrs map[string][]*StructWithoutPtrsAlias
} }
type integer = constraints.Integer type integer = constraints.Integer

@ -450,7 +450,63 @@ func (src *StructWithTypeAliasFields) Clone() *StructWithTypeAliasFields {
} }
dst := new(StructWithTypeAliasFields) dst := new(StructWithTypeAliasFields)
*dst = *src *dst = *src
panic("TODO: WithPtr (*types.Struct)") dst.WithPtr = *src.WithPtr.Clone()
dst.WithPtrByPtr = src.WithPtrByPtr.Clone()
if dst.WithoutPtrByPtr != nil {
dst.WithoutPtrByPtr = ptr.To(*src.WithoutPtrByPtr)
}
if src.SliceWithPtrs != nil {
dst.SliceWithPtrs = make([]*StructWithPtrsAlias, len(src.SliceWithPtrs))
for i := range dst.SliceWithPtrs {
if src.SliceWithPtrs[i] == nil {
dst.SliceWithPtrs[i] = nil
} else {
dst.SliceWithPtrs[i] = src.SliceWithPtrs[i].Clone()
}
}
}
if src.SliceWithoutPtrs != nil {
dst.SliceWithoutPtrs = make([]*StructWithoutPtrsAlias, len(src.SliceWithoutPtrs))
for i := range dst.SliceWithoutPtrs {
if src.SliceWithoutPtrs[i] == nil {
dst.SliceWithoutPtrs[i] = nil
} else {
dst.SliceWithoutPtrs[i] = ptr.To(*src.SliceWithoutPtrs[i])
}
}
}
if dst.MapWithPtrs != nil {
dst.MapWithPtrs = map[string]*StructWithPtrsAlias{}
for k, v := range src.MapWithPtrs {
if v == nil {
dst.MapWithPtrs[k] = nil
} else {
dst.MapWithPtrs[k] = v.Clone()
}
}
}
if dst.MapWithoutPtrs != nil {
dst.MapWithoutPtrs = map[string]*StructWithoutPtrsAlias{}
for k, v := range src.MapWithoutPtrs {
if v == nil {
dst.MapWithoutPtrs[k] = nil
} else {
dst.MapWithoutPtrs[k] = ptr.To(*v)
}
}
}
if dst.MapOfSlicesWithPtrs != nil {
dst.MapOfSlicesWithPtrs = map[string][]*StructWithPtrsAlias{}
for k := range src.MapOfSlicesWithPtrs {
dst.MapOfSlicesWithPtrs[k] = append([]*StructWithPtrsAlias{}, src.MapOfSlicesWithPtrs[k]...)
}
}
if dst.MapOfSlicesWithoutPtrs != nil {
dst.MapOfSlicesWithoutPtrs = map[string][]*StructWithoutPtrsAlias{}
for k := range src.MapOfSlicesWithoutPtrs {
dst.MapOfSlicesWithoutPtrs[k] = append([]*StructWithoutPtrsAlias{}, src.MapOfSlicesWithoutPtrs[k]...)
}
}
return dst return dst
} }
@ -458,6 +514,14 @@ func (src *StructWithTypeAliasFields) Clone() *StructWithTypeAliasFields {
var _StructWithTypeAliasFieldsCloneNeedsRegeneration = StructWithTypeAliasFields(struct { var _StructWithTypeAliasFieldsCloneNeedsRegeneration = StructWithTypeAliasFields(struct {
WithPtr StructWithPtrsAlias WithPtr StructWithPtrsAlias
WithoutPtr StructWithoutPtrsAlias WithoutPtr StructWithoutPtrsAlias
WithPtrByPtr *StructWithPtrsAlias
WithoutPtrByPtr *StructWithoutPtrsAlias
SliceWithPtrs []*StructWithPtrsAlias
SliceWithoutPtrs []*StructWithoutPtrsAlias
MapWithPtrs map[string]*StructWithPtrsAlias
MapWithoutPtrs map[string]*StructWithoutPtrsAlias
MapOfSlicesWithPtrs map[string][]*StructWithPtrsAlias
MapOfSlicesWithoutPtrs map[string][]*StructWithoutPtrsAlias
}{}) }{})
// Clone makes a deep copy of GenericTypeAliasStruct. // Clone makes a deep copy of GenericTypeAliasStruct.

@ -724,11 +724,60 @@ func (v *StructWithTypeAliasFieldsView) UnmarshalJSON(b []byte) error {
func (v StructWithTypeAliasFieldsView) WithPtr() StructWithPtrsView { return v.ж.WithPtr.View() } func (v StructWithTypeAliasFieldsView) WithPtr() StructWithPtrsView { return v.ж.WithPtr.View() }
func (v StructWithTypeAliasFieldsView) WithoutPtr() StructWithoutPtrsAlias { return v.ж.WithoutPtr } func (v StructWithTypeAliasFieldsView) WithoutPtr() StructWithoutPtrsAlias { return v.ж.WithoutPtr }
func (v StructWithTypeAliasFieldsView) WithPtrByPtr() StructWithPtrsAliasView {
return v.ж.WithPtrByPtr.View()
}
func (v StructWithTypeAliasFieldsView) WithoutPtrByPtr() *StructWithoutPtrsAlias {
if v.ж.WithoutPtrByPtr == nil {
return nil
}
x := *v.ж.WithoutPtrByPtr
return &x
}
func (v StructWithTypeAliasFieldsView) SliceWithPtrs() views.SliceView[*StructWithPtrsAlias, StructWithPtrsAliasView] {
return views.SliceOfViews[*StructWithPtrsAlias, StructWithPtrsAliasView](v.ж.SliceWithPtrs)
}
func (v StructWithTypeAliasFieldsView) SliceWithoutPtrs() views.SliceView[*StructWithoutPtrsAlias, StructWithoutPtrsAliasView] {
return views.SliceOfViews[*StructWithoutPtrsAlias, StructWithoutPtrsAliasView](v.ж.SliceWithoutPtrs)
}
func (v StructWithTypeAliasFieldsView) MapWithPtrs() views.MapFn[string, *StructWithPtrsAlias, StructWithPtrsAliasView] {
return views.MapFnOf(v.ж.MapWithPtrs, func(t *StructWithPtrsAlias) StructWithPtrsAliasView {
return t.View()
})
}
func (v StructWithTypeAliasFieldsView) MapWithoutPtrs() views.MapFn[string, *StructWithoutPtrsAlias, StructWithoutPtrsAliasView] {
return views.MapFnOf(v.ж.MapWithoutPtrs, func(t *StructWithoutPtrsAlias) StructWithoutPtrsAliasView {
return t.View()
})
}
func (v StructWithTypeAliasFieldsView) MapOfSlicesWithPtrs() views.MapFn[string, []*StructWithPtrsAlias, views.SliceView[*StructWithPtrsAlias, StructWithPtrsAliasView]] {
return views.MapFnOf(v.ж.MapOfSlicesWithPtrs, func(t []*StructWithPtrsAlias) views.SliceView[*StructWithPtrsAlias, StructWithPtrsAliasView] {
return views.SliceOfViews[*StructWithPtrsAlias, StructWithPtrsAliasView](t)
})
}
func (v StructWithTypeAliasFieldsView) MapOfSlicesWithoutPtrs() views.MapFn[string, []*StructWithoutPtrsAlias, views.SliceView[*StructWithoutPtrsAlias, StructWithoutPtrsAliasView]] {
return views.MapFnOf(v.ж.MapOfSlicesWithoutPtrs, func(t []*StructWithoutPtrsAlias) views.SliceView[*StructWithoutPtrsAlias, StructWithoutPtrsAliasView] {
return views.SliceOfViews[*StructWithoutPtrsAlias, StructWithoutPtrsAliasView](t)
})
}
// A compilation failure here means this code must be regenerated, with the command at the top of this file. // A compilation failure here means this code must be regenerated, with the command at the top of this file.
var _StructWithTypeAliasFieldsViewNeedsRegeneration = StructWithTypeAliasFields(struct { var _StructWithTypeAliasFieldsViewNeedsRegeneration = StructWithTypeAliasFields(struct {
WithPtr StructWithPtrsAlias WithPtr StructWithPtrsAlias
WithoutPtr StructWithoutPtrsAlias WithoutPtr StructWithoutPtrsAlias
WithPtrByPtr *StructWithPtrsAlias
WithoutPtrByPtr *StructWithoutPtrsAlias
SliceWithPtrs []*StructWithPtrsAlias
SliceWithoutPtrs []*StructWithoutPtrsAlias
MapWithPtrs map[string]*StructWithPtrsAlias
MapWithoutPtrs map[string]*StructWithoutPtrsAlias
MapOfSlicesWithPtrs map[string][]*StructWithPtrsAlias
MapOfSlicesWithoutPtrs map[string][]*StructWithoutPtrsAlias
}{}) }{})
// View returns a readonly view of GenericTypeAliasStruct. // View returns a readonly view of GenericTypeAliasStruct.

@ -230,7 +230,7 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi
writeTemplate("sliceField") writeTemplate("sliceField")
} }
continue continue
case *types.Struct, *types.Named: case *types.Struct:
strucT := underlying strucT := underlying
args.FieldType = it.QualifiedName(fieldType) args.FieldType = it.QualifiedName(fieldType)
if codegen.ContainsPointers(strucT) { if codegen.ContainsPointers(strucT) {
@ -262,7 +262,7 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi
mElem := m.Elem() mElem := m.Elem()
var template string var template string
switch u := mElem.(type) { switch u := mElem.(type) {
case *types.Struct, *types.Named: case *types.Struct, *types.Named, *types.Alias:
strucT := u strucT := u
args.FieldType = it.QualifiedName(fieldType) args.FieldType = it.QualifiedName(fieldType)
if codegen.ContainsPointers(strucT) { if codegen.ContainsPointers(strucT) {
@ -281,7 +281,7 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi
slice := u slice := u
sElem := slice.Elem() sElem := slice.Elem()
switch x := sElem.(type) { switch x := sElem.(type) {
case *types.Basic, *types.Named: case *types.Basic, *types.Named, *types.Alias:
sElem := it.QualifiedName(sElem) sElem := it.QualifiedName(sElem)
args.MapValueView = fmt.Sprintf("views.Slice[%v]", sElem) args.MapValueView = fmt.Sprintf("views.Slice[%v]", sElem)
args.MapValueType = sElem args.MapValueType = sElem
@ -292,7 +292,7 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi
template = "unsupportedField" template = "unsupportedField"
if _, isIface := pElem.Underlying().(*types.Interface); !isIface { if _, isIface := pElem.Underlying().(*types.Interface); !isIface {
switch pElem.(type) { switch pElem.(type) {
case *types.Struct, *types.Named: case *types.Struct, *types.Named, *types.Alias:
ptrType := it.QualifiedName(ptr) ptrType := it.QualifiedName(ptr)
viewType := appendNameSuffix(it.QualifiedName(pElem), "View") viewType := appendNameSuffix(it.QualifiedName(pElem), "View")
args.MapFn = fmt.Sprintf("views.SliceOfViews[%v,%v](t)", ptrType, viewType) args.MapFn = fmt.Sprintf("views.SliceOfViews[%v,%v](t)", ptrType, viewType)
@ -313,7 +313,7 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi
pElem := ptr.Elem() pElem := ptr.Elem()
if _, isIface := pElem.Underlying().(*types.Interface); !isIface { if _, isIface := pElem.Underlying().(*types.Interface); !isIface {
switch pElem.(type) { switch pElem.(type) {
case *types.Struct, *types.Named: case *types.Struct, *types.Named, *types.Alias:
args.MapValueType = it.QualifiedName(ptr) args.MapValueType = it.QualifiedName(ptr)
args.MapValueView = appendNameSuffix(it.QualifiedName(pElem), "View") args.MapValueView = appendNameSuffix(it.QualifiedName(pElem), "View")
args.MapFn = "t.View()" args.MapFn = "t.View()"
@ -422,7 +422,7 @@ func viewTypeForValueType(typ types.Type) types.Type {
func viewTypeForContainerType(typ types.Type) (*types.Named, *types.Func) { func viewTypeForContainerType(typ types.Type) (*types.Named, *types.Func) {
// The container type should be an instantiated generic type, // The container type should be an instantiated generic type,
// with its first type parameter specifying the element type. // with its first type parameter specifying the element type.
containerType, ok := typ.(*types.Named) containerType, ok := codegen.NamedTypeOf(typ)
if !ok || containerType.TypeArgs().Len() == 0 { if !ok || containerType.TypeArgs().Len() == 0 {
return nil, nil return nil, nil
} }
@ -435,7 +435,7 @@ func viewTypeForContainerType(typ types.Type) (*types.Named, *types.Func) {
if !ok { if !ok {
return nil, nil return nil, nil
} }
containerViewGenericType, ok := containerViewTypeObj.Type().(*types.Named) containerViewGenericType, ok := codegen.NamedTypeOf(containerViewTypeObj.Type())
if !ok || containerViewGenericType.TypeParams().Len() != containerType.TypeArgs().Len()+1 { if !ok || containerViewGenericType.TypeParams().Len() != containerType.TypeArgs().Len()+1 {
return nil, nil return nil, nil
} }
@ -448,7 +448,7 @@ func viewTypeForContainerType(typ types.Type) (*types.Named, *types.Func) {
} }
// ...and add the element view type. // ...and add the element view type.
// For that, we need to first determine the named elem type... // For that, we need to first determine the named elem type...
elemType, ok := baseType(containerType.TypeArgs().At(containerType.TypeArgs().Len() - 1)).(*types.Named) elemType, ok := codegen.NamedTypeOf(baseType(containerType.TypeArgs().At(containerType.TypeArgs().Len() - 1)))
if !ok { if !ok {
return nil, nil return nil, nil
} }
@ -473,7 +473,7 @@ func viewTypeForContainerType(typ types.Type) (*types.Named, *types.Func) {
} }
// If elemType is an instantiated generic type, instantiate the elemViewType as well. // If elemType is an instantiated generic type, instantiate the elemViewType as well.
if elemTypeArgs := elemType.TypeArgs(); elemTypeArgs != nil { if elemTypeArgs := elemType.TypeArgs(); elemTypeArgs != nil {
elemViewType = must.Get(types.Instantiate(nil, elemViewType, collectTypes(elemTypeArgs), false)).(*types.Named) elemViewType, _ = codegen.NamedTypeOf(must.Get(types.Instantiate(nil, elemViewType, collectTypes(elemTypeArgs), false)))
} }
// And finally set the elemViewType as the last type argument. // And finally set the elemViewType as the last type argument.
containerViewTypeArgs[len(containerViewTypeArgs)-1] = elemViewType containerViewTypeArgs[len(containerViewTypeArgs)-1] = elemViewType

@ -382,3 +382,12 @@ func LookupMethod(t types.Type, name string) *types.Func {
} }
return nil return nil
} }
// NamedTypeOf is like t.(*types.Named), but also works with type aliases.
func NamedTypeOf(t types.Type) (named *types.Named, ok bool) {
if a, ok := t.(*types.Alias); ok {
return NamedTypeOf(types.Unalias(a))
}
named, ok = t.(*types.Named)
return
}

Loading…
Cancel
Save