summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext/mypy/decl_class.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/ext/mypy/decl_class.py')
-rw-r--r--lib/sqlalchemy/ext/mypy/decl_class.py122
1 files changed, 68 insertions, 54 deletions
diff --git a/lib/sqlalchemy/ext/mypy/decl_class.py b/lib/sqlalchemy/ext/mypy/decl_class.py
index 40f1f0c0f..8fac36342 100644
--- a/lib/sqlalchemy/ext/mypy/decl_class.py
+++ b/lib/sqlalchemy/ext/mypy/decl_class.py
@@ -6,7 +6,7 @@
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from typing import Optional
-from typing import Type
+from typing import Union
from mypy import nodes
from mypy.nodes import AssignmentStmt
@@ -14,18 +14,24 @@ from mypy.nodes import CallExpr
from mypy.nodes import ClassDef
from mypy.nodes import Decorator
from mypy.nodes import ListExpr
+from mypy.nodes import MemberExpr
from mypy.nodes import NameExpr
from mypy.nodes import PlaceholderNode
from mypy.nodes import RefExpr
from mypy.nodes import StrExpr
+from mypy.nodes import SymbolNode
from mypy.nodes import SymbolTableNode
from mypy.nodes import TempNode
from mypy.nodes import TypeInfo
from mypy.nodes import Var
from mypy.plugin import SemanticAnalyzerPluginInterface
from mypy.types import AnyType
+from mypy.types import CallableType
+from mypy.types import get_proper_type
from mypy.types import Instance
from mypy.types import NoneType
+from mypy.types import ProperType
+from mypy.types import Type
from mypy.types import TypeOfAny
from mypy.types import UnboundType
from mypy.types import UnionType
@@ -37,7 +43,9 @@ from . import util
def _scan_declarative_assignments_and_apply_types(
- cls: ClassDef, api: SemanticAnalyzerPluginInterface, is_mixin_scan=False
+ cls: ClassDef,
+ api: SemanticAnalyzerPluginInterface,
+ is_mixin_scan: bool = False,
) -> Optional[util.DeclClassApplied]:
info = util._info_for_cls(cls, api)
@@ -94,16 +102,17 @@ def _scan_symbol_table_entry(
name: str,
value: SymbolTableNode,
cls_metadata: util.DeclClassApplied,
-):
+) -> None:
"""Extract mapping information from a SymbolTableNode that's in the
type.names dictionary.
"""
- if not isinstance(value.type, Instance):
+ value_type = get_proper_type(value.type)
+ if not isinstance(value_type, Instance):
return
left_hand_explicit_type = None
- type_id = names._type_id_for_named_node(value.type.type)
+ type_id = names._type_id_for_named_node(value_type.type)
# type_id = names._type_id_for_unbound_type(value.type.type, cls, api)
err = False
@@ -118,22 +127,24 @@ def _scan_symbol_table_entry(
names.SYNONYM_PROPERTY,
names.COLUMN_PROPERTY,
}:
- if value.type.args:
- left_hand_explicit_type = value.type.args[0]
+ if value_type.args:
+ left_hand_explicit_type = get_proper_type(value_type.args[0])
else:
err = True
elif type_id is names.COLUMN:
- if not value.type.args:
+ if not value_type.args:
err = True
else:
- typeengine_arg = value.type.args[0]
+ typeengine_arg: Union[ProperType, TypeInfo] = get_proper_type(
+ value_type.args[0]
+ )
if isinstance(typeengine_arg, Instance):
typeengine_arg = typeengine_arg.type
if isinstance(typeengine_arg, (UnboundType, TypeInfo)):
sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg)
- if sym is not None:
- if names._mro_has_id(sym.node.mro, names.TYPEENGINE):
+ if sym is not None and isinstance(sym.node, TypeInfo):
+ if names._has_base_type_id(sym.node, names.TYPEENGINE):
left_hand_explicit_type = UnionType(
[
@@ -148,7 +159,7 @@ def _scan_symbol_table_entry(
api,
"Column type should be a TypeEngine "
"subclass not '{}'".format(sym.node.fullname),
- value.type,
+ value_type,
)
if err:
@@ -158,7 +169,7 @@ def _scan_symbol_table_entry(
"one of: Mapped[<python type>], relationship[<target class>], "
"Column[<TypeEngine>], MapperProperty[<python type>]"
)
- util.fail(api, msg.format(name, cls.name))
+ util.fail(api, msg.format(name, cls.name), cls)
left_hand_explicit_type = AnyType(TypeOfAny.special_form)
@@ -171,7 +182,7 @@ def _scan_declarative_decorator_stmt(
api: SemanticAnalyzerPluginInterface,
stmt: Decorator,
cls_metadata: util.DeclClassApplied,
-):
+) -> None:
"""Extract mapping information from a @declared_attr in a declarative
class.
@@ -195,16 +206,19 @@ def _scan_declarative_decorator_stmt(
"""
for dec in stmt.decorators:
- if names._type_id_for_named_node(dec) is names.DECLARED_ATTR:
+ if (
+ isinstance(dec, (NameExpr, MemberExpr, SymbolNode))
+ and names._type_id_for_named_node(dec) is names.DECLARED_ATTR
+ ):
break
else:
return
dec_index = cls.defs.body.index(stmt)
- left_hand_explicit_type = None
+ left_hand_explicit_type: Optional[ProperType] = None
- if stmt.func.type is not None:
+ if isinstance(stmt.func.type, CallableType):
func_type = stmt.func.type.ret_type
if isinstance(func_type, UnboundType):
type_id = names._type_id_for_unbound_type(func_type, cls, api)
@@ -225,30 +239,28 @@ def _scan_declarative_decorator_stmt(
}
and func_type.args
):
- left_hand_explicit_type = func_type.args[0]
+ left_hand_explicit_type = get_proper_type(func_type.args[0])
elif type_id is names.COLUMN and func_type.args:
typeengine_arg = func_type.args[0]
if isinstance(typeengine_arg, UnboundType):
sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg)
- if sym is not None and names._mro_has_id(
- sym.node.mro, names.TYPEENGINE
- ):
-
- left_hand_explicit_type = UnionType(
- [
- infer._extract_python_type_from_typeengine(
- api, sym.node, []
- ),
- NoneType(),
- ]
- )
- else:
- util.fail(
- api,
- "Column type should be a TypeEngine "
- "subclass not '{}'".format(sym.node.fullname),
- func_type,
- )
+ if sym is not None and isinstance(sym.node, TypeInfo):
+ if names._has_base_type_id(sym.node, names.TYPEENGINE):
+ left_hand_explicit_type = UnionType(
+ [
+ infer._extract_python_type_from_typeengine(
+ api, sym.node, []
+ ),
+ NoneType(),
+ ]
+ )
+ else:
+ util.fail(
+ api,
+ "Column type should be a TypeEngine "
+ "subclass not '{}'".format(sym.node.fullname),
+ func_type,
+ )
if left_hand_explicit_type is None:
# no type on the decorated function. our option here is to
@@ -274,8 +286,8 @@ def _scan_declarative_decorator_stmt(
# of converting it to the regular Instance/TypeInfo/UnionType structures
# we see everywhere else.
if isinstance(left_hand_explicit_type, UnboundType):
- left_hand_explicit_type = util._unbound_to_instance(
- api, left_hand_explicit_type
+ left_hand_explicit_type = get_proper_type(
+ util._unbound_to_instance(api, left_hand_explicit_type)
)
left_node.node.type = api.named_type(
@@ -315,7 +327,7 @@ def _scan_declarative_assignment_stmt(
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
cls_metadata: util.DeclClassApplied,
-):
+) -> None:
"""Extract mapping information from an assignment statement in a
declarative class.
@@ -339,7 +351,7 @@ def _scan_declarative_assignment_stmt(
assert isinstance(node, Var)
if node.name == "__abstract__":
- if stmt.rvalue.fullname == "builtins.True":
+ if api.parse_bool(stmt.rvalue) is True:
cls_metadata.is_mapped = False
return
elif node.name == "__tablename__":
@@ -354,7 +366,8 @@ def _scan_declarative_assignment_stmt(
if isinstance(item, (NameExpr, StrExpr)):
apply._apply_mypy_mapped_attr(cls, api, item, cls_metadata)
- left_hand_mapped_type: Type = None
+ left_hand_mapped_type: Optional[Type] = None
+ left_hand_explicit_type: Optional[ProperType] = None
if node.is_inferred or node.type is None:
if isinstance(stmt.type, UnboundType):
@@ -370,32 +383,33 @@ def _scan_declarative_assignment_stmt(
mapped_sym = api.lookup_qualified("Mapped", cls)
if (
mapped_sym is not None
+ and mapped_sym.node is not None
and names._type_id_for_named_node(mapped_sym.node)
is names.MAPPED
):
- left_hand_explicit_type = stmt.type.args[0]
+ left_hand_explicit_type = get_proper_type(
+ stmt.type.args[0]
+ )
left_hand_mapped_type = stmt.type
# TODO: do we need to convert from unbound for this case?
# left_hand_explicit_type = util._unbound_to_instance(
# api, left_hand_explicit_type
# )
-
- else:
- left_hand_explicit_type = None
else:
+ node_type = get_proper_type(node.type)
if (
- isinstance(node.type, Instance)
- and names._type_id_for_named_node(node.type.type) is names.MAPPED
+ isinstance(node_type, Instance)
+ and names._type_id_for_named_node(node_type.type) is names.MAPPED
):
# print(node.type)
# sqlalchemy.orm.attributes.Mapped[<python type>]
- left_hand_explicit_type = node.type.args[0]
- left_hand_mapped_type = node.type
+ left_hand_explicit_type = get_proper_type(node_type.args[0])
+ left_hand_mapped_type = node_type
else:
# print(node.type)
# <python type>
- left_hand_explicit_type = node.type
+ left_hand_explicit_type = node_type
left_hand_mapped_type = None
if isinstance(stmt.rvalue, TempNode) and left_hand_mapped_type is not None:
@@ -440,10 +454,10 @@ def _scan_declarative_assignment_stmt(
else:
return
- cls_metadata.mapped_attr_names.append((node.name, python_type_for_type))
-
assert python_type_for_type is not None
+ cls_metadata.mapped_attr_names.append((node.name, python_type_for_type))
+
apply._apply_type_to_mapped_statement(
api,
stmt,
@@ -485,6 +499,6 @@ def _scan_for_mapped_bases(
)
)
- if base_decl_class_applied not in (None, False):
+ if base_decl_class_applied is not None:
cls_metadata.mapped_mro.append(base)
baseclasses.extend(base.type.bases)