diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-12-16 18:32:25 +0000 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-12-16 18:32:25 +0000 |
| commit | abc33bd32d6fd11f46bdc3e65ce97b606ce1cb89 (patch) | |
| tree | fcd07dee8fe2a2dc4bf84dee1deed28cca6d9c8c /lib | |
| parent | 8ce3f5d6997be2d28e88f2ed982454e7b4d6e3fa (diff) | |
| download | sqlalchemy-abc33bd32d6fd11f46bdc3e65ce97b606ce1cb89.tar.gz | |
- more fixes to the LIMIT/OFFSET aliasing applied with Query + eagerloads,
in this case when mapped against a select statement [ticket:904]
- _hide_froms logic in expression totally localized to Join class, including search through previous clone sources
- removed "stop_on" from main visitors, not used
- "stop_on" in AbstractClauseProcessor part of constructor, ClauseAdapter sets it up based on given clause
- fixes to is_derived_from() to take previous clone sources into account, Alias takes self + cloned sources into account. this is ultimately what the #904 bug was.
Diffstat (limited to 'lib')
| -rw-r--r-- | lib/sqlalchemy/orm/query.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/expression.py | 58 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/util.py | 17 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/visitors.py | 25 |
4 files changed, 49 insertions, 53 deletions
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index d75a9b8b2..902a4fd3b 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -919,7 +919,7 @@ class Query(object): adapt_criterion = self.table not in self._get_joinable_tables() if not adapt_criterion and whereclause is not None and (self.mapper is not self.select_mapper): - whereclause = sql_util.ClauseAdapter(from_obj).traverse(whereclause, stop_on=util.Set([from_obj])) + whereclause = sql_util.ClauseAdapter(from_obj).traverse(whereclause) # TODO: mappers added via add_entity(), adapt their queries also, # if those mappers are polymorphic diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index a4d8fa6a0..ddeaaf8ad 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -28,6 +28,7 @@ to stay the same in future releases. import re import datetime import warnings +from itertools import chain from sqlalchemy import util, exceptions from sqlalchemy.sql import operators, visitors from sqlalchemy import types as sqltypes @@ -864,6 +865,13 @@ class ClauseElement(object): return c + def _cloned_set(self): + f = self + while f is not None: + yield f + f = getattr(f, '_is_clone_of', None) + _cloned_set = property(_cloned_set) + def _get_from_objects(self, **modifiers): """Return objects represented in this ``ClauseElement`` that should be added to the ``FROM`` list of a query, when this @@ -1543,7 +1551,8 @@ class FromClause(Selectable): __visit_name__ = 'fromclause' named_with_column=False - + _hide_froms = [] + def __init__(self): self.oid_column = None @@ -1588,7 +1597,7 @@ class FromClause(Selectable): An example would be an Alias of a Table is derived from that Table. """ - return fromclause is self + return fromclause in util.Set(self._cloned_set) def replace_selectable(self, old, alias): """replace all occurences of FromClause 'old' with the given Alias object, returning a copy of this ``FromClause``.""" @@ -1649,22 +1658,6 @@ class FromClause(Selectable): return getattr(self, 'name', self.__class__.__name__ + " object") description = property(description) - def _aggregate_hide_froms(self, **modifiers): - """Return a list of ``FROM`` clause elements which this ``FromClause`` replaces, taking into account - the element which this element was cloned from (and so on until the orginal is reached). - """ - - s = self - while s is not None: - for h in s._hide_froms(**modifiers): - yield h - s = getattr(s, '_is_clone_of', None) - - def _hide_froms(self, **modifiers): - """Return a list of ``FROM`` clause elements which this ``FromClause`` replaces.""" - - return [] - def _clone_from_clause(self): # delete all the "generated" collections of columns for a # newly cloned FromClause, so that they will be re-derived @@ -2230,6 +2223,7 @@ class Join(FromClause): def __init__(self, left, right, onclause=None, isouter = False): self.left = _selectable(left) self.right = _selectable(right).self_group() + self.oid_column = self.left.oid_column if onclause is None: self.onclause = self._match_primaries(self.left, self.right) @@ -2303,7 +2297,7 @@ class Join(FromClause): self.right = clone(self.right) self.onclause = clone(self.onclause) self.__folded_equivalents = None - + def get_children(self, **kwargs): return self.left, self.right, self.onclause @@ -2409,9 +2403,10 @@ class Join(FromClause): return self.select(use_labels=True, correlate=False).alias(name) - def _hide_froms(self, **modifiers): - return self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers) - + def _hide_froms(self): + return chain(*[x.left._get_from_objects() + x.right._get_from_objects() for x in self._cloned_set]) + _hide_froms = property(_hide_froms) + def _get_from_objects(self, **modifiers): return [self] + self.onclause._get_from_objects(**modifiers) + self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers) @@ -2450,6 +2445,8 @@ class Alias(FromClause): description = property(description) def is_derived_from(self, fromclause): + if fromclause in util.Set(self._cloned_set): + return True return self.selectable.is_derived_from(fromclause) def supports_execution(self): @@ -2527,13 +2524,11 @@ class _FromGrouping(FromClause): self.elem = elem columns = c = property(lambda s:s.elem.columns) - + _hide_froms = property(lambda s:s.elem._hide_froms) + def get_children(self, **kwargs): return self.elem, - def _hide_froms(self, **modifiers): - return self.elem._hide_froms(**modifiers) - def _copy_internals(self, clone=_clone): self.elem = clone(self.elem) @@ -3066,7 +3061,6 @@ class Select(_SelectBaseMixin, FromClause): """ froms = util.OrderedSet() - hide_froms = util.Set() for col in self._raw_columns: froms.update(col._get_from_objects()) @@ -3078,14 +3072,13 @@ class Select(_SelectBaseMixin, FromClause): froms.update(self._froms) for f in froms: - hide_froms.update(f._aggregate_hide_froms()) - froms = froms.difference(hide_froms) + froms.difference_update(f._hide_froms) if len(froms) > 1: if self.__correlate: - froms = froms.difference(self.__correlate) + froms.difference_update(self.__correlate) if self._should_correlate and existing_froms is not None: - froms = froms.difference(existing_froms) + froms.difference_update(existing_froms) if not froms: raise exceptions.InvalidRequestError("Select statement '%s' is overcorrelated; returned no 'from' clauses" % str(self.__dont_correlate())) @@ -3129,6 +3122,9 @@ class Select(_SelectBaseMixin, FromClause): inner_columns = property(_get_inner_columns, doc="""a collection of all ColumnElement expressions which would be rendered into the columns clause of the resulting SELECT statement.""") def is_derived_from(self, fromclause): + if self in util.Set(fromclause._cloned_set): + return True + for f in self.locate_all_froms(): if f.is_derived_from(fromclause): return True diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 5aa985f47..d6b10a78a 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -71,6 +71,9 @@ class AbstractClauseProcessor(object): __traverse_options__ = {'column_collections':False} + def __init__(self, stop_on=None): + self.stop_on = stop_on + def convert_element(self, elem): """Define the *conversion* method for this ``AbstractClauseProcessor``.""" @@ -92,13 +95,14 @@ class AbstractClauseProcessor(object): setattr(tail, attr, visitor) return self - def copy_and_process(self, list_, stop_on=None): + def copy_and_process(self, list_): """Copy the given list to a new list, with each element traversed individually.""" list_ = list(list_) - stop_on = util.Set() + stop_on = util.Set(self.stop_on or []) + cloned = {} for i in range(0, len(list_)): - list_[i] = self.traverse(list_[i], stop_on=stop_on) + list_[i] = self._traverse(list_[i], stop_on, cloned, _clone_toplevel=True) return list_ def _convert_element(self, elem, stop_on, cloned): @@ -116,13 +120,11 @@ class AbstractClauseProcessor(object): cloned[elem] = elem._clone() return cloned[elem] - def traverse(self, elem, clone=True, stop_on=None): + def traverse(self, elem, clone=True): if not clone: raise exceptions.ArgumentError("AbstractClauseProcessor 'clone' argument must be True") - if stop_on is None: - stop_on = util.Set() - return self._traverse(elem, stop_on, {}, _clone_toplevel=True) + return self._traverse(elem, util.Set(self.stop_on or []), {}, _clone_toplevel=True) def _traverse(self, elem, stop_on, cloned, _clone_toplevel=False): if elem in stop_on: @@ -178,6 +180,7 @@ class ClauseAdapter(AbstractClauseProcessor): """ def __init__(self, selectable, include=None, exclude=None, equivalents=None): + AbstractClauseProcessor.__init__(self, [selectable]) self.selectable = selectable self.include = include self.exclude = exclude diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 150ee9cc7..bb63ab09c 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -37,18 +37,17 @@ class ClauseVisitor(object): meth(obj, **kwargs) v = getattr(v, '_next', None) - def iterate(self, obj, stop_on=None): + def iterate(self, obj): stack = [obj] traversal = [] while len(stack) > 0: t = stack.pop() - if stop_on is None or t not in stop_on: - yield t - traversal.insert(0, t) - for c in t.get_children(**self.__traverse_options__): - stack.append(c) + yield t + traversal.insert(0, t) + for c in t.get_children(**self.__traverse_options__): + stack.append(c) - def traverse(self, obj, stop_on=None, clone=False): + def traverse(self, obj, clone=False): if clone: cloned = {} @@ -60,17 +59,15 @@ class ClauseVisitor(object): return cloned[obj] obj = do_clone(obj) - stack = [obj] traversal = [] while len(stack) > 0: t = stack.pop() - if stop_on is None or t not in stop_on: - traversal.insert(0, t) - if clone: - t._copy_internals(clone=do_clone) - for c in t.get_children(**self.__traverse_options__): - stack.append(c) + traversal.insert(0, t) + if clone: + t._copy_internals(clone=do_clone) + for c in t.get_children(**self.__traverse_options__): + stack.append(c) for target in traversal: v = self while v is not None: |
