diff options
| author | Federico Caselli <cfederico87@gmail.com> | 2022-11-08 22:12:47 +0100 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-11-26 18:49:06 -0500 |
| commit | 61443aa62bbef158274ae393db399fec7f054c2d (patch) | |
| tree | 18d8794c2da57295f7b48530457ca9e71a60dfdb /test/sql | |
| parent | 5cc3825da3cdda6bd80e4fe7250b795c15ca4be3 (diff) | |
| download | sqlalchemy-61443aa62bbef158274ae393db399fec7f054c2d.tar.gz | |
Implement ScalarValue
Added :class:`_expression.ScalarValues` that can be used as a column
element allowing using :class:`_expression.Values` inside IN clauses
or in conjunction with ``ANY`` or ``ALL`` collection aggregates.
This new class is generated using the method
:meth:`_expression.Values.scalar_values`.
The :class:`_expression.Values` instance is now coerced to a
:class:`_expression.ScalarValues` when used in a ``IN`` or ``NOT IN``
operation.
Fixes: #6289
Change-Id: Iac22487ccb01553684b908e54d01c0687fa739f1
Diffstat (limited to 'test/sql')
| -rw-r--r-- | test/sql/test_compare.py | 79 | ||||
| -rw-r--r-- | test/sql/test_compiler.py | 12 | ||||
| -rw-r--r-- | test/sql/test_operators.py | 247 | ||||
| -rw-r--r-- | test/sql/test_roles.py | 17 |
4 files changed, 199 insertions, 156 deletions
diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index f18c79c7b..87710fdd9 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -625,6 +625,15 @@ class CoreFixtures: column("mykey", Integer), column("mytext", String), column("myint", Integer), + name="myvalues", + literal_binds=True, + ) + .data([(1, "textA", 99), (2, "textB", 88)]) + ._annotate({"nocache": True}), + values( + column("mykey", Integer), + column("mytext", String), + column("myint", Integer), name="myothervalues", ) .data([(1, "textA", 99), (2, "textB", 88)]) @@ -647,15 +656,62 @@ class CoreFixtures: ._annotate({"nocache": True}), # TODO: difference in type # values( - # [ - # column("mykey", Integer), - # column("mytext", Text), - # column("myint", Integer), - # ], - # (1, "textA", 99), - # (2, "textB", 88), - # alias_name="myvalues", - # ), + # column("mykey", Integer), + # column("mytext", Text), + # column("myint", Integer), + # name="myvalues", + # ) + # .data([(1, "textA", 99), (2, "textB", 88)]) + # ._annotate({"nocache": True}), + ), + lambda: ( + values( + column("mykey", Integer), + column("mytext", String), + column("myint", Integer), + name="myvalues", + ) + .data([(1, "textA", 99), (2, "textB", 88)]) + .scalar_values() + ._annotate({"nocache": True}), + values( + column("mykey", Integer), + column("mytext", String), + column("myint", Integer), + name="myvalues", + literal_binds=True, + ) + .data([(1, "textA", 99), (2, "textB", 88)]) + .scalar_values() + ._annotate({"nocache": True}), + values( + column("mykey", Integer), + column("mytext", String), + column("myint", Integer), + name="myvalues", + ) + .data([(1, "textA", 89), (2, "textG", 88)]) + .scalar_values() + ._annotate({"nocache": True}), + values( + column("mykey", Integer), + column("mynottext", String), + column("myint", Integer), + name="myvalues", + ) + .data([(1, "textA", 99), (2, "textB", 88)]) + .scalar_values() + ._annotate({"nocache": True}), + # TODO: difference in type + # values( + # column("mykey", Integer), + # column("mytext", Text), + # column("myint", Integer), + # name="myvalues", + # ) + # .data([(1, "textA", 99), (2, "textB", 88)]) + # .scalar_values() + # ._annotate({"nocache": True}), ), lambda: ( select(table_a.c.a), @@ -1304,9 +1360,8 @@ class CompareAndCopyTest(CoreFixtures, fixtures.TestBase): compare_annotations=True, compare_values=compare_values, ), - "%r != %r" % (case_a[a], case_b[b]), + f"{case_a[a]!r} != {case_b[b]!r} (index {a} {b})", ) - else: is_false( case_a[a].compare( @@ -1314,7 +1369,7 @@ class CompareAndCopyTest(CoreFixtures, fixtures.TestBase): compare_annotations=True, compare_values=compare_values, ), - "%r == %r" % (case_a[a], case_b[b]), + f"{case_a[a]!r} == {case_b[b]!r} (index {a} {b})", ) def test_compare_col_identity(self): diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index e5a149c49..c71cfd61f 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -4526,10 +4526,16 @@ class BindParameterTest(AssertsCompiledSQL, fixtures.TestBase): "((myothertable.otherid, myothertable.othername))", ) + @testing.variation("scalar_subquery", [True, False]) + def test_select_in(self, scalar_subquery): + + stmt = select(table2.c.otherid, table2.c.othername) + + if scalar_subquery: + stmt = stmt.scalar_subquery() + self.assert_compile( - tuple_(table1.c.myid, table1.c.name).in_( - select(table2.c.otherid, table2.c.othername) - ), + tuple_(table1.c.myid, table1.c.name).in_(stmt), "(mytable.myid, mytable.name) IN (SELECT " "myothertable.otherid, myothertable.othername FROM myothertable)", ) diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py index e00cacad8..103520f1f 100644 --- a/test/sql/test_operators.py +++ b/test/sql/test_operators.py @@ -43,6 +43,7 @@ from sqlalchemy.sql import roles from sqlalchemy.sql import sqltypes from sqlalchemy.sql import table from sqlalchemy.sql import true +from sqlalchemy.sql import values from sqlalchemy.sql.elements import BindParameter from sqlalchemy.sql.elements import BooleanClauseList from sqlalchemy.sql.elements import Label @@ -2390,6 +2391,23 @@ class InTest(fixtures.TestBase, testing.AssertsCompiledSQL): dialect="default_enhanced", ) + @testing.combinations(lambda v: v, lambda v: v.scalar_values()) + def test_in_values(self, scalar): + t1, t2 = self.table1, self.table2 + v = scalar(values(t2.c.otherid).data([(1,), (42,)])) + self.assert_compile( + select(t1.c.myid.in_(v)), + "SELECT mytable.myid IN (VALUES (:param_1), (:param_2)) " + "AS anon_1 FROM mytable", + params={"param_1": 1, "param_2": 42}, + ) + self.assert_compile( + select(t1.c.myid.not_in(v)), + "SELECT (mytable.myid NOT IN (VALUES (:param_1), (:param_2))) " + "AS anon_1 FROM mytable", + params={"param_1": 1, "param_2": 42}, + ) + class MathOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "default" @@ -2708,6 +2726,25 @@ class NegationTest(fixtures.TestBase, testing.AssertsCompiledSQL): assert not text("x = y")._is_implicitly_boolean assert not literal_column("x = y")._is_implicitly_boolean + def test_scalar_select(self): + t = self.table1 + expr = select(t.c.myid).where(t.c.myid > 5).scalar_subquery() + self.assert_compile( + not_(expr), + "NOT (SELECT mytable.myid FROM mytable " + "WHERE mytable.myid > :myid_1)", + params={"myid_1": 5}, + ) + + def test_scalar_values(self): + t = self.table1 + expr = values(t.c.myid).data([(7,), (42,)]).scalar_values() + self.assert_compile( + not_(expr), + "NOT (VALUES (:param_1), (:param_2))", + params={"param_1": 7, "param_2": 42}, + ) + class LikeTest(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "default" @@ -4324,85 +4361,61 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): self.assert_compile(expr(col), "NULL = ANY (tab1.%s)" % col.name) - def test_any_array(self, t_fixture): - t = t_fixture - - self.assert_compile( - 5 == any_(t.c.arrval), - ":param_1 = ANY (tab1.arrval)", - checkparams={"param_1": 5}, - ) - - def test_any_array_method(self, t_fixture): - t = t_fixture - - self.assert_compile( - 5 == t.c.arrval.any_(), - ":param_1 = ANY (tab1.arrval)", - checkparams={"param_1": 5}, - ) - - def test_all_array(self, t_fixture): - t = t_fixture - - self.assert_compile( - 5 == all_(t.c.arrval), - ":param_1 = ALL (tab1.arrval)", - checkparams={"param_1": 5}, - ) - - def test_all_array_method(self, t_fixture): - t = t_fixture - - self.assert_compile( - 5 == t.c.arrval.all_(), - ":param_1 = ALL (tab1.arrval)", - checkparams={"param_1": 5}, - ) + @testing.fixture( + params=[ + ("ANY", any_), + ("ANY", lambda x: x.any_()), + ("ALL", all_), + ("ALL", lambda x: x.all_()), + ] + ) + def operator(self, request): + return request.param + + @testing.fixture( + params=[ + ("ANY", lambda x, *o: x.any(*o)), + ("ALL", lambda x, *o: x.all(*o)), + ] + ) + def array_op(self, request): + return request.param - def test_any_comparator_array(self, t_fixture): + def test_array(self, t_fixture, operator): t = t_fixture - + op, fn = operator self.assert_compile( - 5 > any_(t.c.arrval), - ":param_1 > ANY (tab1.arrval)", + 5 == fn(t.c.arrval), + f":param_1 = {op} (tab1.arrval)", checkparams={"param_1": 5}, ) - def test_all_comparator_array(self, t_fixture): + def test_comparator_array(self, t_fixture, operator): t = t_fixture - + op, fn = operator self.assert_compile( - 5 > all_(t.c.arrval), - ":param_1 > ALL (tab1.arrval)", + 5 > fn(t.c.arrval), + f":param_1 > {op} (tab1.arrval)", checkparams={"param_1": 5}, ) - def test_any_comparator_array_wexpr(self, t_fixture): + def test_comparator_array_wexpr(self, t_fixture, operator): t = t_fixture - - self.assert_compile( - t.c.data > any_(t.c.arrval), - "tab1.data > ANY (tab1.arrval)", - checkparams={}, - ) - - def test_all_comparator_array_wexpr(self, t_fixture): - t = t_fixture - + op, fn = operator self.assert_compile( - t.c.data > all_(t.c.arrval), - "tab1.data > ALL (tab1.arrval)", + t.c.data > fn(t.c.arrval), + f"tab1.data > {op} (tab1.arrval)", checkparams={}, ) - def test_illegal_ops(self, t_fixture): + def test_illegal_ops(self, t_fixture, operator): t = t_fixture + op, fn = operator assert_raises_message( exc.ArgumentError, "Only comparison operators may be used with ANY/ALL", - lambda: 5 + all_(t.c.arrval), + lambda: 5 + fn(t.c.arrval), ) # TODO: @@ -4410,86 +4423,47 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): # as the left-hand side just does its thing. Types # would need to reject their right-hand side. self.assert_compile( - t.c.data + all_(t.c.arrval), "tab1.data + ALL (tab1.arrval)" + t.c.data + fn(t.c.arrval), f"tab1.data + {op} (tab1.arrval)" ) - @testing.combinations("all", "any", argnames="op") - def test_any_all_bindparam_coercion(self, t_fixture, op): + def test_bindparam_coercion(self, t_fixture, array_op): """test #7979""" t = t_fixture + op, fn = array_op - if op == "all": - expr = t.c.arrval.all(bindparam("param")) - expected = "%(param)s = ALL (tab1.arrval)" - elif op == "any": - expr = t.c.arrval.any(bindparam("param")) - expected = "%(param)s = ANY (tab1.arrval)" - else: - assert False - + expr = fn(t.c.arrval, bindparam("param")) + expected = f"%(param)s = {op} (tab1.arrval)" is_(expr.left.type._type_affinity, Integer) self.assert_compile(expr, expected, dialect="postgresql") - def test_any_array_comparator_accessor(self, t_fixture): - t = t_fixture - - self.assert_compile( - t.c.arrval.any(5, operator.gt), - ":arrval_1 > ANY (tab1.arrval)", - checkparams={"arrval_1": 5}, - ) - - def test_any_array_comparator_negate_accessor(self, t_fixture): - t = t_fixture - - self.assert_compile( - ~t.c.arrval.any(5, operator.gt), - "NOT (:arrval_1 > ANY (tab1.arrval))", - checkparams={"arrval_1": 5}, - ) - - def test_all_array_comparator_accessor(self, t_fixture): + def test_array_comparator_accessor(self, t_fixture, array_op): t = t_fixture + op, fn = array_op self.assert_compile( - t.c.arrval.all(5, operator.gt), - ":arrval_1 > ALL (tab1.arrval)", + fn(t.c.arrval, 5, operator.gt), + f":arrval_1 > {op} (tab1.arrval)", checkparams={"arrval_1": 5}, ) - def test_all_array_comparator_negate_accessor(self, t_fixture): + def test_array_comparator_negate_accessor(self, t_fixture, array_op): t = t_fixture + op, fn = array_op self.assert_compile( - ~t.c.arrval.all(5, operator.gt), - "NOT (:arrval_1 > ALL (tab1.arrval))", + ~fn(t.c.arrval, 5, operator.gt), + f"NOT (:arrval_1 > {op} (tab1.arrval))", checkparams={"arrval_1": 5}, ) - def test_any_array_expression(self, t_fixture): - t = t_fixture - - self.assert_compile( - 5 == any_(t.c.arrval[5:6] + postgresql.array([3, 4])), - "%(param_1)s = ANY (tab1.arrval[%(arrval_1)s:%(arrval_2)s] || " - "ARRAY[%(param_2)s, %(param_3)s])", - checkparams={ - "arrval_2": 6, - "param_1": 5, - "param_3": 4, - "arrval_1": 5, - "param_2": 3, - }, - dialect="postgresql", - ) - - def test_all_array_expression(self, t_fixture): + def test_array_expression(self, t_fixture, operator): t = t_fixture + op, fn = operator self.assert_compile( - 5 == all_(t.c.arrval[5:6] + postgresql.array([3, 4])), - "%(param_1)s = ALL (tab1.arrval[%(arrval_1)s:%(arrval_2)s] || " + 5 == fn(t.c.arrval[5:6] + postgresql.array([3, 4])), + f"%(param_1)s = {op} (tab1.arrval[%(arrval_1)s:%(arrval_2)s] || " "ARRAY[%(param_2)s, %(param_3)s])", checkparams={ "arrval_2": 6, @@ -4501,44 +4475,35 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): dialect="postgresql", ) - def test_any_subq(self, t_fixture): - t = t_fixture - - self.assert_compile( - 5 == any_(select(t.c.data).where(t.c.data < 10).scalar_subquery()), - ":param_1 = ANY (SELECT tab1.data " - "FROM tab1 WHERE tab1.data < :data_1)", - checkparams={"data_1": 10, "param_1": 5}, - ) - - def test_any_subq_method(self, t_fixture): + def test_subq(self, t_fixture, operator): t = t_fixture + op, fn = operator self.assert_compile( - 5 - == select(t.c.data).where(t.c.data < 10).scalar_subquery().any_(), - ":param_1 = ANY (SELECT tab1.data " + 5 == fn(select(t.c.data).where(t.c.data < 10).scalar_subquery()), + f":param_1 = {op} (SELECT tab1.data " "FROM tab1 WHERE tab1.data < :data_1)", checkparams={"data_1": 10, "param_1": 5}, ) - def test_all_subq(self, t_fixture): + def test_scalar_values(self, t_fixture, operator): t = t_fixture + op, fn = operator self.assert_compile( - 5 == all_(select(t.c.data).where(t.c.data < 10).scalar_subquery()), - ":param_1 = ALL (SELECT tab1.data " - "FROM tab1 WHERE tab1.data < :data_1)", - checkparams={"data_1": 10, "param_1": 5}, + 5 == fn(values(t.c.data).data([(1,), (42,)]).scalar_values()), + f":param_1 = {op} (VALUES (:param_2), (:param_3))", + checkparams={"param_1": 5, "param_2": 1, "param_3": 42}, ) - def test_all_subq_method(self, t_fixture): + @testing.combinations(any_, all_, argnames="fn") + def test_values_illegal(self, t_fixture, fn): t = t_fixture - self.assert_compile( - 5 - == select(t.c.data).where(t.c.data < 10).scalar_subquery().all_(), - ":param_1 = ALL (SELECT tab1.data " - "FROM tab1 WHERE tab1.data < :data_1)", - checkparams={"data_1": 10, "param_1": 5}, - ) + with expect_raises_message( + exc.ArgumentError, + "SQL expression element expected, got .* " + "To create a column expression from a VALUES clause, " + r"use the .scalar_values\(\) method.", + ): + fn(values(t.c.data).data([(1,), (42,)])) diff --git a/test/sql/test_roles.py b/test/sql/test_roles.py index 5c9ed3588..d181e0d1a 100644 --- a/test/sql/test_roles.py +++ b/test/sql/test_roles.py @@ -12,6 +12,7 @@ from sqlalchemy import table from sqlalchemy import testing from sqlalchemy import text from sqlalchemy import update +from sqlalchemy import values from sqlalchemy.schema import DDL from sqlalchemy.schema import Sequence from sqlalchemy.sql import ClauseElement @@ -189,6 +190,22 @@ class RoleTest(fixtures.TestBase): select(column("q")).alias(), ) + def test_values_advice(self): + value_expr = values( + column("id", Integer), column("name", String), name="my_values" + ).data([(1, "name1"), (2, "name2"), (3, "name3")]) + + assert_raises_message( + exc.ArgumentError, + r"SQL expression element expected, got <.*Values.*my_values>. To " + r"create a " + r"column expression from a VALUES clause, " + r"use the .scalar_values\(\) method.", + expect, + roles.ExpressionElementRole, + value_expr, + ) + def test_table_valued_advice(self): msg = ( r"SQL expression element expected, got %s. To create a " |
