summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-03-25 17:08:48 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-03-30 14:04:52 -0400
commit4e754a8914a1c2c16c97bdf363d2e24bfa823730 (patch)
treedb723242b4e4c0d4c7f15c167857dd79fdfa6ccb
parentdba480ebaf89c0b5ea787661583de9da3928920f (diff)
downloadsqlalchemy-4e754a8914a1c2c16c97bdf363d2e24bfa823730.tar.gz
pep-484: the pep-484ening, SQL part three
hitting DML which is causing us to open up the ColumnCollection structure a bit, as we do put anonymous column expressions with None here. However, we still want Table /TableClause to have named column collections that don't return None, so parametrize the "key" in this collection also. * rename some "immutable" elements to "readonly". we change the contents of immutablecolumncollection underneath, so it's not "immutable" Change-Id: I2593995a4e5c6eae874bed5bf76117198be8ae97
-rw-r--r--lib/sqlalchemy/cyextension/immutabledict.pyx16
-rw-r--r--lib/sqlalchemy/engine/default.py4
-rw-r--r--lib/sqlalchemy/ext/hybrid.py44
-rw-r--r--lib/sqlalchemy/orm/attributes.py3
-rw-r--r--lib/sqlalchemy/orm/interfaces.py10
-rw-r--r--lib/sqlalchemy/orm/mapper.py9
-rw-r--r--lib/sqlalchemy/orm/persistence.py2
-rw-r--r--lib/sqlalchemy/orm/query.py4
-rw-r--r--lib/sqlalchemy/orm/session.py4
-rw-r--r--lib/sqlalchemy/orm/state.py2
-rw-r--r--lib/sqlalchemy/orm/util.py1
-rw-r--r--lib/sqlalchemy/sql/_dml_constructors.py90
-rw-r--r--lib/sqlalchemy/sql/_elements_constructors.py58
-rw-r--r--lib/sqlalchemy/sql/_selectable_constructors.py4
-rw-r--r--lib/sqlalchemy/sql/_typing.py65
-rw-r--r--lib/sqlalchemy/sql/annotation.py16
-rw-r--r--lib/sqlalchemy/sql/base.py110
-rw-r--r--lib/sqlalchemy/sql/coercions.py27
-rw-r--r--lib/sqlalchemy/sql/compiler.py9
-rw-r--r--lib/sqlalchemy/sql/crud.py11
-rw-r--r--lib/sqlalchemy/sql/dml.py267
-rw-r--r--lib/sqlalchemy/sql/elements.py344
-rw-r--r--lib/sqlalchemy/sql/functions.py7
-rw-r--r--lib/sqlalchemy/sql/roles.py104
-rw-r--r--lib/sqlalchemy/sql/schema.py95
-rw-r--r--lib/sqlalchemy/sql/selectable.py238
-rw-r--r--lib/sqlalchemy/sql/util.py5
-rw-r--r--lib/sqlalchemy/util/__init__.py7
-rw-r--r--lib/sqlalchemy/util/_collections.py10
-rw-r--r--lib/sqlalchemy/util/_py_collections.py28
-rw-r--r--lib/sqlalchemy/util/langhelpers.py7
-rw-r--r--pyproject.toml1
-rw-r--r--test/base/test_utils.py38
-rw-r--r--test/profiles.txt22
-rw-r--r--test/sql/test_quote.py44
-rw-r--r--test/sql/test_returning.py110
-rw-r--r--test/sql/test_selectable.py18
37 files changed, 1146 insertions, 688 deletions
diff --git a/lib/sqlalchemy/cyextension/immutabledict.pyx b/lib/sqlalchemy/cyextension/immutabledict.pyx
index 861e7574d..6ab255311 100644
--- a/lib/sqlalchemy/cyextension/immutabledict.pyx
+++ b/lib/sqlalchemy/cyextension/immutabledict.pyx
@@ -1,18 +1,24 @@
from cpython.dict cimport PyDict_New, PyDict_Update, PyDict_Size
+def _readonly_fn(obj):
+ raise TypeError(
+ "%s object is immutable and/or readonly" % obj.__class__.__name__)
+
+
def _immutable_fn(obj):
- raise TypeError("%s object is immutable" % obj.__class__.__name__)
+ raise TypeError(
+ "%s object is immutable" % obj.__class__.__name__)
-class ImmutableContainer:
+class ReadOnlyContainer:
__slots__ = ()
- def _immutable(self, *a,**kw):
- _immutable_fn(self)
+ def _readonly(self, *a,**kw):
+ _readonly_fn(self)
- __delitem__ = __setitem__ = __setattr__ = _immutable
+ __delitem__ = __setitem__ = __setattr__ = _readonly
class ImmutableDictBase(dict):
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 65cb57e10..85ce91deb 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -54,7 +54,6 @@ from ..sql import expression
from ..sql._typing import is_tuple_type
from ..sql.compiler import DDLCompiler
from ..sql.compiler import SQLCompiler
-from ..sql.elements import ColumnClause
from ..sql.elements import quoted_name
from ..sql.schema import default_is_scalar
@@ -88,6 +87,7 @@ if typing.TYPE_CHECKING:
from ..sql.dml import DMLState
from ..sql.dml import UpdateBase
from ..sql.elements import BindParameter
+ from ..sql.roles import ColumnsClauseRole
from ..sql.schema import Column
from ..sql.schema import ColumnDefault
from ..sql.type_api import _BindProcessorType
@@ -1166,7 +1166,7 @@ class DefaultExecutionContext(ExecutionContext):
return ()
@util.memoized_property
- def returning_cols(self) -> Optional[Sequence[ColumnClause[Any]]]:
+ def returning_cols(self) -> Optional[Sequence[ColumnsClauseRole]]:
if TYPE_CHECKING:
assert isinstance(self.compiled, SQLCompiler)
return self.compiled.returning
diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py
index 92b3ce54f..5ca8b03dd 100644
--- a/lib/sqlalchemy/ext/hybrid.py
+++ b/lib/sqlalchemy/ext/hybrid.py
@@ -813,6 +813,7 @@ from typing import Generic
from typing import List
from typing import Optional
from typing import overload
+from typing import Sequence
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
@@ -824,15 +825,20 @@ from ..orm import attributes
from ..orm import InspectionAttrExtensionType
from ..orm import interfaces
from ..orm import ORMDescriptor
-from ..sql._typing import is_has_column_element_clause_element
+from ..sql._typing import is_has_clause_element
from ..sql.elements import ColumnElement
from ..sql.elements import SQLCoreOperations
from ..util.typing import Literal
from ..util.typing import Protocol
+
if TYPE_CHECKING:
from ..orm.util import AliasedInsp
+ from ..sql._typing import _ColumnExpressionArgument
+ from ..sql._typing import _DMLColumnArgument
+ from ..sql._typing import _HasClauseElement
from ..sql.operators import OperatorType
+ from ..sql.roles import ColumnsClauseRole
_T = TypeVar("_T", bound=Any)
_T_co = TypeVar("_T_co", bound=Any, covariant=True)
@@ -878,10 +884,12 @@ class _HybridSetterType(Protocol[_T_con]):
...
-class _HybridUpdaterType(Protocol[_T]):
+class _HybridUpdaterType(Protocol[_T_con]):
def __call__(
- self, cls: Type[Any], value: Union[_T, SQLCoreOperations[_T]]
- ) -> List[Tuple[SQLCoreOperations[_T], Any]]:
+ self,
+ cls: Type[Any],
+ value: Union[_T_con, _ColumnExpressionArgument[_T_con]],
+ ) -> List[Tuple[_DMLColumnArgument, Any]]:
...
@@ -890,8 +898,10 @@ class _HybridDeleterType(Protocol[_T_co]):
...
-class _HybridExprCallableType(Protocol[_T]):
- def __call__(self, cls: Any) -> SQLCoreOperations[_T]:
+class _HybridExprCallableType(Protocol[_T_co]):
+ def __call__(
+ self, cls: Any
+ ) -> Union[_HasClauseElement, ColumnElement[_T_co]]:
...
@@ -1273,17 +1283,21 @@ class Comparator(interfaces.PropComparator[_T]):
:class:`~.orm.interfaces.PropComparator`
classes for usage with hybrids."""
- def __init__(self, expression: SQLCoreOperations[_T]):
+ def __init__(
+ self, expression: Union[_HasClauseElement, ColumnElement[_T]]
+ ):
self.expression = expression
- def __clause_element__(self) -> ColumnElement[_T]:
+ def __clause_element__(self) -> ColumnsClauseRole:
expr = self.expression
- if is_has_column_element_clause_element(expr):
- expr = expr.__clause_element__()
+ if is_has_clause_element(expr):
+ ret_expr = expr.__clause_element__()
+ else:
+ if TYPE_CHECKING:
+ assert isinstance(expr, ColumnElement)
+ ret_expr = expr
- elif TYPE_CHECKING:
- assert isinstance(expr, ColumnElement)
- return expr
+ return ret_expr
@util.non_memoized_property
def property(self) -> Any:
@@ -1298,7 +1312,7 @@ class ExprComparator(Comparator[_T]):
def __init__(
self,
cls: Type[Any],
- expression: SQLCoreOperations[_T],
+ expression: Union[_HasClauseElement, ColumnElement[_T]],
hybrid: hybrid_property[_T],
):
self.cls = cls
@@ -1314,7 +1328,7 @@ class ExprComparator(Comparator[_T]):
def _bulk_update_tuples(
self, value: Any
- ) -> List[Tuple[SQLCoreOperations[_T], Any]]:
+ ) -> Sequence[Tuple[_DMLColumnArgument, Any]]:
if isinstance(self.expression, attributes.QueryableAttribute):
return self.expression._bulk_update_tuples(value)
elif self.hybrid.update_expr is not None:
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py
index 2b6ca400e..3d3492710 100644
--- a/lib/sqlalchemy/orm/attributes.py
+++ b/lib/sqlalchemy/orm/attributes.py
@@ -68,6 +68,7 @@ from ..sql import traversals
from ..sql import visitors
if typing.TYPE_CHECKING:
+ from ..sql.dml import _DMLColumnElement
from ..sql.elements import ColumnElement
from ..sql.elements import SQLCoreOperations
@@ -281,7 +282,7 @@ class QueryableAttribute(
def _bulk_update_tuples(
self, value: Any
- ) -> List[Tuple[SQLCoreOperations[_T], Any]]:
+ ) -> List[Tuple[_DMLColumnElement, Any]]:
"""Return setter tuples for a bulk UPDATE."""
return self.comparator._bulk_update_tuples(value)
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index d79774187..b4228323b 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -24,6 +24,7 @@ from typing import Any
from typing import cast
from typing import List
from typing import Optional
+from typing import Sequence
from typing import Tuple
from typing import Type
from typing import TypeVar
@@ -50,7 +51,6 @@ from .. import util
from ..sql import operators
from ..sql import roles
from ..sql import visitors
-from ..sql._typing import _ColumnsClauseElement
from ..sql.base import ExecutableOption
from ..sql.cache_key import HasCacheKey
from ..sql.elements import SQLCoreOperations
@@ -60,6 +60,8 @@ from ..util.typing import TypedDict
if typing.TYPE_CHECKING:
from .decl_api import RegistryType
+ from ..sql._typing import _ColumnsClauseArgument
+ from ..sql._typing import _DMLColumnArgument
_T = TypeVar("_T", bound=Any)
@@ -90,8 +92,8 @@ class ORMColumnDescription(TypedDict):
name: str
type: Union[Type, TypeEngine]
aliased: bool
- expr: _ColumnsClauseElement
- entity: Optional[_ColumnsClauseElement]
+ expr: _ColumnsClauseArgument
+ entity: Optional[_ColumnsClauseArgument]
class _IntrospectsAnnotations:
@@ -468,7 +470,7 @@ class PropComparator(SQLORMOperations[_T]):
def _bulk_update_tuples(
self, value: Any
- ) -> List[Tuple[SQLCoreOperations[_T], Any]]:
+ ) -> Sequence[Tuple[_DMLColumnArgument, Any]]:
"""Receive a SQL expression that represents a value in the SET
clause of an UPDATE statement.
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index 7d1fc7643..e463dcdb5 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -1749,6 +1749,7 @@ class Mapper(
col.key = col._tq_key_label = key
self.columns.add(col, key)
+
for col in prop.columns + prop._orig_columns:
for col in col.proxy_set:
self._columntoproperty[col] = prop
@@ -2381,7 +2382,7 @@ class Mapper(
yield c
@HasMemoized.memoized_attribute
- def attrs(self) -> util.ImmutableProperties["MapperProperty"]:
+ def attrs(self) -> util.ReadOnlyProperties["MapperProperty"]:
"""A namespace of all :class:`.MapperProperty` objects
associated this mapper.
@@ -2416,7 +2417,7 @@ class Mapper(
"""
self._check_configure()
- return util.ImmutableProperties(self._props)
+ return util.ReadOnlyProperties(self._props)
@HasMemoized.memoized_attribute
def all_orm_descriptors(self):
@@ -2484,7 +2485,7 @@ class Mapper(
:attr:`_orm.Mapper.attrs`
"""
- return util.ImmutableProperties(
+ return util.ReadOnlyProperties(
dict(self.class_manager._all_sqla_attributes())
)
@@ -2571,7 +2572,7 @@ class Mapper(
def _filter_properties(self, type_):
self._check_configure()
- return util.ImmutableProperties(
+ return util.ReadOnlyProperties(
util.OrderedDict(
(k, v) for k, v in self._props.items() if isinstance(v, type_)
)
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index 6478aac15..f2cddad53 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -2206,7 +2206,7 @@ class BulkORMUpdate(ORMDMLState, UpdateDMLState, BulkUDCompileState):
if opt._is_criteria_option:
opt.get_global_criteria(extra_criteria_attributes)
- if not statement._preserve_parameter_order and statement._values:
+ if statement._values:
self._resolved_values = dict(self._resolved_values)
new_stmt = sql.Update.__new__(sql.Update)
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index ea5d5406e..18a14012f 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -60,7 +60,7 @@ from ..sql import roles
from ..sql import Select
from ..sql import util as sql_util
from ..sql import visitors
-from ..sql._typing import _FromClauseElement
+from ..sql._typing import _FromClauseArgument
from ..sql.annotation import SupportsCloneAnnotations
from ..sql.base import _entity_namespace_key
from ..sql.base import _generative
@@ -2018,7 +2018,7 @@ class Query(
@_generative
@_assertions(_no_clauseelement_condition)
def select_from(
- self: SelfQuery, *from_obj: _FromClauseElement
+ self: SelfQuery, *from_obj: _FromClauseArgument
) -> SelfQuery:
r"""Set the FROM clause of this :class:`.Query` explicitly.
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 4140d52c5..58820fef6 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -56,7 +56,7 @@ from ..sql import coercions
from ..sql import dml
from ..sql import roles
from ..sql import visitors
-from ..sql._typing import _ColumnsClauseElement
+from ..sql._typing import _ColumnsClauseArgument
from ..sql.base import CompileState
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
from ..util.typing import Literal
@@ -2040,7 +2040,7 @@ class Session(_SessionClassMethods):
)
def query(
- self, *entities: "_ColumnsClauseElement", **kwargs: Any
+ self, *entities: "_ColumnsClauseArgument", **kwargs: Any
) -> "Query":
"""Return a new :class:`_query.Query` object corresponding to this
:class:`_orm.Session`.
diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py
index 58fa3e41a..c3e4e299a 100644
--- a/lib/sqlalchemy/orm/state.py
+++ b/lib/sqlalchemy/orm/state.py
@@ -122,7 +122,7 @@ class InstanceState(interfaces.InspectionAttrInfo):
since the last flush.
"""
- return util.ImmutableProperties(
+ return util.ReadOnlyProperties(
dict((key, AttributeState(self, key)) for key in self.manager)
)
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 2b49d4400..baca8f547 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -632,7 +632,6 @@ class AliasedInsp(
ORMEntityColumnsClauseRole,
ORMFromClauseRole,
sql_base.HasCacheKey,
- roles.HasFromClauseElement,
InspectionAttr,
MemoizedSlots,
):
diff --git a/lib/sqlalchemy/sql/_dml_constructors.py b/lib/sqlalchemy/sql/_dml_constructors.py
index a8c24413f..835819bac 100644
--- a/lib/sqlalchemy/sql/_dml_constructors.py
+++ b/lib/sqlalchemy/sql/_dml_constructors.py
@@ -112,82 +112,10 @@ def update(table):
object representing the database
table to be updated.
- :param whereclause: Optional SQL expression describing the ``WHERE``
- condition of the ``UPDATE`` statement; is equivalent to using the
- more modern :meth:`~Update.where()` method to specify the ``WHERE``
- clause.
-
- :param values:
- Optional dictionary which specifies the ``SET`` conditions of the
- ``UPDATE``. If left as ``None``, the ``SET``
- conditions are determined from those parameters passed to the
- statement during the execution and/or compilation of the
- statement. When compiled standalone without any parameters,
- the ``SET`` clause generates for all columns.
-
- Modern applications may prefer to use the generative
- :meth:`_expression.Update.values` method to set the values of the
- UPDATE statement.
-
- :param inline:
- if True, SQL defaults present on :class:`_schema.Column` objects via
- the ``default`` keyword will be compiled 'inline' into the statement
- and not pre-executed. This means that their values will not
- be available in the dictionary returned from
- :meth:`_engine.CursorResult.last_updated_params`.
-
- :param preserve_parameter_order: if True, the update statement is
- expected to receive parameters **only** via the
- :meth:`_expression.Update.values` method,
- and they must be passed as a Python
- ``list`` of 2-tuples. The rendered UPDATE statement will emit the SET
- clause for each referenced column maintaining this order.
-
- .. versionadded:: 1.0.10
-
- .. seealso::
-
- :ref:`updates_order_parameters` - illustrates the
- :meth:`_expression.Update.ordered_values` method.
-
- If both ``values`` and compile-time bind parameters are present, the
- compile-time bind parameters override the information specified
- within ``values`` on a per-key basis.
-
- The keys within ``values`` can be either :class:`_schema.Column`
- objects or their string identifiers (specifically the "key" of the
- :class:`_schema.Column`, normally but not necessarily equivalent to
- its "name"). Normally, the
- :class:`_schema.Column` objects used here are expected to be
- part of the target :class:`_schema.Table` that is the table
- to be updated. However when using MySQL, a multiple-table
- UPDATE statement can refer to columns from any of
- the tables referred to in the WHERE clause.
-
- The values referred to in ``values`` are typically:
-
- * a literal data value (i.e. string, number, etc.)
- * a SQL expression, such as a related :class:`_schema.Column`,
- a scalar-returning :func:`_expression.select` construct,
- etc.
-
- When combining :func:`_expression.select` constructs within the
- values clause of an :func:`_expression.update`
- construct, the subquery represented
- by the :func:`_expression.select` should be *correlated* to the
- parent table, that is, providing criterion which links the table inside
- the subquery to the outer table being updated::
-
- users.update().values(
- name=select(addresses.c.email_address).\
- where(addresses.c.user_id==users.c.id).\
- scalar_subquery()
- )
.. seealso::
- :ref:`inserts_and_updates` - SQL Expression
- Language Tutorial
+ :ref:`tutorial_core_update_delete` - in the :ref:`unified_tutorial`
"""
@@ -210,24 +138,12 @@ def delete(table):
:meth:`_expression.TableClause.delete` method on
:class:`_schema.Table`.
- .. seealso::
-
- :ref:`inserts_and_updates` - in the
- :ref:`1.x tutorial <sqlexpression_toplevel>`
-
- :ref:`tutorial_core_update_delete` - in the :ref:`unified_tutorial`
-
-
:param table: The table to delete rows from.
- :param whereclause: Optional SQL expression describing the ``WHERE``
- condition of the ``DELETE`` statement; is equivalent to using the
- more modern :meth:`~Delete.where()` method to specify the ``WHERE``
- clause.
-
.. seealso::
- :ref:`deletes` - SQL Expression Tutorial
+ :ref:`tutorial_core_update_delete` - in the :ref:`unified_tutorial`
+
"""
return Delete(table)
diff --git a/lib/sqlalchemy/sql/_elements_constructors.py b/lib/sqlalchemy/sql/_elements_constructors.py
index aabd3871e..f647ae927 100644
--- a/lib/sqlalchemy/sql/_elements_constructors.py
+++ b/lib/sqlalchemy/sql/_elements_constructors.py
@@ -48,7 +48,7 @@ from ..util.typing import Literal
if typing.TYPE_CHECKING:
from . import sqltypes
- from ._typing import _ColumnExpression
+ from ._typing import _ColumnExpressionArgument
from ._typing import _TypeEngineArgument
from .elements import BinaryExpression
from .functions import FunctionElement
@@ -58,7 +58,7 @@ if typing.TYPE_CHECKING:
_T = TypeVar("_T")
-def all_(expr: _ColumnExpression[_T]) -> CollectionAggregate[bool]:
+def all_(expr: _ColumnExpressionArgument[_T]) -> CollectionAggregate[bool]:
"""Produce an ALL expression.
For dialects such as that of PostgreSQL, this operator applies
@@ -112,7 +112,7 @@ def all_(expr: _ColumnExpression[_T]) -> CollectionAggregate[bool]:
return CollectionAggregate._create_all(expr)
-def and_(*clauses: _ColumnExpression[bool]) -> BooleanClauseList:
+def and_(*clauses: _ColumnExpressionArgument[bool]) -> ColumnElement[bool]:
r"""Produce a conjunction of expressions joined by ``AND``.
E.g.::
@@ -173,7 +173,7 @@ def and_(*clauses: _ColumnExpression[bool]) -> BooleanClauseList:
return BooleanClauseList.and_(*clauses)
-def any_(expr: _ColumnExpression[_T]) -> CollectionAggregate[bool]:
+def any_(expr: _ColumnExpressionArgument[_T]) -> CollectionAggregate[bool]:
"""Produce an ANY expression.
For dialects such as that of PostgreSQL, this operator applies
@@ -227,7 +227,7 @@ def any_(expr: _ColumnExpression[_T]) -> CollectionAggregate[bool]:
return CollectionAggregate._create_any(expr)
-def asc(column: _ColumnExpression[_T]) -> UnaryExpression[_T]:
+def asc(column: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]:
"""Produce an ascending ``ORDER BY`` clause element.
e.g.::
@@ -266,7 +266,7 @@ def asc(column: _ColumnExpression[_T]) -> UnaryExpression[_T]:
def collate(
- expression: _ColumnExpression[str], collation: str
+ expression: _ColumnExpressionArgument[str], collation: str
) -> BinaryExpression[str]:
"""Return the clause ``expression COLLATE collation``.
@@ -289,7 +289,7 @@ def collate(
def between(
- expr: _ColumnExpression[_T],
+ expr: _ColumnExpressionArgument[_T],
lower_bound: Any,
upper_bound: Any,
symmetric: bool = False,
@@ -364,17 +364,19 @@ def outparam(
return BindParameter(key, None, type_=type_, unique=False, isoutparam=True)
+# mypy insists that BinaryExpression and _HasClauseElement protocol overlap.
+# they do not. at all. bug in mypy?
@overload
-def not_(clause: BinaryExpression[_T]) -> BinaryExpression[_T]:
+def not_(clause: BinaryExpression[_T]) -> BinaryExpression[_T]: # type: ignore
...
@overload
-def not_(clause: _ColumnExpression[_T]) -> ColumnElement[_T]:
+def not_(clause: _ColumnExpressionArgument[_T]) -> ColumnElement[_T]:
...
-def not_(clause: _ColumnExpression[_T]) -> ColumnElement[_T]:
+def not_(clause: _ColumnExpressionArgument[_T]) -> ColumnElement[_T]:
"""Return a negation of the given clause, i.e. ``NOT(clause)``.
The ``~`` operator is also overloaded on all
@@ -646,7 +648,7 @@ def bindparam(
def case(
*whens: Union[
- typing_Tuple[_ColumnExpression[bool], Any], Mapping[Any, Any]
+ typing_Tuple[_ColumnExpressionArgument[bool], Any], Mapping[Any, Any]
],
value: Optional[Any] = None,
else_: Optional[Any] = None,
@@ -775,7 +777,7 @@ def case(
def cast(
- expression: _ColumnExpression[Any],
+ expression: _ColumnExpressionArgument[Any],
type_: _TypeEngineArgument[_T],
) -> Cast[_T]:
r"""Produce a ``CAST`` expression.
@@ -932,7 +934,7 @@ def column(
return ColumnClause(text, type_, is_literal, _selectable)
-def desc(column: _ColumnExpression[_T]) -> UnaryExpression[_T]:
+def desc(column: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]:
"""Produce a descending ``ORDER BY`` clause element.
e.g.::
@@ -971,7 +973,7 @@ def desc(column: _ColumnExpression[_T]) -> UnaryExpression[_T]:
return UnaryExpression._create_desc(column)
-def distinct(expr: _ColumnExpression[_T]) -> UnaryExpression[_T]:
+def distinct(expr: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]:
"""Produce an column-expression-level unary ``DISTINCT`` clause.
This applies the ``DISTINCT`` keyword to an individual column
@@ -1010,7 +1012,7 @@ def distinct(expr: _ColumnExpression[_T]) -> UnaryExpression[_T]:
return UnaryExpression._create_distinct(expr)
-def extract(field: str, expr: _ColumnExpression[Any]) -> Extract:
+def extract(field: str, expr: _ColumnExpressionArgument[Any]) -> Extract:
"""Return a :class:`.Extract` construct.
This is typically available as :func:`.extract`
@@ -1090,7 +1092,7 @@ def false() -> False_:
def funcfilter(
- func: FunctionElement[_T], *criterion: _ColumnExpression[bool]
+ func: FunctionElement[_T], *criterion: _ColumnExpressionArgument[bool]
) -> FunctionFilter[_T]:
"""Produce a :class:`.FunctionFilter` object against a function.
@@ -1122,7 +1124,7 @@ def funcfilter(
def label(
name: str,
- element: _ColumnExpression[_T],
+ element: _ColumnExpressionArgument[_T],
type_: Optional[_TypeEngineArgument[_T]] = None,
) -> "Label[_T]":
"""Return a :class:`Label` object for the
@@ -1149,7 +1151,7 @@ def null() -> Null:
return Null._instance()
-def nulls_first(column: _ColumnExpression[_T]) -> UnaryExpression[_T]:
+def nulls_first(column: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]:
"""Produce the ``NULLS FIRST`` modifier for an ``ORDER BY`` expression.
:func:`.nulls_first` is intended to modify the expression produced
@@ -1193,7 +1195,7 @@ def nulls_first(column: _ColumnExpression[_T]) -> UnaryExpression[_T]:
return UnaryExpression._create_nulls_first(column)
-def nulls_last(column: _ColumnExpression[_T]) -> UnaryExpression[_T]:
+def nulls_last(column: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]:
"""Produce the ``NULLS LAST`` modifier for an ``ORDER BY`` expression.
:func:`.nulls_last` is intended to modify the expression produced
@@ -1237,7 +1239,7 @@ def nulls_last(column: _ColumnExpression[_T]) -> UnaryExpression[_T]:
return UnaryExpression._create_nulls_last(column)
-def or_(*clauses: _ColumnExpression[bool]) -> BooleanClauseList:
+def or_(*clauses: _ColumnExpressionArgument[bool]) -> ColumnElement[bool]:
"""Produce a conjunction of expressions joined by ``OR``.
E.g.::
@@ -1291,10 +1293,16 @@ def or_(*clauses: _ColumnExpression[bool]) -> BooleanClauseList:
def over(
element: FunctionElement[_T],
partition_by: Optional[
- Union[Iterable[_ColumnExpression[Any]], _ColumnExpression[Any]]
+ Union[
+ Iterable[_ColumnExpressionArgument[Any]],
+ _ColumnExpressionArgument[Any],
+ ]
] = None,
order_by: Optional[
- Union[Iterable[_ColumnExpression[Any]], _ColumnExpression[Any]]
+ Union[
+ Iterable[_ColumnExpressionArgument[Any]],
+ _ColumnExpressionArgument[Any],
+ ]
] = None,
range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None,
rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None,
@@ -1502,7 +1510,7 @@ def true() -> True_:
def tuple_(
- *clauses: _ColumnExpression[Any],
+ *clauses: _ColumnExpressionArgument[Any],
types: Optional[Sequence[_TypeEngineArgument[Any]]] = None,
) -> Tuple:
"""Return a :class:`.Tuple`.
@@ -1531,7 +1539,7 @@ def tuple_(
def type_coerce(
- expression: _ColumnExpression[Any],
+ expression: _ColumnExpressionArgument[Any],
type_: _TypeEngineArgument[_T],
) -> TypeCoerce[_T]:
r"""Associate a SQL expression with a particular type, without rendering
@@ -1612,7 +1620,7 @@ def type_coerce(
def within_group(
- element: FunctionElement[_T], *order_by: _ColumnExpression[Any]
+ element: FunctionElement[_T], *order_by: _ColumnExpressionArgument[Any]
) -> WithinGroup[_T]:
r"""Produce a :class:`.WithinGroup` object against a function.
diff --git a/lib/sqlalchemy/sql/_selectable_constructors.py b/lib/sqlalchemy/sql/_selectable_constructors.py
index e9acc7e6d..a17ee4ce8 100644
--- a/lib/sqlalchemy/sql/_selectable_constructors.py
+++ b/lib/sqlalchemy/sql/_selectable_constructors.py
@@ -12,7 +12,7 @@ from typing import Optional
from . import coercions
from . import roles
-from ._typing import _ColumnsClauseElement
+from ._typing import _ColumnsClauseArgument
from .elements import ColumnClause
from .selectable import Alias
from .selectable import CompoundSelect
@@ -281,7 +281,7 @@ def outerjoin(left, right, onclause=None, full=False):
return Join(left, right, onclause, isouter=True, full=full)
-def select(*entities: _ColumnsClauseElement) -> Select:
+def select(*entities: _ColumnsClauseArgument) -> Select:
r"""Construct a new :class:`_expression.Select`.
diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py
index b50a7bf6a..a5da87802 100644
--- a/lib/sqlalchemy/sql/_typing.py
+++ b/lib/sqlalchemy/sql/_typing.py
@@ -1,6 +1,7 @@
from __future__ import annotations
from typing import Any
+from typing import Iterable
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
@@ -10,9 +11,17 @@ from . import roles
from .. import util
from ..inspection import Inspectable
from ..util.typing import Literal
+from ..util.typing import Protocol
if TYPE_CHECKING:
+ from .elements import ClauseElement
+ from .elements import ColumnClause
+ from .elements import ColumnElement
from .elements import quoted_name
+ from .elements import SQLCoreOperations
+ from .elements import TextClause
+ from .roles import ColumnsClauseRole
+ from .roles import FromClauseRole
from .schema import DefaultGenerator
from .schema import Sequence
from .selectable import FromClause
@@ -24,31 +33,61 @@ if TYPE_CHECKING:
_T = TypeVar("_T", bound=Any)
-_ColumnsClauseElement = Union[
+
+class _HasClauseElement(Protocol):
+ """indicates a class that has a __clause_element__() method"""
+
+ def __clause_element__(self) -> ColumnsClauseRole:
+ ...
+
+
+# convention:
+# XYZArgument - something that the end user is passing to a public API method
+# XYZElement - the internal representation that we use for the thing.
+# the coercions system is responsible for converting from XYZArgument to
+# XYZElement.
+
+_ColumnsClauseArgument = Union[
Literal["*", 1],
roles.ColumnsClauseRole,
Type[Any],
- Inspectable[roles.HasColumnElementClauseElement],
+ Inspectable[_HasClauseElement],
+ _HasClauseElement,
]
-_FromClauseElement = Union[
- roles.FromClauseRole, Type[Any], Inspectable[roles.HasFromClauseElement]
+
+_SelectIterable = Iterable[Union["ColumnElement[Any]", "TextClause"]]
+
+_FromClauseArgument = Union[
+ roles.FromClauseRole,
+ Type[Any],
+ Inspectable[_HasClauseElement],
+ _HasClauseElement,
]
-_ColumnExpression = Union[
- roles.ExpressionElementRole[_T],
- Inspectable[roles.HasColumnElementClauseElement],
+_ColumnExpressionArgument = Union[
+ "ColumnElement[_T]", _HasClauseElement, roles.ExpressionElementRole[_T]
]
+_DMLColumnArgument = Union[str, "ColumnClause[Any]", _HasClauseElement]
+
_PropagateAttrsType = util.immutabledict[str, Any]
_TypeEngineArgument = Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"]
-def is_named_from_clause(t: FromClause) -> TypeGuard[NamedFromClause]:
+def is_named_from_clause(t: FromClauseRole) -> TypeGuard[NamedFromClause]:
return t.named_with_column
-def has_schema_attr(t: FromClause) -> TypeGuard[TableClause]:
+def is_column_element(c: ClauseElement) -> TypeGuard[ColumnElement[Any]]:
+ return c._is_column_element
+
+
+def is_text_clause(c: ClauseElement) -> TypeGuard[TextClause]:
+ return c._is_text_clause
+
+
+def has_schema_attr(t: FromClauseRole) -> TypeGuard[TableClause]:
return hasattr(t, "schema")
@@ -60,11 +99,5 @@ def is_tuple_type(t: TypeEngine[Any]) -> TypeGuard[TupleType]:
return t._is_tuple_type
-def is_has_clause_element(s: object) -> TypeGuard[roles.HasClauseElement]:
- return hasattr(s, "__clause_element__")
-
-
-def is_has_column_element_clause_element(
- s: object,
-) -> TypeGuard[roles.HasColumnElementClauseElement]:
+def is_has_clause_element(s: object) -> TypeGuard[_HasClauseElement]:
return hasattr(s, "__clause_element__")
diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py
index f37ae9a60..f1919d1d3 100644
--- a/lib/sqlalchemy/sql/annotation.py
+++ b/lib/sqlalchemy/sql/annotation.py
@@ -25,6 +25,7 @@ from typing import overload
from typing import Sequence
from typing import Tuple
from typing import Type
+from typing import TYPE_CHECKING
from typing import TypeVar
from . import operators
@@ -35,9 +36,9 @@ from .visitors import InternalTraversal
from .. import util
from ..util.typing import Literal
-if typing.TYPE_CHECKING:
+if TYPE_CHECKING:
+ from .base import _EntityNamespace
from .visitors import _TraverseInternalsType
- from ..util.typing import Self
_AnnotationDict = Mapping[str, Any]
@@ -192,7 +193,12 @@ class SupportsWrappingAnnotations(SupportsAnnotations):
__slots__ = ()
_constructor: Callable[..., SupportsWrappingAnnotations]
- entity_namespace: Mapping[str, Any]
+
+ if TYPE_CHECKING:
+
+ @util.ro_non_memoized_property
+ def entity_namespace(self) -> _EntityNamespace:
+ ...
def _annotate(self, values: _AnnotationDict) -> Annotated:
"""return a copy of this ClauseElement with annotations
@@ -380,8 +386,8 @@ class Annotated(SupportsAnnotations):
else:
return hash(other) == hash(self)
- @property
- def entity_namespace(self) -> Mapping[str, Any]:
+ @util.ro_non_memoized_property
+ def entity_namespace(self) -> _EntityNamespace:
if "entity_namespace" in self._annotations:
return cast(
SupportsWrappingAnnotations,
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index 8f5135915..19e4c13d2 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -62,10 +62,14 @@ if TYPE_CHECKING:
from . import coercions
from . import elements
from . import type_api
+ from ._typing import _ColumnsClauseArgument
+ from ._typing import _SelectIterable
from .elements import BindParameter
from .elements import ColumnClause
from .elements import ColumnElement
+ from .elements import NamedColumn
from .elements import SQLCoreOperations
+ from .selectable import FromClause
from ..engine import Connection
from ..engine import Result
from ..engine.base import _CompiledCacheType
@@ -91,6 +95,8 @@ class _NoArg(Enum):
NO_ARG = _NoArg.NO_ARG
+_T = TypeVar("_T", bound=Any)
+
_Fn = TypeVar("_Fn", bound=Callable[..., Any])
_AmbiguousTableNameMap = MutableMapping[str, str]
@@ -102,7 +108,9 @@ class _EntityNamespace(Protocol):
class _HasEntityNamespace(Protocol):
- entity_namespace: _EntityNamespace
+ @util.ro_non_memoized_property
+ def entity_namespace(self) -> _EntityNamespace:
+ ...
def _is_has_entity_namespace(element: Any) -> TypeGuard[_HasEntityNamespace]:
@@ -136,8 +144,8 @@ class SingletonConstant(Immutable):
_singleton: SingletonConstant
- def __new__(cls, *arg, **kw):
- return cls._singleton
+ def __new__(cls: _T, *arg: Any, **kw: Any) -> _T:
+ return cast(_T, cls._singleton)
@util.non_memoized_property
def proxy_set(self) -> FrozenSet[ColumnElement[Any]]:
@@ -159,13 +167,15 @@ class SingletonConstant(Immutable):
cls._singleton = obj
-def _from_objects(*elements):
+def _from_objects(*elements: ColumnElement[Any]) -> Iterator[FromClause]:
return itertools.chain.from_iterable(
[element._from_objects for element in elements]
)
-def _select_iterables(elements):
+def _select_iterables(
+ elements: Iterable[roles.ColumnsClauseRole],
+) -> _SelectIterable:
"""expand tables into individual columns in the
given list of column expressions.
@@ -207,7 +217,7 @@ def _generative(fn: _Fn) -> _Fn:
return decorated # type: ignore
-def _exclusive_against(*names, **kw):
+def _exclusive_against(*names: str, **kw: Any) -> Callable[[_Fn], _Fn]:
msgs = kw.pop("msgs", {})
defaults = kw.pop("defaults", {})
@@ -502,7 +512,7 @@ class DialectKWArgs:
util.portable_instancemethod(self._kw_reg_for_dialect_cls)
)
- def _validate_dialect_kwargs(self, kwargs):
+ def _validate_dialect_kwargs(self, kwargs: Any) -> None:
# validate remaining kwargs that they all specify DB prefixes
if not kwargs:
@@ -605,7 +615,9 @@ class CompileState:
self.statement = statement
@classmethod
- def get_plugin_class(cls, statement):
+ def get_plugin_class(
+ cls, statement: Executable
+ ) -> Optional[Type[CompileState]]:
plugin_name = statement._propagate_attrs.get(
"compile_state_plugin", None
)
@@ -634,7 +646,9 @@ class CompileState:
return None
@classmethod
- def plugin_for(cls, plugin_name, visit_name):
+ def plugin_for(
+ cls, plugin_name: str, visit_name: str
+ ) -> Callable[[_Fn], _Fn]:
def decorate(cls_to_decorate):
cls.plugins[(plugin_name, visit_name)] = cls_to_decorate
return cls_to_decorate
@@ -957,7 +971,7 @@ class Executable(roles.StatementRole, Generative):
) -> Result:
...
- @property
+ @util.non_memoized_property
def _all_selected_columns(self):
raise NotImplementedError()
@@ -1202,10 +1216,11 @@ class SchemaVisitor(ClauseVisitor):
__traverse_options__ = {"schema_visitor": True}
-_COL = TypeVar("_COL", bound="ColumnClause[Any]")
+_COLKEY = TypeVar("_COLKEY", Union[None, str], str)
+_COL = TypeVar("_COL", bound="ColumnElement[Any]")
-class ColumnCollection(Generic[_COL]):
+class ColumnCollection(Generic[_COLKEY, _COL]):
"""Collection of :class:`_expression.ColumnElement` instances,
typically for
:class:`_sql.FromClause` objects.
@@ -1316,25 +1331,27 @@ class ColumnCollection(Generic[_COL]):
__slots__ = "_collection", "_index", "_colset"
- _collection: List[Tuple[str, _COL]]
- _index: Dict[Union[str, int], _COL]
+ _collection: List[Tuple[_COLKEY, _COL]]
+ _index: Dict[Union[None, str, int], _COL]
_colset: Set[_COL]
- def __init__(self, columns: Optional[Iterable[Tuple[str, _COL]]] = None):
+ def __init__(
+ self, columns: Optional[Iterable[Tuple[_COLKEY, _COL]]] = None
+ ):
object.__setattr__(self, "_colset", set())
object.__setattr__(self, "_index", {})
object.__setattr__(self, "_collection", [])
if columns:
self._initial_populate(columns)
- def _initial_populate(self, iter_: Iterable[Tuple[str, _COL]]) -> None:
+ def _initial_populate(self, iter_: Iterable[Tuple[_COLKEY, _COL]]) -> None:
self._populate_separate_keys(iter_)
@property
def _all_columns(self) -> List[_COL]:
return [col for (k, col) in self._collection]
- def keys(self) -> List[str]:
+ def keys(self) -> List[_COLKEY]:
"""Return a sequence of string key names for all columns in this
collection."""
return [k for (k, col) in self._collection]
@@ -1345,7 +1362,7 @@ class ColumnCollection(Generic[_COL]):
collection."""
return [col for (k, col) in self._collection]
- def items(self) -> List[Tuple[str, _COL]]:
+ def items(self) -> List[Tuple[_COLKEY, _COL]]:
"""Return a sequence of (key, column) tuples for all columns in this
collection each consisting of a string key name and a
:class:`_sql.ColumnClause` or
@@ -1389,7 +1406,7 @@ class ColumnCollection(Generic[_COL]):
else:
return True
- def compare(self, other: ColumnCollection[Any]) -> bool:
+ def compare(self, other: ColumnCollection[Any, Any]) -> bool:
"""Compare this :class:`_expression.ColumnCollection` to another
based on the names of the keys"""
@@ -1444,7 +1461,7 @@ class ColumnCollection(Generic[_COL]):
__hash__ = None # type: ignore
def _populate_separate_keys(
- self, iter_: Iterable[Tuple[str, _COL]]
+ self, iter_: Iterable[Tuple[_COLKEY, _COL]]
) -> None:
"""populate from an iterator of (key, column)"""
cols = list(iter_)
@@ -1455,7 +1472,7 @@ class ColumnCollection(Generic[_COL]):
)
self._index.update({k: col for k, col in reversed(self._collection)})
- def add(self, column: _COL, key: Optional[str] = None) -> None:
+ def add(self, column: _COL, key: Optional[_COLKEY] = None) -> None:
"""Add a column to this :class:`_sql.ColumnCollection`.
.. note::
@@ -1467,15 +1484,19 @@ class ColumnCollection(Generic[_COL]):
object, use the :meth:`_schema.Table.append_column` method.
"""
+ colkey: _COLKEY
+
if key is None:
- key = column.key
+ colkey = column.key # type: ignore
+ else:
+ colkey = key
l = len(self._collection)
- self._collection.append((key, column))
+ self._collection.append((colkey, column))
self._colset.add(column)
self._index[l] = column
- if key not in self._index:
- self._index[key] = column
+ if colkey not in self._index:
+ self._index[colkey] = column
def __getstate__(self) -> Dict[str, Any]:
return {"_collection": self._collection, "_index": self._index}
@@ -1499,11 +1520,11 @@ class ColumnCollection(Generic[_COL]):
else:
return True
- def as_immutable(self) -> ImmutableColumnCollection[_COL]:
- """Return an "immutable" form of this
+ def as_readonly(self) -> ReadOnlyColumnCollection[_COLKEY, _COL]:
+ """Return a "read only" form of this
:class:`_sql.ColumnCollection`."""
- return ImmutableColumnCollection(self)
+ return ReadOnlyColumnCollection(self)
def corresponding_column(
self, column: _COL, require_embedded: bool = False
@@ -1605,7 +1626,10 @@ class ColumnCollection(Generic[_COL]):
return col
-class DedupeColumnCollection(ColumnCollection[_COL]):
+_NAMEDCOL = TypeVar("_NAMEDCOL", bound="NamedColumn[Any]")
+
+
+class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
"""A :class:`_expression.ColumnCollection`
that maintains deduplicating behavior.
@@ -1618,7 +1642,7 @@ class DedupeColumnCollection(ColumnCollection[_COL]):
"""
- def add(self, column: _COL, key: Optional[str] = None) -> None:
+ def add(self, column: _NAMEDCOL, key: Optional[str] = None) -> None:
if key is not None and column.key != key:
raise exc.ArgumentError(
@@ -1653,7 +1677,7 @@ class DedupeColumnCollection(ColumnCollection[_COL]):
self._index[key] = column
def _populate_separate_keys(
- self, iter_: Iterable[Tuple[str, _COL]]
+ self, iter_: Iterable[Tuple[str, _NAMEDCOL]]
) -> None:
"""populate from an iterator of (key, column)"""
cols = list(iter_)
@@ -1679,10 +1703,10 @@ class DedupeColumnCollection(ColumnCollection[_COL]):
for col in replace_col:
self.replace(col)
- def extend(self, iter_: Iterable[_COL]) -> None:
- self._populate_separate_keys((col.key, col) for col in iter_)
+ def extend(self, iter_: Iterable[_NAMEDCOL]) -> None:
+ self._populate_separate_keys((col.key, col) for col in iter_) # type: ignore # noqa: E501
- def remove(self, column: _COL) -> None:
+ def remove(self, column: _NAMEDCOL) -> None:
if column not in self._colset:
raise ValueError(
"Can't remove column %r; column is not in this collection"
@@ -1699,7 +1723,7 @@ class DedupeColumnCollection(ColumnCollection[_COL]):
# delete higher index
del self._index[len(self._collection)]
- def replace(self, column: _COL) -> None:
+ def replace(self, column: _NAMEDCOL) -> None:
"""add the given column to this collection, removing unaliased
versions of this column as well as existing columns with the
same key.
@@ -1726,7 +1750,7 @@ class DedupeColumnCollection(ColumnCollection[_COL]):
if column.key in self._index:
remove_col.add(self._index[column.key])
- new_cols = []
+ new_cols: List[Tuple[str, _NAMEDCOL]] = []
replaced = False
for k, col in self._collection:
if col in remove_col:
@@ -1752,8 +1776,8 @@ class DedupeColumnCollection(ColumnCollection[_COL]):
self._index.update(self._collection)
-class ImmutableColumnCollection(
- util.ImmutableContainer, ColumnCollection[_COL]
+class ReadOnlyColumnCollection(
+ util.ReadOnlyContainer, ColumnCollection[_COLKEY, _COL]
):
__slots__ = ("_parent",)
@@ -1771,13 +1795,13 @@ class ImmutableColumnCollection(
self.__init__(parent) # type: ignore
def add(self, column: Any, key: Any = ...) -> Any:
- self._immutable()
+ self._readonly()
- def extend(self, elements: Any) -> None:
- self._immutable()
+ def extend(self, elements: Any) -> NoReturn:
+ self._readonly()
- def remove(self, item: Any) -> None:
- self._immutable()
+ def remove(self, item: Any) -> NoReturn:
+ self._readonly()
class ColumnSet(util.OrderedSet["ColumnClause[Any]"]):
diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py
index 35cd33a18..ccc8fba8d 100644
--- a/lib/sqlalchemy/sql/coercions.py
+++ b/lib/sqlalchemy/sql/coercions.py
@@ -46,6 +46,7 @@ if typing.TYPE_CHECKING:
from . import schema
from . import selectable
from . import traversals
+ from ._typing import _ColumnsClauseArgument
from .elements import ClauseElement
from .elements import ColumnClause
from .elements import ColumnElement
@@ -166,6 +167,32 @@ def expect(
@overload
def expect(
+ role: Type[roles.DMLTableRole],
+ element: Any,
+ *,
+ apply_propagate_attrs: Optional[ClauseElement] = None,
+ argname: Optional[str] = None,
+ post_inspect: bool = False,
+ **kw: Any,
+) -> roles.DMLTableRole:
+ ...
+
+
+@overload
+def expect(
+ role: Type[roles.ColumnsClauseRole],
+ element: Any,
+ *,
+ apply_propagate_attrs: Optional[ClauseElement] = None,
+ argname: Optional[str] = None,
+ post_inspect: bool = False,
+ **kw: Any,
+) -> roles.ColumnsClauseRole:
+ ...
+
+
+@overload
+def expect(
role: Type[_SR],
element: Any,
*,
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 7fd37e9b1..a2f731ac9 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -52,7 +52,6 @@ from typing import Type
from typing import TYPE_CHECKING
from typing import Union
-from sqlalchemy.sql.ddl import DDLElement
from . import base
from . import coercions
from . import crud
@@ -79,10 +78,12 @@ from ..util.typing import Protocol
from ..util.typing import TypedDict
if typing.TYPE_CHECKING:
+ from . import roles
from .annotation import _AnnotationDict
from .base import _AmbiguousTableNameMap
from .base import CompileState
from .cache_key import CacheKey
+ from .ddl import DDLElement
from .dml import Insert
from .dml import UpdateBase
from .dml import ValuesBase
@@ -724,7 +725,7 @@ class SQLCompiler(Compiled):
"""list of columns for which onupdate default values should be evaluated
before an UPDATE takes place"""
- returning: Optional[List[ColumnClause[Any]]]
+ returning: Optional[Sequence[roles.ColumnsClauseRole]]
"""list of columns that will be delivered to cursor.description or
dialect equivalent via the RETURNING clause on an INSERT, UPDATE, or DELETE
@@ -4099,7 +4100,9 @@ class SQLCompiler(Compiled):
return " FOR UPDATE"
def returning_clause(
- self, stmt: UpdateBase, returning_cols: List[ColumnClause[Any]]
+ self,
+ stmt: UpdateBase,
+ returning_cols: Sequence[roles.ColumnsClauseRole],
) -> str:
raise exc.CompileError(
"RETURNING is not supported by this "
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index 533a2f6cd..91a3f70c9 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -39,6 +39,7 @@ from ..util.typing import Literal
if TYPE_CHECKING:
from .compiler import _BindNameForColProtocol
from .compiler import SQLCompiler
+ from .dml import _DMLColumnElement
from .dml import DMLState
from .dml import Insert
from .dml import Update
@@ -129,8 +130,10 @@ def _get_crud_params(
[],
)
- stmt_parameter_tuples: Optional[List[Any]]
- spd: Optional[MutableMapping[str, Any]]
+ stmt_parameter_tuples: Optional[
+ List[Tuple[Union[str, ColumnClause[Any]], Any]]
+ ]
+ spd: Optional[MutableMapping[_DMLColumnElement, Any]]
if compile_state._has_multi_parameters:
mp = compile_state._multi_parameters
@@ -355,8 +358,8 @@ def _handle_values_anonymous_param(compiler, col, value, name, **kw):
def _key_getters_for_crud_column(
compiler: SQLCompiler, stmt: ValuesBase, compile_state: DMLState
) -> Tuple[
- Callable[[Union[str, Column[Any]]], Union[str, Tuple[str, str]]],
- Callable[[Column[Any]], Union[str, Tuple[str, str]]],
+ Callable[[Union[str, ColumnClause[Any]]], Union[str, Tuple[str, str]]],
+ Callable[[ColumnClause[Any]], Union[str, Tuple[str, str]]],
_BindNameForColProtocol,
]:
if dml.isupdate(compile_state) and compile_state._extra_froms:
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index f5fb6b2f3..0c9056aee 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -15,18 +15,29 @@ import collections.abc as collections_abc
import operator
import typing
from typing import Any
+from typing import cast
+from typing import Dict
+from typing import Iterable
from typing import List
from typing import MutableMapping
+from typing import NoReturn
from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Type
from typing import TYPE_CHECKING
+from typing import Union
from . import coercions
from . import roles
from . import util as sql_util
+from ._typing import is_column_element
+from ._typing import is_named_from_clause
from .base import _entity_namespace_key
from .base import _exclusive_against
from .base import _from_objects
from .base import _generative
+from .base import _select_iterables
from .base import ColumnCollection
from .base import CompileState
from .base import DialectKWArgs
@@ -34,7 +45,9 @@ from .base import Executable
from .base import HasCompileState
from .elements import BooleanClauseList
from .elements import ClauseElement
+from .elements import ColumnElement
from .elements import Null
+from .selectable import FromClause
from .selectable import HasCTE
from .selectable import HasPrefixes
from .selectable import ReturnsRows
@@ -45,16 +58,25 @@ from .. import exc
from .. import util
from ..util.typing import TypeGuard
-
if TYPE_CHECKING:
- def isupdate(dml) -> TypeGuard[UpdateDMLState]:
+ from ._typing import _ColumnsClauseArgument
+ from ._typing import _DMLColumnArgument
+ from ._typing import _FromClauseArgument
+ from ._typing import _HasClauseElement
+ from ._typing import _SelectIterable
+ from .base import ReadOnlyColumnCollection
+ from .compiler import SQLCompiler
+ from .elements import ColumnClause
+ from .selectable import Select
+
+ def isupdate(dml: DMLState) -> TypeGuard[UpdateDMLState]:
...
- def isdelete(dml) -> TypeGuard[DeleteDMLState]:
+ def isdelete(dml: DMLState) -> TypeGuard[DeleteDMLState]:
...
- def isinsert(dml) -> TypeGuard[InsertDMLState]:
+ def isinsert(dml: DMLState) -> TypeGuard[InsertDMLState]:
...
else:
@@ -63,27 +85,43 @@ else:
isinsert = operator.attrgetter("isinsert")
+_DMLColumnElement = Union[str, "ColumnClause[Any]"]
+
+
class DMLState(CompileState):
_no_parameters = True
- _dict_parameters: Optional[MutableMapping[str, Any]] = None
- _multi_parameters: Optional[List[MutableMapping[str, Any]]] = None
- _ordered_values = None
- _parameter_ordering = None
+ _dict_parameters: Optional[MutableMapping[_DMLColumnElement, Any]] = None
+ _multi_parameters: Optional[
+ List[MutableMapping[_DMLColumnElement, Any]]
+ ] = None
+ _ordered_values: Optional[List[Tuple[_DMLColumnElement, Any]]] = None
+ _parameter_ordering: Optional[List[_DMLColumnElement]] = None
_has_multi_parameters = False
isupdate = False
isdelete = False
isinsert = False
- def __init__(self, statement, compiler, **kw):
+ statement: UpdateBase
+
+ def __init__(
+ self, statement: UpdateBase, compiler: SQLCompiler, **kw: Any
+ ):
raise NotImplementedError()
@classmethod
- def get_entity_description(cls, statement):
- return {"name": statement.table.name, "table": statement.table}
+ def get_entity_description(cls, statement: UpdateBase) -> Dict[str, Any]:
+ return {
+ "name": statement.table.name
+ if is_named_from_clause(statement.table)
+ else None,
+ "table": statement.table,
+ }
@classmethod
- def get_returning_column_descriptions(cls, statement):
+ def get_returning_column_descriptions(
+ cls, statement: UpdateBase
+ ) -> List[Dict[str, Any]]:
return [
{
"name": c.key,
@@ -94,11 +132,21 @@ class DMLState(CompileState):
]
@property
- def dml_table(self):
+ def dml_table(self) -> roles.DMLTableRole:
return self.statement.table
+ if TYPE_CHECKING:
+
+ @classmethod
+ def get_plugin_class(cls, statement: Executable) -> Type[DMLState]:
+ ...
+
@classmethod
- def _get_crud_kv_pairs(cls, statement, kv_iterator):
+ def _get_crud_kv_pairs(
+ cls,
+ statement: UpdateBase,
+ kv_iterator: Iterable[Tuple[_DMLColumnArgument, Any]],
+ ) -> List[Tuple[_DMLColumnElement, Any]]:
return [
(
coercions.expect(roles.DMLColumnRole, k),
@@ -112,8 +160,8 @@ class DMLState(CompileState):
for k, v in kv_iterator
]
- def _make_extra_froms(self, statement):
- froms = []
+ def _make_extra_froms(self, statement: DMLWhereBase) -> List[FromClause]:
+ froms: List[FromClause] = []
all_tables = list(sql_util.tables_from_leftmost(statement.table))
seen = {all_tables[0]}
@@ -127,7 +175,7 @@ class DMLState(CompileState):
froms.extend(all_tables[1:])
return froms
- def _process_multi_values(self, statement):
+ def _process_multi_values(self, statement: ValuesBase) -> None:
if not statement._supports_multi_parameters:
raise exc.InvalidRequestError(
"%s construct does not support "
@@ -135,7 +183,7 @@ class DMLState(CompileState):
)
for parameters in statement._multi_values:
- multi_parameters = [
+ multi_parameters: List[MutableMapping[_DMLColumnElement, Any]] = [
{
c.key: value
for c, value in zip(statement.table.c, parameter_set)
@@ -153,9 +201,10 @@ class DMLState(CompileState):
elif not self._has_multi_parameters:
self._cant_mix_formats_error()
else:
+ assert self._multi_parameters
self._multi_parameters.extend(multi_parameters)
- def _process_values(self, statement):
+ def _process_values(self, statement: ValuesBase) -> None:
if self._no_parameters:
self._has_multi_parameters = False
self._dict_parameters = statement._values
@@ -163,11 +212,12 @@ class DMLState(CompileState):
elif self._has_multi_parameters:
self._cant_mix_formats_error()
- def _process_ordered_values(self, statement):
+ def _process_ordered_values(self, statement: ValuesBase) -> None:
parameters = statement._ordered_values
if self._no_parameters:
self._no_parameters = False
+ assert parameters is not None
self._dict_parameters = dict(parameters)
self._ordered_values = parameters
self._parameter_ordering = [key for key, value in parameters]
@@ -179,7 +229,8 @@ class DMLState(CompileState):
"with any other values() call"
)
- def _process_select_values(self, statement):
+ def _process_select_values(self, statement: ValuesBase) -> None:
+ assert statement._select_names is not None
parameters = {
coercions.expect(roles.DMLColumnRole, name, as_key=True): Null()
for name in statement._select_names
@@ -193,7 +244,7 @@ class DMLState(CompileState):
# does not allow this construction to occur
assert False, "This statement already has parameters"
- def _cant_mix_formats_error(self):
+ def _cant_mix_formats_error(self) -> NoReturn:
raise exc.InvalidRequestError(
"Can't mix single and multiple VALUES "
"formats in one INSERT statement; one style appends to a "
@@ -208,7 +259,7 @@ class InsertDMLState(DMLState):
include_table_with_column_exprs = False
- def __init__(self, statement, compiler, **kw):
+ def __init__(self, statement: Insert, compiler: SQLCompiler, **kw: Any):
self.statement = statement
self.isinsert = True
@@ -226,10 +277,9 @@ class UpdateDMLState(DMLState):
include_table_with_column_exprs = False
- def __init__(self, statement, compiler, **kw):
+ def __init__(self, statement: Update, compiler: SQLCompiler, **kw: Any):
self.statement = statement
self.isupdate = True
- self._preserve_parameter_order = statement._preserve_parameter_order
if statement._ordered_values is not None:
self._process_ordered_values(statement)
elif statement._values is not None:
@@ -238,7 +288,7 @@ class UpdateDMLState(DMLState):
self._process_multi_values(statement)
self._extra_froms = ef = self._make_extra_froms(statement)
self.is_multitable = mt = ef and self._dict_parameters
- self.include_table_with_column_exprs = (
+ self.include_table_with_column_exprs = bool(
mt and compiler.render_table_with_column_in_update_from
)
@@ -247,7 +297,7 @@ class UpdateDMLState(DMLState):
class DeleteDMLState(DMLState):
isdelete = True
- def __init__(self, statement, compiler, **kw):
+ def __init__(self, statement: Delete, compiler: SQLCompiler, **kw: Any):
self.statement = statement
self.isdelete = True
@@ -271,23 +321,31 @@ class UpdateBase(
__visit_name__ = "update_base"
- _hints = util.immutabledict()
+ _hints: util.immutabledict[
+ Tuple[roles.DMLTableRole, str], str
+ ] = util.EMPTY_DICT
named_with_column = False
- table: TableClause
+ table: roles.DMLTableRole
_return_defaults = False
- _return_defaults_columns = None
- _returning = ()
+ _return_defaults_columns: Optional[
+ Tuple[roles.ColumnsClauseRole, ...]
+ ] = None
+ _returning: Tuple[roles.ColumnsClauseRole, ...] = ()
is_dml = True
- def _generate_fromclause_column_proxies(self, fromclause):
+ def _generate_fromclause_column_proxies(
+ self, fromclause: FromClause
+ ) -> None:
fromclause._columns._populate_separate_keys(
- col._make_proxy(fromclause) for col in self._returning
+ col._make_proxy(fromclause)
+ for col in self._all_selected_columns
+ if is_column_element(col)
)
- def params(self, *arg, **kw):
+ def params(self, *arg: Any, **kw: Any) -> NoReturn:
"""Set the parameters for the statement.
This method raises ``NotImplementedError`` on the base class,
@@ -302,7 +360,9 @@ class UpdateBase(
)
@_generative
- def with_dialect_options(self: SelfUpdateBase, **opt) -> SelfUpdateBase:
+ def with_dialect_options(
+ self: SelfUpdateBase, **opt: Any
+ ) -> SelfUpdateBase:
"""Add dialect options to this INSERT/UPDATE/DELETE object.
e.g.::
@@ -318,7 +378,9 @@ class UpdateBase(
return self
@_generative
- def returning(self: SelfUpdateBase, *cols) -> SelfUpdateBase:
+ def returning(
+ self: SelfUpdateBase, *cols: _ColumnsClauseArgument
+ ) -> SelfUpdateBase:
r"""Add a :term:`RETURNING` or equivalent clause to this statement.
e.g.:
@@ -397,26 +459,32 @@ class UpdateBase(
)
return self
- @property
- def _all_selected_columns(self):
- return self._returning
+ @util.non_memoized_property
+ def _all_selected_columns(self) -> _SelectIterable:
+ return [c for c in _select_iterables(self._returning)]
@property
- def exported_columns(self):
+ def exported_columns(
+ self,
+ ) -> ReadOnlyColumnCollection[Optional[str], ColumnElement[Any]]:
"""Return the RETURNING columns as a column collection for this
statement.
.. versionadded:: 1.4
"""
- # TODO: no coverage here
return ColumnCollection(
- (c.key, c) for c in self._all_selected_columns
- ).as_immutable()
+ (c.key, c)
+ for c in self._all_selected_columns
+ if is_column_element(c)
+ ).as_readonly()
@_generative
def with_hint(
- self: SelfUpdateBase, text, selectable=None, dialect_name="*"
+ self: SelfUpdateBase,
+ text: str,
+ selectable: Optional[roles.DMLTableRole] = None,
+ dialect_name: str = "*",
) -> SelfUpdateBase:
"""Add a table hint for a single table to this
INSERT/UPDATE/DELETE statement.
@@ -454,7 +522,7 @@ class UpdateBase(
return self
@property
- def entity_description(self):
+ def entity_description(self) -> Dict[str, Any]:
"""Return a :term:`plugin-enabled` description of the table and/or entity
which this DML construct is operating against.
@@ -490,7 +558,7 @@ class UpdateBase(
return meth(self)
@property
- def returning_column_descriptions(self):
+ def returning_column_descriptions(self) -> List[Dict[str, Any]]:
"""Return a :term:`plugin-enabled` description of the columns
which this DML construct is RETURNING against, in other words
the expressions established as part of :meth:`.UpdateBase.returning`.
@@ -547,18 +615,30 @@ class ValuesBase(UpdateBase):
__visit_name__ = "values_base"
_supports_multi_parameters = False
- _preserve_parameter_order = False
- select = None
- _post_values_clause = None
- _values = None
- _multi_values = ()
- _ordered_values = None
- _select_names = None
+ select: Optional[Select] = None
+ """SELECT statement for INSERT .. FROM SELECT"""
+
+ _post_values_clause: Optional[ClauseElement] = None
+ """used by extensions to Insert etc. to add additional syntacitcal
+ constructs, e.g. ON CONFLICT etc."""
+
+ _values: Optional[util.immutabledict[_DMLColumnElement, Any]] = None
+ _multi_values: Tuple[
+ Union[
+ Sequence[Dict[_DMLColumnElement, Any]],
+ Sequence[Sequence[Any]],
+ ],
+ ...,
+ ] = ()
+
+ _ordered_values: Optional[List[Tuple[_DMLColumnElement, Any]]] = None
+
+ _select_names: Optional[List[str]] = None
_inline: bool = False
- _returning = ()
+ _returning: Tuple[roles.ColumnsClauseRole, ...] = ()
- def __init__(self, table):
+ def __init__(self, table: _FromClauseArgument):
self.table = coercions.expect(
roles.DMLTableRole, table, apply_propagate_attrs=self
)
@@ -573,7 +653,14 @@ class ValuesBase(UpdateBase):
"values present",
},
)
- def values(self: SelfValuesBase, *args, **kwargs) -> SelfValuesBase:
+ def values(
+ self: SelfValuesBase,
+ *args: Union[
+ Dict[_DMLColumnArgument, Any],
+ Sequence[Any],
+ ],
+ **kwargs: Any,
+ ) -> SelfValuesBase:
r"""Specify a fixed VALUES clause for an INSERT statement, or the SET
clause for an UPDATE.
@@ -704,9 +791,7 @@ class ValuesBase(UpdateBase):
"dictionaries/tuples is accepted positionally."
)
- elif not self._preserve_parameter_order and isinstance(
- arg, collections_abc.Sequence
- ):
+ elif isinstance(arg, collections_abc.Sequence):
if arg and isinstance(arg[0], (list, dict, tuple)):
self._multi_values += (arg,)
@@ -714,18 +799,11 @@ class ValuesBase(UpdateBase):
# tuple values
arg = {c.key: value for c, value in zip(self.table.c, arg)}
- elif self._preserve_parameter_order and not isinstance(
- arg, collections_abc.Sequence
- ):
- raise ValueError(
- "When preserve_parameter_order is True, "
- "values() only accepts a list of 2-tuples"
- )
else:
# kwarg path. this is the most common path for non-multi-params
# so this is fairly quick.
- arg = kwargs
+ arg = cast("Dict[_DMLColumnArgument, Any]", kwargs)
if args:
raise exc.ArgumentError(
"Only a single dictionary/tuple or list of "
@@ -739,15 +817,11 @@ class ValuesBase(UpdateBase):
# and ensures they get the "crud"-style name when rendered.
kv_generator = DMLState.get_plugin_class(self)._get_crud_kv_pairs
-
- if self._preserve_parameter_order:
- self._ordered_values = kv_generator(self, arg)
+ coerced_arg = {k: v for k, v in kv_generator(self, arg.items())}
+ if self._values:
+ self._values = self._values.union(coerced_arg)
else:
- arg = {k: v for k, v in kv_generator(self, arg.items())}
- if self._values:
- self._values = self._values.union(arg)
- else:
- self._values = util.immutabledict(arg)
+ self._values = util.immutabledict(coerced_arg)
return self
@_generative
@@ -758,7 +832,9 @@ class ValuesBase(UpdateBase):
},
defaults={"_returning": _returning},
)
- def return_defaults(self: SelfValuesBase, *cols) -> SelfValuesBase:
+ def return_defaults(
+ self: SelfValuesBase, *cols: _DMLColumnArgument
+ ) -> SelfValuesBase:
"""Make use of a :term:`RETURNING` clause for the purpose
of fetching server-side expressions and defaults.
@@ -843,7 +919,9 @@ class ValuesBase(UpdateBase):
"""
self._return_defaults = True
- self._return_defaults_columns = cols
+ self._return_defaults_columns = tuple(
+ coercions.expect(roles.ColumnsClauseRole, c) for c in cols
+ )
return self
@@ -867,6 +945,8 @@ class Insert(ValuesBase):
is_insert = True
+ table: TableClause
+
_traverse_internals = (
[
("table", InternalTraversal.dp_clauseelement),
@@ -890,7 +970,7 @@ class Insert(ValuesBase):
+ HasCTE._has_ctes_traverse_internals
)
- def __init__(self, table):
+ def __init__(self, table: roles.FromClauseRole):
super(Insert, self).__init__(table)
@_generative
@@ -916,7 +996,10 @@ class Insert(ValuesBase):
@_generative
def from_select(
- self: SelfInsert, names, select, include_defaults=True
+ self: SelfInsert,
+ names: List[str],
+ select: Select,
+ include_defaults: bool = True,
) -> SelfInsert:
"""Return a new :class:`_expression.Insert` construct which represents
an ``INSERT...FROM SELECT`` statement.
@@ -983,10 +1066,13 @@ SelfDMLWhereBase = typing.TypeVar("SelfDMLWhereBase", bound="DMLWhereBase")
class DMLWhereBase:
- _where_criteria = ()
+ table: roles.DMLTableRole
+ _where_criteria: Tuple[ColumnElement[Any], ...] = ()
@_generative
- def where(self: SelfDMLWhereBase, *whereclause) -> SelfDMLWhereBase:
+ def where(
+ self: SelfDMLWhereBase, *whereclause: roles.ExpressionElementRole[Any]
+ ) -> SelfDMLWhereBase:
"""Return a new construct with the given expression(s) added to
its WHERE clause, joined to the existing clause via AND, if any.
@@ -1022,7 +1108,9 @@ class DMLWhereBase:
self._where_criteria += (where_criteria,)
return self
- def filter(self: SelfDMLWhereBase, *criteria) -> SelfDMLWhereBase:
+ def filter(
+ self: SelfDMLWhereBase, *criteria: roles.ExpressionElementRole[Any]
+ ) -> SelfDMLWhereBase:
"""A synonym for the :meth:`_dml.DMLWhereBase.where` method.
.. versionadded:: 1.4
@@ -1031,10 +1119,10 @@ class DMLWhereBase:
return self.where(*criteria)
- def _filter_by_zero(self):
+ def _filter_by_zero(self) -> roles.DMLTableRole:
return self.table
- def filter_by(self: SelfDMLWhereBase, **kwargs) -> SelfDMLWhereBase:
+ def filter_by(self: SelfDMLWhereBase, **kwargs: Any) -> SelfDMLWhereBase:
r"""apply the given filtering criterion as a WHERE clause
to this select.
@@ -1048,7 +1136,7 @@ class DMLWhereBase:
return self.filter(*clauses)
@property
- def whereclause(self):
+ def whereclause(self) -> Optional[ColumnElement[Any]]:
"""Return the completed WHERE clause for this :class:`.DMLWhereBase`
statement.
@@ -1079,7 +1167,6 @@ class Update(DMLWhereBase, ValuesBase):
__visit_name__ = "update"
is_update = True
- _preserve_parameter_order = False
_traverse_internals = (
[
@@ -1102,11 +1189,13 @@ class Update(DMLWhereBase, ValuesBase):
+ HasCTE._has_ctes_traverse_internals
)
- def __init__(self, table):
+ def __init__(self, table: roles.FromClauseRole):
super(Update, self).__init__(table)
@_generative
- def ordered_values(self: SelfUpdate, *args) -> SelfUpdate:
+ def ordered_values(
+ self: SelfUpdate, *args: Tuple[_DMLColumnArgument, Any]
+ ) -> SelfUpdate:
"""Specify the VALUES clause of this UPDATE statement with an explicit
parameter ordering that will be maintained in the SET clause of the
resulting UPDATE statement.
@@ -1190,7 +1279,7 @@ class Delete(DMLWhereBase, UpdateBase):
+ HasCTE._has_ctes_traverse_internals
)
- def __init__(self, table):
+ def __init__(self, table: roles.FromClauseRole):
self.table = coercions.expect(
roles.DMLTableRole, table, apply_propagate_attrs=self
)
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 691eb10ec..da1d50a53 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -70,12 +70,14 @@ from .visitors import Visitable
from .. import exc
from .. import inspection
from .. import util
-from ..util.langhelpers import TypingOnly
+from ..util import HasMemoized_ro_memoized_attribute
+from ..util import TypingOnly
from ..util.typing import Literal
if typing.TYPE_CHECKING:
- from ._typing import _ColumnExpression
+ from ._typing import _ColumnExpressionArgument
from ._typing import _PropagateAttrsType
+ from ._typing import _SelectIterable
from ._typing import _TypeEngineArgument
from .cache_key import CacheKey
from .compiler import Compiled
@@ -300,7 +302,7 @@ class ClauseElement(
is_clause_element = True
is_selectable = False
-
+ _is_column_element = False
_is_table = False
_is_textual = False
_is_from_clause = False
@@ -330,7 +332,7 @@ class ClauseElement(
) -> Iterable[ClauseElement]:
...
- @util.non_memoized_property
+ @util.ro_non_memoized_property
def _from_objects(self) -> List[FromClause]:
return []
@@ -696,6 +698,9 @@ class CompilerColumnElement(
__slots__ = ()
+# SQLCoreOperations should be suiting the ExpressionElementRole
+# and ColumnsClauseRole. however the MRO issues become too elaborate
+# at the moment.
class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
__slots__ = ()
@@ -1154,6 +1159,7 @@ class ColumnElement(
primary_key: bool = False
_is_clone_of: Optional[ColumnElement[_T]]
+ _is_column_element = True
foreign_keys: AbstractSet[ForeignKey] = frozenset()
@@ -1396,7 +1402,7 @@ class ColumnElement(
return self
@property
- def _select_iterable(self) -> Iterable[ColumnElement[Any]]:
+ def _select_iterable(self) -> _SelectIterable:
return (self,)
@util.memoized_property
@@ -2075,7 +2081,7 @@ class TextClause(
return and_(self, other)
@property
- def _select_iterable(self):
+ def _select_iterable(self) -> _SelectIterable:
return (self,)
# help in those cases where text() is
@@ -2491,9 +2497,11 @@ class ClauseList(
("operator", InternalTraversal.dp_operator),
]
+ clauses: List[ColumnElement[Any]]
+
def __init__(
self,
- *clauses: _ColumnExpression[Any],
+ *clauses: _ColumnExpressionArgument[Any],
operator: OperatorType = operators.comma_op,
group: bool = True,
group_contents: bool = True,
@@ -2541,7 +2549,7 @@ class ClauseList(
return len(self.clauses)
@property
- def _select_iterable(self):
+ def _select_iterable(self) -> _SelectIterable:
return itertools.chain.from_iterable(
[elem._select_iterable for elem in self.clauses]
)
@@ -2558,7 +2566,7 @@ class ClauseList(
coercions.expect(self._text_converter_role, clause)
)
- @util.non_memoized_property
+ @util.ro_non_memoized_property
def _from_objects(self) -> List[FromClause]:
return list(itertools.chain(*[c._from_objects for c in self.clauses]))
@@ -2580,8 +2588,12 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]):
@classmethod
def _process_clauses_for_boolean(
- cls, operator, continue_on, skip_on, clauses
- ):
+ cls,
+ operator: OperatorType,
+ continue_on: Any,
+ skip_on: Any,
+ clauses: Iterable[ColumnElement[Any]],
+ ) -> typing_Tuple[int, List[ColumnElement[Any]]]:
has_continue_on = None
convert_clauses = []
@@ -2623,9 +2635,9 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]):
operator: OperatorType,
continue_on: Any,
skip_on: Any,
- *clauses: _ColumnExpression[Any],
+ *clauses: _ColumnExpressionArgument[Any],
**kw: Any,
- ) -> BooleanClauseList:
+ ) -> ColumnElement[Any]:
lcc, convert_clauses = cls._process_clauses_for_boolean(
operator,
continue_on,
@@ -2639,7 +2651,7 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]):
if lcc > 1:
# multiple elements. Return regular BooleanClauseList
# which will link elements against the operator.
- return cls._construct_raw(operator, convert_clauses) # type: ignore[no-any-return] # noqa E501
+ return cls._construct_raw(operator, convert_clauses) # type: ignore # noqa E501
elif lcc == 1:
# just one element. return it as a single boolean element,
# not a list and discard the operator.
@@ -2663,7 +2675,9 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]):
return cls._construct_raw(operator) # type: ignore[no-any-return] # noqa E501
@classmethod
- def _construct_for_whereclause(cls, clauses):
+ def _construct_for_whereclause(
+ cls, clauses: Iterable[ColumnElement[Any]]
+ ) -> Optional[ColumnElement[bool]]:
operator, continue_on, skip_on = (
operators.and_,
True_._singleton,
@@ -2689,7 +2703,11 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]):
return None
@classmethod
- def _construct_raw(cls, operator, clauses=None):
+ def _construct_raw(
+ cls,
+ operator: OperatorType,
+ clauses: Optional[List[ColumnElement[Any]]] = None,
+ ) -> BooleanClauseList:
self = cls.__new__(cls)
self.clauses = clauses if clauses else []
self.group = True
@@ -2700,7 +2718,9 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]):
return self
@classmethod
- def and_(cls, *clauses: _ColumnExpression[bool]) -> BooleanClauseList:
+ def and_(
+ cls, *clauses: _ColumnExpressionArgument[bool]
+ ) -> ColumnElement[bool]:
r"""Produce a conjunction of expressions joined by ``AND``.
See :func:`_sql.and_` for full documentation.
@@ -2710,7 +2730,9 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]):
)
@classmethod
- def or_(cls, *clauses: _ColumnExpression[bool]) -> BooleanClauseList:
+ def or_(
+ cls, *clauses: _ColumnExpressionArgument[bool]
+ ) -> ColumnElement[bool]:
"""Produce a conjunction of expressions joined by ``OR``.
See :func:`_sql.or_` for full documentation.
@@ -2720,7 +2742,7 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]):
)
@property
- def _select_iterable(self):
+ def _select_iterable(self) -> _SelectIterable:
return (self,)
def self_group(self, against=None):
@@ -2751,7 +2773,7 @@ class Tuple(ClauseList, ColumnElement[typing_Tuple[Any, ...]]):
@util.preload_module("sqlalchemy.sql.sqltypes")
def __init__(
self,
- *clauses: _ColumnExpression[Any],
+ *clauses: _ColumnExpressionArgument[Any],
types: Optional[Sequence[_TypeEngineArgument[Any]]] = None,
):
sqltypes = util.preloaded.sql_sqltypes
@@ -2780,7 +2802,7 @@ class Tuple(ClauseList, ColumnElement[typing_Tuple[Any, ...]]):
super(Tuple, self).__init__(*init_clauses)
@property
- def _select_iterable(self):
+ def _select_iterable(self) -> _SelectIterable:
return (self,)
def _bind_param(self, operator, obj, type_=None, expanding=False):
@@ -2856,7 +2878,8 @@ class Case(ColumnElement[_T]):
def __init__(
self,
*whens: Union[
- typing_Tuple[_ColumnExpression[bool], Any], Mapping[Any, Any]
+ typing_Tuple[_ColumnExpressionArgument[bool], Any],
+ Mapping[Any, Any],
],
value: Optional[Any] = None,
else_: Optional[Any] = None,
@@ -2900,7 +2923,7 @@ class Case(ColumnElement[_T]):
else:
self.else_ = None
- @util.non_memoized_property
+ @util.ro_non_memoized_property
def _from_objects(self) -> List[FromClause]:
return list(
itertools.chain(*[x._from_objects for x in self.get_children()])
@@ -2944,7 +2967,7 @@ class Cast(WrapsColumnExpression[_T]):
def __init__(
self,
- expression: _ColumnExpression[Any],
+ expression: _ColumnExpressionArgument[Any],
type_: _TypeEngineArgument[_T],
):
self.type = type_api.to_instance(type_)
@@ -2956,7 +2979,7 @@ class Cast(WrapsColumnExpression[_T]):
)
self.typeclause = TypeClause(self.type)
- @util.non_memoized_property
+ @util.ro_non_memoized_property
def _from_objects(self) -> List[FromClause]:
return self.clause._from_objects
@@ -2995,7 +3018,7 @@ class TypeCoerce(WrapsColumnExpression[_T]):
def __init__(
self,
- expression: _ColumnExpression[Any],
+ expression: _ColumnExpressionArgument[Any],
type_: _TypeEngineArgument[_T],
):
self.type = type_api.to_instance(type_)
@@ -3006,7 +3029,7 @@ class TypeCoerce(WrapsColumnExpression[_T]):
apply_propagate_attrs=self,
)
- @util.non_memoized_property
+ @util.ro_non_memoized_property
def _from_objects(self) -> List[FromClause]:
return self.clause._from_objects
@@ -3044,12 +3067,12 @@ class Extract(ColumnElement[int]):
expr: ColumnElement[Any]
field: str
- def __init__(self, field: str, expr: _ColumnExpression[Any]):
+ def __init__(self, field: str, expr: _ColumnExpressionArgument[Any]):
self.type = type_api.INTEGERTYPE
self.field = field
self.expr = coercions.expect(roles.ExpressionElementRole, expr)
- @util.non_memoized_property
+ @util.ro_non_memoized_property
def _from_objects(self) -> List[FromClause]:
return self.expr._from_objects
@@ -3076,7 +3099,7 @@ class _label_reference(ColumnElement[_T]):
def __init__(self, element: ColumnElement[_T]):
self.element = element
- @util.non_memoized_property
+ @util.ro_non_memoized_property
def _from_objects(self) -> List[FromClause]:
return []
@@ -3142,7 +3165,7 @@ class UnaryExpression(ColumnElement[_T]):
@classmethod
def _create_nulls_first(
cls,
- column: _ColumnExpression[_T],
+ column: _ColumnExpressionArgument[_T],
) -> UnaryExpression[_T]:
return UnaryExpression(
coercions.expect(roles.ByOfRole, column),
@@ -3153,7 +3176,7 @@ class UnaryExpression(ColumnElement[_T]):
@classmethod
def _create_nulls_last(
cls,
- column: _ColumnExpression[_T],
+ column: _ColumnExpressionArgument[_T],
) -> UnaryExpression[_T]:
return UnaryExpression(
coercions.expect(roles.ByOfRole, column),
@@ -3163,7 +3186,7 @@ class UnaryExpression(ColumnElement[_T]):
@classmethod
def _create_desc(
- cls, column: _ColumnExpression[_T]
+ cls, column: _ColumnExpressionArgument[_T]
) -> UnaryExpression[_T]:
return UnaryExpression(
coercions.expect(roles.ByOfRole, column),
@@ -3174,7 +3197,7 @@ class UnaryExpression(ColumnElement[_T]):
@classmethod
def _create_asc(
cls,
- column: _ColumnExpression[_T],
+ column: _ColumnExpressionArgument[_T],
) -> UnaryExpression[_T]:
return UnaryExpression(
coercions.expect(roles.ByOfRole, column),
@@ -3185,7 +3208,7 @@ class UnaryExpression(ColumnElement[_T]):
@classmethod
def _create_distinct(
cls,
- expr: _ColumnExpression[_T],
+ expr: _ColumnExpressionArgument[_T],
) -> UnaryExpression[_T]:
col_expr = coercions.expect(roles.ExpressionElementRole, expr)
return UnaryExpression(
@@ -3202,7 +3225,7 @@ class UnaryExpression(ColumnElement[_T]):
else:
return None
- @util.non_memoized_property
+ @util.ro_non_memoized_property
def _from_objects(self) -> List[FromClause]:
return self.element._from_objects
@@ -3238,7 +3261,7 @@ class CollectionAggregate(UnaryExpression[_T]):
@classmethod
def _create_any(
- cls, expr: _ColumnExpression[_T]
+ cls, expr: _ColumnExpressionArgument[_T]
) -> CollectionAggregate[bool]:
col_expr = coercions.expect(
roles.ExpressionElementRole,
@@ -3254,7 +3277,7 @@ class CollectionAggregate(UnaryExpression[_T]):
@classmethod
def _create_all(
- cls, expr: _ColumnExpression[_T]
+ cls, expr: _ColumnExpressionArgument[_T]
) -> CollectionAggregate[bool]:
col_expr = coercions.expect(
roles.ExpressionElementRole,
@@ -3431,7 +3454,7 @@ class BinaryExpression(ColumnElement[_T]):
def is_comparison(self):
return operators.is_comparison(self.operator)
- @util.non_memoized_property
+ @util.ro_non_memoized_property
def _from_objects(self) -> List[FromClause]:
return self.left._from_objects + self.right._from_objects
@@ -3557,7 +3580,7 @@ class Grouping(GroupedElement, ColumnElement[_T]):
else:
return []
- @util.non_memoized_property
+ @util.ro_non_memoized_property
def _from_objects(self) -> List[FromClause]:
return self.element._from_objects
@@ -3614,10 +3637,16 @@ class Over(ColumnElement[_T]):
self,
element: ColumnElement[_T],
partition_by: Optional[
- Union[Iterable[_ColumnExpression[Any]], _ColumnExpression[Any]]
+ Union[
+ Iterable[_ColumnExpressionArgument[Any]],
+ _ColumnExpressionArgument[Any],
+ ]
] = None,
order_by: Optional[
- Union[Iterable[_ColumnExpression[Any]], _ColumnExpression[Any]]
+ Union[
+ Iterable[_ColumnExpressionArgument[Any]],
+ _ColumnExpressionArgument[Any],
+ ]
] = None,
range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None,
rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None,
@@ -3697,7 +3726,7 @@ class Over(ColumnElement[_T]):
def type(self):
return self.element.type
- @util.non_memoized_property
+ @util.ro_non_memoized_property
def _from_objects(self) -> List[FromClause]:
return list(
itertools.chain(
@@ -3737,7 +3766,9 @@ class WithinGroup(ColumnElement[_T]):
order_by: Optional[ClauseList] = None
def __init__(
- self, element: FunctionElement[_T], *order_by: _ColumnExpression[Any]
+ self,
+ element: FunctionElement[_T],
+ *order_by: _ColumnExpressionArgument[Any],
):
self.element = element
if order_by is not None:
@@ -3774,7 +3805,7 @@ class WithinGroup(ColumnElement[_T]):
else:
return self.element.type
- @util.non_memoized_property
+ @util.ro_non_memoized_property
def _from_objects(self) -> List[FromClause]:
return list(
itertools.chain(
@@ -3817,7 +3848,9 @@ class FunctionFilter(ColumnElement[_T]):
criterion: Optional[ColumnElement[bool]] = None
def __init__(
- self, func: FunctionElement[_T], *criterion: _ColumnExpression[bool]
+ self,
+ func: FunctionElement[_T],
+ *criterion: _ColumnExpressionArgument[bool],
):
self.func = func
self.filter(*criterion)
@@ -3847,10 +3880,16 @@ class FunctionFilter(ColumnElement[_T]):
def over(
self,
partition_by: Optional[
- Union[Iterable[_ColumnExpression[Any]], _ColumnExpression[Any]]
+ Union[
+ Iterable[_ColumnExpressionArgument[Any]],
+ _ColumnExpressionArgument[Any],
+ ]
] = None,
order_by: Optional[
- Union[Iterable[_ColumnExpression[Any]], _ColumnExpression[Any]]
+ Union[
+ Iterable[_ColumnExpressionArgument[Any]],
+ _ColumnExpressionArgument[Any],
+ ]
] = None,
range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None,
rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None,
@@ -3890,7 +3929,7 @@ class FunctionFilter(ColumnElement[_T]):
def type(self):
return self.func.type
- @util.non_memoized_property
+ @util.ro_non_memoized_property
def _from_objects(self) -> List[FromClause]:
return list(
itertools.chain(
@@ -3903,7 +3942,97 @@ class FunctionFilter(ColumnElement[_T]):
)
-class Label(roles.LabeledColumnExprRole[_T], ColumnElement[_T]):
+class NamedColumn(ColumnElement[_T]):
+ is_literal = False
+ table: Optional[FromClause] = None
+ name: str
+ key: str
+
+ def _compare_name_for_result(self, other):
+ return (hasattr(other, "name") and self.name == other.name) or (
+ hasattr(other, "_label") and self._label == other._label
+ )
+
+ @util.ro_memoized_property
+ def description(self) -> str:
+ return self.name
+
+ @HasMemoized.memoized_attribute
+ def _tq_key_label(self):
+ """table qualified label based on column key.
+
+ for table-bound columns this is <tablename>_<column key/proxy key>;
+
+ all other expressions it resolves to key/proxy key.
+
+ """
+ proxy_key = self._proxy_key
+ if proxy_key and proxy_key != self.name:
+ return self._gen_tq_label(proxy_key)
+ else:
+ return self._tq_label
+
+ @HasMemoized.memoized_attribute
+ def _tq_label(self) -> Optional[str]:
+ """table qualified label based on column name.
+
+ for table-bound columns this is <tablename>_<columnname>; all other
+ expressions it resolves to .name.
+
+ """
+ return self._gen_tq_label(self.name)
+
+ @HasMemoized.memoized_attribute
+ def _render_label_in_columns_clause(self):
+ return True
+
+ @HasMemoized.memoized_attribute
+ def _non_anon_label(self):
+ return self.name
+
+ def _gen_tq_label(
+ self, name: str, dedupe_on_key: bool = True
+ ) -> Optional[str]:
+ return name
+
+ def _bind_param(self, operator, obj, type_=None, expanding=False):
+ return BindParameter(
+ self.key,
+ obj,
+ _compared_to_operator=operator,
+ _compared_to_type=self.type,
+ type_=type_,
+ unique=True,
+ expanding=expanding,
+ )
+
+ def _make_proxy(
+ self,
+ selectable,
+ name=None,
+ name_is_truncatable=False,
+ disallow_is_literal=False,
+ **kw,
+ ):
+ c = ColumnClause(
+ coercions.expect(roles.TruncatedLabelRole, name or self.name)
+ if name_is_truncatable
+ else (name or self.name),
+ type_=self.type,
+ _selectable=selectable,
+ is_literal=False,
+ )
+
+ c._propagate_attrs = selectable._propagate_attrs
+ if name is None:
+ c.key = self.key
+ c._proxies = [self]
+ if selectable._is_clone_of is not None:
+ c._is_clone_of = selectable._is_clone_of.columns.get(c.key)
+ return c.key, c
+
+
+class Label(roles.LabeledColumnExprRole[_T], NamedColumn[_T]):
"""Represents a column label (AS).
Represent a label, as typically applied to any column-level
@@ -3925,7 +4054,7 @@ class Label(roles.LabeledColumnExprRole[_T], ColumnElement[_T]):
def __init__(
self,
name: Optional[str],
- element: _ColumnExpression[_T],
+ element: _ColumnExpressionArgument[_T],
type_: Optional[_TypeEngineArgument[_T]] = None,
):
orig_element = element
@@ -3964,6 +4093,21 @@ class Label(roles.LabeledColumnExprRole[_T], ColumnElement[_T]):
def __reduce__(self):
return self.__class__, (self.name, self._element, self.type)
+ @HasMemoized.memoized_attribute
+ def _render_label_in_columns_clause(self):
+ return True
+
+ def _bind_param(self, operator, obj, type_=None, expanding=False):
+ return BindParameter(
+ None,
+ obj,
+ _compared_to_operator=operator,
+ type_=type_,
+ _compared_to_type=self.type,
+ unique=True,
+ expanding=expanding,
+ )
+
@util.memoized_property
def _is_implicitly_boolean(self):
return self.element._is_implicitly_boolean
@@ -4010,7 +4154,7 @@ class Label(roles.LabeledColumnExprRole[_T], ColumnElement[_T]):
)
self.key = self._tq_label = self._tq_key_label = self.name
- @util.non_memoized_property
+ @util.ro_non_memoized_property
def _from_objects(self) -> List[FromClause]:
return self.element._from_objects
@@ -4047,96 +4191,6 @@ class Label(roles.LabeledColumnExprRole[_T], ColumnElement[_T]):
return self.key, e
-class NamedColumn(ColumnElement[_T]):
- is_literal = False
- table: Optional[FromClause] = None
- name: str
- key: str
-
- def _compare_name_for_result(self, other):
- return (hasattr(other, "name") and self.name == other.name) or (
- hasattr(other, "_label") and self._label == other._label
- )
-
- @util.ro_memoized_property
- def description(self) -> str:
- return self.name
-
- @HasMemoized.memoized_attribute
- def _tq_key_label(self):
- """table qualified label based on column key.
-
- for table-bound columns this is <tablename>_<column key/proxy key>;
-
- all other expressions it resolves to key/proxy key.
-
- """
- proxy_key = self._proxy_key
- if proxy_key and proxy_key != self.name:
- return self._gen_tq_label(proxy_key)
- else:
- return self._tq_label
-
- @HasMemoized.memoized_attribute
- def _tq_label(self) -> Optional[str]:
- """table qualified label based on column name.
-
- for table-bound columns this is <tablename>_<columnname>; all other
- expressions it resolves to .name.
-
- """
- return self._gen_tq_label(self.name)
-
- @HasMemoized.memoized_attribute
- def _render_label_in_columns_clause(self):
- return True
-
- @HasMemoized.memoized_attribute
- def _non_anon_label(self):
- return self.name
-
- def _gen_tq_label(
- self, name: str, dedupe_on_key: bool = True
- ) -> Optional[str]:
- return name
-
- def _bind_param(self, operator, obj, type_=None, expanding=False):
- return BindParameter(
- self.key,
- obj,
- _compared_to_operator=operator,
- _compared_to_type=self.type,
- type_=type_,
- unique=True,
- expanding=expanding,
- )
-
- def _make_proxy(
- self,
- selectable,
- name=None,
- name_is_truncatable=False,
- disallow_is_literal=False,
- **kw,
- ):
- c = ColumnClause(
- coercions.expect(roles.TruncatedLabelRole, name or self.name)
- if name_is_truncatable
- else (name or self.name),
- type_=self.type,
- _selectable=selectable,
- is_literal=False,
- )
-
- c._propagate_attrs = selectable._propagate_attrs
- if name is None:
- c.key = self.key
- c._proxies = [self]
- if selectable._is_clone_of is not None:
- c._is_clone_of = selectable._is_clone_of.columns.get(c.key)
- return c.key, c
-
-
class ColumnClause(
roles.DDLReferredColumnRole,
roles.LabeledColumnExprRole[_T],
@@ -4242,7 +4296,7 @@ class ColumnClause(
return super(ColumnClause, self)._clone(**kw)
- @HasMemoized.memoized_attribute
+ @HasMemoized_ro_memoized_attribute
def _from_objects(self) -> List[FromClause]:
t = self.table
if t is not None:
@@ -4395,7 +4449,7 @@ class TableValuedColumn(NamedColumn[_T]):
self.scalar_alias = clone(self.scalar_alias, **kw)
self.key = self.name = self.scalar_alias.name
- @util.non_memoized_property
+ @util.ro_non_memoized_property
def _from_objects(self) -> List[FromClause]:
return [self.scalar_alias]
@@ -4409,7 +4463,7 @@ class CollationClause(ColumnElement[str]):
@classmethod
def _create_collation_expression(
- cls, expression: _ColumnExpression[str], collation: str
+ cls, expression: _ColumnExpressionArgument[str], collation: str
) -> BinaryExpression[str]:
expr = coercions.expect(roles.ExpressionElementRole, expression)
return BinaryExpression(
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
index 9e801a99f..3bca8b502 100644
--- a/lib/sqlalchemy/sql/functions.py
+++ b/lib/sqlalchemy/sql/functions.py
@@ -298,7 +298,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
return self.alias(name=name).column
- @property
+ @util.ro_non_memoized_property
def columns(self):
r"""The set of columns exported by this :class:`.FunctionElement`.
@@ -320,6 +320,11 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
SQL function expressions.
""" # noqa E501
+ return self.c
+
+ @util.ro_memoized_property
+ def c(self):
+ """synonym for :attr:`.FunctionElement.columns`."""
return ColumnCollection(
columns=[(col.key, col) for col in self._all_selected_columns]
diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py
index beb73c1b5..86725f86f 100644
--- a/lib/sqlalchemy/sql/roles.py
+++ b/lib/sqlalchemy/sql/roles.py
@@ -6,27 +6,31 @@
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
-import typing
from typing import Any
from typing import Generic
from typing import Iterable
+from typing import List
from typing import Optional
+from typing import TYPE_CHECKING
from typing import TypeVar
from .. import util
-from ..util import TypingOnly
from ..util.typing import Literal
-if typing.TYPE_CHECKING:
+if TYPE_CHECKING:
from ._typing import _PropagateAttrsType
+ from ._typing import _SelectIterable
+ from .base import _EntityNamespace
from .base import ColumnCollection
+ from .base import ReadOnlyColumnCollection
from .elements import ClauseElement
+ from .elements import ColumnClause
from .elements import ColumnElement
from .elements import Label
+ from .elements import NamedColumn
from .selectable import FromClause
from .selectable import Subquery
-
_T = TypeVar("_T", bound=Any)
@@ -109,7 +113,7 @@ class ColumnsClauseRole(AllowsLambdaRole, UsesInspection, ColumnListRole):
_role_name = "Column expression or FROM clause"
@property
- def _select_iterable(self) -> Iterable[ColumnsClauseRole]:
+ def _select_iterable(self) -> _SelectIterable:
raise NotImplementedError()
@@ -202,32 +206,51 @@ class FromClauseRole(ColumnsClauseRole, JoinTargetRole):
_is_subquery = False
- @property
- def _hide_froms(self) -> Iterable[FromClause]:
- raise NotImplementedError()
+ named_with_column: bool
+
+ if TYPE_CHECKING:
+
+ @util.ro_non_memoized_property
+ def c(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]:
+ ...
+
+ @util.ro_non_memoized_property
+ def columns(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]:
+ ...
+
+ @util.ro_non_memoized_property
+ def entity_namespace(self) -> _EntityNamespace:
+ ...
+
+ @util.ro_non_memoized_property
+ def _hide_froms(self) -> Iterable[FromClause]:
+ ...
+
+ @util.ro_non_memoized_property
+ def _from_objects(self) -> List[FromClause]:
+ ...
class StrictFromClauseRole(FromClauseRole):
__slots__ = ()
# does not allow text() or select() objects
- c: ColumnCollection[Any]
+ if TYPE_CHECKING:
- # this should be ->str , however, working around:
- # https://github.com/python/mypy/issues/12440
- @util.ro_non_memoized_property
- def description(self) -> str:
- raise NotImplementedError()
+ @util.ro_non_memoized_property
+ def description(self) -> str:
+ ...
class AnonymizedFromClauseRole(StrictFromClauseRole):
__slots__ = ()
- # calls .alias() as a post processor
- def _anonymous_fromclause(
- self, name: Optional[str] = None, flat: bool = False
- ) -> FromClause:
- raise NotImplementedError()
+ if TYPE_CHECKING:
+
+ def _anonymous_fromclause(
+ self, name: Optional[str] = None, flat: bool = False
+ ) -> FromClause:
+ ...
class ReturnsRowsRole(SQLRole):
@@ -283,6 +306,16 @@ class DMLTableRole(FromClauseRole):
__slots__ = ()
_role_name = "subject table for an INSERT, UPDATE or DELETE"
+ if TYPE_CHECKING:
+
+ @util.ro_non_memoized_property
+ def primary_key(self) -> Iterable[NamedColumn[Any]]:
+ ...
+
+ @util.ro_non_memoized_property
+ def columns(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]:
+ ...
+
class DMLColumnRole(SQLRole):
__slots__ = ()
@@ -315,36 +348,3 @@ class DDLReferredColumnRole(DDLConstraintColumnRole):
_role_name = (
"String column name or Column object for DDL foreign key constraint"
)
-
-
-class HasClauseElement(TypingOnly):
- """indicates a class that has a __clause_element__() method"""
-
- __slots__ = ()
-
- if typing.TYPE_CHECKING:
-
- def __clause_element__(self) -> ClauseElement:
- ...
-
-
-class HasColumnElementClauseElement(TypingOnly):
- """indicates a class that has a __clause_element__() method"""
-
- __slots__ = ()
-
- if typing.TYPE_CHECKING:
-
- def __clause_element__(self) -> ColumnElement[Any]:
- ...
-
-
-class HasFromClauseElement(HasClauseElement, TypingOnly):
- """indicates a class that has a __clause_element__() method"""
-
- __slots__ = ()
-
- if typing.TYPE_CHECKING:
-
- def __clause_element__(self) -> FromClause:
- ...
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index 7206cfdba..0e3e24a14 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -37,6 +37,7 @@ import typing
from typing import Any
from typing import Callable
from typing import Dict
+from typing import Iterator
from typing import List
from typing import MutableMapping
from typing import Optional
@@ -54,7 +55,6 @@ from . import ddl
from . import roles
from . import type_api
from . import visitors
-from .base import ColumnCollection
from .base import DedupeColumnCollection
from .base import DialectKWArgs
from .base import Executable
@@ -78,6 +78,7 @@ from ..util.typing import Protocol
from ..util.typing import TypeGuard
if typing.TYPE_CHECKING:
+ from .base import ReadOnlyColumnCollection
from .type_api import TypeEngine
from ..engine import Connection
from ..engine import Engine
@@ -273,6 +274,16 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause):
__visit_name__ = "table"
+ if TYPE_CHECKING:
+
+ @util.ro_non_memoized_property
+ def primary_key(self) -> PrimaryKeyConstraint:
+ ...
+
+ @util.ro_non_memoized_property
+ def foreign_keys(self) -> Set[ForeignKey]:
+ ...
+
constraints: Set[Constraint]
"""A collection of all :class:`_schema.Constraint` objects associated with
this :class:`_schema.Table`.
@@ -316,12 +327,18 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause):
]
if TYPE_CHECKING:
-
- @util.non_memoized_property
- def columns(self) -> ColumnCollection[Column[Any]]:
+ # we are upgrading .c and .columns to return Column, not
+ # ColumnClause. mypy typically sees this as incompatible because
+ # the contract of TableClause is that we can put a ColumnClause
+ # into this collection. does not recognize its immutability
+ # for the moment.
+ @util.ro_non_memoized_property
+ def columns(self) -> ReadOnlyColumnCollection[str, Column[Any]]: # type: ignore # noqa: E501
...
- c: ColumnCollection[Column[Any]]
+ @util.ro_non_memoized_property
+ def c(self) -> ReadOnlyColumnCollection[str, Column[Any]]: # type: ignore # noqa: E501
+ ...
def _gen_cache_key(self, anon_map, bindparams):
if self._annotations:
@@ -737,7 +754,7 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause):
PrimaryKeyConstraint(
_implicit_generated=True
)._set_parent_with_dispatch(self)
- self.foreign_keys = set()
+ self.foreign_keys = set() # type: ignore
self._extra_dependencies = set()
if self.schema is not None:
self.fullname = "%s.%s" % (self.schema, self.name)
@@ -3537,7 +3554,7 @@ class ColumnCollectionMixin:
"""
- columns: ColumnCollection[Column[Any]]
+ _columns: DedupeColumnCollection[Column[Any]]
_allow_multiple_tables = False
@@ -3551,7 +3568,7 @@ class ColumnCollectionMixin:
def __init__(self, *columns, **kw):
_autoattach = kw.pop("_autoattach", True)
self._column_flag = kw.pop("_column_flag", False)
- self.columns = DedupeColumnCollection()
+ self._columns = DedupeColumnCollection()
processed_expressions = kw.pop("_gather_expressions", None)
if processed_expressions is not None:
@@ -3624,6 +3641,14 @@ class ColumnCollectionMixin:
)
)
+ @util.ro_memoized_property
+ def columns(self) -> ReadOnlyColumnCollection[str, Column[Any]]:
+ return self._columns.as_readonly()
+
+ @util.ro_memoized_property
+ def c(self) -> ReadOnlyColumnCollection[str, Column[Any]]:
+ return self._columns.as_readonly()
+
def _col_expressions(self, table: Table) -> List[Column[Any]]:
return [
table.c[col] if isinstance(col, str) else col
@@ -3635,7 +3660,7 @@ class ColumnCollectionMixin:
assert isinstance(parent, Table)
for col in self._col_expressions(parent):
if col is not None:
- self.columns.add(col)
+ self._columns.add(col)
class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint):
@@ -3668,7 +3693,7 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint):
self, *columns, _autoattach=_autoattach, _column_flag=_column_flag
)
- columns: DedupeColumnCollection[Column[Any]]
+ columns: ReadOnlyColumnCollection[str, Column[Any]]
"""A :class:`_expression.ColumnCollection` representing the set of columns
for this constraint.
@@ -3679,7 +3704,7 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint):
ColumnCollectionMixin._set_parent(self, table)
def __contains__(self, x):
- return x in self.columns
+ return x in self._columns
@util.deprecated(
"1.4",
@@ -3708,7 +3733,7 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint):
initially=self.initially,
*[
_copy_expression(expr, self.parent, target_table)
- for expr in self.columns
+ for expr in self._columns
],
**constraint_kwargs,
)
@@ -3723,13 +3748,13 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint):
"""
- return self.columns.contains_column(col)
+ return self._columns.contains_column(col)
- def __iter__(self):
- return iter(self.columns)
+ def __iter__(self) -> Iterator[Column[Any]]:
+ return iter(self._columns)
- def __len__(self):
- return len(self.columns)
+ def __len__(self) -> int:
+ return len(self._columns)
class CheckConstraint(ColumnCollectionConstraint):
@@ -4002,10 +4027,10 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
self._set_parent_with_dispatch(table)
def _append_element(self, column: Column[Any], fk: ForeignKey) -> None:
- self.columns.add(column)
+ self._columns.add(column)
self.elements.append(fk)
- columns: DedupeColumnCollection[Column[Any]]
+ columns: ReadOnlyColumnCollection[str, Column[Any]]
"""A :class:`_expression.ColumnCollection` representing the set of columns
for this constraint.
@@ -4072,7 +4097,7 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
"""
if hasattr(self, "parent"):
- return self.columns.keys()
+ return self._columns.keys()
else:
return [
col.key if isinstance(col, ColumnElement) else str(col)
@@ -4095,7 +4120,7 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
"named '%s' is present." % (table.description, ke.args[0])
) from ke
- for col, fk in zip(self.columns, self.elements):
+ for col, fk in zip(self._columns, self.elements):
if not hasattr(fk, "parent") or fk.parent is not col:
fk._set_parent_with_dispatch(col)
@@ -4226,7 +4251,11 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint):
table.constraints.add(self)
table_pks = [c for c in table.c if c.primary_key]
- if self.columns and table_pks and set(table_pks) != set(self.columns):
+ if (
+ self._columns
+ and table_pks
+ and set(table_pks) != set(self._columns)
+ ):
util.warn(
"Table '%s' specifies columns %s as primary_key=True, "
"not matching locally specified columns %s; setting the "
@@ -4235,18 +4264,18 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint):
% (
table.name,
", ".join("'%s'" % c.name for c in table_pks),
- ", ".join("'%s'" % c.name for c in self.columns),
- ", ".join("'%s'" % c.name for c in self.columns),
+ ", ".join("'%s'" % c.name for c in self._columns),
+ ", ".join("'%s'" % c.name for c in self._columns),
)
)
table_pks[:] = []
- for c in self.columns:
+ for c in self._columns:
c.primary_key = True
if c._user_defined_nullable is NULL_UNSPECIFIED:
c.nullable = False
if table_pks:
- self.columns.extend(table_pks)
+ self._columns.extend(table_pks)
def _reload(self, columns):
"""repopulate this :class:`.PrimaryKeyConstraint` given
@@ -4272,14 +4301,14 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint):
for col in columns:
col.primary_key = True
- self.columns.extend(columns)
+ self._columns.extend(columns)
PrimaryKeyConstraint._autoincrement_column._reset(self)
self._set_parent_with_dispatch(self.table)
def _replace(self, col):
PrimaryKeyConstraint._autoincrement_column._reset(self)
- self.columns.replace(col)
+ self._columns.replace(col)
self.dispatch._sa_event_column_added_to_pk_constraint(self, col)
@@ -4288,9 +4317,9 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint):
autoinc = self._autoincrement_column
if autoinc is not None:
- return [autoinc] + [c for c in self.columns if c is not autoinc]
+ return [autoinc] + [c for c in self._columns if c is not autoinc]
else:
- return list(self.columns)
+ return list(self._columns)
@util.memoized_property
def _autoincrement_column(self):
@@ -4323,8 +4352,8 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint):
return False
return True
- if len(self.columns) == 1:
- col = list(self.columns)[0]
+ if len(self._columns) == 1:
+ col = list(self._columns)[0]
if col.autoincrement is True:
_validate_autoinc(col, True)
@@ -4337,7 +4366,7 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint):
else:
autoinc = None
- for col in self.columns:
+ for col in self._columns:
if col.autoincrement is True:
_validate_autoinc(col, True)
if autoinc is not None:
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index 2f37317f2..24edc1cae 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -16,13 +16,14 @@ from __future__ import annotations
import collections
from enum import Enum
import itertools
-from operator import attrgetter
import typing
from typing import Any as TODO_Any
from typing import Any
from typing import Iterable
+from typing import List
from typing import NamedTuple
from typing import Optional
+from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from typing import TypeVar
@@ -34,13 +35,15 @@ from . import roles
from . import traversals
from . import type_api
from . import visitors
-from ._typing import _ColumnsClauseElement
+from ._typing import _ColumnsClauseArgument
+from ._typing import is_column_element
from .annotation import Annotated
from .annotation import SupportsCloneAnnotations
from .base import _clone
from .base import _cloned_difference
from .base import _cloned_intersection
from .base import _entity_namespace_key
+from .base import _EntityNamespace
from .base import _expand_cloned
from .base import _from_objects
from .base import _generative
@@ -78,6 +81,13 @@ and_ = BooleanClauseList.and_
_T = TypeVar("_T", bound=Any)
+if TYPE_CHECKING:
+ from ._typing import _SelectIterable
+ from .base import ReadOnlyColumnCollection
+ from .elements import NamedColumn
+ from .schema import ForeignKey
+ from .schema import PrimaryKeyConstraint
+
class _OffsetLimitParam(BindParameter):
inherit_cache = True
@@ -111,8 +121,8 @@ class ReturnsRows(roles.ReturnsRowsRole, DQLDMLClauseElement):
def selectable(self):
return self
- @property
- def _all_selected_columns(self):
+ @util.non_memoized_property
+ def _all_selected_columns(self) -> _SelectIterable:
"""A sequence of column expression objects that represents the
"selected" columns of this :class:`_expression.ReturnsRows`.
@@ -457,7 +467,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
__visit_name__ = "fromclause"
named_with_column = False
- @property
+ @util.ro_non_memoized_property
def _hide_froms(self) -> Iterable[FromClause]:
return ()
@@ -707,10 +717,10 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
"""
- return self.columns
+ return self.c
- @util.memoized_property
- def columns(self) -> ColumnCollection[Any]:
+ @util.ro_non_memoized_property
+ def columns(self) -> ReadOnlyColumnCollection[str, Any]:
"""A named-based collection of :class:`_expression.ColumnElement`
objects maintained by this :class:`_expression.FromClause`.
@@ -723,14 +733,23 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
:return: a :class:`.ColumnCollection` object.
"""
+ return self.c
+
+ @util.ro_memoized_property
+ def c(self) -> ReadOnlyColumnCollection[str, Any]:
+ """
+ A synonym for :attr:`.FromClause.columns`
+
+ :return: a :class:`.ColumnCollection`
+ """
if "_columns" not in self.__dict__:
self._init_collections()
self._populate_column_collection()
- return self._columns.as_immutable()
+ return self._columns.as_readonly()
- @property
- def entity_namespace(self):
+ @util.ro_non_memoized_property
+ def entity_namespace(self) -> _EntityNamespace:
"""Return a namespace used for name-based access in SQL expressions.
This is the namespace that is used to resolve "filter_by()" type
@@ -743,10 +762,10 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
alternative results.
"""
- return self.columns
+ return self.c
- @util.memoized_property
- def primary_key(self):
+ @util.ro_memoized_property
+ def primary_key(self) -> Iterable[NamedColumn[Any]]:
"""Return the iterable collection of :class:`_schema.Column` objects
which comprise the primary key of this :class:`_selectable.FromClause`.
@@ -759,8 +778,8 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
self._populate_column_collection()
return self.primary_key
- @util.memoized_property
- def foreign_keys(self):
+ @util.ro_memoized_property
+ def foreign_keys(self) -> Iterable[ForeignKey]:
"""Return the collection of :class:`_schema.ForeignKey` marker objects
which this FromClause references.
@@ -791,28 +810,12 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
"""
- for key in ["_columns", "columns", "primary_key", "foreign_keys"]:
+ for key in ["_columns", "columns", "c", "primary_key", "foreign_keys"]:
self.__dict__.pop(key, None)
- # this is awkward. maybe there's a better way
- if TYPE_CHECKING:
- c: ColumnCollection[Any]
- else:
- c = property(
- attrgetter("columns"),
- doc="""
- A named-based collection of :class:`_expression.ColumnElement`
- objects maintained by this :class:`_expression.FromClause`.
-
- The :attr:`_sql.FromClause.c` attribute is an alias for the
- :attr:`_sql.FromClause.columns` attribute.
-
- :return: a :class:`.ColumnCollection`
-
- """,
- )
-
- _select_iterable = property(attrgetter("columns"))
+ @util.ro_non_memoized_property
+ def _select_iterable(self) -> _SelectIterable:
+ return self.c
def _init_collections(self):
assert "_columns" not in self.__dict__
@@ -820,8 +823,8 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
assert "foreign_keys" not in self.__dict__
self._columns = ColumnCollection()
- self.primary_key = ColumnSet()
- self.foreign_keys = set()
+ self.primary_key = ColumnSet() # type: ignore
+ self.foreign_keys = set() # type: ignore
@property
def _cols_populated(self):
@@ -1050,9 +1053,7 @@ class Join(roles.DMLTableRole, FromClause):
@util.preload_module("sqlalchemy.sql.util")
def _populate_column_collection(self):
sqlutil = util.preloaded.sql_util
- columns = [c for c in self.left.columns] + [
- c for c in self.right.columns
- ]
+ columns = [c for c in self.left.c] + [c for c in self.right.c]
self.primary_key.extend(
sqlutil.reduce_columns(
@@ -1300,14 +1301,14 @@ class Join(roles.DMLTableRole, FromClause):
.alias(name)
)
- @property
+ @util.ro_non_memoized_property
def _hide_froms(self) -> Iterable[FromClause]:
return itertools.chain(
*[_from_objects(x.left, x.right) for x in self._cloned_set]
)
- @property
- def _from_objects(self):
+ @util.ro_non_memoized_property
+ def _from_objects(self) -> List[FromClause]:
return [self] + self.left._from_objects + self.right._from_objects
@@ -1415,7 +1416,7 @@ class AliasedReturnsRows(NoInit, NamedFromClause):
self._reset_column_collection()
@property
- def _from_objects(self):
+ def _from_objects(self) -> List[FromClause]:
return [self]
@@ -2329,10 +2330,14 @@ class FromGrouping(GroupedElement, FromClause):
def _init_collections(self):
pass
- @property
+ @util.ro_non_memoized_property
def columns(self):
return self.element.columns
+ @util.ro_non_memoized_property
+ def c(self):
+ return self.element.columns
+
@property
def primary_key(self):
return self.element.primary_key
@@ -2350,12 +2355,12 @@ class FromGrouping(GroupedElement, FromClause):
def _anonymous_fromclause(self, **kw):
return FromGrouping(self.element._anonymous_fromclause(**kw))
- @property
+ @util.ro_non_memoized_property
def _hide_froms(self) -> Iterable[FromClause]:
return self.element._hide_froms
- @property
- def _from_objects(self):
+ @util.ro_non_memoized_property
+ def _from_objects(self) -> List[FromClause]:
return self.element._from_objects
def __getstate__(self):
@@ -2436,6 +2441,16 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause):
if kw:
raise exc.ArgumentError("Unsupported argument(s): %s" % list(kw))
+ if TYPE_CHECKING:
+
+ @util.ro_non_memoized_property
+ def columns(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]:
+ ...
+
+ @util.ro_non_memoized_property
+ def c(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]:
+ ...
+
def __str__(self):
if self.schema is not None:
return self.schema + "." + self.name
@@ -2507,8 +2522,8 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause):
"""
return util.preloaded.sql_dml.Delete(self)
- @property
- def _from_objects(self):
+ @util.ro_non_memoized_property
+ def _from_objects(self) -> List[FromClause]:
return [self]
@@ -2669,11 +2684,14 @@ class Values(Generative, NamedFromClause):
self._columns.add(c)
c.table = self
- @property
- def _from_objects(self):
+ @util.ro_non_memoized_property
+ def _from_objects(self) -> List[FromClause]:
return [self]
+SelfSelectBase = TypeVar("SelfSelectBase", bound=Any)
+
+
class SelectBase(
roles.SelectStatementRole,
roles.DMLSelectRole,
@@ -2697,12 +2715,27 @@ class SelectBase(
_is_select_statement = True
is_select = True
- def _generate_fromclause_column_proxies(self, fromclause):
+ def _generate_fromclause_column_proxies(
+ self, fromclause: FromClause
+ ) -> None:
raise NotImplementedError()
- def _refresh_for_new_column(self, column):
+ def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None:
self._reset_memoizations()
+ def _generate_columns_plus_names(
+ self, anon_for_dupe_key: bool
+ ) -> List[Tuple[str, str, str, ColumnElement[Any], bool]]:
+ raise NotImplementedError()
+
+ def set_label_style(
+ self: SelfSelectBase, label_style: SelectLabelStyle
+ ) -> SelfSelectBase:
+ raise NotImplementedError()
+
+ def get_label_style(self) -> SelectLabelStyle:
+ raise NotImplementedError()
+
@property
def selected_columns(self):
"""A :class:`_expression.ColumnCollection`
@@ -2733,8 +2766,8 @@ class SelectBase(
"""
raise NotImplementedError()
- @property
- def _all_selected_columns(self):
+ @util.non_memoized_property
+ def _all_selected_columns(self) -> _SelectIterable:
"""A sequence of expressions that correspond to what is rendered
in the columns clause, including :class:`_sql.TextClause`
constructs.
@@ -2893,8 +2926,8 @@ class SelectBase(
"""
return Lateral._factory(self, name)
- @property
- def _from_objects(self):
+ @util.ro_non_memoized_property
+ def _from_objects(self) -> List[FromClause]:
return [self]
def subquery(self, name=None):
@@ -2979,6 +3012,8 @@ class SelectStatementGrouping(GroupedElement, SelectBase):
_is_select_container = True
+ element: SelectBase
+
def __init__(self, element):
self.element = coercions.expect(roles.SelectStatementRole, element)
@@ -2990,37 +3025,34 @@ class SelectStatementGrouping(GroupedElement, SelectBase):
return self
def get_label_style(self) -> SelectLabelStyle:
- return self._label_style
+ return self.element.get_label_style()
def set_label_style(
self, label_style: SelectLabelStyle
- ) -> "SelectStatementGrouping":
+ ) -> SelectStatementGrouping:
return SelectStatementGrouping(
self.element.set_label_style(label_style)
)
@property
- def _label_style(self):
- return self.element._label_style
-
- @property
def select_statement(self):
return self.element
def self_group(self, against=None):
return self
- def _generate_columns_plus_names(self, anon_for_dupe_key):
+ def _generate_columns_plus_names(
+ self, anon_for_dupe_key: bool
+ ) -> List[Tuple[str, str, str, ColumnElement[Any], bool]]:
return self.element._generate_columns_plus_names(anon_for_dupe_key)
- def _generate_fromclause_column_proxies(self, subquery):
+ def _generate_fromclause_column_proxies(
+ self, subquery: FromClause
+ ) -> None:
self.element._generate_fromclause_column_proxies(subquery)
- def _generate_proxy_for_new_column(self, column, subquery):
- return self.element._generate_proxy_for_new_column(subquery)
-
- @property
- def _all_selected_columns(self):
+ @util.non_memoized_property
+ def _all_selected_columns(self) -> _SelectIterable:
return self.element._all_selected_columns
@property
@@ -3039,8 +3071,8 @@ class SelectStatementGrouping(GroupedElement, SelectBase):
"""
return self.element.selected_columns
- @property
- def _from_objects(self):
+ @util.ro_non_memoized_property
+ def _from_objects(self) -> List[FromClause]:
return self.element._from_objects
@@ -3612,10 +3644,10 @@ class CompoundSelect(HasCompileState, GenerativeSelect):
return True
return False
- def _set_label_style(self, style):
+ def set_label_style(self, style):
if self._label_style is not style:
self = self._generate()
- select_0 = self.selects[0]._set_label_style(style)
+ select_0 = self.selects[0].set_label_style(style)
self.selects = [select_0] + self.selects[1:]
return self
@@ -3665,8 +3697,8 @@ class CompoundSelect(HasCompileState, GenerativeSelect):
for select in self.selects:
select._refresh_for_new_column(column)
- @property
- def _all_selected_columns(self):
+ @util.non_memoized_property
+ def _all_selected_columns(self) -> _SelectIterable:
return self.selects[0]._all_selected_columns
@property
@@ -3701,8 +3733,18 @@ class SelectState(util.MemoizedSlots, CompileState):
"_label_resolve_dict",
)
- class default_select_compile_options(CacheableOptions):
- _cache_key_traversal = []
+ if TYPE_CHECKING:
+ default_select_compile_options: CacheableOptions
+ else:
+
+ class default_select_compile_options(CacheableOptions):
+ _cache_key_traversal = []
+
+ if TYPE_CHECKING:
+
+ @classmethod
+ def get_plugin_class(cls, statement: Select) -> SelectState:
+ ...
def __init__(self, statement, compiler, **kw):
self.statement = statement
@@ -3966,7 +4008,7 @@ class SelectState(util.MemoizedSlots, CompileState):
return None
@classmethod
- def all_selected_columns(cls, statement):
+ def all_selected_columns(cls, statement: Select) -> _SelectIterable:
return [c for c in _select_iterables(statement._raw_columns)]
def _setup_joins(self, args, raw_columns):
@@ -4205,15 +4247,17 @@ class Select(
_memoized_select_entities: Tuple[TODO_Any, ...] = ()
_distinct = False
- _distinct_on: Tuple[ColumnElement, ...] = ()
+ _distinct_on: Tuple[ColumnElement[Any], ...] = ()
_correlate: Tuple[FromClause, ...] = ()
_correlate_except: Optional[Tuple[FromClause, ...]] = None
- _where_criteria: Tuple[ColumnElement, ...] = ()
- _having_criteria: Tuple[ColumnElement, ...] = ()
+ _where_criteria: Tuple[ColumnElement[Any], ...] = ()
+ _having_criteria: Tuple[ColumnElement[Any], ...] = ()
_from_obj: Tuple[FromClause, ...] = ()
_auto_correlate = True
- _compile_options = SelectState.default_select_compile_options
+ _compile_options: CacheableOptions = (
+ SelectState.default_select_compile_options
+ )
_traverse_internals = (
[
@@ -4264,7 +4308,7 @@ class Select(
stmt.__dict__.update(kw)
return stmt
- def __init__(self, *entities: _ColumnsClauseElement):
+ def __init__(self, *entities: _ColumnsClauseArgument):
r"""Construct a new :class:`_expression.Select`.
The public constructor for :class:`_expression.Select` is the
@@ -4286,7 +4330,7 @@ class Select(
cols = list(elem._select_iterable)
return cols[0].type
- def filter(self, *criteria):
+ def filter(self: SelfSelect, *criteria: ColumnElement[Any]) -> SelfSelect:
"""A synonym for the :meth:`_future.Select.where` method."""
return self.where(*criteria)
@@ -4896,7 +4940,7 @@ class Select(
return self
@property
- def whereclause(self):
+ def whereclause(self) -> Optional[ColumnElement[Any]]:
"""Return the completed WHERE clause for this
:class:`_expression.Select` statement.
@@ -5161,12 +5205,12 @@ class Select(
[
(conv(c), c)
for c in self._all_selected_columns
- if not c._is_text_clause
+ if is_column_element(c)
]
- ).as_immutable()
+ ).as_readonly()
@HasMemoized.memoized_attribute
- def _all_selected_columns(self):
+ def _all_selected_columns(self) -> Sequence[ColumnElement[Any]]:
meth = SelectState.get_plugin_class(self).all_selected_columns
return list(meth(self))
@@ -5175,7 +5219,9 @@ class Select(
self = self.set_label_style(LABEL_STYLE_DISAMBIGUATE_ONLY)
return self
- def _generate_columns_plus_names(self, anon_for_dupe_key):
+ def _generate_columns_plus_names(
+ self, anon_for_dupe_key: bool
+ ) -> List[Tuple[str, str, str, ColumnElement[Any], bool]]:
"""Generate column names as rendered in a SELECT statement by
the compiler.
@@ -5805,13 +5851,13 @@ class TextualSelect(SelectBase):
"""
return ColumnCollection(
(c.key, c) for c in self.column_args
- ).as_immutable()
+ ).as_readonly()
- @property
- def _all_selected_columns(self):
+ @util.non_memoized_property
+ def _all_selected_columns(self) -> _SelectIterable:
return self.column_args
- def _set_label_style(self, style):
+ def set_label_style(self, style):
return self
def _ensure_disambiguated_names(self):
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index 5114a2431..cdce49f7b 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -15,6 +15,7 @@ from itertools import chain
import typing
from typing import Any
from typing import cast
+from typing import Iterator
from typing import Optional
from . import coercions
@@ -33,6 +34,7 @@ from .elements import _find_columns # noqa
from .elements import _label_reference
from .elements import _textual_label_reference
from .elements import BindParameter
+from .elements import ClauseElement # noqa
from .elements import ColumnClause
from .elements import ColumnElement
from .elements import Grouping
@@ -51,6 +53,7 @@ from .. import exc
from .. import util
if typing.TYPE_CHECKING:
+ from .roles import FromClauseRole
from ..engine.interfaces import _AnyExecuteParams
from ..engine.interfaces import _AnyMultiExecuteParams
from ..engine.interfaces import _AnySingleExecuteParams
@@ -404,7 +407,7 @@ def clause_is_present(clause, search):
return False
-def tables_from_leftmost(clause):
+def tables_from_leftmost(clause: FromClauseRole) -> Iterator[FromClause]:
if isinstance(clause, Join):
for t in tables_from_leftmost(clause.left):
yield t
diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py
index 7e616cd74..406c8af24 100644
--- a/lib/sqlalchemy/util/__init__.py
+++ b/lib/sqlalchemy/util/__init__.py
@@ -21,9 +21,7 @@ from ._collections import flatten_iterator as flatten_iterator
from ._collections import has_dupes as has_dupes
from ._collections import has_intersection as has_intersection
from ._collections import IdentitySet as IdentitySet
-from ._collections import ImmutableContainer as ImmutableContainer
from ._collections import immutabledict as immutabledict
-from ._collections import ImmutableProperties as ImmutableProperties
from ._collections import LRUCache as LRUCache
from ._collections import merge_lists_w_ordering as merge_lists_w_ordering
from ._collections import ordered_column_set as ordered_column_set
@@ -33,6 +31,8 @@ from ._collections import OrderedProperties as OrderedProperties
from ._collections import OrderedSet as OrderedSet
from ._collections import PopulateDict as PopulateDict
from ._collections import Properties as Properties
+from ._collections import ReadOnlyContainer as ReadOnlyContainer
+from ._collections import ReadOnlyProperties as ReadOnlyProperties
from ._collections import ScopedRegistry as ScopedRegistry
from ._collections import sort_dictionary as sort_dictionary
from ._collections import ThreadLocalRegistry as ThreadLocalRegistry
@@ -107,6 +107,9 @@ from .langhelpers import get_func_kwargs as get_func_kwargs
from .langhelpers import getargspec_init as getargspec_init
from .langhelpers import has_compiled_ext as has_compiled_ext
from .langhelpers import HasMemoized as HasMemoized
+from .langhelpers import (
+ HasMemoized_ro_memoized_attribute as HasMemoized_ro_memoized_attribute,
+)
from .langhelpers import hybridmethod as hybridmethod
from .langhelpers import hybridproperty as hybridproperty
from .langhelpers import inject_docstring_text as inject_docstring_text
diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py
index 2d974b737..bd73bf714 100644
--- a/lib/sqlalchemy/util/_collections.py
+++ b/lib/sqlalchemy/util/_collections.py
@@ -38,13 +38,13 @@ from .typing import Protocol
if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
from ._py_collections import immutabledict as immutabledict
from ._py_collections import IdentitySet as IdentitySet
- from ._py_collections import ImmutableContainer as ImmutableContainer
+ from ._py_collections import ReadOnlyContainer as ReadOnlyContainer
from ._py_collections import ImmutableDictBase as ImmutableDictBase
from ._py_collections import OrderedSet as OrderedSet
from ._py_collections import unique_list as unique_list
else:
from sqlalchemy.cyextension.immutabledict import (
- ImmutableContainer as ImmutableContainer,
+ ReadOnlyContainer as ReadOnlyContainer,
)
from sqlalchemy.cyextension.immutabledict import (
ImmutableDictBase as ImmutableDictBase,
@@ -213,10 +213,10 @@ class Properties(Generic[_T]):
def __contains__(self, key: str) -> bool:
return key in self._data
- def as_immutable(self) -> "ImmutableProperties[_T]":
+ def as_readonly(self) -> "ReadOnlyProperties[_T]":
"""Return an immutable proxy for this :class:`.Properties`."""
- return ImmutableProperties(self._data)
+ return ReadOnlyProperties(self._data)
def update(self, value):
self._data.update(value)
@@ -263,7 +263,7 @@ class OrderedProperties(Properties[_T]):
Properties.__init__(self, OrderedDict())
-class ImmutableProperties(ImmutableContainer, Properties[_T]):
+class ReadOnlyProperties(ReadOnlyContainer, Properties[_T]):
"""Provide immutable dict/object attribute to an underlying dictionary."""
__slots__ = ()
diff --git a/lib/sqlalchemy/util/_py_collections.py b/lib/sqlalchemy/util/_py_collections.py
index d50352930..1016871aa 100644
--- a/lib/sqlalchemy/util/_py_collections.py
+++ b/lib/sqlalchemy/util/_py_collections.py
@@ -29,37 +29,45 @@ _KT = TypeVar("_KT", bound=Any)
_VT = TypeVar("_VT", bound=Any)
-class ImmutableContainer:
+class ReadOnlyContainer:
__slots__ = ()
+ def _readonly(self, *arg: Any, **kw: Any) -> NoReturn:
+ raise TypeError(
+ "%s object is immutable and/or readonly" % self.__class__.__name__
+ )
+
def _immutable(self, *arg: Any, **kw: Any) -> NoReturn:
raise TypeError("%s object is immutable" % self.__class__.__name__)
def __delitem__(self, key: Any) -> NoReturn:
- self._immutable()
+ self._readonly()
def __setitem__(self, key: Any, value: Any) -> NoReturn:
- self._immutable()
+ self._readonly()
def __setattr__(self, key: str, value: Any) -> NoReturn:
- self._immutable()
+ self._readonly()
-class ImmutableDictBase(ImmutableContainer, Dict[_KT, _VT]):
- def clear(self) -> NoReturn:
+class ImmutableDictBase(ReadOnlyContainer, Dict[_KT, _VT]):
+ def _readonly(self, *arg: Any, **kw: Any) -> NoReturn:
self._immutable()
+ def clear(self) -> NoReturn:
+ self._readonly()
+
def pop(self, key: Any, default: Optional[Any] = None) -> NoReturn:
- self._immutable()
+ self._readonly()
def popitem(self) -> NoReturn:
- self._immutable()
+ self._readonly()
def setdefault(self, key: Any, default: Optional[Any] = None) -> NoReturn:
- self._immutable()
+ self._readonly()
def update(self, *arg: Any, **kw: Any) -> NoReturn:
- self._immutable()
+ self._readonly()
class immutabledict(ImmutableDictBase[_KT, _VT]):
diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py
index 8cf50c724..9e1194e23 100644
--- a/lib/sqlalchemy/util/langhelpers.py
+++ b/lib/sqlalchemy/util/langhelpers.py
@@ -1248,6 +1248,7 @@ if TYPE_CHECKING:
# of a property, meaning assignment needs to be disallowed
ro_memoized_property = property
ro_non_memoized_property = property
+
else:
memoized_property = ro_memoized_property = _memoized_property
non_memoized_property = ro_non_memoized_property = _non_memoized_property
@@ -1348,6 +1349,12 @@ class HasMemoized:
return update_wrapper(oneshot, fn)
+if TYPE_CHECKING:
+ HasMemoized_ro_memoized_attribute = property
+else:
+ HasMemoized_ro_memoized_attribute = HasMemoized.memoized_attribute
+
+
class MemoizedSlots:
"""Apply memoized items to an object using a __getattr__ scheme.
diff --git a/pyproject.toml b/pyproject.toml
index aa2790b04..cc79e8646 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -81,7 +81,6 @@ module = [
"sqlalchemy.sql.selectable", # would be nice as strict
"sqlalchemy.sql.functions", # would be nice as strict
"sqlalchemy.sql.lambdas",
- "sqlalchemy.sql.dml", # would be nice as strict
"sqlalchemy.sql.util",
# not yet classified:
diff --git a/test/base/test_utils.py b/test/base/test_utils.py
index 67fcc8870..fc61e39b6 100644
--- a/test/base/test_utils.py
+++ b/test/base/test_utils.py
@@ -357,8 +357,8 @@ class ImmutableTest(fixtures.TestBase):
with expect_raises_message(TypeError, "object is immutable"):
m()
- def test_immutable_properties(self):
- d = util.ImmutableProperties({3: 4})
+ def test_readonly_properties(self):
+ d = util.ReadOnlyProperties({3: 4})
calls = (
lambda: d.__delitem__(1),
lambda: d.__setitem__(2, 3),
@@ -563,7 +563,7 @@ class ColumnCollectionCommon(testing.AssertsCompiledSQL):
eq_(keys, ["c1", "foo", "c3"])
ne_(id(keys), id(cc.keys()))
- ci = cc.as_immutable()
+ ci = cc.as_readonly()
eq_(ci.keys(), ["c1", "foo", "c3"])
def test_values(self):
@@ -576,7 +576,7 @@ class ColumnCollectionCommon(testing.AssertsCompiledSQL):
eq_(val, [c1, c2, c3])
ne_(id(val), id(cc.values()))
- ci = cc.as_immutable()
+ ci = cc.as_readonly()
eq_(ci.values(), [c1, c2, c3])
def test_items(self):
@@ -589,7 +589,7 @@ class ColumnCollectionCommon(testing.AssertsCompiledSQL):
eq_(items, [("c1", c1), ("foo", c2), ("c3", c3)])
ne_(id(items), id(cc.items()))
- ci = cc.as_immutable()
+ ci = cc.as_readonly()
eq_(ci.items(), [("c1", c1), ("foo", c2), ("c3", c3)])
def test_key_index_error(self):
@@ -732,7 +732,7 @@ class ColumnCollectionTest(ColumnCollectionCommon, fixtures.TestBase):
self._assert_collection_integrity(cc)
- ci = cc.as_immutable()
+ ci = cc.as_readonly()
eq_(ci._all_columns, [c1, c2a, c3, c2b])
eq_(list(ci), [c1, c2a, c3, c2b])
eq_(ci.keys(), ["c1", "c2", "c3", "c2"])
@@ -763,7 +763,7 @@ class ColumnCollectionTest(ColumnCollectionCommon, fixtures.TestBase):
self._assert_collection_integrity(cc)
- ci = cc.as_immutable()
+ ci = cc.as_readonly()
eq_(ci._all_columns, [c1, c2a, c3, c2b])
eq_(list(ci), [c1, c2a, c3, c2b])
eq_(ci.keys(), ["c1", "c2", "c3", "c2"])
@@ -786,7 +786,7 @@ class ColumnCollectionTest(ColumnCollectionCommon, fixtures.TestBase):
assert cc.contains_column(c2)
self._assert_collection_integrity(cc)
- ci = cc.as_immutable()
+ ci = cc.as_readonly()
eq_(ci._all_columns, [c1, c2, c3, c2])
eq_(list(ci), [c1, c2, c3, c2])
@@ -821,7 +821,7 @@ class DedupeColumnCollectionTest(ColumnCollectionCommon, fixtures.TestBase):
c2.key = "foo"
cc = self._column_collection(columns=[("c1", c1), ("foo", c2)])
- ci = cc.as_immutable()
+ ci = cc.as_readonly()
d = {"cc": cc, "ci": ci}
@@ -922,7 +922,7 @@ class DedupeColumnCollectionTest(ColumnCollectionCommon, fixtures.TestBase):
assert cc.contains_column(c2)
self._assert_collection_integrity(cc)
- ci = cc.as_immutable()
+ ci = cc.as_readonly()
eq_(ci._all_columns, [c1, c2, c3])
eq_(list(ci), [c1, c2, c3])
@@ -944,13 +944,13 @@ class DedupeColumnCollectionTest(ColumnCollectionCommon, fixtures.TestBase):
assert cc.contains_column(c2)
self._assert_collection_integrity(cc)
- ci = cc.as_immutable()
+ ci = cc.as_readonly()
eq_(ci._all_columns, [c1, c2, c3])
eq_(list(ci), [c1, c2, c3])
def test_replace(self):
cc = DedupeColumnCollection()
- ci = cc.as_immutable()
+ ci = cc.as_readonly()
c1, c2a, c3, c2b = (
column("c1"),
@@ -979,7 +979,7 @@ class DedupeColumnCollectionTest(ColumnCollectionCommon, fixtures.TestBase):
def test_replace_key_matches_name_of_another(self):
cc = DedupeColumnCollection()
- ci = cc.as_immutable()
+ ci = cc.as_readonly()
c1, c2a, c3, c2b = (
column("c1"),
@@ -1009,7 +1009,7 @@ class DedupeColumnCollectionTest(ColumnCollectionCommon, fixtures.TestBase):
def test_replace_key_matches(self):
cc = DedupeColumnCollection()
- ci = cc.as_immutable()
+ ci = cc.as_readonly()
c1, c2a, c3, c2b = (
column("c1"),
@@ -1041,7 +1041,7 @@ class DedupeColumnCollectionTest(ColumnCollectionCommon, fixtures.TestBase):
def test_replace_name_matches(self):
cc = DedupeColumnCollection()
- ci = cc.as_immutable()
+ ci = cc.as_readonly()
c1, c2a, c3, c2b = (
column("c1"),
@@ -1073,7 +1073,7 @@ class DedupeColumnCollectionTest(ColumnCollectionCommon, fixtures.TestBase):
def test_replace_no_match(self):
cc = DedupeColumnCollection()
- ci = cc.as_immutable()
+ ci = cc.as_readonly()
c1, c2, c3, c4 = column("c1"), column("c2"), column("c3"), column("c4")
c4.key = "X"
@@ -1123,7 +1123,7 @@ class DedupeColumnCollectionTest(ColumnCollectionCommon, fixtures.TestBase):
cc = DedupeColumnCollection(
columns=[("c1", c1), ("c2", c2), ("c3", c3)]
)
- ci = cc.as_immutable()
+ ci = cc.as_readonly()
eq_(cc._all_columns, [c1, c2, c3])
eq_(list(cc), [c1, c2, c3])
@@ -1184,7 +1184,7 @@ class DedupeColumnCollectionTest(ColumnCollectionCommon, fixtures.TestBase):
def test_dupes_extend(self):
cc = DedupeColumnCollection()
- ci = cc.as_immutable()
+ ci = cc.as_readonly()
c1, c2a, c3, c2b = (
column("c1"),
@@ -3044,7 +3044,7 @@ class TestProperties(fixtures.TestBase):
def test_pickle_immuatbleprops(self):
data = {"hello": "bla"}
- props = util.Properties(data).as_immutable()
+ props = util.Properties(data).as_readonly()
for loader, dumper in picklers():
s = dumper(props)
diff --git a/test/profiles.txt b/test/profiles.txt
index 074b649f2..31f72bd16 100644
--- a/test/profiles.txt
+++ b/test/profiles.txt
@@ -69,17 +69,17 @@ test.aaa_profiling.test_compiler.CompileTest.test_update x86_64_linux_cpython_3.
# TEST: test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_cextensions 174
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_nocextensions 174
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_cextensions 174
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_nocextensions 174
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_cextensions 174
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_cextensions 174
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_nocextensions 174
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_cextensions 174
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_nocextensions 174
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 170
-test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 173
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_cextensions 180
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_nocextensions 180
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_cextensions 180
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_nocextensions 180
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_cextensions 180
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_cextensions 180
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_nocextensions 180
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_cextensions 180
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_nocextensions 180
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 180
+test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 180
# TEST: test.aaa_profiling.test_misc.CacheKeyTest.test_statement_key_is_cached
diff --git a/test/sql/test_quote.py b/test/sql/test_quote.py
index 9812f84c1..7d90bc67b 100644
--- a/test/sql/test_quote.py
+++ b/test/sql/test_quote.py
@@ -252,6 +252,50 @@ class QuoteTest(fixtures.TestBase, AssertsCompiledSQL):
eq_(repr(name), repr("姓名"))
+ def test_literal_column_label_embedded_select_samename_explicit_quote(
+ self,
+ ):
+ col = sql.literal_column("NEEDS QUOTES").label(
+ quoted_name("NEEDS QUOTES", True)
+ )
+
+ self.assert_compile(
+ select(col).subquery().select(),
+ 'SELECT anon_1."NEEDS QUOTES" FROM '
+ '(SELECT NEEDS QUOTES AS "NEEDS QUOTES") AS anon_1',
+ )
+
+ def test_literal_column_label_embedded_select_diffname_explicit_quote(
+ self,
+ ):
+ col = sql.literal_column("NEEDS QUOTES").label(
+ quoted_name("NEEDS QUOTES_", True)
+ )
+
+ self.assert_compile(
+ select(col).subquery().select(),
+ 'SELECT anon_1."NEEDS QUOTES_" FROM '
+ '(SELECT NEEDS QUOTES AS "NEEDS QUOTES_") AS anon_1',
+ )
+
+ def test_literal_column_label_embedded_select_diffname(self):
+ col = sql.literal_column("NEEDS QUOTES").label("NEEDS QUOTES_")
+
+ self.assert_compile(
+ select(col).subquery().select(),
+ 'SELECT anon_1."NEEDS QUOTES_" FROM (SELECT NEEDS QUOTES AS '
+ '"NEEDS QUOTES_") AS anon_1',
+ )
+
+ def test_literal_column_label_embedded_select_samename(self):
+ col = sql.literal_column("NEEDS QUOTES").label("NEEDS QUOTES")
+
+ self.assert_compile(
+ select(col).subquery().select(),
+ 'SELECT anon_1."NEEDS QUOTES" FROM (SELECT NEEDS QUOTES AS '
+ '"NEEDS QUOTES") AS anon_1',
+ )
+
def test_lower_case_names(self):
# Create table with quote defaults
metadata = MetaData()
diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py
index 138e7a4c6..ffbab3223 100644
--- a/test/sql/test_returning.py
+++ b/test/sql/test_returning.py
@@ -1,6 +1,7 @@
import itertools
from sqlalchemy import Boolean
+from sqlalchemy import column
from sqlalchemy import delete
from sqlalchemy import exc as sa_exc
from sqlalchemy import func
@@ -10,9 +11,11 @@ from sqlalchemy import MetaData
from sqlalchemy import select
from sqlalchemy import Sequence
from sqlalchemy import String
+from sqlalchemy import table
from sqlalchemy import testing
from sqlalchemy import type_coerce
from sqlalchemy import update
+from sqlalchemy.sql.sqltypes import NullType
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import AssertsCompiledSQL
from sqlalchemy.testing import AssertsExecutionResults
@@ -88,6 +91,113 @@ class ReturnCombinationTests(fixtures.TestBase, AssertsCompiledSQL):
t.c.x,
)
+ def test_named_expressions_selected_columns(self, table_fixture):
+ table = table_fixture
+ stmt = (
+ table.insert()
+ .values(goofy="someOTHERgoofy")
+ .returning(func.lower(table.c.x).label("goof"))
+ )
+ self.assert_compile(
+ select(stmt.exported_columns.goof),
+ "SELECT lower(foo.x) AS goof FROM foo",
+ )
+
+ def test_anon_expressions_selected_columns(self, table_fixture):
+ table = table_fixture
+ stmt = (
+ table.insert()
+ .values(goofy="someOTHERgoofy")
+ .returning(func.lower(table.c.x))
+ )
+ self.assert_compile(
+ select(stmt.exported_columns[0]),
+ "SELECT lower(foo.x) AS lower_1 FROM foo",
+ )
+
+ def test_returning_fromclause(self):
+ t = table("t", column("x"), column("y"), column("z"))
+ stmt = t.update().returning(t)
+
+ self.assert_compile(
+ stmt,
+ "UPDATE t SET x=%(x)s, y=%(y)s, z=%(z)s RETURNING t.x, t.y, t.z",
+ )
+
+ eq_(
+ stmt.returning_column_descriptions,
+ [
+ {
+ "name": "x",
+ "type": testing.eq_type_affinity(NullType),
+ "expr": t.c.x,
+ },
+ {
+ "name": "y",
+ "type": testing.eq_type_affinity(NullType),
+ "expr": t.c.y,
+ },
+ {
+ "name": "z",
+ "type": testing.eq_type_affinity(NullType),
+ "expr": t.c.z,
+ },
+ ],
+ )
+
+ cte = stmt.cte("c")
+
+ stmt = select(cte.c.z)
+ self.assert_compile(
+ stmt,
+ "WITH c AS (UPDATE t SET x=%(x)s, y=%(y)s, z=%(z)s "
+ "RETURNING t.x, t.y, t.z) SELECT c.z FROM c",
+ )
+
+ def test_returning_inspectable(self):
+ t = table("t", column("x"), column("y"), column("z"))
+
+ class HasClauseElement:
+ def __clause_element__(self):
+ return t
+
+ stmt = update(HasClauseElement()).returning(HasClauseElement())
+
+ eq_(
+ stmt.returning_column_descriptions,
+ [
+ {
+ "name": "x",
+ "type": testing.eq_type_affinity(NullType),
+ "expr": t.c.x,
+ },
+ {
+ "name": "y",
+ "type": testing.eq_type_affinity(NullType),
+ "expr": t.c.y,
+ },
+ {
+ "name": "z",
+ "type": testing.eq_type_affinity(NullType),
+ "expr": t.c.z,
+ },
+ ],
+ )
+
+ self.assert_compile(
+ stmt,
+ "UPDATE t SET x=%(x)s, y=%(y)s, z=%(z)s "
+ "RETURNING t.x, t.y, t.z",
+ )
+ cte = stmt.cte("c")
+
+ stmt = select(cte.c.z)
+ self.assert_compile(
+ stmt,
+ "WITH c AS (UPDATE t SET x=%(x)s, y=%(y)s, z=%(z)s "
+ "RETURNING t.x, t.y, t.z) SELECT c.z FROM c",
+ )
+
class ReturningTest(fixtures.TablesTest, AssertsExecutionResults):
__requires__ = ("returning",)
diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py
index 4944f2d57..ca5f43bb6 100644
--- a/test/sql/test_selectable.py
+++ b/test/sql/test_selectable.py
@@ -208,6 +208,24 @@ class SelectableTest(
{"name": "table1", "table": table1},
[],
),
+ (
+ table1.alias("some_alias"),
+ None,
+ {
+ "name": "some_alias",
+ "table": testing.eq_clause_element(table1.alias("some_alias")),
+ },
+ [],
+ ),
+ (
+ table1.join(table2),
+ None,
+ {
+ "name": None,
+ "table": testing.eq_clause_element(table1.join(table2)),
+ },
+ [],
+ ),
argnames="entity, cols, expected_entity, expected_returning",
)
def test_dml_descriptions(