diff --git a/chirp/chirp.go b/chirp/chirp.go index 3c4342281..eb879df1d 100644 --- a/chirp/chirp.go +++ b/chirp/chirp.go @@ -11,15 +11,31 @@ import ( "fmt" "net" "strings" + "time" +) + +const ( + // Maximum amount of time we should wait when reading a response from BIRD. + responseTimeout = 10 * time.Second ) // New creates a BIRDClient. func New(socket string) (*BIRDClient, error) { + return newWithTimeout(socket, responseTimeout) +} + +func newWithTimeout(socket string, timeout time.Duration) (*BIRDClient, error) { conn, err := net.Dial("unix", socket) if err != nil { return nil, fmt.Errorf("failed to connect to BIRD: %w", err) } - b := &BIRDClient{socket: socket, conn: conn, scanner: bufio.NewScanner(conn)} + b := &BIRDClient{ + socket: socket, + conn: conn, + scanner: bufio.NewScanner(conn), + timeNow: time.Now, + timeout: timeout, + } // Read and discard the first line as that is the welcome message. if _, err := b.readResponse(); err != nil { return nil, err @@ -32,6 +48,8 @@ type BIRDClient struct { socket string conn net.Conn scanner *bufio.Scanner + timeNow func() time.Time + timeout time.Duration } // Close closes the underlying connection to BIRD. @@ -81,10 +99,15 @@ func (b *BIRDClient) EnableProtocol(protocol string) error { // 1 means ‘table entry’, 8 ‘runtime error’ and 9 ‘syntax error’. func (b *BIRDClient) exec(cmd string, args ...any) (string, error) { + if err := b.conn.SetWriteDeadline(b.timeNow().Add(b.timeout)); err != nil { + return "", err + } if _, err := fmt.Fprintf(b.conn, cmd, args...); err != nil { return "", err } - fmt.Fprintln(b.conn) + if _, err := fmt.Fprintln(b.conn); err != nil { + return "", err + } return b.readResponse() } @@ -105,14 +128,20 @@ func hasResponseCode(s []byte) bool { } func (b *BIRDClient) readResponse() (string, error) { + // Set the read timeout before we start reading anything. + if err := b.conn.SetReadDeadline(b.timeNow().Add(b.timeout)); err != nil { + return "", err + } + var resp strings.Builder var done bool for !done { if !b.scanner.Scan() { - return "", fmt.Errorf("reading response from bird failed: %q", resp.String()) - } - if err := b.scanner.Err(); err != nil { - return "", err + if err := b.scanner.Err(); err != nil { + return "", err + } + + return "", fmt.Errorf("reading response from bird failed (EOF): %q", resp.String()) } out := b.scanner.Bytes() if _, err := resp.Write(out); err != nil { diff --git a/chirp/chirp_test.go b/chirp/chirp_test.go index c29f36073..48c4b45fe 100644 --- a/chirp/chirp_test.go +++ b/chirp/chirp_test.go @@ -8,9 +8,12 @@ import ( "errors" "fmt" "net" + "os" "path/filepath" "strings" + "sync" "testing" + "time" ) type fakeBIRD struct { @@ -109,3 +112,82 @@ func TestChirp(t *testing.T) { t.Fatalf("disabling %q succeded", "rando") } } + +type hangingListener struct { + net.Listener + t *testing.T + done chan struct{} + wg sync.WaitGroup + sock string +} + +func newHangingListener(t *testing.T) *hangingListener { + sock := filepath.Join(t.TempDir(), "sock") + l, err := net.Listen("unix", sock) + if err != nil { + t.Fatal(err) + } + return &hangingListener{ + Listener: l, + t: t, + done: make(chan struct{}), + sock: sock, + } +} + +func (hl *hangingListener) Stop() { + hl.Close() + close(hl.done) + hl.wg.Wait() +} + +func (hl *hangingListener) listen() error { + for { + c, err := hl.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + return nil + } + return err + } + hl.wg.Add(1) + go hl.handle(c) + } +} + +func (hl *hangingListener) handle(c net.Conn) { + defer hl.wg.Done() + + // Write our fake first line of response so that we get into the read loop + fmt.Fprintln(c, "0001 BIRD 2.0.8 ready.") + + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + hl.t.Logf("connection still hanging") + case <-hl.done: + return + } + } +} + +func TestChirpTimeout(t *testing.T) { + fb := newHangingListener(t) + defer fb.Stop() + go fb.listen() + + c, err := newWithTimeout(fb.sock, 500*time.Millisecond) + if err != nil { + t.Fatal(err) + } + + err = c.EnableProtocol("tailscale") + if err == nil { + t.Fatal("got err=nil, want timeout") + } + if !os.IsTimeout(err) { + t.Fatalf("got err=%v, want os.IsTimeout(err)=true", err) + } +}