diff options
-rw-r--r-- | doc/build/changelog/unreleased_14/7321.rst | 16 | ||||
-rw-r--r-- | lib/sqlalchemy/ext/mypy/decl_class.py | 15 | ||||
-rw-r--r-- | lib/sqlalchemy/ext/mypy/util.py | 5 | ||||
-rw-r--r-- | test/ext/mypy/files/issue_7321.py | 21 | ||||
-rw-r--r-- | test/ext/mypy/files/issue_7321_part2.py | 27 |
5 files changed, 83 insertions, 1 deletions
diff --git a/doc/build/changelog/unreleased_14/7321.rst b/doc/build/changelog/unreleased_14/7321.rst new file mode 100644 index 000000000..08cca4344 --- /dev/null +++ b/doc/build/changelog/unreleased_14/7321.rst @@ -0,0 +1,16 @@ +.. change:: + :tags: bug, mypy + :tickets: 7321 + + Fixed Mypy crash which would occur when using Mypy plugin against code + which made use of :class:`_orm.declared_attr` methods for non-mapped names + like ``__mapper_args__``, ``__table_args__``, or other dunder names, as the + plugin would try to interpret these as mapped attributes which would then + be later mis-handled. As part of this change, the decorated function is + still converted by the plugin into a generic assignment statement (e.g. + ``__mapper_args__: Any``) so that the argument signature can continue to be + annotated in the same way one would for any other ``@classmethod`` without + Mypy complaining about the wrong argument type for a method that isn't + explicitly ``@classmethod``. + + diff --git a/lib/sqlalchemy/ext/mypy/decl_class.py b/lib/sqlalchemy/ext/mypy/decl_class.py index b85ec0f69..0d7462d5b 100644 --- a/lib/sqlalchemy/ext/mypy/decl_class.py +++ b/lib/sqlalchemy/ext/mypy/decl_class.py @@ -241,7 +241,20 @@ def _scan_declarative_decorator_stmt( left_hand_explicit_type: Optional[ProperType] = None - if isinstance(stmt.func.type, CallableType): + if util.name_is_dunder(stmt.name): + # for dunder names like __table_args__, __tablename__, + # __mapper_args__ etc., rewrite these as simple assignment + # statements; otherwise mypy doesn't like if the decorated + # function has an annotation like ``cls: Type[Foo]`` because + # it isn't @classmethod + any_ = AnyType(TypeOfAny.special_form) + left_node = NameExpr(stmt.var.name) + left_node.node = stmt.var + new_stmt = AssignmentStmt([left_node], TempNode(any_)) + new_stmt.type = left_node.node.type + cls.defs.body[dec_index] = new_stmt + return + elif isinstance(stmt.func.type, CallableType): func_type = stmt.func.type.ret_type if isinstance(func_type, UnboundType): type_id = names.type_id_for_unbound_type(func_type, cls, api) diff --git a/lib/sqlalchemy/ext/mypy/util.py b/lib/sqlalchemy/ext/mypy/util.py index a3825f175..4d55cb728 100644 --- a/lib/sqlalchemy/ext/mypy/util.py +++ b/lib/sqlalchemy/ext/mypy/util.py @@ -1,3 +1,4 @@ +import re from typing import Any from typing import Iterable from typing import Iterator @@ -82,6 +83,10 @@ class SQLAlchemyAttribute: return cls(typ=typ, info=info, **data) +def name_is_dunder(name): + return bool(re.match(r"^__.+?__$", name)) + + def _set_info_metadata(info: TypeInfo, key: str, data: Any) -> None: info.metadata.setdefault("sqlalchemy", {})[key] = data diff --git a/test/ext/mypy/files/issue_7321.py b/test/ext/mypy/files/issue_7321.py new file mode 100644 index 000000000..6a40b9dda --- /dev/null +++ b/test/ext/mypy/files/issue_7321.py @@ -0,0 +1,21 @@ +from typing import Any + +from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import declared_attr + + +Base = declarative_base() + + +class Foo(Base): + @declared_attr + def __tablename__(cls) -> str: + return "name" + + @declared_attr + def __mapper_args__(cls) -> dict[Any, Any]: + return {} + + @declared_attr + def __table_args__(cls) -> dict[Any, Any]: + return {} diff --git a/test/ext/mypy/files/issue_7321_part2.py b/test/ext/mypy/files/issue_7321_part2.py new file mode 100644 index 000000000..f53add1da --- /dev/null +++ b/test/ext/mypy/files/issue_7321_part2.py @@ -0,0 +1,27 @@ +from typing import Any +from typing import Type + +from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import declared_attr + + +Base = declarative_base() + + +class Foo(Base): + # no mypy error emitted regarding the + # Type[Foo] part + @declared_attr + def __tablename__(cls: Type["Foo"]) -> str: + return "name" + + @declared_attr + def __mapper_args__(cls: Type["Foo"]) -> dict[Any, Any]: + return {} + + # this was a workaround that works if there's no plugin present, make + # sure that doesn't crash anything + @classmethod + @declared_attr + def __table_args__(cls: Type["Foo"]) -> dict[Any, Any]: + return {} |