|
|
|
@ -40,14 +40,31 @@ const (
|
|
|
|
|
testDeniedProto ipproto.Proto = 127 // CRUDP, appropriately cruddy
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
func m(srcs []netip.Prefix, dsts []NetPortRange, protos ...ipproto.Proto) Match {
|
|
|
|
|
if protos == nil {
|
|
|
|
|
// m returnns a Match with the given srcs and dsts.
|
|
|
|
|
//
|
|
|
|
|
// opts can be ipproto.Proto values (if none, defaultProtos is used)
|
|
|
|
|
// or tailcfg.NodeCapability values. Other values panic.
|
|
|
|
|
func m(srcs []netip.Prefix, dsts []NetPortRange, opts ...any) Match {
|
|
|
|
|
var protos []ipproto.Proto
|
|
|
|
|
var caps []tailcfg.NodeCapability
|
|
|
|
|
for _, o := range opts {
|
|
|
|
|
switch o := o.(type) {
|
|
|
|
|
case ipproto.Proto:
|
|
|
|
|
protos = append(protos, o)
|
|
|
|
|
case tailcfg.NodeCapability:
|
|
|
|
|
caps = append(caps, o)
|
|
|
|
|
default:
|
|
|
|
|
panic(fmt.Sprintf("unknown option type %T", o))
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if len(protos) == 0 {
|
|
|
|
|
protos = defaultProtos
|
|
|
|
|
}
|
|
|
|
|
return Match{
|
|
|
|
|
IPProto: views.SliceOf(protos),
|
|
|
|
|
Srcs: srcs,
|
|
|
|
|
SrcsContains: ipset.NewContainsIPFunc(views.SliceOf(srcs)),
|
|
|
|
|
SrcCaps: caps,
|
|
|
|
|
Dsts: dsts,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -65,6 +82,7 @@ func newFilter(logf logger.Logf) *Filter {
|
|
|
|
|
m(nets("::/0"), netports("::/0:443")),
|
|
|
|
|
m(nets("0.0.0.0/0"), netports("0.0.0.0/0:*"), testAllowedProto),
|
|
|
|
|
m(nets("::/0"), netports("::/0:*"), testAllowedProto),
|
|
|
|
|
m(nil, netports("1.2.3.4:22"), tailcfg.NodeCapability("cap-hit-1234-ssh")),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Expects traffic to 100.122.98.50, 1.2.3.4, 5.6.7.8,
|
|
|
|
@ -79,11 +97,17 @@ func newFilter(logf logger.Logf) *Filter {
|
|
|
|
|
localNetsSet, _ := localNets.IPSet()
|
|
|
|
|
logBSet, _ := logB.IPSet()
|
|
|
|
|
|
|
|
|
|
return New(matches, localNetsSet, logBSet, nil, logf)
|
|
|
|
|
return New(matches, nil, localNetsSet, logBSet, nil, logf)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestFilter(t *testing.T) {
|
|
|
|
|
acl := newFilter(t.Logf)
|
|
|
|
|
filt := newFilter(t.Logf)
|
|
|
|
|
|
|
|
|
|
ipWithCap := netip.MustParseAddr("10.0.0.1")
|
|
|
|
|
ipWithoutCap := netip.MustParseAddr("10.0.0.2")
|
|
|
|
|
filt.srcIPHasCap = func(ip netip.Addr, cap tailcfg.NodeCapability) bool {
|
|
|
|
|
return cap == "cap-hit-1234-ssh" && ip == ipWithCap
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type InOut struct {
|
|
|
|
|
want Response
|
|
|
|
@ -139,21 +163,27 @@ func TestFilter(t *testing.T) {
|
|
|
|
|
{Accept, parsed(testAllowedProto, "2001::1", "2001::2", 0, 0)},
|
|
|
|
|
{Drop, parsed(testDeniedProto, "1.2.3.4", "5.6.7.8", 0, 0)},
|
|
|
|
|
{Drop, parsed(testDeniedProto, "2001::1", "2001::2", 0, 0)},
|
|
|
|
|
|
|
|
|
|
// Test use of a node capability to grant access.
|
|
|
|
|
// 10.0.0.1 has the capability; 10.0.0.2 does not (see srcIPHasCap at top of func)
|
|
|
|
|
{Accept, parsed(ipproto.TCP, ipWithCap.String(), "1.2.3.4", 30000, 22)},
|
|
|
|
|
{Drop, parsed(ipproto.TCP, ipWithoutCap.String(), "1.2.3.4", 30000, 22)},
|
|
|
|
|
}
|
|
|
|
|
for i, test := range tests {
|
|
|
|
|
aclFunc := acl.runIn4
|
|
|
|
|
aclFunc := filt.runIn4
|
|
|
|
|
if test.p.IPVersion == 6 {
|
|
|
|
|
aclFunc = acl.runIn6
|
|
|
|
|
aclFunc = filt.runIn6
|
|
|
|
|
}
|
|
|
|
|
if got, why := aclFunc(&test.p); test.want != got {
|
|
|
|
|
t.Errorf("#%d runIn got=%v want=%v why=%q packet:%v", i, got, test.want, why, test.p)
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
if test.p.IPProto == ipproto.TCP {
|
|
|
|
|
var got Response
|
|
|
|
|
if test.p.IPVersion == 4 {
|
|
|
|
|
got = acl.CheckTCP(test.p.Src.Addr(), test.p.Dst.Addr(), test.p.Dst.Port())
|
|
|
|
|
got = filt.CheckTCP(test.p.Src.Addr(), test.p.Dst.Addr(), test.p.Dst.Port())
|
|
|
|
|
} else {
|
|
|
|
|
got = acl.CheckTCP(test.p.Src.Addr(), test.p.Dst.Addr(), test.p.Dst.Port())
|
|
|
|
|
got = filt.CheckTCP(test.p.Src.Addr(), test.p.Dst.Addr(), test.p.Dst.Port())
|
|
|
|
|
}
|
|
|
|
|
if test.want != got {
|
|
|
|
|
t.Errorf("#%d CheckTCP got=%v want=%v packet:%v", i, got, test.want, test.p)
|
|
|
|
@ -165,7 +195,7 @@ func TestFilter(t *testing.T) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// Update UDP state
|
|
|
|
|
_, _ = acl.runOut(&test.p)
|
|
|
|
|
_, _ = filt.runOut(&test.p)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -264,13 +294,16 @@ func TestParseIPSet(t *testing.T) {
|
|
|
|
|
{"*", pfx("0.0.0.0/0", "::/0"), ""},
|
|
|
|
|
}
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
got, err := parseIPSet(tt.host)
|
|
|
|
|
got, gotCap, err := parseIPSet(tt.host)
|
|
|
|
|
if err != nil {
|
|
|
|
|
if err.Error() == tt.wantErr {
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
t.Errorf("parseIPSet(%q) error: %v; want error %q", tt.host, err, tt.wantErr)
|
|
|
|
|
}
|
|
|
|
|
if gotCap != "" {
|
|
|
|
|
t.Errorf("parseIPSet(%q) cap: %q; want empty", tt.host, gotCap)
|
|
|
|
|
}
|
|
|
|
|
compareIP := cmp.Comparer(func(a, b netip.Addr) bool { return a == b })
|
|
|
|
|
compareIPPrefix := cmp.Comparer(func(a, b netip.Prefix) bool { return a == b })
|
|
|
|
|
if diff := cmp.Diff(got, tt.want, compareIP, compareIPPrefix); diff != "" {
|
|
|
|
@ -278,6 +311,27 @@ func TestParseIPSet(t *testing.T) {
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
capTests := []struct {
|
|
|
|
|
in string
|
|
|
|
|
want tailcfg.NodeCapability
|
|
|
|
|
}{
|
|
|
|
|
{"cap:foo", "foo"},
|
|
|
|
|
{"cap:people-in-8.8.8.0/24", "people-in-8.8.8.0/24"}, // test precedence of "/" search
|
|
|
|
|
}
|
|
|
|
|
for _, tt := range capTests {
|
|
|
|
|
pfxes, gotCap, err := parseIPSet(tt.in)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Errorf("parseIPSet(%q) error: %v; want no error", tt.in, err)
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
if gotCap != tt.want {
|
|
|
|
|
t.Errorf("parseIPSet(%q) cap: %q; want %q", tt.in, gotCap, tt.want)
|
|
|
|
|
}
|
|
|
|
|
if len(pfxes) != 0 {
|
|
|
|
|
t.Errorf("parseIPSet(%q) pfxes: %v; want empty", tt.in, pfxes)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func BenchmarkFilter(b *testing.B) {
|
|
|
|
@ -904,7 +958,7 @@ func TestPeerCaps(t *testing.T) {
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
filt := New(mm, nil, nil, nil, t.Logf)
|
|
|
|
|
filt := New(mm, nil, nil, nil, nil, t.Logf)
|
|
|
|
|
tests := []struct {
|
|
|
|
|
name string
|
|
|
|
|
src, dst string // IP
|
|
|
|
@ -1037,7 +1091,7 @@ func benchmarkFile(b *testing.B, file string, opt benchOpt) {
|
|
|
|
|
logIPs.AddPrefix(tsaddr.CGNATRange())
|
|
|
|
|
logIPs.AddPrefix(tsaddr.TailscaleULARange())
|
|
|
|
|
|
|
|
|
|
f := New(matches, must.Get(localNets.IPSet()), must.Get(logIPs.IPSet()), nil, logger.Discard)
|
|
|
|
|
f := New(matches, nil, must.Get(localNets.IPSet()), must.Get(logIPs.IPSet()), nil, logger.Discard)
|
|
|
|
|
var srcIP, dstIP netip.Addr
|
|
|
|
|
if opt.v4 {
|
|
|
|
|
srcIP = netip.MustParseAddr("1.2.3.4")
|
|
|
|
|