diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/__init__.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 21 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/sqltypes.py | 5 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/suite/test_select.py | 104 | ||||
| -rw-r--r-- | lib/sqlalchemy/types.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/__init__.py | 1 |
6 files changed, 127 insertions, 7 deletions
diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py index dc49690a5..3580dae59 100644 --- a/lib/sqlalchemy/__init__.py +++ b/lib/sqlalchemy/__init__.py @@ -123,6 +123,7 @@ from .types import Text from .types import TIME from .types import Time from .types import TIMESTAMP +from .types import TupleType from .types import TypeDecorator from .types import Unicode from .types import UnicodeText diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index bcede5d76..96349578c 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -2024,8 +2024,14 @@ class SQLCompiler(Compiled): [parameter.type], parameter.expand_op ) - elif isinstance(values[0], (tuple, list)): - assert typ_dialect_impl._is_tuple_type + elif typ_dialect_impl._is_tuple_type or ( + typ_dialect_impl._isnull + and isinstance(values[0], util.collections_abc.Sequence) + and not isinstance( + values[0], util.string_types + util.binary_types + ) + ): + replacement_expression = ( "VALUES " if self.dialect.tuple_in_values else "" ) + ", ".join( @@ -2041,7 +2047,6 @@ class SQLCompiler(Compiled): for i, tuple_element in enumerate(values) ) else: - assert not typ_dialect_impl._is_tuple_type replacement_expression = ", ".join( self.render_literal_value(value, parameter.type) for value in values @@ -2070,10 +2075,14 @@ class SQLCompiler(Compiled): [parameter.type], parameter.expand_op ) - elif ( - isinstance(values[0], (tuple, list)) - and not typ_dialect_impl._is_array + elif typ_dialect_impl._is_tuple_type or ( + typ_dialect_impl._isnull + and isinstance(values[0], util.collections_abc.Sequence) + and not isinstance( + values[0], util.string_types + util.binary_types + ) ): + assert not typ_dialect_impl._is_array to_update = [ ("%s_%s_%s" % (name, i, j), value) for i, tuple_element in enumerate(values, 1) diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 77af76d0b..0d7a06e31 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -2949,7 +2949,10 @@ class TupleType(TypeEngine): def __init__(self, *types): self._fully_typed = NULLTYPE not in types - self.types = types + self.types = [ + item_type() if isinstance(item_type, type) else item_type + for item_type in types + ] def _resolve_values_to_types(self, value): if self._fully_typed: diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py index 63502b077..bea8a6075 100644 --- a/lib/sqlalchemy/testing/suite/test_select.py +++ b/lib/sqlalchemy/testing/suite/test_select.py @@ -30,11 +30,13 @@ from ... import testing from ... import text from ... import true from ... import tuple_ +from ... import TupleType from ... import union from ... import util from ... import values from ...exc import DatabaseError from ...exc import ProgrammingError +from ...util import collections_abc class CollateTest(fixtures.TablesTest): @@ -1131,6 +1133,41 @@ class ExpandingBoundInTest(fixtures.TablesTest): ) self._assert_result(stmt, []) + def test_typed_str_in(self): + """test related to #7292. + + as a type is given to the bound param, there is no ambiguity + to the type of element. + + """ + + stmt = text( + "select id FROM some_table WHERE z IN :q ORDER BY id" + ).bindparams(bindparam("q", type_=String, expanding=True)) + self._assert_result( + stmt, + [(2,), (3,), (4,)], + params={"q": ["z2", "z3", "z4"]}, + ) + + def test_untyped_str_in(self): + """test related to #7292. + + for untyped expression, we look at the types of elements. + Test for Sequence to detect tuple in. but not strings or bytes! + as always.... + + """ + + stmt = text( + "select id FROM some_table WHERE z IN :q ORDER BY id" + ).bindparams(bindparam("q", expanding=True)) + self._assert_result( + stmt, + [(2,), (3,), (4,)], + params={"q": ["z2", "z3", "z4"]}, + ) + @testing.requires.tuple_in def test_bound_in_two_tuple_bindparam(self): table = self.tables.some_table @@ -1197,6 +1234,73 @@ class ExpandingBoundInTest(fixtures.TablesTest): params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]}, ) + @testing.requires.tuple_in + def test_bound_in_heterogeneous_two_tuple_typed_bindparam_non_tuple(self): + class LikeATuple(collections_abc.Sequence): + def __init__(self, *data): + self._data = data + + def __iter__(self): + return iter(self._data) + + def __getitem__(self, idx): + return self._data[idx] + + def __len__(self): + return len(self._data) + + stmt = text( + "select id FROM some_table WHERE (x, z) IN :q ORDER BY id" + ).bindparams( + bindparam( + "q", type_=TupleType(Integer(), String()), expanding=True + ) + ) + self._assert_result( + stmt, + [(2,), (3,), (4,)], + params={ + "q": [ + LikeATuple(2, "z2"), + LikeATuple(3, "z3"), + LikeATuple(4, "z4"), + ] + }, + ) + + @testing.requires.tuple_in + def test_bound_in_heterogeneous_two_tuple_text_bindparam_non_tuple(self): + # note this becomes ARRAY if we dont use expanding + # explicitly right now + + class LikeATuple(collections_abc.Sequence): + def __init__(self, *data): + self._data = data + + def __iter__(self): + return iter(self._data) + + def __getitem__(self, idx): + return self._data[idx] + + def __len__(self): + return len(self._data) + + stmt = text( + "select id FROM some_table WHERE (x, z) IN :q ORDER BY id" + ).bindparams(bindparam("q", expanding=True)) + self._assert_result( + stmt, + [(2,), (3,), (4,)], + params={ + "q": [ + LikeATuple(2, "z2"), + LikeATuple(3, "z3"), + LikeATuple(4, "z4"), + ] + }, + ) + def test_empty_set_against_integer_bindparam(self): table = self.tables.some_table stmt = ( diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index ecc351fc9..df8abdc69 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -36,6 +36,7 @@ __all__ = [ "INTEGER", "DATE", "TIME", + "TupleType", "String", "Integer", "SmallInteger", @@ -103,6 +104,7 @@ from .sql.sqltypes import Text from .sql.sqltypes import TIME from .sql.sqltypes import Time from .sql.sqltypes import TIMESTAMP +from .sql.sqltypes import TupleType from .sql.sqltypes import Unicode from .sql.sqltypes import UnicodeText from .sql.sqltypes import VARBINARY diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index 327f76715..8a18a584a 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -53,6 +53,7 @@ from .compat import b from .compat import b64decode from .compat import b64encode from .compat import binary_type +from .compat import binary_types from .compat import byte_buffer from .compat import callable from .compat import cmp |
