summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2021-01-29 23:22:01 +0000
committerGerrit Code Review <gerrit@bbpush.zzzcomputing.com>2021-01-29 23:22:01 +0000
commit532026f97f402d6673cd9746f1a7daee99327a68 (patch)
treed5d271e550abf531ef225379312b220857d4b315 /lib/sqlalchemy
parentaff54c0bd8f75d324f1a4a8601a3d6f28739439e (diff)
parent09e999808a5272a1426a7c00e3a4f27b2e27b8d7 (diff)
downloadsqlalchemy-532026f97f402d6673cd9746f1a7daee99327a68.tar.gz
Merge "Use schema._copy_expression() fully in column collection constraints"
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/dialects/postgresql/ext.py11
-rw-r--r--lib/sqlalchemy/dialects/sqlite/base.py9
-rw-r--r--lib/sqlalchemy/sql/schema.py24
3 files changed, 32 insertions, 12 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/ext.py b/lib/sqlalchemy/dialects/postgresql/ext.py
index 1bd1e5258..a93a3477d 100644
--- a/lib/sqlalchemy/dialects/postgresql/ext.py
+++ b/lib/sqlalchemy/dialects/postgresql/ext.py
@@ -12,6 +12,7 @@ from ...sql import elements
from ...sql import expression
from ...sql import functions
from ...sql import roles
+from ...sql import schema
from ...sql.schema import ColumnCollectionConstraint
@@ -240,8 +241,14 @@ class ExcludeConstraint(ColumnCollectionConstraint):
)
]
- def copy(self, **kw):
- elements = [(col, self.operators[col]) for col in self.columns.keys()]
+ def copy(self, target_table=None, **kw):
+ elements = [
+ (
+ schema._copy_expression(expr, self.parent, target_table),
+ self.operators[expr.name],
+ )
+ for expr in self.columns
+ ]
c = self.__class__(
*elements,
name=self.name,
diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py
index ded1ce2a8..8582c1f14 100644
--- a/lib/sqlalchemy/dialects/sqlite/base.py
+++ b/lib/sqlalchemy/dialects/sqlite/base.py
@@ -826,6 +826,7 @@ from ...sql import ColumnElement
from ...sql import compiler
from ...sql import elements
from ...sql import roles
+from ...sql import schema
from ...types import BLOB # noqa
from ...types import BOOLEAN # noqa
from ...types import CHAR # noqa
@@ -1502,9 +1503,11 @@ class SQLiteDDLCompiler(compiler.DDLCompiler):
"on_conflict"
]
if on_conflict_clause is None and len(constraint.columns) == 1:
- on_conflict_clause = list(constraint)[0].dialect_options["sqlite"][
- "on_conflict_unique"
- ]
+ col1 = list(constraint)[0]
+ if isinstance(col1, schema.SchemaItem):
+ on_conflict_clause = list(constraint)[0].dialect_options[
+ "sqlite"
+ ]["on_conflict_unique"]
if on_conflict_clause is not None:
text += " ON CONFLICT " + on_conflict_clause
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index b18a6e365..34bedbc6a 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -92,6 +92,9 @@ def _get_table_key(name, schema):
# this should really be in sql/util.py but we'd have to
# break an import cycle
def _copy_expression(expression, source_table, target_table):
+ if source_table is None or target_table is None:
+ return expression
+
def replace(col):
if (
isinstance(col, Column)
@@ -3272,7 +3275,7 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint):
def __contains__(self, x):
return x in self.columns
- def copy(self, **kw):
+ def copy(self, target_table=None, **kw):
# ticket #5276
constraint_kwargs = {}
for dialect_name in self.dialect_options:
@@ -3289,7 +3292,10 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint):
name=self.name,
deferrable=self.deferrable,
initially=self.initially,
- *self.columns.keys(),
+ *[
+ _copy_expression(expr, self.parent, target_table)
+ for expr in self.columns
+ ],
**constraint_kwargs
)
return self._schema_item_copy(c)
@@ -3393,6 +3399,9 @@ class CheckConstraint(ColumnCollectionConstraint):
def copy(self, target_table=None, **kw):
if target_table is not None:
+ # note that target_table is None for the copy process of
+ # a column-bound CheckConstraint, so this path is not reached
+ # in that case.
sqltext = _copy_expression(self.sqltext, self.table, target_table)
else:
sqltext = self.sqltext
@@ -4864,10 +4873,11 @@ class Computed(FetchedValue, SchemaItem):
return self
def copy(self, target_table=None, **kw):
- if target_table is not None:
- sqltext = _copy_expression(self.sqltext, self.table, target_table)
- else:
- sqltext = self.sqltext
+ sqltext = _copy_expression(
+ self.sqltext,
+ self.column.table if self.column is not None else None,
+ target_table,
+ )
g = Computed(sqltext, persisted=self.persisted)
return self._schema_item_copy(g)
@@ -4998,7 +5008,7 @@ class Identity(IdentityOptions, FetchedValue, SchemaItem):
def _as_for_update(self, for_update):
return self
- def copy(self, target_table=None, **kw):
+ def copy(self, **kw):
i = Identity(
always=self.always,
on_null=self.on_null,