@ -60,7 +60,7 @@ var stdDialer net.Dialer
//
//
// The provided ctx is only used for the initial connection, until
// The provided ctx is only used for the initial connection, until
// Dial returns. It does not affect the connection once established.
// Dial returns. It does not affect the connection once established.
func ( a * Dialer ) Dial ( ctx context . Context ) ( * controlbase. Conn, error ) {
func ( a * Dialer ) Dial ( ctx context . Context ) ( * Client Conn, error ) {
if a . Hostname == "" {
if a . Hostname == "" {
return nil , errors . New ( "required Dialer.Hostname empty" )
return nil , errors . New ( "required Dialer.Hostname empty" )
}
}
@ -91,7 +91,7 @@ func (a *Dialer) httpsFallbackDelay() time.Duration {
var _ = envknob . RegisterBool ( "TS_USE_CONTROL_DIAL_PLAN" ) // to record at init time whether it's in use
var _ = envknob . RegisterBool ( "TS_USE_CONTROL_DIAL_PLAN" ) // to record at init time whether it's in use
func ( a * Dialer ) dial ( ctx context . Context ) ( * controlbase. Conn, error ) {
func ( a * Dialer ) dial ( ctx context . Context ) ( * Client Conn, error ) {
// If we don't have a dial plan, just fall back to dialing the single
// If we don't have a dial plan, just fall back to dialing the single
// host we know about.
// host we know about.
useDialPlan := envknob . BoolDefaultTrue ( "TS_USE_CONTROL_DIAL_PLAN" )
useDialPlan := envknob . BoolDefaultTrue ( "TS_USE_CONTROL_DIAL_PLAN" )
@ -117,7 +117,7 @@ func (a *Dialer) dial(ctx context.Context) (*controlbase.Conn, error) {
// Now, for each candidate, kick off a dial in parallel.
// Now, for each candidate, kick off a dial in parallel.
type dialResult struct {
type dialResult struct {
conn * controlbase. Conn
conn * Client Conn
err error
err error
addr netip . Addr
addr netip . Addr
priority int
priority int
@ -129,7 +129,7 @@ func (a *Dialer) dial(ctx context.Context) (*controlbase.Conn, error) {
for _ , c := range candidates {
for _ , c := range candidates {
go func ( ctx context . Context , c tailcfg . ControlIPCandidate ) {
go func ( ctx context . Context , c tailcfg . ControlIPCandidate ) {
var (
var (
conn * controlbase. Conn
conn * Client Conn
err error
err error
)
)
@ -228,7 +228,7 @@ func (a *Dialer) dial(ctx context.Context) (*controlbase.Conn, error) {
} )
} )
var (
var (
conn * controlbase. Conn
conn * Client Conn
errs [ ] error
errs [ ] error
)
)
for i , result := range results {
for i , result := range results {
@ -252,7 +252,7 @@ func (a *Dialer) dial(ctx context.Context) (*controlbase.Conn, error) {
// dialHost connects to the configured Dialer.Hostname and upgrades the
// dialHost connects to the configured Dialer.Hostname and upgrades the
// connection into a controlbase.Conn. If addr is valid, then no DNS is used
// connection into a controlbase.Conn. If addr is valid, then no DNS is used
// and the connection will be made to the provided address.
// and the connection will be made to the provided address.
func ( a * Dialer ) dialHost ( ctx context . Context , addr netip . Addr ) ( * controlbase. Conn, error ) {
func ( a * Dialer ) dialHost ( ctx context . Context , addr netip . Addr ) ( * Client Conn, error ) {
// Create one shared context used by both port 80 and port 443 dials.
// Create one shared context used by both port 80 and port 443 dials.
// If port 80 is still in flight when 443 returns, this deferred cancel
// If port 80 is still in flight when 443 returns, this deferred cancel
// will stop the port 80 dial.
// will stop the port 80 dial.
@ -275,7 +275,7 @@ func (a *Dialer) dialHost(ctx context.Context, addr netip.Addr) (*controlbase.Co
type tryURLRes struct {
type tryURLRes struct {
u * url . URL // input (the URL conn+err are for/from)
u * url . URL // input (the URL conn+err are for/from)
conn * controlbase. Conn // result (mutually exclusive with err)
conn * Client Conn // result (mutually exclusive with err)
err error
err error
}
}
ch := make ( chan tryURLRes ) // must be unbuffered
ch := make ( chan tryURLRes ) // must be unbuffered
@ -331,12 +331,12 @@ func (a *Dialer) dialHost(ctx context.Context, addr netip.Addr) (*controlbase.Co
}
}
// dialURL attempts to connect to the given URL.
// dialURL attempts to connect to the given URL.
func ( a * Dialer ) dialURL ( ctx context . Context , u * url . URL , addr netip . Addr ) ( * controlbase. Conn, error ) {
func ( a * Dialer ) dialURL ( ctx context . Context , u * url . URL , addr netip . Addr ) ( * Client Conn, error ) {
init , cont , err := controlbase . ClientDeferred ( a . MachineKey , a . ControlKey , a . ProtocolVersion )
init , cont , err := controlbase . ClientDeferred ( a . MachineKey , a . ControlKey , a . ProtocolVersion )
if err != nil {
if err != nil {
return nil , err
return nil , err
}
}
netConn , err := a . tryURLUpgrade ( ctx , u , addr , init )
netConn , untrustedUpgradeHeaders, err := a . tryURLUpgrade ( ctx , u , addr , init )
if err != nil {
if err != nil {
return nil , err
return nil , err
}
}
@ -345,7 +345,10 @@ func (a *Dialer) dialURL(ctx context.Context, u *url.URL, addr netip.Addr) (*con
netConn . Close ( )
netConn . Close ( )
return nil , err
return nil , err
}
}
return cbConn , nil
return & ClientConn {
Conn : cbConn ,
UntrustedUpgradeHeaders : untrustedUpgradeHeaders ,
} , nil
}
}
// tryURLUpgrade connects to u, and tries to upgrade it to a net.Conn. If addr
// tryURLUpgrade connects to u, and tries to upgrade it to a net.Conn. If addr
@ -353,7 +356,7 @@ func (a *Dialer) dialURL(ctx context.Context, u *url.URL, addr netip.Addr) (*con
// provided address.
// provided address.
//
//
// Only the provided ctx is used, not a.ctx.
// Only the provided ctx is used, not a.ctx.
func ( a * Dialer ) tryURLUpgrade ( ctx context . Context , u * url . URL , addr netip . Addr , init [ ] byte ) ( net. Conn , error ) {
func ( a * Dialer ) tryURLUpgrade ( ctx context . Context , u * url . URL , addr netip . Addr , init [ ] byte ) ( _ net. Conn , untrustedUpgradeHeaders http . Header , _ error ) {
var dns * dnscache . Resolver
var dns * dnscache . Resolver
// If we were provided an address to dial, then create a resolver that just
// If we were provided an address to dial, then create a resolver that just
@ -435,11 +438,11 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, addr netip.Addr,
resp , err := tr . RoundTrip ( req )
resp , err := tr . RoundTrip ( req )
if err != nil {
if err != nil {
return nil , err
return nil , nil , err
}
}
if resp . StatusCode != http . StatusSwitchingProtocols {
if resp . StatusCode != http . StatusSwitchingProtocols {
return nil , fmt . Errorf ( "unexpected HTTP response: %s" , resp . Status )
return nil , nil , fmt . Errorf ( "unexpected HTTP response: %s" , resp . Status )
}
}
// From here on, the underlying net.Conn is ours to use, but there
// From here on, the underlying net.Conn is ours to use, but there
@ -453,19 +456,19 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, addr netip.Addr,
}
}
if switchedConn == nil {
if switchedConn == nil {
resp . Body . Close ( )
resp . Body . Close ( )
return nil , fmt . Errorf ( "httptrace didn't provide a connection" )
return nil , nil , fmt . Errorf ( "httptrace didn't provide a connection" )
}
}
if next := resp . Header . Get ( "Upgrade" ) ; next != upgradeHeaderValue {
if next := resp . Header . Get ( "Upgrade" ) ; next != upgradeHeaderValue {
resp . Body . Close ( )
resp . Body . Close ( )
return nil , fmt . Errorf ( "server switched to unexpected protocol %q" , next )
return nil , nil , fmt . Errorf ( "server switched to unexpected protocol %q" , next )
}
}
rwc , ok := resp . Body . ( io . ReadWriteCloser )
rwc , ok := resp . Body . ( io . ReadWriteCloser )
if ! ok {
if ! ok {
resp . Body . Close ( )
resp . Body . Close ( )
return nil , errors . New ( "http Transport did not provide a writable body" )
return nil , nil , errors . New ( "http Transport did not provide a writable body" )
}
}
return netutil . NewAltReadWriteCloserConn ( rwc , switchedConn ) , nil
return netutil . NewAltReadWriteCloserConn ( rwc , switchedConn ) , resp . Header , nil
}
}