diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-01-24 17:04:27 -0500 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-02-13 14:23:04 -0500 |
| commit | e545298e35ea9f126054b337e4b5ba01988b29f7 (patch) | |
| tree | e64aea159111d5921ff01f08b1c4efb667249dfe /lib/sqlalchemy | |
| parent | f1da1623b800cd4de3b71fd1b2ad5ccfde286780 (diff) | |
| download | sqlalchemy-e545298e35ea9f126054b337e4b5ba01988b29f7.tar.gz | |
establish mypy / typing approach for v2.0
large patch to get ORM / typing efforts started.
this is to support adding new test cases to mypy,
support dropping sqlalchemy2-stubs entirely from the
test suite, validate major ORM typing reorganization
to eliminate the need for the mypy plugin.
* New declarative approach which uses annotation
introspection, fixes: #7535
* Mapped[] is now at the base of all ORM constructs
that find themselves in classes, to support direct
typing without plugins
* Mypy plugin updated for new typing structures
* Mypy test suite broken out into "plugin" tests vs.
"plain" tests, and enhanced to better support test
structures where we assert that various objects are
introspected by the type checker as we expect.
as we go forward with typing, we will
add new use cases to "plain" where we can assert that
types are introspected as we expect.
* For typing support, users will be much more exposed to the
class names of things. Add these all to "sqlalchemy" import
space.
* Column(ForeignKey()) no longer needs to be `@declared_attr`
if the FK refers to a remote table
* composite() attributes mapped to a dataclass no longer
need to implement a `__composite_values__()` method
* with_variant() accepts multiple dialect names
Change-Id: I22797c0be73a8fbbd2d6f5e0c0b7258b17fe145d
Fixes: #7535
Fixes: #7551
References: #6810
Diffstat (limited to 'lib/sqlalchemy')
60 files changed, 3706 insertions, 1883 deletions
diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py index c8ec1d825..eadb427d0 100644 --- a/lib/sqlalchemy/__init__.py +++ b/lib/sqlalchemy/__init__.py @@ -6,10 +6,56 @@ # the MIT License: https://www.opensource.org/licenses/mit-license.php from . import util as _util +from .engine import AdaptedConnection as AdaptedConnection +from .engine import BaseCursorResult as BaseCursorResult +from .engine import BaseRow as BaseRow +from .engine import BindTyping as BindTyping +from .engine import BufferedColumnResultProxy as BufferedColumnResultProxy +from .engine import BufferedColumnRow as BufferedColumnRow +from .engine import BufferedRowResultProxy as BufferedRowResultProxy +from .engine import ChunkedIteratorResult as ChunkedIteratorResult +from .engine import Compiled as Compiled +from .engine import Connection as Connection from .engine import create_engine as create_engine from .engine import create_mock_engine as create_mock_engine +from .engine import CreateEnginePlugin as CreateEnginePlugin +from .engine import CursorResult as CursorResult +from .engine import Dialect as Dialect +from .engine import Engine as Engine from .engine import engine_from_config as engine_from_config +from .engine import ExceptionContext as ExceptionContext +from .engine import ExecutionContext as ExecutionContext +from .engine import FrozenResult as FrozenResult +from .engine import FullyBufferedResultProxy as FullyBufferedResultProxy +from .engine import Inspector as Inspector +from .engine import IteratorResult as IteratorResult +from .engine import make_url as make_url +from .engine import MappingResult as MappingResult +from .engine import MergedResult as MergedResult +from .engine import NestedTransaction as NestedTransaction +from .engine import Result as Result +from .engine import result_tuple as result_tuple +from .engine import ResultProxy as ResultProxy +from .engine import RootTransaction as RootTransaction +from .engine import Row as Row +from .engine import RowMapping as RowMapping +from .engine import ScalarResult as ScalarResult +from .engine import Transaction as Transaction +from .engine import TwoPhaseTransaction as TwoPhaseTransaction +from .engine import TypeCompiler as TypeCompiler +from .engine import URL as URL from .inspection import inspect as inspect +from .pool import AssertionPool as AssertionPool +from .pool import AsyncAdaptedQueuePool as AsyncAdaptedQueuePool +from .pool import ( + FallbackAsyncAdaptedQueuePool as FallbackAsyncAdaptedQueuePool, +) +from .pool import NullPool as NullPool +from .pool import Pool as Pool +from .pool import PoolProxiedConnection as PoolProxiedConnection +from .pool import QueuePool as QueuePool +from .pool import SingletonThreadPool as SingleonThreadPool +from .pool import StaticPool as StaticPool from .schema import BLANK_SCHEMA as BLANK_SCHEMA from .schema import CheckConstraint as CheckConstraint from .schema import Column as Column @@ -28,67 +74,139 @@ from .schema import PrimaryKeyConstraint as PrimaryKeyConstraint from .schema import Sequence as Sequence from .schema import Table as Table from .schema import UniqueConstraint as UniqueConstraint -from .sql import alias as alias -from .sql import all_ as all_ -from .sql import and_ as and_ -from .sql import any_ as any_ -from .sql import asc as asc -from .sql import between as between -from .sql import bindparam as bindparam -from .sql import case as case -from .sql import cast as cast -from .sql import collate as collate -from .sql import column as column -from .sql import delete as delete -from .sql import desc as desc -from .sql import distinct as distinct -from .sql import except_ as except_ -from .sql import except_all as except_all -from .sql import exists as exists -from .sql import extract as extract -from .sql import false as false -from .sql import func as func -from .sql import funcfilter as funcfilter -from .sql import insert as insert -from .sql import intersect as intersect -from .sql import intersect_all as intersect_all -from .sql import join as join -from .sql import label as label -from .sql import LABEL_STYLE_DEFAULT as LABEL_STYLE_DEFAULT -from .sql import ( +from .sql import SelectLabelStyle as SelectLabelStyle +from .sql.expression import Alias as Alias +from .sql.expression import alias as alias +from .sql.expression import AliasedReturnsRows as AliasedReturnsRows +from .sql.expression import all_ as all_ +from .sql.expression import and_ as and_ +from .sql.expression import any_ as any_ +from .sql.expression import asc as asc +from .sql.expression import between as between +from .sql.expression import BinaryExpression as BinaryExpression +from .sql.expression import bindparam as bindparam +from .sql.expression import BindParameter as BindParameter +from .sql.expression import BooleanClauseList as BooleanClauseList +from .sql.expression import CacheKey as CacheKey +from .sql.expression import Case as Case +from .sql.expression import case as case +from .sql.expression import Cast as Cast +from .sql.expression import cast as cast +from .sql.expression import ClauseElement as ClauseElement +from .sql.expression import ClauseList as ClauseList +from .sql.expression import collate as collate +from .sql.expression import CollectionAggregate as CollectionAggregate +from .sql.expression import column as column +from .sql.expression import ColumnClause as ColumnClause +from .sql.expression import ColumnCollection as ColumnCollection +from .sql.expression import ColumnElement as ColumnElement +from .sql.expression import ColumnOperators as ColumnOperators +from .sql.expression import CompoundSelect as CompoundSelect +from .sql.expression import CTE as CTE +from .sql.expression import cte as cte +from .sql.expression import custom_op as custom_op +from .sql.expression import Delete as Delete +from .sql.expression import delete as delete +from .sql.expression import desc as desc +from .sql.expression import distinct as distinct +from .sql.expression import except_ as except_ +from .sql.expression import except_all as except_all +from .sql.expression import Executable as Executable +from .sql.expression import Exists as Exists +from .sql.expression import exists as exists +from .sql.expression import Extract as Extract +from .sql.expression import extract as extract +from .sql.expression import false as false +from .sql.expression import False_ as False_ +from .sql.expression import FromClause as FromClause +from .sql.expression import FromGrouping as FromGrouping +from .sql.expression import func as func +from .sql.expression import funcfilter as funcfilter +from .sql.expression import Function as Function +from .sql.expression import FunctionElement as FunctionElement +from .sql.expression import FunctionFilter as FunctionFilter +from .sql.expression import GenerativeSelect as GenerativeSelect +from .sql.expression import Grouping as Grouping +from .sql.expression import HasCTE as HasCTE +from .sql.expression import HasPrefixes as HasPrefixes +from .sql.expression import HasSuffixes as HasSuffixes +from .sql.expression import Insert as Insert +from .sql.expression import insert as insert +from .sql.expression import intersect as intersect +from .sql.expression import intersect_all as intersect_all +from .sql.expression import Join as Join +from .sql.expression import join as join +from .sql.expression import Label as Label +from .sql.expression import label as label +from .sql.expression import LABEL_STYLE_DEFAULT as LABEL_STYLE_DEFAULT +from .sql.expression import ( LABEL_STYLE_DISAMBIGUATE_ONLY as LABEL_STYLE_DISAMBIGUATE_ONLY, ) -from .sql import LABEL_STYLE_NONE as LABEL_STYLE_NONE -from .sql import ( +from .sql.expression import LABEL_STYLE_NONE as LABEL_STYLE_NONE +from .sql.expression import ( LABEL_STYLE_TABLENAME_PLUS_COL as LABEL_STYLE_TABLENAME_PLUS_COL, ) -from .sql import lambda_stmt as lambda_stmt -from .sql import lateral as lateral -from .sql import literal as literal -from .sql import literal_column as literal_column -from .sql import modifier as modifier -from .sql import not_ as not_ -from .sql import null as null -from .sql import nulls_first as nulls_first -from .sql import nulls_last as nulls_last -from .sql import nullsfirst as nullsfirst -from .sql import nullslast as nullslast -from .sql import or_ as or_ -from .sql import outerjoin as outerjoin -from .sql import outparam as outparam -from .sql import over as over -from .sql import select as select -from .sql import table as table -from .sql import tablesample as tablesample -from .sql import text as text -from .sql import true as true -from .sql import tuple_ as tuple_ -from .sql import type_coerce as type_coerce -from .sql import union as union -from .sql import union_all as union_all -from .sql import update as update -from .sql import values as values -from .sql import within_group as within_group +from .sql.expression import lambda_stmt as lambda_stmt +from .sql.expression import LambdaElement as LambdaElement +from .sql.expression import Lateral as Lateral +from .sql.expression import lateral as lateral +from .sql.expression import literal as literal +from .sql.expression import literal_column as literal_column +from .sql.expression import modifier as modifier +from .sql.expression import not_ as not_ +from .sql.expression import Null as Null +from .sql.expression import null as null +from .sql.expression import nulls_first as nulls_first +from .sql.expression import nulls_last as nulls_last +from .sql.expression import Operators as Operators +from .sql.expression import or_ as or_ +from .sql.expression import outerjoin as outerjoin +from .sql.expression import outparam as outparam +from .sql.expression import Over as Over +from .sql.expression import over as over +from .sql.expression import quoted_name as quoted_name +from .sql.expression import ReleaseSavepointClause as ReleaseSavepointClause +from .sql.expression import ReturnsRows as ReturnsRows +from .sql.expression import ( + RollbackToSavepointClause as RollbackToSavepointClause, +) +from .sql.expression import SavepointClause as SavepointClause +from .sql.expression import ScalarSelect as ScalarSelect +from .sql.expression import Select as Select +from .sql.expression import select as select +from .sql.expression import Selectable as Selectable +from .sql.expression import SelectBase as SelectBase +from .sql.expression import StatementLambdaElement as StatementLambdaElement +from .sql.expression import Subquery as Subquery +from .sql.expression import table as table +from .sql.expression import TableClause as TableClause +from .sql.expression import TableSample as TableSample +from .sql.expression import tablesample as tablesample +from .sql.expression import TableValuedAlias as TableValuedAlias +from .sql.expression import text as text +from .sql.expression import TextAsFrom as TextAsFrom +from .sql.expression import TextClause as TextClause +from .sql.expression import TextualSelect as TextualSelect +from .sql.expression import true as true +from .sql.expression import True_ as True_ +from .sql.expression import Tuple as Tuple +from .sql.expression import tuple_ as tuple_ +from .sql.expression import type_coerce as type_coerce +from .sql.expression import TypeClause as TypeClause +from .sql.expression import TypeCoerce as TypeCoerce +from .sql.expression import typing as typing +from .sql.expression import UnaryExpression as UnaryExpression +from .sql.expression import union as union +from .sql.expression import union_all as union_all +from .sql.expression import Update as Update +from .sql.expression import update as update +from .sql.expression import UpdateBase as UpdateBase +from .sql.expression import Values as Values +from .sql.expression import values as values +from .sql.expression import ValuesBase as ValuesBase +from .sql.expression import Visitable as Visitable +from .sql.expression import within_group as within_group +from .sql.expression import WithinGroup as WithinGroup from .types import ARRAY as ARRAY from .types import BIGINT as BIGINT from .types import BigInteger as BigInteger @@ -133,7 +251,6 @@ from .types import UnicodeText as UnicodeText from .types import VARBINARY as VARBINARY from .types import VARCHAR as VARCHAR - __version__ = "2.0.0b1" diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py index e934f9f89..c6bc4b6aa 100644 --- a/lib/sqlalchemy/engine/__init__.py +++ b/lib/sqlalchemy/engine/__init__.py @@ -15,45 +15,45 @@ constructor ``create_engine()``. """ -from . import events -from . import util -from .base import Connection -from .base import Engine -from .base import NestedTransaction -from .base import RootTransaction -from .base import Transaction -from .base import TwoPhaseTransaction -from .create import create_engine -from .create import engine_from_config -from .cursor import BaseCursorResult -from .cursor import BufferedColumnResultProxy -from .cursor import BufferedColumnRow -from .cursor import BufferedRowResultProxy -from .cursor import CursorResult -from .cursor import FullyBufferedResultProxy -from .cursor import ResultProxy -from .interfaces import AdaptedConnection -from .interfaces import BindTyping -from .interfaces import Compiled -from .interfaces import CreateEnginePlugin -from .interfaces import Dialect -from .interfaces import ExceptionContext -from .interfaces import ExecutionContext -from .interfaces import TypeCompiler -from .mock import create_mock_engine -from .reflection import Inspector -from .result import ChunkedIteratorResult -from .result import FrozenResult -from .result import IteratorResult -from .result import MappingResult -from .result import MergedResult -from .result import Result -from .result import result_tuple -from .result import ScalarResult -from .row import BaseRow -from .row import Row -from .row import RowMapping -from .url import make_url -from .url import URL -from .util import connection_memoize -from ..sql import ddl +from . import events as events +from . import util as util +from .base import Connection as Connection +from .base import Engine as Engine +from .base import NestedTransaction as NestedTransaction +from .base import RootTransaction as RootTransaction +from .base import Transaction as Transaction +from .base import TwoPhaseTransaction as TwoPhaseTransaction +from .create import create_engine as create_engine +from .create import engine_from_config as engine_from_config +from .cursor import BaseCursorResult as BaseCursorResult +from .cursor import BufferedColumnResultProxy as BufferedColumnResultProxy +from .cursor import BufferedColumnRow as BufferedColumnRow +from .cursor import BufferedRowResultProxy as BufferedRowResultProxy +from .cursor import CursorResult as CursorResult +from .cursor import FullyBufferedResultProxy as FullyBufferedResultProxy +from .cursor import ResultProxy as ResultProxy +from .interfaces import AdaptedConnection as AdaptedConnection +from .interfaces import BindTyping as BindTyping +from .interfaces import Compiled as Compiled +from .interfaces import CreateEnginePlugin as CreateEnginePlugin +from .interfaces import Dialect as Dialect +from .interfaces import ExceptionContext as ExceptionContext +from .interfaces import ExecutionContext as ExecutionContext +from .interfaces import TypeCompiler as TypeCompiler +from .mock import create_mock_engine as create_mock_engine +from .reflection import Inspector as Inspector +from .result import ChunkedIteratorResult as ChunkedIteratorResult +from .result import FrozenResult as FrozenResult +from .result import IteratorResult as IteratorResult +from .result import MappingResult as MappingResult +from .result import MergedResult as MergedResult +from .result import Result as Result +from .result import result_tuple as result_tuple +from .result import ScalarResult as ScalarResult +from .row import BaseRow as BaseRow +from .row import Row as Row +from .row import RowMapping as RowMapping +from .url import make_url as make_url +from .url import URL as URL +from .util import connection_memoize as connection_memoize +from ..sql import ddl as ddl diff --git a/lib/sqlalchemy/engine/create.py b/lib/sqlalchemy/engine/create.py index 6fb827989..2f8ce17df 100644 --- a/lib/sqlalchemy/engine/create.py +++ b/lib/sqlalchemy/engine/create.py @@ -6,6 +6,7 @@ # the MIT License: https://www.opensource.org/licenses/mit-license.php from typing import Any +from typing import Union from . import base from . import url as _url @@ -41,7 +42,7 @@ from ..sql import compiler "is deprecated and will be removed in a future release. ", ), ) -def create_engine(url: "_url.URL", **kwargs: Any) -> "base.Engine": +def create_engine(url: Union[str, "_url.URL"], **kwargs: Any) -> "base.Engine": """Create a new :class:`_engine.Engine` instance. The standard calling form is to send the :ref:`URL <database_urls>` as the diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index df7a53ab7..882392e9c 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -24,11 +24,13 @@ methods such as get_table_names, get_columns, etc. use the key 'name'. So for most return values, each record will have a 'name' attribute.. """ - import contextlib +from typing import List +from typing import Optional from .base import Connection from .base import Engine +from .interfaces import ReflectedColumn from .. import exc from .. import inspection from .. import sql @@ -433,7 +435,9 @@ class Inspector(inspection.Inspectable["Inspector"]): conn, view_name, schema, info_cache=self.info_cache ) - def get_columns(self, table_name, schema=None, **kw): + def get_columns( + self, table_name: str, schema: Optional[str] = None, **kw + ) -> List[ReflectedColumn]: """Return information about columns in `table_name`. Given a string `table_name` and an optional string `schema`, return diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index e6a826c64..d5119907e 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -361,7 +361,7 @@ class AssociationProxyInstance: prop = orm.class_mapper(owning_class).get_property(target_collection) # this was never asserted before but this should be made clear. - if not isinstance(prop, orm.RelationshipProperty): + if not isinstance(prop, orm.Relationship): raise NotImplementedError( "association proxy to a non-relationship " "intermediary is not supported" @@ -717,8 +717,8 @@ class AssociationProxyInstance: """Produce a proxied 'any' expression using EXISTS. This expression will be a composed product - using the :meth:`.RelationshipProperty.Comparator.any` - and/or :meth:`.RelationshipProperty.Comparator.has` + using the :meth:`.Relationship.Comparator.any` + and/or :meth:`.Relationship.Comparator.has` operators of the underlying proxied attributes. """ @@ -737,8 +737,8 @@ class AssociationProxyInstance: """Produce a proxied 'has' expression using EXISTS. This expression will be a composed product - using the :meth:`.RelationshipProperty.Comparator.any` - and/or :meth:`.RelationshipProperty.Comparator.has` + using the :meth:`.Relationship.Comparator.any` + and/or :meth:`.Relationship.Comparator.has` operators of the underlying proxied attributes. """ @@ -859,9 +859,9 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance): """Produce a proxied 'contains' expression using EXISTS. This expression will be a composed product - using the :meth:`.RelationshipProperty.Comparator.any`, - :meth:`.RelationshipProperty.Comparator.has`, - and/or :meth:`.RelationshipProperty.Comparator.contains` + using the :meth:`.Relationship.Comparator.any`, + :meth:`.Relationship.Comparator.has`, + and/or :meth:`.Relationship.Comparator.contains` operators of the underlying proxied attributes. """ diff --git a/lib/sqlalchemy/ext/declarative/extensions.py b/lib/sqlalchemy/ext/declarative/extensions.py index 5aff4dfe2..470ff6ad8 100644 --- a/lib/sqlalchemy/ext/declarative/extensions.py +++ b/lib/sqlalchemy/ext/declarative/extensions.py @@ -378,7 +378,7 @@ class DeferredReflection: metadata = mapper.class_.metadata for rel in mapper._props.values(): if ( - isinstance(rel, relationships.RelationshipProperty) + isinstance(rel, relationships.Relationship) and rel.secondary is not None ): if isinstance(rel.secondary, Table): diff --git a/lib/sqlalchemy/ext/mypy/apply.py b/lib/sqlalchemy/ext/mypy/apply.py index 99be194cd..4e244b5b9 100644 --- a/lib/sqlalchemy/ext/mypy/apply.py +++ b/lib/sqlalchemy/ext/mypy/apply.py @@ -36,6 +36,7 @@ from mypy.types import UnionType from . import infer from . import util +from .names import expr_to_mapped_constructor from .names import NAMED_TYPE_SQLA_MAPPED @@ -117,6 +118,7 @@ def re_apply_declarative_assignments( ): left_node = stmt.lvalues[0].node + python_type_for_type = mapped_attr_lookup[ stmt.lvalues[0].name ].type @@ -142,7 +144,7 @@ def re_apply_declarative_assignments( ) ): - python_type_for_type = ( + new_python_type_for_type = ( infer.infer_type_from_right_hand_nameexpr( api, stmt, @@ -152,19 +154,27 @@ def re_apply_declarative_assignments( ) ) - if python_type_for_type is None or isinstance( - python_type_for_type, UnboundType + if new_python_type_for_type is not None and not isinstance( + new_python_type_for_type, UnboundType ): - continue + python_type_for_type = new_python_type_for_type - # update the SQLAlchemyAttribute with the better information - mapped_attr_lookup[ - stmt.lvalues[0].name - ].type = python_type_for_type + # update the SQLAlchemyAttribute with the better + # information + mapped_attr_lookup[ + stmt.lvalues[0].name + ].type = python_type_for_type - update_cls_metadata = True + update_cls_metadata = True - if python_type_for_type is not None: + # for some reason if you have a Mapped type explicitly annotated, + # and here you set it again, mypy forgets how to do descriptors. + # no idea. 100% feeling around in the dark to see what sticks + if ( + not isinstance(left_node.type, Instance) + or left_node.type.type.fullname != NAMED_TYPE_SQLA_MAPPED + ): + assert python_type_for_type is not None left_node.type = api.named_type( NAMED_TYPE_SQLA_MAPPED, [python_type_for_type] ) @@ -202,6 +212,7 @@ def apply_type_to_mapped_statement( assert isinstance(left_node, Var) if left_hand_explicit_type is not None: + lvalue.is_inferred_def = False left_node.type = api.named_type( NAMED_TYPE_SQLA_MAPPED, [left_hand_explicit_type] ) @@ -224,7 +235,7 @@ def apply_type_to_mapped_statement( # _sa_Mapped._empty_constructor(<original CallExpr from rvalue>) # the original right-hand side is maintained so it gets type checked # internally - stmt.rvalue = util.expr_to_mapped_constructor(stmt.rvalue) + stmt.rvalue = expr_to_mapped_constructor(stmt.rvalue) def add_additional_orm_attributes( diff --git a/lib/sqlalchemy/ext/mypy/decl_class.py b/lib/sqlalchemy/ext/mypy/decl_class.py index c33c30e25..bd6c6f41e 100644 --- a/lib/sqlalchemy/ext/mypy/decl_class.py +++ b/lib/sqlalchemy/ext/mypy/decl_class.py @@ -337,7 +337,7 @@ def _scan_declarative_decorator_stmt( # <attr> : Mapped[<typ>] = # _sa_Mapped._empty_constructor(lambda: <function body>) # the function body is maintained so it gets type checked internally - rvalue = util.expr_to_mapped_constructor( + rvalue = names.expr_to_mapped_constructor( LambdaExpr(stmt.func.arguments, stmt.func.body) ) diff --git a/lib/sqlalchemy/ext/mypy/infer.py b/lib/sqlalchemy/ext/mypy/infer.py index 3cd946e04..6a5e99e48 100644 --- a/lib/sqlalchemy/ext/mypy/infer.py +++ b/lib/sqlalchemy/ext/mypy/infer.py @@ -42,11 +42,13 @@ def infer_type_from_right_hand_nameexpr( left_hand_explicit_type: Optional[ProperType], infer_from_right_side: RefExpr, ) -> Optional[ProperType]: - type_id = names.type_id_for_callee(infer_from_right_side) - if type_id is None: return None + elif type_id is names.MAPPED: + python_type_for_type = _infer_type_from_mapped( + api, stmt, node, left_hand_explicit_type, infer_from_right_side + ) elif type_id is names.COLUMN: python_type_for_type = _infer_type_from_decl_column( api, stmt, node, left_hand_explicit_type @@ -245,7 +247,7 @@ def _infer_type_from_decl_composite_property( node: Var, left_hand_explicit_type: Optional[ProperType], ) -> Optional[ProperType]: - """Infer the type of mapping from a CompositeProperty.""" + """Infer the type of mapping from a Composite.""" assert isinstance(stmt.rvalue, CallExpr) target_cls_arg = stmt.rvalue.args[0] @@ -271,6 +273,38 @@ def _infer_type_from_decl_composite_property( return python_type_for_type +def _infer_type_from_mapped( + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + node: Var, + left_hand_explicit_type: Optional[ProperType], + infer_from_right_side: RefExpr, +) -> Optional[ProperType]: + """Infer the type of mapping from a right side expression + that returns Mapped. + + + """ + assert isinstance(stmt.rvalue, CallExpr) + + # (Pdb) print(stmt.rvalue.callee) + # NameExpr(query_expression [sqlalchemy.orm._orm_constructors.query_expression]) # noqa: E501 + # (Pdb) stmt.rvalue.callee.node + # <mypy.nodes.FuncDef object at 0x7f8d92fb5940> + # (Pdb) stmt.rvalue.callee.node.type + # def [_T] (default_expr: sqlalchemy.sql.elements.ColumnElement[_T`-1] =) -> sqlalchemy.orm.base.Mapped[_T`-1] # noqa: E501 + # sqlalchemy.orm.base.Mapped[_T`-1] + # the_mapped_type = stmt.rvalue.callee.node.type.ret_type + + # TODO: look at generic ref and either use that, + # or reconcile w/ what's present, etc. + the_mapped_type = util.type_for_callee(infer_from_right_side) # noqa + + return infer_type_from_left_hand_type_only( + api, node, left_hand_explicit_type + ) + + def _infer_type_from_decl_column_property( api: SemanticAnalyzerPluginInterface, stmt: AssignmentStmt, diff --git a/lib/sqlalchemy/ext/mypy/names.py b/lib/sqlalchemy/ext/mypy/names.py index b6f911979..ad4449e5b 100644 --- a/lib/sqlalchemy/ext/mypy/names.py +++ b/lib/sqlalchemy/ext/mypy/names.py @@ -12,11 +12,14 @@ from typing import Set from typing import Tuple from typing import Union +from mypy.nodes import ARG_POS +from mypy.nodes import CallExpr from mypy.nodes import ClassDef from mypy.nodes import Expression from mypy.nodes import FuncDef from mypy.nodes import MemberExpr from mypy.nodes import NameExpr +from mypy.nodes import OverloadedFuncDef from mypy.nodes import SymbolNode from mypy.nodes import TypeAlias from mypy.nodes import TypeInfo @@ -51,7 +54,7 @@ QUERY_EXPRESSION: int = util.symbol("QUERY_EXPRESSION") # type: ignore NAMED_TYPE_BUILTINS_OBJECT = "builtins.object" NAMED_TYPE_BUILTINS_STR = "builtins.str" NAMED_TYPE_BUILTINS_LIST = "builtins.list" -NAMED_TYPE_SQLA_MAPPED = "sqlalchemy.orm.attributes.Mapped" +NAMED_TYPE_SQLA_MAPPED = "sqlalchemy.orm.base.Mapped" _lookup: Dict[str, Tuple[int, Set[str]]] = { "Column": ( @@ -61,11 +64,11 @@ _lookup: Dict[str, Tuple[int, Set[str]]] = { "sqlalchemy.sql.Column", }, ), - "RelationshipProperty": ( + "Relationship": ( RELATIONSHIP, { - "sqlalchemy.orm.relationships.RelationshipProperty", - "sqlalchemy.orm.RelationshipProperty", + "sqlalchemy.orm.relationships.Relationship", + "sqlalchemy.orm.Relationship", }, ), "registry": ( @@ -82,18 +85,18 @@ _lookup: Dict[str, Tuple[int, Set[str]]] = { "sqlalchemy.orm.ColumnProperty", }, ), - "SynonymProperty": ( + "Synonym": ( SYNONYM_PROPERTY, { - "sqlalchemy.orm.descriptor_props.SynonymProperty", - "sqlalchemy.orm.SynonymProperty", + "sqlalchemy.orm.descriptor_props.Synonym", + "sqlalchemy.orm.Synonym", }, ), - "CompositeProperty": ( + "Composite": ( COMPOSITE_PROPERTY, { - "sqlalchemy.orm.descriptor_props.CompositeProperty", - "sqlalchemy.orm.CompositeProperty", + "sqlalchemy.orm.descriptor_props.Composite", + "sqlalchemy.orm.Composite", }, ), "MapperProperty": ( @@ -159,7 +162,10 @@ _lookup: Dict[str, Tuple[int, Set[str]]] = { ), "query_expression": ( QUERY_EXPRESSION, - {"sqlalchemy.orm.query_expression"}, + { + "sqlalchemy.orm.query_expression", + "sqlalchemy.orm._orm_constructors.query_expression", + }, ), } @@ -209,7 +215,19 @@ def type_id_for_unbound_type( def type_id_for_callee(callee: Expression) -> Optional[int]: if isinstance(callee, (MemberExpr, NameExpr)): - if isinstance(callee.node, FuncDef): + if isinstance(callee.node, OverloadedFuncDef): + if ( + callee.node.impl + and callee.node.impl.type + and isinstance(callee.node.impl.type, CallableType) + ): + ret_type = get_proper_type(callee.node.impl.type.ret_type) + + if isinstance(ret_type, Instance): + return type_id_for_fullname(ret_type.type.fullname) + + return None + elif isinstance(callee.node, FuncDef): if callee.node.type and isinstance(callee.node.type, CallableType): ret_type = get_proper_type(callee.node.type.ret_type) @@ -251,3 +269,15 @@ def type_id_for_fullname(fullname: str) -> Optional[int]: return type_id else: return None + + +def expr_to_mapped_constructor(expr: Expression) -> CallExpr: + column_descriptor = NameExpr("__sa_Mapped") + column_descriptor.fullname = NAMED_TYPE_SQLA_MAPPED + member_expr = MemberExpr(column_descriptor, "_empty_constructor") + return CallExpr( + member_expr, + [expr], + [ARG_POS], + ["arg1"], + ) diff --git a/lib/sqlalchemy/ext/mypy/plugin.py b/lib/sqlalchemy/ext/mypy/plugin.py index 0a21feb51..c9520fef3 100644 --- a/lib/sqlalchemy/ext/mypy/plugin.py +++ b/lib/sqlalchemy/ext/mypy/plugin.py @@ -40,6 +40,19 @@ from . import decl_class from . import names from . import util +try: + import sqlalchemy_stubs # noqa +except ImportError: + pass +else: + import sqlalchemy + + raise ImportError( + f"The SQLAlchemy mypy plugin in SQLAlchemy " + f"{sqlalchemy.__version__} does not work with sqlalchemy-stubs or " + "sqlalchemy2-stubs installed" + ) + class SQLAlchemyPlugin(Plugin): def get_dynamic_class_hook( diff --git a/lib/sqlalchemy/ext/mypy/util.py b/lib/sqlalchemy/ext/mypy/util.py index fa42074c3..741772eac 100644 --- a/lib/sqlalchemy/ext/mypy/util.py +++ b/lib/sqlalchemy/ext/mypy/util.py @@ -10,24 +10,27 @@ from typing import Type as TypingType from typing import TypeVar from typing import Union -from mypy.nodes import ARG_POS from mypy.nodes import CallExpr from mypy.nodes import ClassDef from mypy.nodes import CLASSDEF_NO_INFO from mypy.nodes import Context from mypy.nodes import Expression +from mypy.nodes import FuncDef from mypy.nodes import IfStmt from mypy.nodes import JsonDict from mypy.nodes import MemberExpr from mypy.nodes import NameExpr from mypy.nodes import Statement from mypy.nodes import SymbolTableNode +from mypy.nodes import TypeAlias from mypy.nodes import TypeInfo from mypy.plugin import ClassDefContext from mypy.plugin import DynamicClassDefContext from mypy.plugin import SemanticAnalyzerPluginInterface from mypy.plugins.common import deserialize_and_fixup_type from mypy.typeops import map_type_from_supertype +from mypy.types import CallableType +from mypy.types import get_proper_type from mypy.types import Instance from mypy.types import NoneType from mypy.types import Type @@ -231,6 +234,25 @@ def flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]: yield stmt +def type_for_callee(callee: Expression) -> Optional[Union[Instance, TypeInfo]]: + if isinstance(callee, (MemberExpr, NameExpr)): + if isinstance(callee.node, FuncDef): + if callee.node.type and isinstance(callee.node.type, CallableType): + ret_type = get_proper_type(callee.node.type.ret_type) + + if isinstance(ret_type, Instance): + return ret_type + + return None + elif isinstance(callee.node, TypeAlias): + target_type = get_proper_type(callee.node.target) + if isinstance(target_type, Instance): + return target_type + elif isinstance(callee.node, TypeInfo): + return callee.node + return None + + def unbound_to_instance( api: SemanticAnalyzerPluginInterface, typ: Type ) -> Type: @@ -290,15 +312,3 @@ def info_for_cls( return sym.node return cls.info - - -def expr_to_mapped_constructor(expr: Expression) -> CallExpr: - column_descriptor = NameExpr("__sa_Mapped") - column_descriptor.fullname = "sqlalchemy.orm.attributes.Mapped" - member_expr = MemberExpr(column_descriptor, "_empty_constructor") - return CallExpr( - member_expr, - [expr], - [ARG_POS], - ["arg1"], - ) diff --git a/lib/sqlalchemy/ext/orderinglist.py b/lib/sqlalchemy/ext/orderinglist.py index 5a327d1a5..5384851b1 100644 --- a/lib/sqlalchemy/ext/orderinglist.py +++ b/lib/sqlalchemy/ext/orderinglist.py @@ -119,14 +119,28 @@ start numbering at 1 or some other integer, provide ``count_from=1``. """ +from typing import Callable +from typing import List +from typing import Optional +from typing import Sequence +from typing import TypeVar + from ..orm.collections import collection from ..orm.collections import collection_adapter +_T = TypeVar("_T") +OrderingFunc = Callable[[int, Sequence[_T]], int] + __all__ = ["ordering_list"] -def ordering_list(attr, count_from=None, **kw): +def ordering_list( + attr: str, + count_from: Optional[int] = None, + ordering_func: Optional[OrderingFunc] = None, + reorder_on_append: bool = False, +) -> Callable[[], "OrderingList"]: """Prepares an :class:`OrderingList` factory for use in mapper definitions. Returns an object suitable for use as an argument to a Mapper @@ -157,7 +171,11 @@ def ordering_list(attr, count_from=None, **kw): """ - kw = _unsugar_count_from(count_from=count_from, **kw) + kw = _unsugar_count_from( + count_from=count_from, + ordering_func=ordering_func, + reorder_on_append=reorder_on_append, + ) return lambda: OrderingList(attr, **kw) @@ -207,7 +225,7 @@ def _unsugar_count_from(**kw): return kw -class OrderingList(list): +class OrderingList(List[_T]): """A custom list that manages position information for its children. The :class:`.OrderingList` object is normally set up using the @@ -216,8 +234,15 @@ class OrderingList(list): """ + ordering_attr: str + ordering_func: OrderingFunc + reorder_on_append: bool + def __init__( - self, ordering_attr=None, ordering_func=None, reorder_on_append=False + self, + ordering_attr: Optional[str] = None, + ordering_func: Optional[OrderingFunc] = None, + reorder_on_append: bool = False, ): """A custom list that manages position information for its children. @@ -282,7 +307,7 @@ class OrderingList(list): def _set_order_value(self, entity, value): setattr(entity, self.ordering_attr, value) - def reorder(self): + def reorder(self) -> None: """Synchronize ordering for the entire collection. Sweeps through the list and ensures that each object has accurate diff --git a/lib/sqlalchemy/log.py b/lib/sqlalchemy/log.py index 885163ecb..c6a8b6ea7 100644 --- a/lib/sqlalchemy/log.py +++ b/lib/sqlalchemy/log.py @@ -74,6 +74,8 @@ def class_logger(cls: Type[_IT]) -> Type[_IT]: class Identified: + __slots__ = () + logging_name: Optional[str] = None logger: Union[logging.Logger, "InstanceLogger"] @@ -116,6 +118,8 @@ class InstanceLogger: _echo: _EchoFlagType + __slots__ = ("echo", "logger") + def __init__(self, echo: _EchoFlagType, name: str): self.echo = echo self.logger = logging.getLogger(name) diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index 55f2f3100..bbed93310 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -17,19 +17,27 @@ from . import exc as exc from . import mapper as mapperlib from . import strategy_options as strategy_options from ._orm_constructors import _mapper_fn as mapper +from ._orm_constructors import aliased as aliased from ._orm_constructors import backref as backref from ._orm_constructors import clear_mappers as clear_mappers from ._orm_constructors import column_property as column_property from ._orm_constructors import composite as composite +from ._orm_constructors import CompositeProperty as CompositeProperty from ._orm_constructors import contains_alias as contains_alias from ._orm_constructors import create_session as create_session from ._orm_constructors import deferred as deferred from ._orm_constructors import dynamic_loader as dynamic_loader +from ._orm_constructors import join as join from ._orm_constructors import mapped_column as mapped_column +from ._orm_constructors import MappedColumn as MappedColumn +from ._orm_constructors import outerjoin as outerjoin from ._orm_constructors import query_expression as query_expression from ._orm_constructors import relationship as relationship +from ._orm_constructors import RelationshipProperty as RelationshipProperty from ._orm_constructors import synonym as synonym +from ._orm_constructors import SynonymProperty as SynonymProperty from ._orm_constructors import with_loader_criteria as with_loader_criteria +from ._orm_constructors import with_polymorphic as with_polymorphic from .attributes import AttributeEvent as AttributeEvent from .attributes import InstrumentedAttribute as InstrumentedAttribute from .attributes import QueryableAttribute as QueryableAttribute @@ -46,8 +54,8 @@ from .decl_api import declared_attr as declared_attr from .decl_api import has_inherited_table as has_inherited_table from .decl_api import registry as registry from .decl_api import synonym_for as synonym_for -from .descriptor_props import CompositeProperty as CompositeProperty -from .descriptor_props import SynonymProperty as SynonymProperty +from .descriptor_props import Composite as Composite +from .descriptor_props import Synonym as Synonym from .dynamic import AppenderQuery as AppenderQuery from .events import AttributeEvents as AttributeEvents from .events import InstanceEvents as InstanceEvents @@ -81,7 +89,7 @@ from .query import AliasOption as AliasOption from .query import FromStatement as FromStatement from .query import Query as Query from .relationships import foreign as foreign -from .relationships import RelationshipProperty as RelationshipProperty +from .relationships import Relationship as Relationship from .relationships import remote as remote from .scoping import scoped_session as scoped_session from .session import close_all_sessions as close_all_sessions @@ -111,17 +119,13 @@ from .strategy_options import undefer as undefer from .strategy_options import undefer_group as undefer_group from .strategy_options import with_expression as with_expression from .unitofwork import UOWTransaction as UOWTransaction -from .util import aliased as aliased from .util import Bundle as Bundle from .util import CascadeOptions as CascadeOptions -from .util import join as join from .util import LoaderCriteriaOption as LoaderCriteriaOption from .util import object_mapper as object_mapper -from .util import outerjoin as outerjoin from .util import polymorphic_union as polymorphic_union from .util import was_deleted as was_deleted from .util import with_parent as with_parent -from .util import with_polymorphic as with_polymorphic from .. import util as _sa_util diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py index 80607670e..a1f1faa05 100644 --- a/lib/sqlalchemy/orm/_orm_constructors.py +++ b/lib/sqlalchemy/orm/_orm_constructors.py @@ -7,35 +7,52 @@ import typing from typing import Any -from typing import Callable from typing import Collection +from typing import List +from typing import Mapping from typing import Optional from typing import overload +from typing import Set from typing import Type from typing import Union from . import mapper as mapperlib from .base import Mapped -from .descriptor_props import CompositeProperty -from .descriptor_props import SynonymProperty +from .descriptor_props import Composite +from .descriptor_props import Synonym +from .mapper import Mapper from .properties import ColumnProperty +from .properties import MappedColumn from .query import AliasOption -from .relationships import RelationshipProperty +from .relationships import _RelationshipArgumentType +from .relationships import Relationship from .session import Session +from .util import _ORMJoin +from .util import AliasedClass +from .util import AliasedInsp from .util import LoaderCriteriaOption from .. import sql from .. import util from ..exc import InvalidRequestError -from ..sql.schema import Column -from ..sql.schema import SchemaEventTarget +from ..sql.base import SchemaEventTarget +from ..sql.selectable import Alias +from ..sql.selectable import FromClause from ..sql.type_api import TypeEngine from ..util.typing import Literal - -_RC = typing.TypeVar("_RC") _T = typing.TypeVar("_T") +CompositeProperty = Composite +"""Alias for :class:`_orm.Composite`.""" + +RelationshipProperty = Relationship +"""Alias for :class:`_orm.Relationship`.""" + +SynonymProperty = Synonym +"""Alias for :class:`_orm.Synonym`.""" + + @util.deprecated( "1.4", "The :class:`.AliasOption` object is not necessary " @@ -51,35 +68,45 @@ def contains_alias(alias) -> "AliasOption": return AliasOption(alias) +# see test/ext/mypy/plain_files/mapped_column.py for mapped column +# typing tests + + @overload def mapped_column( + __type: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], *args: SchemaEventTarget, - nullable: bool = ..., - primary_key: bool = ..., + nullable: Literal[None] = ..., + primary_key: Literal[None] = ..., + deferred: bool = ..., **kw: Any, -) -> "Mapped": +) -> "MappedColumn[Any]": ... @overload def mapped_column( + __name: str, __type: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], *args: SchemaEventTarget, - nullable: Union[Literal[None], Literal[True]] = ..., - primary_key: Union[Literal[None], Literal[False]] = ..., + nullable: Literal[None] = ..., + primary_key: Literal[None] = ..., + deferred: bool = ..., **kw: Any, -) -> "Mapped[Optional[_T]]": +) -> "MappedColumn[Any]": ... @overload def mapped_column( + __name: str, __type: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], *args: SchemaEventTarget, - nullable: Union[Literal[None], Literal[True]] = ..., - primary_key: Union[Literal[None], Literal[False]] = ..., + nullable: Literal[True] = ..., + primary_key: Literal[None] = ..., + deferred: bool = ..., **kw: Any, -) -> "Mapped[Optional[_T]]": +) -> "MappedColumn[Optional[_T]]": ... @@ -87,45 +114,48 @@ def mapped_column( def mapped_column( __type: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], *args: SchemaEventTarget, - nullable: Union[Literal[None], Literal[False]] = ..., - primary_key: Literal[True] = True, + nullable: Literal[True] = ..., + primary_key: Literal[None] = ..., + deferred: bool = ..., **kw: Any, -) -> "Mapped[_T]": +) -> "MappedColumn[Optional[_T]]": ... @overload def mapped_column( + __name: str, __type: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], *args: SchemaEventTarget, nullable: Literal[False] = ..., - primary_key: bool = ..., + primary_key: Literal[None] = ..., + deferred: bool = ..., **kw: Any, -) -> "Mapped[_T]": +) -> "MappedColumn[_T]": ... @overload def mapped_column( - __name: str, __type: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], *args: SchemaEventTarget, - nullable: Union[Literal[None], Literal[True]] = ..., - primary_key: Union[Literal[None], Literal[False]] = ..., + nullable: Literal[False] = ..., + primary_key: Literal[None] = ..., + deferred: bool = ..., **kw: Any, -) -> "Mapped[Optional[_T]]": +) -> "MappedColumn[_T]": ... @overload def mapped_column( - __name: str, __type: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], *args: SchemaEventTarget, - nullable: Union[Literal[None], Literal[True]] = ..., - primary_key: Union[Literal[None], Literal[False]] = ..., + nullable: bool = ..., + primary_key: Literal[True] = ..., + deferred: bool = ..., **kw: Any, -) -> "Mapped[Optional[_T]]": +) -> "MappedColumn[_T]": ... @@ -134,55 +164,209 @@ def mapped_column( __name: str, __type: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], *args: SchemaEventTarget, - nullable: Union[Literal[None], Literal[False]] = ..., - primary_key: Literal[True] = True, + nullable: bool = ..., + primary_key: Literal[True] = ..., + deferred: bool = ..., **kw: Any, -) -> "Mapped[_T]": +) -> "MappedColumn[_T]": ... @overload def mapped_column( __name: str, - __type: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], *args: SchemaEventTarget, - nullable: Literal[False] = ..., + nullable: bool = ..., primary_key: bool = ..., + deferred: bool = ..., **kw: Any, -) -> "Mapped[_T]": +) -> "MappedColumn[Any]": ... -def mapped_column(*args, **kw) -> "Mapped": - """construct a new ORM-mapped :class:`_schema.Column` construct. +@overload +def mapped_column( + *args: SchemaEventTarget, + nullable: bool = ..., + primary_key: bool = ..., + deferred: bool = ..., + **kw: Any, +) -> "MappedColumn[Any]": + ... + + +def mapped_column(*args: Any, **kw: Any) -> "MappedColumn[Any]": + r"""construct a new ORM-mapped :class:`_schema.Column` construct. + + The :func:`_orm.mapped_column` function provides an ORM-aware and + Python-typing-compatible construct which is used with + :ref:`declarative <orm_declarative_mapping>` mappings to indicate an + attribute that's mapped to a Core :class:`_schema.Column` object. It + provides the equivalent feature as mapping an attribute to a + :class:`_schema.Column` object directly when using declarative. + + .. versionadded:: 2.0 - The :func:`_orm.mapped_column` function is shorthand for the construction - of a Core :class:`_schema.Column` object delivered within a - :func:`_orm.column_property` construct, which provides for consistent - typing information to be delivered to the class so that it works under - static type checkers such as mypy and delivers useful information in - IDE related type checkers such as pylance. The function can be used - in declarative mappings anywhere that :class:`_schema.Column` is normally - used:: + :func:`_orm.mapped_column` is normally used with explicit typing along with + the :class:`_orm.Mapped` mapped attribute type, where it can derive the SQL + type and nullability for the column automatically, such as:: + from typing import Optional + + from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column class User(Base): __tablename__ = 'user' - id = mapped_column(Integer) - name = mapped_column(String) + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column() + options: Mapped[Optional[str]] = mapped_column() + + In the above example, the ``int`` and ``str`` types are inferred by the + Declarative mapping system to indicate use of the :class:`_types.Integer` + and :class:`_types.String` datatypes, and the presence of ``Optional`` or + not indicates whether or not each non-primary-key column is to be + ``nullable=True`` or ``nullable=False``. + + The above example, when interpreted within a Declarative class, will result + in a table named ``"user"`` which is equivalent to the following:: + + from sqlalchemy import Integer + from sqlalchemy import String + from sqlalchemy import Table + + Table( + 'user', + Base.metadata, + Column("id", Integer, primary_key=True), + Column("name", String, nullable=False), + Column("options", String, nullable=True), + ) + The :func:`_orm.mapped_column` construct accepts the same arguments as + that of :class:`_schema.Column` directly, including optional "name" + and "type" fields, so the above mapping can be stated more explicitly + as:: - .. versionadded:: 2.0 + from typing import Optional + + from sqlalchemy import Integer + from sqlalchemy import String + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + + class User(Base): + __tablename__ = 'user' + + id: Mapped[int] = mapped_column("id", Integer, primary_key=True) + name: Mapped[str] = mapped_column("name", String, nullable=False) + options: Mapped[Optional[str]] = mapped_column( + "name", String, nullable=True + ) + + Arguments passed to :func:`_orm.mapped_column` always supersede those which + would be derived from the type annotation and/or attribute name. To state + the above mapping with more specific datatypes for ``id`` and ``options``, + and a different column name for ``name``, looks like:: + + from sqlalchemy import BigInteger + + class User(Base): + __tablename__ = 'user' + + id: Mapped[int] = mapped_column("id", BigInteger, primary_key=True) + name: Mapped[str] = mapped_column("user_name") + options: Mapped[Optional[str]] = mapped_column(String(50)) + + Where again, datatypes and nullable parameters that can be automatically + derived may be omitted. + + The datatypes passed to :class:`_orm.Mapped` are mapped to SQL + :class:`_types.TypeEngine` types with the following default mapping:: + + _type_map = { + int: Integer(), + float: Float(), + bool: Boolean(), + decimal.Decimal: Numeric(), + dt.date: Date(), + dt.datetime: DateTime(), + dt.time: Time(), + dt.timedelta: Interval(), + util.NoneType: NULLTYPE, + bytes: LargeBinary(), + str: String(), + } + + The above mapping may be expanded to include any combination of Python + datatypes to SQL types by using the + :paramref:`_orm.registry.type_annotation_map` parameter to + :class:`_orm.registry`, or as the attribute ``type_annotation_map`` upon + the :class:`_orm.DeclarativeBase` base class. + + Finally, :func:`_orm.mapped_column` is implicitly used by the Declarative + mapping system for any :class:`_orm.Mapped` annotation that has no + attribute value set up. This is much in the way that Python dataclasses + allow the ``field()`` construct to be optional, only needed when additional + parameters should be associated with the field. Using this functionality, + our original mapping can be stated even more succinctly as:: + + from typing import Optional + + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + + class User(Base): + __tablename__ = 'user' + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + options: Mapped[Optional[str]] + + Above, the ``name`` and ``options`` columns will be evaluated as + ``Column("name", String, nullable=False)`` and + ``Column("options", String, nullable=True)``, respectively. + + :param __name: String name to give to the :class:`_schema.Column`. This + is an optional, positional only argument that if present must be the + first positional argument passed. If omitted, the attribute name to + which the :func:`_orm.mapped_column` is mapped will be used as the SQL + column name. + :param __type: :class:`_types.TypeEngine` type or instance which will + indicate the datatype to be associated with the :class:`_schema.Column`. + This is an optional, positional-only argument that if present must + immediately follow the ``__name`` parameter if present also, or otherwise + be the first positional parameter. If omitted, the ultimate type for + the column may be derived either from the annotated type, or if a + :class:`_schema.ForeignKey` is present, from the datatype of the + referenced column. + :param \*args: Additional positional arguments include constructs such + as :class:`_schema.ForeignKey`, :class:`_schema.CheckConstraint`, + and :class:`_schema.Identity`, which are passed through to the constructed + :class:`_schema.Column`. + :param nullable: Optional bool, whether the column should be "NULL" or + "NOT NULL". If omitted, the nullability is derived from the type + annotation based on whether or not ``typing.Optional`` is present. + ``nullable`` defaults to ``True`` otherwise for non-primary key columns, + and ``False`` or primary key columns. + :param primary_key: optional bool, indicates the :class:`_schema.Column` + would be part of the table's primary key or not. + :param deferred: Optional bool - this keyword argument is consumed by the + ORM declarative process, and is not part of the :class:`_schema.Column` + itself; instead, it indicates that this column should be "deferred" for + loading as though mapped by :func:`_orm.deferred`. + :param \**kw: All remaining keyword argments are passed through to the + constructor for the :class:`_schema.Column`. """ - return column_property(Column(*args, **kw)) + + return MappedColumn(*args, **kw) def column_property( column: sql.ColumnElement[_T], *additional_columns, **kwargs -) -> "Mapped[_T]": +) -> "ColumnProperty[_T]": r"""Provide a column-level property for use with a mapping. Column-based properties can normally be applied to the mapper's @@ -269,22 +453,49 @@ def column_property( return ColumnProperty(column, *additional_columns, **kwargs) -def composite(class_: Type[_T], *attrs, **kwargs) -> "Mapped[_T]": +@overload +def composite( + class_: Type[_T], + *attrs: Union[sql.ColumnElement[Any], MappedColumn, str, Mapped[Any]], + **kwargs: Any, +) -> "Composite[_T]": + ... + + +@overload +def composite( + *attrs: Union[sql.ColumnElement[Any], MappedColumn, str, Mapped[Any]], + **kwargs: Any, +) -> "Composite[Any]": + ... + + +def composite( + class_: Any = None, + *attrs: Union[sql.ColumnElement[Any], MappedColumn, str, Mapped[Any]], + **kwargs: Any, +) -> "Composite[Any]": r"""Return a composite column-based property for use with a Mapper. See the mapping documentation section :ref:`mapper_composite` for a full usage example. The :class:`.MapperProperty` returned by :func:`.composite` - is the :class:`.CompositeProperty`. + is the :class:`.Composite`. :param class\_: The "composite type" class, or any classmethod or callable which will produce a new instance of the composite object given the column values in order. - :param \*cols: - List of Column objects to be mapped. + :param \*attrs: + List of elements to be mapped, which may include: + + * :class:`_schema.Column` objects + * :func:`_orm.mapped_column` constructs + * string names of other attributes on the mapped class, which may be + any other SQL or object-mapped attribute. This can for + example allow a composite that refers to a many-to-one relationship :param active_history=False: When ``True``, indicates that the "previous" value for a @@ -301,7 +512,7 @@ def composite(class_: Type[_T], *attrs, **kwargs) -> "Mapped[_T]": :func:`~sqlalchemy.orm.deferred`. :param comparator_factory: a class which extends - :class:`.CompositeProperty.Comparator` which provides custom SQL + :class:`.Composite.Comparator` which provides custom SQL clause generation for comparison operations. :param doc: @@ -312,7 +523,7 @@ def composite(class_: Type[_T], *attrs, **kwargs) -> "Mapped[_T]": :attr:`.MapperProperty.info` attribute of this object. """ - return CompositeProperty(class_, *attrs, **kwargs) + return Composite(class_, *attrs, **kwargs) def with_loader_criteria( @@ -500,143 +711,140 @@ def with_loader_criteria( @overload def relationship( - argument: Union[str, Type[_RC], Callable[[], Type[_RC]]], + argument: Optional[_RelationshipArgumentType[_T]], + secondary=None, + *, + uselist: Literal[False] = None, + collection_class: Literal[None] = None, + primaryjoin=None, + secondaryjoin=None, + back_populates=None, + **kw: Any, +) -> Relationship[_T]: + ... + + +@overload +def relationship( + argument: Optional[_RelationshipArgumentType[_T]], secondary=None, *, uselist: Literal[True] = None, + collection_class: Literal[None] = None, primaryjoin=None, secondaryjoin=None, - foreign_keys=None, - order_by=False, - backref=None, back_populates=None, - overlaps=None, - post_update=False, - cascade=False, - viewonly=False, - lazy="select", - collection_class=None, - passive_deletes=RelationshipProperty._persistence_only["passive_deletes"], - passive_updates=RelationshipProperty._persistence_only["passive_updates"], - remote_side=None, - enable_typechecks=RelationshipProperty._persistence_only[ - "enable_typechecks" - ], - join_depth=None, - comparator_factory=None, - single_parent=False, - innerjoin=False, - distinct_target_key=None, - doc=None, - active_history=RelationshipProperty._persistence_only["active_history"], - cascade_backrefs=RelationshipProperty._persistence_only[ - "cascade_backrefs" - ], - load_on_pending=False, - bake_queries=True, - _local_remote_pairs=None, - query_class=None, - info=None, - omit_join=None, - sync_backref=None, - _legacy_inactive_history_style=False, -) -> Mapped[Collection[_RC]]: + **kw: Any, +) -> Relationship[List[_T]]: ... @overload def relationship( - argument: Union[str, Type[_RC], Callable[[], Type[_RC]]], + argument: Optional[_RelationshipArgumentType[_T]], secondary=None, *, - uselist: Optional[bool] = None, + uselist: Union[Literal[None], Literal[True]] = None, + collection_class: Type[List] = None, primaryjoin=None, secondaryjoin=None, - foreign_keys=None, - order_by=False, - backref=None, back_populates=None, - overlaps=None, - post_update=False, - cascade=False, - viewonly=False, - lazy="select", - collection_class=None, - passive_deletes=RelationshipProperty._persistence_only["passive_deletes"], - passive_updates=RelationshipProperty._persistence_only["passive_updates"], - remote_side=None, - enable_typechecks=RelationshipProperty._persistence_only[ - "enable_typechecks" - ], - join_depth=None, - comparator_factory=None, - single_parent=False, - innerjoin=False, - distinct_target_key=None, - doc=None, - active_history=RelationshipProperty._persistence_only["active_history"], - cascade_backrefs=RelationshipProperty._persistence_only[ - "cascade_backrefs" - ], - load_on_pending=False, - bake_queries=True, - _local_remote_pairs=None, - query_class=None, - info=None, - omit_join=None, - sync_backref=None, - _legacy_inactive_history_style=False, -) -> Mapped[_RC]: + **kw: Any, +) -> Relationship[List[_T]]: ... +@overload def relationship( - argument: Union[str, Type[_RC], Callable[[], Type[_RC]]], + argument: Optional[_RelationshipArgumentType[_T]], secondary=None, *, + uselist: Union[Literal[None], Literal[True]] = None, + collection_class: Type[Set] = None, primaryjoin=None, secondaryjoin=None, - foreign_keys=None, + back_populates=None, + **kw: Any, +) -> Relationship[Set[_T]]: + ... + + +@overload +def relationship( + argument: Optional[_RelationshipArgumentType[_T]], + secondary=None, + *, + uselist: Union[Literal[None], Literal[True]] = None, + collection_class: Type[Mapping[Any, Any]] = None, + primaryjoin=None, + secondaryjoin=None, + back_populates=None, + **kw: Any, +) -> Relationship[Mapping[Any, _T]]: + ... + + +@overload +def relationship( + argument: _RelationshipArgumentType[_T], + secondary=None, + *, + uselist: Literal[None] = None, + collection_class: Literal[None] = None, + primaryjoin=None, + secondaryjoin=None, + back_populates=None, + **kw: Any, +) -> Relationship[Any]: + ... + + +@overload +def relationship( + argument: Optional[_RelationshipArgumentType[_T]] = None, + secondary=None, + *, + uselist: Literal[True] = None, + collection_class: Any = None, + primaryjoin=None, + secondaryjoin=None, + back_populates=None, + **kw: Any, +) -> Relationship[Any]: + ... + + +@overload +def relationship( + argument: Literal[None] = None, + secondary=None, + *, uselist: Optional[bool] = None, - order_by=False, - backref=None, + collection_class: Any = None, + primaryjoin=None, + secondaryjoin=None, back_populates=None, - overlaps=None, - post_update=False, - cascade=False, - viewonly=False, - lazy="select", - collection_class=None, - passive_deletes=RelationshipProperty._persistence_only["passive_deletes"], - passive_updates=RelationshipProperty._persistence_only["passive_updates"], - remote_side=None, - enable_typechecks=RelationshipProperty._persistence_only[ - "enable_typechecks" - ], - join_depth=None, - comparator_factory=None, - single_parent=False, - innerjoin=False, - distinct_target_key=None, - doc=None, - active_history=RelationshipProperty._persistence_only["active_history"], - cascade_backrefs=RelationshipProperty._persistence_only[ - "cascade_backrefs" - ], - load_on_pending=False, - bake_queries=True, - _local_remote_pairs=None, - query_class=None, - info=None, - omit_join=None, - sync_backref=None, - _legacy_inactive_history_style=False, -) -> Mapped[_RC]: + **kw: Any, +) -> Relationship[Any]: + ... + + +def relationship( + argument: Optional[_RelationshipArgumentType[_T]] = None, + secondary=None, + *, + uselist: Optional[bool] = None, + collection_class: Optional[Type[Collection]] = None, + primaryjoin=None, + secondaryjoin=None, + back_populates=None, + **kw: Any, +) -> Relationship[Any]: """Provide a relationship between two mapped classes. This corresponds to a parent-child or associative table relationship. The constructed class is an instance of - :class:`.RelationshipProperty`. + :class:`.Relationship`. A typical :func:`_orm.relationship`, used in a classical mapping:: @@ -897,7 +1105,7 @@ def relationship( examples. :param comparator_factory: - A class which extends :class:`.RelationshipProperty.Comparator` + A class which extends :class:`.Relationship.Comparator` which provides custom SQL clause generation for comparison operations. @@ -1447,42 +1655,15 @@ def relationship( """ - return RelationshipProperty( + return Relationship( argument, - secondary, - primaryjoin, - secondaryjoin, - foreign_keys, - uselist, - order_by, - backref, - back_populates, - overlaps, - post_update, - cascade, - viewonly, - lazy, - collection_class, - passive_deletes, - passive_updates, - remote_side, - enable_typechecks, - join_depth, - comparator_factory, - single_parent, - innerjoin, - distinct_target_key, - doc, - active_history, - cascade_backrefs, - load_on_pending, - bake_queries, - _local_remote_pairs, - query_class, - info, - omit_join, - sync_backref, - _legacy_inactive_history_style, + secondary=secondary, + uselist=uselist, + collection_class=collection_class, + primaryjoin=primaryjoin, + secondaryjoin=secondaryjoin, + back_populates=back_populates, + **kw, ) @@ -1493,7 +1674,7 @@ def synonym( comparator_factory=None, doc=None, info=None, -) -> "Mapped": +) -> "Synonym[Any]": """Denote an attribute name as a synonym to a mapped property, in that the attribute will mirror the value and expression behavior of another attribute. @@ -1597,9 +1778,7 @@ def synonym( than can be achieved with synonyms. """ - return SynonymProperty( - name, map_column, descriptor, comparator_factory, doc, info - ) + return Synonym(name, map_column, descriptor, comparator_factory, doc, info) def create_session(bind=None, **kwargs): @@ -1733,7 +1912,9 @@ def deferred(*columns, **kw): return ColumnProperty(deferred=True, *columns, **kw) -def query_expression(default_expr=sql.null()): +def query_expression( + default_expr: sql.ColumnElement[_T] = sql.null(), +) -> "Mapped[_T]": """Indicate an attribute that populates from a query-time SQL expression. :param default_expr: Optional SQL expression object that will be used in @@ -1787,3 +1968,273 @@ def clear_mappers(): """ mapperlib._dispose_registries(mapperlib._all_registries(), False) + + +@overload +def aliased( + element: Union[Type[_T], "Mapper[_T]", "AliasedClass[_T]"], + alias=None, + name=None, + flat=False, + adapt_on_names=False, +) -> "AliasedClass[_T]": + ... + + +@overload +def aliased( + element: "FromClause", + alias=None, + name=None, + flat=False, + adapt_on_names=False, +) -> "Alias": + ... + + +def aliased( + element: Union[Type[_T], "Mapper[_T]", "FromClause", "AliasedClass[_T]"], + alias=None, + name=None, + flat=False, + adapt_on_names=False, +) -> Union["AliasedClass[_T]", "Alias"]: + """Produce an alias of the given element, usually an :class:`.AliasedClass` + instance. + + E.g.:: + + my_alias = aliased(MyClass) + + session.query(MyClass, my_alias).filter(MyClass.id > my_alias.id) + + The :func:`.aliased` function is used to create an ad-hoc mapping of a + mapped class to a new selectable. By default, a selectable is generated + from the normally mapped selectable (typically a :class:`_schema.Table` + ) using the + :meth:`_expression.FromClause.alias` method. However, :func:`.aliased` + can also be + used to link the class to a new :func:`_expression.select` statement. + Also, the :func:`.with_polymorphic` function is a variant of + :func:`.aliased` that is intended to specify a so-called "polymorphic + selectable", that corresponds to the union of several joined-inheritance + subclasses at once. + + For convenience, the :func:`.aliased` function also accepts plain + :class:`_expression.FromClause` constructs, such as a + :class:`_schema.Table` or + :func:`_expression.select` construct. In those cases, the + :meth:`_expression.FromClause.alias` + method is called on the object and the new + :class:`_expression.Alias` object returned. The returned + :class:`_expression.Alias` is not + ORM-mapped in this case. + + .. seealso:: + + :ref:`tutorial_orm_entity_aliases` - in the :ref:`unified_tutorial` + + :ref:`orm_queryguide_orm_aliases` - in the :ref:`queryguide_toplevel` + + :ref:`ormtutorial_aliases` - in the legacy :ref:`ormtutorial_toplevel` + + :param element: element to be aliased. Is normally a mapped class, + but for convenience can also be a :class:`_expression.FromClause` + element. + + :param alias: Optional selectable unit to map the element to. This is + usually used to link the object to a subquery, and should be an aliased + select construct as one would produce from the + :meth:`_query.Query.subquery` method or + the :meth:`_expression.Select.subquery` or + :meth:`_expression.Select.alias` methods of the :func:`_expression.select` + construct. + + :param name: optional string name to use for the alias, if not specified + by the ``alias`` parameter. The name, among other things, forms the + attribute name that will be accessible via tuples returned by a + :class:`_query.Query` object. Not supported when creating aliases + of :class:`_sql.Join` objects. + + :param flat: Boolean, will be passed through to the + :meth:`_expression.FromClause.alias` call so that aliases of + :class:`_expression.Join` objects will alias the individual tables + inside the join, rather than creating a subquery. This is generally + supported by all modern databases with regards to right-nested joins + and generally produces more efficient queries. + + :param adapt_on_names: if True, more liberal "matching" will be used when + mapping the mapped columns of the ORM entity to those of the + given selectable - a name-based match will be performed if the + given selectable doesn't otherwise have a column that corresponds + to one on the entity. The use case for this is when associating + an entity with some derived selectable such as one that uses + aggregate functions:: + + class UnitPrice(Base): + __tablename__ = 'unit_price' + ... + unit_id = Column(Integer) + price = Column(Numeric) + + aggregated_unit_price = Session.query( + func.sum(UnitPrice.price).label('price') + ).group_by(UnitPrice.unit_id).subquery() + + aggregated_unit_price = aliased(UnitPrice, + alias=aggregated_unit_price, adapt_on_names=True) + + Above, functions on ``aggregated_unit_price`` which refer to + ``.price`` will return the + ``func.sum(UnitPrice.price).label('price')`` column, as it is + matched on the name "price". Ordinarily, the "price" function + wouldn't have any "column correspondence" to the actual + ``UnitPrice.price`` column as it is not a proxy of the original. + + """ + return AliasedInsp._alias_factory( + element, + alias=alias, + name=name, + flat=flat, + adapt_on_names=adapt_on_names, + ) + + +def with_polymorphic( + base, + classes, + selectable=False, + flat=False, + polymorphic_on=None, + aliased=False, + innerjoin=False, + _use_mapper_path=False, +): + """Produce an :class:`.AliasedClass` construct which specifies + columns for descendant mappers of the given base. + + Using this method will ensure that each descendant mapper's + tables are included in the FROM clause, and will allow filter() + criterion to be used against those tables. The resulting + instances will also have those columns already loaded so that + no "post fetch" of those columns will be required. + + .. seealso:: + + :ref:`with_polymorphic` - full discussion of + :func:`_orm.with_polymorphic`. + + :param base: Base class to be aliased. + + :param classes: a single class or mapper, or list of + class/mappers, which inherit from the base class. + Alternatively, it may also be the string ``'*'``, in which case + all descending mapped classes will be added to the FROM clause. + + :param aliased: when True, the selectable will be aliased. For a + JOIN, this means the JOIN will be SELECTed from inside of a subquery + unless the :paramref:`_orm.with_polymorphic.flat` flag is set to + True, which is recommended for simpler use cases. + + :param flat: Boolean, will be passed through to the + :meth:`_expression.FromClause.alias` call so that aliases of + :class:`_expression.Join` objects will alias the individual tables + inside the join, rather than creating a subquery. This is generally + supported by all modern databases with regards to right-nested joins + and generally produces more efficient queries. Setting this flag is + recommended as long as the resulting SQL is functional. + + :param selectable: a table or subquery that will + be used in place of the generated FROM clause. This argument is + required if any of the desired classes use concrete table + inheritance, since SQLAlchemy currently cannot generate UNIONs + among tables automatically. If used, the ``selectable`` argument + must represent the full set of tables and columns mapped by every + mapped class. Otherwise, the unaccounted mapped columns will + result in their table being appended directly to the FROM clause + which will usually lead to incorrect results. + + When left at its default value of ``False``, the polymorphic + selectable assigned to the base mapper is used for selecting rows. + However, it may also be passed as ``None``, which will bypass the + configured polymorphic selectable and instead construct an ad-hoc + selectable for the target classes given; for joined table inheritance + this will be a join that includes all target mappers and their + subclasses. + + :param polymorphic_on: a column to be used as the "discriminator" + column for the given selectable. If not given, the polymorphic_on + attribute of the base classes' mapper will be used, if any. This + is useful for mappings that don't have polymorphic loading + behavior by default. + + :param innerjoin: if True, an INNER JOIN will be used. This should + only be specified if querying for one specific subtype only + """ + return AliasedInsp._with_polymorphic_factory( + base, + classes, + selectable=selectable, + flat=flat, + polymorphic_on=polymorphic_on, + aliased=aliased, + innerjoin=innerjoin, + _use_mapper_path=_use_mapper_path, + ) + + +def join( + left, right, onclause=None, isouter=False, full=False, join_to_left=None +): + r"""Produce an inner join between left and right clauses. + + :func:`_orm.join` is an extension to the core join interface + provided by :func:`_expression.join()`, where the + left and right selectables may be not only core selectable + objects such as :class:`_schema.Table`, but also mapped classes or + :class:`.AliasedClass` instances. The "on" clause can + be a SQL expression, or an attribute or string name + referencing a configured :func:`_orm.relationship`. + + :func:`_orm.join` is not commonly needed in modern usage, + as its functionality is encapsulated within that of the + :meth:`_query.Query.join` method, which features a + significant amount of automation beyond :func:`_orm.join` + by itself. Explicit usage of :func:`_orm.join` + with :class:`_query.Query` involves usage of the + :meth:`_query.Query.select_from` method, as in:: + + from sqlalchemy.orm import join + session.query(User).\ + select_from(join(User, Address, User.addresses)).\ + filter(Address.email_address=='foo@bar.com') + + In modern SQLAlchemy the above join can be written more + succinctly as:: + + session.query(User).\ + join(User.addresses).\ + filter(Address.email_address=='foo@bar.com') + + See :meth:`_query.Query.join` for information on modern usage + of ORM level joins. + + .. deprecated:: 0.8 + + the ``join_to_left`` parameter is deprecated, and will be removed + in a future release. The parameter has no effect. + + """ + return _ORMJoin(left, right, onclause, isouter, full) + + +def outerjoin(left, right, onclause=None, full=False, join_to_left=None): + """Produce a left outer join between left and right clauses. + + This is the "outer join" version of the :func:`_orm.join` function, + featuring the same behavior except that an OUTER JOIN is generated. + See that function's documentation for other usage details. + + """ + return _ORMJoin(left, right, onclause, True, full) diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 5a605b7c6..fbfb2b2ee 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -35,6 +35,7 @@ from .base import instance_state from .base import instance_str from .base import LOAD_AGAINST_COMMITTED from .base import manager_of_class +from .base import Mapped as Mapped # noqa from .base import NEVER_SET # noqa from .base import NO_AUTOFLUSH from .base import NO_CHANGE # noqa @@ -79,6 +80,7 @@ class QueryableAttribute( traversals.HasCopyInternals, roles.JoinTargetRole, roles.OnClauseRole, + roles.ColumnsClauseRole, sql_base.Immutable, sql_base.MemoizedHasCacheKey, ): @@ -190,7 +192,7 @@ class QueryableAttribute( construct has defined one). * If the attribute refers to any other kind of - :class:`.MapperProperty`, including :class:`.RelationshipProperty`, + :class:`.MapperProperty`, including :class:`.Relationship`, the attribute will refer to the :attr:`.MapperProperty.info` dictionary associated with that :class:`.MapperProperty`. @@ -352,7 +354,7 @@ class QueryableAttribute( Return values here will commonly be instances of - :class:`.ColumnProperty` or :class:`.RelationshipProperty`. + :class:`.ColumnProperty` or :class:`.Relationship`. """ diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py index 7ab4b7737..e6d4a6729 100644 --- a/lib/sqlalchemy/orm/base.py +++ b/lib/sqlalchemy/orm/base.py @@ -12,8 +12,11 @@ import operator import typing from typing import Any +from typing import Callable from typing import Generic +from typing import Optional from typing import overload +from typing import Tuple from typing import TypeVar from typing import Union @@ -22,8 +25,9 @@ from .. import exc as sa_exc from .. import inspection from .. import util from ..sql.elements import SQLCoreOperations -from ..util import typing as compat_typing from ..util.langhelpers import TypingOnly +from ..util.typing import Concatenate +from ..util.typing import ParamSpec if typing.TYPE_CHECKING: @@ -32,6 +36,9 @@ if typing.TYPE_CHECKING: _T = TypeVar("_T", bound=Any) +_IdentityKeyType = Tuple[type, Tuple[Any, ...], Optional[str]] + + PASSIVE_NO_RESULT = util.symbol( "PASSIVE_NO_RESULT", """Symbol returned by a loader callable or other attribute/history @@ -236,16 +243,16 @@ _DEFER_FOR_STATE = util.symbol("DEFER_FOR_STATE") _RAISE_FOR_STATE = util.symbol("RAISE_FOR_STATE") -_Fn = typing.TypeVar("_Fn", bound=typing.Callable) -_Args = compat_typing.ParamSpec("_Args") -_Self = typing.TypeVar("_Self") +_Fn = TypeVar("_Fn", bound=Callable) +_Args = ParamSpec("_Args") +_Self = TypeVar("_Self") def _assertions( - *assertions, -) -> typing.Callable[ - [typing.Callable[compat_typing.Concatenate[_Fn, _Args], _Self]], - typing.Callable[compat_typing.Concatenate[_Fn, _Args], _Self], + *assertions: Any, +) -> Callable[ + [Callable[Concatenate[_Self, _Fn, _Args], _Self]], + Callable[Concatenate[_Self, _Fn, _Args], _Self], ]: @util.decorator def generate( @@ -605,8 +612,8 @@ class SQLORMOperations(SQLCoreOperations[_T], TypingOnly): ... -class Mapped(Generic[_T], util.TypingOnly): - """Represent an ORM mapped attribute for typing purposes. +class Mapped(Generic[_T], TypingOnly): + """Represent an ORM mapped attribute on a mapped class. This class represents the complete descriptor interface for any class attribute that will have been :term:`instrumented` by the ORM @@ -650,7 +657,7 @@ class Mapped(Generic[_T], util.TypingOnly): ... @classmethod - def _empty_constructor(cls, arg1: Any) -> "SQLORMOperations[_T]": + def _empty_constructor(cls, arg1: Any) -> "Mapped[_T]": ... @overload diff --git a/lib/sqlalchemy/orm/clsregistry.py b/lib/sqlalchemy/orm/clsregistry.py index 3bf7ddde8..c24b3c696 100644 --- a/lib/sqlalchemy/orm/clsregistry.py +++ b/lib/sqlalchemy/orm/clsregistry.py @@ -10,11 +10,14 @@ This system allows specification of classes and expressions used in :func:`_orm.relationship` using strings. """ +import re +from typing import MutableMapping +from typing import Union import weakref from . import attributes from . import interfaces -from .descriptor_props import SynonymProperty +from .descriptor_props import Synonym from .properties import ColumnProperty from .util import class_mapper from .. import exc @@ -22,6 +25,8 @@ from .. import inspection from .. import util from ..sql.schema import _get_table_key +_ClsRegistryType = MutableMapping[str, Union[type, "ClsRegistryToken"]] + # strong references to registries which we place in # the _decl_class_registry, which is usually weak referencing. # the internal registries here link to classes with weakrefs and remove @@ -118,7 +123,13 @@ def _key_is_empty(key, decl_class_registry, test): return not test(thing) -class _MultipleClassMarker: +class ClsRegistryToken: + """an object that can be in the registry._class_registry as a value.""" + + __slots__ = () + + +class _MultipleClassMarker(ClsRegistryToken): """refers to multiple classes of the same name within _decl_class_registry. @@ -182,7 +193,7 @@ class _MultipleClassMarker: self.contents.add(weakref.ref(item, self._remove_item)) -class _ModuleMarker: +class _ModuleMarker(ClsRegistryToken): """Refers to a module name within _decl_class_registry. @@ -281,7 +292,7 @@ class _GetColumns: desc = mp.all_orm_descriptors[key] if desc.extension_type is interfaces.NOT_EXTENSION: prop = desc.property - if isinstance(prop, SynonymProperty): + if isinstance(prop, Synonym): key = prop.name elif not isinstance(prop, ColumnProperty): raise exc.InvalidRequestError( @@ -372,13 +383,26 @@ class _class_resolver: return self.fallback[key] def _raise_for_name(self, name, err): - raise exc.InvalidRequestError( - "When initializing mapper %s, expression %r failed to " - "locate a name (%r). If this is a class name, consider " - "adding this relationship() to the %r class after " - "both dependent classes have been defined." - % (self.prop.parent, self.arg, name, self.cls) - ) from err + generic_match = re.match(r"(.+)\[(.+)\]", name) + + if generic_match: + raise exc.InvalidRequestError( + f"When initializing mapper {self.prop.parent}, " + f'expression "relationship({self.arg!r})" seems to be ' + "using a generic class as the argument to relationship(); " + "please state the generic argument " + "using an annotation, e.g. " + f'"{self.prop.key}: Mapped[{generic_match.group(1)}' + f'[{generic_match.group(2)}]] = relationship()"' + ) from err + else: + raise exc.InvalidRequestError( + "When initializing mapper %s, expression %r failed to " + "locate a name (%r). If this is a class name, consider " + "adding this relationship() to the %r class after " + "both dependent classes have been defined." + % (self.prop.parent, self.arg, name, self.cls) + ) from err def _resolve_name(self): name = self.arg diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index 75ce8216f..ba4225563 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -102,18 +102,20 @@ The owning object and :class:`.CollectionAttributeImpl` are also reachable through the adapter, allowing for some very sophisticated behavior. """ - import operator import threading +import typing import weakref -from sqlalchemy.util.compat import inspect_getfullargspec -from . import base from .. import exc as sa_exc from .. import util -from ..sql import coercions -from ..sql import expression -from ..sql import roles +from ..util.compat import inspect_getfullargspec + +if typing.TYPE_CHECKING: + from .mapped_collection import attribute_mapped_collection + from .mapped_collection import column_mapped_collection + from .mapped_collection import mapped_collection + from .mapped_collection import MappedCollection # noqa: F401 __all__ = [ "collection", @@ -126,180 +128,6 @@ __all__ = [ __instrumentation_mutex = threading.Lock() -class _PlainColumnGetter: - """Plain column getter, stores collection of Column objects - directly. - - Serializes to a :class:`._SerializableColumnGetterV2` - which has more expensive __call__() performance - and some rare caveats. - - """ - - def __init__(self, cols): - self.cols = cols - self.composite = len(cols) > 1 - - def __reduce__(self): - return _SerializableColumnGetterV2._reduce_from_cols(self.cols) - - def _cols(self, mapper): - return self.cols - - def __call__(self, value): - state = base.instance_state(value) - m = base._state_mapper(state) - - key = [ - m._get_state_attr_by_column(state, state.dict, col) - for col in self._cols(m) - ] - - if self.composite: - return tuple(key) - else: - return key[0] - - -class _SerializableColumnGetter: - """Column-based getter used in version 0.7.6 only. - - Remains here for pickle compatibility with 0.7.6. - - """ - - def __init__(self, colkeys): - self.colkeys = colkeys - self.composite = len(colkeys) > 1 - - def __reduce__(self): - return _SerializableColumnGetter, (self.colkeys,) - - def __call__(self, value): - state = base.instance_state(value) - m = base._state_mapper(state) - key = [ - m._get_state_attr_by_column( - state, state.dict, m.mapped_table.columns[k] - ) - for k in self.colkeys - ] - if self.composite: - return tuple(key) - else: - return key[0] - - -class _SerializableColumnGetterV2(_PlainColumnGetter): - """Updated serializable getter which deals with - multi-table mapped classes. - - Two extremely unusual cases are not supported. - Mappings which have tables across multiple metadata - objects, or which are mapped to non-Table selectables - linked across inheriting mappers may fail to function - here. - - """ - - def __init__(self, colkeys): - self.colkeys = colkeys - self.composite = len(colkeys) > 1 - - def __reduce__(self): - return self.__class__, (self.colkeys,) - - @classmethod - def _reduce_from_cols(cls, cols): - def _table_key(c): - if not isinstance(c.table, expression.TableClause): - return None - else: - return c.table.key - - colkeys = [(c.key, _table_key(c)) for c in cols] - return _SerializableColumnGetterV2, (colkeys,) - - def _cols(self, mapper): - cols = [] - metadata = getattr(mapper.local_table, "metadata", None) - for (ckey, tkey) in self.colkeys: - if tkey is None or metadata is None or tkey not in metadata: - cols.append(mapper.local_table.c[ckey]) - else: - cols.append(metadata.tables[tkey].c[ckey]) - return cols - - -def column_mapped_collection(mapping_spec): - """A dictionary-based collection type with column-based keying. - - Returns a :class:`.MappedCollection` factory with a keying function - generated from mapping_spec, which may be a Column or a sequence - of Columns. - - The key value must be immutable for the lifetime of the object. You - can not, for example, map on foreign key values if those key values will - change during the session, i.e. from None to a database-assigned integer - after a session flush. - - """ - cols = [ - coercions.expect(roles.ColumnArgumentRole, q, argname="mapping_spec") - for q in util.to_list(mapping_spec) - ] - keyfunc = _PlainColumnGetter(cols) - return lambda: MappedCollection(keyfunc) - - -class _SerializableAttrGetter: - def __init__(self, name): - self.name = name - self.getter = operator.attrgetter(name) - - def __call__(self, target): - return self.getter(target) - - def __reduce__(self): - return _SerializableAttrGetter, (self.name,) - - -def attribute_mapped_collection(attr_name): - """A dictionary-based collection type with attribute-based keying. - - Returns a :class:`.MappedCollection` factory with a keying based on the - 'attr_name' attribute of entities in the collection, where ``attr_name`` - is the string name of the attribute. - - .. warning:: the key value must be assigned to its final value - **before** it is accessed by the attribute mapped collection. - Additionally, changes to the key attribute are **not tracked** - automatically, which means the key in the dictionary is not - automatically synchronized with the key value on the target object - itself. See the section :ref:`key_collections_mutations` - for an example. - - """ - getter = _SerializableAttrGetter(attr_name) - return lambda: MappedCollection(getter) - - -def mapped_collection(keyfunc): - """A dictionary-based collection type with arbitrary keying. - - Returns a :class:`.MappedCollection` factory with a keying function - generated from keyfunc, a callable that takes an entity and returns a - key value. - - The key value must be immutable for the lifetime of the object. You - can not, for example, map on foreign key values if those key values will - change during the session, i.e. from None to a database-assigned integer - after a session flush. - - """ - return lambda: MappedCollection(keyfunc) - - class collection: """Decorators for entity collection classes. @@ -1620,63 +1448,24 @@ __interfaces = { } -class MappedCollection(dict): - """A basic dictionary-based collection class. - - Extends dict with the minimal bag semantics that collection - classes require. ``set`` and ``remove`` are implemented in terms - of a keying function: any callable that takes an object and - returns an object for use as a dictionary key. - - """ - - def __init__(self, keyfunc): - """Create a new collection with keying provided by keyfunc. +def __go(lcls): - keyfunc may be any callable that takes an object and returns an object - for use as a dictionary key. + global mapped_collection, column_mapped_collection + global attribute_mapped_collection, MappedCollection - The keyfunc will be called every time the ORM needs to add a member by - value-only (such as when loading instances from the database) or - remove a member. The usual cautions about dictionary keying apply- - ``keyfunc(object)`` should return the same output for the life of the - collection. Keying based on mutable properties can result in - unreachable instances "lost" in the collection. + from .mapped_collection import mapped_collection + from .mapped_collection import column_mapped_collection + from .mapped_collection import attribute_mapped_collection + from .mapped_collection import MappedCollection - """ - self.keyfunc = keyfunc - - @collection.appender - @collection.internally_instrumented - def set(self, value, _sa_initiator=None): - """Add an item by value, consulting the keyfunc for the key.""" - - key = self.keyfunc(value) - self.__setitem__(key, value, _sa_initiator) - - @collection.remover - @collection.internally_instrumented - def remove(self, value, _sa_initiator=None): - """Remove an item by value, consulting the keyfunc for the key.""" - - key = self.keyfunc(value) - # Let self[key] raise if key is not in this collection - # testlib.pragma exempt:__ne__ - if self[key] != value: - raise sa_exc.InvalidRequestError( - "Can not remove '%s': collection holds '%s' for key '%s'. " - "Possible cause: is the MappedCollection key function " - "based on mutable properties or properties that only obtain " - "values after flush?" % (value, self[key], key) - ) - self.__delitem__(key, _sa_initiator) + # ensure instrumentation is associated with + # these built-in classes; if a user-defined class + # subclasses these and uses @internally_instrumented, + # the superclass is otherwise not instrumented. + # see [ticket:2406]. + _instrument_class(InstrumentedList) + _instrument_class(InstrumentedSet) + _instrument_class(MappedCollection) -# ensure instrumentation is associated with -# these built-in classes; if a user-defined class -# subclasses these and uses @internally_instrumented, -# the superclass is otherwise not instrumented. -# see [ticket:2406]. -_instrument_class(MappedCollection) -_instrument_class(InstrumentedList) -_instrument_class(InstrumentedSet) +__go(locals()) diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index 8e9cf66e2..34f291864 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -5,16 +5,18 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php import itertools +from typing import List from . import attributes from . import interfaces from . import loading from .base import _is_aliased_class +from .interfaces import ORMColumnDescription from .interfaces import ORMColumnsClauseRole from .path_registry import PathRegistry from .util import _entity_corresponds_to from .util import _ORMJoin -from .util import aliased +from .util import AliasedClass from .util import Bundle from .util import ORMAdapter from .. import exc as sa_exc @@ -1570,7 +1572,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): # when we are here, it means join() was called with an indicator # as to an exact left side, which means a path to a - # RelationshipProperty was given, e.g.: + # Relationship was given, e.g.: # # join(RightEntity, LeftEntity.right) # @@ -1725,7 +1727,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): need_adapter = True # make the right hand side target into an ORM entity - right = aliased(right_mapper, right_selectable) + right = AliasedClass(right_mapper, right_selectable) util.warn_deprecated( "An alias is being generated automatically against " @@ -1750,7 +1752,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): # test/orm/inheritance/test_relationships.py. There are also # general overlap cases with many-to-many tables where automatic # aliasing is desirable. - right = aliased(right, flat=True) + right = AliasedClass(right, flat=True) need_adapter = True util.warn( @@ -1910,7 +1912,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): def _column_descriptions( query_or_select_stmt, compile_state=None, legacy=False -): +) -> List[ORMColumnDescription]: if compile_state is None: compile_state = ORMSelectCompileState._create_entities_collection( query_or_select_stmt, legacy=legacy diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index 59fabb9b6..5ac9966dd 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -11,7 +11,9 @@ import typing from typing import Any from typing import Callable from typing import ClassVar +from typing import Mapping from typing import Optional +from typing import Type from typing import TypeVar from typing import Union import weakref @@ -31,7 +33,7 @@ from .decl_base import _declarative_constructor from .decl_base import _DeferredMapperConfig from .decl_base import _del_attribute from .decl_base import _mapper -from .descriptor_props import SynonymProperty as _orm_synonym +from .descriptor_props import Synonym as _orm_synonym from .mapper import Mapper from .. import exc from .. import inspection @@ -39,14 +41,18 @@ from .. import util from ..sql.elements import SQLCoreOperations from ..sql.schema import MetaData from ..sql.selectable import FromClause +from ..sql.type_api import TypeEngine from ..util import hybridmethod from ..util import hybridproperty +from ..util import typing as compat_typing if typing.TYPE_CHECKING: from .state import InstanceState # noqa _T = TypeVar("_T", bound=Any) +_TypeAnnotationMapType = Mapping[Type, Union[Type[TypeEngine], TypeEngine]] + def has_inherited_table(cls): """Given a class, return True if any of the classes it inherits from has a @@ -67,8 +73,22 @@ def has_inherited_table(cls): return False +class _DynamicAttributesType(type): + def __setattr__(cls, key, value): + if "__mapper__" in cls.__dict__: + _add_attribute(cls, key, value) + else: + type.__setattr__(cls, key, value) + + def __delattr__(cls, key): + if "__mapper__" in cls.__dict__: + _del_attribute(cls, key) + else: + type.__delattr__(cls, key) + + class DeclarativeAttributeIntercept( - type, inspection.Inspectable["Mapper[Any]"] + _DynamicAttributesType, inspection.Inspectable["Mapper[Any]"] ): """Metaclass that may be used in conjunction with the :class:`_orm.DeclarativeBase` class to support addition of class @@ -76,15 +96,16 @@ class DeclarativeAttributeIntercept( """ - def __setattr__(cls, key, value): - _add_attribute(cls, key, value) - - def __delattr__(cls, key): - _del_attribute(cls, key) +class DeclarativeMeta( + _DynamicAttributesType, inspection.Inspectable["Mapper[Any]"] +): + metadata: MetaData + registry: "RegistryType" -class DeclarativeMeta(type, inspection.Inspectable["Mapper[Any]"]): - def __init__(cls, classname, bases, dict_, **kw): + def __init__( + cls, classname: Any, bases: Any, dict_: Any, **kw: Any + ) -> None: # early-consume registry from the initial declarative base, # assign privately to not conflict with subclass attributes named # "registry" @@ -103,12 +124,6 @@ class DeclarativeMeta(type, inspection.Inspectable["Mapper[Any]"]): _as_declarative(reg, cls, dict_) type.__init__(cls, classname, bases, dict_) - def __setattr__(cls, key, value): - _add_attribute(cls, key, value) - - def __delattr__(cls, key): - _del_attribute(cls, key) - def synonym_for(name, map_column=False): """Decorator that produces an :func:`_orm.synonym` @@ -250,6 +265,9 @@ class declared_attr(interfaces._MappedAttribute[_T]): self._cascading = cascading self.__doc__ = fn.__doc__ + def _collect_return_annotation(self) -> Optional[Type[Any]]: + return util.get_annotations(self.fget).get("return") + def __get__(self, instance, owner) -> InstrumentedAttribute[_T]: # the declared_attr needs to make use of a cache that exists # for the span of the declarative scan_attributes() phase. @@ -409,6 +427,11 @@ def _setup_declarative_base(cls): else: metadata = None + if "type_annotation_map" in cls.__dict__: + type_annotation_map = cls.__dict__["type_annotation_map"] + else: + type_annotation_map = None + reg = cls.__dict__.get("registry", None) if reg is not None: if not isinstance(reg, registry): @@ -416,8 +439,18 @@ def _setup_declarative_base(cls): "Declarative base class has a 'registry' attribute that is " "not an instance of sqlalchemy.orm.registry()" ) + elif type_annotation_map is not None: + raise exc.InvalidRequestError( + "Declarative base class has both a 'registry' attribute and a " + "type_annotation_map entry. Per-base type_annotation_maps " + "are not supported. Please apply the type_annotation_map " + "to this registry directly." + ) + else: - reg = registry(metadata=metadata) + reg = registry( + metadata=metadata, type_annotation_map=type_annotation_map + ) cls.registry = reg cls._sa_registry = reg @@ -476,6 +509,44 @@ class DeclarativeBase( mappings. The superclass makes use of the ``__init_subclass__()`` method to set up new classes and metaclasses aren't used. + When first used, the :class:`_orm.DeclarativeBase` class instantiates a new + :class:`_orm.registry` to be used with the base, assuming one was not + provided explicitly. The :class:`_orm.DeclarativeBase` class supports + class-level attributes which act as parameters for the construction of this + registry; such as to indicate a specific :class:`_schema.MetaData` + collection as well as a specific value for + :paramref:`_orm.registry.type_annotation_map`:: + + from typing import Annotation + + from sqlalchemy import BigInteger + from sqlalchemy import MetaData + from sqlalchemy import String + from sqlalchemy.orm import DeclarativeBase + + bigint = Annotation(int, "bigint") + my_metadata = MetaData() + + class Base(DeclarativeBase): + metadata = my_metadata + type_annotation_map = { + str: String().with_variant(String(255), "mysql", "mariadb"), + bigint: BigInteger() + } + + Class-level attributes which may be specified include: + + :param metadata: optional :class:`_schema.MetaData` collection. + If a :class:`_orm.registry` is constructed automatically, this + :class:`_schema.MetaData` collection will be used to construct it. + Otherwise, the local :class:`_schema.MetaData` collection will supercede + that used by an existing :class:`_orm.registry` passed using the + :paramref:`_orm.DeclarativeBase.registry` parameter. + :param type_annotation_map: optional type annotation map that will be + passed to the :class:`_orm.registry` as + :paramref:`_orm.registry.type_annotation_map`. + :param registry: supply a pre-existing :class:`_orm.registry` directly. + .. versionadded:: 2.0 """ @@ -516,12 +587,13 @@ def add_mapped_attribute(target, key, attr): def declarative_base( - metadata=None, + metadata: Optional[MetaData] = None, mapper=None, cls=object, name="Base", - constructor=_declarative_constructor, - class_registry=None, + class_registry: Optional[clsregistry._ClsRegistryType] = None, + type_annotation_map: Optional[_TypeAnnotationMapType] = None, + constructor: Callable[..., None] = _declarative_constructor, metaclass=DeclarativeMeta, ) -> Any: r"""Construct a base class for declarative class definitions. @@ -593,6 +665,14 @@ def declarative_base( to share the same registry of class names for simplified inter-base relationships. + :param type_annotation_map: optional dictionary of Python types to + SQLAlchemy :class:`_types.TypeEngine` classes or instances. This + is used exclusively by the :class:`_orm.MappedColumn` construct + to produce column types based on annotations within the + :class:`_orm.Mapped` type. + + .. versionadded:: 2.0 + :param metaclass: Defaults to :class:`.DeclarativeMeta`. A metaclass or __metaclass__ compatible callable to use as the meta type of the generated @@ -608,6 +688,7 @@ def declarative_base( metadata=metadata, class_registry=class_registry, constructor=constructor, + type_annotation_map=type_annotation_map, ).generate_base( mapper=mapper, cls=cls, @@ -651,9 +732,10 @@ class registry: def __init__( self, - metadata=None, - class_registry=None, - constructor=_declarative_constructor, + metadata: Optional[MetaData] = None, + class_registry: Optional[clsregistry._ClsRegistryType] = None, + type_annotation_map: Optional[_TypeAnnotationMapType] = None, + constructor: Callable[..., None] = _declarative_constructor, ): r"""Construct a new :class:`_orm.registry` @@ -679,6 +761,14 @@ class registry: to share the same registry of class names for simplified inter-base relationships. + :param type_annotation_map: optional dictionary of Python types to + SQLAlchemy :class:`_types.TypeEngine` classes or instances. This + is used exclusively by the :class:`_orm.MappedColumn` construct + to produce column types based on annotations within the + :class:`_orm.Mapped` type. + + .. versionadded:: 2.0 + """ lcl_metadata = metadata or MetaData() @@ -690,7 +780,9 @@ class registry: self._non_primary_mappers = weakref.WeakKeyDictionary() self.metadata = lcl_metadata self.constructor = constructor - + self.type_annotation_map = {} + if type_annotation_map is not None: + self.update_type_annotation_map(type_annotation_map) self._dependents = set() self._dependencies = set() @@ -699,6 +791,25 @@ class registry: with mapperlib._CONFIGURE_MUTEX: mapperlib._mapper_registries[self] = True + def update_type_annotation_map( + self, + type_annotation_map: Mapping[ + Type, Union[Type[TypeEngine], TypeEngine] + ], + ) -> None: + """update the :paramref:`_orm.registry.type_annotation_map` with new + values.""" + + self.type_annotation_map.update( + { + sub_type: sqltype + for typ, sqltype in type_annotation_map.items() + for sub_type in compat_typing.expand_unions( + typ, include_union=True, discard_none=True + ) + } + ) + @property def mappers(self): """read only collection of all :class:`_orm.Mapper` objects.""" @@ -1131,6 +1242,9 @@ class registry: return _mapper(self, class_, local_table, kw) +RegistryType = registry + + def as_declarative(**kw): """ Class decorator which will adapt a given class into a diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index fb736806c..342aa772b 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -5,23 +5,34 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php """Internal implementation for declarative.""" + +from __future__ import annotations + import collections +from typing import Any +from typing import Dict +from typing import Tuple import weakref -from sqlalchemy.orm import attributes -from sqlalchemy.orm import instrumentation +from . import attributes from . import clsregistry from . import exc as orm_exc +from . import instrumentation from . import mapperlib from .attributes import InstrumentedAttribute from .attributes import QueryableAttribute from .base import _is_mapped_class from .base import InspectionAttr -from .descriptor_props import CompositeProperty -from .descriptor_props import SynonymProperty +from .descriptor_props import Composite +from .descriptor_props import Synonym +from .interfaces import _IntrospectsAnnotations +from .interfaces import _MappedAttribute +from .interfaces import _MapsColumns from .interfaces import MapperProperty from .mapper import Mapper as mapper from .properties import ColumnProperty +from .properties import MappedColumn +from .util import _is_mapped_annotation from .util import class_mapper from .. import event from .. import exc @@ -130,7 +141,7 @@ def _mapper(registry, cls, table, mapper_kw): @util.preload_module("sqlalchemy.orm.decl_api") -def _is_declarative_props(obj): +def _is_declarative_props(obj: Any) -> bool: declared_attr = util.preloaded.orm_decl_api.declared_attr return isinstance(obj, (declared_attr, util.classproperty)) @@ -208,7 +219,7 @@ class _MapperConfig: class _ImperativeMapperConfig(_MapperConfig): - __slots__ = ("dict_", "local_table", "inherits") + __slots__ = ("local_table", "inherits") def __init__( self, @@ -221,7 +232,6 @@ class _ImperativeMapperConfig(_MapperConfig): registry, cls_, mapper_kw ) - self.dict_ = {} self.local_table = self.set_cls_attribute("__table__", table) with mapperlib._CONFIGURE_MUTEX: @@ -277,7 +287,10 @@ class _ImperativeMapperConfig(_MapperConfig): class _ClassScanMapperConfig(_MapperConfig): __slots__ = ( - "dict_", + "registry", + "clsdict_view", + "collected_attributes", + "collected_annotations", "local_table", "persist_selectable", "declared_columns", @@ -299,11 +312,17 @@ class _ClassScanMapperConfig(_MapperConfig): ): super(_ClassScanMapperConfig, self).__init__(registry, cls_, mapper_kw) - - self.dict_ = dict(dict_) if dict_ else {} + self.registry = registry self.persist_selectable = None - self.declared_columns = set() + + self.clsdict_view = ( + util.immutabledict(dict_) if dict_ else util.EMPTY_DICT + ) + self.collected_attributes = {} + self.collected_annotations: Dict[str, Tuple[Any, bool]] = {} + self.declared_columns = util.OrderedSet() self.column_copies = {} + self._setup_declared_events() self._scan_attributes() @@ -407,6 +426,19 @@ class _ClassScanMapperConfig(_MapperConfig): return attribute_is_overridden + _skip_attrs = frozenset( + [ + "__module__", + "__annotations__", + "__doc__", + "__dict__", + "__weakref__", + "_sa_class_manager", + "__dict__", + "__weakref__", + ] + ) + def _cls_attr_resolver(self, cls): """produce a function to iterate the "attributes" of a class, adjusting for SQLAlchemy fields embedded in dataclass fields. @@ -416,31 +448,52 @@ class _ClassScanMapperConfig(_MapperConfig): cls, "__sa_dataclass_metadata_key__", None ) + cls_annotations = util.get_annotations(cls) + + cls_vars = vars(cls) + + skip = self._skip_attrs + + names = util.merge_lists_w_ordering( + [n for n in cls_vars if n not in skip], list(cls_annotations) + ) if sa_dataclass_metadata_key is None: def local_attributes_for_class(): - for name, obj in vars(cls).items(): - yield name, obj, False + return ( + ( + name, + cls_vars.get(name), + cls_annotations.get(name), + False, + ) + for name in names + ) else: - field_names = set() + dataclass_fields = { + field.name: field for field in util.local_dataclass_fields(cls) + } def local_attributes_for_class(): - for field in util.local_dataclass_fields(cls): - if sa_dataclass_metadata_key in field.metadata: - field_names.add(field.name) + for name in names: + field = dataclass_fields.get(name, None) + if field and sa_dataclass_metadata_key in field.metadata: yield field.name, _as_dc_declaredattr( field.metadata, sa_dataclass_metadata_key - ), True - for name, obj in vars(cls).items(): - if name not in field_names: - yield name, obj, False + ), cls_annotations.get(field.name), True + else: + yield name, cls_vars.get(name), cls_annotations.get( + name + ), False return local_attributes_for_class def _scan_attributes(self): cls = self.cls - dict_ = self.dict_ + + clsdict_view = self.clsdict_view + collected_attributes = self.collected_attributes column_copies = self.column_copies mapper_args_fn = None table_args = inherited_table_args = None @@ -462,10 +515,16 @@ class _ClassScanMapperConfig(_MapperConfig): if not class_mapped and base is not cls: self._produce_column_copies( - local_attributes_for_class, attribute_is_overridden + local_attributes_for_class, + attribute_is_overridden, ) - for name, obj, is_dataclass in local_attributes_for_class(): + for ( + name, + obj, + annotation, + is_dataclass, + ) in local_attributes_for_class(): if name == "__mapper_args__": check_decl = _check_declared_props_nocascade( obj, name, cls @@ -514,7 +573,12 @@ class _ClassScanMapperConfig(_MapperConfig): elif base is not cls: # we're a mixin, abstract base, or something that is # acting like that for now. - if isinstance(obj, Column): + + if isinstance(obj, (Column, MappedColumn)): + self.collected_annotations[name] = ( + annotation, + False, + ) # already copied columns to the mapped class. continue elif isinstance(obj, MapperProperty): @@ -526,8 +590,12 @@ class _ClassScanMapperConfig(_MapperConfig): "field() objects, use a lambda:" ) elif _is_declarative_props(obj): + # tried to get overloads to tell this to + # pylance, no luck + assert obj is not None + if obj._cascading: - if name in dict_: + if name in clsdict_view: # unfortunately, while we can use the user- # defined attribute here to allow a clean # override, if there's another @@ -541,7 +609,7 @@ class _ClassScanMapperConfig(_MapperConfig): "@declared_attr.cascading; " "skipping" % (name, cls) ) - dict_[name] = column_copies[ + collected_attributes[name] = column_copies[ obj ] = ret = obj.__get__(obj, cls) setattr(cls, name, ret) @@ -579,19 +647,36 @@ class _ClassScanMapperConfig(_MapperConfig): ): ret = ret.descriptor - dict_[name] = column_copies[obj] = ret + collected_attributes[name] = column_copies[ + obj + ] = ret if ( isinstance(ret, (Column, MapperProperty)) and ret.doc is None ): ret.doc = obj.__doc__ - # here, the attribute is some other kind of property that - # we assume is not part of the declarative mapping. - # however, check for some more common mistakes + + self.collected_annotations[name] = ( + obj._collect_return_annotation(), + False, + ) + elif _is_mapped_annotation(annotation, cls): + self.collected_annotations[name] = ( + annotation, + is_dataclass, + ) + if obj is None: + collected_attributes[name] = MappedColumn() + else: + collected_attributes[name] = obj else: + # here, the attribute is some other kind of + # property that we assume is not part of the + # declarative mapping. however, check for some + # more common mistakes self._warn_for_decl_attributes(base, name, obj) elif is_dataclass and ( - name not in dict_ or dict_[name] is not obj + name not in clsdict_view or clsdict_view[name] is not obj ): # here, we are definitely looking at the target class # and not a superclass. this is currently a @@ -606,7 +691,20 @@ class _ClassScanMapperConfig(_MapperConfig): if _is_declarative_props(obj): obj = obj.fget() - dict_[name] = obj + collected_attributes[name] = obj + self.collected_annotations[name] = ( + annotation, + True, + ) + else: + self.collected_annotations[name] = ( + annotation, + False, + ) + if obj is None and _is_mapped_annotation(annotation, cls): + collected_attributes[name] = MappedColumn() + elif name in clsdict_view: + collected_attributes[name] = obj if inherited_table_args and not tablename: table_args = None @@ -618,46 +716,55 @@ class _ClassScanMapperConfig(_MapperConfig): def _warn_for_decl_attributes(self, cls, key, c): if isinstance(c, expression.ColumnClause): util.warn( - "Attribute '%s' on class %s appears to be a non-schema " - "'sqlalchemy.sql.column()' " + f"Attribute '{key}' on class {cls} appears to " + "be a non-schema 'sqlalchemy.sql.column()' " "object; this won't be part of the declarative mapping" - % (key, cls) ) def _produce_column_copies( self, attributes_for_class, attribute_is_overridden ): cls = self.cls - dict_ = self.dict_ + dict_ = self.clsdict_view + collected_attributes = self.collected_attributes column_copies = self.column_copies # copy mixin columns to the mapped class - for name, obj, is_dataclass in attributes_for_class(): - if isinstance(obj, Column): + for name, obj, annotation, is_dataclass in attributes_for_class(): + if isinstance(obj, (Column, MappedColumn)): if attribute_is_overridden(name, obj): # if column has been overridden # (like by the InstrumentedAttribute of the # superclass), skip continue - elif obj.foreign_keys: - raise exc.InvalidRequestError( - "Columns with foreign keys to other columns " - "must be declared as @declared_attr callables " - "on declarative mixin classes. For dataclass " - "field() objects, use a lambda:." - ) elif name not in dict_ and not ( "__table__" in dict_ and (obj.name or name) in dict_["__table__"].c ): + if obj.foreign_keys: + for fk in obj.foreign_keys: + if ( + fk._table_column is not None + and fk._table_column.table is None + ): + raise exc.InvalidRequestError( + "Columns with foreign keys to " + "non-table-bound " + "columns must be declared as " + "@declared_attr callables " + "on declarative mixin classes. " + "For dataclass " + "field() objects, use a lambda:." + ) + column_copies[obj] = copy_ = obj._copy() - copy_._creation_order = obj._creation_order + collected_attributes[name] = copy_ + setattr(cls, name, copy_) - dict_[name] = copy_ def _extract_mappable_attributes(self): cls = self.cls - dict_ = self.dict_ + collected_attributes = self.collected_attributes our_stuff = self.properties @@ -665,13 +772,17 @@ class _ClassScanMapperConfig(_MapperConfig): cls, "_sa_decl_prepare_nocascade", strict=True ) - for k in list(dict_): + for k in list(collected_attributes): if k in ("__table__", "__tablename__", "__mapper_args__"): continue - value = dict_[k] + value = collected_attributes[k] + if _is_declarative_props(value): + # @declared_attr in collected_attributes only occurs here for a + # @declared_attr that's directly on the mapped class; + # for a mixin, these have already been evaluated if value._cascading: util.warn( "Use of @declared_attr.cascading only applies to " @@ -689,13 +800,13 @@ class _ClassScanMapperConfig(_MapperConfig): ): # detect a QueryableAttribute that's already mapped being # assigned elsewhere in userland, turn into a synonym() - value = SynonymProperty(value.key) + value = Synonym(value.key) setattr(cls, k, value) if ( isinstance(value, tuple) and len(value) == 1 - and isinstance(value[0], (Column, MapperProperty)) + and isinstance(value[0], (Column, _MappedAttribute)) ): util.warn( "Ignoring declarative-like tuple value of attribute " @@ -703,12 +814,12 @@ class _ClassScanMapperConfig(_MapperConfig): "accidentally placed at the end of the line?" % k ) continue - elif not isinstance(value, (Column, MapperProperty)): + elif not isinstance(value, (Column, MapperProperty, _MapsColumns)): # using @declared_attr for some object that - # isn't Column/MapperProperty; remove from the dict_ + # isn't Column/MapperProperty; remove from the clsdict_view # and place the evaluated value onto the class. if not k.startswith("__"): - dict_.pop(k) + collected_attributes.pop(k) self._warn_for_decl_attributes(cls, k, value) if not late_mapped: setattr(cls, k, value) @@ -722,27 +833,37 @@ class _ClassScanMapperConfig(_MapperConfig): "for the MetaData instance when using a " "declarative base class." ) + elif isinstance(value, _IntrospectsAnnotations): + annotation, is_dataclass = self.collected_annotations.get( + k, (None, None) + ) + value.declarative_scan( + self.registry, cls, k, annotation, is_dataclass + ) our_stuff[k] = value def _extract_declared_columns(self): our_stuff = self.properties - # set up attributes in the order they were created - util.sort_dictionary( - our_stuff, key=lambda key: our_stuff[key]._creation_order - ) - # extract columns from the class dict declared_columns = self.declared_columns name_to_prop_key = collections.defaultdict(set) for key, c in list(our_stuff.items()): - if isinstance(c, (ColumnProperty, CompositeProperty)): - for col in c.columns: - if isinstance(col, Column) and col.table is None: - _undefer_column_name(key, col) - if not isinstance(c, CompositeProperty): - name_to_prop_key[col.name].add(key) - declared_columns.add(col) + if isinstance(c, _MapsColumns): + for col in c.columns_to_assign: + if not isinstance(c, Composite): + name_to_prop_key[col.name].add(key) + declared_columns.add(col) + + # remove object from the dictionary that will be passed + # as mapper(properties={...}) if it is not a MapperProperty + # (i.e. this currently means it's a MappedColumn) + mp_to_assign = c.mapper_property_to_assign + if mp_to_assign: + our_stuff[key] = mp_to_assign + else: + del our_stuff[key] + elif isinstance(c, Column): _undefer_column_name(key, c) name_to_prop_key[c.name].add(key) @@ -769,16 +890,12 @@ class _ClassScanMapperConfig(_MapperConfig): cls = self.cls tablename = self.tablename table_args = self.table_args - dict_ = self.dict_ + clsdict_view = self.clsdict_view declared_columns = self.declared_columns manager = attributes.manager_of_class(cls) - declared_columns = self.declared_columns = sorted( - declared_columns, key=lambda c: c._creation_order - ) - - if "__table__" not in dict_ and table is None: + if "__table__" not in clsdict_view and table is None: if hasattr(cls, "__table_cls__"): table_cls = util.unbound_method_to_callable(cls.__table_cls__) else: @@ -796,11 +913,11 @@ class _ClassScanMapperConfig(_MapperConfig): else: args = table_args - autoload_with = dict_.get("__autoload_with__") + autoload_with = clsdict_view.get("__autoload_with__") if autoload_with: table_kw["autoload_with"] = autoload_with - autoload = dict_.get("__autoload__") + autoload = clsdict_view.get("__autoload__") if autoload: table_kw["autoload"] = True @@ -1095,18 +1212,21 @@ def _add_attribute(cls, key, value): _undefer_column_name(key, value) cls.__table__.append_column(value, replace_existing=True) cls.__mapper__.add_property(key, value) - elif isinstance(value, ColumnProperty): - for col in value.columns: - if isinstance(col, Column) and col.table is None: - _undefer_column_name(key, col) - cls.__table__.append_column(col, replace_existing=True) - cls.__mapper__.add_property(key, value) + elif isinstance(value, _MapsColumns): + mp = value.mapper_property_to_assign + for col in value.columns_to_assign: + _undefer_column_name(key, col) + cls.__table__.append_column(col, replace_existing=True) + if not mp: + cls.__mapper__.add_property(key, col) + if mp: + cls.__mapper__.add_property(key, mp) elif isinstance(value, MapperProperty): cls.__mapper__.add_property(key, value) elif isinstance(value, QueryableAttribute) and value.key != key: # detect a QueryableAttribute that's already mapped being # assigned elsewhere in userland, turn into a synonym() - value = SynonymProperty(value.key) + value = Synonym(value.key) cls.__mapper__.add_property(key, value) else: type.__setattr__(cls, key, value) @@ -1124,7 +1244,7 @@ def _del_attribute(cls, key): ): value = cls.__dict__[key] if isinstance( - value, (Column, ColumnProperty, MapperProperty, QueryableAttribute) + value, (Column, _MapsColumns, MapperProperty, QueryableAttribute) ): raise NotImplementedError( "Can't un-map individual mapped attributes on a mapped class." diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index 5e67b64cd..4526a8b33 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -10,14 +10,26 @@ that exist as configurational elements, but don't participate as actively in the load/persist ORM loop. """ +import inspect +import itertools +import operator +import typing from typing import Any -from typing import Type +from typing import Callable +from typing import List +from typing import Optional +from typing import Tuple from typing import TypeVar +from typing import Union from . import attributes from . import util as orm_util +from .base import Mapped +from .interfaces import _IntrospectsAnnotations +from .interfaces import _MapsColumns from .interfaces import MapperProperty from .interfaces import PropComparator +from .util import _extract_mapped_subtype from .util import _none_set from .. import event from .. import exc as sa_exc @@ -27,6 +39,9 @@ from .. import util from ..sql import expression from ..sql import operators +if typing.TYPE_CHECKING: + from .properties import MappedColumn + _T = TypeVar("_T", bound=Any) _PT = TypeVar("_PT", bound=Any) @@ -92,30 +107,48 @@ class DescriptorProperty(MapperProperty[_T]): mapper.class_manager.instrument_attribute(self.key, proxy_attr) -class CompositeProperty(DescriptorProperty[_T]): +class Composite( + _MapsColumns[_T], _IntrospectsAnnotations, DescriptorProperty[_T] +): """Defines a "composite" mapped attribute, representing a collection of columns as one attribute. - :class:`.CompositeProperty` is constructed using the :func:`.composite` + :class:`.Composite` is constructed using the :func:`.composite` function. + .. versionchanged:: 2.0 Renamed :class:`_orm.CompositeProperty` + to :class:`_orm.Composite`. The old name + :class:`_orm.CompositeProperty` remains as an alias. + .. seealso:: :ref:`mapper_composite` """ - def __init__(self, class_: Type[_T], *attrs, **kwargs): - super(CompositeProperty, self).__init__() + composite_class: Union[type, Callable[..., type]] + attrs: Tuple[ + Union[sql.ColumnElement[Any], "MappedColumn", str, Mapped[Any]], ... + ] + + def __init__(self, class_=None, *attrs, **kwargs): + super().__init__() + + if isinstance(class_, (Mapped, str, sql.ColumnElement)): + self.attrs = (class_,) + attrs + # will initialize within declarative_scan + self.composite_class = None # type: ignore + else: + self.composite_class = class_ + self.attrs = attrs - self.attrs = attrs - self.composite_class = class_ self.active_history = kwargs.get("active_history", False) self.deferred = kwargs.get("deferred", False) self.group = kwargs.get("group", None) self.comparator_factory = kwargs.pop( "comparator_factory", self.__class__.Comparator ) + self._generated_composite_accessor = None if "info" in kwargs: self.info = kwargs.pop("info") @@ -123,11 +156,26 @@ class CompositeProperty(DescriptorProperty[_T]): self._create_descriptor() def instrument_class(self, mapper): - super(CompositeProperty, self).instrument_class(mapper) + super().instrument_class(mapper) self._setup_event_handlers() + def _composite_values_from_instance(self, value): + if self._generated_composite_accessor: + return self._generated_composite_accessor(value) + else: + try: + accessor = value.__composite_values__ + except AttributeError as ae: + raise sa_exc.InvalidRequestError( + f"Composite class {self.composite_class.__name__} is not " + f"a dataclass and does not define a __composite_values__()" + " method; can't get state" + ) from ae + else: + return accessor() + def do_init(self): - """Initialization which occurs after the :class:`.CompositeProperty` + """Initialization which occurs after the :class:`.Composite` has been associated with its parent mapper. """ @@ -181,7 +229,8 @@ class CompositeProperty(DescriptorProperty[_T]): setattr(instance, key, None) else: for key, value in zip( - self._attribute_keys, value.__composite_values__() + self._attribute_keys, + self._composite_values_from_instance(value), ): setattr(instance, key, value) @@ -196,18 +245,74 @@ class CompositeProperty(DescriptorProperty[_T]): self.descriptor = property(fget, fset, fdel) + @util.preload_module("sqlalchemy.orm.properties") + @util.preload_module("sqlalchemy.orm.decl_base") + def declarative_scan( + self, registry, cls, key, annotation, is_dataclass_field + ): + MappedColumn = util.preloaded.orm_properties.MappedColumn + decl_base = util.preloaded.orm_decl_base + + argument = _extract_mapped_subtype( + annotation, + cls, + key, + MappedColumn, + self.composite_class is None, + is_dataclass_field, + ) + + if argument and self.composite_class is None: + if isinstance(argument, str) or hasattr( + argument, "__forward_arg__" + ): + raise sa_exc.ArgumentError( + f"Can't use forward ref {argument} for composite " + f"class argument" + ) + self.composite_class = argument + insp = inspect.signature(self.composite_class) + for param, attr in itertools.zip_longest( + insp.parameters.values(), self.attrs + ): + if param is None or attr is None: + raise sa_exc.ArgumentError( + f"number of arguments to {self.composite_class.__name__} " + f"class and number of attributes don't match" + ) + if isinstance(attr, MappedColumn): + attr.declarative_scan_for_composite( + registry, cls, key, param.name, param.annotation + ) + elif isinstance(attr, schema.Column): + decl_base._undefer_column_name(param.name, attr) + + if not hasattr(cls, "__composite_values__"): + getter = operator.attrgetter( + *[p.name for p in insp.parameters.values()] + ) + if len(insp.parameters) == 1: + self._generated_composite_accessor = lambda obj: (getter(obj),) + else: + self._generated_composite_accessor = getter + @util.memoized_property def _comparable_elements(self): return [getattr(self.parent.class_, prop.key) for prop in self.props] @util.memoized_property + @util.preload_module("orm.properties") def props(self): props = [] + MappedColumn = util.preloaded.orm_properties.MappedColumn + for attr in self.attrs: if isinstance(attr, str): prop = self.parent.get_property(attr, _configure_mappers=False) elif isinstance(attr, schema.Column): prop = self.parent._columntoproperty[attr] + elif isinstance(attr, MappedColumn): + prop = self.parent._columntoproperty[attr.column] elif isinstance(attr, attributes.InstrumentedAttribute): prop = attr.property else: @@ -220,8 +325,22 @@ class CompositeProperty(DescriptorProperty[_T]): return props @property + @util.preload_module("orm.properties") def columns(self): - return [a for a in self.attrs if isinstance(a, schema.Column)] + MappedColumn = util.preloaded.orm_properties.MappedColumn + return [ + a.column if isinstance(a, MappedColumn) else a + for a in self.attrs + if isinstance(a, (schema.Column, MappedColumn)) + ] + + @property + def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]: + return self + + @property + def columns_to_assign(self) -> List[schema.Column]: + return [c for c in self.columns if c.table is None] def _setup_arguments_on_columns(self): """Propagate configuration arguments made on this composite @@ -351,9 +470,7 @@ class CompositeProperty(DescriptorProperty[_T]): class CompositeBundle(orm_util.Bundle): def __init__(self, property_, expr): self.property = property_ - super(CompositeProperty.CompositeBundle, self).__init__( - property_.key, *expr - ) + super().__init__(property_.key, *expr) def create_row_processor(self, query, procs, labels): def proc(row): @@ -365,7 +482,7 @@ class CompositeProperty(DescriptorProperty[_T]): class Comparator(PropComparator[_PT]): """Produce boolean, comparison, and other operators for - :class:`.CompositeProperty` attributes. + :class:`.Composite` attributes. See the example in :ref:`composite_operations` for an overview of usage , as well as the documentation for :class:`.PropComparator`. @@ -402,7 +519,7 @@ class CompositeProperty(DescriptorProperty[_T]): "proxy_key": self.prop.key, } ) - return CompositeProperty.CompositeBundle(self.prop, clauses) + return Composite.CompositeBundle(self.prop, clauses) def _bulk_update_tuples(self, value): if isinstance(value, sql.elements.BindParameter): @@ -411,7 +528,7 @@ class CompositeProperty(DescriptorProperty[_T]): if value is None: values = [None for key in self.prop._attribute_keys] elif isinstance(value, self.prop.composite_class): - values = value.__composite_values__() + values = self.prop._composite_values_from_instance(value) else: raise sa_exc.ArgumentError( "Can't UPDATE composite attribute %s to %r" @@ -434,7 +551,7 @@ class CompositeProperty(DescriptorProperty[_T]): if other is None: values = [None] * len(self.prop._comparable_elements) else: - values = other.__composite_values__() + values = self.prop._composite_values_from_instance(other) comparisons = [ a == b for a, b in zip(self.prop._comparable_elements, values) ] @@ -477,7 +594,7 @@ class ConcreteInheritedProperty(DescriptorProperty[_T]): return comparator_callable def __init__(self): - super(ConcreteInheritedProperty, self).__init__() + super().__init__() def warn(): raise AttributeError( @@ -502,7 +619,24 @@ class ConcreteInheritedProperty(DescriptorProperty[_T]): self.descriptor = NoninheritedConcreteProp() -class SynonymProperty(DescriptorProperty[_T]): +class Synonym(DescriptorProperty[_T]): + """Denote an attribute name as a synonym to a mapped property, + in that the attribute will mirror the value and expression behavior + of another attribute. + + :class:`.Synonym` is constructed using the :func:`_orm.synonym` + function. + + .. versionchanged:: 2.0 Renamed :class:`_orm.SynonymProperty` + to :class:`_orm.Synonym`. The old name + :class:`_orm.SynonymProperty` remains as an alias. + + .. seealso:: + + :ref:`synonyms` - Overview of synonyms + + """ + def __init__( self, name, @@ -512,7 +646,7 @@ class SynonymProperty(DescriptorProperty[_T]): doc=None, info=None, ): - super(SynonymProperty, self).__init__() + super().__init__() self.name = name self.map_column = map_column diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py index ade47480d..3d9c61c20 100644 --- a/lib/sqlalchemy/orm/dynamic.py +++ b/lib/sqlalchemy/orm/dynamic.py @@ -28,7 +28,7 @@ from ..engine import result @log.class_logger -@relationships.RelationshipProperty.strategy_for(lazy="dynamic") +@relationships.Relationship.strategy_for(lazy="dynamic") class DynaLoader(strategies.AbstractRelationshipLoader): def init_class_attribute(self, mapper): self.is_class_level = True diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index b9a5aaf51..1f9ec78f7 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -20,7 +20,12 @@ import collections import typing from typing import Any from typing import cast +from typing import List +from typing import Optional +from typing import Tuple +from typing import Type from typing import TypeVar +from typing import Union from . import exc as orm_exc from . import path_registry @@ -41,8 +46,15 @@ from .. import util from ..sql import operators from ..sql import roles from ..sql import visitors +from ..sql._typing import _ColumnsClauseElement from ..sql.base import ExecutableOption from ..sql.cache_key import HasCacheKey +from ..sql.schema import Column +from ..sql.type_api import TypeEngine +from ..util.typing import TypedDict + +if typing.TYPE_CHECKING: + from .decl_api import RegistryType _T = TypeVar("_T", bound=Any) @@ -85,6 +97,54 @@ class ORMFromClauseRole(roles.StrictFromClauseRole): _role_name = "ORM mapped entity, aliased entity, or FROM expression" +class ORMColumnDescription(TypedDict): + name: str + type: Union[Type, TypeEngine] + aliased: bool + expr: _ColumnsClauseElement + entity: Optional[_ColumnsClauseElement] + + +class _IntrospectsAnnotations: + __slots__ = () + + def declarative_scan( + self, + registry: "RegistryType", + cls: type, + key: str, + annotation: Optional[type], + is_dataclass_field: Optional[bool], + ) -> None: + """Perform class-specific initializaton at early declarative scanning + time. + + .. versionadded:: 2.0 + + """ + + +class _MapsColumns(_MappedAttribute[_T]): + """interface for declarative-capable construct that delivers one or more + Column objects to the declarative process to be part of a Table. + """ + + __slots__ = () + + @property + def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]: + """return a MapperProperty to be assigned to the declarative mapping""" + raise NotImplementedError() + + @property + def columns_to_assign(self) -> List[Column]: + """A list of Column objects that should be declaratively added to the + new Table object. + + """ + raise NotImplementedError() + + @inspection._self_inspects class MapperProperty( HasCacheKey, _MappedAttribute[_T], InspectionAttr, util.MemoizedSlots @@ -96,7 +156,7 @@ class MapperProperty( an instance of :class:`.ColumnProperty`, and a reference to another class produced by :func:`_orm.relationship`, represented in the mapping as an instance of - :class:`.RelationshipProperty`. + :class:`.Relationship`. """ @@ -118,7 +178,7 @@ class MapperProperty( This collection is checked before the 'cascade_iterator' method is called. - The collection typically only applies to a RelationshipProperty. + The collection typically only applies to a Relationship. """ @@ -132,7 +192,7 @@ class MapperProperty( def _links_to_entity(self): """True if this MapperProperty refers to a mapped entity. - Should only be True for RelationshipProperty, False for all others. + Should only be True for Relationship, False for all others. """ raise NotImplementedError() @@ -189,7 +249,7 @@ class MapperProperty( Note that the 'cascade' collection on this MapperProperty is checked first for the given type before cascade_iterator is called. - This method typically only applies to RelationshipProperty. + This method typically only applies to Relationship. """ @@ -323,7 +383,7 @@ class PropComparator( be redefined at both the Core and ORM level. :class:`.PropComparator` is the base class of operator redefinition for ORM-level operations, including those of :class:`.ColumnProperty`, - :class:`.RelationshipProperty`, and :class:`.CompositeProperty`. + :class:`.Relationship`, and :class:`.Composite`. User-defined subclasses of :class:`.PropComparator` may be created. The built-in Python comparison and math operator methods, such as @@ -339,19 +399,19 @@ class PropComparator( from sqlalchemy.orm.properties import \ ColumnProperty,\ - CompositeProperty,\ - RelationshipProperty + Composite,\ + Relationship class MyColumnComparator(ColumnProperty.Comparator): def __eq__(self, other): return self.__clause_element__() == other - class MyRelationshipComparator(RelationshipProperty.Comparator): + class MyRelationshipComparator(Relationship.Comparator): def any(self, expression): "define the 'any' operation" # ... - class MyCompositeComparator(CompositeProperty.Comparator): + class MyCompositeComparator(Composite.Comparator): def __gt__(self, other): "redefine the 'greater than' operation" @@ -386,9 +446,9 @@ class PropComparator( :class:`.ColumnProperty.Comparator` - :class:`.RelationshipProperty.Comparator` + :class:`.Relationship.Comparator` - :class:`.CompositeProperty.Comparator` + :class:`.Composite.Comparator` :class:`.ColumnOperators` @@ -552,7 +612,7 @@ class PropComparator( given criterion. The usual implementation of ``any()`` is - :meth:`.RelationshipProperty.Comparator.any`. + :meth:`.Relationship.Comparator.any`. :param criterion: an optional ClauseElement formulated against the member class' table or attributes. @@ -570,7 +630,7 @@ class PropComparator( given criterion. The usual implementation of ``has()`` is - :meth:`.RelationshipProperty.Comparator.has`. + :meth:`.Relationship.Comparator.has`. :param criterion: an optional ClauseElement formulated against the member class' table or attributes. @@ -606,10 +666,13 @@ class StrategizedProperty(MapperProperty[_T]): "strategy", "_wildcard_token", "_default_path_loader_key", + "strategy_key", ) inherit_cache = True strategy_wildcard_key = None + strategy_key: Tuple[Any, ...] + def _memoized_attr__wildcard_token(self): return ( f"{self.strategy_wildcard_key}:{path_registry._WILDCARD_TOKEN}", diff --git a/lib/sqlalchemy/orm/mapped_collection.py b/lib/sqlalchemy/orm/mapped_collection.py new file mode 100644 index 000000000..75abeef4c --- /dev/null +++ b/lib/sqlalchemy/orm/mapped_collection.py @@ -0,0 +1,232 @@ +# orm/collections.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +import operator +from typing import Any +from typing import Callable +from typing import Dict +from typing import Type +from typing import TypeVar + +from . import base +from .collections import collection +from .. import exc as sa_exc +from .. import util +from ..sql import coercions +from ..sql import expression +from ..sql import roles + +_KT = TypeVar("_KT", bound=Any) +_VT = TypeVar("_VT", bound=Any) + + +class _PlainColumnGetter: + """Plain column getter, stores collection of Column objects + directly. + + Serializes to a :class:`._SerializableColumnGetterV2` + which has more expensive __call__() performance + and some rare caveats. + + """ + + __slots__ = ("cols", "composite") + + def __init__(self, cols): + self.cols = cols + self.composite = len(cols) > 1 + + def __reduce__(self): + return _SerializableColumnGetterV2._reduce_from_cols(self.cols) + + def _cols(self, mapper): + return self.cols + + def __call__(self, value): + state = base.instance_state(value) + m = base._state_mapper(state) + + key = [ + m._get_state_attr_by_column(state, state.dict, col) + for col in self._cols(m) + ] + + if self.composite: + return tuple(key) + else: + return key[0] + + +class _SerializableColumnGetterV2(_PlainColumnGetter): + """Updated serializable getter which deals with + multi-table mapped classes. + + Two extremely unusual cases are not supported. + Mappings which have tables across multiple metadata + objects, or which are mapped to non-Table selectables + linked across inheriting mappers may fail to function + here. + + """ + + __slots__ = ("colkeys",) + + def __init__(self, colkeys): + self.colkeys = colkeys + self.composite = len(colkeys) > 1 + + def __reduce__(self): + return self.__class__, (self.colkeys,) + + @classmethod + def _reduce_from_cols(cls, cols): + def _table_key(c): + if not isinstance(c.table, expression.TableClause): + return None + else: + return c.table.key + + colkeys = [(c.key, _table_key(c)) for c in cols] + return _SerializableColumnGetterV2, (colkeys,) + + def _cols(self, mapper): + cols = [] + metadata = getattr(mapper.local_table, "metadata", None) + for (ckey, tkey) in self.colkeys: + if tkey is None or metadata is None or tkey not in metadata: + cols.append(mapper.local_table.c[ckey]) + else: + cols.append(metadata.tables[tkey].c[ckey]) + return cols + + +def column_mapped_collection(mapping_spec): + """A dictionary-based collection type with column-based keying. + + Returns a :class:`.MappedCollection` factory with a keying function + generated from mapping_spec, which may be a Column or a sequence + of Columns. + + The key value must be immutable for the lifetime of the object. You + can not, for example, map on foreign key values if those key values will + change during the session, i.e. from None to a database-assigned integer + after a session flush. + + """ + cols = [ + coercions.expect(roles.ColumnArgumentRole, q, argname="mapping_spec") + for q in util.to_list(mapping_spec) + ] + keyfunc = _PlainColumnGetter(cols) + return _mapped_collection_cls(keyfunc) + + +def attribute_mapped_collection(attr_name: str) -> Type["MappedCollection"]: + """A dictionary-based collection type with attribute-based keying. + + Returns a :class:`.MappedCollection` factory with a keying based on the + 'attr_name' attribute of entities in the collection, where ``attr_name`` + is the string name of the attribute. + + .. warning:: the key value must be assigned to its final value + **before** it is accessed by the attribute mapped collection. + Additionally, changes to the key attribute are **not tracked** + automatically, which means the key in the dictionary is not + automatically synchronized with the key value on the target object + itself. See the section :ref:`key_collections_mutations` + for an example. + + """ + getter = operator.attrgetter(attr_name) + return _mapped_collection_cls(getter) + + +def mapped_collection( + keyfunc: Callable[[Any], _KT] +) -> Type["MappedCollection[_KT, Any]"]: + """A dictionary-based collection type with arbitrary keying. + + Returns a :class:`.MappedCollection` factory with a keying function + generated from keyfunc, a callable that takes an entity and returns a + key value. + + The key value must be immutable for the lifetime of the object. You + can not, for example, map on foreign key values if those key values will + change during the session, i.e. from None to a database-assigned integer + after a session flush. + + """ + return _mapped_collection_cls(keyfunc) + + +class MappedCollection(Dict[_KT, _VT]): + """A basic dictionary-based collection class. + + Extends dict with the minimal bag semantics that collection + classes require. ``set`` and ``remove`` are implemented in terms + of a keying function: any callable that takes an object and + returns an object for use as a dictionary key. + + """ + + def __init__(self, keyfunc): + """Create a new collection with keying provided by keyfunc. + + keyfunc may be any callable that takes an object and returns an object + for use as a dictionary key. + + The keyfunc will be called every time the ORM needs to add a member by + value-only (such as when loading instances from the database) or + remove a member. The usual cautions about dictionary keying apply- + ``keyfunc(object)`` should return the same output for the life of the + collection. Keying based on mutable properties can result in + unreachable instances "lost" in the collection. + + """ + self.keyfunc = keyfunc + + @classmethod + def _unreduce(cls, keyfunc, values): + mp = MappedCollection(keyfunc) + mp.update(values) + return mp + + def __reduce__(self): + return (MappedCollection._unreduce, (self.keyfunc, dict(self))) + + @collection.appender + @collection.internally_instrumented + def set(self, value, _sa_initiator=None): + """Add an item by value, consulting the keyfunc for the key.""" + + key = self.keyfunc(value) + self.__setitem__(key, value, _sa_initiator) + + @collection.remover + @collection.internally_instrumented + def remove(self, value, _sa_initiator=None): + """Remove an item by value, consulting the keyfunc for the key.""" + + key = self.keyfunc(value) + # Let self[key] raise if key is not in this collection + # testlib.pragma exempt:__ne__ + if self[key] != value: + raise sa_exc.InvalidRequestError( + "Can not remove '%s': collection holds '%s' for key '%s'. " + "Possible cause: is the MappedCollection key function " + "based on mutable properties or properties that only obtain " + "values after flush?" % (value, self[key], key) + ) + self.__delitem__(key, _sa_initiator) + + +def _mapped_collection_cls(keyfunc): + class _MKeyfuncMapped(MappedCollection): + def __init__(self): + super().__init__(keyfunc) + + return _MKeyfuncMapped diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index fdf065488..cd0d1e820 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -580,7 +580,16 @@ class Mapper( self.version_id_prop = version_id_col self.version_id_col = None else: - self.version_id_col = version_id_col + self.version_id_col = ( + coercions.expect( + roles.ColumnArgumentOrKeyRole, + version_id_col, + argname="version_id_col", + ) + if version_id_col is not None + else None + ) + if version_id_generator is False: self.version_id_generator = False elif version_id_generator is None: @@ -2473,7 +2482,7 @@ class Mapper( @HasMemoized.memoized_attribute @util.preload_module("sqlalchemy.orm.descriptor_props") def synonyms(self): - """Return a namespace of all :class:`.SynonymProperty` + """Return a namespace of all :class:`.Synonym` properties maintained by this :class:`_orm.Mapper`. .. seealso:: @@ -2485,7 +2494,7 @@ class Mapper( """ descriptor_props = util.preloaded.orm_descriptor_props - return self._filter_properties(descriptor_props.SynonymProperty) + return self._filter_properties(descriptor_props.Synonym) @property def entity_namespace(self): @@ -2508,7 +2517,7 @@ class Mapper( @util.preload_module("sqlalchemy.orm.relationships") @HasMemoized.memoized_attribute def relationships(self): - """A namespace of all :class:`.RelationshipProperty` properties + """A namespace of all :class:`.Relationship` properties maintained by this :class:`_orm.Mapper`. .. warning:: @@ -2531,13 +2540,13 @@ class Mapper( """ return self._filter_properties( - util.preloaded.orm_relationships.RelationshipProperty + util.preloaded.orm_relationships.Relationship ) @HasMemoized.memoized_attribute @util.preload_module("sqlalchemy.orm.descriptor_props") def composites(self): - """Return a namespace of all :class:`.CompositeProperty` + """Return a namespace of all :class:`.Composite` properties maintained by this :class:`_orm.Mapper`. .. seealso:: @@ -2548,7 +2557,7 @@ class Mapper( """ return self._filter_properties( - util.preloaded.orm_descriptor_props.CompositeProperty + util.preloaded.orm_descriptor_props.Composite ) def _filter_properties(self, type_): diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index b035dbef2..f28c45fab 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -13,37 +13,60 @@ mapped attributes. """ from typing import Any +from typing import cast +from typing import List +from typing import Optional +from typing import Set from typing import TypeVar from . import attributes from . import strategy_options -from .descriptor_props import CompositeProperty +from .base import SQLCoreOperations +from .descriptor_props import Composite from .descriptor_props import ConcreteInheritedProperty -from .descriptor_props import SynonymProperty +from .descriptor_props import Synonym +from .interfaces import _IntrospectsAnnotations +from .interfaces import _MapsColumns +from .interfaces import MapperProperty from .interfaces import PropComparator from .interfaces import StrategizedProperty -from .relationships import RelationshipProperty +from .relationships import Relationship +from .util import _extract_mapped_subtype from .util import _orm_full_deannotate +from .. import exc as sa_exc +from .. import ForeignKey from .. import log from .. import sql from .. import util from ..sql import coercions +from ..sql import operators from ..sql import roles +from ..sql import sqltypes +from ..sql.schema import Column +from ..util.typing import de_optionalize_union_types +from ..util.typing import de_stringify_annotation +from ..util.typing import is_fwd_ref +from ..util.typing import NoneType _T = TypeVar("_T", bound=Any) _PT = TypeVar("_PT", bound=Any) __all__ = [ "ColumnProperty", - "CompositeProperty", + "Composite", "ConcreteInheritedProperty", - "RelationshipProperty", - "SynonymProperty", + "Relationship", + "Synonym", ] @log.class_logger -class ColumnProperty(StrategizedProperty[_T]): +class ColumnProperty( + _MapsColumns[_T], + StrategizedProperty[_T], + _IntrospectsAnnotations, + log.Identified, +): """Describes an object attribute that corresponds to a table column. Public constructor is the :func:`_orm.column_property` function. @@ -65,7 +88,6 @@ class ColumnProperty(StrategizedProperty[_T]): "active_history", "expire_on_flush", "doc", - "strategy_key", "_creation_order", "_is_polymorphic_discriminator", "_mapped_by_synonym", @@ -84,8 +106,8 @@ class ColumnProperty(StrategizedProperty[_T]): coercions.expect(roles.LabeledColumnExprRole, c) for c in columns ] self.columns = [ - coercions.expect( - roles.LabeledColumnExprRole, _orm_full_deannotate(c) + _orm_full_deannotate( + coercions.expect(roles.LabeledColumnExprRole, c) ) for c in columns ] @@ -130,6 +152,27 @@ class ColumnProperty(StrategizedProperty[_T]): if self.raiseload: self.strategy_key += (("raiseload", True),) + def declarative_scan( + self, registry, cls, key, annotation, is_dataclass_field + ): + column = self.columns[0] + if column.key is None: + column.key = key + if column.name is None: + column.name = key + + @property + def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]: + return self + + @property + def columns_to_assign(self) -> List[Column]: + return [ + c + for c in self.columns + if isinstance(c, Column) and c.table is None + ] + def _memoized_attr__renders_in_subqueries(self): return ("deferred", True) not in self.strategy_key or ( self not in self.parent._readonly_props @@ -197,7 +240,7 @@ class ColumnProperty(StrategizedProperty[_T]): ) def do_init(self): - super(ColumnProperty, self).do_init() + super().do_init() if len(self.columns) > 1 and set(self.parent.primary_key).issuperset( self.columns @@ -364,3 +407,135 @@ class ColumnProperty(StrategizedProperty[_T]): if not self.parent or not self.key: return object.__repr__(self) return str(self.parent.class_.__name__) + "." + self.key + + +class MappedColumn( + SQLCoreOperations[_T], + operators.ColumnOperators[SQLCoreOperations], + _IntrospectsAnnotations, + _MapsColumns[_T], +): + """Maps a single :class:`_schema.Column` on a class. + + :class:`_orm.MappedColumn` is a specialization of the + :class:`_orm.ColumnProperty` class and is oriented towards declarative + configuration. + + To construct :class:`_orm.MappedColumn` objects, use the + :func:`_orm.mapped_column` constructor function. + + .. versionadded:: 2.0 + + + """ + + __slots__ = ( + "column", + "_creation_order", + "foreign_keys", + "_has_nullable", + "deferred", + ) + + deferred: bool + column: Column[_T] + foreign_keys: Optional[Set[ForeignKey]] + + def __init__(self, *arg, **kw): + self.deferred = kw.pop("deferred", False) + self.column = cast("Column[_T]", Column(*arg, **kw)) + self.foreign_keys = self.column.foreign_keys + self._has_nullable = "nullable" in kw + util.set_creation_order(self) + + def _copy(self, **kw): + new = self.__class__.__new__(self.__class__) + new.column = self.column._copy(**kw) + new.deferred = self.deferred + new.foreign_keys = new.column.foreign_keys + new._has_nullable = self._has_nullable + util.set_creation_order(new) + return new + + @property + def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]: + if self.deferred: + return ColumnProperty(self.column, deferred=True) + else: + return None + + @property + def columns_to_assign(self) -> List[Column]: + return [self.column] + + def __clause_element__(self): + return self.column + + def operate(self, op, *other, **kwargs): + return op(self.__clause_element__(), *other, **kwargs) + + def reverse_operate(self, op, other, **kwargs): + col = self.__clause_element__() + return op(col._bind_param(op, other), col, **kwargs) + + def declarative_scan( + self, registry, cls, key, annotation, is_dataclass_field + ): + column = self.column + if column.key is None: + column.key = key + if column.name is None: + column.name = key + + sqltype = column.type + + argument = _extract_mapped_subtype( + annotation, + cls, + key, + MappedColumn, + sqltype._isnull and not self.column.foreign_keys, + is_dataclass_field, + ) + if argument is None: + return + + self._init_column_for_annotation(cls, registry, argument) + + @util.preload_module("sqlalchemy.orm.decl_base") + def declarative_scan_for_composite( + self, registry, cls, key, param_name, param_annotation + ): + decl_base = util.preloaded.orm_decl_base + decl_base._undefer_column_name(param_name, self.column) + self._init_column_for_annotation(cls, registry, param_annotation) + + def _init_column_for_annotation(self, cls, registry, argument): + sqltype = self.column.type + + nullable = False + + if hasattr(argument, "__origin__"): + nullable = NoneType in argument.__args__ + + if not self._has_nullable: + self.column.nullable = nullable + + if sqltype._isnull and not self.column.foreign_keys: + sqltype = None + our_type = de_optionalize_union_types(argument) + + if is_fwd_ref(our_type): + our_type = de_stringify_annotation(cls, our_type) + + if registry.type_annotation_map: + sqltype = registry.type_annotation_map.get(our_type) + if sqltype is None: + sqltype = sqltypes._type_map_get(our_type) + + if sqltype is None: + raise sa_exc.ArgumentError( + f"Could not locate SQLAlchemy Core " + f"type for Python type: {our_type}" + ) + self.column.type = sqltype diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 15259f130..61174487a 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -21,7 +21,12 @@ database to return iterable result sets. import collections.abc as collections_abc import itertools import operator -import typing +from typing import Any +from typing import Generic +from typing import Iterable +from typing import List +from typing import Optional +from typing import TypeVar from . import exc as orm_exc from . import interfaces @@ -35,8 +40,9 @@ from .context import LABEL_STYLE_LEGACY_ORM from .context import ORMCompileState from .context import ORMFromStatementCompileState from .context import QueryContext +from .interfaces import ORMColumnDescription from .interfaces import ORMColumnsClauseRole -from .util import aliased +from .util import AliasedClass from .util import object_mapper from .util import with_parent from .. import exc as sa_exc @@ -45,16 +51,19 @@ from .. import inspection from .. import log from .. import sql from .. import util +from ..engine import Result from ..sql import coercions from ..sql import expression from ..sql import roles from ..sql import Select from ..sql import util as sql_util from ..sql import visitors +from ..sql._typing import _FromClauseElement from ..sql.annotation import SupportsCloneAnnotations from ..sql.base import _entity_namespace_key from ..sql.base import _generative from ..sql.base import Executable +from ..sql.expression import Exists from ..sql.selectable import _MemoizedSelectEntities from ..sql.selectable import _SelectFromElements from ..sql.selectable import ForUpdateArg @@ -67,9 +76,12 @@ from ..sql.selectable import SelectBase from ..sql.selectable import SelectStatementGrouping from ..sql.visitors import InternalTraversal -__all__ = ["Query", "QueryContext", "aliased"] -SelfQuery = typing.TypeVar("SelfQuery", bound="Query") +__all__ = ["Query", "QueryContext"] + +_T = TypeVar("_T", bound=Any) + +SelfQuery = TypeVar("SelfQuery", bound="Query") @inspection._self_inspects @@ -80,7 +92,9 @@ class Query( HasPrefixes, HasSuffixes, HasHints, + log.Identified, Executable, + Generic[_T], ): """ORM-level SQL construction object. @@ -1040,7 +1054,7 @@ class Query( for prop in mapper.iterate_properties: if ( - isinstance(prop, relationships.RelationshipProperty) + isinstance(prop, relationships.Relationship) and prop.mapper is entity_zero.mapper ): property = prop # noqa @@ -1064,7 +1078,7 @@ class Query( if alias is not None: # TODO: deprecate - entity = aliased(entity, alias) + entity = AliasedClass(entity, alias) self._raw_columns = list(self._raw_columns) @@ -1992,7 +2006,9 @@ class Query( @_generative @_assertions(_no_clauseelement_condition) - def select_from(self: SelfQuery, *from_obj) -> SelfQuery: + def select_from( + self: SelfQuery, *from_obj: _FromClauseElement + ) -> SelfQuery: r"""Set the FROM clause of this :class:`.Query` explicitly. :meth:`.Query.select_from` is often used in conjunction with @@ -2144,7 +2160,7 @@ class Query( self._distinct = True return self - def all(self): + def all(self) -> List[_T]: """Return the results represented by this :class:`_query.Query` as a list. @@ -2183,7 +2199,7 @@ class Query( self._statement = statement return self - def first(self): + def first(self) -> Optional[_T]: """Return the first result of this ``Query`` or None if the result doesn't contain any row. @@ -2209,7 +2225,7 @@ class Query( else: return self.limit(1)._iter().first() - def one_or_none(self): + def one_or_none(self) -> Optional[_T]: """Return at most one result or raise an exception. Returns ``None`` if the query selects @@ -2235,7 +2251,7 @@ class Query( """ return self._iter().one_or_none() - def one(self): + def one(self) -> _T: """Return exactly one result or raise an exception. Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query selects @@ -2255,7 +2271,7 @@ class Query( """ return self._iter().one() - def scalar(self): + def scalar(self) -> Any: """Return the first element of the first result or None if no rows present. If multiple rows are returned, raises MultipleResultsFound. @@ -2283,7 +2299,7 @@ class Query( except orm_exc.NoResultFound: return None - def __iter__(self): + def __iter__(self) -> Iterable[_T]: return self._iter().__iter__() def _iter(self): @@ -2309,7 +2325,7 @@ class Query( return result - def __str__(self): + def __str__(self) -> str: statement = self._statement_20() try: @@ -2327,7 +2343,7 @@ class Query( return fn(clause=statement, **kw) @property - def column_descriptions(self): + def column_descriptions(self) -> List[ORMColumnDescription]: """Return metadata about the columns which would be returned by this :class:`_query.Query`. @@ -2368,7 +2384,7 @@ class Query( return _column_descriptions(self, legacy=True) - def instances(self, result_proxy, context=None): + def instances(self, result_proxy: Result, context=None) -> Any: """Return an ORM result given a :class:`_engine.CursorResult` and :class:`.QueryContext`. @@ -2400,6 +2416,7 @@ class Query( if result._attributes.get("filtered", False): result = result.unique() + # TODO: isn't this supposed to be a list? return result @util.became_legacy_20( @@ -2436,7 +2453,7 @@ class Query( return loading.merge_result(self, iterator, load) - def exists(self): + def exists(self) -> Exists: """A convenience method that turns a query into an EXISTS subquery of the form EXISTS (SELECT 1 FROM ... WHERE ...). diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index c5ea07051..1b8f778c0 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -13,10 +13,15 @@ SQL annotation and aliasing behavior focused on the `primaryjoin` and `secondaryjoin` aspects of :func:`_orm.relationship`. """ +from __future__ import annotations + import collections +from collections import abc import re +import typing from typing import Any from typing import Callable +from typing import Optional from typing import Type from typing import TypeVar from typing import Union @@ -26,11 +31,13 @@ from . import attributes from . import strategy_options from .base import _is_mapped_class from .base import state_str +from .interfaces import _IntrospectsAnnotations from .interfaces import MANYTOMANY from .interfaces import MANYTOONE from .interfaces import ONETOMANY from .interfaces import PropComparator from .interfaces import StrategizedProperty +from .util import _extract_mapped_subtype from .util import _orm_annotate from .util import _orm_deannotate from .util import CascadeOptions @@ -53,10 +60,26 @@ from ..sql.util import join_condition from ..sql.util import selectables_overlap from ..sql.util import visit_binary_product +if typing.TYPE_CHECKING: + from .mapper import Mapper + from .util import AliasedClass + from .util import AliasedInsp + _T = TypeVar("_T", bound=Any) _PT = TypeVar("_PT", bound=Any) +_RelationshipArgumentType = Union[ + str, + Type[_T], + Callable[[], Type[_T]], + "Mapper[_T]", + "AliasedClass[_T]", + Callable[[], "Mapper[_T]"], + Callable[[], "AliasedClass[_T]"], +] + + def remote(expr): """Annotate a portion of a primaryjoin expression with a 'remote' annotation. @@ -97,7 +120,9 @@ def foreign(expr): @log.class_logger -class RelationshipProperty(StrategizedProperty[_T]): +class Relationship( + _IntrospectsAnnotations, StrategizedProperty[_T], log.Identified +): """Describes an object property that holds a single item or list of items that correspond to a related database table. @@ -107,6 +132,10 @@ class RelationshipProperty(StrategizedProperty[_T]): :ref:`relationship_config_toplevel` + .. versionchanged:: 2.0 Renamed :class:`_orm.RelationshipProperty` + to :class:`_orm.Relationship`. The old name + :class:`_orm.RelationshipProperty` remains as an alias. + """ strategy_wildcard_key = strategy_options._RELATIONSHIP_TOKEN @@ -126,7 +155,7 @@ class RelationshipProperty(StrategizedProperty[_T]): def __init__( self, - argument: Union[str, Type[_T], Callable[[], Type[_T]]], + argument: Optional[_RelationshipArgumentType[_T]] = None, secondary=None, primaryjoin=None, secondaryjoin=None, @@ -162,7 +191,7 @@ class RelationshipProperty(StrategizedProperty[_T]): sync_backref=None, _legacy_inactive_history_style=False, ): - super(RelationshipProperty, self).__init__() + super(Relationship, self).__init__() self.uselist = uselist self.argument = argument @@ -221,9 +250,7 @@ class RelationshipProperty(StrategizedProperty[_T]): self.local_remote_pairs = _local_remote_pairs self.bake_queries = bake_queries self.load_on_pending = load_on_pending - self.comparator_factory = ( - comparator_factory or RelationshipProperty.Comparator - ) + self.comparator_factory = comparator_factory or Relationship.Comparator self.comparator = self.comparator_factory(self, None) util.set_creation_order(self) @@ -288,7 +315,7 @@ class RelationshipProperty(StrategizedProperty[_T]): class Comparator(PropComparator[_PT]): """Produce boolean, comparison, and other operators for - :class:`.RelationshipProperty` attributes. + :class:`.Relationship` attributes. See the documentation for :class:`.PropComparator` for a brief overview of ORM level operator definition. @@ -318,7 +345,7 @@ class RelationshipProperty(StrategizedProperty[_T]): of_type=None, extra_criteria=(), ): - """Construction of :class:`.RelationshipProperty.Comparator` + """Construction of :class:`.Relationship.Comparator` is internal to the ORM's attribute mechanics. """ @@ -340,7 +367,7 @@ class RelationshipProperty(StrategizedProperty[_T]): @util.memoized_property def entity(self): """The target entity referred to by this - :class:`.RelationshipProperty.Comparator`. + :class:`.Relationship.Comparator`. This is either a :class:`_orm.Mapper` or :class:`.AliasedInsp` object. @@ -360,7 +387,7 @@ class RelationshipProperty(StrategizedProperty[_T]): @util.memoized_property def mapper(self): """The target :class:`_orm.Mapper` referred to by this - :class:`.RelationshipProperty.Comparator`. + :class:`.Relationship.Comparator`. This is the "target" or "remote" side of the :func:`_orm.relationship`. @@ -411,7 +438,7 @@ class RelationshipProperty(StrategizedProperty[_T]): """ - return RelationshipProperty.Comparator( + return Relationship.Comparator( self.property, self._parententity, adapt_to_entity=self._adapt_to_entity, @@ -427,7 +454,7 @@ class RelationshipProperty(StrategizedProperty[_T]): .. versionadded:: 1.4 """ - return RelationshipProperty.Comparator( + return Relationship.Comparator( self.property, self._parententity, adapt_to_entity=self._adapt_to_entity, @@ -468,7 +495,7 @@ class RelationshipProperty(StrategizedProperty[_T]): many-to-one comparisons: * Comparisons against collections are not supported. - Use :meth:`~.RelationshipProperty.Comparator.contains`. + Use :meth:`~.Relationship.Comparator.contains`. * Compared to a scalar one-to-many, will produce a clause that compares the target columns in the parent to the given target. @@ -479,7 +506,7 @@ class RelationshipProperty(StrategizedProperty[_T]): queries that go beyond simple AND conjunctions of comparisons, such as those which use OR. Use explicit joins, outerjoins, or - :meth:`~.RelationshipProperty.Comparator.has` for + :meth:`~.Relationship.Comparator.has` for more comprehensive non-many-to-one scalar membership tests. * Comparisons against ``None`` given in a one-to-many @@ -613,12 +640,12 @@ class RelationshipProperty(StrategizedProperty[_T]): EXISTS (SELECT 1 FROM related WHERE related.my_id=my_table.id AND related.x=2) - Because :meth:`~.RelationshipProperty.Comparator.any` uses + Because :meth:`~.Relationship.Comparator.any` uses a correlated subquery, its performance is not nearly as good when compared against large target tables as that of using a join. - :meth:`~.RelationshipProperty.Comparator.any` is particularly + :meth:`~.Relationship.Comparator.any` is particularly useful for testing for empty collections:: session.query(MyClass).filter( @@ -631,10 +658,10 @@ class RelationshipProperty(StrategizedProperty[_T]): NOT (EXISTS (SELECT 1 FROM related WHERE related.my_id=my_table.id)) - :meth:`~.RelationshipProperty.Comparator.any` is only + :meth:`~.Relationship.Comparator.any` is only valid for collections, i.e. a :func:`_orm.relationship` that has ``uselist=True``. For scalar references, - use :meth:`~.RelationshipProperty.Comparator.has`. + use :meth:`~.Relationship.Comparator.has`. """ if not self.property.uselist: @@ -662,15 +689,15 @@ class RelationshipProperty(StrategizedProperty[_T]): EXISTS (SELECT 1 FROM related WHERE related.id==my_table.related_id AND related.x=2) - Because :meth:`~.RelationshipProperty.Comparator.has` uses + Because :meth:`~.Relationship.Comparator.has` uses a correlated subquery, its performance is not nearly as good when compared against large target tables as that of using a join. - :meth:`~.RelationshipProperty.Comparator.has` is only + :meth:`~.Relationship.Comparator.has` is only valid for scalar references, i.e. a :func:`_orm.relationship` that has ``uselist=False``. For collection references, - use :meth:`~.RelationshipProperty.Comparator.any`. + use :meth:`~.Relationship.Comparator.any`. """ if self.property.uselist: @@ -683,7 +710,7 @@ class RelationshipProperty(StrategizedProperty[_T]): """Return a simple expression that tests a collection for containment of a particular item. - :meth:`~.RelationshipProperty.Comparator.contains` is + :meth:`~.Relationship.Comparator.contains` is only valid for a collection, i.e. a :func:`_orm.relationship` that implements one-to-many or many-to-many with ``uselist=True``. @@ -700,12 +727,12 @@ class RelationshipProperty(StrategizedProperty[_T]): Where ``<some id>`` is the value of the foreign key attribute on ``other`` which refers to the primary key of its parent object. From this it follows that - :meth:`~.RelationshipProperty.Comparator.contains` is + :meth:`~.Relationship.Comparator.contains` is very useful when used with simple one-to-many operations. For many-to-many operations, the behavior of - :meth:`~.RelationshipProperty.Comparator.contains` + :meth:`~.Relationship.Comparator.contains` has more caveats. The association table will be rendered in the statement, producing an "implicit" join, that is, includes multiple tables in the FROM @@ -722,14 +749,14 @@ class RelationshipProperty(StrategizedProperty[_T]): Where ``<some id>`` would be the primary key of ``other``. From the above, it is clear that - :meth:`~.RelationshipProperty.Comparator.contains` + :meth:`~.Relationship.Comparator.contains` will **not** work with many-to-many collections when used in queries that move beyond simple AND conjunctions, such as multiple - :meth:`~.RelationshipProperty.Comparator.contains` + :meth:`~.Relationship.Comparator.contains` expressions joined by OR. In such cases subqueries or explicit "outer joins" will need to be used instead. - See :meth:`~.RelationshipProperty.Comparator.any` for + See :meth:`~.Relationship.Comparator.any` for a less-performant alternative using EXISTS, or refer to :meth:`_query.Query.outerjoin` as well as :ref:`ormtutorial_joins` @@ -818,7 +845,7 @@ class RelationshipProperty(StrategizedProperty[_T]): * Comparisons against collections are not supported. Use - :meth:`~.RelationshipProperty.Comparator.contains` + :meth:`~.Relationship.Comparator.contains` in conjunction with :func:`_expression.not_`. * Compared to a scalar one-to-many, will produce a clause that compares the target columns in the parent to @@ -830,7 +857,7 @@ class RelationshipProperty(StrategizedProperty[_T]): queries that go beyond simple AND conjunctions of comparisons, such as those which use OR. Use explicit joins, outerjoins, or - :meth:`~.RelationshipProperty.Comparator.has` in + :meth:`~.Relationship.Comparator.has` in conjunction with :func:`_expression.not_` for more comprehensive non-many-to-one scalar membership tests. @@ -1249,7 +1276,7 @@ class RelationshipProperty(StrategizedProperty[_T]): def _add_reverse_property(self, key): other = self.mapper.get_property(key, _configure_mappers=False) - if not isinstance(other, RelationshipProperty): + if not isinstance(other, Relationship): raise sa_exc.InvalidRequestError( "back_populates on relationship '%s' refers to attribute '%s' " "that is not a relationship. The back_populates parameter " @@ -1269,6 +1296,8 @@ class RelationshipProperty(StrategizedProperty[_T]): self._reverse_property.add(other) other._reverse_property.add(self) + other._setup_entity() + if not other.mapper.common_parent(self.parent): raise sa_exc.ArgumentError( "reverse_property %r on " @@ -1289,48 +1318,18 @@ class RelationshipProperty(StrategizedProperty[_T]): ) @util.memoized_property - @util.preload_module("sqlalchemy.orm.mapper") - def entity(self): + def entity(self) -> Union["Mapper", "AliasedInsp"]: """Return the target mapped entity, which is an inspect() of the class or aliased class that is referred towards. """ - - mapperlib = util.preloaded.orm_mapper - - if isinstance(self.argument, str): - argument = self._clsregistry_resolve_name(self.argument)() - - elif callable(self.argument) and not isinstance( - self.argument, (type, mapperlib.Mapper) - ): - argument = self.argument() - else: - argument = self.argument - - if isinstance(argument, type): - return mapperlib.class_mapper(argument, configure=False) - - try: - entity = inspect(argument) - except sa_exc.NoInspectionAvailable: - pass - else: - if hasattr(entity, "mapper"): - return entity - - raise sa_exc.ArgumentError( - "relationship '%s' expects " - "a class or a mapper argument (received: %s)" - % (self.key, type(argument)) - ) + self.parent._check_configure() + return self.entity @util.memoized_property - def mapper(self): + def mapper(self) -> "Mapper": """Return the targeted :class:`_orm.Mapper` for this - :class:`.RelationshipProperty`. - - This is a lazy-initializing static attribute. + :class:`.Relationship`. """ return self.entity.mapper @@ -1338,13 +1337,14 @@ class RelationshipProperty(StrategizedProperty[_T]): def do_init(self): self._check_conflicts() self._process_dependent_arguments() + self._setup_entity() self._setup_registry_dependencies() self._setup_join_conditions() self._check_cascade_settings(self._cascade) self._post_init() self._generate_backref() self._join_condition._warn_for_conflicting_sync_targets() - super(RelationshipProperty, self).do_init() + super(Relationship, self).do_init() self._lazy_strategy = self._get_strategy((("lazy", "select"),)) def _setup_registry_dependencies(self): @@ -1432,6 +1432,84 @@ class RelationshipProperty(StrategizedProperty[_T]): for x in util.to_column_set(self.remote_side) ) + def declarative_scan( + self, registry, cls, key, annotation, is_dataclass_field + ): + argument = _extract_mapped_subtype( + annotation, + cls, + key, + Relationship, + self.argument is None, + is_dataclass_field, + ) + if argument is None: + return + + if hasattr(argument, "__origin__"): + + collection_class = argument.__origin__ + if issubclass(collection_class, abc.Collection): + if self.collection_class is None: + self.collection_class = collection_class + else: + self.uselist = False + if argument.__args__: + if issubclass(argument.__origin__, typing.Mapping): + type_arg = argument.__args__[1] + else: + type_arg = argument.__args__[0] + if hasattr(type_arg, "__forward_arg__"): + str_argument = type_arg.__forward_arg__ + argument = str_argument + else: + argument = type_arg + else: + raise sa_exc.ArgumentError( + f"Generic alias {argument} requires an argument" + ) + elif hasattr(argument, "__forward_arg__"): + argument = argument.__forward_arg__ + + self.argument = argument + + @util.preload_module("sqlalchemy.orm.mapper") + def _setup_entity(self, __argument=None): + if "entity" in self.__dict__: + return + + mapperlib = util.preloaded.orm_mapper + + if __argument: + argument = __argument + else: + argument = self.argument + + if isinstance(argument, str): + argument = self._clsregistry_resolve_name(argument)() + elif callable(argument) and not isinstance( + argument, (type, mapperlib.Mapper) + ): + argument = argument() + else: + argument = argument + + if isinstance(argument, type): + entity = mapperlib.class_mapper(argument, configure=False) + else: + try: + entity = inspect(argument) + except sa_exc.NoInspectionAvailable: + entity = None + + if not hasattr(entity, "mapper"): + raise sa_exc.ArgumentError( + "relationship '%s' expects " + "a class or a mapper argument (received: %s)" + % (self.key, type(argument)) + ) + + self.entity = entity # type: ignore self.target = self.entity.persist_selectable def _setup_join_conditions(self): @@ -1502,7 +1580,7 @@ class RelationshipProperty(StrategizedProperty[_T]): @property def cascade(self): """Return the current cascade setting for this - :class:`.RelationshipProperty`. + :class:`.Relationship`. """ return self._cascade @@ -1666,7 +1744,7 @@ class RelationshipProperty(StrategizedProperty[_T]): kwargs.setdefault("passive_updates", self.passive_updates) kwargs.setdefault("sync_backref", self.sync_backref) self.back_populates = backref_key - relationship = RelationshipProperty( + relationship = Relationship( parent, self.secondary, pj, diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index cf47ee729..6911ab505 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -9,6 +9,15 @@ import contextlib import itertools import sys +import typing +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import overload +from typing import Tuple +from typing import Type +from typing import Union import weakref from . import attributes @@ -20,12 +29,15 @@ from . import persistence from . import query from . import state as statelib from .base import _class_to_mapper +from .base import _IdentityKeyType from .base import _none_set from .base import _state_mapper from .base import instance_str from .base import object_mapper from .base import object_state from .base import state_str +from .query import Query +from .state import InstanceState from .state_changes import _StateChange from .state_changes import _StateChangeState from .state_changes import _StateChangeStates @@ -34,14 +46,26 @@ from .. import engine from .. import exc as sa_exc from .. import sql from .. import util +from ..engine import Connection +from ..engine import Engine from ..engine.util import TransactionalContext from ..inspection import inspect from ..sql import coercions from ..sql import dml from ..sql import roles from ..sql import visitors +from ..sql._typing import _ColumnsClauseElement from ..sql.base import CompileState from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL +from ..util.typing import Literal + +if typing.TYPE_CHECKING: + from .mapper import Mapper + from ..engine import Row + from ..sql._typing import _ExecuteOptions + from ..sql._typing import _ExecuteParams + from ..sql.base import Executable + from ..sql.schema import Table __all__ = [ "Session", @@ -78,23 +102,60 @@ class _SessionClassMethods: "removed in a future release. Please refer to " ":func:`.session.close_all_sessions`.", ) - def close_all(cls): + def close_all(cls) -> None: """Close *all* sessions in memory.""" close_all_sessions() @classmethod + @overload + def identity_key( + cls, + class_: type, + ident: Tuple[Any, ...], + *, + identity_token: Optional[str], + ) -> _IdentityKeyType: + ... + + @classmethod + @overload + def identity_key(cls, *, instance: Any) -> _IdentityKeyType: + ... + + @classmethod + @overload + def identity_key( + cls, class_: type, *, row: "Row", identity_token: Optional[str] + ) -> _IdentityKeyType: + ... + + @classmethod @util.preload_module("sqlalchemy.orm.util") - def identity_key(cls, *args, **kwargs): + def identity_key( + cls, + class_=None, + ident=None, + *, + instance=None, + row=None, + identity_token=None, + ) -> _IdentityKeyType: """Return an identity key. This is an alias of :func:`.util.identity_key`. """ - return util.preloaded.orm_util.identity_key(*args, **kwargs) + return util.preloaded.orm_util.identity_key( + class_, + ident, + instance=instance, + row=row, + identity_token=identity_token, + ) @classmethod - def object_session(cls, instance): + def object_session(cls, instance: Any) -> "Session": """Return the :class:`.Session` to which an object belongs. This is an alias of :func:`.object_session`. @@ -142,15 +203,26 @@ class ORMExecuteState(util.MemoizedSlots): "_update_execution_options", ) + session: "Session" + statement: "Executable" + parameters: "_ExecuteParams" + execution_options: "_ExecuteOptions" + local_execution_options: "_ExecuteOptions" + bind_arguments: Dict[str, Any] + _compile_state_cls: Type[context.ORMCompileState] + _starting_event_idx: Optional[int] + _events_todo: List[Any] + _update_execution_options: Optional["_ExecuteOptions"] + def __init__( self, - session, - statement, - parameters, - execution_options, - bind_arguments, - compile_state_cls, - events_todo, + session: "Session", + statement: "Executable", + parameters: "_ExecuteParams", + execution_options: "_ExecuteOptions", + bind_arguments: Dict[str, Any], + compile_state_cls: Type[context.ORMCompileState], + events_todo: List[Any], ): self.session = session self.statement = statement @@ -834,7 +906,7 @@ class SessionTransaction(_StateChange, TransactionalContext): (SessionTransactionState.ACTIVE, SessionTransactionState.PREPARED), SessionTransactionState.CLOSED, ) - def commit(self, _to_root=False): + def commit(self, _to_root: bool = False) -> None: if self._state is not SessionTransactionState.PREPARED: with self._expect_state(SessionTransactionState.PREPARED): self._prepare_impl() @@ -981,18 +1053,42 @@ class Session(_SessionClassMethods): _is_asyncio = False + identity_map: identity.IdentityMap + _new: Dict["InstanceState", Any] + _deleted: Dict["InstanceState", Any] + bind: Optional[Union[Engine, Connection]] + __binds: Dict[ + Union[type, "Mapper", "Table"], + Union[engine.Engine, engine.Connection], + ] + _flusing: bool + _warn_on_events: bool + _transaction: Optional[SessionTransaction] + _nested_transaction: Optional[SessionTransaction] + hash_key: int + autoflush: bool + expire_on_commit: bool + enable_baked_queries: bool + twophase: bool + _query_cls: Type[Query] + def __init__( self, - bind=None, - autoflush=True, - future=True, - expire_on_commit=True, - twophase=False, - binds=None, - enable_baked_queries=True, - info=None, - query_cls=None, - autocommit=False, + bind: Optional[Union[engine.Engine, engine.Connection]] = None, + autoflush: bool = True, + future: Literal[True] = True, + expire_on_commit: bool = True, + twophase: bool = False, + binds: Optional[ + Dict[ + Union[type, "Mapper", "Table"], + Union[engine.Engine, engine.Connection], + ] + ] = None, + enable_baked_queries: bool = True, + info: Optional[Dict[Any, Any]] = None, + query_cls: Optional[Type[query.Query]] = None, + autocommit: Literal[False] = False, ): r"""Construct a new Session. @@ -1054,7 +1150,8 @@ class Session(_SessionClassMethods): :class:`.sessionmaker` function, and is not sent directly to the constructor for ``Session``. - :param enable_baked_queries: defaults to ``True``. A flag consumed + :param enable_baked_queries: legacy; defaults to ``True``. + A parameter consumed by the :mod:`sqlalchemy.ext.baked` extension to determine if "baked queries" should be cached, as is the normal operation of this extension. When set to ``False``, caching as used by @@ -1331,7 +1428,7 @@ class Session(_SessionClassMethods): else: self._transaction.rollback(_to_root=True) - def commit(self): + def commit(self) -> None: """Flush pending changes and commit the current transaction. If no transaction is in progress, the method will first @@ -1353,7 +1450,7 @@ class Session(_SessionClassMethods): self._transaction.commit(_to_root=True) - def prepare(self): + def prepare(self) -> None: """Prepare the current transaction in progress for two phase commit. If no transaction is in progress, this method raises an @@ -1370,7 +1467,11 @@ class Session(_SessionClassMethods): self._transaction.prepare() - def connection(self, bind_arguments=None, execution_options=None): + def connection( + self, + bind_arguments: Optional[Dict[str, Any]] = None, + execution_options: Optional["_ExecuteOptions"] = None, + ) -> "Connection": r"""Return a :class:`_engine.Connection` object corresponding to this :class:`.Session` object's transactional state. @@ -1425,12 +1526,12 @@ class Session(_SessionClassMethods): def execute( self, - statement, - params=None, - execution_options=util.EMPTY_DICT, - bind_arguments=None, - _parent_execute_state=None, - _add_event=None, + statement: "Executable", + params: Optional["_ExecuteParams"] = None, + execution_options: "_ExecuteOptions" = util.EMPTY_DICT, + bind_arguments: Optional[Dict[str, Any]] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, ): r"""Execute a SQL expression construct. @@ -1936,7 +2037,9 @@ class Session(_SessionClassMethods): % (", ".join(context),), ) - def query(self, *entities, **kwargs): + def query( + self, *entities: "_ColumnsClauseElement", **kwargs: Any + ) -> "Query": """Return a new :class:`_query.Query` object corresponding to this :class:`_orm.Session`. @@ -2391,7 +2494,7 @@ class Session(_SessionClassMethods): if persistent_to_deleted is not None: persistent_to_deleted(self, state) - def add(self, instance, _warn=True): + def add(self, instance: Any, _warn: bool = True) -> None: """Place an object in the ``Session``. Its state will be persisted to the database on the next flush diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 07e71d4c0..316aa7ed7 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -34,7 +34,7 @@ from .interfaces import StrategizedProperty from .session import _state_session from .state import InstanceState from .util import _none_set -from .util import aliased +from .util import AliasedClass from .. import event from .. import exc as sa_exc from .. import inspect @@ -564,7 +564,7 @@ class AbstractRelationshipLoader(LoaderStrategy): @log.class_logger -@relationships.RelationshipProperty.strategy_for(do_nothing=True) +@relationships.Relationship.strategy_for(do_nothing=True) class DoNothingLoader(LoaderStrategy): """Relationship loader that makes no change to the object's state. @@ -576,10 +576,10 @@ class DoNothingLoader(LoaderStrategy): @log.class_logger -@relationships.RelationshipProperty.strategy_for(lazy="noload") -@relationships.RelationshipProperty.strategy_for(lazy=None) +@relationships.Relationship.strategy_for(lazy="noload") +@relationships.Relationship.strategy_for(lazy=None) class NoLoader(AbstractRelationshipLoader): - """Provide loading behavior for a :class:`.RelationshipProperty` + """Provide loading behavior for a :class:`.Relationship` with "lazy=None". """ @@ -617,13 +617,13 @@ class NoLoader(AbstractRelationshipLoader): @log.class_logger -@relationships.RelationshipProperty.strategy_for(lazy=True) -@relationships.RelationshipProperty.strategy_for(lazy="select") -@relationships.RelationshipProperty.strategy_for(lazy="raise") -@relationships.RelationshipProperty.strategy_for(lazy="raise_on_sql") -@relationships.RelationshipProperty.strategy_for(lazy="baked_select") +@relationships.Relationship.strategy_for(lazy=True) +@relationships.Relationship.strategy_for(lazy="select") +@relationships.Relationship.strategy_for(lazy="raise") +@relationships.Relationship.strategy_for(lazy="raise_on_sql") +@relationships.Relationship.strategy_for(lazy="baked_select") class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): - """Provide loading behavior for a :class:`.RelationshipProperty` + """Provide loading behavior for a :class:`.Relationship` with "lazy=True", that is loads when first accessed. """ @@ -1214,7 +1214,7 @@ class PostLoader(AbstractRelationshipLoader): ) -@relationships.RelationshipProperty.strategy_for(lazy="immediate") +@relationships.Relationship.strategy_for(lazy="immediate") class ImmediateLoader(PostLoader): __slots__ = () @@ -1250,7 +1250,7 @@ class ImmediateLoader(PostLoader): @log.class_logger -@relationships.RelationshipProperty.strategy_for(lazy="subquery") +@relationships.Relationship.strategy_for(lazy="subquery") class SubqueryLoader(PostLoader): __slots__ = ("join_depth",) @@ -1906,10 +1906,10 @@ class SubqueryLoader(PostLoader): @log.class_logger -@relationships.RelationshipProperty.strategy_for(lazy="joined") -@relationships.RelationshipProperty.strategy_for(lazy=False) +@relationships.Relationship.strategy_for(lazy="joined") +@relationships.Relationship.strategy_for(lazy=False) class JoinedLoader(AbstractRelationshipLoader): - """Provide loading behavior for a :class:`.RelationshipProperty` + """Provide loading behavior for a :class:`.Relationship` using joined eager loading. """ @@ -2628,7 +2628,7 @@ class JoinedLoader(AbstractRelationshipLoader): @log.class_logger -@relationships.RelationshipProperty.strategy_for(lazy="selectin") +@relationships.Relationship.strategy_for(lazy="selectin") class SelectInLoader(PostLoader, util.MemoizedSlots): __slots__ = ( "join_depth", @@ -2721,7 +2721,7 @@ class SelectInLoader(PostLoader, util.MemoizedSlots): ) def _init_for_join(self): - self._parent_alias = aliased(self.parent.class_) + self._parent_alias = AliasedClass(self.parent.class_) pa_insp = inspect(self._parent_alias) pk_cols = [ pa_insp._adapt_element(col) for col in self.parent.primary_key diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py index 0f993b86c..3f093e543 100644 --- a/lib/sqlalchemy/orm/strategy_options.py +++ b/lib/sqlalchemy/orm/strategy_options.py @@ -1808,7 +1808,7 @@ class _AttributeStrategyLoad(_LoadElement): assert pwpi if not pwpi.is_aliased_class: pwpi = inspect( - orm_util.with_polymorphic( + orm_util.AliasedInsp._with_polymorphic_factory( pwpi.mapper.base_mapper, pwpi.mapper, aliased=True, diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 75f711007..45c578355 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -5,13 +5,22 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php - import re import types +import typing +from typing import Any +from typing import Generic +from typing import Optional +from typing import overload +from typing import Tuple +from typing import Type +from typing import TypeVar +from typing import Union import weakref from . import attributes # noqa from .base import _class_to_mapper # noqa +from .base import _IdentityKeyType from .base import _never_set # noqa from .base import _none_set # noqa from .base import attribute_str # noqa @@ -45,8 +54,17 @@ from ..sql import util as sql_util from ..sql import visitors from ..sql.annotation import SupportsCloneAnnotations from ..sql.base import ColumnCollection +from ..sql.selectable import FromClause from ..util.langhelpers import MemoizedSlots +from ..util.typing import de_stringify_annotation +from ..util.typing import is_origin_of + +if typing.TYPE_CHECKING: + from .mapper import Mapper + from ..engine import Row + from ..sql.selectable import Alias +_T = TypeVar("_T", bound=Any) all_cascades = frozenset( ( @@ -276,7 +294,28 @@ def polymorphic_union( return sql.union_all(*result).alias(aliasname) -def identity_key(*args, **kwargs): +@overload +def identity_key( + class_: type, ident: Tuple[Any, ...], *, identity_token: Optional[str] +) -> _IdentityKeyType: + ... + + +@overload +def identity_key(*, instance: Any) -> _IdentityKeyType: + ... + + +@overload +def identity_key( + class_: type, *, row: "Row", identity_token: Optional[str] +) -> _IdentityKeyType: + ... + + +def identity_key( + class_=None, ident=None, *, instance=None, row=None, identity_token=None +) -> _IdentityKeyType: r"""Generate "identity key" tuples, as are used as keys in the :attr:`.Session.identity_map` dictionary. @@ -340,29 +379,11 @@ def identity_key(*args, **kwargs): .. versionadded:: 1.2 added identity_token """ - if args: - row = None - largs = len(args) - if largs == 1: - class_ = args[0] - try: - row = kwargs.pop("row") - except KeyError: - ident = kwargs.pop("ident") - elif largs in (2, 3): - class_, ident = args - else: - raise sa_exc.ArgumentError( - "expected up to three positional arguments, " "got %s" % largs - ) - - identity_token = kwargs.pop("identity_token", None) - if kwargs: - raise sa_exc.ArgumentError( - "unknown keyword arguments: %s" % ", ".join(kwargs) - ) + if class_ is not None: mapper = class_mapper(class_) if row is None: + if ident is None: + raise sa_exc.ArgumentError("ident or row is required") return mapper.identity_key_from_primary_key( util.to_list(ident), identity_token=identity_token ) @@ -370,14 +391,11 @@ def identity_key(*args, **kwargs): return mapper.identity_key_from_row( row, identity_token=identity_token ) - else: - instance = kwargs.pop("instance") - if kwargs: - raise sa_exc.ArgumentError( - "unknown keyword arguments: %s" % ", ".join(kwargs.keys) - ) + elif instance is not None: mapper = object_mapper(instance) return mapper.identity_key_from_instance(instance) + else: + raise sa_exc.ArgumentError("class or instance is required") class ORMAdapter(sql_util.ColumnAdapter): @@ -420,7 +438,7 @@ class ORMAdapter(sql_util.ColumnAdapter): return not entity or entity.isa(self.mapper) -class AliasedClass: +class AliasedClass(inspection.Inspectable["AliasedInsp"], Generic[_T]): r"""Represents an "aliased" form of a mapped class for usage with Query. The ORM equivalent of a :func:`~sqlalchemy.sql.expression.alias` @@ -481,7 +499,7 @@ class AliasedClass: def __init__( self, - mapped_class_or_ac, + mapped_class_or_ac: Union[Type[_T], "Mapper[_T]", "AliasedClass[_T]"], alias=None, name=None, flat=False, @@ -611,6 +629,7 @@ class AliasedInsp( ORMEntityColumnsClauseRole, ORMFromClauseRole, sql_base.HasCacheKey, + roles.HasFromClauseElement, InspectionAttr, MemoizedSlots, ): @@ -747,6 +766,73 @@ class AliasedInsp( self._target = mapped_class_or_ac # self._target = mapper.class_ # mapped_class_or_ac + @classmethod + def _alias_factory( + cls, + element: Union[ + Type[_T], "Mapper[_T]", "FromClause", "AliasedClass[_T]" + ], + alias=None, + name=None, + flat=False, + adapt_on_names=False, + ) -> Union["AliasedClass[_T]", "Alias"]: + + if isinstance(element, FromClause): + if adapt_on_names: + raise sa_exc.ArgumentError( + "adapt_on_names only applies to ORM elements" + ) + if name: + return element.alias(name=name, flat=flat) + else: + return coercions.expect( + roles.AnonymizedFromClauseRole, element, flat=flat + ) + else: + return AliasedClass( + element, + alias=alias, + flat=flat, + name=name, + adapt_on_names=adapt_on_names, + ) + + @classmethod + def _with_polymorphic_factory( + cls, + base, + classes, + selectable=False, + flat=False, + polymorphic_on=None, + aliased=False, + innerjoin=False, + _use_mapper_path=False, + ): + + primary_mapper = _class_to_mapper(base) + + if selectable not in (None, False) and flat: + raise sa_exc.ArgumentError( + "the 'flat' and 'selectable' arguments cannot be passed " + "simultaneously to with_polymorphic()" + ) + + mappers, selectable = primary_mapper._with_polymorphic_args( + classes, selectable, innerjoin=innerjoin + ) + if aliased or flat: + selectable = selectable._anonymous_fromclause(flat=flat) + return AliasedClass( + base, + selectable, + with_polymorphic_mappers=mappers, + with_polymorphic_discriminator=polymorphic_on, + use_mapper_path=_use_mapper_path, + represents_outer_join=not innerjoin, + ) + @property def entity(self): # to eliminate reference cycles, the AliasedClass is held weakly. @@ -1107,215 +1193,6 @@ inspection._inspects(AliasedClass)(lambda target: target._aliased_insp) inspection._inspects(AliasedInsp)(lambda target: target) -def aliased(element, alias=None, name=None, flat=False, adapt_on_names=False): - """Produce an alias of the given element, usually an :class:`.AliasedClass` - instance. - - E.g.:: - - my_alias = aliased(MyClass) - - session.query(MyClass, my_alias).filter(MyClass.id > my_alias.id) - - The :func:`.aliased` function is used to create an ad-hoc mapping of a - mapped class to a new selectable. By default, a selectable is generated - from the normally mapped selectable (typically a :class:`_schema.Table` - ) using the - :meth:`_expression.FromClause.alias` method. However, :func:`.aliased` - can also be - used to link the class to a new :func:`_expression.select` statement. - Also, the :func:`.with_polymorphic` function is a variant of - :func:`.aliased` that is intended to specify a so-called "polymorphic - selectable", that corresponds to the union of several joined-inheritance - subclasses at once. - - For convenience, the :func:`.aliased` function also accepts plain - :class:`_expression.FromClause` constructs, such as a - :class:`_schema.Table` or - :func:`_expression.select` construct. In those cases, the - :meth:`_expression.FromClause.alias` - method is called on the object and the new - :class:`_expression.Alias` object returned. The returned - :class:`_expression.Alias` is not - ORM-mapped in this case. - - .. seealso:: - - :ref:`tutorial_orm_entity_aliases` - in the :ref:`unified_tutorial` - - :ref:`orm_queryguide_orm_aliases` - in the :ref:`queryguide_toplevel` - - :ref:`ormtutorial_aliases` - in the legacy :ref:`ormtutorial_toplevel` - - :param element: element to be aliased. Is normally a mapped class, - but for convenience can also be a :class:`_expression.FromClause` - element. - - :param alias: Optional selectable unit to map the element to. This is - usually used to link the object to a subquery, and should be an aliased - select construct as one would produce from the - :meth:`_query.Query.subquery` method or - the :meth:`_expression.Select.subquery` or - :meth:`_expression.Select.alias` methods of the :func:`_expression.select` - construct. - - :param name: optional string name to use for the alias, if not specified - by the ``alias`` parameter. The name, among other things, forms the - attribute name that will be accessible via tuples returned by a - :class:`_query.Query` object. Not supported when creating aliases - of :class:`_sql.Join` objects. - - :param flat: Boolean, will be passed through to the - :meth:`_expression.FromClause.alias` call so that aliases of - :class:`_expression.Join` objects will alias the individual tables - inside the join, rather than creating a subquery. This is generally - supported by all modern databases with regards to right-nested joins - and generally produces more efficient queries. - - :param adapt_on_names: if True, more liberal "matching" will be used when - mapping the mapped columns of the ORM entity to those of the - given selectable - a name-based match will be performed if the - given selectable doesn't otherwise have a column that corresponds - to one on the entity. The use case for this is when associating - an entity with some derived selectable such as one that uses - aggregate functions:: - - class UnitPrice(Base): - __tablename__ = 'unit_price' - ... - unit_id = Column(Integer) - price = Column(Numeric) - - aggregated_unit_price = Session.query( - func.sum(UnitPrice.price).label('price') - ).group_by(UnitPrice.unit_id).subquery() - - aggregated_unit_price = aliased(UnitPrice, - alias=aggregated_unit_price, adapt_on_names=True) - - Above, functions on ``aggregated_unit_price`` which refer to - ``.price`` will return the - ``func.sum(UnitPrice.price).label('price')`` column, as it is - matched on the name "price". Ordinarily, the "price" function - wouldn't have any "column correspondence" to the actual - ``UnitPrice.price`` column as it is not a proxy of the original. - - """ - if isinstance(element, expression.FromClause): - if adapt_on_names: - raise sa_exc.ArgumentError( - "adapt_on_names only applies to ORM elements" - ) - if name: - return element.alias(name=name, flat=flat) - else: - return coercions.expect( - roles.AnonymizedFromClauseRole, element, flat=flat - ) - else: - return AliasedClass( - element, - alias=alias, - flat=flat, - name=name, - adapt_on_names=adapt_on_names, - ) - - -def with_polymorphic( - base, - classes, - selectable=False, - flat=False, - polymorphic_on=None, - aliased=False, - innerjoin=False, - _use_mapper_path=False, -): - """Produce an :class:`.AliasedClass` construct which specifies - columns for descendant mappers of the given base. - - Using this method will ensure that each descendant mapper's - tables are included in the FROM clause, and will allow filter() - criterion to be used against those tables. The resulting - instances will also have those columns already loaded so that - no "post fetch" of those columns will be required. - - .. seealso:: - - :ref:`with_polymorphic` - full discussion of - :func:`_orm.with_polymorphic`. - - :param base: Base class to be aliased. - - :param classes: a single class or mapper, or list of - class/mappers, which inherit from the base class. - Alternatively, it may also be the string ``'*'``, in which case - all descending mapped classes will be added to the FROM clause. - - :param aliased: when True, the selectable will be aliased. For a - JOIN, this means the JOIN will be SELECTed from inside of a subquery - unless the :paramref:`_orm.with_polymorphic.flat` flag is set to - True, which is recommended for simpler use cases. - - :param flat: Boolean, will be passed through to the - :meth:`_expression.FromClause.alias` call so that aliases of - :class:`_expression.Join` objects will alias the individual tables - inside the join, rather than creating a subquery. This is generally - supported by all modern databases with regards to right-nested joins - and generally produces more efficient queries. Setting this flag is - recommended as long as the resulting SQL is functional. - - :param selectable: a table or subquery that will - be used in place of the generated FROM clause. This argument is - required if any of the desired classes use concrete table - inheritance, since SQLAlchemy currently cannot generate UNIONs - among tables automatically. If used, the ``selectable`` argument - must represent the full set of tables and columns mapped by every - mapped class. Otherwise, the unaccounted mapped columns will - result in their table being appended directly to the FROM clause - which will usually lead to incorrect results. - - When left at its default value of ``False``, the polymorphic - selectable assigned to the base mapper is used for selecting rows. - However, it may also be passed as ``None``, which will bypass the - configured polymorphic selectable and instead construct an ad-hoc - selectable for the target classes given; for joined table inheritance - this will be a join that includes all target mappers and their - subclasses. - - :param polymorphic_on: a column to be used as the "discriminator" - column for the given selectable. If not given, the polymorphic_on - attribute of the base classes' mapper will be used, if any. This - is useful for mappings that don't have polymorphic loading - behavior by default. - - :param innerjoin: if True, an INNER JOIN will be used. This should - only be specified if querying for one specific subtype only - """ - primary_mapper = _class_to_mapper(base) - - if selectable not in (None, False) and flat: - raise sa_exc.ArgumentError( - "the 'flat' and 'selectable' arguments cannot be passed " - "simultaneously to with_polymorphic()" - ) - - mappers, selectable = primary_mapper._with_polymorphic_args( - classes, selectable, innerjoin=innerjoin - ) - if aliased or flat: - selectable = selectable._anonymous_fromclause(flat=flat) - return AliasedClass( - base, - selectable, - with_polymorphic_mappers=mappers, - with_polymorphic_discriminator=polymorphic_on, - use_mapper_path=_use_mapper_path, - represents_outer_join=not innerjoin, - ) - - @inspection._self_inspects class Bundle( ORMColumnsClauseRole, @@ -1667,62 +1544,6 @@ class _ORMJoin(expression.Join): return _ORMJoin(self, right, onclause, isouter=True, full=full) -def join( - left, right, onclause=None, isouter=False, full=False, join_to_left=None -): - r"""Produce an inner join between left and right clauses. - - :func:`_orm.join` is an extension to the core join interface - provided by :func:`_expression.join()`, where the - left and right selectables may be not only core selectable - objects such as :class:`_schema.Table`, but also mapped classes or - :class:`.AliasedClass` instances. The "on" clause can - be a SQL expression, or an attribute or string name - referencing a configured :func:`_orm.relationship`. - - :func:`_orm.join` is not commonly needed in modern usage, - as its functionality is encapsulated within that of the - :meth:`_query.Query.join` method, which features a - significant amount of automation beyond :func:`_orm.join` - by itself. Explicit usage of :func:`_orm.join` - with :class:`_query.Query` involves usage of the - :meth:`_query.Query.select_from` method, as in:: - - from sqlalchemy.orm import join - session.query(User).\ - select_from(join(User, Address, User.addresses)).\ - filter(Address.email_address=='foo@bar.com') - - In modern SQLAlchemy the above join can be written more - succinctly as:: - - session.query(User).\ - join(User.addresses).\ - filter(Address.email_address=='foo@bar.com') - - See :meth:`_query.Query.join` for information on modern usage - of ORM level joins. - - .. deprecated:: 0.8 - - the ``join_to_left`` parameter is deprecated, and will be removed - in a future release. The parameter has no effect. - - """ - return _ORMJoin(left, right, onclause, isouter, full) - - -def outerjoin(left, right, onclause=None, full=False, join_to_left=None): - """Produce a left outer join between left and right clauses. - - This is the "outer join" version of the :func:`_orm.join` function, - featuring the same behavior except that an OUTER JOIN is generated. - See that function's documentation for other usage details. - - """ - return _ORMJoin(left, right, onclause, True, full) - - def with_parent(instance, prop, from_entity=None): """Create filtering criterion that relates this query's primary entity to the given related instance, using established @@ -1964,3 +1785,56 @@ def _getitem(iterable_query, item): return list(iterable_query)[-1] else: return list(iterable_query[item : item + 1])[0] + + +def _is_mapped_annotation(raw_annotation: Union[type, str], cls: type): + annotated = de_stringify_annotation(cls, raw_annotation) + return is_origin_of(annotated, "Mapped", module="sqlalchemy.orm") + + +def _extract_mapped_subtype( + raw_annotation: Union[type, str], + cls: type, + key: str, + attr_cls: type, + required: bool, + is_dataclass_field: bool, +) -> Optional[Union[type, str]]: + + if raw_annotation is None: + + if required: + raise sa_exc.ArgumentError( + f"Python typing annotation is required for attribute " + f'"{cls.__name__}.{key}" when primary argument(s) for ' + f'"{attr_cls.__name__}" construct are None or not present' + ) + return None + + annotated = de_stringify_annotation(cls, raw_annotation) + + if is_dataclass_field: + return annotated + else: + if ( + not hasattr(annotated, "__origin__") + or not issubclass(annotated.__origin__, attr_cls) + and not issubclass(attr_cls, annotated.__origin__) + ): + our_annotated_str = ( + annotated.__name__ + if not isinstance(annotated, str) + else repr(annotated) + ) + raise sa_exc.ArgumentError( + f'Type annotation for "{cls.__name__}.{key}" should use the ' + f'syntax "Mapped[{our_annotated_str}]" or ' + f'"{attr_cls.__name__}[{our_annotated_str}]".' + ) + + if len(annotated.__args__) != 1: + raise sa_exc.ArgumentError( + "Expected sub-type for Mapped[] annotation" + ) + + return annotated.__args__[0] diff --git a/lib/sqlalchemy/pool/__init__.py b/lib/sqlalchemy/pool/__init__.py index 38059856e..bc2f93d57 100644 --- a/lib/sqlalchemy/pool/__init__.py +++ b/lib/sqlalchemy/pool/__init__.py @@ -22,36 +22,17 @@ from .base import _AdhocProxiedConnection from .base import _ConnectionFairy from .base import _ConnectionRecord from .base import _finalize_fairy -from .base import Pool -from .base import PoolProxiedConnection -from .base import reset_commit -from .base import reset_none -from .base import reset_rollback -from .impl import AssertionPool -from .impl import AsyncAdaptedQueuePool -from .impl import FallbackAsyncAdaptedQueuePool -from .impl import NullPool -from .impl import QueuePool -from .impl import SingletonThreadPool -from .impl import StaticPool - - -__all__ = [ - "Pool", - "PoolProxiedConnection", - "reset_commit", - "reset_none", - "reset_rollback", - "clear_managers", - "manage", - "AssertionPool", - "NullPool", - "QueuePool", - "AsyncAdaptedQueuePool", - "FallbackAsyncAdaptedQueuePool", - "SingletonThreadPool", - "StaticPool", -] - -# as these are likely to be used in various test suites, debugging -# setups, keep them in the sqlalchemy.pool namespace +from .base import Pool as Pool +from .base import PoolProxiedConnection as PoolProxiedConnection +from .base import reset_commit as reset_commit +from .base import reset_none as reset_none +from .base import reset_rollback as reset_rollback +from .impl import AssertionPool as AssertionPool +from .impl import AsyncAdaptedQueuePool as AsyncAdaptedQueuePool +from .impl import ( + FallbackAsyncAdaptedQueuePool as FallbackAsyncAdaptedQueuePool, +) +from .impl import NullPool as NullPool +from .impl import QueuePool as QueuePool +from .impl import SingletonThreadPool as SingletonThreadPool +from .impl import StaticPool as StaticPool diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index c596dee5a..b2ca1cfef 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -9,50 +9,54 @@ """ -from .sql.base import SchemaVisitor # noqa -from .sql.ddl import _CreateDropBase # noqa -from .sql.ddl import _DDLCompiles # noqa -from .sql.ddl import _DropView # noqa -from .sql.ddl import AddConstraint # noqa -from .sql.ddl import CreateColumn # noqa -from .sql.ddl import CreateIndex # noqa -from .sql.ddl import CreateSchema # noqa -from .sql.ddl import CreateSequence # noqa -from .sql.ddl import CreateTable # noqa -from .sql.ddl import DDL # noqa -from .sql.ddl import DDLBase # noqa -from .sql.ddl import DDLElement # noqa -from .sql.ddl import DropColumnComment # noqa -from .sql.ddl import DropConstraint # noqa -from .sql.ddl import DropIndex # noqa -from .sql.ddl import DropSchema # noqa -from .sql.ddl import DropSequence # noqa -from .sql.ddl import DropTable # noqa -from .sql.ddl import DropTableComment # noqa -from .sql.ddl import SetColumnComment # noqa -from .sql.ddl import SetTableComment # noqa -from .sql.ddl import sort_tables # noqa -from .sql.ddl import sort_tables_and_constraints # noqa -from .sql.naming import conv # noqa -from .sql.schema import _get_table_key # noqa -from .sql.schema import BLANK_SCHEMA # noqa -from .sql.schema import CheckConstraint # noqa -from .sql.schema import Column # noqa -from .sql.schema import ColumnCollectionConstraint # noqa -from .sql.schema import ColumnCollectionMixin # noqa -from .sql.schema import ColumnDefault # noqa -from .sql.schema import Computed # noqa -from .sql.schema import Constraint # noqa -from .sql.schema import DefaultClause # noqa -from .sql.schema import DefaultGenerator # noqa -from .sql.schema import FetchedValue # noqa -from .sql.schema import ForeignKey # noqa -from .sql.schema import ForeignKeyConstraint # noqa -from .sql.schema import Identity # noqa -from .sql.schema import Index # noqa -from .sql.schema import MetaData # noqa -from .sql.schema import PrimaryKeyConstraint # noqa -from .sql.schema import SchemaItem # noqa -from .sql.schema import Sequence # noqa -from .sql.schema import Table # noqa -from .sql.schema import UniqueConstraint # noqa +from .sql.base import SchemaVisitor as SchemaVisitor +from .sql.ddl import _CreateDropBase as _CreateDropBase +from .sql.ddl import _DDLCompiles as _DDLCompiles +from .sql.ddl import _DropView as _DropView +from .sql.ddl import AddConstraint as AddConstraint +from .sql.ddl import CreateColumn as CreateColumn +from .sql.ddl import CreateIndex as CreateIndex +from .sql.ddl import CreateSchema as CreateSchema +from .sql.ddl import CreateSequence as CreateSequence +from .sql.ddl import CreateTable as CreateTable +from .sql.ddl import DDL as DDL +from .sql.ddl import DDLBase as DDLBase +from .sql.ddl import DDLElement as DDLElement +from .sql.ddl import DropColumnComment as DropColumnComment +from .sql.ddl import DropConstraint as DropConstraint +from .sql.ddl import DropIndex as DropIndex +from .sql.ddl import DropSchema as DropSchema +from .sql.ddl import DropSequence as DropSequence +from .sql.ddl import DropTable as DropTable +from .sql.ddl import DropTableComment as DropTableComment +from .sql.ddl import SetColumnComment as SetColumnComment +from .sql.ddl import SetTableComment as SetTableComment +from .sql.ddl import sort_tables as sort_tables +from .sql.ddl import ( + sort_tables_and_constraints as sort_tables_and_constraints, +) +from .sql.naming import conv as conv +from .sql.schema import _get_table_key as _get_table_key +from .sql.schema import BLANK_SCHEMA as BLANK_SCHEMA +from .sql.schema import CheckConstraint as CheckConstraint +from .sql.schema import Column as Column +from .sql.schema import ( + ColumnCollectionConstraint as ColumnCollectionConstraint, +) +from .sql.schema import ColumnCollectionMixin as ColumnCollectionMixin +from .sql.schema import ColumnDefault as ColumnDefault +from .sql.schema import Computed as Computed +from .sql.schema import Constraint as Constraint +from .sql.schema import DefaultClause as DefaultClause +from .sql.schema import DefaultGenerator as DefaultGenerator +from .sql.schema import FetchedValue as FetchedValue +from .sql.schema import ForeignKey as ForeignKey +from .sql.schema import ForeignKeyConstraint as ForeignKeyConstraint +from .sql.schema import Identity as Identity +from .sql.schema import Index as Index +from .sql.schema import MetaData as MetaData +from .sql.schema import PrimaryKeyConstraint as PrimaryKeyConstraint +from .sql.schema import SchemaItem as SchemaItem +from .sql.schema import Sequence as Sequence +from .sql.schema import Table as Table +from .sql.schema import UniqueConstraint as UniqueConstraint diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py index 2f84370aa..169ddf3db 100644 --- a/lib/sqlalchemy/sql/__init__.py +++ b/lib/sqlalchemy/sql/__init__.py @@ -75,6 +75,7 @@ from .expression import quoted_name as quoted_name from .expression import Select as Select from .expression import select as select from .expression import Selectable as Selectable +from .expression import SelectLabelStyle as SelectLabelStyle from .expression import StatementLambdaElement as StatementLambdaElement from .expression import Subquery as Subquery from .expression import table as table diff --git a/lib/sqlalchemy/sql/_selectable_constructors.py b/lib/sqlalchemy/sql/_selectable_constructors.py index 4b67c12f0..d3cf207da 100644 --- a/lib/sqlalchemy/sql/_selectable_constructors.py +++ b/lib/sqlalchemy/sql/_selectable_constructors.py @@ -6,11 +6,11 @@ # the MIT License: https://www.opensource.org/licenses/mit-license.php from typing import Any -from typing import Type from typing import Union from . import coercions from . import roles +from ._typing import _ColumnsClauseElement from .elements import ColumnClause from .selectable import Alias from .selectable import CompoundSelect @@ -21,6 +21,8 @@ from .selectable import Select from .selectable import TableClause from .selectable import TableSample from .selectable import Values +from ..util.typing import _LiteralStar +from ..util.typing import Literal def alias(selectable, name=None, flat=False): @@ -279,7 +281,9 @@ def outerjoin(left, right, onclause=None, full=False): return Join(left, right, onclause, isouter=True, full=full) -def select(*entities: Union[roles.ColumnsClauseRole, Type]) -> "Select": +def select( + *entities: Union[_LiteralStar, Literal[1], _ColumnsClauseElement] +) -> "Select": r"""Construct a new :class:`_expression.Select`. diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index b5b0efb21..4d2dd2688 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -1,9 +1,21 @@ from typing import Any from typing import Mapping from typing import Sequence +from typing import Type from typing import Union +from . import roles +from ..inspection import Inspectable +from ..util import immutabledict + _SingleExecuteParams = Mapping[str, Any] _MultiExecuteParams = Sequence[_SingleExecuteParams] _ExecuteParams = Union[_SingleExecuteParams, _MultiExecuteParams] _ExecuteOptions = Mapping[str, Any] +_ImmutableExecuteOptions = immutabledict[str, Any] +_ColumnsClauseElement = Union[ + roles.ColumnsClauseRole, Type, Inspectable[roles.HasClauseElement] +] +_FromClauseElement = Union[ + roles.FromClauseRole, Type, Inspectable[roles.HasFromClauseElement] +] diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index f4fe7afab..5828f9369 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -21,6 +21,7 @@ from typing import TypeVar from . import roles from . import visitors +from ._typing import _ImmutableExecuteOptions from .cache_key import HasCacheKey # noqa from .cache_key import MemoizedHasCacheKey # noqa from .traversals import HasCopyInternals # noqa @@ -832,9 +833,8 @@ class Executable(roles.StatementRole, Generative): """ - supports_execution = True - _execution_options = util.immutabledict() - _bind = None + supports_execution: bool = True + _execution_options: _ImmutableExecuteOptions = util.immutabledict() _with_options = () _with_context_options = () diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 9cf4d8397..bf78b4231 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -889,7 +889,7 @@ class SQLCompiler(Compiled): def _apply_numbered_params(self): poscount = itertools.count(1) self.string = re.sub( - r"\[_POSITION\]", lambda m: str(util.next(poscount)), self.string + r"\[_POSITION\]", lambda m: str(next(poscount)), self.string ) @util.memoized_property diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index 18931ce67..f622023b0 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -10,6 +10,11 @@ to invoke them for a create/drop call. """ import typing +from typing import Callable +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple from . import roles from .base import _generative @@ -21,6 +26,11 @@ from .. import util from ..util import topological +if typing.TYPE_CHECKING: + from .schema import ForeignKeyConstraint + from .schema import Table + + class _DDLCompiles(ClauseElement): _hierarchy_supports_caching = False """disable cache warnings for all _DDLCompiles subclasses. """ @@ -1007,10 +1017,10 @@ class SchemaDropper(DDLBase): def sort_tables( - tables, - skip_fn=None, - extra_dependencies=None, -): + tables: Sequence["Table"], + skip_fn: Optional[Callable[["ForeignKeyConstraint"], bool]] = None, + extra_dependencies: Optional[Sequence[Tuple["Table", "Table"]]] = None, +) -> List["Table"]: """Sort a collection of :class:`_schema.Table` objects based on dependency. @@ -1051,7 +1061,7 @@ def sort_tables( :param tables: a sequence of :class:`_schema.Table` objects. :param skip_fn: optional callable which will be passed a - :class:`_schema.ForeignKey` object; if it returns True, this + :class:`_schema.ForeignKeyConstraint` object; if it returns True, this constraint will not be considered as a dependency. Note this is **different** from the same parameter in :func:`.sort_tables_and_constraints`, which is diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 0ed5bd986..22195cd7c 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -136,6 +136,7 @@ from .selectable import ScalarSelect as ScalarSelect from .selectable import Select as Select from .selectable import Selectable as Selectable from .selectable import SelectBase as SelectBase +from .selectable import SelectLabelStyle as SelectLabelStyle from .selectable import Subquery as Subquery from .selectable import TableClause as TableClause from .selectable import TableSample as TableSample diff --git a/lib/sqlalchemy/sql/naming.py b/lib/sqlalchemy/sql/naming.py index 00a2b1d89..15a1566a6 100644 --- a/lib/sqlalchemy/sql/naming.py +++ b/lib/sqlalchemy/sql/naming.py @@ -14,7 +14,7 @@ import re from . import events # noqa from .elements import _NONE_NAME -from .elements import conv +from .elements import conv as conv from .schema import CheckConstraint from .schema import Column from .schema import Constraint diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py index 787a1c25e..b41ef7a5d 100644 --- a/lib/sqlalchemy/sql/roles.py +++ b/lib/sqlalchemy/sql/roles.py @@ -4,10 +4,17 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +import typing +from sqlalchemy.util.langhelpers import TypingOnly from .. import util +if typing.TYPE_CHECKING: + from .elements import ClauseElement + from .selectable import FromClause + + class SQLRole: """Define a "role" within a SQL statement structure. @@ -284,3 +291,25 @@ class DDLReferredColumnRole(DDLConstraintColumnRole): _role_name = ( "String column name or Column object for DDL foreign key constraint" ) + + +class HasClauseElement(TypingOnly): + """indicates a class that has a __clause_element__() method""" + + __slots__ = () + + if typing.TYPE_CHECKING: + + def __clause_element__(self) -> "ClauseElement": + ... + + +class HasFromClauseElement(HasClauseElement, TypingOnly): + """indicates a class that has a __clause_element__() method""" + + __slots__ = () + + if typing.TYPE_CHECKING: + + def __clause_element__(self) -> "FromClause": + ... diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index a04fad05d..9387ae030 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -31,9 +31,12 @@ as components in SQL expressions. import collections import typing from typing import Any +from typing import Dict +from typing import List from typing import MutableMapping from typing import Optional from typing import overload +from typing import Sequence as _typing_Sequence from typing import Type from typing import TypeVar from typing import Union @@ -52,6 +55,7 @@ from .elements import ClauseElement from .elements import ColumnClause from .elements import ColumnElement from .elements import quoted_name +from .elements import SQLCoreOperations from .elements import TextClause from .selectable import TableClause from .type_api import to_instance @@ -64,9 +68,12 @@ from ..util.typing import Literal if typing.TYPE_CHECKING: from .type_api import TypeEngine + from ..engine import Connection + from ..engine import Engine _T = TypeVar("_T", bound="Any") _ServerDefaultType = Union["FetchedValue", str, TextClause, ColumnElement] +_TAB = TypeVar("_TAB", bound="Table") RETAIN_SCHEMA = util.symbol("retain_schema") @@ -188,313 +195,6 @@ class Table(DialectKWArgs, SchemaItem, TableClause): :ref:`metadata_describing` - Introduction to database metadata - Constructor arguments are as follows: - - :param name: The name of this table as represented in the database. - - The table name, along with the value of the ``schema`` parameter, - forms a key which uniquely identifies this :class:`_schema.Table` - within - the owning :class:`_schema.MetaData` collection. - Additional calls to :class:`_schema.Table` with the same name, - metadata, - and schema name will return the same :class:`_schema.Table` object. - - Names which contain no upper case characters - will be treated as case insensitive names, and will not be quoted - unless they are a reserved word or contain special characters. - A name with any number of upper case characters is considered - to be case sensitive, and will be sent as quoted. - - To enable unconditional quoting for the table name, specify the flag - ``quote=True`` to the constructor, or use the :class:`.quoted_name` - construct to specify the name. - - :param metadata: a :class:`_schema.MetaData` - object which will contain this - table. The metadata is used as a point of association of this table - with other tables which are referenced via foreign key. It also - may be used to associate this table with a particular - :class:`.Connection` or :class:`.Engine`. - - :param \*args: Additional positional arguments are used primarily - to add the list of :class:`_schema.Column` - objects contained within this - table. Similar to the style of a CREATE TABLE statement, other - :class:`.SchemaItem` constructs may be added here, including - :class:`.PrimaryKeyConstraint`, and - :class:`_schema.ForeignKeyConstraint`. - - :param autoload: Defaults to ``False``, unless - :paramref:`_schema.Table.autoload_with` - is set in which case it defaults to ``True``; - :class:`_schema.Column` objects - for this table should be reflected from the database, possibly - augmenting objects that were explicitly specified. - :class:`_schema.Column` and other objects explicitly set on the - table will replace corresponding reflected objects. - - .. deprecated:: 1.4 - - The autoload parameter is deprecated and will be removed in - version 2.0. Please use the - :paramref:`_schema.Table.autoload_with` parameter, passing an - engine or connection. - - .. seealso:: - - :ref:`metadata_reflection_toplevel` - - :param autoload_replace: Defaults to ``True``; when using - :paramref:`_schema.Table.autoload` - in conjunction with :paramref:`_schema.Table.extend_existing`, - indicates - that :class:`_schema.Column` objects present in the already-existing - :class:`_schema.Table` - object should be replaced with columns of the same - name retrieved from the autoload process. When ``False``, columns - already present under existing names will be omitted from the - reflection process. - - Note that this setting does not impact :class:`_schema.Column` objects - specified programmatically within the call to :class:`_schema.Table` - that - also is autoloading; those :class:`_schema.Column` objects will always - replace existing columns of the same name when - :paramref:`_schema.Table.extend_existing` is ``True``. - - .. seealso:: - - :paramref:`_schema.Table.autoload` - - :paramref:`_schema.Table.extend_existing` - - :param autoload_with: An :class:`_engine.Engine` or - :class:`_engine.Connection` object, - or a :class:`_reflection.Inspector` object as returned by - :func:`_sa.inspect` - against one, with which this :class:`_schema.Table` - object will be reflected. - When set to a non-None value, the autoload process will take place - for this table against the given engine or connection. - - :param extend_existing: When ``True``, indicates that if this - :class:`_schema.Table` is already present in the given - :class:`_schema.MetaData`, - apply further arguments within the constructor to the existing - :class:`_schema.Table`. - - If :paramref:`_schema.Table.extend_existing` or - :paramref:`_schema.Table.keep_existing` are not set, - and the given name - of the new :class:`_schema.Table` refers to a :class:`_schema.Table` - that is - already present in the target :class:`_schema.MetaData` collection, - and - this :class:`_schema.Table` - specifies additional columns or other constructs - or flags that modify the table's state, an - error is raised. The purpose of these two mutually-exclusive flags - is to specify what action should be taken when a - :class:`_schema.Table` - is specified that matches an existing :class:`_schema.Table`, - yet specifies - additional constructs. - - :paramref:`_schema.Table.extend_existing` - will also work in conjunction - with :paramref:`_schema.Table.autoload` to run a new reflection - operation against the database, even if a :class:`_schema.Table` - of the same name is already present in the target - :class:`_schema.MetaData`; newly reflected :class:`_schema.Column` - objects - and other options will be added into the state of the - :class:`_schema.Table`, potentially overwriting existing columns - and options of the same name. - - As is always the case with :paramref:`_schema.Table.autoload`, - :class:`_schema.Column` objects can be specified in the same - :class:`_schema.Table` - constructor, which will take precedence. Below, the existing - table ``mytable`` will be augmented with :class:`_schema.Column` - objects - both reflected from the database, as well as the given - :class:`_schema.Column` - named "y":: - - Table("mytable", metadata, - Column('y', Integer), - extend_existing=True, - autoload_with=engine - ) - - .. seealso:: - - :paramref:`_schema.Table.autoload` - - :paramref:`_schema.Table.autoload_replace` - - :paramref:`_schema.Table.keep_existing` - - - :param implicit_returning: True by default - indicates that - RETURNING can be used by default to fetch newly inserted primary key - values, for backends which support this. Note that - :func:`_sa.create_engine` also provides an ``implicit_returning`` - flag. - - :param include_columns: A list of strings indicating a subset of - columns to be loaded via the ``autoload`` operation; table columns who - aren't present in this list will not be represented on the resulting - ``Table`` object. Defaults to ``None`` which indicates all columns - should be reflected. - - :param resolve_fks: Whether or not to reflect :class:`_schema.Table` - objects - related to this one via :class:`_schema.ForeignKey` objects, when - :paramref:`_schema.Table.autoload` or - :paramref:`_schema.Table.autoload_with` is - specified. Defaults to True. Set to False to disable reflection of - related tables as :class:`_schema.ForeignKey` - objects are encountered; may be - used either to save on SQL calls or to avoid issues with related tables - that can't be accessed. Note that if a related table is already present - in the :class:`_schema.MetaData` collection, or becomes present later, - a - :class:`_schema.ForeignKey` object associated with this - :class:`_schema.Table` will - resolve to that table normally. - - .. versionadded:: 1.3 - - .. seealso:: - - :paramref:`.MetaData.reflect.resolve_fks` - - - :param info: Optional data dictionary which will be populated into the - :attr:`.SchemaItem.info` attribute of this object. - - :param keep_existing: When ``True``, indicates that if this Table - is already present in the given :class:`_schema.MetaData`, ignore - further arguments within the constructor to the existing - :class:`_schema.Table`, and return the :class:`_schema.Table` - object as - originally created. This is to allow a function that wishes - to define a new :class:`_schema.Table` on first call, but on - subsequent calls will return the same :class:`_schema.Table`, - without any of the declarations (particularly constraints) - being applied a second time. - - If :paramref:`_schema.Table.extend_existing` or - :paramref:`_schema.Table.keep_existing` are not set, - and the given name - of the new :class:`_schema.Table` refers to a :class:`_schema.Table` - that is - already present in the target :class:`_schema.MetaData` collection, - and - this :class:`_schema.Table` - specifies additional columns or other constructs - or flags that modify the table's state, an - error is raised. The purpose of these two mutually-exclusive flags - is to specify what action should be taken when a - :class:`_schema.Table` - is specified that matches an existing :class:`_schema.Table`, - yet specifies - additional constructs. - - .. seealso:: - - :paramref:`_schema.Table.extend_existing` - - :param listeners: A list of tuples of the form ``(<eventname>, <fn>)`` - which will be passed to :func:`.event.listen` upon construction. - This alternate hook to :func:`.event.listen` allows the establishment - of a listener function specific to this :class:`_schema.Table` before - the "autoload" process begins. Historically this has been intended - for use with the :meth:`.DDLEvents.column_reflect` event, however - note that this event hook may now be associated with the - :class:`_schema.MetaData` object directly:: - - def listen_for_reflect(table, column_info): - "handle the column reflection event" - # ... - - t = Table( - 'sometable', - autoload_with=engine, - listeners=[ - ('column_reflect', listen_for_reflect) - ]) - - .. seealso:: - - :meth:`_events.DDLEvents.column_reflect` - - :param must_exist: When ``True``, indicates that this Table must already - be present in the given :class:`_schema.MetaData` collection, else - an exception is raised. - - :param prefixes: - A list of strings to insert after CREATE in the CREATE TABLE - statement. They will be separated by spaces. - - :param quote: Force quoting of this table's name on or off, corresponding - to ``True`` or ``False``. When left at its default of ``None``, - the column identifier will be quoted according to whether the name is - case sensitive (identifiers with at least one upper case character are - treated as case sensitive), or if it's a reserved word. This flag - is only needed to force quoting of a reserved word which is not known - by the SQLAlchemy dialect. - - .. note:: setting this flag to ``False`` will not provide - case-insensitive behavior for table reflection; table reflection - will always search for a mixed-case name in a case sensitive - fashion. Case insensitive names are specified in SQLAlchemy only - by stating the name with all lower case characters. - - :param quote_schema: same as 'quote' but applies to the schema identifier. - - :param schema: The schema name for this table, which is required if - the table resides in a schema other than the default selected schema - for the engine's database connection. Defaults to ``None``. - - If the owning :class:`_schema.MetaData` of this :class:`_schema.Table` - specifies its - own :paramref:`_schema.MetaData.schema` parameter, - then that schema name will - be applied to this :class:`_schema.Table` - if the schema parameter here is set - to ``None``. To set a blank schema name on a :class:`_schema.Table` - that - would otherwise use the schema set on the owning - :class:`_schema.MetaData`, - specify the special symbol :attr:`.BLANK_SCHEMA`. - - .. versionadded:: 1.0.14 Added the :attr:`.BLANK_SCHEMA` symbol to - allow a :class:`_schema.Table` - to have a blank schema name even when the - parent :class:`_schema.MetaData` specifies - :paramref:`_schema.MetaData.schema`. - - The quoting rules for the schema name are the same as those for the - ``name`` parameter, in that quoting is applied for reserved words or - case-sensitive names; to enable unconditional quoting for the schema - name, specify the flag ``quote_schema=True`` to the constructor, or use - the :class:`.quoted_name` construct to specify the name. - - :param comment: Optional string that will render an SQL comment on table - creation. - - .. versionadded:: 1.2 Added the :paramref:`_schema.Table.comment` - parameter - to :class:`_schema.Table`. - - :param \**kw: Additional keyword arguments not mentioned above are - dialect specific, and passed in the form ``<dialectname>_<argname>``. - See the documentation regarding an individual dialect at - :ref:`dialect_toplevel` for detail on documented arguments. - """ __visit_name__ = "table" @@ -547,13 +247,21 @@ class Table(DialectKWArgs, SchemaItem, TableClause): else: return (self,) - @util.deprecated_params( - mustexist=( - "1.4", - "Deprecated alias of :paramref:`_schema.Table.must_exist`", - ), - ) - def __new__(cls, *args, **kw): + if not typing.TYPE_CHECKING: + # typing tools seem to be inconsistent in how they handle + # __new__, so suggest this pattern for classes that use + # __new__. apply typing to the __init__ method normally + @util.deprecated_params( + mustexist=( + "1.4", + "Deprecated alias of :paramref:`_schema.Table.must_exist`", + ), + ) + def __new__(cls, *args: Any, **kw: Any) -> Any: + return cls._new(*args, **kw) + + @classmethod + def _new(cls, *args, **kw): if not args and not kw: # python3k pickle seems to call this return object.__new__(cls) @@ -607,14 +315,323 @@ class Table(DialectKWArgs, SchemaItem, TableClause): with util.safe_reraise(): metadata._remove_table(name, schema) - def __init__(self, *args, **kw): - """Constructor for :class:`_schema.Table`. + def __init__( + self, + name: str, + metadata: "MetaData", + *args: SchemaItem, + **kw: Any, + ): + r"""Constructor for :class:`_schema.Table`. - This method is a no-op. See the top-level - documentation for :class:`_schema.Table` - for constructor arguments. - """ + :param name: The name of this table as represented in the database. + + The table name, along with the value of the ``schema`` parameter, + forms a key which uniquely identifies this :class:`_schema.Table` + within + the owning :class:`_schema.MetaData` collection. + Additional calls to :class:`_schema.Table` with the same name, + metadata, + and schema name will return the same :class:`_schema.Table` object. + + Names which contain no upper case characters + will be treated as case insensitive names, and will not be quoted + unless they are a reserved word or contain special characters. + A name with any number of upper case characters is considered + to be case sensitive, and will be sent as quoted. + + To enable unconditional quoting for the table name, specify the flag + ``quote=True`` to the constructor, or use the :class:`.quoted_name` + construct to specify the name. + + :param metadata: a :class:`_schema.MetaData` + object which will contain this + table. The metadata is used as a point of association of this table + with other tables which are referenced via foreign key. It also + may be used to associate this table with a particular + :class:`.Connection` or :class:`.Engine`. + + :param \*args: Additional positional arguments are used primarily + to add the list of :class:`_schema.Column` + objects contained within this + table. Similar to the style of a CREATE TABLE statement, other + :class:`.SchemaItem` constructs may be added here, including + :class:`.PrimaryKeyConstraint`, and + :class:`_schema.ForeignKeyConstraint`. + + :param autoload: Defaults to ``False``, unless + :paramref:`_schema.Table.autoload_with` + is set in which case it defaults to ``True``; + :class:`_schema.Column` objects + for this table should be reflected from the database, possibly + augmenting objects that were explicitly specified. + :class:`_schema.Column` and other objects explicitly set on the + table will replace corresponding reflected objects. + + .. deprecated:: 1.4 + + The autoload parameter is deprecated and will be removed in + version 2.0. Please use the + :paramref:`_schema.Table.autoload_with` parameter, passing an + engine or connection. + + .. seealso:: + + :ref:`metadata_reflection_toplevel` + + :param autoload_replace: Defaults to ``True``; when using + :paramref:`_schema.Table.autoload` + in conjunction with :paramref:`_schema.Table.extend_existing`, + indicates + that :class:`_schema.Column` objects present in the already-existing + :class:`_schema.Table` + object should be replaced with columns of the same + name retrieved from the autoload process. When ``False``, columns + already present under existing names will be omitted from the + reflection process. + + Note that this setting does not impact :class:`_schema.Column` objects + specified programmatically within the call to :class:`_schema.Table` + that + also is autoloading; those :class:`_schema.Column` objects will always + replace existing columns of the same name when + :paramref:`_schema.Table.extend_existing` is ``True``. + + .. seealso:: + + :paramref:`_schema.Table.autoload` + + :paramref:`_schema.Table.extend_existing` + + :param autoload_with: An :class:`_engine.Engine` or + :class:`_engine.Connection` object, + or a :class:`_reflection.Inspector` object as returned by + :func:`_sa.inspect` + against one, with which this :class:`_schema.Table` + object will be reflected. + When set to a non-None value, the autoload process will take place + for this table against the given engine or connection. + + :param extend_existing: When ``True``, indicates that if this + :class:`_schema.Table` is already present in the given + :class:`_schema.MetaData`, + apply further arguments within the constructor to the existing + :class:`_schema.Table`. + + If :paramref:`_schema.Table.extend_existing` or + :paramref:`_schema.Table.keep_existing` are not set, + and the given name + of the new :class:`_schema.Table` refers to a :class:`_schema.Table` + that is + already present in the target :class:`_schema.MetaData` collection, + and + this :class:`_schema.Table` + specifies additional columns or other constructs + or flags that modify the table's state, an + error is raised. The purpose of these two mutually-exclusive flags + is to specify what action should be taken when a + :class:`_schema.Table` + is specified that matches an existing :class:`_schema.Table`, + yet specifies + additional constructs. + + :paramref:`_schema.Table.extend_existing` + will also work in conjunction + with :paramref:`_schema.Table.autoload` to run a new reflection + operation against the database, even if a :class:`_schema.Table` + of the same name is already present in the target + :class:`_schema.MetaData`; newly reflected :class:`_schema.Column` + objects + and other options will be added into the state of the + :class:`_schema.Table`, potentially overwriting existing columns + and options of the same name. + + As is always the case with :paramref:`_schema.Table.autoload`, + :class:`_schema.Column` objects can be specified in the same + :class:`_schema.Table` + constructor, which will take precedence. Below, the existing + table ``mytable`` will be augmented with :class:`_schema.Column` + objects + both reflected from the database, as well as the given + :class:`_schema.Column` + named "y":: + + Table("mytable", metadata, + Column('y', Integer), + extend_existing=True, + autoload_with=engine + ) + + .. seealso:: + + :paramref:`_schema.Table.autoload` + + :paramref:`_schema.Table.autoload_replace` + + :paramref:`_schema.Table.keep_existing` + + + :param implicit_returning: True by default - indicates that + RETURNING can be used by default to fetch newly inserted primary key + values, for backends which support this. Note that + :func:`_sa.create_engine` also provides an ``implicit_returning`` + flag. + + :param include_columns: A list of strings indicating a subset of + columns to be loaded via the ``autoload`` operation; table columns who + aren't present in this list will not be represented on the resulting + ``Table`` object. Defaults to ``None`` which indicates all columns + should be reflected. + + :param resolve_fks: Whether or not to reflect :class:`_schema.Table` + objects + related to this one via :class:`_schema.ForeignKey` objects, when + :paramref:`_schema.Table.autoload` or + :paramref:`_schema.Table.autoload_with` is + specified. Defaults to True. Set to False to disable reflection of + related tables as :class:`_schema.ForeignKey` + objects are encountered; may be + used either to save on SQL calls or to avoid issues with related tables + that can't be accessed. Note that if a related table is already present + in the :class:`_schema.MetaData` collection, or becomes present later, + a + :class:`_schema.ForeignKey` object associated with this + :class:`_schema.Table` will + resolve to that table normally. + + .. versionadded:: 1.3 + + .. seealso:: + + :paramref:`.MetaData.reflect.resolve_fks` + + + :param info: Optional data dictionary which will be populated into the + :attr:`.SchemaItem.info` attribute of this object. + + :param keep_existing: When ``True``, indicates that if this Table + is already present in the given :class:`_schema.MetaData`, ignore + further arguments within the constructor to the existing + :class:`_schema.Table`, and return the :class:`_schema.Table` + object as + originally created. This is to allow a function that wishes + to define a new :class:`_schema.Table` on first call, but on + subsequent calls will return the same :class:`_schema.Table`, + without any of the declarations (particularly constraints) + being applied a second time. + + If :paramref:`_schema.Table.extend_existing` or + :paramref:`_schema.Table.keep_existing` are not set, + and the given name + of the new :class:`_schema.Table` refers to a :class:`_schema.Table` + that is + already present in the target :class:`_schema.MetaData` collection, + and + this :class:`_schema.Table` + specifies additional columns or other constructs + or flags that modify the table's state, an + error is raised. The purpose of these two mutually-exclusive flags + is to specify what action should be taken when a + :class:`_schema.Table` + is specified that matches an existing :class:`_schema.Table`, + yet specifies + additional constructs. + + .. seealso:: + + :paramref:`_schema.Table.extend_existing` + + :param listeners: A list of tuples of the form ``(<eventname>, <fn>)`` + which will be passed to :func:`.event.listen` upon construction. + This alternate hook to :func:`.event.listen` allows the establishment + of a listener function specific to this :class:`_schema.Table` before + the "autoload" process begins. Historically this has been intended + for use with the :meth:`.DDLEvents.column_reflect` event, however + note that this event hook may now be associated with the + :class:`_schema.MetaData` object directly:: + + def listen_for_reflect(table, column_info): + "handle the column reflection event" + # ... + + t = Table( + 'sometable', + autoload_with=engine, + listeners=[ + ('column_reflect', listen_for_reflect) + ]) + + .. seealso:: + + :meth:`_events.DDLEvents.column_reflect` + + :param must_exist: When ``True``, indicates that this Table must already + be present in the given :class:`_schema.MetaData` collection, else + an exception is raised. + + :param prefixes: + A list of strings to insert after CREATE in the CREATE TABLE + statement. They will be separated by spaces. + + :param quote: Force quoting of this table's name on or off, corresponding + to ``True`` or ``False``. When left at its default of ``None``, + the column identifier will be quoted according to whether the name is + case sensitive (identifiers with at least one upper case character are + treated as case sensitive), or if it's a reserved word. This flag + is only needed to force quoting of a reserved word which is not known + by the SQLAlchemy dialect. + + .. note:: setting this flag to ``False`` will not provide + case-insensitive behavior for table reflection; table reflection + will always search for a mixed-case name in a case sensitive + fashion. Case insensitive names are specified in SQLAlchemy only + by stating the name with all lower case characters. + + :param quote_schema: same as 'quote' but applies to the schema identifier. + + :param schema: The schema name for this table, which is required if + the table resides in a schema other than the default selected schema + for the engine's database connection. Defaults to ``None``. + + If the owning :class:`_schema.MetaData` of this :class:`_schema.Table` + specifies its + own :paramref:`_schema.MetaData.schema` parameter, + then that schema name will + be applied to this :class:`_schema.Table` + if the schema parameter here is set + to ``None``. To set a blank schema name on a :class:`_schema.Table` + that + would otherwise use the schema set on the owning + :class:`_schema.MetaData`, + specify the special symbol :attr:`.BLANK_SCHEMA`. + + .. versionadded:: 1.0.14 Added the :attr:`.BLANK_SCHEMA` symbol to + allow a :class:`_schema.Table` + to have a blank schema name even when the + parent :class:`_schema.MetaData` specifies + :paramref:`_schema.MetaData.schema`. + + The quoting rules for the schema name are the same as those for the + ``name`` parameter, in that quoting is applied for reserved words or + case-sensitive names; to enable unconditional quoting for the schema + name, specify the flag ``quote_schema=True`` to the constructor, or use + the :class:`.quoted_name` construct to specify the name. + + :param comment: Optional string that will render an SQL comment on table + creation. + + .. versionadded:: 1.2 Added the :paramref:`_schema.Table.comment` + parameter + to :class:`_schema.Table`. + + :param \**kw: Additional keyword arguments not mentioned above are + dialect specific, and passed in the form ``<dialectname>_<argname>``. + See the documentation regarding an individual dialect at + :ref:`dialect_toplevel` for detail on documented arguments. + + """ # noqa E501 + # __init__ is overridden to prevent __new__ from # calling the superclass constructor. @@ -1203,7 +1220,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): ) -> None: ... - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): r""" Construct a new ``Column`` object. @@ -2179,18 +2196,18 @@ class ForeignKey(DialectKWArgs, SchemaItem): def __init__( self, - column, - _constraint=None, - use_alter=False, - name=None, - onupdate=None, - ondelete=None, - deferrable=None, - initially=None, - link_to_name=False, - match=None, - info=None, - **dialect_kw, + column: Union[str, Column, SQLCoreOperations], + _constraint: Optional["ForeignKeyConstraint"] = None, + use_alter: bool = False, + name: Optional[str] = None, + onupdate: Optional[str] = None, + ondelete: Optional[str] = None, + deferrable: Optional[bool] = None, + initially: Optional[bool] = None, + link_to_name: bool = False, + match: Optional[str] = None, + info: Optional[Dict[Any, Any]] = None, + **dialect_kw: Any, ): r""" Construct a column-level FOREIGN KEY. @@ -2337,7 +2354,7 @@ class ForeignKey(DialectKWArgs, SchemaItem): ) return self._schema_item_copy(fk) - def _get_colspec(self, schema=None, table_name=None): + def _get_colspec(self, schema=None, table_name=None, _is_copy=False): """Return a string based 'column specification' for this :class:`_schema.ForeignKey`. @@ -2357,6 +2374,14 @@ class ForeignKey(DialectKWArgs, SchemaItem): else: return "%s.%s" % (table_name, colname) elif self._table_column is not None: + if self._table_column.table is None: + if _is_copy: + raise exc.InvalidRequestError( + f"Can't copy ForeignKey object which refers to " + f"non-table bound Column {self._table_column!r}" + ) + else: + return self._table_column.key return "%s.%s" % ( self._table_column.table.fullname, self._table_column.key, @@ -3858,6 +3883,7 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): if target_table is not None and x._table_key() == x.parent.table.key else None, + _is_copy=True, ) for x in self.elements ], @@ -4331,10 +4357,10 @@ class MetaData(SchemaItem): def __init__( self, - schema=None, - quote_schema=None, - naming_convention=None, - info=None, + schema: Optional[str] = None, + quote_schema: Optional[bool] = None, + naming_convention: Optional[Dict[str, str]] = None, + info: Optional[Dict[Any, Any]] = None, ): """Create a new MetaData object. @@ -4465,7 +4491,7 @@ class MetaData(SchemaItem): self._sequences = {} self._fk_memos = collections.defaultdict(list) - tables = None + tables: Dict[str, Table] """A dictionary of :class:`_schema.Table` objects keyed to their name or "table key". @@ -4483,10 +4509,10 @@ class MetaData(SchemaItem): """ - def __repr__(self): + def __repr__(self) -> str: return "MetaData()" - def __contains__(self, table_or_key): + def __contains__(self, table_or_key: Union[str, Table]) -> bool: if not isinstance(table_or_key, str): table_or_key = table_or_key.key return table_or_key in self.tables @@ -4530,20 +4556,20 @@ class MetaData(SchemaItem): self._schemas = state["schemas"] self._fk_memos = state["fk_memos"] - def clear(self): + def clear(self) -> None: """Clear all Table objects from this MetaData.""" dict.clear(self.tables) self._schemas.clear() self._fk_memos.clear() - def remove(self, table): + def remove(self, table: Table) -> None: """Remove the given Table object from this MetaData.""" self._remove_table(table.name, table.schema) @property - def sorted_tables(self): + def sorted_tables(self) -> List[Table]: """Returns a list of :class:`_schema.Table` objects sorted in order of foreign key dependency. @@ -4599,14 +4625,14 @@ class MetaData(SchemaItem): def reflect( self, - bind, - schema=None, - views=False, - only=None, - extend_existing=False, - autoload_replace=True, - resolve_fks=True, - **dialect_kwargs, + bind: Union["Engine", "Connection"], + schema: Optional[str] = None, + views: bool = False, + only: Optional[_typing_Sequence[str]] = None, + extend_existing: bool = False, + autoload_replace: bool = True, + resolve_fks: bool = True, + **dialect_kwargs: Any, ): r"""Load all available table definitions from the database. @@ -4754,7 +4780,12 @@ class MetaData(SchemaItem): except exc.UnreflectableTableError as uerr: util.warn("Skipping table %s: %s" % (name, uerr)) - def create_all(self, bind, tables=None, checkfirst=True): + def create_all( + self, + bind: Union["Engine", "Connection"], + tables: Optional[_typing_Sequence[Table]] = None, + checkfirst: bool = True, + ): """Create all tables stored in this metadata. Conditional by default, will not attempt to recreate tables already @@ -4777,7 +4808,12 @@ class MetaData(SchemaItem): ddl.SchemaGenerator, self, checkfirst=checkfirst, tables=tables ) - def drop_all(self, bind, tables=None, checkfirst=True): + def drop_all( + self, + bind: Union["Engine", "Connection"], + tables: Optional[_typing_Sequence[Table]] = None, + checkfirst: bool = True, + ): """Drop all tables stored in this metadata. Conditional by default, will not attempt to drop tables not present in diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index e1bbcffec..b0985f75d 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -12,14 +12,13 @@ SQL tables and derived rowsets. """ import collections +from enum import Enum import itertools from operator import attrgetter import typing from typing import Any as TODO_Any from typing import Optional from typing import Tuple -from typing import Type -from typing import Union from . import cache_key from . import coercions @@ -28,6 +27,7 @@ from . import roles from . import traversals from . import type_api from . import visitors +from ._typing import _ColumnsClauseElement from .annotation import Annotated from .annotation import SupportsCloneAnnotations from .base import _clone @@ -847,8 +847,11 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): return self.alias(name=name) -LABEL_STYLE_NONE = util.symbol( - "LABEL_STYLE_NONE", +class SelectLabelStyle(Enum): + """Label style constants that may be passed to + :meth:`_sql.Select.set_label_style`.""" + + LABEL_STYLE_NONE = 0 """Label style indicating no automatic labeling should be applied to the columns clause of a SELECT statement. @@ -867,11 +870,9 @@ LABEL_STYLE_NONE = util.symbol( .. versionadded:: 1.4 -""", # noqa E501 -) + """ # noqa E501 -LABEL_STYLE_TABLENAME_PLUS_COL = util.symbol( - "LABEL_STYLE_TABLENAME_PLUS_COL", + LABEL_STYLE_TABLENAME_PLUS_COL = 1 """Label style indicating all columns should be labeled as ``<tablename>_<columnname>`` when generating the columns clause of a SELECT statement, to disambiguate same-named columns referenced from different @@ -897,12 +898,9 @@ LABEL_STYLE_TABLENAME_PLUS_COL = util.symbol( .. versionadded:: 1.4 -""", # noqa E501 -) + """ # noqa: E501 - -LABEL_STYLE_DISAMBIGUATE_ONLY = util.symbol( - "LABEL_STYLE_DISAMBIGUATE_ONLY", + LABEL_STYLE_DISAMBIGUATE_ONLY = 2 """Label style indicating that columns with a name that conflicts with an existing name should be labeled with a semi-anonymizing label when generating the columns clause of a SELECT statement. @@ -924,17 +922,24 @@ LABEL_STYLE_DISAMBIGUATE_ONLY = util.symbol( .. versionadded:: 1.4 -""", # noqa: E501, -) + """ # noqa: E501 + LABEL_STYLE_DEFAULT = LABEL_STYLE_DISAMBIGUATE_ONLY + """The default label style, refers to + :data:`_sql.LABEL_STYLE_DISAMBIGUATE_ONLY`. -LABEL_STYLE_DEFAULT = LABEL_STYLE_DISAMBIGUATE_ONLY -"""The default label style, refers to -:data:`_sql.LABEL_STYLE_DISAMBIGUATE_ONLY`. + .. versionadded:: 1.4 -.. versionadded:: 1.4 + """ -""" + +( + LABEL_STYLE_NONE, + LABEL_STYLE_TABLENAME_PLUS_COL, + LABEL_STYLE_DISAMBIGUATE_ONLY, +) = list(SelectLabelStyle) + +LABEL_STYLE_DEFAULT = LABEL_STYLE_DISAMBIGUATE_ONLY class Join(roles.DMLTableRole, FromClause): @@ -2870,10 +2875,12 @@ class SelectStatementGrouping(GroupedElement, SelectBase): else: return self - def get_label_style(self): + def get_label_style(self) -> SelectLabelStyle: return self._label_style - def set_label_style(self, label_style): + def set_label_style( + self, label_style: SelectLabelStyle + ) -> "SelectStatementGrouping": return SelectStatementGrouping( self.element.set_label_style(label_style) ) @@ -3018,7 +3025,7 @@ class GenerativeSelect(SelectBase): ) return self - def get_label_style(self): + def get_label_style(self) -> SelectLabelStyle: """ Retrieve the current label style. @@ -3027,14 +3034,16 @@ class GenerativeSelect(SelectBase): """ return self._label_style - def set_label_style(self, style): + def set_label_style( + self: SelfGenerativeSelect, style: SelectLabelStyle + ) -> SelfGenerativeSelect: """Return a new selectable with the specified label style. There are three "label styles" available, - :data:`_sql.LABEL_STYLE_DISAMBIGUATE_ONLY`, - :data:`_sql.LABEL_STYLE_TABLENAME_PLUS_COL`, and - :data:`_sql.LABEL_STYLE_NONE`. The default style is - :data:`_sql.LABEL_STYLE_TABLENAME_PLUS_COL`. + :attr:`_sql.SelectLabelStyle.LABEL_STYLE_DISAMBIGUATE_ONLY`, + :attr:`_sql.SelectLabelStyle.LABEL_STYLE_TABLENAME_PLUS_COL`, and + :attr:`_sql.SelectLabelStyle.LABEL_STYLE_NONE`. The default style is + :attr:`_sql.SelectLabelStyle.LABEL_STYLE_TABLENAME_PLUS_COL`. In modern SQLAlchemy, there is not generally a need to change the labeling style, as per-expression labels are more effectively used by @@ -4131,7 +4140,7 @@ class Select( stmt.__dict__.update(kw) return stmt - def __init__(self, *entities: Union[roles.ColumnsClauseRole, Type]): + def __init__(self, *entities: _ColumnsClauseElement): r"""Construct a new :class:`_expression.Select`. The public constructor for :class:`_expression.Select` is the diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index dd29b2c3a..6b878dc70 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -13,6 +13,7 @@ import typing from typing import Any from typing import Callable from typing import Generic +from typing import Optional from typing import Tuple from typing import Type from typing import TypeVar @@ -21,7 +22,7 @@ from typing import Union from .base import SchemaEventTarget from .cache_key import NO_CACHE from .operators import ColumnOperators -from .visitors import Traversible +from .visitors import Visitable from .. import exc from .. import util @@ -52,7 +53,7 @@ _CT = TypeVar("_CT", bound=Any) SelfTypeEngine = typing.TypeVar("SelfTypeEngine", bound="TypeEngine") -class TypeEngine(Traversible, Generic[_T]): +class TypeEngine(Visitable, Generic[_T]): """The ultimate base class for all SQL datatypes. Common subclasses of :class:`.TypeEngine` include @@ -573,7 +574,7 @@ class TypeEngine(Traversible, Generic[_T]): raise NotImplementedError() def with_variant( - self: SelfTypeEngine, type_: "TypeEngine", dialect_name: str + self: SelfTypeEngine, type_: "TypeEngine", *dialect_names: str ) -> SelfTypeEngine: r"""Produce a copy of this type object that will utilize the given type when applied to the dialect of the given name. @@ -586,7 +587,7 @@ class TypeEngine(Traversible, Generic[_T]): string_type = String() string_type = string_type.with_variant( - mysql.VARCHAR(collation='foo'), 'mysql' + mysql.VARCHAR(collation='foo'), 'mysql', 'mariadb' ) The variant mapping indicates that when this type is @@ -602,16 +603,20 @@ class TypeEngine(Traversible, Generic[_T]): :param type\_: a :class:`.TypeEngine` that will be selected as a variant from the originating type, when a dialect of the given name is in use. - :param dialect_name: base name of the dialect which uses - this type. (i.e. ``'postgresql'``, ``'mysql'``, etc.) + :param \*dialect_names: one or more base names of the dialect which + uses this type. (i.e. ``'postgresql'``, ``'mysql'``, etc.) + + .. versionchanged:: 2.0 multiple dialect names can be specified + for one variant. """ - if dialect_name in self._variant_mapping: - raise exc.ArgumentError( - "Dialect '%s' is already present in " - "the mapping for this %r" % (dialect_name, self) - ) + for dialect_name in dialect_names: + if dialect_name in self._variant_mapping: + raise exc.ArgumentError( + "Dialect '%s' is already present in " + "the mapping for this %r" % (dialect_name, self) + ) new_type = self.copy() if isinstance(type_, type): type_ = type_() @@ -620,8 +625,9 @@ class TypeEngine(Traversible, Generic[_T]): "can't pass a type that already has variants as a " "dialect-level type to with_variant()" ) + new_type._variant_mapping = self._variant_mapping.union( - {dialect_name: type_} + {dialect_name: type_ for dialect_name in dialect_names} ) return new_type @@ -919,7 +925,7 @@ class ExternalType: """ - cache_ok = None + cache_ok: Optional[bool] = None """Indicate if statements using this :class:`.ExternalType` are "safe to cache". @@ -1357,6 +1363,8 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]): _is_type_decorator = True + impl: Union[TypeEngine[Any], Type[TypeEngine[Any]]] + def __init__(self, *args, **kwargs): """Construct a :class:`.TypeDecorator`. diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py index 268a56421..c1ca670da 100644 --- a/lib/sqlalchemy/testing/config.py +++ b/lib/sqlalchemy/testing/config.py @@ -6,6 +6,11 @@ # the MIT License: https://www.opensource.org/licenses/mit-license.php import collections +import typing +from typing import Any +from typing import Iterable +from typing import Tuple +from typing import Union from .. import util @@ -20,10 +25,15 @@ any_async = False _current = None ident = "main" -_fixture_functions = None # installed by plugin_base +if typing.TYPE_CHECKING: + from .plugin.plugin_base import FixtureFunctions + _fixture_functions: FixtureFunctions +else: + _fixture_functions = None # installed by plugin_base -def combinations(*comb, **kw): + +def combinations(*comb: Union[Any, Tuple[Any, ...]], **kw: str): r"""Deliver multiple versions of a test based on positional combinations. This is a facade over pytest.mark.parametrize. @@ -89,25 +99,32 @@ def combinations(*comb, **kw): return _fixture_functions.combinations(*comb, **kw) -def combinations_list(arg_iterable, **kw): +def combinations_list( + arg_iterable: Iterable[ + Tuple[ + Any, + ] + ], + **kw, +): "As combination, but takes a single iterable" return combinations(*arg_iterable, **kw) -def fixture(*arg, **kw): +def fixture(*arg: Any, **kw: Any) -> Any: return _fixture_functions.fixture(*arg, **kw) -def get_current_test_name(): +def get_current_test_name() -> str: return _fixture_functions.get_current_test_name() -def mark_base_test_class(): +def mark_base_test_class() -> Any: return _fixture_functions.mark_base_test_class() class _AddToMarker: - def __getattr__(self, attr): + def __getattr__(self, attr: str) -> Any: return getattr(_fixture_functions.add_to_marker, attr) diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index ecc20f163..7228e5afe 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -20,6 +20,7 @@ from .util import drop_all_tables_from_metadata from .. import event from .. import util from ..orm import declarative_base +from ..orm import DeclarativeBase from ..orm import registry from ..schema import sort_tables_and_constraints @@ -82,6 +83,21 @@ class TestBase: yield reg reg.dispose() + @config.fixture + def decl_base(self, metadata): + _md = metadata + + class Base(DeclarativeBase): + metadata = _md + type_annotation_map = { + str: sa.String().with_variant( + sa.String(50), "mysql", "mariadb" + ) + } + + yield Base + Base.registry.dispose() + @config.fixture() def future_connection(self, future_engine, connection): # integrate the future_engine and connection fixtures so diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py index 0b4451b3c..52e42bb97 100644 --- a/lib/sqlalchemy/testing/plugin/plugin_base.py +++ b/lib/sqlalchemy/testing/plugin/plugin_base.py @@ -19,6 +19,7 @@ import logging import os import re import sys +from typing import Any from sqlalchemy.testing import asyncio @@ -738,7 +739,7 @@ class FixtureFunctions(abc.ABC): raise NotImplementedError() @abc.abstractmethod - def mark_base_test_class(self): + def mark_base_test_class(self) -> Any: raise NotImplementedError() @abc.abstractproperty diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index 410ab26ed..41e5d6772 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -1326,6 +1326,18 @@ class SuiteRequirements(Requirements): return exclusions.only_if(check) @property + def no_sqlalchemy2_stubs(self): + def check(config): + try: + __import__("sqlalchemy-stubs.ext.mypy") + except ImportError: + return False + else: + return True + + return exclusions.skip_if(check) + + @property def python38(self): return exclusions.only_if( lambda: util.py38, "Python 3.8 or above required" diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index 91d15aae0..85bbca20f 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -6,131 +6,135 @@ # the MIT License: https://www.opensource.org/licenses/mit-license.php -from collections import defaultdict -from functools import partial -from functools import update_wrapper +from collections import defaultdict as defaultdict +from functools import partial as partial +from functools import update_wrapper as update_wrapper -from ._collections import coerce_generator_arg -from ._collections import coerce_to_immutabledict -from ._collections import column_dict -from ._collections import column_set -from ._collections import EMPTY_DICT -from ._collections import EMPTY_SET -from ._collections import FacadeDict -from ._collections import flatten_iterator -from ._collections import has_dupes -from ._collections import has_intersection -from ._collections import IdentitySet -from ._collections import ImmutableContainer -from ._collections import immutabledict -from ._collections import ImmutableProperties -from ._collections import LRUCache -from ._collections import ordered_column_set -from ._collections import OrderedDict -from ._collections import OrderedIdentitySet -from ._collections import OrderedProperties -from ._collections import OrderedSet -from ._collections import PopulateDict -from ._collections import Properties -from ._collections import ScopedRegistry -from ._collections import sort_dictionary -from ._collections import ThreadLocalRegistry -from ._collections import to_column_set -from ._collections import to_list -from ._collections import to_set -from ._collections import unique_list -from ._collections import UniqueAppender -from ._collections import update_copy -from ._collections import WeakPopulateDict -from ._collections import WeakSequence -from ._preloaded import preload_module -from ._preloaded import preloaded -from .compat import arm -from .compat import b -from .compat import b64decode -from .compat import b64encode -from .compat import cmp -from .compat import cpython -from .compat import dataclass_fields -from .compat import decode_backslashreplace -from .compat import dottedgetter -from .compat import has_refcount_gc -from .compat import inspect_getfullargspec -from .compat import local_dataclass_fields -from .compat import next -from .compat import osx -from .compat import py38 -from .compat import py39 -from .compat import pypy -from .compat import win32 -from .concurrency import asyncio -from .concurrency import await_fallback -from .concurrency import await_only -from .concurrency import greenlet_spawn -from .concurrency import is_exit_exception -from .deprecations import became_legacy_20 -from .deprecations import deprecated -from .deprecations import deprecated_cls -from .deprecations import deprecated_params -from .deprecations import deprecated_property -from .deprecations import inject_docstring_text -from .deprecations import moved_20 -from .deprecations import warn_deprecated -from .langhelpers import add_parameter_text -from .langhelpers import as_interface -from .langhelpers import asbool -from .langhelpers import asint -from .langhelpers import assert_arg_type -from .langhelpers import attrsetter -from .langhelpers import bool_or_str -from .langhelpers import chop_traceback -from .langhelpers import class_hierarchy -from .langhelpers import classproperty -from .langhelpers import clsname_as_plain_name -from .langhelpers import coerce_kw_type -from .langhelpers import constructor_copy -from .langhelpers import constructor_key -from .langhelpers import counter -from .langhelpers import create_proxy_methods -from .langhelpers import decode_slice -from .langhelpers import decorator -from .langhelpers import dictlike_iteritems -from .langhelpers import duck_type_collection -from .langhelpers import ellipses_string -from .langhelpers import EnsureKWArg -from .langhelpers import format_argspec_init -from .langhelpers import format_argspec_plus -from .langhelpers import generic_repr -from .langhelpers import get_callable_argspec -from .langhelpers import get_cls_kwargs -from .langhelpers import get_func_kwargs -from .langhelpers import getargspec_init -from .langhelpers import has_compiled_ext -from .langhelpers import HasMemoized -from .langhelpers import hybridmethod -from .langhelpers import hybridproperty -from .langhelpers import iterate_attributes -from .langhelpers import map_bits -from .langhelpers import md5_hex -from .langhelpers import memoized_instancemethod -from .langhelpers import memoized_property -from .langhelpers import MemoizedSlots -from .langhelpers import method_is_overridden -from .langhelpers import methods_equivalent -from .langhelpers import monkeypatch_proxied_specials -from .langhelpers import NoneType -from .langhelpers import only_once -from .langhelpers import PluginLoader -from .langhelpers import portable_instancemethod -from .langhelpers import quoted_token_parser -from .langhelpers import safe_reraise -from .langhelpers import set_creation_order -from .langhelpers import string_or_unprintable -from .langhelpers import symbol -from .langhelpers import TypingOnly -from .langhelpers import unbound_method_to_callable -from .langhelpers import walk_subclasses -from .langhelpers import warn -from .langhelpers import warn_exception -from .langhelpers import warn_limited -from .langhelpers import wrap_callable +from ._collections import coerce_generator_arg as coerce_generator_arg +from ._collections import coerce_to_immutabledict as coerce_to_immutabledict +from ._collections import column_dict as column_dict +from ._collections import column_set as column_set +from ._collections import EMPTY_DICT as EMPTY_DICT +from ._collections import EMPTY_SET as EMPTY_SET +from ._collections import FacadeDict as FacadeDict +from ._collections import flatten_iterator as flatten_iterator +from ._collections import has_dupes as has_dupes +from ._collections import has_intersection as has_intersection +from ._collections import IdentitySet as IdentitySet +from ._collections import ImmutableContainer as ImmutableContainer +from ._collections import immutabledict as immutabledict +from ._collections import ImmutableProperties as ImmutableProperties +from ._collections import LRUCache as LRUCache +from ._collections import merge_lists_w_ordering as merge_lists_w_ordering +from ._collections import ordered_column_set as ordered_column_set +from ._collections import OrderedDict as OrderedDict +from ._collections import OrderedIdentitySet as OrderedIdentitySet +from ._collections import OrderedProperties as OrderedProperties +from ._collections import OrderedSet as OrderedSet +from ._collections import PopulateDict as PopulateDict +from ._collections import Properties as Properties +from ._collections import ScopedRegistry as ScopedRegistry +from ._collections import sort_dictionary as sort_dictionary +from ._collections import ThreadLocalRegistry as ThreadLocalRegistry +from ._collections import to_column_set as to_column_set +from ._collections import to_list as to_list +from ._collections import to_set as to_set +from ._collections import unique_list as unique_list +from ._collections import UniqueAppender as UniqueAppender +from ._collections import update_copy as update_copy +from ._collections import WeakPopulateDict as WeakPopulateDict +from ._collections import WeakSequence as WeakSequence +from ._preloaded import preload_module as preload_module +from ._preloaded import preloaded as preloaded +from .compat import arm as arm +from .compat import b as b +from .compat import b64decode as b64decode +from .compat import b64encode as b64encode +from .compat import cmp as cmp +from .compat import cpython as cpython +from .compat import dataclass_fields as dataclass_fields +from .compat import decode_backslashreplace as decode_backslashreplace +from .compat import dottedgetter as dottedgetter +from .compat import has_refcount_gc as has_refcount_gc +from .compat import inspect_getfullargspec as inspect_getfullargspec +from .compat import local_dataclass_fields as local_dataclass_fields +from .compat import osx as osx +from .compat import py38 as py38 +from .compat import py39 as py39 +from .compat import pypy as pypy +from .compat import win32 as win32 +from .concurrency import await_fallback as await_fallback +from .concurrency import await_only as await_only +from .concurrency import greenlet_spawn as greenlet_spawn +from .concurrency import is_exit_exception as is_exit_exception +from .deprecations import became_legacy_20 as became_legacy_20 +from .deprecations import deprecated as deprecated +from .deprecations import deprecated_cls as deprecated_cls +from .deprecations import deprecated_params as deprecated_params +from .deprecations import deprecated_property as deprecated_property +from .deprecations import moved_20 as moved_20 +from .deprecations import warn_deprecated as warn_deprecated +from .langhelpers import add_parameter_text as add_parameter_text +from .langhelpers import as_interface as as_interface +from .langhelpers import asbool as asbool +from .langhelpers import asint as asint +from .langhelpers import assert_arg_type as assert_arg_type +from .langhelpers import attrsetter as attrsetter +from .langhelpers import bool_or_str as bool_or_str +from .langhelpers import chop_traceback as chop_traceback +from .langhelpers import class_hierarchy as class_hierarchy +from .langhelpers import classproperty as classproperty +from .langhelpers import clsname_as_plain_name as clsname_as_plain_name +from .langhelpers import coerce_kw_type as coerce_kw_type +from .langhelpers import constructor_copy as constructor_copy +from .langhelpers import constructor_key as constructor_key +from .langhelpers import counter as counter +from .langhelpers import create_proxy_methods as create_proxy_methods +from .langhelpers import decode_slice as decode_slice +from .langhelpers import decorator as decorator +from .langhelpers import dictlike_iteritems as dictlike_iteritems +from .langhelpers import duck_type_collection as duck_type_collection +from .langhelpers import ellipses_string as ellipses_string +from .langhelpers import EnsureKWArg as EnsureKWArg +from .langhelpers import format_argspec_init as format_argspec_init +from .langhelpers import format_argspec_plus as format_argspec_plus +from .langhelpers import generic_repr as generic_repr +from .langhelpers import get_annotations as get_annotations +from .langhelpers import get_callable_argspec as get_callable_argspec +from .langhelpers import get_cls_kwargs as get_cls_kwargs +from .langhelpers import get_func_kwargs as get_func_kwargs +from .langhelpers import getargspec_init as getargspec_init +from .langhelpers import has_compiled_ext as has_compiled_ext +from .langhelpers import HasMemoized as HasMemoized +from .langhelpers import hybridmethod as hybridmethod +from .langhelpers import hybridproperty as hybridproperty +from .langhelpers import inject_docstring_text as inject_docstring_text +from .langhelpers import iterate_attributes as iterate_attributes +from .langhelpers import map_bits as map_bits +from .langhelpers import md5_hex as md5_hex +from .langhelpers import memoized_instancemethod as memoized_instancemethod +from .langhelpers import memoized_property as memoized_property +from .langhelpers import MemoizedSlots as MemoizedSlots +from .langhelpers import method_is_overridden as method_is_overridden +from .langhelpers import methods_equivalent as methods_equivalent +from .langhelpers import ( + monkeypatch_proxied_specials as monkeypatch_proxied_specials, +) +from .langhelpers import NoneType as NoneType +from .langhelpers import only_once as only_once +from .langhelpers import PluginLoader as PluginLoader +from .langhelpers import portable_instancemethod as portable_instancemethod +from .langhelpers import quoted_token_parser as quoted_token_parser +from .langhelpers import safe_reraise as safe_reraise +from .langhelpers import set_creation_order as set_creation_order +from .langhelpers import string_or_unprintable as string_or_unprintable +from .langhelpers import symbol as symbol +from .langhelpers import TypingOnly as TypingOnly +from .langhelpers import ( + unbound_method_to_callable as unbound_method_to_callable, +) +from .langhelpers import walk_subclasses as walk_subclasses +from .langhelpers import warn as warn +from .langhelpers import warn_exception as warn_exception +from .langhelpers import warn_limited as warn_limited +from .langhelpers import wrap_callable as wrap_callable diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index 3e4ef1310..850986802 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -34,19 +34,27 @@ from ._has_cy import HAS_CYEXTENSION from .typing import Literal if typing.TYPE_CHECKING or not HAS_CYEXTENSION: - from ._py_collections import immutabledict - from ._py_collections import IdentitySet - from ._py_collections import ImmutableContainer - from ._py_collections import ImmutableDictBase - from ._py_collections import OrderedSet - from ._py_collections import unique_list # noqa + from ._py_collections import immutabledict as immutabledict + from ._py_collections import IdentitySet as IdentitySet + from ._py_collections import ImmutableContainer as ImmutableContainer + from ._py_collections import ImmutableDictBase as ImmutableDictBase + from ._py_collections import OrderedSet as OrderedSet + from ._py_collections import unique_list as unique_list else: - from sqlalchemy.cyextension.immutabledict import ImmutableContainer - from sqlalchemy.cyextension.immutabledict import ImmutableDictBase - from sqlalchemy.cyextension.immutabledict import immutabledict - from sqlalchemy.cyextension.collections import IdentitySet - from sqlalchemy.cyextension.collections import OrderedSet - from sqlalchemy.cyextension.collections import unique_list # noqa + from sqlalchemy.cyextension.immutabledict import ( + ImmutableContainer as ImmutableContainer, + ) + from sqlalchemy.cyextension.immutabledict import ( + ImmutableDictBase as ImmutableDictBase, + ) + from sqlalchemy.cyextension.immutabledict import ( + immutabledict as immutabledict, + ) + from sqlalchemy.cyextension.collections import IdentitySet as IdentitySet + from sqlalchemy.cyextension.collections import OrderedSet as OrderedSet + from sqlalchemy.cyextension.collections import ( # noqa + unique_list as unique_list, + ) _T = TypeVar("_T", bound=Any) @@ -57,6 +65,62 @@ _VT = TypeVar("_VT", bound=Any) EMPTY_SET: FrozenSet[Any] = frozenset() +def merge_lists_w_ordering(a, b): + """merge two lists, maintaining ordering as much as possible. + + this is to reconcile vars(cls) with cls.__annotations__. + + Example:: + + >>> a = ['__tablename__', 'id', 'x', 'created_at'] + >>> b = ['id', 'name', 'data', 'y', 'created_at'] + >>> merge_lists_w_ordering(a, b) + ['__tablename__', 'id', 'name', 'data', 'y', 'x', 'created_at'] + + This is not necessarily the ordering that things had on the class, + in this case the class is:: + + class User(Base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + data: Mapped[Optional[str]] + x = Column(Integer) + y: Mapped[int] + created_at: Mapped[datetime.datetime] = mapped_column() + + But things are *mostly* ordered. + + The algorithm could also be done by creating a partial ordering for + all items in both lists and then using topological_sort(), but that + is too much overhead. + + Background on how I came up with this is at: + https://gist.github.com/zzzeek/89de958cf0803d148e74861bd682ebae + + """ + overlap = set(a).intersection(b) + + result = [] + + current, other = iter(a), iter(b) + + while True: + for element in current: + if element in overlap: + overlap.discard(element) + other, current = current, other + break + + result.append(element) + else: + result.extend(other) + break + + return result + + def coerce_to_immutabledict(d): if not d: return EMPTY_DICT diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index 0f4befbb1..62cffa556 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -39,7 +39,6 @@ arm = "aarch" in platform.machine().lower() has_refcount_gc = bool(cpython) dottedgetter = operator.attrgetter -next = next # noqa class FullArgSpec(typing.NamedTuple): diff --git a/lib/sqlalchemy/util/concurrency.py b/lib/sqlalchemy/util/concurrency.py index 57ef23006..6b94a2294 100644 --- a/lib/sqlalchemy/util/concurrency.py +++ b/lib/sqlalchemy/util/concurrency.py @@ -16,15 +16,17 @@ except ImportError as e: pass else: have_greenlet = True - from ._concurrency_py3k import await_only - from ._concurrency_py3k import await_fallback - from ._concurrency_py3k import greenlet_spawn - from ._concurrency_py3k import is_exit_exception - from ._concurrency_py3k import AsyncAdaptedLock - from ._concurrency_py3k import _util_async_run # noqa F401 + from ._concurrency_py3k import await_only as await_only + from ._concurrency_py3k import await_fallback as await_fallback + from ._concurrency_py3k import greenlet_spawn as greenlet_spawn + from ._concurrency_py3k import is_exit_exception as is_exit_exception + from ._concurrency_py3k import AsyncAdaptedLock as AsyncAdaptedLock from ._concurrency_py3k import ( - _util_async_run_coroutine_function, - ) # noqa F401, E501 + _util_async_run as _util_async_run, + ) # noqa F401 + from ._concurrency_py3k import ( + _util_async_run_coroutine_function as _util_async_run_coroutine_function, # noqa F401, E501 + ) if not have_greenlet: diff --git a/lib/sqlalchemy/util/deprecations.py b/lib/sqlalchemy/util/deprecations.py index 565cbafe2..7c2586166 100644 --- a/lib/sqlalchemy/util/deprecations.py +++ b/lib/sqlalchemy/util/deprecations.py @@ -13,6 +13,7 @@ from typing import Any from typing import Callable from typing import cast from typing import Optional +from typing import Tuple from typing import TypeVar from . import compat @@ -209,7 +210,10 @@ def became_legacy_20(api_name, alternative=None, **kw): return deprecated("2.0", message=message, warning=warning_cls, **kw) -def deprecated_params(**specs): +_C = TypeVar("_C", bound=Callable[..., Any]) + + +def deprecated_params(**specs: Tuple[str, str]) -> Callable[[_C], _C]: """Decorates a function to warn on use of certain parameters. e.g. :: diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 9401c249f..ed879894d 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -30,6 +30,7 @@ from typing import FrozenSet from typing import Generic from typing import Iterator from typing import List +from typing import Mapping from typing import Optional from typing import overload from typing import Sequence @@ -54,6 +55,30 @@ _HP = TypeVar("_HP", bound="hybridproperty") _HM = TypeVar("_HM", bound="hybridmethod") +if compat.py310: + + def get_annotations(obj: Any) -> Mapping[str, Any]: + return inspect.get_annotations(obj) + +else: + + def get_annotations(obj: Any) -> Mapping[str, Any]: + # it's been observed that cls.__annotations__ can be non present. + # it's not clear what causes this, running under tox py37/38 it + # happens, running straight pytest it doesnt + + # https://docs.python.org/3/howto/annotations.html#annotations-howto + if isinstance(obj, type): + ann = obj.__dict__.get("__annotations__", None) + else: + ann = getattr(obj, "__annotations__", None) + + if ann is None: + return _collections.EMPTY_DICT + else: + return cast("Mapping[str, Any]", ann) + + def md5_hex(x: Any) -> str: x = x.encode("utf-8") m = hashlib.md5() diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 62a9f6c8a..56ea4d0e0 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -1,6 +1,10 @@ +import sys import typing from typing import Any from typing import Callable # noqa +from typing import cast +from typing import Dict +from typing import ForwardRef from typing import Generic from typing import overload from typing import Type @@ -13,21 +17,36 @@ from . import compat _T = TypeVar("_T", bound=Any) -if typing.TYPE_CHECKING or not compat.py38: - from typing_extensions import Literal # noqa F401 - from typing_extensions import Protocol # noqa F401 - from typing_extensions import TypedDict # noqa F401 +if compat.py310: + # why they took until py310 to put this in stdlib is beyond me, + # I've been wanting it since py27 + from types import NoneType else: - from typing import Literal # noqa F401 - from typing import Protocol # noqa F401 - from typing import TypedDict # noqa F401 + NoneType = type(None) # type: ignore + +if typing.TYPE_CHECKING or compat.py310: + from typing import Annotated as Annotated +else: + from typing_extensions import Annotated as Annotated # noqa F401 + +if typing.TYPE_CHECKING or compat.py38: + from typing import Literal as Literal + from typing import Protocol as Protocol + from typing import TypedDict as TypedDict +else: + from typing_extensions import Literal as Literal # noqa F401 + from typing_extensions import Protocol as Protocol # noqa F401 + from typing_extensions import TypedDict as TypedDict # noqa F401 + +# work around https://github.com/microsoft/pyright/issues/3025 +_LiteralStar = Literal["*"] if typing.TYPE_CHECKING or not compat.py310: - from typing_extensions import Concatenate # noqa F401 - from typing_extensions import ParamSpec # noqa F401 + from typing_extensions import Concatenate as Concatenate + from typing_extensions import ParamSpec as ParamSpec else: - from typing import Concatenate # noqa F401 - from typing import ParamSpec # noqa F401 + from typing import Concatenate as Concatenate # noqa F401 + from typing import ParamSpec as ParamSpec # noqa F401 class _TypeToInstance(Generic[_T]): @@ -76,3 +95,121 @@ class ReadOnlyInstanceDescriptor(Protocol[_T]): self, instance: object, owner: Any ) -> Union["ReadOnlyInstanceDescriptor[_T]", _T]: ... + + +def de_stringify_annotation( + cls: Type[Any], annotation: Union[str, Type[Any]] +) -> Union[str, Type[Any]]: + """Resolve annotations that may be string based into real objects. + + This is particularly important if a module defines "from __future__ import + annotations", as everything inside of __annotations__ is a string. We want + to at least have generic containers like ``Mapped``, ``Union``, ``List``, + etc. + + """ + + # looked at typing.get_type_hints(), looked at pydantic. We need much + # less here, and we here try to not use any private typing internals + # or construct ForwardRef objects which is documented as something + # that should be avoided. + + if ( + is_fwd_ref(annotation) + and not cast(ForwardRef, annotation).__forward_evaluated__ + ): + annotation = cast(ForwardRef, annotation).__forward_arg__ + + if isinstance(annotation, str): + base_globals: "Dict[str, Any]" = getattr( + sys.modules.get(cls.__module__, None), "__dict__", {} + ) + try: + annotation = eval(annotation, base_globals, None) + except NameError: + pass + return annotation + + +def is_fwd_ref(type_): + return isinstance(type_, ForwardRef) + + +def de_optionalize_union_types(type_): + """Given a type, filter out ``Union`` types that include ``NoneType`` + to not include the ``NoneType``. + + """ + if is_optional(type_): + typ = set(type_.__args__) + + typ.discard(NoneType) + + return make_union_type(*typ) + + else: + return type_ + + +def make_union_type(*types): + """Make a Union type. + + This is needed by :func:`.de_optionalize_union_types` which removes + ``NoneType`` from a ``Union``. + + """ + return cast(Any, Union).__getitem__(types) + + +def expand_unions(type_, include_union=False, discard_none=False): + """Return a type as as a tuple of individual types, expanding for + ``Union`` types.""" + + if is_union(type_): + typ = set(type_.__args__) + + if discard_none: + typ.discard(NoneType) + + if include_union: + return (type_,) + tuple(typ) + else: + return tuple(typ) + else: + return (type_,) + + +def is_optional(type_): + return is_origin_of( + type_, + "Optional", + "Union", + ) + + +def is_union(type_): + return is_origin_of(type_, "Union") + + +def is_origin_of(type_, *names, module=None): + """return True if the given type has an __origin__ with the given name + and optional module.""" + + origin = getattr(type_, "__origin__", None) + if origin is None: + return False + + return _get_type_name(origin) in names and ( + module is None or origin.__module__.startswith(module) + ) + + +def _get_type_name(type_): + if compat.py310: + return type_.__name__ + else: + typ_name = getattr(type_, "__name__", None) + if typ_name is None: + typ_name = getattr(type_, "_name", None) + + return typ_name |
