diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index a042b689e..fd2f80681 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -3079,10 +3079,9 @@ type discoEndpoint struct { lastFullPing time.Time // last time we pinged all endpoints derpAddr netaddr.IPPort // fallback/bootstrap path, if non-zero (non-zero for well-behaved clients) - bestAddr netaddr.IPPort // best non-DERP path; zero if none - bestAddrLatency time.Duration - bestAddrAt time.Time // time best address re-confirmed - trustBestAddrUntil time.Time // time when bestAddr expires + bestAddr addrLatency // best non-DERP path; zero if none + bestAddrAt time.Time // time best address re-confirmed + trustBestAddrUntil time.Time // time when bestAddr expires sentPing map[stun.TxID]sentPing endpointState map[netaddr.IPPort]*endpointState isCallMeMaybeEP map[netaddr.IPPort]bool @@ -3187,8 +3186,8 @@ func (st *endpointState) shouldDeleteLocked() bool { func (de *discoEndpoint) deleteEndpointLocked(ep netaddr.IPPort) { delete(de.endpointState, ep) - if de.bestAddr == ep { - de.bestAddr = netaddr.IPPort{} + if de.bestAddr.IPPort == ep { + de.bestAddr = addrLatency{} } } @@ -3256,7 +3255,7 @@ func (de *discoEndpoint) DstToBytes() []byte { return packIPPort(de.fakeWGAddr) // // de.mu must be held. func (de *discoEndpoint) addrForSendLocked(now time.Time) (udpAddr, derpAddr netaddr.IPPort) { - udpAddr = de.bestAddr + udpAddr = de.bestAddr.IPPort if udpAddr.IsZero() || now.After(de.trustBestAddrUntil) { // We had a bestAddr but it expired so send both to it // and DERP. @@ -3309,7 +3308,7 @@ func (de *discoEndpoint) wantFullPingLocked(now time.Time) bool { if now.After(de.trustBestAddrUntil) { return true } - if de.bestAddrLatency <= goodEnoughLatency { + if de.bestAddr.latency <= goodEnoughLatency { return false } if now.Sub(de.lastFullPing) >= upgradeInterval { @@ -3641,20 +3640,39 @@ func (de *discoEndpoint) handlePongConnLocked(m *disco.Pong, src netaddr.IPPort) // Promote this pong response to our current best address if it's lower latency. // TODO(bradfitz): decide how latency vs. preference order affects decision if !isDerp { - if de.bestAddr.IsZero() || latency < de.bestAddrLatency { - if de.bestAddr != sp.to { - de.c.logf("magicsock: disco: node %v %v now using %v", de.publicKey.ShortString(), de.discoShort, sp.to) - de.bestAddr = sp.to - } + thisPong := addrLatency{sp.to, latency} + if betterAddr(thisPong, de.bestAddr) { + de.c.logf("magicsock: disco: node %v %v now using %v", de.publicKey.ShortString(), de.discoShort, sp.to) + de.bestAddr = thisPong } - if de.bestAddr == sp.to { - de.bestAddrLatency = latency + if de.bestAddr.IPPort == thisPong.IPPort { + de.bestAddr.latency = latency de.bestAddrAt = now de.trustBestAddrUntil = now.Add(trustUDPAddrDuration) } } } +// addrLatency is an IPPort with an associated latency. +type addrLatency struct { + netaddr.IPPort + latency time.Duration +} + +// betterAddr reports whether a is a better addr to use than b. +func betterAddr(a, b addrLatency) bool { + if a.IPPort == b.IPPort { + return false + } + if b.IsZero() { + return true + } + if a.IsZero() { + return false + } + return a.latency < b.latency +} + // discoEndpoint.mu must be held. func (st *endpointState) addPongReplyLocked(r pongReply) { if n := len(st.recentPongs); n < pongHistoryCount { @@ -3761,8 +3779,7 @@ func (de *discoEndpoint) stopAndReset() { // state isn't a mix of before & after two sessions. de.lastSend = time.Time{} de.lastFullPing = time.Time{} - de.bestAddr = netaddr.IPPort{} - de.bestAddrLatency = 0 + de.bestAddr = addrLatency{} de.bestAddrAt = time.Time{} de.trustBestAddrUntil = time.Time{} for _, es := range de.endpointState { diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 1c9dfd1d7..a48676e30 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -1851,3 +1851,33 @@ func TestStringSetsEqual(t *testing.T) { } } + +func TestBetterAddr(t *testing.T) { + const ms = time.Millisecond + al := func(ipps string, d time.Duration) addrLatency { + return addrLatency{netaddr.MustParseIPPort(ipps), d} + } + zero := addrLatency{} + tests := []struct { + a, b addrLatency + want bool + }{ + {a: zero, b: zero, want: false}, + {a: al("10.0.0.2:123", 5*ms), b: zero, want: true}, + {a: zero, b: al("10.0.0.2:123", 5*ms), want: false}, + {a: al("10.0.0.2:123", 5*ms), b: al("1.2.3.4:555", 6*ms), want: true}, + {a: al("10.0.0.2:123", 5*ms), b: al("10.0.0.2:123", 10*ms), want: false}, // same IPPort + } + for _, tt := range tests { + got := betterAddr(tt.a, tt.b) + if got != tt.want { + t.Errorf("betterAddr(%+v, %+v) = %v; want %v", tt.a, tt.b, got, tt.want) + continue + } + gotBack := betterAddr(tt.b, tt.a) + if got && gotBack { + t.Errorf("betterAddr(%+v, %+v) and betterAddr(%+v, %+v) both unexpectedly true", tt.a, tt.b, tt.b, tt.a) + } + } + +}