summaryrefslogtreecommitdiff
path: root/test/sql
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2020-01-22 22:42:42 +0000
committerGerrit Code Review <gerrit@bbpush.zzzcomputing.com>2020-01-22 22:42:42 +0000
commit5675345fc11966d449130ff4e4327a5a4bece0c2 (patch)
treecb896a06ca43eb4eb4e0594355bb7831261a4e46 /test/sql
parent4bf97e41b2d18b86fc7c0bba6acd50e2b58a4a70 (diff)
parent3809a5ecfe785cecbc9d91a8e4e4558e3839c694 (diff)
downloadsqlalchemy-5675345fc11966d449130ff4e4327a5a4bece0c2.tar.gz
Merge "Query linter option"
Diffstat (limited to 'test/sql')
-rw-r--r--test/sql/test_defaults.py2
-rw-r--r--test/sql/test_from_linter.py277
-rw-r--r--test/sql/test_resultset.py55
3 files changed, 315 insertions, 19 deletions
diff --git a/test/sql/test_defaults.py b/test/sql/test_defaults.py
index f033abab2..b31b070d8 100644
--- a/test/sql/test_defaults.py
+++ b/test/sql/test_defaults.py
@@ -852,7 +852,7 @@ class CTEDefaultTest(fixtures.TablesTest):
if b == "select":
conn.execute(p.insert().values(s=1))
- stmt = select([p.c.s, cte.c.z])
+ stmt = select([p.c.s, cte.c.z]).where(p.c.s == cte.c.z)
elif b == "insert":
sel = select([1, cte.c.z])
stmt = (
diff --git a/test/sql/test_from_linter.py b/test/sql/test_from_linter.py
new file mode 100644
index 000000000..bf2f06b57
--- /dev/null
+++ b/test/sql/test_from_linter.py
@@ -0,0 +1,277 @@
+from sqlalchemy import Integer
+from sqlalchemy import select
+from sqlalchemy import sql
+from sqlalchemy import true
+from sqlalchemy.testing import config
+from sqlalchemy.testing import engines
+from sqlalchemy.testing import expect_warnings
+from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import is_
+from sqlalchemy.testing.schema import Column
+from sqlalchemy.testing.schema import Table
+
+
+def find_unmatching_froms(query, start=None):
+ compiled = query.compile(linting=sql.COLLECT_CARTESIAN_PRODUCTS)
+
+ return compiled.from_linter.lint(start)
+
+
+class TestFindUnmatchingFroms(fixtures.TablesTest):
+ @classmethod
+ def define_tables(cls, metadata):
+ Table("table_a", metadata, Column("col_a", Integer, primary_key=True))
+ Table("table_b", metadata, Column("col_b", Integer, primary_key=True))
+ Table("table_c", metadata, Column("col_c", Integer, primary_key=True))
+ Table("table_d", metadata, Column("col_d", Integer, primary_key=True))
+
+ def setup(self):
+ self.a = self.tables.table_a
+ self.b = self.tables.table_b
+ self.c = self.tables.table_c
+ self.d = self.tables.table_d
+
+ def test_everything_is_connected(self):
+ query = (
+ select([self.a])
+ .select_from(self.a.join(self.b, self.a.c.col_a == self.b.c.col_b))
+ .select_from(self.c)
+ .select_from(self.d)
+ .where(self.d.c.col_d == self.b.c.col_b)
+ .where(self.c.c.col_c == self.d.c.col_d)
+ .where(self.c.c.col_c == 5)
+ )
+ froms, start = find_unmatching_froms(query)
+ assert not froms
+
+ for start in self.a, self.b, self.c, self.d:
+ froms, start = find_unmatching_froms(query, start)
+ assert not froms
+
+ def test_plain_cartesian(self):
+ query = select([self.a]).where(self.b.c.col_b == 5)
+ froms, start = find_unmatching_froms(query, self.a)
+ assert start == self.a
+ assert froms == {self.b}
+
+ froms, start = find_unmatching_froms(query, self.b)
+ assert start == self.b
+ assert froms == {self.a}
+
+ def test_count_non_eq_comparison_operators(self):
+ query = select([self.a]).where(self.a.c.col_a > self.b.c.col_b)
+ froms, start = find_unmatching_froms(query, self.a)
+ is_(start, None)
+ is_(froms, None)
+
+ def test_dont_count_non_comparison_operators(self):
+ query = select([self.a]).where(self.a.c.col_a + self.b.c.col_b == 5)
+ froms, start = find_unmatching_froms(query, self.a)
+ assert start == self.a
+ assert froms == {self.b}
+
+ def test_disconnect_between_ab_cd(self):
+ query = (
+ select([self.a])
+ .select_from(self.a.join(self.b, self.a.c.col_a == self.b.c.col_b))
+ .select_from(self.c)
+ .select_from(self.d)
+ .where(self.c.c.col_c == self.d.c.col_d)
+ .where(self.c.c.col_c == 5)
+ )
+ for start in self.a, self.b:
+ froms, start = find_unmatching_froms(query, start)
+ assert start == start
+ assert froms == {self.c, self.d}
+ for start in self.c, self.d:
+ froms, start = find_unmatching_froms(query, start)
+ assert start == start
+ assert froms == {self.a, self.b}
+
+ def test_c_and_d_both_disconnected(self):
+ query = (
+ select([self.a])
+ .select_from(self.a.join(self.b, self.a.c.col_a == self.b.c.col_b))
+ .where(self.c.c.col_c == 5)
+ .where(self.d.c.col_d == 10)
+ )
+ for start in self.a, self.b:
+ froms, start = find_unmatching_froms(query, start)
+ assert start == start
+ assert froms == {self.c, self.d}
+
+ froms, start = find_unmatching_froms(query, self.c)
+ assert start == self.c
+ assert froms == {self.a, self.b, self.d}
+
+ froms, start = find_unmatching_froms(query, self.d)
+ assert start == self.d
+ assert froms == {self.a, self.b, self.c}
+
+ def test_now_connected(self):
+ query = (
+ select([self.a])
+ .select_from(self.a.join(self.b, self.a.c.col_a == self.b.c.col_b))
+ .select_from(self.c.join(self.d, self.c.c.col_c == self.d.c.col_d))
+ .where(self.c.c.col_c == self.b.c.col_b)
+ .where(self.c.c.col_c == 5)
+ .where(self.d.c.col_d == 10)
+ )
+ froms, start = find_unmatching_froms(query)
+ assert not froms
+
+ for start in self.a, self.b, self.c, self.d:
+ froms, start = find_unmatching_froms(query, start)
+ assert not froms
+
+ def test_disconnected_subquery(self):
+ subq = (
+ select([self.a]).where(self.a.c.col_a == self.b.c.col_b).subquery()
+ )
+ stmt = select([self.c]).select_from(subq)
+
+ froms, start = find_unmatching_froms(stmt, self.c)
+ assert start == self.c
+ assert froms == {subq}
+
+ froms, start = find_unmatching_froms(stmt, subq)
+ assert start == subq
+ assert froms == {self.c}
+
+ def test_now_connect_it(self):
+ subq = (
+ select([self.a]).where(self.a.c.col_a == self.b.c.col_b).subquery()
+ )
+ stmt = (
+ select([self.c])
+ .select_from(subq)
+ .where(self.c.c.col_c == subq.c.col_a)
+ )
+
+ froms, start = find_unmatching_froms(stmt)
+ assert not froms
+
+ for start in self.c, subq:
+ froms, start = find_unmatching_froms(stmt, start)
+ assert not froms
+
+ def test_right_nested_join_without_issue(self):
+ query = select([self.a]).select_from(
+ self.a.join(
+ self.b.join(self.c, self.b.c.col_b == self.c.c.col_c),
+ self.a.c.col_a == self.b.c.col_b,
+ )
+ )
+ froms, start = find_unmatching_froms(query)
+ assert not froms
+
+ for start in self.a, self.b, self.c:
+ froms, start = find_unmatching_froms(query, start)
+ assert not froms
+
+ def test_join_on_true(self):
+ # test that a join(a, b) counts a->b as an edge even if there isn't
+ # actually a join condition. this essentially allows a cartesian
+ # product to be added explicitly.
+
+ query = select([self.a]).select_from(self.a.join(self.b, true()))
+ froms, start = find_unmatching_froms(query)
+ assert not froms
+
+ def test_right_nested_join_with_an_issue(self):
+ query = (
+ select([self.a])
+ .select_from(
+ self.a.join(
+ self.b.join(self.c, self.b.c.col_b == self.c.c.col_c),
+ self.a.c.col_a == self.b.c.col_b,
+ )
+ )
+ .where(self.d.c.col_d == 5)
+ )
+
+ for start in self.a, self.b, self.c:
+ froms, start = find_unmatching_froms(query, start)
+ assert start == start
+ assert froms == {self.d}
+
+ froms, start = find_unmatching_froms(query, self.d)
+ assert start == self.d
+ assert froms == {self.a, self.b, self.c}
+
+ def test_no_froms(self):
+ query = select([1])
+
+ froms, start = find_unmatching_froms(query)
+ assert not froms
+
+
+class TestLinter(fixtures.TablesTest):
+ @classmethod
+ def define_tables(cls, metadata):
+ Table("table_a", metadata, Column("col_a", Integer, primary_key=True))
+ Table("table_b", metadata, Column("col_b", Integer, primary_key=True))
+
+ @classmethod
+ def setup_bind(cls):
+ # from linting is enabled by default
+ return config.db
+
+ def test_noop_for_unhandled_objects(self):
+ with self.bind.connect() as conn:
+ conn.execute("SELECT 1;").fetchone()
+
+ def test_does_not_modify_query(self):
+ with self.bind.connect() as conn:
+ [result] = conn.execute(select([1])).fetchone()
+ assert result == 1
+
+ def test_warn_simple(self):
+ a, b = self.tables("table_a", "table_b")
+ query = select([a.c.col_a]).where(b.c.col_b == 5)
+
+ with expect_warnings(
+ r"SELECT statement has a cartesian product between FROM "
+ r'element\(s\) "table_[ab]" '
+ r'and FROM element "table_[ba]"'
+ ):
+ with self.bind.connect() as conn:
+ conn.execute(query)
+
+ def test_warn_anon_alias(self):
+ a, b = self.tables("table_a", "table_b")
+
+ b_alias = b.alias()
+ query = select([a.c.col_a]).where(b_alias.c.col_b == 5)
+
+ with expect_warnings(
+ r"SELECT statement has a cartesian product between FROM "
+ r'element\(s\) "table_(?:a|b_1)" '
+ r'and FROM element "table_(?:a|b_1)"'
+ ):
+ with self.bind.connect() as conn:
+ conn.execute(query)
+
+ def test_warn_anon_cte(self):
+ a, b = self.tables("table_a", "table_b")
+
+ b_cte = select([b]).cte()
+ query = select([a.c.col_a]).where(b_cte.c.col_b == 5)
+
+ with expect_warnings(
+ r"SELECT statement has a cartesian product between "
+ r"FROM element\(s\) "
+ r'"(?:anon_1|table_a)" '
+ r'and FROM element "(?:anon_1|table_a)"'
+ ):
+ with self.bind.connect() as conn:
+ conn.execute(query)
+
+ def test_no_linting(self):
+ eng = engines.testing_engine(options={"enable_from_linting": False})
+ eng.pool = self.bind.pool # needed for SQLite
+ a, b = self.tables("table_a", "table_b")
+ query = select([a.c.col_a]).where(b.c.col_b == 5)
+
+ with eng.connect() as conn:
+ conn.execute(query)
diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py
index 794508a32..8aa524d78 100644
--- a/test/sql/test_resultset.py
+++ b/test/sql/test_resultset.py
@@ -19,6 +19,7 @@ from sqlalchemy import String
from sqlalchemy import table
from sqlalchemy import testing
from sqlalchemy import text
+from sqlalchemy import true
from sqlalchemy import type_coerce
from sqlalchemy import TypeDecorator
from sqlalchemy import util
@@ -771,7 +772,11 @@ class ResultProxyTest(fixtures.TablesTest):
users.insert().execute(user_id=1, user_name="john")
ua = users.alias()
u2 = users.alias()
- result = select([users.c.user_id, ua.c.user_id]).execute()
+ result = (
+ select([users.c.user_id, ua.c.user_id])
+ .select_from(users.join(ua, true()))
+ .execute()
+ )
row = result.first()
# as of 1.1 issue #3501, we use pure positional
@@ -1414,7 +1419,9 @@ class KeyTargetingTest(fixtures.TablesTest):
keyed1 = self.tables.keyed1
keyed2 = self.tables.keyed2
- row = testing.db.execute(select([keyed1, keyed2])).first()
+ row = testing.db.execute(
+ select([keyed1, keyed2]).select_from(keyed1.join(keyed2, true()))
+ ).first()
# column access is unambiguous
eq_(row[self.tables.keyed2.c.b], "b2")
@@ -1446,7 +1453,9 @@ class KeyTargetingTest(fixtures.TablesTest):
keyed2 = self.tables.keyed2
row = testing.db.execute(
- select([keyed1, keyed2]).apply_labels()
+ select([keyed1, keyed2])
+ .select_from(keyed1.join(keyed2, true()))
+ .apply_labels()
).first()
# column access is unambiguous
@@ -1459,7 +1468,9 @@ class KeyTargetingTest(fixtures.TablesTest):
keyed1 = self.tables.keyed1
keyed4 = self.tables.keyed4
- row = testing.db.execute(select([keyed1, keyed4])).first()
+ row = testing.db.execute(
+ select([keyed1, keyed4]).select_from(keyed1.join(keyed4, true()))
+ ).first()
eq_(row.b, "b4")
eq_(row.q, "q4")
eq_(row.a, "a1")
@@ -1470,7 +1481,9 @@ class KeyTargetingTest(fixtures.TablesTest):
keyed1 = self.tables.keyed1
keyed3 = self.tables.keyed3
- row = testing.db.execute(select([keyed1, keyed3])).first()
+ row = testing.db.execute(
+ select([keyed1, keyed3]).select_from(keyed1.join(keyed3, true()))
+ ).first()
eq_(row.q, "c1")
# prior to 1.4 #4887, this raised an "ambiguous column name 'a'""
@@ -1493,7 +1506,9 @@ class KeyTargetingTest(fixtures.TablesTest):
keyed2 = self.tables.keyed2
row = testing.db.execute(
- select([keyed1, keyed2]).apply_labels()
+ select([keyed1, keyed2])
+ .select_from(keyed1.join(keyed2, true()))
+ .apply_labels()
).first()
eq_(row.keyed1_b, "a1")
eq_(row.keyed1_a, "a1")
@@ -1515,18 +1530,22 @@ class KeyTargetingTest(fixtures.TablesTest):
keyed2 = self.tables.keyed2
keyed3 = self.tables.keyed3
- stmt = select(
- [
- keyed2.c.a,
- keyed3.c.a,
- keyed2.c.a,
- keyed2.c.a,
- keyed3.c.a,
- keyed3.c.a,
- keyed3.c.d,
- keyed3.c.d,
- ]
- ).apply_labels()
+ stmt = (
+ select(
+ [
+ keyed2.c.a,
+ keyed3.c.a,
+ keyed2.c.a,
+ keyed2.c.a,
+ keyed3.c.a,
+ keyed3.c.a,
+ keyed3.c.d,
+ keyed3.c.d,
+ ]
+ )
+ .select_from(keyed2.join(keyed3, true()))
+ .apply_labels()
+ )
result = testing.db.execute(stmt)
is_false(result._metadata.matched_on_name)