summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-04-22 10:57:00 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-04-22 23:29:11 -0400
commit63191fbef63ebfbf57e7b66bd6529305fc62c605 (patch)
tree7572c1bdcecbc4e3640e5860c42a8b031f595bce
parentfe2045fb1c767436ed1e32359fe005dabead504a (diff)
downloadsqlalchemy-63191fbef63ebfbf57e7b66bd6529305fc62c605.tar.gz
properly type array element in any() / all()
Fixed bug in :class:`.ARRAY` datatype in combination with :class:`.Enum` on PostgreSQL where using the ``.any()`` method to render SQL ANY(), given members of the Python enumeration as arguments, would produce a type adaptation failure on all drivers. Fixes: #6515 Change-Id: Ia1e3b4e10aaf264ed436ce6030d105fc60023433
-rw-r--r--doc/build/changelog/unreleased_14/6515.rst8
-rw-r--r--lib/sqlalchemy/sql/sqltypes.py24
-rw-r--r--test/dialect/postgresql/test_compiler.py32
-rw-r--r--test/dialect/postgresql/test_types.py33
-rw-r--r--test/sql/test_operators.py16
5 files changed, 81 insertions, 32 deletions
diff --git a/doc/build/changelog/unreleased_14/6515.rst b/doc/build/changelog/unreleased_14/6515.rst
new file mode 100644
index 000000000..0ac5332b5
--- /dev/null
+++ b/doc/build/changelog/unreleased_14/6515.rst
@@ -0,0 +1,8 @@
+.. change::
+ :tags: bug, postgresql
+ :tickets: 6515
+
+ Fixed bug in :class:`.ARRAY` datatype in combination with :class:`.Enum` on
+ PostgreSQL where using the ``.any()`` method to render SQL ANY(), given
+ members of the Python enumeration as arguments, would produce a type
+ adaptation failure on all drivers.
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py
index 64d6ea81b..65b97d565 100644
--- a/lib/sqlalchemy/sql/sqltypes.py
+++ b/lib/sqlalchemy/sql/sqltypes.py
@@ -2715,9 +2715,11 @@ class ARRAY(
__slots__ = ()
+ type: ARRAY
+
def _setup_getitem(self, index):
- arr_type = cast(ARRAY, self.type)
+ arr_type = self.type
return_type: TypeEngine[Any]
@@ -2784,10 +2786,18 @@ class ARRAY(
elements = util.preloaded.sql_elements
operator = operator if operator else operators.eq
+ arr_type = self.type
+
# send plain BinaryExpression so that negate remains at None,
# leading to NOT expr for negation.
return elements.BinaryExpression(
- coercions.expect(roles.ExpressionElementRole, other),
+ coercions.expect(
+ roles.BinaryElementRole,
+ element=other,
+ operator=operator,
+ expr=self.expr,
+ bindparam_type=arr_type.item_type,
+ ),
elements.CollectionAggregate._create_any(self.expr),
operator,
)
@@ -2828,10 +2838,18 @@ class ARRAY(
elements = util.preloaded.sql_elements
operator = operator if operator else operators.eq
+ arr_type = self.type
+
# send plain BinaryExpression so that negate remains at None,
# leading to NOT expr for negation.
return elements.BinaryExpression(
- coercions.expect(roles.ExpressionElementRole, other),
+ coercions.expect(
+ roles.BinaryElementRole,
+ element=other,
+ operator=operator,
+ expr=self.expr,
+ bindparam_type=arr_type.item_type,
+ ),
elements.CollectionAggregate._create_all(self.expr),
operator,
)
diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py
index 2221fd30a..0fe5f7066 100644
--- a/test/dialect/postgresql/test_compiler.py
+++ b/test/dialect/postgresql/test_compiler.py
@@ -1506,48 +1506,48 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
)
self.assert_compile(
postgresql.Any(4, c),
- "%(param_1)s = ANY (x)",
- checkparams={"param_1": 4},
+ "%(x_1)s = ANY (x)",
+ checkparams={"x_1": 4},
)
self.assert_compile(
c.any(5),
- "%(param_1)s = ANY (x)",
- checkparams={"param_1": 5},
+ "%(x_1)s = ANY (x)",
+ checkparams={"x_1": 5},
)
self.assert_compile(
~c.any(5),
- "NOT (%(param_1)s = ANY (x))",
- checkparams={"param_1": 5},
+ "NOT (%(x_1)s = ANY (x))",
+ checkparams={"x_1": 5},
)
self.assert_compile(
c.all(5),
- "%(param_1)s = ALL (x)",
- checkparams={"param_1": 5},
+ "%(x_1)s = ALL (x)",
+ checkparams={"x_1": 5},
)
self.assert_compile(
~c.all(5),
- "NOT (%(param_1)s = ALL (x))",
- checkparams={"param_1": 5},
+ "NOT (%(x_1)s = ALL (x))",
+ checkparams={"x_1": 5},
)
self.assert_compile(
c.any(5, operator=operators.ne),
- "%(param_1)s != ANY (x)",
- checkparams={"param_1": 5},
+ "%(x_1)s != ANY (x)",
+ checkparams={"x_1": 5},
)
self.assert_compile(
postgresql.All(6, c, operator=operators.gt),
- "%(param_1)s > ALL (x)",
- checkparams={"param_1": 6},
+ "%(x_1)s > ALL (x)",
+ checkparams={"x_1": 6},
)
self.assert_compile(
c.all(7, operator=operators.lt),
- "%(param_1)s < ALL (x)",
- checkparams={"param_1": 7},
+ "%(x_1)s < ALL (x)",
+ checkparams={"x_1": 7},
)
@testing.combinations(
diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py
index 8c4bb7fe7..bca952ade 100644
--- a/test/dialect/postgresql/test_types.py
+++ b/test/dialect/postgresql/test_types.py
@@ -1271,16 +1271,16 @@ class ArrayTest(AssertsCompiledSQL, fixtures.TestBase):
col = column("x", postgresql.ARRAY(Integer))
self.assert_compile(
select(col.any(7, operator=operators.lt)),
- "SELECT %(param_1)s < ANY (x) AS anon_1",
- checkparams={"param_1": 7},
+ "SELECT %(x_1)s < ANY (x) AS anon_1",
+ checkparams={"x_1": 7},
)
def test_array_all(self):
col = column("x", postgresql.ARRAY(Integer))
self.assert_compile(
select(col.all(7, operator=operators.lt)),
- "SELECT %(param_1)s < ALL (x) AS anon_1",
- checkparams={"param_1": 7},
+ "SELECT %(x_1)s < ALL (x) AS anon_1",
+ checkparams={"x_1": 7},
)
def test_array_contains(self):
@@ -2404,7 +2404,10 @@ class ArrayEnum(fixtures.TestBase):
metadata.create_all(connection)
connection.execute(
tbl.insert(),
- [{"enum_col": ["foo"]}, {"enum_col": ["foo", "bar"]}],
+ [
+ {"enum_col": ["foo"], "pyenum_col": [MyEnum.a, MyEnum.b]},
+ {"enum_col": ["foo", "bar"], "pyenum_col": [MyEnum.b]},
+ ],
)
return tbl, MyEnum
@@ -2423,6 +2426,26 @@ class ArrayEnum(fixtures.TestBase):
)
@_enum_combinations
+ @testing.combinations("all", "any", argnames="fn")
+ def test_any_all_roundtrip(
+ self, array_of_enum_fixture, connection, array_cls, enum_cls, fn
+ ):
+ """test #6515"""
+
+ tbl, MyEnum = array_of_enum_fixture(array_cls, enum_cls)
+
+ if fn == "all":
+ expr = tbl.c.pyenum_col.all(MyEnum.b)
+ result = [([MyEnum.b],)]
+ elif fn == "any":
+ expr = tbl.c.pyenum_col.any(MyEnum.b)
+ result = [([MyEnum.a, MyEnum.b],), ([MyEnum.b],)]
+ else:
+ assert False
+ sel = select(tbl.c.pyenum_col).where(expr).order_by(tbl.c.id)
+ eq_(connection.execute(sel).fetchall(), result)
+
+ @_enum_combinations
def test_array_of_enums_roundtrip(
self, array_of_enum_fixture, connection, array_cls, enum_cls
):
diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py
index 88d1ea053..77ca95de7 100644
--- a/test/sql/test_operators.py
+++ b/test/sql/test_operators.py
@@ -3795,8 +3795,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
self.assert_compile(
t.c.arrval.any(5, operator.gt),
- ":param_1 > ANY (tab1.arrval)",
- checkparams={"param_1": 5},
+ ":arrval_1 > ANY (tab1.arrval)",
+ checkparams={"arrval_1": 5},
)
def test_any_array_comparator_negate_accessor(self, t_fixture):
@@ -3804,8 +3804,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
self.assert_compile(
~t.c.arrval.any(5, operator.gt),
- "NOT (:param_1 > ANY (tab1.arrval))",
- checkparams={"param_1": 5},
+ "NOT (:arrval_1 > ANY (tab1.arrval))",
+ checkparams={"arrval_1": 5},
)
def test_all_array_comparator_accessor(self, t_fixture):
@@ -3813,8 +3813,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
self.assert_compile(
t.c.arrval.all(5, operator.gt),
- ":param_1 > ALL (tab1.arrval)",
- checkparams={"param_1": 5},
+ ":arrval_1 > ALL (tab1.arrval)",
+ checkparams={"arrval_1": 5},
)
def test_all_array_comparator_negate_accessor(self, t_fixture):
@@ -3822,8 +3822,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
self.assert_compile(
~t.c.arrval.all(5, operator.gt),
- "NOT (:param_1 > ALL (tab1.arrval))",
- checkparams={"param_1": 5},
+ "NOT (:arrval_1 > ALL (tab1.arrval))",
+ checkparams={"arrval_1": 5},
)
def test_any_array_expression(self, t_fixture):