summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/sqlalchemy/orm/interfaces.py16
-rw-r--r--lib/sqlalchemy/orm/properties.py87
-rw-r--r--lib/sqlalchemy/orm/strategies.py18
-rw-r--r--lib/sqlalchemy/orm/util.py57
-rw-r--r--lib/sqlalchemy/sql/expression.py13
-rw-r--r--lib/sqlalchemy/sql/util.py70
6 files changed, 195 insertions, 66 deletions
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index 631d3f582..bd934ce13 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -438,9 +438,21 @@ class PropComparator(expression.ColumnOperators):
PropComparator.
"""
+ def __init__(self, prop, mapper, adapter=None):
+ self.prop = self.property = prop
+ self.mapper = mapper
+ self.adapter = adapter
+
def __clause_element__(self):
raise NotImplementedError("%r" % self)
+ def adapted(self, adapter):
+ """Return a copy of this PropComparator which will use the given adaption function
+ on the local side of generated expressions.
+
+ """
+ return self.__class__(self.prop, self.mapper, adapter)
+
@staticmethod
def any_op(a, b, **kwargs):
return a.any(b, **kwargs)
@@ -449,10 +461,6 @@ class PropComparator(expression.ColumnOperators):
def has_op(a, b, **kwargs):
return a.has(b, **kwargs)
- def __init__(self, prop, mapper):
- self.prop = self.property = prop
- self.mapper = mapper
-
@staticmethod
def of_type_op(a, class_):
return a.of_type(class_)
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index 87e35eb83..2b860af37 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -18,7 +18,7 @@ from sqlalchemy.sql import operators, expression
from sqlalchemy.orm import (
attributes, dependency, mapper, object_mapper, strategies,
)
-from sqlalchemy.orm.util import CascadeOptions, _class_to_mapper, _orm_annotate
+from sqlalchemy.orm.util import CascadeOptions, _class_to_mapper, _orm_annotate, _orm_deannotate
from sqlalchemy.orm.interfaces import (
MANYTOMANY, MANYTOONE, MapperProperty, ONETOMANY, PropComparator,
StrategizedProperty,
@@ -85,8 +85,11 @@ class ColumnProperty(StrategizedProperty):
class ColumnComparator(PropComparator):
@util.memoized_instancemethod
def __clause_element__(self):
- return self.prop.columns[0]._annotate({"parententity": self.mapper})
-
+ if self.adapter:
+ return self.adapter(self.prop.columns[0])
+ else:
+ return self.prop.columns[0]._annotate({"parententity": self.mapper})
+
def operate(self, op, *other, **kwargs):
return op(self.__clause_element__(), *other, **kwargs)
@@ -147,7 +150,11 @@ class CompositeProperty(ColumnProperty):
class Comparator(PropComparator):
def __clause_element__(self):
- return expression.ClauseList(*self.prop.columns)
+ if self.adapter:
+ # TODO: test coverage for adapted composite comparison
+ return expression.ClauseList(*[self.adapter(x) for x in self.prop.columns])
+ else:
+ return expression.ClauseList(*self.prop.columns)
def __eq__(self, other):
if other is None:
@@ -318,18 +325,30 @@ class PropertyLoader(StrategizedProperty):
self._is_backref = _is_backref
class Comparator(PropComparator):
- def __init__(self, prop, mapper, of_type=None):
+ def __init__(self, prop, mapper, of_type=None, adapter=None):
self.prop = self.property = prop
self.mapper = mapper
+ self.adapter = adapter
if of_type:
self._of_type = _class_to_mapper(of_type)
+ def adapted(self, adapter):
+ """Return a copy of this PropComparator which will use the given adaption function
+ on the local side of generated expressions.
+
+ """
+ return PropertyLoader.Comparator(self.prop, self.mapper, getattr(self, '_of_type', None), adapter)
+
@property
def parententity(self):
return self.prop.parent
def __clause_element__(self):
- return self.prop.parent._with_polymorphic_selectable
+ elem = self.prop.parent._with_polymorphic_selectable
+ if self.adapter:
+ return self.adapter(elem)
+ else:
+ return elem
def operate(self, op, *other, **kwargs):
return op(self, *other, **kwargs)
@@ -343,13 +362,13 @@ class PropertyLoader(StrategizedProperty):
def __eq__(self, other):
if other is None:
if self.prop.direction in [ONETOMANY, MANYTOMANY]:
- return ~sql.exists([1], self.prop.primaryjoin)
+ return ~self._criterion_exists()
else:
- return self.prop._optimized_compare(None)
+ return self.prop._optimized_compare(None, adapt_source=self.adapter)
elif self.prop.uselist:
raise sa_exc.InvalidRequestError("Can't compare a collection to an object or collection; use contains() to test for membership.")
else:
- return self.prop._optimized_compare(other)
+ return self.prop._optimized_compare(other, adapt_source=self.adapter)
def _criterion_exists(self, criterion=None, **kwargs):
if getattr(self, '_of_type', None):
@@ -360,7 +379,12 @@ class PropertyLoader(StrategizedProperty):
else:
to_selectable = None
- pj, sj, source, dest, secondary, target_adapter = self.prop._create_joins(dest_polymorphic=True, dest_selectable=to_selectable)
+ if self.adapter:
+ source_selectable = self.__clause_element__()
+ else:
+ source_selectable = None
+ pj, sj, source, dest, secondary, target_adapter = \
+ self.prop._create_joins(dest_polymorphic=True, dest_selectable=to_selectable, source_selectable=source_selectable)
for k in kwargs:
crit = self.prop.mapper.class_manager.get_inst(k) == kwargs[k]
@@ -368,12 +392,15 @@ class PropertyLoader(StrategizedProperty):
criterion = crit
else:
criterion = criterion & crit
-
+
+ # annotate the *local* side of the join condition, in the case of pj + sj this
+ # is the full primaryjoin, in the case of just pj its the local side of
+ # the primaryjoin.
if sj:
j = _orm_annotate(pj) & sj
else:
j = _orm_annotate(pj, exclude=self.prop.remote_side)
-
+
if criterion and target_adapter:
# limit this adapter to annotated only?
criterion = target_adapter.traverse(criterion)
@@ -383,7 +410,10 @@ class PropertyLoader(StrategizedProperty):
# to anything in the enclosing query.
if criterion:
criterion = criterion._annotate({'_halt_adapt': True})
- return sql.exists([1], j & criterion, from_obj=dest).correlate(source)
+
+ crit = j & criterion
+
+ return sql.exists([1], crit, from_obj=dest).correlate(source)
def any(self, criterion=None, **kwargs):
if not self.prop.uselist:
@@ -399,7 +429,7 @@ class PropertyLoader(StrategizedProperty):
def contains(self, other, **kwargs):
if not self.prop.uselist:
raise sa_exc.InvalidRequestError("'contains' not implemented for scalar attributes. Use ==")
- clause = self.prop._optimized_compare(other)
+ clause = self.prop._optimized_compare(other, adapt_source=self.adapter)
if self.prop.secondaryjoin:
clause.negation_clause = self.__negated_contains_or_equals(other)
@@ -410,12 +440,22 @@ class PropertyLoader(StrategizedProperty):
if self.prop.direction == MANYTOONE:
state = attributes.instance_state(other)
strategy = self.prop._get_strategy(strategies.LazyLoader)
+
+ def state_bindparam(state, col):
+ o = state.obj() # strong ref
+ return lambda: self.prop.mapper._get_committed_attr_by_column(o, col)
+
+ def adapt(col):
+ if self.adapter:
+ return self.adapter(col)
+ else:
+ return col
+
if strategy.use_get:
return sql.and_(*[
sql.or_(
- x !=
- self.prop.mapper._get_committed_state_attr_by_column(state, y),
- x == None)
+ adapt(x) != state_bindparam(state, y),
+ adapt(x) == None)
for (x, y) in self.prop.local_remote_pairs])
criterion = sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))])
@@ -444,10 +484,11 @@ class PropertyLoader(StrategizedProperty):
else:
return op(self.comparator, value)
- def _optimized_compare(self, value, value_is_parent=False):
+ def _optimized_compare(self, value, value_is_parent=False, adapt_source=None):
if value is not None:
value = attributes.instance_state(value)
- return self._get_strategy(strategies.LazyLoader).lazy_clause(value, reverse_direction=not value_is_parent, alias_secondary=True)
+ return self._get_strategy(strategies.LazyLoader).\
+ lazy_clause(value, reverse_direction=not value_is_parent, alias_secondary=True, adapt_source=adapt_source)
def __str__(self):
return str(self.parent.class_.__name__) + "." + self.key
@@ -549,6 +590,14 @@ class PropertyLoader(StrategizedProperty):
for attr in ('order_by', 'primaryjoin', 'secondaryjoin', 'secondary', '_foreign_keys', 'remote_side'):
if callable(getattr(self, attr)):
setattr(self, attr, getattr(self, attr)())
+
+ # in the case that InstrumentedAttributes were used to construct
+ # primaryjoin or secondaryjoin, remove the "_orm_adapt" annotation so these
+ # interact with Query in the same way as the original Table-bound Column objects
+ for attr in ('primaryjoin', 'secondaryjoin'):
+ val = getattr(self, attr)
+ if val:
+ setattr(self, attr, _orm_deannotate(val))
if self.order_by:
self.order_by = [expression._literal_as_column(x) for x in util.to_list(self.order_by)]
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
index 1962a7e2d..ba5541944 100644
--- a/lib/sqlalchemy/orm/strategies.py
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -353,9 +353,9 @@ class LazyLoader(AbstractRelationLoader):
self.is_class_level = True
self._register_attribute(self.parent.class_, callable_=self.class_level_loader)
- def lazy_clause(self, state, reverse_direction=False, alias_secondary=False):
+ def lazy_clause(self, state, reverse_direction=False, alias_secondary=False, adapt_source=None):
if state is None:
- return self._lazy_none_clause(reverse_direction)
+ return self._lazy_none_clause(reverse_direction, adapt_source=adapt_source)
if not reverse_direction:
(criterion, bind_to_col, rev) = (self.__lazywhere, self.__bind_to_col, self._equated_columns)
@@ -374,9 +374,12 @@ class LazyLoader(AbstractRelationLoader):
if self.parent_property.secondary and alias_secondary:
criterion = sql_util.ClauseAdapter(self.parent_property.secondary.alias()).traverse(criterion)
- return visitors.cloned_traverse(criterion, {}, {'bindparam':visit_bindparam})
-
- def _lazy_none_clause(self, reverse_direction=False):
+ criterion = visitors.cloned_traverse(criterion, {}, {'bindparam':visit_bindparam})
+ if adapt_source:
+ criterion = adapt_source(criterion)
+ return criterion
+
+ def _lazy_none_clause(self, reverse_direction=False, adapt_source=None):
if not reverse_direction:
(criterion, bind_to_col, rev) = (self.__lazywhere, self.__bind_to_col, self._equated_columns)
else:
@@ -393,7 +396,10 @@ class LazyLoader(AbstractRelationLoader):
binary.right = expression.null()
binary.operator = operators.is_
- return visitors.cloned_traverse(criterion, {}, {'binary':visit_binary})
+ criterion = visitors.cloned_traverse(criterion, {}, {'binary':visit_binary})
+ if adapt_source:
+ criterion = adapt_source(criterion)
+ return criterion
def class_level_loader(self, state, options=None, path=None):
if not mapperutil._state_has_identity(state):
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 264a4d212..fbc1acd5d 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -243,6 +243,12 @@ class ExtensionCarrier(dict):
return self.get(key, self._pass)
class ORMAdapter(sql_util.ColumnAdapter):
+ """Extends ColumnAdapter to accept ORM entities.
+
+ The selectable is extracted from the given entity,
+ and the AliasedClass if any is referenced.
+
+ """
def __init__(self, entity, equivalents=None, chain_to=None):
mapper, selectable, is_aliased_class = _entity_info(entity)
if is_aliased_class:
@@ -252,18 +258,36 @@ class ORMAdapter(sql_util.ColumnAdapter):
sql_util.ColumnAdapter.__init__(self, selectable, equivalents, chain_to)
class AliasedClass(object):
+ """Represents an 'alias'ed form of a mapped class for usage with Query.
+
+ The ORM equivalent of a sqlalchemy.sql.expression.Alias
+ object, this object mimics the mapped class using a
+ __getattr__ scheme and maintains a reference to a
+ real Alias object. It indicates to Query that the
+ selectable produced for this class should be aliased,
+ and also adapts PropComparators produced by the class'
+ InstrumentedAttributes so that they adapt the
+ "local" side of SQL expressions against the alias.
+
+ """
def __init__(self, cls, alias=None, name=None):
self.__mapper = _class_to_mapper(cls)
self.__target = self.__mapper.class_
alias = alias or self.__mapper._with_polymorphic_selectable.alias()
self.__adapter = sql_util.ClauseAdapter(alias, equivalents=self.__mapper._equivalent_columns)
self.__alias = alias
+ # used to assign a name to the RowTuple object
+ # returned by Query.
self._sa_label_name = name
self.__name__ = 'AliasedClass_' + str(self.__target)
+ def __adapt_element(self, elem):
+ return self.__adapter.traverse(elem)._annotate({'parententity': self})
+
def __adapt_prop(self, prop):
existing = getattr(self.__target, prop.key)
- comparator = AliasedComparator(self, self.__adapter, existing.comparator)
+ comparator = existing.comparator.adapted(self.__adapt_element)
+
queryattr = attributes.QueryableAttribute(
existing.impl, parententity=self, comparator=comparator)
setattr(self, prop.key, queryattr)
@@ -299,41 +323,16 @@ class AliasedClass(object):
return '<AliasedClass at 0x%x; %s>' % (
id(self), self.__target.__name__)
-class AliasedComparator(PropComparator):
- def __init__(self, aliasedclass, adapter, comparator):
- self.aliasedclass = aliasedclass
- self.comparator = comparator
- self.adapter = adapter
- self.__clause_element = self.adapter.traverse(self.comparator.__clause_element__())._annotate({'parententity': aliasedclass})
-
- def __clause_element__(self):
- return self.__clause_element
-
- def operate(self, op, *other, **kwargs):
- return self.adapter.traverse(self.comparator.operate(op, *other, **kwargs))
-
- def reverse_operate(self, op, other, **kwargs):
- return self.adapter.traverse(self.comparator.reverse_operate(op, *other, **kwargs))
-
def _orm_annotate(element, exclude=None):
"""Deep copy the given ClauseElement, annotating each element with the "_orm_adapt" flag.
Elements within the exclude collection will be cloned but not annotated.
"""
- def clone(elem):
- if exclude and elem in exclude:
- elem = elem._clone()
- elif '_orm_adapt' not in elem._annotations:
- elem = elem._annotate({'_orm_adapt':True})
- elem._copy_internals(clone=clone)
- return elem
-
- if element is not None:
- element = clone(element)
- return element
-
+ return sql_util._deep_annotate(element, {'_orm_adapt':True}, exclude)
+_orm_deannotate = sql_util._deep_deannotate
+
class _ORMJoin(expression.Join):
"""Extend Join to support ORM constructs as input."""
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index a9ef45dc1..3b996d6cb 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -987,6 +987,13 @@ class ClauseElement(Visitable):
@property
def _cloned_set(self):
+ """Return the set consisting all cloned anscestors of this ClauseElement.
+
+ Includes this ClauseElement. This accessor tends to be used for
+ FromClause objects to identify 'equivalent' FROM clauses, regardless
+ of transformative operations.
+
+ """
f = self
while f is not None:
yield f
@@ -1008,7 +1015,11 @@ class ClauseElement(Visitable):
if Annotated is None:
from sqlalchemy.sql.util import Annotated
return Annotated(self, values)
-
+
+ def _deannotate(self):
+ """return a copy of this ClauseElement with an empty annotations dictionary."""
+ return self._clone()
+
def unique_params(self, *optionaldict, **kwargs):
"""Return a copy with ``bindparam()`` elments replaced.
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index d9c3ed899..2a510906b 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -159,6 +159,20 @@ class Annotated(object):
clone.__dict__ = self.__dict__.copy()
clone._annotations = _values
return clone
+
+ def _deannotate(self):
+ return self.__element
+
+ def _clone(self):
+ clone = self.__element._clone()
+ if clone is self.__element:
+ # detect immutable, don't change anything
+ return self
+ else:
+ # update the clone with any changes that have occured
+ # to this object's __dict__.
+ clone.__dict__.update(self.__dict__)
+ return Annotated(clone, self._annotations)
def __hash__(self):
return hash(self.__element)
@@ -166,6 +180,39 @@ class Annotated(object):
def __cmp__(self, other):
return cmp(hash(self.__element), hash(other))
+def _deep_annotate(element, annotations, exclude=None):
+ """Deep copy the given ClauseElement, annotating each element with the given annotations dictionary.
+
+ Elements within the exclude collection will be cloned but not annotated.
+
+ """
+ def clone(elem):
+ # check if element is present in the exclude list.
+ # take into account proxying relationships.
+ if exclude and elem.proxy_set.intersection(exclude):
+ elem = elem._clone()
+ elif annotations != elem._annotations:
+ elem = elem._annotate(annotations.copy())
+ elem._copy_internals(clone=clone)
+ return elem
+
+ if element is not None:
+ element = clone(element)
+ return element
+
+def _deep_deannotate(element):
+ """Deep copy the given element, removing all annotations."""
+
+ def clone(elem):
+ elem = elem._deannotate()
+ elem._copy_internals(clone=clone)
+ return elem
+
+ if element is not None:
+ element = clone(element)
+ return element
+
+
def splice_joins(left, right, stop_on=None):
if left is None:
return right
@@ -208,7 +255,6 @@ def reduce_columns(columns, *clauses, **kw):
in the the selectable to just those that are not repeated.
"""
-
ignore_nonexistent_tables = kw.pop('ignore_nonexistent_tables', False)
columns = util.OrderedSet(columns)
@@ -317,7 +363,12 @@ def folded_equivalents(join, equivs=None):
return collist
class AliasedRow(object):
+ """Wrap a RowProxy with a translation map.
+
+ This object allows a set of keys to be translated
+ to those present in a RowProxy.
+ """
def __init__(self, row, map):
# AliasedRow objects don't nest, so un-nest
# if another AliasedRow was passed
@@ -341,10 +392,8 @@ class AliasedRow(object):
class ClauseAdapter(visitors.ReplacingCloningVisitor):
- """Given a clause (like as in a WHERE criterion), locate columns
- which are embedded within a given selectable, and changes those
- columns to be that of the selectable.
-
+ """Clones and modifies clauses based on column correspondence.
+
E.g.::
table1 = Table('sometable', metadata,
@@ -358,7 +407,7 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor):
condition = table1.c.col1 == table2.c.col1
- and make an alias of table1::
+ make an alias of table1::
s = table1.alias('foo')
@@ -401,7 +450,14 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor):
return self._corresponding_column(col, True)
class ColumnAdapter(ClauseAdapter):
-
+ """Extends ClauseAdapter with extra utility functions.
+
+ Provides the ability to "wrap" this ClauseAdapter
+ around another, a columns dictionary which returns
+ cached, adapted elements given an original, and an
+ adapted_row() factory.
+
+ """
def __init__(self, selectable, equivalents=None, chain_to=None, include=None, exclude=None):
ClauseAdapter.__init__(self, selectable, equivalents, include, exclude)
if chain_to: