From 4b6a0c42c89a6a004686a13e16d6a0821680d74d Mon Sep 17 00:00:00 2001 From: Maisem Ali Date: Mon, 10 Jun 2024 19:38:10 -0700 Subject: [PATCH] safesocket: add ConnectContext This adds a variant for Connect that takes in a context.Context which allows passing through cancellation etc by the caller. Updates tailscale/corp#18266 Signed-off-by: Maisem Ali --- drive/driveimpl/remote_impl.go | 7 ++++--- logpolicy/logpolicy.go | 2 +- safesocket/basic_test.go | 3 ++- safesocket/pipe_windows.go | 5 ++--- safesocket/safesocket.go | 16 +++++++++++++--- safesocket/safesocket_js.go | 5 +++-- safesocket/safesocket_plan9.go | 3 ++- safesocket/unixsocket.go | 6 ++++-- tstest/integration/integration_test.go | 2 +- 9 files changed, 32 insertions(+), 17 deletions(-) diff --git a/drive/driveimpl/remote_impl.go b/drive/driveimpl/remote_impl.go index c579c0f39..debdf8a36 100644 --- a/drive/driveimpl/remote_impl.go +++ b/drive/driveimpl/remote_impl.go @@ -167,7 +167,7 @@ func (s *FileSystemForRemote) buildChild(share *drive.Share) *compositedav.Child return fmt.Sprintf("http://%s/%s/%s", hex.EncodeToString([]byte(share.Name)), secretToken, url.PathEscape(share.Name)), nil }, Transport: &http.Transport{ - Dial: func(_, shareAddr string) (net.Conn, error) { + DialContext: func(ctx context.Context, _, shareAddr string) (net.Conn, error) { shareNameHex, _, err := net.SplitHostPort(shareAddr) if err != nil { return nil, fmt.Errorf("unable to parse share address %v: %w", shareAddr, err) @@ -188,10 +188,11 @@ func (s *FileSystemForRemote) buildChild(share *drive.Share) *compositedav.Child _, err = netip.ParseAddrPort(addr) if err == nil { // this is a regular network address, dial normally - return net.Dial("tcp", addr) + var std net.Dialer + return std.DialContext(ctx, "tcp", addr) } // assume this is a safesocket address - return safesocket.Connect(addr) + return safesocket.ConnectContext(ctx, addr) }, }, } diff --git a/logpolicy/logpolicy.go b/logpolicy/logpolicy.go index 2595a7da1..71093dd3c 100644 --- a/logpolicy/logpolicy.go +++ b/logpolicy/logpolicy.go @@ -709,7 +709,7 @@ func dialContext(ctx context.Context, netw, addr string, netMon *netmon.Monitor, } if version.IsWindowsGUI() && strings.HasPrefix(netw, "tcp") { - if c, err := safesocket.Connect(""); err == nil { + if c, err := safesocket.ConnectContext(ctx, ""); err == nil { fmt.Fprintf(c, "CONNECT %s HTTP/1.0\r\n\r\n", addr) br := bufio.NewReader(c) res, err := http.ReadResponse(br, nil) diff --git a/safesocket/basic_test.go b/safesocket/basic_test.go index 71422b8bc..292a3438a 100644 --- a/safesocket/basic_test.go +++ b/safesocket/basic_test.go @@ -4,6 +4,7 @@ package safesocket import ( + "context" "fmt" "path/filepath" "runtime" @@ -57,7 +58,7 @@ func TestBasics(t *testing.T) { }() go func() { - c, err := Connect(sock) + c, err := ConnectContext(context.Background(), sock) if err != nil { errs <- err return diff --git a/safesocket/pipe_windows.go b/safesocket/pipe_windows.go index eb78b2bc1..582834165 100644 --- a/safesocket/pipe_windows.go +++ b/safesocket/pipe_windows.go @@ -16,9 +16,8 @@ import ( "golang.org/x/sys/windows" ) -func connect(path string) (net.Conn, error) { - dl := time.Now().Add(20 * time.Second) - ctx, cancel := context.WithDeadline(context.Background(), dl) +func connect(ctx context.Context, path string) (net.Conn, error) { + ctx, cancel := context.WithTimeout(ctx, 20*time.Second) defer cancel() // We use the identification impersonation level so that tailscaled may // obtain information about our token for access control purposes. diff --git a/safesocket/safesocket.go b/safesocket/safesocket.go index 379b934ca..991fddf5f 100644 --- a/safesocket/safesocket.go +++ b/safesocket/safesocket.go @@ -6,6 +6,7 @@ package safesocket import ( + "context" "errors" "net" "runtime" @@ -52,11 +53,14 @@ func tailscaledStillStarting() bool { return tailscaledProcExists() } -// Connect connects to tailscaled using a unix socket or named pipe. -func Connect(path string) (net.Conn, error) { +// ConnectContext connects to tailscaled using a unix socket or named pipe. +func ConnectContext(ctx context.Context, path string) (net.Conn, error) { for { - c, err := connect(path) + c, err := connect(ctx, path) if err != nil && tailscaledStillStarting() { + if ctx.Err() != nil { + return nil, ctx.Err() + } time.Sleep(250 * time.Millisecond) continue } @@ -64,6 +68,12 @@ func Connect(path string) (net.Conn, error) { } } +// Connect connects to tailscaled using a unix socket or named pipe. +// Deprecated: use ConnectContext instead. +func Connect(path string) (net.Conn, error) { + return ConnectContext(context.Background(), path) +} + // Listen returns a listener either on Unix socket path (on Unix), or // the NamedPipe path (on Windows). func Listen(path string) (net.Listener, error) { diff --git a/safesocket/safesocket_js.go b/safesocket/safesocket_js.go index 396938355..38e615da4 100644 --- a/safesocket/safesocket_js.go +++ b/safesocket/safesocket_js.go @@ -4,6 +4,7 @@ package safesocket import ( + "context" "net" "github.com/akutz/memconn" @@ -15,6 +16,6 @@ func listen(path string) (net.Listener, error) { return memconn.Listen("memu", memName) } -func connect(_ string) (net.Conn, error) { - return memconn.Dial("memu", memName) +func connect(ctx context.Context, _ string) (net.Conn, error) { + return memconn.DialContext(ctx, "memu", memName) } diff --git a/safesocket/safesocket_plan9.go b/safesocket/safesocket_plan9.go index 32095bbd2..196c1df9c 100644 --- a/safesocket/safesocket_plan9.go +++ b/safesocket/safesocket_plan9.go @@ -6,6 +6,7 @@ package safesocket import ( + "context" "fmt" "net" "os" @@ -85,7 +86,7 @@ func (fc plan9FileConn) SetWriteDeadline(t time.Time) error { return syscall.EPLAN9 } -func connect(path string) (net.Conn, error) { +func connect(_ context.Context, path string) (net.Conn, error) { f, err := os.OpenFile(path, os.O_RDWR, 0666) if err != nil { return nil, err diff --git a/safesocket/unixsocket.go b/safesocket/unixsocket.go index 8a7162408..ef22263aa 100644 --- a/safesocket/unixsocket.go +++ b/safesocket/unixsocket.go @@ -6,6 +6,7 @@ package safesocket import ( + "context" "errors" "fmt" "log" @@ -16,11 +17,12 @@ import ( "runtime" ) -func connect(path string) (net.Conn, error) { +func connect(ctx context.Context, path string) (net.Conn, error) { if runtime.GOOS == "js" { return nil, errors.New("safesocket.Connect not yet implemented on js/wasm") } - return net.Dial("unix", path) + var std net.Dialer + return std.DialContext(ctx, "unix", path) } func listen(path string) (net.Listener, error) { diff --git a/tstest/integration/integration_test.go b/tstest/integration/integration_test.go index 1ec5e6390..236c02e32 100644 --- a/tstest/integration/integration_test.go +++ b/tstest/integration/integration_test.go @@ -1497,7 +1497,7 @@ func (n *testNode) Ping(otherNode *testNode) error { func (n *testNode) AwaitListening() { t := n.env.t if err := tstest.WaitFor(20*time.Second, func() (err error) { - c, err := safesocket.Connect(n.sockFile) + c, err := safesocket.ConnectContext(context.Background(), n.sockFile) if err == nil { c.Close() }