wgengine/magicsock: refactor twoDevicePing to make stack construction cleaner.

Signed-off-by: David Anderson <danderson@tailscale.com>
pull/602/head
David Anderson 4 years ago committed by Dave Anderson
parent 0a42b0a726
commit 3669296cef

@ -60,6 +60,136 @@ func (c *Conn) WaitReady(t *testing.T) {
} }
} }
func runDERPAndStun(t *testing.T, logf logger.Logf, l nettype.PacketListener, stunIP netaddr.IP) (derpMap *tailcfg.DERPMap, cleanup func()) {
var serverPrivateKey key.Private
if _, err := crand.Read(serverPrivateKey[:]); err != nil {
t.Fatal(err)
}
d := derp.NewServer(serverPrivateKey, logf)
if l != (nettype.Std{}) {
// When using virtual networking, only allow DERP to forward
// discovery traffic, not actual packets.
d.OnlyDisco = true
}
httpsrv := httptest.NewUnstartedServer(derphttp.Handler(d))
httpsrv.Config.ErrorLog = logger.StdLogger(logf)
httpsrv.Config.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
httpsrv.StartTLS()
stunAddr, stunCleanup := stuntest.ServeWithPacketListener(t, l)
m := &tailcfg.DERPMap{
Regions: map[int]*tailcfg.DERPRegion{
1: &tailcfg.DERPRegion{
RegionID: 1,
RegionCode: "test",
Nodes: []*tailcfg.DERPNode{
{
Name: "t1",
RegionID: 1,
HostName: "test-node.unused",
IPv4: "127.0.0.1",
IPv6: "none",
STUNPort: stunAddr.Port,
DERPTestPort: httpsrv.Listener.Addr().(*net.TCPAddr).Port,
STUNTestIP: stunIP.String(),
},
},
},
},
}
cleanup = func() {
httpsrv.CloseClientConnections()
httpsrv.Close()
d.Close()
stunCleanup()
}
return m, cleanup
}
// magicStack is a magicsock, plus all the stuff around it that's
// necessary to send and receive packets to test e2e wireguard
// happiness.
type magicStack struct {
privateKey wgcfg.PrivateKey
epCh chan []string // endpoint updates produced by this peer
conn *Conn // the magicsock itself
tun *tuntest.ChannelTUN // tuntap device to send/receive packets
tsTun *tstun.TUN // wrapped tun that implements filtering and wgengine hooks
dev *device.Device // the wireguard-go Device that connects the previous things
}
// newMagicStack builds and initializes an idle magicsock and
// friends. You need to call conn.SetNetworkMap and dev.Reconfig
// before anything interesting happens.
func newMagicStack(t *testing.T, logf logger.Logf, l nettype.PacketListener, derpMap *tailcfg.DERPMap) *magicStack {
t.Helper()
privateKey, err := wgcfg.NewPrivateKey()
if err != nil {
t.Fatalf("generating private key: %v", err)
}
epCh := make(chan []string, 100) // arbitrary
conn, err := NewConn(Options{
Logf: logf,
PacketListener: l,
EndpointsFunc: func(eps []string) {
epCh <- eps
},
})
if err != nil {
t.Fatalf("constructing magicsock: %v", err)
}
conn.Start()
conn.SetDERPMap(derpMap)
if err := conn.SetPrivateKey(privateKey); err != nil {
t.Fatalf("setting private key in magicsock: %v", err)
}
tun := tuntest.NewChannelTUN()
tsTun := tstun.WrapTUN(logf, tun.TUN())
tsTun.SetFilter(filter.NewAllowAll([]filter.Net{filter.NetAny}, logf))
dev := device.NewDevice(tsTun, &device.DeviceOptions{
Logger: &device.Logger{
Debug: logger.StdLogger(logf),
Info: logger.StdLogger(logf),
Error: logger.StdLogger(logf),
},
CreateEndpoint: conn.CreateEndpoint,
CreateBind: conn.CreateBind,
SkipBindUpdate: true,
})
dev.Up()
// Wait for magicsock to connect up to DERP.
conn.WaitReady(t)
// Wait for first endpoint update to be available
deadline := time.Now().Add(2 * time.Second)
for len(epCh) == 0 && time.Now().Before(deadline) {
time.Sleep(10 * time.Millisecond)
}
return &magicStack{
privateKey: privateKey,
epCh: epCh,
conn: conn,
tun: tun,
tsTun: tsTun,
dev: dev,
}
}
func (s *magicStack) Close() {
s.dev.Close()
s.conn.Close()
}
func TestNewConn(t *testing.T) { func TestNewConn(t *testing.T) {
tstest.PanicOnLog() tstest.PanicOnLog()
rc := tstest.NewResourceCheck() rc := tstest.NewResourceCheck()
@ -243,45 +373,6 @@ func parseCIDR(t *testing.T, addr string) wgcfg.CIDR {
return cidr return cidr
} }
func runDERP(t *testing.T, logf logger.Logf, onlyDisco bool) (s *derp.Server, addr *net.TCPAddr, cleanupFn func()) {
var serverPrivateKey key.Private
if _, err := crand.Read(serverPrivateKey[:]); err != nil {
t.Fatal(err)
}
s = derp.NewServer(serverPrivateKey, logf)
s.OnlyDisco = onlyDisco
httpsrv := httptest.NewUnstartedServer(derphttp.Handler(s))
httpsrv.Config.ErrorLog = logger.StdLogger(logf)
httpsrv.Config.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
httpsrv.StartTLS()
logf("DERP server URL: %s (onlyDisco=%v)", httpsrv.URL, onlyDisco)
cleanupFn = func() {
httpsrv.CloseClientConnections()
httpsrv.Close()
s.Close()
}
return s, httpsrv.Listener.Addr().(*net.TCPAddr), cleanupFn
}
// devLogger returns a wireguard-go device.Logger that writes
// wireguard logs to the test logger.
func devLogger(t *testing.T, prefix string, logfx logger.Logf) *device.Logger {
pfx := []interface{}{prefix}
logf := func(format string, args ...interface{}) {
t.Helper()
logfx("%s: "+format, append(pfx, args...)...)
}
return &device.Logger{
Debug: logger.StdLogger(logf),
Info: logger.StdLogger(logf),
Error: logger.StdLogger(logf),
}
}
// TestDeviceStartStop exercises the startup and shutdown logic of // TestDeviceStartStop exercises the startup and shutdown logic of
// wireguard-go, which is intimately intertwined with magicsock's own // wireguard-go, which is intimately intertwined with magicsock's own
// lifecycle. We seem to be good at generating deadlocks here, so if // lifecycle. We seem to be good at generating deadlocks here, so if
@ -305,7 +396,11 @@ func TestDeviceStartStop(t *testing.T) {
tun := tuntest.NewChannelTUN() tun := tuntest.NewChannelTUN()
dev := device.NewDevice(tun.TUN(), &device.DeviceOptions{ dev := device.NewDevice(tun.TUN(), &device.DeviceOptions{
Logger: devLogger(t, "dev", t.Logf), Logger: &device.Logger{
Debug: logger.StdLogger(t.Logf),
Info: logger.StdLogger(t.Logf),
Error: logger.StdLogger(t.Logf),
},
CreateEndpoint: conn.CreateEndpoint, CreateEndpoint: conn.CreateEndpoint,
CreateBind: conn.CreateBind, CreateBind: conn.CreateBind,
SkipBindUpdate: true, SkipBindUpdate: true,
@ -414,127 +509,37 @@ func testTwoDevicePing(t *testing.T, d *devices) {
rc := tstest.NewResourceCheck() rc := tstest.NewResourceCheck()
defer rc.Assert(t) defer rc.Assert(t)
usingNatLab := d.m1 != (nettype.Std{})
// This gets reassigned inside every test, so that the connections // This gets reassigned inside every test, so that the connections
// all log using the "current" t.Logf function. Sigh. // all log using the "current" t.Logf function. Sigh.
logf, setT := makeNestable(t) logf, setT := makeNestable(t)
derpServer, derpAddr, derpCleanupFn := runDERP(t, logf, usingNatLab) derpMap, cleanup := runDERPAndStun(t, logf, d.stun, d.stunIP)
defer derpCleanupFn() defer cleanup()
stunAddr, stunCleanupFn := stuntest.ServeWithPacketListener(t, d.stun) m1 := newMagicStack(t, logf, d.m1, derpMap)
defer stunCleanupFn() defer m1.Close()
m2 := newMagicStack(t, logf, d.m2, derpMap)
derpMap := &tailcfg.DERPMap{ defer m2.Close()
Regions: map[int]*tailcfg.DERPRegion{
1: &tailcfg.DERPRegion{
RegionID: 1,
RegionCode: "test",
Nodes: []*tailcfg.DERPNode{
{
Name: "t1",
RegionID: 1,
HostName: "test-node.unused",
IPv4: "127.0.0.1",
IPv6: "none",
STUNPort: stunAddr.Port,
DERPTestPort: derpAddr.Port,
STUNTestIP: d.stunIP.String(),
},
},
},
},
}
epCh1 := make(chan []string, 16)
conn1, err := NewConn(Options{
Logf: logger.WithPrefix(logf, "conn1: "),
PacketListener: d.m1,
EndpointsFunc: func(eps []string) {
epCh1 <- eps
},
})
if err != nil {
t.Fatal(err)
}
defer conn1.Close()
conn1.Start()
conn1.SetDERPMap(derpMap)
epCh2 := make(chan []string, 16)
conn2, err := NewConn(Options{
Logf: logger.WithPrefix(logf, "conn2: "),
PacketListener: d.m2,
EndpointsFunc: func(eps []string) {
epCh2 <- eps
},
})
if err != nil {
t.Fatal(err)
}
defer conn2.Close()
conn2.Start()
conn2.SetDERPMap(derpMap)
addrs := []netaddr.IPPort{ addrs := []netaddr.IPPort{
{IP: d.m1IP, Port: conn1.LocalPort()}, {IP: d.m1IP, Port: m1.conn.LocalPort()},
{IP: d.m2IP, Port: conn2.LocalPort()}, {IP: d.m2IP, Port: m2.conn.LocalPort()},
} }
cfgs := makeConfigs(t, addrs) cfgs := makeConfigs(t, addrs)
if err := conn1.SetPrivateKey(cfgs[0].PrivateKey); err != nil { if err := m1.dev.Reconfig(&cfgs[0]); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := conn2.SetPrivateKey(cfgs[1].PrivateKey); err != nil { if err := m2.dev.Reconfig(&cfgs[1]); err != nil {
t.Fatal(err) t.Fatal(err)
} }
//uapi1, _ := cfgs[0].ToUAPI()
//logf("cfg0: %v", uapi1)
//uapi2, _ := cfgs[1].ToUAPI()
//logf("cfg1: %v", uapi2)
tun1 := tuntest.NewChannelTUN()
tstun1 := tstun.WrapTUN(logf, tun1.TUN())
tstun1.SetFilter(filter.NewAllowAll([]filter.Net{filter.NetAny}, logf))
dev1 := device.NewDevice(tstun1, &device.DeviceOptions{
Logger: devLogger(t, "dev1", logf),
CreateEndpoint: conn1.CreateEndpoint,
CreateBind: conn1.CreateBind,
SkipBindUpdate: true,
})
dev1.Up()
if err := dev1.Reconfig(&cfgs[0]); err != nil {
t.Fatal(err)
}
defer dev1.Close()
tun2 := tuntest.NewChannelTUN()
tstun2 := tstun.WrapTUN(logf, tun2.TUN())
tstun2.SetFilter(filter.NewAllowAll([]filter.Net{filter.NetAny}, logf))
dev2 := device.NewDevice(tstun2, &device.DeviceOptions{
Logger: devLogger(t, "dev2", logf),
CreateEndpoint: conn2.CreateEndpoint,
CreateBind: conn2.CreateBind,
SkipBindUpdate: true,
})
dev2.Up()
defer dev2.Close()
if err := dev2.Reconfig(&cfgs[1]); err != nil {
t.Fatal(err)
}
conn1.WaitReady(t)
conn2.WaitReady(t)
ping1 := func(t *testing.T) { ping1 := func(t *testing.T) {
msg2to1 := tuntest.Ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2")) msg2to1 := tuntest.Ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2"))
tun2.Outbound <- msg2to1 m2.tun.Outbound <- msg2to1
t.Log("ping1 sent") t.Log("ping1 sent")
select { select {
case msgRecv := <-tun1.Inbound: case msgRecv := <-m1.tun.Inbound:
if !bytes.Equal(msg2to1, msgRecv) { if !bytes.Equal(msg2to1, msgRecv) {
t.Error("ping did not transit correctly") t.Error("ping did not transit correctly")
} }
@ -544,10 +549,10 @@ func testTwoDevicePing(t *testing.T, d *devices) {
} }
ping2 := func(t *testing.T) { ping2 := func(t *testing.T) {
msg1to2 := tuntest.Ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1")) msg1to2 := tuntest.Ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1"))
tun1.Outbound <- msg1to2 m1.tun.Outbound <- msg1to2
t.Log("ping2 sent") t.Log("ping2 sent")
select { select {
case msgRecv := <-tun2.Inbound: case msgRecv := <-m2.tun.Inbound:
if !bytes.Equal(msg1to2, msgRecv) { if !bytes.Equal(msg1to2, msgRecv) {
t.Error("return ping did not transit correctly") t.Error("return ping did not transit correctly")
} }
@ -573,12 +578,12 @@ func testTwoDevicePing(t *testing.T, d *devices) {
setT(t) setT(t)
defer setT(outerT) defer setT(outerT)
msg1to2 := tuntest.Ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1")) msg1to2 := tuntest.Ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1"))
if err := tstun1.InjectOutbound(msg1to2); err != nil { if err := m1.tsTun.InjectOutbound(msg1to2); err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Log("SendPacket sent") t.Log("SendPacket sent")
select { select {
case msgRecv := <-tun2.Inbound: case msgRecv := <-m2.tun.Inbound:
if !bytes.Equal(msg1to2, msgRecv) { if !bytes.Equal(msg1to2, msgRecv) {
t.Error("return ping did not transit correctly") t.Error("return ping did not transit correctly")
} }
@ -590,7 +595,7 @@ func testTwoDevicePing(t *testing.T, d *devices) {
t.Run("no-op dev1 reconfig", func(t *testing.T) { t.Run("no-op dev1 reconfig", func(t *testing.T) {
setT(t) setT(t)
defer setT(outerT) defer setT(outerT)
if err := dev1.Reconfig(&cfgs[0]); err != nil { if err := m1.dev.Reconfig(&cfgs[0]); err != nil {
t.Fatal(err) t.Fatal(err)
} }
ping1(t) ping1(t)
@ -632,14 +637,14 @@ func testTwoDevicePing(t *testing.T, d *devices) {
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
b := msg(i) b := msg(i)
tun1.Outbound <- b m1.tun.Outbound <- b
time.Sleep(interPacketGap) time.Sleep(interPacketGap)
} }
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
b := msg(i) b := msg(i)
select { select {
case msgRecv := <-tun2.Inbound: case msgRecv := <-m2.tun.Inbound:
if !bytes.Equal(b, msgRecv) { if !bytes.Equal(b, msgRecv) {
if strict { if strict {
t.Errorf("return ping %d did not transit correctly: %s", i, cmp.Diff(b, msgRecv)) t.Errorf("return ping %d did not transit correctly: %s", i, cmp.Diff(b, msgRecv))
@ -651,7 +656,6 @@ func testTwoDevicePing(t *testing.T, d *devices) {
} }
} }
} }
} }
t.Run("ping 1.0.0.1 x50", func(t *testing.T) { t.Run("ping 1.0.0.1 x50", func(t *testing.T) {
@ -668,29 +672,26 @@ func testTwoDevicePing(t *testing.T, d *devices) {
ep1 := cfgs[1].Peers[0].Endpoints ep1 := cfgs[1].Peers[0].Endpoints
ep1 = append([]wgcfg.Endpoint{derpEp}, ep1...) ep1 = append([]wgcfg.Endpoint{derpEp}, ep1...)
cfgs[1].Peers[0].Endpoints = ep1 cfgs[1].Peers[0].Endpoints = ep1
if err := dev1.Reconfig(&cfgs[0]); err != nil { if err := m1.dev.Reconfig(&cfgs[0]); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := dev2.Reconfig(&cfgs[1]); err != nil { if err := m2.dev.Reconfig(&cfgs[1]); err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Run("add DERP", func(t *testing.T) { t.Run("add DERP", func(t *testing.T) {
setT(t) setT(t)
defer setT(outerT) defer setT(outerT)
defer func() {
logf("DERP vars: %s", derpServer.ExpVar().String())
}()
pingSeq(t, 20, 0, true) pingSeq(t, 20, 0, true)
}) })
// Disable real route. // Disable real route.
cfgs[0].Peers[0].Endpoints = []wgcfg.Endpoint{derpEp} cfgs[0].Peers[0].Endpoints = []wgcfg.Endpoint{derpEp}
cfgs[1].Peers[0].Endpoints = []wgcfg.Endpoint{derpEp} cfgs[1].Peers[0].Endpoints = []wgcfg.Endpoint{derpEp}
if err := dev1.Reconfig(&cfgs[0]); err != nil { if err := m1.dev.Reconfig(&cfgs[0]); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := dev2.Reconfig(&cfgs[1]); err != nil { if err := m2.dev.Reconfig(&cfgs[1]); err != nil {
t.Fatal(err) t.Fatal(err)
} }
time.Sleep(250 * time.Millisecond) // TODO remove time.Sleep(250 * time.Millisecond) // TODO remove
@ -699,7 +700,6 @@ func testTwoDevicePing(t *testing.T, d *devices) {
setT(t) setT(t)
defer setT(outerT) defer setT(outerT)
defer func() { defer func() {
logf("DERP vars: %s", derpServer.ExpVar().String())
if t.Failed() || true { if t.Failed() || true {
uapi1, _ := cfgs[0].ToUAPI() uapi1, _ := cfgs[0].ToUAPI()
logf("cfg0: %v", uapi1) logf("cfg0: %v", uapi1)
@ -710,8 +710,8 @@ func testTwoDevicePing(t *testing.T, d *devices) {
pingSeq(t, 20, 0, true) pingSeq(t, 20, 0, true)
}) })
dev1.RemoveAllPeers() m1.dev.RemoveAllPeers()
dev2.RemoveAllPeers() m2.dev.RemoveAllPeers()
// Give one peer a non-DERP endpoint. We expect the other to // Give one peer a non-DERP endpoint. We expect the other to
// accept it via roamAddr. // accept it via roamAddr.
@ -719,10 +719,10 @@ func testTwoDevicePing(t *testing.T, d *devices) {
if ep2 := cfgs[1].Peers[0].Endpoints; len(ep2) != 1 { if ep2 := cfgs[1].Peers[0].Endpoints; len(ep2) != 1 {
t.Errorf("unexpected peer endpoints in dev2: %v", ep2) t.Errorf("unexpected peer endpoints in dev2: %v", ep2)
} }
if err := dev2.Reconfig(&cfgs[1]); err != nil { if err := m2.dev.Reconfig(&cfgs[1]); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := dev1.Reconfig(&cfgs[0]); err != nil { if err := m1.dev.Reconfig(&cfgs[0]); err != nil {
t.Fatal(err) t.Fatal(err)
} }
// Dear future human debugging a test failure here: this test is // Dear future human debugging a test failure here: this test is
@ -736,7 +736,7 @@ func testTwoDevicePing(t *testing.T, d *devices) {
defer setT(outerT) defer setT(outerT)
pingSeq(t, 50, 700*time.Millisecond, false) pingSeq(t, 50, 700*time.Millisecond, false)
ep2 := dev2.Config().Peers[0].Endpoints ep2 := m2.dev.Config().Peers[0].Endpoints
if len(ep2) != 2 { if len(ep2) != 2 {
t.Error("handshake spray failed to find real route") t.Error("handshake spray failed to find real route")
} }

Loading…
Cancel
Save