diff --git a/mitogen/master.py b/mitogen/master.py index 5c1f40e8..e9cad282 100644 --- a/mitogen/master.py +++ b/mitogen/master.py @@ -157,14 +157,25 @@ def tty_create_child(*args): return pid, master_fd -def write_all(fd, s): +def write_all(fd, s, deadline=None): + timeout = None written = 0 + while written < len(s): - rc = os.write(fd, buffer(s, written)) - if not rc: - raise IOError('short write') - written += rc - return written + if deadline is not None: + timeout = max(0, deadline - time.time()) + if timeout == 0: + raise mitogen.core.TimeoutError('write timed out') + + _, wfds, _ = select.select([], [fd], [], timeout) + if not wfds: + continue + + n, disconnected = mitogen.core.io_op(os.write, fd, buffer(s, written)) + if disconnected: + raise mitogen.core.StreamError('EOF on stream during write') + + written += n def iter_read(fd, deadline=None): diff --git a/tests/data/write_all_consumer.sh b/tests/data/write_all_consumer.sh new file mode 100755 index 00000000..e6aaaf72 --- /dev/null +++ b/tests/data/write_all_consumer.sh @@ -0,0 +1,7 @@ +#!/bin/bash +# I consume 65535 bytes every 10ms, for testing mitogen.core.write_all() + +while :; do + read -n 65535 + sleep 0.01 +done diff --git a/tests/master_test.py b/tests/master_test.py index 72318f46..2bea6193 100644 --- a/tests/master_test.py +++ b/tests/master_test.py @@ -12,7 +12,9 @@ class IterReadTest(unittest.TestCase): def make_proc(self): args = [testlib.data_path('iter_read_generator.sh')] - return subprocess.Popen(args, stdout=subprocess.PIPE) + proc = subprocess.Popen(args, stdout=subprocess.PIPE) + mitogen.core.set_nonblock(proc.stdout.fileno()) + return proc def test_no_deadline(self): proc = self.make_proc() @@ -54,3 +56,43 @@ class IterReadTest(unittest.TestCase): assert 3 < len(got) < 5 finally: proc.terminate() + + +class WriteAllTest(unittest.TestCase): + func = staticmethod(mitogen.master.write_all) + + def make_proc(self): + args = [testlib.data_path('write_all_consumer.sh')] + proc = subprocess.Popen(args, stdin=subprocess.PIPE) + mitogen.core.set_nonblock(proc.stdin.fileno()) + return proc + + ten_ms_chunk = ('x' * 65535) + + def test_no_deadline(self): + proc = self.make_proc() + try: + self.func(proc.stdin.fileno(), self.ten_ms_chunk) + finally: + proc.terminate() + + def test_deadline_exceeded_before_call(self): + proc = self.make_proc() + try: + self.assertRaises(mitogen.core.TimeoutError, ( + lambda: self.func(proc.stdin.fileno(), self.ten_ms_chunk, 0) + )) + finally: + proc.terminate() + + def test_deadline_exceeded_during_call(self): + proc = self.make_proc() + try: + deadline = time.time() + 0.1 # 100ms deadline + self.assertRaises(mitogen.core.TimeoutError, ( + lambda: self.func(proc.stdin.fileno(), + self.ten_ms_chunk * 100, # 1s of data + deadline) + )) + finally: + proc.terminate()