summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/dialects/postgresql/asyncpg.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/dialects/postgresql/asyncpg.py')
-rw-r--r--lib/sqlalchemy/dialects/postgresql/asyncpg.py117
1 files changed, 68 insertions, 49 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py
index 7ef5e441c..4580421f6 100644
--- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py
+++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py
@@ -122,6 +122,7 @@ from ... import pool
from ... import processors
from ... import util
from ...sql import sqltypes
+from ...util.concurrency import asyncio
from ...util.concurrency import await_fallback
from ...util.concurrency import await_only
@@ -369,74 +370,90 @@ class AsyncAdapt_asyncpg_cursor:
)
async def _prepare_and_execute(self, operation, parameters):
+ adapt_connection = self._adapt_connection
- if not self._adapt_connection._started:
- await self._adapt_connection._start_transaction()
-
- if parameters is not None:
- operation = operation % self._parameter_placeholders(parameters)
- else:
- parameters = ()
+ async with adapt_connection._execute_mutex:
- try:
- prepared_stmt, attributes = await self._adapt_connection._prepare(
- operation, self._invalidate_schema_cache_asof
- )
+ if not adapt_connection._started:
+ await adapt_connection._start_transaction()
- if attributes:
- self.description = [
- (attr.name, attr.type.oid, None, None, None, None, None)
- for attr in attributes
- ]
+ if parameters is not None:
+ operation = operation % self._parameter_placeholders(
+ parameters
+ )
else:
- self.description = None
+ parameters = ()
- if self.server_side:
- self._cursor = await prepared_stmt.cursor(*parameters)
- self.rowcount = -1
- else:
- self._rows = await prepared_stmt.fetch(*parameters)
- status = prepared_stmt.get_statusmsg()
+ try:
+ prepared_stmt, attributes = await adapt_connection._prepare(
+ operation, self._invalidate_schema_cache_asof
+ )
- reg = re.match(r"(?:UPDATE|DELETE|INSERT \d+) (\d+)", status)
- if reg:
- self.rowcount = int(reg.group(1))
+ if attributes:
+ self.description = [
+ (
+ attr.name,
+ attr.type.oid,
+ None,
+ None,
+ None,
+ None,
+ None,
+ )
+ for attr in attributes
+ ]
else:
+ self.description = None
+
+ if self.server_side:
+ self._cursor = await prepared_stmt.cursor(*parameters)
self.rowcount = -1
+ else:
+ self._rows = await prepared_stmt.fetch(*parameters)
+ status = prepared_stmt.get_statusmsg()
- except Exception as error:
- self._handle_exception(error)
+ reg = re.match(
+ r"(?:UPDATE|DELETE|INSERT \d+) (\d+)", status
+ )
+ if reg:
+ self.rowcount = int(reg.group(1))
+ else:
+ self.rowcount = -1
- def execute(self, operation, parameters=None):
- try:
- self._adapt_connection.await_(
- self._prepare_and_execute(operation, parameters)
- )
- except Exception as error:
- self._handle_exception(error)
+ except Exception as error:
+ self._handle_exception(error)
- def executemany(self, operation, seq_of_parameters):
+ async def _executemany(self, operation, seq_of_parameters):
adapt_connection = self._adapt_connection
- adapt_connection.await_(
- adapt_connection._check_type_cache_invalidation(
+ async with adapt_connection._execute_mutex:
+ await adapt_connection._check_type_cache_invalidation(
self._invalidate_schema_cache_asof
)
- )
- if not adapt_connection._started:
- adapt_connection.await_(adapt_connection._start_transaction())
+ if not adapt_connection._started:
+ await adapt_connection._start_transaction()
- operation = operation % self._parameter_placeholders(
- seq_of_parameters[0]
+ operation = operation % self._parameter_placeholders(
+ seq_of_parameters[0]
+ )
+
+ try:
+ return await self._connection.executemany(
+ operation, seq_of_parameters
+ )
+ except Exception as error:
+ self._handle_exception(error)
+
+ def execute(self, operation, parameters=None):
+ self._adapt_connection.await_(
+ self._prepare_and_execute(operation, parameters)
)
- try:
- return adapt_connection.await_(
- self._connection.executemany(operation, seq_of_parameters)
- )
- except Exception as error:
- self._handle_exception(error)
+ def executemany(self, operation, seq_of_parameters):
+ return self._adapt_connection.await_(
+ self._executemany(operation, seq_of_parameters)
+ )
def setinputsizes(self, *inputsizes):
self._inputsizes = inputsizes
@@ -561,6 +578,7 @@ class AsyncAdapt_asyncpg_connection:
"_started",
"_prepared_statement_cache",
"_invalidate_schema_cache_asof",
+ "_execute_mutex",
)
await_ = staticmethod(await_only)
@@ -574,6 +592,7 @@ class AsyncAdapt_asyncpg_connection:
self._transaction = None
self._started = False
self._invalidate_schema_cache_asof = time.time()
+ self._execute_mutex = asyncio.Lock()
if prepared_statement_cache_size:
self._prepared_statement_cache = util.LRUCache(