diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2021-02-22 21:49:09 -0500 |
---|---|---|
committer | mike bayer <mike_mp@zzzcomputing.com> | 2021-02-25 20:49:52 +0000 |
commit | a5b76f7b07e620ece882137ceb51bace3898fad5 (patch) | |
tree | 03aac63d2e9af851a317503e577e6bea2ff35793 /lib/sqlalchemy/dialects/postgresql/asyncpg.py | |
parent | dc615763d39916e9c037c7c376db1817cdf02764 (diff) | |
download | sqlalchemy-a5b76f7b07e620ece882137ceb51bace3898fad5.tar.gz |
mutex asyncpg / aiomysql connection state changes
Added an ``asyncio.Lock()`` within SQLAlchemy's emulated DBAPI cursor,
local to the connection, for the asyncpg dialect, so that the space between
the call to ``prepare()`` and ``fetch()`` is prevented from allowing
concurrent executions on the connection from causing interface error
exceptions, as well as preventing race conditions when starting a new
transaction. Other PostgreSQL DBAPIs are threadsafe at the connection level
so this intends to provide a similar behavior, outside the realm of server
side cursors.
Apply the same idea to the aiomysql dialect which also would
otherwise be subject to corruption if the connection were used
concurrently.
While this is an issue which can also occur with the threaded
connection libraries, we anticipate asyncio users are more likely
to attempt using the same connection in multiple awaitables
at a time, even though this won't achieve concurrency for that
use case, as the asyncio programming style is very encouraging
of this. As the failure modes are also more complicated under
asyncio, we'd rather not have this being reported.
Fixes: #5967
Change-Id: I3670ba0c8f0b593c587c5aa7c6c61f9e8c5eb93a
Diffstat (limited to 'lib/sqlalchemy/dialects/postgresql/asyncpg.py')
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/asyncpg.py | 117 |
1 files changed, 68 insertions, 49 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index 7ef5e441c..4580421f6 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -122,6 +122,7 @@ from ... import pool from ... import processors from ... import util from ...sql import sqltypes +from ...util.concurrency import asyncio from ...util.concurrency import await_fallback from ...util.concurrency import await_only @@ -369,74 +370,90 @@ class AsyncAdapt_asyncpg_cursor: ) async def _prepare_and_execute(self, operation, parameters): + adapt_connection = self._adapt_connection - if not self._adapt_connection._started: - await self._adapt_connection._start_transaction() - - if parameters is not None: - operation = operation % self._parameter_placeholders(parameters) - else: - parameters = () + async with adapt_connection._execute_mutex: - try: - prepared_stmt, attributes = await self._adapt_connection._prepare( - operation, self._invalidate_schema_cache_asof - ) + if not adapt_connection._started: + await adapt_connection._start_transaction() - if attributes: - self.description = [ - (attr.name, attr.type.oid, None, None, None, None, None) - for attr in attributes - ] + if parameters is not None: + operation = operation % self._parameter_placeholders( + parameters + ) else: - self.description = None + parameters = () - if self.server_side: - self._cursor = await prepared_stmt.cursor(*parameters) - self.rowcount = -1 - else: - self._rows = await prepared_stmt.fetch(*parameters) - status = prepared_stmt.get_statusmsg() + try: + prepared_stmt, attributes = await adapt_connection._prepare( + operation, self._invalidate_schema_cache_asof + ) - reg = re.match(r"(?:UPDATE|DELETE|INSERT \d+) (\d+)", status) - if reg: - self.rowcount = int(reg.group(1)) + if attributes: + self.description = [ + ( + attr.name, + attr.type.oid, + None, + None, + None, + None, + None, + ) + for attr in attributes + ] else: + self.description = None + + if self.server_side: + self._cursor = await prepared_stmt.cursor(*parameters) self.rowcount = -1 + else: + self._rows = await prepared_stmt.fetch(*parameters) + status = prepared_stmt.get_statusmsg() - except Exception as error: - self._handle_exception(error) + reg = re.match( + r"(?:UPDATE|DELETE|INSERT \d+) (\d+)", status + ) + if reg: + self.rowcount = int(reg.group(1)) + else: + self.rowcount = -1 - def execute(self, operation, parameters=None): - try: - self._adapt_connection.await_( - self._prepare_and_execute(operation, parameters) - ) - except Exception as error: - self._handle_exception(error) + except Exception as error: + self._handle_exception(error) - def executemany(self, operation, seq_of_parameters): + async def _executemany(self, operation, seq_of_parameters): adapt_connection = self._adapt_connection - adapt_connection.await_( - adapt_connection._check_type_cache_invalidation( + async with adapt_connection._execute_mutex: + await adapt_connection._check_type_cache_invalidation( self._invalidate_schema_cache_asof ) - ) - if not adapt_connection._started: - adapt_connection.await_(adapt_connection._start_transaction()) + if not adapt_connection._started: + await adapt_connection._start_transaction() - operation = operation % self._parameter_placeholders( - seq_of_parameters[0] + operation = operation % self._parameter_placeholders( + seq_of_parameters[0] + ) + + try: + return await self._connection.executemany( + operation, seq_of_parameters + ) + except Exception as error: + self._handle_exception(error) + + def execute(self, operation, parameters=None): + self._adapt_connection.await_( + self._prepare_and_execute(operation, parameters) ) - try: - return adapt_connection.await_( - self._connection.executemany(operation, seq_of_parameters) - ) - except Exception as error: - self._handle_exception(error) + def executemany(self, operation, seq_of_parameters): + return self._adapt_connection.await_( + self._executemany(operation, seq_of_parameters) + ) def setinputsizes(self, *inputsizes): self._inputsizes = inputsizes @@ -561,6 +578,7 @@ class AsyncAdapt_asyncpg_connection: "_started", "_prepared_statement_cache", "_invalidate_schema_cache_asof", + "_execute_mutex", ) await_ = staticmethod(await_only) @@ -574,6 +592,7 @@ class AsyncAdapt_asyncpg_connection: self._transaction = None self._started = False self._invalidate_schema_cache_asof = time.time() + self._execute_mutex = asyncio.Lock() if prepared_statement_cache_size: self._prepared_statement_cache = util.LRUCache( |