diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2023-02-10 08:39:21 -0500 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2023-02-10 10:42:11 -0500 |
| commit | eb0861e8e69f8ce702301c558e552e1aeb2e9eba (patch) | |
| tree | 638c8a08641e524b3ed54ee824275e5eb520ecbe /test | |
| parent | 48aad8b244737ad7d2000056ce7320d6b32fa2de (diff) | |
| download | sqlalchemy-eb0861e8e69f8ce702301c558e552e1aeb2e9eba.tar.gz | |
generalize adapt_on_names to expect non-named elements
The fix in #9217 opened up adapt_on_names to more kinds of
expressions than it was prepared for; adjust that logic
and also refine in the ORM where we are using it, as we
dont need it (yet) for the DML RETURNING use case.
Fixed regression introduced in version 2.0.2 due to :ticket:`9217` where
using DML RETURNING statements, as well as
:meth:`_sql.Select.from_statement` constructs as was "fixed" in
:ticket:`9217`, in conjunction with ORM mapped classes that used
expressions such as with :func:`_orm.column_property`, would lead to an
internal error within Core where it would attempt to match the expression
by name. The fix repairs the Core issue, and also adjusts the fix in
:ticket:`9217` to not take effect for the DML RETURNING use case, where it
adds unnecessary overhead.
Fixes: #9273
Change-Id: Ie0344efb12ff7df48f21e71e62dc598c76a6a0de
Diffstat (limited to 'test')
| -rw-r--r-- | test/orm/dml/test_bulk_statements.py | 37 | ||||
| -rw-r--r-- | test/orm/test_froms.py | 35 | ||||
| -rw-r--r-- | test/sql/test_external_traversal.py | 34 |
3 files changed, 97 insertions, 9 deletions
diff --git a/test/orm/dml/test_bulk_statements.py b/test/orm/dml/test_bulk_statements.py index 78607e03d..0b26786d4 100644 --- a/test/orm/dml/test_bulk_statements.py +++ b/test/orm/dml/test_bulk_statements.py @@ -11,12 +11,14 @@ from sqlalchemy import func from sqlalchemy import Identity from sqlalchemy import insert from sqlalchemy import inspect +from sqlalchemy import literal from sqlalchemy import literal_column from sqlalchemy import select from sqlalchemy import String from sqlalchemy import testing from sqlalchemy import update from sqlalchemy.orm import aliased +from sqlalchemy.orm import column_property from sqlalchemy.orm import load_only from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column @@ -27,10 +29,11 @@ from sqlalchemy.testing import fixtures from sqlalchemy.testing import mock from sqlalchemy.testing import provision from sqlalchemy.testing.assertsql import CompiledSQL +from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session -class NoReturningTest(fixtures.TestBase): +class InsertStmtTest(fixtures.TestBase): def test_no_returning_error(self, decl_base): class A(fixtures.ComparableEntity, decl_base): __tablename__ = "a" @@ -86,6 +89,38 @@ class NoReturningTest(fixtures.TestBase): [("d3", 5), ("d4", 6)], ) + def test_insert_from_select_col_property(self, decl_base): + """test #9273""" + + class User(ComparableEntity, decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + + name: Mapped[str] = mapped_column() + age: Mapped[int] = mapped_column() + + is_adult: Mapped[bool] = column_property(age >= 18) + + decl_base.metadata.create_all(testing.db) + + stmt = select( + literal(1).label("id"), + literal("John").label("name"), + literal(30).label("age"), + ) + + insert_stmt = ( + insert(User) + .from_select(["id", "name", "age"], stmt) + .returning(User) + ) + + s = fixture_session() + result = s.scalars(insert_stmt) + + eq_(result.all(), [User(id=1, name="John", age=30)]) + class BulkDMLReturningInhTest: def test_insert_col_key_also_works_currently(self): diff --git a/test/orm/test_froms.py b/test/orm/test_froms.py index e24062469..c2c237587 100644 --- a/test/orm/test_froms.py +++ b/test/orm/test_froms.py @@ -6,6 +6,7 @@ from sqlalchemy import exists from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import Integer +from sqlalchemy import literal from sqlalchemy import literal_column from sqlalchemy import select from sqlalchemy import String @@ -25,6 +26,8 @@ from sqlalchemy.orm import configure_mappers from sqlalchemy.orm import contains_eager from sqlalchemy.orm import declarative_base from sqlalchemy.orm import joinedload +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship from sqlalchemy.orm import Session from sqlalchemy.orm.context import ORMSelectCompileState @@ -36,8 +39,8 @@ from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures -from sqlalchemy.testing import in_ from sqlalchemy.testing import is_ +from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from test.orm import _fixtures @@ -2728,7 +2731,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): eq_(q.all(), expected) def test_unrelated_column(self): - """Test for 9217""" + """Test for #9217""" User = self.classes.User @@ -2739,8 +2742,32 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): s = select(User).from_statement(q) sess = fixture_session() res = sess.scalars(s).one() - in_("name", res.__dict__) - eq_(res.name, "sandy") + eq_(res, User(name="sandy", id=7)) + + def test_unrelated_column_col_prop(self, decl_base): + """Test for #9217 combined with #9273""" + + class User(ComparableEntity, decl_base): + __tablename__ = "some_user_table" + + id: Mapped[int] = mapped_column(primary_key=True) + + name: Mapped[str] = mapped_column() + age: Mapped[int] = mapped_column() + + is_adult: Mapped[bool] = column_property(age >= 18) + + stmt = select( + literal(1).label("id"), + literal("John").label("name"), + literal(30).label("age"), + ) + + s = select(User).from_statement(stmt) + sess = fixture_session() + res = sess.scalars(s).one() + + eq_(res, User(name="John", age=30, id=1)) def test_expression_selectable_matches_mzero(self): User, Address = self.classes.User, self.classes.Address diff --git a/test/sql/test_external_traversal.py b/test/sql/test_external_traversal.py index 8ccbd8d20..b8f6e5685 100644 --- a/test/sql/test_external_traversal.py +++ b/test/sql/test_external_traversal.py @@ -1194,7 +1194,6 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): ) def test_this_thing_using_setup_joins_three(self): - j = t1.join(t2, t1.c.col1 == t2.c.col2) s1 = select(j) @@ -1239,7 +1238,6 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): ) def test_this_thing_using_setup_joins_four(self): - j = t1.join(t2, t1.c.col1 == t2.c.col2) s1 = select(j) @@ -1606,6 +1604,36 @@ class ColumnAdapterTest(fixtures.TestBase, AssertsCompiledSQL): # not covered by a1, rejected by a2 is_(a3.columns[c2a1], c2a1) + @testing.combinations(True, False, argnames="colpresent") + @testing.combinations(True, False, argnames="adapt_on_names") + @testing.combinations(True, False, argnames="use_label") + def test_adapt_binary_col(self, colpresent, use_label, adapt_on_names): + """test #9273""" + + if use_label: + stmt = select(t1.c.col1, (t1.c.col2 > 18).label("foo")) + else: + stmt = select(t1.c.col1, (t1.c.col2 > 18)) + + sq = stmt.subquery() + + if colpresent: + s2 = select(sq.c[0], sq.c[1]) + else: + s2 = select(sq.c[0]) + + a1 = sql_util.ColumnAdapter(s2, adapt_on_names=adapt_on_names) + + is_(a1.columns[stmt.selected_columns[0]], s2.selected_columns[0]) + + if colpresent: + is_(a1.columns[stmt.selected_columns[1]], s2.selected_columns[1]) + else: + is_( + a1.columns[stmt.selected_columns[1]], + a1.columns[stmt.selected_columns[1]], + ) + class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = "default" @@ -1735,7 +1763,6 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): ) def test_adapt_select_w_unlabeled_fn(self): - expr = func.count(t1.c.col1) stmt = select(t1, expr) @@ -2335,7 +2362,6 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): assert s2.is_derived_from(s1) def test_aliasedselect_to_aliasedselect_straight(self): - # original issue from ticket #904 s1 = select(t1).alias("foo") |
