diff options
68 files changed, 2397 insertions, 1587 deletions
diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index 72102ac26..ccf573675 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -1638,7 +1638,6 @@ class CursorResult(Result): :ref:`tutorial_update_delete_rowcount` - in the :ref:`unified_tutorial` """ # noqa: E501 - try: return self.context.rowcount except BaseException as e: diff --git a/lib/sqlalchemy/engine/util.py b/lib/sqlalchemy/engine/util.py index 529b2ca73..45f6bf20b 100644 --- a/lib/sqlalchemy/engine/util.py +++ b/lib/sqlalchemy/engine/util.py @@ -48,7 +48,7 @@ def connection_memoize(key: str) -> Callable[[_C], _C]: connection.info[key] = val = fn(self, connection) return val - return decorated # type: ignore[return-value] + return decorated # type: ignore class _TConsSubject(Protocol): diff --git a/lib/sqlalchemy/ext/asyncio/scoping.py b/lib/sqlalchemy/ext/asyncio/scoping.py index 33cf3f745..c7a6e2ca0 100644 --- a/lib/sqlalchemy/ext/asyncio/scoping.py +++ b/lib/sqlalchemy/ext/asyncio/scoping.py @@ -76,7 +76,6 @@ if TYPE_CHECKING: "expunge", "expunge_all", "flush", - "get", "get_bind", "is_modified", "invalidate", @@ -204,6 +203,49 @@ class async_scoped_session: await self.registry().close() self.registry.clear() + async def get( + self, + entity: _EntityBindKey[_O], + ident: _PKIdentityArgument, + *, + options: Optional[Sequence[ORMOption]] = None, + populate_existing: bool = False, + with_for_update: Optional[ForUpdateArg] = None, + identity_token: Optional[Any] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + ) -> Optional[_O]: + r"""Return an instance based on the given primary key identifier, + or ``None`` if not found. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. seealso:: + + :meth:`_orm.Session.get` - main documentation for get + + + + """ # noqa: E501 + + # this was proxied but Mypy is requiring the return type to be + # clarified + + # work around: + # https://github.com/python/typing/discussions/1143 + return_value = await self._proxied.get( + entity, + ident, + options=options, + populate_existing=populate_existing, + with_for_update=with_for_update, + identity_token=identity_token, + execution_options=execution_options, + ) + return return_value + # START PROXY METHODS async_scoped_session # code within this block is **programmatically, @@ -632,43 +674,6 @@ class async_scoped_session: return await self._proxied.flush(objects=objects) - async def get( - self, - entity: _EntityBindKey[_O], - ident: _PKIdentityArgument, - *, - options: Optional[Sequence[ORMOption]] = None, - populate_existing: bool = False, - with_for_update: Optional[ForUpdateArg] = None, - identity_token: Optional[Any] = None, - execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, - ) -> Optional[_O]: - r"""Return an instance based on the given primary key identifier, - or ``None`` if not found. - - .. container:: class_bases - - Proxied for the :class:`_asyncio.AsyncSession` class on - behalf of the :class:`_asyncio.scoping.async_scoped_session` class. - - .. seealso:: - - :meth:`_orm.Session.get` - main documentation for get - - - - """ # noqa: E501 - - return await self._proxied.get( - entity, - ident, - options=options, - populate_existing=populate_existing, - with_for_update=with_for_update, - identity_token=identity_token, - execution_options=execution_options, - ) - def get_bind( self, mapper: Optional[_EntityBindKey[_O]] = None, diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py index be872804e..7200414a1 100644 --- a/lib/sqlalchemy/ext/hybrid.py +++ b/lib/sqlalchemy/ext/hybrid.py @@ -832,6 +832,8 @@ from ..util.typing import Protocol if TYPE_CHECKING: + from ..orm._typing import _ORMColumnExprArgument + from ..orm.interfaces import MapperProperty from ..orm.util import AliasedInsp from ..sql._typing import _ColumnExpressionArgument from ..sql._typing import _DMLColumnArgument @@ -840,7 +842,6 @@ if TYPE_CHECKING: from ..sql.operators import OperatorType from ..sql.roles import ColumnsClauseRole - _T = TypeVar("_T", bound=Any) _T_co = TypeVar("_T_co", bound=Any, covariant=True) _T_con = TypeVar("_T_con", bound=Any, contravariant=True) @@ -1289,7 +1290,7 @@ class Comparator(interfaces.PropComparator[_T]): ): self.expression = expression - def __clause_element__(self) -> ColumnsClauseRole: + def __clause_element__(self) -> _ORMColumnExprArgument[_T]: expr = self.expression if is_has_clause_element(expr): ret_expr = expr.__clause_element__() @@ -1298,10 +1299,15 @@ class Comparator(interfaces.PropComparator[_T]): assert isinstance(expr, ColumnElement) ret_expr = expr + if TYPE_CHECKING: + # see test_hybrid->test_expression_isnt_clause_element + # that exercises the usual place this is caught if not + # true + assert isinstance(ret_expr, ColumnElement) return ret_expr - @util.non_memoized_property - def property(self) -> Any: + @util.ro_non_memoized_property + def property(self) -> Optional[interfaces.MapperProperty[_T]]: return None def adapt_to_entity( @@ -1325,7 +1331,7 @@ class ExprComparator(Comparator[_T]): def __getattr__(self, key: str) -> Any: return getattr(self.expression, key) - @util.non_memoized_property + @util.ro_non_memoized_property def info(self) -> _InfoType: return self.hybrid.info @@ -1339,8 +1345,8 @@ class ExprComparator(Comparator[_T]): else: return [(self.expression, value)] - @util.non_memoized_property - def property(self) -> Any: + @util.ro_non_memoized_property + def property(self) -> Optional[MapperProperty[_T]]: return self.expression.property # type: ignore def operate( diff --git a/lib/sqlalchemy/ext/instrumentation.py b/lib/sqlalchemy/ext/instrumentation.py index 72448fbdc..b1138a4ad 100644 --- a/lib/sqlalchemy/ext/instrumentation.py +++ b/lib/sqlalchemy/ext/instrumentation.py @@ -25,6 +25,7 @@ from ..orm import exc as orm_exc from ..orm import instrumentation as orm_instrumentation from ..orm.instrumentation import _default_dict_getter from ..orm.instrumentation import _default_manager_getter +from ..orm.instrumentation import _default_opt_manager_getter from ..orm.instrumentation import _default_state_getter from ..orm.instrumentation import ClassManager from ..orm.instrumentation import InstrumentationFactory @@ -140,7 +141,7 @@ class ExtendedInstrumentationRegistry(InstrumentationFactory): hierarchy = util.class_hierarchy(cls) factories = set() for member in hierarchy: - manager = self.manager_of_class(member) + manager = self.opt_manager_of_class(member) if manager is not None: factories.add(manager.factory) else: @@ -161,17 +162,34 @@ class ExtendedInstrumentationRegistry(InstrumentationFactory): del self._state_finders[class_] del self._dict_finders[class_] - def manager_of_class(self, cls): - if cls is None: - return None + def opt_manager_of_class(self, cls): try: - finder = self._manager_finders.get(cls, _default_manager_getter) + finder = self._manager_finders.get( + cls, _default_opt_manager_getter + ) except TypeError: # due to weakref lookup on invalid object return None else: return finder(cls) + def manager_of_class(self, cls): + try: + finder = self._manager_finders.get(cls, _default_manager_getter) + except TypeError: + # due to weakref lookup on invalid object + raise orm_exc.UnmappedClassError( + cls, f"Can't locate an instrumentation manager for class {cls}" + ) + else: + manager = finder(cls) + if manager is None: + raise orm_exc.UnmappedClassError( + cls, + f"Can't locate an instrumentation manager for class {cls}", + ) + return manager + def state_of(self, instance): if instance is None: raise AttributeError("None has no persistent state.") @@ -384,6 +402,7 @@ def _install_instrumented_lookups(): instance_state=_instrumentation_factory.state_of, instance_dict=_instrumentation_factory.dict_of, manager_of_class=_instrumentation_factory.manager_of_class, + opt_manager_of_class=_instrumentation_factory.opt_manager_of_class, ) ) @@ -395,16 +414,19 @@ def _reinstall_default_lookups(): instance_state=_default_state_getter, instance_dict=_default_dict_getter, manager_of_class=_default_manager_getter, + opt_manager_of_class=_default_opt_manager_getter, ) ) _instrumentation_factory._extended = False def _install_lookups(lookups): - global instance_state, instance_dict, manager_of_class + global instance_state, instance_dict + global manager_of_class, opt_manager_of_class instance_state = lookups["instance_state"] instance_dict = lookups["instance_dict"] manager_of_class = lookups["manager_of_class"] + opt_manager_of_class = lookups["opt_manager_of_class"] orm_base.instance_state = ( attributes.instance_state ) = orm_instrumentation.instance_state = instance_state @@ -414,3 +436,6 @@ def _install_lookups(lookups): orm_base.manager_of_class = ( attributes.manager_of_class ) = orm_instrumentation.manager_of_class = manager_of_class + orm_base.opt_manager_of_class = ( + attributes.opt_manager_of_class + ) = orm_instrumentation.opt_manager_of_class = opt_manager_of_class diff --git a/lib/sqlalchemy/inspection.py b/lib/sqlalchemy/inspection.py index 6b06c0d6b..01c740fe4 100644 --- a/lib/sqlalchemy/inspection.py +++ b/lib/sqlalchemy/inspection.py @@ -34,6 +34,7 @@ from typing import Any from typing import Callable from typing import Dict from typing import Generic +from typing import Optional from typing import overload from typing import Type from typing import TypeVar @@ -43,6 +44,9 @@ from . import exc from .util.typing import Literal _T = TypeVar("_T", bound=Any) +_F = TypeVar("_F", bound=Callable[..., Any]) + +_IN = TypeVar("_IN", bound="Inspectable[Any]") _registrars: Dict[type, Union[Literal[True], Callable[[Any], Any]]] = {} @@ -53,11 +57,22 @@ class Inspectable(Generic[_T]): This allows typing to set up a linkage between an object that can be inspected and the type of inspection it returns. + Unfortunately we cannot at the moment get all classes that are + returned by inspection to suit this interface as we get into + MRO issues. + """ + __slots__ = () + @overload -def inspect(subject: Inspectable[_T], raiseerr: bool = True) -> _T: +def inspect(subject: Inspectable[_IN], raiseerr: bool = True) -> _IN: + ... + + +@overload +def inspect(subject: Any, raiseerr: Literal[False] = ...) -> Optional[Any]: ... @@ -108,9 +123,9 @@ def inspect(subject: Any, raiseerr: bool = True) -> Any: def _inspects( - *types: type, -) -> Callable[[Callable[[Any], Any]], Callable[[Any], Any]]: - def decorate(fn_or_cls: Callable[[Any], Any]) -> Callable[[Any], Any]: + *types: Type[Any], +) -> Callable[[_F], _F]: + def decorate(fn_or_cls: _F) -> _F: for type_ in types: if type_ in _registrars: raise AssertionError( @@ -122,7 +137,10 @@ def _inspects( return decorate -def _self_inspects(cls: Type[_T]) -> Type[_T]: +_TT = TypeVar("_TT", bound="Type[Any]") + + +def _self_inspects(cls: _TT) -> _TT: if cls in _registrars: raise AssertionError("Type %s is already " "registered" % cls) _registrars[cls] = True diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py index 7690c05de..457ad5c5a 100644 --- a/lib/sqlalchemy/orm/_orm_constructors.py +++ b/lib/sqlalchemy/orm/_orm_constructors.py @@ -9,20 +9,19 @@ from __future__ import annotations import typing from typing import Any +from typing import Callable from typing import Collection -from typing import Dict -from typing import List from typing import Optional from typing import overload -from typing import Set from typing import Type +from typing import TYPE_CHECKING from typing import Union -from . import mapper as mapperlib +from . import mapperlib as mapperlib +from ._typing import _O from .base import Mapped from .descriptor_props import Composite from .descriptor_props import Synonym -from .mapper import Mapper from .properties import ColumnProperty from .properties import MappedColumn from .query import AliasOption @@ -37,11 +36,29 @@ from .. import sql from .. import util from ..exc import InvalidRequestError from ..sql.base import SchemaEventTarget -from ..sql.selectable import Alias +from ..sql.schema import SchemaConst from ..sql.selectable import FromClause -from ..sql.type_api import TypeEngine from ..util.typing import Literal +if TYPE_CHECKING: + from ._typing import _EntityType + from ._typing import _ORMColumnExprArgument + from .descriptor_props import _CompositeAttrType + from .interfaces import PropComparator + from .query import Query + from .relationships import _LazyLoadArgumentType + from .relationships import _ORMBackrefArgument + from .relationships import _ORMColCollectionArgument + from .relationships import _ORMOrderByArgument + from .relationships import _RelationshipJoinConditionArgument + from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _InfoType + from ..sql._typing import _TypeEngineArgument + from ..sql.schema import _ServerDefaultType + from ..sql.schema import FetchedValue + from ..sql.selectable import Alias + from ..sql.selectable import Subquery + _T = typing.TypeVar("_T") @@ -61,7 +78,7 @@ SynonymProperty = Synonym "for entities to be matched up to a query that is established " "via :meth:`.Query.from_statement` and now does nothing.", ) -def contains_alias(alias) -> AliasOption: +def contains_alias(alias: Union[Alias, Subquery]) -> AliasOption: r"""Return a :class:`.MapperOption` that will indicate to the :class:`_query.Query` that the main table has been aliased. @@ -70,134 +87,36 @@ def contains_alias(alias) -> AliasOption: return AliasOption(alias) -# see test/ext/mypy/plain_files/mapped_column.py for mapped column -# typing tests - - -@overload -def mapped_column( - __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]], - *args: SchemaEventTarget, - nullable: Literal[None] = ..., - primary_key: Literal[None] = ..., - deferred: bool = ..., - **kw: Any, -) -> "MappedColumn[Any]": - ... - - -@overload -def mapped_column( - __name: str, - __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]], - *args: SchemaEventTarget, - nullable: Literal[None] = ..., - primary_key: Literal[None] = ..., - deferred: bool = ..., - **kw: Any, -) -> "MappedColumn[Any]": - ... - - -@overload -def mapped_column( - __name: str, - __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]], - *args: SchemaEventTarget, - nullable: Literal[True] = ..., - primary_key: Literal[None] = ..., - deferred: bool = ..., - **kw: Any, -) -> "MappedColumn[Optional[_T]]": - ... - - -@overload -def mapped_column( - __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]], - *args: SchemaEventTarget, - nullable: Literal[True] = ..., - primary_key: Literal[None] = ..., - deferred: bool = ..., - **kw: Any, -) -> "MappedColumn[Optional[_T]]": - ... - - -@overload -def mapped_column( - __name: str, - __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]], - *args: SchemaEventTarget, - nullable: Literal[False] = ..., - primary_key: Literal[None] = ..., - deferred: bool = ..., - **kw: Any, -) -> "MappedColumn[_T]": - ... - - -@overload -def mapped_column( - __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]], - *args: SchemaEventTarget, - nullable: Literal[False] = ..., - primary_key: Literal[None] = ..., - deferred: bool = ..., - **kw: Any, -) -> "MappedColumn[_T]": - ... - - -@overload -def mapped_column( - __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]], - *args: SchemaEventTarget, - nullable: bool = ..., - primary_key: Literal[True] = ..., - deferred: bool = ..., - **kw: Any, -) -> "MappedColumn[_T]": - ... - - -@overload -def mapped_column( - __name: str, - __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]], - *args: SchemaEventTarget, - nullable: bool = ..., - primary_key: Literal[True] = ..., - deferred: bool = ..., - **kw: Any, -) -> "MappedColumn[_T]": - ... - - -@overload -def mapped_column( - __name: str, - *args: SchemaEventTarget, - nullable: bool = ..., - primary_key: bool = ..., - deferred: bool = ..., - **kw: Any, -) -> "MappedColumn[Any]": - ... - - -@overload def mapped_column( + __name_pos: Optional[ + Union[str, _TypeEngineArgument[Any], SchemaEventTarget] + ] = None, + __type_pos: Optional[ + Union[_TypeEngineArgument[Any], SchemaEventTarget] + ] = None, *args: SchemaEventTarget, - nullable: bool = ..., - primary_key: bool = ..., - deferred: bool = ..., - **kw: Any, -) -> "MappedColumn[Any]": - ... - - -def mapped_column(*args: Any, **kw: Any) -> "MappedColumn[Any]": + nullable: Optional[ + Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]] + ] = SchemaConst.NULL_UNSPECIFIED, + primary_key: Optional[bool] = False, + deferred: bool = False, + name: Optional[str] = None, + type_: Optional[_TypeEngineArgument[Any]] = None, + autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto", + default: Optional[Any] = None, + doc: Optional[str] = None, + key: Optional[str] = None, + index: Optional[bool] = None, + unique: Optional[bool] = None, + info: Optional[_InfoType] = None, + onupdate: Optional[Any] = None, + server_default: Optional[_ServerDefaultType] = None, + server_onupdate: Optional[FetchedValue] = None, + quote: Optional[bool] = None, + system: bool = False, + comment: Optional[str] = None, + **dialect_kwargs: Any, +) -> MappedColumn[Any]: r"""construct a new ORM-mapped :class:`_schema.Column` construct. The :func:`_orm.mapped_column` function provides an ORM-aware and @@ -363,12 +282,45 @@ def mapped_column(*args: Any, **kw: Any) -> "MappedColumn[Any]": """ - return MappedColumn(*args, **kw) + return MappedColumn( + __name_pos, + __type_pos, + *args, + name=name, + type_=type_, + autoincrement=autoincrement, + default=default, + doc=doc, + key=key, + index=index, + unique=unique, + info=info, + nullable=nullable, + onupdate=onupdate, + primary_key=primary_key, + server_default=server_default, + server_onupdate=server_onupdate, + quote=quote, + comment=comment, + system=system, + deferred=deferred, + **dialect_kwargs, + ) def column_property( - column: sql.ColumnElement[_T], *additional_columns, **kwargs -) -> "ColumnProperty[_T]": + column: _ORMColumnExprArgument[_T], + *additional_columns: _ORMColumnExprArgument[Any], + group: Optional[str] = None, + deferred: bool = False, + raiseload: bool = False, + comparator_factory: Optional[Type[PropComparator[_T]]] = None, + descriptor: Optional[Any] = None, + active_history: bool = False, + expire_on_flush: bool = True, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, +) -> ColumnProperty[_T]: r"""Provide a column-level property for use with a mapping. Column-based properties can normally be applied to the mapper's @@ -452,13 +404,25 @@ def column_property( expressions """ - return ColumnProperty(column, *additional_columns, **kwargs) + return ColumnProperty( + column, + *additional_columns, + group=group, + deferred=deferred, + raiseload=raiseload, + comparator_factory=comparator_factory, + descriptor=descriptor, + active_history=active_history, + expire_on_flush=expire_on_flush, + info=info, + doc=doc, + ) @overload def composite( class_: Type[_T], - *attrs: Union[sql.ColumnElement[Any], MappedColumn, str, Mapped[Any]], + *attrs: _CompositeAttrType[Any], **kwargs: Any, ) -> Composite[_T]: ... @@ -466,7 +430,7 @@ def composite( @overload def composite( - *attrs: Union[sql.ColumnElement[Any], MappedColumn, str, Mapped[Any]], + *attrs: _CompositeAttrType[Any], **kwargs: Any, ) -> Composite[Any]: ... @@ -474,7 +438,7 @@ def composite( def composite( class_: Any = None, - *attrs: Union[sql.ColumnElement[Any], MappedColumn, str, Mapped[Any]], + *attrs: _CompositeAttrType[Any], **kwargs: Any, ) -> Composite[Any]: r"""Return a composite column-based property for use with a Mapper. @@ -529,13 +493,13 @@ def composite( def with_loader_criteria( - entity_or_base, - where_criteria, - loader_only=False, - include_aliases=False, - propagate_to_loaders=True, - track_closure_variables=True, -) -> "LoaderCriteriaOption": + entity_or_base: _EntityType[Any], + where_criteria: _ColumnExpressionArgument[bool], + loader_only: bool = False, + include_aliases: bool = False, + propagate_to_loaders: bool = True, + track_closure_variables: bool = True, +) -> LoaderCriteriaOption: """Add additional WHERE criteria to the load for all occurrences of a particular entity. @@ -711,180 +675,40 @@ def with_loader_criteria( ) -@overload -def relationship( - argument: str, - secondary=..., - *, - uselist: bool = ..., - collection_class: Literal[None] = ..., - primaryjoin=..., - secondaryjoin=..., - back_populates=..., - **kw: Any, -) -> Relationship[Any]: - ... - - -@overload -def relationship( - argument: str, - secondary=..., - *, - uselist: bool = ..., - collection_class: Type[Set] = ..., - primaryjoin=..., - secondaryjoin=..., - back_populates=..., - **kw: Any, -) -> Relationship[Set[Any]]: - ... - - -@overload -def relationship( - argument: str, - secondary=..., - *, - uselist: bool = ..., - collection_class: Type[List] = ..., - primaryjoin=..., - secondaryjoin=..., - back_populates=..., - **kw: Any, -) -> Relationship[List[Any]]: - ... - - -@overload -def relationship( - argument: Optional[_RelationshipArgumentType[_T]], - secondary=..., - *, - uselist: Literal[False] = ..., - collection_class: Literal[None] = ..., - primaryjoin=..., - secondaryjoin=..., - back_populates=..., - **kw: Any, -) -> Relationship[_T]: - ... - - -@overload -def relationship( - argument: Optional[_RelationshipArgumentType[_T]], - secondary=..., - *, - uselist: Literal[True] = ..., - collection_class: Literal[None] = ..., - primaryjoin=..., - secondaryjoin=..., - back_populates=..., - **kw: Any, -) -> Relationship[List[_T]]: - ... - - -@overload -def relationship( - argument: Optional[_RelationshipArgumentType[_T]], - secondary=..., - *, - uselist: Union[Literal[None], Literal[True]] = ..., - collection_class: Type[List] = ..., - primaryjoin=..., - secondaryjoin=..., - back_populates=..., - **kw: Any, -) -> Relationship[List[_T]]: - ... - - -@overload -def relationship( - argument: Optional[_RelationshipArgumentType[_T]], - secondary=..., - *, - uselist: Union[Literal[None], Literal[True]] = ..., - collection_class: Type[Set] = ..., - primaryjoin=..., - secondaryjoin=..., - back_populates=..., - **kw: Any, -) -> Relationship[Set[_T]]: - ... - - -@overload -def relationship( - argument: Optional[_RelationshipArgumentType[_T]], - secondary=..., - *, - uselist: Union[Literal[None], Literal[True]] = ..., - collection_class: Type[Dict[Any, Any]] = ..., - primaryjoin=..., - secondaryjoin=..., - back_populates=..., - **kw: Any, -) -> Relationship[Dict[Any, _T]]: - ... - - -@overload -def relationship( - argument: _RelationshipArgumentType[_T], - secondary=..., - *, - uselist: Literal[None] = ..., - collection_class: Literal[None] = ..., - primaryjoin=..., - secondaryjoin=None, - back_populates=None, - **kw: Any, -) -> Relationship[Any]: - ... - - -@overload -def relationship( - argument: Optional[_RelationshipArgumentType[_T]] = ..., - secondary=..., - *, - uselist: Literal[True] = ..., - collection_class: Any = ..., - primaryjoin=..., - secondaryjoin=..., - back_populates=..., - **kw: Any, -) -> Relationship[Any]: - ... - - -@overload def relationship( - argument: Literal[None] = ..., - secondary=..., - *, - uselist: Optional[bool] = ..., - collection_class: Any = ..., - primaryjoin=..., - secondaryjoin=..., - back_populates=..., - **kw: Any, -) -> Relationship[Any]: - ... - - -def relationship( - argument: Optional[_RelationshipArgumentType[_T]] = None, - secondary=None, + argument: Optional[_RelationshipArgumentType[Any]] = None, + secondary: Optional[FromClause] = None, *, uselist: Optional[bool] = None, - collection_class: Optional[Type[Collection]] = None, - primaryjoin=None, - secondaryjoin=None, - back_populates=None, + collection_class: Optional[ + Union[Type[Collection[Any]], Callable[[], Collection[Any]]] + ] = None, + primaryjoin: Optional[_RelationshipJoinConditionArgument] = None, + secondaryjoin: Optional[_RelationshipJoinConditionArgument] = None, + back_populates: Optional[str] = None, + order_by: _ORMOrderByArgument = False, + backref: Optional[_ORMBackrefArgument] = None, + overlaps: Optional[str] = None, + post_update: bool = False, + cascade: str = "save-update, merge", + viewonly: bool = False, + lazy: _LazyLoadArgumentType = "select", + passive_deletes: bool = False, + passive_updates: bool = True, + active_history: bool = False, + enable_typechecks: bool = True, + foreign_keys: Optional[_ORMColCollectionArgument] = None, + remote_side: Optional[_ORMColCollectionArgument] = None, + join_depth: Optional[int] = None, + comparator_factory: Optional[Type[PropComparator[Any]]] = None, + single_parent: bool = False, + innerjoin: bool = False, + distinct_target_key: Optional[bool] = None, + load_on_pending: bool = False, + query_class: Optional[Type[Query[Any]]] = None, + info: Optional[_InfoType] = None, + omit_join: Literal[None, False] = None, + sync_backref: Optional[bool] = None, **kw: Any, ) -> Relationship[Any]: """Provide a relationship between two mapped classes. @@ -1098,13 +922,6 @@ def relationship( :ref:`error_qzyx` - usage example - :param bake_queries=True: - Legacy parameter, not used. - - .. versionchanged:: 1.4.23 the "lambda caching" system is no longer - used by loader strategies and the ``bake_queries`` parameter - has no effect. - :param cascade: A comma-separated list of cascade rules which determines how Session operations should be "cascaded" from parent to child. @@ -1701,18 +1518,42 @@ def relationship( primaryjoin=primaryjoin, secondaryjoin=secondaryjoin, back_populates=back_populates, + order_by=order_by, + backref=backref, + overlaps=overlaps, + post_update=post_update, + cascade=cascade, + viewonly=viewonly, + lazy=lazy, + passive_deletes=passive_deletes, + passive_updates=passive_updates, + active_history=active_history, + enable_typechecks=enable_typechecks, + foreign_keys=foreign_keys, + remote_side=remote_side, + join_depth=join_depth, + comparator_factory=comparator_factory, + single_parent=single_parent, + innerjoin=innerjoin, + distinct_target_key=distinct_target_key, + load_on_pending=load_on_pending, + query_class=query_class, + info=info, + omit_join=omit_join, + sync_backref=sync_backref, **kw, ) def synonym( - name, - map_column=None, - descriptor=None, - comparator_factory=None, - doc=None, - info=None, -) -> "Synonym[Any]": + name: str, + *, + map_column: Optional[bool] = None, + descriptor: Optional[Any] = None, + comparator_factory: Optional[Type[PropComparator[_T]]] = None, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, +) -> Synonym[Any]: """Denote an attribute name as a synonym to a mapped property, in that the attribute will mirror the value and expression behavior of another attribute. @@ -1951,8 +1792,8 @@ def deferred(*columns, **kw): def query_expression( - default_expr: sql.ColumnElement[_T] = sql.null(), -) -> "Mapped[_T]": + default_expr: _ORMColumnExprArgument[_T] = sql.null(), +) -> Mapped[_T]: """Indicate an attribute that populates from a query-time SQL expression. :param default_expr: Optional SQL expression object that will be used in @@ -2010,33 +1851,33 @@ def clear_mappers(): @overload def aliased( - element: Union[Type[_T], "Mapper[_T]", "AliasedClass[_T]"], - alias=None, - name=None, - flat=False, - adapt_on_names=False, -) -> "AliasedClass[_T]": + element: _EntityType[_O], + alias: Optional[Union[Alias, Subquery]] = None, + name: Optional[str] = None, + flat: bool = False, + adapt_on_names: bool = False, +) -> AliasedClass[_O]: ... @overload def aliased( - element: "FromClause", - alias=None, - name=None, - flat=False, - adapt_on_names=False, -) -> "Alias": + element: FromClause, + alias: Optional[Union[Alias, Subquery]] = None, + name: Optional[str] = None, + flat: bool = False, + adapt_on_names: bool = False, +) -> FromClause: ... def aliased( - element: Union[Type[_T], "Mapper[_T]", "FromClause", "AliasedClass[_T]"], - alias=None, - name=None, - flat=False, - adapt_on_names=False, -) -> Union["AliasedClass[_T]", "Alias"]: + element: Union[_EntityType[_O], FromClause], + alias: Optional[Union[Alias, Subquery]] = None, + name: Optional[str] = None, + flat: bool = False, + adapt_on_names: bool = False, +) -> Union[AliasedClass[_O], FromClause]: """Produce an alias of the given element, usually an :class:`.AliasedClass` instance. @@ -2233,9 +2074,7 @@ def with_polymorphic( ) -def join( - left, right, onclause=None, isouter=False, full=False, join_to_left=None -): +def join(left, right, onclause=None, isouter=False, full=False): r"""Produce an inner join between left and right clauses. :func:`_orm.join` is an extension to the core join interface @@ -2270,16 +2109,11 @@ def join( See :ref:`orm_queryguide_joins` for information on modern usage of ORM level joins. - .. deprecated:: 0.8 - - the ``join_to_left`` parameter is deprecated, and will be removed - in a future release. The parameter has no effect. - """ return _ORMJoin(left, right, onclause, isouter, full) -def outerjoin(left, right, onclause=None, full=False, join_to_left=None): +def outerjoin(left, right, onclause=None, full=False): """Produce a left outer join between left and right clauses. This is the "outer join" version of the :func:`_orm.join` function, diff --git a/lib/sqlalchemy/orm/_typing.py b/lib/sqlalchemy/orm/_typing.py index 4250cdbe1..339844f14 100644 --- a/lib/sqlalchemy/orm/_typing.py +++ b/lib/sqlalchemy/orm/_typing.py @@ -2,6 +2,7 @@ from __future__ import annotations import operator from typing import Any +from typing import Callable from typing import Dict from typing import Optional from typing import Tuple @@ -10,7 +11,9 @@ from typing import TYPE_CHECKING from typing import TypeVar from typing import Union -from sqlalchemy.orm.interfaces import UserDefinedOption +from ..sql import roles +from ..sql._typing import _HasClauseElement +from ..sql.elements import ColumnElement from ..util.typing import Protocol from ..util.typing import TypeGuard @@ -18,8 +21,12 @@ if TYPE_CHECKING: from .attributes import AttributeImpl from .attributes import CollectionAttributeImpl from .base import PassiveFlag + from .decl_api import registry as _registry_type from .descriptor_props import _CompositeClassProto + from .interfaces import MapperProperty + from .interfaces import UserDefinedOption from .mapper import Mapper + from .relationships import Relationship from .state import InstanceState from .util import AliasedClass from .util import AliasedInsp @@ -27,21 +34,39 @@ if TYPE_CHECKING: _T = TypeVar("_T", bound=Any) + +# I would have preferred this were bound=object however it seems +# to not travel in all situations when defined in that way. _O = TypeVar("_O", bound=Any) """The 'ORM mapped object' type. -I would have preferred this were bound=object however it seems -to not travel in all situations when defined in that way. + """ +if TYPE_CHECKING: + _RegistryType = _registry_type + _InternalEntityType = Union["Mapper[_T]", "AliasedInsp[_T]"] -_EntityType = Union[_T, "AliasedClass[_T]", "Mapper[_T]", "AliasedInsp[_T]"] +_EntityType = Union[ + Type[_T], "AliasedClass[_T]", "Mapper[_T]", "AliasedInsp[_T]" +] _InstanceDict = Dict[str, Any] _IdentityKeyType = Tuple[Type[_T], Tuple[Any, ...], Optional[Any]] +_ORMColumnExprArgument = Union[ + ColumnElement[_T], + _HasClauseElement, + roles.ExpressionElementRole[_T], +] + +# somehow Protocol didn't want to work for this one +_ORMAdapterProto = Callable[ + [_ORMColumnExprArgument[_T], Optional[str]], _ORMColumnExprArgument[_T] +] + class _LoaderCallable(Protocol): def __call__(self, state: InstanceState[Any], passive: PassiveFlag) -> Any: @@ -60,10 +85,28 @@ def is_composite_class(obj: Any) -> TypeGuard[_CompositeClassProto]: if TYPE_CHECKING: + def insp_is_mapper_property(obj: Any) -> TypeGuard[MapperProperty[Any]]: + ... + + def insp_is_mapper(obj: Any) -> TypeGuard[Mapper[Any]]: + ... + + def insp_is_aliased_class(obj: Any) -> TypeGuard[AliasedInsp[Any]]: + ... + + def prop_is_relationship( + prop: MapperProperty[Any], + ) -> TypeGuard[Relationship[Any]]: + ... + def is_collection_impl( impl: AttributeImpl, ) -> TypeGuard[CollectionAttributeImpl]: ... else: + insp_is_mapper_property = operator.attrgetter("is_property") + insp_is_mapper = operator.attrgetter("is_mapper") + insp_is_aliased_class = operator.attrgetter("is_aliased_class") is_collection_impl = operator.attrgetter("collection") + prop_is_relationship = operator.attrgetter("_is_relationship") diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 33ce96a19..41d944c57 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -44,7 +44,7 @@ from .base import instance_dict as instance_dict from .base import instance_state as instance_state from .base import instance_str from .base import LOAD_AGAINST_COMMITTED -from .base import manager_of_class +from .base import manager_of_class as manager_of_class from .base import Mapped as Mapped # noqa from .base import NEVER_SET # noqa from .base import NO_AUTOFLUSH @@ -52,6 +52,7 @@ from .base import NO_CHANGE # noqa from .base import NO_RAISE from .base import NO_VALUE from .base import NON_PERSISTENT_OK # noqa +from .base import opt_manager_of_class as opt_manager_of_class from .base import PASSIVE_CLASS_MISMATCH # noqa from .base import PASSIVE_NO_FETCH from .base import PASSIVE_NO_FETCH_RELATED # noqa @@ -74,6 +75,7 @@ from ..sql import traversals from ..sql import visitors if TYPE_CHECKING: + from .interfaces import MapperProperty from .state import InstanceState from ..sql.dml import _DMLColumnElement from ..sql.elements import ColumnElement @@ -146,7 +148,7 @@ class QueryableAttribute( self._of_type = of_type self._extra_criteria = extra_criteria - manager = manager_of_class(class_) + manager = opt_manager_of_class(class_) # manager is None in the case of AliasedClass if manager: # propagate existing event listeners from @@ -370,7 +372,7 @@ class QueryableAttribute( return "%s.%s" % (self.class_.__name__, self.key) @util.memoized_property - def property(self): + def property(self) -> MapperProperty[_T]: """Return the :class:`.MapperProperty` associated with this :class:`.QueryableAttribute`. diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py index 3fa855a4b..054d52d83 100644 --- a/lib/sqlalchemy/orm/base.py +++ b/lib/sqlalchemy/orm/base.py @@ -26,24 +26,25 @@ from typing import TypeVar from typing import Union from . import exc +from ._typing import insp_is_mapper from .. import exc as sa_exc from .. import inspection from .. import util from ..sql.elements import SQLCoreOperations from ..util import FastIntFlag from ..util.langhelpers import TypingOnly -from ..util.typing import Concatenate from ..util.typing import Literal -from ..util.typing import ParamSpec from ..util.typing import Self if typing.TYPE_CHECKING: from ._typing import _InternalEntityType from .attributes import InstrumentedAttribute + from .instrumentation import ClassManager from .mapper import Mapper from .state import InstanceState from ..sql._typing import _InfoType + _T = TypeVar("_T", bound=Any) _O = TypeVar("_O", bound=object) @@ -246,21 +247,15 @@ _DEFER_FOR_STATE = util.symbol("DEFER_FOR_STATE") _RAISE_FOR_STATE = util.symbol("RAISE_FOR_STATE") -_Fn = TypeVar("_Fn", bound=Callable) -_Args = ParamSpec("_Args") +_F = TypeVar("_F", bound=Callable) _Self = TypeVar("_Self") def _assertions( *assertions: Any, -) -> Callable[ - [Callable[Concatenate[_Self, _Fn, _Args], _Self]], - Callable[Concatenate[_Self, _Fn, _Args], _Self], -]: +) -> Callable[[_F], _F]: @util.decorator - def generate( - fn: _Fn, self: _Self, *args: _Args.args, **kw: _Args.kwargs - ) -> _Self: + def generate(fn: _F, self: _Self, *args: Any, **kw: Any) -> _Self: for assertion in assertions: assertion(self, fn.__name__) fn(self, *args, **kw) @@ -269,13 +264,13 @@ def _assertions( return generate -# these can be replaced by sqlalchemy.ext.instrumentation -# if augmented class instrumentation is enabled. -def manager_of_class(cls): - return cls.__dict__.get(DEFAULT_MANAGER_ATTR, None) +if TYPE_CHECKING: + def manager_of_class(cls: Type[Any]) -> ClassManager: + ... -if TYPE_CHECKING: + def opt_manager_of_class(cls: Type[Any]) -> Optional[ClassManager]: + ... def instance_state(instance: _O) -> InstanceState[_O]: ... @@ -284,6 +279,20 @@ if TYPE_CHECKING: ... else: + # these can be replaced by sqlalchemy.ext.instrumentation + # if augmented class instrumentation is enabled. + + def manager_of_class(cls): + try: + return cls.__dict__[DEFAULT_MANAGER_ATTR] + except KeyError as ke: + raise exc.UnmappedClassError( + cls, f"Can't locate an instrumentation manager for class {cls}" + ) from ke + + def opt_manager_of_class(cls): + return cls.__dict__.get(DEFAULT_MANAGER_ATTR) + instance_state = operator.attrgetter(DEFAULT_STATE_ATTR) instance_dict = operator.attrgetter("__dict__") @@ -458,11 +467,12 @@ else: _state_mapper = util.dottedgetter("manager.mapper") -@inspection._inspects(type) -def _inspect_mapped_class(class_, configure=False): +def _inspect_mapped_class( + class_: Type[_O], configure: bool = False +) -> Optional[Mapper[_O]]: try: - class_manager = manager_of_class(class_) - if not class_manager.is_mapped: + class_manager = opt_manager_of_class(class_) + if class_manager is None or not class_manager.is_mapped: return None mapper = class_manager.mapper except exc.NO_STATE: @@ -473,7 +483,28 @@ def _inspect_mapped_class(class_, configure=False): return mapper -def class_mapper(class_: Type[_T], configure: bool = True) -> Mapper[_T]: +@inspection._inspects(type) +def _inspect_mc(class_: Type[_O]) -> Optional[Mapper[_O]]: + try: + class_manager = opt_manager_of_class(class_) + if class_manager is None or not class_manager.is_mapped: + return None + mapper = class_manager.mapper + except exc.NO_STATE: + return None + else: + return mapper + + +def _parse_mapper_argument(arg: Union[Mapper[_O], Type[_O]]) -> Mapper[_O]: + insp = inspection.inspect(arg, raiseerr=False) + if insp_is_mapper(insp): + return insp + + raise sa_exc.ArgumentError(f"Mapper or mapped class expected, got {arg!r}") + + +def class_mapper(class_: Type[_O], configure: bool = True) -> Mapper[_O]: """Given a class, return the primary :class:`_orm.Mapper` associated with the key. @@ -502,8 +533,8 @@ def class_mapper(class_: Type[_T], configure: bool = True) -> Mapper[_T]: class InspectionAttr: - """A base class applied to all ORM objects that can be returned - by the :func:`_sa.inspect` function. + """A base class applied to all ORM objects and attributes that are + related to things that can be returned by the :func:`_sa.inspect` function. The attributes defined here allow the usage of simple boolean checks to test basic facts about the object returned. diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index 419da65f7..4fee2d383 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -63,11 +63,14 @@ from ..sql.visitors import InternalTraversal if TYPE_CHECKING: from ._typing import _InternalEntityType + from .mapper import Mapper + from .query import Query from ..sql.compiler import _CompilerStackEntry from ..sql.dml import _DMLTableElement from ..sql.elements import ColumnElement from ..sql.selectable import _LabelConventionCallable from ..sql.selectable import SelectBase + from ..sql.type_api import TypeEngine _path_registry = PathRegistry.root @@ -211,6 +214,9 @@ class ORMCompileState(CompileState): _for_refresh_state = False _render_for_subquery = False + attributes: Dict[Any, Any] + global_attributes: Dict[Any, Any] + statement: Union[Select, FromStatement] select_statement: Union[Select, FromStatement] _entities: List[_QueryEntity] @@ -1930,7 +1936,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): assert right_mapper adapter = ORMAdapter( - right, equivalents=right_mapper._equivalent_columns + inspect(right), equivalents=right_mapper._equivalent_columns ) # if an alias() on the right side was generated, @@ -2075,14 +2081,16 @@ class ORMSelectCompileState(ORMCompileState, SelectState): def _column_descriptions( - query_or_select_stmt, compile_state=None, legacy=False + query_or_select_stmt: Union[Query, Select, FromStatement], + compile_state: Optional[ORMSelectCompileState] = None, + legacy: bool = False, ) -> List[ORMColumnDescription]: if compile_state is None: compile_state = ORMSelectCompileState._create_entities_collection( query_or_select_stmt, legacy=legacy ) ctx = compile_state - return [ + d = [ { "name": ent._label_name, "type": ent.type, @@ -2093,17 +2101,10 @@ def _column_descriptions( else None, } for ent, insp_ent in [ - ( - _ent, - ( - inspect(_ent.entity_zero) - if _ent.entity_zero is not None - else None - ), - ) - for _ent in ctx._entities + (_ent, _ent.entity_zero) for _ent in ctx._entities ] ] + return d def _legacy_filter_by_entity_zero(query_or_augmented_select): @@ -2157,6 +2158,11 @@ class _QueryEntity: _null_column_type = False use_id_for_hash = False + _label_name: Optional[str] + type: Union[Type[Any], TypeEngine[Any]] + expr: Union[_InternalEntityType, ColumnElement[Any]] + entity_zero: Optional[_InternalEntityType] + def setup_compile_state(self, compile_state: ORMCompileState) -> None: raise NotImplementedError() @@ -2234,6 +2240,13 @@ class _MapperEntity(_QueryEntity): "_polymorphic_discriminator", ) + expr: _InternalEntityType + mapper: Mapper[Any] + entity_zero: _InternalEntityType + is_aliased_class: bool + path: PathRegistry + _label_name: str + def __init__( self, compile_state, entity, entities_collection, is_current_entities ): @@ -2389,6 +2402,13 @@ class _BundleEntity(_QueryEntity): "supports_single_entity", ) + _entities: List[_QueryEntity] + bundle: Bundle + type: Type[Any] + _label_name: str + supports_single_entity: bool + expr: Bundle + def __init__( self, compile_state, diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index 70507015b..0c990f809 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -50,7 +50,7 @@ from ..util import hybridproperty from ..util import typing as compat_typing if typing.TYPE_CHECKING: - from .state import InstanceState # noqa + from .state import InstanceState _T = TypeVar("_T", bound=Any) @@ -280,7 +280,7 @@ class declared_attr(interfaces._MappedAttribute[_T]): # for the span of the declarative scan_attributes() phase. # to achieve this we look at the class manager that's configured. cls = owner - manager = attributes.manager_of_class(cls) + manager = attributes.opt_manager_of_class(cls) if manager is None: if not re.match(r"^__.+__$", self.fget.__name__): # if there is no manager at all, then this class hasn't been @@ -1294,8 +1294,8 @@ def as_declarative(**kw): @inspection._inspects( DeclarativeMeta, DeclarativeBase, DeclarativeAttributeIntercept ) -def _inspect_decl_meta(cls): - mp = _inspect_mapped_class(cls) +def _inspect_decl_meta(cls: Type[Any]) -> Mapper[Any]: + mp: Mapper[Any] = _inspect_mapped_class(cls) if mp is None: if _DeferredMapperConfig.has_cls(cls): _DeferredMapperConfig.raise_unmapped_for_cls(cls) diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index 804d05ce1..9c79a4172 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -12,6 +12,8 @@ import collections from typing import Any from typing import Dict from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING import weakref from . import attributes @@ -42,6 +44,10 @@ from ..sql.schema import Column from ..sql.schema import Table from ..util import topological +if TYPE_CHECKING: + from ._typing import _O + from ._typing import _RegistryType + def _declared_mapping_info(cls): # deferred mapping @@ -121,7 +127,7 @@ def _dive_for_cls_manager(cls): return None for base in cls.__mro__: - manager = attributes.manager_of_class(base) + manager = attributes.opt_manager_of_class(base) if manager: return manager return None @@ -171,7 +177,7 @@ class _MapperConfig: @classmethod def setup_mapping(cls, registry, cls_, dict_, table, mapper_kw): - manager = attributes.manager_of_class(cls) + manager = attributes.opt_manager_of_class(cls) if manager and manager.class_ is cls_: raise exc.InvalidRequestError( "Class %r already has been " "instrumented declaratively" % cls @@ -191,7 +197,12 @@ class _MapperConfig: return cfg_cls(registry, cls_, dict_, table, mapper_kw) - def __init__(self, registry, cls_, mapper_kw): + def __init__( + self, + registry: _RegistryType, + cls_: Type[Any], + mapper_kw: Dict[str, Any], + ): self.cls = util.assert_arg_type(cls_, type, "cls_") self.classname = cls_.__name__ self.properties = util.OrderedDict() @@ -206,7 +217,7 @@ class _MapperConfig: init_method=registry.constructor, ) else: - manager = attributes.manager_of_class(self.cls) + manager = attributes.opt_manager_of_class(self.cls) if not manager or not manager.is_mapped: raise exc.InvalidRequestError( "Class %s has no primary mapper configured. Configure " diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index 8beac472e..4738d8c2c 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -122,7 +122,11 @@ class DescriptorProperty(MapperProperty[_T]): _CompositeAttrType = Union[ - str, "Column[Any]", "MappedColumn[Any]", "InstrumentedAttribute[Any]" + str, + "Column[_T]", + "MappedColumn[_T]", + "InstrumentedAttribute[_T]", + "Mapped[_T]", ] diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index c531e7cf1..331c224ee 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -11,6 +11,9 @@ from __future__ import annotations from typing import Any +from typing import Optional +from typing import Type +from typing import TYPE_CHECKING import weakref from . import instrumentation @@ -27,6 +30,10 @@ from .. import exc from .. import util from ..util.compat import inspect_getfullargspec +if TYPE_CHECKING: + from ._typing import _O + from .instrumentation import ClassManager + class InstrumentationEvents(event.Events): """Events related to class instrumentation events. @@ -214,7 +221,7 @@ class InstanceEvents(event.Events): if issubclass(target, mapperlib.Mapper): return instrumentation.ClassManager else: - manager = instrumentation.manager_of_class(target) + manager = instrumentation.opt_manager_of_class(target) if manager: return manager else: @@ -613,8 +620,8 @@ class _EventsHold(event.RefCollection): class _InstanceEventsHold(_EventsHold): all_holds = weakref.WeakKeyDictionary() - def resolve(self, class_): - return instrumentation.manager_of_class(class_) + def resolve(self, class_: Type[_O]) -> Optional[ClassManager[_O]]: + return instrumentation.opt_manager_of_class(class_) class HoldInstanceEvents(_EventsHold.HoldEvents, InstanceEvents): pass diff --git a/lib/sqlalchemy/orm/exc.py b/lib/sqlalchemy/orm/exc.py index 00829ecbb..529a7cd01 100644 --- a/lib/sqlalchemy/orm/exc.py +++ b/lib/sqlalchemy/orm/exc.py @@ -203,7 +203,10 @@ def _default_unmapped(cls) -> Optional[str]: try: mappers = base.manager_of_class(cls).mappers - except (TypeError,) + NO_STATE: + except ( + UnmappedClassError, + TypeError, + ) + NO_STATE: mappers = {} name = _safe_cls_name(cls) diff --git a/lib/sqlalchemy/orm/instrumentation.py b/lib/sqlalchemy/orm/instrumentation.py index 0d4b630da..88ceacd07 100644 --- a/lib/sqlalchemy/orm/instrumentation.py +++ b/lib/sqlalchemy/orm/instrumentation.py @@ -33,10 +33,13 @@ alternate instrumentation forms. from __future__ import annotations from typing import Any +from typing import Callable from typing import Dict from typing import Generic from typing import Optional from typing import Set +from typing import Tuple +from typing import Type from typing import TYPE_CHECKING from typing import TypeVar import weakref @@ -53,7 +56,9 @@ from ..util import HasMemoized from ..util.typing import Protocol if TYPE_CHECKING: + from ._typing import _RegistryType from .attributes import InstrumentedAttribute + from .decl_base import _MapperConfig from .mapper import Mapper from .state import InstanceState from ..event import dispatcher @@ -72,6 +77,11 @@ class _ExpiredAttributeLoaderProto(Protocol): ... +class _ManagerFactory(Protocol): + def __call__(self, class_: Type[_O]) -> ClassManager[_O]: + ... + + class ClassManager( HasMemoized, Dict[str, "InstrumentedAttribute[Any]"], @@ -90,12 +100,12 @@ class ClassManager( expired_attribute_loader: _ExpiredAttributeLoaderProto "previously known as deferred_scalar_loader" - init_method = None + init_method: Optional[Callable[..., None]] - factory = None + factory: Optional[_ManagerFactory] - declarative_scan = None - registry = None + declarative_scan: Optional[weakref.ref[_MapperConfig]] = None + registry: Optional[_RegistryType] = None @property @util.deprecated( @@ -122,11 +132,13 @@ class ClassManager( self.local_attrs = {} self.originals = {} self._finalized = False + self.factory = None + self.init_method = None self._bases = [ mgr for mgr in [ - manager_of_class(base) + opt_manager_of_class(base) for base in self.class_.__bases__ if isinstance(base, type) ] @@ -139,7 +151,7 @@ class ClassManager( self.dispatch._events._new_classmanager_instance(class_, self) for basecls in class_.__mro__: - mgr = manager_of_class(basecls) + mgr = opt_manager_of_class(basecls) if mgr is not None: self.dispatch._update(mgr.dispatch) @@ -155,16 +167,18 @@ class ClassManager( def _update_state( self, - finalize=False, - mapper=None, - registry=None, - declarative_scan=None, - expired_attribute_loader=None, - init_method=None, - ): + finalize: bool = False, + mapper: Optional[Mapper[_O]] = None, + registry: Optional[_RegistryType] = None, + declarative_scan: Optional[_MapperConfig] = None, + expired_attribute_loader: Optional[ + _ExpiredAttributeLoaderProto + ] = None, + init_method: Optional[Callable[..., None]] = None, + ) -> None: if mapper: - self.mapper = mapper + self.mapper = mapper # type: ignore[assignment] if registry: registry._add_manager(self) if declarative_scan: @@ -350,7 +364,7 @@ class ClassManager( def subclass_managers(self, recursive): for cls in self.class_.__subclasses__(): - mgr = manager_of_class(cls) + mgr = opt_manager_of_class(cls) if mgr is not None and mgr is not self: yield mgr if recursive: @@ -374,7 +388,7 @@ class ClassManager( self._reset_memoizations() del self[key] for cls in self.class_.__subclasses__(): - manager = manager_of_class(cls) + manager = opt_manager_of_class(cls) if manager: manager.uninstrument_attribute(key, True) @@ -523,7 +537,7 @@ class _SerializeManager: manager.dispatch.pickle(state, d) def __call__(self, state, inst, state_dict): - state.manager = manager = manager_of_class(self.class_) + state.manager = manager = opt_manager_of_class(self.class_) if manager is None: raise exc.UnmappedInstanceError( inst, @@ -546,9 +560,9 @@ class _SerializeManager: class InstrumentationFactory: """Factory for new ClassManager instances.""" - def create_manager_for_cls(self, class_): + def create_manager_for_cls(self, class_: Type[_O]) -> ClassManager[_O]: assert class_ is not None - assert manager_of_class(class_) is None + assert opt_manager_of_class(class_) is None # give a more complicated subclass # a chance to do what it wants here @@ -557,6 +571,8 @@ class InstrumentationFactory: if factory is None: factory = ClassManager manager = factory(class_) + else: + assert manager is not None self._check_conflicts(class_, factory) @@ -564,11 +580,15 @@ class InstrumentationFactory: return manager - def _locate_extended_factory(self, class_): + def _locate_extended_factory( + self, class_: Type[_O] + ) -> Tuple[Optional[ClassManager[_O]], Optional[_ManagerFactory]]: """Overridden by a subclass to do an extended lookup.""" return None, None - def _check_conflicts(self, class_, factory): + def _check_conflicts( + self, class_: Type[_O], factory: Callable[[Type[_O]], ClassManager[_O]] + ): """Overridden by a subclass to test for conflicting factories.""" return @@ -590,24 +610,25 @@ instance_state = _default_state_getter = base.instance_state instance_dict = _default_dict_getter = base.instance_dict manager_of_class = _default_manager_getter = base.manager_of_class +opt_manager_of_class = _default_opt_manager_getter = base.opt_manager_of_class def register_class( - class_, - finalize=True, - mapper=None, - registry=None, - declarative_scan=None, - expired_attribute_loader=None, - init_method=None, -): + class_: Type[_O], + finalize: bool = True, + mapper: Optional[Mapper[_O]] = None, + registry: Optional[_RegistryType] = None, + declarative_scan: Optional[_MapperConfig] = None, + expired_attribute_loader: Optional[_ExpiredAttributeLoaderProto] = None, + init_method: Optional[Callable[..., None]] = None, +) -> ClassManager[_O]: """Register class instrumentation. Returns the existing or newly created class manager. """ - manager = manager_of_class(class_) + manager = opt_manager_of_class(class_) if manager is None: manager = _instrumentation_factory.create_manager_for_cls(class_) manager._update_state( diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index abc1300d8..0ca62b7e3 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -21,10 +21,15 @@ from __future__ import annotations import collections import typing from typing import Any +from typing import Callable from typing import cast +from typing import ClassVar +from typing import Dict +from typing import Iterator from typing import List from typing import Optional from typing import Sequence +from typing import Set from typing import Tuple from typing import Type from typing import TypeVar @@ -45,7 +50,6 @@ from .base import NotExtension as NotExtension from .base import ONETOMANY as ONETOMANY from .base import SQLORMOperations from .. import ColumnElement -from .. import inspect from .. import inspection from .. import util from ..sql import operators @@ -53,19 +57,47 @@ from ..sql import roles from ..sql import visitors from ..sql.base import ExecutableOption from ..sql.cache_key import HasCacheKey -from ..sql.elements import SQLCoreOperations from ..sql.schema import Column from ..sql.type_api import TypeEngine from ..util.typing import TypedDict + if typing.TYPE_CHECKING: + from ._typing import _EntityType + from ._typing import _IdentityKeyType + from ._typing import _InstanceDict + from ._typing import _InternalEntityType + from ._typing import _ORMAdapterProto + from ._typing import _ORMColumnExprArgument + from .attributes import InstrumentedAttribute + from .context import _MapperEntity + from .context import ORMCompileState from .decl_api import RegistryType + from .loading import _PopulatorDict + from .mapper import Mapper + from .path_registry import AbstractEntityRegistry + from .path_registry import PathRegistry + from .query import Query + from .session import Session + from .state import InstanceState + from .strategy_options import _LoadElement + from .util import AliasedInsp + from .util import CascadeOptions + from .util import ORMAdapter + from ..engine.result import Result + from ..sql._typing import _ColumnExpressionArgument from ..sql._typing import _ColumnsClauseArgument from ..sql._typing import _DMLColumnArgument from ..sql._typing import _InfoType + from ..sql._typing import _PropagateAttrsType + from ..sql.operators import OperatorType + from ..sql.util import ColumnAdapter + from ..sql.visitors import _TraverseInternalsType _T = TypeVar("_T", bound=Any) +_TLS = TypeVar("_TLS", bound="Type[LoaderStrategy]") + class ORMStatementRole(roles.StatementRole): __slots__ = () @@ -91,7 +123,9 @@ class ORMFromClauseRole(roles.StrictFromClauseRole): class ORMColumnDescription(TypedDict): name: str - type: Union[Type, TypeEngine] + # TODO: add python_type and sql_type here; combining them + # into "type" is a bad idea + type: Union[Type[Any], TypeEngine[Any]] aliased: bool expr: _ColumnsClauseArgument entity: Optional[_ColumnsClauseArgument] @@ -102,10 +136,10 @@ class _IntrospectsAnnotations: def declarative_scan( self, - registry: "RegistryType", - cls: type, + registry: RegistryType, + cls: Type[Any], key: str, - annotation: Optional[type], + annotation: Optional[Type[Any]], is_dataclass_field: Optional[bool], ) -> None: """Perform class-specific initializaton at early declarative scanning @@ -124,12 +158,12 @@ class _MapsColumns(_MappedAttribute[_T]): __slots__ = () @property - def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]: + def mapper_property_to_assign(self) -> Optional[MapperProperty[_T]]: """return a MapperProperty to be assigned to the declarative mapping""" raise NotImplementedError() @property - def columns_to_assign(self) -> List[Column]: + def columns_to_assign(self) -> List[Column[_T]]: """A list of Column objects that should be declaratively added to the new Table object. @@ -139,7 +173,10 @@ class _MapsColumns(_MappedAttribute[_T]): @inspection._self_inspects class MapperProperty( - HasCacheKey, _MappedAttribute[_T], InspectionAttr, util.MemoizedSlots + HasCacheKey, + _MappedAttribute[_T], + InspectionAttrInfo, + util.MemoizedSlots, ): """Represent a particular class attribute mapped by :class:`_orm.Mapper`. @@ -160,12 +197,12 @@ class MapperProperty( "info", ) - _cache_key_traversal = [ + _cache_key_traversal: _TraverseInternalsType = [ ("parent", visitors.ExtendedInternalTraversal.dp_has_cache_key), ("key", visitors.ExtendedInternalTraversal.dp_string), ] - cascade = frozenset() + cascade: Optional[CascadeOptions] = None """The set of 'cascade' attribute names. This collection is checked before the 'cascade_iterator' method is called. @@ -184,14 +221,20 @@ class MapperProperty( """The :class:`_orm.PropComparator` instance that implements SQL expression construction on behalf of this mapped attribute.""" - @property - def _links_to_entity(self): - """True if this MapperProperty refers to a mapped entity. + key: str + """name of class attribute""" - Should only be True for Relationship, False for all others. + parent: Mapper[Any] + """the :class:`.Mapper` managing this property.""" - """ - raise NotImplementedError() + _is_relationship = False + + _links_to_entity: bool + """True if this MapperProperty refers to a mapped entity. + + Should only be True for Relationship, False for all others. + + """ def _memoized_attr_info(self) -> _InfoType: """Info dictionary associated with the object, allowing user-defined @@ -217,7 +260,14 @@ class MapperProperty( """ return {} - def setup(self, context, query_entity, path, adapter, **kwargs): + def setup( + self, + context: ORMCompileState, + query_entity: _MapperEntity, + path: PathRegistry, + adapter: Optional[ColumnAdapter], + **kwargs: Any, + ) -> None: """Called by Query for the purposes of constructing a SQL statement. Each MapperProperty associated with the target mapper processes the @@ -227,16 +277,30 @@ class MapperProperty( """ def create_row_processor( - self, context, query_entity, path, mapper, result, adapter, populators - ): + self, + context: ORMCompileState, + query_entity: _MapperEntity, + path: PathRegistry, + mapper: Mapper[Any], + result: Result, + adapter: Optional[ColumnAdapter], + populators: _PopulatorDict, + ) -> None: """Produce row processing functions and append to the given set of populators lists. """ def cascade_iterator( - self, type_, state, dict_, visited_states, halt_on=None - ): + self, + type_: str, + state: InstanceState[Any], + dict_: _InstanceDict, + visited_states: Set[InstanceState[Any]], + halt_on: Optional[Callable[[InstanceState[Any]], bool]] = None, + ) -> Iterator[ + Tuple[object, Mapper[Any], InstanceState[Any], _InstanceDict] + ]: """Iterate through instances related to the given instance for a particular 'cascade', starting with this MapperProperty. @@ -251,7 +315,7 @@ class MapperProperty( return iter(()) - def set_parent(self, parent, init): + def set_parent(self, parent: Mapper[Any], init: bool) -> None: """Set the parent mapper that references this MapperProperty. This method is overridden by some subclasses to perform extra @@ -260,7 +324,7 @@ class MapperProperty( """ self.parent = parent - def instrument_class(self, mapper): + def instrument_class(self, mapper: Mapper[Any]) -> None: """Hook called by the Mapper to the property to initiate instrumentation of the class attribute managed by this MapperProperty. @@ -280,11 +344,11 @@ class MapperProperty( """ - def __init__(self): + def __init__(self) -> None: self._configure_started = False self._configure_finished = False - def init(self): + def init(self) -> None: """Called after all mappers are created to assemble relationships between mappers and perform other post-mapper-creation initialization steps. @@ -296,7 +360,7 @@ class MapperProperty( self._configure_finished = True @property - def class_attribute(self): + def class_attribute(self) -> InstrumentedAttribute[_T]: """Return the class-bound descriptor corresponding to this :class:`.MapperProperty`. @@ -319,9 +383,9 @@ class MapperProperty( """ - return getattr(self.parent.class_, self.key) + return getattr(self.parent.class_, self.key) # type: ignore - def do_init(self): + def do_init(self) -> None: """Perform subclass-specific initialization post-mapper-creation steps. @@ -330,7 +394,7 @@ class MapperProperty( """ - def post_instrument_class(self, mapper): + def post_instrument_class(self, mapper: Mapper[Any]) -> None: """Perform instrumentation adjustments that need to occur after init() has completed. @@ -347,21 +411,21 @@ class MapperProperty( def merge( self, - session, - source_state, - source_dict, - dest_state, - dest_dict, - load, - _recursive, - _resolve_conflict_map, - ): + session: Session, + source_state: InstanceState[Any], + source_dict: _InstanceDict, + dest_state: InstanceState[Any], + dest_dict: _InstanceDict, + load: bool, + _recursive: Set[InstanceState[Any]], + _resolve_conflict_map: Dict[_IdentityKeyType[Any], object], + ) -> None: """Merge the attribute represented by this ``MapperProperty`` from source to destination object. """ - def __repr__(self): + def __repr__(self) -> str: return "<%s at 0x%x; %s>" % ( self.__class__.__name__, id(self), @@ -452,21 +516,28 @@ class PropComparator(SQLORMOperations[_T]): """ - __slots__ = "prop", "property", "_parententity", "_adapt_to_entity" + __slots__ = "prop", "_parententity", "_adapt_to_entity" __visit_name__ = "orm_prop_comparator" + _parententity: _InternalEntityType[Any] + _adapt_to_entity: Optional[AliasedInsp[Any]] + def __init__( self, - prop, - parentmapper, - adapt_to_entity=None, + prop: MapperProperty[_T], + parentmapper: _InternalEntityType[Any], + adapt_to_entity: Optional[AliasedInsp[Any]] = None, ): - self.prop = self.property = prop + self.prop = prop self._parententity = adapt_to_entity or parentmapper self._adapt_to_entity = adapt_to_entity - def __clause_element__(self): + @util.ro_non_memoized_property + def property(self) -> Optional[MapperProperty[_T]]: + return self.prop + + def __clause_element__(self) -> _ORMColumnExprArgument[_T]: raise NotImplementedError("%r" % self) def _bulk_update_tuples( @@ -480,22 +551,24 @@ class PropComparator(SQLORMOperations[_T]): """ - return [(self.__clause_element__(), value)] + return [(cast("_DMLColumnArgument", self.__clause_element__()), value)] - def adapt_to_entity(self, adapt_to_entity): + def adapt_to_entity( + self, adapt_to_entity: AliasedInsp[Any] + ) -> PropComparator[_T]: """Return a copy of this PropComparator which will use the given :class:`.AliasedInsp` to produce corresponding expressions. """ return self.__class__(self.prop, self._parententity, adapt_to_entity) - @property - def _parentmapper(self): + @util.ro_non_memoized_property + def _parentmapper(self) -> Mapper[Any]: """legacy; this is renamed to _parententity to be compatible with QueryableAttribute.""" - return inspect(self._parententity).mapper + return self._parententity.mapper - @property - def _propagate_attrs(self): + @util.memoized_property + def _propagate_attrs(self) -> _PropagateAttrsType: # this suits the case in coercions where we don't actually # call ``__clause_element__()`` but still need to get # resolved._propagate_attrs. See #6558. @@ -507,12 +580,14 @@ class PropComparator(SQLORMOperations[_T]): ) def _criterion_exists( - self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, ) -> ColumnElement[Any]: return self.prop.comparator._criterion_exists(criterion, **kwargs) - @property - def adapter(self): + @util.ro_non_memoized_property + def adapter(self) -> Optional[_ORMAdapterProto[_T]]: """Produce a callable that adapts column expressions to suit an aliased version of this comparator. @@ -522,20 +597,20 @@ class PropComparator(SQLORMOperations[_T]): else: return self._adapt_to_entity._adapt_element - @util.non_memoized_property + @util.ro_non_memoized_property def info(self) -> _InfoType: - return self.property.info + return self.prop.info @staticmethod - def _any_op(a, b, **kwargs): + def _any_op(a: Any, b: Any, **kwargs: Any) -> Any: return a.any(b, **kwargs) @staticmethod - def _has_op(left, other, **kwargs): + def _has_op(left: Any, other: Any, **kwargs: Any) -> Any: return left.has(other, **kwargs) @staticmethod - def _of_type_op(a, class_): + def _of_type_op(a: Any, class_: Any) -> Any: return a.of_type(class_) any_op = cast(operators.OperatorType, _any_op) @@ -545,16 +620,16 @@ class PropComparator(SQLORMOperations[_T]): if typing.TYPE_CHECKING: def operate( - self, op: operators.OperatorType, *other: Any, **kwargs: Any - ) -> "SQLCoreOperations[Any]": + self, op: OperatorType, *other: Any, **kwargs: Any + ) -> ColumnElement[Any]: ... def reverse_operate( - self, op: operators.OperatorType, other: Any, **kwargs: Any - ) -> "SQLCoreOperations[Any]": + self, op: OperatorType, other: Any, **kwargs: Any + ) -> ColumnElement[Any]: ... - def of_type(self, class_) -> "SQLORMOperations[_T]": + def of_type(self, class_: _EntityType[Any]) -> PropComparator[_T]: r"""Redefine this object in terms of a polymorphic subclass, :func:`_orm.with_polymorphic` construct, or :func:`_orm.aliased` construct. @@ -578,9 +653,11 @@ class PropComparator(SQLORMOperations[_T]): """ - return self.operate(PropComparator.of_type_op, class_) + return self.operate(PropComparator.of_type_op, class_) # type: ignore - def and_(self, *criteria) -> "SQLORMOperations[_T]": + def and_( + self, *criteria: _ColumnExpressionArgument[bool] + ) -> ColumnElement[bool]: """Add additional criteria to the ON clause that's represented by this relationship attribute. @@ -606,10 +683,12 @@ class PropComparator(SQLORMOperations[_T]): :func:`.with_loader_criteria` """ - return self.operate(operators.and_, *criteria) + return self.operate(operators.and_, *criteria) # type: ignore def any( - self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, ) -> ColumnElement[bool]: r"""Return a SQL expression representing true if this element references a member which meets the given criterion. @@ -626,10 +705,14 @@ class PropComparator(SQLORMOperations[_T]): """ - return self.operate(PropComparator.any_op, criterion, **kwargs) + return self.operate( # type: ignore + PropComparator.any_op, criterion, **kwargs + ) def has( - self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, ) -> ColumnElement[bool]: r"""Return a SQL expression representing true if this element references a member which meets the given criterion. @@ -646,7 +729,9 @@ class PropComparator(SQLORMOperations[_T]): """ - return self.operate(PropComparator.has_op, criterion, **kwargs) + return self.operate( # type: ignore + PropComparator.has_op, criterion, **kwargs + ) class StrategizedProperty(MapperProperty[_T]): @@ -674,23 +759,30 @@ class StrategizedProperty(MapperProperty[_T]): "strategy_key", ) inherit_cache = True - strategy_wildcard_key = None + strategy_wildcard_key: ClassVar[str] strategy_key: Tuple[Any, ...] - def _memoized_attr__wildcard_token(self): + _strategies: Dict[Tuple[Any, ...], LoaderStrategy] + + def _memoized_attr__wildcard_token(self) -> Tuple[str]: return ( f"{self.strategy_wildcard_key}:{path_registry._WILDCARD_TOKEN}", ) - def _memoized_attr__default_path_loader_key(self): + def _memoized_attr__default_path_loader_key( + self, + ) -> Tuple[str, Tuple[str]]: return ( "loader", (f"{self.strategy_wildcard_key}:{path_registry._DEFAULT_TOKEN}",), ) - def _get_context_loader(self, context, path): - load = None + def _get_context_loader( + self, context: ORMCompileState, path: AbstractEntityRegistry + ) -> Optional[_LoadElement]: + + load: Optional[_LoadElement] = None search_path = path[self] @@ -714,7 +806,7 @@ class StrategizedProperty(MapperProperty[_T]): return load - def _get_strategy(self, key): + def _get_strategy(self, key: Tuple[Any, ...]) -> LoaderStrategy: try: return self._strategies[key] except KeyError: @@ -768,11 +860,13 @@ class StrategizedProperty(MapperProperty[_T]): ): self.strategy.init_class_attribute(mapper) - _all_strategies = collections.defaultdict(dict) + _all_strategies: collections.defaultdict[ + Type[Any], Dict[Tuple[Any, ...], Type[LoaderStrategy]] + ] = collections.defaultdict(dict) @classmethod - def strategy_for(cls, **kw): - def decorate(dec_cls): + def strategy_for(cls, **kw: Any) -> Callable[[_TLS], _TLS]: + def decorate(dec_cls: _TLS) -> _TLS: # ensure each subclass of the strategy has its # own _strategy_keys collection if "_strategy_keys" not in dec_cls.__dict__: @@ -785,7 +879,9 @@ class StrategizedProperty(MapperProperty[_T]): return decorate @classmethod - def _strategy_lookup(cls, requesting_property, *key): + def _strategy_lookup( + cls, requesting_property: MapperProperty[Any], *key: Any + ) -> Type[LoaderStrategy]: requesting_property.parent._with_polymorphic_mappers for prop_cls in cls.__mro__: @@ -984,10 +1080,10 @@ class MapperOption(ORMOption): """ - def process_query(self, query): + def process_query(self, query: Query[Any]) -> None: """Apply a modification to the given :class:`_query.Query`.""" - def process_query_conditionally(self, query): + def process_query_conditionally(self, query: Query[Any]) -> None: """same as process_query(), except that this option may not apply to the given query. @@ -1034,7 +1130,11 @@ class LoaderStrategy: "strategy_opts", ) - def __init__(self, parent, strategy_key): + _strategy_keys: ClassVar[List[Tuple[Any, ...]]] + + def __init__( + self, parent: MapperProperty[Any], strategy_key: Tuple[Any, ...] + ): self.parent_property = parent self.is_class_level = False self.parent = self.parent_property.parent @@ -1042,12 +1142,18 @@ class LoaderStrategy: self.strategy_key = strategy_key self.strategy_opts = dict(strategy_key) - def init_class_attribute(self, mapper): + def init_class_attribute(self, mapper: Mapper[Any]) -> None: pass def setup_query( - self, compile_state, query_entity, path, loadopt, adapter, **kwargs - ): + self, + compile_state: ORMCompileState, + query_entity: _MapperEntity, + path: AbstractEntityRegistry, + loadopt: Optional[_LoadElement], + adapter: Optional[ORMAdapter], + **kwargs: Any, + ) -> None: """Establish column and other state for a given QueryContext. This method fulfills the contract specified by MapperProperty.setup(). @@ -1059,15 +1165,15 @@ class LoaderStrategy: def create_row_processor( self, - context, - query_entity, - path, - loadopt, - mapper, - result, - adapter, - populators, - ): + context: ORMCompileState, + query_entity: _MapperEntity, + path: AbstractEntityRegistry, + loadopt: Optional[_LoadElement], + mapper: Mapper[Any], + result: Result, + adapter: Optional[ORMAdapter], + populators: _PopulatorDict, + ) -> None: """Establish row processing functions for a given QueryContext. This method fulfills the contract specified by diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index ae083054c..d9949eb7a 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -16,7 +16,9 @@ as well as some of the attribute loading strategies. from __future__ import annotations from typing import Any +from typing import Dict from typing import Iterable +from typing import List from typing import Mapping from typing import Optional from typing import Sequence @@ -65,6 +67,9 @@ _O = TypeVar("_O", bound=object) _new_runid = util.counter() +_PopulatorDict = Dict[str, List[Tuple[str, Any]]] + + def instances(cursor, context): """Return a :class:`.Result` given an ORM query context. @@ -383,7 +388,7 @@ def get_from_identity( mapper: Mapper[_O], key: _IdentityKeyType[_O], passive: PassiveFlag, -) -> Union[Optional[_O], LoaderCallableStatus]: +) -> Union[LoaderCallableStatus, Optional[_O]]: """Look up the given key in the given session's identity map, check the object for expired state if found. diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index abe11cc68..b37c080ea 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -23,12 +23,23 @@ import sys import threading from typing import Any from typing import Callable +from typing import cast +from typing import Collection +from typing import Deque +from typing import Dict from typing import Generic +from typing import Iterable from typing import Iterator +from typing import List +from typing import Mapping from typing import Optional +from typing import Sequence +from typing import Set from typing import Tuple from typing import Type from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union import weakref from . import attributes @@ -39,8 +50,8 @@ from . import properties from . import util as orm_util from ._typing import _O from .base import _class_to_mapper +from .base import _parse_mapper_argument from .base import _state_mapper -from .base import class_mapper from .base import PassiveFlag from .base import state_str from .interfaces import _MappedAttribute @@ -58,6 +69,8 @@ from .. import log from .. import schema from .. import sql from .. import util +from ..event import dispatcher +from ..event import EventTarget from ..sql import base as sql_base from ..sql import coercions from ..sql import expression @@ -65,26 +78,68 @@ from ..sql import operators from ..sql import roles from ..sql import util as sql_util from ..sql import visitors +from ..sql.cache_key import MemoizedHasCacheKey +from ..sql.schema import Table from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from ..util import HasMemoized +from ..util import HasMemoized_ro_memoized_attribute +from ..util.typing import Literal if TYPE_CHECKING: from ._typing import _IdentityKeyType from ._typing import _InstanceDict + from ._typing import _ORMColumnExprArgument + from ._typing import _RegistryType + from .decl_api import registry + from .dependency import DependencyProcessor + from .descriptor_props import Composite + from .descriptor_props import Synonym + from .events import MapperEvents from .instrumentation import ClassManager + from .path_registry import AbstractEntityRegistry + from .path_registry import CachingEntityRegistry + from .properties import ColumnProperty + from .relationships import Relationship from .state import InstanceState + from ..engine import Row + from ..engine import RowMapping + from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _EquivalentColumnMap + from ..sql.base import ReadOnlyColumnCollection + from ..sql.elements import ColumnClause from ..sql.elements import ColumnElement from ..sql.schema import Column + from ..sql.schema import Table + from ..sql.selectable import FromClause + from ..sql.selectable import TableClause + from ..sql.util import ColumnAdapter + from ..util import OrderedSet -_mapper_registries = weakref.WeakKeyDictionary() +_T = TypeVar("_T", bound=Any) +_MP = TypeVar("_MP", bound="MapperProperty[Any]") -def _all_registries(): +_WithPolymorphicArg = Union[ + Literal["*"], + Tuple[ + Union[Literal["*"], Sequence[Union["Mapper[Any]", Type[Any]]]], + Optional["FromClause"], + ], + Sequence[Union["Mapper[Any]", Type[Any]]], +] + + +_mapper_registries: weakref.WeakKeyDictionary[ + _RegistryType, bool +] = weakref.WeakKeyDictionary() + + +def _all_registries() -> Set[registry]: with _CONFIGURE_MUTEX: return set(_mapper_registries) -def _unconfigured_mappers(): +def _unconfigured_mappers() -> Iterator[Mapper[Any]]: for reg in _all_registries(): for mapper in reg._mappers_to_configure(): yield mapper @@ -107,9 +162,11 @@ _CONFIGURE_MUTEX = threading.RLock() class Mapper( ORMFromClauseRole, ORMEntityColumnsClauseRole, - sql_base.MemoizedHasCacheKey, + MemoizedHasCacheKey, InspectionAttr, log.Identified, + inspection.Inspectable["Mapper[_O]"], + EventTarget, Generic[_O], ): """Defines an association between a Python class and a database table or @@ -123,18 +180,11 @@ class Mapper( """ + dispatch: dispatcher[Mapper[_O]] + _dispose_called = False _ready_for_configure = False - class_: Type[_O] - """The class to which this :class:`_orm.Mapper` is mapped.""" - - _identity_class: Type[_O] - - always_refresh: bool - allow_partial_pks: bool - version_id_col: Optional[ColumnElement[Any]] - @util.deprecated_params( non_primary=( "1.3", @@ -148,33 +198,39 @@ class Mapper( def __init__( self, class_: Type[_O], - local_table=None, - properties=None, - primary_key=None, - non_primary=False, - inherits=None, - inherit_condition=None, - inherit_foreign_keys=None, - always_refresh=False, - version_id_col=None, - version_id_generator=None, - polymorphic_on=None, - _polymorphic_map=None, - polymorphic_identity=None, - concrete=False, - with_polymorphic=None, - polymorphic_load=None, - allow_partial_pks=True, - batch=True, - column_prefix=None, - include_properties=None, - exclude_properties=None, - passive_updates=True, - passive_deletes=False, - confirm_deleted_rows=True, - eager_defaults=False, - legacy_is_orphan=False, - _compiled_cache_size=100, + local_table: Optional[FromClause] = None, + properties: Optional[Mapping[str, MapperProperty[Any]]] = None, + primary_key: Optional[Iterable[_ORMColumnExprArgument[Any]]] = None, + non_primary: bool = False, + inherits: Optional[Union[Mapper[Any], Type[Any]]] = None, + inherit_condition: Optional[_ColumnExpressionArgument[bool]] = None, + inherit_foreign_keys: Optional[ + Sequence[_ORMColumnExprArgument[Any]] + ] = None, + always_refresh: bool = False, + version_id_col: Optional[_ORMColumnExprArgument[Any]] = None, + version_id_generator: Optional[ + Union[Literal[False], Callable[[Any], Any]] + ] = None, + polymorphic_on: Optional[ + Union[_ORMColumnExprArgument[Any], str, MapperProperty[Any]] + ] = None, + _polymorphic_map: Optional[Dict[Any, Mapper[Any]]] = None, + polymorphic_identity: Optional[Any] = None, + concrete: bool = False, + with_polymorphic: Optional[_WithPolymorphicArg] = None, + polymorphic_load: Optional[Literal["selectin", "inline"]] = None, + allow_partial_pks: bool = True, + batch: bool = True, + column_prefix: Optional[str] = None, + include_properties: Optional[Sequence[str]] = None, + exclude_properties: Optional[Sequence[str]] = None, + passive_updates: bool = True, + passive_deletes: bool = False, + confirm_deleted_rows: bool = True, + eager_defaults: bool = False, + legacy_is_orphan: bool = False, + _compiled_cache_size: int = 100, ): r"""Direct constructor for a new :class:`_orm.Mapper` object. @@ -593,8 +649,6 @@ class Mapper( self.class_.__name__, ) - self.class_manager = None - self._primary_key_argument = util.to_list(primary_key) self.non_primary = non_primary @@ -623,17 +677,36 @@ class Mapper( self.concrete = concrete self.single = False - self.inherits = inherits + + if inherits is not None: + self.inherits = _parse_mapper_argument(inherits) + else: + self.inherits = None + if local_table is not None: self.local_table = coercions.expect( roles.StrictFromClauseRole, local_table ) + elif self.inherits: + # note this is a new flow as of 2.0 so that + # .local_table need not be Optional + self.local_table = self.inherits.local_table + self.single = True else: - self.local_table = None + raise sa_exc.ArgumentError( + f"Mapper[{self.class_.__name__}(None)] has None for a " + "primary table argument and does not specify 'inherits'" + ) + + if inherit_condition is not None: + self.inherit_condition = coercions.expect( + roles.OnClauseRole, inherit_condition + ) + else: + self.inherit_condition = None - self.inherit_condition = inherit_condition self.inherit_foreign_keys = inherit_foreign_keys - self._init_properties = properties or {} + self._init_properties = dict(properties) if properties else {} self._delete_orphans = [] self.batch = batch self.eager_defaults = eager_defaults @@ -694,7 +767,10 @@ class Mapper( # while a configure_mappers() is occurring (and defer a # configure_mappers() until construction succeeds) with _CONFIGURE_MUTEX: - self.dispatch._events._new_mapper_instance(class_, self) + + cast("MapperEvents", self.dispatch._events)._new_mapper_instance( + class_, self + ) self._configure_inheritance() self._configure_class_instrumentation() self._configure_properties() @@ -704,16 +780,21 @@ class Mapper( self._log("constructed") self._expire_memoizations() - # major attributes initialized at the classlevel so that - # they can be Sphinx-documented. + def _gen_cache_key(self, anon_map, bindparams): + return (self,) + + # ### BEGIN + # ATTRIBUTE DECLARATIONS START HERE is_mapper = True """Part of the inspection API.""" represents_outer_join = False + registry: _RegistryType + @property - def mapper(self): + def mapper(self) -> Mapper[_O]: """Part of the inspection API. Returns self. @@ -721,9 +802,6 @@ class Mapper( """ return self - def _gen_cache_key(self, anon_map, bindparams): - return (self,) - @property def entity(self): r"""Part of the inspection API. @@ -733,49 +811,109 @@ class Mapper( """ return self.class_ - local_table = None - """The :class:`_expression.Selectable` which this :class:`_orm.Mapper` - manages. + class_: Type[_O] + """The class to which this :class:`_orm.Mapper` is mapped.""" + + _identity_class: Type[_O] + + _delete_orphans: List[Tuple[str, Type[Any]]] + _dependency_processors: List[DependencyProcessor] + _memoized_values: Dict[Any, Callable[[], Any]] + _inheriting_mappers: util.WeakSequence[Mapper[Any]] + _all_tables: Set[Table] + + _pks_by_table: Dict[FromClause, OrderedSet[ColumnClause[Any]]] + _cols_by_table: Dict[FromClause, OrderedSet[ColumnElement[Any]]] + + _props: util.OrderedDict[str, MapperProperty[Any]] + _init_properties: Dict[str, MapperProperty[Any]] + + _columntoproperty: _ColumnMapping + + _set_polymorphic_identity: Optional[Callable[[InstanceState[_O]], None]] + _validate_polymorphic_identity: Optional[ + Callable[[Mapper[_O], InstanceState[_O], _InstanceDict], None] + ] + + tables: Sequence[Table] + """A sequence containing the collection of :class:`_schema.Table` objects + which this :class:`_orm.Mapper` is aware of. + + If the mapper is mapped to a :class:`_expression.Join`, or an + :class:`_expression.Alias` + representing a :class:`_expression.Select`, the individual + :class:`_schema.Table` + objects that comprise the full construct will be represented here. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + validators: util.immutabledict[str, Tuple[str, Dict[str, Any]]] + """An immutable dictionary of attributes which have been decorated + using the :func:`_orm.validates` decorator. + + The dictionary contains string attribute names as keys + mapped to the actual validation method. + + """ + + always_refresh: bool + allow_partial_pks: bool + version_id_col: Optional[ColumnElement[Any]] + + with_polymorphic: Optional[ + Tuple[ + Union[Literal["*"], Sequence[Union["Mapper[Any]", Type[Any]]]], + Optional["FromClause"], + ] + ] + + version_id_generator: Optional[Union[Literal[False], Callable[[Any], Any]]] + + local_table: FromClause + """The immediate :class:`_expression.FromClause` which this + :class:`_orm.Mapper` refers towards. - Typically is an instance of :class:`_schema.Table` or - :class:`_expression.Alias`. - May also be ``None``. + Typically is an instance of :class:`_schema.Table`, may be any + :class:`.FromClause`. The "local" table is the selectable that the :class:`_orm.Mapper` is directly responsible for managing from an attribute access and flush perspective. For - non-inheriting mappers, the local table is the same as the - "mapped" table. For joined-table inheritance mappers, local_table - will be the particular sub-table of the overall "join" which - this :class:`_orm.Mapper` represents. If this mapper is a - single-table inheriting mapper, local_table will be ``None``. + non-inheriting mappers, :attr:`.Mapper.local_table` will be the same + as :attr:`.Mapper.persist_selectable`. For inheriting mappers, + :attr:`.Mapper.local_table` refers to the specific portion of + :attr:`.Mapper.persist_selectable` that includes the columns to which + this :class:`.Mapper` is loading/persisting, such as a particular + :class:`.Table` within a join. .. seealso:: :attr:`_orm.Mapper.persist_selectable`. + :attr:`_orm.Mapper.selectable`. + """ - persist_selectable = None - """The :class:`_expression.Selectable` to which this :class:`_orm.Mapper` + persist_selectable: FromClause + """The :class:`_expression.FromClause` to which this :class:`_orm.Mapper` is mapped. - Typically an instance of :class:`_schema.Table`, - :class:`_expression.Join`, or :class:`_expression.Alias`. - - The :attr:`_orm.Mapper.persist_selectable` is separate from - :attr:`_orm.Mapper.selectable` in that the former represents columns - that are mapped on this class or its superclasses, whereas the - latter may be a "polymorphic" selectable that contains additional columns - which are in fact mapped on subclasses only. + Typically is an instance of :class:`_schema.Table`, may be any + :class:`.FromClause`. - "persist selectable" is the "thing the mapper writes to" and - "selectable" is the "thing the mapper selects from". - - :attr:`_orm.Mapper.persist_selectable` is also separate from - :attr:`_orm.Mapper.local_table`, which represents the set of columns that - are locally mapped on this class directly. + The :attr:`_orm.Mapper.persist_selectable` is similar to + :attr:`.Mapper.local_table`, but represents the :class:`.FromClause` that + represents the inheriting class hierarchy overall in an inheritance + scenario. + :attr.`.Mapper.persist_selectable` is also separate from the + :attr:`.Mapper.selectable` attribute, the latter of which may be an + alternate subquery used for selecting columns. + :attr.`.Mapper.persist_selectable` is oriented towards columns that + will be written on a persist operation. .. seealso:: @@ -785,16 +923,15 @@ class Mapper( """ - inherits = None + inherits: Optional[Mapper[Any]] """References the :class:`_orm.Mapper` which this :class:`_orm.Mapper` inherits from, if any. - This is a *read only* attribute determined during mapper construction. - Behavior is undefined if directly modified. - """ - configured = False + inherit_condition: Optional[ColumnElement[bool]] + + configured: bool = False """Represent ``True`` if this :class:`_orm.Mapper` has been configured. This is a *read only* attribute determined during mapper construction. @@ -806,7 +943,7 @@ class Mapper( """ - concrete = None + concrete: bool """Represent ``True`` if this :class:`_orm.Mapper` is a concrete inheritance mapper. @@ -815,21 +952,6 @@ class Mapper( """ - tables = None - """An iterable containing the collection of :class:`_schema.Table` objects - which this :class:`_orm.Mapper` is aware of. - - If the mapper is mapped to a :class:`_expression.Join`, or an - :class:`_expression.Alias` - representing a :class:`_expression.Select`, the individual - :class:`_schema.Table` - objects that comprise the full construct will be represented here. - - This is a *read only* attribute determined during mapper construction. - Behavior is undefined if directly modified. - - """ - primary_key: Tuple[Column[Any], ...] """An iterable containing the collection of :class:`_schema.Column` objects @@ -854,14 +976,6 @@ class Mapper( """ - class_: Type[_O] - """The Python class which this :class:`_orm.Mapper` maps. - - This is a *read only* attribute determined during mapper construction. - Behavior is undefined if directly modified. - - """ - class_manager: ClassManager[_O] """The :class:`.ClassManager` which maintains event listeners and class-bound descriptors for this :class:`_orm.Mapper`. @@ -871,7 +985,7 @@ class Mapper( """ - single = None + single: bool """Represent ``True`` if this :class:`_orm.Mapper` is a single table inheritance mapper. @@ -882,7 +996,7 @@ class Mapper( """ - non_primary = None + non_primary: bool """Represent ``True`` if this :class:`_orm.Mapper` is a "non-primary" mapper, e.g. a mapper that is used only to select rows but not for persistence management. @@ -892,7 +1006,7 @@ class Mapper( """ - polymorphic_on = None + polymorphic_on: Optional[ColumnElement[Any]] """The :class:`_schema.Column` or SQL expression specified as the ``polymorphic_on`` argument for this :class:`_orm.Mapper`, within an inheritance scenario. @@ -906,7 +1020,7 @@ class Mapper( """ - polymorphic_map = None + polymorphic_map: Dict[Any, Mapper[Any]] """A mapping of "polymorphic identity" identifiers mapped to :class:`_orm.Mapper` instances, within an inheritance scenario. @@ -922,7 +1036,7 @@ class Mapper( """ - polymorphic_identity = None + polymorphic_identity: Optional[Any] """Represent an identifier which is matched against the :attr:`_orm.Mapper.polymorphic_on` column during result row loading. @@ -935,7 +1049,7 @@ class Mapper( """ - base_mapper = None + base_mapper: Mapper[Any] """The base-most :class:`_orm.Mapper` in an inheritance chain. In a non-inheriting scenario, this attribute will always be this @@ -948,7 +1062,7 @@ class Mapper( """ - columns = None + columns: ReadOnlyColumnCollection[str, Column[Any]] """A collection of :class:`_schema.Column` or other scalar expression objects maintained by this :class:`_orm.Mapper`. @@ -965,25 +1079,16 @@ class Mapper( """ - validators = None - """An immutable dictionary of attributes which have been decorated - using the :func:`_orm.validates` decorator. - - The dictionary contains string attribute names as keys - mapped to the actual validation method. - - """ - - c = None + c: ReadOnlyColumnCollection[str, Column[Any]] """A synonym for :attr:`_orm.Mapper.columns`.""" - @property + @util.non_memoized_property @util.deprecated("1.3", "Use .persist_selectable") def mapped_table(self): return self.persist_selectable @util.memoized_property - def _path_registry(self) -> PathRegistry: + def _path_registry(self) -> CachingEntityRegistry: return PathRegistry.per_mapper(self) def _configure_inheritance(self): @@ -994,8 +1099,6 @@ class Mapper( self._inheriting_mappers = util.WeakSequence() if self.inherits: - if isinstance(self.inherits, type): - self.inherits = class_mapper(self.inherits, configure=False) if not issubclass(self.class_, self.inherits.class_): raise sa_exc.ArgumentError( "Class '%s' does not inherit from '%s'" @@ -1011,11 +1114,9 @@ class Mapper( "only allowed from a %s mapper" % (np, self.class_.__name__, np) ) - # inherit_condition is optional. - if self.local_table is None: - self.local_table = self.inherits.local_table + + if self.single: self.persist_selectable = self.inherits.persist_selectable - self.single = True elif self.local_table is not self.inherits.local_table: if self.concrete: self.persist_selectable = self.local_table @@ -1068,6 +1169,7 @@ class Mapper( self.local_table.description, ) ) from afe + assert self.inherits.persist_selectable is not None self.persist_selectable = sql.join( self.inherits.persist_selectable, self.local_table, @@ -1149,6 +1251,7 @@ class Mapper( else: self._all_tables = set() self.base_mapper = self + assert self.local_table is not None self.persist_selectable = self.local_table if self.polymorphic_identity is not None: self.polymorphic_map[self.polymorphic_identity] = self @@ -1160,21 +1263,34 @@ class Mapper( % self ) - def _set_with_polymorphic(self, with_polymorphic): + def _set_with_polymorphic( + self, with_polymorphic: Optional[_WithPolymorphicArg] + ) -> None: if with_polymorphic == "*": self.with_polymorphic = ("*", None) elif isinstance(with_polymorphic, (tuple, list)): if isinstance(with_polymorphic[0], (str, tuple, list)): - self.with_polymorphic = with_polymorphic + self.with_polymorphic = cast( + """Tuple[ + Union[ + Literal["*"], + Sequence[Union["Mapper[Any]", Type[Any]]], + ], + Optional["FromClause"], + ]""", + with_polymorphic, + ) else: self.with_polymorphic = (with_polymorphic, None) elif with_polymorphic is not None: - raise sa_exc.ArgumentError("Invalid setting for with_polymorphic") + raise sa_exc.ArgumentError( + f"Invalid setting for with_polymorphic: {with_polymorphic!r}" + ) else: self.with_polymorphic = None if self.with_polymorphic and self.with_polymorphic[1] is not None: - self.with_polymorphic = ( + self.with_polymorphic = ( # type: ignore self.with_polymorphic[0], coercions.expect( roles.StrictFromClauseRole, @@ -1191,6 +1307,7 @@ class Mapper( if self.with_polymorphic is None: self._set_with_polymorphic((subcl,)) elif self.with_polymorphic[0] != "*": + assert isinstance(self.with_polymorphic[0], tuple) self._set_with_polymorphic( (self.with_polymorphic[0] + (subcl,), self.with_polymorphic[1]) ) @@ -1241,7 +1358,7 @@ class Mapper( # we expect that declarative has applied the class manager # already and set up a registry. if this is None, # this raises as of 2.0. - manager = attributes.manager_of_class(self.class_) + manager = attributes.opt_manager_of_class(self.class_) if self.non_primary: if not manager or not manager.is_mapped: @@ -1251,6 +1368,8 @@ class Mapper( "Mapper." % self.class_ ) self.class_manager = manager + + assert manager.registry is not None self.registry = manager.registry self._identity_class = manager.mapper._identity_class manager.registry._add_non_primary_mapper(self) @@ -1275,7 +1394,7 @@ class Mapper( manager = instrumentation.register_class( self.class_, mapper=self, - expired_attribute_loader=util.partial( + expired_attribute_loader=util.partial( # type: ignore loading.load_scalar_attributes, self ), # finalize flag means instrument the __init__ method @@ -1284,6 +1403,8 @@ class Mapper( ) self.class_manager = manager + + assert manager.registry is not None self.registry = manager.registry # The remaining members can be added by any mapper, @@ -1315,15 +1436,25 @@ class Mapper( {name: (method, validation_opts)} ) - def _set_dispose_flags(self): + def _set_dispose_flags(self) -> None: self.configured = True self._ready_for_configure = True self._dispose_called = True self.__dict__.pop("_configure_failed", None) - def _configure_pks(self): - self.tables = sql_util.find_tables(self.persist_selectable) + def _configure_pks(self) -> None: + self.tables = cast( + "List[Table]", sql_util.find_tables(self.persist_selectable) + ) + for t in self.tables: + if not isinstance(t, Table): + raise sa_exc.ArgumentError( + f"ORM mappings can only be made against schema-level " + f"Table objects, not TableClause; got " + f"tableclause {t.name !r}" + ) + self._all_tables.update(t for t in self.tables if isinstance(t, Table)) self._pks_by_table = {} self._cols_by_table = {} @@ -1335,16 +1466,16 @@ class Mapper( pk_cols = util.column_set(c for c in all_cols if c.primary_key) # identify primary key columns which are also mapped by this mapper. - tables = set(self.tables + [self.persist_selectable]) - self._all_tables.update(tables) - for t in tables: - if t.primary_key and pk_cols.issuperset(t.primary_key): + for fc in set(self.tables).union([self.persist_selectable]): + if fc.primary_key and pk_cols.issuperset(fc.primary_key): # ordering is important since it determines the ordering of # mapper.primary_key (and therefore query.get()) - self._pks_by_table[t] = util.ordered_column_set( - t.primary_key - ).intersection(pk_cols) - self._cols_by_table[t] = util.ordered_column_set(t.c).intersection( + self._pks_by_table[fc] = util.ordered_column_set( # type: ignore # noqa: E501 + fc.primary_key + ).intersection( + pk_cols + ) + self._cols_by_table[fc] = util.ordered_column_set(fc.c).intersection( # type: ignore # noqa: E501 all_cols ) @@ -1386,10 +1517,15 @@ class Mapper( self.primary_key = self.inherits.primary_key else: # determine primary key from argument or persist_selectable pks + primary_key: Collection[ColumnElement[Any]] + if self._primary_key_argument: primary_key = [ - self.persist_selectable.corresponding_column(c) - for c in self._primary_key_argument + cc if cc is not None else c + for cc, c in ( + (self.persist_selectable.corresponding_column(c), c) + for c in self._primary_key_argument + ) ] else: # if heuristically determined PKs, reduce to the minimal set @@ -1413,7 +1549,7 @@ class Mapper( # determine cols that aren't expressed within our tables; mark these # as "read only" properties which are refreshed upon INSERT/UPDATE - self._readonly_props = set( + self._readonly_props = { self._columntoproperty[col] for col in self._columntoproperty if self._columntoproperty[col] not in self._identity_key_props @@ -1421,12 +1557,12 @@ class Mapper( not hasattr(col, "table") or col.table not in self._cols_by_table ) - ) + } - def _configure_properties(self): + def _configure_properties(self) -> None: # TODO: consider using DedupeColumnCollection - self.columns = self.c = sql_base.ColumnCollection() + self.columns = self.c = sql_base.ColumnCollection() # type: ignore # object attribute names mapped to MapperProperty objects self._props = util.OrderedDict() @@ -1454,7 +1590,6 @@ class Mapper( continue column_key = (self.column_prefix or "") + column.key - if self._should_exclude( column.key, column_key, @@ -1542,6 +1677,7 @@ class Mapper( col = self.polymorphic_on if isinstance(col, schema.Column) and ( self.with_polymorphic is None + or self.with_polymorphic[1] is None or self.with_polymorphic[1].corresponding_column(col) is None ): @@ -1763,8 +1899,8 @@ class Mapper( self.columns.add(col, key) for col in prop.columns + prop._orig_columns: - for col in col.proxy_set: - self._columntoproperty[col] = prop + for proxy_col in col.proxy_set: + self._columntoproperty[proxy_col] = prop prop.key = key @@ -2033,7 +2169,9 @@ class Mapper( self._check_configure() return iter(self._props.values()) - def _mappers_from_spec(self, spec, selectable): + def _mappers_from_spec( + self, spec: Any, selectable: Optional[FromClause] + ) -> Sequence[Mapper[Any]]: """given a with_polymorphic() argument, return the set of mappers it represents. @@ -2044,7 +2182,7 @@ class Mapper( if spec == "*": mappers = list(self.self_and_descendants) elif spec: - mappers = set() + mapper_set = set() for m in util.to_list(spec): m = _class_to_mapper(m) if not m.isa(self): @@ -2053,10 +2191,10 @@ class Mapper( ) if selectable is None: - mappers.update(m.iterate_to_root()) + mapper_set.update(m.iterate_to_root()) else: - mappers.add(m) - mappers = [m for m in self.self_and_descendants if m in mappers] + mapper_set.add(m) + mappers = [m for m in self.self_and_descendants if m in mapper_set] else: mappers = [] @@ -2067,7 +2205,9 @@ class Mapper( mappers = [m for m in mappers if m.local_table in tables] return mappers - def _selectable_from_mappers(self, mappers, innerjoin): + def _selectable_from_mappers( + self, mappers: Iterable[Mapper[Any]], innerjoin: bool + ) -> FromClause: """given a list of mappers (assumed to be within this mapper's inheritance hierarchy), construct an outerjoin amongst those mapper's mapped tables. @@ -2098,13 +2238,13 @@ class Mapper( def _single_table_criterion(self): if self.single and self.inherits and self.polymorphic_on is not None: return self.polymorphic_on._annotate({"parentmapper": self}).in_( - m.polymorphic_identity for m in self.self_and_descendants + [m.polymorphic_identity for m in self.self_and_descendants] ) else: return None @HasMemoized.memoized_attribute - def _with_polymorphic_mappers(self): + def _with_polymorphic_mappers(self) -> Sequence[Mapper[Any]]: self._check_configure() if not self.with_polymorphic: @@ -2124,8 +2264,8 @@ class Mapper( """ self._check_configure() - @HasMemoized.memoized_attribute - def _with_polymorphic_selectable(self): + @HasMemoized_ro_memoized_attribute + def _with_polymorphic_selectable(self) -> FromClause: if not self.with_polymorphic: return self.persist_selectable @@ -2143,7 +2283,7 @@ class Mapper( """ - @HasMemoized.memoized_attribute + @HasMemoized_ro_memoized_attribute def _insert_cols_evaluating_none(self): return dict( ( @@ -2250,7 +2390,7 @@ class Mapper( @HasMemoized.memoized_instancemethod def __clause_element__(self): - annotations = { + annotations: Dict[str, Any] = { "entity_namespace": self, "parententity": self, "parentmapper": self, @@ -2290,7 +2430,7 @@ class Mapper( ) @property - def selectable(self): + def selectable(self) -> FromClause: """The :class:`_schema.FromClause` construct this :class:`_orm.Mapper` selects from by default. @@ -2302,8 +2442,11 @@ class Mapper( return self._with_polymorphic_selectable def _with_polymorphic_args( - self, spec=None, selectable=False, innerjoin=False - ): + self, + spec: Any = None, + selectable: Union[Literal[False, None], FromClause] = False, + innerjoin: bool = False, + ) -> Tuple[Sequence[Mapper[Any]], FromClause]: if selectable not in (None, False): selectable = coercions.expect( roles.StrictFromClauseRole, selectable, allow_select=True @@ -2357,7 +2500,7 @@ class Mapper( ] @HasMemoized.memoized_attribute - def _polymorphic_adapter(self): + def _polymorphic_adapter(self) -> Optional[sql_util.ColumnAdapter]: if self.with_polymorphic: return sql_util.ColumnAdapter( self.selectable, equivalents=self._equivalent_columns @@ -2394,7 +2537,7 @@ class Mapper( yield c @HasMemoized.memoized_attribute - def attrs(self) -> util.ReadOnlyProperties["MapperProperty"]: + def attrs(self) -> util.ReadOnlyProperties[MapperProperty[Any]]: """A namespace of all :class:`.MapperProperty` objects associated this mapper. @@ -2432,7 +2575,7 @@ class Mapper( return util.ReadOnlyProperties(self._props) @HasMemoized.memoized_attribute - def all_orm_descriptors(self): + def all_orm_descriptors(self) -> util.ReadOnlyProperties[InspectionAttr]: """A namespace of all :class:`.InspectionAttr` attributes associated with the mapped class. @@ -2503,7 +2646,7 @@ class Mapper( @HasMemoized.memoized_attribute @util.preload_module("sqlalchemy.orm.descriptor_props") - def synonyms(self): + def synonyms(self) -> util.ReadOnlyProperties[Synonym[Any]]: """Return a namespace of all :class:`.Synonym` properties maintained by this :class:`_orm.Mapper`. @@ -2523,7 +2666,7 @@ class Mapper( return self.class_ @HasMemoized.memoized_attribute - def column_attrs(self): + def column_attrs(self) -> util.ReadOnlyProperties[ColumnProperty[Any]]: """Return a namespace of all :class:`.ColumnProperty` properties maintained by this :class:`_orm.Mapper`. @@ -2536,9 +2679,9 @@ class Mapper( """ return self._filter_properties(properties.ColumnProperty) - @util.preload_module("sqlalchemy.orm.relationships") @HasMemoized.memoized_attribute - def relationships(self): + @util.preload_module("sqlalchemy.orm.relationships") + def relationships(self) -> util.ReadOnlyProperties[Relationship[Any]]: """A namespace of all :class:`.Relationship` properties maintained by this :class:`_orm.Mapper`. @@ -2567,7 +2710,7 @@ class Mapper( @HasMemoized.memoized_attribute @util.preload_module("sqlalchemy.orm.descriptor_props") - def composites(self): + def composites(self) -> util.ReadOnlyProperties[Composite[Any]]: """Return a namespace of all :class:`.Composite` properties maintained by this :class:`_orm.Mapper`. @@ -2582,7 +2725,9 @@ class Mapper( util.preloaded.orm_descriptor_props.Composite ) - def _filter_properties(self, type_): + def _filter_properties( + self, type_: Type[_MP] + ) -> util.ReadOnlyProperties[_MP]: self._check_configure() return util.ReadOnlyProperties( util.OrderedDict( @@ -2610,7 +2755,7 @@ class Mapper( ) @HasMemoized.memoized_attribute - def _equivalent_columns(self): + def _equivalent_columns(self) -> _EquivalentColumnMap: """Create a map of all equivalent columns, based on the determination of column pairs that are equated to one another based on inherit condition. This is designed @@ -2630,18 +2775,18 @@ class Mapper( } """ - result = util.column_dict() + result: _EquivalentColumnMap = {} def visit_binary(binary): if binary.operator == operators.eq: if binary.left in result: result[binary.left].add(binary.right) else: - result[binary.left] = util.column_set((binary.right,)) + result[binary.left] = {binary.right} if binary.right in result: result[binary.right].add(binary.left) else: - result[binary.right] = util.column_set((binary.left,)) + result[binary.right] = {binary.left} for mapper in self.base_mapper.self_and_descendants: if mapper.inherit_condition is not None: @@ -2711,13 +2856,13 @@ class Mapper( return False - def common_parent(self, other): + def common_parent(self, other: Mapper[Any]) -> bool: """Return true if the given mapper shares a common inherited parent as this mapper.""" return self.base_mapper is other.base_mapper - def is_sibling(self, other): + def is_sibling(self, other: Mapper[Any]) -> bool: """return true if the other mapper is an inheriting sibling to this one. common parent but different branch @@ -2728,7 +2873,9 @@ class Mapper( and not other.isa(self) ) - def _canload(self, state, allow_subtypes): + def _canload( + self, state: InstanceState[Any], allow_subtypes: bool + ) -> bool: s = self.primary_mapper() if self.polymorphic_on is not None or allow_subtypes: return _state_mapper(state).isa(s) @@ -2738,19 +2885,19 @@ class Mapper( def isa(self, other: Mapper[Any]) -> bool: """Return True if the this mapper inherits from the given mapper.""" - m = self + m: Optional[Mapper[Any]] = self while m and m is not other: m = m.inherits return bool(m) - def iterate_to_root(self): - m = self + def iterate_to_root(self) -> Iterator[Mapper[Any]]: + m: Optional[Mapper[Any]] = self while m: yield m m = m.inherits @HasMemoized.memoized_attribute - def self_and_descendants(self): + def self_and_descendants(self) -> Sequence[Mapper[Any]]: """The collection including this mapper and all descendant mappers. This includes not just the immediately inheriting mappers but @@ -2765,7 +2912,7 @@ class Mapper( stack.extend(item._inheriting_mappers) return util.WeakSequence(descendants) - def polymorphic_iterator(self): + def polymorphic_iterator(self) -> Iterator[Mapper[Any]]: """Iterate through the collection including this mapper and all descendant mappers. @@ -2778,18 +2925,18 @@ class Mapper( """ return iter(self.self_and_descendants) - def primary_mapper(self): + def primary_mapper(self) -> Mapper[Any]: """Return the primary mapper corresponding to this mapper's class key (class).""" return self.class_manager.mapper @property - def primary_base_mapper(self): + def primary_base_mapper(self) -> Mapper[Any]: return self.class_manager.mapper.base_mapper def _result_has_identity_key(self, result, adapter=None): - pk_cols = self.primary_key + pk_cols: Sequence[ColumnClause[Any]] = self.primary_key if adapter: pk_cols = [adapter.columns[c] for c in pk_cols] rk = result.keys() @@ -2799,25 +2946,35 @@ class Mapper( else: return True - def identity_key_from_row(self, row, identity_token=None, adapter=None): + def identity_key_from_row( + self, + row: Optional[Union[Row, RowMapping]], + identity_token: Optional[Any] = None, + adapter: Optional[ColumnAdapter] = None, + ) -> _IdentityKeyType[_O]: """Return an identity-map key for use in storing/retrieving an item from the identity map. - :param row: A :class:`.Row` instance. The columns which are - mapped by this :class:`_orm.Mapper` should be locatable in the row, - preferably via the :class:`_schema.Column` - object directly (as is the case - when a :func:`_expression.select` construct is executed), or - via string names of the form ``<tablename>_<colname>``. + :param row: A :class:`.Row` or :class:`.RowMapping` produced from a + result set that selected from the ORM mapped primary key columns. + + .. versionchanged:: 2.0 + :class:`.Row` or :class:`.RowMapping` are accepted + for the "row" argument """ - pk_cols = self.primary_key + pk_cols: Sequence[ColumnClause[Any]] = self.primary_key if adapter: pk_cols = [adapter.columns[c] for c in pk_cols] + if hasattr(row, "_mapping"): + mapping = row._mapping # type: ignore + else: + mapping = cast("Mapping[Any, Any]", row) + return ( self._identity_class, - tuple(row[column] for column in pk_cols), + tuple(mapping[column] for column in pk_cols), # type: ignore identity_token, ) @@ -2852,12 +3009,12 @@ class Mapper( """ state = attributes.instance_state(instance) - return self._identity_key_from_state(state, attributes.PASSIVE_OFF) + return self._identity_key_from_state(state, PassiveFlag.PASSIVE_OFF) def _identity_key_from_state( self, state: InstanceState[_O], - passive: PassiveFlag = attributes.PASSIVE_RETURN_NO_VALUE, + passive: PassiveFlag = PassiveFlag.PASSIVE_RETURN_NO_VALUE, ) -> _IdentityKeyType[_O]: dict_ = state.dict manager = state.manager @@ -2884,7 +3041,7 @@ class Mapper( """ state = attributes.instance_state(instance) identity_key = self._identity_key_from_state( - state, attributes.PASSIVE_OFF + state, PassiveFlag.PASSIVE_OFF ) return identity_key[1] @@ -2913,14 +3070,14 @@ class Mapper( @HasMemoized.memoized_attribute def _all_pk_cols(self): - collection = set() + collection: Set[ColumnClause[Any]] = set() for table in self.tables: collection.update(self._pks_by_table[table]) return collection @HasMemoized.memoized_attribute def _should_undefer_in_wildcard(self): - cols = set(self.primary_key) + cols: Set[ColumnElement[Any]] = set(self.primary_key) if self.polymorphic_on is not None: cols.add(self.polymorphic_on) return cols @@ -2951,11 +3108,11 @@ class Mapper( state = attributes.instance_state(obj) dict_ = attributes.instance_dict(obj) return self._get_committed_state_attr_by_column( - state, dict_, column, passive=attributes.PASSIVE_OFF + state, dict_, column, passive=PassiveFlag.PASSIVE_OFF ) def _get_committed_state_attr_by_column( - self, state, dict_, column, passive=attributes.PASSIVE_RETURN_NO_VALUE + self, state, dict_, column, passive=PassiveFlag.PASSIVE_RETURN_NO_VALUE ): prop = self._columntoproperty[column] @@ -2978,7 +3135,7 @@ class Mapper( col_attribute_names = set(attribute_names).intersection( state.mapper.column_attrs.keys() ) - tables = set( + tables: Set[FromClause] = set( chain( *[ sql_util.find_tables(c, check_columns=True) @@ -3002,7 +3159,7 @@ class Mapper( state, state.dict, leftcol, - passive=attributes.PASSIVE_NO_INITIALIZE, + passive=PassiveFlag.PASSIVE_NO_INITIALIZE, ) if leftval in orm_util._none_set: raise _OptGetColumnsNotAvailable() @@ -3014,7 +3171,7 @@ class Mapper( state, state.dict, rightcol, - passive=attributes.PASSIVE_NO_INITIALIZE, + passive=PassiveFlag.PASSIVE_NO_INITIALIZE, ) if rightval in orm_util._none_set: raise _OptGetColumnsNotAvailable() @@ -3022,7 +3179,7 @@ class Mapper( None, rightval, type_=binary.right.type ) - allconds = [] + allconds: List[ColumnElement[bool]] = [] start = False @@ -3035,6 +3192,9 @@ class Mapper( elif not isinstance(mapper.local_table, expression.TableClause): return None if start and not mapper.single: + assert mapper.inherits + assert not mapper.concrete + assert mapper.inherit_condition is not None allconds.append(mapper.inherit_condition) tables.add(mapper.local_table) @@ -3043,11 +3203,13 @@ class Mapper( # descendant-most class should all be present and joined to each # other. try: - allconds[0] = visitors.cloned_traverse( + _traversed = visitors.cloned_traverse( allconds[0], {}, {"binary": visit_binary} ) except _OptGetColumnsNotAvailable: return None + else: + allconds[0] = _traversed cond = sql.and_(*allconds) @@ -3145,6 +3307,8 @@ class Mapper( for pk in self.primary_key ] + in_expr: ColumnElement[Any] + if len(primary_key) > 1: in_expr = sql.tuple_(*primary_key) else: @@ -3209,11 +3373,22 @@ class Mapper( traverse all objects without relying on cascades. """ - visited_states = set() + visited_states: Set[InstanceState[Any]] = set() prp, mpp = object(), object() assert state.mapper.isa(self) + # this is actually a recursive structure, fully typing it seems + # a little too difficult for what it's worth here + visitables: Deque[ + Tuple[ + Deque[Any], + object, + Optional[InstanceState[Any]], + Optional[_InstanceDict], + ] + ] + visitables = deque( [(deque(state.mapper._props.values()), prp, state, state.dict)] ) @@ -3226,8 +3401,10 @@ class Mapper( if item_type is prp: prop = iterator.popleft() - if type_ not in prop.cascade: + if not prop.cascade or type_ not in prop.cascade: continue + assert parent_state is not None + assert parent_dict is not None queue = deque( prop.cascade_iterator( type_, @@ -3267,7 +3444,7 @@ class Mapper( @HasMemoized.memoized_attribute def _sorted_tables(self): - table_to_mapper = {} + table_to_mapper: Dict[Table, Mapper[Any]] = {} for mapper in self.base_mapper.self_and_descendants: for t in mapper.tables: @@ -3316,9 +3493,9 @@ class Mapper( ret[t] = table_to_mapper[t] return ret - def _memo(self, key, callable_): + def _memo(self, key: Any, callable_: Callable[[], _T]) -> _T: if key in self._memoized_values: - return self._memoized_values[key] + return cast(_T, self._memoized_values[key]) else: self._memoized_values[key] = value = callable_() return value @@ -3328,14 +3505,22 @@ class Mapper( """memoized map of tables to collections of columns to be synchronized upwards to the base mapper.""" - result = util.defaultdict(list) + result: util.defaultdict[ + Table, + List[ + Tuple[ + Mapper[Any], + List[Tuple[ColumnElement[Any], ColumnElement[Any]]], + ] + ], + ] = util.defaultdict(list) for table in self._sorted_tables: cols = set(table.c) for m in self.iterate_to_root(): if m._inherits_equated_pairs and cols.intersection( reduce( - set.union, + set.union, # type: ignore [l.proxy_set for l, r in m._inherits_equated_pairs], ) ): @@ -3440,7 +3625,7 @@ def _configure_registries(registries, cascade): else: return - Mapper.dispatch._for_class(Mapper).before_configured() + Mapper.dispatch._for_class(Mapper).before_configured() # type: ignore # noqa: E501 # initialize properties on all mappers # note that _mapper_registry is unordered, which # may randomly conceal/reveal issues related to @@ -3449,7 +3634,7 @@ def _configure_registries(registries, cascade): _do_configure_registries(registries, cascade) finally: _already_compiling = False - Mapper.dispatch._for_class(Mapper).after_configured() + Mapper.dispatch._for_class(Mapper).after_configured() # type: ignore @util.preload_module("sqlalchemy.orm.decl_api") @@ -3480,7 +3665,7 @@ def _do_configure_registries(registries, cascade): "Original exception was: %s" % (mapper, mapper._configure_failed) ) - e._configure_failed = mapper._configure_failed + e._configure_failed = mapper._configure_failed # type: ignore raise e if not mapper.configured: @@ -3636,7 +3821,7 @@ def _event_on_init(state, args, kwargs): instrumenting_mapper._set_polymorphic_identity(state) -class _ColumnMapping(dict): +class _ColumnMapping(Dict["ColumnElement[Any]", "MapperProperty[Any]"]): """Error reporting helper for mapper._columntoproperty.""" __slots__ = ("mapper",) diff --git a/lib/sqlalchemy/orm/path_registry.py b/lib/sqlalchemy/orm/path_registry.py index e2cf1d5b0..361cea975 100644 --- a/lib/sqlalchemy/orm/path_registry.py +++ b/lib/sqlalchemy/orm/path_registry.py @@ -13,22 +13,70 @@ from __future__ import annotations from functools import reduce from itertools import chain import logging +import operator from typing import Any +from typing import cast +from typing import Dict +from typing import Iterator +from typing import List +from typing import Optional +from typing import overload from typing import Sequence from typing import Tuple +from typing import TYPE_CHECKING from typing import Union from . import base as orm_base +from ._typing import insp_is_mapper_property from .. import exc -from .. import inspection from .. import util from ..sql import visitors from ..sql.cache_key import HasCacheKey +if TYPE_CHECKING: + from ._typing import _InternalEntityType + from .interfaces import MapperProperty + from .mapper import Mapper + from .relationships import Relationship + from .util import AliasedInsp + from ..sql.cache_key import _CacheKeyTraversalType + from ..sql.elements import BindParameter + from ..sql.visitors import anon_map + from ..util.typing import TypeGuard + + def is_root(path: PathRegistry) -> TypeGuard[RootRegistry]: + ... + + def is_entity(path: PathRegistry) -> TypeGuard[AbstractEntityRegistry]: + ... + +else: + is_root = operator.attrgetter("is_root") + is_entity = operator.attrgetter("is_entity") + + +_SerializedPath = List[Any] + +_PathElementType = Union[ + str, "_InternalEntityType[Any]", "MapperProperty[Any]" +] + +# the representation is in fact +# a tuple with alternating: +# [_InternalEntityType[Any], Union[str, MapperProperty[Any]], +# _InternalEntityType[Any], Union[str, MapperProperty[Any]], ...] +# this might someday be a tuple of 2-tuples instead, but paths can be +# chopped at odd intervals as well so this is less flexible +_PathRepresentation = Tuple[_PathElementType, ...] + +_OddPathRepresentation = Sequence["_InternalEntityType[Any]"] +_EvenPathRepresentation = Sequence[Union["MapperProperty[Any]", str]] + + log = logging.getLogger(__name__) -def _unreduce_path(path): +def _unreduce_path(path: _SerializedPath) -> PathRegistry: return PathRegistry.deserialize(path) @@ -67,17 +115,18 @@ class PathRegistry(HasCacheKey): is_token = False is_root = False has_entity = False + is_entity = False - path: Tuple - natural_path: Tuple - parent: Union["PathRegistry", None] + path: _PathRepresentation + natural_path: _PathRepresentation + parent: Optional[PathRegistry] + root: RootRegistry - root: "PathRegistry" - _cache_key_traversal = [ + _cache_key_traversal: _CacheKeyTraversalType = [ ("path", visitors.ExtendedInternalTraversal.dp_has_cache_key_list) ] - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: try: return other is not None and self.path == other._path_for_compare except AttributeError: @@ -87,7 +136,7 @@ class PathRegistry(HasCacheKey): ) return False - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: try: return other is None or self.path != other._path_for_compare except AttributeError: @@ -98,74 +147,88 @@ class PathRegistry(HasCacheKey): return True @property - def _path_for_compare(self): + def _path_for_compare(self) -> Optional[_PathRepresentation]: return self.path - def set(self, attributes, key, value): + def set(self, attributes: Dict[Any, Any], key: Any, value: Any) -> None: log.debug("set '%s' on path '%s' to '%s'", key, self, value) attributes[(key, self.natural_path)] = value - def setdefault(self, attributes, key, value): + def setdefault( + self, attributes: Dict[Any, Any], key: Any, value: Any + ) -> None: log.debug("setdefault '%s' on path '%s' to '%s'", key, self, value) attributes.setdefault((key, self.natural_path), value) - def get(self, attributes, key, value=None): + def get( + self, attributes: Dict[Any, Any], key: Any, value: Optional[Any] = None + ) -> Any: key = (key, self.natural_path) if key in attributes: return attributes[key] else: return value - def __len__(self): + def __len__(self) -> int: return len(self.path) - def __hash__(self): + def __hash__(self) -> int: return id(self) - def __getitem__(self, key: Any) -> "PathRegistry": + def __getitem__(self, key: Any) -> PathRegistry: raise NotImplementedError() + # TODO: what are we using this for? @property - def length(self): + def length(self) -> int: return len(self.path) - def pairs(self): - path = self.path - for i in range(0, len(path), 2): - yield path[i], path[i + 1] - - def contains_mapper(self, mapper): - for path_mapper in [self.path[i] for i in range(0, len(self.path), 2)]: + def pairs( + self, + ) -> Iterator[ + Tuple[_InternalEntityType[Any], Union[str, MapperProperty[Any]]] + ]: + odd_path = cast(_OddPathRepresentation, self.path) + even_path = cast(_EvenPathRepresentation, odd_path) + for i in range(0, len(odd_path), 2): + yield odd_path[i], even_path[i + 1] + + def contains_mapper(self, mapper: Mapper[Any]) -> bool: + _m_path = cast(_OddPathRepresentation, self.path) + for path_mapper in [_m_path[i] for i in range(0, len(_m_path), 2)]: if path_mapper.is_mapper and path_mapper.isa(mapper): return True else: return False - def contains(self, attributes, key): + def contains(self, attributes: Dict[Any, Any], key: Any) -> bool: return (key, self.path) in attributes - def __reduce__(self): + def __reduce__(self) -> Any: return _unreduce_path, (self.serialize(),) @classmethod - def _serialize_path(cls, path): + def _serialize_path(cls, path: _PathRepresentation) -> _SerializedPath: + _m_path = cast(_OddPathRepresentation, path) + _p_path = cast(_EvenPathRepresentation, path) + return list( zip( - [ + tuple( m.class_ if (m.is_mapper or m.is_aliased_class) else str(m) - for m in [path[i] for i in range(0, len(path), 2)] - ], - [ - path[i].key if (path[i].is_property) else str(path[i]) - for i in range(1, len(path), 2) - ] - + [None], + for m in [_m_path[i] for i in range(0, len(_m_path), 2)] + ), + tuple( + p.key if insp_is_mapper_property(p) else str(p) + for p in [_p_path[i] for i in range(1, len(_p_path), 2)] + ) + + (None,), ) ) @classmethod - def _deserialize_path(cls, path): - def _deserialize_mapper_token(mcls): + def _deserialize_path(cls, path: _SerializedPath) -> _PathRepresentation: + def _deserialize_mapper_token(mcls: Any) -> Any: return ( # note: we likely dont want configure=True here however # this is maintained at the moment for backwards compatibility @@ -174,15 +237,15 @@ class PathRegistry(HasCacheKey): else PathToken._intern[mcls] ) - def _deserialize_key_token(mcls, key): + def _deserialize_key_token(mcls: Any, key: Any) -> Any: if key is None: return None elif key in PathToken._intern: return PathToken._intern[key] else: - return orm_base._inspect_mapped_class( - mcls, configure=True - ).attrs[key] + mp = orm_base._inspect_mapped_class(mcls, configure=True) + assert mp is not None + return mp.attrs[key] p = tuple( chain( @@ -199,28 +262,63 @@ class PathRegistry(HasCacheKey): p = p[0:-1] return p - def serialize(self) -> Sequence[Any]: + def serialize(self) -> _SerializedPath: path = self.path return self._serialize_path(path) @classmethod - def deserialize(cls, path: Sequence[Any]) -> PathRegistry: + def deserialize(cls, path: _SerializedPath) -> PathRegistry: assert path is not None p = cls._deserialize_path(path) return cls.coerce(p) + @overload @classmethod - def per_mapper(cls, mapper): + def per_mapper(cls, mapper: Mapper[Any]) -> CachingEntityRegistry: + ... + + @overload + @classmethod + def per_mapper(cls, mapper: AliasedInsp[Any]) -> SlotsEntityRegistry: + ... + + @classmethod + def per_mapper( + cls, mapper: _InternalEntityType[Any] + ) -> AbstractEntityRegistry: if mapper.is_mapper: return CachingEntityRegistry(cls.root, mapper) else: return SlotsEntityRegistry(cls.root, mapper) @classmethod - def coerce(cls, raw): - return reduce(lambda prev, next: prev[next], raw, cls.root) + def coerce(cls, raw: _PathRepresentation) -> PathRegistry: + def _red(prev: PathRegistry, next_: _PathElementType) -> PathRegistry: + return prev[next_] + + # can't quite get mypy to appreciate this one :) + return reduce(_red, raw, cls.root) # type: ignore + + def __add__(self, other: PathRegistry) -> PathRegistry: + def _red(prev: PathRegistry, next_: _PathElementType) -> PathRegistry: + return prev[next_] - def token(self, token): + return reduce(_red, other.path, self) + + def __str__(self) -> str: + return f"ORM Path[{' -> '.join(str(elem) for elem in self.path)}]" + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.path!r})" + + +class CreatesToken(PathRegistry): + __slots__ = () + + is_aliased_class: bool + is_root: bool + + def token(self, token: str) -> TokenRegistry: if token.endswith(f":{_WILDCARD_TOKEN}"): return TokenRegistry(self, token) elif token.endswith(f":{_DEFAULT_TOKEN}"): @@ -228,34 +326,47 @@ class PathRegistry(HasCacheKey): else: raise exc.ArgumentError(f"invalid token: {token}") - def __add__(self, other): - return reduce(lambda prev, next: prev[next], other.path, self) - - def __str__(self): - return f"ORM Path[{' -> '.join(str(elem) for elem in self.path)}]" - - def __repr__(self): - return f"{self.__class__.__name__}({self.path!r})" - -class RootRegistry(PathRegistry): +class RootRegistry(CreatesToken): """Root registry, defers to mappers so that paths are maintained per-root-mapper. """ + __slots__ = () + inherit_cache = True path = natural_path = () has_entity = False is_aliased_class = False is_root = True + is_unnatural = False + + @overload + def __getitem__(self, entity: str) -> TokenRegistry: + ... + + @overload + def __getitem__( + self, entity: _InternalEntityType[Any] + ) -> AbstractEntityRegistry: + ... - def __getitem__(self, entity): + def __getitem__( + self, entity: Union[str, _InternalEntityType[Any]] + ) -> Union[TokenRegistry, AbstractEntityRegistry]: if entity in PathToken._intern: + if TYPE_CHECKING: + assert isinstance(entity, str) return TokenRegistry(self, PathToken._intern[entity]) else: - return inspection.inspect(entity)._path_registry + try: + return entity._path_registry # type: ignore + except AttributeError: + raise IndexError( + f"invalid argument for RootRegistry.__getitem__: {entity}" + ) PathRegistry.root = RootRegistry() @@ -264,17 +375,19 @@ PathRegistry.root = RootRegistry() class PathToken(orm_base.InspectionAttr, HasCacheKey, str): """cacheable string token""" - _intern = {} + _intern: Dict[str, PathToken] = {} - def _gen_cache_key(self, anon_map, bindparams): + def _gen_cache_key( + self, anon_map: anon_map, bindparams: List[BindParameter[Any]] + ) -> Tuple[Any, ...]: return (str(self),) @property - def _path_for_compare(self): + def _path_for_compare(self) -> Optional[_PathRepresentation]: return None @classmethod - def intern(cls, strvalue): + def intern(cls, strvalue: str) -> PathToken: if strvalue in cls._intern: return cls._intern[strvalue] else: @@ -287,7 +400,10 @@ class TokenRegistry(PathRegistry): inherit_cache = True - def __init__(self, parent, token): + token: str + parent: CreatesToken + + def __init__(self, parent: CreatesToken, token: str): token = PathToken.intern(token) self.token = token @@ -299,21 +415,33 @@ class TokenRegistry(PathRegistry): is_token = True - def generate_for_superclasses(self): - if not self.parent.is_aliased_class and not self.parent.is_root: - for ent in self.parent.mapper.iterate_to_root(): - yield TokenRegistry(self.parent.parent[ent], self.token) + def generate_for_superclasses(self) -> Iterator[PathRegistry]: + parent = self.parent + if is_root(parent): + yield self + return + + if TYPE_CHECKING: + assert isinstance(parent, AbstractEntityRegistry) + if not parent.is_aliased_class: + for mp_ent in parent.mapper.iterate_to_root(): + yield TokenRegistry(parent.parent[mp_ent], self.token) elif ( - self.parent.is_aliased_class - and self.parent.entity._is_with_polymorphic + parent.is_aliased_class + and cast( + "AliasedInsp[Any]", + parent.entity, + )._is_with_polymorphic ): yield self - for ent in self.parent.entity._with_polymorphic_entities: - yield TokenRegistry(self.parent.parent[ent], self.token) + for ent in cast( + "AliasedInsp[Any]", parent.entity + )._with_polymorphic_entities: + yield TokenRegistry(parent.parent[ent], self.token) else: yield self - def __getitem__(self, entity): + def __getitem__(self, entity: Any) -> Any: try: return self.path[entity] except TypeError as err: @@ -321,23 +449,42 @@ class TokenRegistry(PathRegistry): class PropRegistry(PathRegistry): - is_unnatural = False + __slots__ = ( + "prop", + "parent", + "path", + "natural_path", + "has_entity", + "entity", + "mapper", + "_wildcard_path_loader_key", + "_default_path_loader_key", + "_loader_key", + "is_unnatural", + ) inherit_cache = True - def __init__(self, parent, prop): + prop: MapperProperty[Any] + mapper: Optional[Mapper[Any]] + entity: Optional[_InternalEntityType[Any]] + + def __init__( + self, parent: AbstractEntityRegistry, prop: MapperProperty[Any] + ): # restate this path in terms of the # given MapperProperty's parent. - insp = inspection.inspect(parent[-1]) - natural_parent = parent + insp = cast("_InternalEntityType[Any]", parent[-1]) + natural_parent: AbstractEntityRegistry = parent + self.is_unnatural = False - if not insp.is_aliased_class or insp._use_mapper_path: + if not insp.is_aliased_class or insp._use_mapper_path: # type: ignore parent = natural_parent = parent.parent[prop.parent] elif ( insp.is_aliased_class and insp.with_polymorphic_mappers and prop.parent in insp.with_polymorphic_mappers ): - subclass_entity = parent[-1]._entity_for_mapper(prop.parent) + subclass_entity: _InternalEntityType[Any] = parent[-1]._entity_for_mapper(prop.parent) # type: ignore # noqa: E501 parent = parent.parent[subclass_entity] # when building a path where with_polymorphic() is in use, @@ -388,43 +535,74 @@ class PropRegistry(PathRegistry): self.parent = parent self.path = parent.path + (prop,) self.natural_path = natural_parent.natural_path + (prop,) + self.has_entity = prop._links_to_entity + if prop._is_relationship: + if TYPE_CHECKING: + assert isinstance(prop, Relationship) + self.entity = prop.entity + self.mapper = prop.mapper + else: + self.entity = None + self.mapper = None self._wildcard_path_loader_key = ( "loader", - parent.path + self.prop._wildcard_token, + parent.path + self.prop._wildcard_token, # type: ignore ) self._default_path_loader_key = self.prop._default_path_loader_key self._loader_key = ("loader", self.natural_path) - @util.memoized_property - def has_entity(self): - return self.prop._links_to_entity + @property + def entity_path(self) -> AbstractEntityRegistry: + assert self.entity is not None + return self[self.entity] - @util.memoized_property - def entity(self): - return self.prop.entity + @overload + def __getitem__(self, entity: slice) -> _PathRepresentation: + ... - @property - def mapper(self): - return self.prop.mapper + @overload + def __getitem__(self, entity: int) -> _PathElementType: + ... - @property - def entity_path(self): - return self[self.entity] + @overload + def __getitem__( + self, entity: _InternalEntityType[Any] + ) -> AbstractEntityRegistry: + ... - def __getitem__(self, entity): + def __getitem__( + self, entity: Union[int, slice, _InternalEntityType[Any]] + ) -> Union[AbstractEntityRegistry, _PathElementType, _PathRepresentation]: if isinstance(entity, (int, slice)): return self.path[entity] else: return SlotsEntityRegistry(self, entity) -class AbstractEntityRegistry(PathRegistry): - __slots__ = () +class AbstractEntityRegistry(CreatesToken): + __slots__ = ( + "key", + "parent", + "is_aliased_class", + "path", + "entity", + "natural_path", + ) has_entity = True - - def __init__(self, parent, entity): + is_entity = True + + parent: Union[RootRegistry, PropRegistry] + key: _InternalEntityType[Any] + entity: _InternalEntityType[Any] + is_aliased_class: bool + + def __init__( + self, + parent: Union[RootRegistry, PropRegistry], + entity: _InternalEntityType[Any], + ): self.key = entity self.parent = parent self.is_aliased_class = entity.is_aliased_class @@ -447,11 +625,11 @@ class AbstractEntityRegistry(PathRegistry): if parent.path and (self.is_aliased_class or parent.is_unnatural): # this is an infrequent code path used only for loader strategies # that also make use of of_type(). - if entity.mapper.isa(parent.natural_path[-1].entity): + if entity.mapper.isa(parent.natural_path[-1].entity): # type: ignore # noqa: E501 self.natural_path = parent.natural_path + (entity.mapper,) else: self.natural_path = parent.natural_path + ( - parent.natural_path[-1].entity, + parent.natural_path[-1].entity, # type: ignore ) # it seems to make sense that since these paths get mixed up # with statements that are cached or not, we should make @@ -465,19 +643,35 @@ class AbstractEntityRegistry(PathRegistry): self.natural_path = self.path @property - def entity_path(self): + def entity_path(self) -> PathRegistry: return self @property - def mapper(self): - return inspection.inspect(self.entity).mapper + def mapper(self) -> Mapper[Any]: + return self.entity.mapper - def __bool__(self): + def __bool__(self) -> bool: return True - __nonzero__ = __bool__ + @overload + def __getitem__(self, entity: MapperProperty[Any]) -> PropRegistry: + ... + + @overload + def __getitem__(self, entity: str) -> TokenRegistry: + ... + + @overload + def __getitem__(self, entity: int) -> _PathElementType: + ... - def __getitem__(self, entity): + @overload + def __getitem__(self, entity: slice) -> _PathRepresentation: + ... + + def __getitem__( + self, entity: Any + ) -> Union[_PathElementType, _PathRepresentation, PathRegistry]: if isinstance(entity, (int, slice)): return self.path[entity] elif entity in PathToken._intern: @@ -491,31 +685,40 @@ class SlotsEntityRegistry(AbstractEntityRegistry): # version inherit_cache = True - __slots__ = ( - "key", - "parent", - "is_aliased_class", - "entity", - "path", - "natural_path", - ) + +class _ERDict(Dict[Any, Any]): + def __init__(self, registry: CachingEntityRegistry): + self.registry = registry + + def __missing__(self, key: Any) -> PropRegistry: + self[key] = item = PropRegistry(self.registry, key) + + return item -class CachingEntityRegistry(AbstractEntityRegistry, dict): +class CachingEntityRegistry(AbstractEntityRegistry): # for long lived mapper, return dict based caching # version that creates reference cycles + __slots__ = ("_cache",) + inherit_cache = True - def __getitem__(self, entity): + def __init__( + self, + parent: Union[RootRegistry, PropRegistry], + entity: _InternalEntityType[Any], + ): + super().__init__(parent, entity) + self._cache = _ERDict(self) + + def pop(self, key: Any, default: Any) -> Any: + return self._cache.pop(key, default) + + def __getitem__(self, entity: Any) -> Any: if isinstance(entity, (int, slice)): return self.path[entity] elif isinstance(entity, PathToken): return TokenRegistry(self, entity) else: - return dict.__getitem__(self, entity) - - def __missing__(self, key): - self[key] = item = PropRegistry(self, key) - - return item + return self._cache[entity] diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index c01825b6d..9f37e8457 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -19,6 +19,8 @@ from typing import cast from typing import List from typing import Optional from typing import Set +from typing import Type +from typing import TYPE_CHECKING from typing import TypeVar from . import attributes @@ -38,17 +40,22 @@ from .util import _orm_full_deannotate from .. import exc as sa_exc from .. import ForeignKey from .. import log -from .. import sql from .. import util from ..sql import coercions from ..sql import roles from ..sql import sqltypes from ..sql.schema import Column +from ..sql.schema import SchemaConst from ..util.typing import de_optionalize_union_types from ..util.typing import de_stringify_annotation from ..util.typing import is_fwd_ref from ..util.typing import NoneType +if TYPE_CHECKING: + from ._typing import _ORMColumnExprArgument + from ..sql._typing import _InfoType + from ..sql.elements import ColumnElement + _T = TypeVar("_T", bound=Any) _PT = TypeVar("_PT", bound=Any) @@ -78,6 +85,10 @@ class ColumnProperty( inherit_cache = True _links_to_entity = False + columns: List[ColumnElement[Any]] + + _is_polymorphic_discriminator: bool + __slots__ = ( "_orig_columns", "columns", @@ -99,7 +110,19 @@ class ColumnProperty( ) def __init__( - self, column: sql.ColumnElement[_T], *additional_columns, **kwargs + self, + column: _ORMColumnExprArgument[_T], + *additional_columns: _ORMColumnExprArgument[Any], + group: Optional[str] = None, + deferred: bool = False, + raiseload: bool = False, + comparator_factory: Optional[Type[PropComparator]] = None, + descriptor: Optional[Any] = None, + active_history: bool = False, + expire_on_flush: bool = True, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, + _instrument: bool = True, ): super(ColumnProperty, self).__init__() columns = (column,) + additional_columns @@ -112,23 +135,24 @@ class ColumnProperty( ) for c in columns ] - self.parent = self.key = None - self.group = kwargs.pop("group", None) - self.deferred = kwargs.pop("deferred", False) - self.raiseload = kwargs.pop("raiseload", False) - self.instrument = kwargs.pop("_instrument", True) - self.comparator_factory = kwargs.pop( - "comparator_factory", self.__class__.Comparator + self.group = group + self.deferred = deferred + self.raiseload = raiseload + self.instrument = _instrument + self.comparator_factory = ( + comparator_factory + if comparator_factory is not None + else self.__class__.Comparator ) - self.descriptor = kwargs.pop("descriptor", None) - self.active_history = kwargs.pop("active_history", False) - self.expire_on_flush = kwargs.pop("expire_on_flush", True) + self.descriptor = descriptor + self.active_history = active_history + self.expire_on_flush = expire_on_flush - if "info" in kwargs: - self.info = kwargs.pop("info") + if info is not None: + self.info = info - if "doc" in kwargs: - self.doc = kwargs.pop("doc") + if doc is not None: + self.doc = doc else: for col in reversed(self.columns): doc = getattr(col, "doc", None) @@ -138,12 +162,6 @@ class ColumnProperty( else: self.doc = None - if kwargs: - raise TypeError( - "%s received unexpected keyword argument(s): %s" - % (self.__class__.__name__, ", ".join(sorted(kwargs.keys()))) - ) - util.set_creation_order(self) self.strategy_key = ( @@ -445,7 +463,10 @@ class MappedColumn( self.deferred = kw.pop("deferred", False) self.column = cast("Column[_T]", Column(*arg, **kw)) self.foreign_keys = self.column.foreign_keys - self._has_nullable = "nullable" in kw + self._has_nullable = "nullable" in kw and kw.get("nullable") not in ( + None, + SchemaConst.NULL_UNSPECIFIED, + ) util.set_creation_order(self) def _copy(self, **kw): diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index a754bd4f2..395d01a1e 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -30,6 +30,7 @@ from typing import Optional from typing import Tuple from typing import TYPE_CHECKING from typing import TypeVar +from typing import Union from . import exc as orm_exc from . import interfaces @@ -77,6 +78,8 @@ from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL if TYPE_CHECKING: from ..sql.selectable import _SetupJoinsElement + from ..sql.selectable import Alias + from ..sql.selectable import Subquery __all__ = ["Query", "QueryContext"] @@ -2769,14 +2772,14 @@ class AliasOption(interfaces.LoaderOption): "for entities to be matched up to a query that is established " "via :meth:`.Query.from_statement` and now does nothing.", ) - def __init__(self, alias): + def __init__(self, alias: Union[Alias, Subquery]): r"""Return a :class:`.MapperOption` that will indicate to the :class:`_query.Query` that the main table has been aliased. """ - def process_compile_state(self, compile_state): + def process_compile_state(self, compile_state: ORMCompileState): pass diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 58c7c4efd..66021c9c2 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -21,7 +21,10 @@ import re import typing from typing import Any from typing import Callable +from typing import Dict from typing import Optional +from typing import Sequence +from typing import Tuple from typing import Type from typing import TypeVar from typing import Union @@ -30,6 +33,7 @@ import weakref from . import attributes from . import strategy_options from .base import _is_mapped_class +from .base import class_mapper from .base import state_str from .interfaces import _IntrospectsAnnotations from .interfaces import MANYTOMANY @@ -53,7 +57,9 @@ from ..sql import expression from ..sql import operators from ..sql import roles from ..sql import visitors -from ..sql.elements import SQLCoreOperations +from ..sql._typing import _ColumnExpressionArgument +from ..sql._typing import _HasClauseElement +from ..sql.elements import ColumnClause from ..sql.util import _deep_deannotate from ..sql.util import _shallow_annotate from ..sql.util import adapt_criterion_to_null @@ -61,11 +67,14 @@ from ..sql.util import ClauseAdapter from ..sql.util import join_condition from ..sql.util import selectables_overlap from ..sql.util import visit_binary_product +from ..util.typing import Literal if typing.TYPE_CHECKING: + from ._typing import _EntityType from .mapper import Mapper from .util import AliasedClass from .util import AliasedInsp + from ..sql.elements import ColumnElement _T = TypeVar("_T", bound=Any) _PT = TypeVar("_PT", bound=Any) @@ -81,6 +90,34 @@ _RelationshipArgumentType = Union[ Callable[[], "AliasedClass[_T]"], ] +_LazyLoadArgumentType = Literal[ + "select", + "joined", + "selectin", + "subquery", + "raise", + "raise_on_sql", + "noload", + "immediate", + "dynamic", + True, + False, + None, +] + + +_RelationshipJoinConditionArgument = Union[ + str, _ColumnExpressionArgument[bool] +] +_ORMOrderByArgument = Union[ + Literal[False], str, _ColumnExpressionArgument[Any] +] +_ORMBackrefArgument = Union[str, Tuple[str, Dict[str, Any]]] +_ORMColCollectionArgument = Union[ + str, + Sequence[Union[ColumnClause[Any], _HasClauseElement, roles.DMLColumnRole]], +] + def remote(expr): """Annotate a portion of a primaryjoin expression @@ -144,6 +181,7 @@ class Relationship( inherit_cache = True _links_to_entity = True + _is_relationship = True _persistence_only = dict( passive_deletes=False, @@ -159,38 +197,39 @@ class Relationship( self, argument: Optional[_RelationshipArgumentType[_T]] = None, secondary=None, + *, + uselist=None, + collection_class=None, primaryjoin=None, secondaryjoin=None, - foreign_keys=None, - uselist=None, + back_populates=None, order_by=False, backref=None, - back_populates=None, + cascade_backrefs=False, overlaps=None, post_update=False, - cascade=False, + cascade="save-update, merge", viewonly=False, - lazy="select", - collection_class=None, - passive_deletes=_persistence_only["passive_deletes"], - passive_updates=_persistence_only["passive_updates"], + lazy: _LazyLoadArgumentType = "select", + passive_deletes=False, + passive_updates=True, + active_history=False, + enable_typechecks=True, + foreign_keys=None, remote_side=None, - enable_typechecks=_persistence_only["enable_typechecks"], join_depth=None, comparator_factory=None, single_parent=False, innerjoin=False, distinct_target_key=None, - doc=None, - active_history=_persistence_only["active_history"], - cascade_backrefs=_persistence_only["cascade_backrefs"], load_on_pending=False, - bake_queries=True, - _local_remote_pairs=None, query_class=None, info=None, omit_join=None, sync_backref=None, + doc=None, + bake_queries=True, + _local_remote_pairs=None, _legacy_inactive_history_style=False, ): super(Relationship, self).__init__() @@ -250,7 +289,6 @@ class Relationship( self.omit_join = omit_join self.local_remote_pairs = _local_remote_pairs - self.bake_queries = bake_queries self.load_on_pending = load_on_pending self.comparator_factory = comparator_factory or Relationship.Comparator self.comparator = self.comparator_factory(self, None) @@ -267,12 +305,7 @@ class Relationship( else: self._overlaps = () - if cascade is not False: - self.cascade = cascade - elif self.viewonly: - self.cascade = "none" - else: - self.cascade = "save-update, merge" + self.cascade = cascade self.order_by = order_by @@ -539,9 +572,9 @@ class Relationship( def _criterion_exists( self, - criterion: Optional[SQLCoreOperations[Any]] = None, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, **kwargs: Any, - ) -> Exists[bool]: + ) -> Exists: if getattr(self, "_of_type", None): info = inspect(self._of_type) target_mapper, to_selectable, is_aliased_class = ( @@ -898,7 +931,12 @@ class Relationship( comparator: Comparator[_T] - def _with_parent(self, instance, alias_secondary=True, from_entity=None): + def _with_parent( + self, + instance: object, + alias_secondary: bool = True, + from_entity: Optional[_EntityType[Any]] = None, + ) -> ColumnElement[bool]: assert instance is not None adapt_source = None if from_entity is not None: @@ -1502,7 +1540,7 @@ class Relationship( argument = argument if isinstance(argument, type): - entity = mapperlib.class_mapper(argument, configure=False) + entity = class_mapper(argument, configure=False) else: try: entity = inspect(argument) @@ -1568,7 +1606,7 @@ class Relationship( """Test that this relationship is legal, warn about inheritance conflicts.""" mapperlib = util.preloaded.orm_mapper - if self.parent.non_primary and not mapperlib.class_mapper( + if self.parent.non_primary and not class_mapper( self.parent.class_, configure=False ).has_property(self.key): raise sa_exc.ArgumentError( @@ -1585,29 +1623,23 @@ class Relationship( ) @property - def cascade(self): + def cascade(self) -> CascadeOptions: """Return the current cascade setting for this :class:`.Relationship`. """ return self._cascade @cascade.setter - def cascade(self, cascade): + def cascade(self, cascade: Union[str, CascadeOptions]): self._set_cascade(cascade) - def _set_cascade(self, cascade): - cascade = CascadeOptions(cascade) + def _set_cascade(self, cascade_arg: Union[str, CascadeOptions]): + cascade = CascadeOptions(cascade_arg) if self.viewonly: - non_viewonly = set(cascade).difference( - CascadeOptions._viewonly_cascades + cascade = CascadeOptions( + cascade.intersection(CascadeOptions._viewonly_cascades) ) - if non_viewonly: - raise sa_exc.ArgumentError( - 'Cascade settings "%s" apply to persistence operations ' - "and should not be combined with a viewonly=True " - "relationship." % (", ".join(sorted(non_viewonly))) - ) if "mapper" in self.__dict__: self._check_cascade_settings(cascade) @@ -1754,8 +1786,8 @@ class Relationship( relationship = Relationship( parent, self.secondary, - pj, - sj, + primaryjoin=pj, + secondaryjoin=sj, foreign_keys=foreign_keys, back_populates=self.key, **kwargs, diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 5b1d0bb08..74035ec0a 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -39,6 +39,7 @@ from . import persistence from . import query from . import state as statelib from ._typing import _O +from ._typing import insp_is_mapper from ._typing import is_composite_class from ._typing import is_user_defined_option from .base import _class_to_mapper @@ -69,12 +70,14 @@ from ..engine.util import TransactionalContext from ..event import dispatcher from ..event import EventTarget from ..inspection import inspect +from ..inspection import Inspectable from ..sql import coercions from ..sql import dml from ..sql import roles from ..sql import Select from ..sql import visitors from ..sql.base import CompileState +from ..sql.schema import Table from ..sql.selectable import ForUpdateArg from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from ..util import IdentitySet @@ -90,6 +93,7 @@ if typing.TYPE_CHECKING: from .path_registry import PathRegistry from ..engine import Result from ..engine import Row + from ..engine import RowMapping from ..engine.base import Transaction from ..engine.base import TwoPhaseTransaction from ..engine.interfaces import _CoreAnyExecuteParams @@ -103,6 +107,7 @@ if typing.TYPE_CHECKING: from ..sql.base import Executable from ..sql.elements import ClauseElement from ..sql.schema import Table + from ..sql.selectable import TableClause __all__ = [ "Session", @@ -184,7 +189,7 @@ class _SessionClassMethods: ident: Union[Any, Tuple[Any, ...]] = None, *, instance: Optional[Any] = None, - row: Optional[Row] = None, + row: Optional[Union[Row, RowMapping]] = None, identity_token: Optional[Any] = None, ) -> _IdentityKeyType[Any]: """Return an identity key. @@ -2050,9 +2055,12 @@ class Session(_SessionClassMethods, EventTarget): else: self.__binds[key] = bind else: - if insp.is_selectable: + if TYPE_CHECKING: + assert isinstance(insp, Inspectable) + + if isinstance(insp, Table): self.__binds[insp] = bind - elif insp.is_mapper: + elif insp_is_mapper(insp): self.__binds[insp.class_] = bind for _selectable in insp._all_tables: self.__binds[_selectable] = bind @@ -2211,7 +2219,7 @@ class Session(_SessionClassMethods, EventTarget): # we don't have self.bind and either have self.__binds # or we don't have self.__binds (which is legacy). Look at the # mapper and the clause - if mapper is clause is None: + if mapper is None and clause is None: if self.bind: return self.bind else: @@ -2350,7 +2358,10 @@ class Session(_SessionClassMethods, EventTarget): key = mapper.identity_key_from_primary_key( primary_key_identity, identity_token=identity_token ) - return loading.get_from_identity(self, mapper, key, passive) + + # work around: https://github.com/python/typing/discussions/1143 + return_value = loading.get_from_identity(self, mapper, key, passive) + return return_value @util.non_memoized_property @contextlib.contextmanager diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 85e015193..2d85ba7f6 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -2162,10 +2162,11 @@ class JoinedLoader(AbstractRelationshipLoader): else: to_adapt = self._gen_pooled_aliased_class(compile_state) - clauses = inspect(to_adapt)._memo( + to_adapt_insp = inspect(to_adapt) + clauses = to_adapt_insp._memo( ("joinedloader_ormadapter", self), orm_util.ORMAdapter, - to_adapt, + to_adapt_insp, equivalents=self.mapper._equivalent_columns, adapt_required=True, allow_label_resolve=False, diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 4699781a4..3934de535 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -11,8 +11,16 @@ import re import types import typing from typing import Any +from typing import cast +from typing import Dict +from typing import FrozenSet from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Match from typing import Optional +from typing import Sequence from typing import Tuple from typing import Type from typing import TypeVar @@ -20,32 +28,35 @@ from typing import Union import weakref from . import attributes # noqa -from .base import _class_to_mapper # noqa -from .base import _never_set # noqa -from .base import _none_set # noqa -from .base import attribute_str # noqa -from .base import class_mapper # noqa -from .base import InspectionAttr # noqa -from .base import instance_str # noqa -from .base import object_mapper # noqa -from .base import object_state # noqa -from .base import state_attribute_str # noqa -from .base import state_class_str # noqa -from .base import state_str # noqa +from ._typing import _O +from ._typing import insp_is_aliased_class +from ._typing import insp_is_mapper +from ._typing import prop_is_relationship +from .base import _class_to_mapper as _class_to_mapper +from .base import _never_set as _never_set +from .base import _none_set as _none_set +from .base import attribute_str as attribute_str +from .base import class_mapper as class_mapper +from .base import InspectionAttr as InspectionAttr +from .base import instance_str as instance_str +from .base import object_mapper as object_mapper +from .base import object_state as object_state +from .base import state_attribute_str as state_attribute_str +from .base import state_class_str as state_class_str +from .base import state_str as state_str from .interfaces import CriteriaOption -from .interfaces import MapperProperty # noqa +from .interfaces import MapperProperty as MapperProperty from .interfaces import ORMColumnsClauseRole from .interfaces import ORMEntityColumnsClauseRole from .interfaces import ORMFromClauseRole -from .interfaces import PropComparator # noqa -from .path_registry import PathRegistry # noqa +from .interfaces import PropComparator as PropComparator +from .path_registry import PathRegistry as PathRegistry from .. import event from .. import exc as sa_exc from .. import inspection from .. import sql from .. import util from ..engine.result import result_tuple -from ..sql import base as sql_base from ..sql import coercions from ..sql import expression from ..sql import lambdas @@ -54,19 +65,39 @@ from ..sql import util as sql_util from ..sql import visitors from ..sql.annotation import SupportsCloneAnnotations from ..sql.base import ColumnCollection +from ..sql.cache_key import HasCacheKey +from ..sql.cache_key import MemoizedHasCacheKey +from ..sql.elements import ColumnElement from ..sql.selectable import FromClause from ..util.langhelpers import MemoizedSlots from ..util.typing import de_stringify_annotation from ..util.typing import is_origin_of +from ..util.typing import Literal if typing.TYPE_CHECKING: from ._typing import _EntityType from ._typing import _IdentityKeyType from ._typing import _InternalEntityType + from ._typing import _ORMColumnExprArgument + from .context import _MapperEntity + from .context import ORMCompileState from .mapper import Mapper + from .relationships import Relationship from ..engine import Row + from ..engine import RowMapping + from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _EquivalentColumnMap + from ..sql._typing import _FromClauseArgument + from ..sql._typing import _OnClauseArgument from ..sql._typing import _PropagateAttrsType + from ..sql.base import ReadOnlyColumnCollection + from ..sql.elements import BindParameter + from ..sql.selectable import _ColumnsClauseElement from ..sql.selectable import Alias + from ..sql.selectable import Subquery + from ..sql.visitors import _ET + from ..sql.visitors import anon_map + from ..sql.visitors import ExternallyTraversible _T = TypeVar("_T", bound=Any) @@ -84,7 +115,7 @@ all_cascades = frozenset( ) -class CascadeOptions(frozenset): +class CascadeOptions(FrozenSet[str]): """Keeps track of the options sent to :paramref:`.relationship.cascade`""" @@ -104,6 +135,13 @@ class CascadeOptions(frozenset): "delete_orphan", ) + save_update: bool + delete: bool + refresh_expire: bool + merge: bool + expunge: bool + delete_orphan: bool + def __new__(cls, value_list): if isinstance(value_list, str) or value_list is None: return cls.from_string(value_list) @@ -127,7 +165,7 @@ class CascadeOptions(frozenset): values.clear() values.discard("all") - self = frozenset.__new__(CascadeOptions, values) + self = super().__new__(cls, values) # type: ignore self.save_update = "save-update" in values self.delete = "delete" in values self.refresh_expire = "refresh-expire" in values @@ -238,7 +276,7 @@ def polymorphic_union( """ - colnames = util.OrderedSet() + colnames: util.OrderedSet[str] = util.OrderedSet() colnamemaps = {} types = {} for key in table_map: @@ -299,13 +337,13 @@ def polymorphic_union( def identity_key( - class_: Optional[Type[Any]] = None, + class_: Optional[Type[_T]] = None, ident: Union[Any, Tuple[Any, ...]] = None, *, - instance: Optional[Any] = None, - row: Optional[Row] = None, + instance: Optional[_T] = None, + row: Optional[Union[Row, RowMapping]] = None, identity_token: Optional[Any] = None, -) -> _IdentityKeyType: +) -> _IdentityKeyType[_T]: r"""Generate "identity key" tuples, as are used as keys in the :attr:`.Session.identity_map` dictionary. @@ -351,7 +389,7 @@ def identity_key( * ``identity_key(class, row=row, identity_token=token)`` This form is similar to the class/tuple form, except is passed a - database result row as a :class:`.Row` object. + database result row as a :class:`.Row` or :class:`.RowMapping` object. E.g.:: @@ -375,7 +413,7 @@ def identity_key( if ident is None: raise sa_exc.ArgumentError("ident or row is required") return mapper.identity_key_from_primary_key( - util.to_list(ident), identity_token=identity_token + tuple(util.to_list(ident)), identity_token=identity_token ) else: return mapper.identity_key_from_row( @@ -394,24 +432,26 @@ class ORMAdapter(sql_util.ColumnAdapter): """ - is_aliased_class = False - aliased_insp = None + is_aliased_class: bool + aliased_insp: Optional[AliasedInsp[Any]] def __init__( self, - entity, - equivalents=None, - adapt_required=False, - allow_label_resolve=True, - anonymize_labels=False, + entity: _InternalEntityType[Any], + equivalents: Optional[_EquivalentColumnMap] = None, + adapt_required: bool = False, + allow_label_resolve: bool = True, + anonymize_labels: bool = False, ): - info = inspection.inspect(entity) - self.mapper = info.mapper - selectable = info.selectable - if info.is_aliased_class: + self.mapper = entity.mapper + selectable = entity.selectable + if insp_is_aliased_class(entity): self.is_aliased_class = True - self.aliased_insp = info + self.aliased_insp = entity + else: + self.is_aliased_class = False + self.aliased_insp = None sql_util.ColumnAdapter.__init__( self, @@ -428,7 +468,7 @@ class ORMAdapter(sql_util.ColumnAdapter): return not entity or entity.isa(self.mapper) -class AliasedClass(inspection.Inspectable["AliasedInsp"], Generic[_T]): +class AliasedClass(inspection.Inspectable["AliasedInsp[_O]"], Generic[_O]): r"""Represents an "aliased" form of a mapped class for usage with Query. The ORM equivalent of a :func:`~sqlalchemy.sql.expression.alias` @@ -489,19 +529,20 @@ class AliasedClass(inspection.Inspectable["AliasedInsp"], Generic[_T]): def __init__( self, - mapped_class_or_ac: Union[Type[_T], "Mapper[_T]", "AliasedClass[_T]"], - alias=None, - name=None, - flat=False, - adapt_on_names=False, - # TODO: None for default here? - with_polymorphic_mappers=(), - with_polymorphic_discriminator=None, - base_alias=None, - use_mapper_path=False, - represents_outer_join=False, + mapped_class_or_ac: _EntityType[_O], + alias: Optional[FromClause] = None, + name: Optional[str] = None, + flat: bool = False, + adapt_on_names: bool = False, + with_polymorphic_mappers: Optional[Sequence[Mapper[Any]]] = None, + with_polymorphic_discriminator: Optional[ColumnElement[Any]] = None, + base_alias: Optional[AliasedInsp[Any]] = None, + use_mapper_path: bool = False, + represents_outer_join: bool = False, ): - insp = inspection.inspect(mapped_class_or_ac) + insp = cast( + "_InternalEntityType[_O]", inspection.inspect(mapped_class_or_ac) + ) mapper = insp.mapper nest_adapters = False @@ -519,6 +560,7 @@ class AliasedClass(inspection.Inspectable["AliasedInsp"], Generic[_T]): elif insp.is_aliased_class: nest_adapters = True + assert alias is not None self._aliased_insp = AliasedInsp( self, insp, @@ -540,7 +582,9 @@ class AliasedClass(inspection.Inspectable["AliasedInsp"], Generic[_T]): self.__name__ = f"aliased({mapper.class_.__name__})" @classmethod - def _reconstitute_from_aliased_insp(cls, aliased_insp): + def _reconstitute_from_aliased_insp( + cls, aliased_insp: AliasedInsp[_O] + ) -> AliasedClass[_O]: obj = cls.__new__(cls) obj.__name__ = f"aliased({aliased_insp.mapper.class_.__name__})" obj._aliased_insp = aliased_insp @@ -555,7 +599,7 @@ class AliasedClass(inspection.Inspectable["AliasedInsp"], Generic[_T]): return obj - def __getattr__(self, key): + def __getattr__(self, key: str) -> Any: try: _aliased_insp = self.__dict__["_aliased_insp"] except KeyError: @@ -584,7 +628,9 @@ class AliasedClass(inspection.Inspectable["AliasedInsp"], Generic[_T]): return attr - def _get_from_serialized(self, key, mapped_class, aliased_insp): + def _get_from_serialized( + self, key: str, mapped_class: _O, aliased_insp: AliasedInsp[_O] + ) -> Any: # this method is only used in terms of the # sqlalchemy.ext.serializer extension attr = getattr(mapped_class, key) @@ -605,23 +651,25 @@ class AliasedClass(inspection.Inspectable["AliasedInsp"], Generic[_T]): return attr - def __repr__(self): + def __repr__(self) -> str: return "<AliasedClass at 0x%x; %s>" % ( id(self), self._aliased_insp._target.__name__, ) - def __str__(self): + def __str__(self) -> str: return str(self._aliased_insp) +@inspection._self_inspects class AliasedInsp( ORMEntityColumnsClauseRole, ORMFromClauseRole, - sql_base.HasCacheKey, + HasCacheKey, InspectionAttr, MemoizedSlots, - Generic[_T], + inspection.Inspectable["AliasedInsp[_O]"], + Generic[_O], ): """Provide an inspection interface for an :class:`.AliasedClass` object. @@ -685,19 +733,36 @@ class AliasedInsp( "_nest_adapters", ) + mapper: Mapper[_O] + selectable: FromClause + _adapter: sql_util.ColumnAdapter + with_polymorphic_mappers: Sequence[Mapper[Any]] + _with_polymorphic_entities: Sequence[AliasedInsp[Any]] + + _weak_entity: weakref.ref[AliasedClass[_O]] + """the AliasedClass that refers to this AliasedInsp""" + + _target: Union[_O, AliasedClass[_O]] + """the thing referred towards by the AliasedClass/AliasedInsp. + + In the vast majority of cases, this is the mapped class. However + it may also be another AliasedClass (alias of alias). + + """ + def __init__( self, - entity: _EntityType, - inspected: _InternalEntityType, - selectable, - name, - with_polymorphic_mappers, - polymorphic_on, - _base_alias, - _use_mapper_path, - adapt_on_names, - represents_outer_join, - nest_adapters, + entity: AliasedClass[_O], + inspected: _InternalEntityType[_O], + selectable: FromClause, + name: Optional[str], + with_polymorphic_mappers: Optional[Sequence[Mapper[Any]]], + polymorphic_on: Optional[ColumnElement[Any]], + _base_alias: Optional[AliasedInsp[Any]], + _use_mapper_path: bool, + adapt_on_names: bool, + represents_outer_join: bool, + nest_adapters: bool, ): mapped_class_or_ac = inspected.entity @@ -752,23 +817,22 @@ class AliasedInsp( ) if nest_adapters: + # supports "aliased class of aliased class" use case + assert isinstance(inspected, AliasedInsp) self._adapter = inspected._adapter.wrap(self._adapter) self._adapt_on_names = adapt_on_names self._target = mapped_class_or_ac - # self._target = mapper.class_ # mapped_class_or_ac @classmethod def _alias_factory( cls, - element: Union[ - Type[_T], "Mapper[_T]", "FromClause", "AliasedClass[_T]" - ], - alias=None, - name=None, - flat=False, - adapt_on_names=False, - ) -> Union["AliasedClass[_T]", "Alias"]: + element: Union[_EntityType[_O], FromClause], + alias: Optional[Union[Alias, Subquery]] = None, + name: Optional[str] = None, + flat: bool = False, + adapt_on_names: bool = False, + ) -> Union[AliasedClass[_O], FromClause]: if isinstance(element, FromClause): if adapt_on_names: @@ -793,16 +857,16 @@ class AliasedInsp( @classmethod def _with_polymorphic_factory( cls, - base, - classes, - selectable=False, - flat=False, - polymorphic_on=None, - aliased=False, - innerjoin=False, - adapt_on_names=False, - _use_mapper_path=False, - ): + base: Union[_O, Mapper[_O]], + classes: Iterable[Type[Any]], + selectable: Union[Literal[False, None], FromClause] = False, + flat: bool = False, + polymorphic_on: Optional[ColumnElement[Any]] = None, + aliased: bool = False, + innerjoin: bool = False, + adapt_on_names: bool = False, + _use_mapper_path: bool = False, + ) -> AliasedClass[_O]: primary_mapper = _class_to_mapper(base) @@ -816,7 +880,9 @@ class AliasedInsp( classes, selectable, innerjoin=innerjoin ) if aliased or flat: + assert selectable is not None selectable = selectable._anonymous_fromclause(flat=flat) + return AliasedClass( base, selectable, @@ -828,7 +894,7 @@ class AliasedInsp( ) @property - def entity(self): + def entity(self) -> AliasedClass[_O]: # to eliminate reference cycles, the AliasedClass is held weakly. # this produces some situations where the AliasedClass gets lost, # particularly when one is created internally and only the AliasedInsp @@ -844,7 +910,7 @@ class AliasedInsp( is_aliased_class = True "always returns True" - def _memoized_method___clause_element__(self): + def _memoized_method___clause_element__(self) -> FromClause: return self.selectable._annotate( { "parentmapper": self.mapper, @@ -856,7 +922,7 @@ class AliasedInsp( ) @property - def entity_namespace(self): + def entity_namespace(self) -> AliasedClass[_O]: return self.entity _cache_key_traversal = [ @@ -866,7 +932,7 @@ class AliasedInsp( ] @property - def class_(self): + def class_(self) -> Type[_O]: """Return the mapped class ultimately represented by this :class:`.AliasedInsp`.""" return self.mapper.class_ @@ -878,7 +944,7 @@ class AliasedInsp( else: return PathRegistry.per_mapper(self) - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: return { "entity": self.entity, "mapper": self.mapper, @@ -893,8 +959,8 @@ class AliasedInsp( "nest_adapters": self._nest_adapters, } - def __setstate__(self, state): - self.__init__( + def __setstate__(self, state: Dict[str, Any]) -> None: + self.__init__( # type: ignore state["entity"], state["mapper"], state["alias"], @@ -908,7 +974,7 @@ class AliasedInsp( state["nest_adapters"], ) - def _merge_with(self, other): + def _merge_with(self, other: AliasedInsp[_O]) -> AliasedInsp[_O]: # assert self._is_with_polymorphic # assert other._is_with_polymorphic @@ -929,7 +995,6 @@ class AliasedInsp( classes, None, innerjoin=not other.represents_outer_join ) selectable = selectable._anonymous_fromclause(flat=True) - return AliasedClass( primary_mapper, selectable, @@ -937,10 +1002,13 @@ class AliasedInsp( with_polymorphic_discriminator=other.polymorphic_on, use_mapper_path=other._use_mapper_path, represents_outer_join=other.represents_outer_join, - ) + )._aliased_insp - def _adapt_element(self, elem, key=None): - d = { + def _adapt_element( + self, elem: _ORMColumnExprArgument[_T], key: Optional[str] = None + ) -> _ORMColumnExprArgument[_T]: + assert isinstance(elem, ColumnElement) + d: Dict[str, Any] = { "parententity": self, "parentmapper": self.mapper, } @@ -1084,35 +1152,45 @@ class LoaderCriteriaOption(CriteriaOption): ("propagate_to_loaders", visitors.InternalTraversal.dp_boolean), ] + root_entity: Optional[Type[Any]] + entity: Optional[_InternalEntityType[Any]] + where_criteria: Union[ColumnElement[bool], lambdas.DeferredLambdaElement] + deferred_where_criteria: bool + include_aliases: bool + propagate_to_loaders: bool + def __init__( self, - entity_or_base, - where_criteria, - loader_only=False, - include_aliases=False, - propagate_to_loaders=True, - track_closure_variables=True, + entity_or_base: _EntityType[Any], + where_criteria: _ColumnExpressionArgument[bool], + loader_only: bool = False, + include_aliases: bool = False, + propagate_to_loaders: bool = True, + track_closure_variables: bool = True, ): - entity = inspection.inspect(entity_or_base, False) + entity = cast( + "_InternalEntityType[Any]", + inspection.inspect(entity_or_base, False), + ) if entity is None: - self.root_entity = entity_or_base + self.root_entity = cast("Type[Any]", entity_or_base) self.entity = None else: self.root_entity = None self.entity = entity if callable(where_criteria): + if self.root_entity is not None: + wrap_entity = self.root_entity + else: + assert entity is not None + wrap_entity = entity.entity + self.deferred_where_criteria = True self.where_criteria = lambdas.DeferredLambdaElement( - where_criteria, + where_criteria, # type: ignore roles.WhereHavingRole, - lambda_args=( - _WrapUserEntity( - self.root_entity - if self.root_entity is not None - else self.entity.entity, - ), - ), + lambda_args=(_WrapUserEntity(wrap_entity),), opts=lambdas.LambdaOptions( track_closure_variables=track_closure_variables ), @@ -1126,22 +1204,27 @@ class LoaderCriteriaOption(CriteriaOption): self.include_aliases = include_aliases self.propagate_to_loaders = propagate_to_loaders - def _all_mappers(self): + def _all_mappers(self) -> Iterator[Mapper[Any]]: + if self.entity: - for ent in self.entity.mapper.self_and_descendants: - yield ent + for mp_ent in self.entity.mapper.self_and_descendants: + yield mp_ent else: + assert self.root_entity stack = list(self.root_entity.__subclasses__()) while stack: subclass = stack.pop(0) - ent = inspection.inspect(subclass, raiseerr=False) + ent = cast( + "_InternalEntityType[Any]", + inspection.inspect(subclass, raiseerr=False), + ) if ent: for mp in ent.mapper.self_and_descendants: yield mp else: stack.extend(subclass.__subclasses__()) - def _should_include(self, compile_state): + def _should_include(self, compile_state: ORMCompileState) -> bool: if ( compile_state.select_statement._annotations.get( "for_loader_criteria", None @@ -1151,21 +1234,29 @@ class LoaderCriteriaOption(CriteriaOption): return False return True - def _resolve_where_criteria(self, ext_info): + def _resolve_where_criteria( + self, ext_info: _InternalEntityType[Any] + ) -> ColumnElement[bool]: if self.deferred_where_criteria: - crit = self.where_criteria._resolve_with_args(ext_info.entity) + crit = cast( + "ColumnElement[bool]", + self.where_criteria._resolve_with_args(ext_info.entity), + ) else: - crit = self.where_criteria + crit = self.where_criteria # type: ignore + assert isinstance(crit, ColumnElement) return sql_util._deep_annotate( crit, {"for_loader_criteria": self}, detect_subquery_cols=True ) def process_compile_state_replaced_entities( - self, compile_state, mapper_entities - ): - return self.process_compile_state(compile_state) + self, + compile_state: ORMCompileState, + mapper_entities: Iterable[_MapperEntity], + ) -> None: + self.process_compile_state(compile_state) - def process_compile_state(self, compile_state): + def process_compile_state(self, compile_state: ORMCompileState) -> None: """Apply a modification to a given :class:`.CompileState`.""" # if options to limit the criteria to immediate query only, @@ -1173,7 +1264,7 @@ class LoaderCriteriaOption(CriteriaOption): self.get_global_criteria(compile_state.global_attributes) - def get_global_criteria(self, attributes): + def get_global_criteria(self, attributes: Dict[Any, Any]) -> None: for mp in self._all_mappers(): load_criteria = attributes.setdefault( ("additional_entity_criteria", mp), [] @@ -1183,14 +1274,14 @@ class LoaderCriteriaOption(CriteriaOption): inspection._inspects(AliasedClass)(lambda target: target._aliased_insp) -inspection._inspects(AliasedInsp)(lambda target: target) @inspection._self_inspects class Bundle( ORMColumnsClauseRole, SupportsCloneAnnotations, - sql_base.MemoizedHasCacheKey, + MemoizedHasCacheKey, + inspection.Inspectable["Bundle"], InspectionAttr, ): """A grouping of SQL expressions that are returned by a :class:`.Query` @@ -1227,7 +1318,11 @@ class Bundle( _propagate_attrs: _PropagateAttrsType = util.immutabledict() - def __init__(self, name, *exprs, **kw): + exprs: List[_ColumnsClauseElement] + + def __init__( + self, name: str, *exprs: _ColumnExpressionArgument[Any], **kw: Any + ): r"""Construct a new :class:`.Bundle`. e.g.:: @@ -1246,37 +1341,43 @@ class Bundle( """ self.name = self._label = name - self.exprs = exprs = [ + coerced_exprs = [ coercions.expect( roles.ColumnsClauseRole, expr, apply_propagate_attrs=self ) for expr in exprs ] + self.exprs = coerced_exprs self.c = self.columns = ColumnCollection( (getattr(col, "key", col._label), col) - for col in [e._annotations.get("bundle", e) for e in exprs] - ) + for col in [e._annotations.get("bundle", e) for e in coerced_exprs] + ).as_readonly() self.single_entity = kw.pop("single_entity", self.single_entity) - def _gen_cache_key(self, anon_map, bindparams): + def _gen_cache_key( + self, anon_map: anon_map, bindparams: List[BindParameter[Any]] + ) -> Tuple[Any, ...]: return (self.__class__, self.name, self.single_entity) + tuple( [expr._gen_cache_key(anon_map, bindparams) for expr in self.exprs] ) @property - def mapper(self): + def mapper(self) -> Mapper[Any]: return self.exprs[0]._annotations.get("parentmapper", None) @property - def entity(self): + def entity(self) -> _InternalEntityType[Any]: return self.exprs[0]._annotations.get("parententity", None) @property - def entity_namespace(self): + def entity_namespace( + self, + ) -> ReadOnlyColumnCollection[str, ColumnElement[Any]]: return self.c - columns = None + columns: ReadOnlyColumnCollection[str, ColumnElement[Any]] + """A namespace of SQL expressions referred to by this :class:`.Bundle`. e.g.:: @@ -1301,7 +1402,7 @@ class Bundle( """ - c = None + c: ReadOnlyColumnCollection[str, ColumnElement[Any]] """An alias for :attr:`.Bundle.columns`.""" def _clone(self): @@ -1400,32 +1501,30 @@ class _ORMJoin(expression.Join): def __init__( self, - left, - right, - onclause=None, - isouter=False, - full=False, - _left_memo=None, - _right_memo=None, - _extra_criteria=(), + left: _FromClauseArgument, + right: _FromClauseArgument, + onclause: Optional[_OnClauseArgument] = None, + isouter: bool = False, + full: bool = False, + _left_memo: Optional[Any] = None, + _right_memo: Optional[Any] = None, + _extra_criteria: Sequence[ColumnElement[bool]] = (), ): - left_info = inspection.inspect(left) + left_info = cast( + "Union[FromClause, _InternalEntityType[Any]]", + inspection.inspect(left), + ) - right_info = inspection.inspect(right) + right_info = cast( + "Union[FromClause, _InternalEntityType[Any]]", + inspection.inspect(right), + ) adapt_to = right_info.selectable # used by joined eager loader self._left_memo = _left_memo self._right_memo = _right_memo - # legacy, for string attr name ON clause. if that's removed - # then the "_joined_from_info" concept can go - left_orm_info = getattr(left, "_joined_from_info", left_info) - self._joined_from_info = right_info - if isinstance(onclause, str): - onclause = getattr(left_orm_info.entity, onclause) - # #### - if isinstance(onclause, attributes.QueryableAttribute): on_selectable = onclause.comparator._source_selectable() prop = onclause.property @@ -1477,20 +1576,23 @@ class _ORMJoin(expression.Join): augment_onclause = onclause is None and _extra_criteria expression.Join.__init__(self, left, right, onclause, isouter, full) + assert self.onclause is not None + if augment_onclause: self.onclause &= sql.and_(*_extra_criteria) if ( not prop and getattr(right_info, "mapper", None) - and right_info.mapper.single + and right_info.mapper.single # type: ignore ): + right_info = cast("_InternalEntityType[Any]", right_info) # if single inheritance target and we are using a manual # or implicit ON clause, augment it the same way we'd augment the # WHERE. single_crit = right_info.mapper._single_table_criterion if single_crit is not None: - if right_info.is_aliased_class: + if insp_is_aliased_class(right_info): single_crit = right_info._adapter.traverse(single_crit) self.onclause = self.onclause & single_crit @@ -1525,19 +1627,27 @@ class _ORMJoin(expression.Join): def join( self, - right, - onclause=None, - isouter=False, - full=False, - join_to_left=None, - ): + right: _FromClauseArgument, + onclause: Optional[_OnClauseArgument] = None, + isouter: bool = False, + full: bool = False, + ) -> _ORMJoin: return _ORMJoin(self, right, onclause, full=full, isouter=isouter) - def outerjoin(self, right, onclause=None, full=False, join_to_left=None): + def outerjoin( + self, + right: _FromClauseArgument, + onclause: Optional[_OnClauseArgument] = None, + full: bool = False, + ) -> _ORMJoin: return _ORMJoin(self, right, onclause, isouter=True, full=full) -def with_parent(instance, prop, from_entity=None): +def with_parent( + instance: object, + prop: attributes.QueryableAttribute[Any], + from_entity: Optional[_EntityType[Any]] = None, +) -> ColumnElement[bool]: """Create filtering criterion that relates this query's primary entity to the given related instance, using established :func:`_orm.relationship()` @@ -1588,6 +1698,8 @@ def with_parent(instance, prop, from_entity=None): .. versionadded:: 1.2 """ + prop_t: Relationship[Any] + if isinstance(prop, str): raise sa_exc.ArgumentError( "with_parent() accepts class-bound mapped attributes, not strings" @@ -1595,12 +1707,19 @@ def with_parent(instance, prop, from_entity=None): elif isinstance(prop, attributes.QueryableAttribute): if prop._of_type: from_entity = prop._of_type - prop = prop.property + if not prop_is_relationship(prop.property): + raise sa_exc.ArgumentError( + f"Expected relationship property for with_parent(), " + f"got {prop.property}" + ) + prop_t = prop.property + else: + prop_t = prop - return prop._with_parent(instance, from_entity=from_entity) + return prop_t._with_parent(instance, from_entity=from_entity) -def has_identity(object_): +def has_identity(object_: object) -> bool: """Return True if the given object has a database identity. @@ -1616,7 +1735,7 @@ def has_identity(object_): return state.has_identity -def was_deleted(object_): +def was_deleted(object_: object) -> bool: """Return True if the given object was deleted within a session flush. @@ -1633,27 +1752,32 @@ def was_deleted(object_): return state.was_deleted -def _entity_corresponds_to(given, entity): +def _entity_corresponds_to( + given: _InternalEntityType[Any], entity: _InternalEntityType[Any] +) -> bool: """determine if 'given' corresponds to 'entity', in terms of an entity passed to Query that would match the same entity being referred to elsewhere in the query. """ - if entity.is_aliased_class: - if given.is_aliased_class: + if insp_is_aliased_class(entity): + if insp_is_aliased_class(given): if entity._base_alias() is given._base_alias(): return True return False - elif given.is_aliased_class: + elif insp_is_aliased_class(given): if given._use_mapper_path: return entity in given.with_polymorphic_mappers else: return entity is given + assert insp_is_mapper(given) return entity.common_parent(given) -def _entity_corresponds_to_use_path_impl(given, entity): +def _entity_corresponds_to_use_path_impl( + given: _InternalEntityType[Any], entity: _InternalEntityType[Any] +) -> bool: """determine if 'given' corresponds to 'entity', in terms of a path of loader options where a mapped attribute is taken to be a member of a parent entity. @@ -1673,13 +1797,13 @@ def _entity_corresponds_to_use_path_impl(given, entity): """ - if given.is_aliased_class: + if insp_is_aliased_class(given): return ( - entity.is_aliased_class + insp_is_aliased_class(entity) and not entity._use_mapper_path and (given is entity or entity in given._with_polymorphic_entities) ) - elif not entity.is_aliased_class: + elif not insp_is_aliased_class(entity): return given.isa(entity.mapper) else: return ( @@ -1688,7 +1812,7 @@ def _entity_corresponds_to_use_path_impl(given, entity): ) -def _entity_isa(given, mapper): +def _entity_isa(given: _InternalEntityType[Any], mapper: Mapper[Any]) -> bool: """determine if 'given' "is a" mapper, in terms of the given would load rows of type 'mapper'. @@ -1703,42 +1827,6 @@ def _entity_isa(given, mapper): return given.isa(mapper) -def randomize_unitofwork(): - """Use random-ordering sets within the unit of work in order - to detect unit of work sorting issues. - - This is a utility function that can be used to help reproduce - inconsistent unit of work sorting issues. For example, - if two kinds of objects A and B are being inserted, and - B has a foreign key reference to A - the A must be inserted first. - However, if there is no relationship between A and B, the unit of work - won't know to perform this sorting, and an operation may or may not - fail, depending on how the ordering works out. Since Python sets - and dictionaries have non-deterministic ordering, such an issue may - occur on some runs and not on others, and in practice it tends to - have a great dependence on the state of the interpreter. This leads - to so-called "heisenbugs" where changing entirely irrelevant aspects - of the test program still cause the failure behavior to change. - - By calling ``randomize_unitofwork()`` when a script first runs, the - ordering of a key series of sets within the unit of work implementation - are randomized, so that the script can be minimized down to the - fundamental mapping and operation that's failing, while still reproducing - the issue on at least some runs. - - This utility is also available when running the test suite via the - ``--reversetop`` flag. - - """ - from sqlalchemy.orm import unitofwork, session, mapper, dependency - from sqlalchemy.util import topological - from sqlalchemy.testing.util import RandomSet - - topological.set = ( - unitofwork.set - ) = session.set = mapper.set = dependency.set = RandomSet - - def _getitem(iterable_query, item): """calculate __getitem__ in terms of an iterable query object that also has a slice() method. @@ -1780,16 +1868,21 @@ def _getitem(iterable_query, item): return list(iterable_query[item : item + 1])[0] -def _is_mapped_annotation(raw_annotation: Union[type, str], cls: type): +def _is_mapped_annotation( + raw_annotation: Union[type, str], cls: Type[Any] +) -> bool: annotated = de_stringify_annotation(cls, raw_annotation) return is_origin_of(annotated, "Mapped", module="sqlalchemy.orm") -def _cleanup_mapped_str_annotation(annotation): +def _cleanup_mapped_str_annotation(annotation: str) -> str: # fix up an annotation that comes in as the form: # 'Mapped[List[Address]]' so that it instead looks like: # 'Mapped[List["Address"]]' , which will allow us to get # "Address" as a string + + inner: Optional[Match[str]] + mm = re.match(r"^(.+?)\[(.+)\]$", annotation) if mm and mm.group(1) == "Mapped": stack = [] @@ -1839,8 +1932,8 @@ def _extract_mapped_subtype( else: if ( not hasattr(annotated, "__origin__") - or not issubclass(annotated.__origin__, attr_cls) - and not issubclass(attr_cls, annotated.__origin__) + or not issubclass(annotated.__origin__, attr_cls) # type: ignore + and not issubclass(attr_cls, annotated.__origin__) # type: ignore ): our_annotated_str = ( annotated.__name__ @@ -1853,9 +1946,9 @@ def _extract_mapped_subtype( f'"{attr_cls.__name__}[{our_annotated_str}]".' ) - if len(annotated.__args__) != 1: + if len(annotated.__args__) != 1: # type: ignore raise sa_exc.ArgumentError( "Expected sub-type for Mapped[] annotation" ) - return annotated.__args__[0] + return annotated.__args__[0] # type: ignore diff --git a/lib/sqlalchemy/sql/_elements_constructors.py b/lib/sqlalchemy/sql/_elements_constructors.py index ea21e01c6..605f75ec4 100644 --- a/lib/sqlalchemy/sql/_elements_constructors.py +++ b/lib/sqlalchemy/sql/_elements_constructors.py @@ -389,7 +389,7 @@ def not_(clause: _ColumnExpressionArgument[_T]) -> ColumnElement[_T]: def bindparam( - key: str, + key: Optional[str], value: Any = _NoArg.NO_ARG, type_: Optional[TypeEngine[_T]] = None, unique: bool = False, @@ -521,6 +521,11 @@ def bindparam( key, or if its length is too long and truncation is required. + If omitted, an "anonymous" name is generated for the bound parameter; + when given a value to bind, the end result is equivalent to calling upon + the :func:`.literal` function with a value to bind, particularly + if the :paramref:`.bindparam.unique` parameter is also provided. + :param value: Initial value for this bind param. Will be used at statement execution time as the value for this parameter passed to the diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index b0a717a1a..53d29b628 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -2,13 +2,14 @@ from __future__ import annotations import operator from typing import Any +from typing import Callable from typing import Dict +from typing import Set from typing import Type from typing import TYPE_CHECKING from typing import TypeVar from typing import Union -from sqlalchemy.sql.base import Executable from . import roles from .. import util from ..inspection import Inspectable @@ -16,6 +17,7 @@ from ..util.typing import Literal from ..util.typing import Protocol if TYPE_CHECKING: + from .base import Executable from .compiler import Compiled from .compiler import DDLCompiler from .compiler import SQLCompiler @@ -27,17 +29,20 @@ if TYPE_CHECKING: from .elements import quoted_name from .elements import SQLCoreOperations from .elements import TextClause + from .lambdas import LambdaElement from .roles import ColumnsClauseRole from .roles import FromClauseRole from .schema import Column from .schema import DefaultGenerator from .schema import Sequence + from .schema import Table from .selectable import Alias from .selectable import FromClause from .selectable import Join from .selectable import NamedFromClause from .selectable import ReturnsRows from .selectable import Select + from .selectable import Selectable from .selectable import SelectBase from .selectable import Subquery from .selectable import TableClause @@ -46,7 +51,6 @@ if TYPE_CHECKING: from .type_api import TypeEngine from ..util.typing import TypeGuard - _T = TypeVar("_T", bound=Any) @@ -89,7 +93,11 @@ sets; select(...), insert().returning(...), etc. """ _ColumnExpressionArgument = Union[ - "ColumnElement[_T]", _HasClauseElement, roles.ExpressionElementRole[_T] + "ColumnElement[_T]", + _HasClauseElement, + roles.ExpressionElementRole[_T], + Callable[[], "ColumnElement[_T]"], + "LambdaElement", ] """narrower "column expression" argument. @@ -103,6 +111,7 @@ overall which brings in the TextClause object also. """ + _InfoType = Dict[Any, Any] """the .info dictionary accepted and used throughout Core /ORM""" @@ -169,6 +178,8 @@ _PropagateAttrsType = util.immutabledict[str, Any] _TypeEngineArgument = Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"] +_EquivalentColumnMap = Dict["ColumnElement[Any]", Set["ColumnElement[Any]"]] + if TYPE_CHECKING: def is_sql_compiler(c: Compiled) -> TypeGuard[SQLCompiler]: @@ -195,6 +206,9 @@ if TYPE_CHECKING: def is_table_value_type(t: TypeEngine[Any]) -> TypeGuard[TableValueType]: ... + def is_selectable(t: Any) -> TypeGuard[Selectable]: + ... + def is_select_base( t: Union[Executable, ReturnsRows] ) -> TypeGuard[SelectBase]: @@ -224,6 +238,7 @@ else: is_from_clause = operator.attrgetter("_is_from_clause") is_tuple_type = operator.attrgetter("_is_tuple_type") is_table_value_type = operator.attrgetter("_is_table_value") + is_selectable = operator.attrgetter("is_selectable") is_select_base = operator.attrgetter("_is_select_base") is_select_statement = operator.attrgetter("_is_select_statement") is_table = operator.attrgetter("_is_table") diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index f7692dbc2..f81878d55 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -218,7 +218,7 @@ def _generative(fn: _Fn) -> _Fn: """ - @util.decorator + @util.decorator # type: ignore def _generative( fn: _Fn, self: _SelfGenerativeType, *args: Any, **kw: Any ) -> _SelfGenerativeType: @@ -244,7 +244,7 @@ def _exclusive_against(*names: str, **kw: Any) -> Callable[[_Fn], _Fn]: for name in names ] - @util.decorator + @util.decorator # type: ignore def check(fn, *args, **kw): # make pylance happy by not including "self" in the argument # list @@ -260,7 +260,7 @@ def _exclusive_against(*names: str, **kw: Any) -> Callable[[_Fn], _Fn]: raise exc.InvalidRequestError(msg) return fn(self, *args, **kw) - return check + return check # type: ignore def _clone(element, **kw): @@ -1750,15 +1750,14 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): self._collection.append((k, col)) self._colset.update(c for (k, c) in self._collection) - # https://github.com/python/mypy/issues/12610 self._index.update( - (idx, c) for idx, (k, c) in enumerate(self._collection) # type: ignore # noqa: E501 + (idx, c) for idx, (k, c) in enumerate(self._collection) ) for col in replace_col: self.replace(col) def extend(self, iter_: Iterable[_NAMEDCOL]) -> None: - self._populate_separate_keys((col.key, col) for col in iter_) # type: ignore # noqa: E501 + self._populate_separate_keys((col.key, col) for col in iter_) def remove(self, column: _NAMEDCOL) -> None: if column not in self._colset: @@ -1772,9 +1771,8 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): (k, c) for (k, c) in self._collection if c is not column ] - # https://github.com/python/mypy/issues/12610 self._index.update( - {idx: col for idx, (k, col) in enumerate(self._collection)} # type: ignore # noqa: E501 + {idx: col for idx, (k, col) in enumerate(self._collection)} ) # delete higher index del self._index[len(self._collection)] @@ -1827,9 +1825,8 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): self._index.clear() - # https://github.com/python/mypy/issues/12610 self._index.update( - {idx: col for idx, (k, col) in enumerate(self._collection)} # type: ignore # noqa: E501 + {idx: col for idx, (k, col) in enumerate(self._collection)} ) self._index.update(self._collection) diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index 4bf45da9c..0659709ab 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -214,6 +214,7 @@ def expect( Type[roles.ExpressionElementRole[Any]], Type[roles.LimitOffsetRole], Type[roles.WhereHavingRole], + Type[roles.OnClauseRole], ], element: Any, **kw: Any, diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 938be0f81..c524a2602 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1078,7 +1078,7 @@ class SQLCompiler(Compiled): return list(self.insert_prefetch) + list(self.update_prefetch) @util.memoized_property - def _global_attributes(self): + def _global_attributes(self) -> Dict[Any, Any]: return {} @util.memoized_instancemethod diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index 6ac7c2448..052af6ac9 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -14,6 +14,7 @@ from __future__ import annotations import typing from typing import Any from typing import Callable +from typing import Iterable from typing import List from typing import Optional from typing import Sequence as typing_Sequence @@ -1143,7 +1144,7 @@ class SchemaDropper(InvokeDDLBase): def sort_tables( - tables: typing_Sequence["Table"], + tables: Iterable["Table"], skip_fn: Optional[Callable[["ForeignKeyConstraint"], bool]] = None, extra_dependencies: Optional[ typing_Sequence[Tuple["Table", "Table"]] diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index ea0fa7996..34d5127ab 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -293,11 +293,18 @@ class ClauseElement( __visit_name__ = "clause" - _propagate_attrs: _PropagateAttrsType = util.immutabledict() - """like annotations, however these propagate outwards liberally - as SQL constructs are built, and are set up at construction time. + if TYPE_CHECKING: - """ + @util.memoized_property + def _propagate_attrs(self) -> _PropagateAttrsType: + """like annotations, however these propagate outwards liberally + as SQL constructs are built, and are set up at construction time. + + """ + ... + + else: + _propagate_attrs = util.EMPTY_DICT @util.ro_memoized_property def description(self) -> Optional[str]: @@ -343,7 +350,9 @@ class ClauseElement( def _from_objects(self) -> List[FromClause]: return [] - def _set_propagate_attrs(self, values): + def _set_propagate_attrs( + self: SelfClauseElement, values: Mapping[str, Any] + ) -> SelfClauseElement: # usually, self._propagate_attrs is empty here. one case where it's # not is a subquery against ORM select, that is then pulled as a # property of an aliased class. should all be good @@ -526,13 +535,10 @@ class ClauseElement( if unique: bind._convert_to_unique() - return cast( - SelfClauseElement, - cloned_traverse( - self, - {"maintain_key": True, "detect_subquery_cols": True}, - {"bindparam": visit_bindparam}, - ), + return cloned_traverse( + self, + {"maintain_key": True, "detect_subquery_cols": True}, + {"bindparam": visit_bindparam}, ) def compare(self, other, **kw): @@ -730,7 +736,9 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly): # redefined with the specific types returned by ColumnElement hierarchies if typing.TYPE_CHECKING: - _propagate_attrs: _PropagateAttrsType + @util.non_memoized_property + def _propagate_attrs(self) -> _PropagateAttrsType: + ... def operate( self, op: OperatorType, *other: Any, **kwargs: Any @@ -2064,10 +2072,11 @@ class TextClause( roles.OrderByRole, roles.FromClauseRole, roles.SelectStatementRole, - roles.BinaryElementRole[Any], roles.InElementRole, Executable, DQLDMLClauseElement, + roles.BinaryElementRole[Any], + inspection.Inspectable["TextClause"], ): """Represent a literal SQL text fragment. diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py index da15c305f..4b220188f 100644 --- a/lib/sqlalchemy/sql/lambdas.py +++ b/lib/sqlalchemy/sql/lambdas.py @@ -444,7 +444,7 @@ class DeferredLambdaElement(LambdaElement): def _invoke_user_fn(self, fn, *arg): return fn(*self.lambda_args) - def _resolve_with_args(self, *lambda_args): + def _resolve_with_args(self, *lambda_args: Any) -> ClauseElement: assert isinstance(self._rec, AnalyzedFunction) tracker_fn = self._rec.tracker_instrumented_fn expr = tracker_fn(*lambda_args) @@ -478,7 +478,7 @@ class DeferredLambdaElement(LambdaElement): for deferred_copy_internals in self._transforms: expr = deferred_copy_internals(expr) - return expr + return expr # type: ignore def _copy_internals( self, clone=_clone, deferred_copy_internals=None, **kw diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py index 577d868fd..231c70a5b 100644 --- a/lib/sqlalchemy/sql/roles.py +++ b/lib/sqlalchemy/sql/roles.py @@ -22,9 +22,7 @@ if TYPE_CHECKING: from .base import _EntityNamespace from .base import ColumnCollection from .base import ReadOnlyColumnCollection - from .elements import ClauseElement from .elements import ColumnClause - from .elements import ColumnElement from .elements import Label from .elements import NamedColumn from .selectable import _SelectIterable @@ -271,7 +269,14 @@ class StatementRole(SQLRole): __slots__ = () _role_name = "Executable SQL or text() construct" - _propagate_attrs: _PropagateAttrsType = util.immutabledict() + if TYPE_CHECKING: + + @util.memoized_property + def _propagate_attrs(self) -> _PropagateAttrsType: + ... + + else: + _propagate_attrs = util.EMPTY_DICT class SelectStatementRole(StatementRole, ReturnsRowsRole): diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 92b9cc62c..52ba60a62 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -144,9 +144,9 @@ class SchemaConst(Enum): NULL_UNSPECIFIED = 3 """Symbol indicating the "nullable" keyword was not passed to a Column. - Normally we would expect None to be acceptable for this but some backends - such as that of SQL Server place special signficance on a "nullability" - value of None. + This is used to distinguish between the use case of passing + ``nullable=None`` to a :class:`.Column`, which has special meaning + on some backends such as SQL Server. """ @@ -308,7 +308,9 @@ class HasSchemaAttr(SchemaItem): schema: Optional[str] -class Table(DialectKWArgs, HasSchemaAttr, TableClause): +class Table( + DialectKWArgs, HasSchemaAttr, TableClause, inspection.Inspectable["Table"] +): r"""Represent a table in a database. e.g.:: @@ -1318,117 +1320,15 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): inherit_cache = True key: str - @overload - def __init__( - self, - *args: SchemaEventTarget, - autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto", - default: Optional[Any] = None, - doc: Optional[str] = None, - key: Optional[str] = None, - index: Optional[bool] = None, - unique: Optional[bool] = None, - info: Optional[_InfoType] = None, - nullable: Optional[ - Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]] - ] = NULL_UNSPECIFIED, - onupdate: Optional[Any] = None, - primary_key: bool = False, - server_default: Optional[_ServerDefaultType] = None, - server_onupdate: Optional[FetchedValue] = None, - quote: Optional[bool] = None, - system: bool = False, - comment: Optional[str] = None, - _proxies: Optional[Any] = None, - **dialect_kwargs: Any, - ): - ... - - @overload - def __init__( - self, - __name: str, - *args: SchemaEventTarget, - autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto", - default: Optional[Any] = None, - doc: Optional[str] = None, - key: Optional[str] = None, - index: Optional[bool] = None, - unique: Optional[bool] = None, - info: Optional[_InfoType] = None, - nullable: Optional[ - Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]] - ] = NULL_UNSPECIFIED, - onupdate: Optional[Any] = None, - primary_key: bool = False, - server_default: Optional[_ServerDefaultType] = None, - server_onupdate: Optional[FetchedValue] = None, - quote: Optional[bool] = None, - system: bool = False, - comment: Optional[str] = None, - _proxies: Optional[Any] = None, - **dialect_kwargs: Any, - ): - ... - - @overload def __init__( self, - __type: _TypeEngineArgument[_T], - *args: SchemaEventTarget, - autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto", - default: Optional[Any] = None, - doc: Optional[str] = None, - key: Optional[str] = None, - index: Optional[bool] = None, - unique: Optional[bool] = None, - info: Optional[_InfoType] = None, - nullable: Optional[ - Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]] - ] = NULL_UNSPECIFIED, - onupdate: Optional[Any] = None, - primary_key: bool = False, - server_default: Optional[_ServerDefaultType] = None, - server_onupdate: Optional[FetchedValue] = None, - quote: Optional[bool] = None, - system: bool = False, - comment: Optional[str] = None, - _proxies: Optional[Any] = None, - **dialect_kwargs: Any, - ): - ... - - @overload - def __init__( - self, - __name: str, - __type: _TypeEngineArgument[_T], + __name_pos: Optional[ + Union[str, _TypeEngineArgument[_T], SchemaEventTarget] + ] = None, + __type_pos: Optional[ + Union[_TypeEngineArgument[_T], SchemaEventTarget] + ] = None, *args: SchemaEventTarget, - autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto", - default: Optional[Any] = None, - doc: Optional[str] = None, - key: Optional[str] = None, - index: Optional[bool] = None, - unique: Optional[bool] = None, - info: Optional[_InfoType] = None, - nullable: Optional[ - Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]] - ] = NULL_UNSPECIFIED, - onupdate: Optional[Any] = None, - primary_key: bool = False, - server_default: Optional[_ServerDefaultType] = None, - server_onupdate: Optional[FetchedValue] = None, - quote: Optional[bool] = None, - system: bool = False, - comment: Optional[str] = None, - _proxies: Optional[Any] = None, - **dialect_kwargs: Any, - ): - ... - - def __init__( - self, - *args: Union[str, _TypeEngineArgument[_T], SchemaEventTarget], name: Optional[str] = None, type_: Optional[_TypeEngineArgument[_T]] = None, autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto", @@ -1440,7 +1340,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): info: Optional[_InfoType] = None, nullable: Optional[ Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]] - ] = NULL_UNSPECIFIED, + ] = SchemaConst.NULL_UNSPECIFIED, onupdate: Optional[Any] = None, primary_key: bool = False, server_default: Optional[_ServerDefaultType] = None, @@ -1953,7 +1853,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): """ # noqa: E501, RST201, RST202 - l_args = list(args) + l_args = [__name_pos, __type_pos] + list(args) del args if l_args: @@ -1963,6 +1863,8 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): "May not pass name positionally and as a keyword." ) name = l_args.pop(0) # type: ignore + elif l_args[0] is None: + l_args.pop(0) if l_args: coltype = l_args[0] @@ -1972,6 +1874,8 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): "May not pass type_ positionally and as a keyword." ) type_ = l_args.pop(0) # type: ignore + elif l_args[0] is None: + l_args.pop(0) if name is not None: name = quoted_name(name, quote) @@ -1989,7 +1893,6 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): self.primary_key = primary_key self._user_defined_nullable = udn = nullable - if udn is not NULL_UNSPECIFIED: self.nullable = udn else: @@ -5128,7 +5031,7 @@ class MetaData(HasSchemaAttr): def clear(self) -> None: """Clear all Table objects from this MetaData.""" - dict.clear(self.tables) + dict.clear(self.tables) # type: ignore self._schemas.clear() self._fk_memos.clear() diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index aab3c678c..9d4d1d6c7 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -1223,7 +1223,9 @@ class Join(roles.DMLTableRole, FromClause): @util.preload_module("sqlalchemy.sql.util") def _populate_column_collection(self): sqlutil = util.preloaded.sql_util - columns = [c for c in self.left.c] + [c for c in self.right.c] + columns: List[ColumnClause[Any]] = [c for c in self.left.c] + [ + c for c in self.right.c + ] self.primary_key.extend( # type: ignore sqlutil.reduce_columns( diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 284343154..d08fef60a 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -17,7 +17,9 @@ from typing import AbstractSet from typing import Any from typing import Callable from typing import cast +from typing import Collection from typing import Dict +from typing import Iterable from typing import Iterator from typing import List from typing import Optional @@ -32,15 +34,15 @@ from . import coercions from . import operators from . import roles from . import visitors +from ._typing import is_text_clause from .annotation import _deep_annotate as _deep_annotate from .annotation import _deep_deannotate as _deep_deannotate from .annotation import _shallow_annotate as _shallow_annotate from .base import _expand_cloned from .base import _from_objects -from .base import ColumnSet -from .cache_key import HasCacheKey # noqa -from .ddl import sort_tables # noqa -from .elements import _find_columns +from .cache_key import HasCacheKey as HasCacheKey +from .ddl import sort_tables as sort_tables +from .elements import _find_columns as _find_columns from .elements import _label_reference from .elements import _textual_label_reference from .elements import BindParameter @@ -67,10 +69,13 @@ from ..util.typing import Protocol if typing.TYPE_CHECKING: from ._typing import _ColumnExpressionArgument + from ._typing import _EquivalentColumnMap from ._typing import _TypeEngineArgument + from .elements import TextClause from .roles import FromClauseRole from .selectable import _JoinTargetElement from .selectable import _OnClauseElement + from .selectable import _SelectIterable from .selectable import Selectable from .visitors import _TraverseCallableType from .visitors import ExternallyTraversible @@ -752,7 +757,29 @@ def splice_joins( return ret -def reduce_columns(columns, *clauses, **kw): +@overload +def reduce_columns( + columns: Iterable[ColumnElement[Any]], + *clauses: Optional[ClauseElement], + **kw: bool, +) -> Sequence[ColumnElement[Any]]: + ... + + +@overload +def reduce_columns( + columns: _SelectIterable, + *clauses: Optional[ClauseElement], + **kw: bool, +) -> Sequence[Union[ColumnElement[Any], TextClause]]: + ... + + +def reduce_columns( + columns: _SelectIterable, + *clauses: Optional[ClauseElement], + **kw: bool, +) -> Collection[Union[ColumnElement[Any], TextClause]]: r"""given a list of columns, return a 'reduced' set based on natural equivalents. @@ -775,12 +802,15 @@ def reduce_columns(columns, *clauses, **kw): ignore_nonexistent_tables = kw.pop("ignore_nonexistent_tables", False) only_synonyms = kw.pop("only_synonyms", False) - columns = util.ordered_column_set(columns) + column_set = util.OrderedSet(columns) + cset_no_text: util.OrderedSet[ColumnElement[Any]] = column_set.difference( + c for c in column_set if is_text_clause(c) # type: ignore + ) omit = util.column_set() - for col in columns: + for col in cset_no_text: for fk in chain(*[c.foreign_keys for c in col.proxy_set]): - for c in columns: + for c in cset_no_text: if c is col: continue try: @@ -810,10 +840,12 @@ def reduce_columns(columns, *clauses, **kw): def visit_binary(binary): if binary.operator == operators.eq: cols = util.column_set( - chain(*[c.proxy_set for c in columns.difference(omit)]) + chain( + *[c.proxy_set for c in cset_no_text.difference(omit)] + ) ) if binary.left in cols and binary.right in cols: - for c in reversed(columns): + for c in reversed(cset_no_text): if c.shares_lineage(binary.right) and ( not only_synonyms or c.name == binary.left.name ): @@ -824,7 +856,7 @@ def reduce_columns(columns, *clauses, **kw): if clause is not None: visitors.traverse(clause, {}, {"binary": visit_binary}) - return ColumnSet(columns.difference(omit)) + return column_set.difference(omit) def criterion_as_pairs( @@ -923,9 +955,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): def __init__( self, selectable: Selectable, - equivalents: Optional[ - Dict[ColumnElement[Any], AbstractSet[ColumnElement[Any]]] - ] = None, + equivalents: Optional[_EquivalentColumnMap] = None, include_fn: Optional[Callable[[ClauseElement], bool]] = None, exclude_fn: Optional[Callable[[ClauseElement], bool]] = None, adapt_on_names: bool = False, @@ -1059,9 +1089,23 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): class _ColumnLookup(Protocol): - def __getitem__( - self, key: ColumnElement[Any] - ) -> Optional[ColumnElement[Any]]: + @overload + def __getitem__(self, key: None) -> None: + ... + + @overload + def __getitem__(self, key: ColumnClause[Any]) -> ColumnClause[Any]: + ... + + @overload + def __getitem__(self, key: ColumnElement[Any]) -> ColumnElement[Any]: + ... + + @overload + def __getitem__(self, key: _ET) -> _ET: + ... + + def __getitem__(self, key: Any) -> Any: ... @@ -1101,9 +1145,7 @@ class ColumnAdapter(ClauseAdapter): def __init__( self, selectable: Selectable, - equivalents: Optional[ - Dict[ColumnElement[Any], AbstractSet[ColumnElement[Any]]] - ] = None, + equivalents: Optional[_EquivalentColumnMap] = None, adapt_required: bool = False, include_fn: Optional[Callable[[ClauseElement], bool]] = None, exclude_fn: Optional[Callable[[ClauseElement], bool]] = None, @@ -1155,7 +1197,17 @@ class ColumnAdapter(ClauseAdapter): return ac - def traverse(self, obj): + @overload + def traverse(self, obj: Literal[None]) -> None: + ... + + @overload + def traverse(self, obj: _ET) -> _ET: + ... + + def traverse( + self, obj: Optional[ExternallyTraversible] + ) -> Optional[ExternallyTraversible]: return self.columns[obj] def chain(self, visitor: ExternalTraversal) -> ColumnAdapter: @@ -1172,7 +1224,9 @@ class ColumnAdapter(ClauseAdapter): adapt_clause = traverse adapt_list = ClauseAdapter.copy_and_process - def adapt_check_present(self, col): + def adapt_check_present( + self, col: ColumnElement[Any] + ) -> Optional[ColumnElement[Any]]: newcol = self.columns[col] if newcol is col and self._corresponding_column(col, True) is None: diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 7363f9ddc..e0a66fbcf 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -961,12 +961,16 @@ def cloned_traverse( ... +# a bit of controversy here, as the clone of the lead element +# *could* in theory replace with an entirely different kind of element. +# however this is really not how cloned_traverse is ever used internally +# at least. @overload def cloned_traverse( - obj: ExternallyTraversible, + obj: _ET, opts: Mapping[str, Any], visitors: Mapping[str, _TraverseCallableType[Any]], -) -> ExternallyTraversible: +) -> _ET: ... diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py index b90858512..16924a0a1 100644 --- a/lib/sqlalchemy/testing/plugin/plugin_base.py +++ b/lib/sqlalchemy/testing/plugin/plugin_base.py @@ -167,14 +167,6 @@ def setup_options(make_option): "when -n<num> is used", ) make_option( - "--reversetop", - action="store_true", - dest="reversetop", - default=False, - help="Use a random-ordering set implementation in the ORM " - "(helps reveal dependency issues)", - ) - make_option( "--requirements", action="callback", type=str, @@ -476,14 +468,6 @@ def _prep_testing_database(options, file_config): @post -def _reverse_topological(options, file_config): - if options.reversetop: - from sqlalchemy.orm.util import randomize_unitofwork - - randomize_unitofwork() - - -@post def _post_setup_options(opt, file_config): from sqlalchemy.testing import config diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index 086b008de..ed6945090 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -26,6 +26,7 @@ from typing import Mapping from typing import NoReturn from typing import Optional from typing import overload +from typing import Sequence from typing import Set from typing import Tuple from typing import TypeVar @@ -287,8 +288,8 @@ OrderedDict = dict sort_dictionary = _ordered_dictionary_sort -class WeakSequence: - def __init__(self, __elements=()): +class WeakSequence(Sequence[_T]): + def __init__(self, __elements: Sequence[_T] = ()): # adapted from weakref.WeakKeyDictionary, prevent reference # cycles in the collection itself def _remove(item, selfref=weakref.ref(self)): diff --git a/lib/sqlalchemy/util/_py_collections.py b/lib/sqlalchemy/util/_py_collections.py index 88deac28f..b02bca28f 100644 --- a/lib/sqlalchemy/util/_py_collections.py +++ b/lib/sqlalchemy/util/_py_collections.py @@ -92,7 +92,7 @@ class immutabledict(ImmutableDictBase[_KT, _VT]): new = dict.__new__(self.__class__) dict.__init__(new, self) - dict.update(new, __d) + dict.update(new, __d) # type: ignore return new def _union_w_kw( @@ -105,7 +105,7 @@ class immutabledict(ImmutableDictBase[_KT, _VT]): new = dict.__new__(self.__class__) dict.__init__(new, self) if __d: - dict.update(new, __d) + dict.update(new, __d) # type: ignore dict.update(new, kw) # type: ignore return new @@ -118,7 +118,7 @@ class immutabledict(ImmutableDictBase[_KT, _VT]): if new is None: new = dict.__new__(self.__class__) dict.__init__(new, self) - dict.update(new, d) + dict.update(new, d) # type: ignore if new is None: return self diff --git a/lib/sqlalchemy/util/deprecations.py b/lib/sqlalchemy/util/deprecations.py index 5c536b675..7c80ef4e0 100644 --- a/lib/sqlalchemy/util/deprecations.py +++ b/lib/sqlalchemy/util/deprecations.py @@ -13,7 +13,6 @@ from __future__ import annotations import re from typing import Any from typing import Callable -from typing import cast from typing import Dict from typing import Match from typing import Optional @@ -79,7 +78,7 @@ def warn_deprecated_limited( def deprecated_cls( - version: str, message: str, constructor: str = "__init__" + version: str, message: str, constructor: Optional[str] = "__init__" ) -> Callable[[Type[_T]], Type[_T]]: header = ".. deprecated:: %s %s" % (version, (message or "")) @@ -288,7 +287,9 @@ def deprecated_params(**specs: Tuple[str, str]) -> Callable[[_F], _F]: check_any_kw = spec.varkw - @decorator + # latest mypy has opinions here, not sure if they implemented + # Concatenate or something + @decorator # type: ignore def warned(fn: _F, *args: Any, **kwargs: Any) -> _F: for m in check_defaults: if (defaults[m] is None and kwargs[m] is not None) or ( @@ -332,7 +333,7 @@ def deprecated_params(**specs: Tuple[str, str]) -> Callable[[_F], _F]: for param, (version, message) in specs.items() }, ) - decorated = cast(_F, warned)(fn) + decorated = warned(fn) # type: ignore decorated.__doc__ = doc return decorated # type: ignore[no-any-return] @@ -352,7 +353,7 @@ def _sanitize_restructured_text(text: str) -> str: def _decorate_cls_with_warning( cls: Type[_T], - constructor: str, + constructor: Optional[str], wtype: Type[exc.SADeprecationWarning], message: str, version: str, @@ -418,7 +419,7 @@ def _decorate_with_warning( else: doc_only = "" - @decorator + @decorator # type: ignore def warned(fn: _F, *args: Any, **kwargs: Any) -> _F: skip_warning = not enable_warnings or kwargs.pop( "_sa_skip_warning", False @@ -435,9 +436,9 @@ def _decorate_with_warning( doc = inject_docstring_text(doc, docstring_header, 1) - decorated = cast(_F, warned)(func) + decorated = warned(func) # type: ignore decorated.__doc__ = doc - decorated._sa_warn = lambda: _warn_with_version( + decorated._sa_warn = lambda: _warn_with_version( # type: ignore message, version, wtype, stacklevel=3 ) return decorated # type: ignore[no-any-return] diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 9b3692d59..49c5d693a 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -238,6 +238,8 @@ def map_bits(fn: Callable[[int], Any], n: int) -> Iterator[Any]: _Fn = TypeVar("_Fn", bound="Callable[..., Any]") +# this seems to be in flux in recent mypy versions + def decorator(target: Callable[..., Any]) -> Callable[[_Fn], _Fn]: """A signature-matching decorator factory.""" diff --git a/lib/sqlalchemy/util/preloaded.py b/lib/sqlalchemy/util/preloaded.py index 260250b2c..ee3227d77 100644 --- a/lib/sqlalchemy/util/preloaded.py +++ b/lib/sqlalchemy/util/preloaded.py @@ -23,6 +23,8 @@ _FN = TypeVar("_FN", bound=Callable[..., Any]) if TYPE_CHECKING: from sqlalchemy.engine import default as engine_default + from sqlalchemy.orm import descriptor_props as orm_descriptor_props + from sqlalchemy.orm import relationships as orm_relationships from sqlalchemy.orm import session as orm_session from sqlalchemy.orm import util as orm_util from sqlalchemy.sql import dml as sql_dml diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index b3f3b9387..d192dc06b 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -187,7 +187,9 @@ def is_union(type_): return is_origin_of(type_, "Union") -def is_origin_of(type_, *names, module=None): +def is_origin_of( + type_: Any, *names: str, module: Optional[str] = None +) -> bool: """return True if the given type has an __origin__ with the given name and optional module.""" @@ -200,7 +202,7 @@ def is_origin_of(type_, *names, module=None): ) -def _get_type_name(type_): +def _get_type_name(type_: Type[Any]) -> str: if compat.py310: return type_.__name__ else: @@ -208,4 +210,4 @@ def _get_type_name(type_): if typ_name is None: typ_name = getattr(type_, "_name", None) - return typ_name + return typ_name # type: ignore diff --git a/pyproject.toml b/pyproject.toml index e727ee1e4..d16f03c03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,8 @@ incremental = true [[tool.mypy.overrides]] -# ad-hoc ignores +##################################################################### +# modules / packages explicitly not checked by Mypy at all right now. module = [ "sqlalchemy.engine.reflection", # interim, should be strict @@ -86,7 +87,8 @@ module = [ warn_unused_ignores = false ignore_errors = true -# strict checking +################################################ +# modules explicitly for Mypy strict checking [[tool.mypy.overrides]] module = [ @@ -98,6 +100,11 @@ module = [ "sqlalchemy.engine.*", "sqlalchemy.pool.*", + # uncomment, trying to make sure mypy + # is at a baseline + # "sqlalchemy.orm._orm_constructors", + + "sqlalchemy.orm.path_registry", "sqlalchemy.orm.scoping", "sqlalchemy.orm.session", "sqlalchemy.orm.state", @@ -114,7 +121,8 @@ warn_unused_ignores = false ignore_errors = false strict = true -# partial checking +################################################ +# modules explicitly for Mypy non-strict checking [[tool.mypy.overrides]] module = [ @@ -135,6 +143,12 @@ module = [ "sqlalchemy.sql.traversals", "sqlalchemy.sql.util", + "sqlalchemy.orm._orm_constructors", + + "sqlalchemy.orm.interfaces", + "sqlalchemy.orm.mapper", + "sqlalchemy.orm.util", + "sqlalchemy.util.*", ] diff --git a/test/ext/mypy/plain_files/association_proxy_one.py b/test/ext/mypy/plain_files/association_proxy_one.py index c5c897956..e8b57a0c0 100644 --- a/test/ext/mypy/plain_files/association_proxy_one.py +++ b/test/ext/mypy/plain_files/association_proxy_one.py @@ -40,8 +40,8 @@ class Address(Base): u1 = User() if typing.TYPE_CHECKING: - # EXPECTED_TYPE: sqlalchemy.*.associationproxy.AssociationProxyInstance\[builtins.set\*\[builtins.str\]\] + # EXPECTED_TYPE: sqlalchemy.*.associationproxy.AssociationProxyInstance\[builtins.set\*?\[builtins.str\]\] reveal_type(User.email_addresses) - # EXPECTED_TYPE: builtins.set\*\[builtins.str\] + # EXPECTED_TYPE: builtins.set\*?\[builtins.str\] reveal_type(u1.email_addresses) diff --git a/test/ext/mypy/plain_files/experimental_relationship.py b/test/ext/mypy/plain_files/experimental_relationship.py index e97a9598b..fe2742072 100644 --- a/test/ext/mypy/plain_files/experimental_relationship.py +++ b/test/ext/mypy/plain_files/experimental_relationship.py @@ -8,7 +8,6 @@ from typing import Set from sqlalchemy import ForeignKey from sqlalchemy import Integer -from sqlalchemy import String from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column @@ -42,8 +41,8 @@ class Address(Base): id = mapped_column(Integer, primary_key=True) user_id = mapped_column(ForeignKey("user.id")) - email = mapped_column(String, nullable=False) - email_name = mapped_column("email_name", String, nullable=False) + email: Mapped[str] + email_name: Mapped[str] = mapped_column("email_name") user_style_one: Mapped[User] = relationship() user_style_two: Mapped["User"] = relationship() @@ -56,14 +55,14 @@ if typing.TYPE_CHECKING: # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Union\[builtins.str, None\]\] reveal_type(User.extra_name) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.str\*\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.str\*?\] reveal_type(Address.email) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.str\*\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.str\*?\] reveal_type(Address.email_name) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[experimental_relationship.Address\]\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[experimental_relationship.Address\]\] reveal_type(User.addresses_style_one) - # EXPECTED_TYPE: sqlalchemy.orm.attributes.InstrumentedAttribute\[builtins.set\*\[experimental_relationship.Address\]\] + # EXPECTED_TYPE: sqlalchemy.orm.attributes.InstrumentedAttribute\[builtins.set\*?\[experimental_relationship.Address\]\] reveal_type(User.addresses_style_two) diff --git a/test/ext/mypy/plain_files/hybrid_one.py b/test/ext/mypy/plain_files/hybrid_one.py index 7d97024af..d9f97ebcf 100644 --- a/test/ext/mypy/plain_files/hybrid_one.py +++ b/test/ext/mypy/plain_files/hybrid_one.py @@ -47,7 +47,7 @@ expr2 = Interval.contains(7) expr3 = Interval.intersects(i2) if typing.TYPE_CHECKING: - # EXPECTED_TYPE: builtins.int\* + # EXPECTED_TYPE: builtins.int\*? reveal_type(i1.length) # EXPECTED_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\] diff --git a/test/ext/mypy/plain_files/hybrid_two.py b/test/ext/mypy/plain_files/hybrid_two.py index 6bfabbd30..ab2970656 100644 --- a/test/ext/mypy/plain_files/hybrid_two.py +++ b/test/ext/mypy/plain_files/hybrid_two.py @@ -69,10 +69,10 @@ expr3 = Interval.radius.in_([0.5, 5.2]) if typing.TYPE_CHECKING: - # EXPECTED_TYPE: builtins.int\* + # EXPECTED_TYPE: builtins.int\*? reveal_type(i1.length) - # EXPECTED_TYPE: builtins.float\* + # EXPECTED_TYPE: builtins.float\*? reveal_type(i2.radius) # EXPECTED_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\] diff --git a/test/ext/mypy/plain_files/mapped_column.py b/test/ext/mypy/plain_files/mapped_column.py index b20beeb3a..14f4ad845 100644 --- a/test/ext/mypy/plain_files/mapped_column.py +++ b/test/ext/mypy/plain_files/mapped_column.py @@ -14,68 +14,67 @@ class Base(DeclarativeBase): class X(Base): __tablename__ = "x" + # these are fine - pk, column is not null, have the attribute be + # non-optional, fine id: Mapped[int] = mapped_column(primary_key=True) int_id: Mapped[int] = mapped_column(Integer, primary_key=True) - # EXPECTED_MYPY: No overload variant of "mapped_column" matches argument types + # but this is also "fine" because the developer may wish to have the object + # in a pending state with None for the id for some period of time. + # "primary_key=True" will still be interpreted correctly in DDL err_int_id: Mapped[Optional[int]] = mapped_column( Integer, primary_key=True ) - id_name: Mapped[int] = mapped_column("id_name", primary_key=True) - int_id_name: Mapped[int] = mapped_column( - "int_id_name", Integer, primary_key=True - ) - - # EXPECTED_MYPY: No overload variant of "mapped_column" matches argument types + # also fine, X(err_int_id_name) is None when you first make the + # object err_int_id_name: Mapped[Optional[int]] = mapped_column( "err_int_id_name", Integer, primary_key=True ) - # note we arent getting into primary_key=True / nullable=True here. - # leaving that as undefined for now + id_name: Mapped[int] = mapped_column("id_name", primary_key=True) + int_id_name: Mapped[int] = mapped_column( + "int_id_name", Integer, primary_key=True + ) a: Mapped[str] = mapped_column() b: Mapped[Optional[str]] = mapped_column() - # can't detect error because no SQL type is present + # this can't be detected because we don't know the type c: Mapped[str] = mapped_column(nullable=True) d: Mapped[str] = mapped_column(nullable=False) e: Mapped[Optional[str]] = mapped_column(nullable=True) - # can't detect error because no SQL type is present f: Mapped[Optional[str]] = mapped_column(nullable=False) g: Mapped[str] = mapped_column(String) h: Mapped[Optional[str]] = mapped_column(String) - # EXPECTED_MYPY: No overload variant of "mapped_column" matches argument types + # this probably is wrong. however at the moment it seems better to + # decouple the right hand arguments from declaring things about the + # left side since it mostly doesn't work in any case. i: Mapped[str] = mapped_column(String, nullable=True) j: Mapped[str] = mapped_column(String, nullable=False) k: Mapped[Optional[str]] = mapped_column(String, nullable=True) - # EXPECTED_MYPY_RE: Argument \d to "mapped_column" has incompatible type l: Mapped[Optional[str]] = mapped_column(String, nullable=False) a_name: Mapped[str] = mapped_column("a_name") b_name: Mapped[Optional[str]] = mapped_column("b_name") - # can't detect error because no SQL type is present c_name: Mapped[str] = mapped_column("c_name", nullable=True) d_name: Mapped[str] = mapped_column("d_name", nullable=False) e_name: Mapped[Optional[str]] = mapped_column("e_name", nullable=True) - # can't detect error because no SQL type is present f_name: Mapped[Optional[str]] = mapped_column("f_name", nullable=False) g_name: Mapped[str] = mapped_column("g_name", String) h_name: Mapped[Optional[str]] = mapped_column("h_name", String) - # EXPECTED_MYPY: No overload variant of "mapped_column" matches argument types i_name: Mapped[str] = mapped_column("i_name", String, nullable=True) j_name: Mapped[str] = mapped_column("j_name", String, nullable=False) @@ -86,7 +85,6 @@ class X(Base): l_name: Mapped[Optional[str]] = mapped_column( "l_name", - # EXPECTED_MYPY_RE: Argument \d to "mapped_column" has incompatible type String, nullable=False, ) diff --git a/test/ext/mypy/plain_files/sql_operations.py b/test/ext/mypy/plain_files/sql_operations.py index 78b0a467c..f9b9b2ffe 100644 --- a/test/ext/mypy/plain_files/sql_operations.py +++ b/test/ext/mypy/plain_files/sql_operations.py @@ -3,6 +3,7 @@ import typing from sqlalchemy import Boolean from sqlalchemy import column from sqlalchemy import Integer +from sqlalchemy import select from sqlalchemy import String @@ -32,6 +33,7 @@ expr7 = c1 + "x" expr8 = c2 + 10 +stmt = select(column("q")).where(lambda: column("g") > 5).where(c2 == 5) if typing.TYPE_CHECKING: diff --git a/test/ext/mypy/plain_files/trad_relationship_uselist.py b/test/ext/mypy/plain_files/trad_relationship_uselist.py index b43dcd594..af7d292be 100644 --- a/test/ext/mypy/plain_files/trad_relationship_uselist.py +++ b/test/ext/mypy/plain_files/trad_relationship_uselist.py @@ -101,45 +101,45 @@ class Address(Base): if typing.TYPE_CHECKING: - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[trad_relationship_uselist.Address\]\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[trad_relationship_uselist.Address\]\] reveal_type(User.addresses_style_one) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*\[trad_relationship_uselist.Address\]\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*?\[trad_relationship_uselist.Address\]\] reveal_type(User.addresses_style_two) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*\[Any\]\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] reveal_type(User.addresses_style_three) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*\[trad_relationship_uselist.Address\]\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] reveal_type(User.addresses_style_three_cast) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[Any\]\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] reveal_type(User.addresses_style_four) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[trad_relationship_uselist.User\*\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] reveal_type(Address.user_style_one) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[trad_relationship_uselist.User\*\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[trad_relationship_uselist.User\*?\] reveal_type(Address.user_style_one_typed) # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] reveal_type(Address.user_style_two) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[trad_relationship_uselist.User\*\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[trad_relationship_uselist.User\*?\] reveal_type(Address.user_style_two_typed) # reveal_type(Address.user_style_six) # reveal_type(Address.user_style_seven) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[trad_relationship_uselist.User\]\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] reveal_type(Address.user_style_eight) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[trad_relationship_uselist.User\]\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] reveal_type(Address.user_style_nine) # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] reveal_type(Address.user_style_ten) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.dict\*\[builtins.str, trad_relationship_uselist.User\]\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.dict\*?\[builtins.str, trad_relationship_uselist.User\]\] reveal_type(Address.user_style_ten_typed) diff --git a/test/ext/mypy/plain_files/traditional_relationship.py b/test/ext/mypy/plain_files/traditional_relationship.py index 473ccb282..ce131dd00 100644 --- a/test/ext/mypy/plain_files/traditional_relationship.py +++ b/test/ext/mypy/plain_files/traditional_relationship.py @@ -60,29 +60,29 @@ class Address(Base): if typing.TYPE_CHECKING: - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[traditional_relationship.Address\]\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[traditional_relationship.Address\]\] reveal_type(User.addresses_style_one) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*\[traditional_relationship.Address\]\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*?\[traditional_relationship.Address\]\] reveal_type(User.addresses_style_two) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[traditional_relationship.User\*\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] reveal_type(Address.user_style_one) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[traditional_relationship.User\*\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[traditional_relationship.User\*?\] reveal_type(Address.user_style_one_typed) # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] reveal_type(Address.user_style_two) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[traditional_relationship.User\*\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[traditional_relationship.User\*?\] reveal_type(Address.user_style_two_typed) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[traditional_relationship.User\]\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[traditional_relationship.User\]\] reveal_type(Address.user_style_three) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[traditional_relationship.User\]\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[traditional_relationship.User\]\] reveal_type(Address.user_style_four) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*\[traditional_relationship.User\]\] + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] reveal_type(Address.user_style_five) diff --git a/test/ext/mypy/plugin_files/relationship_6255_one.py b/test/ext/mypy/plugin_files/relationship_6255_one.py index 0c8e3c4f6..15961c703 100644 --- a/test/ext/mypy/plugin_files/relationship_6255_one.py +++ b/test/ext/mypy/plugin_files/relationship_6255_one.py @@ -17,7 +17,7 @@ class User(Base): __tablename__ = "user" id = mapped_column(Integer, primary_key=True) - name = mapped_column(String, nullable=True) + name: Mapped[Optional[str]] = mapped_column(String, nullable=True) addresses: Mapped[List["Address"]] = relationship( "Address", back_populates="user" diff --git a/test/ext/mypy/plugin_files/typing_err3.py b/test/ext/mypy/plugin_files/typing_err3.py index 466e636a7..d29909c3c 100644 --- a/test/ext/mypy/plugin_files/typing_err3.py +++ b/test/ext/mypy/plugin_files/typing_err3.py @@ -43,11 +43,11 @@ class Address(Base): @declared_attr def email_address(cls) -> Column[String]: - # EXPECTED_MYPY: No overload variant of "Column" matches argument type "bool" # noqa + # EXPECTED_MYPY: Argument 1 to "Column" has incompatible type "bool"; return Column(True) @declared_attr # EXPECTED_MYPY: Invalid type comment or annotation def thisisweird(cls) -> Column(String): - # EXPECTED_MYPY: No overload variant of "Column" matches argument type "bool" # noqa + # EXPECTED_MYPY: Argument 1 to "Column" has incompatible type "bool"; return Column(False) diff --git a/test/ext/test_extendedattr.py b/test/ext/test_extendedattr.py index 43443b7f6..7830fcee6 100644 --- a/test/ext/test_extendedattr.py +++ b/test/ext/test_extendedattr.py @@ -18,6 +18,7 @@ from sqlalchemy.orm.instrumentation import register_class from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import is_not @@ -110,7 +111,13 @@ class DisposeTest(_ExtBase, fixtures.TestBase): class MyClass: __sa_instrumentation_manager__ = MyClassState - assert attributes.manager_of_class(MyClass) is None + assert attributes.opt_manager_of_class(MyClass) is None + + with expect_raises_message( + sa.orm.exc.UnmappedClassError, + r"Can't locate an instrumentation manager for class .*MyClass", + ): + attributes.manager_of_class(MyClass) t = Table( "my_table", @@ -120,7 +127,7 @@ class DisposeTest(_ExtBase, fixtures.TestBase): registry.map_imperatively(MyClass, t) - manager = attributes.manager_of_class(MyClass) + manager = attributes.opt_manager_of_class(MyClass) is_not(manager, None) is_(manager, MyClass.xyz) @@ -128,7 +135,7 @@ class DisposeTest(_ExtBase, fixtures.TestBase): registry.dispose() - manager = attributes.manager_of_class(MyClass) + manager = attributes.opt_manager_of_class(MyClass) is_(manager, None) assert not hasattr(MyClass, "xyz") @@ -532,9 +539,9 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest): register_class(Known) k, u = Known(), Unknown() - assert instrumentation.manager_of_class(Unknown) is None - assert instrumentation.manager_of_class(Known) is not None - assert instrumentation.manager_of_class(None) is None + assert instrumentation.opt_manager_of_class(Unknown) is None + assert instrumentation.opt_manager_of_class(Known) is not None + assert instrumentation.opt_manager_of_class(None) is None assert attributes.instance_state(k) is not None assert_raises((AttributeError, KeyError), attributes.instance_state, u) @@ -583,7 +590,10 @@ class FinderTest(_ExtBase, fixtures.ORMTest): ) register_class(A) - ne_(type(manager_of_class(A)), instrumentation.ClassManager) + ne_( + type(attributes.opt_manager_of_class(A)), + instrumentation.ClassManager, + ) def test_nativeext_submanager(self): class Mine(instrumentation.ClassManager): diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py index 67abc8971..b50cbc2ba 100644 --- a/test/orm/inheritance/test_basic.py +++ b/test/orm/inheritance/test_basic.py @@ -391,6 +391,33 @@ class PolymorphicOnNotLocalTest(fixtures.MappedTest): polymorphic_identity=0, ) + def test_polymorphic_on_not_present_col_partial_wpoly(self): + """fix for partial with_polymorphic(). + + found_during_type_annotation + + """ + t2, t1 = self.tables.t2, self.tables.t1 + Parent = self.classes.Parent + t1t2_join = select(t1.c.x).select_from(t1.join(t2)).alias() + + def go(): + t1t2_join_2 = select(t1.c.q).select_from(t1.join(t2)).alias() + self.mapper_registry.map_imperatively( + Parent, + t2, + polymorphic_on=t1t2_join.c.x, + with_polymorphic=("*", None), + polymorphic_identity=0, + ) + + assert_raises_message( + sa_exc.InvalidRequestError, + "Could not map polymorphic_on column 'x' to the mapped table - " + "polymorphic loads will not function properly", + go, + ) + def test_polymorphic_on_not_present_col(self): t2, t1 = self.tables.t2, self.tables.t1 Parent = self.classes.Parent diff --git a/test/orm/test_cascade.py b/test/orm/test_cascade.py index 51f37f028..5a171e372 100644 --- a/test/orm/test_cascade.py +++ b/test/orm/test_cascade.py @@ -9,6 +9,7 @@ from sqlalchemy import String from sqlalchemy import testing from sqlalchemy.orm import attributes from sqlalchemy.orm import backref +from sqlalchemy.orm import CascadeOptions from sqlalchemy.orm import class_mapper from sqlalchemy.orm import configure_mappers from sqlalchemy.orm import exc as orm_exc @@ -4217,11 +4218,12 @@ class SubclassCascadeTest(fixtures.DeclarativeMappedTest): eq_(s.query(Language).count(), 0) -class ViewonlyFlagWarningTest(fixtures.MappedTest): - """test for #4993. +class ViewonlyCascadeUpdate(fixtures.MappedTest): + """Test that cascades are trimmed accordingly when viewonly is set. - In 1.4, this moves to test/orm/test_cascade, deprecation warnings - become errors, will then be for #4994. + Originally #4993 and #4994 this was raising an error for invalid + cascades. in 2.0 this is simplified to just remove the write + cascades, allows the default cascade to be reasonable. """ @@ -4250,21 +4252,17 @@ class ViewonlyFlagWarningTest(fixtures.MappedTest): pass @testing.combinations( - ({"delete"}, {"delete"}), + ({"delete"}, {"none"}), ( {"all, delete-orphan"}, - {"delete", "delete-orphan", "merge", "save-update"}, + {"refresh-expire", "expunge"}, ), - ({"save-update, expunge"}, {"save-update"}), + ({"save-update, expunge"}, {"expunge"}), ) - def test_write_cascades(self, setting, settings_that_warn): + def test_write_cascades(self, setting, expected): Order = self.classes.Order - assert_raises_message( - sa_exc.ArgumentError, - 'Cascade settings "%s" apply to persistence ' - "operations" % (", ".join(sorted(settings_that_warn))), - relationship, + r = relationship( Order, primaryjoin=( self.tables.users.c.id == foreign(self.tables.orders.c.user_id) @@ -4272,6 +4270,7 @@ class ViewonlyFlagWarningTest(fixtures.MappedTest): cascade=", ".join(sorted(setting)), viewonly=True, ) + eq_(r.cascade, CascadeOptions(expected)) def test_expunge_cascade(self): User, Order, orders, users = ( @@ -4425,23 +4424,6 @@ class ViewonlyFlagWarningTest(fixtures.MappedTest): eq_(umapper.attrs["orders"].cascade, set()) - def test_write_cascade_disallowed_w_viewonly(self): - - Order = self.classes.Order - - assert_raises_message( - sa_exc.ArgumentError, - 'Cascade settings "delete, delete-orphan, merge, save-update" ' - "apply to persistence operations", - relationship, - Order, - primaryjoin=( - self.tables.users.c.id == foreign(self.tables.orders.c.user_id) - ), - cascade="all, delete, delete-orphan", - viewonly=True, - ) - class CollectionCascadesNoBackrefTest(fixtures.TestBase): """test the removal of cascade_backrefs behavior diff --git a/test/orm/test_instrumentation.py b/test/orm/test_instrumentation.py index 3b1010300..437129af1 100644 --- a/test/orm/test_instrumentation.py +++ b/test/orm/test_instrumentation.py @@ -11,6 +11,7 @@ from sqlalchemy.orm import relationship from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_warns_message from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures from sqlalchemy.testing import ne_ from sqlalchemy.testing.fixtures import fixture_session @@ -739,9 +740,15 @@ class MiscTest(fixtures.MappedTest): assert instrumentation.manager_of_class(A) is manager instrumentation.unregister_class(A) - assert instrumentation.manager_of_class(A) is None + assert instrumentation.opt_manager_of_class(A) is None assert not hasattr(A, "x") + with expect_raises_message( + sa.orm.exc.UnmappedClassError, + r"Can't locate an instrumentation manager for class .*A", + ): + instrumentation.manager_of_class(A) + assert A.__init__ == object.__init__ def test_compileonattr_rel_backref_a(self): diff --git a/test/orm/test_joins.py b/test/orm/test_joins.py index f71ab3032..43a34eae4 100644 --- a/test/orm/test_joins.py +++ b/test/orm/test_joins.py @@ -1374,6 +1374,16 @@ class JoinTest(QueryTest, AssertsCompiledSQL): [User(name="fred")], ) + def test_str_not_accepted_orm_join(self): + User, Address = self.classes.User, self.classes.Address + + with expect_raises_message( + sa.exc.ArgumentError, + "ON clause, typically a SQL expression or ORM " + "relationship attribute expected, got 'addresses'.", + ): + outerjoin(User, Address, "addresses") + def test_aliased_classes(self): User, Address = self.classes.User, self.classes.Address @@ -1409,7 +1419,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): eq_(result, [(user8, address3)]) result = ( - q.select_from(outerjoin(User, AdAlias, "addresses")) + q.select_from(outerjoin(User, AdAlias, User.addresses)) .filter(AdAlias.email_address == "ed@bettyboop.com") .all() ) @@ -1504,7 +1514,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): q = sess.query(Order) q = ( q.add_entity(Item) - .select_from(join(Order, Item, "items")) + .select_from(join(Order, Item, Order.items)) .order_by(Order.id, Item.id) ) result = q.all() @@ -1513,7 +1523,7 @@ class JoinTest(QueryTest, AssertsCompiledSQL): IAlias = aliased(Item) q = ( sess.query(Order, IAlias) - .select_from(join(Order, IAlias, "items")) + .select_from(join(Order, IAlias, Order.items)) .filter(IAlias.description == "item 3") ) result = q.all() @@ -2569,18 +2579,6 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): s.query(Node).join(Node.children)._compile_context, ) - def test_explicit_join_1(self): - Node = self.classes.Node - n1 = aliased(Node) - n2 = aliased(Node) - - self.assert_compile( - join(Node, n1, "children").join(n2, "children"), - "nodes JOIN nodes AS nodes_1 ON nodes.id = nodes_1.parent_id " - "JOIN nodes AS nodes_2 ON nodes_1.id = nodes_2.parent_id", - use_default_dialect=True, - ) - def test_explicit_join_2(self): Node = self.classes.Node n1 = aliased(Node) @@ -2598,12 +2596,8 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): n1 = aliased(Node) n2 = aliased(Node) - # the join_to_left=False here is unfortunate. the default on this - # flag should be False. self.assert_compile( - join(Node, n1, Node.children).join( - n2, Node.children, join_to_left=False - ), + join(Node, n1, Node.children).join(n2, Node.children), "nodes JOIN nodes AS nodes_1 ON nodes.id = nodes_1.parent_id " "JOIN nodes AS nodes_2 ON nodes.id = nodes_2.parent_id", use_default_dialect=True, @@ -2646,7 +2640,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): node = ( sess.query(Node) - .select_from(join(Node, n1, "children")) + .select_from(join(Node, n1, n1.children)) .filter(n1.data == "n122") .first() ) @@ -2660,7 +2654,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): node = ( sess.query(Node) - .select_from(join(Node, n1, "children").join(n2, "children")) + .select_from(join(Node, n1, Node.children).join(n2, n1.children)) .filter(n2.data == "n122") .first() ) @@ -2676,7 +2670,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): node = ( sess.query(Node) .select_from( - join(Node, n1, Node.id == n1.parent_id).join(n2, "children") + join(Node, n1, Node.id == n1.parent_id).join(n2, n1.children) ) .filter(n2.data == "n122") .first() @@ -2691,7 +2685,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): node = ( sess.query(Node) - .select_from(join(Node, n1, "parent").join(n2, "parent")) + .select_from(join(Node, n1, Node.parent).join(n2, n1.parent)) .filter( and_(Node.data == "n122", n1.data == "n12", n2.data == "n1") ) @@ -2708,7 +2702,7 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL): eq_( list( sess.query(Node) - .select_from(join(Node, n1, "parent").join(n2, "parent")) + .select_from(join(Node, n1, Node.parent).join(n2, n1.parent)) .filter( and_( Node.data == "n122", n1.data == "n12", n2.data == "n1" @@ -3085,7 +3079,7 @@ class SelfReferentialM2MTest(fixtures.MappedTest): n1 = aliased(Node) eq_( sess.query(Node) - .select_from(join(Node, n1, "children")) + .select_from(join(Node, n1, Node.children)) .filter(n1.data.in_(["n3", "n7"])) .order_by(Node.id) .all(), diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py index 980c82fbe..d8cc48939 100644 --- a/test/orm/test_mapper.py +++ b/test/orm/test_mapper.py @@ -2,6 +2,7 @@ import logging import logging.handlers import sqlalchemy as sa +from sqlalchemy import column from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import Integer @@ -9,6 +10,7 @@ from sqlalchemy import literal from sqlalchemy import MetaData from sqlalchemy import select from sqlalchemy import String +from sqlalchemy import table from sqlalchemy import testing from sqlalchemy import util from sqlalchemy.engine import default @@ -132,6 +134,22 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): ): self.mapper(User, users) + def test_no_table(self): + """test new error condition raised for table=None + + found_during_type_annotation + + """ + + User = self.classes.User + + with expect_raises_message( + sa.exc.ArgumentError, + r"Mapper\[User\(None\)\] has None for a primary table " + r"argument and does not specify 'inherits'", + ): + self.mapper(User, None) + def test_cant_call_legacy_constructor_directly(self): users, User = ( self.tables.users, @@ -341,6 +359,34 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): s, ) + def test_no_tableclause(self): + """It's not tested for a Mapper to have lower-case table() objects + as part of its collection of tables, and in particular these objects + won't report on constraints or primary keys, which while this doesn't + necessarily disqualify them from being part of a mapper, we don't + have assumptions figured out right now to accommodate them. + + found_during_type_annotation + + """ + User = self.classes.User + users = self.tables.users + + address = table( + "address", + column("address_id", Integer), + column("user_id", Integer), + ) + + with expect_raises_message( + sa.exc.ArgumentError, + "ORM mappings can only be made against schema-level Table " + "objects, not TableClause; got tableclause 'address'", + ): + self.mapper_registry.map_imperatively( + User, users.join(address, users.c.id == address.c.user_id) + ) + def test_reconfigure_on_other_mapper(self): """A configure trigger on an already-configured mapper still triggers a check against all mappers.""" @@ -666,7 +712,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): (column_property, (users.c.name,)), (relationship, (Address,)), (composite, (MyComposite, "id", "name")), - (synonym, "foo"), + (synonym, ("foo",)), ]: obj = constructor(info={"x": "y"}, *args) eq_(obj.info, {"x": "y"}) diff --git a/test/orm/test_options.py b/test/orm/test_options.py index 96759e388..d6fadc449 100644 --- a/test/orm/test_options.py +++ b/test/orm/test_options.py @@ -11,7 +11,6 @@ from sqlalchemy import Table from sqlalchemy import testing from sqlalchemy.orm import aliased from sqlalchemy.orm import attributes -from sqlalchemy.orm import class_mapper from sqlalchemy.orm import column_property from sqlalchemy.orm import contains_eager from sqlalchemy.orm import defaultload @@ -82,8 +81,7 @@ class PathTest: r = [] for i, item in enumerate(path): if i % 2 == 0: - if isinstance(item, type): - item = class_mapper(item) + item = inspect(item) else: if isinstance(item, str): item = inspect(r[-1]).mapper.attrs[item] diff --git a/test/orm/test_query.py b/test/orm/test_query.py index d0c8f4108..55414364c 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -6389,6 +6389,23 @@ class ParentTest(QueryTest, AssertsCompiledSQL): Order(description="order 5"), ] == o + def test_invalid_property(self): + """Test if with_parent is passed a non-relationship + + found_during_type_annotation + + """ + User, Address = self.classes.User, self.classes.Address + + sess = fixture_session() + u1 = sess.get(User, 7) + with expect_raises_message( + sa_exc.ArgumentError, + r"Expected relationship property for with_parent\(\), " + "got User.name", + ): + with_parent(u1, User.name) + def test_select_from(self): User, Address = self.classes.User, self.classes.Address diff --git a/test/orm/test_utils.py b/test/orm/test_utils.py index 03c31dc0f..c829582fd 100644 --- a/test/orm/test_utils.py +++ b/test/orm/test_utils.py @@ -5,6 +5,7 @@ from sqlalchemy import MetaData from sqlalchemy import select from sqlalchemy import Table from sqlalchemy import testing +from sqlalchemy.engine import result from sqlalchemy.ext.hybrid import hybrid_method from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import aliased @@ -465,8 +466,7 @@ class IdentityKeyTest(_fixtures.FixtureTest): def _cases(): return testing.combinations( - (orm_util,), - (Session,), + (orm_util,), (Session,), argnames="ormutil" ) @_cases() @@ -504,12 +504,29 @@ class IdentityKeyTest(_fixtures.FixtureTest): eq_(key, (User, (u.id,), None)) @_cases() - def test_identity_key_3(self, ormutil): + @testing.combinations("dict", "row", "mapping", argnames="rowtype") + def test_identity_key_3(self, ormutil, rowtype): + """test a real Row works with identity_key. + + this was broken w/ 1.4 future mode as we are assuming a mapping + here. to prevent regressions, identity_key now accepts any of + dict, RowMapping, Row for the "row". + + found_during_type_annotation + + + """ User, users = self.classes.User, self.tables.users self.mapper_registry.map_imperatively(User, users) - row = {users.c.id: 1, users.c.name: "Frank"} + if rowtype == "dict": + row = {users.c.id: 1, users.c.name: "Frank"} + elif rowtype in ("mapping", "row"): + row = result.result_tuple([users.c.id, users.c.name])((1, "Frank")) + if rowtype == "mapping": + row = row._mapping + key = ormutil.identity_key(User, row=row) eq_(key, (User, (1,), None)) diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py index 9fdc51938..7fa39825c 100644 --- a/test/sql/test_selectable.py +++ b/test/sql/test_selectable.py @@ -776,6 +776,28 @@ class SelectableTest( "table1.col3, table1.colx FROM table1) AS anon_1", ) + def test_reduce_cols_odd_expressions(self): + """test util.reduce_columns() works with text, non-col expressions + in a SELECT. + + found_during_type_annotation + + """ + + stmt = select( + table1.c.col1, + table1.c.col3 * 5, + text("some_expr"), + table2.c.col2, + func.foo(), + ).join(table2) + self.assert_compile( + stmt.reduce_columns(only_synonyms=False), + "SELECT table1.col1, table1.col3 * :col3_1 AS anon_1, " + "some_expr, foo() AS foo_1 FROM table1 JOIN table2 " + "ON table1.col1 = table2.col2", + ) + def test_with_only_generative_no_list(self): s1 = table1.select().scalar_subquery() |
