diff options
Diffstat (limited to 'lib/sqlalchemy/dialects/mssql/base.py')
| -rw-r--r-- | lib/sqlalchemy/dialects/mssql/base.py | 53 |
1 files changed, 34 insertions, 19 deletions
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 4339551a3..a3855cc2c 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -685,7 +685,6 @@ import operator import re from . import information_schema as ischema -from ... import engine from ... import exc from ... import schema as sa_schema from ... import sql @@ -693,6 +692,7 @@ from ... import types as sqltypes from ... import util from ...engine import default from ...engine import reflection +from ...engine import result as _result from ...sql import compiler from ...sql import elements from ...sql import expression @@ -1431,8 +1431,9 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): class MSExecutionContext(default.DefaultExecutionContext): _enable_identity_insert = False _select_lastrowid = False - _result_proxy = None _lastrowid = None + _rowcount = None + _result_strategy = None def _opt_encode(self, statement): if not self.dialect.supports_unicode_statements: @@ -1500,6 +1501,10 @@ class MSExecutionContext(default.DefaultExecutionContext): """Disable IDENTITY_INSERT if enabled.""" conn = self.root_connection + + if self.isinsert or self.isupdate or self.isdelete: + self._rowcount = self.cursor.rowcount + if self._select_lastrowid: if self.dialect.use_scope_identity: conn._cursor_execute( @@ -1516,10 +1521,13 @@ class MSExecutionContext(default.DefaultExecutionContext): row = self.cursor.fetchall()[0] self._lastrowid = int(row[0]) - if ( + elif ( self.isinsert or self.isupdate or self.isdelete ) and self.compiled.returning: - self._result_proxy = engine.FullyBufferedResultProxy(self) + fbcr = _result.FullyBufferedCursorFetchStrategy + self._result_strategy = fbcr.create_from_buffer( + self.cursor, self.cursor.description, self.cursor.fetchall() + ) if self._enable_identity_insert: conn._cursor_execute( @@ -1537,6 +1545,13 @@ class MSExecutionContext(default.DefaultExecutionContext): def get_lastrowid(self): return self._lastrowid + @property + def rowcount(self): + if self._rowcount is not None: + return self._rowcount + else: + return self.cursor.rowcount + def handle_dbapi_exception(self, e): if self._enable_identity_insert: try: @@ -1551,11 +1566,13 @@ class MSExecutionContext(default.DefaultExecutionContext): except Exception: pass - def get_result_proxy(self): - if self._result_proxy: - return self._result_proxy + def get_result_cursor_strategy(self, result): + if self._result_strategy: + return self._result_strategy else: - return engine.ResultProxy(self) + return super(MSExecutionContext, self).get_result_cursor_strategy( + result + ) class MSSQLCompiler(compiler.SQLCompiler): @@ -2570,7 +2587,7 @@ class MSDialect(default.DefaultDialect): if self.server_version_info < MS_2005_VERSION: return [] - rp = connection.execute( + rp = connection.execution_options(future_result=True).execute( sql.text( "select ind.index_id, ind.is_unique, ind.name " "from sys.indexes as ind join sys.tables as tab on " @@ -2587,13 +2604,13 @@ class MSDialect(default.DefaultDialect): .columns(name=sqltypes.Unicode()) ) indexes = {} - for row in rp: + for row in rp.mappings(): indexes[row["index_id"]] = { "name": row["name"], "unique": row["is_unique"] == 1, "column_names": [], } - rp = connection.execute( + rp = connection.execution_options(future_result=True).execute( sql.text( "select ind_col.index_id, ind_col.object_id, col.name " "from sys.columns as col " @@ -2611,7 +2628,7 @@ class MSDialect(default.DefaultDialect): ) .columns(name=sqltypes.Unicode()) ) - for row in rp: + for row in rp.mappings(): if row["index_id"] in indexes: indexes[row["index_id"]]["column_names"].append(row["name"]) @@ -2657,12 +2674,10 @@ class MSDialect(default.DefaultDialect): [columns], whereclause, order_by=[columns.c.ordinal_position] ) - c = connection.execute(s) + c = connection.execution_options(future_result=True).execute(s) cols = [] - while True: - row = c.fetchone() - if row is None: - break + + for row in c.mappings(): ( name, type_, @@ -2785,9 +2800,9 @@ class MSDialect(default.DefaultDialect): C.c.table_schema == owner, ), ) - c = connection.execute(s) + c = connection.execution_options(future_result=True).execute(s) constraint_name = None - for row in c: + for row in c.mappings(): if "PRIMARY" in row[TC.c.constraint_type.name]: pkeys.append(row[0]) if constraint_name is None: |
