summaryrefslogtreecommitdiff
path: root/test/sql/test_cte.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2012-07-10 11:00:49 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2012-07-10 11:00:49 -0400
commit0e41673ed4e8551b892c058ffc6a607cf7aba71c (patch)
tree6914091d7cbb58331d147242e6efa83bd1345424 /test/sql/test_cte.py
parentb297b40fca923a03e3c34094e5298d6524944c39 (diff)
downloadsqlalchemy-0e41673ed4e8551b892c058ffc6a607cf7aba71c.tar.gz
- [bug] Fixed more un-intuitivenesses in CTEs
which prevented referring to a CTE in a union of itself without it being aliased. CTEs now render uniquely on name, rendering the outermost CTE of a given name only - all other references are rendered just as the name. This even includes other CTE/SELECTs that refer to different versions of the same CTE object, such as a SELECT or a UNION ALL of that SELECT. We are somewhat loosening the usual link between object identity and lexical identity in this case. A true name conflict between two unrelated CTEs now raises an error.
Diffstat (limited to 'test/sql/test_cte.py')
-rw-r--r--test/sql/test_cte.py179
1 files changed, 159 insertions, 20 deletions
diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py
index 36f992a86..59b347ccd 100644
--- a/test/sql/test_cte.py
+++ b/test/sql/test_cte.py
@@ -1,15 +1,16 @@
from test.lib import fixtures
-from test.lib.testing import AssertsCompiledSQL
+from test.lib.testing import AssertsCompiledSQL, assert_raises_message
from sqlalchemy.sql import table, column, select, func, literal
from sqlalchemy.dialects import mssql
from sqlalchemy.engine import default
+from sqlalchemy.exc import CompileError
class CTETest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = 'default'
def test_nonrecursive(self):
- orders = table('orders',
+ orders = table('orders',
column('region'),
column('amount'),
column('product'),
@@ -17,22 +18,22 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
)
regional_sales = select([
- orders.c.region,
+ orders.c.region,
func.sum(orders.c.amount).label('total_sales')
]).group_by(orders.c.region).cte("regional_sales")
top_regions = select([regional_sales.c.region]).\
where(
- regional_sales.c.total_sales >
+ regional_sales.c.total_sales >
select([
func.sum(regional_sales.c.total_sales)/10
])
).cte("top_regions")
s = select([
- orders.c.region,
- orders.c.product,
- func.sum(orders.c.quantity).label("product_units"),
+ orders.c.region,
+ orders.c.product,
+ func.sum(orders.c.quantity).label("product_units"),
func.sum(orders.c.amount).label("product_sales")
]).where(orders.c.region.in_(
select([top_regions.c.region])
@@ -59,15 +60,15 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
)
def test_recursive(self):
- parts = table('parts',
+ parts = table('parts',
column('part'),
column('sub_part'),
column('quantity'),
)
included_parts = select([
- parts.c.sub_part,
- parts.c.part,
+ parts.c.sub_part,
+ parts.c.part,
parts.c.quantity]).\
where(parts.c.part=='our part').\
cte(recursive=True)
@@ -76,19 +77,19 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
parts_alias = parts.alias()
included_parts = included_parts.union(
select([
- parts_alias.c.part,
- parts_alias.c.sub_part,
+ parts_alias.c.part,
+ parts_alias.c.sub_part,
parts_alias.c.quantity]).\
where(parts_alias.c.part==incl_alias.c.sub_part)
)
s = select([
- included_parts.c.sub_part,
+ included_parts.c.sub_part,
func.sum(included_parts.c.quantity).label('total_quantity')]).\
select_from(included_parts.join(
parts,included_parts.c.part==parts.c.part)).\
group_by(included_parts.c.sub_part)
- self.assert_compile(s,
+ self.assert_compile(s,
"WITH RECURSIVE anon_1(sub_part, part, quantity) "
"AS (SELECT parts.sub_part AS sub_part, parts.part "
"AS part, parts.quantity AS quantity FROM parts "
@@ -104,7 +105,7 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
# quick check that the "WITH RECURSIVE" varies per
# dialect
- self.assert_compile(s,
+ self.assert_compile(s,
"WITH anon_1(sub_part, part, quantity) "
"AS (SELECT parts.sub_part AS sub_part, parts.part "
"AS part, parts.quantity AS quantity FROM parts "
@@ -119,8 +120,146 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
dialect=mssql.dialect()
)
+ def test_recursive_union_no_alias_one(self):
+ s1 = select([literal(0).label("x")])
+ cte = s1.cte(name="cte", recursive=True)
+ cte = cte.union_all(
+ select([cte.c.x + 1]).where(cte.c.x < 10)
+ )
+ s2 = select([cte])
+ self.assert_compile(s2,
+ "WITH RECURSIVE cte(x) AS "
+ "(SELECT :param_1 AS x UNION ALL "
+ "SELECT cte.x + :x_1 AS anon_1 "
+ "FROM cte WHERE cte.x < :x_2) "
+ "SELECT cte.x FROM cte"
+ )
+
+
+ def test_recursive_union_no_alias_two(self):
+ """
+
+ pg's example:
+
+ WITH RECURSIVE t(n) AS (
+ VALUES (1)
+ UNION ALL
+ SELECT n+1 FROM t WHERE n < 100
+ )
+ SELECT sum(n) FROM t;
+
+ """
+
+ # I know, this is the PG VALUES keyword,
+ # we're cheating here. also yes we need the SELECT,
+ # sorry PG.
+ t = select([func.values(1).label("n")]).cte("t", recursive=True)
+ t = t.union_all(select([t.c.n + 1]).where(t.c.n < 100))
+ s = select([func.sum(t.c.n)])
+ self.assert_compile(s,
+ "WITH RECURSIVE t(n) AS "
+ "(SELECT values(:values_1) AS n "
+ "UNION ALL SELECT t.n + :n_1 AS anon_1 "
+ "FROM t "
+ "WHERE t.n < :n_2) "
+ "SELECT sum(t.n) AS sum_1 FROM t"
+ )
+
+ def test_recursive_union_no_alias_three(self):
+ # like test one, but let's refer to the CTE
+ # in a sibling CTE.
+
+ s1 = select([literal(0).label("x")])
+ cte = s1.cte(name="cte", recursive=True)
+
+ # can't do it here...
+ #bar = select([cte]).cte('bar')
+ cte = cte.union_all(
+ select([cte.c.x + 1]).where(cte.c.x < 10)
+ )
+ bar = select([cte]).cte('bar')
+
+ s2 = select([cte, bar])
+ self.assert_compile(s2,
+ "WITH RECURSIVE cte(x) AS "
+ "(SELECT :param_1 AS x UNION ALL "
+ "SELECT cte.x + :x_1 AS anon_1 "
+ "FROM cte WHERE cte.x < :x_2), "
+ "bar AS (SELECT cte.x AS x FROM cte) "
+ "SELECT cte.x, bar.x FROM cte, bar"
+ )
+
+
+ def test_recursive_union_no_alias_four(self):
+ # like test one and three, but let's refer
+ # previous version of "cte". here we test
+ # how the compiler resolves multiple instances
+ # of "cte".
+
+ s1 = select([literal(0).label("x")])
+ cte = s1.cte(name="cte", recursive=True)
+
+ bar = select([cte]).cte('bar')
+ cte = cte.union_all(
+ select([cte.c.x + 1]).where(cte.c.x < 10)
+ )
+
+ # outer cte rendered first, then bar, which
+ # includes "inner" cte
+ s2 = select([cte, bar])
+ self.assert_compile(s2,
+ "WITH RECURSIVE cte(x) AS "
+ "(SELECT :param_1 AS x UNION ALL "
+ "SELECT cte.x + :x_1 AS anon_1 "
+ "FROM cte WHERE cte.x < :x_2), "
+ "bar AS (SELECT cte.x AS x FROM cte) "
+ "SELECT cte.x, bar.x FROM cte, bar"
+ )
+
+ # bar rendered, only includes "inner" cte,
+ # "outer" cte isn't present
+ s2 = select([bar])
+ self.assert_compile(s2,
+ "WITH RECURSIVE cte(x) AS "
+ "(SELECT :param_1 AS x), "
+ "bar AS (SELECT cte.x AS x FROM cte) "
+ "SELECT bar.x FROM bar"
+ )
+
+ # bar rendered, but then the "outer"
+ # cte is rendered.
+ s2 = select([bar, cte])
+ self.assert_compile(s2,
+ "WITH RECURSIVE bar AS (SELECT cte.x AS x FROM cte), "
+ "cte(x) AS "
+ "(SELECT :param_1 AS x UNION ALL "
+ "SELECT cte.x + :x_1 AS anon_1 "
+ "FROM cte WHERE cte.x < :x_2) "
+
+ "SELECT bar.x, cte.x FROM bar, cte"
+ )
+
+ def test_conflicting_names(self):
+ """test a flat out name conflict."""
+
+ s1 = select([1])
+ c1= s1.cte(name='cte1', recursive=True)
+ s2 = select([1])
+ c2 = s2.cte(name='cte1', recursive=True)
+
+ s = select([c1, c2])
+ assert_raises_message(
+ CompileError,
+ "Multiple, unrelated CTEs found "
+ "with the same name: 'cte1'",
+ s.compile
+ )
+
+
+
+
def test_union(self):
- orders = table('orders',
+ orders = table('orders',
column('region'),
column('amount'),
)
@@ -135,7 +274,7 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
regional_sales.c.amount > 500
)
- self.assert_compile(s,
+ self.assert_compile(s,
"WITH regional_sales AS "
"(SELECT orders.region AS region, "
"orders.amount AS amount FROM orders) "
@@ -149,7 +288,7 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
regional_sales.c.amount < 300
)
)
- self.assert_compile(s,
+ self.assert_compile(s,
"WITH regional_sales AS "
"(SELECT orders.region AS region, "
"orders.amount AS amount FROM orders) "
@@ -160,7 +299,7 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
"regional_sales.amount < :amount_2")
def test_reserved_quote(self):
- orders = table('orders',
+ orders = table('orders',
column('order'),
)
s = select([orders.c.order]).cte("regional_sales", recursive=True)
@@ -174,7 +313,7 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
)
def test_positional_binds(self):
- orders = table('orders',
+ orders = table('orders',
column('order'),
)
s = select([orders.c.order, literal("x")]).cte("regional_sales")