diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/exc.py | 9 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/base.py | 11 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/schema.py | 79 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/__init__.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/config.py | 29 |
5 files changed, 101 insertions, 28 deletions
diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py index fa46a46c4..c1f1a9c1c 100644 --- a/lib/sqlalchemy/exc.py +++ b/lib/sqlalchemy/exc.py @@ -125,6 +125,15 @@ class ArgumentError(SQLAlchemyError): """ +class DuplicateColumnError(ArgumentError): + """a Column is being added to a Table that would replace another + Column, without appropriate parameters to allow this in place. + + .. versionadded:: 2.0.0b4 + + """ + + class ObjectNotExecutableError(ArgumentError): """Raised when an object is passed to .execute() that can't be executed as SQL. diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index fc80334e8..0b96e5bbf 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -1969,7 +1969,11 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): # delete higher index del self._index[len(self._collection)] - def replace(self, column: _NAMEDCOL) -> None: + def replace( + self, + column: _NAMEDCOL, + extra_remove: Optional[Iterable[_NAMEDCOL]] = None, + ) -> None: """add the given column to this collection, removing unaliased versions of this column as well as existing columns with the same key. @@ -1986,7 +1990,10 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): """ - remove_col = set() + if extra_remove: + remove_col = set(extra_remove) + else: + remove_col = set() # remove up to two columns based on matches of name as well as key if column.name in self._index and column.key != column.name: other = self._index[column.name][1] diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 2d04b28a8..cb28564d1 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -872,6 +872,7 @@ class Table( allow_replacements=extend_existing or keep_existing or autoload_with, + all_names={}, ) def _autoload( @@ -936,9 +937,13 @@ class Table( schema = kwargs.pop("schema", None) _extend_on = kwargs.pop("_extend_on", None) _reflect_info = kwargs.pop("_reflect_info", None) + # these arguments are only used with _init() - kwargs.pop("extend_existing", False) - kwargs.pop("keep_existing", False) + extend_existing = kwargs.pop("extend_existing", False) + keep_existing = kwargs.pop("keep_existing", False) + + assert extend_existing + assert not keep_existing if schema and schema != self.schema: raise exc.ArgumentError( @@ -987,8 +992,9 @@ class Table( _reflect_info=_reflect_info, ) + all_names = {c.name: c for c in self.c} self._extra_kwargs(**kwargs) - self._init_items(*args) + self._init_items(*args, allow_replacements=True, all_names=all_names) def _extra_kwargs(self, **kwargs: Any) -> None: self._validate_dialect_kwargs(kwargs) @@ -1070,9 +1076,18 @@ class Table( .. versionadded:: 1.4.0 """ - column._set_parent_with_dispatch( - self, allow_replacements=replace_existing - ) + try: + column._set_parent_with_dispatch( + self, + allow_replacements=replace_existing, + all_names={c.name: c for c in self.c}, + ) + except exc.DuplicateColumnError as de: + raise exc.DuplicateColumnError( + f"{de.args[0]} Specify replace_existing=True to " + "Table.append_column() to replace an " + "existing column." + ) from de def append_constraint(self, constraint: Union[Index, Constraint]) -> None: """Append a :class:`_schema.Constraint` to this @@ -2099,10 +2114,12 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): + ["%s=%s" % (k, repr(getattr(self, k))) for k in kwarg] ) - def _set_parent( + def _set_parent( # type: ignore[override] self, parent: SchemaEventTarget, - allow_replacements: bool = True, + *, + all_names: Dict[str, Column[Any]], + allow_replacements: bool, **kw: Any, ) -> None: table = parent @@ -2125,19 +2142,32 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): % (self.key, existing.description) ) + extra_remove = None + existing_col = None + conflicts_on = "" + if self.key in table._columns: - col = table._columns[self.key] - if col is not self: + existing_col = table._columns[self.key] + if self.key == self.name: + conflicts_on = "name" + else: + conflicts_on = "key" + elif self.name in all_names: + existing_col = all_names[self.name] + extra_remove = {existing_col} + conflicts_on = "name" + + if existing_col is not None: + if existing_col is not self: if not allow_replacements: - util.warn_deprecated( - "A column with name '%s' is already present " - "in table '%s'. Please use method " - ":meth:`_schema.Table.append_column` with the " - "parameter ``replace_existing=True`` to replace an " - "existing column." % (self.key, table.name), - "1.4", + raise exc.DuplicateColumnError( + f"A column with {conflicts_on} " + f"""'{ + self.key if conflicts_on == 'key' else self.name + }' """ + f"is already present in table '{table.name}'." ) - for fk in col.foreign_keys: + for fk in existing_col.foreign_keys: table.foreign_keys.remove(fk) if fk.constraint in table.constraints: # this might have been removed @@ -2145,8 +2175,17 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): # and more than one col being replaced table.constraints.remove(fk.constraint) - table._columns.replace(self) - + if extra_remove and existing_col is not None and self.key == self.name: + util.warn( + f'Column with user-specified key "{existing_col.key}" is ' + "being replaced with " + f'plain named column "{self.name}", ' + f'key "{existing_col.key}" is being removed. If this is a ' + "reflection operation, specify autoload_replace=False to " + "prevent this replacement." + ) + table._columns.replace(self, extra_remove=extra_remove) + all_names[self.name] = self self.table = table if self.primary_key: diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py index 3a028f002..993fc4954 100644 --- a/lib/sqlalchemy/testing/__init__.py +++ b/lib/sqlalchemy/testing/__init__.py @@ -54,6 +54,7 @@ from .config import db from .config import fixture from .config import requirements as requires from .config import skip_test +from .config import Variation from .config import variation from .exclusions import _is_excluded from .exclusions import _server_version diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py index a75c36776..957876579 100644 --- a/lib/sqlalchemy/testing/config.py +++ b/lib/sqlalchemy/testing/config.py @@ -14,11 +14,13 @@ import typing from typing import Any from typing import Callable from typing import Iterable +from typing import NoReturn from typing import Optional from typing import Tuple from typing import TypeVar from typing import Union +from .util import fail from .. import util requirements = None @@ -128,21 +130,36 @@ def combinations_list( return combinations(*arg_iterable, **kw) -class _variation_base: - __slots__ = ("name", "argname") +class Variation: + __slots__ = ("_name", "_argname") def __init__(self, case, argname, case_names): - self.name = case - self.argname = argname + self._name = case + self._argname = argname for casename in case_names: setattr(self, casename, casename == case) + if typing.TYPE_CHECKING: + + def __getattr__(self, key: str) -> bool: + ... + + @property + def name(self): + return self._name + def __bool__(self): - return self.name == self.argname + return self._name == self._argname def __nonzero__(self): return not self.__bool__() + def __str__(self): + return f"{self._argname}={self._name!r}" + + def fail(self) -> NoReturn: + fail(f"Unknown {self}") + def variation(argname, cases): """a helper around testing.combinations that provides a single namespace @@ -193,7 +210,7 @@ def variation(argname, cases): typ = type( argname, - (_variation_base,), + (Variation,), { "__slots__": tuple(case_names), }, |
