diff options
Diffstat (limited to 'lib/sqlalchemy/sql')
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 5 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/ddl.py | 206 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/schema.py | 71 |
3 files changed, 221 insertions, 61 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 5ba52ae51..aa98ff256 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -4889,10 +4889,7 @@ class DDLCompiler(Compiled): for p in ( self.process(constraint) for constraint in constraints - if ( - constraint._create_rule is None - or constraint._create_rule(self) - ) + if (constraint._should_create_for_compiler(self)) and ( not self.dialect.supports_alter or not getattr(constraint, "use_alter", False) diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index 7acb69beb..4d57ad869 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -12,10 +12,11 @@ to invoke them for a create/drop call. from __future__ import annotations import typing +from typing import Any from typing import Callable from typing import List from typing import Optional -from typing import Sequence +from typing import Sequence as typing_Sequence from typing import Tuple from . import roles @@ -26,11 +27,20 @@ from .elements import ClauseElement from .. import exc from .. import util from ..util import topological - +from ..util.typing import Protocol if typing.TYPE_CHECKING: + from .compiler import Compiled + from .compiler import DDLCompiler + from .elements import BindParameter from .schema import ForeignKeyConstraint + from .schema import SchemaItem from .schema import Table + from ..engine.base import _CompiledCacheType + from ..engine.base import Connection + from ..engine.interfaces import _SchemaTranslateMapType + from ..engine.interfaces import CacheStats + from ..engine.interfaces import Dialect class _DDLCompiles(ClauseElement): @@ -43,10 +53,70 @@ class _DDLCompiles(ClauseElement): return dialect.ddl_compiler(dialect, self, **kw) - def _compile_w_cache(self, *arg, **kw): + def _compile_w_cache( + self, + dialect: Dialect, + *, + compiled_cache: Optional[_CompiledCacheType], + column_keys: List[str], + for_executemany: bool = False, + schema_translate_map: Optional[_SchemaTranslateMapType] = None, + **kw: Any, + ) -> Tuple[ + Compiled, Optional[typing_Sequence[BindParameter[Any]]], CacheStats + ]: raise NotImplementedError() +class DDLIfCallable(Protocol): + def __call__( + self, + ddl: "DDLElement", + target: "SchemaItem", + bind: Optional["Connection"], + tables: Optional[List["Table"]] = None, + state: Optional[Any] = None, + *, + dialect: Dialect, + compiler: Optional[DDLCompiler] = ..., + checkfirst: bool, + ) -> bool: + ... + + +class DDLIf(typing.NamedTuple): + dialect: Optional[str] + callable_: Optional[DDLIfCallable] + state: Optional[Any] + + def _should_execute(self, ddl, target, bind, compiler=None, **kw): + if bind is not None: + dialect = bind.dialect + elif compiler is not None: + dialect = compiler.dialect + else: + assert False, "compiler or dialect is required" + + if isinstance(self.dialect, str): + if self.dialect != dialect.name: + return False + elif isinstance(self.dialect, (tuple, list, set)): + if dialect.name not in self.dialect: + return False + if self.callable_ is not None and not self.callable_( + ddl, + target, + bind, + state=self.state, + dialect=dialect, + compiler=compiler, + **kw, + ): + return False + + return True + + SelfDDLElement = typing.TypeVar("SelfDDLElement", bound="DDLElement") @@ -80,10 +150,8 @@ class DDLElement(roles.DDLRole, Executable, _DDLCompiles): """ - target = None - on = None - dialect = None - callable_ = None + _ddl_if: Optional[DDLIf] = None + target: Optional["SchemaItem"] = None def _execute_on_connection( self, connection, distilled_params, execution_options @@ -93,7 +161,7 @@ class DDLElement(roles.DDLRole, Executable, _DDLCompiles): ) @_generative - def against(self: SelfDDLElement, target) -> SelfDDLElement: + def against(self: SelfDDLElement, target: SchemaItem) -> SelfDDLElement: """Return a copy of this :class:`_schema.DDLElement` which will include the given target. @@ -125,13 +193,15 @@ class DDLElement(roles.DDLRole, Executable, _DDLCompiles): processing the DDL string. """ - self.target = target return self @_generative def execute_if( - self: SelfDDLElement, dialect=None, callable_=None, state=None + self: SelfDDLElement, + dialect: Optional[str] = None, + callable_: Optional[DDLIfCallable] = None, + state: Optional[Any] = None, ) -> SelfDDLElement: r"""Return a callable that will execute this :class:`_ddl.DDLElement` conditionally within an event handler. @@ -155,7 +225,7 @@ class DDLElement(roles.DDLRole, Executable, _DDLCompiles): DDL('something').execute_if(dialect=('postgresql', 'mysql')) :param callable\_: A callable, which will be invoked with - four positional arguments as well as optional keyword + three positional arguments as well as optional keyword arguments: :ddl: @@ -168,13 +238,22 @@ class DDLElement(roles.DDLRole, Executable, _DDLCompiles): explicitly. :bind: - The :class:`_engine.Connection` being used for DDL execution + The :class:`_engine.Connection` being used for DDL execution. + May be None if this construct is being created inline within + a table, in which case ``compiler`` will be present. :tables: Optional keyword argument - a list of Table objects which are to be created/ dropped within a MetaData.create_all() or drop_all() method call. + :dialect: keyword argument, but always present - the + :class:`.Dialect` involved in the operation. + + :compiler: keyword argument. Will be ``None`` for an engine + level DDL invocation, but will refer to a :class:`.DDLCompiler` + if this DDL element is being created inline within a table. + :state: Optional keyword argument - will be the ``state`` argument passed to this function. @@ -192,35 +271,30 @@ class DDLElement(roles.DDLRole, Executable, _DDLCompiles): .. seealso:: + :meth:`.SchemaItem.ddl_if` + :class:`.DDLEvents` :ref:`event_toplevel` """ - self.dialect = dialect - self.callable_ = callable_ - self.state = state + self._ddl_if = DDLIf(dialect, callable_, state) return self def _should_execute(self, target, bind, **kw): - if isinstance(self.dialect, str): - if self.dialect != bind.engine.name: - return False - elif isinstance(self.dialect, (tuple, list, set)): - if bind.engine.name not in self.dialect: - return False - if self.callable_ is not None and not self.callable_( - self, target, bind, state=self.state, **kw - ): - return False + if self._ddl_if is None: + return True + else: + return self._ddl_if._should_execute(self, target, bind, **kw) - return True + def _invoke_with(self, bind): + if self._should_execute(self.target, bind): + return bind.execute(self) def __call__(self, target, bind, **kw): """Execute the DDL as a ddl_listener.""" - if self._should_execute(target, bind, **kw): - return bind.execute(self.against(target)) + self.against(target)._invoke_with(bind) def _generate(self): s = self.__class__.__new__(self.__class__) @@ -330,9 +404,10 @@ class _CreateDropBase(DDLElement): if_exists=False, if_not_exists=False, ): - self.element = element + self.element = self.target = element self.if_exists = if_exists self.if_not_exists = if_not_exists + self._ddl_if = getattr(element, "_ddl_if", None) @property def stringify_dialect(self): @@ -358,11 +433,19 @@ class CreateSchema(_CreateDropBase): __visit_name__ = "create_schema" - def __init__(self, name, quote=None, **kw): + def __init__( + self, + name, + quote=None, + if_exists=False, + if_not_exists=False, + ): """Create a new :class:`.CreateSchema` construct.""" self.quote = quote - super(CreateSchema, self).__init__(name, **kw) + self.element = name + self.if_exists = if_exists + self.if_not_exists = if_not_exists class DropSchema(_CreateDropBase): @@ -374,12 +457,22 @@ class DropSchema(_CreateDropBase): __visit_name__ = "drop_schema" - def __init__(self, name, quote=None, cascade=False, **kw): + def __init__( + self, + name, + quote=None, + cascade=False, + if_exists=False, + if_not_exists=False, + ): """Create a new :class:`.DropSchema` construct.""" self.quote = quote self.cascade = cascade - super(DropSchema, self).__init__(name, **kw) + self.quote = quote + self.element = name + self.if_exists = if_exists + self.if_not_exists = if_not_exists class CreateTable(_CreateDropBase): @@ -427,6 +520,11 @@ class _DropView(_CreateDropBase): __visit_name__ = "drop_view" +class CreateConstraint(_DDLCompiles): + def __init__(self, element): + self.element = element + + class CreateColumn(_DDLCompiles): """Represent a :class:`_schema.Column` as rendered in a CREATE TABLE statement, @@ -784,15 +882,10 @@ class SchemaGenerator(DDLBase): # e.g., don't omit any foreign key constraints include_foreign_key_constraints = None - self.connection.execute( - # fmt: off - CreateTable( - table, - include_foreign_key_constraints= # noqa - include_foreign_key_constraints, # noqa - ) - # fmt: on - ) + CreateTable( + table, + include_foreign_key_constraints=include_foreign_key_constraints, + )._invoke_with(self.connection) if hasattr(table, "indexes"): for index in table.indexes: @@ -800,11 +893,11 @@ class SchemaGenerator(DDLBase): if self.dialect.supports_comments and not self.dialect.inline_comments: if table.comment is not None: - self.connection.execute(SetTableComment(table)) + SetTableComment(table)._invoke_with(self.connection) for column in table.columns: if column.comment is not None: - self.connection.execute(SetColumnComment(column)) + SetColumnComment(column)._invoke_with(self.connection) table.dispatch.after_create( table, @@ -817,17 +910,17 @@ class SchemaGenerator(DDLBase): def visit_foreign_key_constraint(self, constraint): if not self.dialect.supports_alter: return - self.connection.execute(AddConstraint(constraint)) + AddConstraint(constraint)._invoke_with(self.connection) def visit_sequence(self, sequence, create_ok=False): if not create_ok and not self._can_create_sequence(sequence): return - self.connection.execute(CreateSequence(sequence)) + CreateSequence(sequence)._invoke_with(self.connection) def visit_index(self, index, create_ok=False): if not create_ok and not self._can_create_index(index): return - self.connection.execute(CreateIndex(index)) + CreateIndex(index)._invoke_with(self.connection) class SchemaDropper(DDLBase): @@ -964,7 +1057,7 @@ class SchemaDropper(DDLBase): if not drop_ok and not self._can_drop_index(index): return - self.connection.execute(DropIndex(index)) + DropIndex(index)(index, self.connection) def visit_table( self, @@ -984,7 +1077,7 @@ class SchemaDropper(DDLBase): _is_metadata_operation=_is_metadata_operation, ) - self.connection.execute(DropTable(table)) + DropTable(table)._invoke_with(self.connection) # traverse client side defaults which may refer to server-side # sequences. noting that some of these client side defaults may also be @@ -1009,19 +1102,21 @@ class SchemaDropper(DDLBase): def visit_foreign_key_constraint(self, constraint): if not self.dialect.supports_alter: return - self.connection.execute(DropConstraint(constraint)) + DropConstraint(constraint)._invoke_with(self.connection) def visit_sequence(self, sequence, drop_ok=False): if not drop_ok and not self._can_drop_sequence(sequence): return - self.connection.execute(DropSequence(sequence)) + DropSequence(sequence)._invoke_with(self.connection) def sort_tables( - tables: Sequence["Table"], + tables: typing_Sequence["Table"], skip_fn: Optional[Callable[["ForeignKeyConstraint"], bool]] = None, - extra_dependencies: Optional[Sequence[Tuple["Table", "Table"]]] = None, + extra_dependencies: Optional[ + typing_Sequence[Tuple["Table", "Table"]] + ] = None, ) -> List["Table"]: """Sort a collection of :class:`_schema.Table` objects based on dependency. @@ -1082,16 +1177,17 @@ def sort_tables( """ if skip_fn is not None: + fixed_skip_fn = skip_fn def _skip_fn(fkc): for fk in fkc.elements: - if skip_fn(fk): + if fixed_skip_fn(fk): return True else: return None else: - _skip_fn = None + _skip_fn = None # type: ignore return [ t diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 540b62e8a..dfe82432d 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -184,6 +184,61 @@ class SchemaItem(SchemaEventTarget, visitors.Visitable): _use_schema_map = True +SelfHasConditionalDDL = TypeVar( + "SelfHasConditionalDDL", bound="HasConditionalDDL" +) + + +class HasConditionalDDL: + """define a class that includes the :meth:`.HasConditionalDDL.ddl_if` + method, allowing for conditional rendering of DDL. + + Currently applies to constraints and indexes. + + .. versionadded:: 2.0 + + + """ + + _ddl_if: Optional[ddl.DDLIf] = None + + def ddl_if( + self: SelfHasConditionalDDL, + dialect: Optional[str] = None, + callable_: Optional[ddl.DDLIfCallable] = None, + state: Optional[Any] = None, + ) -> SelfHasConditionalDDL: + r"""apply a conditional DDL rule to this schema item. + + These rules work in a similar manner to the + :meth:`.DDLElement.execute_if` callable, with the added feature that + the criteria may be checked within the DDL compilation phase for a + construct such as :class:`.CreateTable`. + :meth:`.HasConditionalDDL.ddl_if` currently applies towards the + :class:`.Index` construct as well as all :class:`.Constraint` + constructs. + + :param dialect: string name of a dialect, or a tuple of string names + to indicate multiple dialect types. + + :param callable\_: a callable that is constructed using the same form + as that described in :paramref:`.DDLElement.execute_if.callable_`. + + :param state: any arbitrary object that will be passed to the + callable, if present. + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`schema_ddl_ddl_if` - background and usage examples + + + """ + self._ddl_if = ddl.DDLIf(dialect, callable_, state) + return self + + class HasSchemaAttr(SchemaItem): """schema item that includes a top-level schema name""" @@ -3355,7 +3410,7 @@ class DefaultClause(FetchedValue): return "DefaultClause(%r, for_update=%r)" % (self.arg, self.for_update) -class Constraint(DialectKWArgs, SchemaItem): +class Constraint(DialectKWArgs, HasConditionalDDL, SchemaItem): """A table-level SQL constraint. :class:`_schema.Constraint` serves as the base class for the series of @@ -3424,6 +3479,16 @@ class Constraint(DialectKWArgs, SchemaItem): util.set_creation_order(self) self._validate_dialect_kwargs(dialect_kw) + def _should_create_for_compiler(self, compiler, **kw): + if self._create_rule is not None and not self._create_rule(compiler): + return False + elif self._ddl_if is not None: + return self._ddl_if._should_execute( + ddl.CreateConstraint(self), self, None, compiler=compiler, **kw + ) + else: + return True + @property def table(self): try: @@ -4292,7 +4357,9 @@ class UniqueConstraint(ColumnCollectionConstraint): __visit_name__ = "unique_constraint" -class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem): +class Index( + DialectKWArgs, ColumnCollectionMixin, HasConditionalDDL, SchemaItem +): """A table-level INDEX. Defines a composite (one or more column) INDEX. |
