diff options
| author | mike bayer <mike_mp@zzzcomputing.com> | 2021-03-16 22:14:38 +0000 |
|---|---|---|
| committer | Gerrit Code Review <gerrit@ci3.zzzcomputing.com> | 2021-03-16 22:14:38 +0000 |
| commit | 520c5cd0f4ed02b7595d9d83ed573688223fffc2 (patch) | |
| tree | 055cea0ff86f7b1c7b6f6883feae6bf245e9c867 /lib/sqlalchemy | |
| parent | 4a827330616a90b1fa0a10f86d8e7cb6e92047ba (diff) | |
| parent | dfa1d3b28f1a0abf1e11c76a94f7a65bf98d29af (diff) | |
| download | sqlalchemy-520c5cd0f4ed02b7595d9d83ed573688223fffc2.tar.gz | |
Merge "CAST the elements in ARRAYs when using psycopg2"
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/array.py | 11 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/psycopg2.py | 16 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/crud.py | 9 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/sqltypes.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/type_api.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/__init__.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/config.py | 5 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/plugin/pytestplugin.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/schema.py | 41 |
9 files changed, 78 insertions, 12 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/array.py b/lib/sqlalchemy/dialects/postgresql/array.py index 91bb89ea9..c2d99845f 100644 --- a/lib/sqlalchemy/dialects/postgresql/array.py +++ b/lib/sqlalchemy/dialects/postgresql/array.py @@ -331,12 +331,6 @@ class ARRAY(sqltypes.ARRAY): ) @util.memoized_property - def _require_cast(self): - return self._against_native_enum or isinstance( - self.item_type, sqltypes.JSON - ) - - @util.memoized_property def _against_native_enum(self): return ( isinstance(self.item_type, sqltypes.Enum) @@ -344,10 +338,7 @@ class ARRAY(sqltypes.ARRAY): ) def bind_expression(self, bindvalue): - if self._require_cast: - return expression.cast(bindvalue, self) - else: - return bindvalue + return bindvalue def bind_processor(self, dialect): item_proc = self.item_type.dialect_impl(dialect).bind_processor( diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index a52eacd8b..1969eb844 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -450,6 +450,7 @@ from ... import processors from ... import types as sqltypes from ... import util from ...engine import cursor as _cursor +from ...sql import elements from ...util import collections_abc @@ -597,7 +598,20 @@ class PGExecutionContext_psycopg2(PGExecutionContext): class PGCompiler_psycopg2(PGCompiler): - pass + def visit_bindparam(self, bindparam, skip_bind_expression=False, **kw): + + text = super(PGCompiler_psycopg2, self).visit_bindparam( + bindparam, skip_bind_expression=skip_bind_expression, **kw + ) + # note that if the type has a bind_expression(), we will get a + # double compile here + if not skip_bind_expression and bindparam.type._is_array: + text += "::%s" % ( + elements.TypeClause(bindparam.type)._compiler_dispatch( + self, skip_bind_expression=skip_bind_expression, **kw + ), + ) + return text class PGIdentifierPreparer_psycopg2(PGIdentifierPreparer): diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 8b4950aa3..174a1c131 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -943,6 +943,9 @@ def _get_stmt_parameter_tuples_params( # add it to values() in an "as-is" state, # coercing right side to bound param + # note one of the main use cases for this is array slice + # updates on PostgreSQL, as the left side is also an expression. + col_expr = compiler.process( k, include_table=compile_state.include_table_with_column_exprs ) @@ -952,6 +955,12 @@ def _get_stmt_parameter_tuples_params( elements.BindParameter(None, v, type_=k.type), **kw ) else: + if v._is_bind_parameter and v.type._isnull: + # either unique parameter, or other bound parameters that + # were passed in directly + # set type to that of the column unconditionally + v = v._with_binary_element_type(k.type) + v = compiler.process(v.self_group(), **kw) values.append((k, col_expr, v)) diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index d075ef77d..816423d1b 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -2675,6 +2675,8 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine): __visit_name__ = "ARRAY" + _is_array = True + zero_indexes = False """If True, Python zero-based indexes should be interpreted as one-based on the SQL expression side.""" diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index 46751cb22..9752750c5 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -47,6 +47,7 @@ class TypeEngine(Traversible): _isnull = False _is_tuple_type = False _is_table_value = False + _is_array = False class Comparator(operators.ColumnOperators): """Base class for custom comparison operations defined at the diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py index adbb8f643..a3ce24226 100644 --- a/lib/sqlalchemy/testing/__init__.py +++ b/lib/sqlalchemy/testing/__init__.py @@ -42,6 +42,7 @@ from .assertions import startswith_ from .assertions import uses_deprecated from .config import async_test from .config import combinations +from .config import combinations_list from .config import db from .config import fixture from .config import requirements as requires diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py index 750671f9f..6589e5097 100644 --- a/lib/sqlalchemy/testing/config.py +++ b/lib/sqlalchemy/testing/config.py @@ -89,6 +89,11 @@ def combinations(*comb, **kw): return _fixture_functions.combinations(*comb, **kw) +def combinations_list(arg_iterable, **kw): + "As combination, but takes a single iterable" + return combinations(*arg_iterable, **kw) + + def fixture(*arg, **kw): return _fixture_functions.fixture(*arg, **kw) diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index 4eaaecebb..388d71c73 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -578,7 +578,9 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions): "i": lambda obj: obj, "r": repr, "s": str, - "n": operator.attrgetter("__name__"), + "n": lambda obj: obj.__name__ + if hasattr(obj, "__name__") + else type(obj).__name__, } def combinations(self, *arg_sets, **kw): diff --git a/lib/sqlalchemy/testing/schema.py b/lib/sqlalchemy/testing/schema.py index 22b1f7b77..fee021cff 100644 --- a/lib/sqlalchemy/testing/schema.py +++ b/lib/sqlalchemy/testing/schema.py @@ -5,11 +5,14 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php +import sys + from . import config from . import exclusions from .. import event from .. import schema from .. import types as sqltypes +from ..util import OrderedDict __all__ = ["Table", "Column"] @@ -162,3 +165,41 @@ def _truncate_name(dialect, name): ) else: return name + + +def pep435_enum(name): + # Implements PEP 435 in the minimal fashion needed by SQLAlchemy + __members__ = OrderedDict() + + def __init__(self, name, value, alias=None): + self.name = name + self.value = value + self.__members__[name] = self + value_to_member[value] = self + setattr(self.__class__, name, self) + if alias: + self.__members__[alias] = self + setattr(self.__class__, alias, self) + + value_to_member = {} + + @classmethod + def get(cls, value): + return value_to_member[value] + + someenum = type( + name, + (object,), + {"__members__": __members__, "__init__": __init__, "get": get}, + ) + + # getframe() trick for pickling I don't understand courtesy + # Python namedtuple() + try: + module = sys._getframe(1).f_globals.get("__name__", "__main__") + except (AttributeError, ValueError): + pass + if module is not None: + someenum.__module__ = module + + return someenum |
