diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-08-02 16:18:18 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-08-05 10:07:15 -0400 |
| commit | 82a1d4096fbfe94e2fa626d65d5c3beb2c6afa37 (patch) | |
| tree | baca62a1a0784f192e65402f824319b0403c6847 /lib/sqlalchemy | |
| parent | 0027b3a4bc54599ac8102a4a3d81d8007738903e (diff) | |
| download | sqlalchemy-82a1d4096fbfe94e2fa626d65d5c3beb2c6afa37.tar.gz | |
include column.default, column.onupdate in eager_defaults
Fixed bug in the behavior of the :paramref:`_orm.Mapper.eager_defaults`
parameter such that client-side SQL default or onupdate expressions in the
table definition alone will trigger a fetch operation using RETURNING or
SELECT when the ORM emits an INSERT or UPDATE for the row. Previously, only
server side defaults established as part of table DDL and/or server-side
onupdate expressions would trigger this fetch, even though client-side SQL
expressions would be included when the fetch was rendered.
Fixes: #7438
Change-Id: Iba719298ba4a26d185edec97ba77d2d54585e5a4
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/orm/mapper.py | 72 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 26 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/dml.py | 22 |
3 files changed, 87 insertions, 33 deletions
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 769b1b623..6a95030b5 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -28,6 +28,7 @@ from typing import cast from typing import Collection from typing import Deque from typing import Dict +from typing import FrozenSet from typing import Generic from typing import Iterable from typing import Iterator @@ -2397,15 +2398,21 @@ class Mapper( ) @HasMemoized.memoized_attribute - def _server_default_cols(self): + def _server_default_cols( + self, + ) -> Mapping[FromClause, FrozenSet[Column[Any]]]: return dict( ( table, frozenset( [ - col.key - for col in columns + col + for col in cast("Iterable[Column[Any]]", columns) if col.server_default is not None + or ( + col.default is not None + and col.default.is_clause_element + ) ] ), ) @@ -2413,35 +2420,60 @@ class Mapper( ) @HasMemoized.memoized_attribute - def _server_default_plus_onupdate_propkeys(self): - result = set() - - for table, columns in self._cols_by_table.items(): - for col in columns: - if ( - col.server_default is not None - or col.server_onupdate is not None - ) and col in self._columntoproperty: - result.add(self._columntoproperty[col].key) - - return result - - @HasMemoized.memoized_attribute - def _server_onupdate_default_cols(self): + def _server_onupdate_default_cols( + self, + ) -> Mapping[FromClause, FrozenSet[Column[Any]]]: return dict( ( table, frozenset( [ - col.key - for col in columns + col + for col in cast("Iterable[Column[Any]]", columns) if col.server_onupdate is not None + or ( + col.onupdate is not None + and col.onupdate.is_clause_element + ) ] ), ) for table, columns in self._cols_by_table.items() ) + @HasMemoized.memoized_attribute + def _server_default_col_keys(self) -> Mapping[FromClause, FrozenSet[str]]: + return { + table: frozenset(col.key for col in cols if col.key is not None) + for table, cols in self._server_default_cols.items() + } + + @HasMemoized.memoized_attribute + def _server_onupdate_default_col_keys( + self, + ) -> Mapping[FromClause, FrozenSet[str]]: + return { + table: frozenset(col.key for col in cols if col.key is not None) + for table, cols in self._server_onupdate_default_cols.items() + } + + @HasMemoized.memoized_attribute + def _server_default_plus_onupdate_propkeys(self) -> Set[str]: + result: Set[str] = set() + + col_to_property = self._columntoproperty + for table, columns in self._server_default_cols.items(): + result.update( + col_to_property[col].key + for col in columns.intersection(col_to_property) + ) + for table, columns in self._server_onupdate_default_cols.items(): + result.update( + col_to_property[col].key + for col in columns.intersection(col_to_property) + ) + return result + @HasMemoized.memoized_instancemethod def __clause_element__(self): diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index c10f4701e..7cd66513b 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -561,9 +561,9 @@ def _collect_insert_commands( has_all_pks = mapper._pk_keys_by_table[table].issubset(params) if mapper.base_mapper.eager_defaults: - has_all_defaults = mapper._server_default_cols[table].issubset( - params - ) + has_all_defaults = mapper._server_default_col_keys[ + table + ].issubset(params) else: has_all_defaults = True else: @@ -659,7 +659,7 @@ def _collect_update_commands( if mapper.base_mapper.eager_defaults: has_all_defaults = ( - mapper._server_onupdate_default_cols[table] + mapper._server_onupdate_default_col_keys[table] ).issubset(params) else: has_all_defaults = True @@ -930,16 +930,20 @@ def _emit_update_statements( return_defaults = False if not has_all_pks: - statement = statement.return_defaults() + statement = statement.return_defaults(*mapper._pks_by_table[table]) return_defaults = True - elif ( + + if ( bookkeeping and not has_all_defaults and mapper.base_mapper.eager_defaults ): - statement = statement.return_defaults() + statement = statement.return_defaults( + *mapper._server_onupdate_default_cols[table] + ) return_defaults = True - elif mapper.version_id_col is not None: + + if mapper.version_id_col is not None: statement = statement.return_defaults(mapper.version_id_col) return_defaults = True @@ -1171,8 +1175,10 @@ def _emit_insert_statements( do_executemany = False if not has_all_defaults and base_mapper.eager_defaults: - statement = statement.return_defaults() - elif mapper.version_id_col is not None: + statement = statement.return_defaults( + *mapper._server_default_cols[table] + ) + if mapper.version_id_col is not None: statement = statement.return_defaults(mapper.version_id_col) elif do_executemany: statement = statement.return_defaults(*table.primary_key) diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 76a16eb1c..9d489ed98 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -989,10 +989,26 @@ class ValuesBase(UpdateBase): :attr:`_engine.CursorResult.inserted_primary_key_rows` """ + + if self._return_defaults: + # note _return_defaults_columns = () means return all columns, + # so if we have been here before, only update collection if there + # are columns in the collection + if self._return_defaults_columns and cols: + self._return_defaults_columns = tuple( + set(self._return_defaults_columns).union( + coercions.expect(roles.ColumnsClauseRole, c) + for c in cols + ) + ) + else: + # set for all columns + self._return_defaults_columns = () + else: + self._return_defaults_columns = tuple( + coercions.expect(roles.ColumnsClauseRole, c) for c in cols + ) self._return_defaults = True - self._return_defaults_columns = tuple( - coercions.expect(roles.ColumnsClauseRole, c) for c in cols - ) return self |
