diff --git a/control/controlhttp/client.go b/control/controlhttp/client.go index 87061c310..5b1fed422 100644 --- a/control/controlhttp/client.go +++ b/control/controlhttp/client.go @@ -27,14 +27,13 @@ import ( "errors" "fmt" "io" - "math" + "log" "net" "net/http" "net/http/httptrace" "net/netip" "net/url" "runtime" - "sort" "sync/atomic" "time" @@ -53,7 +52,6 @@ import ( "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/tstime" - "tailscale.com/util/multierr" ) var stdDialer net.Dialer @@ -110,18 +108,8 @@ func (a *Dialer) dial(ctx context.Context) (*ClientConn, error) { } candidates := a.DialPlan.Candidates - // Otherwise, we try dialing per the plan. Store the highest priority - // in the list, so that if we get a connection to one of those - // candidates we can return quickly. - var highestPriority int = math.MinInt - for _, c := range candidates { - if c.Priority > highestPriority { - highestPriority = c.Priority - } - } - - // This context allows us to cancel in-flight connections if we get a - // highest-priority connection before we're all done. + // Create a context to be canceled as we return, so once we get a good connection, + // we can drop all the other ones. ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -129,142 +117,61 @@ func (a *Dialer) dial(ctx context.Context) (*ClientConn, error) { type dialResult struct { conn *ClientConn err error - cand tailcfg.ControlIPCandidate - } - resultsCh := make(chan dialResult, len(candidates)) - - var pending atomic.Int32 - pending.Store(int32(len(candidates))) - for _, c := range candidates { - go func(ctx context.Context, c tailcfg.ControlIPCandidate) { - var ( - conn *ClientConn - err error - ) - - // Always send results back to our channel. - defer func() { - resultsCh <- dialResult{conn, err, c} - if pending.Add(-1) == 0 { - close(resultsCh) - } - }() + } + resultsCh := make(chan dialResult) // unbuffered, never closed - // If non-zero, wait the configured start timeout - // before we do anything. - if c.DialStartDelaySec > 0 { - a.logf("[v2] controlhttp: waiting %.2f seconds before dialing %q @ %v", c.DialStartDelaySec, a.Hostname, c.IP) - tmr, tmrChannel := a.clock().NewTimer(time.Duration(c.DialStartDelaySec * float64(time.Second))) - defer tmr.Stop() - select { - case <-ctx.Done(): - err = ctx.Err() - return - case <-tmrChannel: - } - } + dialCand := func(cand tailcfg.ControlIPCandidate) (*ClientConn, error) { + a.logf("[v2] controlhttp: waited %.2f seconds, dialing %q @ %v", cand.DialStartDelaySec, a.Hostname, cand.IP) - // Now, create a sub-context with the given timeout and - // try dialing the provided host. - ctx, cancel := context.WithTimeout(ctx, time.Duration(c.DialTimeoutSec*float64(time.Second))) - defer cancel() + ctx, cancel := context.WithTimeout(ctx, time.Duration(cand.DialTimeoutSec*float64(time.Second))) + defer cancel() - if c.IP.IsValid() { - a.logf("[v2] controlhttp: trying to dial %q @ %v", a.Hostname, c.IP) - } else if c.ACEHost != "" { - a.logf("[v2] controlhttp: trying to dial %q via ACE %q", a.Hostname, c.ACEHost) - } - // This will dial, and the defer above sends it back to our parent. - conn, err = a.dialHostOpt(ctx, c.IP, c.ACEHost) - }(ctx, c) + if cand.IP.IsValid() { + a.logf("[v2] controlhttp: trying to dial %q @ %v", a.Hostname, cand.IP) + } else if cand.ACEHost != "" { + a.logf("[v2] controlhttp: trying to dial %q via ACE %q", a.Hostname, cand.ACEHost) + } + // This will dial, and the defer above sends it back to our parent. + return a.dialHostOpt(ctx, cand.IP, cand.ACEHost) } - var results []dialResult - for res := range resultsCh { - // If we get a response that has the highest priority, we don't - // need to wait for any of the other connections to finish; we - // can just return this connection. - // - // TODO(andrew): we could make this better by keeping track of - // the highest remaining priority dynamically, instead of just - // checking for the highest total - if res.cand.Priority == highestPriority && res.conn != nil { - a.logf("[v1] controlhttp: high-priority success dialing %q @ %v from dial plan", a.Hostname, cmp.Or(res.cand.ACEHost, res.cand.IP.String())) - - // Drain the channel and any existing connections in - // the background. + for _, cand := range candidates { + timer := time.AfterFunc(time.Duration(cand.DialStartDelaySec*float64(time.Second)), func() { go func() { - for _, res := range results { - if res.conn != nil { - res.conn.Close() + conn, err := dialCand(cand) + select { + case resultsCh <- dialResult{conn, err}: + if err == nil { + a.logf("[v1] controlhttp: succeeded dialing %q @ %v from dial plan", a.Hostname, cmp.Or(cand.ACEHost, cand.IP.String())) } - } - for res := range resultsCh { - if res.conn != nil { - res.conn.Close() + case <-ctx.Done(): + if conn != nil { + conn.Close() } } - if a.drainFinished != nil { - close(a.drainFinished) - } }() - return res.conn, nil - } - - // This isn't a highest-priority result, so just store it until - // we're done. - results = append(results, res) + }) + defer timer.Stop() } - // After we finish this function, close any remaining open connections. - defer func() { - for _, result := range results { - // Note: below, we nil out the returned connection (if - // any) in the slice so we don't close it. - if result.conn != nil { - result.conn.Close() + var errs []error + for { + select { + case res := <-resultsCh: + if res.err == nil { + return res.conn, nil } + errs = append(errs, res.err) + if len(errs) == len(candidates) { + // If we get here, then we didn't get anywhere with our dial plan; fall back to just using DNS. + a.logf("controlhttp: failed dialing using DialPlan, falling back to DNS; errs=%s", errors.Join(errs...)) + return a.dialHost(ctx) + } + case <-ctx.Done(): + a.logf("controlhttp: context aborted dialing") + return nil, ctx.Err() } - - // We don't drain asynchronously after this point, so notify our - // channel when we return. - if a.drainFinished != nil { - close(a.drainFinished) - } - }() - - // Sort by priority, then take the first non-error response. - sort.Slice(results, func(i, j int) bool { - // NOTE: intentionally inverted so that the highest priority - // item comes first - return results[i].cand.Priority > results[j].cand.Priority - }) - - var ( - conn *ClientConn - errs []error - ) - for i, result := range results { - if result.err != nil { - errs = append(errs, result.err) - continue - } - - a.logf("[v1] controlhttp: succeeded dialing %q @ %v from dial plan", a.Hostname, cmp.Or(result.cand.ACEHost, result.cand.IP.String())) - conn = result.conn - results[i].conn = nil // so we don't close it in the defer - return conn, nil } - if ctx.Err() != nil { - a.logf("controlhttp: context aborted dialing") - return nil, ctx.Err() - } - - merr := multierr.New(errs...) - - // If we get here, then we didn't get anywhere with our dial plan; fall back to just using DNS. - a.logf("controlhttp: failed dialing using DialPlan, falling back to DNS; errs=%s", merr.Error()) - return a.dialHost(ctx) } // The TS_FORCE_NOISE_443 envknob forces the controlclient noise dialer to @@ -422,6 +329,11 @@ func (a *Dialer) dialHostOpt(ctx context.Context, optAddr netip.Addr, optACEHost go try(u443) } // else we lost the race and it started already which is what we want case u443: + if u80 == nil { + log.Printf("XXXX no port 80 so returning error: %v", res.err) + // We never started a port 80 dial, so just return the port 443 error. + return nil, res.err + } err443 = res.err default: panic("invalid") diff --git a/control/controlhttp/http_test.go b/control/controlhttp/http_test.go index 0b4e117f9..ffae74a6f 100644 --- a/control/controlhttp/http_test.go +++ b/control/controlhttp/http_test.go @@ -15,17 +15,19 @@ import ( "net/http/httputil" "net/netip" "net/url" - "runtime" "slices" "strconv" + "strings" "sync" "testing" + "testing/synctest" "time" "tailscale.com/control/controlbase" "tailscale.com/control/controlhttp/controlhttpcommon" "tailscale.com/control/controlhttp/controlhttpserver" "tailscale.com/health" + "tailscale.com/net/memnet" "tailscale.com/net/netmon" "tailscale.com/net/netx" "tailscale.com/net/socks5" @@ -545,35 +547,13 @@ func brokenMITMHandler(clock tstime.Clock) http.HandlerFunc { } func TestDialPlan(t *testing.T) { - if runtime.GOOS != "linux" { - t.Skip("only works on Linux due to multiple localhost addresses") - } - client, server := key.NewMachine(), key.NewMachine() const ( testProtocolVersion = 1 ) - getRandomPort := func() string { - ln, err := net.Listen("tcp", ":0") - if err != nil { - t.Fatalf("net.Listen: %v", err) - } - defer ln.Close() - _, port, err := net.SplitHostPort(ln.Addr().String()) - if err != nil { - t.Fatal(err) - } - return port - } - - // We need consistent ports for each address; these are chosen - // randomly and we hope that they won't conflict during this test. - httpPort := getRandomPort() - httpsPort := getRandomPort() - - makeHandler := func(t *testing.T, name string, host netip.Addr, wrap func(http.Handler) http.Handler) { + makeHandler := func(t *testing.T, memNet *memnet.Network, name string, host netip.Addr, wrap func(http.Handler) http.Handler) { done := make(chan struct{}) t.Cleanup(func() { close(done) @@ -592,11 +572,11 @@ func TestDialPlan(t *testing.T) { handler = wrap(handler) } - httpLn, err := net.Listen("tcp", host.String()+":"+httpPort) + httpLn, err := memNet.Listen("tcp", host.String()+":80") if err != nil { t.Fatalf("HTTP listen: %v", err) } - httpsLn, err := net.Listen("tcp", host.String()+":"+httpsPort) + httpsLn, err := memNet.Listen("tcp", host.String()+":443") if err != nil { t.Fatalf("HTTPS listen: %v", err) } @@ -616,7 +596,6 @@ func TestDialPlan(t *testing.T) { t.Cleanup(func() { httpsServer.Close() }) - return } fallbackAddr := netip.MustParseAddr("127.0.0.1") @@ -686,74 +665,82 @@ func TestDialPlan(t *testing.T) { } for _, tt := range testCases { t.Run(tt.name, func(t *testing.T) { - // TODO(awly): replace this with tstest.NewClock and update the - // test to advance the clock correctly. - clock := tstime.StdClock{} - makeHandler(t, "fallback", fallbackAddr, nil) - makeHandler(t, "good", goodAddr, nil) - makeHandler(t, "other", otherAddr, nil) - makeHandler(t, "other2", other2Addr, nil) - makeHandler(t, "broken", brokenAddr, func(h http.Handler) http.Handler { - return brokenMITMHandler(clock) - }) + synctest.Test(t, func(t *testing.T) { + + // Get the synctest clock way out to 2025 at least so the + // net/http/httptest TLS client certs are valid? + // TODO(bradfitz): this might not be necessary. Still debugging. + time.Sleep(26 * 365 * 24 * time.Hour) + + var memNet memnet.Network + + clock := tstime.StdClock{} + makeHandler(t, &memNet, "fallback", fallbackAddr, nil) + makeHandler(t, &memNet, "good", goodAddr, nil) + makeHandler(t, &memNet, "other", otherAddr, nil) + makeHandler(t, &memNet, "other2", other2Addr, nil) + makeHandler(t, &memNet, "broken", brokenAddr, func(h http.Handler) http.Handler { + return brokenMITMHandler(clock) + }) + + dialer := closeTrackDialer{ + t: t, + inner: memNet.Dial, + conns: make(map[*closeTrackConn]bool), + } + defer dialer.Done() - dialer := closeTrackDialer{ - t: t, - inner: tsdial.NewDialer(netmon.NewStatic()).SystemDial, - conns: make(map[*closeTrackConn]bool), - } - defer dialer.Done() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() + // By default, we intentionally point to something that + // we know won't connect, since we want a fallback to + // DNS to be an error. + host := "example.com" + if tt.allowFallback { + host = "localhost" + } - // By default, we intentionally point to something that - // we know won't connect, since we want a fallback to - // DNS to be an error. - host := "example.com" - if tt.allowFallback { - host = "localhost" - } + a := &Dialer{ + Hostname: host, + MachineKey: client, + ControlKey: server.Public(), + ProtocolVersion: testProtocolVersion, + Dialer: dialer.Dial, + Logf: t.Logf, + DialPlan: tt.plan, + proxyFunc: func(*http.Request) (*url.URL, error) { return nil, nil }, + omitCertErrorLogging: true, + testFallbackDelay: 50 * time.Millisecond, + Clock: clock, + HealthTracker: health.NewTracker(eventbustest.NewBus(t)), + } - drained := make(chan struct{}) - a := &Dialer{ - Hostname: host, - HTTPPort: httpPort, - HTTPSPort: httpsPort, - MachineKey: client, - ControlKey: server.Public(), - ProtocolVersion: testProtocolVersion, - Dialer: dialer.Dial, - Logf: t.Logf, - DialPlan: tt.plan, - proxyFunc: func(*http.Request) (*url.URL, error) { return nil, nil }, - drainFinished: drained, - omitCertErrorLogging: true, - testFallbackDelay: 50 * time.Millisecond, - Clock: clock, - HealthTracker: health.NewTracker(eventbustest.NewBus(t)), - } + conn, err := a.dial(ctx) + if err != nil { + t.Fatalf("dialing controlhttp: %v", err) + } + defer conn.Close() - conn, err := a.dial(ctx) - if err != nil { - t.Fatalf("dialing controlhttp: %v", err) - } - defer conn.Close() + raddrStr := conn.RemoteAddr().String() - raddr := conn.RemoteAddr().(*net.TCPAddr) + raddrStr = strings.TrimSuffix(raddrStr, "|1") // memnet noise + raddrPort, err := netip.ParseAddrPort(raddrStr) + if err != nil { + t.Fatalf("parsing remote addr %q: %v", raddrStr, err) + } - got, ok := netip.AddrFromSlice(raddr.IP) - if !ok { - t.Errorf("invalid remote IP: %v", raddr.IP) - } else if got != tt.want { - t.Errorf("got connection from %q; want %q", got, tt.want) - } else { - t.Logf("successfully connected to %q", raddr.String()) - } + got := raddrPort.Addr() + if got != tt.want { + t.Errorf("got connection from %q; want %q", got, tt.want) + } else { + t.Logf("successfully connected to %q", got) + } - // Wait until our dialer drains so we can verify that - // all connections are closed. - <-drained + // Wait until our dialer drains so we can verify that + // all connections are closed. + synctest.Wait() + }) }) } }