diff options
| author | mike bayer <mike_mp@zzzcomputing.com> | 2021-01-29 23:22:01 +0000 |
|---|---|---|
| committer | Gerrit Code Review <gerrit@bbpush.zzzcomputing.com> | 2021-01-29 23:22:01 +0000 |
| commit | 532026f97f402d6673cd9746f1a7daee99327a68 (patch) | |
| tree | d5d271e550abf531ef225379312b220857d4b315 /lib/sqlalchemy | |
| parent | aff54c0bd8f75d324f1a4a8601a3d6f28739439e (diff) | |
| parent | 09e999808a5272a1426a7c00e3a4f27b2e27b8d7 (diff) | |
| download | sqlalchemy-532026f97f402d6673cd9746f1a7daee99327a68.tar.gz | |
Merge "Use schema._copy_expression() fully in column collection constraints"
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/ext.py | 11 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/sqlite/base.py | 9 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/schema.py | 24 |
3 files changed, 32 insertions, 12 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/ext.py b/lib/sqlalchemy/dialects/postgresql/ext.py index 1bd1e5258..a93a3477d 100644 --- a/lib/sqlalchemy/dialects/postgresql/ext.py +++ b/lib/sqlalchemy/dialects/postgresql/ext.py @@ -12,6 +12,7 @@ from ...sql import elements from ...sql import expression from ...sql import functions from ...sql import roles +from ...sql import schema from ...sql.schema import ColumnCollectionConstraint @@ -240,8 +241,14 @@ class ExcludeConstraint(ColumnCollectionConstraint): ) ] - def copy(self, **kw): - elements = [(col, self.operators[col]) for col in self.columns.keys()] + def copy(self, target_table=None, **kw): + elements = [ + ( + schema._copy_expression(expr, self.parent, target_table), + self.operators[expr.name], + ) + for expr in self.columns + ] c = self.__class__( *elements, name=self.name, diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index ded1ce2a8..8582c1f14 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -826,6 +826,7 @@ from ...sql import ColumnElement from ...sql import compiler from ...sql import elements from ...sql import roles +from ...sql import schema from ...types import BLOB # noqa from ...types import BOOLEAN # noqa from ...types import CHAR # noqa @@ -1502,9 +1503,11 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): "on_conflict" ] if on_conflict_clause is None and len(constraint.columns) == 1: - on_conflict_clause = list(constraint)[0].dialect_options["sqlite"][ - "on_conflict_unique" - ] + col1 = list(constraint)[0] + if isinstance(col1, schema.SchemaItem): + on_conflict_clause = list(constraint)[0].dialect_options[ + "sqlite" + ]["on_conflict_unique"] if on_conflict_clause is not None: text += " ON CONFLICT " + on_conflict_clause diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index b18a6e365..34bedbc6a 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -92,6 +92,9 @@ def _get_table_key(name, schema): # this should really be in sql/util.py but we'd have to # break an import cycle def _copy_expression(expression, source_table, target_table): + if source_table is None or target_table is None: + return expression + def replace(col): if ( isinstance(col, Column) @@ -3272,7 +3275,7 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint): def __contains__(self, x): return x in self.columns - def copy(self, **kw): + def copy(self, target_table=None, **kw): # ticket #5276 constraint_kwargs = {} for dialect_name in self.dialect_options: @@ -3289,7 +3292,10 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint): name=self.name, deferrable=self.deferrable, initially=self.initially, - *self.columns.keys(), + *[ + _copy_expression(expr, self.parent, target_table) + for expr in self.columns + ], **constraint_kwargs ) return self._schema_item_copy(c) @@ -3393,6 +3399,9 @@ class CheckConstraint(ColumnCollectionConstraint): def copy(self, target_table=None, **kw): if target_table is not None: + # note that target_table is None for the copy process of + # a column-bound CheckConstraint, so this path is not reached + # in that case. sqltext = _copy_expression(self.sqltext, self.table, target_table) else: sqltext = self.sqltext @@ -4864,10 +4873,11 @@ class Computed(FetchedValue, SchemaItem): return self def copy(self, target_table=None, **kw): - if target_table is not None: - sqltext = _copy_expression(self.sqltext, self.table, target_table) - else: - sqltext = self.sqltext + sqltext = _copy_expression( + self.sqltext, + self.column.table if self.column is not None else None, + target_table, + ) g = Computed(sqltext, persisted=self.persisted) return self._schema_item_copy(g) @@ -4998,7 +5008,7 @@ class Identity(IdentityOptions, FetchedValue, SchemaItem): def _as_for_update(self, for_update): return self - def copy(self, target_table=None, **kw): + def copy(self, **kw): i = Identity( always=self.always, on_null=self.on_null, |
