summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2021-03-16 22:14:38 +0000
committerGerrit Code Review <gerrit@ci3.zzzcomputing.com>2021-03-16 22:14:38 +0000
commit520c5cd0f4ed02b7595d9d83ed573688223fffc2 (patch)
tree055cea0ff86f7b1c7b6f6883feae6bf245e9c867 /lib/sqlalchemy
parent4a827330616a90b1fa0a10f86d8e7cb6e92047ba (diff)
parentdfa1d3b28f1a0abf1e11c76a94f7a65bf98d29af (diff)
downloadsqlalchemy-520c5cd0f4ed02b7595d9d83ed573688223fffc2.tar.gz
Merge "CAST the elements in ARRAYs when using psycopg2"
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/dialects/postgresql/array.py11
-rw-r--r--lib/sqlalchemy/dialects/postgresql/psycopg2.py16
-rw-r--r--lib/sqlalchemy/sql/crud.py9
-rw-r--r--lib/sqlalchemy/sql/sqltypes.py2
-rw-r--r--lib/sqlalchemy/sql/type_api.py1
-rw-r--r--lib/sqlalchemy/testing/__init__.py1
-rw-r--r--lib/sqlalchemy/testing/config.py5
-rw-r--r--lib/sqlalchemy/testing/plugin/pytestplugin.py4
-rw-r--r--lib/sqlalchemy/testing/schema.py41
9 files changed, 78 insertions, 12 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/array.py b/lib/sqlalchemy/dialects/postgresql/array.py
index 91bb89ea9..c2d99845f 100644
--- a/lib/sqlalchemy/dialects/postgresql/array.py
+++ b/lib/sqlalchemy/dialects/postgresql/array.py
@@ -331,12 +331,6 @@ class ARRAY(sqltypes.ARRAY):
)
@util.memoized_property
- def _require_cast(self):
- return self._against_native_enum or isinstance(
- self.item_type, sqltypes.JSON
- )
-
- @util.memoized_property
def _against_native_enum(self):
return (
isinstance(self.item_type, sqltypes.Enum)
@@ -344,10 +338,7 @@ class ARRAY(sqltypes.ARRAY):
)
def bind_expression(self, bindvalue):
- if self._require_cast:
- return expression.cast(bindvalue, self)
- else:
- return bindvalue
+ return bindvalue
def bind_processor(self, dialect):
item_proc = self.item_type.dialect_impl(dialect).bind_processor(
diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py
index a52eacd8b..1969eb844 100644
--- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py
+++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py
@@ -450,6 +450,7 @@ from ... import processors
from ... import types as sqltypes
from ... import util
from ...engine import cursor as _cursor
+from ...sql import elements
from ...util import collections_abc
@@ -597,7 +598,20 @@ class PGExecutionContext_psycopg2(PGExecutionContext):
class PGCompiler_psycopg2(PGCompiler):
- pass
+ def visit_bindparam(self, bindparam, skip_bind_expression=False, **kw):
+
+ text = super(PGCompiler_psycopg2, self).visit_bindparam(
+ bindparam, skip_bind_expression=skip_bind_expression, **kw
+ )
+ # note that if the type has a bind_expression(), we will get a
+ # double compile here
+ if not skip_bind_expression and bindparam.type._is_array:
+ text += "::%s" % (
+ elements.TypeClause(bindparam.type)._compiler_dispatch(
+ self, skip_bind_expression=skip_bind_expression, **kw
+ ),
+ )
+ return text
class PGIdentifierPreparer_psycopg2(PGIdentifierPreparer):
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index 8b4950aa3..174a1c131 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -943,6 +943,9 @@ def _get_stmt_parameter_tuples_params(
# add it to values() in an "as-is" state,
# coercing right side to bound param
+ # note one of the main use cases for this is array slice
+ # updates on PostgreSQL, as the left side is also an expression.
+
col_expr = compiler.process(
k, include_table=compile_state.include_table_with_column_exprs
)
@@ -952,6 +955,12 @@ def _get_stmt_parameter_tuples_params(
elements.BindParameter(None, v, type_=k.type), **kw
)
else:
+ if v._is_bind_parameter and v.type._isnull:
+ # either unique parameter, or other bound parameters that
+ # were passed in directly
+ # set type to that of the column unconditionally
+ v = v._with_binary_element_type(k.type)
+
v = compiler.process(v.self_group(), **kw)
values.append((k, col_expr, v))
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py
index d075ef77d..816423d1b 100644
--- a/lib/sqlalchemy/sql/sqltypes.py
+++ b/lib/sqlalchemy/sql/sqltypes.py
@@ -2675,6 +2675,8 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
__visit_name__ = "ARRAY"
+ _is_array = True
+
zero_indexes = False
"""If True, Python zero-based indexes should be interpreted as one-based
on the SQL expression side."""
diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py
index 46751cb22..9752750c5 100644
--- a/lib/sqlalchemy/sql/type_api.py
+++ b/lib/sqlalchemy/sql/type_api.py
@@ -47,6 +47,7 @@ class TypeEngine(Traversible):
_isnull = False
_is_tuple_type = False
_is_table_value = False
+ _is_array = False
class Comparator(operators.ColumnOperators):
"""Base class for custom comparison operations defined at the
diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py
index adbb8f643..a3ce24226 100644
--- a/lib/sqlalchemy/testing/__init__.py
+++ b/lib/sqlalchemy/testing/__init__.py
@@ -42,6 +42,7 @@ from .assertions import startswith_
from .assertions import uses_deprecated
from .config import async_test
from .config import combinations
+from .config import combinations_list
from .config import db
from .config import fixture
from .config import requirements as requires
diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py
index 750671f9f..6589e5097 100644
--- a/lib/sqlalchemy/testing/config.py
+++ b/lib/sqlalchemy/testing/config.py
@@ -89,6 +89,11 @@ def combinations(*comb, **kw):
return _fixture_functions.combinations(*comb, **kw)
+def combinations_list(arg_iterable, **kw):
+ "As combination, but takes a single iterable"
+ return combinations(*arg_iterable, **kw)
+
+
def fixture(*arg, **kw):
return _fixture_functions.fixture(*arg, **kw)
diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py
index 4eaaecebb..388d71c73 100644
--- a/lib/sqlalchemy/testing/plugin/pytestplugin.py
+++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py
@@ -578,7 +578,9 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
"i": lambda obj: obj,
"r": repr,
"s": str,
- "n": operator.attrgetter("__name__"),
+ "n": lambda obj: obj.__name__
+ if hasattr(obj, "__name__")
+ else type(obj).__name__,
}
def combinations(self, *arg_sets, **kw):
diff --git a/lib/sqlalchemy/testing/schema.py b/lib/sqlalchemy/testing/schema.py
index 22b1f7b77..fee021cff 100644
--- a/lib/sqlalchemy/testing/schema.py
+++ b/lib/sqlalchemy/testing/schema.py
@@ -5,11 +5,14 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
+import sys
+
from . import config
from . import exclusions
from .. import event
from .. import schema
from .. import types as sqltypes
+from ..util import OrderedDict
__all__ = ["Table", "Column"]
@@ -162,3 +165,41 @@ def _truncate_name(dialect, name):
)
else:
return name
+
+
+def pep435_enum(name):
+ # Implements PEP 435 in the minimal fashion needed by SQLAlchemy
+ __members__ = OrderedDict()
+
+ def __init__(self, name, value, alias=None):
+ self.name = name
+ self.value = value
+ self.__members__[name] = self
+ value_to_member[value] = self
+ setattr(self.__class__, name, self)
+ if alias:
+ self.__members__[alias] = self
+ setattr(self.__class__, alias, self)
+
+ value_to_member = {}
+
+ @classmethod
+ def get(cls, value):
+ return value_to_member[value]
+
+ someenum = type(
+ name,
+ (object,),
+ {"__members__": __members__, "__init__": __init__, "get": get},
+ )
+
+ # getframe() trick for pickling I don't understand courtesy
+ # Python namedtuple()
+ try:
+ module = sys._getframe(1).f_globals.get("__name__", "__main__")
+ except (AttributeError, ValueError):
+ pass
+ if module is not None:
+ someenum.__module__ = module
+
+ return someenum