diff options
Diffstat (limited to 'lib/sqlalchemy/ext/mypy/util.py')
-rw-r--r-- | lib/sqlalchemy/ext/mypy/util.py | 179 |
1 files changed, 129 insertions, 50 deletions
diff --git a/lib/sqlalchemy/ext/mypy/util.py b/lib/sqlalchemy/ext/mypy/util.py index 67c3fa209..614805d77 100644 --- a/lib/sqlalchemy/ext/mypy/util.py +++ b/lib/sqlalchemy/ext/mypy/util.py @@ -1,5 +1,4 @@ from typing import Any -from typing import cast from typing import Iterable from typing import Iterator from typing import List @@ -10,12 +9,15 @@ from typing import Type as TypingType from typing import TypeVar from typing import Union +from mypy.nodes import ARG_POS from mypy.nodes import CallExpr from mypy.nodes import ClassDef from mypy.nodes import CLASSDEF_NO_INFO from mypy.nodes import Context +from mypy.nodes import Expression from mypy.nodes import IfStmt from mypy.nodes import JsonDict +from mypy.nodes import MemberExpr from mypy.nodes import NameExpr from mypy.nodes import Statement from mypy.nodes import SymbolTableNode @@ -24,10 +26,11 @@ from mypy.plugin import ClassDefContext from mypy.plugin import DynamicClassDefContext from mypy.plugin import SemanticAnalyzerPluginInterface from mypy.plugins.common import deserialize_and_fixup_type +from mypy.typeops import map_type_from_supertype from mypy.types import Instance from mypy.types import NoneType -from mypy.types import ProperType from mypy.types import Type +from mypy.types import TypeVarType from mypy.types import UnboundType from mypy.types import UnionType @@ -35,53 +38,117 @@ from mypy.types import UnionType _TArgType = TypeVar("_TArgType", bound=Union[CallExpr, NameExpr]) -class DeclClassApplied: +class SQLAlchemyAttribute: def __init__( self, - is_mapped: bool, - has_table: bool, - mapped_attr_names: Iterable[Tuple[str, ProperType]], - mapped_mro: Iterable[Instance], - ): - self.is_mapped = is_mapped - self.has_table = has_table - self.mapped_attr_names = list(mapped_attr_names) - self.mapped_mro = list(mapped_mro) + name: str, + line: int, + column: int, + typ: Optional[Type], + info: TypeInfo, + ) -> None: + self.name = name + self.line = line + self.column = column + self.type = typ + self.info = info def serialize(self) -> JsonDict: + assert self.type return { - "is_mapped": self.is_mapped, - "has_table": self.has_table, - "mapped_attr_names": [ - (name, type_.serialize()) - for name, type_ in self.mapped_attr_names - ], - "mapped_mro": [type_.serialize() for type_ in self.mapped_mro], + "name": self.name, + "line": self.line, + "column": self.column, + "type": self.type.serialize(), } + def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None: + """Expands type vars in the context of a subtype when an attribute is inherited + from a generic super type.""" + if not isinstance(self.type, TypeVarType): + return + + self.type = map_type_from_supertype(self.type, sub_type, self.info) + @classmethod def deserialize( - cls, data: JsonDict, api: SemanticAnalyzerPluginInterface - ) -> "DeclClassApplied": - - return DeclClassApplied( - is_mapped=data["is_mapped"], - has_table=data["has_table"], - mapped_attr_names=cast( - List[Tuple[str, ProperType]], - [ - (name, deserialize_and_fixup_type(type_, api)) - for name, type_ in data["mapped_attr_names"] - ], - ), - mapped_mro=cast( - List[Instance], - [ - deserialize_and_fixup_type(type_, api) - for type_ in data["mapped_mro"] - ], - ), - ) + cls, + info: TypeInfo, + data: JsonDict, + api: SemanticAnalyzerPluginInterface, + ) -> "SQLAlchemyAttribute": + data = data.copy() + typ = deserialize_and_fixup_type(data.pop("type"), api) + return cls(typ=typ, info=info, **data) + + +def _set_info_metadata(info: TypeInfo, key: str, data: Any) -> None: + info.metadata.setdefault("sqlalchemy", {})[key] = data + + +def _get_info_metadata(info: TypeInfo, key: str) -> Optional[Any]: + return info.metadata.get("sqlalchemy", {}).get(key, None) + + +def _get_info_mro_metadata(info: TypeInfo, key: str) -> Optional[Any]: + if info.mro: + for base in info.mro: + metadata = _get_info_metadata(base, key) + if metadata is not None: + return metadata + return None + + +def set_is_base(info: TypeInfo) -> None: + _set_info_metadata(info, "is_base", True) + + +def get_is_base(info: TypeInfo) -> bool: + is_base = _get_info_metadata(info, "is_base") + return is_base is True + + +def has_declarative_base(info: TypeInfo) -> bool: + is_base = _get_info_mro_metadata(info, "is_base") + return is_base is True + + +def set_has_table(info: TypeInfo) -> None: + _set_info_metadata(info, "has_table", True) + + +def get_has_table(info: TypeInfo) -> bool: + is_base = _get_info_metadata(info, "has_table") + return is_base is True + + +def get_mapped_attributes( + info: TypeInfo, api: SemanticAnalyzerPluginInterface +) -> Optional[List[SQLAlchemyAttribute]]: + mapped_attributes: Optional[List[JsonDict]] = _get_info_metadata( + info, "mapped_attributes" + ) + if mapped_attributes is None: + return None + + attributes: List[SQLAlchemyAttribute] = [] + + for data in mapped_attributes: + attr = SQLAlchemyAttribute.deserialize(info, data, api) + attr.expand_typevar_from_subtype(info) + attributes.append(attr) + + return attributes + + +def set_mapped_attributes( + info: TypeInfo, attributes: List[SQLAlchemyAttribute] +) -> None: + _set_info_metadata( + info, + "mapped_attributes", + [attribute.serialize() for attribute in attributes], + ) def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context) -> None: @@ -106,14 +173,14 @@ def add_global( @overload -def _get_callexpr_kwarg( +def get_callexpr_kwarg( callexpr: CallExpr, name: str, *, expr_types: None = ... ) -> Optional[Union[CallExpr, NameExpr]]: ... @overload -def _get_callexpr_kwarg( +def get_callexpr_kwarg( callexpr: CallExpr, name: str, *, @@ -122,7 +189,7 @@ def _get_callexpr_kwarg( ... -def _get_callexpr_kwarg( +def get_callexpr_kwarg( callexpr: CallExpr, name: str, *, @@ -142,7 +209,7 @@ def _get_callexpr_kwarg( return None -def _flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]: +def flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]: for stmt in stmts: if ( isinstance(stmt, IfStmt) @@ -155,7 +222,7 @@ def _flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]: yield stmt -def _unbound_to_instance( +def unbound_to_instance( api: SemanticAnalyzerPluginInterface, typ: Type ) -> Type: """Take the UnboundType that we seem to get as the ret_type from a FuncDef @@ -173,10 +240,10 @@ def _unbound_to_instance( if typ.name == "Optional": # convert from "Optional?" to the more familiar # UnionType[..., NoneType()] - return _unbound_to_instance( + return unbound_to_instance( api, UnionType( - [_unbound_to_instance(api, typ_arg) for typ_arg in typ.args] + [unbound_to_instance(api, typ_arg) for typ_arg in typ.args] + [NoneType()] ), ) @@ -193,7 +260,7 @@ def _unbound_to_instance( return Instance( bound_type, [ - _unbound_to_instance(api, arg) + unbound_to_instance(api, arg) if isinstance(arg, UnboundType) else arg for arg in typ.args @@ -203,9 +270,9 @@ def _unbound_to_instance( return typ -def _info_for_cls( +def info_for_cls( cls: ClassDef, api: SemanticAnalyzerPluginInterface -) -> TypeInfo: +) -> Optional[TypeInfo]: if cls.info is CLASSDEF_NO_INFO: sym = api.lookup_qualified(cls.name, cls) if sym is None: @@ -214,3 +281,15 @@ def _info_for_cls( return sym.node return cls.info + + +def expr_to_mapped_constructor(expr: Expression) -> CallExpr: + column_descriptor = NameExpr("__sa_Mapped") + column_descriptor.fullname = "sqlalchemy.orm.attributes.Mapped" + member_expr = MemberExpr(column_descriptor, "_empty_constructor") + return CallExpr( + member_expr, + [expr], + [ARG_POS], + ["arg1"], + ) |