diff options
| -rw-r--r-- | doc/build/changelog/unreleased_12/4062.rst | 13 | ||||
| -rw-r--r-- | lib/sqlalchemy/connectors/pyodbc.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/default.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 24 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/requirements.py | 23 | ||||
| -rw-r--r-- | test/requirements.py | 14 | ||||
| -rw-r--r-- | test/sql/test_rowcount.py | 10 |
7 files changed, 66 insertions, 23 deletions
diff --git a/doc/build/changelog/unreleased_12/4062.rst b/doc/build/changelog/unreleased_12/4062.rst new file mode 100644 index 000000000..3a89a1ad6 --- /dev/null +++ b/doc/build/changelog/unreleased_12/4062.rst @@ -0,0 +1,13 @@ +.. change:: + :tags: bug, mssql, orm + :tickets: 4062 + + Added a new class of "rowcount support" for dialects that is specific to + when "RETURNING", which on SQL Server looks like "OUTPUT inserted", is in + use, as the PyODBC backend isn't able to give us rowcount on an UPDATE or + DELETE statement when OUTPUT is in effect. This primarily affects the ORM + when a flush is updating a row that contains server-calcluated values, + raising an error if the backend does not return the expected row count. + PyODBC now states that it supports rowcount except if OUTPUT.inserted is + present, which is taken into account by the ORM during a flush as to + whether it will look for a rowcount. diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py index 65fe37212..66acf0072 100644 --- a/lib/sqlalchemy/connectors/pyodbc.py +++ b/lib/sqlalchemy/connectors/pyodbc.py @@ -16,6 +16,7 @@ import re class PyODBCConnector(Connector): driver = 'pyodbc' + supports_sane_rowcount_returning = False supports_sane_multi_rowcount = False if util.py2k: diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index d1b54ab01..8b72c0001 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -249,6 +249,10 @@ class DefaultDialect(interfaces.Dialect): def dialect_description(self): return self.name + "+" + self.driver + @property + def supports_sane_rowcount_returning(self): + return self.supports_sane_rowcount + @classmethod def get_pool_class(cls, url): return getattr(cls, 'poolclass', pool.QueuePool) diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index faacd018e..24c9743d4 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -693,22 +693,28 @@ def _emit_update_statements(base_mapper, uowtransaction, records = list(records) statement = cached_stmt - - # TODO: would be super-nice to not have to determine this boolean - # inside the loop here, in the 99.9999% of the time there's only - # one connection in use - assert_singlerow = connection.dialect.supports_sane_rowcount - assert_multirow = assert_singlerow and \ - connection.dialect.supports_sane_multi_rowcount - allow_multirow = has_all_defaults and not needs_version_id + return_defaults = False if not has_all_pks: statement = statement.return_defaults() + return_defaults = True elif bookkeeping and not has_all_defaults and \ mapper.base_mapper.eager_defaults: statement = statement.return_defaults() + return_defaults = True elif mapper.version_id_col is not None: statement = statement.return_defaults(mapper.version_id_col) + return_defaults = True + + assert_singlerow = ( + connection.dialect.supports_sane_rowcount + if not return_defaults + else connection.dialect.supports_sane_rowcount_returning + ) + + assert_multirow = assert_singlerow and \ + connection.dialect.supports_sane_multi_rowcount + allow_multirow = has_all_defaults and not needs_version_id if hasvalue: for state, state_dict, params, mapper, \ @@ -728,7 +734,7 @@ def _emit_update_statements(base_mapper, uowtransaction, c.context.compiled_parameters[0], value_params) rows += c.rowcount - check_rowcount = True + check_rowcount = assert_singlerow else: if not allow_multirow: check_rowcount = assert_singlerow diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index 08a7b1ced..327362bf6 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -193,6 +193,29 @@ class SuiteRequirements(Requirements): return exclusions.open() + + @property + def sane_rowcount(self): + return exclusions.skip_if( + lambda config: not config.db.dialect.supports_sane_rowcount, + "driver doesn't support 'sane' rowcount" + ) + + @property + def sane_multi_rowcount(self): + return exclusions.fails_if( + lambda config: not config.db.dialect.supports_sane_multi_rowcount, + "driver %(driver)s %(doesnt_support)s 'sane' multi row count" + ) + + @property + def sane_rowcount_w_returning(self): + return exclusions.fails_if( + lambda config: + not config.db.dialect.supports_sane_rowcount_returning, + "driver doesn't support 'sane' rowcount when returning is on" + ) + @property def empty_inserts(self): """target platform supports INSERT with no values, i.e. diff --git a/test/requirements.py b/test/requirements.py index 4f01eac9b..0362e28d1 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -559,13 +559,6 @@ class DefaultRequirements(SuiteRequirements): ]) @property - def sane_rowcount(self): - return skip_if( - lambda config: not config.db.dialect.supports_sane_rowcount, - "driver doesn't support 'sane' rowcount" - ) - - @property def emulated_lastrowid(self): """"target dialect retrieves cursor.lastrowid or an equivalent after an insert() construct executes. @@ -594,13 +587,6 @@ class DefaultRequirements(SuiteRequirements): 'sqlite+pysqlcipher') @property - def sane_multi_rowcount(self): - return fails_if( - lambda config: not config.db.dialect.supports_sane_multi_rowcount, - "driver %(driver)s %(doesnt_support)s 'sane' multi row count" - ) - - @property def nullsordering(self): """Target backends that support nulls ordering.""" return fails_on_everything_except('postgresql', 'oracle', 'firebird') diff --git a/test/sql/test_rowcount.py b/test/sql/test_rowcount.py index 16087b94c..3399ba7ec 100644 --- a/test/sql/test_rowcount.py +++ b/test/sql/test_rowcount.py @@ -65,6 +65,15 @@ class FoundRowsTest(fixtures.TestBase, AssertsExecutionResults): r = employees_table.update(department == 'C').execute(department='C') assert r.rowcount == 3 + @testing.requires.sane_rowcount_w_returning + def test_update_rowcount_return_defaults(self): + department = employees_table.c.department + stmt = employees_table.update(department == 'C').values( + name=employees_table.c.department + 'Z').return_defaults() + + r = stmt.execute() + assert r.rowcount == 3 + def test_raw_sql_rowcount(self): # test issue #3622, make sure eager rowcount is called for text with testing.db.connect() as conn: @@ -117,3 +126,4 @@ class FoundRowsTest(fixtures.TestBase, AssertsExecutionResults): eq_( r.rowcount, 2 ) + |
