summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2018-10-03 10:40:38 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2018-10-03 10:40:38 -0400
commitaa2128427064a2bdeaeff5dc946ecbb3727c90aa (patch)
tree62e207bb41e8569727153c144f4d650d4981d02c
parentffd27cef48241e39725c4e9cd13fd744a2806bdd (diff)
downloadsqlalchemy-aa2128427064a2bdeaeff5dc946ecbb3727c90aa.tar.gz
Support tuples of heterogeneous types for empty expanding IN
Pass a list of all the types for the left side of an IN expression to the visit_empty_set_expr() method, so that the "empty expanding IN" can produce clauses for each element. Fixes: #4271 Change-Id: I2738b9df2292ac01afda37f16d4fa56ae7bf9147
-rw-r--r--doc/build/changelog/migration_13.rst18
-rw-r--r--lib/sqlalchemy/dialects/mysql/base.py14
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py13
-rw-r--r--lib/sqlalchemy/engine/default.py5
-rw-r--r--lib/sqlalchemy/sql/compiler.py2
-rw-r--r--lib/sqlalchemy/sql/default_comparator.py11
-rw-r--r--lib/sqlalchemy/sql/elements.py10
-rw-r--r--lib/sqlalchemy/testing/suite/test_select.py42
8 files changed, 105 insertions, 10 deletions
diff --git a/doc/build/changelog/migration_13.rst b/doc/build/changelog/migration_13.rst
index 5a8e3ce05..40363d426 100644
--- a/doc/build/changelog/migration_13.rst
+++ b/doc/build/changelog/migration_13.rst
@@ -413,6 +413,24 @@ backend, such as "SELECT CAST(NULL AS INTEGER) WHERE 1!=1" for Postgresql,
...
SELECT 1 WHERE 1 IN (SELECT CAST(NULL AS INTEGER) WHERE 1!=1)
+The feature also works for tuple-oriented IN statements, where the "empty IN"
+expression will be expanded to support the elements given inside the tuple,
+such as on Postgresql::
+
+ >>> from sqlalchemy import create_engine
+ >>> from sqlalchemy import select, literal_column, tuple_, bindparam
+ >>> e = create_engine("postgresql://scott:tiger@localhost/test", echo=True)
+ >>> with e.connect() as conn:
+ ... conn.execute(
+ ... select([literal_column('1')]).
+ ... where(tuple_(50, "somestring").in_(bindparam('q', expanding=True))),
+ ... q=[]
+ ... )
+ ...
+ SELECT 1 WHERE (%(param_1)s, %(param_2)s)
+ IN (SELECT CAST(NULL AS INTEGER), CAST(NULL AS VARCHAR) WHERE 1!=1)
+
+
:ticket:`4271`
.. _change_3981:
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py
index 45e8c2510..43966d1dc 100644
--- a/lib/sqlalchemy/dialects/mysql/base.py
+++ b/lib/sqlalchemy/dialects/mysql/base.py
@@ -1179,8 +1179,18 @@ class MySQLCompiler(compiler.SQLCompiler):
fromhints=from_hints, **kw)
for t in [from_table] + extra_froms)
- def visit_empty_set_expr(self, type_):
- return 'SELECT 1 FROM (SELECT 1) as _empty_set WHERE 1!=1'
+ def visit_empty_set_expr(self, element_types):
+ return (
+ "SELECT %(outer)s FROM (SELECT %(inner)s) "
+ "as _empty_set WHERE 1!=1" % {
+ "inner": ", ".join(
+ "1 AS _in_%s" % idx
+ for idx, type_ in enumerate(element_types)),
+ "outer": ", ".join(
+ "_in_%s" % idx
+ for idx, type_ in enumerate(element_types))
+ }
+ )
class MySQLDDLCompiler(compiler.DDLCompiler):
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index 11fcc41d5..5251a000d 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -1485,14 +1485,17 @@ class PGCompiler(compiler.SQLCompiler):
if escape else ''
)
- def visit_empty_set_expr(self, type_, **kw):
+ def visit_empty_set_expr(self, element_types):
# cast the empty set to the type we are comparing against. if
# we are comparing against the null type, pick an arbitrary
# datatype for the empty set
- if type_._isnull:
- type_ = INTEGER()
- return 'SELECT CAST(NULL AS %s) WHERE 1!=1' % \
- self.dialect.type_compiler.process(type_, **kw)
+ return 'SELECT %s WHERE 1!=1' % (
+ ", ".join(
+ "CAST(NULL AS %s)" % self.dialect.type_compiler.process(
+ INTEGER() if type_._isnull else type_,
+ ) for type_ in element_types or [INTEGER()]
+ ),
+ )
def render_literal_value(self, value, type_):
value = super(PGCompiler, self).render_literal_value(value, type_)
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index f48217a4e..5c96e4240 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -737,7 +737,10 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
to_update = []
replacement_expressions[name] = (
self.compiled.visit_empty_set_expr(
- type_=parameter.type)
+ parameter._expanding_in_types
+ if parameter._expanding_in_types
+ else [parameter.type]
+ )
)
elif isinstance(values[0], (tuple, list)):
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 2f68b7e2e..27ee4afc6 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -1056,7 +1056,7 @@ class SQLCompiler(Compiled):
self._emit_empty_in_warning()
return self.process(binary.left == binary.left)
- def visit_empty_set_expr(self, type_):
+ def visit_empty_set_expr(self, element_types):
raise NotImplementedError(
"Dialect '%s' does not support empty set expression." %
self.dialect.name
diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py
index 5d02f65a1..8149f9731 100644
--- a/lib/sqlalchemy/sql/default_comparator.py
+++ b/lib/sqlalchemy/sql/default_comparator.py
@@ -15,7 +15,8 @@ from .elements import BindParameter, True_, False_, BinaryExpression, \
Null, _const_expr, _clause_element_as_expr, \
ClauseList, ColumnElement, TextClause, UnaryExpression, \
collate, _is_literal, _literal_as_text, ClauseElement, and_, or_, \
- Slice, Visitable, _literal_as_binds, CollectionAggregate
+ Slice, Visitable, _literal_as_binds, CollectionAggregate, \
+ Tuple
from .selectable import SelectBase, Alias, Selectable, ScalarSelect
@@ -145,6 +146,14 @@ def _in_impl(expr, op, seq_or_selectable, negate_op, **kw):
elif isinstance(seq_or_selectable, ClauseElement):
if isinstance(seq_or_selectable, BindParameter) and \
seq_or_selectable.expanding:
+
+ if isinstance(expr, Tuple):
+ seq_or_selectable = (
+ seq_or_selectable._with_expanding_in_types(
+ [elem.type for elem in expr]
+ )
+ )
+
return _boolean_compare(
expr, op,
seq_or_selectable,
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index dd16b6862..de3b7992a 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -865,6 +865,7 @@ class BindParameter(ColumnElement):
__visit_name__ = 'bindparam'
_is_crud = False
+ _expanding_in_types = ()
def __init__(self, key, value=NO_ARG, type_=None,
unique=False, required=NO_ARG,
@@ -1134,6 +1135,15 @@ class BindParameter(ColumnElement):
else:
self.type = type_
+ def _with_expanding_in_types(self, types):
+ """Return a copy of this :class:`.BindParameter` in
+ the context of an expanding IN against a tuple.
+
+ """
+ cloned = self._clone()
+ cloned._expanding_in_types = types
+ return cloned
+
def _with_value(self, value):
"""Return a copy of this :class:`.BindParameter` with the given value
set.
diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py
index 78b34f496..73ce02492 100644
--- a/lib/sqlalchemy/testing/suite/test_select.py
+++ b/lib/sqlalchemy/testing/suite/test_select.py
@@ -402,6 +402,34 @@ class ExpandingBoundInTest(fixtures.TablesTest):
params={"q": [], "p": []},
)
+ @testing.requires.tuple_in
+ def test_empty_heterogeneous_tuples(self):
+ table = self.tables.some_table
+
+ stmt = select([table.c.id]).where(
+ tuple_(table.c.x, table.c.z).in_(
+ bindparam('q', expanding=True))).order_by(table.c.id)
+
+ self._assert_result(
+ stmt,
+ [],
+ params={"q": []},
+ )
+
+ @testing.requires.tuple_in
+ def test_empty_homogeneous_tuples(self):
+ table = self.tables.some_table
+
+ stmt = select([table.c.id]).where(
+ tuple_(table.c.x, table.c.y).in_(
+ bindparam('q', expanding=True))).order_by(table.c.id)
+
+ self._assert_result(
+ stmt,
+ [],
+ params={"q": []},
+ )
+
def test_bound_in_scalar(self):
table = self.tables.some_table
@@ -428,6 +456,20 @@ class ExpandingBoundInTest(fixtures.TablesTest):
params={"q": [(2, 3), (3, 4), (4, 5)]},
)
+ @testing.requires.tuple_in
+ def test_bound_in_heterogeneous_two_tuple(self):
+ table = self.tables.some_table
+
+ stmt = select([table.c.id]).where(
+ tuple_(table.c.x, table.c.z).in_(
+ bindparam('q', expanding=True))).order_by(table.c.id)
+
+ self._assert_result(
+ stmt,
+ [(2, ), (3, ), (4, )],
+ params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]},
+ )
+
def test_empty_set_against_integer(self):
table = self.tables.some_table