From 26c0e8e1846a4e6ac05c15a1ad188a5655b72edb Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 16 Jul 2022 16:19:15 -0400 Subject: implement column._merge() this takes the user-defined args of one Column and merges them into the not-user-defined args of another Column. Implemented within the pep-593 column transfer operation to begin to make this new feature more robust. work may still be needed for constraints etc. but in theory everything from the left side annotated column should take effect for the right side if not otherwise specified on the right. Change-Id: I57eb37ed6ceb4b60979a35cfc4b63731d990911d --- lib/sqlalchemy/sql/schema.py | 113 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 112 insertions(+), 1 deletion(-) (limited to 'lib/sqlalchemy/sql') diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 979b8319e..4ed5b9e6b 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -2233,6 +2233,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): server_default = self.server_default server_onupdate = self.server_onupdate if isinstance(server_default, (Computed, Identity)): + # TODO: likely should be copied in all cases args.append(server_default._copy(**kw)) server_default = server_onupdate = None @@ -2243,6 +2244,10 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): if self._user_defined_nullable is not NULL_UNSPECIFIED: column_kwargs["nullable"] = self._user_defined_nullable + # TODO: DefaultGenerator is not copied here! it's just used again + # with _set_parent() pointing to the old column. see the new + # use of _copy() in the new _merge() method + c = self._constructor( name=self.name, type_=type_, @@ -2264,6 +2269,69 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): ) return self._schema_item_copy(c) + def _merge(self, other: Column[Any]) -> None: + """merge the elements of another column into this one. + + this is used by ORM pep-593 merge and will likely need a lot + of fixes. + + + """ + + if self.primary_key: + other.primary_key = True + + type_ = self.type + if not type_._isnull and other.type._isnull: + if isinstance(type_, SchemaEventTarget): + type_ = type_.copy() + + other.type = type_ + + if isinstance(type_, SchemaEventTarget): + type_._set_parent_with_dispatch(other) + + for impl in type_._variant_mapping.values(): + if isinstance(impl, SchemaEventTarget): + impl._set_parent_with_dispatch(other) + + if ( + self._user_defined_nullable is not NULL_UNSPECIFIED + and other._user_defined_nullable is NULL_UNSPECIFIED + ): + other.nullable = self.nullable + + if self.default is not None and other.default is None: + new_default = self.default._copy() + new_default._set_parent(other) + + if self.server_default and other.server_default is None: + new_server_default = self.server_default + if isinstance(new_server_default, FetchedValue): + new_server_default = new_server_default._copy() + new_server_default._set_parent(other) + else: + other.server_default = new_server_default + + if self.server_onupdate and other.server_onupdate is None: + new_server_onupdate = self.server_onupdate + new_server_onupdate = new_server_onupdate._copy() + new_server_onupdate._set_parent(other) + + if self.onupdate and other.onupdate is None: + new_onupdate = self.onupdate._copy() + new_onupdate._set_parent(other) + + for const in self.constraints: + if not const._type_bound: + new_const = const._copy() + new_const._set_parent(other) + + for fk in self.foreign_keys: + if not fk.constraint: + new_fk = fk._copy() + new_fk._set_parent(other) + def _make_proxy( self, selectable: FromClause, @@ -2948,6 +3016,9 @@ class DefaultGenerator(Executable, SchemaItem): else: self.column.default = self + def _copy(self) -> DefaultGenerator: + raise NotImplementedError() + def _execute_on_connection( self, connection: Connection, @@ -3077,6 +3148,11 @@ class ScalarElementColumnDefault(ColumnDefault): self.for_update = for_update self.arg = arg + def _copy(self) -> ScalarElementColumnDefault: + return ScalarElementColumnDefault( + arg=self.arg, for_update=self.for_update + ) + # _SQLExprDefault = Union["ColumnElement[Any]", "TextClause", "SelectBase"] _SQLExprDefault = Union["ColumnElement[Any]", "TextClause"] @@ -3101,6 +3177,11 @@ class ColumnElementColumnDefault(ColumnDefault): self.for_update = for_update self.arg = arg + def _copy(self) -> ColumnElementColumnDefault: + return ColumnElementColumnDefault( + arg=self.arg, for_update=self.for_update + ) + @util.memoized_property @util.preload_module("sqlalchemy.sql.sqltypes") def _arg_is_typed(self) -> bool: @@ -3132,6 +3213,9 @@ class CallableColumnDefault(ColumnDefault): self.for_update = for_update self.arg = self._maybe_wrap_callable(arg) + def _copy(self) -> CallableColumnDefault: + return CallableColumnDefault(arg=self.arg, for_update=self.for_update) + def _maybe_wrap_callable( self, fn: Union[_CallableColumnDefaultProtocol, Callable[[], Any]] ) -> _CallableColumnDefaultProtocol: @@ -3266,7 +3350,7 @@ class Sequence(HasSchemaAttr, IdentityOptions, DefaultGenerator): nomaxvalue: Optional[bool] = None, cycle: Optional[bool] = None, schema: Optional[Union[str, Literal[SchemaConst.BLANK_SCHEMA]]] = None, - cache: Optional[bool] = None, + cache: Optional[int] = None, order: Optional[bool] = None, data_type: Optional[_TypeEngineArgument[int]] = None, optional: bool = False, @@ -3459,6 +3543,25 @@ class Sequence(HasSchemaAttr, IdentityOptions, DefaultGenerator): super(Sequence, self)._set_parent(column) column._on_table_attach(self._set_table) + def _copy(self) -> Sequence: + return Sequence( + name=self.name, + start=self.start, + increment=self.increment, + minvalue=self.minvalue, + maxvalue=self.maxvalue, + nominvalue=self.nominvalue, + nomaxvalue=self.nomaxvalue, + cycle=self.cycle, + schema=self.schema, + cache=self.cache, + order=self.order, + data_type=self.data_type, + optional=self.optional, + metadata=self.metadata, + for_update=self.for_update, + ) + def _set_table(self, column: Column[Any], table: Table) -> None: self._set_metadata(table.metadata) @@ -3522,6 +3625,9 @@ class FetchedValue(SchemaEventTarget): else: return self._clone(for_update) # type: ignore + def _copy(self) -> FetchedValue: + return FetchedValue(self.for_update) + def _clone(self, for_update: bool) -> Any: n = self.__class__.__new__(self.__class__) n.__dict__.update(self.__dict__) @@ -3577,6 +3683,11 @@ class DefaultClause(FetchedValue): self.arg = arg self.reflected = _reflected + def _copy(self) -> DefaultClause: + return DefaultClause( + arg=self.arg, for_update=self.for_update, _reflected=self.reflected + ) + def __repr__(self) -> str: return "DefaultClause(%r, for_update=%r)" % (self.arg, self.for_update) -- cgit v1.2.1