cmd/cloner, cmd/viewer, util/codegen: add support for generic types and interfaces

This adds support for generic types and interfaces to our cloner and viewer codegens.
It updates these packages to determine whether to make shallow or deep copies based
on the type parameter constraints. Additionally, if a template parameter or an interface
type has View() and Clone() methods, we'll use them for getters and the cloner of the
owning structure.

Updates #12736

Signed-off-by: Nick Khyl <nickk@tailscale.com>
pull/12791/head
Nick Khyl 4 months ago committed by Nick Khyl
parent b7c3cfe049
commit fc28c8e7f3

@ -91,16 +91,19 @@ func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) {
} }
name := typ.Obj().Name() name := typ.Obj().Name()
typeParams := typ.Origin().TypeParams()
_, typeParamNames := codegen.FormatTypeParams(typeParams, it)
nameWithParams := name + typeParamNames
fmt.Fprintf(buf, "// Clone makes a deep copy of %s.\n", name) fmt.Fprintf(buf, "// Clone makes a deep copy of %s.\n", name)
fmt.Fprintf(buf, "// The result aliases no memory with the original.\n") fmt.Fprintf(buf, "// The result aliases no memory with the original.\n")
fmt.Fprintf(buf, "func (src *%s) Clone() *%s {\n", name, name) fmt.Fprintf(buf, "func (src *%s) Clone() *%s {\n", nameWithParams, nameWithParams)
writef := func(format string, args ...any) { writef := func(format string, args ...any) {
fmt.Fprintf(buf, "\t"+format+"\n", args...) fmt.Fprintf(buf, "\t"+format+"\n", args...)
} }
writef("if src == nil {") writef("if src == nil {")
writef("\treturn nil") writef("\treturn nil")
writef("}") writef("}")
writef("dst := new(%s)", name) writef("dst := new(%s)", nameWithParams)
writef("*dst = *src") writef("*dst = *src")
for i := range t.NumFields() { for i := range t.NumFields() {
fname := t.Field(i).Name() fname := t.Field(i).Name()
@ -126,16 +129,23 @@ func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) {
writef("dst.%s = make([]%s, len(src.%s))", fname, n, fname) writef("dst.%s = make([]%s, len(src.%s))", fname, n, fname)
writef("for i := range dst.%s {", fname) writef("for i := range dst.%s {", fname)
if ptr, isPtr := ft.Elem().(*types.Pointer); isPtr { if ptr, isPtr := ft.Elem().(*types.Pointer); isPtr {
if _, isBasic := ptr.Elem().Underlying().(*types.Basic); isBasic {
it.Import("tailscale.com/types/ptr")
writef("if src.%s[i] == nil { dst.%s[i] = nil } else {", fname, fname) writef("if src.%s[i] == nil { dst.%s[i] = nil } else {", fname, fname)
writef("\tdst.%s[i] = ptr.To(*src.%s[i])", fname, fname) if codegen.ContainsPointers(ptr.Elem()) {
writef("}") if _, isIface := ptr.Elem().Underlying().(*types.Interface); isIface {
it.Import("tailscale.com/types/ptr")
writef("\tdst.%s[i] = ptr.To((*src.%s[i]).Clone())", fname, fname)
} else { } else {
writef("\tdst.%s[i] = src.%s[i].Clone()", fname, fname) writef("\tdst.%s[i] = src.%s[i].Clone()", fname, fname)
} }
} else {
it.Import("tailscale.com/types/ptr")
writef("\tdst.%s[i] = ptr.To(*src.%s[i])", fname, fname)
}
writef("}")
} else if ft.Elem().String() == "encoding/json.RawMessage" { } else if ft.Elem().String() == "encoding/json.RawMessage" {
writef("\tdst.%s[i] = append(src.%s[i][:0:0], src.%s[i]...)", fname, fname, fname) writef("\tdst.%s[i] = append(src.%s[i][:0:0], src.%s[i]...)", fname, fname, fname)
} else if _, isIface := ft.Elem().Underlying().(*types.Interface); isIface {
writef("\tdst.%s[i] = src.%s[i].Clone()", fname, fname)
} else { } else {
writef("\tdst.%s[i] = *src.%s[i].Clone()", fname, fname) writef("\tdst.%s[i] = *src.%s[i].Clone()", fname, fname)
} }
@ -145,14 +155,19 @@ func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) {
writef("dst.%s = append(src.%s[:0:0], src.%s...)", fname, fname, fname) writef("dst.%s = append(src.%s[:0:0], src.%s...)", fname, fname, fname)
} }
case *types.Pointer: case *types.Pointer:
if named, _ := ft.Elem().(*types.Named); named != nil && codegen.ContainsPointers(ft.Elem()) { base := ft.Elem()
hasPtrs := codegen.ContainsPointers(base)
if named, _ := base.(*types.Named); named != nil && hasPtrs {
writef("dst.%s = src.%s.Clone()", fname, fname) writef("dst.%s = src.%s.Clone()", fname, fname)
continue continue
} }
it.Import("tailscale.com/types/ptr") it.Import("tailscale.com/types/ptr")
writef("if dst.%s != nil {", fname) writef("if dst.%s != nil {", fname)
if _, isIface := base.Underlying().(*types.Interface); isIface && hasPtrs {
writef("\tdst.%s = ptr.To((*src.%s).Clone())", fname, fname)
} else if !hasPtrs {
writef("\tdst.%s = ptr.To(*src.%s)", fname, fname) writef("\tdst.%s = ptr.To(*src.%s)", fname, fname)
if codegen.ContainsPointers(ft.Elem()) { } else {
writef("\t" + `panic("TODO pointers in pointers")`) writef("\t" + `panic("TODO pointers in pointers")`)
} }
writef("}") writef("}")
@ -172,18 +187,50 @@ func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) {
writef("if dst.%s != nil {", fname) writef("if dst.%s != nil {", fname)
writef("\tdst.%s = map[%s]%s{}", fname, it.QualifiedName(ft.Key()), it.QualifiedName(elem)) writef("\tdst.%s = map[%s]%s{}", fname, it.QualifiedName(ft.Key()), it.QualifiedName(elem))
writef("\tfor k, v := range src.%s {", fname) writef("\tfor k, v := range src.%s {", fname)
switch elem.(type) {
switch elem := elem.Underlying().(type) {
case *types.Pointer: case *types.Pointer:
writef("\t\tif v == nil { dst.%s[k] = nil } else {", fname)
if base := elem.Elem().Underlying(); codegen.ContainsPointers(base) {
if _, isIface := base.(*types.Interface); isIface {
it.Import("tailscale.com/types/ptr")
writef("\t\t\tdst.%s[k] = ptr.To((*v).Clone())", fname)
} else {
writef("\t\t\tdst.%s[k] = v.Clone()", fname)
}
} else {
it.Import("tailscale.com/types/ptr")
writef("\t\t\tdst.%s[k] = ptr.To(*v)", fname)
}
writef("}")
case *types.Interface:
if cloneResultType := methodResultType(elem, "Clone"); cloneResultType != nil {
if _, isPtr := cloneResultType.(*types.Pointer); isPtr {
writef("\t\tdst.%s[k] = *(v.Clone())", fname)
} else {
writef("\t\tdst.%s[k] = v.Clone()", fname) writef("\t\tdst.%s[k] = v.Clone()", fname)
}
} else {
writef(`panic("%s (%v) does not have a Clone method")`, fname, elem)
}
default: default:
writef("\t\tdst.%s[k] = *(v.Clone())", fname) writef("\t\tdst.%s[k] = *(v.Clone())", fname)
} }
writef("\t}") writef("\t}")
writef("}") writef("}")
} else { } else {
it.Import("maps") it.Import("maps")
writef("\tdst.%s = maps.Clone(src.%s)", fname, fname) writef("\tdst.%s = maps.Clone(src.%s)", fname, fname)
} }
case *types.Interface:
// If ft is an interface with a "Clone() ft" method, it can be used to clone the field.
// This includes scenarios where ft is a constrained type parameter.
if cloneResultType := methodResultType(ft, "Clone"); cloneResultType.Underlying() == ft {
writef("dst.%s = src.%s.Clone()", fname, fname)
continue
}
writef(`panic("%s (%v) does not have a compatible Clone method")`, fname, ft)
default: default:
writef(`panic("TODO: %s (%T)")`, fname, ft) writef(`panic("TODO: %s (%T)")`, fname, ft)
} }
@ -191,7 +238,7 @@ func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) {
writef("return dst") writef("return dst")
fmt.Fprintf(buf, "}\n\n") fmt.Fprintf(buf, "}\n\n")
buf.Write(codegen.AssertStructUnchanged(t, name, "Clone", it)) buf.Write(codegen.AssertStructUnchanged(t, name, typeParams, "Clone", it))
} }
// hasBasicUnderlying reports true when typ.Underlying() is a slice or a map. // hasBasicUnderlying reports true when typ.Underlying() is a slice or a map.
@ -203,3 +250,15 @@ func hasBasicUnderlying(typ types.Type) bool {
return false return false
} }
} }
func methodResultType(typ types.Type, method string) types.Type {
viewMethod := codegen.LookupMethod(typ, method)
if viewMethod == nil {
return nil
}
sig, ok := viewMethod.Type().(*types.Signature)
if !ok || sig.Results().Len() != 1 {
return nil
}
return sig.Results().At(0).Type()
}

@ -7,9 +7,12 @@ package tests
import ( import (
"fmt" "fmt"
"net/netip" "net/netip"
"golang.org/x/exp/constraints"
"tailscale.com/types/views"
) )
//go:generate go run tailscale.com/cmd/viewer --type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded --clone-only-type=OnlyGetClone //go:generate go run tailscale.com/cmd/viewer --type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded,GenericIntStruct,GenericNoPtrsStruct,GenericCloneableStruct --clone-only-type=OnlyGetClone
type StructWithoutPtrs struct { type StructWithoutPtrs struct {
Int int Int int
@ -25,12 +28,12 @@ type Map struct {
SlicesWithPtrs map[string][]*StructWithPtrs SlicesWithPtrs map[string][]*StructWithPtrs
SlicesWithoutPtrs map[string][]*StructWithoutPtrs SlicesWithoutPtrs map[string][]*StructWithoutPtrs
StructWithoutPtrKey map[StructWithoutPtrs]int `json:"-"` StructWithoutPtrKey map[StructWithoutPtrs]int `json:"-"`
StructWithPtr map[string]StructWithPtrs
// Unsupported views. // Unsupported views.
SliceIntPtr map[string][]*int SliceIntPtr map[string][]*int
PointerKey map[*string]int `json:"-"` PointerKey map[*string]int `json:"-"`
StructWithPtrKey map[StructWithPtrs]int `json:"-"` StructWithPtrKey map[StructWithPtrs]int `json:"-"`
StructWithPtr map[string]StructWithPtrs
} }
type StructWithPtrs struct { type StructWithPtrs struct {
@ -50,12 +53,14 @@ type StructWithSlices struct {
Values []StructWithoutPtrs Values []StructWithoutPtrs
ValuePointers []*StructWithoutPtrs ValuePointers []*StructWithoutPtrs
StructPointers []*StructWithPtrs StructPointers []*StructWithPtrs
Structs []StructWithPtrs
Ints []*int
Slice []string Slice []string
Prefixes []netip.Prefix Prefixes []netip.Prefix
Data []byte Data []byte
// Unsupported views.
Structs []StructWithPtrs
Ints []*int
} }
type OnlyGetClone struct { type OnlyGetClone struct {
@ -66,3 +71,46 @@ type StructWithEmbedded struct {
A *StructWithPtrs A *StructWithPtrs
StructWithSlices StructWithSlices
} }
type GenericIntStruct[T constraints.Integer] struct {
Value T
Pointer *T
Slice []T
Map map[string]T
// Unsupported views.
PtrSlice []*T
PtrKeyMap map[*T]string `json:"-"`
PtrValueMap map[string]*T
SliceMap map[string][]T
}
type BasicType interface {
~bool | constraints.Integer | constraints.Float | constraints.Complex | ~string
}
type GenericNoPtrsStruct[T StructWithoutPtrs | netip.Prefix | BasicType] struct {
Value T
Pointer *T
Slice []T
Map map[string]T
// Unsupported views.
PtrSlice []*T
PtrKeyMap map[*T]string `json:"-"`
PtrValueMap map[string]*T
SliceMap map[string][]T
}
type GenericCloneableStruct[T views.ViewCloner[T, V], V views.StructView[T]] struct {
Value T
Slice []T
Map map[string]T
// Unsupported views.
Pointer *T
PtrSlice []*T
PtrKeyMap map[*T]string `json:"-"`
PtrValueMap map[string]*T
SliceMap map[string][]T
}

@ -9,7 +9,9 @@ import (
"maps" "maps"
"net/netip" "net/netip"
"golang.org/x/exp/constraints"
"tailscale.com/types/ptr" "tailscale.com/types/ptr"
"tailscale.com/types/views"
) )
// Clone makes a deep copy of StructWithPtrs. // Clone makes a deep copy of StructWithPtrs.
@ -71,13 +73,21 @@ func (src *Map) Clone() *Map {
if dst.StructPtrWithPtr != nil { if dst.StructPtrWithPtr != nil {
dst.StructPtrWithPtr = map[string]*StructWithPtrs{} dst.StructPtrWithPtr = map[string]*StructWithPtrs{}
for k, v := range src.StructPtrWithPtr { for k, v := range src.StructPtrWithPtr {
if v == nil {
dst.StructPtrWithPtr[k] = nil
} else {
dst.StructPtrWithPtr[k] = v.Clone() dst.StructPtrWithPtr[k] = v.Clone()
} }
} }
}
if dst.StructPtrWithoutPtr != nil { if dst.StructPtrWithoutPtr != nil {
dst.StructPtrWithoutPtr = map[string]*StructWithoutPtrs{} dst.StructPtrWithoutPtr = map[string]*StructWithoutPtrs{}
for k, v := range src.StructPtrWithoutPtr { for k, v := range src.StructPtrWithoutPtr {
dst.StructPtrWithoutPtr[k] = v.Clone() if v == nil {
dst.StructPtrWithoutPtr[k] = nil
} else {
dst.StructPtrWithoutPtr[k] = ptr.To(*v)
}
} }
} }
dst.StructWithoutPtr = maps.Clone(src.StructWithoutPtr) dst.StructWithoutPtr = maps.Clone(src.StructWithoutPtr)
@ -94,6 +104,12 @@ func (src *Map) Clone() *Map {
} }
} }
dst.StructWithoutPtrKey = maps.Clone(src.StructWithoutPtrKey) dst.StructWithoutPtrKey = maps.Clone(src.StructWithoutPtrKey)
if dst.StructWithPtr != nil {
dst.StructWithPtr = map[string]StructWithPtrs{}
for k, v := range src.StructWithPtr {
dst.StructWithPtr[k] = *(v.Clone())
}
}
if dst.SliceIntPtr != nil { if dst.SliceIntPtr != nil {
dst.SliceIntPtr = map[string][]*int{} dst.SliceIntPtr = map[string][]*int{}
for k := range src.SliceIntPtr { for k := range src.SliceIntPtr {
@ -102,12 +118,6 @@ func (src *Map) Clone() *Map {
} }
dst.PointerKey = maps.Clone(src.PointerKey) dst.PointerKey = maps.Clone(src.PointerKey)
dst.StructWithPtrKey = maps.Clone(src.StructWithPtrKey) dst.StructWithPtrKey = maps.Clone(src.StructWithPtrKey)
if dst.StructWithPtr != nil {
dst.StructWithPtr = map[string]StructWithPtrs{}
for k, v := range src.StructWithPtr {
dst.StructWithPtr[k] = *(v.Clone())
}
}
return dst return dst
} }
@ -121,10 +131,10 @@ var _MapCloneNeedsRegeneration = Map(struct {
SlicesWithPtrs map[string][]*StructWithPtrs SlicesWithPtrs map[string][]*StructWithPtrs
SlicesWithoutPtrs map[string][]*StructWithoutPtrs SlicesWithoutPtrs map[string][]*StructWithoutPtrs
StructWithoutPtrKey map[StructWithoutPtrs]int StructWithoutPtrKey map[StructWithoutPtrs]int
StructWithPtr map[string]StructWithPtrs
SliceIntPtr map[string][]*int SliceIntPtr map[string][]*int
PointerKey map[*string]int PointerKey map[*string]int
StructWithPtrKey map[StructWithPtrs]int StructWithPtrKey map[StructWithPtrs]int
StructWithPtr map[string]StructWithPtrs
}{}) }{})
// Clone makes a deep copy of StructWithSlices. // Clone makes a deep copy of StructWithSlices.
@ -139,15 +149,26 @@ func (src *StructWithSlices) Clone() *StructWithSlices {
if src.ValuePointers != nil { if src.ValuePointers != nil {
dst.ValuePointers = make([]*StructWithoutPtrs, len(src.ValuePointers)) dst.ValuePointers = make([]*StructWithoutPtrs, len(src.ValuePointers))
for i := range dst.ValuePointers { for i := range dst.ValuePointers {
dst.ValuePointers[i] = src.ValuePointers[i].Clone() if src.ValuePointers[i] == nil {
dst.ValuePointers[i] = nil
} else {
dst.ValuePointers[i] = ptr.To(*src.ValuePointers[i])
}
} }
} }
if src.StructPointers != nil { if src.StructPointers != nil {
dst.StructPointers = make([]*StructWithPtrs, len(src.StructPointers)) dst.StructPointers = make([]*StructWithPtrs, len(src.StructPointers))
for i := range dst.StructPointers { for i := range dst.StructPointers {
if src.StructPointers[i] == nil {
dst.StructPointers[i] = nil
} else {
dst.StructPointers[i] = src.StructPointers[i].Clone() dst.StructPointers[i] = src.StructPointers[i].Clone()
} }
} }
}
dst.Slice = append(src.Slice[:0:0], src.Slice...)
dst.Prefixes = append(src.Prefixes[:0:0], src.Prefixes...)
dst.Data = append(src.Data[:0:0], src.Data...)
if src.Structs != nil { if src.Structs != nil {
dst.Structs = make([]StructWithPtrs, len(src.Structs)) dst.Structs = make([]StructWithPtrs, len(src.Structs))
for i := range dst.Structs { for i := range dst.Structs {
@ -164,9 +185,6 @@ func (src *StructWithSlices) Clone() *StructWithSlices {
} }
} }
} }
dst.Slice = append(src.Slice[:0:0], src.Slice...)
dst.Prefixes = append(src.Prefixes[:0:0], src.Prefixes...)
dst.Data = append(src.Data[:0:0], src.Data...)
return dst return dst
} }
@ -175,11 +193,11 @@ var _StructWithSlicesCloneNeedsRegeneration = StructWithSlices(struct {
Values []StructWithoutPtrs Values []StructWithoutPtrs
ValuePointers []*StructWithoutPtrs ValuePointers []*StructWithoutPtrs
StructPointers []*StructWithPtrs StructPointers []*StructWithPtrs
Structs []StructWithPtrs
Ints []*int
Slice []string Slice []string
Prefixes []netip.Prefix Prefixes []netip.Prefix
Data []byte Data []byte
Structs []StructWithPtrs
Ints []*int
}{}) }{})
// Clone makes a deep copy of OnlyGetClone. // Clone makes a deep copy of OnlyGetClone.
@ -216,3 +234,185 @@ var _StructWithEmbeddedCloneNeedsRegeneration = StructWithEmbedded(struct {
A *StructWithPtrs A *StructWithPtrs
StructWithSlices StructWithSlices
}{}) }{})
// Clone makes a deep copy of GenericIntStruct.
// The result aliases no memory with the original.
func (src *GenericIntStruct[T]) Clone() *GenericIntStruct[T] {
if src == nil {
return nil
}
dst := new(GenericIntStruct[T])
*dst = *src
if dst.Pointer != nil {
dst.Pointer = ptr.To(*src.Pointer)
}
dst.Slice = append(src.Slice[:0:0], src.Slice...)
dst.Map = maps.Clone(src.Map)
if src.PtrSlice != nil {
dst.PtrSlice = make([]*T, len(src.PtrSlice))
for i := range dst.PtrSlice {
if src.PtrSlice[i] == nil {
dst.PtrSlice[i] = nil
} else {
dst.PtrSlice[i] = ptr.To(*src.PtrSlice[i])
}
}
}
dst.PtrKeyMap = maps.Clone(src.PtrKeyMap)
if dst.PtrValueMap != nil {
dst.PtrValueMap = map[string]*T{}
for k, v := range src.PtrValueMap {
if v == nil {
dst.PtrValueMap[k] = nil
} else {
dst.PtrValueMap[k] = ptr.To(*v)
}
}
}
if dst.SliceMap != nil {
dst.SliceMap = map[string][]T{}
for k := range src.SliceMap {
dst.SliceMap[k] = append([]T{}, src.SliceMap[k]...)
}
}
return dst
}
// A compilation failure here means this code must be regenerated, with the command at the top of this file.
func _GenericIntStructCloneNeedsRegeneration[T constraints.Integer](GenericIntStruct[T]) {
_GenericIntStructCloneNeedsRegeneration(struct {
Value T
Pointer *T
Slice []T
Map map[string]T
PtrSlice []*T
PtrKeyMap map[*T]string `json:"-"`
PtrValueMap map[string]*T
SliceMap map[string][]T
}{})
}
// Clone makes a deep copy of GenericNoPtrsStruct.
// The result aliases no memory with the original.
func (src *GenericNoPtrsStruct[T]) Clone() *GenericNoPtrsStruct[T] {
if src == nil {
return nil
}
dst := new(GenericNoPtrsStruct[T])
*dst = *src
if dst.Pointer != nil {
dst.Pointer = ptr.To(*src.Pointer)
}
dst.Slice = append(src.Slice[:0:0], src.Slice...)
dst.Map = maps.Clone(src.Map)
if src.PtrSlice != nil {
dst.PtrSlice = make([]*T, len(src.PtrSlice))
for i := range dst.PtrSlice {
if src.PtrSlice[i] == nil {
dst.PtrSlice[i] = nil
} else {
dst.PtrSlice[i] = ptr.To(*src.PtrSlice[i])
}
}
}
dst.PtrKeyMap = maps.Clone(src.PtrKeyMap)
if dst.PtrValueMap != nil {
dst.PtrValueMap = map[string]*T{}
for k, v := range src.PtrValueMap {
if v == nil {
dst.PtrValueMap[k] = nil
} else {
dst.PtrValueMap[k] = ptr.To(*v)
}
}
}
if dst.SliceMap != nil {
dst.SliceMap = map[string][]T{}
for k := range src.SliceMap {
dst.SliceMap[k] = append([]T{}, src.SliceMap[k]...)
}
}
return dst
}
// A compilation failure here means this code must be regenerated, with the command at the top of this file.
func _GenericNoPtrsStructCloneNeedsRegeneration[T StructWithoutPtrs | netip.Prefix | BasicType](GenericNoPtrsStruct[T]) {
_GenericNoPtrsStructCloneNeedsRegeneration(struct {
Value T
Pointer *T
Slice []T
Map map[string]T
PtrSlice []*T
PtrKeyMap map[*T]string `json:"-"`
PtrValueMap map[string]*T
SliceMap map[string][]T
}{})
}
// Clone makes a deep copy of GenericCloneableStruct.
// The result aliases no memory with the original.
func (src *GenericCloneableStruct[T, V]) Clone() *GenericCloneableStruct[T, V] {
if src == nil {
return nil
}
dst := new(GenericCloneableStruct[T, V])
*dst = *src
dst.Value = src.Value.Clone()
if src.Slice != nil {
dst.Slice = make([]T, len(src.Slice))
for i := range dst.Slice {
dst.Slice[i] = src.Slice[i].Clone()
}
}
if dst.Map != nil {
dst.Map = map[string]T{}
for k, v := range src.Map {
dst.Map[k] = v.Clone()
}
}
if dst.Pointer != nil {
dst.Pointer = ptr.To((*src.Pointer).Clone())
}
if src.PtrSlice != nil {
dst.PtrSlice = make([]*T, len(src.PtrSlice))
for i := range dst.PtrSlice {
if src.PtrSlice[i] == nil {
dst.PtrSlice[i] = nil
} else {
dst.PtrSlice[i] = ptr.To((*src.PtrSlice[i]).Clone())
}
}
}
dst.PtrKeyMap = maps.Clone(src.PtrKeyMap)
if dst.PtrValueMap != nil {
dst.PtrValueMap = map[string]*T{}
for k, v := range src.PtrValueMap {
if v == nil {
dst.PtrValueMap[k] = nil
} else {
dst.PtrValueMap[k] = ptr.To((*v).Clone())
}
}
}
if dst.SliceMap != nil {
dst.SliceMap = map[string][]T{}
for k := range src.SliceMap {
dst.SliceMap[k] = append([]T{}, src.SliceMap[k]...)
}
}
return dst
}
// A compilation failure here means this code must be regenerated, with the command at the top of this file.
func _GenericCloneableStructCloneNeedsRegeneration[T views.ViewCloner[T, V], V views.StructView[T]](GenericCloneableStruct[T, V]) {
_GenericCloneableStructCloneNeedsRegeneration(struct {
Value T
Slice []T
Map map[string]T
Pointer *T
PtrSlice []*T
PtrKeyMap map[*T]string `json:"-"`
PtrValueMap map[string]*T
SliceMap map[string][]T
}{})
}

@ -10,10 +10,11 @@ import (
"errors" "errors"
"net/netip" "net/netip"
"golang.org/x/exp/constraints"
"tailscale.com/types/views" "tailscale.com/types/views"
) )
//go:generate go run tailscale.com/cmd/cloner -clonefunc=false -type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded //go:generate go run tailscale.com/cmd/cloner -clonefunc=false -type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded,GenericIntStruct,GenericNoPtrsStruct,GenericCloneableStruct
// View returns a readonly view of StructWithPtrs. // View returns a readonly view of StructWithPtrs.
func (p *StructWithPtrs) View() StructWithPtrsView { func (p *StructWithPtrs) View() StructWithPtrsView {
@ -221,15 +222,15 @@ func (v MapView) SlicesWithoutPtrs() views.MapFn[string, []*StructWithoutPtrs, v
func (v MapView) StructWithoutPtrKey() views.Map[StructWithoutPtrs, int] { func (v MapView) StructWithoutPtrKey() views.Map[StructWithoutPtrs, int] {
return views.MapOf(v.ж.StructWithoutPtrKey) return views.MapOf(v.ж.StructWithoutPtrKey)
} }
func (v MapView) SliceIntPtr() map[string][]*int { panic("unsupported") }
func (v MapView) PointerKey() map[*string]int { panic("unsupported") }
func (v MapView) StructWithPtrKey() map[StructWithPtrs]int { panic("unsupported") }
func (v MapView) StructWithPtr() views.MapFn[string, StructWithPtrs, StructWithPtrsView] { func (v MapView) StructWithPtr() views.MapFn[string, StructWithPtrs, StructWithPtrsView] {
return views.MapFnOf(v.ж.StructWithPtr, func(t StructWithPtrs) StructWithPtrsView { return views.MapFnOf(v.ж.StructWithPtr, func(t StructWithPtrs) StructWithPtrsView {
return t.View() return t.View()
}) })
} }
func (v MapView) SliceIntPtr() map[string][]*int { panic("unsupported") }
func (v MapView) PointerKey() map[*string]int { panic("unsupported") }
func (v MapView) StructWithPtrKey() map[StructWithPtrs]int { panic("unsupported") }
// A compilation failure here means this code must be regenerated, with the command at the top of this file. // A compilation failure here means this code must be regenerated, with the command at the top of this file.
var _MapViewNeedsRegeneration = Map(struct { var _MapViewNeedsRegeneration = Map(struct {
@ -241,10 +242,10 @@ var _MapViewNeedsRegeneration = Map(struct {
SlicesWithPtrs map[string][]*StructWithPtrs SlicesWithPtrs map[string][]*StructWithPtrs
SlicesWithoutPtrs map[string][]*StructWithoutPtrs SlicesWithoutPtrs map[string][]*StructWithoutPtrs
StructWithoutPtrKey map[StructWithoutPtrs]int StructWithoutPtrKey map[StructWithoutPtrs]int
StructWithPtr map[string]StructWithPtrs
SliceIntPtr map[string][]*int SliceIntPtr map[string][]*int
PointerKey map[*string]int PointerKey map[*string]int
StructWithPtrKey map[StructWithPtrs]int StructWithPtrKey map[StructWithPtrs]int
StructWithPtr map[string]StructWithPtrs
}{}) }{})
// View returns a readonly view of StructWithSlices. // View returns a readonly view of StructWithSlices.
@ -301,24 +302,24 @@ func (v StructWithSlicesView) ValuePointers() views.SliceView[*StructWithoutPtrs
func (v StructWithSlicesView) StructPointers() views.SliceView[*StructWithPtrs, StructWithPtrsView] { func (v StructWithSlicesView) StructPointers() views.SliceView[*StructWithPtrs, StructWithPtrsView] {
return views.SliceOfViews[*StructWithPtrs, StructWithPtrsView](v.ж.StructPointers) return views.SliceOfViews[*StructWithPtrs, StructWithPtrsView](v.ж.StructPointers)
} }
func (v StructWithSlicesView) Structs() StructWithPtrs { panic("unsupported") }
func (v StructWithSlicesView) Ints() *int { panic("unsupported") }
func (v StructWithSlicesView) Slice() views.Slice[string] { return views.SliceOf(v.ж.Slice) } func (v StructWithSlicesView) Slice() views.Slice[string] { return views.SliceOf(v.ж.Slice) }
func (v StructWithSlicesView) Prefixes() views.Slice[netip.Prefix] { func (v StructWithSlicesView) Prefixes() views.Slice[netip.Prefix] {
return views.SliceOf(v.ж.Prefixes) return views.SliceOf(v.ж.Prefixes)
} }
func (v StructWithSlicesView) Data() views.ByteSlice[[]byte] { return views.ByteSliceOf(v.ж.Data) } func (v StructWithSlicesView) Data() views.ByteSlice[[]byte] { return views.ByteSliceOf(v.ж.Data) }
func (v StructWithSlicesView) Structs() StructWithPtrs { panic("unsupported") }
func (v StructWithSlicesView) Ints() *int { panic("unsupported") }
// A compilation failure here means this code must be regenerated, with the command at the top of this file. // A compilation failure here means this code must be regenerated, with the command at the top of this file.
var _StructWithSlicesViewNeedsRegeneration = StructWithSlices(struct { var _StructWithSlicesViewNeedsRegeneration = StructWithSlices(struct {
Values []StructWithoutPtrs Values []StructWithoutPtrs
ValuePointers []*StructWithoutPtrs ValuePointers []*StructWithoutPtrs
StructPointers []*StructWithPtrs StructPointers []*StructWithPtrs
Structs []StructWithPtrs
Ints []*int
Slice []string Slice []string
Prefixes []netip.Prefix Prefixes []netip.Prefix
Data []byte Data []byte
Structs []StructWithPtrs
Ints []*int
}{}) }{})
// View returns a readonly view of StructWithEmbedded. // View returns a readonly view of StructWithEmbedded.
@ -376,3 +377,230 @@ var _StructWithEmbeddedViewNeedsRegeneration = StructWithEmbedded(struct {
A *StructWithPtrs A *StructWithPtrs
StructWithSlices StructWithSlices
}{}) }{})
// View returns a readonly view of GenericIntStruct.
func (p *GenericIntStruct[T]) View() GenericIntStructView[T] {
return GenericIntStructView[T]{ж: p}
}
// GenericIntStructView[T] provides a read-only view over GenericIntStruct[T].
//
// Its methods should only be called if `Valid()` returns true.
type GenericIntStructView[T constraints.Integer] struct {
// ж is the underlying mutable value, named with a hard-to-type
// character that looks pointy like a pointer.
// It is named distinctively to make you think of how dangerous it is to escape
// to callers. You must not let callers be able to mutate it.
ж *GenericIntStruct[T]
}
// Valid reports whether underlying value is non-nil.
func (v GenericIntStructView[T]) Valid() bool { return v.ж != nil }
// AsStruct returns a clone of the underlying value which aliases no memory with
// the original.
func (v GenericIntStructView[T]) AsStruct() *GenericIntStruct[T] {
if v.ж == nil {
return nil
}
return v.ж.Clone()
}
func (v GenericIntStructView[T]) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) }
func (v *GenericIntStructView[T]) UnmarshalJSON(b []byte) error {
if v.ж != nil {
return errors.New("already initialized")
}
if len(b) == 0 {
return nil
}
var x GenericIntStruct[T]
if err := json.Unmarshal(b, &x); err != nil {
return err
}
v.ж = &x
return nil
}
func (v GenericIntStructView[T]) Value() T { return v.ж.Value }
func (v GenericIntStructView[T]) Pointer() *T {
if v.ж.Pointer == nil {
return nil
}
x := *v.ж.Pointer
return &x
}
func (v GenericIntStructView[T]) Slice() views.Slice[T] { return views.SliceOf(v.ж.Slice) }
func (v GenericIntStructView[T]) Map() views.Map[string, T] { return views.MapOf(v.ж.Map) }
func (v GenericIntStructView[T]) PtrSlice() *T { panic("unsupported") }
func (v GenericIntStructView[T]) PtrKeyMap() map[*T]string { panic("unsupported") }
func (v GenericIntStructView[T]) PtrValueMap() map[string]*T { panic("unsupported") }
func (v GenericIntStructView[T]) SliceMap() map[string][]T { panic("unsupported") }
// A compilation failure here means this code must be regenerated, with the command at the top of this file.
func _GenericIntStructViewNeedsRegeneration[T constraints.Integer](GenericIntStruct[T]) {
_GenericIntStructViewNeedsRegeneration(struct {
Value T
Pointer *T
Slice []T
Map map[string]T
PtrSlice []*T
PtrKeyMap map[*T]string `json:"-"`
PtrValueMap map[string]*T
SliceMap map[string][]T
}{})
}
// View returns a readonly view of GenericNoPtrsStruct.
func (p *GenericNoPtrsStruct[T]) View() GenericNoPtrsStructView[T] {
return GenericNoPtrsStructView[T]{ж: p}
}
// GenericNoPtrsStructView[T] provides a read-only view over GenericNoPtrsStruct[T].
//
// Its methods should only be called if `Valid()` returns true.
type GenericNoPtrsStructView[T StructWithoutPtrs | netip.Prefix | BasicType] struct {
// ж is the underlying mutable value, named with a hard-to-type
// character that looks pointy like a pointer.
// It is named distinctively to make you think of how dangerous it is to escape
// to callers. You must not let callers be able to mutate it.
ж *GenericNoPtrsStruct[T]
}
// Valid reports whether underlying value is non-nil.
func (v GenericNoPtrsStructView[T]) Valid() bool { return v.ж != nil }
// AsStruct returns a clone of the underlying value which aliases no memory with
// the original.
func (v GenericNoPtrsStructView[T]) AsStruct() *GenericNoPtrsStruct[T] {
if v.ж == nil {
return nil
}
return v.ж.Clone()
}
func (v GenericNoPtrsStructView[T]) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) }
func (v *GenericNoPtrsStructView[T]) UnmarshalJSON(b []byte) error {
if v.ж != nil {
return errors.New("already initialized")
}
if len(b) == 0 {
return nil
}
var x GenericNoPtrsStruct[T]
if err := json.Unmarshal(b, &x); err != nil {
return err
}
v.ж = &x
return nil
}
func (v GenericNoPtrsStructView[T]) Value() T { return v.ж.Value }
func (v GenericNoPtrsStructView[T]) Pointer() *T {
if v.ж.Pointer == nil {
return nil
}
x := *v.ж.Pointer
return &x
}
func (v GenericNoPtrsStructView[T]) Slice() views.Slice[T] { return views.SliceOf(v.ж.Slice) }
func (v GenericNoPtrsStructView[T]) Map() views.Map[string, T] { return views.MapOf(v.ж.Map) }
func (v GenericNoPtrsStructView[T]) PtrSlice() *T { panic("unsupported") }
func (v GenericNoPtrsStructView[T]) PtrKeyMap() map[*T]string { panic("unsupported") }
func (v GenericNoPtrsStructView[T]) PtrValueMap() map[string]*T { panic("unsupported") }
func (v GenericNoPtrsStructView[T]) SliceMap() map[string][]T { panic("unsupported") }
// A compilation failure here means this code must be regenerated, with the command at the top of this file.
func _GenericNoPtrsStructViewNeedsRegeneration[T StructWithoutPtrs | netip.Prefix | BasicType](GenericNoPtrsStruct[T]) {
_GenericNoPtrsStructViewNeedsRegeneration(struct {
Value T
Pointer *T
Slice []T
Map map[string]T
PtrSlice []*T
PtrKeyMap map[*T]string `json:"-"`
PtrValueMap map[string]*T
SliceMap map[string][]T
}{})
}
// View returns a readonly view of GenericCloneableStruct.
func (p *GenericCloneableStruct[T, V]) View() GenericCloneableStructView[T, V] {
return GenericCloneableStructView[T, V]{ж: p}
}
// GenericCloneableStructView[T, V] provides a read-only view over GenericCloneableStruct[T, V].
//
// Its methods should only be called if `Valid()` returns true.
type GenericCloneableStructView[T views.ViewCloner[T, V], V views.StructView[T]] struct {
// ж is the underlying mutable value, named with a hard-to-type
// character that looks pointy like a pointer.
// It is named distinctively to make you think of how dangerous it is to escape
// to callers. You must not let callers be able to mutate it.
ж *GenericCloneableStruct[T, V]
}
// Valid reports whether underlying value is non-nil.
func (v GenericCloneableStructView[T, V]) Valid() bool { return v.ж != nil }
// AsStruct returns a clone of the underlying value which aliases no memory with
// the original.
func (v GenericCloneableStructView[T, V]) AsStruct() *GenericCloneableStruct[T, V] {
if v.ж == nil {
return nil
}
return v.ж.Clone()
}
func (v GenericCloneableStructView[T, V]) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) }
func (v *GenericCloneableStructView[T, V]) UnmarshalJSON(b []byte) error {
if v.ж != nil {
return errors.New("already initialized")
}
if len(b) == 0 {
return nil
}
var x GenericCloneableStruct[T, V]
if err := json.Unmarshal(b, &x); err != nil {
return err
}
v.ж = &x
return nil
}
func (v GenericCloneableStructView[T, V]) Value() V { return v.ж.Value.View() }
func (v GenericCloneableStructView[T, V]) Slice() views.SliceView[T, V] {
return views.SliceOfViews[T, V](v.ж.Slice)
}
func (v GenericCloneableStructView[T, V]) Map() views.MapFn[string, T, V] {
return views.MapFnOf(v.ж.Map, func(t T) V {
return t.View()
})
}
func (v GenericCloneableStructView[T, V]) Pointer() map[string]T { panic("unsupported") }
func (v GenericCloneableStructView[T, V]) PtrSlice() *T { panic("unsupported") }
func (v GenericCloneableStructView[T, V]) PtrKeyMap() map[*T]string { panic("unsupported") }
func (v GenericCloneableStructView[T, V]) PtrValueMap() map[string]*T { panic("unsupported") }
func (v GenericCloneableStructView[T, V]) SliceMap() map[string][]T { panic("unsupported") }
// A compilation failure here means this code must be regenerated, with the command at the top of this file.
func _GenericCloneableStructViewNeedsRegeneration[T views.ViewCloner[T, V], V views.StructView[T]](GenericCloneableStruct[T, V]) {
_GenericCloneableStructViewNeedsRegeneration(struct {
Value T
Slice []T
Map map[string]T
Pointer *T
PtrSlice []*T
PtrKeyMap map[*T]string `json:"-"`
PtrValueMap map[string]*T
SliceMap map[string][]T
}{})
}

@ -20,43 +20,43 @@ import (
const viewTemplateStr = `{{define "common"}} const viewTemplateStr = `{{define "common"}}
// View returns a readonly view of {{.StructName}}. // View returns a readonly view of {{.StructName}}.
func (p *{{.StructName}}) View() {{.ViewName}} { func (p *{{.StructName}}{{.TypeParamNames}}) View() {{.ViewName}}{{.TypeParamNames}} {
return {{.ViewName}}{ж: p} return {{.ViewName}}{{.TypeParamNames}}{ж: p}
} }
// {{.ViewName}} provides a read-only view over {{.StructName}}. // {{.ViewName}}{{.TypeParamNames}} provides a read-only view over {{.StructName}}{{.TypeParamNames}}.
// //
// Its methods should only be called if ` + "`Valid()`" + ` returns true. // Its methods should only be called if ` + "`Valid()`" + ` returns true.
type {{.ViewName}} struct { type {{.ViewName}}{{.TypeParams}} struct {
// ж is the underlying mutable value, named with a hard-to-type // ж is the underlying mutable value, named with a hard-to-type
// character that looks pointy like a pointer. // character that looks pointy like a pointer.
// It is named distinctively to make you think of how dangerous it is to escape // It is named distinctively to make you think of how dangerous it is to escape
// to callers. You must not let callers be able to mutate it. // to callers. You must not let callers be able to mutate it.
ж *{{.StructName}} ж *{{.StructName}}{{.TypeParamNames}}
} }
// Valid reports whether underlying value is non-nil. // Valid reports whether underlying value is non-nil.
func (v {{.ViewName}}) Valid() bool { return v.ж != nil } func (v {{.ViewName}}{{.TypeParamNames}}) Valid() bool { return v.ж != nil }
// AsStruct returns a clone of the underlying value which aliases no memory with // AsStruct returns a clone of the underlying value which aliases no memory with
// the original. // the original.
func (v {{.ViewName}}) AsStruct() *{{.StructName}}{ func (v {{.ViewName}}{{.TypeParamNames}}) AsStruct() *{{.StructName}}{{.TypeParamNames}}{
if v.ж == nil { if v.ж == nil {
return nil return nil
} }
return v.ж.Clone() return v.ж.Clone()
} }
func (v {{.ViewName}}) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) } func (v {{.ViewName}}{{.TypeParamNames}}) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) }
func (v *{{.ViewName}}) UnmarshalJSON(b []byte) error { func (v *{{.ViewName}}{{.TypeParamNames}}) UnmarshalJSON(b []byte) error {
if v.ж != nil { if v.ж != nil {
return errors.New("already initialized") return errors.New("already initialized")
} }
if len(b) == 0 { if len(b) == 0 {
return nil return nil
} }
var x {{.StructName}} var x {{.StructName}}{{.TypeParamNames}}
if err := json.Unmarshal(b, &x); err != nil { if err := json.Unmarshal(b, &x); err != nil {
return err return err
} }
@ -65,17 +65,17 @@ func (v *{{.ViewName}}) UnmarshalJSON(b []byte) error {
} }
{{end}} {{end}}
{{define "valueField"}}func (v {{.ViewName}}) {{.FieldName}}() {{.FieldType}} { return v.ж.{{.FieldName}} } {{define "valueField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldType}} { return v.ж.{{.FieldName}} }
{{end}} {{end}}
{{define "byteSliceField"}}func (v {{.ViewName}}) {{.FieldName}}() views.ByteSlice[{{.FieldType}}] { return views.ByteSliceOf(v.ж.{{.FieldName}}) } {{define "byteSliceField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.ByteSlice[{{.FieldType}}] { return views.ByteSliceOf(v.ж.{{.FieldName}}) }
{{end}} {{end}}
{{define "sliceField"}}func (v {{.ViewName}}) {{.FieldName}}() views.Slice[{{.FieldType}}] { return views.SliceOf(v.ж.{{.FieldName}}) } {{define "sliceField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.Slice[{{.FieldType}}] { return views.SliceOf(v.ж.{{.FieldName}}) }
{{end}} {{end}}
{{define "viewSliceField"}}func (v {{.ViewName}}) {{.FieldName}}() views.SliceView[{{.FieldType}},{{.FieldViewName}}] { return views.SliceOfViews[{{.FieldType}},{{.FieldViewName}}](v.ж.{{.FieldName}}) } {{define "viewSliceField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.SliceView[{{.FieldType}},{{.FieldViewName}}] { return views.SliceOfViews[{{.FieldType}},{{.FieldViewName}}](v.ж.{{.FieldName}}) }
{{end}} {{end}}
{{define "viewField"}}func (v {{.ViewName}}) {{.FieldName}}() {{.FieldType}}View { return v.ж.{{.FieldName}}.View() } {{define "viewField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldViewName}} { return v.ж.{{.FieldName}}.View() }
{{end}} {{end}}
{{define "valuePointerField"}}func (v {{.ViewName}}) {{.FieldName}}() {{.FieldType}} { {{define "valuePointerField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldType}} {
if v.ж.{{.FieldName}} == nil { if v.ж.{{.FieldName}} == nil {
return nil return nil
} }
@ -85,21 +85,21 @@ func (v *{{.ViewName}}) UnmarshalJSON(b []byte) error {
{{end}} {{end}}
{{define "mapField"}} {{define "mapField"}}
func(v {{.ViewName}}) {{.FieldName}}() views.Map[{{.MapKeyType}},{{.MapValueType}}] { return views.MapOf(v.ж.{{.FieldName}})} func(v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.Map[{{.MapKeyType}},{{.MapValueType}}] { return views.MapOf(v.ж.{{.FieldName}})}
{{end}} {{end}}
{{define "mapFnField"}} {{define "mapFnField"}}
func(v {{.ViewName}}) {{.FieldName}}() views.MapFn[{{.MapKeyType}},{{.MapValueType}},{{.MapValueView}}] { return views.MapFnOf(v.ж.{{.FieldName}}, func (t {{.MapValueType}}) {{.MapValueView}} { func(v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.MapFn[{{.MapKeyType}},{{.MapValueType}},{{.MapValueView}}] { return views.MapFnOf(v.ж.{{.FieldName}}, func (t {{.MapValueType}}) {{.MapValueView}} {
return {{.MapFn}} return {{.MapFn}}
})} })}
{{end}} {{end}}
{{define "mapSliceField"}} {{define "mapSliceField"}}
func(v {{.ViewName}}) {{.FieldName}}() views.MapSlice[{{.MapKeyType}},{{.MapValueType}}] { return views.MapSliceOf(v.ж.{{.FieldName}}) } func(v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.MapSlice[{{.MapKeyType}},{{.MapValueType}}] { return views.MapSliceOf(v.ж.{{.FieldName}}) }
{{end}} {{end}}
{{define "unsupportedField"}}func(v {{.ViewName}}) {{.FieldName}}() {{.FieldType}} {panic("unsupported")} {{define "unsupportedField"}}func(v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldType}} {panic("unsupported")}
{{end}} {{end}}
{{define "stringFunc"}}func(v {{.ViewName}}) String() string { return v.ж.String() } {{define "stringFunc"}}func(v {{.ViewName}}{{.TypeParamNames}}) String() string { return v.ж.String() }
{{end}} {{end}}
{{define "equalFunc"}}func(v {{.ViewName}}) Equal(v2 {{.ViewName}}) bool { return v.ж.Equal(v2.ж) } {{define "equalFunc"}}func(v {{.ViewName}}{{.TypeParamNames}}) Equal(v2 {{.ViewName}}{{.TypeParamNames}}) bool { return v.ж.Equal(v2.ж) }
{{end}} {{end}}
` `
@ -133,6 +133,9 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi
args := struct { args := struct {
StructName string StructName string
ViewName string ViewName string
TypeParams string // e.g. [T constraints.Integer]
TypeParamNames string // e.g. [T]
FieldName string FieldName string
FieldType string FieldType string
FieldViewName string FieldViewName string
@ -143,9 +146,12 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi
MapFn string MapFn string
}{ }{
StructName: typ.Obj().Name(), StructName: typ.Obj().Name(),
ViewName: typ.Obj().Name() + "View", ViewName: typ.Origin().Obj().Name() + "View",
} }
typeParams := typ.Origin().TypeParams()
args.TypeParams, args.TypeParamNames = codegen.FormatTypeParams(typeParams, it)
writeTemplate := func(name string) { writeTemplate := func(name string) {
if err := viewTemplate.ExecuteTemplate(buf, name, args); err != nil { if err := viewTemplate.ExecuteTemplate(buf, name, args); err != nil {
log.Fatal(err) log.Fatal(err)
@ -182,19 +188,35 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi
it.Import("tailscale.com/types/views") it.Import("tailscale.com/types/views")
shallow, deep, base := requiresCloning(elem) shallow, deep, base := requiresCloning(elem)
if deep { if deep {
if _, isPtr := elem.(*types.Pointer); isPtr { switch elem.Underlying().(type) {
args.FieldViewName = it.QualifiedName(base) + "View" case *types.Pointer:
if _, isIface := base.Underlying().(*types.Interface); !isIface {
args.FieldViewName = appendNameSuffix(it.QualifiedName(base), "View")
writeTemplate("viewSliceField") writeTemplate("viewSliceField")
} else { } else {
writeTemplate("unsupportedField") writeTemplate("unsupportedField")
} }
continue continue
case *types.Interface:
if viewType := viewTypeForValueType(elem); viewType != nil {
args.FieldViewName = it.QualifiedName(viewType)
writeTemplate("viewSliceField")
continue
}
}
writeTemplate("unsupportedField")
continue
} else if shallow { } else if shallow {
if _, isBasic := base.(*types.Basic); isBasic { switch base.Underlying().(type) {
case *types.Basic, *types.Interface:
writeTemplate("unsupportedField") writeTemplate("unsupportedField")
} else { default:
args.FieldViewName = it.QualifiedName(base) + "View" if _, isIface := base.Underlying().(*types.Interface); !isIface {
args.FieldViewName = appendNameSuffix(it.QualifiedName(base), "View")
writeTemplate("viewSliceField") writeTemplate("viewSliceField")
} else {
writeTemplate("unsupportedField")
}
} }
continue continue
} }
@ -205,6 +227,7 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi
strucT := underlying strucT := underlying
args.FieldType = it.QualifiedName(fieldType) args.FieldType = it.QualifiedName(fieldType)
if codegen.ContainsPointers(strucT) { if codegen.ContainsPointers(strucT) {
args.FieldViewName = appendNameSuffix(args.FieldType, "View")
writeTemplate("viewField") writeTemplate("viewField")
continue continue
} }
@ -229,7 +252,7 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi
args.MapFn = "t.View()" args.MapFn = "t.View()"
template = "mapFnField" template = "mapFnField"
args.MapValueType = it.QualifiedName(mElem) args.MapValueType = it.QualifiedName(mElem)
args.MapValueView = args.MapValueType + "View" args.MapValueView = appendNameSuffix(args.MapValueType, "View")
} else { } else {
template = "mapField" template = "mapField"
args.MapValueType = it.QualifiedName(mElem) args.MapValueType = it.QualifiedName(mElem)
@ -249,10 +272,12 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi
case *types.Pointer: case *types.Pointer:
ptr := x ptr := x
pElem := ptr.Elem() pElem := ptr.Elem()
template = "unsupportedField"
if _, isIface := pElem.Underlying().(*types.Interface); !isIface {
switch pElem.(type) { switch pElem.(type) {
case *types.Struct, *types.Named: case *types.Struct, *types.Named:
ptrType := it.QualifiedName(ptr) ptrType := it.QualifiedName(ptr)
viewType := it.QualifiedName(pElem) + "View" viewType := appendNameSuffix(it.QualifiedName(pElem), "View")
args.MapFn = fmt.Sprintf("views.SliceOfViews[%v,%v](t)", ptrType, viewType) args.MapFn = fmt.Sprintf("views.SliceOfViews[%v,%v](t)", ptrType, viewType)
args.MapValueView = fmt.Sprintf("views.SliceView[%v,%v]", ptrType, viewType) args.MapValueView = fmt.Sprintf("views.SliceView[%v,%v]", ptrType, viewType)
args.MapValueType = "[]" + ptrType args.MapValueType = "[]" + ptrType
@ -260,21 +285,40 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi
default: default:
template = "unsupportedField" template = "unsupportedField"
} }
} else {
template = "unsupportedField"
}
default: default:
template = "unsupportedField" template = "unsupportedField"
} }
case *types.Pointer: case *types.Pointer:
ptr := u ptr := u
pElem := ptr.Elem() pElem := ptr.Elem()
if _, isIface := pElem.Underlying().(*types.Interface); !isIface {
switch pElem.(type) { switch pElem.(type) {
case *types.Struct, *types.Named: case *types.Struct, *types.Named:
args.MapValueType = it.QualifiedName(ptr) args.MapValueType = it.QualifiedName(ptr)
args.MapValueView = it.QualifiedName(pElem) + "View" args.MapValueView = appendNameSuffix(it.QualifiedName(pElem), "View")
args.MapFn = "t.View()" args.MapFn = "t.View()"
template = "mapFnField" template = "mapFnField"
default: default:
template = "unsupportedField" template = "unsupportedField"
} }
} else {
template = "unsupportedField"
}
case *types.Interface, *types.TypeParam:
if viewType := viewTypeForValueType(u); viewType != nil {
args.MapValueType = it.QualifiedName(u)
args.MapValueView = it.QualifiedName(viewType)
args.MapFn = "t.View()"
template = "mapFnField"
} else if !codegen.ContainsPointers(u) {
args.MapValueType = it.QualifiedName(mElem)
template = "mapField"
} else {
template = "unsupportedField"
}
default: default:
template = "unsupportedField" template = "unsupportedField"
} }
@ -283,14 +327,28 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi
case *types.Pointer: case *types.Pointer:
ptr := underlying ptr := underlying
_, deep, base := requiresCloning(ptr) _, deep, base := requiresCloning(ptr)
if deep { if deep {
if _, isIface := base.Underlying().(*types.Interface); !isIface {
args.FieldType = it.QualifiedName(base) args.FieldType = it.QualifiedName(base)
args.FieldViewName = appendNameSuffix(args.FieldType, "View")
writeTemplate("viewField") writeTemplate("viewField")
} else {
writeTemplate("unsupportedField")
}
} else { } else {
args.FieldType = it.QualifiedName(ptr) args.FieldType = it.QualifiedName(ptr)
writeTemplate("valuePointerField") writeTemplate("valuePointerField")
} }
continue continue
case *types.Interface:
// If fieldType is an interface with a "View() {ViewType}" method, it can be used to clone the field.
// This includes scenarios where fieldType is a constrained type parameter.
if viewType := viewTypeForValueType(underlying); viewType != nil {
args.FieldViewName = it.QualifiedName(viewType)
writeTemplate("viewField")
continue
}
} }
writeTemplate("unsupportedField") writeTemplate("unsupportedField")
} }
@ -318,7 +376,27 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi
} }
} }
fmt.Fprintf(buf, "\n") fmt.Fprintf(buf, "\n")
buf.Write(codegen.AssertStructUnchanged(t, args.StructName, "View", it)) buf.Write(codegen.AssertStructUnchanged(t, args.StructName, typeParams, "View", it))
}
func appendNameSuffix(name, suffix string) string {
if idx := strings.IndexRune(name, '['); idx != -1 {
// Insert suffix after the type name, but before type parameters.
return name[:idx] + suffix + name[idx:]
}
return name + suffix
}
func viewTypeForValueType(typ types.Type) types.Type {
viewMethod := codegen.LookupMethod(typ, "View")
if viewMethod == nil {
return nil
}
sig, ok := viewMethod.Type().(*types.Signature)
if !ok || sig.Results().Len() != 1 {
return nil
}
return sig.Results().At(0).Type()
} }
var ( var (

@ -14,6 +14,7 @@ import (
"tailscale.com/types/opt" "tailscale.com/types/opt"
"tailscale.com/types/persist" "tailscale.com/types/persist"
"tailscale.com/types/preftype" "tailscale.com/types/preftype"
"tailscale.com/types/ptr"
) )
// Clone makes a deep copy of Prefs. // Clone makes a deep copy of Prefs.
@ -29,9 +30,13 @@ func (src *Prefs) Clone() *Prefs {
if src.DriveShares != nil { if src.DriveShares != nil {
dst.DriveShares = make([]*drive.Share, len(src.DriveShares)) dst.DriveShares = make([]*drive.Share, len(src.DriveShares))
for i := range dst.DriveShares { for i := range dst.DriveShares {
if src.DriveShares[i] == nil {
dst.DriveShares[i] = nil
} else {
dst.DriveShares[i] = src.DriveShares[i].Clone() dst.DriveShares[i] = src.DriveShares[i].Clone()
} }
} }
}
dst.Persist = src.Persist.Clone() dst.Persist = src.Persist.Clone()
return dst return dst
} }
@ -81,22 +86,34 @@ func (src *ServeConfig) Clone() *ServeConfig {
if dst.TCP != nil { if dst.TCP != nil {
dst.TCP = map[uint16]*TCPPortHandler{} dst.TCP = map[uint16]*TCPPortHandler{}
for k, v := range src.TCP { for k, v := range src.TCP {
dst.TCP[k] = v.Clone() if v == nil {
dst.TCP[k] = nil
} else {
dst.TCP[k] = ptr.To(*v)
}
} }
} }
if dst.Web != nil { if dst.Web != nil {
dst.Web = map[HostPort]*WebServerConfig{} dst.Web = map[HostPort]*WebServerConfig{}
for k, v := range src.Web { for k, v := range src.Web {
if v == nil {
dst.Web[k] = nil
} else {
dst.Web[k] = v.Clone() dst.Web[k] = v.Clone()
} }
} }
}
dst.AllowFunnel = maps.Clone(src.AllowFunnel) dst.AllowFunnel = maps.Clone(src.AllowFunnel)
if dst.Foreground != nil { if dst.Foreground != nil {
dst.Foreground = map[string]*ServeConfig{} dst.Foreground = map[string]*ServeConfig{}
for k, v := range src.Foreground { for k, v := range src.Foreground {
if v == nil {
dst.Foreground[k] = nil
} else {
dst.Foreground[k] = v.Clone() dst.Foreground[k] = v.Clone()
} }
} }
}
return dst return dst
} }
@ -157,7 +174,11 @@ func (src *WebServerConfig) Clone() *WebServerConfig {
if dst.Handlers != nil { if dst.Handlers != nil {
dst.Handlers = map[string]*HTTPHandler{} dst.Handlers = map[string]*HTTPHandler{}
for k, v := range src.Handlers { for k, v := range src.Handlers {
dst.Handlers[k] = v.Clone() if v == nil {
dst.Handlers[k] = nil
} else {
dst.Handlers[k] = ptr.To(*v)
}
} }
} }
return dst return dst

@ -77,9 +77,13 @@ func (src *Node) Clone() *Node {
if src.ExitNodeDNSResolvers != nil { if src.ExitNodeDNSResolvers != nil {
dst.ExitNodeDNSResolvers = make([]*dnstype.Resolver, len(src.ExitNodeDNSResolvers)) dst.ExitNodeDNSResolvers = make([]*dnstype.Resolver, len(src.ExitNodeDNSResolvers))
for i := range dst.ExitNodeDNSResolvers { for i := range dst.ExitNodeDNSResolvers {
if src.ExitNodeDNSResolvers[i] == nil {
dst.ExitNodeDNSResolvers[i] = nil
} else {
dst.ExitNodeDNSResolvers[i] = src.ExitNodeDNSResolvers[i].Clone() dst.ExitNodeDNSResolvers[i] = src.ExitNodeDNSResolvers[i].Clone()
} }
} }
}
return dst return dst
} }
@ -244,9 +248,13 @@ func (src *DNSConfig) Clone() *DNSConfig {
if src.Resolvers != nil { if src.Resolvers != nil {
dst.Resolvers = make([]*dnstype.Resolver, len(src.Resolvers)) dst.Resolvers = make([]*dnstype.Resolver, len(src.Resolvers))
for i := range dst.Resolvers { for i := range dst.Resolvers {
if src.Resolvers[i] == nil {
dst.Resolvers[i] = nil
} else {
dst.Resolvers[i] = src.Resolvers[i].Clone() dst.Resolvers[i] = src.Resolvers[i].Clone()
} }
} }
}
if dst.Routes != nil { if dst.Routes != nil {
dst.Routes = map[string][]*dnstype.Resolver{} dst.Routes = map[string][]*dnstype.Resolver{}
for k := range src.Routes { for k := range src.Routes {
@ -256,9 +264,13 @@ func (src *DNSConfig) Clone() *DNSConfig {
if src.FallbackResolvers != nil { if src.FallbackResolvers != nil {
dst.FallbackResolvers = make([]*dnstype.Resolver, len(src.FallbackResolvers)) dst.FallbackResolvers = make([]*dnstype.Resolver, len(src.FallbackResolvers))
for i := range dst.FallbackResolvers { for i := range dst.FallbackResolvers {
if src.FallbackResolvers[i] == nil {
dst.FallbackResolvers[i] = nil
} else {
dst.FallbackResolvers[i] = src.FallbackResolvers[i].Clone() dst.FallbackResolvers[i] = src.FallbackResolvers[i].Clone()
} }
} }
}
dst.Domains = append(src.Domains[:0:0], src.Domains...) dst.Domains = append(src.Domains[:0:0], src.Domains...)
dst.Nameservers = append(src.Nameservers[:0:0], src.Nameservers...) dst.Nameservers = append(src.Nameservers[:0:0], src.Nameservers...)
dst.CertDomains = append(src.CertDomains[:0:0], src.CertDomains...) dst.CertDomains = append(src.CertDomains[:0:0], src.CertDomains...)
@ -393,7 +405,11 @@ func (src *DERPRegion) Clone() *DERPRegion {
if src.Nodes != nil { if src.Nodes != nil {
dst.Nodes = make([]*DERPNode, len(src.Nodes)) dst.Nodes = make([]*DERPNode, len(src.Nodes))
for i := range dst.Nodes { for i := range dst.Nodes {
dst.Nodes[i] = src.Nodes[i].Clone() if src.Nodes[i] == nil {
dst.Nodes[i] = nil
} else {
dst.Nodes[i] = ptr.To(*src.Nodes[i])
}
} }
} }
return dst return dst
@ -422,9 +438,13 @@ func (src *DERPMap) Clone() *DERPMap {
if dst.Regions != nil { if dst.Regions != nil {
dst.Regions = map[int]*DERPRegion{} dst.Regions = map[int]*DERPRegion{}
for k, v := range src.Regions { for k, v := range src.Regions {
if v == nil {
dst.Regions[k] = nil
} else {
dst.Regions[k] = v.Clone() dst.Regions[k] = v.Clone()
} }
} }
}
return dst return dst
} }
@ -476,9 +496,13 @@ func (src *SSHRule) Clone() *SSHRule {
if src.Principals != nil { if src.Principals != nil {
dst.Principals = make([]*SSHPrincipal, len(src.Principals)) dst.Principals = make([]*SSHPrincipal, len(src.Principals))
for i := range dst.Principals { for i := range dst.Principals {
if src.Principals[i] == nil {
dst.Principals[i] = nil
} else {
dst.Principals[i] = src.Principals[i].Clone() dst.Principals[i] = src.Principals[i].Clone()
} }
} }
}
dst.SSHUsers = maps.Clone(src.SSHUsers) dst.SSHUsers = maps.Clone(src.SSHUsers)
dst.Action = src.Action.Clone() dst.Action = src.Action.Clone()
return dst return dst

@ -27,9 +27,9 @@ var flagCopyright = flag.Bool("copyright", true, "add Tailscale copyright to gen
func LoadTypes(buildTags string, pkgName string) (*packages.Package, map[string]*types.Named, error) { func LoadTypes(buildTags string, pkgName string) (*packages.Package, map[string]*types.Named, error) {
cfg := &packages.Config{ cfg := &packages.Config{
Mode: packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedName, Mode: packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedName,
Tests: false, Tests: buildTags == "test",
} }
if buildTags != "" { if buildTags != "" && !cfg.Tests {
cfg.BuildFlags = []string{"-tags=" + buildTags} cfg.BuildFlags = []string{"-tags=" + buildTags}
} }
@ -37,6 +37,9 @@ func LoadTypes(buildTags string, pkgName string) (*packages.Package, map[string]
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if cfg.Tests {
pkgs = testPackages(pkgs)
}
if len(pkgs) != 1 { if len(pkgs) != 1 {
return nil, nil, fmt.Errorf("wrong number of packages: %d", len(pkgs)) return nil, nil, fmt.Errorf("wrong number of packages: %d", len(pkgs))
} }
@ -44,6 +47,17 @@ func LoadTypes(buildTags string, pkgName string) (*packages.Package, map[string]
return pkg, namedTypes(pkg), nil return pkg, namedTypes(pkg), nil
} }
func testPackages(pkgs []*packages.Package) []*packages.Package {
var testPackages []*packages.Package
for _, pkg := range pkgs {
testPackageID := fmt.Sprintf("%[1]s [%[1]s.test]", pkg.PkgPath)
if pkg.ID == testPackageID {
testPackages = append(testPackages, pkg)
}
}
return testPackages
}
// HasNoClone reports whether the provided tag has `codegen:noclone`. // HasNoClone reports whether the provided tag has `codegen:noclone`.
func HasNoClone(structTag string) bool { func HasNoClone(structTag string) bool {
val := reflect.StructTag(structTag).Get("codegen") val := reflect.StructTag(structTag).Get("codegen")
@ -193,13 +207,21 @@ func namedTypes(pkg *packages.Package) map[string]*types.Named {
// ctx is a single-word context for this assertion, such as "Clone". // ctx is a single-word context for this assertion, such as "Clone".
// If non-nil, AssertStructUnchanged will add elements to imports // If non-nil, AssertStructUnchanged will add elements to imports
// for each package path that the caller must import for the returned code to compile. // for each package path that the caller must import for the returned code to compile.
func AssertStructUnchanged(t *types.Struct, tname, ctx string, it *ImportTracker) []byte { func AssertStructUnchanged(t *types.Struct, tname string, params *types.TypeParamList, ctx string, it *ImportTracker) []byte {
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
w := func(format string, args ...any) { w := func(format string, args ...any) {
fmt.Fprintf(buf, format+"\n", args...) fmt.Fprintf(buf, format+"\n", args...)
} }
w("// A compilation failure here means this code must be regenerated, with the command at the top of this file.") w("// A compilation failure here means this code must be regenerated, with the command at the top of this file.")
hasTypeParams := params != nil && params.Len() > 0
if hasTypeParams {
constraints, identifiers := FormatTypeParams(params, it)
w("func _%s%sNeedsRegeneration%s (%s%s) {", tname, ctx, constraints, tname, identifiers)
w("_%s%sNeedsRegeneration(struct {", tname, ctx)
} else {
w("var _%s%sNeedsRegeneration = %s(struct {", tname, ctx, tname) w("var _%s%sNeedsRegeneration = %s(struct {", tname, ctx, tname)
}
for i := range t.NumFields() { for i := range t.NumFields() {
st := t.Field(i) st := t.Field(i)
@ -209,14 +231,25 @@ func AssertStructUnchanged(t *types.Struct, tname, ctx string, it *ImportTracker
continue continue
} }
qname := it.QualifiedName(ft) qname := it.QualifiedName(ft)
var tag string
if hasTypeParams {
tag = t.Tag(i)
if tag != "" {
tag = "`" + tag + "`"
}
}
if st.Anonymous() { if st.Anonymous() {
w("\t%s ", fname) w("\t%s %s", fname, tag)
} else { } else {
w("\t%s %s", fname, qname) w("\t%s %s %s", fname, qname, tag)
} }
} }
w("}{})\n") if hasTypeParams {
w("}{})\n}")
} else {
w("}{})")
}
return buf.Bytes() return buf.Bytes()
} }
@ -242,10 +275,21 @@ func ContainsPointers(typ types.Type) bool {
switch ft := typ.Underlying().(type) { switch ft := typ.Underlying().(type) {
case *types.Array: case *types.Array:
return ContainsPointers(ft.Elem()) return ContainsPointers(ft.Elem())
case *types.Basic:
if ft.Kind() == types.UnsafePointer {
return true
}
case *types.Chan: case *types.Chan:
return true return true
case *types.Interface: case *types.Interface:
return true // a little too broad if ft.Empty() || ft.IsMethodSet() {
return true
}
for i := 0; i < ft.NumEmbeddeds(); i++ {
if ContainsPointers(ft.EmbeddedType(i)) {
return true
}
}
case *types.Map: case *types.Map:
return true return true
case *types.Pointer: case *types.Pointer:
@ -258,6 +302,12 @@ func ContainsPointers(typ types.Type) bool {
return true return true
} }
} }
case *types.Union:
for i := range ft.Len() {
if ContainsPointers(ft.Term(i).Type()) {
return true
}
}
} }
return false return false
} }
@ -273,3 +323,44 @@ func IsViewType(typ types.Type) bool {
} }
return t.Field(0).Name() == "ж" return t.Field(0).Name() == "ж"
} }
// FormatTypeParams formats the specified params and returns two strings:
// - constraints are comma-separated type parameters and their constraints in square brackets (e.g. [T any, V constraints.Integer])
// - names are comma-separated type parameter names in square brackets (e.g. [T, V])
//
// If params is nil or empty, both return values are empty strings.
func FormatTypeParams(params *types.TypeParamList, it *ImportTracker) (constraints, names string) {
if params == nil || params.Len() == 0 {
return "", ""
}
var constraintList, nameList []string
for i := range params.Len() {
param := params.At(i)
name := param.Obj().Name()
constraint := it.QualifiedName(param.Constraint())
nameList = append(nameList, name)
constraintList = append(constraintList, name+" "+constraint)
}
constraints = "[" + strings.Join(constraintList, ", ") + "]"
names = "[" + strings.Join(nameList, ", ") + "]"
return constraints, names
}
// LookupMethod returns the method with the specified name in t, or nil if the method does not exist.
func LookupMethod(t types.Type, name string) *types.Func {
if t, ok := t.(*types.Named); ok {
for i := 0; i < t.NumMethods(); i++ {
if method := t.Method(i); method.Name() == name {
return method
}
}
}
if t, ok := t.Underlying().(*types.Interface); ok {
for i := 0; i < t.NumMethods(); i++ {
if method := t.Method(i); method.Name() == name {
return method
}
}
}
return nil
}

@ -0,0 +1,176 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package codegen
import (
"log"
"net/netip"
"testing"
"unsafe"
"golang.org/x/exp/constraints"
)
type AnyParam[T any] struct {
V T
}
type AnyParamPhantom[T any] struct {
}
type IntegerParam[T constraints.Integer] struct {
V T
}
type FloatParam[T constraints.Float] struct {
V T
}
type StringLikeParam[T ~string] struct {
V T
}
type BasicType interface {
~bool | constraints.Integer | constraints.Float | constraints.Complex | ~string
}
type BasicTypeParam[T BasicType] struct {
V T
}
type IntPtr *int
type IntPtrParam[T IntPtr] struct {
V T
}
type IntegerPtr interface {
*int | *int32 | *int64
}
type IntegerPtrParam[T IntegerPtr] struct {
V T
}
type IntegerParamPtr[T constraints.Integer] struct {
V *T
}
type IntegerSliceParam[T constraints.Integer] struct {
V []T
}
type IntegerMapParam[T constraints.Integer] struct {
V []T
}
type UnsafePointerParam[T unsafe.Pointer] struct {
V T
}
type ValueUnionParam[T netip.Prefix | BasicType] struct {
V T
}
type ValueUnionParamPtr[T netip.Prefix | BasicType] struct {
V *T
}
type PointerUnionParam[T netip.Prefix | BasicType | IntPtr] struct {
V T
}
type Interface interface {
Method()
}
type InterfaceParam[T Interface] struct {
V T
}
func TestGenericContainsPointers(t *testing.T) {
tests := []struct {
typ string
wantPointer bool
}{
{
typ: "AnyParam",
wantPointer: true,
},
{
typ: "AnyParamPhantom",
wantPointer: false, // has a pointer type parameter, but no pointer fields
},
{
typ: "IntegerParam",
wantPointer: false,
},
{
typ: "FloatParam",
wantPointer: false,
},
{
typ: "StringLikeParam",
wantPointer: false,
},
{
typ: "BasicTypeParam",
wantPointer: false,
},
{
typ: "IntPtrParam",
wantPointer: true,
},
{
typ: "IntegerPtrParam",
wantPointer: true,
},
{
typ: "IntegerParamPtr",
wantPointer: true,
},
{
typ: "IntegerSliceParam",
wantPointer: true,
},
{
typ: "IntegerMapParam",
wantPointer: true,
},
{
typ: "UnsafePointerParam",
wantPointer: true,
},
{
typ: "InterfaceParam",
wantPointer: true,
},
{
typ: "ValueUnionParam",
wantPointer: false,
},
{
typ: "ValueUnionParamPtr",
wantPointer: true,
},
{
typ: "PointerUnionParam",
wantPointer: true,
},
}
_, namedTypes, err := LoadTypes("test", ".")
if err != nil {
log.Fatal(err)
}
for _, tt := range tests {
t.Run(tt.typ, func(t *testing.T) {
typ := namedTypes[tt.typ]
if isPointer := ContainsPointers(typ); isPointer != tt.wantPointer {
t.Fatalf("ContainsPointers: got %v, want: %v", isPointer, tt.wantPointer)
}
})
}
}
Loading…
Cancel
Save