cmd/tta: pull out test driver dialing into a type, fix bugs

There were a few places it could get wedged (notably the dial without
a timeout).

And add a knob for verbose debug logs.

And keep two idle connections always.

Updates #13038

Change-Id: I952ad182d7111481d97a83c12aa2ff4bfdc55fe8
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/13267/head
Brad Fitzpatrick 3 months ago committed by Brad Fitzpatrick
parent 6dd1af0d1e
commit 2636a83d0e

@ -24,6 +24,7 @@ import (
"os" "os"
"os/exec" "os/exec"
"regexp" "regexp"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -31,6 +32,7 @@ import (
"tailscale.com/atomicfile" "tailscale.com/atomicfile"
"tailscale.com/client/tailscale" "tailscale.com/client/tailscale"
"tailscale.com/hostinfo" "tailscale.com/hostinfo"
"tailscale.com/util/mak"
"tailscale.com/util/must" "tailscale.com/util/must"
"tailscale.com/util/set" "tailscale.com/util/set"
"tailscale.com/version/distro" "tailscale.com/version/distro"
@ -85,6 +87,7 @@ func main() {
} }
flag.Parse() flag.Parse()
debug := false
if distro.Get() == distro.Gokrazy { if distro.Get() == distro.Gokrazy {
cmdLine, _ := os.ReadFile("/proc/cmdline") cmdLine, _ := os.ReadFile("/proc/cmdline")
explicitNS := false explicitNS := false
@ -93,7 +96,11 @@ func main() {
err := atomicfile.WriteFile("/tmp/resolv.conf", []byte("nameserver "+ns+"\n"), 0644) err := atomicfile.WriteFile("/tmp/resolv.conf", []byte("nameserver "+ns+"\n"), 0644)
log.Printf("Wrote /tmp/resolv.conf: %v", err) log.Printf("Wrote /tmp/resolv.conf: %v", err)
explicitNS = true explicitNS = true
break continue
}
if v, ok := strings.CutPrefix(s, "tta.debug="); ok {
debug, _ = strconv.ParseBool(v)
continue
} }
} }
if !explicitNS { if !explicitNS {
@ -134,28 +141,11 @@ func main() {
}) })
var hs http.Server var hs http.Server
hs.Handler = &serveMux hs.Handler = &serveMux
var ( revSt := revDialState{
stMu sync.Mutex needConnCh: make(chan bool, 1),
newSet = set.Set[net.Conn]{} // conns in StateNew debug: debug,
)
needConnCh := make(chan bool, 1)
hs.ConnState = func(c net.Conn, s http.ConnState) {
stMu.Lock()
defer stMu.Unlock()
oldLen := len(newSet)
switch s {
case http.StateNew:
newSet.Add(c)
default:
newSet.Delete(c)
}
if oldLen != 0 && len(newSet) == 0 {
select {
case needConnCh <- true:
default:
}
}
} }
hs.ConnState = revSt.connState
conns := make(chan net.Conn, 1) conns := make(chan net.Conn, 1)
lcRP := httputil.NewSingleHostReverseProxy(must.Get(url.Parse("http://local-tailscaled.sock"))) lcRP := httputil.NewSingleHostReverseProxy(must.Get(url.Parse("http://local-tailscaled.sock")))
@ -193,26 +183,14 @@ func main() {
} }
}() }()
var lastErr string revSt.runDialOutLoop(conns)
needConnCh <- true
for {
<-needConnCh
c, err := connect()
if err != nil {
s := err.Error()
if s != lastErr {
log.Printf("Connect failure: %v", s)
}
lastErr = s
time.Sleep(time.Second)
continue
}
conns <- c
}
} }
func connect() (net.Conn, error) { func connect() (net.Conn, error) {
c, err := net.Dial("tcp", *driverAddr) var d net.Dialer
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
c, err := d.DialContext(ctx, "tcp", *driverAddr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -240,6 +218,100 @@ func (cl chanListener) Addr() net.Addr {
} }
} }
type revDialState struct {
needConnCh chan bool
debug bool
mu sync.Mutex
newSet set.Set[net.Conn] // conns in StateNew
onNew map[net.Conn]func()
}
func (s *revDialState) connState(c net.Conn, cs http.ConnState) {
s.mu.Lock()
defer s.mu.Unlock()
oldLen := len(s.newSet)
switch cs {
case http.StateNew:
if f, ok := s.onNew[c]; ok {
f()
delete(s.onNew, c)
}
s.newSet.Make()
s.newSet.Add(c)
default:
s.newSet.Delete(c)
}
s.vlogf("ConnState: %p now %v; newSet %v=>%v", c, s, oldLen, len(s.newSet))
if len(s.newSet) < 2 {
select {
case s.needConnCh <- true:
default:
}
}
}
func (s *revDialState) waitNeedConnect() {
for {
s.mu.Lock()
need := len(s.newSet) < 2
s.mu.Unlock()
if need {
return
}
<-s.needConnCh
}
}
func (s *revDialState) vlogf(format string, arg ...any) {
if !s.debug {
return
}
log.Printf(format, arg...)
}
func (s *revDialState) runDialOutLoop(conns chan<- net.Conn) {
var lastErr string
connected := false
for {
s.vlogf("[dial-driver] waiting need connect...")
s.waitNeedConnect()
s.vlogf("[dial-driver] connecting...")
t0 := time.Now()
c, err := connect()
if err != nil {
s := err.Error()
if s != lastErr {
log.Printf("[dial-driver] connect failure: %v", s)
}
lastErr = s
time.Sleep(time.Second)
continue
}
if !connected {
connected = true
log.Printf("Connected to %v", *driverAddr)
}
s.vlogf("[dial-driver] connected %v => %v after %v", c.LocalAddr(), c.RemoteAddr(), time.Since(t0))
inHTTP := make(chan struct{})
s.mu.Lock()
mak.Set(&s.onNew, c, func() { close(inHTTP) })
s.mu.Unlock()
s.vlogf("[dial-driver] sending...")
conns <- c
s.vlogf("[dial-driver] sent; waiting")
select {
case <-inHTTP:
s.vlogf("[dial-driver] conn in HTTP")
case <-time.After(2 * time.Second):
s.vlogf("[dial-driver] timeout waiting for conn to be accepted into HTTP")
}
}
}
func addFirewallHandler(w http.ResponseWriter, r *http.Request) { func addFirewallHandler(w http.ResponseWriter, r *http.Request) {
if addFirewall == nil { if addFirewall == nil {
http.Error(w, "firewall not supported", 500) http.Error(w, "firewall not supported", 500)

Loading…
Cancel
Save