diff options
| author | mike bayer <mike_mp@zzzcomputing.com> | 2022-09-26 01:17:44 +0000 |
|---|---|---|
| committer | Gerrit Code Review <gerrit@ci3.zzzcomputing.com> | 2022-09-26 01:17:44 +0000 |
| commit | 6201b4d88666983b883b96d22a159aa2594de94b (patch) | |
| tree | 4036c155ca7c274ea4bd12c059fd8fcd277fc026 /test | |
| parent | f81fdd9a9008a6517f89f2115765b7db9a32721b (diff) | |
| parent | a8029f5a7e3e376ec57f1614ab0294b717d53c05 (diff) | |
| download | sqlalchemy-6201b4d88666983b883b96d22a159aa2594de94b.tar.gz | |
Merge "ORM bulk insert via execute" into main
Diffstat (limited to 'test')
| -rw-r--r-- | test/ext/test_horizontal_shard.py | 212 | ||||
| -rw-r--r-- | test/ext/test_hybrid.py | 35 | ||||
| -rw-r--r-- | test/orm/dml/__init__.py | 0 | ||||
| -rw-r--r-- | test/orm/dml/test_bulk.py (renamed from test/orm/test_bulk.py) | 231 | ||||
| -rw-r--r-- | test/orm/dml/test_bulk_statements.py | 1199 | ||||
| -rw-r--r-- | test/orm/dml/test_evaluator.py (renamed from test/orm/test_evaluator.py) | 1 | ||||
| -rw-r--r-- | test/orm/dml/test_update_delete_where.py (renamed from test/orm/test_update_delete.py) | 499 | ||||
| -rw-r--r-- | test/orm/inheritance/test_basic.py | 3 | ||||
| -rw-r--r-- | test/orm/test_bind.py | 11 | ||||
| -rw-r--r-- | test/orm/test_composites.py | 207 | ||||
| -rw-r--r-- | test/orm/test_cycles.py | 2 | ||||
| -rw-r--r-- | test/orm/test_defaults.py | 4 | ||||
| -rw-r--r-- | test/orm/test_events.py | 13 | ||||
| -rw-r--r-- | test/orm/test_unitofwork.py | 6 | ||||
| -rw-r--r-- | test/orm/test_unitofworkv2.py | 34 | ||||
| -rw-r--r-- | test/orm/test_versioning.py | 10 | ||||
| -rw-r--r-- | test/sql/test_resultset.py | 47 | ||||
| -rw-r--r-- | test/sql/test_returning.py | 195 | ||||
| -rw-r--r-- | test/sql/test_selectable.py | 21 |
19 files changed, 2420 insertions, 310 deletions
diff --git a/test/ext/test_horizontal_shard.py b/test/ext/test_horizontal_shard.py index 7cc6a6f79..667f4bfb0 100644 --- a/test/ext/test_horizontal_shard.py +++ b/test/ext/test_horizontal_shard.py @@ -465,7 +465,11 @@ class ShardTest: t = get_tokyo(sess2) eq_(t.city, tokyo.city) - def test_bulk_update_synchronize_evaluate(self): + @testing.combinations( + "fetch", "evaluate", "auto", argnames="synchronize_session" + ) + @testing.combinations(True, False, argnames="legacy") + def test_orm_update_synchronize(self, synchronize_session, legacy): sess = self._fixture_data() eq_( @@ -476,33 +480,25 @@ class ShardTest: temps = sess.query(Report).all() eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0}) - sess.query(Report).filter(Report.temperature >= 80).update( - {"temperature": Report.temperature + 6}, - synchronize_session="evaluate", - ) - - eq_( - set(row.temperature for row in sess.query(Report.temperature)), - {86.0, 75.0, 91.0}, - ) - - # test synchronize session as well - eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0}) - - def test_bulk_update_synchronize_fetch(self): - sess = self._fixture_data() - - eq_( - set(row.temperature for row in sess.query(Report.temperature)), - {80.0, 75.0, 85.0}, - ) + if legacy: + sess.query(Report).filter(Report.temperature >= 80).update( + {"temperature": Report.temperature + 6}, + synchronize_session=synchronize_session, + ) + else: + sess.execute( + update(Report) + .filter(Report.temperature >= 80) + .values(temperature=Report.temperature + 6) + .execution_options(synchronize_session=synchronize_session) + ) - temps = sess.query(Report).all() - eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0}) + # test synchronize session + def go(): + eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0}) - sess.query(Report).filter(Report.temperature >= 80).update( - {"temperature": Report.temperature + 6}, - synchronize_session="fetch", + self.assert_sql_count( + sess._ShardedSession__binds["north_america"], go, 0 ) eq_( @@ -510,165 +506,41 @@ class ShardTest: {86.0, 75.0, 91.0}, ) - # test synchronize session as well - eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0}) - - def test_bulk_delete_synchronize_evaluate(self): - sess = self._fixture_data() - - temps = sess.query(Report).all() - eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0}) - - sess.query(Report).filter(Report.temperature >= 80).delete( - synchronize_session="evaluate" - ) - - eq_( - set(row.temperature for row in sess.query(Report.temperature)), - {75.0}, - ) - - # test synchronize session as well - for t in temps: - assert inspect(t).deleted is (t.temperature >= 80) - - def test_bulk_delete_synchronize_fetch(self): + @testing.combinations( + "fetch", "evaluate", "auto", argnames="synchronize_session" + ) + @testing.combinations(True, False, argnames="legacy") + def test_orm_delete_synchronize(self, synchronize_session, legacy): sess = self._fixture_data() temps = sess.query(Report).all() eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0}) - sess.query(Report).filter(Report.temperature >= 80).delete( - synchronize_session="fetch" - ) - - eq_( - set(row.temperature for row in sess.query(Report.temperature)), - {75.0}, - ) - - # test synchronize session as well - for t in temps: - assert inspect(t).deleted is (t.temperature >= 80) - - def test_bulk_update_future_synchronize_evaluate(self): - sess = self._fixture_data() - - eq_( - set( - row.temperature - for row in sess.execute(select(Report.temperature)) - ), - {80.0, 75.0, 85.0}, - ) - - temps = sess.execute(select(Report)).scalars().all() - eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0}) - - sess.execute( - update(Report) - .filter(Report.temperature >= 80) - .values( - {"temperature": Report.temperature + 6}, + if legacy: + sess.query(Report).filter(Report.temperature >= 80).delete( + synchronize_session=synchronize_session ) - .execution_options(synchronize_session="evaluate") - ) - - eq_( - set( - row.temperature - for row in sess.execute(select(Report.temperature)) - ), - {86.0, 75.0, 91.0}, - ) - - # test synchronize session as well - eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0}) - - def test_bulk_update_future_synchronize_fetch(self): - sess = self._fixture_data() - - eq_( - set( - row.temperature - for row in sess.execute(select(Report.temperature)) - ), - {80.0, 75.0, 85.0}, - ) - - temps = sess.execute(select(Report)).scalars().all() - eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0}) - - # MARKMARK - # omitting the criteria so that the UPDATE affects three out of - # four shards - sess.execute( - update(Report) - .values( - {"temperature": Report.temperature + 6}, + else: + sess.execute( + delete(Report) + .filter(Report.temperature >= 80) + .execution_options(synchronize_session=synchronize_session) ) - .execution_options(synchronize_session="fetch") - ) - - eq_( - set( - row.temperature - for row in sess.execute(select(Report.temperature)) - ), - {86.0, 81.0, 91.0}, - ) - - # test synchronize session as well - eq_(set(t.temperature for t in temps), {86.0, 81.0, 91.0}) - - def test_bulk_delete_future_synchronize_evaluate(self): - sess = self._fixture_data() - - temps = sess.execute(select(Report)).scalars().all() - eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0}) - - sess.execute( - delete(Report) - .filter(Report.temperature >= 80) - .execution_options(synchronize_session="evaluate") - ) - eq_( - set( - row.temperature - for row in sess.execute(select(Report.temperature)) - ), - {75.0}, - ) - - # test synchronize session as well - for t in temps: - assert inspect(t).deleted is (t.temperature >= 80) - - def test_bulk_delete_future_synchronize_fetch(self): - sess = self._fixture_data() - - temps = sess.execute(select(Report)).scalars().all() - eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0}) + def go(): + # test synchronize session + for t in temps: + assert inspect(t).deleted is (t.temperature >= 80) - sess.execute( - delete(Report) - .filter(Report.temperature >= 80) - .execution_options(synchronize_session="fetch") + self.assert_sql_count( + sess._ShardedSession__binds["north_america"], go, 0 ) eq_( - set( - row.temperature - for row in sess.execute(select(Report.temperature)) - ), + set(row.temperature for row in sess.query(Report.temperature)), {75.0}, ) - # test synchronize session as well - for t in temps: - assert inspect(t).deleted is (t.temperature >= 80) - class DistinctEngineShardTest(ShardTest, fixtures.MappedTest): def _init_dbs(self): diff --git a/test/ext/test_hybrid.py b/test/ext/test_hybrid.py index de5f89b25..0cba8f3a1 100644 --- a/test/ext/test_hybrid.py +++ b/test/ext/test_hybrid.py @@ -3,6 +3,7 @@ from decimal import Decimal from sqlalchemy import exc from sqlalchemy import ForeignKey from sqlalchemy import func +from sqlalchemy import insert from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import LABEL_STYLE_TABLENAME_PLUS_COL @@ -1017,15 +1018,43 @@ class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL): params={"first_name": "Dr."}, ) - def test_update_expr(self): + @testing.combinations("attr", "str", "kwarg", argnames="keytype") + def test_update_expr(self, keytype): Person = self.classes.Person - statement = update(Person).values({Person.name: "Dr. No"}) + if keytype == "attr": + statement = update(Person).values({Person.name: "Dr. No"}) + elif keytype == "str": + statement = update(Person).values({"name": "Dr. No"}) + elif keytype == "kwarg": + statement = update(Person).values(name="Dr. No") + else: + assert False self.assert_compile( statement, "UPDATE person SET first_name=:first_name, last_name=:last_name", - params={"first_name": "Dr.", "last_name": "No"}, + checkparams={"first_name": "Dr.", "last_name": "No"}, + ) + + @testing.combinations("attr", "str", "kwarg", argnames="keytype") + def test_insert_expr(self, keytype): + Person = self.classes.Person + + if keytype == "attr": + statement = insert(Person).values({Person.name: "Dr. No"}) + elif keytype == "str": + statement = insert(Person).values({"name": "Dr. No"}) + elif keytype == "kwarg": + statement = insert(Person).values(name="Dr. No") + else: + assert False + + self.assert_compile( + statement, + "INSERT INTO person (first_name, last_name) VALUES " + "(:first_name, :last_name)", + checkparams={"first_name": "Dr.", "last_name": "No"}, ) # these tests all run two UPDATES to assert that caching is not diff --git a/test/orm/dml/__init__.py b/test/orm/dml/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/test/orm/dml/__init__.py diff --git a/test/orm/test_bulk.py b/test/orm/dml/test_bulk.py index 802cdfac5..52db4247f 100644 --- a/test/orm/test_bulk.py +++ b/test/orm/dml/test_bulk.py @@ -1,8 +1,11 @@ from sqlalchemy import FetchedValue from sqlalchemy import ForeignKey +from sqlalchemy import Identity +from sqlalchemy import insert from sqlalchemy import Integer from sqlalchemy import String from sqlalchemy import testing +from sqlalchemy import update from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import mock @@ -20,6 +23,8 @@ class BulkTest(testing.AssertsExecutionResults): class BulkInsertUpdateVersionId(BulkTest, fixtures.MappedTest): + __backend__ = True + @classmethod def define_tables(cls, metadata): Table( @@ -73,6 +78,8 @@ class BulkInsertUpdateVersionId(BulkTest, fixtures.MappedTest): class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest): + __backend__ = True + @classmethod def setup_mappers(cls): User, Address, Order = cls.classes("User", "Address", "Order") @@ -82,22 +89,42 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest): cls.mapper_registry.map_imperatively(Address, a) cls.mapper_registry.map_imperatively(Order, o) - def test_bulk_save_return_defaults(self): + @testing.combinations("save_objects", "insert_mappings", "insert_stmt") + def test_bulk_save_return_defaults(self, statement_type): (User,) = self.classes("User") s = fixture_session() - objects = [User(name="u1"), User(name="u2"), User(name="u3")] - assert "id" not in objects[0].__dict__ - with self.sql_execution_asserter() as asserter: - s.bulk_save_objects(objects, return_defaults=True) + if statement_type == "save_objects": + objects = [User(name="u1"), User(name="u2"), User(name="u3")] + assert "id" not in objects[0].__dict__ + + returning_users_id = " RETURNING users.id" + with self.sql_execution_asserter() as asserter: + s.bulk_save_objects(objects, return_defaults=True) + elif statement_type == "insert_mappings": + data = [dict(name="u1"), dict(name="u2"), dict(name="u3")] + returning_users_id = " RETURNING users.id" + with self.sql_execution_asserter() as asserter: + s.bulk_insert_mappings(User, data, return_defaults=True) + elif statement_type == "insert_stmt": + data = [dict(name="u1"), dict(name="u2"), dict(name="u3")] + + # for statement, "return_defaults" is heuristic on if we are + # a joined inh mapping if we don't otherwise include + # .returning() on the statement itself + returning_users_id = "" + with self.sql_execution_asserter() as asserter: + s.execute(insert(User), data) asserter.assert_( Conditional( - testing.db.dialect.insert_executemany_returning, + testing.db.dialect.insert_executemany_returning + or statement_type == "insert_stmt", [ CompiledSQL( - "INSERT INTO users (name) VALUES (:name)", + "INSERT INTO users (name) " + f"VALUES (:name){returning_users_id}", [{"name": "u1"}, {"name": "u2"}, {"name": "u3"}], ), ], @@ -117,7 +144,8 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest): ], ) ) - eq_(objects[0].__dict__["id"], 1) + if statement_type == "save_objects": + eq_(objects[0].__dict__["id"], 1) def test_bulk_save_mappings_preserve_order(self): (User,) = self.classes("User") @@ -219,8 +247,9 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest): ) ) - def test_bulk_update(self): - (User,) = self.classes("User") + @testing.combinations("update_mappings", "update_stmt") + def test_bulk_update(self, statement_type): + User = self.classes.User s = fixture_session(expire_on_commit=False) objects = [User(name="u1"), User(name="u2"), User(name="u3")] @@ -228,15 +257,18 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest): s.commit() s = fixture_session() - with self.sql_execution_asserter() as asserter: - s.bulk_update_mappings( - User, - [ - {"id": 1, "name": "u1new"}, - {"id": 2, "name": "u2"}, - {"id": 3, "name": "u3new"}, - ], - ) + data = [ + {"id": 1, "name": "u1new"}, + {"id": 2, "name": "u2"}, + {"id": 3, "name": "u3new"}, + ] + + if statement_type == "update_mappings": + with self.sql_execution_asserter() as asserter: + s.bulk_update_mappings(User, data) + elif statement_type == "update_stmt": + with self.sql_execution_asserter() as asserter: + s.execute(update(User), data) asserter.assert_( CompiledSQL( @@ -303,6 +335,8 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest): class BulkUDPostfetchTest(BulkTest, fixtures.MappedTest): + __backend__ = True + @classmethod def define_tables(cls, metadata): Table( @@ -360,6 +394,8 @@ class BulkUDPostfetchTest(BulkTest, fixtures.MappedTest): class BulkUDTestAltColKeys(BulkTest, fixtures.MappedTest): + __backend__ = True + @classmethod def define_tables(cls, metadata): Table( @@ -547,6 +583,8 @@ class BulkUDTestAltColKeys(BulkTest, fixtures.MappedTest): class BulkInheritanceTest(BulkTest, fixtures.MappedTest): + __backend__ = True + @classmethod def define_tables(cls, metadata): Table( @@ -643,6 +681,7 @@ class BulkInheritanceTest(BulkTest, fixtures.MappedTest): ) s = fixture_session() + objects = [ Manager(name="m1", status="s1", manager_name="mn1"), Engineer(name="e1", status="s2", primary_language="l1"), @@ -669,7 +708,7 @@ class BulkInheritanceTest(BulkTest, fixtures.MappedTest): [ CompiledSQL( "INSERT INTO people (name, type) " - "VALUES (:name, :type)", + "VALUES (:name, :type) RETURNING people.person_id", [ {"type": "engineer", "name": "e1"}, {"type": "engineer", "name": "e2"}, @@ -798,59 +837,74 @@ class BulkInheritanceTest(BulkTest, fixtures.MappedTest): ), ) - def test_bulk_insert_joined_inh_return_defaults(self): + @testing.combinations("insert_mappings", "insert_stmt") + def test_bulk_insert_joined_inh_return_defaults(self, statement_type): Person, Engineer, Manager, Boss = self.classes( "Person", "Engineer", "Manager", "Boss" ) s = fixture_session() - with self.sql_execution_asserter() as asserter: - s.bulk_insert_mappings( - Boss, - [ - dict( - name="b1", - status="s1", - manager_name="mn1", - golf_swing="g1", - ), - dict( - name="b2", - status="s2", - manager_name="mn2", - golf_swing="g2", - ), - dict( - name="b3", - status="s3", - manager_name="mn3", - golf_swing="g3", - ), - ], - return_defaults=True, - ) + data = [ + dict( + name="b1", + status="s1", + manager_name="mn1", + golf_swing="g1", + ), + dict( + name="b2", + status="s2", + manager_name="mn2", + golf_swing="g2", + ), + dict( + name="b3", + status="s3", + manager_name="mn3", + golf_swing="g3", + ), + ] + + if statement_type == "insert_mappings": + with self.sql_execution_asserter() as asserter: + s.bulk_insert_mappings( + Boss, + data, + return_defaults=True, + ) + elif statement_type == "insert_stmt": + with self.sql_execution_asserter() as asserter: + s.execute(insert(Boss), data) asserter.assert_( Conditional( testing.db.dialect.insert_executemany_returning, [ CompiledSQL( - "INSERT INTO people (name) VALUES (:name)", - [{"name": "b1"}, {"name": "b2"}, {"name": "b3"}], + "INSERT INTO people (name, type) " + "VALUES (:name, :type) RETURNING people.person_id", + [ + {"name": "b1", "type": "boss"}, + {"name": "b2", "type": "boss"}, + {"name": "b3", "type": "boss"}, + ], ), ], [ CompiledSQL( - "INSERT INTO people (name) VALUES (:name)", - [{"name": "b1"}], + "INSERT INTO people (name, type) " + "VALUES (:name, :type)", + [{"name": "b1", "type": "boss"}], ), CompiledSQL( - "INSERT INTO people (name) VALUES (:name)", - [{"name": "b2"}], + "INSERT INTO people (name, type) " + "VALUES (:name, :type)", + [{"name": "b2", "type": "boss"}], ), CompiledSQL( - "INSERT INTO people (name) VALUES (:name)", - [{"name": "b3"}], + "INSERT INTO people (name, type) " + "VALUES (:name, :type)", + [{"name": "b3", "type": "boss"}], ), ], ), @@ -874,15 +928,79 @@ class BulkInheritanceTest(BulkTest, fixtures.MappedTest): ), ) + @testing.combinations("update_mappings", "update_stmt") + def test_bulk_update(self, statement_type): + Person, Engineer, Manager, Boss = self.classes( + "Person", "Engineer", "Manager", "Boss" + ) + + s = fixture_session() + + b1, b2, b3 = ( + Boss(name="b1", status="s1", manager_name="mn1", golf_swing="g1"), + Boss(name="b2", status="s2", manager_name="mn2", golf_swing="g2"), + Boss(name="b3", status="s3", manager_name="mn3", golf_swing="g3"), + ) + s.add_all([b1, b2, b3]) + s.commit() + + # slight non-convenient thing. we have to fill in boss_id here + # for update, this is not sent along automatically. this is not a + # new behavior in bulk + new_data = [ + { + "person_id": b1.person_id, + "boss_id": b1.boss_id, + "name": "b1_updated", + "manager_name": "mn1_updated", + }, + { + "person_id": b3.person_id, + "boss_id": b3.boss_id, + "manager_name": "mn2_updated", + "golf_swing": "g1_updated", + }, + ] + + if statement_type == "update_mappings": + with self.sql_execution_asserter() as asserter: + s.bulk_update_mappings(Boss, new_data) + elif statement_type == "update_stmt": + with self.sql_execution_asserter() as asserter: + s.execute(update(Boss), new_data) + + asserter.assert_( + CompiledSQL( + "UPDATE people SET name=:name WHERE " + "people.person_id = :people_person_id", + [{"name": "b1_updated", "people_person_id": 1}], + ), + CompiledSQL( + "UPDATE managers SET manager_name=:manager_name WHERE " + "managers.person_id = :managers_person_id", + [ + {"manager_name": "mn1_updated", "managers_person_id": 1}, + {"manager_name": "mn2_updated", "managers_person_id": 3}, + ], + ), + CompiledSQL( + "UPDATE boss SET golf_swing=:golf_swing WHERE " + "boss.boss_id = :boss_boss_id", + [{"golf_swing": "g1_updated", "boss_boss_id": 3}], + ), + ) + class BulkIssue6793Test(BulkTest, fixtures.DeclarativeMappedTest): + __backend__ = True + @classmethod def setup_classes(cls): Base = cls.DeclarativeBasic class User(Base): __tablename__ = "users" - id = Column(Integer, primary_key=True) + id = Column(Integer, Identity(), primary_key=True) name = Column(String(255), nullable=False) def test_issue_6793(self): @@ -907,7 +1025,8 @@ class BulkIssue6793Test(BulkTest, fixtures.DeclarativeMappedTest): [{"name": "A"}, {"name": "B"}], ), CompiledSQL( - "INSERT INTO users (name) VALUES (:name)", + "INSERT INTO users (name) VALUES (:name) " + "RETURNING users.id", [{"name": "C"}, {"name": "D"}], ), ], diff --git a/test/orm/dml/test_bulk_statements.py b/test/orm/dml/test_bulk_statements.py new file mode 100644 index 000000000..0cca9e6f5 --- /dev/null +++ b/test/orm/dml/test_bulk_statements.py @@ -0,0 +1,1199 @@ +from __future__ import annotations + +from typing import Any +from typing import List +from typing import Optional +import uuid + +from sqlalchemy import exc +from sqlalchemy import ForeignKey +from sqlalchemy import func +from sqlalchemy import Identity +from sqlalchemy import insert +from sqlalchemy import inspect +from sqlalchemy import literal_column +from sqlalchemy import select +from sqlalchemy import String +from sqlalchemy import testing +from sqlalchemy import update +from sqlalchemy.orm import aliased +from sqlalchemy.orm import load_only +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.testing import config +from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises_message +from sqlalchemy.testing import fixtures +from sqlalchemy.testing import mock +from sqlalchemy.testing import provision +from sqlalchemy.testing.assertsql import CompiledSQL +from sqlalchemy.testing.fixtures import fixture_session + + +class NoReturningTest(fixtures.TestBase): + def test_no_returning_error(self, decl_base): + class A(fixtures.ComparableEntity, decl_base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(Identity(), primary_key=True) + data: Mapped[str] + x: Mapped[Optional[int]] = mapped_column("xcol") + + decl_base.metadata.create_all(testing.db) + s = fixture_session() + + if testing.requires.insert_executemany_returning.enabled: + result = s.scalars( + insert(A).returning(A), + [ + {"data": "d3", "x": 5}, + {"data": "d4", "x": 6}, + ], + ) + eq_(result.all(), [A(data="d3", x=5), A(data="d4", x=6)]) + + else: + with expect_raises_message( + exc.InvalidRequestError, + "Can't use explicit RETURNING for bulk INSERT operation", + ): + s.scalars( + insert(A).returning(A), + [ + {"data": "d3", "x": 5}, + {"data": "d4", "x": 6}, + ], + ) + + def test_omit_returning_ok(self, decl_base): + class A(decl_base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(Identity(), primary_key=True) + data: Mapped[str] + x: Mapped[Optional[int]] = mapped_column("xcol") + + decl_base.metadata.create_all(testing.db) + s = fixture_session() + + s.execute( + insert(A), + [ + {"data": "d3", "x": 5}, + {"data": "d4", "x": 6}, + ], + ) + eq_( + s.execute(select(A.data, A.x).order_by(A.id)).all(), + [("d3", 5), ("d4", 6)], + ) + + +class BulkDMLReturningInhTest: + def test_insert_col_key_also_works_currently(self): + """using the column key, not mapped attr key. + + right now this passes through to the INSERT. when doing this with + an UPDATE, it tends to fail because the synchronize session + strategies can't match "xcol" back. however w/ INSERT we aren't + doing that, so there's no place this gets checked. UPDATE also + succeeds if synchronize_session is turned off. + + """ + A, B = self.classes("A", "B") + + s = fixture_session() + s.execute(insert(A).values(type="a", data="d", xcol=10)) + eq_(s.scalars(select(A.x)).all(), [10]) + + @testing.combinations(True, False, argnames="use_returning") + def test_heterogeneous_keys(self, use_returning): + A, B = self.classes("A", "B") + + values = [ + {"data": "d3", "x": 5, "type": "a"}, + {"data": "d4", "x": 6, "type": "a"}, + {"data": "d5", "type": "a"}, + {"data": "d6", "x": 8, "y": 9, "type": "a"}, + {"data": "d7", "x": 12, "y": 12, "type": "a"}, + {"data": "d8", "x": 7, "type": "a"}, + ] + + s = fixture_session() + + stmt = insert(A) + if use_returning: + stmt = stmt.returning(A) + + with self.sql_execution_asserter() as asserter: + result = s.execute(stmt, values) + + if inspect(B).single: + single_inh = ", a.bd, a.zcol, a.q" + else: + single_inh = "" + + if use_returning: + asserter.assert_( + CompiledSQL( + "INSERT INTO a (type, data, xcol) VALUES " + "(:type, :data, :xcol) " + f"RETURNING a.id, a.type, a.data, a.xcol, a.y{single_inh}", + [ + {"type": "a", "data": "d3", "xcol": 5}, + {"type": "a", "data": "d4", "xcol": 6}, + ], + ), + CompiledSQL( + "INSERT INTO a (type, data) VALUES (:type, :data) " + f"RETURNING a.id, a.type, a.data, a.xcol, a.y{single_inh}", + [{"type": "a", "data": "d5"}], + ), + CompiledSQL( + "INSERT INTO a (type, data, xcol, y) " + "VALUES (:type, :data, :xcol, :y) " + f"RETURNING a.id, a.type, a.data, a.xcol, a.y{single_inh}", + [ + {"type": "a", "data": "d6", "xcol": 8, "y": 9}, + {"type": "a", "data": "d7", "xcol": 12, "y": 12}, + ], + ), + CompiledSQL( + "INSERT INTO a (type, data, xcol) " + "VALUES (:type, :data, :xcol) " + f"RETURNING a.id, a.type, a.data, a.xcol, a.y{single_inh}", + [{"type": "a", "data": "d8", "xcol": 7}], + ), + ) + else: + asserter.assert_( + CompiledSQL( + "INSERT INTO a (type, data, xcol) VALUES " + "(:type, :data, :xcol)", + [ + {"type": "a", "data": "d3", "xcol": 5}, + {"type": "a", "data": "d4", "xcol": 6}, + ], + ), + CompiledSQL( + "INSERT INTO a (type, data) VALUES (:type, :data)", + [{"type": "a", "data": "d5"}], + ), + CompiledSQL( + "INSERT INTO a (type, data, xcol, y) " + "VALUES (:type, :data, :xcol, :y)", + [ + {"type": "a", "data": "d6", "xcol": 8, "y": 9}, + {"type": "a", "data": "d7", "xcol": 12, "y": 12}, + ], + ), + CompiledSQL( + "INSERT INTO a (type, data, xcol) " + "VALUES (:type, :data, :xcol)", + [{"type": "a", "data": "d8", "xcol": 7}], + ), + ) + + if use_returning: + eq_( + result.scalars().all(), + [ + A(data="d3", id=mock.ANY, type="a", x=5, y=None), + A(data="d4", id=mock.ANY, type="a", x=6, y=None), + A(data="d5", id=mock.ANY, type="a", x=None, y=None), + A(data="d6", id=mock.ANY, type="a", x=8, y=9), + A(data="d7", id=mock.ANY, type="a", x=12, y=12), + A(data="d8", id=mock.ANY, type="a", x=7, y=None), + ], + ) + + @testing.combinations( + "strings", + "cols", + "strings_w_exprs", + "cols_w_exprs", + argnames="paramstyle", + ) + @testing.combinations( + True, + (False, testing.requires.multivalues_inserts), + argnames="single_element", + ) + def test_single_values_returning_fn(self, paramstyle, single_element): + """test using insert().values(). + + these INSERT statements go straight in as a single execute without any + insertmanyreturning or bulk_insert_mappings thing going on. the + advantage here is that SQL expressions can be used in the values also. + Disadvantage is none of the automation for inheritance mappers. + + """ + A, B = self.classes("A", "B") + + if paramstyle == "strings": + values = [ + {"data": "d3", "x": 5, "y": 9, "type": "a"}, + {"data": "d4", "x": 10, "y": 8, "type": "a"}, + ] + elif paramstyle == "cols": + values = [ + {A.data: "d3", A.x: 5, A.y: 9, A.type: "a"}, + {A.data: "d4", A.x: 10, A.y: 8, A.type: "a"}, + ] + elif paramstyle == "strings_w_exprs": + values = [ + {"data": func.lower("D3"), "x": 5, "y": 9, "type": "a"}, + { + "data": "d4", + "x": literal_column("5") + 5, + "y": 8, + "type": "a", + }, + ] + elif paramstyle == "cols_w_exprs": + values = [ + {A.data: func.lower("D3"), A.x: 5, A.y: 9, A.type: "a"}, + { + A.data: "d4", + A.x: literal_column("5") + 5, + A.y: 8, + A.type: "a", + }, + ] + else: + assert False + + s = fixture_session() + + if single_element: + if paramstyle.startswith("strings"): + stmt = ( + insert(A) + .values(**values[0]) + .returning(A, func.upper(A.data, type_=String)) + ) + else: + stmt = ( + insert(A) + .values(values[0]) + .returning(A, func.upper(A.data, type_=String)) + ) + else: + stmt = ( + insert(A) + .values(values) + .returning(A, func.upper(A.data, type_=String)) + ) + + for i in range(3): + result = s.execute(stmt) + expected: List[Any] = [(A(data="d3", x=5, y=9), "D3")] + if not single_element: + expected.append((A(data="d4", x=10, y=8), "D4")) + eq_(result.all(), expected) + + def test_bulk_w_sql_expressions(self): + A, B = self.classes("A", "B") + + data = [ + {"x": 5, "y": 9, "type": "a"}, + { + "x": 10, + "y": 8, + "type": "a", + }, + ] + + s = fixture_session() + + stmt = ( + insert(A) + .values(data=func.lower("DD")) + .returning(A, func.upper(A.data, type_=String)) + ) + + for i in range(3): + result = s.execute(stmt, data) + expected: List[Any] = [ + (A(data="dd", x=5, y=9), "DD"), + (A(data="dd", x=10, y=8), "DD"), + ] + eq_(result.all(), expected) + + def test_bulk_w_sql_expressions_subclass(self): + A, B = self.classes("A", "B") + + data = [ + {"bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4}, + {"bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8}, + ] + + s = fixture_session() + + stmt = ( + insert(B) + .values(data=func.lower("DD")) + .returning(B, func.upper(B.data, type_=String)) + ) + + for i in range(3): + result = s.execute(stmt, data) + expected: List[Any] = [ + (B(bd="bd1", data="dd", q=4, type="b", x=1, y=2, z=3), "DD"), + (B(bd="bd2", data="dd", q=8, type="b", x=5, y=6, z=7), "DD"), + ] + eq_(result.all(), expected) + + @testing.combinations(True, False, argnames="use_ordered") + def test_bulk_upd_w_sql_expressions_no_ordered_values(self, use_ordered): + A, B = self.classes("A", "B") + + s = fixture_session() + + stmt = update(B).ordered_values( + ("data", func.lower("DD_UPDATE")), + ("z", literal_column("3 + 12")), + ) + with expect_raises_message( + exc.InvalidRequestError, + r"bulk ORM UPDATE does not support ordered_values\(\) " + r"for custom UPDATE", + ): + s.execute( + stmt, + [ + {"id": 5, "bd": "bd1_updated"}, + {"id": 6, "bd": "bd2_updated"}, + ], + ) + + def test_bulk_upd_w_sql_expressions_subclass(self): + A, B = self.classes("A", "B") + + s = fixture_session() + + data = [ + {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4}, + {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8}, + ] + ids = s.scalars(insert(B).returning(B.id), data).all() + + stmt = update(B).values( + data=func.lower("DD_UPDATE"), z=literal_column("3 + 12") + ) + + result = s.execute( + stmt, + [ + {"id": ids[0], "bd": "bd1_updated"}, + {"id": ids[1], "bd": "bd2_updated"}, + ], + ) + + # this is a nullresult at the moment + assert result is not None + + eq_( + s.scalars(select(B)).all(), + [ + B( + bd="bd1_updated", + data="dd_update", + id=ids[0], + q=4, + type="b", + x=1, + y=2, + z=15, + ), + B( + bd="bd2_updated", + data="dd_update", + id=ids[1], + q=8, + type="b", + x=5, + y=6, + z=15, + ), + ], + ) + + def test_single_returning_fn(self): + A, B = self.classes("A", "B") + + s = fixture_session() + for i in range(3): + result = s.execute( + insert(A).returning(A, func.upper(A.data, type_=String)), + [{"data": "d3"}, {"data": "d4"}], + ) + eq_(result.all(), [(A(data="d3"), "D3"), (A(data="d4"), "D4")]) + + @testing.combinations( + True, + False, + argnames="single_element", + ) + def test_subclass_no_returning(self, single_element): + A, B = self.classes("A", "B") + + s = fixture_session() + + if single_element: + data = {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4} + else: + data = [ + {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4}, + {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8}, + ] + + result = s.execute(insert(B), data) + assert result._soft_closed + + @testing.combinations( + True, + False, + argnames="single_element", + ) + def test_subclass_load_only(self, single_element): + """test that load_only() prevents additional attributes from being + populated. + + """ + A, B = self.classes("A", "B") + + s = fixture_session() + + if single_element: + data = {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4} + else: + data = [ + {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4}, + {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8}, + ] + + for i in range(3): + # tests both caching and that the data dictionaries aren't + # mutated... + result = s.execute( + insert(B).returning(B).options(load_only(B.data, B.y, B.q)), + data, + ) + objects = result.scalars().all() + for obj in objects: + assert "data" in obj.__dict__ + assert "q" in obj.__dict__ + assert "z" not in obj.__dict__ + assert "x" not in obj.__dict__ + + expected = [ + B(data="d3", bd="bd1", x=1, y=2, z=3, q=4), + ] + if not single_element: + expected.append(B(data="d4", bd="bd2", x=5, y=6, z=7, q=8)) + eq_(objects, expected) + + @testing.combinations( + True, + False, + argnames="single_element", + ) + def test_subclass_load_only_doesnt_fetch_cols(self, single_element): + """test that when using load_only(), the actual INSERT statement + does not include the deferred columns + + """ + A, B = self.classes("A", "B") + + s = fixture_session() + + data = [ + {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4}, + {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8}, + ] + if single_element: + data = data[0] + + with self.sql_execution_asserter() as asserter: + + # tests both caching and that the data dictionaries aren't + # mutated... + + # note that if we don't put B.id here, accessing .id on the + # B object for joined inheritance is triggering a SELECT + # (and not for single inheritance). this seems not great, but is + # likely a different issue + result = s.execute( + insert(B) + .returning(B) + .options(load_only(B.id, B.data, B.y, B.q)), + data, + ) + objects = result.scalars().all() + if single_element: + id0 = objects[0].id + id1 = None + else: + id0, id1 = objects[0].id, objects[1].id + + if inspect(B).single or inspect(B).concrete: + expected_params = [ + { + "type": "b", + "data": "d3", + "xcol": 1, + "y": 2, + "bd": "bd1", + "zcol": 3, + "q": 4, + }, + { + "type": "b", + "data": "d4", + "xcol": 5, + "y": 6, + "bd": "bd2", + "zcol": 7, + "q": 8, + }, + ] + if single_element: + expected_params[1:] = [] + # RETURNING only includes PK, discriminator, then the cols + # we asked for data, y, q. xcol, z, bd are omitted + + if inspect(B).single: + asserter.assert_( + CompiledSQL( + "INSERT INTO a (type, data, xcol, y, bd, zcol, q) " + "VALUES " + "(:type, :data, :xcol, :y, :bd, :zcol, :q) " + "RETURNING a.id, a.type, a.data, a.y, a.q", + expected_params, + ), + ) + else: + asserter.assert_( + CompiledSQL( + "INSERT INTO b (type, data, xcol, y, bd, zcol, q) " + "VALUES " + "(:type, :data, :xcol, :y, :bd, :zcol, :q) " + "RETURNING b.id, b.type, b.data, b.y, b.q", + expected_params, + ), + ) + else: + a_data = [ + {"type": "b", "data": "d3", "xcol": 1, "y": 2}, + {"type": "b", "data": "d4", "xcol": 5, "y": 6}, + ] + b_data = [ + {"id": id0, "bd": "bd1", "zcol": 3, "q": 4}, + {"id": id1, "bd": "bd2", "zcol": 7, "q": 8}, + ] + if single_element: + a_data[1:] = [] + b_data[1:] = [] + # RETURNING only includes PK, discriminator, then the cols + # we asked for data, y, q. xcol, z, bd are omitted. plus they + # are broken out correctly in the two statements. + asserter.assert_( + CompiledSQL( + "INSERT INTO a (type, data, xcol, y) VALUES " + "(:type, :data, :xcol, :y) " + "RETURNING a.id, a.type, a.data, a.y", + a_data, + ), + CompiledSQL( + "INSERT INTO b (id, bd, zcol, q) " + "VALUES (:id, :bd, :zcol, :q) " + "RETURNING b.id, b.q", + b_data, + ), + ) + + @testing.combinations( + True, + False, + argnames="single_element", + ) + def test_subclass_returning_bind_expr(self, single_element): + A, B = self.classes("A", "B") + + s = fixture_session() + + if single_element: + data = {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4} + else: + data = [ + {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4}, + {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8}, + ] + # note there's a fix in compiler.py -> + # _deliver_insertmanyvalues_batches + # for this re: the parameter rendering that isn't tested anywhere + # else. two different versions of the bug for both positional + # and non + result = s.execute(insert(B).returning(B.data, B.y, B.q + 5), data) + if single_element: + eq_(result.all(), [("d3", 2, 9)]) + else: + eq_(result.all(), [("d3", 2, 9), ("d4", 6, 13)]) + + def test_subclass_bulk_update(self): + A, B = self.classes("A", "B") + + s = fixture_session() + + data = [ + {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4}, + {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8}, + ] + ids = s.scalars(insert(B).returning(B.id), data).all() + + result = s.execute( + update(B), + [ + {"id": ids[0], "data": "d3_updated", "bd": "bd1_updated"}, + {"id": ids[1], "data": "d4_updated", "bd": "bd2_updated"}, + ], + ) + + # this is a nullresult at the moment + assert result is not None + + eq_( + s.scalars(select(B)).all(), + [ + B( + bd="bd1_updated", + data="d3_updated", + id=ids[0], + q=4, + type="b", + x=1, + y=2, + z=3, + ), + B( + bd="bd2_updated", + data="d4_updated", + id=ids[1], + q=8, + type="b", + x=5, + y=6, + z=7, + ), + ], + ) + + @testing.combinations(True, False, argnames="single_element") + def test_subclass_return_just_subclass_ids(self, single_element): + A, B = self.classes("A", "B") + + s = fixture_session() + + if single_element: + data = {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4} + else: + data = [ + {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4}, + {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8}, + ] + + ids = s.scalars(insert(B).returning(B.id), data).all() + actual_ids = s.scalars(select(B.id).order_by(B.data)).all() + + eq_(ids, actual_ids) + + @testing.combinations( + "orm", + "bulk", + argnames="insert_strategy", + ) + @testing.requires.provisioned_upsert + def test_base_class_upsert(self, insert_strategy): + """upsert is really tricky. if you dont have any data updated, + then you dont get the rows back and things dont work so well. + + so we need to be careful how much we document this because this is + still a thorny use case. + + """ + A = self.classes.A + + s = fixture_session() + + initial_data = [ + {"data": "d3", "x": 1, "y": 2, "q": 4}, + {"data": "d4", "x": 5, "y": 6, "q": 8}, + ] + ids = s.scalars(insert(A).returning(A.id), initial_data).all() + + upsert_data = [ + { + "id": ids[0], + "type": "a", + "data": "d3", + "x": 1, + "y": 2, + }, + { + "id": 32, + "type": "a", + "data": "d32", + "x": 19, + "y": 5, + }, + { + "id": ids[1], + "type": "a", + "data": "d4", + "x": 5, + "y": 6, + }, + { + "id": 28, + "type": "a", + "data": "d28", + "x": 9, + "y": 15, + }, + ] + + stmt = provision.upsert( + config, + A, + (A,), + lambda inserted: {"data": inserted.data + " upserted"}, + ) + + if insert_strategy == "orm": + result = s.scalars(stmt.values(upsert_data)) + elif insert_strategy == "bulk": + result = s.scalars(stmt, upsert_data) + else: + assert False + + eq_( + result.all(), + [ + A(data="d3 upserted", id=ids[0], type="a", x=1, y=2), + A(data="d32", id=32, type="a", x=19, y=5), + A(data="d4 upserted", id=ids[1], type="a", x=5, y=6), + A(data="d28", id=28, type="a", x=9, y=15), + ], + ) + + @testing.combinations( + "orm", + "bulk", + argnames="insert_strategy", + ) + @testing.requires.provisioned_upsert + def test_subclass_upsert(self, insert_strategy): + """note this is overridden in the joined version to expect failure""" + + A, B = self.classes("A", "B") + + s = fixture_session() + + idd3 = 1 + idd4 = 2 + id32 = 32 + id28 = 28 + + initial_data = [ + { + "id": idd3, + "data": "d3", + "bd": "bd1", + "x": 1, + "y": 2, + "z": 3, + "q": 4, + }, + { + "id": idd4, + "data": "d4", + "bd": "bd2", + "x": 5, + "y": 6, + "z": 7, + "q": 8, + }, + ] + ids = s.scalars(insert(B).returning(B.id), initial_data).all() + + upsert_data = [ + { + "id": ids[0], + "type": "b", + "data": "d3", + "bd": "bd1_upserted", + "x": 1, + "y": 2, + "z": 33, + "q": 44, + }, + { + "id": id32, + "type": "b", + "data": "d32", + "bd": "bd 32", + "x": 19, + "y": 5, + "z": 20, + "q": 21, + }, + { + "id": ids[1], + "type": "b", + "bd": "bd2_upserted", + "data": "d4", + "x": 5, + "y": 6, + "z": 77, + "q": 88, + }, + { + "id": id28, + "type": "b", + "data": "d28", + "bd": "bd 28", + "x": 9, + "y": 15, + "z": 10, + "q": 11, + }, + ] + + stmt = provision.upsert( + config, + B, + (B,), + lambda inserted: { + "data": inserted.data + " upserted", + "bd": inserted.bd + " upserted", + }, + ) + result = s.scalars(stmt, upsert_data) + eq_( + result.all(), + [ + B( + bd="bd1_upserted upserted", + data="d3 upserted", + id=ids[0], + q=4, + type="b", + x=1, + y=2, + z=3, + ), + B( + bd="bd 32", + data="d32", + id=32, + q=21, + type="b", + x=19, + y=5, + z=20, + ), + B( + bd="bd2_upserted upserted", + data="d4 upserted", + id=ids[1], + q=8, + type="b", + x=5, + y=6, + z=7, + ), + B( + bd="bd 28", + data="d28", + id=28, + q=11, + type="b", + x=9, + y=15, + z=10, + ), + ], + ) + + +class BulkDMLReturningJoinedInhTest( + BulkDMLReturningInhTest, fixtures.DeclarativeMappedTest +): + + __requires__ = ("insert_returning",) + __backend__ = True + + @classmethod + def setup_classes(cls): + decl_base = cls.DeclarativeBasic + + class A(fixtures.ComparableEntity, decl_base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(Identity(), primary_key=True) + type: Mapped[str] + data: Mapped[str] + x: Mapped[Optional[int]] = mapped_column("xcol") + y: Mapped[Optional[int]] + + __mapper_args__ = { + "polymorphic_identity": "a", + "polymorphic_on": "type", + } + + class B(A): + __tablename__ = "b" + id: Mapped[int] = mapped_column( + ForeignKey("a.id"), primary_key=True + ) + bd: Mapped[str] + z: Mapped[Optional[int]] = mapped_column("zcol") + q: Mapped[Optional[int]] + + __mapper_args__ = {"polymorphic_identity": "b"} + + @testing.combinations( + "orm", + "bulk", + argnames="insert_strategy", + ) + @testing.combinations( + True, + False, + argnames="single_param", + ) + @testing.requires.provisioned_upsert + def test_subclass_upsert(self, insert_strategy, single_param): + A, B = self.classes("A", "B") + + s = fixture_session() + + initial_data = [ + {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4}, + {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8}, + ] + ids = s.scalars(insert(B).returning(B.id), initial_data).all() + + upsert_data = [ + { + "id": ids[0], + "type": "b", + }, + { + "id": 32, + "type": "b", + }, + ] + if single_param: + upsert_data = upsert_data[0] + + stmt = provision.upsert( + config, + B, + (B,), + lambda inserted: { + "bd": inserted.bd + " upserted", + }, + ) + + with expect_raises_message( + exc.InvalidRequestError, + r"bulk INSERT with a 'post values' clause \(typically upsert\) " + r"not supported for multi-table mapper", + ): + s.scalars(stmt, upsert_data) + + +class BulkDMLReturningSingleInhTest( + BulkDMLReturningInhTest, fixtures.DeclarativeMappedTest +): + __requires__ = ("insert_returning",) + __backend__ = True + + @classmethod + def setup_classes(cls): + decl_base = cls.DeclarativeBasic + + class A(fixtures.ComparableEntity, decl_base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(Identity(), primary_key=True) + type: Mapped[str] + data: Mapped[str] + x: Mapped[Optional[int]] = mapped_column("xcol") + y: Mapped[Optional[int]] + + __mapper_args__ = { + "polymorphic_identity": "a", + "polymorphic_on": "type", + } + + class B(A): + bd: Mapped[str] = mapped_column(nullable=True) + z: Mapped[Optional[int]] = mapped_column("zcol") + q: Mapped[Optional[int]] + + __mapper_args__ = {"polymorphic_identity": "b"} + + +class BulkDMLReturningConcreteInhTest( + BulkDMLReturningInhTest, fixtures.DeclarativeMappedTest +): + __requires__ = ("insert_returning",) + __backend__ = True + + @classmethod + def setup_classes(cls): + decl_base = cls.DeclarativeBasic + + class A(fixtures.ComparableEntity, decl_base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(Identity(), primary_key=True) + type: Mapped[str] + data: Mapped[str] + x: Mapped[Optional[int]] = mapped_column("xcol") + y: Mapped[Optional[int]] + + __mapper_args__ = { + "polymorphic_identity": "a", + "polymorphic_on": "type", + } + + class B(A): + __tablename__ = "b" + id: Mapped[int] = mapped_column(Identity(), primary_key=True) + type: Mapped[str] + data: Mapped[str] + x: Mapped[Optional[int]] = mapped_column("xcol") + y: Mapped[Optional[int]] + + bd: Mapped[str] = mapped_column(nullable=True) + z: Mapped[Optional[int]] = mapped_column("zcol") + q: Mapped[Optional[int]] + + __mapper_args__ = { + "polymorphic_identity": "b", + "concrete": True, + "polymorphic_on": "type", + } + + +class CTETest(fixtures.DeclarativeMappedTest): + __requires__ = ("insert_returning", "ctes_on_dml") + __backend__ = True + + @classmethod + def setup_classes(cls): + decl_base = cls.DeclarativeBasic + + class User(fixtures.ComparableEntity, decl_base): + __tablename__ = "users" + id: Mapped[uuid.UUID] = mapped_column(primary_key=True) + username: Mapped[str] + + @testing.combinations( + ("cte_aliased", True), + ("cte", False), + argnames="wrap_cte_in_aliased", + id_="ia", + ) + @testing.combinations( + ("use_union", True), + ("no_union", False), + argnames="use_a_union", + id_="ia", + ) + @testing.combinations( + "from_statement", "aliased", "direct", argnames="fetch_entity_type" + ) + def test_select_from_insert_cte( + self, wrap_cte_in_aliased, use_a_union, fetch_entity_type + ): + """test the use case from #8544; SELECT that selects from a + CTE INSERT...RETURNING. + + """ + User = self.classes.User + + id_ = uuid.uuid4() + + cte = ( + insert(User) + .values(id=id_, username="some user") + .returning(User) + .cte() + ) + if wrap_cte_in_aliased: + cte = aliased(User, cte) + + if use_a_union: + stmt = select(User).where(User.id == id_).union(select(cte)) + else: + stmt = select(cte) + + if fetch_entity_type == "from_statement": + outer_stmt = select(User).from_statement(stmt) + expect_entity = True + elif fetch_entity_type == "aliased": + outer_stmt = select(aliased(User, stmt.subquery())) + expect_entity = True + elif fetch_entity_type == "direct": + outer_stmt = stmt + expect_entity = not use_a_union and wrap_cte_in_aliased + else: + assert False + + sess = fixture_session() + with self.sql_execution_asserter() as asserter: + + if not expect_entity: + row = sess.execute(outer_stmt).one() + eq_(row, (id_, "some user")) + else: + new_user = sess.scalars(outer_stmt).one() + eq_(new_user, User(id=id_, username="some user")) + + cte_sql = ( + "(INSERT INTO users (id, username) " + "VALUES (:param_1, :param_2) " + "RETURNING users.id, users.username)" + ) + + if fetch_entity_type == "aliased" and not use_a_union: + expected = ( + f"WITH anon_2 AS {cte_sql} " + "SELECT anon_1.id, anon_1.username " + "FROM (SELECT anon_2.id AS id, anon_2.username AS username " + "FROM anon_2) AS anon_1" + ) + elif not use_a_union: + expected = ( + f"WITH anon_1 AS {cte_sql} " + "SELECT anon_1.id, anon_1.username FROM anon_1" + ) + elif fetch_entity_type == "aliased": + expected = ( + f"WITH anon_2 AS {cte_sql} SELECT anon_1.id, anon_1.username " + "FROM (SELECT users.id AS id, users.username AS username " + "FROM users WHERE users.id = :id_1 " + "UNION SELECT anon_2.id AS id, anon_2.username AS username " + "FROM anon_2) AS anon_1" + ) + else: + expected = ( + f"WITH anon_1 AS {cte_sql} " + "SELECT users.id, users.username FROM users " + "WHERE users.id = :id_1 " + "UNION SELECT anon_1.id, anon_1.username FROM anon_1" + ) + + asserter.assert_( + CompiledSQL(expected, [{"param_1": id_, "param_2": "some user"}]) + ) diff --git a/test/orm/test_evaluator.py b/test/orm/dml/test_evaluator.py index ff40cd201..4b903b863 100644 --- a/test/orm/test_evaluator.py +++ b/test/orm/dml/test_evaluator.py @@ -324,7 +324,6 @@ class EvaluateTest(fixtures.MappedTest): """test #3162""" User = self.classes.User - with expect_raises_message( evaluator.UnevaluatableError, r"Custom operator '\^\^' can't be evaluated in " diff --git a/test/orm/test_update_delete.py b/test/orm/dml/test_update_delete_where.py index 1e93f88de..836feb659 100644 --- a/test/orm/test_update_delete.py +++ b/test/orm/dml/test_update_delete_where.py @@ -1,3 +1,4 @@ +from sqlalchemy import bindparam from sqlalchemy import Boolean from sqlalchemy import case from sqlalchemy import column @@ -7,6 +8,7 @@ from sqlalchemy import exc from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import insert +from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import lambda_stmt from sqlalchemy import MetaData @@ -17,6 +19,7 @@ from sqlalchemy import testing from sqlalchemy import text from sqlalchemy import update from sqlalchemy.orm import backref +from sqlalchemy.orm import exc as orm_exc from sqlalchemy.orm import joinedload from sqlalchemy.orm import relationship from sqlalchemy.orm import Session @@ -26,6 +29,7 @@ from sqlalchemy.orm import with_loader_criteria from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises from sqlalchemy.testing import fixtures from sqlalchemy.testing import in_ from sqlalchemy.testing import not_in @@ -123,6 +127,25 @@ class UpdateDeleteTest(fixtures.MappedTest): }, ) + def test_update_dont_use_col_key(self): + User = self.classes.User + + s = fixture_session() + + # make sure objects are present to synchronize + _ = s.query(User).all() + + with expect_raises_message( + exc.InvalidRequestError, + "Attribute name not found, can't be synchronized back " + "to objects: 'age_int'", + ): + s.execute(update(User).values(age_int=5)) + + stmt = update(User).values(age=5) + s.execute(stmt) + eq_(s.scalars(select(User.age)).all(), [5, 5, 5, 5]) + @testing.combinations("table", "mapper", "both", argnames="bind_type") @testing.combinations( "update", "insert", "delete", argnames="statement_type" @@ -162,7 +185,7 @@ class UpdateDeleteTest(fixtures.MappedTest): assert_raises_message( exc.ArgumentError, "Valid strategies for session synchronization " - "are 'evaluate', 'fetch', False", + "are 'auto', 'evaluate', 'fetch', False", s.query(User).update, {}, synchronize_session="fake", @@ -351,6 +374,12 @@ class UpdateDeleteTest(fixtures.MappedTest): def test_evaluate_dont_refresh_expired_objects( self, expire_jane_age, add_filter_criteria ): + """test #5664. + + approach is revised in SQLAlchemy 2.0 to not pre-emptively + unexpire the involved attributes + + """ User = self.classes.User sess = fixture_session() @@ -379,15 +408,10 @@ class UpdateDeleteTest(fixtures.MappedTest): if add_filter_criteria: if expire_jane_age: asserter.assert_( - # it has to unexpire jane.name, because jane is not fully - # expired and the criteria needs to look at this particular - # key - CompiledSQL( - "SELECT users.age_int AS users_age_int, " - "users.name AS users_name FROM users " - "WHERE users.id = :pk_1", - [{"pk_1": 4}], - ), + # previously, this would unexpire the attribute and + # cause an additional SELECT. The + # 2.0 approach is that if the object has expired attrs + # we just expire the whole thing, avoiding SQL up front CompiledSQL( "UPDATE users " "SET age_int=(users.age_int + :age_int_1) " @@ -397,14 +421,10 @@ class UpdateDeleteTest(fixtures.MappedTest): ) else: asserter.assert_( - # it has to unexpire jane.name, because jane is not fully - # expired and the criteria needs to look at this particular - # key - CompiledSQL( - "SELECT users.name AS users_name FROM users " - "WHERE users.id = :pk_1", - [{"pk_1": 4}], - ), + # previously, this would unexpire the attribute and + # cause an additional SELECT. The + # 2.0 approach is that if the object has expired attrs + # we just expire the whole thing, avoiding SQL up front CompiledSQL( "UPDATE users SET " "age_int=(users.age_int + :age_int_1) " @@ -443,9 +463,9 @@ class UpdateDeleteTest(fixtures.MappedTest): ), ] - if expire_jane_age and not add_filter_criteria: + if expire_jane_age: to_assert.append( - # refresh jane + # refresh jane for partial attributes CompiledSQL( "SELECT users.age_int AS users_age_int, " "users.name AS users_name FROM users " @@ -455,6 +475,75 @@ class UpdateDeleteTest(fixtures.MappedTest): ) asserter.assert_(*to_assert) + @testing.combinations(True, False, argnames="is_evaluable") + def test_auto_synchronize(self, is_evaluable): + User = self.classes.User + + sess = fixture_session() + + john, jack, jill, jane = sess.query(User).order_by(User.id).all() + + if is_evaluable: + crit = or_(User.name == "jack", User.name == "jane") + else: + crit = case((User.name.in_(["jack", "jane"]), True), else_=False) + + with self.sql_execution_asserter() as asserter: + sess.execute(update(User).where(crit).values(age=User.age + 10)) + + if is_evaluable: + asserter.assert_( + CompiledSQL( + "UPDATE users SET age_int=(users.age_int + :age_int_1) " + "WHERE users.name = :name_1 OR users.name = :name_2", + [{"age_int_1": 10, "name_1": "jack", "name_2": "jane"}], + ), + ) + elif testing.db.dialect.update_returning: + asserter.assert_( + CompiledSQL( + "UPDATE users SET age_int=(users.age_int + :age_int_1) " + "WHERE CASE WHEN (users.name IN (__[POSTCOMPILE_name_1])) " + "THEN :param_1 ELSE :param_2 END = 1 RETURNING users.id", + [ + { + "age_int_1": 10, + "name_1": ["jack", "jane"], + "param_1": True, + "param_2": False, + } + ], + ), + ) + else: + asserter.assert_( + CompiledSQL( + "SELECT users.id FROM users WHERE CASE WHEN " + "(users.name IN (__[POSTCOMPILE_name_1])) " + "THEN :param_1 ELSE :param_2 END = 1", + [ + { + "name_1": ["jack", "jane"], + "param_1": True, + "param_2": False, + } + ], + ), + CompiledSQL( + "UPDATE users SET age_int=(users.age_int + :age_int_1) " + "WHERE CASE WHEN (users.name IN (__[POSTCOMPILE_name_1])) " + "THEN :param_1 ELSE :param_2 END = 1", + [ + { + "age_int_1": 10, + "name_1": ["jack", "jane"], + "param_1": True, + "param_2": False, + } + ], + ), + ) + def test_fetch_dont_refresh_expired_objects(self): User = self.classes.User @@ -518,17 +607,25 @@ class UpdateDeleteTest(fixtures.MappedTest): ), ) - def test_delete(self): + @testing.combinations(False, None, "auto", "evaluate", "fetch") + def test_delete(self, synchronize_session): User = self.classes.User sess = fixture_session() john, jack, jill, jane = sess.query(User).order_by(User.id).all() - sess.query(User).filter( + + stmt = delete(User).filter( or_(User.name == "john", User.name == "jill") - ).delete() + ) + if synchronize_session is not None: + stmt = stmt.execution_options( + synchronize_session=synchronize_session + ) + sess.execute(stmt) - assert john not in sess and jill not in sess + if synchronize_session not in (False, None): + assert john not in sess and jill not in sess eq_(sess.query(User).order_by(User.id).all(), [jack, jane]) @@ -629,6 +726,33 @@ class UpdateDeleteTest(fixtures.MappedTest): eq_(sess.query(User).order_by(User.id).all(), [jack, jill, jane]) + def test_update_multirow_not_supported(self): + User = self.classes.User + + sess = fixture_session() + + with expect_raises_message( + exc.InvalidRequestError, + "WHERE clause with bulk ORM UPDATE not supported " "right now.", + ): + sess.execute( + update(User).where(User.id == bindparam("id")), + [{"id": 1, "age": 27}, {"id": 2, "age": 37}], + ) + + def test_delete_bulk_not_supported(self): + User = self.classes.User + + sess = fixture_session() + + with expect_raises_message( + exc.InvalidRequestError, "Bulk ORM DELETE not supported right now." + ): + sess.execute( + delete(User), + [{"id": 1}, {"id": 2}], + ) + def test_update(self): User, users = self.classes.User, self.tables.users @@ -640,6 +764,7 @@ class UpdateDeleteTest(fixtures.MappedTest): ) eq_([john.age, jack.age, jill.age, jane.age], [25, 37, 29, 27]) + eq_( sess.query(User.age).order_by(User.id).all(), list(zip([25, 37, 29, 27])), @@ -974,7 +1099,7 @@ class UpdateDeleteTest(fixtures.MappedTest): ) @testing.requires.update_returning - def test_update_explicit_returning(self): + def test_update_evaluate_w_explicit_returning(self): User = self.classes.User sess = fixture_session() @@ -987,6 +1112,7 @@ class UpdateDeleteTest(fixtures.MappedTest): .filter(User.age > 29) .values({"age": User.age - 10}) .returning(User.id) + .execution_options(synchronize_session="evaluate") ) rows = sess.execute(stmt).all() @@ -1006,24 +1132,41 @@ class UpdateDeleteTest(fixtures.MappedTest): ) @testing.requires.update_returning - def test_no_fetch_w_explicit_returning(self): + @testing.combinations("update", "delete", argnames="crud_type") + def test_fetch_w_explicit_returning(self, crud_type): User = self.classes.User sess = fixture_session() - stmt = ( - update(User) - .filter(User.age > 29) - .values({"age": User.age - 10}) - .execution_options(synchronize_session="fetch") - .returning(User.id) - ) - with expect_raises_message( - exc.InvalidRequestError, - r"Can't use synchronize_session='fetch' " - r"with explicit returning\(\)", - ): - sess.execute(stmt) + if crud_type == "update": + stmt = ( + update(User) + .filter(User.age > 29) + .values({"age": User.age - 10}) + .execution_options(synchronize_session="fetch") + .returning(User, User.name) + ) + expected = [ + (User(age=37), "jack"), + (User(age=27), "jane"), + ] + elif crud_type == "delete": + stmt = ( + delete(User) + .filter(User.age > 29) + .execution_options(synchronize_session="fetch") + .returning(User, User.name) + ) + expected = [ + (User(age=47), "jack"), + (User(age=37), "jane"), + ] + else: + assert False + + result = sess.execute(stmt) + + eq_(result.all(), expected) @testing.combinations(True, False, argnames="implicit_returning") def test_delete_fetch_returning(self, implicit_returning): @@ -1142,7 +1285,8 @@ class UpdateDeleteTest(fixtures.MappedTest): list(zip([25, 47, 44, 37])), ) - def test_update_changes_resets_dirty(self): + @testing.combinations("orm", "bulk") + def test_update_changes_resets_dirty(self, update_type): User = self.classes.User sess = fixture_session(autoflush=False) @@ -1155,9 +1299,30 @@ class UpdateDeleteTest(fixtures.MappedTest): # autoflush is false. therefore our '50' and '37' are getting # blown away by this operation. - sess.query(User).filter(User.age > 29).update( - {"age": User.age - 10}, synchronize_session="evaluate" - ) + if update_type == "orm": + sess.execute( + update(User) + .filter(User.age > 29) + .values({"age": User.age - 10}), + execution_options=dict(synchronize_session="evaluate"), + ) + elif update_type == "bulk": + + data = [ + {"id": john.id, "age": 25}, + {"id": jack.id, "age": 37}, + {"id": jill.id, "age": 29}, + {"id": jane.id, "age": 27}, + ] + + sess.execute( + update(User), + data, + execution_options=dict(synchronize_session="evaluate"), + ) + + else: + assert False for x in (john, jack, jill, jane): assert not sess.is_modified(x) @@ -1171,6 +1336,93 @@ class UpdateDeleteTest(fixtures.MappedTest): assert not sess.is_modified(john) assert not sess.is_modified(jack) + @testing.combinations( + None, False, "evaluate", "fetch", argnames="synchronize_session" + ) + @testing.combinations(True, False, argnames="homogeneous_keys") + def test_bulk_update_synchronize_session( + self, synchronize_session, homogeneous_keys + ): + User = self.classes.User + + sess = fixture_session(expire_on_commit=False) + + john, jack, jill, jane = sess.query(User).order_by(User.id).all() + + if homogeneous_keys: + data = [ + {"id": john.id, "age": 35}, + {"id": jack.id, "age": 27}, + {"id": jill.id, "age": 30}, + ] + else: + data = [ + {"id": john.id, "age": 35}, + {"id": jack.id, "name": "new jack"}, + {"id": jill.id, "age": 30, "name": "new jill"}, + ] + + with self.sql_execution_asserter() as asserter: + if synchronize_session is not None: + opts = {"synchronize_session": synchronize_session} + else: + opts = {} + + if synchronize_session == "fetch": + with expect_raises_message( + exc.InvalidRequestError, + "The 'fetch' synchronization strategy is not available " + "for 'bulk' ORM updates", + ): + sess.execute(update(User), data, execution_options=opts) + return + else: + sess.execute(update(User), data, execution_options=opts) + + if homogeneous_keys: + asserter.assert_( + CompiledSQL( + "UPDATE users SET age_int=:age_int " + "WHERE users.id = :users_id", + [ + {"age_int": 35, "users_id": 1}, + {"age_int": 27, "users_id": 2}, + {"age_int": 30, "users_id": 3}, + ], + ) + ) + else: + asserter.assert_( + CompiledSQL( + "UPDATE users SET age_int=:age_int " + "WHERE users.id = :users_id", + [{"age_int": 35, "users_id": 1}], + ), + CompiledSQL( + "UPDATE users SET name=:name WHERE users.id = :users_id", + [{"name": "new jack", "users_id": 2}], + ), + CompiledSQL( + "UPDATE users SET name=:name, age_int=:age_int " + "WHERE users.id = :users_id", + [{"name": "new jill", "age_int": 30, "users_id": 3}], + ), + ) + + if synchronize_session is False: + eq_(jill.name, "jill") + eq_(jack.name, "jack") + eq_(jill.age, 29) + eq_(jack.age, 47) + else: + if not homogeneous_keys: + eq_(jill.name, "new jill") + eq_(jack.name, "new jack") + eq_(jack.age, 47) + else: + eq_(jack.age, 27) + eq_(jill.age, 30) + def test_update_changes_with_autoflush(self): User = self.classes.User @@ -1214,7 +1466,8 @@ class UpdateDeleteTest(fixtures.MappedTest): ) @testing.fails_if(lambda: not testing.db.dialect.supports_sane_rowcount) - def test_update_returns_rowcount(self): + @testing.combinations("auto", "fetch", "evaluate") + def test_update_returns_rowcount(self, synchronize_session): User = self.classes.User sess = fixture_session() @@ -1222,20 +1475,25 @@ class UpdateDeleteTest(fixtures.MappedTest): rowcount = ( sess.query(User) .filter(User.age > 29) - .update({"age": User.age + 0}) + .update( + {"age": User.age + 0}, synchronize_session=synchronize_session + ) ) eq_(rowcount, 2) rowcount = ( sess.query(User) .filter(User.age > 29) - .update({"age": User.age - 10}) + .update( + {"age": User.age - 10}, synchronize_session=synchronize_session + ) ) eq_(rowcount, 2) # test future result = sess.execute( - update(User).where(User.age > 19).values({"age": User.age - 10}) + update(User).where(User.age > 19).values({"age": User.age - 10}), + execution_options={"synchronize_session": synchronize_session}, ) eq_(result.rowcount, 4) @@ -1327,12 +1585,17 @@ class UpdateDeleteTest(fixtures.MappedTest): ) assert john not in sess - def test_evaluate_before_update(self): + @testing.combinations(True, False) + def test_evaluate_before_update(self, full_expiration): User = self.classes.User sess = fixture_session() john = sess.query(User).filter_by(name="john").one() - sess.expire(john, ["age"]) + + if full_expiration: + sess.expire(john) + else: + sess.expire(john, ["age"]) # eval must be before the update. otherwise # we eval john, age has been expired and doesn't @@ -1356,17 +1619,47 @@ class UpdateDeleteTest(fixtures.MappedTest): eq_(john.name, "j2") eq_(john.age, 40) - def test_evaluate_before_delete(self): + @testing.combinations(True, False) + def test_evaluate_before_delete(self, full_expiration): User = self.classes.User sess = fixture_session() john = sess.query(User).filter_by(name="john").one() - sess.expire(john, ["age"]) + jill = sess.query(User).filter_by(name="jill").one() + jane = sess.query(User).filter_by(name="jane").one() - sess.query(User).filter_by(name="john").filter_by(age=25).delete( + if full_expiration: + sess.expire(jill) + sess.expire(john) + else: + sess.expire(jill, ["age"]) + sess.expire(john, ["age"]) + + sess.query(User).filter(or_(User.age == 25, User.age == 37)).delete( synchronize_session="evaluate" ) - assert john not in sess + + # was fully deleted + assert jane not in sess + + # deleted object was expired, but not otherwise affected + assert jill in sess + + # deleted object was expired, but not otherwise affected + assert john in sess + + # partially expired row fully expired + assert inspect(jill).expired + + # non-deleted row still present + eq_(jill.age, 29) + + # partially expired row fully expired + assert inspect(john).expired + + # is deleted + with expect_raises(orm_exc.ObjectDeletedError): + john.name def test_fetch_before_delete(self): User = self.classes.User @@ -1378,6 +1671,7 @@ class UpdateDeleteTest(fixtures.MappedTest): sess.query(User).filter_by(name="john").filter_by(age=25).delete( synchronize_session="fetch" ) + assert john not in sess def test_update_unordered_dict(self): @@ -1495,6 +1789,60 @@ class UpdateDeleteTest(fixtures.MappedTest): ] eq_(["name", "age_int"], cols) + @testing.requires.sqlite + def test_sharding_extension_returning_mismatch(self, testing_engine): + """test one horizontal shard case where the given binds don't match + for RETURNING support; we dont support this. + + See test/ext/test_horizontal_shard.py for complete round trip + test cases for ORM update/delete + + """ + e1 = testing_engine("sqlite://") + e2 = testing_engine("sqlite://") + e1.connect().close() + e2.connect().close() + + e1.dialect.update_returning = True + e2.dialect.update_returning = False + + engines = [e1, e2] + + # a simulated version of the horizontal sharding extension + def execute_and_instances(orm_context): + execution_options = dict(orm_context.local_execution_options) + partial = [] + for engine in engines: + bind_arguments = dict(orm_context.bind_arguments) + bind_arguments["bind"] = engine + result_ = orm_context.invoke_statement( + bind_arguments=bind_arguments, + execution_options=execution_options, + ) + + partial.append(result_) + return partial[0].merge(*partial[1:]) + + User = self.classes.User + session = Session() + + event.listen( + session, "do_orm_execute", execute_and_instances, retval=True + ) + + stmt = ( + update(User) + .filter(User.id == 15) + .values(age=123) + .execution_options(synchronize_session="fetch") + ) + with expect_raises_message( + exc.InvalidRequestError, + "For synchronize_session='fetch', can't mix multiple backends " + "where some support RETURNING and others don't", + ): + session.execute(stmt) + class UpdateDeleteIgnoresLoadersTest(fixtures.MappedTest): @classmethod @@ -1748,6 +2096,7 @@ class UpdateDeleteFromTest(fixtures.MappedTest): "Could not evaluate current criteria in Python.", q.update, {"samename": "ed"}, + synchronize_session="evaluate", ) @testing.requires.multi_table_update @@ -1901,7 +2250,7 @@ class ExpressionUpdateTest(fixtures.MappedTest): sess.commit() eq_(d1.cnt, 0) - sess.query(Data).update({Data.cnt: Data.cnt + 1}) + sess.query(Data).update({Data.cnt: Data.cnt + 1}, "evaluate") sess.flush() eq_(d1.cnt, 1) @@ -2443,7 +2792,8 @@ class LoadFromReturningTest(fixtures.MappedTest): ) @testing.requires.update_returning - def test_load_from_update(self, connection): + @testing.combinations(True, False, argnames="use_from_statement") + def test_load_from_update(self, connection, use_from_statement): User = self.classes.User stmt = ( @@ -2453,7 +2803,16 @@ class LoadFromReturningTest(fixtures.MappedTest): .returning(User) ) - stmt = select(User).from_statement(stmt) + if use_from_statement: + # this is now a legacy-ish case, because as of 2.0 you can just + # use returning() directly to get the objects back. + # + # when from_statement is used, the UPDATE statement is no + # longer interpreted by + # BulkUDCompileState.orm_pre_session_exec or + # BulkUDCompileState.orm_setup_cursor_result. The compilation + # level routines still take place though + stmt = select(User).from_statement(stmt) with Session(connection) as sess: rows = sess.execute(stmt).scalars().all() @@ -2468,7 +2827,8 @@ class LoadFromReturningTest(fixtures.MappedTest): ("multiple", testing.requires.multivalues_inserts), argnames="params", ) - def test_load_from_insert(self, connection, params): + @testing.combinations(True, False, argnames="use_from_statement") + def test_load_from_insert(self, connection, params, use_from_statement): User = self.classes.User if params == "multiple": @@ -2484,7 +2844,8 @@ class LoadFromReturningTest(fixtures.MappedTest): stmt = insert(User).values(values).returning(User) - stmt = select(User).from_statement(stmt) + if use_from_statement: + stmt = select(User).from_statement(stmt) with Session(connection) as sess: rows = sess.execute(stmt).scalars().all() @@ -2505,3 +2866,25 @@ class LoadFromReturningTest(fixtures.MappedTest): ) else: assert False + + @testing.requires.delete_returning + @testing.combinations(True, False, argnames="use_from_statement") + def test_load_from_delete(self, connection, use_from_statement): + User = self.classes.User + + stmt = ( + delete(User).where(User.name.in_(["jack", "jill"])).returning(User) + ) + + if use_from_statement: + stmt = select(User).from_statement(stmt) + + with Session(connection) as sess: + rows = sess.execute(stmt).scalars().all() + + eq_( + rows, + [User(name="jack", age=47), User(name="jill", age=29)], + ) + + # TODO: state of above objects should be "deleted" diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py index 2e3874549..5f8cfc1f5 100644 --- a/test/orm/inheritance/test_basic.py +++ b/test/orm/inheritance/test_basic.py @@ -2012,7 +2012,8 @@ class JoinedNoFKSortingTest(fixtures.MappedTest): and testing.db.dialect.supports_default_metavalue, [ CompiledSQL( - "INSERT INTO a (id) VALUES (DEFAULT)", [{}, {}, {}, {}] + "INSERT INTO a (id) VALUES (DEFAULT) RETURNING a.id", + [{}, {}, {}, {}], ), ], [ diff --git a/test/orm/test_bind.py b/test/orm/test_bind.py index a6480365d..2f392cf6e 100644 --- a/test/orm/test_bind.py +++ b/test/orm/test_bind.py @@ -326,6 +326,7 @@ class BindIntegrationTest(_fixtures.FixtureTest): ), ( lambda User: update(User) + .execution_options(synchronize_session=False) .values(name="not ed") .where(User.name == "ed"), lambda User: {"clause": mock.ANY, "mapper": inspect(User)}, @@ -392,7 +393,15 @@ class BindIntegrationTest(_fixtures.FixtureTest): engine = {"e1": e1, "e2": e2, "e3": e3}[expected_engine_name] with mock.patch( - "sqlalchemy.orm.context.ORMCompileState.orm_setup_cursor_result" + "sqlalchemy.orm.context." "ORMCompileState.orm_setup_cursor_result" + ), mock.patch( + "sqlalchemy.orm.context.ORMCompileState.orm_execute_statement" + ), mock.patch( + "sqlalchemy.orm.bulk_persistence." + "BulkORMInsert.orm_execute_statement" + ), mock.patch( + "sqlalchemy.orm.bulk_persistence." + "BulkUDCompileState.orm_setup_cursor_result" ): sess.execute(statement) diff --git a/test/orm/test_composites.py b/test/orm/test_composites.py index 3a789aff7..efa2ecb45 100644 --- a/test/orm/test_composites.py +++ b/test/orm/test_composites.py @@ -1,8 +1,10 @@ import dataclasses import operator +import random import sqlalchemy as sa from sqlalchemy import ForeignKey +from sqlalchemy import insert from sqlalchemy import Integer from sqlalchemy import select from sqlalchemy import String @@ -233,7 +235,7 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL): is g.edges[1] ) - def test_bulk_update_sql(self): + def test_update_crit_sql(self): Edge, Point = (self.classes.Edge, self.classes.Point) sess = self._fixture() @@ -258,7 +260,7 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL): dialect="default", ) - def test_bulk_update_evaluate(self): + def test_update_crit_evaluate(self): Edge, Point = (self.classes.Edge, self.classes.Point) sess = self._fixture() @@ -287,7 +289,7 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL): eq_(e1.end, Point(17, 8)) - def test_bulk_update_fetch(self): + def test_update_crit_fetch(self): Edge, Point = (self.classes.Edge, self.classes.Point) sess = self._fixture() @@ -305,6 +307,205 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL): eq_(e1.end, Point(17, 8)) + @testing.combinations( + "legacy", "statement", "values", "stmt_returning", "values_returning" + ) + def test_bulk_insert(self, type_): + Edge, Point = (self.classes.Edge, self.classes.Point) + Graph = self.classes.Graph + + sess = self._fixture() + + graph = Graph(id=2) + sess.add(graph) + sess.flush() + graph_id = 2 + + data = [ + { + "start": Point(random.randint(1, 50), random.randint(1, 50)), + "end": Point(random.randint(1, 50), random.randint(1, 50)), + "graph_id": graph_id, + } + for i in range(25) + ] + returning = False + if type_ == "statement": + sess.execute(insert(Edge), data) + elif type_ == "stmt_returning": + result = sess.scalars(insert(Edge).returning(Edge), data) + returning = True + elif type_ == "values": + sess.execute(insert(Edge).values(data)) + elif type_ == "values_returning": + result = sess.scalars(insert(Edge).values(data).returning(Edge)) + returning = True + elif type_ == "legacy": + sess.bulk_insert_mappings(Edge, data) + else: + assert False + + if returning: + eq_(result.all(), [Edge(rec["start"], rec["end"]) for rec in data]) + + edges = self.tables.edges + eq_( + sess.execute( + select(edges.c["x1", "y1", "x2", "y2"]) + .where(edges.c.graph_id == graph_id) + .order_by(edges.c.id) + ).all(), + [ + (e["start"].x, e["start"].y, e["end"].x, e["end"].y) + for e in data + ], + ) + + @testing.combinations("legacy", "statement") + def test_bulk_insert_heterogeneous(self, type_): + Edge, Point = (self.classes.Edge, self.classes.Point) + Graph = self.classes.Graph + + sess = self._fixture() + + graph = Graph(id=2) + sess.add(graph) + sess.flush() + graph_id = 2 + + d1 = [ + { + "start": Point(random.randint(1, 50), random.randint(1, 50)), + "end": Point(random.randint(1, 50), random.randint(1, 50)), + "graph_id": graph_id, + } + for i in range(3) + ] + d2 = [ + { + "start": Point(random.randint(1, 50), random.randint(1, 50)), + "graph_id": graph_id, + } + for i in range(2) + ] + d3 = [ + { + "x2": random.randint(1, 50), + "y2": random.randint(1, 50), + "graph_id": graph_id, + } + for i in range(2) + ] + data = d1 + d2 + d3 + random.shuffle(data) + + assert_data = [ + { + "start": d["start"] if "start" in d else None, + "end": d["end"] + if "end" in d + else Point(d["x2"], d["y2"]) + if "x2" in d + else None, + "graph_id": d["graph_id"], + } + for d in data + ] + + if type_ == "statement": + sess.execute(insert(Edge), data) + elif type_ == "legacy": + sess.bulk_insert_mappings(Edge, data) + else: + assert False + + edges = self.tables.edges + eq_( + sess.execute( + select(edges.c["x1", "y1", "x2", "y2"]) + .where(edges.c.graph_id == graph_id) + .order_by(edges.c.id) + ).all(), + [ + ( + e["start"].x if e["start"] else None, + e["start"].y if e["start"] else None, + e["end"].x if e["end"] else None, + e["end"].y if e["end"] else None, + ) + for e in assert_data + ], + ) + + @testing.combinations("legacy", "statement") + def test_bulk_update(self, type_): + Edge, Point = (self.classes.Edge, self.classes.Point) + Graph = self.classes.Graph + + sess = self._fixture() + + graph = Graph(id=2) + sess.add(graph) + sess.flush() + graph_id = 2 + + data = [ + { + "start": Point(random.randint(1, 50), random.randint(1, 50)), + "end": Point(random.randint(1, 50), random.randint(1, 50)), + "graph_id": graph_id, + } + for i in range(25) + ] + sess.execute(insert(Edge), data) + + inserted_data = [ + dict(row._mapping) + for row in sess.execute( + select(Edge.id, Edge.start, Edge.end, Edge.graph_id) + .where(Edge.graph_id == graph_id) + .order_by(Edge.id) + ) + ] + + to_update = [] + updated_pks = {} + for rec in random.choices(inserted_data, k=7): + rec_copy = dict(rec) + updated_pks[rec_copy["id"]] = rec_copy + rec_copy["start"] = Point( + random.randint(1, 50), random.randint(1, 50) + ) + rec_copy["end"] = Point( + random.randint(1, 50), random.randint(1, 50) + ) + to_update.append(rec_copy) + + expected_dataset = [ + updated_pks[row["id"]] if row["id"] in updated_pks else row + for row in inserted_data + ] + + if type_ == "statement": + sess.execute(update(Edge), to_update) + elif type_ == "legacy": + sess.bulk_update_mappings(Edge, to_update) + else: + assert False + + edges = self.tables.edges + eq_( + sess.execute( + select(edges.c["x1", "y1", "x2", "y2"]) + .where(edges.c.graph_id == graph_id) + .order_by(edges.c.id) + ).all(), + [ + (e["start"].x, e["start"].y, e["end"].x, e["end"].y) + for e in expected_dataset + ], + ) + def test_get_history(self): Edge = self.classes.Edge Point = self.classes.Point diff --git a/test/orm/test_cycles.py b/test/orm/test_cycles.py index a23d8b735..7f0f504b5 100644 --- a/test/orm/test_cycles.py +++ b/test/orm/test_cycles.py @@ -1122,7 +1122,7 @@ class OneToManyManyToOneTest(fixtures.MappedTest): [ CompiledSQL( "INSERT INTO ball (person_id, data) " - "VALUES (:person_id, :data)", + "VALUES (:person_id, :data) RETURNING ball.id", [ {"person_id": None, "data": "some data"}, {"person_id": None, "data": "some data"}, diff --git a/test/orm/test_defaults.py b/test/orm/test_defaults.py index 7860f5eb1..e738689b8 100644 --- a/test/orm/test_defaults.py +++ b/test/orm/test_defaults.py @@ -383,20 +383,24 @@ class ComputedDefaultsOnUpdateTest(fixtures.MappedTest): CompiledSQL( "UPDATE test SET foo=:foo WHERE test.id = :test_id", [{"foo": 5, "test_id": 1}], + enable_returning=False, ), CompiledSQL( "UPDATE test SET foo=:foo WHERE test.id = :test_id", [{"foo": 6, "test_id": 2}], + enable_returning=False, ), CompiledSQL( "SELECT test.bar AS test_bar FROM test " "WHERE test.id = :pk_1", [{"pk_1": 1}], + enable_returning=False, ), CompiledSQL( "SELECT test.bar AS test_bar FROM test " "WHERE test.id = :pk_1", [{"pk_1": 2}], + enable_returning=False, ), ) else: diff --git a/test/orm/test_events.py b/test/orm/test_events.py index 24870e20f..75955afb5 100644 --- a/test/orm/test_events.py +++ b/test/orm/test_events.py @@ -661,8 +661,17 @@ class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest): canary = self._flag_fixture(sess) - sess.execute(delete(User).filter_by(id=18)) - sess.execute(update(User).filter_by(id=18).values(name="eighteen")) + sess.execute( + delete(User) + .filter_by(id=18) + .execution_options(synchronize_session="evaluate") + ) + sess.execute( + update(User) + .filter_by(id=18) + .values(name="eighteen") + .execution_options(synchronize_session="evaluate") + ) eq_( canary.mock_calls, diff --git a/test/orm/test_unitofwork.py b/test/orm/test_unitofwork.py index b94998716..fc452dc9c 100644 --- a/test/orm/test_unitofwork.py +++ b/test/orm/test_unitofwork.py @@ -2868,12 +2868,14 @@ class SaveTest2(_fixtures.FixtureTest): testing.db.dialect.insert_executemany_returning, [ CompiledSQL( - "INSERT INTO users (name) VALUES (:name)", + "INSERT INTO users (name) VALUES (:name) " + "RETURNING users.id", [{"name": "u1"}, {"name": "u2"}], ), CompiledSQL( "INSERT INTO addresses (user_id, email_address) " - "VALUES (:user_id, :email_address)", + "VALUES (:user_id, :email_address) " + "RETURNING addresses.id", [ {"user_id": 1, "email_address": "a1"}, {"user_id": 2, "email_address": "a2"}, diff --git a/test/orm/test_unitofworkv2.py b/test/orm/test_unitofworkv2.py index dd3b88915..855b44e81 100644 --- a/test/orm/test_unitofworkv2.py +++ b/test/orm/test_unitofworkv2.py @@ -98,7 +98,8 @@ class RudimentaryFlushTest(UOWTest): [ CompiledSQL( "INSERT INTO addresses (user_id, email_address) " - "VALUES (:user_id, :email_address)", + "VALUES (:user_id, :email_address) " + "RETURNING addresses.id", lambda ctx: [ {"email_address": "a1", "user_id": u1.id}, {"email_address": "a2", "user_id": u1.id}, @@ -220,7 +221,8 @@ class RudimentaryFlushTest(UOWTest): [ CompiledSQL( "INSERT INTO addresses (user_id, email_address) " - "VALUES (:user_id, :email_address)", + "VALUES (:user_id, :email_address) " + "RETURNING addresses.id", lambda ctx: [ {"email_address": "a1", "user_id": u1.id}, {"email_address": "a2", "user_id": u1.id}, @@ -889,7 +891,7 @@ class SingleCycleTest(UOWTest): [ CompiledSQL( "INSERT INTO nodes (parent_id, data) VALUES " - "(:parent_id, :data)", + "(:parent_id, :data) RETURNING nodes.id", lambda ctx: [ {"parent_id": n1.id, "data": "n2"}, {"parent_id": n1.id, "data": "n3"}, @@ -1003,7 +1005,7 @@ class SingleCycleTest(UOWTest): [ CompiledSQL( "INSERT INTO nodes (parent_id, data) VALUES " - "(:parent_id, :data)", + "(:parent_id, :data) RETURNING nodes.id", lambda ctx: [ {"parent_id": n1.id, "data": "n2"}, {"parent_id": n1.id, "data": "n3"}, @@ -1165,7 +1167,7 @@ class SingleCycleTest(UOWTest): [ CompiledSQL( "INSERT INTO nodes (parent_id, data) VALUES " - "(:parent_id, :data)", + "(:parent_id, :data) RETURNING nodes.id", lambda ctx: [ {"parent_id": n1.id, "data": "n11"}, {"parent_id": n1.id, "data": "n12"}, @@ -1196,7 +1198,7 @@ class SingleCycleTest(UOWTest): [ CompiledSQL( "INSERT INTO nodes (parent_id, data) VALUES " - "(:parent_id, :data)", + "(:parent_id, :data) RETURNING nodes.id", lambda ctx: [ {"parent_id": n12.id, "data": "n121"}, {"parent_id": n12.id, "data": "n122"}, @@ -2099,7 +2101,7 @@ class BatchInsertsTest(fixtures.MappedTest, testing.AssertsExecutionResults): testing.db.dialect.insert_executemany_returning, [ CompiledSQL( - "INSERT INTO t (data) VALUES (:data)", + "INSERT INTO t (data) VALUES (:data) RETURNING t.id", [{"data": "t1"}, {"data": "t2"}], ), ], @@ -2472,20 +2474,24 @@ class EagerDefaultsTest(fixtures.MappedTest): CompiledSQL( "INSERT INTO test (id, foo) VALUES (:id, 2 + 5)", [{"id": 1}], + enable_returning=False, ), CompiledSQL( "INSERT INTO test (id, foo) VALUES (:id, 5 + 5)", [{"id": 2}], + enable_returning=False, ), CompiledSQL( "SELECT test.foo AS test_foo FROM test " "WHERE test.id = :pk_1", [{"pk_1": 1}], + enable_returning=False, ), CompiledSQL( "SELECT test.foo AS test_foo FROM test " "WHERE test.id = :pk_1", [{"pk_1": 2}], + enable_returning=False, ), ) @@ -2678,20 +2684,24 @@ class EagerDefaultsTest(fixtures.MappedTest): CompiledSQL( "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id", [{"foo": 5, "test2_id": 1}], + enable_returning=False, ), CompiledSQL( "UPDATE test2 SET foo=:foo, bar=:bar " "WHERE test2.id = :test2_id", [{"foo": 6, "bar": 10, "test2_id": 2}], + enable_returning=False, ), CompiledSQL( "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id", [{"foo": 7, "test2_id": 3}], + enable_returning=False, ), CompiledSQL( "UPDATE test2 SET foo=:foo, bar=:bar " "WHERE test2.id = :test2_id", [{"foo": 8, "bar": 12, "test2_id": 4}], + enable_returning=False, ), CompiledSQL( "SELECT test2.bar AS test2_bar FROM test2 " @@ -2772,31 +2782,37 @@ class EagerDefaultsTest(fixtures.MappedTest): "UPDATE test4 SET foo=:foo, bar=5 + 3 " "WHERE test4.id = :test4_id", [{"foo": 5, "test4_id": 1}], + enable_returning=False, ), CompiledSQL( "UPDATE test4 SET foo=:foo, bar=:bar " "WHERE test4.id = :test4_id", [{"foo": 6, "bar": 10, "test4_id": 2}], + enable_returning=False, ), CompiledSQL( "UPDATE test4 SET foo=:foo, bar=5 + 3 " "WHERE test4.id = :test4_id", [{"foo": 7, "test4_id": 3}], + enable_returning=False, ), CompiledSQL( "UPDATE test4 SET foo=:foo, bar=:bar " "WHERE test4.id = :test4_id", [{"foo": 8, "bar": 12, "test4_id": 4}], + enable_returning=False, ), CompiledSQL( "SELECT test4.bar AS test4_bar FROM test4 " "WHERE test4.id = :pk_1", [{"pk_1": 1}], + enable_returning=False, ), CompiledSQL( "SELECT test4.bar AS test4_bar FROM test4 " "WHERE test4.id = :pk_1", [{"pk_1": 3}], + enable_returning=False, ), ], ), @@ -2871,20 +2887,24 @@ class EagerDefaultsTest(fixtures.MappedTest): "UPDATE test2 SET foo=:foo, bar=1 + 1 " "WHERE test2.id = :test2_id", [{"foo": 5, "test2_id": 1}], + enable_returning=False, ), CompiledSQL( "UPDATE test2 SET foo=:foo, bar=:bar " "WHERE test2.id = :test2_id", [{"foo": 6, "bar": 10, "test2_id": 2}], + enable_returning=False, ), CompiledSQL( "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id", [{"foo": 7, "test2_id": 3}], + enable_returning=False, ), CompiledSQL( "UPDATE test2 SET foo=:foo, bar=5 + 7 " "WHERE test2.id = :test2_id", [{"foo": 8, "test2_id": 4}], + enable_returning=False, ), CompiledSQL( "SELECT test2.bar AS test2_bar FROM test2 " diff --git a/test/orm/test_versioning.py b/test/orm/test_versioning.py index abd5833be..84e5a83b0 100644 --- a/test/orm/test_versioning.py +++ b/test/orm/test_versioning.py @@ -1424,12 +1424,10 @@ class ServerVersioningTest(fixtures.MappedTest): sess.add(f1) statements = [ - # note that the assertsql tests the rule against - # "default" - on a "returning" backend, the statement - # includes "RETURNING" CompiledSQL( "INSERT INTO version_table (version_id, value) " - "VALUES (1, :value)", + "VALUES (1, :value) " + "RETURNING version_table.id, version_table.version_id", lambda ctx: [{"value": "f1"}], ) ] @@ -1493,6 +1491,7 @@ class ServerVersioningTest(fixtures.MappedTest): "value": "f2", } ], + enable_returning=False, ), CompiledSQL( "SELECT version_table.version_id " @@ -1618,6 +1617,7 @@ class ServerVersioningTest(fixtures.MappedTest): "value": "f1a", } ], + enable_returning=False, ), CompiledSQL( "UPDATE version_table SET version_id=2, value=:value " @@ -1630,6 +1630,7 @@ class ServerVersioningTest(fixtures.MappedTest): "value": "f2a", } ], + enable_returning=False, ), CompiledSQL( "UPDATE version_table SET version_id=2, value=:value " @@ -1642,6 +1643,7 @@ class ServerVersioningTest(fixtures.MappedTest): "value": "f3a", } ], + enable_returning=False, ), CompiledSQL( "SELECT version_table.version_id " diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py index 42cf31bf5..4f776e300 100644 --- a/test/sql/test_resultset.py +++ b/test/sql/test_resultset.py @@ -100,10 +100,55 @@ class CursorResultTest(fixtures.TablesTest): Table( "test", metadata, - Column("x", Integer, primary_key=True), + Column( + "x", Integer, primary_key=True, test_needs_autoincrement=False + ), Column("y", String(50)), ) + @testing.requires.insert_returning + def test_splice_horizontally(self, connection): + users = self.tables.users + addresses = self.tables.addresses + + r1 = connection.execute( + users.insert().returning(users.c.user_name, users.c.user_id), + [ + dict(user_id=1, user_name="john"), + dict(user_id=2, user_name="jack"), + ], + ) + + r2 = connection.execute( + addresses.insert().returning( + addresses.c.address_id, + addresses.c.address, + addresses.c.user_id, + ), + [ + dict(address_id=1, user_id=1, address="foo@bar.com"), + dict(address_id=2, user_id=2, address="bar@bat.com"), + ], + ) + + rows = r1.splice_horizontally(r2).all() + eq_( + rows, + [ + ("john", 1, 1, "foo@bar.com", 1), + ("jack", 2, 2, "bar@bat.com", 2), + ], + ) + + eq_(rows[0]._mapping[users.c.user_id], 1) + eq_(rows[0]._mapping[addresses.c.user_id], 1) + eq_(rows[1].address, "bar@bat.com") + + with expect_raises_message( + exc.InvalidRequestError, "Ambiguous column name 'user_id'" + ): + rows[0].user_id + def test_keys_no_rows(self, connection): for i in range(2): diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py index f8cc32517..c26f825c2 100644 --- a/test/sql/test_returning.py +++ b/test/sql/test_returning.py @@ -23,6 +23,7 @@ from sqlalchemy.testing import config from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures +from sqlalchemy.testing import is_ from sqlalchemy.testing import mock from sqlalchemy.testing import provision from sqlalchemy.testing.schema import Column @@ -76,6 +77,7 @@ class ReturnCombinationTests(fixtures.TestBase, AssertsCompiledSQL): stmt = stmt.returning(t.c.x) stmt = stmt.return_defaults() + assert_raises_message( sa_exc.CompileError, r"Can't compile statement that includes returning\(\) " @@ -330,6 +332,7 @@ class InsertReturningTest(fixtures.TablesTest, AssertsExecutionResults): table = self.tables.returning_tbl exprs = testing.resolve_lambda(testcase, table=table) + result = connection.execute( table.insert().returning(*exprs), {"persons": 5, "full": False, "strval": "str1"}, @@ -679,6 +682,30 @@ class InsertReturnDefaultsTest(fixtures.TablesTest): Column("upddef", Integer, onupdate=IncDefault()), ) + Table( + "table_no_addtl_defaults", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + ) + + class MyType(TypeDecorator): + impl = String(50) + + def process_result_value(self, value, dialect): + return f"PROCESSED! {value}" + + Table( + "table_datatype_has_result_proc", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", MyType()), + ) + def test_chained_insert_pk(self, connection): t1 = self.tables.t1 result = connection.execute( @@ -758,6 +785,38 @@ class InsertReturnDefaultsTest(fixtures.TablesTest): ) eq_(result.inserted_primary_key, (1,)) + def test_insert_w_defaults_supplemental_cols(self, connection): + t1 = self.tables.t1 + result = connection.execute( + t1.insert().return_defaults(supplemental_cols=[t1.c.id]), + {"data": "d1"}, + ) + eq_(result.all(), [(1, 0, None)]) + + def test_insert_w_no_defaults_supplemental_cols(self, connection): + t1 = self.tables.table_no_addtl_defaults + result = connection.execute( + t1.insert().return_defaults(supplemental_cols=[t1.c.id]), + {"data": "d1"}, + ) + eq_(result.all(), [(1,)]) + + def test_insert_w_defaults_supplemental_processor_cols(self, connection): + """test that the cursor._rewind() used by supplemental RETURNING + clears out result-row processors as we will have already processed + the rows. + + """ + + t1 = self.tables.table_datatype_has_result_proc + result = connection.execute( + t1.insert().return_defaults( + supplemental_cols=[t1.c.id, t1.c.data] + ), + {"data": "d1"}, + ) + eq_(result.all(), [(1, "PROCESSED! d1")]) + class UpdatedReturnDefaultsTest(fixtures.TablesTest): __requires__ = ("update_returning",) @@ -792,6 +851,7 @@ class UpdatedReturnDefaultsTest(fixtures.TablesTest): t1 = self.tables.t1 connection.execute(t1.insert().values(upddef=1)) + result = connection.execute( t1.update().values(upddef=2).return_defaults(t1.c.data) ) @@ -800,6 +860,72 @@ class UpdatedReturnDefaultsTest(fixtures.TablesTest): [None], ) + def test_update_values_col_is_excluded(self, connection): + """columns that are in values() are not returned""" + t1 = self.tables.t1 + connection.execute(t1.insert().values(upddef=1)) + + result = connection.execute( + t1.update().values(data="x", upddef=2).return_defaults(t1.c.data) + ) + is_(result.returned_defaults, None) + + result = connection.execute( + t1.update() + .values(data="x", upddef=2) + .return_defaults(t1.c.data, t1.c.id) + ) + eq_(result.returned_defaults, (1,)) + + def test_update_supplemental_cols(self, connection): + """with supplemental_cols, we can get back arbitrary cols.""" + + t1 = self.tables.t1 + connection.execute(t1.insert().values(upddef=1)) + result = connection.execute( + t1.update() + .values(data="x", insdef=3) + .return_defaults(supplemental_cols=[t1.c.data, t1.c.insdef]) + ) + + row = result.returned_defaults + + # row has all the cols in it + eq_(row, ("x", 3, 1)) + eq_(row._mapping[t1.c.upddef], 1) + eq_(row._mapping[t1.c.insdef], 3) + + # result is rewound + # but has both return_defaults + supplemental_cols + eq_(result.all(), [("x", 3, 1)]) + + def test_update_expl_return_defaults_plus_supplemental_cols( + self, connection + ): + """with supplemental_cols, we can get back arbitrary cols.""" + + t1 = self.tables.t1 + connection.execute(t1.insert().values(upddef=1)) + result = connection.execute( + t1.update() + .values(data="x", insdef=3) + .return_defaults( + t1.c.id, supplemental_cols=[t1.c.data, t1.c.insdef] + ) + ) + + row = result.returned_defaults + + # row has all the cols in it + eq_(row, (1, "x", 3)) + eq_(row._mapping[t1.c.id], 1) + eq_(row._mapping[t1.c.insdef], 3) + assert t1.c.upddef not in row._mapping + + # result is rewound + # but has both return_defaults + supplemental_cols + eq_(result.all(), [(1, "x", 3)]) + def test_update_sql_expr(self, connection): from sqlalchemy import literal @@ -833,6 +959,75 @@ class UpdatedReturnDefaultsTest(fixtures.TablesTest): eq_(dict(result.returned_defaults._mapping), {"upddef": 1}) +class DeleteReturnDefaultsTest(fixtures.TablesTest): + __requires__ = ("delete_returning",) + run_define_tables = "each" + __backend__ = True + + define_tables = InsertReturnDefaultsTest.define_tables + + def test_delete(self, connection): + t1 = self.tables.t1 + connection.execute(t1.insert().values(upddef=1)) + result = connection.execute(t1.delete().return_defaults(t1.c.upddef)) + eq_( + [result.returned_defaults._mapping[k] for k in (t1.c.upddef,)], [1] + ) + + def test_delete_empty_return_defaults(self, connection): + t1 = self.tables.t1 + connection.execute(t1.insert().values(upddef=5)) + result = connection.execute(t1.delete().return_defaults()) + + # there's no "delete" default, so we get None. we have to + # ask for them in all cases + eq_(result.returned_defaults, None) + + def test_delete_non_default(self, connection): + """test that a column not marked at all as a + default works with this feature.""" + + t1 = self.tables.t1 + connection.execute(t1.insert().values(upddef=1)) + result = connection.execute(t1.delete().return_defaults(t1.c.data)) + eq_( + [result.returned_defaults._mapping[k] for k in (t1.c.data,)], + [None], + ) + + def test_delete_non_default_plus_default(self, connection): + t1 = self.tables.t1 + connection.execute(t1.insert().values(upddef=1)) + result = connection.execute( + t1.delete().return_defaults(t1.c.data, t1.c.upddef) + ) + eq_( + dict(result.returned_defaults._mapping), + {"data": None, "upddef": 1}, + ) + + def test_delete_supplemental_cols(self, connection): + """with supplemental_cols, we can get back arbitrary cols.""" + + t1 = self.tables.t1 + connection.execute(t1.insert().values(upddef=1)) + result = connection.execute( + t1.delete().return_defaults( + t1.c.id, supplemental_cols=[t1.c.data, t1.c.insdef] + ) + ) + + row = result.returned_defaults + + # row has all the cols in it + eq_(row, (1, None, 0)) + eq_(row._mapping[t1.c.insdef], 0) + + # result is rewound + # but has both return_defaults + supplemental_cols + eq_(result.all(), [(1, None, 0)]) + + class InsertManyReturnDefaultsTest(fixtures.TablesTest): __requires__ = ("insert_executemany_returning",) run_define_tables = "each" diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py index 64ff2e421..5ef927b15 100644 --- a/test/sql/test_selectable.py +++ b/test/sql/test_selectable.py @@ -44,6 +44,7 @@ from sqlalchemy.sql import operators from sqlalchemy.sql import table from sqlalchemy.sql import util as sql_util from sqlalchemy.sql import visitors +from sqlalchemy.sql.dml import Insert from sqlalchemy.sql.selectable import LABEL_STYLE_NONE from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message @@ -3029,6 +3030,26 @@ class AnnotationsTest(fixtures.TestBase): eq_(whereclause.left._annotations, {"foo": "bar"}) eq_(whereclause.right._annotations, {"foo": "bar"}) + @testing.combinations(True, False, None) + def test_setup_inherit_cache(self, inherit_cache_value): + if inherit_cache_value is None: + + class MyInsertThing(Insert): + pass + + else: + + class MyInsertThing(Insert): + inherit_cache = inherit_cache_value + + t = table("t", column("x")) + anno = MyInsertThing(t)._annotate({"foo": "bar"}) + + if inherit_cache_value is not None: + is_(type(anno).__dict__["inherit_cache"], inherit_cache_value) + else: + assert "inherit_cache" not in type(anno).__dict__ + def test_proxy_set_iteration_includes_annotated(self): from sqlalchemy.schema import Column |
