diff options
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/crud.py | 97 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/dml.py | 26 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/suite/test_insert.py | 37 |
4 files changed, 132 insertions, 30 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 86f00d944..a6c30b7dc 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1793,7 +1793,7 @@ class SQLCompiler(Compiled): text += " " + returning_clause if insert_stmt.select is not None: - text += " %s" % self.process(insert_stmt.select, **kw) + text += " %s" % self.process(self._insert_from_select, **kw) elif not crud_params and supports_default_values: text += " DEFAULT VALUES" elif insert_stmt._has_multi_parameters: diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 1c1f661d2..831d05be1 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -89,18 +89,15 @@ def _get_crud_params(compiler, stmt, **kw): _col_bind_name, _getattr_col_key, values, kw) if compiler.isinsert and stmt.select_names: - # for an insert from select, we can only use names that - # are given, so only select for those names. - cols = (stmt.table.c[_column_as_key(name)] - for name in stmt.select_names) + _scan_insert_from_select_cols( + compiler, stmt, parameters, + _getattr_col_key, _column_as_key, + _col_bind_name, check_columns, values, kw) else: - # iterate through all table columns to maintain - # ordering, even for those cols that aren't included - cols = stmt.table.columns - - _scan_cols( - compiler, stmt, cols, parameters, - _getattr_col_key, _col_bind_name, check_columns, values, kw) + _scan_cols( + compiler, stmt, parameters, + _getattr_col_key, _column_as_key, + _col_bind_name, check_columns, values, kw) if parameters and stmt_parameters: check = set(parameters).intersection( @@ -118,13 +115,17 @@ def _get_crud_params(compiler, stmt, **kw): return values -def _create_bind_param(compiler, col, value, required=False, name=None): +def _create_bind_param( + compiler, col, value, process=True, required=False, name=None): if name is None: name = col.key bindparam = elements.BindParameter(name, value, type_=col.type, required=required) bindparam._is_crud = True - return bindparam._compiler_dispatch(compiler) + if process: + bindparam = bindparam._compiler_dispatch(compiler) + return bindparam + def _key_getters_for_crud_column(compiler): if compiler.isupdate and compiler.statement._extra_froms: @@ -162,14 +163,52 @@ def _key_getters_for_crud_column(compiler): return _column_as_key, _getattr_col_key, _col_bind_name +def _scan_insert_from_select_cols( + compiler, stmt, parameters, _getattr_col_key, + _column_as_key, _col_bind_name, check_columns, values, kw): + + need_pks, implicit_returning, \ + implicit_return_defaults, postfetch_lastrowid = \ + _get_returning_modifiers(compiler, stmt) + + cols = [stmt.table.c[_column_as_key(name)] + for name in stmt.select_names] + + compiler._insert_from_select = stmt.select + + add_select_cols = [] + if stmt.include_insert_from_select_defaults: + col_set = set(cols) + for col in stmt.table.columns: + if col not in col_set and col.default: + cols.append(col) + + for c in cols: + col_key = _getattr_col_key(c) + if col_key in parameters and col_key not in check_columns: + parameters.pop(col_key) + values.append((c, None)) + else: + _append_param_insert_select_hasdefault( + compiler, stmt, c, add_select_cols, kw) + + if add_select_cols: + values.extend(add_select_cols) + compiler._insert_from_select = compiler._insert_from_select._generate() + compiler._insert_from_select._raw_columns += tuple( + expr for col, expr in add_select_cols) + + def _scan_cols( - compiler, stmt, cols, parameters, _getattr_col_key, - _col_bind_name, check_columns, values, kw): + compiler, stmt, parameters, _getattr_col_key, + _column_as_key, _col_bind_name, check_columns, values, kw): need_pks, implicit_returning, \ implicit_return_defaults, postfetch_lastrowid = \ _get_returning_modifiers(compiler, stmt) + cols = stmt.table.columns + for c in cols: col_key = _getattr_col_key(c) if col_key in parameters and col_key not in check_columns: @@ -196,7 +235,8 @@ def _scan_cols( elif c.default is not None: _append_param_insert_hasdefault( - compiler, stmt, c, implicit_return_defaults, values, kw) + compiler, stmt, c, implicit_return_defaults, + values, kw) elif c.server_default is not None: if implicit_return_defaults and \ @@ -299,10 +339,8 @@ def _append_param_insert_hasdefault( elif not c.primary_key: compiler.postfetch.append(c) elif c.default.is_clause_element: - values.append( - (c, compiler.process( - c.default.arg.self_group(), **kw)) - ) + proc = compiler.process(c.default.arg.self_group(), **kw) + values.append((c, proc)) if implicit_return_defaults and \ c in implicit_return_defaults: @@ -317,6 +355,25 @@ def _append_param_insert_hasdefault( compiler.prefetch.append(c) +def _append_param_insert_select_hasdefault( + compiler, stmt, c, values, kw): + + if c.default.is_sequence: + if compiler.dialect.supports_sequences and \ + (not c.default.optional or + not compiler.dialect.sequences_optional): + proc = c.default + values.append((c, proc)) + elif c.default.is_clause_element: + proc = c.default.arg.self_group() + values.append((c, proc)) + else: + values.append( + (c, _create_bind_param(compiler, c, None, process=False)) + ) + compiler.prefetch.append(c) + + def _append_param_update( compiler, stmt, c, implicit_return_defaults, values, kw): diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 1934d0776..9f2ce7ce3 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -475,6 +475,7 @@ class Insert(ValuesBase): ValuesBase.__init__(self, table, values, prefixes) self._bind = bind self.select = self.select_names = None + self.include_insert_from_select_defaults = False self.inline = inline self._returning = returning self._validate_dialect_kwargs(dialect_kw) @@ -487,7 +488,7 @@ class Insert(ValuesBase): return () @_generative - def from_select(self, names, select): + def from_select(self, names, select, include_defaults=True): """Return a new :class:`.Insert` construct which represents an ``INSERT...FROM SELECT`` statement. @@ -506,6 +507,21 @@ class Insert(ValuesBase): is not checked before passing along to the database, the database would normally raise an exception if these column lists don't correspond. + :param include_defaults: if True, non-server default values and + SQL expressions as specified on :class:`.Column` objects + (as documented in :ref:`metadata_defaults_toplevel`) not + otherwise specified in the list of names will be rendered + into the INSERT and SELECT statements, so that these values are also + included in the data to be inserted. + + .. note:: A Python-side default that uses a Python callable function + will only be invoked **once** for the whole statement, and **not + per row**. + + .. versionadded:: 1.0.0 - :meth:`.Insert.from_select` now renders + Python-side and SQL expression column defaults into the + SELECT statement for columns otherwise not included in the + list of column names. .. versionchanged:: 1.0.0 an INSERT that uses FROM SELECT implies that the :paramref:`.insert.inline` flag is set to @@ -514,13 +530,6 @@ class Insert(ValuesBase): deals with an arbitrary number of rows, so the :attr:`.ResultProxy.inserted_primary_key` accessor does not apply. - .. note:: - - A SELECT..INSERT construct in SQL has no VALUES clause. Therefore - :class:`.Column` objects which utilize Python-side defaults - (e.g. as described at :ref:`metadata_defaults_toplevel`) - will **not** take effect when using :meth:`.Insert.from_select`. - .. versionadded:: 0.8.3 """ @@ -533,6 +542,7 @@ class Insert(ValuesBase): self.select_names = names self.inline = True + self.include_insert_from_select_defaults = include_defaults self.select = _interpret_as_select(select) def _copy_internals(self, clone=_clone, **kw): diff --git a/lib/sqlalchemy/testing/suite/test_insert.py b/lib/sqlalchemy/testing/suite/test_insert.py index 92d3d93e5..c197145c7 100644 --- a/lib/sqlalchemy/testing/suite/test_insert.py +++ b/lib/sqlalchemy/testing/suite/test_insert.py @@ -4,7 +4,7 @@ from .. import exclusions from ..assertions import eq_ from .. import engines -from sqlalchemy import Integer, String, select, util +from sqlalchemy import Integer, String, select, literal_column from ..schema import Table, Column @@ -90,6 +90,13 @@ class InsertBehaviorTest(fixtures.TablesTest): Column('id', Integer, primary_key=True, autoincrement=False), Column('data', String(50)) ) + Table('includes_defaults', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('data', String(50)), + Column('x', Integer, default=5), + Column('y', Integer, + default=literal_column("2", type_=Integer) + 2)) def test_autoclose_on_insert(self): if requirements.returning.enabled: @@ -158,6 +165,34 @@ class InsertBehaviorTest(fixtures.TablesTest): ("data3", ), ("data3", )] ) + @requirements.insert_from_select + def test_insert_from_select_with_defaults(self): + table = self.tables.includes_defaults + config.db.execute( + table.insert(), + [ + dict(id=1, data="data1"), + dict(id=2, data="data2"), + dict(id=3, data="data3"), + ] + ) + + config.db.execute( + table.insert(inline=True). + from_select(("id", "data",), + select([table.c.id + 5, table.c.data]). + where(table.c.data.in_(["data2", "data3"])) + ), + ) + + eq_( + config.db.execute( + select([table]).order_by(table.c.data) + ).fetchall(), + [(1, 'data1', 5, 4), (2, 'data2', 5, 4), + (7, 'data2', 5, 4), (3, 'data3', 5, 4), (8, 'data3', 5, 4)] + ) + class ReturningTest(fixtures.TablesTest): run_create_tables = 'each' |