diff options
| author | Michele Simionato <michele.simionato@gmail.com> | 2015-07-23 07:03:48 +0200 |
|---|---|---|
| committer | Michele Simionato <michele.simionato@gmail.com> | 2015-07-23 07:03:48 +0200 |
| commit | 80072c50a765547736c03810cfb3e4f5afcb928d (patch) | |
| tree | 72e254b8b8b8f542286c4f30d7dd121f9512eca9 /src | |
| parent | 14b603738eed475281a38350bfaf0df5bca5f3c8 (diff) | |
| download | python-decorator-git-80072c50a765547736c03810cfb3e4f5afcb928d.tar.gz | |
Changed the implementation of generic functions
Diffstat (limited to 'src')
| -rw-r--r-- | src/decorator.py | 113 | ||||
| -rw-r--r-- | src/tests/documentation.py | 80 | ||||
| -rw-r--r-- | src/tests/test.py | 45 |
3 files changed, 101 insertions, 137 deletions
diff --git a/src/decorator.py b/src/decorator.py index a7a5439..d71ed60 100644 --- a/src/decorator.py +++ b/src/decorator.py @@ -40,6 +40,7 @@ import re import sys import inspect import itertools +import collections if sys.version >= '3': from inspect import getfullargspec @@ -281,49 +282,27 @@ contextmanager = decorator(ContextManager) # ############################ dispatch_on ############################ # -class _VAManager(object): +def unique(classes): """ - Manage a list of virtual ancestors for each dispatch type. - The list is partially ordered by the `issubclass` comparison operator. + Return a tuple of unique classes by preserving the original order. """ - def __init__(self, n): - self.indices = range(n) - self.vancestors = [[] for _ in self.indices] + known = set([object]) + outlist = [] + for cl in classes: + if cl not in known: + outlist.append(cl) + known.add(cl) + return tuple(outlist) - def insert(self, i, a): - """ - For each index `i` insert a virtual ancestor `a` in the corresponding - list, by keeping the partial ordering. - """ - vancestors = self.vancestors[i] - for j, va in enumerate(vancestors): - if issubclass(a, va) and a is not va: - vancestors.insert(j, a) - break - else: # less specialized - if a not in vancestors: - vancestors.append(a) - - def get_vancestors(self, types): - """ - For each type get the most specialized VA available; return a tuple - """ - class Sentinel(object): - pass - valist = [Sentinel for _ in self.indices] - for i, t, vancestors in zip(self.indices, types, self.vancestors): - for new in vancestors: - if issubclass(t, new): - old = valist[i] - if old is Sentinel or issubclass(new, old): - valist[i] = new - elif issubclass(old, new): - pass - else: - raise RuntimeError( - 'Ambiguous dispatch for %s instance: %s or %s?' - % (t.__name__, old.__name__, new.__name__)) - return tuple(valist) + +def insert(a, vancestors): + for j, va in enumerate(vancestors): + if issubclass(a, va) and a is not va: + vancestors.insert(j, a) + break + else: # less specialized + if a not in vancestors: + vancestors.append(a) # inspired from simplegeneric by P.J. Eby and functools.singledispatch @@ -335,6 +314,12 @@ def dispatch_on(*dispatch_args): assert dispatch_args, 'No dispatch args passed' dispatch_str = '(%s,)' % ', '.join(dispatch_args) + def check(types): + """Make use one passes the expected number of types""" + if len(types) != len(dispatch_args): + raise TypeError('Expected %d types, got %d' % + (len(dispatch_args), len(types))) + def gen_func_dec(func): """Decorator turning a function into a generic function""" @@ -343,14 +328,38 @@ def dispatch_on(*dispatch_args): if not set(dispatch_args) <= argset: raise NameError('Unknown dispatch arguments %s' % dispatch_str) - typemap = {} - man = _VAManager(len(dispatch_args)) + typemap = collections.OrderedDict() + + def vancestors(*types): + """ + Get a list of lists of virtual ancestors for the given types + """ + check(types) + ras = [[] for _ in range(len(dispatch_args))] + for types_ in typemap: + for t, type_, ra in zip(types, types_, ras): + if issubclass(t, type_) and type_ not in t.__mro__: + insert(type_, ra) + return ras + + def mros(*types): + """ + Get a list of MROs, one for each type + """ + check(types) + lists = [] + for t, ancestors in zip(types, vancestors(*types)): + t_ancestors = unique(t.__bases__ + tuple(ancestors)) + if not t_ancestors: + mro = t.__mro__ + else: + mro = type(t.__name__, t_ancestors, {}).__mro__ + lists.append(mro[:-1]) # discard object + return lists def register(*types): "Decorator to register an implementation for the given types" - if len(types) != len(dispatch_args): - raise TypeError('Length mismatch: expected %d types, got %d' % - (len(dispatch_args), len(types))) + check(types) def dec(f): n_args = len(getfullargspec(f).args) @@ -358,9 +367,6 @@ def dispatch_on(*dispatch_args): raise TypeError( '%s has not enough arguments (got %d, expected %d)' % (f, n_args, len(dispatch_args))) - for i, t, va in zip(man.indices, types, man.vancestors): - if isinstance(t, ABCMeta): - man.insert(i, t) typemap[types] = f return f return dec @@ -374,24 +380,19 @@ def dispatch_on(*dispatch_args): pass else: return f(*args, **kw) - for types_ in itertools.product(*(t.__mro__ for t in types)): + for types_ in itertools.product(*mros(*types)): f = typemap.get(types_) if f is not None: return f(*args, **kw) - # else look at the virtual ancestors - if man.vancestors: - f = typemap.get(man.get_vancestors(types)) - if f is not None: - return f(*args, **kw) - # else call the default implementation return func(*args, **kw) return FunctionMaker.create( func, 'return _f_(%s, %%(shortsignature)s)' % dispatch_str, dict(_f_=dispatch), register=register, default=func, - typemap=typemap, vancestors=man.vancestors, __wrapped__=func) + typemap=typemap, vancestors=vancestors, mros=mros, + __wrapped__=func) gen_func_dec.__name__ = 'dispatch_on' + dispatch_str return gen_func_dec diff --git a/src/tests/documentation.py b/src/tests/documentation.py index f4aa995..6ce800b 100644 --- a/src/tests/documentation.py +++ b/src/tests/documentation.py @@ -782,6 +782,11 @@ then ``get_length`` must be defined on ``WithLength`` instances: >>> get_length(WithLength()) 0 +You can find the virtual ancestors of a given set of classes as follows: + + >> get_length.vancestors(WithLength,) + [[<class 'collections.abc.Sized'>]] + Of course this is a contrived example since you could just use the builtin ``len``, but you should get the idea. @@ -809,69 +814,43 @@ as a virtual ancestor): Now, let us define an implementation of ``get_length`` specific to set: -.. code-block:: python - - >>> @get_length.register(collections.Set) - ... def get_length_set(obj): - ... return 1 +$$get_length_set -The current implementation first check in the MRO and then look -for virtual ancestors; since ``SomeSet`` inherits directly -from ``collections.Sized`` that implementation is found first: +The current implementation, as the one used by ``functools.singledispatch``, +is able to discern that a ``Set`` is a ``Sized`` object, so the +implementation for ``Set`` is taken: .. code-block:: python >>> get_length(SomeSet()) - 0 - -Generic functions implemented via ``functools.singledispatch`` use -a more sophisticated lookup algorithm; in particular they are able -to discern that a ``Set`` is a ``Sized`` object, so the -implementation for ``Set`` is taken and the result is 1, not 0. -Still, the implementation in the decorator module is easy to -undestand, once one declare that real ancestors take the precedence -over virtual ancestors and the problem can be solved anyway by -subclassing. As a matter of fact, if we define a subclass - -$$SomeSet2 - -which inherits from ``collections.Set``, we get as expected - -.. code-block:: python - - >>> get_length(SomeSet2()) - 1 - -consistently with the method resolution order, with ``Set`` having the -precedence with respect to ``Sized``: - -.. code-block:: python - - >>> [c.__name__ for c in SomeSet2.mro()] - ['SomeSet2', 'SomeSet', 'Set', 'Sized', 'Iterable', 'Container', 'object'] + Traceback (most recent call last): + ... + TypeError: Cannot create a consistent method resolution + order (MRO) for bases Sized, Set -The functions implemented via ``functools.singledispatch`` -are smarter when there are conflicting implementations and are -able to solve more potential conflicts. Just to have an idea -of what I am talking about, here is a situation with a conflict: +Sometimes it is impossible to find the right implementation. Here is a +situation with a type conflict. First of all, let us register .. code-block:: python - >>> _ = collections.Iterable.register(WithLength) >>> @get_length.register(collections.Iterable) ... def get_length_iterable(obj): ... raise TypeError('Cannot get the length of an iterable') - >>> get_length(WithLength()) - Traceback (most recent call last): - ... - RuntimeError: Ambiguous dispatch for WithLength instance: Sized or Iterable? -Since ``WithLength`` is both a (virtual) subclass + +Since ``SomeSet`` is now both a (virtual) subclass of ``collections.Iterable`` and of ``collections.Sized``, which are not related by subclassing, it is impossible to decide which implementation should be taken. Consistently with the *refuse the temptation to guess* philosophy, an error is raised. -``functools.singledispatch`` would work exactly the same in this case. + + >>> get_length(SomeSet()) + Traceback (most recent call last): + ... + TypeError: Cannot create a consistent method resolution + order (MRO) for bases Iterable, Sized, Set + +``functools.singledispatch`` would raise a similar error in this case. Finally let me notice that the decorator module implementation does not use any cache, whereas the one in ``singledispatch`` has a cache. @@ -1474,9 +1453,6 @@ def get_length_sized(obj): return len(obj) -class SomeSet2(SomeSet, collections.Set): - def __contains__(self, a): - return True - - def __iter__(self): - yield 1 +@get_length.register(collections.Set) +def get_length_set(obj): + return 1 diff --git a/src/tests/test.py b/src/tests/test.py index cbb8373..af5998b 100644 --- a/src/tests/test.py +++ b/src/tests/test.py @@ -20,7 +20,7 @@ def assertRaises(etype): except etype: pass else: - raise Exception('Expected %s' % etype) + raise Exception('Expected %s' % etype.__name__) class DocumentationTestCase(unittest.TestCase): @@ -243,16 +243,6 @@ class TestSingleDispatch(unittest.TestCase): self.assertEqual(g(s), "concrete-set") self.assertEqual(g(f), "frozen-set") self.assertEqual(g(t), "tuple") - if hasattr(c, 'ChainMap'): - self.assertEqual( - [abc.__name__ for abc in g.vancestors[0]], - ['ChainMap', 'MutableMapping', 'MutableSequence', 'MutableSet', - 'Mapping', 'Sequence', 'Set', 'Sized']) - else: - self.assertEqual( - [abc.__name__ for abc in g.vancestors[0]], - ['MutableMapping', 'MutableSequence', 'MutableSet', - 'Mapping', 'Sequence', 'Set', 'Sized']) def test_mro_conflicts(self): c = collections @@ -272,13 +262,12 @@ class TestSingleDispatch(unittest.TestCase): g.register(c.Set)(lambda arg: "set") self.assertEqual(g(o), "sized") c.Iterable.register(O) - self.assertEqual(g(o), "sized") # because it's explicitly in __mro__ + self.assertEqual(g(o), "sized") c.Container.register(O) - self.assertEqual(g(o), "sized") # see above: Sized is in __mro__ - c.Set.register(O) self.assertEqual(g(o), "sized") - # could be set because c.Set is a subclass of - # c.Sized and c.Container + c.Set.register(O) + with assertRaises(TypeError): # was ok + self.assertEqual(g(o), "set") class P(object): pass @@ -288,8 +277,8 @@ class TestSingleDispatch(unittest.TestCase): self.assertEqual(g(p), "iterable") c.Container.register(P) - with assertRaises(RuntimeError): - g(p) + #with assertRaises(RuntimeError): + self.assertEqual(g(p), "iterable") class Q(c.Sized): def __len__(self): @@ -297,9 +286,9 @@ class TestSingleDispatch(unittest.TestCase): q = Q() self.assertEqual(g(q), "sized") c.Iterable.register(Q) - self.assertEqual(g(q), "sized") # because it's explicitly in __mro__ - c.Set.register(Q) self.assertEqual(g(q), "sized") + c.Set.register(Q) + # self.assertEqual(g(q), "sized") # could be because c.Set is a subclass of # c.Sized and c.Iterable @@ -318,8 +307,8 @@ class TestSingleDispatch(unittest.TestCase): # this ABC is implicitly registered on defaultdict which makes all of # MutableMapping's bases implicit as well from defaultdict's # perspective. - with assertRaises(RuntimeError): - h(c.defaultdict(lambda: 0)) + #with assertRaises(RuntimeError): + h(c.defaultdict(lambda: 0)) class R(c.defaultdict): pass @@ -337,10 +326,9 @@ class TestSingleDispatch(unittest.TestCase): def i_sequence(arg): return "sequence" r = R() - with assertRaises(RuntimeError): # not for standardlib - self.assertEqual(i(r), "sequence") + self.assertEqual(i(r), "mapping") # was sequence - class S: + class S(object): pass class T(S, c.Sized): @@ -351,7 +339,7 @@ class TestSingleDispatch(unittest.TestCase): c.Container.register(T) self.assertEqual(h(t), "sized") # because it's explicitly in the MRO - class U: + class U(object): def __len__(self): return 0 u = U() @@ -361,9 +349,8 @@ class TestSingleDispatch(unittest.TestCase): # from the existence of __len__() c.Container.register(U) - # There is no preference for registered versus inferred ABCs. - with assertRaises(RuntimeError): - h(u) + # There is preference for registered versus inferred ABCs. + self.assertEqual(h(u), "sized") # was conflict class V(c.Sized, S): def __len__(self): |
