diff --git a/docs/api.rst b/docs/api.rst index 3ea787dc..0b420749 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -1104,6 +1104,11 @@ Exceptions Raised when a channel dies or has been closed. +.. class:: LatchError (fmt, \*args) + + Raised when an attempt is made to use a :py:class:`mitogen.core.Latch` that + has been marked closed. + .. class:: StreamError (fmt, \*args) Raised when a stream cannot be established. diff --git a/docs/internals.rst b/docs/internals.rst index 84f70521..71a03fb7 100644 --- a/docs/internals.rst +++ b/docs/internals.rst @@ -54,6 +54,9 @@ Latch Class If :py:data:`False`, immediately raise :py:class:`mitogen.core.TimeoutError` if the latch is empty. + :raises mitogen.core.LatchError: + :py:meth:`close` has been called, and the object is no longer valid. + :raises mitogen.core.TimeoutError: Timeout was reached. @@ -62,9 +65,17 @@ Latch Class .. method:: put (obj) - Enquue an object on this latch, waking the first thread that is asleep + Enqueue an object on this latch, waking the first thread that is asleep waiting for a result, if one exists. + :raises mitogen.core.LatchError: + :py:meth:`close` has been called, and the object is no longer valid. + + .. method:: close () + + Mark the latch as closed, and cause every sleeping thread to be woken, + with :py:class:`mitogen.core.LatchError` raised in each thread. + Side Class ---------- diff --git a/mitogen/core.py b/mitogen/core.py index ad299dea..6105e069 100644 --- a/mitogen/core.py +++ b/mitogen/core.py @@ -88,6 +88,10 @@ class Error(Exception): Exception.__init__(self, fmt) +class LatchError(Error): + pass + + class CallError(Error): def __init__(self, e): s = '%s.%s: %s' % (type(e).__module__, type(e).__name__, e) @@ -902,6 +906,8 @@ def _unpickle_context(router, context_id, name): class Latch(object): + closed = False + def __init__(self): self.lock = threading.Lock() self.queue = [] @@ -913,12 +919,23 @@ class Latch(object): set_cloexec(_tls.rsock.fileno()) set_cloexec(_tls.wsock.fileno()) + def close(self): + self.lock.acquire() + try: + self.closed = True + for sock in self.wake_socks: + self._wake(sock) + finally: + self.lock.release() + def empty(self): return len(self.queue) == 0 def get(self, timeout=None, block=True): self.lock.acquire() try: + if self.closed: + raise LatchError() if self.queue: return self.queue.pop(0) if not block: @@ -933,6 +950,8 @@ class Latch(object): self.lock.acquire() try: + if self.closed: + raise LatchError() if _tls.wsock in self.wake_socks: # Nothing woke us, remove stale entry. self.wake_socks.remove(_tls.wsock) @@ -948,6 +967,8 @@ class Latch(object): _vv and IOLOG.debug('%r.put(%r)', self, obj) self.lock.acquire() try: + if self.closed: + raise LatchError() self.queue.append(obj) woken = len(self.wake_socks) > 0 if woken: diff --git a/tests/latch_test.py b/tests/latch_test.py index 2922cd71..4ea2bc50 100644 --- a/tests/latch_test.py +++ b/tests/latch_test.py @@ -1,4 +1,6 @@ +import threading + import unittest2 import mitogen.core @@ -52,6 +54,51 @@ class GetTest(testlib.TestCase): self.assertEquals(obj, latch.get(timeout=0)) +class ThreadedGetTest(testlib.TestCase): + klass = mitogen.core.Latch + + def setUp(self): + super(ThreadedGetTest, self).setUp() + self.results = [] + self.excs = [] + self.threads = [] + + def _worker(self, func): + try: + self.results.append(func()) + except Exception, e: + self.results.append(None) + self.excs.append(e) + + def start_one(self, func): + thread = threading.Thread(target=self._worker, args=(func,)) + thread.start() + self.threads.append(thread) + + def join(self): + for th in self.threads: + th.join(3.0) + + def test_one_thread(self): + latch = self.klass() + self.start_one(lambda: latch.get(timeout=3.0)) + latch.put('test') + self.join() + self.assertEquals(self.results, ['test']) + self.assertEquals(self.excs, []) + + def test_five_threads(self): + latch = self.klass() + for x in xrange(5): + self.start_one(lambda: latch.get(timeout=3.0)) + for x in xrange(5): + latch.put(x) + self.join() + self.assertEquals(sorted(self.results), range(5)) + self.assertEquals(self.excs, []) + + + class PutTest(testlib.TestCase): klass = mitogen.core.Latch @@ -61,5 +108,102 @@ class PutTest(testlib.TestCase): self.assertEquals(None, latch.get()) +class CloseTest(testlib.TestCase): + klass = mitogen.core.Latch + + def test_empty_noblock(self): + latch = self.klass() + latch.close() + self.assertRaises(mitogen.core.LatchError, + lambda: latch.get(block=False)) + + def test_empty_zero_timeout(self): + latch = self.klass() + latch.close() + self.assertRaises(mitogen.core.LatchError, + lambda: latch.get(timeout=0)) + + def test_nonempty(self): + obj = object() + latch = self.klass() + latch.put(obj) + latch.close() + self.assertRaises(mitogen.core.LatchError, + lambda: latch.get()) + + def test_nonempty_noblock(self): + obj = object() + latch = self.klass() + latch.put(obj) + latch.close() + self.assertRaises(mitogen.core.LatchError, + lambda: latch.get(block=False)) + + def test_nonempty_zero_timeout(self): + obj = object() + latch = self.klass() + latch.put(obj) + latch.close() + self.assertRaises(mitogen.core.LatchError, + lambda: latch.get(timeout=0)) + + def test_put(self): + latch = self.klass() + latch.close() + self.assertRaises(mitogen.core.LatchError, + lambda: latch.put(None)) + + def test_double_close(self): + latch = self.klass() + latch.close() + latch.close() + + +class ThreadedCloseTest(testlib.TestCase): + klass = mitogen.core.Latch + + def setUp(self): + super(ThreadedCloseTest, self).setUp() + self.results = [] + self.excs = [] + self.threads = [] + + def _worker(self, func): + try: + self.results.append(func()) + except Exception, e: + self.results.append(None) + self.excs.append(e) + + def start_one(self, func): + thread = threading.Thread(target=self._worker, args=(func,)) + thread.start() + self.threads.append(thread) + + def join(self): + for th in self.threads: + th.join(3.0) + + def test_one_thread(self): + latch = self.klass() + self.start_one(lambda: latch.get(timeout=3.0)) + latch.close() + self.join() + self.assertEquals(self.results, [None]) + for exc in self.excs: + self.assertTrue(isinstance(exc, mitogen.core.LatchError)) + + def test_five_threads(self): + latch = self.klass() + for x in xrange(5): + self.start_one(lambda: latch.get(timeout=3.0)) + latch.close() + self.join() + self.assertEquals(self.results, [None]*5) + for exc in self.excs: + self.assertTrue(isinstance(exc, mitogen.core.LatchError)) + + + if __name__ == '__main__': unittest2.main()