wgengine: optimize isLocalAddr a bit

On macOS/iOS, this removes a map lookup per outgoing packet.

Noticed it while reading code, not from profiles, but can't hurt.

BenchmarkGenLocalAddrFunc
BenchmarkGenLocalAddrFunc/map1
BenchmarkGenLocalAddrFunc/map1-4                16184868                69.78 ns/op
BenchmarkGenLocalAddrFunc/map2
BenchmarkGenLocalAddrFunc/map2-4                16878140                70.73 ns/op
BenchmarkGenLocalAddrFunc/or1
BenchmarkGenLocalAddrFunc/or1-4                 623055721                1.950 ns/op
BenchmarkGenLocalAddrFunc/or2
BenchmarkGenLocalAddrFunc/or2-4                 472493098                2.589 ns/op

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/1585/head
Brad Fitzpatrick 4 years ago committed by Brad Fitzpatrick
parent 95ca86c048
commit e18c3a7d84

@ -91,10 +91,10 @@ type userspaceEngine struct {
testMaybeReconfigHook func() // for tests; if non-nil, fires if maybeReconfigWireguardLocked called testMaybeReconfigHook func() // for tests; if non-nil, fires if maybeReconfigWireguardLocked called
// localAddrs is the set of IP addresses assigned to the local // isLocalAddr reports the whether an IP is assigned to the local
// tunnel interface. It's used to reflect local packets // tunnel interface. It's used to reflect local packets
// incorrectly sent to us. // incorrectly sent to us.
localAddrs atomic.Value // of map[netaddr.IP]bool isLocalAddr atomic.Value // of func(netaddr.IP)bool
wgLock sync.Mutex // serializes all wgdev operations; see lock order comment below wgLock sync.Mutex // serializes all wgdev operations; see lock order comment below
lastCfgFull wgcfg.Config lastCfgFull wgcfg.Config
@ -180,7 +180,7 @@ func NewUserspaceEngine(logf logger.Logf, dev tun.Device, conf Config) (_ Engine
router: conf.Router, router: conf.Router,
pingers: make(map[wgkey.Key]*pinger), pingers: make(map[wgkey.Key]*pinger),
} }
e.localAddrs.Store(map[netaddr.IP]bool{}) e.isLocalAddr.Store(genLocalAddrFunc(nil))
if conf.LinkMonitor != nil { if conf.LinkMonitor != nil {
e.linkMon = conf.LinkMonitor e.linkMon = conf.LinkMonitor
@ -390,28 +390,24 @@ func (e *userspaceEngine) handleLocalPackets(p *packet.Parsed, t *tstun.Wrapper)
return filter.Drop return filter.Drop
} }
if (runtime.GOOS == "darwin" || runtime.GOOS == "ios") && e.isLocalAddr(p.Dst.IP) { if runtime.GOOS == "darwin" || runtime.GOOS == "ios" {
// macOS NetworkExtension directs packets destined to the isLocalAddr, ok := e.isLocalAddr.Load().(func(netaddr.IP) bool)
// tunnel's local IP address into the tunnel, instead of if !ok {
// looping back within the kernel network stack. We have to e.logf("[unexpected] e.isLocalAddr was nil, can't check for loopback packet")
// notice that an outbound packet is actually destined for } else if isLocalAddr(p.Dst.IP) {
// ourselves, and loop it back into macOS. // macOS NetworkExtension directs packets destined to the
t.InjectInboundCopy(p.Buffer()) // tunnel's local IP address into the tunnel, instead of
return filter.Drop // looping back within the kernel network stack. We have to
// notice that an outbound packet is actually destined for
// ourselves, and loop it back into macOS.
t.InjectInboundCopy(p.Buffer())
return filter.Drop
}
} }
return filter.Accept return filter.Accept
} }
func (e *userspaceEngine) isLocalAddr(ip netaddr.IP) bool {
localAddrs, ok := e.localAddrs.Load().(map[netaddr.IP]bool)
if !ok {
e.logf("[unexpected] e.localAddrs was nil, can't check for loopback packet")
return false
}
return localAddrs[ip]
}
// handleDNS is an outbound pre-filter resolving Tailscale domains. // handleDNS is an outbound pre-filter resolving Tailscale domains.
func (e *userspaceEngine) handleDNS(p *packet.Parsed, t *tstun.Wrapper) filter.Response { func (e *userspaceEngine) handleDNS(p *packet.Parsed, t *tstun.Wrapper) filter.Response {
if p.Dst.IP == magicDNSIP && p.Dst.Port == magicDNSPort && p.IPProto == ipproto.UDP { if p.Dst.IP == magicDNSIP && p.Dst.Port == magicDNSPort && p.IPProto == ipproto.UDP {
@ -877,16 +873,34 @@ func (e *userspaceEngine) updateActivityMapsLocked(trackDisco []tailcfg.DiscoKey
e.tundev.SetDestIPActivityFuncs(e.destIPActivityFuncs) e.tundev.SetDestIPActivityFuncs(e.destIPActivityFuncs)
} }
// genLocalAddrFunc returns a func that reports whether an IP is in addrs.
// addrs is assumed to be all /32 or /128 entries.
func genLocalAddrFunc(addrs []netaddr.IPPrefix) func(netaddr.IP) bool {
// Specialize the three common cases: no address, just IPv4
// (or just IPv6), and both IPv4 and IPv6.
if len(addrs) == 0 {
return func(netaddr.IP) bool { return false }
}
if len(addrs) == 1 {
return func(t netaddr.IP) bool { return t == addrs[0].IP }
}
if len(addrs) == 2 {
return func(t netaddr.IP) bool { return t == addrs[0].IP || t == addrs[1].IP }
}
// Otherwise, the general implementation: a map lookup.
m := map[netaddr.IP]bool{}
for _, a := range addrs {
m[a.IP] = true
}
return func(t netaddr.IP) bool { return m[t] }
}
func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config) error { func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config) error {
if routerCfg == nil { if routerCfg == nil {
panic("routerCfg must not be nil") panic("routerCfg must not be nil")
} }
localAddrs := map[netaddr.IP]bool{} e.isLocalAddr.Store(genLocalAddrFunc(routerCfg.LocalAddrs))
for _, addr := range routerCfg.LocalAddrs {
localAddrs[addr.IP] = true
}
e.localAddrs.Store(localAddrs)
e.wgLock.Lock() e.wgLock.Lock()
defer e.wgLock.Unlock() defer e.wgLock.Unlock()

@ -139,3 +139,50 @@ func dkFromHex(hex string) tailcfg.DiscoKey {
} }
return tailcfg.DiscoKey(k) return tailcfg.DiscoKey(k)
} }
// an experiment to see if genLocalAddrFunc was worth it. As of Go
// 1.16, it still very much is. (30-40x faster)
func BenchmarkGenLocalAddrFunc(b *testing.B) {
la1 := netaddr.MustParseIP("1.2.3.4")
la2 := netaddr.MustParseIP("::4")
lanot := netaddr.MustParseIP("5.5.5.5")
var x bool
b.Run("map1", func(b *testing.B) {
m := map[netaddr.IP]bool{
la1: true,
}
for i := 0; i < b.N; i++ {
x = m[la1]
x = m[lanot]
}
})
b.Run("map2", func(b *testing.B) {
m := map[netaddr.IP]bool{
la1: true,
la2: true,
}
for i := 0; i < b.N; i++ {
x = m[la1]
x = m[lanot]
}
})
b.Run("or1", func(b *testing.B) {
f := func(t netaddr.IP) bool {
return t == la1
}
for i := 0; i < b.N; i++ {
x = f(la1)
x = f(lanot)
}
})
b.Run("or2", func(b *testing.B) {
f := func(t netaddr.IP) bool {
return t == la1 || t == la2
}
for i := 0; i < b.N; i++ {
x = f(la1)
x = f(lanot)
}
})
b.Logf("x = %v", x)
}

Loading…
Cancel
Save