safesocket: add ConnectContext

This adds a variant for Connect that takes in a context.Context
which allows passing through cancellation etc by the caller.

Updates tailscale/corp#18266

Signed-off-by: Maisem Ali <maisem@tailscale.com>
pull/12367/head
Maisem Ali 5 months ago committed by Maisem Ali
parent 3672f66c74
commit 4b6a0c42c8

@ -167,7 +167,7 @@ func (s *FileSystemForRemote) buildChild(share *drive.Share) *compositedav.Child
return fmt.Sprintf("http://%s/%s/%s", hex.EncodeToString([]byte(share.Name)), secretToken, url.PathEscape(share.Name)), nil return fmt.Sprintf("http://%s/%s/%s", hex.EncodeToString([]byte(share.Name)), secretToken, url.PathEscape(share.Name)), nil
}, },
Transport: &http.Transport{ Transport: &http.Transport{
Dial: func(_, shareAddr string) (net.Conn, error) { DialContext: func(ctx context.Context, _, shareAddr string) (net.Conn, error) {
shareNameHex, _, err := net.SplitHostPort(shareAddr) shareNameHex, _, err := net.SplitHostPort(shareAddr)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to parse share address %v: %w", shareAddr, err) return nil, fmt.Errorf("unable to parse share address %v: %w", shareAddr, err)
@ -188,10 +188,11 @@ func (s *FileSystemForRemote) buildChild(share *drive.Share) *compositedav.Child
_, err = netip.ParseAddrPort(addr) _, err = netip.ParseAddrPort(addr)
if err == nil { if err == nil {
// this is a regular network address, dial normally // this is a regular network address, dial normally
return net.Dial("tcp", addr) var std net.Dialer
return std.DialContext(ctx, "tcp", addr)
} }
// assume this is a safesocket address // assume this is a safesocket address
return safesocket.Connect(addr) return safesocket.ConnectContext(ctx, addr)
}, },
}, },
} }

@ -709,7 +709,7 @@ func dialContext(ctx context.Context, netw, addr string, netMon *netmon.Monitor,
} }
if version.IsWindowsGUI() && strings.HasPrefix(netw, "tcp") { if version.IsWindowsGUI() && strings.HasPrefix(netw, "tcp") {
if c, err := safesocket.Connect(""); err == nil { if c, err := safesocket.ConnectContext(ctx, ""); err == nil {
fmt.Fprintf(c, "CONNECT %s HTTP/1.0\r\n\r\n", addr) fmt.Fprintf(c, "CONNECT %s HTTP/1.0\r\n\r\n", addr)
br := bufio.NewReader(c) br := bufio.NewReader(c)
res, err := http.ReadResponse(br, nil) res, err := http.ReadResponse(br, nil)

@ -4,6 +4,7 @@
package safesocket package safesocket
import ( import (
"context"
"fmt" "fmt"
"path/filepath" "path/filepath"
"runtime" "runtime"
@ -57,7 +58,7 @@ func TestBasics(t *testing.T) {
}() }()
go func() { go func() {
c, err := Connect(sock) c, err := ConnectContext(context.Background(), sock)
if err != nil { if err != nil {
errs <- err errs <- err
return return

@ -16,9 +16,8 @@ import (
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
) )
func connect(path string) (net.Conn, error) { func connect(ctx context.Context, path string) (net.Conn, error) {
dl := time.Now().Add(20 * time.Second) ctx, cancel := context.WithTimeout(ctx, 20*time.Second)
ctx, cancel := context.WithDeadline(context.Background(), dl)
defer cancel() defer cancel()
// We use the identification impersonation level so that tailscaled may // We use the identification impersonation level so that tailscaled may
// obtain information about our token for access control purposes. // obtain information about our token for access control purposes.

@ -6,6 +6,7 @@
package safesocket package safesocket
import ( import (
"context"
"errors" "errors"
"net" "net"
"runtime" "runtime"
@ -52,11 +53,14 @@ func tailscaledStillStarting() bool {
return tailscaledProcExists() return tailscaledProcExists()
} }
// Connect connects to tailscaled using a unix socket or named pipe. // ConnectContext connects to tailscaled using a unix socket or named pipe.
func Connect(path string) (net.Conn, error) { func ConnectContext(ctx context.Context, path string) (net.Conn, error) {
for { for {
c, err := connect(path) c, err := connect(ctx, path)
if err != nil && tailscaledStillStarting() { if err != nil && tailscaledStillStarting() {
if ctx.Err() != nil {
return nil, ctx.Err()
}
time.Sleep(250 * time.Millisecond) time.Sleep(250 * time.Millisecond)
continue continue
} }
@ -64,6 +68,12 @@ func Connect(path string) (net.Conn, error) {
} }
} }
// Connect connects to tailscaled using a unix socket or named pipe.
// Deprecated: use ConnectContext instead.
func Connect(path string) (net.Conn, error) {
return ConnectContext(context.Background(), path)
}
// Listen returns a listener either on Unix socket path (on Unix), or // Listen returns a listener either on Unix socket path (on Unix), or
// the NamedPipe path (on Windows). // the NamedPipe path (on Windows).
func Listen(path string) (net.Listener, error) { func Listen(path string) (net.Listener, error) {

@ -4,6 +4,7 @@
package safesocket package safesocket
import ( import (
"context"
"net" "net"
"github.com/akutz/memconn" "github.com/akutz/memconn"
@ -15,6 +16,6 @@ func listen(path string) (net.Listener, error) {
return memconn.Listen("memu", memName) return memconn.Listen("memu", memName)
} }
func connect(_ string) (net.Conn, error) { func connect(ctx context.Context, _ string) (net.Conn, error) {
return memconn.Dial("memu", memName) return memconn.DialContext(ctx, "memu", memName)
} }

@ -6,6 +6,7 @@
package safesocket package safesocket
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"os" "os"
@ -85,7 +86,7 @@ func (fc plan9FileConn) SetWriteDeadline(t time.Time) error {
return syscall.EPLAN9 return syscall.EPLAN9
} }
func connect(path string) (net.Conn, error) { func connect(_ context.Context, path string) (net.Conn, error) {
f, err := os.OpenFile(path, os.O_RDWR, 0666) f, err := os.OpenFile(path, os.O_RDWR, 0666)
if err != nil { if err != nil {
return nil, err return nil, err

@ -6,6 +6,7 @@
package safesocket package safesocket
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"log" "log"
@ -16,11 +17,12 @@ import (
"runtime" "runtime"
) )
func connect(path string) (net.Conn, error) { func connect(ctx context.Context, path string) (net.Conn, error) {
if runtime.GOOS == "js" { if runtime.GOOS == "js" {
return nil, errors.New("safesocket.Connect not yet implemented on js/wasm") return nil, errors.New("safesocket.Connect not yet implemented on js/wasm")
} }
return net.Dial("unix", path) var std net.Dialer
return std.DialContext(ctx, "unix", path)
} }
func listen(path string) (net.Listener, error) { func listen(path string) (net.Listener, error) {

@ -1497,7 +1497,7 @@ func (n *testNode) Ping(otherNode *testNode) error {
func (n *testNode) AwaitListening() { func (n *testNode) AwaitListening() {
t := n.env.t t := n.env.t
if err := tstest.WaitFor(20*time.Second, func() (err error) { if err := tstest.WaitFor(20*time.Second, func() (err error) {
c, err := safesocket.Connect(n.sockFile) c, err := safesocket.ConnectContext(context.Background(), n.sockFile)
if err == nil { if err == nil {
c.Close() c.Close()
} }

Loading…
Cancel
Save