From aba308868544b21bafa0b3435701ddc908654b0a Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 28 Apr 2021 18:31:51 -0400 Subject: Use non-subquery form for empty IN Revised the "EMPTY IN" expression to no longer rely upon using a subquery, as this was causing some compatibility and performance problems. The new approach for selected databases takes advantage of using a NULL-returning IN expression combined with the usual "1 != 1" or "1 = 1" expression appended by AND or OR. The expression is now the default for all backends other than SQLite, which still had some compatibility issues regarding tuple "IN" for older SQLite versions. Third party dialects can still override how the "empty set" expression renders by implementing a new compiler method ``def visit_empty_set_op_expr(self, type_, expand_op)``, which takes precedence over the existing ``def visit_empty_set_expr(self, element_types)`` which remains in place. Fixes: #6258 Fixes: #6397 Change-Id: I2df09eb00d2ad3b57039ae48128fdf94641b5e59 --- lib/sqlalchemy/dialects/sqlite/base.py | 5 + lib/sqlalchemy/sql/coercions.py | 11 +- lib/sqlalchemy/sql/compiler.py | 50 +++++++- lib/sqlalchemy/sql/elements.py | 10 ++ lib/sqlalchemy/testing/suite/test_select.py | 187 ++++++++++++++++++++-------- 5 files changed, 199 insertions(+), 64 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 59d40fef0..66a556ae0 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -1299,6 +1299,11 @@ class SQLiteCompiler(compiler.SQLCompiler): self.process(binary.right, **kw), ) + def visit_empty_set_op_expr(self, type_, expand_op): + # slightly old SQLite versions don't seem to be able to handle + # the empty set impl + return self.visit_empty_set_expr(type_) + def visit_empty_set_expr(self, element_types): return "SELECT %s FROM (SELECT %s) WHERE 1!=1" % ( ", ".join("1" for type_ in element_types or [INTEGER()]), diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index b7aba9d74..820fc1bf1 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -561,14 +561,9 @@ class InElementImpl(RoleImpl): return element.self_group(against=operator) elif isinstance(element, elements.BindParameter): - if not element.expanding: - # coercing to expanding at the moment to work with the - # lambda system. not sure if this is the right approach. - # is there a valid use case to send a single non-expanding - # param to IN? check for ARRAY type? - element = element._clone(maintain_key=True) - element.expanding = True - + # previously we were adding expanding flags here but + # we now do this in the compiler where we have more context + # see compiler.py -> _render_in_expr_w_bindparam return element else: return element diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 6168248ff..e9e05b7e9 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1903,6 +1903,45 @@ class SQLCompiler(Compiled): binary, override_operator=operators.match_op ) + def visit_in_op_binary(self, binary, operator, **kw): + return self._render_in_expr_w_bindparam(binary, operator, **kw) + + def visit_not_in_op_binary(self, binary, operator, **kw): + return self._render_in_expr_w_bindparam(binary, operator, **kw) + + def _render_in_expr_w_bindparam(self, binary, operator, **kw): + opstring = OPERATORS[operator] + + if isinstance(binary.right, elements.BindParameter): + if not binary.right.expanding or not binary.right.expand_op: + # note that by cloning here, we rely upon the + # _cache_key_bind_match dictionary to resolve + # clones of bindparam() objects to the ones that are + # present in our cache key. + binary.right = binary.right._clone(maintain_key=True) + binary.right.expanding = True + binary.right.expand_op = operator + + return self._generate_generic_binary(binary, opstring, **kw) + + def visit_empty_set_op_expr(self, type_, expand_op): + if expand_op is operators.not_in_op: + if len(type_) > 1: + return "(%s)) OR (1 = 1" % ( + ", ".join("NULL" for element in type_) + ) + else: + return "NULL) OR (1 = 1" + elif expand_op is operators.in_op: + if len(type_) > 1: + return "(%s)) AND (1 != 1" % ( + ", ".join("NULL" for element in type_) + ) + else: + return "NULL) AND (1 != 1" + else: + return self.visit_empty_set_expr(type_) + def visit_empty_set_expr(self, element_types): raise NotImplementedError( "Dialect '%s' does not support empty set expression." @@ -1959,12 +1998,12 @@ class SQLCompiler(Compiled): to_update = [] if parameter.type._is_tuple_type: - replacement_expression = self.visit_empty_set_expr( - parameter.type.types + replacement_expression = self.visit_empty_set_op_expr( + parameter.type.types, parameter.expand_op ) else: - replacement_expression = self.visit_empty_set_expr( - [parameter.type] + replacement_expression = self.visit_empty_set_op_expr( + [parameter.type], parameter.expand_op ) elif isinstance(values[0], (tuple, list)): @@ -3900,6 +3939,9 @@ class StrSQLCompiler(SQLCompiler): for t in extra_froms ) + def visit_empty_set_op_expr(self, type_, expand_op): + return self.visit_empty_set_expr(type_) + def visit_empty_set_expr(self, type_): return "SELECT 1 WHERE 1!=1" diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 696f3b249..e27b97802 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -1411,7 +1411,17 @@ class BindParameter(roles.InElementRole, ColumnElement): self.callable = callable_ self.isoutparam = isoutparam self.required = required + + # indicate an "expanding" parameter; the compiler sets this + # automatically in the compiler _render_in_expr_w_bindparam method + # for an IN expression self.expanding = expanding + + # this is another hint to help w/ expanding and is typically + # set in the compiler _render_in_expr_w_bindparam method for an + # IN expression + self.expand_op = None + self.literal_execute = literal_execute if _is_crud: self._is_crud = True diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py index 7b35dc3fa..1614acd3d 100644 --- a/lib/sqlalchemy/testing/suite/test_select.py +++ b/lib/sqlalchemy/testing/suite/test_select.py @@ -1016,163 +1016,246 @@ class ExpandingBoundInTest(fixtures.TablesTest): with config.db.connect() as conn: eq_(conn.execute(select, params).fetchall(), result) - def test_multiple_empty_sets(self): + def test_multiple_empty_sets_bindparam(self): # test that any anonymous aliasing used by the dialect # is fine with duplicates table = self.tables.some_table - stmt = ( select(table.c.id) - .where(table.c.x.in_(bindparam("q", expanding=True))) - .where(table.c.y.in_(bindparam("p", expanding=True))) + .where(table.c.x.in_(bindparam("q"))) + .where(table.c.y.in_(bindparam("p"))) .order_by(table.c.id) ) - self._assert_result(stmt, [], params={"q": [], "p": []}) - @testing.requires.tuple_in_w_empty - def test_empty_heterogeneous_tuples(self): + def test_multiple_empty_sets_direct(self): + # test that any anonymous aliasing used by the dialect + # is fine with duplicates table = self.tables.some_table - stmt = ( select(table.c.id) - .where( - tuple_(table.c.x, table.c.z).in_( - bindparam("q", expanding=True) - ) - ) + .where(table.c.x.in_([])) + .where(table.c.y.in_([])) .order_by(table.c.id) ) + self._assert_result(stmt, []) + @testing.requires.tuple_in_w_empty + def test_empty_heterogeneous_tuples_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.z).in_(bindparam("q"))) + .order_by(table.c.id) + ) self._assert_result(stmt, [], params={"q": []}) @testing.requires.tuple_in_w_empty - def test_empty_homogeneous_tuples(self): + def test_empty_heterogeneous_tuples_direct(self): table = self.tables.some_table + def go(val, expected): + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.z).in_(val)) + .order_by(table.c.id) + ) + self._assert_result(stmt, expected) + + go([], []) + go([(2, "z2"), (3, "z3"), (4, "z4")], [(2,), (3,), (4,)]) + go([], []) + + @testing.requires.tuple_in_w_empty + def test_empty_homogeneous_tuples_bindparam(self): + table = self.tables.some_table stmt = ( select(table.c.id) - .where( - tuple_(table.c.x, table.c.y).in_( - bindparam("q", expanding=True) - ) - ) + .where(tuple_(table.c.x, table.c.y).in_(bindparam("q"))) .order_by(table.c.id) ) - self._assert_result(stmt, [], params={"q": []}) - def test_bound_in_scalar(self): + @testing.requires.tuple_in_w_empty + def test_empty_homogeneous_tuples_direct(self): table = self.tables.some_table + def go(val, expected): + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.y).in_(val)) + .order_by(table.c.id) + ) + self._assert_result(stmt, expected) + + go([], []) + go([(1, 2), (2, 3), (3, 4)], [(1,), (2,), (3,)]) + go([], []) + + def test_bound_in_scalar_bindparam(self): + table = self.tables.some_table stmt = ( select(table.c.id) - .where(table.c.x.in_(bindparam("q", expanding=True))) + .where(table.c.x.in_(bindparam("q"))) .order_by(table.c.id) ) - self._assert_result(stmt, [(2,), (3,), (4,)], params={"q": [2, 3, 4]}) - @testing.requires.tuple_in - def test_bound_in_two_tuple(self): + def test_bound_in_scalar_direct(self): table = self.tables.some_table - stmt = ( select(table.c.id) - .where( - tuple_(table.c.x, table.c.y).in_( - bindparam("q", expanding=True) - ) - ) + .where(table.c.x.in_([2, 3, 4])) .order_by(table.c.id) ) + self._assert_result(stmt, [(2,), (3,), (4,)]) + @testing.requires.tuple_in + def test_bound_in_two_tuple_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.y).in_(bindparam("q"))) + .order_by(table.c.id) + ) self._assert_result( stmt, [(2,), (3,), (4,)], params={"q": [(2, 3), (3, 4), (4, 5)]} ) @testing.requires.tuple_in - def test_bound_in_heterogeneous_two_tuple(self): + def test_bound_in_two_tuple_direct(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.y).in_([(2, 3), (3, 4), (4, 5)])) + .order_by(table.c.id) + ) + self._assert_result(stmt, [(2,), (3,), (4,)]) + + @testing.requires.tuple_in + def test_bound_in_heterogeneous_two_tuple_bindparam(self): table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.z).in_(bindparam("q"))) + .order_by(table.c.id) + ) + self._assert_result( + stmt, + [(2,), (3,), (4,)], + params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]}, + ) + @testing.requires.tuple_in + def test_bound_in_heterogeneous_two_tuple_direct(self): + table = self.tables.some_table stmt = ( select(table.c.id) .where( tuple_(table.c.x, table.c.z).in_( - bindparam("q", expanding=True) + [(2, "z2"), (3, "z3"), (4, "z4")] ) ) .order_by(table.c.id) ) - self._assert_result( stmt, [(2,), (3,), (4,)], - params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]}, ) @testing.requires.tuple_in - def test_bound_in_heterogeneous_two_tuple_text(self): + def test_bound_in_heterogeneous_two_tuple_text_bindparam(self): + # note this becomes ARRAY if we dont use expanding + # explicitly right now stmt = text( "select id FROM some_table WHERE (x, z) IN :q ORDER BY id" ).bindparams(bindparam("q", expanding=True)) - self._assert_result( stmt, [(2,), (3,), (4,)], params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]}, ) - def test_empty_set_against_integer(self): + def test_empty_set_against_integer_bindparam(self): table = self.tables.some_table - stmt = ( select(table.c.id) - .where(table.c.x.in_(bindparam("q", expanding=True))) + .where(table.c.x.in_(bindparam("q"))) .order_by(table.c.id) ) - self._assert_result(stmt, [], params={"q": []}) - def test_empty_set_against_integer_negation(self): + def test_empty_set_against_integer_direct(self): table = self.tables.some_table + stmt = select(table.c.id).where(table.c.x.in_([])).order_by(table.c.id) + self._assert_result(stmt, []) + def test_empty_set_against_integer_negation_bindparam(self): + table = self.tables.some_table stmt = ( select(table.c.id) - .where(table.c.x.not_in(bindparam("q", expanding=True))) + .where(table.c.x.not_in(bindparam("q"))) .order_by(table.c.id) ) - self._assert_result(stmt, [(1,), (2,), (3,), (4,)], params={"q": []}) - def test_empty_set_against_string(self): + def test_empty_set_against_integer_negation_direct(self): table = self.tables.some_table + stmt = ( + select(table.c.id).where(table.c.x.not_in([])).order_by(table.c.id) + ) + self._assert_result(stmt, [(1,), (2,), (3,), (4,)]) + def test_empty_set_against_string_bindparam(self): + table = self.tables.some_table stmt = ( select(table.c.id) - .where(table.c.z.in_(bindparam("q", expanding=True))) + .where(table.c.z.in_(bindparam("q"))) .order_by(table.c.id) ) - self._assert_result(stmt, [], params={"q": []}) - def test_empty_set_against_string_negation(self): + def test_empty_set_against_string_direct(self): table = self.tables.some_table + stmt = select(table.c.id).where(table.c.z.in_([])).order_by(table.c.id) + self._assert_result(stmt, []) + def test_empty_set_against_string_negation_bindparam(self): + table = self.tables.some_table stmt = ( select(table.c.id) - .where(table.c.z.not_in(bindparam("q", expanding=True))) + .where(table.c.z.not_in(bindparam("q"))) .order_by(table.c.id) ) - self._assert_result(stmt, [(1,), (2,), (3,), (4,)], params={"q": []}) - def test_null_in_empty_set_is_false(self, connection): + def test_empty_set_against_string_negation_direct(self): + table = self.tables.some_table + stmt = ( + select(table.c.id).where(table.c.z.not_in([])).order_by(table.c.id) + ) + self._assert_result(stmt, [(1,), (2,), (3,), (4,)]) + + def test_null_in_empty_set_is_false_bindparam(self, connection): + stmt = select( + case( + [ + ( + null().in_(bindparam("foo", value=())), + true(), + ) + ], + else_=false(), + ) + ) + in_(connection.execute(stmt).fetchone()[0], (False, 0)) + + def test_null_in_empty_set_is_false_direct(self, connection): stmt = select( case( [ ( - null().in_(bindparam("foo", value=(), expanding=True)), + null().in_([]), true(), ) ], -- cgit v1.2.1