diff options
-rw-r--r-- | lib/git/async/pool.py | 55 | ||||
-rw-r--r-- | lib/git/async/task.py | 7 | ||||
-rw-r--r-- | test/git/async/test_pool.py | 159 |
3 files changed, 172 insertions, 49 deletions
diff --git a/lib/git/async/pool.py b/lib/git/async/pool.py index f7c1cfe0..2ec18f1a 100644 --- a/lib/git/async/pool.py +++ b/lib/git/async/pool.py @@ -25,6 +25,7 @@ from channel import ( ) import sys +import weakref from time import sleep @@ -33,25 +34,37 @@ class RPoolChannel(CallbackRChannel): before and after an item is to be read. It acts like a handle to the underlying task in the pool.""" - __slots__ = ('_task', '_pool') + __slots__ = ('_task_ref', '_pool_ref') def __init__(self, wchannel, task, pool): CallbackRChannel.__init__(self, wchannel) - self._task = task - self._pool = pool + self._task_ref = weakref.ref(task) + self._pool_ref = weakref.ref(pool) def __del__(self): """Assures that our task will be deleted if we were the last reader""" - del(self._wc) # decrement ref-count early - # now, if this is the last reader to the wc we just handled, there + task = self._task_ref() + if task is None: + return + + pool = self._pool_ref() + if pool is None: + return + + # if this is the last reader to the wc we just handled, there # is no way anyone will ever read from the task again. If so, # delete the task in question, it will take care of itself and orphans # it might leave # 1 is ourselves, + 1 for the call + 1, and 3 magical ones which # I can't explain, but appears to be normal in the destructor # On the caller side, getrefcount returns 2, as expected + # When just calling remove_task, + # it has no way of knowing that the write channel is about to diminsh. + # which is why we pass the info as a private kwarg - not nice, but + # okay for now + # TODO: Fix this - private/public method if sys.getrefcount(self) < 6: - self._pool.remove_task(self._task) + pool.remove_task(task, _from_destructor_=True) # END handle refcount based removal of task def read(self, count=0, block=True, timeout=None): @@ -72,11 +85,16 @@ class RPoolChannel(CallbackRChannel): # if the user tries to use us to read from a done task, we will never # compute as all produced items are already in the channel - skip_compute = self._task.is_done() or self._task.error() + task = self._task_ref() + if task is None: + return list() + # END abort if task was deleted + + skip_compute = task.is_done() or task.error() ########## prepare ############################## if not skip_compute: - self._pool._prepare_channel_read(self._task, count) + self._pool_ref()._prepare_channel_read(task, count) # END prepare pool scheduling @@ -261,11 +279,16 @@ class Pool(object): # END for each task to process - def _remove_task_if_orphaned(self, task): + def _remove_task_if_orphaned(self, task, from_destructor): """Check the task, and delete it if it is orphaned""" # 1 as its stored on the task, 1 for the getrefcount call - if sys.getrefcount(task._out_wc) < 3: - self.remove_task(task) + # If we are getting here from the destructor of an RPool channel, + # its totally valid to virtually decrement the refcount by 1 as + # we can expect it to drop once the destructor completes, which is when + # we finish all recursive calls + max_ref_count = 3 + from_destructor + if sys.getrefcount(task.wchannel()) < max_ref_count: + self.remove_task(task, from_destructor) #} END internal #{ Interface @@ -335,7 +358,7 @@ class Pool(object): finally: self._taskgraph_lock.release() - def remove_task(self, task): + def remove_task(self, task, _from_destructor_=False): """Delete the task Additionally we will remove orphaned tasks, which can be identified if their output channel is only held by themselves, so no one will ever consume @@ -370,7 +393,7 @@ class Pool(object): # END locked deletion for t in in_tasks: - self._remove_task_if_orphaned(t) + self._remove_task_if_orphaned(t, _from_destructor_) # END handle orphans recursively return self @@ -409,11 +432,11 @@ class Pool(object): # If the input channel is one of our read channels, we add the relation if isinstance(task, InputChannelTask): - ic = task.in_rc - if isinstance(ic, RPoolChannel) and ic._pool is self: + ic = task.rchannel() + if isinstance(ic, RPoolChannel) and ic._pool_ref() is self: self._taskgraph_lock.acquire() try: - self._tasks.add_edge(ic._task, task) + self._tasks.add_edge(ic._task_ref(), task) finally: self._taskgraph_lock.release() # END handle edge-adding diff --git a/lib/git/async/task.py b/lib/git/async/task.py index f98336b2..03b40492 100644 --- a/lib/git/async/task.py +++ b/lib/git/async/task.py @@ -208,5 +208,8 @@ class InputChannelTask(OutputChannelTask): OutputChannelTask.__init__(self, *args, **kwargs) self._read = in_rc.read - #{ Configuration - + def rchannel(self): + """:return: input channel from which we read""" + # the instance is bound in its instance method - lets use this to keep + # the refcount at one ( per consumer ) + return self._read.im_self diff --git a/test/git/async/test_pool.py b/test/git/async/test_pool.py index 202fdb66..2a5e4647 100644 --- a/test/git/async/test_pool.py +++ b/test/git/async/test_pool.py @@ -8,15 +8,14 @@ import threading import time import sys -class TestThreadTaskNode(InputIteratorThreadTask): +class _TestTaskBase(object): def __init__(self, *args, **kwargs): - super(TestThreadTaskNode, self).__init__(*args, **kwargs) + super(_TestTaskBase, self).__init__(*args, **kwargs) self.should_fail = False self.lock = threading.Lock() # yes, can't safely do x = x + 1 :) self.plock = threading.Lock() self.item_count = 0 self.process_count = 0 - self._scheduled_items = 0 def do_fun(self, item): self.lock.acquire() @@ -32,44 +31,118 @@ class TestThreadTaskNode(InputIteratorThreadTask): self.plock.acquire() self.process_count += 1 self.plock.release() - super(TestThreadTaskNode, self).process(count) + super(_TestTaskBase, self).process(count) def _assert(self, pc, fc, check_scheduled=False): """Assert for num process counts (pc) and num function counts (fc) :return: self""" - # TODO: fixme - return self - self.plock.acquire() - if self.process_count != pc: - print self.process_count, pc - assert self.process_count == pc - self.plock.release() self.lock.acquire() if self.item_count != fc: print self.item_count, fc assert self.item_count == fc self.lock.release() - # if we read all, we can't really use scheduled items - if check_scheduled: - assert self._scheduled_items == 0 - assert not self.error() return self + +class TestThreadTaskNode(_TestTaskBase, InputIteratorThreadTask): + pass class TestThreadFailureNode(TestThreadTaskNode): """Fails after X items""" + def __init__(self, *args, **kwargs): + self.fail_after = kwargs.pop('fail_after') + super(TestThreadFailureNode, self).__init__(*args, **kwargs) + def do_fun(self, item): + item = TestThreadTaskNode.do_fun(self, item) + if self.item_count > self.fail_after: + raise AssertionError("Simulated failure after processing %i items" % self.fail_after) + return item + + +class TestThreadInputChannelTaskNode(_TestTaskBase, InputChannelTask): + """Apply a transformation on items read from an input channel""" + + def do_fun(self, item): + """return tuple(i, i*2)""" + item = super(TestThreadInputChannelTaskNode, self).do_fun(item) + if isinstance(item, tuple): + i = item[0] + return item + (i * self.id, ) + else: + return (item, item * self.id) + # END handle tuple + + +class TestThreadInputChannelVerifyTaskNode(_TestTaskBase, InputChannelTask): + """An input channel task, which verifies the result of its input channels, + should be last in the chain. + Id must be int""" + + def do_fun(self, item): + """return tuple(i, i*2)""" + item = super(TestThreadInputChannelTaskNode, self).do_fun(item) + + # make sure the computation order matches + assert isinstance(item, tuple) + + base = item[0] + for num in item[1:]: + assert num == base * 2 + base = num + # END verify order + + return item + + class TestThreadPool(TestBase): max_threads = cpu_count() - def _add_triple_task(self, p): - """Add a triplet of feeder, transformer and finalizer to the pool, like - t1 -> t2 -> t3, return all 3 return channels in order""" - # t1 = TestThreadTaskNode(make_task(), 'iterator', None) - # TODO: + def _add_task_chain(self, p, ni, count=1): + """Create a task chain of feeder, count transformers and order verifcator + to the pool p, like t1 -> t2 -> t3 + :return: tuple(list(task1, taskN, ...), list(rc1, rcN, ...))""" + nt = p.num_tasks() + + feeder = self._make_iterator_task(ni) + frc = p.add_task(feeder) + + assert p.num_tasks() == nt + 1 + + rcs = [frc] + tasks = [feeder] + + inrc = frc + for tc in xrange(count): + t = TestThreadInputChannelTaskNode(inrc, tc, None) + t.fun = t.do_fun + inrc = p.add_task(t) + + tasks.append(t) + rcs.append(inrc) + assert p.num_tasks() == nt + 2 + tc + # END create count transformers + + verifier = TestThreadInputChannelVerifyTaskNode(inrc, 'verifier', None) + verifier.fun = verifier.do_fun + vrc = p.add_task(verifier) + + assert p.num_tasks() == nt + tc + 3 + + tasks.append(verifier) + rcs.append(vrc) + return tasks, rcs + + def _make_iterator_task(self, ni, taskcls=TestThreadTaskNode, **kwargs): + """:return: task which yields ni items + :param taskcls: the actual iterator type to use + :param **kwargs: additional kwargs to be passed to the task""" + t = taskcls(iter(range(ni)), 'iterator', None, **kwargs) + t.fun = t.do_fun + return t def _assert_single_task(self, p, async=False): """Performs testing in a synchronized environment""" @@ -82,11 +155,7 @@ class TestThreadPool(TestBase): assert ni % 2 == 0, "ni needs to be dividable by 2" assert ni % 4 == 0, "ni needs to be dividable by 4" - def make_task(): - t = TestThreadTaskNode(iter(range(ni)), 'iterator', None) - t.fun = t.do_fun - return t - # END utility + make_task = lambda *args, **kwargs: self._make_iterator_task(ni, *args, **kwargs) task = make_task() @@ -252,15 +321,44 @@ class TestThreadPool(TestBase): # test failure after ni / 2 items # This makes sure it correctly closes the channel on failure to prevent blocking + nri = ni/2 + task = make_task(TestThreadFailureNode, fail_after=ni/2) + rc = p.add_task(task) + assert len(rc.read()) == nri + assert task.is_done() + assert isinstance(task.error(), AssertionError) - def _assert_async_dependent_tasks(self, p): + def _assert_async_dependent_tasks(self, pool): # includes failure in center task, 'recursive' orphan cleanup # This will also verify that the channel-close mechanism works # t1 -> t2 -> t3 # t1 -> x -> t3 - pass + null_tasks = pool.num_tasks() + ni = 100 + count = 1 + make_task = lambda *args, **kwargs: self._add_task_chain(pool, ni, count, *args, **kwargs) + + ts, rcs = make_task() + assert len(ts) == count + 2 + assert len(rcs) == count + 2 + assert pool.num_tasks() == null_tasks + len(ts) + print pool._tasks.nodes + + + # in the end, we expect all tasks to be gone, automatically + + + + # order of deletion matters - just keep the end, then delete + final_rc = rcs[-1] + del(ts) + del(rcs) + del(final_rc) + assert pool.num_tasks() == null_tasks + + @terminate_threads def test_base(self): @@ -301,8 +399,8 @@ class TestThreadPool(TestBase): assert p.num_tasks() == 0 - # DEPENDENT TASKS SERIAL - ######################## + # DEPENDENT TASKS SYNC MODE + ########################### self._assert_async_dependent_tasks(p) @@ -311,12 +409,11 @@ class TestThreadPool(TestBase): # step one gear up - just one thread for now. p.set_size(1) assert p.size() == 1 - print len(threading.enumerate()), num_threads assert len(threading.enumerate()) == num_threads + 1 # deleting the pool stops its threads - just to be sure ;) # Its not synchronized, hence we wait a moment del(p) - time.sleep(0.25) + time.sleep(0.05) assert len(threading.enumerate()) == num_threads p = ThreadPool(1) |