diff options
Diffstat (limited to 'lib')
42 files changed, 2939 insertions, 475 deletions
diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py index 29dd6aff9..afba17075 100644 --- a/lib/sqlalchemy/engine/__init__.py +++ b/lib/sqlalchemy/engine/__init__.py @@ -46,6 +46,7 @@ from .result import MergedResult as MergedResult from .result import Result as Result from .result import result_tuple as result_tuple from .result import ScalarResult as ScalarResult +from .result import TupleResult as TupleResult from .row import BaseRow as BaseRow from .row import Row as Row from .row import RowMapping as RowMapping diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index a325da929..fe3bfa1ad 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -18,8 +18,10 @@ from typing import Mapping from typing import MutableMapping from typing import NoReturn from typing import Optional +from typing import overload from typing import Tuple from typing import Type +from typing import TypeVar from typing import Union from .interfaces import _IsolationLevel @@ -45,12 +47,10 @@ if typing.TYPE_CHECKING: from . import ScalarResult from .interfaces import _AnyExecuteParams from .interfaces import _AnyMultiExecuteParams - from .interfaces import _AnySingleExecuteParams from .interfaces import _CoreAnyExecuteParams from .interfaces import _CoreMultiExecuteParams from .interfaces import _CoreSingleExecuteParams from .interfaces import _DBAPIAnyExecuteParams - from .interfaces import _DBAPIMultiExecuteParams from .interfaces import _DBAPISingleExecuteParams from .interfaces import _ExecuteOptions from .interfaces import _ExecuteOptionsParameter @@ -65,21 +65,21 @@ if typing.TYPE_CHECKING: from ..pool import PoolProxiedConnection from ..sql import Executable from ..sql._typing import _InfoType - from ..sql.base import SchemaVisitor from ..sql.compiler import Compiled from ..sql.ddl import ExecutableDDLElement from ..sql.ddl import SchemaDropper from ..sql.ddl import SchemaGenerator from ..sql.functions import FunctionElement - from ..sql.schema import ColumnDefault from ..sql.schema import DefaultGenerator from ..sql.schema import HasSchemaAttr from ..sql.schema import SchemaItem + from ..sql.selectable import TypedReturnsRows """Defines :class:`_engine.Connection` and :class:`_engine.Engine`. """ +_T = TypeVar("_T", bound=Any) _EMPTY_EXECUTION_OPTS: _ExecuteOptions = util.EMPTY_DICT NO_OPTIONS: Mapping[str, Any] = util.EMPTY_DICT @@ -1142,10 +1142,31 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): self._dbapi_connection = None self.__can_reconnect = False + @overload + def scalar( + self, + statement: TypedReturnsRows[Tuple[_T]], + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> Optional[_T]: + ... + + @overload + def scalar( + self, + statement: Executable, + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> Any: + ... + def scalar( self, statement: Executable, parameters: Optional[_CoreSingleExecuteParams] = None, + *, execution_options: Optional[_ExecuteOptionsParameter] = None, ) -> Any: r"""Executes a SQL statement construct and returns a scalar object. @@ -1170,10 +1191,31 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): execution_options or NO_OPTIONS, ) + @overload + def scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> ScalarResult[_T]: + ... + + @overload def scalars( self, statement: Executable, parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> ScalarResult[Any]: + ... + + def scalars( + self, + statement: Executable, + parameters: Optional[_CoreSingleExecuteParams] = None, + *, execution_options: Optional[_ExecuteOptionsParameter] = None, ) -> ScalarResult[Any]: """Executes and returns a scalar result set, which yields scalar values @@ -1190,14 +1232,37 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): """ - return self.execute(statement, parameters, execution_options).scalars() + return self.execute( + statement, parameters, execution_options=execution_options + ).scalars() + + @overload + def execute( + self, + statement: TypedReturnsRows[_T], + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> CursorResult[_T]: + ... + + @overload + def execute( + self, + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> CursorResult[Any]: + ... def execute( self, statement: Executable, parameters: Optional[_CoreAnyExecuteParams] = None, + *, execution_options: Optional[_ExecuteOptionsParameter] = None, - ) -> CursorResult: + ) -> CursorResult[Any]: r"""Executes a SQL statement construct and returns a :class:`_engine.CursorResult`. @@ -1246,7 +1311,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): func: FunctionElement[Any], distilled_parameters: _CoreMultiExecuteParams, execution_options: _ExecuteOptionsParameter, - ) -> CursorResult: + ) -> CursorResult[Any]: """Execute a sql.FunctionElement object.""" return self._execute_clauseelement( @@ -1317,7 +1382,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): ddl: ExecutableDDLElement, distilled_parameters: _CoreMultiExecuteParams, execution_options: _ExecuteOptionsParameter, - ) -> CursorResult: + ) -> CursorResult[Any]: """Execute a schema.DDL object.""" execution_options = ddl._execution_options.merge_with( @@ -1414,7 +1479,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): elem: Executable, distilled_parameters: _CoreMultiExecuteParams, execution_options: _ExecuteOptionsParameter, - ) -> CursorResult: + ) -> CursorResult[Any]: """Execute a sql.ClauseElement object.""" execution_options = elem._execution_options.merge_with( @@ -1487,7 +1552,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): compiled: Compiled, distilled_parameters: _CoreMultiExecuteParams, execution_options: _ExecuteOptionsParameter = _EMPTY_EXECUTION_OPTS, - ) -> CursorResult: + ) -> CursorResult[Any]: """Execute a sql.Compiled object. TODO: why do we have this? likely deprecate or remove @@ -1537,7 +1602,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): statement: str, parameters: Optional[_DBAPIAnyExecuteParams] = None, execution_options: Optional[_ExecuteOptionsParameter] = None, - ) -> CursorResult: + ) -> CursorResult[Any]: r"""Executes a SQL statement construct and returns a :class:`_engine.CursorResult`. @@ -1614,7 +1679,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): execution_options: _ExecuteOptions, *args: Any, **kw: Any, - ) -> CursorResult: + ) -> CursorResult[Any]: """Create an :class:`.ExecutionContext` and execute, returning a :class:`_engine.CursorResult`.""" diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index ccf573675..ff69666b7 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -24,6 +24,7 @@ from typing import Optional from typing import Sequence from typing import Tuple from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union from .result import MergedResult @@ -55,11 +56,12 @@ if typing.TYPE_CHECKING: from .interfaces import ExecutionContext from .result import _KeyIndexType from .result import _KeyMapRecType - from .result import _KeyMapType from .result import _KeyType from .result import _ProcessorsType from ..sql.type_api import _ResultProcessorType +_T = TypeVar("_T", bound=Any) + # metadata entry tuple indexes. # using raw tuple is faster than namedtuple. MD_INDEX: Literal[0] = 0 # integer index in cursor.description @@ -214,7 +216,9 @@ class CursorResultMetaData(ResultMetaData): return md def __init__( - self, parent: CursorResult, cursor_description: _DBAPICursorDescription + self, + parent: CursorResult[Any], + cursor_description: _DBAPICursorDescription, ): context = parent.context self._tuplefilter = None @@ -1158,7 +1162,7 @@ class _NoResultMetaData(ResultMetaData): _NO_RESULT_METADATA = _NoResultMetaData() -class CursorResult(Result): +class CursorResult(Result[_T]): """A Result that is representing state from a DBAPI cursor. .. versionchanged:: 1.4 The :class:`.CursorResult`` @@ -1179,6 +1183,15 @@ class CursorResult(Result): """ + __slots__ = ( + "context", + "dialect", + "cursor", + "cursor_strategy", + "_echo", + "connection", + ) + _metadata: Union[CursorResultMetaData, _NoResultMetaData] _no_result_metadata = _NO_RESULT_METADATA _soft_closed: bool = False @@ -1231,7 +1244,6 @@ class CursorResult(Result): make_row = _make_row_2 else: make_row = _make_row - self._set_memoized_attribute("_row_getter", make_row) else: @@ -1726,12 +1738,12 @@ class CursorResult(Result): def _raw_row_iterator(self): return self._fetchiter_impl() - def merge(self, *others: Result) -> MergedResult: + def merge(self, *others: Result[Any]) -> MergedResult[Any]: merged_result = super().merge(*others) setup_rowcounts = not self._metadata.returns_rows if setup_rowcounts: merged_result.rowcount = sum( - cast(CursorResult, result).rowcount + cast("CursorResult[Any]", result).rowcount for result in (self,) + others ) return merged_result diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index c6571f68b..9c6ff758f 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -62,13 +62,9 @@ if typing.TYPE_CHECKING: from .base import Connection from .base import Engine - from .characteristics import ConnectionCharacteristic - from .interfaces import _AnyMultiExecuteParams from .interfaces import _CoreMultiExecuteParams from .interfaces import _CoreSingleExecuteParams - from .interfaces import _DBAPIAnyExecuteParams from .interfaces import _DBAPIMultiExecuteParams - from .interfaces import _DBAPISingleExecuteParams from .interfaces import _ExecuteOptions from .interfaces import _IsolationLevel from .interfaces import _MutableCoreSingleExecuteParams @@ -83,15 +79,11 @@ if typing.TYPE_CHECKING: from ..sql.compiler import Compiled from ..sql.compiler import Linting from ..sql.compiler import ResultColumnsEntry - from ..sql.compiler import TypeCompiler from ..sql.dml import DMLState from ..sql.dml import UpdateBase from ..sql.elements import BindParameter - from ..sql.roles import ColumnsClauseRole from ..sql.schema import Column - from ..sql.schema import ColumnDefault from ..sql.type_api import _BindProcessorType - from ..sql.type_api import _ResultProcessorType from ..sql.type_api import TypeEngine # When we're handed literal SQL, ensure it's a SELECT query @@ -781,7 +773,7 @@ class DefaultExecutionContext(ExecutionContext): result_column_struct: Optional[ Tuple[List[ResultColumnsEntry], bool, bool, bool] ] = None - returned_default_rows: Optional[List[Row]] = None + returned_default_rows: Optional[Sequence[Row[Any]]] = None execution_options: _ExecuteOptions = util.EMPTY_DICT @@ -1385,7 +1377,9 @@ class DefaultExecutionContext(ExecutionContext): if cursor_description is None: strategy = _cursor._NO_CURSOR_DML - result = _cursor.CursorResult(self, strategy, cursor_description) + result: _cursor.CursorResult[Any] = _cursor.CursorResult( + self, strategy, cursor_description + ) if self.isinsert: if self._is_implicit_returning: diff --git a/lib/sqlalchemy/engine/events.py b/lib/sqlalchemy/engine/events.py index ef10946a8..4093d3e0e 100644 --- a/lib/sqlalchemy/engine/events.py +++ b/lib/sqlalchemy/engine/events.py @@ -28,7 +28,6 @@ from ..util.typing import Literal if typing.TYPE_CHECKING: from .base import Connection - from .interfaces import _CoreAnyExecuteParams from .interfaces import _CoreMultiExecuteParams from .interfaces import _CoreSingleExecuteParams from .interfaces import _DBAPIAnyExecuteParams @@ -273,7 +272,7 @@ class ConnectionEvents(event.Events[ConnectionEventsTarget]): multiparams: _CoreMultiExecuteParams, params: _CoreSingleExecuteParams, execution_options: _ExecuteOptions, - result: Result, + result: Result[Any], ) -> None: """Intercept high level execute() events after execute. diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index 54fe21d74..641024603 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -2422,7 +2422,7 @@ class ExecutionContext: def _get_cache_stats(self) -> str: raise NotImplementedError() - def _setup_result_proxy(self) -> CursorResult: + def _setup_result_proxy(self) -> CursorResult[Any]: raise NotImplementedError() def fire_sequence(self, seq: Sequence_SchemaItem, type_: Integer) -> int: diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index 71320a583..55d36a1d5 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -28,6 +28,7 @@ from typing import overload from typing import Sequence from typing import Set from typing import Tuple +from typing import TYPE_CHECKING from typing import TypeVar from typing import Union @@ -70,6 +71,8 @@ _RawRowType = Tuple[Any, ...] """represents the kind of row we get from a DBAPI cursor""" _R = TypeVar("_R", bound=_RowData) +_T = TypeVar("_T", bound=Any) +_TP = TypeVar("_TP", bound=Tuple[Any, ...]) _InterimRowType = Union[_R, _RawRowType] """a catchall "anything" kind of return type that can be applied @@ -141,7 +144,7 @@ class ResultMetaData: def _getter( self, key: Any, raiseerr: bool = True - ) -> Optional[Callable[[Row], Any]]: + ) -> Optional[Callable[[Row[Any]], Any]]: index = self._index_for_key(key, raiseerr) @@ -270,7 +273,7 @@ class SimpleResultMetaData(ResultMetaData): _tuplefilter=_tuplefilter, ) - def _contains(self, value: Any, row: Row) -> bool: + def _contains(self, value: Any, row: Row[Any]) -> bool: return value in row._data def _index_for_key(self, key: Any, raiseerr: bool = True) -> int: @@ -335,7 +338,7 @@ class SimpleResultMetaData(ResultMetaData): def result_tuple( fields: Sequence[str], extra: Optional[Any] = None -) -> Callable[[Iterable[Any]], Row]: +) -> Callable[[Iterable[Any]], Row[Any]]: parent = SimpleResultMetaData(fields, extra) return functools.partial( Row, parent, parent._processors, parent._keymap, Row._default_key_style @@ -355,7 +358,9 @@ SelfResultInternal = TypeVar("SelfResultInternal", bound="ResultInternal[Any]") class ResultInternal(InPlaceGenerative, Generic[_R]): - _real_result: Optional[Result] = None + __slots__ = () + + _real_result: Optional[Result[Any]] = None _generate_rows: bool = True _row_logging_fn: Optional[Callable[[Any], Any]] @@ -367,20 +372,20 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): _source_supports_scalars: bool - def _fetchiter_impl(self) -> Iterator[_InterimRowType[Row]]: + def _fetchiter_impl(self) -> Iterator[_InterimRowType[Row[Any]]]: raise NotImplementedError() def _fetchone_impl( self, hard_close: bool = False - ) -> Optional[_InterimRowType[Row]]: + ) -> Optional[_InterimRowType[Row[Any]]]: raise NotImplementedError() def _fetchmany_impl( self, size: Optional[int] = None - ) -> List[_InterimRowType[Row]]: + ) -> List[_InterimRowType[Row[Any]]]: raise NotImplementedError() - def _fetchall_impl(self) -> List[_InterimRowType[Row]]: + def _fetchall_impl(self) -> List[_InterimRowType[Row[Any]]]: raise NotImplementedError() def _soft_close(self, hard: bool = False) -> None: @@ -388,8 +393,10 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): @HasMemoized_ro_memoized_attribute def _row_getter(self) -> Optional[Callable[..., _R]]: - real_result: Result = ( - self._real_result if self._real_result else cast(Result, self) + real_result: Result[Any] = ( + self._real_result + if self._real_result + else cast("Result[Any]", self) ) if real_result._source_supports_scalars: @@ -404,7 +411,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): keymap: _KeyMapType, key_style: Any, scalar_obj: Any, - ) -> Row: + ) -> Row[Any]: return _proc( metadata, processors, keymap, key_style, (scalar_obj,) ) @@ -429,7 +436,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): fixed_tf = tf - def make_row(row: _InterimRowType[Row]) -> _R: + def make_row(row: _InterimRowType[Row[Any]]) -> _R: return _make_row_orig(fixed_tf(row)) else: @@ -447,7 +454,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): if fns: _make_row = make_row - def make_row(row: _InterimRowType[Row]) -> _R: + def make_row(row: _InterimRowType[Row[Any]]) -> _R: interim_row = _make_row(row) for fn in fns: interim_row = fn(interim_row) @@ -465,7 +472,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): if self._unique_filter_state: uniques, strategy = self._unique_strategy - def iterrows(self: Result) -> Iterator[_R]: + def iterrows(self: Result[Any]) -> Iterator[_R]: for raw_row in self._fetchiter_impl(): obj: _InterimRowType[Any] = ( make_row(raw_row) if make_row else raw_row @@ -480,7 +487,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): else: - def iterrows(self: Result) -> Iterator[_R]: + def iterrows(self: Result[Any]) -> Iterator[_R]: for raw_row in self._fetchiter_impl(): row: _InterimRowType[Any] = ( make_row(raw_row) if make_row else raw_row @@ -546,7 +553,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): if self._unique_filter_state: uniques, strategy = self._unique_strategy - def onerow(self: Result) -> Union[_NoRow, _R]: + def onerow(self: Result[Any]) -> Union[_NoRow, _R]: _onerow = self._fetchone_impl while True: row = _onerow() @@ -567,7 +574,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): else: - def onerow(self: Result) -> Union[_NoRow, _R]: + def onerow(self: Result[Any]) -> Union[_NoRow, _R]: row = self._fetchone_impl() if row is None: return _NO_ROW @@ -627,7 +634,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): real_result = ( self._real_result if self._real_result - else cast(Result, self) + else cast("Result[Any]", self) ) if real_result._yield_per: num_required = num = real_result._yield_per @@ -667,7 +674,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): real_result = ( self._real_result if self._real_result - else cast(Result, self) + else cast("Result[Any]", self) ) num = real_result._yield_per @@ -799,7 +806,9 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): self: SelfResultInternal, indexes: Sequence[_KeyIndexType] ) -> SelfResultInternal: real_result = ( - self._real_result if self._real_result else cast(Result, self) + self._real_result + if self._real_result + else cast("Result[Any]", self) ) if not real_result._source_supports_scalars or len(indexes) != 1: @@ -817,7 +826,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): real_result = ( self._real_result if self._real_result is not None - else cast(Result, self) + else cast("Result[Any]", self) ) if not strategy and self._metadata._unique_filters: @@ -836,6 +845,8 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): class _WithKeys: + __slots__ = () + _metadata: ResultMetaData # used mainly to share documentation on the keys method. @@ -859,10 +870,10 @@ class _WithKeys: return self._metadata.keys -SelfResult = TypeVar("SelfResult", bound="Result") +SelfResult = TypeVar("SelfResult", bound="Result[Any]") -class Result(_WithKeys, ResultInternal[Row]): +class Result(_WithKeys, ResultInternal[Row[_TP]]): """Represent a set of database results. .. versionadded:: 1.4 The :class:`.Result` object provides a completely @@ -887,7 +898,9 @@ class Result(_WithKeys, ResultInternal[Row]): """ - _row_logging_fn: Optional[Callable[[Row], Row]] = None + __slots__ = ("_metadata", "__dict__") + + _row_logging_fn: Optional[Callable[[Row[Any]], Row[Any]]] = None _source_supports_scalars: bool = False @@ -1011,6 +1024,15 @@ class Result(_WithKeys, ResultInternal[Row]): appropriate :class:`.ColumnElement` objects which correspond to a given statement construct. + .. versionchanged:: 2.0 Due to a bug in 1.4, the + :meth:`.Result.columns` method had an incorrect behavior where + calling upon the method with just one index would cause the + :class:`.Result` object to yield scalar values rather than + :class:`.Row` objects. In version 2.0, this behavior has been + corrected such that calling upon :meth:`.Result.columns` with + a single index will produce a :class:`.Result` object that continues + to yield :class:`.Row` objects, which include only a single column. + E.g.:: statement = select(table.c.x, table.c.y, table.c.z) @@ -1040,6 +1062,20 @@ class Result(_WithKeys, ResultInternal[Row]): """ return self._column_slices(col_expressions) + @overload + def scalars(self: Result[Tuple[_T]]) -> ScalarResult[_T]: + ... + + @overload + def scalars( + self: Result[Tuple[_T]], index: Literal[0] + ) -> ScalarResult[_T]: + ... + + @overload + def scalars(self, index: _KeyIndexType = 0) -> ScalarResult[Any]: + ... + def scalars(self, index: _KeyIndexType = 0) -> ScalarResult[Any]: """Return a :class:`_result.ScalarResult` filtering object which will return single elements rather than :class:`_row.Row` objects. @@ -1067,7 +1103,7 @@ class Result(_WithKeys, ResultInternal[Row]): def _getter( self, key: _KeyIndexType, raiseerr: bool = True - ) -> Optional[Callable[[Row], Any]]: + ) -> Optional[Callable[[Row[Any]], Any]]: """return a callable that will retrieve the given key from a :class:`.Row`. @@ -1105,6 +1141,43 @@ class Result(_WithKeys, ResultInternal[Row]): return MappingResult(self) + @property + def t(self) -> TupleResult[_TP]: + """Apply a "typed tuple" typing filter to returned rows. + + The :attr:`.Result.t` attribute is a synonym for calling the + :meth:`.Result.tuples` method. + + .. versionadded:: 2.0 + + """ + return self # type: ignore + + def tuples(self) -> TupleResult[_TP]: + """Apply a "typed tuple" typing filter to returned rows. + + This method returns the same :class:`.Result` object at runtime, + however annotates as returning a :class:`.TupleResult` object + that will indicate to :pep:`484` typing tools that plain typed + ``Tuple`` instances are returned rather than rows. This allows + tuple unpacking and ``__getitem__`` access of :class:`.Row` objects + to by typed, for those cases where the statement invoked itself + included typing information. + + .. versionadded:: 2.0 + + :return: the :class:`_result.TupleResult` type at typing time. + + .. seealso:: + + :attr:`.Result.t` - shorter synonym + + :attr:`.Row.t` - :class:`.Row` version + + """ + + return self # type: ignore + def _raw_row_iterator(self) -> Iterator[_RowData]: """Return a safe iterator that yields raw row data. @@ -1114,13 +1187,15 @@ class Result(_WithKeys, ResultInternal[Row]): """ raise NotImplementedError() - def __iter__(self) -> Iterator[Row]: + def __iter__(self) -> Iterator[Row[_TP]]: return self._iter_impl() - def __next__(self) -> Row: + def __next__(self) -> Row[_TP]: return self._next_impl() - def partitions(self, size: Optional[int] = None) -> Iterator[List[Row]]: + def partitions( + self, size: Optional[int] = None + ) -> Iterator[Sequence[Row[_TP]]]: """Iterate through sub-lists of rows of the size given. Each list will be of the size given, excluding the last list to @@ -1158,12 +1233,12 @@ class Result(_WithKeys, ResultInternal[Row]): else: break - def fetchall(self) -> List[Row]: + def fetchall(self) -> Sequence[Row[_TP]]: """A synonym for the :meth:`_engine.Result.all` method.""" return self._allrows() - def fetchone(self) -> Optional[Row]: + def fetchone(self) -> Optional[Row[_TP]]: """Fetch one row. When all rows are exhausted, returns None. @@ -1185,7 +1260,7 @@ class Result(_WithKeys, ResultInternal[Row]): else: return row - def fetchmany(self, size: Optional[int] = None) -> List[Row]: + def fetchmany(self, size: Optional[int] = None) -> Sequence[Row[_TP]]: """Fetch many rows. When all rows are exhausted, returns an empty list. @@ -1202,7 +1277,7 @@ class Result(_WithKeys, ResultInternal[Row]): return self._manyrow_getter(self, size) - def all(self) -> List[Row]: + def all(self) -> Sequence[Row[_TP]]: """Return all rows in a list. Closes the result set after invocation. Subsequent invocations @@ -1216,7 +1291,7 @@ class Result(_WithKeys, ResultInternal[Row]): return self._allrows() - def first(self) -> Optional[Row]: + def first(self) -> Optional[Row[_TP]]: """Fetch the first row or None if no row is present. Closes the result set and discards remaining rows. @@ -1252,7 +1327,7 @@ class Result(_WithKeys, ResultInternal[Row]): raise_for_second_row=False, raise_for_none=False, scalar=False ) - def one_or_none(self) -> Optional[Row]: + def one_or_none(self) -> Optional[Row[_TP]]: """Return at most one result or raise an exception. Returns ``None`` if the result has no rows. @@ -1276,6 +1351,14 @@ class Result(_WithKeys, ResultInternal[Row]): raise_for_second_row=True, raise_for_none=False, scalar=False ) + @overload + def scalar_one(self: Result[Tuple[_T]]) -> _T: + ... + + @overload + def scalar_one(self) -> Any: + ... + def scalar_one(self) -> Any: """Return exactly one scalar result or raise an exception. @@ -1293,6 +1376,14 @@ class Result(_WithKeys, ResultInternal[Row]): raise_for_second_row=True, raise_for_none=True, scalar=True ) + @overload + def scalar_one_or_none(self: Result[Tuple[_T]]) -> Optional[_T]: + ... + + @overload + def scalar_one_or_none(self) -> Optional[Any]: + ... + def scalar_one_or_none(self) -> Optional[Any]: """Return exactly one or no scalar result. @@ -1310,7 +1401,7 @@ class Result(_WithKeys, ResultInternal[Row]): raise_for_second_row=True, raise_for_none=False, scalar=True ) - def one(self) -> Row: + def one(self) -> Row[_TP]: """Return exactly one row or raise an exception. Raises :class:`.NoResultFound` if the result returns no @@ -1341,6 +1432,14 @@ class Result(_WithKeys, ResultInternal[Row]): raise_for_second_row=True, raise_for_none=True, scalar=False ) + @overload + def scalar(self: Result[Tuple[_T]]) -> Optional[_T]: + ... + + @overload + def scalar(self) -> Any: + ... + def scalar(self) -> Any: """Fetch the first column of the first row, and close the result set. @@ -1359,7 +1458,7 @@ class Result(_WithKeys, ResultInternal[Row]): raise_for_second_row=False, raise_for_none=False, scalar=True ) - def freeze(self) -> FrozenResult: + def freeze(self) -> FrozenResult[_TP]: """Return a callable object that will produce copies of this :class:`.Result` when invoked. @@ -1382,7 +1481,7 @@ class Result(_WithKeys, ResultInternal[Row]): return FrozenResult(self) - def merge(self, *others: Result) -> MergedResult: + def merge(self, *others: Result[Any]) -> MergedResult[_TP]: """Merge this :class:`.Result` with other compatible result objects. @@ -1405,9 +1504,17 @@ class FilterResult(ResultInternal[_R]): """ - _post_creational_filter: Optional[Callable[[Any], Any]] = None + __slots__ = ( + "_real_result", + "_post_creational_filter", + "_metadata", + "_unique_filter_state", + "__dict__", + ) + + _post_creational_filter: Optional[Callable[[Any], Any]] - _real_result: Result + _real_result: Result[Any] def _soft_close(self, hard: bool = False) -> None: self._real_result._soft_close(hard=hard) @@ -1416,20 +1523,20 @@ class FilterResult(ResultInternal[_R]): def _attributes(self) -> Dict[Any, Any]: return self._real_result._attributes - def _fetchiter_impl(self) -> Iterator[_InterimRowType[Row]]: + def _fetchiter_impl(self) -> Iterator[_InterimRowType[Row[Any]]]: return self._real_result._fetchiter_impl() def _fetchone_impl( self, hard_close: bool = False - ) -> Optional[_InterimRowType[Row]]: + ) -> Optional[_InterimRowType[Row[Any]]]: return self._real_result._fetchone_impl(hard_close=hard_close) - def _fetchall_impl(self) -> List[_InterimRowType[Row]]: + def _fetchall_impl(self) -> List[_InterimRowType[Row[Any]]]: return self._real_result._fetchall_impl() def _fetchmany_impl( self, size: Optional[int] = None - ) -> List[_InterimRowType[Row]]: + ) -> List[_InterimRowType[Row[Any]]]: return self._real_result._fetchmany_impl(size=size) @@ -1452,11 +1559,13 @@ class ScalarResult(FilterResult[_R]): """ + __slots__ = () + _generate_rows = False _post_creational_filter: Optional[Callable[[Any], Any]] - def __init__(self, real_result: Result, index: _KeyIndexType): + def __init__(self, real_result: Result[Any], index: _KeyIndexType): self._real_result = real_result if real_result._source_supports_scalars: @@ -1480,7 +1589,7 @@ class ScalarResult(FilterResult[_R]): self._unique_filter_state = (set(), strategy) return self - def partitions(self, size: Optional[int] = None) -> Iterator[List[_R]]: + def partitions(self, size: Optional[int] = None) -> Iterator[Sequence[_R]]: """Iterate through sub-lists of elements of the size given. Equivalent to :meth:`_result.Result.partitions` except that @@ -1498,12 +1607,12 @@ class ScalarResult(FilterResult[_R]): else: break - def fetchall(self) -> List[_R]: + def fetchall(self) -> Sequence[_R]: """A synonym for the :meth:`_engine.ScalarResult.all` method.""" return self._allrows() - def fetchmany(self, size: Optional[int] = None) -> List[_R]: + def fetchmany(self, size: Optional[int] = None) -> Sequence[_R]: """Fetch many objects. Equivalent to :meth:`_result.Result.fetchmany` except that @@ -1513,7 +1622,7 @@ class ScalarResult(FilterResult[_R]): """ return self._manyrow_getter(self, size) - def all(self) -> List[_R]: + def all(self) -> Sequence[_R]: """Return all scalar values in a list. Equivalent to :meth:`_result.Result.all` except that @@ -1567,6 +1676,177 @@ class ScalarResult(FilterResult[_R]): ) +SelfTupleResult = TypeVar("SelfTupleResult", bound="TupleResult[Any]") + + +class TupleResult(FilterResult[_R], util.TypingOnly): + """a :class:`.Result` that's typed as returning plain Python tuples + instead of rows. + + Since :class:`.Row` acts like a tuple in every way already, + this class is a typing only class, regular :class:`.Result` is still + used at runtime. + + """ + + __slots__ = () + + if TYPE_CHECKING: + + def partitions( + self, size: Optional[int] = None + ) -> Iterator[Sequence[_R]]: + """Iterate through sub-lists of elements of the size given. + + Equivalent to :meth:`_result.Result.partitions` except that + tuple values, rather than :class:`_result.Row` objects, + are returned. + + """ + ... + + def fetchone(self) -> Optional[_R]: + """Fetch one tuple. + + Equivalent to :meth:`_result.Result.fetchone` except that + tuple values, rather than :class:`_result.Row` + objects, are returned. + + """ + ... + + def fetchall(self) -> Sequence[_R]: + """A synonym for the :meth:`_engine.ScalarResult.all` method.""" + ... + + def fetchmany(self, size: Optional[int] = None) -> Sequence[_R]: + """Fetch many objects. + + Equivalent to :meth:`_result.Result.fetchmany` except that + tuple values, rather than :class:`_result.Row` objects, + are returned. + + """ + ... + + def all(self) -> Sequence[_R]: # noqa: A001 + """Return all scalar values in a list. + + Equivalent to :meth:`_result.Result.all` except that + tuple values, rather than :class:`_result.Row` objects, + are returned. + + """ + ... + + def __iter__(self) -> Iterator[_R]: + ... + + def __next__(self) -> _R: + ... + + def first(self) -> Optional[_R]: + """Fetch the first object or None if no object is present. + + Equivalent to :meth:`_result.Result.first` except that + tuple values, rather than :class:`_result.Row` objects, + are returned. + + + """ + ... + + def one_or_none(self) -> Optional[_R]: + """Return at most one object or raise an exception. + + Equivalent to :meth:`_result.Result.one_or_none` except that + tuple values, rather than :class:`_result.Row` objects, + are returned. + + """ + ... + + def one(self) -> _R: + """Return exactly one object or raise an exception. + + Equivalent to :meth:`_result.Result.one` except that + tuple values, rather than :class:`_result.Row` objects, + are returned. + + """ + ... + + @overload + def scalar_one(self: TupleResult[Tuple[_T]]) -> _T: + ... + + @overload + def scalar_one(self) -> Any: + ... + + def scalar_one(self) -> Any: + """Return exactly one scalar result or raise an exception. + + This is equivalent to calling :meth:`.Result.scalars` and then + :meth:`.Result.one`. + + .. seealso:: + + :meth:`.Result.one` + + :meth:`.Result.scalars` + + """ + ... + + @overload + def scalar_one_or_none(self: TupleResult[Tuple[_T]]) -> Optional[_T]: + ... + + @overload + def scalar_one_or_none(self) -> Optional[Any]: + ... + + def scalar_one_or_none(self) -> Optional[Any]: + """Return exactly one or no scalar result. + + This is equivalent to calling :meth:`.Result.scalars` and then + :meth:`.Result.one_or_none`. + + .. seealso:: + + :meth:`.Result.one_or_none` + + :meth:`.Result.scalars` + + """ + ... + + @overload + def scalar(self: TupleResult[Tuple[_T]]) -> Optional[_T]: + ... + + @overload + def scalar(self) -> Any: + ... + + def scalar(self) -> Any: + """Fetch the first column of the first row, and close the result set. + + Returns None if there are no rows to fetch. + + No validation is performed to test if additional rows remain. + + After calling this method, the object is fully closed, + e.g. the :meth:`_engine.CursorResult.close` + method will have been called. + + :return: a Python scalar value , or None if no rows remain. + + """ + ... + + SelfMappingResult = TypeVar("SelfMappingResult", bound="MappingResult") @@ -1579,11 +1859,13 @@ class MappingResult(_WithKeys, FilterResult[RowMapping]): """ + __slots__ = () + _generate_rows = True _post_creational_filter = operator.attrgetter("_mapping") - def __init__(self, result: Result): + def __init__(self, result: Result[Any]): self._real_result = result self._unique_filter_state = result._unique_filter_state self._metadata = result._metadata @@ -1610,7 +1892,7 @@ class MappingResult(_WithKeys, FilterResult[RowMapping]): def partitions( self, size: Optional[int] = None - ) -> Iterator[List[RowMapping]]: + ) -> Iterator[Sequence[RowMapping]]: """Iterate through sub-lists of elements of the size given. Equivalent to :meth:`_result.Result.partitions` except that @@ -1628,7 +1910,7 @@ class MappingResult(_WithKeys, FilterResult[RowMapping]): else: break - def fetchall(self) -> List[RowMapping]: + def fetchall(self) -> Sequence[RowMapping]: """A synonym for the :meth:`_engine.MappingResult.all` method.""" return self._allrows() @@ -1648,7 +1930,7 @@ class MappingResult(_WithKeys, FilterResult[RowMapping]): else: return row - def fetchmany(self, size: Optional[int] = None) -> List[RowMapping]: + def fetchmany(self, size: Optional[int] = None) -> Sequence[RowMapping]: """Fetch many objects. Equivalent to :meth:`_result.Result.fetchmany` except that @@ -1659,7 +1941,7 @@ class MappingResult(_WithKeys, FilterResult[RowMapping]): return self._manyrow_getter(self, size) - def all(self) -> List[RowMapping]: + def all(self) -> Sequence[RowMapping]: """Return all scalar values in a list. Equivalent to :meth:`_result.Result.all` except that @@ -1714,7 +1996,7 @@ class MappingResult(_WithKeys, FilterResult[RowMapping]): ) -class FrozenResult: +class FrozenResult(Generic[_TP]): """Represents a :class:`.Result` object in a "frozen" state suitable for caching. @@ -1755,7 +2037,7 @@ class FrozenResult: data: Sequence[Any] - def __init__(self, result: Result): + def __init__(self, result: Result[_TP]): self.metadata = result._metadata._for_freeze() self._source_supports_scalars = result._source_supports_scalars self._attributes = result._attributes @@ -1771,7 +2053,9 @@ class FrozenResult: else: return [list(row) for row in self.data] - def with_new_rows(self, tuple_data: Sequence[Row]) -> FrozenResult: + def with_new_rows( + self, tuple_data: Sequence[Row[_TP]] + ) -> FrozenResult[_TP]: fr = FrozenResult.__new__(FrozenResult) fr.metadata = self.metadata fr._attributes = self._attributes @@ -1783,14 +2067,16 @@ class FrozenResult: fr.data = tuple_data return fr - def __call__(self) -> Result: - result = IteratorResult(self.metadata, iter(self.data)) + def __call__(self) -> Result[_TP]: + result: IteratorResult[_TP] = IteratorResult( + self.metadata, iter(self.data) + ) result._attributes = self._attributes result._source_supports_scalars = self._source_supports_scalars return result -class IteratorResult(Result): +class IteratorResult(Result[_TP]): """A :class:`.Result` that gets data from a Python iterator of :class:`.Row` objects or similar row-like data. @@ -1833,7 +2119,7 @@ class IteratorResult(Result): def _fetchone_impl( self, hard_close: bool = False - ) -> Optional[_InterimRowType[Row]]: + ) -> Optional[_InterimRowType[Row[Any]]]: if self._hard_closed: self._raise_hard_closed() @@ -1844,7 +2130,7 @@ class IteratorResult(Result): else: return row - def _fetchall_impl(self) -> List[_InterimRowType[Row]]: + def _fetchall_impl(self) -> List[_InterimRowType[Row[Any]]]: if self._hard_closed: self._raise_hard_closed() try: @@ -1854,23 +2140,23 @@ class IteratorResult(Result): def _fetchmany_impl( self, size: Optional[int] = None - ) -> List[_InterimRowType[Row]]: + ) -> List[_InterimRowType[Row[Any]]]: if self._hard_closed: self._raise_hard_closed() return list(itertools.islice(self.iterator, 0, size)) -def null_result() -> IteratorResult: +def null_result() -> IteratorResult[Any]: return IteratorResult(SimpleResultMetaData([]), iter([])) SelfChunkedIteratorResult = TypeVar( - "SelfChunkedIteratorResult", bound="ChunkedIteratorResult" + "SelfChunkedIteratorResult", bound="ChunkedIteratorResult[Any]" ) -class ChunkedIteratorResult(IteratorResult): +class ChunkedIteratorResult(IteratorResult[_TP]): """An :class:`.IteratorResult` that works from an iterator-producing callable. The given ``chunks`` argument is a function that is given a number of rows @@ -1922,13 +2208,13 @@ class ChunkedIteratorResult(IteratorResult): def _fetchmany_impl( self, size: Optional[int] = None - ) -> List[_InterimRowType[Row]]: + ) -> List[_InterimRowType[Row[Any]]]: if self.dynamic_yield_per: self.iterator = itertools.chain.from_iterable(self.chunks(size)) return super()._fetchmany_impl(size=size) -class MergedResult(IteratorResult): +class MergedResult(IteratorResult[_TP]): """A :class:`_engine.Result` that is merged from any number of :class:`_engine.Result` objects. @@ -1942,7 +2228,7 @@ class MergedResult(IteratorResult): rowcount: Optional[int] def __init__( - self, cursor_metadata: ResultMetaData, results: Sequence[Result] + self, cursor_metadata: ResultMetaData, results: Sequence[Result[_TP]] ): self._results = results super(MergedResult, self).__init__( diff --git a/lib/sqlalchemy/engine/row.py b/lib/sqlalchemy/engine/row.py index 4ba39b55d..7c9eacb78 100644 --- a/lib/sqlalchemy/engine/row.py +++ b/lib/sqlalchemy/engine/row.py @@ -16,6 +16,7 @@ import typing from typing import Any from typing import Callable from typing import Dict +from typing import Generic from typing import Iterator from typing import List from typing import Mapping @@ -24,12 +25,14 @@ from typing import Optional from typing import overload from typing import Sequence from typing import Tuple +from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union from ..sql import util as sql_util from ..util._has_cy import HAS_CYEXTENSION -if typing.TYPE_CHECKING or not HAS_CYEXTENSION: +if TYPE_CHECKING or not HAS_CYEXTENSION: from ._py_row import BaseRow as BaseRow from ._py_row import KEY_INTEGER_ONLY from ._py_row import KEY_OBJECTS_ONLY @@ -38,13 +41,16 @@ else: from sqlalchemy.cyextension.resultproxy import KEY_INTEGER_ONLY from sqlalchemy.cyextension.resultproxy import KEY_OBJECTS_ONLY -if typing.TYPE_CHECKING: +if TYPE_CHECKING: from .result import _KeyType from .result import RMKeyView from ..sql.type_api import _ResultProcessorType +_T = TypeVar("_T", bound=Any) +_TP = TypeVar("_TP", bound=Tuple[Any, ...]) -class Row(BaseRow, typing.Sequence[Any]): + +class Row(BaseRow, Sequence[Any], Generic[_TP]): """Represent a single result row. The :class:`.Row` object represents a row of a database result. It is @@ -82,6 +88,37 @@ class Row(BaseRow, typing.Sequence[Any]): def __delattr__(self, name: str) -> NoReturn: raise AttributeError("can't delete attribute") + def tuple(self) -> _TP: + """Return a 'tuple' form of this :class:`.Row`. + + At runtime, this method returns "self"; the :class:`.Row` object is + already a named tuple. However, at the typing level, if this + :class:`.Row` is typed, the "tuple" return type will be a :pep:`484` + ``Tuple`` datatype that contains typing information about individual + elements, supporting typed unpacking and attribute access. + + .. versionadded:: 2.0 + + .. seealso:: + + :meth:`.Result.tuples` + + """ + return self # type: ignore + + @property + def t(self) -> _TP: + """a synonym for :attr:`.Row.tuple` + + .. versionadded:: 2.0 + + .. seealso:: + + :meth:`.Result.t` + + """ + return self # type: ignore + @property def _mapping(self) -> RowMapping: """Return a :class:`.RowMapping` for this :class:`.Row`. @@ -107,7 +144,7 @@ class Row(BaseRow, typing.Sequence[Any]): def _filter_on_values( self, filters: Optional[Sequence[Optional[_ResultProcessorType[Any]]]] - ) -> Row: + ) -> Row[Any]: return Row( self._parent, filters, @@ -116,7 +153,7 @@ class Row(BaseRow, typing.Sequence[Any]): self._data, ) - if not typing.TYPE_CHECKING: + if not TYPE_CHECKING: def _special_name_accessor(name: str) -> Any: """Handle ambiguous names such as "count" and "index" """ @@ -151,7 +188,7 @@ class Row(BaseRow, typing.Sequence[Any]): __hash__ = BaseRow.__hash__ - if typing.TYPE_CHECKING: + if TYPE_CHECKING: @overload def __getitem__(self, index: int) -> Any: @@ -299,7 +336,7 @@ class RowMapping(BaseRow, typing.Mapping[str, Any]): _default_key_style = KEY_OBJECTS_ONLY - if typing.TYPE_CHECKING: + if TYPE_CHECKING: def __getitem__(self, key: _KeyType) -> Any: ... diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index fb05f512e..95549ada6 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -12,8 +12,10 @@ from typing import Generator from typing import NoReturn from typing import Optional from typing import overload +from typing import Tuple from typing import Type from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union from . import exc as async_exc @@ -50,6 +52,9 @@ if TYPE_CHECKING: from ...pool import PoolProxiedConnection from ...sql._typing import _InfoType from ...sql.base import Executable + from ...sql.selectable import TypedReturnsRows + +_T = TypeVar("_T", bound=Any) class _SyncConnectionCallable(Protocol): @@ -407,7 +412,7 @@ class AsyncConnection( statement: str, parameters: Optional[_DBAPIAnyExecuteParams] = None, execution_options: Optional[_ExecuteOptionsParameter] = None, - ) -> CursorResult: + ) -> CursorResult[Any]: r"""Executes a driver-level SQL string and return buffered :class:`_engine.Result`. @@ -423,12 +428,33 @@ class AsyncConnection( return await _ensure_sync_result(result, self.exec_driver_sql) + @overload + async def stream( + self, + statement: TypedReturnsRows[_T], + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> AsyncResult[_T]: + ... + + @overload async def stream( self, statement: Executable, parameters: Optional[_CoreAnyExecuteParams] = None, + *, execution_options: Optional[_ExecuteOptionsParameter] = None, - ) -> AsyncResult: + ) -> AsyncResult[Any]: + ... + + async def stream( + self, + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> AsyncResult[Any]: """Execute a statement and return a streaming :class:`_asyncio.AsyncResult` object.""" @@ -436,7 +462,7 @@ class AsyncConnection( self._proxied.execute, statement, parameters, - util.EMPTY_DICT.merge_with( + execution_options=util.EMPTY_DICT.merge_with( execution_options, {"stream_results": True} ), _require_await=True, @@ -446,12 +472,33 @@ class AsyncConnection( assert False, "server side result expected" return AsyncResult(result) + @overload + async def execute( + self, + statement: TypedReturnsRows[_T], + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> CursorResult[_T]: + ... + + @overload async def execute( self, statement: Executable, parameters: Optional[_CoreAnyExecuteParams] = None, + *, execution_options: Optional[_ExecuteOptionsParameter] = None, - ) -> CursorResult: + ) -> CursorResult[Any]: + ... + + async def execute( + self, + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> CursorResult[Any]: r"""Executes a SQL statement construct and return a buffered :class:`_engine.Result`. @@ -487,15 +534,36 @@ class AsyncConnection( self._proxied.execute, statement, parameters, - execution_options, + execution_options=execution_options, _require_await=True, ) return await _ensure_sync_result(result, self.execute) + @overload + async def scalar( + self, + statement: TypedReturnsRows[Tuple[_T]], + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> Optional[_T]: + ... + + @overload async def scalar( self, statement: Executable, parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> Any: + ... + + async def scalar( + self, + statement: Executable, + parameters: Optional[_CoreSingleExecuteParams] = None, + *, execution_options: Optional[_ExecuteOptionsParameter] = None, ) -> Any: r"""Executes a SQL statement construct and returns a scalar object. @@ -508,13 +576,36 @@ class AsyncConnection( first row returned. """ - result = await self.execute(statement, parameters, execution_options) + result = await self.execute( + statement, parameters, execution_options=execution_options + ) return result.scalar() + @overload + async def scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> ScalarResult[_T]: + ... + + @overload + async def scalars( + self, + statement: Executable, + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> ScalarResult[Any]: + ... + async def scalars( self, statement: Executable, parameters: Optional[_CoreSingleExecuteParams] = None, + *, execution_options: Optional[_ExecuteOptionsParameter] = None, ) -> ScalarResult[Any]: r"""Executes a SQL statement construct and returns a scalar objects. @@ -528,13 +619,36 @@ class AsyncConnection( .. versionadded:: 1.4.24 """ - result = await self.execute(statement, parameters, execution_options) + result = await self.execute( + statement, parameters, execution_options=execution_options + ) return result.scalars() + @overload + async def stream_scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> AsyncScalarResult[_T]: + ... + + @overload async def stream_scalars( self, statement: Executable, parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> AsyncScalarResult[Any]: + ... + + async def stream_scalars( + self, + statement: Executable, + parameters: Optional[_CoreSingleExecuteParams] = None, + *, execution_options: Optional[_ExecuteOptionsParameter] = None, ) -> AsyncScalarResult[Any]: r"""Executes a SQL statement and returns a streaming scalar result @@ -549,7 +663,9 @@ class AsyncConnection( .. versionadded:: 1.4.24 """ - result = await self.stream(statement, parameters, execution_options) + result = await self.stream( + statement, parameters, execution_options=execution_options + ) return result.scalars() async def run_sync( diff --git a/lib/sqlalchemy/ext/asyncio/result.py b/lib/sqlalchemy/ext/asyncio/result.py index d0337554c..ff3dcf417 100644 --- a/lib/sqlalchemy/ext/asyncio/result.py +++ b/lib/sqlalchemy/ext/asyncio/result.py @@ -9,12 +9,15 @@ from __future__ import annotations import operator from typing import Any from typing import AsyncIterator -from typing import List from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple from typing import TYPE_CHECKING from typing import TypeVar from . import exc as async_exc +from ... import util from ...engine.result import _NO_ROW from ...engine.result import _R from ...engine.result import FilterResult @@ -24,6 +27,7 @@ from ...engine.result import ResultMetaData from ...engine.row import Row from ...engine.row import RowMapping from ...util.concurrency import greenlet_spawn +from ...util.typing import Literal if TYPE_CHECKING: from ...engine import CursorResult @@ -32,9 +36,14 @@ if TYPE_CHECKING: from ...engine.result import _UniqueFilterType from ...engine.result import RMKeyView +_T = TypeVar("_T", bound=Any) +_TP = TypeVar("_TP", bound=Tuple[Any, ...]) + class AsyncCommon(FilterResult[_R]): - _real_result: Result + __slots__ = () + + _real_result: Result[Any] _metadata: ResultMetaData async def close(self) -> None: @@ -43,10 +52,10 @@ class AsyncCommon(FilterResult[_R]): await greenlet_spawn(self._real_result.close) -SelfAsyncResult = TypeVar("SelfAsyncResult", bound="AsyncResult") +SelfAsyncResult = TypeVar("SelfAsyncResult", bound="AsyncResult[Any]") -class AsyncResult(AsyncCommon[Row]): +class AsyncResult(AsyncCommon[Row[_TP]]): """An asyncio wrapper around a :class:`_result.Result` object. The :class:`_asyncio.AsyncResult` only applies to statement executions that @@ -67,11 +76,16 @@ class AsyncResult(AsyncCommon[Row]): """ - def __init__(self, real_result: Result): + __slots__ = () + + _real_result: Result[_TP] + + def __init__(self, real_result: Result[_TP]): self._real_result = real_result self._metadata = real_result._metadata self._unique_filter_state = real_result._unique_filter_state + self._post_creational_filter = None # BaseCursorResult pre-generates the "_row_getter". Use that # if available rather than building a second one @@ -80,6 +94,43 @@ class AsyncResult(AsyncCommon[Row]): "_row_getter", real_result.__dict__["_row_getter"] ) + @property + def t(self) -> AsyncTupleResult[_TP]: + """Apply a "typed tuple" typing filter to returned rows. + + The :attr:`.AsyncResult.t` attribute is a synonym for calling the + :meth:`.AsyncResult.tuples` method. + + .. versionadded:: 2.0 + + """ + return self # type: ignore + + def tuples(self) -> AsyncTupleResult[_TP]: + """Apply a "typed tuple" typing filter to returned rows. + + This method returns the same :class:`.AsyncResult` object at runtime, + however annotates as returning a :class:`.AsyncTupleResult` object + that will indicate to :pep:`484` typing tools that plain typed + ``Tuple`` instances are returned rather than rows. This allows + tuple unpacking and ``__getitem__`` access of :class:`.Row` objects + to by typed, for those cases where the statement invoked itself + included typing information. + + .. versionadded:: 2.0 + + :return: the :class:`_result.AsyncTupleResult` type at typing time. + + .. seealso:: + + :attr:`.AsyncResult.t` - shorter synonym + + :attr:`.Row.t` - :class:`.Row` version + + """ + + return self # type: ignore + def keys(self) -> RMKeyView: """Return the :meth:`_engine.Result.keys` collection from the underlying :class:`_engine.Result`. @@ -115,7 +166,7 @@ class AsyncResult(AsyncCommon[Row]): async def partitions( self, size: Optional[int] = None - ) -> AsyncIterator[List[Row]]: + ) -> AsyncIterator[Sequence[Row[_TP]]]: """Iterate through sub-lists of rows of the size given. An async iterator is returned:: @@ -141,7 +192,16 @@ class AsyncResult(AsyncCommon[Row]): else: break - async def fetchone(self) -> Optional[Row]: + async def fetchall(self) -> Sequence[Row[_TP]]: + """A synonym for the :meth:`.AsyncResult.all` method. + + .. versionadded:: 2.0 + + """ + + return await greenlet_spawn(self._allrows) + + async def fetchone(self) -> Optional[Row[_TP]]: """Fetch one row. When all rows are exhausted, returns None. @@ -163,7 +223,9 @@ class AsyncResult(AsyncCommon[Row]): else: return row - async def fetchmany(self, size: Optional[int] = None) -> List[Row]: + async def fetchmany( + self, size: Optional[int] = None + ) -> Sequence[Row[_TP]]: """Fetch many rows. When all rows are exhausted, returns an empty list. @@ -184,7 +246,7 @@ class AsyncResult(AsyncCommon[Row]): return await greenlet_spawn(self._manyrow_getter, self, size) - async def all(self) -> List[Row]: + async def all(self) -> Sequence[Row[_TP]]: """Return all rows in a list. Closes the result set after invocation. Subsequent invocations @@ -196,17 +258,17 @@ class AsyncResult(AsyncCommon[Row]): return await greenlet_spawn(self._allrows) - def __aiter__(self) -> AsyncResult: + def __aiter__(self) -> AsyncResult[_TP]: return self - async def __anext__(self) -> Row: + async def __anext__(self) -> Row[_TP]: row = await greenlet_spawn(self._onerow_getter, self) if row is _NO_ROW: raise StopAsyncIteration() else: return row - async def first(self) -> Optional[Row]: + async def first(self) -> Optional[Row[_TP]]: """Fetch the first row or None if no row is present. Closes the result set and discards remaining rows. @@ -229,7 +291,7 @@ class AsyncResult(AsyncCommon[Row]): """ return await greenlet_spawn(self._only_one_row, False, False, False) - async def one_or_none(self) -> Optional[Row]: + async def one_or_none(self) -> Optional[Row[_TP]]: """Return at most one result or raise an exception. Returns ``None`` if the result has no rows. @@ -251,6 +313,14 @@ class AsyncResult(AsyncCommon[Row]): """ return await greenlet_spawn(self._only_one_row, True, False, False) + @overload + async def scalar_one(self: AsyncResult[Tuple[_T]]) -> _T: + ... + + @overload + async def scalar_one(self) -> Any: + ... + async def scalar_one(self) -> Any: """Return exactly one scalar result or raise an exception. @@ -266,6 +336,16 @@ class AsyncResult(AsyncCommon[Row]): """ return await greenlet_spawn(self._only_one_row, True, True, True) + @overload + async def scalar_one_or_none( + self: AsyncResult[Tuple[_T]], + ) -> Optional[_T]: + ... + + @overload + async def scalar_one_or_none(self) -> Optional[Any]: + ... + async def scalar_one_or_none(self) -> Optional[Any]: """Return exactly one or no scalar result. @@ -281,7 +361,7 @@ class AsyncResult(AsyncCommon[Row]): """ return await greenlet_spawn(self._only_one_row, True, False, True) - async def one(self) -> Row: + async def one(self) -> Row[_TP]: """Return exactly one row or raise an exception. Raises :class:`.NoResultFound` if the result returns no @@ -312,6 +392,14 @@ class AsyncResult(AsyncCommon[Row]): """ return await greenlet_spawn(self._only_one_row, True, True, False) + @overload + async def scalar(self: AsyncResult[Tuple[_T]]) -> Optional[_T]: + ... + + @overload + async def scalar(self) -> Any: + ... + async def scalar(self) -> Any: """Fetch the first column of the first row, and close the result set. @@ -328,7 +416,7 @@ class AsyncResult(AsyncCommon[Row]): """ return await greenlet_spawn(self._only_one_row, False, False, True) - async def freeze(self) -> FrozenResult: + async def freeze(self) -> FrozenResult[_TP]: """Return a callable object that will produce copies of this :class:`_asyncio.AsyncResult` when invoked. @@ -351,7 +439,7 @@ class AsyncResult(AsyncCommon[Row]): return await greenlet_spawn(FrozenResult, self) - def merge(self, *others: AsyncResult) -> MergedResult: + def merge(self, *others: AsyncResult[_TP]) -> MergedResult[_TP]: """Merge this :class:`_asyncio.AsyncResult` with other compatible result objects. @@ -370,6 +458,20 @@ class AsyncResult(AsyncCommon[Row]): (self._real_result,) + tuple(o._real_result for o in others), ) + @overload + def scalars( + self: AsyncResult[Tuple[_T]], index: Literal[0] + ) -> AsyncScalarResult[_T]: + ... + + @overload + def scalars(self: AsyncResult[Tuple[_T]]) -> AsyncScalarResult[_T]: + ... + + @overload + def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]: + ... + def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]: """Return an :class:`_asyncio.AsyncScalarResult` filtering object which will return single elements rather than :class:`_row.Row` objects. @@ -423,9 +525,11 @@ class AsyncScalarResult(AsyncCommon[_R]): """ + __slots__ = () + _generate_rows = False - def __init__(self, real_result: Result, index: _KeyIndexType): + def __init__(self, real_result: Result[Any], index: _KeyIndexType): self._real_result = real_result if real_result._source_supports_scalars: @@ -452,7 +556,7 @@ class AsyncScalarResult(AsyncCommon[_R]): async def partitions( self, size: Optional[int] = None - ) -> AsyncIterator[List[_R]]: + ) -> AsyncIterator[Sequence[_R]]: """Iterate through sub-lists of elements of the size given. Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that @@ -470,12 +574,12 @@ class AsyncScalarResult(AsyncCommon[_R]): else: break - async def fetchall(self) -> List[_R]: + async def fetchall(self) -> Sequence[_R]: """A synonym for the :meth:`_asyncio.AsyncScalarResult.all` method.""" return await greenlet_spawn(self._allrows) - async def fetchmany(self, size: Optional[int] = None) -> List[_R]: + async def fetchmany(self, size: Optional[int] = None) -> Sequence[_R]: """Fetch many objects. Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that @@ -485,7 +589,7 @@ class AsyncScalarResult(AsyncCommon[_R]): """ return await greenlet_spawn(self._manyrow_getter, self, size) - async def all(self) -> List[_R]: + async def all(self) -> Sequence[_R]: """Return all scalar values in a list. Equivalent to :meth:`_asyncio.AsyncResult.all` except that @@ -555,11 +659,13 @@ class AsyncMappingResult(AsyncCommon[RowMapping]): """ + __slots__ = () + _generate_rows = True _post_creational_filter = operator.attrgetter("_mapping") - def __init__(self, result: Result): + def __init__(self, result: Result[Any]): self._real_result = result self._unique_filter_state = result._unique_filter_state self._metadata = result._metadata @@ -602,7 +708,7 @@ class AsyncMappingResult(AsyncCommon[RowMapping]): async def partitions( self, size: Optional[int] = None - ) -> AsyncIterator[List[RowMapping]]: + ) -> AsyncIterator[Sequence[RowMapping]]: """Iterate through sub-lists of elements of the size given. @@ -621,7 +727,7 @@ class AsyncMappingResult(AsyncCommon[RowMapping]): else: break - async def fetchall(self) -> List[RowMapping]: + async def fetchall(self) -> Sequence[RowMapping]: """A synonym for the :meth:`_asyncio.AsyncMappingResult.all` method.""" return await greenlet_spawn(self._allrows) @@ -641,7 +747,9 @@ class AsyncMappingResult(AsyncCommon[RowMapping]): else: return row - async def fetchmany(self, size: Optional[int] = None) -> List[RowMapping]: + async def fetchmany( + self, size: Optional[int] = None + ) -> Sequence[RowMapping]: """Fetch many rows. Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that @@ -652,7 +760,7 @@ class AsyncMappingResult(AsyncCommon[RowMapping]): return await greenlet_spawn(self._manyrow_getter, self, size) - async def all(self) -> List[RowMapping]: + async def all(self) -> Sequence[RowMapping]: """Return all rows in a list. Equivalent to :meth:`_asyncio.AsyncResult.all` except that @@ -705,11 +813,186 @@ class AsyncMappingResult(AsyncCommon[RowMapping]): return await greenlet_spawn(self._only_one_row, True, True, False) -_RT = TypeVar("_RT", bound="Result") +SelfAsyncTupleResult = TypeVar( + "SelfAsyncTupleResult", bound="AsyncTupleResult[Any]" +) + + +class AsyncTupleResult(AsyncCommon[_R], util.TypingOnly): + """a :class:`.AsyncResult` that's typed as returning plain Python tuples + instead of rows. + + Since :class:`.Row` acts like a tuple in every way already, + this class is a typing only class, regular :class:`.AsyncResult` is + still used at runtime. + + """ + + __slots__ = () + + if TYPE_CHECKING: + + async def partitions( + self, size: Optional[int] = None + ) -> AsyncIterator[Sequence[_R]]: + """Iterate through sub-lists of elements of the size given. + + Equivalent to :meth:`_result.Result.partitions` except that + tuple values, rather than :class:`_result.Row` objects, + are returned. + + """ + ... + + async def fetchone(self) -> Optional[_R]: + """Fetch one tuple. + + Equivalent to :meth:`_result.Result.fetchone` except that + tuple values, rather than :class:`_result.Row` + objects, are returned. + + """ + ... + + async def fetchall(self) -> Sequence[_R]: + """A synonym for the :meth:`_engine.ScalarResult.all` method.""" + ... + + async def fetchmany(self, size: Optional[int] = None) -> Sequence[_R]: + """Fetch many objects. + + Equivalent to :meth:`_result.Result.fetchmany` except that + tuple values, rather than :class:`_result.Row` objects, + are returned. + + """ + ... + + async def all(self) -> Sequence[_R]: # noqa: A001 + """Return all scalar values in a list. + + Equivalent to :meth:`_result.Result.all` except that + tuple values, rather than :class:`_result.Row` objects, + are returned. + + """ + ... + + async def __aiter__(self) -> AsyncIterator[_R]: + ... + + async def __anext__(self) -> _R: + ... + + async def first(self) -> Optional[_R]: + """Fetch the first object or None if no object is present. + + Equivalent to :meth:`_result.Result.first` except that + tuple values, rather than :class:`_result.Row` objects, + are returned. + + + """ + ... + + async def one_or_none(self) -> Optional[_R]: + """Return at most one object or raise an exception. + + Equivalent to :meth:`_result.Result.one_or_none` except that + tuple values, rather than :class:`_result.Row` objects, + are returned. + + """ + ... + + async def one(self) -> _R: + """Return exactly one object or raise an exception. + + Equivalent to :meth:`_result.Result.one` except that + tuple values, rather than :class:`_result.Row` objects, + are returned. + + """ + ... + + @overload + async def scalar_one(self: AsyncTupleResult[Tuple[_T]]) -> _T: + ... + + @overload + async def scalar_one(self) -> Any: + ... + + async def scalar_one(self) -> Any: + """Return exactly one scalar result or raise an exception. + + This is equivalent to calling :meth:`.Result.scalars` and then + :meth:`.Result.one`. + + .. seealso:: + + :meth:`.Result.one` + + :meth:`.Result.scalars` + + """ + ... + + @overload + async def scalar_one_or_none( + self: AsyncTupleResult[Tuple[_T]], + ) -> Optional[_T]: + ... + + @overload + async def scalar_one_or_none(self) -> Optional[Any]: + ... + + async def scalar_one_or_none(self) -> Optional[Any]: + """Return exactly one or no scalar result. + + This is equivalent to calling :meth:`.Result.scalars` and then + :meth:`.Result.one_or_none`. + + .. seealso:: + + :meth:`.Result.one_or_none` + + :meth:`.Result.scalars` + + """ + ... + + @overload + async def scalar(self: AsyncTupleResult[Tuple[_T]]) -> Optional[_T]: + ... + + @overload + async def scalar(self) -> Any: + ... + + async def scalar(self) -> Any: + """Fetch the first column of the first row, and close the result set. + + Returns None if there are no rows to fetch. + + No validation is performed to test if additional rows remain. + + After calling this method, the object is fully closed, + e.g. the :meth:`_engine.CursorResult.close` + method will have been called. + + :return: a Python scalar value , or None if no rows remain. + + """ + ... + + +_RT = TypeVar("_RT", bound="Result[Any]") async def _ensure_sync_result(result: _RT, calling_method: Any) -> _RT: - cursor_result: CursorResult + cursor_result: CursorResult[Any] try: is_cursor = result._is_cursor diff --git a/lib/sqlalchemy/ext/asyncio/scoping.py b/lib/sqlalchemy/ext/asyncio/scoping.py index c7a6e2ca0..22a060a0d 100644 --- a/lib/sqlalchemy/ext/asyncio/scoping.py +++ b/lib/sqlalchemy/ext/asyncio/scoping.py @@ -12,10 +12,12 @@ from typing import Callable from typing import Iterable from typing import Iterator from typing import Optional +from typing import overload from typing import Sequence from typing import Tuple from typing import Type from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union from .session import async_sessionmaker @@ -37,9 +39,9 @@ if TYPE_CHECKING: from ...engine import Engine from ...engine import Result from ...engine import Row + from ...engine import RowMapping from ...engine.interfaces import _CoreAnyExecuteParams from ...engine.interfaces import _CoreSingleExecuteParams - from ...engine.interfaces import _ExecuteOptions from ...engine.interfaces import _ExecuteOptionsParameter from ...engine.result import ScalarResult from ...orm._typing import _IdentityKeyType @@ -52,6 +54,9 @@ if TYPE_CHECKING: from ...sql.base import Executable from ...sql.elements import ClauseElement from ...sql.selectable import ForUpdateArg + from ...sql.selectable import TypedReturnsRows + +_T = TypeVar("_T", bound=Any) @create_proxy_methods( @@ -480,6 +485,32 @@ class async_scoped_session: return await self._proxied.delete(instance) + @overload + async def execute( + self, + statement: TypedReturnsRows[_T], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[_T]: + ... + + @overload + async def execute( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[Any]: + ... + async def execute( self, statement: Executable, @@ -488,7 +519,7 @@ class async_scoped_session: execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> Result: + ) -> Result[Any]: r"""Execute a statement and return a buffered :class:`_engine.Result` object. @@ -916,6 +947,30 @@ class async_scoped_session: return await self._proxied.rollback() + @overload + async def scalar( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Optional[_T]: + ... + + @overload + async def scalar( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Any: + ... + async def scalar( self, statement: Executable, @@ -947,6 +1002,30 @@ class async_scoped_session: **kw, ) + @overload + async def scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[_T]: + ... + + @overload + async def scalars( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[Any]: + ... + async def scalars( self, statement: Executable, @@ -984,6 +1063,19 @@ class async_scoped_session: **kw, ) + @overload + async def stream( + self, + statement: TypedReturnsRows[_T], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncResult[_T]: + ... + + @overload async def stream( self, statement: Executable, @@ -992,7 +1084,18 @@ class async_scoped_session: execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> AsyncResult: + ) -> AsyncResult[Any]: + ... + + async def stream( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncResult[Any]: r"""Execute a statement and return a streaming :class:`_asyncio.AsyncResult` object. @@ -1012,6 +1115,30 @@ class async_scoped_session: **kw, ) + @overload + async def stream_scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncScalarResult[_T]: + ... + + @overload + async def stream_scalars( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncScalarResult[Any]: + ... + async def stream_scalars( self, statement: Executable, @@ -1323,7 +1450,7 @@ class async_scoped_session: ident: Union[Any, Tuple[Any, ...]] = None, *, instance: Optional[Any] = None, - row: Optional[Row] = None, + row: Optional[Union[Row[Any], RowMapping]] = None, identity_token: Optional[Any] = None, ) -> _IdentityKeyType[Any]: r"""Return an identity key. diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index 1422f99a3..f2a69e9cd 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -12,10 +12,12 @@ from typing import Iterable from typing import Iterator from typing import NoReturn from typing import Optional +from typing import overload from typing import Sequence from typing import Tuple from typing import Type from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union from . import engine @@ -39,11 +41,10 @@ if TYPE_CHECKING: from ...engine import Engine from ...engine import Result from ...engine import Row + from ...engine import RowMapping from ...engine import ScalarResult - from ...engine import Transaction from ...engine.interfaces import _CoreAnyExecuteParams from ...engine.interfaces import _CoreSingleExecuteParams - from ...engine.interfaces import _ExecuteOptions from ...engine.interfaces import _ExecuteOptionsParameter from ...event import dispatcher from ...orm._typing import _IdentityKeyType @@ -59,9 +60,12 @@ if TYPE_CHECKING: from ...sql.base import Executable from ...sql.elements import ClauseElement from ...sql.selectable import ForUpdateArg + from ...sql.selectable import TypedReturnsRows _AsyncSessionBind = Union["AsyncEngine", "AsyncConnection"] +_T = TypeVar("_T", bound=Any) + class _SyncSessionCallable(Protocol): def __call__(self, session: Session, *arg: Any, **kw: Any) -> Any: @@ -257,6 +261,32 @@ class AsyncSession(ReversibleProxy[Session]): return await greenlet_spawn(fn, self.sync_session, *arg, **kw) + @overload + async def execute( + self, + statement: TypedReturnsRows[_T], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[_T]: + ... + + @overload + async def execute( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[Any]: + ... + async def execute( self, statement: Executable, @@ -265,7 +295,7 @@ class AsyncSession(ReversibleProxy[Session]): execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> Result: + ) -> Result[Any]: """Execute a statement and return a buffered :class:`_engine.Result` object. @@ -292,6 +322,30 @@ class AsyncSession(ReversibleProxy[Session]): ) return await _ensure_sync_result(result, self.execute) + @overload + async def scalar( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Optional[_T]: + ... + + @overload + async def scalar( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Any: + ... + async def scalar( self, statement: Executable, @@ -326,6 +380,30 @@ class AsyncSession(ReversibleProxy[Session]): ) return result + @overload + async def scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[_T]: + ... + + @overload + async def scalars( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[Any]: + ... + async def scalars( self, statement: Executable, @@ -391,6 +469,30 @@ class AsyncSession(ReversibleProxy[Session]): ) return result_obj + @overload + async def stream( + self, + statement: TypedReturnsRows[_T], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncResult[_T]: + ... + + @overload + async def stream( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncResult[Any]: + ... + async def stream( self, statement: Executable, @@ -399,7 +501,7 @@ class AsyncSession(ReversibleProxy[Session]): execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> AsyncResult: + ) -> AsyncResult[Any]: """Execute a statement and return a streaming :class:`_asyncio.AsyncResult` object. @@ -423,6 +525,30 @@ class AsyncSession(ReversibleProxy[Session]): ) return AsyncResult(result) + @overload + async def stream_scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncScalarResult[_T]: + ... + + @overload + async def stream_scalars( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncScalarResult[Any]: + ... + async def stream_scalars( self, statement: Executable, @@ -1215,7 +1341,7 @@ class AsyncSession(ReversibleProxy[Session]): ident: Union[Any, Tuple[Any, ...]] = None, *, instance: Optional[Any] = None, - row: Optional[Row] = None, + row: Optional[Union[Row[Any], RowMapping]] = None, identity_token: Optional[Any] = None, ) -> _IdentityKeyType[Any]: r"""Return an identity key. diff --git a/lib/sqlalchemy/ext/instrumentation.py b/lib/sqlalchemy/ext/instrumentation.py index b1138a4ad..c14b466eb 100644 --- a/lib/sqlalchemy/ext/instrumentation.py +++ b/lib/sqlalchemy/ext/instrumentation.py @@ -23,6 +23,7 @@ from ..orm import base as orm_base from ..orm import collections from ..orm import exc as orm_exc from ..orm import instrumentation as orm_instrumentation +from ..orm import util as orm_util from ..orm.instrumentation import _default_dict_getter from ..orm.instrumentation import _default_manager_getter from ..orm.instrumentation import _default_opt_manager_getter @@ -437,5 +438,7 @@ def _install_lookups(lookups): attributes.manager_of_class ) = orm_instrumentation.manager_of_class = manager_of_class orm_base.opt_manager_of_class = ( + orm_util.opt_manager_of_class + ) = ( attributes.opt_manager_of_class ) = orm_instrumentation.opt_manager_of_class = opt_manager_of_class diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py index 457ad5c5a..48615b174 100644 --- a/lib/sqlalchemy/orm/_orm_constructors.py +++ b/lib/sqlalchemy/orm/_orm_constructors.py @@ -38,6 +38,7 @@ from ..exc import InvalidRequestError from ..sql.base import SchemaEventTarget from ..sql.schema import SchemaConst from ..sql.selectable import FromClause +from ..util.typing import Annotated from ..util.typing import Literal if TYPE_CHECKING: @@ -45,6 +46,7 @@ if TYPE_CHECKING: from ._typing import _ORMColumnExprArgument from .descriptor_props import _CompositeAttrType from .interfaces import PropComparator + from .mapper import Mapper from .query import Query from .relationships import _LazyLoadArgumentType from .relationships import _ORMBackrefArgument @@ -1849,9 +1851,27 @@ def clear_mappers(): mapperlib._dispose_registries(mapperlib._all_registries(), False) +# I would really like a way to get the Type[] here that shows up +# in a different way in typing tools, however there is no current method +# that is accepted by mypy (subclass of Type[_O] works in pylance, rejected +# by mypy). +AliasedType = Annotated[Type[_O], "aliased"] + + +@overload +def aliased( + element: Type[_O], + alias: Optional[Union[Alias, Subquery]] = None, + name: Optional[str] = None, + flat: bool = False, + adapt_on_names: bool = False, +) -> AliasedType[_O]: + ... + + @overload def aliased( - element: _EntityType[_O], + element: Union[AliasedClass[_O], Mapper[_O], AliasedInsp[_O]], alias: Optional[Union[Alias, Subquery]] = None, name: Optional[str] = None, flat: bool = False, @@ -1877,7 +1897,7 @@ def aliased( name: Optional[str] = None, flat: bool = False, adapt_on_names: bool = False, -) -> Union[AliasedClass[_O], FromClause]: +) -> Union[AliasedClass[_O], FromClause, AliasedType[_O]]: """Produce an alias of the given element, usually an :class:`.AliasedClass` instance. @@ -1885,7 +1905,8 @@ def aliased( my_alias = aliased(MyClass) - session.query(MyClass, my_alias).filter(MyClass.id > my_alias.id) + stmt = select(MyClass, my_alias).filter(MyClass.id > my_alias.id) + result = session.execute(stmt) The :func:`.aliased` function is used to create an ad-hoc mapping of a mapped class to a new selectable. By default, a selectable is generated @@ -1911,6 +1932,9 @@ def aliased( .. seealso:: + :class:`.AsAliased` - a :pep:`484` typed version of + :func:`_orm.aliased` + :ref:`tutorial_orm_entity_aliases` - in the :ref:`unified_tutorial` :ref:`orm_queryguide_orm_aliases` - in the :ref:`queryguide_toplevel` diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 41d944c57..619af6510 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -70,6 +70,7 @@ from .. import exc from .. import inspection from .. import util from ..sql import base as sql_base +from ..sql import cache_key from ..sql import roles from ..sql import traversals from ..sql import visitors @@ -99,10 +100,8 @@ class QueryableAttribute( traversals.HasCopyInternals, roles.JoinTargetRole, roles.OnClauseRole, - roles.ColumnsClauseRole, - roles.ExpressionElementRole[_T], sql_base.Immutable, - sql_base.MemoizedHasCacheKey, + cache_key.MemoizedHasCacheKey, ): """Base class for :term:`descriptor` objects that intercept attribute events on behalf of a :class:`.MapperProperty` diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py index 054d52d83..367a5332d 100644 --- a/lib/sqlalchemy/orm/base.py +++ b/lib/sqlalchemy/orm/base.py @@ -30,6 +30,7 @@ from ._typing import insp_is_mapper from .. import exc as sa_exc from .. import inspection from .. import util +from ..sql import roles from ..sql.elements import SQLCoreOperations from ..util import FastIntFlag from ..util.langhelpers import TypingOnly @@ -483,19 +484,6 @@ def _inspect_mapped_class( return mapper -@inspection._inspects(type) -def _inspect_mc(class_: Type[_O]) -> Optional[Mapper[_O]]: - try: - class_manager = opt_manager_of_class(class_) - if class_manager is None or not class_manager.is_mapped: - return None - mapper = class_manager.mapper - except exc.NO_STATE: - return None - else: - return mapper - - def _parse_mapper_argument(arg: Union[Mapper[_O], Type[_O]]) -> Mapper[_O]: insp = inspection.inspect(arg, raiseerr=False) if insp_is_mapper(insp): @@ -691,7 +679,7 @@ class ORMDescriptor(Generic[_T], TypingOnly): ... -class Mapped(ORMDescriptor[_T], TypingOnly): +class Mapped(ORMDescriptor[_T], roles.TypedColumnsClauseRole[_T], TypingOnly): """Represent an ORM mapped attribute on a mapped class. This class represents the complete descriptor interface for any class diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index 4fee2d383..05287cbcf 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -17,6 +17,7 @@ from typing import Set from typing import Tuple from typing import Type from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union from . import attributes @@ -48,14 +49,15 @@ from ..sql.base import _select_iterables from ..sql.base import CacheableOptions from ..sql.base import CompileState from ..sql.base import Executable +from ..sql.base import Generative from ..sql.base import Options from ..sql.dml import UpdateBase from ..sql.elements import GroupedElement from ..sql.elements import TextClause +from ..sql.selectable import ExecutableReturnsRows from ..sql.selectable import LABEL_STYLE_DISAMBIGUATE_ONLY from ..sql.selectable import LABEL_STYLE_NONE from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL -from ..sql.selectable import ReturnsRows from ..sql.selectable import Select from ..sql.selectable import SelectLabelStyle from ..sql.selectable import SelectState @@ -72,6 +74,7 @@ if TYPE_CHECKING: from ..sql.selectable import SelectBase from ..sql.type_api import TypeEngine +_T = TypeVar("_T", bound=Any) _path_registry = PathRegistry.root _EMPTY_DICT = util.immutabledict() @@ -574,7 +577,7 @@ class ORMFromStatementCompileState(ORMCompileState): return None -class FromStatement(GroupedElement, ReturnsRows, Executable): +class FromStatement(GroupedElement, Generative, ExecutableReturnsRows): """Core construct that represents a load of ORM objects from various :class:`.ReturnsRows` and other classes including: diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 0ca62b7e3..6a5690be2 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -61,7 +61,6 @@ from ..sql.schema import Column from ..sql.type_api import TypeEngine from ..util.typing import TypedDict - if typing.TYPE_CHECKING: from ._typing import _EntityType from ._typing import _IdentityKeyType @@ -106,12 +105,12 @@ class ORMStatementRole(roles.StatementRole): ) -class ORMColumnsClauseRole(roles.ColumnsClauseRole): +class ORMColumnsClauseRole(roles.TypedColumnsClauseRole[_T]): __slots__ = () _role_name = "ORM mapped entity, aliased entity, or Column expression" -class ORMEntityColumnsClauseRole(ORMColumnsClauseRole): +class ORMEntityColumnsClauseRole(ORMColumnsClauseRole[_T]): __slots__ = () _role_name = "ORM mapped or aliased entity" @@ -127,8 +126,8 @@ class ORMColumnDescription(TypedDict): # into "type" is a bad idea type: Union[Type[Any], TypeEngine[Any]] aliased: bool - expr: _ColumnsClauseArgument - entity: Optional[_ColumnsClauseArgument] + expr: _ColumnsClauseArgument[Any] + entity: Optional[_ColumnsClauseArgument[Any]] class _IntrospectsAnnotations: @@ -282,7 +281,7 @@ class MapperProperty( query_entity: _MapperEntity, path: PathRegistry, mapper: Mapper[Any], - result: Result, + result: Result[Any], adapter: Optional[ColumnAdapter], populators: _PopulatorDict, ) -> None: @@ -1170,7 +1169,7 @@ class LoaderStrategy: path: AbstractEntityRegistry, loadopt: Optional[_LoadElement], mapper: Mapper[Any], - result: Result, + result: Result[Any], adapter: Optional[ORMAdapter], populators: _PopulatorDict, ) -> None: diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index b37c080ea..083035093 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -96,7 +96,6 @@ if TYPE_CHECKING: from .descriptor_props import Synonym from .events import MapperEvents from .instrumentation import ClassManager - from .path_registry import AbstractEntityRegistry from .path_registry import CachingEntityRegistry from .properties import ColumnProperty from .relationships import Relationship @@ -108,10 +107,10 @@ if TYPE_CHECKING: from ..sql.base import ReadOnlyColumnCollection from ..sql.elements import ColumnClause from ..sql.elements import ColumnElement + from ..sql.elements import KeyedColumnElement from ..sql.schema import Column from ..sql.schema import Table from ..sql.selectable import FromClause - from ..sql.selectable import TableClause from ..sql.util import ColumnAdapter from ..util import OrderedSet @@ -161,7 +160,7 @@ _CONFIGURE_MUTEX = threading.RLock() @log.class_logger class Mapper( ORMFromClauseRole, - ORMEntityColumnsClauseRole, + ORMEntityColumnsClauseRole[_O], MemoizedHasCacheKey, InspectionAttr, log.Identified, @@ -1006,7 +1005,7 @@ class Mapper( """ - polymorphic_on: Optional[ColumnElement[Any]] + polymorphic_on: Optional[KeyedColumnElement[Any]] """The :class:`_schema.Column` or SQL expression specified as the ``polymorphic_on`` argument for this :class:`_orm.Mapper`, within an inheritance scenario. @@ -1699,10 +1698,10 @@ class Mapper( instrument = True key = getattr(col, "key", None) if key: - if self._should_exclude(col.key, col.key, False, col): + if self._should_exclude(key, key, False, col): raise sa_exc.InvalidRequestError( "Cannot exclude or override the " - "discriminator column %r" % col.key + "discriminator column %r" % key ) else: self.polymorphic_on = col = col.label("_sa_polymorphic_on") @@ -2948,7 +2947,7 @@ class Mapper( def identity_key_from_row( self, - row: Optional[Union[Row, RowMapping]], + row: Optional[Union[Row[Any], RowMapping]], identity_token: Optional[Any] = None, adapter: Optional[ColumnAdapter] = None, ) -> _IdentityKeyType[_O]: diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 9f37e8457..0ca0559b4 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -54,7 +54,7 @@ from ..util.typing import NoneType if TYPE_CHECKING: from ._typing import _ORMColumnExprArgument from ..sql._typing import _InfoType - from ..sql.elements import ColumnElement + from ..sql.elements import KeyedColumnElement _T = TypeVar("_T", bound=Any) _PT = TypeVar("_PT", bound=Any) @@ -85,7 +85,8 @@ class ColumnProperty( inherit_cache = True _links_to_entity = False - columns: List[ColumnElement[Any]] + columns: List[KeyedColumnElement[Any]] + _orig_columns: List[KeyedColumnElement[Any]] _is_polymorphic_discriminator: bool diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 395d01a1e..5bd302b21 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -27,6 +27,8 @@ from typing import Generic from typing import Iterable from typing import List from typing import Optional +from typing import overload +from typing import Sequence from typing import Tuple from typing import TYPE_CHECKING from typing import TypeVar @@ -36,6 +38,7 @@ from . import exc as orm_exc from . import interfaces from . import loading from . import util as orm_util +from ._typing import _O from .base import _assertions from .context import _column_descriptions from .context import _determine_last_joined_entity @@ -56,6 +59,7 @@ from .. import log from .. import sql from .. import util from ..engine import Result +from ..engine import Row from ..sql import coercions from ..sql import expression from ..sql import roles @@ -63,10 +67,12 @@ from ..sql import Select from ..sql import util as sql_util from ..sql import visitors from ..sql._typing import _FromClauseArgument +from ..sql._typing import _TP from ..sql.annotation import SupportsCloneAnnotations from ..sql.base import _entity_namespace_key from ..sql.base import _generative from ..sql.base import Executable +from ..sql.base import Generative from ..sql.expression import Exists from ..sql.selectable import _MemoizedSelectEntities from ..sql.selectable import _SelectFromElements @@ -75,10 +81,33 @@ from ..sql.selectable import HasHints from ..sql.selectable import HasPrefixes from ..sql.selectable import HasSuffixes from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL +from ..util.typing import Literal if TYPE_CHECKING: + from ._typing import _EntityType + from .session import Session + from ..engine.result import ScalarResult + from ..engine.row import Row + from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _ColumnsClauseArgument + from ..sql._typing import _MAYBE_ENTITY + from ..sql._typing import _no_kw + from ..sql._typing import _NOT_ENTITY + from ..sql._typing import _PropagateAttrsType + from ..sql._typing import _T0 + from ..sql._typing import _T1 + from ..sql._typing import _T2 + from ..sql._typing import _T3 + from ..sql._typing import _T4 + from ..sql._typing import _T5 + from ..sql._typing import _T6 + from ..sql._typing import _T7 + from ..sql._typing import _TypedColumnClauseArgument as _TCCA + from ..sql.roles import TypedColumnsClauseRole from ..sql.selectable import _SetupJoinsElement from ..sql.selectable import Alias + from ..sql.selectable import ExecutableReturnsRows + from ..sql.selectable import ScalarSelect from ..sql.selectable import Subquery __all__ = ["Query", "QueryContext"] @@ -97,6 +126,7 @@ class Query( HasSuffixes, HasHints, log.Identified, + Generative, Executable, Generic[_T], ): @@ -159,9 +189,15 @@ class Query( # mirrors that of ClauseElement, used to propagate the "orm" # plugin as well as the "subject" of the plugin, e.g. the mapper # we are querying against. - _propagate_attrs = util.immutabledict() + @util.memoized_property + def _propagate_attrs(self) -> _PropagateAttrsType: + return util.EMPTY_DICT - def __init__(self, entities, session=None): + def __init__( + self, + entities: Sequence[_ColumnsClauseArgument[Any]], + session: Optional[Session] = None, + ): """Construct a :class:`_query.Query` directly. E.g.:: @@ -207,6 +243,36 @@ class Query( for ent in util.to_list(entities) ] + @overload + def tuples(self: Query[Row[_TP]]) -> Query[_TP]: + ... + + @overload + def tuples(self: Query[_O]) -> Query[Tuple[_O]]: + ... + + def tuples(self) -> Query[Any]: + """return a tuple-typed form of this :class:`.Query`. + + This method invokes the :meth:`.Query.only_return_tuples` + method with a value of ``True``, which by itself ensures that this + :class:`.Query` will always return :class:`.Row` objects, even + if the query is made against a single entity. It then also + at the typing level will return a "typed" query, if possible, + that will type result rows as ``Tuple`` objects with typed + elements. + + This method can be compared to the :meth:`.Result.tuples` method, + which returns "self", but from a typing perspective returns an object + that will yield typed ``Tuple`` objects for results. Typing + takes effect only if this :class:`.Query` object is a typed + query object already. + + .. versionadded:: 2.0 + + """ + return self.only_return_tuples(True) + def _entity_from_pre_ent_zero(self): if not self._raw_columns: return None @@ -582,20 +648,52 @@ class Query( return self.enable_eagerloads(False).statement.label(name) + @overload + def as_scalar( + self: Query[Tuple[_MAYBE_ENTITY]], + ) -> ScalarSelect[_MAYBE_ENTITY]: + ... + + @overload + def as_scalar( + self: Query[Tuple[_NOT_ENTITY]], + ) -> ScalarSelect[_NOT_ENTITY]: + ... + + @overload + def as_scalar(self) -> ScalarSelect[Any]: + ... + @util.deprecated( "1.4", "The :meth:`_query.Query.as_scalar` method is deprecated and will be " "removed in a future release. Please refer to " ":meth:`_query.Query.scalar_subquery`.", ) - def as_scalar(self): + def as_scalar(self) -> ScalarSelect[Any]: """Return the full SELECT statement represented by this :class:`_query.Query`, converted to a scalar subquery. """ return self.scalar_subquery() - def scalar_subquery(self): + @overload + def scalar_subquery( + self: Query[Tuple[_MAYBE_ENTITY]], + ) -> ScalarSelect[Any]: + ... + + @overload + def scalar_subquery( + self: Query[Tuple[_NOT_ENTITY]], + ) -> ScalarSelect[_NOT_ENTITY]: + ... + + @overload + def scalar_subquery(self) -> ScalarSelect[Any]: + ... + + def scalar_subquery(self) -> ScalarSelect[Any]: """Return the full SELECT statement represented by this :class:`_query.Query`, converted to a scalar subquery. @@ -630,16 +728,31 @@ class Query( .statement ) - @_generative - def only_return_tuples(self: SelfQuery, value) -> SelfQuery: - """When set to True, the query results will always be a tuple. + @overload + def only_return_tuples( + self: Query[_O], value: Literal[True] + ) -> RowReturningQuery[Tuple[_O]]: + ... - This is specifically for single element queries. The default is False. + @overload + def only_return_tuples( + self: Query[_O], value: Literal[False] + ) -> Query[_O]: + ... - .. versionadded:: 1.2.5 + @_generative + def only_return_tuples(self, value: bool) -> Query[Any]: + """When set to True, the query results will always be a + :class:`.Row` object. + + This can change a query that normally returns a single entity + as a scalar to return a :class:`.Row` result in all cases. .. seealso:: + :meth:`.Query.tuples` - returns tuples, but also at the typing + level will type results as ``Tuple``. + :meth:`_query.Query.is_single_entity` """ @@ -1077,7 +1190,11 @@ class Query( return self.filter(with_parent(instance, property, entity_zero.entity)) @_generative - def add_entity(self: SelfQuery, entity, alias=None) -> SelfQuery: + def add_entity( + self, + entity: _EntityType[Any], + alias: Optional[Union[Alias, Subquery]] = None, + ) -> Query[Any]: """add a mapped entity to the list of result columns to be returned.""" @@ -1209,8 +1326,107 @@ class Query( except StopIteration: return None + @overload + def with_entities( + self, _entity: _EntityType[_O], **kwargs: Any + ) -> ScalarInstanceQuery[_O]: + ... + + @overload + def with_entities( + self, _colexpr: TypedColumnsClauseRole[_T] + ) -> RowReturningQuery[Tuple[_T]]: + ... + + # START OVERLOADED FUNCTIONS self.with_entities RowReturningQuery 2-8 + + # code within this block is **programmatically, + # statically generated** by tools/generate_tuple_map_overloads.py + + @overload + def with_entities( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] + ) -> RowReturningQuery[Tuple[_T0, _T1]]: + ... + + @overload + def with_entities( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]: + ... + + @overload + def with_entities( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]: + ... + + @overload + def with_entities( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... + + @overload + def with_entities( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... + + @overload + def with_entities( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... + + @overload + def with_entities( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... + + # END OVERLOADED FUNCTIONS self.with_entities + + @overload + def with_entities( + self: SelfQuery, *entities: _ColumnsClauseArgument[Any] + ) -> SelfQuery: + ... + @_generative - def with_entities(self: SelfQuery, *entities) -> SelfQuery: + def with_entities( + self: SelfQuery, *entities: _ColumnsClauseArgument[Any], **__kw: Any + ) -> SelfQuery: r"""Return a new :class:`_query.Query` replacing the SELECT list with the given entities. @@ -1234,12 +1450,14 @@ class Query( limit(1) """ + if __kw: + raise _no_kw() _MemoizedSelectEntities._generate_for_statement(self) self._set_entities(entities) return self @_generative - def add_columns(self: SelfQuery, *column) -> SelfQuery: + def add_columns(self, *column: _ColumnExpressionArgument) -> Query[Any]: """Add one or more column expressions to the list of result columns to be returned.""" @@ -1262,7 +1480,7 @@ class Query( "is deprecated and will be removed in a " "future release. Please use :meth:`_query.Query.add_columns`", ) - def add_column(self, column): + def add_column(self, column) -> Query[Any]: """Add a column expression to the list of result columns to be returned. @@ -1472,7 +1690,9 @@ class Query( @_generative @_assertions(_no_statement_condition, _no_limit_offset) - def filter(self: SelfQuery, *criterion) -> SelfQuery: + def filter( + self: SelfQuery, *criterion: _ColumnExpressionArgument[bool] + ) -> SelfQuery: r"""Apply the given filtering criterion to a copy of this :class:`_query.Query`, using SQL expressions. @@ -1556,7 +1776,7 @@ class Query( return self._raw_columns[0] - def filter_by(self, **kwargs): + def filter_by(self: SelfQuery, **kwargs: Any) -> SelfQuery: r"""Apply the given filtering criterion to a copy of this :class:`_query.Query`, using keyword expressions. @@ -1597,7 +1817,9 @@ class Query( @_generative @_assertions(_no_statement_condition, _no_limit_offset) - def order_by(self: SelfQuery, *clauses) -> SelfQuery: + def order_by( + self: SelfQuery, *clauses: _ColumnExpressionArgument[Any] + ) -> SelfQuery: """Apply one or more ORDER BY criteria to the query and return the newly resulting :class:`_query.Query`. @@ -1635,7 +1857,9 @@ class Query( @_generative @_assertions(_no_statement_condition, _no_limit_offset) - def group_by(self: SelfQuery, *clauses) -> SelfQuery: + def group_by( + self: SelfQuery, *clauses: _ColumnExpressionArgument[Any] + ) -> SelfQuery: """Apply one or more GROUP BY criterion to the query and return the newly resulting :class:`_query.Query`. @@ -1667,7 +1891,9 @@ class Query( @_generative @_assertions(_no_statement_condition, _no_limit_offset) - def having(self: SelfQuery, criterion) -> SelfQuery: + def having( + self: SelfQuery, *having: _ColumnExpressionArgument[bool] + ) -> SelfQuery: r"""Apply a HAVING criterion to the query and return the newly resulting :class:`_query.Query`. @@ -1684,17 +1910,17 @@ class Query( """ - self._having_criteria += ( - coercions.expect( - roles.WhereHavingRole, criterion, apply_propagate_attrs=self - ), - ) + for criterion in having: + having_criteria = coercions.expect( + roles.WhereHavingRole, criterion + ) + self._having_criteria += (having_criteria,) return self def _set_op(self, expr_fn, *q): return self._from_selectable(expr_fn(*([self] + list(q))).subquery()) - def union(self, *q): + def union(self: SelfQuery, *q: Query[Any]) -> SelfQuery: """Produce a UNION of this Query against one or more queries. e.g.:: @@ -1733,7 +1959,7 @@ class Query( """ return self._set_op(expression.union, *q) - def union_all(self, *q): + def union_all(self: SelfQuery, *q: Query[Any]) -> SelfQuery: """Produce a UNION ALL of this Query against one or more queries. Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See @@ -1742,7 +1968,7 @@ class Query( """ return self._set_op(expression.union_all, *q) - def intersect(self, *q): + def intersect(self: SelfQuery, *q: Query[Any]) -> SelfQuery: """Produce an INTERSECT of this Query against one or more queries. Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See @@ -1751,7 +1977,7 @@ class Query( """ return self._set_op(expression.intersect, *q) - def intersect_all(self, *q): + def intersect_all(self: SelfQuery, *q: Query[Any]) -> SelfQuery: """Produce an INTERSECT ALL of this Query against one or more queries. Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See @@ -1760,7 +1986,7 @@ class Query( """ return self._set_op(expression.intersect_all, *q) - def except_(self, *q): + def except_(self: SelfQuery, *q: Query[Any]) -> SelfQuery: """Produce an EXCEPT of this Query against one or more queries. Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See @@ -1769,7 +1995,7 @@ class Query( """ return self._set_op(expression.except_, *q) - def except_all(self, *q): + def except_all(self: SelfQuery, *q: Query[Any]) -> SelfQuery: """Produce an EXCEPT ALL of this Query against one or more queries. Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See @@ -2194,7 +2420,9 @@ class Query( @_generative @_assertions(_no_clauseelement_condition) - def from_statement(self: SelfQuery, statement) -> SelfQuery: + def from_statement( + self: SelfQuery, statement: ExecutableReturnsRows + ) -> SelfQuery: """Execute the given SELECT statement and return results. This method bypasses all internal statement compilation, and the @@ -2283,7 +2511,7 @@ class Query( :meth:`_query.Query.one_or_none` """ - return self._iter().one() + return self._iter().one() # type: ignore def scalar(self) -> Any: """Return the first element of the first result or None @@ -2316,7 +2544,7 @@ class Query( def __iter__(self) -> Iterable[_T]: return self._iter().__iter__() - def _iter(self): + def _iter(self) -> Union[ScalarResult[_T], Result[_T]]: # new style execution. params = self._params @@ -2837,3 +3065,7 @@ class BulkUpdate(BulkUD): class BulkDelete(BulkUD): """BulkUD which handles DELETEs.""" + + +class RowReturningQuery(Query[Row[_TP]]): + pass diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py index 93d18b8d7..9220c44c7 100644 --- a/lib/sqlalchemy/orm/scoping.py +++ b/lib/sqlalchemy/orm/scoping.py @@ -13,6 +13,7 @@ from typing import Dict from typing import Iterable from typing import Iterator from typing import Optional +from typing import overload from typing import Sequence from typing import Tuple from typing import Type @@ -20,8 +21,6 @@ from typing import TYPE_CHECKING from typing import TypeVar from typing import Union -from . import exc as orm_exc -from .base import class_mapper from .session import Session from .. import exc as sa_exc from .. import util @@ -33,11 +32,13 @@ from ..util import warn_deprecated from ..util.typing import Protocol if TYPE_CHECKING: + from ._typing import _EntityType from ._typing import _IdentityKeyType from .identity import IdentityMap from .interfaces import ORMOption from .mapper import Mapper from .query import Query + from .query import RowReturningQuery from .session import _BindArguments from .session import _EntityBindKey from .session import _PKIdentityArgument @@ -48,19 +49,33 @@ if TYPE_CHECKING: from ..engine import Engine from ..engine import Result from ..engine import Row + from ..engine import RowMapping from ..engine.interfaces import _CoreAnyExecuteParams from ..engine.interfaces import _CoreSingleExecuteParams from ..engine.interfaces import _ExecuteOptions from ..engine.interfaces import _ExecuteOptionsParameter from ..engine.result import ScalarResult from ..sql._typing import _ColumnsClauseArgument + from ..sql._typing import _T0 + from ..sql._typing import _T1 + from ..sql._typing import _T2 + from ..sql._typing import _T3 + from ..sql._typing import _T4 + from ..sql._typing import _T5 + from ..sql._typing import _T6 + from ..sql._typing import _T7 + from ..sql._typing import _TypedColumnClauseArgument as _TCCA from ..sql.base import Executable from ..sql.elements import ClauseElement + from ..sql.roles import TypedColumnsClauseRole from ..sql.selectable import ForUpdateArg + from ..sql.selectable import TypedReturnsRows + +_T = TypeVar("_T", bound=Any) class _QueryDescriptorType(Protocol): - def __get__(self, instance: Any, owner: Type[Any]) -> Optional[Query[Any]]: + def __get__(self, instance: Any, owner: Type[_T]) -> Query[_T]: ... @@ -236,7 +251,7 @@ class scoped_session: self.registry.clear() def query_property( - self, query_cls: Optional[Type[Query[Any]]] = None + self, query_cls: Optional[Type[Query[_T]]] = None ) -> _QueryDescriptorType: """return a class property which produces a :class:`_query.Query` object @@ -264,20 +279,13 @@ class scoped_session: """ class query: - def __get__( - s, instance: Any, owner: Type[Any] - ) -> Optional[Query[Any]]: - try: - mapper = class_mapper(owner) - assert mapper is not None - if query_cls: - # custom query class - return query_cls(mapper, session=self.registry()) - else: - # session's configured query class - return self.registry().query(mapper) - except orm_exc.UnmappedClassError: - return None + def __get__(s, instance: Any, owner: Type[_O]) -> Query[_O]: + if query_cls: + # custom query class + return query_cls(owner, session=self.registry()) # type: ignore # noqa: E501 + else: + # session's configured query class + return self.registry().query(owner) return query() @@ -548,6 +556,32 @@ class scoped_session: return self._proxied.delete(instance) + @overload + def execute( + self, + statement: TypedReturnsRows[_T], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[_T]: + ... + + @overload + def execute( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[Any]: + ... + def execute( self, statement: Executable, @@ -557,7 +591,7 @@ class scoped_session: bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result: + ) -> Result[Any]: r"""Execute a SQL expression construct. .. container:: class_bases @@ -1430,8 +1464,103 @@ class scoped_session: return self._proxied.merge(instance, load=load, options=options) + @overload + def query(self, _entity: _EntityType[_O]) -> Query[_O]: + ... + + @overload def query( - self, *entities: _ColumnsClauseArgument, **kwargs: Any + self, _colexpr: TypedColumnsClauseRole[_T] + ) -> RowReturningQuery[Tuple[_T]]: + ... + + # START OVERLOADED FUNCTIONS self.query RowReturningQuery 2-8 + + # code within this block is **programmatically, + # statically generated** by tools/generate_tuple_map_overloads.py + + @overload + def query( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] + ) -> RowReturningQuery[Tuple[_T0, _T1]]: + ... + + @overload + def query( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... + + # END OVERLOADED FUNCTIONS self.query + + @overload + def query( + self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any + ) -> Query[Any]: + ... + + def query( + self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any ) -> Query[Any]: r"""Return a new :class:`_query.Query` object corresponding to this :class:`_orm.Session`. @@ -1559,6 +1688,30 @@ class scoped_session: return self._proxied.rollback() + @overload + def scalar( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Optional[_T]: + ... + + @overload + def scalar( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Any: + ... + def scalar( self, statement: Executable, @@ -1590,6 +1743,30 @@ class scoped_session: **kw, ) + @overload + def scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[_T]: + ... + + @overload + def scalars( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[Any]: + ... + def scalars( self, statement: Executable, @@ -1848,7 +2025,7 @@ class scoped_session: ident: Union[Any, Tuple[Any, ...]] = None, *, instance: Optional[Any] = None, - row: Optional[Row] = None, + row: Optional[Union[Row[Any], RowMapping]] = None, identity_token: Optional[Any] = None, ) -> _IdentityKeyType[Any]: r"""Return an identity key. diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 74035ec0a..263d56101 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -27,6 +27,7 @@ from typing import Set from typing import Tuple from typing import Type from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union import weakref @@ -85,12 +86,16 @@ from ..util.typing import Literal from ..util.typing import Protocol if typing.TYPE_CHECKING: + from ._typing import _EntityType from ._typing import _IdentityKeyType from ._typing import _InstanceDict + from ._typing import _O + from .context import FromStatement from .interfaces import ORMOption from .interfaces import UserDefinedOption from .mapper import Mapper from .path_registry import PathRegistry + from .query import RowReturningQuery from ..engine import Result from ..engine import Row from ..engine import RowMapping @@ -104,10 +109,23 @@ if typing.TYPE_CHECKING: from ..event import _InstanceLevelDispatch from ..sql._typing import _ColumnsClauseArgument from ..sql._typing import _InfoType + from ..sql._typing import _T0 + from ..sql._typing import _T1 + from ..sql._typing import _T2 + from ..sql._typing import _T3 + from ..sql._typing import _T4 + from ..sql._typing import _T5 + from ..sql._typing import _T6 + from ..sql._typing import _T7 + from ..sql._typing import _TypedColumnClauseArgument as _TCCA from ..sql.base import Executable from ..sql.elements import ClauseElement + from ..sql.roles import TypedColumnsClauseRole from ..sql.schema import Table - from ..sql.selectable import TableClause + from ..sql.selectable import Select + from ..sql.selectable import TypedReturnsRows + +_T = TypeVar("_T", bound=Any) __all__ = [ "Session", @@ -189,7 +207,7 @@ class _SessionClassMethods: ident: Union[Any, Tuple[Any, ...]] = None, *, instance: Optional[Any] = None, - row: Optional[Union[Row, RowMapping]] = None, + row: Optional[Union[Row[Any], RowMapping]] = None, identity_token: Optional[Any] = None, ) -> _IdentityKeyType[Any]: """Return an identity key. @@ -295,7 +313,7 @@ class ORMExecuteState(util.MemoizedSlots): params: Optional[_CoreAnyExecuteParams] = None, execution_options: Optional[_ExecuteOptionsParameter] = None, bind_arguments: Optional[_BindArguments] = None, - ) -> Result: + ) -> Result[Any]: """Execute the statement represented by this :class:`.ORMExecuteState`, without re-invoking events that have already proceeded. @@ -1718,7 +1736,7 @@ class Session(_SessionClassMethods, EventTarget): _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, _scalar_result: bool = ..., - ) -> Result: + ) -> Result[Any]: ... def _execute_internal( @@ -1789,7 +1807,7 @@ class Session(_SessionClassMethods, EventTarget): ) for idx, fn in enumerate(events_todo): orm_exec_state._starting_event_idx = idx - fn_result: Optional[Result] = fn(orm_exec_state) + fn_result: Optional[Result[Any]] = fn(orm_exec_state) if fn_result: if _scalar_result: return fn_result.scalar() @@ -1806,10 +1824,12 @@ class Session(_SessionClassMethods, EventTarget): if _scalar_result and not compile_state_cls: if TYPE_CHECKING: params = cast(_CoreSingleExecuteParams, params) - return conn.scalar(statement, params or {}, execution_options) + return conn.scalar( + statement, params or {}, execution_options=execution_options + ) - result: Result = conn.execute( - statement, params or {}, execution_options + result: Result[Any] = conn.execute( + statement, params or {}, execution_options=execution_options ) if compile_state_cls: @@ -1827,6 +1847,32 @@ class Session(_SessionClassMethods, EventTarget): else: return result + @overload + def execute( + self, + statement: TypedReturnsRows[_T], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[_T]: + ... + + @overload + def execute( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[Any]: + ... + def execute( self, statement: Executable, @@ -1836,7 +1882,7 @@ class Session(_SessionClassMethods, EventTarget): bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result: + ) -> Result[Any]: r"""Execute a SQL expression construct. Returns a :class:`_engine.Result` object representing @@ -1897,6 +1943,30 @@ class Session(_SessionClassMethods, EventTarget): _add_event=_add_event, ) + @overload + def scalar( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Optional[_T]: + ... + + @overload + def scalar( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Any: + ... + def scalar( self, statement: Executable, @@ -1923,6 +1993,30 @@ class Session(_SessionClassMethods, EventTarget): **kw, ) + @overload + def scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[_T]: + ... + + @overload + def scalars( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[Any]: + ... + def scalars( self, statement: Executable, @@ -2284,8 +2378,103 @@ class Session(_SessionClassMethods, EventTarget): f'{", ".join(context)} or this Session.' ) + @overload + def query(self, _entity: _EntityType[_O]) -> Query[_O]: + ... + + @overload + def query( + self, _colexpr: TypedColumnsClauseRole[_T] + ) -> RowReturningQuery[Tuple[_T]]: + ... + + # START OVERLOADED FUNCTIONS self.query RowReturningQuery 2-8 + + # code within this block is **programmatically, + # statically generated** by tools/generate_tuple_map_overloads.py + + @overload + def query( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] + ) -> RowReturningQuery[Tuple[_T0, _T1]]: + ... + + @overload + def query( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... + + # END OVERLOADED FUNCTIONS self.query + + @overload + def query( + self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any + ) -> Query[Any]: + ... + def query( - self, *entities: _ColumnsClauseArgument, **kwargs: Any + self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any ) -> Query[Any]: """Return a new :class:`_query.Query` object corresponding to this :class:`_orm.Session`. @@ -2486,7 +2675,7 @@ class Session(_SessionClassMethods, EventTarget): with_for_update = ForUpdateArg._from_argument(with_for_update) - stmt = sql.select(object_mapper(instance)) + stmt: Select[Any] = sql.select(object_mapper(instance)) if ( loading.load_on_ident( self, diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py index 58f141997..ab32a3981 100644 --- a/lib/sqlalchemy/orm/state.py +++ b/lib/sqlalchemy/orm/state.py @@ -656,13 +656,13 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_O]): @classmethod def _instance_level_callable_processor( cls, manager: ClassManager[_O], fn: _LoaderCallable, key: Any - ) -> Callable[[InstanceState[_O], _InstanceDict, Row], None]: + ) -> Callable[[InstanceState[_O], _InstanceDict, Row[Any]], None]: impl = manager[key].impl if is_collection_impl(impl): fixed_impl = impl def _set_callable( - state: InstanceState[_O], dict_: _InstanceDict, row: Row + state: InstanceState[_O], dict_: _InstanceDict, row: Row[Any] ) -> None: if "callables" not in state.__dict__: state.callables = {} @@ -674,7 +674,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_O]): else: def _set_callable( - state: InstanceState[_O], dict_: _InstanceDict, row: Row + state: InstanceState[_O], dict_: _InstanceDict, row: Row[Any] ) -> None: if "callables" not in state.__dict__: state.callables = {} diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 3934de535..8148793b1 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -28,6 +28,7 @@ from typing import Union import weakref from . import attributes # noqa +from . import exc from ._typing import _O from ._typing import insp_is_aliased_class from ._typing import insp_is_mapper @@ -41,6 +42,7 @@ from .base import InspectionAttr as InspectionAttr from .base import instance_str as instance_str from .base import object_mapper as object_mapper from .base import object_state as object_state +from .base import opt_manager_of_class from .base import state_attribute_str as state_attribute_str from .base import state_class_str as state_class_str from .base import state_str as state_str @@ -68,6 +70,7 @@ from ..sql.base import ColumnCollection from ..sql.cache_key import HasCacheKey from ..sql.cache_key import MemoizedHasCacheKey from ..sql.elements import ColumnElement +from ..sql.elements import KeyedColumnElement from ..sql.selectable import FromClause from ..util.langhelpers import MemoizedSlots from ..util.typing import de_stringify_annotation @@ -95,9 +98,7 @@ if typing.TYPE_CHECKING: from ..sql.selectable import _ColumnsClauseElement from ..sql.selectable import Alias from ..sql.selectable import Subquery - from ..sql.visitors import _ET from ..sql.visitors import anon_map - from ..sql.visitors import ExternallyTraversible _T = TypeVar("_T", bound=Any) @@ -341,7 +342,7 @@ def identity_key( ident: Union[Any, Tuple[Any, ...]] = None, *, instance: Optional[_T] = None, - row: Optional[Union[Row, RowMapping]] = None, + row: Optional[Union[Row[Any], RowMapping]] = None, identity_token: Optional[Any] = None, ) -> _IdentityKeyType[_T]: r"""Generate "identity key" tuples, as are used as keys in the @@ -468,7 +469,9 @@ class ORMAdapter(sql_util.ColumnAdapter): return not entity or entity.isa(self.mapper) -class AliasedClass(inspection.Inspectable["AliasedInsp[_O]"], Generic[_O]): +class AliasedClass( + inspection.Inspectable["AliasedInsp[_O]"], ORMColumnsClauseRole[_O] +): r"""Represents an "aliased" form of a mapped class for usage with Query. The ORM equivalent of a :func:`~sqlalchemy.sql.expression.alias` @@ -663,7 +666,7 @@ class AliasedClass(inspection.Inspectable["AliasedInsp[_O]"], Generic[_O]): @inspection._self_inspects class AliasedInsp( - ORMEntityColumnsClauseRole, + ORMEntityColumnsClauseRole[_O], ORMFromClauseRole, HasCacheKey, InspectionAttr, @@ -1276,12 +1279,29 @@ class LoaderCriteriaOption(CriteriaOption): inspection._inspects(AliasedClass)(lambda target: target._aliased_insp) +@inspection._inspects(type) +def _inspect_mc( + class_: Type[_O], +) -> Optional[Mapper[_O]]: + + try: + class_manager = opt_manager_of_class(class_) + if class_manager is None or not class_manager.is_mapped: + return None + mapper = class_manager.mapper + except exc.NO_STATE: + + return None + else: + return mapper + + @inspection._self_inspects class Bundle( - ORMColumnsClauseRole, + ORMColumnsClauseRole[_T], SupportsCloneAnnotations, MemoizedHasCacheKey, - inspection.Inspectable["Bundle"], + inspection.Inspectable["Bundle[_T]"], InspectionAttr, ): """A grouping of SQL expressions that are returned by a :class:`.Query` @@ -1373,10 +1393,10 @@ class Bundle( @property def entity_namespace( self, - ) -> ReadOnlyColumnCollection[str, ColumnElement[Any]]: + ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]: return self.c - columns: ReadOnlyColumnCollection[str, ColumnElement[Any]] + columns: ReadOnlyColumnCollection[str, KeyedColumnElement[Any]] """A namespace of SQL expressions referred to by this :class:`.Bundle`. @@ -1402,7 +1422,7 @@ class Bundle( """ - c: ReadOnlyColumnCollection[str, ColumnElement[Any]] + c: ReadOnlyColumnCollection[str, KeyedColumnElement[Any]] """An alias for :attr:`.Bundle.columns`.""" def _clone(self): @@ -1908,9 +1928,10 @@ def _extract_mapped_subtype( raw_annotation: Union[type, str], cls: type, key: str, - attr_cls: type, + attr_cls: Type[Any], required: bool, is_dataclass_field: bool, + superclasses: Optional[Tuple[Type[Any], ...]] = None, ) -> Optional[Union[type, str]]: if raw_annotation is None: @@ -1930,9 +1951,13 @@ def _extract_mapped_subtype( if is_dataclass_field: return annotated else: - if ( - not hasattr(annotated, "__origin__") - or not issubclass(annotated.__origin__, attr_cls) # type: ignore + # TODO: there don't seem to be tests for the failure + # conditions here + if not hasattr(annotated, "__origin__") or ( + not issubclass( + annotated.__origin__, # type: ignore + superclasses if superclasses else attr_cls, + ) and not issubclass(attr_cls, annotated.__origin__) # type: ignore ): our_annotated_str = ( diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py index 84913225d..c3ebb4596 100644 --- a/lib/sqlalchemy/sql/__init__.py +++ b/lib/sqlalchemy/sql/__init__.py @@ -121,7 +121,6 @@ def __go(lcls: Any) -> None: coercions.lambdas = lambdas coercions.schema = schema coercions.selectable = selectable - coercions.traversals = traversals from .annotation import _prepare_annotations from .annotation import Annotated diff --git a/lib/sqlalchemy/sql/_selectable_constructors.py b/lib/sqlalchemy/sql/_selectable_constructors.py index 37d44976a..f89e8f578 100644 --- a/lib/sqlalchemy/sql/_selectable_constructors.py +++ b/lib/sqlalchemy/sql/_selectable_constructors.py @@ -9,12 +9,16 @@ from __future__ import annotations from typing import Any from typing import Optional +from typing import overload +from typing import Tuple from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union from . import coercions from . import roles from ._typing import _ColumnsClauseArgument +from ._typing import _no_kw from .elements import ColumnClause from .selectable import Alias from .selectable import CompoundSelect @@ -34,6 +38,17 @@ if TYPE_CHECKING: from ._typing import _FromClauseArgument from ._typing import _OnClauseArgument from ._typing import _SelectStatementForCompoundArgument + from ._typing import _T0 + from ._typing import _T1 + from ._typing import _T2 + from ._typing import _T3 + from ._typing import _T4 + from ._typing import _T5 + from ._typing import _T6 + from ._typing import _T7 + from ._typing import _T8 + from ._typing import _T9 + from ._typing import _TypedColumnClauseArgument as _TCCA from .functions import Function from .selectable import CTE from .selectable import HasCTE @@ -41,6 +56,9 @@ if TYPE_CHECKING: from .selectable import SelectBase +_T = TypeVar("_T", bound=Any) + + def alias( selectable: FromClause, name: Optional[str] = None, flat: bool = False ) -> NamedFromClause: @@ -89,7 +107,9 @@ def cte( ) -def except_(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect: +def except_( + *selects: _SelectStatementForCompoundArgument, +) -> CompoundSelect: r"""Return an ``EXCEPT`` of multiple selectables. The returned object is an instance of @@ -119,7 +139,7 @@ def except_all( def exists( __argument: Optional[ - Union[_ColumnsClauseArgument, SelectBase, ScalarSelect[bool]] + Union[_ColumnsClauseArgument[Any], SelectBase, ScalarSelect[Any]] ] = None, ) -> Exists: """Construct a new :class:`_expression.Exists` construct. @@ -162,7 +182,9 @@ def exists( return Exists(__argument) -def intersect(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect: +def intersect( + *selects: _SelectStatementForCompoundArgument, +) -> CompoundSelect: r"""Return an ``INTERSECT`` of multiple selectables. The returned object is an instance of @@ -306,7 +328,129 @@ def outerjoin( return Join(left, right, onclause, isouter=True, full=full) -def select(*entities: _ColumnsClauseArgument) -> Select: +# START OVERLOADED FUNCTIONS select Select 1-10 + +# code within this block is **programmatically, +# statically generated** by tools/generate_tuple_map_overloads.py + + +@overload +def select(__ent0: _TCCA[_T0]) -> Select[Tuple[_T0]]: + ... + + +@overload +def select(__ent0: _TCCA[_T0], __ent1: _TCCA[_T1]) -> Select[Tuple[_T0, _T1]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] +) -> Select[Tuple[_T0, _T1, _T2]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], +) -> Select[Tuple[_T0, _T1, _T2, _T3]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + __ent8: _TCCA[_T8], +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + __ent8: _TCCA[_T8], + __ent9: _TCCA[_T9], +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8, _T9]]: + ... + + +# END OVERLOADED FUNCTIONS select + + +@overload +def select(*entities: _ColumnsClauseArgument[Any], **__kw: Any) -> Select[Any]: + ... + + +def select(*entities: _ColumnsClauseArgument[Any], **__kw: Any) -> Select[Any]: r"""Construct a new :class:`_expression.Select`. @@ -343,7 +487,11 @@ def select(*entities: _ColumnsClauseArgument) -> Select: given, as well as ORM-mapped classes. """ - + # the keyword args are a necessary element in order for the typing + # to work out w/ the varargs vs. having named "keyword" arguments that + # aren't always present. + if __kw: + raise _no_kw() return Select(*entities) @@ -425,7 +573,9 @@ def tablesample( return TableSample._factory(selectable, sampling, name=name, seed=seed) -def union(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect: +def union( + *selects: _SelectStatementForCompoundArgument, +) -> CompoundSelect: r"""Return a ``UNION`` of multiple selectables. The returned object is an instance of @@ -445,7 +595,9 @@ def union(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect: return CompoundSelect._create_union(*selects) -def union_all(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect: +def union_all( + *selects: _SelectStatementForCompoundArgument, +) -> CompoundSelect: r"""Return a ``UNION ALL`` of multiple selectables. The returned object is an instance of diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index 53d29b628..1df530dbd 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -5,18 +5,27 @@ from typing import Any from typing import Callable from typing import Dict from typing import Set +from typing import Tuple from typing import Type from typing import TYPE_CHECKING from typing import TypeVar from typing import Union from . import roles +from .. import exc from .. import util from ..inspection import Inspectable from ..util.typing import Literal from ..util.typing import Protocol if TYPE_CHECKING: + from datetime import date + from datetime import datetime + from datetime import time + from datetime import timedelta + from decimal import Decimal + from uuid import UUID + from .base import Executable from .compiler import Compiled from .compiler import DDLCompiler @@ -26,17 +35,15 @@ if TYPE_CHECKING: from .elements import ClauseElement from .elements import ColumnClause from .elements import ColumnElement + from .elements import KeyedColumnElement from .elements import quoted_name - from .elements import SQLCoreOperations from .elements import TextClause from .lambdas import LambdaElement from .roles import ColumnsClauseRole from .roles import FromClauseRole from .schema import Column - from .schema import DefaultGenerator - from .schema import Sequence - from .schema import Table from .selectable import Alias + from .selectable import CTE from .selectable import FromClause from .selectable import Join from .selectable import NamedFromClause @@ -61,6 +68,30 @@ class _HasClauseElement(Protocol): ... +# match column types that are not ORM entities +_NOT_ENTITY = TypeVar( + "_NOT_ENTITY", + int, + str, + "datetime", + "date", + "time", + "timedelta", + "UUID", + float, + "Decimal", +) + +_MAYBE_ENTITY = TypeVar( + "_MAYBE_ENTITY", + roles.ColumnsClauseRole, + Literal["*", 1], + Type[Any], + Inspectable[_HasClauseElement], + _HasClauseElement, +) + + # convention: # XYZArgument - something that the end user is passing to a public API method # XYZElement - the internal representation that we use for the thing. @@ -76,9 +107,10 @@ _TextCoercedExpressionArgument = Union[ ] _ColumnsClauseArgument = Union[ - Literal["*", 1], + roles.TypedColumnsClauseRole[_T], roles.ColumnsClauseRole, - Type[Any], + Literal["*", 1], + Type[_T], Inspectable[_HasClauseElement], _HasClauseElement, ] @@ -92,6 +124,24 @@ sets; select(...), insert().returning(...), etc. """ +_TypedColumnClauseArgument = Union[ + roles.TypedColumnsClauseRole[_T], roles.ExpressionElementRole[_T], Type[_T] +] + +_TP = TypeVar("_TP", bound=Tuple[Any, ...]) + +_T0 = TypeVar("_T0", bound=Any) +_T1 = TypeVar("_T1", bound=Any) +_T2 = TypeVar("_T2", bound=Any) +_T3 = TypeVar("_T3", bound=Any) +_T4 = TypeVar("_T4", bound=Any) +_T5 = TypeVar("_T5", bound=Any) +_T6 = TypeVar("_T6", bound=Any) +_T7 = TypeVar("_T7", bound=Any) +_T8 = TypeVar("_T8", bound=Any) +_T9 = TypeVar("_T9", bound=Any) + + _ColumnExpressionArgument = Union[ "ColumnElement[_T]", _HasClauseElement, @@ -169,6 +219,7 @@ _DMLTableArgument = Union[ "TableClause", "Join", "Alias", + "CTE", Type[Any], Inspectable[_HasClauseElement], _HasClauseElement, @@ -194,6 +245,11 @@ if TYPE_CHECKING: def is_column_element(c: ClauseElement) -> TypeGuard[ColumnElement[Any]]: ... + def is_keyed_column_element( + c: ClauseElement, + ) -> TypeGuard[KeyedColumnElement[Any]]: + ... + def is_text_clause(c: ClauseElement) -> TypeGuard[TextClause]: ... @@ -216,7 +272,7 @@ if TYPE_CHECKING: def is_select_statement( t: Union[Executable, ReturnsRows] - ) -> TypeGuard[Select]: + ) -> TypeGuard[Select[Any]]: ... def is_table(t: FromClause) -> TypeGuard[TableClause]: @@ -234,6 +290,7 @@ else: is_ddl_compiler = operator.attrgetter("is_ddl") is_named_from_clause = operator.attrgetter("named_with_column") is_column_element = operator.attrgetter("_is_column_element") + is_keyed_column_element = operator.attrgetter("_is_keyed_column_element") is_text_clause = operator.attrgetter("_is_text_clause") is_from_clause = operator.attrgetter("_is_from_clause") is_tuple_type = operator.attrgetter("_is_tuple_type") @@ -260,3 +317,10 @@ def is_has_clause_element(s: object) -> TypeGuard[_HasClauseElement]: def is_insert_update(c: ClauseElement) -> TypeGuard[ValuesBase]: return c.is_dml and (c.is_insert or c.is_update) # type: ignore + + +def _no_kw() -> exc.ArgumentError: + return exc.ArgumentError( + "Additional keyword arguments are not accepted by this " + "function/method. The presence of **kw is for pep-484 typing purposes" + ) diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index f81878d55..790edefc6 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -62,10 +62,10 @@ if TYPE_CHECKING: from . import coercions from . import elements from . import type_api - from ._typing import _ColumnsClauseArgument from .elements import BindParameter - from .elements import ColumnClause + from .elements import ColumnClause # noqa from .elements import ColumnElement + from .elements import KeyedColumnElement from .elements import NamedColumn from .elements import SQLCoreOperations from .elements import TextClause @@ -74,7 +74,6 @@ if TYPE_CHECKING: from .selectable import FromClause from ..engine import Connection from ..engine import CursorResult - from ..engine import Result from ..engine.base import _CompiledCacheType from ..engine.interfaces import _CoreMultiExecuteParams from ..engine.interfaces import _ExecuteOptions @@ -704,8 +703,11 @@ class InPlaceGenerative(HasMemoized): """Provide a method-chaining pattern in conjunction with the @_generative decorator that mutates in place.""" + __slots__ = () + def _generate(self): skip = self._memoized_keys + # note __dict__ needs to be in __slots__ if this is used for k in skip: self.__dict__.pop(k, None) return self @@ -937,7 +939,7 @@ class ExecutableOption(HasCopyInternals): SelfExecutable = TypeVar("SelfExecutable", bound="Executable") -class Executable(roles.StatementRole, Generative): +class Executable(roles.StatementRole): """Mark a :class:`_expression.ClauseElement` as supporting execution. :class:`.Executable` is a superclass for all "statement" types @@ -994,7 +996,7 @@ class Executable(roles.StatementRole, Generative): connection: Connection, distilled_params: _CoreMultiExecuteParams, execution_options: _ExecuteOptionsParameter, - ) -> CursorResult: + ) -> CursorResult[Any]: ... def _execute_on_scalar( @@ -1253,7 +1255,7 @@ class SchemaVisitor(ClauseVisitor): _COLKEY = TypeVar("_COLKEY", Union[None, str], str) _COL_co = TypeVar("_COL_co", bound="ColumnElement[Any]", covariant=True) -_COL = TypeVar("_COL", bound="ColumnElement[Any]") +_COL = TypeVar("_COL", bound="KeyedColumnElement[Any]") class ColumnCollection(Generic[_COLKEY, _COL_co]): @@ -1505,6 +1507,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]): ) -> None: """populate from an iterator of (key, column)""" cols = list(iter_) + self._collection[:] = cols self._colset.update(c for k, c in self._collection) self._index.update( diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index 0659709ab..9b7231360 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -29,6 +29,7 @@ from typing import Union from . import operators from . import roles from . import visitors +from ._typing import is_from_clause from .base import ExecutableOption from .base import Options from .cache_key import HasCacheKey @@ -38,25 +39,18 @@ from .. import inspection from .. import util from ..util.typing import Literal -if not typing.TYPE_CHECKING: - elements = None - lambdas = None - schema = None - selectable = None - traversals = None - if typing.TYPE_CHECKING: from . import elements from . import lambdas from . import schema from . import selectable - from . import traversals from ._typing import _ColumnExpressionArgument from ._typing import _ColumnsClauseArgument from ._typing import _DDLColumnArgument from ._typing import _DMLTableArgument from ._typing import _FromClauseArgument from .dml import _DMLTableElement + from .elements import BindParameter from .elements import ClauseElement from .elements import ColumnClause from .elements import ColumnElement @@ -64,9 +58,7 @@ if typing.TYPE_CHECKING: from .elements import SQLCoreOperations from .schema import Column from .selectable import _ColumnsClauseElement - from .selectable import _JoinTargetElement from .selectable import _JoinTargetProtocol - from .selectable import _OnClauseElement from .selectable import FromClause from .selectable import HasCTE from .selectable import SelectBase @@ -170,6 +162,15 @@ def expect( @overload def expect( + role: Type[roles.LiteralValueRole], + element: Any, + **kw: Any, +) -> BindParameter[Any]: + ... + + +@overload +def expect( role: Type[roles.DDLReferredColumnRole], element: Any, **kw: Any, @@ -272,7 +273,7 @@ def expect( @overload def expect( role: Type[roles.ColumnsClauseRole], - element: _ColumnsClauseArgument, + element: _ColumnsClauseArgument[Any], **kw: Any, ) -> _ColumnsClauseElement: ... @@ -933,7 +934,7 @@ class GroupByImpl(ByOfImpl, RoleImpl): argname: Optional[str] = None, **kw: Any, ) -> Any: - if isinstance(resolved, roles.StrictFromClauseRole): + if is_from_clause(resolved): return elements.ClauseList(*resolved.c) else: return resolved diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index c524a2602..a1b25b8a6 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -80,7 +80,6 @@ from ..util.typing import Protocol from ..util.typing import TypedDict if typing.TYPE_CHECKING: - from . import roles from .annotation import _AnnotationDict from .base import _AmbiguousTableNameMap from .base import CompileState @@ -95,7 +94,6 @@ if typing.TYPE_CHECKING: from .elements import ColumnElement from .elements import Label from .functions import Function - from .selectable import Alias from .selectable import AliasedReturnsRows from .selectable import CompoundSelectState from .selectable import CTE @@ -386,7 +384,7 @@ class _CompilerStackEntry(_BaseCompilerStackEntry, total=False): need_result_map_for_nested: bool need_result_map_for_compound: bool select_0: ReturnsRows - insert_from_select: Select + insert_from_select: Select[Any] class ExpandedState(NamedTuple): @@ -2834,15 +2832,31 @@ class SQLCompiler(Compiled): "unique bind parameter of the same name" % name ) elif existing._is_crud or bindparam._is_crud: - raise exc.CompileError( - "bindparam() name '%s' is reserved " - "for automatic usage in the VALUES or SET " - "clause of this " - "insert/update statement. Please use a " - "name other than column name when using bindparam() " - "with insert() or update() (for example, 'b_%s')." - % (bindparam.key, bindparam.key) - ) + if existing._is_crud and bindparam._is_crud: + # TODO: this condition is not well understood. + # see tests in test/sql/test_update.py + raise exc.CompileError( + "Encountered unsupported case when compiling an " + "INSERT or UPDATE statement. If this is a " + "multi-table " + "UPDATE statement, please provide string-named " + "arguments to the " + "values() method with distinct names; support for " + "multi-table UPDATE statements that " + "target multiple tables for UPDATE is very " + "limited", + ) + else: + raise exc.CompileError( + f"bindparam() name '{bindparam.key}' is reserved " + "for automatic usage in the VALUES or SET " + "clause of this " + "insert/update statement. Please use a " + "name other than column name when using " + "bindparam() " + "with insert() or update() (for example, " + f"'b_{bindparam.key}')." + ) self.binds[bindparam.key] = self.binds[name] = bindparam @@ -3881,7 +3895,7 @@ class SQLCompiler(Compiled): return text def _setup_select_hints( - self, select: Select + self, select: Select[Any] ) -> Tuple[str, _FromHintsType]: byfrom = dict( [ diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index e4408cd31..29d7b45d7 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -22,6 +22,7 @@ from typing import MutableMapping from typing import NamedTuple from typing import Optional from typing import overload +from typing import Sequence from typing import Tuple from typing import TYPE_CHECKING from typing import Union @@ -30,8 +31,10 @@ from . import coercions from . import dml from . import elements from . import roles +from .elements import ColumnClause from .schema import default_is_clause_element from .schema import default_is_sequence +from .selectable import TableClause from .. import exc from .. import util from ..util.typing import Literal @@ -41,16 +44,9 @@ if TYPE_CHECKING: from .compiler import SQLCompiler from .dml import _DMLColumnElement from .dml import DMLState - from .dml import Insert - from .dml import Update - from .dml import UpdateDMLState from .dml import ValuesBase - from .elements import ClauseElement - from .elements import ColumnClause from .elements import ColumnElement - from .elements import TextClause from .schema import _SQLExprDefault - from .schema import Column from .selectable import TableClause REQUIRED = util.symbol( @@ -68,12 +64,20 @@ values present. ) +def _as_dml_column(c: ColumnElement[Any]) -> ColumnClause[Any]: + if not isinstance(c, ColumnClause): + raise exc.CompileError( + f"Can't create DML statement against column expression {c!r}" + ) + return c + + class _CrudParams(NamedTuple): - single_params: List[ - Tuple[ColumnClause[Any], str, Optional[Union[str, _SQLExprDefault]]] + single_params: Sequence[ + Tuple[ColumnElement[Any], str, Optional[Union[str, _SQLExprDefault]]] ] all_multi_params: List[ - List[ + Sequence[ Tuple[ ColumnClause[Any], str, @@ -274,7 +278,7 @@ def _get_crud_params( compiler, stmt, compile_state, - cast("List[Tuple[ColumnClause[Any], str, str]]", values), + cast("Sequence[Tuple[ColumnClause[Any], str, str]]", values), cast("Callable[..., str]", _column_as_key), kw, ) @@ -290,7 +294,7 @@ def _get_crud_params( # insert_executemany_returning mode :) values = [ ( - stmt.table.columns[0], + _as_dml_column(stmt.table.columns[0]), compiler.preparer.format_column(stmt.table.columns[0]), "DEFAULT", ) @@ -1135,10 +1139,10 @@ def _extend_values_for_multiparams( compiler: SQLCompiler, stmt: ValuesBase, compile_state: DMLState, - initial_values: List[Tuple[ColumnClause[Any], str, str]], + initial_values: Sequence[Tuple[ColumnClause[Any], str, str]], _column_as_key: Callable[..., str], kw: Dict[str, Any], -) -> List[List[Tuple[ColumnClause[Any], str, str]]]: +) -> List[Sequence[Tuple[ColumnClause[Any], str, str]]]: values_0 = initial_values values = [initial_values] diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 8307f6400..e0f162fc8 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -22,15 +22,19 @@ from typing import List from typing import MutableMapping from typing import NoReturn from typing import Optional +from typing import overload from typing import Sequence from typing import Tuple from typing import Type from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union from . import coercions from . import roles from . import util as sql_util +from ._typing import _no_kw +from ._typing import _TP from ._typing import is_column_element from ._typing import is_named_from_clause from .base import _entity_namespace_key @@ -42,6 +46,7 @@ from .base import ColumnCollection from .base import CompileState from .base import DialectKWArgs from .base import Executable +from .base import Generative from .base import HasCompileState from .elements import BooleanClauseList from .elements import ClauseElement @@ -49,12 +54,13 @@ from .elements import ColumnClause from .elements import ColumnElement from .elements import Null from .selectable import Alias +from .selectable import ExecutableReturnsRows from .selectable import FromClause from .selectable import HasCTE from .selectable import HasPrefixes from .selectable import Join -from .selectable import ReturnsRows from .selectable import TableClause +from .selectable import TypedReturnsRows from .sqltypes import NullType from .visitors import InternalTraversal from .. import exc @@ -66,9 +72,19 @@ if TYPE_CHECKING: from ._typing import _ColumnsClauseArgument from ._typing import _DMLColumnArgument from ._typing import _DMLTableArgument - from ._typing import _FromClauseArgument + from ._typing import _T0 # noqa + from ._typing import _T1 # noqa + from ._typing import _T2 # noqa + from ._typing import _T3 # noqa + from ._typing import _T4 # noqa + from ._typing import _T5 # noqa + from ._typing import _T6 # noqa + from ._typing import _T7 # noqa + from ._typing import _TypedColumnClauseArgument as _TCCA # noqa from .base import ReadOnlyColumnCollection from .compiler import SQLCompiler + from .elements import ColumnElement + from .elements import KeyedColumnElement from .selectable import _ColumnsClauseElement from .selectable import _SelectIterable from .selectable import Select @@ -88,6 +104,8 @@ else: isinsert = operator.attrgetter("isinsert") +_T = TypeVar("_T", bound=Any) + _DMLColumnElement = Union[str, ColumnClause[Any]] _DMLTableElement = Union[TableClause, Alias, Join] @@ -185,6 +203,11 @@ class DMLState(CompileState): "%s construct does not support " "multiple parameter sets." % statement.__visit_name__.upper() ) + else: + assert isinstance(statement, Insert) + + # which implies... + # assert isinstance(statement.table, TableClause) for parameters in statement._multi_values: multi_parameters: List[MutableMapping[_DMLColumnElement, Any]] = [ @@ -291,7 +314,9 @@ class UpdateDMLState(DMLState): elif statement._multi_values: self._process_multi_values(statement) self._extra_froms = ef = self._make_extra_froms(statement) - self.is_multitable = mt = ef and self._dict_parameters + + self.is_multitable = mt = ef + self.include_table_with_column_exprs = bool( mt and compiler.render_table_with_column_in_update_from ) @@ -317,8 +342,8 @@ class UpdateBase( HasCompileState, DialectKWArgs, HasPrefixes, - ReturnsRows, - Executable, + Generative, + ExecutableReturnsRows, ClauseElement, ): """Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements.""" @@ -383,8 +408,8 @@ class UpdateBase( @_generative def returning( - self: SelfUpdateBase, *cols: _ColumnsClauseArgument - ) -> SelfUpdateBase: + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> UpdateBase: r"""Add a :term:`RETURNING` or equivalent clause to this statement. e.g.: @@ -454,6 +479,8 @@ class UpdateBase( :ref:`tutorial_insert_returning` - in the :ref:`unified_tutorial` """ # noqa: E501 + if __kw: + raise _no_kw() if self._return_defaults: raise exc.InvalidRequestError( "return_defaults() is already configured on this statement" @@ -464,7 +491,7 @@ class UpdateBase( return self def corresponding_column( - self, column: ColumnElement[Any], require_embedded: bool = False + self, column: KeyedColumnElement[Any], require_embedded: bool = False ) -> Optional[ColumnElement[Any]]: return self.exported_columns.corresponding_column( column, require_embedded=require_embedded @@ -628,7 +655,7 @@ class ValuesBase(UpdateBase): _supports_multi_parameters = False - select: Optional[Select] = None + select: Optional[Select[Any]] = None """SELECT statement for INSERT .. FROM SELECT""" _post_values_clause: Optional[ClauseElement] = None @@ -804,11 +831,15 @@ class ValuesBase(UpdateBase): ) elif isinstance(arg, collections_abc.Sequence): - if arg and isinstance(arg[0], (list, dict, tuple)): self._multi_values += (arg,) return self + if TYPE_CHECKING: + # crud.py raises during compilation if this is not the + # case + assert isinstance(self, Insert) + # tuple values arg = {c.key: value for c, value in zip(self.table.c, arg)} @@ -1010,7 +1041,7 @@ class Insert(ValuesBase): def from_select( self: SelfInsert, names: List[str], - select: Select, + select: Select[Any], include_defaults: bool = True, ) -> SelfInsert: """Return a new :class:`_expression.Insert` construct which represents @@ -1073,6 +1104,114 @@ class Insert(ValuesBase): self.select = coercions.expect(roles.DMLSelectRole, select) return self + if TYPE_CHECKING: + + # START OVERLOADED FUNCTIONS self.returning ReturningInsert 1-8 + + # code within this block is **programmatically, + # statically generated** by tools/generate_tuple_map_overloads.py + + @overload + def returning(self, __ent0: _TCCA[_T0]) -> ReturningInsert[Tuple[_T0]]: + ... + + @overload + def returning( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] + ) -> ReturningInsert[Tuple[_T0, _T1]]: + ... + + @overload + def returning( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] + ) -> ReturningInsert[Tuple[_T0, _T1, _T2]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... + + # END OVERLOADED FUNCTIONS self.returning + + @overload + def returning( + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> ReturningInsert[Any]: + ... + + def returning( + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> ReturningInsert[Any]: + ... + + +class ReturningInsert(Insert, TypedReturnsRows[_TP]): + """Typing-only class that establishes a generic type form of + :class:`.Insert` which tracks returned column types. + + This datatype is delivered when calling the + :meth:`.Insert.returning` method. + + .. versionadded:: 2.0 + + """ + SelfDMLWhereBase = typing.TypeVar("SelfDMLWhereBase", bound="DMLWhereBase") @@ -1264,6 +1403,113 @@ class Update(DMLWhereBase, ValuesBase): self._inline = True return self + if TYPE_CHECKING: + # START OVERLOADED FUNCTIONS self.returning ReturningUpdate 1-8 + + # code within this block is **programmatically, + # statically generated** by tools/generate_tuple_map_overloads.py + + @overload + def returning(self, __ent0: _TCCA[_T0]) -> ReturningUpdate[Tuple[_T0]]: + ... + + @overload + def returning( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] + ) -> ReturningUpdate[Tuple[_T0, _T1]]: + ... + + @overload + def returning( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... + + # END OVERLOADED FUNCTIONS self.returning + + @overload + def returning( + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> ReturningUpdate[Any]: + ... + + def returning( + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> ReturningUpdate[Any]: + ... + + +class ReturningUpdate(Update, TypedReturnsRows[_TP]): + """Typing-only class that establishes a generic type form of + :class:`.Update` which tracks returned column types. + + This datatype is delivered when calling the + :meth:`.Update.returning` method. + + .. versionadded:: 2.0 + + """ + SelfDelete = typing.TypeVar("SelfDelete", bound="Delete") @@ -1297,3 +1543,111 @@ class Delete(DMLWhereBase, UpdateBase): self.table = coercions.expect( roles.DMLTableRole, table, apply_propagate_attrs=self ) + + if TYPE_CHECKING: + + # START OVERLOADED FUNCTIONS self.returning ReturningDelete 1-8 + + # code within this block is **programmatically, + # statically generated** by tools/generate_tuple_map_overloads.py + + @overload + def returning(self, __ent0: _TCCA[_T0]) -> ReturningDelete[Tuple[_T0]]: + ... + + @overload + def returning( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] + ) -> ReturningDelete[Tuple[_T0, _T1]]: + ... + + @overload + def returning( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] + ) -> ReturningDelete[Tuple[_T0, _T1, _T2]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... + + # END OVERLOADED FUNCTIONS self.returning + + @overload + def returning( + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> ReturningDelete[Any]: + ... + + def returning( + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> ReturningDelete[Any]: + ... + + +class ReturningDelete(Update, TypedReturnsRows[_TP]): + """Typing-only class that establishes a generic type form of + :class:`.Delete` which tracks returned column types. + + This datatype is delivered when calling the + :meth:`.Delete.returning` method. + + .. versionadded:: 2.0 + + """ diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 34d5127ab..a29561291 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -54,6 +54,7 @@ from .base import _clone from .base import _generative from .base import _NoArg from .base import Executable +from .base import Generative from .base import HasMemoized from .base import Immutable from .base import NO_ARG @@ -94,10 +95,7 @@ if typing.TYPE_CHECKING: from .selectable import _SelectIterable from .selectable import FromClause from .selectable import NamedFromClause - from .selectable import ReturnsRows from .selectable import Select - from .selectable import TableClause - from .sqltypes import Boolean from .sqltypes import TupleType from .type_api import TypeEngine from .visitors import _CloneCallableType @@ -122,7 +120,9 @@ _NT = TypeVar("_NT", bound="_NUMERIC") _NMT = TypeVar("_NMT", bound="_NUMBER") -def literal(value, type_=None): +def literal( + value: Any, type_: Optional[_TypeEngineArgument[_T]] = None +) -> BindParameter[_T]: r"""Return a literal clause, bound to a bind parameter. Literal clauses are created automatically when non- @@ -144,7 +144,9 @@ def literal(value, type_=None): return coercions.expect(roles.LiteralValueRole, value, type_=type_) -def literal_column(text, type_=None): +def literal_column( + text: str, type_: Optional[_TypeEngineArgument[_T]] = None +) -> ColumnClause[_T]: r"""Produce a :class:`.ColumnClause` object that has the :paramref:`_expression.column.is_literal` flag set to True. @@ -316,6 +318,7 @@ class ClauseElement( is_selectable = False is_dml = False _is_column_element = False + _is_keyed_column_element = False _is_table = False _is_textual = False _is_from_clause = False @@ -342,7 +345,7 @@ class ClauseElement( if typing.TYPE_CHECKING: def get_children( - self, omit_attrs: typing_Tuple[str, ...] = ..., **kw: Any + self, *, omit_attrs: typing_Tuple[str, ...] = ..., **kw: Any ) -> Iterable[ClauseElement]: ... @@ -455,7 +458,7 @@ class ClauseElement( connection: Connection, distilled_params: _CoreMultiExecuteParams, execution_options: _ExecuteOptions, - ) -> Result: + ) -> Result[Any]: if self.supports_execution: if TYPE_CHECKING: assert isinstance(self, Executable) @@ -833,13 +836,13 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly): def in_( self, - other: Union[Sequence[Any], BindParameter[Any], Select], + other: Union[Sequence[Any], BindParameter[Any], Select[Any]], ) -> BinaryExpression[bool]: ... def not_in( self, - other: Union[Sequence[Any], BindParameter[Any], Select], + other: Union[Sequence[Any], BindParameter[Any], Select[Any]], ) -> BinaryExpression[bool]: ... @@ -1699,6 +1702,14 @@ class ColumnElement( return self._anon_label(label, add_hash=idx) +class KeyedColumnElement(ColumnElement[_T]): + """ColumnElement where ``.key`` is non-None.""" + + _is_keyed_column_element = True + + key: str + + class WrapsColumnExpression(ColumnElement[_T]): """Mixin that defines a :class:`_expression.ColumnElement` as a wrapper with special @@ -1760,7 +1771,7 @@ class WrapsColumnExpression(ColumnElement[_T]): SelfBindParameter = TypeVar("SelfBindParameter", bound="BindParameter[Any]") -class BindParameter(roles.InElementRole, ColumnElement[_T]): +class BindParameter(roles.InElementRole, KeyedColumnElement[_T]): r"""Represent a "bound expression". :class:`.BindParameter` is invoked explicitly using the @@ -2073,6 +2084,7 @@ class TextClause( roles.FromClauseRole, roles.SelectStatementRole, roles.InElementRole, + Generative, Executable, DQLDMLClauseElement, roles.BinaryElementRole[Any], @@ -4160,7 +4172,7 @@ class FunctionFilter(ColumnElement[_T]): ) -class NamedColumn(ColumnElement[_T]): +class NamedColumn(KeyedColumnElement[_T]): is_literal = False table: Optional[FromClause] = None name: str @@ -4502,7 +4514,7 @@ class ColumnClause( self.is_literal = is_literal - def get_children(self, column_tables=False, **kw): + def get_children(self, *, column_tables=False, **kw): # override base get_children() to not return the Table # or selectable that is parent to this column. Traversals # expect the columns of tables and subqueries to be leaf nodes. diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 648168235..b827df3df 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -175,7 +175,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): connection: Connection, distilled_params: _CoreMultiExecuteParams, execution_options: _ExecuteOptionsParameter, - ) -> CursorResult: + ) -> CursorResult[Any]: return connection._execute_function( self, distilled_params, execution_options ) @@ -623,7 +623,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): joins_implicitly=joins_implicitly, ) - def select(self) -> "Select": + def select(self) -> Select[Any]: """Produce a :func:`_expression.select` construct against this :class:`.FunctionElement`. @@ -632,7 +632,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): s = select(function_element) """ - s = Select(self) + s: Select[Any] = Select(self) if self._execution_options: s = s.execution_options(**self._execution_options) return s @@ -846,7 +846,7 @@ class _FunctionGenerator: @overload def __call__( - self, *c: Any, type_: TypeEngine[_T], **kwargs: Any + self, *c: Any, type_: _TypeEngineArgument[_T], **kwargs: Any ) -> Function[_T]: ... diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py index 231c70a5b..09d4b35ad 100644 --- a/lib/sqlalchemy/sql/roles.py +++ b/lib/sqlalchemy/sql/roles.py @@ -8,8 +8,6 @@ from __future__ import annotations from typing import Any from typing import Generic -from typing import Iterable -from typing import List from typing import Optional from typing import TYPE_CHECKING from typing import TypeVar @@ -19,12 +17,7 @@ from ..util.typing import Literal if TYPE_CHECKING: from ._typing import _PropagateAttrsType - from .base import _EntityNamespace - from .base import ColumnCollection - from .base import ReadOnlyColumnCollection - from .elements import ColumnClause from .elements import Label - from .elements import NamedColumn from .selectable import _SelectIterable from .selectable import FromClause from .selectable import Subquery @@ -108,13 +101,21 @@ class TruncatedLabelRole(StringRole, SQLRole): class ColumnsClauseRole(AllowsLambdaRole, UsesInspection, ColumnListRole): __slots__ = () - _role_name = "Column expression or FROM clause" + _role_name = ( + "Column expression, FROM clause, or other columns clause element" + ) @property def _select_iterable(self) -> _SelectIterable: raise NotImplementedError() +class TypedColumnsClauseRole(Generic[_T], SQLRole): + """element-typed form of ColumnsClauseRole""" + + __slots__ = () + + class LimitOffsetRole(SQLRole): __slots__ = () _role_name = "LIMIT / OFFSET expression" @@ -161,7 +162,7 @@ class WhereHavingRole(OnClauseRole): _role_name = "SQL expression for WHERE/HAVING role" -class ExpressionElementRole(Generic[_T], SQLRole): +class ExpressionElementRole(TypedColumnsClauseRole[_T]): # note when using generics for ExpressionElementRole, # the generic type needs to be in # sqlalchemy.sql.coercions._impl_lookup mapping also. @@ -212,39 +213,11 @@ class FromClauseRole(ColumnsClauseRole, JoinTargetRole): named_with_column: bool - if TYPE_CHECKING: - - @util.ro_non_memoized_property - def c(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]: - ... - - @util.ro_non_memoized_property - def columns(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]: - ... - - @util.ro_non_memoized_property - def entity_namespace(self) -> _EntityNamespace: - ... - - @util.ro_non_memoized_property - def _hide_froms(self) -> Iterable[FromClause]: - ... - - @util.ro_non_memoized_property - def _from_objects(self) -> List[FromClause]: - ... - class StrictFromClauseRole(FromClauseRole): __slots__ = () # does not allow text() or select() objects - if TYPE_CHECKING: - - @util.ro_non_memoized_property - def description(self) -> str: - ... - class AnonymizedFromClauseRole(StrictFromClauseRole): __slots__ = () @@ -317,16 +290,6 @@ class DMLTableRole(FromClauseRole): __slots__ = () _role_name = "subject table for an INSERT, UPDATE or DELETE" - if TYPE_CHECKING: - - @util.ro_non_memoized_property - def primary_key(self) -> Iterable[NamedColumn[Any]]: - ... - - @util.ro_non_memoized_property - def columns(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]: - ... - class DMLColumnRole(SQLRole): __slots__ = () diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 52ba60a62..27456d2be 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -86,7 +86,6 @@ if typing.TYPE_CHECKING: from ._typing import _InfoType from ._typing import _TextCoercedExpressionArgument from ._typing import _TypeEngineArgument - from .base import ColumnCollection from .base import DedupeColumnCollection from .base import ReadOnlyColumnCollection from .compiler import DDLCompiler @@ -97,9 +96,7 @@ if typing.TYPE_CHECKING: from .visitors import anon_map from ..engine import Connection from ..engine import Engine - from ..engine.cursor import CursorResult from ..engine.interfaces import _CoreMultiExecuteParams - from ..engine.interfaces import _CoreSingleExecuteParams from ..engine.interfaces import _ExecuteOptionsParameter from ..engine.interfaces import ExecutionContext from ..engine.mock import MockConnection @@ -2609,8 +2606,10 @@ class ForeignKey(DialectKWArgs, SchemaItem): :class:`_schema.Table`. """ - - return table.columns.corresponding_column(self.column) + # our column is a Column, and any subquery etc. proxying us + # would be doing so via another Column, so that's what would + # be returned here + return table.columns.corresponding_column(self.column) # type: ignore @util.memoized_property def _column_tokens(self) -> Tuple[Optional[str], str, Optional[str]]: diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 9d4d1d6c7..b08f13f99 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -23,6 +23,7 @@ from typing import Any from typing import Callable from typing import cast from typing import Dict +from typing import Generic from typing import Iterable from typing import Iterator from typing import List @@ -46,6 +47,8 @@ from . import traversals from . import type_api from . import visitors from ._typing import _ColumnsClauseArgument +from ._typing import _no_kw +from ._typing import _TP from ._typing import is_column_element from ._typing import is_select_statement from ._typing import is_subquery @@ -103,9 +106,20 @@ if TYPE_CHECKING: from ._typing import _ColumnExpressionArgument from ._typing import _FromClauseArgument from ._typing import _JoinTargetArgument + from ._typing import _MAYBE_ENTITY + from ._typing import _NOT_ENTITY from ._typing import _OnClauseArgument from ._typing import _SelectStatementForCompoundArgument + from ._typing import _T0 + from ._typing import _T1 + from ._typing import _T2 + from ._typing import _T3 + from ._typing import _T4 + from ._typing import _T5 + from ._typing import _T6 + from ._typing import _T7 from ._typing import _TextCoercedExpressionArgument + from ._typing import _TypedColumnClauseArgument as _TCCA from ._typing import _TypeEngineArgument from .base import _AmbiguousTableNameMap from .base import ExecutableOption @@ -115,14 +129,13 @@ if TYPE_CHECKING: from .dml import Delete from .dml import Insert from .dml import Update + from .elements import KeyedColumnElement from .elements import NamedColumn from .elements import TextClause from .functions import Function - from .schema import Column from .schema import ForeignKey from .schema import ForeignKeyConstraint from .type_api import TypeEngine - from .util import ClauseAdapter from .visitors import _CloneCallableType @@ -245,6 +258,14 @@ class ReturnsRows(roles.ReturnsRowsRole, DQLDMLClauseElement): raise NotImplementedError() +class ExecutableReturnsRows(Executable, ReturnsRows): + """base for executable statements that return rows.""" + + +class TypedReturnsRows(ExecutableReturnsRows, Generic[_TP]): + """base for executable statements that return rows.""" + + SelfSelectable = TypeVar("SelfSelectable", bound="Selectable") @@ -293,8 +314,8 @@ class Selectable(ReturnsRows): ) def corresponding_column( - self, column: ColumnElement[Any], require_embedded: bool = False - ) -> Optional[ColumnElement[Any]]: + self, column: KeyedColumnElement[Any], require_embedded: bool = False + ) -> Optional[KeyedColumnElement[Any]]: """Given a :class:`_expression.ColumnElement`, return the exported :class:`_expression.ColumnElement` object from the :attr:`_expression.Selectable.exported_columns` @@ -593,7 +614,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): _use_schema_map = False - def select(self) -> Select: + def select(self) -> Select[Any]: r"""Return a SELECT of this :class:`_expression.FromClause`. @@ -795,7 +816,9 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): ) @util.ro_non_memoized_property - def exported_columns(self) -> ReadOnlyColumnCollection[str, Any]: + def exported_columns( + self, + ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]: """A :class:`_expression.ColumnCollection` that represents the "exported" columns of this :class:`_expression.Selectable`. @@ -817,7 +840,9 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): return self.c @util.ro_non_memoized_property - def columns(self) -> ReadOnlyColumnCollection[str, Any]: + def columns( + self, + ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]: """A named-based collection of :class:`_expression.ColumnElement` objects maintained by this :class:`_expression.FromClause`. @@ -833,7 +858,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): return self.c @util.ro_memoized_property - def c(self) -> ReadOnlyColumnCollection[str, Any]: + def c(self) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]: """ A synonym for :attr:`.FromClause.columns` @@ -1223,7 +1248,7 @@ class Join(roles.DMLTableRole, FromClause): @util.preload_module("sqlalchemy.sql.util") def _populate_column_collection(self): sqlutil = util.preloaded.sql_util - columns: List[ColumnClause[Any]] = [c for c in self.left.c] + [ + columns: List[KeyedColumnElement[Any]] = [c for c in self.left.c] + [ c for c in self.right.c ] @@ -1458,7 +1483,7 @@ class Join(roles.DMLTableRole, FromClause): "join explicitly." % (a.description, b.description) ) - def select(self) -> "Select": + def select(self) -> Select[Any]: r"""Create a :class:`_expression.Select` from this :class:`_expression.Join`. @@ -2764,6 +2789,7 @@ class Subquery(AliasedReturnsRows): cls, selectable: SelectBase, name: Optional[str] = None ) -> Subquery: """Return a :class:`.Subquery` object.""" + return coercions.expect( roles.SelectStatementRole, selectable ).subquery(name=name) @@ -3216,7 +3242,6 @@ class SelectBase( roles.CompoundElementRole, roles.InElementRole, HasCTE, - Executable, SupportsCloneAnnotations, Selectable, ): @@ -3239,7 +3264,9 @@ class SelectBase( self._reset_memoizations() @util.ro_non_memoized_property - def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]: + def selected_columns( + self, + ) -> ColumnCollection[str, ColumnElement[Any]]: """A :class:`_expression.ColumnCollection` representing the columns that this SELECT statement or similar construct returns in its result set. @@ -3284,7 +3311,9 @@ class SelectBase( raise NotImplementedError() @property - def exported_columns(self) -> ReadOnlyColumnCollection[str, Any]: + def exported_columns( + self, + ) -> ReadOnlyColumnCollection[str, ColumnElement[Any]]: """A :class:`_expression.ColumnCollection` that represents the "exported" columns of this :class:`_expression.Selectable`, not including @@ -3377,7 +3406,7 @@ class SelectBase( def as_scalar(self): return self.scalar_subquery() - def exists(self): + def exists(self) -> Exists: """Return an :class:`_sql.Exists` representation of this selectable, which can be used as a column expression. @@ -3394,7 +3423,7 @@ class SelectBase( """ return Exists(self) - def scalar_subquery(self): + def scalar_subquery(self) -> ScalarSelect[Any]: """Return a 'scalar' representation of this selectable, which can be used as a column expression. @@ -3607,7 +3636,7 @@ SelfGenerativeSelect = typing.TypeVar( ) -class GenerativeSelect(SelectBase): +class GenerativeSelect(SelectBase, Generative): """Base class for SELECT statements where additional elements can be added. @@ -4128,7 +4157,7 @@ class _CompoundSelectKeyword(Enum): INTERSECT_ALL = "INTERSECT ALL" -class CompoundSelect(HasCompileState, GenerativeSelect): +class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows): """Forms the basis of ``UNION``, ``UNION ALL``, and other SELECT-based set operations. @@ -4293,7 +4322,9 @@ class CompoundSelect(HasCompileState, GenerativeSelect): return self.selects[0]._all_selected_columns @util.ro_non_memoized_property - def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]: + def selected_columns( + self, + ) -> ColumnCollection[str, ColumnElement[Any]]: """A :class:`_expression.ColumnCollection` representing the columns that this SELECT statement or similar construct returns in its result set, @@ -4343,7 +4374,10 @@ class SelectState(util.MemoizedSlots, CompileState): ... def __init__( - self, statement: Select, compiler: Optional[SQLCompiler], **kw: Any + self, + statement: Select[Any], + compiler: Optional[SQLCompiler], + **kw: Any, ): self.statement = statement self.from_clauses = statement._from_obj @@ -4369,7 +4403,7 @@ class SelectState(util.MemoizedSlots, CompileState): @classmethod def get_column_descriptions( - cls, statement: Select + cls, statement: Select[Any] ) -> List[Dict[str, Any]]: return [ { @@ -4384,12 +4418,14 @@ class SelectState(util.MemoizedSlots, CompileState): @classmethod def from_statement( - cls, statement: Select, from_statement: ReturnsRows - ) -> Any: + cls, statement: Select[Any], from_statement: ExecutableReturnsRows + ) -> ExecutableReturnsRows: cls._plugin_not_implemented() @classmethod - def get_columns_clause_froms(cls, statement: Select) -> List[FromClause]: + def get_columns_clause_froms( + cls, statement: Select[Any] + ) -> List[FromClause]: return cls._normalize_froms( itertools.chain.from_iterable( element._from_objects for element in statement._raw_columns @@ -4439,7 +4475,7 @@ class SelectState(util.MemoizedSlots, CompileState): return go - def _get_froms(self, statement: Select) -> List[FromClause]: + def _get_froms(self, statement: Select[Any]) -> List[FromClause]: ambiguous_table_name_map: _AmbiguousTableNameMap self._ambiguous_table_name_map = ambiguous_table_name_map = {} @@ -4467,7 +4503,7 @@ class SelectState(util.MemoizedSlots, CompileState): def _normalize_froms( cls, iterable_of_froms: Iterable[FromClause], - check_statement: Optional[Select] = None, + check_statement: Optional[Select[Any]] = None, ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] = None, ) -> List[FromClause]: """given an iterable of things to select FROM, reduce them to what @@ -4615,7 +4651,7 @@ class SelectState(util.MemoizedSlots, CompileState): @classmethod def determine_last_joined_entity( - cls, stmt: Select + cls, stmt: Select[Any] ) -> Optional[_JoinTargetElement]: if stmt._setup_joins: return stmt._setup_joins[-1][0] @@ -4623,7 +4659,7 @@ class SelectState(util.MemoizedSlots, CompileState): return None @classmethod - def all_selected_columns(cls, statement: Select) -> _SelectIterable: + def all_selected_columns(cls, statement: Select[Any]) -> _SelectIterable: return [c for c in _select_iterables(statement._raw_columns)] def _setup_joins( @@ -4876,7 +4912,7 @@ class _MemoizedSelectEntities( return c # type: ignore @classmethod - def _generate_for_statement(cls, select_stmt: Select) -> None: + def _generate_for_statement(cls, select_stmt: Select[Any]) -> None: if select_stmt._setup_joins or select_stmt._with_options: self = _MemoizedSelectEntities() self._raw_columns = select_stmt._raw_columns @@ -4888,7 +4924,7 @@ class _MemoizedSelectEntities( select_stmt._setup_joins = select_stmt._with_options = () -SelfSelect = typing.TypeVar("SelfSelect", bound="Select") +SelfSelect = typing.TypeVar("SelfSelect", bound="Select[Any]") class Select( @@ -4898,6 +4934,7 @@ class Select( HasCompileState, _SelectFromElements, GenerativeSelect, + TypedReturnsRows[_TP], ): """Represents a ``SELECT`` statement. @@ -4973,7 +5010,7 @@ class Select( _compile_state_factory: Type[SelectState] @classmethod - def _create_raw_select(cls, **kw: Any) -> Select: + def _create_raw_select(cls, **kw: Any) -> Select[Any]: """Create a :class:`.Select` using raw ``__new__`` with no coercions. Used internally to build up :class:`.Select` constructs with @@ -4985,7 +5022,7 @@ class Select( stmt.__dict__.update(kw) return stmt - def __init__(self, *entities: _ColumnsClauseArgument): + def __init__(self, *entities: _ColumnsClauseArgument[Any]): r"""Construct a new :class:`_expression.Select`. The public constructor for :class:`_expression.Select` is the @@ -5013,7 +5050,9 @@ class Select( cols = list(elem._select_iterable) return cols[0].type - def filter(self: SelfSelect, *criteria: ColumnElement[Any]) -> SelfSelect: + def filter( + self: SelfSelect, *criteria: _ColumnExpressionArgument[bool] + ) -> SelfSelect: """A synonym for the :meth:`_future.Select.where` method.""" return self.where(*criteria) @@ -5032,7 +5071,28 @@ class Select( return self._raw_columns[0] - def filter_by(self, **kwargs): + if TYPE_CHECKING: + + @overload + def scalar_subquery( + self: Select[Tuple[_MAYBE_ENTITY]], + ) -> ScalarSelect[Any]: + ... + + @overload + def scalar_subquery( + self: Select[Tuple[_NOT_ENTITY]], + ) -> ScalarSelect[_NOT_ENTITY]: + ... + + @overload + def scalar_subquery(self) -> ScalarSelect[Any]: + ... + + def scalar_subquery(self) -> ScalarSelect[Any]: + ... + + def filter_by(self: SelfSelect, **kwargs: Any) -> SelfSelect: r"""apply the given filtering criterion as a WHERE clause to this select. @@ -5046,7 +5106,7 @@ class Select( return self.filter(*clauses) @property - def column_descriptions(self): + def column_descriptions(self) -> Any: """Return a :term:`plugin-enabled` 'column descriptions' structure referring to the columns which are SELECTed by this statement. @@ -5089,7 +5149,9 @@ class Select( meth = SelectState.get_plugin_class(self).get_column_descriptions return meth(self) - def from_statement(self, statement): + def from_statement( + self, statement: ExecutableReturnsRows + ) -> ExecutableReturnsRows: """Apply the columns which this :class:`.Select` would select onto another statement. @@ -5410,7 +5472,7 @@ class Select( ) @property - def inner_columns(self): + def inner_columns(self) -> _SelectIterable: """An iterator of all :class:`_expression.ColumnElement` expressions which would be rendered into the columns clause of the resulting SELECT statement. @@ -5487,18 +5549,19 @@ class Select( self._reset_memoizations() - def get_children(self, **kwargs): + def get_children(self, **kw: Any) -> Iterable[ClauseElement]: return itertools.chain( super(Select, self).get_children( - omit_attrs=("_from_obj", "_correlate", "_correlate_except") + omit_attrs=("_from_obj", "_correlate", "_correlate_except"), + **kw, ), self._iterate_from_elements(), ) @_generative def add_columns( - self: SelfSelect, *columns: _ColumnsClauseArgument - ) -> SelfSelect: + self, *columns: _ColumnsClauseArgument[Any] + ) -> Select[Any]: """Return a new :func:`_expression.select` construct with the given column expressions added to its columns clause. @@ -5523,7 +5586,7 @@ class Select( return self def _set_entities( - self, entities: Iterable[_ColumnsClauseArgument] + self, entities: Iterable[_ColumnsClauseArgument[Any]] ) -> None: self._raw_columns = [ coercions.expect( @@ -5538,7 +5601,7 @@ class Select( "be removed in a future release. Please use " ":meth:`_expression.Select.add_columns`", ) - def column(self: SelfSelect, column: _ColumnsClauseArgument) -> SelfSelect: + def column(self, column: _ColumnsClauseArgument[Any]) -> Select[Any]: """Return a new :func:`_expression.select` construct with the given column expression added to its columns clause. @@ -5555,9 +5618,7 @@ class Select( return self.add_columns(column) @util.preload_module("sqlalchemy.sql.util") - def reduce_columns( - self: SelfSelect, only_synonyms: bool = True - ) -> SelfSelect: + def reduce_columns(self, only_synonyms: bool = True) -> Select[Any]: """Return a new :func:`_expression.select` construct with redundantly named, equivalently-valued columns removed from the columns clause. @@ -5580,20 +5641,115 @@ class Select( all columns that are equivalent to another are removed. """ - return self.with_only_columns( + woc: Select[Any] + woc = self.with_only_columns( *util.preloaded.sql_util.reduce_columns( self._all_selected_columns, only_synonyms=only_synonyms, *(self._where_criteria + self._from_obj), ) ) + return woc + + # START OVERLOADED FUNCTIONS self.with_only_columns Select 8 + + # code within this block is **programmatically, + # statically generated** by tools/generate_sel_v1_overloads.py + + @overload + def with_only_columns(self, __ent0: _TCCA[_T0]) -> Select[Tuple[_T0]]: + ... + + @overload + def with_only_columns( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] + ) -> Select[Tuple[_T0, _T1]]: + ... + + @overload + def with_only_columns( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] + ) -> Select[Tuple[_T0, _T1, _T2]]: + ... + + @overload + def with_only_columns( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + ) -> Select[Tuple[_T0, _T1, _T2, _T3]]: + ... + + @overload + def with_only_columns( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + ) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... + + @overload + def with_only_columns( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + ) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... + + @overload + def with_only_columns( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + ) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... + + @overload + def with_only_columns( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + ) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... + + # END OVERLOADED FUNCTIONS self.with_only_columns + + @overload + def with_only_columns( + self, + *columns: _ColumnsClauseArgument[Any], + maintain_column_froms: bool = False, + **__kw: Any, + ) -> Select[Any]: + ... @_generative def with_only_columns( - self: SelfSelect, - *columns: _ColumnsClauseArgument, + self, + *columns: _ColumnsClauseArgument[Any], maintain_column_froms: bool = False, - ) -> SelfSelect: + **__kw: Any, + ) -> Select[Any]: r"""Return a new :func:`_expression.select` construct with its columns clause replaced with the given columns. @@ -5647,6 +5803,9 @@ class Select( """ # noqa: E501 + if __kw: + raise _no_kw() + # memoizations should be cleared here as of # I95c560ffcbfa30b26644999412fb6a385125f663 , asserting this # is the case for now. @@ -5915,7 +6074,9 @@ class Select( return self @HasMemoized_ro_memoized_attribute - def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]: + def selected_columns( + self, + ) -> ColumnCollection[str, ColumnElement[Any]]: """A :class:`_expression.ColumnCollection` representing the columns that this SELECT statement or similar construct returns in its result set, @@ -6215,7 +6376,7 @@ class ScalarSelect( by this :class:`_expression.ScalarSelect`. """ - self.element = cast(Select, self.element).where(crit) + self.element = cast("Select[Any]", self.element).where(crit) return self @overload @@ -6269,7 +6430,9 @@ class ScalarSelect( """ - self.element = cast(Select, self.element).correlate(*fromclauses) + self.element = cast("Select[Any]", self.element).correlate( + *fromclauses + ) return self @_generative @@ -6307,7 +6470,7 @@ class ScalarSelect( """ - self.element = cast(Select, self.element).correlate_except( + self.element = cast("Select[Any]", self.element).correlate_except( *fromclauses ) return self @@ -6331,12 +6494,18 @@ class Exists(UnaryExpression[bool]): def __init__( self, __argument: Optional[ - Union[_ColumnsClauseArgument, SelectBase, ScalarSelect[bool]] + Union[_ColumnsClauseArgument[Any], SelectBase, ScalarSelect[Any]] ] = None, ): + s: ScalarSelect[Any] + + # TODO: this seems like we should be using coercions for this if __argument is None: s = Select(literal_column("*")).scalar_subquery() - elif isinstance(__argument, (SelectBase, ScalarSelect)): + elif isinstance(__argument, SelectBase): + s = __argument.scalar_subquery() + s._propagate_attrs = __argument._propagate_attrs + elif isinstance(__argument, ScalarSelect): s = __argument else: s = Select(__argument).scalar_subquery() @@ -6358,7 +6527,7 @@ class Exists(UnaryExpression[bool]): element = fn(element) return element.self_group(against=operators.exists) - def select(self) -> Select: + def select(self) -> Select[Any]: r"""Return a SELECT of this :class:`_expression.Exists`. e.g.:: @@ -6452,7 +6621,7 @@ class Exists(UnaryExpression[bool]): SelfTextualSelect = typing.TypeVar("SelfTextualSelect", bound="TextualSelect") -class TextualSelect(SelectBase): +class TextualSelect(SelectBase, Executable, Generative): """Wrap a :class:`_expression.TextClause` construct within a :class:`_expression.SelectBase` interface. @@ -6503,7 +6672,9 @@ class TextualSelect(SelectBase): self.positional = positional @HasMemoized_ro_memoized_attribute - def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]: + def selected_columns( + self, + ) -> ColumnCollection[str, KeyedColumnElement[Any]]: """A :class:`_expression.ColumnCollection` representing the columns that this SELECT statement or similar construct returns in its result set, diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index d08fef60a..8c45ba410 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -50,6 +50,7 @@ from .elements import ClauseElement from .elements import ColumnClause from .elements import ColumnElement from .elements import Grouping +from .elements import KeyedColumnElement from .elements import Label from .elements import Null from .elements import UnaryExpression @@ -72,9 +73,7 @@ if typing.TYPE_CHECKING: from ._typing import _EquivalentColumnMap from ._typing import _TypeEngineArgument from .elements import TextClause - from .roles import FromClauseRole from .selectable import _JoinTargetElement - from .selectable import _OnClauseElement from .selectable import _SelectIterable from .selectable import Selectable from .visitors import _TraverseCallableType @@ -569,7 +568,7 @@ class _repr_row(_repr_base): __slots__ = ("row",) - def __init__(self, row: "Row", max_chars: int = 300): + def __init__(self, row: "Row[Any]", max_chars: int = 300): self.row = row self.max_chars = max_chars @@ -1068,7 +1067,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): col = col._annotations["adapt_column"] if TYPE_CHECKING: - assert isinstance(col, ColumnElement) + assert isinstance(col, KeyedColumnElement) if self.adapt_from_selectables and col not in self.equivalents: for adp in self.adapt_from_selectables: @@ -1078,7 +1077,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): return None if TYPE_CHECKING: - assert isinstance(col, ColumnElement) + assert isinstance(col, KeyedColumnElement) if self.include_fn and not self.include_fn(col): return None diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index e0a66fbcf..88586d834 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -450,7 +450,7 @@ class HasTraverseInternals: @util.preload_module("sqlalchemy.sql.traversals") def get_children( - self, omit_attrs: Tuple[str, ...] = (), **kw: Any + self, *, omit_attrs: Tuple[str, ...] = (), **kw: Any ) -> Iterable[HasTraverseInternals]: r"""Return immediate child :class:`.visitors.HasTraverseInternals` elements of this :class:`.visitors.HasTraverseInternals`. @@ -594,7 +594,7 @@ class ExternallyTraversible(HasTraverseInternals, Visitable): if typing.TYPE_CHECKING: def get_children( - self, omit_attrs: Tuple[str, ...] = (), **kw: Any + self, *, omit_attrs: Tuple[str, ...] = (), **kw: Any ) -> Iterable[ExternallyTraversible]: ... diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 49c5d693a..da3fbc718 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -18,6 +18,7 @@ import hashlib import inspect import itertools import operator +import os import re import sys import textwrap @@ -32,6 +33,7 @@ from typing import Generic from typing import Iterator from typing import List from typing import Mapping +from typing import no_type_check from typing import NoReturn from typing import Optional from typing import overload @@ -2106,3 +2108,45 @@ def has_compiled_ext(raise_=False): ) else: return False + + +@no_type_check +def console_scripts( + path: str, options: dict, ignore_output: bool = False +) -> None: + + import subprocess + import shlex + from pathlib import Path + + is_posix = os.name == "posix" + + entrypoint_name = options["entrypoint"] + + for entry in compat.importlib_metadata_get("console_scripts"): + if entry.name == entrypoint_name: + impl = entry + break + else: + raise Exception( + f"Could not find entrypoint console_scripts.{entrypoint_name}" + ) + cmdline_options_str = options.get("options", "") + cmdline_options_list = shlex.split(cmdline_options_str, posix=is_posix) + [ + path + ] + + kw = {} + if ignore_output: + kw["stdout"] = kw["stderr"] = subprocess.DEVNULL + + subprocess.run( + [ + sys.executable, + "-c", + "import %s; %s.%s()" % (impl.module, impl.module, impl.attr), + ] + + cmdline_options_list, + cwd=Path(__file__).parent.parent, + **kw, + ) diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index d192dc06b..2a215c4f1 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -14,7 +14,7 @@ from typing import Type from typing import TypeVar from typing import Union -from typing_extensions import NotRequired as NotRequired # noqa +from typing_extensions import NotRequired as NotRequired from . import compat |
