diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-03-11 20:52:02 +0000 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-03-11 20:52:02 +0000 |
| commit | 6a3c374b955299f0065356ef1de6cc0920d5382e (patch) | |
| tree | 1ec2c2fddcc2d3c8b8f350fb42f86a84918c6fe1 /lib/sqlalchemy/sql.py | |
| parent | 320cb9b75f763355ed798c80d245998ce57e21cc (diff) | |
| download | sqlalchemy-6a3c374b955299f0065356ef1de6cc0920d5382e.tar.gz | |
- for hackers, refactored the "visitor" system of ClauseElement and
SchemaItem so that the traversal of items is controlled by the
ClauseVisitor itself, using the method visitor.traverse(item).
accept_visitor() methods can still be called directly but will
not do any traversal of child items. ClauseElement/SchemaItem now
have a configurable get_children() method to return the collection
of child elements for each parent object. This allows the full
traversal of items to be clear and unambiguous (as well as loggable),
with an easy method of limiting a traversal (just pass flags which
are picked up by appropriate get_children() methods). [ticket:501]
- accept_schema_visitor() methods removed, replaced with
get_children(schema_visitor=True)
- various docstring/changelog cleanup/reformatting
Diffstat (limited to 'lib/sqlalchemy/sql.py')
| -rw-r--r-- | lib/sqlalchemy/sql.py | 302 |
1 files changed, 206 insertions, 96 deletions
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 073277d57..190ec29d4 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -5,7 +5,7 @@ """Define the base components of SQL expression trees.""" -from sqlalchemy import util, exceptions +from sqlalchemy import util, exceptions, logging from sqlalchemy import types as sqltypes import string, re, random, sets @@ -485,44 +485,103 @@ class ClauseParameters(dict): return d class ClauseVisitor(object): - """Define the visiting of ``ClauseElements``.""" - - def visit_column(self, column):pass - def visit_table(self, column):pass - def visit_fromclause(self, fromclause):pass - def visit_bindparam(self, bindparam):pass - def visit_textclause(self, textclause):pass - def visit_compound(self, compound):pass - def visit_compound_select(self, compound):pass - def visit_binary(self, binary):pass - def visit_alias(self, alias):pass - def visit_select(self, select):pass - def visit_join(self, join):pass - def visit_null(self, null):pass - def visit_clauselist(self, list):pass - def visit_calculatedclause(self, calcclause):pass - def visit_function(self, func):pass - def visit_cast(self, cast):pass - def visit_label(self, label):pass - def visit_typeclause(self, typeclause):pass - -class VisitColumnMixin(object): - """a mixin that adds Column traversal to a ClauseVisitor""" + """A class that knows how to traverse and visit + ``ClauseElements``. + + Each ``ClauseElement``'s accept_visitor() method will call a + corresponding visit_XXXX() method here. Traversal of a + hierarchy of ``ClauseElements`` is achieved via the + ``traverse()`` method, which is passed the lead + ``ClauseElement``. + + By default, ``ClauseVisitor`` traverses all elements + fully. Options can be specified at the class level via the + ``__traverse_options__`` dictionary which will be passed + to the ``get_children()`` method of each ``ClauseElement``; + these options can indicate modifications to the set of + elements returned, such as to not return column collections + (column_collections=False) or to return Schema-level items + (schema_visitor=True).""" + __traverse_options__ = {} + def traverse(self, obj): + for n in obj.get_children(**self.__traverse_options__): + self.traverse(n) + obj.accept_visitor(self) + def visit_column(self, column): + pass def visit_table(self, table): - for c in table.c: - c.accept_visitor(self) - def visit_select(self, select): - for c in select.c: - c.accept_visitor(self) - def visit_compound_select(self, select): - for c in select.c: - c.accept_visitor(self) + pass + def visit_fromclause(self, fromclause): + pass + def visit_bindparam(self, bindparam): + pass + def visit_textclause(self, textclause): + pass + def visit_compound(self, compound): + pass + def visit_compound_select(self, compound): + pass + def visit_binary(self, binary): + pass def visit_alias(self, alias): - for c in alias.c: - c.accept_visitor(self) - + pass + def visit_select(self, select): + pass + def visit_join(self, join): + pass + def visit_null(self, null): + pass + def visit_clauselist(self, list): + pass + def visit_calculatedclause(self, calcclause): + pass + def visit_function(self, func): + pass + def visit_cast(self, cast): + pass + def visit_label(self, label): + pass + def visit_typeclause(self, typeclause): + pass + +class LoggingClauseVisitor(ClauseVisitor): + """extends ClauseVisitor to include debug logging of all traversal. + + To install this visitor, set logging.DEBUG for + 'sqlalchemy.sql.ClauseVisitor' **before** you import the + sqlalchemy.sql module. + """ + + def traverse(self, obj): + indent = getattr(self, '_indent', "") + self.logger.debug(indent + "START " + repr(obj)) + setattr(self, "_indent", indent + " ") + for n in obj.get_children(**self.__traverse_options__): + self.traverse(n) + obj.accept_visitor(self) + setattr(self, "_indent", indent) + self.logger.debug(indent+ "END " + repr(obj)) + +LoggingClauseVisitor.logger = logging.class_logger(ClauseVisitor) + +if logging.is_debug_enabled(LoggingClauseVisitor.logger): + ClauseVisitor=LoggingClauseVisitor + +class NoColumnVisitor(ClauseVisitor): + """a ClauseVisitor that will not traverse the exported Column + collections on Table, Alias, Select, and CompoundSelect objects + (i.e. their 'columns' or 'c' attribute). + + this is useful because most traversals don't need those columns, or + in the case of ANSICompiler it traverses them explicitly; so + skipping their traversal here greatly cuts down on method call overhead. + """ + + __traverse_options__ = {'column_collections':False} + class Executor(object): - """Represent a *thing that can produce Compiled objects and execute them*.""" + """Interface representing a *thing that can produce Compiled objects + and execute them*.""" def execute_compiled(self, compiled, parameters, echo=None, **kwargs): """Execute a Compiled object.""" @@ -539,7 +598,7 @@ class Compiled(ClauseVisitor): The ``__str__`` method of the ``Compiled`` object should produce the actual text of the statement. ``Compiled`` objects are - specific to the database library that created them, and also may + specific to their underlying database dialect, and also may or may not be specific to the columns referenced within a particular set of bind parameters. In no case should the ``Compiled`` object be dependent on the actual values of those @@ -547,7 +606,7 @@ class Compiled(ClauseVisitor): defaults. """ - def __init__(self, dialect, statement, parameters, engine=None): + def __init__(self, dialect, statement, parameters, engine=None, traversal=None): """Construct a new Compiled object. statement @@ -570,7 +629,7 @@ class Compiled(ClauseVisitor): engine Optional Engine to compile this statement against. """ - + ClauseVisitor.__init__(self, traversal=traversal) self.dialect = dialect self.statement = statement self.parameters = parameters @@ -578,7 +637,7 @@ class Compiled(ClauseVisitor): self.can_execute = statement.supports_execution() def compile(self): - self.statement.accept_visitor(self) + self.traverse(self.statement) self.after_compile() def __str__(self): @@ -649,7 +708,19 @@ class ClauseElement(object): """ raise NotImplementedError(repr(self)) - + + def get_children(self, **kwargs): + """return immediate child elements of this ``ClauseElement``. + + this is used for visit traversal. + + **kwargs may contain flags that change the collection + that is returned, for example to return a subset of items + in order to cut down on larger traversals, or to return + child items from a different context (such as schema-level + collections instead of clause-level).""" + return [] + def supports_execution(self): """Return True if this clause element represents a complete executable statement. @@ -1058,16 +1129,38 @@ class FromClause(Selectable): def _get_all_embedded_columns(self): ret = [] - class FindCols(VisitColumnMixin, ClauseVisitor): + class FindCols(ClauseVisitor): def visit_column(self, col): ret.append(col) - self.accept_visitor(FindCols()) + FindCols().traverse(self) return ret def corresponding_column(self, column, raiseerr=True, keys_ok=False, require_embedded=False): - """Given a ``ColumnElement``, return the ``ColumnElement`` - object from this ``Selectable`` which corresponds to that - original ``Column`` via a proxy relationship. + """Given a ``ColumnElement``, return the exported + ``ColumnElement`` object from this ``Selectable`` which + corresponds to that original ``Column`` via a common + anscestor column. + + column + the target ``ColumnElement`` to be matched + + raiseerr + if True, raise an error if the given ``ColumnElement`` + could not be matched. if False, non-matches will + return None. + + keys_ok + if the ``ColumnElement`` cannot be matched, attempt to + match based on the string "key" property of the column + alone. This makes the search much more liberal. + + require_embedded + only return corresponding columns for the given + ``ColumnElement``, if the given ``ColumnElement`` is + actually present within a sub-element of this + ``FromClause``. Normally the column will match if + it merely shares a common anscestor with one of + the exported columns of this ``FromClause``. """ if require_embedded and column not in util.Set(self._get_all_embedded_columns()): @@ -1258,11 +1351,14 @@ class _TextClause(ClauseElement): if bindparams is not None: for b in bindparams: self.bindparams[b.key] = b - columns = property(lambda s:[]) - def accept_visitor(self, visitor): - for item in self.bindparams.values(): - item.accept_visitor(visitor) + columns = property(lambda s:[]) + + def get_children(self, **kwargs): + return self.bindparams.values() + + def accept_visitor(self, visitor): visitor.visit_textclause(self) + def _get_from_objects(self): return [] def supports_execution(self): @@ -1296,9 +1392,9 @@ class ClauseList(ClauseElement): if _is_literal(clause): clause = _TextClause(str(clause)) self.clauses.append(clause) + def get_children(self, **kwargs): + return self.clauses def accept_visitor(self, visitor): - for c in self.clauses: - c.accept_visitor(visitor) visitor.visit_clauselist(self) def _get_from_objects(self): f = [] @@ -1338,9 +1434,9 @@ class _CompoundClause(ClauseList): clause.parens = True ClauseList.append(self, clause) + def get_children(self, **kwargs): + return self.clauses def accept_visitor(self, visitor): - for c in self.clauses: - c.accept_visitor(visitor) visitor.visit_compound(self) def _get_from_objects(self): @@ -1384,9 +1480,9 @@ class _CalculatedClause(ClauseList, ColumnElement): clauses = [clause.copy_container() for clause in self.clauses] return _CalculatedClause(type=self.type, engine=self._engine, *clauses) + def get_children(self, **kwargs): + return self.clauses def accept_visitor(self, visitor): - for c in self.clauses: - c.accept_visitor(visitor) visitor.visit_calculatedclause(self) def _bind_param(self, obj): @@ -1432,9 +1528,9 @@ class _Function(_CalculatedClause, FromClause): clauses = [clause.copy_container() for clause in self.clauses] return _Function(self.name, type=self.type, packagenames=self.packagenames, engine=self._engine, *clauses) + def get_children(self, **kwargs): + return self.clauses def accept_visitor(self, visitor): - for c in self.clauses: - c.accept_visitor(visitor) visitor.visit_function(self) class _Cast(ColumnElement): @@ -1445,9 +1541,9 @@ class _Cast(ColumnElement): self.clause = clause self.typeclause = _TypeClause(self.type) + def get_children(self, **kwargs): + return self.clause, self.typeclause def accept_visitor(self, visitor): - self.clause.accept_visitor(visitor) - self.typeclause.accept_visitor(visitor) visitor.visit_cast(self) def _get_from_objects(self): @@ -1494,9 +1590,9 @@ class _BinaryClause(ClauseElement): return self.__class__(self.left.copy_container(), self.right.copy_container(), self.operator) def _get_from_objects(self): return self.left._get_from_objects() + self.right._get_from_objects() + def get_children(self, **kwargs): + return self.left, self.right def accept_visitor(self, visitor): - self.left.accept_visitor(visitor) - self.right.accept_visitor(visitor) visitor.visit_binary(self) def swap(self): c = self.left @@ -1589,12 +1685,12 @@ class Join(FromClause): def _get_folded_equivalents(self, equivs=None): if equivs is None: equivs = util.Set() - class LocateEquivs(ClauseVisitor): + class LocateEquivs(NoColumnVisitor): def visit_binary(self, binary): if binary.operator == '=' and binary.left.name == binary.right.name: equivs.add(binary.right) equivs.add(binary.left) - self.onclause.accept_visitor(LocateEquivs()) + LocateEquivs().traverse(self.onclause) collist = [] if isinstance(self.left, Join): left = self.left._get_folded_equivalents(equivs) @@ -1636,10 +1732,9 @@ class Join(FromClause): return select(collist, whereclause, from_obj=[self], **kwargs) + def get_children(self, **kwargs): + return self.left, self.right, self.onclause def accept_visitor(self, visitor): - self.left.accept_visitor(visitor) - self.right.accept_visitor(visitor) - self.onclause.accept_visitor(visitor) visitor.visit_join(self) engine = property(lambda s:s.left.engine or s.right.engine) @@ -1692,8 +1787,11 @@ class Alias(FromClause): #return self.selectable._exportable_columns() return self.selectable.columns + def get_children(self, **kwargs): + for c in self.c: + yield c + yield self.selectable def accept_visitor(self, visitor): - self.selectable.accept_visitor(visitor) visitor.visit_alias(self) def _get_from_objects(self): @@ -1717,9 +1815,10 @@ class _Label(ColumnElement): key = property(lambda s: s.name) _label = property(lambda s: s.name) orig_set = property(lambda s:s.obj.orig_set) - + + def get_children(self, **kwargs): + return self.obj, def accept_visitor(self, visitor): - self.obj.accept_visitor(visitor) visitor.visit_label(self) def _get_from_objects(self): @@ -1841,6 +1940,11 @@ class TableClause(FromClause): original_columns = property(_orig_columns) + def get_children(self, column_collections=True, **kwargs): + if column_collections: + return [c for c in self.c] + else: + return [] def accept_visitor(self, visitor): visitor.visit_table(self) @@ -1964,11 +2068,10 @@ class CompoundSelect(_SelectBaseMixin, FromClause): col.orig_set = colset return col + def get_children(self, column_collections=True, **kwargs): + return (column_collections and list(self.c) or []) + \ + [self.order_by_clause, self.group_by_clause] + list(self.selects) def accept_visitor(self, visitor): - self.order_by_clause.accept_visitor(visitor) - self.group_by_clause.accept_visitor(visitor) - for s in self.selects: - s.accept_visitor(visitor) visitor.visit_compound_select(self) def _find_engine(self): @@ -2028,9 +2131,9 @@ class Select(_SelectBaseMixin, FromClause): self.order_by(*(order_by or [None])) self.group_by(*(group_by or [None])) for c in self.order_by_clause: - c.accept_visitor(self.__correlator) + self.__correlator.traverse(c) for c in self.group_by_clause: - c.accept_visitor(self.__correlator) + self.__correlator.traverse(c) for f in from_obj: self.append_from(f) @@ -2044,13 +2147,14 @@ class Select(_SelectBaseMixin, FromClause): self.append_having(having) - class _CorrelatedVisitor(ClauseVisitor): + class _CorrelatedVisitor(NoColumnVisitor): """Visit a clause, locate any ``Select`` clauses, and tell them that they should correlate their ``FROM`` list to that of their parent. """ def __init__(self, select, is_where): + NoColumnVisitor.__init__(self) self.select = select self.is_where = is_where @@ -2084,12 +2188,12 @@ class Select(_SelectBaseMixin, FromClause): # if the column is a Select statement itself, # accept visitor - column.accept_visitor(self.__correlator) + self.__correlator.traverse(column) # visit the FROM objects of the column looking for more Selects for f in column._get_from_objects(): if f is not self: - f.accept_visitor(self.__correlator) + self.__correlator.traverse(f) self._process_froms(column, False) def _make_proxy(self, selectable, name): if self.is_scalar: @@ -2127,7 +2231,7 @@ class Select(_SelectBaseMixin, FromClause): def _append_condition(self, attribute, condition): if type(condition) == str: condition = _TextClause(condition) - condition.accept_visitor(self.__wherecorrelator) + self.__wherecorrelator.traverse(condition) self._process_froms(condition, False) if getattr(self, attribute) is not None: setattr(self, attribute, and_(getattr(self, attribute), condition)) @@ -2146,7 +2250,7 @@ class Select(_SelectBaseMixin, FromClause): def append_from(self, fromclause): if type(fromclause) == str: fromclause = FromClause(fromclause) - fromclause.accept_visitor(self.__correlator) + self.__correlator.traverse(fromclause) self._process_froms(fromclause, True) def _locate_oid_column(self): @@ -2169,16 +2273,14 @@ class Select(_SelectBaseMixin, FromClause): return f froms = property(_calc_froms, doc="""A collection containing all elements of the FROM clause""") + + def get_children(self, column_collections=True, **kwargs): + return (column_collections and list(self.columns) or []) + \ + list(self.froms) + \ + [x for x in (self.whereclause, self.having) if x is not None] + \ + [self.order_by_clause, self.group_by_clause] def accept_visitor(self, visitor): - for f in self.froms: - f.accept_visitor(visitor) - if self.whereclause is not None: - self.whereclause.accept_visitor(visitor) - if self.having is not None: - self.having.accept_visitor(visitor) - self.order_by_clause.accept_visitor(visitor) - self.group_by_clause.accept_visitor(visitor) visitor.visit_select(self) def union(self, other, **kwargs): @@ -2259,10 +2361,12 @@ class _Insert(_UpdateBase): self.select = None self.parameters = self._process_colparams(values) - def accept_visitor(self, visitor): + def get_children(self, **kwargs): if self.select is not None: - self.select.accept_visitor(visitor) - + return self.select, + else: + return () + def accept_visitor(self, visitor): visitor.visit_insert(self) class _Update(_UpdateBase): @@ -2271,9 +2375,12 @@ class _Update(_UpdateBase): self.whereclause = whereclause self.parameters = self._process_colparams(values) - def accept_visitor(self, visitor): + def get_children(self, **kwargs): if self.whereclause is not None: - self.whereclause.accept_visitor(visitor) + return self.whereclause, + else: + return () + def accept_visitor(self, visitor): visitor.visit_update(self) class _Delete(_UpdateBase): @@ -2281,7 +2388,10 @@ class _Delete(_UpdateBase): self.table = table self.whereclause = whereclause - def accept_visitor(self, visitor): + def get_children(self, **kwargs): if self.whereclause is not None: - self.whereclause.accept_visitor(visitor) + return self.whereclause, + else: + return () + def accept_visitor(self, visitor): visitor.visit_delete(self) |
