diff options
Diffstat (limited to 'lib/sqlalchemy/sql/selectable.py')
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 86 |
1 files changed, 68 insertions, 18 deletions
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 1610191d1..e1dee091b 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -18,7 +18,9 @@ from operator import attrgetter from . import coercions from . import operators from . import roles +from . import traversals from . import type_api +from . import visitors from .annotation import Annotated from .annotation import SupportsCloneAnnotations from .base import _clone @@ -4131,8 +4133,13 @@ class SelectState(util.MemoizedSlots, CompileState): self.statement = statement self.from_clauses = statement._from_obj + for memoized_entities in statement._memoized_select_entities: + self._setup_joins( + memoized_entities._setup_joins, memoized_entities._raw_columns + ) + if statement._setup_joins: - self._setup_joins(statement._setup_joins) + self._setup_joins(statement._setup_joins, statement._raw_columns) self.froms = self._get_froms(statement) @@ -4361,7 +4368,7 @@ class SelectState(util.MemoizedSlots, CompileState): def all_selected_columns(cls, statement): return [c for c in _select_iterables(statement._raw_columns)] - def _setup_joins(self, args): + def _setup_joins(self, args, raw_columns): for (right, onclause, left, flags) in args: isouter = flags["isouter"] full = flags["full"] @@ -4371,7 +4378,7 @@ class SelectState(util.MemoizedSlots, CompileState): left, replace_from_obj_index, ) = self._join_determine_implicit_left_side( - left, right, onclause + raw_columns, left, right, onclause ) else: (replace_from_obj_index) = self._join_place_explicit_left_side( @@ -4403,7 +4410,9 @@ class SelectState(util.MemoizedSlots, CompileState): ) @util.preload_module("sqlalchemy.sql.util") - def _join_determine_implicit_left_side(self, left, right, onclause): + def _join_determine_implicit_left_side( + self, raw_columns, left, right, onclause + ): """When join conditions don't express the left side explicitly, determine if an existing FROM or entity in this query can serve as the left hand side. @@ -4431,10 +4440,7 @@ class SelectState(util.MemoizedSlots, CompileState): for from_clause in itertools.chain( itertools.chain.from_iterable( - [ - element._from_objects - for element in statement._raw_columns - ] + [element._from_objects for element in raw_columns] ), itertools.chain.from_iterable( [ @@ -4531,6 +4537,47 @@ class _SelectFromElements(object): yield element +class _MemoizedSelectEntities( + traversals.HasCacheKey, traversals.HasCopyInternals, visitors.Traversible +): + __visit_name__ = "memoized_select_entities" + + _traverse_internals = [ + ("_raw_columns", InternalTraversal.dp_clauseelement_list), + ("_setup_joins", InternalTraversal.dp_setup_join_tuple), + ("_legacy_setup_joins", InternalTraversal.dp_setup_join_tuple), + ("_with_options", InternalTraversal.dp_executable_options), + ] + + _annotations = util.EMPTY_DICT + + def _clone(self, **kw): + c = self.__class__.__new__(self.__class__) + c.__dict__ = {k: v for k, v in self.__dict__.items()} + c._is_clone_of = self + return c + + @classmethod + def _generate_for_statement(cls, select_stmt): + if ( + select_stmt._setup_joins + or select_stmt._legacy_setup_joins + or select_stmt._with_options + ): + self = _MemoizedSelectEntities() + self._raw_columns = select_stmt._raw_columns + self._setup_joins = select_stmt._setup_joins + self._legacy_setup_joins = select_stmt._legacy_setup_joins + self._with_options = select_stmt._with_options + + select_stmt._memoized_select_entities += (self,) + select_stmt._raw_columns = ( + select_stmt._setup_joins + ) = ( + select_stmt._legacy_setup_joins + ) = select_stmt._with_options = () + + class Select( HasPrefixes, HasSuffixes, @@ -4559,6 +4606,7 @@ class Select( _setup_joins = () _legacy_setup_joins = () + _memoized_select_entities = () _distinct = False _distinct_on = () @@ -4574,6 +4622,10 @@ class Select( _traverse_internals = ( [ ("_raw_columns", InternalTraversal.dp_clauseelement_list), + ( + "_memoized_select_entities", + InternalTraversal.dp_memoized_select_entities, + ), ("_from_obj", InternalTraversal.dp_clauseelement_list), ("_where_criteria", InternalTraversal.dp_clauseelement_tuple), ("_having_criteria", InternalTraversal.dp_clauseelement_tuple), @@ -5461,16 +5513,14 @@ class Select( # is the case for now. self._assert_no_memoizations() - rc = [] - for c in coercions._expression_collection_was_a_list( - "columns", "Select.with_only_columns", columns - ): - c = coercions.expect(roles.ColumnsClauseRole, c) - # TODO: why are we doing this here? - if isinstance(c, ScalarSelect): - c = c.self_group(against=operators.comma_op) - rc.append(c) - self._raw_columns = rc + _MemoizedSelectEntities._generate_for_statement(self) + + self._raw_columns = [ + coercions.expect(roles.ColumnsClauseRole, c) + for c in coercions._expression_collection_was_a_list( + "columns", "Select.with_only_columns", columns + ) + ] @property def whereclause(self): |
