summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMichele Simionato <michele.simionato@gmail.com>2015-07-23 07:03:48 +0200
committerMichele Simionato <michele.simionato@gmail.com>2015-07-23 07:03:48 +0200
commit80072c50a765547736c03810cfb3e4f5afcb928d (patch)
tree72e254b8b8b8f542286c4f30d7dd121f9512eca9 /src
parent14b603738eed475281a38350bfaf0df5bca5f3c8 (diff)
downloadpython-decorator-git-80072c50a765547736c03810cfb3e4f5afcb928d.tar.gz
Changed the implementation of generic functions
Diffstat (limited to 'src')
-rw-r--r--src/decorator.py113
-rw-r--r--src/tests/documentation.py80
-rw-r--r--src/tests/test.py45
3 files changed, 101 insertions, 137 deletions
diff --git a/src/decorator.py b/src/decorator.py
index a7a5439..d71ed60 100644
--- a/src/decorator.py
+++ b/src/decorator.py
@@ -40,6 +40,7 @@ import re
import sys
import inspect
import itertools
+import collections
if sys.version >= '3':
from inspect import getfullargspec
@@ -281,49 +282,27 @@ contextmanager = decorator(ContextManager)
# ############################ dispatch_on ############################ #
-class _VAManager(object):
+def unique(classes):
"""
- Manage a list of virtual ancestors for each dispatch type.
- The list is partially ordered by the `issubclass` comparison operator.
+ Return a tuple of unique classes by preserving the original order.
"""
- def __init__(self, n):
- self.indices = range(n)
- self.vancestors = [[] for _ in self.indices]
+ known = set([object])
+ outlist = []
+ for cl in classes:
+ if cl not in known:
+ outlist.append(cl)
+ known.add(cl)
+ return tuple(outlist)
- def insert(self, i, a):
- """
- For each index `i` insert a virtual ancestor `a` in the corresponding
- list, by keeping the partial ordering.
- """
- vancestors = self.vancestors[i]
- for j, va in enumerate(vancestors):
- if issubclass(a, va) and a is not va:
- vancestors.insert(j, a)
- break
- else: # less specialized
- if a not in vancestors:
- vancestors.append(a)
-
- def get_vancestors(self, types):
- """
- For each type get the most specialized VA available; return a tuple
- """
- class Sentinel(object):
- pass
- valist = [Sentinel for _ in self.indices]
- for i, t, vancestors in zip(self.indices, types, self.vancestors):
- for new in vancestors:
- if issubclass(t, new):
- old = valist[i]
- if old is Sentinel or issubclass(new, old):
- valist[i] = new
- elif issubclass(old, new):
- pass
- else:
- raise RuntimeError(
- 'Ambiguous dispatch for %s instance: %s or %s?'
- % (t.__name__, old.__name__, new.__name__))
- return tuple(valist)
+
+def insert(a, vancestors):
+ for j, va in enumerate(vancestors):
+ if issubclass(a, va) and a is not va:
+ vancestors.insert(j, a)
+ break
+ else: # less specialized
+ if a not in vancestors:
+ vancestors.append(a)
# inspired from simplegeneric by P.J. Eby and functools.singledispatch
@@ -335,6 +314,12 @@ def dispatch_on(*dispatch_args):
assert dispatch_args, 'No dispatch args passed'
dispatch_str = '(%s,)' % ', '.join(dispatch_args)
+ def check(types):
+ """Make use one passes the expected number of types"""
+ if len(types) != len(dispatch_args):
+ raise TypeError('Expected %d types, got %d' %
+ (len(dispatch_args), len(types)))
+
def gen_func_dec(func):
"""Decorator turning a function into a generic function"""
@@ -343,14 +328,38 @@ def dispatch_on(*dispatch_args):
if not set(dispatch_args) <= argset:
raise NameError('Unknown dispatch arguments %s' % dispatch_str)
- typemap = {}
- man = _VAManager(len(dispatch_args))
+ typemap = collections.OrderedDict()
+
+ def vancestors(*types):
+ """
+ Get a list of lists of virtual ancestors for the given types
+ """
+ check(types)
+ ras = [[] for _ in range(len(dispatch_args))]
+ for types_ in typemap:
+ for t, type_, ra in zip(types, types_, ras):
+ if issubclass(t, type_) and type_ not in t.__mro__:
+ insert(type_, ra)
+ return ras
+
+ def mros(*types):
+ """
+ Get a list of MROs, one for each type
+ """
+ check(types)
+ lists = []
+ for t, ancestors in zip(types, vancestors(*types)):
+ t_ancestors = unique(t.__bases__ + tuple(ancestors))
+ if not t_ancestors:
+ mro = t.__mro__
+ else:
+ mro = type(t.__name__, t_ancestors, {}).__mro__
+ lists.append(mro[:-1]) # discard object
+ return lists
def register(*types):
"Decorator to register an implementation for the given types"
- if len(types) != len(dispatch_args):
- raise TypeError('Length mismatch: expected %d types, got %d' %
- (len(dispatch_args), len(types)))
+ check(types)
def dec(f):
n_args = len(getfullargspec(f).args)
@@ -358,9 +367,6 @@ def dispatch_on(*dispatch_args):
raise TypeError(
'%s has not enough arguments (got %d, expected %d)' %
(f, n_args, len(dispatch_args)))
- for i, t, va in zip(man.indices, types, man.vancestors):
- if isinstance(t, ABCMeta):
- man.insert(i, t)
typemap[types] = f
return f
return dec
@@ -374,24 +380,19 @@ def dispatch_on(*dispatch_args):
pass
else:
return f(*args, **kw)
- for types_ in itertools.product(*(t.__mro__ for t in types)):
+ for types_ in itertools.product(*mros(*types)):
f = typemap.get(types_)
if f is not None:
return f(*args, **kw)
- # else look at the virtual ancestors
- if man.vancestors:
- f = typemap.get(man.get_vancestors(types))
- if f is not None:
- return f(*args, **kw)
-
# else call the default implementation
return func(*args, **kw)
return FunctionMaker.create(
func, 'return _f_(%s, %%(shortsignature)s)' % dispatch_str,
dict(_f_=dispatch), register=register, default=func,
- typemap=typemap, vancestors=man.vancestors, __wrapped__=func)
+ typemap=typemap, vancestors=vancestors, mros=mros,
+ __wrapped__=func)
gen_func_dec.__name__ = 'dispatch_on' + dispatch_str
return gen_func_dec
diff --git a/src/tests/documentation.py b/src/tests/documentation.py
index f4aa995..6ce800b 100644
--- a/src/tests/documentation.py
+++ b/src/tests/documentation.py
@@ -782,6 +782,11 @@ then ``get_length`` must be defined on ``WithLength`` instances:
>>> get_length(WithLength())
0
+You can find the virtual ancestors of a given set of classes as follows:
+
+ >> get_length.vancestors(WithLength,)
+ [[<class 'collections.abc.Sized'>]]
+
Of course this is a contrived example since you could just use the
builtin ``len``, but you should get the idea.
@@ -809,69 +814,43 @@ as a virtual ancestor):
Now, let us define an implementation of ``get_length`` specific to set:
-.. code-block:: python
-
- >>> @get_length.register(collections.Set)
- ... def get_length_set(obj):
- ... return 1
+$$get_length_set
-The current implementation first check in the MRO and then look
-for virtual ancestors; since ``SomeSet`` inherits directly
-from ``collections.Sized`` that implementation is found first:
+The current implementation, as the one used by ``functools.singledispatch``,
+is able to discern that a ``Set`` is a ``Sized`` object, so the
+implementation for ``Set`` is taken:
.. code-block:: python
>>> get_length(SomeSet())
- 0
-
-Generic functions implemented via ``functools.singledispatch`` use
-a more sophisticated lookup algorithm; in particular they are able
-to discern that a ``Set`` is a ``Sized`` object, so the
-implementation for ``Set`` is taken and the result is 1, not 0.
-Still, the implementation in the decorator module is easy to
-undestand, once one declare that real ancestors take the precedence
-over virtual ancestors and the problem can be solved anyway by
-subclassing. As a matter of fact, if we define a subclass
-
-$$SomeSet2
-
-which inherits from ``collections.Set``, we get as expected
-
-.. code-block:: python
-
- >>> get_length(SomeSet2())
- 1
-
-consistently with the method resolution order, with ``Set`` having the
-precedence with respect to ``Sized``:
-
-.. code-block:: python
-
- >>> [c.__name__ for c in SomeSet2.mro()]
- ['SomeSet2', 'SomeSet', 'Set', 'Sized', 'Iterable', 'Container', 'object']
+ Traceback (most recent call last):
+ ...
+ TypeError: Cannot create a consistent method resolution
+ order (MRO) for bases Sized, Set
-The functions implemented via ``functools.singledispatch``
-are smarter when there are conflicting implementations and are
-able to solve more potential conflicts. Just to have an idea
-of what I am talking about, here is a situation with a conflict:
+Sometimes it is impossible to find the right implementation. Here is a
+situation with a type conflict. First of all, let us register
.. code-block:: python
- >>> _ = collections.Iterable.register(WithLength)
>>> @get_length.register(collections.Iterable)
... def get_length_iterable(obj):
... raise TypeError('Cannot get the length of an iterable')
- >>> get_length(WithLength())
- Traceback (most recent call last):
- ...
- RuntimeError: Ambiguous dispatch for WithLength instance: Sized or Iterable?
-Since ``WithLength`` is both a (virtual) subclass
+
+Since ``SomeSet`` is now both a (virtual) subclass
of ``collections.Iterable`` and of ``collections.Sized``, which are
not related by subclassing, it is impossible
to decide which implementation should be taken. Consistently with
the *refuse the temptation to guess* philosophy, an error is raised.
-``functools.singledispatch`` would work exactly the same in this case.
+
+ >>> get_length(SomeSet())
+ Traceback (most recent call last):
+ ...
+ TypeError: Cannot create a consistent method resolution
+ order (MRO) for bases Iterable, Sized, Set
+
+``functools.singledispatch`` would raise a similar error in this case.
Finally let me notice that the decorator module implementation does
not use any cache, whereas the one in ``singledispatch`` has a cache.
@@ -1474,9 +1453,6 @@ def get_length_sized(obj):
return len(obj)
-class SomeSet2(SomeSet, collections.Set):
- def __contains__(self, a):
- return True
-
- def __iter__(self):
- yield 1
+@get_length.register(collections.Set)
+def get_length_set(obj):
+ return 1
diff --git a/src/tests/test.py b/src/tests/test.py
index cbb8373..af5998b 100644
--- a/src/tests/test.py
+++ b/src/tests/test.py
@@ -20,7 +20,7 @@ def assertRaises(etype):
except etype:
pass
else:
- raise Exception('Expected %s' % etype)
+ raise Exception('Expected %s' % etype.__name__)
class DocumentationTestCase(unittest.TestCase):
@@ -243,16 +243,6 @@ class TestSingleDispatch(unittest.TestCase):
self.assertEqual(g(s), "concrete-set")
self.assertEqual(g(f), "frozen-set")
self.assertEqual(g(t), "tuple")
- if hasattr(c, 'ChainMap'):
- self.assertEqual(
- [abc.__name__ for abc in g.vancestors[0]],
- ['ChainMap', 'MutableMapping', 'MutableSequence', 'MutableSet',
- 'Mapping', 'Sequence', 'Set', 'Sized'])
- else:
- self.assertEqual(
- [abc.__name__ for abc in g.vancestors[0]],
- ['MutableMapping', 'MutableSequence', 'MutableSet',
- 'Mapping', 'Sequence', 'Set', 'Sized'])
def test_mro_conflicts(self):
c = collections
@@ -272,13 +262,12 @@ class TestSingleDispatch(unittest.TestCase):
g.register(c.Set)(lambda arg: "set")
self.assertEqual(g(o), "sized")
c.Iterable.register(O)
- self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
+ self.assertEqual(g(o), "sized")
c.Container.register(O)
- self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
- c.Set.register(O)
self.assertEqual(g(o), "sized")
- # could be set because c.Set is a subclass of
- # c.Sized and c.Container
+ c.Set.register(O)
+ with assertRaises(TypeError): # was ok
+ self.assertEqual(g(o), "set")
class P(object):
pass
@@ -288,8 +277,8 @@ class TestSingleDispatch(unittest.TestCase):
self.assertEqual(g(p), "iterable")
c.Container.register(P)
- with assertRaises(RuntimeError):
- g(p)
+ #with assertRaises(RuntimeError):
+ self.assertEqual(g(p), "iterable")
class Q(c.Sized):
def __len__(self):
@@ -297,9 +286,9 @@ class TestSingleDispatch(unittest.TestCase):
q = Q()
self.assertEqual(g(q), "sized")
c.Iterable.register(Q)
- self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
- c.Set.register(Q)
self.assertEqual(g(q), "sized")
+ c.Set.register(Q)
+ # self.assertEqual(g(q), "sized")
# could be because c.Set is a subclass of
# c.Sized and c.Iterable
@@ -318,8 +307,8 @@ class TestSingleDispatch(unittest.TestCase):
# this ABC is implicitly registered on defaultdict which makes all of
# MutableMapping's bases implicit as well from defaultdict's
# perspective.
- with assertRaises(RuntimeError):
- h(c.defaultdict(lambda: 0))
+ #with assertRaises(RuntimeError):
+ h(c.defaultdict(lambda: 0))
class R(c.defaultdict):
pass
@@ -337,10 +326,9 @@ class TestSingleDispatch(unittest.TestCase):
def i_sequence(arg):
return "sequence"
r = R()
- with assertRaises(RuntimeError): # not for standardlib
- self.assertEqual(i(r), "sequence")
+ self.assertEqual(i(r), "mapping") # was sequence
- class S:
+ class S(object):
pass
class T(S, c.Sized):
@@ -351,7 +339,7 @@ class TestSingleDispatch(unittest.TestCase):
c.Container.register(T)
self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
- class U:
+ class U(object):
def __len__(self):
return 0
u = U()
@@ -361,9 +349,8 @@ class TestSingleDispatch(unittest.TestCase):
# from the existence of __len__()
c.Container.register(U)
- # There is no preference for registered versus inferred ABCs.
- with assertRaises(RuntimeError):
- h(u)
+ # There is preference for registered versus inferred ABCs.
+ self.assertEqual(h(u), "sized") # was conflict
class V(c.Sized, S):
def __len__(self):