diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go index 5a83e3067..3df89b7c0 100644 --- a/control/controlclient/direct.go +++ b/control/controlclient/direct.go @@ -593,7 +593,7 @@ func (c *Direct) PollNetMap(ctx context.Context, maxPolls int, cb func(*NetworkM DNS: resp.DNS, DNSDomains: resp.SearchPaths, Hostinfo: resp.Node.Hostinfo, - PacketFilter: resp.PacketFilter, + PacketFilter: c.parsePacketFilter(resp.PacketFilter), } for _, profile := range resp.UserProfiles { nm.UserProfiles[profile.ID] = profile diff --git a/control/controlclient/filter.go b/control/controlclient/filter.go new file mode 100644 index 000000000..a80e34ced --- /dev/null +++ b/control/controlclient/filter.go @@ -0,0 +1,80 @@ +package controlclient + +import ( + "fmt" + "net" + "tailscale.com/tailcfg" + "tailscale.com/wgengine/filter" +) + +func parseIP(host string, defaultBits int) (filter.Net, error) { + ip := net.ParseIP(host) + if ip != nil && ip.IsUnspecified() { + // For clarity, reject 0.0.0.0 as an input + return filter.NetNone, fmt.Errorf("ports=%#v: to allow all IP addresses, use *:port, not 0.0.0.0:port", host) + } else if ip == nil && host == "*" { + // User explicitly requested wildcard dst ip + return filter.NetAny, nil + } else { + if ip != nil { + ip = ip.To4() + } + if ip == nil || len(ip) != 4 { + return filter.NetNone, fmt.Errorf("ports=%#v: invalid IPv4 address", host) + } + return filter.Net{ + IP: filter.NewIP(ip), + Mask: filter.Netmask(defaultBits), + }, nil + } +} + +// Parse a backward-compatible FilterRule used by control's wire format, +// producing the most current filter.Matches format. +func (c *Direct) parsePacketFilter(pf []tailcfg.FilterRule) filter.Matches { + mm := make([]filter.Match, 0, len(pf)) + var erracc error + + for _, r := range pf { + m := filter.Match{} + + for i, s := range r.SrcIPs { + bits := 32 + if len(r.SrcBits) > i { + bits = r.SrcBits[i] + } + net, err := parseIP(s, bits) + if err != nil && erracc == nil { + erracc = err + continue + } + m.Srcs = append(m.Srcs, net) + } + + for _, d := range r.DstPorts { + bits := 32 + if d.Bits != nil { + bits = *d.Bits + } + net, err := parseIP(d.IP, bits) + if err != nil && erracc == nil { + erracc = err + continue + } + m.Dsts = append(m.Dsts, filter.NetPortRange{ + Net: net, + Ports: filter.PortRange{ + First: d.Ports.First, + Last: d.Ports.Last, + }, + }) + } + + mm = append(mm, m) + } + + if erracc != nil { + c.logf("parsePacketFilter: %s\n", erracc) + } + return mm +} diff --git a/tailcfg/tailcfg.go b/tailcfg/tailcfg.go index 88128096f..f37a59992 100644 --- a/tailcfg/tailcfg.go +++ b/tailcfg/tailcfg.go @@ -15,7 +15,6 @@ import ( "github.com/tailscale/wireguard-go/wgcfg" "golang.org/x/oauth2" "tailscale.com/types/opt" - "tailscale.com/wgengine/filter" ) type ID int64 @@ -404,6 +403,40 @@ type MapRequest struct { Hostinfo *Hostinfo } +// PortRange represents a range of UDP or TCP port numbers. +type PortRange struct { + First uint16 + Last uint16 +} + +var PortRangeAny = PortRange{0, 65535} + +// NetPortRange represents a single subnet:portrange. +type NetPortRange struct { + IP string + Bits *int // backward compatibility: if missing, means "all" bits + Ports PortRange +} + +// FilterRule represents one rule in a packet filter. +type FilterRule struct { + SrcIPs []string + SrcBits []int + DstPorts []NetPortRange +} + +var FilterAllowAll = []FilterRule{ + FilterRule{ + SrcIPs: []string{"*"}, + SrcBits: nil, + DstPorts: []NetPortRange{NetPortRange{ + IP: "*", + Bits: nil, + Ports: PortRange{0, 65535}, + }}, + }, +} + type MapResponse struct { KeepAlive bool // if set, all other fields are ignored @@ -415,7 +448,7 @@ type MapResponse struct { // ACLs Domain string - PacketFilter filter.Matches + PacketFilter []FilterRule UserProfiles []UserProfile Roles []Role // TODO: Groups []Group diff --git a/wgengine/filter/filter.go b/wgengine/filter/filter.go index 5e18fb0cf..4d5e66ded 100644 --- a/wgengine/filter/filter.go +++ b/wgengine/filter/filter.go @@ -71,7 +71,7 @@ const lruMax = 512 // max entries in UDP LRU cache // MatchAllowAll matches all packets. var MatchAllowAll = Matches{ - Match{[]IPPortRange{IPPortRangeAny}, []IP{IPAny}}, + Match{[]NetPortRange{NetPortRangeAny}, []Net{NetAny}}, } // NewAllowAll returns a packet filter that accepts everything. diff --git a/wgengine/filter/filter_test.go b/wgengine/filter/filter_test.go index 7dc0b099d..ad5cc484b 100644 --- a/wgengine/filter/filter_test.go +++ b/wgengine/filter/filter_test.go @@ -21,23 +21,37 @@ var TCP = packet.TCP var UDP = packet.UDP var Fragment = packet.Fragment -func ippr(ip IP, start, end uint16) []IPPortRange { - return []IPPortRange{ - IPPortRange{ip, PortRange{start, end}}, +func nets(ips []IP) []Net { + out := make([]Net, 0, len(ips)) + for _, ip := range ips { + out = append(out, Net{ip, Netmask(32)}) + } + return out +} + +func ippr(ip IP, start, end uint16) []NetPortRange { + return []NetPortRange{ + NetPortRange{Net{ip, Netmask(32)}, PortRange{start, end}}, + } +} + +func netpr(ip IP, bits int, start, end uint16) []NetPortRange { + return []NetPortRange{ + NetPortRange{Net{ip, Netmask(bits)}, PortRange{start, end}}, } } func TestFilter(t *testing.T) { mm := Matches{ - {SrcIPs: []IP{0x08010101, 0x08020202}, DstPorts: []IPPortRange{ - IPPortRange{0x01020304, PortRange{22, 22}}, - IPPortRange{0x05060708, PortRange{23, 24}}, + {Srcs: nets([]IP{0x08010101, 0x08020202}), Dsts: []NetPortRange{ + NetPortRange{Net{0x01020304, Netmask(32)}, PortRange{22, 22}}, + NetPortRange{Net{0x05060708, Netmask(32)}, PortRange{23, 24}}, }}, - {SrcIPs: []IP{0x08010101, 0x08020202}, DstPorts: ippr(0x05060708, 27, 28)}, - {SrcIPs: []IP{0x02020202}, DstPorts: ippr(0x08010101, 22, 22)}, - {SrcIPs: []IP{0}, DstPorts: ippr(0x647a6232, 0, 65535)}, - {SrcIPs: []IP{0}, DstPorts: ippr(0, 443, 443)}, - {SrcIPs: []IP{0x99010101, 0x99010102, 0x99030303}, DstPorts: ippr(0x01020304, 999, 999)}, + {Srcs: nets([]IP{0x08010101, 0x08020202}), Dsts: ippr(0x05060708, 27, 28)}, + {Srcs: nets([]IP{0x02020202}), Dsts: ippr(0x08010101, 22, 22)}, + {Srcs: []Net{NetAny}, Dsts: ippr(0x647a6232, 0, 65535)}, + {Srcs: []Net{NetAny}, Dsts: netpr(0, 0, 443, 443)}, + {Srcs: nets([]IP{0x99010101, 0x99010102, 0x99030303}), Dsts: ippr(0x01020304, 999, 999)}, } acl := New(mm, nil) diff --git a/wgengine/filter/match.go b/wgengine/filter/match.go index 2f89ab924..572595c7b 100644 --- a/wgengine/filter/match.go +++ b/wgengine/filter/match.go @@ -6,6 +6,8 @@ package filter import ( "fmt" + "math/bits" + "net" "strings" "tailscale.com/wgengine/packet" @@ -13,9 +15,42 @@ import ( type IP = packet.IP -const IPAny = IP(0) +func NewIP(ip net.IP) IP { + return packet.NewIP(ip) +} + +type Net struct { + IP IP + Mask IP +} + +func (n Net) Includes(ip IP) bool { + return (n.IP & n.Mask) == (ip & n.Mask) +} + +func (n Net) Bits() int { + return 32 - bits.TrailingZeros32(uint32(n.Mask)) +} -var NewIP = packet.NewIP +func (n Net) String() string { + b := n.Bits() + if b == 32 { + return n.IP.String() + } else if b == 0 { + return "*" + } else { + return fmt.Sprintf("%s/%d", n.IP, b) + } +} + +var NetAny = Net{0, 0} +var NetNone = Net{^IP(0), ^IP(0)} + +func Netmask(bits int) IP { + var b uint32 + b = ^uint32((1 << (32 - bits)) - 1) + return IP(b) +} type PortRange struct { First, Last uint16 @@ -33,39 +68,39 @@ func (pr PortRange) String() string { } } -type IPPortRange struct { - IP IP +type NetPortRange struct { + Net Net Ports PortRange } -var IPPortRangeAny = IPPortRange{IPAny, PortRangeAny} +var NetPortRangeAny = NetPortRange{NetAny, PortRangeAny} -func (ipr IPPortRange) String() string { - return fmt.Sprintf("%v:%v", ipr.IP, ipr.Ports) +func (ipr NetPortRange) String() string { + return fmt.Sprintf("%v:%v", ipr.Net, ipr.Ports) } type Match struct { - DstPorts []IPPortRange - SrcIPs []IP + Dsts []NetPortRange + Srcs []Net } func (m Match) Clone() (res Match) { - if m.DstPorts != nil { - res.DstPorts = append([]IPPortRange{}, m.DstPorts...) + if m.Dsts != nil { + res.Dsts = append([]NetPortRange{}, m.Dsts...) } - if m.SrcIPs != nil { - res.SrcIPs = append([]IP{}, m.SrcIPs...) + if m.Srcs != nil { + res.Srcs = append([]Net{}, m.Srcs...) } return res } func (m Match) String() string { srcs := []string{} - for _, srcip := range m.SrcIPs { - srcs = append(srcs, srcip.String()) + for _, src := range m.Srcs { + srcs = append(srcs, src.String()) } dsts := []string{} - for _, dst := range m.DstPorts { + for _, dst := range m.Dsts { dsts = append(dsts, dst.String()) } @@ -92,9 +127,9 @@ func (m Matches) Clone() (res Matches) { return res } -func ipInList(ip IP, iplist []IP) bool { - for _, ipp := range iplist { - if ipp == IPAny || ipp == ip { +func ipInList(ip IP, netlist []Net) bool { + for _, net := range netlist { + if net.Includes(ip) { return true } } @@ -103,14 +138,14 @@ func ipInList(ip IP, iplist []IP) bool { func matchIPPorts(mm Matches, q *packet.QDecode) bool { for _, acl := range mm { - for _, dst := range acl.DstPorts { - if dst.IP != IPAny && dst.IP != q.DstIP { + for _, dst := range acl.Dsts { + if !dst.Net.Includes(q.DstIP) { continue } if q.DstPort < dst.Ports.First || q.DstPort > dst.Ports.Last { continue } - if !ipInList(q.SrcIP, acl.SrcIPs) { + if !ipInList(q.SrcIP, acl.Srcs) { // Skip other dests in this acl, since // the src will never match. break @@ -123,11 +158,11 @@ func matchIPPorts(mm Matches, q *packet.QDecode) bool { func matchIPWithoutPorts(mm Matches, q *packet.QDecode) bool { for _, acl := range mm { - for _, dst := range acl.DstPorts { - if dst.IP != IPAny && dst.IP != q.DstIP { + for _, dst := range acl.Dsts { + if !dst.Net.Includes(q.DstIP) { continue } - if !ipInList(q.SrcIP, acl.SrcIPs) { + if !ipInList(q.SrcIP, acl.Srcs) { // Skip other dests in this acl, since // the src will never match. break diff --git a/wgengine/packet/packet.go b/wgengine/packet/packet.go index 0481cf9d1..e2cd08fc7 100644 --- a/wgengine/packet/packet.go +++ b/wgengine/packet/packet.go @@ -6,7 +6,6 @@ package packet import ( "encoding/binary" - "encoding/json" "fmt" "log" "net" @@ -43,8 +42,6 @@ func (p IPProto) String() string { type IP uint32 -const IPAny = IP(0) - func NewIP(b net.IP) IP { b4 := b.To4() if b4 == nil { @@ -54,45 +51,11 @@ func NewIP(b net.IP) IP { } func (ip IP) String() string { - if ip == 0 { - return "*" - } b := make([]byte, 4) binary.BigEndian.PutUint32(b, uint32(ip)) return fmt.Sprintf("%d.%d.%d.%d", b[0], b[1], b[2], b[3]) } -func (ipp *IP) MarshalJSON() ([]byte, error) { - s := "\"" + (*ipp).String() + "\"" - return []byte(s), nil -} - -func (ipp *IP) UnmarshalJSON(b []byte) error { - var hostp *string - err := json.Unmarshal(b, &hostp) - if err != nil { - return err - } - host := *hostp - ip := net.ParseIP(host) - if ip != nil && ip.IsUnspecified() { - // For clarity, reject 0.0.0.0 as an input - return fmt.Errorf("ports=%#v: to allow all IP addresses, use *:port, not 0.0.0.0:port", host) - } else if ip == nil && host == "*" { - // User explicitly requested wildcard dst ip - *ipp = IPAny - } else { - if ip != nil { - ip = ip.To4() - } - if ip == nil || len(ip) != 4 { - return fmt.Errorf("ports=%#v: invalid IPv4 address", host) - } - *ipp = NewIP(ip) - } - return nil -} - const ( EchoReply uint8 = 0x00 EchoRequest uint8 = 0x08