summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext/mypy/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/ext/mypy/util.py')
-rw-r--r--lib/sqlalchemy/ext/mypy/util.py179
1 files changed, 129 insertions, 50 deletions
diff --git a/lib/sqlalchemy/ext/mypy/util.py b/lib/sqlalchemy/ext/mypy/util.py
index 67c3fa209..614805d77 100644
--- a/lib/sqlalchemy/ext/mypy/util.py
+++ b/lib/sqlalchemy/ext/mypy/util.py
@@ -1,5 +1,4 @@
from typing import Any
-from typing import cast
from typing import Iterable
from typing import Iterator
from typing import List
@@ -10,12 +9,15 @@ from typing import Type as TypingType
from typing import TypeVar
from typing import Union
+from mypy.nodes import ARG_POS
from mypy.nodes import CallExpr
from mypy.nodes import ClassDef
from mypy.nodes import CLASSDEF_NO_INFO
from mypy.nodes import Context
+from mypy.nodes import Expression
from mypy.nodes import IfStmt
from mypy.nodes import JsonDict
+from mypy.nodes import MemberExpr
from mypy.nodes import NameExpr
from mypy.nodes import Statement
from mypy.nodes import SymbolTableNode
@@ -24,10 +26,11 @@ from mypy.plugin import ClassDefContext
from mypy.plugin import DynamicClassDefContext
from mypy.plugin import SemanticAnalyzerPluginInterface
from mypy.plugins.common import deserialize_and_fixup_type
+from mypy.typeops import map_type_from_supertype
from mypy.types import Instance
from mypy.types import NoneType
-from mypy.types import ProperType
from mypy.types import Type
+from mypy.types import TypeVarType
from mypy.types import UnboundType
from mypy.types import UnionType
@@ -35,53 +38,117 @@ from mypy.types import UnionType
_TArgType = TypeVar("_TArgType", bound=Union[CallExpr, NameExpr])
-class DeclClassApplied:
+class SQLAlchemyAttribute:
def __init__(
self,
- is_mapped: bool,
- has_table: bool,
- mapped_attr_names: Iterable[Tuple[str, ProperType]],
- mapped_mro: Iterable[Instance],
- ):
- self.is_mapped = is_mapped
- self.has_table = has_table
- self.mapped_attr_names = list(mapped_attr_names)
- self.mapped_mro = list(mapped_mro)
+ name: str,
+ line: int,
+ column: int,
+ typ: Optional[Type],
+ info: TypeInfo,
+ ) -> None:
+ self.name = name
+ self.line = line
+ self.column = column
+ self.type = typ
+ self.info = info
def serialize(self) -> JsonDict:
+ assert self.type
return {
- "is_mapped": self.is_mapped,
- "has_table": self.has_table,
- "mapped_attr_names": [
- (name, type_.serialize())
- for name, type_ in self.mapped_attr_names
- ],
- "mapped_mro": [type_.serialize() for type_ in self.mapped_mro],
+ "name": self.name,
+ "line": self.line,
+ "column": self.column,
+ "type": self.type.serialize(),
}
+ def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
+ """Expands type vars in the context of a subtype when an attribute is inherited
+ from a generic super type."""
+ if not isinstance(self.type, TypeVarType):
+ return
+
+ self.type = map_type_from_supertype(self.type, sub_type, self.info)
+
@classmethod
def deserialize(
- cls, data: JsonDict, api: SemanticAnalyzerPluginInterface
- ) -> "DeclClassApplied":
-
- return DeclClassApplied(
- is_mapped=data["is_mapped"],
- has_table=data["has_table"],
- mapped_attr_names=cast(
- List[Tuple[str, ProperType]],
- [
- (name, deserialize_and_fixup_type(type_, api))
- for name, type_ in data["mapped_attr_names"]
- ],
- ),
- mapped_mro=cast(
- List[Instance],
- [
- deserialize_and_fixup_type(type_, api)
- for type_ in data["mapped_mro"]
- ],
- ),
- )
+ cls,
+ info: TypeInfo,
+ data: JsonDict,
+ api: SemanticAnalyzerPluginInterface,
+ ) -> "SQLAlchemyAttribute":
+ data = data.copy()
+ typ = deserialize_and_fixup_type(data.pop("type"), api)
+ return cls(typ=typ, info=info, **data)
+
+
+def _set_info_metadata(info: TypeInfo, key: str, data: Any) -> None:
+ info.metadata.setdefault("sqlalchemy", {})[key] = data
+
+
+def _get_info_metadata(info: TypeInfo, key: str) -> Optional[Any]:
+ return info.metadata.get("sqlalchemy", {}).get(key, None)
+
+
+def _get_info_mro_metadata(info: TypeInfo, key: str) -> Optional[Any]:
+ if info.mro:
+ for base in info.mro:
+ metadata = _get_info_metadata(base, key)
+ if metadata is not None:
+ return metadata
+ return None
+
+
+def set_is_base(info: TypeInfo) -> None:
+ _set_info_metadata(info, "is_base", True)
+
+
+def get_is_base(info: TypeInfo) -> bool:
+ is_base = _get_info_metadata(info, "is_base")
+ return is_base is True
+
+
+def has_declarative_base(info: TypeInfo) -> bool:
+ is_base = _get_info_mro_metadata(info, "is_base")
+ return is_base is True
+
+
+def set_has_table(info: TypeInfo) -> None:
+ _set_info_metadata(info, "has_table", True)
+
+
+def get_has_table(info: TypeInfo) -> bool:
+ is_base = _get_info_metadata(info, "has_table")
+ return is_base is True
+
+
+def get_mapped_attributes(
+ info: TypeInfo, api: SemanticAnalyzerPluginInterface
+) -> Optional[List[SQLAlchemyAttribute]]:
+ mapped_attributes: Optional[List[JsonDict]] = _get_info_metadata(
+ info, "mapped_attributes"
+ )
+ if mapped_attributes is None:
+ return None
+
+ attributes: List[SQLAlchemyAttribute] = []
+
+ for data in mapped_attributes:
+ attr = SQLAlchemyAttribute.deserialize(info, data, api)
+ attr.expand_typevar_from_subtype(info)
+ attributes.append(attr)
+
+ return attributes
+
+
+def set_mapped_attributes(
+ info: TypeInfo, attributes: List[SQLAlchemyAttribute]
+) -> None:
+ _set_info_metadata(
+ info,
+ "mapped_attributes",
+ [attribute.serialize() for attribute in attributes],
+ )
def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context) -> None:
@@ -106,14 +173,14 @@ def add_global(
@overload
-def _get_callexpr_kwarg(
+def get_callexpr_kwarg(
callexpr: CallExpr, name: str, *, expr_types: None = ...
) -> Optional[Union[CallExpr, NameExpr]]:
...
@overload
-def _get_callexpr_kwarg(
+def get_callexpr_kwarg(
callexpr: CallExpr,
name: str,
*,
@@ -122,7 +189,7 @@ def _get_callexpr_kwarg(
...
-def _get_callexpr_kwarg(
+def get_callexpr_kwarg(
callexpr: CallExpr,
name: str,
*,
@@ -142,7 +209,7 @@ def _get_callexpr_kwarg(
return None
-def _flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]:
+def flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]:
for stmt in stmts:
if (
isinstance(stmt, IfStmt)
@@ -155,7 +222,7 @@ def _flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]:
yield stmt
-def _unbound_to_instance(
+def unbound_to_instance(
api: SemanticAnalyzerPluginInterface, typ: Type
) -> Type:
"""Take the UnboundType that we seem to get as the ret_type from a FuncDef
@@ -173,10 +240,10 @@ def _unbound_to_instance(
if typ.name == "Optional":
# convert from "Optional?" to the more familiar
# UnionType[..., NoneType()]
- return _unbound_to_instance(
+ return unbound_to_instance(
api,
UnionType(
- [_unbound_to_instance(api, typ_arg) for typ_arg in typ.args]
+ [unbound_to_instance(api, typ_arg) for typ_arg in typ.args]
+ [NoneType()]
),
)
@@ -193,7 +260,7 @@ def _unbound_to_instance(
return Instance(
bound_type,
[
- _unbound_to_instance(api, arg)
+ unbound_to_instance(api, arg)
if isinstance(arg, UnboundType)
else arg
for arg in typ.args
@@ -203,9 +270,9 @@ def _unbound_to_instance(
return typ
-def _info_for_cls(
+def info_for_cls(
cls: ClassDef, api: SemanticAnalyzerPluginInterface
-) -> TypeInfo:
+) -> Optional[TypeInfo]:
if cls.info is CLASSDEF_NO_INFO:
sym = api.lookup_qualified(cls.name, cls)
if sym is None:
@@ -214,3 +281,15 @@ def _info_for_cls(
return sym.node
return cls.info
+
+
+def expr_to_mapped_constructor(expr: Expression) -> CallExpr:
+ column_descriptor = NameExpr("__sa_Mapped")
+ column_descriptor.fullname = "sqlalchemy.orm.attributes.Mapped"
+ member_expr = MemberExpr(column_descriptor, "_empty_constructor")
+ return CallExpr(
+ member_expr,
+ [expr],
+ [ARG_POS],
+ ["arg1"],
+ )