summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext/mypy/decl_class.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/ext/mypy/decl_class.py')
-rw-r--r--lib/sqlalchemy/ext/mypy/decl_class.py150
1 files changed, 82 insertions, 68 deletions
diff --git a/lib/sqlalchemy/ext/mypy/decl_class.py b/lib/sqlalchemy/ext/mypy/decl_class.py
index 45d025fc9..23c78aa51 100644
--- a/lib/sqlalchemy/ext/mypy/decl_class.py
+++ b/lib/sqlalchemy/ext/mypy/decl_class.py
@@ -5,14 +5,15 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
+from typing import List
from typing import Optional
from typing import Union
-from mypy import nodes
from mypy.nodes import AssignmentStmt
from mypy.nodes import CallExpr
from mypy.nodes import ClassDef
from mypy.nodes import Decorator
+from mypy.nodes import LambdaExpr
from mypy.nodes import ListExpr
from mypy.nodes import MemberExpr
from mypy.nodes import NameExpr
@@ -42,62 +43,68 @@ from . import names
from . import util
-def _scan_declarative_assignments_and_apply_types(
+def scan_declarative_assignments_and_apply_types(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
is_mixin_scan: bool = False,
-) -> Optional[util.DeclClassApplied]:
+) -> Optional[List[util.SQLAlchemyAttribute]]:
- info = util._info_for_cls(cls, api)
+ info = util.info_for_cls(cls, api)
if info is None:
# this can occur during cached passes
return None
elif cls.fullname.startswith("builtins"):
return None
- elif "_sa_decl_class_applied" in info.metadata:
- cls_metadata = util.DeclClassApplied.deserialize(
- info.metadata["_sa_decl_class_applied"], api
- )
+ mapped_attributes: Optional[
+ List[util.SQLAlchemyAttribute]
+ ] = util.get_mapped_attributes(info, api)
+
+ if mapped_attributes is not None:
# ensure that a class that's mapped is always picked up by
# its mapped() decorator or declarative metaclass before
# it would be detected as an unmapped mixin class
- if not is_mixin_scan:
- assert cls_metadata.is_mapped
+ if not is_mixin_scan:
# mypy can call us more than once. it then *may* have reset the
# left hand side of everything, but not the right that we removed,
# removing our ability to re-scan. but we have the types
# here, so lets re-apply them, or if we have an UnboundType,
# we can re-scan
- apply._re_apply_declarative_assignments(cls, api, cls_metadata)
+ apply.re_apply_declarative_assignments(cls, api, mapped_attributes)
- return cls_metadata
+ return mapped_attributes
- cls_metadata = util.DeclClassApplied(not is_mixin_scan, False, [], [])
+ mapped_attributes = []
if not cls.defs.body:
# when we get a mixin class from another file, the body is
# empty (!) but the names are in the symbol table. so use that.
for sym_name, sym in info.names.items():
- _scan_symbol_table_entry(cls, api, sym_name, sym, cls_metadata)
+ _scan_symbol_table_entry(
+ cls, api, sym_name, sym, mapped_attributes
+ )
else:
- for stmt in util._flatten_typechecking(cls.defs.body):
+ for stmt in util.flatten_typechecking(cls.defs.body):
if isinstance(stmt, AssignmentStmt):
- _scan_declarative_assignment_stmt(cls, api, stmt, cls_metadata)
+ _scan_declarative_assignment_stmt(
+ cls, api, stmt, mapped_attributes
+ )
elif isinstance(stmt, Decorator):
- _scan_declarative_decorator_stmt(cls, api, stmt, cls_metadata)
- _scan_for_mapped_bases(cls, api, cls_metadata)
+ _scan_declarative_decorator_stmt(
+ cls, api, stmt, mapped_attributes
+ )
+ _scan_for_mapped_bases(cls, api)
if not is_mixin_scan:
- apply._add_additional_orm_attributes(cls, api, cls_metadata)
+ apply.add_additional_orm_attributes(cls, api, mapped_attributes)
- info.metadata["_sa_decl_class_applied"] = cls_metadata.serialize()
+ util.set_mapped_attributes(info, mapped_attributes)
- return cls_metadata
+ return mapped_attributes
def _scan_symbol_table_entry(
@@ -105,7 +112,7 @@ def _scan_symbol_table_entry(
api: SemanticAnalyzerPluginInterface,
name: str,
value: SymbolTableNode,
- cls_metadata: util.DeclClassApplied,
+ attributes: List[util.SQLAlchemyAttribute],
) -> None:
"""Extract mapping information from a SymbolTableNode that's in the
type.names dictionary.
@@ -116,7 +123,7 @@ def _scan_symbol_table_entry(
return
left_hand_explicit_type = None
- type_id = names._type_id_for_named_node(value_type.type)
+ type_id = names.type_id_for_named_node(value_type.type)
# type_id = names._type_id_for_unbound_type(value.type.type, cls, api)
err = False
@@ -148,11 +155,11 @@ def _scan_symbol_table_entry(
if isinstance(typeengine_arg, (UnboundType, TypeInfo)):
sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg)
if sym is not None and isinstance(sym.node, TypeInfo):
- if names._has_base_type_id(sym.node, names.TYPEENGINE):
+ if names.has_base_type_id(sym.node, names.TYPEENGINE):
left_hand_explicit_type = UnionType(
[
- infer._extract_python_type_from_typeengine(
+ infer.extract_python_type_from_typeengine(
api, sym.node, []
),
NoneType(),
@@ -178,14 +185,23 @@ def _scan_symbol_table_entry(
left_hand_explicit_type = AnyType(TypeOfAny.special_form)
if left_hand_explicit_type is not None:
- cls_metadata.mapped_attr_names.append((name, left_hand_explicit_type))
+ assert value.node is not None
+ attributes.append(
+ util.SQLAlchemyAttribute(
+ name=name,
+ line=value.node.line,
+ column=value.node.column,
+ typ=left_hand_explicit_type,
+ info=cls.info,
+ )
+ )
def _scan_declarative_decorator_stmt(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
stmt: Decorator,
- cls_metadata: util.DeclClassApplied,
+ attributes: List[util.SQLAlchemyAttribute],
) -> None:
"""Extract mapping information from a @declared_attr in a declarative
class.
@@ -212,7 +228,7 @@ def _scan_declarative_decorator_stmt(
for dec in stmt.decorators:
if (
isinstance(dec, (NameExpr, MemberExpr, SymbolNode))
- and names._type_id_for_named_node(dec) is names.DECLARED_ATTR
+ and names.type_id_for_named_node(dec) is names.DECLARED_ATTR
):
break
else:
@@ -225,7 +241,7 @@ def _scan_declarative_decorator_stmt(
if isinstance(stmt.func.type, CallableType):
func_type = stmt.func.type.ret_type
if isinstance(func_type, UnboundType):
- type_id = names._type_id_for_unbound_type(func_type, cls, api)
+ type_id = names.type_id_for_unbound_type(func_type, cls, api)
else:
# this does not seem to occur unless the type argument is
# incorrect
@@ -249,10 +265,10 @@ def _scan_declarative_decorator_stmt(
if isinstance(typeengine_arg, UnboundType):
sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg)
if sym is not None and isinstance(sym.node, TypeInfo):
- if names._has_base_type_id(sym.node, names.TYPEENGINE):
+ if names.has_base_type_id(sym.node, names.TYPEENGINE):
left_hand_explicit_type = UnionType(
[
- infer._extract_python_type_from_typeengine(
+ infer.extract_python_type_from_typeengine(
api, sym.node, []
),
NoneType(),
@@ -291,7 +307,7 @@ def _scan_declarative_decorator_stmt(
# we see everywhere else.
if isinstance(left_hand_explicit_type, UnboundType):
left_hand_explicit_type = get_proper_type(
- util._unbound_to_instance(api, left_hand_explicit_type)
+ util.unbound_to_instance(api, left_hand_explicit_type)
)
left_node.node.type = api.named_type(
@@ -305,23 +321,21 @@ def _scan_declarative_decorator_stmt(
# <attr> : Mapped[<typ>] =
# _sa_Mapped._empty_constructor(lambda: <function body>)
# the function body is maintained so it gets type checked internally
- column_descriptor = nodes.NameExpr("__sa_Mapped")
- column_descriptor.fullname = "sqlalchemy.orm.attributes.Mapped"
- mm = nodes.MemberExpr(column_descriptor, "_empty_constructor")
-
- arg = nodes.LambdaExpr(stmt.func.arguments, stmt.func.body)
- rvalue = CallExpr(
- mm,
- [arg],
- [nodes.ARG_POS],
- ["arg1"],
+ rvalue = util.expr_to_mapped_constructor(
+ LambdaExpr(stmt.func.arguments, stmt.func.body)
)
new_stmt = AssignmentStmt([left_node], rvalue)
new_stmt.type = left_node.node.type
- cls_metadata.mapped_attr_names.append(
- (left_node.name, left_hand_explicit_type)
+ attributes.append(
+ util.SQLAlchemyAttribute(
+ name=left_node.name,
+ line=stmt.line,
+ column=stmt.column,
+ typ=left_hand_explicit_type,
+ info=cls.info,
+ )
)
cls.defs.body[dec_index] = new_stmt
@@ -330,7 +344,7 @@ def _scan_declarative_assignment_stmt(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
- cls_metadata: util.DeclClassApplied,
+ attributes: List[util.SQLAlchemyAttribute],
) -> None:
"""Extract mapping information from an assignment statement in a
declarative class.
@@ -356,10 +370,10 @@ def _scan_declarative_assignment_stmt(
if node.name == "__abstract__":
if api.parse_bool(stmt.rvalue) is True:
- cls_metadata.is_mapped = False
+ util.set_is_base(cls.info)
return
elif node.name == "__tablename__":
- cls_metadata.has_table = True
+ util.set_has_table(cls.info)
elif node.name.startswith("__"):
return
elif node.name == "_mypy_mapped_attrs":
@@ -368,7 +382,7 @@ def _scan_declarative_assignment_stmt(
else:
for item in stmt.rvalue.items:
if isinstance(item, (NameExpr, StrExpr)):
- apply._apply_mypy_mapped_attr(cls, api, item, cls_metadata)
+ apply.apply_mypy_mapped_attr(cls, api, item, attributes)
left_hand_mapped_type: Optional[Type] = None
left_hand_explicit_type: Optional[ProperType] = None
@@ -388,7 +402,7 @@ def _scan_declarative_assignment_stmt(
if (
mapped_sym is not None
and mapped_sym.node is not None
- and names._type_id_for_named_node(mapped_sym.node)
+ and names.type_id_for_named_node(mapped_sym.node)
is names.MAPPED
):
left_hand_explicit_type = get_proper_type(
@@ -404,7 +418,7 @@ def _scan_declarative_assignment_stmt(
node_type = get_proper_type(node.type)
if (
isinstance(node_type, Instance)
- and names._type_id_for_named_node(node_type.type) is names.MAPPED
+ and names.type_id_for_named_node(node_type.type) is names.MAPPED
):
# print(node.type)
# sqlalchemy.orm.attributes.Mapped[<python type>]
@@ -426,7 +440,7 @@ def _scan_declarative_assignment_stmt(
stmt.rvalue.callee, RefExpr
):
- python_type_for_type = infer._infer_type_from_right_hand_nameexpr(
+ python_type_for_type = infer.infer_type_from_right_hand_nameexpr(
api, stmt, node, left_hand_explicit_type, stmt.rvalue.callee
)
@@ -438,9 +452,17 @@ def _scan_declarative_assignment_stmt(
assert python_type_for_type is not None
- cls_metadata.mapped_attr_names.append((node.name, python_type_for_type))
+ attributes.append(
+ util.SQLAlchemyAttribute(
+ name=node.name,
+ line=stmt.line,
+ column=stmt.column,
+ typ=python_type_for_type,
+ info=cls.info,
+ )
+ )
- apply._apply_type_to_mapped_statement(
+ apply.apply_type_to_mapped_statement(
api,
stmt,
lvalue,
@@ -452,7 +474,6 @@ def _scan_declarative_assignment_stmt(
def _scan_for_mapped_bases(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
- cls_metadata: util.DeclClassApplied,
) -> None:
"""Given a class, iterate through its superclass hierarchy to find
all other classes that are considered as ORM-significant.
@@ -462,25 +483,18 @@ def _scan_for_mapped_bases(
"""
- info = util._info_for_cls(cls, api)
+ info = util.info_for_cls(cls, api)
- baseclasses = list(info.bases)
-
- while baseclasses:
- base: Instance = baseclasses.pop(0)
+ if info is None:
+ return
- if base.type.fullname.startswith("builtins"):
+ for base_info in info.mro[1:-1]:
+ if base_info.fullname.startswith("builtins"):
continue
# scan each base for mapped attributes. if they are not already
# scanned (but have all their type info), that means they are unmapped
# mixins
- base_decl_class_applied = (
- _scan_declarative_assignments_and_apply_types(
- base.type.defn, api, is_mixin_scan=True
- )
+ scan_declarative_assignments_and_apply_types(
+ base_info.defn, api, is_mixin_scan=True
)
-
- if base_decl_class_applied is not None:
- cls_metadata.mapped_mro.append(base)
- baseclasses.extend(base.type.bases)