diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-08-04 10:27:59 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-08-05 10:39:39 -0400 |
commit | fce1d954aa57feca9c163f9d8cf66df5e8ce7b65 (patch) | |
tree | 7412139205de0379b5e47e549b87c80bfe618da9 /lib/sqlalchemy/dialects/postgresql/asyncpg.py | |
parent | eeff036db61377b8159757e6cc2a2d83d85bf69e (diff) | |
download | sqlalchemy-fce1d954aa57feca9c163f9d8cf66df5e8ce7b65.tar.gz |
implement PG ranges/multiranges agnostically
Ranges now work using a new Range object,
multiranges as lists of Range objects (this is what
asyncpg does. not sure why psycopg has a "Multirange"
type).
psycopg, psycopg2, and asyncpg are currently supported.
It's not clear how to make ranges work with pg8000, likely
needs string conversion; this is straightforward with the
new archicture and can be added later.
Fixes: #8178
Change-Id: Iab8d8382873d5c14199adbe3f09fd0dc17e2b9f1
Diffstat (limited to 'lib/sqlalchemy/dialects/postgresql/asyncpg.py')
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/asyncpg.py | 95 |
1 files changed, 95 insertions, 0 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index d6385a5d6..38f8fddee 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -119,14 +119,19 @@ client using this setting passed to :func:`_asyncio.create_async_engine`:: """ # noqa +from __future__ import annotations + import collections import collections.abc as collections_abc import decimal import json as _py_json import re import time +from typing import cast +from typing import TYPE_CHECKING from . import json +from . import ranges from .base import _DECIMAL_TYPES from .base import _FLOAT_TYPES from .base import _INT_TYPES @@ -148,6 +153,9 @@ from ...util.concurrency import asyncio from ...util.concurrency import await_fallback from ...util.concurrency import await_only +if TYPE_CHECKING: + from typing import Iterable + class AsyncpgString(sqltypes.String): render_bind_cast = True @@ -278,6 +286,91 @@ class AsyncpgCHAR(sqltypes.CHAR): render_bind_cast = True +class _AsyncpgRange(ranges.AbstractRange): + def bind_processor(self, dialect): + Range = dialect.dbapi.asyncpg.Range + + NoneType = type(None) + + def to_range(value): + if not isinstance(value, (str, NoneType)): + value = Range( + value.lower, + value.upper, + lower_inc=value.bounds[0] == "[", + upper_inc=value.bounds[1] == "]", + empty=value.empty, + ) + return value + + return to_range + + def result_processor(self, dialect, coltype): + def to_range(value): + if value is not None: + empty = value.isempty + value = ranges.Range( + value.lower, + value.upper, + bounds=f"{'[' if empty or value.lower_inc else '('}" # type: ignore # noqa: E501 + f"{']' if not empty and value.upper_inc else ')'}", + empty=empty, + ) + return value + + return to_range + + +class _AsyncpgMultiRange(ranges.AbstractMultiRange): + def bind_processor(self, dialect): + Range = dialect.dbapi.asyncpg.Range + + NoneType = type(None) + + def to_range(value): + if isinstance(value, (str, NoneType)): + return value + + def to_range(value): + if not isinstance(value, (str, NoneType)): + value = Range( + value.lower, + value.upper, + lower_inc=value.bounds[0] == "[", + upper_inc=value.bounds[1] == "]", + empty=value.empty, + ) + return value + + return [ + to_range(element) + for element in cast("Iterable[ranges.Range]", value) + ] + + return to_range + + def result_processor(self, dialect, coltype): + def to_range_array(value): + def to_range(rvalue): + if rvalue is not None: + empty = rvalue.isempty + rvalue = ranges.Range( + rvalue.lower, + rvalue.upper, + bounds=f"{'[' if empty or rvalue.lower_inc else '('}" # type: ignore # noqa: E501 + f"{']' if not empty and rvalue.upper_inc else ')'}", + empty=empty, + ) + return rvalue + + if value is not None: + value = [to_range(elem) for elem in value] + + return value + + return to_range_array + + class PGExecutionContext_asyncpg(PGExecutionContext): def handle_dbapi_exception(self, e): if isinstance( @@ -828,6 +921,8 @@ class PGDialect_asyncpg(PGDialect): OID: AsyncpgOID, REGCLASS: AsyncpgREGCLASS, sqltypes.CHAR: AsyncpgCHAR, + ranges.AbstractRange: _AsyncpgRange, + ranges.AbstractMultiRange: _AsyncpgMultiRange, }, ) is_async = True |