summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2019-02-11 17:00:47 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2019-02-12 16:55:48 -0500
commitc1f310df44033d943413170de878ce95fafa387e (patch)
tree1c053d5a0bf8610393ba38bbb19a576383da357b /lib/sqlalchemy
parentbb7b353d6f97184d2689c8c682bab5caac4ec1e7 (diff)
downloadsqlalchemy-c1f310df44033d943413170de878ce95fafa387e.tar.gz
Allow SQL expression for ORM primary keys
A SQL expression can now be assigned to a primary key attribute for an ORM flush in the same manner as ordinary attributes as described in :ref:`flush_embedded_sql_expressions` where the expression will be evaulated and then returned to the ORM using RETURNING, or in the case of pysqlite, works using the cursor.lastrowid attribute.Requires either a database that supports RETURNING (e.g. Postgresql, Oracle, SQL Server) or pysqlite. Fixes: #3133 Fixes: #4494 Change-Id: I83da8357354de002cb04fa4a553f2a2f90c5157d
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/orm/persistence.py37
-rw-r--r--lib/sqlalchemy/sql/crud.py7
2 files changed, 38 insertions, 6 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index c90c8d91e..6345ee28a 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -507,7 +507,7 @@ def _collect_insert_commands(
and hasattr(value, "__clause_element__")
or isinstance(value, sql.ClauseElement)
):
- value_params[col.key] = (
+ value_params[col] = (
value.__clause_element__()
if hasattr(value, "__clause_element__")
else value
@@ -525,7 +525,7 @@ def _collect_insert_commands(
for colkey in (
mapper._insert_cols_as_none[table]
.difference(params)
- .difference(value_params)
+ .difference([c.key for c in value_params])
):
params[colkey] = None
@@ -932,6 +932,7 @@ def _emit_update_statements(
c,
c.context.compiled_parameters[0],
value_params,
+ True,
)
rows += c.rowcount
check_rowcount = assert_singlerow
@@ -963,6 +964,7 @@ def _emit_update_statements(
c,
c.context.compiled_parameters[0],
value_params,
+ True,
)
rows += c.rowcount
else:
@@ -998,6 +1000,7 @@ def _emit_update_statements(
c,
c.context.compiled_parameters[0],
value_params,
+ True,
)
if check_rowcount:
@@ -1086,6 +1089,7 @@ def _emit_insert_statements(
c,
last_inserted_params,
value_params,
+ False,
)
else:
_postfetch_bulk_save(mapper_rec, state_dict, table)
@@ -1117,14 +1121,16 @@ def _emit_insert_statements(
)
primary_key = result.context.inserted_primary_key
-
if primary_key is not None:
# set primary key attributes
for pk, col in zip(
primary_key, mapper._pks_by_table[table]
):
prop = mapper_rec._columntoproperty[col]
- if state_dict.get(prop.key) is None:
+ if pk is not None and (
+ col in value_params
+ or state_dict.get(prop.key) is None
+ ):
state_dict[prop.key] = pk
if bookkeeping:
if state:
@@ -1137,6 +1143,7 @@ def _emit_insert_statements(
result,
result.context.compiled_parameters[0],
value_params,
+ False,
)
else:
_postfetch_bulk_save(mapper_rec, state_dict, table)
@@ -1461,7 +1468,15 @@ def _postfetch_post_update(
def _postfetch(
- mapper, uowtransaction, table, state, dict_, result, params, value_params
+ mapper,
+ uowtransaction,
+ table,
+ state,
+ dict_,
+ result,
+ params,
+ value_params,
+ isupdate,
):
"""Expire attributes in need of newly persisted database state,
after an INSERT or UPDATE statement has proceeded for that
@@ -1511,6 +1526,18 @@ def _postfetch(
state, uowtransaction, load_evt_attrs
)
+ if isupdate and value_params:
+ # explicitly suit the use case specified by
+ # [ticket:3801], PK SQL expressions for UPDATE on non-RETURNING
+ # database which are set to themselves in order to do a version bump.
+ postfetch_cols.extend(
+ [
+ col
+ for col in value_params
+ if col.primary_key and col not in returning_cols
+ ]
+ )
+
if postfetch_cols:
state._expire_attributes(
state.dict,
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index cc72073ca..6c9b8ee5b 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -409,7 +409,12 @@ def _append_param_parameter(
compiler.returning.append(c)
value = compiler.process(value.self_group(), **kw)
else:
- compiler.postfetch.append(c)
+ # postfetch specifically means, "we can SELECT the row we just
+ # inserted by primary key to get back the server generated
+ # defaults". so by definition this can't be used to get the primary
+ # key value back, because we need to have it ahead of time.
+ if not c.primary_key:
+ compiler.postfetch.append(c)
value = compiler.process(value.self_group(), **kw)
values.append((c, value))