summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2021-06-15 15:13:34 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2021-06-17 09:48:52 -0400
commit5b3e887f46afdbee312d5efd2a14f7c9b7eeac65 (patch)
tree7c12dd2686dc3d26222383d39527b24613e49da3 /lib/sqlalchemy/sql
parent29fbbd9cebf5d4a4f21d01a74bcfb6dce923fe1b (diff)
downloadsqlalchemy-5b3e887f46afdbee312d5efd2a14f7c9b7eeac65.tar.gz
memoize current options and joins w with_entities/with_only_cols
Fixed further regressions in the same area as that of :ticket:`6052` where loader options as well as invocations of methods like :meth:`_orm.Query.join` would fail if the left side of the statement for which the option/join depends upon were replaced by using the :meth:`_orm.Query.with_entities` method, or when using 2.0 style queries when using the :meth:`_sql.Select.with_only_columns` method. A new set of state has been added to the objects which tracks the "left" entities that the options / join were made against which is memoized when the lead entities are changed. Fixes: #6503 Fixes: #6253 Change-Id: I211b2af98b0b20d1263fb15dc513884dcc5de6a4
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/elements.py28
-rw-r--r--lib/sqlalchemy/sql/selectable.py86
-rw-r--r--lib/sqlalchemy/sql/traversals.py18
-rw-r--r--lib/sqlalchemy/sql/visitors.py35
4 files changed, 121 insertions, 46 deletions
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 213f47c40..709106b6b 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -32,7 +32,6 @@ from .base import NO_ARG
from .base import PARSE_AUTOCOMMIT
from .base import SingletonConstant
from .coercions import _document_text_coercion
-from .traversals import _get_children
from .traversals import HasCopyInternals
from .traversals import MemoizedHasCacheKey
from .traversals import NO_CACHE
@@ -389,33 +388,6 @@ class ClauseElement(
"""
return traversals.compare(self, other, **kw)
- def get_children(self, omit_attrs=(), **kw):
- r"""Return immediate child :class:`.visitors.Traversible`
- elements of this :class:`.visitors.Traversible`.
-
- This is used for visit traversal.
-
- \**kw may contain flags that change the collection that is
- returned, for example to return a subset of items in order to
- cut down on larger traversals, or to return child items from a
- different context (such as schema-level collections instead of
- clause-level).
-
- """
- try:
- traverse_internals = self._traverse_internals
- except AttributeError:
- # user-defined classes may not have a _traverse_internals
- return []
-
- return itertools.chain.from_iterable(
- meth(obj, **kw)
- for attrname, obj, meth in _get_children.run_generated_dispatch(
- self, traverse_internals, "_generated_get_children_traversal"
- )
- if attrname not in omit_attrs and obj is not None
- )
-
def self_group(self, against=None):
"""Apply a 'grouping' to this :class:`_expression.ClauseElement`.
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index 1610191d1..e1dee091b 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -18,7 +18,9 @@ from operator import attrgetter
from . import coercions
from . import operators
from . import roles
+from . import traversals
from . import type_api
+from . import visitors
from .annotation import Annotated
from .annotation import SupportsCloneAnnotations
from .base import _clone
@@ -4131,8 +4133,13 @@ class SelectState(util.MemoizedSlots, CompileState):
self.statement = statement
self.from_clauses = statement._from_obj
+ for memoized_entities in statement._memoized_select_entities:
+ self._setup_joins(
+ memoized_entities._setup_joins, memoized_entities._raw_columns
+ )
+
if statement._setup_joins:
- self._setup_joins(statement._setup_joins)
+ self._setup_joins(statement._setup_joins, statement._raw_columns)
self.froms = self._get_froms(statement)
@@ -4361,7 +4368,7 @@ class SelectState(util.MemoizedSlots, CompileState):
def all_selected_columns(cls, statement):
return [c for c in _select_iterables(statement._raw_columns)]
- def _setup_joins(self, args):
+ def _setup_joins(self, args, raw_columns):
for (right, onclause, left, flags) in args:
isouter = flags["isouter"]
full = flags["full"]
@@ -4371,7 +4378,7 @@ class SelectState(util.MemoizedSlots, CompileState):
left,
replace_from_obj_index,
) = self._join_determine_implicit_left_side(
- left, right, onclause
+ raw_columns, left, right, onclause
)
else:
(replace_from_obj_index) = self._join_place_explicit_left_side(
@@ -4403,7 +4410,9 @@ class SelectState(util.MemoizedSlots, CompileState):
)
@util.preload_module("sqlalchemy.sql.util")
- def _join_determine_implicit_left_side(self, left, right, onclause):
+ def _join_determine_implicit_left_side(
+ self, raw_columns, left, right, onclause
+ ):
"""When join conditions don't express the left side explicitly,
determine if an existing FROM or entity in this query
can serve as the left hand side.
@@ -4431,10 +4440,7 @@ class SelectState(util.MemoizedSlots, CompileState):
for from_clause in itertools.chain(
itertools.chain.from_iterable(
- [
- element._from_objects
- for element in statement._raw_columns
- ]
+ [element._from_objects for element in raw_columns]
),
itertools.chain.from_iterable(
[
@@ -4531,6 +4537,47 @@ class _SelectFromElements(object):
yield element
+class _MemoizedSelectEntities(
+ traversals.HasCacheKey, traversals.HasCopyInternals, visitors.Traversible
+):
+ __visit_name__ = "memoized_select_entities"
+
+ _traverse_internals = [
+ ("_raw_columns", InternalTraversal.dp_clauseelement_list),
+ ("_setup_joins", InternalTraversal.dp_setup_join_tuple),
+ ("_legacy_setup_joins", InternalTraversal.dp_setup_join_tuple),
+ ("_with_options", InternalTraversal.dp_executable_options),
+ ]
+
+ _annotations = util.EMPTY_DICT
+
+ def _clone(self, **kw):
+ c = self.__class__.__new__(self.__class__)
+ c.__dict__ = {k: v for k, v in self.__dict__.items()}
+ c._is_clone_of = self
+ return c
+
+ @classmethod
+ def _generate_for_statement(cls, select_stmt):
+ if (
+ select_stmt._setup_joins
+ or select_stmt._legacy_setup_joins
+ or select_stmt._with_options
+ ):
+ self = _MemoizedSelectEntities()
+ self._raw_columns = select_stmt._raw_columns
+ self._setup_joins = select_stmt._setup_joins
+ self._legacy_setup_joins = select_stmt._legacy_setup_joins
+ self._with_options = select_stmt._with_options
+
+ select_stmt._memoized_select_entities += (self,)
+ select_stmt._raw_columns = (
+ select_stmt._setup_joins
+ ) = (
+ select_stmt._legacy_setup_joins
+ ) = select_stmt._with_options = ()
+
+
class Select(
HasPrefixes,
HasSuffixes,
@@ -4559,6 +4606,7 @@ class Select(
_setup_joins = ()
_legacy_setup_joins = ()
+ _memoized_select_entities = ()
_distinct = False
_distinct_on = ()
@@ -4574,6 +4622,10 @@ class Select(
_traverse_internals = (
[
("_raw_columns", InternalTraversal.dp_clauseelement_list),
+ (
+ "_memoized_select_entities",
+ InternalTraversal.dp_memoized_select_entities,
+ ),
("_from_obj", InternalTraversal.dp_clauseelement_list),
("_where_criteria", InternalTraversal.dp_clauseelement_tuple),
("_having_criteria", InternalTraversal.dp_clauseelement_tuple),
@@ -5461,16 +5513,14 @@ class Select(
# is the case for now.
self._assert_no_memoizations()
- rc = []
- for c in coercions._expression_collection_was_a_list(
- "columns", "Select.with_only_columns", columns
- ):
- c = coercions.expect(roles.ColumnsClauseRole, c)
- # TODO: why are we doing this here?
- if isinstance(c, ScalarSelect):
- c = c.self_group(against=operators.comma_op)
- rc.append(c)
- self._raw_columns = rc
+ _MemoizedSelectEntities._generate_for_statement(self)
+
+ self._raw_columns = [
+ coercions.expect(roles.ColumnsClauseRole, c)
+ for c in coercions._expression_collection_was_a_list(
+ "columns", "Select.with_only_columns", columns
+ )
+ ]
@property
def whereclause(self):
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py
index 35f2bd62f..a86d16ef4 100644
--- a/lib/sqlalchemy/sql/traversals.py
+++ b/lib/sqlalchemy/sql/traversals.py
@@ -194,6 +194,8 @@ class HasCacheKey(object):
elif (
meth is InternalTraversal.dp_clauseelement_list
or meth is InternalTraversal.dp_clauseelement_tuple
+ or meth
+ is InternalTraversal.dp_memoized_select_entities
):
result += (
attrname,
@@ -409,6 +411,9 @@ class _CacheKey(ExtendedInternalTraversal):
visit_clauseelement_list = InternalTraversal.dp_clauseelement_list
visit_annotations_key = InternalTraversal.dp_annotations_key
visit_clauseelement_tuple = InternalTraversal.dp_clauseelement_tuple
+ visit_memoized_select_entities = (
+ InternalTraversal.dp_memoized_select_entities
+ )
visit_string = (
visit_boolean
@@ -799,6 +804,9 @@ class _CopyInternals(InternalTraversal):
for (target, onclause, from_, flags) in element
)
+ def visit_memoized_select_entities(self, attrname, parent, element, **kw):
+ return self.visit_clauseelement_tuple(attrname, parent, element, **kw)
+
def visit_dml_ordered_values(
self, attrname, parent, element, clone=_clone, **kw
):
@@ -919,6 +927,9 @@ class _GetChildren(InternalTraversal):
if onclause is not None and not isinstance(onclause, str):
yield _flatten_clauseelement(onclause)
+ def visit_memoized_select_entities(self, element, **kw):
+ return self.visit_clauseelement_tuple(element, **kw)
+
def visit_dml_ordered_values(self, element, **kw):
for k, v in element:
if hasattr(k, "__clause_element__"):
@@ -1265,6 +1276,13 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
self.stack.append((l_onclause, r_onclause))
self.stack.append((l_from, r_from))
+ def visit_memoized_select_entities(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ return self.visit_clauseelement_tuple(
+ attrname, left_parent, left, right_parent, right, **kw
+ )
+
def visit_table_hint_list(
self, attrname, left_parent, left, right_parent, right, **kw
):
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 93ee8eb1c..c750c546a 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -24,6 +24,7 @@ http://techspot.zzzeek.org/2008/01/23/expression-transformations/ .
"""
from collections import deque
+import itertools
import operator
from .. import exc
@@ -119,6 +120,38 @@ class Traversible(util.with_metaclass(TraversibleType)):
"""
+ @util.preload_module("sqlalchemy.sql.traversals")
+ def get_children(self, omit_attrs=(), **kw):
+ r"""Return immediate child :class:`.visitors.Traversible`
+ elements of this :class:`.visitors.Traversible`.
+
+ This is used for visit traversal.
+
+ \**kw may contain flags that change the collection that is
+ returned, for example to return a subset of items in order to
+ cut down on larger traversals, or to return child items from a
+ different context (such as schema-level collections instead of
+ clause-level).
+
+ """
+
+ traversals = util.preloaded.sql_traversals
+
+ try:
+ traverse_internals = self._traverse_internals
+ except AttributeError:
+ # user-defined classes may not have a _traverse_internals
+ return []
+
+ dispatch = traversals._get_children.run_generated_dispatch
+ return itertools.chain.from_iterable(
+ meth(obj, **kw)
+ for attrname, obj, meth in dispatch(
+ self, traverse_internals, "_generated_get_children_traversal"
+ )
+ if attrname not in omit_attrs and obj is not None
+ )
+
class _InternalTraversalType(type):
def __init__(cls, clsname, bases, clsdict):
@@ -393,6 +426,8 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)):
dp_setup_join_tuple = symbol("SJ")
+ dp_memoized_select_entities = symbol("ME")
+
dp_statement_hint_list = symbol("SH")
"""Visit the ``_statement_hints`` collection of a
:class:`_expression.Select`