summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/build/changelog/changelog_12.rst7
-rw-r--r--lib/sqlalchemy/dialects/postgresql/array.py17
-rw-r--r--lib/sqlalchemy/sql/sqltypes.py19
-rw-r--r--test/dialect/postgresql/test_types.py231
4 files changed, 160 insertions, 114 deletions
diff --git a/doc/build/changelog/changelog_12.rst b/doc/build/changelog/changelog_12.rst
index 815f587f8..7c0421019 100644
--- a/doc/build/changelog/changelog_12.rst
+++ b/doc/build/changelog/changelog_12.rst
@@ -13,6 +13,13 @@
.. changelog::
:version: 1.2.0b1
+ .. change:: 3964
+ :tags: bug, postgresql
+ :tickets: 3964
+
+ Fixed bug where the base :class:`.sqltypes.ARRAY` datatype would not
+ invoke the bind/result processors of :class:`.postgresql.ARRAY`.
+
.. change:: 3963
:tags: bug, orm
:tickets: 3963
diff --git a/lib/sqlalchemy/dialects/postgresql/array.py b/lib/sqlalchemy/dialects/postgresql/array.py
index 98cab9562..009c83c0d 100644
--- a/lib/sqlalchemy/dialects/postgresql/array.py
+++ b/lib/sqlalchemy/dialects/postgresql/array.py
@@ -5,7 +5,7 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-from .base import ischema_names
+from .base import ischema_names, colspecs
from ...sql import expression, operators
from ...sql.base import SchemaEventTarget
from ... import types as sqltypes
@@ -114,7 +114,7 @@ CONTAINED_BY = operators.custom_op("<@", precedence=5)
OVERLAP = operators.custom_op("&&", precedence=5)
-class ARRAY(SchemaEventTarget, sqltypes.ARRAY):
+class ARRAY(sqltypes.ARRAY):
"""PostgreSQL ARRAY type.
@@ -248,18 +248,6 @@ class ARRAY(SchemaEventTarget, sqltypes.ARRAY):
def compare_values(self, x, y):
return x == y
- def _set_parent(self, column):
- """Support SchemaEventTarget"""
-
- if isinstance(self.item_type, SchemaEventTarget):
- self.item_type._set_parent(column)
-
- def _set_parent_with_dispatch(self, parent):
- """Support SchemaEventTarget"""
-
- if isinstance(self.item_type, SchemaEventTarget):
- self.item_type._set_parent_with_dispatch(parent)
-
def _proc_array(self, arr, itemproc, dim, collection):
if dim is None:
arr = list(arr)
@@ -311,4 +299,5 @@ class ARRAY(SchemaEventTarget, sqltypes.ARRAY):
tuple if self.as_tuple else list)
return process
+colspecs[sqltypes.ARRAY] = ARRAY
ischema_names['_array'] = ARRAY
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py
index 8a114ece6..b8117e3ca 100644
--- a/lib/sqlalchemy/sql/sqltypes.py
+++ b/lib/sqlalchemy/sql/sqltypes.py
@@ -2061,7 +2061,7 @@ class JSON(Indexable, TypeEngine):
return process
-class ARRAY(Indexable, Concatenable, TypeEngine):
+class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
"""Represent a SQL Array type.
.. note:: This type serves as the basis for all ARRAY operations.
@@ -2199,6 +2199,11 @@ class ARRAY(Indexable, Concatenable, TypeEngine):
return operators.getitem, index, return_type
+ def contains(self, *arg, **kw):
+ raise NotImplementedError(
+ "ARRAY.contains() not implemented for the base "
+ "ARRAY type; please use the dialect-specific ARRAY type")
+
@util.dependencies("sqlalchemy.sql.elements")
def any(self, elements, other, operator=None):
"""Return ``other operator ANY (array)`` clause.
@@ -2325,6 +2330,18 @@ class ARRAY(Indexable, Concatenable, TypeEngine):
def compare_values(self, x, y):
return x == y
+ def _set_parent(self, column):
+ """Support SchemaEventTarget"""
+
+ if isinstance(self.item_type, SchemaEventTarget):
+ self.item_type._set_parent(column)
+
+ def _set_parent_with_dispatch(self, parent):
+ """Support SchemaEventTarget"""
+
+ if isinstance(self.item_type, SchemaEventTarget):
+ self.item_type._set_parent_with_dispatch(parent)
+
class REAL(Float):
diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py
index 807eeb60c..d2e19a04a 100644
--- a/test/dialect/postgresql/test_types.py
+++ b/test/dialect/postgresql/test_types.py
@@ -4,6 +4,7 @@ from sqlalchemy.testing.assertions import eq_, assert_raises, \
AssertsCompiledSQL, ComparesTables
from sqlalchemy.testing import engines, fixtures
from sqlalchemy import testing
+from sqlalchemy.sql import sqltypes
import datetime
from sqlalchemy import Table, MetaData, Column, Integer, Enum, Float, select, \
func, DateTime, Numeric, exc, String, cast, REAL, TypeDecorator, Unicode, \
@@ -85,7 +86,7 @@ class FloatCoercionTest(fixtures.TablesTest, AssertsExecutionResults):
@testing.fails_on('postgresql+zxjdbc',
'zxjdbc has no support for PG arrays')
@testing.provide_metadata
- def test_arrays(self):
+ def test_arrays_pg(self):
metadata = self.metadata
t1 = Table('t', metadata,
Column('x', postgresql.ARRAY(Float)),
@@ -101,6 +102,25 @@ class FloatCoercionTest(fixtures.TablesTest, AssertsExecutionResults):
([5], [5], [6], [decimal.Decimal("6.4")])
)
+ @testing.fails_on('postgresql+zxjdbc',
+ 'zxjdbc has no support for PG arrays')
+ @testing.provide_metadata
+ def test_arrays_base(self):
+ metadata = self.metadata
+ t1 = Table('t', metadata,
+ Column('x', sqltypes.ARRAY(Float)),
+ Column('y', sqltypes.ARRAY(REAL)),
+ Column('z', sqltypes.ARRAY(postgresql.DOUBLE_PRECISION)),
+ Column('q', sqltypes.ARRAY(Numeric))
+ )
+ metadata.create_all()
+ t1.insert().execute(x=[5], y=[5], z=[6], q=[decimal.Decimal("6.4")])
+ row = t1.select().execute().first()
+ eq_(
+ row,
+ ([5], [5], [6], [decimal.Decimal("6.4")])
+ )
+
class EnumTest(fixtures.TestBase, AssertsExecutionResults):
__backend__ = True
@@ -987,17 +1007,19 @@ class ArrayTest(AssertsCompiledSQL, fixtures.TestBase):
is_(expr.type.item_type.__class__, Integer)
-class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults):
+class ArrayRoundTripTest(object):
__only_on__ = 'postgresql'
__backend__ = True
__unsupported_on__ = 'postgresql+pg8000', 'postgresql+zxjdbc'
+ ARRAY = postgresql.ARRAY
+
@classmethod
def define_tables(cls, metadata):
class ProcValue(TypeDecorator):
- impl = postgresql.ARRAY(Integer, dimensions=2)
+ impl = cls.ARRAY(Integer, dimensions=2)
def process_bind_param(self, value, dialect):
if value is None:
@@ -1017,15 +1039,15 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults):
Table('arrtable', metadata,
Column('id', Integer, primary_key=True),
- Column('intarr', postgresql.ARRAY(Integer)),
- Column('strarr', postgresql.ARRAY(Unicode())),
+ Column('intarr', cls.ARRAY(Integer)),
+ Column('strarr', cls.ARRAY(Unicode())),
Column('dimarr', ProcValue)
)
Table('dim_arrtable', metadata,
Column('id', Integer, primary_key=True),
- Column('intarr', postgresql.ARRAY(Integer, dimensions=1)),
- Column('strarr', postgresql.ARRAY(Unicode(), dimensions=1)),
+ Column('intarr', cls.ARRAY(Integer, dimensions=1)),
+ Column('strarr', cls.ARRAY(Unicode(), dimensions=1)),
Column('dimarr', ProcValue)
)
@@ -1038,8 +1060,8 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults):
def test_reflect_array_column(self):
metadata2 = MetaData(testing.db)
tbl = Table('arrtable', metadata2, autoload=True)
- assert isinstance(tbl.c.intarr.type, postgresql.ARRAY)
- assert isinstance(tbl.c.strarr.type, postgresql.ARRAY)
+ assert isinstance(tbl.c.intarr.type, self.ARRAY)
+ assert isinstance(tbl.c.strarr.type, self.ARRAY)
assert isinstance(tbl.c.intarr.type.item_type, Integer)
assert isinstance(tbl.c.strarr.type.item_type, String)
@@ -1107,19 +1129,19 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults):
func.array_cat(
array([1, 2, 3]),
array([4, 5, 6]),
- type_=postgresql.ARRAY(Integer)
+ type_=self.ARRAY(Integer)
)[2:5]
])
eq_(
testing.db.execute(stmt).scalar(), [2, 3, 4, 5]
)
- def test_any_all_exprs(self):
+ def test_any_all_exprs_array(self):
stmt = select([
3 == any_(func.array_cat(
array([1, 2, 3]),
array([4, 5, 6]),
- type_=postgresql.ARRAY(Integer)
+ type_=self.ARRAY(Integer)
))
])
eq_(
@@ -1225,17 +1247,6 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults):
7
)
- def test_undim_array_empty(self):
- arrtable = self.tables.arrtable
- self._fixture_456(arrtable)
- eq_(
- testing.db.scalar(
- select([arrtable.c.intarr]).
- where(arrtable.c.intarr.contains([]))
- ),
- [4, 5, 6]
- )
-
def test_array_getitem_slice_exec(self):
arrtable = self.tables.arrtable
testing.db.execute(
@@ -1255,49 +1266,6 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults):
[7, 8]
)
- def _test_undim_array_contains_typed_exec(self, struct):
- arrtable = self.tables.arrtable
- self._fixture_456(arrtable)
- eq_(
- testing.db.scalar(
- select([arrtable.c.intarr]).
- where(arrtable.c.intarr.contains(struct([4, 5])))
- ),
- [4, 5, 6]
- )
-
- def test_undim_array_contains_set_exec(self):
- self._test_undim_array_contains_typed_exec(set)
-
- def test_undim_array_contains_list_exec(self):
- self._test_undim_array_contains_typed_exec(list)
-
- def test_undim_array_contains_generator_exec(self):
- self._test_undim_array_contains_typed_exec(
- lambda elem: (x for x in elem))
-
- def _test_dim_array_contains_typed_exec(self, struct):
- dim_arrtable = self.tables.dim_arrtable
- self._fixture_456(dim_arrtable)
- eq_(
- testing.db.scalar(
- select([dim_arrtable.c.intarr]).
- where(dim_arrtable.c.intarr.contains(struct([4, 5])))
- ),
- [4, 5, 6]
- )
-
- def test_dim_array_contains_set_exec(self):
- self._test_dim_array_contains_typed_exec(set)
-
- def test_dim_array_contains_list_exec(self):
- self._test_dim_array_contains_typed_exec(list)
-
- def test_dim_array_contains_generator_exec(self):
- self._test_dim_array_contains_typed_exec(
- lambda elem: (
- x for x in elem))
-
def test_multi_dim_roundtrip(self):
arrtable = self.tables.arrtable
testing.db.execute(arrtable.insert(), dimarr=[[1, 2, 3], [4, 5, 6]])
@@ -1306,35 +1274,6 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults):
[[-1, 0, 1], [2, 3, 4]]
)
- def test_array_contained_by_exec(self):
- arrtable = self.tables.arrtable
- with testing.db.connect() as conn:
- conn.execute(
- arrtable.insert(),
- intarr=[6, 5, 4]
- )
- eq_(
- conn.scalar(
- select([arrtable.c.intarr.contained_by([4, 5, 6, 7])])
- ),
- True
- )
-
- def test_array_overlap_exec(self):
- arrtable = self.tables.arrtable
- with testing.db.connect() as conn:
- conn.execute(
- arrtable.insert(),
- intarr=[4, 5, 6]
- )
- eq_(
- conn.scalar(
- select([arrtable.c.intarr]).
- where(arrtable.c.intarr.overlap([7, 6]))
- ),
- [4, 5, 6]
- )
-
def test_array_any_exec(self):
arrtable = self.tables.arrtable
with testing.db.connect() as conn:
@@ -1372,10 +1311,10 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults):
t1 = Table(
't1', metadata,
Column('id', Integer, primary_key=True),
- Column('data', postgresql.ARRAY(String(5), as_tuple=True)),
+ Column('data', self.ARRAY(String(5), as_tuple=True)),
Column(
'data2',
- postgresql.ARRAY(
+ self.ARRAY(
Numeric(asdecimal=False), as_tuple=True)
)
)
@@ -1416,13 +1355,13 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults):
't', m,
Column(
'data_1',
- postgresql.ARRAY(
+ self.ARRAY(
postgresql.ENUM('a', 'b', 'c', name='my_enum_1')
)
),
Column(
'data_2',
- postgresql.ARRAY(
+ self.ARRAY(
types.Enum('a', 'b', 'c', name='my_enum_2')
)
)
@@ -1437,6 +1376,100 @@ class ArrayRoundTripTest(fixtures.TablesTest, AssertsExecutionResults):
eq_(inspect(testing.db).get_enums(), [])
+class CoreArrayRoundTripTest(ArrayRoundTripTest,
+ fixtures.TablesTest, AssertsExecutionResults):
+
+ ARRAY = sqltypes.ARRAY
+
+
+class PGArrayRoundTripTest(ArrayRoundTripTest,
+ fixtures.TablesTest, AssertsExecutionResults):
+ ARRAY = postgresql.ARRAY
+
+ def _test_undim_array_contains_typed_exec(self, struct):
+ arrtable = self.tables.arrtable
+ self._fixture_456(arrtable)
+ eq_(
+ testing.db.scalar(
+ select([arrtable.c.intarr]).
+ where(arrtable.c.intarr.contains(struct([4, 5])))
+ ),
+ [4, 5, 6]
+ )
+
+ def test_undim_array_contains_set_exec(self):
+ self._test_undim_array_contains_typed_exec(set)
+
+ def test_undim_array_contains_list_exec(self):
+ self._test_undim_array_contains_typed_exec(list)
+
+ def test_undim_array_contains_generator_exec(self):
+ self._test_undim_array_contains_typed_exec(
+ lambda elem: (x for x in elem))
+
+ def _test_dim_array_contains_typed_exec(self, struct):
+ dim_arrtable = self.tables.dim_arrtable
+ self._fixture_456(dim_arrtable)
+ eq_(
+ testing.db.scalar(
+ select([dim_arrtable.c.intarr]).
+ where(dim_arrtable.c.intarr.contains(struct([4, 5])))
+ ),
+ [4, 5, 6]
+ )
+
+ def test_dim_array_contains_set_exec(self):
+ self._test_dim_array_contains_typed_exec(set)
+
+ def test_dim_array_contains_list_exec(self):
+ self._test_dim_array_contains_typed_exec(list)
+
+ def test_dim_array_contains_generator_exec(self):
+ self._test_dim_array_contains_typed_exec(
+ lambda elem: (
+ x for x in elem))
+
+ def test_array_contained_by_exec(self):
+ arrtable = self.tables.arrtable
+ with testing.db.connect() as conn:
+ conn.execute(
+ arrtable.insert(),
+ intarr=[6, 5, 4]
+ )
+ eq_(
+ conn.scalar(
+ select([arrtable.c.intarr.contained_by([4, 5, 6, 7])])
+ ),
+ True
+ )
+
+ def test_undim_array_empty(self):
+ arrtable = self.tables.arrtable
+ self._fixture_456(arrtable)
+ eq_(
+ testing.db.scalar(
+ select([arrtable.c.intarr]).
+ where(arrtable.c.intarr.contains([]))
+ ),
+ [4, 5, 6]
+ )
+
+ def test_array_overlap_exec(self):
+ arrtable = self.tables.arrtable
+ with testing.db.connect() as conn:
+ conn.execute(
+ arrtable.insert(),
+ intarr=[4, 5, 6]
+ )
+ eq_(
+ conn.scalar(
+ select([arrtable.c.intarr]).
+ where(arrtable.c.intarr.overlap([7, 6]))
+ ),
+ [4, 5, 6]
+ )
+
+
class HashableFlagORMTest(fixtures.TestBase):
"""test the various 'collection' types that they flip the 'hashable' flag
appropriately. [ticket:3499]"""