summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2023-04-27 00:13:00 +0000
committerGerrit Code Review <gerrit@bbpush.zzzcomputing.com>2023-04-27 00:13:00 +0000
commit11535752b94acb41ff684cf8d9c745038addc447 (patch)
treefec970fe35228d33c45280ad558ed9bc251b5208 /lib/sqlalchemy
parentc89c2b3d9a18bd0eb4c8ace50ef875101c9f4b70 (diff)
parent8ec396873c9bbfcc4416e55b5f9d8653554a1df0 (diff)
downloadsqlalchemy-11535752b94acb41ff684cf8d9c745038addc447.tar.gz
Merge "support parameters in all ORM insert modes" into main
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/orm/bulk_persistence.py77
-rw-r--r--lib/sqlalchemy/orm/persistence.py31
-rw-r--r--lib/sqlalchemy/sql/elements.py22
3 files changed, 107 insertions, 23 deletions
diff --git a/lib/sqlalchemy/orm/bulk_persistence.py b/lib/sqlalchemy/orm/bulk_persistence.py
index 8388d3980..cb416d69e 100644
--- a/lib/sqlalchemy/orm/bulk_persistence.py
+++ b/lib/sqlalchemy/orm/bulk_persistence.py
@@ -150,6 +150,23 @@ def _bulk_insert(
for table, super_mapper in mappers_to_run:
+ # find bindparams in the statement. For bulk, we don't really know if
+ # a key in the params applies to a different table since we are
+ # potentially inserting for multiple tables here; looking at the
+ # bindparam() is a lot more direct. in most cases this will
+ # use _generate_cache_key() which is memoized, although in practice
+ # the ultimate statement that's executed is probably not the same
+ # object so that memoization might not matter much.
+ extra_bp_names = (
+ [
+ b.key
+ for b in use_orm_insert_stmt._get_embedded_bindparams()
+ if b.key in mappings[0]
+ ]
+ if use_orm_insert_stmt is not None
+ else ()
+ )
+
records = (
(
None,
@@ -176,6 +193,7 @@ def _bulk_insert(
bulk=True,
return_defaults=bookkeeping,
render_nulls=render_nulls,
+ include_bulk_keys=extra_bp_names,
)
)
@@ -218,6 +236,7 @@ def _bulk_update(
isstates: bool,
update_changed_only: bool,
use_orm_update_stmt: Literal[None] = ...,
+ enable_check_rowcount: bool = True,
) -> None:
...
@@ -230,6 +249,7 @@ def _bulk_update(
isstates: bool,
update_changed_only: bool,
use_orm_update_stmt: Optional[dml.Update] = ...,
+ enable_check_rowcount: bool = True,
) -> _result.Result[Any]:
...
@@ -241,6 +261,7 @@ def _bulk_update(
isstates: bool,
update_changed_only: bool,
use_orm_update_stmt: Optional[dml.Update] = None,
+ enable_check_rowcount: bool = True,
) -> Optional[_result.Result[Any]]:
base_mapper = mapper.base_mapper
@@ -272,6 +293,18 @@ def _bulk_update(
connection = session_transaction.connection(base_mapper)
+ # find bindparams in the statement. see _bulk_insert for similar
+ # notes for the insert case
+ extra_bp_names = (
+ [
+ b.key
+ for b in use_orm_update_stmt._get_embedded_bindparams()
+ if b.key in mappings[0]
+ ]
+ if use_orm_update_stmt is not None
+ else ()
+ )
+
for table, super_mapper in base_mapper._sorted_tables.items():
if not mapper.isa(super_mapper) or table not in mapper._pks_by_table:
continue
@@ -295,6 +328,7 @@ def _bulk_update(
),
bulk=True,
use_orm_update_stmt=use_orm_update_stmt,
+ include_bulk_keys=extra_bp_names,
)
persistence._emit_update_statements(
base_mapper,
@@ -304,6 +338,7 @@ def _bulk_update(
records,
bookkeeping=False,
use_orm_update_stmt=use_orm_update_stmt,
+ enable_check_rowcount=enable_check_rowcount,
)
if use_orm_update_stmt is not None:
@@ -588,6 +623,7 @@ class BulkUDCompileState(ORMDMLState):
is_multitable: bool = False,
is_update_from: bool = False,
is_delete_using: bool = False,
+ is_executemany: bool = False,
) -> bool:
raise NotImplementedError()
@@ -639,11 +675,6 @@ class BulkUDCompileState(ORMDMLState):
else:
if update_options._dml_strategy == "auto":
update_options += {"_dml_strategy": "bulk"}
- elif update_options._dml_strategy == "orm":
- raise sa_exc.InvalidRequestError(
- 'Can\'t use "orm" ORM insert strategy with a '
- "separate parameter list"
- )
sync = update_options._synchronize_session
if sync is not None:
@@ -1062,6 +1093,7 @@ class BulkUDCompileState(ORMDMLState):
mapper,
is_update_from=update_options._is_update_from,
is_delete_using=update_options._is_delete_using,
+ is_executemany=orm_context.is_executemany,
)
if can_use_returning is not None:
@@ -1071,6 +1103,12 @@ class BulkUDCompileState(ORMDMLState):
"backends where some support RETURNING and others "
"don't"
)
+ elif orm_context.is_executemany and not per_bind_result:
+ raise sa_exc.InvalidRequestError(
+ "For synchronize_session='fetch', can't use multiple "
+ "parameter sets in ORM mode, which this backend does not "
+ "support with RETURNING"
+ )
else:
can_use_returning = per_bind_result
@@ -1146,11 +1184,6 @@ class BulkORMInsert(ORMDMLState, InsertDMLState):
else:
if insert_options._dml_strategy == "auto":
insert_options += {"_dml_strategy": "bulk"}
- elif insert_options._dml_strategy == "orm":
- raise sa_exc.InvalidRequestError(
- 'Can\'t use "orm" ORM insert strategy with a '
- "separate parameter list"
- )
if insert_options._dml_strategy != "raw":
# for ORM object loading, like ORMContext, we have to disable
@@ -1512,12 +1545,20 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
result: _result.Result[Any]
if update_options._dml_strategy == "bulk":
- if statement._where_criteria:
+ enable_check_rowcount = not statement._where_criteria
+
+ assert update_options._synchronize_session != "fetch"
+
+ if (
+ statement._where_criteria
+ and update_options._synchronize_session == "evaluate"
+ ):
raise sa_exc.InvalidRequestError(
- "WHERE clause with bulk ORM UPDATE not "
- "supported right now. Statement may be invoked at the "
- "Core level using "
- "session.connection().execute(stmt, parameters)"
+ "bulk synchronize of persistent objects not supported "
+ "when using bulk update with additional WHERE "
+ "criteria right now. add synchronize_session=None "
+ "execution option to bypass synchronize of persistent "
+ "objects."
)
mapper = update_options._subject_mapper
assert mapper is not None
@@ -1532,6 +1573,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
isstates=False,
update_changed_only=False,
use_orm_update_stmt=statement,
+ enable_check_rowcount=enable_check_rowcount,
)
return cls.orm_setup_cursor_result(
session,
@@ -1560,6 +1602,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
is_multitable: bool = False,
is_update_from: bool = False,
is_delete_using: bool = False,
+ is_executemany: bool = False,
) -> bool:
# normal answer for "should we use RETURNING" at all.
@@ -1569,6 +1612,9 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
if not normal_answer:
return False
+ if is_executemany:
+ return dialect.update_executemany_returning
+
# these workarounds are currently hypothetical for UPDATE,
# unlike DELETE where they impact MariaDB
if is_update_from:
@@ -1869,6 +1915,7 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
is_multitable: bool = False,
is_update_from: bool = False,
is_delete_using: bool = False,
+ is_executemany: bool = False,
) -> bool:
# normal answer for "should we use RETURNING" at all.
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index 1af55df00..6fa338ced 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -326,9 +326,11 @@ def _organize_states_for_delete(base_mapper, states, uowtransaction):
def _collect_insert_commands(
table,
states_to_insert,
+ *,
bulk=False,
return_defaults=False,
render_nulls=False,
+ include_bulk_keys=(),
):
"""Identify sets of values to use in INSERT statements for a
list of states.
@@ -401,10 +403,14 @@ def _collect_insert_commands(
None
)
- if bulk and mapper._set_polymorphic_identity:
- params.setdefault(
- mapper._polymorphic_attr_key, mapper.polymorphic_identity
- )
+ if bulk:
+ if mapper._set_polymorphic_identity:
+ params.setdefault(
+ mapper._polymorphic_attr_key, mapper.polymorphic_identity
+ )
+
+ if include_bulk_keys:
+ params.update((k, state_dict[k]) for k in include_bulk_keys)
yield (
state,
@@ -422,8 +428,10 @@ def _collect_update_commands(
uowtransaction,
table,
states_to_update,
+ *,
bulk=False,
use_orm_update_stmt=None,
+ include_bulk_keys=(),
):
"""Identify sets of values to use in UPDATE statements for a
list of states.
@@ -581,6 +589,9 @@ def _collect_update_commands(
"key value on column %s" % (table, col)
)
+ if include_bulk_keys:
+ params.update((k, state_dict[k]) for k in include_bulk_keys)
+
if params or value_params:
params.update(pk_params)
yield (
@@ -712,8 +723,10 @@ def _emit_update_statements(
mapper,
table,
update,
+ *,
bookkeeping=True,
use_orm_update_stmt=None,
+ enable_check_rowcount=True,
):
"""Emit UPDATE statements corresponding to value lists collected
by _collect_update_commands()."""
@@ -847,10 +860,10 @@ def _emit_update_statements(
c.returned_defaults,
)
rows += c.rowcount
- check_rowcount = assert_singlerow
+ check_rowcount = enable_check_rowcount and assert_singlerow
else:
if not allow_executemany:
- check_rowcount = assert_singlerow
+ check_rowcount = enable_check_rowcount and assert_singlerow
for (
state,
state_dict,
@@ -883,8 +896,9 @@ def _emit_update_statements(
else:
multiparams = [rec[2] for rec in records]
- check_rowcount = assert_multirow or (
- assert_singlerow and len(multiparams) == 1
+ check_rowcount = enable_check_rowcount and (
+ assert_multirow
+ or (assert_singlerow and len(multiparams) == 1)
)
c = connection.execute(
@@ -941,6 +955,7 @@ def _emit_insert_statements(
mapper,
table,
insert,
+ *,
bookkeeping=True,
use_orm_insert_stmt=None,
execution_options=None,
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index ff47ec79d..2e32da754 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -502,6 +502,28 @@ class ClauseElement(
connection, distilled_params, execution_options
).scalar()
+ def _get_embedded_bindparams(self) -> Sequence[BindParameter[Any]]:
+ """Return the list of :class:`.BindParameter` objects embedded in the
+ object.
+
+ This accomplishes the same purpose as ``visitors.traverse()`` or
+ similar would provide, however by making use of the cache key
+ it takes advantage of memoization of the key to result in fewer
+ net method calls, assuming the statement is also going to be
+ executed.
+
+ """
+
+ key = self._generate_cache_key()
+ if key is None:
+ bindparams: List[BindParameter[Any]] = []
+
+ traverse(self, {}, {"bindparam": bindparams.append})
+ return bindparams
+
+ else:
+ return key.bindparams
+
def unique_params(
self,
__optionaldict: Optional[Dict[str, Any]] = None,