diff options
Diffstat (limited to 'lib/sqlalchemy/ext/mypy/apply.py')
| -rw-r--r-- | lib/sqlalchemy/ext/mypy/apply.py | 115 |
1 files changed, 70 insertions, 45 deletions
diff --git a/lib/sqlalchemy/ext/mypy/apply.py b/lib/sqlalchemy/ext/mypy/apply.py index 293ef2f9a..cf5b4fda2 100644 --- a/lib/sqlalchemy/ext/mypy/apply.py +++ b/lib/sqlalchemy/ext/mypy/apply.py @@ -5,10 +5,10 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from typing import List from typing import Optional from typing import Union -from mypy import nodes from mypy.nodes import ARG_NAMED_OPT from mypy.nodes import Argument from mypy.nodes import AssignmentStmt @@ -17,6 +17,7 @@ from mypy.nodes import ClassDef from mypy.nodes import MDEF from mypy.nodes import MemberExpr from mypy.nodes import NameExpr +from mypy.nodes import RefExpr from mypy.nodes import StrExpr from mypy.nodes import SymbolTableNode from mypy.nodes import TempNode @@ -37,18 +38,18 @@ from . import infer from . import util -def _apply_mypy_mapped_attr( +def apply_mypy_mapped_attr( cls: ClassDef, api: SemanticAnalyzerPluginInterface, item: Union[NameExpr, StrExpr], - cls_metadata: util.DeclClassApplied, + attributes: List[util.SQLAlchemyAttribute], ) -> None: if isinstance(item, NameExpr): name = item.name elif isinstance(item, StrExpr): name = item.value else: - return + return None for stmt in cls.defs.body: if ( @@ -59,7 +60,7 @@ def _apply_mypy_mapped_attr( break else: util.fail(api, "Can't find mapped attribute {}".format(name), cls) - return + return None if stmt.type is None: util.fail( @@ -68,32 +69,38 @@ def _apply_mypy_mapped_attr( "typing information", stmt, ) - return + return None left_hand_explicit_type = get_proper_type(stmt.type) assert isinstance( left_hand_explicit_type, (Instance, UnionType, UnboundType) ) - cls_metadata.mapped_attr_names.append((name, left_hand_explicit_type)) + attributes.append( + util.SQLAlchemyAttribute( + name=name, + line=item.line, + column=item.column, + typ=left_hand_explicit_type, + info=cls.info, + ) + ) - _apply_type_to_mapped_statement( + apply_type_to_mapped_statement( api, stmt, stmt.lvalues[0], left_hand_explicit_type, None ) -def _re_apply_declarative_assignments( +def re_apply_declarative_assignments( cls: ClassDef, api: SemanticAnalyzerPluginInterface, - cls_metadata: util.DeclClassApplied, + attributes: List[util.SQLAlchemyAttribute], ) -> None: """For multiple class passes, re-apply our left-hand side types as mypy seems to reset them in place. """ - mapped_attr_lookup = { - name: typ for name, typ in cls_metadata.mapped_attr_names - } + mapped_attr_lookup = {attr.name: attr for attr in attributes} update_cls_metadata = False for stmt in cls.defs.body: @@ -109,28 +116,37 @@ def _re_apply_declarative_assignments( ): left_node = stmt.lvalues[0].node - python_type_for_type = mapped_attr_lookup[stmt.lvalues[0].name] + python_type_for_type = mapped_attr_lookup[ + stmt.lvalues[0].name + ].type + + left_node_proper_type = get_proper_type(left_node.type) + # if we have scanned an UnboundType and now there's a more # specific type than UnboundType, call the re-scan so we # can get that set up correctly if ( isinstance(python_type_for_type, UnboundType) - and not isinstance(left_node.type, UnboundType) + and not isinstance(left_node_proper_type, UnboundType) and ( - isinstance(stmt.rvalue.callee, MemberExpr) + isinstance(stmt.rvalue, CallExpr) + and isinstance(stmt.rvalue.callee, MemberExpr) + and isinstance(stmt.rvalue.callee.expr, NameExpr) + and stmt.rvalue.callee.expr.node is not None and stmt.rvalue.callee.expr.node.fullname == "sqlalchemy.orm.attributes.Mapped" and stmt.rvalue.callee.name == "_empty_constructor" and isinstance(stmt.rvalue.args[0], CallExpr) + and isinstance(stmt.rvalue.args[0].callee, RefExpr) ) ): python_type_for_type = ( - infer._infer_type_from_right_hand_nameexpr( + infer.infer_type_from_right_hand_nameexpr( api, stmt, left_node, - left_node.type, + left_node_proper_type, stmt.rvalue.args[0].callee, ) ) @@ -140,21 +156,23 @@ def _re_apply_declarative_assignments( ): continue - # update the DeclClassApplied with the better information - mapped_attr_lookup[stmt.lvalues[0].name] = python_type_for_type + # update the SQLAlchemyAttribute with the better information + mapped_attr_lookup[ + stmt.lvalues[0].name + ].type = python_type_for_type + update_cls_metadata = True - left_node.type = api.named_type( - "__sa_Mapped", [python_type_for_type] - ) + if python_type_for_type is not None: + left_node.type = api.named_type( + "__sa_Mapped", [python_type_for_type] + ) if update_cls_metadata: - cls_metadata.mapped_attr_names[:] = [ - (k, v) for k, v in mapped_attr_lookup.items() - ] + util.set_mapped_attributes(cls.info, attributes) -def _apply_type_to_mapped_statement( +def apply_type_to_mapped_statement( api: SemanticAnalyzerPluginInterface, stmt: AssignmentStmt, lvalue: NameExpr, @@ -205,30 +223,36 @@ def _apply_type_to_mapped_statement( # _sa_Mapped._empty_constructor(<original CallExpr from rvalue>) # the original right-hand side is maintained so it gets type checked # internally - column_descriptor = nodes.NameExpr("__sa_Mapped") - column_descriptor.fullname = "sqlalchemy.orm.attributes.Mapped" - mm = nodes.MemberExpr(column_descriptor, "_empty_constructor") - orig_call_expr = stmt.rvalue - stmt.rvalue = CallExpr(mm, [orig_call_expr], [nodes.ARG_POS], ["arg1"]) + stmt.rvalue = util.expr_to_mapped_constructor(stmt.rvalue) -def _add_additional_orm_attributes( +def add_additional_orm_attributes( cls: ClassDef, api: SemanticAnalyzerPluginInterface, - cls_metadata: util.DeclClassApplied, + attributes: List[util.SQLAlchemyAttribute], ) -> None: """Apply __init__, __table__ and other attributes to the mapped class.""" - info = util._info_for_cls(cls, api) - if "__init__" not in info.names and cls_metadata.is_mapped: - mapped_attr_names = {n: t for n, t in cls_metadata.mapped_attr_names} + info = util.info_for_cls(cls, api) - for mapped_base in cls_metadata.mapped_mro: - base_cls_metadata = util.DeclClassApplied.deserialize( - mapped_base.type.metadata["_sa_decl_class_applied"], api - ) - for n, t in base_cls_metadata.mapped_attr_names: - mapped_attr_names.setdefault(n, t) + if info is None: + return + + is_base = util.get_is_base(info) + + if "__init__" not in info.names and not is_base: + mapped_attr_names = {attr.name: attr.type for attr in attributes} + + for base in info.mro[1:-1]: + if "sqlalchemy" not in info.metadata: + continue + + base_cls_attributes = util.get_mapped_attributes(base, api) + if base_cls_attributes is None: + continue + + for attr in base_cls_attributes: + mapped_attr_names.setdefault(attr.name, attr.type) arguments = [] for name, typ in mapped_attr_names.items(): @@ -242,13 +266,14 @@ def _add_additional_orm_attributes( kind=ARG_NAMED_OPT, ) ) + add_method_to_class(api, cls, "__init__", arguments, NoneTyp()) - if "__table__" not in info.names and cls_metadata.has_table: + if "__table__" not in info.names and util.get_has_table(info): _apply_placeholder_attr_to_class( api, cls, "sqlalchemy.sql.schema.Table", "__table__" ) - if cls_metadata.is_mapped: + if not is_base: _apply_placeholder_attr_to_class( api, cls, "sqlalchemy.orm.mapper.Mapper", "__mapper__" ) |
