diff options
Diffstat (limited to 'lib/sqlalchemy/dialects/postgresql/asyncpg.py')
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/asyncpg.py | 117 |
1 files changed, 68 insertions, 49 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index 7ef5e441c..4580421f6 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -122,6 +122,7 @@ from ... import pool from ... import processors from ... import util from ...sql import sqltypes +from ...util.concurrency import asyncio from ...util.concurrency import await_fallback from ...util.concurrency import await_only @@ -369,74 +370,90 @@ class AsyncAdapt_asyncpg_cursor: ) async def _prepare_and_execute(self, operation, parameters): + adapt_connection = self._adapt_connection - if not self._adapt_connection._started: - await self._adapt_connection._start_transaction() - - if parameters is not None: - operation = operation % self._parameter_placeholders(parameters) - else: - parameters = () + async with adapt_connection._execute_mutex: - try: - prepared_stmt, attributes = await self._adapt_connection._prepare( - operation, self._invalidate_schema_cache_asof - ) + if not adapt_connection._started: + await adapt_connection._start_transaction() - if attributes: - self.description = [ - (attr.name, attr.type.oid, None, None, None, None, None) - for attr in attributes - ] + if parameters is not None: + operation = operation % self._parameter_placeholders( + parameters + ) else: - self.description = None + parameters = () - if self.server_side: - self._cursor = await prepared_stmt.cursor(*parameters) - self.rowcount = -1 - else: - self._rows = await prepared_stmt.fetch(*parameters) - status = prepared_stmt.get_statusmsg() + try: + prepared_stmt, attributes = await adapt_connection._prepare( + operation, self._invalidate_schema_cache_asof + ) - reg = re.match(r"(?:UPDATE|DELETE|INSERT \d+) (\d+)", status) - if reg: - self.rowcount = int(reg.group(1)) + if attributes: + self.description = [ + ( + attr.name, + attr.type.oid, + None, + None, + None, + None, + None, + ) + for attr in attributes + ] else: + self.description = None + + if self.server_side: + self._cursor = await prepared_stmt.cursor(*parameters) self.rowcount = -1 + else: + self._rows = await prepared_stmt.fetch(*parameters) + status = prepared_stmt.get_statusmsg() - except Exception as error: - self._handle_exception(error) + reg = re.match( + r"(?:UPDATE|DELETE|INSERT \d+) (\d+)", status + ) + if reg: + self.rowcount = int(reg.group(1)) + else: + self.rowcount = -1 - def execute(self, operation, parameters=None): - try: - self._adapt_connection.await_( - self._prepare_and_execute(operation, parameters) - ) - except Exception as error: - self._handle_exception(error) + except Exception as error: + self._handle_exception(error) - def executemany(self, operation, seq_of_parameters): + async def _executemany(self, operation, seq_of_parameters): adapt_connection = self._adapt_connection - adapt_connection.await_( - adapt_connection._check_type_cache_invalidation( + async with adapt_connection._execute_mutex: + await adapt_connection._check_type_cache_invalidation( self._invalidate_schema_cache_asof ) - ) - if not adapt_connection._started: - adapt_connection.await_(adapt_connection._start_transaction()) + if not adapt_connection._started: + await adapt_connection._start_transaction() - operation = operation % self._parameter_placeholders( - seq_of_parameters[0] + operation = operation % self._parameter_placeholders( + seq_of_parameters[0] + ) + + try: + return await self._connection.executemany( + operation, seq_of_parameters + ) + except Exception as error: + self._handle_exception(error) + + def execute(self, operation, parameters=None): + self._adapt_connection.await_( + self._prepare_and_execute(operation, parameters) ) - try: - return adapt_connection.await_( - self._connection.executemany(operation, seq_of_parameters) - ) - except Exception as error: - self._handle_exception(error) + def executemany(self, operation, seq_of_parameters): + return self._adapt_connection.await_( + self._executemany(operation, seq_of_parameters) + ) def setinputsizes(self, *inputsizes): self._inputsizes = inputsizes @@ -561,6 +578,7 @@ class AsyncAdapt_asyncpg_connection: "_started", "_prepared_statement_cache", "_invalidate_schema_cache_asof", + "_execute_mutex", ) await_ = staticmethod(await_only) @@ -574,6 +592,7 @@ class AsyncAdapt_asyncpg_connection: self._transaction = None self._started = False self._invalidate_schema_cache_asof = time.time() + self._execute_mutex = asyncio.Lock() if prepared_statement_cache_size: self._prepared_statement_cache = util.LRUCache( |