summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/dialects/postgresql/asyncpg.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-08-04 10:27:59 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-08-05 10:39:39 -0400
commitfce1d954aa57feca9c163f9d8cf66df5e8ce7b65 (patch)
tree7412139205de0379b5e47e549b87c80bfe618da9 /lib/sqlalchemy/dialects/postgresql/asyncpg.py
parenteeff036db61377b8159757e6cc2a2d83d85bf69e (diff)
downloadsqlalchemy-fce1d954aa57feca9c163f9d8cf66df5e8ce7b65.tar.gz
implement PG ranges/multiranges agnostically
Ranges now work using a new Range object, multiranges as lists of Range objects (this is what asyncpg does. not sure why psycopg has a "Multirange" type). psycopg, psycopg2, and asyncpg are currently supported. It's not clear how to make ranges work with pg8000, likely needs string conversion; this is straightforward with the new archicture and can be added later. Fixes: #8178 Change-Id: Iab8d8382873d5c14199adbe3f09fd0dc17e2b9f1
Diffstat (limited to 'lib/sqlalchemy/dialects/postgresql/asyncpg.py')
-rw-r--r--lib/sqlalchemy/dialects/postgresql/asyncpg.py95
1 files changed, 95 insertions, 0 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py
index d6385a5d6..38f8fddee 100644
--- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py
+++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py
@@ -119,14 +119,19 @@ client using this setting passed to :func:`_asyncio.create_async_engine`::
""" # noqa
+from __future__ import annotations
+
import collections
import collections.abc as collections_abc
import decimal
import json as _py_json
import re
import time
+from typing import cast
+from typing import TYPE_CHECKING
from . import json
+from . import ranges
from .base import _DECIMAL_TYPES
from .base import _FLOAT_TYPES
from .base import _INT_TYPES
@@ -148,6 +153,9 @@ from ...util.concurrency import asyncio
from ...util.concurrency import await_fallback
from ...util.concurrency import await_only
+if TYPE_CHECKING:
+ from typing import Iterable
+
class AsyncpgString(sqltypes.String):
render_bind_cast = True
@@ -278,6 +286,91 @@ class AsyncpgCHAR(sqltypes.CHAR):
render_bind_cast = True
+class _AsyncpgRange(ranges.AbstractRange):
+ def bind_processor(self, dialect):
+ Range = dialect.dbapi.asyncpg.Range
+
+ NoneType = type(None)
+
+ def to_range(value):
+ if not isinstance(value, (str, NoneType)):
+ value = Range(
+ value.lower,
+ value.upper,
+ lower_inc=value.bounds[0] == "[",
+ upper_inc=value.bounds[1] == "]",
+ empty=value.empty,
+ )
+ return value
+
+ return to_range
+
+ def result_processor(self, dialect, coltype):
+ def to_range(value):
+ if value is not None:
+ empty = value.isempty
+ value = ranges.Range(
+ value.lower,
+ value.upper,
+ bounds=f"{'[' if empty or value.lower_inc else '('}" # type: ignore # noqa: E501
+ f"{']' if not empty and value.upper_inc else ')'}",
+ empty=empty,
+ )
+ return value
+
+ return to_range
+
+
+class _AsyncpgMultiRange(ranges.AbstractMultiRange):
+ def bind_processor(self, dialect):
+ Range = dialect.dbapi.asyncpg.Range
+
+ NoneType = type(None)
+
+ def to_range(value):
+ if isinstance(value, (str, NoneType)):
+ return value
+
+ def to_range(value):
+ if not isinstance(value, (str, NoneType)):
+ value = Range(
+ value.lower,
+ value.upper,
+ lower_inc=value.bounds[0] == "[",
+ upper_inc=value.bounds[1] == "]",
+ empty=value.empty,
+ )
+ return value
+
+ return [
+ to_range(element)
+ for element in cast("Iterable[ranges.Range]", value)
+ ]
+
+ return to_range
+
+ def result_processor(self, dialect, coltype):
+ def to_range_array(value):
+ def to_range(rvalue):
+ if rvalue is not None:
+ empty = rvalue.isempty
+ rvalue = ranges.Range(
+ rvalue.lower,
+ rvalue.upper,
+ bounds=f"{'[' if empty or rvalue.lower_inc else '('}" # type: ignore # noqa: E501
+ f"{']' if not empty and rvalue.upper_inc else ')'}",
+ empty=empty,
+ )
+ return rvalue
+
+ if value is not None:
+ value = [to_range(elem) for elem in value]
+
+ return value
+
+ return to_range_array
+
+
class PGExecutionContext_asyncpg(PGExecutionContext):
def handle_dbapi_exception(self, e):
if isinstance(
@@ -828,6 +921,8 @@ class PGDialect_asyncpg(PGDialect):
OID: AsyncpgOID,
REGCLASS: AsyncpgREGCLASS,
sqltypes.CHAR: AsyncpgCHAR,
+ ranges.AbstractRange: _AsyncpgRange,
+ ranges.AbstractMultiRange: _AsyncpgMultiRange,
},
)
is_async = True