summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-01-27 15:07:17 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2022-03-25 10:16:36 -0400
commit57877461c1bd3b43a9d833fbca873d59db36b6f7 (patch)
treee6fa8a914f8275ee2a7b99b4772cff14438c6b7f /lib/sqlalchemy/sql
parent6f02d5edd88fe2475629438b0730181a2b00c5fe (diff)
downloadsqlalchemy-57877461c1bd3b43a9d833fbca873d59db36b6f7.tar.gz
generalize conditional DDL throughout schema / DDL
Expanded on the "conditional DDL" system implemented by the :class:`_schema.DDLElement` class to be directly available on :class:`_schema.SchemaItem` constructs such as :class:`_schema.Index`, :class:`_schema.ForeignKeyConstraint`, etc. such that the conditional logic for generating these elements is included within the default DDL emitting process. This system can also be accommodated by a future release of Alembic to support conditional DDL elements within all schema-management systems. Fixes: #7631 Change-Id: I9457524d7f66f49696187cf7d2b37dbb44f0e20b
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/compiler.py5
-rw-r--r--lib/sqlalchemy/sql/ddl.py206
-rw-r--r--lib/sqlalchemy/sql/schema.py71
3 files changed, 221 insertions, 61 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 5ba52ae51..aa98ff256 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -4889,10 +4889,7 @@ class DDLCompiler(Compiled):
for p in (
self.process(constraint)
for constraint in constraints
- if (
- constraint._create_rule is None
- or constraint._create_rule(self)
- )
+ if (constraint._should_create_for_compiler(self))
and (
not self.dialect.supports_alter
or not getattr(constraint, "use_alter", False)
diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py
index 7acb69beb..4d57ad869 100644
--- a/lib/sqlalchemy/sql/ddl.py
+++ b/lib/sqlalchemy/sql/ddl.py
@@ -12,10 +12,11 @@ to invoke them for a create/drop call.
from __future__ import annotations
import typing
+from typing import Any
from typing import Callable
from typing import List
from typing import Optional
-from typing import Sequence
+from typing import Sequence as typing_Sequence
from typing import Tuple
from . import roles
@@ -26,11 +27,20 @@ from .elements import ClauseElement
from .. import exc
from .. import util
from ..util import topological
-
+from ..util.typing import Protocol
if typing.TYPE_CHECKING:
+ from .compiler import Compiled
+ from .compiler import DDLCompiler
+ from .elements import BindParameter
from .schema import ForeignKeyConstraint
+ from .schema import SchemaItem
from .schema import Table
+ from ..engine.base import _CompiledCacheType
+ from ..engine.base import Connection
+ from ..engine.interfaces import _SchemaTranslateMapType
+ from ..engine.interfaces import CacheStats
+ from ..engine.interfaces import Dialect
class _DDLCompiles(ClauseElement):
@@ -43,10 +53,70 @@ class _DDLCompiles(ClauseElement):
return dialect.ddl_compiler(dialect, self, **kw)
- def _compile_w_cache(self, *arg, **kw):
+ def _compile_w_cache(
+ self,
+ dialect: Dialect,
+ *,
+ compiled_cache: Optional[_CompiledCacheType],
+ column_keys: List[str],
+ for_executemany: bool = False,
+ schema_translate_map: Optional[_SchemaTranslateMapType] = None,
+ **kw: Any,
+ ) -> Tuple[
+ Compiled, Optional[typing_Sequence[BindParameter[Any]]], CacheStats
+ ]:
raise NotImplementedError()
+class DDLIfCallable(Protocol):
+ def __call__(
+ self,
+ ddl: "DDLElement",
+ target: "SchemaItem",
+ bind: Optional["Connection"],
+ tables: Optional[List["Table"]] = None,
+ state: Optional[Any] = None,
+ *,
+ dialect: Dialect,
+ compiler: Optional[DDLCompiler] = ...,
+ checkfirst: bool,
+ ) -> bool:
+ ...
+
+
+class DDLIf(typing.NamedTuple):
+ dialect: Optional[str]
+ callable_: Optional[DDLIfCallable]
+ state: Optional[Any]
+
+ def _should_execute(self, ddl, target, bind, compiler=None, **kw):
+ if bind is not None:
+ dialect = bind.dialect
+ elif compiler is not None:
+ dialect = compiler.dialect
+ else:
+ assert False, "compiler or dialect is required"
+
+ if isinstance(self.dialect, str):
+ if self.dialect != dialect.name:
+ return False
+ elif isinstance(self.dialect, (tuple, list, set)):
+ if dialect.name not in self.dialect:
+ return False
+ if self.callable_ is not None and not self.callable_(
+ ddl,
+ target,
+ bind,
+ state=self.state,
+ dialect=dialect,
+ compiler=compiler,
+ **kw,
+ ):
+ return False
+
+ return True
+
+
SelfDDLElement = typing.TypeVar("SelfDDLElement", bound="DDLElement")
@@ -80,10 +150,8 @@ class DDLElement(roles.DDLRole, Executable, _DDLCompiles):
"""
- target = None
- on = None
- dialect = None
- callable_ = None
+ _ddl_if: Optional[DDLIf] = None
+ target: Optional["SchemaItem"] = None
def _execute_on_connection(
self, connection, distilled_params, execution_options
@@ -93,7 +161,7 @@ class DDLElement(roles.DDLRole, Executable, _DDLCompiles):
)
@_generative
- def against(self: SelfDDLElement, target) -> SelfDDLElement:
+ def against(self: SelfDDLElement, target: SchemaItem) -> SelfDDLElement:
"""Return a copy of this :class:`_schema.DDLElement` which will include
the given target.
@@ -125,13 +193,15 @@ class DDLElement(roles.DDLRole, Executable, _DDLCompiles):
processing the DDL string.
"""
-
self.target = target
return self
@_generative
def execute_if(
- self: SelfDDLElement, dialect=None, callable_=None, state=None
+ self: SelfDDLElement,
+ dialect: Optional[str] = None,
+ callable_: Optional[DDLIfCallable] = None,
+ state: Optional[Any] = None,
) -> SelfDDLElement:
r"""Return a callable that will execute this
:class:`_ddl.DDLElement` conditionally within an event handler.
@@ -155,7 +225,7 @@ class DDLElement(roles.DDLRole, Executable, _DDLCompiles):
DDL('something').execute_if(dialect=('postgresql', 'mysql'))
:param callable\_: A callable, which will be invoked with
- four positional arguments as well as optional keyword
+ three positional arguments as well as optional keyword
arguments:
:ddl:
@@ -168,13 +238,22 @@ class DDLElement(roles.DDLRole, Executable, _DDLCompiles):
explicitly.
:bind:
- The :class:`_engine.Connection` being used for DDL execution
+ The :class:`_engine.Connection` being used for DDL execution.
+ May be None if this construct is being created inline within
+ a table, in which case ``compiler`` will be present.
:tables:
Optional keyword argument - a list of Table objects which are to
be created/ dropped within a MetaData.create_all() or drop_all()
method call.
+ :dialect: keyword argument, but always present - the
+ :class:`.Dialect` involved in the operation.
+
+ :compiler: keyword argument. Will be ``None`` for an engine
+ level DDL invocation, but will refer to a :class:`.DDLCompiler`
+ if this DDL element is being created inline within a table.
+
:state:
Optional keyword argument - will be the ``state`` argument
passed to this function.
@@ -192,35 +271,30 @@ class DDLElement(roles.DDLRole, Executable, _DDLCompiles):
.. seealso::
+ :meth:`.SchemaItem.ddl_if`
+
:class:`.DDLEvents`
:ref:`event_toplevel`
"""
- self.dialect = dialect
- self.callable_ = callable_
- self.state = state
+ self._ddl_if = DDLIf(dialect, callable_, state)
return self
def _should_execute(self, target, bind, **kw):
- if isinstance(self.dialect, str):
- if self.dialect != bind.engine.name:
- return False
- elif isinstance(self.dialect, (tuple, list, set)):
- if bind.engine.name not in self.dialect:
- return False
- if self.callable_ is not None and not self.callable_(
- self, target, bind, state=self.state, **kw
- ):
- return False
+ if self._ddl_if is None:
+ return True
+ else:
+ return self._ddl_if._should_execute(self, target, bind, **kw)
- return True
+ def _invoke_with(self, bind):
+ if self._should_execute(self.target, bind):
+ return bind.execute(self)
def __call__(self, target, bind, **kw):
"""Execute the DDL as a ddl_listener."""
- if self._should_execute(target, bind, **kw):
- return bind.execute(self.against(target))
+ self.against(target)._invoke_with(bind)
def _generate(self):
s = self.__class__.__new__(self.__class__)
@@ -330,9 +404,10 @@ class _CreateDropBase(DDLElement):
if_exists=False,
if_not_exists=False,
):
- self.element = element
+ self.element = self.target = element
self.if_exists = if_exists
self.if_not_exists = if_not_exists
+ self._ddl_if = getattr(element, "_ddl_if", None)
@property
def stringify_dialect(self):
@@ -358,11 +433,19 @@ class CreateSchema(_CreateDropBase):
__visit_name__ = "create_schema"
- def __init__(self, name, quote=None, **kw):
+ def __init__(
+ self,
+ name,
+ quote=None,
+ if_exists=False,
+ if_not_exists=False,
+ ):
"""Create a new :class:`.CreateSchema` construct."""
self.quote = quote
- super(CreateSchema, self).__init__(name, **kw)
+ self.element = name
+ self.if_exists = if_exists
+ self.if_not_exists = if_not_exists
class DropSchema(_CreateDropBase):
@@ -374,12 +457,22 @@ class DropSchema(_CreateDropBase):
__visit_name__ = "drop_schema"
- def __init__(self, name, quote=None, cascade=False, **kw):
+ def __init__(
+ self,
+ name,
+ quote=None,
+ cascade=False,
+ if_exists=False,
+ if_not_exists=False,
+ ):
"""Create a new :class:`.DropSchema` construct."""
self.quote = quote
self.cascade = cascade
- super(DropSchema, self).__init__(name, **kw)
+ self.quote = quote
+ self.element = name
+ self.if_exists = if_exists
+ self.if_not_exists = if_not_exists
class CreateTable(_CreateDropBase):
@@ -427,6 +520,11 @@ class _DropView(_CreateDropBase):
__visit_name__ = "drop_view"
+class CreateConstraint(_DDLCompiles):
+ def __init__(self, element):
+ self.element = element
+
+
class CreateColumn(_DDLCompiles):
"""Represent a :class:`_schema.Column`
as rendered in a CREATE TABLE statement,
@@ -784,15 +882,10 @@ class SchemaGenerator(DDLBase):
# e.g., don't omit any foreign key constraints
include_foreign_key_constraints = None
- self.connection.execute(
- # fmt: off
- CreateTable(
- table,
- include_foreign_key_constraints= # noqa
- include_foreign_key_constraints, # noqa
- )
- # fmt: on
- )
+ CreateTable(
+ table,
+ include_foreign_key_constraints=include_foreign_key_constraints,
+ )._invoke_with(self.connection)
if hasattr(table, "indexes"):
for index in table.indexes:
@@ -800,11 +893,11 @@ class SchemaGenerator(DDLBase):
if self.dialect.supports_comments and not self.dialect.inline_comments:
if table.comment is not None:
- self.connection.execute(SetTableComment(table))
+ SetTableComment(table)._invoke_with(self.connection)
for column in table.columns:
if column.comment is not None:
- self.connection.execute(SetColumnComment(column))
+ SetColumnComment(column)._invoke_with(self.connection)
table.dispatch.after_create(
table,
@@ -817,17 +910,17 @@ class SchemaGenerator(DDLBase):
def visit_foreign_key_constraint(self, constraint):
if not self.dialect.supports_alter:
return
- self.connection.execute(AddConstraint(constraint))
+ AddConstraint(constraint)._invoke_with(self.connection)
def visit_sequence(self, sequence, create_ok=False):
if not create_ok and not self._can_create_sequence(sequence):
return
- self.connection.execute(CreateSequence(sequence))
+ CreateSequence(sequence)._invoke_with(self.connection)
def visit_index(self, index, create_ok=False):
if not create_ok and not self._can_create_index(index):
return
- self.connection.execute(CreateIndex(index))
+ CreateIndex(index)._invoke_with(self.connection)
class SchemaDropper(DDLBase):
@@ -964,7 +1057,7 @@ class SchemaDropper(DDLBase):
if not drop_ok and not self._can_drop_index(index):
return
- self.connection.execute(DropIndex(index))
+ DropIndex(index)(index, self.connection)
def visit_table(
self,
@@ -984,7 +1077,7 @@ class SchemaDropper(DDLBase):
_is_metadata_operation=_is_metadata_operation,
)
- self.connection.execute(DropTable(table))
+ DropTable(table)._invoke_with(self.connection)
# traverse client side defaults which may refer to server-side
# sequences. noting that some of these client side defaults may also be
@@ -1009,19 +1102,21 @@ class SchemaDropper(DDLBase):
def visit_foreign_key_constraint(self, constraint):
if not self.dialect.supports_alter:
return
- self.connection.execute(DropConstraint(constraint))
+ DropConstraint(constraint)._invoke_with(self.connection)
def visit_sequence(self, sequence, drop_ok=False):
if not drop_ok and not self._can_drop_sequence(sequence):
return
- self.connection.execute(DropSequence(sequence))
+ DropSequence(sequence)._invoke_with(self.connection)
def sort_tables(
- tables: Sequence["Table"],
+ tables: typing_Sequence["Table"],
skip_fn: Optional[Callable[["ForeignKeyConstraint"], bool]] = None,
- extra_dependencies: Optional[Sequence[Tuple["Table", "Table"]]] = None,
+ extra_dependencies: Optional[
+ typing_Sequence[Tuple["Table", "Table"]]
+ ] = None,
) -> List["Table"]:
"""Sort a collection of :class:`_schema.Table` objects based on
dependency.
@@ -1082,16 +1177,17 @@ def sort_tables(
"""
if skip_fn is not None:
+ fixed_skip_fn = skip_fn
def _skip_fn(fkc):
for fk in fkc.elements:
- if skip_fn(fk):
+ if fixed_skip_fn(fk):
return True
else:
return None
else:
- _skip_fn = None
+ _skip_fn = None # type: ignore
return [
t
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index 540b62e8a..dfe82432d 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -184,6 +184,61 @@ class SchemaItem(SchemaEventTarget, visitors.Visitable):
_use_schema_map = True
+SelfHasConditionalDDL = TypeVar(
+ "SelfHasConditionalDDL", bound="HasConditionalDDL"
+)
+
+
+class HasConditionalDDL:
+ """define a class that includes the :meth:`.HasConditionalDDL.ddl_if`
+ method, allowing for conditional rendering of DDL.
+
+ Currently applies to constraints and indexes.
+
+ .. versionadded:: 2.0
+
+
+ """
+
+ _ddl_if: Optional[ddl.DDLIf] = None
+
+ def ddl_if(
+ self: SelfHasConditionalDDL,
+ dialect: Optional[str] = None,
+ callable_: Optional[ddl.DDLIfCallable] = None,
+ state: Optional[Any] = None,
+ ) -> SelfHasConditionalDDL:
+ r"""apply a conditional DDL rule to this schema item.
+
+ These rules work in a similar manner to the
+ :meth:`.DDLElement.execute_if` callable, with the added feature that
+ the criteria may be checked within the DDL compilation phase for a
+ construct such as :class:`.CreateTable`.
+ :meth:`.HasConditionalDDL.ddl_if` currently applies towards the
+ :class:`.Index` construct as well as all :class:`.Constraint`
+ constructs.
+
+ :param dialect: string name of a dialect, or a tuple of string names
+ to indicate multiple dialect types.
+
+ :param callable\_: a callable that is constructed using the same form
+ as that described in :paramref:`.DDLElement.execute_if.callable_`.
+
+ :param state: any arbitrary object that will be passed to the
+ callable, if present.
+
+ .. versionadded:: 2.0
+
+ .. seealso::
+
+ :ref:`schema_ddl_ddl_if` - background and usage examples
+
+
+ """
+ self._ddl_if = ddl.DDLIf(dialect, callable_, state)
+ return self
+
+
class HasSchemaAttr(SchemaItem):
"""schema item that includes a top-level schema name"""
@@ -3355,7 +3410,7 @@ class DefaultClause(FetchedValue):
return "DefaultClause(%r, for_update=%r)" % (self.arg, self.for_update)
-class Constraint(DialectKWArgs, SchemaItem):
+class Constraint(DialectKWArgs, HasConditionalDDL, SchemaItem):
"""A table-level SQL constraint.
:class:`_schema.Constraint` serves as the base class for the series of
@@ -3424,6 +3479,16 @@ class Constraint(DialectKWArgs, SchemaItem):
util.set_creation_order(self)
self._validate_dialect_kwargs(dialect_kw)
+ def _should_create_for_compiler(self, compiler, **kw):
+ if self._create_rule is not None and not self._create_rule(compiler):
+ return False
+ elif self._ddl_if is not None:
+ return self._ddl_if._should_execute(
+ ddl.CreateConstraint(self), self, None, compiler=compiler, **kw
+ )
+ else:
+ return True
+
@property
def table(self):
try:
@@ -4292,7 +4357,9 @@ class UniqueConstraint(ColumnCollectionConstraint):
__visit_name__ = "unique_constraint"
-class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem):
+class Index(
+ DialectKWArgs, ColumnCollectionMixin, HasConditionalDDL, SchemaItem
+):
"""A table-level INDEX.
Defines a composite (one or more column) INDEX.