diff options
| author | mike bayer <mike_mp@zzzcomputing.com> | 2020-01-22 22:42:42 +0000 |
|---|---|---|
| committer | Gerrit Code Review <gerrit@bbpush.zzzcomputing.com> | 2020-01-22 22:42:42 +0000 |
| commit | 5675345fc11966d449130ff4e4327a5a4bece0c2 (patch) | |
| tree | cb896a06ca43eb4eb4e0594355bb7831261a4e46 /test/sql | |
| parent | 4bf97e41b2d18b86fc7c0bba6acd50e2b58a4a70 (diff) | |
| parent | 3809a5ecfe785cecbc9d91a8e4e4558e3839c694 (diff) | |
| download | sqlalchemy-5675345fc11966d449130ff4e4327a5a4bece0c2.tar.gz | |
Merge "Query linter option"
Diffstat (limited to 'test/sql')
| -rw-r--r-- | test/sql/test_defaults.py | 2 | ||||
| -rw-r--r-- | test/sql/test_from_linter.py | 277 | ||||
| -rw-r--r-- | test/sql/test_resultset.py | 55 |
3 files changed, 315 insertions, 19 deletions
diff --git a/test/sql/test_defaults.py b/test/sql/test_defaults.py index f033abab2..b31b070d8 100644 --- a/test/sql/test_defaults.py +++ b/test/sql/test_defaults.py @@ -852,7 +852,7 @@ class CTEDefaultTest(fixtures.TablesTest): if b == "select": conn.execute(p.insert().values(s=1)) - stmt = select([p.c.s, cte.c.z]) + stmt = select([p.c.s, cte.c.z]).where(p.c.s == cte.c.z) elif b == "insert": sel = select([1, cte.c.z]) stmt = ( diff --git a/test/sql/test_from_linter.py b/test/sql/test_from_linter.py new file mode 100644 index 000000000..bf2f06b57 --- /dev/null +++ b/test/sql/test_from_linter.py @@ -0,0 +1,277 @@ +from sqlalchemy import Integer +from sqlalchemy import select +from sqlalchemy import sql +from sqlalchemy import true +from sqlalchemy.testing import config +from sqlalchemy.testing import engines +from sqlalchemy.testing import expect_warnings +from sqlalchemy.testing import fixtures +from sqlalchemy.testing import is_ +from sqlalchemy.testing.schema import Column +from sqlalchemy.testing.schema import Table + + +def find_unmatching_froms(query, start=None): + compiled = query.compile(linting=sql.COLLECT_CARTESIAN_PRODUCTS) + + return compiled.from_linter.lint(start) + + +class TestFindUnmatchingFroms(fixtures.TablesTest): + @classmethod + def define_tables(cls, metadata): + Table("table_a", metadata, Column("col_a", Integer, primary_key=True)) + Table("table_b", metadata, Column("col_b", Integer, primary_key=True)) + Table("table_c", metadata, Column("col_c", Integer, primary_key=True)) + Table("table_d", metadata, Column("col_d", Integer, primary_key=True)) + + def setup(self): + self.a = self.tables.table_a + self.b = self.tables.table_b + self.c = self.tables.table_c + self.d = self.tables.table_d + + def test_everything_is_connected(self): + query = ( + select([self.a]) + .select_from(self.a.join(self.b, self.a.c.col_a == self.b.c.col_b)) + .select_from(self.c) + .select_from(self.d) + .where(self.d.c.col_d == self.b.c.col_b) + .where(self.c.c.col_c == self.d.c.col_d) + .where(self.c.c.col_c == 5) + ) + froms, start = find_unmatching_froms(query) + assert not froms + + for start in self.a, self.b, self.c, self.d: + froms, start = find_unmatching_froms(query, start) + assert not froms + + def test_plain_cartesian(self): + query = select([self.a]).where(self.b.c.col_b == 5) + froms, start = find_unmatching_froms(query, self.a) + assert start == self.a + assert froms == {self.b} + + froms, start = find_unmatching_froms(query, self.b) + assert start == self.b + assert froms == {self.a} + + def test_count_non_eq_comparison_operators(self): + query = select([self.a]).where(self.a.c.col_a > self.b.c.col_b) + froms, start = find_unmatching_froms(query, self.a) + is_(start, None) + is_(froms, None) + + def test_dont_count_non_comparison_operators(self): + query = select([self.a]).where(self.a.c.col_a + self.b.c.col_b == 5) + froms, start = find_unmatching_froms(query, self.a) + assert start == self.a + assert froms == {self.b} + + def test_disconnect_between_ab_cd(self): + query = ( + select([self.a]) + .select_from(self.a.join(self.b, self.a.c.col_a == self.b.c.col_b)) + .select_from(self.c) + .select_from(self.d) + .where(self.c.c.col_c == self.d.c.col_d) + .where(self.c.c.col_c == 5) + ) + for start in self.a, self.b: + froms, start = find_unmatching_froms(query, start) + assert start == start + assert froms == {self.c, self.d} + for start in self.c, self.d: + froms, start = find_unmatching_froms(query, start) + assert start == start + assert froms == {self.a, self.b} + + def test_c_and_d_both_disconnected(self): + query = ( + select([self.a]) + .select_from(self.a.join(self.b, self.a.c.col_a == self.b.c.col_b)) + .where(self.c.c.col_c == 5) + .where(self.d.c.col_d == 10) + ) + for start in self.a, self.b: + froms, start = find_unmatching_froms(query, start) + assert start == start + assert froms == {self.c, self.d} + + froms, start = find_unmatching_froms(query, self.c) + assert start == self.c + assert froms == {self.a, self.b, self.d} + + froms, start = find_unmatching_froms(query, self.d) + assert start == self.d + assert froms == {self.a, self.b, self.c} + + def test_now_connected(self): + query = ( + select([self.a]) + .select_from(self.a.join(self.b, self.a.c.col_a == self.b.c.col_b)) + .select_from(self.c.join(self.d, self.c.c.col_c == self.d.c.col_d)) + .where(self.c.c.col_c == self.b.c.col_b) + .where(self.c.c.col_c == 5) + .where(self.d.c.col_d == 10) + ) + froms, start = find_unmatching_froms(query) + assert not froms + + for start in self.a, self.b, self.c, self.d: + froms, start = find_unmatching_froms(query, start) + assert not froms + + def test_disconnected_subquery(self): + subq = ( + select([self.a]).where(self.a.c.col_a == self.b.c.col_b).subquery() + ) + stmt = select([self.c]).select_from(subq) + + froms, start = find_unmatching_froms(stmt, self.c) + assert start == self.c + assert froms == {subq} + + froms, start = find_unmatching_froms(stmt, subq) + assert start == subq + assert froms == {self.c} + + def test_now_connect_it(self): + subq = ( + select([self.a]).where(self.a.c.col_a == self.b.c.col_b).subquery() + ) + stmt = ( + select([self.c]) + .select_from(subq) + .where(self.c.c.col_c == subq.c.col_a) + ) + + froms, start = find_unmatching_froms(stmt) + assert not froms + + for start in self.c, subq: + froms, start = find_unmatching_froms(stmt, start) + assert not froms + + def test_right_nested_join_without_issue(self): + query = select([self.a]).select_from( + self.a.join( + self.b.join(self.c, self.b.c.col_b == self.c.c.col_c), + self.a.c.col_a == self.b.c.col_b, + ) + ) + froms, start = find_unmatching_froms(query) + assert not froms + + for start in self.a, self.b, self.c: + froms, start = find_unmatching_froms(query, start) + assert not froms + + def test_join_on_true(self): + # test that a join(a, b) counts a->b as an edge even if there isn't + # actually a join condition. this essentially allows a cartesian + # product to be added explicitly. + + query = select([self.a]).select_from(self.a.join(self.b, true())) + froms, start = find_unmatching_froms(query) + assert not froms + + def test_right_nested_join_with_an_issue(self): + query = ( + select([self.a]) + .select_from( + self.a.join( + self.b.join(self.c, self.b.c.col_b == self.c.c.col_c), + self.a.c.col_a == self.b.c.col_b, + ) + ) + .where(self.d.c.col_d == 5) + ) + + for start in self.a, self.b, self.c: + froms, start = find_unmatching_froms(query, start) + assert start == start + assert froms == {self.d} + + froms, start = find_unmatching_froms(query, self.d) + assert start == self.d + assert froms == {self.a, self.b, self.c} + + def test_no_froms(self): + query = select([1]) + + froms, start = find_unmatching_froms(query) + assert not froms + + +class TestLinter(fixtures.TablesTest): + @classmethod + def define_tables(cls, metadata): + Table("table_a", metadata, Column("col_a", Integer, primary_key=True)) + Table("table_b", metadata, Column("col_b", Integer, primary_key=True)) + + @classmethod + def setup_bind(cls): + # from linting is enabled by default + return config.db + + def test_noop_for_unhandled_objects(self): + with self.bind.connect() as conn: + conn.execute("SELECT 1;").fetchone() + + def test_does_not_modify_query(self): + with self.bind.connect() as conn: + [result] = conn.execute(select([1])).fetchone() + assert result == 1 + + def test_warn_simple(self): + a, b = self.tables("table_a", "table_b") + query = select([a.c.col_a]).where(b.c.col_b == 5) + + with expect_warnings( + r"SELECT statement has a cartesian product between FROM " + r'element\(s\) "table_[ab]" ' + r'and FROM element "table_[ba]"' + ): + with self.bind.connect() as conn: + conn.execute(query) + + def test_warn_anon_alias(self): + a, b = self.tables("table_a", "table_b") + + b_alias = b.alias() + query = select([a.c.col_a]).where(b_alias.c.col_b == 5) + + with expect_warnings( + r"SELECT statement has a cartesian product between FROM " + r'element\(s\) "table_(?:a|b_1)" ' + r'and FROM element "table_(?:a|b_1)"' + ): + with self.bind.connect() as conn: + conn.execute(query) + + def test_warn_anon_cte(self): + a, b = self.tables("table_a", "table_b") + + b_cte = select([b]).cte() + query = select([a.c.col_a]).where(b_cte.c.col_b == 5) + + with expect_warnings( + r"SELECT statement has a cartesian product between " + r"FROM element\(s\) " + r'"(?:anon_1|table_a)" ' + r'and FROM element "(?:anon_1|table_a)"' + ): + with self.bind.connect() as conn: + conn.execute(query) + + def test_no_linting(self): + eng = engines.testing_engine(options={"enable_from_linting": False}) + eng.pool = self.bind.pool # needed for SQLite + a, b = self.tables("table_a", "table_b") + query = select([a.c.col_a]).where(b.c.col_b == 5) + + with eng.connect() as conn: + conn.execute(query) diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py index 794508a32..8aa524d78 100644 --- a/test/sql/test_resultset.py +++ b/test/sql/test_resultset.py @@ -19,6 +19,7 @@ from sqlalchemy import String from sqlalchemy import table from sqlalchemy import testing from sqlalchemy import text +from sqlalchemy import true from sqlalchemy import type_coerce from sqlalchemy import TypeDecorator from sqlalchemy import util @@ -771,7 +772,11 @@ class ResultProxyTest(fixtures.TablesTest): users.insert().execute(user_id=1, user_name="john") ua = users.alias() u2 = users.alias() - result = select([users.c.user_id, ua.c.user_id]).execute() + result = ( + select([users.c.user_id, ua.c.user_id]) + .select_from(users.join(ua, true())) + .execute() + ) row = result.first() # as of 1.1 issue #3501, we use pure positional @@ -1414,7 +1419,9 @@ class KeyTargetingTest(fixtures.TablesTest): keyed1 = self.tables.keyed1 keyed2 = self.tables.keyed2 - row = testing.db.execute(select([keyed1, keyed2])).first() + row = testing.db.execute( + select([keyed1, keyed2]).select_from(keyed1.join(keyed2, true())) + ).first() # column access is unambiguous eq_(row[self.tables.keyed2.c.b], "b2") @@ -1446,7 +1453,9 @@ class KeyTargetingTest(fixtures.TablesTest): keyed2 = self.tables.keyed2 row = testing.db.execute( - select([keyed1, keyed2]).apply_labels() + select([keyed1, keyed2]) + .select_from(keyed1.join(keyed2, true())) + .apply_labels() ).first() # column access is unambiguous @@ -1459,7 +1468,9 @@ class KeyTargetingTest(fixtures.TablesTest): keyed1 = self.tables.keyed1 keyed4 = self.tables.keyed4 - row = testing.db.execute(select([keyed1, keyed4])).first() + row = testing.db.execute( + select([keyed1, keyed4]).select_from(keyed1.join(keyed4, true())) + ).first() eq_(row.b, "b4") eq_(row.q, "q4") eq_(row.a, "a1") @@ -1470,7 +1481,9 @@ class KeyTargetingTest(fixtures.TablesTest): keyed1 = self.tables.keyed1 keyed3 = self.tables.keyed3 - row = testing.db.execute(select([keyed1, keyed3])).first() + row = testing.db.execute( + select([keyed1, keyed3]).select_from(keyed1.join(keyed3, true())) + ).first() eq_(row.q, "c1") # prior to 1.4 #4887, this raised an "ambiguous column name 'a'"" @@ -1493,7 +1506,9 @@ class KeyTargetingTest(fixtures.TablesTest): keyed2 = self.tables.keyed2 row = testing.db.execute( - select([keyed1, keyed2]).apply_labels() + select([keyed1, keyed2]) + .select_from(keyed1.join(keyed2, true())) + .apply_labels() ).first() eq_(row.keyed1_b, "a1") eq_(row.keyed1_a, "a1") @@ -1515,18 +1530,22 @@ class KeyTargetingTest(fixtures.TablesTest): keyed2 = self.tables.keyed2 keyed3 = self.tables.keyed3 - stmt = select( - [ - keyed2.c.a, - keyed3.c.a, - keyed2.c.a, - keyed2.c.a, - keyed3.c.a, - keyed3.c.a, - keyed3.c.d, - keyed3.c.d, - ] - ).apply_labels() + stmt = ( + select( + [ + keyed2.c.a, + keyed3.c.a, + keyed2.c.a, + keyed2.c.a, + keyed3.c.a, + keyed3.c.a, + keyed3.c.d, + keyed3.c.d, + ] + ) + .select_from(keyed2.join(keyed3, true())) + .apply_labels() + ) result = testing.db.execute(stmt) is_false(result._metadata.matched_on_name) |
