tailcfg, control/controlhttp, control/controlclient: add ControlDialPlan field (#5648)

* tailcfg, control/controlhttp, control/controlclient: add ControlDialPlan field

This field allows the control server to provide explicit information
about how to connect to it; useful if the client's link status can
change after the initial connection, or if the DNS settings pushed by
the control server break future connections.

Change-Id: I720afe6289ec27d40a41b3dcb310ec45bd7e5f3e
Signed-off-by: Andrew Dunham <andrew@tailscale.com>
pull/5728/head
Andrew Dunham 2 years ago committed by GitHub
parent acc7baac6d
commit e1bdbfe710
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -100,6 +100,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
tailscale.com/util/groupmember from tailscale.com/cmd/tailscale/cli tailscale.com/util/groupmember from tailscale.com/cmd/tailscale/cli
tailscale.com/util/lineread from tailscale.com/net/interfaces+ tailscale.com/util/lineread from tailscale.com/net/interfaces+
tailscale.com/util/mak from tailscale.com/net/netcheck tailscale.com/util/mak from tailscale.com/net/netcheck
tailscale.com/util/multierr from tailscale.com/control/controlhttp
tailscale.com/util/singleflight from tailscale.com/net/dnscache tailscale.com/util/singleflight from tailscale.com/net/dnscache
L tailscale.com/util/strs from tailscale.com/hostinfo L tailscale.com/util/strs from tailscale.com/hostinfo
W 💣 tailscale.com/util/winutil from tailscale.com/hostinfo+ W 💣 tailscale.com/util/winutil from tailscale.com/hostinfo+

@ -76,6 +76,8 @@ type Direct struct {
popBrowser func(url string) // or nil popBrowser func(url string) // or nil
c2nHandler http.Handler // or nil c2nHandler http.Handler // or nil
dialPlan ControlDialPlanner // can be nil
mu sync.Mutex // mutex guards the following fields mu sync.Mutex // mutex guards the following fields
serverKey key.MachinePublic // original ("legacy") nacl crypto_box-based public key serverKey key.MachinePublic // original ("legacy") nacl crypto_box-based public key
serverNoiseKey key.MachinePublic serverNoiseKey key.MachinePublic
@ -133,6 +135,34 @@ type Options struct {
// MapResponse.PingRequest queries from the control plane. // MapResponse.PingRequest queries from the control plane.
// If nil, PingRequest queries are not answered. // If nil, PingRequest queries are not answered.
Pinger Pinger Pinger Pinger
// DialPlan contains and stores a previous dial plan that we received
// from the control server; if nil, we fall back to using DNS.
//
// If we receive a new DialPlan from the server, this value will be
// updated.
DialPlan ControlDialPlanner
}
// ControlDialPlanner is the interface optionally supplied when creating a
// control client to control exactly how TCP connections to the control plane
// are dialed.
//
// It is usually implemented by an atomic.Pointer.
type ControlDialPlanner interface {
// Load returns the current plan for how to connect to control.
//
// The returned plan can be nil. If so, connections should be made by
// resolving the control URL using DNS.
Load() *tailcfg.ControlDialPlan
// Store updates the dial plan with new directions from the control
// server.
//
// The dial plan can span multiple connections to the control server.
// That is, a dial plan received when connected over Wi-Fi is still
// valid for a subsequent connection over LTE after a network switch.
Store(*tailcfg.ControlDialPlan)
} }
// Pinger is the LocalBackend.Ping method. // Pinger is the LocalBackend.Ping method.
@ -216,6 +246,7 @@ func NewDirect(opts Options) (*Direct, error) {
popBrowser: opts.PopBrowserURL, popBrowser: opts.PopBrowserURL,
c2nHandler: opts.C2NHandler, c2nHandler: opts.C2NHandler,
dialer: opts.Dialer, dialer: opts.Dialer,
dialPlan: opts.DialPlan,
} }
if opts.Hostinfo == nil { if opts.Hostinfo == nil {
c.SetHostinfo(hostinfo.New()) c.SetHostinfo(hostinfo.New())
@ -915,6 +946,14 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, readOnly bool
} else { } else {
vlogf("netmap: got new map") vlogf("netmap: got new map")
} }
if resp.ControlDialPlan != nil {
if c.dialPlan != nil {
c.logf("netmap: got new dial plan from control")
c.dialPlan.Store(resp.ControlDialPlan)
} else {
c.logf("netmap: [unexpected] new dial plan; nowhere to store it")
}
}
select { select {
case timeoutReset <- struct{}{}: case timeoutReset <- struct{}{}:
@ -1365,12 +1404,17 @@ func (c *Direct) getNoiseClient() (*noiseClient, error) {
if nc != nil { if nc != nil {
return nc, nil return nc, nil
} }
var dp func() *tailcfg.ControlDialPlan
if c.dialPlan != nil {
dp = c.dialPlan.Load
}
nc, err, _ := c.sfGroup.Do(struct{}{}, func() (*noiseClient, error) { nc, err, _ := c.sfGroup.Do(struct{}{}, func() (*noiseClient, error) {
k, err := c.getMachinePrivKey() k, err := c.getMachinePrivKey()
if err != nil { if err != nil {
return nil, err return nil, err
} }
nc, err := newNoiseClient(k, serverNoiseKey, c.serverURL, c.dialer) c.logf("creating new noise client")
nc, err := newNoiseClient(k, serverNoiseKey, c.serverURL, c.dialer, dp)
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -53,6 +53,11 @@ type noiseClient struct {
httpPort string // the default port to call httpPort string // the default port to call
httpsPort string // the fallback Noise-over-https port httpsPort string // the fallback Noise-over-https port
// dialPlan optionally returns a ControlDialPlan previously received
// from the control server; either the function or the return value can
// be nil.
dialPlan func() *tailcfg.ControlDialPlan
// mu only protects the following variables. // mu only protects the following variables.
mu sync.Mutex mu sync.Mutex
nextID int nextID int
@ -61,7 +66,9 @@ type noiseClient struct {
// newNoiseClient returns a new noiseClient for the provided server and machine key. // newNoiseClient returns a new noiseClient for the provided server and machine key.
// serverURL is of the form https://<host>:<port> (no trailing slash). // serverURL is of the form https://<host>:<port> (no trailing slash).
func newNoiseClient(priKey key.MachinePrivate, serverPubKey key.MachinePublic, serverURL string, dialer *tsdial.Dialer) (*noiseClient, error) { //
// dialPlan may be nil
func newNoiseClient(priKey key.MachinePrivate, serverPubKey key.MachinePublic, serverURL string, dialer *tsdial.Dialer, dialPlan func() *tailcfg.ControlDialPlan) (*noiseClient, error) {
u, err := url.Parse(serverURL) u, err := url.Parse(serverURL)
if err != nil { if err != nil {
return nil, err return nil, err
@ -89,6 +96,7 @@ func newNoiseClient(priKey key.MachinePrivate, serverPubKey key.MachinePublic, s
httpPort: httpPort, httpPort: httpPort,
httpsPort: httpsPort, httpsPort: httpsPort,
dialer: dialer, dialer: dialer,
dialPlan: dialPlan,
} }
// Create the HTTP/2 Transport using a net/http.Transport // Create the HTTP/2 Transport using a net/http.Transport
@ -155,16 +163,51 @@ func (nc *noiseClient) dial(_, _ string, _ *tls.Config) (net.Conn, error) {
nc.nextID++ nc.nextID++
nc.mu.Unlock() nc.mu.Unlock()
// Timeout is a little arbitrary, but plenty long enough for even the
// highest latency links.
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if tailcfg.CurrentCapabilityVersion > math.MaxUint16 { if tailcfg.CurrentCapabilityVersion > math.MaxUint16 {
// Panic, because a test should have started failing several // Panic, because a test should have started failing several
// thousand version numbers before getting to this point. // thousand version numbers before getting to this point.
panic("capability version is too high to fit in the wire protocol") panic("capability version is too high to fit in the wire protocol")
} }
var dialPlan *tailcfg.ControlDialPlan
if nc.dialPlan != nil {
dialPlan = nc.dialPlan()
}
// If we have a dial plan, then set our timeout as slightly longer than
// the maximum amount of time contained therein; we assume that
// explicit instructions on timeouts are more useful than a single
// hard-coded timeout.
//
// The default value of 5 is chosen so that, when there's no dial plan,
// we retain the previous behaviour of 10 seconds end-to-end timeout.
timeoutSec := 5.0
if dialPlan != nil {
for _, c := range dialPlan.Candidates {
if v := c.DialStartDelaySec + c.DialTimeoutSec; v > timeoutSec {
timeoutSec = v
}
}
}
// After we establish a connection, we need some time to actually
// upgrade it into a Noise connection. With a ballpark worst-case RTT
// of 1000ms, give ourselves an extra 5 seconds to complete the
// handshake.
timeoutSec += 5
// Be extremely defensive and ensure that the timeout is in the range
// [5, 60] seconds (e.g. if we accidentally get a negative number).
if timeoutSec > 60 {
timeoutSec = 60
} else if timeoutSec < 5 {
timeoutSec = 5
}
timeout := time.Duration(timeoutSec * float64(time.Second))
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
conn, err := (&controlhttp.Dialer{ conn, err := (&controlhttp.Dialer{
Hostname: nc.host, Hostname: nc.host,
HTTPPort: nc.httpPort, HTTPPort: nc.httpPort,
@ -173,6 +216,7 @@ func (nc *noiseClient) dial(_, _ string, _ *tls.Config) (net.Conn, error) {
ControlKey: nc.serverPubKey, ControlKey: nc.serverPubKey,
ProtocolVersion: uint16(tailcfg.CurrentCapabilityVersion), ProtocolVersion: uint16(tailcfg.CurrentCapabilityVersion),
Dialer: nc.dialer.SystemDial, Dialer: nc.dialer.SystemDial,
DialPlan: dialPlan,
}).Dial(ctx) }).Dial(ctx)
if err != nil { if err != nil {
return nil, err return nil, err

@ -28,18 +28,25 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"math"
"net" "net"
"net/http" "net/http"
"net/http/httptrace" "net/http/httptrace"
"net/netip"
"net/url" "net/url"
"sort"
"sync/atomic"
"time" "time"
"tailscale.com/control/controlbase" "tailscale.com/control/controlbase"
"tailscale.com/envknob"
"tailscale.com/net/dnscache" "tailscale.com/net/dnscache"
"tailscale.com/net/dnsfallback" "tailscale.com/net/dnsfallback"
"tailscale.com/net/netutil" "tailscale.com/net/netutil"
"tailscale.com/net/tlsdial" "tailscale.com/net/tlsdial"
"tailscale.com/net/tshttpproxy" "tailscale.com/net/tshttpproxy"
"tailscale.com/tailcfg"
"tailscale.com/util/multierr"
) )
var stdDialer net.Dialer var stdDialer net.Dialer
@ -82,7 +89,170 @@ func (a *Dialer) httpsFallbackDelay() time.Duration {
return 500 * time.Millisecond return 500 * time.Millisecond
} }
var _ = envknob.RegisterBool("TS_USE_CONTROL_DIAL_PLAN") // to record at init time whether it's in use
func (a *Dialer) dial(ctx context.Context) (*controlbase.Conn, error) { func (a *Dialer) dial(ctx context.Context) (*controlbase.Conn, error) {
// If we don't have a dial plan, just fall back to dialing the single
// host we know about.
useDialPlan := envknob.BoolDefaultTrue("TS_USE_CONTROL_DIAL_PLAN")
if !useDialPlan || a.DialPlan == nil || len(a.DialPlan.Candidates) == 0 {
return a.dialHost(ctx, netip.Addr{})
}
candidates := a.DialPlan.Candidates
// Otherwise, we try dialing per the plan. Store the highest priority
// in the list, so that if we get a connection to one of those
// candidates we can return quickly.
var highestPriority int = math.MinInt
for _, c := range candidates {
if c.Priority > highestPriority {
highestPriority = c.Priority
}
}
// This context allows us to cancel in-flight connections if we get a
// highest-priority connection before we're all done.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
// Now, for each candidate, kick off a dial in parallel.
type dialResult struct {
conn *controlbase.Conn
err error
addr netip.Addr
priority int
}
resultsCh := make(chan dialResult, len(candidates))
var pending atomic.Int32
pending.Store(int32(len(candidates)))
for _, c := range candidates {
go func(ctx context.Context, c tailcfg.ControlIPCandidate) {
var (
conn *controlbase.Conn
err error
)
// Always send results back to our channel.
defer func() {
resultsCh <- dialResult{conn, err, c.IP, c.Priority}
if pending.Add(-1) == 0 {
close(resultsCh)
}
}()
// If non-zero, wait the configured start timeout
// before we do anything.
if c.DialStartDelaySec > 0 {
a.logf("[v2] controlhttp: waiting %.2f seconds before dialing %q @ %v", c.DialStartDelaySec, a.Hostname, c.IP)
tmr := time.NewTimer(time.Duration(c.DialStartDelaySec * float64(time.Second)))
defer tmr.Stop()
select {
case <-ctx.Done():
err = ctx.Err()
return
case <-tmr.C:
}
}
// Now, create a sub-context with the given timeout and
// try dialing the provided host.
ctx, cancel := context.WithTimeout(ctx, time.Duration(c.DialTimeoutSec*float64(time.Second)))
defer cancel()
// This will dial, and the defer above sends it back to our parent.
a.logf("[v2] controlhttp: trying to dial %q @ %v", a.Hostname, c.IP)
conn, err = a.dialHost(ctx, c.IP)
}(ctx, c)
}
var results []dialResult
for res := range resultsCh {
// If we get a response that has the highest priority, we don't
// need to wait for any of the other connections to finish; we
// can just return this connection.
//
// TODO(andrew): we could make this better by keeping track of
// the highest remaining priority dynamically, instead of just
// checking for the highest total
if res.priority == highestPriority && res.conn != nil {
a.logf("[v1] controlhttp: high-priority success dialing %q @ %v from dial plan", a.Hostname, res.addr)
// Drain the channel and any existing connections in
// the background.
go func() {
for _, res := range results {
if res.conn != nil {
res.conn.Close()
}
}
for res := range resultsCh {
if res.conn != nil {
res.conn.Close()
}
}
if a.drainFinished != nil {
close(a.drainFinished)
}
}()
return res.conn, nil
}
// This isn't a highest-priority result, so just store it until
// we're done.
results = append(results, res)
}
// After we finish this function, close any remaining open connections.
defer func() {
for _, result := range results {
// Note: below, we nil out the returned connection (if
// any) in the slice so we don't close it.
if result.conn != nil {
result.conn.Close()
}
}
// We don't drain asynchronously after this point, so notify our
// channel when we return.
if a.drainFinished != nil {
close(a.drainFinished)
}
}()
// Sort by priority, then take the first non-error response.
sort.Slice(results, func(i, j int) bool {
// NOTE: intentionally inverted so that the highest priority
// item comes first
return results[i].priority > results[j].priority
})
var (
conn *controlbase.Conn
errs []error
)
for i, result := range results {
if result.err != nil {
errs = append(errs, result.err)
continue
}
a.logf("[v1] controlhttp: succeeded dialing %q @ %v from dial plan", a.Hostname, result.addr)
conn = result.conn
results[i].conn = nil // so we don't close it in the defer
return conn, nil
}
merr := multierr.New(errs...)
// If we get here, then we didn't get anywhere with our dial plan; fall back to just using DNS.
a.logf("controlhttp: failed dialing using DialPlan, falling back to DNS; errs=%s", merr.Error())
return a.dialHost(ctx, netip.Addr{})
}
// dialHost connects to the configured Dialer.Hostname and upgrades the
// connection into a controlbase.Conn. If addr is valid, then no DNS is used
// and the connection will be made to the provided address.
func (a *Dialer) dialHost(ctx context.Context, addr netip.Addr) (*controlbase.Conn, error) {
// Create one shared context used by both port 80 and port 443 dials. // Create one shared context used by both port 80 and port 443 dials.
// If port 80 is still in flight when 443 returns, this deferred cancel // If port 80 is still in flight when 443 returns, this deferred cancel
// will stop the port 80 dial. // will stop the port 80 dial.
@ -110,7 +280,7 @@ func (a *Dialer) dial(ctx context.Context) (*controlbase.Conn, error) {
} }
ch := make(chan tryURLRes) // must be unbuffered ch := make(chan tryURLRes) // must be unbuffered
try := func(u *url.URL) { try := func(u *url.URL) {
cbConn, err := a.dialURL(ctx, u) cbConn, err := a.dialURL(ctx, u, addr)
select { select {
case ch <- tryURLRes{u, cbConn, err}: case ch <- tryURLRes{u, cbConn, err}:
case <-ctx.Done(): case <-ctx.Done():
@ -161,12 +331,12 @@ func (a *Dialer) dial(ctx context.Context) (*controlbase.Conn, error) {
} }
// dialURL attempts to connect to the given URL. // dialURL attempts to connect to the given URL.
func (a *Dialer) dialURL(ctx context.Context, u *url.URL) (*controlbase.Conn, error) { func (a *Dialer) dialURL(ctx context.Context, u *url.URL, addr netip.Addr) (*controlbase.Conn, error) {
init, cont, err := controlbase.ClientDeferred(a.MachineKey, a.ControlKey, a.ProtocolVersion) init, cont, err := controlbase.ClientDeferred(a.MachineKey, a.ControlKey, a.ProtocolVersion)
if err != nil { if err != nil {
return nil, err return nil, err
} }
netConn, err := a.tryURLUpgrade(ctx, u, init) netConn, err := a.tryURLUpgrade(ctx, u, addr, init)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -178,15 +348,28 @@ func (a *Dialer) dialURL(ctx context.Context, u *url.URL) (*controlbase.Conn, er
return cbConn, nil return cbConn, nil
} }
// tryURLUpgrade connects to u, and tries to upgrade it to a net.Conn. // tryURLUpgrade connects to u, and tries to upgrade it to a net.Conn. If addr
// is valid, then no DNS is used and the connection will be made to the
// provided address.
// //
// Only the provided ctx is used, not a.ctx. // Only the provided ctx is used, not a.ctx.
func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, init []byte) (net.Conn, error) { func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, addr netip.Addr, init []byte) (net.Conn, error) {
dns := &dnscache.Resolver{ var dns *dnscache.Resolver
// If we were provided an address to dial, then create a resolver that just
// returns that value; otherwise, fall back to DNS.
if addr.IsValid() {
dns = &dnscache.Resolver{
SingleHostStaticResult: []netip.Addr{addr},
SingleHost: u.Hostname(),
}
} else {
dns = &dnscache.Resolver{
Forward: dnscache.Get().Forward, Forward: dnscache.Get().Forward,
LookupIPFallback: dnsfallback.Lookup, LookupIPFallback: dnsfallback.Lookup,
UseLastGood: true, UseLastGood: true,
} }
}
var dialer dnscache.DialContextFunc var dialer dnscache.DialContextFunc
if a.Dialer != nil { if a.Dialer != nil {

@ -10,6 +10,7 @@ import (
"time" "time"
"tailscale.com/net/dnscache" "tailscale.com/net/dnscache"
"tailscale.com/tailcfg"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/logger" "tailscale.com/types/logger"
) )
@ -70,9 +71,15 @@ type Dialer struct {
// dropped. // dropped.
Logf logger.Logf Logf logger.Logf
// DialPlan, if set, contains instructions from the control server on
// how to connect to it. If present, we will try the methods in this
// plan before falling back to DNS.
DialPlan *tailcfg.ControlDialPlan
proxyFunc func(*http.Request) (*url.URL, error) // or nil proxyFunc func(*http.Request) (*url.URL, error) // or nil
// For tests only // For tests only
drainFinished chan struct{}
insecureTLS bool insecureTLS bool
testFallbackDelay time.Duration testFallbackDelay time.Duration
} }

@ -13,16 +13,21 @@ import (
"net" "net"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"net/netip"
"net/url" "net/url"
"runtime"
"strconv" "strconv"
"sync" "sync"
"testing" "testing"
"time" "time"
"tailscale.com/control/controlbase" "tailscale.com/control/controlbase"
"tailscale.com/net/dnscache"
"tailscale.com/net/socks5" "tailscale.com/net/socks5"
"tailscale.com/net/tsdial" "tailscale.com/net/tsdial"
"tailscale.com/tailcfg"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/logger"
) )
type httpTestParam struct { type httpTestParam struct {
@ -444,3 +449,263 @@ func brokenMITMHandler(w http.ResponseWriter, r *http.Request) {
w.(http.Flusher).Flush() w.(http.Flusher).Flush()
<-r.Context().Done() <-r.Context().Done()
} }
func TestDialPlan(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skip("only works on Linux due to multiple localhost addresses")
}
client, server := key.NewMachine(), key.NewMachine()
const (
testProtocolVersion = 1
// We need consistent ports for each address; these are chosen
// randomly and we hope that they won't conflict during this test.
httpPort = "40080"
httpsPort = "40443"
)
makeHandler := func(t *testing.T, name string, host netip.Addr, wrap func(http.Handler) http.Handler) {
done := make(chan struct{})
t.Cleanup(func() {
close(done)
})
var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := AcceptHTTP(context.Background(), w, r, server)
if err != nil {
log.Print(err)
} else {
defer conn.Close()
}
w.Header().Set("X-Handler-Name", name)
<-done
})
if wrap != nil {
handler = wrap(handler)
}
httpLn, err := net.Listen("tcp", host.String()+":"+httpPort)
if err != nil {
t.Fatalf("HTTP listen: %v", err)
}
httpsLn, err := net.Listen("tcp", host.String()+":"+httpsPort)
if err != nil {
t.Fatalf("HTTPS listen: %v", err)
}
httpServer := &http.Server{Handler: handler}
go httpServer.Serve(httpLn)
t.Cleanup(func() {
httpServer.Close()
})
httpsServer := &http.Server{
Handler: handler,
TLSConfig: tlsConfig(t),
ErrorLog: logger.StdLogger(logger.WithPrefix(t.Logf, "http.Server.ErrorLog: ")),
}
go httpsServer.ServeTLS(httpsLn, "", "")
t.Cleanup(func() {
httpsServer.Close()
})
return
}
fallbackAddr := netip.MustParseAddr("127.0.0.1")
goodAddr := netip.MustParseAddr("127.0.0.2")
otherAddr := netip.MustParseAddr("127.0.0.3")
other2Addr := netip.MustParseAddr("127.0.0.4")
brokenAddr := netip.MustParseAddr("127.0.0.10")
testCases := []struct {
name string
plan *tailcfg.ControlDialPlan
wrap func(http.Handler) http.Handler
want netip.Addr
allowFallback bool
}{
{
name: "single",
plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
{IP: goodAddr, Priority: 1, DialTimeoutSec: 10},
}},
want: goodAddr,
},
{
name: "broken-then-good",
plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
// Dials the broken one, which fails, and then
// eventually dials the good one and succeeds
{IP: brokenAddr, Priority: 2, DialTimeoutSec: 10},
{IP: goodAddr, Priority: 1, DialTimeoutSec: 10, DialStartDelaySec: 1},
}},
want: goodAddr,
},
{
name: "multiple-priority-fast-path",
plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
// Dials some good IPs and our bad one (which
// hangs forever), which then hits the fast
// path where we bail without waiting.
{IP: brokenAddr, Priority: 1, DialTimeoutSec: 10},
{IP: goodAddr, Priority: 1, DialTimeoutSec: 10},
{IP: other2Addr, Priority: 1, DialTimeoutSec: 10},
{IP: otherAddr, Priority: 2, DialTimeoutSec: 10},
}},
want: otherAddr,
},
{
name: "multiple-priority-slow-path",
plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
// Our broken address is the highest priority,
// so we don't hit our fast path.
{IP: brokenAddr, Priority: 10, DialTimeoutSec: 10},
{IP: otherAddr, Priority: 2, DialTimeoutSec: 10},
{IP: goodAddr, Priority: 1, DialTimeoutSec: 10},
}},
want: otherAddr,
},
{
name: "fallback",
plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
{IP: brokenAddr, Priority: 1, DialTimeoutSec: 1},
}},
want: fallbackAddr,
allowFallback: true,
},
}
for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
makeHandler(t, "fallback", fallbackAddr, nil)
makeHandler(t, "good", goodAddr, nil)
makeHandler(t, "other", otherAddr, nil)
makeHandler(t, "other2", other2Addr, nil)
makeHandler(t, "broken", brokenAddr, func(h http.Handler) http.Handler {
return http.HandlerFunc(brokenMITMHandler)
})
dialer := closeTrackDialer{
t: t,
inner: new(tsdial.Dialer).SystemDial,
conns: make(map[*closeTrackConn]bool),
}
defer dialer.Done()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// By default, we intentionally point to something that
// we know won't connect, since we want a fallback to
// DNS to be an error.
host := "example.com"
if tt.allowFallback {
host = "localhost"
}
drained := make(chan struct{})
a := &Dialer{
Hostname: host,
HTTPPort: httpPort,
HTTPSPort: httpsPort,
MachineKey: client,
ControlKey: server.Public(),
ProtocolVersion: testProtocolVersion,
Dialer: dialer.Dial,
Logf: t.Logf,
DialPlan: tt.plan,
proxyFunc: func(*http.Request) (*url.URL, error) { return nil, nil },
drainFinished: drained,
insecureTLS: true,
testFallbackDelay: 50 * time.Millisecond,
}
conn, err := a.dial(ctx)
if err != nil {
t.Fatalf("dialing controlhttp: %v", err)
}
defer conn.Close()
raddr := conn.RemoteAddr().(*net.TCPAddr)
got, ok := netip.AddrFromSlice(raddr.IP)
if !ok {
t.Errorf("invalid remote IP: %v", raddr.IP)
} else if got != tt.want {
t.Errorf("got connection from %q; want %q", got, tt.want)
} else {
t.Logf("successfully connected to %q", raddr.String())
}
// Wait until our dialer drains so we can verify that
// all connections are closed.
<-drained
})
}
}
type closeTrackDialer struct {
t testing.TB
inner dnscache.DialContextFunc
mu sync.Mutex
conns map[*closeTrackConn]bool
}
func (d *closeTrackDialer) Dial(ctx context.Context, network, addr string) (net.Conn, error) {
c, err := d.inner(ctx, network, addr)
if err != nil {
return nil, err
}
ct := &closeTrackConn{Conn: c, d: d}
d.mu.Lock()
d.conns[ct] = true
d.mu.Unlock()
return ct, nil
}
func (d *closeTrackDialer) Done() {
// Unfortunately, tsdial.Dialer.SystemDial closes connections
// asynchronously in a goroutine, so we can't assume that everything is
// closed by the time we get here.
//
// Sleep/wait a few times on the assumption that things will close
// "eventually".
const iters = 100
for i := 0; i < iters; i++ {
d.mu.Lock()
if len(d.conns) == 0 {
d.mu.Unlock()
return
}
// Only error on last iteration
if i != iters-1 {
d.mu.Unlock()
time.Sleep(100 * time.Millisecond)
continue
}
for conn := range d.conns {
d.t.Errorf("expected close of conn %p; RemoteAddr=%q", conn, conn.RemoteAddr().String())
}
d.mu.Unlock()
}
}
func (d *closeTrackDialer) noteClose(c *closeTrackConn) {
d.mu.Lock()
delete(d.conns, c) // safe if already deleted
d.mu.Unlock()
}
type closeTrackConn struct {
net.Conn
d *closeTrackDialer
}
func (c *closeTrackConn) Close() error {
c.d.noteClose(c)
return c.Conn.Close()
}

@ -189,6 +189,10 @@ type LocalBackend struct {
// statusChanged.Broadcast(). // statusChanged.Broadcast().
statusLock sync.Mutex statusLock sync.Mutex
statusChanged *sync.Cond statusChanged *sync.Cond
// dialPlan is any dial plan that we've received from the control
// server during a previous connection; it is cleared on logout.
dialPlan atomic.Pointer[tailcfg.ControlDialPlan]
} }
// clientGen is a func that creates a control plane client. // clientGen is a func that creates a control plane client.
@ -1087,6 +1091,7 @@ func (b *LocalBackend) Start(opts ipn.Options) error {
Dialer: b.Dialer(), Dialer: b.Dialer(),
Status: b.setClientStatus, Status: b.setClientStatus,
C2NHandler: http.HandlerFunc(b.handleC2N), C2NHandler: http.HandlerFunc(b.handleC2N),
DialPlan: &b.dialPlan, // pointer because it can't be copied
// Don't warn about broken Linux IP forwarding when // Don't warn about broken Linux IP forwarding when
// netstack is being used. // netstack is being used.
@ -3112,6 +3117,9 @@ func (b *LocalBackend) logout(ctx context.Context, sync bool) error {
Prefs: ipn.Prefs{WantRunning: false, LoggedOut: true}, Prefs: ipn.Prefs{WantRunning: false, LoggedOut: true},
}) })
// Clear any previous dial plan(s), if set.
b.dialPlan.Store(nil)
if cc == nil { if cc == nil {
// Double Logout can happen via repeated IPN // Double Logout can happen via repeated IPN
// connections to ipnserver making it repeatedly // connections to ipnserver making it repeatedly

@ -80,7 +80,8 @@ type CapabilityVersion int
// - 41: 2022-08-30: uses 100.100.100.100 for route-less ExtraRecords if global nameservers is set // - 41: 2022-08-30: uses 100.100.100.100 for route-less ExtraRecords if global nameservers is set
// - 42: 2022-09-06: NextDNS DoH support; see https://github.com/tailscale/tailscale/pull/5556 // - 42: 2022-09-06: NextDNS DoH support; see https://github.com/tailscale/tailscale/pull/5556
// - 43: 2022-09-21: clients can return usernames for SSH // - 43: 2022-09-21: clients can return usernames for SSH
const CurrentCapabilityVersion CapabilityVersion = 43 // - 44: 2022-09-22: MapResponse.ControlDialPlan
const CurrentCapabilityVersion CapabilityVersion = 44
type StableID string type StableID string
@ -1383,6 +1384,40 @@ type MapResponse struct {
// Debug is normally nil, except for when the control server // Debug is normally nil, except for when the control server
// is setting debug settings on a node. // is setting debug settings on a node.
Debug *Debug `json:",omitempty"` Debug *Debug `json:",omitempty"`
// ControlDialPlan tells the client how to connect to the control
// server. An initial nil is equivalent to new(ControlDialPlan).
// A subsequent streamed nil means no change.
ControlDialPlan *ControlDialPlan `json:",omitempty"`
}
// ControlDialPlan is instructions from the control server to the client on how
// to connect to the control server; this is useful for maintaining connection
// if the client's network state changes after the initial connection, or due
// to the configuration that the control server pushes.
type ControlDialPlan struct {
// An empty list means the default: use DNS (unspecified which DNS).
Candidates []ControlIPCandidate
}
// ControlIPCandidate represents a single candidate address to use when
// connecting to the control server.
type ControlIPCandidate struct {
// IP is the address to attempt connecting to.
IP netip.Addr
// DialStartSec is the number of seconds after the beginning of the
// connection process to wait before trying this candidate.
DialStartDelaySec float64 `json:",omitempty"`
// DialTimeoutSec is the timeout for a connection to this candidate,
// starting after DialStartDelaySec.
DialTimeoutSec float64 `json:",omitempty"`
// Priority is the relative priority of this candidate; candidates with
// a higher priority are preferred over candidates with a lower
// priority.
Priority int `json:",omitempty"`
} }
// Debug are instructions from the control server to the client // Debug are instructions from the control server to the client

Loading…
Cancel
Save