diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2021-06-15 15:13:34 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2021-06-17 09:48:52 -0400 |
| commit | 5b3e887f46afdbee312d5efd2a14f7c9b7eeac65 (patch) | |
| tree | 7c12dd2686dc3d26222383d39527b24613e49da3 /lib/sqlalchemy/sql | |
| parent | 29fbbd9cebf5d4a4f21d01a74bcfb6dce923fe1b (diff) | |
| download | sqlalchemy-5b3e887f46afdbee312d5efd2a14f7c9b7eeac65.tar.gz | |
memoize current options and joins w with_entities/with_only_cols
Fixed further regressions in the same area as that of :ticket:`6052` where
loader options as well as invocations of methods like
:meth:`_orm.Query.join` would fail if the left side of the statement for
which the option/join depends upon were replaced by using the
:meth:`_orm.Query.with_entities` method, or when using 2.0 style queries
when using the :meth:`_sql.Select.with_only_columns` method. A new set of
state has been added to the objects which tracks the "left" entities that
the options / join were made against which is memoized when the lead
entities are changed.
Fixes: #6503
Fixes: #6253
Change-Id: I211b2af98b0b20d1263fb15dc513884dcc5de6a4
Diffstat (limited to 'lib/sqlalchemy/sql')
| -rw-r--r-- | lib/sqlalchemy/sql/elements.py | 28 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 86 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/traversals.py | 18 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/visitors.py | 35 |
4 files changed, 121 insertions, 46 deletions
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 213f47c40..709106b6b 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -32,7 +32,6 @@ from .base import NO_ARG from .base import PARSE_AUTOCOMMIT from .base import SingletonConstant from .coercions import _document_text_coercion -from .traversals import _get_children from .traversals import HasCopyInternals from .traversals import MemoizedHasCacheKey from .traversals import NO_CACHE @@ -389,33 +388,6 @@ class ClauseElement( """ return traversals.compare(self, other, **kw) - def get_children(self, omit_attrs=(), **kw): - r"""Return immediate child :class:`.visitors.Traversible` - elements of this :class:`.visitors.Traversible`. - - This is used for visit traversal. - - \**kw may contain flags that change the collection that is - returned, for example to return a subset of items in order to - cut down on larger traversals, or to return child items from a - different context (such as schema-level collections instead of - clause-level). - - """ - try: - traverse_internals = self._traverse_internals - except AttributeError: - # user-defined classes may not have a _traverse_internals - return [] - - return itertools.chain.from_iterable( - meth(obj, **kw) - for attrname, obj, meth in _get_children.run_generated_dispatch( - self, traverse_internals, "_generated_get_children_traversal" - ) - if attrname not in omit_attrs and obj is not None - ) - def self_group(self, against=None): """Apply a 'grouping' to this :class:`_expression.ClauseElement`. diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 1610191d1..e1dee091b 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -18,7 +18,9 @@ from operator import attrgetter from . import coercions from . import operators from . import roles +from . import traversals from . import type_api +from . import visitors from .annotation import Annotated from .annotation import SupportsCloneAnnotations from .base import _clone @@ -4131,8 +4133,13 @@ class SelectState(util.MemoizedSlots, CompileState): self.statement = statement self.from_clauses = statement._from_obj + for memoized_entities in statement._memoized_select_entities: + self._setup_joins( + memoized_entities._setup_joins, memoized_entities._raw_columns + ) + if statement._setup_joins: - self._setup_joins(statement._setup_joins) + self._setup_joins(statement._setup_joins, statement._raw_columns) self.froms = self._get_froms(statement) @@ -4361,7 +4368,7 @@ class SelectState(util.MemoizedSlots, CompileState): def all_selected_columns(cls, statement): return [c for c in _select_iterables(statement._raw_columns)] - def _setup_joins(self, args): + def _setup_joins(self, args, raw_columns): for (right, onclause, left, flags) in args: isouter = flags["isouter"] full = flags["full"] @@ -4371,7 +4378,7 @@ class SelectState(util.MemoizedSlots, CompileState): left, replace_from_obj_index, ) = self._join_determine_implicit_left_side( - left, right, onclause + raw_columns, left, right, onclause ) else: (replace_from_obj_index) = self._join_place_explicit_left_side( @@ -4403,7 +4410,9 @@ class SelectState(util.MemoizedSlots, CompileState): ) @util.preload_module("sqlalchemy.sql.util") - def _join_determine_implicit_left_side(self, left, right, onclause): + def _join_determine_implicit_left_side( + self, raw_columns, left, right, onclause + ): """When join conditions don't express the left side explicitly, determine if an existing FROM or entity in this query can serve as the left hand side. @@ -4431,10 +4440,7 @@ class SelectState(util.MemoizedSlots, CompileState): for from_clause in itertools.chain( itertools.chain.from_iterable( - [ - element._from_objects - for element in statement._raw_columns - ] + [element._from_objects for element in raw_columns] ), itertools.chain.from_iterable( [ @@ -4531,6 +4537,47 @@ class _SelectFromElements(object): yield element +class _MemoizedSelectEntities( + traversals.HasCacheKey, traversals.HasCopyInternals, visitors.Traversible +): + __visit_name__ = "memoized_select_entities" + + _traverse_internals = [ + ("_raw_columns", InternalTraversal.dp_clauseelement_list), + ("_setup_joins", InternalTraversal.dp_setup_join_tuple), + ("_legacy_setup_joins", InternalTraversal.dp_setup_join_tuple), + ("_with_options", InternalTraversal.dp_executable_options), + ] + + _annotations = util.EMPTY_DICT + + def _clone(self, **kw): + c = self.__class__.__new__(self.__class__) + c.__dict__ = {k: v for k, v in self.__dict__.items()} + c._is_clone_of = self + return c + + @classmethod + def _generate_for_statement(cls, select_stmt): + if ( + select_stmt._setup_joins + or select_stmt._legacy_setup_joins + or select_stmt._with_options + ): + self = _MemoizedSelectEntities() + self._raw_columns = select_stmt._raw_columns + self._setup_joins = select_stmt._setup_joins + self._legacy_setup_joins = select_stmt._legacy_setup_joins + self._with_options = select_stmt._with_options + + select_stmt._memoized_select_entities += (self,) + select_stmt._raw_columns = ( + select_stmt._setup_joins + ) = ( + select_stmt._legacy_setup_joins + ) = select_stmt._with_options = () + + class Select( HasPrefixes, HasSuffixes, @@ -4559,6 +4606,7 @@ class Select( _setup_joins = () _legacy_setup_joins = () + _memoized_select_entities = () _distinct = False _distinct_on = () @@ -4574,6 +4622,10 @@ class Select( _traverse_internals = ( [ ("_raw_columns", InternalTraversal.dp_clauseelement_list), + ( + "_memoized_select_entities", + InternalTraversal.dp_memoized_select_entities, + ), ("_from_obj", InternalTraversal.dp_clauseelement_list), ("_where_criteria", InternalTraversal.dp_clauseelement_tuple), ("_having_criteria", InternalTraversal.dp_clauseelement_tuple), @@ -5461,16 +5513,14 @@ class Select( # is the case for now. self._assert_no_memoizations() - rc = [] - for c in coercions._expression_collection_was_a_list( - "columns", "Select.with_only_columns", columns - ): - c = coercions.expect(roles.ColumnsClauseRole, c) - # TODO: why are we doing this here? - if isinstance(c, ScalarSelect): - c = c.self_group(against=operators.comma_op) - rc.append(c) - self._raw_columns = rc + _MemoizedSelectEntities._generate_for_statement(self) + + self._raw_columns = [ + coercions.expect(roles.ColumnsClauseRole, c) + for c in coercions._expression_collection_was_a_list( + "columns", "Select.with_only_columns", columns + ) + ] @property def whereclause(self): diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index 35f2bd62f..a86d16ef4 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -194,6 +194,8 @@ class HasCacheKey(object): elif ( meth is InternalTraversal.dp_clauseelement_list or meth is InternalTraversal.dp_clauseelement_tuple + or meth + is InternalTraversal.dp_memoized_select_entities ): result += ( attrname, @@ -409,6 +411,9 @@ class _CacheKey(ExtendedInternalTraversal): visit_clauseelement_list = InternalTraversal.dp_clauseelement_list visit_annotations_key = InternalTraversal.dp_annotations_key visit_clauseelement_tuple = InternalTraversal.dp_clauseelement_tuple + visit_memoized_select_entities = ( + InternalTraversal.dp_memoized_select_entities + ) visit_string = ( visit_boolean @@ -799,6 +804,9 @@ class _CopyInternals(InternalTraversal): for (target, onclause, from_, flags) in element ) + def visit_memoized_select_entities(self, attrname, parent, element, **kw): + return self.visit_clauseelement_tuple(attrname, parent, element, **kw) + def visit_dml_ordered_values( self, attrname, parent, element, clone=_clone, **kw ): @@ -919,6 +927,9 @@ class _GetChildren(InternalTraversal): if onclause is not None and not isinstance(onclause, str): yield _flatten_clauseelement(onclause) + def visit_memoized_select_entities(self, element, **kw): + return self.visit_clauseelement_tuple(element, **kw) + def visit_dml_ordered_values(self, element, **kw): for k, v in element: if hasattr(k, "__clause_element__"): @@ -1265,6 +1276,13 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): self.stack.append((l_onclause, r_onclause)) self.stack.append((l_from, r_from)) + def visit_memoized_select_entities( + self, attrname, left_parent, left, right_parent, right, **kw + ): + return self.visit_clauseelement_tuple( + attrname, left_parent, left, right_parent, right, **kw + ) + def visit_table_hint_list( self, attrname, left_parent, left, right_parent, right, **kw ): diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 93ee8eb1c..c750c546a 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -24,6 +24,7 @@ http://techspot.zzzeek.org/2008/01/23/expression-transformations/ . """ from collections import deque +import itertools import operator from .. import exc @@ -119,6 +120,38 @@ class Traversible(util.with_metaclass(TraversibleType)): """ + @util.preload_module("sqlalchemy.sql.traversals") + def get_children(self, omit_attrs=(), **kw): + r"""Return immediate child :class:`.visitors.Traversible` + elements of this :class:`.visitors.Traversible`. + + This is used for visit traversal. + + \**kw may contain flags that change the collection that is + returned, for example to return a subset of items in order to + cut down on larger traversals, or to return child items from a + different context (such as schema-level collections instead of + clause-level). + + """ + + traversals = util.preloaded.sql_traversals + + try: + traverse_internals = self._traverse_internals + except AttributeError: + # user-defined classes may not have a _traverse_internals + return [] + + dispatch = traversals._get_children.run_generated_dispatch + return itertools.chain.from_iterable( + meth(obj, **kw) + for attrname, obj, meth in dispatch( + self, traverse_internals, "_generated_get_children_traversal" + ) + if attrname not in omit_attrs and obj is not None + ) + class _InternalTraversalType(type): def __init__(cls, clsname, bases, clsdict): @@ -393,6 +426,8 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)): dp_setup_join_tuple = symbol("SJ") + dp_memoized_select_entities = symbol("ME") + dp_statement_hint_list = symbol("SH") """Visit the ``_statement_hints`` collection of a :class:`_expression.Select` |
