summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext/mypy/apply.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/ext/mypy/apply.py')
-rw-r--r--lib/sqlalchemy/ext/mypy/apply.py115
1 files changed, 70 insertions, 45 deletions
diff --git a/lib/sqlalchemy/ext/mypy/apply.py b/lib/sqlalchemy/ext/mypy/apply.py
index 293ef2f9a..cf5b4fda2 100644
--- a/lib/sqlalchemy/ext/mypy/apply.py
+++ b/lib/sqlalchemy/ext/mypy/apply.py
@@ -5,10 +5,10 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
+from typing import List
from typing import Optional
from typing import Union
-from mypy import nodes
from mypy.nodes import ARG_NAMED_OPT
from mypy.nodes import Argument
from mypy.nodes import AssignmentStmt
@@ -17,6 +17,7 @@ from mypy.nodes import ClassDef
from mypy.nodes import MDEF
from mypy.nodes import MemberExpr
from mypy.nodes import NameExpr
+from mypy.nodes import RefExpr
from mypy.nodes import StrExpr
from mypy.nodes import SymbolTableNode
from mypy.nodes import TempNode
@@ -37,18 +38,18 @@ from . import infer
from . import util
-def _apply_mypy_mapped_attr(
+def apply_mypy_mapped_attr(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
item: Union[NameExpr, StrExpr],
- cls_metadata: util.DeclClassApplied,
+ attributes: List[util.SQLAlchemyAttribute],
) -> None:
if isinstance(item, NameExpr):
name = item.name
elif isinstance(item, StrExpr):
name = item.value
else:
- return
+ return None
for stmt in cls.defs.body:
if (
@@ -59,7 +60,7 @@ def _apply_mypy_mapped_attr(
break
else:
util.fail(api, "Can't find mapped attribute {}".format(name), cls)
- return
+ return None
if stmt.type is None:
util.fail(
@@ -68,32 +69,38 @@ def _apply_mypy_mapped_attr(
"typing information",
stmt,
)
- return
+ return None
left_hand_explicit_type = get_proper_type(stmt.type)
assert isinstance(
left_hand_explicit_type, (Instance, UnionType, UnboundType)
)
- cls_metadata.mapped_attr_names.append((name, left_hand_explicit_type))
+ attributes.append(
+ util.SQLAlchemyAttribute(
+ name=name,
+ line=item.line,
+ column=item.column,
+ typ=left_hand_explicit_type,
+ info=cls.info,
+ )
+ )
- _apply_type_to_mapped_statement(
+ apply_type_to_mapped_statement(
api, stmt, stmt.lvalues[0], left_hand_explicit_type, None
)
-def _re_apply_declarative_assignments(
+def re_apply_declarative_assignments(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
- cls_metadata: util.DeclClassApplied,
+ attributes: List[util.SQLAlchemyAttribute],
) -> None:
"""For multiple class passes, re-apply our left-hand side types as mypy
seems to reset them in place.
"""
- mapped_attr_lookup = {
- name: typ for name, typ in cls_metadata.mapped_attr_names
- }
+ mapped_attr_lookup = {attr.name: attr for attr in attributes}
update_cls_metadata = False
for stmt in cls.defs.body:
@@ -109,28 +116,37 @@ def _re_apply_declarative_assignments(
):
left_node = stmt.lvalues[0].node
- python_type_for_type = mapped_attr_lookup[stmt.lvalues[0].name]
+ python_type_for_type = mapped_attr_lookup[
+ stmt.lvalues[0].name
+ ].type
+
+ left_node_proper_type = get_proper_type(left_node.type)
+
# if we have scanned an UnboundType and now there's a more
# specific type than UnboundType, call the re-scan so we
# can get that set up correctly
if (
isinstance(python_type_for_type, UnboundType)
- and not isinstance(left_node.type, UnboundType)
+ and not isinstance(left_node_proper_type, UnboundType)
and (
- isinstance(stmt.rvalue.callee, MemberExpr)
+ isinstance(stmt.rvalue, CallExpr)
+ and isinstance(stmt.rvalue.callee, MemberExpr)
+ and isinstance(stmt.rvalue.callee.expr, NameExpr)
+ and stmt.rvalue.callee.expr.node is not None
and stmt.rvalue.callee.expr.node.fullname
== "sqlalchemy.orm.attributes.Mapped"
and stmt.rvalue.callee.name == "_empty_constructor"
and isinstance(stmt.rvalue.args[0], CallExpr)
+ and isinstance(stmt.rvalue.args[0].callee, RefExpr)
)
):
python_type_for_type = (
- infer._infer_type_from_right_hand_nameexpr(
+ infer.infer_type_from_right_hand_nameexpr(
api,
stmt,
left_node,
- left_node.type,
+ left_node_proper_type,
stmt.rvalue.args[0].callee,
)
)
@@ -140,21 +156,23 @@ def _re_apply_declarative_assignments(
):
continue
- # update the DeclClassApplied with the better information
- mapped_attr_lookup[stmt.lvalues[0].name] = python_type_for_type
+ # update the SQLAlchemyAttribute with the better information
+ mapped_attr_lookup[
+ stmt.lvalues[0].name
+ ].type = python_type_for_type
+
update_cls_metadata = True
- left_node.type = api.named_type(
- "__sa_Mapped", [python_type_for_type]
- )
+ if python_type_for_type is not None:
+ left_node.type = api.named_type(
+ "__sa_Mapped", [python_type_for_type]
+ )
if update_cls_metadata:
- cls_metadata.mapped_attr_names[:] = [
- (k, v) for k, v in mapped_attr_lookup.items()
- ]
+ util.set_mapped_attributes(cls.info, attributes)
-def _apply_type_to_mapped_statement(
+def apply_type_to_mapped_statement(
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
lvalue: NameExpr,
@@ -205,30 +223,36 @@ def _apply_type_to_mapped_statement(
# _sa_Mapped._empty_constructor(<original CallExpr from rvalue>)
# the original right-hand side is maintained so it gets type checked
# internally
- column_descriptor = nodes.NameExpr("__sa_Mapped")
- column_descriptor.fullname = "sqlalchemy.orm.attributes.Mapped"
- mm = nodes.MemberExpr(column_descriptor, "_empty_constructor")
- orig_call_expr = stmt.rvalue
- stmt.rvalue = CallExpr(mm, [orig_call_expr], [nodes.ARG_POS], ["arg1"])
+ stmt.rvalue = util.expr_to_mapped_constructor(stmt.rvalue)
-def _add_additional_orm_attributes(
+def add_additional_orm_attributes(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
- cls_metadata: util.DeclClassApplied,
+ attributes: List[util.SQLAlchemyAttribute],
) -> None:
"""Apply __init__, __table__ and other attributes to the mapped class."""
- info = util._info_for_cls(cls, api)
- if "__init__" not in info.names and cls_metadata.is_mapped:
- mapped_attr_names = {n: t for n, t in cls_metadata.mapped_attr_names}
+ info = util.info_for_cls(cls, api)
- for mapped_base in cls_metadata.mapped_mro:
- base_cls_metadata = util.DeclClassApplied.deserialize(
- mapped_base.type.metadata["_sa_decl_class_applied"], api
- )
- for n, t in base_cls_metadata.mapped_attr_names:
- mapped_attr_names.setdefault(n, t)
+ if info is None:
+ return
+
+ is_base = util.get_is_base(info)
+
+ if "__init__" not in info.names and not is_base:
+ mapped_attr_names = {attr.name: attr.type for attr in attributes}
+
+ for base in info.mro[1:-1]:
+ if "sqlalchemy" not in info.metadata:
+ continue
+
+ base_cls_attributes = util.get_mapped_attributes(base, api)
+ if base_cls_attributes is None:
+ continue
+
+ for attr in base_cls_attributes:
+ mapped_attr_names.setdefault(attr.name, attr.type)
arguments = []
for name, typ in mapped_attr_names.items():
@@ -242,13 +266,14 @@ def _add_additional_orm_attributes(
kind=ARG_NAMED_OPT,
)
)
+
add_method_to_class(api, cls, "__init__", arguments, NoneTyp())
- if "__table__" not in info.names and cls_metadata.has_table:
+ if "__table__" not in info.names and util.get_has_table(info):
_apply_placeholder_attr_to_class(
api, cls, "sqlalchemy.sql.schema.Table", "__table__"
)
- if cls_metadata.is_mapped:
+ if not is_base:
_apply_placeholder_attr_to_class(
api, cls, "sqlalchemy.orm.mapper.Mapper", "__mapper__"
)