diff --git a/control/controlknobs/controlknobs.go b/control/controlknobs/controlknobs.go index 4d57b30a3..3ea0575a5 100644 --- a/control/controlknobs/controlknobs.go +++ b/control/controlknobs/controlknobs.go @@ -45,6 +45,9 @@ type Knobs struct { // incremental (delta) netmap updates and should treat all netmap // changes as "full" ones as tailscaled did in 1.48.x and earlier. DisableDeltaUpdates atomic.Bool + + // PeerMTUEnable is whether the node should do peer path MTU discovery. + PeerMTUEnable atomic.Bool } // UpdateFromNodeAttributes updates k (if non-nil) based on the provided self @@ -65,6 +68,7 @@ func (k *Knobs) UpdateFromNodeAttributes(selfNodeAttrs []tailcfg.NodeCapability, disableDeltaUpdates = has(tailcfg.NodeAttrDisableDeltaUpdates) oneCGNAT opt.Bool forceBackgroundSTUN = has(tailcfg.NodeAttrDebugForceBackgroundSTUN) + peerMTUEnable = has(tailcfg.NodeAttrPeerMTUEnable) ) if has(tailcfg.NodeAttrOneCGNATEnable) { @@ -80,6 +84,7 @@ func (k *Knobs) UpdateFromNodeAttributes(selfNodeAttrs []tailcfg.NodeCapability, k.OneCGNAT.Store(oneCGNAT) k.ForceBackgroundSTUN.Store(forceBackgroundSTUN) k.DisableDeltaUpdates.Store(disableDeltaUpdates) + k.PeerMTUEnable.Store(peerMTUEnable) } // AsDebugJSON returns k as something that can be marshalled with json.Marshal @@ -96,5 +101,6 @@ func (k *Knobs) AsDebugJSON() map[string]any { "OneCGNAT": k.OneCGNAT.Load(), "ForceBackgroundSTUN": k.ForceBackgroundSTUN.Load(), "DisableDeltaUpdates": k.DisableDeltaUpdates.Load(), + "PeerMTUEnable": k.PeerMTUEnable.Load(), } } diff --git a/tailcfg/tailcfg.go b/tailcfg/tailcfg.go index a1d8befa3..180f588a3 100644 --- a/tailcfg/tailcfg.go +++ b/tailcfg/tailcfg.go @@ -2133,6 +2133,10 @@ const ( // rather than one big /10 CGNAT route. At most one of this or // NodeAttrOneCGNATEnable may be set; if neither are, it's automatic. NodeAttrOneCGNATDisable NodeCapability = "one-cgnat?v=false" + + // NodeAttrPeerMTUEnable makes the client do path MTU discovery to its + // peers. If it isn't set, it defaults to the client default. + NodeAttrPeerMTUEnable NodeCapability = "peer-mtu-enable" ) // SetDNSRequest is a request to add a DNS record. diff --git a/wgengine/magicsock/peermtu.go b/wgengine/magicsock/peermtu.go index 12ea8ea4e..177f5d4fc 100644 --- a/wgengine/magicsock/peermtu.go +++ b/wgengine/magicsock/peermtu.go @@ -34,6 +34,14 @@ func (c *Conn) ShouldPMTUD() bool { } return v } + if c.controlKnobs != nil { + if v := c.controlKnobs.PeerMTUEnable.Load(); v { + if debugPMTUD() { + c.logf("magicsock: peermtu: peer path MTU discovery enabled by control") + } + return v + } + } if debugPMTUD() { c.logf("magicsock: peermtu: peer path MTU discovery set by default to false") } diff --git a/wgengine/userspace_test.go b/wgengine/userspace_test.go index c59bc8253..4a9b11587 100644 --- a/wgengine/userspace_test.go +++ b/wgengine/userspace_test.go @@ -6,12 +6,15 @@ package wgengine import ( "fmt" "net/netip" + "os" "reflect" + "runtime" "testing" "go4.org/mem" "tailscale.com/cmd/testwrapper/flakytest" "tailscale.com/control/controlknobs" + "tailscale.com/envknob" "tailscale.com/net/dns" "tailscale.com/net/netaddr" "tailscale.com/net/tstun" @@ -20,6 +23,7 @@ import ( "tailscale.com/tstime/mono" "tailscale.com/types/key" "tailscale.com/types/netmap" + "tailscale.com/types/opt" "tailscale.com/wgengine/router" "tailscale.com/wgengine/wgcfg" ) @@ -227,6 +231,86 @@ func TestUserspaceEnginePortReconfig(t *testing.T) { } } +// Test that enabling and disabling peer path MTU discovery works correctly. +func TestUserspaceEnginePeerMTUReconfig(t *testing.T) { + if runtime.GOOS != "linux" && runtime.GOOS != "darwin" { + t.Skipf("skipping on %q; peer MTU not supported", runtime.GOOS) + } + + defer os.Setenv("TS_DEBUG_ENABLE_PMTUD", os.Getenv("TS_DEBUG_ENABLE_PMTUD")) + envknob.Setenv("TS_DEBUG_ENABLE_PMTUD", "") + // Turn on debugging to help diagnose problems. + defer os.Setenv("TS_DEBUG_PMTUD", os.Getenv("TS_DEBUG_PMTUD")) + envknob.Setenv("TS_DEBUG_PMTUD", "true") + + var knobs controlknobs.Knobs + + e, err := NewFakeUserspaceEngine(t.Logf, 0, &knobs) + if err != nil { + t.Fatal(err) + } + t.Cleanup(e.Close) + ue := e.(*userspaceEngine) + + if ue.magicConn.PeerMTUEnabled() != false { + t.Error("peer MTU enabled by default, should not be") + } + osDefaultDF, err := ue.magicConn.DontFragSetting() + if err != nil { + t.Errorf("get don't fragment bit failed: %v", err) + } + t.Logf("Info: OS default don't fragment bit(s) setting: %v", osDefaultDF) + + // Build a set of configs to use as we change the peer MTU settings. + nodeKey, err := key.ParseNodePublicUntyped(mem.S("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")) + if err != nil { + t.Fatal(err) + } + cfg := &wgcfg.Config{ + Peers: []wgcfg.Peer{ + { + PublicKey: nodeKey, + AllowedIPs: []netip.Prefix{ + netip.PrefixFrom(netaddr.IPv4(100, 100, 99, 1), 32), + }, + }, + }, + } + routerCfg := &router.Config{} + + tests := []struct { + desc string // test description + wantP bool // desired value of PMTUD setting + wantDF bool // desired value of don't fragment bits + shouldP opt.Bool // if set, force peer MTU to this value + }{ + {desc: "after_first_reconfig", wantP: false, wantDF: osDefaultDF, shouldP: ""}, + {desc: "enabling_PMTUD_first_time", wantP: true, wantDF: true, shouldP: "true"}, + {desc: "disabling_PMTUD", wantP: false, wantDF: false, shouldP: "false"}, + {desc: "enabling_PMTUD_second_time", wantP: true, wantDF: true, shouldP: "true"}, + {desc: "returning_to_default_PMTUD", wantP: false, wantDF: false, shouldP: ""}, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + if v, ok := tt.shouldP.Get(); ok { + knobs.PeerMTUEnable.Store(v) + } else { + knobs.PeerMTUEnable.Store(false) + } + if err := ue.Reconfig(cfg, routerCfg, &dns.Config{}); err != nil { + t.Fatal(err) + } + if v := ue.magicConn.PeerMTUEnabled(); v != tt.wantP { + t.Errorf("peer MTU set to %v, want %v", v, tt.wantP) + } + if v, err := ue.magicConn.DontFragSetting(); v != tt.wantDF || err != nil { + t.Errorf("don't fragment bit set to %v, want %v, err %v", v, tt.wantP, err) + } + }) + } +} + func nkFromHex(hex string) key.NodePublic { if len(hex) != 64 { panic(fmt.Sprintf("%q is len %d; want 64", hex, len(hex)))