From 44e027abcaee9c89bc0a9812a1f512f8aeb966d0 Mon Sep 17 00:00:00 2001 From: David Crawshaw Date: Wed, 22 Feb 2023 19:52:17 -0800 Subject: [PATCH] tsnet: add data transfer test Signed-off-by: David Crawshaw --- tsnet/tsnet.go | 30 ++++++++--- tsnet/tsnet_test.go | 123 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 145 insertions(+), 8 deletions(-) diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index 20f0feaf6..0c33ee1c3 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -204,11 +204,15 @@ func (s *Server) Close() error { go func() { defer wg.Done() // Perform a best-effort final flush. - s.logtail.Shutdown(ctx) - s.logbuffer.Close() + if s.logtail != nil { + s.logtail.Shutdown(ctx) + } + if s.logbuffer != nil { + s.logbuffer.Close() + } }() - if _, isMemStore := s.Store.(*mem.Store); isMemStore && s.Ephemeral { + if _, isMemStore := s.Store.(*mem.Store); isMemStore && s.Ephemeral && s.lb != nil { wg.Add(1) go func() { defer wg.Done() @@ -221,11 +225,21 @@ func (s *Server) Close() error { s.netstack.Close() s.netstack = nil } - s.shutdownCancel() - s.lb.Shutdown() - s.linkMon.Close() - s.dialer.Close() - s.localAPIListener.Close() + if s.shutdownCancel != nil { + s.shutdownCancel() + } + if s.lb != nil { + s.lb.Shutdown() + } + if s.linkMon != nil { + s.linkMon.Close() + } + if s.dialer != nil { + s.dialer.Close() + } + if s.localAPIListener != nil { + s.localAPIListener.Close() + } s.mu.Lock() defer s.mu.Unlock() diff --git a/tsnet/tsnet_test.go b/tsnet/tsnet_test.go index 26f7cf117..9d0fa8700 100644 --- a/tsnet/tsnet_test.go +++ b/tsnet/tsnet_test.go @@ -4,8 +4,23 @@ package tsnet import ( + "context" "errors" + "flag" + "fmt" + "io" + "path/filepath" + "os" + "net/http/httptest" "testing" + "time" + + "tailscale.com/ipn/store/mem" + "tailscale.com/tailcfg" + "tailscale.com/tstest/integration" + "tailscale.com/net/netns" + "tailscale.com/tstest/integration/testcontrol" + "tailscale.com/types/logger" ) // TestListener_Server ensures that the listener type always keeps the Server @@ -44,3 +59,111 @@ func TestListenerPort(t *testing.T) { } } } + +var verboseDERP = flag.Bool("verbose-derp", false, "if set, print DERP and STUN logs") +var verboseNodes = flag.Bool("verbose-nodes", false, "if set, print tsnet.Server logs") + +func TestConn(t *testing.T) { + // Corp#4520: don't use netns for tests. + netns.SetEnabled(false) + t.Cleanup(func() { + netns.SetEnabled(true) + }) + + derpLogf := logger.Discard + if *verboseDERP { + derpLogf = t.Logf + } + derpMap := integration.RunDERPAndSTUN(t, derpLogf, "127.0.0.1") + control := &testcontrol.Server{ + DERPMap: derpMap, + } + control.HTTPTestServer = httptest.NewUnstartedServer(control) + control.HTTPTestServer.Start() + t.Cleanup(control.HTTPTestServer.Close) + controlURL := control.HTTPTestServer.URL + t.Logf("testcontrol listening on %s", controlURL) + + tmp := t.TempDir() + tmps1 := filepath.Join(tmp, "s1") + os.MkdirAll(tmps1, 0755) + s1 := &Server{ + Dir: tmps1, + ControlURL: controlURL, + Hostname: "s1", + Store: new(mem.Store), + Ephemeral: true, + } + defer s1.Close() + + tmps2 := filepath.Join(tmp, "s1") + os.MkdirAll(tmps2, 0755) + s2 := &Server{ + Dir: tmps2, + ControlURL: controlURL, + Hostname: "s2", + Store: new(mem.Store), + Ephemeral: true, + } + defer s2.Close() + + if !*verboseNodes { + s1.Logf = logger.Discard + s2.Logf = logger.Discard + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + s1status, err := s1.Up(ctx) + if err != nil { + t.Fatal(err) + } + s1ip := s1status.TailscaleIPs[0] + if _, err := s2.Up(ctx); err != nil { + t.Fatal(err) + } + + lc2, err := s2.LocalClient() + if err != nil { + t.Fatal(err) + } + + // ping to make sure the connection is up. + res, err := lc2.Ping(ctx, s1ip, tailcfg.PingICMP) + if err != nil { + t.Fatal(err) + } + t.Logf("ping success: %#+v", res) + + // pass some data through TCP. + ln, err := s1.Listen("tcp", ":8081") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + w, err := s2.Dial(ctx, "tcp", fmt.Sprintf("%s:8081", s1ip)) + if err != nil { + t.Fatal(err) + } + + r, err := ln.Accept() + if err != nil { + t.Fatal(err) + } + + want := "hello" + if _, err := io.WriteString(w, want); err != nil { + t.Fatal(err) + } + + got := make([]byte, len(want)) + if _, err := io.ReadAtLeast(r, got, len(got)); err != nil { + t.Fatal(err) + } + t.Logf("got: %q", got) + if string(got) != want { + t.Errorf("got %q, want %q", got, want) + } +}