diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2023-02-21 10:34:01 -0500 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2023-02-26 22:20:11 -0500 |
| commit | 681055f9fb5230d344a67f47b0c60fc1a5804b3e (patch) | |
| tree | 2cb2b6bda2e36b936c0f50fe202bf40214408bd7 /lib/sqlalchemy | |
| parent | 304d590ad4016d5ec627edd55e9ac5b747c68d2a (diff) | |
| download | sqlalchemy-681055f9fb5230d344a67f47b0c60fc1a5804b3e.tar.gz | |
apply a fixed locals w/ Mapped to all de-stringify
Continued the fix for :ticket:`8853`, allowing the :class:`_orm.Mapped`
name to be fully qualified regardless of whether or not
``from __annotations__ import future`` were present. This issue first fixed
in 2.0.0b3 confirmed that this case worked via the test suite, however the
test suite apparently was not testing the behavior for the name ``Mapped``
not being locally present at all; string resolution has been updated to
ensure the ``Mapped`` symbol is locatable as applies to how the ORM uses
these functions.
Fixes: #8853
Fixes: #9335
Change-Id: Id82d09aee906165a4d77c7da6a0b4177dd675c10
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/orm/decl_base.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/descriptor_props.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/properties.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/util.py | 65 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/typing.py | 35 |
5 files changed, 95 insertions, 13 deletions
diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index 29d748596..d01aad439 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -55,6 +55,7 @@ from .properties import MappedColumn from .util import _extract_mapped_subtype from .util import _is_mapped_annotation from .util import class_mapper +from .util import de_stringify_annotation from .. import event from .. import exc from .. import util @@ -64,7 +65,6 @@ from ..sql.schema import Column from ..sql.schema import Table from ..util import topological from ..util.typing import _AnnotationScanType -from ..util.typing import de_stringify_annotation from ..util.typing import is_fwd_ref from ..util.typing import is_literal from ..util.typing import Protocol diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index b65171c9d..fd28830d9 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -44,6 +44,7 @@ from .interfaces import _MapsColumns from .interfaces import MapperProperty from .interfaces import PropComparator from .util import _none_set +from .util import de_stringify_annotation from .. import event from .. import exc as sa_exc from .. import schema @@ -52,7 +53,6 @@ from .. import util from ..sql import expression from ..sql import operators from ..sql.elements import BindParameter -from ..util.typing import de_stringify_annotation from ..util.typing import is_fwd_ref from ..util.typing import is_pep593 from ..util.typing import typing_get_args diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index a5f34f3de..4c07bad23 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -41,6 +41,8 @@ from .interfaces import MapperProperty from .interfaces import PropComparator from .interfaces import StrategizedProperty from .relationships import RelationshipProperty +from .util import de_stringify_annotation +from .util import de_stringify_union_elements from .. import exc as sa_exc from .. import ForeignKey from .. import log @@ -52,8 +54,6 @@ from ..sql.schema import Column from ..sql.schema import SchemaConst from ..sql.type_api import TypeEngine from ..util.typing import de_optionalize_union_types -from ..util.typing import de_stringify_annotation -from ..util.typing import de_stringify_union_elements from ..util.typing import is_fwd_ref from ..util.typing import is_optional_union from ..util.typing import is_pep593 diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index ad9ce2013..7966f6cd9 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -9,6 +9,7 @@ from __future__ import annotations import enum +import functools import re import types import typing @@ -46,6 +47,7 @@ from .base import attribute_str as attribute_str # noqa: F401 from .base import class_mapper as class_mapper from .base import InspectionAttr as InspectionAttr from .base import instance_str as instance_str # noqa: F401 +from .base import Mapped from .base import object_mapper as object_mapper from .base import object_state as object_state # noqa: F401 from .base import opt_manager_of_class @@ -79,10 +81,14 @@ from ..sql.elements import ColumnElement from ..sql.elements import KeyedColumnElement from ..sql.selectable import FromClause from ..util.langhelpers import MemoizedSlots -from ..util.typing import de_stringify_annotation -from ..util.typing import eval_name_only +from ..util.typing import de_stringify_annotation as _de_stringify_annotation +from ..util.typing import ( + de_stringify_union_elements as _de_stringify_union_elements, +) +from ..util.typing import eval_name_only as _eval_name_only from ..util.typing import is_origin_of_cls from ..util.typing import Literal +from ..util.typing import Protocol from ..util.typing import typing_get_origin if typing.TYPE_CHECKING: @@ -113,6 +119,7 @@ if typing.TYPE_CHECKING: from ..sql.selectable import Subquery from ..sql.visitors import anon_map from ..util.typing import _AnnotationScanType + from ..util.typing import ArgsTypeProcotol _T = TypeVar("_T", bound=Any) @@ -130,6 +137,58 @@ all_cascades = frozenset( ) +_de_stringify_partial = functools.partial( + functools.partial, locals_=util.immutabledict({"Mapped": Mapped}) +) + +# partial is practically useless as we have to write out the whole +# function and maintain the signature anyway + + +class _DeStringifyAnnotation(Protocol): + def __call__( + self, + cls: Type[Any], + annotation: _AnnotationScanType, + originating_module: str, + *, + str_cleanup_fn: Optional[Callable[[str, str], str]] = None, + include_generic: bool = False, + ) -> Type[Any]: + ... + + +de_stringify_annotation = cast( + _DeStringifyAnnotation, _de_stringify_partial(_de_stringify_annotation) +) + + +class _DeStringifyUnionElements(Protocol): + def __call__( + self, + cls: Type[Any], + annotation: ArgsTypeProcotol, + originating_module: str, + *, + str_cleanup_fn: Optional[Callable[[str, str], str]] = None, + ) -> Type[Any]: + ... + + +de_stringify_union_elements = cast( + _DeStringifyUnionElements, + _de_stringify_partial(_de_stringify_union_elements), +) + + +class _EvalNameOnly(Protocol): + def __call__(self, name: str, module_name: str) -> Any: + ... + + +eval_name_only = cast(_EvalNameOnly, _de_stringify_partial(_eval_name_only)) + + class CascadeOptions(FrozenSet[str]): """Keeps track of the options sent to :paramref:`.relationship.cascade`""" @@ -2271,7 +2330,7 @@ def _extract_mapped_subtype( cls, raw_annotation, originating_module, - _cleanup_mapped_str_annotation, + str_cleanup_fn=_cleanup_mapped_str_annotation, ) except _CleanupError as ce: raise sa_exc.ArgumentError( diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 9e6df0d35..24d8dd2dc 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -18,6 +18,7 @@ from typing import Dict from typing import ForwardRef from typing import Generic from typing import Iterable +from typing import Mapping from typing import NewType from typing import NoReturn from typing import Optional @@ -123,6 +124,8 @@ def de_stringify_annotation( cls: Type[Any], annotation: _AnnotationScanType, originating_module: str, + locals_: Mapping[str, Any], + *, str_cleanup_fn: Optional[Callable[[str, str], str]] = None, include_generic: bool = False, ) -> Type[Any]: @@ -150,7 +153,9 @@ def de_stringify_annotation( if str_cleanup_fn: annotation = str_cleanup_fn(annotation, originating_module) - annotation = eval_expression(annotation, originating_module) + annotation = eval_expression( + annotation, originating_module, locals_=locals_ + ) if ( include_generic @@ -162,6 +167,7 @@ def de_stringify_annotation( cls, elem, originating_module, + locals_, str_cleanup_fn=str_cleanup_fn, include_generic=include_generic, ) @@ -183,7 +189,12 @@ def _copy_generic_annotation_with( return annotation.__origin__[elements] # type: ignore -def eval_expression(expression: str, module_name: str) -> Any: +def eval_expression( + expression: str, + module_name: str, + *, + locals_: Optional[Mapping[str, Any]] = None, +) -> Any: try: base_globals: Dict[str, Any] = sys.modules[module_name].__dict__ except KeyError as ke: @@ -191,8 +202,9 @@ def eval_expression(expression: str, module_name: str) -> Any: f"Module {module_name} isn't present in sys.modules; can't " f"evaluate expression {expression}" ) from ke + try: - annotation = eval(expression, base_globals, None) + annotation = eval(expression, base_globals, locals_) except Exception as err: raise NameError( f"Could not de-stringify annotation {expression!r}" @@ -201,9 +213,14 @@ def eval_expression(expression: str, module_name: str) -> Any: return annotation -def eval_name_only(name: str, module_name: str) -> Any: +def eval_name_only( + name: str, + module_name: str, + *, + locals_: Optional[Mapping[str, Any]] = None, +) -> Any: if "." in name: - return eval_expression(name, module_name) + return eval_expression(name, module_name, locals_=locals_) try: base_globals: Dict[str, Any] = sys.modules[module_name].__dict__ @@ -237,12 +254,18 @@ def de_stringify_union_elements( cls: Type[Any], annotation: ArgsTypeProcotol, originating_module: str, + locals_: Mapping[str, Any], + *, str_cleanup_fn: Optional[Callable[[str, str], str]] = None, ) -> Type[Any]: return make_union_type( *[ de_stringify_annotation( - cls, anno, originating_module, str_cleanup_fn + cls, + anno, + originating_module, + {}, + str_cleanup_fn=str_cleanup_fn, ) for anno in annotation.__args__ ] |
