From 4d56d19b46932f804b1b207f8cc3153e1096381e Mon Sep 17 00:00:00 2001 From: chungdaniel Date: Thu, 20 Aug 2020 13:36:19 -0400 Subject: [PATCH] =?UTF-8?q?control/controlclient,=20wgengine/filter:=20ext?= =?UTF-8?q?ract=20parsePacketFilter=20to=20=E2=80=A6=20(#696)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit control/controlclient, wgengine/filter: extract parsePacketFilter to new constructor in wgengine/filter Signed-off-by: chungdaniel --- control/controlclient/filter.go | 74 ++------------------------- control/controlclient/filter_test.go | 42 ---------------- wgengine/filter/filter.go | 75 ++++++++++++++++++++++++++++ wgengine/filter/filter_test.go | 30 +++++++++++ 4 files changed, 108 insertions(+), 113 deletions(-) delete mode 100644 control/controlclient/filter_test.go diff --git a/control/controlclient/filter.go b/control/controlclient/filter.go index 098079cdf..9d2752a74 100644 --- a/control/controlclient/filter.go +++ b/control/controlclient/filter.go @@ -5,84 +5,16 @@ package controlclient import ( - "fmt" - "net" - "tailscale.com/tailcfg" "tailscale.com/wgengine/filter" ) -func parseIP(host string, defaultBits int) (filter.Net, error) { - ip := net.ParseIP(host) - if ip != nil && ip.IsUnspecified() { - // For clarity, reject 0.0.0.0 as an input - return filter.NetNone, fmt.Errorf("ports=%#v: to allow all IP addresses, use *:port, not 0.0.0.0:port", host) - } else if ip == nil && host == "*" { - // User explicitly requested wildcard dst ip - return filter.NetAny, nil - } else { - if ip != nil { - ip = ip.To4() - } - if ip == nil || len(ip) != 4 { - return filter.NetNone, fmt.Errorf("ports=%#v: invalid IPv4 address", host) - } - if len(ip) == 4 && (defaultBits < 0 || defaultBits > 32) { - return filter.NetNone, fmt.Errorf("invalid CIDR size %d for host %q", defaultBits, host) - } - return filter.Net{ - IP: filter.NewIP(ip), - Mask: filter.Netmask(defaultBits), - }, nil - } -} - // Parse a backward-compatible FilterRule used by control's wire format, // producing the most current filter.Matches format. func (c *Direct) parsePacketFilter(pf []tailcfg.FilterRule) filter.Matches { - mm := make([]filter.Match, 0, len(pf)) - var erracc error - - for _, r := range pf { - m := filter.Match{} - - for i, s := range r.SrcIPs { - bits := 32 - if len(r.SrcBits) > i { - bits = r.SrcBits[i] - } - net, err := parseIP(s, bits) - if err != nil && erracc == nil { - erracc = err - continue - } - m.Srcs = append(m.Srcs, net) - } - - for _, d := range r.DstPorts { - bits := 32 - if d.Bits != nil { - bits = *d.Bits - } - net, err := parseIP(d.IP, bits) - if err != nil && erracc == nil { - erracc = err - continue - } - m.Dsts = append(m.Dsts, filter.NetPortRange{ - Net: net, - Ports: filter.PortRange{ - First: d.Ports.First, - Last: d.Ports.Last, - }, - }) - } - - mm = append(mm, m) - } - - if erracc != nil { - c.logf("parsePacketFilter: %s\n", erracc) + mm, err := filter.MatchesFromFilterRules(pf) + if err != nil { + c.logf("parsePacketFilter: %s\n", err) } return mm } diff --git a/control/controlclient/filter_test.go b/control/controlclient/filter_test.go deleted file mode 100644 index 9a69be81f..000000000 --- a/control/controlclient/filter_test.go +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package controlclient - -import ( - "net" - "testing" - - "tailscale.com/wgengine/filter" - "tailscale.com/wgengine/packet" -) - -func TestParseIP(t *testing.T) { - tests := []struct { - host string - bits int - want filter.Net - wantErr string - }{ - {"8.8.8.8", 24, filter.Net{IP: packet.NewIP(net.ParseIP("8.8.8.8")), Mask: packet.NewIP(net.ParseIP("255.255.255.0"))}, ""}, - {"8.8.8.8", 33, filter.Net{}, `invalid CIDR size 33 for host "8.8.8.8"`}, - {"8.8.8.8", -1, filter.Net{}, `invalid CIDR size -1 for host "8.8.8.8"`}, - {"0.0.0.0", 24, filter.Net{}, `ports="0.0.0.0": to allow all IP addresses, use *:port, not 0.0.0.0:port`}, - {"*", 24, filter.NetAny, ""}, - {"fe80::1", 128, filter.NetNone, `ports="fe80::1": invalid IPv4 address`}, - } - for _, tt := range tests { - got, err := parseIP(tt.host, tt.bits) - if err != nil { - if err.Error() == tt.wantErr { - continue - } - t.Errorf("parseIP(%q, %v) error: %v; want error %q", tt.host, tt.bits, err, tt.wantErr) - } - if got != tt.want { - t.Errorf("parseIP(%q, %v) = %#v; want %#v", tt.host, tt.bits, got, tt.want) - continue - } - } -} diff --git a/wgengine/filter/filter.go b/wgengine/filter/filter.go index 4875b40c2..4cb509b28 100644 --- a/wgengine/filter/filter.go +++ b/wgengine/filter/filter.go @@ -7,11 +7,13 @@ package filter import ( "fmt" + "net" "sync" "time" "github.com/golang/groupcache/lru" "golang.org/x/time/rate" + "tailscale.com/tailcfg" "tailscale.com/types/logger" "tailscale.com/wgengine/packet" ) @@ -129,6 +131,79 @@ func maybeHexdump(flag RunFlags, b []byte) string { return packet.Hexdump(b) + "\n" } +// MatchesFromFilterRules parse a number of wire-format FilterRule values into +// the Matches format. +// If an error is returned, the Matches result is still valid, containing the rules that +// were successfully converted. +func MatchesFromFilterRules(pf []tailcfg.FilterRule) (Matches, error) { + mm := make([]Match, 0, len(pf)) + var erracc error + + for _, r := range pf { + m := Match{} + + for i, s := range r.SrcIPs { + bits := 32 + if len(r.SrcBits) > i { + bits = r.SrcBits[i] + } + net, err := parseIP(s, bits) + if err != nil && erracc == nil { + erracc = err + continue + } + m.Srcs = append(m.Srcs, net) + } + + for _, d := range r.DstPorts { + bits := 32 + if d.Bits != nil { + bits = *d.Bits + } + net, err := parseIP(d.IP, bits) + if err != nil && erracc == nil { + erracc = err + continue + } + m.Dsts = append(m.Dsts, NetPortRange{ + Net: net, + Ports: PortRange{ + First: d.Ports.First, + Last: d.Ports.Last, + }, + }) + } + + mm = append(mm, m) + } + return mm, erracc +} + +func parseIP(host string, defaultBits int) (Net, error) { + ip := net.ParseIP(host) + if ip != nil && ip.IsUnspecified() { + // For clarity, reject 0.0.0.0 as an input + return NetNone, fmt.Errorf("ports=%#v: to allow all IP addresses, use *:port, not 0.0.0.0:port", host) + } else if ip == nil && host == "*" { + // User explicitly requested wildcard dst ip + return NetAny, nil + } else { + if ip != nil { + ip = ip.To4() + } + if ip == nil || len(ip) != 4 { + return NetNone, fmt.Errorf("ports=%#v: invalid IPv4 address", host) + } + if len(ip) == 4 && (defaultBits < 0 || defaultBits > 32) { + return NetNone, fmt.Errorf("invalid CIDR size %d for host %q", defaultBits, host) + } + return Net{ + IP: NewIP(ip), + Mask: Netmask(defaultBits), + }, nil + } +} + // TODO(apenwarr): use a bigger bucket for specifically TCP SYN accept logging? // Logging is a quick way to record every newly opened TCP connection, but // we have to be cautious about flooding the logs vs letting people use diff --git a/wgengine/filter/filter_test.go b/wgengine/filter/filter_test.go index c95f43897..ce1e46cac 100644 --- a/wgengine/filter/filter_test.go +++ b/wgengine/filter/filter_test.go @@ -8,6 +8,7 @@ import ( "encoding/binary" "encoding/hex" "encoding/json" + "net" "strings" "testing" @@ -162,6 +163,35 @@ func TestNoAllocs(t *testing.T) { } } +func TestParseIP(t *testing.T) { + tests := []struct { + host string + bits int + want Net + wantErr string + }{ + {"8.8.8.8", 24, Net{IP: packet.NewIP(net.ParseIP("8.8.8.8")), Mask: packet.NewIP(net.ParseIP("255.255.255.0"))}, ""}, + {"8.8.8.8", 33, Net{}, `invalid CIDR size 33 for host "8.8.8.8"`}, + {"8.8.8.8", -1, Net{}, `invalid CIDR size -1 for host "8.8.8.8"`}, + {"0.0.0.0", 24, Net{}, `ports="0.0.0.0": to allow all IP addresses, use *:port, not 0.0.0.0:port`}, + {"*", 24, NetAny, ""}, + {"fe80::1", 128, NetNone, `ports="fe80::1": invalid IPv4 address`}, + } + for _, tt := range tests { + got, err := parseIP(tt.host, tt.bits) + if err != nil { + if err.Error() == tt.wantErr { + continue + } + t.Errorf("parseIP(%q, %v) error: %v; want error %q", tt.host, tt.bits, err, tt.wantErr) + } + if got != tt.want { + t.Errorf("parseIP(%q, %v) = %#v; want %#v", tt.host, tt.bits, got, tt.want) + continue + } + } +} + func BenchmarkFilter(b *testing.B) { acl := newFilter(b.Logf)