summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/sqlalchemy/engine/__init__.py1
-rw-r--r--lib/sqlalchemy/engine/base.py89
-rw-r--r--lib/sqlalchemy/engine/cursor.py24
-rw-r--r--lib/sqlalchemy/engine/default.py14
-rw-r--r--lib/sqlalchemy/engine/events.py3
-rw-r--r--lib/sqlalchemy/engine/interfaces.py2
-rw-r--r--lib/sqlalchemy/engine/result.py422
-rw-r--r--lib/sqlalchemy/engine/row.py51
-rw-r--r--lib/sqlalchemy/ext/asyncio/engine.py132
-rw-r--r--lib/sqlalchemy/ext/asyncio/result.py339
-rw-r--r--lib/sqlalchemy/ext/asyncio/scoping.py135
-rw-r--r--lib/sqlalchemy/ext/asyncio/session.py136
-rw-r--r--lib/sqlalchemy/ext/instrumentation.py3
-rw-r--r--lib/sqlalchemy/orm/_orm_constructors.py30
-rw-r--r--lib/sqlalchemy/orm/attributes.py5
-rw-r--r--lib/sqlalchemy/orm/base.py16
-rw-r--r--lib/sqlalchemy/orm/context.py7
-rw-r--r--lib/sqlalchemy/orm/interfaces.py13
-rw-r--r--lib/sqlalchemy/orm/mapper.py13
-rw-r--r--lib/sqlalchemy/orm/properties.py5
-rw-r--r--lib/sqlalchemy/orm/query.py296
-rw-r--r--lib/sqlalchemy/orm/scoping.py219
-rw-r--r--lib/sqlalchemy/orm/session.py211
-rw-r--r--lib/sqlalchemy/orm/state.py6
-rw-r--r--lib/sqlalchemy/orm/util.py53
-rw-r--r--lib/sqlalchemy/sql/__init__.py1
-rw-r--r--lib/sqlalchemy/sql/_selectable_constructors.py166
-rw-r--r--lib/sqlalchemy/sql/_typing.py78
-rw-r--r--lib/sqlalchemy/sql/base.py15
-rw-r--r--lib/sqlalchemy/sql/coercions.py25
-rw-r--r--lib/sqlalchemy/sql/compiler.py40
-rw-r--r--lib/sqlalchemy/sql/crud.py32
-rw-r--r--lib/sqlalchemy/sql/dml.py376
-rw-r--r--lib/sqlalchemy/sql/elements.py36
-rw-r--r--lib/sqlalchemy/sql/functions.py8
-rw-r--r--lib/sqlalchemy/sql/roles.py57
-rw-r--r--lib/sqlalchemy/sql/schema.py9
-rw-r--r--lib/sqlalchemy/sql/selectable.py287
-rw-r--r--lib/sqlalchemy/sql/util.py9
-rw-r--r--lib/sqlalchemy/sql/visitors.py4
-rw-r--r--lib/sqlalchemy/util/langhelpers.py44
-rw-r--r--lib/sqlalchemy/util/typing.py2
42 files changed, 2939 insertions, 475 deletions
diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py
index 29dd6aff9..afba17075 100644
--- a/lib/sqlalchemy/engine/__init__.py
+++ b/lib/sqlalchemy/engine/__init__.py
@@ -46,6 +46,7 @@ from .result import MergedResult as MergedResult
from .result import Result as Result
from .result import result_tuple as result_tuple
from .result import ScalarResult as ScalarResult
+from .result import TupleResult as TupleResult
from .row import BaseRow as BaseRow
from .row import Row as Row
from .row import RowMapping as RowMapping
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index a325da929..fe3bfa1ad 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -18,8 +18,10 @@ from typing import Mapping
from typing import MutableMapping
from typing import NoReturn
from typing import Optional
+from typing import overload
from typing import Tuple
from typing import Type
+from typing import TypeVar
from typing import Union
from .interfaces import _IsolationLevel
@@ -45,12 +47,10 @@ if typing.TYPE_CHECKING:
from . import ScalarResult
from .interfaces import _AnyExecuteParams
from .interfaces import _AnyMultiExecuteParams
- from .interfaces import _AnySingleExecuteParams
from .interfaces import _CoreAnyExecuteParams
from .interfaces import _CoreMultiExecuteParams
from .interfaces import _CoreSingleExecuteParams
from .interfaces import _DBAPIAnyExecuteParams
- from .interfaces import _DBAPIMultiExecuteParams
from .interfaces import _DBAPISingleExecuteParams
from .interfaces import _ExecuteOptions
from .interfaces import _ExecuteOptionsParameter
@@ -65,21 +65,21 @@ if typing.TYPE_CHECKING:
from ..pool import PoolProxiedConnection
from ..sql import Executable
from ..sql._typing import _InfoType
- from ..sql.base import SchemaVisitor
from ..sql.compiler import Compiled
from ..sql.ddl import ExecutableDDLElement
from ..sql.ddl import SchemaDropper
from ..sql.ddl import SchemaGenerator
from ..sql.functions import FunctionElement
- from ..sql.schema import ColumnDefault
from ..sql.schema import DefaultGenerator
from ..sql.schema import HasSchemaAttr
from ..sql.schema import SchemaItem
+ from ..sql.selectable import TypedReturnsRows
"""Defines :class:`_engine.Connection` and :class:`_engine.Engine`.
"""
+_T = TypeVar("_T", bound=Any)
_EMPTY_EXECUTION_OPTS: _ExecuteOptions = util.EMPTY_DICT
NO_OPTIONS: Mapping[str, Any] = util.EMPTY_DICT
@@ -1142,10 +1142,31 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
self._dbapi_connection = None
self.__can_reconnect = False
+ @overload
+ def scalar(
+ self,
+ statement: TypedReturnsRows[Tuple[_T]],
+ parameters: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> Optional[_T]:
+ ...
+
+ @overload
+ def scalar(
+ self,
+ statement: Executable,
+ parameters: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> Any:
+ ...
+
def scalar(
self,
statement: Executable,
parameters: Optional[_CoreSingleExecuteParams] = None,
+ *,
execution_options: Optional[_ExecuteOptionsParameter] = None,
) -> Any:
r"""Executes a SQL statement construct and returns a scalar object.
@@ -1170,10 +1191,31 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
execution_options or NO_OPTIONS,
)
+ @overload
+ def scalars(
+ self,
+ statement: TypedReturnsRows[Tuple[_T]],
+ parameters: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> ScalarResult[_T]:
+ ...
+
+ @overload
def scalars(
self,
statement: Executable,
parameters: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> ScalarResult[Any]:
+ ...
+
+ def scalars(
+ self,
+ statement: Executable,
+ parameters: Optional[_CoreSingleExecuteParams] = None,
+ *,
execution_options: Optional[_ExecuteOptionsParameter] = None,
) -> ScalarResult[Any]:
"""Executes and returns a scalar result set, which yields scalar values
@@ -1190,14 +1232,37 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
"""
- return self.execute(statement, parameters, execution_options).scalars()
+ return self.execute(
+ statement, parameters, execution_options=execution_options
+ ).scalars()
+
+ @overload
+ def execute(
+ self,
+ statement: TypedReturnsRows[_T],
+ parameters: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> CursorResult[_T]:
+ ...
+
+ @overload
+ def execute(
+ self,
+ statement: Executable,
+ parameters: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> CursorResult[Any]:
+ ...
def execute(
self,
statement: Executable,
parameters: Optional[_CoreAnyExecuteParams] = None,
+ *,
execution_options: Optional[_ExecuteOptionsParameter] = None,
- ) -> CursorResult:
+ ) -> CursorResult[Any]:
r"""Executes a SQL statement construct and returns a
:class:`_engine.CursorResult`.
@@ -1246,7 +1311,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
func: FunctionElement[Any],
distilled_parameters: _CoreMultiExecuteParams,
execution_options: _ExecuteOptionsParameter,
- ) -> CursorResult:
+ ) -> CursorResult[Any]:
"""Execute a sql.FunctionElement object."""
return self._execute_clauseelement(
@@ -1317,7 +1382,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
ddl: ExecutableDDLElement,
distilled_parameters: _CoreMultiExecuteParams,
execution_options: _ExecuteOptionsParameter,
- ) -> CursorResult:
+ ) -> CursorResult[Any]:
"""Execute a schema.DDL object."""
execution_options = ddl._execution_options.merge_with(
@@ -1414,7 +1479,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
elem: Executable,
distilled_parameters: _CoreMultiExecuteParams,
execution_options: _ExecuteOptionsParameter,
- ) -> CursorResult:
+ ) -> CursorResult[Any]:
"""Execute a sql.ClauseElement object."""
execution_options = elem._execution_options.merge_with(
@@ -1487,7 +1552,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
compiled: Compiled,
distilled_parameters: _CoreMultiExecuteParams,
execution_options: _ExecuteOptionsParameter = _EMPTY_EXECUTION_OPTS,
- ) -> CursorResult:
+ ) -> CursorResult[Any]:
"""Execute a sql.Compiled object.
TODO: why do we have this? likely deprecate or remove
@@ -1537,7 +1602,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
statement: str,
parameters: Optional[_DBAPIAnyExecuteParams] = None,
execution_options: Optional[_ExecuteOptionsParameter] = None,
- ) -> CursorResult:
+ ) -> CursorResult[Any]:
r"""Executes a SQL statement construct and returns a
:class:`_engine.CursorResult`.
@@ -1614,7 +1679,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
execution_options: _ExecuteOptions,
*args: Any,
**kw: Any,
- ) -> CursorResult:
+ ) -> CursorResult[Any]:
"""Create an :class:`.ExecutionContext` and execute, returning
a :class:`_engine.CursorResult`."""
diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py
index ccf573675..ff69666b7 100644
--- a/lib/sqlalchemy/engine/cursor.py
+++ b/lib/sqlalchemy/engine/cursor.py
@@ -24,6 +24,7 @@ from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
from .result import MergedResult
@@ -55,11 +56,12 @@ if typing.TYPE_CHECKING:
from .interfaces import ExecutionContext
from .result import _KeyIndexType
from .result import _KeyMapRecType
- from .result import _KeyMapType
from .result import _KeyType
from .result import _ProcessorsType
from ..sql.type_api import _ResultProcessorType
+_T = TypeVar("_T", bound=Any)
+
# metadata entry tuple indexes.
# using raw tuple is faster than namedtuple.
MD_INDEX: Literal[0] = 0 # integer index in cursor.description
@@ -214,7 +216,9 @@ class CursorResultMetaData(ResultMetaData):
return md
def __init__(
- self, parent: CursorResult, cursor_description: _DBAPICursorDescription
+ self,
+ parent: CursorResult[Any],
+ cursor_description: _DBAPICursorDescription,
):
context = parent.context
self._tuplefilter = None
@@ -1158,7 +1162,7 @@ class _NoResultMetaData(ResultMetaData):
_NO_RESULT_METADATA = _NoResultMetaData()
-class CursorResult(Result):
+class CursorResult(Result[_T]):
"""A Result that is representing state from a DBAPI cursor.
.. versionchanged:: 1.4 The :class:`.CursorResult``
@@ -1179,6 +1183,15 @@ class CursorResult(Result):
"""
+ __slots__ = (
+ "context",
+ "dialect",
+ "cursor",
+ "cursor_strategy",
+ "_echo",
+ "connection",
+ )
+
_metadata: Union[CursorResultMetaData, _NoResultMetaData]
_no_result_metadata = _NO_RESULT_METADATA
_soft_closed: bool = False
@@ -1231,7 +1244,6 @@ class CursorResult(Result):
make_row = _make_row_2
else:
make_row = _make_row
-
self._set_memoized_attribute("_row_getter", make_row)
else:
@@ -1726,12 +1738,12 @@ class CursorResult(Result):
def _raw_row_iterator(self):
return self._fetchiter_impl()
- def merge(self, *others: Result) -> MergedResult:
+ def merge(self, *others: Result[Any]) -> MergedResult[Any]:
merged_result = super().merge(*others)
setup_rowcounts = not self._metadata.returns_rows
if setup_rowcounts:
merged_result.rowcount = sum(
- cast(CursorResult, result).rowcount
+ cast("CursorResult[Any]", result).rowcount
for result in (self,) + others
)
return merged_result
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index c6571f68b..9c6ff758f 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -62,13 +62,9 @@ if typing.TYPE_CHECKING:
from .base import Connection
from .base import Engine
- from .characteristics import ConnectionCharacteristic
- from .interfaces import _AnyMultiExecuteParams
from .interfaces import _CoreMultiExecuteParams
from .interfaces import _CoreSingleExecuteParams
- from .interfaces import _DBAPIAnyExecuteParams
from .interfaces import _DBAPIMultiExecuteParams
- from .interfaces import _DBAPISingleExecuteParams
from .interfaces import _ExecuteOptions
from .interfaces import _IsolationLevel
from .interfaces import _MutableCoreSingleExecuteParams
@@ -83,15 +79,11 @@ if typing.TYPE_CHECKING:
from ..sql.compiler import Compiled
from ..sql.compiler import Linting
from ..sql.compiler import ResultColumnsEntry
- from ..sql.compiler import TypeCompiler
from ..sql.dml import DMLState
from ..sql.dml import UpdateBase
from ..sql.elements import BindParameter
- from ..sql.roles import ColumnsClauseRole
from ..sql.schema import Column
- from ..sql.schema import ColumnDefault
from ..sql.type_api import _BindProcessorType
- from ..sql.type_api import _ResultProcessorType
from ..sql.type_api import TypeEngine
# When we're handed literal SQL, ensure it's a SELECT query
@@ -781,7 +773,7 @@ class DefaultExecutionContext(ExecutionContext):
result_column_struct: Optional[
Tuple[List[ResultColumnsEntry], bool, bool, bool]
] = None
- returned_default_rows: Optional[List[Row]] = None
+ returned_default_rows: Optional[Sequence[Row[Any]]] = None
execution_options: _ExecuteOptions = util.EMPTY_DICT
@@ -1385,7 +1377,9 @@ class DefaultExecutionContext(ExecutionContext):
if cursor_description is None:
strategy = _cursor._NO_CURSOR_DML
- result = _cursor.CursorResult(self, strategy, cursor_description)
+ result: _cursor.CursorResult[Any] = _cursor.CursorResult(
+ self, strategy, cursor_description
+ )
if self.isinsert:
if self._is_implicit_returning:
diff --git a/lib/sqlalchemy/engine/events.py b/lib/sqlalchemy/engine/events.py
index ef10946a8..4093d3e0e 100644
--- a/lib/sqlalchemy/engine/events.py
+++ b/lib/sqlalchemy/engine/events.py
@@ -28,7 +28,6 @@ from ..util.typing import Literal
if typing.TYPE_CHECKING:
from .base import Connection
- from .interfaces import _CoreAnyExecuteParams
from .interfaces import _CoreMultiExecuteParams
from .interfaces import _CoreSingleExecuteParams
from .interfaces import _DBAPIAnyExecuteParams
@@ -273,7 +272,7 @@ class ConnectionEvents(event.Events[ConnectionEventsTarget]):
multiparams: _CoreMultiExecuteParams,
params: _CoreSingleExecuteParams,
execution_options: _ExecuteOptions,
- result: Result,
+ result: Result[Any],
) -> None:
"""Intercept high level execute() events after execute.
diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py
index 54fe21d74..641024603 100644
--- a/lib/sqlalchemy/engine/interfaces.py
+++ b/lib/sqlalchemy/engine/interfaces.py
@@ -2422,7 +2422,7 @@ class ExecutionContext:
def _get_cache_stats(self) -> str:
raise NotImplementedError()
- def _setup_result_proxy(self) -> CursorResult:
+ def _setup_result_proxy(self) -> CursorResult[Any]:
raise NotImplementedError()
def fire_sequence(self, seq: Sequence_SchemaItem, type_: Integer) -> int:
diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py
index 71320a583..55d36a1d5 100644
--- a/lib/sqlalchemy/engine/result.py
+++ b/lib/sqlalchemy/engine/result.py
@@ -28,6 +28,7 @@ from typing import overload
from typing import Sequence
from typing import Set
from typing import Tuple
+from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
@@ -70,6 +71,8 @@ _RawRowType = Tuple[Any, ...]
"""represents the kind of row we get from a DBAPI cursor"""
_R = TypeVar("_R", bound=_RowData)
+_T = TypeVar("_T", bound=Any)
+_TP = TypeVar("_TP", bound=Tuple[Any, ...])
_InterimRowType = Union[_R, _RawRowType]
"""a catchall "anything" kind of return type that can be applied
@@ -141,7 +144,7 @@ class ResultMetaData:
def _getter(
self, key: Any, raiseerr: bool = True
- ) -> Optional[Callable[[Row], Any]]:
+ ) -> Optional[Callable[[Row[Any]], Any]]:
index = self._index_for_key(key, raiseerr)
@@ -270,7 +273,7 @@ class SimpleResultMetaData(ResultMetaData):
_tuplefilter=_tuplefilter,
)
- def _contains(self, value: Any, row: Row) -> bool:
+ def _contains(self, value: Any, row: Row[Any]) -> bool:
return value in row._data
def _index_for_key(self, key: Any, raiseerr: bool = True) -> int:
@@ -335,7 +338,7 @@ class SimpleResultMetaData(ResultMetaData):
def result_tuple(
fields: Sequence[str], extra: Optional[Any] = None
-) -> Callable[[Iterable[Any]], Row]:
+) -> Callable[[Iterable[Any]], Row[Any]]:
parent = SimpleResultMetaData(fields, extra)
return functools.partial(
Row, parent, parent._processors, parent._keymap, Row._default_key_style
@@ -355,7 +358,9 @@ SelfResultInternal = TypeVar("SelfResultInternal", bound="ResultInternal[Any]")
class ResultInternal(InPlaceGenerative, Generic[_R]):
- _real_result: Optional[Result] = None
+ __slots__ = ()
+
+ _real_result: Optional[Result[Any]] = None
_generate_rows: bool = True
_row_logging_fn: Optional[Callable[[Any], Any]]
@@ -367,20 +372,20 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
_source_supports_scalars: bool
- def _fetchiter_impl(self) -> Iterator[_InterimRowType[Row]]:
+ def _fetchiter_impl(self) -> Iterator[_InterimRowType[Row[Any]]]:
raise NotImplementedError()
def _fetchone_impl(
self, hard_close: bool = False
- ) -> Optional[_InterimRowType[Row]]:
+ ) -> Optional[_InterimRowType[Row[Any]]]:
raise NotImplementedError()
def _fetchmany_impl(
self, size: Optional[int] = None
- ) -> List[_InterimRowType[Row]]:
+ ) -> List[_InterimRowType[Row[Any]]]:
raise NotImplementedError()
- def _fetchall_impl(self) -> List[_InterimRowType[Row]]:
+ def _fetchall_impl(self) -> List[_InterimRowType[Row[Any]]]:
raise NotImplementedError()
def _soft_close(self, hard: bool = False) -> None:
@@ -388,8 +393,10 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
@HasMemoized_ro_memoized_attribute
def _row_getter(self) -> Optional[Callable[..., _R]]:
- real_result: Result = (
- self._real_result if self._real_result else cast(Result, self)
+ real_result: Result[Any] = (
+ self._real_result
+ if self._real_result
+ else cast("Result[Any]", self)
)
if real_result._source_supports_scalars:
@@ -404,7 +411,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
keymap: _KeyMapType,
key_style: Any,
scalar_obj: Any,
- ) -> Row:
+ ) -> Row[Any]:
return _proc(
metadata, processors, keymap, key_style, (scalar_obj,)
)
@@ -429,7 +436,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
fixed_tf = tf
- def make_row(row: _InterimRowType[Row]) -> _R:
+ def make_row(row: _InterimRowType[Row[Any]]) -> _R:
return _make_row_orig(fixed_tf(row))
else:
@@ -447,7 +454,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
if fns:
_make_row = make_row
- def make_row(row: _InterimRowType[Row]) -> _R:
+ def make_row(row: _InterimRowType[Row[Any]]) -> _R:
interim_row = _make_row(row)
for fn in fns:
interim_row = fn(interim_row)
@@ -465,7 +472,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
if self._unique_filter_state:
uniques, strategy = self._unique_strategy
- def iterrows(self: Result) -> Iterator[_R]:
+ def iterrows(self: Result[Any]) -> Iterator[_R]:
for raw_row in self._fetchiter_impl():
obj: _InterimRowType[Any] = (
make_row(raw_row) if make_row else raw_row
@@ -480,7 +487,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
else:
- def iterrows(self: Result) -> Iterator[_R]:
+ def iterrows(self: Result[Any]) -> Iterator[_R]:
for raw_row in self._fetchiter_impl():
row: _InterimRowType[Any] = (
make_row(raw_row) if make_row else raw_row
@@ -546,7 +553,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
if self._unique_filter_state:
uniques, strategy = self._unique_strategy
- def onerow(self: Result) -> Union[_NoRow, _R]:
+ def onerow(self: Result[Any]) -> Union[_NoRow, _R]:
_onerow = self._fetchone_impl
while True:
row = _onerow()
@@ -567,7 +574,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
else:
- def onerow(self: Result) -> Union[_NoRow, _R]:
+ def onerow(self: Result[Any]) -> Union[_NoRow, _R]:
row = self._fetchone_impl()
if row is None:
return _NO_ROW
@@ -627,7 +634,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
real_result = (
self._real_result
if self._real_result
- else cast(Result, self)
+ else cast("Result[Any]", self)
)
if real_result._yield_per:
num_required = num = real_result._yield_per
@@ -667,7 +674,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
real_result = (
self._real_result
if self._real_result
- else cast(Result, self)
+ else cast("Result[Any]", self)
)
num = real_result._yield_per
@@ -799,7 +806,9 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
self: SelfResultInternal, indexes: Sequence[_KeyIndexType]
) -> SelfResultInternal:
real_result = (
- self._real_result if self._real_result else cast(Result, self)
+ self._real_result
+ if self._real_result
+ else cast("Result[Any]", self)
)
if not real_result._source_supports_scalars or len(indexes) != 1:
@@ -817,7 +826,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
real_result = (
self._real_result
if self._real_result is not None
- else cast(Result, self)
+ else cast("Result[Any]", self)
)
if not strategy and self._metadata._unique_filters:
@@ -836,6 +845,8 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
class _WithKeys:
+ __slots__ = ()
+
_metadata: ResultMetaData
# used mainly to share documentation on the keys method.
@@ -859,10 +870,10 @@ class _WithKeys:
return self._metadata.keys
-SelfResult = TypeVar("SelfResult", bound="Result")
+SelfResult = TypeVar("SelfResult", bound="Result[Any]")
-class Result(_WithKeys, ResultInternal[Row]):
+class Result(_WithKeys, ResultInternal[Row[_TP]]):
"""Represent a set of database results.
.. versionadded:: 1.4 The :class:`.Result` object provides a completely
@@ -887,7 +898,9 @@ class Result(_WithKeys, ResultInternal[Row]):
"""
- _row_logging_fn: Optional[Callable[[Row], Row]] = None
+ __slots__ = ("_metadata", "__dict__")
+
+ _row_logging_fn: Optional[Callable[[Row[Any]], Row[Any]]] = None
_source_supports_scalars: bool = False
@@ -1011,6 +1024,15 @@ class Result(_WithKeys, ResultInternal[Row]):
appropriate :class:`.ColumnElement` objects which correspond to
a given statement construct.
+ .. versionchanged:: 2.0 Due to a bug in 1.4, the
+ :meth:`.Result.columns` method had an incorrect behavior where
+ calling upon the method with just one index would cause the
+ :class:`.Result` object to yield scalar values rather than
+ :class:`.Row` objects. In version 2.0, this behavior has been
+ corrected such that calling upon :meth:`.Result.columns` with
+ a single index will produce a :class:`.Result` object that continues
+ to yield :class:`.Row` objects, which include only a single column.
+
E.g.::
statement = select(table.c.x, table.c.y, table.c.z)
@@ -1040,6 +1062,20 @@ class Result(_WithKeys, ResultInternal[Row]):
"""
return self._column_slices(col_expressions)
+ @overload
+ def scalars(self: Result[Tuple[_T]]) -> ScalarResult[_T]:
+ ...
+
+ @overload
+ def scalars(
+ self: Result[Tuple[_T]], index: Literal[0]
+ ) -> ScalarResult[_T]:
+ ...
+
+ @overload
+ def scalars(self, index: _KeyIndexType = 0) -> ScalarResult[Any]:
+ ...
+
def scalars(self, index: _KeyIndexType = 0) -> ScalarResult[Any]:
"""Return a :class:`_result.ScalarResult` filtering object which
will return single elements rather than :class:`_row.Row` objects.
@@ -1067,7 +1103,7 @@ class Result(_WithKeys, ResultInternal[Row]):
def _getter(
self, key: _KeyIndexType, raiseerr: bool = True
- ) -> Optional[Callable[[Row], Any]]:
+ ) -> Optional[Callable[[Row[Any]], Any]]:
"""return a callable that will retrieve the given key from a
:class:`.Row`.
@@ -1105,6 +1141,43 @@ class Result(_WithKeys, ResultInternal[Row]):
return MappingResult(self)
+ @property
+ def t(self) -> TupleResult[_TP]:
+ """Apply a "typed tuple" typing filter to returned rows.
+
+ The :attr:`.Result.t` attribute is a synonym for calling the
+ :meth:`.Result.tuples` method.
+
+ .. versionadded:: 2.0
+
+ """
+ return self # type: ignore
+
+ def tuples(self) -> TupleResult[_TP]:
+ """Apply a "typed tuple" typing filter to returned rows.
+
+ This method returns the same :class:`.Result` object at runtime,
+ however annotates as returning a :class:`.TupleResult` object
+ that will indicate to :pep:`484` typing tools that plain typed
+ ``Tuple`` instances are returned rather than rows. This allows
+ tuple unpacking and ``__getitem__`` access of :class:`.Row` objects
+ to by typed, for those cases where the statement invoked itself
+ included typing information.
+
+ .. versionadded:: 2.0
+
+ :return: the :class:`_result.TupleResult` type at typing time.
+
+ .. seealso::
+
+ :attr:`.Result.t` - shorter synonym
+
+ :attr:`.Row.t` - :class:`.Row` version
+
+ """
+
+ return self # type: ignore
+
def _raw_row_iterator(self) -> Iterator[_RowData]:
"""Return a safe iterator that yields raw row data.
@@ -1114,13 +1187,15 @@ class Result(_WithKeys, ResultInternal[Row]):
"""
raise NotImplementedError()
- def __iter__(self) -> Iterator[Row]:
+ def __iter__(self) -> Iterator[Row[_TP]]:
return self._iter_impl()
- def __next__(self) -> Row:
+ def __next__(self) -> Row[_TP]:
return self._next_impl()
- def partitions(self, size: Optional[int] = None) -> Iterator[List[Row]]:
+ def partitions(
+ self, size: Optional[int] = None
+ ) -> Iterator[Sequence[Row[_TP]]]:
"""Iterate through sub-lists of rows of the size given.
Each list will be of the size given, excluding the last list to
@@ -1158,12 +1233,12 @@ class Result(_WithKeys, ResultInternal[Row]):
else:
break
- def fetchall(self) -> List[Row]:
+ def fetchall(self) -> Sequence[Row[_TP]]:
"""A synonym for the :meth:`_engine.Result.all` method."""
return self._allrows()
- def fetchone(self) -> Optional[Row]:
+ def fetchone(self) -> Optional[Row[_TP]]:
"""Fetch one row.
When all rows are exhausted, returns None.
@@ -1185,7 +1260,7 @@ class Result(_WithKeys, ResultInternal[Row]):
else:
return row
- def fetchmany(self, size: Optional[int] = None) -> List[Row]:
+ def fetchmany(self, size: Optional[int] = None) -> Sequence[Row[_TP]]:
"""Fetch many rows.
When all rows are exhausted, returns an empty list.
@@ -1202,7 +1277,7 @@ class Result(_WithKeys, ResultInternal[Row]):
return self._manyrow_getter(self, size)
- def all(self) -> List[Row]:
+ def all(self) -> Sequence[Row[_TP]]:
"""Return all rows in a list.
Closes the result set after invocation. Subsequent invocations
@@ -1216,7 +1291,7 @@ class Result(_WithKeys, ResultInternal[Row]):
return self._allrows()
- def first(self) -> Optional[Row]:
+ def first(self) -> Optional[Row[_TP]]:
"""Fetch the first row or None if no row is present.
Closes the result set and discards remaining rows.
@@ -1252,7 +1327,7 @@ class Result(_WithKeys, ResultInternal[Row]):
raise_for_second_row=False, raise_for_none=False, scalar=False
)
- def one_or_none(self) -> Optional[Row]:
+ def one_or_none(self) -> Optional[Row[_TP]]:
"""Return at most one result or raise an exception.
Returns ``None`` if the result has no rows.
@@ -1276,6 +1351,14 @@ class Result(_WithKeys, ResultInternal[Row]):
raise_for_second_row=True, raise_for_none=False, scalar=False
)
+ @overload
+ def scalar_one(self: Result[Tuple[_T]]) -> _T:
+ ...
+
+ @overload
+ def scalar_one(self) -> Any:
+ ...
+
def scalar_one(self) -> Any:
"""Return exactly one scalar result or raise an exception.
@@ -1293,6 +1376,14 @@ class Result(_WithKeys, ResultInternal[Row]):
raise_for_second_row=True, raise_for_none=True, scalar=True
)
+ @overload
+ def scalar_one_or_none(self: Result[Tuple[_T]]) -> Optional[_T]:
+ ...
+
+ @overload
+ def scalar_one_or_none(self) -> Optional[Any]:
+ ...
+
def scalar_one_or_none(self) -> Optional[Any]:
"""Return exactly one or no scalar result.
@@ -1310,7 +1401,7 @@ class Result(_WithKeys, ResultInternal[Row]):
raise_for_second_row=True, raise_for_none=False, scalar=True
)
- def one(self) -> Row:
+ def one(self) -> Row[_TP]:
"""Return exactly one row or raise an exception.
Raises :class:`.NoResultFound` if the result returns no
@@ -1341,6 +1432,14 @@ class Result(_WithKeys, ResultInternal[Row]):
raise_for_second_row=True, raise_for_none=True, scalar=False
)
+ @overload
+ def scalar(self: Result[Tuple[_T]]) -> Optional[_T]:
+ ...
+
+ @overload
+ def scalar(self) -> Any:
+ ...
+
def scalar(self) -> Any:
"""Fetch the first column of the first row, and close the result set.
@@ -1359,7 +1458,7 @@ class Result(_WithKeys, ResultInternal[Row]):
raise_for_second_row=False, raise_for_none=False, scalar=True
)
- def freeze(self) -> FrozenResult:
+ def freeze(self) -> FrozenResult[_TP]:
"""Return a callable object that will produce copies of this
:class:`.Result` when invoked.
@@ -1382,7 +1481,7 @@ class Result(_WithKeys, ResultInternal[Row]):
return FrozenResult(self)
- def merge(self, *others: Result) -> MergedResult:
+ def merge(self, *others: Result[Any]) -> MergedResult[_TP]:
"""Merge this :class:`.Result` with other compatible result
objects.
@@ -1405,9 +1504,17 @@ class FilterResult(ResultInternal[_R]):
"""
- _post_creational_filter: Optional[Callable[[Any], Any]] = None
+ __slots__ = (
+ "_real_result",
+ "_post_creational_filter",
+ "_metadata",
+ "_unique_filter_state",
+ "__dict__",
+ )
+
+ _post_creational_filter: Optional[Callable[[Any], Any]]
- _real_result: Result
+ _real_result: Result[Any]
def _soft_close(self, hard: bool = False) -> None:
self._real_result._soft_close(hard=hard)
@@ -1416,20 +1523,20 @@ class FilterResult(ResultInternal[_R]):
def _attributes(self) -> Dict[Any, Any]:
return self._real_result._attributes
- def _fetchiter_impl(self) -> Iterator[_InterimRowType[Row]]:
+ def _fetchiter_impl(self) -> Iterator[_InterimRowType[Row[Any]]]:
return self._real_result._fetchiter_impl()
def _fetchone_impl(
self, hard_close: bool = False
- ) -> Optional[_InterimRowType[Row]]:
+ ) -> Optional[_InterimRowType[Row[Any]]]:
return self._real_result._fetchone_impl(hard_close=hard_close)
- def _fetchall_impl(self) -> List[_InterimRowType[Row]]:
+ def _fetchall_impl(self) -> List[_InterimRowType[Row[Any]]]:
return self._real_result._fetchall_impl()
def _fetchmany_impl(
self, size: Optional[int] = None
- ) -> List[_InterimRowType[Row]]:
+ ) -> List[_InterimRowType[Row[Any]]]:
return self._real_result._fetchmany_impl(size=size)
@@ -1452,11 +1559,13 @@ class ScalarResult(FilterResult[_R]):
"""
+ __slots__ = ()
+
_generate_rows = False
_post_creational_filter: Optional[Callable[[Any], Any]]
- def __init__(self, real_result: Result, index: _KeyIndexType):
+ def __init__(self, real_result: Result[Any], index: _KeyIndexType):
self._real_result = real_result
if real_result._source_supports_scalars:
@@ -1480,7 +1589,7 @@ class ScalarResult(FilterResult[_R]):
self._unique_filter_state = (set(), strategy)
return self
- def partitions(self, size: Optional[int] = None) -> Iterator[List[_R]]:
+ def partitions(self, size: Optional[int] = None) -> Iterator[Sequence[_R]]:
"""Iterate through sub-lists of elements of the size given.
Equivalent to :meth:`_result.Result.partitions` except that
@@ -1498,12 +1607,12 @@ class ScalarResult(FilterResult[_R]):
else:
break
- def fetchall(self) -> List[_R]:
+ def fetchall(self) -> Sequence[_R]:
"""A synonym for the :meth:`_engine.ScalarResult.all` method."""
return self._allrows()
- def fetchmany(self, size: Optional[int] = None) -> List[_R]:
+ def fetchmany(self, size: Optional[int] = None) -> Sequence[_R]:
"""Fetch many objects.
Equivalent to :meth:`_result.Result.fetchmany` except that
@@ -1513,7 +1622,7 @@ class ScalarResult(FilterResult[_R]):
"""
return self._manyrow_getter(self, size)
- def all(self) -> List[_R]:
+ def all(self) -> Sequence[_R]:
"""Return all scalar values in a list.
Equivalent to :meth:`_result.Result.all` except that
@@ -1567,6 +1676,177 @@ class ScalarResult(FilterResult[_R]):
)
+SelfTupleResult = TypeVar("SelfTupleResult", bound="TupleResult[Any]")
+
+
+class TupleResult(FilterResult[_R], util.TypingOnly):
+ """a :class:`.Result` that's typed as returning plain Python tuples
+ instead of rows.
+
+ Since :class:`.Row` acts like a tuple in every way already,
+ this class is a typing only class, regular :class:`.Result` is still
+ used at runtime.
+
+ """
+
+ __slots__ = ()
+
+ if TYPE_CHECKING:
+
+ def partitions(
+ self, size: Optional[int] = None
+ ) -> Iterator[Sequence[_R]]:
+ """Iterate through sub-lists of elements of the size given.
+
+ Equivalent to :meth:`_result.Result.partitions` except that
+ tuple values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ ...
+
+ def fetchone(self) -> Optional[_R]:
+ """Fetch one tuple.
+
+ Equivalent to :meth:`_result.Result.fetchone` except that
+ tuple values, rather than :class:`_result.Row`
+ objects, are returned.
+
+ """
+ ...
+
+ def fetchall(self) -> Sequence[_R]:
+ """A synonym for the :meth:`_engine.ScalarResult.all` method."""
+ ...
+
+ def fetchmany(self, size: Optional[int] = None) -> Sequence[_R]:
+ """Fetch many objects.
+
+ Equivalent to :meth:`_result.Result.fetchmany` except that
+ tuple values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ ...
+
+ def all(self) -> Sequence[_R]: # noqa: A001
+ """Return all scalar values in a list.
+
+ Equivalent to :meth:`_result.Result.all` except that
+ tuple values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ ...
+
+ def __iter__(self) -> Iterator[_R]:
+ ...
+
+ def __next__(self) -> _R:
+ ...
+
+ def first(self) -> Optional[_R]:
+ """Fetch the first object or None if no object is present.
+
+ Equivalent to :meth:`_result.Result.first` except that
+ tuple values, rather than :class:`_result.Row` objects,
+ are returned.
+
+
+ """
+ ...
+
+ def one_or_none(self) -> Optional[_R]:
+ """Return at most one object or raise an exception.
+
+ Equivalent to :meth:`_result.Result.one_or_none` except that
+ tuple values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ ...
+
+ def one(self) -> _R:
+ """Return exactly one object or raise an exception.
+
+ Equivalent to :meth:`_result.Result.one` except that
+ tuple values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ ...
+
+ @overload
+ def scalar_one(self: TupleResult[Tuple[_T]]) -> _T:
+ ...
+
+ @overload
+ def scalar_one(self) -> Any:
+ ...
+
+ def scalar_one(self) -> Any:
+ """Return exactly one scalar result or raise an exception.
+
+ This is equivalent to calling :meth:`.Result.scalars` and then
+ :meth:`.Result.one`.
+
+ .. seealso::
+
+ :meth:`.Result.one`
+
+ :meth:`.Result.scalars`
+
+ """
+ ...
+
+ @overload
+ def scalar_one_or_none(self: TupleResult[Tuple[_T]]) -> Optional[_T]:
+ ...
+
+ @overload
+ def scalar_one_or_none(self) -> Optional[Any]:
+ ...
+
+ def scalar_one_or_none(self) -> Optional[Any]:
+ """Return exactly one or no scalar result.
+
+ This is equivalent to calling :meth:`.Result.scalars` and then
+ :meth:`.Result.one_or_none`.
+
+ .. seealso::
+
+ :meth:`.Result.one_or_none`
+
+ :meth:`.Result.scalars`
+
+ """
+ ...
+
+ @overload
+ def scalar(self: TupleResult[Tuple[_T]]) -> Optional[_T]:
+ ...
+
+ @overload
+ def scalar(self) -> Any:
+ ...
+
+ def scalar(self) -> Any:
+ """Fetch the first column of the first row, and close the result set.
+
+ Returns None if there are no rows to fetch.
+
+ No validation is performed to test if additional rows remain.
+
+ After calling this method, the object is fully closed,
+ e.g. the :meth:`_engine.CursorResult.close`
+ method will have been called.
+
+ :return: a Python scalar value , or None if no rows remain.
+
+ """
+ ...
+
+
SelfMappingResult = TypeVar("SelfMappingResult", bound="MappingResult")
@@ -1579,11 +1859,13 @@ class MappingResult(_WithKeys, FilterResult[RowMapping]):
"""
+ __slots__ = ()
+
_generate_rows = True
_post_creational_filter = operator.attrgetter("_mapping")
- def __init__(self, result: Result):
+ def __init__(self, result: Result[Any]):
self._real_result = result
self._unique_filter_state = result._unique_filter_state
self._metadata = result._metadata
@@ -1610,7 +1892,7 @@ class MappingResult(_WithKeys, FilterResult[RowMapping]):
def partitions(
self, size: Optional[int] = None
- ) -> Iterator[List[RowMapping]]:
+ ) -> Iterator[Sequence[RowMapping]]:
"""Iterate through sub-lists of elements of the size given.
Equivalent to :meth:`_result.Result.partitions` except that
@@ -1628,7 +1910,7 @@ class MappingResult(_WithKeys, FilterResult[RowMapping]):
else:
break
- def fetchall(self) -> List[RowMapping]:
+ def fetchall(self) -> Sequence[RowMapping]:
"""A synonym for the :meth:`_engine.MappingResult.all` method."""
return self._allrows()
@@ -1648,7 +1930,7 @@ class MappingResult(_WithKeys, FilterResult[RowMapping]):
else:
return row
- def fetchmany(self, size: Optional[int] = None) -> List[RowMapping]:
+ def fetchmany(self, size: Optional[int] = None) -> Sequence[RowMapping]:
"""Fetch many objects.
Equivalent to :meth:`_result.Result.fetchmany` except that
@@ -1659,7 +1941,7 @@ class MappingResult(_WithKeys, FilterResult[RowMapping]):
return self._manyrow_getter(self, size)
- def all(self) -> List[RowMapping]:
+ def all(self) -> Sequence[RowMapping]:
"""Return all scalar values in a list.
Equivalent to :meth:`_result.Result.all` except that
@@ -1714,7 +1996,7 @@ class MappingResult(_WithKeys, FilterResult[RowMapping]):
)
-class FrozenResult:
+class FrozenResult(Generic[_TP]):
"""Represents a :class:`.Result` object in a "frozen" state suitable
for caching.
@@ -1755,7 +2037,7 @@ class FrozenResult:
data: Sequence[Any]
- def __init__(self, result: Result):
+ def __init__(self, result: Result[_TP]):
self.metadata = result._metadata._for_freeze()
self._source_supports_scalars = result._source_supports_scalars
self._attributes = result._attributes
@@ -1771,7 +2053,9 @@ class FrozenResult:
else:
return [list(row) for row in self.data]
- def with_new_rows(self, tuple_data: Sequence[Row]) -> FrozenResult:
+ def with_new_rows(
+ self, tuple_data: Sequence[Row[_TP]]
+ ) -> FrozenResult[_TP]:
fr = FrozenResult.__new__(FrozenResult)
fr.metadata = self.metadata
fr._attributes = self._attributes
@@ -1783,14 +2067,16 @@ class FrozenResult:
fr.data = tuple_data
return fr
- def __call__(self) -> Result:
- result = IteratorResult(self.metadata, iter(self.data))
+ def __call__(self) -> Result[_TP]:
+ result: IteratorResult[_TP] = IteratorResult(
+ self.metadata, iter(self.data)
+ )
result._attributes = self._attributes
result._source_supports_scalars = self._source_supports_scalars
return result
-class IteratorResult(Result):
+class IteratorResult(Result[_TP]):
"""A :class:`.Result` that gets data from a Python iterator of
:class:`.Row` objects or similar row-like data.
@@ -1833,7 +2119,7 @@ class IteratorResult(Result):
def _fetchone_impl(
self, hard_close: bool = False
- ) -> Optional[_InterimRowType[Row]]:
+ ) -> Optional[_InterimRowType[Row[Any]]]:
if self._hard_closed:
self._raise_hard_closed()
@@ -1844,7 +2130,7 @@ class IteratorResult(Result):
else:
return row
- def _fetchall_impl(self) -> List[_InterimRowType[Row]]:
+ def _fetchall_impl(self) -> List[_InterimRowType[Row[Any]]]:
if self._hard_closed:
self._raise_hard_closed()
try:
@@ -1854,23 +2140,23 @@ class IteratorResult(Result):
def _fetchmany_impl(
self, size: Optional[int] = None
- ) -> List[_InterimRowType[Row]]:
+ ) -> List[_InterimRowType[Row[Any]]]:
if self._hard_closed:
self._raise_hard_closed()
return list(itertools.islice(self.iterator, 0, size))
-def null_result() -> IteratorResult:
+def null_result() -> IteratorResult[Any]:
return IteratorResult(SimpleResultMetaData([]), iter([]))
SelfChunkedIteratorResult = TypeVar(
- "SelfChunkedIteratorResult", bound="ChunkedIteratorResult"
+ "SelfChunkedIteratorResult", bound="ChunkedIteratorResult[Any]"
)
-class ChunkedIteratorResult(IteratorResult):
+class ChunkedIteratorResult(IteratorResult[_TP]):
"""An :class:`.IteratorResult` that works from an iterator-producing callable.
The given ``chunks`` argument is a function that is given a number of rows
@@ -1922,13 +2208,13 @@ class ChunkedIteratorResult(IteratorResult):
def _fetchmany_impl(
self, size: Optional[int] = None
- ) -> List[_InterimRowType[Row]]:
+ ) -> List[_InterimRowType[Row[Any]]]:
if self.dynamic_yield_per:
self.iterator = itertools.chain.from_iterable(self.chunks(size))
return super()._fetchmany_impl(size=size)
-class MergedResult(IteratorResult):
+class MergedResult(IteratorResult[_TP]):
"""A :class:`_engine.Result` that is merged from any number of
:class:`_engine.Result` objects.
@@ -1942,7 +2228,7 @@ class MergedResult(IteratorResult):
rowcount: Optional[int]
def __init__(
- self, cursor_metadata: ResultMetaData, results: Sequence[Result]
+ self, cursor_metadata: ResultMetaData, results: Sequence[Result[_TP]]
):
self._results = results
super(MergedResult, self).__init__(
diff --git a/lib/sqlalchemy/engine/row.py b/lib/sqlalchemy/engine/row.py
index 4ba39b55d..7c9eacb78 100644
--- a/lib/sqlalchemy/engine/row.py
+++ b/lib/sqlalchemy/engine/row.py
@@ -16,6 +16,7 @@ import typing
from typing import Any
from typing import Callable
from typing import Dict
+from typing import Generic
from typing import Iterator
from typing import List
from typing import Mapping
@@ -24,12 +25,14 @@ from typing import Optional
from typing import overload
from typing import Sequence
from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
from ..sql import util as sql_util
from ..util._has_cy import HAS_CYEXTENSION
-if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
+if TYPE_CHECKING or not HAS_CYEXTENSION:
from ._py_row import BaseRow as BaseRow
from ._py_row import KEY_INTEGER_ONLY
from ._py_row import KEY_OBJECTS_ONLY
@@ -38,13 +41,16 @@ else:
from sqlalchemy.cyextension.resultproxy import KEY_INTEGER_ONLY
from sqlalchemy.cyextension.resultproxy import KEY_OBJECTS_ONLY
-if typing.TYPE_CHECKING:
+if TYPE_CHECKING:
from .result import _KeyType
from .result import RMKeyView
from ..sql.type_api import _ResultProcessorType
+_T = TypeVar("_T", bound=Any)
+_TP = TypeVar("_TP", bound=Tuple[Any, ...])
-class Row(BaseRow, typing.Sequence[Any]):
+
+class Row(BaseRow, Sequence[Any], Generic[_TP]):
"""Represent a single result row.
The :class:`.Row` object represents a row of a database result. It is
@@ -82,6 +88,37 @@ class Row(BaseRow, typing.Sequence[Any]):
def __delattr__(self, name: str) -> NoReturn:
raise AttributeError("can't delete attribute")
+ def tuple(self) -> _TP:
+ """Return a 'tuple' form of this :class:`.Row`.
+
+ At runtime, this method returns "self"; the :class:`.Row` object is
+ already a named tuple. However, at the typing level, if this
+ :class:`.Row` is typed, the "tuple" return type will be a :pep:`484`
+ ``Tuple`` datatype that contains typing information about individual
+ elements, supporting typed unpacking and attribute access.
+
+ .. versionadded:: 2.0
+
+ .. seealso::
+
+ :meth:`.Result.tuples`
+
+ """
+ return self # type: ignore
+
+ @property
+ def t(self) -> _TP:
+ """a synonym for :attr:`.Row.tuple`
+
+ .. versionadded:: 2.0
+
+ .. seealso::
+
+ :meth:`.Result.t`
+
+ """
+ return self # type: ignore
+
@property
def _mapping(self) -> RowMapping:
"""Return a :class:`.RowMapping` for this :class:`.Row`.
@@ -107,7 +144,7 @@ class Row(BaseRow, typing.Sequence[Any]):
def _filter_on_values(
self, filters: Optional[Sequence[Optional[_ResultProcessorType[Any]]]]
- ) -> Row:
+ ) -> Row[Any]:
return Row(
self._parent,
filters,
@@ -116,7 +153,7 @@ class Row(BaseRow, typing.Sequence[Any]):
self._data,
)
- if not typing.TYPE_CHECKING:
+ if not TYPE_CHECKING:
def _special_name_accessor(name: str) -> Any:
"""Handle ambiguous names such as "count" and "index" """
@@ -151,7 +188,7 @@ class Row(BaseRow, typing.Sequence[Any]):
__hash__ = BaseRow.__hash__
- if typing.TYPE_CHECKING:
+ if TYPE_CHECKING:
@overload
def __getitem__(self, index: int) -> Any:
@@ -299,7 +336,7 @@ class RowMapping(BaseRow, typing.Mapping[str, Any]):
_default_key_style = KEY_OBJECTS_ONLY
- if typing.TYPE_CHECKING:
+ if TYPE_CHECKING:
def __getitem__(self, key: _KeyType) -> Any:
...
diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py
index fb05f512e..95549ada6 100644
--- a/lib/sqlalchemy/ext/asyncio/engine.py
+++ b/lib/sqlalchemy/ext/asyncio/engine.py
@@ -12,8 +12,10 @@ from typing import Generator
from typing import NoReturn
from typing import Optional
from typing import overload
+from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
from . import exc as async_exc
@@ -50,6 +52,9 @@ if TYPE_CHECKING:
from ...pool import PoolProxiedConnection
from ...sql._typing import _InfoType
from ...sql.base import Executable
+ from ...sql.selectable import TypedReturnsRows
+
+_T = TypeVar("_T", bound=Any)
class _SyncConnectionCallable(Protocol):
@@ -407,7 +412,7 @@ class AsyncConnection(
statement: str,
parameters: Optional[_DBAPIAnyExecuteParams] = None,
execution_options: Optional[_ExecuteOptionsParameter] = None,
- ) -> CursorResult:
+ ) -> CursorResult[Any]:
r"""Executes a driver-level SQL string and return buffered
:class:`_engine.Result`.
@@ -423,12 +428,33 @@ class AsyncConnection(
return await _ensure_sync_result(result, self.exec_driver_sql)
+ @overload
+ async def stream(
+ self,
+ statement: TypedReturnsRows[_T],
+ parameters: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> AsyncResult[_T]:
+ ...
+
+ @overload
async def stream(
self,
statement: Executable,
parameters: Optional[_CoreAnyExecuteParams] = None,
+ *,
execution_options: Optional[_ExecuteOptionsParameter] = None,
- ) -> AsyncResult:
+ ) -> AsyncResult[Any]:
+ ...
+
+ async def stream(
+ self,
+ statement: Executable,
+ parameters: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> AsyncResult[Any]:
"""Execute a statement and return a streaming
:class:`_asyncio.AsyncResult` object."""
@@ -436,7 +462,7 @@ class AsyncConnection(
self._proxied.execute,
statement,
parameters,
- util.EMPTY_DICT.merge_with(
+ execution_options=util.EMPTY_DICT.merge_with(
execution_options, {"stream_results": True}
),
_require_await=True,
@@ -446,12 +472,33 @@ class AsyncConnection(
assert False, "server side result expected"
return AsyncResult(result)
+ @overload
+ async def execute(
+ self,
+ statement: TypedReturnsRows[_T],
+ parameters: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> CursorResult[_T]:
+ ...
+
+ @overload
async def execute(
self,
statement: Executable,
parameters: Optional[_CoreAnyExecuteParams] = None,
+ *,
execution_options: Optional[_ExecuteOptionsParameter] = None,
- ) -> CursorResult:
+ ) -> CursorResult[Any]:
+ ...
+
+ async def execute(
+ self,
+ statement: Executable,
+ parameters: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> CursorResult[Any]:
r"""Executes a SQL statement construct and return a buffered
:class:`_engine.Result`.
@@ -487,15 +534,36 @@ class AsyncConnection(
self._proxied.execute,
statement,
parameters,
- execution_options,
+ execution_options=execution_options,
_require_await=True,
)
return await _ensure_sync_result(result, self.execute)
+ @overload
+ async def scalar(
+ self,
+ statement: TypedReturnsRows[Tuple[_T]],
+ parameters: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> Optional[_T]:
+ ...
+
+ @overload
async def scalar(
self,
statement: Executable,
parameters: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> Any:
+ ...
+
+ async def scalar(
+ self,
+ statement: Executable,
+ parameters: Optional[_CoreSingleExecuteParams] = None,
+ *,
execution_options: Optional[_ExecuteOptionsParameter] = None,
) -> Any:
r"""Executes a SQL statement construct and returns a scalar object.
@@ -508,13 +576,36 @@ class AsyncConnection(
first row returned.
"""
- result = await self.execute(statement, parameters, execution_options)
+ result = await self.execute(
+ statement, parameters, execution_options=execution_options
+ )
return result.scalar()
+ @overload
+ async def scalars(
+ self,
+ statement: TypedReturnsRows[Tuple[_T]],
+ parameters: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> ScalarResult[_T]:
+ ...
+
+ @overload
+ async def scalars(
+ self,
+ statement: Executable,
+ parameters: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> ScalarResult[Any]:
+ ...
+
async def scalars(
self,
statement: Executable,
parameters: Optional[_CoreSingleExecuteParams] = None,
+ *,
execution_options: Optional[_ExecuteOptionsParameter] = None,
) -> ScalarResult[Any]:
r"""Executes a SQL statement construct and returns a scalar objects.
@@ -528,13 +619,36 @@ class AsyncConnection(
.. versionadded:: 1.4.24
"""
- result = await self.execute(statement, parameters, execution_options)
+ result = await self.execute(
+ statement, parameters, execution_options=execution_options
+ )
return result.scalars()
+ @overload
+ async def stream_scalars(
+ self,
+ statement: TypedReturnsRows[Tuple[_T]],
+ parameters: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> AsyncScalarResult[_T]:
+ ...
+
+ @overload
async def stream_scalars(
self,
statement: Executable,
parameters: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> AsyncScalarResult[Any]:
+ ...
+
+ async def stream_scalars(
+ self,
+ statement: Executable,
+ parameters: Optional[_CoreSingleExecuteParams] = None,
+ *,
execution_options: Optional[_ExecuteOptionsParameter] = None,
) -> AsyncScalarResult[Any]:
r"""Executes a SQL statement and returns a streaming scalar result
@@ -549,7 +663,9 @@ class AsyncConnection(
.. versionadded:: 1.4.24
"""
- result = await self.stream(statement, parameters, execution_options)
+ result = await self.stream(
+ statement, parameters, execution_options=execution_options
+ )
return result.scalars()
async def run_sync(
diff --git a/lib/sqlalchemy/ext/asyncio/result.py b/lib/sqlalchemy/ext/asyncio/result.py
index d0337554c..ff3dcf417 100644
--- a/lib/sqlalchemy/ext/asyncio/result.py
+++ b/lib/sqlalchemy/ext/asyncio/result.py
@@ -9,12 +9,15 @@ from __future__ import annotations
import operator
from typing import Any
from typing import AsyncIterator
-from typing import List
from typing import Optional
+from typing import overload
+from typing import Sequence
+from typing import Tuple
from typing import TYPE_CHECKING
from typing import TypeVar
from . import exc as async_exc
+from ... import util
from ...engine.result import _NO_ROW
from ...engine.result import _R
from ...engine.result import FilterResult
@@ -24,6 +27,7 @@ from ...engine.result import ResultMetaData
from ...engine.row import Row
from ...engine.row import RowMapping
from ...util.concurrency import greenlet_spawn
+from ...util.typing import Literal
if TYPE_CHECKING:
from ...engine import CursorResult
@@ -32,9 +36,14 @@ if TYPE_CHECKING:
from ...engine.result import _UniqueFilterType
from ...engine.result import RMKeyView
+_T = TypeVar("_T", bound=Any)
+_TP = TypeVar("_TP", bound=Tuple[Any, ...])
+
class AsyncCommon(FilterResult[_R]):
- _real_result: Result
+ __slots__ = ()
+
+ _real_result: Result[Any]
_metadata: ResultMetaData
async def close(self) -> None:
@@ -43,10 +52,10 @@ class AsyncCommon(FilterResult[_R]):
await greenlet_spawn(self._real_result.close)
-SelfAsyncResult = TypeVar("SelfAsyncResult", bound="AsyncResult")
+SelfAsyncResult = TypeVar("SelfAsyncResult", bound="AsyncResult[Any]")
-class AsyncResult(AsyncCommon[Row]):
+class AsyncResult(AsyncCommon[Row[_TP]]):
"""An asyncio wrapper around a :class:`_result.Result` object.
The :class:`_asyncio.AsyncResult` only applies to statement executions that
@@ -67,11 +76,16 @@ class AsyncResult(AsyncCommon[Row]):
"""
- def __init__(self, real_result: Result):
+ __slots__ = ()
+
+ _real_result: Result[_TP]
+
+ def __init__(self, real_result: Result[_TP]):
self._real_result = real_result
self._metadata = real_result._metadata
self._unique_filter_state = real_result._unique_filter_state
+ self._post_creational_filter = None
# BaseCursorResult pre-generates the "_row_getter". Use that
# if available rather than building a second one
@@ -80,6 +94,43 @@ class AsyncResult(AsyncCommon[Row]):
"_row_getter", real_result.__dict__["_row_getter"]
)
+ @property
+ def t(self) -> AsyncTupleResult[_TP]:
+ """Apply a "typed tuple" typing filter to returned rows.
+
+ The :attr:`.AsyncResult.t` attribute is a synonym for calling the
+ :meth:`.AsyncResult.tuples` method.
+
+ .. versionadded:: 2.0
+
+ """
+ return self # type: ignore
+
+ def tuples(self) -> AsyncTupleResult[_TP]:
+ """Apply a "typed tuple" typing filter to returned rows.
+
+ This method returns the same :class:`.AsyncResult` object at runtime,
+ however annotates as returning a :class:`.AsyncTupleResult` object
+ that will indicate to :pep:`484` typing tools that plain typed
+ ``Tuple`` instances are returned rather than rows. This allows
+ tuple unpacking and ``__getitem__`` access of :class:`.Row` objects
+ to by typed, for those cases where the statement invoked itself
+ included typing information.
+
+ .. versionadded:: 2.0
+
+ :return: the :class:`_result.AsyncTupleResult` type at typing time.
+
+ .. seealso::
+
+ :attr:`.AsyncResult.t` - shorter synonym
+
+ :attr:`.Row.t` - :class:`.Row` version
+
+ """
+
+ return self # type: ignore
+
def keys(self) -> RMKeyView:
"""Return the :meth:`_engine.Result.keys` collection from the
underlying :class:`_engine.Result`.
@@ -115,7 +166,7 @@ class AsyncResult(AsyncCommon[Row]):
async def partitions(
self, size: Optional[int] = None
- ) -> AsyncIterator[List[Row]]:
+ ) -> AsyncIterator[Sequence[Row[_TP]]]:
"""Iterate through sub-lists of rows of the size given.
An async iterator is returned::
@@ -141,7 +192,16 @@ class AsyncResult(AsyncCommon[Row]):
else:
break
- async def fetchone(self) -> Optional[Row]:
+ async def fetchall(self) -> Sequence[Row[_TP]]:
+ """A synonym for the :meth:`.AsyncResult.all` method.
+
+ .. versionadded:: 2.0
+
+ """
+
+ return await greenlet_spawn(self._allrows)
+
+ async def fetchone(self) -> Optional[Row[_TP]]:
"""Fetch one row.
When all rows are exhausted, returns None.
@@ -163,7 +223,9 @@ class AsyncResult(AsyncCommon[Row]):
else:
return row
- async def fetchmany(self, size: Optional[int] = None) -> List[Row]:
+ async def fetchmany(
+ self, size: Optional[int] = None
+ ) -> Sequence[Row[_TP]]:
"""Fetch many rows.
When all rows are exhausted, returns an empty list.
@@ -184,7 +246,7 @@ class AsyncResult(AsyncCommon[Row]):
return await greenlet_spawn(self._manyrow_getter, self, size)
- async def all(self) -> List[Row]:
+ async def all(self) -> Sequence[Row[_TP]]:
"""Return all rows in a list.
Closes the result set after invocation. Subsequent invocations
@@ -196,17 +258,17 @@ class AsyncResult(AsyncCommon[Row]):
return await greenlet_spawn(self._allrows)
- def __aiter__(self) -> AsyncResult:
+ def __aiter__(self) -> AsyncResult[_TP]:
return self
- async def __anext__(self) -> Row:
+ async def __anext__(self) -> Row[_TP]:
row = await greenlet_spawn(self._onerow_getter, self)
if row is _NO_ROW:
raise StopAsyncIteration()
else:
return row
- async def first(self) -> Optional[Row]:
+ async def first(self) -> Optional[Row[_TP]]:
"""Fetch the first row or None if no row is present.
Closes the result set and discards remaining rows.
@@ -229,7 +291,7 @@ class AsyncResult(AsyncCommon[Row]):
"""
return await greenlet_spawn(self._only_one_row, False, False, False)
- async def one_or_none(self) -> Optional[Row]:
+ async def one_or_none(self) -> Optional[Row[_TP]]:
"""Return at most one result or raise an exception.
Returns ``None`` if the result has no rows.
@@ -251,6 +313,14 @@ class AsyncResult(AsyncCommon[Row]):
"""
return await greenlet_spawn(self._only_one_row, True, False, False)
+ @overload
+ async def scalar_one(self: AsyncResult[Tuple[_T]]) -> _T:
+ ...
+
+ @overload
+ async def scalar_one(self) -> Any:
+ ...
+
async def scalar_one(self) -> Any:
"""Return exactly one scalar result or raise an exception.
@@ -266,6 +336,16 @@ class AsyncResult(AsyncCommon[Row]):
"""
return await greenlet_spawn(self._only_one_row, True, True, True)
+ @overload
+ async def scalar_one_or_none(
+ self: AsyncResult[Tuple[_T]],
+ ) -> Optional[_T]:
+ ...
+
+ @overload
+ async def scalar_one_or_none(self) -> Optional[Any]:
+ ...
+
async def scalar_one_or_none(self) -> Optional[Any]:
"""Return exactly one or no scalar result.
@@ -281,7 +361,7 @@ class AsyncResult(AsyncCommon[Row]):
"""
return await greenlet_spawn(self._only_one_row, True, False, True)
- async def one(self) -> Row:
+ async def one(self) -> Row[_TP]:
"""Return exactly one row or raise an exception.
Raises :class:`.NoResultFound` if the result returns no
@@ -312,6 +392,14 @@ class AsyncResult(AsyncCommon[Row]):
"""
return await greenlet_spawn(self._only_one_row, True, True, False)
+ @overload
+ async def scalar(self: AsyncResult[Tuple[_T]]) -> Optional[_T]:
+ ...
+
+ @overload
+ async def scalar(self) -> Any:
+ ...
+
async def scalar(self) -> Any:
"""Fetch the first column of the first row, and close the result set.
@@ -328,7 +416,7 @@ class AsyncResult(AsyncCommon[Row]):
"""
return await greenlet_spawn(self._only_one_row, False, False, True)
- async def freeze(self) -> FrozenResult:
+ async def freeze(self) -> FrozenResult[_TP]:
"""Return a callable object that will produce copies of this
:class:`_asyncio.AsyncResult` when invoked.
@@ -351,7 +439,7 @@ class AsyncResult(AsyncCommon[Row]):
return await greenlet_spawn(FrozenResult, self)
- def merge(self, *others: AsyncResult) -> MergedResult:
+ def merge(self, *others: AsyncResult[_TP]) -> MergedResult[_TP]:
"""Merge this :class:`_asyncio.AsyncResult` with other compatible result
objects.
@@ -370,6 +458,20 @@ class AsyncResult(AsyncCommon[Row]):
(self._real_result,) + tuple(o._real_result for o in others),
)
+ @overload
+ def scalars(
+ self: AsyncResult[Tuple[_T]], index: Literal[0]
+ ) -> AsyncScalarResult[_T]:
+ ...
+
+ @overload
+ def scalars(self: AsyncResult[Tuple[_T]]) -> AsyncScalarResult[_T]:
+ ...
+
+ @overload
+ def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]:
+ ...
+
def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]:
"""Return an :class:`_asyncio.AsyncScalarResult` filtering object which
will return single elements rather than :class:`_row.Row` objects.
@@ -423,9 +525,11 @@ class AsyncScalarResult(AsyncCommon[_R]):
"""
+ __slots__ = ()
+
_generate_rows = False
- def __init__(self, real_result: Result, index: _KeyIndexType):
+ def __init__(self, real_result: Result[Any], index: _KeyIndexType):
self._real_result = real_result
if real_result._source_supports_scalars:
@@ -452,7 +556,7 @@ class AsyncScalarResult(AsyncCommon[_R]):
async def partitions(
self, size: Optional[int] = None
- ) -> AsyncIterator[List[_R]]:
+ ) -> AsyncIterator[Sequence[_R]]:
"""Iterate through sub-lists of elements of the size given.
Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that
@@ -470,12 +574,12 @@ class AsyncScalarResult(AsyncCommon[_R]):
else:
break
- async def fetchall(self) -> List[_R]:
+ async def fetchall(self) -> Sequence[_R]:
"""A synonym for the :meth:`_asyncio.AsyncScalarResult.all` method."""
return await greenlet_spawn(self._allrows)
- async def fetchmany(self, size: Optional[int] = None) -> List[_R]:
+ async def fetchmany(self, size: Optional[int] = None) -> Sequence[_R]:
"""Fetch many objects.
Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that
@@ -485,7 +589,7 @@ class AsyncScalarResult(AsyncCommon[_R]):
"""
return await greenlet_spawn(self._manyrow_getter, self, size)
- async def all(self) -> List[_R]:
+ async def all(self) -> Sequence[_R]:
"""Return all scalar values in a list.
Equivalent to :meth:`_asyncio.AsyncResult.all` except that
@@ -555,11 +659,13 @@ class AsyncMappingResult(AsyncCommon[RowMapping]):
"""
+ __slots__ = ()
+
_generate_rows = True
_post_creational_filter = operator.attrgetter("_mapping")
- def __init__(self, result: Result):
+ def __init__(self, result: Result[Any]):
self._real_result = result
self._unique_filter_state = result._unique_filter_state
self._metadata = result._metadata
@@ -602,7 +708,7 @@ class AsyncMappingResult(AsyncCommon[RowMapping]):
async def partitions(
self, size: Optional[int] = None
- ) -> AsyncIterator[List[RowMapping]]:
+ ) -> AsyncIterator[Sequence[RowMapping]]:
"""Iterate through sub-lists of elements of the size given.
@@ -621,7 +727,7 @@ class AsyncMappingResult(AsyncCommon[RowMapping]):
else:
break
- async def fetchall(self) -> List[RowMapping]:
+ async def fetchall(self) -> Sequence[RowMapping]:
"""A synonym for the :meth:`_asyncio.AsyncMappingResult.all` method."""
return await greenlet_spawn(self._allrows)
@@ -641,7 +747,9 @@ class AsyncMappingResult(AsyncCommon[RowMapping]):
else:
return row
- async def fetchmany(self, size: Optional[int] = None) -> List[RowMapping]:
+ async def fetchmany(
+ self, size: Optional[int] = None
+ ) -> Sequence[RowMapping]:
"""Fetch many rows.
Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that
@@ -652,7 +760,7 @@ class AsyncMappingResult(AsyncCommon[RowMapping]):
return await greenlet_spawn(self._manyrow_getter, self, size)
- async def all(self) -> List[RowMapping]:
+ async def all(self) -> Sequence[RowMapping]:
"""Return all rows in a list.
Equivalent to :meth:`_asyncio.AsyncResult.all` except that
@@ -705,11 +813,186 @@ class AsyncMappingResult(AsyncCommon[RowMapping]):
return await greenlet_spawn(self._only_one_row, True, True, False)
-_RT = TypeVar("_RT", bound="Result")
+SelfAsyncTupleResult = TypeVar(
+ "SelfAsyncTupleResult", bound="AsyncTupleResult[Any]"
+)
+
+
+class AsyncTupleResult(AsyncCommon[_R], util.TypingOnly):
+ """a :class:`.AsyncResult` that's typed as returning plain Python tuples
+ instead of rows.
+
+ Since :class:`.Row` acts like a tuple in every way already,
+ this class is a typing only class, regular :class:`.AsyncResult` is
+ still used at runtime.
+
+ """
+
+ __slots__ = ()
+
+ if TYPE_CHECKING:
+
+ async def partitions(
+ self, size: Optional[int] = None
+ ) -> AsyncIterator[Sequence[_R]]:
+ """Iterate through sub-lists of elements of the size given.
+
+ Equivalent to :meth:`_result.Result.partitions` except that
+ tuple values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ ...
+
+ async def fetchone(self) -> Optional[_R]:
+ """Fetch one tuple.
+
+ Equivalent to :meth:`_result.Result.fetchone` except that
+ tuple values, rather than :class:`_result.Row`
+ objects, are returned.
+
+ """
+ ...
+
+ async def fetchall(self) -> Sequence[_R]:
+ """A synonym for the :meth:`_engine.ScalarResult.all` method."""
+ ...
+
+ async def fetchmany(self, size: Optional[int] = None) -> Sequence[_R]:
+ """Fetch many objects.
+
+ Equivalent to :meth:`_result.Result.fetchmany` except that
+ tuple values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ ...
+
+ async def all(self) -> Sequence[_R]: # noqa: A001
+ """Return all scalar values in a list.
+
+ Equivalent to :meth:`_result.Result.all` except that
+ tuple values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ ...
+
+ async def __aiter__(self) -> AsyncIterator[_R]:
+ ...
+
+ async def __anext__(self) -> _R:
+ ...
+
+ async def first(self) -> Optional[_R]:
+ """Fetch the first object or None if no object is present.
+
+ Equivalent to :meth:`_result.Result.first` except that
+ tuple values, rather than :class:`_result.Row` objects,
+ are returned.
+
+
+ """
+ ...
+
+ async def one_or_none(self) -> Optional[_R]:
+ """Return at most one object or raise an exception.
+
+ Equivalent to :meth:`_result.Result.one_or_none` except that
+ tuple values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ ...
+
+ async def one(self) -> _R:
+ """Return exactly one object or raise an exception.
+
+ Equivalent to :meth:`_result.Result.one` except that
+ tuple values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ ...
+
+ @overload
+ async def scalar_one(self: AsyncTupleResult[Tuple[_T]]) -> _T:
+ ...
+
+ @overload
+ async def scalar_one(self) -> Any:
+ ...
+
+ async def scalar_one(self) -> Any:
+ """Return exactly one scalar result or raise an exception.
+
+ This is equivalent to calling :meth:`.Result.scalars` and then
+ :meth:`.Result.one`.
+
+ .. seealso::
+
+ :meth:`.Result.one`
+
+ :meth:`.Result.scalars`
+
+ """
+ ...
+
+ @overload
+ async def scalar_one_or_none(
+ self: AsyncTupleResult[Tuple[_T]],
+ ) -> Optional[_T]:
+ ...
+
+ @overload
+ async def scalar_one_or_none(self) -> Optional[Any]:
+ ...
+
+ async def scalar_one_or_none(self) -> Optional[Any]:
+ """Return exactly one or no scalar result.
+
+ This is equivalent to calling :meth:`.Result.scalars` and then
+ :meth:`.Result.one_or_none`.
+
+ .. seealso::
+
+ :meth:`.Result.one_or_none`
+
+ :meth:`.Result.scalars`
+
+ """
+ ...
+
+ @overload
+ async def scalar(self: AsyncTupleResult[Tuple[_T]]) -> Optional[_T]:
+ ...
+
+ @overload
+ async def scalar(self) -> Any:
+ ...
+
+ async def scalar(self) -> Any:
+ """Fetch the first column of the first row, and close the result set.
+
+ Returns None if there are no rows to fetch.
+
+ No validation is performed to test if additional rows remain.
+
+ After calling this method, the object is fully closed,
+ e.g. the :meth:`_engine.CursorResult.close`
+ method will have been called.
+
+ :return: a Python scalar value , or None if no rows remain.
+
+ """
+ ...
+
+
+_RT = TypeVar("_RT", bound="Result[Any]")
async def _ensure_sync_result(result: _RT, calling_method: Any) -> _RT:
- cursor_result: CursorResult
+ cursor_result: CursorResult[Any]
try:
is_cursor = result._is_cursor
diff --git a/lib/sqlalchemy/ext/asyncio/scoping.py b/lib/sqlalchemy/ext/asyncio/scoping.py
index c7a6e2ca0..22a060a0d 100644
--- a/lib/sqlalchemy/ext/asyncio/scoping.py
+++ b/lib/sqlalchemy/ext/asyncio/scoping.py
@@ -12,10 +12,12 @@ from typing import Callable
from typing import Iterable
from typing import Iterator
from typing import Optional
+from typing import overload
from typing import Sequence
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
from .session import async_sessionmaker
@@ -37,9 +39,9 @@ if TYPE_CHECKING:
from ...engine import Engine
from ...engine import Result
from ...engine import Row
+ from ...engine import RowMapping
from ...engine.interfaces import _CoreAnyExecuteParams
from ...engine.interfaces import _CoreSingleExecuteParams
- from ...engine.interfaces import _ExecuteOptions
from ...engine.interfaces import _ExecuteOptionsParameter
from ...engine.result import ScalarResult
from ...orm._typing import _IdentityKeyType
@@ -52,6 +54,9 @@ if TYPE_CHECKING:
from ...sql.base import Executable
from ...sql.elements import ClauseElement
from ...sql.selectable import ForUpdateArg
+ from ...sql.selectable import TypedReturnsRows
+
+_T = TypeVar("_T", bound=Any)
@create_proxy_methods(
@@ -480,6 +485,32 @@ class async_scoped_session:
return await self._proxied.delete(instance)
+ @overload
+ async def execute(
+ self,
+ statement: TypedReturnsRows[_T],
+ params: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ _parent_execute_state: Optional[Any] = None,
+ _add_event: Optional[Any] = None,
+ ) -> Result[_T]:
+ ...
+
+ @overload
+ async def execute(
+ self,
+ statement: Executable,
+ params: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ _parent_execute_state: Optional[Any] = None,
+ _add_event: Optional[Any] = None,
+ ) -> Result[Any]:
+ ...
+
async def execute(
self,
statement: Executable,
@@ -488,7 +519,7 @@ class async_scoped_session:
execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
- ) -> Result:
+ ) -> Result[Any]:
r"""Execute a statement and return a buffered
:class:`_engine.Result` object.
@@ -916,6 +947,30 @@ class async_scoped_session:
return await self._proxied.rollback()
+ @overload
+ async def scalar(
+ self,
+ statement: TypedReturnsRows[Tuple[_T]],
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> Optional[_T]:
+ ...
+
+ @overload
+ async def scalar(
+ self,
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> Any:
+ ...
+
async def scalar(
self,
statement: Executable,
@@ -947,6 +1002,30 @@ class async_scoped_session:
**kw,
)
+ @overload
+ async def scalars(
+ self,
+ statement: TypedReturnsRows[Tuple[_T]],
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> ScalarResult[_T]:
+ ...
+
+ @overload
+ async def scalars(
+ self,
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> ScalarResult[Any]:
+ ...
+
async def scalars(
self,
statement: Executable,
@@ -984,6 +1063,19 @@ class async_scoped_session:
**kw,
)
+ @overload
+ async def stream(
+ self,
+ statement: TypedReturnsRows[_T],
+ params: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> AsyncResult[_T]:
+ ...
+
+ @overload
async def stream(
self,
statement: Executable,
@@ -992,7 +1084,18 @@ class async_scoped_session:
execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
- ) -> AsyncResult:
+ ) -> AsyncResult[Any]:
+ ...
+
+ async def stream(
+ self,
+ statement: Executable,
+ params: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> AsyncResult[Any]:
r"""Execute a statement and return a streaming
:class:`_asyncio.AsyncResult` object.
@@ -1012,6 +1115,30 @@ class async_scoped_session:
**kw,
)
+ @overload
+ async def stream_scalars(
+ self,
+ statement: TypedReturnsRows[Tuple[_T]],
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> AsyncScalarResult[_T]:
+ ...
+
+ @overload
+ async def stream_scalars(
+ self,
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> AsyncScalarResult[Any]:
+ ...
+
async def stream_scalars(
self,
statement: Executable,
@@ -1323,7 +1450,7 @@ class async_scoped_session:
ident: Union[Any, Tuple[Any, ...]] = None,
*,
instance: Optional[Any] = None,
- row: Optional[Row] = None,
+ row: Optional[Union[Row[Any], RowMapping]] = None,
identity_token: Optional[Any] = None,
) -> _IdentityKeyType[Any]:
r"""Return an identity key.
diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py
index 1422f99a3..f2a69e9cd 100644
--- a/lib/sqlalchemy/ext/asyncio/session.py
+++ b/lib/sqlalchemy/ext/asyncio/session.py
@@ -12,10 +12,12 @@ from typing import Iterable
from typing import Iterator
from typing import NoReturn
from typing import Optional
+from typing import overload
from typing import Sequence
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
from . import engine
@@ -39,11 +41,10 @@ if TYPE_CHECKING:
from ...engine import Engine
from ...engine import Result
from ...engine import Row
+ from ...engine import RowMapping
from ...engine import ScalarResult
- from ...engine import Transaction
from ...engine.interfaces import _CoreAnyExecuteParams
from ...engine.interfaces import _CoreSingleExecuteParams
- from ...engine.interfaces import _ExecuteOptions
from ...engine.interfaces import _ExecuteOptionsParameter
from ...event import dispatcher
from ...orm._typing import _IdentityKeyType
@@ -59,9 +60,12 @@ if TYPE_CHECKING:
from ...sql.base import Executable
from ...sql.elements import ClauseElement
from ...sql.selectable import ForUpdateArg
+ from ...sql.selectable import TypedReturnsRows
_AsyncSessionBind = Union["AsyncEngine", "AsyncConnection"]
+_T = TypeVar("_T", bound=Any)
+
class _SyncSessionCallable(Protocol):
def __call__(self, session: Session, *arg: Any, **kw: Any) -> Any:
@@ -257,6 +261,32 @@ class AsyncSession(ReversibleProxy[Session]):
return await greenlet_spawn(fn, self.sync_session, *arg, **kw)
+ @overload
+ async def execute(
+ self,
+ statement: TypedReturnsRows[_T],
+ params: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ _parent_execute_state: Optional[Any] = None,
+ _add_event: Optional[Any] = None,
+ ) -> Result[_T]:
+ ...
+
+ @overload
+ async def execute(
+ self,
+ statement: Executable,
+ params: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ _parent_execute_state: Optional[Any] = None,
+ _add_event: Optional[Any] = None,
+ ) -> Result[Any]:
+ ...
+
async def execute(
self,
statement: Executable,
@@ -265,7 +295,7 @@ class AsyncSession(ReversibleProxy[Session]):
execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
- ) -> Result:
+ ) -> Result[Any]:
"""Execute a statement and return a buffered
:class:`_engine.Result` object.
@@ -292,6 +322,30 @@ class AsyncSession(ReversibleProxy[Session]):
)
return await _ensure_sync_result(result, self.execute)
+ @overload
+ async def scalar(
+ self,
+ statement: TypedReturnsRows[Tuple[_T]],
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> Optional[_T]:
+ ...
+
+ @overload
+ async def scalar(
+ self,
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> Any:
+ ...
+
async def scalar(
self,
statement: Executable,
@@ -326,6 +380,30 @@ class AsyncSession(ReversibleProxy[Session]):
)
return result
+ @overload
+ async def scalars(
+ self,
+ statement: TypedReturnsRows[Tuple[_T]],
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> ScalarResult[_T]:
+ ...
+
+ @overload
+ async def scalars(
+ self,
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> ScalarResult[Any]:
+ ...
+
async def scalars(
self,
statement: Executable,
@@ -391,6 +469,30 @@ class AsyncSession(ReversibleProxy[Session]):
)
return result_obj
+ @overload
+ async def stream(
+ self,
+ statement: TypedReturnsRows[_T],
+ params: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> AsyncResult[_T]:
+ ...
+
+ @overload
+ async def stream(
+ self,
+ statement: Executable,
+ params: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> AsyncResult[Any]:
+ ...
+
async def stream(
self,
statement: Executable,
@@ -399,7 +501,7 @@ class AsyncSession(ReversibleProxy[Session]):
execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
- ) -> AsyncResult:
+ ) -> AsyncResult[Any]:
"""Execute a statement and return a streaming
:class:`_asyncio.AsyncResult` object.
@@ -423,6 +525,30 @@ class AsyncSession(ReversibleProxy[Session]):
)
return AsyncResult(result)
+ @overload
+ async def stream_scalars(
+ self,
+ statement: TypedReturnsRows[Tuple[_T]],
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> AsyncScalarResult[_T]:
+ ...
+
+ @overload
+ async def stream_scalars(
+ self,
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> AsyncScalarResult[Any]:
+ ...
+
async def stream_scalars(
self,
statement: Executable,
@@ -1215,7 +1341,7 @@ class AsyncSession(ReversibleProxy[Session]):
ident: Union[Any, Tuple[Any, ...]] = None,
*,
instance: Optional[Any] = None,
- row: Optional[Row] = None,
+ row: Optional[Union[Row[Any], RowMapping]] = None,
identity_token: Optional[Any] = None,
) -> _IdentityKeyType[Any]:
r"""Return an identity key.
diff --git a/lib/sqlalchemy/ext/instrumentation.py b/lib/sqlalchemy/ext/instrumentation.py
index b1138a4ad..c14b466eb 100644
--- a/lib/sqlalchemy/ext/instrumentation.py
+++ b/lib/sqlalchemy/ext/instrumentation.py
@@ -23,6 +23,7 @@ from ..orm import base as orm_base
from ..orm import collections
from ..orm import exc as orm_exc
from ..orm import instrumentation as orm_instrumentation
+from ..orm import util as orm_util
from ..orm.instrumentation import _default_dict_getter
from ..orm.instrumentation import _default_manager_getter
from ..orm.instrumentation import _default_opt_manager_getter
@@ -437,5 +438,7 @@ def _install_lookups(lookups):
attributes.manager_of_class
) = orm_instrumentation.manager_of_class = manager_of_class
orm_base.opt_manager_of_class = (
+ orm_util.opt_manager_of_class
+ ) = (
attributes.opt_manager_of_class
) = orm_instrumentation.opt_manager_of_class = opt_manager_of_class
diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py
index 457ad5c5a..48615b174 100644
--- a/lib/sqlalchemy/orm/_orm_constructors.py
+++ b/lib/sqlalchemy/orm/_orm_constructors.py
@@ -38,6 +38,7 @@ from ..exc import InvalidRequestError
from ..sql.base import SchemaEventTarget
from ..sql.schema import SchemaConst
from ..sql.selectable import FromClause
+from ..util.typing import Annotated
from ..util.typing import Literal
if TYPE_CHECKING:
@@ -45,6 +46,7 @@ if TYPE_CHECKING:
from ._typing import _ORMColumnExprArgument
from .descriptor_props import _CompositeAttrType
from .interfaces import PropComparator
+ from .mapper import Mapper
from .query import Query
from .relationships import _LazyLoadArgumentType
from .relationships import _ORMBackrefArgument
@@ -1849,9 +1851,27 @@ def clear_mappers():
mapperlib._dispose_registries(mapperlib._all_registries(), False)
+# I would really like a way to get the Type[] here that shows up
+# in a different way in typing tools, however there is no current method
+# that is accepted by mypy (subclass of Type[_O] works in pylance, rejected
+# by mypy).
+AliasedType = Annotated[Type[_O], "aliased"]
+
+
+@overload
+def aliased(
+ element: Type[_O],
+ alias: Optional[Union[Alias, Subquery]] = None,
+ name: Optional[str] = None,
+ flat: bool = False,
+ adapt_on_names: bool = False,
+) -> AliasedType[_O]:
+ ...
+
+
@overload
def aliased(
- element: _EntityType[_O],
+ element: Union[AliasedClass[_O], Mapper[_O], AliasedInsp[_O]],
alias: Optional[Union[Alias, Subquery]] = None,
name: Optional[str] = None,
flat: bool = False,
@@ -1877,7 +1897,7 @@ def aliased(
name: Optional[str] = None,
flat: bool = False,
adapt_on_names: bool = False,
-) -> Union[AliasedClass[_O], FromClause]:
+) -> Union[AliasedClass[_O], FromClause, AliasedType[_O]]:
"""Produce an alias of the given element, usually an :class:`.AliasedClass`
instance.
@@ -1885,7 +1905,8 @@ def aliased(
my_alias = aliased(MyClass)
- session.query(MyClass, my_alias).filter(MyClass.id > my_alias.id)
+ stmt = select(MyClass, my_alias).filter(MyClass.id > my_alias.id)
+ result = session.execute(stmt)
The :func:`.aliased` function is used to create an ad-hoc mapping of a
mapped class to a new selectable. By default, a selectable is generated
@@ -1911,6 +1932,9 @@ def aliased(
.. seealso::
+ :class:`.AsAliased` - a :pep:`484` typed version of
+ :func:`_orm.aliased`
+
:ref:`tutorial_orm_entity_aliases` - in the :ref:`unified_tutorial`
:ref:`orm_queryguide_orm_aliases` - in the :ref:`queryguide_toplevel`
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py
index 41d944c57..619af6510 100644
--- a/lib/sqlalchemy/orm/attributes.py
+++ b/lib/sqlalchemy/orm/attributes.py
@@ -70,6 +70,7 @@ from .. import exc
from .. import inspection
from .. import util
from ..sql import base as sql_base
+from ..sql import cache_key
from ..sql import roles
from ..sql import traversals
from ..sql import visitors
@@ -99,10 +100,8 @@ class QueryableAttribute(
traversals.HasCopyInternals,
roles.JoinTargetRole,
roles.OnClauseRole,
- roles.ColumnsClauseRole,
- roles.ExpressionElementRole[_T],
sql_base.Immutable,
- sql_base.MemoizedHasCacheKey,
+ cache_key.MemoizedHasCacheKey,
):
"""Base class for :term:`descriptor` objects that intercept
attribute events on behalf of a :class:`.MapperProperty`
diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py
index 054d52d83..367a5332d 100644
--- a/lib/sqlalchemy/orm/base.py
+++ b/lib/sqlalchemy/orm/base.py
@@ -30,6 +30,7 @@ from ._typing import insp_is_mapper
from .. import exc as sa_exc
from .. import inspection
from .. import util
+from ..sql import roles
from ..sql.elements import SQLCoreOperations
from ..util import FastIntFlag
from ..util.langhelpers import TypingOnly
@@ -483,19 +484,6 @@ def _inspect_mapped_class(
return mapper
-@inspection._inspects(type)
-def _inspect_mc(class_: Type[_O]) -> Optional[Mapper[_O]]:
- try:
- class_manager = opt_manager_of_class(class_)
- if class_manager is None or not class_manager.is_mapped:
- return None
- mapper = class_manager.mapper
- except exc.NO_STATE:
- return None
- else:
- return mapper
-
-
def _parse_mapper_argument(arg: Union[Mapper[_O], Type[_O]]) -> Mapper[_O]:
insp = inspection.inspect(arg, raiseerr=False)
if insp_is_mapper(insp):
@@ -691,7 +679,7 @@ class ORMDescriptor(Generic[_T], TypingOnly):
...
-class Mapped(ORMDescriptor[_T], TypingOnly):
+class Mapped(ORMDescriptor[_T], roles.TypedColumnsClauseRole[_T], TypingOnly):
"""Represent an ORM mapped attribute on a mapped class.
This class represents the complete descriptor interface for any class
diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py
index 4fee2d383..05287cbcf 100644
--- a/lib/sqlalchemy/orm/context.py
+++ b/lib/sqlalchemy/orm/context.py
@@ -17,6 +17,7 @@ from typing import Set
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
from . import attributes
@@ -48,14 +49,15 @@ from ..sql.base import _select_iterables
from ..sql.base import CacheableOptions
from ..sql.base import CompileState
from ..sql.base import Executable
+from ..sql.base import Generative
from ..sql.base import Options
from ..sql.dml import UpdateBase
from ..sql.elements import GroupedElement
from ..sql.elements import TextClause
+from ..sql.selectable import ExecutableReturnsRows
from ..sql.selectable import LABEL_STYLE_DISAMBIGUATE_ONLY
from ..sql.selectable import LABEL_STYLE_NONE
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
-from ..sql.selectable import ReturnsRows
from ..sql.selectable import Select
from ..sql.selectable import SelectLabelStyle
from ..sql.selectable import SelectState
@@ -72,6 +74,7 @@ if TYPE_CHECKING:
from ..sql.selectable import SelectBase
from ..sql.type_api import TypeEngine
+_T = TypeVar("_T", bound=Any)
_path_registry = PathRegistry.root
_EMPTY_DICT = util.immutabledict()
@@ -574,7 +577,7 @@ class ORMFromStatementCompileState(ORMCompileState):
return None
-class FromStatement(GroupedElement, ReturnsRows, Executable):
+class FromStatement(GroupedElement, Generative, ExecutableReturnsRows):
"""Core construct that represents a load of ORM objects from various
:class:`.ReturnsRows` and other classes including:
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index 0ca62b7e3..6a5690be2 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -61,7 +61,6 @@ from ..sql.schema import Column
from ..sql.type_api import TypeEngine
from ..util.typing import TypedDict
-
if typing.TYPE_CHECKING:
from ._typing import _EntityType
from ._typing import _IdentityKeyType
@@ -106,12 +105,12 @@ class ORMStatementRole(roles.StatementRole):
)
-class ORMColumnsClauseRole(roles.ColumnsClauseRole):
+class ORMColumnsClauseRole(roles.TypedColumnsClauseRole[_T]):
__slots__ = ()
_role_name = "ORM mapped entity, aliased entity, or Column expression"
-class ORMEntityColumnsClauseRole(ORMColumnsClauseRole):
+class ORMEntityColumnsClauseRole(ORMColumnsClauseRole[_T]):
__slots__ = ()
_role_name = "ORM mapped or aliased entity"
@@ -127,8 +126,8 @@ class ORMColumnDescription(TypedDict):
# into "type" is a bad idea
type: Union[Type[Any], TypeEngine[Any]]
aliased: bool
- expr: _ColumnsClauseArgument
- entity: Optional[_ColumnsClauseArgument]
+ expr: _ColumnsClauseArgument[Any]
+ entity: Optional[_ColumnsClauseArgument[Any]]
class _IntrospectsAnnotations:
@@ -282,7 +281,7 @@ class MapperProperty(
query_entity: _MapperEntity,
path: PathRegistry,
mapper: Mapper[Any],
- result: Result,
+ result: Result[Any],
adapter: Optional[ColumnAdapter],
populators: _PopulatorDict,
) -> None:
@@ -1170,7 +1169,7 @@ class LoaderStrategy:
path: AbstractEntityRegistry,
loadopt: Optional[_LoadElement],
mapper: Mapper[Any],
- result: Result,
+ result: Result[Any],
adapter: Optional[ORMAdapter],
populators: _PopulatorDict,
) -> None:
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index b37c080ea..083035093 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -96,7 +96,6 @@ if TYPE_CHECKING:
from .descriptor_props import Synonym
from .events import MapperEvents
from .instrumentation import ClassManager
- from .path_registry import AbstractEntityRegistry
from .path_registry import CachingEntityRegistry
from .properties import ColumnProperty
from .relationships import Relationship
@@ -108,10 +107,10 @@ if TYPE_CHECKING:
from ..sql.base import ReadOnlyColumnCollection
from ..sql.elements import ColumnClause
from ..sql.elements import ColumnElement
+ from ..sql.elements import KeyedColumnElement
from ..sql.schema import Column
from ..sql.schema import Table
from ..sql.selectable import FromClause
- from ..sql.selectable import TableClause
from ..sql.util import ColumnAdapter
from ..util import OrderedSet
@@ -161,7 +160,7 @@ _CONFIGURE_MUTEX = threading.RLock()
@log.class_logger
class Mapper(
ORMFromClauseRole,
- ORMEntityColumnsClauseRole,
+ ORMEntityColumnsClauseRole[_O],
MemoizedHasCacheKey,
InspectionAttr,
log.Identified,
@@ -1006,7 +1005,7 @@ class Mapper(
"""
- polymorphic_on: Optional[ColumnElement[Any]]
+ polymorphic_on: Optional[KeyedColumnElement[Any]]
"""The :class:`_schema.Column` or SQL expression specified as the
``polymorphic_on`` argument
for this :class:`_orm.Mapper`, within an inheritance scenario.
@@ -1699,10 +1698,10 @@ class Mapper(
instrument = True
key = getattr(col, "key", None)
if key:
- if self._should_exclude(col.key, col.key, False, col):
+ if self._should_exclude(key, key, False, col):
raise sa_exc.InvalidRequestError(
"Cannot exclude or override the "
- "discriminator column %r" % col.key
+ "discriminator column %r" % key
)
else:
self.polymorphic_on = col = col.label("_sa_polymorphic_on")
@@ -2948,7 +2947,7 @@ class Mapper(
def identity_key_from_row(
self,
- row: Optional[Union[Row, RowMapping]],
+ row: Optional[Union[Row[Any], RowMapping]],
identity_token: Optional[Any] = None,
adapter: Optional[ColumnAdapter] = None,
) -> _IdentityKeyType[_O]:
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index 9f37e8457..0ca0559b4 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -54,7 +54,7 @@ from ..util.typing import NoneType
if TYPE_CHECKING:
from ._typing import _ORMColumnExprArgument
from ..sql._typing import _InfoType
- from ..sql.elements import ColumnElement
+ from ..sql.elements import KeyedColumnElement
_T = TypeVar("_T", bound=Any)
_PT = TypeVar("_PT", bound=Any)
@@ -85,7 +85,8 @@ class ColumnProperty(
inherit_cache = True
_links_to_entity = False
- columns: List[ColumnElement[Any]]
+ columns: List[KeyedColumnElement[Any]]
+ _orig_columns: List[KeyedColumnElement[Any]]
_is_polymorphic_discriminator: bool
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index 395d01a1e..5bd302b21 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -27,6 +27,8 @@ from typing import Generic
from typing import Iterable
from typing import List
from typing import Optional
+from typing import overload
+from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from typing import TypeVar
@@ -36,6 +38,7 @@ from . import exc as orm_exc
from . import interfaces
from . import loading
from . import util as orm_util
+from ._typing import _O
from .base import _assertions
from .context import _column_descriptions
from .context import _determine_last_joined_entity
@@ -56,6 +59,7 @@ from .. import log
from .. import sql
from .. import util
from ..engine import Result
+from ..engine import Row
from ..sql import coercions
from ..sql import expression
from ..sql import roles
@@ -63,10 +67,12 @@ from ..sql import Select
from ..sql import util as sql_util
from ..sql import visitors
from ..sql._typing import _FromClauseArgument
+from ..sql._typing import _TP
from ..sql.annotation import SupportsCloneAnnotations
from ..sql.base import _entity_namespace_key
from ..sql.base import _generative
from ..sql.base import Executable
+from ..sql.base import Generative
from ..sql.expression import Exists
from ..sql.selectable import _MemoizedSelectEntities
from ..sql.selectable import _SelectFromElements
@@ -75,10 +81,33 @@ from ..sql.selectable import HasHints
from ..sql.selectable import HasPrefixes
from ..sql.selectable import HasSuffixes
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
+from ..util.typing import Literal
if TYPE_CHECKING:
+ from ._typing import _EntityType
+ from .session import Session
+ from ..engine.result import ScalarResult
+ from ..engine.row import Row
+ from ..sql._typing import _ColumnExpressionArgument
+ from ..sql._typing import _ColumnsClauseArgument
+ from ..sql._typing import _MAYBE_ENTITY
+ from ..sql._typing import _no_kw
+ from ..sql._typing import _NOT_ENTITY
+ from ..sql._typing import _PropagateAttrsType
+ from ..sql._typing import _T0
+ from ..sql._typing import _T1
+ from ..sql._typing import _T2
+ from ..sql._typing import _T3
+ from ..sql._typing import _T4
+ from ..sql._typing import _T5
+ from ..sql._typing import _T6
+ from ..sql._typing import _T7
+ from ..sql._typing import _TypedColumnClauseArgument as _TCCA
+ from ..sql.roles import TypedColumnsClauseRole
from ..sql.selectable import _SetupJoinsElement
from ..sql.selectable import Alias
+ from ..sql.selectable import ExecutableReturnsRows
+ from ..sql.selectable import ScalarSelect
from ..sql.selectable import Subquery
__all__ = ["Query", "QueryContext"]
@@ -97,6 +126,7 @@ class Query(
HasSuffixes,
HasHints,
log.Identified,
+ Generative,
Executable,
Generic[_T],
):
@@ -159,9 +189,15 @@ class Query(
# mirrors that of ClauseElement, used to propagate the "orm"
# plugin as well as the "subject" of the plugin, e.g. the mapper
# we are querying against.
- _propagate_attrs = util.immutabledict()
+ @util.memoized_property
+ def _propagate_attrs(self) -> _PropagateAttrsType:
+ return util.EMPTY_DICT
- def __init__(self, entities, session=None):
+ def __init__(
+ self,
+ entities: Sequence[_ColumnsClauseArgument[Any]],
+ session: Optional[Session] = None,
+ ):
"""Construct a :class:`_query.Query` directly.
E.g.::
@@ -207,6 +243,36 @@ class Query(
for ent in util.to_list(entities)
]
+ @overload
+ def tuples(self: Query[Row[_TP]]) -> Query[_TP]:
+ ...
+
+ @overload
+ def tuples(self: Query[_O]) -> Query[Tuple[_O]]:
+ ...
+
+ def tuples(self) -> Query[Any]:
+ """return a tuple-typed form of this :class:`.Query`.
+
+ This method invokes the :meth:`.Query.only_return_tuples`
+ method with a value of ``True``, which by itself ensures that this
+ :class:`.Query` will always return :class:`.Row` objects, even
+ if the query is made against a single entity. It then also
+ at the typing level will return a "typed" query, if possible,
+ that will type result rows as ``Tuple`` objects with typed
+ elements.
+
+ This method can be compared to the :meth:`.Result.tuples` method,
+ which returns "self", but from a typing perspective returns an object
+ that will yield typed ``Tuple`` objects for results. Typing
+ takes effect only if this :class:`.Query` object is a typed
+ query object already.
+
+ .. versionadded:: 2.0
+
+ """
+ return self.only_return_tuples(True)
+
def _entity_from_pre_ent_zero(self):
if not self._raw_columns:
return None
@@ -582,20 +648,52 @@ class Query(
return self.enable_eagerloads(False).statement.label(name)
+ @overload
+ def as_scalar(
+ self: Query[Tuple[_MAYBE_ENTITY]],
+ ) -> ScalarSelect[_MAYBE_ENTITY]:
+ ...
+
+ @overload
+ def as_scalar(
+ self: Query[Tuple[_NOT_ENTITY]],
+ ) -> ScalarSelect[_NOT_ENTITY]:
+ ...
+
+ @overload
+ def as_scalar(self) -> ScalarSelect[Any]:
+ ...
+
@util.deprecated(
"1.4",
"The :meth:`_query.Query.as_scalar` method is deprecated and will be "
"removed in a future release. Please refer to "
":meth:`_query.Query.scalar_subquery`.",
)
- def as_scalar(self):
+ def as_scalar(self) -> ScalarSelect[Any]:
"""Return the full SELECT statement represented by this
:class:`_query.Query`, converted to a scalar subquery.
"""
return self.scalar_subquery()
- def scalar_subquery(self):
+ @overload
+ def scalar_subquery(
+ self: Query[Tuple[_MAYBE_ENTITY]],
+ ) -> ScalarSelect[Any]:
+ ...
+
+ @overload
+ def scalar_subquery(
+ self: Query[Tuple[_NOT_ENTITY]],
+ ) -> ScalarSelect[_NOT_ENTITY]:
+ ...
+
+ @overload
+ def scalar_subquery(self) -> ScalarSelect[Any]:
+ ...
+
+ def scalar_subquery(self) -> ScalarSelect[Any]:
"""Return the full SELECT statement represented by this
:class:`_query.Query`, converted to a scalar subquery.
@@ -630,16 +728,31 @@ class Query(
.statement
)
- @_generative
- def only_return_tuples(self: SelfQuery, value) -> SelfQuery:
- """When set to True, the query results will always be a tuple.
+ @overload
+ def only_return_tuples(
+ self: Query[_O], value: Literal[True]
+ ) -> RowReturningQuery[Tuple[_O]]:
+ ...
- This is specifically for single element queries. The default is False.
+ @overload
+ def only_return_tuples(
+ self: Query[_O], value: Literal[False]
+ ) -> Query[_O]:
+ ...
- .. versionadded:: 1.2.5
+ @_generative
+ def only_return_tuples(self, value: bool) -> Query[Any]:
+ """When set to True, the query results will always be a
+ :class:`.Row` object.
+
+ This can change a query that normally returns a single entity
+ as a scalar to return a :class:`.Row` result in all cases.
.. seealso::
+ :meth:`.Query.tuples` - returns tuples, but also at the typing
+ level will type results as ``Tuple``.
+
:meth:`_query.Query.is_single_entity`
"""
@@ -1077,7 +1190,11 @@ class Query(
return self.filter(with_parent(instance, property, entity_zero.entity))
@_generative
- def add_entity(self: SelfQuery, entity, alias=None) -> SelfQuery:
+ def add_entity(
+ self,
+ entity: _EntityType[Any],
+ alias: Optional[Union[Alias, Subquery]] = None,
+ ) -> Query[Any]:
"""add a mapped entity to the list of result columns
to be returned."""
@@ -1209,8 +1326,107 @@ class Query(
except StopIteration:
return None
+ @overload
+ def with_entities(
+ self, _entity: _EntityType[_O], **kwargs: Any
+ ) -> ScalarInstanceQuery[_O]:
+ ...
+
+ @overload
+ def with_entities(
+ self, _colexpr: TypedColumnsClauseRole[_T]
+ ) -> RowReturningQuery[Tuple[_T]]:
+ ...
+
+ # START OVERLOADED FUNCTIONS self.with_entities RowReturningQuery 2-8
+
+ # code within this block is **programmatically,
+ # statically generated** by tools/generate_tuple_map_overloads.py
+
+ @overload
+ def with_entities(
+ self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1]
+ ) -> RowReturningQuery[Tuple[_T0, _T1]]:
+ ...
+
+ @overload
+ def with_entities(
+ self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2]
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]:
+ ...
+
+ @overload
+ def with_entities(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]:
+ ...
+
+ @overload
+ def with_entities(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]:
+ ...
+
+ @overload
+ def with_entities(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]:
+ ...
+
+ @overload
+ def with_entities(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]:
+ ...
+
+ @overload
+ def with_entities(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+ __ent7: _TCCA[_T7],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]:
+ ...
+
+ # END OVERLOADED FUNCTIONS self.with_entities
+
+ @overload
+ def with_entities(
+ self: SelfQuery, *entities: _ColumnsClauseArgument[Any]
+ ) -> SelfQuery:
+ ...
+
@_generative
- def with_entities(self: SelfQuery, *entities) -> SelfQuery:
+ def with_entities(
+ self: SelfQuery, *entities: _ColumnsClauseArgument[Any], **__kw: Any
+ ) -> SelfQuery:
r"""Return a new :class:`_query.Query`
replacing the SELECT list with the
given entities.
@@ -1234,12 +1450,14 @@ class Query(
limit(1)
"""
+ if __kw:
+ raise _no_kw()
_MemoizedSelectEntities._generate_for_statement(self)
self._set_entities(entities)
return self
@_generative
- def add_columns(self: SelfQuery, *column) -> SelfQuery:
+ def add_columns(self, *column: _ColumnExpressionArgument) -> Query[Any]:
"""Add one or more column expressions to the list
of result columns to be returned."""
@@ -1262,7 +1480,7 @@ class Query(
"is deprecated and will be removed in a "
"future release. Please use :meth:`_query.Query.add_columns`",
)
- def add_column(self, column):
+ def add_column(self, column) -> Query[Any]:
"""Add a column expression to the list of result columns to be
returned.
@@ -1472,7 +1690,9 @@ class Query(
@_generative
@_assertions(_no_statement_condition, _no_limit_offset)
- def filter(self: SelfQuery, *criterion) -> SelfQuery:
+ def filter(
+ self: SelfQuery, *criterion: _ColumnExpressionArgument[bool]
+ ) -> SelfQuery:
r"""Apply the given filtering criterion to a copy
of this :class:`_query.Query`, using SQL expressions.
@@ -1556,7 +1776,7 @@ class Query(
return self._raw_columns[0]
- def filter_by(self, **kwargs):
+ def filter_by(self: SelfQuery, **kwargs: Any) -> SelfQuery:
r"""Apply the given filtering criterion to a copy
of this :class:`_query.Query`, using keyword expressions.
@@ -1597,7 +1817,9 @@ class Query(
@_generative
@_assertions(_no_statement_condition, _no_limit_offset)
- def order_by(self: SelfQuery, *clauses) -> SelfQuery:
+ def order_by(
+ self: SelfQuery, *clauses: _ColumnExpressionArgument[Any]
+ ) -> SelfQuery:
"""Apply one or more ORDER BY criteria to the query and return
the newly resulting :class:`_query.Query`.
@@ -1635,7 +1857,9 @@ class Query(
@_generative
@_assertions(_no_statement_condition, _no_limit_offset)
- def group_by(self: SelfQuery, *clauses) -> SelfQuery:
+ def group_by(
+ self: SelfQuery, *clauses: _ColumnExpressionArgument[Any]
+ ) -> SelfQuery:
"""Apply one or more GROUP BY criterion to the query and return
the newly resulting :class:`_query.Query`.
@@ -1667,7 +1891,9 @@ class Query(
@_generative
@_assertions(_no_statement_condition, _no_limit_offset)
- def having(self: SelfQuery, criterion) -> SelfQuery:
+ def having(
+ self: SelfQuery, *having: _ColumnExpressionArgument[bool]
+ ) -> SelfQuery:
r"""Apply a HAVING criterion to the query and return the
newly resulting :class:`_query.Query`.
@@ -1684,17 +1910,17 @@ class Query(
"""
- self._having_criteria += (
- coercions.expect(
- roles.WhereHavingRole, criterion, apply_propagate_attrs=self
- ),
- )
+ for criterion in having:
+ having_criteria = coercions.expect(
+ roles.WhereHavingRole, criterion
+ )
+ self._having_criteria += (having_criteria,)
return self
def _set_op(self, expr_fn, *q):
return self._from_selectable(expr_fn(*([self] + list(q))).subquery())
- def union(self, *q):
+ def union(self: SelfQuery, *q: Query[Any]) -> SelfQuery:
"""Produce a UNION of this Query against one or more queries.
e.g.::
@@ -1733,7 +1959,7 @@ class Query(
"""
return self._set_op(expression.union, *q)
- def union_all(self, *q):
+ def union_all(self: SelfQuery, *q: Query[Any]) -> SelfQuery:
"""Produce a UNION ALL of this Query against one or more queries.
Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See
@@ -1742,7 +1968,7 @@ class Query(
"""
return self._set_op(expression.union_all, *q)
- def intersect(self, *q):
+ def intersect(self: SelfQuery, *q: Query[Any]) -> SelfQuery:
"""Produce an INTERSECT of this Query against one or more queries.
Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See
@@ -1751,7 +1977,7 @@ class Query(
"""
return self._set_op(expression.intersect, *q)
- def intersect_all(self, *q):
+ def intersect_all(self: SelfQuery, *q: Query[Any]) -> SelfQuery:
"""Produce an INTERSECT ALL of this Query against one or more queries.
Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See
@@ -1760,7 +1986,7 @@ class Query(
"""
return self._set_op(expression.intersect_all, *q)
- def except_(self, *q):
+ def except_(self: SelfQuery, *q: Query[Any]) -> SelfQuery:
"""Produce an EXCEPT of this Query against one or more queries.
Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See
@@ -1769,7 +1995,7 @@ class Query(
"""
return self._set_op(expression.except_, *q)
- def except_all(self, *q):
+ def except_all(self: SelfQuery, *q: Query[Any]) -> SelfQuery:
"""Produce an EXCEPT ALL of this Query against one or more queries.
Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See
@@ -2194,7 +2420,9 @@ class Query(
@_generative
@_assertions(_no_clauseelement_condition)
- def from_statement(self: SelfQuery, statement) -> SelfQuery:
+ def from_statement(
+ self: SelfQuery, statement: ExecutableReturnsRows
+ ) -> SelfQuery:
"""Execute the given SELECT statement and return results.
This method bypasses all internal statement compilation, and the
@@ -2283,7 +2511,7 @@ class Query(
:meth:`_query.Query.one_or_none`
"""
- return self._iter().one()
+ return self._iter().one() # type: ignore
def scalar(self) -> Any:
"""Return the first element of the first result or None
@@ -2316,7 +2544,7 @@ class Query(
def __iter__(self) -> Iterable[_T]:
return self._iter().__iter__()
- def _iter(self):
+ def _iter(self) -> Union[ScalarResult[_T], Result[_T]]:
# new style execution.
params = self._params
@@ -2837,3 +3065,7 @@ class BulkUpdate(BulkUD):
class BulkDelete(BulkUD):
"""BulkUD which handles DELETEs."""
+
+
+class RowReturningQuery(Query[Row[_TP]]):
+ pass
diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py
index 93d18b8d7..9220c44c7 100644
--- a/lib/sqlalchemy/orm/scoping.py
+++ b/lib/sqlalchemy/orm/scoping.py
@@ -13,6 +13,7 @@ from typing import Dict
from typing import Iterable
from typing import Iterator
from typing import Optional
+from typing import overload
from typing import Sequence
from typing import Tuple
from typing import Type
@@ -20,8 +21,6 @@ from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
-from . import exc as orm_exc
-from .base import class_mapper
from .session import Session
from .. import exc as sa_exc
from .. import util
@@ -33,11 +32,13 @@ from ..util import warn_deprecated
from ..util.typing import Protocol
if TYPE_CHECKING:
+ from ._typing import _EntityType
from ._typing import _IdentityKeyType
from .identity import IdentityMap
from .interfaces import ORMOption
from .mapper import Mapper
from .query import Query
+ from .query import RowReturningQuery
from .session import _BindArguments
from .session import _EntityBindKey
from .session import _PKIdentityArgument
@@ -48,19 +49,33 @@ if TYPE_CHECKING:
from ..engine import Engine
from ..engine import Result
from ..engine import Row
+ from ..engine import RowMapping
from ..engine.interfaces import _CoreAnyExecuteParams
from ..engine.interfaces import _CoreSingleExecuteParams
from ..engine.interfaces import _ExecuteOptions
from ..engine.interfaces import _ExecuteOptionsParameter
from ..engine.result import ScalarResult
from ..sql._typing import _ColumnsClauseArgument
+ from ..sql._typing import _T0
+ from ..sql._typing import _T1
+ from ..sql._typing import _T2
+ from ..sql._typing import _T3
+ from ..sql._typing import _T4
+ from ..sql._typing import _T5
+ from ..sql._typing import _T6
+ from ..sql._typing import _T7
+ from ..sql._typing import _TypedColumnClauseArgument as _TCCA
from ..sql.base import Executable
from ..sql.elements import ClauseElement
+ from ..sql.roles import TypedColumnsClauseRole
from ..sql.selectable import ForUpdateArg
+ from ..sql.selectable import TypedReturnsRows
+
+_T = TypeVar("_T", bound=Any)
class _QueryDescriptorType(Protocol):
- def __get__(self, instance: Any, owner: Type[Any]) -> Optional[Query[Any]]:
+ def __get__(self, instance: Any, owner: Type[_T]) -> Query[_T]:
...
@@ -236,7 +251,7 @@ class scoped_session:
self.registry.clear()
def query_property(
- self, query_cls: Optional[Type[Query[Any]]] = None
+ self, query_cls: Optional[Type[Query[_T]]] = None
) -> _QueryDescriptorType:
"""return a class property which produces a :class:`_query.Query`
object
@@ -264,20 +279,13 @@ class scoped_session:
"""
class query:
- def __get__(
- s, instance: Any, owner: Type[Any]
- ) -> Optional[Query[Any]]:
- try:
- mapper = class_mapper(owner)
- assert mapper is not None
- if query_cls:
- # custom query class
- return query_cls(mapper, session=self.registry())
- else:
- # session's configured query class
- return self.registry().query(mapper)
- except orm_exc.UnmappedClassError:
- return None
+ def __get__(s, instance: Any, owner: Type[_O]) -> Query[_O]:
+ if query_cls:
+ # custom query class
+ return query_cls(owner, session=self.registry()) # type: ignore # noqa: E501
+ else:
+ # session's configured query class
+ return self.registry().query(owner)
return query()
@@ -548,6 +556,32 @@ class scoped_session:
return self._proxied.delete(instance)
+ @overload
+ def execute(
+ self,
+ statement: TypedReturnsRows[_T],
+ params: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ _parent_execute_state: Optional[Any] = None,
+ _add_event: Optional[Any] = None,
+ ) -> Result[_T]:
+ ...
+
+ @overload
+ def execute(
+ self,
+ statement: Executable,
+ params: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ _parent_execute_state: Optional[Any] = None,
+ _add_event: Optional[Any] = None,
+ ) -> Result[Any]:
+ ...
+
def execute(
self,
statement: Executable,
@@ -557,7 +591,7 @@ class scoped_session:
bind_arguments: Optional[_BindArguments] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
- ) -> Result:
+ ) -> Result[Any]:
r"""Execute a SQL expression construct.
.. container:: class_bases
@@ -1430,8 +1464,103 @@ class scoped_session:
return self._proxied.merge(instance, load=load, options=options)
+ @overload
+ def query(self, _entity: _EntityType[_O]) -> Query[_O]:
+ ...
+
+ @overload
def query(
- self, *entities: _ColumnsClauseArgument, **kwargs: Any
+ self, _colexpr: TypedColumnsClauseRole[_T]
+ ) -> RowReturningQuery[Tuple[_T]]:
+ ...
+
+ # START OVERLOADED FUNCTIONS self.query RowReturningQuery 2-8
+
+ # code within this block is **programmatically,
+ # statically generated** by tools/generate_tuple_map_overloads.py
+
+ @overload
+ def query(
+ self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1]
+ ) -> RowReturningQuery[Tuple[_T0, _T1]]:
+ ...
+
+ @overload
+ def query(
+ self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2]
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]:
+ ...
+
+ @overload
+ def query(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]:
+ ...
+
+ @overload
+ def query(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]:
+ ...
+
+ @overload
+ def query(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]:
+ ...
+
+ @overload
+ def query(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]:
+ ...
+
+ @overload
+ def query(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+ __ent7: _TCCA[_T7],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]:
+ ...
+
+ # END OVERLOADED FUNCTIONS self.query
+
+ @overload
+ def query(
+ self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any
+ ) -> Query[Any]:
+ ...
+
+ def query(
+ self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any
) -> Query[Any]:
r"""Return a new :class:`_query.Query` object corresponding to this
:class:`_orm.Session`.
@@ -1559,6 +1688,30 @@ class scoped_session:
return self._proxied.rollback()
+ @overload
+ def scalar(
+ self,
+ statement: TypedReturnsRows[Tuple[_T]],
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> Optional[_T]:
+ ...
+
+ @overload
+ def scalar(
+ self,
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> Any:
+ ...
+
def scalar(
self,
statement: Executable,
@@ -1590,6 +1743,30 @@ class scoped_session:
**kw,
)
+ @overload
+ def scalars(
+ self,
+ statement: TypedReturnsRows[Tuple[_T]],
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> ScalarResult[_T]:
+ ...
+
+ @overload
+ def scalars(
+ self,
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> ScalarResult[Any]:
+ ...
+
def scalars(
self,
statement: Executable,
@@ -1848,7 +2025,7 @@ class scoped_session:
ident: Union[Any, Tuple[Any, ...]] = None,
*,
instance: Optional[Any] = None,
- row: Optional[Row] = None,
+ row: Optional[Union[Row[Any], RowMapping]] = None,
identity_token: Optional[Any] = None,
) -> _IdentityKeyType[Any]:
r"""Return an identity key.
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 74035ec0a..263d56101 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -27,6 +27,7 @@ from typing import Set
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
import weakref
@@ -85,12 +86,16 @@ from ..util.typing import Literal
from ..util.typing import Protocol
if typing.TYPE_CHECKING:
+ from ._typing import _EntityType
from ._typing import _IdentityKeyType
from ._typing import _InstanceDict
+ from ._typing import _O
+ from .context import FromStatement
from .interfaces import ORMOption
from .interfaces import UserDefinedOption
from .mapper import Mapper
from .path_registry import PathRegistry
+ from .query import RowReturningQuery
from ..engine import Result
from ..engine import Row
from ..engine import RowMapping
@@ -104,10 +109,23 @@ if typing.TYPE_CHECKING:
from ..event import _InstanceLevelDispatch
from ..sql._typing import _ColumnsClauseArgument
from ..sql._typing import _InfoType
+ from ..sql._typing import _T0
+ from ..sql._typing import _T1
+ from ..sql._typing import _T2
+ from ..sql._typing import _T3
+ from ..sql._typing import _T4
+ from ..sql._typing import _T5
+ from ..sql._typing import _T6
+ from ..sql._typing import _T7
+ from ..sql._typing import _TypedColumnClauseArgument as _TCCA
from ..sql.base import Executable
from ..sql.elements import ClauseElement
+ from ..sql.roles import TypedColumnsClauseRole
from ..sql.schema import Table
- from ..sql.selectable import TableClause
+ from ..sql.selectable import Select
+ from ..sql.selectable import TypedReturnsRows
+
+_T = TypeVar("_T", bound=Any)
__all__ = [
"Session",
@@ -189,7 +207,7 @@ class _SessionClassMethods:
ident: Union[Any, Tuple[Any, ...]] = None,
*,
instance: Optional[Any] = None,
- row: Optional[Union[Row, RowMapping]] = None,
+ row: Optional[Union[Row[Any], RowMapping]] = None,
identity_token: Optional[Any] = None,
) -> _IdentityKeyType[Any]:
"""Return an identity key.
@@ -295,7 +313,7 @@ class ORMExecuteState(util.MemoizedSlots):
params: Optional[_CoreAnyExecuteParams] = None,
execution_options: Optional[_ExecuteOptionsParameter] = None,
bind_arguments: Optional[_BindArguments] = None,
- ) -> Result:
+ ) -> Result[Any]:
"""Execute the statement represented by this
:class:`.ORMExecuteState`, without re-invoking events that have
already proceeded.
@@ -1718,7 +1736,7 @@ class Session(_SessionClassMethods, EventTarget):
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
_scalar_result: bool = ...,
- ) -> Result:
+ ) -> Result[Any]:
...
def _execute_internal(
@@ -1789,7 +1807,7 @@ class Session(_SessionClassMethods, EventTarget):
)
for idx, fn in enumerate(events_todo):
orm_exec_state._starting_event_idx = idx
- fn_result: Optional[Result] = fn(orm_exec_state)
+ fn_result: Optional[Result[Any]] = fn(orm_exec_state)
if fn_result:
if _scalar_result:
return fn_result.scalar()
@@ -1806,10 +1824,12 @@ class Session(_SessionClassMethods, EventTarget):
if _scalar_result and not compile_state_cls:
if TYPE_CHECKING:
params = cast(_CoreSingleExecuteParams, params)
- return conn.scalar(statement, params or {}, execution_options)
+ return conn.scalar(
+ statement, params or {}, execution_options=execution_options
+ )
- result: Result = conn.execute(
- statement, params or {}, execution_options
+ result: Result[Any] = conn.execute(
+ statement, params or {}, execution_options=execution_options
)
if compile_state_cls:
@@ -1827,6 +1847,32 @@ class Session(_SessionClassMethods, EventTarget):
else:
return result
+ @overload
+ def execute(
+ self,
+ statement: TypedReturnsRows[_T],
+ params: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ _parent_execute_state: Optional[Any] = None,
+ _add_event: Optional[Any] = None,
+ ) -> Result[_T]:
+ ...
+
+ @overload
+ def execute(
+ self,
+ statement: Executable,
+ params: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ _parent_execute_state: Optional[Any] = None,
+ _add_event: Optional[Any] = None,
+ ) -> Result[Any]:
+ ...
+
def execute(
self,
statement: Executable,
@@ -1836,7 +1882,7 @@ class Session(_SessionClassMethods, EventTarget):
bind_arguments: Optional[_BindArguments] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
- ) -> Result:
+ ) -> Result[Any]:
r"""Execute a SQL expression construct.
Returns a :class:`_engine.Result` object representing
@@ -1897,6 +1943,30 @@ class Session(_SessionClassMethods, EventTarget):
_add_event=_add_event,
)
+ @overload
+ def scalar(
+ self,
+ statement: TypedReturnsRows[Tuple[_T]],
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> Optional[_T]:
+ ...
+
+ @overload
+ def scalar(
+ self,
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> Any:
+ ...
+
def scalar(
self,
statement: Executable,
@@ -1923,6 +1993,30 @@ class Session(_SessionClassMethods, EventTarget):
**kw,
)
+ @overload
+ def scalars(
+ self,
+ statement: TypedReturnsRows[Tuple[_T]],
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> ScalarResult[_T]:
+ ...
+
+ @overload
+ def scalars(
+ self,
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> ScalarResult[Any]:
+ ...
+
def scalars(
self,
statement: Executable,
@@ -2284,8 +2378,103 @@ class Session(_SessionClassMethods, EventTarget):
f'{", ".join(context)} or this Session.'
)
+ @overload
+ def query(self, _entity: _EntityType[_O]) -> Query[_O]:
+ ...
+
+ @overload
+ def query(
+ self, _colexpr: TypedColumnsClauseRole[_T]
+ ) -> RowReturningQuery[Tuple[_T]]:
+ ...
+
+ # START OVERLOADED FUNCTIONS self.query RowReturningQuery 2-8
+
+ # code within this block is **programmatically,
+ # statically generated** by tools/generate_tuple_map_overloads.py
+
+ @overload
+ def query(
+ self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1]
+ ) -> RowReturningQuery[Tuple[_T0, _T1]]:
+ ...
+
+ @overload
+ def query(
+ self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2]
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]:
+ ...
+
+ @overload
+ def query(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]:
+ ...
+
+ @overload
+ def query(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]:
+ ...
+
+ @overload
+ def query(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]:
+ ...
+
+ @overload
+ def query(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]:
+ ...
+
+ @overload
+ def query(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+ __ent7: _TCCA[_T7],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]:
+ ...
+
+ # END OVERLOADED FUNCTIONS self.query
+
+ @overload
+ def query(
+ self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any
+ ) -> Query[Any]:
+ ...
+
def query(
- self, *entities: _ColumnsClauseArgument, **kwargs: Any
+ self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any
) -> Query[Any]:
"""Return a new :class:`_query.Query` object corresponding to this
:class:`_orm.Session`.
@@ -2486,7 +2675,7 @@ class Session(_SessionClassMethods, EventTarget):
with_for_update = ForUpdateArg._from_argument(with_for_update)
- stmt = sql.select(object_mapper(instance))
+ stmt: Select[Any] = sql.select(object_mapper(instance))
if (
loading.load_on_ident(
self,
diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py
index 58f141997..ab32a3981 100644
--- a/lib/sqlalchemy/orm/state.py
+++ b/lib/sqlalchemy/orm/state.py
@@ -656,13 +656,13 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_O]):
@classmethod
def _instance_level_callable_processor(
cls, manager: ClassManager[_O], fn: _LoaderCallable, key: Any
- ) -> Callable[[InstanceState[_O], _InstanceDict, Row], None]:
+ ) -> Callable[[InstanceState[_O], _InstanceDict, Row[Any]], None]:
impl = manager[key].impl
if is_collection_impl(impl):
fixed_impl = impl
def _set_callable(
- state: InstanceState[_O], dict_: _InstanceDict, row: Row
+ state: InstanceState[_O], dict_: _InstanceDict, row: Row[Any]
) -> None:
if "callables" not in state.__dict__:
state.callables = {}
@@ -674,7 +674,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_O]):
else:
def _set_callable(
- state: InstanceState[_O], dict_: _InstanceDict, row: Row
+ state: InstanceState[_O], dict_: _InstanceDict, row: Row[Any]
) -> None:
if "callables" not in state.__dict__:
state.callables = {}
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 3934de535..8148793b1 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -28,6 +28,7 @@ from typing import Union
import weakref
from . import attributes # noqa
+from . import exc
from ._typing import _O
from ._typing import insp_is_aliased_class
from ._typing import insp_is_mapper
@@ -41,6 +42,7 @@ from .base import InspectionAttr as InspectionAttr
from .base import instance_str as instance_str
from .base import object_mapper as object_mapper
from .base import object_state as object_state
+from .base import opt_manager_of_class
from .base import state_attribute_str as state_attribute_str
from .base import state_class_str as state_class_str
from .base import state_str as state_str
@@ -68,6 +70,7 @@ from ..sql.base import ColumnCollection
from ..sql.cache_key import HasCacheKey
from ..sql.cache_key import MemoizedHasCacheKey
from ..sql.elements import ColumnElement
+from ..sql.elements import KeyedColumnElement
from ..sql.selectable import FromClause
from ..util.langhelpers import MemoizedSlots
from ..util.typing import de_stringify_annotation
@@ -95,9 +98,7 @@ if typing.TYPE_CHECKING:
from ..sql.selectable import _ColumnsClauseElement
from ..sql.selectable import Alias
from ..sql.selectable import Subquery
- from ..sql.visitors import _ET
from ..sql.visitors import anon_map
- from ..sql.visitors import ExternallyTraversible
_T = TypeVar("_T", bound=Any)
@@ -341,7 +342,7 @@ def identity_key(
ident: Union[Any, Tuple[Any, ...]] = None,
*,
instance: Optional[_T] = None,
- row: Optional[Union[Row, RowMapping]] = None,
+ row: Optional[Union[Row[Any], RowMapping]] = None,
identity_token: Optional[Any] = None,
) -> _IdentityKeyType[_T]:
r"""Generate "identity key" tuples, as are used as keys in the
@@ -468,7 +469,9 @@ class ORMAdapter(sql_util.ColumnAdapter):
return not entity or entity.isa(self.mapper)
-class AliasedClass(inspection.Inspectable["AliasedInsp[_O]"], Generic[_O]):
+class AliasedClass(
+ inspection.Inspectable["AliasedInsp[_O]"], ORMColumnsClauseRole[_O]
+):
r"""Represents an "aliased" form of a mapped class for usage with Query.
The ORM equivalent of a :func:`~sqlalchemy.sql.expression.alias`
@@ -663,7 +666,7 @@ class AliasedClass(inspection.Inspectable["AliasedInsp[_O]"], Generic[_O]):
@inspection._self_inspects
class AliasedInsp(
- ORMEntityColumnsClauseRole,
+ ORMEntityColumnsClauseRole[_O],
ORMFromClauseRole,
HasCacheKey,
InspectionAttr,
@@ -1276,12 +1279,29 @@ class LoaderCriteriaOption(CriteriaOption):
inspection._inspects(AliasedClass)(lambda target: target._aliased_insp)
+@inspection._inspects(type)
+def _inspect_mc(
+ class_: Type[_O],
+) -> Optional[Mapper[_O]]:
+
+ try:
+ class_manager = opt_manager_of_class(class_)
+ if class_manager is None or not class_manager.is_mapped:
+ return None
+ mapper = class_manager.mapper
+ except exc.NO_STATE:
+
+ return None
+ else:
+ return mapper
+
+
@inspection._self_inspects
class Bundle(
- ORMColumnsClauseRole,
+ ORMColumnsClauseRole[_T],
SupportsCloneAnnotations,
MemoizedHasCacheKey,
- inspection.Inspectable["Bundle"],
+ inspection.Inspectable["Bundle[_T]"],
InspectionAttr,
):
"""A grouping of SQL expressions that are returned by a :class:`.Query`
@@ -1373,10 +1393,10 @@ class Bundle(
@property
def entity_namespace(
self,
- ) -> ReadOnlyColumnCollection[str, ColumnElement[Any]]:
+ ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]:
return self.c
- columns: ReadOnlyColumnCollection[str, ColumnElement[Any]]
+ columns: ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]
"""A namespace of SQL expressions referred to by this :class:`.Bundle`.
@@ -1402,7 +1422,7 @@ class Bundle(
"""
- c: ReadOnlyColumnCollection[str, ColumnElement[Any]]
+ c: ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]
"""An alias for :attr:`.Bundle.columns`."""
def _clone(self):
@@ -1908,9 +1928,10 @@ def _extract_mapped_subtype(
raw_annotation: Union[type, str],
cls: type,
key: str,
- attr_cls: type,
+ attr_cls: Type[Any],
required: bool,
is_dataclass_field: bool,
+ superclasses: Optional[Tuple[Type[Any], ...]] = None,
) -> Optional[Union[type, str]]:
if raw_annotation is None:
@@ -1930,9 +1951,13 @@ def _extract_mapped_subtype(
if is_dataclass_field:
return annotated
else:
- if (
- not hasattr(annotated, "__origin__")
- or not issubclass(annotated.__origin__, attr_cls) # type: ignore
+ # TODO: there don't seem to be tests for the failure
+ # conditions here
+ if not hasattr(annotated, "__origin__") or (
+ not issubclass(
+ annotated.__origin__, # type: ignore
+ superclasses if superclasses else attr_cls,
+ )
and not issubclass(attr_cls, annotated.__origin__) # type: ignore
):
our_annotated_str = (
diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py
index 84913225d..c3ebb4596 100644
--- a/lib/sqlalchemy/sql/__init__.py
+++ b/lib/sqlalchemy/sql/__init__.py
@@ -121,7 +121,6 @@ def __go(lcls: Any) -> None:
coercions.lambdas = lambdas
coercions.schema = schema
coercions.selectable = selectable
- coercions.traversals = traversals
from .annotation import _prepare_annotations
from .annotation import Annotated
diff --git a/lib/sqlalchemy/sql/_selectable_constructors.py b/lib/sqlalchemy/sql/_selectable_constructors.py
index 37d44976a..f89e8f578 100644
--- a/lib/sqlalchemy/sql/_selectable_constructors.py
+++ b/lib/sqlalchemy/sql/_selectable_constructors.py
@@ -9,12 +9,16 @@ from __future__ import annotations
from typing import Any
from typing import Optional
+from typing import overload
+from typing import Tuple
from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
from . import coercions
from . import roles
from ._typing import _ColumnsClauseArgument
+from ._typing import _no_kw
from .elements import ColumnClause
from .selectable import Alias
from .selectable import CompoundSelect
@@ -34,6 +38,17 @@ if TYPE_CHECKING:
from ._typing import _FromClauseArgument
from ._typing import _OnClauseArgument
from ._typing import _SelectStatementForCompoundArgument
+ from ._typing import _T0
+ from ._typing import _T1
+ from ._typing import _T2
+ from ._typing import _T3
+ from ._typing import _T4
+ from ._typing import _T5
+ from ._typing import _T6
+ from ._typing import _T7
+ from ._typing import _T8
+ from ._typing import _T9
+ from ._typing import _TypedColumnClauseArgument as _TCCA
from .functions import Function
from .selectable import CTE
from .selectable import HasCTE
@@ -41,6 +56,9 @@ if TYPE_CHECKING:
from .selectable import SelectBase
+_T = TypeVar("_T", bound=Any)
+
+
def alias(
selectable: FromClause, name: Optional[str] = None, flat: bool = False
) -> NamedFromClause:
@@ -89,7 +107,9 @@ def cte(
)
-def except_(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect:
+def except_(
+ *selects: _SelectStatementForCompoundArgument,
+) -> CompoundSelect:
r"""Return an ``EXCEPT`` of multiple selectables.
The returned object is an instance of
@@ -119,7 +139,7 @@ def except_all(
def exists(
__argument: Optional[
- Union[_ColumnsClauseArgument, SelectBase, ScalarSelect[bool]]
+ Union[_ColumnsClauseArgument[Any], SelectBase, ScalarSelect[Any]]
] = None,
) -> Exists:
"""Construct a new :class:`_expression.Exists` construct.
@@ -162,7 +182,9 @@ def exists(
return Exists(__argument)
-def intersect(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect:
+def intersect(
+ *selects: _SelectStatementForCompoundArgument,
+) -> CompoundSelect:
r"""Return an ``INTERSECT`` of multiple selectables.
The returned object is an instance of
@@ -306,7 +328,129 @@ def outerjoin(
return Join(left, right, onclause, isouter=True, full=full)
-def select(*entities: _ColumnsClauseArgument) -> Select:
+# START OVERLOADED FUNCTIONS select Select 1-10
+
+# code within this block is **programmatically,
+# statically generated** by tools/generate_tuple_map_overloads.py
+
+
+@overload
+def select(__ent0: _TCCA[_T0]) -> Select[Tuple[_T0]]:
+ ...
+
+
+@overload
+def select(__ent0: _TCCA[_T0], __ent1: _TCCA[_T1]) -> Select[Tuple[_T0, _T1]]:
+ ...
+
+
+@overload
+def select(
+ __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2]
+) -> Select[Tuple[_T0, _T1, _T2]]:
+ ...
+
+
+@overload
+def select(
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+) -> Select[Tuple[_T0, _T1, _T2, _T3]]:
+ ...
+
+
+@overload
+def select(
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4]]:
+ ...
+
+
+@overload
+def select(
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]:
+ ...
+
+
+@overload
+def select(
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]:
+ ...
+
+
+@overload
+def select(
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+ __ent7: _TCCA[_T7],
+) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]:
+ ...
+
+
+@overload
+def select(
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+ __ent7: _TCCA[_T7],
+ __ent8: _TCCA[_T8],
+) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8]]:
+ ...
+
+
+@overload
+def select(
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+ __ent7: _TCCA[_T7],
+ __ent8: _TCCA[_T8],
+ __ent9: _TCCA[_T9],
+) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8, _T9]]:
+ ...
+
+
+# END OVERLOADED FUNCTIONS select
+
+
+@overload
+def select(*entities: _ColumnsClauseArgument[Any], **__kw: Any) -> Select[Any]:
+ ...
+
+
+def select(*entities: _ColumnsClauseArgument[Any], **__kw: Any) -> Select[Any]:
r"""Construct a new :class:`_expression.Select`.
@@ -343,7 +487,11 @@ def select(*entities: _ColumnsClauseArgument) -> Select:
given, as well as ORM-mapped classes.
"""
-
+ # the keyword args are a necessary element in order for the typing
+ # to work out w/ the varargs vs. having named "keyword" arguments that
+ # aren't always present.
+ if __kw:
+ raise _no_kw()
return Select(*entities)
@@ -425,7 +573,9 @@ def tablesample(
return TableSample._factory(selectable, sampling, name=name, seed=seed)
-def union(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect:
+def union(
+ *selects: _SelectStatementForCompoundArgument,
+) -> CompoundSelect:
r"""Return a ``UNION`` of multiple selectables.
The returned object is an instance of
@@ -445,7 +595,9 @@ def union(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect:
return CompoundSelect._create_union(*selects)
-def union_all(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect:
+def union_all(
+ *selects: _SelectStatementForCompoundArgument,
+) -> CompoundSelect:
r"""Return a ``UNION ALL`` of multiple selectables.
The returned object is an instance of
diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py
index 53d29b628..1df530dbd 100644
--- a/lib/sqlalchemy/sql/_typing.py
+++ b/lib/sqlalchemy/sql/_typing.py
@@ -5,18 +5,27 @@ from typing import Any
from typing import Callable
from typing import Dict
from typing import Set
+from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from . import roles
+from .. import exc
from .. import util
from ..inspection import Inspectable
from ..util.typing import Literal
from ..util.typing import Protocol
if TYPE_CHECKING:
+ from datetime import date
+ from datetime import datetime
+ from datetime import time
+ from datetime import timedelta
+ from decimal import Decimal
+ from uuid import UUID
+
from .base import Executable
from .compiler import Compiled
from .compiler import DDLCompiler
@@ -26,17 +35,15 @@ if TYPE_CHECKING:
from .elements import ClauseElement
from .elements import ColumnClause
from .elements import ColumnElement
+ from .elements import KeyedColumnElement
from .elements import quoted_name
- from .elements import SQLCoreOperations
from .elements import TextClause
from .lambdas import LambdaElement
from .roles import ColumnsClauseRole
from .roles import FromClauseRole
from .schema import Column
- from .schema import DefaultGenerator
- from .schema import Sequence
- from .schema import Table
from .selectable import Alias
+ from .selectable import CTE
from .selectable import FromClause
from .selectable import Join
from .selectable import NamedFromClause
@@ -61,6 +68,30 @@ class _HasClauseElement(Protocol):
...
+# match column types that are not ORM entities
+_NOT_ENTITY = TypeVar(
+ "_NOT_ENTITY",
+ int,
+ str,
+ "datetime",
+ "date",
+ "time",
+ "timedelta",
+ "UUID",
+ float,
+ "Decimal",
+)
+
+_MAYBE_ENTITY = TypeVar(
+ "_MAYBE_ENTITY",
+ roles.ColumnsClauseRole,
+ Literal["*", 1],
+ Type[Any],
+ Inspectable[_HasClauseElement],
+ _HasClauseElement,
+)
+
+
# convention:
# XYZArgument - something that the end user is passing to a public API method
# XYZElement - the internal representation that we use for the thing.
@@ -76,9 +107,10 @@ _TextCoercedExpressionArgument = Union[
]
_ColumnsClauseArgument = Union[
- Literal["*", 1],
+ roles.TypedColumnsClauseRole[_T],
roles.ColumnsClauseRole,
- Type[Any],
+ Literal["*", 1],
+ Type[_T],
Inspectable[_HasClauseElement],
_HasClauseElement,
]
@@ -92,6 +124,24 @@ sets; select(...), insert().returning(...), etc.
"""
+_TypedColumnClauseArgument = Union[
+ roles.TypedColumnsClauseRole[_T], roles.ExpressionElementRole[_T], Type[_T]
+]
+
+_TP = TypeVar("_TP", bound=Tuple[Any, ...])
+
+_T0 = TypeVar("_T0", bound=Any)
+_T1 = TypeVar("_T1", bound=Any)
+_T2 = TypeVar("_T2", bound=Any)
+_T3 = TypeVar("_T3", bound=Any)
+_T4 = TypeVar("_T4", bound=Any)
+_T5 = TypeVar("_T5", bound=Any)
+_T6 = TypeVar("_T6", bound=Any)
+_T7 = TypeVar("_T7", bound=Any)
+_T8 = TypeVar("_T8", bound=Any)
+_T9 = TypeVar("_T9", bound=Any)
+
+
_ColumnExpressionArgument = Union[
"ColumnElement[_T]",
_HasClauseElement,
@@ -169,6 +219,7 @@ _DMLTableArgument = Union[
"TableClause",
"Join",
"Alias",
+ "CTE",
Type[Any],
Inspectable[_HasClauseElement],
_HasClauseElement,
@@ -194,6 +245,11 @@ if TYPE_CHECKING:
def is_column_element(c: ClauseElement) -> TypeGuard[ColumnElement[Any]]:
...
+ def is_keyed_column_element(
+ c: ClauseElement,
+ ) -> TypeGuard[KeyedColumnElement[Any]]:
+ ...
+
def is_text_clause(c: ClauseElement) -> TypeGuard[TextClause]:
...
@@ -216,7 +272,7 @@ if TYPE_CHECKING:
def is_select_statement(
t: Union[Executable, ReturnsRows]
- ) -> TypeGuard[Select]:
+ ) -> TypeGuard[Select[Any]]:
...
def is_table(t: FromClause) -> TypeGuard[TableClause]:
@@ -234,6 +290,7 @@ else:
is_ddl_compiler = operator.attrgetter("is_ddl")
is_named_from_clause = operator.attrgetter("named_with_column")
is_column_element = operator.attrgetter("_is_column_element")
+ is_keyed_column_element = operator.attrgetter("_is_keyed_column_element")
is_text_clause = operator.attrgetter("_is_text_clause")
is_from_clause = operator.attrgetter("_is_from_clause")
is_tuple_type = operator.attrgetter("_is_tuple_type")
@@ -260,3 +317,10 @@ def is_has_clause_element(s: object) -> TypeGuard[_HasClauseElement]:
def is_insert_update(c: ClauseElement) -> TypeGuard[ValuesBase]:
return c.is_dml and (c.is_insert or c.is_update) # type: ignore
+
+
+def _no_kw() -> exc.ArgumentError:
+ return exc.ArgumentError(
+ "Additional keyword arguments are not accepted by this "
+ "function/method. The presence of **kw is for pep-484 typing purposes"
+ )
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index f81878d55..790edefc6 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -62,10 +62,10 @@ if TYPE_CHECKING:
from . import coercions
from . import elements
from . import type_api
- from ._typing import _ColumnsClauseArgument
from .elements import BindParameter
- from .elements import ColumnClause
+ from .elements import ColumnClause # noqa
from .elements import ColumnElement
+ from .elements import KeyedColumnElement
from .elements import NamedColumn
from .elements import SQLCoreOperations
from .elements import TextClause
@@ -74,7 +74,6 @@ if TYPE_CHECKING:
from .selectable import FromClause
from ..engine import Connection
from ..engine import CursorResult
- from ..engine import Result
from ..engine.base import _CompiledCacheType
from ..engine.interfaces import _CoreMultiExecuteParams
from ..engine.interfaces import _ExecuteOptions
@@ -704,8 +703,11 @@ class InPlaceGenerative(HasMemoized):
"""Provide a method-chaining pattern in conjunction with the
@_generative decorator that mutates in place."""
+ __slots__ = ()
+
def _generate(self):
skip = self._memoized_keys
+ # note __dict__ needs to be in __slots__ if this is used
for k in skip:
self.__dict__.pop(k, None)
return self
@@ -937,7 +939,7 @@ class ExecutableOption(HasCopyInternals):
SelfExecutable = TypeVar("SelfExecutable", bound="Executable")
-class Executable(roles.StatementRole, Generative):
+class Executable(roles.StatementRole):
"""Mark a :class:`_expression.ClauseElement` as supporting execution.
:class:`.Executable` is a superclass for all "statement" types
@@ -994,7 +996,7 @@ class Executable(roles.StatementRole, Generative):
connection: Connection,
distilled_params: _CoreMultiExecuteParams,
execution_options: _ExecuteOptionsParameter,
- ) -> CursorResult:
+ ) -> CursorResult[Any]:
...
def _execute_on_scalar(
@@ -1253,7 +1255,7 @@ class SchemaVisitor(ClauseVisitor):
_COLKEY = TypeVar("_COLKEY", Union[None, str], str)
_COL_co = TypeVar("_COL_co", bound="ColumnElement[Any]", covariant=True)
-_COL = TypeVar("_COL", bound="ColumnElement[Any]")
+_COL = TypeVar("_COL", bound="KeyedColumnElement[Any]")
class ColumnCollection(Generic[_COLKEY, _COL_co]):
@@ -1505,6 +1507,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
) -> None:
"""populate from an iterator of (key, column)"""
cols = list(iter_)
+
self._collection[:] = cols
self._colset.update(c for k, c in self._collection)
self._index.update(
diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py
index 0659709ab..9b7231360 100644
--- a/lib/sqlalchemy/sql/coercions.py
+++ b/lib/sqlalchemy/sql/coercions.py
@@ -29,6 +29,7 @@ from typing import Union
from . import operators
from . import roles
from . import visitors
+from ._typing import is_from_clause
from .base import ExecutableOption
from .base import Options
from .cache_key import HasCacheKey
@@ -38,25 +39,18 @@ from .. import inspection
from .. import util
from ..util.typing import Literal
-if not typing.TYPE_CHECKING:
- elements = None
- lambdas = None
- schema = None
- selectable = None
- traversals = None
-
if typing.TYPE_CHECKING:
from . import elements
from . import lambdas
from . import schema
from . import selectable
- from . import traversals
from ._typing import _ColumnExpressionArgument
from ._typing import _ColumnsClauseArgument
from ._typing import _DDLColumnArgument
from ._typing import _DMLTableArgument
from ._typing import _FromClauseArgument
from .dml import _DMLTableElement
+ from .elements import BindParameter
from .elements import ClauseElement
from .elements import ColumnClause
from .elements import ColumnElement
@@ -64,9 +58,7 @@ if typing.TYPE_CHECKING:
from .elements import SQLCoreOperations
from .schema import Column
from .selectable import _ColumnsClauseElement
- from .selectable import _JoinTargetElement
from .selectable import _JoinTargetProtocol
- from .selectable import _OnClauseElement
from .selectable import FromClause
from .selectable import HasCTE
from .selectable import SelectBase
@@ -170,6 +162,15 @@ def expect(
@overload
def expect(
+ role: Type[roles.LiteralValueRole],
+ element: Any,
+ **kw: Any,
+) -> BindParameter[Any]:
+ ...
+
+
+@overload
+def expect(
role: Type[roles.DDLReferredColumnRole],
element: Any,
**kw: Any,
@@ -272,7 +273,7 @@ def expect(
@overload
def expect(
role: Type[roles.ColumnsClauseRole],
- element: _ColumnsClauseArgument,
+ element: _ColumnsClauseArgument[Any],
**kw: Any,
) -> _ColumnsClauseElement:
...
@@ -933,7 +934,7 @@ class GroupByImpl(ByOfImpl, RoleImpl):
argname: Optional[str] = None,
**kw: Any,
) -> Any:
- if isinstance(resolved, roles.StrictFromClauseRole):
+ if is_from_clause(resolved):
return elements.ClauseList(*resolved.c)
else:
return resolved
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index c524a2602..a1b25b8a6 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -80,7 +80,6 @@ from ..util.typing import Protocol
from ..util.typing import TypedDict
if typing.TYPE_CHECKING:
- from . import roles
from .annotation import _AnnotationDict
from .base import _AmbiguousTableNameMap
from .base import CompileState
@@ -95,7 +94,6 @@ if typing.TYPE_CHECKING:
from .elements import ColumnElement
from .elements import Label
from .functions import Function
- from .selectable import Alias
from .selectable import AliasedReturnsRows
from .selectable import CompoundSelectState
from .selectable import CTE
@@ -386,7 +384,7 @@ class _CompilerStackEntry(_BaseCompilerStackEntry, total=False):
need_result_map_for_nested: bool
need_result_map_for_compound: bool
select_0: ReturnsRows
- insert_from_select: Select
+ insert_from_select: Select[Any]
class ExpandedState(NamedTuple):
@@ -2834,15 +2832,31 @@ class SQLCompiler(Compiled):
"unique bind parameter of the same name" % name
)
elif existing._is_crud or bindparam._is_crud:
- raise exc.CompileError(
- "bindparam() name '%s' is reserved "
- "for automatic usage in the VALUES or SET "
- "clause of this "
- "insert/update statement. Please use a "
- "name other than column name when using bindparam() "
- "with insert() or update() (for example, 'b_%s')."
- % (bindparam.key, bindparam.key)
- )
+ if existing._is_crud and bindparam._is_crud:
+ # TODO: this condition is not well understood.
+ # see tests in test/sql/test_update.py
+ raise exc.CompileError(
+ "Encountered unsupported case when compiling an "
+ "INSERT or UPDATE statement. If this is a "
+ "multi-table "
+ "UPDATE statement, please provide string-named "
+ "arguments to the "
+ "values() method with distinct names; support for "
+ "multi-table UPDATE statements that "
+ "target multiple tables for UPDATE is very "
+ "limited",
+ )
+ else:
+ raise exc.CompileError(
+ f"bindparam() name '{bindparam.key}' is reserved "
+ "for automatic usage in the VALUES or SET "
+ "clause of this "
+ "insert/update statement. Please use a "
+ "name other than column name when using "
+ "bindparam() "
+ "with insert() or update() (for example, "
+ f"'b_{bindparam.key}')."
+ )
self.binds[bindparam.key] = self.binds[name] = bindparam
@@ -3881,7 +3895,7 @@ class SQLCompiler(Compiled):
return text
def _setup_select_hints(
- self, select: Select
+ self, select: Select[Any]
) -> Tuple[str, _FromHintsType]:
byfrom = dict(
[
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index e4408cd31..29d7b45d7 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -22,6 +22,7 @@ from typing import MutableMapping
from typing import NamedTuple
from typing import Optional
from typing import overload
+from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
@@ -30,8 +31,10 @@ from . import coercions
from . import dml
from . import elements
from . import roles
+from .elements import ColumnClause
from .schema import default_is_clause_element
from .schema import default_is_sequence
+from .selectable import TableClause
from .. import exc
from .. import util
from ..util.typing import Literal
@@ -41,16 +44,9 @@ if TYPE_CHECKING:
from .compiler import SQLCompiler
from .dml import _DMLColumnElement
from .dml import DMLState
- from .dml import Insert
- from .dml import Update
- from .dml import UpdateDMLState
from .dml import ValuesBase
- from .elements import ClauseElement
- from .elements import ColumnClause
from .elements import ColumnElement
- from .elements import TextClause
from .schema import _SQLExprDefault
- from .schema import Column
from .selectable import TableClause
REQUIRED = util.symbol(
@@ -68,12 +64,20 @@ values present.
)
+def _as_dml_column(c: ColumnElement[Any]) -> ColumnClause[Any]:
+ if not isinstance(c, ColumnClause):
+ raise exc.CompileError(
+ f"Can't create DML statement against column expression {c!r}"
+ )
+ return c
+
+
class _CrudParams(NamedTuple):
- single_params: List[
- Tuple[ColumnClause[Any], str, Optional[Union[str, _SQLExprDefault]]]
+ single_params: Sequence[
+ Tuple[ColumnElement[Any], str, Optional[Union[str, _SQLExprDefault]]]
]
all_multi_params: List[
- List[
+ Sequence[
Tuple[
ColumnClause[Any],
str,
@@ -274,7 +278,7 @@ def _get_crud_params(
compiler,
stmt,
compile_state,
- cast("List[Tuple[ColumnClause[Any], str, str]]", values),
+ cast("Sequence[Tuple[ColumnClause[Any], str, str]]", values),
cast("Callable[..., str]", _column_as_key),
kw,
)
@@ -290,7 +294,7 @@ def _get_crud_params(
# insert_executemany_returning mode :)
values = [
(
- stmt.table.columns[0],
+ _as_dml_column(stmt.table.columns[0]),
compiler.preparer.format_column(stmt.table.columns[0]),
"DEFAULT",
)
@@ -1135,10 +1139,10 @@ def _extend_values_for_multiparams(
compiler: SQLCompiler,
stmt: ValuesBase,
compile_state: DMLState,
- initial_values: List[Tuple[ColumnClause[Any], str, str]],
+ initial_values: Sequence[Tuple[ColumnClause[Any], str, str]],
_column_as_key: Callable[..., str],
kw: Dict[str, Any],
-) -> List[List[Tuple[ColumnClause[Any], str, str]]]:
+) -> List[Sequence[Tuple[ColumnClause[Any], str, str]]]:
values_0 = initial_values
values = [initial_values]
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index 8307f6400..e0f162fc8 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -22,15 +22,19 @@ from typing import List
from typing import MutableMapping
from typing import NoReturn
from typing import Optional
+from typing import overload
from typing import Sequence
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
from . import coercions
from . import roles
from . import util as sql_util
+from ._typing import _no_kw
+from ._typing import _TP
from ._typing import is_column_element
from ._typing import is_named_from_clause
from .base import _entity_namespace_key
@@ -42,6 +46,7 @@ from .base import ColumnCollection
from .base import CompileState
from .base import DialectKWArgs
from .base import Executable
+from .base import Generative
from .base import HasCompileState
from .elements import BooleanClauseList
from .elements import ClauseElement
@@ -49,12 +54,13 @@ from .elements import ColumnClause
from .elements import ColumnElement
from .elements import Null
from .selectable import Alias
+from .selectable import ExecutableReturnsRows
from .selectable import FromClause
from .selectable import HasCTE
from .selectable import HasPrefixes
from .selectable import Join
-from .selectable import ReturnsRows
from .selectable import TableClause
+from .selectable import TypedReturnsRows
from .sqltypes import NullType
from .visitors import InternalTraversal
from .. import exc
@@ -66,9 +72,19 @@ if TYPE_CHECKING:
from ._typing import _ColumnsClauseArgument
from ._typing import _DMLColumnArgument
from ._typing import _DMLTableArgument
- from ._typing import _FromClauseArgument
+ from ._typing import _T0 # noqa
+ from ._typing import _T1 # noqa
+ from ._typing import _T2 # noqa
+ from ._typing import _T3 # noqa
+ from ._typing import _T4 # noqa
+ from ._typing import _T5 # noqa
+ from ._typing import _T6 # noqa
+ from ._typing import _T7 # noqa
+ from ._typing import _TypedColumnClauseArgument as _TCCA # noqa
from .base import ReadOnlyColumnCollection
from .compiler import SQLCompiler
+ from .elements import ColumnElement
+ from .elements import KeyedColumnElement
from .selectable import _ColumnsClauseElement
from .selectable import _SelectIterable
from .selectable import Select
@@ -88,6 +104,8 @@ else:
isinsert = operator.attrgetter("isinsert")
+_T = TypeVar("_T", bound=Any)
+
_DMLColumnElement = Union[str, ColumnClause[Any]]
_DMLTableElement = Union[TableClause, Alias, Join]
@@ -185,6 +203,11 @@ class DMLState(CompileState):
"%s construct does not support "
"multiple parameter sets." % statement.__visit_name__.upper()
)
+ else:
+ assert isinstance(statement, Insert)
+
+ # which implies...
+ # assert isinstance(statement.table, TableClause)
for parameters in statement._multi_values:
multi_parameters: List[MutableMapping[_DMLColumnElement, Any]] = [
@@ -291,7 +314,9 @@ class UpdateDMLState(DMLState):
elif statement._multi_values:
self._process_multi_values(statement)
self._extra_froms = ef = self._make_extra_froms(statement)
- self.is_multitable = mt = ef and self._dict_parameters
+
+ self.is_multitable = mt = ef
+
self.include_table_with_column_exprs = bool(
mt and compiler.render_table_with_column_in_update_from
)
@@ -317,8 +342,8 @@ class UpdateBase(
HasCompileState,
DialectKWArgs,
HasPrefixes,
- ReturnsRows,
- Executable,
+ Generative,
+ ExecutableReturnsRows,
ClauseElement,
):
"""Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements."""
@@ -383,8 +408,8 @@ class UpdateBase(
@_generative
def returning(
- self: SelfUpdateBase, *cols: _ColumnsClauseArgument
- ) -> SelfUpdateBase:
+ self, *cols: _ColumnsClauseArgument[Any], **__kw: Any
+ ) -> UpdateBase:
r"""Add a :term:`RETURNING` or equivalent clause to this statement.
e.g.:
@@ -454,6 +479,8 @@ class UpdateBase(
:ref:`tutorial_insert_returning` - in the :ref:`unified_tutorial`
""" # noqa: E501
+ if __kw:
+ raise _no_kw()
if self._return_defaults:
raise exc.InvalidRequestError(
"return_defaults() is already configured on this statement"
@@ -464,7 +491,7 @@ class UpdateBase(
return self
def corresponding_column(
- self, column: ColumnElement[Any], require_embedded: bool = False
+ self, column: KeyedColumnElement[Any], require_embedded: bool = False
) -> Optional[ColumnElement[Any]]:
return self.exported_columns.corresponding_column(
column, require_embedded=require_embedded
@@ -628,7 +655,7 @@ class ValuesBase(UpdateBase):
_supports_multi_parameters = False
- select: Optional[Select] = None
+ select: Optional[Select[Any]] = None
"""SELECT statement for INSERT .. FROM SELECT"""
_post_values_clause: Optional[ClauseElement] = None
@@ -804,11 +831,15 @@ class ValuesBase(UpdateBase):
)
elif isinstance(arg, collections_abc.Sequence):
-
if arg and isinstance(arg[0], (list, dict, tuple)):
self._multi_values += (arg,)
return self
+ if TYPE_CHECKING:
+ # crud.py raises during compilation if this is not the
+ # case
+ assert isinstance(self, Insert)
+
# tuple values
arg = {c.key: value for c, value in zip(self.table.c, arg)}
@@ -1010,7 +1041,7 @@ class Insert(ValuesBase):
def from_select(
self: SelfInsert,
names: List[str],
- select: Select,
+ select: Select[Any],
include_defaults: bool = True,
) -> SelfInsert:
"""Return a new :class:`_expression.Insert` construct which represents
@@ -1073,6 +1104,114 @@ class Insert(ValuesBase):
self.select = coercions.expect(roles.DMLSelectRole, select)
return self
+ if TYPE_CHECKING:
+
+ # START OVERLOADED FUNCTIONS self.returning ReturningInsert 1-8
+
+ # code within this block is **programmatically,
+ # statically generated** by tools/generate_tuple_map_overloads.py
+
+ @overload
+ def returning(self, __ent0: _TCCA[_T0]) -> ReturningInsert[Tuple[_T0]]:
+ ...
+
+ @overload
+ def returning(
+ self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1]
+ ) -> ReturningInsert[Tuple[_T0, _T1]]:
+ ...
+
+ @overload
+ def returning(
+ self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2]
+ ) -> ReturningInsert[Tuple[_T0, _T1, _T2]]:
+ ...
+
+ @overload
+ def returning(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3]]:
+ ...
+
+ @overload
+ def returning(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4]]:
+ ...
+
+ @overload
+ def returning(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]:
+ ...
+
+ @overload
+ def returning(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+ ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]:
+ ...
+
+ @overload
+ def returning(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+ __ent7: _TCCA[_T7],
+ ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]:
+ ...
+
+ # END OVERLOADED FUNCTIONS self.returning
+
+ @overload
+ def returning(
+ self, *cols: _ColumnsClauseArgument[Any], **__kw: Any
+ ) -> ReturningInsert[Any]:
+ ...
+
+ def returning(
+ self, *cols: _ColumnsClauseArgument[Any], **__kw: Any
+ ) -> ReturningInsert[Any]:
+ ...
+
+
+class ReturningInsert(Insert, TypedReturnsRows[_TP]):
+ """Typing-only class that establishes a generic type form of
+ :class:`.Insert` which tracks returned column types.
+
+ This datatype is delivered when calling the
+ :meth:`.Insert.returning` method.
+
+ .. versionadded:: 2.0
+
+ """
+
SelfDMLWhereBase = typing.TypeVar("SelfDMLWhereBase", bound="DMLWhereBase")
@@ -1264,6 +1403,113 @@ class Update(DMLWhereBase, ValuesBase):
self._inline = True
return self
+ if TYPE_CHECKING:
+ # START OVERLOADED FUNCTIONS self.returning ReturningUpdate 1-8
+
+ # code within this block is **programmatically,
+ # statically generated** by tools/generate_tuple_map_overloads.py
+
+ @overload
+ def returning(self, __ent0: _TCCA[_T0]) -> ReturningUpdate[Tuple[_T0]]:
+ ...
+
+ @overload
+ def returning(
+ self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1]
+ ) -> ReturningUpdate[Tuple[_T0, _T1]]:
+ ...
+
+ @overload
+ def returning(
+ self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2]
+ ) -> ReturningUpdate[Tuple[_T0, _T1, _T2]]:
+ ...
+
+ @overload
+ def returning(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3]]:
+ ...
+
+ @overload
+ def returning(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4]]:
+ ...
+
+ @overload
+ def returning(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]:
+ ...
+
+ @overload
+ def returning(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+ ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]:
+ ...
+
+ @overload
+ def returning(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+ __ent7: _TCCA[_T7],
+ ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]:
+ ...
+
+ # END OVERLOADED FUNCTIONS self.returning
+
+ @overload
+ def returning(
+ self, *cols: _ColumnsClauseArgument[Any], **__kw: Any
+ ) -> ReturningUpdate[Any]:
+ ...
+
+ def returning(
+ self, *cols: _ColumnsClauseArgument[Any], **__kw: Any
+ ) -> ReturningUpdate[Any]:
+ ...
+
+
+class ReturningUpdate(Update, TypedReturnsRows[_TP]):
+ """Typing-only class that establishes a generic type form of
+ :class:`.Update` which tracks returned column types.
+
+ This datatype is delivered when calling the
+ :meth:`.Update.returning` method.
+
+ .. versionadded:: 2.0
+
+ """
+
SelfDelete = typing.TypeVar("SelfDelete", bound="Delete")
@@ -1297,3 +1543,111 @@ class Delete(DMLWhereBase, UpdateBase):
self.table = coercions.expect(
roles.DMLTableRole, table, apply_propagate_attrs=self
)
+
+ if TYPE_CHECKING:
+
+ # START OVERLOADED FUNCTIONS self.returning ReturningDelete 1-8
+
+ # code within this block is **programmatically,
+ # statically generated** by tools/generate_tuple_map_overloads.py
+
+ @overload
+ def returning(self, __ent0: _TCCA[_T0]) -> ReturningDelete[Tuple[_T0]]:
+ ...
+
+ @overload
+ def returning(
+ self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1]
+ ) -> ReturningDelete[Tuple[_T0, _T1]]:
+ ...
+
+ @overload
+ def returning(
+ self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2]
+ ) -> ReturningDelete[Tuple[_T0, _T1, _T2]]:
+ ...
+
+ @overload
+ def returning(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3]]:
+ ...
+
+ @overload
+ def returning(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4]]:
+ ...
+
+ @overload
+ def returning(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]:
+ ...
+
+ @overload
+ def returning(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+ ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]:
+ ...
+
+ @overload
+ def returning(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+ __ent7: _TCCA[_T7],
+ ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]:
+ ...
+
+ # END OVERLOADED FUNCTIONS self.returning
+
+ @overload
+ def returning(
+ self, *cols: _ColumnsClauseArgument[Any], **__kw: Any
+ ) -> ReturningDelete[Any]:
+ ...
+
+ def returning(
+ self, *cols: _ColumnsClauseArgument[Any], **__kw: Any
+ ) -> ReturningDelete[Any]:
+ ...
+
+
+class ReturningDelete(Update, TypedReturnsRows[_TP]):
+ """Typing-only class that establishes a generic type form of
+ :class:`.Delete` which tracks returned column types.
+
+ This datatype is delivered when calling the
+ :meth:`.Delete.returning` method.
+
+ .. versionadded:: 2.0
+
+ """
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 34d5127ab..a29561291 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -54,6 +54,7 @@ from .base import _clone
from .base import _generative
from .base import _NoArg
from .base import Executable
+from .base import Generative
from .base import HasMemoized
from .base import Immutable
from .base import NO_ARG
@@ -94,10 +95,7 @@ if typing.TYPE_CHECKING:
from .selectable import _SelectIterable
from .selectable import FromClause
from .selectable import NamedFromClause
- from .selectable import ReturnsRows
from .selectable import Select
- from .selectable import TableClause
- from .sqltypes import Boolean
from .sqltypes import TupleType
from .type_api import TypeEngine
from .visitors import _CloneCallableType
@@ -122,7 +120,9 @@ _NT = TypeVar("_NT", bound="_NUMERIC")
_NMT = TypeVar("_NMT", bound="_NUMBER")
-def literal(value, type_=None):
+def literal(
+ value: Any, type_: Optional[_TypeEngineArgument[_T]] = None
+) -> BindParameter[_T]:
r"""Return a literal clause, bound to a bind parameter.
Literal clauses are created automatically when non-
@@ -144,7 +144,9 @@ def literal(value, type_=None):
return coercions.expect(roles.LiteralValueRole, value, type_=type_)
-def literal_column(text, type_=None):
+def literal_column(
+ text: str, type_: Optional[_TypeEngineArgument[_T]] = None
+) -> ColumnClause[_T]:
r"""Produce a :class:`.ColumnClause` object that has the
:paramref:`_expression.column.is_literal` flag set to True.
@@ -316,6 +318,7 @@ class ClauseElement(
is_selectable = False
is_dml = False
_is_column_element = False
+ _is_keyed_column_element = False
_is_table = False
_is_textual = False
_is_from_clause = False
@@ -342,7 +345,7 @@ class ClauseElement(
if typing.TYPE_CHECKING:
def get_children(
- self, omit_attrs: typing_Tuple[str, ...] = ..., **kw: Any
+ self, *, omit_attrs: typing_Tuple[str, ...] = ..., **kw: Any
) -> Iterable[ClauseElement]:
...
@@ -455,7 +458,7 @@ class ClauseElement(
connection: Connection,
distilled_params: _CoreMultiExecuteParams,
execution_options: _ExecuteOptions,
- ) -> Result:
+ ) -> Result[Any]:
if self.supports_execution:
if TYPE_CHECKING:
assert isinstance(self, Executable)
@@ -833,13 +836,13 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
def in_(
self,
- other: Union[Sequence[Any], BindParameter[Any], Select],
+ other: Union[Sequence[Any], BindParameter[Any], Select[Any]],
) -> BinaryExpression[bool]:
...
def not_in(
self,
- other: Union[Sequence[Any], BindParameter[Any], Select],
+ other: Union[Sequence[Any], BindParameter[Any], Select[Any]],
) -> BinaryExpression[bool]:
...
@@ -1699,6 +1702,14 @@ class ColumnElement(
return self._anon_label(label, add_hash=idx)
+class KeyedColumnElement(ColumnElement[_T]):
+ """ColumnElement where ``.key`` is non-None."""
+
+ _is_keyed_column_element = True
+
+ key: str
+
+
class WrapsColumnExpression(ColumnElement[_T]):
"""Mixin that defines a :class:`_expression.ColumnElement`
as a wrapper with special
@@ -1760,7 +1771,7 @@ class WrapsColumnExpression(ColumnElement[_T]):
SelfBindParameter = TypeVar("SelfBindParameter", bound="BindParameter[Any]")
-class BindParameter(roles.InElementRole, ColumnElement[_T]):
+class BindParameter(roles.InElementRole, KeyedColumnElement[_T]):
r"""Represent a "bound expression".
:class:`.BindParameter` is invoked explicitly using the
@@ -2073,6 +2084,7 @@ class TextClause(
roles.FromClauseRole,
roles.SelectStatementRole,
roles.InElementRole,
+ Generative,
Executable,
DQLDMLClauseElement,
roles.BinaryElementRole[Any],
@@ -4160,7 +4172,7 @@ class FunctionFilter(ColumnElement[_T]):
)
-class NamedColumn(ColumnElement[_T]):
+class NamedColumn(KeyedColumnElement[_T]):
is_literal = False
table: Optional[FromClause] = None
name: str
@@ -4502,7 +4514,7 @@ class ColumnClause(
self.is_literal = is_literal
- def get_children(self, column_tables=False, **kw):
+ def get_children(self, *, column_tables=False, **kw):
# override base get_children() to not return the Table
# or selectable that is parent to this column. Traversals
# expect the columns of tables and subqueries to be leaf nodes.
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
index 648168235..b827df3df 100644
--- a/lib/sqlalchemy/sql/functions.py
+++ b/lib/sqlalchemy/sql/functions.py
@@ -175,7 +175,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
connection: Connection,
distilled_params: _CoreMultiExecuteParams,
execution_options: _ExecuteOptionsParameter,
- ) -> CursorResult:
+ ) -> CursorResult[Any]:
return connection._execute_function(
self, distilled_params, execution_options
)
@@ -623,7 +623,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
joins_implicitly=joins_implicitly,
)
- def select(self) -> "Select":
+ def select(self) -> Select[Any]:
"""Produce a :func:`_expression.select` construct
against this :class:`.FunctionElement`.
@@ -632,7 +632,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
s = select(function_element)
"""
- s = Select(self)
+ s: Select[Any] = Select(self)
if self._execution_options:
s = s.execution_options(**self._execution_options)
return s
@@ -846,7 +846,7 @@ class _FunctionGenerator:
@overload
def __call__(
- self, *c: Any, type_: TypeEngine[_T], **kwargs: Any
+ self, *c: Any, type_: _TypeEngineArgument[_T], **kwargs: Any
) -> Function[_T]:
...
diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py
index 231c70a5b..09d4b35ad 100644
--- a/lib/sqlalchemy/sql/roles.py
+++ b/lib/sqlalchemy/sql/roles.py
@@ -8,8 +8,6 @@ from __future__ import annotations
from typing import Any
from typing import Generic
-from typing import Iterable
-from typing import List
from typing import Optional
from typing import TYPE_CHECKING
from typing import TypeVar
@@ -19,12 +17,7 @@ from ..util.typing import Literal
if TYPE_CHECKING:
from ._typing import _PropagateAttrsType
- from .base import _EntityNamespace
- from .base import ColumnCollection
- from .base import ReadOnlyColumnCollection
- from .elements import ColumnClause
from .elements import Label
- from .elements import NamedColumn
from .selectable import _SelectIterable
from .selectable import FromClause
from .selectable import Subquery
@@ -108,13 +101,21 @@ class TruncatedLabelRole(StringRole, SQLRole):
class ColumnsClauseRole(AllowsLambdaRole, UsesInspection, ColumnListRole):
__slots__ = ()
- _role_name = "Column expression or FROM clause"
+ _role_name = (
+ "Column expression, FROM clause, or other columns clause element"
+ )
@property
def _select_iterable(self) -> _SelectIterable:
raise NotImplementedError()
+class TypedColumnsClauseRole(Generic[_T], SQLRole):
+ """element-typed form of ColumnsClauseRole"""
+
+ __slots__ = ()
+
+
class LimitOffsetRole(SQLRole):
__slots__ = ()
_role_name = "LIMIT / OFFSET expression"
@@ -161,7 +162,7 @@ class WhereHavingRole(OnClauseRole):
_role_name = "SQL expression for WHERE/HAVING role"
-class ExpressionElementRole(Generic[_T], SQLRole):
+class ExpressionElementRole(TypedColumnsClauseRole[_T]):
# note when using generics for ExpressionElementRole,
# the generic type needs to be in
# sqlalchemy.sql.coercions._impl_lookup mapping also.
@@ -212,39 +213,11 @@ class FromClauseRole(ColumnsClauseRole, JoinTargetRole):
named_with_column: bool
- if TYPE_CHECKING:
-
- @util.ro_non_memoized_property
- def c(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]:
- ...
-
- @util.ro_non_memoized_property
- def columns(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]:
- ...
-
- @util.ro_non_memoized_property
- def entity_namespace(self) -> _EntityNamespace:
- ...
-
- @util.ro_non_memoized_property
- def _hide_froms(self) -> Iterable[FromClause]:
- ...
-
- @util.ro_non_memoized_property
- def _from_objects(self) -> List[FromClause]:
- ...
-
class StrictFromClauseRole(FromClauseRole):
__slots__ = ()
# does not allow text() or select() objects
- if TYPE_CHECKING:
-
- @util.ro_non_memoized_property
- def description(self) -> str:
- ...
-
class AnonymizedFromClauseRole(StrictFromClauseRole):
__slots__ = ()
@@ -317,16 +290,6 @@ class DMLTableRole(FromClauseRole):
__slots__ = ()
_role_name = "subject table for an INSERT, UPDATE or DELETE"
- if TYPE_CHECKING:
-
- @util.ro_non_memoized_property
- def primary_key(self) -> Iterable[NamedColumn[Any]]:
- ...
-
- @util.ro_non_memoized_property
- def columns(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]:
- ...
-
class DMLColumnRole(SQLRole):
__slots__ = ()
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index 52ba60a62..27456d2be 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -86,7 +86,6 @@ if typing.TYPE_CHECKING:
from ._typing import _InfoType
from ._typing import _TextCoercedExpressionArgument
from ._typing import _TypeEngineArgument
- from .base import ColumnCollection
from .base import DedupeColumnCollection
from .base import ReadOnlyColumnCollection
from .compiler import DDLCompiler
@@ -97,9 +96,7 @@ if typing.TYPE_CHECKING:
from .visitors import anon_map
from ..engine import Connection
from ..engine import Engine
- from ..engine.cursor import CursorResult
from ..engine.interfaces import _CoreMultiExecuteParams
- from ..engine.interfaces import _CoreSingleExecuteParams
from ..engine.interfaces import _ExecuteOptionsParameter
from ..engine.interfaces import ExecutionContext
from ..engine.mock import MockConnection
@@ -2609,8 +2606,10 @@ class ForeignKey(DialectKWArgs, SchemaItem):
:class:`_schema.Table`.
"""
-
- return table.columns.corresponding_column(self.column)
+ # our column is a Column, and any subquery etc. proxying us
+ # would be doing so via another Column, so that's what would
+ # be returned here
+ return table.columns.corresponding_column(self.column) # type: ignore
@util.memoized_property
def _column_tokens(self) -> Tuple[Optional[str], str, Optional[str]]:
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index 9d4d1d6c7..b08f13f99 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -23,6 +23,7 @@ from typing import Any
from typing import Callable
from typing import cast
from typing import Dict
+from typing import Generic
from typing import Iterable
from typing import Iterator
from typing import List
@@ -46,6 +47,8 @@ from . import traversals
from . import type_api
from . import visitors
from ._typing import _ColumnsClauseArgument
+from ._typing import _no_kw
+from ._typing import _TP
from ._typing import is_column_element
from ._typing import is_select_statement
from ._typing import is_subquery
@@ -103,9 +106,20 @@ if TYPE_CHECKING:
from ._typing import _ColumnExpressionArgument
from ._typing import _FromClauseArgument
from ._typing import _JoinTargetArgument
+ from ._typing import _MAYBE_ENTITY
+ from ._typing import _NOT_ENTITY
from ._typing import _OnClauseArgument
from ._typing import _SelectStatementForCompoundArgument
+ from ._typing import _T0
+ from ._typing import _T1
+ from ._typing import _T2
+ from ._typing import _T3
+ from ._typing import _T4
+ from ._typing import _T5
+ from ._typing import _T6
+ from ._typing import _T7
from ._typing import _TextCoercedExpressionArgument
+ from ._typing import _TypedColumnClauseArgument as _TCCA
from ._typing import _TypeEngineArgument
from .base import _AmbiguousTableNameMap
from .base import ExecutableOption
@@ -115,14 +129,13 @@ if TYPE_CHECKING:
from .dml import Delete
from .dml import Insert
from .dml import Update
+ from .elements import KeyedColumnElement
from .elements import NamedColumn
from .elements import TextClause
from .functions import Function
- from .schema import Column
from .schema import ForeignKey
from .schema import ForeignKeyConstraint
from .type_api import TypeEngine
- from .util import ClauseAdapter
from .visitors import _CloneCallableType
@@ -245,6 +258,14 @@ class ReturnsRows(roles.ReturnsRowsRole, DQLDMLClauseElement):
raise NotImplementedError()
+class ExecutableReturnsRows(Executable, ReturnsRows):
+ """base for executable statements that return rows."""
+
+
+class TypedReturnsRows(ExecutableReturnsRows, Generic[_TP]):
+ """base for executable statements that return rows."""
+
+
SelfSelectable = TypeVar("SelfSelectable", bound="Selectable")
@@ -293,8 +314,8 @@ class Selectable(ReturnsRows):
)
def corresponding_column(
- self, column: ColumnElement[Any], require_embedded: bool = False
- ) -> Optional[ColumnElement[Any]]:
+ self, column: KeyedColumnElement[Any], require_embedded: bool = False
+ ) -> Optional[KeyedColumnElement[Any]]:
"""Given a :class:`_expression.ColumnElement`, return the exported
:class:`_expression.ColumnElement` object from the
:attr:`_expression.Selectable.exported_columns`
@@ -593,7 +614,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
_use_schema_map = False
- def select(self) -> Select:
+ def select(self) -> Select[Any]:
r"""Return a SELECT of this :class:`_expression.FromClause`.
@@ -795,7 +816,9 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
)
@util.ro_non_memoized_property
- def exported_columns(self) -> ReadOnlyColumnCollection[str, Any]:
+ def exported_columns(
+ self,
+ ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]:
"""A :class:`_expression.ColumnCollection`
that represents the "exported"
columns of this :class:`_expression.Selectable`.
@@ -817,7 +840,9 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
return self.c
@util.ro_non_memoized_property
- def columns(self) -> ReadOnlyColumnCollection[str, Any]:
+ def columns(
+ self,
+ ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]:
"""A named-based collection of :class:`_expression.ColumnElement`
objects maintained by this :class:`_expression.FromClause`.
@@ -833,7 +858,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
return self.c
@util.ro_memoized_property
- def c(self) -> ReadOnlyColumnCollection[str, Any]:
+ def c(self) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]:
"""
A synonym for :attr:`.FromClause.columns`
@@ -1223,7 +1248,7 @@ class Join(roles.DMLTableRole, FromClause):
@util.preload_module("sqlalchemy.sql.util")
def _populate_column_collection(self):
sqlutil = util.preloaded.sql_util
- columns: List[ColumnClause[Any]] = [c for c in self.left.c] + [
+ columns: List[KeyedColumnElement[Any]] = [c for c in self.left.c] + [
c for c in self.right.c
]
@@ -1458,7 +1483,7 @@ class Join(roles.DMLTableRole, FromClause):
"join explicitly." % (a.description, b.description)
)
- def select(self) -> "Select":
+ def select(self) -> Select[Any]:
r"""Create a :class:`_expression.Select` from this
:class:`_expression.Join`.
@@ -2764,6 +2789,7 @@ class Subquery(AliasedReturnsRows):
cls, selectable: SelectBase, name: Optional[str] = None
) -> Subquery:
"""Return a :class:`.Subquery` object."""
+
return coercions.expect(
roles.SelectStatementRole, selectable
).subquery(name=name)
@@ -3216,7 +3242,6 @@ class SelectBase(
roles.CompoundElementRole,
roles.InElementRole,
HasCTE,
- Executable,
SupportsCloneAnnotations,
Selectable,
):
@@ -3239,7 +3264,9 @@ class SelectBase(
self._reset_memoizations()
@util.ro_non_memoized_property
- def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]:
+ def selected_columns(
+ self,
+ ) -> ColumnCollection[str, ColumnElement[Any]]:
"""A :class:`_expression.ColumnCollection`
representing the columns that
this SELECT statement or similar construct returns in its result set.
@@ -3284,7 +3311,9 @@ class SelectBase(
raise NotImplementedError()
@property
- def exported_columns(self) -> ReadOnlyColumnCollection[str, Any]:
+ def exported_columns(
+ self,
+ ) -> ReadOnlyColumnCollection[str, ColumnElement[Any]]:
"""A :class:`_expression.ColumnCollection`
that represents the "exported"
columns of this :class:`_expression.Selectable`, not including
@@ -3377,7 +3406,7 @@ class SelectBase(
def as_scalar(self):
return self.scalar_subquery()
- def exists(self):
+ def exists(self) -> Exists:
"""Return an :class:`_sql.Exists` representation of this selectable,
which can be used as a column expression.
@@ -3394,7 +3423,7 @@ class SelectBase(
"""
return Exists(self)
- def scalar_subquery(self):
+ def scalar_subquery(self) -> ScalarSelect[Any]:
"""Return a 'scalar' representation of this selectable, which can be
used as a column expression.
@@ -3607,7 +3636,7 @@ SelfGenerativeSelect = typing.TypeVar(
)
-class GenerativeSelect(SelectBase):
+class GenerativeSelect(SelectBase, Generative):
"""Base class for SELECT statements where additional elements can be
added.
@@ -4128,7 +4157,7 @@ class _CompoundSelectKeyword(Enum):
INTERSECT_ALL = "INTERSECT ALL"
-class CompoundSelect(HasCompileState, GenerativeSelect):
+class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows):
"""Forms the basis of ``UNION``, ``UNION ALL``, and other
SELECT-based set operations.
@@ -4293,7 +4322,9 @@ class CompoundSelect(HasCompileState, GenerativeSelect):
return self.selects[0]._all_selected_columns
@util.ro_non_memoized_property
- def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]:
+ def selected_columns(
+ self,
+ ) -> ColumnCollection[str, ColumnElement[Any]]:
"""A :class:`_expression.ColumnCollection`
representing the columns that
this SELECT statement or similar construct returns in its result set,
@@ -4343,7 +4374,10 @@ class SelectState(util.MemoizedSlots, CompileState):
...
def __init__(
- self, statement: Select, compiler: Optional[SQLCompiler], **kw: Any
+ self,
+ statement: Select[Any],
+ compiler: Optional[SQLCompiler],
+ **kw: Any,
):
self.statement = statement
self.from_clauses = statement._from_obj
@@ -4369,7 +4403,7 @@ class SelectState(util.MemoizedSlots, CompileState):
@classmethod
def get_column_descriptions(
- cls, statement: Select
+ cls, statement: Select[Any]
) -> List[Dict[str, Any]]:
return [
{
@@ -4384,12 +4418,14 @@ class SelectState(util.MemoizedSlots, CompileState):
@classmethod
def from_statement(
- cls, statement: Select, from_statement: ReturnsRows
- ) -> Any:
+ cls, statement: Select[Any], from_statement: ExecutableReturnsRows
+ ) -> ExecutableReturnsRows:
cls._plugin_not_implemented()
@classmethod
- def get_columns_clause_froms(cls, statement: Select) -> List[FromClause]:
+ def get_columns_clause_froms(
+ cls, statement: Select[Any]
+ ) -> List[FromClause]:
return cls._normalize_froms(
itertools.chain.from_iterable(
element._from_objects for element in statement._raw_columns
@@ -4439,7 +4475,7 @@ class SelectState(util.MemoizedSlots, CompileState):
return go
- def _get_froms(self, statement: Select) -> List[FromClause]:
+ def _get_froms(self, statement: Select[Any]) -> List[FromClause]:
ambiguous_table_name_map: _AmbiguousTableNameMap
self._ambiguous_table_name_map = ambiguous_table_name_map = {}
@@ -4467,7 +4503,7 @@ class SelectState(util.MemoizedSlots, CompileState):
def _normalize_froms(
cls,
iterable_of_froms: Iterable[FromClause],
- check_statement: Optional[Select] = None,
+ check_statement: Optional[Select[Any]] = None,
ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] = None,
) -> List[FromClause]:
"""given an iterable of things to select FROM, reduce them to what
@@ -4615,7 +4651,7 @@ class SelectState(util.MemoizedSlots, CompileState):
@classmethod
def determine_last_joined_entity(
- cls, stmt: Select
+ cls, stmt: Select[Any]
) -> Optional[_JoinTargetElement]:
if stmt._setup_joins:
return stmt._setup_joins[-1][0]
@@ -4623,7 +4659,7 @@ class SelectState(util.MemoizedSlots, CompileState):
return None
@classmethod
- def all_selected_columns(cls, statement: Select) -> _SelectIterable:
+ def all_selected_columns(cls, statement: Select[Any]) -> _SelectIterable:
return [c for c in _select_iterables(statement._raw_columns)]
def _setup_joins(
@@ -4876,7 +4912,7 @@ class _MemoizedSelectEntities(
return c # type: ignore
@classmethod
- def _generate_for_statement(cls, select_stmt: Select) -> None:
+ def _generate_for_statement(cls, select_stmt: Select[Any]) -> None:
if select_stmt._setup_joins or select_stmt._with_options:
self = _MemoizedSelectEntities()
self._raw_columns = select_stmt._raw_columns
@@ -4888,7 +4924,7 @@ class _MemoizedSelectEntities(
select_stmt._setup_joins = select_stmt._with_options = ()
-SelfSelect = typing.TypeVar("SelfSelect", bound="Select")
+SelfSelect = typing.TypeVar("SelfSelect", bound="Select[Any]")
class Select(
@@ -4898,6 +4934,7 @@ class Select(
HasCompileState,
_SelectFromElements,
GenerativeSelect,
+ TypedReturnsRows[_TP],
):
"""Represents a ``SELECT`` statement.
@@ -4973,7 +5010,7 @@ class Select(
_compile_state_factory: Type[SelectState]
@classmethod
- def _create_raw_select(cls, **kw: Any) -> Select:
+ def _create_raw_select(cls, **kw: Any) -> Select[Any]:
"""Create a :class:`.Select` using raw ``__new__`` with no coercions.
Used internally to build up :class:`.Select` constructs with
@@ -4985,7 +5022,7 @@ class Select(
stmt.__dict__.update(kw)
return stmt
- def __init__(self, *entities: _ColumnsClauseArgument):
+ def __init__(self, *entities: _ColumnsClauseArgument[Any]):
r"""Construct a new :class:`_expression.Select`.
The public constructor for :class:`_expression.Select` is the
@@ -5013,7 +5050,9 @@ class Select(
cols = list(elem._select_iterable)
return cols[0].type
- def filter(self: SelfSelect, *criteria: ColumnElement[Any]) -> SelfSelect:
+ def filter(
+ self: SelfSelect, *criteria: _ColumnExpressionArgument[bool]
+ ) -> SelfSelect:
"""A synonym for the :meth:`_future.Select.where` method."""
return self.where(*criteria)
@@ -5032,7 +5071,28 @@ class Select(
return self._raw_columns[0]
- def filter_by(self, **kwargs):
+ if TYPE_CHECKING:
+
+ @overload
+ def scalar_subquery(
+ self: Select[Tuple[_MAYBE_ENTITY]],
+ ) -> ScalarSelect[Any]:
+ ...
+
+ @overload
+ def scalar_subquery(
+ self: Select[Tuple[_NOT_ENTITY]],
+ ) -> ScalarSelect[_NOT_ENTITY]:
+ ...
+
+ @overload
+ def scalar_subquery(self) -> ScalarSelect[Any]:
+ ...
+
+ def scalar_subquery(self) -> ScalarSelect[Any]:
+ ...
+
+ def filter_by(self: SelfSelect, **kwargs: Any) -> SelfSelect:
r"""apply the given filtering criterion as a WHERE clause
to this select.
@@ -5046,7 +5106,7 @@ class Select(
return self.filter(*clauses)
@property
- def column_descriptions(self):
+ def column_descriptions(self) -> Any:
"""Return a :term:`plugin-enabled` 'column descriptions' structure
referring to the columns which are SELECTed by this statement.
@@ -5089,7 +5149,9 @@ class Select(
meth = SelectState.get_plugin_class(self).get_column_descriptions
return meth(self)
- def from_statement(self, statement):
+ def from_statement(
+ self, statement: ExecutableReturnsRows
+ ) -> ExecutableReturnsRows:
"""Apply the columns which this :class:`.Select` would select
onto another statement.
@@ -5410,7 +5472,7 @@ class Select(
)
@property
- def inner_columns(self):
+ def inner_columns(self) -> _SelectIterable:
"""An iterator of all :class:`_expression.ColumnElement`
expressions which would
be rendered into the columns clause of the resulting SELECT statement.
@@ -5487,18 +5549,19 @@ class Select(
self._reset_memoizations()
- def get_children(self, **kwargs):
+ def get_children(self, **kw: Any) -> Iterable[ClauseElement]:
return itertools.chain(
super(Select, self).get_children(
- omit_attrs=("_from_obj", "_correlate", "_correlate_except")
+ omit_attrs=("_from_obj", "_correlate", "_correlate_except"),
+ **kw,
),
self._iterate_from_elements(),
)
@_generative
def add_columns(
- self: SelfSelect, *columns: _ColumnsClauseArgument
- ) -> SelfSelect:
+ self, *columns: _ColumnsClauseArgument[Any]
+ ) -> Select[Any]:
"""Return a new :func:`_expression.select` construct with
the given column expressions added to its columns clause.
@@ -5523,7 +5586,7 @@ class Select(
return self
def _set_entities(
- self, entities: Iterable[_ColumnsClauseArgument]
+ self, entities: Iterable[_ColumnsClauseArgument[Any]]
) -> None:
self._raw_columns = [
coercions.expect(
@@ -5538,7 +5601,7 @@ class Select(
"be removed in a future release. Please use "
":meth:`_expression.Select.add_columns`",
)
- def column(self: SelfSelect, column: _ColumnsClauseArgument) -> SelfSelect:
+ def column(self, column: _ColumnsClauseArgument[Any]) -> Select[Any]:
"""Return a new :func:`_expression.select` construct with
the given column expression added to its columns clause.
@@ -5555,9 +5618,7 @@ class Select(
return self.add_columns(column)
@util.preload_module("sqlalchemy.sql.util")
- def reduce_columns(
- self: SelfSelect, only_synonyms: bool = True
- ) -> SelfSelect:
+ def reduce_columns(self, only_synonyms: bool = True) -> Select[Any]:
"""Return a new :func:`_expression.select` construct with redundantly
named, equivalently-valued columns removed from the columns clause.
@@ -5580,20 +5641,115 @@ class Select(
all columns that are equivalent to another are removed.
"""
- return self.with_only_columns(
+ woc: Select[Any]
+ woc = self.with_only_columns(
*util.preloaded.sql_util.reduce_columns(
self._all_selected_columns,
only_synonyms=only_synonyms,
*(self._where_criteria + self._from_obj),
)
)
+ return woc
+
+ # START OVERLOADED FUNCTIONS self.with_only_columns Select 8
+
+ # code within this block is **programmatically,
+ # statically generated** by tools/generate_sel_v1_overloads.py
+
+ @overload
+ def with_only_columns(self, __ent0: _TCCA[_T0]) -> Select[Tuple[_T0]]:
+ ...
+
+ @overload
+ def with_only_columns(
+ self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1]
+ ) -> Select[Tuple[_T0, _T1]]:
+ ...
+
+ @overload
+ def with_only_columns(
+ self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2]
+ ) -> Select[Tuple[_T0, _T1, _T2]]:
+ ...
+
+ @overload
+ def with_only_columns(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ ) -> Select[Tuple[_T0, _T1, _T2, _T3]]:
+ ...
+
+ @overload
+ def with_only_columns(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ ) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4]]:
+ ...
+
+ @overload
+ def with_only_columns(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ ) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]:
+ ...
+
+ @overload
+ def with_only_columns(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+ ) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]:
+ ...
+
+ @overload
+ def with_only_columns(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+ __ent7: _TCCA[_T7],
+ ) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]:
+ ...
+
+ # END OVERLOADED FUNCTIONS self.with_only_columns
+
+ @overload
+ def with_only_columns(
+ self,
+ *columns: _ColumnsClauseArgument[Any],
+ maintain_column_froms: bool = False,
+ **__kw: Any,
+ ) -> Select[Any]:
+ ...
@_generative
def with_only_columns(
- self: SelfSelect,
- *columns: _ColumnsClauseArgument,
+ self,
+ *columns: _ColumnsClauseArgument[Any],
maintain_column_froms: bool = False,
- ) -> SelfSelect:
+ **__kw: Any,
+ ) -> Select[Any]:
r"""Return a new :func:`_expression.select` construct with its columns
clause replaced with the given columns.
@@ -5647,6 +5803,9 @@ class Select(
""" # noqa: E501
+ if __kw:
+ raise _no_kw()
+
# memoizations should be cleared here as of
# I95c560ffcbfa30b26644999412fb6a385125f663 , asserting this
# is the case for now.
@@ -5915,7 +6074,9 @@ class Select(
return self
@HasMemoized_ro_memoized_attribute
- def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]:
+ def selected_columns(
+ self,
+ ) -> ColumnCollection[str, ColumnElement[Any]]:
"""A :class:`_expression.ColumnCollection`
representing the columns that
this SELECT statement or similar construct returns in its result set,
@@ -6215,7 +6376,7 @@ class ScalarSelect(
by this :class:`_expression.ScalarSelect`.
"""
- self.element = cast(Select, self.element).where(crit)
+ self.element = cast("Select[Any]", self.element).where(crit)
return self
@overload
@@ -6269,7 +6430,9 @@ class ScalarSelect(
"""
- self.element = cast(Select, self.element).correlate(*fromclauses)
+ self.element = cast("Select[Any]", self.element).correlate(
+ *fromclauses
+ )
return self
@_generative
@@ -6307,7 +6470,7 @@ class ScalarSelect(
"""
- self.element = cast(Select, self.element).correlate_except(
+ self.element = cast("Select[Any]", self.element).correlate_except(
*fromclauses
)
return self
@@ -6331,12 +6494,18 @@ class Exists(UnaryExpression[bool]):
def __init__(
self,
__argument: Optional[
- Union[_ColumnsClauseArgument, SelectBase, ScalarSelect[bool]]
+ Union[_ColumnsClauseArgument[Any], SelectBase, ScalarSelect[Any]]
] = None,
):
+ s: ScalarSelect[Any]
+
+ # TODO: this seems like we should be using coercions for this
if __argument is None:
s = Select(literal_column("*")).scalar_subquery()
- elif isinstance(__argument, (SelectBase, ScalarSelect)):
+ elif isinstance(__argument, SelectBase):
+ s = __argument.scalar_subquery()
+ s._propagate_attrs = __argument._propagate_attrs
+ elif isinstance(__argument, ScalarSelect):
s = __argument
else:
s = Select(__argument).scalar_subquery()
@@ -6358,7 +6527,7 @@ class Exists(UnaryExpression[bool]):
element = fn(element)
return element.self_group(against=operators.exists)
- def select(self) -> Select:
+ def select(self) -> Select[Any]:
r"""Return a SELECT of this :class:`_expression.Exists`.
e.g.::
@@ -6452,7 +6621,7 @@ class Exists(UnaryExpression[bool]):
SelfTextualSelect = typing.TypeVar("SelfTextualSelect", bound="TextualSelect")
-class TextualSelect(SelectBase):
+class TextualSelect(SelectBase, Executable, Generative):
"""Wrap a :class:`_expression.TextClause` construct within a
:class:`_expression.SelectBase`
interface.
@@ -6503,7 +6672,9 @@ class TextualSelect(SelectBase):
self.positional = positional
@HasMemoized_ro_memoized_attribute
- def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]:
+ def selected_columns(
+ self,
+ ) -> ColumnCollection[str, KeyedColumnElement[Any]]:
"""A :class:`_expression.ColumnCollection`
representing the columns that
this SELECT statement or similar construct returns in its result set,
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index d08fef60a..8c45ba410 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -50,6 +50,7 @@ from .elements import ClauseElement
from .elements import ColumnClause
from .elements import ColumnElement
from .elements import Grouping
+from .elements import KeyedColumnElement
from .elements import Label
from .elements import Null
from .elements import UnaryExpression
@@ -72,9 +73,7 @@ if typing.TYPE_CHECKING:
from ._typing import _EquivalentColumnMap
from ._typing import _TypeEngineArgument
from .elements import TextClause
- from .roles import FromClauseRole
from .selectable import _JoinTargetElement
- from .selectable import _OnClauseElement
from .selectable import _SelectIterable
from .selectable import Selectable
from .visitors import _TraverseCallableType
@@ -569,7 +568,7 @@ class _repr_row(_repr_base):
__slots__ = ("row",)
- def __init__(self, row: "Row", max_chars: int = 300):
+ def __init__(self, row: "Row[Any]", max_chars: int = 300):
self.row = row
self.max_chars = max_chars
@@ -1068,7 +1067,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
col = col._annotations["adapt_column"]
if TYPE_CHECKING:
- assert isinstance(col, ColumnElement)
+ assert isinstance(col, KeyedColumnElement)
if self.adapt_from_selectables and col not in self.equivalents:
for adp in self.adapt_from_selectables:
@@ -1078,7 +1077,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
return None
if TYPE_CHECKING:
- assert isinstance(col, ColumnElement)
+ assert isinstance(col, KeyedColumnElement)
if self.include_fn and not self.include_fn(col):
return None
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index e0a66fbcf..88586d834 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -450,7 +450,7 @@ class HasTraverseInternals:
@util.preload_module("sqlalchemy.sql.traversals")
def get_children(
- self, omit_attrs: Tuple[str, ...] = (), **kw: Any
+ self, *, omit_attrs: Tuple[str, ...] = (), **kw: Any
) -> Iterable[HasTraverseInternals]:
r"""Return immediate child :class:`.visitors.HasTraverseInternals`
elements of this :class:`.visitors.HasTraverseInternals`.
@@ -594,7 +594,7 @@ class ExternallyTraversible(HasTraverseInternals, Visitable):
if typing.TYPE_CHECKING:
def get_children(
- self, omit_attrs: Tuple[str, ...] = (), **kw: Any
+ self, *, omit_attrs: Tuple[str, ...] = (), **kw: Any
) -> Iterable[ExternallyTraversible]:
...
diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py
index 49c5d693a..da3fbc718 100644
--- a/lib/sqlalchemy/util/langhelpers.py
+++ b/lib/sqlalchemy/util/langhelpers.py
@@ -18,6 +18,7 @@ import hashlib
import inspect
import itertools
import operator
+import os
import re
import sys
import textwrap
@@ -32,6 +33,7 @@ from typing import Generic
from typing import Iterator
from typing import List
from typing import Mapping
+from typing import no_type_check
from typing import NoReturn
from typing import Optional
from typing import overload
@@ -2106,3 +2108,45 @@ def has_compiled_ext(raise_=False):
)
else:
return False
+
+
+@no_type_check
+def console_scripts(
+ path: str, options: dict, ignore_output: bool = False
+) -> None:
+
+ import subprocess
+ import shlex
+ from pathlib import Path
+
+ is_posix = os.name == "posix"
+
+ entrypoint_name = options["entrypoint"]
+
+ for entry in compat.importlib_metadata_get("console_scripts"):
+ if entry.name == entrypoint_name:
+ impl = entry
+ break
+ else:
+ raise Exception(
+ f"Could not find entrypoint console_scripts.{entrypoint_name}"
+ )
+ cmdline_options_str = options.get("options", "")
+ cmdline_options_list = shlex.split(cmdline_options_str, posix=is_posix) + [
+ path
+ ]
+
+ kw = {}
+ if ignore_output:
+ kw["stdout"] = kw["stderr"] = subprocess.DEVNULL
+
+ subprocess.run(
+ [
+ sys.executable,
+ "-c",
+ "import %s; %s.%s()" % (impl.module, impl.module, impl.attr),
+ ]
+ + cmdline_options_list,
+ cwd=Path(__file__).parent.parent,
+ **kw,
+ )
diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py
index d192dc06b..2a215c4f1 100644
--- a/lib/sqlalchemy/util/typing.py
+++ b/lib/sqlalchemy/util/typing.py
@@ -14,7 +14,7 @@ from typing import Type
from typing import TypeVar
from typing import Union
-from typing_extensions import NotRequired as NotRequired # noqa
+from typing_extensions import NotRequired as NotRequired
from . import compat