summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/visitors.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/visitors.py')
-rw-r--r--lib/sqlalchemy/sql/visitors.py57
1 files changed, 42 insertions, 15 deletions
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 3f06535de..371c6ec4b 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -1,10 +1,39 @@
from collections import deque
+import re
+from sqlalchemy import util
+
+class VisitableType(type):
+ def __init__(cls, clsname, bases, dict):
+ if not '__visit_name__' in cls.__dict__:
+ m = re.match(r'_?(\w+?)(?:Expression|Clause|Element|$)', clsname)
+ x = m.group(1)
+ x = re.sub(r'(?!^)[A-Z]', lambda m:'_'+m.group(0).lower(), x)
+ cls.__visit_name__ = x.lower()
+
+ # set up an optimized visit dispatch function
+ # for use by the compiler
+ visit_name = cls.__dict__["__visit_name__"]
+ if isinstance(visit_name, str):
+ func_text = "def _compiler_dispatch(self, visitor, **kw):\n"\
+ " return visitor.visit_%s(self, **kw)" % visit_name
+ else:
+ func_text = "def _compiler_dispatch(self, visitor, **kw):\n"\
+ " return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw)"
+
+ env = locals().copy()
+ exec func_text in env
+ cls._compiler_dispatch = env['_compiler_dispatch']
+
+ super(VisitableType, cls).__init__(clsname, bases, dict)
+
+class Visitable(object):
+ __metaclass__ = VisitableType
class ClauseVisitor(object):
__traverse_options__ = {}
def traverse_single(self, obj):
- for v in self._iterate_visitors:
+ for v in self._visitor_iterator:
meth = getattr(v, "visit_%s" % obj.__visit_name__, None)
if meth:
return meth(obj)
@@ -17,29 +46,33 @@ class ClauseVisitor(object):
def traverse(self, obj):
"""traverse and visit the given expression structure."""
+ return traverse(obj, self.__traverse_options__, self._visitor_dict)
+
+ @util.memoized_property
+ def _visitor_dict(self):
visitors = {}
for name in dir(self):
if name.startswith('visit_'):
visitors[name[6:]] = getattr(self, name)
-
- return traverse(obj, self.__traverse_options__, visitors)
-
- def _iterate_visitors(self):
+ return visitors
+
+ @property
+ def _visitor_iterator(self):
"""iterate through this visitor and each 'chained' visitor."""
v = self
while v:
yield v
v = getattr(v, '_next', None)
- _iterate_visitors = property(_iterate_visitors)
def chain(self, visitor):
"""'chain' an additional ClauseVisitor onto this ClauseVisitor.
the chained visitor will receive all visit events after this one.
+
"""
- tail = list(self._iterate_visitors)[-1]
+ tail = list(self._visitor_iterator)[-1]
tail._next = visitor
return self
@@ -52,13 +85,7 @@ class CloningVisitor(ClauseVisitor):
def traverse(self, obj):
"""traverse and visit the given expression structure."""
- visitors = {}
-
- for name in dir(self):
- if name.startswith('visit_'):
- visitors[name[6:]] = getattr(self, name)
-
- return cloned_traverse(obj, self.__traverse_options__, visitors)
+ return cloned_traverse(obj, self.__traverse_options__, self._visitor_dict)
class ReplacingCloningVisitor(CloningVisitor):
def replace(self, elem):
@@ -74,7 +101,7 @@ class ReplacingCloningVisitor(CloningVisitor):
"""traverse and visit the given expression structure."""
def replace(elem):
- for v in self._iterate_visitors:
+ for v in self._visitor_iterator:
e = v.replace(elem)
if e:
return e