summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-08-02 16:18:18 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-08-05 10:07:15 -0400
commit82a1d4096fbfe94e2fa626d65d5c3beb2c6afa37 (patch)
treebaca62a1a0784f192e65402f824319b0403c6847 /lib/sqlalchemy
parent0027b3a4bc54599ac8102a4a3d81d8007738903e (diff)
downloadsqlalchemy-82a1d4096fbfe94e2fa626d65d5c3beb2c6afa37.tar.gz
include column.default, column.onupdate in eager_defaults
Fixed bug in the behavior of the :paramref:`_orm.Mapper.eager_defaults` parameter such that client-side SQL default or onupdate expressions in the table definition alone will trigger a fetch operation using RETURNING or SELECT when the ORM emits an INSERT or UPDATE for the row. Previously, only server side defaults established as part of table DDL and/or server-side onupdate expressions would trigger this fetch, even though client-side SQL expressions would be included when the fetch was rendered. Fixes: #7438 Change-Id: Iba719298ba4a26d185edec97ba77d2d54585e5a4
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/orm/mapper.py72
-rw-r--r--lib/sqlalchemy/orm/persistence.py26
-rw-r--r--lib/sqlalchemy/sql/dml.py22
3 files changed, 87 insertions, 33 deletions
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index 769b1b623..6a95030b5 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -28,6 +28,7 @@ from typing import cast
from typing import Collection
from typing import Deque
from typing import Dict
+from typing import FrozenSet
from typing import Generic
from typing import Iterable
from typing import Iterator
@@ -2397,15 +2398,21 @@ class Mapper(
)
@HasMemoized.memoized_attribute
- def _server_default_cols(self):
+ def _server_default_cols(
+ self,
+ ) -> Mapping[FromClause, FrozenSet[Column[Any]]]:
return dict(
(
table,
frozenset(
[
- col.key
- for col in columns
+ col
+ for col in cast("Iterable[Column[Any]]", columns)
if col.server_default is not None
+ or (
+ col.default is not None
+ and col.default.is_clause_element
+ )
]
),
)
@@ -2413,35 +2420,60 @@ class Mapper(
)
@HasMemoized.memoized_attribute
- def _server_default_plus_onupdate_propkeys(self):
- result = set()
-
- for table, columns in self._cols_by_table.items():
- for col in columns:
- if (
- col.server_default is not None
- or col.server_onupdate is not None
- ) and col in self._columntoproperty:
- result.add(self._columntoproperty[col].key)
-
- return result
-
- @HasMemoized.memoized_attribute
- def _server_onupdate_default_cols(self):
+ def _server_onupdate_default_cols(
+ self,
+ ) -> Mapping[FromClause, FrozenSet[Column[Any]]]:
return dict(
(
table,
frozenset(
[
- col.key
- for col in columns
+ col
+ for col in cast("Iterable[Column[Any]]", columns)
if col.server_onupdate is not None
+ or (
+ col.onupdate is not None
+ and col.onupdate.is_clause_element
+ )
]
),
)
for table, columns in self._cols_by_table.items()
)
+ @HasMemoized.memoized_attribute
+ def _server_default_col_keys(self) -> Mapping[FromClause, FrozenSet[str]]:
+ return {
+ table: frozenset(col.key for col in cols if col.key is not None)
+ for table, cols in self._server_default_cols.items()
+ }
+
+ @HasMemoized.memoized_attribute
+ def _server_onupdate_default_col_keys(
+ self,
+ ) -> Mapping[FromClause, FrozenSet[str]]:
+ return {
+ table: frozenset(col.key for col in cols if col.key is not None)
+ for table, cols in self._server_onupdate_default_cols.items()
+ }
+
+ @HasMemoized.memoized_attribute
+ def _server_default_plus_onupdate_propkeys(self) -> Set[str]:
+ result: Set[str] = set()
+
+ col_to_property = self._columntoproperty
+ for table, columns in self._server_default_cols.items():
+ result.update(
+ col_to_property[col].key
+ for col in columns.intersection(col_to_property)
+ )
+ for table, columns in self._server_onupdate_default_cols.items():
+ result.update(
+ col_to_property[col].key
+ for col in columns.intersection(col_to_property)
+ )
+ return result
+
@HasMemoized.memoized_instancemethod
def __clause_element__(self):
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index c10f4701e..7cd66513b 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -561,9 +561,9 @@ def _collect_insert_commands(
has_all_pks = mapper._pk_keys_by_table[table].issubset(params)
if mapper.base_mapper.eager_defaults:
- has_all_defaults = mapper._server_default_cols[table].issubset(
- params
- )
+ has_all_defaults = mapper._server_default_col_keys[
+ table
+ ].issubset(params)
else:
has_all_defaults = True
else:
@@ -659,7 +659,7 @@ def _collect_update_commands(
if mapper.base_mapper.eager_defaults:
has_all_defaults = (
- mapper._server_onupdate_default_cols[table]
+ mapper._server_onupdate_default_col_keys[table]
).issubset(params)
else:
has_all_defaults = True
@@ -930,16 +930,20 @@ def _emit_update_statements(
return_defaults = False
if not has_all_pks:
- statement = statement.return_defaults()
+ statement = statement.return_defaults(*mapper._pks_by_table[table])
return_defaults = True
- elif (
+
+ if (
bookkeeping
and not has_all_defaults
and mapper.base_mapper.eager_defaults
):
- statement = statement.return_defaults()
+ statement = statement.return_defaults(
+ *mapper._server_onupdate_default_cols[table]
+ )
return_defaults = True
- elif mapper.version_id_col is not None:
+
+ if mapper.version_id_col is not None:
statement = statement.return_defaults(mapper.version_id_col)
return_defaults = True
@@ -1171,8 +1175,10 @@ def _emit_insert_statements(
do_executemany = False
if not has_all_defaults and base_mapper.eager_defaults:
- statement = statement.return_defaults()
- elif mapper.version_id_col is not None:
+ statement = statement.return_defaults(
+ *mapper._server_default_cols[table]
+ )
+ if mapper.version_id_col is not None:
statement = statement.return_defaults(mapper.version_id_col)
elif do_executemany:
statement = statement.return_defaults(*table.primary_key)
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index 76a16eb1c..9d489ed98 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -989,10 +989,26 @@ class ValuesBase(UpdateBase):
:attr:`_engine.CursorResult.inserted_primary_key_rows`
"""
+
+ if self._return_defaults:
+ # note _return_defaults_columns = () means return all columns,
+ # so if we have been here before, only update collection if there
+ # are columns in the collection
+ if self._return_defaults_columns and cols:
+ self._return_defaults_columns = tuple(
+ set(self._return_defaults_columns).union(
+ coercions.expect(roles.ColumnsClauseRole, c)
+ for c in cols
+ )
+ )
+ else:
+ # set for all columns
+ self._return_defaults_columns = ()
+ else:
+ self._return_defaults_columns = tuple(
+ coercions.expect(roles.ColumnsClauseRole, c) for c in cols
+ )
self._return_defaults = True
- self._return_defaults_columns = tuple(
- coercions.expect(roles.ColumnsClauseRole, c) for c in cols
- )
return self