summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/crud.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/crud.py')
-rw-r--r--lib/sqlalchemy/sql/crud.py97
1 files changed, 77 insertions, 20 deletions
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index 1c1f661d2..831d05be1 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -89,18 +89,15 @@ def _get_crud_params(compiler, stmt, **kw):
_col_bind_name, _getattr_col_key, values, kw)
if compiler.isinsert and stmt.select_names:
- # for an insert from select, we can only use names that
- # are given, so only select for those names.
- cols = (stmt.table.c[_column_as_key(name)]
- for name in stmt.select_names)
+ _scan_insert_from_select_cols(
+ compiler, stmt, parameters,
+ _getattr_col_key, _column_as_key,
+ _col_bind_name, check_columns, values, kw)
else:
- # iterate through all table columns to maintain
- # ordering, even for those cols that aren't included
- cols = stmt.table.columns
-
- _scan_cols(
- compiler, stmt, cols, parameters,
- _getattr_col_key, _col_bind_name, check_columns, values, kw)
+ _scan_cols(
+ compiler, stmt, parameters,
+ _getattr_col_key, _column_as_key,
+ _col_bind_name, check_columns, values, kw)
if parameters and stmt_parameters:
check = set(parameters).intersection(
@@ -118,13 +115,17 @@ def _get_crud_params(compiler, stmt, **kw):
return values
-def _create_bind_param(compiler, col, value, required=False, name=None):
+def _create_bind_param(
+ compiler, col, value, process=True, required=False, name=None):
if name is None:
name = col.key
bindparam = elements.BindParameter(name, value,
type_=col.type, required=required)
bindparam._is_crud = True
- return bindparam._compiler_dispatch(compiler)
+ if process:
+ bindparam = bindparam._compiler_dispatch(compiler)
+ return bindparam
+
def _key_getters_for_crud_column(compiler):
if compiler.isupdate and compiler.statement._extra_froms:
@@ -162,14 +163,52 @@ def _key_getters_for_crud_column(compiler):
return _column_as_key, _getattr_col_key, _col_bind_name
+def _scan_insert_from_select_cols(
+ compiler, stmt, parameters, _getattr_col_key,
+ _column_as_key, _col_bind_name, check_columns, values, kw):
+
+ need_pks, implicit_returning, \
+ implicit_return_defaults, postfetch_lastrowid = \
+ _get_returning_modifiers(compiler, stmt)
+
+ cols = [stmt.table.c[_column_as_key(name)]
+ for name in stmt.select_names]
+
+ compiler._insert_from_select = stmt.select
+
+ add_select_cols = []
+ if stmt.include_insert_from_select_defaults:
+ col_set = set(cols)
+ for col in stmt.table.columns:
+ if col not in col_set and col.default:
+ cols.append(col)
+
+ for c in cols:
+ col_key = _getattr_col_key(c)
+ if col_key in parameters and col_key not in check_columns:
+ parameters.pop(col_key)
+ values.append((c, None))
+ else:
+ _append_param_insert_select_hasdefault(
+ compiler, stmt, c, add_select_cols, kw)
+
+ if add_select_cols:
+ values.extend(add_select_cols)
+ compiler._insert_from_select = compiler._insert_from_select._generate()
+ compiler._insert_from_select._raw_columns += tuple(
+ expr for col, expr in add_select_cols)
+
+
def _scan_cols(
- compiler, stmt, cols, parameters, _getattr_col_key,
- _col_bind_name, check_columns, values, kw):
+ compiler, stmt, parameters, _getattr_col_key,
+ _column_as_key, _col_bind_name, check_columns, values, kw):
need_pks, implicit_returning, \
implicit_return_defaults, postfetch_lastrowid = \
_get_returning_modifiers(compiler, stmt)
+ cols = stmt.table.columns
+
for c in cols:
col_key = _getattr_col_key(c)
if col_key in parameters and col_key not in check_columns:
@@ -196,7 +235,8 @@ def _scan_cols(
elif c.default is not None:
_append_param_insert_hasdefault(
- compiler, stmt, c, implicit_return_defaults, values, kw)
+ compiler, stmt, c, implicit_return_defaults,
+ values, kw)
elif c.server_default is not None:
if implicit_return_defaults and \
@@ -299,10 +339,8 @@ def _append_param_insert_hasdefault(
elif not c.primary_key:
compiler.postfetch.append(c)
elif c.default.is_clause_element:
- values.append(
- (c, compiler.process(
- c.default.arg.self_group(), **kw))
- )
+ proc = compiler.process(c.default.arg.self_group(), **kw)
+ values.append((c, proc))
if implicit_return_defaults and \
c in implicit_return_defaults:
@@ -317,6 +355,25 @@ def _append_param_insert_hasdefault(
compiler.prefetch.append(c)
+def _append_param_insert_select_hasdefault(
+ compiler, stmt, c, values, kw):
+
+ if c.default.is_sequence:
+ if compiler.dialect.supports_sequences and \
+ (not c.default.optional or
+ not compiler.dialect.sequences_optional):
+ proc = c.default
+ values.append((c, proc))
+ elif c.default.is_clause_element:
+ proc = c.default.arg.self_group()
+ values.append((c, proc))
+ else:
+ values.append(
+ (c, _create_bind_param(compiler, c, None, process=False))
+ )
+ compiler.prefetch.append(c)
+
+
def _append_param_update(
compiler, stmt, c, implicit_return_defaults, values, kw):