diff --git a/net/tsaddr/tsaddr.go b/net/tsaddr/tsaddr.go index b75a2662d..880695387 100644 --- a/net/tsaddr/tsaddr.go +++ b/net/tsaddr/tsaddr.go @@ -191,6 +191,17 @@ func ContainsExitRoutes(rr views.Slice[netip.Prefix]) bool { return v4 && v6 } +// ContainsExitRoute reports whether rr contains at least one of IPv4 or +// IPv6 /0 (exit) routes. +func ContainsExitRoute(rr views.Slice[netip.Prefix]) bool { + for i := range rr.Len() { + if rr.At(i).Bits() == 0 { + return true + } + } + return false +} + // ContainsNonExitSubnetRoutes reports whether v contains Subnet // Routes other than ExitNode Routes. func ContainsNonExitSubnetRoutes(rr views.Slice[netip.Prefix]) bool { @@ -202,6 +213,38 @@ func ContainsNonExitSubnetRoutes(rr views.Slice[netip.Prefix]) bool { return false } +// WithoutExitRoutes returns rr unchanged if it has only 1 or 0 /0 +// routes. If it has both IPv4 and IPv6 /0 routes, then it returns +// a copy with all /0 routes removed. +func WithoutExitRoutes(rr views.Slice[netip.Prefix]) views.Slice[netip.Prefix] { + if !ContainsExitRoutes(rr) { + return rr + } + var out []netip.Prefix + for _, r := range rr.All() { + if r.Bits() > 0 { + out = append(out, r) + } + } + return views.SliceOf(out) +} + +// WithoutExitRoute returns rr unchanged if it has 0 /0 +// routes. If it has a IPv4 or IPv6 /0 routes, then it returns +// a copy with all /0 routes removed. +func WithoutExitRoute(rr views.Slice[netip.Prefix]) views.Slice[netip.Prefix] { + if !ContainsExitRoute(rr) { + return rr + } + var out []netip.Prefix + for _, r := range rr.All() { + if r.Bits() > 0 { + out = append(out, r) + } + } + return views.SliceOf(out) +} + var ( allIPv4 = netip.MustParsePrefix("0.0.0.0/0") allIPv6 = netip.MustParsePrefix("::/0") @@ -216,6 +259,11 @@ func AllIPv6() netip.Prefix { return allIPv6 } // ExitRoutes returns a slice containing AllIPv4 and AllIPv6. func ExitRoutes() []netip.Prefix { return []netip.Prefix{allIPv4, allIPv6} } +// IsExitRoute reports whether p is an exit node route. +func IsExitRoute(p netip.Prefix) bool { + return p == allIPv4 || p == allIPv6 +} + // SortPrefixes sorts the prefixes in place. func SortPrefixes(p []netip.Prefix) { slices.SortFunc(p, netipx.ComparePrefix) diff --git a/net/tsaddr/tsaddr_test.go b/net/tsaddr/tsaddr_test.go index dccc34271..4aa2f8c60 100644 --- a/net/tsaddr/tsaddr_test.go +++ b/net/tsaddr/tsaddr_test.go @@ -7,7 +7,10 @@ import ( "net/netip" "testing" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "tailscale.com/net/netaddr" + "tailscale.com/types/views" ) func TestInCrostiniRange(t *testing.T) { @@ -89,3 +92,133 @@ func TestUnmapVia(t *testing.T) { } } } + +func TestIsExitNodeRoute(t *testing.T) { + tests := []struct { + pref netip.Prefix + want bool + }{ + { + pref: AllIPv4(), + want: true, + }, + { + pref: AllIPv6(), + want: true, + }, + { + pref: netip.MustParsePrefix("1.1.1.1/0"), + want: false, + }, + { + pref: netip.MustParsePrefix("1.1.1.1/1"), + want: false, + }, + { + pref: netip.MustParsePrefix("192.168.0.0/24"), + want: false, + }, + } + + for _, tt := range tests { + if got := IsExitRoute(tt.pref); got != tt.want { + t.Errorf("for %q: got %v, want %v", tt.pref, got, tt.want) + } + } +} + +func TestWithoutExitRoutes(t *testing.T) { + tests := []struct { + prefs []netip.Prefix + want []netip.Prefix + }{ + { + prefs: []netip.Prefix{AllIPv4(), AllIPv6()}, + want: []netip.Prefix{}, + }, + { + prefs: []netip.Prefix{AllIPv4()}, + want: []netip.Prefix{AllIPv4()}, + }, + { + prefs: []netip.Prefix{AllIPv4(), AllIPv6(), netip.MustParsePrefix("10.0.0.0/10")}, + want: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/10")}, + }, + { + prefs: []netip.Prefix{AllIPv6(), netip.MustParsePrefix("10.0.0.0/10")}, + want: []netip.Prefix{AllIPv6(), netip.MustParsePrefix("10.0.0.0/10")}, + }, + } + + for _, tt := range tests { + got := WithoutExitRoutes(views.SliceOf(tt.prefs)) + if diff := cmp.Diff(tt.want, got.AsSlice(), cmpopts.EquateEmpty(), cmp.Comparer(func(a, b netip.Prefix) bool { return a == b })); diff != "" { + t.Errorf("unexpected route difference (-want +got):\n%s", diff) + } + } +} + +func TestWithoutExitRoute(t *testing.T) { + tests := []struct { + prefs []netip.Prefix + want []netip.Prefix + }{ + { + prefs: []netip.Prefix{AllIPv4(), AllIPv6()}, + want: []netip.Prefix{}, + }, + { + prefs: []netip.Prefix{AllIPv4()}, + want: []netip.Prefix{}, + }, + { + prefs: []netip.Prefix{AllIPv4(), AllIPv6(), netip.MustParsePrefix("10.0.0.0/10")}, + want: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/10")}, + }, + { + prefs: []netip.Prefix{AllIPv6(), netip.MustParsePrefix("10.0.0.0/10")}, + want: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/10")}, + }, + } + + for _, tt := range tests { + got := WithoutExitRoute(views.SliceOf(tt.prefs)) + if diff := cmp.Diff(tt.want, got.AsSlice(), cmpopts.EquateEmpty(), cmp.Comparer(func(a, b netip.Prefix) bool { return a == b })); diff != "" { + t.Errorf("unexpected route difference (-want +got):\n%s", diff) + } + } +} + +func TestContainsExitRoute(t *testing.T) { + tests := []struct { + prefs []netip.Prefix + want bool + }{ + { + prefs: []netip.Prefix{AllIPv4(), AllIPv6()}, + want: true, + }, + { + prefs: []netip.Prefix{AllIPv4()}, + want: true, + }, + { + prefs: []netip.Prefix{AllIPv4(), AllIPv6(), netip.MustParsePrefix("10.0.0.0/10")}, + want: true, + }, + { + prefs: []netip.Prefix{AllIPv6(), netip.MustParsePrefix("10.0.0.0/10")}, + want: true, + }, + { + prefs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/10")}, + want: false, + }, + } + + for _, tt := range tests { + if got := ContainsExitRoute(views.SliceOf(tt.prefs)); got != tt.want { + t.Errorf("for %q: got %v, want %v", tt.prefs, got, tt.want) + } + } +}