diff options
Diffstat (limited to 'lib')
| -rw-r--r-- | lib/sqlalchemy/engine/default.py | 5 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 30 |
2 files changed, 28 insertions, 7 deletions
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 539af2507..4861214c4 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -1389,7 +1389,6 @@ class DefaultExecutionContext(interfaces.ExecutionContext): def _setup_ins_pk_from_empty(self): getter = self.compiled._inserted_primary_key_from_lastrowid_getter - return [getter(None, param) for param in self.compiled_parameters] def _setup_ins_pk_from_implicit_returning(self, result, rows): @@ -1664,7 +1663,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): return self._exec_default(column, column.onupdate, column.type) def _process_executemany_defaults(self): - key_getter = self.compiled._key_getters_for_crud_column[2] + key_getter = self.compiled._within_exec_param_key_getter scalar_defaults = {} @@ -1702,7 +1701,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): del self.current_parameters def _process_executesingle_defaults(self): - key_getter = self.compiled._key_getters_for_crud_column[2] + key_getter = self.compiled._within_exec_param_key_getter self.current_parameters = ( compiled_parameters ) = self.compiled_parameters[0] diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 8a3f26425..9cf4d8397 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1255,15 +1255,28 @@ class SQLCompiler(Compiled): ) @util.memoized_property + def _within_exec_param_key_getter(self): + getter = self._key_getters_for_crud_column[2] + if self.escaped_bind_names: + + def _get(obj): + key = getter(obj) + return self.escaped_bind_names.get(key, key) + + return _get + else: + return getter + + @util.memoized_property @util.preload_module("sqlalchemy.engine.result") def _inserted_primary_key_from_lastrowid_getter(self): result = util.preloaded.engine_result - key_getter = self._key_getters_for_crud_column[2] + param_key_getter = self._within_exec_param_key_getter table = self.statement.table getters = [ - (operator.methodcaller("get", key_getter(col), None), col) + (operator.methodcaller("get", param_key_getter(col), None), col) for col in table.primary_key ] @@ -1279,6 +1292,12 @@ class SQLCompiler(Compiled): row_fn = result.result_tuple([col.key for col in table.primary_key]) def get(lastrowid, parameters): + """given cursor.lastrowid value and the parameters used for INSERT, + return a "row" that represents the primary key, either by + using the "lastrowid" or by extracting values from the parameters + that were sent along with the INSERT. + + """ if proc is not None: lastrowid = proc(lastrowid) @@ -1297,7 +1316,7 @@ class SQLCompiler(Compiled): def _inserted_primary_key_from_returning_getter(self): result = util.preloaded.engine_result - key_getter = self._key_getters_for_crud_column[2] + param_key_getter = self._within_exec_param_key_getter table = self.statement.table ret = {col: idx for idx, col in enumerate(self.returning)} @@ -1305,7 +1324,10 @@ class SQLCompiler(Compiled): getters = [ (operator.itemgetter(ret[col]), True) if col in ret - else (operator.methodcaller("get", key_getter(col), None), False) + else ( + operator.methodcaller("get", param_key_getter(col), None), + False, + ) for col in table.primary_key ] |
