summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2007-12-16 18:32:25 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2007-12-16 18:32:25 +0000
commitabc33bd32d6fd11f46bdc3e65ce97b606ce1cb89 (patch)
treefcd07dee8fe2a2dc4bf84dee1deed28cca6d9c8c /lib
parent8ce3f5d6997be2d28e88f2ed982454e7b4d6e3fa (diff)
downloadsqlalchemy-abc33bd32d6fd11f46bdc3e65ce97b606ce1cb89.tar.gz
- more fixes to the LIMIT/OFFSET aliasing applied with Query + eagerloads,
in this case when mapped against a select statement [ticket:904] - _hide_froms logic in expression totally localized to Join class, including search through previous clone sources - removed "stop_on" from main visitors, not used - "stop_on" in AbstractClauseProcessor part of constructor, ClauseAdapter sets it up based on given clause - fixes to is_derived_from() to take previous clone sources into account, Alias takes self + cloned sources into account. this is ultimately what the #904 bug was.
Diffstat (limited to 'lib')
-rw-r--r--lib/sqlalchemy/orm/query.py2
-rw-r--r--lib/sqlalchemy/sql/expression.py58
-rw-r--r--lib/sqlalchemy/sql/util.py17
-rw-r--r--lib/sqlalchemy/sql/visitors.py25
4 files changed, 49 insertions, 53 deletions
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index d75a9b8b2..902a4fd3b 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -919,7 +919,7 @@ class Query(object):
adapt_criterion = self.table not in self._get_joinable_tables()
if not adapt_criterion and whereclause is not None and (self.mapper is not self.select_mapper):
- whereclause = sql_util.ClauseAdapter(from_obj).traverse(whereclause, stop_on=util.Set([from_obj]))
+ whereclause = sql_util.ClauseAdapter(from_obj).traverse(whereclause)
# TODO: mappers added via add_entity(), adapt their queries also,
# if those mappers are polymorphic
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index a4d8fa6a0..ddeaaf8ad 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -28,6 +28,7 @@ to stay the same in future releases.
import re
import datetime
import warnings
+from itertools import chain
from sqlalchemy import util, exceptions
from sqlalchemy.sql import operators, visitors
from sqlalchemy import types as sqltypes
@@ -864,6 +865,13 @@ class ClauseElement(object):
return c
+ def _cloned_set(self):
+ f = self
+ while f is not None:
+ yield f
+ f = getattr(f, '_is_clone_of', None)
+ _cloned_set = property(_cloned_set)
+
def _get_from_objects(self, **modifiers):
"""Return objects represented in this ``ClauseElement`` that
should be added to the ``FROM`` list of a query, when this
@@ -1543,7 +1551,8 @@ class FromClause(Selectable):
__visit_name__ = 'fromclause'
named_with_column=False
-
+ _hide_froms = []
+
def __init__(self):
self.oid_column = None
@@ -1588,7 +1597,7 @@ class FromClause(Selectable):
An example would be an Alias of a Table is derived from that Table.
"""
- return fromclause is self
+ return fromclause in util.Set(self._cloned_set)
def replace_selectable(self, old, alias):
"""replace all occurences of FromClause 'old' with the given Alias object, returning a copy of this ``FromClause``."""
@@ -1649,22 +1658,6 @@ class FromClause(Selectable):
return getattr(self, 'name', self.__class__.__name__ + " object")
description = property(description)
- def _aggregate_hide_froms(self, **modifiers):
- """Return a list of ``FROM`` clause elements which this ``FromClause`` replaces, taking into account
- the element which this element was cloned from (and so on until the orginal is reached).
- """
-
- s = self
- while s is not None:
- for h in s._hide_froms(**modifiers):
- yield h
- s = getattr(s, '_is_clone_of', None)
-
- def _hide_froms(self, **modifiers):
- """Return a list of ``FROM`` clause elements which this ``FromClause`` replaces."""
-
- return []
-
def _clone_from_clause(self):
# delete all the "generated" collections of columns for a
# newly cloned FromClause, so that they will be re-derived
@@ -2230,6 +2223,7 @@ class Join(FromClause):
def __init__(self, left, right, onclause=None, isouter = False):
self.left = _selectable(left)
self.right = _selectable(right).self_group()
+
self.oid_column = self.left.oid_column
if onclause is None:
self.onclause = self._match_primaries(self.left, self.right)
@@ -2303,7 +2297,7 @@ class Join(FromClause):
self.right = clone(self.right)
self.onclause = clone(self.onclause)
self.__folded_equivalents = None
-
+
def get_children(self, **kwargs):
return self.left, self.right, self.onclause
@@ -2409,9 +2403,10 @@ class Join(FromClause):
return self.select(use_labels=True, correlate=False).alias(name)
- def _hide_froms(self, **modifiers):
- return self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers)
-
+ def _hide_froms(self):
+ return chain(*[x.left._get_from_objects() + x.right._get_from_objects() for x in self._cloned_set])
+ _hide_froms = property(_hide_froms)
+
def _get_from_objects(self, **modifiers):
return [self] + self.onclause._get_from_objects(**modifiers) + self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers)
@@ -2450,6 +2445,8 @@ class Alias(FromClause):
description = property(description)
def is_derived_from(self, fromclause):
+ if fromclause in util.Set(self._cloned_set):
+ return True
return self.selectable.is_derived_from(fromclause)
def supports_execution(self):
@@ -2527,13 +2524,11 @@ class _FromGrouping(FromClause):
self.elem = elem
columns = c = property(lambda s:s.elem.columns)
-
+ _hide_froms = property(lambda s:s.elem._hide_froms)
+
def get_children(self, **kwargs):
return self.elem,
- def _hide_froms(self, **modifiers):
- return self.elem._hide_froms(**modifiers)
-
def _copy_internals(self, clone=_clone):
self.elem = clone(self.elem)
@@ -3066,7 +3061,6 @@ class Select(_SelectBaseMixin, FromClause):
"""
froms = util.OrderedSet()
- hide_froms = util.Set()
for col in self._raw_columns:
froms.update(col._get_from_objects())
@@ -3078,14 +3072,13 @@ class Select(_SelectBaseMixin, FromClause):
froms.update(self._froms)
for f in froms:
- hide_froms.update(f._aggregate_hide_froms())
- froms = froms.difference(hide_froms)
+ froms.difference_update(f._hide_froms)
if len(froms) > 1:
if self.__correlate:
- froms = froms.difference(self.__correlate)
+ froms.difference_update(self.__correlate)
if self._should_correlate and existing_froms is not None:
- froms = froms.difference(existing_froms)
+ froms.difference_update(existing_froms)
if not froms:
raise exceptions.InvalidRequestError("Select statement '%s' is overcorrelated; returned no 'from' clauses" % str(self.__dont_correlate()))
@@ -3129,6 +3122,9 @@ class Select(_SelectBaseMixin, FromClause):
inner_columns = property(_get_inner_columns, doc="""a collection of all ColumnElement expressions which would be rendered into the columns clause of the resulting SELECT statement.""")
def is_derived_from(self, fromclause):
+ if self in util.Set(fromclause._cloned_set):
+ return True
+
for f in self.locate_all_froms():
if f.is_derived_from(fromclause):
return True
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index 5aa985f47..d6b10a78a 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -71,6 +71,9 @@ class AbstractClauseProcessor(object):
__traverse_options__ = {'column_collections':False}
+ def __init__(self, stop_on=None):
+ self.stop_on = stop_on
+
def convert_element(self, elem):
"""Define the *conversion* method for this ``AbstractClauseProcessor``."""
@@ -92,13 +95,14 @@ class AbstractClauseProcessor(object):
setattr(tail, attr, visitor)
return self
- def copy_and_process(self, list_, stop_on=None):
+ def copy_and_process(self, list_):
"""Copy the given list to a new list, with each element traversed individually."""
list_ = list(list_)
- stop_on = util.Set()
+ stop_on = util.Set(self.stop_on or [])
+ cloned = {}
for i in range(0, len(list_)):
- list_[i] = self.traverse(list_[i], stop_on=stop_on)
+ list_[i] = self._traverse(list_[i], stop_on, cloned, _clone_toplevel=True)
return list_
def _convert_element(self, elem, stop_on, cloned):
@@ -116,13 +120,11 @@ class AbstractClauseProcessor(object):
cloned[elem] = elem._clone()
return cloned[elem]
- def traverse(self, elem, clone=True, stop_on=None):
+ def traverse(self, elem, clone=True):
if not clone:
raise exceptions.ArgumentError("AbstractClauseProcessor 'clone' argument must be True")
- if stop_on is None:
- stop_on = util.Set()
- return self._traverse(elem, stop_on, {}, _clone_toplevel=True)
+ return self._traverse(elem, util.Set(self.stop_on or []), {}, _clone_toplevel=True)
def _traverse(self, elem, stop_on, cloned, _clone_toplevel=False):
if elem in stop_on:
@@ -178,6 +180,7 @@ class ClauseAdapter(AbstractClauseProcessor):
"""
def __init__(self, selectable, include=None, exclude=None, equivalents=None):
+ AbstractClauseProcessor.__init__(self, [selectable])
self.selectable = selectable
self.include = include
self.exclude = exclude
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 150ee9cc7..bb63ab09c 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -37,18 +37,17 @@ class ClauseVisitor(object):
meth(obj, **kwargs)
v = getattr(v, '_next', None)
- def iterate(self, obj, stop_on=None):
+ def iterate(self, obj):
stack = [obj]
traversal = []
while len(stack) > 0:
t = stack.pop()
- if stop_on is None or t not in stop_on:
- yield t
- traversal.insert(0, t)
- for c in t.get_children(**self.__traverse_options__):
- stack.append(c)
+ yield t
+ traversal.insert(0, t)
+ for c in t.get_children(**self.__traverse_options__):
+ stack.append(c)
- def traverse(self, obj, stop_on=None, clone=False):
+ def traverse(self, obj, clone=False):
if clone:
cloned = {}
@@ -60,17 +59,15 @@ class ClauseVisitor(object):
return cloned[obj]
obj = do_clone(obj)
-
stack = [obj]
traversal = []
while len(stack) > 0:
t = stack.pop()
- if stop_on is None or t not in stop_on:
- traversal.insert(0, t)
- if clone:
- t._copy_internals(clone=do_clone)
- for c in t.get_children(**self.__traverse_options__):
- stack.append(c)
+ traversal.insert(0, t)
+ if clone:
+ t._copy_internals(clone=do_clone)
+ for c in t.get_children(**self.__traverse_options__):
+ stack.append(c)
for target in traversal:
v = self
while v is not None: