summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext/mypy/infer.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/ext/mypy/infer.py')
-rw-r--r--lib/sqlalchemy/ext/mypy/infer.py155
1 files changed, 95 insertions, 60 deletions
diff --git a/lib/sqlalchemy/ext/mypy/infer.py b/lib/sqlalchemy/ext/mypy/infer.py
index f1bda7865..7915c3ae2 100644
--- a/lib/sqlalchemy/ext/mypy/infer.py
+++ b/lib/sqlalchemy/ext/mypy/infer.py
@@ -6,23 +6,26 @@
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from typing import Optional
-from typing import Union
+from typing import Sequence
-from mypy import nodes
-from mypy import types
from mypy.maptype import map_instance_to_supertype
from mypy.messages import format_type
from mypy.nodes import AssignmentStmt
from mypy.nodes import CallExpr
+from mypy.nodes import Expression
+from mypy.nodes import MemberExpr
from mypy.nodes import NameExpr
+from mypy.nodes import RefExpr
from mypy.nodes import StrExpr
from mypy.nodes import TypeInfo
from mypy.nodes import Var
from mypy.plugin import SemanticAnalyzerPluginInterface
from mypy.subtypes import is_subtype
from mypy.types import AnyType
+from mypy.types import get_proper_type
from mypy.types import Instance
from mypy.types import NoneType
+from mypy.types import ProperType
from mypy.types import TypeOfAny
from mypy.types import UnionType
@@ -34,8 +37,8 @@ def _infer_type_from_relationship(
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
node: Var,
- left_hand_explicit_type: Optional[types.Type],
-) -> Union[Instance, UnionType, None]:
+ left_hand_explicit_type: Optional[ProperType],
+) -> Optional[ProperType]:
"""Infer the type of mapping from a relationship.
E.g.::
@@ -62,7 +65,7 @@ def _infer_type_from_relationship(
assert isinstance(stmt.rvalue, CallExpr)
target_cls_arg = stmt.rvalue.args[0]
- python_type_for_type = None
+ python_type_for_type: Optional[ProperType] = None
if isinstance(target_cls_arg, NameExpr) and isinstance(
target_cls_arg.node, TypeInfo
@@ -86,7 +89,7 @@ def _infer_type_from_relationship(
# isinstance(target_cls_arg, StrExpr)
uselist_arg = util._get_callexpr_kwarg(stmt.rvalue, "uselist")
- collection_cls_arg = util._get_callexpr_kwarg(
+ collection_cls_arg: Optional[Expression] = util._get_callexpr_kwarg(
stmt.rvalue, "collection_class"
)
type_is_a_collection = False
@@ -98,7 +101,7 @@ def _infer_type_from_relationship(
if (
uselist_arg is not None
- and uselist_arg.fullname == "builtins.True"
+ and api.parse_bool(uselist_arg) is True
and collection_cls_arg is None
):
type_is_a_collection = True
@@ -107,7 +110,7 @@ def _infer_type_from_relationship(
"__builtins__.list", [python_type_for_type]
)
elif (
- uselist_arg is None or uselist_arg.fullname == "builtins.True"
+ uselist_arg is None or api.parse_bool(uselist_arg) is True
) and collection_cls_arg is not None:
type_is_a_collection = True
if isinstance(collection_cls_arg, CallExpr):
@@ -130,7 +133,7 @@ def _infer_type_from_relationship(
stmt.rvalue,
)
python_type_for_type = None
- elif uselist_arg is not None and uselist_arg.fullname == "builtins.False":
+ elif uselist_arg is not None and api.parse_bool(uselist_arg) is False:
if collection_cls_arg is not None:
util.fail(
api,
@@ -159,13 +162,19 @@ def _infer_type_from_relationship(
api, node, left_hand_explicit_type
)
elif left_hand_explicit_type is not None:
- return _infer_type_from_left_and_inferred_right(
- api,
- node,
- left_hand_explicit_type,
- python_type_for_type,
- type_is_a_collection=type_is_a_collection,
- )
+ if type_is_a_collection:
+ assert isinstance(left_hand_explicit_type, Instance)
+ assert isinstance(python_type_for_type, Instance)
+ return _infer_collection_type_from_left_and_inferred_right(
+ api, node, left_hand_explicit_type, python_type_for_type
+ )
+ else:
+ return _infer_type_from_left_and_inferred_right(
+ api,
+ node,
+ left_hand_explicit_type,
+ python_type_for_type,
+ )
else:
return python_type_for_type
@@ -174,8 +183,8 @@ def _infer_type_from_decl_composite_property(
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
node: Var,
- left_hand_explicit_type: Optional[types.Type],
-) -> Union[Instance, UnionType, None]:
+ left_hand_explicit_type: Optional[ProperType],
+) -> Optional[ProperType]:
"""Infer the type of mapping from a CompositeProperty."""
assert isinstance(stmt.rvalue, CallExpr)
@@ -206,8 +215,8 @@ def _infer_type_from_decl_column_property(
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
node: Var,
- left_hand_explicit_type: Optional[types.Type],
-) -> Union[Instance, UnionType, None]:
+ left_hand_explicit_type: Optional[ProperType],
+) -> Optional[ProperType]:
"""Infer the type of mapping from a ColumnProperty.
This includes mappings against ``column_property()`` as well as the
@@ -219,28 +228,26 @@ def _infer_type_from_decl_column_property(
if isinstance(first_prop_arg, CallExpr):
type_id = names._type_id_for_callee(first_prop_arg.callee)
- else:
- type_id = None
- # look for column_property() / deferred() etc with Column as first
- # argument
- if type_id is names.COLUMN:
- return _infer_type_from_decl_column(
- api, stmt, node, left_hand_explicit_type, first_prop_arg
- )
- else:
- return _infer_type_from_left_hand_type_only(
- api, node, left_hand_explicit_type
- )
+ # look for column_property() / deferred() etc with Column as first
+ # argument
+ if type_id is names.COLUMN:
+ return _infer_type_from_decl_column(
+ api, stmt, node, left_hand_explicit_type, first_prop_arg
+ )
+
+ return _infer_type_from_left_hand_type_only(
+ api, node, left_hand_explicit_type
+ )
def _infer_type_from_decl_column(
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
node: Var,
- left_hand_explicit_type: Optional[types.Type],
+ left_hand_explicit_type: Optional[ProperType],
right_hand_expression: CallExpr,
-) -> Union[Instance, UnionType, None]:
+) -> Optional[ProperType]:
"""Infer the type of mapping from a Column.
E.g.::
@@ -277,12 +284,13 @@ def _infer_type_from_decl_column(
callee = None
for column_arg in right_hand_expression.args[0:2]:
- if isinstance(column_arg, nodes.CallExpr):
- # x = Column(String(50))
- callee = column_arg.callee
- type_args = column_arg.args
- break
- elif isinstance(column_arg, (nodes.NameExpr, nodes.MemberExpr)):
+ if isinstance(column_arg, CallExpr):
+ if isinstance(column_arg.callee, RefExpr):
+ # x = Column(String(50))
+ callee = column_arg.callee
+ type_args: Sequence[Expression] = column_arg.args
+ break
+ elif isinstance(column_arg, (NameExpr, MemberExpr)):
if isinstance(column_arg.node, TypeInfo):
# x = Column(String)
callee = column_arg
@@ -314,10 +322,7 @@ def _infer_type_from_decl_column(
)
else:
- python_type_for_type = UnionType(
- [python_type_for_type, NoneType()]
- )
- return python_type_for_type
+ return UnionType([python_type_for_type, NoneType()])
else:
# it's not TypeEngine, it's typically implicitly typed
# like ForeignKey. we can't infer from the right side.
@@ -329,10 +334,11 @@ def _infer_type_from_decl_column(
def _infer_type_from_left_and_inferred_right(
api: SemanticAnalyzerPluginInterface,
node: Var,
- left_hand_explicit_type: Optional[types.Type],
- python_type_for_type: Union[Instance, UnionType],
- type_is_a_collection: bool = False,
-) -> Optional[Union[Instance, UnionType]]:
+ left_hand_explicit_type: ProperType,
+ python_type_for_type: ProperType,
+ orig_left_hand_type: Optional[ProperType] = None,
+ orig_python_type_for_type: Optional[ProperType] = None,
+) -> Optional[ProperType]:
"""Validate type when a left hand annotation is present and we also
could infer the right hand side::
@@ -340,12 +346,10 @@ def _infer_type_from_left_and_inferred_right(
"""
- orig_left_hand_type = left_hand_explicit_type
- orig_python_type_for_type = python_type_for_type
-
- if type_is_a_collection and left_hand_explicit_type.args:
- left_hand_explicit_type = left_hand_explicit_type.args[0]
- python_type_for_type = python_type_for_type.args[0]
+ if orig_left_hand_type is None:
+ orig_left_hand_type = left_hand_explicit_type
+ if orig_python_type_for_type is None:
+ orig_python_type_for_type = python_type_for_type
if not is_subtype(left_hand_explicit_type, python_type_for_type):
effective_type = api.named_type(
@@ -369,11 +373,40 @@ def _infer_type_from_left_and_inferred_right(
return orig_left_hand_type
+def _infer_collection_type_from_left_and_inferred_right(
+ api: SemanticAnalyzerPluginInterface,
+ node: Var,
+ left_hand_explicit_type: Instance,
+ python_type_for_type: Instance,
+) -> Optional[ProperType]:
+ orig_left_hand_type = left_hand_explicit_type
+ orig_python_type_for_type = python_type_for_type
+
+ if left_hand_explicit_type.args:
+ left_hand_arg = get_proper_type(left_hand_explicit_type.args[0])
+ python_type_arg = get_proper_type(python_type_for_type.args[0])
+ else:
+ left_hand_arg = left_hand_explicit_type
+ python_type_arg = python_type_for_type
+
+ assert isinstance(left_hand_arg, (Instance, UnionType))
+ assert isinstance(python_type_arg, (Instance, UnionType))
+
+ return _infer_type_from_left_and_inferred_right(
+ api,
+ node,
+ left_hand_arg,
+ python_type_arg,
+ orig_left_hand_type=orig_left_hand_type,
+ orig_python_type_for_type=orig_python_type_for_type,
+ )
+
+
def _infer_type_from_left_hand_type_only(
api: SemanticAnalyzerPluginInterface,
node: Var,
- left_hand_explicit_type: Optional[types.Type],
-) -> Optional[Union[Instance, UnionType]]:
+ left_hand_explicit_type: Optional[ProperType],
+) -> Optional[ProperType]:
"""Determine the type based on explicit annotation only.
if no annotation were present, note that we need one there to know
@@ -397,8 +430,10 @@ def _infer_type_from_left_hand_type_only(
def _extract_python_type_from_typeengine(
- api: SemanticAnalyzerPluginInterface, node: TypeInfo, type_args
-) -> Instance:
+ api: SemanticAnalyzerPluginInterface,
+ node: TypeInfo,
+ type_args: Sequence[Expression],
+) -> ProperType:
if node.fullname == "sqlalchemy.sql.sqltypes.Enum" and type_args:
first_arg = type_args[0]
if isinstance(first_arg, NameExpr) and isinstance(
@@ -426,4 +461,4 @@ def _extract_python_type_from_typeengine(
Instance(node, []),
type_engine_sym.node,
)
- return type_engine.args[-1]
+ return get_proper_type(type_engine.args[-1])