summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/sqlalchemy/cyextension/resultproxy.pyx5
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py3
-rw-r--r--lib/sqlalchemy/dialects/postgresql/ext.py15
-rw-r--r--lib/sqlalchemy/ext/automap.py20
-rw-r--r--lib/sqlalchemy/orm/__init__.py1
-rw-r--r--lib/sqlalchemy/orm/_orm_constructors.py2
-rw-r--r--lib/sqlalchemy/orm/decl_base.py2
-rw-r--r--lib/sqlalchemy/orm/descriptor_props.py2
-rw-r--r--lib/sqlalchemy/orm/loading.py54
-rw-r--r--lib/sqlalchemy/orm/mapper.py76
-rw-r--r--lib/sqlalchemy/orm/properties.py4
-rw-r--r--lib/sqlalchemy/orm/scoping.py27
-rw-r--r--lib/sqlalchemy/orm/util.py67
-rw-r--r--lib/sqlalchemy/sql/_typing.py4
-rw-r--r--lib/sqlalchemy/sql/dml.py3
-rw-r--r--lib/sqlalchemy/util/typing.py35
16 files changed, 267 insertions, 53 deletions
diff --git a/lib/sqlalchemy/cyextension/resultproxy.pyx b/lib/sqlalchemy/cyextension/resultproxy.pyx
index e88c8ec0b..96a028d93 100644
--- a/lib/sqlalchemy/cyextension/resultproxy.pyx
+++ b/lib/sqlalchemy/cyextension/resultproxy.pyx
@@ -3,9 +3,10 @@
import operator
cdef int MD_INDEX = 0 # integer index in cursor.description
+cdef int _KEY_OBJECTS_ONLY = 1
KEY_INTEGER_ONLY = 0
-KEY_OBJECTS_ONLY = 1
+KEY_OBJECTS_ONLY = _KEY_OBJECTS_ONLY
cdef class BaseRow:
cdef readonly object _parent
@@ -76,7 +77,7 @@ cdef class BaseRow:
if mdindex is None:
self._parent._raise_for_ambiguous_column_name(rec)
elif (
- self._key_style == KEY_OBJECTS_ONLY
+ self._key_style == _KEY_OBJECTS_ONLY
and isinstance(key, int)
):
raise KeyError(key)
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index 255c72042..3ba103802 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -2350,8 +2350,9 @@ class PGDDLCompiler(compiler.DDLCompiler):
constraint
)
elements = []
+ kw["include_table"] = False
+ kw["literal_binds"] = True
for expr, name, op in constraint._render_exprs:
- kw["include_table"] = False
exclude_element = self.sql_compiler.process(expr, **kw) + (
(" " + constraint.ops[expr.key])
if hasattr(expr, "key") and expr.key in constraint.ops
diff --git a/lib/sqlalchemy/dialects/postgresql/ext.py b/lib/sqlalchemy/dialects/postgresql/ext.py
index 0f5efb1de..22604955d 100644
--- a/lib/sqlalchemy/dialects/postgresql/ext.py
+++ b/lib/sqlalchemy/dialects/postgresql/ext.py
@@ -164,16 +164,15 @@ class ExcludeConstraint(ColumnCollectionConstraint):
:param \*elements:
A sequence of two tuples of the form ``(column, operator)`` where
- "column" is a SQL expression element or a raw SQL string, most
- typically a :class:`_schema.Column` object,
- and "operator" is a string
- containing the operator to use. In order to specify a column name
- when a :class:`_schema.Column` object is not available,
- while ensuring
+ "column" is a SQL expression element or the name of a column as
+ string, most typically a :class:`_schema.Column` object,
+ and "operator" is a string containing the operator to use.
+ In order to specify a column name when a :class:`_schema.Column`
+ object is not available, while ensuring
that any necessary quoting rules take effect, an ad-hoc
:class:`_schema.Column` or :func:`_expression.column`
- object should be
- used.
+ object should be used. ``column`` may also be a string SQL
+ expression when passed as :func:`_expression.literal_column`
:param name:
Optional, the in-database name of this constraint.
diff --git a/lib/sqlalchemy/ext/automap.py b/lib/sqlalchemy/ext/automap.py
index 030015284..1861791b7 100644
--- a/lib/sqlalchemy/ext/automap.py
+++ b/lib/sqlalchemy/ext/automap.py
@@ -1188,6 +1188,14 @@ class AutomapBase:
.. versionadded:: 1.4
"""
+
+ for mr in cls.__mro__:
+ if "_sa_automapbase_bookkeeping" in mr.__dict__:
+ automap_base = cast("Type[AutomapBase]", mr)
+ break
+ else:
+ assert False, "Can't locate automap base in class hierarchy"
+
glbls = globals()
if classname_for_table is None:
classname_for_table = glbls["classname_for_table"]
@@ -1237,7 +1245,7 @@ class AutomapBase:
]
many_to_many = []
- bookkeeping = cls._sa_automapbase_bookkeeping
+ bookkeeping = automap_base._sa_automapbase_bookkeeping
metadata_tables = cls.metadata.tables
for table_key in set(metadata_tables).difference(
@@ -1278,7 +1286,7 @@ class AutomapBase:
mapped_cls = type(
newname,
- (cls,),
+ (automap_base,),
clsdict,
)
map_config = _DeferredMapperConfig.config_for_cls(
@@ -1309,7 +1317,7 @@ class AutomapBase:
for map_config in table_to_map_config.values():
_relationships_for_fks(
- cls,
+ automap_base,
map_config,
table_to_map_config,
collection_class,
@@ -1320,7 +1328,7 @@ class AutomapBase:
for lcl_m2m, rem_m2m, m2m_const, table in many_to_many:
_m2m_relationship(
- cls,
+ automap_base,
lcl_m2m,
rem_m2m,
m2m_const,
@@ -1332,7 +1340,9 @@ class AutomapBase:
generate_relationship,
)
- for map_config in _DeferredMapperConfig.classes_for_base(cls):
+ for map_config in _DeferredMapperConfig.classes_for_base(
+ automap_base
+ ):
map_config.map()
_sa_decl_prepare = True
diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py
index d54e1ccb9..69cd7f598 100644
--- a/lib/sqlalchemy/orm/__init__.py
+++ b/lib/sqlalchemy/orm/__init__.py
@@ -120,6 +120,7 @@ from .relationships import foreign as foreign
from .relationships import Relationship as Relationship
from .relationships import RelationshipProperty as RelationshipProperty
from .relationships import remote as remote
+from .scoping import QueryPropertyDescriptor as QueryPropertyDescriptor
from .scoping import scoped_session as scoped_session
from .session import close_all_sessions as close_all_sessions
from .session import make_transient as make_transient
diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py
index 3bd1db79d..64e7937f1 100644
--- a/lib/sqlalchemy/orm/_orm_constructors.py
+++ b/lib/sqlalchemy/orm/_orm_constructors.py
@@ -2208,7 +2208,7 @@ def aliased(
def with_polymorphic(
- base: Union[_O, Mapper[_O]],
+ base: Union[Type[_O], Mapper[_O]],
classes: Union[Literal["*"], Iterable[Type[Any]]],
selectable: Union[Literal[False, None], FromClause] = False,
flat: bool = False,
diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py
index 29d748596..d01aad439 100644
--- a/lib/sqlalchemy/orm/decl_base.py
+++ b/lib/sqlalchemy/orm/decl_base.py
@@ -55,6 +55,7 @@ from .properties import MappedColumn
from .util import _extract_mapped_subtype
from .util import _is_mapped_annotation
from .util import class_mapper
+from .util import de_stringify_annotation
from .. import event
from .. import exc
from .. import util
@@ -64,7 +65,6 @@ from ..sql.schema import Column
from ..sql.schema import Table
from ..util import topological
from ..util.typing import _AnnotationScanType
-from ..util.typing import de_stringify_annotation
from ..util.typing import is_fwd_ref
from ..util.typing import is_literal
from ..util.typing import Protocol
diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py
index b65171c9d..fd28830d9 100644
--- a/lib/sqlalchemy/orm/descriptor_props.py
+++ b/lib/sqlalchemy/orm/descriptor_props.py
@@ -44,6 +44,7 @@ from .interfaces import _MapsColumns
from .interfaces import MapperProperty
from .interfaces import PropComparator
from .util import _none_set
+from .util import de_stringify_annotation
from .. import event
from .. import exc as sa_exc
from .. import schema
@@ -52,7 +53,6 @@ from .. import util
from ..sql import expression
from ..sql import operators
from ..sql.elements import BindParameter
-from ..util.typing import de_stringify_annotation
from ..util.typing import is_fwd_ref
from ..util.typing import is_pep593
from ..util.typing import typing_get_args
diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py
index 54b96c215..7974d94c5 100644
--- a/lib/sqlalchemy/orm/loading.py
+++ b/lib/sqlalchemy/orm/loading.py
@@ -972,17 +972,17 @@ def _instance_processor(
if not refresh_state and _polymorphic_from is not None:
key = ("loader", path.path)
+
if key in context.attributes and context.attributes[key].strategy == (
("selectinload_polymorphic", True),
):
- selectin_load_via = mapper._should_selectin_load(
- context.attributes[key].local_opts["entities"],
- _polymorphic_from,
- )
+ option_entities = context.attributes[key].local_opts["entities"]
else:
- selectin_load_via = mapper._should_selectin_load(
- None, _polymorphic_from
- )
+ option_entities = None
+ selectin_load_via = mapper._should_selectin_load(
+ option_entities,
+ _polymorphic_from,
+ )
if selectin_load_via and selectin_load_via is not _polymorphic_from:
# only_load_props goes w/ refresh_state only, and in a refresh
@@ -990,8 +990,13 @@ def _instance_processor(
# loading does not apply
assert only_load_props is None
- callable_ = _load_subclass_via_in(context, path, selectin_load_via)
-
+ callable_ = _load_subclass_via_in(
+ context,
+ path,
+ selectin_load_via,
+ _polymorphic_from,
+ option_entities,
+ )
PostLoad.callable_for_path(
context,
load_path,
@@ -1212,17 +1217,42 @@ def _instance_processor(
return _instance
-def _load_subclass_via_in(context, path, entity):
+def _load_subclass_via_in(
+ context, path, entity, polymorphic_from, option_entities
+):
mapper = entity.mapper
+ # TODO: polymorphic_from seems to be a Mapper in all cases.
+ # this is likely not needed, but as we dont have typing in loading.py
+ # yet, err on the safe side
+ polymorphic_from_mapper = polymorphic_from.mapper
+ not_against_basemost = polymorphic_from_mapper.inherits is not None
+
zero_idx = len(mapper.base_mapper.primary_key) == 1
- if entity.is_aliased_class:
- q, enable_opt, disable_opt = mapper._subclass_load_via_in(entity)
+ if entity.is_aliased_class or not_against_basemost:
+ q, enable_opt, disable_opt = mapper._subclass_load_via_in(
+ entity, polymorphic_from
+ )
else:
q, enable_opt, disable_opt = mapper._subclass_load_via_in_mapper
def do_load(context, path, states, load_only, effective_entity):
+ if not option_entities:
+ # filter out states for those that would have selectinloaded
+ # from another loader
+ # TODO: we are currently ignoring the case where the
+ # "selectin_polymorphic" option is used, as this is much more
+ # complex / specific / very uncommon API use
+ states = [
+ (s, v)
+ for s, v in states
+ if s.mapper._would_selectin_load_only_from_given_mapper(mapper)
+ ]
+
+ if not states:
+ return
+
orig_query = context.query
options = (enable_opt,) + orig_query._with_options + (disable_opt,)
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index c0ff2ed10..2ae6dadcd 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -3698,6 +3698,65 @@ class Mapper(
if m is mapper:
break
+ @HasMemoized.memoized_attribute
+ def _would_selectinload_combinations_cache(self):
+ return {}
+
+ def _would_selectin_load_only_from_given_mapper(self, super_mapper):
+ """return True if this mapper would "selectin" polymorphic load based
+ on the given super mapper, and not from a setting from a subclass.
+
+ given::
+
+ class A:
+ ...
+
+ class B(A):
+ __mapper_args__ = {"polymorphic_load": "selectin"}
+
+ class C(B):
+ ...
+
+ class D(B):
+ __mapper_args__ = {"polymorphic_load": "selectin"}
+
+ ``inspect(C)._would_selectin_load_only_from_given_mapper(inspect(B))``
+ returns True, because C does selectin loading because of B's setting.
+
+ OTOH, ``inspect(D)
+ ._would_selectin_load_only_from_given_mapper(inspect(B))``
+ returns False, because D does selectin loading because of its own
+ setting; when we are doing a selectin poly load from B, we want to
+ filter out D because it would already have its own selectin poly load
+ set up separately.
+
+ Added as part of #9373.
+
+ """
+ cache = self._would_selectinload_combinations_cache
+
+ try:
+ return cache[super_mapper]
+ except KeyError:
+ pass
+
+ # assert that given object is a supermapper, meaning we already
+ # strong reference it directly or indirectly. this allows us
+ # to not worry that we are creating new strongrefs to unrelated
+ # mappers or other objects.
+ assert self.isa(super_mapper)
+
+ mapper = super_mapper
+ for m in self._iterate_to_target_viawpoly(mapper):
+ if m.polymorphic_load == "selectin":
+ retval = m is super_mapper
+ break
+ else:
+ retval = False
+
+ cache[super_mapper] = retval
+ return retval
+
def _should_selectin_load(self, enabled_via_opt, polymorphic_from):
if not enabled_via_opt:
# common case, takes place for all polymorphic loads
@@ -3721,7 +3780,7 @@ class Mapper(
return None
@util.preload_module("sqlalchemy.orm.strategy_options")
- def _subclass_load_via_in(self, entity):
+ def _subclass_load_via_in(self, entity, polymorphic_from):
"""Assemble a that can load the columns local to
this subclass as a SELECT with IN.
@@ -3739,6 +3798,16 @@ class Mapper(
disable_opt = strategy_options.Load(entity)
enable_opt = strategy_options.Load(entity)
+ classes_to_include = {self}
+ m: Optional[Mapper[Any]] = self.inherits
+ while (
+ m is not None
+ and m is not polymorphic_from
+ and m.polymorphic_load == "selectin"
+ ):
+ classes_to_include.add(m)
+ m = m.inherits
+
for prop in self.attrs:
# skip prop keys that are not instrumented on the mapped class.
@@ -3747,7 +3816,7 @@ class Mapper(
if prop.key not in self.class_manager:
continue
- if prop.parent is self or prop in keep_props:
+ if prop.parent in classes_to_include or prop in keep_props:
# "enable" options, to turn on the properties that we want to
# load by default (subject to options from the query)
if not isinstance(prop, StrategizedProperty):
@@ -3811,7 +3880,8 @@ class Mapper(
@HasMemoized.memoized_attribute
def _subclass_load_via_in_mapper(self):
- return self._subclass_load_via_in(self)
+ # the default is loading this mapper against the basemost mapper
+ return self._subclass_load_via_in(self, self.base_mapper)
def cascade_iterator(
self,
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index a5f34f3de..4c07bad23 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -41,6 +41,8 @@ from .interfaces import MapperProperty
from .interfaces import PropComparator
from .interfaces import StrategizedProperty
from .relationships import RelationshipProperty
+from .util import de_stringify_annotation
+from .util import de_stringify_union_elements
from .. import exc as sa_exc
from .. import ForeignKey
from .. import log
@@ -52,8 +54,6 @@ from ..sql.schema import Column
from ..sql.schema import SchemaConst
from ..sql.type_api import TypeEngine
from ..util.typing import de_optionalize_union_types
-from ..util.typing import de_stringify_annotation
-from ..util.typing import de_stringify_union_elements
from ..util.typing import is_fwd_ref
from ..util.typing import is_optional_union
from ..util.typing import is_pep593
diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py
index 3832664e5..787c5a4ab 100644
--- a/lib/sqlalchemy/orm/scoping.py
+++ b/lib/sqlalchemy/orm/scoping.py
@@ -76,7 +76,14 @@ if TYPE_CHECKING:
_T = TypeVar("_T", bound=Any)
-class _QueryDescriptorType(Protocol):
+class QueryPropertyDescriptor(Protocol):
+ """Describes the type applied to a class-level
+ :meth:`_orm.scoped_session.query_property` attribute.
+
+ .. versionadded:: 2.0.5
+
+ """
+
def __get__(self, instance: Any, owner: Type[_T]) -> Query[_T]:
...
@@ -254,17 +261,25 @@ class scoped_session(Generic[_S]):
def query_property(
self, query_cls: Optional[Type[Query[_T]]] = None
- ) -> _QueryDescriptorType:
- """return a class property which produces a :class:`_query.Query`
- object
- against the class and the current :class:`.Session` when called.
+ ) -> QueryPropertyDescriptor:
+ """return a class property which produces a legacy
+ :class:`_query.Query` object against the class and the current
+ :class:`.Session` when called.
+
+ .. legacy:: The :meth:`_orm.scoped_session.query_property` accessor
+ is specific to the legacy :class:`.Query` object and is not
+ considered to be part of :term:`2.0-style` ORM use.
e.g.::
+ from sqlalchemy.orm import QueryPropertyDescriptor
+ from sqlalchemy.orm import scoped_session
+ from sqlalchemy.orm import sessionmaker
+
Session = scoped_session(sessionmaker())
class MyClass:
- query = Session.query_property()
+ query: QueryPropertyDescriptor = Session.query_property()
# after mappers are defined
result = MyClass.query.filter(MyClass.name=='foo').all()
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index ad9ce2013..d3e36a494 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -9,6 +9,7 @@
from __future__ import annotations
import enum
+import functools
import re
import types
import typing
@@ -46,6 +47,7 @@ from .base import attribute_str as attribute_str # noqa: F401
from .base import class_mapper as class_mapper
from .base import InspectionAttr as InspectionAttr
from .base import instance_str as instance_str # noqa: F401
+from .base import Mapped
from .base import object_mapper as object_mapper
from .base import object_state as object_state # noqa: F401
from .base import opt_manager_of_class
@@ -79,10 +81,14 @@ from ..sql.elements import ColumnElement
from ..sql.elements import KeyedColumnElement
from ..sql.selectable import FromClause
from ..util.langhelpers import MemoizedSlots
-from ..util.typing import de_stringify_annotation
-from ..util.typing import eval_name_only
+from ..util.typing import de_stringify_annotation as _de_stringify_annotation
+from ..util.typing import (
+ de_stringify_union_elements as _de_stringify_union_elements,
+)
+from ..util.typing import eval_name_only as _eval_name_only
from ..util.typing import is_origin_of_cls
from ..util.typing import Literal
+from ..util.typing import Protocol
from ..util.typing import typing_get_origin
if typing.TYPE_CHECKING:
@@ -113,6 +119,7 @@ if typing.TYPE_CHECKING:
from ..sql.selectable import Subquery
from ..sql.visitors import anon_map
from ..util.typing import _AnnotationScanType
+ from ..util.typing import ArgsTypeProcotol
_T = TypeVar("_T", bound=Any)
@@ -130,6 +137,58 @@ all_cascades = frozenset(
)
+_de_stringify_partial = functools.partial(
+ functools.partial, locals_=util.immutabledict({"Mapped": Mapped})
+)
+
+# partial is practically useless as we have to write out the whole
+# function and maintain the signature anyway
+
+
+class _DeStringifyAnnotation(Protocol):
+ def __call__(
+ self,
+ cls: Type[Any],
+ annotation: _AnnotationScanType,
+ originating_module: str,
+ *,
+ str_cleanup_fn: Optional[Callable[[str, str], str]] = None,
+ include_generic: bool = False,
+ ) -> Type[Any]:
+ ...
+
+
+de_stringify_annotation = cast(
+ _DeStringifyAnnotation, _de_stringify_partial(_de_stringify_annotation)
+)
+
+
+class _DeStringifyUnionElements(Protocol):
+ def __call__(
+ self,
+ cls: Type[Any],
+ annotation: ArgsTypeProcotol,
+ originating_module: str,
+ *,
+ str_cleanup_fn: Optional[Callable[[str, str], str]] = None,
+ ) -> Type[Any]:
+ ...
+
+
+de_stringify_union_elements = cast(
+ _DeStringifyUnionElements,
+ _de_stringify_partial(_de_stringify_union_elements),
+)
+
+
+class _EvalNameOnly(Protocol):
+ def __call__(self, name: str, module_name: str) -> Any:
+ ...
+
+
+eval_name_only = cast(_EvalNameOnly, _de_stringify_partial(_eval_name_only))
+
+
class CascadeOptions(FrozenSet[str]):
"""Keeps track of the options sent to
:paramref:`.relationship.cascade`"""
@@ -994,7 +1053,7 @@ class AliasedInsp(
@classmethod
def _with_polymorphic_factory(
cls,
- base: Union[_O, Mapper[_O]],
+ base: Union[Type[_O], Mapper[_O]],
classes: Union[Literal["*"], Iterable[_EntityType[Any]]],
selectable: Union[Literal[False, None], FromClause] = False,
flat: bool = False,
@@ -2271,7 +2330,7 @@ def _extract_mapped_subtype(
cls,
raw_annotation,
originating_module,
- _cleanup_mapped_str_annotation,
+ str_cleanup_fn=_cleanup_mapped_str_annotation,
)
except _CleanupError as ce:
raise sa_exc.ArgumentError(
diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py
index 6bf9a5a1f..a828d6a0f 100644
--- a/lib/sqlalchemy/sql/_typing.py
+++ b/lib/sqlalchemy/sql/_typing.py
@@ -11,6 +11,7 @@ import operator
from typing import Any
from typing import Callable
from typing import Dict
+from typing import Mapping
from typing import Set
from typing import Tuple
from typing import Type
@@ -238,6 +239,9 @@ the DMLColumnRole to be able to accommodate.
"""
+_DMLKey = TypeVar("_DMLKey", bound=_DMLColumnArgument)
+_DMLColumnKeyMapping = Mapping[_DMLKey, Any]
+
_DDLColumnArgument = Union[str, "Column[Any]", roles.DDLConstraintColumnRole]
"""DDL column.
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index 9042fdff7..dbbf09f1b 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -72,6 +72,7 @@ if TYPE_CHECKING:
from ._typing import _ColumnExpressionArgument
from ._typing import _ColumnsClauseArgument
from ._typing import _DMLColumnArgument
+ from ._typing import _DMLColumnKeyMapping
from ._typing import _DMLTableArgument
from ._typing import _T0 # noqa
from ._typing import _T1 # noqa
@@ -944,7 +945,7 @@ class ValuesBase(UpdateBase):
def values(
self,
*args: Union[
- Dict[_DMLColumnArgument, Any],
+ _DMLColumnKeyMapping[Any],
Sequence[Any],
],
**kwargs: Any,
diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py
index 9e6df0d35..24d8dd2dc 100644
--- a/lib/sqlalchemy/util/typing.py
+++ b/lib/sqlalchemy/util/typing.py
@@ -18,6 +18,7 @@ from typing import Dict
from typing import ForwardRef
from typing import Generic
from typing import Iterable
+from typing import Mapping
from typing import NewType
from typing import NoReturn
from typing import Optional
@@ -123,6 +124,8 @@ def de_stringify_annotation(
cls: Type[Any],
annotation: _AnnotationScanType,
originating_module: str,
+ locals_: Mapping[str, Any],
+ *,
str_cleanup_fn: Optional[Callable[[str, str], str]] = None,
include_generic: bool = False,
) -> Type[Any]:
@@ -150,7 +153,9 @@ def de_stringify_annotation(
if str_cleanup_fn:
annotation = str_cleanup_fn(annotation, originating_module)
- annotation = eval_expression(annotation, originating_module)
+ annotation = eval_expression(
+ annotation, originating_module, locals_=locals_
+ )
if (
include_generic
@@ -162,6 +167,7 @@ def de_stringify_annotation(
cls,
elem,
originating_module,
+ locals_,
str_cleanup_fn=str_cleanup_fn,
include_generic=include_generic,
)
@@ -183,7 +189,12 @@ def _copy_generic_annotation_with(
return annotation.__origin__[elements] # type: ignore
-def eval_expression(expression: str, module_name: str) -> Any:
+def eval_expression(
+ expression: str,
+ module_name: str,
+ *,
+ locals_: Optional[Mapping[str, Any]] = None,
+) -> Any:
try:
base_globals: Dict[str, Any] = sys.modules[module_name].__dict__
except KeyError as ke:
@@ -191,8 +202,9 @@ def eval_expression(expression: str, module_name: str) -> Any:
f"Module {module_name} isn't present in sys.modules; can't "
f"evaluate expression {expression}"
) from ke
+
try:
- annotation = eval(expression, base_globals, None)
+ annotation = eval(expression, base_globals, locals_)
except Exception as err:
raise NameError(
f"Could not de-stringify annotation {expression!r}"
@@ -201,9 +213,14 @@ def eval_expression(expression: str, module_name: str) -> Any:
return annotation
-def eval_name_only(name: str, module_name: str) -> Any:
+def eval_name_only(
+ name: str,
+ module_name: str,
+ *,
+ locals_: Optional[Mapping[str, Any]] = None,
+) -> Any:
if "." in name:
- return eval_expression(name, module_name)
+ return eval_expression(name, module_name, locals_=locals_)
try:
base_globals: Dict[str, Any] = sys.modules[module_name].__dict__
@@ -237,12 +254,18 @@ def de_stringify_union_elements(
cls: Type[Any],
annotation: ArgsTypeProcotol,
originating_module: str,
+ locals_: Mapping[str, Any],
+ *,
str_cleanup_fn: Optional[Callable[[str, str], str]] = None,
) -> Type[Any]:
return make_union_type(
*[
de_stringify_annotation(
- cls, anno, originating_module, str_cleanup_fn
+ cls,
+ anno,
+ originating_module,
+ {},
+ str_cleanup_fn=str_cleanup_fn,
)
for anno in annotation.__args__
]