net/tstun: use gaissmai/bart instead of tempfork/device

This implementation uses less memory than tempfork/device,
which helps avoid OOM conditions in the iOS VPN extension when
switching to a Tailnet with ExitNode routing enabled.

Updates tailscale/corp#18514

Signed-off-by: Percy Wegmann <percy@tailscale.com>
pull/11516/head
Percy Wegmann 8 months ago committed by Percy Wegmann
parent 1e7050e73a
commit 8b8b315258

@ -78,6 +78,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
L github.com/aws/smithy-go/transport/http from github.com/aws/aws-sdk-go-v2/aws/middleware+ L github.com/aws/smithy-go/transport/http from github.com/aws/aws-sdk-go-v2/aws/middleware+
L github.com/aws/smithy-go/transport/http/internal/io from github.com/aws/smithy-go/transport/http L github.com/aws/smithy-go/transport/http/internal/io from github.com/aws/smithy-go/transport/http
L github.com/aws/smithy-go/waiter from github.com/aws/aws-sdk-go-v2/service/ssm L github.com/aws/smithy-go/waiter from github.com/aws/aws-sdk-go-v2/service/ssm
github.com/bits-and-blooms/bitset from github.com/gaissmai/bart
L github.com/coreos/go-iptables/iptables from tailscale.com/util/linuxfw L github.com/coreos/go-iptables/iptables from tailscale.com/util/linuxfw
L github.com/coreos/go-systemd/v22/dbus from tailscale.com/clientupdate L github.com/coreos/go-systemd/v22/dbus from tailscale.com/clientupdate
LD 💣 github.com/creack/pty from tailscale.com/ssh/tailssh LD 💣 github.com/creack/pty from tailscale.com/ssh/tailssh
@ -89,6 +90,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
LW 💣 github.com/digitalocean/go-smbios/smbios from tailscale.com/posture LW 💣 github.com/digitalocean/go-smbios/smbios from tailscale.com/posture
💣 github.com/djherbis/times from tailscale.com/tailfs/tailfsimpl 💣 github.com/djherbis/times from tailscale.com/tailfs/tailfsimpl
github.com/fxamacker/cbor/v2 from tailscale.com/tka github.com/fxamacker/cbor/v2 from tailscale.com/tka
github.com/gaissmai/bart from tailscale.com/net/tstun
W 💣 github.com/go-ole/go-ole from github.com/go-ole/go-ole/oleutil+ W 💣 github.com/go-ole/go-ole from github.com/go-ole/go-ole/oleutil+
W 💣 github.com/go-ole/go-ole/oleutil from tailscale.com/wgengine/winnet W 💣 github.com/go-ole/go-ole/oleutil from tailscale.com/wgengine/winnet
L 💣 github.com/godbus/dbus/v5 from github.com/coreos/go-systemd/v22/dbus+ L 💣 github.com/godbus/dbus/v5 from github.com/coreos/go-systemd/v22/dbus+
@ -308,7 +310,6 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
tailscale.com/net/tsdial from tailscale.com/cmd/tailscaled+ tailscale.com/net/tsdial from tailscale.com/cmd/tailscaled+
💣 tailscale.com/net/tshttpproxy from tailscale.com/clientupdate/distsign+ 💣 tailscale.com/net/tshttpproxy from tailscale.com/clientupdate/distsign+
tailscale.com/net/tstun from tailscale.com/cmd/tailscaled+ tailscale.com/net/tstun from tailscale.com/cmd/tailscaled+
tailscale.com/net/tstun/table from tailscale.com/net/tstun
tailscale.com/net/wsconn from tailscale.com/control/controlhttp+ tailscale.com/net/wsconn from tailscale.com/control/controlhttp+
tailscale.com/paths from tailscale.com/client/tailscale+ tailscale.com/paths from tailscale.com/client/tailscale+
💣 tailscale.com/portlist from tailscale.com/ipn/ipnlocal 💣 tailscale.com/portlist from tailscale.com/ipn/ipnlocal
@ -324,7 +325,6 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
tailscale.com/tailfs/tailfsimpl/compositedav from tailscale.com/tailfs/tailfsimpl tailscale.com/tailfs/tailfsimpl/compositedav from tailscale.com/tailfs/tailfsimpl
tailscale.com/tailfs/tailfsimpl/dirfs from tailscale.com/tailfs/tailfsimpl+ tailscale.com/tailfs/tailfsimpl/dirfs from tailscale.com/tailfs/tailfsimpl+
tailscale.com/tailfs/tailfsimpl/shared from tailscale.com/tailfs/tailfsimpl+ tailscale.com/tailfs/tailfsimpl/shared from tailscale.com/tailfs/tailfsimpl+
💣 tailscale.com/tempfork/device from tailscale.com/net/tstun/table
LD tailscale.com/tempfork/gliderlabs/ssh from tailscale.com/ssh/tailssh LD tailscale.com/tempfork/gliderlabs/ssh from tailscale.com/ssh/tailssh
tailscale.com/tempfork/heap from tailscale.com/wgengine/magicsock tailscale.com/tempfork/heap from tailscale.com/wgengine/magicsock
tailscale.com/tka from tailscale.com/client/tailscale+ tailscale.com/tka from tailscale.com/client/tailscale+

@ -28,6 +28,7 @@ require (
github.com/evanw/esbuild v0.19.11 github.com/evanw/esbuild v0.19.11
github.com/frankban/quicktest v1.14.6 github.com/frankban/quicktest v1.14.6
github.com/fxamacker/cbor/v2 v2.5.0 github.com/fxamacker/cbor/v2 v2.5.0
github.com/gaissmai/bart v0.4.0
github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0 github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0
github.com/go-logr/zapr v1.3.0 github.com/go-logr/zapr v1.3.0
github.com/go-ole/go-ole v1.3.0 github.com/go-ole/go-ole v1.3.0
@ -115,6 +116,7 @@ require (
require ( require (
github.com/Microsoft/go-winio v0.6.1 // indirect github.com/Microsoft/go-winio v0.6.1 // indirect
github.com/bits-and-blooms/bitset v1.13.0 // indirect
github.com/cyphar/filepath-securejoin v0.2.4 // indirect github.com/cyphar/filepath-securejoin v0.2.4 // indirect
github.com/dave/astrid v0.0.0-20170323122508-8c2895878b14 // indirect github.com/dave/astrid v0.0.0-20170323122508-8c2895878b14 // indirect
github.com/dave/brenda v1.1.0 // indirect github.com/dave/brenda v1.1.0 // indirect

@ -167,6 +167,8 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24
github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bits-and-blooms/bitset v1.13.0 h1:bAQ9OPNFYbGHV6Nez0tmNI0RiEu7/hxlYJRUA0wFAVE=
github.com/bits-and-blooms/bitset v1.13.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8=
github.com/bkielbasa/cyclop v1.2.0 h1:7Jmnh0yL2DjKfw28p86YTd/B4lRGcNuu12sKE35sM7A= github.com/bkielbasa/cyclop v1.2.0 h1:7Jmnh0yL2DjKfw28p86YTd/B4lRGcNuu12sKE35sM7A=
github.com/bkielbasa/cyclop v1.2.0/go.mod h1:qOI0yy6A7dYC4Zgsa72Ppm9kONl0RoIlPbzot9mhmeI= github.com/bkielbasa/cyclop v1.2.0/go.mod h1:qOI0yy6A7dYC4Zgsa72Ppm9kONl0RoIlPbzot9mhmeI=
github.com/blakesmith/ar v0.0.0-20190502131153-809d4375e1fb h1:m935MPodAbYS46DG4pJSv7WO+VECIWUQ7OJYSoTrMh4= github.com/blakesmith/ar v0.0.0-20190502131153-809d4375e1fb h1:m935MPodAbYS46DG4pJSv7WO+VECIWUQ7OJYSoTrMh4=
@ -294,6 +296,8 @@ github.com/fxamacker/cbor/v2 v2.5.0 h1:oHsG0V/Q6E/wqTS2O1Cozzsy69nqCiguo5Q1a1ADi
github.com/fxamacker/cbor/v2 v2.5.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo= github.com/fxamacker/cbor/v2 v2.5.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo=
github.com/fzipp/gocyclo v0.6.0 h1:lsblElZG7d3ALtGMx9fmxeTKZaLLpU8mET09yN4BBLo= github.com/fzipp/gocyclo v0.6.0 h1:lsblElZG7d3ALtGMx9fmxeTKZaLLpU8mET09yN4BBLo=
github.com/fzipp/gocyclo v0.6.0/go.mod h1:rXPyn8fnlpa0R2csP/31uerbiVBugk5whMdlyaLkLoA= github.com/fzipp/gocyclo v0.6.0/go.mod h1:rXPyn8fnlpa0R2csP/31uerbiVBugk5whMdlyaLkLoA=
github.com/gaissmai/bart v0.4.0 h1:ImIFoETsNMBzUr21tMGD82GQIwAb555fI6uxEyCHBTI=
github.com/gaissmai/bart v0.4.0/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg=
github.com/github/fakeca v0.1.0 h1:Km/MVOFvclqxPM9dZBC4+QE564nU4gz4iZ0D9pMw28I= github.com/github/fakeca v0.1.0 h1:Km/MVOFvclqxPM9dZBC4+QE564nU4gz4iZ0D9pMw28I=
github.com/github/fakeca v0.1.0/go.mod h1:+bormgoGMMuamOscx7N91aOuUST7wdaJ2rNjeohylyo= github.com/github/fakeca v0.1.0/go.mod h1:+bormgoGMMuamOscx7N91aOuUST7wdaJ2rNjeohylyo=
github.com/gliderlabs/ssh v0.3.5 h1:OcaySEmAQJgyYcArR+gGGTHCyE7nvhEMTlYY+Dp8CpY= github.com/gliderlabs/ssh v0.3.5 h1:OcaySEmAQJgyYcArR+gGGTHCyE7nvhEMTlYY+Dp8CpY=

@ -111,5 +111,5 @@ Some packages may only be included on certain architectures or operating systems
- [sigs.k8s.io/yaml/goyaml.v2](https://pkg.go.dev/sigs.k8s.io/yaml/goyaml.v2) ([Apache-2.0](https://github.com/kubernetes-sigs/yaml/blob/v1.4.0/goyaml.v2/LICENSE)) - [sigs.k8s.io/yaml/goyaml.v2](https://pkg.go.dev/sigs.k8s.io/yaml/goyaml.v2) ([Apache-2.0](https://github.com/kubernetes-sigs/yaml/blob/v1.4.0/goyaml.v2/LICENSE))
- [software.sslmate.com/src/go-pkcs12](https://pkg.go.dev/software.sslmate.com/src/go-pkcs12) ([BSD-3-Clause](https://github.com/SSLMate/go-pkcs12/blob/v0.4.0/LICENSE)) - [software.sslmate.com/src/go-pkcs12](https://pkg.go.dev/software.sslmate.com/src/go-pkcs12) ([BSD-3-Clause](https://github.com/SSLMate/go-pkcs12/blob/v0.4.0/LICENSE))
- [tailscale.com](https://pkg.go.dev/tailscale.com) ([BSD-3-Clause](https://github.com/tailscale/tailscale/blob/HEAD/LICENSE)) - [tailscale.com](https://pkg.go.dev/tailscale.com) ([BSD-3-Clause](https://github.com/tailscale/tailscale/blob/HEAD/LICENSE))
- [tailscale.com/tempfork/device](https://pkg.go.dev/tailscale.com/tempfork/device) ([MIT](https://github.com/tailscale/tailscale/blob/HEAD/tempfork/device/LICENSE)) - [github.com/gaissmai/bart](https://pkg.go.dev/github.com/gaissmai/bart) ([MIT](https://github.com/gaissmai/bart?tab=MIT-1-ov-file#readme))
- [tailscale.com/tempfork/gliderlabs/ssh](https://pkg.go.dev/tailscale.com/tempfork/gliderlabs/ssh) ([BSD-3-Clause](https://github.com/tailscale/tailscale/blob/HEAD/tempfork/gliderlabs/ssh/LICENSE)) - [tailscale.com/tempfork/gliderlabs/ssh](https://pkg.go.dev/tailscale.com/tempfork/gliderlabs/ssh) ([BSD-3-Clause](https://github.com/tailscale/tailscale/blob/HEAD/tempfork/gliderlabs/ssh/LICENSE))

@ -1,90 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package table provides a Routing Table implementation which allows
// looking up the peer that should be used to route a given IP address.
package table
import (
"net/netip"
"tailscale.com/tempfork/device"
"tailscale.com/types/key"
"tailscale.com/util/mak"
)
// RoutingTableBuilder is a builder for a RoutingTable.
// It is not safe for concurrent use.
type RoutingTableBuilder struct {
// peers is a map from node public key to the peer that owns that key.
// It is only used to handle insertions and deletions.
peers map[key.NodePublic]*device.Peer
// prefixTrie is a trie of prefixes which facilitates looking up the
// peer that owns a given IP address.
prefixTrie *device.AllowedIPs
}
// Remove removes the given peer from the routing table.
func (t *RoutingTableBuilder) Remove(peer key.NodePublic) {
p, ok := t.peers[peer]
if !ok {
return
}
t.prefixTrie.RemoveByPeer(p)
delete(t.peers, peer)
}
// InsertOrReplace inserts the given peer and prefixes into the routing table.
func (t *RoutingTableBuilder) InsertOrReplace(peer key.NodePublic, pfxs ...netip.Prefix) {
p, ok := t.peers[peer]
if !ok {
p = device.NewPeer(peer)
mak.Set(&t.peers, peer, p)
} else {
t.prefixTrie.RemoveByPeer(p)
}
if len(pfxs) == 0 {
return
}
if t.prefixTrie == nil {
t.prefixTrie = new(device.AllowedIPs)
}
for _, pfx := range pfxs {
t.prefixTrie.Insert(pfx, p)
}
}
// Build returns a RoutingTable that can be used to look up peers.
// Build resets the RoutingTableBuilder to its zero value.
func (t *RoutingTableBuilder) Build() *RoutingTable {
pt := t.prefixTrie
t.prefixTrie = nil
t.peers = nil
return &RoutingTable{
prefixTrie: pt,
}
}
// RoutingTable provides a mapping from IP addresses to peers identified by
// their public node key. It is used to find the peer that should be used to
// route a given IP address.
// It is immutable after creation.
//
// It is safe for concurrent use.
type RoutingTable struct {
prefixTrie *device.AllowedIPs
}
// Lookup returns the peer that would be used to route the given IP address.
// If no peer is found, Lookup returns the zero value.
func (t *RoutingTable) Lookup(ip netip.Addr) (_ key.NodePublic, ok bool) {
if t == nil {
return key.NodePublic{}, false
}
p := t.prefixTrie.Lookup(ip.AsSlice())
if p == nil {
return key.NodePublic{}, false
}
return p.Key(), true
}

@ -18,6 +18,7 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/gaissmai/bart"
"github.com/tailscale/wireguard-go/device" "github.com/tailscale/wireguard-go/device"
"github.com/tailscale/wireguard-go/tun" "github.com/tailscale/wireguard-go/tun"
"go4.org/mem" "go4.org/mem"
@ -27,7 +28,6 @@ import (
"tailscale.com/net/packet" "tailscale.com/net/packet"
"tailscale.com/net/packet/checksum" "tailscale.com/net/packet/checksum"
"tailscale.com/net/tsaddr" "tailscale.com/net/tsaddr"
"tailscale.com/net/tstun/table"
"tailscale.com/syncs" "tailscale.com/syncs"
"tailscale.com/tstime/mono" "tailscale.com/tstime/mono"
"tailscale.com/types/ipproto" "tailscale.com/types/ipproto"
@ -611,15 +611,13 @@ type natFamilyConfig struct {
// peers will use to connect to this node. // peers will use to connect to this node.
listenAddrs views.Map[netip.Addr, struct{}] // masqAddr -> struct{} listenAddrs views.Map[netip.Addr, struct{}] // masqAddr -> struct{}
// dstMasqAddrs is map of dst addresses to their respective MasqueradeAsIP // dstMasqAddrs is the routing table used to map a given dst IP to the
// addresses. The MasqueradeAsIP address is the address that should be used // respective MasqueradeAsIP address. The MasqueradeAsIP address is the
// as the source address for packets to dst. // address that should be used as the source address for packets to dst.
dstMasqAddrs views.Map[key.NodePublic, netip.Addr] // dst -> masqAddr dstMasqAddrs *bart.Table[netip.Addr]
// dstAddrToPeerKeyMapper is the routing table used to map a given dst IP to // masqAddrCounts is a count of peers by MasqueradeAsIP.
// the peer key responsible for that IP. masqAddrCounts map[netip.Addr]int
// It only contains peers that require a MasqueradeAsIP address.
dstAddrToPeerKeyMapper *table.RoutingTable
} }
func (c *natFamilyConfig) String() string { func (c *natFamilyConfig) String() string {
@ -640,15 +638,10 @@ func (c *natFamilyConfig) String() string {
i++ i++
return true return true
}) })
count := map[netip.Addr]int{}
c.dstMasqAddrs.Range(func(_ key.NodePublic, v netip.Addr) bool {
count[v]++
return true
})
i = 0 i = 0
b.WriteString("], dstMasqAddrs: [") b.WriteString("], dstMasqAddrs: [")
for k, v := range count { for k, v := range c.masqAddrCounts {
if i > 0 { if i > 0 {
b.WriteString(", ") b.WriteString(", ")
} }
@ -682,14 +675,11 @@ func (c *natFamilyConfig) selectSrcIP(oldSrc, dst netip.Addr) netip.Addr {
if oldSrc != c.nativeAddr { if oldSrc != c.nativeAddr {
return oldSrc return oldSrc
} }
p, ok := c.dstAddrToPeerKeyMapper.Lookup(dst) eip, ok := c.dstMasqAddrs.Get(dst)
if !ok { if !ok {
return oldSrc return oldSrc
} }
if eip, ok := c.dstMasqAddrs.GetOk(p); ok { return eip
return eip
}
return oldSrc
} }
// natConfigFromWGConfig generates a natFamilyConfig from nm, // natConfigFromWGConfig generates a natFamilyConfig from nm,
@ -712,9 +702,9 @@ func natConfigFromWGConfig(wcfg *wgcfg.Config, addrFam ipproto.Version) *natFami
} }
var ( var (
rt table.RoutingTableBuilder rt bart.Table[netip.Addr]
dstMasqAddrs map[key.NodePublic]netip.Addr masqAddrCounts = map[netip.Addr]int{}
listenAddrs set.Set[netip.Addr] listenAddrs set.Set[netip.Addr]
) )
// When using an exit node that requires masquerading, we need to // When using an exit node that requires masquerading, we need to
@ -747,17 +737,20 @@ func natConfigFromWGConfig(wcfg *wgcfg.Config, addrFam ipproto.Version) *natFami
} else { } else {
continue continue
} }
rt.InsertOrReplace(p.PublicKey, p.AllowedIPs...)
mak.Set(&dstMasqAddrs, p.PublicKey, addrToUse) masqAddrCounts[addrToUse]++
for _, ip := range p.AllowedIPs {
rt.Insert(ip, addrToUse)
}
} }
if len(listenAddrs) == 0 && len(dstMasqAddrs) == 0 { if len(listenAddrs) == 0 && len(masqAddrCounts) == 0 {
return nil return nil
} }
return &natFamilyConfig{ return &natFamilyConfig{
nativeAddr: nativeAddr, nativeAddr: nativeAddr,
listenAddrs: views.MapOf(listenAddrs), listenAddrs: views.MapOf(listenAddrs),
dstMasqAddrs: views.MapOf(dstMasqAddrs), dstMasqAddrs: &rt,
dstAddrToPeerKeyMapper: rt.Build(), masqAddrCounts: masqAddrCounts,
} }
} }

@ -1,17 +0,0 @@
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
of the Software, and to permit persons to whom the Software is furnished to do
so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

@ -1,3 +0,0 @@
This is a fork of golang.zx2c4.com/wireguard/device that only keeps the bare
minimum data structures required for AllowedIPs. It is meant to be short lived
until we replace it with our version of a routing table.

@ -1,294 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"container/list"
"encoding/binary"
"errors"
"math/bits"
"net"
"net/netip"
"sync"
"unsafe"
)
type parentIndirection struct {
parentBit **trieEntry
parentBitType uint8
}
type trieEntry struct {
peer *Peer
child [2]*trieEntry
parent parentIndirection
cidr uint8
bitAtByte uint8
bitAtShift uint8
bits []byte
perPeerElem *list.Element
}
func commonBits(ip1, ip2 []byte) uint8 {
size := len(ip1)
if size == net.IPv4len {
a := binary.BigEndian.Uint32(ip1)
b := binary.BigEndian.Uint32(ip2)
x := a ^ b
return uint8(bits.LeadingZeros32(x))
} else if size == net.IPv6len {
a := binary.BigEndian.Uint64(ip1)
b := binary.BigEndian.Uint64(ip2)
x := a ^ b
if x != 0 {
return uint8(bits.LeadingZeros64(x))
}
a = binary.BigEndian.Uint64(ip1[8:])
b = binary.BigEndian.Uint64(ip2[8:])
x = a ^ b
return 64 + uint8(bits.LeadingZeros64(x))
} else {
panic("Wrong size bit string")
}
}
func (node *trieEntry) addToPeerEntries() {
node.perPeerElem = node.peer.trieEntries.PushBack(node)
}
func (node *trieEntry) removeFromPeerEntries() {
if node.perPeerElem != nil {
node.peer.trieEntries.Remove(node.perPeerElem)
node.perPeerElem = nil
}
}
func (node *trieEntry) choose(ip []byte) byte {
return (ip[node.bitAtByte] >> node.bitAtShift) & 1
}
func (node *trieEntry) maskSelf() {
mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
for i := 0; i < len(mask); i++ {
node.bits[i] &= mask[i]
}
}
func (node *trieEntry) zeroizePointers() {
// Make the garbage collector's life slightly easier
node.peer = nil
node.child[0] = nil
node.child[1] = nil
node.parent.parentBit = nil
}
func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) {
for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
parent = node
if parent.cidr == cidr {
exact = true
return
}
bit := node.choose(ip)
node = node.child[bit]
}
return
}
func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
if *trie.parentBit == nil {
node := &trieEntry{
peer: peer,
parent: trie,
bits: ip,
cidr: cidr,
bitAtByte: cidr / 8,
bitAtShift: 7 - (cidr % 8),
}
node.maskSelf()
node.addToPeerEntries()
*trie.parentBit = node
return
}
node, exact := (*trie.parentBit).nodePlacement(ip, cidr)
if exact {
node.removeFromPeerEntries()
node.peer = peer
node.addToPeerEntries()
return
}
newNode := &trieEntry{
peer: peer,
bits: ip,
cidr: cidr,
bitAtByte: cidr / 8,
bitAtShift: 7 - (cidr % 8),
}
newNode.maskSelf()
newNode.addToPeerEntries()
var down *trieEntry
if node == nil {
down = *trie.parentBit
} else {
bit := node.choose(ip)
down = node.child[bit]
if down == nil {
newNode.parent = parentIndirection{&node.child[bit], bit}
node.child[bit] = newNode
return
}
}
common := commonBits(down.bits, ip)
if common < cidr {
cidr = common
}
parent := node
if newNode.cidr == cidr {
bit := newNode.choose(down.bits)
down.parent = parentIndirection{&newNode.child[bit], bit}
newNode.child[bit] = down
if parent == nil {
newNode.parent = trie
*trie.parentBit = newNode
} else {
bit := parent.choose(newNode.bits)
newNode.parent = parentIndirection{&parent.child[bit], bit}
parent.child[bit] = newNode
}
return
}
node = &trieEntry{
bits: append([]byte{}, newNode.bits...),
cidr: cidr,
bitAtByte: cidr / 8,
bitAtShift: 7 - (cidr % 8),
}
node.maskSelf()
bit := node.choose(down.bits)
down.parent = parentIndirection{&node.child[bit], bit}
node.child[bit] = down
bit = node.choose(newNode.bits)
newNode.parent = parentIndirection{&node.child[bit], bit}
node.child[bit] = newNode
if parent == nil {
node.parent = trie
*trie.parentBit = node
} else {
bit := parent.choose(node.bits)
node.parent = parentIndirection{&parent.child[bit], bit}
parent.child[bit] = node
}
}
func (node *trieEntry) lookup(ip []byte) *Peer {
var found *Peer
size := uint8(len(ip))
for node != nil && commonBits(node.bits, ip) >= node.cidr {
if node.peer != nil {
found = node.peer
}
if node.bitAtByte == size {
break
}
bit := node.choose(ip)
node = node.child[bit]
}
return found
}
type AllowedIPs struct {
IPv4 *trieEntry
IPv6 *trieEntry
mutex sync.RWMutex
}
func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) {
table.mutex.RLock()
defer table.mutex.RUnlock()
for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
node := elem.Value.(*trieEntry)
a, _ := netip.AddrFromSlice(node.bits)
if !cb(netip.PrefixFrom(a, int(node.cidr))) {
return
}
}
}
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
var next *list.Element
for elem := peer.trieEntries.Front(); elem != nil; elem = next {
next = elem.Next()
node := elem.Value.(*trieEntry)
node.removeFromPeerEntries()
node.peer = nil
if node.child[0] != nil && node.child[1] != nil {
continue
}
bit := 0
if node.child[0] == nil {
bit = 1
}
child := node.child[bit]
if child != nil {
child.parent = node.parent
}
*node.parent.parentBit = child
if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
node.zeroizePointers()
continue
}
parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
if parent.peer != nil {
node.zeroizePointers()
continue
}
child = parent.child[node.parent.parentBitType^1]
if child != nil {
child.parent = parent.parent
}
*parent.parent.parentBit = child
node.zeroizePointers()
parent.zeroizePointers()
}
}
func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
if prefix.Addr().Is6() {
ip := prefix.Addr().As16()
parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
} else if prefix.Addr().Is4() {
ip := prefix.Addr().As4()
parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
} else {
panic(errors.New("inserting unknown address type"))
}
}
func (table *AllowedIPs) Lookup(ip []byte) *Peer {
table.mutex.RLock()
defer table.mutex.RUnlock()
switch len(ip) {
case net.IPv6len:
return table.IPv6.lookup(ip)
case net.IPv4len:
return table.IPv4.lookup(ip)
default:
panic(errors.New("looking up unknown address type"))
}
}

@ -1,141 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"math/rand"
"net"
"net/netip"
"sort"
"testing"
)
const (
NumberOfPeers = 100
NumberOfPeerRemovals = 4
NumberOfAddresses = 250
NumberOfTests = 10000
)
type SlowNode struct {
peer *Peer
cidr uint8
bits []byte
}
type SlowRouter []*SlowNode
func (r SlowRouter) Len() int {
return len(r)
}
func (r SlowRouter) Less(i, j int) bool {
return r[i].cidr > r[j].cidr
}
func (r SlowRouter) Swap(i, j int) {
r[i], r[j] = r[j], r[i]
}
func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter {
for _, t := range r {
if t.cidr == cidr && commonBits(t.bits, addr) >= cidr {
t.peer = peer
t.bits = addr
return r
}
}
r = append(r, &SlowNode{
cidr: cidr,
bits: addr,
peer: peer,
})
sort.Sort(r)
return r
}
func (r SlowRouter) Lookup(addr []byte) *Peer {
for _, t := range r {
common := commonBits(t.bits, addr)
if common >= t.cidr {
return t.peer
}
}
return nil
}
func (r SlowRouter) RemoveByPeer(peer *Peer) SlowRouter {
n := 0
for _, x := range r {
if x.peer != peer {
r[n] = x
n++
}
}
return r[:n]
}
func TestTrieRandom(t *testing.T) {
var slow4, slow6 SlowRouter
var peers []*Peer
var allowedIPs AllowedIPs
rand.Seed(1)
for n := 0; n < NumberOfPeers; n++ {
peers = append(peers, &Peer{})
}
for n := 0; n < NumberOfAddresses; n++ {
var addr4 [4]byte
rand.Read(addr4[:])
cidr := uint8(rand.Intn(32) + 1)
index := rand.Intn(NumberOfPeers)
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index])
slow4 = slow4.Insert(addr4[:], cidr, peers[index])
var addr6 [16]byte
rand.Read(addr6[:])
cidr = uint8(rand.Intn(128) + 1)
index = rand.Intn(NumberOfPeers)
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index])
slow6 = slow6.Insert(addr6[:], cidr, peers[index])
}
var p int
for p = 0; ; p++ {
for n := 0; n < NumberOfTests; n++ {
var addr4 [4]byte
rand.Read(addr4[:])
peer1 := slow4.Lookup(addr4[:])
peer2 := allowedIPs.Lookup(addr4[:])
if peer1 != peer2 {
t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2)
}
var addr6 [16]byte
rand.Read(addr6[:])
peer1 = slow6.Lookup(addr6[:])
peer2 = allowedIPs.Lookup(addr6[:])
if peer1 != peer2 {
t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2)
}
}
if p >= len(peers) || p >= NumberOfPeerRemovals {
break
}
allowedIPs.RemoveByPeer(peers[p])
slow4 = slow4.RemoveByPeer(peers[p])
slow6 = slow6.RemoveByPeer(peers[p])
}
for ; p < len(peers); p++ {
allowedIPs.RemoveByPeer(peers[p])
}
if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
t.Error("Failed to remove all nodes from trie by peer")
}
}

@ -1,247 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"math/rand"
"net"
"net/netip"
"testing"
)
type testPairCommonBits struct {
s1 []byte
s2 []byte
match uint8
}
func TestCommonBits(t *testing.T) {
tests := []testPairCommonBits{
{s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7},
{s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13},
{s1: []byte{0, 4, 53, 253}, s2: []byte{0, 4, 53, 252}, match: 31},
{s1: []byte{192, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 15},
{s1: []byte{65, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 0},
}
for _, p := range tests {
v := commonBits(p.s1, p.s2)
if v != p.match {
t.Error(
"For slice", p.s1, p.s2,
"expected match", p.match,
",but got", v,
)
}
}
}
func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) {
var trie *trieEntry
var peers []*Peer
root := parentIndirection{&trie, 2}
rand.Seed(1)
const AddressLength = 4
for n := 0; n < peerNumber; n++ {
peers = append(peers, &Peer{})
}
for n := 0; n < addressNumber; n++ {
var addr [AddressLength]byte
rand.Read(addr[:])
cidr := uint8(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % peerNumber
root.insert(addr[:], cidr, peers[index])
}
for n := 0; n < b.N; n++ {
var addr [AddressLength]byte
rand.Read(addr[:])
trie.lookup(addr[:])
}
}
func BenchmarkTrieIPv4Peers100Addresses1000(b *testing.B) {
benchmarkTrie(100, 1000, net.IPv4len, b)
}
func BenchmarkTrieIPv4Peers10Addresses10(b *testing.B) {
benchmarkTrie(10, 10, net.IPv4len, b)
}
func BenchmarkTrieIPv6Peers100Addresses1000(b *testing.B) {
benchmarkTrie(100, 1000, net.IPv6len, b)
}
func BenchmarkTrieIPv6Peers10Addresses10(b *testing.B) {
benchmarkTrie(10, 10, net.IPv6len, b)
}
/* Test ported from kernel implementation:
* selftest/allowedips.h
*/
func TestTrieIPv4(t *testing.T) {
a := &Peer{}
b := &Peer{}
c := &Peer{}
d := &Peer{}
e := &Peer{}
g := &Peer{}
h := &Peer{}
var allowedIPs AllowedIPs
insert := func(peer *Peer, a, b, c, d byte, cidr uint8) {
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
}
assertEQ := func(peer *Peer, a, b, c, d byte) {
p := allowedIPs.Lookup([]byte{a, b, c, d})
if p != peer {
t.Error("Assert EQ failed")
}
}
assertNEQ := func(peer *Peer, a, b, c, d byte) {
p := allowedIPs.Lookup([]byte{a, b, c, d})
if p == peer {
t.Error("Assert NEQ failed")
}
}
insert(a, 192, 168, 4, 0, 24)
insert(b, 192, 168, 4, 4, 32)
insert(c, 192, 168, 0, 0, 16)
insert(d, 192, 95, 5, 64, 27)
insert(c, 192, 95, 5, 65, 27)
insert(e, 0, 0, 0, 0, 0)
insert(g, 64, 15, 112, 0, 20)
insert(h, 64, 15, 123, 211, 25)
insert(a, 10, 0, 0, 0, 25)
insert(b, 10, 0, 0, 128, 25)
insert(a, 10, 1, 0, 0, 30)
insert(b, 10, 1, 0, 4, 30)
insert(c, 10, 1, 0, 8, 29)
insert(d, 10, 1, 0, 16, 29)
assertEQ(a, 192, 168, 4, 20)
assertEQ(a, 192, 168, 4, 0)
assertEQ(b, 192, 168, 4, 4)
assertEQ(c, 192, 168, 200, 182)
assertEQ(c, 192, 95, 5, 68)
assertEQ(e, 192, 95, 5, 96)
assertEQ(g, 64, 15, 116, 26)
assertEQ(g, 64, 15, 127, 3)
insert(a, 1, 0, 0, 0, 32)
insert(a, 64, 0, 0, 0, 32)
insert(a, 128, 0, 0, 0, 32)
insert(a, 192, 0, 0, 0, 32)
insert(a, 255, 0, 0, 0, 32)
assertEQ(a, 1, 0, 0, 0)
assertEQ(a, 64, 0, 0, 0)
assertEQ(a, 128, 0, 0, 0)
assertEQ(a, 192, 0, 0, 0)
assertEQ(a, 255, 0, 0, 0)
allowedIPs.RemoveByPeer(a)
assertNEQ(a, 1, 0, 0, 0)
assertNEQ(a, 64, 0, 0, 0)
assertNEQ(a, 128, 0, 0, 0)
assertNEQ(a, 192, 0, 0, 0)
assertNEQ(a, 255, 0, 0, 0)
allowedIPs.RemoveByPeer(a)
allowedIPs.RemoveByPeer(b)
allowedIPs.RemoveByPeer(c)
allowedIPs.RemoveByPeer(d)
allowedIPs.RemoveByPeer(e)
allowedIPs.RemoveByPeer(g)
allowedIPs.RemoveByPeer(h)
if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
t.Error("Expected removing all the peers to empty trie, but it did not")
}
insert(a, 192, 168, 0, 0, 16)
insert(a, 192, 168, 0, 0, 24)
allowedIPs.RemoveByPeer(a)
assertNEQ(a, 192, 168, 0, 1)
}
/* Test ported from kernel implementation:
* selftest/allowedips.h
*/
func TestTrieIPv6(t *testing.T) {
a := &Peer{}
b := &Peer{}
c := &Peer{}
d := &Peer{}
e := &Peer{}
f := &Peer{}
g := &Peer{}
h := &Peer{}
var allowedIPs AllowedIPs
expand := func(a uint32) []byte {
var out [4]byte
out[0] = byte(a >> 24 & 0xff)
out[1] = byte(a >> 16 & 0xff)
out[2] = byte(a >> 8 & 0xff)
out[3] = byte(a & 0xff)
return out[:]
}
insert := func(peer *Peer, a, b, c, d uint32, cidr uint8) {
var addr []byte
addr = append(addr, expand(a)...)
addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...)
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
}
assertEQ := func(peer *Peer, a, b, c, d uint32) {
var addr []byte
addr = append(addr, expand(a)...)
addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...)
p := allowedIPs.Lookup(addr)
if p != peer {
t.Error("Assert EQ failed")
}
}
insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128)
insert(c, 0x26075300, 0x60006b00, 0, 0, 64)
insert(e, 0, 0, 0, 0, 0)
insert(f, 0, 0, 0, 0, 0)
insert(g, 0x24046800, 0, 0, 0, 32)
insert(h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64)
insert(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 128)
insert(c, 0x24446800, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
insert(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
assertEQ(d, 0x26075300, 0x60006b00, 0, 0xc05f0543)
assertEQ(c, 0x26075300, 0x60006b00, 0, 0xc02e01ee)
assertEQ(f, 0x26075300, 0x60006b01, 0, 0)
assertEQ(g, 0x24046800, 0x40040806, 0, 0x1006)
assertEQ(g, 0x24046800, 0x40040806, 0x1234, 0x5678)
assertEQ(f, 0x240467ff, 0x40040806, 0x1234, 0x5678)
assertEQ(f, 0x24046801, 0x40040806, 0x1234, 0x5678)
assertEQ(h, 0x24046800, 0x40040800, 0x1234, 0x5678)
assertEQ(h, 0x24046800, 0x40040800, 0, 0)
assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010)
assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef)
}

@ -1,28 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"container/list"
"tailscale.com/types/key"
)
type Peer struct {
trieEntries list.List
key key.NodePublic
}
func NewPeer(k key.NodePublic) *Peer {
return &Peer{
key: k,
}
}
func (p *Peer) Key() key.NodePublic {
return p.key
}
Loading…
Cancel
Save