summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2016-09-20 11:56:05 -0400
committerGerrit Code Review <gerrit2@ln3.zzzcomputing.com>2016-09-20 11:56:05 -0400
commit5af3c903368e9a437a6ceafce5dc993433420cc7 (patch)
tree65a957c6154b6d43a001241330d756bf70d8efc3
parentb9a7a74d5e729408fcac86fe2919aa423c59d863 (diff)
parentf8ecdf47f0975b8b4e357fde2008d9aae8c50239 (diff)
downloadsqlalchemy-5af3c903368e9a437a6ceafce5dc993433420cc7.tar.gz
Merge "Allow SQL expressions to be set on PK columns"
-rw-r--r--doc/build/changelog/changelog_11.rst10
-rw-r--r--lib/sqlalchemy/orm/persistence.py28
-rw-r--r--test/orm/test_naturalpks.py37
-rw-r--r--test/orm/test_versioning.py32
4 files changed, 98 insertions, 9 deletions
diff --git a/doc/build/changelog/changelog_11.rst b/doc/build/changelog/changelog_11.rst
index 0d6c6b159..efd2d33f9 100644
--- a/doc/build/changelog/changelog_11.rst
+++ b/doc/build/changelog/changelog_11.rst
@@ -44,6 +44,16 @@
.. change::
:tags: bug, orm
+ :tickets: 3801
+
+ An UPDATE emitted from the ORM flush process can now accommodate a
+ SQL expression element for a column within the primary key of an
+ object, if the target database supports RETURNING in order to provide
+ the new value, or if the PK value is set "to itself" for the purposes
+ of bumping some other trigger / onupdate on the column.
+
+ .. change::
+ :tags: bug, orm
:tickets: 3788
Fixed bug where the "simple many-to-one" condition that allows lazy
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index 0b029f466..56b028375 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -506,6 +506,7 @@ def _collect_update_commands(
elif not (params or value_params):
continue
+ has_all_pks = True
if bulk:
pk_params = dict(
(propkey_to_col[propkey]._label, state_dict.get(propkey))
@@ -530,7 +531,8 @@ def _collect_update_commands(
else:
# else, use the old value to locate the row
pk_params[col._label] = history.deleted[0]
- params[col.key] = history.added[0]
+ if col in value_params:
+ has_all_pks = False
else:
pk_params[col._label] = history.unchanged[0]
if pk_params[col._label] is None:
@@ -542,7 +544,7 @@ def _collect_update_commands(
params.update(pk_params)
yield (
state, state_dict, params, mapper,
- connection, value_params, has_all_defaults)
+ connection, value_params, has_all_defaults, has_all_pks)
def _collect_post_update_commands(base_mapper, uowtransaction, table,
@@ -636,14 +638,15 @@ def _emit_update_statements(base_mapper, uowtransaction,
cached_stmt = base_mapper._memo(('update', table), update_stmt)
- for (connection, paramkeys, hasvalue, has_all_defaults), \
+ for (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks), \
records in groupby(
update,
lambda rec: (
rec[4], # connection
set(rec[2]), # set of parameter keys
bool(rec[5]), # whether or not we have "value" parameters
- rec[6] # has_all_defaults
+ rec[6], # has_all_defaults
+ rec[7] # has all pks
)
):
rows = 0
@@ -659,7 +662,9 @@ def _emit_update_statements(base_mapper, uowtransaction,
connection.dialect.supports_sane_multi_rowcount
allow_multirow = has_all_defaults and not needs_version_id
- if bookkeeping and not has_all_defaults and \
+ if not has_all_pks:
+ statement = statement.return_defaults()
+ elif bookkeeping and not has_all_defaults and \
mapper.base_mapper.eager_defaults:
statement = statement.return_defaults()
elif mapper.version_id_col is not None:
@@ -667,7 +672,8 @@ def _emit_update_statements(base_mapper, uowtransaction,
if hasvalue:
for state, state_dict, params, mapper, \
- connection, value_params, has_all_defaults in records:
+ connection, value_params, \
+ has_all_defaults, has_all_pks in records:
c = connection.execute(
statement.values(value_params),
params)
@@ -687,7 +693,8 @@ def _emit_update_statements(base_mapper, uowtransaction,
if not allow_multirow:
check_rowcount = assert_singlerow
for state, state_dict, params, mapper, \
- connection, value_params, has_all_defaults in records:
+ connection, value_params, has_all_defaults, \
+ has_all_pks in records:
c = cached_connections[connection].\
execute(statement, params)
@@ -717,7 +724,8 @@ def _emit_update_statements(base_mapper, uowtransaction,
rows += c.rowcount
for state, state_dict, params, mapper, \
- connection, value_params, has_all_defaults in records:
+ connection, value_params, \
+ has_all_defaults, has_all_pks in records:
if bookkeeping:
_postfetch(
mapper,
@@ -1013,7 +1021,9 @@ def _postfetch(mapper, uowtransaction, table,
row = result.context.returned_defaults
if row is not None:
for col in returning_cols:
- if col.primary_key:
+ # pk cols returned from insert are handled
+ # distinctly, don't step on the values here
+ if col.primary_key and result.context.isinsert:
continue
dict_[mapper._columntoproperty[col].key] = row[col]
if refresh_flush:
diff --git a/test/orm/test_naturalpks.py b/test/orm/test_naturalpks.py
index 60387ddce..6780967c9 100644
--- a/test/orm/test_naturalpks.py
+++ b/test/orm/test_naturalpks.py
@@ -122,6 +122,43 @@ class NaturalPKTest(fixtures.MappedTest):
assert sess.query(User).get('jack') is None
assert sess.query(User).get('ed').fullname == 'jack'
+ @testing.requires.returning
+ def test_update_to_sql_expr(self):
+ users, User = self.tables.users, self.classes.User
+
+ mapper(User, users)
+
+ sess = create_session()
+ u1 = User(username='jack', fullname='jack')
+
+ sess.add(u1)
+ sess.flush()
+
+ u1.username = User.username + ' jones'
+
+ sess.flush()
+
+ eq_(u1.username, 'jack jones')
+
+ def test_update_to_self_sql_expr(self):
+ # SQL expression where the PK won't actually change,
+ # such as to bump a server side trigger
+ users, User = self.tables.users, self.classes.User
+
+ mapper(User, users)
+
+ sess = create_session()
+ u1 = User(username='jack', fullname='jack')
+
+ sess.add(u1)
+ sess.flush()
+
+ u1.username = User.username + ''
+
+ sess.flush()
+
+ eq_(u1.username, 'jack')
+
def test_flush_new_pk_after_expire(self):
User, users = self.classes.User, self.tables.users
diff --git a/test/orm/test_versioning.py b/test/orm/test_versioning.py
index 07b090c60..40b373097 100644
--- a/test/orm/test_versioning.py
+++ b/test/orm/test_versioning.py
@@ -1020,6 +1020,38 @@ class ServerVersioningTest(fixtures.MappedTest):
)
self.assert_sql_execution(testing.db, sess.flush, *statements)
+ def test_sql_expr_bump(self):
+ sess = self._fixture()
+
+ f1 = self.classes.Foo(value='f1')
+ sess.add(f1)
+ sess.flush()
+
+ eq_(f1.version_id, 1)
+
+ f1.id = self.classes.Foo.id + 0
+
+ sess.flush()
+
+ eq_(f1.version_id, 2)
+
+ @testing.requires.returning
+ def test_sql_expr_w_mods_bump(self):
+ sess = self._fixture()
+
+ f1 = self.classes.Foo(id=2, value='f1')
+ sess.add(f1)
+ sess.flush()
+
+ eq_(f1.version_id, 1)
+
+ f1.id = self.classes.Foo.id + 3
+
+ sess.flush()
+
+ eq_(f1.id, 5)
+ eq_(f1.version_id, 2)
+
def test_multi_update(self):
sess = self._fixture()