diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-10-25 09:10:09 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-11-03 18:42:52 -0400 |
| commit | b96321ae79a0366c33ca739e6e67aaf5f4420db4 (patch) | |
| tree | d56cb4cdf58e0b060f1ceb14f468eef21de0688b /test | |
| parent | 9bae9a931a460ff70172858ff90bcc1defae8e20 (diff) | |
| download | sqlalchemy-b96321ae79a0366c33ca739e6e67aaf5f4420db4.tar.gz | |
Support result.close() for all iterator patterns
This change contains new features for 2.0 only as well as some
behaviors that will be backported to 1.4.
For 1.4 and 2.0:
Fixed issue where the underlying DBAPI cursor would not be closed when
using :class:`_orm.Query` with :meth:`_orm.Query.yield_per` and direct
iteration, if a user-defined exception case were raised within the
iteration process, interrupting the iterator. This would lead to the usual
MySQL-related issues with server side cursors out of sync.
For 1.4 only:
A similar scenario can occur when using :term:`2.x` executions with direct
use of :class:`.Result`, in that case the end-user code has access to the
:class:`.Result` itself and should call :meth:`.Result.close` directly.
Version 2.0 will feature context-manager calling patterns to address this
use case. However within the 1.4 scope, ensured that ``.close()`` methods
are available on all :class:`.Result` implementations including
:class:`.ScalarResult`, :class:`.MappingResult`.
For 2.0 only:
To better support the use case of iterating :class:`.Result` and
:class:`.AsyncResult` objects where user-defined exceptions may interrupt
the iteration, both objects as well as variants such as
:class:`.ScalarResult`, :class:`.MappingResult`,
:class:`.AsyncScalarResult`, :class:`.AsyncMappingResult` now support
context manager usage, where the result will be closed at the end of
iteration.
Corrected various typing issues within the engine and async engine
packages.
Fixes: #8710
Change-Id: I3166328bfd3900957eb33cbf1061d0495c9df670
Diffstat (limited to 'test')
| -rw-r--r-- | test/base/test_result.py | 15 | ||||
| -rw-r--r-- | test/ext/asyncio/test_engine_py3k.py | 49 | ||||
| -rw-r--r-- | test/ext/mypy/plain_files/engines.py | 86 | ||||
| -rw-r--r-- | test/orm/test_loading.py | 19 | ||||
| -rw-r--r-- | test/orm/test_query.py | 52 | ||||
| -rw-r--r-- | test/sql/test_resultset.py | 125 |
6 files changed, 343 insertions, 3 deletions
diff --git a/test/base/test_result.py b/test/base/test_result.py index 90938263f..3e6444daa 100644 --- a/test/base/test_result.py +++ b/test/base/test_result.py @@ -253,6 +253,21 @@ class ResultTest(fixtures.TestBase): return res + def test_close_attributes(self): + """test #8710""" + r1 = self._fixture() + + is_false(r1.closed) + is_false(r1._soft_closed) + + r1._soft_close() + is_false(r1.closed) + is_true(r1._soft_closed) + + r1.close() + is_true(r1.closed) + is_true(r1._soft_closed) + def test_class_presented(self): """To support different kinds of objects returned vs. rows, there are two wrapper classes for Result. diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py index cdf70ca67..2eebb433d 100644 --- a/test/ext/asyncio/test_engine_py3k.py +++ b/test/ext/asyncio/test_engine_py3k.py @@ -799,6 +799,42 @@ class AsyncResultTest(EngineFixture): ): await conn.exec_driver_sql("SELECT * FROM users") + @async_test + async def test_stream_ctxmanager(self, async_engine): + async with async_engine.connect() as conn: + conn = await conn.execution_options(stream_results=True) + + async with conn.stream(select(self.tables.users)) as result: + assert not result._real_result._soft_closed + assert not result.closed + with expect_raises_message(Exception, "hi"): + i = 0 + async for row in result: + if i > 2: + raise Exception("hi") + i += 1 + assert result._real_result._soft_closed + assert result.closed + + @async_test + async def test_stream_scalars_ctxmanager(self, async_engine): + async with async_engine.connect() as conn: + conn = await conn.execution_options(stream_results=True) + + async with conn.stream_scalars( + select(self.tables.users) + ) as result: + assert not result._real_result._soft_closed + assert not result.closed + with expect_raises_message(Exception, "hi"): + i = 0 + async for scalar in result: + if i > 2: + raise Exception("hi") + i += 1 + assert result._real_result._soft_closed + assert result.closed + @testing.combinations( (None,), ("scalars",), ("mappings",), argnames="filter_" ) @@ -831,13 +867,20 @@ class AsyncResultTest(EngineFixture): eq_(all_, [(i, "name%d" % i) for i in range(1, 20)]) @testing.combinations( - (None,), ("scalars",), ("mappings",), argnames="filter_" + (None,), + ("scalars",), + ("stream_scalars",), + ("mappings",), + argnames="filter_", ) @async_test async def test_aiter(self, async_engine, filter_): users = self.tables.users async with async_engine.connect() as conn: - result = await conn.stream(select(users)) + if filter_ == "stream_scalars": + result = await conn.stream_scalars(select(users.c.user_name)) + else: + result = await conn.stream(select(users)) if filter_ == "mappings": result = result.mappings() @@ -857,7 +900,7 @@ class AsyncResultTest(EngineFixture): for i in range(1, 20) ], ) - elif filter_ == "scalars": + elif filter_ in ("scalars", "stream_scalars"): eq_( rows, ["name%d" % i for i in range(1, 20)], diff --git a/test/ext/mypy/plain_files/engines.py b/test/ext/mypy/plain_files/engines.py new file mode 100644 index 000000000..c920ad55d --- /dev/null +++ b/test/ext/mypy/plain_files/engines.py @@ -0,0 +1,86 @@ +from sqlalchemy import create_engine +from sqlalchemy import text +from sqlalchemy.ext.asyncio import create_async_engine + + +def regular() -> None: + + e = create_engine("sqlite://") + + # EXPECTED_TYPE: Engine + reveal_type(e) + + with e.connect() as conn: + + # EXPECTED_TYPE: Connection + reveal_type(conn) + + result = conn.execute(text("select * from table")) + + # EXPECTED_TYPE: CursorResult[Any] + reveal_type(result) + + with e.begin() as conn: + + # EXPECTED_TYPE: Connection + reveal_type(conn) + + result = conn.execute(text("select * from table")) + + # EXPECTED_TYPE: CursorResult[Any] + reveal_type(result) + + +async def asyncio() -> None: + e = create_async_engine("sqlite://") + + # EXPECTED_TYPE: AsyncEngine + reveal_type(e) + + async with e.connect() as conn: + + # EXPECTED_TYPE: AsyncConnection + reveal_type(conn) + + result = await conn.execute(text("select * from table")) + + # EXPECTED_TYPE: CursorResult[Any] + reveal_type(result) + + # stream with direct await + async_result = await conn.stream(text("select * from table")) + + # EXPECTED_TYPE: AsyncResult[Any] + reveal_type(async_result) + + # stream with context manager + async with conn.stream( + text("select * from table") + ) as ctx_async_result: + # EXPECTED_TYPE: AsyncResult[Any] + reveal_type(ctx_async_result) + + # stream_scalars with direct await + async_scalar_result = await conn.stream_scalars( + text("select * from table") + ) + + # EXPECTED_TYPE: AsyncScalarResult[Any] + reveal_type(async_scalar_result) + + # stream_scalars with context manager + async with conn.stream_scalars( + text("select * from table") + ) as ctx_async_scalar_result: + # EXPECTED_TYPE: AsyncScalarResult[Any] + reveal_type(ctx_async_scalar_result) + + async with e.begin() as conn: + + # EXPECTED_TYPE: AsyncConnection + reveal_type(conn) + + result = await conn.execute(text("select * from table")) + + # EXPECTED_TYPE: CursorResult[Any] + reveal_type(result) diff --git a/test/orm/test_loading.py b/test/orm/test_loading.py index cc3c3f494..d0b5c9d8f 100644 --- a/test/orm/test_loading.py +++ b/test/orm/test_loading.py @@ -6,6 +6,7 @@ from sqlalchemy import testing from sqlalchemy import text from sqlalchemy.orm import loading from sqlalchemy.orm import relationship +from sqlalchemy.testing import is_true from sqlalchemy.testing import mock from sqlalchemy.testing.assertions import assert_raises from sqlalchemy.testing.assertions import assert_raises_message @@ -152,6 +153,24 @@ class InstancesTest(_fixtures.FixtureTest): def setup_mappers(cls): cls._setup_stock_mapping() + def test_cursor_close_exception_raised_in_iteration(self): + """test #8710""" + + User = self.classes.User + s = fixture_session() + + stmt = select(User).execution_options(yield_per=1) + + result = s.execute(stmt) + raw_cursor = result.raw + + for row in result: + with expect_raises_message(Exception, "whoops"): + for row in result: + raise Exception("whoops") + + is_true(raw_cursor._soft_closed) + def test_cursor_close_w_failed_rowproc(self): User = self.classes.User s = fixture_session() diff --git a/test/orm/test_query.py b/test/orm/test_query.py index 463009065..c05fdaf4f 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -5417,6 +5417,58 @@ class YieldTest(_fixtures.FixtureTest): result.close() assert_raises(sa.exc.ResourceClosedError, result.all) + def test_yield_per_close_on_interrupted_iteration_legacy(self): + """test #8710""" + + self._eagerload_mappings() + + User = self.classes.User + + asserted_result = None + + class _Query(Query): + def _iter(self): + nonlocal asserted_result + asserted_result = super(_Query, self)._iter() + return asserted_result + + sess = fixture_session(query_cls=_Query) + + with expect_raises_message(Exception, "hi"): + for i, row in enumerate(sess.query(User).yield_per(1)): + assert not asserted_result._soft_closed + assert not asserted_result.closed + + if i > 1: + raise Exception("hi") + + assert asserted_result._soft_closed + assert not asserted_result.closed + + def test_yield_per_close_on_interrupted_iteration(self): + """test #8710""" + + self._eagerload_mappings() + + User = self.classes.User + + sess = fixture_session() + + with expect_raises_message(Exception, "hi"): + result = sess.execute(select(User).execution_options(yield_per=1)) + for i, row in enumerate(result): + assert not result._soft_closed + assert not result.closed + + if i > 1: + raise Exception("hi") + + assert not result._soft_closed + assert not result.closed + result.close() + assert result._soft_closed + assert result.closed + def test_yield_per_and_execution_options_legacy(self): self._eagerload_mappings() diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py index 4f776e300..fa86d75ee 100644 --- a/test/sql/test_resultset.py +++ b/test/sql/test_resultset.py @@ -53,6 +53,7 @@ from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures from sqlalchemy.testing import in_ from sqlalchemy.testing import is_ +from sqlalchemy.testing import is_false from sqlalchemy.testing import is_true from sqlalchemy.testing import le_ from sqlalchemy.testing import mock @@ -2033,6 +2034,89 @@ class CursorResultTest(fixtures.TablesTest): partition = next(result.partitions()) eq_(len(partition), value) + @testing.fixture + def autoclose_row_fixture(self, connection): + users = self.tables.users + connection.execute( + users.insert(), + [ + {"user_id": 1, "name": "u1"}, + {"user_id": 2, "name": "u2"}, + {"user_id": 3, "name": "u3"}, + {"user_id": 4, "name": "u4"}, + {"user_id": 5, "name": "u5"}, + ], + ) + + @testing.fixture(params=["plain", "scalars", "mapping"]) + def result_fixture(self, request, connection): + users = self.tables.users + + result_type = request.param + + if result_type == "plain": + result = connection.execute(select(users)) + elif result_type == "scalars": + result = connection.scalars(select(users)) + elif result_type == "mapping": + result = connection.execute(select(users)).mappings() + else: + assert False + + return result + + def test_results_can_close(self, autoclose_row_fixture, result_fixture): + """test #8710""" + + r1 = result_fixture + + is_false(r1.closed) + is_false(r1._soft_closed) + + r1._soft_close() + is_false(r1.closed) + is_true(r1._soft_closed) + + r1.close() + is_true(r1.closed) + is_true(r1._soft_closed) + + def test_autoclose_rows_exhausted_plain( + self, connection, autoclose_row_fixture, result_fixture + ): + result = result_fixture + + assert not result._soft_closed + assert not result.closed + + read_iterator = list(result) + eq_(len(read_iterator), 5) + + assert result._soft_closed + assert not result.closed + + result.close() + assert result.closed + + def test_result_ctxmanager( + self, connection, autoclose_row_fixture, result_fixture + ): + """test #8710""" + + result = result_fixture + + with expect_raises_message(Exception, "hi"): + with result: + assert not result._soft_closed + assert not result.closed + + for i, obj in enumerate(result): + if i > 2: + raise Exception("hi") + + assert result._soft_closed + assert result.closed + class KeyTargetingTest(fixtures.TablesTest): run_inserts = "once" @@ -3113,6 +3197,47 @@ class AlternateCursorResultTest(fixtures.TablesTest): # buffer of 98, plus buffer of 99 - 89, 10 rows eq_(len(result.cursor_strategy._rowbuffer), 10) + for i, row in enumerate(result): + if i == 206: + break + + eq_(i, 206) + + def test_iterator_remains_unbroken(self, connection): + """test related to #8710. + + demonstrate that we can't close the cursor by catching + GeneratorExit inside of our iteration. Leaving the iterable + block using break, then picking up again, would be directly + impacted by this. So this provides a clear rationale for + providing context manager support for result objects. + + """ + table = self.tables.test + + connection.execute( + table.insert(), + [{"x": i, "y": "t_%d" % i} for i in range(15, 250)], + ) + + result = connection.execute(table.select()) + result = result.yield_per(100) + for i, row in enumerate(result): + if i == 188: + # this will raise GeneratorExit inside the iterator. + # so we can't close the DBAPI cursor here, we have plenty + # more rows to yield + break + + eq_(i, 188) + + # demonstrate getting more rows + for i, row in enumerate(result, 188): + if i == 206: + break + + eq_(i, 206) + @testing.combinations(True, False, argnames="close_on_init") @testing.combinations( "fetchone", "fetchmany", "fetchall", argnames="fetch_style" |
