diff options
| -rw-r--r-- | doc/build/changelog/changelog_11.rst | 13 | ||||
| -rw-r--r-- | doc/build/changelog/migration_11.rst | 59 | ||||
| -rw-r--r-- | doc/build/core/selectable.rst | 3 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/query.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 94 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/crud.py | 47 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/dml.py | 9 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/expression.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 288 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/assertions.py | 2 | ||||
| -rw-r--r-- | test/sql/test_cte.py | 152 | ||||
| -rw-r--r-- | test/sql/test_insert.py | 5 |
12 files changed, 503 insertions, 177 deletions
diff --git a/doc/build/changelog/changelog_11.rst b/doc/build/changelog/changelog_11.rst index 2473a02a2..273bffb83 100644 --- a/doc/build/changelog/changelog_11.rst +++ b/doc/build/changelog/changelog_11.rst @@ -22,6 +22,19 @@ :version: 1.1.0b1 .. change:: + :tags: feature, sql + :tickets: 2551 + + CTE functionality has been expanded to support all DML, allowing + INSERT, UPDATE, and DELETE statements to both specify their own + WITH clause, as well as for these statements themselves to be + CTE expressions when they include a RETURNING clause. + + .. seealso:: + + :ref:`change_2551` + + .. change:: :tags: bug, orm :tickets: 3641 diff --git a/doc/build/changelog/migration_11.rst b/doc/build/changelog/migration_11.rst index 3d65ede80..7eb8e800f 100644 --- a/doc/build/changelog/migration_11.rst +++ b/doc/build/changelog/migration_11.rst @@ -529,6 +529,65 @@ remains unchanged. New Features and Improvements - Core ==================================== +.. _change_2551: + +CTE Support for INSERT, UPDATE, DELETE +-------------------------------------- + +One of the most widely requested features is support for common table +expressions (CTE) that work with INSERT, UPDATE, DELETE, and is now implemented. +An INSERT/UPDATE/DELETE can both draw from a WITH clause that's stated at the +top of the SQL, as well as can be used as a CTE itself in the context of +a larger statement. + +As part of this change, an INSERT from SELECT that includes a CTE will now +render the CTE at the top of the entire statement, rather than nested +in the SELECT statement as was the case in 1.0. + +Below is an example that renders UPDATE, INSERT and SELECT all in one +statement:: + + >>> from sqlalchemy import table, column, select, literal, exists + >>> orders = table( + ... 'orders', + ... column('region'), + ... column('amount'), + ... column('product'), + ... column('quantity') + ... ) + >>> + >>> upsert = ( + ... orders.update() + ... .where(orders.c.region == 'Region1') + ... .values(amount=1.0, product='Product1', quantity=1) + ... .returning(*(orders.c._all_columns)).cte('upsert')) + >>> + >>> insert = orders.insert().from_select( + ... orders.c.keys(), + ... select([ + ... literal('Region1'), literal(1.0), + ... literal('Product1'), literal(1) + ... ]).where(~exists(upsert.select())) + ... ) + >>> + >>> print(insert) # note formatting added for clarity + WITH upsert AS + (UPDATE orders SET amount=:amount, product=:product, quantity=:quantity + WHERE orders.region = :region_1 + RETURNING orders.region, orders.amount, orders.product, orders.quantity + ) + INSERT INTO orders (region, amount, product, quantity) + SELECT + :param_1 AS anon_1, :param_2 AS anon_2, + :param_3 AS anon_3, :param_4 AS anon_4 + WHERE NOT ( + EXISTS ( + SELECT upsert.region, upsert.amount, + upsert.product, upsert.quantity + FROM upsert)) + +:ticket:`2551` + .. _change_3216: The ``.autoincrement`` directive is no longer implicitly enabled for a composite primary key column diff --git a/doc/build/core/selectable.rst b/doc/build/core/selectable.rst index e73ce7b64..a582ab4dc 100644 --- a/doc/build/core/selectable.rst +++ b/doc/build/core/selectable.rst @@ -57,6 +57,9 @@ elements are themselves :class:`.ColumnElement` subclasses). :members: :inherited-members: +.. autoclass:: HasCTE + :members: + .. autoclass:: HasPrefixes :members: diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 8a25f570a..ad7b9130b 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -478,8 +478,6 @@ class Query(object): """Return the full SELECT statement represented by this :class:`.Query` represented as a common table expression (CTE). - .. versionadded:: 0.7.6 - Parameters and usage are the same as those of the :meth:`.SelectBase.cte` method; see that method for further details. @@ -528,7 +526,7 @@ class Query(object): .. seealso:: - :meth:`.SelectBase.cte` + :meth:`.HasCTE.cte` """ return self.enable_eagerloads(False).\ diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index cc9a49a91..a2fc0fe68 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -418,6 +418,11 @@ class SQLCompiler(Compiled): self.truncated_names = {} Compiled.__init__(self, dialect, statement, **kwargs) + if ( + self.isinsert or self.isupdate or self.isdelete + ) and statement._returning: + self.returning = statement._returning + if self.positional and dialect.paramstyle == 'numeric': self._apply_numbered_params() @@ -1659,7 +1664,7 @@ class SQLCompiler(Compiled): if per_dialect: text += " " + self.get_statement_hint_text(per_dialect) - if self.ctes and self._is_toplevel_select(select): + if self.ctes and toplevel: text = self._render_cte_clause() + text if select._suffixes: @@ -1673,20 +1678,6 @@ class SQLCompiler(Compiled): else: return text - def _is_toplevel_select(self, select): - """Return True if the stack is placed at the given select, and - is also the outermost SELECT, meaning there is either no stack - before this one, or the enclosing stack is a topmost INSERT. - - """ - return ( - self.stack[-1]['selectable'] is select and - ( - len(self.stack) == 1 or self.isinsert and len(self.stack) == 2 - and self.statement is self.stack[0]['selectable'] - ) - ) - def _setup_select_hints(self, select): byfrom = dict([ (from_, hinttext % { @@ -1876,14 +1867,16 @@ class SQLCompiler(Compiled): ) return dialect_hints, table_text - def visit_insert(self, insert_stmt, **kw): + def visit_insert(self, insert_stmt, asfrom=False, **kw): + toplevel = not self.stack + self.stack.append( {'correlate_froms': set(), "asfrom_froms": set(), "selectable": insert_stmt}) - self.isinsert = True - crud_params = crud._get_crud_params(self, insert_stmt, **kw) + crud_params = crud._setup_crud_params( + self, insert_stmt, crud.ISINSERT, **kw) if not crud_params and \ not self.dialect.supports_default_values and \ @@ -1929,12 +1922,13 @@ class SQLCompiler(Compiled): for c in crud_params_single]) if self.returning or insert_stmt._returning: - self.returning = self.returning or insert_stmt._returning returning_clause = self.returning_clause( - insert_stmt, self.returning) + insert_stmt, self.returning or insert_stmt._returning) if self.returning_precedes_values: text += " " + returning_clause + else: + returning_clause = None if insert_stmt.select is not None: text += " %s" % self.process(self._insert_from_select, **kw) @@ -1953,12 +1947,18 @@ class SQLCompiler(Compiled): text += " VALUES (%s)" % \ ', '.join([c[1] for c in crud_params]) - if self.returning and not self.returning_precedes_values: + if returning_clause and not self.returning_precedes_values: text += " " + returning_clause + if self.ctes and toplevel: + text = self._render_cte_clause() + text + self.stack.pop(-1) - return text + if asfrom: + return "(" + text + ")" + else: + return text def update_limit_clause(self, update_stmt): """Provide a hook for MySQL to add LIMIT to the UPDATE""" @@ -1972,8 +1972,8 @@ class SQLCompiler(Compiled): MySQL overrides this. """ - return from_table._compiler_dispatch(self, asfrom=True, - iscrud=True, **kw) + kw['asfrom'] = True + return from_table._compiler_dispatch(self, iscrud=True, **kw) def update_from_clause(self, update_stmt, from_table, extra_froms, @@ -1990,14 +1990,14 @@ class SQLCompiler(Compiled): fromhints=from_hints, **kw) for t in extra_froms) - def visit_update(self, update_stmt, **kw): + def visit_update(self, update_stmt, asfrom=False, **kw): + toplevel = not self.stack + self.stack.append( {'correlate_froms': set([update_stmt.table]), "asfrom_froms": set([update_stmt.table]), "selectable": update_stmt}) - self.isupdate = True - extra_froms = update_stmt._extra_froms text = "UPDATE " @@ -2009,7 +2009,8 @@ class SQLCompiler(Compiled): table_text = self.update_tables_clause(update_stmt, update_stmt.table, extra_froms, **kw) - crud_params = crud._get_crud_params(self, update_stmt, **kw) + crud_params = crud._setup_crud_params( + self, update_stmt, crud.ISUPDATE, **kw) if update_stmt._hints: dialect_hints, table_text = self._setup_crud_hints( @@ -2029,11 +2030,9 @@ class SQLCompiler(Compiled): ) if self.returning or update_stmt._returning: - if not self.returning: - self.returning = update_stmt._returning if self.returning_precedes_values: text += " " + self.returning_clause( - update_stmt, self.returning) + update_stmt, self.returning or update_stmt._returning) if extra_froms: extra_from_text = self.update_from_clause( @@ -2053,23 +2052,33 @@ class SQLCompiler(Compiled): if limit_clause: text += " " + limit_clause - if self.returning and not self.returning_precedes_values: + if (self.returning or update_stmt._returning) and \ + not self.returning_precedes_values: text += " " + self.returning_clause( - update_stmt, self.returning) + update_stmt, self.returning or update_stmt._returning) + + if self.ctes and toplevel: + text = self._render_cte_clause() + text self.stack.pop(-1) - return text + if asfrom: + return "(" + text + ")" + else: + return text @util.memoized_property def _key_getters_for_crud_column(self): - return crud._key_getters_for_crud_column(self) + return crud._key_getters_for_crud_column(self, self.statement) + + def visit_delete(self, delete_stmt, asfrom=False, **kw): + toplevel = not self.stack - def visit_delete(self, delete_stmt, **kw): self.stack.append({'correlate_froms': set([delete_stmt.table]), "asfrom_froms": set([delete_stmt.table]), "selectable": delete_stmt}) - self.isdelete = True + + crud._setup_crud_params(self, delete_stmt, crud.ISDELETE, **kw) text = "DELETE " @@ -2088,7 +2097,6 @@ class SQLCompiler(Compiled): text += table_text if delete_stmt._returning: - self.returning = delete_stmt._returning if self.returning_precedes_values: text += " " + self.returning_clause( delete_stmt, delete_stmt._returning) @@ -2098,13 +2106,19 @@ class SQLCompiler(Compiled): if t: text += " WHERE " + t - if self.returning and not self.returning_precedes_values: + if delete_stmt._returning and not self.returning_precedes_values: text += " " + self.returning_clause( delete_stmt, delete_stmt._returning) + if self.ctes and toplevel: + text = self._render_cte_clause() + text + self.stack.pop(-1) - return text + if asfrom: + return "(" + text + ")" + else: + return text def visit_savepoint(self, savepoint_stmt): return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index a01b72e61..58cd80995 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -25,6 +25,41 @@ values present. """) +ISINSERT = util.symbol('ISINSERT') +ISUPDATE = util.symbol('ISUPDATE') +ISDELETE = util.symbol('ISDELETE') + + +def _setup_crud_params(compiler, stmt, local_stmt_type, **kw): + restore_isinsert = compiler.isinsert + restore_isupdate = compiler.isupdate + restore_isdelete = compiler.isdelete + + should_restore = ( + restore_isinsert or restore_isupdate or restore_isdelete + ) or len(compiler.stack) > 1 + + if local_stmt_type is ISINSERT: + compiler.isupdate = False + compiler.isinsert = True + elif local_stmt_type is ISUPDATE: + compiler.isupdate = True + compiler.isinsert = False + elif local_stmt_type is ISDELETE: + if not should_restore: + compiler.isdelete = True + else: + assert False, "ISINSERT, ISUPDATE, or ISDELETE expected" + + try: + if local_stmt_type in (ISINSERT, ISUPDATE): + return _get_crud_params(compiler, stmt, **kw) + finally: + if should_restore: + compiler.isinsert = restore_isinsert + compiler.isupdate = restore_isupdate + compiler.isdelete = restore_isdelete + def _get_crud_params(compiler, stmt, **kw): """create a set of tuples representing column/string pairs for use @@ -59,7 +94,7 @@ def _get_crud_params(compiler, stmt, **kw): # but in the case of mysql multi-table update, the rules for # .key must conditionally take tablename into account _column_as_key, _getattr_col_key, _col_bind_name = \ - _key_getters_for_crud_column(compiler) + _key_getters_for_crud_column(compiler, stmt) # if we have statement parameters - set defaults in the # compiled params @@ -128,15 +163,15 @@ def _create_bind_param( return bindparam -def _key_getters_for_crud_column(compiler): - if compiler.isupdate and compiler.statement._extra_froms: +def _key_getters_for_crud_column(compiler, stmt): + if compiler.isupdate and stmt._extra_froms: # when extra tables are present, refer to the columns # in those extra tables as table-qualified, including in # dictionaries and when rendering bind param names. # the "main" table of the statement remains unqualified, # allowing the most compatibility with a non-multi-table # statement. - _et = set(compiler.statement._extra_froms) + _et = set(stmt._extra_froms) def _column_as_key(key): str_key = elements._column_as_key(key) @@ -609,7 +644,9 @@ def _get_returning_modifiers(compiler, stmt): stmt.table.implicit_returning and stmt._return_defaults) else: - implicit_return_defaults = False + # this line is unused, currently we are always + # isinsert or isupdate + implicit_return_defaults = False # pragma: no cover if implicit_return_defaults: if stmt._return_defaults is True: diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 7b506f9db..8f368dcdb 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -9,15 +9,18 @@ Provide :class:`.Insert`, :class:`.Update` and :class:`.Delete`. """ -from .base import Executable, _generative, _from_objects, DialectKWArgs +from .base import Executable, _generative, _from_objects, DialectKWArgs, \ + ColumnCollection from .elements import ClauseElement, _literal_as_text, Null, and_, _clone, \ _column_as_key -from .selectable import _interpret_as_from, _interpret_as_select, HasPrefixes +from .selectable import _interpret_as_from, _interpret_as_select, \ + HasPrefixes, HasCTE from .. import util from .. import exc -class UpdateBase(DialectKWArgs, HasPrefixes, Executable, ClauseElement): +class UpdateBase( + HasCTE, DialectKWArgs, HasPrefixes, Executable, ClauseElement): """Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements. """ diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index a63eac2f8..36f7f7fe1 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -46,8 +46,8 @@ from .base import ColumnCollection, Generative, Executable, \ from .selectable import Alias, Join, Select, Selectable, TableClause, \ CompoundSelect, CTE, FromClause, FromGrouping, SelectBase, \ - alias, GenerativeSelect, \ - subquery, HasPrefixes, HasSuffixes, Exists, ScalarSelect, TextAsFrom + alias, GenerativeSelect, subquery, HasCTE, HasPrefixes, HasSuffixes, \ + Exists, ScalarSelect, TextAsFrom from .dml import Insert, Update, Delete, UpdateBase, ValuesBase diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index c3906c2f2..fcd22a786 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -1195,6 +1195,15 @@ class CTE(Generative, HasSuffixes, Alias): self._suffixes = _suffixes super(CTE, self).__init__(selectable, name=name) + @util.dependencies("sqlalchemy.sql.dml") + def _populate_column_collection(self, dml): + if isinstance(self.element, dml.UpdateBase): + for col in self.element._returning: + col._make_proxy(self) + else: + for col in self.element.columns._all_columns: + col._make_proxy(self) + def alias(self, name=None, flat=False): return CTE( self.original, @@ -1223,6 +1232,164 @@ class CTE(Generative, HasSuffixes, Alias): ) +class HasCTE(object): + """Mixin that declares a class to include CTE support. + + .. versionadded:: 1.1 + + """ + + def cte(self, name=None, recursive=False): + """Return a new :class:`.CTE`, or Common Table Expression instance. + + Common table expressions are a SQL standard whereby SELECT + statements can draw upon secondary statements specified along + with the primary statement, using a clause called "WITH". + Special semantics regarding UNION can also be employed to + allow "recursive" queries, where a SELECT statement can draw + upon the set of rows that have previously been selected. + + CTEs can also be applied to DML constructs UPDATE, INSERT + and DELETE on some databases, both as a source of CTE rows + when combined with RETURNING, as well as a consumer of + CTE rows. + + SQLAlchemy detects :class:`.CTE` objects, which are treated + similarly to :class:`.Alias` objects, as special elements + to be delivered to the FROM clause of the statement as well + as to a WITH clause at the top of the statement. + + .. versionchanged:: 1.1 Added support for UPDATE/INSERT/DELETE as + CTE, CTEs added to UPDATE/INSERT/DELETE. + + :param name: name given to the common table expression. Like + :meth:`._FromClause.alias`, the name can be left as ``None`` + in which case an anonymous symbol will be used at query + compile time. + :param recursive: if ``True``, will render ``WITH RECURSIVE``. + A recursive common table expression is intended to be used in + conjunction with UNION ALL in order to derive rows + from those already selected. + + The following examples include two from Postgresql's documentation at + http://www.postgresql.org/docs/current/static/queries-with.html, + as well as additional examples. + + Example 1, non recursive:: + + from sqlalchemy import (Table, Column, String, Integer, + MetaData, select, func) + + metadata = MetaData() + + orders = Table('orders', metadata, + Column('region', String), + Column('amount', Integer), + Column('product', String), + Column('quantity', Integer) + ) + + regional_sales = select([ + orders.c.region, + func.sum(orders.c.amount).label('total_sales') + ]).group_by(orders.c.region).cte("regional_sales") + + + top_regions = select([regional_sales.c.region]).\\ + where( + regional_sales.c.total_sales > + select([ + func.sum(regional_sales.c.total_sales)/10 + ]) + ).cte("top_regions") + + statement = select([ + orders.c.region, + orders.c.product, + func.sum(orders.c.quantity).label("product_units"), + func.sum(orders.c.amount).label("product_sales") + ]).where(orders.c.region.in_( + select([top_regions.c.region]) + )).group_by(orders.c.region, orders.c.product) + + result = conn.execute(statement).fetchall() + + Example 2, WITH RECURSIVE:: + + from sqlalchemy import (Table, Column, String, Integer, + MetaData, select, func) + + metadata = MetaData() + + parts = Table('parts', metadata, + Column('part', String), + Column('sub_part', String), + Column('quantity', Integer), + ) + + included_parts = select([ + parts.c.sub_part, + parts.c.part, + parts.c.quantity]).\\ + where(parts.c.part=='our part').\\ + cte(recursive=True) + + + incl_alias = included_parts.alias() + parts_alias = parts.alias() + included_parts = included_parts.union_all( + select([ + parts_alias.c.sub_part, + parts_alias.c.part, + parts_alias.c.quantity + ]). + where(parts_alias.c.part==incl_alias.c.sub_part) + ) + + statement = select([ + included_parts.c.sub_part, + func.sum(included_parts.c.quantity). + label('total_quantity') + ]).\\ + group_by(included_parts.c.sub_part) + + result = conn.execute(statement).fetchall() + + Example 3, an upsert using UPDATE and INSERT with CTEs:: + + orders = table( + 'orders', + column('region'), + column('amount'), + column('product'), + column('quantity') + ) + + upsert = ( + orders.update() + .where(orders.c.region == 'Region1') + .values(amount=1.0, product='Product1', quantity=1) + .returning(*(orders.c._all_columns)).cte('upsert')) + + insert = orders.insert().from_select( + orders.c.keys(), + select([ + literal('Region1'), literal(1.0), + literal('Product1'), literal(1) + ).where(exists(upsert.select())) + ) + + connection.execute(insert) + + .. seealso:: + + :meth:`.orm.query.Query.cte` - ORM version of + :meth:`.HasCTE.cte`. + + """ + return CTE(self, name=name, recursive=recursive) + + class FromGrouping(FromClause): """Represent a grouping of a FROM clause""" __visit_name__ = 'grouping' @@ -1497,7 +1664,7 @@ class ForUpdateArg(ClauseElement): self.of = None -class SelectBase(Executable, FromClause): +class SelectBase(HasCTE, Executable, FromClause): """Base class for SELECT statements. @@ -1531,125 +1698,6 @@ class SelectBase(Executable, FromClause): """ return self.as_scalar().label(name) - def cte(self, name=None, recursive=False): - """Return a new :class:`.CTE`, or Common Table Expression instance. - - Common table expressions are a SQL standard whereby SELECT - statements can draw upon secondary statements specified along - with the primary statement, using a clause called "WITH". - Special semantics regarding UNION can also be employed to - allow "recursive" queries, where a SELECT statement can draw - upon the set of rows that have previously been selected. - - SQLAlchemy detects :class:`.CTE` objects, which are treated - similarly to :class:`.Alias` objects, as special elements - to be delivered to the FROM clause of the statement as well - as to a WITH clause at the top of the statement. - - .. versionadded:: 0.7.6 - - :param name: name given to the common table expression. Like - :meth:`._FromClause.alias`, the name can be left as ``None`` - in which case an anonymous symbol will be used at query - compile time. - :param recursive: if ``True``, will render ``WITH RECURSIVE``. - A recursive common table expression is intended to be used in - conjunction with UNION ALL in order to derive rows - from those already selected. - - The following examples illustrate two examples from - Postgresql's documentation at - http://www.postgresql.org/docs/8.4/static/queries-with.html. - - Example 1, non recursive:: - - from sqlalchemy import (Table, Column, String, Integer, - MetaData, select, func) - - metadata = MetaData() - - orders = Table('orders', metadata, - Column('region', String), - Column('amount', Integer), - Column('product', String), - Column('quantity', Integer) - ) - - regional_sales = select([ - orders.c.region, - func.sum(orders.c.amount).label('total_sales') - ]).group_by(orders.c.region).cte("regional_sales") - - - top_regions = select([regional_sales.c.region]).\\ - where( - regional_sales.c.total_sales > - select([ - func.sum(regional_sales.c.total_sales)/10 - ]) - ).cte("top_regions") - - statement = select([ - orders.c.region, - orders.c.product, - func.sum(orders.c.quantity).label("product_units"), - func.sum(orders.c.amount).label("product_sales") - ]).where(orders.c.region.in_( - select([top_regions.c.region]) - )).group_by(orders.c.region, orders.c.product) - - result = conn.execute(statement).fetchall() - - Example 2, WITH RECURSIVE:: - - from sqlalchemy import (Table, Column, String, Integer, - MetaData, select, func) - - metadata = MetaData() - - parts = Table('parts', metadata, - Column('part', String), - Column('sub_part', String), - Column('quantity', Integer), - ) - - included_parts = select([ - parts.c.sub_part, - parts.c.part, - parts.c.quantity]).\\ - where(parts.c.part=='our part').\\ - cte(recursive=True) - - - incl_alias = included_parts.alias() - parts_alias = parts.alias() - included_parts = included_parts.union_all( - select([ - parts_alias.c.sub_part, - parts_alias.c.part, - parts_alias.c.quantity - ]). - where(parts_alias.c.part==incl_alias.c.sub_part) - ) - - statement = select([ - included_parts.c.sub_part, - func.sum(included_parts.c.quantity). - label('total_quantity') - ]).\\ - group_by(included_parts.c.sub_part) - - result = conn.execute(statement).fetchall() - - - .. seealso:: - - :meth:`.orm.query.Query.cte` - ORM version of - :meth:`.SelectBase.cte`. - - """ - return CTE(self, name=name, recursive=recursive) - @_generative @util.deprecated('0.6', message="``autocommit()`` is deprecated. Use " diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index bb5a96256..21f9f68fb 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -296,6 +296,8 @@ class AssertsCompiledSQL(object): dialect = config.db.dialect elif dialect == 'default': dialect = default.DefaultDialect() + elif dialect == 'default_enhanced': + dialect = default.StrCompileDialect() elif isinstance(dialect, util.string_types): dialect = url.URL(dialect).get_dialect()() diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index b59914afc..aa674403e 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -1,6 +1,6 @@ from sqlalchemy.testing import fixtures from sqlalchemy.testing import AssertsCompiledSQL, assert_raises_message -from sqlalchemy.sql import table, column, select, func, literal +from sqlalchemy.sql import table, column, select, func, literal, exists, and_ from sqlalchemy.dialects import mssql from sqlalchemy.engine import default from sqlalchemy.exc import CompileError @@ -8,7 +8,7 @@ from sqlalchemy.exc import CompileError class CTETest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = 'default_enhanced' def test_nonrecursive(self): orders = table('orders', @@ -492,3 +492,151 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): 'regional_sales WHERE "order"."order" > regional_sales."order"', dialect='postgresql' ) + + def test_upsert_from_select(self): + orders = table( + 'orders', + column('region'), + column('amount'), + column('product'), + column('quantity') + ) + + upsert = ( + orders.update() + .where(orders.c.region == 'Region1') + .values(amount=1.0, product='Product1', quantity=1) + .returning(*(orders.c._all_columns)).cte('upsert')) + + insert = orders.insert().from_select( + orders.c.keys(), + select([ + literal('Region1'), literal(1.0), + literal('Product1'), literal(1) + ]).where(~exists(upsert.select())) + ) + + self.assert_compile( + insert, + "WITH upsert AS (UPDATE orders SET amount=:amount, " + "product=:product, quantity=:quantity " + "WHERE orders.region = :region_1 " + "RETURNING orders.region, orders.amount, " + "orders.product, orders.quantity) " + "INSERT INTO orders (region, amount, product, quantity) " + "SELECT :param_1 AS anon_1, :param_2 AS anon_2, " + ":param_3 AS anon_3, :param_4 AS anon_4 WHERE NOT (EXISTS " + "(SELECT upsert.region, upsert.amount, upsert.product, " + "upsert.quantity FROM upsert))" + ) + + def test_pg_example_one(self): + products = table('products', column('id'), column('date')) + products_log = table('products_log', column('id'), column('date')) + + moved_rows = products.delete().where(and_( + products.c.date >= 'dateone', + products.c.date < 'datetwo')).returning(*products.c).\ + cte('moved_rows') + + stmt = products_log.insert().from_select( + products_log.c, moved_rows.select()) + self.assert_compile( + stmt, + "WITH moved_rows AS " + "(DELETE FROM products WHERE products.date >= :date_1 " + "AND products.date < :date_2 " + "RETURNING products.id, products.date) " + "INSERT INTO products_log (id, date) " + "SELECT moved_rows.id, moved_rows.date FROM moved_rows" + ) + + def test_pg_example_two(self): + products = table('products', column('id'), column('price')) + + t = products.update().values(price='someprice').\ + returning(*products.c).cte('t') + stmt = t.select() + + self.assert_compile( + stmt, + "WITH t AS " + "(UPDATE products SET price=:price " + "RETURNING products.id, products.price) " + "SELECT t.id, t.price " + "FROM t" + ) + + def test_pg_example_three(self): + + parts = table( + 'parts', + column('part'), + column('sub_part'), + ) + + included_parts = select([ + parts.c.sub_part, + parts.c.part]).\ + where(parts.c.part == 'our part').\ + cte("included_parts", recursive=True) + + pr = included_parts.alias('pr') + p = parts.alias('p') + included_parts = included_parts.union_all( + select([ + p.c.sub_part, + p.c.part]). + where(p.c.part == pr.c.sub_part) + ) + stmt = parts.delete().where( + parts.c.part.in_(select([included_parts.c.part]))).returning( + parts.c.part) + + # the outer RETURNING is a bonus over what PG's docs have + self.assert_compile( + stmt, + "WITH RECURSIVE included_parts(sub_part, part) AS " + "(SELECT parts.sub_part AS sub_part, parts.part AS part " + "FROM parts " + "WHERE parts.part = :part_1 " + "UNION ALL SELECT p.sub_part AS sub_part, p.part AS part " + "FROM parts AS p, included_parts AS pr " + "WHERE p.part = pr.sub_part) " + "DELETE FROM parts WHERE parts.part IN " + "(SELECT included_parts.part FROM included_parts) " + "RETURNING parts.part" + ) + + def test_insert_in_the_cte(self): + products = table('products', column('id'), column('price')) + + cte = products.insert().values(id=1, price=27.0).\ + returning(*products.c).cte('pd') + + stmt = select([cte]) + + self.assert_compile( + stmt, + "WITH pd AS " + "(INSERT INTO products (id, price) VALUES (:id, :price) " + "RETURNING products.id, products.price) " + "SELECT pd.id, pd.price " + "FROM pd" + ) + + def test_update_pulls_from_cte(self): + products = table('products', column('id'), column('price')) + + cte = products.select().cte('pd') + + stmt = products.update().where(products.c.price == cte.c.price) + + self.assert_compile( + stmt, + "WITH pd AS " + "(SELECT products.id AS id, products.price AS price " + "FROM products) " + "UPDATE products SET id=:id, price=:price FROM pd " + "WHERE products.price = pd.price" + ) diff --git a/test/sql/test_insert.py b/test/sql/test_insert.py index ea4de032c..513757d5b 100644 --- a/test/sql/test_insert.py +++ b/test/sql/test_insert.py @@ -188,9 +188,10 @@ class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): from_select(("otherid", "othername"), sel) self.assert_compile( ins, - "INSERT INTO myothertable (otherid, othername) WITH anon_1 AS " + "WITH anon_1 AS " "(SELECT mytable.name AS name FROM mytable " "WHERE mytable.name = :name_1) " + "INSERT INTO myothertable (otherid, othername) " "SELECT mytable.myid, mytable.name FROM mytable, anon_1 " "WHERE mytable.name = anon_1.name", checkparams={"name_1": "bar"} @@ -205,9 +206,9 @@ class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): self.assert_compile( ins, - "INSERT INTO mytable (myid, name, description) " "WITH c AS (SELECT mytable.myid AS myid, mytable.name AS name, " "mytable.description AS description FROM mytable) " + "INSERT INTO mytable (myid, name, description) " "SELECT c.myid, c.name, c.description FROM c" ) |
