summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2023-02-10 08:39:21 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2023-02-10 10:42:11 -0500
commiteb0861e8e69f8ce702301c558e552e1aeb2e9eba (patch)
tree638c8a08641e524b3ed54ee824275e5eb520ecbe /test
parent48aad8b244737ad7d2000056ce7320d6b32fa2de (diff)
downloadsqlalchemy-eb0861e8e69f8ce702301c558e552e1aeb2e9eba.tar.gz
generalize adapt_on_names to expect non-named elements
The fix in #9217 opened up adapt_on_names to more kinds of expressions than it was prepared for; adjust that logic and also refine in the ORM where we are using it, as we dont need it (yet) for the DML RETURNING use case. Fixed regression introduced in version 2.0.2 due to :ticket:`9217` where using DML RETURNING statements, as well as :meth:`_sql.Select.from_statement` constructs as was "fixed" in :ticket:`9217`, in conjunction with ORM mapped classes that used expressions such as with :func:`_orm.column_property`, would lead to an internal error within Core where it would attempt to match the expression by name. The fix repairs the Core issue, and also adjusts the fix in :ticket:`9217` to not take effect for the DML RETURNING use case, where it adds unnecessary overhead. Fixes: #9273 Change-Id: Ie0344efb12ff7df48f21e71e62dc598c76a6a0de
Diffstat (limited to 'test')
-rw-r--r--test/orm/dml/test_bulk_statements.py37
-rw-r--r--test/orm/test_froms.py35
-rw-r--r--test/sql/test_external_traversal.py34
3 files changed, 97 insertions, 9 deletions
diff --git a/test/orm/dml/test_bulk_statements.py b/test/orm/dml/test_bulk_statements.py
index 78607e03d..0b26786d4 100644
--- a/test/orm/dml/test_bulk_statements.py
+++ b/test/orm/dml/test_bulk_statements.py
@@ -11,12 +11,14 @@ from sqlalchemy import func
from sqlalchemy import Identity
from sqlalchemy import insert
from sqlalchemy import inspect
+from sqlalchemy import literal
from sqlalchemy import literal_column
from sqlalchemy import select
from sqlalchemy import String
from sqlalchemy import testing
from sqlalchemy import update
from sqlalchemy.orm import aliased
+from sqlalchemy.orm import column_property
from sqlalchemy.orm import load_only
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
@@ -27,10 +29,11 @@ from sqlalchemy.testing import fixtures
from sqlalchemy.testing import mock
from sqlalchemy.testing import provision
from sqlalchemy.testing.assertsql import CompiledSQL
+from sqlalchemy.testing.entities import ComparableEntity
from sqlalchemy.testing.fixtures import fixture_session
-class NoReturningTest(fixtures.TestBase):
+class InsertStmtTest(fixtures.TestBase):
def test_no_returning_error(self, decl_base):
class A(fixtures.ComparableEntity, decl_base):
__tablename__ = "a"
@@ -86,6 +89,38 @@ class NoReturningTest(fixtures.TestBase):
[("d3", 5), ("d4", 6)],
)
+ def test_insert_from_select_col_property(self, decl_base):
+ """test #9273"""
+
+ class User(ComparableEntity, decl_base):
+ __tablename__ = "users"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+
+ name: Mapped[str] = mapped_column()
+ age: Mapped[int] = mapped_column()
+
+ is_adult: Mapped[bool] = column_property(age >= 18)
+
+ decl_base.metadata.create_all(testing.db)
+
+ stmt = select(
+ literal(1).label("id"),
+ literal("John").label("name"),
+ literal(30).label("age"),
+ )
+
+ insert_stmt = (
+ insert(User)
+ .from_select(["id", "name", "age"], stmt)
+ .returning(User)
+ )
+
+ s = fixture_session()
+ result = s.scalars(insert_stmt)
+
+ eq_(result.all(), [User(id=1, name="John", age=30)])
+
class BulkDMLReturningInhTest:
def test_insert_col_key_also_works_currently(self):
diff --git a/test/orm/test_froms.py b/test/orm/test_froms.py
index e24062469..c2c237587 100644
--- a/test/orm/test_froms.py
+++ b/test/orm/test_froms.py
@@ -6,6 +6,7 @@ from sqlalchemy import exists
from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy import Integer
+from sqlalchemy import literal
from sqlalchemy import literal_column
from sqlalchemy import select
from sqlalchemy import String
@@ -25,6 +26,8 @@ from sqlalchemy.orm import configure_mappers
from sqlalchemy.orm import contains_eager
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import joinedload
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import relationship
from sqlalchemy.orm import Session
from sqlalchemy.orm.context import ORMSelectCompileState
@@ -36,8 +39,8 @@ from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import AssertsCompiledSQL
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
-from sqlalchemy.testing import in_
from sqlalchemy.testing import is_
+from sqlalchemy.testing.entities import ComparableEntity
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from test.orm import _fixtures
@@ -2728,7 +2731,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL):
eq_(q.all(), expected)
def test_unrelated_column(self):
- """Test for 9217"""
+ """Test for #9217"""
User = self.classes.User
@@ -2739,8 +2742,32 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL):
s = select(User).from_statement(q)
sess = fixture_session()
res = sess.scalars(s).one()
- in_("name", res.__dict__)
- eq_(res.name, "sandy")
+ eq_(res, User(name="sandy", id=7))
+
+ def test_unrelated_column_col_prop(self, decl_base):
+ """Test for #9217 combined with #9273"""
+
+ class User(ComparableEntity, decl_base):
+ __tablename__ = "some_user_table"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+
+ name: Mapped[str] = mapped_column()
+ age: Mapped[int] = mapped_column()
+
+ is_adult: Mapped[bool] = column_property(age >= 18)
+
+ stmt = select(
+ literal(1).label("id"),
+ literal("John").label("name"),
+ literal(30).label("age"),
+ )
+
+ s = select(User).from_statement(stmt)
+ sess = fixture_session()
+ res = sess.scalars(s).one()
+
+ eq_(res, User(name="John", age=30, id=1))
def test_expression_selectable_matches_mzero(self):
User, Address = self.classes.User, self.classes.Address
diff --git a/test/sql/test_external_traversal.py b/test/sql/test_external_traversal.py
index 8ccbd8d20..b8f6e5685 100644
--- a/test/sql/test_external_traversal.py
+++ b/test/sql/test_external_traversal.py
@@ -1194,7 +1194,6 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL):
)
def test_this_thing_using_setup_joins_three(self):
-
j = t1.join(t2, t1.c.col1 == t2.c.col2)
s1 = select(j)
@@ -1239,7 +1238,6 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL):
)
def test_this_thing_using_setup_joins_four(self):
-
j = t1.join(t2, t1.c.col1 == t2.c.col2)
s1 = select(j)
@@ -1606,6 +1604,36 @@ class ColumnAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
# not covered by a1, rejected by a2
is_(a3.columns[c2a1], c2a1)
+ @testing.combinations(True, False, argnames="colpresent")
+ @testing.combinations(True, False, argnames="adapt_on_names")
+ @testing.combinations(True, False, argnames="use_label")
+ def test_adapt_binary_col(self, colpresent, use_label, adapt_on_names):
+ """test #9273"""
+
+ if use_label:
+ stmt = select(t1.c.col1, (t1.c.col2 > 18).label("foo"))
+ else:
+ stmt = select(t1.c.col1, (t1.c.col2 > 18))
+
+ sq = stmt.subquery()
+
+ if colpresent:
+ s2 = select(sq.c[0], sq.c[1])
+ else:
+ s2 = select(sq.c[0])
+
+ a1 = sql_util.ColumnAdapter(s2, adapt_on_names=adapt_on_names)
+
+ is_(a1.columns[stmt.selected_columns[0]], s2.selected_columns[0])
+
+ if colpresent:
+ is_(a1.columns[stmt.selected_columns[1]], s2.selected_columns[1])
+ else:
+ is_(
+ a1.columns[stmt.selected_columns[1]],
+ a1.columns[stmt.selected_columns[1]],
+ )
+
class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = "default"
@@ -1735,7 +1763,6 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
)
def test_adapt_select_w_unlabeled_fn(self):
-
expr = func.count(t1.c.col1)
stmt = select(t1, expr)
@@ -2335,7 +2362,6 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
assert s2.is_derived_from(s1)
def test_aliasedselect_to_aliasedselect_straight(self):
-
# original issue from ticket #904
s1 = select(t1).alias("foo")