summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py2
-rw-r--r--lib/sqlalchemy/engine/base.py105
-rw-r--r--lib/sqlalchemy/engine/create.py13
-rw-r--r--lib/sqlalchemy/engine/cursor.py89
-rw-r--r--lib/sqlalchemy/engine/default.py15
-rw-r--r--lib/sqlalchemy/ext/baked.py1
-rw-r--r--lib/sqlalchemy/future/selectable.py1
-rw-r--r--lib/sqlalchemy/orm/attributes.py30
-rw-r--r--lib/sqlalchemy/orm/context.py37
-rw-r--r--lib/sqlalchemy/orm/descriptor_props.py1
-rw-r--r--lib/sqlalchemy/orm/interfaces.py32
-rw-r--r--lib/sqlalchemy/orm/loading.py11
-rw-r--r--lib/sqlalchemy/orm/mapper.py2
-rw-r--r--lib/sqlalchemy/orm/path_registry.py8
-rw-r--r--lib/sqlalchemy/orm/persistence.py9
-rw-r--r--lib/sqlalchemy/orm/properties.py1
-rw-r--r--lib/sqlalchemy/orm/query.py20
-rw-r--r--lib/sqlalchemy/orm/relationships.py1
-rw-r--r--lib/sqlalchemy/orm/strategies.py500
-rw-r--r--lib/sqlalchemy/orm/util.py8
-rw-r--r--lib/sqlalchemy/sql/__init__.py3
-rw-r--r--lib/sqlalchemy/sql/annotation.py9
-rw-r--r--lib/sqlalchemy/sql/base.py7
-rw-r--r--lib/sqlalchemy/sql/compiler.py54
-rw-r--r--lib/sqlalchemy/sql/ddl.py3
-rw-r--r--lib/sqlalchemy/sql/dml.py9
-rw-r--r--lib/sqlalchemy/sql/elements.py138
-rw-r--r--lib/sqlalchemy/sql/functions.py39
-rw-r--r--lib/sqlalchemy/sql/schema.py2
-rw-r--r--lib/sqlalchemy/sql/selectable.py32
-rw-r--r--lib/sqlalchemy/sql/sqltypes.py24
-rw-r--r--lib/sqlalchemy/sql/traversals.py129
-rw-r--r--lib/sqlalchemy/sql/type_api.py7
-rw-r--r--lib/sqlalchemy/sql/visitors.py19
-rw-r--r--lib/sqlalchemy/testing/assertsql.py6
35 files changed, 920 insertions, 447 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index 441e77a37..24e2d13d8 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -3681,7 +3681,7 @@ class PGDialect(default.DefaultDialect):
WHERE t.typtype = 'd'
"""
- s = sql.text(SQL_DOMAINS).columns(attname=sqltypes.Unicode)
+ s = sql.text(SQL_DOMAINS)
c = connection.execution_options(future_result=True).execute(s)
domains = {}
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index a36f4eee2..3e02a29fe 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -1175,46 +1175,17 @@ class Connection(Connectable):
)
compiled_cache = execution_options.get(
- "compiled_cache", self.dialect._compiled_cache
+ "compiled_cache", self.engine._compiled_cache
)
- if compiled_cache is not None:
- elem_cache_key = elem._generate_cache_key()
- else:
- elem_cache_key = None
-
- if elem_cache_key:
- cache_key, extracted_params = elem_cache_key
- key = (
- dialect,
- cache_key,
- tuple(keys),
- bool(schema_translate_map),
- inline,
- )
- compiled_sql = compiled_cache.get(key)
-
- if compiled_sql is None:
- compiled_sql = elem.compile(
- dialect=dialect,
- cache_key=elem_cache_key,
- column_keys=keys,
- inline=inline,
- schema_translate_map=schema_translate_map,
- linting=self.dialect.compiler_linting
- | compiler.WARN_LINTING,
- )
- compiled_cache[key] = compiled_sql
- else:
- extracted_params = None
- compiled_sql = elem.compile(
- dialect=dialect,
- column_keys=keys,
- inline=inline,
- schema_translate_map=schema_translate_map,
- linting=self.dialect.compiler_linting | compiler.WARN_LINTING,
- )
-
+ compiled_sql, extracted_params, cache_hit = elem._compile_w_cache(
+ dialect=dialect,
+ compiled_cache=compiled_cache,
+ column_keys=keys,
+ inline=inline,
+ schema_translate_map=schema_translate_map,
+ linting=self.dialect.compiler_linting | compiler.WARN_LINTING,
+ )
ret = self._execute_context(
dialect,
dialect.execution_ctx_cls._init_compiled,
@@ -1225,6 +1196,7 @@ class Connection(Connectable):
distilled_params,
elem,
extracted_params,
+ cache_hit=cache_hit,
)
if has_events:
self.dispatch.after_execute(
@@ -1389,7 +1361,8 @@ class Connection(Connectable):
statement,
parameters,
execution_options,
- *args
+ *args,
+ **kw
):
"""Create an :class:`.ExecutionContext` and execute, returning
a :class:`_engine.CursorResult`."""
@@ -1407,7 +1380,7 @@ class Connection(Connectable):
conn = self._revalidate_connection()
context = constructor(
- dialect, self, conn, execution_options, *args
+ dialect, self, conn, execution_options, *args, **kw
)
except (exc.PendingRollbackError, exc.ResourceClosedError):
raise
@@ -1455,32 +1428,21 @@ class Connection(Connectable):
self.engine.logger.info(statement)
- # stats = context._get_cache_stats()
+ stats = context._get_cache_stats()
if not self.engine.hide_parameters:
- # TODO: I love the stats but a ton of tests that are hardcoded.
- # to certain log output are failing.
self.engine.logger.info(
- "%r",
+ "[%s] %r",
+ stats,
sql_util._repr_params(
parameters, batches=10, ismulti=context.executemany
),
)
- # self.engine.logger.info(
- # "[%s] %r",
- # stats,
- # sql_util._repr_params(
- # parameters, batches=10, ismulti=context.executemany
- # ),
- # )
else:
self.engine.logger.info(
- "[SQL parameters hidden due to hide_parameters=True]"
+ "[%s] [SQL parameters hidden due to hide_parameters=True]"
+ % (stats,)
)
- # self.engine.logger.info(
- # "[%s] [SQL parameters hidden due to hide_parameters=True]"
- # % (stats,)
- # )
evt_handled = False
try:
@@ -2369,6 +2331,7 @@ class Engine(Connectable, log.Identified):
url,
logging_name=None,
echo=None,
+ query_cache_size=500,
execution_options=None,
hide_parameters=False,
):
@@ -2379,14 +2342,43 @@ class Engine(Connectable, log.Identified):
self.logging_name = logging_name
self.echo = echo
self.hide_parameters = hide_parameters
+ if query_cache_size != 0:
+ self._compiled_cache = util.LRUCache(
+ query_cache_size, size_alert=self._lru_size_alert
+ )
+ else:
+ self._compiled_cache = None
log.instance_logger(self, echoflag=echo)
if execution_options:
self.update_execution_options(**execution_options)
+ def _lru_size_alert(self, cache):
+ if self._should_log_info:
+ self.logger.info(
+ "Compiled cache size pruning from %d items to %d. "
+ "Increase cache size to reduce the frequency of pruning.",
+ len(cache),
+ cache.capacity,
+ )
+
@property
def engine(self):
return self
+ def clear_compiled_cache(self):
+ """Clear the compiled cache associated with the dialect.
+
+ This applies **only** to the built-in cache that is established
+ via the :paramref:`.create_engine.query_cache_size` parameter.
+ It will not impact any dictionary caches that were passed via the
+ :paramref:`.Connection.execution_options.query_cache` parameter.
+
+ .. versionadded:: 1.4
+
+ """
+ if self._compiled_cache:
+ self._compiled_cache.clear()
+
def update_execution_options(self, **opt):
r"""Update the default execution_options dictionary
of this :class:`_engine.Engine`.
@@ -2874,6 +2866,7 @@ class OptionEngineMixin(object):
self.dialect = proxied.dialect
self.logging_name = proxied.logging_name
self.echo = proxied.echo
+ self._compiled_cache = proxied._compiled_cache
self.hide_parameters = proxied.hide_parameters
log.instance_logger(self, echoflag=self.echo)
diff --git a/lib/sqlalchemy/engine/create.py b/lib/sqlalchemy/engine/create.py
index 4c912349e..9bf72eb06 100644
--- a/lib/sqlalchemy/engine/create.py
+++ b/lib/sqlalchemy/engine/create.py
@@ -436,7 +436,13 @@ def create_engine(url, **kwargs):
.. versionadded:: 1.2.3
:param query_cache_size: size of the cache used to cache the SQL string
- form of queries. Defaults to zero, which disables caching.
+ form of queries. Set to zero to disable caching.
+
+ The cache is pruned of its least recently used items when its size reaches
+ N * 1.5. Defaults to 500, meaning the cache will always store at least
+ 500 SQL statements when filled, and will grow up to 750 items at which
+ point it is pruned back down to 500 by removing the 250 least recently
+ used items.
Caching is accomplished on a per-statement basis by generating a
cache key that represents the statement's structure, then generating
@@ -446,6 +452,11 @@ def create_engine(url, **kwargs):
bypass the cache. SQL logging will indicate statistics for each
statement whether or not it were pull from the cache.
+ .. note:: some ORM functions related to unit-of-work persistence as well
+ as some attribute loading strategies will make use of individual
+ per-mapper caches outside of the main cache.
+
+
.. seealso::
``engine_caching`` - TODO: this will be an upcoming section describing
diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py
index d03d79df7..abffe0d1f 100644
--- a/lib/sqlalchemy/engine/cursor.py
+++ b/lib/sqlalchemy/engine/cursor.py
@@ -51,6 +51,7 @@ class CursorResultMetaData(ResultMetaData):
"_keys",
"_tuplefilter",
"_translated_indexes",
+ "_safe_for_cache"
# don't need _unique_filters support here for now. Can be added
# if a need arises.
)
@@ -104,11 +105,11 @@ class CursorResultMetaData(ResultMetaData):
return new_metadata
def _adapt_to_context(self, context):
- """When using a cached result metadata against a new context,
- we need to rewrite the _keymap so that it has the specific
- Column objects in the new context inside of it. this accommodates
- for select() constructs that contain anonymized columns and
- are cached.
+ """When using a cached Compiled construct that has a _result_map,
+ for a new statement that used the cached Compiled, we need to ensure
+ the keymap has the Column objects from our new statement as keys.
+ So here we rewrite keymap with new entries for the new columns
+ as matched to those of the cached statement.
"""
if not context.compiled._result_columns:
@@ -124,14 +125,15 @@ class CursorResultMetaData(ResultMetaData):
# to the result map.
md = self.__class__.__new__(self.__class__)
- md._keymap = self._keymap.copy()
+ md._keymap = dict(self._keymap)
# match up new columns positionally to the result columns
for existing, new in zip(
context.compiled._result_columns,
invoked_statement._exported_columns_iterator(),
):
- md._keymap[new] = md._keymap[existing[RM_NAME]]
+ if existing[RM_NAME] in md._keymap:
+ md._keymap[new] = md._keymap[existing[RM_NAME]]
md.case_sensitive = self.case_sensitive
md._processors = self._processors
@@ -147,6 +149,7 @@ class CursorResultMetaData(ResultMetaData):
self._tuplefilter = None
self._translated_indexes = None
self.case_sensitive = dialect.case_sensitive
+ self._safe_for_cache = False
if context.result_column_struct:
(
@@ -341,6 +344,10 @@ class CursorResultMetaData(ResultMetaData):
self._keys = [elem[0] for elem in result_columns]
# pure positional 1-1 case; doesn't need to read
# the names from cursor.description
+
+ # this metadata is safe to cache because we are guaranteed
+ # to have the columns in the same order for new executions
+ self._safe_for_cache = True
return [
(
idx,
@@ -359,9 +366,12 @@ class CursorResultMetaData(ResultMetaData):
for idx, rmap_entry in enumerate(result_columns)
]
else:
+
# name-based or text-positional cases, where we need
# to read cursor.description names
+
if textual_ordered:
+ self._safe_for_cache = True
# textual positional case
raw_iterator = self._merge_textual_cols_by_position(
context, cursor_description, result_columns
@@ -369,6 +379,9 @@ class CursorResultMetaData(ResultMetaData):
elif num_ctx_cols:
# compiled SQL with a mismatch of description cols
# vs. compiled cols, or textual w/ unordered columns
+ # the order of columns can change if the query is
+ # against a "select *", so not safe to cache
+ self._safe_for_cache = False
raw_iterator = self._merge_cols_by_name(
context,
cursor_description,
@@ -376,7 +389,9 @@ class CursorResultMetaData(ResultMetaData):
loose_column_name_matching,
)
else:
- # no compiled SQL, just a raw string
+ # no compiled SQL, just a raw string, order of columns
+ # can change for "select *"
+ self._safe_for_cache = False
raw_iterator = self._merge_cols_by_none(
context, cursor_description
)
@@ -1152,7 +1167,6 @@ class BaseCursorResult(object):
out_parameters = None
_metadata = None
- _metadata_from_cache = False
_soft_closed = False
closed = False
@@ -1209,33 +1223,38 @@ class BaseCursorResult(object):
def _init_metadata(self, context, cursor_description):
if context.compiled:
if context.compiled._cached_metadata:
- cached_md = self.context.compiled._cached_metadata
- self._metadata_from_cache = True
-
- # result rewrite/ adapt step. two translations can occur here.
- # one is if we are invoked against a cached statement, we want
- # to rewrite the ResultMetaData to reflect the column objects
- # that are in our current selectable, not the cached one. the
- # other is, the CompileState can return an alternative Result
- # object. Finally, CompileState might want to tell us to not
- # actually do the ResultMetaData adapt step if it in fact has
- # changed the selected columns in any case.
- compiled = context.compiled
- if (
- compiled
- and not compiled._rewrites_selected_columns
- and compiled.statement is not context.invoked_statement
- ):
- cached_md = cached_md._adapt_to_context(context)
-
- self._metadata = metadata = cached_md
-
+ metadata = self.context.compiled._cached_metadata
else:
- self._metadata = (
- metadata
- ) = context.compiled._cached_metadata = self._cursor_metadata(
- self, cursor_description
- )
+ metadata = self._cursor_metadata(self, cursor_description)
+ if metadata._safe_for_cache:
+ context.compiled._cached_metadata = metadata
+
+ # result rewrite/ adapt step. this is to suit the case
+ # when we are invoked against a cached Compiled object, we want
+ # to rewrite the ResultMetaData to reflect the Column objects
+ # that are in our current SQL statement object, not the one
+ # that is associated with the cached Compiled object.
+ # the Compiled object may also tell us to not
+ # actually do this step; this is to support the ORM where
+ # it is to produce a new Result object in any case, and will
+ # be using the cached Column objects against this database result
+ # so we don't want to rewrite them.
+ #
+ # Basically this step suits the use case where the end user
+ # is using Core SQL expressions and is accessing columns in the
+ # result row using row._mapping[table.c.column].
+ compiled = context.compiled
+ if (
+ compiled
+ and compiled._result_columns
+ and context.cache_hit
+ and not compiled._rewrites_selected_columns
+ and compiled.statement is not context.invoked_statement
+ ):
+ metadata = metadata._adapt_to_context(context)
+
+ self._metadata = metadata
+
else:
self._metadata = metadata = self._cursor_metadata(
self, cursor_description
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index c682a8ee1..4d516e97c 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -230,7 +230,6 @@ class DefaultDialect(interfaces.Dialect):
supports_native_boolean=None,
max_identifier_length=None,
label_length=None,
- query_cache_size=0,
# int() is because the @deprecated_params decorator cannot accommodate
# the direct reference to the "NO_LINTING" object
compiler_linting=int(compiler.NO_LINTING),
@@ -262,10 +261,6 @@ class DefaultDialect(interfaces.Dialect):
if supports_native_boolean is not None:
self.supports_native_boolean = supports_native_boolean
self.case_sensitive = case_sensitive
- if query_cache_size != 0:
- self._compiled_cache = util.LRUCache(query_cache_size)
- else:
- self._compiled_cache = None
self._user_defined_max_identifier_length = max_identifier_length
if self._user_defined_max_identifier_length:
@@ -794,6 +789,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
parameters,
invoked_statement,
extracted_parameters,
+ cache_hit=False,
):
"""Initialize execution context for a Compiled construct."""
@@ -804,6 +800,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
self.extracted_parameters = extracted_parameters
self.invoked_statement = invoked_statement
self.compiled = compiled
+ self.cache_hit = cache_hit
self.execution_options = execution_options
@@ -1027,13 +1024,15 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
def _get_cache_stats(self):
if self.compiled is None:
- return "raw SQL"
+ return "raw sql"
now = time.time()
if self.compiled.cache_key is None:
- return "gen %.5fs" % (now - self.compiled._gen_time,)
+ return "no key %.5fs" % (now - self.compiled._gen_time,)
+ elif self.cache_hit:
+ return "cached for %.4gs" % (now - self.compiled._gen_time,)
else:
- return "cached %.5fs" % (now - self.compiled._gen_time,)
+ return "generated in %.5fs" % (now - self.compiled._gen_time,)
@util.memoized_property
def engine(self):
diff --git a/lib/sqlalchemy/ext/baked.py b/lib/sqlalchemy/ext/baked.py
index f95a30fda..4f40637c5 100644
--- a/lib/sqlalchemy/ext/baked.py
+++ b/lib/sqlalchemy/ext/baked.py
@@ -412,7 +412,6 @@ class Result(object):
result = self.session.execute(
statement, params, execution_options=execution_options
)
-
if result._attributes.get("is_single_entity", False):
result = result.scalars()
diff --git a/lib/sqlalchemy/future/selectable.py b/lib/sqlalchemy/future/selectable.py
index 407ec9633..53fc7c107 100644
--- a/lib/sqlalchemy/future/selectable.py
+++ b/lib/sqlalchemy/future/selectable.py
@@ -11,6 +11,7 @@ class Select(_LegacySelect):
_is_future = True
_setup_joins = ()
_legacy_setup_joins = ()
+ inherit_cache = True
@classmethod
def _create_select(cls, *entities):
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py
index 262a1efc9..bf07061c6 100644
--- a/lib/sqlalchemy/orm/attributes.py
+++ b/lib/sqlalchemy/orm/attributes.py
@@ -85,16 +85,16 @@ class QueryableAttribute(
self,
class_,
key,
+ parententity,
impl=None,
comparator=None,
- parententity=None,
of_type=None,
):
self.class_ = class_
self.key = key
+ self._parententity = parententity
self.impl = impl
self.comparator = comparator
- self._parententity = parententity
self._of_type = of_type
manager = manager_of_class(class_)
@@ -197,10 +197,14 @@ class QueryableAttribute(
@util.memoized_property
def expression(self):
return self.comparator.__clause_element__()._annotate(
- {"orm_key": self.key}
+ {"orm_key": self.key, "entity_namespace": self._entity_namespace}
)
@property
+ def _entity_namespace(self):
+ return self._parententity
+
+ @property
def _annotations(self):
return self.__clause_element__()._annotations
@@ -230,9 +234,9 @@ class QueryableAttribute(
return QueryableAttribute(
self.class_,
self.key,
- self.impl,
- self.comparator.of_type(entity),
self._parententity,
+ impl=self.impl,
+ comparator=self.comparator.of_type(entity),
of_type=inspection.inspect(entity),
)
@@ -301,6 +305,8 @@ class InstrumentedAttribute(QueryableAttribute):
"""
+ inherit_cache = True
+
def __set__(self, instance, value):
self.impl.set(
instance_state(instance), instance_dict(instance), value, None
@@ -320,6 +326,11 @@ class InstrumentedAttribute(QueryableAttribute):
return self.impl.get(instance_state(instance), dict_)
+HasEntityNamespace = util.namedtuple(
+ "HasEntityNamespace", ["entity_namespace"]
+)
+
+
def create_proxied_attribute(descriptor):
"""Create an QueryableAttribute / user descriptor hybrid.
@@ -365,6 +376,15 @@ def create_proxied_attribute(descriptor):
)
@property
+ def _entity_namespace(self):
+ if hasattr(self._comparator, "_parententity"):
+ return self._comparator._parententity
+ else:
+ # used by hybrid attributes which try to remain
+ # agnostic of any ORM concepts like mappers
+ return HasEntityNamespace(self.class_)
+
+ @property
def property(self):
return self.comparator.property
diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py
index a16db66f6..588b83571 100644
--- a/lib/sqlalchemy/orm/context.py
+++ b/lib/sqlalchemy/orm/context.py
@@ -63,6 +63,8 @@ class QueryContext(object):
"post_load_paths",
"identity_token",
"yield_per",
+ "loaders_require_buffering",
+ "loaders_require_uniquing",
)
class default_load_options(Options):
@@ -80,21 +82,23 @@ class QueryContext(object):
def __init__(
self,
compile_state,
+ statement,
session,
load_options,
execution_options=None,
bind_arguments=None,
):
-
self.load_options = load_options
self.execution_options = execution_options or _EMPTY_DICT
self.bind_arguments = bind_arguments or _EMPTY_DICT
self.compile_state = compile_state
- self.query = query = compile_state.select_statement
+ self.query = statement
self.session = session
+ self.loaders_require_buffering = False
+ self.loaders_require_uniquing = False
self.propagated_loader_options = {
- o for o in query._with_options if o.propagate_to_loaders
+ o for o in statement._with_options if o.propagate_to_loaders
}
self.attributes = dict(compile_state.attributes)
@@ -237,6 +241,7 @@ class ORMCompileState(CompileState):
)
querycontext = QueryContext(
compile_state,
+ statement,
session,
load_options,
execution_options,
@@ -278,8 +283,6 @@ class ORMFromStatementCompileState(ORMCompileState):
_has_orm_entities = False
multi_row_eager_loaders = False
compound_eager_adapter = None
- loaders_require_buffering = False
- loaders_require_uniquing = False
@classmethod
def create_for_statement(cls, statement_container, compiler, **kw):
@@ -386,8 +389,6 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
_has_orm_entities = False
multi_row_eager_loaders = False
compound_eager_adapter = None
- loaders_require_buffering = False
- loaders_require_uniquing = False
correlate = None
_where_criteria = ()
@@ -416,7 +417,14 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
self = cls.__new__(cls)
- self.select_statement = select_statement
+ if select_statement._execution_options:
+ # execution options should not impact the compilation of a
+ # query, and at the moment subqueryloader is putting some things
+ # in here that we explicitly don't want stuck in a cache.
+ self.select_statement = select_statement._clone()
+ self.select_statement._execution_options = util.immutabledict()
+ else:
+ self.select_statement = select_statement
# indicates this select() came from Query.statement
self.for_statement = (
@@ -654,6 +662,8 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
)
self._setup_with_polymorphics()
+ # entities will also set up polymorphic adapters for mappers
+ # that have with_polymorphic configured
_QueryEntity.to_compile_state(self, query._raw_columns)
return self
@@ -1810,10 +1820,12 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
self._where_criteria += (single_crit,)
-def _column_descriptions(query_or_select_stmt):
- ctx = ORMSelectCompileState._create_entities_collection(
- query_or_select_stmt
- )
+def _column_descriptions(query_or_select_stmt, compile_state=None):
+ if compile_state is None:
+ compile_state = ORMSelectCompileState._create_entities_collection(
+ query_or_select_stmt
+ )
+ ctx = compile_state
return [
{
"name": ent._label_name,
@@ -2097,6 +2109,7 @@ class _MapperEntity(_QueryEntity):
only_load_props = refresh_state = None
_instance = loading._instance_processor(
+ self,
self.mapper,
context,
result,
diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py
index 027f2521b..39cf86e34 100644
--- a/lib/sqlalchemy/orm/descriptor_props.py
+++ b/lib/sqlalchemy/orm/descriptor_props.py
@@ -411,7 +411,6 @@ class CompositeProperty(DescriptorProperty):
def expression(self):
clauses = self.clauses._annotate(
{
- "bundle": True,
"parententity": self._parententity,
"parentmapper": self._parententity,
"orm_key": self.prop.key,
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index 6c0f5d3ef..9782d92b7 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -158,7 +158,7 @@ class MapperProperty(
"""
def create_row_processor(
- self, context, path, mapper, result, adapter, populators
+ self, context, query_entity, path, mapper, result, adapter, populators
):
"""Produce row processing functions and append to the given
set of populators lists.
@@ -539,7 +539,7 @@ class StrategizedProperty(MapperProperty):
"_wildcard_token",
"_default_path_loader_key",
)
-
+ inherit_cache = True
strategy_wildcard_key = None
def _memoized_attr__wildcard_token(self):
@@ -600,7 +600,7 @@ class StrategizedProperty(MapperProperty):
)
def create_row_processor(
- self, context, path, mapper, result, adapter, populators
+ self, context, query_entity, path, mapper, result, adapter, populators
):
loader = self._get_context_loader(context, path)
if loader and loader.strategy:
@@ -608,7 +608,14 @@ class StrategizedProperty(MapperProperty):
else:
strat = self.strategy
strat.create_row_processor(
- context, path, loader, mapper, result, adapter, populators
+ context,
+ query_entity,
+ path,
+ loader,
+ mapper,
+ result,
+ adapter,
+ populators,
)
def do_init(self):
@@ -668,7 +675,7 @@ class StrategizedProperty(MapperProperty):
)
-class ORMOption(object):
+class ORMOption(HasCacheKey):
"""Base class for option objects that are passed to ORM queries.
These options may be consumed by :meth:`.Query.options`,
@@ -696,7 +703,7 @@ class ORMOption(object):
_is_compile_state = False
-class LoaderOption(HasCacheKey, ORMOption):
+class LoaderOption(ORMOption):
"""Describe a loader modification to an ORM statement at compilation time.
.. versionadded:: 1.4
@@ -736,9 +743,6 @@ class UserDefinedOption(ORMOption):
def __init__(self, payload=None):
self.payload = payload
- def _gen_cache_key(self, *arg, **kw):
- return ()
-
@util.deprecated_cls(
"1.4",
@@ -855,7 +859,15 @@ class LoaderStrategy(object):
"""
def create_row_processor(
- self, context, path, loadopt, mapper, result, adapter, populators
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
):
"""Establish row processing functions for a given QueryContext.
diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py
index 424ed5dfe..a33e1b77d 100644
--- a/lib/sqlalchemy/orm/loading.py
+++ b/lib/sqlalchemy/orm/loading.py
@@ -72,8 +72,8 @@ def instances(cursor, context):
)
if context.yield_per and (
- context.compile_state.loaders_require_buffering
- or context.compile_state.loaders_require_uniquing
+ context.loaders_require_buffering
+ or context.loaders_require_uniquing
):
raise sa_exc.InvalidRequestError(
"Can't use yield_per with eager loaders that require uniquing "
@@ -545,6 +545,7 @@ def _warn_for_runid_changed(state):
def _instance_processor(
+ query_entity,
mapper,
context,
result,
@@ -648,6 +649,7 @@ def _instance_processor(
# to see if one fits
prop.create_row_processor(
context,
+ query_entity,
path,
mapper,
result,
@@ -667,7 +669,7 @@ def _instance_processor(
populators = {key: list(value) for key, value in cached_populators.items()}
for prop in getters["todo"]:
prop.create_row_processor(
- context, path, mapper, result, adapter, populators
+ context, query_entity, path, mapper, result, adapter, populators
)
propagated_loader_options = context.propagated_loader_options
@@ -925,6 +927,7 @@ def _instance_processor(
_instance = _decorate_polymorphic_switch(
_instance,
context,
+ query_entity,
mapper,
result,
path,
@@ -1081,6 +1084,7 @@ def _validate_version_id(mapper, state, dict_, row, getter):
def _decorate_polymorphic_switch(
instance_fn,
context,
+ query_entity,
mapper,
result,
path,
@@ -1112,6 +1116,7 @@ def _decorate_polymorphic_switch(
return False
return _instance_processor(
+ query_entity,
sub_mapper,
context,
result,
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index c4cb89c03..bec6da74d 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -720,7 +720,7 @@ class Mapper(
return self
_cache_key_traversal = [
- ("class_", visitors.ExtendedInternalTraversal.dp_plain_obj)
+ ("mapper", visitors.ExtendedInternalTraversal.dp_plain_obj),
]
@property
diff --git a/lib/sqlalchemy/orm/path_registry.py b/lib/sqlalchemy/orm/path_registry.py
index 2e5941713..ac7a64c30 100644
--- a/lib/sqlalchemy/orm/path_registry.py
+++ b/lib/sqlalchemy/orm/path_registry.py
@@ -216,6 +216,8 @@ class RootRegistry(PathRegistry):
"""
+ inherit_cache = True
+
path = natural_path = ()
has_entity = False
is_aliased_class = False
@@ -248,6 +250,8 @@ class PathToken(HasCacheKey, str):
class TokenRegistry(PathRegistry):
__slots__ = ("token", "parent", "path", "natural_path")
+ inherit_cache = True
+
def __init__(self, parent, token):
token = PathToken.intern(token)
@@ -280,6 +284,7 @@ class TokenRegistry(PathRegistry):
class PropRegistry(PathRegistry):
is_unnatural = False
+ inherit_cache = True
def __init__(self, parent, prop):
# restate this path in terms of the
@@ -439,6 +444,7 @@ class AbstractEntityRegistry(PathRegistry):
class SlotsEntityRegistry(AbstractEntityRegistry):
# for aliased class, return lightweight, no-cycles created
# version
+ inherit_cache = True
__slots__ = (
"key",
@@ -454,6 +460,8 @@ class CachingEntityRegistry(AbstractEntityRegistry, dict):
# for long lived mapper, return dict based caching
# version that creates reference cycles
+ inherit_cache = True
+
def __getitem__(self, entity):
if isinstance(entity, (int, slice)):
return self.path[entity]
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index 19d43d354..8393eaf74 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -38,6 +38,7 @@ from ..sql.base import Options
from ..sql.dml import DeleteDMLState
from ..sql.dml import UpdateDMLState
from ..sql.elements import BooleanClauseList
+from ..sql.util import _entity_namespace_key
def _bulk_insert(
@@ -1820,8 +1821,12 @@ class BulkUDCompileState(CompileState):
if isinstance(k, util.string_types):
desc = sql.util._entity_namespace_key(mapper, k)
values.extend(desc._bulk_update_tuples(v))
- elif isinstance(k, attributes.QueryableAttribute):
- values.extend(k._bulk_update_tuples(v))
+ elif "entity_namespace" in k._annotations:
+ k_anno = k._annotations
+ attr = _entity_namespace_key(
+ k_anno["entity_namespace"], k_anno["orm_key"]
+ )
+ values.extend(attr._bulk_update_tuples(v))
else:
values.append((k, v))
else:
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index 02f0752a5..5fb3beca3 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -45,6 +45,7 @@ class ColumnProperty(StrategizedProperty):
"""
strategy_wildcard_key = "column"
+ inherit_cache = True
__slots__ = (
"_orig_columns",
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index 284ea9d72..cdad55320 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -61,6 +61,7 @@ from ..sql.selectable import LABEL_STYLE_NONE
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
from ..sql.selectable import SelectStatementGrouping
from ..sql.util import _entity_namespace_key
+from ..sql.visitors import InternalTraversal
from ..util import collections_abc
__all__ = ["Query", "QueryContext", "aliased"]
@@ -423,6 +424,7 @@ class Query(
_label_style=self._label_style,
compile_options=compile_options,
)
+ stmt.__dict__.pop("session", None)
stmt._propagate_attrs = self._propagate_attrs
return stmt
@@ -1725,7 +1727,6 @@ class Query(
"""
from_entity = self._filter_by_zero()
-
if from_entity is None:
raise sa_exc.InvalidRequestError(
"Can't use filter_by when the first entity '%s' of a query "
@@ -2900,7 +2901,10 @@ class Query(
compile_state = self._compile_state(for_statement=False)
context = QueryContext(
- compile_state, self.session, self.load_options
+ compile_state,
+ compile_state.statement,
+ self.session,
+ self.load_options,
)
result = loading.instances(result_proxy, context)
@@ -3376,7 +3380,12 @@ class Query(
def _compile_context(self, for_statement=False):
compile_state = self._compile_state(for_statement=for_statement)
- context = QueryContext(compile_state, self.session, self.load_options)
+ context = QueryContext(
+ compile_state,
+ compile_state.statement,
+ self.session,
+ self.load_options,
+ )
return context
@@ -3397,6 +3406,11 @@ class FromStatement(SelectStatementGrouping, Executable):
_for_update_arg = None
+ _traverse_internals = [
+ ("_raw_columns", InternalTraversal.dp_clauseelement_list),
+ ("element", InternalTraversal.dp_clauseelement),
+ ] + Executable._executable_traverse_internals
+
def __init__(self, entities, element):
self._raw_columns = [
coercions.expect(
diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py
index 683f2b978..bedc54153 100644
--- a/lib/sqlalchemy/orm/relationships.py
+++ b/lib/sqlalchemy/orm/relationships.py
@@ -107,6 +107,7 @@ class RelationshipProperty(StrategizedProperty):
"""
strategy_wildcard_key = "relationship"
+ inherit_cache = True
_persistence_only = dict(
passive_deletes=False,
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
index f67c23aab..5f039aff7 100644
--- a/lib/sqlalchemy/orm/strategies.py
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -25,6 +25,7 @@ from .base import _DEFER_FOR_STATE
from .base import _RAISE_FOR_STATE
from .base import _SET_DEFERRED_EXPIRED
from .context import _column_descriptions
+from .context import ORMCompileState
from .interfaces import LoaderStrategy
from .interfaces import StrategizedProperty
from .session import _state_session
@@ -156,7 +157,15 @@ class UninstrumentedColumnLoader(LoaderStrategy):
column_collection.append(c)
def create_row_processor(
- self, context, path, loadopt, mapper, result, adapter, populators
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
):
pass
@@ -224,7 +233,15 @@ class ColumnLoader(LoaderStrategy):
)
def create_row_processor(
- self, context, path, loadopt, mapper, result, adapter, populators
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
):
# look through list of columns represented here
# to see which, if any, is present in the row.
@@ -281,7 +298,15 @@ class ExpressionColumnLoader(ColumnLoader):
memoized_populators[self.parent_property] = fetch
def create_row_processor(
- self, context, path, loadopt, mapper, result, adapter, populators
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
):
# look through list of columns represented here
# to see which, if any, is present in the row.
@@ -332,7 +357,15 @@ class DeferredColumnLoader(LoaderStrategy):
self.group = self.parent_property.group
def create_row_processor(
- self, context, path, loadopt, mapper, result, adapter, populators
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
):
# for a DeferredColumnLoader, this method is only used during a
@@ -542,7 +575,15 @@ class NoLoader(AbstractRelationshipLoader):
)
def create_row_processor(
- self, context, path, loadopt, mapper, result, adapter, populators
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
):
def invoke_no_load(state, dict_, row):
if self.uselist:
@@ -985,7 +1026,15 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
return None
def create_row_processor(
- self, context, path, loadopt, mapper, result, adapter, populators
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
):
key = self.key
@@ -1039,12 +1088,27 @@ class PostLoader(AbstractRelationshipLoader):
"""A relationship loader that emits a second SELECT statement."""
def _immediateload_create_row_processor(
- self, context, path, loadopt, mapper, result, adapter, populators
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
):
return self.parent_property._get_strategy(
(("lazy", "immediate"),)
).create_row_processor(
- context, path, loadopt, mapper, result, adapter, populators
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
)
@@ -1057,21 +1121,16 @@ class ImmediateLoader(PostLoader):
(("lazy", "select"),)
).init_class_attribute(mapper)
- def setup_query(
+ def create_row_processor(
self,
- compile_state,
- entity,
+ context,
+ query_entity,
path,
loadopt,
+ mapper,
+ result,
adapter,
- column_collection=None,
- parentmapper=None,
- **kwargs
- ):
- pass
-
- def create_row_processor(
- self, context, path, loadopt, mapper, result, adapter, populators
+ populators,
):
def load_immediate(state, dict_, row):
state.get_impl(self.key).get(state, dict_)
@@ -1093,120 +1152,6 @@ class SubqueryLoader(PostLoader):
(("lazy", "select"),)
).init_class_attribute(mapper)
- def setup_query(
- self,
- compile_state,
- entity,
- path,
- loadopt,
- adapter,
- column_collection=None,
- parentmapper=None,
- **kwargs
- ):
- if (
- not compile_state.compile_options._enable_eagerloads
- or compile_state.compile_options._for_refresh_state
- ):
- return
-
- compile_state.loaders_require_buffering = True
-
- path = path[self.parent_property]
-
- # build up a path indicating the path from the leftmost
- # entity to the thing we're subquery loading.
- with_poly_entity = path.get(
- compile_state.attributes, "path_with_polymorphic", None
- )
- if with_poly_entity is not None:
- effective_entity = with_poly_entity
- else:
- effective_entity = self.entity
-
- subq_path = compile_state.attributes.get(
- ("subquery_path", None), orm_util.PathRegistry.root
- )
-
- subq_path = subq_path + path
-
- # if not via query option, check for
- # a cycle
- if not path.contains(compile_state.attributes, "loader"):
- if self.join_depth:
- if (
- (
- compile_state.current_path.length
- if compile_state.current_path
- else 0
- )
- + path.length
- ) / 2 > self.join_depth:
- return
- elif subq_path.contains_mapper(self.mapper):
- return
-
- (
- leftmost_mapper,
- leftmost_attr,
- leftmost_relationship,
- ) = self._get_leftmost(subq_path)
-
- orig_query = compile_state.attributes.get(
- ("orig_query", SubqueryLoader), compile_state.select_statement
- )
-
- # generate a new Query from the original, then
- # produce a subquery from it.
- left_alias = self._generate_from_original_query(
- compile_state,
- orig_query,
- leftmost_mapper,
- leftmost_attr,
- leftmost_relationship,
- entity.entity_zero,
- )
-
- # generate another Query that will join the
- # left alias to the target relationships.
- # basically doing a longhand
- # "from_self()". (from_self() itself not quite industrial
- # strength enough for all contingencies...but very close)
-
- q = query.Query(effective_entity)
-
- def set_state_options(compile_state):
- compile_state.attributes.update(
- {
- ("orig_query", SubqueryLoader): orig_query,
- ("subquery_path", None): subq_path,
- }
- )
-
- q = q._add_context_option(set_state_options, None)._disable_caching()
-
- q = q._set_enable_single_crit(False)
- to_join, local_attr, parent_alias = self._prep_for_joins(
- left_alias, subq_path
- )
-
- q = q.add_columns(*local_attr)
- q = self._apply_joins(
- q, to_join, left_alias, parent_alias, effective_entity
- )
-
- q = self._setup_options(q, subq_path, orig_query, effective_entity)
- q = self._setup_outermost_orderby(q)
-
- # add new query to attributes to be picked up
- # by create_row_processor
- # NOTE: be sure to consult baked.py for some hardcoded logic
- # about this structure as well
- assert q.session is None
- path.set(
- compile_state.attributes, "subqueryload_data", {"query": q},
- )
-
def _get_leftmost(self, subq_path):
subq_path = subq_path.path
subq_mapper = orm_util._class_to_mapper(subq_path[0])
@@ -1267,27 +1212,34 @@ class SubqueryLoader(PostLoader):
q,
*{
ent["entity"]
- for ent in _column_descriptions(orig_query)
+ for ent in _column_descriptions(
+ orig_query, compile_state=orig_compile_state
+ )
if ent["entity"] is not None
}
)
- # for column information, look to the compile state that is
- # already being passed through
- compile_state = orig_compile_state
-
# select from the identity columns of the outer (specifically, these
- # are the 'local_cols' of the property). This will remove
- # other columns from the query that might suggest the right entity
- # which is why we do _set_select_from above.
- target_cols = compile_state._adapt_col_list(
+ # are the 'local_cols' of the property). This will remove other
+ # columns from the query that might suggest the right entity which is
+ # why we do set select_from above. The attributes we have are
+ # coerced and adapted using the original query's adapter, which is
+ # needed only for the case of adapting a subclass column to
+ # that of a polymorphic selectable, e.g. we have
+ # Engineer.primary_language and the entity is Person. All other
+ # adaptations, e.g. from_self, select_entity_from(), will occur
+ # within the new query when it compiles, as the compile_state we are
+ # using here is only a partial one. If the subqueryload is from a
+ # with_polymorphic() or other aliased() object, left_attr will already
+ # be the correct attributes so no adaptation is needed.
+ target_cols = orig_compile_state._adapt_col_list(
[
- sql.coercions.expect(sql.roles.ByOfRole, o)
+ sql.coercions.expect(sql.roles.ColumnsClauseRole, o)
for o in leftmost_attr
],
- compile_state._get_current_adapter(),
+ orig_compile_state._get_current_adapter(),
)
- q._set_entities(target_cols)
+ q._raw_columns = target_cols
distinct_target_key = leftmost_relationship.distinct_target_key
@@ -1461,13 +1413,13 @@ class SubqueryLoader(PostLoader):
"_data",
)
- def __init__(self, context, subq_info):
+ def __init__(self, context, subq):
# avoid creating a cycle by storing context
# even though that's preferable
self.session = context.session
self.execution_options = context.execution_options
self.load_options = context.load_options
- self.subq = subq_info["query"]
+ self.subq = subq
self._data = None
def get(self, key, default):
@@ -1499,12 +1451,148 @@ class SubqueryLoader(PostLoader):
if self._data is None:
self._load()
+ def _setup_query_from_rowproc(
+ self, context, path, entity, loadopt, adapter,
+ ):
+ compile_state = context.compile_state
+ if (
+ not compile_state.compile_options._enable_eagerloads
+ or compile_state.compile_options._for_refresh_state
+ ):
+ return
+
+ context.loaders_require_buffering = True
+
+ path = path[self.parent_property]
+
+ # build up a path indicating the path from the leftmost
+ # entity to the thing we're subquery loading.
+ with_poly_entity = path.get(
+ compile_state.attributes, "path_with_polymorphic", None
+ )
+ if with_poly_entity is not None:
+ effective_entity = with_poly_entity
+ else:
+ effective_entity = self.entity
+
+ subq_path = context.query._execution_options.get(
+ ("subquery_path", None), orm_util.PathRegistry.root
+ )
+
+ subq_path = subq_path + path
+
+ # if not via query option, check for
+ # a cycle
+ if not path.contains(compile_state.attributes, "loader"):
+ if self.join_depth:
+ if (
+ (
+ compile_state.current_path.length
+ if compile_state.current_path
+ else 0
+ )
+ + path.length
+ ) / 2 > self.join_depth:
+ return
+ elif subq_path.contains_mapper(self.mapper):
+ return
+
+ (
+ leftmost_mapper,
+ leftmost_attr,
+ leftmost_relationship,
+ ) = self._get_leftmost(subq_path)
+
+ # use the current query being invoked, not the compile state
+ # one. this is so that we get the current parameters. however,
+ # it means we can't use the existing compile state, we have to make
+ # a new one. other approaches include possibly using the
+ # compiled query but swapping the params, seems only marginally
+ # less time spent but more complicated
+ orig_query = context.query._execution_options.get(
+ ("orig_query", SubqueryLoader), context.query
+ )
+
+ # make a new compile_state for the query that's probably cached, but
+ # we're sort of undoing a bit of that caching :(
+ compile_state_cls = ORMCompileState._get_plugin_class_for_plugin(
+ orig_query, "orm"
+ )
+
+ # this would create the full blown compile state, which we don't
+ # need
+ # orig_compile_state = compile_state_cls.create_for_statement(
+ # orig_query, None)
+
+ # this is the more "quick" version, however it's not clear how
+ # much of this we need. in particular I can't get a test to
+ # fail if the "set_base_alias" is missing and not sure why that is.
+ orig_compile_state = compile_state_cls._create_entities_collection(
+ orig_query
+ )
+
+ # generate a new Query from the original, then
+ # produce a subquery from it.
+ left_alias = self._generate_from_original_query(
+ orig_compile_state,
+ orig_query,
+ leftmost_mapper,
+ leftmost_attr,
+ leftmost_relationship,
+ entity,
+ )
+
+ # generate another Query that will join the
+ # left alias to the target relationships.
+ # basically doing a longhand
+ # "from_self()". (from_self() itself not quite industrial
+ # strength enough for all contingencies...but very close)
+
+ q = query.Query(effective_entity)
+
+ q._execution_options = q._execution_options.union(
+ {
+ ("orig_query", SubqueryLoader): orig_query,
+ ("subquery_path", None): subq_path,
+ }
+ )
+
+ q = q._set_enable_single_crit(False)
+ to_join, local_attr, parent_alias = self._prep_for_joins(
+ left_alias, subq_path
+ )
+
+ q = q.add_columns(*local_attr)
+ q = self._apply_joins(
+ q, to_join, left_alias, parent_alias, effective_entity
+ )
+
+ q = self._setup_options(q, subq_path, orig_query, effective_entity)
+ q = self._setup_outermost_orderby(q)
+
+ return q
+
def create_row_processor(
- self, context, path, loadopt, mapper, result, adapter, populators
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
):
if context.refresh_state:
return self._immediateload_create_row_processor(
- context, path, loadopt, mapper, result, adapter, populators
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
)
if not self.parent.class_manager[self.key].impl.supports_population:
@@ -1513,16 +1601,27 @@ class SubqueryLoader(PostLoader):
"population - eager loading cannot be applied." % self
)
- path = path[self.parent_property]
+ # a little dance here as the "path" is still something that only
+ # semi-tracks the exact series of things we are loading, still not
+ # telling us about with_polymorphic() and stuff like that when it's at
+ # the root.. the initial MapperEntity is more accurate for this case.
+ if len(path) == 1:
+ if not orm_util._entity_isa(query_entity.entity_zero, self.parent):
+ return
+ elif not orm_util._entity_isa(path[-1], self.parent):
+ return
- subq_info = path.get(context.attributes, "subqueryload_data")
+ subq = self._setup_query_from_rowproc(
+ context, path, path[-1], loadopt, adapter,
+ )
- if subq_info is None:
+ if subq is None:
return
- subq = subq_info["query"]
-
assert subq.session is None
+
+ path = path[self.parent_property]
+
local_cols = self.parent_property.local_columns
# cache the loaded collections in the context
@@ -1530,7 +1629,7 @@ class SubqueryLoader(PostLoader):
# call upon create_row_processor again
collections = path.get(context.attributes, "collections")
if collections is None:
- collections = self._SubqCollections(context, subq_info)
+ collections = self._SubqCollections(context, subq)
path.set(context.attributes, "collections", collections)
if adapter:
@@ -1634,7 +1733,6 @@ class JoinedLoader(AbstractRelationshipLoader):
if not compile_state.compile_options._enable_eagerloads:
return
elif self.uselist:
- compile_state.loaders_require_uniquing = True
compile_state.multi_row_eager_loaders = True
path = path[self.parent_property]
@@ -2142,7 +2240,15 @@ class JoinedLoader(AbstractRelationshipLoader):
return False
def create_row_processor(
- self, context, path, loadopt, mapper, result, adapter, populators
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
):
if not self.parent.class_manager[self.key].impl.supports_population:
raise sa_exc.InvalidRequestError(
@@ -2150,6 +2256,9 @@ class JoinedLoader(AbstractRelationshipLoader):
"population - eager loading cannot be applied." % self
)
+ if self.uselist:
+ context.loaders_require_uniquing = True
+
our_path = path[self.parent_property]
eager_adapter = self._create_eager_adapter(
@@ -2160,6 +2269,7 @@ class JoinedLoader(AbstractRelationshipLoader):
key = self.key
_instance = loading._instance_processor(
+ query_entity,
self.mapper,
context,
result,
@@ -2177,7 +2287,14 @@ class JoinedLoader(AbstractRelationshipLoader):
self.parent_property._get_strategy(
(("lazy", "select"),)
).create_row_processor(
- context, path, loadopt, mapper, result, adapter, populators
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
)
def _create_collection_loader(self, context, key, _instance, populators):
@@ -2382,11 +2499,26 @@ class SelectInLoader(PostLoader, util.MemoizedSlots):
return util.preloaded.ext_baked.bakery(size=50)
def create_row_processor(
- self, context, path, loadopt, mapper, result, adapter, populators
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
):
if context.refresh_state:
return self._immediateload_create_row_processor(
- context, path, loadopt, mapper, result, adapter, populators
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
)
if not self.parent.class_manager[self.key].impl.supports_population:
@@ -2395,13 +2527,20 @@ class SelectInLoader(PostLoader, util.MemoizedSlots):
"population - eager loading cannot be applied." % self
)
+ # a little dance here as the "path" is still something that only
+ # semi-tracks the exact series of things we are loading, still not
+ # telling us about with_polymorphic() and stuff like that when it's at
+ # the root.. the initial MapperEntity is more accurate for this case.
+ if len(path) == 1:
+ if not orm_util._entity_isa(query_entity.entity_zero, self.parent):
+ return
+ elif not orm_util._entity_isa(path[-1], self.parent):
+ return
+
selectin_path = (
context.compile_state.current_path or orm_util.PathRegistry.root
) + path
- if not orm_util._entity_isa(path[-1], self.parent):
- return
-
if loading.PostLoad.path_exists(
context, selectin_path, self.parent_property
):
@@ -2427,7 +2566,6 @@ class SelectInLoader(PostLoader, util.MemoizedSlots):
return
elif selectin_path_w_prop.contains_mapper(self.mapper):
return
-
loading.PostLoad.callable_for_path(
context,
selectin_path,
@@ -2543,7 +2681,39 @@ class SelectInLoader(PostLoader, util.MemoizedSlots):
)
)
- orig_query = context.query
+ # a test which exercises what these comments talk about is
+ # test_selectin_relations.py -> test_twolevel_selectin_w_polymorphic
+ #
+ # effective_entity above is given to us in terms of the cached
+ # statement, namely this one:
+ orig_query = context.compile_state.select_statement
+
+ # the actual statement that was requested is this one:
+ # context_query = context.query
+ #
+ # that's not the cached one, however. So while it is of the identical
+ # structure, if it has entities like AliasedInsp, which we get from
+ # aliased() or with_polymorphic(), the AliasedInsp will likely be a
+ # different object identity each time, and will not match up
+ # hashing-wise to the corresponding AliasedInsp that's in the
+ # cached query, meaning it won't match on paths and loader lookups
+ # and loaders like this one will be skipped if it is used in options.
+ #
+ # Now we want to transfer loader options from the parent query to the
+ # "selectinload" query we're about to run. Which query do we transfer
+ # the options from? We use the cached query, because the options in
+ # that query will be in terms of the effective entity we were just
+ # handed.
+ #
+ # But now the selectinload/ baked query we are running is *also*
+ # cached. What if it's cached and running from some previous iteration
+ # of that AliasedInsp? Well in that case it will also use the previous
+ # iteration of the loader options. If the baked query expires and
+ # gets generated again, it will be handed the current effective_entity
+ # and the current _with_options, again in terms of whatever
+ # compile_state.select_statement happens to be right now, so the
+ # query will still be internally consistent and loader callables
+ # will be correctly invoked.
q._add_lazyload_options(
orig_query._with_options, path[self.parent_property]
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 85f4f85d1..f7a97bfe5 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -1187,9 +1187,9 @@ class Bundle(ORMColumnsClauseRole, SupportsCloneAnnotations, InspectionAttr):
return cloned
def __clause_element__(self):
- annotations = self._annotations.union(
- {"bundle": self, "entity_namespace": self}
- )
+ # ensure existing entity_namespace remains
+ annotations = {"bundle": self, "entity_namespace": self}
+ annotations.update(self._annotations)
return expression.ClauseList(
_literal_as_text_role=roles.ColumnsClauseRole,
group=False,
@@ -1258,6 +1258,8 @@ class _ORMJoin(expression.Join):
__visit_name__ = expression.Join.__visit_name__
+ inherit_cache = True
+
def __init__(
self,
left,
diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py
index 78de80734..a25c1b083 100644
--- a/lib/sqlalchemy/sql/__init__.py
+++ b/lib/sqlalchemy/sql/__init__.py
@@ -100,6 +100,7 @@ def __go(lcls):
from .elements import AnnotatedColumnElement
from .elements import ClauseList # noqa
from .selectable import AnnotatedFromClause # noqa
+ from .traversals import _preconfigure_traversals
from . import base
from . import coercions
@@ -122,6 +123,8 @@ def __go(lcls):
_prepare_annotations(FromClause, AnnotatedFromClause)
_prepare_annotations(ClauseList, Annotated)
+ _preconfigure_traversals(ClauseElement)
+
_sa_util.preloaded.import_prefix("sqlalchemy.sql")
from . import naming # noqa
diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py
index 08ed121d3..8a0d6ec28 100644
--- a/lib/sqlalchemy/sql/annotation.py
+++ b/lib/sqlalchemy/sql/annotation.py
@@ -338,6 +338,15 @@ def _new_annotation_type(cls, base_cls):
anno_cls._traverse_internals = list(cls._traverse_internals) + [
("_annotations", InternalTraversal.dp_annotations_key)
]
+ elif cls.__dict__.get("inherit_cache", False):
+ anno_cls._traverse_internals = list(cls._traverse_internals) + [
+ ("_annotations", InternalTraversal.dp_annotations_key)
+ ]
+
+ # some classes include this even if they have traverse_internals
+ # e.g. BindParameter, add it if present.
+ if cls.__dict__.get("inherit_cache", False):
+ anno_cls.inherit_cache = True
anno_cls._is_column_operators = issubclass(cls, operators.ColumnOperators)
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index 5dd3b519a..5f2ce8f14 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -624,19 +624,14 @@ class Executable(Generative):
_bind = None
_with_options = ()
_with_context_options = ()
- _cache_enable = True
_executable_traverse_internals = [
("_with_options", ExtendedInternalTraversal.dp_has_cache_key_list),
("_with_context_options", ExtendedInternalTraversal.dp_plain_obj),
- ("_cache_enable", ExtendedInternalTraversal.dp_plain_obj),
+ ("_propagate_attrs", ExtendedInternalTraversal.dp_propagate_attrs),
]
@_generative
- def _disable_caching(self):
- self._cache_enable = HasCacheKey()
-
- @_generative
def options(self, *options):
"""Apply options to this statement.
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 2519438d1..61178291a 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -373,6 +373,8 @@ class Compiled(object):
_cached_metadata = None
+ _result_columns = None
+
schema_translate_map = None
execution_options = util.immutabledict()
@@ -433,7 +435,6 @@ class Compiled(object):
self,
dialect,
statement,
- bind=None,
schema_translate_map=None,
render_schema_translate=False,
compile_kwargs=util.immutabledict(),
@@ -463,7 +464,6 @@ class Compiled(object):
"""
self.dialect = dialect
- self.bind = bind
self.preparer = self.dialect.identifier_preparer
if schema_translate_map:
self.schema_translate_map = schema_translate_map
@@ -527,24 +527,6 @@ class Compiled(object):
"""Return the bind params for this compiled object."""
return self.construct_params()
- def execute(self, *multiparams, **params):
- """Execute this compiled object."""
-
- e = self.bind
- if e is None:
- raise exc.UnboundExecutionError(
- "This Compiled object is not bound to any Engine "
- "or Connection.",
- code="2afi",
- )
- return e._execute_compiled(self, multiparams, params)
-
- def scalar(self, *multiparams, **params):
- """Execute this compiled object and return the result's
- scalar value."""
-
- return self.execute(*multiparams, **params).scalar()
-
class TypeCompiler(util.with_metaclass(util.EnsureKWArgType, object)):
"""Produces DDL specification for TypeEngine objects."""
@@ -687,6 +669,13 @@ class SQLCompiler(Compiled):
insert_prefetch = update_prefetch = ()
+ _cache_key_bind_match = None
+ """a mapping that will relate the BindParameter object we compile
+ to those that are part of the extracted collection of parameters
+ in the cache key, if we were given a cache key.
+
+ """
+
def __init__(
self,
dialect,
@@ -717,6 +706,9 @@ class SQLCompiler(Compiled):
self.cache_key = cache_key
+ if cache_key:
+ self._cache_key_bind_match = {b: b for b in cache_key[1]}
+
# compile INSERT/UPDATE defaults/sequences inlined (no pre-
# execute)
self.inline = inline or getattr(statement, "_inline", False)
@@ -875,8 +867,9 @@ class SQLCompiler(Compiled):
replace_context=err,
)
+ ckbm = self._cache_key_bind_match
resolved_extracted = {
- b.key: extracted
+ ckbm[b]: extracted
for b, extracted in zip(orig_extracted, extracted_parameters)
}
else:
@@ -907,7 +900,7 @@ class SQLCompiler(Compiled):
else:
if resolved_extracted:
value_param = resolved_extracted.get(
- bindparam.key, bindparam
+ bindparam, bindparam
)
else:
value_param = bindparam
@@ -936,9 +929,7 @@ class SQLCompiler(Compiled):
)
if resolved_extracted:
- value_param = resolved_extracted.get(
- bindparam.key, bindparam
- )
+ value_param = resolved_extracted.get(bindparam, bindparam)
else:
value_param = bindparam
@@ -2021,6 +2012,19 @@ class SQLCompiler(Compiled):
)
self.binds[bindparam.key] = self.binds[name] = bindparam
+
+ # if we are given a cache key that we're going to match against,
+ # relate the bindparam here to one that is most likely present
+ # in the "extracted params" portion of the cache key. this is used
+ # to set up a positional mapping that is used to determine the
+ # correct parameters for a subsequent use of this compiled with
+ # a different set of parameter values. here, we accommodate for
+ # parameters that may have been cloned both before and after the cache
+ # key was been generated.
+ ckbm = self._cache_key_bind_match
+ if ckbm:
+ ckbm.update({bp: bindparam for bp in bindparam._cloned_set})
+
if bindparam.isoutparam:
self.has_out_parameters = True
diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py
index 569030651..d3730b124 100644
--- a/lib/sqlalchemy/sql/ddl.py
+++ b/lib/sqlalchemy/sql/ddl.py
@@ -28,6 +28,9 @@ class _DDLCompiles(ClauseElement):
return dialect.ddl_compiler(dialect, self, **kw)
+ def _compile_w_cache(self, *arg, **kw):
+ raise NotImplementedError()
+
class DDLElement(roles.DDLRole, Executable, _DDLCompiles):
"""Base class for DDL expression constructs.
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index a82641d77..50b2a935a 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -641,7 +641,7 @@ class ValuesBase(UpdateBase):
if self._preserve_parameter_order:
arg = [
(
- k,
+ coercions.expect(roles.DMLColumnRole, k),
coercions.expect(
roles.ExpressionElementRole,
v,
@@ -654,7 +654,7 @@ class ValuesBase(UpdateBase):
self._ordered_values = arg
else:
arg = {
- k: coercions.expect(
+ coercions.expect(roles.DMLColumnRole, k): coercions.expect(
roles.ExpressionElementRole,
v,
type_=NullType(),
@@ -772,6 +772,7 @@ class Insert(ValuesBase):
]
+ HasPrefixes._has_prefixes_traverse_internals
+ DialectKWArgs._dialect_kwargs_traverse_internals
+ + Executable._executable_traverse_internals
)
@ValuesBase._constructor_20_deprecations(
@@ -997,6 +998,7 @@ class Update(DMLWhereBase, ValuesBase):
]
+ HasPrefixes._has_prefixes_traverse_internals
+ DialectKWArgs._dialect_kwargs_traverse_internals
+ + Executable._executable_traverse_internals
)
@ValuesBase._constructor_20_deprecations(
@@ -1187,7 +1189,7 @@ class Update(DMLWhereBase, ValuesBase):
)
arg = [
(
- k,
+ coercions.expect(roles.DMLColumnRole, k),
coercions.expect(
roles.ExpressionElementRole,
v,
@@ -1238,6 +1240,7 @@ class Delete(DMLWhereBase, UpdateBase):
]
+ HasPrefixes._has_prefixes_traverse_internals
+ DialectKWArgs._dialect_kwargs_traverse_internals
+ + Executable._executable_traverse_internals
)
@ValuesBase._constructor_20_deprecations(
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 8e1b623a7..60c816ee6 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -381,6 +381,7 @@ class ClauseElement(
try:
traverse_internals = self._traverse_internals
except AttributeError:
+ # user-defined classes may not have a _traverse_internals
return
for attrname, obj, meth in _copy_internals.run_generated_dispatch(
@@ -410,6 +411,7 @@ class ClauseElement(
try:
traverse_internals = self._traverse_internals
except AttributeError:
+ # user-defined classes may not have a _traverse_internals
return []
return itertools.chain.from_iterable(
@@ -516,10 +518,62 @@ class ClauseElement(
dialect = bind.dialect
elif self.bind:
dialect = self.bind.dialect
- bind = self.bind
else:
dialect = default.StrCompileDialect()
- return self._compiler(dialect, bind=bind, **kw)
+
+ return self._compiler(dialect, **kw)
+
+ def _compile_w_cache(
+ self,
+ dialect,
+ compiled_cache=None,
+ column_keys=None,
+ inline=False,
+ schema_translate_map=None,
+ **kw
+ ):
+ if compiled_cache is not None:
+ elem_cache_key = self._generate_cache_key()
+ else:
+ elem_cache_key = None
+
+ cache_hit = False
+
+ if elem_cache_key:
+ cache_key, extracted_params = elem_cache_key
+ key = (
+ dialect,
+ cache_key,
+ tuple(column_keys),
+ bool(schema_translate_map),
+ inline,
+ )
+ compiled_sql = compiled_cache.get(key)
+
+ if compiled_sql is None:
+ compiled_sql = self._compiler(
+ dialect,
+ cache_key=elem_cache_key,
+ column_keys=column_keys,
+ inline=inline,
+ schema_translate_map=schema_translate_map,
+ **kw
+ )
+ compiled_cache[key] = compiled_sql
+ else:
+ cache_hit = True
+ else:
+ extracted_params = None
+ compiled_sql = self._compiler(
+ dialect,
+ cache_key=elem_cache_key,
+ column_keys=column_keys,
+ inline=inline,
+ schema_translate_map=schema_translate_map,
+ **kw
+ )
+
+ return compiled_sql, extracted_params, cache_hit
def _compiler(self, dialect, **kw):
"""Return a compiler appropriate for this ClauseElement, given a
@@ -1035,6 +1089,10 @@ class BindParameter(roles.InElementRole, ColumnElement):
_is_bind_parameter = True
_key_is_anon = False
+ # bindparam implements its own _gen_cache_key() method however
+ # we check subclasses for this flag, else no cache key is generated
+ inherit_cache = True
+
def __init__(
self,
key,
@@ -1396,6 +1454,13 @@ class BindParameter(roles.InElementRole, ColumnElement):
return c
def _gen_cache_key(self, anon_map, bindparams):
+ _gen_cache_ok = self.__class__.__dict__.get("inherit_cache", False)
+
+ if not _gen_cache_ok:
+ if anon_map is not None:
+ anon_map[NO_CACHE] = True
+ return None
+
idself = id(self)
if idself in anon_map:
return (anon_map[idself], self.__class__)
@@ -2082,6 +2147,7 @@ class ClauseList(
roles.InElementRole,
roles.OrderByRole,
roles.ColumnsClauseRole,
+ roles.DMLColumnRole,
ClauseElement,
):
"""Describe a list of clauses, separated by an operator.
@@ -2174,6 +2240,7 @@ class ClauseList(
class BooleanClauseList(ClauseList, ColumnElement):
__visit_name__ = "clauselist"
+ inherit_cache = True
_tuple_values = False
@@ -3428,6 +3495,8 @@ class CollectionAggregate(UnaryExpression):
class AsBoolean(WrapsColumnExpression, UnaryExpression):
+ inherit_cache = True
+
def __init__(self, element, operator, negate):
self.element = element
self.type = type_api.BOOLEANTYPE
@@ -3474,6 +3543,7 @@ class BinaryExpression(ColumnElement):
("operator", InternalTraversal.dp_operator),
("negate", InternalTraversal.dp_operator),
("modifiers", InternalTraversal.dp_plain_dict),
+ ("type", InternalTraversal.dp_type,), # affects JSON CAST operators
]
_is_implicitly_boolean = True
@@ -3482,41 +3552,6 @@ class BinaryExpression(ColumnElement):
"""
- def _gen_cache_key(self, anon_map, bindparams):
- # inlined for performance
-
- idself = id(self)
-
- if idself in anon_map:
- return (anon_map[idself], self.__class__)
- else:
- # inline of
- # id_ = anon_map[idself]
- anon_map[idself] = id_ = str(anon_map.index)
- anon_map.index += 1
-
- if self._cache_key_traversal is NO_CACHE:
- anon_map[NO_CACHE] = True
- return None
-
- result = (id_, self.__class__)
-
- return result + (
- ("left", self.left._gen_cache_key(anon_map, bindparams)),
- ("right", self.right._gen_cache_key(anon_map, bindparams)),
- ("operator", self.operator),
- ("negate", self.negate),
- (
- "modifiers",
- tuple(
- (key, self.modifiers[key])
- for key in sorted(self.modifiers)
- )
- if self.modifiers
- else None,
- ),
- )
-
def __init__(
self, left, right, operator, type_=None, negate=None, modifiers=None
):
@@ -3587,15 +3622,30 @@ class Slice(ColumnElement):
__visit_name__ = "slice"
_traverse_internals = [
- ("start", InternalTraversal.dp_plain_obj),
- ("stop", InternalTraversal.dp_plain_obj),
- ("step", InternalTraversal.dp_plain_obj),
+ ("start", InternalTraversal.dp_clauseelement),
+ ("stop", InternalTraversal.dp_clauseelement),
+ ("step", InternalTraversal.dp_clauseelement),
]
- def __init__(self, start, stop, step):
- self.start = start
- self.stop = stop
- self.step = step
+ def __init__(self, start, stop, step, _name=None):
+ self.start = coercions.expect(
+ roles.ExpressionElementRole,
+ start,
+ name=_name,
+ type_=type_api.INTEGERTYPE,
+ )
+ self.stop = coercions.expect(
+ roles.ExpressionElementRole,
+ stop,
+ name=_name,
+ type_=type_api.INTEGERTYPE,
+ )
+ self.step = coercions.expect(
+ roles.ExpressionElementRole,
+ step,
+ name=_name,
+ type_=type_api.INTEGERTYPE,
+ )
self.type = type_api.NULLTYPE
def self_group(self, against=None):
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
index 6b1172eba..7b723f371 100644
--- a/lib/sqlalchemy/sql/functions.py
+++ b/lib/sqlalchemy/sql/functions.py
@@ -744,6 +744,7 @@ class GenericFunction(util.with_metaclass(_GenericMeta, Function)):
coerce_arguments = True
_register = False
+ inherit_cache = True
def __init__(self, *args, **kwargs):
parsed_args = kwargs.pop("_parsed_args", None)
@@ -808,6 +809,8 @@ class next_value(GenericFunction):
class AnsiFunction(GenericFunction):
+ inherit_cache = True
+
def __init__(self, *args, **kwargs):
GenericFunction.__init__(self, *args, **kwargs)
@@ -815,6 +818,8 @@ class AnsiFunction(GenericFunction):
class ReturnTypeFromArgs(GenericFunction):
"""Define a function whose return type is the same as its arguments."""
+ inherit_cache = True
+
def __init__(self, *args, **kwargs):
args = [
coercions.expect(
@@ -832,30 +837,34 @@ class ReturnTypeFromArgs(GenericFunction):
class coalesce(ReturnTypeFromArgs):
_has_args = True
+ inherit_cache = True
class max(ReturnTypeFromArgs): # noqa
- pass
+ inherit_cache = True
class min(ReturnTypeFromArgs): # noqa
- pass
+ inherit_cache = True
class sum(ReturnTypeFromArgs): # noqa
- pass
+ inherit_cache = True
class now(GenericFunction): # noqa
type = sqltypes.DateTime
+ inherit_cache = True
class concat(GenericFunction):
type = sqltypes.String
+ inherit_cache = True
class char_length(GenericFunction):
type = sqltypes.Integer
+ inherit_cache = True
def __init__(self, arg, **kwargs):
GenericFunction.__init__(self, arg, **kwargs)
@@ -863,6 +872,7 @@ class char_length(GenericFunction):
class random(GenericFunction):
_has_args = True
+ inherit_cache = True
class count(GenericFunction):
@@ -887,6 +897,7 @@ class count(GenericFunction):
"""
type = sqltypes.Integer
+ inherit_cache = True
def __init__(self, expression=None, **kwargs):
if expression is None:
@@ -896,38 +907,47 @@ class count(GenericFunction):
class current_date(AnsiFunction):
type = sqltypes.Date
+ inherit_cache = True
class current_time(AnsiFunction):
type = sqltypes.Time
+ inherit_cache = True
class current_timestamp(AnsiFunction):
type = sqltypes.DateTime
+ inherit_cache = True
class current_user(AnsiFunction):
type = sqltypes.String
+ inherit_cache = True
class localtime(AnsiFunction):
type = sqltypes.DateTime
+ inherit_cache = True
class localtimestamp(AnsiFunction):
type = sqltypes.DateTime
+ inherit_cache = True
class session_user(AnsiFunction):
type = sqltypes.String
+ inherit_cache = True
class sysdate(AnsiFunction):
type = sqltypes.DateTime
+ inherit_cache = True
class user(AnsiFunction):
type = sqltypes.String
+ inherit_cache = True
class array_agg(GenericFunction):
@@ -951,6 +971,7 @@ class array_agg(GenericFunction):
"""
type = sqltypes.ARRAY
+ inherit_cache = True
def __init__(self, *args, **kwargs):
args = [
@@ -978,6 +999,7 @@ class OrderedSetAgg(GenericFunction):
:meth:`.FunctionElement.within_group` method."""
array_for_multi_clause = False
+ inherit_cache = True
def within_group_type(self, within_group):
func_clauses = self.clause_expr.element
@@ -1000,6 +1022,8 @@ class mode(OrderedSetAgg):
"""
+ inherit_cache = True
+
class percentile_cont(OrderedSetAgg):
"""implement the ``percentile_cont`` ordered-set aggregate function.
@@ -1016,6 +1040,7 @@ class percentile_cont(OrderedSetAgg):
"""
array_for_multi_clause = True
+ inherit_cache = True
class percentile_disc(OrderedSetAgg):
@@ -1033,6 +1058,7 @@ class percentile_disc(OrderedSetAgg):
"""
array_for_multi_clause = True
+ inherit_cache = True
class rank(GenericFunction):
@@ -1048,6 +1074,7 @@ class rank(GenericFunction):
"""
type = sqltypes.Integer()
+ inherit_cache = True
class dense_rank(GenericFunction):
@@ -1063,6 +1090,7 @@ class dense_rank(GenericFunction):
"""
type = sqltypes.Integer()
+ inherit_cache = True
class percent_rank(GenericFunction):
@@ -1078,6 +1106,7 @@ class percent_rank(GenericFunction):
"""
type = sqltypes.Numeric()
+ inherit_cache = True
class cume_dist(GenericFunction):
@@ -1093,6 +1122,7 @@ class cume_dist(GenericFunction):
"""
type = sqltypes.Numeric()
+ inherit_cache = True
class cube(GenericFunction):
@@ -1109,6 +1139,7 @@ class cube(GenericFunction):
"""
_has_args = True
+ inherit_cache = True
class rollup(GenericFunction):
@@ -1125,6 +1156,7 @@ class rollup(GenericFunction):
"""
_has_args = True
+ inherit_cache = True
class grouping_sets(GenericFunction):
@@ -1158,3 +1190,4 @@ class grouping_sets(GenericFunction):
"""
_has_args = True
+ inherit_cache = True
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index ee411174c..29ca81d26 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -1013,6 +1013,8 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
__visit_name__ = "column"
+ inherit_cache = True
+
def __init__(self, *args, **kwargs):
r"""
Construct a new ``Column`` object.
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index a95fc561a..54f293967 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -61,6 +61,8 @@ if util.TYPE_CHECKING:
class _OffsetLimitParam(BindParameter):
+ inherit_cache = True
+
@property
def _limit_offset_value(self):
return self.effective_value
@@ -1426,6 +1428,8 @@ class Alias(roles.DMLTableRole, AliasedReturnsRows):
__visit_name__ = "alias"
+ inherit_cache = True
+
@classmethod
def _factory(cls, selectable, name=None, flat=False):
"""Return an :class:`_expression.Alias` object.
@@ -1500,6 +1504,8 @@ class Lateral(AliasedReturnsRows):
__visit_name__ = "lateral"
_is_lateral = True
+ inherit_cache = True
+
@classmethod
def _factory(cls, selectable, name=None):
"""Return a :class:`_expression.Lateral` object.
@@ -1626,7 +1632,7 @@ class CTE(Generative, HasPrefixes, HasSuffixes, AliasedReturnsRows):
AliasedReturnsRows._traverse_internals
+ [
("_cte_alias", InternalTraversal.dp_clauseelement),
- ("_restates", InternalTraversal.dp_clauseelement_unordered_set),
+ ("_restates", InternalTraversal.dp_clauseelement_list),
("recursive", InternalTraversal.dp_boolean),
]
+ HasPrefixes._has_prefixes_traverse_internals
@@ -1651,7 +1657,7 @@ class CTE(Generative, HasPrefixes, HasSuffixes, AliasedReturnsRows):
name=None,
recursive=False,
_cte_alias=None,
- _restates=frozenset(),
+ _restates=(),
_prefixes=None,
_suffixes=None,
):
@@ -1692,7 +1698,7 @@ class CTE(Generative, HasPrefixes, HasSuffixes, AliasedReturnsRows):
self.element.union(other),
name=self.name,
recursive=self.recursive,
- _restates=self._restates.union([self]),
+ _restates=self._restates + (self,),
_prefixes=self._prefixes,
_suffixes=self._suffixes,
)
@@ -1702,7 +1708,7 @@ class CTE(Generative, HasPrefixes, HasSuffixes, AliasedReturnsRows):
self.element.union_all(other),
name=self.name,
recursive=self.recursive,
- _restates=self._restates.union([self]),
+ _restates=self._restates + (self,),
_prefixes=self._prefixes,
_suffixes=self._suffixes,
)
@@ -1918,6 +1924,8 @@ class Subquery(AliasedReturnsRows):
_is_subquery = True
+ inherit_cache = True
+
@classmethod
def _factory(cls, selectable, name=None):
"""Return a :class:`.Subquery` object.
@@ -3783,15 +3791,15 @@ class Select(
("_group_by_clauses", InternalTraversal.dp_clauseelement_list,),
("_setup_joins", InternalTraversal.dp_setup_join_tuple,),
("_legacy_setup_joins", InternalTraversal.dp_setup_join_tuple,),
- ("_correlate", InternalTraversal.dp_clauseelement_unordered_set),
- (
- "_correlate_except",
- InternalTraversal.dp_clauseelement_unordered_set,
- ),
+ ("_correlate", InternalTraversal.dp_clauseelement_list),
+ ("_correlate_except", InternalTraversal.dp_clauseelement_list,),
+ ("_limit_clause", InternalTraversal.dp_clauseelement),
+ ("_offset_clause", InternalTraversal.dp_clauseelement),
("_for_update_arg", InternalTraversal.dp_clauseelement),
("_distinct", InternalTraversal.dp_boolean),
("_distinct_on", InternalTraversal.dp_clauseelement_list),
("_label_style", InternalTraversal.dp_plain_obj),
+ ("_is_future", InternalTraversal.dp_boolean),
]
+ HasPrefixes._has_prefixes_traverse_internals
+ HasSuffixes._has_suffixes_traverse_internals
@@ -4522,7 +4530,7 @@ class Select(
if fromclauses and fromclauses[0] is None:
self._correlate = ()
else:
- self._correlate = set(self._correlate).union(
+ self._correlate = self._correlate + tuple(
coercions.expect(roles.FromClauseRole, f) for f in fromclauses
)
@@ -4560,7 +4568,7 @@ class Select(
if fromclauses and fromclauses[0] is None:
self._correlate_except = ()
else:
- self._correlate_except = set(self._correlate_except or ()).union(
+ self._correlate_except = (self._correlate_except or ()) + tuple(
coercions.expect(roles.FromClauseRole, f) for f in fromclauses
)
@@ -4866,6 +4874,7 @@ class ScalarSelect(roles.InElementRole, Generative, Grouping):
_from_objects = []
_is_from_container = True
_is_implicitly_boolean = False
+ inherit_cache = True
def __init__(self, element):
self.element = element
@@ -4899,6 +4908,7 @@ class Exists(UnaryExpression):
"""
_from_objects = []
+ inherit_cache = True
def __init__(self, *args, **kwargs):
"""Construct a new :class:`_expression.Exists` against an existing
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py
index 732b775f6..9cd9d5058 100644
--- a/lib/sqlalchemy/sql/sqltypes.py
+++ b/lib/sqlalchemy/sql/sqltypes.py
@@ -2616,26 +2616,10 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
return_type = self.type
if self.type.zero_indexes:
index = slice(index.start + 1, index.stop + 1, index.step)
- index = Slice(
- coercions.expect(
- roles.ExpressionElementRole,
- index.start,
- name=self.expr.key,
- type_=type_api.INTEGERTYPE,
- ),
- coercions.expect(
- roles.ExpressionElementRole,
- index.stop,
- name=self.expr.key,
- type_=type_api.INTEGERTYPE,
- ),
- coercions.expect(
- roles.ExpressionElementRole,
- index.step,
- name=self.expr.key,
- type_=type_api.INTEGERTYPE,
- ),
+ slice_ = Slice(
+ index.start, index.stop, index.step, _name=self.expr.key
)
+ return operators.getitem, slice_, return_type
else:
if self.type.zero_indexes:
index += 1
@@ -2647,7 +2631,7 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
self.type.__class__, **adapt_kw
)
- return operators.getitem, index, return_type
+ return operators.getitem, index, return_type
def contains(self, *arg, **kw):
raise NotImplementedError(
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py
index 68281f33d..ed0bfa27a 100644
--- a/lib/sqlalchemy/sql/traversals.py
+++ b/lib/sqlalchemy/sql/traversals.py
@@ -19,6 +19,7 @@ NO_CACHE = util.symbol("no_cache")
CACHE_IN_PLACE = util.symbol("cache_in_place")
CALL_GEN_CACHE_KEY = util.symbol("call_gen_cache_key")
STATIC_CACHE_KEY = util.symbol("static_cache_key")
+PROPAGATE_ATTRS = util.symbol("propagate_attrs")
ANON_NAME = util.symbol("anon_name")
@@ -31,10 +32,74 @@ def compare(obj1, obj2, **kw):
return strategy.compare(obj1, obj2, **kw)
+def _preconfigure_traversals(target_hierarchy):
+
+ stack = [target_hierarchy]
+ while stack:
+ cls = stack.pop()
+ stack.extend(cls.__subclasses__())
+
+ if hasattr(cls, "_traverse_internals"):
+ cls._generate_cache_attrs()
+ _copy_internals.generate_dispatch(
+ cls,
+ cls._traverse_internals,
+ "_generated_copy_internals_traversal",
+ )
+ _get_children.generate_dispatch(
+ cls,
+ cls._traverse_internals,
+ "_generated_get_children_traversal",
+ )
+
+
class HasCacheKey(object):
_cache_key_traversal = NO_CACHE
__slots__ = ()
+ @classmethod
+ def _generate_cache_attrs(cls):
+ """generate cache key dispatcher for a new class.
+
+ This sets the _generated_cache_key_traversal attribute once called
+ so should only be called once per class.
+
+ """
+ inherit = cls.__dict__.get("inherit_cache", False)
+
+ if inherit:
+ _cache_key_traversal = getattr(cls, "_cache_key_traversal", None)
+ if _cache_key_traversal is None:
+ try:
+ _cache_key_traversal = cls._traverse_internals
+ except AttributeError:
+ cls._generated_cache_key_traversal = NO_CACHE
+ return NO_CACHE
+
+ # TODO: wouldn't we instead get this from our superclass?
+ # also, our superclass may not have this yet, but in any case,
+ # we'd generate for the superclass that has it. this is a little
+ # more complicated, so for the moment this is a little less
+ # efficient on startup but simpler.
+ return _cache_key_traversal_visitor.generate_dispatch(
+ cls, _cache_key_traversal, "_generated_cache_key_traversal"
+ )
+ else:
+ _cache_key_traversal = cls.__dict__.get(
+ "_cache_key_traversal", None
+ )
+ if _cache_key_traversal is None:
+ _cache_key_traversal = cls.__dict__.get(
+ "_traverse_internals", None
+ )
+ if _cache_key_traversal is None:
+ cls._generated_cache_key_traversal = NO_CACHE
+ return NO_CACHE
+
+ return _cache_key_traversal_visitor.generate_dispatch(
+ cls, _cache_key_traversal, "_generated_cache_key_traversal"
+ )
+
@util.preload_module("sqlalchemy.sql.elements")
def _gen_cache_key(self, anon_map, bindparams):
"""return an optional cache key.
@@ -72,14 +137,18 @@ class HasCacheKey(object):
else:
id_ = None
- _cache_key_traversal = self._cache_key_traversal
- if _cache_key_traversal is None:
- try:
- _cache_key_traversal = self._traverse_internals
- except AttributeError:
- _cache_key_traversal = NO_CACHE
+ try:
+ dispatcher = self.__class__.__dict__[
+ "_generated_cache_key_traversal"
+ ]
+ except KeyError:
+ # most of the dispatchers are generated up front
+ # in sqlalchemy/sql/__init__.py ->
+ # traversals.py-> _preconfigure_traversals().
+ # this block will generate any remaining dispatchers.
+ dispatcher = self.__class__._generate_cache_attrs()
- if _cache_key_traversal is NO_CACHE:
+ if dispatcher is NO_CACHE:
if anon_map is not None:
anon_map[NO_CACHE] = True
return None
@@ -87,19 +156,13 @@ class HasCacheKey(object):
result = (id_, self.__class__)
# inline of _cache_key_traversal_visitor.run_generated_dispatch()
- try:
- dispatcher = self.__class__.__dict__[
- "_generated_cache_key_traversal"
- ]
- except KeyError:
- dispatcher = _cache_key_traversal_visitor.generate_dispatch(
- self, _cache_key_traversal, "_generated_cache_key_traversal"
- )
for attrname, obj, meth in dispatcher(
self, _cache_key_traversal_visitor
):
if obj is not None:
+ # TODO: see if C code can help here as Python lacks an
+ # efficient switch construct
if meth is CACHE_IN_PLACE:
# cache in place is always going to be a Python
# tuple, dict, list, etc. so we can do a boolean check
@@ -116,6 +179,15 @@ class HasCacheKey(object):
attrname,
obj._gen_cache_key(anon_map, bindparams),
)
+ elif meth is PROPAGATE_ATTRS:
+ if obj:
+ result += (
+ attrname,
+ obj["compile_state_plugin"],
+ obj["plugin_subject"]._gen_cache_key(
+ anon_map, bindparams
+ ),
+ )
elif meth is InternalTraversal.dp_annotations_key:
# obj is here is the _annotations dict. however,
# we want to use the memoized cache key version of it.
@@ -332,6 +404,8 @@ class _CacheKey(ExtendedInternalTraversal):
visit_type = STATIC_CACHE_KEY
visit_anon_name = ANON_NAME
+ visit_propagate_attrs = PROPAGATE_ATTRS
+
def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams):
return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams))
@@ -445,10 +519,16 @@ class _CacheKey(ExtendedInternalTraversal):
def visit_setup_join_tuple(
self, attrname, obj, parent, anon_map, bindparams
):
+ is_legacy = "legacy" in attrname
+
return tuple(
(
- target._gen_cache_key(anon_map, bindparams),
- onclause._gen_cache_key(anon_map, bindparams)
+ target
+ if is_legacy and isinstance(target, str)
+ else target._gen_cache_key(anon_map, bindparams),
+ onclause
+ if is_legacy and isinstance(onclause, str)
+ else onclause._gen_cache_key(anon_map, bindparams)
if onclause is not None
else None,
from_._gen_cache_key(anon_map, bindparams)
@@ -711,6 +791,11 @@ class _CopyInternals(InternalTraversal):
for sequence in element
]
+ def visit_propagate_attrs(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
+ return element
+
_copy_internals = _CopyInternals()
@@ -782,6 +867,9 @@ class _GetChildren(InternalTraversal):
def visit_dml_multi_values(self, element, **kw):
return ()
+ def visit_propagate_attrs(self, element, **kw):
+ return ()
+
_get_children = _GetChildren()
@@ -916,6 +1004,13 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
):
return COMPARE_FAILED
+ def visit_propagate_attrs(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ return self.compare_inner(
+ left.get("plugin_subject", None), right.get("plugin_subject", None)
+ )
+
def visit_has_cache_key_list(
self, attrname, left_parent, left, right_parent, right, **kw
):
diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py
index ccda21e11..fe3634bad 100644
--- a/lib/sqlalchemy/sql/type_api.py
+++ b/lib/sqlalchemy/sql/type_api.py
@@ -555,7 +555,12 @@ class TypeEngine(Traversible):
def _static_cache_key(self):
names = util.get_cls_kwargs(self.__class__)
return (self.__class__,) + tuple(
- (k, self.__dict__[k])
+ (
+ k,
+ self.__dict__[k]._static_cache_key
+ if isinstance(self.__dict__[k], TypeEngine)
+ else self.__dict__[k],
+ )
for k in names
if k in self.__dict__ and not k.startswith("_")
)
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 5de68f504..904702003 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -217,18 +217,23 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)):
try:
dispatcher = target.__class__.__dict__[generate_dispatcher_name]
except KeyError:
+ # most of the dispatchers are generated up front
+ # in sqlalchemy/sql/__init__.py ->
+ # traversals.py-> _preconfigure_traversals().
+ # this block will generate any remaining dispatchers.
dispatcher = self.generate_dispatch(
- target, internal_dispatch, generate_dispatcher_name
+ target.__class__, internal_dispatch, generate_dispatcher_name
)
return dispatcher(target, self)
def generate_dispatch(
- self, target, internal_dispatch, generate_dispatcher_name
+ self, target_cls, internal_dispatch, generate_dispatcher_name
):
dispatcher = _generate_dispatcher(
self, internal_dispatch, generate_dispatcher_name
)
- setattr(target.__class__, generate_dispatcher_name, dispatcher)
+ # assert isinstance(target_cls, type)
+ setattr(target_cls, generate_dispatcher_name, dispatcher)
return dispatcher
dp_has_cache_key = symbol("HC")
@@ -263,10 +268,6 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)):
"""
- dp_clauseelement_unordered_set = symbol("CU")
- """Visit an unordered set of :class:`_expression.ClauseElement`
- objects. """
-
dp_fromclause_ordered_set = symbol("CO")
"""Visit an ordered set of :class:`_expression.FromClause` objects. """
@@ -414,6 +415,10 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)):
"""
+ dp_propagate_attrs = symbol("PA")
+ """Visit the propagate attrs dict. this hardcodes to the particular
+ elements we care about right now."""
+
class ExtendedInternalTraversal(InternalTraversal):
"""defines additional symbols that are useful in caching applications.
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py
index 7988b4ec9..48cbb4694 100644
--- a/lib/sqlalchemy/testing/assertsql.py
+++ b/lib/sqlalchemy/testing/assertsql.py
@@ -97,13 +97,13 @@ class CompiledSQL(SQLMatchRule):
else:
map_ = None
- if isinstance(context.compiled.statement, _DDLCompiles):
+ if isinstance(execute_observed.clauseelement, _DDLCompiles):
- compiled = context.compiled.statement.compile(
+ compiled = execute_observed.clauseelement.compile(
dialect=compare_dialect, schema_translate_map=map_
)
else:
- compiled = context.compiled.statement.compile(
+ compiled = execute_observed.clauseelement.compile(
dialect=compare_dialect,
column_keys=context.compiled.column_keys,
inline=context.compiled.inline,