summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2007-03-11 20:52:02 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2007-03-11 20:52:02 +0000
commit6a3c374b955299f0065356ef1de6cc0920d5382e (patch)
tree1ec2c2fddcc2d3c8b8f350fb42f86a84918c6fe1 /lib/sqlalchemy/sql.py
parent320cb9b75f763355ed798c80d245998ce57e21cc (diff)
downloadsqlalchemy-6a3c374b955299f0065356ef1de6cc0920d5382e.tar.gz
- for hackers, refactored the "visitor" system of ClauseElement and
SchemaItem so that the traversal of items is controlled by the ClauseVisitor itself, using the method visitor.traverse(item). accept_visitor() methods can still be called directly but will not do any traversal of child items. ClauseElement/SchemaItem now have a configurable get_children() method to return the collection of child elements for each parent object. This allows the full traversal of items to be clear and unambiguous (as well as loggable), with an easy method of limiting a traversal (just pass flags which are picked up by appropriate get_children() methods). [ticket:501] - accept_schema_visitor() methods removed, replaced with get_children(schema_visitor=True) - various docstring/changelog cleanup/reformatting
Diffstat (limited to 'lib/sqlalchemy/sql.py')
-rw-r--r--lib/sqlalchemy/sql.py302
1 files changed, 206 insertions, 96 deletions
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py
index 073277d57..190ec29d4 100644
--- a/lib/sqlalchemy/sql.py
+++ b/lib/sqlalchemy/sql.py
@@ -5,7 +5,7 @@
"""Define the base components of SQL expression trees."""
-from sqlalchemy import util, exceptions
+from sqlalchemy import util, exceptions, logging
from sqlalchemy import types as sqltypes
import string, re, random, sets
@@ -485,44 +485,103 @@ class ClauseParameters(dict):
return d
class ClauseVisitor(object):
- """Define the visiting of ``ClauseElements``."""
-
- def visit_column(self, column):pass
- def visit_table(self, column):pass
- def visit_fromclause(self, fromclause):pass
- def visit_bindparam(self, bindparam):pass
- def visit_textclause(self, textclause):pass
- def visit_compound(self, compound):pass
- def visit_compound_select(self, compound):pass
- def visit_binary(self, binary):pass
- def visit_alias(self, alias):pass
- def visit_select(self, select):pass
- def visit_join(self, join):pass
- def visit_null(self, null):pass
- def visit_clauselist(self, list):pass
- def visit_calculatedclause(self, calcclause):pass
- def visit_function(self, func):pass
- def visit_cast(self, cast):pass
- def visit_label(self, label):pass
- def visit_typeclause(self, typeclause):pass
-
-class VisitColumnMixin(object):
- """a mixin that adds Column traversal to a ClauseVisitor"""
+ """A class that knows how to traverse and visit
+ ``ClauseElements``.
+
+ Each ``ClauseElement``'s accept_visitor() method will call a
+ corresponding visit_XXXX() method here. Traversal of a
+ hierarchy of ``ClauseElements`` is achieved via the
+ ``traverse()`` method, which is passed the lead
+ ``ClauseElement``.
+
+ By default, ``ClauseVisitor`` traverses all elements
+ fully. Options can be specified at the class level via the
+ ``__traverse_options__`` dictionary which will be passed
+ to the ``get_children()`` method of each ``ClauseElement``;
+ these options can indicate modifications to the set of
+ elements returned, such as to not return column collections
+ (column_collections=False) or to return Schema-level items
+ (schema_visitor=True)."""
+ __traverse_options__ = {}
+ def traverse(self, obj):
+ for n in obj.get_children(**self.__traverse_options__):
+ self.traverse(n)
+ obj.accept_visitor(self)
+ def visit_column(self, column):
+ pass
def visit_table(self, table):
- for c in table.c:
- c.accept_visitor(self)
- def visit_select(self, select):
- for c in select.c:
- c.accept_visitor(self)
- def visit_compound_select(self, select):
- for c in select.c:
- c.accept_visitor(self)
+ pass
+ def visit_fromclause(self, fromclause):
+ pass
+ def visit_bindparam(self, bindparam):
+ pass
+ def visit_textclause(self, textclause):
+ pass
+ def visit_compound(self, compound):
+ pass
+ def visit_compound_select(self, compound):
+ pass
+ def visit_binary(self, binary):
+ pass
def visit_alias(self, alias):
- for c in alias.c:
- c.accept_visitor(self)
-
+ pass
+ def visit_select(self, select):
+ pass
+ def visit_join(self, join):
+ pass
+ def visit_null(self, null):
+ pass
+ def visit_clauselist(self, list):
+ pass
+ def visit_calculatedclause(self, calcclause):
+ pass
+ def visit_function(self, func):
+ pass
+ def visit_cast(self, cast):
+ pass
+ def visit_label(self, label):
+ pass
+ def visit_typeclause(self, typeclause):
+ pass
+
+class LoggingClauseVisitor(ClauseVisitor):
+ """extends ClauseVisitor to include debug logging of all traversal.
+
+ To install this visitor, set logging.DEBUG for
+ 'sqlalchemy.sql.ClauseVisitor' **before** you import the
+ sqlalchemy.sql module.
+ """
+
+ def traverse(self, obj):
+ indent = getattr(self, '_indent', "")
+ self.logger.debug(indent + "START " + repr(obj))
+ setattr(self, "_indent", indent + " ")
+ for n in obj.get_children(**self.__traverse_options__):
+ self.traverse(n)
+ obj.accept_visitor(self)
+ setattr(self, "_indent", indent)
+ self.logger.debug(indent+ "END " + repr(obj))
+
+LoggingClauseVisitor.logger = logging.class_logger(ClauseVisitor)
+
+if logging.is_debug_enabled(LoggingClauseVisitor.logger):
+ ClauseVisitor=LoggingClauseVisitor
+
+class NoColumnVisitor(ClauseVisitor):
+ """a ClauseVisitor that will not traverse the exported Column
+ collections on Table, Alias, Select, and CompoundSelect objects
+ (i.e. their 'columns' or 'c' attribute).
+
+ this is useful because most traversals don't need those columns, or
+ in the case of ANSICompiler it traverses them explicitly; so
+ skipping their traversal here greatly cuts down on method call overhead.
+ """
+
+ __traverse_options__ = {'column_collections':False}
+
class Executor(object):
- """Represent a *thing that can produce Compiled objects and execute them*."""
+ """Interface representing a *thing that can produce Compiled objects
+ and execute them*."""
def execute_compiled(self, compiled, parameters, echo=None, **kwargs):
"""Execute a Compiled object."""
@@ -539,7 +598,7 @@ class Compiled(ClauseVisitor):
The ``__str__`` method of the ``Compiled`` object should produce
the actual text of the statement. ``Compiled`` objects are
- specific to the database library that created them, and also may
+ specific to their underlying database dialect, and also may
or may not be specific to the columns referenced within a
particular set of bind parameters. In no case should the
``Compiled`` object be dependent on the actual values of those
@@ -547,7 +606,7 @@ class Compiled(ClauseVisitor):
defaults.
"""
- def __init__(self, dialect, statement, parameters, engine=None):
+ def __init__(self, dialect, statement, parameters, engine=None, traversal=None):
"""Construct a new Compiled object.
statement
@@ -570,7 +629,7 @@ class Compiled(ClauseVisitor):
engine
Optional Engine to compile this statement against.
"""
-
+ ClauseVisitor.__init__(self, traversal=traversal)
self.dialect = dialect
self.statement = statement
self.parameters = parameters
@@ -578,7 +637,7 @@ class Compiled(ClauseVisitor):
self.can_execute = statement.supports_execution()
def compile(self):
- self.statement.accept_visitor(self)
+ self.traverse(self.statement)
self.after_compile()
def __str__(self):
@@ -649,7 +708,19 @@ class ClauseElement(object):
"""
raise NotImplementedError(repr(self))
-
+
+ def get_children(self, **kwargs):
+ """return immediate child elements of this ``ClauseElement``.
+
+ this is used for visit traversal.
+
+ **kwargs may contain flags that change the collection
+ that is returned, for example to return a subset of items
+ in order to cut down on larger traversals, or to return
+ child items from a different context (such as schema-level
+ collections instead of clause-level)."""
+ return []
+
def supports_execution(self):
"""Return True if this clause element represents a complete
executable statement.
@@ -1058,16 +1129,38 @@ class FromClause(Selectable):
def _get_all_embedded_columns(self):
ret = []
- class FindCols(VisitColumnMixin, ClauseVisitor):
+ class FindCols(ClauseVisitor):
def visit_column(self, col):
ret.append(col)
- self.accept_visitor(FindCols())
+ FindCols().traverse(self)
return ret
def corresponding_column(self, column, raiseerr=True, keys_ok=False, require_embedded=False):
- """Given a ``ColumnElement``, return the ``ColumnElement``
- object from this ``Selectable`` which corresponds to that
- original ``Column`` via a proxy relationship.
+ """Given a ``ColumnElement``, return the exported
+ ``ColumnElement`` object from this ``Selectable`` which
+ corresponds to that original ``Column`` via a common
+ anscestor column.
+
+ column
+ the target ``ColumnElement`` to be matched
+
+ raiseerr
+ if True, raise an error if the given ``ColumnElement``
+ could not be matched. if False, non-matches will
+ return None.
+
+ keys_ok
+ if the ``ColumnElement`` cannot be matched, attempt to
+ match based on the string "key" property of the column
+ alone. This makes the search much more liberal.
+
+ require_embedded
+ only return corresponding columns for the given
+ ``ColumnElement``, if the given ``ColumnElement`` is
+ actually present within a sub-element of this
+ ``FromClause``. Normally the column will match if
+ it merely shares a common anscestor with one of
+ the exported columns of this ``FromClause``.
"""
if require_embedded and column not in util.Set(self._get_all_embedded_columns()):
@@ -1258,11 +1351,14 @@ class _TextClause(ClauseElement):
if bindparams is not None:
for b in bindparams:
self.bindparams[b.key] = b
- columns = property(lambda s:[])
- def accept_visitor(self, visitor):
- for item in self.bindparams.values():
- item.accept_visitor(visitor)
+ columns = property(lambda s:[])
+
+ def get_children(self, **kwargs):
+ return self.bindparams.values()
+
+ def accept_visitor(self, visitor):
visitor.visit_textclause(self)
+
def _get_from_objects(self):
return []
def supports_execution(self):
@@ -1296,9 +1392,9 @@ class ClauseList(ClauseElement):
if _is_literal(clause):
clause = _TextClause(str(clause))
self.clauses.append(clause)
+ def get_children(self, **kwargs):
+ return self.clauses
def accept_visitor(self, visitor):
- for c in self.clauses:
- c.accept_visitor(visitor)
visitor.visit_clauselist(self)
def _get_from_objects(self):
f = []
@@ -1338,9 +1434,9 @@ class _CompoundClause(ClauseList):
clause.parens = True
ClauseList.append(self, clause)
+ def get_children(self, **kwargs):
+ return self.clauses
def accept_visitor(self, visitor):
- for c in self.clauses:
- c.accept_visitor(visitor)
visitor.visit_compound(self)
def _get_from_objects(self):
@@ -1384,9 +1480,9 @@ class _CalculatedClause(ClauseList, ColumnElement):
clauses = [clause.copy_container() for clause in self.clauses]
return _CalculatedClause(type=self.type, engine=self._engine, *clauses)
+ def get_children(self, **kwargs):
+ return self.clauses
def accept_visitor(self, visitor):
- for c in self.clauses:
- c.accept_visitor(visitor)
visitor.visit_calculatedclause(self)
def _bind_param(self, obj):
@@ -1432,9 +1528,9 @@ class _Function(_CalculatedClause, FromClause):
clauses = [clause.copy_container() for clause in self.clauses]
return _Function(self.name, type=self.type, packagenames=self.packagenames, engine=self._engine, *clauses)
+ def get_children(self, **kwargs):
+ return self.clauses
def accept_visitor(self, visitor):
- for c in self.clauses:
- c.accept_visitor(visitor)
visitor.visit_function(self)
class _Cast(ColumnElement):
@@ -1445,9 +1541,9 @@ class _Cast(ColumnElement):
self.clause = clause
self.typeclause = _TypeClause(self.type)
+ def get_children(self, **kwargs):
+ return self.clause, self.typeclause
def accept_visitor(self, visitor):
- self.clause.accept_visitor(visitor)
- self.typeclause.accept_visitor(visitor)
visitor.visit_cast(self)
def _get_from_objects(self):
@@ -1494,9 +1590,9 @@ class _BinaryClause(ClauseElement):
return self.__class__(self.left.copy_container(), self.right.copy_container(), self.operator)
def _get_from_objects(self):
return self.left._get_from_objects() + self.right._get_from_objects()
+ def get_children(self, **kwargs):
+ return self.left, self.right
def accept_visitor(self, visitor):
- self.left.accept_visitor(visitor)
- self.right.accept_visitor(visitor)
visitor.visit_binary(self)
def swap(self):
c = self.left
@@ -1589,12 +1685,12 @@ class Join(FromClause):
def _get_folded_equivalents(self, equivs=None):
if equivs is None:
equivs = util.Set()
- class LocateEquivs(ClauseVisitor):
+ class LocateEquivs(NoColumnVisitor):
def visit_binary(self, binary):
if binary.operator == '=' and binary.left.name == binary.right.name:
equivs.add(binary.right)
equivs.add(binary.left)
- self.onclause.accept_visitor(LocateEquivs())
+ LocateEquivs().traverse(self.onclause)
collist = []
if isinstance(self.left, Join):
left = self.left._get_folded_equivalents(equivs)
@@ -1636,10 +1732,9 @@ class Join(FromClause):
return select(collist, whereclause, from_obj=[self], **kwargs)
+ def get_children(self, **kwargs):
+ return self.left, self.right, self.onclause
def accept_visitor(self, visitor):
- self.left.accept_visitor(visitor)
- self.right.accept_visitor(visitor)
- self.onclause.accept_visitor(visitor)
visitor.visit_join(self)
engine = property(lambda s:s.left.engine or s.right.engine)
@@ -1692,8 +1787,11 @@ class Alias(FromClause):
#return self.selectable._exportable_columns()
return self.selectable.columns
+ def get_children(self, **kwargs):
+ for c in self.c:
+ yield c
+ yield self.selectable
def accept_visitor(self, visitor):
- self.selectable.accept_visitor(visitor)
visitor.visit_alias(self)
def _get_from_objects(self):
@@ -1717,9 +1815,10 @@ class _Label(ColumnElement):
key = property(lambda s: s.name)
_label = property(lambda s: s.name)
orig_set = property(lambda s:s.obj.orig_set)
-
+
+ def get_children(self, **kwargs):
+ return self.obj,
def accept_visitor(self, visitor):
- self.obj.accept_visitor(visitor)
visitor.visit_label(self)
def _get_from_objects(self):
@@ -1841,6 +1940,11 @@ class TableClause(FromClause):
original_columns = property(_orig_columns)
+ def get_children(self, column_collections=True, **kwargs):
+ if column_collections:
+ return [c for c in self.c]
+ else:
+ return []
def accept_visitor(self, visitor):
visitor.visit_table(self)
@@ -1964,11 +2068,10 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
col.orig_set = colset
return col
+ def get_children(self, column_collections=True, **kwargs):
+ return (column_collections and list(self.c) or []) + \
+ [self.order_by_clause, self.group_by_clause] + list(self.selects)
def accept_visitor(self, visitor):
- self.order_by_clause.accept_visitor(visitor)
- self.group_by_clause.accept_visitor(visitor)
- for s in self.selects:
- s.accept_visitor(visitor)
visitor.visit_compound_select(self)
def _find_engine(self):
@@ -2028,9 +2131,9 @@ class Select(_SelectBaseMixin, FromClause):
self.order_by(*(order_by or [None]))
self.group_by(*(group_by or [None]))
for c in self.order_by_clause:
- c.accept_visitor(self.__correlator)
+ self.__correlator.traverse(c)
for c in self.group_by_clause:
- c.accept_visitor(self.__correlator)
+ self.__correlator.traverse(c)
for f in from_obj:
self.append_from(f)
@@ -2044,13 +2147,14 @@ class Select(_SelectBaseMixin, FromClause):
self.append_having(having)
- class _CorrelatedVisitor(ClauseVisitor):
+ class _CorrelatedVisitor(NoColumnVisitor):
"""Visit a clause, locate any ``Select`` clauses, and tell
them that they should correlate their ``FROM`` list to that of
their parent.
"""
def __init__(self, select, is_where):
+ NoColumnVisitor.__init__(self)
self.select = select
self.is_where = is_where
@@ -2084,12 +2188,12 @@ class Select(_SelectBaseMixin, FromClause):
# if the column is a Select statement itself,
# accept visitor
- column.accept_visitor(self.__correlator)
+ self.__correlator.traverse(column)
# visit the FROM objects of the column looking for more Selects
for f in column._get_from_objects():
if f is not self:
- f.accept_visitor(self.__correlator)
+ self.__correlator.traverse(f)
self._process_froms(column, False)
def _make_proxy(self, selectable, name):
if self.is_scalar:
@@ -2127,7 +2231,7 @@ class Select(_SelectBaseMixin, FromClause):
def _append_condition(self, attribute, condition):
if type(condition) == str:
condition = _TextClause(condition)
- condition.accept_visitor(self.__wherecorrelator)
+ self.__wherecorrelator.traverse(condition)
self._process_froms(condition, False)
if getattr(self, attribute) is not None:
setattr(self, attribute, and_(getattr(self, attribute), condition))
@@ -2146,7 +2250,7 @@ class Select(_SelectBaseMixin, FromClause):
def append_from(self, fromclause):
if type(fromclause) == str:
fromclause = FromClause(fromclause)
- fromclause.accept_visitor(self.__correlator)
+ self.__correlator.traverse(fromclause)
self._process_froms(fromclause, True)
def _locate_oid_column(self):
@@ -2169,16 +2273,14 @@ class Select(_SelectBaseMixin, FromClause):
return f
froms = property(_calc_froms, doc="""A collection containing all elements of the FROM clause""")
+
+ def get_children(self, column_collections=True, **kwargs):
+ return (column_collections and list(self.columns) or []) + \
+ list(self.froms) + \
+ [x for x in (self.whereclause, self.having) if x is not None] + \
+ [self.order_by_clause, self.group_by_clause]
def accept_visitor(self, visitor):
- for f in self.froms:
- f.accept_visitor(visitor)
- if self.whereclause is not None:
- self.whereclause.accept_visitor(visitor)
- if self.having is not None:
- self.having.accept_visitor(visitor)
- self.order_by_clause.accept_visitor(visitor)
- self.group_by_clause.accept_visitor(visitor)
visitor.visit_select(self)
def union(self, other, **kwargs):
@@ -2259,10 +2361,12 @@ class _Insert(_UpdateBase):
self.select = None
self.parameters = self._process_colparams(values)
- def accept_visitor(self, visitor):
+ def get_children(self, **kwargs):
if self.select is not None:
- self.select.accept_visitor(visitor)
-
+ return self.select,
+ else:
+ return ()
+ def accept_visitor(self, visitor):
visitor.visit_insert(self)
class _Update(_UpdateBase):
@@ -2271,9 +2375,12 @@ class _Update(_UpdateBase):
self.whereclause = whereclause
self.parameters = self._process_colparams(values)
- def accept_visitor(self, visitor):
+ def get_children(self, **kwargs):
if self.whereclause is not None:
- self.whereclause.accept_visitor(visitor)
+ return self.whereclause,
+ else:
+ return ()
+ def accept_visitor(self, visitor):
visitor.visit_update(self)
class _Delete(_UpdateBase):
@@ -2281,7 +2388,10 @@ class _Delete(_UpdateBase):
self.table = table
self.whereclause = whereclause
- def accept_visitor(self, visitor):
+ def get_children(self, **kwargs):
if self.whereclause is not None:
- self.whereclause.accept_visitor(visitor)
+ return self.whereclause,
+ else:
+ return ()
+ def accept_visitor(self, visitor):
visitor.visit_delete(self)