From e1738ea78edde07ff45a3ea8a234d1ea3d772c57 Mon Sep 17 00:00:00 2001 From: Andrew Dunham Date: Sat, 27 Aug 2022 20:49:31 -0400 Subject: [PATCH] chirp: add a 10s timeout when communicating with BIRD (#5444) Prior to this change, if BIRD stops responding wgengine.watchdogEngine will crash tailscaled. This happens because in wgengine.userspaceEngine, we end up blocking forever trying to write a request to or read a response from BIRD with wgLock held, and then future watchdog'd calls will block on acquiring that mutex until the watchdog kills the process. With the timeout, we at least get the chance to print an error message and decide whether we want to crash or not. Updates tailscale/coral#72 Signed-off-by: Andrew Dunham Signed-off-by: Andrew Dunham --- chirp/chirp.go | 41 +++++++++++++++++++---- chirp/chirp_test.go | 82 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 117 insertions(+), 6 deletions(-) diff --git a/chirp/chirp.go b/chirp/chirp.go index 3c4342281..eb879df1d 100644 --- a/chirp/chirp.go +++ b/chirp/chirp.go @@ -11,15 +11,31 @@ import ( "fmt" "net" "strings" + "time" +) + +const ( + // Maximum amount of time we should wait when reading a response from BIRD. + responseTimeout = 10 * time.Second ) // New creates a BIRDClient. func New(socket string) (*BIRDClient, error) { + return newWithTimeout(socket, responseTimeout) +} + +func newWithTimeout(socket string, timeout time.Duration) (*BIRDClient, error) { conn, err := net.Dial("unix", socket) if err != nil { return nil, fmt.Errorf("failed to connect to BIRD: %w", err) } - b := &BIRDClient{socket: socket, conn: conn, scanner: bufio.NewScanner(conn)} + b := &BIRDClient{ + socket: socket, + conn: conn, + scanner: bufio.NewScanner(conn), + timeNow: time.Now, + timeout: timeout, + } // Read and discard the first line as that is the welcome message. if _, err := b.readResponse(); err != nil { return nil, err @@ -32,6 +48,8 @@ type BIRDClient struct { socket string conn net.Conn scanner *bufio.Scanner + timeNow func() time.Time + timeout time.Duration } // Close closes the underlying connection to BIRD. @@ -81,10 +99,15 @@ func (b *BIRDClient) EnableProtocol(protocol string) error { // 1 means ‘table entry’, 8 ‘runtime error’ and 9 ‘syntax error’. func (b *BIRDClient) exec(cmd string, args ...any) (string, error) { + if err := b.conn.SetWriteDeadline(b.timeNow().Add(b.timeout)); err != nil { + return "", err + } if _, err := fmt.Fprintf(b.conn, cmd, args...); err != nil { return "", err } - fmt.Fprintln(b.conn) + if _, err := fmt.Fprintln(b.conn); err != nil { + return "", err + } return b.readResponse() } @@ -105,14 +128,20 @@ func hasResponseCode(s []byte) bool { } func (b *BIRDClient) readResponse() (string, error) { + // Set the read timeout before we start reading anything. + if err := b.conn.SetReadDeadline(b.timeNow().Add(b.timeout)); err != nil { + return "", err + } + var resp strings.Builder var done bool for !done { if !b.scanner.Scan() { - return "", fmt.Errorf("reading response from bird failed: %q", resp.String()) - } - if err := b.scanner.Err(); err != nil { - return "", err + if err := b.scanner.Err(); err != nil { + return "", err + } + + return "", fmt.Errorf("reading response from bird failed (EOF): %q", resp.String()) } out := b.scanner.Bytes() if _, err := resp.Write(out); err != nil { diff --git a/chirp/chirp_test.go b/chirp/chirp_test.go index c29f36073..48c4b45fe 100644 --- a/chirp/chirp_test.go +++ b/chirp/chirp_test.go @@ -8,9 +8,12 @@ import ( "errors" "fmt" "net" + "os" "path/filepath" "strings" + "sync" "testing" + "time" ) type fakeBIRD struct { @@ -109,3 +112,82 @@ func TestChirp(t *testing.T) { t.Fatalf("disabling %q succeded", "rando") } } + +type hangingListener struct { + net.Listener + t *testing.T + done chan struct{} + wg sync.WaitGroup + sock string +} + +func newHangingListener(t *testing.T) *hangingListener { + sock := filepath.Join(t.TempDir(), "sock") + l, err := net.Listen("unix", sock) + if err != nil { + t.Fatal(err) + } + return &hangingListener{ + Listener: l, + t: t, + done: make(chan struct{}), + sock: sock, + } +} + +func (hl *hangingListener) Stop() { + hl.Close() + close(hl.done) + hl.wg.Wait() +} + +func (hl *hangingListener) listen() error { + for { + c, err := hl.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + return nil + } + return err + } + hl.wg.Add(1) + go hl.handle(c) + } +} + +func (hl *hangingListener) handle(c net.Conn) { + defer hl.wg.Done() + + // Write our fake first line of response so that we get into the read loop + fmt.Fprintln(c, "0001 BIRD 2.0.8 ready.") + + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + hl.t.Logf("connection still hanging") + case <-hl.done: + return + } + } +} + +func TestChirpTimeout(t *testing.T) { + fb := newHangingListener(t) + defer fb.Stop() + go fb.listen() + + c, err := newWithTimeout(fb.sock, 500*time.Millisecond) + if err != nil { + t.Fatal(err) + } + + err = c.EnableProtocol("tailscale") + if err == nil { + t.Fatal("got err=nil, want timeout") + } + if !os.IsTimeout(err) { + t.Fatalf("got err=%v, want os.IsTimeout(err)=true", err) + } +}