diff --git a/net/tsaddr/tsaddr.go b/net/tsaddr/tsaddr.go index 26d837c6e..ea75470a8 100644 --- a/net/tsaddr/tsaddr.go +++ b/net/tsaddr/tsaddr.go @@ -138,3 +138,50 @@ type onceIP struct { sync.Once v netaddr.IP } + +// NewContainsIPFunc returns a func that reports whether ip is in addrs. +// +// It's optimized for the cases of addrs being empty and addrs +// containing 1 or 2 single-IP prefixes (such as one IPv4 address and +// one IPv6 address). +// +// Otherwise the implementation is somewhat slow. +func NewContainsIPFunc(addrs []netaddr.IPPrefix) func(ip netaddr.IP) bool { + // Specialize the three common cases: no address, just IPv4 + // (or just IPv6), and both IPv4 and IPv6. + if len(addrs) == 0 { + return func(netaddr.IP) bool { return false } + } + // If any addr is more than a single IP, then just do the slow + // linear thing until + // https://github.com/inetaf/netaddr/issues/139 is done. + for _, a := range addrs { + if a.IsSingleIP() { + continue + } + acopy := append([]netaddr.IPPrefix(nil), addrs...) + return func(ip netaddr.IP) bool { + for _, a := range acopy { + if a.Contains(ip) { + return true + } + } + return false + } + } + // Fast paths for 1 and 2 IPs: + if len(addrs) == 1 { + a := addrs[0] + return func(ip netaddr.IP) bool { return ip == a.IP } + } + if len(addrs) == 2 { + a, b := addrs[0], addrs[1] + return func(ip netaddr.IP) bool { return ip == a.IP || ip == b.IP } + } + // General case: + m := map[netaddr.IP]bool{} + for _, a := range addrs { + m[a.IP] = true + } + return func(ip netaddr.IP) bool { return m[ip] } +} diff --git a/net/tsaddr/tsaddr_test.go b/net/tsaddr/tsaddr_test.go index e66aa8d8d..eebd61445 100644 --- a/net/tsaddr/tsaddr_test.go +++ b/net/tsaddr/tsaddr_test.go @@ -64,3 +64,32 @@ func TestIsUla(t *testing.T) { } } } + +func TestNewContainsIPFunc(t *testing.T) { + f := NewContainsIPFunc([]netaddr.IPPrefix{netaddr.MustParseIPPrefix("10.0.0.0/8")}) + if f(netaddr.MustParseIP("8.8.8.8")) { + t.Fatal("bad") + } + if !f(netaddr.MustParseIP("10.1.2.3")) { + t.Fatal("bad") + } + f = NewContainsIPFunc([]netaddr.IPPrefix{netaddr.MustParseIPPrefix("10.1.2.3/32")}) + if !f(netaddr.MustParseIP("10.1.2.3")) { + t.Fatal("bad") + } + f = NewContainsIPFunc([]netaddr.IPPrefix{ + netaddr.MustParseIPPrefix("10.1.2.3/32"), + netaddr.MustParseIPPrefix("::2/128"), + }) + if !f(netaddr.MustParseIP("::2")) { + t.Fatal("bad") + } + f = NewContainsIPFunc([]netaddr.IPPrefix{ + netaddr.MustParseIPPrefix("10.1.2.3/32"), + netaddr.MustParseIPPrefix("10.1.2.4/32"), + netaddr.MustParseIPPrefix("::2/128"), + }) + if !f(netaddr.MustParseIP("10.1.2.4")) { + t.Fatal("bad") + } +} diff --git a/wgengine/userspace.go b/wgengine/userspace.go index 5ab060efa..b959aa43b 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -243,7 +243,7 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error) router: conf.Router, pingers: make(map[wgkey.Key]*pinger), } - e.isLocalAddr.Store(genLocalAddrFunc(nil)) + e.isLocalAddr.Store(tsaddr.NewContainsIPFunc(nil)) if conf.LinkMonitor != nil { e.linkMon = conf.LinkMonitor @@ -936,28 +936,6 @@ func (e *userspaceEngine) updateActivityMapsLocked(trackDisco []tailcfg.DiscoKey e.tundev.SetDestIPActivityFuncs(e.destIPActivityFuncs) } -// genLocalAddrFunc returns a func that reports whether an IP is in addrs. -// addrs is assumed to be all /32 or /128 entries. -func genLocalAddrFunc(addrs []netaddr.IPPrefix) func(netaddr.IP) bool { - // Specialize the three common cases: no address, just IPv4 - // (or just IPv6), and both IPv4 and IPv6. - if len(addrs) == 0 { - return func(netaddr.IP) bool { return false } - } - if len(addrs) == 1 { - return func(t netaddr.IP) bool { return t == addrs[0].IP } - } - if len(addrs) == 2 { - return func(t netaddr.IP) bool { return t == addrs[0].IP || t == addrs[1].IP } - } - // Otherwise, the general implementation: a map lookup. - m := map[netaddr.IP]bool{} - for _, a := range addrs { - m[a.IP] = true - } - return func(t netaddr.IP) bool { return m[t] } -} - func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, dnsCfg *dns.Config) error { if routerCfg == nil { panic("routerCfg must not be nil") @@ -966,7 +944,7 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, panic("dnsCfg must not be nil") } - e.isLocalAddr.Store(genLocalAddrFunc(routerCfg.LocalAddrs)) + e.isLocalAddr.Store(tsaddr.NewContainsIPFunc(routerCfg.LocalAddrs)) e.wgLock.Lock() defer e.wgLock.Unlock()