diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-04-22 10:57:00 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-04-22 23:29:11 -0400 |
| commit | 63191fbef63ebfbf57e7b66bd6529305fc62c605 (patch) | |
| tree | 7572c1bdcecbc4e3640e5860c42a8b031f595bce | |
| parent | fe2045fb1c767436ed1e32359fe005dabead504a (diff) | |
| download | sqlalchemy-63191fbef63ebfbf57e7b66bd6529305fc62c605.tar.gz | |
properly type array element in any() / all()
Fixed bug in :class:`.ARRAY` datatype in combination with :class:`.Enum` on
PostgreSQL where using the ``.any()`` method to render SQL ANY(), given
members of the Python enumeration as arguments, would produce a type
adaptation failure on all drivers.
Fixes: #6515
Change-Id: Ia1e3b4e10aaf264ed436ce6030d105fc60023433
| -rw-r--r-- | doc/build/changelog/unreleased_14/6515.rst | 8 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/sqltypes.py | 24 | ||||
| -rw-r--r-- | test/dialect/postgresql/test_compiler.py | 32 | ||||
| -rw-r--r-- | test/dialect/postgresql/test_types.py | 33 | ||||
| -rw-r--r-- | test/sql/test_operators.py | 16 |
5 files changed, 81 insertions, 32 deletions
diff --git a/doc/build/changelog/unreleased_14/6515.rst b/doc/build/changelog/unreleased_14/6515.rst new file mode 100644 index 000000000..0ac5332b5 --- /dev/null +++ b/doc/build/changelog/unreleased_14/6515.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, postgresql + :tickets: 6515 + + Fixed bug in :class:`.ARRAY` datatype in combination with :class:`.Enum` on + PostgreSQL where using the ``.any()`` method to render SQL ANY(), given + members of the Python enumeration as arguments, would produce a type + adaptation failure on all drivers. diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 64d6ea81b..65b97d565 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -2715,9 +2715,11 @@ class ARRAY( __slots__ = () + type: ARRAY + def _setup_getitem(self, index): - arr_type = cast(ARRAY, self.type) + arr_type = self.type return_type: TypeEngine[Any] @@ -2784,10 +2786,18 @@ class ARRAY( elements = util.preloaded.sql_elements operator = operator if operator else operators.eq + arr_type = self.type + # send plain BinaryExpression so that negate remains at None, # leading to NOT expr for negation. return elements.BinaryExpression( - coercions.expect(roles.ExpressionElementRole, other), + coercions.expect( + roles.BinaryElementRole, + element=other, + operator=operator, + expr=self.expr, + bindparam_type=arr_type.item_type, + ), elements.CollectionAggregate._create_any(self.expr), operator, ) @@ -2828,10 +2838,18 @@ class ARRAY( elements = util.preloaded.sql_elements operator = operator if operator else operators.eq + arr_type = self.type + # send plain BinaryExpression so that negate remains at None, # leading to NOT expr for negation. return elements.BinaryExpression( - coercions.expect(roles.ExpressionElementRole, other), + coercions.expect( + roles.BinaryElementRole, + element=other, + operator=operator, + expr=self.expr, + bindparam_type=arr_type.item_type, + ), elements.CollectionAggregate._create_all(self.expr), operator, ) diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index 2221fd30a..0fe5f7066 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -1506,48 +1506,48 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): ) self.assert_compile( postgresql.Any(4, c), - "%(param_1)s = ANY (x)", - checkparams={"param_1": 4}, + "%(x_1)s = ANY (x)", + checkparams={"x_1": 4}, ) self.assert_compile( c.any(5), - "%(param_1)s = ANY (x)", - checkparams={"param_1": 5}, + "%(x_1)s = ANY (x)", + checkparams={"x_1": 5}, ) self.assert_compile( ~c.any(5), - "NOT (%(param_1)s = ANY (x))", - checkparams={"param_1": 5}, + "NOT (%(x_1)s = ANY (x))", + checkparams={"x_1": 5}, ) self.assert_compile( c.all(5), - "%(param_1)s = ALL (x)", - checkparams={"param_1": 5}, + "%(x_1)s = ALL (x)", + checkparams={"x_1": 5}, ) self.assert_compile( ~c.all(5), - "NOT (%(param_1)s = ALL (x))", - checkparams={"param_1": 5}, + "NOT (%(x_1)s = ALL (x))", + checkparams={"x_1": 5}, ) self.assert_compile( c.any(5, operator=operators.ne), - "%(param_1)s != ANY (x)", - checkparams={"param_1": 5}, + "%(x_1)s != ANY (x)", + checkparams={"x_1": 5}, ) self.assert_compile( postgresql.All(6, c, operator=operators.gt), - "%(param_1)s > ALL (x)", - checkparams={"param_1": 6}, + "%(x_1)s > ALL (x)", + checkparams={"x_1": 6}, ) self.assert_compile( c.all(7, operator=operators.lt), - "%(param_1)s < ALL (x)", - checkparams={"param_1": 7}, + "%(x_1)s < ALL (x)", + checkparams={"x_1": 7}, ) @testing.combinations( diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 8c4bb7fe7..bca952ade 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -1271,16 +1271,16 @@ class ArrayTest(AssertsCompiledSQL, fixtures.TestBase): col = column("x", postgresql.ARRAY(Integer)) self.assert_compile( select(col.any(7, operator=operators.lt)), - "SELECT %(param_1)s < ANY (x) AS anon_1", - checkparams={"param_1": 7}, + "SELECT %(x_1)s < ANY (x) AS anon_1", + checkparams={"x_1": 7}, ) def test_array_all(self): col = column("x", postgresql.ARRAY(Integer)) self.assert_compile( select(col.all(7, operator=operators.lt)), - "SELECT %(param_1)s < ALL (x) AS anon_1", - checkparams={"param_1": 7}, + "SELECT %(x_1)s < ALL (x) AS anon_1", + checkparams={"x_1": 7}, ) def test_array_contains(self): @@ -2404,7 +2404,10 @@ class ArrayEnum(fixtures.TestBase): metadata.create_all(connection) connection.execute( tbl.insert(), - [{"enum_col": ["foo"]}, {"enum_col": ["foo", "bar"]}], + [ + {"enum_col": ["foo"], "pyenum_col": [MyEnum.a, MyEnum.b]}, + {"enum_col": ["foo", "bar"], "pyenum_col": [MyEnum.b]}, + ], ) return tbl, MyEnum @@ -2423,6 +2426,26 @@ class ArrayEnum(fixtures.TestBase): ) @_enum_combinations + @testing.combinations("all", "any", argnames="fn") + def test_any_all_roundtrip( + self, array_of_enum_fixture, connection, array_cls, enum_cls, fn + ): + """test #6515""" + + tbl, MyEnum = array_of_enum_fixture(array_cls, enum_cls) + + if fn == "all": + expr = tbl.c.pyenum_col.all(MyEnum.b) + result = [([MyEnum.b],)] + elif fn == "any": + expr = tbl.c.pyenum_col.any(MyEnum.b) + result = [([MyEnum.a, MyEnum.b],), ([MyEnum.b],)] + else: + assert False + sel = select(tbl.c.pyenum_col).where(expr).order_by(tbl.c.id) + eq_(connection.execute(sel).fetchall(), result) + + @_enum_combinations def test_array_of_enums_roundtrip( self, array_of_enum_fixture, connection, array_cls, enum_cls ): diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py index 88d1ea053..77ca95de7 100644 --- a/test/sql/test_operators.py +++ b/test/sql/test_operators.py @@ -3795,8 +3795,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): self.assert_compile( t.c.arrval.any(5, operator.gt), - ":param_1 > ANY (tab1.arrval)", - checkparams={"param_1": 5}, + ":arrval_1 > ANY (tab1.arrval)", + checkparams={"arrval_1": 5}, ) def test_any_array_comparator_negate_accessor(self, t_fixture): @@ -3804,8 +3804,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): self.assert_compile( ~t.c.arrval.any(5, operator.gt), - "NOT (:param_1 > ANY (tab1.arrval))", - checkparams={"param_1": 5}, + "NOT (:arrval_1 > ANY (tab1.arrval))", + checkparams={"arrval_1": 5}, ) def test_all_array_comparator_accessor(self, t_fixture): @@ -3813,8 +3813,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): self.assert_compile( t.c.arrval.all(5, operator.gt), - ":param_1 > ALL (tab1.arrval)", - checkparams={"param_1": 5}, + ":arrval_1 > ALL (tab1.arrval)", + checkparams={"arrval_1": 5}, ) def test_all_array_comparator_negate_accessor(self, t_fixture): @@ -3822,8 +3822,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): self.assert_compile( ~t.c.arrval.all(5, operator.gt), - "NOT (:param_1 > ALL (tab1.arrval))", - checkparams={"param_1": 5}, + "NOT (:arrval_1 > ALL (tab1.arrval))", + checkparams={"arrval_1": 5}, ) def test_any_array_expression(self, t_fixture): |
