diff options
Diffstat (limited to 'lib/sqlalchemy/ext/mypy/infer.py')
-rw-r--r-- | lib/sqlalchemy/ext/mypy/infer.py | 155 |
1 files changed, 95 insertions, 60 deletions
diff --git a/lib/sqlalchemy/ext/mypy/infer.py b/lib/sqlalchemy/ext/mypy/infer.py index f1bda7865..7915c3ae2 100644 --- a/lib/sqlalchemy/ext/mypy/infer.py +++ b/lib/sqlalchemy/ext/mypy/infer.py @@ -6,23 +6,26 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php from typing import Optional -from typing import Union +from typing import Sequence -from mypy import nodes -from mypy import types from mypy.maptype import map_instance_to_supertype from mypy.messages import format_type from mypy.nodes import AssignmentStmt from mypy.nodes import CallExpr +from mypy.nodes import Expression +from mypy.nodes import MemberExpr from mypy.nodes import NameExpr +from mypy.nodes import RefExpr from mypy.nodes import StrExpr from mypy.nodes import TypeInfo from mypy.nodes import Var from mypy.plugin import SemanticAnalyzerPluginInterface from mypy.subtypes import is_subtype from mypy.types import AnyType +from mypy.types import get_proper_type from mypy.types import Instance from mypy.types import NoneType +from mypy.types import ProperType from mypy.types import TypeOfAny from mypy.types import UnionType @@ -34,8 +37,8 @@ def _infer_type_from_relationship( api: SemanticAnalyzerPluginInterface, stmt: AssignmentStmt, node: Var, - left_hand_explicit_type: Optional[types.Type], -) -> Union[Instance, UnionType, None]: + left_hand_explicit_type: Optional[ProperType], +) -> Optional[ProperType]: """Infer the type of mapping from a relationship. E.g.:: @@ -62,7 +65,7 @@ def _infer_type_from_relationship( assert isinstance(stmt.rvalue, CallExpr) target_cls_arg = stmt.rvalue.args[0] - python_type_for_type = None + python_type_for_type: Optional[ProperType] = None if isinstance(target_cls_arg, NameExpr) and isinstance( target_cls_arg.node, TypeInfo @@ -86,7 +89,7 @@ def _infer_type_from_relationship( # isinstance(target_cls_arg, StrExpr) uselist_arg = util._get_callexpr_kwarg(stmt.rvalue, "uselist") - collection_cls_arg = util._get_callexpr_kwarg( + collection_cls_arg: Optional[Expression] = util._get_callexpr_kwarg( stmt.rvalue, "collection_class" ) type_is_a_collection = False @@ -98,7 +101,7 @@ def _infer_type_from_relationship( if ( uselist_arg is not None - and uselist_arg.fullname == "builtins.True" + and api.parse_bool(uselist_arg) is True and collection_cls_arg is None ): type_is_a_collection = True @@ -107,7 +110,7 @@ def _infer_type_from_relationship( "__builtins__.list", [python_type_for_type] ) elif ( - uselist_arg is None or uselist_arg.fullname == "builtins.True" + uselist_arg is None or api.parse_bool(uselist_arg) is True ) and collection_cls_arg is not None: type_is_a_collection = True if isinstance(collection_cls_arg, CallExpr): @@ -130,7 +133,7 @@ def _infer_type_from_relationship( stmt.rvalue, ) python_type_for_type = None - elif uselist_arg is not None and uselist_arg.fullname == "builtins.False": + elif uselist_arg is not None and api.parse_bool(uselist_arg) is False: if collection_cls_arg is not None: util.fail( api, @@ -159,13 +162,19 @@ def _infer_type_from_relationship( api, node, left_hand_explicit_type ) elif left_hand_explicit_type is not None: - return _infer_type_from_left_and_inferred_right( - api, - node, - left_hand_explicit_type, - python_type_for_type, - type_is_a_collection=type_is_a_collection, - ) + if type_is_a_collection: + assert isinstance(left_hand_explicit_type, Instance) + assert isinstance(python_type_for_type, Instance) + return _infer_collection_type_from_left_and_inferred_right( + api, node, left_hand_explicit_type, python_type_for_type + ) + else: + return _infer_type_from_left_and_inferred_right( + api, + node, + left_hand_explicit_type, + python_type_for_type, + ) else: return python_type_for_type @@ -174,8 +183,8 @@ def _infer_type_from_decl_composite_property( api: SemanticAnalyzerPluginInterface, stmt: AssignmentStmt, node: Var, - left_hand_explicit_type: Optional[types.Type], -) -> Union[Instance, UnionType, None]: + left_hand_explicit_type: Optional[ProperType], +) -> Optional[ProperType]: """Infer the type of mapping from a CompositeProperty.""" assert isinstance(stmt.rvalue, CallExpr) @@ -206,8 +215,8 @@ def _infer_type_from_decl_column_property( api: SemanticAnalyzerPluginInterface, stmt: AssignmentStmt, node: Var, - left_hand_explicit_type: Optional[types.Type], -) -> Union[Instance, UnionType, None]: + left_hand_explicit_type: Optional[ProperType], +) -> Optional[ProperType]: """Infer the type of mapping from a ColumnProperty. This includes mappings against ``column_property()`` as well as the @@ -219,28 +228,26 @@ def _infer_type_from_decl_column_property( if isinstance(first_prop_arg, CallExpr): type_id = names._type_id_for_callee(first_prop_arg.callee) - else: - type_id = None - # look for column_property() / deferred() etc with Column as first - # argument - if type_id is names.COLUMN: - return _infer_type_from_decl_column( - api, stmt, node, left_hand_explicit_type, first_prop_arg - ) - else: - return _infer_type_from_left_hand_type_only( - api, node, left_hand_explicit_type - ) + # look for column_property() / deferred() etc with Column as first + # argument + if type_id is names.COLUMN: + return _infer_type_from_decl_column( + api, stmt, node, left_hand_explicit_type, first_prop_arg + ) + + return _infer_type_from_left_hand_type_only( + api, node, left_hand_explicit_type + ) def _infer_type_from_decl_column( api: SemanticAnalyzerPluginInterface, stmt: AssignmentStmt, node: Var, - left_hand_explicit_type: Optional[types.Type], + left_hand_explicit_type: Optional[ProperType], right_hand_expression: CallExpr, -) -> Union[Instance, UnionType, None]: +) -> Optional[ProperType]: """Infer the type of mapping from a Column. E.g.:: @@ -277,12 +284,13 @@ def _infer_type_from_decl_column( callee = None for column_arg in right_hand_expression.args[0:2]: - if isinstance(column_arg, nodes.CallExpr): - # x = Column(String(50)) - callee = column_arg.callee - type_args = column_arg.args - break - elif isinstance(column_arg, (nodes.NameExpr, nodes.MemberExpr)): + if isinstance(column_arg, CallExpr): + if isinstance(column_arg.callee, RefExpr): + # x = Column(String(50)) + callee = column_arg.callee + type_args: Sequence[Expression] = column_arg.args + break + elif isinstance(column_arg, (NameExpr, MemberExpr)): if isinstance(column_arg.node, TypeInfo): # x = Column(String) callee = column_arg @@ -314,10 +322,7 @@ def _infer_type_from_decl_column( ) else: - python_type_for_type = UnionType( - [python_type_for_type, NoneType()] - ) - return python_type_for_type + return UnionType([python_type_for_type, NoneType()]) else: # it's not TypeEngine, it's typically implicitly typed # like ForeignKey. we can't infer from the right side. @@ -329,10 +334,11 @@ def _infer_type_from_decl_column( def _infer_type_from_left_and_inferred_right( api: SemanticAnalyzerPluginInterface, node: Var, - left_hand_explicit_type: Optional[types.Type], - python_type_for_type: Union[Instance, UnionType], - type_is_a_collection: bool = False, -) -> Optional[Union[Instance, UnionType]]: + left_hand_explicit_type: ProperType, + python_type_for_type: ProperType, + orig_left_hand_type: Optional[ProperType] = None, + orig_python_type_for_type: Optional[ProperType] = None, +) -> Optional[ProperType]: """Validate type when a left hand annotation is present and we also could infer the right hand side:: @@ -340,12 +346,10 @@ def _infer_type_from_left_and_inferred_right( """ - orig_left_hand_type = left_hand_explicit_type - orig_python_type_for_type = python_type_for_type - - if type_is_a_collection and left_hand_explicit_type.args: - left_hand_explicit_type = left_hand_explicit_type.args[0] - python_type_for_type = python_type_for_type.args[0] + if orig_left_hand_type is None: + orig_left_hand_type = left_hand_explicit_type + if orig_python_type_for_type is None: + orig_python_type_for_type = python_type_for_type if not is_subtype(left_hand_explicit_type, python_type_for_type): effective_type = api.named_type( @@ -369,11 +373,40 @@ def _infer_type_from_left_and_inferred_right( return orig_left_hand_type +def _infer_collection_type_from_left_and_inferred_right( + api: SemanticAnalyzerPluginInterface, + node: Var, + left_hand_explicit_type: Instance, + python_type_for_type: Instance, +) -> Optional[ProperType]: + orig_left_hand_type = left_hand_explicit_type + orig_python_type_for_type = python_type_for_type + + if left_hand_explicit_type.args: + left_hand_arg = get_proper_type(left_hand_explicit_type.args[0]) + python_type_arg = get_proper_type(python_type_for_type.args[0]) + else: + left_hand_arg = left_hand_explicit_type + python_type_arg = python_type_for_type + + assert isinstance(left_hand_arg, (Instance, UnionType)) + assert isinstance(python_type_arg, (Instance, UnionType)) + + return _infer_type_from_left_and_inferred_right( + api, + node, + left_hand_arg, + python_type_arg, + orig_left_hand_type=orig_left_hand_type, + orig_python_type_for_type=orig_python_type_for_type, + ) + + def _infer_type_from_left_hand_type_only( api: SemanticAnalyzerPluginInterface, node: Var, - left_hand_explicit_type: Optional[types.Type], -) -> Optional[Union[Instance, UnionType]]: + left_hand_explicit_type: Optional[ProperType], +) -> Optional[ProperType]: """Determine the type based on explicit annotation only. if no annotation were present, note that we need one there to know @@ -397,8 +430,10 @@ def _infer_type_from_left_hand_type_only( def _extract_python_type_from_typeengine( - api: SemanticAnalyzerPluginInterface, node: TypeInfo, type_args -) -> Instance: + api: SemanticAnalyzerPluginInterface, + node: TypeInfo, + type_args: Sequence[Expression], +) -> ProperType: if node.fullname == "sqlalchemy.sql.sqltypes.Enum" and type_args: first_arg = type_args[0] if isinstance(first_arg, NameExpr) and isinstance( @@ -426,4 +461,4 @@ def _extract_python_type_from_typeengine( Instance(node, []), type_engine_sym.node, ) - return type_engine.args[-1] + return get_proper_type(type_engine.args[-1]) |