diff options
Diffstat (limited to 'lib/sqlalchemy/ext/mypy/decl_class.py')
-rw-r--r-- | lib/sqlalchemy/ext/mypy/decl_class.py | 122 |
1 files changed, 68 insertions, 54 deletions
diff --git a/lib/sqlalchemy/ext/mypy/decl_class.py b/lib/sqlalchemy/ext/mypy/decl_class.py index 40f1f0c0f..8fac36342 100644 --- a/lib/sqlalchemy/ext/mypy/decl_class.py +++ b/lib/sqlalchemy/ext/mypy/decl_class.py @@ -6,7 +6,7 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php from typing import Optional -from typing import Type +from typing import Union from mypy import nodes from mypy.nodes import AssignmentStmt @@ -14,18 +14,24 @@ from mypy.nodes import CallExpr from mypy.nodes import ClassDef from mypy.nodes import Decorator from mypy.nodes import ListExpr +from mypy.nodes import MemberExpr from mypy.nodes import NameExpr from mypy.nodes import PlaceholderNode from mypy.nodes import RefExpr from mypy.nodes import StrExpr +from mypy.nodes import SymbolNode from mypy.nodes import SymbolTableNode from mypy.nodes import TempNode from mypy.nodes import TypeInfo from mypy.nodes import Var from mypy.plugin import SemanticAnalyzerPluginInterface from mypy.types import AnyType +from mypy.types import CallableType +from mypy.types import get_proper_type from mypy.types import Instance from mypy.types import NoneType +from mypy.types import ProperType +from mypy.types import Type from mypy.types import TypeOfAny from mypy.types import UnboundType from mypy.types import UnionType @@ -37,7 +43,9 @@ from . import util def _scan_declarative_assignments_and_apply_types( - cls: ClassDef, api: SemanticAnalyzerPluginInterface, is_mixin_scan=False + cls: ClassDef, + api: SemanticAnalyzerPluginInterface, + is_mixin_scan: bool = False, ) -> Optional[util.DeclClassApplied]: info = util._info_for_cls(cls, api) @@ -94,16 +102,17 @@ def _scan_symbol_table_entry( name: str, value: SymbolTableNode, cls_metadata: util.DeclClassApplied, -): +) -> None: """Extract mapping information from a SymbolTableNode that's in the type.names dictionary. """ - if not isinstance(value.type, Instance): + value_type = get_proper_type(value.type) + if not isinstance(value_type, Instance): return left_hand_explicit_type = None - type_id = names._type_id_for_named_node(value.type.type) + type_id = names._type_id_for_named_node(value_type.type) # type_id = names._type_id_for_unbound_type(value.type.type, cls, api) err = False @@ -118,22 +127,24 @@ def _scan_symbol_table_entry( names.SYNONYM_PROPERTY, names.COLUMN_PROPERTY, }: - if value.type.args: - left_hand_explicit_type = value.type.args[0] + if value_type.args: + left_hand_explicit_type = get_proper_type(value_type.args[0]) else: err = True elif type_id is names.COLUMN: - if not value.type.args: + if not value_type.args: err = True else: - typeengine_arg = value.type.args[0] + typeengine_arg: Union[ProperType, TypeInfo] = get_proper_type( + value_type.args[0] + ) if isinstance(typeengine_arg, Instance): typeengine_arg = typeengine_arg.type if isinstance(typeengine_arg, (UnboundType, TypeInfo)): sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg) - if sym is not None: - if names._mro_has_id(sym.node.mro, names.TYPEENGINE): + if sym is not None and isinstance(sym.node, TypeInfo): + if names._has_base_type_id(sym.node, names.TYPEENGINE): left_hand_explicit_type = UnionType( [ @@ -148,7 +159,7 @@ def _scan_symbol_table_entry( api, "Column type should be a TypeEngine " "subclass not '{}'".format(sym.node.fullname), - value.type, + value_type, ) if err: @@ -158,7 +169,7 @@ def _scan_symbol_table_entry( "one of: Mapped[<python type>], relationship[<target class>], " "Column[<TypeEngine>], MapperProperty[<python type>]" ) - util.fail(api, msg.format(name, cls.name)) + util.fail(api, msg.format(name, cls.name), cls) left_hand_explicit_type = AnyType(TypeOfAny.special_form) @@ -171,7 +182,7 @@ def _scan_declarative_decorator_stmt( api: SemanticAnalyzerPluginInterface, stmt: Decorator, cls_metadata: util.DeclClassApplied, -): +) -> None: """Extract mapping information from a @declared_attr in a declarative class. @@ -195,16 +206,19 @@ def _scan_declarative_decorator_stmt( """ for dec in stmt.decorators: - if names._type_id_for_named_node(dec) is names.DECLARED_ATTR: + if ( + isinstance(dec, (NameExpr, MemberExpr, SymbolNode)) + and names._type_id_for_named_node(dec) is names.DECLARED_ATTR + ): break else: return dec_index = cls.defs.body.index(stmt) - left_hand_explicit_type = None + left_hand_explicit_type: Optional[ProperType] = None - if stmt.func.type is not None: + if isinstance(stmt.func.type, CallableType): func_type = stmt.func.type.ret_type if isinstance(func_type, UnboundType): type_id = names._type_id_for_unbound_type(func_type, cls, api) @@ -225,30 +239,28 @@ def _scan_declarative_decorator_stmt( } and func_type.args ): - left_hand_explicit_type = func_type.args[0] + left_hand_explicit_type = get_proper_type(func_type.args[0]) elif type_id is names.COLUMN and func_type.args: typeengine_arg = func_type.args[0] if isinstance(typeengine_arg, UnboundType): sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg) - if sym is not None and names._mro_has_id( - sym.node.mro, names.TYPEENGINE - ): - - left_hand_explicit_type = UnionType( - [ - infer._extract_python_type_from_typeengine( - api, sym.node, [] - ), - NoneType(), - ] - ) - else: - util.fail( - api, - "Column type should be a TypeEngine " - "subclass not '{}'".format(sym.node.fullname), - func_type, - ) + if sym is not None and isinstance(sym.node, TypeInfo): + if names._has_base_type_id(sym.node, names.TYPEENGINE): + left_hand_explicit_type = UnionType( + [ + infer._extract_python_type_from_typeengine( + api, sym.node, [] + ), + NoneType(), + ] + ) + else: + util.fail( + api, + "Column type should be a TypeEngine " + "subclass not '{}'".format(sym.node.fullname), + func_type, + ) if left_hand_explicit_type is None: # no type on the decorated function. our option here is to @@ -274,8 +286,8 @@ def _scan_declarative_decorator_stmt( # of converting it to the regular Instance/TypeInfo/UnionType structures # we see everywhere else. if isinstance(left_hand_explicit_type, UnboundType): - left_hand_explicit_type = util._unbound_to_instance( - api, left_hand_explicit_type + left_hand_explicit_type = get_proper_type( + util._unbound_to_instance(api, left_hand_explicit_type) ) left_node.node.type = api.named_type( @@ -315,7 +327,7 @@ def _scan_declarative_assignment_stmt( api: SemanticAnalyzerPluginInterface, stmt: AssignmentStmt, cls_metadata: util.DeclClassApplied, -): +) -> None: """Extract mapping information from an assignment statement in a declarative class. @@ -339,7 +351,7 @@ def _scan_declarative_assignment_stmt( assert isinstance(node, Var) if node.name == "__abstract__": - if stmt.rvalue.fullname == "builtins.True": + if api.parse_bool(stmt.rvalue) is True: cls_metadata.is_mapped = False return elif node.name == "__tablename__": @@ -354,7 +366,8 @@ def _scan_declarative_assignment_stmt( if isinstance(item, (NameExpr, StrExpr)): apply._apply_mypy_mapped_attr(cls, api, item, cls_metadata) - left_hand_mapped_type: Type = None + left_hand_mapped_type: Optional[Type] = None + left_hand_explicit_type: Optional[ProperType] = None if node.is_inferred or node.type is None: if isinstance(stmt.type, UnboundType): @@ -370,32 +383,33 @@ def _scan_declarative_assignment_stmt( mapped_sym = api.lookup_qualified("Mapped", cls) if ( mapped_sym is not None + and mapped_sym.node is not None and names._type_id_for_named_node(mapped_sym.node) is names.MAPPED ): - left_hand_explicit_type = stmt.type.args[0] + left_hand_explicit_type = get_proper_type( + stmt.type.args[0] + ) left_hand_mapped_type = stmt.type # TODO: do we need to convert from unbound for this case? # left_hand_explicit_type = util._unbound_to_instance( # api, left_hand_explicit_type # ) - - else: - left_hand_explicit_type = None else: + node_type = get_proper_type(node.type) if ( - isinstance(node.type, Instance) - and names._type_id_for_named_node(node.type.type) is names.MAPPED + isinstance(node_type, Instance) + and names._type_id_for_named_node(node_type.type) is names.MAPPED ): # print(node.type) # sqlalchemy.orm.attributes.Mapped[<python type>] - left_hand_explicit_type = node.type.args[0] - left_hand_mapped_type = node.type + left_hand_explicit_type = get_proper_type(node_type.args[0]) + left_hand_mapped_type = node_type else: # print(node.type) # <python type> - left_hand_explicit_type = node.type + left_hand_explicit_type = node_type left_hand_mapped_type = None if isinstance(stmt.rvalue, TempNode) and left_hand_mapped_type is not None: @@ -440,10 +454,10 @@ def _scan_declarative_assignment_stmt( else: return - cls_metadata.mapped_attr_names.append((node.name, python_type_for_type)) - assert python_type_for_type is not None + cls_metadata.mapped_attr_names.append((node.name, python_type_for_type)) + apply._apply_type_to_mapped_statement( api, stmt, @@ -485,6 +499,6 @@ def _scan_for_mapped_bases( ) ) - if base_decl_class_applied not in (None, False): + if base_decl_class_applied is not None: cls_metadata.mapped_mro.append(base) baseclasses.extend(base.type.bases) |