diff --git a/wgengine/magicsock/relaymanager_test.go b/wgengine/magicsock/relaymanager_test.go index 6ae21b8fb..d40081839 100644 --- a/wgengine/magicsock/relaymanager_test.go +++ b/wgengine/magicsock/relaymanager_test.go @@ -80,40 +80,177 @@ func TestRelayManagerGetServers(t *testing.T) { } } -// Test for http://go/corp/32978 func TestRelayManager_handleNewServerEndpointRunLoop(t *testing.T) { - rm := relayManager{} - rm.init() - <-rm.runLoopStoppedCh // prevent runLoop() from starting, we will inject/handle events in the test - ep := &endpoint{} + wantHandshakeWorkCount := func(t *testing.T, rm *relayManager, n int) { + t.Helper() + byServerDiscoByEndpoint := 0 + for _, v := range rm.handshakeWorkByServerDiscoByEndpoint { + byServerDiscoByEndpoint += len(v) + } + byServerDiscoVNI := len(rm.handshakeWorkByServerDiscoVNI) + if byServerDiscoByEndpoint != n || + byServerDiscoVNI != n || + byServerDiscoByEndpoint != byServerDiscoVNI { + t.Fatalf("want handshake work count %d byServerDiscoByEndpoint=%d byServerDiscoVNI=%d", + n, + byServerDiscoByEndpoint, + byServerDiscoVNI, + ) + } + } + conn := newConn(t.Logf) - ep.c = conn - serverDisco := key.NewDisco().Public() - rm.handleNewServerEndpointRunLoop(newRelayServerEndpointEvent{ - wlb: endpointWithLastBest{ - ep: ep, + epA := &endpoint{c: conn} + epB := &endpoint{c: conn} + serverDiscoA := key.NewDisco().Public() + serverDiscoB := key.NewDisco().Public() + + serverAendpointALamport1VNI1 := newRelayServerEndpointEvent{ + wlb: endpointWithLastBest{ep: epA}, + se: udprelay.ServerEndpoint{ServerDisco: serverDiscoA, LamportID: 1, VNI: 1}, + } + serverAendpointALamport1VNI1LastBestMatching := newRelayServerEndpointEvent{ + wlb: endpointWithLastBest{ep: epA, lastBestIsTrusted: true, lastBest: addrQuality{relayServerDisco: serverDiscoA}}, + se: udprelay.ServerEndpoint{ServerDisco: serverDiscoA, LamportID: 1, VNI: 1}, + } + serverAendpointALamport2VNI1 := newRelayServerEndpointEvent{ + wlb: endpointWithLastBest{ep: epA}, + se: udprelay.ServerEndpoint{ServerDisco: serverDiscoA, LamportID: 2, VNI: 1}, + } + serverAendpointALamport2VNI2 := newRelayServerEndpointEvent{ + wlb: endpointWithLastBest{ep: epA}, + se: udprelay.ServerEndpoint{ServerDisco: serverDiscoA, LamportID: 2, VNI: 2}, + } + serverAendpointBLamport1VNI2 := newRelayServerEndpointEvent{ + wlb: endpointWithLastBest{ep: epB}, + se: udprelay.ServerEndpoint{ServerDisco: serverDiscoA, LamportID: 1, VNI: 2}, + } + serverBendpointALamport1VNI1 := newRelayServerEndpointEvent{ + wlb: endpointWithLastBest{ep: epA}, + se: udprelay.ServerEndpoint{ServerDisco: serverDiscoB, LamportID: 1, VNI: 1}, + } + + tests := []struct { + name string + events []newRelayServerEndpointEvent + want []newRelayServerEndpointEvent + }{ + { + // Test for http://go/corp/32978 + name: "eq server+ep neq VNI higher lamport", + events: []newRelayServerEndpointEvent{ + serverAendpointALamport1VNI1, + serverAendpointALamport2VNI2, + }, + want: []newRelayServerEndpointEvent{ + serverAendpointALamport2VNI2, + }, + }, + { + name: "eq server+ep neq VNI lower lamport", + events: []newRelayServerEndpointEvent{ + serverAendpointALamport2VNI2, + serverAendpointALamport1VNI1, + }, + want: []newRelayServerEndpointEvent{ + serverAendpointALamport2VNI2, + }, + }, + { + name: "eq server+vni neq ep lower lamport", + events: []newRelayServerEndpointEvent{ + serverAendpointALamport2VNI2, + serverAendpointBLamport1VNI2, + }, + want: []newRelayServerEndpointEvent{ + serverAendpointALamport2VNI2, + }, + }, + { + name: "eq server+vni neq ep higher lamport", + events: []newRelayServerEndpointEvent{ + serverAendpointBLamport1VNI2, + serverAendpointALamport2VNI2, + }, + want: []newRelayServerEndpointEvent{ + serverAendpointALamport2VNI2, + }, + }, + { + name: "eq server+endpoint+vni higher lamport", + events: []newRelayServerEndpointEvent{ + serverAendpointALamport1VNI1, + serverAendpointALamport2VNI1, + }, + want: []newRelayServerEndpointEvent{ + serverAendpointALamport2VNI1, + }, }, - se: udprelay.ServerEndpoint{ - ServerDisco: serverDisco, - LamportID: 1, - VNI: 1, + { + name: "eq server+endpoint+vni lower lamport", + events: []newRelayServerEndpointEvent{ + serverAendpointALamport2VNI1, + serverAendpointALamport1VNI1, + }, + want: []newRelayServerEndpointEvent{ + serverAendpointALamport2VNI1, + }, }, - }) - rm.handleNewServerEndpointRunLoop(newRelayServerEndpointEvent{ - wlb: endpointWithLastBest{ - ep: ep, + { + name: "eq endpoint+vni+lamport neq server", + events: []newRelayServerEndpointEvent{ + serverAendpointALamport1VNI1, + serverBendpointALamport1VNI1, + }, + want: []newRelayServerEndpointEvent{ + serverAendpointALamport1VNI1, + serverBendpointALamport1VNI1, + }, }, - se: udprelay.ServerEndpoint{ - ServerDisco: serverDisco, - LamportID: 2, - VNI: 2, + { + name: "trusted last best with matching server", + events: []newRelayServerEndpointEvent{ + serverAendpointALamport1VNI1LastBestMatching, + }, + want: []newRelayServerEndpointEvent{}, }, - }) - rm.stopWorkRunLoop(ep) - if len(rm.handshakeWorkByServerDiscoByEndpoint) != 0 || - len(rm.handshakeWorkByServerDiscoVNI) != 0 || - len(rm.handshakeWorkAwaitingPong) != 0 || - len(rm.addrPortVNIToHandshakeWork) != 0 { - t.Fatal("stranded relayHandshakeWork state") + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rm := &relayManager{} + rm.init() + <-rm.runLoopStoppedCh // prevent runLoop() from starting + + // feed events + for _, event := range tt.events { + rm.handleNewServerEndpointRunLoop(event) + } + + // validate state + wantHandshakeWorkCount(t, rm, len(tt.want)) + for _, want := range tt.want { + byServerDisco, ok := rm.handshakeWorkByServerDiscoByEndpoint[want.wlb.ep] + if !ok { + t.Fatal("work not found by endpoint") + } + workByServerDiscoByEndpoint, ok := byServerDisco[want.se.ServerDisco] + if !ok { + t.Fatal("work not found by server disco by endpoint") + } + workByServerDiscoVNI, ok := rm.handshakeWorkByServerDiscoVNI[serverDiscoVNI{want.se.ServerDisco, want.se.VNI}] + if !ok { + t.Fatal("work not found by server disco + VNI") + } + if workByServerDiscoByEndpoint != workByServerDiscoVNI { + t.Fatal("workByServerDiscoByEndpoint != workByServerDiscoVNI") + } + } + + // cleanup + for _, event := range tt.events { + rm.stopWorkRunLoop(event.wlb.ep) + } + wantHandshakeWorkCount(t, rm, 0) + }) } }