From b3634f020dab68b61db54814993781b2662e3014 Mon Sep 17 00:00:00 2001 From: David Anderson Date: Mon, 9 Nov 2020 20:12:21 -0800 Subject: [PATCH] wgengine/filter: use netaddr types in public API. We still use the packet.* alloc-free types in the data path, but the compilation from netaddr to packet happens within the filter package. Signed-off-by: David Anderson --- ipn/local.go | 35 ++----- wgengine/filter/filter.go | 97 +++++++++-------- wgengine/filter/filter_test.go | 111 +++++++++++++++----- wgengine/filter/match.go | 122 +++++----------------- wgengine/filter/match4.go | 151 +++++++++++++++++++++++++++ wgengine/magicsock/magicsock_test.go | 2 +- wgengine/tstun/tun_test.go | 81 ++++++++++---- 7 files changed, 385 insertions(+), 214 deletions(-) create mode 100644 wgengine/filter/match4.go diff --git a/ipn/local.go b/ipn/local.go index ad11fe202..e09315378 100644 --- a/ipn/local.go +++ b/ipn/local.go @@ -546,7 +546,7 @@ func (b *LocalBackend) updateFilter(netMap *controlclient.NetworkMap, prefs *Pre return } - localNets := wgCIDRsToFilter(netMap.Addresses, advRoutes) + localNets := wgCIDRsToNetaddr(netMap.Addresses, advRoutes) if shieldsUp { b.logf("netmap packet filter: (shields up)") @@ -1266,14 +1266,14 @@ func routerConfig(cfg *wgcfg.Config, prefs *Prefs) *router.Config { } rs := &router.Config{ - LocalAddrs: wgCIDRToNetaddr(addrs), - SubnetRoutes: wgCIDRToNetaddr(prefs.AdvertiseRoutes), + LocalAddrs: wgCIDRsToNetaddr(addrs), + SubnetRoutes: wgCIDRsToNetaddr(prefs.AdvertiseRoutes), SNATSubnetRoutes: !prefs.NoSNAT, NetfilterMode: prefs.NetfilterMode, } for _, peer := range cfg.Peers { - rs.Routes = append(rs.Routes, wgCIDRToNetaddr(peer.AllowedIPs)...) + rs.Routes = append(rs.Routes, wgCIDRsToNetaddr(peer.AllowedIPs)...) } rs.Routes = append(rs.Routes, netaddr.IPPrefix{ @@ -1284,31 +1284,16 @@ func routerConfig(cfg *wgcfg.Config, prefs *Prefs) *router.Config { return rs } -// wgCIDRsToFilter converts lists of wgcfg.CIDR into a single list of -// filter.Net. -func wgCIDRsToFilter(cidrLists ...[]wgcfg.CIDR) (ret []filter.Net) { +func wgCIDRsToNetaddr(cidrLists ...[]wgcfg.CIDR) (ret []netaddr.IPPrefix) { for _, cidrs := range cidrLists { for _, cidr := range cidrs { - if !cidr.IP.Is4() { - continue + ncidr, ok := netaddr.FromStdIPNet(cidr.IPNet()) + if !ok { + panic(fmt.Sprintf("conversion of %s from wgcfg to netaddr IPNet failed", cidr)) } - ret = append(ret, filter.Net{ - IP: filter.NewIP(cidr.IP.IP()), - Mask: filter.Netmask(int(cidr.Mask)), - }) - } - } - return ret -} - -func wgCIDRToNetaddr(cidrs []wgcfg.CIDR) (ret []netaddr.IPPrefix) { - for _, cidr := range cidrs { - ncidr, ok := netaddr.FromStdIPNet(cidr.IPNet()) - if !ok { - panic(fmt.Sprintf("conversion of %s from wgcfg to netaddr IPNet failed", cidr)) + ncidr.IP = ncidr.IP.Unmap() + ret = append(ret, ncidr) } - ncidr.IP = ncidr.IP.Unmap() - ret = append(ret, ncidr) } return ret } diff --git a/wgengine/filter/filter.go b/wgengine/filter/filter.go index 8eeaeb328..c34d2765d 100644 --- a/wgengine/filter/filter.go +++ b/wgengine/filter/filter.go @@ -7,12 +7,12 @@ package filter import ( "fmt" - "net" "sync" "time" "github.com/golang/groupcache/lru" "golang.org/x/time/rate" + "inet.af/netaddr" "tailscale.com/net/packet" "tailscale.com/tailcfg" "tailscale.com/types/logger" @@ -26,16 +26,18 @@ type filterState struct { // Filter is a stateful packet filter. type Filter struct { logf logger.Logf - // localNets is the list of IP prefixes that we know to be "local" - // to this node. All packets coming in over tailscale must have a - // destination within localNets, regardless of the policy filter - // below. A nil localNets rejects all incoming traffic. - localNets []Net - // matches is a list of match->action rules applied to all packets - // arriving over tailscale tunnels. Matches are checked in order, - // and processing stops at the first matching rule. The default - // policy if no rules match is to drop the packet. - matches Matches + // localNets is the list of IP prefixes that we know to be + // "local" to this node. All packets coming in over tailscale + // must have a destination within localNets, regardless of the + // policy filter below. A nil localNets rejects all incoming + // traffic. + local4 []net4 + // matches4 is a list of match->action rules applied to all + // packets arriving over tailscale tunnels. Matches are + // checked in order, and processing stops at the first + // matching rule. The default policy if no rules match is to + // drop the packet. + matches4 matches4 // state is the connection tracking state attached to this // filter. It is used to allow incoming traffic that is a response // to an outbound connection that this node made, even if those @@ -87,12 +89,12 @@ const lruMax = 512 // max entries in UDP LRU cache // MatchAllowAll matches all packets. var MatchAllowAll = Matches{ - Match{[]NetPortRange{NetPortRangeAny}, []Net{NetAny}}, + Match{NetPortRangeAny, NetAny}, } // NewAllowAll returns a packet filter that accepts everything to and // from localNets. -func NewAllowAll(localNets []Net, logf logger.Logf) *Filter { +func NewAllowAll(localNets []netaddr.IPPrefix, logf logger.Logf) *Filter { return New(MatchAllowAll, localNets, nil, logf) } @@ -106,7 +108,7 @@ func NewAllowNone(logf logger.Logf) *Filter { // by matches. If shareStateWith is non-nil, the returned filter // shares state with the previous one, to enable rules to be changed // at runtime without breaking existing flows. -func New(matches Matches, localNets []Net, shareStateWith *Filter, logf logger.Logf) *Filter { +func New(matches Matches, localNets []netaddr.IPPrefix, shareStateWith *Filter, logf logger.Logf) *Filter { var state *filterState if shareStateWith != nil { state = shareStateWith.state @@ -116,10 +118,10 @@ func New(matches Matches, localNets []Net, shareStateWith *Filter, logf logger.L } } f := &Filter{ - logf: logf, - matches: matches, - localNets: localNets, - state: state, + logf: logf, + matches4: newMatches4(matches), + local4: nets4FromIPPrefixes(localNets), + state: state, } return f } @@ -179,29 +181,32 @@ func MatchesFromFilterRules(pf []tailcfg.FilterRule) (Matches, error) { return mm, erracc } -func parseIP(host string, defaultBits int) (Net, error) { - ip := net.ParseIP(host) - if ip != nil && ip.IsUnspecified() { +func parseIP(host string, defaultBits int) (netaddr.IPPrefix, error) { + if host == "*" { + // User explicitly requested wildcard dst ip. + // TODO: ipv6 + return netaddr.IPPrefix{IP: netaddr.IPv4(0, 0, 0, 0), Bits: 0}, nil + } + + ip, err := netaddr.ParseIP(host) + if err != nil { + return netaddr.IPPrefix{}, fmt.Errorf("ports=%#v: invalid IP address", host) + } + if ip == netaddr.IPv4(0, 0, 0, 0) { // For clarity, reject 0.0.0.0 as an input - return NetNone, fmt.Errorf("ports=%#v: to allow all IP addresses, use *:port, not 0.0.0.0:port", host) - } else if ip == nil && host == "*" { - // User explicitly requested wildcard dst ip - return NetAny, nil - } else { - if ip != nil { - ip = ip.To4() - } - if ip == nil || len(ip) != 4 { - return NetNone, fmt.Errorf("ports=%#v: invalid IPv4 address", host) - } - if len(ip) == 4 && (defaultBits < 0 || defaultBits > 32) { - return NetNone, fmt.Errorf("invalid CIDR size %d for host %q", defaultBits, host) - } - return Net{ - IP: NewIP(ip), - Mask: Netmask(defaultBits), - }, nil + return netaddr.IPPrefix{}, fmt.Errorf("ports=%#v: to allow all IP addresses, use *:port, not 0.0.0.0:port", host) + } + if !ip.Is4() { + // TODO: ipv6 + return netaddr.IPPrefix{}, fmt.Errorf("ports=%#v: invalid IPv4 address", host) + } + if defaultBits < 0 || defaultBits > 32 { + return netaddr.IPPrefix{}, fmt.Errorf("invalid CIDR size %d for host %q", defaultBits, host) } + return netaddr.IPPrefix{ + IP: ip, + Bits: uint8(defaultBits), + }, nil } // TODO(apenwarr): use a bigger bucket for specifically TCP SYN accept logging? @@ -266,7 +271,7 @@ func (f *Filter) runIn(q *packet.ParsedPacket) (r Response, why string) { // A compromised peer could try to send us packets for // destinations we didn't explicitly advertise. This check is to // prevent that. - if !ipInList(q.DstIP, f.localNets) { + if !ip4InList(q.DstIP, f.local4) { return Drop, "destination not allowed" } @@ -284,7 +289,7 @@ func (f *Filter) runIn(q *packet.ParsedPacket) (r Response, why string) { // related to an existing ICMP-Echo, TCP, or UDP // session. return Accept, "icmp response ok" - } else if matchIPWithoutPorts(f.matches, q) { + } else if f.matches4.matchIPsOnly(q) { // If any port is open to an IP, allow ICMP to it. return Accept, "icmp ok" } @@ -300,7 +305,7 @@ func (f *Filter) runIn(q *packet.ParsedPacket) (r Response, why string) { if q.IPProto == packet.TCP && !q.IsTCPSyn() { return Accept, "tcp non-syn" } - if matchIPPorts(f.matches, q) { + if f.matches4.match(q) { return Accept, "tcp ok" } case packet.UDP: @@ -313,7 +318,7 @@ func (f *Filter) runIn(q *packet.ParsedPacket) (r Response, why string) { if ok { return Accept, "udp cached" } - if matchIPPorts(f.matches, q) { + if f.matches4.match(q) { return Accept, "udp ok" } default: @@ -399,9 +404,9 @@ const ( ) // omitDropLogging reports whether packet p, which has already been -// deemded a packet to Drop, should bypass the [rate-limited] logging. -// We don't want to log scary & spammy reject warnings for packets that -// are totally normal, like IPv6 route announcements. +// deemed a packet to Drop, should bypass the [rate-limited] logging. +// We don't want to log scary & spammy reject warnings for packets +// that are totally normal, like IPv6 route announcements. func omitDropLogging(p *packet.ParsedPacket, dir direction) bool { b := p.Buffer() switch dir { diff --git a/wgengine/filter/filter_test.go b/wgengine/filter/filter_test.go index f197b0da5..2b98eb836 100644 --- a/wgengine/filter/filter_test.go +++ b/wgengine/filter/filter_test.go @@ -8,10 +8,13 @@ import ( "encoding/binary" "encoding/hex" "encoding/json" + "fmt" "net" + "strconv" "strings" "testing" + "inet.af/netaddr" "tailscale.com/net/packet" "tailscale.com/types/logger" ) @@ -22,43 +25,91 @@ var TCP = packet.TCP var UDP = packet.UDP var Fragment = packet.Fragment -func nets(ips []packet.IP4) []Net { - out := make([]Net, 0, len(ips)) - for _, ip := range ips { - out = append(out, Net{ip, Netmask(32)}) +func pfx(s string) netaddr.IPPrefix { + pfx, err := netaddr.ParseIPPrefix(s) + if err != nil { + panic(err) } - return out + return pfx } -func ippr(ip packet.IP4, start, end uint16) []NetPortRange { - return []NetPortRange{ - NetPortRange{Net{ip, Netmask(32)}, PortRange{start, end}}, +func nets(nets ...string) (ret []netaddr.IPPrefix) { + for _, s := range nets { + if i := strings.IndexByte(s, '/'); i == -1 { + ip, err := netaddr.ParseIP(s) + if err != nil { + panic(err) + } + bits := uint8(32) + if ip.Is6() { + bits = 128 + } + ret = append(ret, netaddr.IPPrefix{IP: ip, Bits: bits}) + } else { + pfx, err := netaddr.ParseIPPrefix(s) + if err != nil { + panic(err) + } + ret = append(ret, pfx) + } } + return ret } -func netpr(ip packet.IP4, bits int, start, end uint16) []NetPortRange { - return []NetPortRange{ - NetPortRange{Net{ip, Netmask(bits)}, PortRange{start, end}}, +func ports(s string) PortRange { + if s == "*" { + return PortRangeAny + } + + var fs, ls string + i := strings.IndexByte(s, '-') + if i == -1 { + fs = s + ls = fs + } else { + fs = s[:i] + ls = s[i+1:] + } + first, err := strconv.ParseInt(fs, 10, 16) + if err != nil { + panic(fmt.Sprintf("invalid NetPortRange %q", s)) + } + last, err := strconv.ParseInt(ls, 10, 16) + if err != nil { + panic(fmt.Sprintf("invalid NetPortRange %q", s)) + } + return PortRange{uint16(first), uint16(last)} +} + +func netports(netPorts ...string) (ret []NetPortRange) { + for _, s := range netPorts { + i := strings.LastIndexByte(s, ':') + if i == -1 { + panic(fmt.Sprintf("invalid NetPortRange %q", s)) + } + + npr := NetPortRange{ + Net: nets(s[:i])[0], + Ports: ports(s[i+1:]), + } + ret = append(ret, npr) } + return ret } var matches = Matches{ - {Srcs: nets([]packet.IP4{0x08010101, 0x08020202}), Dsts: []NetPortRange{ - NetPortRange{Net{0x01020304, Netmask(32)}, PortRange{22, 22}}, - NetPortRange{Net{0x05060708, Netmask(32)}, PortRange{23, 24}}, - }}, - {Srcs: nets([]packet.IP4{0x08010101, 0x08020202}), Dsts: ippr(0x05060708, 27, 28)}, - {Srcs: nets([]packet.IP4{0x02020202}), Dsts: ippr(0x08010101, 22, 22)}, - {Srcs: []Net{NetAny}, Dsts: ippr(0x647a6232, 0, 65535)}, - {Srcs: []Net{NetAny}, Dsts: netpr(0, 0, 443, 443)}, - {Srcs: nets([]packet.IP4{0x99010101, 0x99010102, 0x99030303}), Dsts: ippr(0x01020304, 999, 999)}, + {Srcs: nets("8.1.1.1", "8.2.2.2"), Dsts: netports("1.2.3.4:22", "5.6.7.8:23-24")}, + {Srcs: nets("8.1.1.1", "8.2.2.2"), Dsts: netports("5.6.7.8:27-28")}, + {Srcs: nets("2.2.2.2"), Dsts: netports("8.1.1.1:22")}, + {Srcs: nets("0.0.0.0/0"), Dsts: netports("100.122.98.50:*")}, + {Srcs: nets("0.0.0.0/0"), Dsts: netports("0.0.0.0/0:443")}, + {Srcs: nets("153.1.1.1", "153.1.1.2", "153.3.3.3"), Dsts: netports("1.2.3.4:999")}, } func newFilter(logf logger.Logf) *Filter { // Expects traffic to 100.122.98.50, 1.2.3.4, 5.6.7.8, // 102.102.102.102, 119.119.119.119, 8.1.0.0/16 - localNets := nets([]packet.IP4{0x647a6232, 0x01020304, 0x05060708, 0x66666666, 0x77777777}) - localNets = append(localNets, Net{packet.IP4(0x08010000), Netmask(16)}) + localNets := nets("100.122.98.50", "1.2.3.4", "5.6.7.8", "102.102.102.102", "119.119.119.119", "8.1.0.0/16") return New(matches, localNets, nil, logf) } @@ -160,18 +211,19 @@ func TestNoAllocs(t *testing.T) { } func TestParseIP(t *testing.T) { + var noaddr netaddr.IPPrefix tests := []struct { host string bits int - want Net + want netaddr.IPPrefix wantErr string }{ - {"8.8.8.8", 24, Net{IP: packet.NewIP4(net.ParseIP("8.8.8.8")), Mask: packet.NewIP4(net.ParseIP("255.255.255.0"))}, ""}, - {"8.8.8.8", 33, Net{}, `invalid CIDR size 33 for host "8.8.8.8"`}, - {"8.8.8.8", -1, Net{}, `invalid CIDR size -1 for host "8.8.8.8"`}, - {"0.0.0.0", 24, Net{}, `ports="0.0.0.0": to allow all IP addresses, use *:port, not 0.0.0.0:port`}, - {"*", 24, NetAny, ""}, - {"fe80::1", 128, NetNone, `ports="fe80::1": invalid IPv4 address`}, + {"8.8.8.8", 24, pfx("8.8.8.8/24"), ""}, + {"8.8.8.8", 33, noaddr, `invalid CIDR size 33 for host "8.8.8.8"`}, + {"8.8.8.8", -1, noaddr, `invalid CIDR size -1 for host "8.8.8.8"`}, + {"0.0.0.0", 24, noaddr, `ports="0.0.0.0": to allow all IP addresses, use *:port, not 0.0.0.0:port`}, + {"*", 24, pfx("0.0.0.0/0"), ""}, + {"fe80::1", 128, pfx("255.255.255.255/32"), `ports="fe80::1": invalid IPv4 address`}, } for _, tt := range tests { got, err := parseIP(tt.host, tt.bits) @@ -215,6 +267,7 @@ func BenchmarkFilter(b *testing.B) { for _, bench := range benches { b.Run(bench.name, func(b *testing.B) { + b.ReportAllocs() for i := 0; i < b.N; i++ { q := &packet.ParsedPacket{} q.Decode(bench.packet) diff --git a/wgengine/filter/match.go b/wgengine/filter/match.go index 41e721652..2b9d78461 100644 --- a/wgengine/filter/match.go +++ b/wgengine/filter/match.go @@ -6,53 +6,17 @@ package filter import ( "fmt" - "math/bits" - "net" "strings" - "tailscale.com/net/packet" + "inet.af/netaddr" ) -func NewIP(ip net.IP) packet.IP4 { - return packet.NewIP4(ip) -} - -type Net struct { - IP packet.IP4 - Mask packet.IP4 -} - -func (n Net) Includes(ip packet.IP4) bool { - return (n.IP & n.Mask) == (ip & n.Mask) -} - -func (n Net) Bits() int { - return 32 - bits.TrailingZeros32(uint32(n.Mask)) -} - -func (n Net) String() string { - b := n.Bits() - if b == 32 { - return n.IP.String() - } else if b == 0 { - return "*" - } else { - return fmt.Sprintf("%s/%d", n.IP, b) - } -} - -var NetAny = Net{0, 0} -var NetNone = Net{^packet.IP4(0), ^packet.IP4(0)} - -func Netmask(bits int) packet.IP4 { - b := ^uint32((1 << (32 - bits)) - 1) - return packet.IP4(b) -} - +// PortRange is a range of TCP and UDP ports. type PortRange struct { - First, Last uint16 + First, Last uint16 // inclusive } +// PortRangeAny represents all TCP and UDP ports. var PortRangeAny = PortRange{0, 65535} func (pr PortRange) String() string { @@ -65,28 +29,40 @@ func (pr PortRange) String() string { } } +func (pr PortRange) contains(port uint16) bool { + return port >= pr.First && port <= pr.Last +} + +// NetAny matches all IP addresses. +// TODO: add ipv6. +var NetAny = []netaddr.IPPrefix{{IP: netaddr.IPv4(0, 0, 0, 0), Bits: 0}} + +// NetPortRange combines an IP address prefix and PortRange. type NetPortRange struct { - Net Net + Net netaddr.IPPrefix Ports PortRange } -var NetPortRangeAny = NetPortRange{NetAny, PortRangeAny} - -func (ipr NetPortRange) String() string { - return fmt.Sprintf("%v:%v", ipr.Net, ipr.Ports) +func (npr NetPortRange) String() string { + return fmt.Sprintf("%v:%v", npr.Net, npr.Ports) } +var NetPortRangeAny = []NetPortRange{{Net: NetAny[0], Ports: PortRangeAny}} + +// Match matches packets from any IP address in Srcs to any ip:port in +// Dsts. type Match struct { Dsts []NetPortRange - Srcs []Net + Srcs []netaddr.IPPrefix } +// Clone returns a deep copy of m. func (m Match) Clone() (res Match) { if m.Dsts != nil { res.Dsts = append([]NetPortRange{}, m.Dsts...) } if m.Srcs != nil { - res.Srcs = append([]Net{}, m.Srcs...) + res.Srcs = append([]netaddr.IPPrefix{}, m.Srcs...) } return res } @@ -115,57 +91,13 @@ func (m Match) String() string { return fmt.Sprintf("%v=>%v", ss, ds) } +// Matches is a list of packet matchers. type Matches []Match -func (m Matches) Clone() (res Matches) { - for _, match := range m { +// Clone returns a deep copy of ms. +func (ms Matches) Clone() (res Matches) { + for _, match := range ms { res = append(res, match.Clone()) } return res } - -func ipInList(ip packet.IP4, netlist []Net) bool { - for _, net := range netlist { - if net.Includes(ip) { - return true - } - } - return false -} - -func matchIPPorts(mm Matches, q *packet.ParsedPacket) bool { - for _, acl := range mm { - for _, dst := range acl.Dsts { - if !dst.Net.Includes(q.DstIP) { - continue - } - if q.DstPort < dst.Ports.First || q.DstPort > dst.Ports.Last { - continue - } - if !ipInList(q.SrcIP, acl.Srcs) { - // Skip other dests in this acl, since - // the src will never match. - break - } - return true - } - } - return false -} - -func matchIPWithoutPorts(mm Matches, q *packet.ParsedPacket) bool { - for _, acl := range mm { - for _, dst := range acl.Dsts { - if !dst.Net.Includes(q.DstIP) { - continue - } - if !ipInList(q.SrcIP, acl.Srcs) { - // Skip other dests in this acl, since - // the src will never match. - break - } - return true - } - } - return false -} diff --git a/wgengine/filter/match4.go b/wgengine/filter/match4.go new file mode 100644 index 000000000..d9329fcd4 --- /dev/null +++ b/wgengine/filter/match4.go @@ -0,0 +1,151 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package filter + +import ( + "fmt" + "math/bits" + "strings" + + "inet.af/netaddr" + "tailscale.com/net/packet" +) + +type net4 struct { + ip packet.IP4 + mask packet.IP4 +} + +func net4FromIPPrefix(pfx netaddr.IPPrefix) net4 { + if !pfx.IP.Is4() { + panic("net4FromIPPrefix given non-ipv4 prefix") + } + return net4{ + ip: packet.IP4FromNetaddr(pfx.IP), + mask: netmask4(pfx.Bits), + } +} + +func nets4FromIPPrefixes(pfxs []netaddr.IPPrefix) (ret []net4) { + for _, pfx := range pfxs { + if pfx.IP.Is4() { + ret = append(ret, net4FromIPPrefix(pfx)) + } + } + return ret +} + +func (n net4) Contains(ip packet.IP4) bool { + return (n.ip & n.mask) == (ip & n.mask) +} + +func (n net4) Bits() int { + return 32 - bits.TrailingZeros32(uint32(n.mask)) +} + +func (n net4) String() string { + b := n.Bits() + if b == 32 { + return n.ip.String() + } else if b == 0 { + return "*" + } else { + return fmt.Sprintf("%s/%d", n.ip, b) + } +} + +type npr4 struct { + net net4 + ports PortRange +} + +func (npr npr4) String() string { + return fmt.Sprintf("%s:%s", npr.net, npr.ports) +} + +type match4 struct { + dsts []npr4 + srcs []net4 +} + +type matches4 []match4 + +func (ms matches4) String() string { + var b strings.Builder + for _, m := range ms { + fmt.Fprintf(&b, "%s => %s\n", m.srcs, m.dsts) + } + return b.String() +} + +func newMatches4(ms Matches) (ret matches4) { + for _, m := range ms { + var m4 match4 + for _, src := range m.Srcs { + if src.IP.Is4() { + m4.srcs = append(m4.srcs, net4FromIPPrefix(src)) + } + } + for _, dst := range m.Dsts { + if dst.Net.IP.Is4() { + m4.dsts = append(m4.dsts, npr4{net4FromIPPrefix(dst.Net), dst.Ports}) + } + } + if len(m4.srcs) > 0 && len(m4.dsts) > 0 { + ret = append(ret, m4) + } + } + return ret +} + +// match returns whether q's source IP and destination IP:port match +// any of ms. +func (ms matches4) match(q *packet.ParsedPacket) bool { + for _, m := range ms { + if !ip4InList(q.SrcIP, m.srcs) { + continue + } + for _, dst := range m.dsts { + if !dst.net.Contains(q.DstIP) { + continue + } + if !dst.ports.contains(q.DstPort) { + continue + } + return true + } + } + return false +} + +// matchIPsOnly returns whether q's source and destination IP match +// any of ms. +func (ms matches4) matchIPsOnly(q *packet.ParsedPacket) bool { + for _, m := range ms { + if !ip4InList(q.SrcIP, m.srcs) { + continue + } + for _, dst := range m.dsts { + if dst.net.Contains(q.DstIP) { + return true + } + } + } + return false +} + +func netmask4(bits uint8) packet.IP4 { + b := ^uint32((1 << (32 - bits)) - 1) + return packet.IP4(b) +} + +func ip4InList(ip packet.IP4, netlist []net4) bool { + for _, net := range netlist { + if net.Contains(ip) { + return true + } + } + return false +} diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 8bf55b179..9a369d241 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -158,7 +158,7 @@ func newMagicStack(t *testing.T, logf logger.Logf, l nettype.PacketListener, der tun := tuntest.NewChannelTUN() tsTun := tstun.WrapTUN(logf, tun.TUN()) - tsTun.SetFilter(filter.NewAllowAll([]filter.Net{filter.NetAny}, logf)) + tsTun.SetFilter(filter.NewAllowAll(filter.NetAny, logf)) dev := device.NewDevice(tsTun, &device.DeviceOptions{ Logger: &device.Logger{ diff --git a/wgengine/tstun/tun_test.go b/wgengine/tstun/tun_test.go index d13d03aba..ee7a11711 100644 --- a/wgengine/tstun/tun_test.go +++ b/wgengine/tstun/tun_test.go @@ -6,11 +6,15 @@ package tstun import ( "bytes" + "fmt" + "strconv" + "strings" "sync/atomic" "testing" "unsafe" "github.com/tailscale/wireguard-go/tun/tuntest" + "inet.af/netaddr" "tailscale.com/net/packet" "tailscale.com/types/logger" "tailscale.com/wgengine/filter" @@ -29,35 +33,76 @@ func udp(src, dst packet.IP4, sport, dport uint16) []byte { return packet.Generate(header, []byte("udp_payload")) } -func filterNet(ip, mask packet.IP4) filter.Net { - return filter.Net{IP: ip, Mask: mask} +func nets(nets ...string) (ret []netaddr.IPPrefix) { + for _, s := range nets { + if i := strings.IndexByte(s, '/'); i == -1 { + ip, err := netaddr.ParseIP(s) + if err != nil { + panic(err) + } + bits := uint8(32) + if ip.Is6() { + bits = 128 + } + ret = append(ret, netaddr.IPPrefix{IP: ip, Bits: bits}) + } else { + pfx, err := netaddr.ParseIPPrefix(s) + if err != nil { + panic(err) + } + ret = append(ret, pfx) + } + } + return ret } -func nets(ips []packet.IP4) []filter.Net { - out := make([]filter.Net, 0, len(ips)) - for _, ip := range ips { - out = append(out, filterNet(ip, filter.Netmask(32))) +func ports(s string) filter.PortRange { + if s == "*" { + return filter.PortRangeAny } - return out + + var fs, ls string + i := strings.IndexByte(s, '-') + if i == -1 { + fs = s + ls = fs + } else { + fs = s[:i] + ls = s[i+1:] + } + first, err := strconv.ParseInt(fs, 10, 16) + if err != nil { + panic(fmt.Sprintf("invalid NetPortRange %q", s)) + } + last, err := strconv.ParseInt(ls, 10, 16) + if err != nil { + panic(fmt.Sprintf("invalid NetPortRange %q", s)) + } + return filter.PortRange{First: uint16(first), Last: uint16(last)} } -func ippr(ip packet.IP4, start, end uint16) []filter.NetPortRange { - return []filter.NetPortRange{ - filter.NetPortRange{ - Net: filterNet(ip, filter.Netmask(32)), - Ports: filter.PortRange{First: start, Last: end}, - }, +func netports(netPorts ...string) (ret []filter.NetPortRange) { + for _, s := range netPorts { + i := strings.LastIndexByte(s, ':') + if i == -1 { + panic(fmt.Sprintf("invalid NetPortRange %q", s)) + } + + npr := filter.NetPortRange{ + Net: nets(s[:i])[0], + Ports: ports(s[i+1:]), + } + ret = append(ret, npr) } + return ret } func setfilter(logf logger.Logf, tun *TUN) { matches := filter.Matches{ - {Srcs: nets([]packet.IP4{0x05060708}), Dsts: ippr(0x01020304, 89, 90)}, - {Srcs: nets([]packet.IP4{0x01020304}), Dsts: ippr(0x05060708, 98, 98)}, - } - localNets := []filter.Net{ - filterNet(packet.IP4(0x01020304), filter.Netmask(16)), + {Srcs: nets("5.6.7.8"), Dsts: netports("1.2.3.4:89-90")}, + {Srcs: nets("1.2.3.4"), Dsts: netports("5.6.7.8:98")}, } + localNets := nets("1.2.0.0/16") tun.SetFilter(filter.New(matches, localNets, nil, logf)) }