summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGord Thompson <gord@gordthompson.com>2021-01-25 11:24:25 -0700
committerMike Bayer <mike_mp@zzzcomputing.com>2021-01-29 16:48:00 -0500
commit09e999808a5272a1426a7c00e3a4f27b2e27b8d7 (patch)
treecdc8f16b86d19fdc2c95496c8ad12e8f226ecb03
parentbead576769e655481cddb4623fad4bfa0aaccdda (diff)
downloadsqlalchemy-09e999808a5272a1426a7c00e3a4f27b2e27b8d7.tar.gz
Use schema._copy_expression() fully in column collection constraints
Fixed issue where using :meth:`_schema.Table.to_metadata` (called :meth:`_schema.Table.tometadata` in 1.3) in conjunction with a PostgreSQL :class:`_postgresql.ExcludeConstraint` that made use of ad-hoc column expressions would fail to copy correctly. Fixes: #5850 Change-Id: I062480afb23f6f60962b7b55bc93f5e4e6ff05e4
-rw-r--r--doc/build/changelog/unreleased_13/5850.rst8
-rw-r--r--lib/sqlalchemy/dialects/postgresql/ext.py11
-rw-r--r--lib/sqlalchemy/dialects/sqlite/base.py9
-rw-r--r--lib/sqlalchemy/sql/schema.py24
-rw-r--r--test/dialect/postgresql/test_compiler.py59
-rw-r--r--test/sql/test_metadata.py50
6 files changed, 148 insertions, 13 deletions
diff --git a/doc/build/changelog/unreleased_13/5850.rst b/doc/build/changelog/unreleased_13/5850.rst
new file mode 100644
index 000000000..2d73a42fb
--- /dev/null
+++ b/doc/build/changelog/unreleased_13/5850.rst
@@ -0,0 +1,8 @@
+.. change::
+ :tags: bug, postgresql
+ :tickets: 5850
+
+ Fixed issue where using :meth:`_schema.Table.to_metadata` (called
+ :meth:`_schema.Table.tometadata` in 1.3) in conjunction with a PostgreSQL
+ :class:`_postgresql.ExcludeConstraint` that made use of ad-hoc column
+ expressions would fail to copy correctly. \ No newline at end of file
diff --git a/lib/sqlalchemy/dialects/postgresql/ext.py b/lib/sqlalchemy/dialects/postgresql/ext.py
index 1bd1e5258..a93a3477d 100644
--- a/lib/sqlalchemy/dialects/postgresql/ext.py
+++ b/lib/sqlalchemy/dialects/postgresql/ext.py
@@ -12,6 +12,7 @@ from ...sql import elements
from ...sql import expression
from ...sql import functions
from ...sql import roles
+from ...sql import schema
from ...sql.schema import ColumnCollectionConstraint
@@ -240,8 +241,14 @@ class ExcludeConstraint(ColumnCollectionConstraint):
)
]
- def copy(self, **kw):
- elements = [(col, self.operators[col]) for col in self.columns.keys()]
+ def copy(self, target_table=None, **kw):
+ elements = [
+ (
+ schema._copy_expression(expr, self.parent, target_table),
+ self.operators[expr.name],
+ )
+ for expr in self.columns
+ ]
c = self.__class__(
*elements,
name=self.name,
diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py
index ded1ce2a8..8582c1f14 100644
--- a/lib/sqlalchemy/dialects/sqlite/base.py
+++ b/lib/sqlalchemy/dialects/sqlite/base.py
@@ -826,6 +826,7 @@ from ...sql import ColumnElement
from ...sql import compiler
from ...sql import elements
from ...sql import roles
+from ...sql import schema
from ...types import BLOB # noqa
from ...types import BOOLEAN # noqa
from ...types import CHAR # noqa
@@ -1502,9 +1503,11 @@ class SQLiteDDLCompiler(compiler.DDLCompiler):
"on_conflict"
]
if on_conflict_clause is None and len(constraint.columns) == 1:
- on_conflict_clause = list(constraint)[0].dialect_options["sqlite"][
- "on_conflict_unique"
- ]
+ col1 = list(constraint)[0]
+ if isinstance(col1, schema.SchemaItem):
+ on_conflict_clause = list(constraint)[0].dialect_options[
+ "sqlite"
+ ]["on_conflict_unique"]
if on_conflict_clause is not None:
text += " ON CONFLICT " + on_conflict_clause
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index 5ca73dac3..a14a9ea6d 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -92,6 +92,9 @@ def _get_table_key(name, schema):
# this should really be in sql/util.py but we'd have to
# break an import cycle
def _copy_expression(expression, source_table, target_table):
+ if source_table is None or target_table is None:
+ return expression
+
def replace(col):
if (
isinstance(col, Column)
@@ -3189,7 +3192,7 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint):
def __contains__(self, x):
return x in self.columns
- def copy(self, **kw):
+ def copy(self, target_table=None, **kw):
# ticket #5276
constraint_kwargs = {}
for dialect_name in self.dialect_options:
@@ -3206,7 +3209,10 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint):
name=self.name,
deferrable=self.deferrable,
initially=self.initially,
- *self.columns.keys(),
+ *[
+ _copy_expression(expr, self.parent, target_table)
+ for expr in self.columns
+ ],
**constraint_kwargs
)
return self._schema_item_copy(c)
@@ -3310,6 +3316,9 @@ class CheckConstraint(ColumnCollectionConstraint):
def copy(self, target_table=None, **kw):
if target_table is not None:
+ # note that target_table is None for the copy process of
+ # a column-bound CheckConstraint, so this path is not reached
+ # in that case.
sqltext = _copy_expression(self.sqltext, self.table, target_table)
else:
sqltext = self.sqltext
@@ -4781,10 +4790,11 @@ class Computed(FetchedValue, SchemaItem):
return self
def copy(self, target_table=None, **kw):
- if target_table is not None:
- sqltext = _copy_expression(self.sqltext, self.table, target_table)
- else:
- sqltext = self.sqltext
+ sqltext = _copy_expression(
+ self.sqltext,
+ self.column.table if self.column is not None else None,
+ target_table,
+ )
g = Computed(sqltext, persisted=self.persisted)
return self._schema_item_copy(g)
@@ -4915,7 +4925,7 @@ class Identity(IdentityOptions, FetchedValue, SchemaItem):
def _as_for_update(self, for_update):
return self
- def copy(self, target_table=None, **kw):
+ def copy(self, **kw):
i = Identity(
always=self.always,
on_null=self.on_null,
diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py
index 8c8a84c4e..5f0c5f3a5 100644
--- a/test/dialect/postgresql/test_compiler.py
+++ b/test/dialect/postgresql/test_compiler.py
@@ -6,6 +6,7 @@ from sqlalchemy import bindparam
from sqlalchemy import cast
from sqlalchemy import Column
from sqlalchemy import Computed
+from sqlalchemy import Date
from sqlalchemy import delete
from sqlalchemy import Enum
from sqlalchemy import exc
@@ -807,6 +808,64 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
dialect=postgresql.dialect(),
)
+ @testing.combinations(
+ (True, "deferred"),
+ (False, "immediate"),
+ argnames="deferrable_value, initially_value",
+ )
+ def test_copy_exclude_constraint_adhoc_columns(
+ self, deferrable_value, initially_value
+ ):
+ meta = MetaData()
+ table = Table(
+ "mytable",
+ meta,
+ Column("myid", Integer, Sequence("foo_id_seq"), primary_key=True),
+ Column("valid_from_date", Date(), nullable=True),
+ Column("valid_thru_date", Date(), nullable=True),
+ )
+ cons = ExcludeConstraint(
+ (
+ literal_column(
+ "daterange(valid_from_date, valid_thru_date, '[]')"
+ ),
+ "&&",
+ ),
+ where=column("valid_from_date") <= column("valid_thru_date"),
+ name="ex_mytable_valid_date_range",
+ deferrable=deferrable_value,
+ initially=initially_value,
+ )
+
+ table.append_constraint(cons)
+ expected = (
+ "ALTER TABLE mytable ADD CONSTRAINT ex_mytable_valid_date_range "
+ "EXCLUDE USING gist "
+ "(daterange(valid_from_date, valid_thru_date, '[]') WITH &&) "
+ "WHERE (valid_from_date <= valid_thru_date) "
+ "%s %s"
+ % (
+ "NOT DEFERRABLE" if not deferrable_value else "DEFERRABLE",
+ "INITIALLY %s" % initially_value,
+ )
+ )
+ self.assert_compile(
+ schema.AddConstraint(cons),
+ expected,
+ dialect=postgresql.dialect(),
+ )
+
+ meta2 = MetaData()
+ table2 = table.to_metadata(meta2)
+ cons2 = [
+ c for c in table2.constraints if isinstance(c, ExcludeConstraint)
+ ][0]
+ self.assert_compile(
+ schema.AddConstraint(cons2),
+ expected,
+ dialect=postgresql.dialect(),
+ )
+
def test_exclude_constraint_full(self):
m = MetaData()
room = Column("room", Integer, primary_key=True)
diff --git a/test/sql/test_metadata.py b/test/sql/test_metadata.py
index 17228dad6..4a592a476 100644
--- a/test/sql/test_metadata.py
+++ b/test/sql/test_metadata.py
@@ -40,6 +40,7 @@ from sqlalchemy.schema import DropIndex
from sqlalchemy.sql import naming
from sqlalchemy.sql import operators
from sqlalchemy.sql.elements import _NONE_NAME
+from sqlalchemy.sql.elements import literal_column
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import AssertsCompiledSQL
@@ -762,7 +763,7 @@ class MetaDataTest(fixtures.TestBase, ComparesTables):
eq_(repr(const), exp)
-class ToMetaDataTest(fixtures.TestBase, ComparesTables):
+class ToMetaDataTest(fixtures.TestBase, AssertsCompiledSQL, ComparesTables):
@testing.requires.check_constraints
def test_copy(self):
# TODO: modernize this test for 2.0
@@ -915,6 +916,53 @@ class ToMetaDataTest(fixtures.TestBase, ComparesTables):
a2 = a.to_metadata(m2)
assert b2.c.y.references(a2.c.x)
+ def test_column_collection_constraint_w_ad_hoc_columns(self):
+ """Test ColumnCollectionConstraint that has columns that aren't
+ part of the Table.
+
+ """
+ meta = MetaData()
+
+ uq1 = UniqueConstraint(literal_column("some_name"))
+ cc1 = CheckConstraint(literal_column("some_name") > 5)
+ table = Table(
+ "mytable",
+ meta,
+ Column("myid", Integer, primary_key=True),
+ Column("name", String(40), nullable=True),
+ uq1,
+ cc1,
+ )
+
+ self.assert_compile(
+ schema.AddConstraint(uq1),
+ "ALTER TABLE mytable ADD UNIQUE (some_name)",
+ dialect="default",
+ )
+ self.assert_compile(
+ schema.AddConstraint(cc1),
+ "ALTER TABLE mytable ADD CHECK (some_name > 5)",
+ dialect="default",
+ )
+ meta2 = MetaData()
+ table2 = table.to_metadata(meta2)
+ uq2 = [
+ c for c in table2.constraints if isinstance(c, UniqueConstraint)
+ ][0]
+ cc2 = [
+ c for c in table2.constraints if isinstance(c, CheckConstraint)
+ ][0]
+ self.assert_compile(
+ schema.AddConstraint(uq2),
+ "ALTER TABLE mytable ADD UNIQUE (some_name)",
+ dialect="default",
+ )
+ self.assert_compile(
+ schema.AddConstraint(cc2),
+ "ALTER TABLE mytable ADD CHECK (some_name > 5)",
+ dialect="default",
+ )
+
def test_change_schema(self):
meta = MetaData()