diff options
| author | Federico Caselli <cfederico87@gmail.com> | 2021-03-10 23:54:52 +0100 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2021-03-15 20:11:20 -0400 |
| commit | dfa1d3b28f1a0abf1e11c76a94f7a65bf98d29af (patch) | |
| tree | 975a06018edcc9a9fa75b709f40698842a82e494 /lib/sqlalchemy | |
| parent | 28b0b6515af26ee3ba09600a8212849b2dae0699 (diff) | |
| download | sqlalchemy-dfa1d3b28f1a0abf1e11c76a94f7a65bf98d29af.tar.gz | |
CAST the elements in ARRAYs when using psycopg2
Adjusted the psycopg2 dialect to emit an explicit PostgreSQL-style cast for
bound parameters that contain ARRAY elements. This allows the full range of
datatypes to function correctly within arrays. The asyncpg dialect already
generated these internal casts in the final statement. This also includes
support for array slice updates as well as the PostgreSQL-specific
:meth:`_postgresql.ARRAY.contains` method.
Fixes: #6023
Change-Id: Ia7519ac4371a635f05ac69a3a4d0f4e6d2f04cad
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/array.py | 11 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/psycopg2.py | 16 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/crud.py | 9 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/sqltypes.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/type_api.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/__init__.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/config.py | 5 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/plugin/pytestplugin.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/schema.py | 41 |
9 files changed, 78 insertions, 12 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/array.py b/lib/sqlalchemy/dialects/postgresql/array.py index 91bb89ea9..c2d99845f 100644 --- a/lib/sqlalchemy/dialects/postgresql/array.py +++ b/lib/sqlalchemy/dialects/postgresql/array.py @@ -331,12 +331,6 @@ class ARRAY(sqltypes.ARRAY): ) @util.memoized_property - def _require_cast(self): - return self._against_native_enum or isinstance( - self.item_type, sqltypes.JSON - ) - - @util.memoized_property def _against_native_enum(self): return ( isinstance(self.item_type, sqltypes.Enum) @@ -344,10 +338,7 @@ class ARRAY(sqltypes.ARRAY): ) def bind_expression(self, bindvalue): - if self._require_cast: - return expression.cast(bindvalue, self) - else: - return bindvalue + return bindvalue def bind_processor(self, dialect): item_proc = self.item_type.dialect_impl(dialect).bind_processor( diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index a52eacd8b..1969eb844 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -450,6 +450,7 @@ from ... import processors from ... import types as sqltypes from ... import util from ...engine import cursor as _cursor +from ...sql import elements from ...util import collections_abc @@ -597,7 +598,20 @@ class PGExecutionContext_psycopg2(PGExecutionContext): class PGCompiler_psycopg2(PGCompiler): - pass + def visit_bindparam(self, bindparam, skip_bind_expression=False, **kw): + + text = super(PGCompiler_psycopg2, self).visit_bindparam( + bindparam, skip_bind_expression=skip_bind_expression, **kw + ) + # note that if the type has a bind_expression(), we will get a + # double compile here + if not skip_bind_expression and bindparam.type._is_array: + text += "::%s" % ( + elements.TypeClause(bindparam.type)._compiler_dispatch( + self, skip_bind_expression=skip_bind_expression, **kw + ), + ) + return text class PGIdentifierPreparer_psycopg2(PGIdentifierPreparer): diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 8b4950aa3..174a1c131 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -943,6 +943,9 @@ def _get_stmt_parameter_tuples_params( # add it to values() in an "as-is" state, # coercing right side to bound param + # note one of the main use cases for this is array slice + # updates on PostgreSQL, as the left side is also an expression. + col_expr = compiler.process( k, include_table=compile_state.include_table_with_column_exprs ) @@ -952,6 +955,12 @@ def _get_stmt_parameter_tuples_params( elements.BindParameter(None, v, type_=k.type), **kw ) else: + if v._is_bind_parameter and v.type._isnull: + # either unique parameter, or other bound parameters that + # were passed in directly + # set type to that of the column unconditionally + v = v._with_binary_element_type(k.type) + v = compiler.process(v.self_group(), **kw) values.append((k, col_expr, v)) diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index d075ef77d..816423d1b 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -2675,6 +2675,8 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine): __visit_name__ = "ARRAY" + _is_array = True + zero_indexes = False """If True, Python zero-based indexes should be interpreted as one-based on the SQL expression side.""" diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index 46751cb22..9752750c5 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -47,6 +47,7 @@ class TypeEngine(Traversible): _isnull = False _is_tuple_type = False _is_table_value = False + _is_array = False class Comparator(operators.ColumnOperators): """Base class for custom comparison operations defined at the diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py index adbb8f643..a3ce24226 100644 --- a/lib/sqlalchemy/testing/__init__.py +++ b/lib/sqlalchemy/testing/__init__.py @@ -42,6 +42,7 @@ from .assertions import startswith_ from .assertions import uses_deprecated from .config import async_test from .config import combinations +from .config import combinations_list from .config import db from .config import fixture from .config import requirements as requires diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py index 750671f9f..6589e5097 100644 --- a/lib/sqlalchemy/testing/config.py +++ b/lib/sqlalchemy/testing/config.py @@ -89,6 +89,11 @@ def combinations(*comb, **kw): return _fixture_functions.combinations(*comb, **kw) +def combinations_list(arg_iterable, **kw): + "As combination, but takes a single iterable" + return combinations(*arg_iterable, **kw) + + def fixture(*arg, **kw): return _fixture_functions.fixture(*arg, **kw) diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index 4eaaecebb..388d71c73 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -578,7 +578,9 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions): "i": lambda obj: obj, "r": repr, "s": str, - "n": operator.attrgetter("__name__"), + "n": lambda obj: obj.__name__ + if hasattr(obj, "__name__") + else type(obj).__name__, } def combinations(self, *arg_sets, **kw): diff --git a/lib/sqlalchemy/testing/schema.py b/lib/sqlalchemy/testing/schema.py index 22b1f7b77..fee021cff 100644 --- a/lib/sqlalchemy/testing/schema.py +++ b/lib/sqlalchemy/testing/schema.py @@ -5,11 +5,14 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php +import sys + from . import config from . import exclusions from .. import event from .. import schema from .. import types as sqltypes +from ..util import OrderedDict __all__ = ["Table", "Column"] @@ -162,3 +165,41 @@ def _truncate_name(dialect, name): ) else: return name + + +def pep435_enum(name): + # Implements PEP 435 in the minimal fashion needed by SQLAlchemy + __members__ = OrderedDict() + + def __init__(self, name, value, alias=None): + self.name = name + self.value = value + self.__members__[name] = self + value_to_member[value] = self + setattr(self.__class__, name, self) + if alias: + self.__members__[alias] = self + setattr(self.__class__, alias, self) + + value_to_member = {} + + @classmethod + def get(cls, value): + return value_to_member[value] + + someenum = type( + name, + (object,), + {"__members__": __members__, "__init__": __init__, "get": get}, + ) + + # getframe() trick for pickling I don't understand courtesy + # Python namedtuple() + try: + module = sys._getframe(1).f_globals.get("__name__", "__main__") + except (AttributeError, ValueError): + pass + if module is not None: + someenum.__module__ = module + + return someenum |
