diff options
45 files changed, 4748 insertions, 1010 deletions
diff --git a/doc/build/orm/session_basics.rst b/doc/build/orm/session_basics.rst index 96b9d8b5c..6b7ef3299 100644 --- a/doc/build/orm/session_basics.rst +++ b/doc/build/orm/session_basics.rst @@ -660,6 +660,17 @@ Selecting a Synchronization Strategy With both the 1.x and 2.0 form of ORM-enabled updates and deletes, the following values for ``synchronize_session`` are supported: +* ``'auto'`` - this is the default. The ``'fetch'`` strategy will be used on + backends that support RETURNING, which includes all SQLAlchemy-native drivers + except for MySQL. If RETURNING is not supported, the ``'evaluate'`` + strategy will be used instead. + + .. versionchanged:: 2.0 Added the ``'auto'`` synchronization strategy. As + most backends now support RETURNING, selecting ``'fetch'`` for these + backends specifically is the more efficient and error-free default for + these backends. The MySQL backend as well as third party backends without + RETURNING support will continue to use ``'evaluate'`` by default. + * ``False`` - don't synchronize the session. This option is the most efficient and is reliable once the session is expired, which typically occurs after a commit(), or explicitly using diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py index 2eef971cc..5eb6b9528 100644 --- a/lib/sqlalchemy/dialects/mssql/pyodbc.py +++ b/lib/sqlalchemy/dialects/mssql/pyodbc.py @@ -301,7 +301,7 @@ Fast Executemany Mode The SQL Server ``fast_executemany`` parameter may be used at the same time as ``insertmanyvalues`` is enabled; however, the parameter will not be used in as many cases as INSERT statements that are invoked using Core - :class:`.Insert` constructs as well as all ORM use no longer use the + :class:`_dml.Insert` constructs as well as all ORM use no longer use the ``.executemany()`` DBAPI cursor method. The PyODBC driver includes support for a "fast executemany" mode of execution diff --git a/lib/sqlalchemy/dialects/postgresql/provision.py b/lib/sqlalchemy/dialects/postgresql/provision.py index 8dd8a4995..4609701a2 100644 --- a/lib/sqlalchemy/dialects/postgresql/provision.py +++ b/lib/sqlalchemy/dialects/postgresql/provision.py @@ -134,9 +134,11 @@ def _upsert(cfg, table, returning, set_lambda=None): stmt = insert(table) + table_pk = inspect(table).selectable + if set_lambda: stmt = stmt.on_conflict_do_update( - index_elements=table.primary_key, set_=set_lambda(stmt.excluded) + index_elements=table_pk.primary_key, set_=set_lambda(stmt.excluded) ) else: stmt = stmt.on_conflict_do_nothing() diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index e57a84fe0..5f468edbe 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -1466,11 +1466,6 @@ class SQLiteCompiler(compiler.SQLCompiler): return target_text - def visit_insert(self, insert_stmt, **kw): - if insert_stmt._post_values_clause is not None: - kw["disable_implicit_returning"] = True - return super().visit_insert(insert_stmt, **kw) - def visit_on_conflict_do_nothing(self, on_conflict, **kw): target_text = self._on_conflict_target(on_conflict, **kw) diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index 8840b5916..07e782296 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -23,12 +23,14 @@ from typing import Iterator from typing import List from typing import NoReturn from typing import Optional +from typing import overload from typing import Sequence from typing import Tuple from typing import TYPE_CHECKING from typing import TypeVar from typing import Union +from .result import IteratorResult from .result import MergedResult from .result import Result from .result import ResultMetaData @@ -62,36 +64,80 @@ if typing.TYPE_CHECKING: from .interfaces import ExecutionContext from .result import _KeyIndexType from .result import _KeyMapRecType + from .result import _KeyMapType from .result import _KeyType from .result import _ProcessorsType + from .result import _TupleGetterType from ..sql.type_api import _ResultProcessorType _T = TypeVar("_T", bound=Any) + # metadata entry tuple indexes. # using raw tuple is faster than namedtuple. -MD_INDEX: Literal[0] = 0 # integer index in cursor.description -MD_RESULT_MAP_INDEX: Literal[ - 1 -] = 1 # integer index in compiled._result_columns -MD_OBJECTS: Literal[ - 2 -] = 2 # other string keys and ColumnElement obj that can match -MD_LOOKUP_KEY: Literal[ - 3 -] = 3 # string key we usually expect for key-based lookup -MD_RENDERED_NAME: Literal[4] = 4 # name that is usually in cursor.description -MD_PROCESSOR: Literal[5] = 5 # callable to process a result value into a row -MD_UNTRANSLATED: Literal[6] = 6 # raw name from cursor.description +# these match up to the positions in +# _CursorKeyMapRecType +MD_INDEX: Literal[0] = 0 +"""integer index in cursor.description + +""" + +MD_RESULT_MAP_INDEX: Literal[1] = 1 +"""integer index in compiled._result_columns""" + +MD_OBJECTS: Literal[2] = 2 +"""other string keys and ColumnElement obj that can match. + +This comes from compiler.RM_OBJECTS / compiler.ResultColumnsEntry.objects + +""" + +MD_LOOKUP_KEY: Literal[3] = 3 +"""string key we usually expect for key-based lookup + +this comes from compiler.RM_NAME / compiler.ResultColumnsEntry.name +""" + + +MD_RENDERED_NAME: Literal[4] = 4 +"""name that is usually in cursor.description + +this comes from compiler.RENDERED_NAME / compiler.ResultColumnsEntry.keyname +""" + + +MD_PROCESSOR: Literal[5] = 5 +"""callable to process a result value into a row""" + +MD_UNTRANSLATED: Literal[6] = 6 +"""raw name from cursor.description""" _CursorKeyMapRecType = Tuple[ - int, int, List[Any], str, str, Optional["_ResultProcessorType"], str + Optional[int], # MD_INDEX, None means the record is ambiguously named + int, # MD_RESULT_MAP_INDEX + List[Any], # MD_OBJECTS + str, # MD_LOOKUP_KEY + str, # MD_RENDERED_NAME + Optional["_ResultProcessorType"], # MD_PROCESSOR + Optional[str], # MD_UNTRANSLATED ] _CursorKeyMapType = Dict["_KeyType", _CursorKeyMapRecType] +# same as _CursorKeyMapRecType except the MD_INDEX value is definitely +# not None +_NonAmbigCursorKeyMapRecType = Tuple[ + int, + int, + List[Any], + str, + str, + Optional["_ResultProcessorType"], + str, +] + class CursorResultMetaData(ResultMetaData): """Result metadata for DBAPI cursors.""" @@ -127,38 +173,112 @@ class CursorResultMetaData(ResultMetaData): extra=[self._keymap[key][MD_OBJECTS] for key in self._keys], ) - def _reduce(self, keys: Sequence[_KeyIndexType]) -> ResultMetaData: - recs = cast( - "List[_CursorKeyMapRecType]", list(self._metadata_for_keys(keys)) + def _make_new_metadata( + self, + *, + unpickled: bool, + processors: _ProcessorsType, + keys: Sequence[str], + keymap: _KeyMapType, + tuplefilter: Optional[_TupleGetterType], + translated_indexes: Optional[List[int]], + safe_for_cache: bool, + keymap_by_result_column_idx: Any, + ) -> CursorResultMetaData: + new_obj = self.__class__.__new__(self.__class__) + new_obj._unpickled = unpickled + new_obj._processors = processors + new_obj._keys = keys + new_obj._keymap = keymap + new_obj._tuplefilter = tuplefilter + new_obj._translated_indexes = translated_indexes + new_obj._safe_for_cache = safe_for_cache + new_obj._keymap_by_result_column_idx = keymap_by_result_column_idx + return new_obj + + def _remove_processors(self) -> CursorResultMetaData: + assert not self._tuplefilter + return self._make_new_metadata( + unpickled=self._unpickled, + processors=[None] * len(self._processors), + tuplefilter=None, + translated_indexes=None, + keymap={ + key: value[0:5] + (None,) + value[6:] + for key, value in self._keymap.items() + }, + keys=self._keys, + safe_for_cache=self._safe_for_cache, + keymap_by_result_column_idx=self._keymap_by_result_column_idx, ) + def _splice_horizontally( + self, other: CursorResultMetaData + ) -> CursorResultMetaData: + + assert not self._tuplefilter + + keymap = self._keymap.copy() + offset = len(self._keys) + keymap.update( + { + key: ( + # int index should be None for ambiguous key + value[0] + offset + if value[0] is not None and key not in keymap + else None, + value[1] + offset, + *value[2:], + ) + for key, value in other._keymap.items() + } + ) + + return self._make_new_metadata( + unpickled=self._unpickled, + processors=self._processors + other._processors, # type: ignore + tuplefilter=None, + translated_indexes=None, + keys=self._keys + other._keys, # type: ignore + keymap=keymap, + safe_for_cache=self._safe_for_cache, + keymap_by_result_column_idx={ + metadata_entry[MD_RESULT_MAP_INDEX]: metadata_entry + for metadata_entry in keymap.values() + }, + ) + + def _reduce(self, keys: Sequence[_KeyIndexType]) -> ResultMetaData: + recs = list(self._metadata_for_keys(keys)) + indexes = [rec[MD_INDEX] for rec in recs] new_keys: List[str] = [rec[MD_LOOKUP_KEY] for rec in recs] if self._translated_indexes: indexes = [self._translated_indexes[idx] for idx in indexes] tup = tuplegetter(*indexes) - - new_metadata = self.__class__.__new__(self.__class__) - new_metadata._unpickled = self._unpickled - new_metadata._processors = self._processors - new_metadata._keys = new_keys - new_metadata._tuplefilter = tup - new_metadata._translated_indexes = indexes - new_recs = [(index,) + rec[1:] for index, rec in enumerate(recs)] - new_metadata._keymap = {rec[MD_LOOKUP_KEY]: rec for rec in new_recs} + keymap: _KeyMapType = {rec[MD_LOOKUP_KEY]: rec for rec in new_recs} # TODO: need unit test for: # result = connection.execute("raw sql, no columns").scalars() # without the "or ()" it's failing because MD_OBJECTS is None - new_metadata._keymap.update( + keymap.update( (e, new_rec) for new_rec in new_recs for e in new_rec[MD_OBJECTS] or () ) - return new_metadata + return self._make_new_metadata( + unpickled=self._unpickled, + processors=self._processors, + keys=new_keys, + tuplefilter=tup, + translated_indexes=indexes, + keymap=keymap, + safe_for_cache=self._safe_for_cache, + keymap_by_result_column_idx=self._keymap_by_result_column_idx, + ) def _adapt_to_context(self, context: ExecutionContext) -> ResultMetaData: """When using a cached Compiled construct that has a _result_map, @@ -168,6 +288,7 @@ class CursorResultMetaData(ResultMetaData): as matched to those of the cached statement. """ + if not context.compiled or not context.compiled._result_columns: return self @@ -189,7 +310,6 @@ class CursorResultMetaData(ResultMetaData): # make a copy and add the columns from the invoked statement # to the result map. - md = self.__class__.__new__(self.__class__) keymap_by_position = self._keymap_by_result_column_idx @@ -201,26 +321,26 @@ class CursorResultMetaData(ResultMetaData): for metadata_entry in self._keymap.values() } - md._keymap = compat.dict_union( - self._keymap, - { - new: keymap_by_position[idx] - for idx, new in enumerate( - invoked_statement._all_selected_columns - ) - if idx in keymap_by_position - }, - ) - - md._unpickled = self._unpickled - md._processors = self._processors assert not self._tuplefilter - md._tuplefilter = None - md._translated_indexes = None - md._keys = self._keys - md._keymap_by_result_column_idx = self._keymap_by_result_column_idx - md._safe_for_cache = self._safe_for_cache - return md + return self._make_new_metadata( + keymap=compat.dict_union( + self._keymap, + { + new: keymap_by_position[idx] + for idx, new in enumerate( + invoked_statement._all_selected_columns + ) + if idx in keymap_by_position + }, + ), + unpickled=self._unpickled, + processors=self._processors, + tuplefilter=None, + translated_indexes=None, + keys=self._keys, + safe_for_cache=self._safe_for_cache, + keymap_by_result_column_idx=self._keymap_by_result_column_idx, + ) def __init__( self, @@ -683,7 +803,27 @@ class CursorResultMetaData(ResultMetaData): untranslated, ) - def _key_fallback(self, key, err, raiseerr=True): + @overload + def _key_fallback( + self, key: Any, err: Exception, raiseerr: Literal[True] = ... + ) -> NoReturn: + ... + + @overload + def _key_fallback( + self, key: Any, err: Exception, raiseerr: Literal[False] = ... + ) -> None: + ... + + @overload + def _key_fallback( + self, key: Any, err: Exception, raiseerr: bool = ... + ) -> Optional[NoReturn]: + ... + + def _key_fallback( + self, key: Any, err: Exception, raiseerr: bool = True + ) -> Optional[NoReturn]: if raiseerr: if self._unpickled and isinstance(key, elements.ColumnElement): @@ -714,9 +854,9 @@ class CursorResultMetaData(ResultMetaData): try: rec = self._keymap[key] except KeyError as ke: - rec = self._key_fallback(key, ke, raiseerr) - if rec is None: - return None + x = self._key_fallback(key, ke, raiseerr) + assert x is None + return None index = rec[0] @@ -734,7 +874,7 @@ class CursorResultMetaData(ResultMetaData): def _metadata_for_keys( self, keys: Sequence[Any] - ) -> Iterator[_CursorKeyMapRecType]: + ) -> Iterator[_NonAmbigCursorKeyMapRecType]: for key in keys: if int in key.__class__.__mro__: key = self._keys[key] @@ -750,7 +890,7 @@ class CursorResultMetaData(ResultMetaData): if index is None: self._raise_for_ambiguous_column_name(rec) - yield rec + yield cast(_NonAmbigCursorKeyMapRecType, rec) def __getstate__(self): return { @@ -1237,6 +1377,12 @@ _NO_RESULT_METADATA = _NoResultMetaData() SelfCursorResult = TypeVar("SelfCursorResult", bound="CursorResult[Any]") +def null_dml_result() -> IteratorResult[Any]: + it: IteratorResult[Any] = IteratorResult(_NoResultMetaData(), iter([])) + it._soft_close() + return it + + class CursorResult(Result[_T]): """A Result that is representing state from a DBAPI cursor. @@ -1586,6 +1732,142 @@ class CursorResult(Result[_T]): """ return self.context.returned_default_rows + def splice_horizontally(self, other): + """Return a new :class:`.CursorResult` that "horizontally splices" + together the rows of this :class:`.CursorResult` with that of another + :class:`.CursorResult`. + + .. tip:: This method is for the benefit of the SQLAlchemy ORM and is + not intended for general use. + + "horizontally splices" means that for each row in the first and second + result sets, a new row that concatenates the two rows together is + produced, which then becomes the new row. The incoming + :class:`.CursorResult` must have the identical number of rows. It is + typically expected that the two result sets come from the same sort + order as well, as the result rows are spliced together based on their + position in the result. + + The expected use case here is so that multiple INSERT..RETURNING + statements against different tables can produce a single result + that looks like a JOIN of those two tables. + + E.g.:: + + r1 = connection.execute( + users.insert().returning(users.c.user_name, users.c.user_id), + user_values + ) + + r2 = connection.execute( + addresses.insert().returning( + addresses.c.address_id, + addresses.c.address, + addresses.c.user_id, + ), + address_values + ) + + rows = r1.splice_horizontally(r2).all() + assert ( + rows == + [ + ("john", 1, 1, "foo@bar.com", 1), + ("jack", 2, 2, "bar@bat.com", 2), + ] + ) + + .. versionadded:: 2.0 + + .. seealso:: + + :meth:`.CursorResult.splice_vertically` + + + """ + + clone = self._generate() + total_rows = [ + tuple(r1) + tuple(r2) + for r1, r2 in zip( + list(self._raw_row_iterator()), + list(other._raw_row_iterator()), + ) + ] + + clone._metadata = clone._metadata._splice_horizontally(other._metadata) + + clone.cursor_strategy = FullyBufferedCursorFetchStrategy( + None, + initial_buffer=total_rows, + ) + clone._reset_memoizations() + return clone + + def splice_vertically(self, other): + """Return a new :class:`.CursorResult` that "vertically splices", + i.e. "extends", the rows of this :class:`.CursorResult` with that of + another :class:`.CursorResult`. + + .. tip:: This method is for the benefit of the SQLAlchemy ORM and is + not intended for general use. + + "vertically splices" means the rows of the given result are appended to + the rows of this cursor result. The incoming :class:`.CursorResult` + must have rows that represent the identical list of columns in the + identical order as they are in this :class:`.CursorResult`. + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`.CursorResult.splice_horizontally` + + """ + clone = self._generate() + total_rows = list(self._raw_row_iterator()) + list( + other._raw_row_iterator() + ) + + clone.cursor_strategy = FullyBufferedCursorFetchStrategy( + None, + initial_buffer=total_rows, + ) + clone._reset_memoizations() + return clone + + def _rewind(self, rows): + """rewind this result back to the given rowset. + + this is used internally for the case where an :class:`.Insert` + construct combines the use of + :meth:`.Insert.return_defaults` along with the + "supplemental columns" feature. + + """ + + if self._echo: + self.context.connection._log_debug( + "CursorResult rewound %d row(s)", len(rows) + ) + + # the rows given are expected to be Row objects, so we + # have to clear out processors which have already run on these + # rows + self._metadata = cast( + CursorResultMetaData, self._metadata + )._remove_processors() + + self.cursor_strategy = FullyBufferedCursorFetchStrategy( + None, + # TODO: if these are Row objects, can we save on not having to + # re-make new Row objects out of them a second time? is that + # what's actually happening right now? maybe look into this + initial_buffer=rows, + ) + self._reset_memoizations() + return self + @property def returned_defaults(self): """Return the values of default columns that were fetched using diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 11ab713d0..cb3d0528f 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -1007,6 +1007,7 @@ class DefaultExecutionContext(ExecutionContext): _is_implicit_returning = False _is_explicit_returning = False + _is_supplemental_returning = False _is_server_side = False _soft_closed = False @@ -1125,18 +1126,19 @@ class DefaultExecutionContext(ExecutionContext): self.is_text = compiled.isplaintext if ii or iu or id_: + dml_statement = compiled.compile_state.statement # type: ignore if TYPE_CHECKING: - assert isinstance(compiled.statement, UpdateBase) + assert isinstance(dml_statement, UpdateBase) self.is_crud = True - self._is_explicit_returning = ier = bool( - compiled.statement._returning - ) - self._is_implicit_returning = iir = is_implicit_returning = bool( + self._is_explicit_returning = ier = bool(dml_statement._returning) + self._is_implicit_returning = iir = bool( compiled.implicit_returning ) - assert not ( - is_implicit_returning and compiled.statement._returning - ) + if iir and dml_statement._supplemental_returning: + self._is_supplemental_returning = True + + # dont mix implicit and explicit returning + assert not (iir and ier) if (ier or iir) and compiled.for_executemany: if ii and not self.dialect.insert_executemany_returning: @@ -1711,7 +1713,14 @@ class DefaultExecutionContext(ExecutionContext): # are that the result has only one row, until executemany() # support is added here. assert result._metadata.returns_rows - result._soft_close() + + # Insert statement has both return_defaults() and + # returning(). rewind the result on the list of rows + # we just used. + if self._is_supplemental_returning: + result._rewind(rows) + else: + result._soft_close() elif not self._is_explicit_returning: result._soft_close() @@ -1721,21 +1730,18 @@ class DefaultExecutionContext(ExecutionContext): # function so this is not necessarily true. # assert not result.returns_rows - elif self.isupdate and self._is_implicit_returning: - # get rowcount - # (which requires open cursor on some drivers) - # we were not doing this in 1.4, however - # test_rowcount -> test_update_rowcount_return_defaults - # is testing this, and psycopg will no longer return - # rowcount after cursor is closed. - result.rowcount - self._has_rowcount = True + elif self._is_implicit_returning: + rows = result.all() - row = result.fetchone() - if row is not None: - self.returned_default_rows = [row] + if rows: + self.returned_default_rows = rows + result.rowcount = len(rows) + self._has_rowcount = True - result._soft_close() + if self._is_supplemental_returning: + result._rewind(rows) + else: + result._soft_close() # test that it has a cursor metadata that is accurate. # the rows have all been fetched however. @@ -1750,7 +1756,6 @@ class DefaultExecutionContext(ExecutionContext): elif self.isupdate or self.isdelete: result.rowcount self._has_rowcount = True - return result @util.memoized_property diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index df5a8199c..05ca17063 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -109,9 +109,27 @@ class ResultMetaData: def _for_freeze(self) -> ResultMetaData: raise NotImplementedError() + @overload def _key_fallback( - self, key: _KeyType, err: Exception, raiseerr: bool = True + self, key: Any, err: Exception, raiseerr: Literal[True] = ... ) -> NoReturn: + ... + + @overload + def _key_fallback( + self, key: Any, err: Exception, raiseerr: Literal[False] = ... + ) -> None: + ... + + @overload + def _key_fallback( + self, key: Any, err: Exception, raiseerr: bool = ... + ) -> Optional[NoReturn]: + ... + + def _key_fallback( + self, key: Any, err: Exception, raiseerr: bool = True + ) -> Optional[NoReturn]: assert raiseerr raise KeyError(key) from err @@ -2148,6 +2166,7 @@ class IteratorResult(Result[_TP]): """ _hard_closed = False + _soft_closed = False def __init__( self, @@ -2168,6 +2187,7 @@ class IteratorResult(Result[_TP]): self.raw._soft_close(hard=hard, **kw) self.iterator = iter([]) self._reset_memoizations() + self._soft_closed = True def _raise_hard_closed(self) -> NoReturn: raise exc.ResourceClosedError("This result object is closed.") diff --git a/lib/sqlalchemy/orm/bulk_persistence.py b/lib/sqlalchemy/orm/bulk_persistence.py index 225292d17..3ed34a57a 100644 --- a/lib/sqlalchemy/orm/bulk_persistence.py +++ b/lib/sqlalchemy/orm/bulk_persistence.py @@ -15,24 +15,32 @@ specifically outside of the flush() process. from __future__ import annotations from typing import Any +from typing import cast from typing import Dict from typing import Iterable +from typing import Optional +from typing import overload from typing import TYPE_CHECKING from typing import TypeVar from typing import Union from . import attributes +from . import context from . import evaluator from . import exc as orm_exc +from . import loading from . import persistence from .base import NO_VALUE from .context import AbstractORMCompileState +from .context import FromStatement +from .context import ORMFromStatementCompileState +from .context import QueryContext from .. import exc as sa_exc -from .. import sql from .. import util from ..engine import Dialect from ..engine import result as _result from ..sql import coercions +from ..sql import dml from ..sql import expression from ..sql import roles from ..sql import select @@ -48,16 +56,24 @@ from ..util.typing import Literal if TYPE_CHECKING: from .mapper import Mapper + from .session import _BindArguments from .session import ORMExecuteState + from .session import Session from .session import SessionTransaction from .state import InstanceState + from ..engine import Connection + from ..engine import cursor + from ..engine.interfaces import _CoreAnyExecuteParams + from ..engine.interfaces import _ExecuteOptionsParameter _O = TypeVar("_O", bound=object) -_SynchronizeSessionArgument = Literal[False, "evaluate", "fetch"] +_SynchronizeSessionArgument = Literal[False, "auto", "evaluate", "fetch"] +_DMLStrategyArgument = Literal["bulk", "raw", "orm", "auto"] +@overload def _bulk_insert( mapper: Mapper[_O], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], @@ -65,7 +81,36 @@ def _bulk_insert( isstates: bool, return_defaults: bool, render_nulls: bool, + use_orm_insert_stmt: Literal[None] = ..., + execution_options: Optional[_ExecuteOptionsParameter] = ..., ) -> None: + ... + + +@overload +def _bulk_insert( + mapper: Mapper[_O], + mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + session_transaction: SessionTransaction, + isstates: bool, + return_defaults: bool, + render_nulls: bool, + use_orm_insert_stmt: Optional[dml.Insert] = ..., + execution_options: Optional[_ExecuteOptionsParameter] = ..., +) -> cursor.CursorResult[Any]: + ... + + +def _bulk_insert( + mapper: Mapper[_O], + mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + session_transaction: SessionTransaction, + isstates: bool, + return_defaults: bool, + render_nulls: bool, + use_orm_insert_stmt: Optional[dml.Insert] = None, + execution_options: Optional[_ExecuteOptionsParameter] = None, +) -> Optional[cursor.CursorResult[Any]]: base_mapper = mapper.base_mapper if session_transaction.session.connection_callable: @@ -81,13 +126,27 @@ def _bulk_insert( else: mappings = [state.dict for state in mappings] else: - mappings = list(mappings) + mappings = [dict(m) for m in mappings] + _expand_composites(mapper, mappings) connection = session_transaction.connection(base_mapper) + + return_result: Optional[cursor.CursorResult[Any]] = None + for table, super_mapper in base_mapper._sorted_tables.items(): - if not mapper.isa(super_mapper): + if not mapper.isa(super_mapper) or table not in mapper._pks_by_table: continue + is_joined_inh_supertable = super_mapper is not mapper + bookkeeping = ( + is_joined_inh_supertable + or return_defaults + or ( + use_orm_insert_stmt is not None + and bool(use_orm_insert_stmt._returning) + ) + ) + records = ( ( None, @@ -112,18 +171,25 @@ def _bulk_insert( table, ((None, mapping, mapper, connection) for mapping in mappings), bulk=True, - return_defaults=return_defaults, + return_defaults=bookkeeping, render_nulls=render_nulls, ) ) - persistence._emit_insert_statements( + result = persistence._emit_insert_statements( base_mapper, None, super_mapper, table, records, - bookkeeping=return_defaults, + bookkeeping=bookkeeping, + use_orm_insert_stmt=use_orm_insert_stmt, + execution_options=execution_options, ) + if use_orm_insert_stmt is not None: + if not use_orm_insert_stmt._returning or return_result is None: + return_result = result + elif result.returns_rows: + return_result = return_result.splice_horizontally(result) if return_defaults and isstates: identity_cls = mapper._identity_class @@ -134,14 +200,43 @@ def _bulk_insert( tuple([dict_[key] for key in identity_props]), ) + if use_orm_insert_stmt is not None: + assert return_result is not None + return return_result + +@overload def _bulk_update( mapper: Mapper[Any], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], session_transaction: SessionTransaction, isstates: bool, update_changed_only: bool, + use_orm_update_stmt: Literal[None] = ..., ) -> None: + ... + + +@overload +def _bulk_update( + mapper: Mapper[Any], + mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + session_transaction: SessionTransaction, + isstates: bool, + update_changed_only: bool, + use_orm_update_stmt: Optional[dml.Update] = ..., +) -> _result.Result[Any]: + ... + + +def _bulk_update( + mapper: Mapper[Any], + mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + session_transaction: SessionTransaction, + isstates: bool, + update_changed_only: bool, + use_orm_update_stmt: Optional[dml.Update] = None, +) -> Optional[_result.Result[Any]]: base_mapper = mapper.base_mapper search_keys = mapper._primary_key_propkeys @@ -161,7 +256,8 @@ def _bulk_update( else: mappings = [state.dict for state in mappings] else: - mappings = list(mappings) + mappings = [dict(m) for m in mappings] + _expand_composites(mapper, mappings) if session_transaction.session.connection_callable: raise NotImplementedError( @@ -172,7 +268,7 @@ def _bulk_update( connection = session_transaction.connection(base_mapper) for table, super_mapper in base_mapper._sorted_tables.items(): - if not mapper.isa(super_mapper): + if not mapper.isa(super_mapper) or table not in mapper._pks_by_table: continue records = persistence._collect_update_commands( @@ -193,8 +289,8 @@ def _bulk_update( for mapping in mappings ), bulk=True, + use_orm_update_stmt=use_orm_update_stmt, ) - persistence._emit_update_statements( base_mapper, None, @@ -202,10 +298,125 @@ def _bulk_update( table, records, bookkeeping=False, + use_orm_update_stmt=use_orm_update_stmt, ) + if use_orm_update_stmt is not None: + return _result.null_result() + + +def _expand_composites(mapper, mappings): + composite_attrs = mapper.composites + if not composite_attrs: + return + + composite_keys = set(composite_attrs.keys()) + populators = { + key: composite_attrs[key]._populate_composite_bulk_save_mappings_fn() + for key in composite_keys + } + for mapping in mappings: + for key in composite_keys.intersection(mapping): + populators[key](mapping) + class ORMDMLState(AbstractORMCompileState): + is_dml_returning = True + from_statement_ctx: Optional[ORMFromStatementCompileState] = None + + @classmethod + def _get_orm_crud_kv_pairs( + cls, mapper, statement, kv_iterator, needs_to_be_cacheable + ): + + core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs + + for k, v in kv_iterator: + k = coercions.expect(roles.DMLColumnRole, k) + + if isinstance(k, str): + desc = _entity_namespace_key(mapper, k, default=NO_VALUE) + if desc is NO_VALUE: + yield ( + coercions.expect(roles.DMLColumnRole, k), + coercions.expect( + roles.ExpressionElementRole, + v, + type_=sqltypes.NullType(), + is_crud=True, + ) + if needs_to_be_cacheable + else v, + ) + else: + yield from core_get_crud_kv_pairs( + statement, + desc._bulk_update_tuples(v), + needs_to_be_cacheable, + ) + elif "entity_namespace" in k._annotations: + k_anno = k._annotations + attr = _entity_namespace_key( + k_anno["entity_namespace"], k_anno["proxy_key"] + ) + yield from core_get_crud_kv_pairs( + statement, + attr._bulk_update_tuples(v), + needs_to_be_cacheable, + ) + else: + yield ( + k, + v + if not needs_to_be_cacheable + else coercions.expect( + roles.ExpressionElementRole, + v, + type_=sqltypes.NullType(), + is_crud=True, + ), + ) + + @classmethod + def _get_multi_crud_kv_pairs(cls, statement, kv_iterator): + plugin_subject = statement._propagate_attrs["plugin_subject"] + + if not plugin_subject or not plugin_subject.mapper: + return UpdateDMLState._get_multi_crud_kv_pairs( + statement, kv_iterator + ) + + return [ + dict( + cls._get_orm_crud_kv_pairs( + plugin_subject.mapper, statement, value_dict.items(), False + ) + ) + for value_dict in kv_iterator + ] + + @classmethod + def _get_crud_kv_pairs(cls, statement, kv_iterator, needs_to_be_cacheable): + assert ( + needs_to_be_cacheable + ), "no test coverage for needs_to_be_cacheable=False" + + plugin_subject = statement._propagate_attrs["plugin_subject"] + + if not plugin_subject or not plugin_subject.mapper: + return UpdateDMLState._get_crud_kv_pairs( + statement, kv_iterator, needs_to_be_cacheable + ) + + return list( + cls._get_orm_crud_kv_pairs( + plugin_subject.mapper, + statement, + kv_iterator, + needs_to_be_cacheable, + ) + ) + @classmethod def get_entity_description(cls, statement): ext_info = statement.table._annotations["parententity"] @@ -250,18 +461,101 @@ class ORMDMLState(AbstractORMCompileState): ] ] + def _setup_orm_returning( + self, + compiler, + orm_level_statement, + dml_level_statement, + use_supplemental_cols=True, + dml_mapper=None, + ): + """establish ORM column handlers for an INSERT, UPDATE, or DELETE + which uses explicit returning(). + + called within compilation level create_for_statement. + + The _return_orm_returning() method then receives the Result + after the statement was executed, and applies ORM loading to the + state that we first established here. + + """ + + if orm_level_statement._returning: + + fs = FromStatement( + orm_level_statement._returning, dml_level_statement + ) + fs = fs.options(*orm_level_statement._with_options) + self.select_statement = fs + self.from_statement_ctx = ( + fsc + ) = ORMFromStatementCompileState.create_for_statement(fs, compiler) + fsc.setup_dml_returning_compile_state(dml_mapper) + + dml_level_statement = dml_level_statement._generate() + dml_level_statement._returning = () + + cols_to_return = [c for c in fsc.primary_columns if c is not None] + + # since we are splicing result sets together, make sure there + # are columns of some kind returned in each result set + if not cols_to_return: + cols_to_return.extend(dml_mapper.primary_key) + + if use_supplemental_cols: + dml_level_statement = dml_level_statement.return_defaults( + supplemental_cols=cols_to_return + ) + else: + dml_level_statement = dml_level_statement.returning( + *cols_to_return + ) + + return dml_level_statement + + @classmethod + def _return_orm_returning( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + result, + ): + + execution_context = result.context + compile_state = execution_context.compiled.compile_state + + if compile_state.from_statement_ctx: + load_options = execution_options.get( + "_sa_orm_load_options", QueryContext.default_load_options + ) + querycontext = QueryContext( + compile_state.from_statement_ctx, + compile_state.select_statement, + params, + session, + load_options, + execution_options, + bind_arguments, + ) + return loading.instances(result, querycontext) + else: + return result + class BulkUDCompileState(ORMDMLState): class default_update_options(Options): - _synchronize_session: _SynchronizeSessionArgument = "evaluate" - _is_delete_using = False - _is_update_from = False - _autoflush = True - _subject_mapper = None + _dml_strategy: _DMLStrategyArgument = "auto" + _synchronize_session: _SynchronizeSessionArgument = "auto" + _can_use_returning: bool = False + _is_delete_using: bool = False + _is_update_from: bool = False + _autoflush: bool = True + _subject_mapper: Optional[Mapper[Any]] = None _resolved_values = EMPTY_DICT - _resolved_keys_as_propnames = EMPTY_DICT - _value_evaluators = EMPTY_DICT - _matched_objects = None + _eval_condition = None _matched_rows = None _refresh_identity_token = None @@ -295,19 +589,16 @@ class BulkUDCompileState(ORMDMLState): execution_options, ) = BulkUDCompileState.default_update_options.from_execution_options( "_sa_orm_update_options", - {"synchronize_session", "is_delete_using", "is_update_from"}, + { + "synchronize_session", + "is_delete_using", + "is_update_from", + "dml_strategy", + }, execution_options, statement._execution_options, ) - sync = update_options._synchronize_session - if sync is not None: - if sync not in ("evaluate", "fetch", False): - raise sa_exc.ArgumentError( - "Valid strategies for session synchronization " - "are 'evaluate', 'fetch', False" - ) - bind_arguments["clause"] = statement try: plugin_subject = statement._propagate_attrs["plugin_subject"] @@ -318,43 +609,86 @@ class BulkUDCompileState(ORMDMLState): update_options += {"_subject_mapper": plugin_subject.mapper} + if not isinstance(params, list): + if update_options._dml_strategy == "auto": + update_options += {"_dml_strategy": "orm"} + elif update_options._dml_strategy == "bulk": + raise sa_exc.InvalidRequestError( + 'Can\'t use "bulk" ORM insert strategy without ' + "passing separate parameters" + ) + else: + if update_options._dml_strategy == "auto": + update_options += {"_dml_strategy": "bulk"} + elif update_options._dml_strategy == "orm": + raise sa_exc.InvalidRequestError( + 'Can\'t use "orm" ORM insert strategy with a ' + "separate parameter list" + ) + + sync = update_options._synchronize_session + if sync is not None: + if sync not in ("auto", "evaluate", "fetch", False): + raise sa_exc.ArgumentError( + "Valid strategies for session synchronization " + "are 'auto', 'evaluate', 'fetch', False" + ) + if update_options._dml_strategy == "bulk" and sync == "fetch": + raise sa_exc.InvalidRequestError( + "The 'fetch' synchronization strategy is not available " + "for 'bulk' ORM updates (i.e. multiple parameter sets)" + ) + if update_options._autoflush: session._autoflush() + if update_options._dml_strategy == "orm": + + if update_options._synchronize_session == "auto": + update_options = cls._do_pre_synchronize_auto( + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ) + elif update_options._synchronize_session == "evaluate": + update_options = cls._do_pre_synchronize_evaluate( + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ) + elif update_options._synchronize_session == "fetch": + update_options = cls._do_pre_synchronize_fetch( + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ) + elif update_options._dml_strategy == "bulk": + if update_options._synchronize_session == "auto": + update_options += {"_synchronize_session": "evaluate"} + + # indicators from the "pre exec" step that are then + # added to the DML statement, which will also be part of the cache + # key. The compile level create_for_statement() method will then + # consume these at compiler time. statement = statement._annotate( { "synchronize_session": update_options._synchronize_session, "is_delete_using": update_options._is_delete_using, "is_update_from": update_options._is_update_from, + "dml_strategy": update_options._dml_strategy, + "can_use_returning": update_options._can_use_returning, } ) - # this stage of the execution is called before the do_orm_execute event - # hook. meaning for an extension like horizontal sharding, this step - # happens before the extension splits out into multiple backends and - # runs only once. if we do pre_sync_fetch, we execute a SELECT - # statement, which the horizontal sharding extension splits amongst the - # shards and combines the results together. - - if update_options._synchronize_session == "evaluate": - update_options = cls._do_pre_synchronize_evaluate( - session, - statement, - params, - execution_options, - bind_arguments, - update_options, - ) - elif update_options._synchronize_session == "fetch": - update_options = cls._do_pre_synchronize_fetch( - session, - statement, - params, - execution_options, - bind_arguments, - update_options, - ) - return ( statement, util.immutabledict(execution_options).union( @@ -382,12 +716,30 @@ class BulkUDCompileState(ORMDMLState): # individual ones we return here. update_options = execution_options["_sa_orm_update_options"] - if update_options._synchronize_session == "evaluate": - cls._do_post_synchronize_evaluate(session, result, update_options) - elif update_options._synchronize_session == "fetch": - cls._do_post_synchronize_fetch(session, result, update_options) + if update_options._dml_strategy == "orm": + if update_options._synchronize_session == "evaluate": + cls._do_post_synchronize_evaluate( + session, statement, result, update_options + ) + elif update_options._synchronize_session == "fetch": + cls._do_post_synchronize_fetch( + session, statement, result, update_options + ) + elif update_options._dml_strategy == "bulk": + if update_options._synchronize_session == "evaluate": + cls._do_post_synchronize_bulk_evaluate( + session, params, result, update_options + ) + return result - return result + return cls._return_orm_returning( + session, + statement, + params, + execution_options, + bind_arguments, + result, + ) @classmethod def _adjust_for_extra_criteria(cls, global_attributes, ext_info): @@ -473,11 +825,76 @@ class BulkUDCompileState(ORMDMLState): primary_key_convert = [ lookup[bpk] for bpk in mapper.base_mapper.primary_key ] - return [tuple(row[idx] for idx in primary_key_convert) for row in rows] @classmethod - def _do_pre_synchronize_evaluate( + def _get_matched_objects_on_criteria(cls, update_options, states): + mapper = update_options._subject_mapper + eval_condition = update_options._eval_condition + + raw_data = [ + (state.obj(), state, state.dict) + for state in states + if state.mapper.isa(mapper) and not state.expired + ] + + identity_token = update_options._refresh_identity_token + if identity_token is not None: + raw_data = [ + (obj, state, dict_) + for obj, state, dict_ in raw_data + if state.identity_token == identity_token + ] + + result = [] + for obj, state, dict_ in raw_data: + evaled_condition = eval_condition(obj) + + # caution: don't use "in ()" or == here, _EXPIRE_OBJECT + # evaluates as True for all comparisons + if ( + evaled_condition is True + or evaled_condition is evaluator._EXPIRED_OBJECT + ): + result.append( + ( + obj, + state, + dict_, + evaled_condition is evaluator._EXPIRED_OBJECT, + ) + ) + return result + + @classmethod + def _eval_condition_from_statement(cls, update_options, statement): + mapper = update_options._subject_mapper + target_cls = mapper.class_ + + evaluator_compiler = evaluator.EvaluatorCompiler(target_cls) + crit = () + if statement._where_criteria: + crit += statement._where_criteria + + global_attributes = {} + for opt in statement._with_options: + if opt._is_criteria_option: + opt.get_global_criteria(global_attributes) + + if global_attributes: + crit += cls._adjust_for_extra_criteria(global_attributes, mapper) + + if crit: + eval_condition = evaluator_compiler.process(*crit) + else: + + def eval_condition(obj): + return True + + return eval_condition + + @classmethod + def _do_pre_synchronize_auto( cls, session, statement, @@ -486,33 +903,59 @@ class BulkUDCompileState(ORMDMLState): bind_arguments, update_options, ): - mapper = update_options._subject_mapper - target_cls = mapper.class_ + """setup auto sync strategy + + + "auto" checks if we can use "evaluate" first, then falls back + to "fetch" + + evaluate is vastly more efficient for the common case + where session is empty, only has a few objects, and the UPDATE + statement can potentially match thousands/millions of rows. - value_evaluators = resolved_keys_as_propnames = EMPTY_DICT + OTOH more complex criteria that fails to work with "evaluate" + we would hope usually correlates with fewer net rows. + + """ try: - evaluator_compiler = evaluator.EvaluatorCompiler(target_cls) - crit = () - if statement._where_criteria: - crit += statement._where_criteria + eval_condition = cls._eval_condition_from_statement( + update_options, statement + ) - global_attributes = {} - for opt in statement._with_options: - if opt._is_criteria_option: - opt.get_global_criteria(global_attributes) + except evaluator.UnevaluatableError: + pass + else: + return update_options + { + "_eval_condition": eval_condition, + "_synchronize_session": "evaluate", + } - if global_attributes: - crit += cls._adjust_for_extra_criteria( - global_attributes, mapper - ) + update_options += {"_synchronize_session": "fetch"} + return cls._do_pre_synchronize_fetch( + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ) - if crit: - eval_condition = evaluator_compiler.process(*crit) - else: + @classmethod + def _do_pre_synchronize_evaluate( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ): - def eval_condition(obj): - return True + try: + eval_condition = cls._eval_condition_from_statement( + update_options, statement + ) except evaluator.UnevaluatableError as err: raise sa_exc.InvalidRequestError( @@ -521,52 +964,8 @@ class BulkUDCompileState(ORMDMLState): "synchronize_session execution option." % err ) from err - if statement.__visit_name__ == "lambda_element": - # ._resolved is called on every LambdaElement in order to - # generate the cache key, so this access does not add - # additional expense - effective_statement = statement._resolved - else: - effective_statement = statement - - if effective_statement.__visit_name__ == "update": - resolved_values = cls._get_resolved_values( - mapper, effective_statement - ) - value_evaluators = {} - resolved_keys_as_propnames = cls._resolved_keys_as_propnames( - mapper, resolved_values - ) - for key, value in resolved_keys_as_propnames: - try: - _evaluator = evaluator_compiler.process( - coercions.expect(roles.ExpressionElementRole, value) - ) - except evaluator.UnevaluatableError: - pass - else: - value_evaluators[key] = _evaluator - - # TODO: detect when the where clause is a trivial primary key match. - matched_objects = [ - state.obj() - for state in session.identity_map.all_states() - if state.mapper.isa(mapper) - and not state.expired - and eval_condition(state.obj()) - and ( - update_options._refresh_identity_token is None - # TODO: coverage for the case where horizontal sharding - # invokes an update() or delete() given an explicit identity - # token up front - or state.identity_token - == update_options._refresh_identity_token - ) - ] return update_options + { - "_matched_objects": matched_objects, - "_value_evaluators": value_evaluators, - "_resolved_keys_as_propnames": resolved_keys_as_propnames, + "_eval_condition": eval_condition, } @classmethod @@ -584,12 +983,6 @@ class BulkUDCompileState(ORMDMLState): def _resolved_keys_as_propnames(cls, mapper, resolved_values): values = [] for k, v in resolved_values: - if isinstance(k, attributes.QueryableAttribute): - values.append((k.key, v)) - continue - elif hasattr(k, "__clause_element__"): - k = k.__clause_element__() - if mapper and isinstance(k, expression.ColumnElement): try: attr = mapper._columntoproperty[k] @@ -599,7 +992,8 @@ class BulkUDCompileState(ORMDMLState): values.append((attr.key, v)) else: raise sa_exc.InvalidRequestError( - "Invalid expression type: %r" % k + "Attribute name not found, can't be " + "synchronized back to objects: %r" % k ) return values @@ -622,14 +1016,43 @@ class BulkUDCompileState(ORMDMLState): ) select_stmt._where_criteria = statement._where_criteria + # conditionally run the SELECT statement for pre-fetch, testing the + # "bind" for if we can use RETURNING or not using the do_orm_execute + # event. If RETURNING is available, the do_orm_execute event + # will cancel the SELECT from being actually run. + # + # The way this is organized seems strange, why don't we just + # call can_use_returning() before invoking the statement and get + # answer?, why does this go through the whole execute phase using an + # event? Answer: because we are integrating with extensions such + # as the horizontal sharding extention that "multiplexes" an individual + # statement run through multiple engines, and it uses + # do_orm_execute() to do that. + + can_use_returning = None + def skip_for_returning(orm_context: ORMExecuteState) -> Any: bind = orm_context.session.get_bind(**orm_context.bind_arguments) - if cls.can_use_returning( + nonlocal can_use_returning + + per_bind_result = cls.can_use_returning( bind.dialect, mapper, is_update_from=update_options._is_update_from, is_delete_using=update_options._is_delete_using, - ): + ) + + if can_use_returning is not None: + if can_use_returning != per_bind_result: + raise sa_exc.InvalidRequestError( + "For synchronize_session='fetch', can't mix multiple " + "backends where some support RETURNING and others " + "don't" + ) + else: + can_use_returning = per_bind_result + + if per_bind_result: return _result.null_result() else: return None @@ -643,52 +1066,22 @@ class BulkUDCompileState(ORMDMLState): ) matched_rows = result.fetchall() - value_evaluators = EMPTY_DICT - - if statement.__visit_name__ == "lambda_element": - # ._resolved is called on every LambdaElement in order to - # generate the cache key, so this access does not add - # additional expense - effective_statement = statement._resolved - else: - effective_statement = statement - - if effective_statement.__visit_name__ == "update": - target_cls = mapper.class_ - evaluator_compiler = evaluator.EvaluatorCompiler(target_cls) - resolved_values = cls._get_resolved_values( - mapper, effective_statement - ) - resolved_keys_as_propnames = cls._resolved_keys_as_propnames( - mapper, resolved_values - ) - - resolved_keys_as_propnames = cls._resolved_keys_as_propnames( - mapper, resolved_values - ) - value_evaluators = {} - for key, value in resolved_keys_as_propnames: - try: - _evaluator = evaluator_compiler.process( - coercions.expect(roles.ExpressionElementRole, value) - ) - except evaluator.UnevaluatableError: - pass - else: - value_evaluators[key] = _evaluator - - else: - resolved_keys_as_propnames = EMPTY_DICT - return update_options + { - "_value_evaluators": value_evaluators, "_matched_rows": matched_rows, - "_resolved_keys_as_propnames": resolved_keys_as_propnames, + "_can_use_returning": can_use_returning, } @CompileState.plugin_for("orm", "insert") -class ORMInsert(ORMDMLState, InsertDMLState): +class BulkORMInsert(ORMDMLState, InsertDMLState): + class default_insert_options(Options): + _dml_strategy: _DMLStrategyArgument = "auto" + _render_nulls: bool = False + _return_defaults: bool = False + _subject_mapper: Optional[Mapper[Any]] = None + + select_statement: Optional[FromStatement] = None + @classmethod def orm_pre_session_exec( cls, @@ -699,6 +1092,16 @@ class ORMInsert(ORMDMLState, InsertDMLState): bind_arguments, is_reentrant_invoke, ): + + ( + insert_options, + execution_options, + ) = BulkORMInsert.default_insert_options.from_execution_options( + "_sa_orm_insert_options", + {"dml_strategy"}, + execution_options, + statement._execution_options, + ) bind_arguments["clause"] = statement try: plugin_subject = statement._propagate_attrs["plugin_subject"] @@ -707,22 +1110,209 @@ class ORMInsert(ORMDMLState, InsertDMLState): else: bind_arguments["mapper"] = plugin_subject.mapper + insert_options += {"_subject_mapper": plugin_subject.mapper} + + if not params: + if insert_options._dml_strategy == "auto": + insert_options += {"_dml_strategy": "orm"} + elif insert_options._dml_strategy == "bulk": + raise sa_exc.InvalidRequestError( + 'Can\'t use "bulk" ORM insert strategy without ' + "passing separate parameters" + ) + else: + if insert_options._dml_strategy == "auto": + insert_options += {"_dml_strategy": "bulk"} + elif insert_options._dml_strategy == "orm": + raise sa_exc.InvalidRequestError( + 'Can\'t use "orm" ORM insert strategy with a ' + "separate parameter list" + ) + + if insert_options._dml_strategy != "raw": + # for ORM object loading, like ORMContext, we have to disable + # result set adapt_to_context, because we will be generating a + # new statement with specific columns that's cached inside of + # an ORMFromStatementCompileState, which we will re-use for + # each result. + if not execution_options: + execution_options = context._orm_load_exec_options + else: + execution_options = execution_options.union( + context._orm_load_exec_options + ) + + statement = statement._annotate( + {"dml_strategy": insert_options._dml_strategy} + ) + return ( statement, - util.immutabledict(execution_options), + util.immutabledict(execution_options).union( + {"_sa_orm_insert_options": insert_options} + ), ) @classmethod - def orm_setup_cursor_result( + def orm_execute_statement( cls, - session, - statement, - params, - execution_options, - bind_arguments, - result, - ): - return result + session: Session, + statement: dml.Insert, + params: _CoreAnyExecuteParams, + execution_options: _ExecuteOptionsParameter, + bind_arguments: _BindArguments, + conn: Connection, + ) -> _result.Result: + + insert_options = execution_options.get( + "_sa_orm_insert_options", cls.default_insert_options + ) + + if insert_options._dml_strategy not in ( + "raw", + "bulk", + "orm", + "auto", + ): + raise sa_exc.ArgumentError( + "Valid strategies for ORM insert strategy " + "are 'raw', 'orm', 'bulk', 'auto" + ) + + result: _result.Result[Any] + + if insert_options._dml_strategy == "raw": + result = conn.execute( + statement, params or {}, execution_options=execution_options + ) + return result + + if insert_options._dml_strategy == "bulk": + mapper = insert_options._subject_mapper + + if ( + statement._post_values_clause is not None + and mapper._multiple_persistence_tables + ): + raise sa_exc.InvalidRequestError( + "bulk INSERT with a 'post values' clause " + "(typically upsert) not supported for multi-table " + f"mapper {mapper}" + ) + + assert mapper is not None + assert session._transaction is not None + result = _bulk_insert( + mapper, + cast( + "Iterable[Dict[str, Any]]", + [params] if isinstance(params, dict) else params, + ), + session._transaction, + isstates=False, + return_defaults=insert_options._return_defaults, + render_nulls=insert_options._render_nulls, + use_orm_insert_stmt=statement, + execution_options=execution_options, + ) + elif insert_options._dml_strategy == "orm": + result = conn.execute( + statement, params or {}, execution_options=execution_options + ) + else: + raise AssertionError() + + if not bool(statement._returning): + return result + + return cls._return_orm_returning( + session, + statement, + params, + execution_options, + bind_arguments, + result, + ) + + @classmethod + def create_for_statement(cls, statement, compiler, **kw) -> BulkORMInsert: + + self = cast( + BulkORMInsert, + super().create_for_statement(statement, compiler, **kw), + ) + + if compiler is not None: + toplevel = not compiler.stack + else: + toplevel = True + if not toplevel: + return self + + mapper = statement._propagate_attrs["plugin_subject"] + dml_strategy = statement._annotations.get("dml_strategy", "raw") + if dml_strategy == "bulk": + self._setup_for_bulk_insert(compiler) + elif dml_strategy == "orm": + self._setup_for_orm_insert(compiler, mapper) + + return self + + @classmethod + def _resolved_keys_as_col_keys(cls, mapper, resolved_value_dict): + return { + col.key if col is not None else k: v + for col, k, v in ( + (mapper.c.get(k), k, v) for k, v in resolved_value_dict.items() + ) + } + + def _setup_for_orm_insert(self, compiler, mapper): + statement = orm_level_statement = cast(dml.Insert, self.statement) + + statement = self._setup_orm_returning( + compiler, + orm_level_statement, + statement, + use_supplemental_cols=False, + ) + self.statement = statement + + def _setup_for_bulk_insert(self, compiler): + """establish an INSERT statement within the context of + bulk insert. + + This method will be within the "conn.execute()" call that is invoked + by persistence._emit_insert_statement(). + + """ + statement = orm_level_statement = cast(dml.Insert, self.statement) + an = statement._annotations + + emit_insert_table, emit_insert_mapper = ( + an["_emit_insert_table"], + an["_emit_insert_mapper"], + ) + + statement = statement._clone() + + statement.table = emit_insert_table + if self._dict_parameters: + self._dict_parameters = { + col: val + for col, val in self._dict_parameters.items() + if col.table is emit_insert_table + } + + statement = self._setup_orm_returning( + compiler, + orm_level_statement, + statement, + use_supplemental_cols=True, + dml_mapper=emit_insert_mapper, + ) + + self.statement = statement @CompileState.plugin_for("orm", "update") @@ -732,13 +1322,27 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): self = cls.__new__(cls) + dml_strategy = statement._annotations.get( + "dml_strategy", "unspecified" + ) + + if dml_strategy == "bulk": + self._setup_for_bulk_update(statement, compiler) + elif dml_strategy in ("orm", "unspecified"): + self._setup_for_orm_update(statement, compiler) + + return self + + def _setup_for_orm_update(self, statement, compiler, **kw): + orm_level_statement = statement + ext_info = statement.table._annotations["parententity"] self.mapper = mapper = ext_info.mapper self.extra_criteria_entities = {} - self._resolved_values = cls._get_resolved_values(mapper, statement) + self._resolved_values = self._get_resolved_values(mapper, statement) extra_criteria_attributes = {} @@ -749,8 +1353,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): if statement._values: self._resolved_values = dict(self._resolved_values) - new_stmt = sql.Update.__new__(sql.Update) - new_stmt.__dict__.update(statement.__dict__) + new_stmt = statement._clone() new_stmt.table = mapper.local_table # note if the statement has _multi_values, these @@ -762,7 +1365,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): elif statement._values: new_stmt._values = self._resolved_values - new_crit = cls._adjust_for_extra_criteria( + new_crit = self._adjust_for_extra_criteria( extra_criteria_attributes, mapper ) if new_crit: @@ -776,21 +1379,150 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): UpdateDMLState.__init__(self, new_stmt, compiler, **kw) - if compiler._annotations.get( + use_supplemental_cols = False + + synchronize_session = compiler._annotations.get( "synchronize_session", None - ) == "fetch" and self.can_use_returning( - compiler.dialect, mapper, is_multitable=self.is_multitable - ): - if new_stmt._returning: - raise sa_exc.InvalidRequestError( - "Can't use synchronize_session='fetch' " - "with explicit returning()" + ) + can_use_returning = compiler._annotations.get( + "can_use_returning", None + ) + if can_use_returning is not False: + # even though pre_exec has determined basic + # can_use_returning for the dialect, if we are to use + # RETURNING we need to run can_use_returning() at this level + # unconditionally because is_delete_using was not known + # at the pre_exec level + can_use_returning = ( + synchronize_session == "fetch" + and self.can_use_returning( + compiler.dialect, mapper, is_multitable=self.is_multitable ) - self.statement = self.statement.returning( - *mapper.local_table.primary_key ) - return self + if synchronize_session == "fetch" and can_use_returning: + use_supplemental_cols = True + + # NOTE: we might want to RETURNING the actual columns to be + # synchronized also. however this is complicated and difficult + # to align against the behavior of "evaluate". Additionally, + # in a large number (if not the majority) of cases, we have the + # "evaluate" answer, usually a fixed value, in memory already and + # there's no need to re-fetch the same value + # over and over again. so perhaps if it could be RETURNING just + # the elements that were based on a SQL expression and not + # a constant. For now it doesn't quite seem worth it + new_stmt = new_stmt.return_defaults( + *(list(mapper.local_table.primary_key)) + ) + + new_stmt = self._setup_orm_returning( + compiler, + orm_level_statement, + new_stmt, + use_supplemental_cols=use_supplemental_cols, + ) + + self.statement = new_stmt + + def _setup_for_bulk_update(self, statement, compiler, **kw): + """establish an UPDATE statement within the context of + bulk insert. + + This method will be within the "conn.execute()" call that is invoked + by persistence._emit_update_statement(). + + """ + statement = cast(dml.Update, statement) + an = statement._annotations + + emit_update_table, _ = ( + an["_emit_update_table"], + an["_emit_update_mapper"], + ) + + statement = statement._clone() + statement.table = emit_update_table + + UpdateDMLState.__init__(self, statement, compiler, **kw) + + if self._ordered_values: + raise sa_exc.InvalidRequestError( + "bulk ORM UPDATE does not support ordered_values() for " + "custom UPDATE statements with bulk parameter sets. Use a " + "non-bulk UPDATE statement or use values()." + ) + + if self._dict_parameters: + self._dict_parameters = { + col: val + for col, val in self._dict_parameters.items() + if col.table is emit_update_table + } + self.statement = statement + + @classmethod + def orm_execute_statement( + cls, + session: Session, + statement: dml.Update, + params: _CoreAnyExecuteParams, + execution_options: _ExecuteOptionsParameter, + bind_arguments: _BindArguments, + conn: Connection, + ) -> _result.Result: + + update_options = execution_options.get( + "_sa_orm_update_options", cls.default_update_options + ) + + if update_options._dml_strategy not in ("orm", "auto", "bulk"): + raise sa_exc.ArgumentError( + "Valid strategies for ORM UPDATE strategy " + "are 'orm', 'auto', 'bulk'" + ) + + result: _result.Result[Any] + + if update_options._dml_strategy == "bulk": + if statement._where_criteria: + raise sa_exc.InvalidRequestError( + "WHERE clause with bulk ORM UPDATE not " + "supported right now. Statement may be invoked at the " + "Core level using " + "session.connection().execute(stmt, parameters)" + ) + mapper = update_options._subject_mapper + assert mapper is not None + assert session._transaction is not None + result = _bulk_update( + mapper, + cast( + "Iterable[Dict[str, Any]]", + [params] if isinstance(params, dict) else params, + ), + session._transaction, + isstates=False, + update_changed_only=False, + use_orm_update_stmt=statement, + ) + return cls.orm_setup_cursor_result( + session, + statement, + params, + execution_options, + bind_arguments, + result, + ) + else: + return super().orm_execute_statement( + session, + statement, + params, + execution_options, + bind_arguments, + conn, + ) @classmethod def can_use_returning( @@ -827,119 +1559,80 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): return True @classmethod - def _get_crud_kv_pairs(cls, statement, kv_iterator): - plugin_subject = statement._propagate_attrs["plugin_subject"] - - core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs - - if not plugin_subject or not plugin_subject.mapper: - return core_get_crud_kv_pairs(statement, kv_iterator) - - mapper = plugin_subject.mapper - - values = [] - - for k, v in kv_iterator: - k = coercions.expect(roles.DMLColumnRole, k) + def _do_post_synchronize_bulk_evaluate( + cls, session, params, result, update_options + ): + if not params: + return - if isinstance(k, str): - desc = _entity_namespace_key(mapper, k, default=NO_VALUE) - if desc is NO_VALUE: - values.append( - ( - k, - coercions.expect( - roles.ExpressionElementRole, - v, - type_=sqltypes.NullType(), - is_crud=True, - ), - ) - ) - else: - values.extend( - core_get_crud_kv_pairs( - statement, desc._bulk_update_tuples(v) - ) - ) - elif "entity_namespace" in k._annotations: - k_anno = k._annotations - attr = _entity_namespace_key( - k_anno["entity_namespace"], k_anno["proxy_key"] - ) - values.extend( - core_get_crud_kv_pairs( - statement, attr._bulk_update_tuples(v) - ) - ) - else: - values.append( - ( - k, - coercions.expect( - roles.ExpressionElementRole, - v, - type_=sqltypes.NullType(), - is_crud=True, - ), - ) - ) - return values + mapper = update_options._subject_mapper + pk_keys = [prop.key for prop in mapper._identity_key_props] - @classmethod - def _do_post_synchronize_evaluate(cls, session, result, update_options): + identity_map = session.identity_map - states = set() - evaluated_keys = list(update_options._value_evaluators.keys()) - values = update_options._resolved_keys_as_propnames - attrib = set(k for k, v in values) - for obj in update_options._matched_objects: - - state, dict_ = ( - attributes.instance_state(obj), - attributes.instance_dict(obj), + for param in params: + identity_key = mapper.identity_key_from_primary_key( + (param[key] for key in pk_keys), + update_options._refresh_identity_token, ) - - # the evaluated states were gathered across all identity tokens. - # however the post_sync events are called per identity token, - # so filter. - if ( - update_options._refresh_identity_token is not None - and state.identity_token - != update_options._refresh_identity_token - ): + state = identity_map.fast_get_state(identity_key) + if not state: continue + evaluated_keys = set(param).difference(pk_keys) + + dict_ = state.dict # only evaluate unmodified attributes to_evaluate = state.unmodified.intersection(evaluated_keys) for key in to_evaluate: if key in dict_: - dict_[key] = update_options._value_evaluators[key](obj) + dict_[key] = param[key] state.manager.dispatch.refresh(state, None, to_evaluate) state._commit(dict_, list(to_evaluate)) - to_expire = attrib.intersection(dict_).difference(to_evaluate) + # attributes that were formerly modified instead get expired. + # this only gets hit if the session had pending changes + # and autoflush were set to False. + to_expire = evaluated_keys.intersection(dict_).difference( + to_evaluate + ) if to_expire: state._expire_attributes(dict_, to_expire) - states.add(state) - session._register_altered(states) + @classmethod + def _do_post_synchronize_evaluate( + cls, session, statement, result, update_options + ): + + matched_objects = cls._get_matched_objects_on_criteria( + update_options, + session.identity_map.all_states(), + ) + + cls._apply_update_set_values_to_objects( + session, + update_options, + statement, + [(obj, state, dict_) for obj, state, dict_, _ in matched_objects], + ) @classmethod - def _do_post_synchronize_fetch(cls, session, result, update_options): + def _do_post_synchronize_fetch( + cls, session, statement, result, update_options + ): target_mapper = update_options._subject_mapper - states = set() - evaluated_keys = list(update_options._value_evaluators.keys()) - - if result.returns_rows: - rows = cls._interpret_returning_rows(target_mapper, result.all()) + returned_defaults_rows = result.returned_defaults_rows + if returned_defaults_rows: + pk_rows = cls._interpret_returning_rows( + target_mapper, returned_defaults_rows + ) matched_rows = [ tuple(row) + (update_options._refresh_identity_token,) - for row in rows + for row in pk_rows ] else: matched_rows = update_options._matched_rows @@ -960,23 +1653,69 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): if identity_key in session.identity_map ] - values = update_options._resolved_keys_as_propnames - attrib = set(k for k, v in values) + if not objs: + return - for obj in objs: - state, dict_ = ( - attributes.instance_state(obj), - attributes.instance_dict(obj), - ) + cls._apply_update_set_values_to_objects( + session, + update_options, + statement, + [ + ( + obj, + attributes.instance_state(obj), + attributes.instance_dict(obj), + ) + for obj in objs + ], + ) + + @classmethod + def _apply_update_set_values_to_objects( + cls, session, update_options, statement, matched_objects + ): + """apply values to objects derived from an update statement, e.g. + UPDATE..SET <values> + + """ + mapper = update_options._subject_mapper + target_cls = mapper.class_ + evaluator_compiler = evaluator.EvaluatorCompiler(target_cls) + resolved_values = cls._get_resolved_values(mapper, statement) + resolved_keys_as_propnames = cls._resolved_keys_as_propnames( + mapper, resolved_values + ) + value_evaluators = {} + for key, value in resolved_keys_as_propnames: + try: + _evaluator = evaluator_compiler.process( + coercions.expect(roles.ExpressionElementRole, value) + ) + except evaluator.UnevaluatableError: + pass + else: + value_evaluators[key] = _evaluator + + evaluated_keys = list(value_evaluators.keys()) + attrib = set(k for k, v in resolved_keys_as_propnames) + + states = set() + for obj, state, dict_ in matched_objects: to_evaluate = state.unmodified.intersection(evaluated_keys) + for key in to_evaluate: if key in dict_: - dict_[key] = update_options._value_evaluators[key](obj) + # only run eval for attributes that are present. + dict_[key] = value_evaluators[key](obj) + state.manager.dispatch.refresh(state, None, to_evaluate) state._commit(dict_, list(to_evaluate)) + # attributes that were formerly modified instead get expired. + # this only gets hit if the session had pending changes + # and autoflush were set to False. to_expire = attrib.intersection(dict_).difference(to_evaluate) if to_expire: state._expire_attributes(dict_, to_expire) @@ -991,6 +1730,8 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState): def create_for_statement(cls, statement, compiler, **kw): self = cls.__new__(cls) + orm_level_statement = statement + ext_info = statement.table._annotations["parententity"] self.mapper = mapper = ext_info.mapper @@ -1002,31 +1743,97 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState): if opt._is_criteria_option: opt.get_global_criteria(extra_criteria_attributes) + new_stmt = statement._clone() + new_stmt.table = mapper.local_table + new_crit = cls._adjust_for_extra_criteria( extra_criteria_attributes, mapper ) if new_crit: - statement = statement.where(*new_crit) + new_stmt = new_stmt.where(*new_crit) # do this first as we need to determine if there is # DELETE..FROM - DeleteDMLState.__init__(self, statement, compiler, **kw) + DeleteDMLState.__init__(self, new_stmt, compiler, **kw) + + use_supplemental_cols = False - if compiler._annotations.get( + synchronize_session = compiler._annotations.get( "synchronize_session", None - ) == "fetch" and self.can_use_returning( - compiler.dialect, - mapper, - is_multitable=self.is_multitable, - is_delete_using=compiler._annotations.get( - "is_delete_using", False - ), - ): - self.statement = statement.returning(*statement.table.primary_key) + ) + can_use_returning = compiler._annotations.get( + "can_use_returning", None + ) + if can_use_returning is not False: + # even though pre_exec has determined basic + # can_use_returning for the dialect, if we are to use + # RETURNING we need to run can_use_returning() at this level + # unconditionally because is_delete_using was not known + # at the pre_exec level + can_use_returning = ( + synchronize_session == "fetch" + and self.can_use_returning( + compiler.dialect, + mapper, + is_multitable=self.is_multitable, + is_delete_using=compiler._annotations.get( + "is_delete_using", False + ), + ) + ) + + if can_use_returning: + use_supplemental_cols = True + + new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key) + + new_stmt = self._setup_orm_returning( + compiler, + orm_level_statement, + new_stmt, + use_supplemental_cols=use_supplemental_cols, + ) + + self.statement = new_stmt return self @classmethod + def orm_execute_statement( + cls, + session: Session, + statement: dml.Delete, + params: _CoreAnyExecuteParams, + execution_options: _ExecuteOptionsParameter, + bind_arguments: _BindArguments, + conn: Connection, + ) -> _result.Result: + + update_options = execution_options.get( + "_sa_orm_update_options", cls.default_update_options + ) + + if update_options._dml_strategy == "bulk": + raise sa_exc.InvalidRequestError( + "Bulk ORM DELETE not supported right now. " + "Statement may be invoked at the " + "Core level using " + "session.connection().execute(stmt, parameters)" + ) + + if update_options._dml_strategy not in ( + "orm", + "auto", + ): + raise sa_exc.ArgumentError( + "Valid strategies for ORM DELETE strategy are 'orm', 'auto'" + ) + + return super().orm_execute_statement( + session, statement, params, execution_options, bind_arguments, conn + ) + + @classmethod def can_use_returning( cls, dialect: Dialect, @@ -1068,25 +1875,41 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState): return True @classmethod - def _do_post_synchronize_evaluate(cls, session, result, update_options): - - session._remove_newly_deleted( - [ - attributes.instance_state(obj) - for obj in update_options._matched_objects - ] + def _do_post_synchronize_evaluate( + cls, session, statement, result, update_options + ): + matched_objects = cls._get_matched_objects_on_criteria( + update_options, + session.identity_map.all_states(), ) + to_delete = [] + + for _, state, dict_, is_partially_expired in matched_objects: + if is_partially_expired: + state._expire(dict_, session.identity_map._modified) + else: + to_delete.append(state) + + if to_delete: + session._remove_newly_deleted(to_delete) + @classmethod - def _do_post_synchronize_fetch(cls, session, result, update_options): + def _do_post_synchronize_fetch( + cls, session, statement, result, update_options + ): target_mapper = update_options._subject_mapper - if result.returns_rows: - rows = cls._interpret_returning_rows(target_mapper, result.all()) + returned_defaults_rows = result.returned_defaults_rows + + if returned_defaults_rows: + pk_rows = cls._interpret_returning_rows( + target_mapper, returned_defaults_rows + ) matched_rows = [ tuple(row) + (update_options._refresh_identity_token,) - for row in rows + for row in pk_rows ] else: matched_rows = update_options._matched_rows diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index dc96f8c3c..f8c7ba714 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -73,6 +73,7 @@ if TYPE_CHECKING: from .query import Query from .session import _BindArguments from .session import Session + from ..engine import Result from ..engine.interfaces import _CoreSingleExecuteParams from ..engine.interfaces import _ExecuteOptionsParameter from ..sql._typing import _ColumnsClauseArgument @@ -203,15 +204,19 @@ _orm_load_exec_options = util.immutabledict( class AbstractORMCompileState(CompileState): + is_dml_returning = False + @classmethod def create_for_statement( cls, statement: Union[Select, FromStatement], compiler: Optional[SQLCompiler], **kw: Any, - ) -> ORMCompileState: + ) -> AbstractORMCompileState: """Create a context for a statement given a :class:`.Compiler`. + This method is always invoked in the context of SQLCompiler.process(). + For a Select object, this would be invoked from SQLCompiler.visit_select(). For the special FromStatement object used by Query to indicate "Query.from_statement()", this is called by @@ -233,6 +238,28 @@ class AbstractORMCompileState(CompileState): raise NotImplementedError() @classmethod + def orm_execute_statement( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + conn, + ) -> Result: + result = conn.execute( + statement, params or {}, execution_options=execution_options + ) + return cls.orm_setup_cursor_result( + session, + statement, + params, + execution_options, + bind_arguments, + result, + ) + + @classmethod def orm_setup_cursor_result( cls, session, @@ -309,6 +336,17 @@ class ORMCompileState(AbstractORMCompileState): def __init__(self, *arg, **kw): raise NotImplementedError() + if TYPE_CHECKING: + + @classmethod + def create_for_statement( + cls, + statement: Union[Select, FromStatement], + compiler: Optional[SQLCompiler], + **kw: Any, + ) -> ORMCompileState: + ... + def _append_dedupe_col_collection(self, obj, col_collection): dedupe = self.dedupe_columns if obj not in dedupe: @@ -333,26 +371,6 @@ class ORMCompileState(AbstractORMCompileState): return SelectState._column_naming_convention(label_style) @classmethod - def create_for_statement( - cls, - statement: Union[Select, FromStatement], - compiler: Optional[SQLCompiler], - **kw: Any, - ) -> ORMCompileState: - """Create a context for a statement given a :class:`.Compiler`. - - This method is always invoked in the context of SQLCompiler.process(). - - For a Select object, this would be invoked from - SQLCompiler.visit_select(). For the special FromStatement object used - by Query to indicate "Query.from_statement()", this is called by - FromStatement._compiler_dispatch() that would be called by - SQLCompiler.process(). - - """ - raise NotImplementedError() - - @classmethod def get_column_descriptions(cls, statement): return _column_descriptions(statement) @@ -518,6 +536,49 @@ class ORMCompileState(AbstractORMCompileState): ) +class DMLReturningColFilter: + """an adapter used for the DML RETURNING case. + + Has a subset of the interface used by + :class:`.ORMAdapter` and is used for :class:`._QueryEntity` + instances to set up their columns as used in RETURNING for a + DML statement. + + """ + + __slots__ = ("mapper", "columns", "__weakref__") + + def __init__(self, target_mapper, immediate_dml_mapper): + if ( + immediate_dml_mapper is not None + and target_mapper.local_table + is not immediate_dml_mapper.local_table + ): + # joined inh, or in theory other kinds of multi-table mappings + self.mapper = immediate_dml_mapper + else: + # single inh, normal mappings, etc. + self.mapper = target_mapper + self.columns = self.columns = util.WeakPopulateDict( + self.adapt_check_present # type: ignore + ) + + def __call__(self, col, as_filter): + for cc in sql_util._find_columns(col): + c2 = self.adapt_check_present(cc) + if c2 is not None: + return col + else: + return None + + def adapt_check_present(self, col): + mapper = self.mapper + prop = mapper._columntoproperty.get(col, None) + if prop is None: + return None + return mapper.local_table.c.corresponding_column(col) + + @sql.base.CompileState.plugin_for("orm", "orm_from_statement") class ORMFromStatementCompileState(ORMCompileState): _from_obj_alias = None @@ -525,7 +586,7 @@ class ORMFromStatementCompileState(ORMCompileState): statement_container: FromStatement requested_statement: Union[SelectBase, TextClause, UpdateBase] - dml_table: _DMLTableElement + dml_table: Optional[_DMLTableElement] = None _has_orm_entities = False multi_row_eager_loaders = False @@ -541,7 +602,7 @@ class ORMFromStatementCompileState(ORMCompileState): statement_container: Union[Select, FromStatement], compiler: Optional[SQLCompiler], **kw: Any, - ) -> ORMCompileState: + ) -> ORMFromStatementCompileState: if compiler is not None: toplevel = not compiler.stack @@ -565,6 +626,7 @@ class ORMFromStatementCompileState(ORMCompileState): if statement.is_dml: self.dml_table = statement.table + self.is_dml_returning = True self._entities = [] self._polymorphic_adapters = {} @@ -674,6 +736,18 @@ class ORMFromStatementCompileState(ORMCompileState): def _get_current_adapter(self): return None + def setup_dml_returning_compile_state(self, dml_mapper): + """used by BulkORMInsert (and Update / Delete?) to set up a handler + for RETURNING to return ORM objects and expressions + + """ + target_mapper = self.statement._propagate_attrs.get( + "plugin_subject", None + ) + adapter = DMLReturningColFilter(target_mapper, dml_mapper) + for entity in self._entities: + entity.setup_dml_returning_compile_state(self, adapter) + class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]): """Core construct that represents a load of ORM objects from various @@ -813,7 +887,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): statement: Union[Select, FromStatement], compiler: Optional[SQLCompiler], **kw: Any, - ) -> ORMCompileState: + ) -> ORMSelectCompileState: """compiler hook, we arrive here from compiler.visit_select() only.""" self = cls.__new__(cls) @@ -2312,6 +2386,13 @@ class _QueryEntity: def setup_compile_state(self, compile_state: ORMCompileState) -> None: raise NotImplementedError() + def setup_dml_returning_compile_state( + self, + compile_state: ORMCompileState, + adapter: DMLReturningColFilter, + ) -> None: + raise NotImplementedError() + def row_processor(self, context, result): raise NotImplementedError() @@ -2509,8 +2590,24 @@ class _MapperEntity(_QueryEntity): return _instance, self._label_name, self._extra_entities - def setup_compile_state(self, compile_state): + def setup_dml_returning_compile_state( + self, + compile_state: ORMCompileState, + adapter: DMLReturningColFilter, + ) -> None: + loading._setup_entity_query( + compile_state, + self.mapper, + self, + self.path, + adapter, + compile_state.primary_columns, + with_polymorphic=self._with_polymorphic_mappers, + only_load_props=compile_state.compile_options._only_load_props, + polymorphic_discriminator=self._polymorphic_discriminator, + ) + def setup_compile_state(self, compile_state): adapter = self._get_entity_clauses(compile_state) single_table_crit = self.mapper._single_table_criterion @@ -2536,7 +2633,6 @@ class _MapperEntity(_QueryEntity): only_load_props=compile_state.compile_options._only_load_props, polymorphic_discriminator=self._polymorphic_discriminator, ) - compile_state._fallback_from_clauses.append(self.selectable) @@ -2743,9 +2839,7 @@ class _ColumnEntity(_QueryEntity): getter, label_name, extra_entities = self._row_processor if self.translate_raw_column: extra_entities += ( - result.context.invoked_statement._raw_columns[ - self.raw_column_index - ], + context.query._raw_columns[self.raw_column_index], ) return getter, label_name, extra_entities @@ -2781,9 +2875,7 @@ class _ColumnEntity(_QueryEntity): if self.translate_raw_column: extra_entities = self._extra_entities + ( - result.context.invoked_statement._raw_columns[ - self.raw_column_index - ], + context.query._raw_columns[self.raw_column_index], ) return getter, self._label_name, extra_entities else: @@ -2843,6 +2935,8 @@ class _RawColumnEntity(_ColumnEntity): current_adapter = compile_state._get_current_adapter() if current_adapter: column = current_adapter(self.column, False) + if column is None: + return else: column = self.column @@ -2944,10 +3038,25 @@ class _ORMColumnEntity(_ColumnEntity): self.entity_zero ) and entity.common_parent(self.entity_zero) + def setup_dml_returning_compile_state( + self, + compile_state: ORMCompileState, + adapter: DMLReturningColFilter, + ) -> None: + self._fetch_column = self.column + column = adapter(self.column, False) + if column is not None: + compile_state.dedupe_columns.add(column) + compile_state.primary_columns.append(column) + def setup_compile_state(self, compile_state): current_adapter = compile_state._get_current_adapter() if current_adapter: column = current_adapter(self.column, False) + if column is None: + assert compile_state.is_dml_returning + self._fetch_column = self.column + return else: column = self.column diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index 52b70b9d4..13d3b70fe 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -19,6 +19,7 @@ import operator import typing from typing import Any from typing import Callable +from typing import Dict from typing import List from typing import NoReturn from typing import Optional @@ -602,6 +603,31 @@ class Composite( def _attribute_keys(self) -> Sequence[str]: return [prop.key for prop in self.props] + def _populate_composite_bulk_save_mappings_fn( + self, + ) -> Callable[[Dict[str, Any]], None]: + + if self._generated_composite_accessor: + get_values = self._generated_composite_accessor + else: + + def get_values(val: Any) -> Tuple[Any]: + return val.__composite_values__() # type: ignore + + attrs = [prop.key for prop in self.props] + + def populate(dest_dict: Dict[str, Any]) -> None: + dest_dict.update( + { + key: val + for key, val in zip( + attrs, get_values(dest_dict.pop(self.key)) + ) + } + ) + + return populate + def get_history( self, state: InstanceState[Any], diff --git a/lib/sqlalchemy/orm/evaluator.py b/lib/sqlalchemy/orm/evaluator.py index b3129afdd..5af14cc00 100644 --- a/lib/sqlalchemy/orm/evaluator.py +++ b/lib/sqlalchemy/orm/evaluator.py @@ -9,8 +9,8 @@ from __future__ import annotations -import operator - +from .base import LoaderCallableStatus +from .base import PassiveFlag from .. import exc from .. import inspect from .. import util @@ -32,7 +32,16 @@ class _NoObject(operators.ColumnOperators): return None +class _ExpiredObject(operators.ColumnOperators): + def operate(self, *arg, **kw): + return self + + def reverse_operate(self, *arg, **kw): + return self + + _NO_OBJECT = _NoObject() +_EXPIRED_OBJECT = _ExpiredObject() class EvaluatorCompiler: @@ -73,6 +82,24 @@ class EvaluatorCompiler: f"alternate class {parentmapper.class_}" ) key = parentmapper._columntoproperty[clause].key + impl = parentmapper.class_manager[key].impl + + if impl is not None: + + def get_corresponding_attr(obj): + if obj is None: + return _NO_OBJECT + state = inspect(obj) + dict_ = state.dict + + value = impl.get( + state, dict_, passive=PassiveFlag.PASSIVE_NO_FETCH + ) + if value is LoaderCallableStatus.PASSIVE_NO_RESULT: + return _EXPIRED_OBJECT + return value + + return get_corresponding_attr else: key = clause.key if ( @@ -85,15 +112,16 @@ class EvaluatorCompiler: "make use of the actual mapped columns in ORM-evaluated " "UPDATE / DELETE expressions." ) + else: raise UnevaluatableError(f"Cannot evaluate column: {clause}") - get_corresponding_attr = operator.attrgetter(key) - return ( - lambda obj: get_corresponding_attr(obj) - if obj is not None - else _NO_OBJECT - ) + def get_corresponding_attr(obj): + if obj is None: + return _NO_OBJECT + return getattr(obj, key, _EXPIRED_OBJECT) + + return get_corresponding_attr def visit_tuple(self, clause): return self.visit_clauselist(clause) @@ -134,7 +162,9 @@ class EvaluatorCompiler: has_null = False for sub_evaluate in evaluators: value = sub_evaluate(obj) - if value: + if value is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + elif value: return True has_null = has_null or value is None if has_null: @@ -147,6 +177,9 @@ class EvaluatorCompiler: def evaluate(obj): for sub_evaluate in evaluators: value = sub_evaluate(obj) + if value is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + if not value: if value is None or value is _NO_OBJECT: return None @@ -160,7 +193,9 @@ class EvaluatorCompiler: values = [] for sub_evaluate in evaluators: value = sub_evaluate(obj) - if value is None or value is _NO_OBJECT: + if value is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + elif value is None or value is _NO_OBJECT: return None values.append(value) return tuple(values) @@ -183,13 +218,21 @@ class EvaluatorCompiler: def visit_is_binary_op(self, operator, eval_left, eval_right, clause): def evaluate(obj): - return eval_left(obj) == eval_right(obj) + left_val = eval_left(obj) + right_val = eval_right(obj) + if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + return left_val == right_val return evaluate def visit_is_not_binary_op(self, operator, eval_left, eval_right, clause): def evaluate(obj): - return eval_left(obj) != eval_right(obj) + left_val = eval_left(obj) + right_val = eval_right(obj) + if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + return left_val != right_val return evaluate @@ -197,8 +240,11 @@ class EvaluatorCompiler: def evaluate(obj): left_val = eval_left(obj) right_val = eval_right(obj) - if left_val is None or right_val is None: + if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + elif left_val is None or right_val is None: return None + return operator(eval_left(obj), eval_right(obj)) return evaluate @@ -274,7 +320,9 @@ class EvaluatorCompiler: def evaluate(obj): value = eval_inner(obj) - if value is None: + if value is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + elif value is None: return None return not value diff --git a/lib/sqlalchemy/orm/identity.py b/lib/sqlalchemy/orm/identity.py index 63b131a78..4848f73f1 100644 --- a/lib/sqlalchemy/orm/identity.py +++ b/lib/sqlalchemy/orm/identity.py @@ -68,6 +68,11 @@ class IdentityMap: ) -> Optional[_O]: raise NotImplementedError() + def fast_get_state( + self, key: _IdentityKeyType[_O] + ) -> Optional[InstanceState[_O]]: + raise NotImplementedError() + def keys(self) -> Iterable[_IdentityKeyType[Any]]: return self._dict.keys() @@ -206,6 +211,11 @@ class WeakInstanceDict(IdentityMap): self._dict[key] = state state._instance_dict = self._wr + def fast_get_state( + self, key: _IdentityKeyType[_O] + ) -> Optional[InstanceState[_O]]: + return self._dict.get(key) + def get( self, key: _IdentityKeyType[_O], default: Optional[_O] = None ) -> Optional[_O]: diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index 7317d48be..64f2542fd 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -29,7 +29,6 @@ from typing import TYPE_CHECKING from typing import TypeVar from typing import Union -from sqlalchemy.orm.context import FromStatement from . import attributes from . import exc as orm_exc from . import path_registry @@ -37,6 +36,7 @@ from .base import _DEFER_FOR_STATE from .base import _RAISE_FOR_STATE from .base import _SET_DEFERRED_EXPIRED from .base import PassiveFlag +from .context import FromStatement from .util import _none_set from .util import state_str from .. import exc as sa_exc @@ -50,6 +50,7 @@ from ..sql import util as sql_util from ..sql.selectable import ForUpdateArg from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from ..sql.selectable import SelectState +from ..util import EMPTY_DICT if TYPE_CHECKING: from ._typing import _IdentityKeyType @@ -764,7 +765,7 @@ def _instance_processor( ) quick_populators = path.get( - context.attributes, "memoized_setups", _none_set + context.attributes, "memoized_setups", EMPTY_DICT ) todo = [] diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index c8df51b06..c9cf8f49b 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -854,6 +854,7 @@ class Mapper( _memoized_values: Dict[Any, Callable[[], Any]] _inheriting_mappers: util.WeakSequence[Mapper[Any]] _all_tables: Set[Table] + _polymorphic_attr_key: Optional[str] _pks_by_table: Dict[FromClause, OrderedSet[ColumnClause[Any]]] _cols_by_table: Dict[FromClause, OrderedSet[ColumnElement[Any]]] @@ -1653,6 +1654,7 @@ class Mapper( """ setter = False + polymorphic_key: Optional[str] = None if self.polymorphic_on is not None: setter = True @@ -1772,17 +1774,23 @@ class Mapper( self._set_polymorphic_identity = ( mapper._set_polymorphic_identity ) + self._polymorphic_attr_key = ( + mapper._polymorphic_attr_key + ) self._validate_polymorphic_identity = ( mapper._validate_polymorphic_identity ) else: self._set_polymorphic_identity = None + self._polymorphic_attr_key = None return if setter: def _set_polymorphic_identity(state): dict_ = state.dict + # TODO: what happens if polymorphic_on column attribute name + # does not match .key? state.get_impl(polymorphic_key).set( state, dict_, @@ -1790,6 +1798,8 @@ class Mapper( None, ) + self._polymorphic_attr_key = polymorphic_key + def _validate_polymorphic_identity(mapper, state, dict_): if ( polymorphic_key in dict_ @@ -1808,6 +1818,7 @@ class Mapper( _validate_polymorphic_identity ) else: + self._polymorphic_attr_key = None self._set_polymorphic_identity = None _validate_polymorphic_identity = None @@ -3562,6 +3573,10 @@ class Mapper( return util.LRUCache(self._compiled_cache_size) @HasMemoized.memoized_attribute + def _multiple_persistence_tables(self): + return len(self.tables) > 1 + + @HasMemoized.memoized_attribute def _sorted_tables(self): table_to_mapper: Dict[Table, Mapper[Any]] = {} diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index abd528986..dfb61c28a 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -31,6 +31,7 @@ from .. import exc as sa_exc from .. import future from .. import sql from .. import util +from ..engine import cursor as _cursor from ..sql import operators from ..sql.elements import BooleanClauseList from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL @@ -398,6 +399,11 @@ def _collect_insert_commands( None ) + if bulk and mapper._set_polymorphic_identity: + params.setdefault( + mapper._polymorphic_attr_key, mapper.polymorphic_identity + ) + yield ( state, state_dict, @@ -411,7 +417,11 @@ def _collect_insert_commands( def _collect_update_commands( - uowtransaction, table, states_to_update, bulk=False + uowtransaction, + table, + states_to_update, + bulk=False, + use_orm_update_stmt=None, ): """Identify sets of values to use in UPDATE statements for a list of states. @@ -437,7 +447,11 @@ def _collect_update_commands( pks = mapper._pks_by_table[table] - value_params = {} + if use_orm_update_stmt is not None: + # TODO: ordered values, etc + value_params = use_orm_update_stmt._values + else: + value_params = {} propkey_to_col = mapper._propkey_to_col[table] @@ -697,6 +711,7 @@ def _emit_update_statements( table, update, bookkeeping=True, + use_orm_update_stmt=None, ): """Emit UPDATE statements corresponding to value lists collected by _collect_update_commands().""" @@ -708,7 +723,7 @@ def _emit_update_statements( execution_options = {"compiled_cache": base_mapper._compiled_cache} - def update_stmt(): + def update_stmt(existing_stmt=None): clauses = BooleanClauseList._construct_raw(operators.and_) for col in mapper._pks_by_table[table]: @@ -725,10 +740,17 @@ def _emit_update_statements( ) ) - stmt = table.update().where(clauses) + if existing_stmt is not None: + stmt = existing_stmt.where(clauses) + else: + stmt = table.update().where(clauses) return stmt - cached_stmt = base_mapper._memo(("update", table), update_stmt) + if use_orm_update_stmt is not None: + cached_stmt = update_stmt(use_orm_update_stmt) + + else: + cached_stmt = base_mapper._memo(("update", table), update_stmt) for ( (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks), @@ -747,6 +769,15 @@ def _emit_update_statements( records = list(records) statement = cached_stmt + + if use_orm_update_stmt is not None: + statement = statement._annotate( + { + "_emit_update_table": table, + "_emit_update_mapper": mapper, + } + ) + return_defaults = False if not has_all_pks: @@ -904,16 +935,35 @@ def _emit_insert_statements( table, insert, bookkeeping=True, + use_orm_insert_stmt=None, + execution_options=None, ): """Emit INSERT statements corresponding to value lists collected by _collect_insert_commands().""" - cached_stmt = base_mapper._memo(("insert", table), table.insert) + if use_orm_insert_stmt is not None: + cached_stmt = use_orm_insert_stmt + exec_opt = util.EMPTY_DICT - execution_options = {"compiled_cache": base_mapper._compiled_cache} + # if a user query with RETURNING was passed, we definitely need + # to use RETURNING. + returning_is_required_anyway = bool(use_orm_insert_stmt._returning) + else: + returning_is_required_anyway = False + cached_stmt = base_mapper._memo(("insert", table), table.insert) + exec_opt = {"compiled_cache": base_mapper._compiled_cache} + + if execution_options: + execution_options = util.EMPTY_DICT.merge_with( + exec_opt, execution_options + ) + else: + execution_options = exec_opt + + return_result = None for ( - (connection, pkeys, hasvalue, has_all_pks, has_all_defaults), + (connection, _, hasvalue, has_all_pks, has_all_defaults), records, ) in groupby( insert, @@ -928,17 +978,29 @@ def _emit_insert_statements( statement = cached_stmt + if use_orm_insert_stmt is not None: + statement = statement._annotate( + { + "_emit_insert_table": table, + "_emit_insert_mapper": mapper, + } + ) + if ( - not bookkeeping - or ( - has_all_defaults - or not base_mapper.eager_defaults - or not base_mapper.local_table.implicit_returning - or not connection.dialect.insert_returning + ( + not bookkeeping + or ( + has_all_defaults + or not base_mapper.eager_defaults + or not base_mapper.local_table.implicit_returning + or not connection.dialect.insert_returning + ) ) + and not returning_is_required_anyway and has_all_pks and not hasvalue ): + # the "we don't need newly generated values back" section. # here we have all the PKs, all the defaults or we don't want # to fetch them, or the dialect doesn't support RETURNING at all @@ -946,7 +1008,7 @@ def _emit_insert_statements( records = list(records) multiparams = [rec[2] for rec in records] - c = connection.execute( + result = connection.execute( statement, multiparams, execution_options=execution_options ) if bookkeeping: @@ -962,7 +1024,7 @@ def _emit_insert_statements( has_all_defaults, ), last_inserted_params, - ) in zip(records, c.context.compiled_parameters): + ) in zip(records, result.context.compiled_parameters): if state: _postfetch( mapper_rec, @@ -970,19 +1032,20 @@ def _emit_insert_statements( table, state, state_dict, - c, + result, last_inserted_params, value_params, False, - c.returned_defaults - if not c.context.executemany + result.returned_defaults + if not result.context.executemany else None, ) else: _postfetch_bulk_save(mapper_rec, state_dict, table) else: - # here, we need defaults and/or pk values back. + # here, we need defaults and/or pk values back or we otherwise + # know that we are using RETURNING in any case records = list(records) if ( @@ -991,6 +1054,16 @@ def _emit_insert_statements( and len(records) > 1 ): do_executemany = True + elif returning_is_required_anyway: + if connection.dialect.insert_executemany_returning: + do_executemany = True + else: + raise sa_exc.InvalidRequestError( + f"Can't use explicit RETURNING for bulk INSERT " + f"operation with " + f"{connection.dialect.dialect_description} backend; " + f"executemany is not supported with RETURNING" + ) else: do_executemany = False @@ -998,6 +1071,7 @@ def _emit_insert_statements( statement = statement.return_defaults( *mapper._server_default_cols[table] ) + if mapper.version_id_col is not None: statement = statement.return_defaults(mapper.version_id_col) elif do_executemany: @@ -1006,10 +1080,16 @@ def _emit_insert_statements( if do_executemany: multiparams = [rec[2] for rec in records] - c = connection.execute( + result = connection.execute( statement, multiparams, execution_options=execution_options ) + if use_orm_insert_stmt is not None: + if return_result is None: + return_result = result + else: + return_result = return_result.splice_vertically(result) + if bookkeeping: for ( ( @@ -1027,9 +1107,9 @@ def _emit_insert_statements( returned_defaults, ) in zip_longest( records, - c.context.compiled_parameters, - c.inserted_primary_key_rows, - c.returned_defaults_rows or (), + result.context.compiled_parameters, + result.inserted_primary_key_rows, + result.returned_defaults_rows or (), ): if inserted_primary_key is None: # this is a real problem and means that we didn't @@ -1062,7 +1142,7 @@ def _emit_insert_statements( table, state, state_dict, - c, + result, last_inserted_params, value_params, False, @@ -1071,6 +1151,8 @@ def _emit_insert_statements( else: _postfetch_bulk_save(mapper_rec, state_dict, table) else: + assert not returning_is_required_anyway + for ( state, state_dict, @@ -1132,6 +1214,12 @@ def _emit_insert_statements( else: _postfetch_bulk_save(mapper_rec, state_dict, table) + if use_orm_insert_stmt is not None: + if return_result is None: + return _cursor.null_dml_result() + else: + return return_result + def _emit_post_update_statements( base_mapper, uowtransaction, mapper, table, update diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 6d0f055e4..4d5a98fcf 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -2978,7 +2978,7 @@ class Query( ) def delete( - self, synchronize_session: _SynchronizeSessionArgument = "evaluate" + self, synchronize_session: _SynchronizeSessionArgument = "auto" ) -> int: r"""Perform a DELETE with an arbitrary WHERE clause. @@ -3042,7 +3042,7 @@ class Query( def update( self, values: Dict[_DMLColumnArgument, Any], - synchronize_session: _SynchronizeSessionArgument = "evaluate", + synchronize_session: _SynchronizeSessionArgument = "auto", update_args: Optional[Dict[Any, Any]] = None, ) -> int: r"""Perform an UPDATE with an arbitrary WHERE clause. diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index a690da0d5..64c013306 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -1828,12 +1828,13 @@ class Session(_SessionClassMethods, EventTarget): statement._propagate_attrs.get("compile_state_plugin", None) == "orm" ): - # note that even without "future" mode, we need compile_state_cls = CompileState._get_plugin_class_for_plugin( statement, "orm" ) if TYPE_CHECKING: - assert isinstance(compile_state_cls, ORMCompileState) + assert isinstance( + compile_state_cls, context.AbstractORMCompileState + ) else: compile_state_cls = None @@ -1897,18 +1898,18 @@ class Session(_SessionClassMethods, EventTarget): statement, params or {}, execution_options=execution_options ) - result: Result[Any] = conn.execute( - statement, params or {}, execution_options=execution_options - ) - if compile_state_cls: - result = compile_state_cls.orm_setup_cursor_result( + result: Result[Any] = compile_state_cls.orm_execute_statement( self, statement, - params, + params or {}, execution_options, bind_arguments, - result, + conn, + ) + else: + result = conn.execute( + statement, params or {}, execution_options=execution_options ) if _scalar_result: @@ -2066,7 +2067,7 @@ class Session(_SessionClassMethods, EventTarget): def scalars( self, statement: TypedReturnsRows[Tuple[_T]], - params: Optional[_CoreSingleExecuteParams] = None, + params: Optional[_CoreAnyExecuteParams] = None, *, execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, @@ -2078,7 +2079,7 @@ class Session(_SessionClassMethods, EventTarget): def scalars( self, statement: Executable, - params: Optional[_CoreSingleExecuteParams] = None, + params: Optional[_CoreAnyExecuteParams] = None, *, execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, @@ -2089,7 +2090,7 @@ class Session(_SessionClassMethods, EventTarget): def scalars( self, statement: Executable, - params: Optional[_CoreSingleExecuteParams] = None, + params: Optional[_CoreAnyExecuteParams] = None, *, execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 19c6493db..8652591c8 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -227,6 +227,11 @@ class ColumnLoader(LoaderStrategy): fetch = self.columns[0] if adapter: fetch = adapter.columns[fetch] + if fetch is None: + # None happens here only for dml bulk_persistence cases + # when context.DMLReturningColFilter is used + return + memoized_populators[self.parent_property] = fetch def init_class_attribute(self, mapper): @@ -318,6 +323,12 @@ class ExpressionColumnLoader(ColumnLoader): fetch = columns[0] if adapter: fetch = adapter.columns[fetch] + if fetch is None: + # None is not expected to be the result of any + # adapter implementation here, however there may be theoretical + # usages of returning() with context.DMLReturningColFilter + return + memoized_populators[self.parent_property] = fetch def create_row_processor( diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index 86b2952cb..262048bd1 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -552,6 +552,8 @@ def _new_annotation_type( # e.g. BindParameter, add it if present. if cls.__dict__.get("inherit_cache", False): anno_cls.inherit_cache = True # type: ignore + elif "inherit_cache" in cls.__dict__: + anno_cls.inherit_cache = cls.__dict__["inherit_cache"] # type: ignore anno_cls._is_column_operators = issubclass(cls, operators.ColumnOperators) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 201324a2a..c7e226fcc 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -5166,6 +5166,8 @@ class SQLCompiler(Compiled): delete_stmt, delete_stmt.table, extra_froms ) + crud._get_crud_params(self, delete_stmt, compile_state, toplevel, **kw) + if delete_stmt._hints: dialect_hints, table_text = self._setup_crud_hints( delete_stmt, table_text @@ -5178,13 +5180,14 @@ class SQLCompiler(Compiled): text += table_text - if delete_stmt._returning: - if self.returning_precedes_values: - text += " " + self.returning_clause( - delete_stmt, - delete_stmt._returning, - populate_result_map=toplevel, - ) + if ( + self.implicit_returning or delete_stmt._returning + ) and self.returning_precedes_values: + text += " " + self.returning_clause( + delete_stmt, + self.implicit_returning or delete_stmt._returning, + populate_result_map=toplevel, + ) if extra_froms: extra_from_text = self.delete_extra_from_clause( @@ -5204,10 +5207,12 @@ class SQLCompiler(Compiled): if t: text += " WHERE " + t - if delete_stmt._returning and not self.returning_precedes_values: + if ( + self.implicit_returning or delete_stmt._returning + ) and not self.returning_precedes_values: text += " " + self.returning_clause( delete_stmt, - delete_stmt._returning, + self.implicit_returning or delete_stmt._returning, populate_result_map=toplevel, ) @@ -5297,7 +5302,6 @@ class StrSQLCompiler(SQLCompiler): self._label_select_column(None, c, True, False, {}) for c in base._select_iterables(returning_cols) ] - return "RETURNING " + ", ".join(columns) def update_from_clause( diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index b13377a59..22fffb73a 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -150,6 +150,22 @@ def _get_crud_params( "return_defaults() simultaneously" ) + if compile_state.isdelete: + _setup_delete_return_defaults( + compiler, + stmt, + compile_state, + (), + _getattr_col_key, + _column_as_key, + _col_bind_name, + (), + (), + toplevel, + kw, + ) + return _CrudParams([], []) + # no parameters in the statement, no parameters in the # compiled params - return binds for all columns if compiler.column_keys is None and compile_state._no_parameters: @@ -466,13 +482,6 @@ def _scan_insert_from_select_cols( kw, ): - ( - need_pks, - implicit_returning, - implicit_return_defaults, - postfetch_lastrowid, - ) = _get_returning_modifiers(compiler, stmt, compile_state, toplevel) - cols = [stmt.table.c[_column_as_key(name)] for name in stmt._select_names] assert compiler.stack[-1]["selectable"] is stmt @@ -537,6 +546,8 @@ def _scan_cols( postfetch_lastrowid, ) = _get_returning_modifiers(compiler, stmt, compile_state, toplevel) + assert compile_state.isupdate or compile_state.isinsert + if compile_state._parameter_ordering: parameter_ordering = [ _column_as_key(key) for key in compile_state._parameter_ordering @@ -563,6 +574,13 @@ def _scan_cols( else: autoincrement_col = insert_null_pk_still_autoincrements = None + if stmt._supplemental_returning: + supplemental_returning = set(stmt._supplemental_returning) + else: + supplemental_returning = set() + + compiler_implicit_returning = compiler.implicit_returning + for c in cols: # scan through every column in the target table @@ -627,11 +645,13 @@ def _scan_cols( # column has a DDL-level default, and is either not a pk # column or we don't need the pk. if implicit_return_defaults and c in implicit_return_defaults: - compiler.implicit_returning.append(c) + compiler_implicit_returning.append(c) elif not c.primary_key: compiler.postfetch.append(c) + elif implicit_return_defaults and c in implicit_return_defaults: - compiler.implicit_returning.append(c) + compiler_implicit_returning.append(c) + elif ( c.primary_key and c is not stmt.table._autoincrement_column @@ -652,6 +672,59 @@ def _scan_cols( kw, ) + # adding supplemental cols to implicit_returning in table + # order so that order is maintained between multiple INSERT + # statements which may have different parameters included, but all + # have the same RETURNING clause + if ( + c in supplemental_returning + and c not in compiler_implicit_returning + ): + compiler_implicit_returning.append(c) + + if supplemental_returning: + # we should have gotten every col into implicit_returning, + # however supplemental returning can also have SQL functions etc. + # in it + remaining_supplemental = supplemental_returning.difference( + compiler_implicit_returning + ) + compiler_implicit_returning.extend( + c + for c in stmt._supplemental_returning + if c in remaining_supplemental + ) + + +def _setup_delete_return_defaults( + compiler, + stmt, + compile_state, + parameters, + _getattr_col_key, + _column_as_key, + _col_bind_name, + check_columns, + values, + toplevel, + kw, +): + (_, _, implicit_return_defaults, _) = _get_returning_modifiers( + compiler, stmt, compile_state, toplevel + ) + + if not implicit_return_defaults: + return + + if stmt._return_defaults_columns: + compiler.implicit_returning.extend(implicit_return_defaults) + + if stmt._supplemental_returning: + ir_set = set(compiler.implicit_returning) + compiler.implicit_returning.extend( + c for c in stmt._supplemental_returning if c not in ir_set + ) + def _append_param_parameter( compiler, @@ -743,7 +816,7 @@ def _append_param_parameter( elif compiler.dialect.postfetch_lastrowid: compiler.postfetch_lastrowid = True - elif implicit_return_defaults and c in implicit_return_defaults: + elif implicit_return_defaults and (c in implicit_return_defaults): compiler.implicit_returning.append(c) else: @@ -1303,6 +1376,7 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel): INSERT or UPDATE statement after it's invoked. """ + need_pks = ( toplevel and _compile_state_isinsert(compile_state) @@ -1315,6 +1389,7 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel): ) ) and not stmt._returning + # and (not stmt._returning or stmt._return_defaults) and not compile_state._has_multi_parameters ) @@ -1357,33 +1432,41 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel): or stmt._return_defaults ) ) - if implicit_returning: postfetch_lastrowid = False if _compile_state_isinsert(compile_state): - implicit_return_defaults = implicit_returning and stmt._return_defaults + should_implicit_return_defaults = ( + implicit_returning and stmt._return_defaults + ) elif compile_state.isupdate: - implicit_return_defaults = ( + should_implicit_return_defaults = ( stmt._return_defaults and compile_state._primary_table.implicit_returning and compile_state._supports_implicit_returning and compiler.dialect.update_returning ) + elif compile_state.isdelete: + should_implicit_return_defaults = ( + stmt._return_defaults + and compile_state._primary_table.implicit_returning + and compile_state._supports_implicit_returning + and compiler.dialect.delete_returning + ) else: - # this line is unused, currently we are always - # isinsert or isupdate - implicit_return_defaults = False # pragma: no cover + should_implicit_return_defaults = False # pragma: no cover - if implicit_return_defaults: + if should_implicit_return_defaults: if not stmt._return_defaults_columns: implicit_return_defaults = set(stmt.table.c) else: implicit_return_defaults = set(stmt._return_defaults_columns) + else: + implicit_return_defaults = None return ( need_pks, - implicit_returning, + implicit_returning or should_implicit_return_defaults, implicit_return_defaults, postfetch_lastrowid, ) diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index a08e38800..5145a4a16 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -165,15 +165,32 @@ class DMLState(CompileState): ... @classmethod + def _get_multi_crud_kv_pairs( + cls, + statement: UpdateBase, + multi_kv_iterator: Iterable[Dict[_DMLColumnArgument, Any]], + ) -> List[Dict[_DMLColumnElement, Any]]: + return [ + { + coercions.expect(roles.DMLColumnRole, k): v + for k, v in mapping.items() + } + for mapping in multi_kv_iterator + ] + + @classmethod def _get_crud_kv_pairs( cls, statement: UpdateBase, kv_iterator: Iterable[Tuple[_DMLColumnArgument, Any]], + needs_to_be_cacheable: bool, ) -> List[Tuple[_DMLColumnElement, Any]]: return [ ( coercions.expect(roles.DMLColumnRole, k), - coercions.expect( + v + if not needs_to_be_cacheable + else coercions.expect( roles.ExpressionElementRole, v, type_=NullType(), @@ -269,7 +286,7 @@ class InsertDMLState(DMLState): def _insert_col_keys(self) -> List[str]: # this is also done in crud.py -> _key_getters_for_crud_column return [ - coercions.expect_as_key(roles.DMLColumnRole, col) + coercions.expect(roles.DMLColumnRole, col, as_key=True) for col in self._dict_parameters or () ] @@ -326,7 +343,6 @@ class UpdateDMLState(DMLState): self._extra_froms = ef self.is_multitable = mt = ef - self.include_table_with_column_exprs = bool( mt and compiler.render_table_with_column_in_update_from ) @@ -389,6 +405,7 @@ class UpdateBase( _return_defaults_columns: Optional[ Tuple[_ColumnsClauseElement, ...] ] = None + _supplemental_returning: Optional[Tuple[_ColumnsClauseElement, ...]] = None _returning: Tuple[_ColumnsClauseElement, ...] = () is_dml = True @@ -435,6 +452,215 @@ class UpdateBase( return self @_generative + def return_defaults( + self: SelfUpdateBase, + *cols: _DMLColumnArgument, + supplemental_cols: Optional[Iterable[_DMLColumnArgument]] = None, + ) -> SelfUpdateBase: + """Make use of a :term:`RETURNING` clause for the purpose + of fetching server-side expressions and defaults, for supporting + backends only. + + .. deepalchemy:: + + The :meth:`.UpdateBase.return_defaults` method is used by the ORM + for its internal work in fetching newly generated primary key + and server default values, in particular to provide the underyling + implementation of the :paramref:`_orm.Mapper.eager_defaults` + ORM feature as well as to allow RETURNING support with bulk + ORM inserts. Its behavior is fairly idiosyncratic + and is not really intended for general use. End users should + stick with using :meth:`.UpdateBase.returning` in order to + add RETURNING clauses to their INSERT, UPDATE and DELETE + statements. + + Normally, a single row INSERT statement will automatically populate the + :attr:`.CursorResult.inserted_primary_key` attribute when executed, + which stores the primary key of the row that was just inserted in the + form of a :class:`.Row` object with column names as named tuple keys + (and the :attr:`.Row._mapping` view fully populated as well). The + dialect in use chooses the strategy to use in order to populate this + data; if it was generated using server-side defaults and / or SQL + expressions, dialect-specific approaches such as ``cursor.lastrowid`` + or ``RETURNING`` are typically used to acquire the new primary key + value. + + However, when the statement is modified by calling + :meth:`.UpdateBase.return_defaults` before executing the statement, + additional behaviors take place **only** for backends that support + RETURNING and for :class:`.Table` objects that maintain the + :paramref:`.Table.implicit_returning` parameter at its default value of + ``True``. In these cases, when the :class:`.CursorResult` is returned + from the statement's execution, not only will + :attr:`.CursorResult.inserted_primary_key` be populated as always, the + :attr:`.CursorResult.returned_defaults` attribute will also be + populated with a :class:`.Row` named-tuple representing the full range + of server generated + values from that single row, including values for any columns that + specify :paramref:`_schema.Column.server_default` or which make use of + :paramref:`_schema.Column.default` using a SQL expression. + + When invoking INSERT statements with multiple rows using + :ref:`insertmanyvalues <engine_insertmanyvalues>`, the + :meth:`.UpdateBase.return_defaults` modifier will have the effect of + the :attr:`_engine.CursorResult.inserted_primary_key_rows` and + :attr:`_engine.CursorResult.returned_defaults_rows` attributes being + fully populated with lists of :class:`.Row` objects representing newly + inserted primary key values as well as newly inserted server generated + values for each row inserted. The + :attr:`.CursorResult.inserted_primary_key` and + :attr:`.CursorResult.returned_defaults` attributes will also continue + to be populated with the first row of these two collections. + + If the backend does not support RETURNING or the :class:`.Table` in use + has disabled :paramref:`.Table.implicit_returning`, then no RETURNING + clause is added and no additional data is fetched, however the + INSERT, UPDATE or DELETE statement proceeds normally. + + E.g.:: + + stmt = table.insert().values(data='newdata').return_defaults() + + result = connection.execute(stmt) + + server_created_at = result.returned_defaults['created_at'] + + When used against an UPDATE statement + :meth:`.UpdateBase.return_defaults` instead looks for columns that + include :paramref:`_schema.Column.onupdate` or + :paramref:`_schema.Column.server_onupdate` parameters assigned, when + constructing the columns that will be included in the RETURNING clause + by default if explicit columns were not specified. When used against a + DELETE statement, no columns are included in RETURNING by default, they + instead must be specified explicitly as there are no columns that + normally change values when a DELETE statement proceeds. + + .. versionadded:: 2.0 :meth:`.UpdateBase.return_defaults` is supported + for DELETE statements also and has been moved from + :class:`.ValuesBase` to :class:`.UpdateBase`. + + The :meth:`.UpdateBase.return_defaults` method is mutually exclusive + against the :meth:`.UpdateBase.returning` method and errors will be + raised during the SQL compilation process if both are used at the same + time on one statement. The RETURNING clause of the INSERT, UPDATE or + DELETE statement is therefore controlled by only one of these methods + at a time. + + The :meth:`.UpdateBase.return_defaults` method differs from + :meth:`.UpdateBase.returning` in these ways: + + 1. :meth:`.UpdateBase.return_defaults` method causes the + :attr:`.CursorResult.returned_defaults` collection to be populated + with the first row from the RETURNING result. This attribute is not + populated when using :meth:`.UpdateBase.returning`. + + 2. :meth:`.UpdateBase.return_defaults` is compatible with existing + logic used to fetch auto-generated primary key values that are then + populated into the :attr:`.CursorResult.inserted_primary_key` + attribute. By contrast, using :meth:`.UpdateBase.returning` will + have the effect of the :attr:`.CursorResult.inserted_primary_key` + attribute being left unpopulated. + + 3. :meth:`.UpdateBase.return_defaults` can be called against any + backend. Backends that don't support RETURNING will skip the usage + of the feature, rather than raising an exception. The return value + of :attr:`_engine.CursorResult.returned_defaults` will be ``None`` + for backends that don't support RETURNING or for which the target + :class:`.Table` sets :paramref:`.Table.implicit_returning` to + ``False``. + + 4. An INSERT statement invoked with executemany() is supported if the + backend database driver supports the + :ref:`insertmanyvalues <engine_insertmanyvalues>` + feature which is now supported by most SQLAlchemy-included backends. + When executemany is used, the + :attr:`_engine.CursorResult.returned_defaults_rows` and + :attr:`_engine.CursorResult.inserted_primary_key_rows` accessors + will return the inserted defaults and primary keys. + + .. versionadded:: 1.4 Added + :attr:`_engine.CursorResult.returned_defaults_rows` and + :attr:`_engine.CursorResult.inserted_primary_key_rows` accessors. + In version 2.0, the underlying implementation which fetches and + populates the data for these attributes was generalized to be + supported by most backends, whereas in 1.4 they were only + supported by the ``psycopg2`` driver. + + + :param cols: optional list of column key names or + :class:`_schema.Column` that acts as a filter for those columns that + will be fetched. + :param supplemental_cols: optional list of RETURNING expressions, + in the same form as one would pass to the + :meth:`.UpdateBase.returning` method. When present, the additional + columns will be included in the RETURNING clause, and the + :class:`.CursorResult` object will be "rewound" when returned, so + that methods like :meth:`.CursorResult.all` will return new rows + mostly as though the statement used :meth:`.UpdateBase.returning` + directly. However, unlike when using :meth:`.UpdateBase.returning` + directly, the **order of the columns is undefined**, so can only be + targeted using names or :attr:`.Row._mapping` keys; they cannot + reliably be targeted positionally. + + .. versionadded:: 2.0 + + .. seealso:: + + :meth:`.UpdateBase.returning` + + :attr:`_engine.CursorResult.returned_defaults` + + :attr:`_engine.CursorResult.returned_defaults_rows` + + :attr:`_engine.CursorResult.inserted_primary_key` + + :attr:`_engine.CursorResult.inserted_primary_key_rows` + + """ + + if self._return_defaults: + # note _return_defaults_columns = () means return all columns, + # so if we have been here before, only update collection if there + # are columns in the collection + if self._return_defaults_columns and cols: + self._return_defaults_columns = tuple( + util.OrderedSet(self._return_defaults_columns).union( + coercions.expect(roles.ColumnsClauseRole, c) + for c in cols + ) + ) + else: + # set for all columns + self._return_defaults_columns = () + else: + self._return_defaults_columns = tuple( + coercions.expect(roles.ColumnsClauseRole, c) for c in cols + ) + self._return_defaults = True + + if supplemental_cols: + # uniquifying while also maintaining order (the maintain of order + # is for test suites but also for vertical splicing + supplemental_col_tup = ( + coercions.expect(roles.ColumnsClauseRole, c) + for c in supplemental_cols + ) + + if self._supplemental_returning is None: + self._supplemental_returning = tuple( + util.unique_list(supplemental_col_tup) + ) + else: + self._supplemental_returning = tuple( + util.unique_list( + self._supplemental_returning + + tuple(supplemental_col_tup) + ) + ) + + return self + + @_generative def returning( self, *cols: _ColumnsClauseArgument[Any], **__kw: Any ) -> UpdateBase: @@ -500,7 +726,7 @@ class UpdateBase( .. seealso:: - :meth:`.ValuesBase.return_defaults` - an alternative method tailored + :meth:`.UpdateBase.return_defaults` - an alternative method tailored towards efficient fetching of server-side defaults and triggers for single-row INSERTs or UPDATEs. @@ -703,7 +929,6 @@ class ValuesBase(UpdateBase): _select_names: Optional[List[str]] = None _inline: bool = False - _returning: Tuple[_ColumnsClauseElement, ...] = () def __init__(self, table: _DMLTableArgument): self.table = coercions.expect( @@ -859,7 +1084,15 @@ class ValuesBase(UpdateBase): ) elif isinstance(arg, collections_abc.Sequence): - if arg and isinstance(arg[0], (list, dict, tuple)): + + if arg and isinstance(arg[0], dict): + multi_kv_generator = DMLState.get_plugin_class( + self + )._get_multi_crud_kv_pairs + self._multi_values += (multi_kv_generator(self, arg),) + return self + + if arg and isinstance(arg[0], (list, tuple)): self._multi_values += (arg,) return self @@ -888,173 +1121,13 @@ class ValuesBase(UpdateBase): # and ensures they get the "crud"-style name when rendered. kv_generator = DMLState.get_plugin_class(self)._get_crud_kv_pairs - coerced_arg = {k: v for k, v in kv_generator(self, arg.items())} + coerced_arg = dict(kv_generator(self, arg.items(), True)) if self._values: self._values = self._values.union(coerced_arg) else: self._values = util.immutabledict(coerced_arg) return self - @_generative - def return_defaults( - self: SelfValuesBase, *cols: _DMLColumnArgument - ) -> SelfValuesBase: - """Make use of a :term:`RETURNING` clause for the purpose - of fetching server-side expressions and defaults, for supporting - backends only. - - .. tip:: - - The :meth:`.ValuesBase.return_defaults` method is used by the ORM - for its internal work in fetching newly generated primary key - and server default values, in particular to provide the underyling - implementation of the :paramref:`_orm.Mapper.eager_defaults` - ORM feature. Its behavior is fairly idiosyncratic - and is not really intended for general use. End users should - stick with using :meth:`.UpdateBase.returning` in order to - add RETURNING clauses to their INSERT, UPDATE and DELETE - statements. - - Normally, a single row INSERT statement will automatically populate the - :attr:`.CursorResult.inserted_primary_key` attribute when executed, - which stores the primary key of the row that was just inserted in the - form of a :class:`.Row` object with column names as named tuple keys - (and the :attr:`.Row._mapping` view fully populated as well). The - dialect in use chooses the strategy to use in order to populate this - data; if it was generated using server-side defaults and / or SQL - expressions, dialect-specific approaches such as ``cursor.lastrowid`` - or ``RETURNING`` are typically used to acquire the new primary key - value. - - However, when the statement is modified by calling - :meth:`.ValuesBase.return_defaults` before executing the statement, - additional behaviors take place **only** for backends that support - RETURNING and for :class:`.Table` objects that maintain the - :paramref:`.Table.implicit_returning` parameter at its default value of - ``True``. In these cases, when the :class:`.CursorResult` is returned - from the statement's execution, not only will - :attr:`.CursorResult.inserted_primary_key` be populated as always, the - :attr:`.CursorResult.returned_defaults` attribute will also be - populated with a :class:`.Row` named-tuple representing the full range - of server generated - values from that single row, including values for any columns that - specify :paramref:`_schema.Column.server_default` or which make use of - :paramref:`_schema.Column.default` using a SQL expression. - - When invoking INSERT statements with multiple rows using - :ref:`insertmanyvalues <engine_insertmanyvalues>`, the - :meth:`.ValuesBase.return_defaults` modifier will have the effect of - the :attr:`_engine.CursorResult.inserted_primary_key_rows` and - :attr:`_engine.CursorResult.returned_defaults_rows` attributes being - fully populated with lists of :class:`.Row` objects representing newly - inserted primary key values as well as newly inserted server generated - values for each row inserted. The - :attr:`.CursorResult.inserted_primary_key` and - :attr:`.CursorResult.returned_defaults` attributes will also continue - to be populated with the first row of these two collections. - - If the backend does not support RETURNING or the :class:`.Table` in use - has disabled :paramref:`.Table.implicit_returning`, then no RETURNING - clause is added and no additional data is fetched, however the - INSERT or UPDATE statement proceeds normally. - - - E.g.:: - - stmt = table.insert().values(data='newdata').return_defaults() - - result = connection.execute(stmt) - - server_created_at = result.returned_defaults['created_at'] - - - The :meth:`.ValuesBase.return_defaults` method is mutually exclusive - against the :meth:`.UpdateBase.returning` method and errors will be - raised during the SQL compilation process if both are used at the same - time on one statement. The RETURNING clause of the INSERT or UPDATE - statement is therefore controlled by only one of these methods at a - time. - - The :meth:`.ValuesBase.return_defaults` method differs from - :meth:`.UpdateBase.returning` in these ways: - - 1. :meth:`.ValuesBase.return_defaults` method causes the - :attr:`.CursorResult.returned_defaults` collection to be populated - with the first row from the RETURNING result. This attribute is not - populated when using :meth:`.UpdateBase.returning`. - - 2. :meth:`.ValuesBase.return_defaults` is compatible with existing - logic used to fetch auto-generated primary key values that are then - populated into the :attr:`.CursorResult.inserted_primary_key` - attribute. By contrast, using :meth:`.UpdateBase.returning` will - have the effect of the :attr:`.CursorResult.inserted_primary_key` - attribute being left unpopulated. - - 3. :meth:`.ValuesBase.return_defaults` can be called against any - backend. Backends that don't support RETURNING will skip the usage - of the feature, rather than raising an exception. The return value - of :attr:`_engine.CursorResult.returned_defaults` will be ``None`` - for backends that don't support RETURNING or for which the target - :class:`.Table` sets :paramref:`.Table.implicit_returning` to - ``False``. - - 4. An INSERT statement invoked with executemany() is supported if the - backend database driver supports the - :ref:`insertmanyvalues <engine_insertmanyvalues>` - feature which is now supported by most SQLAlchemy-included backends. - When executemany is used, the - :attr:`_engine.CursorResult.returned_defaults_rows` and - :attr:`_engine.CursorResult.inserted_primary_key_rows` accessors - will return the inserted defaults and primary keys. - - .. versionadded:: 1.4 Added - :attr:`_engine.CursorResult.returned_defaults_rows` and - :attr:`_engine.CursorResult.inserted_primary_key_rows` accessors. - In version 2.0, the underlying implementation which fetches and - populates the data for these attributes was generalized to be - supported by most backends, whereas in 1.4 they were only - supported by the ``psycopg2`` driver. - - - :param cols: optional list of column key names or - :class:`_schema.Column` that acts as a filter for those columns that - will be fetched. - - .. seealso:: - - :meth:`.UpdateBase.returning` - - :attr:`_engine.CursorResult.returned_defaults` - - :attr:`_engine.CursorResult.returned_defaults_rows` - - :attr:`_engine.CursorResult.inserted_primary_key` - - :attr:`_engine.CursorResult.inserted_primary_key_rows` - - """ - - if self._return_defaults: - # note _return_defaults_columns = () means return all columns, - # so if we have been here before, only update collection if there - # are columns in the collection - if self._return_defaults_columns and cols: - self._return_defaults_columns = tuple( - set(self._return_defaults_columns).union( - coercions.expect(roles.ColumnsClauseRole, c) - for c in cols - ) - ) - else: - # set for all columns - self._return_defaults_columns = () - else: - self._return_defaults_columns = tuple( - coercions.expect(roles.ColumnsClauseRole, c) for c in cols - ) - self._return_defaults = True - return self - SelfInsert = typing.TypeVar("SelfInsert", bound="Insert") @@ -1459,7 +1532,7 @@ class Update(DMLWhereBase, ValuesBase): ) kv_generator = DMLState.get_plugin_class(self)._get_crud_kv_pairs - self._ordered_values = kv_generator(self, args) + self._ordered_values = kv_generator(self, args, True) return self @_generative diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index 4416fe630..b3a71dbff 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -68,7 +68,7 @@ class CursorSQL(SQLMatchRule): class CompiledSQL(SQLMatchRule): def __init__( - self, statement, params=None, dialect="default", enable_returning=False + self, statement, params=None, dialect="default", enable_returning=True ): self.statement = statement self.params = params @@ -90,6 +90,17 @@ class CompiledSQL(SQLMatchRule): dialect.insert_returning = ( dialect.update_returning ) = dialect.delete_returning = True + dialect.use_insertmanyvalues = True + dialect.supports_multivalues_insert = True + dialect.update_returning_multifrom = True + dialect.delete_returning_multifrom = True + # dialect.favor_returning_over_lastrowid = True + # dialect.insert_null_pk_still_autoincrements = True + + # this is calculated but we need it to be True for this + # to look like all the current RETURNING dialects + assert dialect.insert_executemany_returning + return dialect else: return url.URL.create(self.dialect).get_dialect()() diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index 20dee5273..ef284babc 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -23,7 +23,6 @@ from .util import adict from .util import drop_all_tables_from_metadata from .. import event from .. import util -from ..orm import declarative_base from ..orm import DeclarativeBase from ..orm import MappedAsDataclass from ..orm import registry @@ -117,7 +116,7 @@ class TestBase: metadata=metadata, type_annotation_map={ str: sa.String().with_variant( - sa.String(50), "mysql", "mariadb" + sa.String(50), "mysql", "mariadb", "oracle" ) }, ) @@ -132,7 +131,7 @@ class TestBase: metadata = _md type_annotation_map = { str: sa.String().with_variant( - sa.String(50), "mysql", "mariadb" + sa.String(50), "mysql", "mariadb", "oracle" ) } @@ -780,18 +779,19 @@ class DeclarativeMappedTest(MappedTest): def _with_register_classes(cls, fn): cls_registry = cls.classes - class DeclarativeBasic: + class _DeclBase(DeclarativeBase): __table_cls__ = schema.Table + metadata = cls._tables_metadata + type_annotation_map = { + str: sa.String().with_variant( + sa.String(50), "mysql", "mariadb", "oracle" + ) + } - def __init_subclass__(cls) -> None: + def __init_subclass__(cls, **kw) -> None: assert cls_registry is not None cls_registry[cls.__name__] = cls - super().__init_subclass__() - - _DeclBase = declarative_base( - metadata=cls._tables_metadata, - cls=DeclarativeBasic, - ) + super().__init_subclass__(**kw) cls.DeclarativeBasic = _DeclBase diff --git a/lib/sqlalchemy/testing/suite/test_rowcount.py b/lib/sqlalchemy/testing/suite/test_rowcount.py index b7d4b7452..8e19a24a8 100644 --- a/lib/sqlalchemy/testing/suite/test_rowcount.py +++ b/lib/sqlalchemy/testing/suite/test_rowcount.py @@ -89,8 +89,13 @@ class RowCountTest(fixtures.TablesTest): eq_(r.rowcount, 3) @testing.requires.update_returning - @testing.requires.sane_rowcount_w_returning def test_update_rowcount_return_defaults(self, connection): + """note this test should succeed for all RETURNING backends + as of 2.0. In + Idf28379f8705e403a3c6a937f6a798a042ef2540 we changed rowcount to use + len(rows) when we have implicit returning + + """ employees_table = self.tables.employees department = employees_table.c.department diff --git a/lib/sqlalchemy/util/_py_collections.py b/lib/sqlalchemy/util/_py_collections.py index f8348714c..488229abb 100644 --- a/lib/sqlalchemy/util/_py_collections.py +++ b/lib/sqlalchemy/util/_py_collections.py @@ -11,6 +11,7 @@ from __future__ import annotations from itertools import filterfalse from typing import AbstractSet from typing import Any +from typing import Callable from typing import cast from typing import Collection from typing import Dict @@ -481,7 +482,9 @@ class IdentitySet: return "%s(%r)" % (type(self).__name__, list(self._members.values())) -def unique_list(seq, hashfunc=None): +def unique_list( + seq: Iterable[_T], hashfunc: Optional[Callable[[_T], int]] = None +) -> List[_T]: seen: Set[Any] = set() seen_add = seen.add if not hashfunc: diff --git a/test/ext/test_horizontal_shard.py b/test/ext/test_horizontal_shard.py index 7cc6a6f79..667f4bfb0 100644 --- a/test/ext/test_horizontal_shard.py +++ b/test/ext/test_horizontal_shard.py @@ -465,7 +465,11 @@ class ShardTest: t = get_tokyo(sess2) eq_(t.city, tokyo.city) - def test_bulk_update_synchronize_evaluate(self): + @testing.combinations( + "fetch", "evaluate", "auto", argnames="synchronize_session" + ) + @testing.combinations(True, False, argnames="legacy") + def test_orm_update_synchronize(self, synchronize_session, legacy): sess = self._fixture_data() eq_( @@ -476,33 +480,25 @@ class ShardTest: temps = sess.query(Report).all() eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0}) - sess.query(Report).filter(Report.temperature >= 80).update( - {"temperature": Report.temperature + 6}, - synchronize_session="evaluate", - ) - - eq_( - set(row.temperature for row in sess.query(Report.temperature)), - {86.0, 75.0, 91.0}, - ) - - # test synchronize session as well - eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0}) - - def test_bulk_update_synchronize_fetch(self): - sess = self._fixture_data() - - eq_( - set(row.temperature for row in sess.query(Report.temperature)), - {80.0, 75.0, 85.0}, - ) + if legacy: + sess.query(Report).filter(Report.temperature >= 80).update( + {"temperature": Report.temperature + 6}, + synchronize_session=synchronize_session, + ) + else: + sess.execute( + update(Report) + .filter(Report.temperature >= 80) + .values(temperature=Report.temperature + 6) + .execution_options(synchronize_session=synchronize_session) + ) - temps = sess.query(Report).all() - eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0}) + # test synchronize session + def go(): + eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0}) - sess.query(Report).filter(Report.temperature >= 80).update( - {"temperature": Report.temperature + 6}, - synchronize_session="fetch", + self.assert_sql_count( + sess._ShardedSession__binds["north_america"], go, 0 ) eq_( @@ -510,165 +506,41 @@ class ShardTest: {86.0, 75.0, 91.0}, ) - # test synchronize session as well - eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0}) - - def test_bulk_delete_synchronize_evaluate(self): - sess = self._fixture_data() - - temps = sess.query(Report).all() - eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0}) - - sess.query(Report).filter(Report.temperature >= 80).delete( - synchronize_session="evaluate" - ) - - eq_( - set(row.temperature for row in sess.query(Report.temperature)), - {75.0}, - ) - - # test synchronize session as well - for t in temps: - assert inspect(t).deleted is (t.temperature >= 80) - - def test_bulk_delete_synchronize_fetch(self): + @testing.combinations( + "fetch", "evaluate", "auto", argnames="synchronize_session" + ) + @testing.combinations(True, False, argnames="legacy") + def test_orm_delete_synchronize(self, synchronize_session, legacy): sess = self._fixture_data() temps = sess.query(Report).all() eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0}) - sess.query(Report).filter(Report.temperature >= 80).delete( - synchronize_session="fetch" - ) - - eq_( - set(row.temperature for row in sess.query(Report.temperature)), - {75.0}, - ) - - # test synchronize session as well - for t in temps: - assert inspect(t).deleted is (t.temperature >= 80) - - def test_bulk_update_future_synchronize_evaluate(self): - sess = self._fixture_data() - - eq_( - set( - row.temperature - for row in sess.execute(select(Report.temperature)) - ), - {80.0, 75.0, 85.0}, - ) - - temps = sess.execute(select(Report)).scalars().all() - eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0}) - - sess.execute( - update(Report) - .filter(Report.temperature >= 80) - .values( - {"temperature": Report.temperature + 6}, + if legacy: + sess.query(Report).filter(Report.temperature >= 80).delete( + synchronize_session=synchronize_session ) - .execution_options(synchronize_session="evaluate") - ) - - eq_( - set( - row.temperature - for row in sess.execute(select(Report.temperature)) - ), - {86.0, 75.0, 91.0}, - ) - - # test synchronize session as well - eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0}) - - def test_bulk_update_future_synchronize_fetch(self): - sess = self._fixture_data() - - eq_( - set( - row.temperature - for row in sess.execute(select(Report.temperature)) - ), - {80.0, 75.0, 85.0}, - ) - - temps = sess.execute(select(Report)).scalars().all() - eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0}) - - # MARKMARK - # omitting the criteria so that the UPDATE affects three out of - # four shards - sess.execute( - update(Report) - .values( - {"temperature": Report.temperature + 6}, + else: + sess.execute( + delete(Report) + .filter(Report.temperature >= 80) + .execution_options(synchronize_session=synchronize_session) ) - .execution_options(synchronize_session="fetch") - ) - - eq_( - set( - row.temperature - for row in sess.execute(select(Report.temperature)) - ), - {86.0, 81.0, 91.0}, - ) - - # test synchronize session as well - eq_(set(t.temperature for t in temps), {86.0, 81.0, 91.0}) - - def test_bulk_delete_future_synchronize_evaluate(self): - sess = self._fixture_data() - - temps = sess.execute(select(Report)).scalars().all() - eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0}) - - sess.execute( - delete(Report) - .filter(Report.temperature >= 80) - .execution_options(synchronize_session="evaluate") - ) - eq_( - set( - row.temperature - for row in sess.execute(select(Report.temperature)) - ), - {75.0}, - ) - - # test synchronize session as well - for t in temps: - assert inspect(t).deleted is (t.temperature >= 80) - - def test_bulk_delete_future_synchronize_fetch(self): - sess = self._fixture_data() - - temps = sess.execute(select(Report)).scalars().all() - eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0}) + def go(): + # test synchronize session + for t in temps: + assert inspect(t).deleted is (t.temperature >= 80) - sess.execute( - delete(Report) - .filter(Report.temperature >= 80) - .execution_options(synchronize_session="fetch") + self.assert_sql_count( + sess._ShardedSession__binds["north_america"], go, 0 ) eq_( - set( - row.temperature - for row in sess.execute(select(Report.temperature)) - ), + set(row.temperature for row in sess.query(Report.temperature)), {75.0}, ) - # test synchronize session as well - for t in temps: - assert inspect(t).deleted is (t.temperature >= 80) - class DistinctEngineShardTest(ShardTest, fixtures.MappedTest): def _init_dbs(self): diff --git a/test/ext/test_hybrid.py b/test/ext/test_hybrid.py index de5f89b25..0cba8f3a1 100644 --- a/test/ext/test_hybrid.py +++ b/test/ext/test_hybrid.py @@ -3,6 +3,7 @@ from decimal import Decimal from sqlalchemy import exc from sqlalchemy import ForeignKey from sqlalchemy import func +from sqlalchemy import insert from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import LABEL_STYLE_TABLENAME_PLUS_COL @@ -1017,15 +1018,43 @@ class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL): params={"first_name": "Dr."}, ) - def test_update_expr(self): + @testing.combinations("attr", "str", "kwarg", argnames="keytype") + def test_update_expr(self, keytype): Person = self.classes.Person - statement = update(Person).values({Person.name: "Dr. No"}) + if keytype == "attr": + statement = update(Person).values({Person.name: "Dr. No"}) + elif keytype == "str": + statement = update(Person).values({"name": "Dr. No"}) + elif keytype == "kwarg": + statement = update(Person).values(name="Dr. No") + else: + assert False self.assert_compile( statement, "UPDATE person SET first_name=:first_name, last_name=:last_name", - params={"first_name": "Dr.", "last_name": "No"}, + checkparams={"first_name": "Dr.", "last_name": "No"}, + ) + + @testing.combinations("attr", "str", "kwarg", argnames="keytype") + def test_insert_expr(self, keytype): + Person = self.classes.Person + + if keytype == "attr": + statement = insert(Person).values({Person.name: "Dr. No"}) + elif keytype == "str": + statement = insert(Person).values({"name": "Dr. No"}) + elif keytype == "kwarg": + statement = insert(Person).values(name="Dr. No") + else: + assert False + + self.assert_compile( + statement, + "INSERT INTO person (first_name, last_name) VALUES " + "(:first_name, :last_name)", + checkparams={"first_name": "Dr.", "last_name": "No"}, ) # these tests all run two UPDATES to assert that caching is not diff --git a/test/orm/dml/__init__.py b/test/orm/dml/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/test/orm/dml/__init__.py diff --git a/test/orm/test_bulk.py b/test/orm/dml/test_bulk.py index 802cdfac5..52db4247f 100644 --- a/test/orm/test_bulk.py +++ b/test/orm/dml/test_bulk.py @@ -1,8 +1,11 @@ from sqlalchemy import FetchedValue from sqlalchemy import ForeignKey +from sqlalchemy import Identity +from sqlalchemy import insert from sqlalchemy import Integer from sqlalchemy import String from sqlalchemy import testing +from sqlalchemy import update from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import mock @@ -20,6 +23,8 @@ class BulkTest(testing.AssertsExecutionResults): class BulkInsertUpdateVersionId(BulkTest, fixtures.MappedTest): + __backend__ = True + @classmethod def define_tables(cls, metadata): Table( @@ -73,6 +78,8 @@ class BulkInsertUpdateVersionId(BulkTest, fixtures.MappedTest): class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest): + __backend__ = True + @classmethod def setup_mappers(cls): User, Address, Order = cls.classes("User", "Address", "Order") @@ -82,22 +89,42 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest): cls.mapper_registry.map_imperatively(Address, a) cls.mapper_registry.map_imperatively(Order, o) - def test_bulk_save_return_defaults(self): + @testing.combinations("save_objects", "insert_mappings", "insert_stmt") + def test_bulk_save_return_defaults(self, statement_type): (User,) = self.classes("User") s = fixture_session() - objects = [User(name="u1"), User(name="u2"), User(name="u3")] - assert "id" not in objects[0].__dict__ - with self.sql_execution_asserter() as asserter: - s.bulk_save_objects(objects, return_defaults=True) + if statement_type == "save_objects": + objects = [User(name="u1"), User(name="u2"), User(name="u3")] + assert "id" not in objects[0].__dict__ + + returning_users_id = " RETURNING users.id" + with self.sql_execution_asserter() as asserter: + s.bulk_save_objects(objects, return_defaults=True) + elif statement_type == "insert_mappings": + data = [dict(name="u1"), dict(name="u2"), dict(name="u3")] + returning_users_id = " RETURNING users.id" + with self.sql_execution_asserter() as asserter: + s.bulk_insert_mappings(User, data, return_defaults=True) + elif statement_type == "insert_stmt": + data = [dict(name="u1"), dict(name="u2"), dict(name="u3")] + + # for statement, "return_defaults" is heuristic on if we are + # a joined inh mapping if we don't otherwise include + # .returning() on the statement itself + returning_users_id = "" + with self.sql_execution_asserter() as asserter: + s.execute(insert(User), data) asserter.assert_( Conditional( - testing.db.dialect.insert_executemany_returning, + testing.db.dialect.insert_executemany_returning + or statement_type == "insert_stmt", [ CompiledSQL( - "INSERT INTO users (name) VALUES (:name)", + "INSERT INTO users (name) " + f"VALUES (:name){returning_users_id}", [{"name": "u1"}, {"name": "u2"}, {"name": "u3"}], ), ], @@ -117,7 +144,8 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest): ], ) ) - eq_(objects[0].__dict__["id"], 1) + if statement_type == "save_objects": + eq_(objects[0].__dict__["id"], 1) def test_bulk_save_mappings_preserve_order(self): (User,) = self.classes("User") @@ -219,8 +247,9 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest): ) ) - def test_bulk_update(self): - (User,) = self.classes("User") + @testing.combinations("update_mappings", "update_stmt") + def test_bulk_update(self, statement_type): + User = self.classes.User s = fixture_session(expire_on_commit=False) objects = [User(name="u1"), User(name="u2"), User(name="u3")] @@ -228,15 +257,18 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest): s.commit() s = fixture_session() - with self.sql_execution_asserter() as asserter: - s.bulk_update_mappings( - User, - [ - {"id": 1, "name": "u1new"}, - {"id": 2, "name": "u2"}, - {"id": 3, "name": "u3new"}, - ], - ) + data = [ + {"id": 1, "name": "u1new"}, + {"id": 2, "name": "u2"}, + {"id": 3, "name": "u3new"}, + ] + + if statement_type == "update_mappings": + with self.sql_execution_asserter() as asserter: + s.bulk_update_mappings(User, data) + elif statement_type == "update_stmt": + with self.sql_execution_asserter() as asserter: + s.execute(update(User), data) asserter.assert_( CompiledSQL( @@ -303,6 +335,8 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest): class BulkUDPostfetchTest(BulkTest, fixtures.MappedTest): + __backend__ = True + @classmethod def define_tables(cls, metadata): Table( @@ -360,6 +394,8 @@ class BulkUDPostfetchTest(BulkTest, fixtures.MappedTest): class BulkUDTestAltColKeys(BulkTest, fixtures.MappedTest): + __backend__ = True + @classmethod def define_tables(cls, metadata): Table( @@ -547,6 +583,8 @@ class BulkUDTestAltColKeys(BulkTest, fixtures.MappedTest): class BulkInheritanceTest(BulkTest, fixtures.MappedTest): + __backend__ = True + @classmethod def define_tables(cls, metadata): Table( @@ -643,6 +681,7 @@ class BulkInheritanceTest(BulkTest, fixtures.MappedTest): ) s = fixture_session() + objects = [ Manager(name="m1", status="s1", manager_name="mn1"), Engineer(name="e1", status="s2", primary_language="l1"), @@ -669,7 +708,7 @@ class BulkInheritanceTest(BulkTest, fixtures.MappedTest): [ CompiledSQL( "INSERT INTO people (name, type) " - "VALUES (:name, :type)", + "VALUES (:name, :type) RETURNING people.person_id", [ {"type": "engineer", "name": "e1"}, {"type": "engineer", "name": "e2"}, @@ -798,59 +837,74 @@ class BulkInheritanceTest(BulkTest, fixtures.MappedTest): ), ) - def test_bulk_insert_joined_inh_return_defaults(self): + @testing.combinations("insert_mappings", "insert_stmt") + def test_bulk_insert_joined_inh_return_defaults(self, statement_type): Person, Engineer, Manager, Boss = self.classes( "Person", "Engineer", "Manager", "Boss" ) s = fixture_session() - with self.sql_execution_asserter() as asserter: - s.bulk_insert_mappings( - Boss, - [ - dict( - name="b1", - status="s1", - manager_name="mn1", - golf_swing="g1", - ), - dict( - name="b2", - status="s2", - manager_name="mn2", - golf_swing="g2", - ), - dict( - name="b3", - status="s3", - manager_name="mn3", - golf_swing="g3", - ), - ], - return_defaults=True, - ) + data = [ + dict( + name="b1", + status="s1", + manager_name="mn1", + golf_swing="g1", + ), + dict( + name="b2", + status="s2", + manager_name="mn2", + golf_swing="g2", + ), + dict( + name="b3", + status="s3", + manager_name="mn3", + golf_swing="g3", + ), + ] + + if statement_type == "insert_mappings": + with self.sql_execution_asserter() as asserter: + s.bulk_insert_mappings( + Boss, + data, + return_defaults=True, + ) + elif statement_type == "insert_stmt": + with self.sql_execution_asserter() as asserter: + s.execute(insert(Boss), data) asserter.assert_( Conditional( testing.db.dialect.insert_executemany_returning, [ CompiledSQL( - "INSERT INTO people (name) VALUES (:name)", - [{"name": "b1"}, {"name": "b2"}, {"name": "b3"}], + "INSERT INTO people (name, type) " + "VALUES (:name, :type) RETURNING people.person_id", + [ + {"name": "b1", "type": "boss"}, + {"name": "b2", "type": "boss"}, + {"name": "b3", "type": "boss"}, + ], ), ], [ CompiledSQL( - "INSERT INTO people (name) VALUES (:name)", - [{"name": "b1"}], + "INSERT INTO people (name, type) " + "VALUES (:name, :type)", + [{"name": "b1", "type": "boss"}], ), CompiledSQL( - "INSERT INTO people (name) VALUES (:name)", - [{"name": "b2"}], + "INSERT INTO people (name, type) " + "VALUES (:name, :type)", + [{"name": "b2", "type": "boss"}], ), CompiledSQL( - "INSERT INTO people (name) VALUES (:name)", - [{"name": "b3"}], + "INSERT INTO people (name, type) " + "VALUES (:name, :type)", + [{"name": "b3", "type": "boss"}], ), ], ), @@ -874,15 +928,79 @@ class BulkInheritanceTest(BulkTest, fixtures.MappedTest): ), ) + @testing.combinations("update_mappings", "update_stmt") + def test_bulk_update(self, statement_type): + Person, Engineer, Manager, Boss = self.classes( + "Person", "Engineer", "Manager", "Boss" + ) + + s = fixture_session() + + b1, b2, b3 = ( + Boss(name="b1", status="s1", manager_name="mn1", golf_swing="g1"), + Boss(name="b2", status="s2", manager_name="mn2", golf_swing="g2"), + Boss(name="b3", status="s3", manager_name="mn3", golf_swing="g3"), + ) + s.add_all([b1, b2, b3]) + s.commit() + + # slight non-convenient thing. we have to fill in boss_id here + # for update, this is not sent along automatically. this is not a + # new behavior in bulk + new_data = [ + { + "person_id": b1.person_id, + "boss_id": b1.boss_id, + "name": "b1_updated", + "manager_name": "mn1_updated", + }, + { + "person_id": b3.person_id, + "boss_id": b3.boss_id, + "manager_name": "mn2_updated", + "golf_swing": "g1_updated", + }, + ] + + if statement_type == "update_mappings": + with self.sql_execution_asserter() as asserter: + s.bulk_update_mappings(Boss, new_data) + elif statement_type == "update_stmt": + with self.sql_execution_asserter() as asserter: + s.execute(update(Boss), new_data) + + asserter.assert_( + CompiledSQL( + "UPDATE people SET name=:name WHERE " + "people.person_id = :people_person_id", + [{"name": "b1_updated", "people_person_id": 1}], + ), + CompiledSQL( + "UPDATE managers SET manager_name=:manager_name WHERE " + "managers.person_id = :managers_person_id", + [ + {"manager_name": "mn1_updated", "managers_person_id": 1}, + {"manager_name": "mn2_updated", "managers_person_id": 3}, + ], + ), + CompiledSQL( + "UPDATE boss SET golf_swing=:golf_swing WHERE " + "boss.boss_id = :boss_boss_id", + [{"golf_swing": "g1_updated", "boss_boss_id": 3}], + ), + ) + class BulkIssue6793Test(BulkTest, fixtures.DeclarativeMappedTest): + __backend__ = True + @classmethod def setup_classes(cls): Base = cls.DeclarativeBasic class User(Base): __tablename__ = "users" - id = Column(Integer, primary_key=True) + id = Column(Integer, Identity(), primary_key=True) name = Column(String(255), nullable=False) def test_issue_6793(self): @@ -907,7 +1025,8 @@ class BulkIssue6793Test(BulkTest, fixtures.DeclarativeMappedTest): [{"name": "A"}, {"name": "B"}], ), CompiledSQL( - "INSERT INTO users (name) VALUES (:name)", + "INSERT INTO users (name) VALUES (:name) " + "RETURNING users.id", [{"name": "C"}, {"name": "D"}], ), ], diff --git a/test/orm/dml/test_bulk_statements.py b/test/orm/dml/test_bulk_statements.py new file mode 100644 index 000000000..0cca9e6f5 --- /dev/null +++ b/test/orm/dml/test_bulk_statements.py @@ -0,0 +1,1199 @@ +from __future__ import annotations + +from typing import Any +from typing import List +from typing import Optional +import uuid + +from sqlalchemy import exc +from sqlalchemy import ForeignKey +from sqlalchemy import func +from sqlalchemy import Identity +from sqlalchemy import insert +from sqlalchemy import inspect +from sqlalchemy import literal_column +from sqlalchemy import select +from sqlalchemy import String +from sqlalchemy import testing +from sqlalchemy import update +from sqlalchemy.orm import aliased +from sqlalchemy.orm import load_only +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.testing import config +from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises_message +from sqlalchemy.testing import fixtures +from sqlalchemy.testing import mock +from sqlalchemy.testing import provision +from sqlalchemy.testing.assertsql import CompiledSQL +from sqlalchemy.testing.fixtures import fixture_session + + +class NoReturningTest(fixtures.TestBase): + def test_no_returning_error(self, decl_base): + class A(fixtures.ComparableEntity, decl_base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(Identity(), primary_key=True) + data: Mapped[str] + x: Mapped[Optional[int]] = mapped_column("xcol") + + decl_base.metadata.create_all(testing.db) + s = fixture_session() + + if testing.requires.insert_executemany_returning.enabled: + result = s.scalars( + insert(A).returning(A), + [ + {"data": "d3", "x": 5}, + {"data": "d4", "x": 6}, + ], + ) + eq_(result.all(), [A(data="d3", x=5), A(data="d4", x=6)]) + + else: + with expect_raises_message( + exc.InvalidRequestError, + "Can't use explicit RETURNING for bulk INSERT operation", + ): + s.scalars( + insert(A).returning(A), + [ + {"data": "d3", "x": 5}, + {"data": "d4", "x": 6}, + ], + ) + + def test_omit_returning_ok(self, decl_base): + class A(decl_base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(Identity(), primary_key=True) + data: Mapped[str] + x: Mapped[Optional[int]] = mapped_column("xcol") + + decl_base.metadata.create_all(testing.db) + s = fixture_session() + + s.execute( + insert(A), + [ + {"data": "d3", "x": 5}, + {"data": "d4", "x": 6}, + ], + ) + eq_( + s.execute(select(A.data, A.x).order_by(A.id)).all(), + [("d3", 5), ("d4", 6)], + ) + + +class BulkDMLReturningInhTest: + def test_insert_col_key_also_works_currently(self): + """using the column key, not mapped attr key. + + right now this passes through to the INSERT. when doing this with + an UPDATE, it tends to fail because the synchronize session + strategies can't match "xcol" back. however w/ INSERT we aren't + doing that, so there's no place this gets checked. UPDATE also + succeeds if synchronize_session is turned off. + + """ + A, B = self.classes("A", "B") + + s = fixture_session() + s.execute(insert(A).values(type="a", data="d", xcol=10)) + eq_(s.scalars(select(A.x)).all(), [10]) + + @testing.combinations(True, False, argnames="use_returning") + def test_heterogeneous_keys(self, use_returning): + A, B = self.classes("A", "B") + + values = [ + {"data": "d3", "x": 5, "type": "a"}, + {"data": "d4", "x": 6, "type": "a"}, + {"data": "d5", "type": "a"}, + {"data": "d6", "x": 8, "y": 9, "type": "a"}, + {"data": "d7", "x": 12, "y": 12, "type": "a"}, + {"data": "d8", "x": 7, "type": "a"}, + ] + + s = fixture_session() + + stmt = insert(A) + if use_returning: + stmt = stmt.returning(A) + + with self.sql_execution_asserter() as asserter: + result = s.execute(stmt, values) + + if inspect(B).single: + single_inh = ", a.bd, a.zcol, a.q" + else: + single_inh = "" + + if use_returning: + asserter.assert_( + CompiledSQL( + "INSERT INTO a (type, data, xcol) VALUES " + "(:type, :data, :xcol) " + f"RETURNING a.id, a.type, a.data, a.xcol, a.y{single_inh}", + [ + {"type": "a", "data": "d3", "xcol": 5}, + {"type": "a", "data": "d4", "xcol": 6}, + ], + ), + CompiledSQL( + "INSERT INTO a (type, data) VALUES (:type, :data) " + f"RETURNING a.id, a.type, a.data, a.xcol, a.y{single_inh}", + [{"type": "a", "data": "d5"}], + ), + CompiledSQL( + "INSERT INTO a (type, data, xcol, y) " + "VALUES (:type, :data, :xcol, :y) " + f"RETURNING a.id, a.type, a.data, a.xcol, a.y{single_inh}", + [ + {"type": "a", "data": "d6", "xcol": 8, "y": 9}, + {"type": "a", "data": "d7", "xcol": 12, "y": 12}, + ], + ), + CompiledSQL( + "INSERT INTO a (type, data, xcol) " + "VALUES (:type, :data, :xcol) " + f"RETURNING a.id, a.type, a.data, a.xcol, a.y{single_inh}", + [{"type": "a", "data": "d8", "xcol": 7}], + ), + ) + else: + asserter.assert_( + CompiledSQL( + "INSERT INTO a (type, data, xcol) VALUES " + "(:type, :data, :xcol)", + [ + {"type": "a", "data": "d3", "xcol": 5}, + {"type": "a", "data": "d4", "xcol": 6}, + ], + ), + CompiledSQL( + "INSERT INTO a (type, data) VALUES (:type, :data)", + [{"type": "a", "data": "d5"}], + ), + CompiledSQL( + "INSERT INTO a (type, data, xcol, y) " + "VALUES (:type, :data, :xcol, :y)", + [ + {"type": "a", "data": "d6", "xcol": 8, "y": 9}, + {"type": "a", "data": "d7", "xcol": 12, "y": 12}, + ], + ), + CompiledSQL( + "INSERT INTO a (type, data, xcol) " + "VALUES (:type, :data, :xcol)", + [{"type": "a", "data": "d8", "xcol": 7}], + ), + ) + + if use_returning: + eq_( + result.scalars().all(), + [ + A(data="d3", id=mock.ANY, type="a", x=5, y=None), + A(data="d4", id=mock.ANY, type="a", x=6, y=None), + A(data="d5", id=mock.ANY, type="a", x=None, y=None), + A(data="d6", id=mock.ANY, type="a", x=8, y=9), + A(data="d7", id=mock.ANY, type="a", x=12, y=12), + A(data="d8", id=mock.ANY, type="a", x=7, y=None), + ], + ) + + @testing.combinations( + "strings", + "cols", + "strings_w_exprs", + "cols_w_exprs", + argnames="paramstyle", + ) + @testing.combinations( + True, + (False, testing.requires.multivalues_inserts), + argnames="single_element", + ) + def test_single_values_returning_fn(self, paramstyle, single_element): + """test using insert().values(). + + these INSERT statements go straight in as a single execute without any + insertmanyreturning or bulk_insert_mappings thing going on. the + advantage here is that SQL expressions can be used in the values also. + Disadvantage is none of the automation for inheritance mappers. + + """ + A, B = self.classes("A", "B") + + if paramstyle == "strings": + values = [ + {"data": "d3", "x": 5, "y": 9, "type": "a"}, + {"data": "d4", "x": 10, "y": 8, "type": "a"}, + ] + elif paramstyle == "cols": + values = [ + {A.data: "d3", A.x: 5, A.y: 9, A.type: "a"}, + {A.data: "d4", A.x: 10, A.y: 8, A.type: "a"}, + ] + elif paramstyle == "strings_w_exprs": + values = [ + {"data": func.lower("D3"), "x": 5, "y": 9, "type": "a"}, + { + "data": "d4", + "x": literal_column("5") + 5, + "y": 8, + "type": "a", + }, + ] + elif paramstyle == "cols_w_exprs": + values = [ + {A.data: func.lower("D3"), A.x: 5, A.y: 9, A.type: "a"}, + { + A.data: "d4", + A.x: literal_column("5") + 5, + A.y: 8, + A.type: "a", + }, + ] + else: + assert False + + s = fixture_session() + + if single_element: + if paramstyle.startswith("strings"): + stmt = ( + insert(A) + .values(**values[0]) + .returning(A, func.upper(A.data, type_=String)) + ) + else: + stmt = ( + insert(A) + .values(values[0]) + .returning(A, func.upper(A.data, type_=String)) + ) + else: + stmt = ( + insert(A) + .values(values) + .returning(A, func.upper(A.data, type_=String)) + ) + + for i in range(3): + result = s.execute(stmt) + expected: List[Any] = [(A(data="d3", x=5, y=9), "D3")] + if not single_element: + expected.append((A(data="d4", x=10, y=8), "D4")) + eq_(result.all(), expected) + + def test_bulk_w_sql_expressions(self): + A, B = self.classes("A", "B") + + data = [ + {"x": 5, "y": 9, "type": "a"}, + { + "x": 10, + "y": 8, + "type": "a", + }, + ] + + s = fixture_session() + + stmt = ( + insert(A) + .values(data=func.lower("DD")) + .returning(A, func.upper(A.data, type_=String)) + ) + + for i in range(3): + result = s.execute(stmt, data) + expected: List[Any] = [ + (A(data="dd", x=5, y=9), "DD"), + (A(data="dd", x=10, y=8), "DD"), + ] + eq_(result.all(), expected) + + def test_bulk_w_sql_expressions_subclass(self): + A, B = self.classes("A", "B") + + data = [ + {"bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4}, + {"bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8}, + ] + + s = fixture_session() + + stmt = ( + insert(B) + .values(data=func.lower("DD")) + .returning(B, func.upper(B.data, type_=String)) + ) + + for i in range(3): + result = s.execute(stmt, data) + expected: List[Any] = [ + (B(bd="bd1", data="dd", q=4, type="b", x=1, y=2, z=3), "DD"), + (B(bd="bd2", data="dd", q=8, type="b", x=5, y=6, z=7), "DD"), + ] + eq_(result.all(), expected) + + @testing.combinations(True, False, argnames="use_ordered") + def test_bulk_upd_w_sql_expressions_no_ordered_values(self, use_ordered): + A, B = self.classes("A", "B") + + s = fixture_session() + + stmt = update(B).ordered_values( + ("data", func.lower("DD_UPDATE")), + ("z", literal_column("3 + 12")), + ) + with expect_raises_message( + exc.InvalidRequestError, + r"bulk ORM UPDATE does not support ordered_values\(\) " + r"for custom UPDATE", + ): + s.execute( + stmt, + [ + {"id": 5, "bd": "bd1_updated"}, + {"id": 6, "bd": "bd2_updated"}, + ], + ) + + def test_bulk_upd_w_sql_expressions_subclass(self): + A, B = self.classes("A", "B") + + s = fixture_session() + + data = [ + {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4}, + {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8}, + ] + ids = s.scalars(insert(B).returning(B.id), data).all() + + stmt = update(B).values( + data=func.lower("DD_UPDATE"), z=literal_column("3 + 12") + ) + + result = s.execute( + stmt, + [ + {"id": ids[0], "bd": "bd1_updated"}, + {"id": ids[1], "bd": "bd2_updated"}, + ], + ) + + # this is a nullresult at the moment + assert result is not None + + eq_( + s.scalars(select(B)).all(), + [ + B( + bd="bd1_updated", + data="dd_update", + id=ids[0], + q=4, + type="b", + x=1, + y=2, + z=15, + ), + B( + bd="bd2_updated", + data="dd_update", + id=ids[1], + q=8, + type="b", + x=5, + y=6, + z=15, + ), + ], + ) + + def test_single_returning_fn(self): + A, B = self.classes("A", "B") + + s = fixture_session() + for i in range(3): + result = s.execute( + insert(A).returning(A, func.upper(A.data, type_=String)), + [{"data": "d3"}, {"data": "d4"}], + ) + eq_(result.all(), [(A(data="d3"), "D3"), (A(data="d4"), "D4")]) + + @testing.combinations( + True, + False, + argnames="single_element", + ) + def test_subclass_no_returning(self, single_element): + A, B = self.classes("A", "B") + + s = fixture_session() + + if single_element: + data = {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4} + else: + data = [ + {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4}, + {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8}, + ] + + result = s.execute(insert(B), data) + assert result._soft_closed + + @testing.combinations( + True, + False, + argnames="single_element", + ) + def test_subclass_load_only(self, single_element): + """test that load_only() prevents additional attributes from being + populated. + + """ + A, B = self.classes("A", "B") + + s = fixture_session() + + if single_element: + data = {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4} + else: + data = [ + {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4}, + {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8}, + ] + + for i in range(3): + # tests both caching and that the data dictionaries aren't + # mutated... + result = s.execute( + insert(B).returning(B).options(load_only(B.data, B.y, B.q)), + data, + ) + objects = result.scalars().all() + for obj in objects: + assert "data" in obj.__dict__ + assert "q" in obj.__dict__ + assert "z" not in obj.__dict__ + assert "x" not in obj.__dict__ + + expected = [ + B(data="d3", bd="bd1", x=1, y=2, z=3, q=4), + ] + if not single_element: + expected.append(B(data="d4", bd="bd2", x=5, y=6, z=7, q=8)) + eq_(objects, expected) + + @testing.combinations( + True, + False, + argnames="single_element", + ) + def test_subclass_load_only_doesnt_fetch_cols(self, single_element): + """test that when using load_only(), the actual INSERT statement + does not include the deferred columns + + """ + A, B = self.classes("A", "B") + + s = fixture_session() + + data = [ + {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4}, + {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8}, + ] + if single_element: + data = data[0] + + with self.sql_execution_asserter() as asserter: + + # tests both caching and that the data dictionaries aren't + # mutated... + + # note that if we don't put B.id here, accessing .id on the + # B object for joined inheritance is triggering a SELECT + # (and not for single inheritance). this seems not great, but is + # likely a different issue + result = s.execute( + insert(B) + .returning(B) + .options(load_only(B.id, B.data, B.y, B.q)), + data, + ) + objects = result.scalars().all() + if single_element: + id0 = objects[0].id + id1 = None + else: + id0, id1 = objects[0].id, objects[1].id + + if inspect(B).single or inspect(B).concrete: + expected_params = [ + { + "type": "b", + "data": "d3", + "xcol": 1, + "y": 2, + "bd": "bd1", + "zcol": 3, + "q": 4, + }, + { + "type": "b", + "data": "d4", + "xcol": 5, + "y": 6, + "bd": "bd2", + "zcol": 7, + "q": 8, + }, + ] + if single_element: + expected_params[1:] = [] + # RETURNING only includes PK, discriminator, then the cols + # we asked for data, y, q. xcol, z, bd are omitted + + if inspect(B).single: + asserter.assert_( + CompiledSQL( + "INSERT INTO a (type, data, xcol, y, bd, zcol, q) " + "VALUES " + "(:type, :data, :xcol, :y, :bd, :zcol, :q) " + "RETURNING a.id, a.type, a.data, a.y, a.q", + expected_params, + ), + ) + else: + asserter.assert_( + CompiledSQL( + "INSERT INTO b (type, data, xcol, y, bd, zcol, q) " + "VALUES " + "(:type, :data, :xcol, :y, :bd, :zcol, :q) " + "RETURNING b.id, b.type, b.data, b.y, b.q", + expected_params, + ), + ) + else: + a_data = [ + {"type": "b", "data": "d3", "xcol": 1, "y": 2}, + {"type": "b", "data": "d4", "xcol": 5, "y": 6}, + ] + b_data = [ + {"id": id0, "bd": "bd1", "zcol": 3, "q": 4}, + {"id": id1, "bd": "bd2", "zcol": 7, "q": 8}, + ] + if single_element: + a_data[1:] = [] + b_data[1:] = [] + # RETURNING only includes PK, discriminator, then the cols + # we asked for data, y, q. xcol, z, bd are omitted. plus they + # are broken out correctly in the two statements. + asserter.assert_( + CompiledSQL( + "INSERT INTO a (type, data, xcol, y) VALUES " + "(:type, :data, :xcol, :y) " + "RETURNING a.id, a.type, a.data, a.y", + a_data, + ), + CompiledSQL( + "INSERT INTO b (id, bd, zcol, q) " + "VALUES (:id, :bd, :zcol, :q) " + "RETURNING b.id, b.q", + b_data, + ), + ) + + @testing.combinations( + True, + False, + argnames="single_element", + ) + def test_subclass_returning_bind_expr(self, single_element): + A, B = self.classes("A", "B") + + s = fixture_session() + + if single_element: + data = {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4} + else: + data = [ + {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4}, + {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8}, + ] + # note there's a fix in compiler.py -> + # _deliver_insertmanyvalues_batches + # for this re: the parameter rendering that isn't tested anywhere + # else. two different versions of the bug for both positional + # and non + result = s.execute(insert(B).returning(B.data, B.y, B.q + 5), data) + if single_element: + eq_(result.all(), [("d3", 2, 9)]) + else: + eq_(result.all(), [("d3", 2, 9), ("d4", 6, 13)]) + + def test_subclass_bulk_update(self): + A, B = self.classes("A", "B") + + s = fixture_session() + + data = [ + {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4}, + {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8}, + ] + ids = s.scalars(insert(B).returning(B.id), data).all() + + result = s.execute( + update(B), + [ + {"id": ids[0], "data": "d3_updated", "bd": "bd1_updated"}, + {"id": ids[1], "data": "d4_updated", "bd": "bd2_updated"}, + ], + ) + + # this is a nullresult at the moment + assert result is not None + + eq_( + s.scalars(select(B)).all(), + [ + B( + bd="bd1_updated", + data="d3_updated", + id=ids[0], + q=4, + type="b", + x=1, + y=2, + z=3, + ), + B( + bd="bd2_updated", + data="d4_updated", + id=ids[1], + q=8, + type="b", + x=5, + y=6, + z=7, + ), + ], + ) + + @testing.combinations(True, False, argnames="single_element") + def test_subclass_return_just_subclass_ids(self, single_element): + A, B = self.classes("A", "B") + + s = fixture_session() + + if single_element: + data = {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4} + else: + data = [ + {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4}, + {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8}, + ] + + ids = s.scalars(insert(B).returning(B.id), data).all() + actual_ids = s.scalars(select(B.id).order_by(B.data)).all() + + eq_(ids, actual_ids) + + @testing.combinations( + "orm", + "bulk", + argnames="insert_strategy", + ) + @testing.requires.provisioned_upsert + def test_base_class_upsert(self, insert_strategy): + """upsert is really tricky. if you dont have any data updated, + then you dont get the rows back and things dont work so well. + + so we need to be careful how much we document this because this is + still a thorny use case. + + """ + A = self.classes.A + + s = fixture_session() + + initial_data = [ + {"data": "d3", "x": 1, "y": 2, "q": 4}, + {"data": "d4", "x": 5, "y": 6, "q": 8}, + ] + ids = s.scalars(insert(A).returning(A.id), initial_data).all() + + upsert_data = [ + { + "id": ids[0], + "type": "a", + "data": "d3", + "x": 1, + "y": 2, + }, + { + "id": 32, + "type": "a", + "data": "d32", + "x": 19, + "y": 5, + }, + { + "id": ids[1], + "type": "a", + "data": "d4", + "x": 5, + "y": 6, + }, + { + "id": 28, + "type": "a", + "data": "d28", + "x": 9, + "y": 15, + }, + ] + + stmt = provision.upsert( + config, + A, + (A,), + lambda inserted: {"data": inserted.data + " upserted"}, + ) + + if insert_strategy == "orm": + result = s.scalars(stmt.values(upsert_data)) + elif insert_strategy == "bulk": + result = s.scalars(stmt, upsert_data) + else: + assert False + + eq_( + result.all(), + [ + A(data="d3 upserted", id=ids[0], type="a", x=1, y=2), + A(data="d32", id=32, type="a", x=19, y=5), + A(data="d4 upserted", id=ids[1], type="a", x=5, y=6), + A(data="d28", id=28, type="a", x=9, y=15), + ], + ) + + @testing.combinations( + "orm", + "bulk", + argnames="insert_strategy", + ) + @testing.requires.provisioned_upsert + def test_subclass_upsert(self, insert_strategy): + """note this is overridden in the joined version to expect failure""" + + A, B = self.classes("A", "B") + + s = fixture_session() + + idd3 = 1 + idd4 = 2 + id32 = 32 + id28 = 28 + + initial_data = [ + { + "id": idd3, + "data": "d3", + "bd": "bd1", + "x": 1, + "y": 2, + "z": 3, + "q": 4, + }, + { + "id": idd4, + "data": "d4", + "bd": "bd2", + "x": 5, + "y": 6, + "z": 7, + "q": 8, + }, + ] + ids = s.scalars(insert(B).returning(B.id), initial_data).all() + + upsert_data = [ + { + "id": ids[0], + "type": "b", + "data": "d3", + "bd": "bd1_upserted", + "x": 1, + "y": 2, + "z": 33, + "q": 44, + }, + { + "id": id32, + "type": "b", + "data": "d32", + "bd": "bd 32", + "x": 19, + "y": 5, + "z": 20, + "q": 21, + }, + { + "id": ids[1], + "type": "b", + "bd": "bd2_upserted", + "data": "d4", + "x": 5, + "y": 6, + "z": 77, + "q": 88, + }, + { + "id": id28, + "type": "b", + "data": "d28", + "bd": "bd 28", + "x": 9, + "y": 15, + "z": 10, + "q": 11, + }, + ] + + stmt = provision.upsert( + config, + B, + (B,), + lambda inserted: { + "data": inserted.data + " upserted", + "bd": inserted.bd + " upserted", + }, + ) + result = s.scalars(stmt, upsert_data) + eq_( + result.all(), + [ + B( + bd="bd1_upserted upserted", + data="d3 upserted", + id=ids[0], + q=4, + type="b", + x=1, + y=2, + z=3, + ), + B( + bd="bd 32", + data="d32", + id=32, + q=21, + type="b", + x=19, + y=5, + z=20, + ), + B( + bd="bd2_upserted upserted", + data="d4 upserted", + id=ids[1], + q=8, + type="b", + x=5, + y=6, + z=7, + ), + B( + bd="bd 28", + data="d28", + id=28, + q=11, + type="b", + x=9, + y=15, + z=10, + ), + ], + ) + + +class BulkDMLReturningJoinedInhTest( + BulkDMLReturningInhTest, fixtures.DeclarativeMappedTest +): + + __requires__ = ("insert_returning",) + __backend__ = True + + @classmethod + def setup_classes(cls): + decl_base = cls.DeclarativeBasic + + class A(fixtures.ComparableEntity, decl_base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(Identity(), primary_key=True) + type: Mapped[str] + data: Mapped[str] + x: Mapped[Optional[int]] = mapped_column("xcol") + y: Mapped[Optional[int]] + + __mapper_args__ = { + "polymorphic_identity": "a", + "polymorphic_on": "type", + } + + class B(A): + __tablename__ = "b" + id: Mapped[int] = mapped_column( + ForeignKey("a.id"), primary_key=True + ) + bd: Mapped[str] + z: Mapped[Optional[int]] = mapped_column("zcol") + q: Mapped[Optional[int]] + + __mapper_args__ = {"polymorphic_identity": "b"} + + @testing.combinations( + "orm", + "bulk", + argnames="insert_strategy", + ) + @testing.combinations( + True, + False, + argnames="single_param", + ) + @testing.requires.provisioned_upsert + def test_subclass_upsert(self, insert_strategy, single_param): + A, B = self.classes("A", "B") + + s = fixture_session() + + initial_data = [ + {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4}, + {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8}, + ] + ids = s.scalars(insert(B).returning(B.id), initial_data).all() + + upsert_data = [ + { + "id": ids[0], + "type": "b", + }, + { + "id": 32, + "type": "b", + }, + ] + if single_param: + upsert_data = upsert_data[0] + + stmt = provision.upsert( + config, + B, + (B,), + lambda inserted: { + "bd": inserted.bd + " upserted", + }, + ) + + with expect_raises_message( + exc.InvalidRequestError, + r"bulk INSERT with a 'post values' clause \(typically upsert\) " + r"not supported for multi-table mapper", + ): + s.scalars(stmt, upsert_data) + + +class BulkDMLReturningSingleInhTest( + BulkDMLReturningInhTest, fixtures.DeclarativeMappedTest +): + __requires__ = ("insert_returning",) + __backend__ = True + + @classmethod + def setup_classes(cls): + decl_base = cls.DeclarativeBasic + + class A(fixtures.ComparableEntity, decl_base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(Identity(), primary_key=True) + type: Mapped[str] + data: Mapped[str] + x: Mapped[Optional[int]] = mapped_column("xcol") + y: Mapped[Optional[int]] + + __mapper_args__ = { + "polymorphic_identity": "a", + "polymorphic_on": "type", + } + + class B(A): + bd: Mapped[str] = mapped_column(nullable=True) + z: Mapped[Optional[int]] = mapped_column("zcol") + q: Mapped[Optional[int]] + + __mapper_args__ = {"polymorphic_identity": "b"} + + +class BulkDMLReturningConcreteInhTest( + BulkDMLReturningInhTest, fixtures.DeclarativeMappedTest +): + __requires__ = ("insert_returning",) + __backend__ = True + + @classmethod + def setup_classes(cls): + decl_base = cls.DeclarativeBasic + + class A(fixtures.ComparableEntity, decl_base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(Identity(), primary_key=True) + type: Mapped[str] + data: Mapped[str] + x: Mapped[Optional[int]] = mapped_column("xcol") + y: Mapped[Optional[int]] + + __mapper_args__ = { + "polymorphic_identity": "a", + "polymorphic_on": "type", + } + + class B(A): + __tablename__ = "b" + id: Mapped[int] = mapped_column(Identity(), primary_key=True) + type: Mapped[str] + data: Mapped[str] + x: Mapped[Optional[int]] = mapped_column("xcol") + y: Mapped[Optional[int]] + + bd: Mapped[str] = mapped_column(nullable=True) + z: Mapped[Optional[int]] = mapped_column("zcol") + q: Mapped[Optional[int]] + + __mapper_args__ = { + "polymorphic_identity": "b", + "concrete": True, + "polymorphic_on": "type", + } + + +class CTETest(fixtures.DeclarativeMappedTest): + __requires__ = ("insert_returning", "ctes_on_dml") + __backend__ = True + + @classmethod + def setup_classes(cls): + decl_base = cls.DeclarativeBasic + + class User(fixtures.ComparableEntity, decl_base): + __tablename__ = "users" + id: Mapped[uuid.UUID] = mapped_column(primary_key=True) + username: Mapped[str] + + @testing.combinations( + ("cte_aliased", True), + ("cte", False), + argnames="wrap_cte_in_aliased", + id_="ia", + ) + @testing.combinations( + ("use_union", True), + ("no_union", False), + argnames="use_a_union", + id_="ia", + ) + @testing.combinations( + "from_statement", "aliased", "direct", argnames="fetch_entity_type" + ) + def test_select_from_insert_cte( + self, wrap_cte_in_aliased, use_a_union, fetch_entity_type + ): + """test the use case from #8544; SELECT that selects from a + CTE INSERT...RETURNING. + + """ + User = self.classes.User + + id_ = uuid.uuid4() + + cte = ( + insert(User) + .values(id=id_, username="some user") + .returning(User) + .cte() + ) + if wrap_cte_in_aliased: + cte = aliased(User, cte) + + if use_a_union: + stmt = select(User).where(User.id == id_).union(select(cte)) + else: + stmt = select(cte) + + if fetch_entity_type == "from_statement": + outer_stmt = select(User).from_statement(stmt) + expect_entity = True + elif fetch_entity_type == "aliased": + outer_stmt = select(aliased(User, stmt.subquery())) + expect_entity = True + elif fetch_entity_type == "direct": + outer_stmt = stmt + expect_entity = not use_a_union and wrap_cte_in_aliased + else: + assert False + + sess = fixture_session() + with self.sql_execution_asserter() as asserter: + + if not expect_entity: + row = sess.execute(outer_stmt).one() + eq_(row, (id_, "some user")) + else: + new_user = sess.scalars(outer_stmt).one() + eq_(new_user, User(id=id_, username="some user")) + + cte_sql = ( + "(INSERT INTO users (id, username) " + "VALUES (:param_1, :param_2) " + "RETURNING users.id, users.username)" + ) + + if fetch_entity_type == "aliased" and not use_a_union: + expected = ( + f"WITH anon_2 AS {cte_sql} " + "SELECT anon_1.id, anon_1.username " + "FROM (SELECT anon_2.id AS id, anon_2.username AS username " + "FROM anon_2) AS anon_1" + ) + elif not use_a_union: + expected = ( + f"WITH anon_1 AS {cte_sql} " + "SELECT anon_1.id, anon_1.username FROM anon_1" + ) + elif fetch_entity_type == "aliased": + expected = ( + f"WITH anon_2 AS {cte_sql} SELECT anon_1.id, anon_1.username " + "FROM (SELECT users.id AS id, users.username AS username " + "FROM users WHERE users.id = :id_1 " + "UNION SELECT anon_2.id AS id, anon_2.username AS username " + "FROM anon_2) AS anon_1" + ) + else: + expected = ( + f"WITH anon_1 AS {cte_sql} " + "SELECT users.id, users.username FROM users " + "WHERE users.id = :id_1 " + "UNION SELECT anon_1.id, anon_1.username FROM anon_1" + ) + + asserter.assert_( + CompiledSQL(expected, [{"param_1": id_, "param_2": "some user"}]) + ) diff --git a/test/orm/test_evaluator.py b/test/orm/dml/test_evaluator.py index ff40cd201..4b903b863 100644 --- a/test/orm/test_evaluator.py +++ b/test/orm/dml/test_evaluator.py @@ -324,7 +324,6 @@ class EvaluateTest(fixtures.MappedTest): """test #3162""" User = self.classes.User - with expect_raises_message( evaluator.UnevaluatableError, r"Custom operator '\^\^' can't be evaluated in " diff --git a/test/orm/test_update_delete.py b/test/orm/dml/test_update_delete_where.py index 1e93f88de..836feb659 100644 --- a/test/orm/test_update_delete.py +++ b/test/orm/dml/test_update_delete_where.py @@ -1,3 +1,4 @@ +from sqlalchemy import bindparam from sqlalchemy import Boolean from sqlalchemy import case from sqlalchemy import column @@ -7,6 +8,7 @@ from sqlalchemy import exc from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import insert +from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import lambda_stmt from sqlalchemy import MetaData @@ -17,6 +19,7 @@ from sqlalchemy import testing from sqlalchemy import text from sqlalchemy import update from sqlalchemy.orm import backref +from sqlalchemy.orm import exc as orm_exc from sqlalchemy.orm import joinedload from sqlalchemy.orm import relationship from sqlalchemy.orm import Session @@ -26,6 +29,7 @@ from sqlalchemy.orm import with_loader_criteria from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises from sqlalchemy.testing import fixtures from sqlalchemy.testing import in_ from sqlalchemy.testing import not_in @@ -123,6 +127,25 @@ class UpdateDeleteTest(fixtures.MappedTest): }, ) + def test_update_dont_use_col_key(self): + User = self.classes.User + + s = fixture_session() + + # make sure objects are present to synchronize + _ = s.query(User).all() + + with expect_raises_message( + exc.InvalidRequestError, + "Attribute name not found, can't be synchronized back " + "to objects: 'age_int'", + ): + s.execute(update(User).values(age_int=5)) + + stmt = update(User).values(age=5) + s.execute(stmt) + eq_(s.scalars(select(User.age)).all(), [5, 5, 5, 5]) + @testing.combinations("table", "mapper", "both", argnames="bind_type") @testing.combinations( "update", "insert", "delete", argnames="statement_type" @@ -162,7 +185,7 @@ class UpdateDeleteTest(fixtures.MappedTest): assert_raises_message( exc.ArgumentError, "Valid strategies for session synchronization " - "are 'evaluate', 'fetch', False", + "are 'auto', 'evaluate', 'fetch', False", s.query(User).update, {}, synchronize_session="fake", @@ -351,6 +374,12 @@ class UpdateDeleteTest(fixtures.MappedTest): def test_evaluate_dont_refresh_expired_objects( self, expire_jane_age, add_filter_criteria ): + """test #5664. + + approach is revised in SQLAlchemy 2.0 to not pre-emptively + unexpire the involved attributes + + """ User = self.classes.User sess = fixture_session() @@ -379,15 +408,10 @@ class UpdateDeleteTest(fixtures.MappedTest): if add_filter_criteria: if expire_jane_age: asserter.assert_( - # it has to unexpire jane.name, because jane is not fully - # expired and the criteria needs to look at this particular - # key - CompiledSQL( - "SELECT users.age_int AS users_age_int, " - "users.name AS users_name FROM users " - "WHERE users.id = :pk_1", - [{"pk_1": 4}], - ), + # previously, this would unexpire the attribute and + # cause an additional SELECT. The + # 2.0 approach is that if the object has expired attrs + # we just expire the whole thing, avoiding SQL up front CompiledSQL( "UPDATE users " "SET age_int=(users.age_int + :age_int_1) " @@ -397,14 +421,10 @@ class UpdateDeleteTest(fixtures.MappedTest): ) else: asserter.assert_( - # it has to unexpire jane.name, because jane is not fully - # expired and the criteria needs to look at this particular - # key - CompiledSQL( - "SELECT users.name AS users_name FROM users " - "WHERE users.id = :pk_1", - [{"pk_1": 4}], - ), + # previously, this would unexpire the attribute and + # cause an additional SELECT. The + # 2.0 approach is that if the object has expired attrs + # we just expire the whole thing, avoiding SQL up front CompiledSQL( "UPDATE users SET " "age_int=(users.age_int + :age_int_1) " @@ -443,9 +463,9 @@ class UpdateDeleteTest(fixtures.MappedTest): ), ] - if expire_jane_age and not add_filter_criteria: + if expire_jane_age: to_assert.append( - # refresh jane + # refresh jane for partial attributes CompiledSQL( "SELECT users.age_int AS users_age_int, " "users.name AS users_name FROM users " @@ -455,6 +475,75 @@ class UpdateDeleteTest(fixtures.MappedTest): ) asserter.assert_(*to_assert) + @testing.combinations(True, False, argnames="is_evaluable") + def test_auto_synchronize(self, is_evaluable): + User = self.classes.User + + sess = fixture_session() + + john, jack, jill, jane = sess.query(User).order_by(User.id).all() + + if is_evaluable: + crit = or_(User.name == "jack", User.name == "jane") + else: + crit = case((User.name.in_(["jack", "jane"]), True), else_=False) + + with self.sql_execution_asserter() as asserter: + sess.execute(update(User).where(crit).values(age=User.age + 10)) + + if is_evaluable: + asserter.assert_( + CompiledSQL( + "UPDATE users SET age_int=(users.age_int + :age_int_1) " + "WHERE users.name = :name_1 OR users.name = :name_2", + [{"age_int_1": 10, "name_1": "jack", "name_2": "jane"}], + ), + ) + elif testing.db.dialect.update_returning: + asserter.assert_( + CompiledSQL( + "UPDATE users SET age_int=(users.age_int + :age_int_1) " + "WHERE CASE WHEN (users.name IN (__[POSTCOMPILE_name_1])) " + "THEN :param_1 ELSE :param_2 END = 1 RETURNING users.id", + [ + { + "age_int_1": 10, + "name_1": ["jack", "jane"], + "param_1": True, + "param_2": False, + } + ], + ), + ) + else: + asserter.assert_( + CompiledSQL( + "SELECT users.id FROM users WHERE CASE WHEN " + "(users.name IN (__[POSTCOMPILE_name_1])) " + "THEN :param_1 ELSE :param_2 END = 1", + [ + { + "name_1": ["jack", "jane"], + "param_1": True, + "param_2": False, + } + ], + ), + CompiledSQL( + "UPDATE users SET age_int=(users.age_int + :age_int_1) " + "WHERE CASE WHEN (users.name IN (__[POSTCOMPILE_name_1])) " + "THEN :param_1 ELSE :param_2 END = 1", + [ + { + "age_int_1": 10, + "name_1": ["jack", "jane"], + "param_1": True, + "param_2": False, + } + ], + ), + ) + def test_fetch_dont_refresh_expired_objects(self): User = self.classes.User @@ -518,17 +607,25 @@ class UpdateDeleteTest(fixtures.MappedTest): ), ) - def test_delete(self): + @testing.combinations(False, None, "auto", "evaluate", "fetch") + def test_delete(self, synchronize_session): User = self.classes.User sess = fixture_session() john, jack, jill, jane = sess.query(User).order_by(User.id).all() - sess.query(User).filter( + + stmt = delete(User).filter( or_(User.name == "john", User.name == "jill") - ).delete() + ) + if synchronize_session is not None: + stmt = stmt.execution_options( + synchronize_session=synchronize_session + ) + sess.execute(stmt) - assert john not in sess and jill not in sess + if synchronize_session not in (False, None): + assert john not in sess and jill not in sess eq_(sess.query(User).order_by(User.id).all(), [jack, jane]) @@ -629,6 +726,33 @@ class UpdateDeleteTest(fixtures.MappedTest): eq_(sess.query(User).order_by(User.id).all(), [jack, jill, jane]) + def test_update_multirow_not_supported(self): + User = self.classes.User + + sess = fixture_session() + + with expect_raises_message( + exc.InvalidRequestError, + "WHERE clause with bulk ORM UPDATE not supported " "right now.", + ): + sess.execute( + update(User).where(User.id == bindparam("id")), + [{"id": 1, "age": 27}, {"id": 2, "age": 37}], + ) + + def test_delete_bulk_not_supported(self): + User = self.classes.User + + sess = fixture_session() + + with expect_raises_message( + exc.InvalidRequestError, "Bulk ORM DELETE not supported right now." + ): + sess.execute( + delete(User), + [{"id": 1}, {"id": 2}], + ) + def test_update(self): User, users = self.classes.User, self.tables.users @@ -640,6 +764,7 @@ class UpdateDeleteTest(fixtures.MappedTest): ) eq_([john.age, jack.age, jill.age, jane.age], [25, 37, 29, 27]) + eq_( sess.query(User.age).order_by(User.id).all(), list(zip([25, 37, 29, 27])), @@ -974,7 +1099,7 @@ class UpdateDeleteTest(fixtures.MappedTest): ) @testing.requires.update_returning - def test_update_explicit_returning(self): + def test_update_evaluate_w_explicit_returning(self): User = self.classes.User sess = fixture_session() @@ -987,6 +1112,7 @@ class UpdateDeleteTest(fixtures.MappedTest): .filter(User.age > 29) .values({"age": User.age - 10}) .returning(User.id) + .execution_options(synchronize_session="evaluate") ) rows = sess.execute(stmt).all() @@ -1006,24 +1132,41 @@ class UpdateDeleteTest(fixtures.MappedTest): ) @testing.requires.update_returning - def test_no_fetch_w_explicit_returning(self): + @testing.combinations("update", "delete", argnames="crud_type") + def test_fetch_w_explicit_returning(self, crud_type): User = self.classes.User sess = fixture_session() - stmt = ( - update(User) - .filter(User.age > 29) - .values({"age": User.age - 10}) - .execution_options(synchronize_session="fetch") - .returning(User.id) - ) - with expect_raises_message( - exc.InvalidRequestError, - r"Can't use synchronize_session='fetch' " - r"with explicit returning\(\)", - ): - sess.execute(stmt) + if crud_type == "update": + stmt = ( + update(User) + .filter(User.age > 29) + .values({"age": User.age - 10}) + .execution_options(synchronize_session="fetch") + .returning(User, User.name) + ) + expected = [ + (User(age=37), "jack"), + (User(age=27), "jane"), + ] + elif crud_type == "delete": + stmt = ( + delete(User) + .filter(User.age > 29) + .execution_options(synchronize_session="fetch") + .returning(User, User.name) + ) + expected = [ + (User(age=47), "jack"), + (User(age=37), "jane"), + ] + else: + assert False + + result = sess.execute(stmt) + + eq_(result.all(), expected) @testing.combinations(True, False, argnames="implicit_returning") def test_delete_fetch_returning(self, implicit_returning): @@ -1142,7 +1285,8 @@ class UpdateDeleteTest(fixtures.MappedTest): list(zip([25, 47, 44, 37])), ) - def test_update_changes_resets_dirty(self): + @testing.combinations("orm", "bulk") + def test_update_changes_resets_dirty(self, update_type): User = self.classes.User sess = fixture_session(autoflush=False) @@ -1155,9 +1299,30 @@ class UpdateDeleteTest(fixtures.MappedTest): # autoflush is false. therefore our '50' and '37' are getting # blown away by this operation. - sess.query(User).filter(User.age > 29).update( - {"age": User.age - 10}, synchronize_session="evaluate" - ) + if update_type == "orm": + sess.execute( + update(User) + .filter(User.age > 29) + .values({"age": User.age - 10}), + execution_options=dict(synchronize_session="evaluate"), + ) + elif update_type == "bulk": + + data = [ + {"id": john.id, "age": 25}, + {"id": jack.id, "age": 37}, + {"id": jill.id, "age": 29}, + {"id": jane.id, "age": 27}, + ] + + sess.execute( + update(User), + data, + execution_options=dict(synchronize_session="evaluate"), + ) + + else: + assert False for x in (john, jack, jill, jane): assert not sess.is_modified(x) @@ -1171,6 +1336,93 @@ class UpdateDeleteTest(fixtures.MappedTest): assert not sess.is_modified(john) assert not sess.is_modified(jack) + @testing.combinations( + None, False, "evaluate", "fetch", argnames="synchronize_session" + ) + @testing.combinations(True, False, argnames="homogeneous_keys") + def test_bulk_update_synchronize_session( + self, synchronize_session, homogeneous_keys + ): + User = self.classes.User + + sess = fixture_session(expire_on_commit=False) + + john, jack, jill, jane = sess.query(User).order_by(User.id).all() + + if homogeneous_keys: + data = [ + {"id": john.id, "age": 35}, + {"id": jack.id, "age": 27}, + {"id": jill.id, "age": 30}, + ] + else: + data = [ + {"id": john.id, "age": 35}, + {"id": jack.id, "name": "new jack"}, + {"id": jill.id, "age": 30, "name": "new jill"}, + ] + + with self.sql_execution_asserter() as asserter: + if synchronize_session is not None: + opts = {"synchronize_session": synchronize_session} + else: + opts = {} + + if synchronize_session == "fetch": + with expect_raises_message( + exc.InvalidRequestError, + "The 'fetch' synchronization strategy is not available " + "for 'bulk' ORM updates", + ): + sess.execute(update(User), data, execution_options=opts) + return + else: + sess.execute(update(User), data, execution_options=opts) + + if homogeneous_keys: + asserter.assert_( + CompiledSQL( + "UPDATE users SET age_int=:age_int " + "WHERE users.id = :users_id", + [ + {"age_int": 35, "users_id": 1}, + {"age_int": 27, "users_id": 2}, + {"age_int": 30, "users_id": 3}, + ], + ) + ) + else: + asserter.assert_( + CompiledSQL( + "UPDATE users SET age_int=:age_int " + "WHERE users.id = :users_id", + [{"age_int": 35, "users_id": 1}], + ), + CompiledSQL( + "UPDATE users SET name=:name WHERE users.id = :users_id", + [{"name": "new jack", "users_id": 2}], + ), + CompiledSQL( + "UPDATE users SET name=:name, age_int=:age_int " + "WHERE users.id = :users_id", + [{"name": "new jill", "age_int": 30, "users_id": 3}], + ), + ) + + if synchronize_session is False: + eq_(jill.name, "jill") + eq_(jack.name, "jack") + eq_(jill.age, 29) + eq_(jack.age, 47) + else: + if not homogeneous_keys: + eq_(jill.name, "new jill") + eq_(jack.name, "new jack") + eq_(jack.age, 47) + else: + eq_(jack.age, 27) + eq_(jill.age, 30) + def test_update_changes_with_autoflush(self): User = self.classes.User @@ -1214,7 +1466,8 @@ class UpdateDeleteTest(fixtures.MappedTest): ) @testing.fails_if(lambda: not testing.db.dialect.supports_sane_rowcount) - def test_update_returns_rowcount(self): + @testing.combinations("auto", "fetch", "evaluate") + def test_update_returns_rowcount(self, synchronize_session): User = self.classes.User sess = fixture_session() @@ -1222,20 +1475,25 @@ class UpdateDeleteTest(fixtures.MappedTest): rowcount = ( sess.query(User) .filter(User.age > 29) - .update({"age": User.age + 0}) + .update( + {"age": User.age + 0}, synchronize_session=synchronize_session + ) ) eq_(rowcount, 2) rowcount = ( sess.query(User) .filter(User.age > 29) - .update({"age": User.age - 10}) + .update( + {"age": User.age - 10}, synchronize_session=synchronize_session + ) ) eq_(rowcount, 2) # test future result = sess.execute( - update(User).where(User.age > 19).values({"age": User.age - 10}) + update(User).where(User.age > 19).values({"age": User.age - 10}), + execution_options={"synchronize_session": synchronize_session}, ) eq_(result.rowcount, 4) @@ -1327,12 +1585,17 @@ class UpdateDeleteTest(fixtures.MappedTest): ) assert john not in sess - def test_evaluate_before_update(self): + @testing.combinations(True, False) + def test_evaluate_before_update(self, full_expiration): User = self.classes.User sess = fixture_session() john = sess.query(User).filter_by(name="john").one() - sess.expire(john, ["age"]) + + if full_expiration: + sess.expire(john) + else: + sess.expire(john, ["age"]) # eval must be before the update. otherwise # we eval john, age has been expired and doesn't @@ -1356,17 +1619,47 @@ class UpdateDeleteTest(fixtures.MappedTest): eq_(john.name, "j2") eq_(john.age, 40) - def test_evaluate_before_delete(self): + @testing.combinations(True, False) + def test_evaluate_before_delete(self, full_expiration): User = self.classes.User sess = fixture_session() john = sess.query(User).filter_by(name="john").one() - sess.expire(john, ["age"]) + jill = sess.query(User).filter_by(name="jill").one() + jane = sess.query(User).filter_by(name="jane").one() - sess.query(User).filter_by(name="john").filter_by(age=25).delete( + if full_expiration: + sess.expire(jill) + sess.expire(john) + else: + sess.expire(jill, ["age"]) + sess.expire(john, ["age"]) + + sess.query(User).filter(or_(User.age == 25, User.age == 37)).delete( synchronize_session="evaluate" ) - assert john not in sess + + # was fully deleted + assert jane not in sess + + # deleted object was expired, but not otherwise affected + assert jill in sess + + # deleted object was expired, but not otherwise affected + assert john in sess + + # partially expired row fully expired + assert inspect(jill).expired + + # non-deleted row still present + eq_(jill.age, 29) + + # partially expired row fully expired + assert inspect(john).expired + + # is deleted + with expect_raises(orm_exc.ObjectDeletedError): + john.name def test_fetch_before_delete(self): User = self.classes.User @@ -1378,6 +1671,7 @@ class UpdateDeleteTest(fixtures.MappedTest): sess.query(User).filter_by(name="john").filter_by(age=25).delete( synchronize_session="fetch" ) + assert john not in sess def test_update_unordered_dict(self): @@ -1495,6 +1789,60 @@ class UpdateDeleteTest(fixtures.MappedTest): ] eq_(["name", "age_int"], cols) + @testing.requires.sqlite + def test_sharding_extension_returning_mismatch(self, testing_engine): + """test one horizontal shard case where the given binds don't match + for RETURNING support; we dont support this. + + See test/ext/test_horizontal_shard.py for complete round trip + test cases for ORM update/delete + + """ + e1 = testing_engine("sqlite://") + e2 = testing_engine("sqlite://") + e1.connect().close() + e2.connect().close() + + e1.dialect.update_returning = True + e2.dialect.update_returning = False + + engines = [e1, e2] + + # a simulated version of the horizontal sharding extension + def execute_and_instances(orm_context): + execution_options = dict(orm_context.local_execution_options) + partial = [] + for engine in engines: + bind_arguments = dict(orm_context.bind_arguments) + bind_arguments["bind"] = engine + result_ = orm_context.invoke_statement( + bind_arguments=bind_arguments, + execution_options=execution_options, + ) + + partial.append(result_) + return partial[0].merge(*partial[1:]) + + User = self.classes.User + session = Session() + + event.listen( + session, "do_orm_execute", execute_and_instances, retval=True + ) + + stmt = ( + update(User) + .filter(User.id == 15) + .values(age=123) + .execution_options(synchronize_session="fetch") + ) + with expect_raises_message( + exc.InvalidRequestError, + "For synchronize_session='fetch', can't mix multiple backends " + "where some support RETURNING and others don't", + ): + session.execute(stmt) + class UpdateDeleteIgnoresLoadersTest(fixtures.MappedTest): @classmethod @@ -1748,6 +2096,7 @@ class UpdateDeleteFromTest(fixtures.MappedTest): "Could not evaluate current criteria in Python.", q.update, {"samename": "ed"}, + synchronize_session="evaluate", ) @testing.requires.multi_table_update @@ -1901,7 +2250,7 @@ class ExpressionUpdateTest(fixtures.MappedTest): sess.commit() eq_(d1.cnt, 0) - sess.query(Data).update({Data.cnt: Data.cnt + 1}) + sess.query(Data).update({Data.cnt: Data.cnt + 1}, "evaluate") sess.flush() eq_(d1.cnt, 1) @@ -2443,7 +2792,8 @@ class LoadFromReturningTest(fixtures.MappedTest): ) @testing.requires.update_returning - def test_load_from_update(self, connection): + @testing.combinations(True, False, argnames="use_from_statement") + def test_load_from_update(self, connection, use_from_statement): User = self.classes.User stmt = ( @@ -2453,7 +2803,16 @@ class LoadFromReturningTest(fixtures.MappedTest): .returning(User) ) - stmt = select(User).from_statement(stmt) + if use_from_statement: + # this is now a legacy-ish case, because as of 2.0 you can just + # use returning() directly to get the objects back. + # + # when from_statement is used, the UPDATE statement is no + # longer interpreted by + # BulkUDCompileState.orm_pre_session_exec or + # BulkUDCompileState.orm_setup_cursor_result. The compilation + # level routines still take place though + stmt = select(User).from_statement(stmt) with Session(connection) as sess: rows = sess.execute(stmt).scalars().all() @@ -2468,7 +2827,8 @@ class LoadFromReturningTest(fixtures.MappedTest): ("multiple", testing.requires.multivalues_inserts), argnames="params", ) - def test_load_from_insert(self, connection, params): + @testing.combinations(True, False, argnames="use_from_statement") + def test_load_from_insert(self, connection, params, use_from_statement): User = self.classes.User if params == "multiple": @@ -2484,7 +2844,8 @@ class LoadFromReturningTest(fixtures.MappedTest): stmt = insert(User).values(values).returning(User) - stmt = select(User).from_statement(stmt) + if use_from_statement: + stmt = select(User).from_statement(stmt) with Session(connection) as sess: rows = sess.execute(stmt).scalars().all() @@ -2505,3 +2866,25 @@ class LoadFromReturningTest(fixtures.MappedTest): ) else: assert False + + @testing.requires.delete_returning + @testing.combinations(True, False, argnames="use_from_statement") + def test_load_from_delete(self, connection, use_from_statement): + User = self.classes.User + + stmt = ( + delete(User).where(User.name.in_(["jack", "jill"])).returning(User) + ) + + if use_from_statement: + stmt = select(User).from_statement(stmt) + + with Session(connection) as sess: + rows = sess.execute(stmt).scalars().all() + + eq_( + rows, + [User(name="jack", age=47), User(name="jill", age=29)], + ) + + # TODO: state of above objects should be "deleted" diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py index 2e3874549..5f8cfc1f5 100644 --- a/test/orm/inheritance/test_basic.py +++ b/test/orm/inheritance/test_basic.py @@ -2012,7 +2012,8 @@ class JoinedNoFKSortingTest(fixtures.MappedTest): and testing.db.dialect.supports_default_metavalue, [ CompiledSQL( - "INSERT INTO a (id) VALUES (DEFAULT)", [{}, {}, {}, {}] + "INSERT INTO a (id) VALUES (DEFAULT) RETURNING a.id", + [{}, {}, {}, {}], ), ], [ diff --git a/test/orm/test_bind.py b/test/orm/test_bind.py index a6480365d..2f392cf6e 100644 --- a/test/orm/test_bind.py +++ b/test/orm/test_bind.py @@ -326,6 +326,7 @@ class BindIntegrationTest(_fixtures.FixtureTest): ), ( lambda User: update(User) + .execution_options(synchronize_session=False) .values(name="not ed") .where(User.name == "ed"), lambda User: {"clause": mock.ANY, "mapper": inspect(User)}, @@ -392,7 +393,15 @@ class BindIntegrationTest(_fixtures.FixtureTest): engine = {"e1": e1, "e2": e2, "e3": e3}[expected_engine_name] with mock.patch( - "sqlalchemy.orm.context.ORMCompileState.orm_setup_cursor_result" + "sqlalchemy.orm.context." "ORMCompileState.orm_setup_cursor_result" + ), mock.patch( + "sqlalchemy.orm.context.ORMCompileState.orm_execute_statement" + ), mock.patch( + "sqlalchemy.orm.bulk_persistence." + "BulkORMInsert.orm_execute_statement" + ), mock.patch( + "sqlalchemy.orm.bulk_persistence." + "BulkUDCompileState.orm_setup_cursor_result" ): sess.execute(statement) diff --git a/test/orm/test_composites.py b/test/orm/test_composites.py index 3a789aff7..efa2ecb45 100644 --- a/test/orm/test_composites.py +++ b/test/orm/test_composites.py @@ -1,8 +1,10 @@ import dataclasses import operator +import random import sqlalchemy as sa from sqlalchemy import ForeignKey +from sqlalchemy import insert from sqlalchemy import Integer from sqlalchemy import select from sqlalchemy import String @@ -233,7 +235,7 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL): is g.edges[1] ) - def test_bulk_update_sql(self): + def test_update_crit_sql(self): Edge, Point = (self.classes.Edge, self.classes.Point) sess = self._fixture() @@ -258,7 +260,7 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL): dialect="default", ) - def test_bulk_update_evaluate(self): + def test_update_crit_evaluate(self): Edge, Point = (self.classes.Edge, self.classes.Point) sess = self._fixture() @@ -287,7 +289,7 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL): eq_(e1.end, Point(17, 8)) - def test_bulk_update_fetch(self): + def test_update_crit_fetch(self): Edge, Point = (self.classes.Edge, self.classes.Point) sess = self._fixture() @@ -305,6 +307,205 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL): eq_(e1.end, Point(17, 8)) + @testing.combinations( + "legacy", "statement", "values", "stmt_returning", "values_returning" + ) + def test_bulk_insert(self, type_): + Edge, Point = (self.classes.Edge, self.classes.Point) + Graph = self.classes.Graph + + sess = self._fixture() + + graph = Graph(id=2) + sess.add(graph) + sess.flush() + graph_id = 2 + + data = [ + { + "start": Point(random.randint(1, 50), random.randint(1, 50)), + "end": Point(random.randint(1, 50), random.randint(1, 50)), + "graph_id": graph_id, + } + for i in range(25) + ] + returning = False + if type_ == "statement": + sess.execute(insert(Edge), data) + elif type_ == "stmt_returning": + result = sess.scalars(insert(Edge).returning(Edge), data) + returning = True + elif type_ == "values": + sess.execute(insert(Edge).values(data)) + elif type_ == "values_returning": + result = sess.scalars(insert(Edge).values(data).returning(Edge)) + returning = True + elif type_ == "legacy": + sess.bulk_insert_mappings(Edge, data) + else: + assert False + + if returning: + eq_(result.all(), [Edge(rec["start"], rec["end"]) for rec in data]) + + edges = self.tables.edges + eq_( + sess.execute( + select(edges.c["x1", "y1", "x2", "y2"]) + .where(edges.c.graph_id == graph_id) + .order_by(edges.c.id) + ).all(), + [ + (e["start"].x, e["start"].y, e["end"].x, e["end"].y) + for e in data + ], + ) + + @testing.combinations("legacy", "statement") + def test_bulk_insert_heterogeneous(self, type_): + Edge, Point = (self.classes.Edge, self.classes.Point) + Graph = self.classes.Graph + + sess = self._fixture() + + graph = Graph(id=2) + sess.add(graph) + sess.flush() + graph_id = 2 + + d1 = [ + { + "start": Point(random.randint(1, 50), random.randint(1, 50)), + "end": Point(random.randint(1, 50), random.randint(1, 50)), + "graph_id": graph_id, + } + for i in range(3) + ] + d2 = [ + { + "start": Point(random.randint(1, 50), random.randint(1, 50)), + "graph_id": graph_id, + } + for i in range(2) + ] + d3 = [ + { + "x2": random.randint(1, 50), + "y2": random.randint(1, 50), + "graph_id": graph_id, + } + for i in range(2) + ] + data = d1 + d2 + d3 + random.shuffle(data) + + assert_data = [ + { + "start": d["start"] if "start" in d else None, + "end": d["end"] + if "end" in d + else Point(d["x2"], d["y2"]) + if "x2" in d + else None, + "graph_id": d["graph_id"], + } + for d in data + ] + + if type_ == "statement": + sess.execute(insert(Edge), data) + elif type_ == "legacy": + sess.bulk_insert_mappings(Edge, data) + else: + assert False + + edges = self.tables.edges + eq_( + sess.execute( + select(edges.c["x1", "y1", "x2", "y2"]) + .where(edges.c.graph_id == graph_id) + .order_by(edges.c.id) + ).all(), + [ + ( + e["start"].x if e["start"] else None, + e["start"].y if e["start"] else None, + e["end"].x if e["end"] else None, + e["end"].y if e["end"] else None, + ) + for e in assert_data + ], + ) + + @testing.combinations("legacy", "statement") + def test_bulk_update(self, type_): + Edge, Point = (self.classes.Edge, self.classes.Point) + Graph = self.classes.Graph + + sess = self._fixture() + + graph = Graph(id=2) + sess.add(graph) + sess.flush() + graph_id = 2 + + data = [ + { + "start": Point(random.randint(1, 50), random.randint(1, 50)), + "end": Point(random.randint(1, 50), random.randint(1, 50)), + "graph_id": graph_id, + } + for i in range(25) + ] + sess.execute(insert(Edge), data) + + inserted_data = [ + dict(row._mapping) + for row in sess.execute( + select(Edge.id, Edge.start, Edge.end, Edge.graph_id) + .where(Edge.graph_id == graph_id) + .order_by(Edge.id) + ) + ] + + to_update = [] + updated_pks = {} + for rec in random.choices(inserted_data, k=7): + rec_copy = dict(rec) + updated_pks[rec_copy["id"]] = rec_copy + rec_copy["start"] = Point( + random.randint(1, 50), random.randint(1, 50) + ) + rec_copy["end"] = Point( + random.randint(1, 50), random.randint(1, 50) + ) + to_update.append(rec_copy) + + expected_dataset = [ + updated_pks[row["id"]] if row["id"] in updated_pks else row + for row in inserted_data + ] + + if type_ == "statement": + sess.execute(update(Edge), to_update) + elif type_ == "legacy": + sess.bulk_update_mappings(Edge, to_update) + else: + assert False + + edges = self.tables.edges + eq_( + sess.execute( + select(edges.c["x1", "y1", "x2", "y2"]) + .where(edges.c.graph_id == graph_id) + .order_by(edges.c.id) + ).all(), + [ + (e["start"].x, e["start"].y, e["end"].x, e["end"].y) + for e in expected_dataset + ], + ) + def test_get_history(self): Edge = self.classes.Edge Point = self.classes.Point diff --git a/test/orm/test_cycles.py b/test/orm/test_cycles.py index 15155293f..4d4a7ff64 100644 --- a/test/orm/test_cycles.py +++ b/test/orm/test_cycles.py @@ -1125,7 +1125,7 @@ class OneToManyManyToOneTest(fixtures.MappedTest): [ CompiledSQL( "INSERT INTO ball (person_id, data) " - "VALUES (:person_id, :data)", + "VALUES (:person_id, :data) RETURNING ball.id", [ {"person_id": None, "data": "some data"}, {"person_id": None, "data": "some data"}, diff --git a/test/orm/test_defaults.py b/test/orm/test_defaults.py index 7860f5eb1..e738689b8 100644 --- a/test/orm/test_defaults.py +++ b/test/orm/test_defaults.py @@ -383,20 +383,24 @@ class ComputedDefaultsOnUpdateTest(fixtures.MappedTest): CompiledSQL( "UPDATE test SET foo=:foo WHERE test.id = :test_id", [{"foo": 5, "test_id": 1}], + enable_returning=False, ), CompiledSQL( "UPDATE test SET foo=:foo WHERE test.id = :test_id", [{"foo": 6, "test_id": 2}], + enable_returning=False, ), CompiledSQL( "SELECT test.bar AS test_bar FROM test " "WHERE test.id = :pk_1", [{"pk_1": 1}], + enable_returning=False, ), CompiledSQL( "SELECT test.bar AS test_bar FROM test " "WHERE test.id = :pk_1", [{"pk_1": 2}], + enable_returning=False, ), ) else: diff --git a/test/orm/test_events.py b/test/orm/test_events.py index 24870e20f..75955afb5 100644 --- a/test/orm/test_events.py +++ b/test/orm/test_events.py @@ -661,8 +661,17 @@ class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest): canary = self._flag_fixture(sess) - sess.execute(delete(User).filter_by(id=18)) - sess.execute(update(User).filter_by(id=18).values(name="eighteen")) + sess.execute( + delete(User) + .filter_by(id=18) + .execution_options(synchronize_session="evaluate") + ) + sess.execute( + update(User) + .filter_by(id=18) + .values(name="eighteen") + .execution_options(synchronize_session="evaluate") + ) eq_( canary.mock_calls, diff --git a/test/orm/test_unitofwork.py b/test/orm/test_unitofwork.py index b94998716..fc452dc9c 100644 --- a/test/orm/test_unitofwork.py +++ b/test/orm/test_unitofwork.py @@ -2868,12 +2868,14 @@ class SaveTest2(_fixtures.FixtureTest): testing.db.dialect.insert_executemany_returning, [ CompiledSQL( - "INSERT INTO users (name) VALUES (:name)", + "INSERT INTO users (name) VALUES (:name) " + "RETURNING users.id", [{"name": "u1"}, {"name": "u2"}], ), CompiledSQL( "INSERT INTO addresses (user_id, email_address) " - "VALUES (:user_id, :email_address)", + "VALUES (:user_id, :email_address) " + "RETURNING addresses.id", [ {"user_id": 1, "email_address": "a1"}, {"user_id": 2, "email_address": "a2"}, diff --git a/test/orm/test_unitofworkv2.py b/test/orm/test_unitofworkv2.py index dd3b88915..855b44e81 100644 --- a/test/orm/test_unitofworkv2.py +++ b/test/orm/test_unitofworkv2.py @@ -98,7 +98,8 @@ class RudimentaryFlushTest(UOWTest): [ CompiledSQL( "INSERT INTO addresses (user_id, email_address) " - "VALUES (:user_id, :email_address)", + "VALUES (:user_id, :email_address) " + "RETURNING addresses.id", lambda ctx: [ {"email_address": "a1", "user_id": u1.id}, {"email_address": "a2", "user_id": u1.id}, @@ -220,7 +221,8 @@ class RudimentaryFlushTest(UOWTest): [ CompiledSQL( "INSERT INTO addresses (user_id, email_address) " - "VALUES (:user_id, :email_address)", + "VALUES (:user_id, :email_address) " + "RETURNING addresses.id", lambda ctx: [ {"email_address": "a1", "user_id": u1.id}, {"email_address": "a2", "user_id": u1.id}, @@ -889,7 +891,7 @@ class SingleCycleTest(UOWTest): [ CompiledSQL( "INSERT INTO nodes (parent_id, data) VALUES " - "(:parent_id, :data)", + "(:parent_id, :data) RETURNING nodes.id", lambda ctx: [ {"parent_id": n1.id, "data": "n2"}, {"parent_id": n1.id, "data": "n3"}, @@ -1003,7 +1005,7 @@ class SingleCycleTest(UOWTest): [ CompiledSQL( "INSERT INTO nodes (parent_id, data) VALUES " - "(:parent_id, :data)", + "(:parent_id, :data) RETURNING nodes.id", lambda ctx: [ {"parent_id": n1.id, "data": "n2"}, {"parent_id": n1.id, "data": "n3"}, @@ -1165,7 +1167,7 @@ class SingleCycleTest(UOWTest): [ CompiledSQL( "INSERT INTO nodes (parent_id, data) VALUES " - "(:parent_id, :data)", + "(:parent_id, :data) RETURNING nodes.id", lambda ctx: [ {"parent_id": n1.id, "data": "n11"}, {"parent_id": n1.id, "data": "n12"}, @@ -1196,7 +1198,7 @@ class SingleCycleTest(UOWTest): [ CompiledSQL( "INSERT INTO nodes (parent_id, data) VALUES " - "(:parent_id, :data)", + "(:parent_id, :data) RETURNING nodes.id", lambda ctx: [ {"parent_id": n12.id, "data": "n121"}, {"parent_id": n12.id, "data": "n122"}, @@ -2099,7 +2101,7 @@ class BatchInsertsTest(fixtures.MappedTest, testing.AssertsExecutionResults): testing.db.dialect.insert_executemany_returning, [ CompiledSQL( - "INSERT INTO t (data) VALUES (:data)", + "INSERT INTO t (data) VALUES (:data) RETURNING t.id", [{"data": "t1"}, {"data": "t2"}], ), ], @@ -2472,20 +2474,24 @@ class EagerDefaultsTest(fixtures.MappedTest): CompiledSQL( "INSERT INTO test (id, foo) VALUES (:id, 2 + 5)", [{"id": 1}], + enable_returning=False, ), CompiledSQL( "INSERT INTO test (id, foo) VALUES (:id, 5 + 5)", [{"id": 2}], + enable_returning=False, ), CompiledSQL( "SELECT test.foo AS test_foo FROM test " "WHERE test.id = :pk_1", [{"pk_1": 1}], + enable_returning=False, ), CompiledSQL( "SELECT test.foo AS test_foo FROM test " "WHERE test.id = :pk_1", [{"pk_1": 2}], + enable_returning=False, ), ) @@ -2678,20 +2684,24 @@ class EagerDefaultsTest(fixtures.MappedTest): CompiledSQL( "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id", [{"foo": 5, "test2_id": 1}], + enable_returning=False, ), CompiledSQL( "UPDATE test2 SET foo=:foo, bar=:bar " "WHERE test2.id = :test2_id", [{"foo": 6, "bar": 10, "test2_id": 2}], + enable_returning=False, ), CompiledSQL( "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id", [{"foo": 7, "test2_id": 3}], + enable_returning=False, ), CompiledSQL( "UPDATE test2 SET foo=:foo, bar=:bar " "WHERE test2.id = :test2_id", [{"foo": 8, "bar": 12, "test2_id": 4}], + enable_returning=False, ), CompiledSQL( "SELECT test2.bar AS test2_bar FROM test2 " @@ -2772,31 +2782,37 @@ class EagerDefaultsTest(fixtures.MappedTest): "UPDATE test4 SET foo=:foo, bar=5 + 3 " "WHERE test4.id = :test4_id", [{"foo": 5, "test4_id": 1}], + enable_returning=False, ), CompiledSQL( "UPDATE test4 SET foo=:foo, bar=:bar " "WHERE test4.id = :test4_id", [{"foo": 6, "bar": 10, "test4_id": 2}], + enable_returning=False, ), CompiledSQL( "UPDATE test4 SET foo=:foo, bar=5 + 3 " "WHERE test4.id = :test4_id", [{"foo": 7, "test4_id": 3}], + enable_returning=False, ), CompiledSQL( "UPDATE test4 SET foo=:foo, bar=:bar " "WHERE test4.id = :test4_id", [{"foo": 8, "bar": 12, "test4_id": 4}], + enable_returning=False, ), CompiledSQL( "SELECT test4.bar AS test4_bar FROM test4 " "WHERE test4.id = :pk_1", [{"pk_1": 1}], + enable_returning=False, ), CompiledSQL( "SELECT test4.bar AS test4_bar FROM test4 " "WHERE test4.id = :pk_1", [{"pk_1": 3}], + enable_returning=False, ), ], ), @@ -2871,20 +2887,24 @@ class EagerDefaultsTest(fixtures.MappedTest): "UPDATE test2 SET foo=:foo, bar=1 + 1 " "WHERE test2.id = :test2_id", [{"foo": 5, "test2_id": 1}], + enable_returning=False, ), CompiledSQL( "UPDATE test2 SET foo=:foo, bar=:bar " "WHERE test2.id = :test2_id", [{"foo": 6, "bar": 10, "test2_id": 2}], + enable_returning=False, ), CompiledSQL( "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id", [{"foo": 7, "test2_id": 3}], + enable_returning=False, ), CompiledSQL( "UPDATE test2 SET foo=:foo, bar=5 + 7 " "WHERE test2.id = :test2_id", [{"foo": 8, "test2_id": 4}], + enable_returning=False, ), CompiledSQL( "SELECT test2.bar AS test2_bar FROM test2 " diff --git a/test/orm/test_versioning.py b/test/orm/test_versioning.py index abd5833be..84e5a83b0 100644 --- a/test/orm/test_versioning.py +++ b/test/orm/test_versioning.py @@ -1424,12 +1424,10 @@ class ServerVersioningTest(fixtures.MappedTest): sess.add(f1) statements = [ - # note that the assertsql tests the rule against - # "default" - on a "returning" backend, the statement - # includes "RETURNING" CompiledSQL( "INSERT INTO version_table (version_id, value) " - "VALUES (1, :value)", + "VALUES (1, :value) " + "RETURNING version_table.id, version_table.version_id", lambda ctx: [{"value": "f1"}], ) ] @@ -1493,6 +1491,7 @@ class ServerVersioningTest(fixtures.MappedTest): "value": "f2", } ], + enable_returning=False, ), CompiledSQL( "SELECT version_table.version_id " @@ -1618,6 +1617,7 @@ class ServerVersioningTest(fixtures.MappedTest): "value": "f1a", } ], + enable_returning=False, ), CompiledSQL( "UPDATE version_table SET version_id=2, value=:value " @@ -1630,6 +1630,7 @@ class ServerVersioningTest(fixtures.MappedTest): "value": "f2a", } ], + enable_returning=False, ), CompiledSQL( "UPDATE version_table SET version_id=2, value=:value " @@ -1642,6 +1643,7 @@ class ServerVersioningTest(fixtures.MappedTest): "value": "f3a", } ], + enable_returning=False, ), CompiledSQL( "SELECT version_table.version_id " diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py index 42cf31bf5..4f776e300 100644 --- a/test/sql/test_resultset.py +++ b/test/sql/test_resultset.py @@ -100,10 +100,55 @@ class CursorResultTest(fixtures.TablesTest): Table( "test", metadata, - Column("x", Integer, primary_key=True), + Column( + "x", Integer, primary_key=True, test_needs_autoincrement=False + ), Column("y", String(50)), ) + @testing.requires.insert_returning + def test_splice_horizontally(self, connection): + users = self.tables.users + addresses = self.tables.addresses + + r1 = connection.execute( + users.insert().returning(users.c.user_name, users.c.user_id), + [ + dict(user_id=1, user_name="john"), + dict(user_id=2, user_name="jack"), + ], + ) + + r2 = connection.execute( + addresses.insert().returning( + addresses.c.address_id, + addresses.c.address, + addresses.c.user_id, + ), + [ + dict(address_id=1, user_id=1, address="foo@bar.com"), + dict(address_id=2, user_id=2, address="bar@bat.com"), + ], + ) + + rows = r1.splice_horizontally(r2).all() + eq_( + rows, + [ + ("john", 1, 1, "foo@bar.com", 1), + ("jack", 2, 2, "bar@bat.com", 2), + ], + ) + + eq_(rows[0]._mapping[users.c.user_id], 1) + eq_(rows[0]._mapping[addresses.c.user_id], 1) + eq_(rows[1].address, "bar@bat.com") + + with expect_raises_message( + exc.InvalidRequestError, "Ambiguous column name 'user_id'" + ): + rows[0].user_id + def test_keys_no_rows(self, connection): for i in range(2): diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py index f8cc32517..c26f825c2 100644 --- a/test/sql/test_returning.py +++ b/test/sql/test_returning.py @@ -23,6 +23,7 @@ from sqlalchemy.testing import config from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures +from sqlalchemy.testing import is_ from sqlalchemy.testing import mock from sqlalchemy.testing import provision from sqlalchemy.testing.schema import Column @@ -76,6 +77,7 @@ class ReturnCombinationTests(fixtures.TestBase, AssertsCompiledSQL): stmt = stmt.returning(t.c.x) stmt = stmt.return_defaults() + assert_raises_message( sa_exc.CompileError, r"Can't compile statement that includes returning\(\) " @@ -330,6 +332,7 @@ class InsertReturningTest(fixtures.TablesTest, AssertsExecutionResults): table = self.tables.returning_tbl exprs = testing.resolve_lambda(testcase, table=table) + result = connection.execute( table.insert().returning(*exprs), {"persons": 5, "full": False, "strval": "str1"}, @@ -679,6 +682,30 @@ class InsertReturnDefaultsTest(fixtures.TablesTest): Column("upddef", Integer, onupdate=IncDefault()), ) + Table( + "table_no_addtl_defaults", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + ) + + class MyType(TypeDecorator): + impl = String(50) + + def process_result_value(self, value, dialect): + return f"PROCESSED! {value}" + + Table( + "table_datatype_has_result_proc", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", MyType()), + ) + def test_chained_insert_pk(self, connection): t1 = self.tables.t1 result = connection.execute( @@ -758,6 +785,38 @@ class InsertReturnDefaultsTest(fixtures.TablesTest): ) eq_(result.inserted_primary_key, (1,)) + def test_insert_w_defaults_supplemental_cols(self, connection): + t1 = self.tables.t1 + result = connection.execute( + t1.insert().return_defaults(supplemental_cols=[t1.c.id]), + {"data": "d1"}, + ) + eq_(result.all(), [(1, 0, None)]) + + def test_insert_w_no_defaults_supplemental_cols(self, connection): + t1 = self.tables.table_no_addtl_defaults + result = connection.execute( + t1.insert().return_defaults(supplemental_cols=[t1.c.id]), + {"data": "d1"}, + ) + eq_(result.all(), [(1,)]) + + def test_insert_w_defaults_supplemental_processor_cols(self, connection): + """test that the cursor._rewind() used by supplemental RETURNING + clears out result-row processors as we will have already processed + the rows. + + """ + + t1 = self.tables.table_datatype_has_result_proc + result = connection.execute( + t1.insert().return_defaults( + supplemental_cols=[t1.c.id, t1.c.data] + ), + {"data": "d1"}, + ) + eq_(result.all(), [(1, "PROCESSED! d1")]) + class UpdatedReturnDefaultsTest(fixtures.TablesTest): __requires__ = ("update_returning",) @@ -792,6 +851,7 @@ class UpdatedReturnDefaultsTest(fixtures.TablesTest): t1 = self.tables.t1 connection.execute(t1.insert().values(upddef=1)) + result = connection.execute( t1.update().values(upddef=2).return_defaults(t1.c.data) ) @@ -800,6 +860,72 @@ class UpdatedReturnDefaultsTest(fixtures.TablesTest): [None], ) + def test_update_values_col_is_excluded(self, connection): + """columns that are in values() are not returned""" + t1 = self.tables.t1 + connection.execute(t1.insert().values(upddef=1)) + + result = connection.execute( + t1.update().values(data="x", upddef=2).return_defaults(t1.c.data) + ) + is_(result.returned_defaults, None) + + result = connection.execute( + t1.update() + .values(data="x", upddef=2) + .return_defaults(t1.c.data, t1.c.id) + ) + eq_(result.returned_defaults, (1,)) + + def test_update_supplemental_cols(self, connection): + """with supplemental_cols, we can get back arbitrary cols.""" + + t1 = self.tables.t1 + connection.execute(t1.insert().values(upddef=1)) + result = connection.execute( + t1.update() + .values(data="x", insdef=3) + .return_defaults(supplemental_cols=[t1.c.data, t1.c.insdef]) + ) + + row = result.returned_defaults + + # row has all the cols in it + eq_(row, ("x", 3, 1)) + eq_(row._mapping[t1.c.upddef], 1) + eq_(row._mapping[t1.c.insdef], 3) + + # result is rewound + # but has both return_defaults + supplemental_cols + eq_(result.all(), [("x", 3, 1)]) + + def test_update_expl_return_defaults_plus_supplemental_cols( + self, connection + ): + """with supplemental_cols, we can get back arbitrary cols.""" + + t1 = self.tables.t1 + connection.execute(t1.insert().values(upddef=1)) + result = connection.execute( + t1.update() + .values(data="x", insdef=3) + .return_defaults( + t1.c.id, supplemental_cols=[t1.c.data, t1.c.insdef] + ) + ) + + row = result.returned_defaults + + # row has all the cols in it + eq_(row, (1, "x", 3)) + eq_(row._mapping[t1.c.id], 1) + eq_(row._mapping[t1.c.insdef], 3) + assert t1.c.upddef not in row._mapping + + # result is rewound + # but has both return_defaults + supplemental_cols + eq_(result.all(), [(1, "x", 3)]) + def test_update_sql_expr(self, connection): from sqlalchemy import literal @@ -833,6 +959,75 @@ class UpdatedReturnDefaultsTest(fixtures.TablesTest): eq_(dict(result.returned_defaults._mapping), {"upddef": 1}) +class DeleteReturnDefaultsTest(fixtures.TablesTest): + __requires__ = ("delete_returning",) + run_define_tables = "each" + __backend__ = True + + define_tables = InsertReturnDefaultsTest.define_tables + + def test_delete(self, connection): + t1 = self.tables.t1 + connection.execute(t1.insert().values(upddef=1)) + result = connection.execute(t1.delete().return_defaults(t1.c.upddef)) + eq_( + [result.returned_defaults._mapping[k] for k in (t1.c.upddef,)], [1] + ) + + def test_delete_empty_return_defaults(self, connection): + t1 = self.tables.t1 + connection.execute(t1.insert().values(upddef=5)) + result = connection.execute(t1.delete().return_defaults()) + + # there's no "delete" default, so we get None. we have to + # ask for them in all cases + eq_(result.returned_defaults, None) + + def test_delete_non_default(self, connection): + """test that a column not marked at all as a + default works with this feature.""" + + t1 = self.tables.t1 + connection.execute(t1.insert().values(upddef=1)) + result = connection.execute(t1.delete().return_defaults(t1.c.data)) + eq_( + [result.returned_defaults._mapping[k] for k in (t1.c.data,)], + [None], + ) + + def test_delete_non_default_plus_default(self, connection): + t1 = self.tables.t1 + connection.execute(t1.insert().values(upddef=1)) + result = connection.execute( + t1.delete().return_defaults(t1.c.data, t1.c.upddef) + ) + eq_( + dict(result.returned_defaults._mapping), + {"data": None, "upddef": 1}, + ) + + def test_delete_supplemental_cols(self, connection): + """with supplemental_cols, we can get back arbitrary cols.""" + + t1 = self.tables.t1 + connection.execute(t1.insert().values(upddef=1)) + result = connection.execute( + t1.delete().return_defaults( + t1.c.id, supplemental_cols=[t1.c.data, t1.c.insdef] + ) + ) + + row = result.returned_defaults + + # row has all the cols in it + eq_(row, (1, None, 0)) + eq_(row._mapping[t1.c.insdef], 0) + + # result is rewound + # but has both return_defaults + supplemental_cols + eq_(result.all(), [(1, None, 0)]) + + class InsertManyReturnDefaultsTest(fixtures.TablesTest): __requires__ = ("insert_executemany_returning",) run_define_tables = "each" diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py index 64ff2e421..5ef927b15 100644 --- a/test/sql/test_selectable.py +++ b/test/sql/test_selectable.py @@ -44,6 +44,7 @@ from sqlalchemy.sql import operators from sqlalchemy.sql import table from sqlalchemy.sql import util as sql_util from sqlalchemy.sql import visitors +from sqlalchemy.sql.dml import Insert from sqlalchemy.sql.selectable import LABEL_STYLE_NONE from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message @@ -3029,6 +3030,26 @@ class AnnotationsTest(fixtures.TestBase): eq_(whereclause.left._annotations, {"foo": "bar"}) eq_(whereclause.right._annotations, {"foo": "bar"}) + @testing.combinations(True, False, None) + def test_setup_inherit_cache(self, inherit_cache_value): + if inherit_cache_value is None: + + class MyInsertThing(Insert): + pass + + else: + + class MyInsertThing(Insert): + inherit_cache = inherit_cache_value + + t = table("t", column("x")) + anno = MyInsertThing(t)._annotate({"foo": "bar"}) + + if inherit_cache_value is not None: + is_(type(anno).__dict__["inherit_cache"], inherit_cache_value) + else: + assert "inherit_cache" not in type(anno).__dict__ + def test_proxy_set_iteration_includes_annotated(self): from sqlalchemy.schema import Column |
