summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2021-04-29 19:53:02 +0000
committerGerrit Code Review <gerrit@ci3.zzzcomputing.com>2021-04-29 19:53:02 +0000
commitdc5485b7ecdbe1cbed34fcb8d748fbe975aee140 (patch)
tree7f0456f166b53fecf881c6e214b69dc7db4944e3 /lib/sqlalchemy
parent28493bf4bc35a4802b57b02a8b389cec7b6dcbb6 (diff)
parentaba308868544b21bafa0b3435701ddc908654b0a (diff)
downloadsqlalchemy-dc5485b7ecdbe1cbed34fcb8d748fbe975aee140.tar.gz
Merge "Use non-subquery form for empty IN"
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/dialects/sqlite/base.py5
-rw-r--r--lib/sqlalchemy/sql/coercions.py11
-rw-r--r--lib/sqlalchemy/sql/compiler.py50
-rw-r--r--lib/sqlalchemy/sql/elements.py10
-rw-r--r--lib/sqlalchemy/testing/suite/test_select.py187
5 files changed, 199 insertions, 64 deletions
diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py
index 59d40fef0..66a556ae0 100644
--- a/lib/sqlalchemy/dialects/sqlite/base.py
+++ b/lib/sqlalchemy/dialects/sqlite/base.py
@@ -1299,6 +1299,11 @@ class SQLiteCompiler(compiler.SQLCompiler):
self.process(binary.right, **kw),
)
+ def visit_empty_set_op_expr(self, type_, expand_op):
+ # slightly old SQLite versions don't seem to be able to handle
+ # the empty set impl
+ return self.visit_empty_set_expr(type_)
+
def visit_empty_set_expr(self, element_types):
return "SELECT %s FROM (SELECT %s) WHERE 1!=1" % (
", ".join("1" for type_ in element_types or [INTEGER()]),
diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py
index b7aba9d74..820fc1bf1 100644
--- a/lib/sqlalchemy/sql/coercions.py
+++ b/lib/sqlalchemy/sql/coercions.py
@@ -561,14 +561,9 @@ class InElementImpl(RoleImpl):
return element.self_group(against=operator)
elif isinstance(element, elements.BindParameter):
- if not element.expanding:
- # coercing to expanding at the moment to work with the
- # lambda system. not sure if this is the right approach.
- # is there a valid use case to send a single non-expanding
- # param to IN? check for ARRAY type?
- element = element._clone(maintain_key=True)
- element.expanding = True
-
+ # previously we were adding expanding flags here but
+ # we now do this in the compiler where we have more context
+ # see compiler.py -> _render_in_expr_w_bindparam
return element
else:
return element
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index f3ae8c44f..57ffdf86b 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -1904,6 +1904,45 @@ class SQLCompiler(Compiled):
binary, override_operator=operators.match_op
)
+ def visit_in_op_binary(self, binary, operator, **kw):
+ return self._render_in_expr_w_bindparam(binary, operator, **kw)
+
+ def visit_not_in_op_binary(self, binary, operator, **kw):
+ return self._render_in_expr_w_bindparam(binary, operator, **kw)
+
+ def _render_in_expr_w_bindparam(self, binary, operator, **kw):
+ opstring = OPERATORS[operator]
+
+ if isinstance(binary.right, elements.BindParameter):
+ if not binary.right.expanding or not binary.right.expand_op:
+ # note that by cloning here, we rely upon the
+ # _cache_key_bind_match dictionary to resolve
+ # clones of bindparam() objects to the ones that are
+ # present in our cache key.
+ binary.right = binary.right._clone(maintain_key=True)
+ binary.right.expanding = True
+ binary.right.expand_op = operator
+
+ return self._generate_generic_binary(binary, opstring, **kw)
+
+ def visit_empty_set_op_expr(self, type_, expand_op):
+ if expand_op is operators.not_in_op:
+ if len(type_) > 1:
+ return "(%s)) OR (1 = 1" % (
+ ", ".join("NULL" for element in type_)
+ )
+ else:
+ return "NULL) OR (1 = 1"
+ elif expand_op is operators.in_op:
+ if len(type_) > 1:
+ return "(%s)) AND (1 != 1" % (
+ ", ".join("NULL" for element in type_)
+ )
+ else:
+ return "NULL) AND (1 != 1"
+ else:
+ return self.visit_empty_set_expr(type_)
+
def visit_empty_set_expr(self, element_types):
raise NotImplementedError(
"Dialect '%s' does not support empty set expression."
@@ -1960,12 +1999,12 @@ class SQLCompiler(Compiled):
to_update = []
if parameter.type._is_tuple_type:
- replacement_expression = self.visit_empty_set_expr(
- parameter.type.types
+ replacement_expression = self.visit_empty_set_op_expr(
+ parameter.type.types, parameter.expand_op
)
else:
- replacement_expression = self.visit_empty_set_expr(
- [parameter.type]
+ replacement_expression = self.visit_empty_set_op_expr(
+ [parameter.type], parameter.expand_op
)
elif isinstance(values[0], (tuple, list)):
@@ -3900,6 +3939,9 @@ class StrSQLCompiler(SQLCompiler):
for t in extra_froms
)
+ def visit_empty_set_op_expr(self, type_, expand_op):
+ return self.visit_empty_set_expr(type_)
+
def visit_empty_set_expr(self, type_):
return "SELECT 1 WHERE 1!=1"
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 696f3b249..e27b97802 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -1411,7 +1411,17 @@ class BindParameter(roles.InElementRole, ColumnElement):
self.callable = callable_
self.isoutparam = isoutparam
self.required = required
+
+ # indicate an "expanding" parameter; the compiler sets this
+ # automatically in the compiler _render_in_expr_w_bindparam method
+ # for an IN expression
self.expanding = expanding
+
+ # this is another hint to help w/ expanding and is typically
+ # set in the compiler _render_in_expr_w_bindparam method for an
+ # IN expression
+ self.expand_op = None
+
self.literal_execute = literal_execute
if _is_crud:
self._is_crud = True
diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py
index 8f3412929..8133c2105 100644
--- a/lib/sqlalchemy/testing/suite/test_select.py
+++ b/lib/sqlalchemy/testing/suite/test_select.py
@@ -1016,163 +1016,246 @@ class ExpandingBoundInTest(fixtures.TablesTest):
with config.db.connect() as conn:
eq_(conn.execute(select, params).fetchall(), result)
- def test_multiple_empty_sets(self):
+ def test_multiple_empty_sets_bindparam(self):
# test that any anonymous aliasing used by the dialect
# is fine with duplicates
table = self.tables.some_table
-
stmt = (
select(table.c.id)
- .where(table.c.x.in_(bindparam("q", expanding=True)))
- .where(table.c.y.in_(bindparam("p", expanding=True)))
+ .where(table.c.x.in_(bindparam("q")))
+ .where(table.c.y.in_(bindparam("p")))
.order_by(table.c.id)
)
-
self._assert_result(stmt, [], params={"q": [], "p": []})
- @testing.requires.tuple_in_w_empty
- def test_empty_heterogeneous_tuples(self):
+ def test_multiple_empty_sets_direct(self):
+ # test that any anonymous aliasing used by the dialect
+ # is fine with duplicates
table = self.tables.some_table
-
stmt = (
select(table.c.id)
- .where(
- tuple_(table.c.x, table.c.z).in_(
- bindparam("q", expanding=True)
- )
- )
+ .where(table.c.x.in_([]))
+ .where(table.c.y.in_([]))
.order_by(table.c.id)
)
+ self._assert_result(stmt, [])
+ @testing.requires.tuple_in_w_empty
+ def test_empty_heterogeneous_tuples_bindparam(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(tuple_(table.c.x, table.c.z).in_(bindparam("q")))
+ .order_by(table.c.id)
+ )
self._assert_result(stmt, [], params={"q": []})
@testing.requires.tuple_in_w_empty
- def test_empty_homogeneous_tuples(self):
+ def test_empty_heterogeneous_tuples_direct(self):
table = self.tables.some_table
+ def go(val, expected):
+ stmt = (
+ select(table.c.id)
+ .where(tuple_(table.c.x, table.c.z).in_(val))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, expected)
+
+ go([], [])
+ go([(2, "z2"), (3, "z3"), (4, "z4")], [(2,), (3,), (4,)])
+ go([], [])
+
+ @testing.requires.tuple_in_w_empty
+ def test_empty_homogeneous_tuples_bindparam(self):
+ table = self.tables.some_table
stmt = (
select(table.c.id)
- .where(
- tuple_(table.c.x, table.c.y).in_(
- bindparam("q", expanding=True)
- )
- )
+ .where(tuple_(table.c.x, table.c.y).in_(bindparam("q")))
.order_by(table.c.id)
)
-
self._assert_result(stmt, [], params={"q": []})
- def test_bound_in_scalar(self):
+ @testing.requires.tuple_in_w_empty
+ def test_empty_homogeneous_tuples_direct(self):
table = self.tables.some_table
+ def go(val, expected):
+ stmt = (
+ select(table.c.id)
+ .where(tuple_(table.c.x, table.c.y).in_(val))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, expected)
+
+ go([], [])
+ go([(1, 2), (2, 3), (3, 4)], [(1,), (2,), (3,)])
+ go([], [])
+
+ def test_bound_in_scalar_bindparam(self):
+ table = self.tables.some_table
stmt = (
select(table.c.id)
- .where(table.c.x.in_(bindparam("q", expanding=True)))
+ .where(table.c.x.in_(bindparam("q")))
.order_by(table.c.id)
)
-
self._assert_result(stmt, [(2,), (3,), (4,)], params={"q": [2, 3, 4]})
- @testing.requires.tuple_in
- def test_bound_in_two_tuple(self):
+ def test_bound_in_scalar_direct(self):
table = self.tables.some_table
-
stmt = (
select(table.c.id)
- .where(
- tuple_(table.c.x, table.c.y).in_(
- bindparam("q", expanding=True)
- )
- )
+ .where(table.c.x.in_([2, 3, 4]))
.order_by(table.c.id)
)
+ self._assert_result(stmt, [(2,), (3,), (4,)])
+ @testing.requires.tuple_in
+ def test_bound_in_two_tuple_bindparam(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(tuple_(table.c.x, table.c.y).in_(bindparam("q")))
+ .order_by(table.c.id)
+ )
self._assert_result(
stmt, [(2,), (3,), (4,)], params={"q": [(2, 3), (3, 4), (4, 5)]}
)
@testing.requires.tuple_in
- def test_bound_in_heterogeneous_two_tuple(self):
+ def test_bound_in_two_tuple_direct(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(tuple_(table.c.x, table.c.y).in_([(2, 3), (3, 4), (4, 5)]))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [(2,), (3,), (4,)])
+
+ @testing.requires.tuple_in
+ def test_bound_in_heterogeneous_two_tuple_bindparam(self):
table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(tuple_(table.c.x, table.c.z).in_(bindparam("q")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(
+ stmt,
+ [(2,), (3,), (4,)],
+ params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]},
+ )
+ @testing.requires.tuple_in
+ def test_bound_in_heterogeneous_two_tuple_direct(self):
+ table = self.tables.some_table
stmt = (
select(table.c.id)
.where(
tuple_(table.c.x, table.c.z).in_(
- bindparam("q", expanding=True)
+ [(2, "z2"), (3, "z3"), (4, "z4")]
)
)
.order_by(table.c.id)
)
-
self._assert_result(
stmt,
[(2,), (3,), (4,)],
- params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]},
)
@testing.requires.tuple_in
- def test_bound_in_heterogeneous_two_tuple_text(self):
+ def test_bound_in_heterogeneous_two_tuple_text_bindparam(self):
+ # note this becomes ARRAY if we dont use expanding
+ # explicitly right now
stmt = text(
"select id FROM some_table WHERE (x, z) IN :q ORDER BY id"
).bindparams(bindparam("q", expanding=True))
-
self._assert_result(
stmt,
[(2,), (3,), (4,)],
params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]},
)
- def test_empty_set_against_integer(self):
+ def test_empty_set_against_integer_bindparam(self):
table = self.tables.some_table
-
stmt = (
select(table.c.id)
- .where(table.c.x.in_(bindparam("q", expanding=True)))
+ .where(table.c.x.in_(bindparam("q")))
.order_by(table.c.id)
)
-
self._assert_result(stmt, [], params={"q": []})
- def test_empty_set_against_integer_negation(self):
+ def test_empty_set_against_integer_direct(self):
table = self.tables.some_table
+ stmt = select(table.c.id).where(table.c.x.in_([])).order_by(table.c.id)
+ self._assert_result(stmt, [])
+ def test_empty_set_against_integer_negation_bindparam(self):
+ table = self.tables.some_table
stmt = (
select(table.c.id)
- .where(table.c.x.not_in(bindparam("q", expanding=True)))
+ .where(table.c.x.not_in(bindparam("q")))
.order_by(table.c.id)
)
-
self._assert_result(stmt, [(1,), (2,), (3,), (4,)], params={"q": []})
- def test_empty_set_against_string(self):
+ def test_empty_set_against_integer_negation_direct(self):
table = self.tables.some_table
+ stmt = (
+ select(table.c.id).where(table.c.x.not_in([])).order_by(table.c.id)
+ )
+ self._assert_result(stmt, [(1,), (2,), (3,), (4,)])
+ def test_empty_set_against_string_bindparam(self):
+ table = self.tables.some_table
stmt = (
select(table.c.id)
- .where(table.c.z.in_(bindparam("q", expanding=True)))
+ .where(table.c.z.in_(bindparam("q")))
.order_by(table.c.id)
)
-
self._assert_result(stmt, [], params={"q": []})
- def test_empty_set_against_string_negation(self):
+ def test_empty_set_against_string_direct(self):
table = self.tables.some_table
+ stmt = select(table.c.id).where(table.c.z.in_([])).order_by(table.c.id)
+ self._assert_result(stmt, [])
+ def test_empty_set_against_string_negation_bindparam(self):
+ table = self.tables.some_table
stmt = (
select(table.c.id)
- .where(table.c.z.not_in(bindparam("q", expanding=True)))
+ .where(table.c.z.not_in(bindparam("q")))
.order_by(table.c.id)
)
-
self._assert_result(stmt, [(1,), (2,), (3,), (4,)], params={"q": []})
- def test_null_in_empty_set_is_false(self, connection):
+ def test_empty_set_against_string_negation_direct(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id).where(table.c.z.not_in([])).order_by(table.c.id)
+ )
+ self._assert_result(stmt, [(1,), (2,), (3,), (4,)])
+
+ def test_null_in_empty_set_is_false_bindparam(self, connection):
+ stmt = select(
+ case(
+ [
+ (
+ null().in_(bindparam("foo", value=())),
+ true(),
+ )
+ ],
+ else_=false(),
+ )
+ )
+ in_(connection.execute(stmt).fetchone()[0], (False, 0))
+
+ def test_null_in_empty_set_is_false_direct(self, connection):
stmt = select(
case(
[
(
- null().in_(bindparam("foo", value=(), expanding=True)),
+ null().in_([]),
true(),
)
],