wgengine/filter: use netaddr types in public API.

We still use the packet.* alloc-free types in the data path, but
the compilation from netaddr to packet happens within the filter
package.

Signed-off-by: David Anderson <danderson@tailscale.com>
pull/910/head
David Anderson 4 years ago committed by Dave Anderson
parent 7988f75b87
commit b3634f020d

@ -546,7 +546,7 @@ func (b *LocalBackend) updateFilter(netMap *controlclient.NetworkMap, prefs *Pre
return return
} }
localNets := wgCIDRsToFilter(netMap.Addresses, advRoutes) localNets := wgCIDRsToNetaddr(netMap.Addresses, advRoutes)
if shieldsUp { if shieldsUp {
b.logf("netmap packet filter: (shields up)") b.logf("netmap packet filter: (shields up)")
@ -1266,14 +1266,14 @@ func routerConfig(cfg *wgcfg.Config, prefs *Prefs) *router.Config {
} }
rs := &router.Config{ rs := &router.Config{
LocalAddrs: wgCIDRToNetaddr(addrs), LocalAddrs: wgCIDRsToNetaddr(addrs),
SubnetRoutes: wgCIDRToNetaddr(prefs.AdvertiseRoutes), SubnetRoutes: wgCIDRsToNetaddr(prefs.AdvertiseRoutes),
SNATSubnetRoutes: !prefs.NoSNAT, SNATSubnetRoutes: !prefs.NoSNAT,
NetfilterMode: prefs.NetfilterMode, NetfilterMode: prefs.NetfilterMode,
} }
for _, peer := range cfg.Peers { for _, peer := range cfg.Peers {
rs.Routes = append(rs.Routes, wgCIDRToNetaddr(peer.AllowedIPs)...) rs.Routes = append(rs.Routes, wgCIDRsToNetaddr(peer.AllowedIPs)...)
} }
rs.Routes = append(rs.Routes, netaddr.IPPrefix{ rs.Routes = append(rs.Routes, netaddr.IPPrefix{
@ -1284,31 +1284,16 @@ func routerConfig(cfg *wgcfg.Config, prefs *Prefs) *router.Config {
return rs return rs
} }
// wgCIDRsToFilter converts lists of wgcfg.CIDR into a single list of func wgCIDRsToNetaddr(cidrLists ...[]wgcfg.CIDR) (ret []netaddr.IPPrefix) {
// filter.Net.
func wgCIDRsToFilter(cidrLists ...[]wgcfg.CIDR) (ret []filter.Net) {
for _, cidrs := range cidrLists { for _, cidrs := range cidrLists {
for _, cidr := range cidrs { for _, cidr := range cidrs {
if !cidr.IP.Is4() { ncidr, ok := netaddr.FromStdIPNet(cidr.IPNet())
continue if !ok {
panic(fmt.Sprintf("conversion of %s from wgcfg to netaddr IPNet failed", cidr))
} }
ret = append(ret, filter.Net{ ncidr.IP = ncidr.IP.Unmap()
IP: filter.NewIP(cidr.IP.IP()), ret = append(ret, ncidr)
Mask: filter.Netmask(int(cidr.Mask)),
})
}
}
return ret
}
func wgCIDRToNetaddr(cidrs []wgcfg.CIDR) (ret []netaddr.IPPrefix) {
for _, cidr := range cidrs {
ncidr, ok := netaddr.FromStdIPNet(cidr.IPNet())
if !ok {
panic(fmt.Sprintf("conversion of %s from wgcfg to netaddr IPNet failed", cidr))
} }
ncidr.IP = ncidr.IP.Unmap()
ret = append(ret, ncidr)
} }
return ret return ret
} }

@ -7,12 +7,12 @@ package filter
import ( import (
"fmt" "fmt"
"net"
"sync" "sync"
"time" "time"
"github.com/golang/groupcache/lru" "github.com/golang/groupcache/lru"
"golang.org/x/time/rate" "golang.org/x/time/rate"
"inet.af/netaddr"
"tailscale.com/net/packet" "tailscale.com/net/packet"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/logger" "tailscale.com/types/logger"
@ -26,16 +26,18 @@ type filterState struct {
// Filter is a stateful packet filter. // Filter is a stateful packet filter.
type Filter struct { type Filter struct {
logf logger.Logf logf logger.Logf
// localNets is the list of IP prefixes that we know to be "local" // localNets is the list of IP prefixes that we know to be
// to this node. All packets coming in over tailscale must have a // "local" to this node. All packets coming in over tailscale
// destination within localNets, regardless of the policy filter // must have a destination within localNets, regardless of the
// below. A nil localNets rejects all incoming traffic. // policy filter below. A nil localNets rejects all incoming
localNets []Net // traffic.
// matches is a list of match->action rules applied to all packets local4 []net4
// arriving over tailscale tunnels. Matches are checked in order, // matches4 is a list of match->action rules applied to all
// and processing stops at the first matching rule. The default // packets arriving over tailscale tunnels. Matches are
// policy if no rules match is to drop the packet. // checked in order, and processing stops at the first
matches Matches // matching rule. The default policy if no rules match is to
// drop the packet.
matches4 matches4
// state is the connection tracking state attached to this // state is the connection tracking state attached to this
// filter. It is used to allow incoming traffic that is a response // filter. It is used to allow incoming traffic that is a response
// to an outbound connection that this node made, even if those // to an outbound connection that this node made, even if those
@ -87,12 +89,12 @@ const lruMax = 512 // max entries in UDP LRU cache
// MatchAllowAll matches all packets. // MatchAllowAll matches all packets.
var MatchAllowAll = Matches{ var MatchAllowAll = Matches{
Match{[]NetPortRange{NetPortRangeAny}, []Net{NetAny}}, Match{NetPortRangeAny, NetAny},
} }
// NewAllowAll returns a packet filter that accepts everything to and // NewAllowAll returns a packet filter that accepts everything to and
// from localNets. // from localNets.
func NewAllowAll(localNets []Net, logf logger.Logf) *Filter { func NewAllowAll(localNets []netaddr.IPPrefix, logf logger.Logf) *Filter {
return New(MatchAllowAll, localNets, nil, logf) return New(MatchAllowAll, localNets, nil, logf)
} }
@ -106,7 +108,7 @@ func NewAllowNone(logf logger.Logf) *Filter {
// by matches. If shareStateWith is non-nil, the returned filter // by matches. If shareStateWith is non-nil, the returned filter
// shares state with the previous one, to enable rules to be changed // shares state with the previous one, to enable rules to be changed
// at runtime without breaking existing flows. // at runtime without breaking existing flows.
func New(matches Matches, localNets []Net, shareStateWith *Filter, logf logger.Logf) *Filter { func New(matches Matches, localNets []netaddr.IPPrefix, shareStateWith *Filter, logf logger.Logf) *Filter {
var state *filterState var state *filterState
if shareStateWith != nil { if shareStateWith != nil {
state = shareStateWith.state state = shareStateWith.state
@ -116,10 +118,10 @@ func New(matches Matches, localNets []Net, shareStateWith *Filter, logf logger.L
} }
} }
f := &Filter{ f := &Filter{
logf: logf, logf: logf,
matches: matches, matches4: newMatches4(matches),
localNets: localNets, local4: nets4FromIPPrefixes(localNets),
state: state, state: state,
} }
return f return f
} }
@ -179,29 +181,32 @@ func MatchesFromFilterRules(pf []tailcfg.FilterRule) (Matches, error) {
return mm, erracc return mm, erracc
} }
func parseIP(host string, defaultBits int) (Net, error) { func parseIP(host string, defaultBits int) (netaddr.IPPrefix, error) {
ip := net.ParseIP(host) if host == "*" {
if ip != nil && ip.IsUnspecified() { // User explicitly requested wildcard dst ip.
// TODO: ipv6
return netaddr.IPPrefix{IP: netaddr.IPv4(0, 0, 0, 0), Bits: 0}, nil
}
ip, err := netaddr.ParseIP(host)
if err != nil {
return netaddr.IPPrefix{}, fmt.Errorf("ports=%#v: invalid IP address", host)
}
if ip == netaddr.IPv4(0, 0, 0, 0) {
// For clarity, reject 0.0.0.0 as an input // For clarity, reject 0.0.0.0 as an input
return NetNone, fmt.Errorf("ports=%#v: to allow all IP addresses, use *:port, not 0.0.0.0:port", host) return netaddr.IPPrefix{}, fmt.Errorf("ports=%#v: to allow all IP addresses, use *:port, not 0.0.0.0:port", host)
} else if ip == nil && host == "*" { }
// User explicitly requested wildcard dst ip if !ip.Is4() {
return NetAny, nil // TODO: ipv6
} else { return netaddr.IPPrefix{}, fmt.Errorf("ports=%#v: invalid IPv4 address", host)
if ip != nil { }
ip = ip.To4() if defaultBits < 0 || defaultBits > 32 {
} return netaddr.IPPrefix{}, fmt.Errorf("invalid CIDR size %d for host %q", defaultBits, host)
if ip == nil || len(ip) != 4 {
return NetNone, fmt.Errorf("ports=%#v: invalid IPv4 address", host)
}
if len(ip) == 4 && (defaultBits < 0 || defaultBits > 32) {
return NetNone, fmt.Errorf("invalid CIDR size %d for host %q", defaultBits, host)
}
return Net{
IP: NewIP(ip),
Mask: Netmask(defaultBits),
}, nil
} }
return netaddr.IPPrefix{
IP: ip,
Bits: uint8(defaultBits),
}, nil
} }
// TODO(apenwarr): use a bigger bucket for specifically TCP SYN accept logging? // TODO(apenwarr): use a bigger bucket for specifically TCP SYN accept logging?
@ -266,7 +271,7 @@ func (f *Filter) runIn(q *packet.ParsedPacket) (r Response, why string) {
// A compromised peer could try to send us packets for // A compromised peer could try to send us packets for
// destinations we didn't explicitly advertise. This check is to // destinations we didn't explicitly advertise. This check is to
// prevent that. // prevent that.
if !ipInList(q.DstIP, f.localNets) { if !ip4InList(q.DstIP, f.local4) {
return Drop, "destination not allowed" return Drop, "destination not allowed"
} }
@ -284,7 +289,7 @@ func (f *Filter) runIn(q *packet.ParsedPacket) (r Response, why string) {
// related to an existing ICMP-Echo, TCP, or UDP // related to an existing ICMP-Echo, TCP, or UDP
// session. // session.
return Accept, "icmp response ok" return Accept, "icmp response ok"
} else if matchIPWithoutPorts(f.matches, q) { } else if f.matches4.matchIPsOnly(q) {
// If any port is open to an IP, allow ICMP to it. // If any port is open to an IP, allow ICMP to it.
return Accept, "icmp ok" return Accept, "icmp ok"
} }
@ -300,7 +305,7 @@ func (f *Filter) runIn(q *packet.ParsedPacket) (r Response, why string) {
if q.IPProto == packet.TCP && !q.IsTCPSyn() { if q.IPProto == packet.TCP && !q.IsTCPSyn() {
return Accept, "tcp non-syn" return Accept, "tcp non-syn"
} }
if matchIPPorts(f.matches, q) { if f.matches4.match(q) {
return Accept, "tcp ok" return Accept, "tcp ok"
} }
case packet.UDP: case packet.UDP:
@ -313,7 +318,7 @@ func (f *Filter) runIn(q *packet.ParsedPacket) (r Response, why string) {
if ok { if ok {
return Accept, "udp cached" return Accept, "udp cached"
} }
if matchIPPorts(f.matches, q) { if f.matches4.match(q) {
return Accept, "udp ok" return Accept, "udp ok"
} }
default: default:
@ -399,9 +404,9 @@ const (
) )
// omitDropLogging reports whether packet p, which has already been // omitDropLogging reports whether packet p, which has already been
// deemded a packet to Drop, should bypass the [rate-limited] logging. // deemed a packet to Drop, should bypass the [rate-limited] logging.
// We don't want to log scary & spammy reject warnings for packets that // We don't want to log scary & spammy reject warnings for packets
// are totally normal, like IPv6 route announcements. // that are totally normal, like IPv6 route announcements.
func omitDropLogging(p *packet.ParsedPacket, dir direction) bool { func omitDropLogging(p *packet.ParsedPacket, dir direction) bool {
b := p.Buffer() b := p.Buffer()
switch dir { switch dir {

@ -8,10 +8,13 @@ import (
"encoding/binary" "encoding/binary"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"fmt"
"net" "net"
"strconv"
"strings" "strings"
"testing" "testing"
"inet.af/netaddr"
"tailscale.com/net/packet" "tailscale.com/net/packet"
"tailscale.com/types/logger" "tailscale.com/types/logger"
) )
@ -22,43 +25,91 @@ var TCP = packet.TCP
var UDP = packet.UDP var UDP = packet.UDP
var Fragment = packet.Fragment var Fragment = packet.Fragment
func nets(ips []packet.IP4) []Net { func pfx(s string) netaddr.IPPrefix {
out := make([]Net, 0, len(ips)) pfx, err := netaddr.ParseIPPrefix(s)
for _, ip := range ips { if err != nil {
out = append(out, Net{ip, Netmask(32)}) panic(err)
} }
return out return pfx
} }
func ippr(ip packet.IP4, start, end uint16) []NetPortRange { func nets(nets ...string) (ret []netaddr.IPPrefix) {
return []NetPortRange{ for _, s := range nets {
NetPortRange{Net{ip, Netmask(32)}, PortRange{start, end}}, if i := strings.IndexByte(s, '/'); i == -1 {
ip, err := netaddr.ParseIP(s)
if err != nil {
panic(err)
}
bits := uint8(32)
if ip.Is6() {
bits = 128
}
ret = append(ret, netaddr.IPPrefix{IP: ip, Bits: bits})
} else {
pfx, err := netaddr.ParseIPPrefix(s)
if err != nil {
panic(err)
}
ret = append(ret, pfx)
}
} }
return ret
} }
func netpr(ip packet.IP4, bits int, start, end uint16) []NetPortRange { func ports(s string) PortRange {
return []NetPortRange{ if s == "*" {
NetPortRange{Net{ip, Netmask(bits)}, PortRange{start, end}}, return PortRangeAny
}
var fs, ls string
i := strings.IndexByte(s, '-')
if i == -1 {
fs = s
ls = fs
} else {
fs = s[:i]
ls = s[i+1:]
}
first, err := strconv.ParseInt(fs, 10, 16)
if err != nil {
panic(fmt.Sprintf("invalid NetPortRange %q", s))
}
last, err := strconv.ParseInt(ls, 10, 16)
if err != nil {
panic(fmt.Sprintf("invalid NetPortRange %q", s))
}
return PortRange{uint16(first), uint16(last)}
}
func netports(netPorts ...string) (ret []NetPortRange) {
for _, s := range netPorts {
i := strings.LastIndexByte(s, ':')
if i == -1 {
panic(fmt.Sprintf("invalid NetPortRange %q", s))
}
npr := NetPortRange{
Net: nets(s[:i])[0],
Ports: ports(s[i+1:]),
}
ret = append(ret, npr)
} }
return ret
} }
var matches = Matches{ var matches = Matches{
{Srcs: nets([]packet.IP4{0x08010101, 0x08020202}), Dsts: []NetPortRange{ {Srcs: nets("8.1.1.1", "8.2.2.2"), Dsts: netports("1.2.3.4:22", "5.6.7.8:23-24")},
NetPortRange{Net{0x01020304, Netmask(32)}, PortRange{22, 22}}, {Srcs: nets("8.1.1.1", "8.2.2.2"), Dsts: netports("5.6.7.8:27-28")},
NetPortRange{Net{0x05060708, Netmask(32)}, PortRange{23, 24}}, {Srcs: nets("2.2.2.2"), Dsts: netports("8.1.1.1:22")},
}}, {Srcs: nets("0.0.0.0/0"), Dsts: netports("100.122.98.50:*")},
{Srcs: nets([]packet.IP4{0x08010101, 0x08020202}), Dsts: ippr(0x05060708, 27, 28)}, {Srcs: nets("0.0.0.0/0"), Dsts: netports("0.0.0.0/0:443")},
{Srcs: nets([]packet.IP4{0x02020202}), Dsts: ippr(0x08010101, 22, 22)}, {Srcs: nets("153.1.1.1", "153.1.1.2", "153.3.3.3"), Dsts: netports("1.2.3.4:999")},
{Srcs: []Net{NetAny}, Dsts: ippr(0x647a6232, 0, 65535)},
{Srcs: []Net{NetAny}, Dsts: netpr(0, 0, 443, 443)},
{Srcs: nets([]packet.IP4{0x99010101, 0x99010102, 0x99030303}), Dsts: ippr(0x01020304, 999, 999)},
} }
func newFilter(logf logger.Logf) *Filter { func newFilter(logf logger.Logf) *Filter {
// Expects traffic to 100.122.98.50, 1.2.3.4, 5.6.7.8, // Expects traffic to 100.122.98.50, 1.2.3.4, 5.6.7.8,
// 102.102.102.102, 119.119.119.119, 8.1.0.0/16 // 102.102.102.102, 119.119.119.119, 8.1.0.0/16
localNets := nets([]packet.IP4{0x647a6232, 0x01020304, 0x05060708, 0x66666666, 0x77777777}) localNets := nets("100.122.98.50", "1.2.3.4", "5.6.7.8", "102.102.102.102", "119.119.119.119", "8.1.0.0/16")
localNets = append(localNets, Net{packet.IP4(0x08010000), Netmask(16)})
return New(matches, localNets, nil, logf) return New(matches, localNets, nil, logf)
} }
@ -160,18 +211,19 @@ func TestNoAllocs(t *testing.T) {
} }
func TestParseIP(t *testing.T) { func TestParseIP(t *testing.T) {
var noaddr netaddr.IPPrefix
tests := []struct { tests := []struct {
host string host string
bits int bits int
want Net want netaddr.IPPrefix
wantErr string wantErr string
}{ }{
{"8.8.8.8", 24, Net{IP: packet.NewIP4(net.ParseIP("8.8.8.8")), Mask: packet.NewIP4(net.ParseIP("255.255.255.0"))}, ""}, {"8.8.8.8", 24, pfx("8.8.8.8/24"), ""},
{"8.8.8.8", 33, Net{}, `invalid CIDR size 33 for host "8.8.8.8"`}, {"8.8.8.8", 33, noaddr, `invalid CIDR size 33 for host "8.8.8.8"`},
{"8.8.8.8", -1, Net{}, `invalid CIDR size -1 for host "8.8.8.8"`}, {"8.8.8.8", -1, noaddr, `invalid CIDR size -1 for host "8.8.8.8"`},
{"0.0.0.0", 24, Net{}, `ports="0.0.0.0": to allow all IP addresses, use *:port, not 0.0.0.0:port`}, {"0.0.0.0", 24, noaddr, `ports="0.0.0.0": to allow all IP addresses, use *:port, not 0.0.0.0:port`},
{"*", 24, NetAny, ""}, {"*", 24, pfx("0.0.0.0/0"), ""},
{"fe80::1", 128, NetNone, `ports="fe80::1": invalid IPv4 address`}, {"fe80::1", 128, pfx("255.255.255.255/32"), `ports="fe80::1": invalid IPv4 address`},
} }
for _, tt := range tests { for _, tt := range tests {
got, err := parseIP(tt.host, tt.bits) got, err := parseIP(tt.host, tt.bits)
@ -215,6 +267,7 @@ func BenchmarkFilter(b *testing.B) {
for _, bench := range benches { for _, bench := range benches {
b.Run(bench.name, func(b *testing.B) { b.Run(bench.name, func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
q := &packet.ParsedPacket{} q := &packet.ParsedPacket{}
q.Decode(bench.packet) q.Decode(bench.packet)

@ -6,53 +6,17 @@ package filter
import ( import (
"fmt" "fmt"
"math/bits"
"net"
"strings" "strings"
"tailscale.com/net/packet" "inet.af/netaddr"
) )
func NewIP(ip net.IP) packet.IP4 { // PortRange is a range of TCP and UDP ports.
return packet.NewIP4(ip)
}
type Net struct {
IP packet.IP4
Mask packet.IP4
}
func (n Net) Includes(ip packet.IP4) bool {
return (n.IP & n.Mask) == (ip & n.Mask)
}
func (n Net) Bits() int {
return 32 - bits.TrailingZeros32(uint32(n.Mask))
}
func (n Net) String() string {
b := n.Bits()
if b == 32 {
return n.IP.String()
} else if b == 0 {
return "*"
} else {
return fmt.Sprintf("%s/%d", n.IP, b)
}
}
var NetAny = Net{0, 0}
var NetNone = Net{^packet.IP4(0), ^packet.IP4(0)}
func Netmask(bits int) packet.IP4 {
b := ^uint32((1 << (32 - bits)) - 1)
return packet.IP4(b)
}
type PortRange struct { type PortRange struct {
First, Last uint16 First, Last uint16 // inclusive
} }
// PortRangeAny represents all TCP and UDP ports.
var PortRangeAny = PortRange{0, 65535} var PortRangeAny = PortRange{0, 65535}
func (pr PortRange) String() string { func (pr PortRange) String() string {
@ -65,28 +29,40 @@ func (pr PortRange) String() string {
} }
} }
func (pr PortRange) contains(port uint16) bool {
return port >= pr.First && port <= pr.Last
}
// NetAny matches all IP addresses.
// TODO: add ipv6.
var NetAny = []netaddr.IPPrefix{{IP: netaddr.IPv4(0, 0, 0, 0), Bits: 0}}
// NetPortRange combines an IP address prefix and PortRange.
type NetPortRange struct { type NetPortRange struct {
Net Net Net netaddr.IPPrefix
Ports PortRange Ports PortRange
} }
var NetPortRangeAny = NetPortRange{NetAny, PortRangeAny} func (npr NetPortRange) String() string {
return fmt.Sprintf("%v:%v", npr.Net, npr.Ports)
func (ipr NetPortRange) String() string {
return fmt.Sprintf("%v:%v", ipr.Net, ipr.Ports)
} }
var NetPortRangeAny = []NetPortRange{{Net: NetAny[0], Ports: PortRangeAny}}
// Match matches packets from any IP address in Srcs to any ip:port in
// Dsts.
type Match struct { type Match struct {
Dsts []NetPortRange Dsts []NetPortRange
Srcs []Net Srcs []netaddr.IPPrefix
} }
// Clone returns a deep copy of m.
func (m Match) Clone() (res Match) { func (m Match) Clone() (res Match) {
if m.Dsts != nil { if m.Dsts != nil {
res.Dsts = append([]NetPortRange{}, m.Dsts...) res.Dsts = append([]NetPortRange{}, m.Dsts...)
} }
if m.Srcs != nil { if m.Srcs != nil {
res.Srcs = append([]Net{}, m.Srcs...) res.Srcs = append([]netaddr.IPPrefix{}, m.Srcs...)
} }
return res return res
} }
@ -115,57 +91,13 @@ func (m Match) String() string {
return fmt.Sprintf("%v=>%v", ss, ds) return fmt.Sprintf("%v=>%v", ss, ds)
} }
// Matches is a list of packet matchers.
type Matches []Match type Matches []Match
func (m Matches) Clone() (res Matches) { // Clone returns a deep copy of ms.
for _, match := range m { func (ms Matches) Clone() (res Matches) {
for _, match := range ms {
res = append(res, match.Clone()) res = append(res, match.Clone())
} }
return res return res
} }
func ipInList(ip packet.IP4, netlist []Net) bool {
for _, net := range netlist {
if net.Includes(ip) {
return true
}
}
return false
}
func matchIPPorts(mm Matches, q *packet.ParsedPacket) bool {
for _, acl := range mm {
for _, dst := range acl.Dsts {
if !dst.Net.Includes(q.DstIP) {
continue
}
if q.DstPort < dst.Ports.First || q.DstPort > dst.Ports.Last {
continue
}
if !ipInList(q.SrcIP, acl.Srcs) {
// Skip other dests in this acl, since
// the src will never match.
break
}
return true
}
}
return false
}
func matchIPWithoutPorts(mm Matches, q *packet.ParsedPacket) bool {
for _, acl := range mm {
for _, dst := range acl.Dsts {
if !dst.Net.Includes(q.DstIP) {
continue
}
if !ipInList(q.SrcIP, acl.Srcs) {
// Skip other dests in this acl, since
// the src will never match.
break
}
return true
}
}
return false
}

@ -0,0 +1,151 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package filter
import (
"fmt"
"math/bits"
"strings"
"inet.af/netaddr"
"tailscale.com/net/packet"
)
type net4 struct {
ip packet.IP4
mask packet.IP4
}
func net4FromIPPrefix(pfx netaddr.IPPrefix) net4 {
if !pfx.IP.Is4() {
panic("net4FromIPPrefix given non-ipv4 prefix")
}
return net4{
ip: packet.IP4FromNetaddr(pfx.IP),
mask: netmask4(pfx.Bits),
}
}
func nets4FromIPPrefixes(pfxs []netaddr.IPPrefix) (ret []net4) {
for _, pfx := range pfxs {
if pfx.IP.Is4() {
ret = append(ret, net4FromIPPrefix(pfx))
}
}
return ret
}
func (n net4) Contains(ip packet.IP4) bool {
return (n.ip & n.mask) == (ip & n.mask)
}
func (n net4) Bits() int {
return 32 - bits.TrailingZeros32(uint32(n.mask))
}
func (n net4) String() string {
b := n.Bits()
if b == 32 {
return n.ip.String()
} else if b == 0 {
return "*"
} else {
return fmt.Sprintf("%s/%d", n.ip, b)
}
}
type npr4 struct {
net net4
ports PortRange
}
func (npr npr4) String() string {
return fmt.Sprintf("%s:%s", npr.net, npr.ports)
}
type match4 struct {
dsts []npr4
srcs []net4
}
type matches4 []match4
func (ms matches4) String() string {
var b strings.Builder
for _, m := range ms {
fmt.Fprintf(&b, "%s => %s\n", m.srcs, m.dsts)
}
return b.String()
}
func newMatches4(ms Matches) (ret matches4) {
for _, m := range ms {
var m4 match4
for _, src := range m.Srcs {
if src.IP.Is4() {
m4.srcs = append(m4.srcs, net4FromIPPrefix(src))
}
}
for _, dst := range m.Dsts {
if dst.Net.IP.Is4() {
m4.dsts = append(m4.dsts, npr4{net4FromIPPrefix(dst.Net), dst.Ports})
}
}
if len(m4.srcs) > 0 && len(m4.dsts) > 0 {
ret = append(ret, m4)
}
}
return ret
}
// match returns whether q's source IP and destination IP:port match
// any of ms.
func (ms matches4) match(q *packet.ParsedPacket) bool {
for _, m := range ms {
if !ip4InList(q.SrcIP, m.srcs) {
continue
}
for _, dst := range m.dsts {
if !dst.net.Contains(q.DstIP) {
continue
}
if !dst.ports.contains(q.DstPort) {
continue
}
return true
}
}
return false
}
// matchIPsOnly returns whether q's source and destination IP match
// any of ms.
func (ms matches4) matchIPsOnly(q *packet.ParsedPacket) bool {
for _, m := range ms {
if !ip4InList(q.SrcIP, m.srcs) {
continue
}
for _, dst := range m.dsts {
if dst.net.Contains(q.DstIP) {
return true
}
}
}
return false
}
func netmask4(bits uint8) packet.IP4 {
b := ^uint32((1 << (32 - bits)) - 1)
return packet.IP4(b)
}
func ip4InList(ip packet.IP4, netlist []net4) bool {
for _, net := range netlist {
if net.Contains(ip) {
return true
}
}
return false
}

@ -158,7 +158,7 @@ func newMagicStack(t *testing.T, logf logger.Logf, l nettype.PacketListener, der
tun := tuntest.NewChannelTUN() tun := tuntest.NewChannelTUN()
tsTun := tstun.WrapTUN(logf, tun.TUN()) tsTun := tstun.WrapTUN(logf, tun.TUN())
tsTun.SetFilter(filter.NewAllowAll([]filter.Net{filter.NetAny}, logf)) tsTun.SetFilter(filter.NewAllowAll(filter.NetAny, logf))
dev := device.NewDevice(tsTun, &device.DeviceOptions{ dev := device.NewDevice(tsTun, &device.DeviceOptions{
Logger: &device.Logger{ Logger: &device.Logger{

@ -6,11 +6,15 @@ package tstun
import ( import (
"bytes" "bytes"
"fmt"
"strconv"
"strings"
"sync/atomic" "sync/atomic"
"testing" "testing"
"unsafe" "unsafe"
"github.com/tailscale/wireguard-go/tun/tuntest" "github.com/tailscale/wireguard-go/tun/tuntest"
"inet.af/netaddr"
"tailscale.com/net/packet" "tailscale.com/net/packet"
"tailscale.com/types/logger" "tailscale.com/types/logger"
"tailscale.com/wgengine/filter" "tailscale.com/wgengine/filter"
@ -29,35 +33,76 @@ func udp(src, dst packet.IP4, sport, dport uint16) []byte {
return packet.Generate(header, []byte("udp_payload")) return packet.Generate(header, []byte("udp_payload"))
} }
func filterNet(ip, mask packet.IP4) filter.Net { func nets(nets ...string) (ret []netaddr.IPPrefix) {
return filter.Net{IP: ip, Mask: mask} for _, s := range nets {
if i := strings.IndexByte(s, '/'); i == -1 {
ip, err := netaddr.ParseIP(s)
if err != nil {
panic(err)
}
bits := uint8(32)
if ip.Is6() {
bits = 128
}
ret = append(ret, netaddr.IPPrefix{IP: ip, Bits: bits})
} else {
pfx, err := netaddr.ParseIPPrefix(s)
if err != nil {
panic(err)
}
ret = append(ret, pfx)
}
}
return ret
} }
func nets(ips []packet.IP4) []filter.Net { func ports(s string) filter.PortRange {
out := make([]filter.Net, 0, len(ips)) if s == "*" {
for _, ip := range ips { return filter.PortRangeAny
out = append(out, filterNet(ip, filter.Netmask(32)))
} }
return out
var fs, ls string
i := strings.IndexByte(s, '-')
if i == -1 {
fs = s
ls = fs
} else {
fs = s[:i]
ls = s[i+1:]
}
first, err := strconv.ParseInt(fs, 10, 16)
if err != nil {
panic(fmt.Sprintf("invalid NetPortRange %q", s))
}
last, err := strconv.ParseInt(ls, 10, 16)
if err != nil {
panic(fmt.Sprintf("invalid NetPortRange %q", s))
}
return filter.PortRange{First: uint16(first), Last: uint16(last)}
} }
func ippr(ip packet.IP4, start, end uint16) []filter.NetPortRange { func netports(netPorts ...string) (ret []filter.NetPortRange) {
return []filter.NetPortRange{ for _, s := range netPorts {
filter.NetPortRange{ i := strings.LastIndexByte(s, ':')
Net: filterNet(ip, filter.Netmask(32)), if i == -1 {
Ports: filter.PortRange{First: start, Last: end}, panic(fmt.Sprintf("invalid NetPortRange %q", s))
}, }
npr := filter.NetPortRange{
Net: nets(s[:i])[0],
Ports: ports(s[i+1:]),
}
ret = append(ret, npr)
} }
return ret
} }
func setfilter(logf logger.Logf, tun *TUN) { func setfilter(logf logger.Logf, tun *TUN) {
matches := filter.Matches{ matches := filter.Matches{
{Srcs: nets([]packet.IP4{0x05060708}), Dsts: ippr(0x01020304, 89, 90)}, {Srcs: nets("5.6.7.8"), Dsts: netports("1.2.3.4:89-90")},
{Srcs: nets([]packet.IP4{0x01020304}), Dsts: ippr(0x05060708, 98, 98)}, {Srcs: nets("1.2.3.4"), Dsts: netports("5.6.7.8:98")},
}
localNets := []filter.Net{
filterNet(packet.IP4(0x01020304), filter.Netmask(16)),
} }
localNets := nets("1.2.0.0/16")
tun.SetFilter(filter.New(matches, localNets, nil, logf)) tun.SetFilter(filter.New(matches, localNets, nil, logf))
} }

Loading…
Cancel
Save