diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2008-04-07 01:12:44 +0000 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2008-04-07 01:12:44 +0000 |
| commit | e3b2305d6721a1f1ed20f9c520765f7c33876f32 (patch) | |
| tree | 8fa4a5565f42dc836a22c44219762dee4b933fd7 /lib/sqlalchemy/sql | |
| parent | 5b3cddc48e5b436a0c46f0df3b016a837d823c92 (diff) | |
| download | sqlalchemy-e3b2305d6721a1f1ed20f9c520765f7c33876f32.tar.gz | |
- merged -r4458:4466 of query_columns branch
- this branch changes query.values() to immediately return an iterator, adds a new "aliased" construct which will be the primary method to get at aliased columns when using values()
- tentative ORM versions of _join and _outerjoin are not yet public, would like to integrate with Query better (work continues in the branch)
- lots of fixes to expressions regarding cloning and correlation. Some apparent ORM bug-workarounds removed.
- to fix a recursion issue with anonymous identifiers, bind parameters generated against columns now just use the name of the column instead of the tablename_columnname label (plus the unique integer counter). this way expensive recursive schemes aren't needed for the anon identifier logic. This, as usual, impacted a ton of compiler unit tests which needed a search-n-replace for the new bind names.
Diffstat (limited to 'lib/sqlalchemy/sql')
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 16 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/expression.py | 359 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/util.py | 36 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/visitors.py | 9 |
4 files changed, 183 insertions, 237 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 868904c21..47e5ec9c5 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -18,7 +18,7 @@ creating database-specific compilers and schema generators, the module is otherwise internal to SQLAlchemy. """ -import string, re +import string, re, itertools from sqlalchemy import schema, engine, util, exceptions from sqlalchemy.sql import operators, functions from sqlalchemy.sql import expression as sql @@ -47,7 +47,7 @@ ILLEGAL_INITIAL_CHARACTERS = re.compile(r'[0-9$]') BIND_PARAMS = re.compile(r'(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])', re.UNICODE) BIND_PARAMS_ESC = re.compile(r'\x5c(:[\w\$]+)(?![:\w\$])', re.UNICODE) -ANONYMOUS_LABEL = re.compile(r'{ANON (-?\d+) (.*?)}') +ANONYMOUS_LABEL = re.compile(r'{ANON (-?\d+) ([^{}]+)}') BIND_TEMPLATES = { 'pyformat':"%%(%(name)s)s", @@ -404,7 +404,7 @@ class DefaultCompiler(engine.Compiled): def _truncated_identifier(self, ident_class, name): if (ident_class, name) in self.generated_ids: return self.generated_ids[(ident_class, name)] - + anonname = ANONYMOUS_LABEL.sub(self._process_anon, name) if len(anonname) > self.dialect.max_identifier_length: @@ -415,9 +415,10 @@ class DefaultCompiler(engine.Compiled): truncname = anonname self.generated_ids[(ident_class, name)] = truncname return truncname - + def _process_anon(self, match): (ident, derived) = match.group(1,2) + key = ('anonymous', ident) if key in self.generated_ids: return self.generated_ids[key] @@ -460,7 +461,7 @@ class DefaultCompiler(engine.Compiled): not isinstance(column.table, sql.Select): return column.label(column.name) elif not isinstance(column, (sql._UnaryExpression, sql._TextClause)) and (not hasattr(column, 'name') or isinstance(column, sql._Function)): - return column.anon_label + return column.label(column.anon_label) else: return column @@ -488,10 +489,7 @@ class DefaultCompiler(engine.Compiled): froms = select._get_display_froms(existingfroms) - correlate_froms = util.Set() - for f in froms: - correlate_froms.add(f) - correlate_froms.update(f._get_from_objects()) + correlate_froms = util.Set(itertools.chain(*([froms] + [f._get_from_objects() for f in froms]))) # TODO: might want to propigate existing froms for select(select(select)) # where innermost select should correlate to outermost diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index b45fa4035..30f22e31f 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -65,7 +65,7 @@ def asc(column): """ return _UnaryExpression(column, modifier=operators.asc_op) -def outerjoin(left, right, onclause=None, **kwargs): +def outerjoin(left, right, onclause=None): """Return an ``OUTER JOIN`` clause element. The returned object is an instance of [sqlalchemy.sql.expression#Join]. @@ -88,9 +88,9 @@ def outerjoin(left, right, onclause=None, **kwargs): methods on the resulting ``Join`` object. """ - return Join(left, right, onclause, isouter = True, **kwargs) + return Join(left, right, onclause, isouter=True) -def join(left, right, onclause=None, **kwargs): +def join(left, right, onclause=None, isouter=False): """Return a ``JOIN`` clause element (regular inner join). The returned object is an instance of [sqlalchemy.sql.expression#Join]. @@ -113,7 +113,7 @@ def join(left, right, onclause=None, **kwargs): methods on the resulting ``Join`` object. """ - return Join(left, right, onclause, **kwargs) + return Join(left, right, onclause, isouter) def select(columns=None, whereclause=None, from_obj=[], **kwargs): """Returns a ``SELECT`` clause element. @@ -831,14 +831,35 @@ class _FunctionGenerator(object): return _Function(self.__names[-1], packagenames=self.__names[0:-1], *c, **o) +# "func" global - i.e. func.count() func = _FunctionGenerator() +# "modifier" global - i.e. modifier.distinct # TODO: use UnaryExpression for this instead ? modifier = _FunctionGenerator(group=False) def _clone(element): return element._clone() +def _expand_cloned(elements): + """expand the given set of ClauseElements to be the set of all 'cloned' predecessors.""" + + return itertools.chain(*[x._cloned_set for x in elements]) + +def _cloned_intersection(a, b): + """return the intersection of sets a and b, counting + any overlap between 'cloned' predecessors. + + The returned set is in terms of the enties present within 'a'. + + """ + all_overlap = util.Set(_expand_cloned(a)).intersection(_expand_cloned(b)) + return a.intersection( + [ + elem for elem in a if all_overlap.intersection(elem._cloned_set) + ] + ) + def _compound_select(keyword, *selects, **kwargs): return CompoundSelect(keyword, *selects, **kwargs) @@ -894,6 +915,7 @@ def _selectable(element): else: raise exceptions.ArgumentError("Object '%s' is not a Selectable and does not implement `__selectable__()`" % repr(element)) + def is_column(col): """True if ``col`` is an instance of ``ColumnElement``.""" return isinstance(col, ColumnElement) @@ -1475,7 +1497,7 @@ class ColumnElement(ClauseElement, _CompareMixin): co = _ColumnClause(name, selectable, type_=getattr(self, 'type', None)) else: name = str(self) - co = _ColumnClause(self.anon_label.name, selectable, type_=getattr(self, 'type', None)) + co = _ColumnClause(self.anon_label, selectable, type_=getattr(self, 'type', None)) co.proxies = [self] selectable.columns[name]= co @@ -1495,7 +1517,7 @@ class ColumnElement(ClauseElement, _CompareMixin): """ if not hasattr(self, '_ColumnElement__anon_label'): - self.__anon_label = self.label(None) + self.__anon_label = "{ANON %d %s}" % (id(self), getattr(self, 'name', 'anon')) return self.__anon_label anon_label = property(anon_label) @@ -1626,20 +1648,20 @@ class FromClause(Selectable): col = list(self.columns)[0] return select([func.count(col).label('tbl_row_count')], whereclause, from_obj=[self], **params) - def select(self, whereclauses = None, **params): + def select(self, whereclause=None, **params): """return a SELECT of this ``FromClause``.""" - return select([self], whereclauses, **params) + return select([self], whereclause, **params) - def join(self, right, *args, **kwargs): + def join(self, right, onclause=None, isouter=False): """return a join of this ``FromClause`` against another ``FromClause``.""" - return Join(self, right, *args, **kwargs) + return Join(self, right, onclause, isouter) - def outerjoin(self, right, *args, **kwargs): + def outerjoin(self, right, onclause=None): """return an outer join of this ``FromClause`` against another ``FromClause``.""" - return Join(self, right, isouter=True, *args, **kwargs) + return Join(self, right, onclause, True) def alias(self, name=None): """return an alias of this ``FromClause`` against another ``FromClause``.""" @@ -1709,7 +1731,7 @@ class FromClause(Selectable): return getattr(self, 'name', self.__class__.__name__ + " object") description = property(description) - def _clone_from_clause(self): + def _reset_exported(self): # delete all the "generated" collections of columns for a # newly cloned FromClause, so that they will be re-derived # from the item. this is because FromClause subclasses, when @@ -2075,7 +2097,7 @@ class _Function(_CalculatedClause, FromClause): def _copy_internals(self, clone=_clone): _CalculatedClause._copy_internals(self, clone=clone) - self._clone_from_clause() + self._reset_exported() def get_children(self, **kwargs): return _CalculatedClause.get_children(self, **kwargs) @@ -2206,8 +2228,8 @@ class _Exists(_UnaryExpression): s = select(*args, **kwargs).as_scalar().self_group() _UnaryExpression.__init__(self, s, operator=operators.exists) - def select(self, whereclauses = None, **params): - return select([self], whereclauses, **params) + def select(self, whereclause=None, **params): + return select([self], whereclause, **params) def correlate(self, fromclause): e = self._clone() @@ -2230,14 +2252,15 @@ class Join(FromClause): off all ``FromClause`` subclasses. """ - def __init__(self, left, right, onclause=None, isouter = False): + def __init__(self, left, right, onclause=None, isouter=False): self.left = _selectable(left) self.right = _selectable(right).self_group() if onclause is None: - self.onclause = self._match_primaries(self.left, self.right) + self.onclause = self.__match_primaries(self.left, self.right) else: self.onclause = onclause + self.isouter = isouter self.__folded_equivalents = None @@ -2263,7 +2286,7 @@ class Join(FromClause): self._oid_column = self.left.oid_column def _copy_internals(self, clone=_clone): - self._clone_from_clause() + self._reset_exported() self.left = clone(self.left) self.right = clone(self.right) self.onclause = clone(self.onclause) @@ -2272,7 +2295,7 @@ class Join(FromClause): def get_children(self, **kwargs): return self.left, self.right, self.onclause - def _match_primaries(self, primary, secondary): + def __match_primaries(self, primary, secondary): crit = [] constraints = util.Set() for fk in secondary.foreign_keys: @@ -2302,50 +2325,7 @@ class Join(FromClause): else: return and_(*crit) - def _folded_equivalents(self, equivs=None): - """Returns the column list of this Join with all equivalently-named, - equated columns folded into one column, where 'equated' means they are - equated to each other in the ON clause of this join. - - this method is used by select(fold_equivalents=True). - - The primary usage for this is when generating UNIONs so that - each selectable can have distinctly-named columns without the need - for use_labels=True. - """ - - if self.__folded_equivalents is not None: - return self.__folded_equivalents - if equivs is None: - equivs = util.Set() - class LocateEquivs(visitors.NoColumnVisitor): - def visit_binary(self, binary): - if binary.operator == operators.eq and binary.left.name == binary.right.name: - equivs.add(binary.right) - equivs.add(binary.left) - LocateEquivs().traverse(self.onclause) - collist = [] - if isinstance(self.left, Join): - left = self.left._folded_equivalents(equivs) - else: - left = list(self.left.columns) - if isinstance(self.right, Join): - right = self.right._folded_equivalents(equivs) - else: - right = list(self.right.columns) - used = util.Set() - for c in left + right: - if c in equivs: - if c.name not in used: - collist.append(c) - used.add(c.name) - else: - collist.append(c) - self.__folded_equivalents = collist - return self.__folded_equivalents - folded_equivalents = property(_folded_equivalents) - - def select(self, whereclause = None, fold_equivalents=False, **kwargs): + def select(self, whereclause=None, fold_equivalents=False, **kwargs): """Create a ``Select`` from this ``Join``. whereclause @@ -2366,7 +2346,10 @@ class Join(FromClause): """ if fold_equivalents: - collist = self.folded_equivalents + global sql_util + if not sql_util: + from sqlalchemy.sql import util as sql_util + collist = sql_util.folded_equivalents(self) else: collist = [self.left, self.right] @@ -2439,7 +2422,7 @@ class Alias(FromClause): self._oid_column = self.selectable.oid_column._make_proxy(self) def _copy_internals(self, clone=_clone): - self._clone_from_clause() + self._reset_exported() self.selectable = _clone(self.selectable) baseselectable = self.selectable while isinstance(baseselectable, Alias): @@ -2670,7 +2653,7 @@ class _ColumnClause(ColumnElement): return [] def _bind_param(self, obj): - return _BindParamClause(self._label, obj, type_=self.type, unique=True) + return _BindParamClause(self.name, obj, type_=self.type, unique=True) def _make_proxy(self, selectable, name = None): # propigate the "is_literal" flag only if we are keeping our name, @@ -2733,18 +2716,6 @@ class TableClause(FromClause): col = list(self.columns)[0] return select([func.count(col).label('tbl_row_count')], whereclause, from_obj=[self], **params) - def join(self, right, *args, **kwargs): - return Join(self, right, *args, **kwargs) - - def outerjoin(self, right, *args, **kwargs): - return Join(self, right, isouter = True, *args, **kwargs) - - def alias(self, name=None): - return Alias(self, name) - - def select(self, whereclause = None, **params): - return select([self], whereclause, **params) - def insert(self, values=None, inline=False, **kwargs): return insert(self, values=values, inline=inline, **kwargs) @@ -2780,8 +2751,8 @@ class _SelectBaseMixin(object): is eligible to be used as a scalar expression. The returned object is an instance of [sqlalchemy.sql.expression#_ScalarSelect]. - """ + """ return _ScalarSelect(self) def apply_labels(self): @@ -2791,8 +2762,8 @@ class _SelectBaseMixin(object): name, such as "SELECT somecolumn AS tablename_somecolumn". This allows selectables which contain multiple FROM clauses to produce a unique set of column names regardless of name conflicts among the individual FROM clauses. - """ + """ s = self._generate() s.use_labels = True return s @@ -2802,8 +2773,8 @@ class _SelectBaseMixin(object): with a label. See also ``as_scalar()``. - """ + """ return self.as_scalar().label(name) def supports_execution(self): @@ -2819,8 +2790,9 @@ class _SelectBaseMixin(object): return s def _generate(self): - s = self._clone() - s._clone_from_clause() + s = self.__class__.__new__(self.__class__) + s.__dict__ = self.__dict__.copy() + s._reset_exported() return s def limit(self, limit): @@ -2841,8 +2813,8 @@ class _SelectBaseMixin(object): """return a new selectable with the given list of ORDER BY criterion applied. The criterion will be appended to any pre-existing ORDER BY criterion. - """ + """ s = self._generate() s.append_order_by(*clauses) return s @@ -2851,8 +2823,8 @@ class _SelectBaseMixin(object): """return a new selectable with the given list of GROUP BY criterion applied. The criterion will be appended to any pre-existing GROUP BY criterion. - """ + """ s = self._generate() s.append_group_by(*clauses) return s @@ -2862,12 +2834,7 @@ class _SelectBaseMixin(object): The criterion will be appended to any pre-existing ORDER BY criterion. - Note that this mutates the Select construct such that derived attributes, - such as the "primary_key", "oid_column", and child "froms" collection may - be invalid if they have already been initialized. Consider the generative - form of this method instead to prevent this issue. """ - if len(clauses) == 1 and clauses[0] is None: self._order_by_clause = ClauseList() else: @@ -2880,12 +2847,7 @@ class _SelectBaseMixin(object): The criterion will be appended to any pre-existing GROUP BY criterion. - Note that this mutates the Select construct such that derived attributes, - such as the "primary_key", "oid_column", and child "froms" collection may - be invalid if they have already been initialized. Consider the generative - form of this method instead to prevent this issue. """ - if len(clauses) == 1 and clauses[0] is None: self._group_by_clause = ClauseList() else: @@ -2893,14 +2855,6 @@ class _SelectBaseMixin(object): clauses = list(self._group_by_clause) + list(clauses) self._group_by_clause = ClauseList(*clauses) - def select(self, whereclauses = None, **params): - """return a SELECT of this selectable. - - This has the effect of embeddeding this select into a subquery that is selected - from. - """ - return select([self], whereclauses, **params) - def _get_from_objects(self, is_where=False, **modifiers): if is_where: return [] @@ -2974,7 +2928,7 @@ class CompoundSelect(_SelectBaseMixin, FromClause): self._oid_column = col def _copy_internals(self, clone=_clone): - self._clone_from_clause() + self._reset_exported() self.selects = [clone(s) for s in self.selects] if hasattr(self, '_col_map'): del self._col_map @@ -3025,14 +2979,7 @@ class Select(_SelectBaseMixin, FromClause): self._should_correlate = correlate self._distinct = distinct - # NOTE: the _generate() - # operation creates a *shallow* copy of the object, so append_XXX() methods, - # usually called via a generative method, create a copy of each collection - # by default - - self.__correlate = util.Set() - self._having = None - self._prefixes = [] + self._correlate = util.Set() if columns: self._raw_columns = [ @@ -3042,7 +2989,7 @@ class Select(_SelectBaseMixin, FromClause): ] else: self._raw_columns = [] - + if from_obj: self._froms = util.Set([ _is_literal(f) and _TextFromClause(f) or f @@ -3050,7 +2997,7 @@ class Select(_SelectBaseMixin, FromClause): ]) else: self._froms = util.Set() - + if whereclause: self._whereclause = _literal_as_text(whereclause) else: @@ -3075,8 +3022,8 @@ class Select(_SelectBaseMixin, FromClause): rendered in the FROM clause of enclosing selects; this Select may want to leave those absent if it is automatically correlating. + """ - froms = util.OrderedSet() for col in self._raw_columns: @@ -3091,12 +3038,16 @@ class Select(_SelectBaseMixin, FromClause): toremove = itertools.chain(*[f._hide_froms for f in froms]) froms.difference_update(toremove) - if len(froms) > 1 or self.__correlate: - if self.__correlate: - froms.difference_update(self.__correlate) - if self._should_correlate and existing_froms is not None: - froms.difference_update(existing_froms) - + if len(froms) > 1 or self._correlate: + if self._correlate: + froms.difference_update(_cloned_intersection(froms, self._correlate)) + + if self._should_correlate and existing_froms: + froms.difference_update(_cloned_intersection(froms, existing_froms)) + + if not len(froms): + raise exceptions.InvalidRequestError("Select statement '%s' returned no FROM clauses due to auto-correlation; specify correlate(<tables>) to control correlation manually." % self) + return froms froms = property(_get_display_froms, doc="""Return a list of all FromClause elements which will be applied to the FROM clause of the resulting statement.""") @@ -3110,31 +3061,30 @@ class Select(_SelectBaseMixin, FromClause): This set is a superset of that returned by the ``froms`` property, which is specifically for those FromClause elements that would actually be rendered. + """ if hasattr(self, '_all_froms'): return self._all_froms - froms = util.Set() - for col in self._raw_columns: - for f in col._get_from_objects(): - froms.add(f) + froms = util.Set( + itertools.chain(* + [self._froms] + + [f._get_from_objects() for f in self._froms] + + [col._get_from_objects() for col in self._raw_columns] + ) + ) - if self._whereclause is not None: - for f in self._whereclause._get_from_objects(is_where=True): - froms.add(f) + if self._whereclause: + froms.update(self._whereclause._get_from_objects(is_where=True)) - for elem in self._froms: - froms.add(elem) - for f in elem._get_from_objects(): - froms.add(f) self._all_froms = froms return froms def inner_columns(self): - """a collection of all ColumnElement expressions which would + """an iteratorof all ColumnElement expressions which would be rendered into the columns clause of the resulting SELECT statement. - """ + """ for c in self._raw_columns: if isinstance(c, Selectable): for co in c.columns: @@ -3153,8 +3103,10 @@ class Select(_SelectBaseMixin, FromClause): return False def _copy_internals(self, clone=_clone): - self._clone_from_clause() - self._recorrelate_froms([(f, clone(f)) for f in self._froms]) + self._reset_exported() + from_cloned = dict([(f, clone(f)) for f in self._froms.union(self._correlate)]) + self._froms = util.Set([from_cloned[f] for f in self._froms]) + self._correlate = util.Set([from_cloned[f] for f in self._correlate]) self._raw_columns = [clone(c) for c in self._raw_columns] for attr in ('_whereclause', '_having', '_order_by_clause', '_group_by_clause'): if getattr(self, attr) is not None: @@ -3167,25 +3119,17 @@ class Select(_SelectBaseMixin, FromClause): list(self.locate_all_froms()) + \ [x for x in (self._whereclause, self._having, self._order_by_clause, self._group_by_clause) if x is not None] - def _recorrelate_froms(self, froms): - newcorrelate = util.Set() - newfroms = util.Set() - oldfroms = util.Set(self._froms) - for old, new in froms: - if old in self.__correlate: - newcorrelate.add(new) - self.__correlate.remove(old) - if old in oldfroms: - newfroms.add(new) - oldfroms.remove(old) - self.__correlate = self.__correlate.union(newcorrelate) - self._froms = [f for f in oldfroms.union(newfroms)] - def column(self, column): """return a new select() construct with the given column expression added to its columns clause.""" s = self._generate() - s.append_column(column) + column = _literal_as_column(column) + + if isinstance(column, _ScalarSelect): + column = column.self_group(against=operators.comma_op) + + s._raw_columns = s._raw_columns + [column] + return s def where(self, whereclause): @@ -3216,7 +3160,8 @@ class Select(_SelectBaseMixin, FromClause): columns clause, not using any commas.""" s = self._generate() - s.append_prefix(clause) + clause = _literal_as_text(clause) + s._prefixes = s._prefixes + [clause] return s def select_from(self, fromclause): @@ -3224,16 +3169,14 @@ class Select(_SelectBaseMixin, FromClause): FROM objects.""" s = self._generate() - s.append_from(fromclause) - return s + if _is_literal(fromclause): + fromclause = _TextFromClause(fromclause) - def __dont_correlate(self): - s = self._generate() - s._should_correlate = False + s._froms = s._froms.union([fromclause]) return s - def correlate(self, fromclause): - """return a new select() construct which will correlate the given FROM clause to that + def correlate(self, *fromclauses): + """return a new select() construct which will correlate the given FROM clauses to that of an enclosing select(), if a match is found. By "match", the given fromclause must be present in this select's list of FROM objects @@ -3243,77 +3186,47 @@ class Select(_SelectBaseMixin, FromClause): select() auto-correlates all of its FROM clauses to those of an embedded select when compiled. - If the fromclause is None, the select() will not correlate to anything. + If the fromclause is None, correlation is disabled for the returned select(). + """ - s = self._generate() s._should_correlate=False - if fromclause is None: - s.__correlate = util.Set() + if fromclauses == (None,): + s._correlate = util.Set() else: - s.append_correlation(fromclause) + s._correlate = s._correlate.union(fromclauses) return s - def append_correlation(self, fromclause, _copy_collection=True): - """append the given correlation expression to this select() construct. - - Note that this mutates the Select construct such that derived attributes, - such as the "primary_key", "oid_column", and child "froms" collection may - be invalid if they have already been initialized. Consider the generative - form of this method instead to prevent this issue. - """ - - if not _copy_collection: - self.__correlate.add(fromclause) - else: - self.__correlate = util.Set(list(self.__correlate) + [fromclause]) - - def append_column(self, column, _copy_collection=True): - """append the given column expression to the columns clause of this select() construct. + def append_correlation(self, fromclause): + """append the given correlation expression to this select() construct.""" + + self._should_correlate=False + self._correlate.add(fromclause) - Note that this mutates the Select construct such that derived attributes, - such as the "primary_key", "oid_column", and child "froms" collection may - be invalid if they have already been initialized. Consider the generative - form of this method instead to prevent this issue. - """ + def append_column(self, column): + """append the given column expression to the columns clause of this select() construct.""" column = _literal_as_column(column) if isinstance(column, _ScalarSelect): column = column.self_group(against=operators.comma_op) - if not _copy_collection: - self._raw_columns.append(column) - else: - self._raw_columns = self._raw_columns + [column] - - def append_prefix(self, clause, _copy_collection=True): - """append the given columns clause prefix expression to this select() construct. + self._raw_columns.append(column) + self._reset_exported() - Note that this mutates the Select construct such that derived attributes, - such as the "primary_key", "oid_column", and child "froms" collection may - be invalid if they have already been initialized. Consider the generative - form of this method instead to prevent this issue. - """ + def append_prefix(self, clause): + """append the given columns clause prefix expression to this select() construct.""" clause = _literal_as_text(clause) - if not _copy_collection: - self._prefixes.append(clause) - else: - self._prefixes = self._prefixes + [clause] + self._prefixes.append(clause) def append_whereclause(self, whereclause): """append the given expression to this select() construct's WHERE criterion. The expression will be joined to existing WHERE criterion via AND. - Note that this mutates the Select construct such that derived attributes, - such as the "primary_key", "oid_column", and child "froms" collection may - be invalid if they have already been initialized. Consider the generative - form of this method instead to prevent this issue. """ - - if self._whereclause is not None: + if self._whereclause is not None: self._whereclause = and_(self._whereclause, _literal_as_text(whereclause)) else: self._whereclause = _literal_as_text(whereclause) @@ -3323,33 +3236,20 @@ class Select(_SelectBaseMixin, FromClause): The expression will be joined to existing HAVING criterion via AND. - Note that this mutates the Select construct such that derived attributes, - such as the "primary_key", "oid_column", and child "froms" collection may - be invalid if they have already been initialized. Consider the generative - form of this method instead to prevent this issue. """ - if self._having is not None: self._having = and_(self._having, _literal_as_text(having)) else: self._having = _literal_as_text(having) - def append_from(self, fromclause, _copy_collection=True): + def append_from(self, fromclause): """append the given FromClause expression to this select() construct's FROM clause. - Note that this mutates the Select construct such that derived attributes, - such as the "primary_key", "oid_column", and child "froms" collection may - be invalid if they have already been initialized. Consider the generative - form of this method instead to prevent this issue. """ - if _is_literal(fromclause): fromclause = _TextFromClause(fromclause) - if not _copy_collection: - self._froms.add(fromclause) - else: - self._froms = util.Set(list(self._froms) + [fromclause]) + self._froms.add(fromclause) def __exportable_columns(self): for column in self._raw_columns: @@ -3380,8 +3280,8 @@ class Select(_SelectBaseMixin, FromClause): This produces an element that can be embedded in an expression. Note that this method is called automatically as needed when constructing expressions. - """ + """ if isinstance(against, CompoundSelect): return self return _FromGrouping(self) @@ -3454,6 +3354,11 @@ class _UpdateBase(ClauseElement): def _table_iterator(self): return iter([self.table]) + def _generate(self): + s = self.__class__.__new__(self.__class__) + s.__dict__ = self.__dict__.copy() + return s + def _process_colparams(self, parameters): if parameters is None: @@ -3532,7 +3437,7 @@ class Insert(_ValuesBase): If multiple prefixes are supplied, they will be separated with spaces. """ - gen = self._clone() + gen = self._generate() clause = _literal_as_text(clause) gen._prefixes = self._prefixes + [clause] return gen @@ -3564,7 +3469,7 @@ class Update(_ValuesBase): """return a new update() construct with the given expression added to its WHERE clause, joined to the existing clause via AND, if any.""" - s = self._clone() + s = self._generate() if s._whereclause is not None: s._whereclause = and_(s._whereclause, _literal_as_text(whereclause)) else: @@ -3591,7 +3496,7 @@ class Delete(_UpdateBase): """return a new delete() construct with the given expression added to its WHERE clause, joined to the existing clause via AND, if any.""" - s = self._clone() + s = self._generate() if s._whereclause is not None: s._whereclause = and_(s._whereclause, _literal_as_text(whereclause)) else: diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 5b9ffd4fa..dd29cb42b 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -130,7 +130,43 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None, consider_as_re pairs = [] visitors.traverse(expression, visit_binary=visit_binary) return pairs + +def folded_equivalents(join, equivs=None): + """Returns the column list of the given Join with all equivalently-named, + equated columns folded into one column, where 'equated' means they are + equated to each other in the ON clause of this join. + + This function is used by Join.select(fold_equivalents=True). + TODO: deprecate ? + """ + + if equivs is None: + equivs = util.Set() + def visit_binary(binary): + if binary.operator == operators.eq and binary.left.name == binary.right.name: + equivs.add(binary.right) + equivs.add(binary.left) + visitors.traverse(join.onclause, visit_binary=visit_binary) + collist = [] + if isinstance(join.left, expression.Join): + left = folded_equivalents(join.left, equivs) + else: + left = list(join.left.columns) + if isinstance(join.right, expression.Join): + right = folded_equivalents(join.right, equivs) + else: + right = list(join.right.columns) + used = util.Set() + for c in left + right: + if c in equivs: + if c.name not in used: + collist.append(c) + used.add(c.name) + else: + collist.append(c) + return collist + class AliasedRow(object): def __init__(self, row, map): diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 7eccc9b89..792391929 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -161,7 +161,14 @@ class NoColumnVisitor(ClauseVisitor): __traverse_options__ = {'column_collections':False} - +class NullVisitor(ClauseVisitor): + def traverse(self, obj, clone=False): + next = getattr(self, '_next', None) + if next: + return next.traverse(obj, clone=clone) + else: + return obj + def traverse(clause, **kwargs): """traverse the given clause, applying visit functions passed in as keyword arguments.""" |
