summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2015-01-19 11:47:28 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2015-01-19 11:47:28 -0500
commit6135c03230391a2230e2280f2cbb8b02d880db32 (patch)
tree637038e3f4b28d7a354efddd0ab3713b15d5c876
parent3f84a9408064bc2064bc706a369ecb463df17789 (diff)
downloadsqlalchemy-6135c03230391a2230e2280f2cbb8b02d880db32.tar.gz
- further fixes and even better tests for this block
-rw-r--r--lib/sqlalchemy/orm/persistence.py11
-rw-r--r--test/orm/test_versioning.py28
2 files changed, 35 insertions, 4 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index 7f81a5c99..e553f399d 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -642,8 +642,10 @@ def _emit_update_statements(base_mapper, uowtransaction,
c.context.compiled_parameters[0],
value_params)
rows += c.rowcount
+ check_rowcount = True
else:
if not allow_multirow:
+ check_rowcount = assert_singlerow
for state, state_dict, params, mapper, \
connection, value_params in records:
c = cached_connections[connection].\
@@ -661,6 +663,11 @@ def _emit_update_statements(base_mapper, uowtransaction,
else:
multiparams = [rec[2] for rec in records]
+ check_rowcount = assert_multirow or (
+ assert_singlerow and
+ len(multiparams) == 1
+ )
+
c = cached_connections[connection].\
execute(statement, multiparams)
@@ -677,9 +684,7 @@ def _emit_update_statements(base_mapper, uowtransaction,
c.context.compiled_parameters[0],
value_params)
- if hasvalue or assert_multirow or (
- assert_singlerow and
- len(multiparams)) == 1:
+ if check_rowcount:
if rows != len(records):
raise orm_exc.StaleDataError(
"UPDATE statement on table '%s' expected to "
diff --git a/test/orm/test_versioning.py b/test/orm/test_versioning.py
index 55ce586b5..8348cb588 100644
--- a/test/orm/test_versioning.py
+++ b/test/orm/test_versioning.py
@@ -1,7 +1,8 @@
import datetime
import sqlalchemy as sa
-from sqlalchemy.testing import engines
+from sqlalchemy.testing import engines, config
from sqlalchemy import testing
+from sqlalchemy.testing.mock import patch
from sqlalchemy import (
Integer, String, Date, ForeignKey, orm, exc, select, TypeDecorator)
from sqlalchemy.testing.schema import Table, Column
@@ -12,6 +13,7 @@ from sqlalchemy.testing import (
eq_, assert_raises, assert_raises_message, fixtures)
from sqlalchemy.testing.assertsql import CompiledSQL
import uuid
+from sqlalchemy import util
def make_uuid():
@@ -223,6 +225,30 @@ class VersioningTest(fixtures.MappedTest):
s1.refresh(f1s1, lockmode='update_nowait')
assert f1s1.version_id == f1s2.version_id
+ def test_update_multi_missing_broken_multi_rowcount(self):
+ @util.memoized_property
+ def rowcount(self):
+ if len(self.context.compiled_parameters) > 1:
+ return -1
+ else:
+ return self.context.rowcount
+
+ with patch.object(
+ config.db.dialect, "supports_sane_multi_rowcount", False):
+ with patch(
+ "sqlalchemy.engine.result.ResultProxy.rowcount",
+ rowcount):
+
+ Foo = self.classes.Foo
+ s1 = self._fixture()
+ f1s1 = Foo(value='f1 value')
+ s1.add(f1s1)
+ s1.commit()
+
+ f1s1.value = 'f2 value'
+ s1.flush()
+ eq_(f1s1.version_id, 2)
+
@testing.emits_warning(r'.*does not support updated rowcount')
@engines.close_open_connections
def test_noversioncheck(self):