summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2008-05-09 16:34:10 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2008-05-09 16:34:10 +0000
commit4a6afd469fad170868554bf28578849bf3dfd5dd (patch)
treeb396edc33d567ae19dd244e87137296450467725 /lib/sqlalchemy/sql
parent46b7c9dc57a38d5b9e44a4723dad2ad8ec57baca (diff)
downloadsqlalchemy-4a6afd469fad170868554bf28578849bf3dfd5dd.tar.gz
r4695 merged to trunk; trunk now becomes 0.5.
0.4 development continues at /sqlalchemy/branches/rel_0_4
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/__init__.py2
-rw-r--r--lib/sqlalchemy/sql/compiler.py101
-rw-r--r--lib/sqlalchemy/sql/expression.py490
-rw-r--r--lib/sqlalchemy/sql/operators.py2
-rw-r--r--lib/sqlalchemy/sql/util.py283
-rw-r--r--lib/sqlalchemy/sql/visitors.py282
6 files changed, 579 insertions, 581 deletions
diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py
index c966f396a..5ea9eb1e6 100644
--- a/lib/sqlalchemy/sql/__init__.py
+++ b/lib/sqlalchemy/sql/__init__.py
@@ -1,2 +1,2 @@
from sqlalchemy.sql.expression import *
-from sqlalchemy.sql.visitors import ClauseVisitor, NoColumnVisitor
+from sqlalchemy.sql.visitors import ClauseVisitor
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 1fe9ef062..78bb4e31c 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -19,7 +19,7 @@ is otherwise internal to SQLAlchemy.
"""
import string, re, itertools
-from sqlalchemy import schema, engine, util, exceptions
+from sqlalchemy import schema, engine, util, exc
from sqlalchemy.sql import operators, functions
from sqlalchemy.sql import expression as sql
@@ -115,8 +115,6 @@ class DefaultCompiler(engine.Compiled):
paradigm as visitors.ClauseVisitor but implements its own traversal.
"""
- __traverse_options__ = {'column_collections':False, 'entry':True}
-
operators = OPERATORS
functions = FUNCTIONS
@@ -162,17 +160,12 @@ class DefaultCompiler(engine.Compiled):
# for aliases
self.generated_ids = {}
- # paramstyle from the dialect (comes from DB-API)
- self.paramstyle = self.dialect.paramstyle
-
# true if the paramstyle is positional
self.positional = self.dialect.positional
+ if self.positional:
+ self.positiontup = []
- self.bindtemplate = BIND_TEMPLATES[self.paramstyle]
-
- # a list of the compiled's bind parameter names, used to help
- # formulate a positional argument list
- self.positiontup = []
+ self.bindtemplate = BIND_TEMPLATES[self.dialect.paramstyle]
# an IdentifierPreparer that formats the quoting of identifiers
self.preparer = self.dialect.identifier_preparer
@@ -230,15 +223,18 @@ class DefaultCompiler(engine.Compiled):
return ""
def visit_grouping(self, grouping, **kwargs):
- return "(" + self.process(grouping.elem) + ")"
+ return "(" + self.process(grouping.element) + ")"
- def visit_label(self, label, result_map=None):
+ def visit_label(self, label, result_map=None, render_labels=False):
+ if not render_labels:
+ return self.process(label.element)
+
labelname = self._truncated_identifier("colident", label.name)
if result_map is not None:
- result_map[labelname.lower()] = (label.name, (label, label.obj, labelname), label.obj.type)
+ result_map[labelname.lower()] = (label.name, (label, label.element, labelname), label.element.type)
- return " ".join([self.process(label.obj), self.operator_string(operators.as_), self.preparer.format_label(label, labelname)])
+ return " ".join([self.process(label.element), self.operator_string(operators.as_), self.preparer.format_label(label, labelname)])
def visit_column(self, column, result_map=None, **kwargs):
@@ -261,16 +257,16 @@ class DefaultCompiler(engine.Compiled):
if getattr(column, "is_literal", False):
name = self.escape_literal_column(name)
else:
- name = self.preparer.quote(column, name)
+ name = self.preparer.quote(name, column.quote)
if column.table is None or not column.table.named_with_column:
return name
else:
if getattr(column.table, 'schema', None):
- schema_prefix = self.preparer.quote(column.table, column.table.schema) + '.'
+ schema_prefix = self.preparer.quote(column.table.schema, column.table.quote_schema) + '.'
else:
schema_prefix = ''
- return schema_prefix + self.preparer.quote(column.table, ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) + "." + name
+ return schema_prefix + self.preparer.quote(ANONYMOUS_LABEL.sub(self._process_anon, column.table.name), column.table.quote) + "." + name
def escape_literal_column(self, text):
"""provide escaping for the literal_column() construct."""
@@ -387,7 +383,7 @@ class DefaultCompiler(engine.Compiled):
if name in self.binds:
existing = self.binds[name]
if existing is not bindparam and (existing.unique or bindparam.unique):
- raise exceptions.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key)
+ raise exc.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key)
self.binds[bindparam.key] = self.binds[name] = bindparam
return self.bindparam_string(name)
@@ -418,7 +414,7 @@ class DefaultCompiler(engine.Compiled):
return truncname
def _process_anon(self, match):
- (ident, derived) = match.group(1,2)
+ (ident, derived) = match.group(1, 2)
key = ('anonymous', ident)
if key in self.generated_ids:
@@ -436,8 +432,9 @@ class DefaultCompiler(engine.Compiled):
def bindparam_string(self, name):
if self.positional:
self.positiontup.append(name)
-
- return self.bindtemplate % {'name':name, 'position':len(self.positiontup)}
+ return self.bindtemplate % {'name':name, 'position':len(self.positiontup)}
+ else:
+ return self.bindtemplate % {'name':name}
def visit_alias(self, alias, asfrom=False, **kwargs):
if asfrom:
@@ -490,7 +487,7 @@ class DefaultCompiler(engine.Compiled):
froms = select._get_display_froms(existingfroms)
- correlate_froms = util.Set(itertools.chain(*([froms] + [f._get_from_objects() for f in froms])))
+ correlate_froms = util.Set(sql._from_objects(*froms))
# TODO: might want to propigate existing froms for select(select(select))
# where innermost select should correlate to outermost
@@ -504,6 +501,7 @@ class DefaultCompiler(engine.Compiled):
[c for c in [
self.process(
self.label_select_column(select, co, asfrom=asfrom),
+ render_labels=True,
**column_clause_args)
for co in select.inner_columns
]
@@ -580,9 +578,9 @@ class DefaultCompiler(engine.Compiled):
def visit_table(self, table, asfrom=False, **kwargs):
if asfrom:
if getattr(table, "schema", None):
- return self.preparer.quote(table, table.schema) + "." + self.preparer.quote(table, table.name)
+ return self.preparer.quote(table.schema, table.quote_schema) + "." + self.preparer.quote(table.name, table.quote)
else:
- return self.preparer.quote(table, table.name)
+ return self.preparer.quote(table.name, table.quote)
else:
return ""
@@ -603,7 +601,7 @@ class DefaultCompiler(engine.Compiled):
return (insert + " INTO %s (%s) VALUES (%s)" %
(preparer.format_table(insert_stmt.table),
- ', '.join([preparer.quote(c[0], c[0].name)
+ ', '.join([preparer.quote(c[0].name, c[0].quote)
for c in colparams]),
', '.join([c[1] for c in colparams])))
@@ -613,7 +611,7 @@ class DefaultCompiler(engine.Compiled):
self.isupdate = True
colparams = self._get_colparams(update_stmt)
- text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.quote(c[0], c[0].name), c[1]) for c in colparams], ', ')
+ text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.quote(c[0].name, c[0].quote), c[1]) for c in colparams], ', ')
if update_stmt._whereclause:
text += " WHERE " + self.process(update_stmt._whereclause)
@@ -837,7 +835,7 @@ class SchemaGenerator(DDLBase):
if constraint.name is not None:
self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint))
self.append("PRIMARY KEY ")
- self.append("(%s)" % ', '.join([self.preparer.quote(c, c.name) for c in constraint]))
+ self.append("(%s)" % ', '.join([self.preparer.quote(c.name, c.quote) for c in constraint]))
self.define_constraint_deferrability(constraint)
def visit_foreign_key_constraint(self, constraint):
@@ -858,9 +856,9 @@ class SchemaGenerator(DDLBase):
preparer.format_constraint(constraint))
table = list(constraint.elements)[0].column.table
self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % (
- ', '.join([preparer.quote(f.parent, f.parent.name) for f in constraint.elements]),
+ ', '.join([preparer.quote(f.parent.name, f.parent.quote) for f in constraint.elements]),
preparer.format_table(table),
- ', '.join([preparer.quote(f.column, f.column.name) for f in constraint.elements])
+ ', '.join([preparer.quote(f.column.name, f.column.quote) for f in constraint.elements])
))
if constraint.ondelete is not None:
self.append(" ON DELETE %s" % constraint.ondelete)
@@ -873,7 +871,7 @@ class SchemaGenerator(DDLBase):
if constraint.name is not None:
self.append("CONSTRAINT %s " %
self.preparer.format_constraint(constraint))
- self.append(" UNIQUE (%s)" % (', '.join([self.preparer.quote(c, c.name) for c in constraint])))
+ self.append(" UNIQUE (%s)" % (', '.join([self.preparer.quote(c.name, c.quote) for c in constraint])))
self.define_constraint_deferrability(constraint)
def define_constraint_deferrability(self, constraint):
@@ -896,7 +894,7 @@ class SchemaGenerator(DDLBase):
self.append("INDEX %s ON %s (%s)" \
% (preparer.format_index(index),
preparer.format_table(index.table),
- string.join([preparer.quote(c, c.name) for c in index.columns], ', ')))
+ string.join([preparer.quote(c.name, c.quote) for c in index.columns], ', ')))
self.execute()
@@ -1005,9 +1003,12 @@ class IdentifierPreparer(object):
or not self.legal_characters.match(unicode(value))
or (lc_value != value))
- def quote(self, obj, ident):
- if getattr(obj, 'quote', False):
+ def quote(self, ident, force):
+ if force:
return self.quote_identifier(ident)
+ elif force is False:
+ return ident
+
if ident in self.__strings:
return self.__strings[ident]
else:
@@ -1017,53 +1018,47 @@ class IdentifierPreparer(object):
self.__strings[ident] = ident
return self.__strings[ident]
- def should_quote(self, object):
- return object.quote or self._requires_quotes(object.name)
-
def format_sequence(self, sequence, use_schema=True):
- name = self.quote(sequence, sequence.name)
+ name = self.quote(sequence.name, sequence.quote)
if not self.omit_schema and use_schema and sequence.schema is not None:
- name = self.quote(sequence, sequence.schema) + "." + name
+ name = self.quote(sequence.schema, sequence.quote) + "." + name
return name
def format_label(self, label, name=None):
- return self.quote(label, name or label.name)
+ return self.quote(name or label.name, label.quote)
def format_alias(self, alias, name=None):
- return self.quote(alias, name or alias.name)
+ return self.quote(name or alias.name, alias.quote)
def format_savepoint(self, savepoint, name=None):
- return self.quote(savepoint, name or savepoint.ident)
+ return self.quote(name or savepoint.ident, savepoint.quote)
def format_constraint(self, constraint):
- return self.quote(constraint, constraint.name)
+ return self.quote(constraint.name, constraint.quote)
def format_index(self, index):
- return self.quote(index, index.name)
+ return self.quote(index.name, index.quote)
def format_table(self, table, use_schema=True, name=None):
"""Prepare a quoted table and schema name."""
if name is None:
name = table.name
- result = self.quote(table, name)
+ result = self.quote(name, table.quote)
if not self.omit_schema and use_schema and getattr(table, "schema", None):
- result = self.quote(table, table.schema) + "." + result
+ result = self.quote(table.schema, table.quote_schema) + "." + result
return result
def format_column(self, column, use_table=False, name=None, table_name=None):
- """Prepare a quoted column name.
-
- deprecated. use preparer.quote(col, column.name) or combine with format_table()
- """
+ """Prepare a quoted column name."""
if name is None:
name = column.name
if not getattr(column, 'is_literal', False):
if use_table:
- return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.quote(column, name)
+ return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.quote(name, column.quote)
else:
- return self.quote(column, name)
+ return self.quote(name, column.quote)
else:
# literal textual elements get stuck into ColumnClause alot, which shouldnt get quoted
if use_table:
@@ -1079,7 +1074,7 @@ class IdentifierPreparer(object):
# a longer sequence.
if not self.omit_schema and use_schema and getattr(table, 'schema', None):
- return (self.quote_identifier(table.schema),
+ return (self.quote(table.schema, table.quote_schema),
self.format_table(table, use_schema=False))
else:
return (self.format_table(table, use_schema=False), )
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index 867fdd69c..7ce637701 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -26,12 +26,12 @@ to stay the same in future releases.
"""
import itertools, re
-from sqlalchemy import util, exceptions
+from sqlalchemy import util, exc
from sqlalchemy.sql import operators, visitors
from sqlalchemy import types as sqltypes
functions, schema, sql_util = None, None, None
-DefaultDialect, ClauseAdapter = None, None
+DefaultDialect, ClauseAdapter, Annotated = None, None, None
__all__ = [
'Alias', 'ClauseElement',
@@ -503,15 +503,21 @@ def collate(expression, collation):
def exists(*args, **kwargs):
"""Return an ``EXISTS`` clause as applied to a [sqlalchemy.sql.expression#Select] object.
+
+ Calling styles are of the following forms::
+
+ # use on an existing select()
+ s = select([<columns>]).where(<criterion>)
+ s = exists(s)
+
+ # construct a select() at once
+ exists(['*'], **select_arguments).where(<criterion>)
+
+ # columns argument is optional, generates "EXISTS (SELECT *)"
+ # by default.
+ exists().where(<criterion>)
- The resulting [sqlalchemy.sql.expression#_Exists] object can be executed by
- itself or used as a subquery within an enclosing select.
-
- \*args, \**kwargs
- all arguments are sent directly to the [sqlalchemy.sql.expression#select()]
- function to produce a ``SELECT`` statement.
"""
-
return _Exists(*args, **kwargs)
def union(*selects, **kwargs):
@@ -872,27 +878,36 @@ def _compound_select(keyword, *selects, **kwargs):
return CompoundSelect(keyword, *selects, **kwargs)
def _is_literal(element):
- return not isinstance(element, ClauseElement)
+ return not isinstance(element, (ClauseElement, Operators))
+
+def _from_objects(*elements, **kwargs):
+ return itertools.chain(*[element._get_from_objects(**kwargs) for element in elements])
+def _labeled(element):
+ if not hasattr(element, 'name'):
+ return element.label(None)
+ else:
+ return element
+
def _literal_as_text(element):
- if isinstance(element, Operators):
- return element.expression_element()
+ if hasattr(element, '__clause_element__'):
+ return element.__clause_element__()
elif _is_literal(element):
return _TextClause(unicode(element))
else:
return element
def _literal_as_column(element):
- if isinstance(element, Operators):
- return element.clause_element()
+ if hasattr(element, '__clause_element__'):
+ return element.__clause_element__()
elif _is_literal(element):
return literal_column(str(element))
else:
return element
def _literal_as_binds(element, name=None, type_=None):
- if isinstance(element, Operators):
- return element.expression_element()
+ if hasattr(element, '__clause_element__'):
+ return element.__clause_element__()
elif _is_literal(element):
if element is None:
return null()
@@ -902,17 +917,17 @@ def _literal_as_binds(element, name=None, type_=None):
return element
def _no_literals(element):
- if isinstance(element, Operators):
- return element.expression_element()
+ if hasattr(element, '__clause_element__'):
+ return element.__clause_element__()
elif _is_literal(element):
- raise exceptions.ArgumentError("Ambiguous literal: %r. Use the 'text()' function to indicate a SQL expression literal, or 'literal()' to indicate a bound value." % element)
+ raise exc.ArgumentError("Ambiguous literal: %r. Use the 'text()' function to indicate a SQL expression literal, or 'literal()' to indicate a bound value." % element)
else:
return element
def _corresponding_column_or_error(fromclause, column, require_embedded=False):
c = fromclause.corresponding_column(column, require_embedded=require_embedded)
if not c:
- raise exceptions.InvalidRequestError("Given column '%s', attached to table '%s', failed to locate a corresponding column from table '%s'" % (str(column), str(getattr(column, 'table', None)), fromclause.description))
+ raise exc.InvalidRequestError("Given column '%s', attached to table '%s', failed to locate a corresponding column from table '%s'" % (str(column), str(getattr(column, 'table', None)), fromclause.description))
return c
def _selectable(element):
@@ -921,9 +936,8 @@ def _selectable(element):
elif isinstance(element, Selectable):
return element
else:
- raise exceptions.ArgumentError("Object '%s' is not a Selectable and does not implement `__selectable__()`" % repr(element))
+ raise exc.ArgumentError("Object '%s' is not a Selectable and does not implement `__selectable__()`" % repr(element))
-
def is_column(col):
"""True if ``col`` is an instance of ``ColumnElement``."""
return isinstance(col, ColumnElement)
@@ -941,7 +955,9 @@ class _FigureVisitName(type):
class ClauseElement(object):
"""Base class for elements of a programmatically constructed SQL expression."""
__metaclass__ = _FigureVisitName
-
+ _annotations = {}
+ supports_execution = False
+
def _clone(self):
"""Create a shallow copy of this ClauseElement.
@@ -976,6 +992,14 @@ class ClauseElement(object):
"""
raise NotImplementedError(repr(self))
+
+ def _annotate(self, values):
+ """return a copy of this ClauseElement with the given annotations dictionary."""
+
+ global Annotated
+ if Annotated is None:
+ from sqlalchemy.sql.util import Annotated
+ return Annotated(self, values)
def unique_params(self, *optionaldict, **kwargs):
"""Return a copy with ``bindparam()`` elments replaced.
@@ -1006,14 +1030,14 @@ class ClauseElement(object):
if len(optionaldict) == 1:
kwargs.update(optionaldict[0])
elif len(optionaldict) > 1:
- raise exceptions.ArgumentError("params() takes zero or one positional dictionary argument")
+ raise exc.ArgumentError("params() takes zero or one positional dictionary argument")
def visit_bindparam(bind):
if bind.key in kwargs:
bind.value = kwargs[bind.key]
if unique:
bind._convert_to_unique()
- return visitors.traverse(self, visit_bindparam=visit_bindparam, clone=True)
+ return visitors.cloned_traverse(self, {}, {'bindparam':visit_bindparam})
def compare(self, other):
"""Compare this ClauseElement to the given ClauseElement.
@@ -1049,11 +1073,6 @@ class ClauseElement(object):
def self_group(self, against=None):
return self
- def supports_execution(self):
- """Return True if this clause element represents a complete executable statement."""
-
- return False
-
def bind(self):
"""Returns the Engine or Connection to which this ClauseElement is bound, or None if none found."""
@@ -1062,7 +1081,7 @@ class ClauseElement(object):
return self._bind
except AttributeError:
pass
- for f in self._get_from_objects():
+ for f in _from_objects(self):
if f is self:
continue
engine = f.bind
@@ -1083,7 +1102,7 @@ class ClauseElement(object):
'Engine for execution. Or, assign a bind to the statement '
'or the Metadata of its underlying tables to enable '
'implicit execution via this method.' % label)
- raise exceptions.UnboundExecutionError(msg)
+ raise exc.UnboundExecutionError(msg)
return e.execute_clauseelement(self, multiparams, params)
def scalar(self, *multiparams, **params):
@@ -1159,6 +1178,12 @@ class ClauseElement(object):
self.__module__, self.__class__.__name__, id(self), friendly)
+class _Immutable(object):
+ """mark a ClauseElement as 'immutable' when expressions are cloned."""
+
+ def _clone(self):
+ return self
+
class Operators(object):
def __and__(self, other):
return self.operate(operators.and_, other)
@@ -1174,9 +1199,6 @@ class Operators(object):
return self.operate(operators.op, opstring, b)
return op
- def clause_element(self):
- raise NotImplementedError()
-
def operate(self, op, *other, **kwargs):
raise NotImplementedError()
@@ -1216,7 +1238,7 @@ class ColumnOperators(Operators):
def ilike(self, other, escape=None):
return self.operate(operators.ilike_op, other, escape=escape)
- def in_(self, *other):
+ def in_(self, other):
return self.operate(operators.in_op, other)
def startswith(self, other, **kwargs):
@@ -1279,18 +1301,18 @@ class _CompareMixin(ColumnOperators):
def __compare(self, op, obj, negate=None, reverse=False, **kwargs):
if obj is None or isinstance(obj, _Null):
if op == operators.eq:
- return _BinaryExpression(self.expression_element(), null(), operators.is_, negate=operators.isnot)
+ return _BinaryExpression(self, null(), operators.is_, negate=operators.isnot)
elif op == operators.ne:
- return _BinaryExpression(self.expression_element(), null(), operators.isnot, negate=operators.is_)
+ return _BinaryExpression(self, null(), operators.isnot, negate=operators.is_)
else:
- raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL")
+ raise exc.ArgumentError("Only '='/'!=' operators can be used with NULL")
else:
obj = self._check_literal(obj)
if reverse:
- return _BinaryExpression(obj, self.expression_element(), op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs)
+ return _BinaryExpression(obj, self, op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs)
else:
- return _BinaryExpression(self.expression_element(), obj, op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs)
+ return _BinaryExpression(self, obj, op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs)
def __operate(self, op, obj, reverse=False):
obj = self._check_literal(obj)
@@ -1298,9 +1320,9 @@ class _CompareMixin(ColumnOperators):
type_ = self._compare_type(obj)
if reverse:
- return _BinaryExpression(obj, self.expression_element(), type_.adapt_operator(op), type_=type_)
+ return _BinaryExpression(obj, self, type_.adapt_operator(op), type_=type_)
else:
- return _BinaryExpression(self.expression_element(), obj, type_.adapt_operator(op), type_=type_)
+ return _BinaryExpression(self, obj, type_.adapt_operator(op), type_=type_)
# a mapping of operators with the method they use, along with their negated
# operator for comparison operators
@@ -1329,17 +1351,10 @@ class _CompareMixin(ColumnOperators):
o = _CompareMixin.operators[op]
return o[0](self, op, other, reverse=True, *o[1:], **kwargs)
- def in_(self, *other):
- return self._in_impl(operators.in_op, operators.notin_op, *other)
-
- def _in_impl(self, op, negate_op, *other):
- # Handle old style *args argument passing
- if len(other) != 1 or not isinstance(other[0], Selectable) and (not hasattr(other[0], '__iter__') or isinstance(other[0], basestring)):
- util.warn_deprecated('passing in_ arguments as varargs is deprecated, in_ takes a single argument that is a sequence or a selectable')
- seq_or_selectable = other
- else:
- seq_or_selectable = other[0]
+ def in_(self, other):
+ return self._in_impl(operators.in_op, operators.notin_op, other)
+ def _in_impl(self, op, negate_op, seq_or_selectable):
if isinstance(seq_or_selectable, Selectable):
return self.__compare( op, seq_or_selectable, negate=negate_op)
@@ -1348,7 +1363,7 @@ class _CompareMixin(ColumnOperators):
for o in seq_or_selectable:
if not _is_literal(o):
if not isinstance( o, _CompareMixin):
- raise exceptions.InvalidRequestError( "in() function accepts either a list of non-selectable values, or a selectable: "+repr(o) )
+ raise exc.InvalidRequestError( "in() function accepts either a list of non-selectable values, or a selectable: "+repr(o) )
else:
o = self._bind_param(o)
args.append(o)
@@ -1433,22 +1448,13 @@ class _CompareMixin(ColumnOperators):
if isinstance(other, _BindParamClause) and isinstance(other.type, sqltypes.NullType):
other.type = self.type
return other
- elif isinstance(other, Operators):
- return other.expression_element()
+ elif hasattr(other, '__clause_element__'):
+ return other.__clause_element__()
elif _is_literal(other):
return self._bind_param(other)
else:
return other
- def clause_element(self):
- """Allow ``_CompareMixins`` to return the underlying ``ClauseElement``, for non-``ClauseElement`` ``_CompareMixins``."""
- return self
-
- def expression_element(self):
- """Allow ``_CompareMixins`` to return the appropriate object to be used in expressions."""
-
- return self
-
def _compare_type(self, obj):
"""Allow subclasses to override the type used in constructing
``_BinaryExpression`` objects.
@@ -1480,23 +1486,22 @@ class ColumnElement(ClauseElement, _CompareMixin):
primary_key = False
foreign_keys = []
-
+ quote = None
+
def base_columns(self):
- if hasattr(self, '_base_columns'):
- return self._base_columns
- self._base_columns = util.Set([c for c in self.proxy_set if not hasattr(c, 'proxies')])
+ if not hasattr(self, '_base_columns'):
+ self._base_columns = util.Set([c for c in self.proxy_set if not hasattr(c, 'proxies')])
return self._base_columns
base_columns = property(base_columns)
def proxy_set(self):
- if hasattr(self, '_proxy_set'):
- return self._proxy_set
- s = util.Set([self])
- if hasattr(self, 'proxies'):
- for c in self.proxies:
- s = s.union(c.proxy_set)
- self._proxy_set = s
- return s
+ if not hasattr(self, '_proxy_set'):
+ s = util.Set([self])
+ if hasattr(self, 'proxies'):
+ for c in self.proxies:
+ s.update(c.proxy_set)
+ self._proxy_set = s
+ return self._proxy_set
proxy_set = property(proxy_set)
def shares_lineage(self, othercolumn):
@@ -1518,7 +1523,7 @@ class ColumnElement(ClauseElement, _CompareMixin):
co = _ColumnClause(self.anon_label, selectable, type_=getattr(self, 'type', None))
co.proxies = [self]
- selectable.columns[name]= co
+ selectable.columns[name] = co
return co
def anon_label(self):
@@ -1613,7 +1618,7 @@ class ColumnCollection(util.OrderedProperties):
def __contains__(self, other):
if not isinstance(other, basestring):
- raise exceptions.ArgumentError("__contains__ requires a string argument")
+ raise exc.ArgumentError("__contains__ requires a string argument")
return util.OrderedProperties.__contains__(self, other)
def contains_column(self, col):
@@ -1641,6 +1646,9 @@ class ColumnSet(util.OrderedSet):
l.append(c==local)
return and_(*l)
+ def __hash__(self):
+ return hash(tuple(self._list))
+
class Selectable(ClauseElement):
"""mark a class as being selectable"""
@@ -1648,8 +1656,9 @@ class FromClause(Selectable):
"""Represent an element that can be used within the ``FROM`` clause of a ``SELECT`` statement."""
__visit_name__ = 'fromclause'
- named_with_column=False
+ named_with_column = False
_hide_froms = []
+ quote = None
def _get_from_objects(self, **modifiers):
return []
@@ -1694,12 +1703,12 @@ class FromClause(Selectable):
return fromclause in util.Set(self._cloned_set)
def replace_selectable(self, old, alias):
- """replace all occurences of FromClause 'old' with the given Alias object, returning a copy of this ``FromClause``."""
+ """replace all occurences of FromClause 'old' with the given Alias object, returning a copy of this ``FromClause``."""
- global ClauseAdapter
- if ClauseAdapter is None:
- from sqlalchemy.sql.util import ClauseAdapter
- return ClauseAdapter(alias).traverse(self, clone=True)
+ global ClauseAdapter
+ if ClauseAdapter is None:
+ from sqlalchemy.sql.util import ClauseAdapter
+ return ClauseAdapter(alias).traverse(self)
def correspond_on_equivalents(self, column, equivalents):
col = self.corresponding_column(column, require_embedded=True)
@@ -1859,7 +1868,7 @@ class _BindParamClause(ClauseElement, _CompareMixin):
def _convert_to_unique(self):
if not self.unique:
- self.unique=True
+ self.unique = True
self.key = "{ANON %d %s}" % (id(self), self._orig_key or 'param')
def _get_from_objects(self, **modifiers):
@@ -1910,6 +1919,7 @@ class _TextClause(ClauseElement):
__visit_name__ = 'textclause'
_bind_params_regex = re.compile(r'(?<![:\w\x5c]):(\w+)(?!:)', re.UNICODE)
+ supports_execution = True
_hide_froms = []
oid_column = None
@@ -1950,12 +1960,6 @@ class _TextClause(ClauseElement):
def _get_from_objects(self, **modifiers):
return []
- def supports_execution(self):
- return True
-
- def _table_iterator(self):
- return iter([])
-
class _Null(ColumnElement):
"""Represent the NULL keyword in a SQL statement.
@@ -2042,6 +2046,7 @@ class _CalculatedClause(ColumnElement):
__visit_name__ = 'calculatedclause'
def __init__(self, name, *clauses, **kwargs):
+ ColumnElement.__init__(self)
self.name = name
self.type = sqltypes.to_instance(kwargs.get('type_', None))
self._bind = kwargs.get('bind', None)
@@ -2061,7 +2066,7 @@ class _CalculatedClause(ColumnElement):
def clauses(self):
if isinstance(self.clause_expr, _Grouping):
- return self.clause_expr.elem
+ return self.clause_expr.element
else:
return self.clause_expr
clauses = property(clauses)
@@ -2239,8 +2244,13 @@ class _Exists(_UnaryExpression):
__visit_name__ = _UnaryExpression.__visit_name__
def __init__(self, *args, **kwargs):
- kwargs['correlate'] = True
- s = select(*args, **kwargs).as_scalar().self_group()
+ if args and isinstance(args[0], _SelectBaseMixin):
+ s = args[0]
+ else:
+ if not args:
+ args = ([literal_column('*')],)
+ s = select(*args, **kwargs).as_scalar().self_group()
+
_UnaryExpression.__init__(self, s, operator=operators.exists)
def select(self, whereclause=None, **params):
@@ -2272,7 +2282,7 @@ class Join(FromClause):
self.right = _selectable(right).self_group()
if onclause is None:
- self.onclause = self.__match_primaries(self.left, self.right)
+ self.onclause = self._match_primaries(self.left, self.right)
else:
self.onclause = onclause
@@ -2310,7 +2320,7 @@ class Join(FromClause):
def get_children(self, **kwargs):
return self.left, self.right, self.onclause
- def __match_primaries(self, primary, secondary):
+ def _match_primaries(self, primary, secondary):
global sql_util
if not sql_util:
from sqlalchemy.sql import util as sql_util
@@ -2359,7 +2369,7 @@ class Join(FromClause):
return self.select(use_labels=True, correlate=False).alias(name)
def _hide_froms(self):
- return itertools.chain(*[x.left._get_from_objects() + x.right._get_from_objects() for x in self._cloned_set])
+ return itertools.chain(*[_from_objects(x.left, x.right) for x in self._cloned_set])
_hide_froms = property(_hide_froms)
def _get_from_objects(self, **modifiers):
@@ -2382,9 +2392,10 @@ class Alias(FromClause):
def __init__(self, selectable, alias=None):
baseselectable = selectable
while isinstance(baseselectable, Alias):
- baseselectable = baseselectable.selectable
+ baseselectable = baseselectable.element
self.original = baseselectable
- self.selectable = selectable
+ self.supports_execution = baseselectable.supports_execution
+ self.element = selectable
if alias is None:
if self.original.named_with_column:
alias = getattr(self.original, 'name', None)
@@ -2398,112 +2409,100 @@ class Alias(FromClause):
def is_derived_from(self, fromclause):
if fromclause in util.Set(self._cloned_set):
return True
- return self.selectable.is_derived_from(fromclause)
-
- def supports_execution(self):
- return self.original.supports_execution()
-
- def _table_iterator(self):
- return self.original._table_iterator()
+ return self.element.is_derived_from(fromclause)
def _populate_column_collection(self):
- for col in self.selectable.columns:
+ for col in self.element.columns:
col._make_proxy(self)
- if self.selectable.oid_column is not None:
- self._oid_column = self.selectable.oid_column._make_proxy(self)
+ if self.element.oid_column is not None:
+ self._oid_column = self.element.oid_column._make_proxy(self)
def _copy_internals(self, clone=_clone):
- self._reset_exported()
- self.selectable = _clone(self.selectable)
- baseselectable = self.selectable
- while isinstance(baseselectable, Alias):
- baseselectable = baseselectable.selectable
- self.original = baseselectable
+ self._reset_exported()
+ self.element = _clone(self.element)
+ baseselectable = self.element
+ while isinstance(baseselectable, Alias):
+ baseselectable = baseselectable.selectable
+ self.original = baseselectable
def get_children(self, column_collections=True, aliased_selectables=True, **kwargs):
if column_collections:
for c in self.c:
yield c
if aliased_selectables:
- yield self.selectable
+ yield self.element
def _get_from_objects(self, **modifiers):
return [self]
def bind(self):
- return self.selectable.bind
+ return self.element.bind
bind = property(bind)
-class _ColumnElementAdapter(ColumnElement):
- """Adapts a ClauseElement which may or may not be a
- ColumnElement subclass itself into an object which
- acts like a ColumnElement.
- """
+class _Grouping(ColumnElement):
+ """Represent a grouping within a column expression"""
- def __init__(self, elem):
- self.elem = elem
- self.type = getattr(elem, 'type', None)
+ def __init__(self, element):
+ ColumnElement.__init__(self)
+ self.element = element
+ self.type = getattr(element, 'type', None)
def key(self):
- return self.elem.key
+ return self.element.key
key = property(key)
def _label(self):
try:
- return self.elem._label
+ return self.element._label
except AttributeError:
return self.anon_label
_label = property(_label)
def _copy_internals(self, clone=_clone):
- self.elem = clone(self.elem)
+ self.element = clone(self.element)
def get_children(self, **kwargs):
- return self.elem,
+ return self.element,
def _get_from_objects(self, **modifiers):
- return self.elem._get_from_objects(**modifiers)
+ return self.element._get_from_objects(**modifiers)
def __getattr__(self, attr):
- return getattr(self.elem, attr)
+ return getattr(self.element, attr)
def __getstate__(self):
- return {'elem':self.elem, 'type':self.type}
+ return {'element':self.element, 'type':self.type}
def __setstate__(self, state):
- self.elem = state['elem']
+ self.element = state['element']
self.type = state['type']
-class _Grouping(_ColumnElementAdapter):
- """Represent a grouping within a column expression"""
- pass
-
class _FromGrouping(FromClause):
"""Represent a grouping of a FROM clause"""
__visit_name__ = 'grouping'
- def __init__(self, elem):
- self.elem = elem
+ def __init__(self, element):
+ self.element = element
def columns(self):
- return self.elem.columns
+ return self.element.columns
columns = c = property(columns)
def _hide_froms(self):
- return self.elem._hide_froms
+ return self.element._hide_froms
_hide_froms = property(_hide_froms)
def get_children(self, **kwargs):
- return self.elem,
+ return self.element,
def _copy_internals(self, clone=_clone):
- self.elem = clone(self.elem)
+ self.element = clone(self.element)
def _get_from_objects(self, **modifiers):
- return self.elem._get_from_objects(**modifiers)
+ return self.element._get_from_objects(**modifiers)
def __getattr__(self, attr):
- return getattr(self.elem, attr)
+ return getattr(self.element, attr)
class _Label(ColumnElement):
"""Represents a column label (AS).
@@ -2516,12 +2515,12 @@ class _Label(ColumnElement):
``ColumnElement`` subclasses.
"""
- def __init__(self, name, obj, type_=None):
- while isinstance(obj, _Label):
- obj = obj.obj
- self.name = name or "{ANON %d %s}" % (id(self), getattr(obj, 'name', 'anon'))
- self.obj = obj.self_group(against=operators.as_)
- self.type = sqltypes.to_instance(type_ or getattr(obj, 'type', None))
+ def __init__(self, name, element, type_=None):
+ while isinstance(element, _Label):
+ element = element.element
+ self.name = name or "{ANON %d %s}" % (id(self), getattr(element, 'name', 'anon'))
+ self.element = element.self_group(against=operators.as_)
+ self.type = sqltypes.to_instance(type_ or getattr(element, 'type', None))
def key(self):
return self.name
@@ -2532,8 +2531,9 @@ class _Label(ColumnElement):
_label = property(_label)
def _proxy_attr(name):
+ get = util.attrgetter(name)
def attr(self):
- return getattr(self.obj, name)
+ return get(self.element)
return property(attr)
proxies = _proxy_attr('proxies')
@@ -2542,27 +2542,24 @@ class _Label(ColumnElement):
primary_key = _proxy_attr('primary_key')
foreign_keys = _proxy_attr('foreign_keys')
- def expression_element(self):
- return self.obj
-
def get_children(self, **kwargs):
- return self.obj,
+ return self.element,
def _copy_internals(self, clone=_clone):
- self.obj = clone(self.obj)
+ self.element = clone(self.element)
def _get_from_objects(self, **modifiers):
- return self.obj._get_from_objects(**modifiers)
+ return self.element._get_from_objects(**modifiers)
def _make_proxy(self, selectable, name = None):
- if isinstance(self.obj, (Selectable, ColumnElement)):
- e = self.obj._make_proxy(selectable, name=self.name)
+ if isinstance(self.element, (Selectable, ColumnElement)):
+ e = self.element._make_proxy(selectable, name=self.name)
else:
e = column(self.name)._make_proxy(selectable=selectable)
e.proxies.append(self)
return e
-class _ColumnClause(ColumnElement):
+class _ColumnClause(_Immutable, ColumnElement):
"""Represents a generic column expression from any textual string.
This includes columns associated with tables, aliases and select
@@ -2602,16 +2599,7 @@ class _ColumnClause(ColumnElement):
return self.name.encode('ascii', 'backslashreplace')
description = property(description)
- def _clone(self):
- # ColumnClause is immutable
- return self
-
def _label(self):
- """Generate a 'label' string for this column.
- """
-
- # for a "literal" column, we've no idea what the text is
- # therefore no 'label' can be automatically generated
if self.is_literal:
return None
if not self.__label:
@@ -2626,24 +2614,21 @@ class _ColumnClause(ColumnElement):
counter = 1
while label in self.table.c:
label = self.__label + "_" + str(counter)
- counter +=1
+ counter += 1
self.__label = label
else:
self.__label = self.name
return self.__label
-
_label = property(_label)
def label(self, name):
- # if going off the "__label" property and its None, we have
- # no label; return self
if name is None:
return self
else:
return super(_ColumnClause, self).label(name)
def _get_from_objects(self, **modifiers):
- if self.table is not None:
+ if self.table:
return [self.table]
else:
return []
@@ -2651,20 +2636,20 @@ class _ColumnClause(ColumnElement):
def _bind_param(self, obj):
return _BindParamClause(self.name, obj, type_=self.type, unique=True)
- def _make_proxy(self, selectable, name = None):
+ def _make_proxy(self, selectable, name=None, attach=True):
# propigate the "is_literal" flag only if we are keeping our name,
# otherwise its considered to be a label
is_literal = self.is_literal and (name is None or name == self.name)
c = _ColumnClause(name or self.name, selectable=selectable, _is_oid=self._is_oid, type_=self.type, is_literal=is_literal)
c.proxies = [self]
- if not self._is_oid:
+ if attach and not self._is_oid:
selectable.columns[c.name] = c
return c
def _compare_type(self, obj):
return self.type
-class TableClause(FromClause):
+class TableClause(_Immutable, FromClause):
"""Represents a "table" construct.
Note that this represents tables only as another syntactical
@@ -2691,10 +2676,6 @@ class TableClause(FromClause):
return self.name.encode('ascii', 'backslashreplace')
description = property(description)
- def _clone(self):
- # TableClause is immutable
- return self
-
def append_column(self, c):
self._columns[c.name] = c
c.table = self
@@ -2724,10 +2705,11 @@ class TableClause(FromClause):
def _get_from_objects(self, **modifiers):
return [self]
-
class _SelectBaseMixin(object):
"""Base class for ``Select`` and ``CompoundSelects``."""
+ supports_execution = True
+
def __init__(self, use_labels=False, for_update=False, limit=None, offset=None, order_by=None, group_by=None, bind=None, autocommit=False):
self.use_labels = use_labels
self.for_update = for_update
@@ -2773,11 +2755,6 @@ class _SelectBaseMixin(object):
"""
return self.as_scalar().label(name)
- def supports_execution(self):
- """part of the ClauseElement contract; returns ``True`` in all cases for this class."""
-
- return True
-
def autocommit(self):
"""return a new selectable with the 'autocommit' flag set to True."""
@@ -2860,15 +2837,15 @@ class _SelectBaseMixin(object):
class _ScalarSelect(_Grouping):
__visit_name__ = 'grouping'
- def __init__(self, elem):
- self.elem = elem
- cols = list(elem.inner_columns)
+ def __init__(self, element):
+ self.element = element
+ cols = list(element.inner_columns)
if len(cols) != 1:
- raise exceptions.InvalidRequestError("Scalar select can only be created from a Select object that has exactly one column expression.")
+ raise exc.InvalidRequestError("Scalar select can only be created from a Select object that has exactly one column expression.")
self.type = cols[0].type
def columns(self):
- raise exceptions.InvalidRequestError("Scalar Select expression has no columns; use this object directly within a column-level expression.")
+ raise exc.InvalidRequestError("Scalar Select expression has no columns; use this object directly within a column-level expression.")
columns = c = property(columns)
def self_group(self, **kwargs):
@@ -2893,7 +2870,7 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
if not numcols:
numcols = len(s.c)
elif len(s.c) != numcols:
- raise exceptions.ArgumentError("All selectables passed to CompoundSelect must have identical numbers of columns; select #%d has %d columns, select #%d has %d" %
+ raise exc.ArgumentError("All selectables passed to CompoundSelect must have identical numbers of columns; select #%d has %d columns, select #%d has %d" %
(1, len(self.selects[0].c), n+1, len(s.c))
)
if s._order_by_clause:
@@ -2936,11 +2913,6 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
return (column_collections and list(self.c) or []) + \
[self._order_by_clause, self._group_by_clause] + list(self.selects)
- def _table_iterator(self):
- for s in self.selects:
- for t in s._table_iterator():
- yield t
-
def bind(self):
if self._bind:
return self._bind
@@ -2976,6 +2948,7 @@ class Select(_SelectBaseMixin, FromClause):
self._distinct = distinct
self._correlate = util.Set()
+ self._froms = util.OrderedSet()
if columns:
self._raw_columns = [
@@ -2983,22 +2956,23 @@ class Select(_SelectBaseMixin, FromClause):
for c in
[_literal_as_column(c) for c in columns]
]
+
+ self._froms.update(_from_objects(*self._raw_columns))
else:
self._raw_columns = []
-
- if from_obj:
- self._froms = util.Set([
- _is_literal(f) and _TextClause(f) or f
- for f in util.to_list(from_obj)
- ])
- else:
- self._froms = util.Set()
-
+
if whereclause:
self._whereclause = _literal_as_text(whereclause)
+ self._froms.update(_from_objects(self._whereclause, is_where=True))
else:
self._whereclause = None
+ if from_obj:
+ self._froms.update([
+ _is_literal(f) and _TextClause(f) or f
+ for f in util.to_list(from_obj)
+ ])
+
if having:
self._having = _literal_as_text(having)
else:
@@ -3020,36 +2994,28 @@ class Select(_SelectBaseMixin, FromClause):
correlating.
"""
- froms = util.OrderedSet()
-
- for col in self._raw_columns:
- froms.update(col._get_from_objects())
-
- if self._whereclause is not None:
- froms.update(self._whereclause._get_from_objects(is_where=True))
-
- if self._froms:
- froms.update(self._froms)
+ froms = self._froms
toremove = itertools.chain(*[f._hide_froms for f in froms])
- froms.difference_update(toremove)
+ if toremove:
+ froms = froms.difference(toremove)
if len(froms) > 1 or self._correlate:
if self._correlate:
- froms.difference_update(_cloned_intersection(froms, self._correlate))
+ froms = froms.difference(_cloned_intersection(froms, self._correlate))
if self._should_correlate and existing_froms:
- froms.difference_update(_cloned_intersection(froms, existing_froms))
+ froms = froms.difference(_cloned_intersection(froms, existing_froms))
if not len(froms):
- raise exceptions.InvalidRequestError("Select statement '%s' returned no FROM clauses due to auto-correlation; specify correlate(<tables>) to control correlation manually." % self)
+ raise exc.InvalidRequestError("Select statement '%s' returned no FROM clauses due to auto-correlation; specify correlate(<tables>) to control correlation manually." % self)
return froms
froms = property(_get_display_froms, doc="""Return a list of all FromClause elements which will be applied to the FROM clause of the resulting statement.""")
def type(self):
- raise exceptions.InvalidRequestError("Select objects don't have a type. Call as_scalar() on this Select object to return a 'scalar' version of this Select.")
+ raise exc.InvalidRequestError("Select objects don't have a type. Call as_scalar() on this Select object to return a 'scalar' version of this Select.")
type = property(type)
def locate_all_froms(self):
@@ -3059,22 +3025,10 @@ class Select(_SelectBaseMixin, FromClause):
is specifically for those FromClause elements that would actually be rendered.
"""
- if hasattr(self, '_all_froms'):
- return self._all_froms
-
- froms = util.Set(
- itertools.chain(*
- [self._froms] +
- [f._get_from_objects() for f in self._froms] +
- [col._get_from_objects() for col in self._raw_columns]
- )
- )
+ if not hasattr(self, '_all_froms'):
+ self._all_froms = self._froms.union(_from_objects(*list(self._froms)))
- if self._whereclause:
- froms.update(self._whereclause._get_from_objects(is_where=True))
-
- self._all_froms = froms
- return froms
+ return self._all_froms
def inner_columns(self):
"""an iteratorof all ColumnElement expressions which would
@@ -3092,7 +3046,7 @@ class Select(_SelectBaseMixin, FromClause):
def is_derived_from(self, fromclause):
if self in util.Set(fromclause._cloned_set):
return True
-
+
for f in self.locate_all_froms():
if f.is_derived_from(fromclause):
return True
@@ -3112,7 +3066,7 @@ class Select(_SelectBaseMixin, FromClause):
"""return child elements as per the ClauseElement specification."""
return (column_collections and list(self.columns) or []) + \
- list(self.locate_all_froms()) + \
+ list(self._froms) + \
[x for x in (self._whereclause, self._having, self._order_by_clause, self._group_by_clause) if x is not None]
def column(self, column):
@@ -3125,6 +3079,7 @@ class Select(_SelectBaseMixin, FromClause):
column = column.self_group(against=operators.comma_op)
s._raw_columns = s._raw_columns + [column]
+ s._froms = s._froms.union(_from_objects(column))
return s
def where(self, whereclause):
@@ -3185,7 +3140,7 @@ class Select(_SelectBaseMixin, FromClause):
"""
s = self._generate()
- s._should_correlate=False
+ s._should_correlate = False
if fromclauses == (None,):
s._correlate = util.Set()
else:
@@ -3195,7 +3150,7 @@ class Select(_SelectBaseMixin, FromClause):
def append_correlation(self, fromclause):
"""append the given correlation expression to this select() construct."""
- self._should_correlate=False
+ self._should_correlate = False
self._correlate = self._correlate.union([fromclause])
def append_column(self, column):
@@ -3207,6 +3162,7 @@ class Select(_SelectBaseMixin, FromClause):
column = column.self_group(against=operators.comma_op)
self._raw_columns = self._raw_columns + [column]
+ self._froms = self._froms.union(_from_objects(column))
self._reset_exported()
def append_prefix(self, clause):
@@ -3221,10 +3177,13 @@ class Select(_SelectBaseMixin, FromClause):
The expression will be joined to existing WHERE criterion via AND.
"""
+ whereclause = _literal_as_text(whereclause)
+ self._froms = self._froms.union(_from_objects(whereclause, is_where=True))
+
if self._whereclause is not None:
- self._whereclause = and_(self._whereclause, _literal_as_text(whereclause))
+ self._whereclause = and_(self._whereclause, whereclause)
else:
- self._whereclause = _literal_as_text(whereclause)
+ self._whereclause = whereclause
def append_having(self, having):
"""append the given expression to this select() construct's HAVING criterion.
@@ -3311,31 +3270,23 @@ class Select(_SelectBaseMixin, FromClause):
return intersect_all(self, other, **kwargs)
- def _table_iterator(self):
- for t in visitors.NoColumnVisitor().iterate(self):
- if isinstance(t, TableClause):
- yield t
-
def bind(self):
if self._bind:
return self._bind
- for f in self._froms:
- if f is self:
- continue
- e = f.bind
- if e:
- self._bind = e
- return e
- # look through the columns (largely synomous with looking
- # through the FROMs except in the case of _CalculatedClause/_Function)
- for c in self._raw_columns:
- if getattr(c, 'table', None) is self:
- continue
- e = c.bind
+ if not self._froms:
+ for c in self._raw_columns:
+ e = c.bind
+ if e:
+ self._bind = e
+ return e
+ else:
+ e = list(self._froms)[0].bind
if e:
self._bind = e
return e
+
return None
+
def _set_bind(self, bind):
self._bind = bind
bind = property(bind, _set_bind)
@@ -3343,11 +3294,7 @@ class Select(_SelectBaseMixin, FromClause):
class _UpdateBase(ClauseElement):
"""Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements."""
- def supports_execution(self):
- return True
-
- def _table_iterator(self):
- return iter([self.table])
+ supports_execution = True
def _generate(self):
s = self.__class__.__new__(self.__class__)
@@ -3407,7 +3354,7 @@ class Insert(_ValuesBase):
self._bind = bind
self.table = table
self.select = None
- self.inline=inline
+ self.inline = inline
if prefixes:
self._prefixes = [_literal_as_text(p) for p in prefixes]
else:
@@ -3502,10 +3449,11 @@ class Delete(_UpdateBase):
self._whereclause = clone(self._whereclause)
class _IdentifiedClause(ClauseElement):
+ supports_execution = True
+ quote = None
+
def __init__(self, ident):
self.ident = ident
- def supports_execution(self):
- return True
class SavepointClause(_IdentifiedClause):
pass
diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py
index dfd638ecb..46dcaba66 100644
--- a/lib/sqlalchemy/sql/operators.py
+++ b/lib/sqlalchemy/sql/operators.py
@@ -44,7 +44,7 @@ def between_op(a, b, c):
return a.between(b, c)
def in_op(a, b):
- return a.in_(*b)
+ return a.in_(b)
def notin_op(a, b):
raise NotImplementedError()
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index d299982cf..944a68def 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -1,4 +1,4 @@
-from sqlalchemy import exceptions, schema, topological, util, sql
+from sqlalchemy import exc, schema, topological, util, sql
from sqlalchemy.sql import expression, operators, visitors
from itertools import chain
@@ -8,43 +8,57 @@ def sort_tables(tables, reverse=False):
"""sort a collection of Table objects in order of their foreign-key dependency."""
tuples = []
- class TVisitor(schema.SchemaVisitor):
- def visit_foreign_key(_self, fkey):
- if fkey.use_alter:
- return
- parent_table = fkey.column.table
- if parent_table in tables:
- child_table = fkey.parent.table
- tuples.append( ( parent_table, child_table ) )
- vis = TVisitor()
+ def visit_foreign_key(fkey):
+ if fkey.use_alter:
+ return
+ parent_table = fkey.column.table
+ if parent_table in tables:
+ child_table = fkey.parent.table
+ tuples.append( ( parent_table, child_table ) )
+
for table in tables:
- vis.traverse(table)
+ visitors.traverse(table, {'schema_visitor':True}, {'foreign_key':visit_foreign_key})
sequence = topological.sort(tuples, tables)
if reverse:
return util.reversed(sequence)
else:
return sequence
-def find_tables(clause, check_columns=False, include_aliases=False):
+def search(clause, target):
+ if not clause:
+ return False
+ for elem in visitors.iterate(clause, {'column_collections':False}):
+ if elem is target:
+ return True
+ else:
+ return False
+
+def find_tables(clause, check_columns=False, include_aliases=False, include_joins=False, include_selects=False):
"""locate Table objects within the given expression."""
tables = []
- kwargs = {}
+ _visitors = {}
+
+ def visit_something(elem):
+ tables.append(elem)
+
+ if include_selects:
+ _visitors['select'] = _visitors['compound_select'] = visit_something
+
+ if include_joins:
+ _visitors['join'] = visit_something
+
if include_aliases:
- def visit_alias(alias):
- tables.append(alias)
- kwargs['visit_alias'] = visit_alias
+ _visitors['alias'] = visit_something
if check_columns:
def visit_column(column):
tables.append(column.table)
- kwargs['visit_column'] = visit_column
+ _visitors['column'] = visit_column
- def visit_table(table):
- tables.append(table)
- kwargs['visit_table'] = visit_table
+ _visitors['table'] = visit_something
- visitors.traverse(clause, traverse_options= {'column_collections':False}, **kwargs)
+ visitors.traverse(clause, {'column_collections':False}, _visitors)
return tables
def find_columns(clause):
@@ -53,7 +67,7 @@ def find_columns(clause):
cols = util.Set()
def visit_column(col):
cols.add(col)
- visitors.traverse(clause, visit_column=visit_column)
+ visitors.traverse(clause, {}, {'column':visit_column})
return cols
def join_condition(a, b, ignore_nonexistent_tables=False):
@@ -72,7 +86,7 @@ def join_condition(a, b, ignore_nonexistent_tables=False):
for fk in b.foreign_keys:
try:
col = fk.get_referent(a)
- except exceptions.NoReferencedTableError:
+ except exc.NoReferencedTableError:
if ignore_nonexistent_tables:
continue
else:
@@ -81,27 +95,26 @@ def join_condition(a, b, ignore_nonexistent_tables=False):
if col:
crit.append(col == fk.parent)
constraints.add(fk.constraint)
-
if a is not b:
for fk in a.foreign_keys:
try:
col = fk.get_referent(b)
- except exceptions.NoReferencedTableError:
+ except exc.NoReferencedTableError:
if ignore_nonexistent_tables:
continue
else:
raise
-
+
if col:
crit.append(col == fk.parent)
constraints.add(fk.constraint)
if len(crit) == 0:
- raise exceptions.ArgumentError(
+ raise exc.ArgumentError(
"Can't find any foreign key relationships "
"between '%s' and '%s'" % (a.description, b.description))
elif len(constraints) > 1:
- raise exceptions.ArgumentError(
+ raise exc.ArgumentError(
"Can't determine join between '%s' and '%s'; "
"tables have more than one foreign key "
"constraint relationship between them. "
@@ -111,7 +124,70 @@ def join_condition(a, b, ignore_nonexistent_tables=False):
return (crit[0])
else:
return sql.and_(*crit)
+
+class Annotated(object):
+ """clones a ClauseElement and applies an 'annotations' dictionary.
+
+ Unlike regular clones, this clone also mimics __hash__() and
+ __cmp__() of the original element so that it takes its place
+ in hashed collections.
+ A reference to the original element is maintained, for the important
+ reason of keeping its hash value current. When GC'ed, the
+ hash value may be reused, causing conflicts.
+
+ """
+ def __new__(cls, *args):
+ if not args:
+ return object.__new__(cls)
+ else:
+ element, values = args
+ return object.__new__(
+ type.__new__(type, "Annotated%s" % element.__class__.__name__, (Annotated, element.__class__), {})
+ )
+
+ def __init__(self, element, values):
+ self.__dict__ = element.__dict__.copy()
+ self.__element = element
+ self._annotations = values
+
+ def _annotate(self, values):
+ _values = self._annotations.copy()
+ _values.update(values)
+ clone = self.__class__.__new__(self.__class__)
+ clone.__dict__ = self.__dict__.copy()
+ clone._annotations = _values
+ return clone
+
+ def __hash__(self):
+ return hash(self.__element)
+
+ def __cmp__(self, other):
+ return cmp(hash(self.__element), hash(other))
+
+def splice_joins(left, right, stop_on=None):
+ if left is None:
+ return right
+
+ stack = [(right, None)]
+
+ adapter = ClauseAdapter(left)
+ ret = None
+ while stack:
+ (right, prevright) = stack.pop()
+ if isinstance(right, expression.Join) and right is not stop_on:
+ right = right._clone()
+ right._reset_exported()
+ right.onclause = adapter.traverse(right.onclause)
+ stack.append((right.left, right))
+ else:
+ right = adapter.traverse(right)
+ if prevright:
+ prevright.left = right
+ if not ret:
+ ret = right
+
+ return ret
def reduce_columns(columns, *clauses):
"""given a list of columns, return a 'reduced' set based on natural equivalents.
@@ -151,7 +227,7 @@ def reduce_columns(columns, *clauses):
omit.add(c)
break
for clause in clauses:
- visitors.traverse(clause, visit_binary=visit_binary)
+ visitors.traverse(clause, {}, {'binary':visit_binary})
return expression.ColumnSet(columns.difference(omit))
@@ -159,7 +235,7 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None, consider_as_re
"""traverse an expression and locate binary criterion pairs."""
if consider_as_foreign_keys and consider_as_referenced_keys:
- raise exceptions.ArgumentError("Can only specify one of 'consider_as_foreign_keys' or 'consider_as_referenced_keys'")
+ raise exc.ArgumentError("Can only specify one of 'consider_as_foreign_keys' or 'consider_as_referenced_keys'")
def visit_binary(binary):
if not any_operator and binary.operator != operators.eq:
@@ -184,7 +260,7 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None, consider_as_re
elif binary.right.references(binary.left):
pairs.append((binary.left, binary.right))
pairs = []
- visitors.traverse(expression, visit_binary=visit_binary)
+ visitors.traverse(expression, {}, {'binary':visit_binary})
return pairs
def folded_equivalents(join, equivs=None):
@@ -195,15 +271,15 @@ def folded_equivalents(join, equivs=None):
This function is used by Join.select(fold_equivalents=True).
TODO: deprecate ?
- """
+ """
if equivs is None:
equivs = util.Set()
def visit_binary(binary):
if binary.operator == operators.eq and binary.left.name == binary.right.name:
equivs.add(binary.right)
equivs.add(binary.left)
- visitors.traverse(join.onclause, visit_binary=visit_binary)
+ visitors.traverse(join.onclause, {}, {'binary':visit_binary})
collist = []
if isinstance(join.left, expression.Join):
left = folded_equivalents(join.left, equivs)
@@ -246,43 +322,8 @@ class AliasedRow(object):
def keys(self):
return self.row.keys()
-def row_adapter(from_, equivalent_columns=None):
- """create a row adapter callable against a selectable."""
-
- if equivalent_columns is None:
- equivalent_columns = {}
-
- def locate_col(col):
- c = from_.corresponding_column(col)
- if c:
- return c
- elif col in equivalent_columns:
- for c2 in equivalent_columns[col]:
- corr = from_.corresponding_column(c2)
- if corr:
- return corr
- return col
-
- map = util.PopulateDict(locate_col)
-
- def adapt(row):
- return AliasedRow(row, map)
- return adapt
-
-class ColumnsInClause(visitors.ClauseVisitor):
- """Given a selectable, visit clauses and determine if any columns
- from the clause are in the selectable.
- """
-
- def __init__(self, selectable):
- self.selectable = selectable
- self.result = False
-
- def visit_column(self, column):
- if self.selectable.c.get(column.key) is column:
- self.result = True
-class ClauseAdapter(visitors.ClauseVisitor):
+class ClauseAdapter(visitors.ReplacingCloningVisitor):
"""Given a clause (like as in a WHERE criterion), locate columns
which are embedded within a given selectable, and changes those
columns to be that of the selectable.
@@ -308,58 +349,76 @@ class ClauseAdapter(visitors.ClauseVisitor):
condition to read::
s.c.col1 == table2.c.col1
- """
-
- __traverse_options__ = {'column_collections':False}
- def __init__(self, selectable, include=None, exclude=None, equivalents=None):
- self.__traverse_options__ = self.__traverse_options__.copy()
- self.__traverse_options__['stop_on'] = [selectable]
+ """
+ def __init__(self, selectable, equivalents=None, include=None, exclude=None):
+ self.__traverse_options__ = {'column_collections':False, 'stop_on':[selectable]}
self.selectable = selectable
self.include = include
self.exclude = exclude
- self.equivalents = equivalents
-
- def traverse(self, obj, clone=True):
- if not clone:
- raise exceptions.ArgumentError("ClauseAdapter 'clone' argument must be True")
- return visitors.ClauseVisitor.traverse(self, obj, clone=True)
-
- def copy_and_chain(self, adapter):
- """create a copy of this adapter and chain to the given adapter.
-
- currently this adapter must be unchained to start, raises
- an exception if it's already chained.
-
- Does not modify the given adapter.
- """
+ self.equivalents = equivalents or {}
- if adapter is None:
- return self
+ def _corresponding_column(self, col, require_embedded):
+ newcol = self.selectable.corresponding_column(col, require_embedded=require_embedded)
- if hasattr(self, '_next'):
- raise NotImplementedError("Can't chain_to on an already chained ClauseAdapter (yet)")
-
- ca = ClauseAdapter(self.selectable, self.include, self.exclude, self.equivalents)
- ca._next = adapter
- return ca
+ if not newcol and col in self.equivalents:
+ for equiv in self.equivalents[col]:
+ newcol = self.selectable.corresponding_column(equiv, require_embedded=require_embedded)
+ if newcol:
+ return newcol
+ return newcol
- def before_clone(self, col):
+ def replace(self, col):
if isinstance(col, expression.FromClause):
if self.selectable.is_derived_from(col):
return self.selectable
+
if not isinstance(col, expression.ColumnElement):
return None
- if self.include is not None:
- if col not in self.include:
- return None
- if self.exclude is not None:
- if col in self.exclude:
- return None
- newcol = self.selectable.corresponding_column(col, require_embedded=True)
- if newcol is None and self.equivalents is not None and col in self.equivalents:
- for equiv in self.equivalents[col]:
- newcol = self.selectable.corresponding_column(equiv, require_embedded=True)
- if newcol:
- return newcol
- return newcol
+
+ if self.include and col not in self.include:
+ return None
+ elif self.exclude and col in self.exclude:
+ return None
+
+ return self._corresponding_column(col, True)
+
+class ColumnAdapter(ClauseAdapter):
+
+ def __init__(self, selectable, equivalents=None, chain_to=None, include=None, exclude=None):
+ ClauseAdapter.__init__(self, selectable, equivalents, include, exclude)
+ if chain_to:
+ self.chain(chain_to)
+ self.columns = util.PopulateDict(self._locate_col)
+
+ def wrap(self, adapter):
+ ac = self.__class__.__new__(self.__class__)
+ ac.__dict__ = self.__dict__.copy()
+ ac._locate_col = ac._wrap(ac._locate_col, adapter._locate_col)
+ ac.adapt_clause = ac._wrap(ac.adapt_clause, adapter.adapt_clause)
+ ac.adapt_list = ac._wrap(ac.adapt_list, adapter.adapt_list)
+ ac.columns = util.PopulateDict(ac._locate_col)
+ return ac
+
+ adapt_clause = ClauseAdapter.traverse
+ adapt_list = ClauseAdapter.copy_and_process
+
+ def _wrap(self, local, wrapped):
+ def locate(col):
+ col = local(col)
+ return wrapped(col)
+ return locate
+
+ def _locate_col(self, col):
+ c = self._corresponding_column(col, False)
+ if not c:
+ c = self.adapt_clause(col)
+
+ # anonymize labels in case they have a hardcoded name
+ if isinstance(c, expression._Label):
+ c = c.label(None)
+ return c
+
+ def adapted_row(self, row):
+ return AliasedRow(row, self.columns)
+
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 9888a228a..738dae9c7 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -1,138 +1,29 @@
from sqlalchemy import util
class ClauseVisitor(object):
- """Traverses and visits ``ClauseElement`` structures.
-
- Calls visit_XXX() methods for each particular
- ``ClauseElement`` subclass encountered. Traversal of a
- hierarchy of ``ClauseElements`` is achieved via the
- ``traverse()`` method, which is passed the lead
- ``ClauseElement``.
-
- By default, ``ClauseVisitor`` traverses all elements
- fully. Options can be specified at the class level via the
- ``__traverse_options__`` dictionary which will be passed
- to the ``get_children()`` method of each ``ClauseElement``;
- these options can indicate modifications to the set of
- elements returned, such as to not return column collections
- (column_collections=False) or to return Schema-level items
- (schema_visitor=True).
-
- ``ClauseVisitor`` also supports a simultaneous copy-and-traverse
- operation, which will produce a copy of a given ``ClauseElement``
- structure while at the same time allowing ``ClauseVisitor`` subclasses
- to modify the new structure in-place.
-
- """
__traverse_options__ = {}
- def traverse_single(self, obj, **kwargs):
- """visit a single element, without traversing its child elements."""
-
+ def traverse_single(self, obj):
for v in self._iterate_visitors:
meth = getattr(v, "visit_%s" % obj.__visit_name__, None)
if meth:
- return meth(obj, **kwargs)
+ return meth(obj)
- traverse_chained = traverse_single
-
def iterate(self, obj):
"""traverse the given expression structure, returning an iterator of all elements."""
-
- stack = [obj]
- traversal = util.deque()
- while stack:
- t = stack.pop()
- traversal.appendleft(t)
- for c in t.get_children(**self.__traverse_options__):
- stack.append(c)
- return iter(traversal)
-
- def traverse(self, obj, clone=False):
- """traverse and visit the given expression structure.
-
- Returns the structure given, or a copy of the structure if
- clone=True.
-
- When the copy operation takes place, the before_clone() method
- will receive each element before it is copied. If the method
- returns a non-None value, the return value is taken as the
- "copied" element and traversal will not descend further.
-
- The visit_XXX() methods receive the element *after* it's been
- copied. To compare an element to another regardless of
- one element being a cloned copy of the original, the
- '_cloned_set' attribute of ClauseElement can be used for the compare,
- i.e.::
-
- original in copied._cloned_set
-
-
- """
- if clone:
- return self._cloned_traversal(obj)
- else:
- return self._non_cloned_traversal(obj)
-
- def copy_and_process(self, list_):
- """Apply cloned traversal to the given list of elements, and return the new list."""
-
- return [self._cloned_traversal(x) for x in list_]
- def before_clone(self, elem):
- """receive pre-copied elements during a cloning traversal.
-
- If the method returns a new element, the element is used
- instead of creating a simple copy of the element. Traversal
- will halt on the newly returned element if it is re-encountered.
- """
- return None
-
- def _clone_element(self, elem, stop_on, cloned):
- for v in self._iterate_visitors:
- newelem = v.before_clone(elem)
- if newelem:
- stop_on.add(newelem)
- return newelem
-
- if elem not in cloned:
- # the full traversal will only make a clone of a particular element
- # once.
- cloned[elem] = elem._clone()
- return cloned[elem]
-
- def _cloned_traversal(self, obj):
- """a recursive traversal which creates copies of elements, returning the new structure."""
-
- stop_on = self.__traverse_options__.get('stop_on', [])
- return self._cloned_traversal_impl(obj, util.Set(stop_on), {}, _clone_toplevel=True)
-
- def _cloned_traversal_impl(self, elem, stop_on, cloned, _clone_toplevel=False):
- if elem in stop_on:
- return elem
-
- if _clone_toplevel:
- elem = self._clone_element(elem, stop_on, cloned)
- if elem in stop_on:
- return elem
-
- def clone(element):
- return self._clone_element(element, stop_on, cloned)
- elem._copy_internals(clone=clone)
+ return iterate(obj, self.__traverse_options__)
- self.traverse_single(elem)
+ def traverse(self, obj):
+ """traverse and visit the given expression structure."""
- for e in elem.get_children(**self.__traverse_options__):
- if e not in stop_on:
- self._cloned_traversal_impl(e, stop_on, cloned)
- return elem
+ visitors = {}
- def _non_cloned_traversal(self, obj):
- """a non-recursive, non-cloning traversal."""
-
- for target in self.iterate(obj):
- self.traverse_single(target)
- return obj
+ for name in dir(self):
+ if name.startswith('visit_'):
+ visitors[name[6:]] = getattr(self, name)
+
+ return traverse(obj, self.__traverse_options__, visitors)
def _iterate_visitors(self):
"""iterate through this visitor and each 'chained' visitor."""
@@ -152,31 +43,136 @@ class ClauseVisitor(object):
tail._next = visitor
return self
-class NoColumnVisitor(ClauseVisitor):
- """ClauseVisitor with 'column_collections' set to False; will not
- traverse the front-facing Column collections on Table, Alias, Select,
- and CompoundSelect objects.
+class CloningVisitor(ClauseVisitor):
+ def copy_and_process(self, list_):
+ """Apply cloned traversal to the given list of elements, and return the new list."""
+
+ return [self.traverse(x) for x in list_]
+
+ def traverse(self, obj):
+ """traverse and visit the given expression structure."""
+
+ visitors = {}
+
+ for name in dir(self):
+ if name.startswith('visit_'):
+ visitors[name[6:]] = getattr(self, name)
+
+ return cloned_traverse(obj, self.__traverse_options__, visitors)
+
+class ReplacingCloningVisitor(CloningVisitor):
+ def replace(self, elem):
+ """receive pre-copied elements during a cloning traversal.
+
+ If the method returns a new element, the element is used
+ instead of creating a simple copy of the element. Traversal
+ will halt on the newly returned element if it is re-encountered.
+ """
+ return None
+
+ def traverse(self, obj):
+ """traverse and visit the given expression structure."""
+
+ def replace(elem):
+ for v in self._iterate_visitors:
+ e = v.replace(elem)
+ if e:
+ return e
+ return replacement_traverse(obj, self.__traverse_options__, replace)
+
+def iterate(obj, opts):
+ """traverse the given expression structure, returning an iterator.
+
+ traversal is configured to be breadth-first.
"""
+ stack = util.deque([obj])
+ while stack:
+ t = stack.popleft()
+ yield t
+ for c in t.get_children(**opts):
+ stack.append(c)
+
+def iterate_depthfirst(obj, opts):
+ """traverse the given expression structure, returning an iterator.
- __traverse_options__ = {'column_collections':False}
-
-class NullVisitor(ClauseVisitor):
- def traverse(self, obj, clone=False):
- next = getattr(self, '_next', None)
- if next:
- return next.traverse(obj, clone=clone)
- else:
- return obj
-
-def traverse(clause, **kwargs):
- """traverse the given clause, applying visit functions passed in as keyword arguments."""
+ traversal is configured to be depth-first.
+
+ """
+ stack = util.deque([obj])
+ traversal = util.deque()
+ while stack:
+ t = stack.pop()
+ traversal.appendleft(t)
+ for c in t.get_children(**opts):
+ stack.append(c)
+ return iter(traversal)
+
+def traverse_using(iterator, obj, visitors):
+ """visit the given expression structure using the given iterator of objects."""
+
+ for target in iterator:
+ meth = visitors.get(target.__visit_name__, None)
+ if meth:
+ meth(target)
+ return obj
- clone = kwargs.pop('clone', False)
- class Vis(ClauseVisitor):
- __traverse_options__ = kwargs.pop('traverse_options', {})
- vis = Vis()
- for key in kwargs:
- setattr(vis, key, kwargs[key])
- return vis.traverse(clause, clone=clone)
+def traverse(obj, opts, visitors):
+ """traverse and visit the given expression structure using the default iterator."""
+
+ return traverse_using(iterate(obj, opts), obj, visitors)
+
+def traverse_depthfirst(obj, opts, visitors):
+ """traverse and visit the given expression structure using the depth-first iterator."""
+
+ return traverse_using(iterate_depthfirst(obj, opts), obj, visitors)
+
+def cloned_traverse(obj, opts, visitors):
+ cloned = {}
+
+ def clone(element):
+ if element not in cloned:
+ cloned[element] = element._clone()
+ return cloned[element]
+
+ obj = clone(obj)
+ stack = [obj]
+
+ while stack:
+ t = stack.pop()
+ if t in cloned:
+ continue
+ t._copy_internals(clone=clone)
+
+ meth = visitors.get(t.__visit_name__, None)
+ if meth:
+ meth(t)
+
+ for c in t.get_children(**opts):
+ stack.append(c)
+ return obj
+
+def replacement_traverse(obj, opts, replace):
+ cloned = {}
+ stop_on = util.Set(opts.get('stop_on', []))
+
+ def clone(element):
+ newelem = replace(element)
+ if newelem:
+ stop_on.add(newelem)
+ return newelem
+
+ if element not in cloned:
+ cloned[element] = element._clone()
+ return cloned[element]
+ obj = clone(obj)
+ stack = [obj]
+ while stack:
+ t = stack.pop()
+ if t in stop_on:
+ continue
+ t._copy_internals(clone=clone)
+ for c in t.get_children(**opts):
+ stack.append(c)
+ return obj