WIP expiring caps

knyar/sshcap
Anton Tolchanov 2 months ago
parent 99f0c03e9b
commit 00088aea1e

@ -173,6 +173,7 @@ func (ms *mapSession) HandleNonKeepAliveMapResponse(ctx context.Context, resp *t
if DevKnob.StripCaps() {
resp.Node.Capabilities = nil
resp.Node.CapMap = nil
resp.Node.ExtraCapMap = nil
}
// If the server is old and is still sending us Capabilities instead of
// CapMap, convert it to CapMap early so the rest of the client code can

@ -268,6 +268,8 @@ func TestRedactNetmapPrivateKeys(t *testing.T) {
f(tailcfg.DisplayMessage{}, "Severity"): false,
f(tailcfg.DisplayMessage{}, "Text"): false,
f(tailcfg.DisplayMessage{}, "Title"): false,
f(tailcfg.ExtraCapMapValue{}, "Expiry"): false,
f(tailcfg.ExtraCapMapValue{}, "Value"): false,
f(tailcfg.FilterRule{}, "CapGrant"): false,
f(tailcfg.FilterRule{}, "DstPorts"): false,
f(tailcfg.FilterRule{}, "IPProto"): false,
@ -353,6 +355,7 @@ func TestRedactNetmapPrivateKeys(t *testing.T) {
f(tailcfg.Node{}, "DiscoKey"): false,
f(tailcfg.Node{}, "Endpoints"): false,
f(tailcfg.Node{}, "ExitNodeDNSResolvers"): false,
f(tailcfg.Node{}, "ExtraCapMap"): false,
f(tailcfg.Node{}, "Expired"): false,
f(tailcfg.Node{}, "HomeDERP"): false,
f(tailcfg.Node{}, "Hostinfo"): false,

@ -153,6 +153,39 @@ func (em *expiryManager) flagExpiredPeers(netmap *netmap.NetworkMap, localNow ti
}
}
func (em *expiryManager) expireNodeCaps(netmap *netmap.NetworkMap, localNow time.Time) {
controlNow := localNow.Add(em.clockDelta.Load())
if controlNow.Before(flagExpiredPeersEpoch) {
em.logf("netmap: expireNodeCaps: [unexpected] delta-adjusted current time is before hardcoded epoch; skipping")
return
}
expireCaps := func(n *tailcfg.Node) (changed bool) {
if len(n.ExtraCapMap) == 0 {
return false
}
for capName, cap := range n.ExtraCapMap {
if !cap.Expiry.IsZero() && cap.Expiry.Before(controlNow) {
delete(n.ExtraCapMap, capName)
changed = true
}
}
return changed
}
if netmap.SelfNode.Valid() {
// TODO(anton): don't clone if there's nothing to change.
self := netmap.SelfNode.AsStruct()
if expireCaps(self) {
netmap.SelfNode = self.View()
}
}
for i, peer := range netmap.Peers {
p := peer.AsStruct()
if expireCaps(p) {
netmap.Peers[i] = p.View()
}
}
}
// nextPeerExpiry returns the time that the next node in the netmap expires
// (including the self node), based on their KeyExpiry. It skips nodes that are
// already marked as Expired. If there are no nodes expiring in the future,
@ -174,43 +207,42 @@ func (em *expiryManager) nextPeerExpiry(nm *netmap.NetworkMap, localNow time.Tim
}
var nextExpiry time.Time // zero if none
for _, peer := range nm.Peers {
if peer.KeyExpiry().IsZero() {
continue // tagged node
} else if peer.Expired() {
// Peer already expired; Expired is set by the
// flagExpiredPeers function, above.
continue
} else if peer.KeyExpiry().Before(controlNow) {
// This peer already expired, and peer.Expired
// isn't set for some reason. Skip this node.
continue
update := func(expiry time.Time) {
if expiry.IsZero() {
return
}
// nextExpiry being zero is a sentinel that we haven't yet set
// an expiry; otherwise, only update if this node's expiry is
// sooner than the currently-stored one (since we want the
// soonest-occurring expiry time).
if nextExpiry.IsZero() || peer.KeyExpiry().Before(nextExpiry) {
nextExpiry = peer.KeyExpiry()
if nextExpiry.IsZero() || expiry.Before(nextExpiry) {
nextExpiry = expiry
}
}
handleNode := func(n tailcfg.NodeView) {
if n.KeyExpiry().IsZero() {
// tagged node
} else if n.Expired() {
// Already expired; Expired is set by the
// flagExpiredPeers function, above.
} else if n.KeyExpiry().Before(controlNow) {
// Already expired, but Expired
// isn't set for some reason. Skip it.
} else {
update(n.KeyExpiry())
}
// Also handle expiring caps.
for _, c := range n.ExtraCapMap().All() {
update(c.Expiry())
}
}
for _, peer := range nm.Peers {
handleNode(peer)
}
// Ensure that we also fire this timer if our own node key expires.
if nm.SelfNode.Valid() {
selfExpiry := nm.SelfNode.KeyExpiry()
if selfExpiry.IsZero() {
// No expiry for self node
} else if selfExpiry.Before(controlNow) {
// Self node already expired; we don't want to return a
// time in the past, so skip this.
} else if nextExpiry.IsZero() || selfExpiry.Before(nextExpiry) {
// Self node expires after now, but before the soonest
// peer in the netmap; update our next expiry to this
// time.
nextExpiry = selfExpiry
}
handleNode(nm.SelfNode)
}
// As an additional defense in depth, never return a time that is

@ -15,6 +15,7 @@ import (
"tailscale.com/types/key"
"tailscale.com/types/netmap"
"tailscale.com/util/eventbus/eventbustest"
"tailscale.com/util/mak"
)
func TestFlagExpiredPeers(t *testing.T) {
@ -238,6 +239,32 @@ func TestNextPeerExpiry(t *testing.T) {
},
want: noExpiry,
},
{
name: "self_attribute",
netmap: &netmap.NetworkMap{
Peers: nodeViews([]*tailcfg.Node{
n(1, "foo", timeInMoreFuture, func(n *tailcfg.Node) {
mak.Set(&n.ExtraCapMap, "foo", tailcfg.ExtraCapMapValue{Expiry: timeInMoreFuture})
}),
}),
SelfNode: n(2, "self", noExpiry, func(n *tailcfg.Node) {
mak.Set(&n.ExtraCapMap, "foo", tailcfg.ExtraCapMapValue{Expiry: timeInFuture})
}).View(),
},
want: timeInFuture,
},
{
name: "peer_attribute",
netmap: &netmap.NetworkMap{
Peers: nodeViews([]*tailcfg.Node{
n(1, "foo", timeInMoreFuture, func(n *tailcfg.Node) {
mak.Set(&n.ExtraCapMap, "foo", tailcfg.ExtraCapMapValue{Expiry: timeInFuture})
}),
}),
SelfNode: n(2, "self", timeInMoreFuture).View(),
},
want: timeInFuture,
},
}
for _, tt := range tests {

@ -1575,6 +1575,7 @@ func (b *LocalBackend) SetControlClientStatus(c controlclient.Client, st control
if st.NetMap != nil {
now := b.clock.Now()
b.em.flagExpiredPeers(st.NetMap, now)
b.em.expireNodeCaps(st.NetMap, now)
// Always stop the existing netmap timer if we have a netmap;
// it's possible that we have no nodes expiring, so we should

@ -5,7 +5,7 @@
// the node and the coordination server.
package tailcfg
//go:generate go run tailscale.com/cmd/viewer --type=User,Node,Hostinfo,NetInfo,Login,DNSConfig,RegisterResponse,RegisterResponseAuth,RegisterRequest,DERPHomeParams,DERPRegion,DERPMap,DERPNode,SSHRule,SSHAction,SSHPrincipal,ControlDialPlan,Location,UserProfile,VIPService --clonefunc
//go:generate go run tailscale.com/cmd/viewer --type=User,Node,Hostinfo,NetInfo,Login,DNSConfig,RegisterResponse,RegisterResponseAuth,RegisterRequest,DERPHomeParams,DERPRegion,DERPMap,DERPNode,SSHRule,SSHAction,SSHPrincipal,ControlDialPlan,Location,UserProfile,VIPService,ExtraCapMapValue --clonefunc
import (
"bytes"
@ -442,6 +442,8 @@ type Node struct {
// for a particular task vs other peers that could also be chosen.
CapMap NodeCapMap `json:",omitempty"`
ExtraCapMap ExtraCapMap `json:",omitempty"`
// UnsignedPeerAPIOnly means that this node is not signed nor subject to TKA
// restrictions. However, in exchange for that privilege, it does not get
// network access. It can only access this node's peerapi, which may not let
@ -522,7 +524,16 @@ func (v NodeView) HasCap(cap NodeCapability) bool {
// HasCap reports whether the node has the given capability.
// It is safe to call on a nil Node.
func (v *Node) HasCap(cap NodeCapability) bool {
return v != nil && v.CapMap.Contains(cap)
if v == nil {
return false
}
if v.CapMap.Contains(cap) {
return true
}
if v.ExtraCapMap.Contains(cap) {
return true
}
return false
}
// DisplayName returns the user-facing name for a node which should
@ -1548,6 +1559,23 @@ const (
PeerCapabilityTsIDP PeerCapability = "tailscale.com/cap/tsidp"
)
type ExtraCapMap map[NodeCapability]ExtraCapMapValue
type ExtraCapMapValue struct {
Expiry time.Time
Value []RawMessage
}
func (c ExtraCapMap) Contains(cap NodeCapability) bool {
_, ok := c[cap]
return ok
}
func (c ExtraCapMap) Equal(c2 ExtraCapMap) bool {
return maps.EqualFunc(c, c2, func(v1, v2 ExtraCapMapValue) bool {
return v1.Expiry.Equal(v2.Expiry) && slices.Equal(v1.Value, v2.Value)
})
}
// NodeCapMap is a map of capabilities to their optional values. It is valid for
// a capability to have no values (nil slice); such capabilities can be tested
// for by using the [NodeCapMap.Contains] method.
@ -2348,6 +2376,7 @@ func (n *Node) Equal(n2 *Node) bool {
n.MachineAuthorized == n2.MachineAuthorized &&
slices.Equal(n.Capabilities, n2.Capabilities) &&
n.CapMap.Equal(n2.CapMap) &&
n.ExtraCapMap.Equal(n2.ExtraCapMap) &&
n.ComputedName == n2.ComputedName &&
n.computedHostIfDifferent == n2.computedHostIfDifferent &&
n.ComputedNameWithHost == n2.ComputedNameWithHost &&

@ -65,6 +65,12 @@ func (src *Node) Clone() *Node {
dst.CapMap[k] = append([]RawMessage{}, src.CapMap[k]...)
}
}
if dst.ExtraCapMap != nil {
dst.ExtraCapMap = map[NodeCapability]ExtraCapMapValue{}
for k, v := range src.ExtraCapMap {
dst.ExtraCapMap[k] = *(v.Clone())
}
}
if dst.SelfNodeV4MasqAddrForThisPeer != nil {
dst.SelfNodeV4MasqAddrForThisPeer = ptr.To(*src.SelfNodeV4MasqAddrForThisPeer)
}
@ -111,6 +117,7 @@ var _NodeCloneNeedsRegeneration = Node(struct {
MachineAuthorized bool
Capabilities []NodeCapability
CapMap NodeCapMap
ExtraCapMap ExtraCapMap
UnsignedPeerAPIOnly bool
ComputedName string
computedHostIfDifferent string
@ -652,9 +659,27 @@ var _VIPServiceCloneNeedsRegeneration = VIPService(struct {
Active bool
}{})
// Clone makes a deep copy of ExtraCapMapValue.
// The result aliases no memory with the original.
func (src *ExtraCapMapValue) Clone() *ExtraCapMapValue {
if src == nil {
return nil
}
dst := new(ExtraCapMapValue)
*dst = *src
dst.Value = append(src.Value[:0:0], src.Value...)
return dst
}
// A compilation failure here means this code must be regenerated, with the command at the top of this file.
var _ExtraCapMapValueCloneNeedsRegeneration = ExtraCapMapValue(struct {
Expiry time.Time
Value []RawMessage
}{})
// Clone duplicates src into dst and reports whether it succeeded.
// To succeed, <src, dst> must be of types <*T, *T> or <*T, **T>,
// where T is one of User,Node,Hostinfo,NetInfo,Login,DNSConfig,RegisterResponse,RegisterResponseAuth,RegisterRequest,DERPHomeParams,DERPRegion,DERPMap,DERPNode,SSHRule,SSHAction,SSHPrincipal,ControlDialPlan,Location,UserProfile,VIPService.
// where T is one of User,Node,Hostinfo,NetInfo,Login,DNSConfig,RegisterResponse,RegisterResponseAuth,RegisterRequest,DERPHomeParams,DERPRegion,DERPMap,DERPNode,SSHRule,SSHAction,SSHPrincipal,ControlDialPlan,Location,UserProfile,VIPService,ExtraCapMapValue.
func Clone(dst, src any) bool {
switch src := src.(type) {
case *User:
@ -837,6 +862,15 @@ func Clone(dst, src any) bool {
*dst = src.Clone()
return true
}
case *ExtraCapMapValue:
switch dst := dst.(type) {
case *ExtraCapMapValue:
*dst = *src.Clone()
return true
case **ExtraCapMapValue:
*dst = src.Clone()
return true
}
}
return false
}

@ -21,7 +21,7 @@ import (
"tailscale.com/types/views"
)
//go:generate go run tailscale.com/cmd/cloner -clonefunc=true -type=User,Node,Hostinfo,NetInfo,Login,DNSConfig,RegisterResponse,RegisterResponseAuth,RegisterRequest,DERPHomeParams,DERPRegion,DERPMap,DERPNode,SSHRule,SSHAction,SSHPrincipal,ControlDialPlan,Location,UserProfile,VIPService
//go:generate go run tailscale.com/cmd/cloner -clonefunc=true -type=User,Node,Hostinfo,NetInfo,Login,DNSConfig,RegisterResponse,RegisterResponseAuth,RegisterRequest,DERPHomeParams,DERPRegion,DERPMap,DERPNode,SSHRule,SSHAction,SSHPrincipal,ControlDialPlan,Location,UserProfile,VIPService,ExtraCapMapValue
// View returns a read-only view of User.
func (p *User) View() UserView {
@ -302,6 +302,11 @@ func (v NodeView) Capabilities() views.Slice[NodeCapability] { return views.Slic
func (v NodeView) CapMap() views.MapSlice[NodeCapability, RawMessage] {
return views.MapSliceOf(v.ж.CapMap)
}
func (v NodeView) ExtraCapMap() views.MapFn[NodeCapability, ExtraCapMapValue, ExtraCapMapValueView] {
return views.MapFnOf(v.ж.ExtraCapMap, func(t ExtraCapMapValue) ExtraCapMapValueView {
return t.View()
})
}
// UnsignedPeerAPIOnly means that this node is not signed nor subject to TKA
// restrictions. However, in exchange for that privilege, it does not get
@ -404,6 +409,7 @@ var _NodeViewNeedsRegeneration = Node(struct {
MachineAuthorized bool
Capabilities []NodeCapability
CapMap NodeCapMap
ExtraCapMap ExtraCapMap
UnsignedPeerAPIOnly bool
ComputedName string
computedHostIfDifferent string
@ -2606,3 +2612,79 @@ var _VIPServiceViewNeedsRegeneration = VIPService(struct {
Ports []ProtoPortRange
Active bool
}{})
// View returns a read-only view of ExtraCapMapValue.
func (p *ExtraCapMapValue) View() ExtraCapMapValueView {
return ExtraCapMapValueView{ж: p}
}
// ExtraCapMapValueView provides a read-only view over ExtraCapMapValue.
//
// Its methods should only be called if `Valid()` returns true.
type ExtraCapMapValueView struct {
// ж is the underlying mutable value, named with a hard-to-type
// character that looks pointy like a pointer.
// It is named distinctively to make you think of how dangerous it is to escape
// to callers. You must not let callers be able to mutate it.
ж *ExtraCapMapValue
}
// Valid reports whether v's underlying value is non-nil.
func (v ExtraCapMapValueView) Valid() bool { return v.ж != nil }
// AsStruct returns a clone of the underlying value which aliases no memory with
// the original.
func (v ExtraCapMapValueView) AsStruct() *ExtraCapMapValue {
if v.ж == nil {
return nil
}
return v.ж.Clone()
}
// MarshalJSON implements [jsonv1.Marshaler].
func (v ExtraCapMapValueView) MarshalJSON() ([]byte, error) {
return jsonv1.Marshal(v.ж)
}
// MarshalJSONTo implements [jsonv2.MarshalerTo].
func (v ExtraCapMapValueView) MarshalJSONTo(enc *jsontext.Encoder) error {
return jsonv2.MarshalEncode(enc, v.ж)
}
// UnmarshalJSON implements [jsonv1.Unmarshaler].
func (v *ExtraCapMapValueView) UnmarshalJSON(b []byte) error {
if v.ж != nil {
return errors.New("already initialized")
}
if len(b) == 0 {
return nil
}
var x ExtraCapMapValue
if err := jsonv1.Unmarshal(b, &x); err != nil {
return err
}
v.ж = &x
return nil
}
// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom].
func (v *ExtraCapMapValueView) UnmarshalJSONFrom(dec *jsontext.Decoder) error {
if v.ж != nil {
return errors.New("already initialized")
}
var x ExtraCapMapValue
if err := jsonv2.UnmarshalDecode(dec, &x); err != nil {
return err
}
v.ж = &x
return nil
}
func (v ExtraCapMapValueView) Expiry() time.Time { return v.ж.Expiry }
func (v ExtraCapMapValueView) Value() views.Slice[RawMessage] { return views.SliceOf(v.ж.Value) }
// A compilation failure here means this code must be regenerated, with the command at the top of this file.
var _ExtraCapMapValueViewNeedsRegeneration = ExtraCapMapValue(struct {
Expiry time.Time
Value []RawMessage
}{})

@ -8,6 +8,7 @@ import (
"fmt"
"net/netip"
"testing"
"testing/synctest"
"time"
"tailscale.com/ipn"
@ -180,56 +181,83 @@ func TestPacketFilterFromNetmap(t *testing.T) {
{src: "2.2.2.2", dst: "1.1.1.2", port: 22, want: filter.Drop}, // different dst
},
},
{
name: "capmap_based_peers_with_expiry",
mapResponse: &tailcfg.MapResponse{
Node: &tailcfg.Node{
Addresses: []netip.Prefix{netip.MustParsePrefix("1.1.1.1/32")},
},
Peers: []*tailcfg.Node{{
ID: 2,
Name: "foo",
Key: key,
Addresses: []netip.Prefix{netip.MustParsePrefix("2.2.2.2/32")},
ExtraCapMap: tailcfg.ExtraCapMap{"X": {
Expiry: time.Now().Add(1 * time.Minute),
}},
}},
PacketFilter: []tailcfg.FilterRule{{
SrcIPs: []string{"cap:X"},
DstPorts: []tailcfg.NetPortRange{{
IP: "1.1.1.1/32",
Ports: tailcfg.PortRange{
First: 22,
Last: 22,
},
}},
IPProto: []int{int(ipproto.TCP)},
}},
},
waitTest: func(nm *netmap.NetworkMap) bool {
return len(nm.Peers) > 0
},
checks: []check{
{src: "2.2.2.2", dst: "1.1.1.1", port: 22, want: filter.Accept},
{src: "2.2.2.2", dst: "1.1.1.1", port: 23, want: filter.Drop}, // different port
{src: "3.3.3.3", dst: "1.1.1.1", port: 22, want: filter.Drop}, // different src
{src: "2.2.2.2", dst: "1.1.1.2", port: 22, want: filter.Drop}, // different dst
},
incrementalMapResponse: &tailcfg.MapResponse{
PeersChanged: []*tailcfg.Node{{
ID: 2,
Name: "foo",
Key: key,
Addresses: []netip.Prefix{netip.MustParsePrefix("2.2.2.3/32")},
}},
},
incrementalWaitTest: func(nm *netmap.NetworkMap) bool {
time.Sleep(time.Minute)
// Wait until the peer's extra cap expires.
if len(nm.Peers) == 0 {
return false
}
peer := nm.Peers[0]
if peer.Addresses().AsSlice()[0].Addr() != netip.MustParseAddr("2.2.2.3") {
return false
}
return true
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), 30*time.Second)
defer cancel()
controlURL, c := startControl(t)
s, _, pubKey := startServer(t, ctx, controlURL, "node")
if test.waitTest(s.lb.NetMap()) {
t.Fatal("waitTest already passes before sending initial netmap: this will be flaky")
}
if !c.AddRawMapResponse(pubKey, test.mapResponse) {
t.Fatalf("could not send map response to %s", pubKey)
}
if err := waitFor(t, ctx, s, test.waitTest); err != nil {
t.Fatalf("waitFor: %s", err)
}
pf := s.lb.GetFilterForTest()
for _, check := range test.checks {
got := pf.Check(netip.MustParseAddr(check.src), netip.MustParseAddr(check.dst), check.port, ipproto.TCP)
want := check.want
if test.incrementalMapResponse != nil {
want = filter.Drop
}
if got != want {
t.Errorf("check %s -> %s:%d, got: %s, want: %s", check.src, check.dst, check.port, got, want)
}
}
synctest.Test(t, func(t *testing.T) {
ctx, cancel := context.WithTimeout(t.Context(), 30*time.Second)
defer cancel()
if test.incrementalMapResponse != nil {
if test.incrementalWaitTest == nil {
t.Fatal("incrementalWaitTest must be set if incrementalMapResponse is set")
}
controlURL, c := startControl(t)
s, _, pubKey := startServer(t, ctx, controlURL, "node")
if test.incrementalWaitTest(s.lb.NetMap()) {
t.Fatal("incrementalWaitTest already passes before sending incremental netmap: this will be flaky")
if test.waitTest(s.lb.NetMap()) {
t.Fatal("waitTest already passes before sending initial netmap: this will be flaky")
}
if !c.AddRawMapResponse(pubKey, test.incrementalMapResponse) {
if !c.AddRawMapResponse(pubKey, test.mapResponse) {
t.Fatalf("could not send map response to %s", pubKey)
}
if err := waitFor(t, ctx, s, test.incrementalWaitTest); err != nil {
if err := waitFor(t, ctx, s, test.waitTest); err != nil {
t.Fatalf("waitFor: %s", err)
}
@ -237,12 +265,44 @@ func TestPacketFilterFromNetmap(t *testing.T) {
for _, check := range test.checks {
got := pf.Check(netip.MustParseAddr(check.src), netip.MustParseAddr(check.dst), check.port, ipproto.TCP)
if got != check.want {
t.Errorf("check %s -> %s:%d, got: %s, want: %s", check.src, check.dst, check.port, got, check.want)
want := check.want
if test.incrementalMapResponse != nil {
want = filter.Drop
}
if got != want {
t.Errorf("check %s -> %s:%d, got: %s, want: %s", check.src, check.dst, check.port, got, want)
}
}
if test.incrementalMapResponse != nil {
if test.incrementalWaitTest == nil {
t.Fatal("incrementalWaitTest must be set if incrementalMapResponse is set")
}
if test.incrementalWaitTest(s.lb.NetMap()) {
t.Fatal("incrementalWaitTest already passes before sending incremental netmap: this will be flaky")
}
if !c.AddRawMapResponse(pubKey, test.incrementalMapResponse) {
t.Fatalf("could not send map response to %s", pubKey)
}
if err := waitFor(t, ctx, s, test.incrementalWaitTest); err != nil {
t.Fatalf("waitFor: %s", err)
}
pf := s.lb.GetFilterForTest()
for _, check := range test.checks {
got := pf.Check(netip.MustParseAddr(check.src), netip.MustParseAddr(check.dst), check.port, ipproto.TCP)
if got != check.want {
t.Errorf("check %s -> %s:%d, got: %s, want: %s", check.src, check.dst, check.port, got, check.want)
}
}
}
}
})
})
}
}

Loading…
Cancel
Save