summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/dialects/mssql/pyodbc.py2
-rw-r--r--lib/sqlalchemy/dialects/postgresql/provision.py4
-rw-r--r--lib/sqlalchemy/dialects/sqlite/base.py5
-rw-r--r--lib/sqlalchemy/engine/cursor.py390
-rw-r--r--lib/sqlalchemy/engine/default.py51
-rw-r--r--lib/sqlalchemy/engine/result.py22
-rw-r--r--lib/sqlalchemy/orm/bulk_persistence.py1459
-rw-r--r--lib/sqlalchemy/orm/context.py173
-rw-r--r--lib/sqlalchemy/orm/descriptor_props.py26
-rw-r--r--lib/sqlalchemy/orm/evaluator.py76
-rw-r--r--lib/sqlalchemy/orm/identity.py10
-rw-r--r--lib/sqlalchemy/orm/loading.py5
-rw-r--r--lib/sqlalchemy/orm/mapper.py15
-rw-r--r--lib/sqlalchemy/orm/persistence.py138
-rw-r--r--lib/sqlalchemy/orm/query.py4
-rw-r--r--lib/sqlalchemy/orm/session.py25
-rw-r--r--lib/sqlalchemy/orm/strategies.py11
-rw-r--r--lib/sqlalchemy/sql/annotation.py2
-rw-r--r--lib/sqlalchemy/sql/compiler.py24
-rw-r--r--lib/sqlalchemy/sql/crud.py119
-rw-r--r--lib/sqlalchemy/sql/dml.py409
-rw-r--r--lib/sqlalchemy/testing/assertsql.py13
-rw-r--r--lib/sqlalchemy/testing/fixtures.py22
-rw-r--r--lib/sqlalchemy/testing/suite/test_rowcount.py7
-rw-r--r--lib/sqlalchemy/util/_py_collections.py5
25 files changed, 2317 insertions, 700 deletions
diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py
index 2eef971cc..5eb6b9528 100644
--- a/lib/sqlalchemy/dialects/mssql/pyodbc.py
+++ b/lib/sqlalchemy/dialects/mssql/pyodbc.py
@@ -301,7 +301,7 @@ Fast Executemany Mode
The SQL Server ``fast_executemany`` parameter may be used at the same time
as ``insertmanyvalues`` is enabled; however, the parameter will not be used
in as many cases as INSERT statements that are invoked using Core
- :class:`.Insert` constructs as well as all ORM use no longer use the
+ :class:`_dml.Insert` constructs as well as all ORM use no longer use the
``.executemany()`` DBAPI cursor method.
The PyODBC driver includes support for a "fast executemany" mode of execution
diff --git a/lib/sqlalchemy/dialects/postgresql/provision.py b/lib/sqlalchemy/dialects/postgresql/provision.py
index 8dd8a4995..4609701a2 100644
--- a/lib/sqlalchemy/dialects/postgresql/provision.py
+++ b/lib/sqlalchemy/dialects/postgresql/provision.py
@@ -134,9 +134,11 @@ def _upsert(cfg, table, returning, set_lambda=None):
stmt = insert(table)
+ table_pk = inspect(table).selectable
+
if set_lambda:
stmt = stmt.on_conflict_do_update(
- index_elements=table.primary_key, set_=set_lambda(stmt.excluded)
+ index_elements=table_pk.primary_key, set_=set_lambda(stmt.excluded)
)
else:
stmt = stmt.on_conflict_do_nothing()
diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py
index e57a84fe0..5f468edbe 100644
--- a/lib/sqlalchemy/dialects/sqlite/base.py
+++ b/lib/sqlalchemy/dialects/sqlite/base.py
@@ -1466,11 +1466,6 @@ class SQLiteCompiler(compiler.SQLCompiler):
return target_text
- def visit_insert(self, insert_stmt, **kw):
- if insert_stmt._post_values_clause is not None:
- kw["disable_implicit_returning"] = True
- return super().visit_insert(insert_stmt, **kw)
-
def visit_on_conflict_do_nothing(self, on_conflict, **kw):
target_text = self._on_conflict_target(on_conflict, **kw)
diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py
index 8840b5916..07e782296 100644
--- a/lib/sqlalchemy/engine/cursor.py
+++ b/lib/sqlalchemy/engine/cursor.py
@@ -23,12 +23,14 @@ from typing import Iterator
from typing import List
from typing import NoReturn
from typing import Optional
+from typing import overload
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
+from .result import IteratorResult
from .result import MergedResult
from .result import Result
from .result import ResultMetaData
@@ -62,36 +64,80 @@ if typing.TYPE_CHECKING:
from .interfaces import ExecutionContext
from .result import _KeyIndexType
from .result import _KeyMapRecType
+ from .result import _KeyMapType
from .result import _KeyType
from .result import _ProcessorsType
+ from .result import _TupleGetterType
from ..sql.type_api import _ResultProcessorType
_T = TypeVar("_T", bound=Any)
+
# metadata entry tuple indexes.
# using raw tuple is faster than namedtuple.
-MD_INDEX: Literal[0] = 0 # integer index in cursor.description
-MD_RESULT_MAP_INDEX: Literal[
- 1
-] = 1 # integer index in compiled._result_columns
-MD_OBJECTS: Literal[
- 2
-] = 2 # other string keys and ColumnElement obj that can match
-MD_LOOKUP_KEY: Literal[
- 3
-] = 3 # string key we usually expect for key-based lookup
-MD_RENDERED_NAME: Literal[4] = 4 # name that is usually in cursor.description
-MD_PROCESSOR: Literal[5] = 5 # callable to process a result value into a row
-MD_UNTRANSLATED: Literal[6] = 6 # raw name from cursor.description
+# these match up to the positions in
+# _CursorKeyMapRecType
+MD_INDEX: Literal[0] = 0
+"""integer index in cursor.description
+
+"""
+
+MD_RESULT_MAP_INDEX: Literal[1] = 1
+"""integer index in compiled._result_columns"""
+
+MD_OBJECTS: Literal[2] = 2
+"""other string keys and ColumnElement obj that can match.
+
+This comes from compiler.RM_OBJECTS / compiler.ResultColumnsEntry.objects
+
+"""
+
+MD_LOOKUP_KEY: Literal[3] = 3
+"""string key we usually expect for key-based lookup
+
+this comes from compiler.RM_NAME / compiler.ResultColumnsEntry.name
+"""
+
+
+MD_RENDERED_NAME: Literal[4] = 4
+"""name that is usually in cursor.description
+
+this comes from compiler.RENDERED_NAME / compiler.ResultColumnsEntry.keyname
+"""
+
+
+MD_PROCESSOR: Literal[5] = 5
+"""callable to process a result value into a row"""
+
+MD_UNTRANSLATED: Literal[6] = 6
+"""raw name from cursor.description"""
_CursorKeyMapRecType = Tuple[
- int, int, List[Any], str, str, Optional["_ResultProcessorType"], str
+ Optional[int], # MD_INDEX, None means the record is ambiguously named
+ int, # MD_RESULT_MAP_INDEX
+ List[Any], # MD_OBJECTS
+ str, # MD_LOOKUP_KEY
+ str, # MD_RENDERED_NAME
+ Optional["_ResultProcessorType"], # MD_PROCESSOR
+ Optional[str], # MD_UNTRANSLATED
]
_CursorKeyMapType = Dict["_KeyType", _CursorKeyMapRecType]
+# same as _CursorKeyMapRecType except the MD_INDEX value is definitely
+# not None
+_NonAmbigCursorKeyMapRecType = Tuple[
+ int,
+ int,
+ List[Any],
+ str,
+ str,
+ Optional["_ResultProcessorType"],
+ str,
+]
+
class CursorResultMetaData(ResultMetaData):
"""Result metadata for DBAPI cursors."""
@@ -127,38 +173,112 @@ class CursorResultMetaData(ResultMetaData):
extra=[self._keymap[key][MD_OBJECTS] for key in self._keys],
)
- def _reduce(self, keys: Sequence[_KeyIndexType]) -> ResultMetaData:
- recs = cast(
- "List[_CursorKeyMapRecType]", list(self._metadata_for_keys(keys))
+ def _make_new_metadata(
+ self,
+ *,
+ unpickled: bool,
+ processors: _ProcessorsType,
+ keys: Sequence[str],
+ keymap: _KeyMapType,
+ tuplefilter: Optional[_TupleGetterType],
+ translated_indexes: Optional[List[int]],
+ safe_for_cache: bool,
+ keymap_by_result_column_idx: Any,
+ ) -> CursorResultMetaData:
+ new_obj = self.__class__.__new__(self.__class__)
+ new_obj._unpickled = unpickled
+ new_obj._processors = processors
+ new_obj._keys = keys
+ new_obj._keymap = keymap
+ new_obj._tuplefilter = tuplefilter
+ new_obj._translated_indexes = translated_indexes
+ new_obj._safe_for_cache = safe_for_cache
+ new_obj._keymap_by_result_column_idx = keymap_by_result_column_idx
+ return new_obj
+
+ def _remove_processors(self) -> CursorResultMetaData:
+ assert not self._tuplefilter
+ return self._make_new_metadata(
+ unpickled=self._unpickled,
+ processors=[None] * len(self._processors),
+ tuplefilter=None,
+ translated_indexes=None,
+ keymap={
+ key: value[0:5] + (None,) + value[6:]
+ for key, value in self._keymap.items()
+ },
+ keys=self._keys,
+ safe_for_cache=self._safe_for_cache,
+ keymap_by_result_column_idx=self._keymap_by_result_column_idx,
)
+ def _splice_horizontally(
+ self, other: CursorResultMetaData
+ ) -> CursorResultMetaData:
+
+ assert not self._tuplefilter
+
+ keymap = self._keymap.copy()
+ offset = len(self._keys)
+ keymap.update(
+ {
+ key: (
+ # int index should be None for ambiguous key
+ value[0] + offset
+ if value[0] is not None and key not in keymap
+ else None,
+ value[1] + offset,
+ *value[2:],
+ )
+ for key, value in other._keymap.items()
+ }
+ )
+
+ return self._make_new_metadata(
+ unpickled=self._unpickled,
+ processors=self._processors + other._processors, # type: ignore
+ tuplefilter=None,
+ translated_indexes=None,
+ keys=self._keys + other._keys, # type: ignore
+ keymap=keymap,
+ safe_for_cache=self._safe_for_cache,
+ keymap_by_result_column_idx={
+ metadata_entry[MD_RESULT_MAP_INDEX]: metadata_entry
+ for metadata_entry in keymap.values()
+ },
+ )
+
+ def _reduce(self, keys: Sequence[_KeyIndexType]) -> ResultMetaData:
+ recs = list(self._metadata_for_keys(keys))
+
indexes = [rec[MD_INDEX] for rec in recs]
new_keys: List[str] = [rec[MD_LOOKUP_KEY] for rec in recs]
if self._translated_indexes:
indexes = [self._translated_indexes[idx] for idx in indexes]
tup = tuplegetter(*indexes)
-
- new_metadata = self.__class__.__new__(self.__class__)
- new_metadata._unpickled = self._unpickled
- new_metadata._processors = self._processors
- new_metadata._keys = new_keys
- new_metadata._tuplefilter = tup
- new_metadata._translated_indexes = indexes
-
new_recs = [(index,) + rec[1:] for index, rec in enumerate(recs)]
- new_metadata._keymap = {rec[MD_LOOKUP_KEY]: rec for rec in new_recs}
+ keymap: _KeyMapType = {rec[MD_LOOKUP_KEY]: rec for rec in new_recs}
# TODO: need unit test for:
# result = connection.execute("raw sql, no columns").scalars()
# without the "or ()" it's failing because MD_OBJECTS is None
- new_metadata._keymap.update(
+ keymap.update(
(e, new_rec)
for new_rec in new_recs
for e in new_rec[MD_OBJECTS] or ()
)
- return new_metadata
+ return self._make_new_metadata(
+ unpickled=self._unpickled,
+ processors=self._processors,
+ keys=new_keys,
+ tuplefilter=tup,
+ translated_indexes=indexes,
+ keymap=keymap,
+ safe_for_cache=self._safe_for_cache,
+ keymap_by_result_column_idx=self._keymap_by_result_column_idx,
+ )
def _adapt_to_context(self, context: ExecutionContext) -> ResultMetaData:
"""When using a cached Compiled construct that has a _result_map,
@@ -168,6 +288,7 @@ class CursorResultMetaData(ResultMetaData):
as matched to those of the cached statement.
"""
+
if not context.compiled or not context.compiled._result_columns:
return self
@@ -189,7 +310,6 @@ class CursorResultMetaData(ResultMetaData):
# make a copy and add the columns from the invoked statement
# to the result map.
- md = self.__class__.__new__(self.__class__)
keymap_by_position = self._keymap_by_result_column_idx
@@ -201,26 +321,26 @@ class CursorResultMetaData(ResultMetaData):
for metadata_entry in self._keymap.values()
}
- md._keymap = compat.dict_union(
- self._keymap,
- {
- new: keymap_by_position[idx]
- for idx, new in enumerate(
- invoked_statement._all_selected_columns
- )
- if idx in keymap_by_position
- },
- )
-
- md._unpickled = self._unpickled
- md._processors = self._processors
assert not self._tuplefilter
- md._tuplefilter = None
- md._translated_indexes = None
- md._keys = self._keys
- md._keymap_by_result_column_idx = self._keymap_by_result_column_idx
- md._safe_for_cache = self._safe_for_cache
- return md
+ return self._make_new_metadata(
+ keymap=compat.dict_union(
+ self._keymap,
+ {
+ new: keymap_by_position[idx]
+ for idx, new in enumerate(
+ invoked_statement._all_selected_columns
+ )
+ if idx in keymap_by_position
+ },
+ ),
+ unpickled=self._unpickled,
+ processors=self._processors,
+ tuplefilter=None,
+ translated_indexes=None,
+ keys=self._keys,
+ safe_for_cache=self._safe_for_cache,
+ keymap_by_result_column_idx=self._keymap_by_result_column_idx,
+ )
def __init__(
self,
@@ -683,7 +803,27 @@ class CursorResultMetaData(ResultMetaData):
untranslated,
)
- def _key_fallback(self, key, err, raiseerr=True):
+ @overload
+ def _key_fallback(
+ self, key: Any, err: Exception, raiseerr: Literal[True] = ...
+ ) -> NoReturn:
+ ...
+
+ @overload
+ def _key_fallback(
+ self, key: Any, err: Exception, raiseerr: Literal[False] = ...
+ ) -> None:
+ ...
+
+ @overload
+ def _key_fallback(
+ self, key: Any, err: Exception, raiseerr: bool = ...
+ ) -> Optional[NoReturn]:
+ ...
+
+ def _key_fallback(
+ self, key: Any, err: Exception, raiseerr: bool = True
+ ) -> Optional[NoReturn]:
if raiseerr:
if self._unpickled and isinstance(key, elements.ColumnElement):
@@ -714,9 +854,9 @@ class CursorResultMetaData(ResultMetaData):
try:
rec = self._keymap[key]
except KeyError as ke:
- rec = self._key_fallback(key, ke, raiseerr)
- if rec is None:
- return None
+ x = self._key_fallback(key, ke, raiseerr)
+ assert x is None
+ return None
index = rec[0]
@@ -734,7 +874,7 @@ class CursorResultMetaData(ResultMetaData):
def _metadata_for_keys(
self, keys: Sequence[Any]
- ) -> Iterator[_CursorKeyMapRecType]:
+ ) -> Iterator[_NonAmbigCursorKeyMapRecType]:
for key in keys:
if int in key.__class__.__mro__:
key = self._keys[key]
@@ -750,7 +890,7 @@ class CursorResultMetaData(ResultMetaData):
if index is None:
self._raise_for_ambiguous_column_name(rec)
- yield rec
+ yield cast(_NonAmbigCursorKeyMapRecType, rec)
def __getstate__(self):
return {
@@ -1237,6 +1377,12 @@ _NO_RESULT_METADATA = _NoResultMetaData()
SelfCursorResult = TypeVar("SelfCursorResult", bound="CursorResult[Any]")
+def null_dml_result() -> IteratorResult[Any]:
+ it: IteratorResult[Any] = IteratorResult(_NoResultMetaData(), iter([]))
+ it._soft_close()
+ return it
+
+
class CursorResult(Result[_T]):
"""A Result that is representing state from a DBAPI cursor.
@@ -1586,6 +1732,142 @@ class CursorResult(Result[_T]):
"""
return self.context.returned_default_rows
+ def splice_horizontally(self, other):
+ """Return a new :class:`.CursorResult` that "horizontally splices"
+ together the rows of this :class:`.CursorResult` with that of another
+ :class:`.CursorResult`.
+
+ .. tip:: This method is for the benefit of the SQLAlchemy ORM and is
+ not intended for general use.
+
+ "horizontally splices" means that for each row in the first and second
+ result sets, a new row that concatenates the two rows together is
+ produced, which then becomes the new row. The incoming
+ :class:`.CursorResult` must have the identical number of rows. It is
+ typically expected that the two result sets come from the same sort
+ order as well, as the result rows are spliced together based on their
+ position in the result.
+
+ The expected use case here is so that multiple INSERT..RETURNING
+ statements against different tables can produce a single result
+ that looks like a JOIN of those two tables.
+
+ E.g.::
+
+ r1 = connection.execute(
+ users.insert().returning(users.c.user_name, users.c.user_id),
+ user_values
+ )
+
+ r2 = connection.execute(
+ addresses.insert().returning(
+ addresses.c.address_id,
+ addresses.c.address,
+ addresses.c.user_id,
+ ),
+ address_values
+ )
+
+ rows = r1.splice_horizontally(r2).all()
+ assert (
+ rows ==
+ [
+ ("john", 1, 1, "foo@bar.com", 1),
+ ("jack", 2, 2, "bar@bat.com", 2),
+ ]
+ )
+
+ .. versionadded:: 2.0
+
+ .. seealso::
+
+ :meth:`.CursorResult.splice_vertically`
+
+
+ """
+
+ clone = self._generate()
+ total_rows = [
+ tuple(r1) + tuple(r2)
+ for r1, r2 in zip(
+ list(self._raw_row_iterator()),
+ list(other._raw_row_iterator()),
+ )
+ ]
+
+ clone._metadata = clone._metadata._splice_horizontally(other._metadata)
+
+ clone.cursor_strategy = FullyBufferedCursorFetchStrategy(
+ None,
+ initial_buffer=total_rows,
+ )
+ clone._reset_memoizations()
+ return clone
+
+ def splice_vertically(self, other):
+ """Return a new :class:`.CursorResult` that "vertically splices",
+ i.e. "extends", the rows of this :class:`.CursorResult` with that of
+ another :class:`.CursorResult`.
+
+ .. tip:: This method is for the benefit of the SQLAlchemy ORM and is
+ not intended for general use.
+
+ "vertically splices" means the rows of the given result are appended to
+ the rows of this cursor result. The incoming :class:`.CursorResult`
+ must have rows that represent the identical list of columns in the
+ identical order as they are in this :class:`.CursorResult`.
+
+ .. versionadded:: 2.0
+
+ .. seealso::
+
+ :ref:`.CursorResult.splice_horizontally`
+
+ """
+ clone = self._generate()
+ total_rows = list(self._raw_row_iterator()) + list(
+ other._raw_row_iterator()
+ )
+
+ clone.cursor_strategy = FullyBufferedCursorFetchStrategy(
+ None,
+ initial_buffer=total_rows,
+ )
+ clone._reset_memoizations()
+ return clone
+
+ def _rewind(self, rows):
+ """rewind this result back to the given rowset.
+
+ this is used internally for the case where an :class:`.Insert`
+ construct combines the use of
+ :meth:`.Insert.return_defaults` along with the
+ "supplemental columns" feature.
+
+ """
+
+ if self._echo:
+ self.context.connection._log_debug(
+ "CursorResult rewound %d row(s)", len(rows)
+ )
+
+ # the rows given are expected to be Row objects, so we
+ # have to clear out processors which have already run on these
+ # rows
+ self._metadata = cast(
+ CursorResultMetaData, self._metadata
+ )._remove_processors()
+
+ self.cursor_strategy = FullyBufferedCursorFetchStrategy(
+ None,
+ # TODO: if these are Row objects, can we save on not having to
+ # re-make new Row objects out of them a second time? is that
+ # what's actually happening right now? maybe look into this
+ initial_buffer=rows,
+ )
+ self._reset_memoizations()
+ return self
+
@property
def returned_defaults(self):
"""Return the values of default columns that were fetched using
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 11ab713d0..cb3d0528f 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -1007,6 +1007,7 @@ class DefaultExecutionContext(ExecutionContext):
_is_implicit_returning = False
_is_explicit_returning = False
+ _is_supplemental_returning = False
_is_server_side = False
_soft_closed = False
@@ -1125,18 +1126,19 @@ class DefaultExecutionContext(ExecutionContext):
self.is_text = compiled.isplaintext
if ii or iu or id_:
+ dml_statement = compiled.compile_state.statement # type: ignore
if TYPE_CHECKING:
- assert isinstance(compiled.statement, UpdateBase)
+ assert isinstance(dml_statement, UpdateBase)
self.is_crud = True
- self._is_explicit_returning = ier = bool(
- compiled.statement._returning
- )
- self._is_implicit_returning = iir = is_implicit_returning = bool(
+ self._is_explicit_returning = ier = bool(dml_statement._returning)
+ self._is_implicit_returning = iir = bool(
compiled.implicit_returning
)
- assert not (
- is_implicit_returning and compiled.statement._returning
- )
+ if iir and dml_statement._supplemental_returning:
+ self._is_supplemental_returning = True
+
+ # dont mix implicit and explicit returning
+ assert not (iir and ier)
if (ier or iir) and compiled.for_executemany:
if ii and not self.dialect.insert_executemany_returning:
@@ -1711,7 +1713,14 @@ class DefaultExecutionContext(ExecutionContext):
# are that the result has only one row, until executemany()
# support is added here.
assert result._metadata.returns_rows
- result._soft_close()
+
+ # Insert statement has both return_defaults() and
+ # returning(). rewind the result on the list of rows
+ # we just used.
+ if self._is_supplemental_returning:
+ result._rewind(rows)
+ else:
+ result._soft_close()
elif not self._is_explicit_returning:
result._soft_close()
@@ -1721,21 +1730,18 @@ class DefaultExecutionContext(ExecutionContext):
# function so this is not necessarily true.
# assert not result.returns_rows
- elif self.isupdate and self._is_implicit_returning:
- # get rowcount
- # (which requires open cursor on some drivers)
- # we were not doing this in 1.4, however
- # test_rowcount -> test_update_rowcount_return_defaults
- # is testing this, and psycopg will no longer return
- # rowcount after cursor is closed.
- result.rowcount
- self._has_rowcount = True
+ elif self._is_implicit_returning:
+ rows = result.all()
- row = result.fetchone()
- if row is not None:
- self.returned_default_rows = [row]
+ if rows:
+ self.returned_default_rows = rows
+ result.rowcount = len(rows)
+ self._has_rowcount = True
- result._soft_close()
+ if self._is_supplemental_returning:
+ result._rewind(rows)
+ else:
+ result._soft_close()
# test that it has a cursor metadata that is accurate.
# the rows have all been fetched however.
@@ -1750,7 +1756,6 @@ class DefaultExecutionContext(ExecutionContext):
elif self.isupdate or self.isdelete:
result.rowcount
self._has_rowcount = True
-
return result
@util.memoized_property
diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py
index df5a8199c..05ca17063 100644
--- a/lib/sqlalchemy/engine/result.py
+++ b/lib/sqlalchemy/engine/result.py
@@ -109,9 +109,27 @@ class ResultMetaData:
def _for_freeze(self) -> ResultMetaData:
raise NotImplementedError()
+ @overload
def _key_fallback(
- self, key: _KeyType, err: Exception, raiseerr: bool = True
+ self, key: Any, err: Exception, raiseerr: Literal[True] = ...
) -> NoReturn:
+ ...
+
+ @overload
+ def _key_fallback(
+ self, key: Any, err: Exception, raiseerr: Literal[False] = ...
+ ) -> None:
+ ...
+
+ @overload
+ def _key_fallback(
+ self, key: Any, err: Exception, raiseerr: bool = ...
+ ) -> Optional[NoReturn]:
+ ...
+
+ def _key_fallback(
+ self, key: Any, err: Exception, raiseerr: bool = True
+ ) -> Optional[NoReturn]:
assert raiseerr
raise KeyError(key) from err
@@ -2148,6 +2166,7 @@ class IteratorResult(Result[_TP]):
"""
_hard_closed = False
+ _soft_closed = False
def __init__(
self,
@@ -2168,6 +2187,7 @@ class IteratorResult(Result[_TP]):
self.raw._soft_close(hard=hard, **kw)
self.iterator = iter([])
self._reset_memoizations()
+ self._soft_closed = True
def _raise_hard_closed(self) -> NoReturn:
raise exc.ResourceClosedError("This result object is closed.")
diff --git a/lib/sqlalchemy/orm/bulk_persistence.py b/lib/sqlalchemy/orm/bulk_persistence.py
index 225292d17..3ed34a57a 100644
--- a/lib/sqlalchemy/orm/bulk_persistence.py
+++ b/lib/sqlalchemy/orm/bulk_persistence.py
@@ -15,24 +15,32 @@ specifically outside of the flush() process.
from __future__ import annotations
from typing import Any
+from typing import cast
from typing import Dict
from typing import Iterable
+from typing import Optional
+from typing import overload
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from . import attributes
+from . import context
from . import evaluator
from . import exc as orm_exc
+from . import loading
from . import persistence
from .base import NO_VALUE
from .context import AbstractORMCompileState
+from .context import FromStatement
+from .context import ORMFromStatementCompileState
+from .context import QueryContext
from .. import exc as sa_exc
-from .. import sql
from .. import util
from ..engine import Dialect
from ..engine import result as _result
from ..sql import coercions
+from ..sql import dml
from ..sql import expression
from ..sql import roles
from ..sql import select
@@ -48,16 +56,24 @@ from ..util.typing import Literal
if TYPE_CHECKING:
from .mapper import Mapper
+ from .session import _BindArguments
from .session import ORMExecuteState
+ from .session import Session
from .session import SessionTransaction
from .state import InstanceState
+ from ..engine import Connection
+ from ..engine import cursor
+ from ..engine.interfaces import _CoreAnyExecuteParams
+ from ..engine.interfaces import _ExecuteOptionsParameter
_O = TypeVar("_O", bound=object)
-_SynchronizeSessionArgument = Literal[False, "evaluate", "fetch"]
+_SynchronizeSessionArgument = Literal[False, "auto", "evaluate", "fetch"]
+_DMLStrategyArgument = Literal["bulk", "raw", "orm", "auto"]
+@overload
def _bulk_insert(
mapper: Mapper[_O],
mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
@@ -65,7 +81,36 @@ def _bulk_insert(
isstates: bool,
return_defaults: bool,
render_nulls: bool,
+ use_orm_insert_stmt: Literal[None] = ...,
+ execution_options: Optional[_ExecuteOptionsParameter] = ...,
) -> None:
+ ...
+
+
+@overload
+def _bulk_insert(
+ mapper: Mapper[_O],
+ mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
+ session_transaction: SessionTransaction,
+ isstates: bool,
+ return_defaults: bool,
+ render_nulls: bool,
+ use_orm_insert_stmt: Optional[dml.Insert] = ...,
+ execution_options: Optional[_ExecuteOptionsParameter] = ...,
+) -> cursor.CursorResult[Any]:
+ ...
+
+
+def _bulk_insert(
+ mapper: Mapper[_O],
+ mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
+ session_transaction: SessionTransaction,
+ isstates: bool,
+ return_defaults: bool,
+ render_nulls: bool,
+ use_orm_insert_stmt: Optional[dml.Insert] = None,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+) -> Optional[cursor.CursorResult[Any]]:
base_mapper = mapper.base_mapper
if session_transaction.session.connection_callable:
@@ -81,13 +126,27 @@ def _bulk_insert(
else:
mappings = [state.dict for state in mappings]
else:
- mappings = list(mappings)
+ mappings = [dict(m) for m in mappings]
+ _expand_composites(mapper, mappings)
connection = session_transaction.connection(base_mapper)
+
+ return_result: Optional[cursor.CursorResult[Any]] = None
+
for table, super_mapper in base_mapper._sorted_tables.items():
- if not mapper.isa(super_mapper):
+ if not mapper.isa(super_mapper) or table not in mapper._pks_by_table:
continue
+ is_joined_inh_supertable = super_mapper is not mapper
+ bookkeeping = (
+ is_joined_inh_supertable
+ or return_defaults
+ or (
+ use_orm_insert_stmt is not None
+ and bool(use_orm_insert_stmt._returning)
+ )
+ )
+
records = (
(
None,
@@ -112,18 +171,25 @@ def _bulk_insert(
table,
((None, mapping, mapper, connection) for mapping in mappings),
bulk=True,
- return_defaults=return_defaults,
+ return_defaults=bookkeeping,
render_nulls=render_nulls,
)
)
- persistence._emit_insert_statements(
+ result = persistence._emit_insert_statements(
base_mapper,
None,
super_mapper,
table,
records,
- bookkeeping=return_defaults,
+ bookkeeping=bookkeeping,
+ use_orm_insert_stmt=use_orm_insert_stmt,
+ execution_options=execution_options,
)
+ if use_orm_insert_stmt is not None:
+ if not use_orm_insert_stmt._returning or return_result is None:
+ return_result = result
+ elif result.returns_rows:
+ return_result = return_result.splice_horizontally(result)
if return_defaults and isstates:
identity_cls = mapper._identity_class
@@ -134,14 +200,43 @@ def _bulk_insert(
tuple([dict_[key] for key in identity_props]),
)
+ if use_orm_insert_stmt is not None:
+ assert return_result is not None
+ return return_result
+
+@overload
def _bulk_update(
mapper: Mapper[Any],
mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
session_transaction: SessionTransaction,
isstates: bool,
update_changed_only: bool,
+ use_orm_update_stmt: Literal[None] = ...,
) -> None:
+ ...
+
+
+@overload
+def _bulk_update(
+ mapper: Mapper[Any],
+ mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
+ session_transaction: SessionTransaction,
+ isstates: bool,
+ update_changed_only: bool,
+ use_orm_update_stmt: Optional[dml.Update] = ...,
+) -> _result.Result[Any]:
+ ...
+
+
+def _bulk_update(
+ mapper: Mapper[Any],
+ mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
+ session_transaction: SessionTransaction,
+ isstates: bool,
+ update_changed_only: bool,
+ use_orm_update_stmt: Optional[dml.Update] = None,
+) -> Optional[_result.Result[Any]]:
base_mapper = mapper.base_mapper
search_keys = mapper._primary_key_propkeys
@@ -161,7 +256,8 @@ def _bulk_update(
else:
mappings = [state.dict for state in mappings]
else:
- mappings = list(mappings)
+ mappings = [dict(m) for m in mappings]
+ _expand_composites(mapper, mappings)
if session_transaction.session.connection_callable:
raise NotImplementedError(
@@ -172,7 +268,7 @@ def _bulk_update(
connection = session_transaction.connection(base_mapper)
for table, super_mapper in base_mapper._sorted_tables.items():
- if not mapper.isa(super_mapper):
+ if not mapper.isa(super_mapper) or table not in mapper._pks_by_table:
continue
records = persistence._collect_update_commands(
@@ -193,8 +289,8 @@ def _bulk_update(
for mapping in mappings
),
bulk=True,
+ use_orm_update_stmt=use_orm_update_stmt,
)
-
persistence._emit_update_statements(
base_mapper,
None,
@@ -202,10 +298,125 @@ def _bulk_update(
table,
records,
bookkeeping=False,
+ use_orm_update_stmt=use_orm_update_stmt,
)
+ if use_orm_update_stmt is not None:
+ return _result.null_result()
+
+
+def _expand_composites(mapper, mappings):
+ composite_attrs = mapper.composites
+ if not composite_attrs:
+ return
+
+ composite_keys = set(composite_attrs.keys())
+ populators = {
+ key: composite_attrs[key]._populate_composite_bulk_save_mappings_fn()
+ for key in composite_keys
+ }
+ for mapping in mappings:
+ for key in composite_keys.intersection(mapping):
+ populators[key](mapping)
+
class ORMDMLState(AbstractORMCompileState):
+ is_dml_returning = True
+ from_statement_ctx: Optional[ORMFromStatementCompileState] = None
+
+ @classmethod
+ def _get_orm_crud_kv_pairs(
+ cls, mapper, statement, kv_iterator, needs_to_be_cacheable
+ ):
+
+ core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs
+
+ for k, v in kv_iterator:
+ k = coercions.expect(roles.DMLColumnRole, k)
+
+ if isinstance(k, str):
+ desc = _entity_namespace_key(mapper, k, default=NO_VALUE)
+ if desc is NO_VALUE:
+ yield (
+ coercions.expect(roles.DMLColumnRole, k),
+ coercions.expect(
+ roles.ExpressionElementRole,
+ v,
+ type_=sqltypes.NullType(),
+ is_crud=True,
+ )
+ if needs_to_be_cacheable
+ else v,
+ )
+ else:
+ yield from core_get_crud_kv_pairs(
+ statement,
+ desc._bulk_update_tuples(v),
+ needs_to_be_cacheable,
+ )
+ elif "entity_namespace" in k._annotations:
+ k_anno = k._annotations
+ attr = _entity_namespace_key(
+ k_anno["entity_namespace"], k_anno["proxy_key"]
+ )
+ yield from core_get_crud_kv_pairs(
+ statement,
+ attr._bulk_update_tuples(v),
+ needs_to_be_cacheable,
+ )
+ else:
+ yield (
+ k,
+ v
+ if not needs_to_be_cacheable
+ else coercions.expect(
+ roles.ExpressionElementRole,
+ v,
+ type_=sqltypes.NullType(),
+ is_crud=True,
+ ),
+ )
+
+ @classmethod
+ def _get_multi_crud_kv_pairs(cls, statement, kv_iterator):
+ plugin_subject = statement._propagate_attrs["plugin_subject"]
+
+ if not plugin_subject or not plugin_subject.mapper:
+ return UpdateDMLState._get_multi_crud_kv_pairs(
+ statement, kv_iterator
+ )
+
+ return [
+ dict(
+ cls._get_orm_crud_kv_pairs(
+ plugin_subject.mapper, statement, value_dict.items(), False
+ )
+ )
+ for value_dict in kv_iterator
+ ]
+
+ @classmethod
+ def _get_crud_kv_pairs(cls, statement, kv_iterator, needs_to_be_cacheable):
+ assert (
+ needs_to_be_cacheable
+ ), "no test coverage for needs_to_be_cacheable=False"
+
+ plugin_subject = statement._propagate_attrs["plugin_subject"]
+
+ if not plugin_subject or not plugin_subject.mapper:
+ return UpdateDMLState._get_crud_kv_pairs(
+ statement, kv_iterator, needs_to_be_cacheable
+ )
+
+ return list(
+ cls._get_orm_crud_kv_pairs(
+ plugin_subject.mapper,
+ statement,
+ kv_iterator,
+ needs_to_be_cacheable,
+ )
+ )
+
@classmethod
def get_entity_description(cls, statement):
ext_info = statement.table._annotations["parententity"]
@@ -250,18 +461,101 @@ class ORMDMLState(AbstractORMCompileState):
]
]
+ def _setup_orm_returning(
+ self,
+ compiler,
+ orm_level_statement,
+ dml_level_statement,
+ use_supplemental_cols=True,
+ dml_mapper=None,
+ ):
+ """establish ORM column handlers for an INSERT, UPDATE, or DELETE
+ which uses explicit returning().
+
+ called within compilation level create_for_statement.
+
+ The _return_orm_returning() method then receives the Result
+ after the statement was executed, and applies ORM loading to the
+ state that we first established here.
+
+ """
+
+ if orm_level_statement._returning:
+
+ fs = FromStatement(
+ orm_level_statement._returning, dml_level_statement
+ )
+ fs = fs.options(*orm_level_statement._with_options)
+ self.select_statement = fs
+ self.from_statement_ctx = (
+ fsc
+ ) = ORMFromStatementCompileState.create_for_statement(fs, compiler)
+ fsc.setup_dml_returning_compile_state(dml_mapper)
+
+ dml_level_statement = dml_level_statement._generate()
+ dml_level_statement._returning = ()
+
+ cols_to_return = [c for c in fsc.primary_columns if c is not None]
+
+ # since we are splicing result sets together, make sure there
+ # are columns of some kind returned in each result set
+ if not cols_to_return:
+ cols_to_return.extend(dml_mapper.primary_key)
+
+ if use_supplemental_cols:
+ dml_level_statement = dml_level_statement.return_defaults(
+ supplemental_cols=cols_to_return
+ )
+ else:
+ dml_level_statement = dml_level_statement.returning(
+ *cols_to_return
+ )
+
+ return dml_level_statement
+
+ @classmethod
+ def _return_orm_returning(
+ cls,
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ result,
+ ):
+
+ execution_context = result.context
+ compile_state = execution_context.compiled.compile_state
+
+ if compile_state.from_statement_ctx:
+ load_options = execution_options.get(
+ "_sa_orm_load_options", QueryContext.default_load_options
+ )
+ querycontext = QueryContext(
+ compile_state.from_statement_ctx,
+ compile_state.select_statement,
+ params,
+ session,
+ load_options,
+ execution_options,
+ bind_arguments,
+ )
+ return loading.instances(result, querycontext)
+ else:
+ return result
+
class BulkUDCompileState(ORMDMLState):
class default_update_options(Options):
- _synchronize_session: _SynchronizeSessionArgument = "evaluate"
- _is_delete_using = False
- _is_update_from = False
- _autoflush = True
- _subject_mapper = None
+ _dml_strategy: _DMLStrategyArgument = "auto"
+ _synchronize_session: _SynchronizeSessionArgument = "auto"
+ _can_use_returning: bool = False
+ _is_delete_using: bool = False
+ _is_update_from: bool = False
+ _autoflush: bool = True
+ _subject_mapper: Optional[Mapper[Any]] = None
_resolved_values = EMPTY_DICT
- _resolved_keys_as_propnames = EMPTY_DICT
- _value_evaluators = EMPTY_DICT
- _matched_objects = None
+ _eval_condition = None
_matched_rows = None
_refresh_identity_token = None
@@ -295,19 +589,16 @@ class BulkUDCompileState(ORMDMLState):
execution_options,
) = BulkUDCompileState.default_update_options.from_execution_options(
"_sa_orm_update_options",
- {"synchronize_session", "is_delete_using", "is_update_from"},
+ {
+ "synchronize_session",
+ "is_delete_using",
+ "is_update_from",
+ "dml_strategy",
+ },
execution_options,
statement._execution_options,
)
- sync = update_options._synchronize_session
- if sync is not None:
- if sync not in ("evaluate", "fetch", False):
- raise sa_exc.ArgumentError(
- "Valid strategies for session synchronization "
- "are 'evaluate', 'fetch', False"
- )
-
bind_arguments["clause"] = statement
try:
plugin_subject = statement._propagate_attrs["plugin_subject"]
@@ -318,43 +609,86 @@ class BulkUDCompileState(ORMDMLState):
update_options += {"_subject_mapper": plugin_subject.mapper}
+ if not isinstance(params, list):
+ if update_options._dml_strategy == "auto":
+ update_options += {"_dml_strategy": "orm"}
+ elif update_options._dml_strategy == "bulk":
+ raise sa_exc.InvalidRequestError(
+ 'Can\'t use "bulk" ORM insert strategy without '
+ "passing separate parameters"
+ )
+ else:
+ if update_options._dml_strategy == "auto":
+ update_options += {"_dml_strategy": "bulk"}
+ elif update_options._dml_strategy == "orm":
+ raise sa_exc.InvalidRequestError(
+ 'Can\'t use "orm" ORM insert strategy with a '
+ "separate parameter list"
+ )
+
+ sync = update_options._synchronize_session
+ if sync is not None:
+ if sync not in ("auto", "evaluate", "fetch", False):
+ raise sa_exc.ArgumentError(
+ "Valid strategies for session synchronization "
+ "are 'auto', 'evaluate', 'fetch', False"
+ )
+ if update_options._dml_strategy == "bulk" and sync == "fetch":
+ raise sa_exc.InvalidRequestError(
+ "The 'fetch' synchronization strategy is not available "
+ "for 'bulk' ORM updates (i.e. multiple parameter sets)"
+ )
+
if update_options._autoflush:
session._autoflush()
+ if update_options._dml_strategy == "orm":
+
+ if update_options._synchronize_session == "auto":
+ update_options = cls._do_pre_synchronize_auto(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ update_options,
+ )
+ elif update_options._synchronize_session == "evaluate":
+ update_options = cls._do_pre_synchronize_evaluate(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ update_options,
+ )
+ elif update_options._synchronize_session == "fetch":
+ update_options = cls._do_pre_synchronize_fetch(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ update_options,
+ )
+ elif update_options._dml_strategy == "bulk":
+ if update_options._synchronize_session == "auto":
+ update_options += {"_synchronize_session": "evaluate"}
+
+ # indicators from the "pre exec" step that are then
+ # added to the DML statement, which will also be part of the cache
+ # key. The compile level create_for_statement() method will then
+ # consume these at compiler time.
statement = statement._annotate(
{
"synchronize_session": update_options._synchronize_session,
"is_delete_using": update_options._is_delete_using,
"is_update_from": update_options._is_update_from,
+ "dml_strategy": update_options._dml_strategy,
+ "can_use_returning": update_options._can_use_returning,
}
)
- # this stage of the execution is called before the do_orm_execute event
- # hook. meaning for an extension like horizontal sharding, this step
- # happens before the extension splits out into multiple backends and
- # runs only once. if we do pre_sync_fetch, we execute a SELECT
- # statement, which the horizontal sharding extension splits amongst the
- # shards and combines the results together.
-
- if update_options._synchronize_session == "evaluate":
- update_options = cls._do_pre_synchronize_evaluate(
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- update_options,
- )
- elif update_options._synchronize_session == "fetch":
- update_options = cls._do_pre_synchronize_fetch(
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- update_options,
- )
-
return (
statement,
util.immutabledict(execution_options).union(
@@ -382,12 +716,30 @@ class BulkUDCompileState(ORMDMLState):
# individual ones we return here.
update_options = execution_options["_sa_orm_update_options"]
- if update_options._synchronize_session == "evaluate":
- cls._do_post_synchronize_evaluate(session, result, update_options)
- elif update_options._synchronize_session == "fetch":
- cls._do_post_synchronize_fetch(session, result, update_options)
+ if update_options._dml_strategy == "orm":
+ if update_options._synchronize_session == "evaluate":
+ cls._do_post_synchronize_evaluate(
+ session, statement, result, update_options
+ )
+ elif update_options._synchronize_session == "fetch":
+ cls._do_post_synchronize_fetch(
+ session, statement, result, update_options
+ )
+ elif update_options._dml_strategy == "bulk":
+ if update_options._synchronize_session == "evaluate":
+ cls._do_post_synchronize_bulk_evaluate(
+ session, params, result, update_options
+ )
+ return result
- return result
+ return cls._return_orm_returning(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ result,
+ )
@classmethod
def _adjust_for_extra_criteria(cls, global_attributes, ext_info):
@@ -473,11 +825,76 @@ class BulkUDCompileState(ORMDMLState):
primary_key_convert = [
lookup[bpk] for bpk in mapper.base_mapper.primary_key
]
-
return [tuple(row[idx] for idx in primary_key_convert) for row in rows]
@classmethod
- def _do_pre_synchronize_evaluate(
+ def _get_matched_objects_on_criteria(cls, update_options, states):
+ mapper = update_options._subject_mapper
+ eval_condition = update_options._eval_condition
+
+ raw_data = [
+ (state.obj(), state, state.dict)
+ for state in states
+ if state.mapper.isa(mapper) and not state.expired
+ ]
+
+ identity_token = update_options._refresh_identity_token
+ if identity_token is not None:
+ raw_data = [
+ (obj, state, dict_)
+ for obj, state, dict_ in raw_data
+ if state.identity_token == identity_token
+ ]
+
+ result = []
+ for obj, state, dict_ in raw_data:
+ evaled_condition = eval_condition(obj)
+
+ # caution: don't use "in ()" or == here, _EXPIRE_OBJECT
+ # evaluates as True for all comparisons
+ if (
+ evaled_condition is True
+ or evaled_condition is evaluator._EXPIRED_OBJECT
+ ):
+ result.append(
+ (
+ obj,
+ state,
+ dict_,
+ evaled_condition is evaluator._EXPIRED_OBJECT,
+ )
+ )
+ return result
+
+ @classmethod
+ def _eval_condition_from_statement(cls, update_options, statement):
+ mapper = update_options._subject_mapper
+ target_cls = mapper.class_
+
+ evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
+ crit = ()
+ if statement._where_criteria:
+ crit += statement._where_criteria
+
+ global_attributes = {}
+ for opt in statement._with_options:
+ if opt._is_criteria_option:
+ opt.get_global_criteria(global_attributes)
+
+ if global_attributes:
+ crit += cls._adjust_for_extra_criteria(global_attributes, mapper)
+
+ if crit:
+ eval_condition = evaluator_compiler.process(*crit)
+ else:
+
+ def eval_condition(obj):
+ return True
+
+ return eval_condition
+
+ @classmethod
+ def _do_pre_synchronize_auto(
cls,
session,
statement,
@@ -486,33 +903,59 @@ class BulkUDCompileState(ORMDMLState):
bind_arguments,
update_options,
):
- mapper = update_options._subject_mapper
- target_cls = mapper.class_
+ """setup auto sync strategy
+
+
+ "auto" checks if we can use "evaluate" first, then falls back
+ to "fetch"
+
+ evaluate is vastly more efficient for the common case
+ where session is empty, only has a few objects, and the UPDATE
+ statement can potentially match thousands/millions of rows.
- value_evaluators = resolved_keys_as_propnames = EMPTY_DICT
+ OTOH more complex criteria that fails to work with "evaluate"
+ we would hope usually correlates with fewer net rows.
+
+ """
try:
- evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
- crit = ()
- if statement._where_criteria:
- crit += statement._where_criteria
+ eval_condition = cls._eval_condition_from_statement(
+ update_options, statement
+ )
- global_attributes = {}
- for opt in statement._with_options:
- if opt._is_criteria_option:
- opt.get_global_criteria(global_attributes)
+ except evaluator.UnevaluatableError:
+ pass
+ else:
+ return update_options + {
+ "_eval_condition": eval_condition,
+ "_synchronize_session": "evaluate",
+ }
- if global_attributes:
- crit += cls._adjust_for_extra_criteria(
- global_attributes, mapper
- )
+ update_options += {"_synchronize_session": "fetch"}
+ return cls._do_pre_synchronize_fetch(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ update_options,
+ )
- if crit:
- eval_condition = evaluator_compiler.process(*crit)
- else:
+ @classmethod
+ def _do_pre_synchronize_evaluate(
+ cls,
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ update_options,
+ ):
- def eval_condition(obj):
- return True
+ try:
+ eval_condition = cls._eval_condition_from_statement(
+ update_options, statement
+ )
except evaluator.UnevaluatableError as err:
raise sa_exc.InvalidRequestError(
@@ -521,52 +964,8 @@ class BulkUDCompileState(ORMDMLState):
"synchronize_session execution option." % err
) from err
- if statement.__visit_name__ == "lambda_element":
- # ._resolved is called on every LambdaElement in order to
- # generate the cache key, so this access does not add
- # additional expense
- effective_statement = statement._resolved
- else:
- effective_statement = statement
-
- if effective_statement.__visit_name__ == "update":
- resolved_values = cls._get_resolved_values(
- mapper, effective_statement
- )
- value_evaluators = {}
- resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
- mapper, resolved_values
- )
- for key, value in resolved_keys_as_propnames:
- try:
- _evaluator = evaluator_compiler.process(
- coercions.expect(roles.ExpressionElementRole, value)
- )
- except evaluator.UnevaluatableError:
- pass
- else:
- value_evaluators[key] = _evaluator
-
- # TODO: detect when the where clause is a trivial primary key match.
- matched_objects = [
- state.obj()
- for state in session.identity_map.all_states()
- if state.mapper.isa(mapper)
- and not state.expired
- and eval_condition(state.obj())
- and (
- update_options._refresh_identity_token is None
- # TODO: coverage for the case where horizontal sharding
- # invokes an update() or delete() given an explicit identity
- # token up front
- or state.identity_token
- == update_options._refresh_identity_token
- )
- ]
return update_options + {
- "_matched_objects": matched_objects,
- "_value_evaluators": value_evaluators,
- "_resolved_keys_as_propnames": resolved_keys_as_propnames,
+ "_eval_condition": eval_condition,
}
@classmethod
@@ -584,12 +983,6 @@ class BulkUDCompileState(ORMDMLState):
def _resolved_keys_as_propnames(cls, mapper, resolved_values):
values = []
for k, v in resolved_values:
- if isinstance(k, attributes.QueryableAttribute):
- values.append((k.key, v))
- continue
- elif hasattr(k, "__clause_element__"):
- k = k.__clause_element__()
-
if mapper and isinstance(k, expression.ColumnElement):
try:
attr = mapper._columntoproperty[k]
@@ -599,7 +992,8 @@ class BulkUDCompileState(ORMDMLState):
values.append((attr.key, v))
else:
raise sa_exc.InvalidRequestError(
- "Invalid expression type: %r" % k
+ "Attribute name not found, can't be "
+ "synchronized back to objects: %r" % k
)
return values
@@ -622,14 +1016,43 @@ class BulkUDCompileState(ORMDMLState):
)
select_stmt._where_criteria = statement._where_criteria
+ # conditionally run the SELECT statement for pre-fetch, testing the
+ # "bind" for if we can use RETURNING or not using the do_orm_execute
+ # event. If RETURNING is available, the do_orm_execute event
+ # will cancel the SELECT from being actually run.
+ #
+ # The way this is organized seems strange, why don't we just
+ # call can_use_returning() before invoking the statement and get
+ # answer?, why does this go through the whole execute phase using an
+ # event? Answer: because we are integrating with extensions such
+ # as the horizontal sharding extention that "multiplexes" an individual
+ # statement run through multiple engines, and it uses
+ # do_orm_execute() to do that.
+
+ can_use_returning = None
+
def skip_for_returning(orm_context: ORMExecuteState) -> Any:
bind = orm_context.session.get_bind(**orm_context.bind_arguments)
- if cls.can_use_returning(
+ nonlocal can_use_returning
+
+ per_bind_result = cls.can_use_returning(
bind.dialect,
mapper,
is_update_from=update_options._is_update_from,
is_delete_using=update_options._is_delete_using,
- ):
+ )
+
+ if can_use_returning is not None:
+ if can_use_returning != per_bind_result:
+ raise sa_exc.InvalidRequestError(
+ "For synchronize_session='fetch', can't mix multiple "
+ "backends where some support RETURNING and others "
+ "don't"
+ )
+ else:
+ can_use_returning = per_bind_result
+
+ if per_bind_result:
return _result.null_result()
else:
return None
@@ -643,52 +1066,22 @@ class BulkUDCompileState(ORMDMLState):
)
matched_rows = result.fetchall()
- value_evaluators = EMPTY_DICT
-
- if statement.__visit_name__ == "lambda_element":
- # ._resolved is called on every LambdaElement in order to
- # generate the cache key, so this access does not add
- # additional expense
- effective_statement = statement._resolved
- else:
- effective_statement = statement
-
- if effective_statement.__visit_name__ == "update":
- target_cls = mapper.class_
- evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
- resolved_values = cls._get_resolved_values(
- mapper, effective_statement
- )
- resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
- mapper, resolved_values
- )
-
- resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
- mapper, resolved_values
- )
- value_evaluators = {}
- for key, value in resolved_keys_as_propnames:
- try:
- _evaluator = evaluator_compiler.process(
- coercions.expect(roles.ExpressionElementRole, value)
- )
- except evaluator.UnevaluatableError:
- pass
- else:
- value_evaluators[key] = _evaluator
-
- else:
- resolved_keys_as_propnames = EMPTY_DICT
-
return update_options + {
- "_value_evaluators": value_evaluators,
"_matched_rows": matched_rows,
- "_resolved_keys_as_propnames": resolved_keys_as_propnames,
+ "_can_use_returning": can_use_returning,
}
@CompileState.plugin_for("orm", "insert")
-class ORMInsert(ORMDMLState, InsertDMLState):
+class BulkORMInsert(ORMDMLState, InsertDMLState):
+ class default_insert_options(Options):
+ _dml_strategy: _DMLStrategyArgument = "auto"
+ _render_nulls: bool = False
+ _return_defaults: bool = False
+ _subject_mapper: Optional[Mapper[Any]] = None
+
+ select_statement: Optional[FromStatement] = None
+
@classmethod
def orm_pre_session_exec(
cls,
@@ -699,6 +1092,16 @@ class ORMInsert(ORMDMLState, InsertDMLState):
bind_arguments,
is_reentrant_invoke,
):
+
+ (
+ insert_options,
+ execution_options,
+ ) = BulkORMInsert.default_insert_options.from_execution_options(
+ "_sa_orm_insert_options",
+ {"dml_strategy"},
+ execution_options,
+ statement._execution_options,
+ )
bind_arguments["clause"] = statement
try:
plugin_subject = statement._propagate_attrs["plugin_subject"]
@@ -707,22 +1110,209 @@ class ORMInsert(ORMDMLState, InsertDMLState):
else:
bind_arguments["mapper"] = plugin_subject.mapper
+ insert_options += {"_subject_mapper": plugin_subject.mapper}
+
+ if not params:
+ if insert_options._dml_strategy == "auto":
+ insert_options += {"_dml_strategy": "orm"}
+ elif insert_options._dml_strategy == "bulk":
+ raise sa_exc.InvalidRequestError(
+ 'Can\'t use "bulk" ORM insert strategy without '
+ "passing separate parameters"
+ )
+ else:
+ if insert_options._dml_strategy == "auto":
+ insert_options += {"_dml_strategy": "bulk"}
+ elif insert_options._dml_strategy == "orm":
+ raise sa_exc.InvalidRequestError(
+ 'Can\'t use "orm" ORM insert strategy with a '
+ "separate parameter list"
+ )
+
+ if insert_options._dml_strategy != "raw":
+ # for ORM object loading, like ORMContext, we have to disable
+ # result set adapt_to_context, because we will be generating a
+ # new statement with specific columns that's cached inside of
+ # an ORMFromStatementCompileState, which we will re-use for
+ # each result.
+ if not execution_options:
+ execution_options = context._orm_load_exec_options
+ else:
+ execution_options = execution_options.union(
+ context._orm_load_exec_options
+ )
+
+ statement = statement._annotate(
+ {"dml_strategy": insert_options._dml_strategy}
+ )
+
return (
statement,
- util.immutabledict(execution_options),
+ util.immutabledict(execution_options).union(
+ {"_sa_orm_insert_options": insert_options}
+ ),
)
@classmethod
- def orm_setup_cursor_result(
+ def orm_execute_statement(
cls,
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- result,
- ):
- return result
+ session: Session,
+ statement: dml.Insert,
+ params: _CoreAnyExecuteParams,
+ execution_options: _ExecuteOptionsParameter,
+ bind_arguments: _BindArguments,
+ conn: Connection,
+ ) -> _result.Result:
+
+ insert_options = execution_options.get(
+ "_sa_orm_insert_options", cls.default_insert_options
+ )
+
+ if insert_options._dml_strategy not in (
+ "raw",
+ "bulk",
+ "orm",
+ "auto",
+ ):
+ raise sa_exc.ArgumentError(
+ "Valid strategies for ORM insert strategy "
+ "are 'raw', 'orm', 'bulk', 'auto"
+ )
+
+ result: _result.Result[Any]
+
+ if insert_options._dml_strategy == "raw":
+ result = conn.execute(
+ statement, params or {}, execution_options=execution_options
+ )
+ return result
+
+ if insert_options._dml_strategy == "bulk":
+ mapper = insert_options._subject_mapper
+
+ if (
+ statement._post_values_clause is not None
+ and mapper._multiple_persistence_tables
+ ):
+ raise sa_exc.InvalidRequestError(
+ "bulk INSERT with a 'post values' clause "
+ "(typically upsert) not supported for multi-table "
+ f"mapper {mapper}"
+ )
+
+ assert mapper is not None
+ assert session._transaction is not None
+ result = _bulk_insert(
+ mapper,
+ cast(
+ "Iterable[Dict[str, Any]]",
+ [params] if isinstance(params, dict) else params,
+ ),
+ session._transaction,
+ isstates=False,
+ return_defaults=insert_options._return_defaults,
+ render_nulls=insert_options._render_nulls,
+ use_orm_insert_stmt=statement,
+ execution_options=execution_options,
+ )
+ elif insert_options._dml_strategy == "orm":
+ result = conn.execute(
+ statement, params or {}, execution_options=execution_options
+ )
+ else:
+ raise AssertionError()
+
+ if not bool(statement._returning):
+ return result
+
+ return cls._return_orm_returning(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ result,
+ )
+
+ @classmethod
+ def create_for_statement(cls, statement, compiler, **kw) -> BulkORMInsert:
+
+ self = cast(
+ BulkORMInsert,
+ super().create_for_statement(statement, compiler, **kw),
+ )
+
+ if compiler is not None:
+ toplevel = not compiler.stack
+ else:
+ toplevel = True
+ if not toplevel:
+ return self
+
+ mapper = statement._propagate_attrs["plugin_subject"]
+ dml_strategy = statement._annotations.get("dml_strategy", "raw")
+ if dml_strategy == "bulk":
+ self._setup_for_bulk_insert(compiler)
+ elif dml_strategy == "orm":
+ self._setup_for_orm_insert(compiler, mapper)
+
+ return self
+
+ @classmethod
+ def _resolved_keys_as_col_keys(cls, mapper, resolved_value_dict):
+ return {
+ col.key if col is not None else k: v
+ for col, k, v in (
+ (mapper.c.get(k), k, v) for k, v in resolved_value_dict.items()
+ )
+ }
+
+ def _setup_for_orm_insert(self, compiler, mapper):
+ statement = orm_level_statement = cast(dml.Insert, self.statement)
+
+ statement = self._setup_orm_returning(
+ compiler,
+ orm_level_statement,
+ statement,
+ use_supplemental_cols=False,
+ )
+ self.statement = statement
+
+ def _setup_for_bulk_insert(self, compiler):
+ """establish an INSERT statement within the context of
+ bulk insert.
+
+ This method will be within the "conn.execute()" call that is invoked
+ by persistence._emit_insert_statement().
+
+ """
+ statement = orm_level_statement = cast(dml.Insert, self.statement)
+ an = statement._annotations
+
+ emit_insert_table, emit_insert_mapper = (
+ an["_emit_insert_table"],
+ an["_emit_insert_mapper"],
+ )
+
+ statement = statement._clone()
+
+ statement.table = emit_insert_table
+ if self._dict_parameters:
+ self._dict_parameters = {
+ col: val
+ for col, val in self._dict_parameters.items()
+ if col.table is emit_insert_table
+ }
+
+ statement = self._setup_orm_returning(
+ compiler,
+ orm_level_statement,
+ statement,
+ use_supplemental_cols=True,
+ dml_mapper=emit_insert_mapper,
+ )
+
+ self.statement = statement
@CompileState.plugin_for("orm", "update")
@@ -732,13 +1322,27 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
self = cls.__new__(cls)
+ dml_strategy = statement._annotations.get(
+ "dml_strategy", "unspecified"
+ )
+
+ if dml_strategy == "bulk":
+ self._setup_for_bulk_update(statement, compiler)
+ elif dml_strategy in ("orm", "unspecified"):
+ self._setup_for_orm_update(statement, compiler)
+
+ return self
+
+ def _setup_for_orm_update(self, statement, compiler, **kw):
+ orm_level_statement = statement
+
ext_info = statement.table._annotations["parententity"]
self.mapper = mapper = ext_info.mapper
self.extra_criteria_entities = {}
- self._resolved_values = cls._get_resolved_values(mapper, statement)
+ self._resolved_values = self._get_resolved_values(mapper, statement)
extra_criteria_attributes = {}
@@ -749,8 +1353,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
if statement._values:
self._resolved_values = dict(self._resolved_values)
- new_stmt = sql.Update.__new__(sql.Update)
- new_stmt.__dict__.update(statement.__dict__)
+ new_stmt = statement._clone()
new_stmt.table = mapper.local_table
# note if the statement has _multi_values, these
@@ -762,7 +1365,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
elif statement._values:
new_stmt._values = self._resolved_values
- new_crit = cls._adjust_for_extra_criteria(
+ new_crit = self._adjust_for_extra_criteria(
extra_criteria_attributes, mapper
)
if new_crit:
@@ -776,21 +1379,150 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
UpdateDMLState.__init__(self, new_stmt, compiler, **kw)
- if compiler._annotations.get(
+ use_supplemental_cols = False
+
+ synchronize_session = compiler._annotations.get(
"synchronize_session", None
- ) == "fetch" and self.can_use_returning(
- compiler.dialect, mapper, is_multitable=self.is_multitable
- ):
- if new_stmt._returning:
- raise sa_exc.InvalidRequestError(
- "Can't use synchronize_session='fetch' "
- "with explicit returning()"
+ )
+ can_use_returning = compiler._annotations.get(
+ "can_use_returning", None
+ )
+ if can_use_returning is not False:
+ # even though pre_exec has determined basic
+ # can_use_returning for the dialect, if we are to use
+ # RETURNING we need to run can_use_returning() at this level
+ # unconditionally because is_delete_using was not known
+ # at the pre_exec level
+ can_use_returning = (
+ synchronize_session == "fetch"
+ and self.can_use_returning(
+ compiler.dialect, mapper, is_multitable=self.is_multitable
)
- self.statement = self.statement.returning(
- *mapper.local_table.primary_key
)
- return self
+ if synchronize_session == "fetch" and can_use_returning:
+ use_supplemental_cols = True
+
+ # NOTE: we might want to RETURNING the actual columns to be
+ # synchronized also. however this is complicated and difficult
+ # to align against the behavior of "evaluate". Additionally,
+ # in a large number (if not the majority) of cases, we have the
+ # "evaluate" answer, usually a fixed value, in memory already and
+ # there's no need to re-fetch the same value
+ # over and over again. so perhaps if it could be RETURNING just
+ # the elements that were based on a SQL expression and not
+ # a constant. For now it doesn't quite seem worth it
+ new_stmt = new_stmt.return_defaults(
+ *(list(mapper.local_table.primary_key))
+ )
+
+ new_stmt = self._setup_orm_returning(
+ compiler,
+ orm_level_statement,
+ new_stmt,
+ use_supplemental_cols=use_supplemental_cols,
+ )
+
+ self.statement = new_stmt
+
+ def _setup_for_bulk_update(self, statement, compiler, **kw):
+ """establish an UPDATE statement within the context of
+ bulk insert.
+
+ This method will be within the "conn.execute()" call that is invoked
+ by persistence._emit_update_statement().
+
+ """
+ statement = cast(dml.Update, statement)
+ an = statement._annotations
+
+ emit_update_table, _ = (
+ an["_emit_update_table"],
+ an["_emit_update_mapper"],
+ )
+
+ statement = statement._clone()
+ statement.table = emit_update_table
+
+ UpdateDMLState.__init__(self, statement, compiler, **kw)
+
+ if self._ordered_values:
+ raise sa_exc.InvalidRequestError(
+ "bulk ORM UPDATE does not support ordered_values() for "
+ "custom UPDATE statements with bulk parameter sets. Use a "
+ "non-bulk UPDATE statement or use values()."
+ )
+
+ if self._dict_parameters:
+ self._dict_parameters = {
+ col: val
+ for col, val in self._dict_parameters.items()
+ if col.table is emit_update_table
+ }
+ self.statement = statement
+
+ @classmethod
+ def orm_execute_statement(
+ cls,
+ session: Session,
+ statement: dml.Update,
+ params: _CoreAnyExecuteParams,
+ execution_options: _ExecuteOptionsParameter,
+ bind_arguments: _BindArguments,
+ conn: Connection,
+ ) -> _result.Result:
+
+ update_options = execution_options.get(
+ "_sa_orm_update_options", cls.default_update_options
+ )
+
+ if update_options._dml_strategy not in ("orm", "auto", "bulk"):
+ raise sa_exc.ArgumentError(
+ "Valid strategies for ORM UPDATE strategy "
+ "are 'orm', 'auto', 'bulk'"
+ )
+
+ result: _result.Result[Any]
+
+ if update_options._dml_strategy == "bulk":
+ if statement._where_criteria:
+ raise sa_exc.InvalidRequestError(
+ "WHERE clause with bulk ORM UPDATE not "
+ "supported right now. Statement may be invoked at the "
+ "Core level using "
+ "session.connection().execute(stmt, parameters)"
+ )
+ mapper = update_options._subject_mapper
+ assert mapper is not None
+ assert session._transaction is not None
+ result = _bulk_update(
+ mapper,
+ cast(
+ "Iterable[Dict[str, Any]]",
+ [params] if isinstance(params, dict) else params,
+ ),
+ session._transaction,
+ isstates=False,
+ update_changed_only=False,
+ use_orm_update_stmt=statement,
+ )
+ return cls.orm_setup_cursor_result(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ result,
+ )
+ else:
+ return super().orm_execute_statement(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ conn,
+ )
@classmethod
def can_use_returning(
@@ -827,119 +1559,80 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
return True
@classmethod
- def _get_crud_kv_pairs(cls, statement, kv_iterator):
- plugin_subject = statement._propagate_attrs["plugin_subject"]
-
- core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs
-
- if not plugin_subject or not plugin_subject.mapper:
- return core_get_crud_kv_pairs(statement, kv_iterator)
-
- mapper = plugin_subject.mapper
-
- values = []
-
- for k, v in kv_iterator:
- k = coercions.expect(roles.DMLColumnRole, k)
+ def _do_post_synchronize_bulk_evaluate(
+ cls, session, params, result, update_options
+ ):
+ if not params:
+ return
- if isinstance(k, str):
- desc = _entity_namespace_key(mapper, k, default=NO_VALUE)
- if desc is NO_VALUE:
- values.append(
- (
- k,
- coercions.expect(
- roles.ExpressionElementRole,
- v,
- type_=sqltypes.NullType(),
- is_crud=True,
- ),
- )
- )
- else:
- values.extend(
- core_get_crud_kv_pairs(
- statement, desc._bulk_update_tuples(v)
- )
- )
- elif "entity_namespace" in k._annotations:
- k_anno = k._annotations
- attr = _entity_namespace_key(
- k_anno["entity_namespace"], k_anno["proxy_key"]
- )
- values.extend(
- core_get_crud_kv_pairs(
- statement, attr._bulk_update_tuples(v)
- )
- )
- else:
- values.append(
- (
- k,
- coercions.expect(
- roles.ExpressionElementRole,
- v,
- type_=sqltypes.NullType(),
- is_crud=True,
- ),
- )
- )
- return values
+ mapper = update_options._subject_mapper
+ pk_keys = [prop.key for prop in mapper._identity_key_props]
- @classmethod
- def _do_post_synchronize_evaluate(cls, session, result, update_options):
+ identity_map = session.identity_map
- states = set()
- evaluated_keys = list(update_options._value_evaluators.keys())
- values = update_options._resolved_keys_as_propnames
- attrib = set(k for k, v in values)
- for obj in update_options._matched_objects:
-
- state, dict_ = (
- attributes.instance_state(obj),
- attributes.instance_dict(obj),
+ for param in params:
+ identity_key = mapper.identity_key_from_primary_key(
+ (param[key] for key in pk_keys),
+ update_options._refresh_identity_token,
)
-
- # the evaluated states were gathered across all identity tokens.
- # however the post_sync events are called per identity token,
- # so filter.
- if (
- update_options._refresh_identity_token is not None
- and state.identity_token
- != update_options._refresh_identity_token
- ):
+ state = identity_map.fast_get_state(identity_key)
+ if not state:
continue
+ evaluated_keys = set(param).difference(pk_keys)
+
+ dict_ = state.dict
# only evaluate unmodified attributes
to_evaluate = state.unmodified.intersection(evaluated_keys)
for key in to_evaluate:
if key in dict_:
- dict_[key] = update_options._value_evaluators[key](obj)
+ dict_[key] = param[key]
state.manager.dispatch.refresh(state, None, to_evaluate)
state._commit(dict_, list(to_evaluate))
- to_expire = attrib.intersection(dict_).difference(to_evaluate)
+ # attributes that were formerly modified instead get expired.
+ # this only gets hit if the session had pending changes
+ # and autoflush were set to False.
+ to_expire = evaluated_keys.intersection(dict_).difference(
+ to_evaluate
+ )
if to_expire:
state._expire_attributes(dict_, to_expire)
- states.add(state)
- session._register_altered(states)
+ @classmethod
+ def _do_post_synchronize_evaluate(
+ cls, session, statement, result, update_options
+ ):
+
+ matched_objects = cls._get_matched_objects_on_criteria(
+ update_options,
+ session.identity_map.all_states(),
+ )
+
+ cls._apply_update_set_values_to_objects(
+ session,
+ update_options,
+ statement,
+ [(obj, state, dict_) for obj, state, dict_, _ in matched_objects],
+ )
@classmethod
- def _do_post_synchronize_fetch(cls, session, result, update_options):
+ def _do_post_synchronize_fetch(
+ cls, session, statement, result, update_options
+ ):
target_mapper = update_options._subject_mapper
- states = set()
- evaluated_keys = list(update_options._value_evaluators.keys())
-
- if result.returns_rows:
- rows = cls._interpret_returning_rows(target_mapper, result.all())
+ returned_defaults_rows = result.returned_defaults_rows
+ if returned_defaults_rows:
+ pk_rows = cls._interpret_returning_rows(
+ target_mapper, returned_defaults_rows
+ )
matched_rows = [
tuple(row) + (update_options._refresh_identity_token,)
- for row in rows
+ for row in pk_rows
]
else:
matched_rows = update_options._matched_rows
@@ -960,23 +1653,69 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
if identity_key in session.identity_map
]
- values = update_options._resolved_keys_as_propnames
- attrib = set(k for k, v in values)
+ if not objs:
+ return
- for obj in objs:
- state, dict_ = (
- attributes.instance_state(obj),
- attributes.instance_dict(obj),
- )
+ cls._apply_update_set_values_to_objects(
+ session,
+ update_options,
+ statement,
+ [
+ (
+ obj,
+ attributes.instance_state(obj),
+ attributes.instance_dict(obj),
+ )
+ for obj in objs
+ ],
+ )
+
+ @classmethod
+ def _apply_update_set_values_to_objects(
+ cls, session, update_options, statement, matched_objects
+ ):
+ """apply values to objects derived from an update statement, e.g.
+ UPDATE..SET <values>
+
+ """
+ mapper = update_options._subject_mapper
+ target_cls = mapper.class_
+ evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
+ resolved_values = cls._get_resolved_values(mapper, statement)
+ resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
+ mapper, resolved_values
+ )
+ value_evaluators = {}
+ for key, value in resolved_keys_as_propnames:
+ try:
+ _evaluator = evaluator_compiler.process(
+ coercions.expect(roles.ExpressionElementRole, value)
+ )
+ except evaluator.UnevaluatableError:
+ pass
+ else:
+ value_evaluators[key] = _evaluator
+
+ evaluated_keys = list(value_evaluators.keys())
+ attrib = set(k for k, v in resolved_keys_as_propnames)
+
+ states = set()
+ for obj, state, dict_ in matched_objects:
to_evaluate = state.unmodified.intersection(evaluated_keys)
+
for key in to_evaluate:
if key in dict_:
- dict_[key] = update_options._value_evaluators[key](obj)
+ # only run eval for attributes that are present.
+ dict_[key] = value_evaluators[key](obj)
+
state.manager.dispatch.refresh(state, None, to_evaluate)
state._commit(dict_, list(to_evaluate))
+ # attributes that were formerly modified instead get expired.
+ # this only gets hit if the session had pending changes
+ # and autoflush were set to False.
to_expire = attrib.intersection(dict_).difference(to_evaluate)
if to_expire:
state._expire_attributes(dict_, to_expire)
@@ -991,6 +1730,8 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
def create_for_statement(cls, statement, compiler, **kw):
self = cls.__new__(cls)
+ orm_level_statement = statement
+
ext_info = statement.table._annotations["parententity"]
self.mapper = mapper = ext_info.mapper
@@ -1002,31 +1743,97 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
if opt._is_criteria_option:
opt.get_global_criteria(extra_criteria_attributes)
+ new_stmt = statement._clone()
+ new_stmt.table = mapper.local_table
+
new_crit = cls._adjust_for_extra_criteria(
extra_criteria_attributes, mapper
)
if new_crit:
- statement = statement.where(*new_crit)
+ new_stmt = new_stmt.where(*new_crit)
# do this first as we need to determine if there is
# DELETE..FROM
- DeleteDMLState.__init__(self, statement, compiler, **kw)
+ DeleteDMLState.__init__(self, new_stmt, compiler, **kw)
+
+ use_supplemental_cols = False
- if compiler._annotations.get(
+ synchronize_session = compiler._annotations.get(
"synchronize_session", None
- ) == "fetch" and self.can_use_returning(
- compiler.dialect,
- mapper,
- is_multitable=self.is_multitable,
- is_delete_using=compiler._annotations.get(
- "is_delete_using", False
- ),
- ):
- self.statement = statement.returning(*statement.table.primary_key)
+ )
+ can_use_returning = compiler._annotations.get(
+ "can_use_returning", None
+ )
+ if can_use_returning is not False:
+ # even though pre_exec has determined basic
+ # can_use_returning for the dialect, if we are to use
+ # RETURNING we need to run can_use_returning() at this level
+ # unconditionally because is_delete_using was not known
+ # at the pre_exec level
+ can_use_returning = (
+ synchronize_session == "fetch"
+ and self.can_use_returning(
+ compiler.dialect,
+ mapper,
+ is_multitable=self.is_multitable,
+ is_delete_using=compiler._annotations.get(
+ "is_delete_using", False
+ ),
+ )
+ )
+
+ if can_use_returning:
+ use_supplemental_cols = True
+
+ new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key)
+
+ new_stmt = self._setup_orm_returning(
+ compiler,
+ orm_level_statement,
+ new_stmt,
+ use_supplemental_cols=use_supplemental_cols,
+ )
+
+ self.statement = new_stmt
return self
@classmethod
+ def orm_execute_statement(
+ cls,
+ session: Session,
+ statement: dml.Delete,
+ params: _CoreAnyExecuteParams,
+ execution_options: _ExecuteOptionsParameter,
+ bind_arguments: _BindArguments,
+ conn: Connection,
+ ) -> _result.Result:
+
+ update_options = execution_options.get(
+ "_sa_orm_update_options", cls.default_update_options
+ )
+
+ if update_options._dml_strategy == "bulk":
+ raise sa_exc.InvalidRequestError(
+ "Bulk ORM DELETE not supported right now. "
+ "Statement may be invoked at the "
+ "Core level using "
+ "session.connection().execute(stmt, parameters)"
+ )
+
+ if update_options._dml_strategy not in (
+ "orm",
+ "auto",
+ ):
+ raise sa_exc.ArgumentError(
+ "Valid strategies for ORM DELETE strategy are 'orm', 'auto'"
+ )
+
+ return super().orm_execute_statement(
+ session, statement, params, execution_options, bind_arguments, conn
+ )
+
+ @classmethod
def can_use_returning(
cls,
dialect: Dialect,
@@ -1068,25 +1875,41 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
return True
@classmethod
- def _do_post_synchronize_evaluate(cls, session, result, update_options):
-
- session._remove_newly_deleted(
- [
- attributes.instance_state(obj)
- for obj in update_options._matched_objects
- ]
+ def _do_post_synchronize_evaluate(
+ cls, session, statement, result, update_options
+ ):
+ matched_objects = cls._get_matched_objects_on_criteria(
+ update_options,
+ session.identity_map.all_states(),
)
+ to_delete = []
+
+ for _, state, dict_, is_partially_expired in matched_objects:
+ if is_partially_expired:
+ state._expire(dict_, session.identity_map._modified)
+ else:
+ to_delete.append(state)
+
+ if to_delete:
+ session._remove_newly_deleted(to_delete)
+
@classmethod
- def _do_post_synchronize_fetch(cls, session, result, update_options):
+ def _do_post_synchronize_fetch(
+ cls, session, statement, result, update_options
+ ):
target_mapper = update_options._subject_mapper
- if result.returns_rows:
- rows = cls._interpret_returning_rows(target_mapper, result.all())
+ returned_defaults_rows = result.returned_defaults_rows
+
+ if returned_defaults_rows:
+ pk_rows = cls._interpret_returning_rows(
+ target_mapper, returned_defaults_rows
+ )
matched_rows = [
tuple(row) + (update_options._refresh_identity_token,)
- for row in rows
+ for row in pk_rows
]
else:
matched_rows = update_options._matched_rows
diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py
index dc96f8c3c..f8c7ba714 100644
--- a/lib/sqlalchemy/orm/context.py
+++ b/lib/sqlalchemy/orm/context.py
@@ -73,6 +73,7 @@ if TYPE_CHECKING:
from .query import Query
from .session import _BindArguments
from .session import Session
+ from ..engine import Result
from ..engine.interfaces import _CoreSingleExecuteParams
from ..engine.interfaces import _ExecuteOptionsParameter
from ..sql._typing import _ColumnsClauseArgument
@@ -203,15 +204,19 @@ _orm_load_exec_options = util.immutabledict(
class AbstractORMCompileState(CompileState):
+ is_dml_returning = False
+
@classmethod
def create_for_statement(
cls,
statement: Union[Select, FromStatement],
compiler: Optional[SQLCompiler],
**kw: Any,
- ) -> ORMCompileState:
+ ) -> AbstractORMCompileState:
"""Create a context for a statement given a :class:`.Compiler`.
+
This method is always invoked in the context of SQLCompiler.process().
+
For a Select object, this would be invoked from
SQLCompiler.visit_select(). For the special FromStatement object used
by Query to indicate "Query.from_statement()", this is called by
@@ -233,6 +238,28 @@ class AbstractORMCompileState(CompileState):
raise NotImplementedError()
@classmethod
+ def orm_execute_statement(
+ cls,
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ conn,
+ ) -> Result:
+ result = conn.execute(
+ statement, params or {}, execution_options=execution_options
+ )
+ return cls.orm_setup_cursor_result(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ result,
+ )
+
+ @classmethod
def orm_setup_cursor_result(
cls,
session,
@@ -309,6 +336,17 @@ class ORMCompileState(AbstractORMCompileState):
def __init__(self, *arg, **kw):
raise NotImplementedError()
+ if TYPE_CHECKING:
+
+ @classmethod
+ def create_for_statement(
+ cls,
+ statement: Union[Select, FromStatement],
+ compiler: Optional[SQLCompiler],
+ **kw: Any,
+ ) -> ORMCompileState:
+ ...
+
def _append_dedupe_col_collection(self, obj, col_collection):
dedupe = self.dedupe_columns
if obj not in dedupe:
@@ -333,26 +371,6 @@ class ORMCompileState(AbstractORMCompileState):
return SelectState._column_naming_convention(label_style)
@classmethod
- def create_for_statement(
- cls,
- statement: Union[Select, FromStatement],
- compiler: Optional[SQLCompiler],
- **kw: Any,
- ) -> ORMCompileState:
- """Create a context for a statement given a :class:`.Compiler`.
-
- This method is always invoked in the context of SQLCompiler.process().
-
- For a Select object, this would be invoked from
- SQLCompiler.visit_select(). For the special FromStatement object used
- by Query to indicate "Query.from_statement()", this is called by
- FromStatement._compiler_dispatch() that would be called by
- SQLCompiler.process().
-
- """
- raise NotImplementedError()
-
- @classmethod
def get_column_descriptions(cls, statement):
return _column_descriptions(statement)
@@ -518,6 +536,49 @@ class ORMCompileState(AbstractORMCompileState):
)
+class DMLReturningColFilter:
+ """an adapter used for the DML RETURNING case.
+
+ Has a subset of the interface used by
+ :class:`.ORMAdapter` and is used for :class:`._QueryEntity`
+ instances to set up their columns as used in RETURNING for a
+ DML statement.
+
+ """
+
+ __slots__ = ("mapper", "columns", "__weakref__")
+
+ def __init__(self, target_mapper, immediate_dml_mapper):
+ if (
+ immediate_dml_mapper is not None
+ and target_mapper.local_table
+ is not immediate_dml_mapper.local_table
+ ):
+ # joined inh, or in theory other kinds of multi-table mappings
+ self.mapper = immediate_dml_mapper
+ else:
+ # single inh, normal mappings, etc.
+ self.mapper = target_mapper
+ self.columns = self.columns = util.WeakPopulateDict(
+ self.adapt_check_present # type: ignore
+ )
+
+ def __call__(self, col, as_filter):
+ for cc in sql_util._find_columns(col):
+ c2 = self.adapt_check_present(cc)
+ if c2 is not None:
+ return col
+ else:
+ return None
+
+ def adapt_check_present(self, col):
+ mapper = self.mapper
+ prop = mapper._columntoproperty.get(col, None)
+ if prop is None:
+ return None
+ return mapper.local_table.c.corresponding_column(col)
+
+
@sql.base.CompileState.plugin_for("orm", "orm_from_statement")
class ORMFromStatementCompileState(ORMCompileState):
_from_obj_alias = None
@@ -525,7 +586,7 @@ class ORMFromStatementCompileState(ORMCompileState):
statement_container: FromStatement
requested_statement: Union[SelectBase, TextClause, UpdateBase]
- dml_table: _DMLTableElement
+ dml_table: Optional[_DMLTableElement] = None
_has_orm_entities = False
multi_row_eager_loaders = False
@@ -541,7 +602,7 @@ class ORMFromStatementCompileState(ORMCompileState):
statement_container: Union[Select, FromStatement],
compiler: Optional[SQLCompiler],
**kw: Any,
- ) -> ORMCompileState:
+ ) -> ORMFromStatementCompileState:
if compiler is not None:
toplevel = not compiler.stack
@@ -565,6 +626,7 @@ class ORMFromStatementCompileState(ORMCompileState):
if statement.is_dml:
self.dml_table = statement.table
+ self.is_dml_returning = True
self._entities = []
self._polymorphic_adapters = {}
@@ -674,6 +736,18 @@ class ORMFromStatementCompileState(ORMCompileState):
def _get_current_adapter(self):
return None
+ def setup_dml_returning_compile_state(self, dml_mapper):
+ """used by BulkORMInsert (and Update / Delete?) to set up a handler
+ for RETURNING to return ORM objects and expressions
+
+ """
+ target_mapper = self.statement._propagate_attrs.get(
+ "plugin_subject", None
+ )
+ adapter = DMLReturningColFilter(target_mapper, dml_mapper)
+ for entity in self._entities:
+ entity.setup_dml_returning_compile_state(self, adapter)
+
class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]):
"""Core construct that represents a load of ORM objects from various
@@ -813,7 +887,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
statement: Union[Select, FromStatement],
compiler: Optional[SQLCompiler],
**kw: Any,
- ) -> ORMCompileState:
+ ) -> ORMSelectCompileState:
"""compiler hook, we arrive here from compiler.visit_select() only."""
self = cls.__new__(cls)
@@ -2312,6 +2386,13 @@ class _QueryEntity:
def setup_compile_state(self, compile_state: ORMCompileState) -> None:
raise NotImplementedError()
+ def setup_dml_returning_compile_state(
+ self,
+ compile_state: ORMCompileState,
+ adapter: DMLReturningColFilter,
+ ) -> None:
+ raise NotImplementedError()
+
def row_processor(self, context, result):
raise NotImplementedError()
@@ -2509,8 +2590,24 @@ class _MapperEntity(_QueryEntity):
return _instance, self._label_name, self._extra_entities
- def setup_compile_state(self, compile_state):
+ def setup_dml_returning_compile_state(
+ self,
+ compile_state: ORMCompileState,
+ adapter: DMLReturningColFilter,
+ ) -> None:
+ loading._setup_entity_query(
+ compile_state,
+ self.mapper,
+ self,
+ self.path,
+ adapter,
+ compile_state.primary_columns,
+ with_polymorphic=self._with_polymorphic_mappers,
+ only_load_props=compile_state.compile_options._only_load_props,
+ polymorphic_discriminator=self._polymorphic_discriminator,
+ )
+ def setup_compile_state(self, compile_state):
adapter = self._get_entity_clauses(compile_state)
single_table_crit = self.mapper._single_table_criterion
@@ -2536,7 +2633,6 @@ class _MapperEntity(_QueryEntity):
only_load_props=compile_state.compile_options._only_load_props,
polymorphic_discriminator=self._polymorphic_discriminator,
)
-
compile_state._fallback_from_clauses.append(self.selectable)
@@ -2743,9 +2839,7 @@ class _ColumnEntity(_QueryEntity):
getter, label_name, extra_entities = self._row_processor
if self.translate_raw_column:
extra_entities += (
- result.context.invoked_statement._raw_columns[
- self.raw_column_index
- ],
+ context.query._raw_columns[self.raw_column_index],
)
return getter, label_name, extra_entities
@@ -2781,9 +2875,7 @@ class _ColumnEntity(_QueryEntity):
if self.translate_raw_column:
extra_entities = self._extra_entities + (
- result.context.invoked_statement._raw_columns[
- self.raw_column_index
- ],
+ context.query._raw_columns[self.raw_column_index],
)
return getter, self._label_name, extra_entities
else:
@@ -2843,6 +2935,8 @@ class _RawColumnEntity(_ColumnEntity):
current_adapter = compile_state._get_current_adapter()
if current_adapter:
column = current_adapter(self.column, False)
+ if column is None:
+ return
else:
column = self.column
@@ -2944,10 +3038,25 @@ class _ORMColumnEntity(_ColumnEntity):
self.entity_zero
) and entity.common_parent(self.entity_zero)
+ def setup_dml_returning_compile_state(
+ self,
+ compile_state: ORMCompileState,
+ adapter: DMLReturningColFilter,
+ ) -> None:
+ self._fetch_column = self.column
+ column = adapter(self.column, False)
+ if column is not None:
+ compile_state.dedupe_columns.add(column)
+ compile_state.primary_columns.append(column)
+
def setup_compile_state(self, compile_state):
current_adapter = compile_state._get_current_adapter()
if current_adapter:
column = current_adapter(self.column, False)
+ if column is None:
+ assert compile_state.is_dml_returning
+ self._fetch_column = self.column
+ return
else:
column = self.column
diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py
index 52b70b9d4..13d3b70fe 100644
--- a/lib/sqlalchemy/orm/descriptor_props.py
+++ b/lib/sqlalchemy/orm/descriptor_props.py
@@ -19,6 +19,7 @@ import operator
import typing
from typing import Any
from typing import Callable
+from typing import Dict
from typing import List
from typing import NoReturn
from typing import Optional
@@ -602,6 +603,31 @@ class Composite(
def _attribute_keys(self) -> Sequence[str]:
return [prop.key for prop in self.props]
+ def _populate_composite_bulk_save_mappings_fn(
+ self,
+ ) -> Callable[[Dict[str, Any]], None]:
+
+ if self._generated_composite_accessor:
+ get_values = self._generated_composite_accessor
+ else:
+
+ def get_values(val: Any) -> Tuple[Any]:
+ return val.__composite_values__() # type: ignore
+
+ attrs = [prop.key for prop in self.props]
+
+ def populate(dest_dict: Dict[str, Any]) -> None:
+ dest_dict.update(
+ {
+ key: val
+ for key, val in zip(
+ attrs, get_values(dest_dict.pop(self.key))
+ )
+ }
+ )
+
+ return populate
+
def get_history(
self,
state: InstanceState[Any],
diff --git a/lib/sqlalchemy/orm/evaluator.py b/lib/sqlalchemy/orm/evaluator.py
index b3129afdd..5af14cc00 100644
--- a/lib/sqlalchemy/orm/evaluator.py
+++ b/lib/sqlalchemy/orm/evaluator.py
@@ -9,8 +9,8 @@
from __future__ import annotations
-import operator
-
+from .base import LoaderCallableStatus
+from .base import PassiveFlag
from .. import exc
from .. import inspect
from .. import util
@@ -32,7 +32,16 @@ class _NoObject(operators.ColumnOperators):
return None
+class _ExpiredObject(operators.ColumnOperators):
+ def operate(self, *arg, **kw):
+ return self
+
+ def reverse_operate(self, *arg, **kw):
+ return self
+
+
_NO_OBJECT = _NoObject()
+_EXPIRED_OBJECT = _ExpiredObject()
class EvaluatorCompiler:
@@ -73,6 +82,24 @@ class EvaluatorCompiler:
f"alternate class {parentmapper.class_}"
)
key = parentmapper._columntoproperty[clause].key
+ impl = parentmapper.class_manager[key].impl
+
+ if impl is not None:
+
+ def get_corresponding_attr(obj):
+ if obj is None:
+ return _NO_OBJECT
+ state = inspect(obj)
+ dict_ = state.dict
+
+ value = impl.get(
+ state, dict_, passive=PassiveFlag.PASSIVE_NO_FETCH
+ )
+ if value is LoaderCallableStatus.PASSIVE_NO_RESULT:
+ return _EXPIRED_OBJECT
+ return value
+
+ return get_corresponding_attr
else:
key = clause.key
if (
@@ -85,15 +112,16 @@ class EvaluatorCompiler:
"make use of the actual mapped columns in ORM-evaluated "
"UPDATE / DELETE expressions."
)
+
else:
raise UnevaluatableError(f"Cannot evaluate column: {clause}")
- get_corresponding_attr = operator.attrgetter(key)
- return (
- lambda obj: get_corresponding_attr(obj)
- if obj is not None
- else _NO_OBJECT
- )
+ def get_corresponding_attr(obj):
+ if obj is None:
+ return _NO_OBJECT
+ return getattr(obj, key, _EXPIRED_OBJECT)
+
+ return get_corresponding_attr
def visit_tuple(self, clause):
return self.visit_clauselist(clause)
@@ -134,7 +162,9 @@ class EvaluatorCompiler:
has_null = False
for sub_evaluate in evaluators:
value = sub_evaluate(obj)
- if value:
+ if value is _EXPIRED_OBJECT:
+ return _EXPIRED_OBJECT
+ elif value:
return True
has_null = has_null or value is None
if has_null:
@@ -147,6 +177,9 @@ class EvaluatorCompiler:
def evaluate(obj):
for sub_evaluate in evaluators:
value = sub_evaluate(obj)
+ if value is _EXPIRED_OBJECT:
+ return _EXPIRED_OBJECT
+
if not value:
if value is None or value is _NO_OBJECT:
return None
@@ -160,7 +193,9 @@ class EvaluatorCompiler:
values = []
for sub_evaluate in evaluators:
value = sub_evaluate(obj)
- if value is None or value is _NO_OBJECT:
+ if value is _EXPIRED_OBJECT:
+ return _EXPIRED_OBJECT
+ elif value is None or value is _NO_OBJECT:
return None
values.append(value)
return tuple(values)
@@ -183,13 +218,21 @@ class EvaluatorCompiler:
def visit_is_binary_op(self, operator, eval_left, eval_right, clause):
def evaluate(obj):
- return eval_left(obj) == eval_right(obj)
+ left_val = eval_left(obj)
+ right_val = eval_right(obj)
+ if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
+ return _EXPIRED_OBJECT
+ return left_val == right_val
return evaluate
def visit_is_not_binary_op(self, operator, eval_left, eval_right, clause):
def evaluate(obj):
- return eval_left(obj) != eval_right(obj)
+ left_val = eval_left(obj)
+ right_val = eval_right(obj)
+ if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
+ return _EXPIRED_OBJECT
+ return left_val != right_val
return evaluate
@@ -197,8 +240,11 @@ class EvaluatorCompiler:
def evaluate(obj):
left_val = eval_left(obj)
right_val = eval_right(obj)
- if left_val is None or right_val is None:
+ if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
+ return _EXPIRED_OBJECT
+ elif left_val is None or right_val is None:
return None
+
return operator(eval_left(obj), eval_right(obj))
return evaluate
@@ -274,7 +320,9 @@ class EvaluatorCompiler:
def evaluate(obj):
value = eval_inner(obj)
- if value is None:
+ if value is _EXPIRED_OBJECT:
+ return _EXPIRED_OBJECT
+ elif value is None:
return None
return not value
diff --git a/lib/sqlalchemy/orm/identity.py b/lib/sqlalchemy/orm/identity.py
index 63b131a78..4848f73f1 100644
--- a/lib/sqlalchemy/orm/identity.py
+++ b/lib/sqlalchemy/orm/identity.py
@@ -68,6 +68,11 @@ class IdentityMap:
) -> Optional[_O]:
raise NotImplementedError()
+ def fast_get_state(
+ self, key: _IdentityKeyType[_O]
+ ) -> Optional[InstanceState[_O]]:
+ raise NotImplementedError()
+
def keys(self) -> Iterable[_IdentityKeyType[Any]]:
return self._dict.keys()
@@ -206,6 +211,11 @@ class WeakInstanceDict(IdentityMap):
self._dict[key] = state
state._instance_dict = self._wr
+ def fast_get_state(
+ self, key: _IdentityKeyType[_O]
+ ) -> Optional[InstanceState[_O]]:
+ return self._dict.get(key)
+
def get(
self, key: _IdentityKeyType[_O], default: Optional[_O] = None
) -> Optional[_O]:
diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py
index 7317d48be..64f2542fd 100644
--- a/lib/sqlalchemy/orm/loading.py
+++ b/lib/sqlalchemy/orm/loading.py
@@ -29,7 +29,6 @@ from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
-from sqlalchemy.orm.context import FromStatement
from . import attributes
from . import exc as orm_exc
from . import path_registry
@@ -37,6 +36,7 @@ from .base import _DEFER_FOR_STATE
from .base import _RAISE_FOR_STATE
from .base import _SET_DEFERRED_EXPIRED
from .base import PassiveFlag
+from .context import FromStatement
from .util import _none_set
from .util import state_str
from .. import exc as sa_exc
@@ -50,6 +50,7 @@ from ..sql import util as sql_util
from ..sql.selectable import ForUpdateArg
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
from ..sql.selectable import SelectState
+from ..util import EMPTY_DICT
if TYPE_CHECKING:
from ._typing import _IdentityKeyType
@@ -764,7 +765,7 @@ def _instance_processor(
)
quick_populators = path.get(
- context.attributes, "memoized_setups", _none_set
+ context.attributes, "memoized_setups", EMPTY_DICT
)
todo = []
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index c8df51b06..c9cf8f49b 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -854,6 +854,7 @@ class Mapper(
_memoized_values: Dict[Any, Callable[[], Any]]
_inheriting_mappers: util.WeakSequence[Mapper[Any]]
_all_tables: Set[Table]
+ _polymorphic_attr_key: Optional[str]
_pks_by_table: Dict[FromClause, OrderedSet[ColumnClause[Any]]]
_cols_by_table: Dict[FromClause, OrderedSet[ColumnElement[Any]]]
@@ -1653,6 +1654,7 @@ class Mapper(
"""
setter = False
+ polymorphic_key: Optional[str] = None
if self.polymorphic_on is not None:
setter = True
@@ -1772,17 +1774,23 @@ class Mapper(
self._set_polymorphic_identity = (
mapper._set_polymorphic_identity
)
+ self._polymorphic_attr_key = (
+ mapper._polymorphic_attr_key
+ )
self._validate_polymorphic_identity = (
mapper._validate_polymorphic_identity
)
else:
self._set_polymorphic_identity = None
+ self._polymorphic_attr_key = None
return
if setter:
def _set_polymorphic_identity(state):
dict_ = state.dict
+ # TODO: what happens if polymorphic_on column attribute name
+ # does not match .key?
state.get_impl(polymorphic_key).set(
state,
dict_,
@@ -1790,6 +1798,8 @@ class Mapper(
None,
)
+ self._polymorphic_attr_key = polymorphic_key
+
def _validate_polymorphic_identity(mapper, state, dict_):
if (
polymorphic_key in dict_
@@ -1808,6 +1818,7 @@ class Mapper(
_validate_polymorphic_identity
)
else:
+ self._polymorphic_attr_key = None
self._set_polymorphic_identity = None
_validate_polymorphic_identity = None
@@ -3562,6 +3573,10 @@ class Mapper(
return util.LRUCache(self._compiled_cache_size)
@HasMemoized.memoized_attribute
+ def _multiple_persistence_tables(self):
+ return len(self.tables) > 1
+
+ @HasMemoized.memoized_attribute
def _sorted_tables(self):
table_to_mapper: Dict[Table, Mapper[Any]] = {}
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index abd528986..dfb61c28a 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -31,6 +31,7 @@ from .. import exc as sa_exc
from .. import future
from .. import sql
from .. import util
+from ..engine import cursor as _cursor
from ..sql import operators
from ..sql.elements import BooleanClauseList
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
@@ -398,6 +399,11 @@ def _collect_insert_commands(
None
)
+ if bulk and mapper._set_polymorphic_identity:
+ params.setdefault(
+ mapper._polymorphic_attr_key, mapper.polymorphic_identity
+ )
+
yield (
state,
state_dict,
@@ -411,7 +417,11 @@ def _collect_insert_commands(
def _collect_update_commands(
- uowtransaction, table, states_to_update, bulk=False
+ uowtransaction,
+ table,
+ states_to_update,
+ bulk=False,
+ use_orm_update_stmt=None,
):
"""Identify sets of values to use in UPDATE statements for a
list of states.
@@ -437,7 +447,11 @@ def _collect_update_commands(
pks = mapper._pks_by_table[table]
- value_params = {}
+ if use_orm_update_stmt is not None:
+ # TODO: ordered values, etc
+ value_params = use_orm_update_stmt._values
+ else:
+ value_params = {}
propkey_to_col = mapper._propkey_to_col[table]
@@ -697,6 +711,7 @@ def _emit_update_statements(
table,
update,
bookkeeping=True,
+ use_orm_update_stmt=None,
):
"""Emit UPDATE statements corresponding to value lists collected
by _collect_update_commands()."""
@@ -708,7 +723,7 @@ def _emit_update_statements(
execution_options = {"compiled_cache": base_mapper._compiled_cache}
- def update_stmt():
+ def update_stmt(existing_stmt=None):
clauses = BooleanClauseList._construct_raw(operators.and_)
for col in mapper._pks_by_table[table]:
@@ -725,10 +740,17 @@ def _emit_update_statements(
)
)
- stmt = table.update().where(clauses)
+ if existing_stmt is not None:
+ stmt = existing_stmt.where(clauses)
+ else:
+ stmt = table.update().where(clauses)
return stmt
- cached_stmt = base_mapper._memo(("update", table), update_stmt)
+ if use_orm_update_stmt is not None:
+ cached_stmt = update_stmt(use_orm_update_stmt)
+
+ else:
+ cached_stmt = base_mapper._memo(("update", table), update_stmt)
for (
(connection, paramkeys, hasvalue, has_all_defaults, has_all_pks),
@@ -747,6 +769,15 @@ def _emit_update_statements(
records = list(records)
statement = cached_stmt
+
+ if use_orm_update_stmt is not None:
+ statement = statement._annotate(
+ {
+ "_emit_update_table": table,
+ "_emit_update_mapper": mapper,
+ }
+ )
+
return_defaults = False
if not has_all_pks:
@@ -904,16 +935,35 @@ def _emit_insert_statements(
table,
insert,
bookkeeping=True,
+ use_orm_insert_stmt=None,
+ execution_options=None,
):
"""Emit INSERT statements corresponding to value lists collected
by _collect_insert_commands()."""
- cached_stmt = base_mapper._memo(("insert", table), table.insert)
+ if use_orm_insert_stmt is not None:
+ cached_stmt = use_orm_insert_stmt
+ exec_opt = util.EMPTY_DICT
- execution_options = {"compiled_cache": base_mapper._compiled_cache}
+ # if a user query with RETURNING was passed, we definitely need
+ # to use RETURNING.
+ returning_is_required_anyway = bool(use_orm_insert_stmt._returning)
+ else:
+ returning_is_required_anyway = False
+ cached_stmt = base_mapper._memo(("insert", table), table.insert)
+ exec_opt = {"compiled_cache": base_mapper._compiled_cache}
+
+ if execution_options:
+ execution_options = util.EMPTY_DICT.merge_with(
+ exec_opt, execution_options
+ )
+ else:
+ execution_options = exec_opt
+
+ return_result = None
for (
- (connection, pkeys, hasvalue, has_all_pks, has_all_defaults),
+ (connection, _, hasvalue, has_all_pks, has_all_defaults),
records,
) in groupby(
insert,
@@ -928,17 +978,29 @@ def _emit_insert_statements(
statement = cached_stmt
+ if use_orm_insert_stmt is not None:
+ statement = statement._annotate(
+ {
+ "_emit_insert_table": table,
+ "_emit_insert_mapper": mapper,
+ }
+ )
+
if (
- not bookkeeping
- or (
- has_all_defaults
- or not base_mapper.eager_defaults
- or not base_mapper.local_table.implicit_returning
- or not connection.dialect.insert_returning
+ (
+ not bookkeeping
+ or (
+ has_all_defaults
+ or not base_mapper.eager_defaults
+ or not base_mapper.local_table.implicit_returning
+ or not connection.dialect.insert_returning
+ )
)
+ and not returning_is_required_anyway
and has_all_pks
and not hasvalue
):
+
# the "we don't need newly generated values back" section.
# here we have all the PKs, all the defaults or we don't want
# to fetch them, or the dialect doesn't support RETURNING at all
@@ -946,7 +1008,7 @@ def _emit_insert_statements(
records = list(records)
multiparams = [rec[2] for rec in records]
- c = connection.execute(
+ result = connection.execute(
statement, multiparams, execution_options=execution_options
)
if bookkeeping:
@@ -962,7 +1024,7 @@ def _emit_insert_statements(
has_all_defaults,
),
last_inserted_params,
- ) in zip(records, c.context.compiled_parameters):
+ ) in zip(records, result.context.compiled_parameters):
if state:
_postfetch(
mapper_rec,
@@ -970,19 +1032,20 @@ def _emit_insert_statements(
table,
state,
state_dict,
- c,
+ result,
last_inserted_params,
value_params,
False,
- c.returned_defaults
- if not c.context.executemany
+ result.returned_defaults
+ if not result.context.executemany
else None,
)
else:
_postfetch_bulk_save(mapper_rec, state_dict, table)
else:
- # here, we need defaults and/or pk values back.
+ # here, we need defaults and/or pk values back or we otherwise
+ # know that we are using RETURNING in any case
records = list(records)
if (
@@ -991,6 +1054,16 @@ def _emit_insert_statements(
and len(records) > 1
):
do_executemany = True
+ elif returning_is_required_anyway:
+ if connection.dialect.insert_executemany_returning:
+ do_executemany = True
+ else:
+ raise sa_exc.InvalidRequestError(
+ f"Can't use explicit RETURNING for bulk INSERT "
+ f"operation with "
+ f"{connection.dialect.dialect_description} backend; "
+ f"executemany is not supported with RETURNING"
+ )
else:
do_executemany = False
@@ -998,6 +1071,7 @@ def _emit_insert_statements(
statement = statement.return_defaults(
*mapper._server_default_cols[table]
)
+
if mapper.version_id_col is not None:
statement = statement.return_defaults(mapper.version_id_col)
elif do_executemany:
@@ -1006,10 +1080,16 @@ def _emit_insert_statements(
if do_executemany:
multiparams = [rec[2] for rec in records]
- c = connection.execute(
+ result = connection.execute(
statement, multiparams, execution_options=execution_options
)
+ if use_orm_insert_stmt is not None:
+ if return_result is None:
+ return_result = result
+ else:
+ return_result = return_result.splice_vertically(result)
+
if bookkeeping:
for (
(
@@ -1027,9 +1107,9 @@ def _emit_insert_statements(
returned_defaults,
) in zip_longest(
records,
- c.context.compiled_parameters,
- c.inserted_primary_key_rows,
- c.returned_defaults_rows or (),
+ result.context.compiled_parameters,
+ result.inserted_primary_key_rows,
+ result.returned_defaults_rows or (),
):
if inserted_primary_key is None:
# this is a real problem and means that we didn't
@@ -1062,7 +1142,7 @@ def _emit_insert_statements(
table,
state,
state_dict,
- c,
+ result,
last_inserted_params,
value_params,
False,
@@ -1071,6 +1151,8 @@ def _emit_insert_statements(
else:
_postfetch_bulk_save(mapper_rec, state_dict, table)
else:
+ assert not returning_is_required_anyway
+
for (
state,
state_dict,
@@ -1132,6 +1214,12 @@ def _emit_insert_statements(
else:
_postfetch_bulk_save(mapper_rec, state_dict, table)
+ if use_orm_insert_stmt is not None:
+ if return_result is None:
+ return _cursor.null_dml_result()
+ else:
+ return return_result
+
def _emit_post_update_statements(
base_mapper, uowtransaction, mapper, table, update
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index 6d0f055e4..4d5a98fcf 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -2978,7 +2978,7 @@ class Query(
)
def delete(
- self, synchronize_session: _SynchronizeSessionArgument = "evaluate"
+ self, synchronize_session: _SynchronizeSessionArgument = "auto"
) -> int:
r"""Perform a DELETE with an arbitrary WHERE clause.
@@ -3042,7 +3042,7 @@ class Query(
def update(
self,
values: Dict[_DMLColumnArgument, Any],
- synchronize_session: _SynchronizeSessionArgument = "evaluate",
+ synchronize_session: _SynchronizeSessionArgument = "auto",
update_args: Optional[Dict[Any, Any]] = None,
) -> int:
r"""Perform an UPDATE with an arbitrary WHERE clause.
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index a690da0d5..64c013306 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -1828,12 +1828,13 @@ class Session(_SessionClassMethods, EventTarget):
statement._propagate_attrs.get("compile_state_plugin", None)
== "orm"
):
- # note that even without "future" mode, we need
compile_state_cls = CompileState._get_plugin_class_for_plugin(
statement, "orm"
)
if TYPE_CHECKING:
- assert isinstance(compile_state_cls, ORMCompileState)
+ assert isinstance(
+ compile_state_cls, context.AbstractORMCompileState
+ )
else:
compile_state_cls = None
@@ -1897,18 +1898,18 @@ class Session(_SessionClassMethods, EventTarget):
statement, params or {}, execution_options=execution_options
)
- result: Result[Any] = conn.execute(
- statement, params or {}, execution_options=execution_options
- )
-
if compile_state_cls:
- result = compile_state_cls.orm_setup_cursor_result(
+ result: Result[Any] = compile_state_cls.orm_execute_statement(
self,
statement,
- params,
+ params or {},
execution_options,
bind_arguments,
- result,
+ conn,
+ )
+ else:
+ result = conn.execute(
+ statement, params or {}, execution_options=execution_options
)
if _scalar_result:
@@ -2066,7 +2067,7 @@ class Session(_SessionClassMethods, EventTarget):
def scalars(
self,
statement: TypedReturnsRows[Tuple[_T]],
- params: Optional[_CoreSingleExecuteParams] = None,
+ params: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
@@ -2078,7 +2079,7 @@ class Session(_SessionClassMethods, EventTarget):
def scalars(
self,
statement: Executable,
- params: Optional[_CoreSingleExecuteParams] = None,
+ params: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
@@ -2089,7 +2090,7 @@ class Session(_SessionClassMethods, EventTarget):
def scalars(
self,
statement: Executable,
- params: Optional[_CoreSingleExecuteParams] = None,
+ params: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
index 19c6493db..8652591c8 100644
--- a/lib/sqlalchemy/orm/strategies.py
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -227,6 +227,11 @@ class ColumnLoader(LoaderStrategy):
fetch = self.columns[0]
if adapter:
fetch = adapter.columns[fetch]
+ if fetch is None:
+ # None happens here only for dml bulk_persistence cases
+ # when context.DMLReturningColFilter is used
+ return
+
memoized_populators[self.parent_property] = fetch
def init_class_attribute(self, mapper):
@@ -318,6 +323,12 @@ class ExpressionColumnLoader(ColumnLoader):
fetch = columns[0]
if adapter:
fetch = adapter.columns[fetch]
+ if fetch is None:
+ # None is not expected to be the result of any
+ # adapter implementation here, however there may be theoretical
+ # usages of returning() with context.DMLReturningColFilter
+ return
+
memoized_populators[self.parent_property] = fetch
def create_row_processor(
diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py
index 86b2952cb..262048bd1 100644
--- a/lib/sqlalchemy/sql/annotation.py
+++ b/lib/sqlalchemy/sql/annotation.py
@@ -552,6 +552,8 @@ def _new_annotation_type(
# e.g. BindParameter, add it if present.
if cls.__dict__.get("inherit_cache", False):
anno_cls.inherit_cache = True # type: ignore
+ elif "inherit_cache" in cls.__dict__:
+ anno_cls.inherit_cache = cls.__dict__["inherit_cache"] # type: ignore
anno_cls._is_column_operators = issubclass(cls, operators.ColumnOperators)
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 201324a2a..c7e226fcc 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -5166,6 +5166,8 @@ class SQLCompiler(Compiled):
delete_stmt, delete_stmt.table, extra_froms
)
+ crud._get_crud_params(self, delete_stmt, compile_state, toplevel, **kw)
+
if delete_stmt._hints:
dialect_hints, table_text = self._setup_crud_hints(
delete_stmt, table_text
@@ -5178,13 +5180,14 @@ class SQLCompiler(Compiled):
text += table_text
- if delete_stmt._returning:
- if self.returning_precedes_values:
- text += " " + self.returning_clause(
- delete_stmt,
- delete_stmt._returning,
- populate_result_map=toplevel,
- )
+ if (
+ self.implicit_returning or delete_stmt._returning
+ ) and self.returning_precedes_values:
+ text += " " + self.returning_clause(
+ delete_stmt,
+ self.implicit_returning or delete_stmt._returning,
+ populate_result_map=toplevel,
+ )
if extra_froms:
extra_from_text = self.delete_extra_from_clause(
@@ -5204,10 +5207,12 @@ class SQLCompiler(Compiled):
if t:
text += " WHERE " + t
- if delete_stmt._returning and not self.returning_precedes_values:
+ if (
+ self.implicit_returning or delete_stmt._returning
+ ) and not self.returning_precedes_values:
text += " " + self.returning_clause(
delete_stmt,
- delete_stmt._returning,
+ self.implicit_returning or delete_stmt._returning,
populate_result_map=toplevel,
)
@@ -5297,7 +5302,6 @@ class StrSQLCompiler(SQLCompiler):
self._label_select_column(None, c, True, False, {})
for c in base._select_iterables(returning_cols)
]
-
return "RETURNING " + ", ".join(columns)
def update_from_clause(
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index b13377a59..22fffb73a 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -150,6 +150,22 @@ def _get_crud_params(
"return_defaults() simultaneously"
)
+ if compile_state.isdelete:
+ _setup_delete_return_defaults(
+ compiler,
+ stmt,
+ compile_state,
+ (),
+ _getattr_col_key,
+ _column_as_key,
+ _col_bind_name,
+ (),
+ (),
+ toplevel,
+ kw,
+ )
+ return _CrudParams([], [])
+
# no parameters in the statement, no parameters in the
# compiled params - return binds for all columns
if compiler.column_keys is None and compile_state._no_parameters:
@@ -466,13 +482,6 @@ def _scan_insert_from_select_cols(
kw,
):
- (
- need_pks,
- implicit_returning,
- implicit_return_defaults,
- postfetch_lastrowid,
- ) = _get_returning_modifiers(compiler, stmt, compile_state, toplevel)
-
cols = [stmt.table.c[_column_as_key(name)] for name in stmt._select_names]
assert compiler.stack[-1]["selectable"] is stmt
@@ -537,6 +546,8 @@ def _scan_cols(
postfetch_lastrowid,
) = _get_returning_modifiers(compiler, stmt, compile_state, toplevel)
+ assert compile_state.isupdate or compile_state.isinsert
+
if compile_state._parameter_ordering:
parameter_ordering = [
_column_as_key(key) for key in compile_state._parameter_ordering
@@ -563,6 +574,13 @@ def _scan_cols(
else:
autoincrement_col = insert_null_pk_still_autoincrements = None
+ if stmt._supplemental_returning:
+ supplemental_returning = set(stmt._supplemental_returning)
+ else:
+ supplemental_returning = set()
+
+ compiler_implicit_returning = compiler.implicit_returning
+
for c in cols:
# scan through every column in the target table
@@ -627,11 +645,13 @@ def _scan_cols(
# column has a DDL-level default, and is either not a pk
# column or we don't need the pk.
if implicit_return_defaults and c in implicit_return_defaults:
- compiler.implicit_returning.append(c)
+ compiler_implicit_returning.append(c)
elif not c.primary_key:
compiler.postfetch.append(c)
+
elif implicit_return_defaults and c in implicit_return_defaults:
- compiler.implicit_returning.append(c)
+ compiler_implicit_returning.append(c)
+
elif (
c.primary_key
and c is not stmt.table._autoincrement_column
@@ -652,6 +672,59 @@ def _scan_cols(
kw,
)
+ # adding supplemental cols to implicit_returning in table
+ # order so that order is maintained between multiple INSERT
+ # statements which may have different parameters included, but all
+ # have the same RETURNING clause
+ if (
+ c in supplemental_returning
+ and c not in compiler_implicit_returning
+ ):
+ compiler_implicit_returning.append(c)
+
+ if supplemental_returning:
+ # we should have gotten every col into implicit_returning,
+ # however supplemental returning can also have SQL functions etc.
+ # in it
+ remaining_supplemental = supplemental_returning.difference(
+ compiler_implicit_returning
+ )
+ compiler_implicit_returning.extend(
+ c
+ for c in stmt._supplemental_returning
+ if c in remaining_supplemental
+ )
+
+
+def _setup_delete_return_defaults(
+ compiler,
+ stmt,
+ compile_state,
+ parameters,
+ _getattr_col_key,
+ _column_as_key,
+ _col_bind_name,
+ check_columns,
+ values,
+ toplevel,
+ kw,
+):
+ (_, _, implicit_return_defaults, _) = _get_returning_modifiers(
+ compiler, stmt, compile_state, toplevel
+ )
+
+ if not implicit_return_defaults:
+ return
+
+ if stmt._return_defaults_columns:
+ compiler.implicit_returning.extend(implicit_return_defaults)
+
+ if stmt._supplemental_returning:
+ ir_set = set(compiler.implicit_returning)
+ compiler.implicit_returning.extend(
+ c for c in stmt._supplemental_returning if c not in ir_set
+ )
+
def _append_param_parameter(
compiler,
@@ -743,7 +816,7 @@ def _append_param_parameter(
elif compiler.dialect.postfetch_lastrowid:
compiler.postfetch_lastrowid = True
- elif implicit_return_defaults and c in implicit_return_defaults:
+ elif implicit_return_defaults and (c in implicit_return_defaults):
compiler.implicit_returning.append(c)
else:
@@ -1303,6 +1376,7 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel):
INSERT or UPDATE statement after it's invoked.
"""
+
need_pks = (
toplevel
and _compile_state_isinsert(compile_state)
@@ -1315,6 +1389,7 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel):
)
)
and not stmt._returning
+ # and (not stmt._returning or stmt._return_defaults)
and not compile_state._has_multi_parameters
)
@@ -1357,33 +1432,41 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel):
or stmt._return_defaults
)
)
-
if implicit_returning:
postfetch_lastrowid = False
if _compile_state_isinsert(compile_state):
- implicit_return_defaults = implicit_returning and stmt._return_defaults
+ should_implicit_return_defaults = (
+ implicit_returning and stmt._return_defaults
+ )
elif compile_state.isupdate:
- implicit_return_defaults = (
+ should_implicit_return_defaults = (
stmt._return_defaults
and compile_state._primary_table.implicit_returning
and compile_state._supports_implicit_returning
and compiler.dialect.update_returning
)
+ elif compile_state.isdelete:
+ should_implicit_return_defaults = (
+ stmt._return_defaults
+ and compile_state._primary_table.implicit_returning
+ and compile_state._supports_implicit_returning
+ and compiler.dialect.delete_returning
+ )
else:
- # this line is unused, currently we are always
- # isinsert or isupdate
- implicit_return_defaults = False # pragma: no cover
+ should_implicit_return_defaults = False # pragma: no cover
- if implicit_return_defaults:
+ if should_implicit_return_defaults:
if not stmt._return_defaults_columns:
implicit_return_defaults = set(stmt.table.c)
else:
implicit_return_defaults = set(stmt._return_defaults_columns)
+ else:
+ implicit_return_defaults = None
return (
need_pks,
- implicit_returning,
+ implicit_returning or should_implicit_return_defaults,
implicit_return_defaults,
postfetch_lastrowid,
)
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index a08e38800..5145a4a16 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -165,15 +165,32 @@ class DMLState(CompileState):
...
@classmethod
+ def _get_multi_crud_kv_pairs(
+ cls,
+ statement: UpdateBase,
+ multi_kv_iterator: Iterable[Dict[_DMLColumnArgument, Any]],
+ ) -> List[Dict[_DMLColumnElement, Any]]:
+ return [
+ {
+ coercions.expect(roles.DMLColumnRole, k): v
+ for k, v in mapping.items()
+ }
+ for mapping in multi_kv_iterator
+ ]
+
+ @classmethod
def _get_crud_kv_pairs(
cls,
statement: UpdateBase,
kv_iterator: Iterable[Tuple[_DMLColumnArgument, Any]],
+ needs_to_be_cacheable: bool,
) -> List[Tuple[_DMLColumnElement, Any]]:
return [
(
coercions.expect(roles.DMLColumnRole, k),
- coercions.expect(
+ v
+ if not needs_to_be_cacheable
+ else coercions.expect(
roles.ExpressionElementRole,
v,
type_=NullType(),
@@ -269,7 +286,7 @@ class InsertDMLState(DMLState):
def _insert_col_keys(self) -> List[str]:
# this is also done in crud.py -> _key_getters_for_crud_column
return [
- coercions.expect_as_key(roles.DMLColumnRole, col)
+ coercions.expect(roles.DMLColumnRole, col, as_key=True)
for col in self._dict_parameters or ()
]
@@ -326,7 +343,6 @@ class UpdateDMLState(DMLState):
self._extra_froms = ef
self.is_multitable = mt = ef
-
self.include_table_with_column_exprs = bool(
mt and compiler.render_table_with_column_in_update_from
)
@@ -389,6 +405,7 @@ class UpdateBase(
_return_defaults_columns: Optional[
Tuple[_ColumnsClauseElement, ...]
] = None
+ _supplemental_returning: Optional[Tuple[_ColumnsClauseElement, ...]] = None
_returning: Tuple[_ColumnsClauseElement, ...] = ()
is_dml = True
@@ -435,6 +452,215 @@ class UpdateBase(
return self
@_generative
+ def return_defaults(
+ self: SelfUpdateBase,
+ *cols: _DMLColumnArgument,
+ supplemental_cols: Optional[Iterable[_DMLColumnArgument]] = None,
+ ) -> SelfUpdateBase:
+ """Make use of a :term:`RETURNING` clause for the purpose
+ of fetching server-side expressions and defaults, for supporting
+ backends only.
+
+ .. deepalchemy::
+
+ The :meth:`.UpdateBase.return_defaults` method is used by the ORM
+ for its internal work in fetching newly generated primary key
+ and server default values, in particular to provide the underyling
+ implementation of the :paramref:`_orm.Mapper.eager_defaults`
+ ORM feature as well as to allow RETURNING support with bulk
+ ORM inserts. Its behavior is fairly idiosyncratic
+ and is not really intended for general use. End users should
+ stick with using :meth:`.UpdateBase.returning` in order to
+ add RETURNING clauses to their INSERT, UPDATE and DELETE
+ statements.
+
+ Normally, a single row INSERT statement will automatically populate the
+ :attr:`.CursorResult.inserted_primary_key` attribute when executed,
+ which stores the primary key of the row that was just inserted in the
+ form of a :class:`.Row` object with column names as named tuple keys
+ (and the :attr:`.Row._mapping` view fully populated as well). The
+ dialect in use chooses the strategy to use in order to populate this
+ data; if it was generated using server-side defaults and / or SQL
+ expressions, dialect-specific approaches such as ``cursor.lastrowid``
+ or ``RETURNING`` are typically used to acquire the new primary key
+ value.
+
+ However, when the statement is modified by calling
+ :meth:`.UpdateBase.return_defaults` before executing the statement,
+ additional behaviors take place **only** for backends that support
+ RETURNING and for :class:`.Table` objects that maintain the
+ :paramref:`.Table.implicit_returning` parameter at its default value of
+ ``True``. In these cases, when the :class:`.CursorResult` is returned
+ from the statement's execution, not only will
+ :attr:`.CursorResult.inserted_primary_key` be populated as always, the
+ :attr:`.CursorResult.returned_defaults` attribute will also be
+ populated with a :class:`.Row` named-tuple representing the full range
+ of server generated
+ values from that single row, including values for any columns that
+ specify :paramref:`_schema.Column.server_default` or which make use of
+ :paramref:`_schema.Column.default` using a SQL expression.
+
+ When invoking INSERT statements with multiple rows using
+ :ref:`insertmanyvalues <engine_insertmanyvalues>`, the
+ :meth:`.UpdateBase.return_defaults` modifier will have the effect of
+ the :attr:`_engine.CursorResult.inserted_primary_key_rows` and
+ :attr:`_engine.CursorResult.returned_defaults_rows` attributes being
+ fully populated with lists of :class:`.Row` objects representing newly
+ inserted primary key values as well as newly inserted server generated
+ values for each row inserted. The
+ :attr:`.CursorResult.inserted_primary_key` and
+ :attr:`.CursorResult.returned_defaults` attributes will also continue
+ to be populated with the first row of these two collections.
+
+ If the backend does not support RETURNING or the :class:`.Table` in use
+ has disabled :paramref:`.Table.implicit_returning`, then no RETURNING
+ clause is added and no additional data is fetched, however the
+ INSERT, UPDATE or DELETE statement proceeds normally.
+
+ E.g.::
+
+ stmt = table.insert().values(data='newdata').return_defaults()
+
+ result = connection.execute(stmt)
+
+ server_created_at = result.returned_defaults['created_at']
+
+ When used against an UPDATE statement
+ :meth:`.UpdateBase.return_defaults` instead looks for columns that
+ include :paramref:`_schema.Column.onupdate` or
+ :paramref:`_schema.Column.server_onupdate` parameters assigned, when
+ constructing the columns that will be included in the RETURNING clause
+ by default if explicit columns were not specified. When used against a
+ DELETE statement, no columns are included in RETURNING by default, they
+ instead must be specified explicitly as there are no columns that
+ normally change values when a DELETE statement proceeds.
+
+ .. versionadded:: 2.0 :meth:`.UpdateBase.return_defaults` is supported
+ for DELETE statements also and has been moved from
+ :class:`.ValuesBase` to :class:`.UpdateBase`.
+
+ The :meth:`.UpdateBase.return_defaults` method is mutually exclusive
+ against the :meth:`.UpdateBase.returning` method and errors will be
+ raised during the SQL compilation process if both are used at the same
+ time on one statement. The RETURNING clause of the INSERT, UPDATE or
+ DELETE statement is therefore controlled by only one of these methods
+ at a time.
+
+ The :meth:`.UpdateBase.return_defaults` method differs from
+ :meth:`.UpdateBase.returning` in these ways:
+
+ 1. :meth:`.UpdateBase.return_defaults` method causes the
+ :attr:`.CursorResult.returned_defaults` collection to be populated
+ with the first row from the RETURNING result. This attribute is not
+ populated when using :meth:`.UpdateBase.returning`.
+
+ 2. :meth:`.UpdateBase.return_defaults` is compatible with existing
+ logic used to fetch auto-generated primary key values that are then
+ populated into the :attr:`.CursorResult.inserted_primary_key`
+ attribute. By contrast, using :meth:`.UpdateBase.returning` will
+ have the effect of the :attr:`.CursorResult.inserted_primary_key`
+ attribute being left unpopulated.
+
+ 3. :meth:`.UpdateBase.return_defaults` can be called against any
+ backend. Backends that don't support RETURNING will skip the usage
+ of the feature, rather than raising an exception. The return value
+ of :attr:`_engine.CursorResult.returned_defaults` will be ``None``
+ for backends that don't support RETURNING or for which the target
+ :class:`.Table` sets :paramref:`.Table.implicit_returning` to
+ ``False``.
+
+ 4. An INSERT statement invoked with executemany() is supported if the
+ backend database driver supports the
+ :ref:`insertmanyvalues <engine_insertmanyvalues>`
+ feature which is now supported by most SQLAlchemy-included backends.
+ When executemany is used, the
+ :attr:`_engine.CursorResult.returned_defaults_rows` and
+ :attr:`_engine.CursorResult.inserted_primary_key_rows` accessors
+ will return the inserted defaults and primary keys.
+
+ .. versionadded:: 1.4 Added
+ :attr:`_engine.CursorResult.returned_defaults_rows` and
+ :attr:`_engine.CursorResult.inserted_primary_key_rows` accessors.
+ In version 2.0, the underlying implementation which fetches and
+ populates the data for these attributes was generalized to be
+ supported by most backends, whereas in 1.4 they were only
+ supported by the ``psycopg2`` driver.
+
+
+ :param cols: optional list of column key names or
+ :class:`_schema.Column` that acts as a filter for those columns that
+ will be fetched.
+ :param supplemental_cols: optional list of RETURNING expressions,
+ in the same form as one would pass to the
+ :meth:`.UpdateBase.returning` method. When present, the additional
+ columns will be included in the RETURNING clause, and the
+ :class:`.CursorResult` object will be "rewound" when returned, so
+ that methods like :meth:`.CursorResult.all` will return new rows
+ mostly as though the statement used :meth:`.UpdateBase.returning`
+ directly. However, unlike when using :meth:`.UpdateBase.returning`
+ directly, the **order of the columns is undefined**, so can only be
+ targeted using names or :attr:`.Row._mapping` keys; they cannot
+ reliably be targeted positionally.
+
+ .. versionadded:: 2.0
+
+ .. seealso::
+
+ :meth:`.UpdateBase.returning`
+
+ :attr:`_engine.CursorResult.returned_defaults`
+
+ :attr:`_engine.CursorResult.returned_defaults_rows`
+
+ :attr:`_engine.CursorResult.inserted_primary_key`
+
+ :attr:`_engine.CursorResult.inserted_primary_key_rows`
+
+ """
+
+ if self._return_defaults:
+ # note _return_defaults_columns = () means return all columns,
+ # so if we have been here before, only update collection if there
+ # are columns in the collection
+ if self._return_defaults_columns and cols:
+ self._return_defaults_columns = tuple(
+ util.OrderedSet(self._return_defaults_columns).union(
+ coercions.expect(roles.ColumnsClauseRole, c)
+ for c in cols
+ )
+ )
+ else:
+ # set for all columns
+ self._return_defaults_columns = ()
+ else:
+ self._return_defaults_columns = tuple(
+ coercions.expect(roles.ColumnsClauseRole, c) for c in cols
+ )
+ self._return_defaults = True
+
+ if supplemental_cols:
+ # uniquifying while also maintaining order (the maintain of order
+ # is for test suites but also for vertical splicing
+ supplemental_col_tup = (
+ coercions.expect(roles.ColumnsClauseRole, c)
+ for c in supplemental_cols
+ )
+
+ if self._supplemental_returning is None:
+ self._supplemental_returning = tuple(
+ util.unique_list(supplemental_col_tup)
+ )
+ else:
+ self._supplemental_returning = tuple(
+ util.unique_list(
+ self._supplemental_returning
+ + tuple(supplemental_col_tup)
+ )
+ )
+
+ return self
+
+ @_generative
def returning(
self, *cols: _ColumnsClauseArgument[Any], **__kw: Any
) -> UpdateBase:
@@ -500,7 +726,7 @@ class UpdateBase(
.. seealso::
- :meth:`.ValuesBase.return_defaults` - an alternative method tailored
+ :meth:`.UpdateBase.return_defaults` - an alternative method tailored
towards efficient fetching of server-side defaults and triggers
for single-row INSERTs or UPDATEs.
@@ -703,7 +929,6 @@ class ValuesBase(UpdateBase):
_select_names: Optional[List[str]] = None
_inline: bool = False
- _returning: Tuple[_ColumnsClauseElement, ...] = ()
def __init__(self, table: _DMLTableArgument):
self.table = coercions.expect(
@@ -859,7 +1084,15 @@ class ValuesBase(UpdateBase):
)
elif isinstance(arg, collections_abc.Sequence):
- if arg and isinstance(arg[0], (list, dict, tuple)):
+
+ if arg and isinstance(arg[0], dict):
+ multi_kv_generator = DMLState.get_plugin_class(
+ self
+ )._get_multi_crud_kv_pairs
+ self._multi_values += (multi_kv_generator(self, arg),)
+ return self
+
+ if arg and isinstance(arg[0], (list, tuple)):
self._multi_values += (arg,)
return self
@@ -888,173 +1121,13 @@ class ValuesBase(UpdateBase):
# and ensures they get the "crud"-style name when rendered.
kv_generator = DMLState.get_plugin_class(self)._get_crud_kv_pairs
- coerced_arg = {k: v for k, v in kv_generator(self, arg.items())}
+ coerced_arg = dict(kv_generator(self, arg.items(), True))
if self._values:
self._values = self._values.union(coerced_arg)
else:
self._values = util.immutabledict(coerced_arg)
return self
- @_generative
- def return_defaults(
- self: SelfValuesBase, *cols: _DMLColumnArgument
- ) -> SelfValuesBase:
- """Make use of a :term:`RETURNING` clause for the purpose
- of fetching server-side expressions and defaults, for supporting
- backends only.
-
- .. tip::
-
- The :meth:`.ValuesBase.return_defaults` method is used by the ORM
- for its internal work in fetching newly generated primary key
- and server default values, in particular to provide the underyling
- implementation of the :paramref:`_orm.Mapper.eager_defaults`
- ORM feature. Its behavior is fairly idiosyncratic
- and is not really intended for general use. End users should
- stick with using :meth:`.UpdateBase.returning` in order to
- add RETURNING clauses to their INSERT, UPDATE and DELETE
- statements.
-
- Normally, a single row INSERT statement will automatically populate the
- :attr:`.CursorResult.inserted_primary_key` attribute when executed,
- which stores the primary key of the row that was just inserted in the
- form of a :class:`.Row` object with column names as named tuple keys
- (and the :attr:`.Row._mapping` view fully populated as well). The
- dialect in use chooses the strategy to use in order to populate this
- data; if it was generated using server-side defaults and / or SQL
- expressions, dialect-specific approaches such as ``cursor.lastrowid``
- or ``RETURNING`` are typically used to acquire the new primary key
- value.
-
- However, when the statement is modified by calling
- :meth:`.ValuesBase.return_defaults` before executing the statement,
- additional behaviors take place **only** for backends that support
- RETURNING and for :class:`.Table` objects that maintain the
- :paramref:`.Table.implicit_returning` parameter at its default value of
- ``True``. In these cases, when the :class:`.CursorResult` is returned
- from the statement's execution, not only will
- :attr:`.CursorResult.inserted_primary_key` be populated as always, the
- :attr:`.CursorResult.returned_defaults` attribute will also be
- populated with a :class:`.Row` named-tuple representing the full range
- of server generated
- values from that single row, including values for any columns that
- specify :paramref:`_schema.Column.server_default` or which make use of
- :paramref:`_schema.Column.default` using a SQL expression.
-
- When invoking INSERT statements with multiple rows using
- :ref:`insertmanyvalues <engine_insertmanyvalues>`, the
- :meth:`.ValuesBase.return_defaults` modifier will have the effect of
- the :attr:`_engine.CursorResult.inserted_primary_key_rows` and
- :attr:`_engine.CursorResult.returned_defaults_rows` attributes being
- fully populated with lists of :class:`.Row` objects representing newly
- inserted primary key values as well as newly inserted server generated
- values for each row inserted. The
- :attr:`.CursorResult.inserted_primary_key` and
- :attr:`.CursorResult.returned_defaults` attributes will also continue
- to be populated with the first row of these two collections.
-
- If the backend does not support RETURNING or the :class:`.Table` in use
- has disabled :paramref:`.Table.implicit_returning`, then no RETURNING
- clause is added and no additional data is fetched, however the
- INSERT or UPDATE statement proceeds normally.
-
-
- E.g.::
-
- stmt = table.insert().values(data='newdata').return_defaults()
-
- result = connection.execute(stmt)
-
- server_created_at = result.returned_defaults['created_at']
-
-
- The :meth:`.ValuesBase.return_defaults` method is mutually exclusive
- against the :meth:`.UpdateBase.returning` method and errors will be
- raised during the SQL compilation process if both are used at the same
- time on one statement. The RETURNING clause of the INSERT or UPDATE
- statement is therefore controlled by only one of these methods at a
- time.
-
- The :meth:`.ValuesBase.return_defaults` method differs from
- :meth:`.UpdateBase.returning` in these ways:
-
- 1. :meth:`.ValuesBase.return_defaults` method causes the
- :attr:`.CursorResult.returned_defaults` collection to be populated
- with the first row from the RETURNING result. This attribute is not
- populated when using :meth:`.UpdateBase.returning`.
-
- 2. :meth:`.ValuesBase.return_defaults` is compatible with existing
- logic used to fetch auto-generated primary key values that are then
- populated into the :attr:`.CursorResult.inserted_primary_key`
- attribute. By contrast, using :meth:`.UpdateBase.returning` will
- have the effect of the :attr:`.CursorResult.inserted_primary_key`
- attribute being left unpopulated.
-
- 3. :meth:`.ValuesBase.return_defaults` can be called against any
- backend. Backends that don't support RETURNING will skip the usage
- of the feature, rather than raising an exception. The return value
- of :attr:`_engine.CursorResult.returned_defaults` will be ``None``
- for backends that don't support RETURNING or for which the target
- :class:`.Table` sets :paramref:`.Table.implicit_returning` to
- ``False``.
-
- 4. An INSERT statement invoked with executemany() is supported if the
- backend database driver supports the
- :ref:`insertmanyvalues <engine_insertmanyvalues>`
- feature which is now supported by most SQLAlchemy-included backends.
- When executemany is used, the
- :attr:`_engine.CursorResult.returned_defaults_rows` and
- :attr:`_engine.CursorResult.inserted_primary_key_rows` accessors
- will return the inserted defaults and primary keys.
-
- .. versionadded:: 1.4 Added
- :attr:`_engine.CursorResult.returned_defaults_rows` and
- :attr:`_engine.CursorResult.inserted_primary_key_rows` accessors.
- In version 2.0, the underlying implementation which fetches and
- populates the data for these attributes was generalized to be
- supported by most backends, whereas in 1.4 they were only
- supported by the ``psycopg2`` driver.
-
-
- :param cols: optional list of column key names or
- :class:`_schema.Column` that acts as a filter for those columns that
- will be fetched.
-
- .. seealso::
-
- :meth:`.UpdateBase.returning`
-
- :attr:`_engine.CursorResult.returned_defaults`
-
- :attr:`_engine.CursorResult.returned_defaults_rows`
-
- :attr:`_engine.CursorResult.inserted_primary_key`
-
- :attr:`_engine.CursorResult.inserted_primary_key_rows`
-
- """
-
- if self._return_defaults:
- # note _return_defaults_columns = () means return all columns,
- # so if we have been here before, only update collection if there
- # are columns in the collection
- if self._return_defaults_columns and cols:
- self._return_defaults_columns = tuple(
- set(self._return_defaults_columns).union(
- coercions.expect(roles.ColumnsClauseRole, c)
- for c in cols
- )
- )
- else:
- # set for all columns
- self._return_defaults_columns = ()
- else:
- self._return_defaults_columns = tuple(
- coercions.expect(roles.ColumnsClauseRole, c) for c in cols
- )
- self._return_defaults = True
- return self
-
SelfInsert = typing.TypeVar("SelfInsert", bound="Insert")
@@ -1459,7 +1532,7 @@ class Update(DMLWhereBase, ValuesBase):
)
kv_generator = DMLState.get_plugin_class(self)._get_crud_kv_pairs
- self._ordered_values = kv_generator(self, args)
+ self._ordered_values = kv_generator(self, args, True)
return self
@_generative
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py
index 4416fe630..b3a71dbff 100644
--- a/lib/sqlalchemy/testing/assertsql.py
+++ b/lib/sqlalchemy/testing/assertsql.py
@@ -68,7 +68,7 @@ class CursorSQL(SQLMatchRule):
class CompiledSQL(SQLMatchRule):
def __init__(
- self, statement, params=None, dialect="default", enable_returning=False
+ self, statement, params=None, dialect="default", enable_returning=True
):
self.statement = statement
self.params = params
@@ -90,6 +90,17 @@ class CompiledSQL(SQLMatchRule):
dialect.insert_returning = (
dialect.update_returning
) = dialect.delete_returning = True
+ dialect.use_insertmanyvalues = True
+ dialect.supports_multivalues_insert = True
+ dialect.update_returning_multifrom = True
+ dialect.delete_returning_multifrom = True
+ # dialect.favor_returning_over_lastrowid = True
+ # dialect.insert_null_pk_still_autoincrements = True
+
+ # this is calculated but we need it to be True for this
+ # to look like all the current RETURNING dialects
+ assert dialect.insert_executemany_returning
+
return dialect
else:
return url.URL.create(self.dialect).get_dialect()()
diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py
index 20dee5273..ef284babc 100644
--- a/lib/sqlalchemy/testing/fixtures.py
+++ b/lib/sqlalchemy/testing/fixtures.py
@@ -23,7 +23,6 @@ from .util import adict
from .util import drop_all_tables_from_metadata
from .. import event
from .. import util
-from ..orm import declarative_base
from ..orm import DeclarativeBase
from ..orm import MappedAsDataclass
from ..orm import registry
@@ -117,7 +116,7 @@ class TestBase:
metadata=metadata,
type_annotation_map={
str: sa.String().with_variant(
- sa.String(50), "mysql", "mariadb"
+ sa.String(50), "mysql", "mariadb", "oracle"
)
},
)
@@ -132,7 +131,7 @@ class TestBase:
metadata = _md
type_annotation_map = {
str: sa.String().with_variant(
- sa.String(50), "mysql", "mariadb"
+ sa.String(50), "mysql", "mariadb", "oracle"
)
}
@@ -780,18 +779,19 @@ class DeclarativeMappedTest(MappedTest):
def _with_register_classes(cls, fn):
cls_registry = cls.classes
- class DeclarativeBasic:
+ class _DeclBase(DeclarativeBase):
__table_cls__ = schema.Table
+ metadata = cls._tables_metadata
+ type_annotation_map = {
+ str: sa.String().with_variant(
+ sa.String(50), "mysql", "mariadb", "oracle"
+ )
+ }
- def __init_subclass__(cls) -> None:
+ def __init_subclass__(cls, **kw) -> None:
assert cls_registry is not None
cls_registry[cls.__name__] = cls
- super().__init_subclass__()
-
- _DeclBase = declarative_base(
- metadata=cls._tables_metadata,
- cls=DeclarativeBasic,
- )
+ super().__init_subclass__(**kw)
cls.DeclarativeBasic = _DeclBase
diff --git a/lib/sqlalchemy/testing/suite/test_rowcount.py b/lib/sqlalchemy/testing/suite/test_rowcount.py
index b7d4b7452..8e19a24a8 100644
--- a/lib/sqlalchemy/testing/suite/test_rowcount.py
+++ b/lib/sqlalchemy/testing/suite/test_rowcount.py
@@ -89,8 +89,13 @@ class RowCountTest(fixtures.TablesTest):
eq_(r.rowcount, 3)
@testing.requires.update_returning
- @testing.requires.sane_rowcount_w_returning
def test_update_rowcount_return_defaults(self, connection):
+ """note this test should succeed for all RETURNING backends
+ as of 2.0. In
+ Idf28379f8705e403a3c6a937f6a798a042ef2540 we changed rowcount to use
+ len(rows) when we have implicit returning
+
+ """
employees_table = self.tables.employees
department = employees_table.c.department
diff --git a/lib/sqlalchemy/util/_py_collections.py b/lib/sqlalchemy/util/_py_collections.py
index f8348714c..488229abb 100644
--- a/lib/sqlalchemy/util/_py_collections.py
+++ b/lib/sqlalchemy/util/_py_collections.py
@@ -11,6 +11,7 @@ from __future__ import annotations
from itertools import filterfalse
from typing import AbstractSet
from typing import Any
+from typing import Callable
from typing import cast
from typing import Collection
from typing import Dict
@@ -481,7 +482,9 @@ class IdentitySet:
return "%s(%r)" % (type(self).__name__, list(self._members.values()))
-def unique_list(seq, hashfunc=None):
+def unique_list(
+ seq: Iterable[_T], hashfunc: Optional[Callable[[_T], int]] = None
+) -> List[_T]:
seen: Set[Any] = set()
seen_add = seen.add
if not hashfunc: