diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-07-03 16:25:15 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-07-03 22:33:48 -0400 |
| commit | 148711cb8515a19b6177dc07655cc6e652de0553 (patch) | |
| tree | b75505c907d25395d77f45b94919b9a17e9432cf /lib | |
| parent | 4b3f204d07d53ae09b59ce8f33b534f26a605cd4 (diff) | |
| download | sqlalchemy-148711cb8515a19b6177dc07655cc6e652de0553.tar.gz | |
runtime annotation fixes for relationship
* derive uselist=False when fwd ref passed to relationship
This case needs to work whether or not the class name
is a forward ref. we dont allow the colleciton to be a
forward ref so this will work.
* fix issues with MappedCollection
When using string annotations or __future__.annotations,
we need to do more parsing in order to get the target
collection properly
Change-Id: I9e5a1358b62d060a8815826f98190801a9cc0b68
Diffstat (limited to 'lib')
| -rw-r--r-- | lib/sqlalchemy/orm/__init__.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/clsregistry.py | 3 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/relationships.py | 9 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/util.py | 30 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/typing.py | 6 |
5 files changed, 42 insertions, 10 deletions
diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index 4f19ba946..539cf2600 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -87,6 +87,10 @@ from .interfaces import PropComparator as PropComparator from .interfaces import UserDefinedOption as UserDefinedOption from .loading import merge_frozen_result as merge_frozen_result from .loading import merge_result as merge_result +from .mapped_collection import attribute_mapped_collection +from .mapped_collection import column_mapped_collection +from .mapped_collection import mapped_collection +from .mapped_collection import MappedCollection from .mapper import configure_mappers as configure_mappers from .mapper import Mapper as Mapper from .mapper import reconstructor as reconstructor diff --git a/lib/sqlalchemy/orm/clsregistry.py b/lib/sqlalchemy/orm/clsregistry.py index b3fcd29ea..dd79eb1d0 100644 --- a/lib/sqlalchemy/orm/clsregistry.py +++ b/lib/sqlalchemy/orm/clsregistry.py @@ -463,6 +463,7 @@ class _class_resolver: generic_match = re.match(r"(.+)\[(.+)\]", name) if generic_match: + clsarg = generic_match.group(2).strip("'") raise exc.InvalidRequestError( f"When initializing mapper {self.prop.parent}, " f'expression "relationship({self.arg!r})" seems to be ' @@ -470,7 +471,7 @@ class _class_resolver: "please state the generic argument " "using an annotation, e.g. " f'"{self.prop.key}: Mapped[{generic_match.group(1)}' - f'[{generic_match.group(2)}]] = relationship()"' + f"['{clsarg}']] = relationship()\"" ) from err else: raise exc.InvalidRequestError( diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 630f6898f..77a95a195 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -1724,11 +1724,12 @@ class Relationship( self.collection_class = collection_class else: self.uselist = False + if argument.__args__: # type: ignore if issubclass( argument.__origin__, typing.Mapping # type: ignore ): - type_arg = argument.__args__[1] # type: ignore + type_arg = argument.__args__[-1] # type: ignore else: type_arg = argument.__args__[0] # type: ignore if hasattr(type_arg, "__forward_arg__"): @@ -1743,6 +1744,12 @@ class Relationship( elif hasattr(argument, "__forward_arg__"): argument = argument.__forward_arg__ # type: ignore + # we don't allow the collection class to be a + # __forward_arg__ right now, so if we see a forward arg here, + # we know there was no collection class either + if self.collection_class is None: + self.uselist = False + self.argument = argument @util.preload_module("sqlalchemy.orm.mapper") diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 317abe2b4..02080a27f 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -1958,8 +1958,12 @@ def _getitem(iterable_query: Query[Any], item: Any) -> Any: def _is_mapped_annotation( raw_annotation: _AnnotationScanType, cls: Type[Any] ) -> bool: - annotated = de_stringify_annotation(cls, raw_annotation) - return is_origin_of(annotated, "Mapped", module="sqlalchemy.orm") + try: + annotated = de_stringify_annotation(cls, raw_annotation) + except NameError: + return False + else: + return is_origin_of(annotated, "Mapped", module="sqlalchemy.orm") def _cleanup_mapped_str_annotation(annotation: str) -> str: @@ -1984,7 +1988,10 @@ def _cleanup_mapped_str_annotation(annotation: str) -> str: # stack: ['Mapped', 'List', 'Address'] if not re.match(r"""^["'].*["']$""", stack[-1]): - stack[-1] = f'"{stack[-1]}"' + stripchars = "\"' " + stack[-1] = ", ".join( + f'"{elem.strip(stripchars)}"' for elem in stack[-1].split(",") + ) # stack: ['Mapped', 'List', '"Address"'] annotation = "[".join(stack) + ("]" * (len(stack) - 1)) @@ -2007,6 +2014,7 @@ def _extract_mapped_subtype( Includes error raise scenarios and other options. """ + if raw_annotation is None: if required: @@ -2017,9 +2025,19 @@ def _extract_mapped_subtype( ) return None - annotated = de_stringify_annotation( - cls, raw_annotation, _cleanup_mapped_str_annotation - ) + try: + annotated = de_stringify_annotation( + cls, raw_annotation, _cleanup_mapped_str_annotation + ) + except NameError as ne: + if raiseerr and "Mapped[" in raw_annotation: # type: ignore + raise sa_exc.ArgumentError( + f"Could not interpret annotation {raw_annotation}. " + "Check that it's not using names that might not be imported " + "at the module level. See chained stack trace for more hints." + ) from ne + + annotated = raw_annotation # type: ignore if is_dataclass_field: return annotated diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 653301f1f..45fe63765 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -113,8 +113,10 @@ def de_stringify_annotation( try: annotation = eval(annotation, base_globals, None) - except NameError: - pass + except NameError as err: + raise NameError( + f"Could not de-stringify annotation {annotation}" + ) from err return annotation # type: ignore |
