summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/__init__.py1
-rw-r--r--lib/sqlalchemy/sql/compiler.py21
-rw-r--r--lib/sqlalchemy/sql/sqltypes.py5
-rw-r--r--lib/sqlalchemy/testing/suite/test_select.py104
-rw-r--r--lib/sqlalchemy/types.py2
-rw-r--r--lib/sqlalchemy/util/__init__.py1
6 files changed, 127 insertions, 7 deletions
diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py
index dc49690a5..3580dae59 100644
--- a/lib/sqlalchemy/__init__.py
+++ b/lib/sqlalchemy/__init__.py
@@ -123,6 +123,7 @@ from .types import Text
from .types import TIME
from .types import Time
from .types import TIMESTAMP
+from .types import TupleType
from .types import TypeDecorator
from .types import Unicode
from .types import UnicodeText
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index bcede5d76..96349578c 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -2024,8 +2024,14 @@ class SQLCompiler(Compiled):
[parameter.type], parameter.expand_op
)
- elif isinstance(values[0], (tuple, list)):
- assert typ_dialect_impl._is_tuple_type
+ elif typ_dialect_impl._is_tuple_type or (
+ typ_dialect_impl._isnull
+ and isinstance(values[0], util.collections_abc.Sequence)
+ and not isinstance(
+ values[0], util.string_types + util.binary_types
+ )
+ ):
+
replacement_expression = (
"VALUES " if self.dialect.tuple_in_values else ""
) + ", ".join(
@@ -2041,7 +2047,6 @@ class SQLCompiler(Compiled):
for i, tuple_element in enumerate(values)
)
else:
- assert not typ_dialect_impl._is_tuple_type
replacement_expression = ", ".join(
self.render_literal_value(value, parameter.type)
for value in values
@@ -2070,10 +2075,14 @@ class SQLCompiler(Compiled):
[parameter.type], parameter.expand_op
)
- elif (
- isinstance(values[0], (tuple, list))
- and not typ_dialect_impl._is_array
+ elif typ_dialect_impl._is_tuple_type or (
+ typ_dialect_impl._isnull
+ and isinstance(values[0], util.collections_abc.Sequence)
+ and not isinstance(
+ values[0], util.string_types + util.binary_types
+ )
):
+ assert not typ_dialect_impl._is_array
to_update = [
("%s_%s_%s" % (name, i, j), value)
for i, tuple_element in enumerate(values, 1)
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py
index 77af76d0b..0d7a06e31 100644
--- a/lib/sqlalchemy/sql/sqltypes.py
+++ b/lib/sqlalchemy/sql/sqltypes.py
@@ -2949,7 +2949,10 @@ class TupleType(TypeEngine):
def __init__(self, *types):
self._fully_typed = NULLTYPE not in types
- self.types = types
+ self.types = [
+ item_type() if isinstance(item_type, type) else item_type
+ for item_type in types
+ ]
def _resolve_values_to_types(self, value):
if self._fully_typed:
diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py
index 63502b077..bea8a6075 100644
--- a/lib/sqlalchemy/testing/suite/test_select.py
+++ b/lib/sqlalchemy/testing/suite/test_select.py
@@ -30,11 +30,13 @@ from ... import testing
from ... import text
from ... import true
from ... import tuple_
+from ... import TupleType
from ... import union
from ... import util
from ... import values
from ...exc import DatabaseError
from ...exc import ProgrammingError
+from ...util import collections_abc
class CollateTest(fixtures.TablesTest):
@@ -1131,6 +1133,41 @@ class ExpandingBoundInTest(fixtures.TablesTest):
)
self._assert_result(stmt, [])
+ def test_typed_str_in(self):
+ """test related to #7292.
+
+ as a type is given to the bound param, there is no ambiguity
+ to the type of element.
+
+ """
+
+ stmt = text(
+ "select id FROM some_table WHERE z IN :q ORDER BY id"
+ ).bindparams(bindparam("q", type_=String, expanding=True))
+ self._assert_result(
+ stmt,
+ [(2,), (3,), (4,)],
+ params={"q": ["z2", "z3", "z4"]},
+ )
+
+ def test_untyped_str_in(self):
+ """test related to #7292.
+
+ for untyped expression, we look at the types of elements.
+ Test for Sequence to detect tuple in. but not strings or bytes!
+ as always....
+
+ """
+
+ stmt = text(
+ "select id FROM some_table WHERE z IN :q ORDER BY id"
+ ).bindparams(bindparam("q", expanding=True))
+ self._assert_result(
+ stmt,
+ [(2,), (3,), (4,)],
+ params={"q": ["z2", "z3", "z4"]},
+ )
+
@testing.requires.tuple_in
def test_bound_in_two_tuple_bindparam(self):
table = self.tables.some_table
@@ -1197,6 +1234,73 @@ class ExpandingBoundInTest(fixtures.TablesTest):
params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]},
)
+ @testing.requires.tuple_in
+ def test_bound_in_heterogeneous_two_tuple_typed_bindparam_non_tuple(self):
+ class LikeATuple(collections_abc.Sequence):
+ def __init__(self, *data):
+ self._data = data
+
+ def __iter__(self):
+ return iter(self._data)
+
+ def __getitem__(self, idx):
+ return self._data[idx]
+
+ def __len__(self):
+ return len(self._data)
+
+ stmt = text(
+ "select id FROM some_table WHERE (x, z) IN :q ORDER BY id"
+ ).bindparams(
+ bindparam(
+ "q", type_=TupleType(Integer(), String()), expanding=True
+ )
+ )
+ self._assert_result(
+ stmt,
+ [(2,), (3,), (4,)],
+ params={
+ "q": [
+ LikeATuple(2, "z2"),
+ LikeATuple(3, "z3"),
+ LikeATuple(4, "z4"),
+ ]
+ },
+ )
+
+ @testing.requires.tuple_in
+ def test_bound_in_heterogeneous_two_tuple_text_bindparam_non_tuple(self):
+ # note this becomes ARRAY if we dont use expanding
+ # explicitly right now
+
+ class LikeATuple(collections_abc.Sequence):
+ def __init__(self, *data):
+ self._data = data
+
+ def __iter__(self):
+ return iter(self._data)
+
+ def __getitem__(self, idx):
+ return self._data[idx]
+
+ def __len__(self):
+ return len(self._data)
+
+ stmt = text(
+ "select id FROM some_table WHERE (x, z) IN :q ORDER BY id"
+ ).bindparams(bindparam("q", expanding=True))
+ self._assert_result(
+ stmt,
+ [(2,), (3,), (4,)],
+ params={
+ "q": [
+ LikeATuple(2, "z2"),
+ LikeATuple(3, "z3"),
+ LikeATuple(4, "z4"),
+ ]
+ },
+ )
+
def test_empty_set_against_integer_bindparam(self):
table = self.tables.some_table
stmt = (
diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py
index ecc351fc9..df8abdc69 100644
--- a/lib/sqlalchemy/types.py
+++ b/lib/sqlalchemy/types.py
@@ -36,6 +36,7 @@ __all__ = [
"INTEGER",
"DATE",
"TIME",
+ "TupleType",
"String",
"Integer",
"SmallInteger",
@@ -103,6 +104,7 @@ from .sql.sqltypes import Text
from .sql.sqltypes import TIME
from .sql.sqltypes import Time
from .sql.sqltypes import TIMESTAMP
+from .sql.sqltypes import TupleType
from .sql.sqltypes import Unicode
from .sql.sqltypes import UnicodeText
from .sql.sqltypes import VARBINARY
diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py
index 327f76715..8a18a584a 100644
--- a/lib/sqlalchemy/util/__init__.py
+++ b/lib/sqlalchemy/util/__init__.py
@@ -53,6 +53,7 @@ from .compat import b
from .compat import b64decode
from .compat import b64encode
from .compat import binary_type
+from .compat import binary_types
from .compat import byte_buffer
from .compat import callable
from .compat import cmp