summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/orm/interfaces.py25
-rw-r--r--lib/sqlalchemy/orm/mapper.py9
-rw-r--r--lib/sqlalchemy/orm/properties.py106
-rw-r--r--lib/sqlalchemy/orm/query.py45
4 files changed, 118 insertions, 67 deletions
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index 2c3a98c88..f510a3ffa 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -435,8 +435,29 @@ class PropComparator(expression.ColumnOperators):
has_op = staticmethod(has_op)
def __init__(self, prop):
- self.prop = prop
-
+ self.prop = self.property = prop
+
+ def of_type_op(a, class_):
+ return a.of_type(class_)
+ of_type_op = staticmethod(of_type_op)
+
+ def of_type(self, class_):
+ """Redefine this object in terms of a polymorphic subclass.
+
+ Returns a new PropComparator from which further criterion can be evaulated.
+
+ class_
+ a class or mapper indicating that criterion will be against
+ this specific subclass.
+
+ e.g.::
+ query.join(Company.employees.of_type(Engineer)).\
+ filter(Engineer.name=='foo')
+
+ """
+
+ return self.operate(PropComparator.of_type_op, class_)
+
def contains(self, other):
"""Return true if this collection contains other"""
return self.operate(PropComparator.contains_op, other)
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index 85aec2f44..c0a15c427 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -1628,3 +1628,12 @@ def class_mapper(class_, entity_name=None, compile=True, raiseerror=True):
return mapper.compile()
else:
return mapper
+
+def _class_to_mapper(class_or_mapper, entity_name=None, compile=True):
+ if isinstance(class_or_mapper, type):
+ return class_mapper(class_or_mapper, entity_name=entity_name, compile=compile)
+ else:
+ if compile:
+ return class_or_mapper.compile()
+ else:
+ return class_or_mapper
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index 6339ec575..74d4c04ca 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -15,6 +15,7 @@ from sqlalchemy.sql.util import ClauseAdapter, ColumnsInClause
from sqlalchemy.sql import visitors, operators, ColumnElement
from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency, object_mapper
from sqlalchemy.orm import session as sessionlib
+from sqlalchemy.orm.mapper import _class_to_mapper
from sqlalchemy.orm.util import CascadeOptions, PropertyAliasedClauses
from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator, MapperProperty
from sqlalchemy.exceptions import ArgumentError
@@ -244,6 +245,14 @@ class PropertyLoader(StrategizedProperty):
self.is_backref = is_backref
class Comparator(PropComparator):
+ def __init__(self, prop, of_type=None):
+ self.prop = self.property = prop
+ if of_type:
+ self._of_type = _class_to_mapper(of_type)
+
+ def of_type(self, cls):
+ return PropertyLoader.Comparator(self.prop, cls)
+
def __eq__(self, other):
if other is None:
if self.prop.uselist:
@@ -267,17 +276,28 @@ class PropertyLoader(StrategizedProperty):
return self.prop._optimized_compare(other)
def _join_and_criterion(self, criterion=None, **kwargs):
+ adapt_against = None
+
+ if getattr(self, '_of_type', None):
+ target_mapper = self._of_type
+ to_selectable = target_mapper.select_table
+ adapt_against = to_selectable
+ else:
+ target_mapper = self.prop.mapper
+ to_selectable = None
+ if target_mapper.select_table is not target_mapper.mapped_table:
+ adapt_against = target_mapper.select_table
+
if self.prop._is_self_referential():
- pac = PropertyAliasedClauses(self.prop,
- self.prop.primaryjoin,
- self.prop.secondaryjoin)
+ pj = self.prop.primary_join_against(self.prop.parent, None)
+ sj = self.prop.secondary_join_against(self.prop.parent, toselectable=to_selectable)
+
+ pac = PropertyAliasedClauses(self.prop, pj, sj)
j = pac.primaryjoin
if pac.secondaryjoin:
j = j & pac.secondaryjoin
else:
- j = self.prop.primaryjoin
- if self.prop.secondaryjoin:
- j = j & self.prop.secondaryjoin
+ j = self.prop.full_join_against(self.prop.parent, None, toselectable=to_selectable)
for k in kwargs:
crit = (getattr(self.prop.mapper.class_, k) == kwargs[k])
@@ -285,25 +305,28 @@ class PropertyLoader(StrategizedProperty):
criterion = crit
else:
criterion = criterion & crit
-
- if criterion and self.prop._is_self_referential():
- criterion = pac.adapt_clause(criterion)
- return j, criterion
+ if criterion:
+ if adapt_against:
+ criterion = ClauseAdapter(adapt_against).traverse(criterion)
+ if self.prop._is_self_referential():
+ criterion = pac.adapt_clause(criterion)
+
+ return j, criterion, to_selectable
def any(self, criterion=None, **kwargs):
if not self.prop.uselist:
raise exceptions.InvalidRequestError("'any()' not implemented for scalar attributes. Use has().")
- j, criterion = self._join_and_criterion(criterion, **kwargs)
+ j, criterion, from_obj = self._join_and_criterion(criterion, **kwargs)
- return sql.exists([1], j & criterion)
+ return sql.exists([1], j & criterion, from_obj=from_obj)
def has(self, criterion=None, **kwargs):
if self.prop.uselist:
raise exceptions.InvalidRequestError("'has()' not implemented for collections. Use any().")
- j, criterion = self._join_and_criterion(criterion, **kwargs)
+ j, criterion, from_obj = self._join_and_criterion(criterion, **kwargs)
- return sql.exists([1], j & criterion)
+ return sql.exists([1], j & criterion, from_obj=from_obj)
def contains(self, other):
if not self.prop.uselist:
@@ -322,9 +345,9 @@ class PropertyLoader(StrategizedProperty):
raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object")
criterion = sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))])
- j, criterion = self._join_and_criterion(criterion)
+ j, criterion, from_obj = self._join_and_criterion(criterion)
- return ~sql.exists([1], j & criterion)
+ return ~sql.exists([1], j & criterion, from_obj=from_obj)
def compare(self, op, value, value_is_parent=False):
if op == operators.eq:
@@ -700,50 +723,59 @@ class PropertyLoader(StrategizedProperty):
def _is_self_referential(self):
return self.parent.mapped_table is self.target or self.parent.select_table is self.target
- def primary_join_against(self, mapper, selectable=None):
- return self.__cached_join_against(mapper, selectable, True, False)
+ def primary_join_against(self, mapper, selectable=None, toselectable=None):
+ return self.__cached_join_against(mapper, selectable, toselectable, True, False)
- def secondary_join_against(self, mapper):
- return self.__cached_join_against(mapper, None, False, True)
+ def secondary_join_against(self, mapper, toselectable=None):
+ return self.__cached_join_against(mapper, None, toselectable, False, True)
- def full_join_against(self, mapper, selectable=None):
- return self.__cached_join_against(mapper, selectable, True, True)
+ def full_join_against(self, mapper, selectable=None, toselectable=None):
+ return self.__cached_join_against(mapper, selectable, toselectable, True, True)
- def __cached_join_against(self, mapper, selectable, primary, secondary):
- if selectable is None:
- selectable = mapper.local_table
+ def __cached_join_against(self, frommapper, fromselectable, toselectable, primary, secondary):
+ if fromselectable is None:
+ fromselectable = frommapper.local_table
try:
- rec = self.__parent_join_cache[selectable]
+ rec = self.__parent_join_cache[fromselectable]
except KeyError:
- self.__parent_join_cache[selectable] = rec = {}
+ self.__parent_join_cache[fromselectable] = rec = {}
- key = (mapper, primary, secondary)
+ key = (frommapper, primary, secondary, toselectable)
if key in rec:
return rec[key]
- parent_equivalents = mapper._equivalent_columns
+ parent_equivalents = frommapper._equivalent_columns
if primary:
- if selectable is not mapper.local_table:
+ if toselectable:
+ primaryjoin = self.primaryjoin
+ else:
+ primaryjoin = self.polymorphic_primaryjoin
+
+ if fromselectable is not frommapper.local_table:
if self.direction is sync.ONETOMANY:
- primaryjoin = ClauseAdapter(selectable, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin)
+ primaryjoin = ClauseAdapter(fromselectable, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin)
elif self.direction is sync.MANYTOONE:
- primaryjoin = ClauseAdapter(selectable, include=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin)
+ primaryjoin = ClauseAdapter(fromselectable, include=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin)
elif self.secondaryjoin:
- primaryjoin = ClauseAdapter(selectable, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin)
- else:
- primaryjoin = self.polymorphic_primaryjoin
+ primaryjoin = ClauseAdapter(fromselectable, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin)
if secondary:
- secondaryjoin = self.polymorphic_secondaryjoin
+ if toselectable:
+ secondaryjoin = self.secondaryjoin
+ else:
+ secondaryjoin = self.polymorphic_secondaryjoin
rec[key] = ret = primaryjoin & secondaryjoin
else:
rec[key] = ret = primaryjoin
return ret
elif secondary:
- rec[key] = ret = self.polymorphic_secondaryjoin
+ if toselectable:
+ rec[key] = ret = self.secondaryjoin
+ else:
+ rec[key] = ret = self.polymorphic_secondaryjoin
return ret
else:
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index 9c917ec2d..46f986d14 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -22,21 +22,18 @@ from sqlalchemy import sql, util, exceptions, logging
from sqlalchemy.sql import util as sql_util
from sqlalchemy.sql import expression, visitors, operators
from sqlalchemy.orm import mapper, object_mapper
-from sqlalchemy.orm.mapper import _state_mapper
+from sqlalchemy.orm.mapper import _state_mapper, _class_to_mapper
from sqlalchemy.orm import util as mapperutil
from sqlalchemy.orm import interfaces
__all__ = ['Query', 'QueryContext']
-
+
class Query(object):
"""Encapsulates the object-fetching operations provided by Mappers."""
def __init__(self, class_or_mapper, session=None, entity_name=None):
- if isinstance(class_or_mapper, type):
- self.mapper = mapper.class_mapper(class_or_mapper, entity_name=entity_name)
- else:
- self.mapper = class_or_mapper.compile()
+ self.mapper = _class_to_mapper(class_or_mapper, entity_name=entity_name)
self.select_mapper = self.mapper.get_select_mapper().compile()
self._session = session
@@ -422,15 +419,23 @@ class Query(object):
mapper = start
alias = self._aliases
+
if not isinstance(keys, list):
keys = [keys]
for key in keys:
use_selectable = None
+ of_type = None
+
if isinstance(key, tuple):
key, use_selectable = key
if isinstance(key, interfaces.PropComparator):
prop = key.property
+ if getattr(key, '_of_type', None):
+ if use_selectable:
+ raise exceptions.InvalidRequestError("Can't specify use_selectable along with polymorphic property created via of_type().")
+ of_type = key._of_type
+ use_selectable = key._of_type.select_table
else:
prop = mapper.get_property(key, resolve_synonyms=True)
@@ -445,45 +450,29 @@ class Query(object):
if prop.select_table not in currenttables or create_aliases or use_selectable:
if prop.secondary:
- if use_selectable:
+ if use_selectable or create_aliases:
alias = mapperutil.PropertyAliasedClauses(prop,
prop.primary_join_against(mapper, adapt_against),
- prop.secondary_join_against(mapper),
+ prop.secondary_join_against(mapper, toselectable=use_selectable),
alias,
alias=use_selectable
)
crit = alias.primaryjoin
clause = clause.join(alias.secondary, crit, isouter=outerjoin).join(alias.alias, alias.secondaryjoin, isouter=outerjoin)
- elif create_aliases:
- alias = mapperutil.PropertyAliasedClauses(prop,
- prop.primary_join_against(mapper, adapt_against),
- prop.secondary_join_against(mapper),
- alias
- )
- crit = alias.primaryjoin
- clause = clause.join(alias.secondary, crit, isouter=outerjoin).join(alias.alias, alias.secondaryjoin, isouter=outerjoin)
else:
crit = prop.primary_join_against(mapper, adapt_against)
clause = clause.join(prop.secondary, crit, isouter=outerjoin)
clause = clause.join(prop.select_table, prop.secondary_join_against(mapper), isouter=outerjoin)
else:
- if use_selectable:
+ if use_selectable or create_aliases:
alias = mapperutil.PropertyAliasedClauses(prop,
- prop.primary_join_against(mapper, adapt_against),
+ prop.primary_join_against(mapper, adapt_against, toselectable=use_selectable),
None,
alias,
alias=use_selectable
)
crit = alias.primaryjoin
clause = clause.join(alias.alias, crit, isouter=outerjoin)
- elif create_aliases:
- alias = mapperutil.PropertyAliasedClauses(prop,
- prop.primary_join_against(mapper, adapt_against),
- None,
- alias
- )
- crit = alias.primaryjoin
- clause = clause.join(alias.alias, crit, isouter=outerjoin)
else:
crit = prop.primary_join_against(mapper, adapt_against)
clause = clause.join(prop.select_table, crit, isouter=outerjoin)
@@ -492,7 +481,7 @@ class Query(object):
# does not use secondary tables
raise exceptions.InvalidRequestError("Can't join to property '%s'; a path to this table along a different secondary table already exists. Use the `alias=True` argument to `join()`." % prop.key)
- mapper = prop.mapper
+ mapper = of_type or prop.mapper
if use_selectable:
adapt_against = use_selectable
@@ -707,7 +696,7 @@ class Query(object):
q._adapter = sql_util.ClauseAdapter(q._from_obj, equivalents=q.mapper._equivalent_columns)
return q
-
+
def select_from(self, from_obj):
"""Set the `from_obj` parameter of the query and return the newly
resulting ``Query``. This replaces the table which this Query selects