summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2017-03-07 12:53:00 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2017-03-07 16:24:18 -0500
commitc04870ba7b8098c7d408ad66f60efe7229496fde (patch)
tree48ce6b3cbb8225f85499e54dc92e7d58684812f7 /lib/sqlalchemy/sql
parent9e627159733da48e2fd2d25de93589eb079a75f4 (diff)
downloadsqlalchemy-c04870ba7b8098c7d408ad66f60efe7229496fde.tar.gz
Allow SchemaType and Variant to work together
Added support for the :class:`.Variant` and the :class:`.SchemaType` objects to be compatible with each other. That is, a variant can be created against a type like :class:`.Enum`, and the instructions to create constraints and/or database-specific type objects will propagate correctly as per the variant's dialect mapping. Also added testing for some potential double-event scenarios on TypeDecorator but it seems usually this doesn't occur. Change-Id: I4a7e7c26b4133cd14e870f5bc34a1b2f0f19a14a Fixes: #2892
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/sqltypes.py68
-rw-r--r--lib/sqlalchemy/sql/type_api.py22
2 files changed, 79 insertions, 11 deletions
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py
index bb39388ab..8a114ece6 100644
--- a/lib/sqlalchemy/sql/sqltypes.py
+++ b/lib/sqlalchemy/sql/sqltypes.py
@@ -15,7 +15,7 @@ import collections
import json
from . import elements
-from .type_api import TypeEngine, TypeDecorator, to_instance
+from .type_api import TypeEngine, TypeDecorator, to_instance, Variant
from .elements import quoted_name, TypeCoerce as type_coerce, _defer_name, \
Slice, _literal_as_binds
from .. import exc, util, processors
@@ -1003,6 +1003,14 @@ class SchemaType(SchemaEventTarget):
def _set_parent(self, column):
column._on_table_attach(util.portable_instancemethod(self._set_table))
+ def _variant_mapping_for_set_table(self, column):
+ if isinstance(column.type, Variant):
+ variant_mapping = column.type.mapping.copy()
+ variant_mapping['_default'] = column.type.impl
+ else:
+ variant_mapping = None
+ return variant_mapping
+
def _set_table(self, column, table):
if self.inherit_schema:
self.schema = table.schema
@@ -1010,16 +1018,21 @@ class SchemaType(SchemaEventTarget):
if not self._create_events:
return
+ variant_mapping = self._variant_mapping_for_set_table(column)
+
event.listen(
table,
"before_create",
util.portable_instancemethod(
- self._on_table_create)
+ self._on_table_create,
+ {"variant_mapping": variant_mapping})
)
event.listen(
table,
"after_drop",
- util.portable_instancemethod(self._on_table_drop)
+ util.portable_instancemethod(
+ self._on_table_drop,
+ {"variant_mapping": variant_mapping})
)
if self.metadata is None:
# TODO: what's the difference between self.metadata
@@ -1027,12 +1040,16 @@ class SchemaType(SchemaEventTarget):
event.listen(
table.metadata,
"before_create",
- util.portable_instancemethod(self._on_metadata_create)
+ util.portable_instancemethod(
+ self._on_metadata_create,
+ {"variant_mapping": variant_mapping})
)
event.listen(
table.metadata,
"after_drop",
- util.portable_instancemethod(self._on_metadata_drop)
+ util.portable_instancemethod(
+ self._on_metadata_drop,
+ {"variant_mapping": variant_mapping})
)
def copy(self, **kw):
@@ -1073,25 +1090,48 @@ class SchemaType(SchemaEventTarget):
t.drop(bind=bind, checkfirst=checkfirst)
def _on_table_create(self, target, bind, **kw):
+ if not self._is_impl_for_variant(bind.dialect, kw):
+ return
+
t = self.dialect_impl(bind.dialect)
if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
t._on_table_create(target, bind, **kw)
def _on_table_drop(self, target, bind, **kw):
+ if not self._is_impl_for_variant(bind.dialect, kw):
+ return
+
t = self.dialect_impl(bind.dialect)
if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
t._on_table_drop(target, bind, **kw)
def _on_metadata_create(self, target, bind, **kw):
+ if not self._is_impl_for_variant(bind.dialect, kw):
+ return
+
t = self.dialect_impl(bind.dialect)
if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
t._on_metadata_create(target, bind, **kw)
def _on_metadata_drop(self, target, bind, **kw):
+ if not self._is_impl_for_variant(bind.dialect, kw):
+ return
+
t = self.dialect_impl(bind.dialect)
if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
t._on_metadata_drop(target, bind, **kw)
+ def _is_impl_for_variant(self, dialect, kw):
+ variant_mapping = kw.pop('variant_mapping', None)
+ if variant_mapping is None:
+ return True
+
+ if dialect.name in variant_mapping and \
+ variant_mapping[dialect.name] is self:
+ return True
+ elif dialect.name not in variant_mapping:
+ return variant_mapping['_default'] is self
+
class Enum(String, SchemaType):
@@ -1339,7 +1379,9 @@ class Enum(String, SchemaType):
to_inspect=[Enum, SchemaType],
)
- def _should_create_constraint(self, compiler):
+ def _should_create_constraint(self, compiler, **kw):
+ if not self._is_impl_for_variant(compiler.dialect, kw):
+ return False
return not self.native_enum or \
not compiler.dialect.supports_native_enum
@@ -1351,11 +1393,14 @@ class Enum(String, SchemaType):
if not self.create_constraint:
return
+ variant_mapping = self._variant_mapping_for_set_table(column)
+
e = schema.CheckConstraint(
type_coerce(column, self).in_(self.enums),
name=_defer_name(self.name),
_create_rule=util.portable_instancemethod(
- self._should_create_constraint),
+ self._should_create_constraint,
+ {"variant_mapping": variant_mapping}),
_type_bound=True
)
assert e.table is table
@@ -1534,7 +1579,9 @@ class Boolean(TypeEngine, SchemaType):
self.name = name
self._create_events = _create_events
- def _should_create_constraint(self, compiler):
+ def _should_create_constraint(self, compiler, **kw):
+ if not self._is_impl_for_variant(compiler.dialect, kw):
+ return False
return not compiler.dialect.supports_native_boolean
@util.dependencies("sqlalchemy.sql.schema")
@@ -1542,11 +1589,14 @@ class Boolean(TypeEngine, SchemaType):
if not self.create_constraint:
return
+ variant_mapping = self._variant_mapping_for_set_table(column)
+
e = schema.CheckConstraint(
type_coerce(column, self).in_([0, 1]),
name=_defer_name(self.name),
_create_rule=util.portable_instancemethod(
- self._should_create_constraint),
+ self._should_create_constraint,
+ {"variant_mapping": variant_mapping}),
_type_bound=True
)
assert e.table is table
diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py
index 2b697480d..d537e49f0 100644
--- a/lib/sqlalchemy/sql/type_api.py
+++ b/lib/sqlalchemy/sql/type_api.py
@@ -858,7 +858,7 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
return self.impl._type_affinity
def _set_parent(self, column):
- """Support SchemaEentTarget"""
+ """Support SchemaEventTarget"""
super(TypeDecorator, self)._set_parent(column)
@@ -866,7 +866,7 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
self.impl._set_parent(column)
def _set_parent_with_dispatch(self, parent):
- """Support SchemaEentTarget"""
+ """Support SchemaEventTarget"""
super(TypeDecorator, self)._set_parent_with_dispatch(parent)
@@ -1222,6 +1222,24 @@ class Variant(TypeDecorator):
else:
return self.impl
+ def _set_parent(self, column):
+ """Support SchemaEventTarget"""
+
+ if isinstance(self.impl, SchemaEventTarget):
+ self.impl._set_parent(column)
+ for impl in self.mapping.values():
+ if isinstance(impl, SchemaEventTarget):
+ impl._set_parent(column)
+
+ def _set_parent_with_dispatch(self, parent):
+ """Support SchemaEventTarget"""
+
+ if isinstance(self.impl, SchemaEventTarget):
+ self.impl._set_parent_with_dispatch(parent)
+ for impl in self.mapping.values():
+ if isinstance(impl, SchemaEventTarget):
+ impl._set_parent_with_dispatch(parent)
+
def with_variant(self, type_, dialect_name):
"""Return a new :class:`.Variant` which adds the given
type + dialect name to the mapping, in addition to the