diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2008-05-09 16:34:10 +0000 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2008-05-09 16:34:10 +0000 |
| commit | 4a6afd469fad170868554bf28578849bf3dfd5dd (patch) | |
| tree | b396edc33d567ae19dd244e87137296450467725 /lib/sqlalchemy/sql | |
| parent | 46b7c9dc57a38d5b9e44a4723dad2ad8ec57baca (diff) | |
| download | sqlalchemy-4a6afd469fad170868554bf28578849bf3dfd5dd.tar.gz | |
r4695 merged to trunk; trunk now becomes 0.5.
0.4 development continues at /sqlalchemy/branches/rel_0_4
Diffstat (limited to 'lib/sqlalchemy/sql')
| -rw-r--r-- | lib/sqlalchemy/sql/__init__.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 101 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/expression.py | 490 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/operators.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/util.py | 283 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/visitors.py | 282 |
6 files changed, 579 insertions, 581 deletions
diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py index c966f396a..5ea9eb1e6 100644 --- a/lib/sqlalchemy/sql/__init__.py +++ b/lib/sqlalchemy/sql/__init__.py @@ -1,2 +1,2 @@ from sqlalchemy.sql.expression import * -from sqlalchemy.sql.visitors import ClauseVisitor, NoColumnVisitor +from sqlalchemy.sql.visitors import ClauseVisitor diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 1fe9ef062..78bb4e31c 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -19,7 +19,7 @@ is otherwise internal to SQLAlchemy. """ import string, re, itertools -from sqlalchemy import schema, engine, util, exceptions +from sqlalchemy import schema, engine, util, exc from sqlalchemy.sql import operators, functions from sqlalchemy.sql import expression as sql @@ -115,8 +115,6 @@ class DefaultCompiler(engine.Compiled): paradigm as visitors.ClauseVisitor but implements its own traversal. """ - __traverse_options__ = {'column_collections':False, 'entry':True} - operators = OPERATORS functions = FUNCTIONS @@ -162,17 +160,12 @@ class DefaultCompiler(engine.Compiled): # for aliases self.generated_ids = {} - # paramstyle from the dialect (comes from DB-API) - self.paramstyle = self.dialect.paramstyle - # true if the paramstyle is positional self.positional = self.dialect.positional + if self.positional: + self.positiontup = [] - self.bindtemplate = BIND_TEMPLATES[self.paramstyle] - - # a list of the compiled's bind parameter names, used to help - # formulate a positional argument list - self.positiontup = [] + self.bindtemplate = BIND_TEMPLATES[self.dialect.paramstyle] # an IdentifierPreparer that formats the quoting of identifiers self.preparer = self.dialect.identifier_preparer @@ -230,15 +223,18 @@ class DefaultCompiler(engine.Compiled): return "" def visit_grouping(self, grouping, **kwargs): - return "(" + self.process(grouping.elem) + ")" + return "(" + self.process(grouping.element) + ")" - def visit_label(self, label, result_map=None): + def visit_label(self, label, result_map=None, render_labels=False): + if not render_labels: + return self.process(label.element) + labelname = self._truncated_identifier("colident", label.name) if result_map is not None: - result_map[labelname.lower()] = (label.name, (label, label.obj, labelname), label.obj.type) + result_map[labelname.lower()] = (label.name, (label, label.element, labelname), label.element.type) - return " ".join([self.process(label.obj), self.operator_string(operators.as_), self.preparer.format_label(label, labelname)]) + return " ".join([self.process(label.element), self.operator_string(operators.as_), self.preparer.format_label(label, labelname)]) def visit_column(self, column, result_map=None, **kwargs): @@ -261,16 +257,16 @@ class DefaultCompiler(engine.Compiled): if getattr(column, "is_literal", False): name = self.escape_literal_column(name) else: - name = self.preparer.quote(column, name) + name = self.preparer.quote(name, column.quote) if column.table is None or not column.table.named_with_column: return name else: if getattr(column.table, 'schema', None): - schema_prefix = self.preparer.quote(column.table, column.table.schema) + '.' + schema_prefix = self.preparer.quote(column.table.schema, column.table.quote_schema) + '.' else: schema_prefix = '' - return schema_prefix + self.preparer.quote(column.table, ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) + "." + name + return schema_prefix + self.preparer.quote(ANONYMOUS_LABEL.sub(self._process_anon, column.table.name), column.table.quote) + "." + name def escape_literal_column(self, text): """provide escaping for the literal_column() construct.""" @@ -387,7 +383,7 @@ class DefaultCompiler(engine.Compiled): if name in self.binds: existing = self.binds[name] if existing is not bindparam and (existing.unique or bindparam.unique): - raise exceptions.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key) + raise exc.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key) self.binds[bindparam.key] = self.binds[name] = bindparam return self.bindparam_string(name) @@ -418,7 +414,7 @@ class DefaultCompiler(engine.Compiled): return truncname def _process_anon(self, match): - (ident, derived) = match.group(1,2) + (ident, derived) = match.group(1, 2) key = ('anonymous', ident) if key in self.generated_ids: @@ -436,8 +432,9 @@ class DefaultCompiler(engine.Compiled): def bindparam_string(self, name): if self.positional: self.positiontup.append(name) - - return self.bindtemplate % {'name':name, 'position':len(self.positiontup)} + return self.bindtemplate % {'name':name, 'position':len(self.positiontup)} + else: + return self.bindtemplate % {'name':name} def visit_alias(self, alias, asfrom=False, **kwargs): if asfrom: @@ -490,7 +487,7 @@ class DefaultCompiler(engine.Compiled): froms = select._get_display_froms(existingfroms) - correlate_froms = util.Set(itertools.chain(*([froms] + [f._get_from_objects() for f in froms]))) + correlate_froms = util.Set(sql._from_objects(*froms)) # TODO: might want to propigate existing froms for select(select(select)) # where innermost select should correlate to outermost @@ -504,6 +501,7 @@ class DefaultCompiler(engine.Compiled): [c for c in [ self.process( self.label_select_column(select, co, asfrom=asfrom), + render_labels=True, **column_clause_args) for co in select.inner_columns ] @@ -580,9 +578,9 @@ class DefaultCompiler(engine.Compiled): def visit_table(self, table, asfrom=False, **kwargs): if asfrom: if getattr(table, "schema", None): - return self.preparer.quote(table, table.schema) + "." + self.preparer.quote(table, table.name) + return self.preparer.quote(table.schema, table.quote_schema) + "." + self.preparer.quote(table.name, table.quote) else: - return self.preparer.quote(table, table.name) + return self.preparer.quote(table.name, table.quote) else: return "" @@ -603,7 +601,7 @@ class DefaultCompiler(engine.Compiled): return (insert + " INTO %s (%s) VALUES (%s)" % (preparer.format_table(insert_stmt.table), - ', '.join([preparer.quote(c[0], c[0].name) + ', '.join([preparer.quote(c[0].name, c[0].quote) for c in colparams]), ', '.join([c[1] for c in colparams]))) @@ -613,7 +611,7 @@ class DefaultCompiler(engine.Compiled): self.isupdate = True colparams = self._get_colparams(update_stmt) - text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.quote(c[0], c[0].name), c[1]) for c in colparams], ', ') + text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.quote(c[0].name, c[0].quote), c[1]) for c in colparams], ', ') if update_stmt._whereclause: text += " WHERE " + self.process(update_stmt._whereclause) @@ -837,7 +835,7 @@ class SchemaGenerator(DDLBase): if constraint.name is not None: self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint)) self.append("PRIMARY KEY ") - self.append("(%s)" % ', '.join([self.preparer.quote(c, c.name) for c in constraint])) + self.append("(%s)" % ', '.join([self.preparer.quote(c.name, c.quote) for c in constraint])) self.define_constraint_deferrability(constraint) def visit_foreign_key_constraint(self, constraint): @@ -858,9 +856,9 @@ class SchemaGenerator(DDLBase): preparer.format_constraint(constraint)) table = list(constraint.elements)[0].column.table self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % ( - ', '.join([preparer.quote(f.parent, f.parent.name) for f in constraint.elements]), + ', '.join([preparer.quote(f.parent.name, f.parent.quote) for f in constraint.elements]), preparer.format_table(table), - ', '.join([preparer.quote(f.column, f.column.name) for f in constraint.elements]) + ', '.join([preparer.quote(f.column.name, f.column.quote) for f in constraint.elements]) )) if constraint.ondelete is not None: self.append(" ON DELETE %s" % constraint.ondelete) @@ -873,7 +871,7 @@ class SchemaGenerator(DDLBase): if constraint.name is not None: self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint)) - self.append(" UNIQUE (%s)" % (', '.join([self.preparer.quote(c, c.name) for c in constraint]))) + self.append(" UNIQUE (%s)" % (', '.join([self.preparer.quote(c.name, c.quote) for c in constraint]))) self.define_constraint_deferrability(constraint) def define_constraint_deferrability(self, constraint): @@ -896,7 +894,7 @@ class SchemaGenerator(DDLBase): self.append("INDEX %s ON %s (%s)" \ % (preparer.format_index(index), preparer.format_table(index.table), - string.join([preparer.quote(c, c.name) for c in index.columns], ', '))) + string.join([preparer.quote(c.name, c.quote) for c in index.columns], ', '))) self.execute() @@ -1005,9 +1003,12 @@ class IdentifierPreparer(object): or not self.legal_characters.match(unicode(value)) or (lc_value != value)) - def quote(self, obj, ident): - if getattr(obj, 'quote', False): + def quote(self, ident, force): + if force: return self.quote_identifier(ident) + elif force is False: + return ident + if ident in self.__strings: return self.__strings[ident] else: @@ -1017,53 +1018,47 @@ class IdentifierPreparer(object): self.__strings[ident] = ident return self.__strings[ident] - def should_quote(self, object): - return object.quote or self._requires_quotes(object.name) - def format_sequence(self, sequence, use_schema=True): - name = self.quote(sequence, sequence.name) + name = self.quote(sequence.name, sequence.quote) if not self.omit_schema and use_schema and sequence.schema is not None: - name = self.quote(sequence, sequence.schema) + "." + name + name = self.quote(sequence.schema, sequence.quote) + "." + name return name def format_label(self, label, name=None): - return self.quote(label, name or label.name) + return self.quote(name or label.name, label.quote) def format_alias(self, alias, name=None): - return self.quote(alias, name or alias.name) + return self.quote(name or alias.name, alias.quote) def format_savepoint(self, savepoint, name=None): - return self.quote(savepoint, name or savepoint.ident) + return self.quote(name or savepoint.ident, savepoint.quote) def format_constraint(self, constraint): - return self.quote(constraint, constraint.name) + return self.quote(constraint.name, constraint.quote) def format_index(self, index): - return self.quote(index, index.name) + return self.quote(index.name, index.quote) def format_table(self, table, use_schema=True, name=None): """Prepare a quoted table and schema name.""" if name is None: name = table.name - result = self.quote(table, name) + result = self.quote(name, table.quote) if not self.omit_schema and use_schema and getattr(table, "schema", None): - result = self.quote(table, table.schema) + "." + result + result = self.quote(table.schema, table.quote_schema) + "." + result return result def format_column(self, column, use_table=False, name=None, table_name=None): - """Prepare a quoted column name. - - deprecated. use preparer.quote(col, column.name) or combine with format_table() - """ + """Prepare a quoted column name.""" if name is None: name = column.name if not getattr(column, 'is_literal', False): if use_table: - return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.quote(column, name) + return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.quote(name, column.quote) else: - return self.quote(column, name) + return self.quote(name, column.quote) else: # literal textual elements get stuck into ColumnClause alot, which shouldnt get quoted if use_table: @@ -1079,7 +1074,7 @@ class IdentifierPreparer(object): # a longer sequence. if not self.omit_schema and use_schema and getattr(table, 'schema', None): - return (self.quote_identifier(table.schema), + return (self.quote(table.schema, table.quote_schema), self.format_table(table, use_schema=False)) else: return (self.format_table(table, use_schema=False), ) diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 867fdd69c..7ce637701 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -26,12 +26,12 @@ to stay the same in future releases. """ import itertools, re -from sqlalchemy import util, exceptions +from sqlalchemy import util, exc from sqlalchemy.sql import operators, visitors from sqlalchemy import types as sqltypes functions, schema, sql_util = None, None, None -DefaultDialect, ClauseAdapter = None, None +DefaultDialect, ClauseAdapter, Annotated = None, None, None __all__ = [ 'Alias', 'ClauseElement', @@ -503,15 +503,21 @@ def collate(expression, collation): def exists(*args, **kwargs): """Return an ``EXISTS`` clause as applied to a [sqlalchemy.sql.expression#Select] object. + + Calling styles are of the following forms:: + + # use on an existing select() + s = select([<columns>]).where(<criterion>) + s = exists(s) + + # construct a select() at once + exists(['*'], **select_arguments).where(<criterion>) + + # columns argument is optional, generates "EXISTS (SELECT *)" + # by default. + exists().where(<criterion>) - The resulting [sqlalchemy.sql.expression#_Exists] object can be executed by - itself or used as a subquery within an enclosing select. - - \*args, \**kwargs - all arguments are sent directly to the [sqlalchemy.sql.expression#select()] - function to produce a ``SELECT`` statement. """ - return _Exists(*args, **kwargs) def union(*selects, **kwargs): @@ -872,27 +878,36 @@ def _compound_select(keyword, *selects, **kwargs): return CompoundSelect(keyword, *selects, **kwargs) def _is_literal(element): - return not isinstance(element, ClauseElement) + return not isinstance(element, (ClauseElement, Operators)) + +def _from_objects(*elements, **kwargs): + return itertools.chain(*[element._get_from_objects(**kwargs) for element in elements]) +def _labeled(element): + if not hasattr(element, 'name'): + return element.label(None) + else: + return element + def _literal_as_text(element): - if isinstance(element, Operators): - return element.expression_element() + if hasattr(element, '__clause_element__'): + return element.__clause_element__() elif _is_literal(element): return _TextClause(unicode(element)) else: return element def _literal_as_column(element): - if isinstance(element, Operators): - return element.clause_element() + if hasattr(element, '__clause_element__'): + return element.__clause_element__() elif _is_literal(element): return literal_column(str(element)) else: return element def _literal_as_binds(element, name=None, type_=None): - if isinstance(element, Operators): - return element.expression_element() + if hasattr(element, '__clause_element__'): + return element.__clause_element__() elif _is_literal(element): if element is None: return null() @@ -902,17 +917,17 @@ def _literal_as_binds(element, name=None, type_=None): return element def _no_literals(element): - if isinstance(element, Operators): - return element.expression_element() + if hasattr(element, '__clause_element__'): + return element.__clause_element__() elif _is_literal(element): - raise exceptions.ArgumentError("Ambiguous literal: %r. Use the 'text()' function to indicate a SQL expression literal, or 'literal()' to indicate a bound value." % element) + raise exc.ArgumentError("Ambiguous literal: %r. Use the 'text()' function to indicate a SQL expression literal, or 'literal()' to indicate a bound value." % element) else: return element def _corresponding_column_or_error(fromclause, column, require_embedded=False): c = fromclause.corresponding_column(column, require_embedded=require_embedded) if not c: - raise exceptions.InvalidRequestError("Given column '%s', attached to table '%s', failed to locate a corresponding column from table '%s'" % (str(column), str(getattr(column, 'table', None)), fromclause.description)) + raise exc.InvalidRequestError("Given column '%s', attached to table '%s', failed to locate a corresponding column from table '%s'" % (str(column), str(getattr(column, 'table', None)), fromclause.description)) return c def _selectable(element): @@ -921,9 +936,8 @@ def _selectable(element): elif isinstance(element, Selectable): return element else: - raise exceptions.ArgumentError("Object '%s' is not a Selectable and does not implement `__selectable__()`" % repr(element)) + raise exc.ArgumentError("Object '%s' is not a Selectable and does not implement `__selectable__()`" % repr(element)) - def is_column(col): """True if ``col`` is an instance of ``ColumnElement``.""" return isinstance(col, ColumnElement) @@ -941,7 +955,9 @@ class _FigureVisitName(type): class ClauseElement(object): """Base class for elements of a programmatically constructed SQL expression.""" __metaclass__ = _FigureVisitName - + _annotations = {} + supports_execution = False + def _clone(self): """Create a shallow copy of this ClauseElement. @@ -976,6 +992,14 @@ class ClauseElement(object): """ raise NotImplementedError(repr(self)) + + def _annotate(self, values): + """return a copy of this ClauseElement with the given annotations dictionary.""" + + global Annotated + if Annotated is None: + from sqlalchemy.sql.util import Annotated + return Annotated(self, values) def unique_params(self, *optionaldict, **kwargs): """Return a copy with ``bindparam()`` elments replaced. @@ -1006,14 +1030,14 @@ class ClauseElement(object): if len(optionaldict) == 1: kwargs.update(optionaldict[0]) elif len(optionaldict) > 1: - raise exceptions.ArgumentError("params() takes zero or one positional dictionary argument") + raise exc.ArgumentError("params() takes zero or one positional dictionary argument") def visit_bindparam(bind): if bind.key in kwargs: bind.value = kwargs[bind.key] if unique: bind._convert_to_unique() - return visitors.traverse(self, visit_bindparam=visit_bindparam, clone=True) + return visitors.cloned_traverse(self, {}, {'bindparam':visit_bindparam}) def compare(self, other): """Compare this ClauseElement to the given ClauseElement. @@ -1049,11 +1073,6 @@ class ClauseElement(object): def self_group(self, against=None): return self - def supports_execution(self): - """Return True if this clause element represents a complete executable statement.""" - - return False - def bind(self): """Returns the Engine or Connection to which this ClauseElement is bound, or None if none found.""" @@ -1062,7 +1081,7 @@ class ClauseElement(object): return self._bind except AttributeError: pass - for f in self._get_from_objects(): + for f in _from_objects(self): if f is self: continue engine = f.bind @@ -1083,7 +1102,7 @@ class ClauseElement(object): 'Engine for execution. Or, assign a bind to the statement ' 'or the Metadata of its underlying tables to enable ' 'implicit execution via this method.' % label) - raise exceptions.UnboundExecutionError(msg) + raise exc.UnboundExecutionError(msg) return e.execute_clauseelement(self, multiparams, params) def scalar(self, *multiparams, **params): @@ -1159,6 +1178,12 @@ class ClauseElement(object): self.__module__, self.__class__.__name__, id(self), friendly) +class _Immutable(object): + """mark a ClauseElement as 'immutable' when expressions are cloned.""" + + def _clone(self): + return self + class Operators(object): def __and__(self, other): return self.operate(operators.and_, other) @@ -1174,9 +1199,6 @@ class Operators(object): return self.operate(operators.op, opstring, b) return op - def clause_element(self): - raise NotImplementedError() - def operate(self, op, *other, **kwargs): raise NotImplementedError() @@ -1216,7 +1238,7 @@ class ColumnOperators(Operators): def ilike(self, other, escape=None): return self.operate(operators.ilike_op, other, escape=escape) - def in_(self, *other): + def in_(self, other): return self.operate(operators.in_op, other) def startswith(self, other, **kwargs): @@ -1279,18 +1301,18 @@ class _CompareMixin(ColumnOperators): def __compare(self, op, obj, negate=None, reverse=False, **kwargs): if obj is None or isinstance(obj, _Null): if op == operators.eq: - return _BinaryExpression(self.expression_element(), null(), operators.is_, negate=operators.isnot) + return _BinaryExpression(self, null(), operators.is_, negate=operators.isnot) elif op == operators.ne: - return _BinaryExpression(self.expression_element(), null(), operators.isnot, negate=operators.is_) + return _BinaryExpression(self, null(), operators.isnot, negate=operators.is_) else: - raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL") + raise exc.ArgumentError("Only '='/'!=' operators can be used with NULL") else: obj = self._check_literal(obj) if reverse: - return _BinaryExpression(obj, self.expression_element(), op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs) + return _BinaryExpression(obj, self, op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs) else: - return _BinaryExpression(self.expression_element(), obj, op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs) + return _BinaryExpression(self, obj, op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs) def __operate(self, op, obj, reverse=False): obj = self._check_literal(obj) @@ -1298,9 +1320,9 @@ class _CompareMixin(ColumnOperators): type_ = self._compare_type(obj) if reverse: - return _BinaryExpression(obj, self.expression_element(), type_.adapt_operator(op), type_=type_) + return _BinaryExpression(obj, self, type_.adapt_operator(op), type_=type_) else: - return _BinaryExpression(self.expression_element(), obj, type_.adapt_operator(op), type_=type_) + return _BinaryExpression(self, obj, type_.adapt_operator(op), type_=type_) # a mapping of operators with the method they use, along with their negated # operator for comparison operators @@ -1329,17 +1351,10 @@ class _CompareMixin(ColumnOperators): o = _CompareMixin.operators[op] return o[0](self, op, other, reverse=True, *o[1:], **kwargs) - def in_(self, *other): - return self._in_impl(operators.in_op, operators.notin_op, *other) - - def _in_impl(self, op, negate_op, *other): - # Handle old style *args argument passing - if len(other) != 1 or not isinstance(other[0], Selectable) and (not hasattr(other[0], '__iter__') or isinstance(other[0], basestring)): - util.warn_deprecated('passing in_ arguments as varargs is deprecated, in_ takes a single argument that is a sequence or a selectable') - seq_or_selectable = other - else: - seq_or_selectable = other[0] + def in_(self, other): + return self._in_impl(operators.in_op, operators.notin_op, other) + def _in_impl(self, op, negate_op, seq_or_selectable): if isinstance(seq_or_selectable, Selectable): return self.__compare( op, seq_or_selectable, negate=negate_op) @@ -1348,7 +1363,7 @@ class _CompareMixin(ColumnOperators): for o in seq_or_selectable: if not _is_literal(o): if not isinstance( o, _CompareMixin): - raise exceptions.InvalidRequestError( "in() function accepts either a list of non-selectable values, or a selectable: "+repr(o) ) + raise exc.InvalidRequestError( "in() function accepts either a list of non-selectable values, or a selectable: "+repr(o) ) else: o = self._bind_param(o) args.append(o) @@ -1433,22 +1448,13 @@ class _CompareMixin(ColumnOperators): if isinstance(other, _BindParamClause) and isinstance(other.type, sqltypes.NullType): other.type = self.type return other - elif isinstance(other, Operators): - return other.expression_element() + elif hasattr(other, '__clause_element__'): + return other.__clause_element__() elif _is_literal(other): return self._bind_param(other) else: return other - def clause_element(self): - """Allow ``_CompareMixins`` to return the underlying ``ClauseElement``, for non-``ClauseElement`` ``_CompareMixins``.""" - return self - - def expression_element(self): - """Allow ``_CompareMixins`` to return the appropriate object to be used in expressions.""" - - return self - def _compare_type(self, obj): """Allow subclasses to override the type used in constructing ``_BinaryExpression`` objects. @@ -1480,23 +1486,22 @@ class ColumnElement(ClauseElement, _CompareMixin): primary_key = False foreign_keys = [] - + quote = None + def base_columns(self): - if hasattr(self, '_base_columns'): - return self._base_columns - self._base_columns = util.Set([c for c in self.proxy_set if not hasattr(c, 'proxies')]) + if not hasattr(self, '_base_columns'): + self._base_columns = util.Set([c for c in self.proxy_set if not hasattr(c, 'proxies')]) return self._base_columns base_columns = property(base_columns) def proxy_set(self): - if hasattr(self, '_proxy_set'): - return self._proxy_set - s = util.Set([self]) - if hasattr(self, 'proxies'): - for c in self.proxies: - s = s.union(c.proxy_set) - self._proxy_set = s - return s + if not hasattr(self, '_proxy_set'): + s = util.Set([self]) + if hasattr(self, 'proxies'): + for c in self.proxies: + s.update(c.proxy_set) + self._proxy_set = s + return self._proxy_set proxy_set = property(proxy_set) def shares_lineage(self, othercolumn): @@ -1518,7 +1523,7 @@ class ColumnElement(ClauseElement, _CompareMixin): co = _ColumnClause(self.anon_label, selectable, type_=getattr(self, 'type', None)) co.proxies = [self] - selectable.columns[name]= co + selectable.columns[name] = co return co def anon_label(self): @@ -1613,7 +1618,7 @@ class ColumnCollection(util.OrderedProperties): def __contains__(self, other): if not isinstance(other, basestring): - raise exceptions.ArgumentError("__contains__ requires a string argument") + raise exc.ArgumentError("__contains__ requires a string argument") return util.OrderedProperties.__contains__(self, other) def contains_column(self, col): @@ -1641,6 +1646,9 @@ class ColumnSet(util.OrderedSet): l.append(c==local) return and_(*l) + def __hash__(self): + return hash(tuple(self._list)) + class Selectable(ClauseElement): """mark a class as being selectable""" @@ -1648,8 +1656,9 @@ class FromClause(Selectable): """Represent an element that can be used within the ``FROM`` clause of a ``SELECT`` statement.""" __visit_name__ = 'fromclause' - named_with_column=False + named_with_column = False _hide_froms = [] + quote = None def _get_from_objects(self, **modifiers): return [] @@ -1694,12 +1703,12 @@ class FromClause(Selectable): return fromclause in util.Set(self._cloned_set) def replace_selectable(self, old, alias): - """replace all occurences of FromClause 'old' with the given Alias object, returning a copy of this ``FromClause``.""" + """replace all occurences of FromClause 'old' with the given Alias object, returning a copy of this ``FromClause``.""" - global ClauseAdapter - if ClauseAdapter is None: - from sqlalchemy.sql.util import ClauseAdapter - return ClauseAdapter(alias).traverse(self, clone=True) + global ClauseAdapter + if ClauseAdapter is None: + from sqlalchemy.sql.util import ClauseAdapter + return ClauseAdapter(alias).traverse(self) def correspond_on_equivalents(self, column, equivalents): col = self.corresponding_column(column, require_embedded=True) @@ -1859,7 +1868,7 @@ class _BindParamClause(ClauseElement, _CompareMixin): def _convert_to_unique(self): if not self.unique: - self.unique=True + self.unique = True self.key = "{ANON %d %s}" % (id(self), self._orig_key or 'param') def _get_from_objects(self, **modifiers): @@ -1910,6 +1919,7 @@ class _TextClause(ClauseElement): __visit_name__ = 'textclause' _bind_params_regex = re.compile(r'(?<![:\w\x5c]):(\w+)(?!:)', re.UNICODE) + supports_execution = True _hide_froms = [] oid_column = None @@ -1950,12 +1960,6 @@ class _TextClause(ClauseElement): def _get_from_objects(self, **modifiers): return [] - def supports_execution(self): - return True - - def _table_iterator(self): - return iter([]) - class _Null(ColumnElement): """Represent the NULL keyword in a SQL statement. @@ -2042,6 +2046,7 @@ class _CalculatedClause(ColumnElement): __visit_name__ = 'calculatedclause' def __init__(self, name, *clauses, **kwargs): + ColumnElement.__init__(self) self.name = name self.type = sqltypes.to_instance(kwargs.get('type_', None)) self._bind = kwargs.get('bind', None) @@ -2061,7 +2066,7 @@ class _CalculatedClause(ColumnElement): def clauses(self): if isinstance(self.clause_expr, _Grouping): - return self.clause_expr.elem + return self.clause_expr.element else: return self.clause_expr clauses = property(clauses) @@ -2239,8 +2244,13 @@ class _Exists(_UnaryExpression): __visit_name__ = _UnaryExpression.__visit_name__ def __init__(self, *args, **kwargs): - kwargs['correlate'] = True - s = select(*args, **kwargs).as_scalar().self_group() + if args and isinstance(args[0], _SelectBaseMixin): + s = args[0] + else: + if not args: + args = ([literal_column('*')],) + s = select(*args, **kwargs).as_scalar().self_group() + _UnaryExpression.__init__(self, s, operator=operators.exists) def select(self, whereclause=None, **params): @@ -2272,7 +2282,7 @@ class Join(FromClause): self.right = _selectable(right).self_group() if onclause is None: - self.onclause = self.__match_primaries(self.left, self.right) + self.onclause = self._match_primaries(self.left, self.right) else: self.onclause = onclause @@ -2310,7 +2320,7 @@ class Join(FromClause): def get_children(self, **kwargs): return self.left, self.right, self.onclause - def __match_primaries(self, primary, secondary): + def _match_primaries(self, primary, secondary): global sql_util if not sql_util: from sqlalchemy.sql import util as sql_util @@ -2359,7 +2369,7 @@ class Join(FromClause): return self.select(use_labels=True, correlate=False).alias(name) def _hide_froms(self): - return itertools.chain(*[x.left._get_from_objects() + x.right._get_from_objects() for x in self._cloned_set]) + return itertools.chain(*[_from_objects(x.left, x.right) for x in self._cloned_set]) _hide_froms = property(_hide_froms) def _get_from_objects(self, **modifiers): @@ -2382,9 +2392,10 @@ class Alias(FromClause): def __init__(self, selectable, alias=None): baseselectable = selectable while isinstance(baseselectable, Alias): - baseselectable = baseselectable.selectable + baseselectable = baseselectable.element self.original = baseselectable - self.selectable = selectable + self.supports_execution = baseselectable.supports_execution + self.element = selectable if alias is None: if self.original.named_with_column: alias = getattr(self.original, 'name', None) @@ -2398,112 +2409,100 @@ class Alias(FromClause): def is_derived_from(self, fromclause): if fromclause in util.Set(self._cloned_set): return True - return self.selectable.is_derived_from(fromclause) - - def supports_execution(self): - return self.original.supports_execution() - - def _table_iterator(self): - return self.original._table_iterator() + return self.element.is_derived_from(fromclause) def _populate_column_collection(self): - for col in self.selectable.columns: + for col in self.element.columns: col._make_proxy(self) - if self.selectable.oid_column is not None: - self._oid_column = self.selectable.oid_column._make_proxy(self) + if self.element.oid_column is not None: + self._oid_column = self.element.oid_column._make_proxy(self) def _copy_internals(self, clone=_clone): - self._reset_exported() - self.selectable = _clone(self.selectable) - baseselectable = self.selectable - while isinstance(baseselectable, Alias): - baseselectable = baseselectable.selectable - self.original = baseselectable + self._reset_exported() + self.element = _clone(self.element) + baseselectable = self.element + while isinstance(baseselectable, Alias): + baseselectable = baseselectable.selectable + self.original = baseselectable def get_children(self, column_collections=True, aliased_selectables=True, **kwargs): if column_collections: for c in self.c: yield c if aliased_selectables: - yield self.selectable + yield self.element def _get_from_objects(self, **modifiers): return [self] def bind(self): - return self.selectable.bind + return self.element.bind bind = property(bind) -class _ColumnElementAdapter(ColumnElement): - """Adapts a ClauseElement which may or may not be a - ColumnElement subclass itself into an object which - acts like a ColumnElement. - """ +class _Grouping(ColumnElement): + """Represent a grouping within a column expression""" - def __init__(self, elem): - self.elem = elem - self.type = getattr(elem, 'type', None) + def __init__(self, element): + ColumnElement.__init__(self) + self.element = element + self.type = getattr(element, 'type', None) def key(self): - return self.elem.key + return self.element.key key = property(key) def _label(self): try: - return self.elem._label + return self.element._label except AttributeError: return self.anon_label _label = property(_label) def _copy_internals(self, clone=_clone): - self.elem = clone(self.elem) + self.element = clone(self.element) def get_children(self, **kwargs): - return self.elem, + return self.element, def _get_from_objects(self, **modifiers): - return self.elem._get_from_objects(**modifiers) + return self.element._get_from_objects(**modifiers) def __getattr__(self, attr): - return getattr(self.elem, attr) + return getattr(self.element, attr) def __getstate__(self): - return {'elem':self.elem, 'type':self.type} + return {'element':self.element, 'type':self.type} def __setstate__(self, state): - self.elem = state['elem'] + self.element = state['element'] self.type = state['type'] -class _Grouping(_ColumnElementAdapter): - """Represent a grouping within a column expression""" - pass - class _FromGrouping(FromClause): """Represent a grouping of a FROM clause""" __visit_name__ = 'grouping' - def __init__(self, elem): - self.elem = elem + def __init__(self, element): + self.element = element def columns(self): - return self.elem.columns + return self.element.columns columns = c = property(columns) def _hide_froms(self): - return self.elem._hide_froms + return self.element._hide_froms _hide_froms = property(_hide_froms) def get_children(self, **kwargs): - return self.elem, + return self.element, def _copy_internals(self, clone=_clone): - self.elem = clone(self.elem) + self.element = clone(self.element) def _get_from_objects(self, **modifiers): - return self.elem._get_from_objects(**modifiers) + return self.element._get_from_objects(**modifiers) def __getattr__(self, attr): - return getattr(self.elem, attr) + return getattr(self.element, attr) class _Label(ColumnElement): """Represents a column label (AS). @@ -2516,12 +2515,12 @@ class _Label(ColumnElement): ``ColumnElement`` subclasses. """ - def __init__(self, name, obj, type_=None): - while isinstance(obj, _Label): - obj = obj.obj - self.name = name or "{ANON %d %s}" % (id(self), getattr(obj, 'name', 'anon')) - self.obj = obj.self_group(against=operators.as_) - self.type = sqltypes.to_instance(type_ or getattr(obj, 'type', None)) + def __init__(self, name, element, type_=None): + while isinstance(element, _Label): + element = element.element + self.name = name or "{ANON %d %s}" % (id(self), getattr(element, 'name', 'anon')) + self.element = element.self_group(against=operators.as_) + self.type = sqltypes.to_instance(type_ or getattr(element, 'type', None)) def key(self): return self.name @@ -2532,8 +2531,9 @@ class _Label(ColumnElement): _label = property(_label) def _proxy_attr(name): + get = util.attrgetter(name) def attr(self): - return getattr(self.obj, name) + return get(self.element) return property(attr) proxies = _proxy_attr('proxies') @@ -2542,27 +2542,24 @@ class _Label(ColumnElement): primary_key = _proxy_attr('primary_key') foreign_keys = _proxy_attr('foreign_keys') - def expression_element(self): - return self.obj - def get_children(self, **kwargs): - return self.obj, + return self.element, def _copy_internals(self, clone=_clone): - self.obj = clone(self.obj) + self.element = clone(self.element) def _get_from_objects(self, **modifiers): - return self.obj._get_from_objects(**modifiers) + return self.element._get_from_objects(**modifiers) def _make_proxy(self, selectable, name = None): - if isinstance(self.obj, (Selectable, ColumnElement)): - e = self.obj._make_proxy(selectable, name=self.name) + if isinstance(self.element, (Selectable, ColumnElement)): + e = self.element._make_proxy(selectable, name=self.name) else: e = column(self.name)._make_proxy(selectable=selectable) e.proxies.append(self) return e -class _ColumnClause(ColumnElement): +class _ColumnClause(_Immutable, ColumnElement): """Represents a generic column expression from any textual string. This includes columns associated with tables, aliases and select @@ -2602,16 +2599,7 @@ class _ColumnClause(ColumnElement): return self.name.encode('ascii', 'backslashreplace') description = property(description) - def _clone(self): - # ColumnClause is immutable - return self - def _label(self): - """Generate a 'label' string for this column. - """ - - # for a "literal" column, we've no idea what the text is - # therefore no 'label' can be automatically generated if self.is_literal: return None if not self.__label: @@ -2626,24 +2614,21 @@ class _ColumnClause(ColumnElement): counter = 1 while label in self.table.c: label = self.__label + "_" + str(counter) - counter +=1 + counter += 1 self.__label = label else: self.__label = self.name return self.__label - _label = property(_label) def label(self, name): - # if going off the "__label" property and its None, we have - # no label; return self if name is None: return self else: return super(_ColumnClause, self).label(name) def _get_from_objects(self, **modifiers): - if self.table is not None: + if self.table: return [self.table] else: return [] @@ -2651,20 +2636,20 @@ class _ColumnClause(ColumnElement): def _bind_param(self, obj): return _BindParamClause(self.name, obj, type_=self.type, unique=True) - def _make_proxy(self, selectable, name = None): + def _make_proxy(self, selectable, name=None, attach=True): # propigate the "is_literal" flag only if we are keeping our name, # otherwise its considered to be a label is_literal = self.is_literal and (name is None or name == self.name) c = _ColumnClause(name or self.name, selectable=selectable, _is_oid=self._is_oid, type_=self.type, is_literal=is_literal) c.proxies = [self] - if not self._is_oid: + if attach and not self._is_oid: selectable.columns[c.name] = c return c def _compare_type(self, obj): return self.type -class TableClause(FromClause): +class TableClause(_Immutable, FromClause): """Represents a "table" construct. Note that this represents tables only as another syntactical @@ -2691,10 +2676,6 @@ class TableClause(FromClause): return self.name.encode('ascii', 'backslashreplace') description = property(description) - def _clone(self): - # TableClause is immutable - return self - def append_column(self, c): self._columns[c.name] = c c.table = self @@ -2724,10 +2705,11 @@ class TableClause(FromClause): def _get_from_objects(self, **modifiers): return [self] - class _SelectBaseMixin(object): """Base class for ``Select`` and ``CompoundSelects``.""" + supports_execution = True + def __init__(self, use_labels=False, for_update=False, limit=None, offset=None, order_by=None, group_by=None, bind=None, autocommit=False): self.use_labels = use_labels self.for_update = for_update @@ -2773,11 +2755,6 @@ class _SelectBaseMixin(object): """ return self.as_scalar().label(name) - def supports_execution(self): - """part of the ClauseElement contract; returns ``True`` in all cases for this class.""" - - return True - def autocommit(self): """return a new selectable with the 'autocommit' flag set to True.""" @@ -2860,15 +2837,15 @@ class _SelectBaseMixin(object): class _ScalarSelect(_Grouping): __visit_name__ = 'grouping' - def __init__(self, elem): - self.elem = elem - cols = list(elem.inner_columns) + def __init__(self, element): + self.element = element + cols = list(element.inner_columns) if len(cols) != 1: - raise exceptions.InvalidRequestError("Scalar select can only be created from a Select object that has exactly one column expression.") + raise exc.InvalidRequestError("Scalar select can only be created from a Select object that has exactly one column expression.") self.type = cols[0].type def columns(self): - raise exceptions.InvalidRequestError("Scalar Select expression has no columns; use this object directly within a column-level expression.") + raise exc.InvalidRequestError("Scalar Select expression has no columns; use this object directly within a column-level expression.") columns = c = property(columns) def self_group(self, **kwargs): @@ -2893,7 +2870,7 @@ class CompoundSelect(_SelectBaseMixin, FromClause): if not numcols: numcols = len(s.c) elif len(s.c) != numcols: - raise exceptions.ArgumentError("All selectables passed to CompoundSelect must have identical numbers of columns; select #%d has %d columns, select #%d has %d" % + raise exc.ArgumentError("All selectables passed to CompoundSelect must have identical numbers of columns; select #%d has %d columns, select #%d has %d" % (1, len(self.selects[0].c), n+1, len(s.c)) ) if s._order_by_clause: @@ -2936,11 +2913,6 @@ class CompoundSelect(_SelectBaseMixin, FromClause): return (column_collections and list(self.c) or []) + \ [self._order_by_clause, self._group_by_clause] + list(self.selects) - def _table_iterator(self): - for s in self.selects: - for t in s._table_iterator(): - yield t - def bind(self): if self._bind: return self._bind @@ -2976,6 +2948,7 @@ class Select(_SelectBaseMixin, FromClause): self._distinct = distinct self._correlate = util.Set() + self._froms = util.OrderedSet() if columns: self._raw_columns = [ @@ -2983,22 +2956,23 @@ class Select(_SelectBaseMixin, FromClause): for c in [_literal_as_column(c) for c in columns] ] + + self._froms.update(_from_objects(*self._raw_columns)) else: self._raw_columns = [] - - if from_obj: - self._froms = util.Set([ - _is_literal(f) and _TextClause(f) or f - for f in util.to_list(from_obj) - ]) - else: - self._froms = util.Set() - + if whereclause: self._whereclause = _literal_as_text(whereclause) + self._froms.update(_from_objects(self._whereclause, is_where=True)) else: self._whereclause = None + if from_obj: + self._froms.update([ + _is_literal(f) and _TextClause(f) or f + for f in util.to_list(from_obj) + ]) + if having: self._having = _literal_as_text(having) else: @@ -3020,36 +2994,28 @@ class Select(_SelectBaseMixin, FromClause): correlating. """ - froms = util.OrderedSet() - - for col in self._raw_columns: - froms.update(col._get_from_objects()) - - if self._whereclause is not None: - froms.update(self._whereclause._get_from_objects(is_where=True)) - - if self._froms: - froms.update(self._froms) + froms = self._froms toremove = itertools.chain(*[f._hide_froms for f in froms]) - froms.difference_update(toremove) + if toremove: + froms = froms.difference(toremove) if len(froms) > 1 or self._correlate: if self._correlate: - froms.difference_update(_cloned_intersection(froms, self._correlate)) + froms = froms.difference(_cloned_intersection(froms, self._correlate)) if self._should_correlate and existing_froms: - froms.difference_update(_cloned_intersection(froms, existing_froms)) + froms = froms.difference(_cloned_intersection(froms, existing_froms)) if not len(froms): - raise exceptions.InvalidRequestError("Select statement '%s' returned no FROM clauses due to auto-correlation; specify correlate(<tables>) to control correlation manually." % self) + raise exc.InvalidRequestError("Select statement '%s' returned no FROM clauses due to auto-correlation; specify correlate(<tables>) to control correlation manually." % self) return froms froms = property(_get_display_froms, doc="""Return a list of all FromClause elements which will be applied to the FROM clause of the resulting statement.""") def type(self): - raise exceptions.InvalidRequestError("Select objects don't have a type. Call as_scalar() on this Select object to return a 'scalar' version of this Select.") + raise exc.InvalidRequestError("Select objects don't have a type. Call as_scalar() on this Select object to return a 'scalar' version of this Select.") type = property(type) def locate_all_froms(self): @@ -3059,22 +3025,10 @@ class Select(_SelectBaseMixin, FromClause): is specifically for those FromClause elements that would actually be rendered. """ - if hasattr(self, '_all_froms'): - return self._all_froms - - froms = util.Set( - itertools.chain(* - [self._froms] + - [f._get_from_objects() for f in self._froms] + - [col._get_from_objects() for col in self._raw_columns] - ) - ) + if not hasattr(self, '_all_froms'): + self._all_froms = self._froms.union(_from_objects(*list(self._froms))) - if self._whereclause: - froms.update(self._whereclause._get_from_objects(is_where=True)) - - self._all_froms = froms - return froms + return self._all_froms def inner_columns(self): """an iteratorof all ColumnElement expressions which would @@ -3092,7 +3046,7 @@ class Select(_SelectBaseMixin, FromClause): def is_derived_from(self, fromclause): if self in util.Set(fromclause._cloned_set): return True - + for f in self.locate_all_froms(): if f.is_derived_from(fromclause): return True @@ -3112,7 +3066,7 @@ class Select(_SelectBaseMixin, FromClause): """return child elements as per the ClauseElement specification.""" return (column_collections and list(self.columns) or []) + \ - list(self.locate_all_froms()) + \ + list(self._froms) + \ [x for x in (self._whereclause, self._having, self._order_by_clause, self._group_by_clause) if x is not None] def column(self, column): @@ -3125,6 +3079,7 @@ class Select(_SelectBaseMixin, FromClause): column = column.self_group(against=operators.comma_op) s._raw_columns = s._raw_columns + [column] + s._froms = s._froms.union(_from_objects(column)) return s def where(self, whereclause): @@ -3185,7 +3140,7 @@ class Select(_SelectBaseMixin, FromClause): """ s = self._generate() - s._should_correlate=False + s._should_correlate = False if fromclauses == (None,): s._correlate = util.Set() else: @@ -3195,7 +3150,7 @@ class Select(_SelectBaseMixin, FromClause): def append_correlation(self, fromclause): """append the given correlation expression to this select() construct.""" - self._should_correlate=False + self._should_correlate = False self._correlate = self._correlate.union([fromclause]) def append_column(self, column): @@ -3207,6 +3162,7 @@ class Select(_SelectBaseMixin, FromClause): column = column.self_group(against=operators.comma_op) self._raw_columns = self._raw_columns + [column] + self._froms = self._froms.union(_from_objects(column)) self._reset_exported() def append_prefix(self, clause): @@ -3221,10 +3177,13 @@ class Select(_SelectBaseMixin, FromClause): The expression will be joined to existing WHERE criterion via AND. """ + whereclause = _literal_as_text(whereclause) + self._froms = self._froms.union(_from_objects(whereclause, is_where=True)) + if self._whereclause is not None: - self._whereclause = and_(self._whereclause, _literal_as_text(whereclause)) + self._whereclause = and_(self._whereclause, whereclause) else: - self._whereclause = _literal_as_text(whereclause) + self._whereclause = whereclause def append_having(self, having): """append the given expression to this select() construct's HAVING criterion. @@ -3311,31 +3270,23 @@ class Select(_SelectBaseMixin, FromClause): return intersect_all(self, other, **kwargs) - def _table_iterator(self): - for t in visitors.NoColumnVisitor().iterate(self): - if isinstance(t, TableClause): - yield t - def bind(self): if self._bind: return self._bind - for f in self._froms: - if f is self: - continue - e = f.bind - if e: - self._bind = e - return e - # look through the columns (largely synomous with looking - # through the FROMs except in the case of _CalculatedClause/_Function) - for c in self._raw_columns: - if getattr(c, 'table', None) is self: - continue - e = c.bind + if not self._froms: + for c in self._raw_columns: + e = c.bind + if e: + self._bind = e + return e + else: + e = list(self._froms)[0].bind if e: self._bind = e return e + return None + def _set_bind(self, bind): self._bind = bind bind = property(bind, _set_bind) @@ -3343,11 +3294,7 @@ class Select(_SelectBaseMixin, FromClause): class _UpdateBase(ClauseElement): """Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements.""" - def supports_execution(self): - return True - - def _table_iterator(self): - return iter([self.table]) + supports_execution = True def _generate(self): s = self.__class__.__new__(self.__class__) @@ -3407,7 +3354,7 @@ class Insert(_ValuesBase): self._bind = bind self.table = table self.select = None - self.inline=inline + self.inline = inline if prefixes: self._prefixes = [_literal_as_text(p) for p in prefixes] else: @@ -3502,10 +3449,11 @@ class Delete(_UpdateBase): self._whereclause = clone(self._whereclause) class _IdentifiedClause(ClauseElement): + supports_execution = True + quote = None + def __init__(self, ident): self.ident = ident - def supports_execution(self): - return True class SavepointClause(_IdentifiedClause): pass diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index dfd638ecb..46dcaba66 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -44,7 +44,7 @@ def between_op(a, b, c): return a.between(b, c) def in_op(a, b): - return a.in_(*b) + return a.in_(b) def notin_op(a, b): raise NotImplementedError() diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index d299982cf..944a68def 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -1,4 +1,4 @@ -from sqlalchemy import exceptions, schema, topological, util, sql +from sqlalchemy import exc, schema, topological, util, sql from sqlalchemy.sql import expression, operators, visitors from itertools import chain @@ -8,43 +8,57 @@ def sort_tables(tables, reverse=False): """sort a collection of Table objects in order of their foreign-key dependency.""" tuples = [] - class TVisitor(schema.SchemaVisitor): - def visit_foreign_key(_self, fkey): - if fkey.use_alter: - return - parent_table = fkey.column.table - if parent_table in tables: - child_table = fkey.parent.table - tuples.append( ( parent_table, child_table ) ) - vis = TVisitor() + def visit_foreign_key(fkey): + if fkey.use_alter: + return + parent_table = fkey.column.table + if parent_table in tables: + child_table = fkey.parent.table + tuples.append( ( parent_table, child_table ) ) + for table in tables: - vis.traverse(table) + visitors.traverse(table, {'schema_visitor':True}, {'foreign_key':visit_foreign_key}) sequence = topological.sort(tuples, tables) if reverse: return util.reversed(sequence) else: return sequence -def find_tables(clause, check_columns=False, include_aliases=False): +def search(clause, target): + if not clause: + return False + for elem in visitors.iterate(clause, {'column_collections':False}): + if elem is target: + return True + else: + return False + +def find_tables(clause, check_columns=False, include_aliases=False, include_joins=False, include_selects=False): """locate Table objects within the given expression.""" tables = [] - kwargs = {} + _visitors = {} + + def visit_something(elem): + tables.append(elem) + + if include_selects: + _visitors['select'] = _visitors['compound_select'] = visit_something + + if include_joins: + _visitors['join'] = visit_something + if include_aliases: - def visit_alias(alias): - tables.append(alias) - kwargs['visit_alias'] = visit_alias + _visitors['alias'] = visit_something if check_columns: def visit_column(column): tables.append(column.table) - kwargs['visit_column'] = visit_column + _visitors['column'] = visit_column - def visit_table(table): - tables.append(table) - kwargs['visit_table'] = visit_table + _visitors['table'] = visit_something - visitors.traverse(clause, traverse_options= {'column_collections':False}, **kwargs) + visitors.traverse(clause, {'column_collections':False}, _visitors) return tables def find_columns(clause): @@ -53,7 +67,7 @@ def find_columns(clause): cols = util.Set() def visit_column(col): cols.add(col) - visitors.traverse(clause, visit_column=visit_column) + visitors.traverse(clause, {}, {'column':visit_column}) return cols def join_condition(a, b, ignore_nonexistent_tables=False): @@ -72,7 +86,7 @@ def join_condition(a, b, ignore_nonexistent_tables=False): for fk in b.foreign_keys: try: col = fk.get_referent(a) - except exceptions.NoReferencedTableError: + except exc.NoReferencedTableError: if ignore_nonexistent_tables: continue else: @@ -81,27 +95,26 @@ def join_condition(a, b, ignore_nonexistent_tables=False): if col: crit.append(col == fk.parent) constraints.add(fk.constraint) - if a is not b: for fk in a.foreign_keys: try: col = fk.get_referent(b) - except exceptions.NoReferencedTableError: + except exc.NoReferencedTableError: if ignore_nonexistent_tables: continue else: raise - + if col: crit.append(col == fk.parent) constraints.add(fk.constraint) if len(crit) == 0: - raise exceptions.ArgumentError( + raise exc.ArgumentError( "Can't find any foreign key relationships " "between '%s' and '%s'" % (a.description, b.description)) elif len(constraints) > 1: - raise exceptions.ArgumentError( + raise exc.ArgumentError( "Can't determine join between '%s' and '%s'; " "tables have more than one foreign key " "constraint relationship between them. " @@ -111,7 +124,70 @@ def join_condition(a, b, ignore_nonexistent_tables=False): return (crit[0]) else: return sql.and_(*crit) + +class Annotated(object): + """clones a ClauseElement and applies an 'annotations' dictionary. + + Unlike regular clones, this clone also mimics __hash__() and + __cmp__() of the original element so that it takes its place + in hashed collections. + A reference to the original element is maintained, for the important + reason of keeping its hash value current. When GC'ed, the + hash value may be reused, causing conflicts. + + """ + def __new__(cls, *args): + if not args: + return object.__new__(cls) + else: + element, values = args + return object.__new__( + type.__new__(type, "Annotated%s" % element.__class__.__name__, (Annotated, element.__class__), {}) + ) + + def __init__(self, element, values): + self.__dict__ = element.__dict__.copy() + self.__element = element + self._annotations = values + + def _annotate(self, values): + _values = self._annotations.copy() + _values.update(values) + clone = self.__class__.__new__(self.__class__) + clone.__dict__ = self.__dict__.copy() + clone._annotations = _values + return clone + + def __hash__(self): + return hash(self.__element) + + def __cmp__(self, other): + return cmp(hash(self.__element), hash(other)) + +def splice_joins(left, right, stop_on=None): + if left is None: + return right + + stack = [(right, None)] + + adapter = ClauseAdapter(left) + ret = None + while stack: + (right, prevright) = stack.pop() + if isinstance(right, expression.Join) and right is not stop_on: + right = right._clone() + right._reset_exported() + right.onclause = adapter.traverse(right.onclause) + stack.append((right.left, right)) + else: + right = adapter.traverse(right) + if prevright: + prevright.left = right + if not ret: + ret = right + + return ret def reduce_columns(columns, *clauses): """given a list of columns, return a 'reduced' set based on natural equivalents. @@ -151,7 +227,7 @@ def reduce_columns(columns, *clauses): omit.add(c) break for clause in clauses: - visitors.traverse(clause, visit_binary=visit_binary) + visitors.traverse(clause, {}, {'binary':visit_binary}) return expression.ColumnSet(columns.difference(omit)) @@ -159,7 +235,7 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None, consider_as_re """traverse an expression and locate binary criterion pairs.""" if consider_as_foreign_keys and consider_as_referenced_keys: - raise exceptions.ArgumentError("Can only specify one of 'consider_as_foreign_keys' or 'consider_as_referenced_keys'") + raise exc.ArgumentError("Can only specify one of 'consider_as_foreign_keys' or 'consider_as_referenced_keys'") def visit_binary(binary): if not any_operator and binary.operator != operators.eq: @@ -184,7 +260,7 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None, consider_as_re elif binary.right.references(binary.left): pairs.append((binary.left, binary.right)) pairs = [] - visitors.traverse(expression, visit_binary=visit_binary) + visitors.traverse(expression, {}, {'binary':visit_binary}) return pairs def folded_equivalents(join, equivs=None): @@ -195,15 +271,15 @@ def folded_equivalents(join, equivs=None): This function is used by Join.select(fold_equivalents=True). TODO: deprecate ? - """ + """ if equivs is None: equivs = util.Set() def visit_binary(binary): if binary.operator == operators.eq and binary.left.name == binary.right.name: equivs.add(binary.right) equivs.add(binary.left) - visitors.traverse(join.onclause, visit_binary=visit_binary) + visitors.traverse(join.onclause, {}, {'binary':visit_binary}) collist = [] if isinstance(join.left, expression.Join): left = folded_equivalents(join.left, equivs) @@ -246,43 +322,8 @@ class AliasedRow(object): def keys(self): return self.row.keys() -def row_adapter(from_, equivalent_columns=None): - """create a row adapter callable against a selectable.""" - - if equivalent_columns is None: - equivalent_columns = {} - - def locate_col(col): - c = from_.corresponding_column(col) - if c: - return c - elif col in equivalent_columns: - for c2 in equivalent_columns[col]: - corr = from_.corresponding_column(c2) - if corr: - return corr - return col - - map = util.PopulateDict(locate_col) - - def adapt(row): - return AliasedRow(row, map) - return adapt - -class ColumnsInClause(visitors.ClauseVisitor): - """Given a selectable, visit clauses and determine if any columns - from the clause are in the selectable. - """ - - def __init__(self, selectable): - self.selectable = selectable - self.result = False - - def visit_column(self, column): - if self.selectable.c.get(column.key) is column: - self.result = True -class ClauseAdapter(visitors.ClauseVisitor): +class ClauseAdapter(visitors.ReplacingCloningVisitor): """Given a clause (like as in a WHERE criterion), locate columns which are embedded within a given selectable, and changes those columns to be that of the selectable. @@ -308,58 +349,76 @@ class ClauseAdapter(visitors.ClauseVisitor): condition to read:: s.c.col1 == table2.c.col1 - """ - - __traverse_options__ = {'column_collections':False} - def __init__(self, selectable, include=None, exclude=None, equivalents=None): - self.__traverse_options__ = self.__traverse_options__.copy() - self.__traverse_options__['stop_on'] = [selectable] + """ + def __init__(self, selectable, equivalents=None, include=None, exclude=None): + self.__traverse_options__ = {'column_collections':False, 'stop_on':[selectable]} self.selectable = selectable self.include = include self.exclude = exclude - self.equivalents = equivalents - - def traverse(self, obj, clone=True): - if not clone: - raise exceptions.ArgumentError("ClauseAdapter 'clone' argument must be True") - return visitors.ClauseVisitor.traverse(self, obj, clone=True) - - def copy_and_chain(self, adapter): - """create a copy of this adapter and chain to the given adapter. - - currently this adapter must be unchained to start, raises - an exception if it's already chained. - - Does not modify the given adapter. - """ + self.equivalents = equivalents or {} - if adapter is None: - return self + def _corresponding_column(self, col, require_embedded): + newcol = self.selectable.corresponding_column(col, require_embedded=require_embedded) - if hasattr(self, '_next'): - raise NotImplementedError("Can't chain_to on an already chained ClauseAdapter (yet)") - - ca = ClauseAdapter(self.selectable, self.include, self.exclude, self.equivalents) - ca._next = adapter - return ca + if not newcol and col in self.equivalents: + for equiv in self.equivalents[col]: + newcol = self.selectable.corresponding_column(equiv, require_embedded=require_embedded) + if newcol: + return newcol + return newcol - def before_clone(self, col): + def replace(self, col): if isinstance(col, expression.FromClause): if self.selectable.is_derived_from(col): return self.selectable + if not isinstance(col, expression.ColumnElement): return None - if self.include is not None: - if col not in self.include: - return None - if self.exclude is not None: - if col in self.exclude: - return None - newcol = self.selectable.corresponding_column(col, require_embedded=True) - if newcol is None and self.equivalents is not None and col in self.equivalents: - for equiv in self.equivalents[col]: - newcol = self.selectable.corresponding_column(equiv, require_embedded=True) - if newcol: - return newcol - return newcol + + if self.include and col not in self.include: + return None + elif self.exclude and col in self.exclude: + return None + + return self._corresponding_column(col, True) + +class ColumnAdapter(ClauseAdapter): + + def __init__(self, selectable, equivalents=None, chain_to=None, include=None, exclude=None): + ClauseAdapter.__init__(self, selectable, equivalents, include, exclude) + if chain_to: + self.chain(chain_to) + self.columns = util.PopulateDict(self._locate_col) + + def wrap(self, adapter): + ac = self.__class__.__new__(self.__class__) + ac.__dict__ = self.__dict__.copy() + ac._locate_col = ac._wrap(ac._locate_col, adapter._locate_col) + ac.adapt_clause = ac._wrap(ac.adapt_clause, adapter.adapt_clause) + ac.adapt_list = ac._wrap(ac.adapt_list, adapter.adapt_list) + ac.columns = util.PopulateDict(ac._locate_col) + return ac + + adapt_clause = ClauseAdapter.traverse + adapt_list = ClauseAdapter.copy_and_process + + def _wrap(self, local, wrapped): + def locate(col): + col = local(col) + return wrapped(col) + return locate + + def _locate_col(self, col): + c = self._corresponding_column(col, False) + if not c: + c = self.adapt_clause(col) + + # anonymize labels in case they have a hardcoded name + if isinstance(c, expression._Label): + c = c.label(None) + return c + + def adapted_row(self, row): + return AliasedRow(row, self.columns) + diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 9888a228a..738dae9c7 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -1,138 +1,29 @@ from sqlalchemy import util class ClauseVisitor(object): - """Traverses and visits ``ClauseElement`` structures. - - Calls visit_XXX() methods for each particular - ``ClauseElement`` subclass encountered. Traversal of a - hierarchy of ``ClauseElements`` is achieved via the - ``traverse()`` method, which is passed the lead - ``ClauseElement``. - - By default, ``ClauseVisitor`` traverses all elements - fully. Options can be specified at the class level via the - ``__traverse_options__`` dictionary which will be passed - to the ``get_children()`` method of each ``ClauseElement``; - these options can indicate modifications to the set of - elements returned, such as to not return column collections - (column_collections=False) or to return Schema-level items - (schema_visitor=True). - - ``ClauseVisitor`` also supports a simultaneous copy-and-traverse - operation, which will produce a copy of a given ``ClauseElement`` - structure while at the same time allowing ``ClauseVisitor`` subclasses - to modify the new structure in-place. - - """ __traverse_options__ = {} - def traverse_single(self, obj, **kwargs): - """visit a single element, without traversing its child elements.""" - + def traverse_single(self, obj): for v in self._iterate_visitors: meth = getattr(v, "visit_%s" % obj.__visit_name__, None) if meth: - return meth(obj, **kwargs) + return meth(obj) - traverse_chained = traverse_single - def iterate(self, obj): """traverse the given expression structure, returning an iterator of all elements.""" - - stack = [obj] - traversal = util.deque() - while stack: - t = stack.pop() - traversal.appendleft(t) - for c in t.get_children(**self.__traverse_options__): - stack.append(c) - return iter(traversal) - - def traverse(self, obj, clone=False): - """traverse and visit the given expression structure. - - Returns the structure given, or a copy of the structure if - clone=True. - - When the copy operation takes place, the before_clone() method - will receive each element before it is copied. If the method - returns a non-None value, the return value is taken as the - "copied" element and traversal will not descend further. - - The visit_XXX() methods receive the element *after* it's been - copied. To compare an element to another regardless of - one element being a cloned copy of the original, the - '_cloned_set' attribute of ClauseElement can be used for the compare, - i.e.:: - - original in copied._cloned_set - - - """ - if clone: - return self._cloned_traversal(obj) - else: - return self._non_cloned_traversal(obj) - - def copy_and_process(self, list_): - """Apply cloned traversal to the given list of elements, and return the new list.""" - - return [self._cloned_traversal(x) for x in list_] - def before_clone(self, elem): - """receive pre-copied elements during a cloning traversal. - - If the method returns a new element, the element is used - instead of creating a simple copy of the element. Traversal - will halt on the newly returned element if it is re-encountered. - """ - return None - - def _clone_element(self, elem, stop_on, cloned): - for v in self._iterate_visitors: - newelem = v.before_clone(elem) - if newelem: - stop_on.add(newelem) - return newelem - - if elem not in cloned: - # the full traversal will only make a clone of a particular element - # once. - cloned[elem] = elem._clone() - return cloned[elem] - - def _cloned_traversal(self, obj): - """a recursive traversal which creates copies of elements, returning the new structure.""" - - stop_on = self.__traverse_options__.get('stop_on', []) - return self._cloned_traversal_impl(obj, util.Set(stop_on), {}, _clone_toplevel=True) - - def _cloned_traversal_impl(self, elem, stop_on, cloned, _clone_toplevel=False): - if elem in stop_on: - return elem - - if _clone_toplevel: - elem = self._clone_element(elem, stop_on, cloned) - if elem in stop_on: - return elem - - def clone(element): - return self._clone_element(element, stop_on, cloned) - elem._copy_internals(clone=clone) + return iterate(obj, self.__traverse_options__) - self.traverse_single(elem) + def traverse(self, obj): + """traverse and visit the given expression structure.""" - for e in elem.get_children(**self.__traverse_options__): - if e not in stop_on: - self._cloned_traversal_impl(e, stop_on, cloned) - return elem + visitors = {} - def _non_cloned_traversal(self, obj): - """a non-recursive, non-cloning traversal.""" - - for target in self.iterate(obj): - self.traverse_single(target) - return obj + for name in dir(self): + if name.startswith('visit_'): + visitors[name[6:]] = getattr(self, name) + + return traverse(obj, self.__traverse_options__, visitors) def _iterate_visitors(self): """iterate through this visitor and each 'chained' visitor.""" @@ -152,31 +43,136 @@ class ClauseVisitor(object): tail._next = visitor return self -class NoColumnVisitor(ClauseVisitor): - """ClauseVisitor with 'column_collections' set to False; will not - traverse the front-facing Column collections on Table, Alias, Select, - and CompoundSelect objects. +class CloningVisitor(ClauseVisitor): + def copy_and_process(self, list_): + """Apply cloned traversal to the given list of elements, and return the new list.""" + + return [self.traverse(x) for x in list_] + + def traverse(self, obj): + """traverse and visit the given expression structure.""" + + visitors = {} + + for name in dir(self): + if name.startswith('visit_'): + visitors[name[6:]] = getattr(self, name) + + return cloned_traverse(obj, self.__traverse_options__, visitors) + +class ReplacingCloningVisitor(CloningVisitor): + def replace(self, elem): + """receive pre-copied elements during a cloning traversal. + + If the method returns a new element, the element is used + instead of creating a simple copy of the element. Traversal + will halt on the newly returned element if it is re-encountered. + """ + return None + + def traverse(self, obj): + """traverse and visit the given expression structure.""" + + def replace(elem): + for v in self._iterate_visitors: + e = v.replace(elem) + if e: + return e + return replacement_traverse(obj, self.__traverse_options__, replace) + +def iterate(obj, opts): + """traverse the given expression structure, returning an iterator. + + traversal is configured to be breadth-first. """ + stack = util.deque([obj]) + while stack: + t = stack.popleft() + yield t + for c in t.get_children(**opts): + stack.append(c) + +def iterate_depthfirst(obj, opts): + """traverse the given expression structure, returning an iterator. - __traverse_options__ = {'column_collections':False} - -class NullVisitor(ClauseVisitor): - def traverse(self, obj, clone=False): - next = getattr(self, '_next', None) - if next: - return next.traverse(obj, clone=clone) - else: - return obj - -def traverse(clause, **kwargs): - """traverse the given clause, applying visit functions passed in as keyword arguments.""" + traversal is configured to be depth-first. + + """ + stack = util.deque([obj]) + traversal = util.deque() + while stack: + t = stack.pop() + traversal.appendleft(t) + for c in t.get_children(**opts): + stack.append(c) + return iter(traversal) + +def traverse_using(iterator, obj, visitors): + """visit the given expression structure using the given iterator of objects.""" + + for target in iterator: + meth = visitors.get(target.__visit_name__, None) + if meth: + meth(target) + return obj - clone = kwargs.pop('clone', False) - class Vis(ClauseVisitor): - __traverse_options__ = kwargs.pop('traverse_options', {}) - vis = Vis() - for key in kwargs: - setattr(vis, key, kwargs[key]) - return vis.traverse(clause, clone=clone) +def traverse(obj, opts, visitors): + """traverse and visit the given expression structure using the default iterator.""" + + return traverse_using(iterate(obj, opts), obj, visitors) + +def traverse_depthfirst(obj, opts, visitors): + """traverse and visit the given expression structure using the depth-first iterator.""" + + return traverse_using(iterate_depthfirst(obj, opts), obj, visitors) + +def cloned_traverse(obj, opts, visitors): + cloned = {} + + def clone(element): + if element not in cloned: + cloned[element] = element._clone() + return cloned[element] + + obj = clone(obj) + stack = [obj] + + while stack: + t = stack.pop() + if t in cloned: + continue + t._copy_internals(clone=clone) + + meth = visitors.get(t.__visit_name__, None) + if meth: + meth(t) + + for c in t.get_children(**opts): + stack.append(c) + return obj + +def replacement_traverse(obj, opts, replace): + cloned = {} + stop_on = util.Set(opts.get('stop_on', [])) + + def clone(element): + newelem = replace(element) + if newelem: + stop_on.add(newelem) + return newelem + + if element not in cloned: + cloned[element] = element._clone() + return cloned[element] + obj = clone(obj) + stack = [obj] + while stack: + t = stack.pop() + if t in stop_on: + continue + t._copy_internals(clone=clone) + for c in t.get_children(**opts): + stack.append(c) + return obj |
