diff options
Diffstat (limited to 'lib/sqlalchemy/sql/crud.py')
| -rw-r--r-- | lib/sqlalchemy/sql/crud.py | 26 |
1 files changed, 18 insertions, 8 deletions
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index e6f16b698..614f9413b 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -26,7 +26,7 @@ values present. """) -def _get_crud_params(compiler, stmt, **kw): +def _get_crud_params(compiler, stmt, keep_order=False, **kw): """create a set of tuples representing column/string pairs for use in an INSERT or UPDATE statement. @@ -64,12 +64,12 @@ def _get_crud_params(compiler, stmt, **kw): # if we have statement parameters - set defaults in the # compiled params if compiler.column_keys is None: - parameters = {} + parameters = util.OrderedDict() else: - parameters = dict((_column_as_key(key), REQUIRED) - for key in compiler.column_keys - if not stmt_parameters or - key not in stmt_parameters) + parameters = util.OrderedDict((_column_as_key(key), REQUIRED) + for key in compiler.column_keys + if not stmt_parameters or + key not in stmt_parameters) # create a list of column assignment clauses as tuples values = [] @@ -97,7 +97,7 @@ def _get_crud_params(compiler, stmt, **kw): _scan_cols( compiler, stmt, parameters, _getattr_col_key, _column_as_key, - _col_bind_name, check_columns, values, kw) + _col_bind_name, check_columns, values, kw, keep_order=keep_order) if parameters and stmt_parameters: check = set(parameters).intersection( @@ -202,7 +202,7 @@ def _scan_insert_from_select_cols( def _scan_cols( compiler, stmt, parameters, _getattr_col_key, - _column_as_key, _col_bind_name, check_columns, values, kw): + _column_as_key, _col_bind_name, check_columns, values, kw, keep_order): need_pks, implicit_returning, \ implicit_return_defaults, postfetch_lastrowid = \ @@ -210,6 +210,16 @@ def _scan_cols( cols = stmt.table.columns + if keep_order: + # Order columns with parameters first, preserving their original order, + # and then the rest of the columns + keys = tuple(parameters.keys()) if parameters else tuple() + table_cols = tuple(cols) + cols = sorted(table_cols, + key=(lambda x: keys.index(_getattr_col_key(x)) + if _getattr_col_key(x) in keys + else len(keys) + table_cols.index(x))) + for c in cols: col_key = _getattr_col_key(c) if col_key in parameters and col_key not in check_columns: |
