diff --git a/wgengine/router/router_linux.go b/wgengine/router/router_linux.go index fed1fceec..7e64e6158 100644 --- a/wgengine/router/router_linux.go +++ b/wgengine/router/router_linux.go @@ -1532,25 +1532,10 @@ func cidrDiff(kind string, old map[netip.Prefix]bool, new []netip.Prefix, add, d ret[cidr] = true } - var delFail []error - for cidr := range old { - if newMap[cidr] { - continue - } - if err := del(cidr); err != nil { - logf("%s del failed: %v", kind, err) - delFail = append(delFail, err) - } else { - delete(ret, cidr) - } - } - if len(delFail) == 1 { - return ret, delFail[0] - } - if len(delFail) > 0 { - return ret, fmt.Errorf("%d delete %s failures; first was: %w", len(delFail), kind, delFail[0]) - } - + // We want to add before we delete, so that if there is no overlap, we don't + // end up in a state where we have no addresses on an interface as that + // results in other kernel entities (like routes) pointing to that interface + // to also be deleted. var addFail []error for cidr := range newMap { if old[cidr] { @@ -1571,6 +1556,25 @@ func cidrDiff(kind string, old map[netip.Prefix]bool, new []netip.Prefix, add, d return ret, fmt.Errorf("%d add %s failures; first was: %w", len(addFail), kind, addFail[0]) } + var delFail []error + for cidr := range old { + if newMap[cidr] { + continue + } + if err := del(cidr); err != nil { + logf("%s del failed: %v", kind, err) + delFail = append(delFail, err) + } else { + delete(ret, cidr) + } + } + if len(delFail) == 1 { + return ret, delFail[0] + } + if len(delFail) > 0 { + return ret, fmt.Errorf("%d delete %s failures; first was: %w", len(delFail), kind, delFail[0]) + } + return ret, nil } diff --git a/wgengine/router/router_linux_test.go b/wgengine/router/router_linux_test.go index bcf477b07..9534cdfed 100644 --- a/wgengine/router/router_linux_test.go +++ b/wgengine/router/router_linux_test.go @@ -10,6 +10,7 @@ import ( "math/rand" "net/netip" "os" + "reflect" "sort" "strings" "sync/atomic" @@ -17,6 +18,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/vishvananda/netlink" + "golang.org/x/exp/slices" "golang.zx2c4.com/wireguard/tun" "tailscale.com/tstest" "tailscale.com/types/logger" @@ -839,3 +841,84 @@ Usage: busybox [function [arguments]...] t.Errorf("version = %q, want %q", got, want) } } + +func TestCIDRDiff(t *testing.T) { + pfx := func(p ...string) []netip.Prefix { + var ret []netip.Prefix + for _, s := range p { + ret = append(ret, netip.MustParsePrefix(s)) + } + return ret + } + tests := []struct { + old []netip.Prefix + new []netip.Prefix + wantAdd []netip.Prefix + wantDel []netip.Prefix + final []netip.Prefix + }{ + { + old: nil, + new: pfx("1.1.1.1/32"), + wantAdd: pfx("1.1.1.1/32"), + final: pfx("1.1.1.1/32"), + }, + { + old: pfx("1.1.1.1/32"), + new: pfx("1.1.1.1/32"), + final: pfx("1.1.1.1/32"), + }, + { + old: pfx("1.1.1.1/32", "2.3.4.5/32"), + new: pfx("1.1.1.1/32"), + wantDel: pfx("2.3.4.5/32"), + final: pfx("1.1.1.1/32"), + }, + { + old: pfx("1.1.1.1/32", "2.3.4.5/32"), + new: pfx("1.0.0.0/32", "3.4.5.6/32"), + wantDel: pfx("1.1.1.1/32", "2.3.4.5/32"), + wantAdd: pfx("1.0.0.0/32", "3.4.5.6/32"), + final: pfx("1.0.0.0/32", "3.4.5.6/32"), + }, + } + for _, tc := range tests { + om := make(map[netip.Prefix]bool) + for _, p := range tc.old { + om[p] = true + } + var added []netip.Prefix + var deleted []netip.Prefix + fm, err := cidrDiff("test", om, tc.new, func(p netip.Prefix) error { + if len(deleted) > 0 { + t.Error("delete called before add") + } + added = append(added, p) + return nil + }, func(p netip.Prefix) error { + deleted = append(deleted, p) + return nil + }, t.Logf) + if err != nil { + t.Fatal(err) + } + slices.SortFunc(added, func(a, b netip.Prefix) bool { return a.Addr().Less(b.Addr()) }) + slices.SortFunc(deleted, func(a, b netip.Prefix) bool { return a.Addr().Less(b.Addr()) }) + if !reflect.DeepEqual(added, tc.wantAdd) { + t.Errorf("added = %v, want %v", added, tc.wantAdd) + } + if !reflect.DeepEqual(deleted, tc.wantDel) { + t.Errorf("deleted = %v, want %v", deleted, tc.wantDel) + } + + // Check that the final state is correct. + if len(fm) != len(tc.final) { + t.Fatalf("final state = %v, want %v", fm, tc.final) + } + for _, p := range tc.final { + if !fm[p] { + t.Errorf("final state = %v, want %v", fm, tc.final) + } + } + } +}