control/controlhttp: simplify, fix race dialing, remove priority concept

Fixes tailscale/corp#32534

Co-authored-by: James Tucker <james@tailscale.com>
Change-Id: I4eb57f046d8b40403220e40eb67a31c41adb3a38
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
jamesbrad/controlhttp-race-dial
Brad Fitzpatrick 2 months ago
parent 1b6bc37f28
commit 17643b05eb

@ -27,14 +27,13 @@ import (
"errors"
"fmt"
"io"
"math"
"log"
"net"
"net/http"
"net/http/httptrace"
"net/netip"
"net/url"
"runtime"
"sort"
"sync/atomic"
"time"
@ -53,7 +52,6 @@ import (
"tailscale.com/syncs"
"tailscale.com/tailcfg"
"tailscale.com/tstime"
"tailscale.com/util/multierr"
)
var stdDialer net.Dialer
@ -110,18 +108,8 @@ func (a *Dialer) dial(ctx context.Context) (*ClientConn, error) {
}
candidates := a.DialPlan.Candidates
// Otherwise, we try dialing per the plan. Store the highest priority
// in the list, so that if we get a connection to one of those
// candidates we can return quickly.
var highestPriority int = math.MinInt
for _, c := range candidates {
if c.Priority > highestPriority {
highestPriority = c.Priority
}
}
// This context allows us to cancel in-flight connections if we get a
// highest-priority connection before we're all done.
// Create a context to be canceled as we return, so once we get a good connection,
// we can drop all the other ones.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
@ -129,142 +117,61 @@ func (a *Dialer) dial(ctx context.Context) (*ClientConn, error) {
type dialResult struct {
conn *ClientConn
err error
cand tailcfg.ControlIPCandidate
}
resultsCh := make(chan dialResult, len(candidates))
resultsCh := make(chan dialResult) // unbuffered, never closed
var pending atomic.Int32
pending.Store(int32(len(candidates)))
for _, c := range candidates {
go func(ctx context.Context, c tailcfg.ControlIPCandidate) {
var (
conn *ClientConn
err error
)
dialCand := func(cand tailcfg.ControlIPCandidate) (*ClientConn, error) {
a.logf("[v2] controlhttp: waited %.2f seconds, dialing %q @ %v", cand.DialStartDelaySec, a.Hostname, cand.IP)
// Always send results back to our channel.
defer func() {
resultsCh <- dialResult{conn, err, c}
if pending.Add(-1) == 0 {
close(resultsCh)
}
}()
// If non-zero, wait the configured start timeout
// before we do anything.
if c.DialStartDelaySec > 0 {
a.logf("[v2] controlhttp: waiting %.2f seconds before dialing %q @ %v", c.DialStartDelaySec, a.Hostname, c.IP)
tmr, tmrChannel := a.clock().NewTimer(time.Duration(c.DialStartDelaySec * float64(time.Second)))
defer tmr.Stop()
select {
case <-ctx.Done():
err = ctx.Err()
return
case <-tmrChannel:
}
}
// Now, create a sub-context with the given timeout and
// try dialing the provided host.
ctx, cancel := context.WithTimeout(ctx, time.Duration(c.DialTimeoutSec*float64(time.Second)))
ctx, cancel := context.WithTimeout(ctx, time.Duration(cand.DialTimeoutSec*float64(time.Second)))
defer cancel()
if c.IP.IsValid() {
a.logf("[v2] controlhttp: trying to dial %q @ %v", a.Hostname, c.IP)
} else if c.ACEHost != "" {
a.logf("[v2] controlhttp: trying to dial %q via ACE %q", a.Hostname, c.ACEHost)
if cand.IP.IsValid() {
a.logf("[v2] controlhttp: trying to dial %q @ %v", a.Hostname, cand.IP)
} else if cand.ACEHost != "" {
a.logf("[v2] controlhttp: trying to dial %q via ACE %q", a.Hostname, cand.ACEHost)
}
// This will dial, and the defer above sends it back to our parent.
conn, err = a.dialHostOpt(ctx, c.IP, c.ACEHost)
}(ctx, c)
return a.dialHostOpt(ctx, cand.IP, cand.ACEHost)
}
var results []dialResult
for res := range resultsCh {
// If we get a response that has the highest priority, we don't
// need to wait for any of the other connections to finish; we
// can just return this connection.
//
// TODO(andrew): we could make this better by keeping track of
// the highest remaining priority dynamically, instead of just
// checking for the highest total
if res.cand.Priority == highestPriority && res.conn != nil {
a.logf("[v1] controlhttp: high-priority success dialing %q @ %v from dial plan", a.Hostname, cmp.Or(res.cand.ACEHost, res.cand.IP.String()))
// Drain the channel and any existing connections in
// the background.
for _, cand := range candidates {
timer := time.AfterFunc(time.Duration(cand.DialStartDelaySec*float64(time.Second)), func() {
go func() {
for _, res := range results {
if res.conn != nil {
res.conn.Close()
}
}
for res := range resultsCh {
if res.conn != nil {
res.conn.Close()
}
}
if a.drainFinished != nil {
close(a.drainFinished)
}
}()
return res.conn, nil
}
// This isn't a highest-priority result, so just store it until
// we're done.
results = append(results, res)
}
// After we finish this function, close any remaining open connections.
defer func() {
for _, result := range results {
// Note: below, we nil out the returned connection (if
// any) in the slice so we don't close it.
if result.conn != nil {
result.conn.Close()
conn, err := dialCand(cand)
select {
case resultsCh <- dialResult{conn, err}:
if err == nil {
a.logf("[v1] controlhttp: succeeded dialing %q @ %v from dial plan", a.Hostname, cmp.Or(cand.ACEHost, cand.IP.String()))
}
case <-ctx.Done():
if conn != nil {
conn.Close()
}
// We don't drain asynchronously after this point, so notify our
// channel when we return.
if a.drainFinished != nil {
close(a.drainFinished)
}
}()
// Sort by priority, then take the first non-error response.
sort.Slice(results, func(i, j int) bool {
// NOTE: intentionally inverted so that the highest priority
// item comes first
return results[i].cand.Priority > results[j].cand.Priority
})
var (
conn *ClientConn
errs []error
)
for i, result := range results {
if result.err != nil {
errs = append(errs, result.err)
continue
defer timer.Stop()
}
a.logf("[v1] controlhttp: succeeded dialing %q @ %v from dial plan", a.Hostname, cmp.Or(result.cand.ACEHost, result.cand.IP.String()))
conn = result.conn
results[i].conn = nil // so we don't close it in the defer
return conn, nil
var errs []error
for {
select {
case res := <-resultsCh:
if res.err == nil {
return res.conn, nil
}
errs = append(errs, res.err)
if len(errs) == len(candidates) {
// If we get here, then we didn't get anywhere with our dial plan; fall back to just using DNS.
a.logf("controlhttp: failed dialing using DialPlan, falling back to DNS; errs=%s", errors.Join(errs...))
return a.dialHost(ctx)
}
if ctx.Err() != nil {
case <-ctx.Done():
a.logf("controlhttp: context aborted dialing")
return nil, ctx.Err()
}
merr := multierr.New(errs...)
// If we get here, then we didn't get anywhere with our dial plan; fall back to just using DNS.
a.logf("controlhttp: failed dialing using DialPlan, falling back to DNS; errs=%s", merr.Error())
return a.dialHost(ctx)
}
}
// The TS_FORCE_NOISE_443 envknob forces the controlclient noise dialer to
@ -422,6 +329,11 @@ func (a *Dialer) dialHostOpt(ctx context.Context, optAddr netip.Addr, optACEHost
go try(u443)
} // else we lost the race and it started already which is what we want
case u443:
if u80 == nil {
log.Printf("XXXX no port 80 so returning error: %v", res.err)
// We never started a port 80 dial, so just return the port 443 error.
return nil, res.err
}
err443 = res.err
default:
panic("invalid")

@ -15,17 +15,19 @@ import (
"net/http/httputil"
"net/netip"
"net/url"
"runtime"
"slices"
"strconv"
"strings"
"sync"
"testing"
"testing/synctest"
"time"
"tailscale.com/control/controlbase"
"tailscale.com/control/controlhttp/controlhttpcommon"
"tailscale.com/control/controlhttp/controlhttpserver"
"tailscale.com/health"
"tailscale.com/net/memnet"
"tailscale.com/net/netmon"
"tailscale.com/net/netx"
"tailscale.com/net/socks5"
@ -545,35 +547,13 @@ func brokenMITMHandler(clock tstime.Clock) http.HandlerFunc {
}
func TestDialPlan(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skip("only works on Linux due to multiple localhost addresses")
}
client, server := key.NewMachine(), key.NewMachine()
const (
testProtocolVersion = 1
)
getRandomPort := func() string {
ln, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatalf("net.Listen: %v", err)
}
defer ln.Close()
_, port, err := net.SplitHostPort(ln.Addr().String())
if err != nil {
t.Fatal(err)
}
return port
}
// We need consistent ports for each address; these are chosen
// randomly and we hope that they won't conflict during this test.
httpPort := getRandomPort()
httpsPort := getRandomPort()
makeHandler := func(t *testing.T, name string, host netip.Addr, wrap func(http.Handler) http.Handler) {
makeHandler := func(t *testing.T, memNet *memnet.Network, name string, host netip.Addr, wrap func(http.Handler) http.Handler) {
done := make(chan struct{})
t.Cleanup(func() {
close(done)
@ -592,11 +572,11 @@ func TestDialPlan(t *testing.T) {
handler = wrap(handler)
}
httpLn, err := net.Listen("tcp", host.String()+":"+httpPort)
httpLn, err := memNet.Listen("tcp", host.String()+":80")
if err != nil {
t.Fatalf("HTTP listen: %v", err)
}
httpsLn, err := net.Listen("tcp", host.String()+":"+httpsPort)
httpsLn, err := memNet.Listen("tcp", host.String()+":443")
if err != nil {
t.Fatalf("HTTPS listen: %v", err)
}
@ -616,7 +596,6 @@ func TestDialPlan(t *testing.T) {
t.Cleanup(func() {
httpsServer.Close()
})
return
}
fallbackAddr := netip.MustParseAddr("127.0.0.1")
@ -686,20 +665,27 @@ func TestDialPlan(t *testing.T) {
}
for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
// TODO(awly): replace this with tstest.NewClock and update the
// test to advance the clock correctly.
synctest.Test(t, func(t *testing.T) {
// Get the synctest clock way out to 2025 at least so the
// net/http/httptest TLS client certs are valid?
// TODO(bradfitz): this might not be necessary. Still debugging.
time.Sleep(26 * 365 * 24 * time.Hour)
var memNet memnet.Network
clock := tstime.StdClock{}
makeHandler(t, "fallback", fallbackAddr, nil)
makeHandler(t, "good", goodAddr, nil)
makeHandler(t, "other", otherAddr, nil)
makeHandler(t, "other2", other2Addr, nil)
makeHandler(t, "broken", brokenAddr, func(h http.Handler) http.Handler {
makeHandler(t, &memNet, "fallback", fallbackAddr, nil)
makeHandler(t, &memNet, "good", goodAddr, nil)
makeHandler(t, &memNet, "other", otherAddr, nil)
makeHandler(t, &memNet, "other2", other2Addr, nil)
makeHandler(t, &memNet, "broken", brokenAddr, func(h http.Handler) http.Handler {
return brokenMITMHandler(clock)
})
dialer := closeTrackDialer{
t: t,
inner: tsdial.NewDialer(netmon.NewStatic()).SystemDial,
inner: memNet.Dial,
conns: make(map[*closeTrackConn]bool),
}
defer dialer.Done()
@ -715,11 +701,8 @@ func TestDialPlan(t *testing.T) {
host = "localhost"
}
drained := make(chan struct{})
a := &Dialer{
Hostname: host,
HTTPPort: httpPort,
HTTPSPort: httpsPort,
MachineKey: client,
ControlKey: server.Public(),
ProtocolVersion: testProtocolVersion,
@ -727,7 +710,6 @@ func TestDialPlan(t *testing.T) {
Logf: t.Logf,
DialPlan: tt.plan,
proxyFunc: func(*http.Request) (*url.URL, error) { return nil, nil },
drainFinished: drained,
omitCertErrorLogging: true,
testFallbackDelay: 50 * time.Millisecond,
Clock: clock,
@ -740,20 +722,25 @@ func TestDialPlan(t *testing.T) {
}
defer conn.Close()
raddr := conn.RemoteAddr().(*net.TCPAddr)
raddrStr := conn.RemoteAddr().String()
got, ok := netip.AddrFromSlice(raddr.IP)
if !ok {
t.Errorf("invalid remote IP: %v", raddr.IP)
} else if got != tt.want {
raddrStr = strings.TrimSuffix(raddrStr, "|1") // memnet noise
raddrPort, err := netip.ParseAddrPort(raddrStr)
if err != nil {
t.Fatalf("parsing remote addr %q: %v", raddrStr, err)
}
got := raddrPort.Addr()
if got != tt.want {
t.Errorf("got connection from %q; want %q", got, tt.want)
} else {
t.Logf("successfully connected to %q", raddr.String())
t.Logf("successfully connected to %q", got)
}
// Wait until our dialer drains so we can verify that
// all connections are closed.
<-drained
synctest.Wait()
})
})
}
}

Loading…
Cancel
Save