diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2017-03-07 12:53:00 -0500 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2017-03-07 16:24:18 -0500 |
| commit | c04870ba7b8098c7d408ad66f60efe7229496fde (patch) | |
| tree | 48ce6b3cbb8225f85499e54dc92e7d58684812f7 /lib/sqlalchemy/sql | |
| parent | 9e627159733da48e2fd2d25de93589eb079a75f4 (diff) | |
| download | sqlalchemy-c04870ba7b8098c7d408ad66f60efe7229496fde.tar.gz | |
Allow SchemaType and Variant to work together
Added support for the :class:`.Variant` and the :class:`.SchemaType`
objects to be compatible with each other. That is, a variant
can be created against a type like :class:`.Enum`, and the instructions
to create constraints and/or database-specific type objects will
propagate correctly as per the variant's dialect mapping.
Also added testing for some potential double-event scenarios
on TypeDecorator but it seems usually this doesn't occur.
Change-Id: I4a7e7c26b4133cd14e870f5bc34a1b2f0f19a14a
Fixes: #2892
Diffstat (limited to 'lib/sqlalchemy/sql')
| -rw-r--r-- | lib/sqlalchemy/sql/sqltypes.py | 68 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/type_api.py | 22 |
2 files changed, 79 insertions, 11 deletions
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index bb39388ab..8a114ece6 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -15,7 +15,7 @@ import collections import json from . import elements -from .type_api import TypeEngine, TypeDecorator, to_instance +from .type_api import TypeEngine, TypeDecorator, to_instance, Variant from .elements import quoted_name, TypeCoerce as type_coerce, _defer_name, \ Slice, _literal_as_binds from .. import exc, util, processors @@ -1003,6 +1003,14 @@ class SchemaType(SchemaEventTarget): def _set_parent(self, column): column._on_table_attach(util.portable_instancemethod(self._set_table)) + def _variant_mapping_for_set_table(self, column): + if isinstance(column.type, Variant): + variant_mapping = column.type.mapping.copy() + variant_mapping['_default'] = column.type.impl + else: + variant_mapping = None + return variant_mapping + def _set_table(self, column, table): if self.inherit_schema: self.schema = table.schema @@ -1010,16 +1018,21 @@ class SchemaType(SchemaEventTarget): if not self._create_events: return + variant_mapping = self._variant_mapping_for_set_table(column) + event.listen( table, "before_create", util.portable_instancemethod( - self._on_table_create) + self._on_table_create, + {"variant_mapping": variant_mapping}) ) event.listen( table, "after_drop", - util.portable_instancemethod(self._on_table_drop) + util.portable_instancemethod( + self._on_table_drop, + {"variant_mapping": variant_mapping}) ) if self.metadata is None: # TODO: what's the difference between self.metadata @@ -1027,12 +1040,16 @@ class SchemaType(SchemaEventTarget): event.listen( table.metadata, "before_create", - util.portable_instancemethod(self._on_metadata_create) + util.portable_instancemethod( + self._on_metadata_create, + {"variant_mapping": variant_mapping}) ) event.listen( table.metadata, "after_drop", - util.portable_instancemethod(self._on_metadata_drop) + util.portable_instancemethod( + self._on_metadata_drop, + {"variant_mapping": variant_mapping}) ) def copy(self, **kw): @@ -1073,25 +1090,48 @@ class SchemaType(SchemaEventTarget): t.drop(bind=bind, checkfirst=checkfirst) def _on_table_create(self, target, bind, **kw): + if not self._is_impl_for_variant(bind.dialect, kw): + return + t = self.dialect_impl(bind.dialect) if t.__class__ is not self.__class__ and isinstance(t, SchemaType): t._on_table_create(target, bind, **kw) def _on_table_drop(self, target, bind, **kw): + if not self._is_impl_for_variant(bind.dialect, kw): + return + t = self.dialect_impl(bind.dialect) if t.__class__ is not self.__class__ and isinstance(t, SchemaType): t._on_table_drop(target, bind, **kw) def _on_metadata_create(self, target, bind, **kw): + if not self._is_impl_for_variant(bind.dialect, kw): + return + t = self.dialect_impl(bind.dialect) if t.__class__ is not self.__class__ and isinstance(t, SchemaType): t._on_metadata_create(target, bind, **kw) def _on_metadata_drop(self, target, bind, **kw): + if not self._is_impl_for_variant(bind.dialect, kw): + return + t = self.dialect_impl(bind.dialect) if t.__class__ is not self.__class__ and isinstance(t, SchemaType): t._on_metadata_drop(target, bind, **kw) + def _is_impl_for_variant(self, dialect, kw): + variant_mapping = kw.pop('variant_mapping', None) + if variant_mapping is None: + return True + + if dialect.name in variant_mapping and \ + variant_mapping[dialect.name] is self: + return True + elif dialect.name not in variant_mapping: + return variant_mapping['_default'] is self + class Enum(String, SchemaType): @@ -1339,7 +1379,9 @@ class Enum(String, SchemaType): to_inspect=[Enum, SchemaType], ) - def _should_create_constraint(self, compiler): + def _should_create_constraint(self, compiler, **kw): + if not self._is_impl_for_variant(compiler.dialect, kw): + return False return not self.native_enum or \ not compiler.dialect.supports_native_enum @@ -1351,11 +1393,14 @@ class Enum(String, SchemaType): if not self.create_constraint: return + variant_mapping = self._variant_mapping_for_set_table(column) + e = schema.CheckConstraint( type_coerce(column, self).in_(self.enums), name=_defer_name(self.name), _create_rule=util.portable_instancemethod( - self._should_create_constraint), + self._should_create_constraint, + {"variant_mapping": variant_mapping}), _type_bound=True ) assert e.table is table @@ -1534,7 +1579,9 @@ class Boolean(TypeEngine, SchemaType): self.name = name self._create_events = _create_events - def _should_create_constraint(self, compiler): + def _should_create_constraint(self, compiler, **kw): + if not self._is_impl_for_variant(compiler.dialect, kw): + return False return not compiler.dialect.supports_native_boolean @util.dependencies("sqlalchemy.sql.schema") @@ -1542,11 +1589,14 @@ class Boolean(TypeEngine, SchemaType): if not self.create_constraint: return + variant_mapping = self._variant_mapping_for_set_table(column) + e = schema.CheckConstraint( type_coerce(column, self).in_([0, 1]), name=_defer_name(self.name), _create_rule=util.portable_instancemethod( - self._should_create_constraint), + self._should_create_constraint, + {"variant_mapping": variant_mapping}), _type_bound=True ) assert e.table is table diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index 2b697480d..d537e49f0 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -858,7 +858,7 @@ class TypeDecorator(SchemaEventTarget, TypeEngine): return self.impl._type_affinity def _set_parent(self, column): - """Support SchemaEentTarget""" + """Support SchemaEventTarget""" super(TypeDecorator, self)._set_parent(column) @@ -866,7 +866,7 @@ class TypeDecorator(SchemaEventTarget, TypeEngine): self.impl._set_parent(column) def _set_parent_with_dispatch(self, parent): - """Support SchemaEentTarget""" + """Support SchemaEventTarget""" super(TypeDecorator, self)._set_parent_with_dispatch(parent) @@ -1222,6 +1222,24 @@ class Variant(TypeDecorator): else: return self.impl + def _set_parent(self, column): + """Support SchemaEventTarget""" + + if isinstance(self.impl, SchemaEventTarget): + self.impl._set_parent(column) + for impl in self.mapping.values(): + if isinstance(impl, SchemaEventTarget): + impl._set_parent(column) + + def _set_parent_with_dispatch(self, parent): + """Support SchemaEventTarget""" + + if isinstance(self.impl, SchemaEventTarget): + self.impl._set_parent_with_dispatch(parent) + for impl in self.mapping.values(): + if isinstance(impl, SchemaEventTarget): + impl._set_parent_with_dispatch(parent) + def with_variant(self, type_, dialect_name): """Return a new :class:`.Variant` which adds the given type + dialect name to the mapping, in addition to the |
