receiver: only permit one notify callback

There is no point spamming a list for every function call, there is no
use case where multiple notify callbacks would be useful.
wip-fakessh-exit-status
David Wilson 7 years ago
parent 265d9f0293
commit 849ccebe04

@ -283,9 +283,10 @@ def _queue_interruptible_get(queue, timeout=None, block=True):
class Receiver(object): class Receiver(object):
notify = None
def __init__(self, router, handle=None, persist=True, respondent=None): def __init__(self, router, handle=None, persist=True, respondent=None):
self.router = router self.router = router
self.notify = []
self.handle = handle # Avoid __repr__ crash in add_handler() self.handle = handle # Avoid __repr__ crash in add_handler()
self.handle = router.add_handler(self._on_receive, handle, self.handle = router.add_handler(self._on_receive, handle,
persist, respondent) persist, respondent)
@ -298,8 +299,8 @@ class Receiver(object):
"""Callback from the Stream; appends data to the internal queue.""" """Callback from the Stream; appends data to the internal queue."""
IOLOG.debug('%r._on_receive(%r)', self, msg) IOLOG.debug('%r._on_receive(%r)', self, msg)
self._queue.put(msg) self._queue.put(msg)
for func in self.notify: if self.notify:
func(self) self.notify(self)
def close(self): def close(self):
self._queue.put(_DEAD) self._queue.put(_DEAD)

@ -237,18 +237,19 @@ class SelectError(mitogen.core.Error):
class Select(object): class Select(object):
notify = None
def __init__(self, receivers=(), oneshot=True): def __init__(self, receivers=(), oneshot=True):
self._receivers = [] self._receivers = []
self._oneshot = oneshot self._oneshot = oneshot
self._queue = Queue.Queue() self._queue = Queue.Queue()
self.notify = []
for recv in receivers: for recv in receivers:
self.add(recv) self.add(recv)
def _put(self, value): def _put(self, value):
self._queue.put(value) self._queue.put(value)
for func in self.notify: if self.notify:
func(self) self.notify(self)
def __bool__(self): def __bool__(self):
return bool(self._receivers) return bool(self._receivers)
@ -276,12 +277,17 @@ class Select(object):
if isinstance(recv_, Select): if isinstance(recv_, Select):
recv_._check_no_loop(recv) recv_._check_no_loop(recv)
owned_msg = 'Cannot add: Receiver is already owned by another Select'
def add(self, recv): def add(self, recv):
if isinstance(recv, Select): if isinstance(recv, Select):
recv._check_no_loop(self) recv._check_no_loop(self)
self._receivers.append(recv) self._receivers.append(recv)
recv.notify.append(self._put) if recv.notify is not None:
raise SelectError(self.owned_msg)
recv.notify = self._put
# Avoid race by polling once after installation. # Avoid race by polling once after installation.
if not recv.empty(): if not recv.empty():
self._put(recv) self._put(recv)
@ -290,8 +296,10 @@ class Select(object):
def remove(self, recv): def remove(self, recv):
try: try:
if recv.notify != self._put:
raise ValueError
self._receivers.remove(recv) self._receivers.remove(recv)
recv.notify.remove(self._put) recv.notify = None
except (IndexError, ValueError): except (IndexError, ValueError):
raise SelectError(self.not_present_msg) raise SelectError(self.not_present_msg)

@ -14,8 +14,7 @@ class AddTest(testlib.RouterMixin, testlib.TestCase):
select.add(recv) select.add(recv)
self.assertEquals(1, len(select._receivers)) self.assertEquals(1, len(select._receivers))
self.assertEquals(recv, select._receivers[0]) self.assertEquals(recv, select._receivers[0])
self.assertEquals(1, len(recv.notify)) self.assertEquals(select._put, recv.notify)
self.assertEquals(select._put, recv.notify[0])
def test_channel(self): def test_channel(self):
context = self.router.local() context = self.router.local()
@ -24,8 +23,7 @@ class AddTest(testlib.RouterMixin, testlib.TestCase):
select.add(chan) select.add(chan)
self.assertEquals(1, len(select._receivers)) self.assertEquals(1, len(select._receivers))
self.assertEquals(chan, select._receivers[0]) self.assertEquals(chan, select._receivers[0])
self.assertEquals(1, len(chan.notify)) self.assertEquals(select._put, chan.notify)
self.assertEquals(select._put, chan.notify[0])
def test_subselect_empty(self): def test_subselect_empty(self):
select = self.klass() select = self.klass()
@ -33,8 +31,7 @@ class AddTest(testlib.RouterMixin, testlib.TestCase):
select.add(subselect) select.add(subselect)
self.assertEquals(1, len(select._receivers)) self.assertEquals(1, len(select._receivers))
self.assertEquals(subselect, select._receivers[0]) self.assertEquals(subselect, select._receivers[0])
self.assertEquals(1, len(subselect.notify)) self.assertEquals(select._put, subselect.notify)
self.assertEquals(select._put, subselect.notify[0])
def test_subselect_nonempty(self): def test_subselect_nonempty(self):
recv = mitogen.core.Receiver(self.router) recv = mitogen.core.Receiver(self.router)
@ -45,8 +42,7 @@ class AddTest(testlib.RouterMixin, testlib.TestCase):
select.add(subselect) select.add(subselect)
self.assertEquals(1, len(select._receivers)) self.assertEquals(1, len(select._receivers))
self.assertEquals(subselect, select._receivers[0]) self.assertEquals(subselect, select._receivers[0])
self.assertEquals(1, len(subselect.notify)) self.assertEquals(select._put, subselect.notify)
self.assertEquals(select._put, subselect.notify[0])
def test_subselect_loop_direct(self): def test_subselect_loop_direct(self):
select = self.klass() select = self.klass()
@ -65,6 +61,22 @@ class AddTest(testlib.RouterMixin, testlib.TestCase):
lambda: s2.add(s0)) lambda: s2.add(s0))
self.assertEquals(str(exc), self.klass.loop_msg) self.assertEquals(str(exc), self.klass.loop_msg)
def test_double_add_receiver(self):
select = self.klass()
recv = mitogen.core.Receiver(self.router)
select.add(recv)
exc = self.assertRaises(mitogen.master.SelectError,
lambda: select.add(recv))
self.assertEquals(str(exc), self.klass.owned_msg)
def test_double_add_subselect(self):
select = self.klass()
select2 = self.klass()
select.add(select2)
exc = self.assertRaises(mitogen.master.SelectError,
lambda: select.add(select2))
self.assertEquals(str(exc), self.klass.owned_msg)
class RemoveTest(testlib.RouterMixin, testlib.TestCase): class RemoveTest(testlib.RouterMixin, testlib.TestCase):
klass = mitogen.master.Select klass = mitogen.master.Select
@ -91,7 +103,7 @@ class RemoveTest(testlib.RouterMixin, testlib.TestCase):
select.add(recv) select.add(recv)
select.remove(recv) select.remove(recv)
self.assertEquals(0, len(select._receivers)) self.assertEquals(0, len(select._receivers))
self.assertEquals(0, len(recv.notify)) self.assertEquals(None, recv.notify)
class CloseTest(testlib.RouterMixin, testlib.TestCase): class CloseTest(testlib.RouterMixin, testlib.TestCase):
@ -107,12 +119,11 @@ class CloseTest(testlib.RouterMixin, testlib.TestCase):
select.add(recv) select.add(recv)
self.assertEquals(1, len(select._receivers)) self.assertEquals(1, len(select._receivers))
self.assertEquals(1, len(recv.notify)) self.assertEquals(select._put, recv.notify)
self.assertEquals(select._put, recv.notify[0])
select.close() select.close()
self.assertEquals(0, len(select._receivers)) self.assertEquals(0, len(select._receivers))
self.assertEquals(0, len(recv.notify)) self.assertEquals(None, recv.notify)
def test_one_subselect(self): def test_one_subselect(self):
select = self.klass() select = self.klass()
@ -123,16 +134,15 @@ class CloseTest(testlib.RouterMixin, testlib.TestCase):
subselect.add(recv) subselect.add(recv)
self.assertEquals(1, len(select._receivers)) self.assertEquals(1, len(select._receivers))
self.assertEquals(1, len(recv.notify)) self.assertEquals(subselect._put, recv.notify)
self.assertEquals(subselect._put, recv.notify[0])
select.close() select.close()
self.assertEquals(0, len(select._receivers)) self.assertEquals(0, len(select._receivers))
self.assertEquals(1, len(recv.notify)) self.assertEquals(subselect._put, recv.notify)
subselect.close() subselect.close()
self.assertEquals(0, len(recv.notify)) self.assertEquals(None, recv.notify)
class EmptyTest(testlib.RouterMixin, testlib.TestCase): class EmptyTest(testlib.RouterMixin, testlib.TestCase):
@ -186,7 +196,7 @@ class OneShotTest(testlib.RouterMixin, testlib.TestCase):
recv, (msg_, data) = select.get() recv, (msg_, data) = select.get()
self.assertEquals(msg, msg_) self.assertEquals(msg, msg_)
self.assertEquals(0, len(select._receivers)) self.assertEquals(0, len(select._receivers))
self.assertEquals(0, len(recv.notify)) self.assertEquals(None, recv.notify)
def test_false_persists_after_get(self): def test_false_persists_after_get(self):
recv = mitogen.core.Receiver(self.router) recv = mitogen.core.Receiver(self.router)
@ -196,8 +206,7 @@ class OneShotTest(testlib.RouterMixin, testlib.TestCase):
self.assertEquals((recv, (msg, '123')), select.get()) self.assertEquals((recv, (msg, '123')), select.get())
self.assertEquals(1, len(select._receivers)) self.assertEquals(1, len(select._receivers))
self.assertEquals(recv, select._receivers[0]) self.assertEquals(recv, select._receivers[0])
self.assertEquals(1, len(recv.notify)) self.assertEquals(select._put, recv.notify)
self.assertEquals(select._put, recv.notify[0])
class GetTest(testlib.RouterMixin, testlib.TestCase): class GetTest(testlib.RouterMixin, testlib.TestCase):

Loading…
Cancel
Save