wgengine/filter: treat * as both a v4 and v6 wildcard.

Part of #19.

Signed-off-by: David Anderson <danderson@tailscale.com>
pull/918/head
David Anderson 4 years ago
parent 2d604b3791
commit 5062131aad

@ -11,6 +11,7 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/google/go-cmp/cmp"
"inet.af/netaddr" "inet.af/netaddr"
"tailscale.com/net/packet" "tailscale.com/net/packet"
"tailscale.com/types/logger" "tailscale.com/types/logger"
@ -188,21 +189,20 @@ func TestNoAllocs(t *testing.T) {
} }
func TestParseIP(t *testing.T) { func TestParseIP(t *testing.T) {
var noaddr netaddr.IPPrefix
tests := []struct { tests := []struct {
host string host string
bits int bits int
want netaddr.IPPrefix want []netaddr.IPPrefix
wantErr string wantErr string
}{ }{
{"8.8.8.8", 24, pfx("8.8.8.8/24"), ""}, {"8.8.8.8", 24, pfx("8.8.8.8/24"), ""},
{"2601:1234::", 64, pfx("2601:1234::/64"), ""}, {"2601:1234::", 64, pfx("2601:1234::/64"), ""},
{"8.8.8.8", 33, noaddr, `invalid CIDR size 33 for host "8.8.8.8"`}, {"8.8.8.8", 33, nil, `invalid CIDR size 33 for host "8.8.8.8"`},
{"8.8.8.8", -1, noaddr, `invalid CIDR size -1 for host "8.8.8.8"`}, {"8.8.8.8", -1, nil, `invalid CIDR size -1 for host "8.8.8.8"`},
{"2601:1234::", 129, noaddr, `invalid CIDR size 129 for host "2601:1234::"`}, {"2601:1234::", 129, nil, `invalid CIDR size 129 for host "2601:1234::"`},
{"0.0.0.0", 24, noaddr, `ports="0.0.0.0": to allow all IP addresses, use *:port, not 0.0.0.0:port`}, {"0.0.0.0", 24, nil, `ports="0.0.0.0": to allow all IP addresses, use *:port, not 0.0.0.0:port`},
{"::", 64, noaddr, `ports="::": to allow all IP addresses, use *:port, not [::]:port`}, {"::", 64, nil, `ports="::": to allow all IP addresses, use *:port, not [::]:port`},
{"*", 24, pfx("0.0.0.0/0"), ""}, {"*", 24, pfx("0.0.0.0/0", "::/0"), ""},
} }
for _, tt := range tests { for _, tt := range tests {
got, err := parseIP(tt.host, tt.bits) got, err := parseIP(tt.host, tt.bits)
@ -212,8 +212,8 @@ func TestParseIP(t *testing.T) {
} }
t.Errorf("parseIP(%q, %v) error: %v; want error %q", tt.host, tt.bits, err, tt.wantErr) t.Errorf("parseIP(%q, %v) error: %v; want error %q", tt.host, tt.bits, err, tt.wantErr)
} }
if got != tt.want { if diff := cmp.Diff(got, tt.want, cmp.Comparer(func(a, b netaddr.IP) bool { return a == b })); diff != "" {
t.Errorf("parseIP(%q, %v) = %#v; want %#v", tt.host, tt.bits, got, tt.want) t.Errorf("parseIP(%q, %v) = %s; want %s", tt.host, tt.bits, got, tt.want)
continue continue
} }
} }
@ -480,12 +480,15 @@ func mustIP4(s string) packet.IP4 {
return packet.IP4FromNetaddr(ip) return packet.IP4FromNetaddr(ip)
} }
func pfx(s string) netaddr.IPPrefix { func pfx(strs ...string) (ret []netaddr.IPPrefix) {
pfx, err := netaddr.ParseIPPrefix(s) for _, s := range strs {
if err != nil { pfx, err := netaddr.ParseIPPrefix(s)
panic(err) if err != nil {
panic(err)
}
ret = append(ret, pfx)
} }
return pfx return ret
} }
func nets(nets ...string) (ret []netaddr.IPPrefix) { func nets(nets ...string) (ret []netaddr.IPPrefix) {

@ -26,12 +26,12 @@ func MatchesFromFilterRules(pf []tailcfg.FilterRule) ([]Match, error) {
if len(r.SrcBits) > i { if len(r.SrcBits) > i {
bits = r.SrcBits[i] bits = r.SrcBits[i]
} }
net, err := parseIP(s, bits) nets, err := parseIP(s, bits)
if err != nil && erracc == nil { if err != nil && erracc == nil {
erracc = err erracc = err
continue continue
} }
m.Srcs = append(m.Srcs, net) m.Srcs = append(m.Srcs, nets...)
} }
for _, d := range r.DstPorts { for _, d := range r.DstPorts {
@ -39,18 +39,20 @@ func MatchesFromFilterRules(pf []tailcfg.FilterRule) ([]Match, error) {
if d.Bits != nil { if d.Bits != nil {
bits = *d.Bits bits = *d.Bits
} }
net, err := parseIP(d.IP, bits) nets, err := parseIP(d.IP, bits)
if err != nil && erracc == nil { if err != nil && erracc == nil {
erracc = err erracc = err
continue continue
} }
m.Dsts = append(m.Dsts, NetPortRange{ for _, net := range nets {
Net: net, m.Dsts = append(m.Dsts, NetPortRange{
Ports: PortRange{ Net: net,
First: d.Ports.First, Ports: PortRange{
Last: d.Ports.Last, First: d.Ports.First,
}, Last: d.Ports.Last,
}) },
})
}
} }
mm = append(mm, m) mm = append(mm, m)
@ -63,31 +65,35 @@ var (
zeroIP6 = netaddr.IPFrom16([16]byte{}) zeroIP6 = netaddr.IPFrom16([16]byte{})
) )
func parseIP(host string, defaultBits int) (netaddr.IPPrefix, error) { func parseIP(host string, defaultBits int) ([]netaddr.IPPrefix, error) {
if host == "*" { if host == "*" {
// User explicitly requested wildcard dst ip. // User explicitly requested wildcard dst ip.
// TODO: ipv6 return []netaddr.IPPrefix{
return netaddr.IPPrefix{IP: netaddr.IPv4(0, 0, 0, 0), Bits: 0}, nil {IP: zeroIP4, Bits: 0},
{IP: zeroIP6, Bits: 0},
}, nil
} }
ip, err := netaddr.ParseIP(host) ip, err := netaddr.ParseIP(host)
if err != nil { if err != nil {
return netaddr.IPPrefix{}, fmt.Errorf("ports=%#v: invalid IP address", host) return nil, fmt.Errorf("ports=%#v: invalid IP address", host)
} }
if ip == zeroIP4 { if ip == zeroIP4 {
// For clarity, reject 0.0.0.0 as an input // For clarity, reject 0.0.0.0 as an input
return netaddr.IPPrefix{}, fmt.Errorf("ports=%#v: to allow all IP addresses, use *:port, not 0.0.0.0:port", host) return nil, fmt.Errorf("ports=%#v: to allow all IP addresses, use *:port, not 0.0.0.0:port", host)
} }
if ip == zeroIP6 { if ip == zeroIP6 {
// For clarity, reject :: as an input // For clarity, reject :: as an input
return netaddr.IPPrefix{}, fmt.Errorf("ports=%#v: to allow all IP addresses, use *:port, not [::]:port", host) return nil, fmt.Errorf("ports=%#v: to allow all IP addresses, use *:port, not [::]:port", host)
} }
if defaultBits < 0 || (ip.Is4() && defaultBits > 32) || (ip.Is6() && defaultBits > 128) { if defaultBits < 0 || (ip.Is4() && defaultBits > 32) || (ip.Is6() && defaultBits > 128) {
return netaddr.IPPrefix{}, fmt.Errorf("invalid CIDR size %d for host %q", defaultBits, host) return nil, fmt.Errorf("invalid CIDR size %d for host %q", defaultBits, host)
} }
return netaddr.IPPrefix{ return []netaddr.IPPrefix{
IP: ip, {
Bits: uint8(defaultBits), IP: ip,
Bits: uint8(defaultBits),
},
}, nil }, nil
} }

Loading…
Cancel
Save