summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/git/async/pool.py55
-rw-r--r--lib/git/async/task.py7
-rw-r--r--test/git/async/test_pool.py159
3 files changed, 172 insertions, 49 deletions
diff --git a/lib/git/async/pool.py b/lib/git/async/pool.py
index f7c1cfe0..2ec18f1a 100644
--- a/lib/git/async/pool.py
+++ b/lib/git/async/pool.py
@@ -25,6 +25,7 @@ from channel import (
)
import sys
+import weakref
from time import sleep
@@ -33,25 +34,37 @@ class RPoolChannel(CallbackRChannel):
before and after an item is to be read.
It acts like a handle to the underlying task in the pool."""
- __slots__ = ('_task', '_pool')
+ __slots__ = ('_task_ref', '_pool_ref')
def __init__(self, wchannel, task, pool):
CallbackRChannel.__init__(self, wchannel)
- self._task = task
- self._pool = pool
+ self._task_ref = weakref.ref(task)
+ self._pool_ref = weakref.ref(pool)
def __del__(self):
"""Assures that our task will be deleted if we were the last reader"""
- del(self._wc) # decrement ref-count early
- # now, if this is the last reader to the wc we just handled, there
+ task = self._task_ref()
+ if task is None:
+ return
+
+ pool = self._pool_ref()
+ if pool is None:
+ return
+
+ # if this is the last reader to the wc we just handled, there
# is no way anyone will ever read from the task again. If so,
# delete the task in question, it will take care of itself and orphans
# it might leave
# 1 is ourselves, + 1 for the call + 1, and 3 magical ones which
# I can't explain, but appears to be normal in the destructor
# On the caller side, getrefcount returns 2, as expected
+ # When just calling remove_task,
+ # it has no way of knowing that the write channel is about to diminsh.
+ # which is why we pass the info as a private kwarg - not nice, but
+ # okay for now
+ # TODO: Fix this - private/public method
if sys.getrefcount(self) < 6:
- self._pool.remove_task(self._task)
+ pool.remove_task(task, _from_destructor_=True)
# END handle refcount based removal of task
def read(self, count=0, block=True, timeout=None):
@@ -72,11 +85,16 @@ class RPoolChannel(CallbackRChannel):
# if the user tries to use us to read from a done task, we will never
# compute as all produced items are already in the channel
- skip_compute = self._task.is_done() or self._task.error()
+ task = self._task_ref()
+ if task is None:
+ return list()
+ # END abort if task was deleted
+
+ skip_compute = task.is_done() or task.error()
########## prepare ##############################
if not skip_compute:
- self._pool._prepare_channel_read(self._task, count)
+ self._pool_ref()._prepare_channel_read(task, count)
# END prepare pool scheduling
@@ -261,11 +279,16 @@ class Pool(object):
# END for each task to process
- def _remove_task_if_orphaned(self, task):
+ def _remove_task_if_orphaned(self, task, from_destructor):
"""Check the task, and delete it if it is orphaned"""
# 1 as its stored on the task, 1 for the getrefcount call
- if sys.getrefcount(task._out_wc) < 3:
- self.remove_task(task)
+ # If we are getting here from the destructor of an RPool channel,
+ # its totally valid to virtually decrement the refcount by 1 as
+ # we can expect it to drop once the destructor completes, which is when
+ # we finish all recursive calls
+ max_ref_count = 3 + from_destructor
+ if sys.getrefcount(task.wchannel()) < max_ref_count:
+ self.remove_task(task, from_destructor)
#} END internal
#{ Interface
@@ -335,7 +358,7 @@ class Pool(object):
finally:
self._taskgraph_lock.release()
- def remove_task(self, task):
+ def remove_task(self, task, _from_destructor_=False):
"""Delete the task
Additionally we will remove orphaned tasks, which can be identified if their
output channel is only held by themselves, so no one will ever consume
@@ -370,7 +393,7 @@ class Pool(object):
# END locked deletion
for t in in_tasks:
- self._remove_task_if_orphaned(t)
+ self._remove_task_if_orphaned(t, _from_destructor_)
# END handle orphans recursively
return self
@@ -409,11 +432,11 @@ class Pool(object):
# If the input channel is one of our read channels, we add the relation
if isinstance(task, InputChannelTask):
- ic = task.in_rc
- if isinstance(ic, RPoolChannel) and ic._pool is self:
+ ic = task.rchannel()
+ if isinstance(ic, RPoolChannel) and ic._pool_ref() is self:
self._taskgraph_lock.acquire()
try:
- self._tasks.add_edge(ic._task, task)
+ self._tasks.add_edge(ic._task_ref(), task)
finally:
self._taskgraph_lock.release()
# END handle edge-adding
diff --git a/lib/git/async/task.py b/lib/git/async/task.py
index f98336b2..03b40492 100644
--- a/lib/git/async/task.py
+++ b/lib/git/async/task.py
@@ -208,5 +208,8 @@ class InputChannelTask(OutputChannelTask):
OutputChannelTask.__init__(self, *args, **kwargs)
self._read = in_rc.read
- #{ Configuration
-
+ def rchannel(self):
+ """:return: input channel from which we read"""
+ # the instance is bound in its instance method - lets use this to keep
+ # the refcount at one ( per consumer )
+ return self._read.im_self
diff --git a/test/git/async/test_pool.py b/test/git/async/test_pool.py
index 202fdb66..2a5e4647 100644
--- a/test/git/async/test_pool.py
+++ b/test/git/async/test_pool.py
@@ -8,15 +8,14 @@ import threading
import time
import sys
-class TestThreadTaskNode(InputIteratorThreadTask):
+class _TestTaskBase(object):
def __init__(self, *args, **kwargs):
- super(TestThreadTaskNode, self).__init__(*args, **kwargs)
+ super(_TestTaskBase, self).__init__(*args, **kwargs)
self.should_fail = False
self.lock = threading.Lock() # yes, can't safely do x = x + 1 :)
self.plock = threading.Lock()
self.item_count = 0
self.process_count = 0
- self._scheduled_items = 0
def do_fun(self, item):
self.lock.acquire()
@@ -32,44 +31,118 @@ class TestThreadTaskNode(InputIteratorThreadTask):
self.plock.acquire()
self.process_count += 1
self.plock.release()
- super(TestThreadTaskNode, self).process(count)
+ super(_TestTaskBase, self).process(count)
def _assert(self, pc, fc, check_scheduled=False):
"""Assert for num process counts (pc) and num function counts (fc)
:return: self"""
- # TODO: fixme
- return self
- self.plock.acquire()
- if self.process_count != pc:
- print self.process_count, pc
- assert self.process_count == pc
- self.plock.release()
self.lock.acquire()
if self.item_count != fc:
print self.item_count, fc
assert self.item_count == fc
self.lock.release()
- # if we read all, we can't really use scheduled items
- if check_scheduled:
- assert self._scheduled_items == 0
- assert not self.error()
return self
+
+class TestThreadTaskNode(_TestTaskBase, InputIteratorThreadTask):
+ pass
class TestThreadFailureNode(TestThreadTaskNode):
"""Fails after X items"""
+ def __init__(self, *args, **kwargs):
+ self.fail_after = kwargs.pop('fail_after')
+ super(TestThreadFailureNode, self).__init__(*args, **kwargs)
+ def do_fun(self, item):
+ item = TestThreadTaskNode.do_fun(self, item)
+ if self.item_count > self.fail_after:
+ raise AssertionError("Simulated failure after processing %i items" % self.fail_after)
+ return item
+
+
+class TestThreadInputChannelTaskNode(_TestTaskBase, InputChannelTask):
+ """Apply a transformation on items read from an input channel"""
+
+ def do_fun(self, item):
+ """return tuple(i, i*2)"""
+ item = super(TestThreadInputChannelTaskNode, self).do_fun(item)
+ if isinstance(item, tuple):
+ i = item[0]
+ return item + (i * self.id, )
+ else:
+ return (item, item * self.id)
+ # END handle tuple
+
+
+class TestThreadInputChannelVerifyTaskNode(_TestTaskBase, InputChannelTask):
+ """An input channel task, which verifies the result of its input channels,
+ should be last in the chain.
+ Id must be int"""
+
+ def do_fun(self, item):
+ """return tuple(i, i*2)"""
+ item = super(TestThreadInputChannelTaskNode, self).do_fun(item)
+
+ # make sure the computation order matches
+ assert isinstance(item, tuple)
+
+ base = item[0]
+ for num in item[1:]:
+ assert num == base * 2
+ base = num
+ # END verify order
+
+ return item
+
+
class TestThreadPool(TestBase):
max_threads = cpu_count()
- def _add_triple_task(self, p):
- """Add a triplet of feeder, transformer and finalizer to the pool, like
- t1 -> t2 -> t3, return all 3 return channels in order"""
- # t1 = TestThreadTaskNode(make_task(), 'iterator', None)
- # TODO:
+ def _add_task_chain(self, p, ni, count=1):
+ """Create a task chain of feeder, count transformers and order verifcator
+ to the pool p, like t1 -> t2 -> t3
+ :return: tuple(list(task1, taskN, ...), list(rc1, rcN, ...))"""
+ nt = p.num_tasks()
+
+ feeder = self._make_iterator_task(ni)
+ frc = p.add_task(feeder)
+
+ assert p.num_tasks() == nt + 1
+
+ rcs = [frc]
+ tasks = [feeder]
+
+ inrc = frc
+ for tc in xrange(count):
+ t = TestThreadInputChannelTaskNode(inrc, tc, None)
+ t.fun = t.do_fun
+ inrc = p.add_task(t)
+
+ tasks.append(t)
+ rcs.append(inrc)
+ assert p.num_tasks() == nt + 2 + tc
+ # END create count transformers
+
+ verifier = TestThreadInputChannelVerifyTaskNode(inrc, 'verifier', None)
+ verifier.fun = verifier.do_fun
+ vrc = p.add_task(verifier)
+
+ assert p.num_tasks() == nt + tc + 3
+
+ tasks.append(verifier)
+ rcs.append(vrc)
+ return tasks, rcs
+
+ def _make_iterator_task(self, ni, taskcls=TestThreadTaskNode, **kwargs):
+ """:return: task which yields ni items
+ :param taskcls: the actual iterator type to use
+ :param **kwargs: additional kwargs to be passed to the task"""
+ t = taskcls(iter(range(ni)), 'iterator', None, **kwargs)
+ t.fun = t.do_fun
+ return t
def _assert_single_task(self, p, async=False):
"""Performs testing in a synchronized environment"""
@@ -82,11 +155,7 @@ class TestThreadPool(TestBase):
assert ni % 2 == 0, "ni needs to be dividable by 2"
assert ni % 4 == 0, "ni needs to be dividable by 4"
- def make_task():
- t = TestThreadTaskNode(iter(range(ni)), 'iterator', None)
- t.fun = t.do_fun
- return t
- # END utility
+ make_task = lambda *args, **kwargs: self._make_iterator_task(ni, *args, **kwargs)
task = make_task()
@@ -252,15 +321,44 @@ class TestThreadPool(TestBase):
# test failure after ni / 2 items
# This makes sure it correctly closes the channel on failure to prevent blocking
+ nri = ni/2
+ task = make_task(TestThreadFailureNode, fail_after=ni/2)
+ rc = p.add_task(task)
+ assert len(rc.read()) == nri
+ assert task.is_done()
+ assert isinstance(task.error(), AssertionError)
- def _assert_async_dependent_tasks(self, p):
+ def _assert_async_dependent_tasks(self, pool):
# includes failure in center task, 'recursive' orphan cleanup
# This will also verify that the channel-close mechanism works
# t1 -> t2 -> t3
# t1 -> x -> t3
- pass
+ null_tasks = pool.num_tasks()
+ ni = 100
+ count = 1
+ make_task = lambda *args, **kwargs: self._add_task_chain(pool, ni, count, *args, **kwargs)
+
+ ts, rcs = make_task()
+ assert len(ts) == count + 2
+ assert len(rcs) == count + 2
+ assert pool.num_tasks() == null_tasks + len(ts)
+ print pool._tasks.nodes
+
+
+ # in the end, we expect all tasks to be gone, automatically
+
+
+
+ # order of deletion matters - just keep the end, then delete
+ final_rc = rcs[-1]
+ del(ts)
+ del(rcs)
+ del(final_rc)
+ assert pool.num_tasks() == null_tasks
+
+
@terminate_threads
def test_base(self):
@@ -301,8 +399,8 @@ class TestThreadPool(TestBase):
assert p.num_tasks() == 0
- # DEPENDENT TASKS SERIAL
- ########################
+ # DEPENDENT TASKS SYNC MODE
+ ###########################
self._assert_async_dependent_tasks(p)
@@ -311,12 +409,11 @@ class TestThreadPool(TestBase):
# step one gear up - just one thread for now.
p.set_size(1)
assert p.size() == 1
- print len(threading.enumerate()), num_threads
assert len(threading.enumerate()) == num_threads + 1
# deleting the pool stops its threads - just to be sure ;)
# Its not synchronized, hence we wait a moment
del(p)
- time.sleep(0.25)
+ time.sleep(0.05)
assert len(threading.enumerate()) == num_threads
p = ThreadPool(1)