summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2022-09-26 01:17:44 +0000
committerGerrit Code Review <gerrit@ci3.zzzcomputing.com>2022-09-26 01:17:44 +0000
commit6201b4d88666983b883b96d22a159aa2594de94b (patch)
tree4036c155ca7c274ea4bd12c059fd8fcd277fc026 /test
parentf81fdd9a9008a6517f89f2115765b7db9a32721b (diff)
parenta8029f5a7e3e376ec57f1614ab0294b717d53c05 (diff)
downloadsqlalchemy-6201b4d88666983b883b96d22a159aa2594de94b.tar.gz
Merge "ORM bulk insert via execute" into main
Diffstat (limited to 'test')
-rw-r--r--test/ext/test_horizontal_shard.py212
-rw-r--r--test/ext/test_hybrid.py35
-rw-r--r--test/orm/dml/__init__.py0
-rw-r--r--test/orm/dml/test_bulk.py (renamed from test/orm/test_bulk.py)231
-rw-r--r--test/orm/dml/test_bulk_statements.py1199
-rw-r--r--test/orm/dml/test_evaluator.py (renamed from test/orm/test_evaluator.py)1
-rw-r--r--test/orm/dml/test_update_delete_where.py (renamed from test/orm/test_update_delete.py)499
-rw-r--r--test/orm/inheritance/test_basic.py3
-rw-r--r--test/orm/test_bind.py11
-rw-r--r--test/orm/test_composites.py207
-rw-r--r--test/orm/test_cycles.py2
-rw-r--r--test/orm/test_defaults.py4
-rw-r--r--test/orm/test_events.py13
-rw-r--r--test/orm/test_unitofwork.py6
-rw-r--r--test/orm/test_unitofworkv2.py34
-rw-r--r--test/orm/test_versioning.py10
-rw-r--r--test/sql/test_resultset.py47
-rw-r--r--test/sql/test_returning.py195
-rw-r--r--test/sql/test_selectable.py21
19 files changed, 2420 insertions, 310 deletions
diff --git a/test/ext/test_horizontal_shard.py b/test/ext/test_horizontal_shard.py
index 7cc6a6f79..667f4bfb0 100644
--- a/test/ext/test_horizontal_shard.py
+++ b/test/ext/test_horizontal_shard.py
@@ -465,7 +465,11 @@ class ShardTest:
t = get_tokyo(sess2)
eq_(t.city, tokyo.city)
- def test_bulk_update_synchronize_evaluate(self):
+ @testing.combinations(
+ "fetch", "evaluate", "auto", argnames="synchronize_session"
+ )
+ @testing.combinations(True, False, argnames="legacy")
+ def test_orm_update_synchronize(self, synchronize_session, legacy):
sess = self._fixture_data()
eq_(
@@ -476,33 +480,25 @@ class ShardTest:
temps = sess.query(Report).all()
eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
- sess.query(Report).filter(Report.temperature >= 80).update(
- {"temperature": Report.temperature + 6},
- synchronize_session="evaluate",
- )
-
- eq_(
- set(row.temperature for row in sess.query(Report.temperature)),
- {86.0, 75.0, 91.0},
- )
-
- # test synchronize session as well
- eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0})
-
- def test_bulk_update_synchronize_fetch(self):
- sess = self._fixture_data()
-
- eq_(
- set(row.temperature for row in sess.query(Report.temperature)),
- {80.0, 75.0, 85.0},
- )
+ if legacy:
+ sess.query(Report).filter(Report.temperature >= 80).update(
+ {"temperature": Report.temperature + 6},
+ synchronize_session=synchronize_session,
+ )
+ else:
+ sess.execute(
+ update(Report)
+ .filter(Report.temperature >= 80)
+ .values(temperature=Report.temperature + 6)
+ .execution_options(synchronize_session=synchronize_session)
+ )
- temps = sess.query(Report).all()
- eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
+ # test synchronize session
+ def go():
+ eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0})
- sess.query(Report).filter(Report.temperature >= 80).update(
- {"temperature": Report.temperature + 6},
- synchronize_session="fetch",
+ self.assert_sql_count(
+ sess._ShardedSession__binds["north_america"], go, 0
)
eq_(
@@ -510,165 +506,41 @@ class ShardTest:
{86.0, 75.0, 91.0},
)
- # test synchronize session as well
- eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0})
-
- def test_bulk_delete_synchronize_evaluate(self):
- sess = self._fixture_data()
-
- temps = sess.query(Report).all()
- eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
-
- sess.query(Report).filter(Report.temperature >= 80).delete(
- synchronize_session="evaluate"
- )
-
- eq_(
- set(row.temperature for row in sess.query(Report.temperature)),
- {75.0},
- )
-
- # test synchronize session as well
- for t in temps:
- assert inspect(t).deleted is (t.temperature >= 80)
-
- def test_bulk_delete_synchronize_fetch(self):
+ @testing.combinations(
+ "fetch", "evaluate", "auto", argnames="synchronize_session"
+ )
+ @testing.combinations(True, False, argnames="legacy")
+ def test_orm_delete_synchronize(self, synchronize_session, legacy):
sess = self._fixture_data()
temps = sess.query(Report).all()
eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
- sess.query(Report).filter(Report.temperature >= 80).delete(
- synchronize_session="fetch"
- )
-
- eq_(
- set(row.temperature for row in sess.query(Report.temperature)),
- {75.0},
- )
-
- # test synchronize session as well
- for t in temps:
- assert inspect(t).deleted is (t.temperature >= 80)
-
- def test_bulk_update_future_synchronize_evaluate(self):
- sess = self._fixture_data()
-
- eq_(
- set(
- row.temperature
- for row in sess.execute(select(Report.temperature))
- ),
- {80.0, 75.0, 85.0},
- )
-
- temps = sess.execute(select(Report)).scalars().all()
- eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
-
- sess.execute(
- update(Report)
- .filter(Report.temperature >= 80)
- .values(
- {"temperature": Report.temperature + 6},
+ if legacy:
+ sess.query(Report).filter(Report.temperature >= 80).delete(
+ synchronize_session=synchronize_session
)
- .execution_options(synchronize_session="evaluate")
- )
-
- eq_(
- set(
- row.temperature
- for row in sess.execute(select(Report.temperature))
- ),
- {86.0, 75.0, 91.0},
- )
-
- # test synchronize session as well
- eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0})
-
- def test_bulk_update_future_synchronize_fetch(self):
- sess = self._fixture_data()
-
- eq_(
- set(
- row.temperature
- for row in sess.execute(select(Report.temperature))
- ),
- {80.0, 75.0, 85.0},
- )
-
- temps = sess.execute(select(Report)).scalars().all()
- eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
-
- # MARKMARK
- # omitting the criteria so that the UPDATE affects three out of
- # four shards
- sess.execute(
- update(Report)
- .values(
- {"temperature": Report.temperature + 6},
+ else:
+ sess.execute(
+ delete(Report)
+ .filter(Report.temperature >= 80)
+ .execution_options(synchronize_session=synchronize_session)
)
- .execution_options(synchronize_session="fetch")
- )
-
- eq_(
- set(
- row.temperature
- for row in sess.execute(select(Report.temperature))
- ),
- {86.0, 81.0, 91.0},
- )
-
- # test synchronize session as well
- eq_(set(t.temperature for t in temps), {86.0, 81.0, 91.0})
-
- def test_bulk_delete_future_synchronize_evaluate(self):
- sess = self._fixture_data()
-
- temps = sess.execute(select(Report)).scalars().all()
- eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
-
- sess.execute(
- delete(Report)
- .filter(Report.temperature >= 80)
- .execution_options(synchronize_session="evaluate")
- )
- eq_(
- set(
- row.temperature
- for row in sess.execute(select(Report.temperature))
- ),
- {75.0},
- )
-
- # test synchronize session as well
- for t in temps:
- assert inspect(t).deleted is (t.temperature >= 80)
-
- def test_bulk_delete_future_synchronize_fetch(self):
- sess = self._fixture_data()
-
- temps = sess.execute(select(Report)).scalars().all()
- eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
+ def go():
+ # test synchronize session
+ for t in temps:
+ assert inspect(t).deleted is (t.temperature >= 80)
- sess.execute(
- delete(Report)
- .filter(Report.temperature >= 80)
- .execution_options(synchronize_session="fetch")
+ self.assert_sql_count(
+ sess._ShardedSession__binds["north_america"], go, 0
)
eq_(
- set(
- row.temperature
- for row in sess.execute(select(Report.temperature))
- ),
+ set(row.temperature for row in sess.query(Report.temperature)),
{75.0},
)
- # test synchronize session as well
- for t in temps:
- assert inspect(t).deleted is (t.temperature >= 80)
-
class DistinctEngineShardTest(ShardTest, fixtures.MappedTest):
def _init_dbs(self):
diff --git a/test/ext/test_hybrid.py b/test/ext/test_hybrid.py
index de5f89b25..0cba8f3a1 100644
--- a/test/ext/test_hybrid.py
+++ b/test/ext/test_hybrid.py
@@ -3,6 +3,7 @@ from decimal import Decimal
from sqlalchemy import exc
from sqlalchemy import ForeignKey
from sqlalchemy import func
+from sqlalchemy import insert
from sqlalchemy import inspect
from sqlalchemy import Integer
from sqlalchemy import LABEL_STYLE_TABLENAME_PLUS_COL
@@ -1017,15 +1018,43 @@ class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL):
params={"first_name": "Dr."},
)
- def test_update_expr(self):
+ @testing.combinations("attr", "str", "kwarg", argnames="keytype")
+ def test_update_expr(self, keytype):
Person = self.classes.Person
- statement = update(Person).values({Person.name: "Dr. No"})
+ if keytype == "attr":
+ statement = update(Person).values({Person.name: "Dr. No"})
+ elif keytype == "str":
+ statement = update(Person).values({"name": "Dr. No"})
+ elif keytype == "kwarg":
+ statement = update(Person).values(name="Dr. No")
+ else:
+ assert False
self.assert_compile(
statement,
"UPDATE person SET first_name=:first_name, last_name=:last_name",
- params={"first_name": "Dr.", "last_name": "No"},
+ checkparams={"first_name": "Dr.", "last_name": "No"},
+ )
+
+ @testing.combinations("attr", "str", "kwarg", argnames="keytype")
+ def test_insert_expr(self, keytype):
+ Person = self.classes.Person
+
+ if keytype == "attr":
+ statement = insert(Person).values({Person.name: "Dr. No"})
+ elif keytype == "str":
+ statement = insert(Person).values({"name": "Dr. No"})
+ elif keytype == "kwarg":
+ statement = insert(Person).values(name="Dr. No")
+ else:
+ assert False
+
+ self.assert_compile(
+ statement,
+ "INSERT INTO person (first_name, last_name) VALUES "
+ "(:first_name, :last_name)",
+ checkparams={"first_name": "Dr.", "last_name": "No"},
)
# these tests all run two UPDATES to assert that caching is not
diff --git a/test/orm/dml/__init__.py b/test/orm/dml/__init__.py
new file mode 100644
index 000000000..e69de29bb
--- /dev/null
+++ b/test/orm/dml/__init__.py
diff --git a/test/orm/test_bulk.py b/test/orm/dml/test_bulk.py
index 802cdfac5..52db4247f 100644
--- a/test/orm/test_bulk.py
+++ b/test/orm/dml/test_bulk.py
@@ -1,8 +1,11 @@
from sqlalchemy import FetchedValue
from sqlalchemy import ForeignKey
+from sqlalchemy import Identity
+from sqlalchemy import insert
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy import testing
+from sqlalchemy import update
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import mock
@@ -20,6 +23,8 @@ class BulkTest(testing.AssertsExecutionResults):
class BulkInsertUpdateVersionId(BulkTest, fixtures.MappedTest):
+ __backend__ = True
+
@classmethod
def define_tables(cls, metadata):
Table(
@@ -73,6 +78,8 @@ class BulkInsertUpdateVersionId(BulkTest, fixtures.MappedTest):
class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest):
+ __backend__ = True
+
@classmethod
def setup_mappers(cls):
User, Address, Order = cls.classes("User", "Address", "Order")
@@ -82,22 +89,42 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest):
cls.mapper_registry.map_imperatively(Address, a)
cls.mapper_registry.map_imperatively(Order, o)
- def test_bulk_save_return_defaults(self):
+ @testing.combinations("save_objects", "insert_mappings", "insert_stmt")
+ def test_bulk_save_return_defaults(self, statement_type):
(User,) = self.classes("User")
s = fixture_session()
- objects = [User(name="u1"), User(name="u2"), User(name="u3")]
- assert "id" not in objects[0].__dict__
- with self.sql_execution_asserter() as asserter:
- s.bulk_save_objects(objects, return_defaults=True)
+ if statement_type == "save_objects":
+ objects = [User(name="u1"), User(name="u2"), User(name="u3")]
+ assert "id" not in objects[0].__dict__
+
+ returning_users_id = " RETURNING users.id"
+ with self.sql_execution_asserter() as asserter:
+ s.bulk_save_objects(objects, return_defaults=True)
+ elif statement_type == "insert_mappings":
+ data = [dict(name="u1"), dict(name="u2"), dict(name="u3")]
+ returning_users_id = " RETURNING users.id"
+ with self.sql_execution_asserter() as asserter:
+ s.bulk_insert_mappings(User, data, return_defaults=True)
+ elif statement_type == "insert_stmt":
+ data = [dict(name="u1"), dict(name="u2"), dict(name="u3")]
+
+ # for statement, "return_defaults" is heuristic on if we are
+ # a joined inh mapping if we don't otherwise include
+ # .returning() on the statement itself
+ returning_users_id = ""
+ with self.sql_execution_asserter() as asserter:
+ s.execute(insert(User), data)
asserter.assert_(
Conditional(
- testing.db.dialect.insert_executemany_returning,
+ testing.db.dialect.insert_executemany_returning
+ or statement_type == "insert_stmt",
[
CompiledSQL(
- "INSERT INTO users (name) VALUES (:name)",
+ "INSERT INTO users (name) "
+ f"VALUES (:name){returning_users_id}",
[{"name": "u1"}, {"name": "u2"}, {"name": "u3"}],
),
],
@@ -117,7 +144,8 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest):
],
)
)
- eq_(objects[0].__dict__["id"], 1)
+ if statement_type == "save_objects":
+ eq_(objects[0].__dict__["id"], 1)
def test_bulk_save_mappings_preserve_order(self):
(User,) = self.classes("User")
@@ -219,8 +247,9 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest):
)
)
- def test_bulk_update(self):
- (User,) = self.classes("User")
+ @testing.combinations("update_mappings", "update_stmt")
+ def test_bulk_update(self, statement_type):
+ User = self.classes.User
s = fixture_session(expire_on_commit=False)
objects = [User(name="u1"), User(name="u2"), User(name="u3")]
@@ -228,15 +257,18 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest):
s.commit()
s = fixture_session()
- with self.sql_execution_asserter() as asserter:
- s.bulk_update_mappings(
- User,
- [
- {"id": 1, "name": "u1new"},
- {"id": 2, "name": "u2"},
- {"id": 3, "name": "u3new"},
- ],
- )
+ data = [
+ {"id": 1, "name": "u1new"},
+ {"id": 2, "name": "u2"},
+ {"id": 3, "name": "u3new"},
+ ]
+
+ if statement_type == "update_mappings":
+ with self.sql_execution_asserter() as asserter:
+ s.bulk_update_mappings(User, data)
+ elif statement_type == "update_stmt":
+ with self.sql_execution_asserter() as asserter:
+ s.execute(update(User), data)
asserter.assert_(
CompiledSQL(
@@ -303,6 +335,8 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest):
class BulkUDPostfetchTest(BulkTest, fixtures.MappedTest):
+ __backend__ = True
+
@classmethod
def define_tables(cls, metadata):
Table(
@@ -360,6 +394,8 @@ class BulkUDPostfetchTest(BulkTest, fixtures.MappedTest):
class BulkUDTestAltColKeys(BulkTest, fixtures.MappedTest):
+ __backend__ = True
+
@classmethod
def define_tables(cls, metadata):
Table(
@@ -547,6 +583,8 @@ class BulkUDTestAltColKeys(BulkTest, fixtures.MappedTest):
class BulkInheritanceTest(BulkTest, fixtures.MappedTest):
+ __backend__ = True
+
@classmethod
def define_tables(cls, metadata):
Table(
@@ -643,6 +681,7 @@ class BulkInheritanceTest(BulkTest, fixtures.MappedTest):
)
s = fixture_session()
+
objects = [
Manager(name="m1", status="s1", manager_name="mn1"),
Engineer(name="e1", status="s2", primary_language="l1"),
@@ -669,7 +708,7 @@ class BulkInheritanceTest(BulkTest, fixtures.MappedTest):
[
CompiledSQL(
"INSERT INTO people (name, type) "
- "VALUES (:name, :type)",
+ "VALUES (:name, :type) RETURNING people.person_id",
[
{"type": "engineer", "name": "e1"},
{"type": "engineer", "name": "e2"},
@@ -798,59 +837,74 @@ class BulkInheritanceTest(BulkTest, fixtures.MappedTest):
),
)
- def test_bulk_insert_joined_inh_return_defaults(self):
+ @testing.combinations("insert_mappings", "insert_stmt")
+ def test_bulk_insert_joined_inh_return_defaults(self, statement_type):
Person, Engineer, Manager, Boss = self.classes(
"Person", "Engineer", "Manager", "Boss"
)
s = fixture_session()
- with self.sql_execution_asserter() as asserter:
- s.bulk_insert_mappings(
- Boss,
- [
- dict(
- name="b1",
- status="s1",
- manager_name="mn1",
- golf_swing="g1",
- ),
- dict(
- name="b2",
- status="s2",
- manager_name="mn2",
- golf_swing="g2",
- ),
- dict(
- name="b3",
- status="s3",
- manager_name="mn3",
- golf_swing="g3",
- ),
- ],
- return_defaults=True,
- )
+ data = [
+ dict(
+ name="b1",
+ status="s1",
+ manager_name="mn1",
+ golf_swing="g1",
+ ),
+ dict(
+ name="b2",
+ status="s2",
+ manager_name="mn2",
+ golf_swing="g2",
+ ),
+ dict(
+ name="b3",
+ status="s3",
+ manager_name="mn3",
+ golf_swing="g3",
+ ),
+ ]
+
+ if statement_type == "insert_mappings":
+ with self.sql_execution_asserter() as asserter:
+ s.bulk_insert_mappings(
+ Boss,
+ data,
+ return_defaults=True,
+ )
+ elif statement_type == "insert_stmt":
+ with self.sql_execution_asserter() as asserter:
+ s.execute(insert(Boss), data)
asserter.assert_(
Conditional(
testing.db.dialect.insert_executemany_returning,
[
CompiledSQL(
- "INSERT INTO people (name) VALUES (:name)",
- [{"name": "b1"}, {"name": "b2"}, {"name": "b3"}],
+ "INSERT INTO people (name, type) "
+ "VALUES (:name, :type) RETURNING people.person_id",
+ [
+ {"name": "b1", "type": "boss"},
+ {"name": "b2", "type": "boss"},
+ {"name": "b3", "type": "boss"},
+ ],
),
],
[
CompiledSQL(
- "INSERT INTO people (name) VALUES (:name)",
- [{"name": "b1"}],
+ "INSERT INTO people (name, type) "
+ "VALUES (:name, :type)",
+ [{"name": "b1", "type": "boss"}],
),
CompiledSQL(
- "INSERT INTO people (name) VALUES (:name)",
- [{"name": "b2"}],
+ "INSERT INTO people (name, type) "
+ "VALUES (:name, :type)",
+ [{"name": "b2", "type": "boss"}],
),
CompiledSQL(
- "INSERT INTO people (name) VALUES (:name)",
- [{"name": "b3"}],
+ "INSERT INTO people (name, type) "
+ "VALUES (:name, :type)",
+ [{"name": "b3", "type": "boss"}],
),
],
),
@@ -874,15 +928,79 @@ class BulkInheritanceTest(BulkTest, fixtures.MappedTest):
),
)
+ @testing.combinations("update_mappings", "update_stmt")
+ def test_bulk_update(self, statement_type):
+ Person, Engineer, Manager, Boss = self.classes(
+ "Person", "Engineer", "Manager", "Boss"
+ )
+
+ s = fixture_session()
+
+ b1, b2, b3 = (
+ Boss(name="b1", status="s1", manager_name="mn1", golf_swing="g1"),
+ Boss(name="b2", status="s2", manager_name="mn2", golf_swing="g2"),
+ Boss(name="b3", status="s3", manager_name="mn3", golf_swing="g3"),
+ )
+ s.add_all([b1, b2, b3])
+ s.commit()
+
+ # slight non-convenient thing. we have to fill in boss_id here
+ # for update, this is not sent along automatically. this is not a
+ # new behavior in bulk
+ new_data = [
+ {
+ "person_id": b1.person_id,
+ "boss_id": b1.boss_id,
+ "name": "b1_updated",
+ "manager_name": "mn1_updated",
+ },
+ {
+ "person_id": b3.person_id,
+ "boss_id": b3.boss_id,
+ "manager_name": "mn2_updated",
+ "golf_swing": "g1_updated",
+ },
+ ]
+
+ if statement_type == "update_mappings":
+ with self.sql_execution_asserter() as asserter:
+ s.bulk_update_mappings(Boss, new_data)
+ elif statement_type == "update_stmt":
+ with self.sql_execution_asserter() as asserter:
+ s.execute(update(Boss), new_data)
+
+ asserter.assert_(
+ CompiledSQL(
+ "UPDATE people SET name=:name WHERE "
+ "people.person_id = :people_person_id",
+ [{"name": "b1_updated", "people_person_id": 1}],
+ ),
+ CompiledSQL(
+ "UPDATE managers SET manager_name=:manager_name WHERE "
+ "managers.person_id = :managers_person_id",
+ [
+ {"manager_name": "mn1_updated", "managers_person_id": 1},
+ {"manager_name": "mn2_updated", "managers_person_id": 3},
+ ],
+ ),
+ CompiledSQL(
+ "UPDATE boss SET golf_swing=:golf_swing WHERE "
+ "boss.boss_id = :boss_boss_id",
+ [{"golf_swing": "g1_updated", "boss_boss_id": 3}],
+ ),
+ )
+
class BulkIssue6793Test(BulkTest, fixtures.DeclarativeMappedTest):
+ __backend__ = True
+
@classmethod
def setup_classes(cls):
Base = cls.DeclarativeBasic
class User(Base):
__tablename__ = "users"
- id = Column(Integer, primary_key=True)
+ id = Column(Integer, Identity(), primary_key=True)
name = Column(String(255), nullable=False)
def test_issue_6793(self):
@@ -907,7 +1025,8 @@ class BulkIssue6793Test(BulkTest, fixtures.DeclarativeMappedTest):
[{"name": "A"}, {"name": "B"}],
),
CompiledSQL(
- "INSERT INTO users (name) VALUES (:name)",
+ "INSERT INTO users (name) VALUES (:name) "
+ "RETURNING users.id",
[{"name": "C"}, {"name": "D"}],
),
],
diff --git a/test/orm/dml/test_bulk_statements.py b/test/orm/dml/test_bulk_statements.py
new file mode 100644
index 000000000..0cca9e6f5
--- /dev/null
+++ b/test/orm/dml/test_bulk_statements.py
@@ -0,0 +1,1199 @@
+from __future__ import annotations
+
+from typing import Any
+from typing import List
+from typing import Optional
+import uuid
+
+from sqlalchemy import exc
+from sqlalchemy import ForeignKey
+from sqlalchemy import func
+from sqlalchemy import Identity
+from sqlalchemy import insert
+from sqlalchemy import inspect
+from sqlalchemy import literal_column
+from sqlalchemy import select
+from sqlalchemy import String
+from sqlalchemy import testing
+from sqlalchemy import update
+from sqlalchemy.orm import aliased
+from sqlalchemy.orm import load_only
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+from sqlalchemy.testing import config
+from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises_message
+from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import mock
+from sqlalchemy.testing import provision
+from sqlalchemy.testing.assertsql import CompiledSQL
+from sqlalchemy.testing.fixtures import fixture_session
+
+
+class NoReturningTest(fixtures.TestBase):
+ def test_no_returning_error(self, decl_base):
+ class A(fixtures.ComparableEntity, decl_base):
+ __tablename__ = "a"
+ id: Mapped[int] = mapped_column(Identity(), primary_key=True)
+ data: Mapped[str]
+ x: Mapped[Optional[int]] = mapped_column("xcol")
+
+ decl_base.metadata.create_all(testing.db)
+ s = fixture_session()
+
+ if testing.requires.insert_executemany_returning.enabled:
+ result = s.scalars(
+ insert(A).returning(A),
+ [
+ {"data": "d3", "x": 5},
+ {"data": "d4", "x": 6},
+ ],
+ )
+ eq_(result.all(), [A(data="d3", x=5), A(data="d4", x=6)])
+
+ else:
+ with expect_raises_message(
+ exc.InvalidRequestError,
+ "Can't use explicit RETURNING for bulk INSERT operation",
+ ):
+ s.scalars(
+ insert(A).returning(A),
+ [
+ {"data": "d3", "x": 5},
+ {"data": "d4", "x": 6},
+ ],
+ )
+
+ def test_omit_returning_ok(self, decl_base):
+ class A(decl_base):
+ __tablename__ = "a"
+ id: Mapped[int] = mapped_column(Identity(), primary_key=True)
+ data: Mapped[str]
+ x: Mapped[Optional[int]] = mapped_column("xcol")
+
+ decl_base.metadata.create_all(testing.db)
+ s = fixture_session()
+
+ s.execute(
+ insert(A),
+ [
+ {"data": "d3", "x": 5},
+ {"data": "d4", "x": 6},
+ ],
+ )
+ eq_(
+ s.execute(select(A.data, A.x).order_by(A.id)).all(),
+ [("d3", 5), ("d4", 6)],
+ )
+
+
+class BulkDMLReturningInhTest:
+ def test_insert_col_key_also_works_currently(self):
+ """using the column key, not mapped attr key.
+
+ right now this passes through to the INSERT. when doing this with
+ an UPDATE, it tends to fail because the synchronize session
+ strategies can't match "xcol" back. however w/ INSERT we aren't
+ doing that, so there's no place this gets checked. UPDATE also
+ succeeds if synchronize_session is turned off.
+
+ """
+ A, B = self.classes("A", "B")
+
+ s = fixture_session()
+ s.execute(insert(A).values(type="a", data="d", xcol=10))
+ eq_(s.scalars(select(A.x)).all(), [10])
+
+ @testing.combinations(True, False, argnames="use_returning")
+ def test_heterogeneous_keys(self, use_returning):
+ A, B = self.classes("A", "B")
+
+ values = [
+ {"data": "d3", "x": 5, "type": "a"},
+ {"data": "d4", "x": 6, "type": "a"},
+ {"data": "d5", "type": "a"},
+ {"data": "d6", "x": 8, "y": 9, "type": "a"},
+ {"data": "d7", "x": 12, "y": 12, "type": "a"},
+ {"data": "d8", "x": 7, "type": "a"},
+ ]
+
+ s = fixture_session()
+
+ stmt = insert(A)
+ if use_returning:
+ stmt = stmt.returning(A)
+
+ with self.sql_execution_asserter() as asserter:
+ result = s.execute(stmt, values)
+
+ if inspect(B).single:
+ single_inh = ", a.bd, a.zcol, a.q"
+ else:
+ single_inh = ""
+
+ if use_returning:
+ asserter.assert_(
+ CompiledSQL(
+ "INSERT INTO a (type, data, xcol) VALUES "
+ "(:type, :data, :xcol) "
+ f"RETURNING a.id, a.type, a.data, a.xcol, a.y{single_inh}",
+ [
+ {"type": "a", "data": "d3", "xcol": 5},
+ {"type": "a", "data": "d4", "xcol": 6},
+ ],
+ ),
+ CompiledSQL(
+ "INSERT INTO a (type, data) VALUES (:type, :data) "
+ f"RETURNING a.id, a.type, a.data, a.xcol, a.y{single_inh}",
+ [{"type": "a", "data": "d5"}],
+ ),
+ CompiledSQL(
+ "INSERT INTO a (type, data, xcol, y) "
+ "VALUES (:type, :data, :xcol, :y) "
+ f"RETURNING a.id, a.type, a.data, a.xcol, a.y{single_inh}",
+ [
+ {"type": "a", "data": "d6", "xcol": 8, "y": 9},
+ {"type": "a", "data": "d7", "xcol": 12, "y": 12},
+ ],
+ ),
+ CompiledSQL(
+ "INSERT INTO a (type, data, xcol) "
+ "VALUES (:type, :data, :xcol) "
+ f"RETURNING a.id, a.type, a.data, a.xcol, a.y{single_inh}",
+ [{"type": "a", "data": "d8", "xcol": 7}],
+ ),
+ )
+ else:
+ asserter.assert_(
+ CompiledSQL(
+ "INSERT INTO a (type, data, xcol) VALUES "
+ "(:type, :data, :xcol)",
+ [
+ {"type": "a", "data": "d3", "xcol": 5},
+ {"type": "a", "data": "d4", "xcol": 6},
+ ],
+ ),
+ CompiledSQL(
+ "INSERT INTO a (type, data) VALUES (:type, :data)",
+ [{"type": "a", "data": "d5"}],
+ ),
+ CompiledSQL(
+ "INSERT INTO a (type, data, xcol, y) "
+ "VALUES (:type, :data, :xcol, :y)",
+ [
+ {"type": "a", "data": "d6", "xcol": 8, "y": 9},
+ {"type": "a", "data": "d7", "xcol": 12, "y": 12},
+ ],
+ ),
+ CompiledSQL(
+ "INSERT INTO a (type, data, xcol) "
+ "VALUES (:type, :data, :xcol)",
+ [{"type": "a", "data": "d8", "xcol": 7}],
+ ),
+ )
+
+ if use_returning:
+ eq_(
+ result.scalars().all(),
+ [
+ A(data="d3", id=mock.ANY, type="a", x=5, y=None),
+ A(data="d4", id=mock.ANY, type="a", x=6, y=None),
+ A(data="d5", id=mock.ANY, type="a", x=None, y=None),
+ A(data="d6", id=mock.ANY, type="a", x=8, y=9),
+ A(data="d7", id=mock.ANY, type="a", x=12, y=12),
+ A(data="d8", id=mock.ANY, type="a", x=7, y=None),
+ ],
+ )
+
+ @testing.combinations(
+ "strings",
+ "cols",
+ "strings_w_exprs",
+ "cols_w_exprs",
+ argnames="paramstyle",
+ )
+ @testing.combinations(
+ True,
+ (False, testing.requires.multivalues_inserts),
+ argnames="single_element",
+ )
+ def test_single_values_returning_fn(self, paramstyle, single_element):
+ """test using insert().values().
+
+ these INSERT statements go straight in as a single execute without any
+ insertmanyreturning or bulk_insert_mappings thing going on. the
+ advantage here is that SQL expressions can be used in the values also.
+ Disadvantage is none of the automation for inheritance mappers.
+
+ """
+ A, B = self.classes("A", "B")
+
+ if paramstyle == "strings":
+ values = [
+ {"data": "d3", "x": 5, "y": 9, "type": "a"},
+ {"data": "d4", "x": 10, "y": 8, "type": "a"},
+ ]
+ elif paramstyle == "cols":
+ values = [
+ {A.data: "d3", A.x: 5, A.y: 9, A.type: "a"},
+ {A.data: "d4", A.x: 10, A.y: 8, A.type: "a"},
+ ]
+ elif paramstyle == "strings_w_exprs":
+ values = [
+ {"data": func.lower("D3"), "x": 5, "y": 9, "type": "a"},
+ {
+ "data": "d4",
+ "x": literal_column("5") + 5,
+ "y": 8,
+ "type": "a",
+ },
+ ]
+ elif paramstyle == "cols_w_exprs":
+ values = [
+ {A.data: func.lower("D3"), A.x: 5, A.y: 9, A.type: "a"},
+ {
+ A.data: "d4",
+ A.x: literal_column("5") + 5,
+ A.y: 8,
+ A.type: "a",
+ },
+ ]
+ else:
+ assert False
+
+ s = fixture_session()
+
+ if single_element:
+ if paramstyle.startswith("strings"):
+ stmt = (
+ insert(A)
+ .values(**values[0])
+ .returning(A, func.upper(A.data, type_=String))
+ )
+ else:
+ stmt = (
+ insert(A)
+ .values(values[0])
+ .returning(A, func.upper(A.data, type_=String))
+ )
+ else:
+ stmt = (
+ insert(A)
+ .values(values)
+ .returning(A, func.upper(A.data, type_=String))
+ )
+
+ for i in range(3):
+ result = s.execute(stmt)
+ expected: List[Any] = [(A(data="d3", x=5, y=9), "D3")]
+ if not single_element:
+ expected.append((A(data="d4", x=10, y=8), "D4"))
+ eq_(result.all(), expected)
+
+ def test_bulk_w_sql_expressions(self):
+ A, B = self.classes("A", "B")
+
+ data = [
+ {"x": 5, "y": 9, "type": "a"},
+ {
+ "x": 10,
+ "y": 8,
+ "type": "a",
+ },
+ ]
+
+ s = fixture_session()
+
+ stmt = (
+ insert(A)
+ .values(data=func.lower("DD"))
+ .returning(A, func.upper(A.data, type_=String))
+ )
+
+ for i in range(3):
+ result = s.execute(stmt, data)
+ expected: List[Any] = [
+ (A(data="dd", x=5, y=9), "DD"),
+ (A(data="dd", x=10, y=8), "DD"),
+ ]
+ eq_(result.all(), expected)
+
+ def test_bulk_w_sql_expressions_subclass(self):
+ A, B = self.classes("A", "B")
+
+ data = [
+ {"bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4},
+ {"bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8},
+ ]
+
+ s = fixture_session()
+
+ stmt = (
+ insert(B)
+ .values(data=func.lower("DD"))
+ .returning(B, func.upper(B.data, type_=String))
+ )
+
+ for i in range(3):
+ result = s.execute(stmt, data)
+ expected: List[Any] = [
+ (B(bd="bd1", data="dd", q=4, type="b", x=1, y=2, z=3), "DD"),
+ (B(bd="bd2", data="dd", q=8, type="b", x=5, y=6, z=7), "DD"),
+ ]
+ eq_(result.all(), expected)
+
+ @testing.combinations(True, False, argnames="use_ordered")
+ def test_bulk_upd_w_sql_expressions_no_ordered_values(self, use_ordered):
+ A, B = self.classes("A", "B")
+
+ s = fixture_session()
+
+ stmt = update(B).ordered_values(
+ ("data", func.lower("DD_UPDATE")),
+ ("z", literal_column("3 + 12")),
+ )
+ with expect_raises_message(
+ exc.InvalidRequestError,
+ r"bulk ORM UPDATE does not support ordered_values\(\) "
+ r"for custom UPDATE",
+ ):
+ s.execute(
+ stmt,
+ [
+ {"id": 5, "bd": "bd1_updated"},
+ {"id": 6, "bd": "bd2_updated"},
+ ],
+ )
+
+ def test_bulk_upd_w_sql_expressions_subclass(self):
+ A, B = self.classes("A", "B")
+
+ s = fixture_session()
+
+ data = [
+ {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4},
+ {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8},
+ ]
+ ids = s.scalars(insert(B).returning(B.id), data).all()
+
+ stmt = update(B).values(
+ data=func.lower("DD_UPDATE"), z=literal_column("3 + 12")
+ )
+
+ result = s.execute(
+ stmt,
+ [
+ {"id": ids[0], "bd": "bd1_updated"},
+ {"id": ids[1], "bd": "bd2_updated"},
+ ],
+ )
+
+ # this is a nullresult at the moment
+ assert result is not None
+
+ eq_(
+ s.scalars(select(B)).all(),
+ [
+ B(
+ bd="bd1_updated",
+ data="dd_update",
+ id=ids[0],
+ q=4,
+ type="b",
+ x=1,
+ y=2,
+ z=15,
+ ),
+ B(
+ bd="bd2_updated",
+ data="dd_update",
+ id=ids[1],
+ q=8,
+ type="b",
+ x=5,
+ y=6,
+ z=15,
+ ),
+ ],
+ )
+
+ def test_single_returning_fn(self):
+ A, B = self.classes("A", "B")
+
+ s = fixture_session()
+ for i in range(3):
+ result = s.execute(
+ insert(A).returning(A, func.upper(A.data, type_=String)),
+ [{"data": "d3"}, {"data": "d4"}],
+ )
+ eq_(result.all(), [(A(data="d3"), "D3"), (A(data="d4"), "D4")])
+
+ @testing.combinations(
+ True,
+ False,
+ argnames="single_element",
+ )
+ def test_subclass_no_returning(self, single_element):
+ A, B = self.classes("A", "B")
+
+ s = fixture_session()
+
+ if single_element:
+ data = {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4}
+ else:
+ data = [
+ {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4},
+ {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8},
+ ]
+
+ result = s.execute(insert(B), data)
+ assert result._soft_closed
+
+ @testing.combinations(
+ True,
+ False,
+ argnames="single_element",
+ )
+ def test_subclass_load_only(self, single_element):
+ """test that load_only() prevents additional attributes from being
+ populated.
+
+ """
+ A, B = self.classes("A", "B")
+
+ s = fixture_session()
+
+ if single_element:
+ data = {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4}
+ else:
+ data = [
+ {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4},
+ {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8},
+ ]
+
+ for i in range(3):
+ # tests both caching and that the data dictionaries aren't
+ # mutated...
+ result = s.execute(
+ insert(B).returning(B).options(load_only(B.data, B.y, B.q)),
+ data,
+ )
+ objects = result.scalars().all()
+ for obj in objects:
+ assert "data" in obj.__dict__
+ assert "q" in obj.__dict__
+ assert "z" not in obj.__dict__
+ assert "x" not in obj.__dict__
+
+ expected = [
+ B(data="d3", bd="bd1", x=1, y=2, z=3, q=4),
+ ]
+ if not single_element:
+ expected.append(B(data="d4", bd="bd2", x=5, y=6, z=7, q=8))
+ eq_(objects, expected)
+
+ @testing.combinations(
+ True,
+ False,
+ argnames="single_element",
+ )
+ def test_subclass_load_only_doesnt_fetch_cols(self, single_element):
+ """test that when using load_only(), the actual INSERT statement
+ does not include the deferred columns
+
+ """
+ A, B = self.classes("A", "B")
+
+ s = fixture_session()
+
+ data = [
+ {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4},
+ {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8},
+ ]
+ if single_element:
+ data = data[0]
+
+ with self.sql_execution_asserter() as asserter:
+
+ # tests both caching and that the data dictionaries aren't
+ # mutated...
+
+ # note that if we don't put B.id here, accessing .id on the
+ # B object for joined inheritance is triggering a SELECT
+ # (and not for single inheritance). this seems not great, but is
+ # likely a different issue
+ result = s.execute(
+ insert(B)
+ .returning(B)
+ .options(load_only(B.id, B.data, B.y, B.q)),
+ data,
+ )
+ objects = result.scalars().all()
+ if single_element:
+ id0 = objects[0].id
+ id1 = None
+ else:
+ id0, id1 = objects[0].id, objects[1].id
+
+ if inspect(B).single or inspect(B).concrete:
+ expected_params = [
+ {
+ "type": "b",
+ "data": "d3",
+ "xcol": 1,
+ "y": 2,
+ "bd": "bd1",
+ "zcol": 3,
+ "q": 4,
+ },
+ {
+ "type": "b",
+ "data": "d4",
+ "xcol": 5,
+ "y": 6,
+ "bd": "bd2",
+ "zcol": 7,
+ "q": 8,
+ },
+ ]
+ if single_element:
+ expected_params[1:] = []
+ # RETURNING only includes PK, discriminator, then the cols
+ # we asked for data, y, q. xcol, z, bd are omitted
+
+ if inspect(B).single:
+ asserter.assert_(
+ CompiledSQL(
+ "INSERT INTO a (type, data, xcol, y, bd, zcol, q) "
+ "VALUES "
+ "(:type, :data, :xcol, :y, :bd, :zcol, :q) "
+ "RETURNING a.id, a.type, a.data, a.y, a.q",
+ expected_params,
+ ),
+ )
+ else:
+ asserter.assert_(
+ CompiledSQL(
+ "INSERT INTO b (type, data, xcol, y, bd, zcol, q) "
+ "VALUES "
+ "(:type, :data, :xcol, :y, :bd, :zcol, :q) "
+ "RETURNING b.id, b.type, b.data, b.y, b.q",
+ expected_params,
+ ),
+ )
+ else:
+ a_data = [
+ {"type": "b", "data": "d3", "xcol": 1, "y": 2},
+ {"type": "b", "data": "d4", "xcol": 5, "y": 6},
+ ]
+ b_data = [
+ {"id": id0, "bd": "bd1", "zcol": 3, "q": 4},
+ {"id": id1, "bd": "bd2", "zcol": 7, "q": 8},
+ ]
+ if single_element:
+ a_data[1:] = []
+ b_data[1:] = []
+ # RETURNING only includes PK, discriminator, then the cols
+ # we asked for data, y, q. xcol, z, bd are omitted. plus they
+ # are broken out correctly in the two statements.
+ asserter.assert_(
+ CompiledSQL(
+ "INSERT INTO a (type, data, xcol, y) VALUES "
+ "(:type, :data, :xcol, :y) "
+ "RETURNING a.id, a.type, a.data, a.y",
+ a_data,
+ ),
+ CompiledSQL(
+ "INSERT INTO b (id, bd, zcol, q) "
+ "VALUES (:id, :bd, :zcol, :q) "
+ "RETURNING b.id, b.q",
+ b_data,
+ ),
+ )
+
+ @testing.combinations(
+ True,
+ False,
+ argnames="single_element",
+ )
+ def test_subclass_returning_bind_expr(self, single_element):
+ A, B = self.classes("A", "B")
+
+ s = fixture_session()
+
+ if single_element:
+ data = {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4}
+ else:
+ data = [
+ {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4},
+ {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8},
+ ]
+ # note there's a fix in compiler.py ->
+ # _deliver_insertmanyvalues_batches
+ # for this re: the parameter rendering that isn't tested anywhere
+ # else. two different versions of the bug for both positional
+ # and non
+ result = s.execute(insert(B).returning(B.data, B.y, B.q + 5), data)
+ if single_element:
+ eq_(result.all(), [("d3", 2, 9)])
+ else:
+ eq_(result.all(), [("d3", 2, 9), ("d4", 6, 13)])
+
+ def test_subclass_bulk_update(self):
+ A, B = self.classes("A", "B")
+
+ s = fixture_session()
+
+ data = [
+ {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4},
+ {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8},
+ ]
+ ids = s.scalars(insert(B).returning(B.id), data).all()
+
+ result = s.execute(
+ update(B),
+ [
+ {"id": ids[0], "data": "d3_updated", "bd": "bd1_updated"},
+ {"id": ids[1], "data": "d4_updated", "bd": "bd2_updated"},
+ ],
+ )
+
+ # this is a nullresult at the moment
+ assert result is not None
+
+ eq_(
+ s.scalars(select(B)).all(),
+ [
+ B(
+ bd="bd1_updated",
+ data="d3_updated",
+ id=ids[0],
+ q=4,
+ type="b",
+ x=1,
+ y=2,
+ z=3,
+ ),
+ B(
+ bd="bd2_updated",
+ data="d4_updated",
+ id=ids[1],
+ q=8,
+ type="b",
+ x=5,
+ y=6,
+ z=7,
+ ),
+ ],
+ )
+
+ @testing.combinations(True, False, argnames="single_element")
+ def test_subclass_return_just_subclass_ids(self, single_element):
+ A, B = self.classes("A", "B")
+
+ s = fixture_session()
+
+ if single_element:
+ data = {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4}
+ else:
+ data = [
+ {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4},
+ {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8},
+ ]
+
+ ids = s.scalars(insert(B).returning(B.id), data).all()
+ actual_ids = s.scalars(select(B.id).order_by(B.data)).all()
+
+ eq_(ids, actual_ids)
+
+ @testing.combinations(
+ "orm",
+ "bulk",
+ argnames="insert_strategy",
+ )
+ @testing.requires.provisioned_upsert
+ def test_base_class_upsert(self, insert_strategy):
+ """upsert is really tricky. if you dont have any data updated,
+ then you dont get the rows back and things dont work so well.
+
+ so we need to be careful how much we document this because this is
+ still a thorny use case.
+
+ """
+ A = self.classes.A
+
+ s = fixture_session()
+
+ initial_data = [
+ {"data": "d3", "x": 1, "y": 2, "q": 4},
+ {"data": "d4", "x": 5, "y": 6, "q": 8},
+ ]
+ ids = s.scalars(insert(A).returning(A.id), initial_data).all()
+
+ upsert_data = [
+ {
+ "id": ids[0],
+ "type": "a",
+ "data": "d3",
+ "x": 1,
+ "y": 2,
+ },
+ {
+ "id": 32,
+ "type": "a",
+ "data": "d32",
+ "x": 19,
+ "y": 5,
+ },
+ {
+ "id": ids[1],
+ "type": "a",
+ "data": "d4",
+ "x": 5,
+ "y": 6,
+ },
+ {
+ "id": 28,
+ "type": "a",
+ "data": "d28",
+ "x": 9,
+ "y": 15,
+ },
+ ]
+
+ stmt = provision.upsert(
+ config,
+ A,
+ (A,),
+ lambda inserted: {"data": inserted.data + " upserted"},
+ )
+
+ if insert_strategy == "orm":
+ result = s.scalars(stmt.values(upsert_data))
+ elif insert_strategy == "bulk":
+ result = s.scalars(stmt, upsert_data)
+ else:
+ assert False
+
+ eq_(
+ result.all(),
+ [
+ A(data="d3 upserted", id=ids[0], type="a", x=1, y=2),
+ A(data="d32", id=32, type="a", x=19, y=5),
+ A(data="d4 upserted", id=ids[1], type="a", x=5, y=6),
+ A(data="d28", id=28, type="a", x=9, y=15),
+ ],
+ )
+
+ @testing.combinations(
+ "orm",
+ "bulk",
+ argnames="insert_strategy",
+ )
+ @testing.requires.provisioned_upsert
+ def test_subclass_upsert(self, insert_strategy):
+ """note this is overridden in the joined version to expect failure"""
+
+ A, B = self.classes("A", "B")
+
+ s = fixture_session()
+
+ idd3 = 1
+ idd4 = 2
+ id32 = 32
+ id28 = 28
+
+ initial_data = [
+ {
+ "id": idd3,
+ "data": "d3",
+ "bd": "bd1",
+ "x": 1,
+ "y": 2,
+ "z": 3,
+ "q": 4,
+ },
+ {
+ "id": idd4,
+ "data": "d4",
+ "bd": "bd2",
+ "x": 5,
+ "y": 6,
+ "z": 7,
+ "q": 8,
+ },
+ ]
+ ids = s.scalars(insert(B).returning(B.id), initial_data).all()
+
+ upsert_data = [
+ {
+ "id": ids[0],
+ "type": "b",
+ "data": "d3",
+ "bd": "bd1_upserted",
+ "x": 1,
+ "y": 2,
+ "z": 33,
+ "q": 44,
+ },
+ {
+ "id": id32,
+ "type": "b",
+ "data": "d32",
+ "bd": "bd 32",
+ "x": 19,
+ "y": 5,
+ "z": 20,
+ "q": 21,
+ },
+ {
+ "id": ids[1],
+ "type": "b",
+ "bd": "bd2_upserted",
+ "data": "d4",
+ "x": 5,
+ "y": 6,
+ "z": 77,
+ "q": 88,
+ },
+ {
+ "id": id28,
+ "type": "b",
+ "data": "d28",
+ "bd": "bd 28",
+ "x": 9,
+ "y": 15,
+ "z": 10,
+ "q": 11,
+ },
+ ]
+
+ stmt = provision.upsert(
+ config,
+ B,
+ (B,),
+ lambda inserted: {
+ "data": inserted.data + " upserted",
+ "bd": inserted.bd + " upserted",
+ },
+ )
+ result = s.scalars(stmt, upsert_data)
+ eq_(
+ result.all(),
+ [
+ B(
+ bd="bd1_upserted upserted",
+ data="d3 upserted",
+ id=ids[0],
+ q=4,
+ type="b",
+ x=1,
+ y=2,
+ z=3,
+ ),
+ B(
+ bd="bd 32",
+ data="d32",
+ id=32,
+ q=21,
+ type="b",
+ x=19,
+ y=5,
+ z=20,
+ ),
+ B(
+ bd="bd2_upserted upserted",
+ data="d4 upserted",
+ id=ids[1],
+ q=8,
+ type="b",
+ x=5,
+ y=6,
+ z=7,
+ ),
+ B(
+ bd="bd 28",
+ data="d28",
+ id=28,
+ q=11,
+ type="b",
+ x=9,
+ y=15,
+ z=10,
+ ),
+ ],
+ )
+
+
+class BulkDMLReturningJoinedInhTest(
+ BulkDMLReturningInhTest, fixtures.DeclarativeMappedTest
+):
+
+ __requires__ = ("insert_returning",)
+ __backend__ = True
+
+ @classmethod
+ def setup_classes(cls):
+ decl_base = cls.DeclarativeBasic
+
+ class A(fixtures.ComparableEntity, decl_base):
+ __tablename__ = "a"
+ id: Mapped[int] = mapped_column(Identity(), primary_key=True)
+ type: Mapped[str]
+ data: Mapped[str]
+ x: Mapped[Optional[int]] = mapped_column("xcol")
+ y: Mapped[Optional[int]]
+
+ __mapper_args__ = {
+ "polymorphic_identity": "a",
+ "polymorphic_on": "type",
+ }
+
+ class B(A):
+ __tablename__ = "b"
+ id: Mapped[int] = mapped_column(
+ ForeignKey("a.id"), primary_key=True
+ )
+ bd: Mapped[str]
+ z: Mapped[Optional[int]] = mapped_column("zcol")
+ q: Mapped[Optional[int]]
+
+ __mapper_args__ = {"polymorphic_identity": "b"}
+
+ @testing.combinations(
+ "orm",
+ "bulk",
+ argnames="insert_strategy",
+ )
+ @testing.combinations(
+ True,
+ False,
+ argnames="single_param",
+ )
+ @testing.requires.provisioned_upsert
+ def test_subclass_upsert(self, insert_strategy, single_param):
+ A, B = self.classes("A", "B")
+
+ s = fixture_session()
+
+ initial_data = [
+ {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4},
+ {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8},
+ ]
+ ids = s.scalars(insert(B).returning(B.id), initial_data).all()
+
+ upsert_data = [
+ {
+ "id": ids[0],
+ "type": "b",
+ },
+ {
+ "id": 32,
+ "type": "b",
+ },
+ ]
+ if single_param:
+ upsert_data = upsert_data[0]
+
+ stmt = provision.upsert(
+ config,
+ B,
+ (B,),
+ lambda inserted: {
+ "bd": inserted.bd + " upserted",
+ },
+ )
+
+ with expect_raises_message(
+ exc.InvalidRequestError,
+ r"bulk INSERT with a 'post values' clause \(typically upsert\) "
+ r"not supported for multi-table mapper",
+ ):
+ s.scalars(stmt, upsert_data)
+
+
+class BulkDMLReturningSingleInhTest(
+ BulkDMLReturningInhTest, fixtures.DeclarativeMappedTest
+):
+ __requires__ = ("insert_returning",)
+ __backend__ = True
+
+ @classmethod
+ def setup_classes(cls):
+ decl_base = cls.DeclarativeBasic
+
+ class A(fixtures.ComparableEntity, decl_base):
+ __tablename__ = "a"
+ id: Mapped[int] = mapped_column(Identity(), primary_key=True)
+ type: Mapped[str]
+ data: Mapped[str]
+ x: Mapped[Optional[int]] = mapped_column("xcol")
+ y: Mapped[Optional[int]]
+
+ __mapper_args__ = {
+ "polymorphic_identity": "a",
+ "polymorphic_on": "type",
+ }
+
+ class B(A):
+ bd: Mapped[str] = mapped_column(nullable=True)
+ z: Mapped[Optional[int]] = mapped_column("zcol")
+ q: Mapped[Optional[int]]
+
+ __mapper_args__ = {"polymorphic_identity": "b"}
+
+
+class BulkDMLReturningConcreteInhTest(
+ BulkDMLReturningInhTest, fixtures.DeclarativeMappedTest
+):
+ __requires__ = ("insert_returning",)
+ __backend__ = True
+
+ @classmethod
+ def setup_classes(cls):
+ decl_base = cls.DeclarativeBasic
+
+ class A(fixtures.ComparableEntity, decl_base):
+ __tablename__ = "a"
+ id: Mapped[int] = mapped_column(Identity(), primary_key=True)
+ type: Mapped[str]
+ data: Mapped[str]
+ x: Mapped[Optional[int]] = mapped_column("xcol")
+ y: Mapped[Optional[int]]
+
+ __mapper_args__ = {
+ "polymorphic_identity": "a",
+ "polymorphic_on": "type",
+ }
+
+ class B(A):
+ __tablename__ = "b"
+ id: Mapped[int] = mapped_column(Identity(), primary_key=True)
+ type: Mapped[str]
+ data: Mapped[str]
+ x: Mapped[Optional[int]] = mapped_column("xcol")
+ y: Mapped[Optional[int]]
+
+ bd: Mapped[str] = mapped_column(nullable=True)
+ z: Mapped[Optional[int]] = mapped_column("zcol")
+ q: Mapped[Optional[int]]
+
+ __mapper_args__ = {
+ "polymorphic_identity": "b",
+ "concrete": True,
+ "polymorphic_on": "type",
+ }
+
+
+class CTETest(fixtures.DeclarativeMappedTest):
+ __requires__ = ("insert_returning", "ctes_on_dml")
+ __backend__ = True
+
+ @classmethod
+ def setup_classes(cls):
+ decl_base = cls.DeclarativeBasic
+
+ class User(fixtures.ComparableEntity, decl_base):
+ __tablename__ = "users"
+ id: Mapped[uuid.UUID] = mapped_column(primary_key=True)
+ username: Mapped[str]
+
+ @testing.combinations(
+ ("cte_aliased", True),
+ ("cte", False),
+ argnames="wrap_cte_in_aliased",
+ id_="ia",
+ )
+ @testing.combinations(
+ ("use_union", True),
+ ("no_union", False),
+ argnames="use_a_union",
+ id_="ia",
+ )
+ @testing.combinations(
+ "from_statement", "aliased", "direct", argnames="fetch_entity_type"
+ )
+ def test_select_from_insert_cte(
+ self, wrap_cte_in_aliased, use_a_union, fetch_entity_type
+ ):
+ """test the use case from #8544; SELECT that selects from a
+ CTE INSERT...RETURNING.
+
+ """
+ User = self.classes.User
+
+ id_ = uuid.uuid4()
+
+ cte = (
+ insert(User)
+ .values(id=id_, username="some user")
+ .returning(User)
+ .cte()
+ )
+ if wrap_cte_in_aliased:
+ cte = aliased(User, cte)
+
+ if use_a_union:
+ stmt = select(User).where(User.id == id_).union(select(cte))
+ else:
+ stmt = select(cte)
+
+ if fetch_entity_type == "from_statement":
+ outer_stmt = select(User).from_statement(stmt)
+ expect_entity = True
+ elif fetch_entity_type == "aliased":
+ outer_stmt = select(aliased(User, stmt.subquery()))
+ expect_entity = True
+ elif fetch_entity_type == "direct":
+ outer_stmt = stmt
+ expect_entity = not use_a_union and wrap_cte_in_aliased
+ else:
+ assert False
+
+ sess = fixture_session()
+ with self.sql_execution_asserter() as asserter:
+
+ if not expect_entity:
+ row = sess.execute(outer_stmt).one()
+ eq_(row, (id_, "some user"))
+ else:
+ new_user = sess.scalars(outer_stmt).one()
+ eq_(new_user, User(id=id_, username="some user"))
+
+ cte_sql = (
+ "(INSERT INTO users (id, username) "
+ "VALUES (:param_1, :param_2) "
+ "RETURNING users.id, users.username)"
+ )
+
+ if fetch_entity_type == "aliased" and not use_a_union:
+ expected = (
+ f"WITH anon_2 AS {cte_sql} "
+ "SELECT anon_1.id, anon_1.username "
+ "FROM (SELECT anon_2.id AS id, anon_2.username AS username "
+ "FROM anon_2) AS anon_1"
+ )
+ elif not use_a_union:
+ expected = (
+ f"WITH anon_1 AS {cte_sql} "
+ "SELECT anon_1.id, anon_1.username FROM anon_1"
+ )
+ elif fetch_entity_type == "aliased":
+ expected = (
+ f"WITH anon_2 AS {cte_sql} SELECT anon_1.id, anon_1.username "
+ "FROM (SELECT users.id AS id, users.username AS username "
+ "FROM users WHERE users.id = :id_1 "
+ "UNION SELECT anon_2.id AS id, anon_2.username AS username "
+ "FROM anon_2) AS anon_1"
+ )
+ else:
+ expected = (
+ f"WITH anon_1 AS {cte_sql} "
+ "SELECT users.id, users.username FROM users "
+ "WHERE users.id = :id_1 "
+ "UNION SELECT anon_1.id, anon_1.username FROM anon_1"
+ )
+
+ asserter.assert_(
+ CompiledSQL(expected, [{"param_1": id_, "param_2": "some user"}])
+ )
diff --git a/test/orm/test_evaluator.py b/test/orm/dml/test_evaluator.py
index ff40cd201..4b903b863 100644
--- a/test/orm/test_evaluator.py
+++ b/test/orm/dml/test_evaluator.py
@@ -324,7 +324,6 @@ class EvaluateTest(fixtures.MappedTest):
"""test #3162"""
User = self.classes.User
-
with expect_raises_message(
evaluator.UnevaluatableError,
r"Custom operator '\^\^' can't be evaluated in "
diff --git a/test/orm/test_update_delete.py b/test/orm/dml/test_update_delete_where.py
index 1e93f88de..836feb659 100644
--- a/test/orm/test_update_delete.py
+++ b/test/orm/dml/test_update_delete_where.py
@@ -1,3 +1,4 @@
+from sqlalchemy import bindparam
from sqlalchemy import Boolean
from sqlalchemy import case
from sqlalchemy import column
@@ -7,6 +8,7 @@ from sqlalchemy import exc
from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy import insert
+from sqlalchemy import inspect
from sqlalchemy import Integer
from sqlalchemy import lambda_stmt
from sqlalchemy import MetaData
@@ -17,6 +19,7 @@ from sqlalchemy import testing
from sqlalchemy import text
from sqlalchemy import update
from sqlalchemy.orm import backref
+from sqlalchemy.orm import exc as orm_exc
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import relationship
from sqlalchemy.orm import Session
@@ -26,6 +29,7 @@ from sqlalchemy.orm import with_loader_criteria
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import in_
from sqlalchemy.testing import not_in
@@ -123,6 +127,25 @@ class UpdateDeleteTest(fixtures.MappedTest):
},
)
+ def test_update_dont_use_col_key(self):
+ User = self.classes.User
+
+ s = fixture_session()
+
+ # make sure objects are present to synchronize
+ _ = s.query(User).all()
+
+ with expect_raises_message(
+ exc.InvalidRequestError,
+ "Attribute name not found, can't be synchronized back "
+ "to objects: 'age_int'",
+ ):
+ s.execute(update(User).values(age_int=5))
+
+ stmt = update(User).values(age=5)
+ s.execute(stmt)
+ eq_(s.scalars(select(User.age)).all(), [5, 5, 5, 5])
+
@testing.combinations("table", "mapper", "both", argnames="bind_type")
@testing.combinations(
"update", "insert", "delete", argnames="statement_type"
@@ -162,7 +185,7 @@ class UpdateDeleteTest(fixtures.MappedTest):
assert_raises_message(
exc.ArgumentError,
"Valid strategies for session synchronization "
- "are 'evaluate', 'fetch', False",
+ "are 'auto', 'evaluate', 'fetch', False",
s.query(User).update,
{},
synchronize_session="fake",
@@ -351,6 +374,12 @@ class UpdateDeleteTest(fixtures.MappedTest):
def test_evaluate_dont_refresh_expired_objects(
self, expire_jane_age, add_filter_criteria
):
+ """test #5664.
+
+ approach is revised in SQLAlchemy 2.0 to not pre-emptively
+ unexpire the involved attributes
+
+ """
User = self.classes.User
sess = fixture_session()
@@ -379,15 +408,10 @@ class UpdateDeleteTest(fixtures.MappedTest):
if add_filter_criteria:
if expire_jane_age:
asserter.assert_(
- # it has to unexpire jane.name, because jane is not fully
- # expired and the criteria needs to look at this particular
- # key
- CompiledSQL(
- "SELECT users.age_int AS users_age_int, "
- "users.name AS users_name FROM users "
- "WHERE users.id = :pk_1",
- [{"pk_1": 4}],
- ),
+ # previously, this would unexpire the attribute and
+ # cause an additional SELECT. The
+ # 2.0 approach is that if the object has expired attrs
+ # we just expire the whole thing, avoiding SQL up front
CompiledSQL(
"UPDATE users "
"SET age_int=(users.age_int + :age_int_1) "
@@ -397,14 +421,10 @@ class UpdateDeleteTest(fixtures.MappedTest):
)
else:
asserter.assert_(
- # it has to unexpire jane.name, because jane is not fully
- # expired and the criteria needs to look at this particular
- # key
- CompiledSQL(
- "SELECT users.name AS users_name FROM users "
- "WHERE users.id = :pk_1",
- [{"pk_1": 4}],
- ),
+ # previously, this would unexpire the attribute and
+ # cause an additional SELECT. The
+ # 2.0 approach is that if the object has expired attrs
+ # we just expire the whole thing, avoiding SQL up front
CompiledSQL(
"UPDATE users SET "
"age_int=(users.age_int + :age_int_1) "
@@ -443,9 +463,9 @@ class UpdateDeleteTest(fixtures.MappedTest):
),
]
- if expire_jane_age and not add_filter_criteria:
+ if expire_jane_age:
to_assert.append(
- # refresh jane
+ # refresh jane for partial attributes
CompiledSQL(
"SELECT users.age_int AS users_age_int, "
"users.name AS users_name FROM users "
@@ -455,6 +475,75 @@ class UpdateDeleteTest(fixtures.MappedTest):
)
asserter.assert_(*to_assert)
+ @testing.combinations(True, False, argnames="is_evaluable")
+ def test_auto_synchronize(self, is_evaluable):
+ User = self.classes.User
+
+ sess = fixture_session()
+
+ john, jack, jill, jane = sess.query(User).order_by(User.id).all()
+
+ if is_evaluable:
+ crit = or_(User.name == "jack", User.name == "jane")
+ else:
+ crit = case((User.name.in_(["jack", "jane"]), True), else_=False)
+
+ with self.sql_execution_asserter() as asserter:
+ sess.execute(update(User).where(crit).values(age=User.age + 10))
+
+ if is_evaluable:
+ asserter.assert_(
+ CompiledSQL(
+ "UPDATE users SET age_int=(users.age_int + :age_int_1) "
+ "WHERE users.name = :name_1 OR users.name = :name_2",
+ [{"age_int_1": 10, "name_1": "jack", "name_2": "jane"}],
+ ),
+ )
+ elif testing.db.dialect.update_returning:
+ asserter.assert_(
+ CompiledSQL(
+ "UPDATE users SET age_int=(users.age_int + :age_int_1) "
+ "WHERE CASE WHEN (users.name IN (__[POSTCOMPILE_name_1])) "
+ "THEN :param_1 ELSE :param_2 END = 1 RETURNING users.id",
+ [
+ {
+ "age_int_1": 10,
+ "name_1": ["jack", "jane"],
+ "param_1": True,
+ "param_2": False,
+ }
+ ],
+ ),
+ )
+ else:
+ asserter.assert_(
+ CompiledSQL(
+ "SELECT users.id FROM users WHERE CASE WHEN "
+ "(users.name IN (__[POSTCOMPILE_name_1])) "
+ "THEN :param_1 ELSE :param_2 END = 1",
+ [
+ {
+ "name_1": ["jack", "jane"],
+ "param_1": True,
+ "param_2": False,
+ }
+ ],
+ ),
+ CompiledSQL(
+ "UPDATE users SET age_int=(users.age_int + :age_int_1) "
+ "WHERE CASE WHEN (users.name IN (__[POSTCOMPILE_name_1])) "
+ "THEN :param_1 ELSE :param_2 END = 1",
+ [
+ {
+ "age_int_1": 10,
+ "name_1": ["jack", "jane"],
+ "param_1": True,
+ "param_2": False,
+ }
+ ],
+ ),
+ )
+
def test_fetch_dont_refresh_expired_objects(self):
User = self.classes.User
@@ -518,17 +607,25 @@ class UpdateDeleteTest(fixtures.MappedTest):
),
)
- def test_delete(self):
+ @testing.combinations(False, None, "auto", "evaluate", "fetch")
+ def test_delete(self, synchronize_session):
User = self.classes.User
sess = fixture_session()
john, jack, jill, jane = sess.query(User).order_by(User.id).all()
- sess.query(User).filter(
+
+ stmt = delete(User).filter(
or_(User.name == "john", User.name == "jill")
- ).delete()
+ )
+ if synchronize_session is not None:
+ stmt = stmt.execution_options(
+ synchronize_session=synchronize_session
+ )
+ sess.execute(stmt)
- assert john not in sess and jill not in sess
+ if synchronize_session not in (False, None):
+ assert john not in sess and jill not in sess
eq_(sess.query(User).order_by(User.id).all(), [jack, jane])
@@ -629,6 +726,33 @@ class UpdateDeleteTest(fixtures.MappedTest):
eq_(sess.query(User).order_by(User.id).all(), [jack, jill, jane])
+ def test_update_multirow_not_supported(self):
+ User = self.classes.User
+
+ sess = fixture_session()
+
+ with expect_raises_message(
+ exc.InvalidRequestError,
+ "WHERE clause with bulk ORM UPDATE not supported " "right now.",
+ ):
+ sess.execute(
+ update(User).where(User.id == bindparam("id")),
+ [{"id": 1, "age": 27}, {"id": 2, "age": 37}],
+ )
+
+ def test_delete_bulk_not_supported(self):
+ User = self.classes.User
+
+ sess = fixture_session()
+
+ with expect_raises_message(
+ exc.InvalidRequestError, "Bulk ORM DELETE not supported right now."
+ ):
+ sess.execute(
+ delete(User),
+ [{"id": 1}, {"id": 2}],
+ )
+
def test_update(self):
User, users = self.classes.User, self.tables.users
@@ -640,6 +764,7 @@ class UpdateDeleteTest(fixtures.MappedTest):
)
eq_([john.age, jack.age, jill.age, jane.age], [25, 37, 29, 27])
+
eq_(
sess.query(User.age).order_by(User.id).all(),
list(zip([25, 37, 29, 27])),
@@ -974,7 +1099,7 @@ class UpdateDeleteTest(fixtures.MappedTest):
)
@testing.requires.update_returning
- def test_update_explicit_returning(self):
+ def test_update_evaluate_w_explicit_returning(self):
User = self.classes.User
sess = fixture_session()
@@ -987,6 +1112,7 @@ class UpdateDeleteTest(fixtures.MappedTest):
.filter(User.age > 29)
.values({"age": User.age - 10})
.returning(User.id)
+ .execution_options(synchronize_session="evaluate")
)
rows = sess.execute(stmt).all()
@@ -1006,24 +1132,41 @@ class UpdateDeleteTest(fixtures.MappedTest):
)
@testing.requires.update_returning
- def test_no_fetch_w_explicit_returning(self):
+ @testing.combinations("update", "delete", argnames="crud_type")
+ def test_fetch_w_explicit_returning(self, crud_type):
User = self.classes.User
sess = fixture_session()
- stmt = (
- update(User)
- .filter(User.age > 29)
- .values({"age": User.age - 10})
- .execution_options(synchronize_session="fetch")
- .returning(User.id)
- )
- with expect_raises_message(
- exc.InvalidRequestError,
- r"Can't use synchronize_session='fetch' "
- r"with explicit returning\(\)",
- ):
- sess.execute(stmt)
+ if crud_type == "update":
+ stmt = (
+ update(User)
+ .filter(User.age > 29)
+ .values({"age": User.age - 10})
+ .execution_options(synchronize_session="fetch")
+ .returning(User, User.name)
+ )
+ expected = [
+ (User(age=37), "jack"),
+ (User(age=27), "jane"),
+ ]
+ elif crud_type == "delete":
+ stmt = (
+ delete(User)
+ .filter(User.age > 29)
+ .execution_options(synchronize_session="fetch")
+ .returning(User, User.name)
+ )
+ expected = [
+ (User(age=47), "jack"),
+ (User(age=37), "jane"),
+ ]
+ else:
+ assert False
+
+ result = sess.execute(stmt)
+
+ eq_(result.all(), expected)
@testing.combinations(True, False, argnames="implicit_returning")
def test_delete_fetch_returning(self, implicit_returning):
@@ -1142,7 +1285,8 @@ class UpdateDeleteTest(fixtures.MappedTest):
list(zip([25, 47, 44, 37])),
)
- def test_update_changes_resets_dirty(self):
+ @testing.combinations("orm", "bulk")
+ def test_update_changes_resets_dirty(self, update_type):
User = self.classes.User
sess = fixture_session(autoflush=False)
@@ -1155,9 +1299,30 @@ class UpdateDeleteTest(fixtures.MappedTest):
# autoflush is false. therefore our '50' and '37' are getting
# blown away by this operation.
- sess.query(User).filter(User.age > 29).update(
- {"age": User.age - 10}, synchronize_session="evaluate"
- )
+ if update_type == "orm":
+ sess.execute(
+ update(User)
+ .filter(User.age > 29)
+ .values({"age": User.age - 10}),
+ execution_options=dict(synchronize_session="evaluate"),
+ )
+ elif update_type == "bulk":
+
+ data = [
+ {"id": john.id, "age": 25},
+ {"id": jack.id, "age": 37},
+ {"id": jill.id, "age": 29},
+ {"id": jane.id, "age": 27},
+ ]
+
+ sess.execute(
+ update(User),
+ data,
+ execution_options=dict(synchronize_session="evaluate"),
+ )
+
+ else:
+ assert False
for x in (john, jack, jill, jane):
assert not sess.is_modified(x)
@@ -1171,6 +1336,93 @@ class UpdateDeleteTest(fixtures.MappedTest):
assert not sess.is_modified(john)
assert not sess.is_modified(jack)
+ @testing.combinations(
+ None, False, "evaluate", "fetch", argnames="synchronize_session"
+ )
+ @testing.combinations(True, False, argnames="homogeneous_keys")
+ def test_bulk_update_synchronize_session(
+ self, synchronize_session, homogeneous_keys
+ ):
+ User = self.classes.User
+
+ sess = fixture_session(expire_on_commit=False)
+
+ john, jack, jill, jane = sess.query(User).order_by(User.id).all()
+
+ if homogeneous_keys:
+ data = [
+ {"id": john.id, "age": 35},
+ {"id": jack.id, "age": 27},
+ {"id": jill.id, "age": 30},
+ ]
+ else:
+ data = [
+ {"id": john.id, "age": 35},
+ {"id": jack.id, "name": "new jack"},
+ {"id": jill.id, "age": 30, "name": "new jill"},
+ ]
+
+ with self.sql_execution_asserter() as asserter:
+ if synchronize_session is not None:
+ opts = {"synchronize_session": synchronize_session}
+ else:
+ opts = {}
+
+ if synchronize_session == "fetch":
+ with expect_raises_message(
+ exc.InvalidRequestError,
+ "The 'fetch' synchronization strategy is not available "
+ "for 'bulk' ORM updates",
+ ):
+ sess.execute(update(User), data, execution_options=opts)
+ return
+ else:
+ sess.execute(update(User), data, execution_options=opts)
+
+ if homogeneous_keys:
+ asserter.assert_(
+ CompiledSQL(
+ "UPDATE users SET age_int=:age_int "
+ "WHERE users.id = :users_id",
+ [
+ {"age_int": 35, "users_id": 1},
+ {"age_int": 27, "users_id": 2},
+ {"age_int": 30, "users_id": 3},
+ ],
+ )
+ )
+ else:
+ asserter.assert_(
+ CompiledSQL(
+ "UPDATE users SET age_int=:age_int "
+ "WHERE users.id = :users_id",
+ [{"age_int": 35, "users_id": 1}],
+ ),
+ CompiledSQL(
+ "UPDATE users SET name=:name WHERE users.id = :users_id",
+ [{"name": "new jack", "users_id": 2}],
+ ),
+ CompiledSQL(
+ "UPDATE users SET name=:name, age_int=:age_int "
+ "WHERE users.id = :users_id",
+ [{"name": "new jill", "age_int": 30, "users_id": 3}],
+ ),
+ )
+
+ if synchronize_session is False:
+ eq_(jill.name, "jill")
+ eq_(jack.name, "jack")
+ eq_(jill.age, 29)
+ eq_(jack.age, 47)
+ else:
+ if not homogeneous_keys:
+ eq_(jill.name, "new jill")
+ eq_(jack.name, "new jack")
+ eq_(jack.age, 47)
+ else:
+ eq_(jack.age, 27)
+ eq_(jill.age, 30)
+
def test_update_changes_with_autoflush(self):
User = self.classes.User
@@ -1214,7 +1466,8 @@ class UpdateDeleteTest(fixtures.MappedTest):
)
@testing.fails_if(lambda: not testing.db.dialect.supports_sane_rowcount)
- def test_update_returns_rowcount(self):
+ @testing.combinations("auto", "fetch", "evaluate")
+ def test_update_returns_rowcount(self, synchronize_session):
User = self.classes.User
sess = fixture_session()
@@ -1222,20 +1475,25 @@ class UpdateDeleteTest(fixtures.MappedTest):
rowcount = (
sess.query(User)
.filter(User.age > 29)
- .update({"age": User.age + 0})
+ .update(
+ {"age": User.age + 0}, synchronize_session=synchronize_session
+ )
)
eq_(rowcount, 2)
rowcount = (
sess.query(User)
.filter(User.age > 29)
- .update({"age": User.age - 10})
+ .update(
+ {"age": User.age - 10}, synchronize_session=synchronize_session
+ )
)
eq_(rowcount, 2)
# test future
result = sess.execute(
- update(User).where(User.age > 19).values({"age": User.age - 10})
+ update(User).where(User.age > 19).values({"age": User.age - 10}),
+ execution_options={"synchronize_session": synchronize_session},
)
eq_(result.rowcount, 4)
@@ -1327,12 +1585,17 @@ class UpdateDeleteTest(fixtures.MappedTest):
)
assert john not in sess
- def test_evaluate_before_update(self):
+ @testing.combinations(True, False)
+ def test_evaluate_before_update(self, full_expiration):
User = self.classes.User
sess = fixture_session()
john = sess.query(User).filter_by(name="john").one()
- sess.expire(john, ["age"])
+
+ if full_expiration:
+ sess.expire(john)
+ else:
+ sess.expire(john, ["age"])
# eval must be before the update. otherwise
# we eval john, age has been expired and doesn't
@@ -1356,17 +1619,47 @@ class UpdateDeleteTest(fixtures.MappedTest):
eq_(john.name, "j2")
eq_(john.age, 40)
- def test_evaluate_before_delete(self):
+ @testing.combinations(True, False)
+ def test_evaluate_before_delete(self, full_expiration):
User = self.classes.User
sess = fixture_session()
john = sess.query(User).filter_by(name="john").one()
- sess.expire(john, ["age"])
+ jill = sess.query(User).filter_by(name="jill").one()
+ jane = sess.query(User).filter_by(name="jane").one()
- sess.query(User).filter_by(name="john").filter_by(age=25).delete(
+ if full_expiration:
+ sess.expire(jill)
+ sess.expire(john)
+ else:
+ sess.expire(jill, ["age"])
+ sess.expire(john, ["age"])
+
+ sess.query(User).filter(or_(User.age == 25, User.age == 37)).delete(
synchronize_session="evaluate"
)
- assert john not in sess
+
+ # was fully deleted
+ assert jane not in sess
+
+ # deleted object was expired, but not otherwise affected
+ assert jill in sess
+
+ # deleted object was expired, but not otherwise affected
+ assert john in sess
+
+ # partially expired row fully expired
+ assert inspect(jill).expired
+
+ # non-deleted row still present
+ eq_(jill.age, 29)
+
+ # partially expired row fully expired
+ assert inspect(john).expired
+
+ # is deleted
+ with expect_raises(orm_exc.ObjectDeletedError):
+ john.name
def test_fetch_before_delete(self):
User = self.classes.User
@@ -1378,6 +1671,7 @@ class UpdateDeleteTest(fixtures.MappedTest):
sess.query(User).filter_by(name="john").filter_by(age=25).delete(
synchronize_session="fetch"
)
+
assert john not in sess
def test_update_unordered_dict(self):
@@ -1495,6 +1789,60 @@ class UpdateDeleteTest(fixtures.MappedTest):
]
eq_(["name", "age_int"], cols)
+ @testing.requires.sqlite
+ def test_sharding_extension_returning_mismatch(self, testing_engine):
+ """test one horizontal shard case where the given binds don't match
+ for RETURNING support; we dont support this.
+
+ See test/ext/test_horizontal_shard.py for complete round trip
+ test cases for ORM update/delete
+
+ """
+ e1 = testing_engine("sqlite://")
+ e2 = testing_engine("sqlite://")
+ e1.connect().close()
+ e2.connect().close()
+
+ e1.dialect.update_returning = True
+ e2.dialect.update_returning = False
+
+ engines = [e1, e2]
+
+ # a simulated version of the horizontal sharding extension
+ def execute_and_instances(orm_context):
+ execution_options = dict(orm_context.local_execution_options)
+ partial = []
+ for engine in engines:
+ bind_arguments = dict(orm_context.bind_arguments)
+ bind_arguments["bind"] = engine
+ result_ = orm_context.invoke_statement(
+ bind_arguments=bind_arguments,
+ execution_options=execution_options,
+ )
+
+ partial.append(result_)
+ return partial[0].merge(*partial[1:])
+
+ User = self.classes.User
+ session = Session()
+
+ event.listen(
+ session, "do_orm_execute", execute_and_instances, retval=True
+ )
+
+ stmt = (
+ update(User)
+ .filter(User.id == 15)
+ .values(age=123)
+ .execution_options(synchronize_session="fetch")
+ )
+ with expect_raises_message(
+ exc.InvalidRequestError,
+ "For synchronize_session='fetch', can't mix multiple backends "
+ "where some support RETURNING and others don't",
+ ):
+ session.execute(stmt)
+
class UpdateDeleteIgnoresLoadersTest(fixtures.MappedTest):
@classmethod
@@ -1748,6 +2096,7 @@ class UpdateDeleteFromTest(fixtures.MappedTest):
"Could not evaluate current criteria in Python.",
q.update,
{"samename": "ed"},
+ synchronize_session="evaluate",
)
@testing.requires.multi_table_update
@@ -1901,7 +2250,7 @@ class ExpressionUpdateTest(fixtures.MappedTest):
sess.commit()
eq_(d1.cnt, 0)
- sess.query(Data).update({Data.cnt: Data.cnt + 1})
+ sess.query(Data).update({Data.cnt: Data.cnt + 1}, "evaluate")
sess.flush()
eq_(d1.cnt, 1)
@@ -2443,7 +2792,8 @@ class LoadFromReturningTest(fixtures.MappedTest):
)
@testing.requires.update_returning
- def test_load_from_update(self, connection):
+ @testing.combinations(True, False, argnames="use_from_statement")
+ def test_load_from_update(self, connection, use_from_statement):
User = self.classes.User
stmt = (
@@ -2453,7 +2803,16 @@ class LoadFromReturningTest(fixtures.MappedTest):
.returning(User)
)
- stmt = select(User).from_statement(stmt)
+ if use_from_statement:
+ # this is now a legacy-ish case, because as of 2.0 you can just
+ # use returning() directly to get the objects back.
+ #
+ # when from_statement is used, the UPDATE statement is no
+ # longer interpreted by
+ # BulkUDCompileState.orm_pre_session_exec or
+ # BulkUDCompileState.orm_setup_cursor_result. The compilation
+ # level routines still take place though
+ stmt = select(User).from_statement(stmt)
with Session(connection) as sess:
rows = sess.execute(stmt).scalars().all()
@@ -2468,7 +2827,8 @@ class LoadFromReturningTest(fixtures.MappedTest):
("multiple", testing.requires.multivalues_inserts),
argnames="params",
)
- def test_load_from_insert(self, connection, params):
+ @testing.combinations(True, False, argnames="use_from_statement")
+ def test_load_from_insert(self, connection, params, use_from_statement):
User = self.classes.User
if params == "multiple":
@@ -2484,7 +2844,8 @@ class LoadFromReturningTest(fixtures.MappedTest):
stmt = insert(User).values(values).returning(User)
- stmt = select(User).from_statement(stmt)
+ if use_from_statement:
+ stmt = select(User).from_statement(stmt)
with Session(connection) as sess:
rows = sess.execute(stmt).scalars().all()
@@ -2505,3 +2866,25 @@ class LoadFromReturningTest(fixtures.MappedTest):
)
else:
assert False
+
+ @testing.requires.delete_returning
+ @testing.combinations(True, False, argnames="use_from_statement")
+ def test_load_from_delete(self, connection, use_from_statement):
+ User = self.classes.User
+
+ stmt = (
+ delete(User).where(User.name.in_(["jack", "jill"])).returning(User)
+ )
+
+ if use_from_statement:
+ stmt = select(User).from_statement(stmt)
+
+ with Session(connection) as sess:
+ rows = sess.execute(stmt).scalars().all()
+
+ eq_(
+ rows,
+ [User(name="jack", age=47), User(name="jill", age=29)],
+ )
+
+ # TODO: state of above objects should be "deleted"
diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py
index 2e3874549..5f8cfc1f5 100644
--- a/test/orm/inheritance/test_basic.py
+++ b/test/orm/inheritance/test_basic.py
@@ -2012,7 +2012,8 @@ class JoinedNoFKSortingTest(fixtures.MappedTest):
and testing.db.dialect.supports_default_metavalue,
[
CompiledSQL(
- "INSERT INTO a (id) VALUES (DEFAULT)", [{}, {}, {}, {}]
+ "INSERT INTO a (id) VALUES (DEFAULT) RETURNING a.id",
+ [{}, {}, {}, {}],
),
],
[
diff --git a/test/orm/test_bind.py b/test/orm/test_bind.py
index a6480365d..2f392cf6e 100644
--- a/test/orm/test_bind.py
+++ b/test/orm/test_bind.py
@@ -326,6 +326,7 @@ class BindIntegrationTest(_fixtures.FixtureTest):
),
(
lambda User: update(User)
+ .execution_options(synchronize_session=False)
.values(name="not ed")
.where(User.name == "ed"),
lambda User: {"clause": mock.ANY, "mapper": inspect(User)},
@@ -392,7 +393,15 @@ class BindIntegrationTest(_fixtures.FixtureTest):
engine = {"e1": e1, "e2": e2, "e3": e3}[expected_engine_name]
with mock.patch(
- "sqlalchemy.orm.context.ORMCompileState.orm_setup_cursor_result"
+ "sqlalchemy.orm.context." "ORMCompileState.orm_setup_cursor_result"
+ ), mock.patch(
+ "sqlalchemy.orm.context.ORMCompileState.orm_execute_statement"
+ ), mock.patch(
+ "sqlalchemy.orm.bulk_persistence."
+ "BulkORMInsert.orm_execute_statement"
+ ), mock.patch(
+ "sqlalchemy.orm.bulk_persistence."
+ "BulkUDCompileState.orm_setup_cursor_result"
):
sess.execute(statement)
diff --git a/test/orm/test_composites.py b/test/orm/test_composites.py
index 3a789aff7..efa2ecb45 100644
--- a/test/orm/test_composites.py
+++ b/test/orm/test_composites.py
@@ -1,8 +1,10 @@
import dataclasses
import operator
+import random
import sqlalchemy as sa
from sqlalchemy import ForeignKey
+from sqlalchemy import insert
from sqlalchemy import Integer
from sqlalchemy import select
from sqlalchemy import String
@@ -233,7 +235,7 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
is g.edges[1]
)
- def test_bulk_update_sql(self):
+ def test_update_crit_sql(self):
Edge, Point = (self.classes.Edge, self.classes.Point)
sess = self._fixture()
@@ -258,7 +260,7 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
dialect="default",
)
- def test_bulk_update_evaluate(self):
+ def test_update_crit_evaluate(self):
Edge, Point = (self.classes.Edge, self.classes.Point)
sess = self._fixture()
@@ -287,7 +289,7 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
eq_(e1.end, Point(17, 8))
- def test_bulk_update_fetch(self):
+ def test_update_crit_fetch(self):
Edge, Point = (self.classes.Edge, self.classes.Point)
sess = self._fixture()
@@ -305,6 +307,205 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
eq_(e1.end, Point(17, 8))
+ @testing.combinations(
+ "legacy", "statement", "values", "stmt_returning", "values_returning"
+ )
+ def test_bulk_insert(self, type_):
+ Edge, Point = (self.classes.Edge, self.classes.Point)
+ Graph = self.classes.Graph
+
+ sess = self._fixture()
+
+ graph = Graph(id=2)
+ sess.add(graph)
+ sess.flush()
+ graph_id = 2
+
+ data = [
+ {
+ "start": Point(random.randint(1, 50), random.randint(1, 50)),
+ "end": Point(random.randint(1, 50), random.randint(1, 50)),
+ "graph_id": graph_id,
+ }
+ for i in range(25)
+ ]
+ returning = False
+ if type_ == "statement":
+ sess.execute(insert(Edge), data)
+ elif type_ == "stmt_returning":
+ result = sess.scalars(insert(Edge).returning(Edge), data)
+ returning = True
+ elif type_ == "values":
+ sess.execute(insert(Edge).values(data))
+ elif type_ == "values_returning":
+ result = sess.scalars(insert(Edge).values(data).returning(Edge))
+ returning = True
+ elif type_ == "legacy":
+ sess.bulk_insert_mappings(Edge, data)
+ else:
+ assert False
+
+ if returning:
+ eq_(result.all(), [Edge(rec["start"], rec["end"]) for rec in data])
+
+ edges = self.tables.edges
+ eq_(
+ sess.execute(
+ select(edges.c["x1", "y1", "x2", "y2"])
+ .where(edges.c.graph_id == graph_id)
+ .order_by(edges.c.id)
+ ).all(),
+ [
+ (e["start"].x, e["start"].y, e["end"].x, e["end"].y)
+ for e in data
+ ],
+ )
+
+ @testing.combinations("legacy", "statement")
+ def test_bulk_insert_heterogeneous(self, type_):
+ Edge, Point = (self.classes.Edge, self.classes.Point)
+ Graph = self.classes.Graph
+
+ sess = self._fixture()
+
+ graph = Graph(id=2)
+ sess.add(graph)
+ sess.flush()
+ graph_id = 2
+
+ d1 = [
+ {
+ "start": Point(random.randint(1, 50), random.randint(1, 50)),
+ "end": Point(random.randint(1, 50), random.randint(1, 50)),
+ "graph_id": graph_id,
+ }
+ for i in range(3)
+ ]
+ d2 = [
+ {
+ "start": Point(random.randint(1, 50), random.randint(1, 50)),
+ "graph_id": graph_id,
+ }
+ for i in range(2)
+ ]
+ d3 = [
+ {
+ "x2": random.randint(1, 50),
+ "y2": random.randint(1, 50),
+ "graph_id": graph_id,
+ }
+ for i in range(2)
+ ]
+ data = d1 + d2 + d3
+ random.shuffle(data)
+
+ assert_data = [
+ {
+ "start": d["start"] if "start" in d else None,
+ "end": d["end"]
+ if "end" in d
+ else Point(d["x2"], d["y2"])
+ if "x2" in d
+ else None,
+ "graph_id": d["graph_id"],
+ }
+ for d in data
+ ]
+
+ if type_ == "statement":
+ sess.execute(insert(Edge), data)
+ elif type_ == "legacy":
+ sess.bulk_insert_mappings(Edge, data)
+ else:
+ assert False
+
+ edges = self.tables.edges
+ eq_(
+ sess.execute(
+ select(edges.c["x1", "y1", "x2", "y2"])
+ .where(edges.c.graph_id == graph_id)
+ .order_by(edges.c.id)
+ ).all(),
+ [
+ (
+ e["start"].x if e["start"] else None,
+ e["start"].y if e["start"] else None,
+ e["end"].x if e["end"] else None,
+ e["end"].y if e["end"] else None,
+ )
+ for e in assert_data
+ ],
+ )
+
+ @testing.combinations("legacy", "statement")
+ def test_bulk_update(self, type_):
+ Edge, Point = (self.classes.Edge, self.classes.Point)
+ Graph = self.classes.Graph
+
+ sess = self._fixture()
+
+ graph = Graph(id=2)
+ sess.add(graph)
+ sess.flush()
+ graph_id = 2
+
+ data = [
+ {
+ "start": Point(random.randint(1, 50), random.randint(1, 50)),
+ "end": Point(random.randint(1, 50), random.randint(1, 50)),
+ "graph_id": graph_id,
+ }
+ for i in range(25)
+ ]
+ sess.execute(insert(Edge), data)
+
+ inserted_data = [
+ dict(row._mapping)
+ for row in sess.execute(
+ select(Edge.id, Edge.start, Edge.end, Edge.graph_id)
+ .where(Edge.graph_id == graph_id)
+ .order_by(Edge.id)
+ )
+ ]
+
+ to_update = []
+ updated_pks = {}
+ for rec in random.choices(inserted_data, k=7):
+ rec_copy = dict(rec)
+ updated_pks[rec_copy["id"]] = rec_copy
+ rec_copy["start"] = Point(
+ random.randint(1, 50), random.randint(1, 50)
+ )
+ rec_copy["end"] = Point(
+ random.randint(1, 50), random.randint(1, 50)
+ )
+ to_update.append(rec_copy)
+
+ expected_dataset = [
+ updated_pks[row["id"]] if row["id"] in updated_pks else row
+ for row in inserted_data
+ ]
+
+ if type_ == "statement":
+ sess.execute(update(Edge), to_update)
+ elif type_ == "legacy":
+ sess.bulk_update_mappings(Edge, to_update)
+ else:
+ assert False
+
+ edges = self.tables.edges
+ eq_(
+ sess.execute(
+ select(edges.c["x1", "y1", "x2", "y2"])
+ .where(edges.c.graph_id == graph_id)
+ .order_by(edges.c.id)
+ ).all(),
+ [
+ (e["start"].x, e["start"].y, e["end"].x, e["end"].y)
+ for e in expected_dataset
+ ],
+ )
+
def test_get_history(self):
Edge = self.classes.Edge
Point = self.classes.Point
diff --git a/test/orm/test_cycles.py b/test/orm/test_cycles.py
index a23d8b735..7f0f504b5 100644
--- a/test/orm/test_cycles.py
+++ b/test/orm/test_cycles.py
@@ -1122,7 +1122,7 @@ class OneToManyManyToOneTest(fixtures.MappedTest):
[
CompiledSQL(
"INSERT INTO ball (person_id, data) "
- "VALUES (:person_id, :data)",
+ "VALUES (:person_id, :data) RETURNING ball.id",
[
{"person_id": None, "data": "some data"},
{"person_id": None, "data": "some data"},
diff --git a/test/orm/test_defaults.py b/test/orm/test_defaults.py
index 7860f5eb1..e738689b8 100644
--- a/test/orm/test_defaults.py
+++ b/test/orm/test_defaults.py
@@ -383,20 +383,24 @@ class ComputedDefaultsOnUpdateTest(fixtures.MappedTest):
CompiledSQL(
"UPDATE test SET foo=:foo WHERE test.id = :test_id",
[{"foo": 5, "test_id": 1}],
+ enable_returning=False,
),
CompiledSQL(
"UPDATE test SET foo=:foo WHERE test.id = :test_id",
[{"foo": 6, "test_id": 2}],
+ enable_returning=False,
),
CompiledSQL(
"SELECT test.bar AS test_bar FROM test "
"WHERE test.id = :pk_1",
[{"pk_1": 1}],
+ enable_returning=False,
),
CompiledSQL(
"SELECT test.bar AS test_bar FROM test "
"WHERE test.id = :pk_1",
[{"pk_1": 2}],
+ enable_returning=False,
),
)
else:
diff --git a/test/orm/test_events.py b/test/orm/test_events.py
index 24870e20f..75955afb5 100644
--- a/test/orm/test_events.py
+++ b/test/orm/test_events.py
@@ -661,8 +661,17 @@ class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest):
canary = self._flag_fixture(sess)
- sess.execute(delete(User).filter_by(id=18))
- sess.execute(update(User).filter_by(id=18).values(name="eighteen"))
+ sess.execute(
+ delete(User)
+ .filter_by(id=18)
+ .execution_options(synchronize_session="evaluate")
+ )
+ sess.execute(
+ update(User)
+ .filter_by(id=18)
+ .values(name="eighteen")
+ .execution_options(synchronize_session="evaluate")
+ )
eq_(
canary.mock_calls,
diff --git a/test/orm/test_unitofwork.py b/test/orm/test_unitofwork.py
index b94998716..fc452dc9c 100644
--- a/test/orm/test_unitofwork.py
+++ b/test/orm/test_unitofwork.py
@@ -2868,12 +2868,14 @@ class SaveTest2(_fixtures.FixtureTest):
testing.db.dialect.insert_executemany_returning,
[
CompiledSQL(
- "INSERT INTO users (name) VALUES (:name)",
+ "INSERT INTO users (name) VALUES (:name) "
+ "RETURNING users.id",
[{"name": "u1"}, {"name": "u2"}],
),
CompiledSQL(
"INSERT INTO addresses (user_id, email_address) "
- "VALUES (:user_id, :email_address)",
+ "VALUES (:user_id, :email_address) "
+ "RETURNING addresses.id",
[
{"user_id": 1, "email_address": "a1"},
{"user_id": 2, "email_address": "a2"},
diff --git a/test/orm/test_unitofworkv2.py b/test/orm/test_unitofworkv2.py
index dd3b88915..855b44e81 100644
--- a/test/orm/test_unitofworkv2.py
+++ b/test/orm/test_unitofworkv2.py
@@ -98,7 +98,8 @@ class RudimentaryFlushTest(UOWTest):
[
CompiledSQL(
"INSERT INTO addresses (user_id, email_address) "
- "VALUES (:user_id, :email_address)",
+ "VALUES (:user_id, :email_address) "
+ "RETURNING addresses.id",
lambda ctx: [
{"email_address": "a1", "user_id": u1.id},
{"email_address": "a2", "user_id": u1.id},
@@ -220,7 +221,8 @@ class RudimentaryFlushTest(UOWTest):
[
CompiledSQL(
"INSERT INTO addresses (user_id, email_address) "
- "VALUES (:user_id, :email_address)",
+ "VALUES (:user_id, :email_address) "
+ "RETURNING addresses.id",
lambda ctx: [
{"email_address": "a1", "user_id": u1.id},
{"email_address": "a2", "user_id": u1.id},
@@ -889,7 +891,7 @@ class SingleCycleTest(UOWTest):
[
CompiledSQL(
"INSERT INTO nodes (parent_id, data) VALUES "
- "(:parent_id, :data)",
+ "(:parent_id, :data) RETURNING nodes.id",
lambda ctx: [
{"parent_id": n1.id, "data": "n2"},
{"parent_id": n1.id, "data": "n3"},
@@ -1003,7 +1005,7 @@ class SingleCycleTest(UOWTest):
[
CompiledSQL(
"INSERT INTO nodes (parent_id, data) VALUES "
- "(:parent_id, :data)",
+ "(:parent_id, :data) RETURNING nodes.id",
lambda ctx: [
{"parent_id": n1.id, "data": "n2"},
{"parent_id": n1.id, "data": "n3"},
@@ -1165,7 +1167,7 @@ class SingleCycleTest(UOWTest):
[
CompiledSQL(
"INSERT INTO nodes (parent_id, data) VALUES "
- "(:parent_id, :data)",
+ "(:parent_id, :data) RETURNING nodes.id",
lambda ctx: [
{"parent_id": n1.id, "data": "n11"},
{"parent_id": n1.id, "data": "n12"},
@@ -1196,7 +1198,7 @@ class SingleCycleTest(UOWTest):
[
CompiledSQL(
"INSERT INTO nodes (parent_id, data) VALUES "
- "(:parent_id, :data)",
+ "(:parent_id, :data) RETURNING nodes.id",
lambda ctx: [
{"parent_id": n12.id, "data": "n121"},
{"parent_id": n12.id, "data": "n122"},
@@ -2099,7 +2101,7 @@ class BatchInsertsTest(fixtures.MappedTest, testing.AssertsExecutionResults):
testing.db.dialect.insert_executemany_returning,
[
CompiledSQL(
- "INSERT INTO t (data) VALUES (:data)",
+ "INSERT INTO t (data) VALUES (:data) RETURNING t.id",
[{"data": "t1"}, {"data": "t2"}],
),
],
@@ -2472,20 +2474,24 @@ class EagerDefaultsTest(fixtures.MappedTest):
CompiledSQL(
"INSERT INTO test (id, foo) VALUES (:id, 2 + 5)",
[{"id": 1}],
+ enable_returning=False,
),
CompiledSQL(
"INSERT INTO test (id, foo) VALUES (:id, 5 + 5)",
[{"id": 2}],
+ enable_returning=False,
),
CompiledSQL(
"SELECT test.foo AS test_foo FROM test "
"WHERE test.id = :pk_1",
[{"pk_1": 1}],
+ enable_returning=False,
),
CompiledSQL(
"SELECT test.foo AS test_foo FROM test "
"WHERE test.id = :pk_1",
[{"pk_1": 2}],
+ enable_returning=False,
),
)
@@ -2678,20 +2684,24 @@ class EagerDefaultsTest(fixtures.MappedTest):
CompiledSQL(
"UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id",
[{"foo": 5, "test2_id": 1}],
+ enable_returning=False,
),
CompiledSQL(
"UPDATE test2 SET foo=:foo, bar=:bar "
"WHERE test2.id = :test2_id",
[{"foo": 6, "bar": 10, "test2_id": 2}],
+ enable_returning=False,
),
CompiledSQL(
"UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id",
[{"foo": 7, "test2_id": 3}],
+ enable_returning=False,
),
CompiledSQL(
"UPDATE test2 SET foo=:foo, bar=:bar "
"WHERE test2.id = :test2_id",
[{"foo": 8, "bar": 12, "test2_id": 4}],
+ enable_returning=False,
),
CompiledSQL(
"SELECT test2.bar AS test2_bar FROM test2 "
@@ -2772,31 +2782,37 @@ class EagerDefaultsTest(fixtures.MappedTest):
"UPDATE test4 SET foo=:foo, bar=5 + 3 "
"WHERE test4.id = :test4_id",
[{"foo": 5, "test4_id": 1}],
+ enable_returning=False,
),
CompiledSQL(
"UPDATE test4 SET foo=:foo, bar=:bar "
"WHERE test4.id = :test4_id",
[{"foo": 6, "bar": 10, "test4_id": 2}],
+ enable_returning=False,
),
CompiledSQL(
"UPDATE test4 SET foo=:foo, bar=5 + 3 "
"WHERE test4.id = :test4_id",
[{"foo": 7, "test4_id": 3}],
+ enable_returning=False,
),
CompiledSQL(
"UPDATE test4 SET foo=:foo, bar=:bar "
"WHERE test4.id = :test4_id",
[{"foo": 8, "bar": 12, "test4_id": 4}],
+ enable_returning=False,
),
CompiledSQL(
"SELECT test4.bar AS test4_bar FROM test4 "
"WHERE test4.id = :pk_1",
[{"pk_1": 1}],
+ enable_returning=False,
),
CompiledSQL(
"SELECT test4.bar AS test4_bar FROM test4 "
"WHERE test4.id = :pk_1",
[{"pk_1": 3}],
+ enable_returning=False,
),
],
),
@@ -2871,20 +2887,24 @@ class EagerDefaultsTest(fixtures.MappedTest):
"UPDATE test2 SET foo=:foo, bar=1 + 1 "
"WHERE test2.id = :test2_id",
[{"foo": 5, "test2_id": 1}],
+ enable_returning=False,
),
CompiledSQL(
"UPDATE test2 SET foo=:foo, bar=:bar "
"WHERE test2.id = :test2_id",
[{"foo": 6, "bar": 10, "test2_id": 2}],
+ enable_returning=False,
),
CompiledSQL(
"UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id",
[{"foo": 7, "test2_id": 3}],
+ enable_returning=False,
),
CompiledSQL(
"UPDATE test2 SET foo=:foo, bar=5 + 7 "
"WHERE test2.id = :test2_id",
[{"foo": 8, "test2_id": 4}],
+ enable_returning=False,
),
CompiledSQL(
"SELECT test2.bar AS test2_bar FROM test2 "
diff --git a/test/orm/test_versioning.py b/test/orm/test_versioning.py
index abd5833be..84e5a83b0 100644
--- a/test/orm/test_versioning.py
+++ b/test/orm/test_versioning.py
@@ -1424,12 +1424,10 @@ class ServerVersioningTest(fixtures.MappedTest):
sess.add(f1)
statements = [
- # note that the assertsql tests the rule against
- # "default" - on a "returning" backend, the statement
- # includes "RETURNING"
CompiledSQL(
"INSERT INTO version_table (version_id, value) "
- "VALUES (1, :value)",
+ "VALUES (1, :value) "
+ "RETURNING version_table.id, version_table.version_id",
lambda ctx: [{"value": "f1"}],
)
]
@@ -1493,6 +1491,7 @@ class ServerVersioningTest(fixtures.MappedTest):
"value": "f2",
}
],
+ enable_returning=False,
),
CompiledSQL(
"SELECT version_table.version_id "
@@ -1618,6 +1617,7 @@ class ServerVersioningTest(fixtures.MappedTest):
"value": "f1a",
}
],
+ enable_returning=False,
),
CompiledSQL(
"UPDATE version_table SET version_id=2, value=:value "
@@ -1630,6 +1630,7 @@ class ServerVersioningTest(fixtures.MappedTest):
"value": "f2a",
}
],
+ enable_returning=False,
),
CompiledSQL(
"UPDATE version_table SET version_id=2, value=:value "
@@ -1642,6 +1643,7 @@ class ServerVersioningTest(fixtures.MappedTest):
"value": "f3a",
}
],
+ enable_returning=False,
),
CompiledSQL(
"SELECT version_table.version_id "
diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py
index 42cf31bf5..4f776e300 100644
--- a/test/sql/test_resultset.py
+++ b/test/sql/test_resultset.py
@@ -100,10 +100,55 @@ class CursorResultTest(fixtures.TablesTest):
Table(
"test",
metadata,
- Column("x", Integer, primary_key=True),
+ Column(
+ "x", Integer, primary_key=True, test_needs_autoincrement=False
+ ),
Column("y", String(50)),
)
+ @testing.requires.insert_returning
+ def test_splice_horizontally(self, connection):
+ users = self.tables.users
+ addresses = self.tables.addresses
+
+ r1 = connection.execute(
+ users.insert().returning(users.c.user_name, users.c.user_id),
+ [
+ dict(user_id=1, user_name="john"),
+ dict(user_id=2, user_name="jack"),
+ ],
+ )
+
+ r2 = connection.execute(
+ addresses.insert().returning(
+ addresses.c.address_id,
+ addresses.c.address,
+ addresses.c.user_id,
+ ),
+ [
+ dict(address_id=1, user_id=1, address="foo@bar.com"),
+ dict(address_id=2, user_id=2, address="bar@bat.com"),
+ ],
+ )
+
+ rows = r1.splice_horizontally(r2).all()
+ eq_(
+ rows,
+ [
+ ("john", 1, 1, "foo@bar.com", 1),
+ ("jack", 2, 2, "bar@bat.com", 2),
+ ],
+ )
+
+ eq_(rows[0]._mapping[users.c.user_id], 1)
+ eq_(rows[0]._mapping[addresses.c.user_id], 1)
+ eq_(rows[1].address, "bar@bat.com")
+
+ with expect_raises_message(
+ exc.InvalidRequestError, "Ambiguous column name 'user_id'"
+ ):
+ rows[0].user_id
+
def test_keys_no_rows(self, connection):
for i in range(2):
diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py
index f8cc32517..c26f825c2 100644
--- a/test/sql/test_returning.py
+++ b/test/sql/test_returning.py
@@ -23,6 +23,7 @@ from sqlalchemy.testing import config
from sqlalchemy.testing import eq_
from sqlalchemy.testing import expect_raises_message
from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import is_
from sqlalchemy.testing import mock
from sqlalchemy.testing import provision
from sqlalchemy.testing.schema import Column
@@ -76,6 +77,7 @@ class ReturnCombinationTests(fixtures.TestBase, AssertsCompiledSQL):
stmt = stmt.returning(t.c.x)
stmt = stmt.return_defaults()
+
assert_raises_message(
sa_exc.CompileError,
r"Can't compile statement that includes returning\(\) "
@@ -330,6 +332,7 @@ class InsertReturningTest(fixtures.TablesTest, AssertsExecutionResults):
table = self.tables.returning_tbl
exprs = testing.resolve_lambda(testcase, table=table)
+
result = connection.execute(
table.insert().returning(*exprs),
{"persons": 5, "full": False, "strval": "str1"},
@@ -679,6 +682,30 @@ class InsertReturnDefaultsTest(fixtures.TablesTest):
Column("upddef", Integer, onupdate=IncDefault()),
)
+ Table(
+ "table_no_addtl_defaults",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("data", String(50)),
+ )
+
+ class MyType(TypeDecorator):
+ impl = String(50)
+
+ def process_result_value(self, value, dialect):
+ return f"PROCESSED! {value}"
+
+ Table(
+ "table_datatype_has_result_proc",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("data", MyType()),
+ )
+
def test_chained_insert_pk(self, connection):
t1 = self.tables.t1
result = connection.execute(
@@ -758,6 +785,38 @@ class InsertReturnDefaultsTest(fixtures.TablesTest):
)
eq_(result.inserted_primary_key, (1,))
+ def test_insert_w_defaults_supplemental_cols(self, connection):
+ t1 = self.tables.t1
+ result = connection.execute(
+ t1.insert().return_defaults(supplemental_cols=[t1.c.id]),
+ {"data": "d1"},
+ )
+ eq_(result.all(), [(1, 0, None)])
+
+ def test_insert_w_no_defaults_supplemental_cols(self, connection):
+ t1 = self.tables.table_no_addtl_defaults
+ result = connection.execute(
+ t1.insert().return_defaults(supplemental_cols=[t1.c.id]),
+ {"data": "d1"},
+ )
+ eq_(result.all(), [(1,)])
+
+ def test_insert_w_defaults_supplemental_processor_cols(self, connection):
+ """test that the cursor._rewind() used by supplemental RETURNING
+ clears out result-row processors as we will have already processed
+ the rows.
+
+ """
+
+ t1 = self.tables.table_datatype_has_result_proc
+ result = connection.execute(
+ t1.insert().return_defaults(
+ supplemental_cols=[t1.c.id, t1.c.data]
+ ),
+ {"data": "d1"},
+ )
+ eq_(result.all(), [(1, "PROCESSED! d1")])
+
class UpdatedReturnDefaultsTest(fixtures.TablesTest):
__requires__ = ("update_returning",)
@@ -792,6 +851,7 @@ class UpdatedReturnDefaultsTest(fixtures.TablesTest):
t1 = self.tables.t1
connection.execute(t1.insert().values(upddef=1))
+
result = connection.execute(
t1.update().values(upddef=2).return_defaults(t1.c.data)
)
@@ -800,6 +860,72 @@ class UpdatedReturnDefaultsTest(fixtures.TablesTest):
[None],
)
+ def test_update_values_col_is_excluded(self, connection):
+ """columns that are in values() are not returned"""
+ t1 = self.tables.t1
+ connection.execute(t1.insert().values(upddef=1))
+
+ result = connection.execute(
+ t1.update().values(data="x", upddef=2).return_defaults(t1.c.data)
+ )
+ is_(result.returned_defaults, None)
+
+ result = connection.execute(
+ t1.update()
+ .values(data="x", upddef=2)
+ .return_defaults(t1.c.data, t1.c.id)
+ )
+ eq_(result.returned_defaults, (1,))
+
+ def test_update_supplemental_cols(self, connection):
+ """with supplemental_cols, we can get back arbitrary cols."""
+
+ t1 = self.tables.t1
+ connection.execute(t1.insert().values(upddef=1))
+ result = connection.execute(
+ t1.update()
+ .values(data="x", insdef=3)
+ .return_defaults(supplemental_cols=[t1.c.data, t1.c.insdef])
+ )
+
+ row = result.returned_defaults
+
+ # row has all the cols in it
+ eq_(row, ("x", 3, 1))
+ eq_(row._mapping[t1.c.upddef], 1)
+ eq_(row._mapping[t1.c.insdef], 3)
+
+ # result is rewound
+ # but has both return_defaults + supplemental_cols
+ eq_(result.all(), [("x", 3, 1)])
+
+ def test_update_expl_return_defaults_plus_supplemental_cols(
+ self, connection
+ ):
+ """with supplemental_cols, we can get back arbitrary cols."""
+
+ t1 = self.tables.t1
+ connection.execute(t1.insert().values(upddef=1))
+ result = connection.execute(
+ t1.update()
+ .values(data="x", insdef=3)
+ .return_defaults(
+ t1.c.id, supplemental_cols=[t1.c.data, t1.c.insdef]
+ )
+ )
+
+ row = result.returned_defaults
+
+ # row has all the cols in it
+ eq_(row, (1, "x", 3))
+ eq_(row._mapping[t1.c.id], 1)
+ eq_(row._mapping[t1.c.insdef], 3)
+ assert t1.c.upddef not in row._mapping
+
+ # result is rewound
+ # but has both return_defaults + supplemental_cols
+ eq_(result.all(), [(1, "x", 3)])
+
def test_update_sql_expr(self, connection):
from sqlalchemy import literal
@@ -833,6 +959,75 @@ class UpdatedReturnDefaultsTest(fixtures.TablesTest):
eq_(dict(result.returned_defaults._mapping), {"upddef": 1})
+class DeleteReturnDefaultsTest(fixtures.TablesTest):
+ __requires__ = ("delete_returning",)
+ run_define_tables = "each"
+ __backend__ = True
+
+ define_tables = InsertReturnDefaultsTest.define_tables
+
+ def test_delete(self, connection):
+ t1 = self.tables.t1
+ connection.execute(t1.insert().values(upddef=1))
+ result = connection.execute(t1.delete().return_defaults(t1.c.upddef))
+ eq_(
+ [result.returned_defaults._mapping[k] for k in (t1.c.upddef,)], [1]
+ )
+
+ def test_delete_empty_return_defaults(self, connection):
+ t1 = self.tables.t1
+ connection.execute(t1.insert().values(upddef=5))
+ result = connection.execute(t1.delete().return_defaults())
+
+ # there's no "delete" default, so we get None. we have to
+ # ask for them in all cases
+ eq_(result.returned_defaults, None)
+
+ def test_delete_non_default(self, connection):
+ """test that a column not marked at all as a
+ default works with this feature."""
+
+ t1 = self.tables.t1
+ connection.execute(t1.insert().values(upddef=1))
+ result = connection.execute(t1.delete().return_defaults(t1.c.data))
+ eq_(
+ [result.returned_defaults._mapping[k] for k in (t1.c.data,)],
+ [None],
+ )
+
+ def test_delete_non_default_plus_default(self, connection):
+ t1 = self.tables.t1
+ connection.execute(t1.insert().values(upddef=1))
+ result = connection.execute(
+ t1.delete().return_defaults(t1.c.data, t1.c.upddef)
+ )
+ eq_(
+ dict(result.returned_defaults._mapping),
+ {"data": None, "upddef": 1},
+ )
+
+ def test_delete_supplemental_cols(self, connection):
+ """with supplemental_cols, we can get back arbitrary cols."""
+
+ t1 = self.tables.t1
+ connection.execute(t1.insert().values(upddef=1))
+ result = connection.execute(
+ t1.delete().return_defaults(
+ t1.c.id, supplemental_cols=[t1.c.data, t1.c.insdef]
+ )
+ )
+
+ row = result.returned_defaults
+
+ # row has all the cols in it
+ eq_(row, (1, None, 0))
+ eq_(row._mapping[t1.c.insdef], 0)
+
+ # result is rewound
+ # but has both return_defaults + supplemental_cols
+ eq_(result.all(), [(1, None, 0)])
+
+
class InsertManyReturnDefaultsTest(fixtures.TablesTest):
__requires__ = ("insert_executemany_returning",)
run_define_tables = "each"
diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py
index 64ff2e421..5ef927b15 100644
--- a/test/sql/test_selectable.py
+++ b/test/sql/test_selectable.py
@@ -44,6 +44,7 @@ from sqlalchemy.sql import operators
from sqlalchemy.sql import table
from sqlalchemy.sql import util as sql_util
from sqlalchemy.sql import visitors
+from sqlalchemy.sql.dml import Insert
from sqlalchemy.sql.selectable import LABEL_STYLE_NONE
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import assert_raises_message
@@ -3029,6 +3030,26 @@ class AnnotationsTest(fixtures.TestBase):
eq_(whereclause.left._annotations, {"foo": "bar"})
eq_(whereclause.right._annotations, {"foo": "bar"})
+ @testing.combinations(True, False, None)
+ def test_setup_inherit_cache(self, inherit_cache_value):
+ if inherit_cache_value is None:
+
+ class MyInsertThing(Insert):
+ pass
+
+ else:
+
+ class MyInsertThing(Insert):
+ inherit_cache = inherit_cache_value
+
+ t = table("t", column("x"))
+ anno = MyInsertThing(t)._annotate({"foo": "bar"})
+
+ if inherit_cache_value is not None:
+ is_(type(anno).__dict__["inherit_cache"], inherit_cache_value)
+ else:
+ assert "inherit_cache" not in type(anno).__dict__
+
def test_proxy_set_iteration_includes_annotated(self):
from sqlalchemy.schema import Column