diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index 9ac6a2022..2254b3229 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -63,6 +63,15 @@ type Server struct { mu sync.Mutex listeners map[listenKey]*listener + dialer *tsdial.Dialer +} + +// Dial connects to the address on the tailnet. +func (s *Server) Dial(ctx context.Context, network, address string) (net.Conn, error) { + if err := s.init(); err != nil { + return nil, err + } + return s.dialer.UserDial(ctx, network, address) } func (s *Server) doInit() { @@ -71,6 +80,11 @@ func (s *Server) doInit() { } } +func (s *Server) init() error { + s.initOnce.Do(s.doInit) + return s.initErr +} + func (s *Server) start() error { if v, _ := strconv.ParseBool(os.Getenv("TAILSCALE_USE_WIP_CODE")); !v { return errors.New("code disabled without environment variable TAILSCALE_USE_WIP_CODE set true") @@ -117,11 +131,11 @@ func (s *Server) start() error { return err } - dialer := new(tsdial.Dialer) // mutated below (before used) + s.dialer = new(tsdial.Dialer) // mutated below (before used) eng, err := wgengine.NewUserspaceEngine(logf, wgengine.Config{ ListenPort: 0, LinkMonitor: linkMon, - Dialer: dialer, + Dialer: s.dialer, }) if err != nil { return err @@ -132,7 +146,7 @@ func (s *Server) start() error { return fmt.Errorf("%T is not a wgengine.InternalsGetter", eng) } - ns, err := netstack.Create(logf, tunDev, eng, magicConn, dialer) + ns, err := netstack.Create(logf, tunDev, eng, magicConn, s.dialer) if err != nil { return fmt.Errorf("netstack.Create: %w", err) } @@ -141,11 +155,11 @@ func (s *Server) start() error { if err := ns.Start(); err != nil { return fmt.Errorf("failed to start netstack: %w", err) } - dialer.UseNetstackForIP = func(ip netaddr.IP) bool { + s.dialer.UseNetstackForIP = func(ip netaddr.IP) bool { _, ok := eng.PeerForIP(ip) return ok } - dialer.NetstackDialTCP = func(ctx context.Context, dst netaddr.IPPort) (net.Conn, error) { + s.dialer.NetstackDialTCP = func(ctx context.Context, dst netaddr.IPPort) (net.Conn, error) { return ns.DialContextTCP(ctx, dst) } @@ -156,7 +170,7 @@ func (s *Server) start() error { } logid := "tslib-TODO" - lb, err := ipnlocal.NewLocalBackend(logf, logid, store, dialer, eng) + lb, err := ipnlocal.NewLocalBackend(logf, logid, store, s.dialer, eng) if err != nil { return fmt.Errorf("NewLocalBackend: %v", err) } @@ -217,15 +231,15 @@ func (s *Server) forwardTCP(c net.Conn, port uint16) { } } +// Listen announces only on the Tailscale network. func (s *Server) Listen(network, addr string) (net.Listener, error) { host, port, err := net.SplitHostPort(addr) if err != nil { return nil, fmt.Errorf("tsnet: %w", err) } - s.initOnce.Do(s.doInit) - if s.initErr != nil { - return nil, s.initErr + if err := s.init(); err != nil { + return nil, err } key := listenKey{network, host, port}