summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm')
-rw-r--r--lib/sqlalchemy/orm/attributes.py2
-rw-r--r--lib/sqlalchemy/orm/context.py48
-rw-r--r--lib/sqlalchemy/orm/query.py12
-rw-r--r--lib/sqlalchemy/orm/strategies.py10
4 files changed, 56 insertions, 16 deletions
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py
index bf07061c6..6dd95a5a9 100644
--- a/lib/sqlalchemy/orm/attributes.py
+++ b/lib/sqlalchemy/orm/attributes.py
@@ -59,6 +59,8 @@ class QueryableAttribute(
interfaces.InspectionAttr,
interfaces.PropComparator,
roles.JoinTargetRole,
+ roles.OnClauseRole,
+ sql_base.Immutable,
sql_base.MemoizedHasCacheKey,
):
"""Base class for :term:`descriptor` objects that intercept
diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py
index 3a0cce609..09163d4e9 100644
--- a/lib/sqlalchemy/orm/context.py
+++ b/lib/sqlalchemy/orm/context.py
@@ -4,7 +4,6 @@
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-
from . import attributes
from . import interfaces
from . import loading
@@ -664,10 +663,13 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
self._aliased_generations = {}
self._polymorphic_adapters = {}
+ compile_options = cls.default_compile_options.safe_merge(
+ query.compile_options
+ )
# legacy: only for query.with_polymorphic()
- if query.compile_options._with_polymorphic_adapt_map:
+ if compile_options._with_polymorphic_adapt_map:
self._with_polymorphic_adapt_map = dict(
- query.compile_options._with_polymorphic_adapt_map
+ compile_options._with_polymorphic_adapt_map
)
self._setup_with_polymorphics()
@@ -1065,6 +1067,10 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
# maybe?
self._reset_joinpoint()
+ right = inspect(right)
+ if onclause is not None:
+ onclause = inspect(onclause)
+
if onclause is None and isinstance(
right, interfaces.PropComparator
):
@@ -1084,23 +1090,23 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
onclause = right
right = None
elif "parententity" in right._annotations:
- right = right._annotations["parententity"].entity
+ right = right._annotations["parententity"]
if onclause is None:
- r_info = inspect(right)
- if not r_info.is_selectable and not hasattr(r_info, "mapper"):
+ if not right.is_selectable and not hasattr(right, "mapper"):
raise sa_exc.ArgumentError(
"Expected mapped entity or "
"selectable/table as join target"
)
- if isinstance(onclause, interfaces.PropComparator):
- of_type = getattr(onclause, "_of_type", None)
- else:
- of_type = None
+
+ of_type = None
if isinstance(onclause, interfaces.PropComparator):
# descriptor/property given (or determined); this tells us
# explicitly what the expected "left" side of the join is.
+
+ of_type = getattr(onclause, "_of_type", None)
+
if right is None:
if of_type:
right = of_type
@@ -1164,6 +1170,14 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
full = flags["full"]
aliased_generation = flags["aliased_generation"]
+ # do a quick inspect to accommodate for a lambda
+ if right is not None and not isinstance(right, util.string_types):
+ right = inspect(right)
+ if onclause is not None and not isinstance(
+ onclause, util.string_types
+ ):
+ onclause = inspect(onclause)
+
# legacy vvvvvvvvvvvvvvvvvvvvvvvvvv
if not from_joinpoint:
self._reset_joinpoint()
@@ -1190,11 +1204,10 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
onclause = right
right = None
elif "parententity" in right._annotations:
- right = right._annotations["parententity"].entity
+ right = right._annotations["parententity"]
if onclause is None:
- r_info = inspect(right)
- if not r_info.is_selectable and not hasattr(r_info, "mapper"):
+ if not right.is_selectable and not hasattr(right, "mapper"):
raise sa_exc.ArgumentError(
"Expected mapped entity or "
"selectable/table as join target"
@@ -1379,7 +1392,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
self.from_clauses = self.from_clauses + [
orm_join(
- left_clause, right, onclause, isouter=outerjoin, full=full
+ left_clause, r_info, onclause, isouter=outerjoin, full=full
)
]
@@ -1964,6 +1977,13 @@ class _QueryEntity(object):
@classmethod
def to_compile_state(cls, compile_state, entities):
for entity in entities:
+ if entity._is_lambda_element:
+ if entity._is_sequence:
+ cls.to_compile_state(compile_state, entity._resolved)
+ continue
+ else:
+ entity = entity._resolved
+
if entity.is_clause_element:
if entity.is_selectable:
if "parententity" in entity._annotations:
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index 336b7d9aa..1ca65c733 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -20,6 +20,7 @@ database to return iterable result sets.
"""
import itertools
import operator
+import types
from . import attributes
from . import exc as orm_exc
@@ -2229,7 +2230,8 @@ class Query(
# non legacy argument form
_props = [(target,)]
elif not legacy and isinstance(
- target, (expression.Selectable, type, AliasedClass,)
+ target,
+ (expression.Selectable, type, AliasedClass, types.FunctionType),
):
# non legacy argument form
_props = [(target, onclause)]
@@ -2284,7 +2286,13 @@ class Query(
legacy=True,
apply_propagate_attrs=self,
),
- prop[1] if len(prop) == 2 else None,
+ (
+ coercions.expect(roles.OnClauseRole, prop[1])
+ if not isinstance(prop[1], str)
+ else prop[1]
+ )
+ if len(prop) == 2
+ else None,
None,
{
"isouter": isouter,
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
index 5f039aff7..53cc99ccd 100644
--- a/lib/sqlalchemy/orm/strategies.py
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -1524,6 +1524,16 @@ class SubqueryLoader(PostLoader):
# orig_compile_state = compile_state_cls.create_for_statement(
# orig_query, None)
+ if orig_query._is_lambda_element:
+ util.warn(
+ 'subqueryloader for "%s" must invoke lambda callable at %r in '
+ "order to produce a new query, decreasing the efficiency "
+ "of caching for this statement. Consider using "
+ "selectinload() for more effective full-lambda caching"
+ % (self, orig_query)
+ )
+ orig_query = orig_query._resolved
+
# this is the more "quick" version, however it's not clear how
# much of this we need. in particular I can't get a test to
# fail if the "set_base_alias" is missing and not sure why that is.