summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2021-01-05 08:48:36 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2021-01-05 21:56:18 -0500
commitc12012452e11093a06e3c83c19fe4d794f5bb21e (patch)
tree1681e269d80f0e74da7eeb27bd177f967318c601
parent640cd8a70f8a664b7834c5f74ec322fdea644043 (diff)
downloadsqlalchemy-c12012452e11093a06e3c83c19fe4d794f5bb21e.tar.gz
Remove special rule for TypeDecorator of TypeDecorator
Removing this check for "TypeDecorator" in impl seems to not break anything and allows TypeDecorator.with_variant() to work correctly. The line has been traced back to 2007 and does not appear to have relevance today. Fixed bug where making use of the :meth:`.TypeEngine.with_variant` method on a :class:`.TypeDecorator` type would fail to take into account the dialect-specific mappings in use, due to a rule in :class:`.TypeDecorator` that was instead attempting to check for chains of :class:`.TypeDecorator` instances. Fixes: #5816 Change-Id: Ic86d9d985810e3050f15972b4841108acca2fa3e
-rw-r--r--doc/build/changelog/unreleased_13/5816.rst10
-rw-r--r--lib/sqlalchemy/sql/type_api.py6
-rw-r--r--test/sql/test_types.py209
3 files changed, 218 insertions, 7 deletions
diff --git a/doc/build/changelog/unreleased_13/5816.rst b/doc/build/changelog/unreleased_13/5816.rst
new file mode 100644
index 000000000..5049622a8
--- /dev/null
+++ b/doc/build/changelog/unreleased_13/5816.rst
@@ -0,0 +1,10 @@
+.. change::
+ :tags: bug, sql
+ :tickets: 5816
+
+ Fixed bug where making use of the :meth:`.TypeEngine.with_variant` method
+ on a :class:`.TypeDecorator` type would fail to take into account the
+ dialect-specific mappings in use, due to a rule in :class:`.TypeDecorator`
+ that was instead attempting to check for chains of :class:`.TypeDecorator`
+ instances.
+
diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py
index 59e0f18dd..462a8763b 100644
--- a/lib/sqlalchemy/sql/type_api.py
+++ b/lib/sqlalchemy/sql/type_api.py
@@ -1082,8 +1082,7 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
In most cases this returns a dialect-adapted form of
the :class:`.TypeEngine` type represented by ``self.impl``.
- Makes usage of :meth:`dialect_impl` but also traverses
- into wrapped :class:`.TypeDecorator` instances.
+ Makes usage of :meth:`dialect_impl`.
Behavior can be customized here by overriding
:meth:`load_dialect_impl`.
@@ -1091,8 +1090,6 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
adapted = dialect.type_descriptor(self)
if not isinstance(adapted, type(self)):
return adapted
- elif isinstance(self.impl, TypeDecorator):
- return self.impl.type_engine(dialect)
else:
return self.load_dialect_impl(dialect)
@@ -1117,7 +1114,6 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
method.
"""
-
# some dialects have a lookup for a TypeDecorator subclass directly.
# postgresql.INTERVAL being the main example
typ = self.dialect_impl(dialect)
diff --git a/test/sql/test_types.py b/test/sql/test_types.py
index 77aefc190..0e1147800 100644
--- a/test/sql/test_types.py
+++ b/test/sql/test_types.py
@@ -584,6 +584,9 @@ class _UserDefinedTypeFixture(object):
def copy(self):
return MyUnicodeType(self.impl.length)
+ class MyDecOfDec(types.TypeDecorator):
+ impl = MyNewIntType
+
Table(
"users",
metadata,
@@ -596,6 +599,7 @@ class _UserDefinedTypeFixture(object):
Column("goofy7", MyNewUnicodeType(50), nullable=False),
Column("goofy8", MyNewIntType, nullable=False),
Column("goofy9", MyNewIntSubClass, nullable=False),
+ Column("goofy10", MyDecOfDec, nullable=False),
)
@@ -614,6 +618,7 @@ class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest):
goofy7=util.u("jack"),
goofy8=12,
goofy9=12,
+ goofy10=12,
),
)
connection.execute(
@@ -626,6 +631,7 @@ class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest):
goofy7=util.u("lala"),
goofy8=15,
goofy9=15,
+ goofy10=15,
),
)
connection.execute(
@@ -638,6 +644,7 @@ class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest):
goofy7=util.u("fred"),
goofy8=9,
goofy9=9,
+ goofy10=9,
),
)
@@ -665,7 +672,7 @@ class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest):
for col in row[3], row[4]:
assert isinstance(col, util.text_type)
- def test_plain_in(self, connection):
+ def test_plain_in_typedec(self, connection):
users = self.tables.users
self._data_fixture(connection)
@@ -677,7 +684,19 @@ class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest):
result = connection.execute(stmt, {"goofy": [15, 9]})
eq_(result.fetchall(), [(3, 1500), (4, 900)])
- def test_expanding_in(self, connection):
+ def test_plain_in_typedec_of_typedec(self, connection):
+ users = self.tables.users
+ self._data_fixture(connection)
+
+ stmt = (
+ select(users.c.user_id, users.c.goofy10)
+ .where(users.c.goofy10.in_([15, 9]))
+ .order_by(users.c.user_id)
+ )
+ result = connection.execute(stmt, {"goofy": [15, 9]})
+ eq_(result.fetchall(), [(3, 1500), (4, 900)])
+
+ def test_expanding_in_typedec(self, connection):
users = self.tables.users
self._data_fixture(connection)
@@ -689,6 +708,18 @@ class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest):
result = connection.execute(stmt, {"goofy": [15, 9]})
eq_(result.fetchall(), [(3, 1500), (4, 900)])
+ def test_expanding_in_typedec_of_typedec(self, connection):
+ users = self.tables.users
+ self._data_fixture(connection)
+
+ stmt = (
+ select(users.c.user_id, users.c.goofy10)
+ .where(users.c.goofy10.in_(bindparam("goofy", expanding=True)))
+ .order_by(users.c.user_id)
+ )
+ result = connection.execute(stmt, {"goofy": [15, 9]})
+ eq_(result.fetchall(), [(3, 1500), (4, 900)])
+
class UserDefinedTest(
_UserDefinedTypeFixture, fixtures.TablesTest, AssertsCompiledSQL
@@ -1177,6 +1208,172 @@ class TypeCoerceCastTest(fixtures.TablesTest):
)
+class VariantBackendTest(fixtures.TestBase, AssertsCompiledSQL):
+ __backend__ = True
+
+ @testing.fixture
+ def variant_roundtrip(self, metadata, connection):
+ def run(datatype, data, assert_data):
+ t = Table(
+ "t",
+ metadata,
+ Column("data", datatype),
+ )
+ t.create(connection)
+
+ connection.execute(t.insert(), [{"data": elem} for elem in data])
+ eq_(
+ connection.execute(select(t).order_by(t.c.data)).all(),
+ [(elem,) for elem in assert_data],
+ )
+
+ eq_(
+ # test an IN, which in 1.4 is an expanding
+ connection.execute(
+ select(t).where(t.c.data.in_(data)).order_by(t.c.data)
+ ).all(),
+ [(elem,) for elem in assert_data],
+ )
+
+ return run
+
+ def test_type_decorator_variant_one_roundtrip(self, variant_roundtrip):
+ class Foo(TypeDecorator):
+ impl = String(50)
+
+ if testing.against("postgresql"):
+ data = [5, 6, 10]
+ else:
+ data = ["five", "six", "ten"]
+ variant_roundtrip(
+ Foo().with_variant(Integer, "postgresql"), data, data
+ )
+
+ def test_type_decorator_variant_two(self, variant_roundtrip):
+ class UTypeOne(types.UserDefinedType):
+ def get_col_spec(self):
+ return "VARCHAR(50)"
+
+ def bind_processor(self, dialect):
+ def process(value):
+ return value + "UONE"
+
+ return process
+
+ class UTypeTwo(types.UserDefinedType):
+ def get_col_spec(self):
+ return "VARCHAR(50)"
+
+ def bind_processor(self, dialect):
+ def process(value):
+ return value + "UTWO"
+
+ return process
+
+ variant = UTypeOne()
+ for db in ["postgresql", "mysql", "mariadb"]:
+ variant = variant.with_variant(UTypeTwo(), db)
+
+ class Foo(TypeDecorator):
+ impl = variant
+
+ if testing.against("postgresql"):
+ data = assert_data = [5, 6, 10]
+ elif testing.against("mysql") or testing.against("mariadb"):
+ data = ["five", "six", "ten"]
+ assert_data = ["fiveUTWO", "sixUTWO", "tenUTWO"]
+ else:
+ data = ["five", "six", "ten"]
+ assert_data = ["fiveUONE", "sixUONE", "tenUONE"]
+
+ variant_roundtrip(
+ Foo().with_variant(Integer, "postgresql"), data, assert_data
+ )
+
+ def test_type_decorator_variant_three(self, variant_roundtrip):
+ class Foo(TypeDecorator):
+ impl = String
+
+ if testing.against("postgresql"):
+ data = ["five", "six", "ten"]
+ else:
+ data = [5, 6, 10]
+
+ variant_roundtrip(
+ Integer().with_variant(Foo(), "postgresql"), data, data
+ )
+
+ def test_type_decorator_compile_variant_one(self):
+ class Foo(TypeDecorator):
+ impl = String
+
+ self.assert_compile(
+ Foo().with_variant(Integer, "sqlite"),
+ "INTEGER",
+ dialect=dialects.sqlite.dialect(),
+ )
+
+ self.assert_compile(
+ Foo().with_variant(Integer, "sqlite"),
+ "VARCHAR",
+ dialect=dialects.postgresql.dialect(),
+ )
+
+ def test_type_decorator_compile_variant_two(self):
+ class UTypeOne(types.UserDefinedType):
+ def get_col_spec(self):
+ return "UTYPEONE"
+
+ def bind_processor(self, dialect):
+ def process(value):
+ return value + "UONE"
+
+ return process
+
+ class UTypeTwo(types.UserDefinedType):
+ def get_col_spec(self):
+ return "UTYPETWO"
+
+ def bind_processor(self, dialect):
+ def process(value):
+ return value + "UTWO"
+
+ return process
+
+ variant = UTypeOne().with_variant(UTypeTwo(), "postgresql")
+
+ class Foo(TypeDecorator):
+ impl = variant
+
+ self.assert_compile(
+ Foo().with_variant(Integer, "sqlite"),
+ "INTEGER",
+ dialect=dialects.sqlite.dialect(),
+ )
+
+ self.assert_compile(
+ Foo().with_variant(Integer, "sqlite"),
+ "UTYPETWO",
+ dialect=dialects.postgresql.dialect(),
+ )
+
+ def test_type_decorator_compile_variant_three(self):
+ class Foo(TypeDecorator):
+ impl = String
+
+ self.assert_compile(
+ Integer().with_variant(Foo(), "postgresql"),
+ "INTEGER",
+ dialect=dialects.sqlite.dialect(),
+ )
+
+ self.assert_compile(
+ Integer().with_variant(Foo(), "postgresql"),
+ "VARCHAR",
+ dialect=dialects.postgresql.dialect(),
+ )
+
+
class VariantTest(fixtures.TestBase, AssertsCompiledSQL):
def setup(self):
class UTypeOne(types.UserDefinedType):
@@ -2539,6 +2736,9 @@ class ExpressionTest(
def process_result_value(self, value, dialect):
return value + "BIND_OUT"
+ class MyDecOfDec(types.TypeDecorator):
+ impl = MyTypeDec
+
Table(
"test",
metadata,
@@ -2547,6 +2747,7 @@ class ExpressionTest(
Column("atimestamp", Date),
Column("avalue", MyCustomType),
Column("bvalue", MyTypeDec(50)),
+ Column("cvalue", MyDecOfDec(50)),
)
@classmethod
@@ -2560,6 +2761,7 @@ class ExpressionTest(
"atimestamp": datetime.date(2007, 10, 15),
"avalue": 25,
"bvalue": "foo",
+ "cvalue": "foo",
},
)
@@ -2579,6 +2781,7 @@ class ExpressionTest(
datetime.date(2007, 10, 15),
25,
"BIND_INfooBIND_OUT",
+ "BIND_INfooBIND_OUT",
)
],
)
@@ -2617,6 +2820,7 @@ class ExpressionTest(
datetime.date(2007, 10, 15),
25,
"BIND_INfooBIND_OUT",
+ "BIND_INfooBIND_OUT",
)
],
)
@@ -2635,6 +2839,7 @@ class ExpressionTest(
datetime.date(2007, 10, 15),
25,
"BIND_INfooBIND_OUT",
+ "BIND_INfooBIND_OUT",
)
],
)