From 3809a5ecfe785cecbc9d91a8e4e4558e3839c694 Mon Sep 17 00:00:00 2001 From: Alessio Bogon Date: Sun, 15 Sep 2019 11:12:24 -0400 Subject: Query linter option Added "from linting" as a built-in feature to the SQL compiler. This allows the compiler to maintain graph of all the FROM clauses in a particular SELECT statement, linked by criteria in either the WHERE or in JOIN clauses that link these FROM clauses together. If any two FROM clauses have no path between them, a warning is emitted that the query may be producing a cartesian product. As the Core expression language as well as the ORM are built on an "implicit FROMs" model where a particular FROM clause is automatically added if any part of the query refers to it, it is easy for this to happen inadvertently and it is hoped that the new feature helps with this issue. The original recipe is from: https://github.com/sqlalchemy/sqlalchemy/wiki/FromLinter The linter is now enabled for all tests in the test suite as well. This has necessitated that a lot of the queries be adjusted to not include cartesian products. Part of the rationale for the linter to not be enabled for statement compilation only was to reduce the need for adjustment for the many test case statements throughout the test suite that are not real-world statements. This gerrit is adapted from Ib5946e57c9dba6da428c4d1dee6760b3e978dda0. Fixes: #4737 Change-Id: Ic91fd9774379f895d021c3ad564db6062299211c Closes: #4830 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/4830 Pull-request-sha: f8a21aa6262d1bcc9ff0d11a2616e41fba97a47a --- test/sql/test_defaults.py | 2 +- test/sql/test_from_linter.py | 277 +++++++++++++++++++++++++++++++++++++++++++ test/sql/test_resultset.py | 55 ++++++--- 3 files changed, 315 insertions(+), 19 deletions(-) create mode 100644 test/sql/test_from_linter.py (limited to 'test/sql') diff --git a/test/sql/test_defaults.py b/test/sql/test_defaults.py index f033abab2..b31b070d8 100644 --- a/test/sql/test_defaults.py +++ b/test/sql/test_defaults.py @@ -852,7 +852,7 @@ class CTEDefaultTest(fixtures.TablesTest): if b == "select": conn.execute(p.insert().values(s=1)) - stmt = select([p.c.s, cte.c.z]) + stmt = select([p.c.s, cte.c.z]).where(p.c.s == cte.c.z) elif b == "insert": sel = select([1, cte.c.z]) stmt = ( diff --git a/test/sql/test_from_linter.py b/test/sql/test_from_linter.py new file mode 100644 index 000000000..bf2f06b57 --- /dev/null +++ b/test/sql/test_from_linter.py @@ -0,0 +1,277 @@ +from sqlalchemy import Integer +from sqlalchemy import select +from sqlalchemy import sql +from sqlalchemy import true +from sqlalchemy.testing import config +from sqlalchemy.testing import engines +from sqlalchemy.testing import expect_warnings +from sqlalchemy.testing import fixtures +from sqlalchemy.testing import is_ +from sqlalchemy.testing.schema import Column +from sqlalchemy.testing.schema import Table + + +def find_unmatching_froms(query, start=None): + compiled = query.compile(linting=sql.COLLECT_CARTESIAN_PRODUCTS) + + return compiled.from_linter.lint(start) + + +class TestFindUnmatchingFroms(fixtures.TablesTest): + @classmethod + def define_tables(cls, metadata): + Table("table_a", metadata, Column("col_a", Integer, primary_key=True)) + Table("table_b", metadata, Column("col_b", Integer, primary_key=True)) + Table("table_c", metadata, Column("col_c", Integer, primary_key=True)) + Table("table_d", metadata, Column("col_d", Integer, primary_key=True)) + + def setup(self): + self.a = self.tables.table_a + self.b = self.tables.table_b + self.c = self.tables.table_c + self.d = self.tables.table_d + + def test_everything_is_connected(self): + query = ( + select([self.a]) + .select_from(self.a.join(self.b, self.a.c.col_a == self.b.c.col_b)) + .select_from(self.c) + .select_from(self.d) + .where(self.d.c.col_d == self.b.c.col_b) + .where(self.c.c.col_c == self.d.c.col_d) + .where(self.c.c.col_c == 5) + ) + froms, start = find_unmatching_froms(query) + assert not froms + + for start in self.a, self.b, self.c, self.d: + froms, start = find_unmatching_froms(query, start) + assert not froms + + def test_plain_cartesian(self): + query = select([self.a]).where(self.b.c.col_b == 5) + froms, start = find_unmatching_froms(query, self.a) + assert start == self.a + assert froms == {self.b} + + froms, start = find_unmatching_froms(query, self.b) + assert start == self.b + assert froms == {self.a} + + def test_count_non_eq_comparison_operators(self): + query = select([self.a]).where(self.a.c.col_a > self.b.c.col_b) + froms, start = find_unmatching_froms(query, self.a) + is_(start, None) + is_(froms, None) + + def test_dont_count_non_comparison_operators(self): + query = select([self.a]).where(self.a.c.col_a + self.b.c.col_b == 5) + froms, start = find_unmatching_froms(query, self.a) + assert start == self.a + assert froms == {self.b} + + def test_disconnect_between_ab_cd(self): + query = ( + select([self.a]) + .select_from(self.a.join(self.b, self.a.c.col_a == self.b.c.col_b)) + .select_from(self.c) + .select_from(self.d) + .where(self.c.c.col_c == self.d.c.col_d) + .where(self.c.c.col_c == 5) + ) + for start in self.a, self.b: + froms, start = find_unmatching_froms(query, start) + assert start == start + assert froms == {self.c, self.d} + for start in self.c, self.d: + froms, start = find_unmatching_froms(query, start) + assert start == start + assert froms == {self.a, self.b} + + def test_c_and_d_both_disconnected(self): + query = ( + select([self.a]) + .select_from(self.a.join(self.b, self.a.c.col_a == self.b.c.col_b)) + .where(self.c.c.col_c == 5) + .where(self.d.c.col_d == 10) + ) + for start in self.a, self.b: + froms, start = find_unmatching_froms(query, start) + assert start == start + assert froms == {self.c, self.d} + + froms, start = find_unmatching_froms(query, self.c) + assert start == self.c + assert froms == {self.a, self.b, self.d} + + froms, start = find_unmatching_froms(query, self.d) + assert start == self.d + assert froms == {self.a, self.b, self.c} + + def test_now_connected(self): + query = ( + select([self.a]) + .select_from(self.a.join(self.b, self.a.c.col_a == self.b.c.col_b)) + .select_from(self.c.join(self.d, self.c.c.col_c == self.d.c.col_d)) + .where(self.c.c.col_c == self.b.c.col_b) + .where(self.c.c.col_c == 5) + .where(self.d.c.col_d == 10) + ) + froms, start = find_unmatching_froms(query) + assert not froms + + for start in self.a, self.b, self.c, self.d: + froms, start = find_unmatching_froms(query, start) + assert not froms + + def test_disconnected_subquery(self): + subq = ( + select([self.a]).where(self.a.c.col_a == self.b.c.col_b).subquery() + ) + stmt = select([self.c]).select_from(subq) + + froms, start = find_unmatching_froms(stmt, self.c) + assert start == self.c + assert froms == {subq} + + froms, start = find_unmatching_froms(stmt, subq) + assert start == subq + assert froms == {self.c} + + def test_now_connect_it(self): + subq = ( + select([self.a]).where(self.a.c.col_a == self.b.c.col_b).subquery() + ) + stmt = ( + select([self.c]) + .select_from(subq) + .where(self.c.c.col_c == subq.c.col_a) + ) + + froms, start = find_unmatching_froms(stmt) + assert not froms + + for start in self.c, subq: + froms, start = find_unmatching_froms(stmt, start) + assert not froms + + def test_right_nested_join_without_issue(self): + query = select([self.a]).select_from( + self.a.join( + self.b.join(self.c, self.b.c.col_b == self.c.c.col_c), + self.a.c.col_a == self.b.c.col_b, + ) + ) + froms, start = find_unmatching_froms(query) + assert not froms + + for start in self.a, self.b, self.c: + froms, start = find_unmatching_froms(query, start) + assert not froms + + def test_join_on_true(self): + # test that a join(a, b) counts a->b as an edge even if there isn't + # actually a join condition. this essentially allows a cartesian + # product to be added explicitly. + + query = select([self.a]).select_from(self.a.join(self.b, true())) + froms, start = find_unmatching_froms(query) + assert not froms + + def test_right_nested_join_with_an_issue(self): + query = ( + select([self.a]) + .select_from( + self.a.join( + self.b.join(self.c, self.b.c.col_b == self.c.c.col_c), + self.a.c.col_a == self.b.c.col_b, + ) + ) + .where(self.d.c.col_d == 5) + ) + + for start in self.a, self.b, self.c: + froms, start = find_unmatching_froms(query, start) + assert start == start + assert froms == {self.d} + + froms, start = find_unmatching_froms(query, self.d) + assert start == self.d + assert froms == {self.a, self.b, self.c} + + def test_no_froms(self): + query = select([1]) + + froms, start = find_unmatching_froms(query) + assert not froms + + +class TestLinter(fixtures.TablesTest): + @classmethod + def define_tables(cls, metadata): + Table("table_a", metadata, Column("col_a", Integer, primary_key=True)) + Table("table_b", metadata, Column("col_b", Integer, primary_key=True)) + + @classmethod + def setup_bind(cls): + # from linting is enabled by default + return config.db + + def test_noop_for_unhandled_objects(self): + with self.bind.connect() as conn: + conn.execute("SELECT 1;").fetchone() + + def test_does_not_modify_query(self): + with self.bind.connect() as conn: + [result] = conn.execute(select([1])).fetchone() + assert result == 1 + + def test_warn_simple(self): + a, b = self.tables("table_a", "table_b") + query = select([a.c.col_a]).where(b.c.col_b == 5) + + with expect_warnings( + r"SELECT statement has a cartesian product between FROM " + r'element\(s\) "table_[ab]" ' + r'and FROM element "table_[ba]"' + ): + with self.bind.connect() as conn: + conn.execute(query) + + def test_warn_anon_alias(self): + a, b = self.tables("table_a", "table_b") + + b_alias = b.alias() + query = select([a.c.col_a]).where(b_alias.c.col_b == 5) + + with expect_warnings( + r"SELECT statement has a cartesian product between FROM " + r'element\(s\) "table_(?:a|b_1)" ' + r'and FROM element "table_(?:a|b_1)"' + ): + with self.bind.connect() as conn: + conn.execute(query) + + def test_warn_anon_cte(self): + a, b = self.tables("table_a", "table_b") + + b_cte = select([b]).cte() + query = select([a.c.col_a]).where(b_cte.c.col_b == 5) + + with expect_warnings( + r"SELECT statement has a cartesian product between " + r"FROM element\(s\) " + r'"(?:anon_1|table_a)" ' + r'and FROM element "(?:anon_1|table_a)"' + ): + with self.bind.connect() as conn: + conn.execute(query) + + def test_no_linting(self): + eng = engines.testing_engine(options={"enable_from_linting": False}) + eng.pool = self.bind.pool # needed for SQLite + a, b = self.tables("table_a", "table_b") + query = select([a.c.col_a]).where(b.c.col_b == 5) + + with eng.connect() as conn: + conn.execute(query) diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py index 794508a32..8aa524d78 100644 --- a/test/sql/test_resultset.py +++ b/test/sql/test_resultset.py @@ -19,6 +19,7 @@ from sqlalchemy import String from sqlalchemy import table from sqlalchemy import testing from sqlalchemy import text +from sqlalchemy import true from sqlalchemy import type_coerce from sqlalchemy import TypeDecorator from sqlalchemy import util @@ -771,7 +772,11 @@ class ResultProxyTest(fixtures.TablesTest): users.insert().execute(user_id=1, user_name="john") ua = users.alias() u2 = users.alias() - result = select([users.c.user_id, ua.c.user_id]).execute() + result = ( + select([users.c.user_id, ua.c.user_id]) + .select_from(users.join(ua, true())) + .execute() + ) row = result.first() # as of 1.1 issue #3501, we use pure positional @@ -1414,7 +1419,9 @@ class KeyTargetingTest(fixtures.TablesTest): keyed1 = self.tables.keyed1 keyed2 = self.tables.keyed2 - row = testing.db.execute(select([keyed1, keyed2])).first() + row = testing.db.execute( + select([keyed1, keyed2]).select_from(keyed1.join(keyed2, true())) + ).first() # column access is unambiguous eq_(row[self.tables.keyed2.c.b], "b2") @@ -1446,7 +1453,9 @@ class KeyTargetingTest(fixtures.TablesTest): keyed2 = self.tables.keyed2 row = testing.db.execute( - select([keyed1, keyed2]).apply_labels() + select([keyed1, keyed2]) + .select_from(keyed1.join(keyed2, true())) + .apply_labels() ).first() # column access is unambiguous @@ -1459,7 +1468,9 @@ class KeyTargetingTest(fixtures.TablesTest): keyed1 = self.tables.keyed1 keyed4 = self.tables.keyed4 - row = testing.db.execute(select([keyed1, keyed4])).first() + row = testing.db.execute( + select([keyed1, keyed4]).select_from(keyed1.join(keyed4, true())) + ).first() eq_(row.b, "b4") eq_(row.q, "q4") eq_(row.a, "a1") @@ -1470,7 +1481,9 @@ class KeyTargetingTest(fixtures.TablesTest): keyed1 = self.tables.keyed1 keyed3 = self.tables.keyed3 - row = testing.db.execute(select([keyed1, keyed3])).first() + row = testing.db.execute( + select([keyed1, keyed3]).select_from(keyed1.join(keyed3, true())) + ).first() eq_(row.q, "c1") # prior to 1.4 #4887, this raised an "ambiguous column name 'a'"" @@ -1493,7 +1506,9 @@ class KeyTargetingTest(fixtures.TablesTest): keyed2 = self.tables.keyed2 row = testing.db.execute( - select([keyed1, keyed2]).apply_labels() + select([keyed1, keyed2]) + .select_from(keyed1.join(keyed2, true())) + .apply_labels() ).first() eq_(row.keyed1_b, "a1") eq_(row.keyed1_a, "a1") @@ -1515,18 +1530,22 @@ class KeyTargetingTest(fixtures.TablesTest): keyed2 = self.tables.keyed2 keyed3 = self.tables.keyed3 - stmt = select( - [ - keyed2.c.a, - keyed3.c.a, - keyed2.c.a, - keyed2.c.a, - keyed3.c.a, - keyed3.c.a, - keyed3.c.d, - keyed3.c.d, - ] - ).apply_labels() + stmt = ( + select( + [ + keyed2.c.a, + keyed3.c.a, + keyed2.c.a, + keyed2.c.a, + keyed3.c.a, + keyed3.c.a, + keyed3.c.d, + keyed3.c.d, + ] + ) + .select_from(keyed2.join(keyed3, true())) + .apply_labels() + ) result = testing.db.execute(stmt) is_false(result._metadata.matched_on_name) -- cgit v1.2.1