summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2021-10-13 15:52:12 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2021-10-15 09:28:49 -0400
commit639cf972f15c8fbf77980b04fff8e5dbc82af7b6 (patch)
tree162aafe94f82df3e34675ba26b5c88ce4f1b2044
parentfec2b6560c14bb28ee7fc9d21028844acf700b04 (diff)
downloadsqlalchemy-639cf972f15c8fbf77980b04fff8e5dbc82af7b6.tar.gz
support bind expressions w/ expanding IN; apply to psycopg2
Fixed issue where "expanding IN" would fail to function correctly with datatypes that use the :meth:`_types.TypeEngine.bind_expression` method, where the method would need to be applied to each element of the IN expression rather than the overall IN expression itself. Fixed issue where IN expressions against a series of array elements, as can be done with PostgreSQL, would fail to function correctly due to multiple issues within the "expanding IN" feature of SQLAlchemy Core that was standardized in version 1.4. The psycopg2 dialect now makes use of the :meth:`_types.TypeEngine.bind_expression` method with :class:`_types.ARRAY` to portably apply the correct casts to elements. The asyncpg dialect was not affected by this issue as it applies bind-level casts at the driver level rather than at the compiler level. as part of this commit the "bind translate" feature has been simplified and also applies to the names in the POSTCOMPILE tag to accommodate for brackets. Fixes: #7177 Change-Id: I08c703adb0a9bd6f5aeee5de3ff6f03cccdccdc5
-rw-r--r--doc/build/changelog/unreleased_14/7177.rst22
-rw-r--r--lib/sqlalchemy/dialects/postgresql/asyncpg.py1
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py15
-rw-r--r--lib/sqlalchemy/dialects/postgresql/psycopg2.py29
-rw-r--r--lib/sqlalchemy/engine/default.py2
-rw-r--r--lib/sqlalchemy/sql/compiler.py106
-rw-r--r--test/dialect/postgresql/test_types.py223
-rw-r--r--test/sql/test_external_traversal.py10
-rw-r--r--test/sql/test_type_expressions.py23
9 files changed, 338 insertions, 93 deletions
diff --git a/doc/build/changelog/unreleased_14/7177.rst b/doc/build/changelog/unreleased_14/7177.rst
new file mode 100644
index 000000000..7766c838e
--- /dev/null
+++ b/doc/build/changelog/unreleased_14/7177.rst
@@ -0,0 +1,22 @@
+.. change::
+ :tags: sql, bug, regression
+ :tickets: 7177
+
+ Fixed issue where "expanding IN" would fail to function correctly with
+ datatypes that use the :meth:`_types.TypeEngine.bind_expression` method,
+ where the method would need to be applied to each element of the
+ IN expression rather than the overall IN expression itself.
+
+.. change::
+ :tags: postgresql, bug, regression
+ :tickets: 7177
+
+ Fixed issue where IN expressions against a series of array elements, as can
+ be done with PostgreSQL, would fail to function correctly due to multiple
+ issues within the "expanding IN" feature of SQLAlchemy Core that was
+ standardized in version 1.4. The psycopg2 dialect now makes use of the
+ :meth:`_types.TypeEngine.bind_expression` method with :class:`_types.ARRAY`
+ to portably apply the correct casts to elements. The asyncpg dialect was
+ not affected by this issue as it applies bind-level casts at the driver
+ level rather than at the compiler level.
+
diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py
index dc3da224c..3d195e691 100644
--- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py
+++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py
@@ -362,7 +362,6 @@ class AsyncAdapt_asyncpg_cursor:
if not self._inputsizes:
return tuple("$%d" % idx for idx, _ in enumerate(params, 1))
else:
-
return tuple(
"$%d::%s" % (idx, typ) if typ else "$%d" % idx
for idx, typ in enumerate(
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index 2e28b45ca..c1a2cf81d 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -2047,6 +2047,15 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
self.drop(bind=bind, checkfirst=checkfirst)
+class _ColonCast(elements.Cast):
+ __visit_name__ = "colon_cast"
+
+ def __init__(self, expression, type_):
+ self.type = type_
+ self.clause = expression
+ self.typeclause = elements.TypeClause(type_)
+
+
colspecs = {
sqltypes.ARRAY: _array.ARRAY,
sqltypes.Interval: INTERVAL,
@@ -2102,6 +2111,12 @@ ischema_names = {
class PGCompiler(compiler.SQLCompiler):
+ def visit_colon_cast(self, element, **kw):
+ return "%s::%s" % (
+ element.clause._compiler_dispatch(self, **kw),
+ element.typeclause._compiler_dispatch(self, **kw),
+ )
+
def visit_array(self, element, **kw):
return "ARRAY[%s]" % self.visit_clauselist(element, **kw)
diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py
index a71bdf760..4143dd041 100644
--- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py
+++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py
@@ -473,6 +473,8 @@ import logging
import re
from uuid import UUID as _python_UUID
+from .array import ARRAY as PGARRAY
+from .base import _ColonCast
from .base import _DECIMAL_TYPES
from .base import _FLOAT_TYPES
from .base import _INT_TYPES
@@ -490,7 +492,6 @@ from ... import processors
from ... import types as sqltypes
from ... import util
from ...engine import cursor as _cursor
-from ...sql import elements
from ...util import collections_abc
@@ -556,6 +557,11 @@ class _PGHStore(HSTORE):
return super(_PGHStore, self).result_processor(dialect, coltype)
+class _PGARRAY(PGARRAY):
+ def bind_expression(self, bindvalue):
+ return _ColonCast(bindvalue, self)
+
+
class _PGJSON(JSON):
def result_processor(self, dialect, coltype):
return None
@@ -638,25 +644,7 @@ class PGExecutionContext_psycopg2(PGExecutionContext):
class PGCompiler_psycopg2(PGCompiler):
- def visit_bindparam(self, bindparam, skip_bind_expression=False, **kw):
-
- text = super(PGCompiler_psycopg2, self).visit_bindparam(
- bindparam, skip_bind_expression=skip_bind_expression, **kw
- )
- # note that if the type has a bind_expression(), we will get a
- # double compile here
- if not skip_bind_expression and (
- bindparam.type._is_array or bindparam.type._is_type_decorator
- ):
- typ = bindparam.type._unwrapped_dialect_impl(self.dialect)
-
- if typ._is_array:
- text += "::%s" % (
- elements.TypeClause(typ)._compiler_dispatch(
- self, skip_bind_expression=skip_bind_expression, **kw
- ),
- )
- return text
+ pass
class PGIdentifierPreparer_psycopg2(PGIdentifierPreparer):
@@ -713,6 +701,7 @@ class PGDialect_psycopg2(PGDialect):
sqltypes.JSON: _PGJSON,
JSONB: _PGJSONB,
UUID: _PGUUID,
+ sqltypes.ARRAY: _PGARRAY,
},
)
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index eff28e340..75bca1905 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -1584,7 +1584,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
from the bind parameter's ``TypeEngine`` objects.
This method only called by those dialects which require it,
- currently cx_oracle.
+ currently cx_oracle, asyncpg and pg8000.
"""
if self.isddl or self.is_text:
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index efcfe0e51..0cd568fcc 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -165,11 +165,8 @@ BIND_TEMPLATES = {
"named": ":%(name)s",
}
-BIND_TRANSLATE = {
- "pyformat": re.compile(r"[%\(\)]"),
- "named": re.compile(r"[\:]"),
-}
-_BIND_TRANSLATE_CHARS = {"%": "P", "(": "A", ")": "Z", ":": "C"}
+_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\]]")
+_BIND_TRANSLATE_CHARS = dict(zip("%():[]", "PAZC__"))
OPERATORS = {
# binary
@@ -746,7 +743,6 @@ class SQLCompiler(Compiled):
self.positiontup = []
self._numeric_binds = dialect.paramstyle == "numeric"
self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
- self._bind_translate = BIND_TRANSLATE.get(dialect.paramstyle, None)
self.ctes = None
@@ -1113,7 +1109,6 @@ class SQLCompiler(Compiled):
N as a bound parameter.
"""
-
if parameters is None:
parameters = self.construct_params()
@@ -1141,22 +1136,36 @@ class SQLCompiler(Compiled):
replacement_expressions = {}
to_update_sets = {}
+ # notes:
+ # *unescaped* parameter names in:
+ # self.bind_names, self.binds, self._bind_processors
+ #
+ # *escaped* parameter names in:
+ # construct_params(), replacement_expressions
+
for name in (
self.positiontup if self.positional else self.bind_names.values()
):
+ escaped_name = (
+ self.escaped_bind_names.get(name, name)
+ if self.escaped_bind_names
+ else name
+ )
parameter = self.binds[name]
if parameter in self.literal_execute_params:
- if name not in replacement_expressions:
- value = parameters.pop(name)
+ if escaped_name not in replacement_expressions:
+ value = parameters.pop(escaped_name)
- replacement_expressions[name] = self.render_literal_bindparam(
+ replacement_expressions[
+ escaped_name
+ ] = self.render_literal_bindparam(
parameter, render_literal_value=value
)
continue
if parameter in self.post_compile_params:
- if name in replacement_expressions:
- to_update = to_update_sets[name]
+ if escaped_name in replacement_expressions:
+ to_update = to_update_sets[escaped_name]
else:
# we are removing the parameter from parameters
# because it is a list value, which is not expected by
@@ -1164,13 +1173,15 @@ class SQLCompiler(Compiled):
# process it. the single name is being replaced with
# individual numbered parameters for each value in the
# param.
- values = parameters.pop(name)
+ values = parameters.pop(escaped_name)
leep = self._literal_execute_expanding_parameter
- to_update, replacement_expr = leep(name, parameter, values)
+ to_update, replacement_expr = leep(
+ escaped_name, parameter, values
+ )
- to_update_sets[name] = to_update
- replacement_expressions[name] = replacement_expr
+ to_update_sets[escaped_name] = to_update
+ replacement_expressions[escaped_name] = replacement_expr
if not parameter.literal_execute:
parameters.update(to_update)
@@ -1200,10 +1211,24 @@ class SQLCompiler(Compiled):
positiontup.append(name)
def process_expanding(m):
- return replacement_expressions[m.group(1)]
+ key = m.group(1)
+ expr = replacement_expressions[key]
+
+ # if POSTCOMPILE included a bind_expression, render that
+ # around each element
+ if m.group(2):
+ tok = m.group(2).split("~~")
+ be_left, be_right = tok[1], tok[3]
+ expr = ", ".join(
+ "%s%s%s" % (be_left, exp, be_right)
+ for exp in expr.split(", ")
+ )
+ return expr
statement = re.sub(
- r"\[POSTCOMPILE_(\S+)\]", process_expanding, self.string
+ r"\[POSTCOMPILE_(\S+?)(~~.+?~~)?\]",
+ process_expanding,
+ self.string,
)
expanded_state = ExpandedState(
@@ -1963,8 +1988,10 @@ class SQLCompiler(Compiled):
self, parameter, values
):
+ typ_dialect_impl = parameter.type._unwrapped_dialect_impl(self.dialect)
+
if not values:
- if parameter.type._is_tuple_type:
+ if typ_dialect_impl._is_tuple_type:
replacement_expression = (
"VALUES " if self.dialect.tuple_in_values else ""
) + self.visit_empty_set_op_expr(
@@ -1977,7 +2004,7 @@ class SQLCompiler(Compiled):
)
elif isinstance(values[0], (tuple, list)):
- assert parameter.type._is_tuple_type
+ assert typ_dialect_impl._is_tuple_type
replacement_expression = (
"VALUES " if self.dialect.tuple_in_values else ""
) + ", ".join(
@@ -1993,7 +2020,7 @@ class SQLCompiler(Compiled):
for i, tuple_element in enumerate(values)
)
else:
- assert not parameter.type._is_tuple_type
+ assert not typ_dialect_impl._is_tuple_type
replacement_expression = ", ".join(
self.render_literal_value(value, parameter.type)
for value in values
@@ -2008,9 +2035,11 @@ class SQLCompiler(Compiled):
parameter, values
)
+ typ_dialect_impl = parameter.type._unwrapped_dialect_impl(self.dialect)
+
if not values:
to_update = []
- if parameter.type._is_tuple_type:
+ if typ_dialect_impl._is_tuple_type:
replacement_expression = self.visit_empty_set_op_expr(
parameter.type.types, parameter.expand_op
@@ -2020,7 +2049,10 @@ class SQLCompiler(Compiled):
[parameter.type], parameter.expand_op
)
- elif isinstance(values[0], (tuple, list)):
+ elif (
+ isinstance(values[0], (tuple, list))
+ and not typ_dialect_impl._is_array
+ ):
to_update = [
("%s_%s_%s" % (name, i, j), value)
for i, tuple_element in enumerate(values, 1)
@@ -2299,14 +2331,27 @@ class SQLCompiler(Compiled):
impl = bindparam.type.dialect_impl(self.dialect)
if impl._has_bind_expression:
bind_expression = impl.bind_expression(bindparam)
- return self.process(
+ wrapped = self.process(
bind_expression,
skip_bind_expression=True,
within_columns_clause=within_columns_clause,
literal_binds=literal_binds,
literal_execute=literal_execute,
+ render_postcompile=render_postcompile,
**kwargs
)
+ if bindparam.expanding:
+ # for postcompile w/ expanding, move the "wrapped" part
+ # of this into the inside
+ m = re.match(
+ r"^(.*)\(\[POSTCOMPILE_(\S+?)\]\)(.*)$", wrapped
+ )
+ wrapped = "([POSTCOMPILE_%s~~%s~~REPL~~%s~~])" % (
+ m.group(2),
+ m.group(1),
+ m.group(3),
+ )
+ return wrapped
if not literal_binds:
literal_execute = (
@@ -2489,12 +2534,13 @@ class SQLCompiler(Compiled):
positional_names.append(name)
else:
self.positiontup.append(name)
- elif not post_compile and not escaped_from:
- tr_reg = self._bind_translate
- if tr_reg.search(name):
- # i'd rather use translate() here but I can't get it to work
- # in all cases under Python 2, not worth it right now
- new_name = tr_reg.sub(
+ elif not escaped_from:
+
+ if _BIND_TRANSLATE_RE.search(name):
+ # not quite the translate use case as we want to
+ # also get a quick boolean if we even found
+ # unusual characters in the name
+ new_name = _BIND_TRANSLATE_RE.sub(
lambda m: _BIND_TRANSLATE_CHARS[m.group(0)],
name,
)
diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py
index 92641fcc6..dd0a1be0f 100644
--- a/test/dialect/postgresql/test_types.py
+++ b/test/dialect/postgresql/test_types.py
@@ -1198,6 +1198,45 @@ class ArrayTest(AssertsCompiledSQL, fixtures.TestBase):
postgresql.ARRAY(Unicode(30), dimensions=3), "VARCHAR(30)[][][]"
)
+ def test_array_in_enum_psycopg2_cast(self):
+ expr = column(
+ "x",
+ postgresql.ARRAY(
+ postgresql.ENUM("one", "two", "three", name="myenum")
+ ),
+ ).in_([["one", "two"], ["three", "four"]])
+
+ self.assert_compile(
+ expr,
+ "x IN ([POSTCOMPILE_x_1~~~~REPL~~::myenum[]~~])",
+ dialect=postgresql.psycopg2.dialect(),
+ )
+
+ self.assert_compile(
+ expr,
+ "x IN (%(x_1_1)s::myenum[], %(x_1_2)s::myenum[])",
+ dialect=postgresql.psycopg2.dialect(),
+ render_postcompile=True,
+ )
+
+ def test_array_in_str_psycopg2_cast(self):
+ expr = column("x", postgresql.ARRAY(String(15))).in_(
+ [["one", "two"], ["three", "four"]]
+ )
+
+ self.assert_compile(
+ expr,
+ "x IN ([POSTCOMPILE_x_1~~~~REPL~~::VARCHAR(15)[]~~])",
+ dialect=postgresql.psycopg2.dialect(),
+ )
+
+ self.assert_compile(
+ expr,
+ "x IN (%(x_1_1)s::VARCHAR(15)[], %(x_1_2)s::VARCHAR(15)[])",
+ dialect=postgresql.psycopg2.dialect(),
+ render_postcompile=True,
+ )
+
def test_array_type_render_str_collate_multidim(self):
self.assert_compile(
postgresql.ARRAY(Unicode(30, collation="en_US"), dimensions=2),
@@ -1457,11 +1496,79 @@ class ArrayRoundTripTest(object):
t = Table(
"t",
metadata,
- Column("data", sqltypes.ARRAY(String(50, collation="en_US"))),
+ Column("data", self.ARRAY(String(50, collation="en_US"))),
)
t.create(connection)
+ @testing.fixture
+ def array_in_fixture(self, connection):
+ arrtable = self.tables.arrtable
+
+ connection.execute(
+ arrtable.insert(),
+ [
+ {
+ "id": 1,
+ "intarr": [1, 2, 3],
+ "strarr": [u"one", u"two", u"three"],
+ },
+ {
+ "id": 2,
+ "intarr": [4, 5, 6],
+ "strarr": [u"four", u"five", u"six"],
+ },
+ {"id": 3, "intarr": [1, 5], "strarr": [u"one", u"five"]},
+ {"id": 4, "intarr": [], "strarr": []},
+ ],
+ )
+
+ def test_array_in_int(self, array_in_fixture, connection):
+ """test #7177"""
+
+ arrtable = self.tables.arrtable
+
+ stmt = (
+ select(arrtable.c.intarr)
+ .where(arrtable.c.intarr.in_([[1, 5], [4, 5, 6], [9, 10]]))
+ .order_by(arrtable.c.id)
+ )
+
+ eq_(
+ connection.execute(stmt).all(),
+ [
+ ([4, 5, 6],),
+ ([1, 5],),
+ ],
+ )
+
+ def test_array_in_str(self, array_in_fixture, connection):
+ """test #7177"""
+
+ arrtable = self.tables.arrtable
+
+ stmt = (
+ select(arrtable.c.strarr)
+ .where(
+ arrtable.c.strarr.in_(
+ [
+ [u"one", u"five"],
+ [u"four", u"five", u"six"],
+ [u"nine", u"ten"],
+ ]
+ )
+ )
+ .order_by(arrtable.c.id)
+ )
+
+ eq_(
+ connection.execute(stmt).all(),
+ [
+ (["four", "five", "six"],),
+ (["one", "five"],),
+ ],
+ )
+
def test_array_agg(self, metadata, connection):
values_table = Table("values", metadata, Column("value", Integer))
metadata.create_all(connection)
@@ -2151,6 +2258,9 @@ class _ArrayOfEnum(TypeDecorator):
impl = postgresql.ARRAY
cache_ok = True
+ # note expanding logic is checking _is_array here so that has to
+ # translate through the TypeDecorator
+
def bind_expression(self, bindvalue):
return sa.cast(bindvalue, self)
@@ -2207,56 +2317,93 @@ class ArrayEnum(fixtures.TestBase):
connection,
)
- @testing.combinations(sqltypes.Enum, postgresql.ENUM, argnames="enum_cls")
- @testing.combinations(
- sqltypes.ARRAY,
- postgresql.ARRAY,
- (_ArrayOfEnum, testing.only_on("postgresql+psycopg2")),
- argnames="array_cls",
- )
- def test_array_of_enums(self, array_cls, enum_cls, metadata, connection):
- tbl = Table(
- "enum_table",
- self.metadata,
- Column("id", Integer, primary_key=True),
- Column(
- "enum_col",
- array_cls(enum_cls("foo", "bar", "baz", name="an_enum")),
- ),
- )
-
- if util.py3k:
- from enum import Enum
-
- class MyEnum(Enum):
- a = "aaa"
- b = "bbb"
- c = "ccc"
-
- tbl.append_column(
+ @testing.fixture
+ def array_of_enum_fixture(self, metadata, connection):
+ def go(array_cls, enum_cls):
+ tbl = Table(
+ "enum_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
Column(
- "pyenum_col",
- array_cls(enum_cls(MyEnum)),
+ "enum_col",
+ array_cls(enum_cls("foo", "bar", "baz", name="an_enum")),
),
)
+ if util.py3k:
+ from enum import Enum
+
+ class MyEnum(Enum):
+ a = "aaa"
+ b = "bbb"
+ c = "ccc"
+
+ tbl.append_column(
+ Column(
+ "pyenum_col",
+ array_cls(enum_cls(MyEnum)),
+ ),
+ )
+ else:
+ MyEnum = None
- self.metadata.create_all(connection)
+ metadata.create_all(connection)
+ connection.execute(
+ tbl.insert(),
+ [{"enum_col": ["foo"]}, {"enum_col": ["foo", "bar"]}],
+ )
+ return tbl, MyEnum
- connection.execute(
- tbl.insert(), [{"enum_col": ["foo"]}, {"enum_col": ["foo", "bar"]}]
+ yield go
+
+ def _enum_combinations(fn):
+ return testing.combinations(
+ sqltypes.Enum, postgresql.ENUM, argnames="enum_cls"
+ )(
+ testing.combinations(
+ sqltypes.ARRAY,
+ postgresql.ARRAY,
+ (_ArrayOfEnum, testing.only_on("postgresql+psycopg2")),
+ argnames="array_cls",
+ )(fn)
)
+ @_enum_combinations
+ def test_array_of_enums_roundtrip(
+ self, array_of_enum_fixture, connection, array_cls, enum_cls
+ ):
+ tbl, MyEnum = array_of_enum_fixture(array_cls, enum_cls)
+
+ # test select back
sel = select(tbl.c.enum_col).order_by(tbl.c.id)
eq_(
connection.execute(sel).fetchall(), [(["foo"],), (["foo", "bar"],)]
)
- if util.py3k:
- connection.execute(tbl.insert(), {"pyenum_col": [MyEnum.a]})
- sel = select(tbl.c.pyenum_col).order_by(tbl.c.id.desc())
- eq_(connection.scalar(sel), [MyEnum.a])
+ @_enum_combinations
+ def test_array_of_enums_expanding_in(
+ self, array_of_enum_fixture, connection, array_cls, enum_cls
+ ):
+ tbl, MyEnum = array_of_enum_fixture(array_cls, enum_cls)
+
+ # test select with WHERE using expanding IN against arrays
+ # #7177
+ sel = (
+ select(tbl.c.enum_col)
+ .where(tbl.c.enum_col.in_([["foo", "bar"], ["bar", "foo"]]))
+ .order_by(tbl.c.id)
+ )
+ eq_(connection.execute(sel).fetchall(), [(["foo", "bar"],)])
+
+ @_enum_combinations
+ @testing.requires.python3
+ def test_array_of_enums_native_roundtrip(
+ self, array_of_enum_fixture, connection, array_cls, enum_cls
+ ):
+ tbl, MyEnum = array_of_enum_fixture(array_cls, enum_cls)
- self.metadata.drop_all(connection)
+ connection.execute(tbl.insert(), {"pyenum_col": [MyEnum.a]})
+ sel = select(tbl.c.pyenum_col).order_by(tbl.c.id.desc())
+ eq_(connection.scalar(sel), [MyEnum.a])
class ArrayJSON(fixtures.TestBase):
diff --git a/test/sql/test_external_traversal.py b/test/sql/test_external_traversal.py
index 3d1b4fe85..0d43448d5 100644
--- a/test/sql/test_external_traversal.py
+++ b/test/sql/test_external_traversal.py
@@ -188,7 +188,10 @@ class TraversalTest(
("clone",), ("pickle",), ("conv_to_unique"), ("none"), argnames="meth"
)
@testing.combinations(
- ("name with space",), ("name with [brackets]",), argnames="name"
+ ("name with space",),
+ ("name with [brackets]",),
+ ("name with~~tildes~~",),
+ argnames="name",
)
def test_bindparam_key_proc_for_copies(self, meth, name):
r"""test :ticket:`6249`.
@@ -199,7 +202,7 @@ class TraversalTest(
Currently, the bind key reg is::
- re.sub(r"[%\(\) \$]+", "_", body).strip("_")
+ re.sub(r"[%\(\) \$\[\]]", "_", name)
and the compiler postcompile reg is::
@@ -218,7 +221,8 @@ class TraversalTest(
expr.right.unique = False
expr.right._convert_to_unique()
- token = re.sub(r"[%\(\) \$]+", "_", name).strip("_")
+ token = re.sub(r"[%\(\) \$\[\]]", "_", name)
+
self.assert_compile(
expr,
'"%(name)s" IN (:%(token)s_1_1, '
diff --git a/test/sql/test_type_expressions.py b/test/sql/test_type_expressions.py
index 51ee0ae62..adcaef39c 100644
--- a/test/sql/test_type_expressions.py
+++ b/test/sql/test_type_expressions.py
@@ -182,6 +182,29 @@ class SelectTest(_ExprFixture, fixtures.TestBase, AssertsCompiledSQL):
"test_table WHERE test_table.y = lower(:y_1)",
)
+ def test_in_binds(self):
+ table = self._fixture()
+
+ self.assert_compile(
+ select(table).where(
+ table.c.y.in_(["hi", "there", "some", "expr"])
+ ),
+ "SELECT test_table.x, lower(test_table.y) AS y FROM "
+ "test_table WHERE test_table.y IN "
+ "([POSTCOMPILE_y_1~~lower(~~REPL~~)~~])",
+ render_postcompile=False,
+ )
+
+ self.assert_compile(
+ select(table).where(
+ table.c.y.in_(["hi", "there", "some", "expr"])
+ ),
+ "SELECT test_table.x, lower(test_table.y) AS y FROM "
+ "test_table WHERE test_table.y IN "
+ "(lower(:y_1_1), lower(:y_1_2), lower(:y_1_3), lower(:y_1_4))",
+ render_postcompile=True,
+ )
+
def test_dialect(self):
table = self._fixture()
dialect = self._dialect_level_fixture()