diff options
| -rw-r--r-- | doc/build/changelog/changelog_12.rst | 7 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/array.py | 17 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/sqltypes.py | 19 | ||||
| -rw-r--r-- | test/dialect/postgresql/test_types.py | 231 |
4 files changed, 160 insertions, 114 deletions
diff --git a/doc/build/changelog/changelog_12.rst b/doc/build/changelog/changelog_12.rst index 815f587f8..7c0421019 100644 --- a/doc/build/changelog/changelog_12.rst +++ b/doc/build/changelog/changelog_12.rst @@ -13,6 +13,13 @@ .. changelog:: :version: 1.2.0b1 + .. change:: 3964 + :tags: bug, postgresql + :tickets: 3964 + + Fixed bug where the base :class:`.sqltypes.ARRAY` datatype would not + invoke the bind/result processors of :class:`.postgresql.ARRAY`. + .. change:: 3963 :tags: bug, orm :tickets: 3963 diff --git a/lib/sqlalchemy/dialects/postgresql/array.py b/lib/sqlalchemy/dialects/postgresql/array.py index 98cab9562..009c83c0d 100644 --- a/lib/sqlalchemy/dialects/postgresql/array.py +++ b/lib/sqlalchemy/dialects/postgresql/array.py @@ -5,7 +5,7 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from .base import ischema_names +from .base import ischema_names, colspecs from ...sql import expression, operators from ...sql.base import SchemaEventTarget from ... import types as sqltypes @@ -114,7 +114,7 @@ CONTAINED_BY = operators.custom_op("<@", precedence=5) OVERLAP = operators.custom_op("&&", precedence=5) -class ARRAY(SchemaEventTarget, sqltypes.ARRAY): +class ARRAY(sqltypes.ARRAY): """PostgreSQL ARRAY type. @@ -248,18 +248,6 @@ class ARRAY(SchemaEventTarget, sqltypes.ARRAY): def compare_values(self, x, y): return x == y - def _set_parent(self, column): - """Support SchemaEventTarget""" - - if isinstance(self.item_type, SchemaEventTarget): - self.item_type._set_parent(column) - - def _set_parent_with_dispatch(self, parent): - """Support SchemaEventTarget""" - - if isinstance(self.item_type, SchemaEventTarget): - self.item_type._set_parent_with_dispatch(parent) - def _proc_array(self, arr, itemproc, dim, collection): if dim is None: arr = list(arr) @@ -311,4 +299,5 @@ class ARRAY(SchemaEventTarget, sqltypes.ARRAY): tuple if self.as_tuple else list) return process +colspecs[sqltypes.ARRAY] = ARRAY ischema_names['_array'] = ARRAY diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 8a114ece6..b8117e3ca 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -2061,7 +2061,7 @@ class JSON(Indexable, TypeEngine): return process -class ARRAY(Indexable, Concatenable, TypeEngine): +class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine): """Represent a SQL Array type. .. note:: This type serves as the basis for all ARRAY operations. @@ -2199,6 +2199,11 @@ class ARRAY(Indexable, Concatenable, TypeEngine): return operators.getitem, index, return_type + def contains(self, *arg, **kw): + raise NotImplementedError( + "ARRAY.contains() not implemented for the base " + "ARRAY type; please use the dialect-specific ARRAY type") + @util.dependencies("sqlalchemy.sql.elements") def any(self, elements, other, operator=None): """Return ``other operator ANY (array)`` clause. @@ -2325,6 +2330,18 @@ class ARRAY(Indexable, Concatenable, TypeEngine): def compare_values(self, x, y): return x == y + def _set_parent(self, column): + """Support SchemaEventTarget""" + + if isinstance(self.item_type, SchemaEventTarget): + self.item_type._set_parent(column) + + def _set_parent_with_dispatch(self, parent): + """Support SchemaEventTarget""" + + if isinstance(self.item_type, SchemaEventTarget): + self.item_type._set_parent_with_dispatch(parent) + class REAL(Float): diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 807eeb60c..d2e19a04a 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -4,6 +4,7 @@ from sqlalchemy.testing.assertions import eq_, assert_raises, \ AssertsCompiledSQL, ComparesTables from sqlalchemy.testing import engines, fixtures from sqlalchemy import testing +from sqlalchemy.sql import sqltypes import datetime from sqlalchemy import Table, MetaData, Column, Integer, Enum, Float, select, \ func, DateTime, Numeric, exc, String, cast, REAL, TypeDecorator, Unicode, \ @@ -85,7 +86,7 @@ class FloatCoercionTest(fixtures.TablesTest, AssertsExecutionResults): @testing.fails_on('postgresql+zxjdbc', 'zxjdbc has no support for PG arrays') @testing.provide_metadata - def test_arrays(self): + def test_arrays_pg(self): metadata = self.metadata t1 = Table('t', metadata, Column('x', postgresql.ARRAY(Float)), @@ -101,6 +102,25 @@ class FloatCoercionTest(fixtures.TablesTest, AssertsExecutionResults): ([5], [5], [6], [decimal.Decimal("6.4")]) ) + @testing.fails_on('postgresql+zxjdbc', + 'zxjdbc has no support for PG arrays') + @testing.provide_metadata + def test_arrays_base(self): + metadata = self.metadata + t1 = Table('t', metadata, + Column('x', sqltypes.ARRAY(Float)), + Column('y', sqltypes.ARRAY(REAL)), + Column('z', sqltypes.ARRAY(postgresql.DOUBLE_PRECISION)), + Column('q', sqltypes.ARRAY(Numeric)) + ) + metadata.create_all() + t1.insert().execute(x=[5], y=[5], z=[6], q=[decimal.Decimal("6.4")]) + row = t1.select().execute().first() + eq_( + row, + ([5], [5], [6], [decimal.Decimal("6.4")]) + ) + class EnumTest(fixtures.TestBase, AssertsExecutionResults): __backend__ = True @@ -987,17 +1007,19 @@ class ArrayTest(AssertsCompiledSQL, fixtures.TestBase): is_(expr.type.item_type.__class__, Integer) -class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults): +class ArrayRoundTripTest(object): __only_on__ = 'postgresql' __backend__ = True __unsupported_on__ = 'postgresql+pg8000', 'postgresql+zxjdbc' + ARRAY = postgresql.ARRAY + @classmethod def define_tables(cls, metadata): class ProcValue(TypeDecorator): - impl = postgresql.ARRAY(Integer, dimensions=2) + impl = cls.ARRAY(Integer, dimensions=2) def process_bind_param(self, value, dialect): if value is None: @@ -1017,15 +1039,15 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults): Table('arrtable', metadata, Column('id', Integer, primary_key=True), - Column('intarr', postgresql.ARRAY(Integer)), - Column('strarr', postgresql.ARRAY(Unicode())), + Column('intarr', cls.ARRAY(Integer)), + Column('strarr', cls.ARRAY(Unicode())), Column('dimarr', ProcValue) ) Table('dim_arrtable', metadata, Column('id', Integer, primary_key=True), - Column('intarr', postgresql.ARRAY(Integer, dimensions=1)), - Column('strarr', postgresql.ARRAY(Unicode(), dimensions=1)), + Column('intarr', cls.ARRAY(Integer, dimensions=1)), + Column('strarr', cls.ARRAY(Unicode(), dimensions=1)), Column('dimarr', ProcValue) ) @@ -1038,8 +1060,8 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults): def test_reflect_array_column(self): metadata2 = MetaData(testing.db) tbl = Table('arrtable', metadata2, autoload=True) - assert isinstance(tbl.c.intarr.type, postgresql.ARRAY) - assert isinstance(tbl.c.strarr.type, postgresql.ARRAY) + assert isinstance(tbl.c.intarr.type, self.ARRAY) + assert isinstance(tbl.c.strarr.type, self.ARRAY) assert isinstance(tbl.c.intarr.type.item_type, Integer) assert isinstance(tbl.c.strarr.type.item_type, String) @@ -1107,19 +1129,19 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults): func.array_cat( array([1, 2, 3]), array([4, 5, 6]), - type_=postgresql.ARRAY(Integer) + type_=self.ARRAY(Integer) )[2:5] ]) eq_( testing.db.execute(stmt).scalar(), [2, 3, 4, 5] ) - def test_any_all_exprs(self): + def test_any_all_exprs_array(self): stmt = select([ 3 == any_(func.array_cat( array([1, 2, 3]), array([4, 5, 6]), - type_=postgresql.ARRAY(Integer) + type_=self.ARRAY(Integer) )) ]) eq_( @@ -1225,17 +1247,6 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults): 7 ) - def test_undim_array_empty(self): - arrtable = self.tables.arrtable - self._fixture_456(arrtable) - eq_( - testing.db.scalar( - select([arrtable.c.intarr]). - where(arrtable.c.intarr.contains([])) - ), - [4, 5, 6] - ) - def test_array_getitem_slice_exec(self): arrtable = self.tables.arrtable testing.db.execute( @@ -1255,49 +1266,6 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults): [7, 8] ) - def _test_undim_array_contains_typed_exec(self, struct): - arrtable = self.tables.arrtable - self._fixture_456(arrtable) - eq_( - testing.db.scalar( - select([arrtable.c.intarr]). - where(arrtable.c.intarr.contains(struct([4, 5]))) - ), - [4, 5, 6] - ) - - def test_undim_array_contains_set_exec(self): - self._test_undim_array_contains_typed_exec(set) - - def test_undim_array_contains_list_exec(self): - self._test_undim_array_contains_typed_exec(list) - - def test_undim_array_contains_generator_exec(self): - self._test_undim_array_contains_typed_exec( - lambda elem: (x for x in elem)) - - def _test_dim_array_contains_typed_exec(self, struct): - dim_arrtable = self.tables.dim_arrtable - self._fixture_456(dim_arrtable) - eq_( - testing.db.scalar( - select([dim_arrtable.c.intarr]). - where(dim_arrtable.c.intarr.contains(struct([4, 5]))) - ), - [4, 5, 6] - ) - - def test_dim_array_contains_set_exec(self): - self._test_dim_array_contains_typed_exec(set) - - def test_dim_array_contains_list_exec(self): - self._test_dim_array_contains_typed_exec(list) - - def test_dim_array_contains_generator_exec(self): - self._test_dim_array_contains_typed_exec( - lambda elem: ( - x for x in elem)) - def test_multi_dim_roundtrip(self): arrtable = self.tables.arrtable testing.db.execute(arrtable.insert(), dimarr=[[1, 2, 3], [4, 5, 6]]) @@ -1306,35 +1274,6 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults): [[-1, 0, 1], [2, 3, 4]] ) - def test_array_contained_by_exec(self): - arrtable = self.tables.arrtable - with testing.db.connect() as conn: - conn.execute( - arrtable.insert(), - intarr=[6, 5, 4] - ) - eq_( - conn.scalar( - select([arrtable.c.intarr.contained_by([4, 5, 6, 7])]) - ), - True - ) - - def test_array_overlap_exec(self): - arrtable = self.tables.arrtable - with testing.db.connect() as conn: - conn.execute( - arrtable.insert(), - intarr=[4, 5, 6] - ) - eq_( - conn.scalar( - select([arrtable.c.intarr]). - where(arrtable.c.intarr.overlap([7, 6])) - ), - [4, 5, 6] - ) - def test_array_any_exec(self): arrtable = self.tables.arrtable with testing.db.connect() as conn: @@ -1372,10 +1311,10 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults): t1 = Table( 't1', metadata, Column('id', Integer, primary_key=True), - Column('data', postgresql.ARRAY(String(5), as_tuple=True)), + Column('data', self.ARRAY(String(5), as_tuple=True)), Column( 'data2', - postgresql.ARRAY( + self.ARRAY( Numeric(asdecimal=False), as_tuple=True) ) ) @@ -1416,13 +1355,13 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults): 't', m, Column( 'data_1', - postgresql.ARRAY( + self.ARRAY( postgresql.ENUM('a', 'b', 'c', name='my_enum_1') ) ), Column( 'data_2', - postgresql.ARRAY( + self.ARRAY( types.Enum('a', 'b', 'c', name='my_enum_2') ) ) @@ -1437,6 +1376,100 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults): eq_(inspect(testing.db).get_enums(), []) +class CoreArrayRoundTripTest(ArrayRoundTripTest, + fixtures.TablesTest, AssertsExecutionResults): + + ARRAY = sqltypes.ARRAY + + +class PGArrayRoundTripTest(ArrayRoundTripTest, + fixtures.TablesTest, AssertsExecutionResults): + ARRAY = postgresql.ARRAY + + def _test_undim_array_contains_typed_exec(self, struct): + arrtable = self.tables.arrtable + self._fixture_456(arrtable) + eq_( + testing.db.scalar( + select([arrtable.c.intarr]). + where(arrtable.c.intarr.contains(struct([4, 5]))) + ), + [4, 5, 6] + ) + + def test_undim_array_contains_set_exec(self): + self._test_undim_array_contains_typed_exec(set) + + def test_undim_array_contains_list_exec(self): + self._test_undim_array_contains_typed_exec(list) + + def test_undim_array_contains_generator_exec(self): + self._test_undim_array_contains_typed_exec( + lambda elem: (x for x in elem)) + + def _test_dim_array_contains_typed_exec(self, struct): + dim_arrtable = self.tables.dim_arrtable + self._fixture_456(dim_arrtable) + eq_( + testing.db.scalar( + select([dim_arrtable.c.intarr]). + where(dim_arrtable.c.intarr.contains(struct([4, 5]))) + ), + [4, 5, 6] + ) + + def test_dim_array_contains_set_exec(self): + self._test_dim_array_contains_typed_exec(set) + + def test_dim_array_contains_list_exec(self): + self._test_dim_array_contains_typed_exec(list) + + def test_dim_array_contains_generator_exec(self): + self._test_dim_array_contains_typed_exec( + lambda elem: ( + x for x in elem)) + + def test_array_contained_by_exec(self): + arrtable = self.tables.arrtable + with testing.db.connect() as conn: + conn.execute( + arrtable.insert(), + intarr=[6, 5, 4] + ) + eq_( + conn.scalar( + select([arrtable.c.intarr.contained_by([4, 5, 6, 7])]) + ), + True + ) + + def test_undim_array_empty(self): + arrtable = self.tables.arrtable + self._fixture_456(arrtable) + eq_( + testing.db.scalar( + select([arrtable.c.intarr]). + where(arrtable.c.intarr.contains([])) + ), + [4, 5, 6] + ) + + def test_array_overlap_exec(self): + arrtable = self.tables.arrtable + with testing.db.connect() as conn: + conn.execute( + arrtable.insert(), + intarr=[4, 5, 6] + ) + eq_( + conn.scalar( + select([arrtable.c.intarr]). + where(arrtable.c.intarr.overlap([7, 6])) + ), + [4, 5, 6] + ) + + class HashableFlagORMTest(fixtures.TestBase): """test the various 'collection' types that they flip the 'hashable' flag appropriately. [ticket:3499]""" |
