diff options
Diffstat (limited to 'test/ext/asyncio/test_engine_py3k.py')
| -rw-r--r-- | test/ext/asyncio/test_engine_py3k.py | 340 |
1 files changed, 340 insertions, 0 deletions
diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py new file mode 100644 index 000000000..ec513cb64 --- /dev/null +++ b/test/ext/asyncio/test_engine_py3k.py @@ -0,0 +1,340 @@ +from sqlalchemy import Column +from sqlalchemy import delete +from sqlalchemy import exc +from sqlalchemy import func +from sqlalchemy import Integer +from sqlalchemy import select +from sqlalchemy import String +from sqlalchemy import Table +from sqlalchemy import testing +from sqlalchemy import union_all +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.ext.asyncio import exc as asyncio_exc +from sqlalchemy.testing import async_test +from sqlalchemy.testing import eq_ +from sqlalchemy.testing import fixtures +from sqlalchemy.testing.asyncio import assert_raises_message_async + + +class EngineFixture(fixtures.TablesTest): + __requires__ = ("async_dialect",) + + @testing.fixture + def async_engine(self): + return create_async_engine(testing.db.url) + + @classmethod + def define_tables(cls, metadata): + Table( + "users", + metadata, + Column("user_id", Integer, primary_key=True, autoincrement=False), + Column("user_name", String(20)), + ) + + @classmethod + def insert_data(cls, connection): + users = cls.tables.users + with connection.begin(): + connection.execute( + users.insert(), + [ + {"user_id": i, "user_name": "name%d" % i} + for i in range(1, 20) + ], + ) + + +class AsyncEngineTest(EngineFixture): + __backend__ = True + + @async_test + async def test_connect_ctxmanager(self, async_engine): + async with async_engine.connect() as conn: + result = await conn.execute(select(1)) + eq_(result.scalar(), 1) + + @async_test + async def test_connect_plain(self, async_engine): + conn = await async_engine.connect() + try: + result = await conn.execute(select(1)) + eq_(result.scalar(), 1) + finally: + await conn.close() + + @async_test + async def test_connection_not_started(self, async_engine): + + conn = async_engine.connect() + testing.assert_raises_message( + asyncio_exc.AsyncContextNotStarted, + "AsyncConnection context has not been started and " + "object has not been awaited.", + conn.begin, + ) + + @async_test + async def test_transaction_commit(self, async_engine): + users = self.tables.users + + async with async_engine.begin() as conn: + await conn.execute(delete(users)) + + async with async_engine.connect() as conn: + eq_(await conn.scalar(select(func.count(users.c.user_id))), 0) + + @async_test + async def test_savepoint_rollback_noctx(self, async_engine): + users = self.tables.users + + async with async_engine.begin() as conn: + + savepoint = await conn.begin_nested() + await conn.execute(delete(users)) + await savepoint.rollback() + + async with async_engine.connect() as conn: + eq_(await conn.scalar(select(func.count(users.c.user_id))), 19) + + @async_test + async def test_savepoint_commit_noctx(self, async_engine): + users = self.tables.users + + async with async_engine.begin() as conn: + + savepoint = await conn.begin_nested() + await conn.execute(delete(users)) + await savepoint.commit() + + async with async_engine.connect() as conn: + eq_(await conn.scalar(select(func.count(users.c.user_id))), 0) + + @async_test + async def test_transaction_rollback(self, async_engine): + users = self.tables.users + + async with async_engine.connect() as conn: + trans = conn.begin() + await trans.start() + await conn.execute(delete(users)) + await trans.rollback() + + async with async_engine.connect() as conn: + eq_(await conn.scalar(select(func.count(users.c.user_id))), 19) + + @async_test + async def test_conn_transaction_not_started(self, async_engine): + + async with async_engine.connect() as conn: + trans = conn.begin() + await assert_raises_message_async( + asyncio_exc.AsyncContextNotStarted, + "AsyncTransaction context has not been started " + "and object has not been awaited.", + trans.rollback(), + ) + + +class AsyncResultTest(EngineFixture): + @testing.combinations( + (None,), ("scalars",), ("mappings",), argnames="filter_" + ) + @async_test + async def test_all(self, async_engine, filter_): + users = self.tables.users + async with async_engine.connect() as conn: + result = await conn.stream(select(users)) + + if filter_ == "mappings": + result = result.mappings() + elif filter_ == "scalars": + result = result.scalars(1) + + all_ = await result.all() + if filter_ == "mappings": + eq_( + all_, + [ + {"user_id": i, "user_name": "name%d" % i} + for i in range(1, 20) + ], + ) + elif filter_ == "scalars": + eq_( + all_, ["name%d" % i for i in range(1, 20)], + ) + else: + eq_(all_, [(i, "name%d" % i) for i in range(1, 20)]) + + @testing.combinations( + (None,), ("scalars",), ("mappings",), argnames="filter_" + ) + @async_test + async def test_aiter(self, async_engine, filter_): + users = self.tables.users + async with async_engine.connect() as conn: + result = await conn.stream(select(users)) + + if filter_ == "mappings": + result = result.mappings() + elif filter_ == "scalars": + result = result.scalars(1) + + rows = [] + + async for row in result: + rows.append(row) + + if filter_ == "mappings": + eq_( + rows, + [ + {"user_id": i, "user_name": "name%d" % i} + for i in range(1, 20) + ], + ) + elif filter_ == "scalars": + eq_( + rows, ["name%d" % i for i in range(1, 20)], + ) + else: + eq_(rows, [(i, "name%d" % i) for i in range(1, 20)]) + + @testing.combinations((None,), ("mappings",), argnames="filter_") + @async_test + async def test_keys(self, async_engine, filter_): + users = self.tables.users + async with async_engine.connect() as conn: + result = await conn.stream(select(users)) + + if filter_ == "mappings": + result = result.mappings() + + eq_(result.keys(), ["user_id", "user_name"]) + + @async_test + async def test_unique_all(self, async_engine): + users = self.tables.users + async with async_engine.connect() as conn: + result = await conn.stream( + union_all(select(users), select(users)).order_by( + users.c.user_id + ) + ) + + all_ = await result.unique().all() + eq_(all_, [(i, "name%d" % i) for i in range(1, 20)]) + + @async_test + async def test_columns_all(self, async_engine): + users = self.tables.users + async with async_engine.connect() as conn: + result = await conn.stream(select(users)) + + all_ = await result.columns(1).all() + eq_(all_, [("name%d" % i,) for i in range(1, 20)]) + + @testing.combinations( + (None,), ("scalars",), ("mappings",), argnames="filter_" + ) + @async_test + async def test_partitions(self, async_engine, filter_): + users = self.tables.users + async with async_engine.connect() as conn: + result = await conn.stream(select(users)) + + if filter_ == "mappings": + result = result.mappings() + elif filter_ == "scalars": + result = result.scalars(1) + + check_result = [] + async for partition in result.partitions(5): + check_result.append(partition) + + if filter_ == "mappings": + eq_( + check_result, + [ + [ + {"user_id": i, "user_name": "name%d" % i} + for i in range(a, b) + ] + for (a, b) in [(1, 6), (6, 11), (11, 16), (16, 20)] + ], + ) + elif filter_ == "scalars": + eq_( + check_result, + [ + ["name%d" % i for i in range(a, b)] + for (a, b) in [(1, 6), (6, 11), (11, 16), (16, 20)] + ], + ) + else: + eq_( + check_result, + [ + [(i, "name%d" % i) for i in range(a, b)] + for (a, b) in [(1, 6), (6, 11), (11, 16), (16, 20)] + ], + ) + + @testing.combinations( + (None,), ("scalars",), ("mappings",), argnames="filter_" + ) + @async_test + async def test_one_success(self, async_engine, filter_): + users = self.tables.users + async with async_engine.connect() as conn: + result = await conn.stream( + select(users).limit(1).order_by(users.c.user_name) + ) + + if filter_ == "mappings": + result = result.mappings() + elif filter_ == "scalars": + result = result.scalars() + u1 = await result.one() + + if filter_ == "mappings": + eq_(u1, {"user_id": 1, "user_name": "name%d" % 1}) + elif filter_ == "scalars": + eq_(u1, 1) + else: + eq_(u1, (1, "name%d" % 1)) + + @async_test + async def test_one_no_result(self, async_engine): + users = self.tables.users + async with async_engine.connect() as conn: + result = await conn.stream( + select(users).where(users.c.user_name == "nonexistent") + ) + + async def go(): + await result.one() + + await assert_raises_message_async( + exc.NoResultFound, + "No row was found when one was required", + go(), + ) + + @async_test + async def test_one_multi_result(self, async_engine): + users = self.tables.users + async with async_engine.connect() as conn: + result = await conn.stream( + select(users).where(users.c.user_name.in_(["name3", "name5"])) + ) + + async def go(): + await result.one() + + await assert_raises_message_async( + exc.MultipleResultsFound, + "Multiple rows were found when exactly one was required", + go(), + ) |
