diff options
Diffstat (limited to 'lib/sqlalchemy/sql/visitors.py')
| -rw-r--r-- | lib/sqlalchemy/sql/visitors.py | 57 |
1 files changed, 42 insertions, 15 deletions
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 3f06535de..371c6ec4b 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -1,10 +1,39 @@ from collections import deque +import re +from sqlalchemy import util + +class VisitableType(type): + def __init__(cls, clsname, bases, dict): + if not '__visit_name__' in cls.__dict__: + m = re.match(r'_?(\w+?)(?:Expression|Clause|Element|$)', clsname) + x = m.group(1) + x = re.sub(r'(?!^)[A-Z]', lambda m:'_'+m.group(0).lower(), x) + cls.__visit_name__ = x.lower() + + # set up an optimized visit dispatch function + # for use by the compiler + visit_name = cls.__dict__["__visit_name__"] + if isinstance(visit_name, str): + func_text = "def _compiler_dispatch(self, visitor, **kw):\n"\ + " return visitor.visit_%s(self, **kw)" % visit_name + else: + func_text = "def _compiler_dispatch(self, visitor, **kw):\n"\ + " return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw)" + + env = locals().copy() + exec func_text in env + cls._compiler_dispatch = env['_compiler_dispatch'] + + super(VisitableType, cls).__init__(clsname, bases, dict) + +class Visitable(object): + __metaclass__ = VisitableType class ClauseVisitor(object): __traverse_options__ = {} def traverse_single(self, obj): - for v in self._iterate_visitors: + for v in self._visitor_iterator: meth = getattr(v, "visit_%s" % obj.__visit_name__, None) if meth: return meth(obj) @@ -17,29 +46,33 @@ class ClauseVisitor(object): def traverse(self, obj): """traverse and visit the given expression structure.""" + return traverse(obj, self.__traverse_options__, self._visitor_dict) + + @util.memoized_property + def _visitor_dict(self): visitors = {} for name in dir(self): if name.startswith('visit_'): visitors[name[6:]] = getattr(self, name) - - return traverse(obj, self.__traverse_options__, visitors) - - def _iterate_visitors(self): + return visitors + + @property + def _visitor_iterator(self): """iterate through this visitor and each 'chained' visitor.""" v = self while v: yield v v = getattr(v, '_next', None) - _iterate_visitors = property(_iterate_visitors) def chain(self, visitor): """'chain' an additional ClauseVisitor onto this ClauseVisitor. the chained visitor will receive all visit events after this one. + """ - tail = list(self._iterate_visitors)[-1] + tail = list(self._visitor_iterator)[-1] tail._next = visitor return self @@ -52,13 +85,7 @@ class CloningVisitor(ClauseVisitor): def traverse(self, obj): """traverse and visit the given expression structure.""" - visitors = {} - - for name in dir(self): - if name.startswith('visit_'): - visitors[name[6:]] = getattr(self, name) - - return cloned_traverse(obj, self.__traverse_options__, visitors) + return cloned_traverse(obj, self.__traverse_options__, self._visitor_dict) class ReplacingCloningVisitor(CloningVisitor): def replace(self, elem): @@ -74,7 +101,7 @@ class ReplacingCloningVisitor(CloningVisitor): """traverse and visit the given expression structure.""" def replace(elem): - for v in self._iterate_visitors: + for v in self._visitor_iterator: e = v.replace(elem) if e: return e |
