summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-01-05 12:20:46 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2022-01-11 09:25:19 -0500
commite215db01d48c418e190936e6b36ea49c6eb22072 (patch)
treefc0f3144fd7404128aae44f51ea4dc79619ef4d8 /lib/sqlalchemy
parentf96e24013c80d933cb8171061be3d316215fe585 (diff)
downloadsqlalchemy-e215db01d48c418e190936e6b36ea49c6eb22072.tar.gz
implement second-level type resolution for literals
Added additional rule to the system that determines ``TypeEngine`` implementations from Python literals to apply a second level of adjustment to the type, so that a Python datetime with or without tzinfo can set the ``timezone=True`` parameter on the returned :class:`.DateTime` object, as well as :class:`.Time`. This helps with some round-trip scenarios on type-sensitive PostgreSQL dialects such as asyncpg, psycopg3 (2.0 only). Improved support for asyncpg handling of TIME WITH TIMEZONE, which was not fully implemented. Fixes: #7537 Change-Id: Icdb07db85af5f7f39f1c1ef855fe27609770094b (cherry picked from commit 3b2e28bcb5ba32446a92b62b6862b7c11dabb592)
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/dialects/postgresql/asyncpg.py7
-rw-r--r--lib/sqlalchemy/sql/sqltypes.py18
-rw-r--r--lib/sqlalchemy/sql/type_api.py11
-rw-r--r--lib/sqlalchemy/testing/requirements.py33
-rw-r--r--lib/sqlalchemy/testing/suite/test_types.py29
5 files changed, 96 insertions, 2 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py
index fedc0b495..f32192b3c 100644
--- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py
+++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py
@@ -136,7 +136,10 @@ except ImportError:
class AsyncpgTime(sqltypes.Time):
def get_dbapi_type(self, dbapi):
- return dbapi.TIME
+ if self.timezone:
+ return dbapi.TIME_W_TZ
+ else:
+ return dbapi.TIME
class AsyncpgDate(sqltypes.Date):
@@ -818,6 +821,7 @@ class AsyncAdapt_asyncpg_dbapi:
TIMESTAMP = util.symbol("TIMESTAMP")
TIMESTAMP_W_TZ = util.symbol("TIMESTAMP_W_TZ")
TIME = util.symbol("TIME")
+ TIME_W_TZ = util.symbol("TIME_W_TZ")
DATE = util.symbol("DATE")
INTERVAL = util.symbol("INTERVAL")
NUMBER = util.symbol("NUMBER")
@@ -843,6 +847,7 @@ _pg_types = {
AsyncAdapt_asyncpg_dbapi.TIMESTAMP_W_TZ: "timestamp with time zone",
AsyncAdapt_asyncpg_dbapi.DATE: "date",
AsyncAdapt_asyncpg_dbapi.TIME: "time",
+ AsyncAdapt_asyncpg_dbapi.TIME_W_TZ: "time with time zone",
AsyncAdapt_asyncpg_dbapi.INTERVAL: "interval",
AsyncAdapt_asyncpg_dbapi.NUMBER: "numeric",
AsyncAdapt_asyncpg_dbapi.FLOAT: "float",
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py
index 3f3801ab0..c80b10fcc 100644
--- a/lib/sqlalchemy/sql/sqltypes.py
+++ b/lib/sqlalchemy/sql/sqltypes.py
@@ -867,6 +867,13 @@ class DateTime(_LookupExpressionAdapter, TypeEngine):
def get_dbapi_type(self, dbapi):
return dbapi.DATETIME
+ def _resolve_for_literal(self, value):
+ with_timezone = value.tzinfo is not None
+ if with_timezone and not self.timezone:
+ return DATETIME_TIMEZONE
+ else:
+ return self
+
@property
def python_type(self):
return dt.datetime
@@ -937,6 +944,13 @@ class Time(_LookupExpressionAdapter, TypeEngine):
def python_type(self):
return dt.time
+ def _resolve_for_literal(self, value):
+ with_timezone = value.tzinfo is not None
+ if with_timezone and not self.timezone:
+ return TIME_TIMEZONE
+ else:
+ return self
+
@util.memoized_property
def _expression_adaptations(self):
# Based on https://www.postgresql.org/docs/current/\
@@ -3254,6 +3268,8 @@ STRINGTYPE = String()
INTEGERTYPE = Integer()
MATCHTYPE = MatchType()
TABLEVALUE = TableValueType()
+DATETIME_TIMEZONE = DateTime(timezone=True)
+TIME_TIMEZONE = Time(timezone=True)
_type_map = {
int: Integer(),
@@ -3296,7 +3312,7 @@ def _resolve_value_to_type(value):
)
return NULLTYPE
else:
- return _result_type
+ return _result_type._resolve_for_literal(value)
# back-assign to type_api
diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py
index 49f6cfe20..ecf68e62d 100644
--- a/lib/sqlalchemy/sql/type_api.py
+++ b/lib/sqlalchemy/sql/type_api.py
@@ -545,6 +545,17 @@ class TypeEngine(Traversible):
"""
return Variant(self, {dialect_name: to_instance(type_)})
+ def _resolve_for_literal(self, value):
+ """adjust this type given a literal Python value that will be
+ stored in a bound parameter.
+
+ Used exclusively by _resolve_value_to_type().
+
+ .. versionadded:: 1.4.30 or 2.0
+
+ """
+ return self
+
@util.memoized_property
def _type_affinity(self):
"""Return a rudimental 'affinity' value expressing the general class
diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py
index a0f262a76..1c8858ec1 100644
--- a/lib/sqlalchemy/testing/requirements.py
+++ b/lib/sqlalchemy/testing/requirements.py
@@ -754,6 +754,29 @@ class SuiteRequirements(Requirements):
return exclusions.open()
@property
+ def datetime_timezone(self):
+ """target dialect supports representation of Python
+ datetime.datetime() with tzinfo with DateTime(timezone=True)."""
+
+ return exclusions.closed()
+
+ @property
+ def time_timezone(self):
+ """target dialect supports representation of Python
+ datetime.time() with tzinfo with Time(timezone=True)."""
+
+ return exclusions.closed()
+
+ @property
+ def datetime_implicit_bound(self):
+ """target dialect when given a datetime object will bind it such
+ that the database server knows the object is a datetime, and not
+ a plain string.
+
+ """
+ return exclusions.open()
+
+ @property
def datetime_microseconds(self):
"""target dialect supports representation of Python
datetime.datetime() with microsecond objects."""
@@ -768,6 +791,16 @@ class SuiteRequirements(Requirements):
return exclusions.closed()
@property
+ def timestamp_microseconds_implicit_bound(self):
+ """target dialect when given a datetime object which also includes
+ a microseconds portion when using the TIMESTAMP data type
+ will bind it such that the database server knows
+ the object is a datetime with microseconds, and not a plain string.
+
+ """
+ return self.timestamp_microseconds
+
+ @property
def datetime_historic(self):
"""target dialect supports representation of Python
datetime.datetime() objects with historic (pre 1970) values."""
diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py
index d62b60809..2fdea5e48 100644
--- a/lib/sqlalchemy/testing/suite/test_types.py
+++ b/lib/sqlalchemy/testing/suite/test_types.py
@@ -41,6 +41,7 @@ from ... import UnicodeText
from ... import util
from ...orm import declarative_base
from ...orm import Session
+from ...util import compat
from ...util import u
@@ -308,6 +309,11 @@ class _DateFixture(_LiteralRoundTripFixture, fixtures.TestBase):
Column("decorated_date_data", Decorated),
)
+ @testing.requires.datetime_implicit_bound
+ def test_select_direct(self, connection):
+ result = connection.scalar(select(literal(self.data)))
+ eq_(result, self.data)
+
def test_round_trip(self, connection):
date_table = self.tables.date_table
@@ -382,6 +388,15 @@ class DateTimeTest(_DateFixture, fixtures.TablesTest):
data = datetime.datetime(2012, 10, 15, 12, 57, 18)
+class DateTimeTZTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("datetime_timezone",)
+ __backend__ = True
+ datatype = DateTime(timezone=True)
+ data = datetime.datetime(
+ 2012, 10, 15, 12, 57, 18, tzinfo=compat.timezone.utc
+ )
+
+
class DateTimeMicrosecondsTest(_DateFixture, fixtures.TablesTest):
__requires__ = ("datetime_microseconds",)
__backend__ = True
@@ -395,6 +410,11 @@ class TimestampMicrosecondsTest(_DateFixture, fixtures.TablesTest):
datatype = TIMESTAMP
data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396)
+ @testing.requires.timestamp_microseconds_implicit_bound
+ def test_select_direct(self, connection):
+ result = connection.scalar(select(literal(self.data)))
+ eq_(result, self.data)
+
class TimeTest(_DateFixture, fixtures.TablesTest):
__requires__ = ("time",)
@@ -403,6 +423,13 @@ class TimeTest(_DateFixture, fixtures.TablesTest):
data = datetime.time(12, 57, 18)
+class TimeTZTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("time_timezone",)
+ __backend__ = True
+ datatype = Time(timezone=True)
+ data = datetime.time(12, 57, 18, tzinfo=compat.timezone.utc)
+
+
class TimeMicrosecondsTest(_DateFixture, fixtures.TablesTest):
__requires__ = ("time_microseconds",)
__backend__ = True
@@ -1424,6 +1451,7 @@ __all__ = (
"JSONLegacyStringCastIndexTest",
"DateTest",
"DateTimeTest",
+ "DateTimeTZTest",
"TextTest",
"NumericTest",
"IntegerTest",
@@ -1433,6 +1461,7 @@ __all__ = (
"TimeMicrosecondsTest",
"TimestampMicrosecondsTest",
"TimeTest",
+ "TimeTZTest",
"DateTimeMicrosecondsTest",
"DateHistoricTest",
"StringTest",