safesocket: add ConnectionStrategy, provide control over fallbacks

fee2d9fad added support for cmd/tailscale to connect to IPNExtension.
It came in two parts: If no socket was provided, dial IPNExtension first,
and also, if dialing the socket failed, fall back to IPNExtension.

The second half of that support caused the integration tests to fail
when run on a machine that was also running IPNExtension.
The integration tests want to wait until the tailscaled instances
that they spun up are listening. They do that by dialing the new
instance. But when that dial failed, it was falling back to IPNExtension,
so it appeared (incorrectly) that tailscaled was running.
Hilarity predictably ensued.

If a user (or a test) explicitly provides a socket to dial,
it is a reasonable assumption that they have a specific tailscaled
in mind and don't want to fall back to IPNExtension.
It is certainly true of the integration tests.

Instead of adding a bool to Connect, split out the notion of a
connection strategy. For now, the implementation remains the same,
but with the details hidden a bit. Later, we can improve that.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
pull/3546/head
Josh Bleecher Snyder 3 years ago committed by Josh Bleecher Snyder
parent a5235e165c
commit 63cd581c3f

@ -38,6 +38,9 @@ var (
// TailscaledSocket is the tailscaled Unix socket. It's used by the TailscaledDialer. // TailscaledSocket is the tailscaled Unix socket. It's used by the TailscaledDialer.
TailscaledSocket = paths.DefaultTailscaledSocket() TailscaledSocket = paths.DefaultTailscaledSocket()
// TailscaledSocketSetExplicitly reports whether the user explicitly set TailscaledSocket.
TailscaledSocketSetExplicitly bool
// TailscaledDialer is the DialContext func that connects to the local machine's // TailscaledDialer is the DialContext func that connects to the local machine's
// tailscaled or equivalent. // tailscaled or equivalent.
TailscaledDialer = defaultDialer TailscaledDialer = defaultDialer
@ -47,7 +50,8 @@ func defaultDialer(ctx context.Context, network, addr string) (net.Conn, error)
if addr != "local-tailscaled.sock:80" { if addr != "local-tailscaled.sock:80" {
return nil, fmt.Errorf("unexpected URL address %q", addr) return nil, fmt.Errorf("unexpected URL address %q", addr)
} }
if TailscaledSocket == paths.DefaultTailscaledSocket() { // TODO: make this part of a safesocket.ConnectionStrategy
if !TailscaledSocketSetExplicitly {
// On macOS, when dialing from non-sandboxed program to sandboxed GUI running // On macOS, when dialing from non-sandboxed program to sandboxed GUI running
// a TCP server on a random port, find the random port. For HTTP connections, // a TCP server on a random port, find the random port. For HTTP connections,
// we don't send the token. It gets added in an HTTP Basic-Auth header. // we don't send the token. It gets added in an HTTP Basic-Auth header.
@ -56,7 +60,11 @@ func defaultDialer(ctx context.Context, network, addr string) (net.Conn, error)
return d.DialContext(ctx, "tcp", "localhost:"+strconv.Itoa(port)) return d.DialContext(ctx, "tcp", "localhost:"+strconv.Itoa(port))
} }
} }
return safesocket.Connect(TailscaledSocket, safesocket.WindowsLocalPort) s := safesocket.DefaultConnectionStrategy(TailscaledSocket)
// The user provided a non-default tailscaled socket address.
// Connect only to exactly what they provided.
s.UseFallback(false)
return safesocket.Connect(s)
} }
var ( var (

@ -164,6 +164,11 @@ change in the future.
} }
tailscale.TailscaledSocket = rootArgs.socket tailscale.TailscaledSocket = rootArgs.socket
rootfs.Visit(func(f *flag.Flag) {
if f.Name == "socket" {
tailscale.TailscaledSocketSetExplicitly = true
}
})
err := rootCmd.Run(context.Background()) err := rootCmd.Run(context.Background())
if errors.Is(err, flag.ErrHelp) { if errors.Is(err, flag.ErrHelp) {
@ -191,7 +196,8 @@ var rootArgs struct {
var gotSignal syncs.AtomicBool var gotSignal syncs.AtomicBool
func connect(ctx context.Context) (net.Conn, *ipn.BackendClient, context.Context, context.CancelFunc) { func connect(ctx context.Context) (net.Conn, *ipn.BackendClient, context.Context, context.CancelFunc) {
c, err := safesocket.Connect(rootArgs.socket, safesocket.WindowsLocalPort) s := safesocket.DefaultConnectionStrategy(rootArgs.socket)
c, err := safesocket.Connect(s)
if err != nil { if err != nil {
if runtime.GOOS != "windows" && rootArgs.socket == "" { if runtime.GOOS != "windows" && rootArgs.socket == "" {
fatalf("--socket cannot be empty") fatalf("--socket cannot be empty")

@ -33,10 +33,11 @@ func TestRunMultipleAccepts(t *testing.T) {
t.Logf(format, args...) t.Logf(format, args...)
} }
s := safesocket.DefaultConnectionStrategy(socketPath)
connect := func() { connect := func() {
for i := 1; i <= 2; i++ { for i := 1; i <= 2; i++ {
logf("connect %d ...", i) logf("connect %d ...", i)
c, err := safesocket.Connect(socketPath, 0) c, err := safesocket.Connect(s)
if err != nil { if err != nil {
t.Fatalf("safesocket.Connect: %v\n", err) t.Fatalf("safesocket.Connect: %v\n", err)
} }

@ -48,7 +48,9 @@ func TestBasics(t *testing.T) {
}() }()
go func() { go func() {
c, err := Connect(sock, port) s := DefaultConnectionStrategy(sock)
s.UsePort(port)
c, err := Connect(s)
if err != nil { if err != nil {
errs <- err errs <- err
return return

@ -11,8 +11,8 @@ import (
"syscall" "syscall"
) )
func connect(path string, port uint16) (net.Conn, error) { func connect(s *ConnectionStrategy) (net.Conn, error) {
pipe, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port)) pipe, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", s.port))
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -57,10 +57,65 @@ func tailscaledStillStarting() bool {
return tailscaledProcExists() return tailscaledProcExists()
} }
// Connect connects to either path (on Unix) or the provided localhost port (on Windows). // A ConnectionStrategy is a plan for how to connect to tailscaled or equivalent (e.g. IPNExtension on macOS).
func Connect(path string, port uint16) (net.Conn, error) { type ConnectionStrategy struct {
// For now, a ConnectionStrategy is just a unix socket path, a TCP port,
// and a flag indicating whether to try fallback connections options.
path string
port uint16
fallback bool
// Longer term, a ConnectionStrategy should be an ordered list of things to attempt,
// with just the information required to connection for each.
//
// We have at least these cases to consider (see issue 3530):
//
// tailscale sandbox | tailscaled sandbox | OS | connection
// ------------------|--------------------|---------|-----------
// no | no | unix | unix socket
// no | no | Windows | TCP/port
// no | no | wasm | memconn
// no | Network Extension | macOS | TCP/port/token, port/token from lsof
// no | System Extension | macOS | TCP/port/token, port/token from lsof
// yes | Network Extension | macOS | TCP/port/token, port/token from readdir
// yes | System Extension | macOS | TCP/port/token, port/token from readdir
//
// Note e.g. that port is only relevant as an input to Connect on Windows,
// that path is not relevant to Windows, and that neither matters to wasm.
}
// DefaultConnectionStrategy returns a default connection strategy.
// The default strategy is to attempt to connect in as many ways as possible.
// It uses path as the unix socket path, when applicable,
// and defaults to WindowsLocalPort for the TCP port when applicable.
// It falls back to auto-discovery across sandbox boundaries on macOS.
// TODO: maybe take no arguments, since path is irrelevant on Windows? Discussion in PR 3499.
func DefaultConnectionStrategy(path string) *ConnectionStrategy {
return &ConnectionStrategy{path: path, port: WindowsLocalPort, fallback: true}
}
// UsePort modifies s to use port for the TCP port when applicable.
// UsePort is only applicable on Windows, and only then
// when not using the default for Windows.
func (s *ConnectionStrategy) UsePort(port uint16) {
s.port = port
}
// UseFallback modifies s to set whether it should fall back
// to connecting to the macOS GUI's tailscaled
// if the Unix socket path wasn't reachable.
func (s *ConnectionStrategy) UseFallback(b bool) {
s.fallback = b
}
// ExactPath returns a connection strategy that only attempts to connect via path.
func ExactPath(path string) *ConnectionStrategy {
return &ConnectionStrategy{path: path, fallback: false}
}
// Connect connects to tailscaled using s
func Connect(s *ConnectionStrategy) (net.Conn, error) {
for { for {
c, err := connect(path, port) c, err := connect(s)
if err != nil && tailscaledStillStarting() { if err != nil && tailscaledStillStarting() {
time.Sleep(250 * time.Millisecond) time.Sleep(250 * time.Millisecond)
continue continue

@ -17,6 +17,6 @@ func listen(path string, port uint16) (_ net.Listener, gotPort uint16, _ error)
return ln, 1, err return ln, 1, err
} }
func connect(path string, port uint16) (net.Conn, error) { func connect(_ *ConnectionStrategy) (net.Conn, error) {
return memconn.Dial("memu", memName) return memconn.Dial("memu", memName)
} }

@ -23,19 +23,19 @@ import (
) )
// TODO(apenwarr): handle magic cookie auth // TODO(apenwarr): handle magic cookie auth
func connect(path string, port uint16) (net.Conn, error) { func connect(s *ConnectionStrategy) (net.Conn, error) {
if runtime.GOOS == "js" { if runtime.GOOS == "js" {
return nil, errors.New("safesocket.Connect not yet implemented on js/wasm") return nil, errors.New("safesocket.Connect not yet implemented on js/wasm")
} }
if runtime.GOOS == "darwin" && path == "" && port == 0 { if runtime.GOOS == "darwin" && s.fallback && s.path == "" && s.port == 0 {
return connectMacOSAppSandbox() return connectMacOSAppSandbox()
} }
pipe, err := net.Dial("unix", path) pipe, err := net.Dial("unix", s.path)
if err != nil { if err != nil {
if runtime.GOOS == "darwin" { if runtime.GOOS == "darwin" && s.fallback {
extConn, extErr := connectMacOSAppSandbox() extConn, extErr := connectMacOSAppSandbox()
if extErr != nil { if extErr != nil {
return nil, fmt.Errorf("safesocket: failed to connect to %v: %v; failed to connect to Tailscale IPNExtension: %v", path, err, extErr) return nil, fmt.Errorf("safesocket: failed to connect to %v: %v; failed to connect to Tailscale IPNExtension: %v", s.path, err, extErr)
} }
return extConn, nil return extConn, nil
} }

@ -720,8 +720,10 @@ func (n *testNode) MustDown() {
// AwaitListening waits for the tailscaled to be serving local clients // AwaitListening waits for the tailscaled to be serving local clients
// over its localhost IPC mechanism. (Unix socket, etc) // over its localhost IPC mechanism. (Unix socket, etc)
func (n *testNode) AwaitListening(t testing.TB) { func (n *testNode) AwaitListening(t testing.TB) {
s := safesocket.DefaultConnectionStrategy(n.sockFile)
s.UseFallback(false) // connect only to the tailscaled that we started
if err := tstest.WaitFor(20*time.Second, func() (err error) { if err := tstest.WaitFor(20*time.Second, func() (err error) {
c, err := safesocket.Connect(n.sockFile, safesocket.WindowsLocalPort) c, err := safesocket.Connect(s)
if err != nil { if err != nil {
return err return err
} }

Loading…
Cancel
Save