From c0158bcd0bd9b78cef5056b8088a739aa2f117a1 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Tue, 11 May 2021 21:57:25 -0700 Subject: [PATCH] tstest/integration{,/testcontrol}: add testcontrol.RequireAuth mode, new test Signed-off-by: Brad Fitzpatrick --- tstest/integration/integration_test.go | 230 ++++++++++++++---- tstest/integration/testcontrol/testcontrol.go | 136 ++++++++++- 2 files changed, 306 insertions(+), 60 deletions(-) diff --git a/tstest/integration/integration_test.go b/tstest/integration/integration_test.go index 922d6db3f..3b1eecfec 100644 --- a/tstest/integration/integration_test.go +++ b/tstest/integration/integration_test.go @@ -10,6 +10,7 @@ import ( crand "crypto/rand" "crypto/tls" "encoding/json" + "flag" "fmt" "io" "io/ioutil" @@ -19,8 +20,8 @@ import ( "net/http/httptest" "os" "os/exec" - "path" "path/filepath" + "regexp" "runtime" "strings" "sync" @@ -43,6 +44,8 @@ import ( "tailscale.com/types/nettype" ) +var verbose = flag.Bool("verbose", false, "verbose debug logs") + var mainError atomic.Value // of error func TestMain(m *testing.M) { @@ -57,11 +60,8 @@ func TestMain(m *testing.M) { os.Exit(0) } -func TestIntegration(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("not tested/working on Windows yet") - } - +func TestOneNodeUp_NoAuth(t *testing.T) { + t.Parallel() bins := buildTestBinaries(t) env := newTestEnv(t, bins) @@ -69,8 +69,8 @@ func TestIntegration(t *testing.T) { n1 := newTestNode(t, env) - dcmd := n1.StartDaemon(t) - defer dcmd.Process.Kill() + d1 := n1.StartDaemon(t) + defer d1.Kill() n1.AwaitListening(t) @@ -97,34 +97,60 @@ func TestIntegration(t *testing.T) { time.Sleep(d) } - var ip string - if err := tstest.WaitFor(20*time.Second, func() error { - out, err := n1.Tailscale("ip").Output() - if err != nil { - return err + t.Logf("Got IP: %v", n1.AwaitIP(t)) + n1.AwaitRunning(t) + + d1.MustCleanShutdown(t) + + t.Logf("number of HTTP logcatcher requests: %v", env.LogCatcher.numRequests()) +} + +func TestOneNodeUp_Auth(t *testing.T) { + t.Parallel() + bins := buildTestBinaries(t) + + env := newTestEnv(t, bins) + defer env.Close() + env.Control.RequireAuth = true + + n1 := newTestNode(t, env) + + d1 := n1.StartDaemon(t) + defer d1.Kill() + + n1.AwaitListening(t) + + st := n1.MustStatus(t) + t.Logf("Status: %s", st.BackendState) + + t.Logf("Running up --login-server=%s ...", env.ControlServer.URL) + + cmd := n1.Tailscale("up", "--login-server="+env.ControlServer.URL) + var authCountAtomic int32 + cmd.Stdout = &authURLParserWriter{fn: func(urlStr string) error { + if env.Control.CompleteAuth(urlStr) { + atomic.AddInt32(&authCountAtomic, 1) + t.Logf("completed auth path %s", urlStr) + return nil } - ip = string(out) - return nil - }); err != nil { - t.Error(err) + err := fmt.Errorf("Failed to complete auth path to %q", urlStr) + t.Log(err) + return err + }} + cmd.Stderr = cmd.Stdout + if err := cmd.Run(); err != nil { + t.Fatalf("up: %v", err) } - t.Logf("Got IP: %v", ip) + t.Logf("Got IP: %v", n1.AwaitIP(t)) - dcmd.Process.Signal(os.Interrupt) + n1.AwaitRunning(t) - ps, err := dcmd.Process.Wait() - if err != nil { - t.Fatalf("tailscaled Wait: %v", err) - } - if ps.ExitCode() != 0 { - t.Errorf("tailscaled ExitCode = %d; want 0", ps.ExitCode()) + if n := atomic.LoadInt32(&authCountAtomic); n != 1 { + t.Errorf("Auth URLs completed = %d; want 1", n) } - t.Logf("number of HTTP logcatcher requests: %v", env.LogCatcher.numRequests()) - if err := env.TrafficTrap.Err(); err != nil { - t.Errorf("traffic trap: %v", err) - t.Logf("logs: %s", env.LogCatcher.logsString()) - } + d1.MustCleanShutdown(t) + } // testBinaries are the paths to a tailscaled and tailscale binary. @@ -139,16 +165,18 @@ type testBinaries struct { // if they fail to compile. func buildTestBinaries(t testing.TB) *testBinaries { td := t.TempDir() + build(t, td, "tailscale.com/cmd/tailscaled", "tailscale.com/cmd/tailscale") return &testBinaries{ dir: td, - daemon: build(t, td, "tailscale.com/cmd/tailscaled"), - cli: build(t, td, "tailscale.com/cmd/tailscale"), + daemon: filepath.Join(td, "tailscaled"+exe()), + cli: filepath.Join(td, "tailscale"+exe()), } } // testEnv contains the test environment (set of servers) used by one // or more nodes. type testEnv struct { + t testing.TB Binaries *testBinaries LogCatcher *logCatcher @@ -168,6 +196,9 @@ type testEnv struct { // // Call Close to shut everything down. func newTestEnv(t testing.TB, bins *testBinaries) *testEnv { + if runtime.GOOS == "windows" { + t.Skip("not tested/working on Windows yet") + } derpMap, derpShutdown := runDERPAndStun(t, logger.Discard) logc := new(logCatcher) control := &testcontrol.Server{ @@ -175,6 +206,7 @@ func newTestEnv(t testing.TB, bins *testBinaries) *testEnv { } trafficTrap := new(trafficTrap) e := &testEnv{ + t: t, Binaries: bins, LogCatcher: logc, LogCatcherServer: httptest.NewServer(logc), @@ -184,10 +216,16 @@ func newTestEnv(t testing.TB, bins *testBinaries) *testEnv { TrafficTrapServer: httptest.NewServer(trafficTrap), derpShutdown: derpShutdown, } + e.Control.BaseURL = e.ControlServer.URL return e } func (e *testEnv) Close() error { + if err := e.TrafficTrap.Err(); err != nil { + e.t.Errorf("traffic trap: %v", err) + e.t.Logf("logs: %s", e.LogCatcher.logsString()) + } + e.LogCatcherServer.Close() e.TrafficTrapServer.Close() e.ControlServer.Close() @@ -218,9 +256,28 @@ func newTestNode(t *testing.T, env *testEnv) *testNode { } } +type Daemon struct { + Process *os.Process +} + +func (d *Daemon) Kill() { + d.Process.Kill() +} + +func (d *Daemon) MustCleanShutdown(t testing.TB) { + d.Process.Signal(os.Interrupt) + ps, err := d.Process.Wait() + if err != nil { + t.Fatalf("tailscaled Wait: %v", err) + } + if ps.ExitCode() != 0 { + t.Errorf("tailscaled ExitCode = %d; want 0", ps.ExitCode()) + } +} + // StartDaemon starts the node's tailscaled, failing if it fails to // start. -func (n *testNode) StartDaemon(t testing.TB) *exec.Cmd { +func (n *testNode) StartDaemon(t testing.TB) *Daemon { cmd := exec.Command(n.env.Binaries.daemon, "--tun=userspace-networking", "--state="+n.stateFile, @@ -234,7 +291,9 @@ func (n *testNode) StartDaemon(t testing.TB) *exec.Cmd { if err := cmd.Start(); err != nil { t.Fatalf("starting tailscaled: %v", err) } - return cmd + return &Daemon{ + Process: cmd.Process, + } } // AwaitListening waits for the tailscaled to be serving local clients @@ -252,6 +311,40 @@ func (n *testNode) AwaitListening(t testing.TB) { } } +func (n *testNode) AwaitIP(t testing.TB) (ips string) { + t.Helper() + if err := tstest.WaitFor(20*time.Second, func() error { + out, err := n.Tailscale("ip").Output() + if err != nil { + return err + } + ips = string(out) + return nil + }); err != nil { + t.Fatalf("awaiting an IP address: %v", err) + } + if ips == "" { + t.Fatalf("returned IP address was blank") + } + return ips +} + +func (n *testNode) AwaitRunning(t testing.TB) { + t.Helper() + if err := tstest.WaitFor(20*time.Second, func() error { + st, err := n.Status() + if err != nil { + return err + } + if st.BackendState != "Running" { + return fmt.Errorf("in state %q", st.BackendState) + } + return nil + }); err != nil { + t.Fatalf("failure/timeout waiting for transition to Running status: %v", err) + } +} + // Tailscale returns a command that runs the tailscale CLI with the provided arguments. // It does not start the process. func (n *testNode) Tailscale(arg ...string) *exec.Cmd { @@ -261,15 +354,23 @@ func (n *testNode) Tailscale(arg ...string) *exec.Cmd { return cmd } -func (n *testNode) MustStatus(tb testing.TB) *ipnstate.Status { - tb.Helper() +func (n *testNode) Status() (*ipnstate.Status, error) { out, err := n.Tailscale("status", "--json").CombinedOutput() if err != nil { - tb.Fatalf("getting status: %v, %s", err, out) + return nil, fmt.Errorf("running tailscale status: %v, %s", err, out) } st := new(ipnstate.Status) if err := json.Unmarshal(out, st); err != nil { - tb.Fatalf("parsing status json: %v, from: %s", err, out) + return nil, fmt.Errorf("decoding tailscale status JSON: %w", err) + } + return st, nil +} + +func (n *testNode) MustStatus(tb testing.TB) *ipnstate.Status { + tb.Helper() + st, err := n.Status() + if err != nil { + tb.Fatal(err) } return st } @@ -291,21 +392,31 @@ func findGo(t testing.TB) string { } else if !fi.Mode().IsRegular() { t.Fatalf("%v is unexpected %v", goBin, fi.Mode()) } - t.Logf("using go binary %v", goBin) return goBin } -func build(t testing.TB, outDir, target string) string { - exe := "" - if runtime.GOOS == "windows" { - exe = ".exe" - } - bin := filepath.Join(outDir, path.Base(target)) + exe - errOut, err := exec.Command(findGo(t), "build", "-o", bin, target).CombinedOutput() +// buildMu limits our use of "go build" to one at a time, so we don't +// fight Go's built-in caching trying to do the same build concurrently. +var buildMu sync.Mutex + +func build(t testing.TB, outDir string, targets ...string) { + buildMu.Lock() + defer buildMu.Unlock() + + t0 := time.Now() + defer func() { t.Logf("built %s in %v", targets, time.Since(t0).Round(time.Millisecond)) }() + + // TODO(bradfitz): add -race to the built binaries if our + // current binary is a race binary. + + goBin := findGo(t) + cmd := exec.Command(goBin, "install") + cmd.Args = append(cmd.Args, targets...) + cmd.Env = append(os.Environ(), "GOBIN="+outDir) + errOut, err := cmd.CombinedOutput() if err != nil { - t.Fatalf("failed to build %v: %v, %s", target, err, errOut) + t.Fatalf("failed to build %v with %v: %v, %s", targets, goBin, err, errOut) } - return bin } // logCatcher is a minimal logcatcher for the logtail upload client. @@ -378,6 +489,9 @@ func (lc *logCatcher) ServeHTTP(w http.ResponseWriter, r *http.Request) { } else { for _, ent := range jreq { fmt.Fprintf(&lc.buf, "%s\n", strings.TrimSpace(ent.Text)) + if *verbose { + fmt.Fprintf(os.Stderr, "%s\n", strings.TrimSpace(ent.Text)) + } } } w.WriteHeader(200) // must have no content, but not a 204 @@ -454,3 +568,23 @@ func runDERPAndStun(t testing.TB, logf logger.Logf) (derpMap *tailcfg.DERPMap, c return m, cleanup } + +type authURLParserWriter struct { + buf bytes.Buffer + fn func(urlStr string) error +} + +var authURLRx = regexp.MustCompile(`(https?://\S+/auth/\S+)`) + +func (w *authURLParserWriter) Write(p []byte) (n int, err error) { + n, err = w.buf.Write(p) + m := authURLRx.FindSubmatch(w.buf.Bytes()) + if m != nil { + urlStr := string(m[1]) + w.buf.Reset() // so it's not matched again + if err := w.fn(urlStr); err != nil { + return 0, err + } + } + return n, err +} diff --git a/tstest/integration/testcontrol/testcontrol.go b/tstest/integration/testcontrol/testcontrol.go index fe3a12911..cd1051c58 100644 --- a/tstest/integration/testcontrol/testcontrol.go +++ b/tstest/integration/testcontrol/testcontrol.go @@ -17,6 +17,7 @@ import ( "log" "math/rand" "net/http" + "net/url" "strings" "sync" "time" @@ -34,19 +35,43 @@ import ( // Server is a control plane server. Its zero value is ready for use. // Everything is stored in-memory in one tailnet. type Server struct { - Logf logger.Logf // nil means to use the log package - DERPMap *tailcfg.DERPMap // nil means to use prod DERP map + Logf logger.Logf // nil means to use the log package + DERPMap *tailcfg.DERPMap // nil means to use prod DERP map + RequireAuth bool + BaseURL string // must be set to e.g. "http://127.0.0.1:1234" with no trailing URL + Verbose bool initMuxOnce sync.Once mux *http.ServeMux - mu sync.Mutex - pubKey wgkey.Key - privKey wgkey.Private - nodes map[tailcfg.NodeKey]*tailcfg.Node - users map[tailcfg.NodeKey]*tailcfg.User - logins map[tailcfg.NodeKey]*tailcfg.Login - updates map[tailcfg.NodeID]chan updateType + mu sync.Mutex + pubKey wgkey.Key + privKey wgkey.Private + nodes map[tailcfg.NodeKey]*tailcfg.Node + users map[tailcfg.NodeKey]*tailcfg.User + logins map[tailcfg.NodeKey]*tailcfg.Login + updates map[tailcfg.NodeID]chan updateType + authPath map[string]*AuthPath + nodeKeyAuthed map[tailcfg.NodeKey]bool // key => true once authenticated +} + +type AuthPath struct { + nodeKey tailcfg.NodeKey + + closeOnce sync.Once + ch chan struct{} + success bool +} + +func (ap *AuthPath) completeSuccessfully() { + ap.success = true + close(ap.ch) +} + +// CompleteSuccessfully completes the login path successfully, as if +// the user did the whole auth dance. +func (ap *AuthPath) CompleteSuccessfully() { + ap.closeOnce.Do(ap.completeSuccessfully) } func (s *Server) logf(format string, a ...interface{}) { @@ -178,6 +203,56 @@ func (s *Server) getUser(nodeKey tailcfg.NodeKey) (*tailcfg.User, *tailcfg.Login return user, login } +// authPathDone returns a close-only struct that's closed when the +// authPath ("/auth/XXXXXX") has authenticated. +func (s *Server) authPathDone(authPath string) <-chan struct{} { + s.mu.Lock() + defer s.mu.Unlock() + if a, ok := s.authPath[authPath]; ok { + return a.ch + } + return nil +} + +func (s *Server) addAuthPath(authPath string, nodeKey tailcfg.NodeKey) { + s.mu.Lock() + defer s.mu.Unlock() + if s.authPath == nil { + s.authPath = map[string]*AuthPath{} + } + s.authPath[authPath] = &AuthPath{ + ch: make(chan struct{}), + nodeKey: nodeKey, + } +} + +// CompleteAuth marks the provided path or URL (containing +// "/auth/...") as successfully authenticated, unblocking any +// requests blocked on that in serveRegister. +func (s *Server) CompleteAuth(authPathOrURL string) bool { + i := strings.Index(authPathOrURL, "/auth/") + if i == -1 { + return false + } + authPath := authPathOrURL[i:] + + s.mu.Lock() + defer s.mu.Unlock() + ap, ok := s.authPath[authPath] + if !ok { + return false + } + if ap.nodeKey.IsZero() { + panic("zero AuthPath.NodeKey") + } + if s.nodeKeyAuthed == nil { + s.nodeKeyAuthed = map[tailcfg.NodeKey]bool{} + } + s.nodeKeyAuthed[ap.nodeKey] = true + ap.CompleteSuccessfully() + return true +} + func (s *Server) serveRegister(w http.ResponseWriter, r *http.Request, mkey tailcfg.MachineKey) { var req tailcfg.RegisterRequest if err := s.decode(mkey, r.Body, &req); err != nil { @@ -189,28 +264,65 @@ func (s *Server) serveRegister(w http.ResponseWriter, r *http.Request, mkey tail if req.NodeKey.IsZero() { panic("serveRegister: request has zero node key") } + if s.Verbose { + j, _ := json.MarshalIndent(req, "", "\t") + log.Printf("Got %T: %s", req, j) + } + + // If this is a followup request, wait until interactive followup URL visit complete. + if req.Followup != "" { + followupURL, err := url.Parse(req.Followup) + if err != nil { + panic(err) + } + doneCh := s.authPathDone(followupURL.Path) + select { + case <-r.Context().Done(): + return + case <-doneCh: + } + // TODO(bradfitz): support a side test API to mark an + // auth as failued so we can send an error response in + // some follow-ups? For now all are successes. + } user, login := s.getUser(req.NodeKey) s.mu.Lock() if s.nodes == nil { s.nodes = map[tailcfg.NodeKey]*tailcfg.Node{} } + + machineAuthorized := true // TODO: add Server.RequireMachineAuth + s.nodes[req.NodeKey] = &tailcfg.Node{ ID: tailcfg.NodeID(user.ID), StableID: tailcfg.StableNodeID(fmt.Sprintf("TESTCTRL%08x", int(user.ID))), User: user.ID, Machine: mkey, Key: req.NodeKey, - MachineAuthorized: true, + MachineAuthorized: machineAuthorized, + } + requireAuth := s.RequireAuth + if requireAuth && s.nodeKeyAuthed[req.NodeKey] { + requireAuth = false } s.mu.Unlock() + authURL := "" + if requireAuth { + randHex := make([]byte, 10) + crand.Read(randHex) + authPath := fmt.Sprintf("/auth/%x", randHex) + s.addAuthPath(authPath, req.NodeKey) + authURL = s.BaseURL + authPath + } + res, err := s.encode(mkey, false, tailcfg.RegisterResponse{ User: *user, Login: *login, NodeKeyExpired: false, - MachineAuthorized: true, - AuthURL: "", // all good; TODO(bradfitz): add ways to not start all good. + MachineAuthorized: machineAuthorized, + AuthURL: authURL, }) if err != nil { go panic(fmt.Sprintf("serveRegister: encode: %v", err))