diff options
Diffstat (limited to 'lib/sqlalchemy')
25 files changed, 2317 insertions, 700 deletions
diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py index 2eef971cc..5eb6b9528 100644 --- a/lib/sqlalchemy/dialects/mssql/pyodbc.py +++ b/lib/sqlalchemy/dialects/mssql/pyodbc.py @@ -301,7 +301,7 @@ Fast Executemany Mode The SQL Server ``fast_executemany`` parameter may be used at the same time as ``insertmanyvalues`` is enabled; however, the parameter will not be used in as many cases as INSERT statements that are invoked using Core - :class:`.Insert` constructs as well as all ORM use no longer use the + :class:`_dml.Insert` constructs as well as all ORM use no longer use the ``.executemany()`` DBAPI cursor method. The PyODBC driver includes support for a "fast executemany" mode of execution diff --git a/lib/sqlalchemy/dialects/postgresql/provision.py b/lib/sqlalchemy/dialects/postgresql/provision.py index 8dd8a4995..4609701a2 100644 --- a/lib/sqlalchemy/dialects/postgresql/provision.py +++ b/lib/sqlalchemy/dialects/postgresql/provision.py @@ -134,9 +134,11 @@ def _upsert(cfg, table, returning, set_lambda=None): stmt = insert(table) + table_pk = inspect(table).selectable + if set_lambda: stmt = stmt.on_conflict_do_update( - index_elements=table.primary_key, set_=set_lambda(stmt.excluded) + index_elements=table_pk.primary_key, set_=set_lambda(stmt.excluded) ) else: stmt = stmt.on_conflict_do_nothing() diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index e57a84fe0..5f468edbe 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -1466,11 +1466,6 @@ class SQLiteCompiler(compiler.SQLCompiler): return target_text - def visit_insert(self, insert_stmt, **kw): - if insert_stmt._post_values_clause is not None: - kw["disable_implicit_returning"] = True - return super().visit_insert(insert_stmt, **kw) - def visit_on_conflict_do_nothing(self, on_conflict, **kw): target_text = self._on_conflict_target(on_conflict, **kw) diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index 8840b5916..07e782296 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -23,12 +23,14 @@ from typing import Iterator from typing import List from typing import NoReturn from typing import Optional +from typing import overload from typing import Sequence from typing import Tuple from typing import TYPE_CHECKING from typing import TypeVar from typing import Union +from .result import IteratorResult from .result import MergedResult from .result import Result from .result import ResultMetaData @@ -62,36 +64,80 @@ if typing.TYPE_CHECKING: from .interfaces import ExecutionContext from .result import _KeyIndexType from .result import _KeyMapRecType + from .result import _KeyMapType from .result import _KeyType from .result import _ProcessorsType + from .result import _TupleGetterType from ..sql.type_api import _ResultProcessorType _T = TypeVar("_T", bound=Any) + # metadata entry tuple indexes. # using raw tuple is faster than namedtuple. -MD_INDEX: Literal[0] = 0 # integer index in cursor.description -MD_RESULT_MAP_INDEX: Literal[ - 1 -] = 1 # integer index in compiled._result_columns -MD_OBJECTS: Literal[ - 2 -] = 2 # other string keys and ColumnElement obj that can match -MD_LOOKUP_KEY: Literal[ - 3 -] = 3 # string key we usually expect for key-based lookup -MD_RENDERED_NAME: Literal[4] = 4 # name that is usually in cursor.description -MD_PROCESSOR: Literal[5] = 5 # callable to process a result value into a row -MD_UNTRANSLATED: Literal[6] = 6 # raw name from cursor.description +# these match up to the positions in +# _CursorKeyMapRecType +MD_INDEX: Literal[0] = 0 +"""integer index in cursor.description + +""" + +MD_RESULT_MAP_INDEX: Literal[1] = 1 +"""integer index in compiled._result_columns""" + +MD_OBJECTS: Literal[2] = 2 +"""other string keys and ColumnElement obj that can match. + +This comes from compiler.RM_OBJECTS / compiler.ResultColumnsEntry.objects + +""" + +MD_LOOKUP_KEY: Literal[3] = 3 +"""string key we usually expect for key-based lookup + +this comes from compiler.RM_NAME / compiler.ResultColumnsEntry.name +""" + + +MD_RENDERED_NAME: Literal[4] = 4 +"""name that is usually in cursor.description + +this comes from compiler.RENDERED_NAME / compiler.ResultColumnsEntry.keyname +""" + + +MD_PROCESSOR: Literal[5] = 5 +"""callable to process a result value into a row""" + +MD_UNTRANSLATED: Literal[6] = 6 +"""raw name from cursor.description""" _CursorKeyMapRecType = Tuple[ - int, int, List[Any], str, str, Optional["_ResultProcessorType"], str + Optional[int], # MD_INDEX, None means the record is ambiguously named + int, # MD_RESULT_MAP_INDEX + List[Any], # MD_OBJECTS + str, # MD_LOOKUP_KEY + str, # MD_RENDERED_NAME + Optional["_ResultProcessorType"], # MD_PROCESSOR + Optional[str], # MD_UNTRANSLATED ] _CursorKeyMapType = Dict["_KeyType", _CursorKeyMapRecType] +# same as _CursorKeyMapRecType except the MD_INDEX value is definitely +# not None +_NonAmbigCursorKeyMapRecType = Tuple[ + int, + int, + List[Any], + str, + str, + Optional["_ResultProcessorType"], + str, +] + class CursorResultMetaData(ResultMetaData): """Result metadata for DBAPI cursors.""" @@ -127,38 +173,112 @@ class CursorResultMetaData(ResultMetaData): extra=[self._keymap[key][MD_OBJECTS] for key in self._keys], ) - def _reduce(self, keys: Sequence[_KeyIndexType]) -> ResultMetaData: - recs = cast( - "List[_CursorKeyMapRecType]", list(self._metadata_for_keys(keys)) + def _make_new_metadata( + self, + *, + unpickled: bool, + processors: _ProcessorsType, + keys: Sequence[str], + keymap: _KeyMapType, + tuplefilter: Optional[_TupleGetterType], + translated_indexes: Optional[List[int]], + safe_for_cache: bool, + keymap_by_result_column_idx: Any, + ) -> CursorResultMetaData: + new_obj = self.__class__.__new__(self.__class__) + new_obj._unpickled = unpickled + new_obj._processors = processors + new_obj._keys = keys + new_obj._keymap = keymap + new_obj._tuplefilter = tuplefilter + new_obj._translated_indexes = translated_indexes + new_obj._safe_for_cache = safe_for_cache + new_obj._keymap_by_result_column_idx = keymap_by_result_column_idx + return new_obj + + def _remove_processors(self) -> CursorResultMetaData: + assert not self._tuplefilter + return self._make_new_metadata( + unpickled=self._unpickled, + processors=[None] * len(self._processors), + tuplefilter=None, + translated_indexes=None, + keymap={ + key: value[0:5] + (None,) + value[6:] + for key, value in self._keymap.items() + }, + keys=self._keys, + safe_for_cache=self._safe_for_cache, + keymap_by_result_column_idx=self._keymap_by_result_column_idx, ) + def _splice_horizontally( + self, other: CursorResultMetaData + ) -> CursorResultMetaData: + + assert not self._tuplefilter + + keymap = self._keymap.copy() + offset = len(self._keys) + keymap.update( + { + key: ( + # int index should be None for ambiguous key + value[0] + offset + if value[0] is not None and key not in keymap + else None, + value[1] + offset, + *value[2:], + ) + for key, value in other._keymap.items() + } + ) + + return self._make_new_metadata( + unpickled=self._unpickled, + processors=self._processors + other._processors, # type: ignore + tuplefilter=None, + translated_indexes=None, + keys=self._keys + other._keys, # type: ignore + keymap=keymap, + safe_for_cache=self._safe_for_cache, + keymap_by_result_column_idx={ + metadata_entry[MD_RESULT_MAP_INDEX]: metadata_entry + for metadata_entry in keymap.values() + }, + ) + + def _reduce(self, keys: Sequence[_KeyIndexType]) -> ResultMetaData: + recs = list(self._metadata_for_keys(keys)) + indexes = [rec[MD_INDEX] for rec in recs] new_keys: List[str] = [rec[MD_LOOKUP_KEY] for rec in recs] if self._translated_indexes: indexes = [self._translated_indexes[idx] for idx in indexes] tup = tuplegetter(*indexes) - - new_metadata = self.__class__.__new__(self.__class__) - new_metadata._unpickled = self._unpickled - new_metadata._processors = self._processors - new_metadata._keys = new_keys - new_metadata._tuplefilter = tup - new_metadata._translated_indexes = indexes - new_recs = [(index,) + rec[1:] for index, rec in enumerate(recs)] - new_metadata._keymap = {rec[MD_LOOKUP_KEY]: rec for rec in new_recs} + keymap: _KeyMapType = {rec[MD_LOOKUP_KEY]: rec for rec in new_recs} # TODO: need unit test for: # result = connection.execute("raw sql, no columns").scalars() # without the "or ()" it's failing because MD_OBJECTS is None - new_metadata._keymap.update( + keymap.update( (e, new_rec) for new_rec in new_recs for e in new_rec[MD_OBJECTS] or () ) - return new_metadata + return self._make_new_metadata( + unpickled=self._unpickled, + processors=self._processors, + keys=new_keys, + tuplefilter=tup, + translated_indexes=indexes, + keymap=keymap, + safe_for_cache=self._safe_for_cache, + keymap_by_result_column_idx=self._keymap_by_result_column_idx, + ) def _adapt_to_context(self, context: ExecutionContext) -> ResultMetaData: """When using a cached Compiled construct that has a _result_map, @@ -168,6 +288,7 @@ class CursorResultMetaData(ResultMetaData): as matched to those of the cached statement. """ + if not context.compiled or not context.compiled._result_columns: return self @@ -189,7 +310,6 @@ class CursorResultMetaData(ResultMetaData): # make a copy and add the columns from the invoked statement # to the result map. - md = self.__class__.__new__(self.__class__) keymap_by_position = self._keymap_by_result_column_idx @@ -201,26 +321,26 @@ class CursorResultMetaData(ResultMetaData): for metadata_entry in self._keymap.values() } - md._keymap = compat.dict_union( - self._keymap, - { - new: keymap_by_position[idx] - for idx, new in enumerate( - invoked_statement._all_selected_columns - ) - if idx in keymap_by_position - }, - ) - - md._unpickled = self._unpickled - md._processors = self._processors assert not self._tuplefilter - md._tuplefilter = None - md._translated_indexes = None - md._keys = self._keys - md._keymap_by_result_column_idx = self._keymap_by_result_column_idx - md._safe_for_cache = self._safe_for_cache - return md + return self._make_new_metadata( + keymap=compat.dict_union( + self._keymap, + { + new: keymap_by_position[idx] + for idx, new in enumerate( + invoked_statement._all_selected_columns + ) + if idx in keymap_by_position + }, + ), + unpickled=self._unpickled, + processors=self._processors, + tuplefilter=None, + translated_indexes=None, + keys=self._keys, + safe_for_cache=self._safe_for_cache, + keymap_by_result_column_idx=self._keymap_by_result_column_idx, + ) def __init__( self, @@ -683,7 +803,27 @@ class CursorResultMetaData(ResultMetaData): untranslated, ) - def _key_fallback(self, key, err, raiseerr=True): + @overload + def _key_fallback( + self, key: Any, err: Exception, raiseerr: Literal[True] = ... + ) -> NoReturn: + ... + + @overload + def _key_fallback( + self, key: Any, err: Exception, raiseerr: Literal[False] = ... + ) -> None: + ... + + @overload + def _key_fallback( + self, key: Any, err: Exception, raiseerr: bool = ... + ) -> Optional[NoReturn]: + ... + + def _key_fallback( + self, key: Any, err: Exception, raiseerr: bool = True + ) -> Optional[NoReturn]: if raiseerr: if self._unpickled and isinstance(key, elements.ColumnElement): @@ -714,9 +854,9 @@ class CursorResultMetaData(ResultMetaData): try: rec = self._keymap[key] except KeyError as ke: - rec = self._key_fallback(key, ke, raiseerr) - if rec is None: - return None + x = self._key_fallback(key, ke, raiseerr) + assert x is None + return None index = rec[0] @@ -734,7 +874,7 @@ class CursorResultMetaData(ResultMetaData): def _metadata_for_keys( self, keys: Sequence[Any] - ) -> Iterator[_CursorKeyMapRecType]: + ) -> Iterator[_NonAmbigCursorKeyMapRecType]: for key in keys: if int in key.__class__.__mro__: key = self._keys[key] @@ -750,7 +890,7 @@ class CursorResultMetaData(ResultMetaData): if index is None: self._raise_for_ambiguous_column_name(rec) - yield rec + yield cast(_NonAmbigCursorKeyMapRecType, rec) def __getstate__(self): return { @@ -1237,6 +1377,12 @@ _NO_RESULT_METADATA = _NoResultMetaData() SelfCursorResult = TypeVar("SelfCursorResult", bound="CursorResult[Any]") +def null_dml_result() -> IteratorResult[Any]: + it: IteratorResult[Any] = IteratorResult(_NoResultMetaData(), iter([])) + it._soft_close() + return it + + class CursorResult(Result[_T]): """A Result that is representing state from a DBAPI cursor. @@ -1586,6 +1732,142 @@ class CursorResult(Result[_T]): """ return self.context.returned_default_rows + def splice_horizontally(self, other): + """Return a new :class:`.CursorResult` that "horizontally splices" + together the rows of this :class:`.CursorResult` with that of another + :class:`.CursorResult`. + + .. tip:: This method is for the benefit of the SQLAlchemy ORM and is + not intended for general use. + + "horizontally splices" means that for each row in the first and second + result sets, a new row that concatenates the two rows together is + produced, which then becomes the new row. The incoming + :class:`.CursorResult` must have the identical number of rows. It is + typically expected that the two result sets come from the same sort + order as well, as the result rows are spliced together based on their + position in the result. + + The expected use case here is so that multiple INSERT..RETURNING + statements against different tables can produce a single result + that looks like a JOIN of those two tables. + + E.g.:: + + r1 = connection.execute( + users.insert().returning(users.c.user_name, users.c.user_id), + user_values + ) + + r2 = connection.execute( + addresses.insert().returning( + addresses.c.address_id, + addresses.c.address, + addresses.c.user_id, + ), + address_values + ) + + rows = r1.splice_horizontally(r2).all() + assert ( + rows == + [ + ("john", 1, 1, "foo@bar.com", 1), + ("jack", 2, 2, "bar@bat.com", 2), + ] + ) + + .. versionadded:: 2.0 + + .. seealso:: + + :meth:`.CursorResult.splice_vertically` + + + """ + + clone = self._generate() + total_rows = [ + tuple(r1) + tuple(r2) + for r1, r2 in zip( + list(self._raw_row_iterator()), + list(other._raw_row_iterator()), + ) + ] + + clone._metadata = clone._metadata._splice_horizontally(other._metadata) + + clone.cursor_strategy = FullyBufferedCursorFetchStrategy( + None, + initial_buffer=total_rows, + ) + clone._reset_memoizations() + return clone + + def splice_vertically(self, other): + """Return a new :class:`.CursorResult` that "vertically splices", + i.e. "extends", the rows of this :class:`.CursorResult` with that of + another :class:`.CursorResult`. + + .. tip:: This method is for the benefit of the SQLAlchemy ORM and is + not intended for general use. + + "vertically splices" means the rows of the given result are appended to + the rows of this cursor result. The incoming :class:`.CursorResult` + must have rows that represent the identical list of columns in the + identical order as they are in this :class:`.CursorResult`. + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`.CursorResult.splice_horizontally` + + """ + clone = self._generate() + total_rows = list(self._raw_row_iterator()) + list( + other._raw_row_iterator() + ) + + clone.cursor_strategy = FullyBufferedCursorFetchStrategy( + None, + initial_buffer=total_rows, + ) + clone._reset_memoizations() + return clone + + def _rewind(self, rows): + """rewind this result back to the given rowset. + + this is used internally for the case where an :class:`.Insert` + construct combines the use of + :meth:`.Insert.return_defaults` along with the + "supplemental columns" feature. + + """ + + if self._echo: + self.context.connection._log_debug( + "CursorResult rewound %d row(s)", len(rows) + ) + + # the rows given are expected to be Row objects, so we + # have to clear out processors which have already run on these + # rows + self._metadata = cast( + CursorResultMetaData, self._metadata + )._remove_processors() + + self.cursor_strategy = FullyBufferedCursorFetchStrategy( + None, + # TODO: if these are Row objects, can we save on not having to + # re-make new Row objects out of them a second time? is that + # what's actually happening right now? maybe look into this + initial_buffer=rows, + ) + self._reset_memoizations() + return self + @property def returned_defaults(self): """Return the values of default columns that were fetched using diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 11ab713d0..cb3d0528f 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -1007,6 +1007,7 @@ class DefaultExecutionContext(ExecutionContext): _is_implicit_returning = False _is_explicit_returning = False + _is_supplemental_returning = False _is_server_side = False _soft_closed = False @@ -1125,18 +1126,19 @@ class DefaultExecutionContext(ExecutionContext): self.is_text = compiled.isplaintext if ii or iu or id_: + dml_statement = compiled.compile_state.statement # type: ignore if TYPE_CHECKING: - assert isinstance(compiled.statement, UpdateBase) + assert isinstance(dml_statement, UpdateBase) self.is_crud = True - self._is_explicit_returning = ier = bool( - compiled.statement._returning - ) - self._is_implicit_returning = iir = is_implicit_returning = bool( + self._is_explicit_returning = ier = bool(dml_statement._returning) + self._is_implicit_returning = iir = bool( compiled.implicit_returning ) - assert not ( - is_implicit_returning and compiled.statement._returning - ) + if iir and dml_statement._supplemental_returning: + self._is_supplemental_returning = True + + # dont mix implicit and explicit returning + assert not (iir and ier) if (ier or iir) and compiled.for_executemany: if ii and not self.dialect.insert_executemany_returning: @@ -1711,7 +1713,14 @@ class DefaultExecutionContext(ExecutionContext): # are that the result has only one row, until executemany() # support is added here. assert result._metadata.returns_rows - result._soft_close() + + # Insert statement has both return_defaults() and + # returning(). rewind the result on the list of rows + # we just used. + if self._is_supplemental_returning: + result._rewind(rows) + else: + result._soft_close() elif not self._is_explicit_returning: result._soft_close() @@ -1721,21 +1730,18 @@ class DefaultExecutionContext(ExecutionContext): # function so this is not necessarily true. # assert not result.returns_rows - elif self.isupdate and self._is_implicit_returning: - # get rowcount - # (which requires open cursor on some drivers) - # we were not doing this in 1.4, however - # test_rowcount -> test_update_rowcount_return_defaults - # is testing this, and psycopg will no longer return - # rowcount after cursor is closed. - result.rowcount - self._has_rowcount = True + elif self._is_implicit_returning: + rows = result.all() - row = result.fetchone() - if row is not None: - self.returned_default_rows = [row] + if rows: + self.returned_default_rows = rows + result.rowcount = len(rows) + self._has_rowcount = True - result._soft_close() + if self._is_supplemental_returning: + result._rewind(rows) + else: + result._soft_close() # test that it has a cursor metadata that is accurate. # the rows have all been fetched however. @@ -1750,7 +1756,6 @@ class DefaultExecutionContext(ExecutionContext): elif self.isupdate or self.isdelete: result.rowcount self._has_rowcount = True - return result @util.memoized_property diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index df5a8199c..05ca17063 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -109,9 +109,27 @@ class ResultMetaData: def _for_freeze(self) -> ResultMetaData: raise NotImplementedError() + @overload def _key_fallback( - self, key: _KeyType, err: Exception, raiseerr: bool = True + self, key: Any, err: Exception, raiseerr: Literal[True] = ... ) -> NoReturn: + ... + + @overload + def _key_fallback( + self, key: Any, err: Exception, raiseerr: Literal[False] = ... + ) -> None: + ... + + @overload + def _key_fallback( + self, key: Any, err: Exception, raiseerr: bool = ... + ) -> Optional[NoReturn]: + ... + + def _key_fallback( + self, key: Any, err: Exception, raiseerr: bool = True + ) -> Optional[NoReturn]: assert raiseerr raise KeyError(key) from err @@ -2148,6 +2166,7 @@ class IteratorResult(Result[_TP]): """ _hard_closed = False + _soft_closed = False def __init__( self, @@ -2168,6 +2187,7 @@ class IteratorResult(Result[_TP]): self.raw._soft_close(hard=hard, **kw) self.iterator = iter([]) self._reset_memoizations() + self._soft_closed = True def _raise_hard_closed(self) -> NoReturn: raise exc.ResourceClosedError("This result object is closed.") diff --git a/lib/sqlalchemy/orm/bulk_persistence.py b/lib/sqlalchemy/orm/bulk_persistence.py index 225292d17..3ed34a57a 100644 --- a/lib/sqlalchemy/orm/bulk_persistence.py +++ b/lib/sqlalchemy/orm/bulk_persistence.py @@ -15,24 +15,32 @@ specifically outside of the flush() process. from __future__ import annotations from typing import Any +from typing import cast from typing import Dict from typing import Iterable +from typing import Optional +from typing import overload from typing import TYPE_CHECKING from typing import TypeVar from typing import Union from . import attributes +from . import context from . import evaluator from . import exc as orm_exc +from . import loading from . import persistence from .base import NO_VALUE from .context import AbstractORMCompileState +from .context import FromStatement +from .context import ORMFromStatementCompileState +from .context import QueryContext from .. import exc as sa_exc -from .. import sql from .. import util from ..engine import Dialect from ..engine import result as _result from ..sql import coercions +from ..sql import dml from ..sql import expression from ..sql import roles from ..sql import select @@ -48,16 +56,24 @@ from ..util.typing import Literal if TYPE_CHECKING: from .mapper import Mapper + from .session import _BindArguments from .session import ORMExecuteState + from .session import Session from .session import SessionTransaction from .state import InstanceState + from ..engine import Connection + from ..engine import cursor + from ..engine.interfaces import _CoreAnyExecuteParams + from ..engine.interfaces import _ExecuteOptionsParameter _O = TypeVar("_O", bound=object) -_SynchronizeSessionArgument = Literal[False, "evaluate", "fetch"] +_SynchronizeSessionArgument = Literal[False, "auto", "evaluate", "fetch"] +_DMLStrategyArgument = Literal["bulk", "raw", "orm", "auto"] +@overload def _bulk_insert( mapper: Mapper[_O], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], @@ -65,7 +81,36 @@ def _bulk_insert( isstates: bool, return_defaults: bool, render_nulls: bool, + use_orm_insert_stmt: Literal[None] = ..., + execution_options: Optional[_ExecuteOptionsParameter] = ..., ) -> None: + ... + + +@overload +def _bulk_insert( + mapper: Mapper[_O], + mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + session_transaction: SessionTransaction, + isstates: bool, + return_defaults: bool, + render_nulls: bool, + use_orm_insert_stmt: Optional[dml.Insert] = ..., + execution_options: Optional[_ExecuteOptionsParameter] = ..., +) -> cursor.CursorResult[Any]: + ... + + +def _bulk_insert( + mapper: Mapper[_O], + mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + session_transaction: SessionTransaction, + isstates: bool, + return_defaults: bool, + render_nulls: bool, + use_orm_insert_stmt: Optional[dml.Insert] = None, + execution_options: Optional[_ExecuteOptionsParameter] = None, +) -> Optional[cursor.CursorResult[Any]]: base_mapper = mapper.base_mapper if session_transaction.session.connection_callable: @@ -81,13 +126,27 @@ def _bulk_insert( else: mappings = [state.dict for state in mappings] else: - mappings = list(mappings) + mappings = [dict(m) for m in mappings] + _expand_composites(mapper, mappings) connection = session_transaction.connection(base_mapper) + + return_result: Optional[cursor.CursorResult[Any]] = None + for table, super_mapper in base_mapper._sorted_tables.items(): - if not mapper.isa(super_mapper): + if not mapper.isa(super_mapper) or table not in mapper._pks_by_table: continue + is_joined_inh_supertable = super_mapper is not mapper + bookkeeping = ( + is_joined_inh_supertable + or return_defaults + or ( + use_orm_insert_stmt is not None + and bool(use_orm_insert_stmt._returning) + ) + ) + records = ( ( None, @@ -112,18 +171,25 @@ def _bulk_insert( table, ((None, mapping, mapper, connection) for mapping in mappings), bulk=True, - return_defaults=return_defaults, + return_defaults=bookkeeping, render_nulls=render_nulls, ) ) - persistence._emit_insert_statements( + result = persistence._emit_insert_statements( base_mapper, None, super_mapper, table, records, - bookkeeping=return_defaults, + bookkeeping=bookkeeping, + use_orm_insert_stmt=use_orm_insert_stmt, + execution_options=execution_options, ) + if use_orm_insert_stmt is not None: + if not use_orm_insert_stmt._returning or return_result is None: + return_result = result + elif result.returns_rows: + return_result = return_result.splice_horizontally(result) if return_defaults and isstates: identity_cls = mapper._identity_class @@ -134,14 +200,43 @@ def _bulk_insert( tuple([dict_[key] for key in identity_props]), ) + if use_orm_insert_stmt is not None: + assert return_result is not None + return return_result + +@overload def _bulk_update( mapper: Mapper[Any], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], session_transaction: SessionTransaction, isstates: bool, update_changed_only: bool, + use_orm_update_stmt: Literal[None] = ..., ) -> None: + ... + + +@overload +def _bulk_update( + mapper: Mapper[Any], + mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + session_transaction: SessionTransaction, + isstates: bool, + update_changed_only: bool, + use_orm_update_stmt: Optional[dml.Update] = ..., +) -> _result.Result[Any]: + ... + + +def _bulk_update( + mapper: Mapper[Any], + mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + session_transaction: SessionTransaction, + isstates: bool, + update_changed_only: bool, + use_orm_update_stmt: Optional[dml.Update] = None, +) -> Optional[_result.Result[Any]]: base_mapper = mapper.base_mapper search_keys = mapper._primary_key_propkeys @@ -161,7 +256,8 @@ def _bulk_update( else: mappings = [state.dict for state in mappings] else: - mappings = list(mappings) + mappings = [dict(m) for m in mappings] + _expand_composites(mapper, mappings) if session_transaction.session.connection_callable: raise NotImplementedError( @@ -172,7 +268,7 @@ def _bulk_update( connection = session_transaction.connection(base_mapper) for table, super_mapper in base_mapper._sorted_tables.items(): - if not mapper.isa(super_mapper): + if not mapper.isa(super_mapper) or table not in mapper._pks_by_table: continue records = persistence._collect_update_commands( @@ -193,8 +289,8 @@ def _bulk_update( for mapping in mappings ), bulk=True, + use_orm_update_stmt=use_orm_update_stmt, ) - persistence._emit_update_statements( base_mapper, None, @@ -202,10 +298,125 @@ def _bulk_update( table, records, bookkeeping=False, + use_orm_update_stmt=use_orm_update_stmt, ) + if use_orm_update_stmt is not None: + return _result.null_result() + + +def _expand_composites(mapper, mappings): + composite_attrs = mapper.composites + if not composite_attrs: + return + + composite_keys = set(composite_attrs.keys()) + populators = { + key: composite_attrs[key]._populate_composite_bulk_save_mappings_fn() + for key in composite_keys + } + for mapping in mappings: + for key in composite_keys.intersection(mapping): + populators[key](mapping) + class ORMDMLState(AbstractORMCompileState): + is_dml_returning = True + from_statement_ctx: Optional[ORMFromStatementCompileState] = None + + @classmethod + def _get_orm_crud_kv_pairs( + cls, mapper, statement, kv_iterator, needs_to_be_cacheable + ): + + core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs + + for k, v in kv_iterator: + k = coercions.expect(roles.DMLColumnRole, k) + + if isinstance(k, str): + desc = _entity_namespace_key(mapper, k, default=NO_VALUE) + if desc is NO_VALUE: + yield ( + coercions.expect(roles.DMLColumnRole, k), + coercions.expect( + roles.ExpressionElementRole, + v, + type_=sqltypes.NullType(), + is_crud=True, + ) + if needs_to_be_cacheable + else v, + ) + else: + yield from core_get_crud_kv_pairs( + statement, + desc._bulk_update_tuples(v), + needs_to_be_cacheable, + ) + elif "entity_namespace" in k._annotations: + k_anno = k._annotations + attr = _entity_namespace_key( + k_anno["entity_namespace"], k_anno["proxy_key"] + ) + yield from core_get_crud_kv_pairs( + statement, + attr._bulk_update_tuples(v), + needs_to_be_cacheable, + ) + else: + yield ( + k, + v + if not needs_to_be_cacheable + else coercions.expect( + roles.ExpressionElementRole, + v, + type_=sqltypes.NullType(), + is_crud=True, + ), + ) + + @classmethod + def _get_multi_crud_kv_pairs(cls, statement, kv_iterator): + plugin_subject = statement._propagate_attrs["plugin_subject"] + + if not plugin_subject or not plugin_subject.mapper: + return UpdateDMLState._get_multi_crud_kv_pairs( + statement, kv_iterator + ) + + return [ + dict( + cls._get_orm_crud_kv_pairs( + plugin_subject.mapper, statement, value_dict.items(), False + ) + ) + for value_dict in kv_iterator + ] + + @classmethod + def _get_crud_kv_pairs(cls, statement, kv_iterator, needs_to_be_cacheable): + assert ( + needs_to_be_cacheable + ), "no test coverage for needs_to_be_cacheable=False" + + plugin_subject = statement._propagate_attrs["plugin_subject"] + + if not plugin_subject or not plugin_subject.mapper: + return UpdateDMLState._get_crud_kv_pairs( + statement, kv_iterator, needs_to_be_cacheable + ) + + return list( + cls._get_orm_crud_kv_pairs( + plugin_subject.mapper, + statement, + kv_iterator, + needs_to_be_cacheable, + ) + ) + @classmethod def get_entity_description(cls, statement): ext_info = statement.table._annotations["parententity"] @@ -250,18 +461,101 @@ class ORMDMLState(AbstractORMCompileState): ] ] + def _setup_orm_returning( + self, + compiler, + orm_level_statement, + dml_level_statement, + use_supplemental_cols=True, + dml_mapper=None, + ): + """establish ORM column handlers for an INSERT, UPDATE, or DELETE + which uses explicit returning(). + + called within compilation level create_for_statement. + + The _return_orm_returning() method then receives the Result + after the statement was executed, and applies ORM loading to the + state that we first established here. + + """ + + if orm_level_statement._returning: + + fs = FromStatement( + orm_level_statement._returning, dml_level_statement + ) + fs = fs.options(*orm_level_statement._with_options) + self.select_statement = fs + self.from_statement_ctx = ( + fsc + ) = ORMFromStatementCompileState.create_for_statement(fs, compiler) + fsc.setup_dml_returning_compile_state(dml_mapper) + + dml_level_statement = dml_level_statement._generate() + dml_level_statement._returning = () + + cols_to_return = [c for c in fsc.primary_columns if c is not None] + + # since we are splicing result sets together, make sure there + # are columns of some kind returned in each result set + if not cols_to_return: + cols_to_return.extend(dml_mapper.primary_key) + + if use_supplemental_cols: + dml_level_statement = dml_level_statement.return_defaults( + supplemental_cols=cols_to_return + ) + else: + dml_level_statement = dml_level_statement.returning( + *cols_to_return + ) + + return dml_level_statement + + @classmethod + def _return_orm_returning( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + result, + ): + + execution_context = result.context + compile_state = execution_context.compiled.compile_state + + if compile_state.from_statement_ctx: + load_options = execution_options.get( + "_sa_orm_load_options", QueryContext.default_load_options + ) + querycontext = QueryContext( + compile_state.from_statement_ctx, + compile_state.select_statement, + params, + session, + load_options, + execution_options, + bind_arguments, + ) + return loading.instances(result, querycontext) + else: + return result + class BulkUDCompileState(ORMDMLState): class default_update_options(Options): - _synchronize_session: _SynchronizeSessionArgument = "evaluate" - _is_delete_using = False - _is_update_from = False - _autoflush = True - _subject_mapper = None + _dml_strategy: _DMLStrategyArgument = "auto" + _synchronize_session: _SynchronizeSessionArgument = "auto" + _can_use_returning: bool = False + _is_delete_using: bool = False + _is_update_from: bool = False + _autoflush: bool = True + _subject_mapper: Optional[Mapper[Any]] = None _resolved_values = EMPTY_DICT - _resolved_keys_as_propnames = EMPTY_DICT - _value_evaluators = EMPTY_DICT - _matched_objects = None + _eval_condition = None _matched_rows = None _refresh_identity_token = None @@ -295,19 +589,16 @@ class BulkUDCompileState(ORMDMLState): execution_options, ) = BulkUDCompileState.default_update_options.from_execution_options( "_sa_orm_update_options", - {"synchronize_session", "is_delete_using", "is_update_from"}, + { + "synchronize_session", + "is_delete_using", + "is_update_from", + "dml_strategy", + }, execution_options, statement._execution_options, ) - sync = update_options._synchronize_session - if sync is not None: - if sync not in ("evaluate", "fetch", False): - raise sa_exc.ArgumentError( - "Valid strategies for session synchronization " - "are 'evaluate', 'fetch', False" - ) - bind_arguments["clause"] = statement try: plugin_subject = statement._propagate_attrs["plugin_subject"] @@ -318,43 +609,86 @@ class BulkUDCompileState(ORMDMLState): update_options += {"_subject_mapper": plugin_subject.mapper} + if not isinstance(params, list): + if update_options._dml_strategy == "auto": + update_options += {"_dml_strategy": "orm"} + elif update_options._dml_strategy == "bulk": + raise sa_exc.InvalidRequestError( + 'Can\'t use "bulk" ORM insert strategy without ' + "passing separate parameters" + ) + else: + if update_options._dml_strategy == "auto": + update_options += {"_dml_strategy": "bulk"} + elif update_options._dml_strategy == "orm": + raise sa_exc.InvalidRequestError( + 'Can\'t use "orm" ORM insert strategy with a ' + "separate parameter list" + ) + + sync = update_options._synchronize_session + if sync is not None: + if sync not in ("auto", "evaluate", "fetch", False): + raise sa_exc.ArgumentError( + "Valid strategies for session synchronization " + "are 'auto', 'evaluate', 'fetch', False" + ) + if update_options._dml_strategy == "bulk" and sync == "fetch": + raise sa_exc.InvalidRequestError( + "The 'fetch' synchronization strategy is not available " + "for 'bulk' ORM updates (i.e. multiple parameter sets)" + ) + if update_options._autoflush: session._autoflush() + if update_options._dml_strategy == "orm": + + if update_options._synchronize_session == "auto": + update_options = cls._do_pre_synchronize_auto( + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ) + elif update_options._synchronize_session == "evaluate": + update_options = cls._do_pre_synchronize_evaluate( + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ) + elif update_options._synchronize_session == "fetch": + update_options = cls._do_pre_synchronize_fetch( + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ) + elif update_options._dml_strategy == "bulk": + if update_options._synchronize_session == "auto": + update_options += {"_synchronize_session": "evaluate"} + + # indicators from the "pre exec" step that are then + # added to the DML statement, which will also be part of the cache + # key. The compile level create_for_statement() method will then + # consume these at compiler time. statement = statement._annotate( { "synchronize_session": update_options._synchronize_session, "is_delete_using": update_options._is_delete_using, "is_update_from": update_options._is_update_from, + "dml_strategy": update_options._dml_strategy, + "can_use_returning": update_options._can_use_returning, } ) - # this stage of the execution is called before the do_orm_execute event - # hook. meaning for an extension like horizontal sharding, this step - # happens before the extension splits out into multiple backends and - # runs only once. if we do pre_sync_fetch, we execute a SELECT - # statement, which the horizontal sharding extension splits amongst the - # shards and combines the results together. - - if update_options._synchronize_session == "evaluate": - update_options = cls._do_pre_synchronize_evaluate( - session, - statement, - params, - execution_options, - bind_arguments, - update_options, - ) - elif update_options._synchronize_session == "fetch": - update_options = cls._do_pre_synchronize_fetch( - session, - statement, - params, - execution_options, - bind_arguments, - update_options, - ) - return ( statement, util.immutabledict(execution_options).union( @@ -382,12 +716,30 @@ class BulkUDCompileState(ORMDMLState): # individual ones we return here. update_options = execution_options["_sa_orm_update_options"] - if update_options._synchronize_session == "evaluate": - cls._do_post_synchronize_evaluate(session, result, update_options) - elif update_options._synchronize_session == "fetch": - cls._do_post_synchronize_fetch(session, result, update_options) + if update_options._dml_strategy == "orm": + if update_options._synchronize_session == "evaluate": + cls._do_post_synchronize_evaluate( + session, statement, result, update_options + ) + elif update_options._synchronize_session == "fetch": + cls._do_post_synchronize_fetch( + session, statement, result, update_options + ) + elif update_options._dml_strategy == "bulk": + if update_options._synchronize_session == "evaluate": + cls._do_post_synchronize_bulk_evaluate( + session, params, result, update_options + ) + return result - return result + return cls._return_orm_returning( + session, + statement, + params, + execution_options, + bind_arguments, + result, + ) @classmethod def _adjust_for_extra_criteria(cls, global_attributes, ext_info): @@ -473,11 +825,76 @@ class BulkUDCompileState(ORMDMLState): primary_key_convert = [ lookup[bpk] for bpk in mapper.base_mapper.primary_key ] - return [tuple(row[idx] for idx in primary_key_convert) for row in rows] @classmethod - def _do_pre_synchronize_evaluate( + def _get_matched_objects_on_criteria(cls, update_options, states): + mapper = update_options._subject_mapper + eval_condition = update_options._eval_condition + + raw_data = [ + (state.obj(), state, state.dict) + for state in states + if state.mapper.isa(mapper) and not state.expired + ] + + identity_token = update_options._refresh_identity_token + if identity_token is not None: + raw_data = [ + (obj, state, dict_) + for obj, state, dict_ in raw_data + if state.identity_token == identity_token + ] + + result = [] + for obj, state, dict_ in raw_data: + evaled_condition = eval_condition(obj) + + # caution: don't use "in ()" or == here, _EXPIRE_OBJECT + # evaluates as True for all comparisons + if ( + evaled_condition is True + or evaled_condition is evaluator._EXPIRED_OBJECT + ): + result.append( + ( + obj, + state, + dict_, + evaled_condition is evaluator._EXPIRED_OBJECT, + ) + ) + return result + + @classmethod + def _eval_condition_from_statement(cls, update_options, statement): + mapper = update_options._subject_mapper + target_cls = mapper.class_ + + evaluator_compiler = evaluator.EvaluatorCompiler(target_cls) + crit = () + if statement._where_criteria: + crit += statement._where_criteria + + global_attributes = {} + for opt in statement._with_options: + if opt._is_criteria_option: + opt.get_global_criteria(global_attributes) + + if global_attributes: + crit += cls._adjust_for_extra_criteria(global_attributes, mapper) + + if crit: + eval_condition = evaluator_compiler.process(*crit) + else: + + def eval_condition(obj): + return True + + return eval_condition + + @classmethod + def _do_pre_synchronize_auto( cls, session, statement, @@ -486,33 +903,59 @@ class BulkUDCompileState(ORMDMLState): bind_arguments, update_options, ): - mapper = update_options._subject_mapper - target_cls = mapper.class_ + """setup auto sync strategy + + + "auto" checks if we can use "evaluate" first, then falls back + to "fetch" + + evaluate is vastly more efficient for the common case + where session is empty, only has a few objects, and the UPDATE + statement can potentially match thousands/millions of rows. - value_evaluators = resolved_keys_as_propnames = EMPTY_DICT + OTOH more complex criteria that fails to work with "evaluate" + we would hope usually correlates with fewer net rows. + + """ try: - evaluator_compiler = evaluator.EvaluatorCompiler(target_cls) - crit = () - if statement._where_criteria: - crit += statement._where_criteria + eval_condition = cls._eval_condition_from_statement( + update_options, statement + ) - global_attributes = {} - for opt in statement._with_options: - if opt._is_criteria_option: - opt.get_global_criteria(global_attributes) + except evaluator.UnevaluatableError: + pass + else: + return update_options + { + "_eval_condition": eval_condition, + "_synchronize_session": "evaluate", + } - if global_attributes: - crit += cls._adjust_for_extra_criteria( - global_attributes, mapper - ) + update_options += {"_synchronize_session": "fetch"} + return cls._do_pre_synchronize_fetch( + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ) - if crit: - eval_condition = evaluator_compiler.process(*crit) - else: + @classmethod + def _do_pre_synchronize_evaluate( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ): - def eval_condition(obj): - return True + try: + eval_condition = cls._eval_condition_from_statement( + update_options, statement + ) except evaluator.UnevaluatableError as err: raise sa_exc.InvalidRequestError( @@ -521,52 +964,8 @@ class BulkUDCompileState(ORMDMLState): "synchronize_session execution option." % err ) from err - if statement.__visit_name__ == "lambda_element": - # ._resolved is called on every LambdaElement in order to - # generate the cache key, so this access does not add - # additional expense - effective_statement = statement._resolved - else: - effective_statement = statement - - if effective_statement.__visit_name__ == "update": - resolved_values = cls._get_resolved_values( - mapper, effective_statement - ) - value_evaluators = {} - resolved_keys_as_propnames = cls._resolved_keys_as_propnames( - mapper, resolved_values - ) - for key, value in resolved_keys_as_propnames: - try: - _evaluator = evaluator_compiler.process( - coercions.expect(roles.ExpressionElementRole, value) - ) - except evaluator.UnevaluatableError: - pass - else: - value_evaluators[key] = _evaluator - - # TODO: detect when the where clause is a trivial primary key match. - matched_objects = [ - state.obj() - for state in session.identity_map.all_states() - if state.mapper.isa(mapper) - and not state.expired - and eval_condition(state.obj()) - and ( - update_options._refresh_identity_token is None - # TODO: coverage for the case where horizontal sharding - # invokes an update() or delete() given an explicit identity - # token up front - or state.identity_token - == update_options._refresh_identity_token - ) - ] return update_options + { - "_matched_objects": matched_objects, - "_value_evaluators": value_evaluators, - "_resolved_keys_as_propnames": resolved_keys_as_propnames, + "_eval_condition": eval_condition, } @classmethod @@ -584,12 +983,6 @@ class BulkUDCompileState(ORMDMLState): def _resolved_keys_as_propnames(cls, mapper, resolved_values): values = [] for k, v in resolved_values: - if isinstance(k, attributes.QueryableAttribute): - values.append((k.key, v)) - continue - elif hasattr(k, "__clause_element__"): - k = k.__clause_element__() - if mapper and isinstance(k, expression.ColumnElement): try: attr = mapper._columntoproperty[k] @@ -599,7 +992,8 @@ class BulkUDCompileState(ORMDMLState): values.append((attr.key, v)) else: raise sa_exc.InvalidRequestError( - "Invalid expression type: %r" % k + "Attribute name not found, can't be " + "synchronized back to objects: %r" % k ) return values @@ -622,14 +1016,43 @@ class BulkUDCompileState(ORMDMLState): ) select_stmt._where_criteria = statement._where_criteria + # conditionally run the SELECT statement for pre-fetch, testing the + # "bind" for if we can use RETURNING or not using the do_orm_execute + # event. If RETURNING is available, the do_orm_execute event + # will cancel the SELECT from being actually run. + # + # The way this is organized seems strange, why don't we just + # call can_use_returning() before invoking the statement and get + # answer?, why does this go through the whole execute phase using an + # event? Answer: because we are integrating with extensions such + # as the horizontal sharding extention that "multiplexes" an individual + # statement run through multiple engines, and it uses + # do_orm_execute() to do that. + + can_use_returning = None + def skip_for_returning(orm_context: ORMExecuteState) -> Any: bind = orm_context.session.get_bind(**orm_context.bind_arguments) - if cls.can_use_returning( + nonlocal can_use_returning + + per_bind_result = cls.can_use_returning( bind.dialect, mapper, is_update_from=update_options._is_update_from, is_delete_using=update_options._is_delete_using, - ): + ) + + if can_use_returning is not None: + if can_use_returning != per_bind_result: + raise sa_exc.InvalidRequestError( + "For synchronize_session='fetch', can't mix multiple " + "backends where some support RETURNING and others " + "don't" + ) + else: + can_use_returning = per_bind_result + + if per_bind_result: return _result.null_result() else: return None @@ -643,52 +1066,22 @@ class BulkUDCompileState(ORMDMLState): ) matched_rows = result.fetchall() - value_evaluators = EMPTY_DICT - - if statement.__visit_name__ == "lambda_element": - # ._resolved is called on every LambdaElement in order to - # generate the cache key, so this access does not add - # additional expense - effective_statement = statement._resolved - else: - effective_statement = statement - - if effective_statement.__visit_name__ == "update": - target_cls = mapper.class_ - evaluator_compiler = evaluator.EvaluatorCompiler(target_cls) - resolved_values = cls._get_resolved_values( - mapper, effective_statement - ) - resolved_keys_as_propnames = cls._resolved_keys_as_propnames( - mapper, resolved_values - ) - - resolved_keys_as_propnames = cls._resolved_keys_as_propnames( - mapper, resolved_values - ) - value_evaluators = {} - for key, value in resolved_keys_as_propnames: - try: - _evaluator = evaluator_compiler.process( - coercions.expect(roles.ExpressionElementRole, value) - ) - except evaluator.UnevaluatableError: - pass - else: - value_evaluators[key] = _evaluator - - else: - resolved_keys_as_propnames = EMPTY_DICT - return update_options + { - "_value_evaluators": value_evaluators, "_matched_rows": matched_rows, - "_resolved_keys_as_propnames": resolved_keys_as_propnames, + "_can_use_returning": can_use_returning, } @CompileState.plugin_for("orm", "insert") -class ORMInsert(ORMDMLState, InsertDMLState): +class BulkORMInsert(ORMDMLState, InsertDMLState): + class default_insert_options(Options): + _dml_strategy: _DMLStrategyArgument = "auto" + _render_nulls: bool = False + _return_defaults: bool = False + _subject_mapper: Optional[Mapper[Any]] = None + + select_statement: Optional[FromStatement] = None + @classmethod def orm_pre_session_exec( cls, @@ -699,6 +1092,16 @@ class ORMInsert(ORMDMLState, InsertDMLState): bind_arguments, is_reentrant_invoke, ): + + ( + insert_options, + execution_options, + ) = BulkORMInsert.default_insert_options.from_execution_options( + "_sa_orm_insert_options", + {"dml_strategy"}, + execution_options, + statement._execution_options, + ) bind_arguments["clause"] = statement try: plugin_subject = statement._propagate_attrs["plugin_subject"] @@ -707,22 +1110,209 @@ class ORMInsert(ORMDMLState, InsertDMLState): else: bind_arguments["mapper"] = plugin_subject.mapper + insert_options += {"_subject_mapper": plugin_subject.mapper} + + if not params: + if insert_options._dml_strategy == "auto": + insert_options += {"_dml_strategy": "orm"} + elif insert_options._dml_strategy == "bulk": + raise sa_exc.InvalidRequestError( + 'Can\'t use "bulk" ORM insert strategy without ' + "passing separate parameters" + ) + else: + if insert_options._dml_strategy == "auto": + insert_options += {"_dml_strategy": "bulk"} + elif insert_options._dml_strategy == "orm": + raise sa_exc.InvalidRequestError( + 'Can\'t use "orm" ORM insert strategy with a ' + "separate parameter list" + ) + + if insert_options._dml_strategy != "raw": + # for ORM object loading, like ORMContext, we have to disable + # result set adapt_to_context, because we will be generating a + # new statement with specific columns that's cached inside of + # an ORMFromStatementCompileState, which we will re-use for + # each result. + if not execution_options: + execution_options = context._orm_load_exec_options + else: + execution_options = execution_options.union( + context._orm_load_exec_options + ) + + statement = statement._annotate( + {"dml_strategy": insert_options._dml_strategy} + ) + return ( statement, - util.immutabledict(execution_options), + util.immutabledict(execution_options).union( + {"_sa_orm_insert_options": insert_options} + ), ) @classmethod - def orm_setup_cursor_result( + def orm_execute_statement( cls, - session, - statement, - params, - execution_options, - bind_arguments, - result, - ): - return result + session: Session, + statement: dml.Insert, + params: _CoreAnyExecuteParams, + execution_options: _ExecuteOptionsParameter, + bind_arguments: _BindArguments, + conn: Connection, + ) -> _result.Result: + + insert_options = execution_options.get( + "_sa_orm_insert_options", cls.default_insert_options + ) + + if insert_options._dml_strategy not in ( + "raw", + "bulk", + "orm", + "auto", + ): + raise sa_exc.ArgumentError( + "Valid strategies for ORM insert strategy " + "are 'raw', 'orm', 'bulk', 'auto" + ) + + result: _result.Result[Any] + + if insert_options._dml_strategy == "raw": + result = conn.execute( + statement, params or {}, execution_options=execution_options + ) + return result + + if insert_options._dml_strategy == "bulk": + mapper = insert_options._subject_mapper + + if ( + statement._post_values_clause is not None + and mapper._multiple_persistence_tables + ): + raise sa_exc.InvalidRequestError( + "bulk INSERT with a 'post values' clause " + "(typically upsert) not supported for multi-table " + f"mapper {mapper}" + ) + + assert mapper is not None + assert session._transaction is not None + result = _bulk_insert( + mapper, + cast( + "Iterable[Dict[str, Any]]", + [params] if isinstance(params, dict) else params, + ), + session._transaction, + isstates=False, + return_defaults=insert_options._return_defaults, + render_nulls=insert_options._render_nulls, + use_orm_insert_stmt=statement, + execution_options=execution_options, + ) + elif insert_options._dml_strategy == "orm": + result = conn.execute( + statement, params or {}, execution_options=execution_options + ) + else: + raise AssertionError() + + if not bool(statement._returning): + return result + + return cls._return_orm_returning( + session, + statement, + params, + execution_options, + bind_arguments, + result, + ) + + @classmethod + def create_for_statement(cls, statement, compiler, **kw) -> BulkORMInsert: + + self = cast( + BulkORMInsert, + super().create_for_statement(statement, compiler, **kw), + ) + + if compiler is not None: + toplevel = not compiler.stack + else: + toplevel = True + if not toplevel: + return self + + mapper = statement._propagate_attrs["plugin_subject"] + dml_strategy = statement._annotations.get("dml_strategy", "raw") + if dml_strategy == "bulk": + self._setup_for_bulk_insert(compiler) + elif dml_strategy == "orm": + self._setup_for_orm_insert(compiler, mapper) + + return self + + @classmethod + def _resolved_keys_as_col_keys(cls, mapper, resolved_value_dict): + return { + col.key if col is not None else k: v + for col, k, v in ( + (mapper.c.get(k), k, v) for k, v in resolved_value_dict.items() + ) + } + + def _setup_for_orm_insert(self, compiler, mapper): + statement = orm_level_statement = cast(dml.Insert, self.statement) + + statement = self._setup_orm_returning( + compiler, + orm_level_statement, + statement, + use_supplemental_cols=False, + ) + self.statement = statement + + def _setup_for_bulk_insert(self, compiler): + """establish an INSERT statement within the context of + bulk insert. + + This method will be within the "conn.execute()" call that is invoked + by persistence._emit_insert_statement(). + + """ + statement = orm_level_statement = cast(dml.Insert, self.statement) + an = statement._annotations + + emit_insert_table, emit_insert_mapper = ( + an["_emit_insert_table"], + an["_emit_insert_mapper"], + ) + + statement = statement._clone() + + statement.table = emit_insert_table + if self._dict_parameters: + self._dict_parameters = { + col: val + for col, val in self._dict_parameters.items() + if col.table is emit_insert_table + } + + statement = self._setup_orm_returning( + compiler, + orm_level_statement, + statement, + use_supplemental_cols=True, + dml_mapper=emit_insert_mapper, + ) + + self.statement = statement @CompileState.plugin_for("orm", "update") @@ -732,13 +1322,27 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): self = cls.__new__(cls) + dml_strategy = statement._annotations.get( + "dml_strategy", "unspecified" + ) + + if dml_strategy == "bulk": + self._setup_for_bulk_update(statement, compiler) + elif dml_strategy in ("orm", "unspecified"): + self._setup_for_orm_update(statement, compiler) + + return self + + def _setup_for_orm_update(self, statement, compiler, **kw): + orm_level_statement = statement + ext_info = statement.table._annotations["parententity"] self.mapper = mapper = ext_info.mapper self.extra_criteria_entities = {} - self._resolved_values = cls._get_resolved_values(mapper, statement) + self._resolved_values = self._get_resolved_values(mapper, statement) extra_criteria_attributes = {} @@ -749,8 +1353,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): if statement._values: self._resolved_values = dict(self._resolved_values) - new_stmt = sql.Update.__new__(sql.Update) - new_stmt.__dict__.update(statement.__dict__) + new_stmt = statement._clone() new_stmt.table = mapper.local_table # note if the statement has _multi_values, these @@ -762,7 +1365,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): elif statement._values: new_stmt._values = self._resolved_values - new_crit = cls._adjust_for_extra_criteria( + new_crit = self._adjust_for_extra_criteria( extra_criteria_attributes, mapper ) if new_crit: @@ -776,21 +1379,150 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): UpdateDMLState.__init__(self, new_stmt, compiler, **kw) - if compiler._annotations.get( + use_supplemental_cols = False + + synchronize_session = compiler._annotations.get( "synchronize_session", None - ) == "fetch" and self.can_use_returning( - compiler.dialect, mapper, is_multitable=self.is_multitable - ): - if new_stmt._returning: - raise sa_exc.InvalidRequestError( - "Can't use synchronize_session='fetch' " - "with explicit returning()" + ) + can_use_returning = compiler._annotations.get( + "can_use_returning", None + ) + if can_use_returning is not False: + # even though pre_exec has determined basic + # can_use_returning for the dialect, if we are to use + # RETURNING we need to run can_use_returning() at this level + # unconditionally because is_delete_using was not known + # at the pre_exec level + can_use_returning = ( + synchronize_session == "fetch" + and self.can_use_returning( + compiler.dialect, mapper, is_multitable=self.is_multitable ) - self.statement = self.statement.returning( - *mapper.local_table.primary_key ) - return self + if synchronize_session == "fetch" and can_use_returning: + use_supplemental_cols = True + + # NOTE: we might want to RETURNING the actual columns to be + # synchronized also. however this is complicated and difficult + # to align against the behavior of "evaluate". Additionally, + # in a large number (if not the majority) of cases, we have the + # "evaluate" answer, usually a fixed value, in memory already and + # there's no need to re-fetch the same value + # over and over again. so perhaps if it could be RETURNING just + # the elements that were based on a SQL expression and not + # a constant. For now it doesn't quite seem worth it + new_stmt = new_stmt.return_defaults( + *(list(mapper.local_table.primary_key)) + ) + + new_stmt = self._setup_orm_returning( + compiler, + orm_level_statement, + new_stmt, + use_supplemental_cols=use_supplemental_cols, + ) + + self.statement = new_stmt + + def _setup_for_bulk_update(self, statement, compiler, **kw): + """establish an UPDATE statement within the context of + bulk insert. + + This method will be within the "conn.execute()" call that is invoked + by persistence._emit_update_statement(). + + """ + statement = cast(dml.Update, statement) + an = statement._annotations + + emit_update_table, _ = ( + an["_emit_update_table"], + an["_emit_update_mapper"], + ) + + statement = statement._clone() + statement.table = emit_update_table + + UpdateDMLState.__init__(self, statement, compiler, **kw) + + if self._ordered_values: + raise sa_exc.InvalidRequestError( + "bulk ORM UPDATE does not support ordered_values() for " + "custom UPDATE statements with bulk parameter sets. Use a " + "non-bulk UPDATE statement or use values()." + ) + + if self._dict_parameters: + self._dict_parameters = { + col: val + for col, val in self._dict_parameters.items() + if col.table is emit_update_table + } + self.statement = statement + + @classmethod + def orm_execute_statement( + cls, + session: Session, + statement: dml.Update, + params: _CoreAnyExecuteParams, + execution_options: _ExecuteOptionsParameter, + bind_arguments: _BindArguments, + conn: Connection, + ) -> _result.Result: + + update_options = execution_options.get( + "_sa_orm_update_options", cls.default_update_options + ) + + if update_options._dml_strategy not in ("orm", "auto", "bulk"): + raise sa_exc.ArgumentError( + "Valid strategies for ORM UPDATE strategy " + "are 'orm', 'auto', 'bulk'" + ) + + result: _result.Result[Any] + + if update_options._dml_strategy == "bulk": + if statement._where_criteria: + raise sa_exc.InvalidRequestError( + "WHERE clause with bulk ORM UPDATE not " + "supported right now. Statement may be invoked at the " + "Core level using " + "session.connection().execute(stmt, parameters)" + ) + mapper = update_options._subject_mapper + assert mapper is not None + assert session._transaction is not None + result = _bulk_update( + mapper, + cast( + "Iterable[Dict[str, Any]]", + [params] if isinstance(params, dict) else params, + ), + session._transaction, + isstates=False, + update_changed_only=False, + use_orm_update_stmt=statement, + ) + return cls.orm_setup_cursor_result( + session, + statement, + params, + execution_options, + bind_arguments, + result, + ) + else: + return super().orm_execute_statement( + session, + statement, + params, + execution_options, + bind_arguments, + conn, + ) @classmethod def can_use_returning( @@ -827,119 +1559,80 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): return True @classmethod - def _get_crud_kv_pairs(cls, statement, kv_iterator): - plugin_subject = statement._propagate_attrs["plugin_subject"] - - core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs - - if not plugin_subject or not plugin_subject.mapper: - return core_get_crud_kv_pairs(statement, kv_iterator) - - mapper = plugin_subject.mapper - - values = [] - - for k, v in kv_iterator: - k = coercions.expect(roles.DMLColumnRole, k) + def _do_post_synchronize_bulk_evaluate( + cls, session, params, result, update_options + ): + if not params: + return - if isinstance(k, str): - desc = _entity_namespace_key(mapper, k, default=NO_VALUE) - if desc is NO_VALUE: - values.append( - ( - k, - coercions.expect( - roles.ExpressionElementRole, - v, - type_=sqltypes.NullType(), - is_crud=True, - ), - ) - ) - else: - values.extend( - core_get_crud_kv_pairs( - statement, desc._bulk_update_tuples(v) - ) - ) - elif "entity_namespace" in k._annotations: - k_anno = k._annotations - attr = _entity_namespace_key( - k_anno["entity_namespace"], k_anno["proxy_key"] - ) - values.extend( - core_get_crud_kv_pairs( - statement, attr._bulk_update_tuples(v) - ) - ) - else: - values.append( - ( - k, - coercions.expect( - roles.ExpressionElementRole, - v, - type_=sqltypes.NullType(), - is_crud=True, - ), - ) - ) - return values + mapper = update_options._subject_mapper + pk_keys = [prop.key for prop in mapper._identity_key_props] - @classmethod - def _do_post_synchronize_evaluate(cls, session, result, update_options): + identity_map = session.identity_map - states = set() - evaluated_keys = list(update_options._value_evaluators.keys()) - values = update_options._resolved_keys_as_propnames - attrib = set(k for k, v in values) - for obj in update_options._matched_objects: - - state, dict_ = ( - attributes.instance_state(obj), - attributes.instance_dict(obj), + for param in params: + identity_key = mapper.identity_key_from_primary_key( + (param[key] for key in pk_keys), + update_options._refresh_identity_token, ) - - # the evaluated states were gathered across all identity tokens. - # however the post_sync events are called per identity token, - # so filter. - if ( - update_options._refresh_identity_token is not None - and state.identity_token - != update_options._refresh_identity_token - ): + state = identity_map.fast_get_state(identity_key) + if not state: continue + evaluated_keys = set(param).difference(pk_keys) + + dict_ = state.dict # only evaluate unmodified attributes to_evaluate = state.unmodified.intersection(evaluated_keys) for key in to_evaluate: if key in dict_: - dict_[key] = update_options._value_evaluators[key](obj) + dict_[key] = param[key] state.manager.dispatch.refresh(state, None, to_evaluate) state._commit(dict_, list(to_evaluate)) - to_expire = attrib.intersection(dict_).difference(to_evaluate) + # attributes that were formerly modified instead get expired. + # this only gets hit if the session had pending changes + # and autoflush were set to False. + to_expire = evaluated_keys.intersection(dict_).difference( + to_evaluate + ) if to_expire: state._expire_attributes(dict_, to_expire) - states.add(state) - session._register_altered(states) + @classmethod + def _do_post_synchronize_evaluate( + cls, session, statement, result, update_options + ): + + matched_objects = cls._get_matched_objects_on_criteria( + update_options, + session.identity_map.all_states(), + ) + + cls._apply_update_set_values_to_objects( + session, + update_options, + statement, + [(obj, state, dict_) for obj, state, dict_, _ in matched_objects], + ) @classmethod - def _do_post_synchronize_fetch(cls, session, result, update_options): + def _do_post_synchronize_fetch( + cls, session, statement, result, update_options + ): target_mapper = update_options._subject_mapper - states = set() - evaluated_keys = list(update_options._value_evaluators.keys()) - - if result.returns_rows: - rows = cls._interpret_returning_rows(target_mapper, result.all()) + returned_defaults_rows = result.returned_defaults_rows + if returned_defaults_rows: + pk_rows = cls._interpret_returning_rows( + target_mapper, returned_defaults_rows + ) matched_rows = [ tuple(row) + (update_options._refresh_identity_token,) - for row in rows + for row in pk_rows ] else: matched_rows = update_options._matched_rows @@ -960,23 +1653,69 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): if identity_key in session.identity_map ] - values = update_options._resolved_keys_as_propnames - attrib = set(k for k, v in values) + if not objs: + return - for obj in objs: - state, dict_ = ( - attributes.instance_state(obj), - attributes.instance_dict(obj), - ) + cls._apply_update_set_values_to_objects( + session, + update_options, + statement, + [ + ( + obj, + attributes.instance_state(obj), + attributes.instance_dict(obj), + ) + for obj in objs + ], + ) + + @classmethod + def _apply_update_set_values_to_objects( + cls, session, update_options, statement, matched_objects + ): + """apply values to objects derived from an update statement, e.g. + UPDATE..SET <values> + + """ + mapper = update_options._subject_mapper + target_cls = mapper.class_ + evaluator_compiler = evaluator.EvaluatorCompiler(target_cls) + resolved_values = cls._get_resolved_values(mapper, statement) + resolved_keys_as_propnames = cls._resolved_keys_as_propnames( + mapper, resolved_values + ) + value_evaluators = {} + for key, value in resolved_keys_as_propnames: + try: + _evaluator = evaluator_compiler.process( + coercions.expect(roles.ExpressionElementRole, value) + ) + except evaluator.UnevaluatableError: + pass + else: + value_evaluators[key] = _evaluator + + evaluated_keys = list(value_evaluators.keys()) + attrib = set(k for k, v in resolved_keys_as_propnames) + + states = set() + for obj, state, dict_ in matched_objects: to_evaluate = state.unmodified.intersection(evaluated_keys) + for key in to_evaluate: if key in dict_: - dict_[key] = update_options._value_evaluators[key](obj) + # only run eval for attributes that are present. + dict_[key] = value_evaluators[key](obj) + state.manager.dispatch.refresh(state, None, to_evaluate) state._commit(dict_, list(to_evaluate)) + # attributes that were formerly modified instead get expired. + # this only gets hit if the session had pending changes + # and autoflush were set to False. to_expire = attrib.intersection(dict_).difference(to_evaluate) if to_expire: state._expire_attributes(dict_, to_expire) @@ -991,6 +1730,8 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState): def create_for_statement(cls, statement, compiler, **kw): self = cls.__new__(cls) + orm_level_statement = statement + ext_info = statement.table._annotations["parententity"] self.mapper = mapper = ext_info.mapper @@ -1002,31 +1743,97 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState): if opt._is_criteria_option: opt.get_global_criteria(extra_criteria_attributes) + new_stmt = statement._clone() + new_stmt.table = mapper.local_table + new_crit = cls._adjust_for_extra_criteria( extra_criteria_attributes, mapper ) if new_crit: - statement = statement.where(*new_crit) + new_stmt = new_stmt.where(*new_crit) # do this first as we need to determine if there is # DELETE..FROM - DeleteDMLState.__init__(self, statement, compiler, **kw) + DeleteDMLState.__init__(self, new_stmt, compiler, **kw) + + use_supplemental_cols = False - if compiler._annotations.get( + synchronize_session = compiler._annotations.get( "synchronize_session", None - ) == "fetch" and self.can_use_returning( - compiler.dialect, - mapper, - is_multitable=self.is_multitable, - is_delete_using=compiler._annotations.get( - "is_delete_using", False - ), - ): - self.statement = statement.returning(*statement.table.primary_key) + ) + can_use_returning = compiler._annotations.get( + "can_use_returning", None + ) + if can_use_returning is not False: + # even though pre_exec has determined basic + # can_use_returning for the dialect, if we are to use + # RETURNING we need to run can_use_returning() at this level + # unconditionally because is_delete_using was not known + # at the pre_exec level + can_use_returning = ( + synchronize_session == "fetch" + and self.can_use_returning( + compiler.dialect, + mapper, + is_multitable=self.is_multitable, + is_delete_using=compiler._annotations.get( + "is_delete_using", False + ), + ) + ) + + if can_use_returning: + use_supplemental_cols = True + + new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key) + + new_stmt = self._setup_orm_returning( + compiler, + orm_level_statement, + new_stmt, + use_supplemental_cols=use_supplemental_cols, + ) + + self.statement = new_stmt return self @classmethod + def orm_execute_statement( + cls, + session: Session, + statement: dml.Delete, + params: _CoreAnyExecuteParams, + execution_options: _ExecuteOptionsParameter, + bind_arguments: _BindArguments, + conn: Connection, + ) -> _result.Result: + + update_options = execution_options.get( + "_sa_orm_update_options", cls.default_update_options + ) + + if update_options._dml_strategy == "bulk": + raise sa_exc.InvalidRequestError( + "Bulk ORM DELETE not supported right now. " + "Statement may be invoked at the " + "Core level using " + "session.connection().execute(stmt, parameters)" + ) + + if update_options._dml_strategy not in ( + "orm", + "auto", + ): + raise sa_exc.ArgumentError( + "Valid strategies for ORM DELETE strategy are 'orm', 'auto'" + ) + + return super().orm_execute_statement( + session, statement, params, execution_options, bind_arguments, conn + ) + + @classmethod def can_use_returning( cls, dialect: Dialect, @@ -1068,25 +1875,41 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState): return True @classmethod - def _do_post_synchronize_evaluate(cls, session, result, update_options): - - session._remove_newly_deleted( - [ - attributes.instance_state(obj) - for obj in update_options._matched_objects - ] + def _do_post_synchronize_evaluate( + cls, session, statement, result, update_options + ): + matched_objects = cls._get_matched_objects_on_criteria( + update_options, + session.identity_map.all_states(), ) + to_delete = [] + + for _, state, dict_, is_partially_expired in matched_objects: + if is_partially_expired: + state._expire(dict_, session.identity_map._modified) + else: + to_delete.append(state) + + if to_delete: + session._remove_newly_deleted(to_delete) + @classmethod - def _do_post_synchronize_fetch(cls, session, result, update_options): + def _do_post_synchronize_fetch( + cls, session, statement, result, update_options + ): target_mapper = update_options._subject_mapper - if result.returns_rows: - rows = cls._interpret_returning_rows(target_mapper, result.all()) + returned_defaults_rows = result.returned_defaults_rows + + if returned_defaults_rows: + pk_rows = cls._interpret_returning_rows( + target_mapper, returned_defaults_rows + ) matched_rows = [ tuple(row) + (update_options._refresh_identity_token,) - for row in rows + for row in pk_rows ] else: matched_rows = update_options._matched_rows diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index dc96f8c3c..f8c7ba714 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -73,6 +73,7 @@ if TYPE_CHECKING: from .query import Query from .session import _BindArguments from .session import Session + from ..engine import Result from ..engine.interfaces import _CoreSingleExecuteParams from ..engine.interfaces import _ExecuteOptionsParameter from ..sql._typing import _ColumnsClauseArgument @@ -203,15 +204,19 @@ _orm_load_exec_options = util.immutabledict( class AbstractORMCompileState(CompileState): + is_dml_returning = False + @classmethod def create_for_statement( cls, statement: Union[Select, FromStatement], compiler: Optional[SQLCompiler], **kw: Any, - ) -> ORMCompileState: + ) -> AbstractORMCompileState: """Create a context for a statement given a :class:`.Compiler`. + This method is always invoked in the context of SQLCompiler.process(). + For a Select object, this would be invoked from SQLCompiler.visit_select(). For the special FromStatement object used by Query to indicate "Query.from_statement()", this is called by @@ -233,6 +238,28 @@ class AbstractORMCompileState(CompileState): raise NotImplementedError() @classmethod + def orm_execute_statement( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + conn, + ) -> Result: + result = conn.execute( + statement, params or {}, execution_options=execution_options + ) + return cls.orm_setup_cursor_result( + session, + statement, + params, + execution_options, + bind_arguments, + result, + ) + + @classmethod def orm_setup_cursor_result( cls, session, @@ -309,6 +336,17 @@ class ORMCompileState(AbstractORMCompileState): def __init__(self, *arg, **kw): raise NotImplementedError() + if TYPE_CHECKING: + + @classmethod + def create_for_statement( + cls, + statement: Union[Select, FromStatement], + compiler: Optional[SQLCompiler], + **kw: Any, + ) -> ORMCompileState: + ... + def _append_dedupe_col_collection(self, obj, col_collection): dedupe = self.dedupe_columns if obj not in dedupe: @@ -333,26 +371,6 @@ class ORMCompileState(AbstractORMCompileState): return SelectState._column_naming_convention(label_style) @classmethod - def create_for_statement( - cls, - statement: Union[Select, FromStatement], - compiler: Optional[SQLCompiler], - **kw: Any, - ) -> ORMCompileState: - """Create a context for a statement given a :class:`.Compiler`. - - This method is always invoked in the context of SQLCompiler.process(). - - For a Select object, this would be invoked from - SQLCompiler.visit_select(). For the special FromStatement object used - by Query to indicate "Query.from_statement()", this is called by - FromStatement._compiler_dispatch() that would be called by - SQLCompiler.process(). - - """ - raise NotImplementedError() - - @classmethod def get_column_descriptions(cls, statement): return _column_descriptions(statement) @@ -518,6 +536,49 @@ class ORMCompileState(AbstractORMCompileState): ) +class DMLReturningColFilter: + """an adapter used for the DML RETURNING case. + + Has a subset of the interface used by + :class:`.ORMAdapter` and is used for :class:`._QueryEntity` + instances to set up their columns as used in RETURNING for a + DML statement. + + """ + + __slots__ = ("mapper", "columns", "__weakref__") + + def __init__(self, target_mapper, immediate_dml_mapper): + if ( + immediate_dml_mapper is not None + and target_mapper.local_table + is not immediate_dml_mapper.local_table + ): + # joined inh, or in theory other kinds of multi-table mappings + self.mapper = immediate_dml_mapper + else: + # single inh, normal mappings, etc. + self.mapper = target_mapper + self.columns = self.columns = util.WeakPopulateDict( + self.adapt_check_present # type: ignore + ) + + def __call__(self, col, as_filter): + for cc in sql_util._find_columns(col): + c2 = self.adapt_check_present(cc) + if c2 is not None: + return col + else: + return None + + def adapt_check_present(self, col): + mapper = self.mapper + prop = mapper._columntoproperty.get(col, None) + if prop is None: + return None + return mapper.local_table.c.corresponding_column(col) + + @sql.base.CompileState.plugin_for("orm", "orm_from_statement") class ORMFromStatementCompileState(ORMCompileState): _from_obj_alias = None @@ -525,7 +586,7 @@ class ORMFromStatementCompileState(ORMCompileState): statement_container: FromStatement requested_statement: Union[SelectBase, TextClause, UpdateBase] - dml_table: _DMLTableElement + dml_table: Optional[_DMLTableElement] = None _has_orm_entities = False multi_row_eager_loaders = False @@ -541,7 +602,7 @@ class ORMFromStatementCompileState(ORMCompileState): statement_container: Union[Select, FromStatement], compiler: Optional[SQLCompiler], **kw: Any, - ) -> ORMCompileState: + ) -> ORMFromStatementCompileState: if compiler is not None: toplevel = not compiler.stack @@ -565,6 +626,7 @@ class ORMFromStatementCompileState(ORMCompileState): if statement.is_dml: self.dml_table = statement.table + self.is_dml_returning = True self._entities = [] self._polymorphic_adapters = {} @@ -674,6 +736,18 @@ class ORMFromStatementCompileState(ORMCompileState): def _get_current_adapter(self): return None + def setup_dml_returning_compile_state(self, dml_mapper): + """used by BulkORMInsert (and Update / Delete?) to set up a handler + for RETURNING to return ORM objects and expressions + + """ + target_mapper = self.statement._propagate_attrs.get( + "plugin_subject", None + ) + adapter = DMLReturningColFilter(target_mapper, dml_mapper) + for entity in self._entities: + entity.setup_dml_returning_compile_state(self, adapter) + class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]): """Core construct that represents a load of ORM objects from various @@ -813,7 +887,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): statement: Union[Select, FromStatement], compiler: Optional[SQLCompiler], **kw: Any, - ) -> ORMCompileState: + ) -> ORMSelectCompileState: """compiler hook, we arrive here from compiler.visit_select() only.""" self = cls.__new__(cls) @@ -2312,6 +2386,13 @@ class _QueryEntity: def setup_compile_state(self, compile_state: ORMCompileState) -> None: raise NotImplementedError() + def setup_dml_returning_compile_state( + self, + compile_state: ORMCompileState, + adapter: DMLReturningColFilter, + ) -> None: + raise NotImplementedError() + def row_processor(self, context, result): raise NotImplementedError() @@ -2509,8 +2590,24 @@ class _MapperEntity(_QueryEntity): return _instance, self._label_name, self._extra_entities - def setup_compile_state(self, compile_state): + def setup_dml_returning_compile_state( + self, + compile_state: ORMCompileState, + adapter: DMLReturningColFilter, + ) -> None: + loading._setup_entity_query( + compile_state, + self.mapper, + self, + self.path, + adapter, + compile_state.primary_columns, + with_polymorphic=self._with_polymorphic_mappers, + only_load_props=compile_state.compile_options._only_load_props, + polymorphic_discriminator=self._polymorphic_discriminator, + ) + def setup_compile_state(self, compile_state): adapter = self._get_entity_clauses(compile_state) single_table_crit = self.mapper._single_table_criterion @@ -2536,7 +2633,6 @@ class _MapperEntity(_QueryEntity): only_load_props=compile_state.compile_options._only_load_props, polymorphic_discriminator=self._polymorphic_discriminator, ) - compile_state._fallback_from_clauses.append(self.selectable) @@ -2743,9 +2839,7 @@ class _ColumnEntity(_QueryEntity): getter, label_name, extra_entities = self._row_processor if self.translate_raw_column: extra_entities += ( - result.context.invoked_statement._raw_columns[ - self.raw_column_index - ], + context.query._raw_columns[self.raw_column_index], ) return getter, label_name, extra_entities @@ -2781,9 +2875,7 @@ class _ColumnEntity(_QueryEntity): if self.translate_raw_column: extra_entities = self._extra_entities + ( - result.context.invoked_statement._raw_columns[ - self.raw_column_index - ], + context.query._raw_columns[self.raw_column_index], ) return getter, self._label_name, extra_entities else: @@ -2843,6 +2935,8 @@ class _RawColumnEntity(_ColumnEntity): current_adapter = compile_state._get_current_adapter() if current_adapter: column = current_adapter(self.column, False) + if column is None: + return else: column = self.column @@ -2944,10 +3038,25 @@ class _ORMColumnEntity(_ColumnEntity): self.entity_zero ) and entity.common_parent(self.entity_zero) + def setup_dml_returning_compile_state( + self, + compile_state: ORMCompileState, + adapter: DMLReturningColFilter, + ) -> None: + self._fetch_column = self.column + column = adapter(self.column, False) + if column is not None: + compile_state.dedupe_columns.add(column) + compile_state.primary_columns.append(column) + def setup_compile_state(self, compile_state): current_adapter = compile_state._get_current_adapter() if current_adapter: column = current_adapter(self.column, False) + if column is None: + assert compile_state.is_dml_returning + self._fetch_column = self.column + return else: column = self.column diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index 52b70b9d4..13d3b70fe 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -19,6 +19,7 @@ import operator import typing from typing import Any from typing import Callable +from typing import Dict from typing import List from typing import NoReturn from typing import Optional @@ -602,6 +603,31 @@ class Composite( def _attribute_keys(self) -> Sequence[str]: return [prop.key for prop in self.props] + def _populate_composite_bulk_save_mappings_fn( + self, + ) -> Callable[[Dict[str, Any]], None]: + + if self._generated_composite_accessor: + get_values = self._generated_composite_accessor + else: + + def get_values(val: Any) -> Tuple[Any]: + return val.__composite_values__() # type: ignore + + attrs = [prop.key for prop in self.props] + + def populate(dest_dict: Dict[str, Any]) -> None: + dest_dict.update( + { + key: val + for key, val in zip( + attrs, get_values(dest_dict.pop(self.key)) + ) + } + ) + + return populate + def get_history( self, state: InstanceState[Any], diff --git a/lib/sqlalchemy/orm/evaluator.py b/lib/sqlalchemy/orm/evaluator.py index b3129afdd..5af14cc00 100644 --- a/lib/sqlalchemy/orm/evaluator.py +++ b/lib/sqlalchemy/orm/evaluator.py @@ -9,8 +9,8 @@ from __future__ import annotations -import operator - +from .base import LoaderCallableStatus +from .base import PassiveFlag from .. import exc from .. import inspect from .. import util @@ -32,7 +32,16 @@ class _NoObject(operators.ColumnOperators): return None +class _ExpiredObject(operators.ColumnOperators): + def operate(self, *arg, **kw): + return self + + def reverse_operate(self, *arg, **kw): + return self + + _NO_OBJECT = _NoObject() +_EXPIRED_OBJECT = _ExpiredObject() class EvaluatorCompiler: @@ -73,6 +82,24 @@ class EvaluatorCompiler: f"alternate class {parentmapper.class_}" ) key = parentmapper._columntoproperty[clause].key + impl = parentmapper.class_manager[key].impl + + if impl is not None: + + def get_corresponding_attr(obj): + if obj is None: + return _NO_OBJECT + state = inspect(obj) + dict_ = state.dict + + value = impl.get( + state, dict_, passive=PassiveFlag.PASSIVE_NO_FETCH + ) + if value is LoaderCallableStatus.PASSIVE_NO_RESULT: + return _EXPIRED_OBJECT + return value + + return get_corresponding_attr else: key = clause.key if ( @@ -85,15 +112,16 @@ class EvaluatorCompiler: "make use of the actual mapped columns in ORM-evaluated " "UPDATE / DELETE expressions." ) + else: raise UnevaluatableError(f"Cannot evaluate column: {clause}") - get_corresponding_attr = operator.attrgetter(key) - return ( - lambda obj: get_corresponding_attr(obj) - if obj is not None - else _NO_OBJECT - ) + def get_corresponding_attr(obj): + if obj is None: + return _NO_OBJECT + return getattr(obj, key, _EXPIRED_OBJECT) + + return get_corresponding_attr def visit_tuple(self, clause): return self.visit_clauselist(clause) @@ -134,7 +162,9 @@ class EvaluatorCompiler: has_null = False for sub_evaluate in evaluators: value = sub_evaluate(obj) - if value: + if value is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + elif value: return True has_null = has_null or value is None if has_null: @@ -147,6 +177,9 @@ class EvaluatorCompiler: def evaluate(obj): for sub_evaluate in evaluators: value = sub_evaluate(obj) + if value is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + if not value: if value is None or value is _NO_OBJECT: return None @@ -160,7 +193,9 @@ class EvaluatorCompiler: values = [] for sub_evaluate in evaluators: value = sub_evaluate(obj) - if value is None or value is _NO_OBJECT: + if value is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + elif value is None or value is _NO_OBJECT: return None values.append(value) return tuple(values) @@ -183,13 +218,21 @@ class EvaluatorCompiler: def visit_is_binary_op(self, operator, eval_left, eval_right, clause): def evaluate(obj): - return eval_left(obj) == eval_right(obj) + left_val = eval_left(obj) + right_val = eval_right(obj) + if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + return left_val == right_val return evaluate def visit_is_not_binary_op(self, operator, eval_left, eval_right, clause): def evaluate(obj): - return eval_left(obj) != eval_right(obj) + left_val = eval_left(obj) + right_val = eval_right(obj) + if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + return left_val != right_val return evaluate @@ -197,8 +240,11 @@ class EvaluatorCompiler: def evaluate(obj): left_val = eval_left(obj) right_val = eval_right(obj) - if left_val is None or right_val is None: + if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + elif left_val is None or right_val is None: return None + return operator(eval_left(obj), eval_right(obj)) return evaluate @@ -274,7 +320,9 @@ class EvaluatorCompiler: def evaluate(obj): value = eval_inner(obj) - if value is None: + if value is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + elif value is None: return None return not value diff --git a/lib/sqlalchemy/orm/identity.py b/lib/sqlalchemy/orm/identity.py index 63b131a78..4848f73f1 100644 --- a/lib/sqlalchemy/orm/identity.py +++ b/lib/sqlalchemy/orm/identity.py @@ -68,6 +68,11 @@ class IdentityMap: ) -> Optional[_O]: raise NotImplementedError() + def fast_get_state( + self, key: _IdentityKeyType[_O] + ) -> Optional[InstanceState[_O]]: + raise NotImplementedError() + def keys(self) -> Iterable[_IdentityKeyType[Any]]: return self._dict.keys() @@ -206,6 +211,11 @@ class WeakInstanceDict(IdentityMap): self._dict[key] = state state._instance_dict = self._wr + def fast_get_state( + self, key: _IdentityKeyType[_O] + ) -> Optional[InstanceState[_O]]: + return self._dict.get(key) + def get( self, key: _IdentityKeyType[_O], default: Optional[_O] = None ) -> Optional[_O]: diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index 7317d48be..64f2542fd 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -29,7 +29,6 @@ from typing import TYPE_CHECKING from typing import TypeVar from typing import Union -from sqlalchemy.orm.context import FromStatement from . import attributes from . import exc as orm_exc from . import path_registry @@ -37,6 +36,7 @@ from .base import _DEFER_FOR_STATE from .base import _RAISE_FOR_STATE from .base import _SET_DEFERRED_EXPIRED from .base import PassiveFlag +from .context import FromStatement from .util import _none_set from .util import state_str from .. import exc as sa_exc @@ -50,6 +50,7 @@ from ..sql import util as sql_util from ..sql.selectable import ForUpdateArg from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from ..sql.selectable import SelectState +from ..util import EMPTY_DICT if TYPE_CHECKING: from ._typing import _IdentityKeyType @@ -764,7 +765,7 @@ def _instance_processor( ) quick_populators = path.get( - context.attributes, "memoized_setups", _none_set + context.attributes, "memoized_setups", EMPTY_DICT ) todo = [] diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index c8df51b06..c9cf8f49b 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -854,6 +854,7 @@ class Mapper( _memoized_values: Dict[Any, Callable[[], Any]] _inheriting_mappers: util.WeakSequence[Mapper[Any]] _all_tables: Set[Table] + _polymorphic_attr_key: Optional[str] _pks_by_table: Dict[FromClause, OrderedSet[ColumnClause[Any]]] _cols_by_table: Dict[FromClause, OrderedSet[ColumnElement[Any]]] @@ -1653,6 +1654,7 @@ class Mapper( """ setter = False + polymorphic_key: Optional[str] = None if self.polymorphic_on is not None: setter = True @@ -1772,17 +1774,23 @@ class Mapper( self._set_polymorphic_identity = ( mapper._set_polymorphic_identity ) + self._polymorphic_attr_key = ( + mapper._polymorphic_attr_key + ) self._validate_polymorphic_identity = ( mapper._validate_polymorphic_identity ) else: self._set_polymorphic_identity = None + self._polymorphic_attr_key = None return if setter: def _set_polymorphic_identity(state): dict_ = state.dict + # TODO: what happens if polymorphic_on column attribute name + # does not match .key? state.get_impl(polymorphic_key).set( state, dict_, @@ -1790,6 +1798,8 @@ class Mapper( None, ) + self._polymorphic_attr_key = polymorphic_key + def _validate_polymorphic_identity(mapper, state, dict_): if ( polymorphic_key in dict_ @@ -1808,6 +1818,7 @@ class Mapper( _validate_polymorphic_identity ) else: + self._polymorphic_attr_key = None self._set_polymorphic_identity = None _validate_polymorphic_identity = None @@ -3562,6 +3573,10 @@ class Mapper( return util.LRUCache(self._compiled_cache_size) @HasMemoized.memoized_attribute + def _multiple_persistence_tables(self): + return len(self.tables) > 1 + + @HasMemoized.memoized_attribute def _sorted_tables(self): table_to_mapper: Dict[Table, Mapper[Any]] = {} diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index abd528986..dfb61c28a 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -31,6 +31,7 @@ from .. import exc as sa_exc from .. import future from .. import sql from .. import util +from ..engine import cursor as _cursor from ..sql import operators from ..sql.elements import BooleanClauseList from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL @@ -398,6 +399,11 @@ def _collect_insert_commands( None ) + if bulk and mapper._set_polymorphic_identity: + params.setdefault( + mapper._polymorphic_attr_key, mapper.polymorphic_identity + ) + yield ( state, state_dict, @@ -411,7 +417,11 @@ def _collect_insert_commands( def _collect_update_commands( - uowtransaction, table, states_to_update, bulk=False + uowtransaction, + table, + states_to_update, + bulk=False, + use_orm_update_stmt=None, ): """Identify sets of values to use in UPDATE statements for a list of states. @@ -437,7 +447,11 @@ def _collect_update_commands( pks = mapper._pks_by_table[table] - value_params = {} + if use_orm_update_stmt is not None: + # TODO: ordered values, etc + value_params = use_orm_update_stmt._values + else: + value_params = {} propkey_to_col = mapper._propkey_to_col[table] @@ -697,6 +711,7 @@ def _emit_update_statements( table, update, bookkeeping=True, + use_orm_update_stmt=None, ): """Emit UPDATE statements corresponding to value lists collected by _collect_update_commands().""" @@ -708,7 +723,7 @@ def _emit_update_statements( execution_options = {"compiled_cache": base_mapper._compiled_cache} - def update_stmt(): + def update_stmt(existing_stmt=None): clauses = BooleanClauseList._construct_raw(operators.and_) for col in mapper._pks_by_table[table]: @@ -725,10 +740,17 @@ def _emit_update_statements( ) ) - stmt = table.update().where(clauses) + if existing_stmt is not None: + stmt = existing_stmt.where(clauses) + else: + stmt = table.update().where(clauses) return stmt - cached_stmt = base_mapper._memo(("update", table), update_stmt) + if use_orm_update_stmt is not None: + cached_stmt = update_stmt(use_orm_update_stmt) + + else: + cached_stmt = base_mapper._memo(("update", table), update_stmt) for ( (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks), @@ -747,6 +769,15 @@ def _emit_update_statements( records = list(records) statement = cached_stmt + + if use_orm_update_stmt is not None: + statement = statement._annotate( + { + "_emit_update_table": table, + "_emit_update_mapper": mapper, + } + ) + return_defaults = False if not has_all_pks: @@ -904,16 +935,35 @@ def _emit_insert_statements( table, insert, bookkeeping=True, + use_orm_insert_stmt=None, + execution_options=None, ): """Emit INSERT statements corresponding to value lists collected by _collect_insert_commands().""" - cached_stmt = base_mapper._memo(("insert", table), table.insert) + if use_orm_insert_stmt is not None: + cached_stmt = use_orm_insert_stmt + exec_opt = util.EMPTY_DICT - execution_options = {"compiled_cache": base_mapper._compiled_cache} + # if a user query with RETURNING was passed, we definitely need + # to use RETURNING. + returning_is_required_anyway = bool(use_orm_insert_stmt._returning) + else: + returning_is_required_anyway = False + cached_stmt = base_mapper._memo(("insert", table), table.insert) + exec_opt = {"compiled_cache": base_mapper._compiled_cache} + + if execution_options: + execution_options = util.EMPTY_DICT.merge_with( + exec_opt, execution_options + ) + else: + execution_options = exec_opt + + return_result = None for ( - (connection, pkeys, hasvalue, has_all_pks, has_all_defaults), + (connection, _, hasvalue, has_all_pks, has_all_defaults), records, ) in groupby( insert, @@ -928,17 +978,29 @@ def _emit_insert_statements( statement = cached_stmt + if use_orm_insert_stmt is not None: + statement = statement._annotate( + { + "_emit_insert_table": table, + "_emit_insert_mapper": mapper, + } + ) + if ( - not bookkeeping - or ( - has_all_defaults - or not base_mapper.eager_defaults - or not base_mapper.local_table.implicit_returning - or not connection.dialect.insert_returning + ( + not bookkeeping + or ( + has_all_defaults + or not base_mapper.eager_defaults + or not base_mapper.local_table.implicit_returning + or not connection.dialect.insert_returning + ) ) + and not returning_is_required_anyway and has_all_pks and not hasvalue ): + # the "we don't need newly generated values back" section. # here we have all the PKs, all the defaults or we don't want # to fetch them, or the dialect doesn't support RETURNING at all @@ -946,7 +1008,7 @@ def _emit_insert_statements( records = list(records) multiparams = [rec[2] for rec in records] - c = connection.execute( + result = connection.execute( statement, multiparams, execution_options=execution_options ) if bookkeeping: @@ -962,7 +1024,7 @@ def _emit_insert_statements( has_all_defaults, ), last_inserted_params, - ) in zip(records, c.context.compiled_parameters): + ) in zip(records, result.context.compiled_parameters): if state: _postfetch( mapper_rec, @@ -970,19 +1032,20 @@ def _emit_insert_statements( table, state, state_dict, - c, + result, last_inserted_params, value_params, False, - c.returned_defaults - if not c.context.executemany + result.returned_defaults + if not result.context.executemany else None, ) else: _postfetch_bulk_save(mapper_rec, state_dict, table) else: - # here, we need defaults and/or pk values back. + # here, we need defaults and/or pk values back or we otherwise + # know that we are using RETURNING in any case records = list(records) if ( @@ -991,6 +1054,16 @@ def _emit_insert_statements( and len(records) > 1 ): do_executemany = True + elif returning_is_required_anyway: + if connection.dialect.insert_executemany_returning: + do_executemany = True + else: + raise sa_exc.InvalidRequestError( + f"Can't use explicit RETURNING for bulk INSERT " + f"operation with " + f"{connection.dialect.dialect_description} backend; " + f"executemany is not supported with RETURNING" + ) else: do_executemany = False @@ -998,6 +1071,7 @@ def _emit_insert_statements( statement = statement.return_defaults( *mapper._server_default_cols[table] ) + if mapper.version_id_col is not None: statement = statement.return_defaults(mapper.version_id_col) elif do_executemany: @@ -1006,10 +1080,16 @@ def _emit_insert_statements( if do_executemany: multiparams = [rec[2] for rec in records] - c = connection.execute( + result = connection.execute( statement, multiparams, execution_options=execution_options ) + if use_orm_insert_stmt is not None: + if return_result is None: + return_result = result + else: + return_result = return_result.splice_vertically(result) + if bookkeeping: for ( ( @@ -1027,9 +1107,9 @@ def _emit_insert_statements( returned_defaults, ) in zip_longest( records, - c.context.compiled_parameters, - c.inserted_primary_key_rows, - c.returned_defaults_rows or (), + result.context.compiled_parameters, + result.inserted_primary_key_rows, + result.returned_defaults_rows or (), ): if inserted_primary_key is None: # this is a real problem and means that we didn't @@ -1062,7 +1142,7 @@ def _emit_insert_statements( table, state, state_dict, - c, + result, last_inserted_params, value_params, False, @@ -1071,6 +1151,8 @@ def _emit_insert_statements( else: _postfetch_bulk_save(mapper_rec, state_dict, table) else: + assert not returning_is_required_anyway + for ( state, state_dict, @@ -1132,6 +1214,12 @@ def _emit_insert_statements( else: _postfetch_bulk_save(mapper_rec, state_dict, table) + if use_orm_insert_stmt is not None: + if return_result is None: + return _cursor.null_dml_result() + else: + return return_result + def _emit_post_update_statements( base_mapper, uowtransaction, mapper, table, update diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 6d0f055e4..4d5a98fcf 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -2978,7 +2978,7 @@ class Query( ) def delete( - self, synchronize_session: _SynchronizeSessionArgument = "evaluate" + self, synchronize_session: _SynchronizeSessionArgument = "auto" ) -> int: r"""Perform a DELETE with an arbitrary WHERE clause. @@ -3042,7 +3042,7 @@ class Query( def update( self, values: Dict[_DMLColumnArgument, Any], - synchronize_session: _SynchronizeSessionArgument = "evaluate", + synchronize_session: _SynchronizeSessionArgument = "auto", update_args: Optional[Dict[Any, Any]] = None, ) -> int: r"""Perform an UPDATE with an arbitrary WHERE clause. diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index a690da0d5..64c013306 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -1828,12 +1828,13 @@ class Session(_SessionClassMethods, EventTarget): statement._propagate_attrs.get("compile_state_plugin", None) == "orm" ): - # note that even without "future" mode, we need compile_state_cls = CompileState._get_plugin_class_for_plugin( statement, "orm" ) if TYPE_CHECKING: - assert isinstance(compile_state_cls, ORMCompileState) + assert isinstance( + compile_state_cls, context.AbstractORMCompileState + ) else: compile_state_cls = None @@ -1897,18 +1898,18 @@ class Session(_SessionClassMethods, EventTarget): statement, params or {}, execution_options=execution_options ) - result: Result[Any] = conn.execute( - statement, params or {}, execution_options=execution_options - ) - if compile_state_cls: - result = compile_state_cls.orm_setup_cursor_result( + result: Result[Any] = compile_state_cls.orm_execute_statement( self, statement, - params, + params or {}, execution_options, bind_arguments, - result, + conn, + ) + else: + result = conn.execute( + statement, params or {}, execution_options=execution_options ) if _scalar_result: @@ -2066,7 +2067,7 @@ class Session(_SessionClassMethods, EventTarget): def scalars( self, statement: TypedReturnsRows[Tuple[_T]], - params: Optional[_CoreSingleExecuteParams] = None, + params: Optional[_CoreAnyExecuteParams] = None, *, execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, @@ -2078,7 +2079,7 @@ class Session(_SessionClassMethods, EventTarget): def scalars( self, statement: Executable, - params: Optional[_CoreSingleExecuteParams] = None, + params: Optional[_CoreAnyExecuteParams] = None, *, execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, @@ -2089,7 +2090,7 @@ class Session(_SessionClassMethods, EventTarget): def scalars( self, statement: Executable, - params: Optional[_CoreSingleExecuteParams] = None, + params: Optional[_CoreAnyExecuteParams] = None, *, execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 19c6493db..8652591c8 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -227,6 +227,11 @@ class ColumnLoader(LoaderStrategy): fetch = self.columns[0] if adapter: fetch = adapter.columns[fetch] + if fetch is None: + # None happens here only for dml bulk_persistence cases + # when context.DMLReturningColFilter is used + return + memoized_populators[self.parent_property] = fetch def init_class_attribute(self, mapper): @@ -318,6 +323,12 @@ class ExpressionColumnLoader(ColumnLoader): fetch = columns[0] if adapter: fetch = adapter.columns[fetch] + if fetch is None: + # None is not expected to be the result of any + # adapter implementation here, however there may be theoretical + # usages of returning() with context.DMLReturningColFilter + return + memoized_populators[self.parent_property] = fetch def create_row_processor( diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index 86b2952cb..262048bd1 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -552,6 +552,8 @@ def _new_annotation_type( # e.g. BindParameter, add it if present. if cls.__dict__.get("inherit_cache", False): anno_cls.inherit_cache = True # type: ignore + elif "inherit_cache" in cls.__dict__: + anno_cls.inherit_cache = cls.__dict__["inherit_cache"] # type: ignore anno_cls._is_column_operators = issubclass(cls, operators.ColumnOperators) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 201324a2a..c7e226fcc 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -5166,6 +5166,8 @@ class SQLCompiler(Compiled): delete_stmt, delete_stmt.table, extra_froms ) + crud._get_crud_params(self, delete_stmt, compile_state, toplevel, **kw) + if delete_stmt._hints: dialect_hints, table_text = self._setup_crud_hints( delete_stmt, table_text @@ -5178,13 +5180,14 @@ class SQLCompiler(Compiled): text += table_text - if delete_stmt._returning: - if self.returning_precedes_values: - text += " " + self.returning_clause( - delete_stmt, - delete_stmt._returning, - populate_result_map=toplevel, - ) + if ( + self.implicit_returning or delete_stmt._returning + ) and self.returning_precedes_values: + text += " " + self.returning_clause( + delete_stmt, + self.implicit_returning or delete_stmt._returning, + populate_result_map=toplevel, + ) if extra_froms: extra_from_text = self.delete_extra_from_clause( @@ -5204,10 +5207,12 @@ class SQLCompiler(Compiled): if t: text += " WHERE " + t - if delete_stmt._returning and not self.returning_precedes_values: + if ( + self.implicit_returning or delete_stmt._returning + ) and not self.returning_precedes_values: text += " " + self.returning_clause( delete_stmt, - delete_stmt._returning, + self.implicit_returning or delete_stmt._returning, populate_result_map=toplevel, ) @@ -5297,7 +5302,6 @@ class StrSQLCompiler(SQLCompiler): self._label_select_column(None, c, True, False, {}) for c in base._select_iterables(returning_cols) ] - return "RETURNING " + ", ".join(columns) def update_from_clause( diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index b13377a59..22fffb73a 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -150,6 +150,22 @@ def _get_crud_params( "return_defaults() simultaneously" ) + if compile_state.isdelete: + _setup_delete_return_defaults( + compiler, + stmt, + compile_state, + (), + _getattr_col_key, + _column_as_key, + _col_bind_name, + (), + (), + toplevel, + kw, + ) + return _CrudParams([], []) + # no parameters in the statement, no parameters in the # compiled params - return binds for all columns if compiler.column_keys is None and compile_state._no_parameters: @@ -466,13 +482,6 @@ def _scan_insert_from_select_cols( kw, ): - ( - need_pks, - implicit_returning, - implicit_return_defaults, - postfetch_lastrowid, - ) = _get_returning_modifiers(compiler, stmt, compile_state, toplevel) - cols = [stmt.table.c[_column_as_key(name)] for name in stmt._select_names] assert compiler.stack[-1]["selectable"] is stmt @@ -537,6 +546,8 @@ def _scan_cols( postfetch_lastrowid, ) = _get_returning_modifiers(compiler, stmt, compile_state, toplevel) + assert compile_state.isupdate or compile_state.isinsert + if compile_state._parameter_ordering: parameter_ordering = [ _column_as_key(key) for key in compile_state._parameter_ordering @@ -563,6 +574,13 @@ def _scan_cols( else: autoincrement_col = insert_null_pk_still_autoincrements = None + if stmt._supplemental_returning: + supplemental_returning = set(stmt._supplemental_returning) + else: + supplemental_returning = set() + + compiler_implicit_returning = compiler.implicit_returning + for c in cols: # scan through every column in the target table @@ -627,11 +645,13 @@ def _scan_cols( # column has a DDL-level default, and is either not a pk # column or we don't need the pk. if implicit_return_defaults and c in implicit_return_defaults: - compiler.implicit_returning.append(c) + compiler_implicit_returning.append(c) elif not c.primary_key: compiler.postfetch.append(c) + elif implicit_return_defaults and c in implicit_return_defaults: - compiler.implicit_returning.append(c) + compiler_implicit_returning.append(c) + elif ( c.primary_key and c is not stmt.table._autoincrement_column @@ -652,6 +672,59 @@ def _scan_cols( kw, ) + # adding supplemental cols to implicit_returning in table + # order so that order is maintained between multiple INSERT + # statements which may have different parameters included, but all + # have the same RETURNING clause + if ( + c in supplemental_returning + and c not in compiler_implicit_returning + ): + compiler_implicit_returning.append(c) + + if supplemental_returning: + # we should have gotten every col into implicit_returning, + # however supplemental returning can also have SQL functions etc. + # in it + remaining_supplemental = supplemental_returning.difference( + compiler_implicit_returning + ) + compiler_implicit_returning.extend( + c + for c in stmt._supplemental_returning + if c in remaining_supplemental + ) + + +def _setup_delete_return_defaults( + compiler, + stmt, + compile_state, + parameters, + _getattr_col_key, + _column_as_key, + _col_bind_name, + check_columns, + values, + toplevel, + kw, +): + (_, _, implicit_return_defaults, _) = _get_returning_modifiers( + compiler, stmt, compile_state, toplevel + ) + + if not implicit_return_defaults: + return + + if stmt._return_defaults_columns: + compiler.implicit_returning.extend(implicit_return_defaults) + + if stmt._supplemental_returning: + ir_set = set(compiler.implicit_returning) + compiler.implicit_returning.extend( + c for c in stmt._supplemental_returning if c not in ir_set + ) + def _append_param_parameter( compiler, @@ -743,7 +816,7 @@ def _append_param_parameter( elif compiler.dialect.postfetch_lastrowid: compiler.postfetch_lastrowid = True - elif implicit_return_defaults and c in implicit_return_defaults: + elif implicit_return_defaults and (c in implicit_return_defaults): compiler.implicit_returning.append(c) else: @@ -1303,6 +1376,7 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel): INSERT or UPDATE statement after it's invoked. """ + need_pks = ( toplevel and _compile_state_isinsert(compile_state) @@ -1315,6 +1389,7 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel): ) ) and not stmt._returning + # and (not stmt._returning or stmt._return_defaults) and not compile_state._has_multi_parameters ) @@ -1357,33 +1432,41 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel): or stmt._return_defaults ) ) - if implicit_returning: postfetch_lastrowid = False if _compile_state_isinsert(compile_state): - implicit_return_defaults = implicit_returning and stmt._return_defaults + should_implicit_return_defaults = ( + implicit_returning and stmt._return_defaults + ) elif compile_state.isupdate: - implicit_return_defaults = ( + should_implicit_return_defaults = ( stmt._return_defaults and compile_state._primary_table.implicit_returning and compile_state._supports_implicit_returning and compiler.dialect.update_returning ) + elif compile_state.isdelete: + should_implicit_return_defaults = ( + stmt._return_defaults + and compile_state._primary_table.implicit_returning + and compile_state._supports_implicit_returning + and compiler.dialect.delete_returning + ) else: - # this line is unused, currently we are always - # isinsert or isupdate - implicit_return_defaults = False # pragma: no cover + should_implicit_return_defaults = False # pragma: no cover - if implicit_return_defaults: + if should_implicit_return_defaults: if not stmt._return_defaults_columns: implicit_return_defaults = set(stmt.table.c) else: implicit_return_defaults = set(stmt._return_defaults_columns) + else: + implicit_return_defaults = None return ( need_pks, - implicit_returning, + implicit_returning or should_implicit_return_defaults, implicit_return_defaults, postfetch_lastrowid, ) diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index a08e38800..5145a4a16 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -165,15 +165,32 @@ class DMLState(CompileState): ... @classmethod + def _get_multi_crud_kv_pairs( + cls, + statement: UpdateBase, + multi_kv_iterator: Iterable[Dict[_DMLColumnArgument, Any]], + ) -> List[Dict[_DMLColumnElement, Any]]: + return [ + { + coercions.expect(roles.DMLColumnRole, k): v + for k, v in mapping.items() + } + for mapping in multi_kv_iterator + ] + + @classmethod def _get_crud_kv_pairs( cls, statement: UpdateBase, kv_iterator: Iterable[Tuple[_DMLColumnArgument, Any]], + needs_to_be_cacheable: bool, ) -> List[Tuple[_DMLColumnElement, Any]]: return [ ( coercions.expect(roles.DMLColumnRole, k), - coercions.expect( + v + if not needs_to_be_cacheable + else coercions.expect( roles.ExpressionElementRole, v, type_=NullType(), @@ -269,7 +286,7 @@ class InsertDMLState(DMLState): def _insert_col_keys(self) -> List[str]: # this is also done in crud.py -> _key_getters_for_crud_column return [ - coercions.expect_as_key(roles.DMLColumnRole, col) + coercions.expect(roles.DMLColumnRole, col, as_key=True) for col in self._dict_parameters or () ] @@ -326,7 +343,6 @@ class UpdateDMLState(DMLState): self._extra_froms = ef self.is_multitable = mt = ef - self.include_table_with_column_exprs = bool( mt and compiler.render_table_with_column_in_update_from ) @@ -389,6 +405,7 @@ class UpdateBase( _return_defaults_columns: Optional[ Tuple[_ColumnsClauseElement, ...] ] = None + _supplemental_returning: Optional[Tuple[_ColumnsClauseElement, ...]] = None _returning: Tuple[_ColumnsClauseElement, ...] = () is_dml = True @@ -435,6 +452,215 @@ class UpdateBase( return self @_generative + def return_defaults( + self: SelfUpdateBase, + *cols: _DMLColumnArgument, + supplemental_cols: Optional[Iterable[_DMLColumnArgument]] = None, + ) -> SelfUpdateBase: + """Make use of a :term:`RETURNING` clause for the purpose + of fetching server-side expressions and defaults, for supporting + backends only. + + .. deepalchemy:: + + The :meth:`.UpdateBase.return_defaults` method is used by the ORM + for its internal work in fetching newly generated primary key + and server default values, in particular to provide the underyling + implementation of the :paramref:`_orm.Mapper.eager_defaults` + ORM feature as well as to allow RETURNING support with bulk + ORM inserts. Its behavior is fairly idiosyncratic + and is not really intended for general use. End users should + stick with using :meth:`.UpdateBase.returning` in order to + add RETURNING clauses to their INSERT, UPDATE and DELETE + statements. + + Normally, a single row INSERT statement will automatically populate the + :attr:`.CursorResult.inserted_primary_key` attribute when executed, + which stores the primary key of the row that was just inserted in the + form of a :class:`.Row` object with column names as named tuple keys + (and the :attr:`.Row._mapping` view fully populated as well). The + dialect in use chooses the strategy to use in order to populate this + data; if it was generated using server-side defaults and / or SQL + expressions, dialect-specific approaches such as ``cursor.lastrowid`` + or ``RETURNING`` are typically used to acquire the new primary key + value. + + However, when the statement is modified by calling + :meth:`.UpdateBase.return_defaults` before executing the statement, + additional behaviors take place **only** for backends that support + RETURNING and for :class:`.Table` objects that maintain the + :paramref:`.Table.implicit_returning` parameter at its default value of + ``True``. In these cases, when the :class:`.CursorResult` is returned + from the statement's execution, not only will + :attr:`.CursorResult.inserted_primary_key` be populated as always, the + :attr:`.CursorResult.returned_defaults` attribute will also be + populated with a :class:`.Row` named-tuple representing the full range + of server generated + values from that single row, including values for any columns that + specify :paramref:`_schema.Column.server_default` or which make use of + :paramref:`_schema.Column.default` using a SQL expression. + + When invoking INSERT statements with multiple rows using + :ref:`insertmanyvalues <engine_insertmanyvalues>`, the + :meth:`.UpdateBase.return_defaults` modifier will have the effect of + the :attr:`_engine.CursorResult.inserted_primary_key_rows` and + :attr:`_engine.CursorResult.returned_defaults_rows` attributes being + fully populated with lists of :class:`.Row` objects representing newly + inserted primary key values as well as newly inserted server generated + values for each row inserted. The + :attr:`.CursorResult.inserted_primary_key` and + :attr:`.CursorResult.returned_defaults` attributes will also continue + to be populated with the first row of these two collections. + + If the backend does not support RETURNING or the :class:`.Table` in use + has disabled :paramref:`.Table.implicit_returning`, then no RETURNING + clause is added and no additional data is fetched, however the + INSERT, UPDATE or DELETE statement proceeds normally. + + E.g.:: + + stmt = table.insert().values(data='newdata').return_defaults() + + result = connection.execute(stmt) + + server_created_at = result.returned_defaults['created_at'] + + When used against an UPDATE statement + :meth:`.UpdateBase.return_defaults` instead looks for columns that + include :paramref:`_schema.Column.onupdate` or + :paramref:`_schema.Column.server_onupdate` parameters assigned, when + constructing the columns that will be included in the RETURNING clause + by default if explicit columns were not specified. When used against a + DELETE statement, no columns are included in RETURNING by default, they + instead must be specified explicitly as there are no columns that + normally change values when a DELETE statement proceeds. + + .. versionadded:: 2.0 :meth:`.UpdateBase.return_defaults` is supported + for DELETE statements also and has been moved from + :class:`.ValuesBase` to :class:`.UpdateBase`. + + The :meth:`.UpdateBase.return_defaults` method is mutually exclusive + against the :meth:`.UpdateBase.returning` method and errors will be + raised during the SQL compilation process if both are used at the same + time on one statement. The RETURNING clause of the INSERT, UPDATE or + DELETE statement is therefore controlled by only one of these methods + at a time. + + The :meth:`.UpdateBase.return_defaults` method differs from + :meth:`.UpdateBase.returning` in these ways: + + 1. :meth:`.UpdateBase.return_defaults` method causes the + :attr:`.CursorResult.returned_defaults` collection to be populated + with the first row from the RETURNING result. This attribute is not + populated when using :meth:`.UpdateBase.returning`. + + 2. :meth:`.UpdateBase.return_defaults` is compatible with existing + logic used to fetch auto-generated primary key values that are then + populated into the :attr:`.CursorResult.inserted_primary_key` + attribute. By contrast, using :meth:`.UpdateBase.returning` will + have the effect of the :attr:`.CursorResult.inserted_primary_key` + attribute being left unpopulated. + + 3. :meth:`.UpdateBase.return_defaults` can be called against any + backend. Backends that don't support RETURNING will skip the usage + of the feature, rather than raising an exception. The return value + of :attr:`_engine.CursorResult.returned_defaults` will be ``None`` + for backends that don't support RETURNING or for which the target + :class:`.Table` sets :paramref:`.Table.implicit_returning` to + ``False``. + + 4. An INSERT statement invoked with executemany() is supported if the + backend database driver supports the + :ref:`insertmanyvalues <engine_insertmanyvalues>` + feature which is now supported by most SQLAlchemy-included backends. + When executemany is used, the + :attr:`_engine.CursorResult.returned_defaults_rows` and + :attr:`_engine.CursorResult.inserted_primary_key_rows` accessors + will return the inserted defaults and primary keys. + + .. versionadded:: 1.4 Added + :attr:`_engine.CursorResult.returned_defaults_rows` and + :attr:`_engine.CursorResult.inserted_primary_key_rows` accessors. + In version 2.0, the underlying implementation which fetches and + populates the data for these attributes was generalized to be + supported by most backends, whereas in 1.4 they were only + supported by the ``psycopg2`` driver. + + + :param cols: optional list of column key names or + :class:`_schema.Column` that acts as a filter for those columns that + will be fetched. + :param supplemental_cols: optional list of RETURNING expressions, + in the same form as one would pass to the + :meth:`.UpdateBase.returning` method. When present, the additional + columns will be included in the RETURNING clause, and the + :class:`.CursorResult` object will be "rewound" when returned, so + that methods like :meth:`.CursorResult.all` will return new rows + mostly as though the statement used :meth:`.UpdateBase.returning` + directly. However, unlike when using :meth:`.UpdateBase.returning` + directly, the **order of the columns is undefined**, so can only be + targeted using names or :attr:`.Row._mapping` keys; they cannot + reliably be targeted positionally. + + .. versionadded:: 2.0 + + .. seealso:: + + :meth:`.UpdateBase.returning` + + :attr:`_engine.CursorResult.returned_defaults` + + :attr:`_engine.CursorResult.returned_defaults_rows` + + :attr:`_engine.CursorResult.inserted_primary_key` + + :attr:`_engine.CursorResult.inserted_primary_key_rows` + + """ + + if self._return_defaults: + # note _return_defaults_columns = () means return all columns, + # so if we have been here before, only update collection if there + # are columns in the collection + if self._return_defaults_columns and cols: + self._return_defaults_columns = tuple( + util.OrderedSet(self._return_defaults_columns).union( + coercions.expect(roles.ColumnsClauseRole, c) + for c in cols + ) + ) + else: + # set for all columns + self._return_defaults_columns = () + else: + self._return_defaults_columns = tuple( + coercions.expect(roles.ColumnsClauseRole, c) for c in cols + ) + self._return_defaults = True + + if supplemental_cols: + # uniquifying while also maintaining order (the maintain of order + # is for test suites but also for vertical splicing + supplemental_col_tup = ( + coercions.expect(roles.ColumnsClauseRole, c) + for c in supplemental_cols + ) + + if self._supplemental_returning is None: + self._supplemental_returning = tuple( + util.unique_list(supplemental_col_tup) + ) + else: + self._supplemental_returning = tuple( + util.unique_list( + self._supplemental_returning + + tuple(supplemental_col_tup) + ) + ) + + return self + + @_generative def returning( self, *cols: _ColumnsClauseArgument[Any], **__kw: Any ) -> UpdateBase: @@ -500,7 +726,7 @@ class UpdateBase( .. seealso:: - :meth:`.ValuesBase.return_defaults` - an alternative method tailored + :meth:`.UpdateBase.return_defaults` - an alternative method tailored towards efficient fetching of server-side defaults and triggers for single-row INSERTs or UPDATEs. @@ -703,7 +929,6 @@ class ValuesBase(UpdateBase): _select_names: Optional[List[str]] = None _inline: bool = False - _returning: Tuple[_ColumnsClauseElement, ...] = () def __init__(self, table: _DMLTableArgument): self.table = coercions.expect( @@ -859,7 +1084,15 @@ class ValuesBase(UpdateBase): ) elif isinstance(arg, collections_abc.Sequence): - if arg and isinstance(arg[0], (list, dict, tuple)): + + if arg and isinstance(arg[0], dict): + multi_kv_generator = DMLState.get_plugin_class( + self + )._get_multi_crud_kv_pairs + self._multi_values += (multi_kv_generator(self, arg),) + return self + + if arg and isinstance(arg[0], (list, tuple)): self._multi_values += (arg,) return self @@ -888,173 +1121,13 @@ class ValuesBase(UpdateBase): # and ensures they get the "crud"-style name when rendered. kv_generator = DMLState.get_plugin_class(self)._get_crud_kv_pairs - coerced_arg = {k: v for k, v in kv_generator(self, arg.items())} + coerced_arg = dict(kv_generator(self, arg.items(), True)) if self._values: self._values = self._values.union(coerced_arg) else: self._values = util.immutabledict(coerced_arg) return self - @_generative - def return_defaults( - self: SelfValuesBase, *cols: _DMLColumnArgument - ) -> SelfValuesBase: - """Make use of a :term:`RETURNING` clause for the purpose - of fetching server-side expressions and defaults, for supporting - backends only. - - .. tip:: - - The :meth:`.ValuesBase.return_defaults` method is used by the ORM - for its internal work in fetching newly generated primary key - and server default values, in particular to provide the underyling - implementation of the :paramref:`_orm.Mapper.eager_defaults` - ORM feature. Its behavior is fairly idiosyncratic - and is not really intended for general use. End users should - stick with using :meth:`.UpdateBase.returning` in order to - add RETURNING clauses to their INSERT, UPDATE and DELETE - statements. - - Normally, a single row INSERT statement will automatically populate the - :attr:`.CursorResult.inserted_primary_key` attribute when executed, - which stores the primary key of the row that was just inserted in the - form of a :class:`.Row` object with column names as named tuple keys - (and the :attr:`.Row._mapping` view fully populated as well). The - dialect in use chooses the strategy to use in order to populate this - data; if it was generated using server-side defaults and / or SQL - expressions, dialect-specific approaches such as ``cursor.lastrowid`` - or ``RETURNING`` are typically used to acquire the new primary key - value. - - However, when the statement is modified by calling - :meth:`.ValuesBase.return_defaults` before executing the statement, - additional behaviors take place **only** for backends that support - RETURNING and for :class:`.Table` objects that maintain the - :paramref:`.Table.implicit_returning` parameter at its default value of - ``True``. In these cases, when the :class:`.CursorResult` is returned - from the statement's execution, not only will - :attr:`.CursorResult.inserted_primary_key` be populated as always, the - :attr:`.CursorResult.returned_defaults` attribute will also be - populated with a :class:`.Row` named-tuple representing the full range - of server generated - values from that single row, including values for any columns that - specify :paramref:`_schema.Column.server_default` or which make use of - :paramref:`_schema.Column.default` using a SQL expression. - - When invoking INSERT statements with multiple rows using - :ref:`insertmanyvalues <engine_insertmanyvalues>`, the - :meth:`.ValuesBase.return_defaults` modifier will have the effect of - the :attr:`_engine.CursorResult.inserted_primary_key_rows` and - :attr:`_engine.CursorResult.returned_defaults_rows` attributes being - fully populated with lists of :class:`.Row` objects representing newly - inserted primary key values as well as newly inserted server generated - values for each row inserted. The - :attr:`.CursorResult.inserted_primary_key` and - :attr:`.CursorResult.returned_defaults` attributes will also continue - to be populated with the first row of these two collections. - - If the backend does not support RETURNING or the :class:`.Table` in use - has disabled :paramref:`.Table.implicit_returning`, then no RETURNING - clause is added and no additional data is fetched, however the - INSERT or UPDATE statement proceeds normally. - - - E.g.:: - - stmt = table.insert().values(data='newdata').return_defaults() - - result = connection.execute(stmt) - - server_created_at = result.returned_defaults['created_at'] - - - The :meth:`.ValuesBase.return_defaults` method is mutually exclusive - against the :meth:`.UpdateBase.returning` method and errors will be - raised during the SQL compilation process if both are used at the same - time on one statement. The RETURNING clause of the INSERT or UPDATE - statement is therefore controlled by only one of these methods at a - time. - - The :meth:`.ValuesBase.return_defaults` method differs from - :meth:`.UpdateBase.returning` in these ways: - - 1. :meth:`.ValuesBase.return_defaults` method causes the - :attr:`.CursorResult.returned_defaults` collection to be populated - with the first row from the RETURNING result. This attribute is not - populated when using :meth:`.UpdateBase.returning`. - - 2. :meth:`.ValuesBase.return_defaults` is compatible with existing - logic used to fetch auto-generated primary key values that are then - populated into the :attr:`.CursorResult.inserted_primary_key` - attribute. By contrast, using :meth:`.UpdateBase.returning` will - have the effect of the :attr:`.CursorResult.inserted_primary_key` - attribute being left unpopulated. - - 3. :meth:`.ValuesBase.return_defaults` can be called against any - backend. Backends that don't support RETURNING will skip the usage - of the feature, rather than raising an exception. The return value - of :attr:`_engine.CursorResult.returned_defaults` will be ``None`` - for backends that don't support RETURNING or for which the target - :class:`.Table` sets :paramref:`.Table.implicit_returning` to - ``False``. - - 4. An INSERT statement invoked with executemany() is supported if the - backend database driver supports the - :ref:`insertmanyvalues <engine_insertmanyvalues>` - feature which is now supported by most SQLAlchemy-included backends. - When executemany is used, the - :attr:`_engine.CursorResult.returned_defaults_rows` and - :attr:`_engine.CursorResult.inserted_primary_key_rows` accessors - will return the inserted defaults and primary keys. - - .. versionadded:: 1.4 Added - :attr:`_engine.CursorResult.returned_defaults_rows` and - :attr:`_engine.CursorResult.inserted_primary_key_rows` accessors. - In version 2.0, the underlying implementation which fetches and - populates the data for these attributes was generalized to be - supported by most backends, whereas in 1.4 they were only - supported by the ``psycopg2`` driver. - - - :param cols: optional list of column key names or - :class:`_schema.Column` that acts as a filter for those columns that - will be fetched. - - .. seealso:: - - :meth:`.UpdateBase.returning` - - :attr:`_engine.CursorResult.returned_defaults` - - :attr:`_engine.CursorResult.returned_defaults_rows` - - :attr:`_engine.CursorResult.inserted_primary_key` - - :attr:`_engine.CursorResult.inserted_primary_key_rows` - - """ - - if self._return_defaults: - # note _return_defaults_columns = () means return all columns, - # so if we have been here before, only update collection if there - # are columns in the collection - if self._return_defaults_columns and cols: - self._return_defaults_columns = tuple( - set(self._return_defaults_columns).union( - coercions.expect(roles.ColumnsClauseRole, c) - for c in cols - ) - ) - else: - # set for all columns - self._return_defaults_columns = () - else: - self._return_defaults_columns = tuple( - coercions.expect(roles.ColumnsClauseRole, c) for c in cols - ) - self._return_defaults = True - return self - SelfInsert = typing.TypeVar("SelfInsert", bound="Insert") @@ -1459,7 +1532,7 @@ class Update(DMLWhereBase, ValuesBase): ) kv_generator = DMLState.get_plugin_class(self)._get_crud_kv_pairs - self._ordered_values = kv_generator(self, args) + self._ordered_values = kv_generator(self, args, True) return self @_generative diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index 4416fe630..b3a71dbff 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -68,7 +68,7 @@ class CursorSQL(SQLMatchRule): class CompiledSQL(SQLMatchRule): def __init__( - self, statement, params=None, dialect="default", enable_returning=False + self, statement, params=None, dialect="default", enable_returning=True ): self.statement = statement self.params = params @@ -90,6 +90,17 @@ class CompiledSQL(SQLMatchRule): dialect.insert_returning = ( dialect.update_returning ) = dialect.delete_returning = True + dialect.use_insertmanyvalues = True + dialect.supports_multivalues_insert = True + dialect.update_returning_multifrom = True + dialect.delete_returning_multifrom = True + # dialect.favor_returning_over_lastrowid = True + # dialect.insert_null_pk_still_autoincrements = True + + # this is calculated but we need it to be True for this + # to look like all the current RETURNING dialects + assert dialect.insert_executemany_returning + return dialect else: return url.URL.create(self.dialect).get_dialect()() diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index 20dee5273..ef284babc 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -23,7 +23,6 @@ from .util import adict from .util import drop_all_tables_from_metadata from .. import event from .. import util -from ..orm import declarative_base from ..orm import DeclarativeBase from ..orm import MappedAsDataclass from ..orm import registry @@ -117,7 +116,7 @@ class TestBase: metadata=metadata, type_annotation_map={ str: sa.String().with_variant( - sa.String(50), "mysql", "mariadb" + sa.String(50), "mysql", "mariadb", "oracle" ) }, ) @@ -132,7 +131,7 @@ class TestBase: metadata = _md type_annotation_map = { str: sa.String().with_variant( - sa.String(50), "mysql", "mariadb" + sa.String(50), "mysql", "mariadb", "oracle" ) } @@ -780,18 +779,19 @@ class DeclarativeMappedTest(MappedTest): def _with_register_classes(cls, fn): cls_registry = cls.classes - class DeclarativeBasic: + class _DeclBase(DeclarativeBase): __table_cls__ = schema.Table + metadata = cls._tables_metadata + type_annotation_map = { + str: sa.String().with_variant( + sa.String(50), "mysql", "mariadb", "oracle" + ) + } - def __init_subclass__(cls) -> None: + def __init_subclass__(cls, **kw) -> None: assert cls_registry is not None cls_registry[cls.__name__] = cls - super().__init_subclass__() - - _DeclBase = declarative_base( - metadata=cls._tables_metadata, - cls=DeclarativeBasic, - ) + super().__init_subclass__(**kw) cls.DeclarativeBasic = _DeclBase diff --git a/lib/sqlalchemy/testing/suite/test_rowcount.py b/lib/sqlalchemy/testing/suite/test_rowcount.py index b7d4b7452..8e19a24a8 100644 --- a/lib/sqlalchemy/testing/suite/test_rowcount.py +++ b/lib/sqlalchemy/testing/suite/test_rowcount.py @@ -89,8 +89,13 @@ class RowCountTest(fixtures.TablesTest): eq_(r.rowcount, 3) @testing.requires.update_returning - @testing.requires.sane_rowcount_w_returning def test_update_rowcount_return_defaults(self, connection): + """note this test should succeed for all RETURNING backends + as of 2.0. In + Idf28379f8705e403a3c6a937f6a798a042ef2540 we changed rowcount to use + len(rows) when we have implicit returning + + """ employees_table = self.tables.employees department = employees_table.c.department diff --git a/lib/sqlalchemy/util/_py_collections.py b/lib/sqlalchemy/util/_py_collections.py index f8348714c..488229abb 100644 --- a/lib/sqlalchemy/util/_py_collections.py +++ b/lib/sqlalchemy/util/_py_collections.py @@ -11,6 +11,7 @@ from __future__ import annotations from itertools import filterfalse from typing import AbstractSet from typing import Any +from typing import Callable from typing import cast from typing import Collection from typing import Dict @@ -481,7 +482,9 @@ class IdentitySet: return "%s(%r)" % (type(self).__name__, list(self._members.values())) -def unique_list(seq, hashfunc=None): +def unique_list( + seq: Iterable[_T], hashfunc: Optional[Callable[[_T], int]] = None +) -> List[_T]: seen: Set[Any] = set() seen_add = seen.add if not hashfunc: |
