From 65fbb9c303a6c9c7e8d5823ef5b1122ca6a89108 Mon Sep 17 00:00:00 2001 From: Avery Pennarun Date: Thu, 30 Apr 2020 01:49:17 -0400 Subject: [PATCH] wgengine/filter: support subnet mask rules, not just /32 IPs. This depends on improved support from the control server, to send the new subnet width (Bits) fields. If these are missing, we fall back to assuming their value is /32. Conversely, if the server sends Bits fields to an older client, it will interpret them as /32 addresses. Since the only rules we allow are "accept" rules, this will be narrower or equal to the intended rule, so older clients will simply reject hosts on the wider subnet (fail closed). With this change, the internal filter.Matches format has diverged from the wire format used by controlclient, so move the wire format into tailcfg and convert it to filter.Matches in controlclient. Signed-off-by: Avery Pennarun --- control/controlclient/direct.go | 2 +- control/controlclient/filter.go | 80 +++++++++++++++++++++++++++++++ tailcfg/tailcfg.go | 37 +++++++++++++- wgengine/filter/filter.go | 2 +- wgengine/filter/filter_test.go | 36 +++++++++----- wgengine/filter/match.go | 85 +++++++++++++++++++++++---------- wgengine/packet/packet.go | 37 -------------- 7 files changed, 202 insertions(+), 77 deletions(-) create mode 100644 control/controlclient/filter.go diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go index 5a83e3067..3df89b7c0 100644 --- a/control/controlclient/direct.go +++ b/control/controlclient/direct.go @@ -593,7 +593,7 @@ func (c *Direct) PollNetMap(ctx context.Context, maxPolls int, cb func(*NetworkM DNS: resp.DNS, DNSDomains: resp.SearchPaths, Hostinfo: resp.Node.Hostinfo, - PacketFilter: resp.PacketFilter, + PacketFilter: c.parsePacketFilter(resp.PacketFilter), } for _, profile := range resp.UserProfiles { nm.UserProfiles[profile.ID] = profile diff --git a/control/controlclient/filter.go b/control/controlclient/filter.go new file mode 100644 index 000000000..a80e34ced --- /dev/null +++ b/control/controlclient/filter.go @@ -0,0 +1,80 @@ +package controlclient + +import ( + "fmt" + "net" + "tailscale.com/tailcfg" + "tailscale.com/wgengine/filter" +) + +func parseIP(host string, defaultBits int) (filter.Net, error) { + ip := net.ParseIP(host) + if ip != nil && ip.IsUnspecified() { + // For clarity, reject 0.0.0.0 as an input + return filter.NetNone, fmt.Errorf("ports=%#v: to allow all IP addresses, use *:port, not 0.0.0.0:port", host) + } else if ip == nil && host == "*" { + // User explicitly requested wildcard dst ip + return filter.NetAny, nil + } else { + if ip != nil { + ip = ip.To4() + } + if ip == nil || len(ip) != 4 { + return filter.NetNone, fmt.Errorf("ports=%#v: invalid IPv4 address", host) + } + return filter.Net{ + IP: filter.NewIP(ip), + Mask: filter.Netmask(defaultBits), + }, nil + } +} + +// Parse a backward-compatible FilterRule used by control's wire format, +// producing the most current filter.Matches format. +func (c *Direct) parsePacketFilter(pf []tailcfg.FilterRule) filter.Matches { + mm := make([]filter.Match, 0, len(pf)) + var erracc error + + for _, r := range pf { + m := filter.Match{} + + for i, s := range r.SrcIPs { + bits := 32 + if len(r.SrcBits) > i { + bits = r.SrcBits[i] + } + net, err := parseIP(s, bits) + if err != nil && erracc == nil { + erracc = err + continue + } + m.Srcs = append(m.Srcs, net) + } + + for _, d := range r.DstPorts { + bits := 32 + if d.Bits != nil { + bits = *d.Bits + } + net, err := parseIP(d.IP, bits) + if err != nil && erracc == nil { + erracc = err + continue + } + m.Dsts = append(m.Dsts, filter.NetPortRange{ + Net: net, + Ports: filter.PortRange{ + First: d.Ports.First, + Last: d.Ports.Last, + }, + }) + } + + mm = append(mm, m) + } + + if erracc != nil { + c.logf("parsePacketFilter: %s\n", erracc) + } + return mm +} diff --git a/tailcfg/tailcfg.go b/tailcfg/tailcfg.go index 88128096f..f37a59992 100644 --- a/tailcfg/tailcfg.go +++ b/tailcfg/tailcfg.go @@ -15,7 +15,6 @@ import ( "github.com/tailscale/wireguard-go/wgcfg" "golang.org/x/oauth2" "tailscale.com/types/opt" - "tailscale.com/wgengine/filter" ) type ID int64 @@ -404,6 +403,40 @@ type MapRequest struct { Hostinfo *Hostinfo } +// PortRange represents a range of UDP or TCP port numbers. +type PortRange struct { + First uint16 + Last uint16 +} + +var PortRangeAny = PortRange{0, 65535} + +// NetPortRange represents a single subnet:portrange. +type NetPortRange struct { + IP string + Bits *int // backward compatibility: if missing, means "all" bits + Ports PortRange +} + +// FilterRule represents one rule in a packet filter. +type FilterRule struct { + SrcIPs []string + SrcBits []int + DstPorts []NetPortRange +} + +var FilterAllowAll = []FilterRule{ + FilterRule{ + SrcIPs: []string{"*"}, + SrcBits: nil, + DstPorts: []NetPortRange{NetPortRange{ + IP: "*", + Bits: nil, + Ports: PortRange{0, 65535}, + }}, + }, +} + type MapResponse struct { KeepAlive bool // if set, all other fields are ignored @@ -415,7 +448,7 @@ type MapResponse struct { // ACLs Domain string - PacketFilter filter.Matches + PacketFilter []FilterRule UserProfiles []UserProfile Roles []Role // TODO: Groups []Group diff --git a/wgengine/filter/filter.go b/wgengine/filter/filter.go index 5e18fb0cf..4d5e66ded 100644 --- a/wgengine/filter/filter.go +++ b/wgengine/filter/filter.go @@ -71,7 +71,7 @@ const lruMax = 512 // max entries in UDP LRU cache // MatchAllowAll matches all packets. var MatchAllowAll = Matches{ - Match{[]IPPortRange{IPPortRangeAny}, []IP{IPAny}}, + Match{[]NetPortRange{NetPortRangeAny}, []Net{NetAny}}, } // NewAllowAll returns a packet filter that accepts everything. diff --git a/wgengine/filter/filter_test.go b/wgengine/filter/filter_test.go index 7dc0b099d..ad5cc484b 100644 --- a/wgengine/filter/filter_test.go +++ b/wgengine/filter/filter_test.go @@ -21,23 +21,37 @@ var TCP = packet.TCP var UDP = packet.UDP var Fragment = packet.Fragment -func ippr(ip IP, start, end uint16) []IPPortRange { - return []IPPortRange{ - IPPortRange{ip, PortRange{start, end}}, +func nets(ips []IP) []Net { + out := make([]Net, 0, len(ips)) + for _, ip := range ips { + out = append(out, Net{ip, Netmask(32)}) + } + return out +} + +func ippr(ip IP, start, end uint16) []NetPortRange { + return []NetPortRange{ + NetPortRange{Net{ip, Netmask(32)}, PortRange{start, end}}, + } +} + +func netpr(ip IP, bits int, start, end uint16) []NetPortRange { + return []NetPortRange{ + NetPortRange{Net{ip, Netmask(bits)}, PortRange{start, end}}, } } func TestFilter(t *testing.T) { mm := Matches{ - {SrcIPs: []IP{0x08010101, 0x08020202}, DstPorts: []IPPortRange{ - IPPortRange{0x01020304, PortRange{22, 22}}, - IPPortRange{0x05060708, PortRange{23, 24}}, + {Srcs: nets([]IP{0x08010101, 0x08020202}), Dsts: []NetPortRange{ + NetPortRange{Net{0x01020304, Netmask(32)}, PortRange{22, 22}}, + NetPortRange{Net{0x05060708, Netmask(32)}, PortRange{23, 24}}, }}, - {SrcIPs: []IP{0x08010101, 0x08020202}, DstPorts: ippr(0x05060708, 27, 28)}, - {SrcIPs: []IP{0x02020202}, DstPorts: ippr(0x08010101, 22, 22)}, - {SrcIPs: []IP{0}, DstPorts: ippr(0x647a6232, 0, 65535)}, - {SrcIPs: []IP{0}, DstPorts: ippr(0, 443, 443)}, - {SrcIPs: []IP{0x99010101, 0x99010102, 0x99030303}, DstPorts: ippr(0x01020304, 999, 999)}, + {Srcs: nets([]IP{0x08010101, 0x08020202}), Dsts: ippr(0x05060708, 27, 28)}, + {Srcs: nets([]IP{0x02020202}), Dsts: ippr(0x08010101, 22, 22)}, + {Srcs: []Net{NetAny}, Dsts: ippr(0x647a6232, 0, 65535)}, + {Srcs: []Net{NetAny}, Dsts: netpr(0, 0, 443, 443)}, + {Srcs: nets([]IP{0x99010101, 0x99010102, 0x99030303}), Dsts: ippr(0x01020304, 999, 999)}, } acl := New(mm, nil) diff --git a/wgengine/filter/match.go b/wgengine/filter/match.go index 2f89ab924..572595c7b 100644 --- a/wgengine/filter/match.go +++ b/wgengine/filter/match.go @@ -6,6 +6,8 @@ package filter import ( "fmt" + "math/bits" + "net" "strings" "tailscale.com/wgengine/packet" @@ -13,9 +15,42 @@ import ( type IP = packet.IP -const IPAny = IP(0) +func NewIP(ip net.IP) IP { + return packet.NewIP(ip) +} + +type Net struct { + IP IP + Mask IP +} + +func (n Net) Includes(ip IP) bool { + return (n.IP & n.Mask) == (ip & n.Mask) +} + +func (n Net) Bits() int { + return 32 - bits.TrailingZeros32(uint32(n.Mask)) +} -var NewIP = packet.NewIP +func (n Net) String() string { + b := n.Bits() + if b == 32 { + return n.IP.String() + } else if b == 0 { + return "*" + } else { + return fmt.Sprintf("%s/%d", n.IP, b) + } +} + +var NetAny = Net{0, 0} +var NetNone = Net{^IP(0), ^IP(0)} + +func Netmask(bits int) IP { + var b uint32 + b = ^uint32((1 << (32 - bits)) - 1) + return IP(b) +} type PortRange struct { First, Last uint16 @@ -33,39 +68,39 @@ func (pr PortRange) String() string { } } -type IPPortRange struct { - IP IP +type NetPortRange struct { + Net Net Ports PortRange } -var IPPortRangeAny = IPPortRange{IPAny, PortRangeAny} +var NetPortRangeAny = NetPortRange{NetAny, PortRangeAny} -func (ipr IPPortRange) String() string { - return fmt.Sprintf("%v:%v", ipr.IP, ipr.Ports) +func (ipr NetPortRange) String() string { + return fmt.Sprintf("%v:%v", ipr.Net, ipr.Ports) } type Match struct { - DstPorts []IPPortRange - SrcIPs []IP + Dsts []NetPortRange + Srcs []Net } func (m Match) Clone() (res Match) { - if m.DstPorts != nil { - res.DstPorts = append([]IPPortRange{}, m.DstPorts...) + if m.Dsts != nil { + res.Dsts = append([]NetPortRange{}, m.Dsts...) } - if m.SrcIPs != nil { - res.SrcIPs = append([]IP{}, m.SrcIPs...) + if m.Srcs != nil { + res.Srcs = append([]Net{}, m.Srcs...) } return res } func (m Match) String() string { srcs := []string{} - for _, srcip := range m.SrcIPs { - srcs = append(srcs, srcip.String()) + for _, src := range m.Srcs { + srcs = append(srcs, src.String()) } dsts := []string{} - for _, dst := range m.DstPorts { + for _, dst := range m.Dsts { dsts = append(dsts, dst.String()) } @@ -92,9 +127,9 @@ func (m Matches) Clone() (res Matches) { return res } -func ipInList(ip IP, iplist []IP) bool { - for _, ipp := range iplist { - if ipp == IPAny || ipp == ip { +func ipInList(ip IP, netlist []Net) bool { + for _, net := range netlist { + if net.Includes(ip) { return true } } @@ -103,14 +138,14 @@ func ipInList(ip IP, iplist []IP) bool { func matchIPPorts(mm Matches, q *packet.QDecode) bool { for _, acl := range mm { - for _, dst := range acl.DstPorts { - if dst.IP != IPAny && dst.IP != q.DstIP { + for _, dst := range acl.Dsts { + if !dst.Net.Includes(q.DstIP) { continue } if q.DstPort < dst.Ports.First || q.DstPort > dst.Ports.Last { continue } - if !ipInList(q.SrcIP, acl.SrcIPs) { + if !ipInList(q.SrcIP, acl.Srcs) { // Skip other dests in this acl, since // the src will never match. break @@ -123,11 +158,11 @@ func matchIPPorts(mm Matches, q *packet.QDecode) bool { func matchIPWithoutPorts(mm Matches, q *packet.QDecode) bool { for _, acl := range mm { - for _, dst := range acl.DstPorts { - if dst.IP != IPAny && dst.IP != q.DstIP { + for _, dst := range acl.Dsts { + if !dst.Net.Includes(q.DstIP) { continue } - if !ipInList(q.SrcIP, acl.SrcIPs) { + if !ipInList(q.SrcIP, acl.Srcs) { // Skip other dests in this acl, since // the src will never match. break diff --git a/wgengine/packet/packet.go b/wgengine/packet/packet.go index 0481cf9d1..e2cd08fc7 100644 --- a/wgengine/packet/packet.go +++ b/wgengine/packet/packet.go @@ -6,7 +6,6 @@ package packet import ( "encoding/binary" - "encoding/json" "fmt" "log" "net" @@ -43,8 +42,6 @@ func (p IPProto) String() string { type IP uint32 -const IPAny = IP(0) - func NewIP(b net.IP) IP { b4 := b.To4() if b4 == nil { @@ -54,45 +51,11 @@ func NewIP(b net.IP) IP { } func (ip IP) String() string { - if ip == 0 { - return "*" - } b := make([]byte, 4) binary.BigEndian.PutUint32(b, uint32(ip)) return fmt.Sprintf("%d.%d.%d.%d", b[0], b[1], b[2], b[3]) } -func (ipp *IP) MarshalJSON() ([]byte, error) { - s := "\"" + (*ipp).String() + "\"" - return []byte(s), nil -} - -func (ipp *IP) UnmarshalJSON(b []byte) error { - var hostp *string - err := json.Unmarshal(b, &hostp) - if err != nil { - return err - } - host := *hostp - ip := net.ParseIP(host) - if ip != nil && ip.IsUnspecified() { - // For clarity, reject 0.0.0.0 as an input - return fmt.Errorf("ports=%#v: to allow all IP addresses, use *:port, not 0.0.0.0:port", host) - } else if ip == nil && host == "*" { - // User explicitly requested wildcard dst ip - *ipp = IPAny - } else { - if ip != nil { - ip = ip.To4() - } - if ip == nil || len(ip) != 4 { - return fmt.Errorf("ports=%#v: invalid IPv4 address", host) - } - *ipp = NewIP(ip) - } - return nil -} - const ( EchoReply uint8 = 0x00 EchoRequest uint8 = 0x08