diff --git a/client/tailscale/tailscale.go b/client/tailscale/tailscale.go index b3f4c2eb2..c0d4a9bf5 100644 --- a/client/tailscale/tailscale.go +++ b/client/tailscale/tailscale.go @@ -38,6 +38,9 @@ var ( // TailscaledSocket is the tailscaled Unix socket. It's used by the TailscaledDialer. TailscaledSocket = paths.DefaultTailscaledSocket() + // TailscaledSocketSetExplicitly reports whether the user explicitly set TailscaledSocket. + TailscaledSocketSetExplicitly bool + // TailscaledDialer is the DialContext func that connects to the local machine's // tailscaled or equivalent. TailscaledDialer = defaultDialer @@ -47,7 +50,8 @@ func defaultDialer(ctx context.Context, network, addr string) (net.Conn, error) if addr != "local-tailscaled.sock:80" { return nil, fmt.Errorf("unexpected URL address %q", addr) } - if TailscaledSocket == paths.DefaultTailscaledSocket() { + // TODO: make this part of a safesocket.ConnectionStrategy + if !TailscaledSocketSetExplicitly { // On macOS, when dialing from non-sandboxed program to sandboxed GUI running // a TCP server on a random port, find the random port. For HTTP connections, // we don't send the token. It gets added in an HTTP Basic-Auth header. @@ -56,7 +60,11 @@ func defaultDialer(ctx context.Context, network, addr string) (net.Conn, error) return d.DialContext(ctx, "tcp", "localhost:"+strconv.Itoa(port)) } } - return safesocket.Connect(TailscaledSocket, safesocket.WindowsLocalPort) + s := safesocket.DefaultConnectionStrategy(TailscaledSocket) + // The user provided a non-default tailscaled socket address. + // Connect only to exactly what they provided. + s.UseFallback(false) + return safesocket.Connect(s) } var ( diff --git a/cmd/tailscale/cli/cli.go b/cmd/tailscale/cli/cli.go index d139c63a3..294304786 100644 --- a/cmd/tailscale/cli/cli.go +++ b/cmd/tailscale/cli/cli.go @@ -164,6 +164,11 @@ change in the future. } tailscale.TailscaledSocket = rootArgs.socket + rootfs.Visit(func(f *flag.Flag) { + if f.Name == "socket" { + tailscale.TailscaledSocketSetExplicitly = true + } + }) err := rootCmd.Run(context.Background()) if errors.Is(err, flag.ErrHelp) { @@ -191,7 +196,8 @@ var rootArgs struct { var gotSignal syncs.AtomicBool func connect(ctx context.Context) (net.Conn, *ipn.BackendClient, context.Context, context.CancelFunc) { - c, err := safesocket.Connect(rootArgs.socket, safesocket.WindowsLocalPort) + s := safesocket.DefaultConnectionStrategy(rootArgs.socket) + c, err := safesocket.Connect(s) if err != nil { if runtime.GOOS != "windows" && rootArgs.socket == "" { fatalf("--socket cannot be empty") diff --git a/ipn/ipnserver/server_test.go b/ipn/ipnserver/server_test.go index aa66c7040..777e650c0 100644 --- a/ipn/ipnserver/server_test.go +++ b/ipn/ipnserver/server_test.go @@ -33,10 +33,11 @@ func TestRunMultipleAccepts(t *testing.T) { t.Logf(format, args...) } + s := safesocket.DefaultConnectionStrategy(socketPath) connect := func() { for i := 1; i <= 2; i++ { logf("connect %d ...", i) - c, err := safesocket.Connect(socketPath, 0) + c, err := safesocket.Connect(s) if err != nil { t.Fatalf("safesocket.Connect: %v\n", err) } diff --git a/safesocket/basic_test.go b/safesocket/basic_test.go index 4c61ec14d..f3057995f 100644 --- a/safesocket/basic_test.go +++ b/safesocket/basic_test.go @@ -48,7 +48,9 @@ func TestBasics(t *testing.T) { }() go func() { - c, err := Connect(sock, port) + s := DefaultConnectionStrategy(sock) + s.UsePort(port) + c, err := Connect(s) if err != nil { errs <- err return diff --git a/safesocket/pipe_windows.go b/safesocket/pipe_windows.go index d1c69d08d..9041cbefb 100644 --- a/safesocket/pipe_windows.go +++ b/safesocket/pipe_windows.go @@ -11,8 +11,8 @@ import ( "syscall" ) -func connect(path string, port uint16) (net.Conn, error) { - pipe, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port)) +func connect(s *ConnectionStrategy) (net.Conn, error) { + pipe, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", s.port)) if err != nil { return nil, err } diff --git a/safesocket/safesocket.go b/safesocket/safesocket.go index db51f0456..e5b7e80d4 100644 --- a/safesocket/safesocket.go +++ b/safesocket/safesocket.go @@ -57,10 +57,65 @@ func tailscaledStillStarting() bool { return tailscaledProcExists() } -// Connect connects to either path (on Unix) or the provided localhost port (on Windows). -func Connect(path string, port uint16) (net.Conn, error) { +// A ConnectionStrategy is a plan for how to connect to tailscaled or equivalent (e.g. IPNExtension on macOS). +type ConnectionStrategy struct { + // For now, a ConnectionStrategy is just a unix socket path, a TCP port, + // and a flag indicating whether to try fallback connections options. + path string + port uint16 + fallback bool + // Longer term, a ConnectionStrategy should be an ordered list of things to attempt, + // with just the information required to connection for each. + // + // We have at least these cases to consider (see issue 3530): + // + // tailscale sandbox | tailscaled sandbox | OS | connection + // ------------------|--------------------|---------|----------- + // no | no | unix | unix socket + // no | no | Windows | TCP/port + // no | no | wasm | memconn + // no | Network Extension | macOS | TCP/port/token, port/token from lsof + // no | System Extension | macOS | TCP/port/token, port/token from lsof + // yes | Network Extension | macOS | TCP/port/token, port/token from readdir + // yes | System Extension | macOS | TCP/port/token, port/token from readdir + // + // Note e.g. that port is only relevant as an input to Connect on Windows, + // that path is not relevant to Windows, and that neither matters to wasm. +} + +// DefaultConnectionStrategy returns a default connection strategy. +// The default strategy is to attempt to connect in as many ways as possible. +// It uses path as the unix socket path, when applicable, +// and defaults to WindowsLocalPort for the TCP port when applicable. +// It falls back to auto-discovery across sandbox boundaries on macOS. +// TODO: maybe take no arguments, since path is irrelevant on Windows? Discussion in PR 3499. +func DefaultConnectionStrategy(path string) *ConnectionStrategy { + return &ConnectionStrategy{path: path, port: WindowsLocalPort, fallback: true} +} + +// UsePort modifies s to use port for the TCP port when applicable. +// UsePort is only applicable on Windows, and only then +// when not using the default for Windows. +func (s *ConnectionStrategy) UsePort(port uint16) { + s.port = port +} + +// UseFallback modifies s to set whether it should fall back +// to connecting to the macOS GUI's tailscaled +// if the Unix socket path wasn't reachable. +func (s *ConnectionStrategy) UseFallback(b bool) { + s.fallback = b +} + +// ExactPath returns a connection strategy that only attempts to connect via path. +func ExactPath(path string) *ConnectionStrategy { + return &ConnectionStrategy{path: path, fallback: false} +} + +// Connect connects to tailscaled using s +func Connect(s *ConnectionStrategy) (net.Conn, error) { for { - c, err := connect(path, port) + c, err := connect(s) if err != nil && tailscaledStillStarting() { time.Sleep(250 * time.Millisecond) continue diff --git a/safesocket/safesocket_js.go b/safesocket/safesocket_js.go index d5e1616be..fadf81a43 100644 --- a/safesocket/safesocket_js.go +++ b/safesocket/safesocket_js.go @@ -17,6 +17,6 @@ func listen(path string, port uint16) (_ net.Listener, gotPort uint16, _ error) return ln, 1, err } -func connect(path string, port uint16) (net.Conn, error) { +func connect(_ *ConnectionStrategy) (net.Conn, error) { return memconn.Dial("memu", memName) } diff --git a/safesocket/unixsocket.go b/safesocket/unixsocket.go index 4117506ed..4efd05ccd 100644 --- a/safesocket/unixsocket.go +++ b/safesocket/unixsocket.go @@ -23,19 +23,19 @@ import ( ) // TODO(apenwarr): handle magic cookie auth -func connect(path string, port uint16) (net.Conn, error) { +func connect(s *ConnectionStrategy) (net.Conn, error) { if runtime.GOOS == "js" { return nil, errors.New("safesocket.Connect not yet implemented on js/wasm") } - if runtime.GOOS == "darwin" && path == "" && port == 0 { + if runtime.GOOS == "darwin" && s.fallback && s.path == "" && s.port == 0 { return connectMacOSAppSandbox() } - pipe, err := net.Dial("unix", path) + pipe, err := net.Dial("unix", s.path) if err != nil { - if runtime.GOOS == "darwin" { + if runtime.GOOS == "darwin" && s.fallback { extConn, extErr := connectMacOSAppSandbox() if extErr != nil { - return nil, fmt.Errorf("safesocket: failed to connect to %v: %v; failed to connect to Tailscale IPNExtension: %v", path, err, extErr) + return nil, fmt.Errorf("safesocket: failed to connect to %v: %v; failed to connect to Tailscale IPNExtension: %v", s.path, err, extErr) } return extConn, nil } diff --git a/tstest/integration/integration_test.go b/tstest/integration/integration_test.go index 56d778801..38ed7b327 100644 --- a/tstest/integration/integration_test.go +++ b/tstest/integration/integration_test.go @@ -720,8 +720,10 @@ func (n *testNode) MustDown() { // AwaitListening waits for the tailscaled to be serving local clients // over its localhost IPC mechanism. (Unix socket, etc) func (n *testNode) AwaitListening(t testing.TB) { + s := safesocket.DefaultConnectionStrategy(n.sockFile) + s.UseFallback(false) // connect only to the tailscaled that we started if err := tstest.WaitFor(20*time.Second, func() (err error) { - c, err := safesocket.Connect(n.sockFile, safesocket.WindowsLocalPort) + c, err := safesocket.Connect(s) if err != nil { return err }