summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-10-25 09:10:09 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-11-03 18:42:52 -0400
commitb96321ae79a0366c33ca739e6e67aaf5f4420db4 (patch)
treed56cb4cdf58e0b060f1ceb14f468eef21de0688b /test
parent9bae9a931a460ff70172858ff90bcc1defae8e20 (diff)
downloadsqlalchemy-b96321ae79a0366c33ca739e6e67aaf5f4420db4.tar.gz
Support result.close() for all iterator patterns
This change contains new features for 2.0 only as well as some behaviors that will be backported to 1.4. For 1.4 and 2.0: Fixed issue where the underlying DBAPI cursor would not be closed when using :class:`_orm.Query` with :meth:`_orm.Query.yield_per` and direct iteration, if a user-defined exception case were raised within the iteration process, interrupting the iterator. This would lead to the usual MySQL-related issues with server side cursors out of sync. For 1.4 only: A similar scenario can occur when using :term:`2.x` executions with direct use of :class:`.Result`, in that case the end-user code has access to the :class:`.Result` itself and should call :meth:`.Result.close` directly. Version 2.0 will feature context-manager calling patterns to address this use case. However within the 1.4 scope, ensured that ``.close()`` methods are available on all :class:`.Result` implementations including :class:`.ScalarResult`, :class:`.MappingResult`. For 2.0 only: To better support the use case of iterating :class:`.Result` and :class:`.AsyncResult` objects where user-defined exceptions may interrupt the iteration, both objects as well as variants such as :class:`.ScalarResult`, :class:`.MappingResult`, :class:`.AsyncScalarResult`, :class:`.AsyncMappingResult` now support context manager usage, where the result will be closed at the end of iteration. Corrected various typing issues within the engine and async engine packages. Fixes: #8710 Change-Id: I3166328bfd3900957eb33cbf1061d0495c9df670
Diffstat (limited to 'test')
-rw-r--r--test/base/test_result.py15
-rw-r--r--test/ext/asyncio/test_engine_py3k.py49
-rw-r--r--test/ext/mypy/plain_files/engines.py86
-rw-r--r--test/orm/test_loading.py19
-rw-r--r--test/orm/test_query.py52
-rw-r--r--test/sql/test_resultset.py125
6 files changed, 343 insertions, 3 deletions
diff --git a/test/base/test_result.py b/test/base/test_result.py
index 90938263f..3e6444daa 100644
--- a/test/base/test_result.py
+++ b/test/base/test_result.py
@@ -253,6 +253,21 @@ class ResultTest(fixtures.TestBase):
return res
+ def test_close_attributes(self):
+ """test #8710"""
+ r1 = self._fixture()
+
+ is_false(r1.closed)
+ is_false(r1._soft_closed)
+
+ r1._soft_close()
+ is_false(r1.closed)
+ is_true(r1._soft_closed)
+
+ r1.close()
+ is_true(r1.closed)
+ is_true(r1._soft_closed)
+
def test_class_presented(self):
"""To support different kinds of objects returned vs. rows,
there are two wrapper classes for Result.
diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py
index cdf70ca67..2eebb433d 100644
--- a/test/ext/asyncio/test_engine_py3k.py
+++ b/test/ext/asyncio/test_engine_py3k.py
@@ -799,6 +799,42 @@ class AsyncResultTest(EngineFixture):
):
await conn.exec_driver_sql("SELECT * FROM users")
+ @async_test
+ async def test_stream_ctxmanager(self, async_engine):
+ async with async_engine.connect() as conn:
+ conn = await conn.execution_options(stream_results=True)
+
+ async with conn.stream(select(self.tables.users)) as result:
+ assert not result._real_result._soft_closed
+ assert not result.closed
+ with expect_raises_message(Exception, "hi"):
+ i = 0
+ async for row in result:
+ if i > 2:
+ raise Exception("hi")
+ i += 1
+ assert result._real_result._soft_closed
+ assert result.closed
+
+ @async_test
+ async def test_stream_scalars_ctxmanager(self, async_engine):
+ async with async_engine.connect() as conn:
+ conn = await conn.execution_options(stream_results=True)
+
+ async with conn.stream_scalars(
+ select(self.tables.users)
+ ) as result:
+ assert not result._real_result._soft_closed
+ assert not result.closed
+ with expect_raises_message(Exception, "hi"):
+ i = 0
+ async for scalar in result:
+ if i > 2:
+ raise Exception("hi")
+ i += 1
+ assert result._real_result._soft_closed
+ assert result.closed
+
@testing.combinations(
(None,), ("scalars",), ("mappings",), argnames="filter_"
)
@@ -831,13 +867,20 @@ class AsyncResultTest(EngineFixture):
eq_(all_, [(i, "name%d" % i) for i in range(1, 20)])
@testing.combinations(
- (None,), ("scalars",), ("mappings",), argnames="filter_"
+ (None,),
+ ("scalars",),
+ ("stream_scalars",),
+ ("mappings",),
+ argnames="filter_",
)
@async_test
async def test_aiter(self, async_engine, filter_):
users = self.tables.users
async with async_engine.connect() as conn:
- result = await conn.stream(select(users))
+ if filter_ == "stream_scalars":
+ result = await conn.stream_scalars(select(users.c.user_name))
+ else:
+ result = await conn.stream(select(users))
if filter_ == "mappings":
result = result.mappings()
@@ -857,7 +900,7 @@ class AsyncResultTest(EngineFixture):
for i in range(1, 20)
],
)
- elif filter_ == "scalars":
+ elif filter_ in ("scalars", "stream_scalars"):
eq_(
rows,
["name%d" % i for i in range(1, 20)],
diff --git a/test/ext/mypy/plain_files/engines.py b/test/ext/mypy/plain_files/engines.py
new file mode 100644
index 000000000..c920ad55d
--- /dev/null
+++ b/test/ext/mypy/plain_files/engines.py
@@ -0,0 +1,86 @@
+from sqlalchemy import create_engine
+from sqlalchemy import text
+from sqlalchemy.ext.asyncio import create_async_engine
+
+
+def regular() -> None:
+
+ e = create_engine("sqlite://")
+
+ # EXPECTED_TYPE: Engine
+ reveal_type(e)
+
+ with e.connect() as conn:
+
+ # EXPECTED_TYPE: Connection
+ reveal_type(conn)
+
+ result = conn.execute(text("select * from table"))
+
+ # EXPECTED_TYPE: CursorResult[Any]
+ reveal_type(result)
+
+ with e.begin() as conn:
+
+ # EXPECTED_TYPE: Connection
+ reveal_type(conn)
+
+ result = conn.execute(text("select * from table"))
+
+ # EXPECTED_TYPE: CursorResult[Any]
+ reveal_type(result)
+
+
+async def asyncio() -> None:
+ e = create_async_engine("sqlite://")
+
+ # EXPECTED_TYPE: AsyncEngine
+ reveal_type(e)
+
+ async with e.connect() as conn:
+
+ # EXPECTED_TYPE: AsyncConnection
+ reveal_type(conn)
+
+ result = await conn.execute(text("select * from table"))
+
+ # EXPECTED_TYPE: CursorResult[Any]
+ reveal_type(result)
+
+ # stream with direct await
+ async_result = await conn.stream(text("select * from table"))
+
+ # EXPECTED_TYPE: AsyncResult[Any]
+ reveal_type(async_result)
+
+ # stream with context manager
+ async with conn.stream(
+ text("select * from table")
+ ) as ctx_async_result:
+ # EXPECTED_TYPE: AsyncResult[Any]
+ reveal_type(ctx_async_result)
+
+ # stream_scalars with direct await
+ async_scalar_result = await conn.stream_scalars(
+ text("select * from table")
+ )
+
+ # EXPECTED_TYPE: AsyncScalarResult[Any]
+ reveal_type(async_scalar_result)
+
+ # stream_scalars with context manager
+ async with conn.stream_scalars(
+ text("select * from table")
+ ) as ctx_async_scalar_result:
+ # EXPECTED_TYPE: AsyncScalarResult[Any]
+ reveal_type(ctx_async_scalar_result)
+
+ async with e.begin() as conn:
+
+ # EXPECTED_TYPE: AsyncConnection
+ reveal_type(conn)
+
+ result = await conn.execute(text("select * from table"))
+
+ # EXPECTED_TYPE: CursorResult[Any]
+ reveal_type(result)
diff --git a/test/orm/test_loading.py b/test/orm/test_loading.py
index cc3c3f494..d0b5c9d8f 100644
--- a/test/orm/test_loading.py
+++ b/test/orm/test_loading.py
@@ -6,6 +6,7 @@ from sqlalchemy import testing
from sqlalchemy import text
from sqlalchemy.orm import loading
from sqlalchemy.orm import relationship
+from sqlalchemy.testing import is_true
from sqlalchemy.testing import mock
from sqlalchemy.testing.assertions import assert_raises
from sqlalchemy.testing.assertions import assert_raises_message
@@ -152,6 +153,24 @@ class InstancesTest(_fixtures.FixtureTest):
def setup_mappers(cls):
cls._setup_stock_mapping()
+ def test_cursor_close_exception_raised_in_iteration(self):
+ """test #8710"""
+
+ User = self.classes.User
+ s = fixture_session()
+
+ stmt = select(User).execution_options(yield_per=1)
+
+ result = s.execute(stmt)
+ raw_cursor = result.raw
+
+ for row in result:
+ with expect_raises_message(Exception, "whoops"):
+ for row in result:
+ raise Exception("whoops")
+
+ is_true(raw_cursor._soft_closed)
+
def test_cursor_close_w_failed_rowproc(self):
User = self.classes.User
s = fixture_session()
diff --git a/test/orm/test_query.py b/test/orm/test_query.py
index 463009065..c05fdaf4f 100644
--- a/test/orm/test_query.py
+++ b/test/orm/test_query.py
@@ -5417,6 +5417,58 @@ class YieldTest(_fixtures.FixtureTest):
result.close()
assert_raises(sa.exc.ResourceClosedError, result.all)
+ def test_yield_per_close_on_interrupted_iteration_legacy(self):
+ """test #8710"""
+
+ self._eagerload_mappings()
+
+ User = self.classes.User
+
+ asserted_result = None
+
+ class _Query(Query):
+ def _iter(self):
+ nonlocal asserted_result
+ asserted_result = super(_Query, self)._iter()
+ return asserted_result
+
+ sess = fixture_session(query_cls=_Query)
+
+ with expect_raises_message(Exception, "hi"):
+ for i, row in enumerate(sess.query(User).yield_per(1)):
+ assert not asserted_result._soft_closed
+ assert not asserted_result.closed
+
+ if i > 1:
+ raise Exception("hi")
+
+ assert asserted_result._soft_closed
+ assert not asserted_result.closed
+
+ def test_yield_per_close_on_interrupted_iteration(self):
+ """test #8710"""
+
+ self._eagerload_mappings()
+
+ User = self.classes.User
+
+ sess = fixture_session()
+
+ with expect_raises_message(Exception, "hi"):
+ result = sess.execute(select(User).execution_options(yield_per=1))
+ for i, row in enumerate(result):
+ assert not result._soft_closed
+ assert not result.closed
+
+ if i > 1:
+ raise Exception("hi")
+
+ assert not result._soft_closed
+ assert not result.closed
+ result.close()
+ assert result._soft_closed
+ assert result.closed
+
def test_yield_per_and_execution_options_legacy(self):
self._eagerload_mappings()
diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py
index 4f776e300..fa86d75ee 100644
--- a/test/sql/test_resultset.py
+++ b/test/sql/test_resultset.py
@@ -53,6 +53,7 @@ from sqlalchemy.testing import expect_raises_message
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import in_
from sqlalchemy.testing import is_
+from sqlalchemy.testing import is_false
from sqlalchemy.testing import is_true
from sqlalchemy.testing import le_
from sqlalchemy.testing import mock
@@ -2033,6 +2034,89 @@ class CursorResultTest(fixtures.TablesTest):
partition = next(result.partitions())
eq_(len(partition), value)
+ @testing.fixture
+ def autoclose_row_fixture(self, connection):
+ users = self.tables.users
+ connection.execute(
+ users.insert(),
+ [
+ {"user_id": 1, "name": "u1"},
+ {"user_id": 2, "name": "u2"},
+ {"user_id": 3, "name": "u3"},
+ {"user_id": 4, "name": "u4"},
+ {"user_id": 5, "name": "u5"},
+ ],
+ )
+
+ @testing.fixture(params=["plain", "scalars", "mapping"])
+ def result_fixture(self, request, connection):
+ users = self.tables.users
+
+ result_type = request.param
+
+ if result_type == "plain":
+ result = connection.execute(select(users))
+ elif result_type == "scalars":
+ result = connection.scalars(select(users))
+ elif result_type == "mapping":
+ result = connection.execute(select(users)).mappings()
+ else:
+ assert False
+
+ return result
+
+ def test_results_can_close(self, autoclose_row_fixture, result_fixture):
+ """test #8710"""
+
+ r1 = result_fixture
+
+ is_false(r1.closed)
+ is_false(r1._soft_closed)
+
+ r1._soft_close()
+ is_false(r1.closed)
+ is_true(r1._soft_closed)
+
+ r1.close()
+ is_true(r1.closed)
+ is_true(r1._soft_closed)
+
+ def test_autoclose_rows_exhausted_plain(
+ self, connection, autoclose_row_fixture, result_fixture
+ ):
+ result = result_fixture
+
+ assert not result._soft_closed
+ assert not result.closed
+
+ read_iterator = list(result)
+ eq_(len(read_iterator), 5)
+
+ assert result._soft_closed
+ assert not result.closed
+
+ result.close()
+ assert result.closed
+
+ def test_result_ctxmanager(
+ self, connection, autoclose_row_fixture, result_fixture
+ ):
+ """test #8710"""
+
+ result = result_fixture
+
+ with expect_raises_message(Exception, "hi"):
+ with result:
+ assert not result._soft_closed
+ assert not result.closed
+
+ for i, obj in enumerate(result):
+ if i > 2:
+ raise Exception("hi")
+
+ assert result._soft_closed
+ assert result.closed
+
class KeyTargetingTest(fixtures.TablesTest):
run_inserts = "once"
@@ -3113,6 +3197,47 @@ class AlternateCursorResultTest(fixtures.TablesTest):
# buffer of 98, plus buffer of 99 - 89, 10 rows
eq_(len(result.cursor_strategy._rowbuffer), 10)
+ for i, row in enumerate(result):
+ if i == 206:
+ break
+
+ eq_(i, 206)
+
+ def test_iterator_remains_unbroken(self, connection):
+ """test related to #8710.
+
+ demonstrate that we can't close the cursor by catching
+ GeneratorExit inside of our iteration. Leaving the iterable
+ block using break, then picking up again, would be directly
+ impacted by this. So this provides a clear rationale for
+ providing context manager support for result objects.
+
+ """
+ table = self.tables.test
+
+ connection.execute(
+ table.insert(),
+ [{"x": i, "y": "t_%d" % i} for i in range(15, 250)],
+ )
+
+ result = connection.execute(table.select())
+ result = result.yield_per(100)
+ for i, row in enumerate(result):
+ if i == 188:
+ # this will raise GeneratorExit inside the iterator.
+ # so we can't close the DBAPI cursor here, we have plenty
+ # more rows to yield
+ break
+
+ eq_(i, 188)
+
+ # demonstrate getting more rows
+ for i, row in enumerate(result, 188):
+ if i == 206:
+ break
+
+ eq_(i, 206)
+
@testing.combinations(True, False, argnames="close_on_init")
@testing.combinations(
"fetchone", "fetchmany", "fetchall", argnames="fetch_style"