control/controlhttp: simplify, fix race dialing, remove priority concept

Fixes tailscale/corp#32534

Co-authored-by: James Tucker <james@tailscale.com>
Change-Id: I4eb57f046d8b40403220e40eb67a31c41adb3a38
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
jamesbrad/controlhttp-race-dial
Brad Fitzpatrick 2 months ago
parent 1b6bc37f28
commit 17643b05eb

@ -27,14 +27,13 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"math" "log"
"net" "net"
"net/http" "net/http"
"net/http/httptrace" "net/http/httptrace"
"net/netip" "net/netip"
"net/url" "net/url"
"runtime" "runtime"
"sort"
"sync/atomic" "sync/atomic"
"time" "time"
@ -53,7 +52,6 @@ import (
"tailscale.com/syncs" "tailscale.com/syncs"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/tstime" "tailscale.com/tstime"
"tailscale.com/util/multierr"
) )
var stdDialer net.Dialer var stdDialer net.Dialer
@ -110,18 +108,8 @@ func (a *Dialer) dial(ctx context.Context) (*ClientConn, error) {
} }
candidates := a.DialPlan.Candidates candidates := a.DialPlan.Candidates
// Otherwise, we try dialing per the plan. Store the highest priority // Create a context to be canceled as we return, so once we get a good connection,
// in the list, so that if we get a connection to one of those // we can drop all the other ones.
// candidates we can return quickly.
var highestPriority int = math.MinInt
for _, c := range candidates {
if c.Priority > highestPriority {
highestPriority = c.Priority
}
}
// This context allows us to cancel in-flight connections if we get a
// highest-priority connection before we're all done.
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
@ -129,142 +117,61 @@ func (a *Dialer) dial(ctx context.Context) (*ClientConn, error) {
type dialResult struct { type dialResult struct {
conn *ClientConn conn *ClientConn
err error err error
cand tailcfg.ControlIPCandidate
} }
resultsCh := make(chan dialResult, len(candidates)) resultsCh := make(chan dialResult) // unbuffered, never closed
var pending atomic.Int32 dialCand := func(cand tailcfg.ControlIPCandidate) (*ClientConn, error) {
pending.Store(int32(len(candidates))) a.logf("[v2] controlhttp: waited %.2f seconds, dialing %q @ %v", cand.DialStartDelaySec, a.Hostname, cand.IP)
for _, c := range candidates {
go func(ctx context.Context, c tailcfg.ControlIPCandidate) {
var (
conn *ClientConn
err error
)
// Always send results back to our channel. ctx, cancel := context.WithTimeout(ctx, time.Duration(cand.DialTimeoutSec*float64(time.Second)))
defer func() {
resultsCh <- dialResult{conn, err, c}
if pending.Add(-1) == 0 {
close(resultsCh)
}
}()
// If non-zero, wait the configured start timeout
// before we do anything.
if c.DialStartDelaySec > 0 {
a.logf("[v2] controlhttp: waiting %.2f seconds before dialing %q @ %v", c.DialStartDelaySec, a.Hostname, c.IP)
tmr, tmrChannel := a.clock().NewTimer(time.Duration(c.DialStartDelaySec * float64(time.Second)))
defer tmr.Stop()
select {
case <-ctx.Done():
err = ctx.Err()
return
case <-tmrChannel:
}
}
// Now, create a sub-context with the given timeout and
// try dialing the provided host.
ctx, cancel := context.WithTimeout(ctx, time.Duration(c.DialTimeoutSec*float64(time.Second)))
defer cancel() defer cancel()
if c.IP.IsValid() { if cand.IP.IsValid() {
a.logf("[v2] controlhttp: trying to dial %q @ %v", a.Hostname, c.IP) a.logf("[v2] controlhttp: trying to dial %q @ %v", a.Hostname, cand.IP)
} else if c.ACEHost != "" { } else if cand.ACEHost != "" {
a.logf("[v2] controlhttp: trying to dial %q via ACE %q", a.Hostname, c.ACEHost) a.logf("[v2] controlhttp: trying to dial %q via ACE %q", a.Hostname, cand.ACEHost)
} }
// This will dial, and the defer above sends it back to our parent. // This will dial, and the defer above sends it back to our parent.
conn, err = a.dialHostOpt(ctx, c.IP, c.ACEHost) return a.dialHostOpt(ctx, cand.IP, cand.ACEHost)
}(ctx, c)
} }
var results []dialResult for _, cand := range candidates {
for res := range resultsCh { timer := time.AfterFunc(time.Duration(cand.DialStartDelaySec*float64(time.Second)), func() {
// If we get a response that has the highest priority, we don't
// need to wait for any of the other connections to finish; we
// can just return this connection.
//
// TODO(andrew): we could make this better by keeping track of
// the highest remaining priority dynamically, instead of just
// checking for the highest total
if res.cand.Priority == highestPriority && res.conn != nil {
a.logf("[v1] controlhttp: high-priority success dialing %q @ %v from dial plan", a.Hostname, cmp.Or(res.cand.ACEHost, res.cand.IP.String()))
// Drain the channel and any existing connections in
// the background.
go func() { go func() {
for _, res := range results { conn, err := dialCand(cand)
if res.conn != nil { select {
res.conn.Close() case resultsCh <- dialResult{conn, err}:
} if err == nil {
} a.logf("[v1] controlhttp: succeeded dialing %q @ %v from dial plan", a.Hostname, cmp.Or(cand.ACEHost, cand.IP.String()))
for res := range resultsCh {
if res.conn != nil {
res.conn.Close()
}
}
if a.drainFinished != nil {
close(a.drainFinished)
}
}()
return res.conn, nil
}
// This isn't a highest-priority result, so just store it until
// we're done.
results = append(results, res)
}
// After we finish this function, close any remaining open connections.
defer func() {
for _, result := range results {
// Note: below, we nil out the returned connection (if
// any) in the slice so we don't close it.
if result.conn != nil {
result.conn.Close()
} }
case <-ctx.Done():
if conn != nil {
conn.Close()
} }
// We don't drain asynchronously after this point, so notify our
// channel when we return.
if a.drainFinished != nil {
close(a.drainFinished)
} }
}() }()
// Sort by priority, then take the first non-error response.
sort.Slice(results, func(i, j int) bool {
// NOTE: intentionally inverted so that the highest priority
// item comes first
return results[i].cand.Priority > results[j].cand.Priority
}) })
defer timer.Stop()
var (
conn *ClientConn
errs []error
)
for i, result := range results {
if result.err != nil {
errs = append(errs, result.err)
continue
} }
a.logf("[v1] controlhttp: succeeded dialing %q @ %v from dial plan", a.Hostname, cmp.Or(result.cand.ACEHost, result.cand.IP.String())) var errs []error
conn = result.conn for {
results[i].conn = nil // so we don't close it in the defer select {
return conn, nil case res := <-resultsCh:
if res.err == nil {
return res.conn, nil
}
errs = append(errs, res.err)
if len(errs) == len(candidates) {
// If we get here, then we didn't get anywhere with our dial plan; fall back to just using DNS.
a.logf("controlhttp: failed dialing using DialPlan, falling back to DNS; errs=%s", errors.Join(errs...))
return a.dialHost(ctx)
} }
if ctx.Err() != nil { case <-ctx.Done():
a.logf("controlhttp: context aborted dialing") a.logf("controlhttp: context aborted dialing")
return nil, ctx.Err() return nil, ctx.Err()
} }
}
merr := multierr.New(errs...)
// If we get here, then we didn't get anywhere with our dial plan; fall back to just using DNS.
a.logf("controlhttp: failed dialing using DialPlan, falling back to DNS; errs=%s", merr.Error())
return a.dialHost(ctx)
} }
// The TS_FORCE_NOISE_443 envknob forces the controlclient noise dialer to // The TS_FORCE_NOISE_443 envknob forces the controlclient noise dialer to
@ -422,6 +329,11 @@ func (a *Dialer) dialHostOpt(ctx context.Context, optAddr netip.Addr, optACEHost
go try(u443) go try(u443)
} // else we lost the race and it started already which is what we want } // else we lost the race and it started already which is what we want
case u443: case u443:
if u80 == nil {
log.Printf("XXXX no port 80 so returning error: %v", res.err)
// We never started a port 80 dial, so just return the port 443 error.
return nil, res.err
}
err443 = res.err err443 = res.err
default: default:
panic("invalid") panic("invalid")

@ -15,17 +15,19 @@ import (
"net/http/httputil" "net/http/httputil"
"net/netip" "net/netip"
"net/url" "net/url"
"runtime"
"slices" "slices"
"strconv" "strconv"
"strings"
"sync" "sync"
"testing" "testing"
"testing/synctest"
"time" "time"
"tailscale.com/control/controlbase" "tailscale.com/control/controlbase"
"tailscale.com/control/controlhttp/controlhttpcommon" "tailscale.com/control/controlhttp/controlhttpcommon"
"tailscale.com/control/controlhttp/controlhttpserver" "tailscale.com/control/controlhttp/controlhttpserver"
"tailscale.com/health" "tailscale.com/health"
"tailscale.com/net/memnet"
"tailscale.com/net/netmon" "tailscale.com/net/netmon"
"tailscale.com/net/netx" "tailscale.com/net/netx"
"tailscale.com/net/socks5" "tailscale.com/net/socks5"
@ -545,35 +547,13 @@ func brokenMITMHandler(clock tstime.Clock) http.HandlerFunc {
} }
func TestDialPlan(t *testing.T) { func TestDialPlan(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skip("only works on Linux due to multiple localhost addresses")
}
client, server := key.NewMachine(), key.NewMachine() client, server := key.NewMachine(), key.NewMachine()
const ( const (
testProtocolVersion = 1 testProtocolVersion = 1
) )
getRandomPort := func() string { makeHandler := func(t *testing.T, memNet *memnet.Network, name string, host netip.Addr, wrap func(http.Handler) http.Handler) {
ln, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatalf("net.Listen: %v", err)
}
defer ln.Close()
_, port, err := net.SplitHostPort(ln.Addr().String())
if err != nil {
t.Fatal(err)
}
return port
}
// We need consistent ports for each address; these are chosen
// randomly and we hope that they won't conflict during this test.
httpPort := getRandomPort()
httpsPort := getRandomPort()
makeHandler := func(t *testing.T, name string, host netip.Addr, wrap func(http.Handler) http.Handler) {
done := make(chan struct{}) done := make(chan struct{})
t.Cleanup(func() { t.Cleanup(func() {
close(done) close(done)
@ -592,11 +572,11 @@ func TestDialPlan(t *testing.T) {
handler = wrap(handler) handler = wrap(handler)
} }
httpLn, err := net.Listen("tcp", host.String()+":"+httpPort) httpLn, err := memNet.Listen("tcp", host.String()+":80")
if err != nil { if err != nil {
t.Fatalf("HTTP listen: %v", err) t.Fatalf("HTTP listen: %v", err)
} }
httpsLn, err := net.Listen("tcp", host.String()+":"+httpsPort) httpsLn, err := memNet.Listen("tcp", host.String()+":443")
if err != nil { if err != nil {
t.Fatalf("HTTPS listen: %v", err) t.Fatalf("HTTPS listen: %v", err)
} }
@ -616,7 +596,6 @@ func TestDialPlan(t *testing.T) {
t.Cleanup(func() { t.Cleanup(func() {
httpsServer.Close() httpsServer.Close()
}) })
return
} }
fallbackAddr := netip.MustParseAddr("127.0.0.1") fallbackAddr := netip.MustParseAddr("127.0.0.1")
@ -686,20 +665,27 @@ func TestDialPlan(t *testing.T) {
} }
for _, tt := range testCases { for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
// TODO(awly): replace this with tstest.NewClock and update the synctest.Test(t, func(t *testing.T) {
// test to advance the clock correctly.
// Get the synctest clock way out to 2025 at least so the
// net/http/httptest TLS client certs are valid?
// TODO(bradfitz): this might not be necessary. Still debugging.
time.Sleep(26 * 365 * 24 * time.Hour)
var memNet memnet.Network
clock := tstime.StdClock{} clock := tstime.StdClock{}
makeHandler(t, "fallback", fallbackAddr, nil) makeHandler(t, &memNet, "fallback", fallbackAddr, nil)
makeHandler(t, "good", goodAddr, nil) makeHandler(t, &memNet, "good", goodAddr, nil)
makeHandler(t, "other", otherAddr, nil) makeHandler(t, &memNet, "other", otherAddr, nil)
makeHandler(t, "other2", other2Addr, nil) makeHandler(t, &memNet, "other2", other2Addr, nil)
makeHandler(t, "broken", brokenAddr, func(h http.Handler) http.Handler { makeHandler(t, &memNet, "broken", brokenAddr, func(h http.Handler) http.Handler {
return brokenMITMHandler(clock) return brokenMITMHandler(clock)
}) })
dialer := closeTrackDialer{ dialer := closeTrackDialer{
t: t, t: t,
inner: tsdial.NewDialer(netmon.NewStatic()).SystemDial, inner: memNet.Dial,
conns: make(map[*closeTrackConn]bool), conns: make(map[*closeTrackConn]bool),
} }
defer dialer.Done() defer dialer.Done()
@ -715,11 +701,8 @@ func TestDialPlan(t *testing.T) {
host = "localhost" host = "localhost"
} }
drained := make(chan struct{})
a := &Dialer{ a := &Dialer{
Hostname: host, Hostname: host,
HTTPPort: httpPort,
HTTPSPort: httpsPort,
MachineKey: client, MachineKey: client,
ControlKey: server.Public(), ControlKey: server.Public(),
ProtocolVersion: testProtocolVersion, ProtocolVersion: testProtocolVersion,
@ -727,7 +710,6 @@ func TestDialPlan(t *testing.T) {
Logf: t.Logf, Logf: t.Logf,
DialPlan: tt.plan, DialPlan: tt.plan,
proxyFunc: func(*http.Request) (*url.URL, error) { return nil, nil }, proxyFunc: func(*http.Request) (*url.URL, error) { return nil, nil },
drainFinished: drained,
omitCertErrorLogging: true, omitCertErrorLogging: true,
testFallbackDelay: 50 * time.Millisecond, testFallbackDelay: 50 * time.Millisecond,
Clock: clock, Clock: clock,
@ -740,20 +722,25 @@ func TestDialPlan(t *testing.T) {
} }
defer conn.Close() defer conn.Close()
raddr := conn.RemoteAddr().(*net.TCPAddr) raddrStr := conn.RemoteAddr().String()
got, ok := netip.AddrFromSlice(raddr.IP) raddrStr = strings.TrimSuffix(raddrStr, "|1") // memnet noise
if !ok { raddrPort, err := netip.ParseAddrPort(raddrStr)
t.Errorf("invalid remote IP: %v", raddr.IP) if err != nil {
} else if got != tt.want { t.Fatalf("parsing remote addr %q: %v", raddrStr, err)
}
got := raddrPort.Addr()
if got != tt.want {
t.Errorf("got connection from %q; want %q", got, tt.want) t.Errorf("got connection from %q; want %q", got, tt.want)
} else { } else {
t.Logf("successfully connected to %q", raddr.String()) t.Logf("successfully connected to %q", got)
} }
// Wait until our dialer drains so we can verify that // Wait until our dialer drains so we can verify that
// all connections are closed. // all connections are closed.
<-drained synctest.Wait()
})
}) })
} }
} }

Loading…
Cancel
Save