diff --git a/wgengine/filter/filter_test.go b/wgengine/filter/filter_test.go index aea7c9978..0d6c21553 100644 --- a/wgengine/filter/filter_test.go +++ b/wgengine/filter/filter_test.go @@ -11,6 +11,7 @@ import ( "strings" "testing" + "github.com/google/go-cmp/cmp" "inet.af/netaddr" "tailscale.com/net/packet" "tailscale.com/types/logger" @@ -188,21 +189,20 @@ func TestNoAllocs(t *testing.T) { } func TestParseIP(t *testing.T) { - var noaddr netaddr.IPPrefix tests := []struct { host string bits int - want netaddr.IPPrefix + want []netaddr.IPPrefix wantErr string }{ {"8.8.8.8", 24, pfx("8.8.8.8/24"), ""}, {"2601:1234::", 64, pfx("2601:1234::/64"), ""}, - {"8.8.8.8", 33, noaddr, `invalid CIDR size 33 for host "8.8.8.8"`}, - {"8.8.8.8", -1, noaddr, `invalid CIDR size -1 for host "8.8.8.8"`}, - {"2601:1234::", 129, noaddr, `invalid CIDR size 129 for host "2601:1234::"`}, - {"0.0.0.0", 24, noaddr, `ports="0.0.0.0": to allow all IP addresses, use *:port, not 0.0.0.0:port`}, - {"::", 64, noaddr, `ports="::": to allow all IP addresses, use *:port, not [::]:port`}, - {"*", 24, pfx("0.0.0.0/0"), ""}, + {"8.8.8.8", 33, nil, `invalid CIDR size 33 for host "8.8.8.8"`}, + {"8.8.8.8", -1, nil, `invalid CIDR size -1 for host "8.8.8.8"`}, + {"2601:1234::", 129, nil, `invalid CIDR size 129 for host "2601:1234::"`}, + {"0.0.0.0", 24, nil, `ports="0.0.0.0": to allow all IP addresses, use *:port, not 0.0.0.0:port`}, + {"::", 64, nil, `ports="::": to allow all IP addresses, use *:port, not [::]:port`}, + {"*", 24, pfx("0.0.0.0/0", "::/0"), ""}, } for _, tt := range tests { got, err := parseIP(tt.host, tt.bits) @@ -212,8 +212,8 @@ func TestParseIP(t *testing.T) { } t.Errorf("parseIP(%q, %v) error: %v; want error %q", tt.host, tt.bits, err, tt.wantErr) } - if got != tt.want { - t.Errorf("parseIP(%q, %v) = %#v; want %#v", tt.host, tt.bits, got, tt.want) + if diff := cmp.Diff(got, tt.want, cmp.Comparer(func(a, b netaddr.IP) bool { return a == b })); diff != "" { + t.Errorf("parseIP(%q, %v) = %s; want %s", tt.host, tt.bits, got, tt.want) continue } } @@ -480,12 +480,15 @@ func mustIP4(s string) packet.IP4 { return packet.IP4FromNetaddr(ip) } -func pfx(s string) netaddr.IPPrefix { - pfx, err := netaddr.ParseIPPrefix(s) - if err != nil { - panic(err) +func pfx(strs ...string) (ret []netaddr.IPPrefix) { + for _, s := range strs { + pfx, err := netaddr.ParseIPPrefix(s) + if err != nil { + panic(err) + } + ret = append(ret, pfx) } - return pfx + return ret } func nets(nets ...string) (ret []netaddr.IPPrefix) { diff --git a/wgengine/filter/tailcfg.go b/wgengine/filter/tailcfg.go index c498e0936..db3b28469 100644 --- a/wgengine/filter/tailcfg.go +++ b/wgengine/filter/tailcfg.go @@ -26,12 +26,12 @@ func MatchesFromFilterRules(pf []tailcfg.FilterRule) ([]Match, error) { if len(r.SrcBits) > i { bits = r.SrcBits[i] } - net, err := parseIP(s, bits) + nets, err := parseIP(s, bits) if err != nil && erracc == nil { erracc = err continue } - m.Srcs = append(m.Srcs, net) + m.Srcs = append(m.Srcs, nets...) } for _, d := range r.DstPorts { @@ -39,18 +39,20 @@ func MatchesFromFilterRules(pf []tailcfg.FilterRule) ([]Match, error) { if d.Bits != nil { bits = *d.Bits } - net, err := parseIP(d.IP, bits) + nets, err := parseIP(d.IP, bits) if err != nil && erracc == nil { erracc = err continue } - m.Dsts = append(m.Dsts, NetPortRange{ - Net: net, - Ports: PortRange{ - First: d.Ports.First, - Last: d.Ports.Last, - }, - }) + for _, net := range nets { + m.Dsts = append(m.Dsts, NetPortRange{ + Net: net, + Ports: PortRange{ + First: d.Ports.First, + Last: d.Ports.Last, + }, + }) + } } mm = append(mm, m) @@ -63,31 +65,35 @@ var ( zeroIP6 = netaddr.IPFrom16([16]byte{}) ) -func parseIP(host string, defaultBits int) (netaddr.IPPrefix, error) { +func parseIP(host string, defaultBits int) ([]netaddr.IPPrefix, error) { if host == "*" { // User explicitly requested wildcard dst ip. - // TODO: ipv6 - return netaddr.IPPrefix{IP: netaddr.IPv4(0, 0, 0, 0), Bits: 0}, nil + return []netaddr.IPPrefix{ + {IP: zeroIP4, Bits: 0}, + {IP: zeroIP6, Bits: 0}, + }, nil } ip, err := netaddr.ParseIP(host) if err != nil { - return netaddr.IPPrefix{}, fmt.Errorf("ports=%#v: invalid IP address", host) + return nil, fmt.Errorf("ports=%#v: invalid IP address", host) } if ip == zeroIP4 { // For clarity, reject 0.0.0.0 as an input - return netaddr.IPPrefix{}, fmt.Errorf("ports=%#v: to allow all IP addresses, use *:port, not 0.0.0.0:port", host) + return nil, fmt.Errorf("ports=%#v: to allow all IP addresses, use *:port, not 0.0.0.0:port", host) } if ip == zeroIP6 { // For clarity, reject :: as an input - return netaddr.IPPrefix{}, fmt.Errorf("ports=%#v: to allow all IP addresses, use *:port, not [::]:port", host) + return nil, fmt.Errorf("ports=%#v: to allow all IP addresses, use *:port, not [::]:port", host) } if defaultBits < 0 || (ip.Is4() && defaultBits > 32) || (ip.Is6() && defaultBits > 128) { - return netaddr.IPPrefix{}, fmt.Errorf("invalid CIDR size %d for host %q", defaultBits, host) + return nil, fmt.Errorf("invalid CIDR size %d for host %q", defaultBits, host) } - return netaddr.IPPrefix{ - IP: ip, - Bits: uint8(defaultBits), + return []netaddr.IPPrefix{ + { + IP: ip, + Bits: uint8(defaultBits), + }, }, nil }