diff --git a/wgengine/filter/filter.go b/wgengine/filter/filter.go index c34d2765d..3aee79b8b 100644 --- a/wgengine/filter/filter.go +++ b/wgengine/filter/filter.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Package filter contains a stateful packet filter. +// Package filter is a stateful packet filter. package filter import ( @@ -14,15 +14,9 @@ import ( "golang.org/x/time/rate" "inet.af/netaddr" "tailscale.com/net/packet" - "tailscale.com/tailcfg" "tailscale.com/types/logger" ) -type filterState struct { - mu sync.Mutex - lru *lru.Cache // of tuple -} - // Filter is a stateful packet filter. type Filter struct { logf logger.Logf @@ -45,14 +39,31 @@ type Filter struct { state *filterState } -// Response is a verdict: either a Drop, Accept, or noVerdict skip to -// continue processing. +// tuple is a 4-tuple of source and destination IPv4 and port. It's +// used as a lookup key in filterState. +type tuple struct { + SrcIP packet.IP4 + DstIP packet.IP4 + SrcPort uint16 + DstPort uint16 +} + +// filterState is a state cache of past seen packets. +type filterState struct { + mu sync.Mutex + lru *lru.Cache // of tuple +} + +// lruMax is the size of the LRU cache in filterState. +const lruMax = 512 + +// Response is a verdict from the packet filter. type Response int const ( - Drop Response = iota - Accept - noVerdict // Returned from subfilters to continue processing. + Drop Response = iota // do not continue processing packet. + Accept // continue processing packet. + noVerdict // no verdict yet, continue running filter ) func (r Response) String() string { @@ -72,30 +83,16 @@ func (r Response) String() string { type RunFlags int const ( - LogDrops RunFlags = 1 << iota - LogAccepts - HexdumpDrops - HexdumpAccepts + LogDrops RunFlags = 1 << iota // write dropped packet info to logf + LogAccepts // write accepted packet info to logf + HexdumpDrops // print packet hexdump when logging drops + HexdumpAccepts // print packet hexdump when logging accepts ) -type tuple struct { - SrcIP packet.IP4 - DstIP packet.IP4 - SrcPort uint16 - DstPort uint16 -} - -const lruMax = 512 // max entries in UDP LRU cache - -// MatchAllowAll matches all packets. -var MatchAllowAll = Matches{ - Match{NetPortRangeAny, NetAny}, -} - // NewAllowAll returns a packet filter that accepts everything to and // from localNets. func NewAllowAll(localNets []netaddr.IPPrefix, logf logger.Logf) *Filter { - return New(MatchAllowAll, localNets, nil, logf) + return New(Matches{Match{NetPortRangeAny, NetAny}}, localNets, nil, logf) } // NewAllowNone returns a packet filter that rejects everything. @@ -106,8 +103,8 @@ func NewAllowNone(logf logger.Logf) *Filter { // New creates a new packet filter. The filter enforces that incoming // packets must be destined to an IP in localNets, and must be allowed // by matches. If shareStateWith is non-nil, the returned filter -// shares state with the previous one, to enable rules to be changed -// at runtime without breaking existing flows. +// shares state with the previous one, to enable changing rules at +// runtime without breaking existing stateful flows. func New(matches Matches, localNets []netaddr.IPPrefix, shareStateWith *Filter, logf logger.Logf) *Filter { var state *filterState if shareStateWith != nil { @@ -133,82 +130,6 @@ func maybeHexdump(flag RunFlags, b []byte) string { return packet.Hexdump(b) + "\n" } -// MatchesFromFilterRules parse a number of wire-format FilterRule values into -// the Matches format. -// If an error is returned, the Matches result is still valid, containing the rules that -// were successfully converted. -func MatchesFromFilterRules(pf []tailcfg.FilterRule) (Matches, error) { - mm := make([]Match, 0, len(pf)) - var erracc error - - for _, r := range pf { - m := Match{} - - for i, s := range r.SrcIPs { - bits := 32 - if len(r.SrcBits) > i { - bits = r.SrcBits[i] - } - net, err := parseIP(s, bits) - if err != nil && erracc == nil { - erracc = err - continue - } - m.Srcs = append(m.Srcs, net) - } - - for _, d := range r.DstPorts { - bits := 32 - if d.Bits != nil { - bits = *d.Bits - } - net, err := parseIP(d.IP, bits) - if err != nil && erracc == nil { - erracc = err - continue - } - m.Dsts = append(m.Dsts, NetPortRange{ - Net: net, - Ports: PortRange{ - First: d.Ports.First, - Last: d.Ports.Last, - }, - }) - } - - mm = append(mm, m) - } - return mm, erracc -} - -func parseIP(host string, defaultBits int) (netaddr.IPPrefix, error) { - if host == "*" { - // User explicitly requested wildcard dst ip. - // TODO: ipv6 - return netaddr.IPPrefix{IP: netaddr.IPv4(0, 0, 0, 0), Bits: 0}, nil - } - - ip, err := netaddr.ParseIP(host) - if err != nil { - return netaddr.IPPrefix{}, fmt.Errorf("ports=%#v: invalid IP address", host) - } - if ip == netaddr.IPv4(0, 0, 0, 0) { - // For clarity, reject 0.0.0.0 as an input - return netaddr.IPPrefix{}, fmt.Errorf("ports=%#v: to allow all IP addresses, use *:port, not 0.0.0.0:port", host) - } - if !ip.Is4() { - // TODO: ipv6 - return netaddr.IPPrefix{}, fmt.Errorf("ports=%#v: invalid IPv4 address", host) - } - if defaultBits < 0 || defaultBits > 32 { - return netaddr.IPPrefix{}, fmt.Errorf("invalid CIDR size %d for host %q", defaultBits, host) - } - return netaddr.IPPrefix{ - IP: ip, - Bits: uint8(defaultBits), - }, nil -} - // TODO(apenwarr): use a bigger bucket for specifically TCP SYN accept logging? // Logging is a quick way to record every newly opened TCP connection, but // we have to be cautious about flooding the logs vs letting people use @@ -240,7 +161,8 @@ func (f *Filter) logRateLimit(runflags RunFlags, q *packet.ParsedPacket, dir dir } } -// RunIn determines whether this node is allowed to receive q from a Tailscale peer. +// RunIn determines whether this node is allowed to receive q from a +// Tailscale peer. func (f *Filter) RunIn(q *packet.ParsedPacket, rf RunFlags) Response { dir := in r := f.pre(q, rf, dir) @@ -254,7 +176,8 @@ func (f *Filter) RunIn(q *packet.ParsedPacket, rf RunFlags) Response { return r } -// RunOut determines whether this node is allowed to send q to a Tailscale peer. +// RunOut determines whether this node is allowed to send q to a +// Tailscale peer. func (f *Filter) RunOut(q *packet.ParsedPacket, rf RunFlags) Response { dir := out r := f.pre(q, rf, dir) @@ -267,6 +190,7 @@ func (f *Filter) RunOut(q *packet.ParsedPacket, rf RunFlags) Response { return r } +// runIn runs the input-specific part of the filter logic. func (f *Filter) runIn(q *packet.ParsedPacket) (r Response, why string) { // A compromised peer could try to send us packets for // destinations we didn't explicitly advertise. This check is to @@ -327,6 +251,7 @@ func (f *Filter) runIn(q *packet.ParsedPacket) (r Response, why string) { return Drop, "no rules matched" } +// runIn runs the output-specific part of the filter logic. func (f *Filter) runOut(q *packet.ParsedPacket) (r Response, why string) { if q.IPProto == packet.UDP { t := tuple{q.DstIP, q.SrcIP, q.DstPort, q.SrcPort} @@ -339,12 +264,13 @@ func (f *Filter) runOut(q *packet.ParsedPacket) (r Response, why string) { return Accept, "ok out" } -// direction is whether a packet was flowing in to this machine, or flowing out. +// direction is whether a packet was flowing in to this machine, or +// flowing out. type direction int const ( - in direction = iota - out + in direction = iota // from Tailscale peer to local machine + out // from local machine to Tailscale peer ) func (d direction) String() string { @@ -358,6 +284,8 @@ func (d direction) String() string { } } +// pre runs the direction-agnostic filter logic. dir is only used for +// logging. func (f *Filter) pre(q *packet.ParsedPacket, rf RunFlags, dir direction) Response { if len(q.Buffer()) == 0 { // wireguard keepalive packet, always permit. diff --git a/wgengine/filter/match.go b/wgengine/filter/match.go index da5bc367f..198533df3 100644 --- a/wgengine/filter/match.go +++ b/wgengine/filter/match.go @@ -29,6 +29,7 @@ func (pr PortRange) String() string { } } +// contains returns whether port is in pr. func (pr PortRange) contains(port uint16) bool { return port >= pr.First && port <= pr.Last } @@ -47,6 +48,7 @@ func (npr NetPortRange) String() string { return fmt.Sprintf("%v:%v", npr.Net, npr.Ports) } +// NetPortRangeAny matches any IP and port. var NetPortRangeAny = []NetPortRange{{Net: NetAny[0], Ports: PortRangeAny}} // Match matches packets from any IP address in Srcs to any ip:port in diff --git a/wgengine/filter/tailcfg.go b/wgengine/filter/tailcfg.go new file mode 100644 index 000000000..f6a4ef0f3 --- /dev/null +++ b/wgengine/filter/tailcfg.go @@ -0,0 +1,87 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package filter + +import ( + "fmt" + + "inet.af/netaddr" + "tailscale.com/tailcfg" +) + +// MatchesFromFilterRules converts tailcfg FilterRules into Matches. +// If an error is returned, the Matches result is still valid, +// containing the rules that were successfully converted. +func MatchesFromFilterRules(pf []tailcfg.FilterRule) (Matches, error) { + mm := make([]Match, 0, len(pf)) + var erracc error + + for _, r := range pf { + m := Match{} + + for i, s := range r.SrcIPs { + bits := 32 + if len(r.SrcBits) > i { + bits = r.SrcBits[i] + } + net, err := parseIP(s, bits) + if err != nil && erracc == nil { + erracc = err + continue + } + m.Srcs = append(m.Srcs, net) + } + + for _, d := range r.DstPorts { + bits := 32 + if d.Bits != nil { + bits = *d.Bits + } + net, err := parseIP(d.IP, bits) + if err != nil && erracc == nil { + erracc = err + continue + } + m.Dsts = append(m.Dsts, NetPortRange{ + Net: net, + Ports: PortRange{ + First: d.Ports.First, + Last: d.Ports.Last, + }, + }) + } + + mm = append(mm, m) + } + return mm, erracc +} + +func parseIP(host string, defaultBits int) (netaddr.IPPrefix, error) { + if host == "*" { + // User explicitly requested wildcard dst ip. + // TODO: ipv6 + return netaddr.IPPrefix{IP: netaddr.IPv4(0, 0, 0, 0), Bits: 0}, nil + } + + ip, err := netaddr.ParseIP(host) + if err != nil { + return netaddr.IPPrefix{}, fmt.Errorf("ports=%#v: invalid IP address", host) + } + if ip == netaddr.IPv4(0, 0, 0, 0) { + // For clarity, reject 0.0.0.0 as an input + return netaddr.IPPrefix{}, fmt.Errorf("ports=%#v: to allow all IP addresses, use *:port, not 0.0.0.0:port", host) + } + if !ip.Is4() { + // TODO: ipv6 + return netaddr.IPPrefix{}, fmt.Errorf("ports=%#v: invalid IPv4 address", host) + } + if defaultBits < 0 || defaultBits > 32 { + return netaddr.IPPrefix{}, fmt.Errorf("invalid CIDR size %d for host %q", defaultBits, host) + } + return netaddr.IPPrefix{ + IP: ip, + Bits: uint8(defaultBits), + }, nil +}