wgengine/filter: support subnet mask rules, not just /32 IPs.

This depends on improved support from the control server, to send the
new subnet width (Bits) fields. If these are missing, we fall back to
assuming their value is /32.

Conversely, if the server sends Bits fields to an older client, it will
interpret them as /32 addresses. Since the only rules we allow are
"accept" rules, this will be narrower or equal to the intended rule, so
older clients will simply reject hosts on the wider subnet (fail
closed).

With this change, the internal filter.Matches format has diverged
from the wire format used by controlclient, so move the wire format
into tailcfg and convert it to filter.Matches in controlclient.

Signed-off-by: Avery Pennarun <apenwarr@tailscale.com>
pull/348/head
Avery Pennarun 5 years ago
parent d6c34368e8
commit 65fbb9c303

@ -593,7 +593,7 @@ func (c *Direct) PollNetMap(ctx context.Context, maxPolls int, cb func(*NetworkM
DNS: resp.DNS,
DNSDomains: resp.SearchPaths,
Hostinfo: resp.Node.Hostinfo,
PacketFilter: resp.PacketFilter,
PacketFilter: c.parsePacketFilter(resp.PacketFilter),
}
for _, profile := range resp.UserProfiles {
nm.UserProfiles[profile.ID] = profile

@ -0,0 +1,80 @@
package controlclient
import (
"fmt"
"net"
"tailscale.com/tailcfg"
"tailscale.com/wgengine/filter"
)
func parseIP(host string, defaultBits int) (filter.Net, error) {
ip := net.ParseIP(host)
if ip != nil && ip.IsUnspecified() {
// For clarity, reject 0.0.0.0 as an input
return filter.NetNone, fmt.Errorf("ports=%#v: to allow all IP addresses, use *:port, not 0.0.0.0:port", host)
} else if ip == nil && host == "*" {
// User explicitly requested wildcard dst ip
return filter.NetAny, nil
} else {
if ip != nil {
ip = ip.To4()
}
if ip == nil || len(ip) != 4 {
return filter.NetNone, fmt.Errorf("ports=%#v: invalid IPv4 address", host)
}
return filter.Net{
IP: filter.NewIP(ip),
Mask: filter.Netmask(defaultBits),
}, nil
}
}
// Parse a backward-compatible FilterRule used by control's wire format,
// producing the most current filter.Matches format.
func (c *Direct) parsePacketFilter(pf []tailcfg.FilterRule) filter.Matches {
mm := make([]filter.Match, 0, len(pf))
var erracc error
for _, r := range pf {
m := filter.Match{}
for i, s := range r.SrcIPs {
bits := 32
if len(r.SrcBits) > i {
bits = r.SrcBits[i]
}
net, err := parseIP(s, bits)
if err != nil && erracc == nil {
erracc = err
continue
}
m.Srcs = append(m.Srcs, net)
}
for _, d := range r.DstPorts {
bits := 32
if d.Bits != nil {
bits = *d.Bits
}
net, err := parseIP(d.IP, bits)
if err != nil && erracc == nil {
erracc = err
continue
}
m.Dsts = append(m.Dsts, filter.NetPortRange{
Net: net,
Ports: filter.PortRange{
First: d.Ports.First,
Last: d.Ports.Last,
},
})
}
mm = append(mm, m)
}
if erracc != nil {
c.logf("parsePacketFilter: %s\n", erracc)
}
return mm
}

@ -15,7 +15,6 @@ import (
"github.com/tailscale/wireguard-go/wgcfg"
"golang.org/x/oauth2"
"tailscale.com/types/opt"
"tailscale.com/wgengine/filter"
)
type ID int64
@ -404,6 +403,40 @@ type MapRequest struct {
Hostinfo *Hostinfo
}
// PortRange represents a range of UDP or TCP port numbers.
type PortRange struct {
First uint16
Last uint16
}
var PortRangeAny = PortRange{0, 65535}
// NetPortRange represents a single subnet:portrange.
type NetPortRange struct {
IP string
Bits *int // backward compatibility: if missing, means "all" bits
Ports PortRange
}
// FilterRule represents one rule in a packet filter.
type FilterRule struct {
SrcIPs []string
SrcBits []int
DstPorts []NetPortRange
}
var FilterAllowAll = []FilterRule{
FilterRule{
SrcIPs: []string{"*"},
SrcBits: nil,
DstPorts: []NetPortRange{NetPortRange{
IP: "*",
Bits: nil,
Ports: PortRange{0, 65535},
}},
},
}
type MapResponse struct {
KeepAlive bool // if set, all other fields are ignored
@ -415,7 +448,7 @@ type MapResponse struct {
// ACLs
Domain string
PacketFilter filter.Matches
PacketFilter []FilterRule
UserProfiles []UserProfile
Roles []Role
// TODO: Groups []Group

@ -71,7 +71,7 @@ const lruMax = 512 // max entries in UDP LRU cache
// MatchAllowAll matches all packets.
var MatchAllowAll = Matches{
Match{[]IPPortRange{IPPortRangeAny}, []IP{IPAny}},
Match{[]NetPortRange{NetPortRangeAny}, []Net{NetAny}},
}
// NewAllowAll returns a packet filter that accepts everything.

@ -21,23 +21,37 @@ var TCP = packet.TCP
var UDP = packet.UDP
var Fragment = packet.Fragment
func ippr(ip IP, start, end uint16) []IPPortRange {
return []IPPortRange{
IPPortRange{ip, PortRange{start, end}},
func nets(ips []IP) []Net {
out := make([]Net, 0, len(ips))
for _, ip := range ips {
out = append(out, Net{ip, Netmask(32)})
}
return out
}
func ippr(ip IP, start, end uint16) []NetPortRange {
return []NetPortRange{
NetPortRange{Net{ip, Netmask(32)}, PortRange{start, end}},
}
}
func netpr(ip IP, bits int, start, end uint16) []NetPortRange {
return []NetPortRange{
NetPortRange{Net{ip, Netmask(bits)}, PortRange{start, end}},
}
}
func TestFilter(t *testing.T) {
mm := Matches{
{SrcIPs: []IP{0x08010101, 0x08020202}, DstPorts: []IPPortRange{
IPPortRange{0x01020304, PortRange{22, 22}},
IPPortRange{0x05060708, PortRange{23, 24}},
{Srcs: nets([]IP{0x08010101, 0x08020202}), Dsts: []NetPortRange{
NetPortRange{Net{0x01020304, Netmask(32)}, PortRange{22, 22}},
NetPortRange{Net{0x05060708, Netmask(32)}, PortRange{23, 24}},
}},
{SrcIPs: []IP{0x08010101, 0x08020202}, DstPorts: ippr(0x05060708, 27, 28)},
{SrcIPs: []IP{0x02020202}, DstPorts: ippr(0x08010101, 22, 22)},
{SrcIPs: []IP{0}, DstPorts: ippr(0x647a6232, 0, 65535)},
{SrcIPs: []IP{0}, DstPorts: ippr(0, 443, 443)},
{SrcIPs: []IP{0x99010101, 0x99010102, 0x99030303}, DstPorts: ippr(0x01020304, 999, 999)},
{Srcs: nets([]IP{0x08010101, 0x08020202}), Dsts: ippr(0x05060708, 27, 28)},
{Srcs: nets([]IP{0x02020202}), Dsts: ippr(0x08010101, 22, 22)},
{Srcs: []Net{NetAny}, Dsts: ippr(0x647a6232, 0, 65535)},
{Srcs: []Net{NetAny}, Dsts: netpr(0, 0, 443, 443)},
{Srcs: nets([]IP{0x99010101, 0x99010102, 0x99030303}), Dsts: ippr(0x01020304, 999, 999)},
}
acl := New(mm, nil)

@ -6,6 +6,8 @@ package filter
import (
"fmt"
"math/bits"
"net"
"strings"
"tailscale.com/wgengine/packet"
@ -13,9 +15,42 @@ import (
type IP = packet.IP
const IPAny = IP(0)
func NewIP(ip net.IP) IP {
return packet.NewIP(ip)
}
type Net struct {
IP IP
Mask IP
}
func (n Net) Includes(ip IP) bool {
return (n.IP & n.Mask) == (ip & n.Mask)
}
func (n Net) Bits() int {
return 32 - bits.TrailingZeros32(uint32(n.Mask))
}
var NewIP = packet.NewIP
func (n Net) String() string {
b := n.Bits()
if b == 32 {
return n.IP.String()
} else if b == 0 {
return "*"
} else {
return fmt.Sprintf("%s/%d", n.IP, b)
}
}
var NetAny = Net{0, 0}
var NetNone = Net{^IP(0), ^IP(0)}
func Netmask(bits int) IP {
var b uint32
b = ^uint32((1 << (32 - bits)) - 1)
return IP(b)
}
type PortRange struct {
First, Last uint16
@ -33,39 +68,39 @@ func (pr PortRange) String() string {
}
}
type IPPortRange struct {
IP IP
type NetPortRange struct {
Net Net
Ports PortRange
}
var IPPortRangeAny = IPPortRange{IPAny, PortRangeAny}
var NetPortRangeAny = NetPortRange{NetAny, PortRangeAny}
func (ipr IPPortRange) String() string {
return fmt.Sprintf("%v:%v", ipr.IP, ipr.Ports)
func (ipr NetPortRange) String() string {
return fmt.Sprintf("%v:%v", ipr.Net, ipr.Ports)
}
type Match struct {
DstPorts []IPPortRange
SrcIPs []IP
Dsts []NetPortRange
Srcs []Net
}
func (m Match) Clone() (res Match) {
if m.DstPorts != nil {
res.DstPorts = append([]IPPortRange{}, m.DstPorts...)
if m.Dsts != nil {
res.Dsts = append([]NetPortRange{}, m.Dsts...)
}
if m.SrcIPs != nil {
res.SrcIPs = append([]IP{}, m.SrcIPs...)
if m.Srcs != nil {
res.Srcs = append([]Net{}, m.Srcs...)
}
return res
}
func (m Match) String() string {
srcs := []string{}
for _, srcip := range m.SrcIPs {
srcs = append(srcs, srcip.String())
for _, src := range m.Srcs {
srcs = append(srcs, src.String())
}
dsts := []string{}
for _, dst := range m.DstPorts {
for _, dst := range m.Dsts {
dsts = append(dsts, dst.String())
}
@ -92,9 +127,9 @@ func (m Matches) Clone() (res Matches) {
return res
}
func ipInList(ip IP, iplist []IP) bool {
for _, ipp := range iplist {
if ipp == IPAny || ipp == ip {
func ipInList(ip IP, netlist []Net) bool {
for _, net := range netlist {
if net.Includes(ip) {
return true
}
}
@ -103,14 +138,14 @@ func ipInList(ip IP, iplist []IP) bool {
func matchIPPorts(mm Matches, q *packet.QDecode) bool {
for _, acl := range mm {
for _, dst := range acl.DstPorts {
if dst.IP != IPAny && dst.IP != q.DstIP {
for _, dst := range acl.Dsts {
if !dst.Net.Includes(q.DstIP) {
continue
}
if q.DstPort < dst.Ports.First || q.DstPort > dst.Ports.Last {
continue
}
if !ipInList(q.SrcIP, acl.SrcIPs) {
if !ipInList(q.SrcIP, acl.Srcs) {
// Skip other dests in this acl, since
// the src will never match.
break
@ -123,11 +158,11 @@ func matchIPPorts(mm Matches, q *packet.QDecode) bool {
func matchIPWithoutPorts(mm Matches, q *packet.QDecode) bool {
for _, acl := range mm {
for _, dst := range acl.DstPorts {
if dst.IP != IPAny && dst.IP != q.DstIP {
for _, dst := range acl.Dsts {
if !dst.Net.Includes(q.DstIP) {
continue
}
if !ipInList(q.SrcIP, acl.SrcIPs) {
if !ipInList(q.SrcIP, acl.Srcs) {
// Skip other dests in this acl, since
// the src will never match.
break

@ -6,7 +6,6 @@ package packet
import (
"encoding/binary"
"encoding/json"
"fmt"
"log"
"net"
@ -43,8 +42,6 @@ func (p IPProto) String() string {
type IP uint32
const IPAny = IP(0)
func NewIP(b net.IP) IP {
b4 := b.To4()
if b4 == nil {
@ -54,45 +51,11 @@ func NewIP(b net.IP) IP {
}
func (ip IP) String() string {
if ip == 0 {
return "*"
}
b := make([]byte, 4)
binary.BigEndian.PutUint32(b, uint32(ip))
return fmt.Sprintf("%d.%d.%d.%d", b[0], b[1], b[2], b[3])
}
func (ipp *IP) MarshalJSON() ([]byte, error) {
s := "\"" + (*ipp).String() + "\""
return []byte(s), nil
}
func (ipp *IP) UnmarshalJSON(b []byte) error {
var hostp *string
err := json.Unmarshal(b, &hostp)
if err != nil {
return err
}
host := *hostp
ip := net.ParseIP(host)
if ip != nil && ip.IsUnspecified() {
// For clarity, reject 0.0.0.0 as an input
return fmt.Errorf("ports=%#v: to allow all IP addresses, use *:port, not 0.0.0.0:port", host)
} else if ip == nil && host == "*" {
// User explicitly requested wildcard dst ip
*ipp = IPAny
} else {
if ip != nil {
ip = ip.To4()
}
if ip == nil || len(ip) != 4 {
return fmt.Errorf("ports=%#v: invalid IPv4 address", host)
}
*ipp = NewIP(ip)
}
return nil
}
const (
EchoReply uint8 = 0x00
EchoRequest uint8 = 0x08

Loading…
Cancel
Save