summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-07-03 16:25:15 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-07-03 22:33:48 -0400
commit148711cb8515a19b6177dc07655cc6e652de0553 (patch)
treeb75505c907d25395d77f45b94919b9a17e9432cf /lib
parent4b3f204d07d53ae09b59ce8f33b534f26a605cd4 (diff)
downloadsqlalchemy-148711cb8515a19b6177dc07655cc6e652de0553.tar.gz
runtime annotation fixes for relationship
* derive uselist=False when fwd ref passed to relationship This case needs to work whether or not the class name is a forward ref. we dont allow the colleciton to be a forward ref so this will work. * fix issues with MappedCollection When using string annotations or __future__.annotations, we need to do more parsing in order to get the target collection properly Change-Id: I9e5a1358b62d060a8815826f98190801a9cc0b68
Diffstat (limited to 'lib')
-rw-r--r--lib/sqlalchemy/orm/__init__.py4
-rw-r--r--lib/sqlalchemy/orm/clsregistry.py3
-rw-r--r--lib/sqlalchemy/orm/relationships.py9
-rw-r--r--lib/sqlalchemy/orm/util.py30
-rw-r--r--lib/sqlalchemy/util/typing.py6
5 files changed, 42 insertions, 10 deletions
diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py
index 4f19ba946..539cf2600 100644
--- a/lib/sqlalchemy/orm/__init__.py
+++ b/lib/sqlalchemy/orm/__init__.py
@@ -87,6 +87,10 @@ from .interfaces import PropComparator as PropComparator
from .interfaces import UserDefinedOption as UserDefinedOption
from .loading import merge_frozen_result as merge_frozen_result
from .loading import merge_result as merge_result
+from .mapped_collection import attribute_mapped_collection
+from .mapped_collection import column_mapped_collection
+from .mapped_collection import mapped_collection
+from .mapped_collection import MappedCollection
from .mapper import configure_mappers as configure_mappers
from .mapper import Mapper as Mapper
from .mapper import reconstructor as reconstructor
diff --git a/lib/sqlalchemy/orm/clsregistry.py b/lib/sqlalchemy/orm/clsregistry.py
index b3fcd29ea..dd79eb1d0 100644
--- a/lib/sqlalchemy/orm/clsregistry.py
+++ b/lib/sqlalchemy/orm/clsregistry.py
@@ -463,6 +463,7 @@ class _class_resolver:
generic_match = re.match(r"(.+)\[(.+)\]", name)
if generic_match:
+ clsarg = generic_match.group(2).strip("'")
raise exc.InvalidRequestError(
f"When initializing mapper {self.prop.parent}, "
f'expression "relationship({self.arg!r})" seems to be '
@@ -470,7 +471,7 @@ class _class_resolver:
"please state the generic argument "
"using an annotation, e.g. "
f'"{self.prop.key}: Mapped[{generic_match.group(1)}'
- f'[{generic_match.group(2)}]] = relationship()"'
+ f"['{clsarg}']] = relationship()\""
) from err
else:
raise exc.InvalidRequestError(
diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py
index 630f6898f..77a95a195 100644
--- a/lib/sqlalchemy/orm/relationships.py
+++ b/lib/sqlalchemy/orm/relationships.py
@@ -1724,11 +1724,12 @@ class Relationship(
self.collection_class = collection_class
else:
self.uselist = False
+
if argument.__args__: # type: ignore
if issubclass(
argument.__origin__, typing.Mapping # type: ignore
):
- type_arg = argument.__args__[1] # type: ignore
+ type_arg = argument.__args__[-1] # type: ignore
else:
type_arg = argument.__args__[0] # type: ignore
if hasattr(type_arg, "__forward_arg__"):
@@ -1743,6 +1744,12 @@ class Relationship(
elif hasattr(argument, "__forward_arg__"):
argument = argument.__forward_arg__ # type: ignore
+ # we don't allow the collection class to be a
+ # __forward_arg__ right now, so if we see a forward arg here,
+ # we know there was no collection class either
+ if self.collection_class is None:
+ self.uselist = False
+
self.argument = argument
@util.preload_module("sqlalchemy.orm.mapper")
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 317abe2b4..02080a27f 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -1958,8 +1958,12 @@ def _getitem(iterable_query: Query[Any], item: Any) -> Any:
def _is_mapped_annotation(
raw_annotation: _AnnotationScanType, cls: Type[Any]
) -> bool:
- annotated = de_stringify_annotation(cls, raw_annotation)
- return is_origin_of(annotated, "Mapped", module="sqlalchemy.orm")
+ try:
+ annotated = de_stringify_annotation(cls, raw_annotation)
+ except NameError:
+ return False
+ else:
+ return is_origin_of(annotated, "Mapped", module="sqlalchemy.orm")
def _cleanup_mapped_str_annotation(annotation: str) -> str:
@@ -1984,7 +1988,10 @@ def _cleanup_mapped_str_annotation(annotation: str) -> str:
# stack: ['Mapped', 'List', 'Address']
if not re.match(r"""^["'].*["']$""", stack[-1]):
- stack[-1] = f'"{stack[-1]}"'
+ stripchars = "\"' "
+ stack[-1] = ", ".join(
+ f'"{elem.strip(stripchars)}"' for elem in stack[-1].split(",")
+ )
# stack: ['Mapped', 'List', '"Address"']
annotation = "[".join(stack) + ("]" * (len(stack) - 1))
@@ -2007,6 +2014,7 @@ def _extract_mapped_subtype(
Includes error raise scenarios and other options.
"""
+
if raw_annotation is None:
if required:
@@ -2017,9 +2025,19 @@ def _extract_mapped_subtype(
)
return None
- annotated = de_stringify_annotation(
- cls, raw_annotation, _cleanup_mapped_str_annotation
- )
+ try:
+ annotated = de_stringify_annotation(
+ cls, raw_annotation, _cleanup_mapped_str_annotation
+ )
+ except NameError as ne:
+ if raiseerr and "Mapped[" in raw_annotation: # type: ignore
+ raise sa_exc.ArgumentError(
+ f"Could not interpret annotation {raw_annotation}. "
+ "Check that it's not using names that might not be imported "
+ "at the module level. See chained stack trace for more hints."
+ ) from ne
+
+ annotated = raw_annotation # type: ignore
if is_dataclass_field:
return annotated
diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py
index 653301f1f..45fe63765 100644
--- a/lib/sqlalchemy/util/typing.py
+++ b/lib/sqlalchemy/util/typing.py
@@ -113,8 +113,10 @@ def de_stringify_annotation(
try:
annotation = eval(annotation, base_globals, None)
- except NameError:
- pass
+ except NameError as err:
+ raise NameError(
+ f"Could not de-stringify annotation {annotation}"
+ ) from err
return annotation # type: ignore