diff --git a/control/controlclient/filter.go b/control/controlclient/filter.go index 920f7619d..098079cdf 100644 --- a/control/controlclient/filter.go +++ b/control/controlclient/filter.go @@ -7,6 +7,7 @@ package controlclient import ( "fmt" "net" + "tailscale.com/tailcfg" "tailscale.com/wgengine/filter" ) @@ -26,6 +27,9 @@ func parseIP(host string, defaultBits int) (filter.Net, error) { if ip == nil || len(ip) != 4 { return filter.NetNone, fmt.Errorf("ports=%#v: invalid IPv4 address", host) } + if len(ip) == 4 && (defaultBits < 0 || defaultBits > 32) { + return filter.NetNone, fmt.Errorf("invalid CIDR size %d for host %q", defaultBits, host) + } return filter.Net{ IP: filter.NewIP(ip), Mask: filter.Netmask(defaultBits), diff --git a/control/controlclient/filter_test.go b/control/controlclient/filter_test.go new file mode 100644 index 000000000..9a69be81f --- /dev/null +++ b/control/controlclient/filter_test.go @@ -0,0 +1,42 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package controlclient + +import ( + "net" + "testing" + + "tailscale.com/wgengine/filter" + "tailscale.com/wgengine/packet" +) + +func TestParseIP(t *testing.T) { + tests := []struct { + host string + bits int + want filter.Net + wantErr string + }{ + {"8.8.8.8", 24, filter.Net{IP: packet.NewIP(net.ParseIP("8.8.8.8")), Mask: packet.NewIP(net.ParseIP("255.255.255.0"))}, ""}, + {"8.8.8.8", 33, filter.Net{}, `invalid CIDR size 33 for host "8.8.8.8"`}, + {"8.8.8.8", -1, filter.Net{}, `invalid CIDR size -1 for host "8.8.8.8"`}, + {"0.0.0.0", 24, filter.Net{}, `ports="0.0.0.0": to allow all IP addresses, use *:port, not 0.0.0.0:port`}, + {"*", 24, filter.NetAny, ""}, + {"fe80::1", 128, filter.NetNone, `ports="fe80::1": invalid IPv4 address`}, + } + for _, tt := range tests { + got, err := parseIP(tt.host, tt.bits) + if err != nil { + if err.Error() == tt.wantErr { + continue + } + t.Errorf("parseIP(%q, %v) error: %v; want error %q", tt.host, tt.bits, err, tt.wantErr) + } + if got != tt.want { + t.Errorf("parseIP(%q, %v) = %#v; want %#v", tt.host, tt.bits, got, tt.want) + continue + } + } +}