diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2015-01-19 11:47:28 -0500 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2015-01-19 11:47:28 -0500 |
commit | 6135c03230391a2230e2280f2cbb8b02d880db32 (patch) | |
tree | 637038e3f4b28d7a354efddd0ab3713b15d5c876 | |
parent | 3f84a9408064bc2064bc706a369ecb463df17789 (diff) | |
download | sqlalchemy-6135c03230391a2230e2280f2cbb8b02d880db32.tar.gz |
- further fixes and even better tests for this block
-rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 11 | ||||
-rw-r--r-- | test/orm/test_versioning.py | 28 |
2 files changed, 35 insertions, 4 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 7f81a5c99..e553f399d 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -642,8 +642,10 @@ def _emit_update_statements(base_mapper, uowtransaction, c.context.compiled_parameters[0], value_params) rows += c.rowcount + check_rowcount = True else: if not allow_multirow: + check_rowcount = assert_singlerow for state, state_dict, params, mapper, \ connection, value_params in records: c = cached_connections[connection].\ @@ -661,6 +663,11 @@ def _emit_update_statements(base_mapper, uowtransaction, else: multiparams = [rec[2] for rec in records] + check_rowcount = assert_multirow or ( + assert_singlerow and + len(multiparams) == 1 + ) + c = cached_connections[connection].\ execute(statement, multiparams) @@ -677,9 +684,7 @@ def _emit_update_statements(base_mapper, uowtransaction, c.context.compiled_parameters[0], value_params) - if hasvalue or assert_multirow or ( - assert_singlerow and - len(multiparams)) == 1: + if check_rowcount: if rows != len(records): raise orm_exc.StaleDataError( "UPDATE statement on table '%s' expected to " diff --git a/test/orm/test_versioning.py b/test/orm/test_versioning.py index 55ce586b5..8348cb588 100644 --- a/test/orm/test_versioning.py +++ b/test/orm/test_versioning.py @@ -1,7 +1,8 @@ import datetime import sqlalchemy as sa -from sqlalchemy.testing import engines +from sqlalchemy.testing import engines, config from sqlalchemy import testing +from sqlalchemy.testing.mock import patch from sqlalchemy import ( Integer, String, Date, ForeignKey, orm, exc, select, TypeDecorator) from sqlalchemy.testing.schema import Table, Column @@ -12,6 +13,7 @@ from sqlalchemy.testing import ( eq_, assert_raises, assert_raises_message, fixtures) from sqlalchemy.testing.assertsql import CompiledSQL import uuid +from sqlalchemy import util def make_uuid(): @@ -223,6 +225,30 @@ class VersioningTest(fixtures.MappedTest): s1.refresh(f1s1, lockmode='update_nowait') assert f1s1.version_id == f1s2.version_id + def test_update_multi_missing_broken_multi_rowcount(self): + @util.memoized_property + def rowcount(self): + if len(self.context.compiled_parameters) > 1: + return -1 + else: + return self.context.rowcount + + with patch.object( + config.db.dialect, "supports_sane_multi_rowcount", False): + with patch( + "sqlalchemy.engine.result.ResultProxy.rowcount", + rowcount): + + Foo = self.classes.Foo + s1 = self._fixture() + f1s1 = Foo(value='f1 value') + s1.add(f1s1) + s1.commit() + + f1s1.value = 'f2 value' + s1.flush() + eq_(f1s1.version_id, 2) + @testing.emits_warning(r'.*does not support updated rowcount') @engines.close_open_connections def test_noversioncheck(self): |