diff options
Diffstat (limited to 'lib')
| -rw-r--r-- | lib/sqlalchemy/cyextension/resultproxy.pyx | 5 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/base.py | 3 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/ext.py | 15 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/automap.py | 20 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/__init__.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/_orm_constructors.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/decl_base.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/descriptor_props.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/loading.py | 54 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/mapper.py | 76 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/properties.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/scoping.py | 27 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/util.py | 67 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/_typing.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/dml.py | 3 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/typing.py | 35 |
16 files changed, 267 insertions, 53 deletions
diff --git a/lib/sqlalchemy/cyextension/resultproxy.pyx b/lib/sqlalchemy/cyextension/resultproxy.pyx index e88c8ec0b..96a028d93 100644 --- a/lib/sqlalchemy/cyextension/resultproxy.pyx +++ b/lib/sqlalchemy/cyextension/resultproxy.pyx @@ -3,9 +3,10 @@ import operator cdef int MD_INDEX = 0 # integer index in cursor.description +cdef int _KEY_OBJECTS_ONLY = 1 KEY_INTEGER_ONLY = 0 -KEY_OBJECTS_ONLY = 1 +KEY_OBJECTS_ONLY = _KEY_OBJECTS_ONLY cdef class BaseRow: cdef readonly object _parent @@ -76,7 +77,7 @@ cdef class BaseRow: if mdindex is None: self._parent._raise_for_ambiguous_column_name(rec) elif ( - self._key_style == KEY_OBJECTS_ONLY + self._key_style == _KEY_OBJECTS_ONLY and isinstance(key, int) ): raise KeyError(key) diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 255c72042..3ba103802 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -2350,8 +2350,9 @@ class PGDDLCompiler(compiler.DDLCompiler): constraint ) elements = [] + kw["include_table"] = False + kw["literal_binds"] = True for expr, name, op in constraint._render_exprs: - kw["include_table"] = False exclude_element = self.sql_compiler.process(expr, **kw) + ( (" " + constraint.ops[expr.key]) if hasattr(expr, "key") and expr.key in constraint.ops diff --git a/lib/sqlalchemy/dialects/postgresql/ext.py b/lib/sqlalchemy/dialects/postgresql/ext.py index 0f5efb1de..22604955d 100644 --- a/lib/sqlalchemy/dialects/postgresql/ext.py +++ b/lib/sqlalchemy/dialects/postgresql/ext.py @@ -164,16 +164,15 @@ class ExcludeConstraint(ColumnCollectionConstraint): :param \*elements: A sequence of two tuples of the form ``(column, operator)`` where - "column" is a SQL expression element or a raw SQL string, most - typically a :class:`_schema.Column` object, - and "operator" is a string - containing the operator to use. In order to specify a column name - when a :class:`_schema.Column` object is not available, - while ensuring + "column" is a SQL expression element or the name of a column as + string, most typically a :class:`_schema.Column` object, + and "operator" is a string containing the operator to use. + In order to specify a column name when a :class:`_schema.Column` + object is not available, while ensuring that any necessary quoting rules take effect, an ad-hoc :class:`_schema.Column` or :func:`_expression.column` - object should be - used. + object should be used. ``column`` may also be a string SQL + expression when passed as :func:`_expression.literal_column` :param name: Optional, the in-database name of this constraint. diff --git a/lib/sqlalchemy/ext/automap.py b/lib/sqlalchemy/ext/automap.py index 030015284..1861791b7 100644 --- a/lib/sqlalchemy/ext/automap.py +++ b/lib/sqlalchemy/ext/automap.py @@ -1188,6 +1188,14 @@ class AutomapBase: .. versionadded:: 1.4 """ + + for mr in cls.__mro__: + if "_sa_automapbase_bookkeeping" in mr.__dict__: + automap_base = cast("Type[AutomapBase]", mr) + break + else: + assert False, "Can't locate automap base in class hierarchy" + glbls = globals() if classname_for_table is None: classname_for_table = glbls["classname_for_table"] @@ -1237,7 +1245,7 @@ class AutomapBase: ] many_to_many = [] - bookkeeping = cls._sa_automapbase_bookkeeping + bookkeeping = automap_base._sa_automapbase_bookkeeping metadata_tables = cls.metadata.tables for table_key in set(metadata_tables).difference( @@ -1278,7 +1286,7 @@ class AutomapBase: mapped_cls = type( newname, - (cls,), + (automap_base,), clsdict, ) map_config = _DeferredMapperConfig.config_for_cls( @@ -1309,7 +1317,7 @@ class AutomapBase: for map_config in table_to_map_config.values(): _relationships_for_fks( - cls, + automap_base, map_config, table_to_map_config, collection_class, @@ -1320,7 +1328,7 @@ class AutomapBase: for lcl_m2m, rem_m2m, m2m_const, table in many_to_many: _m2m_relationship( - cls, + automap_base, lcl_m2m, rem_m2m, m2m_const, @@ -1332,7 +1340,9 @@ class AutomapBase: generate_relationship, ) - for map_config in _DeferredMapperConfig.classes_for_base(cls): + for map_config in _DeferredMapperConfig.classes_for_base( + automap_base + ): map_config.map() _sa_decl_prepare = True diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index d54e1ccb9..69cd7f598 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -120,6 +120,7 @@ from .relationships import foreign as foreign from .relationships import Relationship as Relationship from .relationships import RelationshipProperty as RelationshipProperty from .relationships import remote as remote +from .scoping import QueryPropertyDescriptor as QueryPropertyDescriptor from .scoping import scoped_session as scoped_session from .session import close_all_sessions as close_all_sessions from .session import make_transient as make_transient diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py index 3bd1db79d..64e7937f1 100644 --- a/lib/sqlalchemy/orm/_orm_constructors.py +++ b/lib/sqlalchemy/orm/_orm_constructors.py @@ -2208,7 +2208,7 @@ def aliased( def with_polymorphic( - base: Union[_O, Mapper[_O]], + base: Union[Type[_O], Mapper[_O]], classes: Union[Literal["*"], Iterable[Type[Any]]], selectable: Union[Literal[False, None], FromClause] = False, flat: bool = False, diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index 29d748596..d01aad439 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -55,6 +55,7 @@ from .properties import MappedColumn from .util import _extract_mapped_subtype from .util import _is_mapped_annotation from .util import class_mapper +from .util import de_stringify_annotation from .. import event from .. import exc from .. import util @@ -64,7 +65,6 @@ from ..sql.schema import Column from ..sql.schema import Table from ..util import topological from ..util.typing import _AnnotationScanType -from ..util.typing import de_stringify_annotation from ..util.typing import is_fwd_ref from ..util.typing import is_literal from ..util.typing import Protocol diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index b65171c9d..fd28830d9 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -44,6 +44,7 @@ from .interfaces import _MapsColumns from .interfaces import MapperProperty from .interfaces import PropComparator from .util import _none_set +from .util import de_stringify_annotation from .. import event from .. import exc as sa_exc from .. import schema @@ -52,7 +53,6 @@ from .. import util from ..sql import expression from ..sql import operators from ..sql.elements import BindParameter -from ..util.typing import de_stringify_annotation from ..util.typing import is_fwd_ref from ..util.typing import is_pep593 from ..util.typing import typing_get_args diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index 54b96c215..7974d94c5 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -972,17 +972,17 @@ def _instance_processor( if not refresh_state and _polymorphic_from is not None: key = ("loader", path.path) + if key in context.attributes and context.attributes[key].strategy == ( ("selectinload_polymorphic", True), ): - selectin_load_via = mapper._should_selectin_load( - context.attributes[key].local_opts["entities"], - _polymorphic_from, - ) + option_entities = context.attributes[key].local_opts["entities"] else: - selectin_load_via = mapper._should_selectin_load( - None, _polymorphic_from - ) + option_entities = None + selectin_load_via = mapper._should_selectin_load( + option_entities, + _polymorphic_from, + ) if selectin_load_via and selectin_load_via is not _polymorphic_from: # only_load_props goes w/ refresh_state only, and in a refresh @@ -990,8 +990,13 @@ def _instance_processor( # loading does not apply assert only_load_props is None - callable_ = _load_subclass_via_in(context, path, selectin_load_via) - + callable_ = _load_subclass_via_in( + context, + path, + selectin_load_via, + _polymorphic_from, + option_entities, + ) PostLoad.callable_for_path( context, load_path, @@ -1212,17 +1217,42 @@ def _instance_processor( return _instance -def _load_subclass_via_in(context, path, entity): +def _load_subclass_via_in( + context, path, entity, polymorphic_from, option_entities +): mapper = entity.mapper + # TODO: polymorphic_from seems to be a Mapper in all cases. + # this is likely not needed, but as we dont have typing in loading.py + # yet, err on the safe side + polymorphic_from_mapper = polymorphic_from.mapper + not_against_basemost = polymorphic_from_mapper.inherits is not None + zero_idx = len(mapper.base_mapper.primary_key) == 1 - if entity.is_aliased_class: - q, enable_opt, disable_opt = mapper._subclass_load_via_in(entity) + if entity.is_aliased_class or not_against_basemost: + q, enable_opt, disable_opt = mapper._subclass_load_via_in( + entity, polymorphic_from + ) else: q, enable_opt, disable_opt = mapper._subclass_load_via_in_mapper def do_load(context, path, states, load_only, effective_entity): + if not option_entities: + # filter out states for those that would have selectinloaded + # from another loader + # TODO: we are currently ignoring the case where the + # "selectin_polymorphic" option is used, as this is much more + # complex / specific / very uncommon API use + states = [ + (s, v) + for s, v in states + if s.mapper._would_selectin_load_only_from_given_mapper(mapper) + ] + + if not states: + return + orig_query = context.query options = (enable_opt,) + orig_query._with_options + (disable_opt,) diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index c0ff2ed10..2ae6dadcd 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -3698,6 +3698,65 @@ class Mapper( if m is mapper: break + @HasMemoized.memoized_attribute + def _would_selectinload_combinations_cache(self): + return {} + + def _would_selectin_load_only_from_given_mapper(self, super_mapper): + """return True if this mapper would "selectin" polymorphic load based + on the given super mapper, and not from a setting from a subclass. + + given:: + + class A: + ... + + class B(A): + __mapper_args__ = {"polymorphic_load": "selectin"} + + class C(B): + ... + + class D(B): + __mapper_args__ = {"polymorphic_load": "selectin"} + + ``inspect(C)._would_selectin_load_only_from_given_mapper(inspect(B))`` + returns True, because C does selectin loading because of B's setting. + + OTOH, ``inspect(D) + ._would_selectin_load_only_from_given_mapper(inspect(B))`` + returns False, because D does selectin loading because of its own + setting; when we are doing a selectin poly load from B, we want to + filter out D because it would already have its own selectin poly load + set up separately. + + Added as part of #9373. + + """ + cache = self._would_selectinload_combinations_cache + + try: + return cache[super_mapper] + except KeyError: + pass + + # assert that given object is a supermapper, meaning we already + # strong reference it directly or indirectly. this allows us + # to not worry that we are creating new strongrefs to unrelated + # mappers or other objects. + assert self.isa(super_mapper) + + mapper = super_mapper + for m in self._iterate_to_target_viawpoly(mapper): + if m.polymorphic_load == "selectin": + retval = m is super_mapper + break + else: + retval = False + + cache[super_mapper] = retval + return retval + def _should_selectin_load(self, enabled_via_opt, polymorphic_from): if not enabled_via_opt: # common case, takes place for all polymorphic loads @@ -3721,7 +3780,7 @@ class Mapper( return None @util.preload_module("sqlalchemy.orm.strategy_options") - def _subclass_load_via_in(self, entity): + def _subclass_load_via_in(self, entity, polymorphic_from): """Assemble a that can load the columns local to this subclass as a SELECT with IN. @@ -3739,6 +3798,16 @@ class Mapper( disable_opt = strategy_options.Load(entity) enable_opt = strategy_options.Load(entity) + classes_to_include = {self} + m: Optional[Mapper[Any]] = self.inherits + while ( + m is not None + and m is not polymorphic_from + and m.polymorphic_load == "selectin" + ): + classes_to_include.add(m) + m = m.inherits + for prop in self.attrs: # skip prop keys that are not instrumented on the mapped class. @@ -3747,7 +3816,7 @@ class Mapper( if prop.key not in self.class_manager: continue - if prop.parent is self or prop in keep_props: + if prop.parent in classes_to_include or prop in keep_props: # "enable" options, to turn on the properties that we want to # load by default (subject to options from the query) if not isinstance(prop, StrategizedProperty): @@ -3811,7 +3880,8 @@ class Mapper( @HasMemoized.memoized_attribute def _subclass_load_via_in_mapper(self): - return self._subclass_load_via_in(self) + # the default is loading this mapper against the basemost mapper + return self._subclass_load_via_in(self, self.base_mapper) def cascade_iterator( self, diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index a5f34f3de..4c07bad23 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -41,6 +41,8 @@ from .interfaces import MapperProperty from .interfaces import PropComparator from .interfaces import StrategizedProperty from .relationships import RelationshipProperty +from .util import de_stringify_annotation +from .util import de_stringify_union_elements from .. import exc as sa_exc from .. import ForeignKey from .. import log @@ -52,8 +54,6 @@ from ..sql.schema import Column from ..sql.schema import SchemaConst from ..sql.type_api import TypeEngine from ..util.typing import de_optionalize_union_types -from ..util.typing import de_stringify_annotation -from ..util.typing import de_stringify_union_elements from ..util.typing import is_fwd_ref from ..util.typing import is_optional_union from ..util.typing import is_pep593 diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py index 3832664e5..787c5a4ab 100644 --- a/lib/sqlalchemy/orm/scoping.py +++ b/lib/sqlalchemy/orm/scoping.py @@ -76,7 +76,14 @@ if TYPE_CHECKING: _T = TypeVar("_T", bound=Any) -class _QueryDescriptorType(Protocol): +class QueryPropertyDescriptor(Protocol): + """Describes the type applied to a class-level + :meth:`_orm.scoped_session.query_property` attribute. + + .. versionadded:: 2.0.5 + + """ + def __get__(self, instance: Any, owner: Type[_T]) -> Query[_T]: ... @@ -254,17 +261,25 @@ class scoped_session(Generic[_S]): def query_property( self, query_cls: Optional[Type[Query[_T]]] = None - ) -> _QueryDescriptorType: - """return a class property which produces a :class:`_query.Query` - object - against the class and the current :class:`.Session` when called. + ) -> QueryPropertyDescriptor: + """return a class property which produces a legacy + :class:`_query.Query` object against the class and the current + :class:`.Session` when called. + + .. legacy:: The :meth:`_orm.scoped_session.query_property` accessor + is specific to the legacy :class:`.Query` object and is not + considered to be part of :term:`2.0-style` ORM use. e.g.:: + from sqlalchemy.orm import QueryPropertyDescriptor + from sqlalchemy.orm import scoped_session + from sqlalchemy.orm import sessionmaker + Session = scoped_session(sessionmaker()) class MyClass: - query = Session.query_property() + query: QueryPropertyDescriptor = Session.query_property() # after mappers are defined result = MyClass.query.filter(MyClass.name=='foo').all() diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index ad9ce2013..d3e36a494 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -9,6 +9,7 @@ from __future__ import annotations import enum +import functools import re import types import typing @@ -46,6 +47,7 @@ from .base import attribute_str as attribute_str # noqa: F401 from .base import class_mapper as class_mapper from .base import InspectionAttr as InspectionAttr from .base import instance_str as instance_str # noqa: F401 +from .base import Mapped from .base import object_mapper as object_mapper from .base import object_state as object_state # noqa: F401 from .base import opt_manager_of_class @@ -79,10 +81,14 @@ from ..sql.elements import ColumnElement from ..sql.elements import KeyedColumnElement from ..sql.selectable import FromClause from ..util.langhelpers import MemoizedSlots -from ..util.typing import de_stringify_annotation -from ..util.typing import eval_name_only +from ..util.typing import de_stringify_annotation as _de_stringify_annotation +from ..util.typing import ( + de_stringify_union_elements as _de_stringify_union_elements, +) +from ..util.typing import eval_name_only as _eval_name_only from ..util.typing import is_origin_of_cls from ..util.typing import Literal +from ..util.typing import Protocol from ..util.typing import typing_get_origin if typing.TYPE_CHECKING: @@ -113,6 +119,7 @@ if typing.TYPE_CHECKING: from ..sql.selectable import Subquery from ..sql.visitors import anon_map from ..util.typing import _AnnotationScanType + from ..util.typing import ArgsTypeProcotol _T = TypeVar("_T", bound=Any) @@ -130,6 +137,58 @@ all_cascades = frozenset( ) +_de_stringify_partial = functools.partial( + functools.partial, locals_=util.immutabledict({"Mapped": Mapped}) +) + +# partial is practically useless as we have to write out the whole +# function and maintain the signature anyway + + +class _DeStringifyAnnotation(Protocol): + def __call__( + self, + cls: Type[Any], + annotation: _AnnotationScanType, + originating_module: str, + *, + str_cleanup_fn: Optional[Callable[[str, str], str]] = None, + include_generic: bool = False, + ) -> Type[Any]: + ... + + +de_stringify_annotation = cast( + _DeStringifyAnnotation, _de_stringify_partial(_de_stringify_annotation) +) + + +class _DeStringifyUnionElements(Protocol): + def __call__( + self, + cls: Type[Any], + annotation: ArgsTypeProcotol, + originating_module: str, + *, + str_cleanup_fn: Optional[Callable[[str, str], str]] = None, + ) -> Type[Any]: + ... + + +de_stringify_union_elements = cast( + _DeStringifyUnionElements, + _de_stringify_partial(_de_stringify_union_elements), +) + + +class _EvalNameOnly(Protocol): + def __call__(self, name: str, module_name: str) -> Any: + ... + + +eval_name_only = cast(_EvalNameOnly, _de_stringify_partial(_eval_name_only)) + + class CascadeOptions(FrozenSet[str]): """Keeps track of the options sent to :paramref:`.relationship.cascade`""" @@ -994,7 +1053,7 @@ class AliasedInsp( @classmethod def _with_polymorphic_factory( cls, - base: Union[_O, Mapper[_O]], + base: Union[Type[_O], Mapper[_O]], classes: Union[Literal["*"], Iterable[_EntityType[Any]]], selectable: Union[Literal[False, None], FromClause] = False, flat: bool = False, @@ -2271,7 +2330,7 @@ def _extract_mapped_subtype( cls, raw_annotation, originating_module, - _cleanup_mapped_str_annotation, + str_cleanup_fn=_cleanup_mapped_str_annotation, ) except _CleanupError as ce: raise sa_exc.ArgumentError( diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index 6bf9a5a1f..a828d6a0f 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -11,6 +11,7 @@ import operator from typing import Any from typing import Callable from typing import Dict +from typing import Mapping from typing import Set from typing import Tuple from typing import Type @@ -238,6 +239,9 @@ the DMLColumnRole to be able to accommodate. """ +_DMLKey = TypeVar("_DMLKey", bound=_DMLColumnArgument) +_DMLColumnKeyMapping = Mapping[_DMLKey, Any] + _DDLColumnArgument = Union[str, "Column[Any]", roles.DDLConstraintColumnRole] """DDL column. diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 9042fdff7..dbbf09f1b 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -72,6 +72,7 @@ if TYPE_CHECKING: from ._typing import _ColumnExpressionArgument from ._typing import _ColumnsClauseArgument from ._typing import _DMLColumnArgument + from ._typing import _DMLColumnKeyMapping from ._typing import _DMLTableArgument from ._typing import _T0 # noqa from ._typing import _T1 # noqa @@ -944,7 +945,7 @@ class ValuesBase(UpdateBase): def values( self, *args: Union[ - Dict[_DMLColumnArgument, Any], + _DMLColumnKeyMapping[Any], Sequence[Any], ], **kwargs: Any, diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 9e6df0d35..24d8dd2dc 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -18,6 +18,7 @@ from typing import Dict from typing import ForwardRef from typing import Generic from typing import Iterable +from typing import Mapping from typing import NewType from typing import NoReturn from typing import Optional @@ -123,6 +124,8 @@ def de_stringify_annotation( cls: Type[Any], annotation: _AnnotationScanType, originating_module: str, + locals_: Mapping[str, Any], + *, str_cleanup_fn: Optional[Callable[[str, str], str]] = None, include_generic: bool = False, ) -> Type[Any]: @@ -150,7 +153,9 @@ def de_stringify_annotation( if str_cleanup_fn: annotation = str_cleanup_fn(annotation, originating_module) - annotation = eval_expression(annotation, originating_module) + annotation = eval_expression( + annotation, originating_module, locals_=locals_ + ) if ( include_generic @@ -162,6 +167,7 @@ def de_stringify_annotation( cls, elem, originating_module, + locals_, str_cleanup_fn=str_cleanup_fn, include_generic=include_generic, ) @@ -183,7 +189,12 @@ def _copy_generic_annotation_with( return annotation.__origin__[elements] # type: ignore -def eval_expression(expression: str, module_name: str) -> Any: +def eval_expression( + expression: str, + module_name: str, + *, + locals_: Optional[Mapping[str, Any]] = None, +) -> Any: try: base_globals: Dict[str, Any] = sys.modules[module_name].__dict__ except KeyError as ke: @@ -191,8 +202,9 @@ def eval_expression(expression: str, module_name: str) -> Any: f"Module {module_name} isn't present in sys.modules; can't " f"evaluate expression {expression}" ) from ke + try: - annotation = eval(expression, base_globals, None) + annotation = eval(expression, base_globals, locals_) except Exception as err: raise NameError( f"Could not de-stringify annotation {expression!r}" @@ -201,9 +213,14 @@ def eval_expression(expression: str, module_name: str) -> Any: return annotation -def eval_name_only(name: str, module_name: str) -> Any: +def eval_name_only( + name: str, + module_name: str, + *, + locals_: Optional[Mapping[str, Any]] = None, +) -> Any: if "." in name: - return eval_expression(name, module_name) + return eval_expression(name, module_name, locals_=locals_) try: base_globals: Dict[str, Any] = sys.modules[module_name].__dict__ @@ -237,12 +254,18 @@ def de_stringify_union_elements( cls: Type[Any], annotation: ArgsTypeProcotol, originating_module: str, + locals_: Mapping[str, Any], + *, str_cleanup_fn: Optional[Callable[[str, str], str]] = None, ) -> Type[Any]: return make_union_type( *[ de_stringify_annotation( - cls, anno, originating_module, str_cleanup_fn + cls, + anno, + originating_module, + {}, + str_cleanup_fn=str_cleanup_fn, ) for anno in annotation.__args__ ] |
