summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-03-20 16:39:36 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-03-24 16:57:30 -0400
commit6f02d5edd88fe2475629438b0730181a2b00c5fe (patch)
treebbf9e9f3e8a2363659be35d59a7749c7fe35ef7c /lib/sqlalchemy/sql
parentc565c470517e1cc70a7f33d1ad3d3256935f1121 (diff)
downloadsqlalchemy-6f02d5edd88fe2475629438b0730181a2b00c5fe.tar.gz
pep484 - SQL internals
non-strict checking for mostly internal or semi-internal code Change-Id: Ib91b47f1a8ccc15e666b94bad1ce78c4ab15b0ec
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/__init__.py3
-rw-r--r--lib/sqlalchemy/sql/_selectable_constructors.py16
-rw-r--r--lib/sqlalchemy/sql/_typing.py6
-rw-r--r--lib/sqlalchemy/sql/base.py269
-rw-r--r--lib/sqlalchemy/sql/compiler.py59
-rw-r--r--lib/sqlalchemy/sql/crud.py305
-rw-r--r--lib/sqlalchemy/sql/default_comparator.py78
-rw-r--r--lib/sqlalchemy/sql/dml.py24
-rw-r--r--lib/sqlalchemy/sql/elements.py21
-rw-r--r--lib/sqlalchemy/sql/events.py46
-rw-r--r--lib/sqlalchemy/sql/expression.py1
-rw-r--r--lib/sqlalchemy/sql/roles.py6
-rw-r--r--lib/sqlalchemy/sql/schema.py436
-rw-r--r--lib/sqlalchemy/sql/selectable.py10
-rw-r--r--lib/sqlalchemy/sql/sqltypes.py10
-rw-r--r--lib/sqlalchemy/sql/type_api.py6
16 files changed, 906 insertions, 390 deletions
diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py
index 169ddf3db..2e766f976 100644
--- a/lib/sqlalchemy/sql/__init__.py
+++ b/lib/sqlalchemy/sql/__init__.py
@@ -4,6 +4,7 @@
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
+from typing import Any
from .base import Executable as Executable
from .compiler import COLLECT_CARTESIAN_PRODUCTS as COLLECT_CARTESIAN_PRODUCTS
@@ -97,7 +98,7 @@ from .expression import within_group as within_group
from .visitors import ClauseVisitor as ClauseVisitor
-def __go(lcls):
+def __go(lcls: Any) -> None:
from .. import util as _sa_util
from . import base
diff --git a/lib/sqlalchemy/sql/_selectable_constructors.py b/lib/sqlalchemy/sql/_selectable_constructors.py
index 9043aa6d0..e9acc7e6d 100644
--- a/lib/sqlalchemy/sql/_selectable_constructors.py
+++ b/lib/sqlalchemy/sql/_selectable_constructors.py
@@ -8,7 +8,7 @@
from __future__ import annotations
from typing import Any
-from typing import Union
+from typing import Optional
from . import coercions
from . import roles
@@ -23,8 +23,6 @@ from .selectable import Select
from .selectable import TableClause
from .selectable import TableSample
from .selectable import Values
-from ..util.typing import _LiteralStar
-from ..util.typing import Literal
def alias(selectable, name=None, flat=False):
@@ -283,9 +281,7 @@ def outerjoin(left, right, onclause=None, full=False):
return Join(left, right, onclause, isouter=True, full=full)
-def select(
- *entities: Union[_LiteralStar, Literal[1], _ColumnsClauseElement]
-) -> "Select":
+def select(*entities: _ColumnsClauseElement) -> Select:
r"""Construct a new :class:`_expression.Select`.
@@ -326,7 +322,7 @@ def select(
return Select(*entities)
-def table(name: str, *columns: ColumnClause, **kw: Any) -> "TableClause":
+def table(name: str, *columns: ColumnClause[Any], **kw: Any) -> TableClause:
"""Produce a new :class:`_expression.TableClause`.
The object returned is an instance of
@@ -435,7 +431,11 @@ def union_all(*selects):
return CompoundSelect._create_union_all(*selects)
-def values(*columns, name=None, literal_binds=False) -> "Values":
+def values(
+ *columns: ColumnClause[Any],
+ name: Optional[str] = None,
+ literal_binds: bool = False,
+) -> Values:
r"""Construct a :class:`_expression.Values` construct.
The column expressions and the actual data for
diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py
index 2be98b88f..b50a7bf6a 100644
--- a/lib/sqlalchemy/sql/_typing.py
+++ b/lib/sqlalchemy/sql/_typing.py
@@ -9,6 +9,7 @@ from typing import Union
from . import roles
from .. import util
from ..inspection import Inspectable
+from ..util.typing import Literal
if TYPE_CHECKING:
from .elements import quoted_name
@@ -24,12 +25,13 @@ if TYPE_CHECKING:
_T = TypeVar("_T", bound=Any)
_ColumnsClauseElement = Union[
+ Literal["*", 1],
roles.ColumnsClauseRole,
- Type,
+ Type[Any],
Inspectable[roles.HasColumnElementClauseElement],
]
_FromClauseElement = Union[
- roles.FromClauseRole, Type, Inspectable[roles.HasFromClauseElement]
+ roles.FromClauseRole, Type[Any], Inspectable[roles.HasFromClauseElement]
]
_ColumnExpression = Union[
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index 6a6b389de..8f5135915 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -12,22 +12,32 @@
from __future__ import annotations
-import collections.abc as collections_abc
from enum import Enum
from functools import reduce
import itertools
from itertools import zip_longest
import operator
import re
-import typing
from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import FrozenSet
+from typing import Generic
from typing import Iterable
+from typing import Iterator
from typing import List
+from typing import Mapping
from typing import MutableMapping
+from typing import NoReturn
from typing import Optional
from typing import Sequence
+from typing import Set
from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
from typing import TypeVar
+from typing import Union
from . import roles
from . import visitors
@@ -36,17 +46,26 @@ from .cache_key import MemoizedHasCacheKey # noqa
from .traversals import HasCopyInternals # noqa
from .visitors import ClauseVisitor
from .visitors import ExtendedInternalTraversal
+from .visitors import ExternallyTraversible
from .visitors import InternalTraversal
+from .. import event
from .. import exc
from .. import util
from ..util import HasMemoized as HasMemoized
from ..util import hybridmethod
from ..util import typing as compat_typing
+from ..util.typing import Protocol
from ..util.typing import Self
+from ..util.typing import TypeGuard
-if typing.TYPE_CHECKING:
+if TYPE_CHECKING:
+ from . import coercions
+ from . import elements
+ from . import type_api
from .elements import BindParameter
+ from .elements import ColumnClause
from .elements import ColumnElement
+ from .elements import SQLCoreOperations
from ..engine import Connection
from ..engine import Result
from ..engine.base import _CompiledCacheType
@@ -58,10 +77,12 @@ if typing.TYPE_CHECKING:
from ..engine.interfaces import CacheStats
from ..engine.interfaces import Compiled
from ..engine.interfaces import Dialect
+ from ..event import dispatcher
-coercions = None
-elements = None
-type_api = None
+if not TYPE_CHECKING:
+ coercions = None # noqa
+ elements = None # noqa
+ type_api = None # noqa
class _NoArg(Enum):
@@ -70,13 +91,24 @@ class _NoArg(Enum):
NO_ARG = _NoArg.NO_ARG
-# if I use sqlalchemy.util.typing, which has the exact same
-# symbols, mypy reports: "error: _Fn? not callable"
-_Fn = typing.TypeVar("_Fn", bound=typing.Callable)
+_Fn = TypeVar("_Fn", bound=Callable[..., Any])
_AmbiguousTableNameMap = MutableMapping[str, str]
+class _EntityNamespace(Protocol):
+ def __getattr__(self, key: str) -> SQLCoreOperations[Any]:
+ ...
+
+
+class _HasEntityNamespace(Protocol):
+ entity_namespace: _EntityNamespace
+
+
+def _is_has_entity_namespace(element: Any) -> TypeGuard[_HasEntityNamespace]:
+ return hasattr(element, "entity_namespace")
+
+
class Immutable:
"""mark a ClauseElement as 'immutable' when expressions are cloned."""
@@ -107,10 +139,14 @@ class SingletonConstant(Immutable):
def __new__(cls, *arg, **kw):
return cls._singleton
+ @util.non_memoized_property
+ def proxy_set(self) -> FrozenSet[ColumnElement[Any]]:
+ raise NotImplementedError()
+
@classmethod
def _create_singleton(cls):
obj = object.__new__(cls)
- obj.__init__()
+ obj.__init__() # type: ignore
# for a long time this was an empty frozenset, meaning
# a SingletonConstant would never be a "corresponding column" in
@@ -139,12 +175,11 @@ def _select_iterables(elements):
)
-_Self = typing.TypeVar("_Self", bound="_GenerativeType")
-_Args = compat_typing.ParamSpec("_Args")
+_SelfGenerativeType = TypeVar("_SelfGenerativeType", bound="_GenerativeType")
class _GenerativeType(compat_typing.Protocol):
- def _generate(self: "_Self") -> "_Self":
+ def _generate(self: _SelfGenerativeType) -> _SelfGenerativeType:
...
@@ -158,8 +193,8 @@ def _generative(fn: _Fn) -> _Fn:
@util.decorator
def _generative(
- fn: _Fn, self: _Self, *args: _Args.args, **kw: _Args.kwargs
- ) -> _Self:
+ fn: _Fn, self: _SelfGenerativeType, *args: Any, **kw: Any
+ ) -> _SelfGenerativeType:
"""Mark a method as generative."""
self = self._generate()
@@ -167,9 +202,9 @@ def _generative(fn: _Fn) -> _Fn:
assert x is self, "generative methods must return self"
return self
- decorated = _generative(fn)
- decorated.non_generative = fn
- return decorated
+ decorated = _generative(fn) # type: ignore
+ decorated.non_generative = fn # type: ignore
+ return decorated # type: ignore
def _exclusive_against(*names, **kw):
@@ -233,7 +268,7 @@ def _cloned_difference(a, b):
)
-class _DialectArgView(collections_abc.MutableMapping):
+class _DialectArgView(MutableMapping[str, Any]):
"""A dictionary view of dialect-level arguments in the form
<dialectname>_<argument_name>.
@@ -290,7 +325,7 @@ class _DialectArgView(collections_abc.MutableMapping):
)
-class _DialectArgDict(collections_abc.MutableMapping):
+class _DialectArgDict(MutableMapping[str, Any]):
"""A dictionary view of dialect-level arguments for a specific
dialect.
@@ -343,6 +378,8 @@ class DialectKWArgs:
"""
+ __slots__ = ()
+
_dialect_kwargs_traverse_internals = [
("dialect_options", InternalTraversal.dp_dialect_options)
]
@@ -534,7 +571,7 @@ class CompileState:
__slots__ = ("statement", "_ambiguous_table_name_map")
- plugins = {}
+ plugins: Dict[Tuple[str, str], Type[CompileState]] = {}
_ambiguous_table_name_map: Optional[_AmbiguousTableNameMap]
@@ -639,9 +676,9 @@ class InPlaceGenerative(HasMemoized):
class HasCompileState(Generative):
"""A class that has a :class:`.CompileState` associated with it."""
- _compile_state_plugin = None
+ _compile_state_plugin: Optional[Type[CompileState]] = None
- _attributes = util.immutabledict()
+ _attributes: util.immutabledict[str, Any] = util.EMPTY_DICT
_compile_state_factory = CompileState.create_for_statement
@@ -655,6 +692,8 @@ class _MetaOptions(type):
"""
+ _cache_attrs: Tuple[str, ...]
+
def __add__(self, other):
o1 = self()
@@ -674,6 +713,8 @@ class Options(metaclass=_MetaOptions):
__slots__ = ()
+ _cache_attrs: Tuple[str, ...]
+
def __init_subclass__(cls) -> None:
dict_ = cls.__dict__
cls._cache_attrs = tuple(
@@ -732,13 +773,13 @@ class Options(metaclass=_MetaOptions):
return self + {name: getattr(self, name) + value}
@hybridmethod
- def _state_dict(self):
+ def _state_dict_inst(self) -> Mapping[str, Any]:
return self.__dict__
- _state_dict_const = util.immutabledict()
+ _state_dict_const: util.immutabledict[str, Any] = util.EMPTY_DICT
- @_state_dict.classlevel
- def _state_dict(cls):
+ @_state_dict_inst.classlevel
+ def _state_dict(cls) -> Mapping[str, Any]:
return cls._state_dict_const
@classmethod
@@ -825,10 +866,10 @@ class CacheableOptions(Options, HasCacheKey):
__slots__ = ()
@hybridmethod
- def _gen_cache_key(self, anon_map, bindparams):
+ def _gen_cache_key_inst(self, anon_map, bindparams):
return HasCacheKey._gen_cache_key(self, anon_map, bindparams)
- @_gen_cache_key.classlevel
+ @_gen_cache_key_inst.classlevel
def _gen_cache_key(cls, anon_map, bindparams):
return (cls, ())
@@ -849,11 +890,11 @@ class ExecutableOption(HasCopyInternals):
def _clone(self, **kw):
"""Create a shallow copy of this ExecutableOption."""
c = self.__class__.__new__(self.__class__)
- c.__dict__ = dict(self.__dict__)
+ c.__dict__ = dict(self.__dict__) # type: ignore
return c
-SelfExecutable = typing.TypeVar("SelfExecutable", bound="Executable")
+SelfExecutable = TypeVar("SelfExecutable", bound="Executable")
class Executable(roles.StatementRole, Generative):
@@ -866,9 +907,12 @@ class Executable(roles.StatementRole, Generative):
"""
supports_execution: bool = True
- _execution_options: _ImmutableExecuteOptions = util.immutabledict()
- _with_options = ()
- _with_context_options = ()
+ _execution_options: _ImmutableExecuteOptions = util.EMPTY_DICT
+ _with_options: Tuple[ExecutableOption, ...] = ()
+ _with_context_options: Tuple[
+ Tuple[Callable[[CompileState], None], Any], ...
+ ] = ()
+ _compile_options: Optional[CacheableOptions]
_executable_traverse_internals = [
("_with_options", InternalTraversal.dp_executable_options),
@@ -886,7 +930,9 @@ class Executable(roles.StatementRole, Generative):
is_delete = False
is_dml = False
- if typing.TYPE_CHECKING:
+ if TYPE_CHECKING:
+
+ __visit_name__: str
def _compile_w_cache(
self,
@@ -916,11 +962,13 @@ class Executable(roles.StatementRole, Generative):
raise NotImplementedError()
@property
- def _effective_plugin_target(self):
+ def _effective_plugin_target(self) -> str:
return self.__visit_name__
@_generative
- def options(self: SelfExecutable, *options) -> SelfExecutable:
+ def options(
+ self: SelfExecutable, *options: ExecutableOption
+ ) -> SelfExecutable:
"""Apply options to this statement.
In the general sense, options are any kind of Python object
@@ -957,7 +1005,7 @@ class Executable(roles.StatementRole, Generative):
@_generative
def _set_compile_options(
- self: SelfExecutable, compile_options
+ self: SelfExecutable, compile_options: CacheableOptions
) -> SelfExecutable:
"""Assign the compile options to a new value.
@@ -970,16 +1018,19 @@ class Executable(roles.StatementRole, Generative):
@_generative
def _update_compile_options(
- self: SelfExecutable, options
+ self: SelfExecutable, options: CacheableOptions
) -> SelfExecutable:
"""update the _compile_options with new keys."""
+ assert self._compile_options is not None
self._compile_options += options
return self
@_generative
def _add_context_option(
- self: SelfExecutable, callable_, cache_args
+ self: SelfExecutable,
+ callable_: Callable[[CompileState], None],
+ cache_args: Any,
) -> SelfExecutable:
"""Add a context option to this statement.
@@ -995,7 +1046,7 @@ class Executable(roles.StatementRole, Generative):
return self
@_generative
- def execution_options(self: SelfExecutable, **kw) -> SelfExecutable:
+ def execution_options(self: SelfExecutable, **kw: Any) -> SelfExecutable:
"""Set non-SQL options for the statement which take effect during
execution.
@@ -1112,7 +1163,7 @@ class Executable(roles.StatementRole, Generative):
self._execution_options = self._execution_options.union(kw)
return self
- def get_execution_options(self):
+ def get_execution_options(self) -> _ExecuteOptions:
"""Get the non-SQL options which will take effect during execution.
.. versionadded:: 1.3
@@ -1124,7 +1175,7 @@ class Executable(roles.StatementRole, Generative):
return self._execution_options
-class SchemaEventTarget:
+class SchemaEventTarget(event.EventTarget):
"""Base class for elements that are the targets of :class:`.DDLEvents`
events.
@@ -1132,6 +1183,8 @@ class SchemaEventTarget:
"""
+ dispatch: dispatcher[SchemaEventTarget]
+
def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None:
"""Associate with this SchemaEvent's parent object."""
@@ -1149,7 +1202,10 @@ class SchemaVisitor(ClauseVisitor):
__traverse_options__ = {"schema_visitor": True}
-class ColumnCollection:
+_COL = TypeVar("_COL", bound="ColumnClause[Any]")
+
+
+class ColumnCollection(Generic[_COL]):
"""Collection of :class:`_expression.ColumnElement` instances,
typically for
:class:`_sql.FromClause` objects.
@@ -1260,32 +1316,36 @@ class ColumnCollection:
__slots__ = "_collection", "_index", "_colset"
- def __init__(self, columns=None):
+ _collection: List[Tuple[str, _COL]]
+ _index: Dict[Union[str, int], _COL]
+ _colset: Set[_COL]
+
+ def __init__(self, columns: Optional[Iterable[Tuple[str, _COL]]] = None):
object.__setattr__(self, "_colset", set())
object.__setattr__(self, "_index", {})
object.__setattr__(self, "_collection", [])
if columns:
self._initial_populate(columns)
- def _initial_populate(self, iter_):
+ def _initial_populate(self, iter_: Iterable[Tuple[str, _COL]]) -> None:
self._populate_separate_keys(iter_)
@property
- def _all_columns(self):
+ def _all_columns(self) -> List[_COL]:
return [col for (k, col) in self._collection]
- def keys(self):
+ def keys(self) -> List[str]:
"""Return a sequence of string key names for all columns in this
collection."""
return [k for (k, col) in self._collection]
- def values(self):
+ def values(self) -> List[_COL]:
"""Return a sequence of :class:`_sql.ColumnClause` or
:class:`_schema.Column` objects for all columns in this
collection."""
return [col for (k, col) in self._collection]
- def items(self):
+ def items(self) -> List[Tuple[str, _COL]]:
"""Return a sequence of (key, column) tuples for all columns in this
collection each consisting of a string key name and a
:class:`_sql.ColumnClause` or
@@ -1294,17 +1354,17 @@ class ColumnCollection:
return list(self._collection)
- def __bool__(self):
+ def __bool__(self) -> bool:
return bool(self._collection)
- def __len__(self):
+ def __len__(self) -> int:
return len(self._collection)
- def __iter__(self):
+ def __iter__(self) -> Iterator[_COL]:
# turn to a list first to maintain over a course of changes
return iter([col for k, col in self._collection])
- def __getitem__(self, key):
+ def __getitem__(self, key: Union[str, int]) -> _COL:
try:
return self._index[key]
except KeyError as err:
@@ -1313,13 +1373,13 @@ class ColumnCollection:
else:
raise
- def __getattr__(self, key):
+ def __getattr__(self, key: str) -> _COL:
try:
return self._index[key]
except KeyError as err:
raise AttributeError(key) from err
- def __contains__(self, key):
+ def __contains__(self, key: str) -> bool:
if key not in self._index:
if not isinstance(key, str):
raise exc.ArgumentError(
@@ -1329,7 +1389,7 @@ class ColumnCollection:
else:
return True
- def compare(self, other):
+ def compare(self, other: ColumnCollection[Any]) -> bool:
"""Compare this :class:`_expression.ColumnCollection` to another
based on the names of the keys"""
@@ -1339,10 +1399,10 @@ class ColumnCollection:
else:
return True
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
return self.compare(other)
- def get(self, key, default=None):
+ def get(self, key: str, default: Optional[_COL] = None) -> Optional[_COL]:
"""Get a :class:`_sql.ColumnClause` or :class:`_schema.Column` object
based on a string key name from this
:class:`_expression.ColumnCollection`."""
@@ -1352,39 +1412,40 @@ class ColumnCollection:
else:
return default
- def __str__(self):
+ def __str__(self) -> str:
return "%s(%s)" % (
self.__class__.__name__,
", ".join(str(c) for c in self),
)
- def __setitem__(self, key, value):
+ def __setitem__(self, key: str, value: Any) -> NoReturn:
raise NotImplementedError()
- def __delitem__(self, key):
+ def __delitem__(self, key: str) -> NoReturn:
raise NotImplementedError()
- def __setattr__(self, key, obj):
+ def __setattr__(self, key: str, obj: Any) -> NoReturn:
raise NotImplementedError()
- def clear(self):
+ def clear(self) -> NoReturn:
"""Dictionary clear() is not implemented for
:class:`_sql.ColumnCollection`."""
raise NotImplementedError()
- def remove(self, column):
- """Dictionary remove() is not implemented for
- :class:`_sql.ColumnCollection`."""
+ def remove(self, column: Any) -> None:
raise NotImplementedError()
- def update(self, iter_):
+ def update(self, iter_: Any) -> NoReturn:
"""Dictionary update() is not implemented for
:class:`_sql.ColumnCollection`."""
raise NotImplementedError()
- __hash__ = None
+ # https://github.com/python/mypy/issues/4266
+ __hash__ = None # type: ignore
- def _populate_separate_keys(self, iter_):
+ def _populate_separate_keys(
+ self, iter_: Iterable[Tuple[str, _COL]]
+ ) -> None:
"""populate from an iterator of (key, column)"""
cols = list(iter_)
self._collection[:] = cols
@@ -1394,7 +1455,7 @@ class ColumnCollection:
)
self._index.update({k: col for k, col in reversed(self._collection)})
- def add(self, column, key=None):
+ def add(self, column: _COL, key: Optional[str] = None) -> None:
"""Add a column to this :class:`_sql.ColumnCollection`.
.. note::
@@ -1416,17 +1477,17 @@ class ColumnCollection:
if key not in self._index:
self._index[key] = column
- def __getstate__(self):
+ def __getstate__(self) -> Dict[str, Any]:
return {"_collection": self._collection, "_index": self._index}
- def __setstate__(self, state):
+ def __setstate__(self, state: Dict[str, Any]) -> None:
object.__setattr__(self, "_index", state["_index"])
object.__setattr__(self, "_collection", state["_collection"])
object.__setattr__(
self, "_colset", {col for k, col in self._collection}
)
- def contains_column(self, col):
+ def contains_column(self, col: _COL) -> bool:
"""Checks if a column object exists in this collection"""
if col not in self._colset:
if isinstance(col, str):
@@ -1438,13 +1499,15 @@ class ColumnCollection:
else:
return True
- def as_immutable(self):
+ def as_immutable(self) -> ImmutableColumnCollection[_COL]:
"""Return an "immutable" form of this
:class:`_sql.ColumnCollection`."""
return ImmutableColumnCollection(self)
- def corresponding_column(self, column, require_embedded=False):
+ def corresponding_column(
+ self, column: _COL, require_embedded: bool = False
+ ) -> Optional[_COL]:
"""Given a :class:`_expression.ColumnElement`, return the exported
:class:`_expression.ColumnElement` object from this
:class:`_expression.ColumnCollection`
@@ -1497,7 +1560,7 @@ class ColumnCollection:
not require_embedded
or embedded(expanded_proxy_set, target_set)
):
- if col is None:
+ if col is None or intersect is None:
# no corresponding column yet, pick this one.
@@ -1542,7 +1605,7 @@ class ColumnCollection:
return col
-class DedupeColumnCollection(ColumnCollection):
+class DedupeColumnCollection(ColumnCollection[_COL]):
"""A :class:`_expression.ColumnCollection`
that maintains deduplicating behavior.
@@ -1555,7 +1618,7 @@ class DedupeColumnCollection(ColumnCollection):
"""
- def add(self, column, key=None):
+ def add(self, column: _COL, key: Optional[str] = None) -> None:
if key is not None and column.key != key:
raise exc.ArgumentError(
@@ -1589,7 +1652,9 @@ class DedupeColumnCollection(ColumnCollection):
self._index[l] = column
self._index[key] = column
- def _populate_separate_keys(self, iter_):
+ def _populate_separate_keys(
+ self, iter_: Iterable[Tuple[str, _COL]]
+ ) -> None:
"""populate from an iterator of (key, column)"""
cols = list(iter_)
@@ -1614,10 +1679,10 @@ class DedupeColumnCollection(ColumnCollection):
for col in replace_col:
self.replace(col)
- def extend(self, iter_):
+ def extend(self, iter_: Iterable[_COL]) -> None:
self._populate_separate_keys((col.key, col) for col in iter_)
- def remove(self, column):
+ def remove(self, column: _COL) -> None:
if column not in self._colset:
raise ValueError(
"Can't remove column %r; column is not in this collection"
@@ -1634,7 +1699,7 @@ class DedupeColumnCollection(ColumnCollection):
# delete higher index
del self._index[len(self._collection)]
- def replace(self, column):
+ def replace(self, column: _COL) -> None:
"""add the given column to this collection, removing unaliased
versions of this column as well as existing columns with the
same key.
@@ -1687,7 +1752,9 @@ class DedupeColumnCollection(ColumnCollection):
self._index.update(self._collection)
-class ImmutableColumnCollection(util.ImmutableContainer, ColumnCollection):
+class ImmutableColumnCollection(
+ util.ImmutableContainer, ColumnCollection[_COL]
+):
__slots__ = ("_parent",)
def __init__(self, collection):
@@ -1701,12 +1768,19 @@ class ImmutableColumnCollection(util.ImmutableContainer, ColumnCollection):
def __setstate__(self, state):
parent = state["_parent"]
- self.__init__(parent)
+ self.__init__(parent) # type: ignore
- add = extend = remove = util.ImmutableContainer._immutable
+ def add(self, column: Any, key: Any = ...) -> Any:
+ self._immutable()
+ def extend(self, elements: Any) -> None:
+ self._immutable()
-class ColumnSet(util.ordered_column_set):
+ def remove(self, item: Any) -> None:
+ self._immutable()
+
+
+class ColumnSet(util.OrderedSet["ColumnClause[Any]"]):
def contains_column(self, col):
return col in self
@@ -1714,9 +1788,6 @@ class ColumnSet(util.ordered_column_set):
for col in cols:
self.add(col)
- def __add__(self, other):
- return list(self) + list(other)
-
def __eq__(self, other):
l = []
for c in other:
@@ -1729,7 +1800,9 @@ class ColumnSet(util.ordered_column_set):
return hash(tuple(x for x in self))
-def _entity_namespace(entity):
+def _entity_namespace(
+ entity: Union[_HasEntityNamespace, ExternallyTraversible]
+) -> _EntityNamespace:
"""Return the nearest .entity_namespace for the given entity.
If not immediately available, does an iterate to find a sub-element
@@ -1737,16 +1810,20 @@ def _entity_namespace(entity):
"""
try:
- return entity.entity_namespace
+ return cast(_HasEntityNamespace, entity).entity_namespace
except AttributeError:
- for elem in visitors.iterate(entity):
- if hasattr(elem, "entity_namespace"):
+ for elem in visitors.iterate(cast(ExternallyTraversible, entity)):
+ if _is_has_entity_namespace(elem):
return elem.entity_namespace
else:
raise
-def _entity_namespace_key(entity, key, default=NO_ARG):
+def _entity_namespace_key(
+ entity: Union[_HasEntityNamespace, ExternallyTraversible],
+ key: str,
+ default: Union[SQLCoreOperations[Any], _NoArg] = NO_ARG,
+) -> SQLCoreOperations[Any]:
"""Return an entry from an entity_namespace.
@@ -1760,7 +1837,7 @@ def _entity_namespace_key(entity, key, default=NO_ARG):
if default is not NO_ARG:
return getattr(ns, key, default)
else:
- return getattr(ns, key)
+ return getattr(ns, key) # type: ignore
except AttributeError as err:
raise exc.InvalidRequestError(
'Entity namespace for "%s" has no property "%s"' % (entity, key)
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index f8019b9c6..5ba52ae51 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -71,6 +71,7 @@ from .schema import Column
from .sqltypes import TupleType
from .type_api import TypeEngine
from .visitors import prefix_anon_map
+from .visitors import Visitable
from .. import exc
from .. import util
from ..util.typing import Literal
@@ -614,10 +615,10 @@ class Compiled:
raise NotImplementedError()
- def process(self, obj, **kwargs):
+ def process(self, obj: Visitable, **kwargs: Any) -> str:
return obj._compiler_dispatch(self, **kwargs)
- def __str__(self):
+ def __str__(self) -> str:
"""Return the string text of the generated SQL or DDL."""
return self.string or ""
@@ -723,7 +724,7 @@ class SQLCompiler(Compiled):
"""list of columns for which onupdate default values should be evaluated
before an UPDATE takes place"""
- returning: Optional[List[Column[Any]]]
+ returning: Optional[List[ColumnClause[Any]]]
"""list of columns that will be delivered to cursor.description or
dialect equivalent via the RETURNING clause on an INSERT, UPDATE, or DELETE
@@ -1485,15 +1486,12 @@ class SQLCompiler(Compiled):
self._result_columns
)
- _key_getters_for_crud_column: Tuple[
- Callable[[Union[str, Column[Any]]], str],
- Callable[[Column[Any]], str],
- Callable[[Column[Any]], str],
- ]
+ # assigned by crud.py for insert/update statements
+ _get_bind_name_for_col: _BindNameForColProtocol
@util.memoized_property
def _within_exec_param_key_getter(self) -> Callable[[Any], str]:
- getter = self._key_getters_for_crud_column[2]
+ getter = self._get_bind_name_for_col
if self.escaped_bind_names:
def _get(obj):
@@ -4098,7 +4096,9 @@ class SQLCompiler(Compiled):
def for_update_clause(self, select, **kw):
return " FOR UPDATE"
- def returning_clause(self, stmt, returning_cols):
+ def returning_clause(
+ self, stmt: UpdateBase, returning_cols: List[ColumnClause[Any]]
+ ) -> str:
raise exc.CompileError(
"RETURNING is not supported by this "
"dialect's statement compiler."
@@ -4243,12 +4243,13 @@ class SQLCompiler(Compiled):
}
)
- crud_params = crud._get_crud_params(
+ crud_params_struct = crud._get_crud_params(
self, insert_stmt, compile_state, **kw
)
+ crud_params_single = crud_params_struct.single_params
if (
- not crud_params
+ not crud_params_single
and not self.dialect.supports_default_values
and not self.dialect.supports_default_metavalue
and not self.dialect.supports_empty_insert
@@ -4266,9 +4267,9 @@ class SQLCompiler(Compiled):
"version settings does not support "
"in-place multirow inserts." % self.dialect.name
)
- crud_params_single = crud_params[0]
+ crud_params_single = crud_params_struct.single_params
else:
- crud_params_single = crud_params
+ crud_params_single = crud_params_struct.single_params
preparer = self.preparer
supports_default_values = self.dialect.supports_default_values
@@ -4293,7 +4294,7 @@ class SQLCompiler(Compiled):
if crud_params_single or not supports_default_values:
text += " (%s)" % ", ".join(
- [expr for c, expr, value in crud_params_single]
+ [expr for _, expr, _ in crud_params_single]
)
if self.returning or insert_stmt._returning:
@@ -4323,19 +4324,24 @@ class SQLCompiler(Compiled):
)
else:
text += " %s" % select_text
- elif not crud_params and supports_default_values:
+ elif not crud_params_single and supports_default_values:
text += " DEFAULT VALUES"
elif compile_state._has_multi_parameters:
text += " VALUES %s" % (
", ".join(
"(%s)"
- % (", ".join(value for c, expr, value in crud_param_set))
- for crud_param_set in crud_params
+ % (", ".join(value for _, _, value in crud_param_set))
+ for crud_param_set in crud_params_struct.all_multi_params
)
)
else:
insert_single_values_expr = ", ".join(
- [value for c, expr, value in crud_params]
+ [
+ value
+ for _, _, value in cast(
+ "List[Tuple[Any, Any, str]]", crud_params_single
+ )
+ ]
)
text += " VALUES (%s)" % insert_single_values_expr
if toplevel and insert_stmt._post_values_clause is None:
@@ -4443,9 +4449,10 @@ class SQLCompiler(Compiled):
table_text = self.update_tables_clause(
update_stmt, update_stmt.table, render_extra_froms, **kw
)
- crud_params = crud._get_crud_params(
+ crud_params_struct = crud._get_crud_params(
self, update_stmt, compile_state, **kw
)
+ crud_params = crud_params_struct.single_params
if update_stmt._hints:
dialect_hints, table_text = self._setup_crud_hints(
@@ -4460,7 +4467,12 @@ class SQLCompiler(Compiled):
text += table_text
text += " SET "
- text += ", ".join(expr + "=" + value for c, expr, value in crud_params)
+ text += ", ".join(
+ expr + "=" + value
+ for _, expr, value in cast(
+ "List[Tuple[Any, str, str]]", crud_params
+ )
+ )
if self.returning or update_stmt._returning:
if self.returning_precedes_values:
@@ -5446,6 +5458,11 @@ class _SchemaForObjectCallable(Protocol):
...
+class _BindNameForColProtocol(Protocol):
+ def __call__(self, col: ColumnClause[Any]) -> str:
+ ...
+
+
class IdentifierPreparer:
"""Handle quoting and case-folding of identifiers based on options."""
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index 4292aa916..533a2f6cd 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -13,13 +13,44 @@ from __future__ import annotations
import functools
import operator
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import List
+from typing import MutableMapping
+from typing import NamedTuple
+from typing import Optional
+from typing import overload
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
from . import coercions
from . import dml
from . import elements
from . import roles
+from .schema import default_is_clause_element
+from .schema import default_is_sequence
from .. import exc
from .. import util
+from ..util.typing import Literal
+
+if TYPE_CHECKING:
+ from .compiler import _BindNameForColProtocol
+ from .compiler import SQLCompiler
+ from .dml import DMLState
+ from .dml import Insert
+ from .dml import Update
+ from .dml import UpdateDMLState
+ from .dml import ValuesBase
+ from .elements import ClauseElement
+ from .elements import ColumnClause
+ from .elements import ColumnElement
+ from .elements import TextClause
+ from .schema import _SQLExprDefault
+ from .schema import Column
+ from .selectable import TableClause
REQUIRED = util.symbol(
"REQUIRED",
@@ -36,7 +67,27 @@ values present.
)
-def _get_crud_params(compiler, stmt, compile_state, **kw):
+class _CrudParams(NamedTuple):
+ single_params: List[
+ Tuple[ColumnClause[Any], str, Optional[Union[str, _SQLExprDefault]]]
+ ]
+ all_multi_params: List[
+ List[
+ Tuple[
+ ColumnClause[Any],
+ str,
+ str,
+ ]
+ ]
+ ]
+
+
+def _get_crud_params(
+ compiler: SQLCompiler,
+ stmt: ValuesBase,
+ compile_state: DMLState,
+ **kw: Any,
+) -> _CrudParams:
"""create a set of tuples representing column/string pairs for use
in an INSERT or UPDATE statement.
@@ -59,24 +110,32 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
_column_as_key,
_getattr_col_key,
_col_bind_name,
- ) = getters = _key_getters_for_crud_column(compiler, stmt, compile_state)
+ ) = _key_getters_for_crud_column(compiler, stmt, compile_state)
- compiler._key_getters_for_crud_column = getters
+ compiler._get_bind_name_for_col = _col_bind_name
# no parameters in the statement, no parameters in the
# compiled params - return binds for all columns
if compiler.column_keys is None and compile_state._no_parameters:
- return [
- (
- c,
- compiler.preparer.format_column(c),
- _create_bind_param(compiler, c, None, required=True),
- )
- for c in stmt.table.columns
- ]
+ return _CrudParams(
+ [
+ (
+ c,
+ compiler.preparer.format_column(c),
+ _create_bind_param(compiler, c, None, required=True),
+ )
+ for c in stmt.table.columns
+ ],
+ [],
+ )
+
+ stmt_parameter_tuples: Optional[List[Any]]
+ spd: Optional[MutableMapping[str, Any]]
if compile_state._has_multi_parameters:
- spd = compile_state._multi_parameters[0]
+ mp = compile_state._multi_parameters
+ assert mp is not None
+ spd = mp[0]
stmt_parameter_tuples = list(spd.items())
elif compile_state._ordered_values:
spd = compile_state._dict_parameters
@@ -92,6 +151,7 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
if compiler.column_keys is None:
parameters = {}
elif stmt_parameter_tuples:
+ assert spd is not None
parameters = dict(
(_column_as_key(key), REQUIRED)
for key in compiler.column_keys
@@ -103,7 +163,9 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
)
# create a list of column assignment clauses as tuples
- values = []
+ values: List[
+ Tuple[ColumnClause[Any], str, Optional[Union[str, _SQLExprDefault]]]
+ ] = []
if stmt_parameter_tuples is not None:
_get_stmt_parameter_tuples_params(
@@ -116,11 +178,11 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
kw,
)
- check_columns = {}
+ check_columns: Dict[str, ColumnClause[Any]] = {}
# special logic that only occurs for multi-table UPDATE
# statements
- if compile_state.isupdate and compile_state.is_multitable:
+ if dml.isupdate(compile_state) and compile_state.is_multitable:
_get_update_multitable_params(
compiler,
stmt,
@@ -134,6 +196,10 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
)
if compile_state.isinsert and stmt._select_names:
+ # is an insert from select, is not a multiparams
+
+ assert not compile_state._has_multi_parameters
+
_scan_insert_from_select_cols(
compiler,
stmt,
@@ -173,14 +239,17 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
)
if compile_state._has_multi_parameters:
- values = _extend_values_for_multiparams(
+ # is a multiparams, is not an insert from a select
+ assert not stmt._select_names
+ multi_extended_values = _extend_values_for_multiparams(
compiler,
stmt,
compile_state,
- values,
- _column_as_key,
+ cast("List[Tuple[ColumnClause[Any], str, str]]", values),
+ cast("Callable[..., str]", _column_as_key),
kw,
)
+ return _CrudParams(values, multi_extended_values)
elif (
not values
and compiler.for_executemany
@@ -198,12 +267,41 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
)
]
- return values
+ return _CrudParams(values, [])
+@overload
def _create_bind_param(
- compiler, col, value, process=True, required=False, name=None, **kw
-):
+ compiler: SQLCompiler,
+ col: ColumnElement[Any],
+ value: Any,
+ process: Literal[True] = ...,
+ required: bool = False,
+ name: Optional[str] = None,
+ **kw: Any,
+) -> str:
+ ...
+
+
+@overload
+def _create_bind_param(
+ compiler: SQLCompiler,
+ col: ColumnElement[Any],
+ value: Any,
+ **kw: Any,
+) -> str:
+ ...
+
+
+def _create_bind_param(
+ compiler: SQLCompiler,
+ col: ColumnElement[Any],
+ value: Any,
+ process: bool = True,
+ required: bool = False,
+ name: Optional[str] = None,
+ **kw: Any,
+) -> Union[str, elements.BindParameter[Any]]:
if name is None:
name = col.key
bindparam = elements.BindParameter(
@@ -211,8 +309,9 @@ def _create_bind_param(
)
bindparam._is_crud = True
if process:
- bindparam = bindparam._compiler_dispatch(compiler, **kw)
- return bindparam
+ return bindparam._compiler_dispatch(compiler, **kw)
+ else:
+ return bindparam
def _handle_values_anonymous_param(compiler, col, value, name, **kw):
@@ -253,8 +352,14 @@ def _handle_values_anonymous_param(compiler, col, value, name, **kw):
return value._compiler_dispatch(compiler, **kw)
-def _key_getters_for_crud_column(compiler, stmt, compile_state):
- if compile_state.isupdate and compile_state._extra_froms:
+def _key_getters_for_crud_column(
+ compiler: SQLCompiler, stmt: ValuesBase, compile_state: DMLState
+) -> Tuple[
+ Callable[[Union[str, Column[Any]]], Union[str, Tuple[str, str]]],
+ Callable[[Column[Any]], Union[str, Tuple[str, str]]],
+ _BindNameForColProtocol,
+]:
+ if dml.isupdate(compile_state) and compile_state._extra_froms:
# when extra tables are present, refer to the columns
# in those extra tables as table-qualified, including in
# dictionaries and when rendering bind param names.
@@ -267,30 +372,36 @@ def _key_getters_for_crud_column(compiler, stmt, compile_state):
coercions.expect_as_key, roles.DMLColumnRole
)
- def _column_as_key(key):
+ def _column_as_key(
+ key: Union[ColumnClause[Any], str]
+ ) -> Union[str, Tuple[str, str]]:
str_key = c_key_role(key)
- if hasattr(key, "table") and key.table in _et:
- return (key.table.name, str_key)
+ if hasattr(key, "table") and key.table in _et: # type: ignore
+ return (key.table.name, str_key) # type: ignore
else:
- return str_key
+ return str_key # type: ignore
- def _getattr_col_key(col):
+ def _getattr_col_key(
+ col: ColumnClause[Any],
+ ) -> Union[str, Tuple[str, str]]:
if col.table in _et:
- return (col.table.name, col.key)
+ return (col.table.name, col.key) # type: ignore
else:
return col.key
- def _col_bind_name(col):
+ def _col_bind_name(col: ColumnClause[Any]) -> str:
if col.table in _et:
+ if TYPE_CHECKING:
+ assert isinstance(col.table, TableClause)
return "%s_%s" % (col.table.name, col.key)
else:
return col.key
else:
- _column_as_key = functools.partial(
+ _column_as_key = functools.partial( # type: ignore
coercions.expect_as_key, roles.DMLColumnRole
)
- _getattr_col_key = _col_bind_name = operator.attrgetter("key")
+ _getattr_col_key = _col_bind_name = operator.attrgetter("key") # type: ignore # noqa E501
return _column_as_key, _getattr_col_key, _col_bind_name
@@ -321,7 +432,7 @@ def _scan_insert_from_select_cols(
compiler.stack[-1]["insert_from_select"] = stmt.select
- add_select_cols = []
+ add_select_cols: List[Tuple[ColumnClause[Any], str, _SQLExprDefault]] = []
if stmt.include_insert_from_select_defaults:
col_set = set(cols)
for col in stmt.table.columns:
@@ -707,16 +818,22 @@ def _append_param_insert_hasdefault(
)
-def _append_param_insert_select_hasdefault(compiler, stmt, c, values, kw):
+def _append_param_insert_select_hasdefault(
+ compiler: SQLCompiler,
+ stmt: ValuesBase,
+ c: ColumnClause[Any],
+ values: List[Tuple[ColumnClause[Any], str, _SQLExprDefault]],
+ kw: Dict[str, Any],
+) -> None:
- if c.default.is_sequence:
+ if default_is_sequence(c.default):
if compiler.dialect.supports_sequences and (
not c.default.optional or not compiler.dialect.sequences_optional
):
values.append(
(c, compiler.preparer.format_column(c), c.default.next_value())
)
- elif c.default.is_clause_element:
+ elif default_is_clause_element(c.default):
values.append(
(c, compiler.preparer.format_column(c), c.default.arg.self_group())
)
@@ -777,28 +894,76 @@ def _append_param_update(
compiler.returning.append(c)
+@overload
def _create_insert_prefetch_bind_param(
- compiler, c, process=True, name=None, **kw
-):
+ compiler: SQLCompiler,
+ c: ColumnElement[Any],
+ process: Literal[True] = ...,
+ **kw: Any,
+) -> str:
+ ...
+
+
+@overload
+def _create_insert_prefetch_bind_param(
+ compiler: SQLCompiler,
+ c: ColumnElement[Any],
+ process: Literal[False],
+ **kw: Any,
+) -> elements.BindParameter[Any]:
+ ...
+
+
+def _create_insert_prefetch_bind_param(
+ compiler: SQLCompiler,
+ c: ColumnElement[Any],
+ process: bool = True,
+ name: Optional[str] = None,
+ **kw: Any,
+) -> Union[elements.BindParameter[Any], str]:
param = _create_bind_param(
compiler, c, None, process=process, name=name, **kw
)
- compiler.insert_prefetch.append(c)
+ compiler.insert_prefetch.append(c) # type: ignore
return param
+@overload
def _create_update_prefetch_bind_param(
- compiler, c, process=True, name=None, **kw
-):
+ compiler: SQLCompiler,
+ c: ColumnElement[Any],
+ process: Literal[True] = ...,
+ **kw: Any,
+) -> str:
+ ...
+
+
+@overload
+def _create_update_prefetch_bind_param(
+ compiler: SQLCompiler,
+ c: ColumnElement[Any],
+ process: Literal[False],
+ **kw: Any,
+) -> elements.BindParameter[Any]:
+ ...
+
+
+def _create_update_prefetch_bind_param(
+ compiler: SQLCompiler,
+ c: ColumnElement[Any],
+ process: bool = True,
+ name: Optional[str] = None,
+ **kw: Any,
+) -> Union[elements.BindParameter[Any], str]:
param = _create_bind_param(
compiler, c, None, process=process, name=name, **kw
)
- compiler.update_prefetch.append(c)
+ compiler.update_prefetch.append(c) # type: ignore
return param
-class _multiparam_column(elements.ColumnElement):
+class _multiparam_column(elements.ColumnElement[Any]):
_is_multiparam_column = True
def __init__(self, original, index):
@@ -822,14 +987,20 @@ class _multiparam_column(elements.ColumnElement):
)
-def _process_multiparam_default_bind(compiler, stmt, c, index, kw):
+def _process_multiparam_default_bind(
+ compiler: SQLCompiler,
+ stmt: ValuesBase,
+ c: ColumnClause[Any],
+ index: int,
+ kw: Dict[str, Any],
+) -> str:
if not c.default:
raise exc.CompileError(
"INSERT value for column %s is explicitly rendered as a bound"
"parameter in the VALUES clause; "
"a Python-side value or SQL expression is required" % c
)
- elif c.default.is_clause_element:
+ elif default_is_clause_element(c.default):
return compiler.process(c.default.arg.self_group(), **kw)
elif c.default.is_sequence:
# these conditions would have been established
@@ -844,9 +1015,13 @@ def _process_multiparam_default_bind(compiler, stmt, c, index, kw):
else:
col = _multiparam_column(c, index)
if isinstance(stmt, dml.Insert):
- return _create_insert_prefetch_bind_param(compiler, col, **kw)
+ return _create_insert_prefetch_bind_param(
+ compiler, col, process=True, **kw
+ )
else:
- return _create_update_prefetch_bind_param(compiler, col, **kw)
+ return _create_update_prefetch_bind_param(
+ compiler, col, process=True, **kw
+ )
def _get_update_multitable_params(
@@ -926,18 +1101,26 @@ def _get_update_multitable_params(
def _extend_values_for_multiparams(
- compiler,
- stmt,
- compile_state,
- values,
- _column_as_key,
- kw,
-):
- values_0 = values
- values = [values]
-
- for i, row in enumerate(compile_state._multi_parameters[1:]):
- extension = []
+ compiler: SQLCompiler,
+ stmt: ValuesBase,
+ compile_state: DMLState,
+ initial_values: List[Tuple[ColumnClause[Any], str, str]],
+ _column_as_key: Callable[..., str],
+ kw: Dict[str, Any],
+) -> List[List[Tuple[ColumnClause[Any], str, str]]]:
+ values_0 = initial_values
+ values = [initial_values]
+
+ mp = compile_state._multi_parameters
+ assert mp is not None
+ for i, row in enumerate(mp[1:]):
+ extension: List[
+ Tuple[
+ ColumnClause[Any],
+ str,
+ str,
+ ]
+ ] = []
row = {_column_as_key(key): v for key, v in row.items()}
diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py
index 91bb0a5c5..944a0a5ce 100644
--- a/lib/sqlalchemy/sql/default_comparator.py
+++ b/lib/sqlalchemy/sql/default_comparator.py
@@ -26,6 +26,7 @@ from . import roles
from . import type_api
from .elements import and_
from .elements import BinaryExpression
+from .elements import ClauseElement
from .elements import ClauseList
from .elements import CollationClause
from .elements import CollectionAggregate
@@ -43,7 +44,7 @@ _T = typing.TypeVar("_T", bound=Any)
if typing.TYPE_CHECKING:
from .elements import ColumnElement
from .operators import custom_op
- from .sqltypes import TypeEngine
+ from .type_api import TypeEngine
def _boolean_compare(
@@ -53,10 +54,10 @@ def _boolean_compare(
*,
negate_op: Optional[OperatorType] = None,
reverse: bool = False,
- _python_is_types=(util.NoneType, bool),
- _any_all_expr=False,
+ _python_is_types: Tuple[Type[Any], ...] = (type(None), bool),
+ _any_all_expr: bool = False,
result_type: Optional[
- Union[Type["TypeEngine[bool]"], "TypeEngine[bool]"]
+ Union[Type[TypeEngine[bool]], TypeEngine[bool]]
] = None,
**kwargs: Any,
) -> BinaryExpression[bool]:
@@ -165,7 +166,7 @@ def _custom_op_operate(
def _binary_operate(
expr: ColumnElement[Any],
op: OperatorType,
- obj: roles.BinaryElementRole,
+ obj: roles.BinaryElementRole[Any],
*,
reverse: bool = False,
result_type: Optional[
@@ -192,7 +193,7 @@ def _binary_operate(
def _conjunction_operate(
- expr: ColumnElement[Any], op: OperatorType, other, **kw
+ expr: ColumnElement[Any], op: OperatorType, other: Any, **kw: Any
) -> ColumnElement[Any]:
if op is operators.and_:
return and_(expr, other)
@@ -203,7 +204,10 @@ def _conjunction_operate(
def _scalar(
- expr: ColumnElement[Any], op: OperatorType, fn, **kw
+ expr: ColumnElement[Any],
+ op: OperatorType,
+ fn: Callable[[ColumnElement[Any]], ColumnElement[Any]],
+ **kw: Any,
) -> ColumnElement[Any]:
return fn(expr)
@@ -211,9 +215,9 @@ def _scalar(
def _in_impl(
expr: ColumnElement[Any],
op: OperatorType,
- seq_or_selectable,
+ seq_or_selectable: ClauseElement,
negate_op: OperatorType,
- **kw,
+ **kw: Any,
) -> ColumnElement[Any]:
seq_or_selectable = coercions.expect(
roles.InElementRole, seq_or_selectable, expr=expr, operator=op
@@ -227,7 +231,7 @@ def _in_impl(
def _getitem_impl(
- expr: ColumnElement[Any], op: OperatorType, other, **kw
+ expr: ColumnElement[Any], op: OperatorType, other: Any, **kw: Any
) -> ColumnElement[Any]:
if isinstance(expr.type, type_api.INDEXABLE):
other = coercions.expect(
@@ -239,7 +243,7 @@ def _getitem_impl(
def _unsupported_impl(
- expr: ColumnElement[Any], op: OperatorType, *arg, **kw
+ expr: ColumnElement[Any], op: OperatorType, *arg: Any, **kw: Any
) -> NoReturn:
raise NotImplementedError(
"Operator '%s' is not supported on " "this expression" % op.__name__
@@ -247,7 +251,7 @@ def _unsupported_impl(
def _inv_impl(
- expr: ColumnElement[Any], op: OperatorType, **kw
+ expr: ColumnElement[Any], op: OperatorType, **kw: Any
) -> ColumnElement[Any]:
"""See :meth:`.ColumnOperators.__inv__`."""
@@ -260,14 +264,14 @@ def _inv_impl(
def _neg_impl(
- expr: ColumnElement[Any], op: OperatorType, **kw
+ expr: ColumnElement[Any], op: OperatorType, **kw: Any
) -> ColumnElement[Any]:
"""See :meth:`.ColumnOperators.__neg__`."""
return UnaryExpression(expr, operator=operators.neg, type_=expr.type)
def _match_impl(
- expr: ColumnElement[Any], op: OperatorType, other, **kw
+ expr: ColumnElement[Any], op: OperatorType, other: Any, **kw: Any
) -> ColumnElement[Any]:
"""See :meth:`.ColumnOperators.match`."""
@@ -289,7 +293,7 @@ def _match_impl(
def _distinct_impl(
- expr: ColumnElement[Any], op: OperatorType, **kw
+ expr: ColumnElement[Any], op: OperatorType, **kw: Any
) -> ColumnElement[Any]:
"""See :meth:`.ColumnOperators.distinct`."""
return UnaryExpression(
@@ -298,7 +302,11 @@ def _distinct_impl(
def _between_impl(
- expr: ColumnElement[Any], op: OperatorType, cleft, cright, **kw
+ expr: ColumnElement[Any],
+ op: OperatorType,
+ cleft: Any,
+ cright: Any,
+ **kw: Any,
) -> ColumnElement[Any]:
"""See :meth:`.ColumnOperators.between`."""
return BinaryExpression(
@@ -329,26 +337,32 @@ def _between_impl(
def _collate_impl(
- expr: ColumnElement[Any], op: OperatorType, collation, **kw
-) -> ColumnElement[Any]:
+ expr: ColumnElement[str], op: OperatorType, collation: str, **kw: Any
+) -> ColumnElement[str]:
return CollationClause._create_collation_expression(expr, collation)
def _regexp_match_impl(
- expr: ColumnElement[Any], op: OperatorType, pattern, flags, **kw
+ expr: ColumnElement[str],
+ op: OperatorType,
+ pattern: Any,
+ flags: Optional[str],
+ **kw: Any,
) -> ColumnElement[Any]:
if flags is not None:
- flags = coercions.expect(
+ flags_expr = coercions.expect(
roles.BinaryElementRole,
flags,
expr=expr,
operator=operators.regexp_replace_op,
)
+ else:
+ flags_expr = None
return _boolean_compare(
expr,
op,
pattern,
- flags=flags,
+ flags=flags_expr,
negate_op=operators.not_regexp_match_op
if op is operators.regexp_match_op
else operators.regexp_match_op,
@@ -359,10 +373,10 @@ def _regexp_match_impl(
def _regexp_replace_impl(
expr: ColumnElement[Any],
op: OperatorType,
- pattern,
- replacement,
- flags,
- **kw,
+ pattern: Any,
+ replacement: Any,
+ flags: Optional[str],
+ **kw: Any,
) -> ColumnElement[Any]:
replacement = coercions.expect(
roles.BinaryElementRole,
@@ -371,21 +385,29 @@ def _regexp_replace_impl(
operator=operators.regexp_replace_op,
)
if flags is not None:
- flags = coercions.expect(
+ flags_expr = coercions.expect(
roles.BinaryElementRole,
flags,
expr=expr,
operator=operators.regexp_replace_op,
)
+ else:
+ flags_expr = None
return _binary_operate(
- expr, op, pattern, replacement=replacement, flags=flags, **kw
+ expr, op, pattern, replacement=replacement, flags=flags_expr, **kw
)
# a mapping of operators with the method they use, along with
# additional keyword arguments to be passed
operator_lookup: Dict[
- str, Tuple[Callable[..., ColumnElement[Any]], util.immutabledict]
+ str,
+ Tuple[
+ Callable[..., ColumnElement[Any]],
+ util.immutabledict[
+ str, Union[OperatorType, Callable[..., ColumnElement[Any]]]
+ ],
+ ],
] = {
"and_": (_conjunction_operate, util.EMPTY_DICT),
"or_": (_conjunction_operate, util.EMPTY_DICT),
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index 1271c5977..10316dd2b 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -12,11 +12,13 @@ Provide :class:`_expression.Insert`, :class:`_expression.Update` and
from __future__ import annotations
import collections.abc as collections_abc
+import operator
import typing
from typing import Any
from typing import List
from typing import MutableMapping
from typing import Optional
+from typing import TYPE_CHECKING
from . import coercions
from . import roles
@@ -36,10 +38,29 @@ from .elements import Null
from .selectable import HasCTE
from .selectable import HasPrefixes
from .selectable import ReturnsRows
+from .selectable import TableClause
from .sqltypes import NullType
from .visitors import InternalTraversal
from .. import exc
from .. import util
+from ..util.typing import TypeGuard
+
+
+if TYPE_CHECKING:
+
+ def isupdate(dml) -> TypeGuard[UpdateDMLState]:
+ ...
+
+ def isdelete(dml) -> TypeGuard[DeleteDMLState]:
+ ...
+
+ def isinsert(dml) -> TypeGuard[InsertDMLState]:
+ ...
+
+else:
+ isupdate = operator.attrgetter("isupdate")
+ isdelete = operator.attrgetter("isdelete")
+ isinsert = operator.attrgetter("isinsert")
class DMLState(CompileState):
@@ -49,6 +70,7 @@ class DMLState(CompileState):
_ordered_values = None
_parameter_ordering = None
_has_multi_parameters = False
+
isupdate = False
isdelete = False
isinsert = False
@@ -237,6 +259,8 @@ class UpdateBase(
_hints = util.immutabledict()
named_with_column = False
+ table: TableClause
+
_return_defaults = False
_return_defaults_columns = None
_returning = ()
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 48c3c3be6..691eb10ec 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -18,6 +18,7 @@ import itertools
import operator
import re
import typing
+from typing import AbstractSet
from typing import Any
from typing import Callable
from typing import cast
@@ -83,6 +84,7 @@ if typing.TYPE_CHECKING:
from .operators import OperatorType
from .schema import Column
from .schema import DefaultGenerator
+ from .schema import FetchedValue
from .schema import ForeignKey
from .selectable import FromClause
from .selectable import NamedFromClause
@@ -290,7 +292,7 @@ class ClauseElement(
"""
- @util.memoized_property
+ @util.ro_memoized_property
def description(self) -> Optional[str]:
return None
@@ -319,7 +321,7 @@ class ClauseElement(
_cache_key_traversal = None
- negation_clause: ClauseElement
+ negation_clause: ColumnElement[bool]
if typing.TYPE_CHECKING:
@@ -1153,9 +1155,7 @@ class ColumnElement(
primary_key: bool = False
_is_clone_of: Optional[ColumnElement[_T]]
- @util.memoized_property
- def foreign_keys(self) -> Iterable[ForeignKey]:
- return []
+ foreign_keys: AbstractSet[ForeignKey] = frozenset()
@util.memoized_property
def _proxies(self) -> List[ColumnElement[Any]]:
@@ -1494,6 +1494,8 @@ class ColumnElement(
else:
key = name
+ assert key is not None
+
co: ColumnClause[_T] = ColumnClause(
coercions.expect(roles.TruncatedLabelRole, name)
if name_is_truncatable
@@ -1506,7 +1508,6 @@ class ColumnElement(
co._proxies = [self]
if selectable._is_clone_of is not None:
co._is_clone_of = selectable._is_clone_of.columns.get(key)
- assert key is not None
return key, co
def cast(self, type_: TypeEngine[_T]) -> Cast[_T]:
@@ -4050,13 +4051,14 @@ class NamedColumn(ColumnElement[_T]):
is_literal = False
table: Optional[FromClause] = None
name: str
+ key: str
def _compare_name_for_result(self, other):
return (hasattr(other, "name") and self.name == other.name) or (
hasattr(other, "_label") and self._label == other._label
)
- @util.memoized_property
+ @util.ro_memoized_property
def description(self) -> str:
return self.name
@@ -4125,6 +4127,7 @@ class NamedColumn(ColumnElement[_T]):
_selectable=selectable,
is_literal=False,
)
+
c._propagate_attrs = selectable._propagate_attrs
if name is None:
c.key = self.key
@@ -4192,8 +4195,8 @@ class ColumnClause(
onupdate: Optional[DefaultGenerator] = None
default: Optional[DefaultGenerator] = None
- server_default: Optional[DefaultGenerator] = None
- server_onupdate: Optional[DefaultGenerator] = None
+ server_default: Optional[FetchedValue] = None
+ server_onupdate: Optional[FetchedValue] = None
_is_multiparam_column = False
diff --git a/lib/sqlalchemy/sql/events.py b/lib/sqlalchemy/sql/events.py
index 1a1fc4c41..0d74e2e4c 100644
--- a/lib/sqlalchemy/sql/events.py
+++ b/lib/sqlalchemy/sql/events.py
@@ -7,11 +7,23 @@
from __future__ import annotations
+from typing import Any
+from typing import TYPE_CHECKING
+
from .base import SchemaEventTarget
from .. import event
+if TYPE_CHECKING:
+ from .schema import Column
+ from .schema import Constraint
+ from .schema import SchemaItem
+ from .schema import Table
+ from ..engine.base import Connection
+ from ..engine.interfaces import ReflectedColumn
+ from ..engine.reflection import Inspector
+
-class DDLEvents(event.Events):
+class DDLEvents(event.Events[SchemaEventTarget]):
"""
Define event listeners for schema objects,
that is, :class:`.SchemaItem` and other :class:`.SchemaEventTarget`
@@ -93,7 +105,9 @@ class DDLEvents(event.Events):
_target_class_doc = "SomeSchemaClassOrObject"
_dispatch_target = SchemaEventTarget
- def before_create(self, target, connection, **kw):
+ def before_create(
+ self, target: SchemaEventTarget, connection: Connection, **kw: Any
+ ) -> None:
r"""Called before CREATE statements are emitted.
:param target: the :class:`_schema.MetaData` or :class:`_schema.Table`
@@ -120,7 +134,9 @@ class DDLEvents(event.Events):
"""
- def after_create(self, target, connection, **kw):
+ def after_create(
+ self, target: SchemaEventTarget, connection: Connection, **kw: Any
+ ) -> None:
r"""Called after CREATE statements are emitted.
:param target: the :class:`_schema.MetaData` or :class:`_schema.Table`
@@ -142,7 +158,9 @@ class DDLEvents(event.Events):
"""
- def before_drop(self, target, connection, **kw):
+ def before_drop(
+ self, target: SchemaEventTarget, connection: Connection, **kw: Any
+ ) -> None:
r"""Called before DROP statements are emitted.
:param target: the :class:`_schema.MetaData` or :class:`_schema.Table`
@@ -164,7 +182,9 @@ class DDLEvents(event.Events):
"""
- def after_drop(self, target, connection, **kw):
+ def after_drop(
+ self, target: SchemaEventTarget, connection: Connection, **kw: Any
+ ) -> None:
r"""Called after DROP statements are emitted.
:param target: the :class:`_schema.MetaData` or :class:`_schema.Table`
@@ -186,7 +206,9 @@ class DDLEvents(event.Events):
"""
- def before_parent_attach(self, target, parent):
+ def before_parent_attach(
+ self, target: SchemaEventTarget, parent: SchemaItem
+ ) -> None:
"""Called before a :class:`.SchemaItem` is associated with
a parent :class:`.SchemaItem`.
@@ -201,7 +223,9 @@ class DDLEvents(event.Events):
"""
- def after_parent_attach(self, target, parent):
+ def after_parent_attach(
+ self, target: SchemaEventTarget, parent: SchemaItem
+ ) -> None:
"""Called after a :class:`.SchemaItem` is associated with
a parent :class:`.SchemaItem`.
@@ -216,13 +240,17 @@ class DDLEvents(event.Events):
"""
- def _sa_event_column_added_to_pk_constraint(self, const, col):
+ def _sa_event_column_added_to_pk_constraint(
+ self, const: Constraint, col: Column[Any]
+ ) -> None:
"""internal event hook used for primary key naming convention
updates.
"""
- def column_reflect(self, inspector, table, column_info):
+ def column_reflect(
+ self, inspector: Inspector, table: Table, column_info: ReflectedColumn
+ ) -> None:
"""Called for each unit of 'column info' retrieved when
a :class:`_schema.Table` is being reflected.
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index 36ddbf309..455e74f7b 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -43,7 +43,6 @@ from ._elements_constructors import text as text
from ._elements_constructors import true as true
from ._elements_constructors import tuple_ as tuple_
from ._elements_constructors import type_coerce as type_coerce
-from ._elements_constructors import typing as typing
from ._elements_constructors import within_group as within_group
from ._selectable_constructors import alias as alias
from ._selectable_constructors import cte as cte
diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py
index 4c4f49aa8..beb73c1b5 100644
--- a/lib/sqlalchemy/sql/roles.py
+++ b/lib/sqlalchemy/sql/roles.py
@@ -211,9 +211,11 @@ class StrictFromClauseRole(FromClauseRole):
__slots__ = ()
# does not allow text() or select() objects
- c: ColumnCollection
+ c: ColumnCollection[Any]
- @property
+ # this should be ->str , however, working around:
+ # https://github.com/python/mypy/issues/12440
+ @util.ro_non_memoized_property
def description(self) -> str:
raise NotImplementedError()
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index 5cfb55603..540b62e8a 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -30,16 +30,22 @@ as components in SQL expressions.
"""
from __future__ import annotations
+from abc import ABC
import collections
+import operator
import typing
from typing import Any
+from typing import Callable
from typing import Dict
from typing import List
from typing import MutableMapping
from typing import Optional
from typing import overload
from typing import Sequence as _typing_Sequence
+from typing import Set
+from typing import Tuple
from typing import Type
+from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
@@ -48,6 +54,7 @@ from . import ddl
from . import roles
from . import type_api
from . import visitors
+from .base import ColumnCollection
from .base import DedupeColumnCollection
from .base import DialectKWArgs
from .base import Executable
@@ -67,12 +74,15 @@ from .. import exc
from .. import inspection
from .. import util
from ..util.typing import Literal
+from ..util.typing import Protocol
+from ..util.typing import TypeGuard
if typing.TYPE_CHECKING:
from .type_api import TypeEngine
from ..engine import Connection
from ..engine import Engine
-
+ from ..engine.interfaces import ExecutionContext
+ from ..engine.mock import MockConnection
_T = TypeVar("_T", bound="Any")
_ServerDefaultType = Union["FetchedValue", str, TextClause, ColumnElement]
_TAB = TypeVar("_TAB", bound="Table")
@@ -102,7 +112,7 @@ NULL_UNSPECIFIED = util.symbol(
)
-def _get_table_key(name, schema):
+def _get_table_key(name: str, schema: Optional[str]) -> str:
if schema is None:
return name
else:
@@ -207,7 +217,7 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause):
__visit_name__ = "table"
- constraints = None
+ constraints: Set[Constraint]
"""A collection of all :class:`_schema.Constraint` objects associated with
this :class:`_schema.Table`.
@@ -235,7 +245,7 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause):
"""
- indexes = None
+ indexes: Set[Index]
"""A collection of all :class:`_schema.Index` objects associated with this
:class:`_schema.Table`.
@@ -249,6 +259,14 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause):
("schema", InternalTraversal.dp_string)
]
+ if TYPE_CHECKING:
+
+ @util.non_memoized_property
+ def columns(self) -> ColumnCollection[Column[Any]]:
+ ...
+
+ c: ColumnCollection[Column[Any]]
+
def _gen_cache_key(self, anon_map, bindparams):
if self._annotations:
return (self,) + self._annotations_cache_key
@@ -736,11 +754,12 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause):
)
@property
- def _sorted_constraints(self):
+ def _sorted_constraints(self) -> List[Constraint]:
"""Return the set of constraints as a list, sorted by creation
order.
"""
+
return sorted(self.constraints, key=lambda c: c._creation_order)
@property
@@ -801,6 +820,8 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause):
)
self.info = kwargs.pop("info", self.info)
+ exclude_columns: _typing_Sequence[str]
+
if autoload:
if not autoload_replace:
# don't replace columns already present.
@@ -1074,8 +1095,8 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause):
return metadata.tables[key]
args = []
- for c in self.columns:
- args.append(c._copy(schema=schema))
+ for col in self.columns:
+ args.append(col._copy(schema=schema))
table = Table(
name,
metadata,
@@ -1084,28 +1105,30 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause):
*args,
**self.kwargs,
)
- for c in self.constraints:
- if isinstance(c, ForeignKeyConstraint):
- referred_schema = c._referred_schema
+ for const in self.constraints:
+ if isinstance(const, ForeignKeyConstraint):
+ referred_schema = const._referred_schema
if referred_schema_fn:
fk_constraint_schema = referred_schema_fn(
- self, schema, c, referred_schema
+ self, schema, const, referred_schema
)
else:
fk_constraint_schema = (
schema if referred_schema == self.schema else None
)
table.append_constraint(
- c._copy(schema=fk_constraint_schema, target_table=table)
+ const._copy(
+ schema=fk_constraint_schema, target_table=table
+ )
)
- elif not c._type_bound:
+ elif not const._type_bound:
# skip unique constraints that would be generated
# by the 'unique' flag on Column
- if c._column_flag:
+ if const._column_flag:
continue
table.append_constraint(
- c._copy(schema=schema, target_table=table)
+ const._copy(schema=schema, target_table=table)
)
for index in self.indexes:
# skip indexes that would be generated
@@ -1734,23 +1757,25 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
name = kwargs.pop("name", None)
type_ = kwargs.pop("type_", None)
- args = list(args)
- if args:
- if isinstance(args[0], str):
+ l_args = list(args)
+ del args
+
+ if l_args:
+ if isinstance(l_args[0], str):
if name is not None:
raise exc.ArgumentError(
"May not pass name positionally and as a keyword."
)
- name = args.pop(0)
- if args:
- coltype = args[0]
+ name = l_args.pop(0)
+ if l_args:
+ coltype = l_args[0]
if hasattr(coltype, "_sqla_type"):
if type_ is not None:
raise exc.ArgumentError(
"May not pass type_ positionally and as a keyword."
)
- type_ = args.pop(0)
+ type_ = l_args.pop(0)
if name is not None:
name = quoted_name(name, kwargs.pop("quote", None))
@@ -1772,7 +1797,9 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
else:
self.nullable = not primary_key
- self.default = kwargs.pop("default", None)
+ default = kwargs.pop("default", None)
+ onupdate = kwargs.pop("onupdate", None)
+
self.server_default = kwargs.pop("server_default", None)
self.server_onupdate = kwargs.pop("server_onupdate", None)
@@ -1784,7 +1811,6 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
self.system = kwargs.pop("system", False)
self.doc = kwargs.pop("doc", None)
- self.onupdate = kwargs.pop("onupdate", None)
self.autoincrement = kwargs.pop("autoincrement", "auto")
self.constraints = set()
self.foreign_keys = set()
@@ -1803,32 +1829,38 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
if isinstance(impl, SchemaEventTarget):
impl._set_parent_with_dispatch(self)
- if self.default is not None:
- if isinstance(self.default, (ColumnDefault, Sequence)):
- args.append(self.default)
- else:
- args.append(ColumnDefault(self.default))
+ if default is not None:
+ if not isinstance(default, (ColumnDefault, Sequence)):
+ default = ColumnDefault(default)
+
+ self.default = default
+ l_args.append(default)
+ else:
+ self.default = None
+
+ if onupdate is not None:
+ if not isinstance(onupdate, (ColumnDefault, Sequence)):
+ onupdate = ColumnDefault(onupdate, for_update=True)
+
+ self.onupdate = onupdate
+ l_args.append(onupdate)
+ else:
+ self.onpudate = None
if self.server_default is not None:
if isinstance(self.server_default, FetchedValue):
- args.append(self.server_default._as_for_update(False))
+ l_args.append(self.server_default._as_for_update(False))
else:
- args.append(DefaultClause(self.server_default))
-
- if self.onupdate is not None:
- if isinstance(self.onupdate, (ColumnDefault, Sequence)):
- args.append(self.onupdate)
- else:
- args.append(ColumnDefault(self.onupdate, for_update=True))
+ l_args.append(DefaultClause(self.server_default))
if self.server_onupdate is not None:
if isinstance(self.server_onupdate, FetchedValue):
- args.append(self.server_onupdate._as_for_update(True))
+ l_args.append(self.server_onupdate._as_for_update(True))
else:
- args.append(
+ l_args.append(
DefaultClause(self.server_onupdate, for_update=True)
)
- self._init_items(*args)
+ self._init_items(*l_args)
util.set_creation_order(self)
@@ -1837,7 +1869,11 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
self._extra_kwargs(**kwargs)
- foreign_keys = None
+ table: Table
+
+ constraints: Set[Constraint]
+
+ foreign_keys: Set[ForeignKey]
"""A collection of all :class:`_schema.ForeignKey` marker objects
associated with this :class:`_schema.Column`.
@@ -1850,7 +1886,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
"""
- index = None
+ index: bool
"""The value of the :paramref:`_schema.Column.index` parameter.
Does not indicate if this :class:`_schema.Column` is actually indexed
@@ -1861,7 +1897,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
:attr:`_schema.Table.indexes`
"""
- unique = None
+ unique: bool
"""The value of the :paramref:`_schema.Column.unique` parameter.
Does not indicate if this :class:`_schema.Column` is actually subject to
@@ -2074,8 +2110,8 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
server_default = self.server_default
server_onupdate = self.server_onupdate
if isinstance(server_default, (Computed, Identity)):
+ args.append(server_default._copy(**kw))
server_default = server_onupdate = None
- args.append(self.server_default._copy(**kw))
type_ = self.type
if isinstance(type_, SchemaEventTarget):
@@ -2203,9 +2239,11 @@ class ForeignKey(DialectKWArgs, SchemaItem):
__visit_name__ = "foreign_key"
+ parent: Column[Any]
+
def __init__(
self,
- column: Union[str, Column, SQLCoreOperations],
+ column: Union[str, Column[Any], SQLCoreOperations[Any]],
_constraint: Optional["ForeignKeyConstraint"] = None,
use_alter: bool = False,
name: Optional[str] = None,
@@ -2296,7 +2334,7 @@ class ForeignKey(DialectKWArgs, SchemaItem):
self._table_column = self._colspec
if not isinstance(
- self._table_column.table, (util.NoneType, TableClause)
+ self._table_column.table, (type(None), TableClause)
):
raise exc.ArgumentError(
"ForeignKey received Column not bound "
@@ -2309,7 +2347,10 @@ class ForeignKey(DialectKWArgs, SchemaItem):
# object passes itself in when creating ForeignKey
# markers.
self.constraint = _constraint
- self.parent = None
+
+ # .parent is not Optional under normal use
+ self.parent = None # type: ignore
+
self.use_alter = use_alter
self.name = name
self.onupdate = onupdate
@@ -2501,19 +2542,18 @@ class ForeignKey(DialectKWArgs, SchemaItem):
return parenttable, tablekey, colname
def _link_to_col_by_colstring(self, parenttable, table, colname):
- if not hasattr(self.constraint, "_referred_table"):
- self.constraint._referred_table = table
- else:
- assert self.constraint._referred_table is table
-
_column = None
if colname is None:
# colname is None in the case that ForeignKey argument
# was specified as table name only, in which case we
# match the column name to the same column on the
# parent.
- key = self.parent
- _column = table.c.get(self.parent.key, None)
+ # this use case wasn't working in later 1.x series
+ # as it had no test coverage; fixed in 2.0
+ parent = self.parent
+ assert parent is not None
+ key = parent.key
+ _column = table.c.get(key, None)
elif self.link_to_name:
key = colname
for c in table.c:
@@ -2533,10 +2573,10 @@ class ForeignKey(DialectKWArgs, SchemaItem):
key,
)
- self._set_target_column(_column)
+ return _column
def _set_target_column(self, column):
- assert isinstance(self.parent.table, Table)
+ assert self.parent is not None
# propagate TypeEngine to parent if it didn't have one
if self.parent.type._isnull:
@@ -2561,11 +2601,6 @@ class ForeignKey(DialectKWArgs, SchemaItem):
If no target column has been established, an exception
is raised.
- .. versionchanged:: 0.9.0
- Foreign key target column resolution now occurs as soon as both
- the ForeignKey object and the remote Column to which it refers
- are both associated with the same MetaData object.
-
"""
if isinstance(self._colspec, str):
@@ -2586,14 +2621,11 @@ class ForeignKey(DialectKWArgs, SchemaItem):
"parent MetaData" % parenttable
)
else:
- raise exc.NoReferencedColumnError(
- "Could not initialize target column for "
- "ForeignKey '%s' on table '%s': "
- "table '%s' has no column named '%s'"
- % (self._colspec, parenttable.name, tablekey, colname),
- tablekey,
- colname,
+ table = parenttable.metadata.tables[tablekey]
+ return self._link_to_col_by_colstring(
+ parenttable, table, colname
)
+
elif hasattr(self._colspec, "__clause_element__"):
_column = self._colspec.__clause_element__()
return _column
@@ -2601,18 +2633,22 @@ class ForeignKey(DialectKWArgs, SchemaItem):
_column = self._colspec
return _column
- def _set_parent(self, column, **kw):
- if self.parent is not None and self.parent is not column:
+ def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None:
+ assert isinstance(parent, Column)
+
+ if self.parent is not None and self.parent is not parent:
raise exc.InvalidRequestError(
"This ForeignKey already has a parent !"
)
- self.parent = column
+ self.parent = parent
self.parent.foreign_keys.add(self)
self.parent._on_table_attach(self._set_table)
def _set_remote_table(self, table):
- parenttable, tablekey, colname = self._resolve_col_tokens()
- self._link_to_col_by_colstring(parenttable, table, colname)
+ parenttable, _, colname = self._resolve_col_tokens()
+ _column = self._link_to_col_by_colstring(parenttable, table, colname)
+ self._set_target_column(_column)
+ assert self.constraint is not None
self.constraint._validate_dest_table(table)
def _remove_from_metadata(self, metadata):
@@ -2651,10 +2687,15 @@ class ForeignKey(DialectKWArgs, SchemaItem):
if table_key in parenttable.metadata.tables:
table = parenttable.metadata.tables[table_key]
try:
- self._link_to_col_by_colstring(parenttable, table, colname)
+ _column = self._link_to_col_by_colstring(
+ parenttable, table, colname
+ )
except exc.NoReferencedColumnError:
# this is OK, we'll try later
pass
+ else:
+ self._set_target_column(_column)
+
parenttable.metadata._fk_memos[fk_key].append(self)
elif hasattr(self._colspec, "__clause_element__"):
_column = self._colspec.__clause_element__()
@@ -2664,6 +2705,31 @@ class ForeignKey(DialectKWArgs, SchemaItem):
self._set_target_column(_column)
+if TYPE_CHECKING:
+
+ def default_is_sequence(
+ obj: Optional[DefaultGenerator],
+ ) -> TypeGuard[Sequence]:
+ ...
+
+ def default_is_clause_element(
+ obj: Optional[DefaultGenerator],
+ ) -> TypeGuard[ColumnElementColumnDefault]:
+ ...
+
+ def default_is_scalar(
+ obj: Optional[DefaultGenerator],
+ ) -> TypeGuard[ScalarElementColumnDefault]:
+ ...
+
+else:
+ default_is_sequence = operator.attrgetter("is_sequence")
+
+ default_is_clause_element = operator.attrgetter("is_clause_element")
+
+ default_is_scalar = operator.attrgetter("is_scalar")
+
+
class DefaultGenerator(Executable, SchemaItem):
"""Base class for column *default* values."""
@@ -2671,18 +2737,18 @@ class DefaultGenerator(Executable, SchemaItem):
is_sequence = False
is_server_default = False
+ is_clause_element = False
+ is_callable = False
is_scalar = False
- column = None
+ column: Optional[Column[Any]]
def __init__(self, for_update=False):
self.for_update = for_update
- @util.memoized_property
- def is_callable(self):
- raise NotImplementedError()
-
- def _set_parent(self, column, **kw):
- self.column = column
+ def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None:
+ if TYPE_CHECKING:
+ assert isinstance(parent, Column)
+ self.column = parent
if self.for_update:
self.column.onupdate = self
else:
@@ -2696,7 +2762,7 @@ class DefaultGenerator(Executable, SchemaItem):
)
-class ColumnDefault(DefaultGenerator):
+class ColumnDefault(DefaultGenerator, ABC):
"""A plain default value on a column.
This could correspond to a constant, a callable function,
@@ -2718,7 +2784,30 @@ class ColumnDefault(DefaultGenerator):
"""
- def __init__(self, arg, **kwargs):
+ arg: Any
+
+ @overload
+ def __new__(
+ cls, arg: Callable[..., Any], for_update: bool = ...
+ ) -> CallableColumnDefault:
+ ...
+
+ @overload
+ def __new__(
+ cls, arg: ColumnElement[Any], for_update: bool = ...
+ ) -> ColumnElementColumnDefault:
+ ...
+
+ # if I return ScalarElementColumnDefault here, which is what's actually
+ # returned, mypy complains that
+ # overloads overlap w/ incompatible return types.
+ @overload
+ def __new__(cls, arg: object, for_update: bool = ...) -> ColumnDefault:
+ ...
+
+ def __new__(
+ cls, arg: Any = None, for_update: bool = False
+ ) -> ColumnDefault:
"""Construct a new :class:`.ColumnDefault`.
@@ -2744,70 +2833,121 @@ class ColumnDefault(DefaultGenerator):
statement and parameters.
"""
- super(ColumnDefault, self).__init__(**kwargs)
+
if isinstance(arg, FetchedValue):
raise exc.ArgumentError(
"ColumnDefault may not be a server-side default type."
)
- if callable(arg):
- arg = self._maybe_wrap_callable(arg)
+ elif callable(arg):
+ cls = CallableColumnDefault
+ elif isinstance(arg, ClauseElement):
+ cls = ColumnElementColumnDefault
+ elif arg is not None:
+ cls = ScalarElementColumnDefault
+
+ return object.__new__(cls)
+
+ def __repr__(self):
+ return f"{self.__class__.__name__}({self.arg!r})"
+
+
+class ScalarElementColumnDefault(ColumnDefault):
+ """default generator for a fixed scalar Python value
+
+ .. versionadded: 2.0
+
+ """
+
+ is_scalar = True
+
+ def __init__(self, arg: Any, for_update: bool = False):
+ self.for_update = for_update
self.arg = arg
- @util.memoized_property
- def is_callable(self):
- return callable(self.arg)
- @util.memoized_property
- def is_clause_element(self):
- return isinstance(self.arg, ClauseElement)
+# _SQLExprDefault = Union["ColumnElement[Any]", "TextClause", "SelectBase"]
+_SQLExprDefault = Union["ColumnElement[Any]", "TextClause"]
- @util.memoized_property
- def is_scalar(self):
- return (
- not self.is_callable
- and not self.is_clause_element
- and not self.is_sequence
- )
+
+class ColumnElementColumnDefault(ColumnDefault):
+ """default generator for a SQL expression
+
+ .. versionadded:: 2.0
+
+ """
+
+ is_clause_element = True
+
+ arg: _SQLExprDefault
+
+ def __init__(
+ self,
+ arg: _SQLExprDefault,
+ for_update: bool = False,
+ ):
+ self.for_update = for_update
+ self.arg = arg
@util.memoized_property
@util.preload_module("sqlalchemy.sql.sqltypes")
def _arg_is_typed(self):
sqltypes = util.preloaded.sql_sqltypes
- if self.is_clause_element:
- return not isinstance(self.arg.type, sqltypes.NullType)
- else:
- return False
+ return not isinstance(self.arg.type, sqltypes.NullType)
+
+
+class _CallableColumnDefaultProtocol(Protocol):
+ def __call__(self, context: ExecutionContext) -> Any:
+ ...
- def _maybe_wrap_callable(self, fn):
+
+class CallableColumnDefault(ColumnDefault):
+ """default generator for a callable Python function
+
+ .. versionadded:: 2.0
+
+ """
+
+ is_callable = True
+ arg: _CallableColumnDefaultProtocol
+
+ def __init__(
+ self,
+ arg: Union[_CallableColumnDefaultProtocol, Callable[[], Any]],
+ for_update: bool = False,
+ ):
+ self.for_update = for_update
+ self.arg = self._maybe_wrap_callable(arg)
+
+ def _maybe_wrap_callable(
+ self, fn: Union[_CallableColumnDefaultProtocol, Callable[[], Any]]
+ ) -> _CallableColumnDefaultProtocol:
"""Wrap callables that don't accept a context.
This is to allow easy compatibility with default callables
that aren't specific to accepting of a context.
"""
+
try:
argspec = util.get_callable_argspec(fn, no_self=True)
except TypeError:
- return util.wrap_callable(lambda ctx: fn(), fn)
+ return util.wrap_callable(lambda ctx: fn(), fn) # type: ignore
defaulted = argspec[3] is not None and len(argspec[3]) or 0
positionals = len(argspec[0]) - defaulted
if positionals == 0:
- return util.wrap_callable(lambda ctx: fn(), fn)
+ return util.wrap_callable(lambda ctx: fn(), fn) # type: ignore
elif positionals == 1:
- return fn
+ return fn # type: ignore
else:
raise exc.ArgumentError(
"ColumnDefault Python function takes zero or one "
"positional arguments"
)
- def __repr__(self):
- return "ColumnDefault(%r)" % (self.arg,)
-
class IdentityOptions:
"""Defines options for a named database sequence or an identity column.
@@ -2899,6 +3039,8 @@ class Sequence(HasSchemaAttr, IdentityOptions, DefaultGenerator):
is_sequence = True
+ column: Optional[Column[Any]] = None
+
def __init__(
self,
name,
@@ -3087,14 +3229,6 @@ class Sequence(HasSchemaAttr, IdentityOptions, DefaultGenerator):
else:
self.data_type = None
- @util.memoized_property
- def is_callable(self):
- return False
-
- @util.memoized_property
- def is_clause_element(self):
- return False
-
@util.preload_module("sqlalchemy.sql.functions")
def next_value(self):
"""Return a :class:`.next_value` function element
@@ -3235,6 +3369,9 @@ class Constraint(DialectKWArgs, SchemaItem):
__visit_name__ = "constraint"
+ _creation_order: int
+ _column_flag: bool
+
def __init__(
self,
name=None,
@@ -3316,8 +3453,6 @@ class Constraint(DialectKWArgs, SchemaItem):
class ColumnCollectionMixin:
-
- columns = None
"""A :class:`_expression.ColumnCollection` of :class:`_schema.Column`
objects.
@@ -3326,8 +3461,17 @@ class ColumnCollectionMixin:
"""
+ columns: ColumnCollection[Column[Any]]
+
_allow_multiple_tables = False
+ if TYPE_CHECKING:
+
+ def _set_parent_with_dispatch(
+ self, parent: SchemaEventTarget, **kw: Any
+ ) -> None:
+ ...
+
def __init__(self, *columns, **kw):
_autoattach = kw.pop("_autoattach", True)
self._column_flag = kw.pop("_column_flag", False)
@@ -3404,14 +3548,16 @@ class ColumnCollectionMixin:
)
)
- def _col_expressions(self, table):
+ def _col_expressions(self, table: Table) -> List[Column[Any]]:
return [
table.c[col] if isinstance(col, str) else col
for col in self._pending_colargs
]
- def _set_parent(self, table, **kw):
- for col in self._col_expressions(table):
+ def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None:
+ if TYPE_CHECKING:
+ assert isinstance(parent, Table)
+ for col in self._col_expressions(parent):
if col is not None:
self.columns.add(col)
@@ -3446,7 +3592,7 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint):
self, *columns, _autoattach=_autoattach, _column_flag=_column_flag
)
- columns = None
+ columns: DedupeColumnCollection[Column[Any]]
"""A :class:`_expression.ColumnCollection` representing the set of columns
for this constraint.
@@ -3568,7 +3714,7 @@ class CheckConstraint(ColumnCollectionConstraint):
"""
self.sqltext = coercions.expect(roles.DDLExpressionRole, sqltext)
- columns = []
+ columns: List[Column[Any]] = []
visitors.traverse(self.sqltext, {}, {"column": columns.append})
super(CheckConstraint, self).__init__(
@@ -3779,17 +3925,17 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
assert table is self.parent
self._set_parent_with_dispatch(table)
- def _append_element(self, column, fk):
+ def _append_element(self, column: Column[Any], fk: ForeignKey) -> None:
self.columns.add(column)
self.elements.append(fk)
- columns = None
+ columns: DedupeColumnCollection[Column[Any]]
"""A :class:`_expression.ColumnCollection` representing the set of columns
for this constraint.
"""
- elements = None
+ elements: List[ForeignKey]
"""A sequence of :class:`_schema.ForeignKey` objects.
Each :class:`_schema.ForeignKey`
@@ -4271,7 +4417,7 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem):
self._validate_dialect_kwargs(kw)
- self.expressions = []
+ self.expressions: List[ColumnElement[Any]] = []
# will call _set_parent() if table-bound column
# objects are present
ColumnCollectionMixin.__init__(
@@ -4501,11 +4647,13 @@ class MetaData(HasSchemaAttr):
)
if info:
self.info = info
- self._schemas = set()
- self._sequences = {}
- self._fk_memos = collections.defaultdict(list)
+ self._schemas: Set[str] = set()
+ self._sequences: Dict[str, Sequence] = {}
+ self._fk_memos: Dict[
+ Tuple[str, str], List[ForeignKey]
+ ] = collections.defaultdict(list)
- tables: Dict[str, Table]
+ tables: util.FacadeDict[str, Table]
"""A dictionary of :class:`_schema.Table`
objects keyed to their name or "table key".
@@ -4539,7 +4687,7 @@ class MetaData(HasSchemaAttr):
def _remove_table(self, name, schema):
key = _get_table_key(name, schema)
- removed = dict.pop(self.tables, key, None)
+ removed = dict.pop(self.tables, key, None) # type: ignore
if removed is not None:
for fk in removed.foreign_keys:
fk._remove_from_metadata(self)
@@ -4634,12 +4782,12 @@ class MetaData(HasSchemaAttr):
"""
return ddl.sort_tables(
- sorted(self.tables.values(), key=lambda t: t.key)
+ sorted(self.tables.values(), key=lambda t: t.key) # type: ignore
)
def reflect(
self,
- bind: Union["Engine", "Connection"],
+ bind: Union[Engine, Connection],
schema: Optional[str] = None,
views: bool = False,
only: Optional[_typing_Sequence[str]] = None,
@@ -4647,7 +4795,7 @@ class MetaData(HasSchemaAttr):
autoload_replace: bool = True,
resolve_fks: bool = True,
**dialect_kwargs: Any,
- ):
+ ) -> None:
r"""Load all available table definitions from the database.
Automatically creates ``Table`` entries in this ``MetaData`` for any
@@ -4748,12 +4896,14 @@ class MetaData(HasSchemaAttr):
if schema is not None:
reflect_opts["schema"] = schema
- available = util.OrderedSet(insp.get_table_names(schema))
+ available: util.OrderedSet[str] = util.OrderedSet(
+ insp.get_table_names(schema)
+ )
if views:
available.update(insp.get_view_names(schema))
if schema is not None:
- available_w_schema = util.OrderedSet(
+ available_w_schema: util.OrderedSet[str] = util.OrderedSet(
["%s.%s" % (schema, name) for name in available]
)
else:
@@ -4796,10 +4946,10 @@ class MetaData(HasSchemaAttr):
def create_all(
self,
- bind: Union["Engine", "Connection"],
+ bind: Union[Engine, Connection, MockConnection],
tables: Optional[_typing_Sequence[Table]] = None,
checkfirst: bool = True,
- ):
+ ) -> None:
"""Create all tables stored in this metadata.
Conditional by default, will not attempt to recreate tables already
@@ -4824,10 +4974,10 @@ class MetaData(HasSchemaAttr):
def drop_all(
self,
- bind: Union["Engine", "Connection"],
+ bind: Union[Engine, Connection, MockConnection],
tables: Optional[_typing_Sequence[Table]] = None,
checkfirst: bool = True,
- ):
+ ) -> None:
"""Drop all tables stored in this metadata.
Conditional by default, will not attempt to drop tables not present in
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index e143d1476..8665a74db 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -463,7 +463,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
_is_clone_of: Optional[FromClause]
- schema = None
+ schema: Optional[str] = None
"""Define the 'schema' attribute for this :class:`_expression.FromClause`.
This is typically ``None`` for most objects except that of
@@ -673,7 +673,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
"""
return self._cloned_set.intersection(other._cloned_set)
- @property
+ @util.non_memoized_property
def description(self) -> str:
"""A brief description of this :class:`_expression.FromClause`.
@@ -710,7 +710,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
return self.columns
@util.memoized_property
- def columns(self) -> ColumnCollection:
+ def columns(self) -> ColumnCollection[Any]:
"""A named-based collection of :class:`_expression.ColumnElement`
objects maintained by this :class:`_expression.FromClause`.
@@ -796,7 +796,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
# this is awkward. maybe there's a better way
if TYPE_CHECKING:
- c: ColumnCollection
+ c: ColumnCollection[Any]
else:
c = property(
attrgetter("columns"),
@@ -2399,6 +2399,8 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause):
_is_table = True
+ fullname: str
+
implicit_returning = False
""":class:`_expression.TableClause`
doesn't support having a primary key or column
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py
index 829c1b72e..1a6de34b0 100644
--- a/lib/sqlalchemy/sql/sqltypes.py
+++ b/lib/sqlalchemy/sql/sqltypes.py
@@ -345,6 +345,12 @@ class Integer(HasExpressionLookup, TypeEngine[int]):
__visit_name__ = "integer"
+ if TYPE_CHECKING:
+
+ @util.ro_memoized_property
+ def _type_affinity(self) -> Type[Integer]:
+ ...
+
def get_dbapi_type(self, dbapi):
return dbapi.NUMBER
@@ -1892,8 +1898,8 @@ class _AbstractInterval(HasExpressionLookup, TypeEngine[dt.timedelta]):
operators.truediv: {Numeric: self.__class__},
}
- @util.non_memoized_property
- def _type_affinity(self) -> Optional[Type[TypeEngine[Any]]]:
+ @util.ro_non_memoized_property
+ def _type_affinity(self) -> Type[Interval]:
return Interval
diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py
index 5a0aba694..9a934a50b 100644
--- a/lib/sqlalchemy/sql/type_api.py
+++ b/lib/sqlalchemy/sql/type_api.py
@@ -705,7 +705,7 @@ class TypeEngine(Visitable, Generic[_T]):
"""
return self
- @util.memoized_property
+ @util.ro_memoized_property
def _type_affinity(self) -> Optional[Type[TypeEngine[_T]]]:
"""Return a rudimental 'affinity' value expressing the general class
of type."""
@@ -719,7 +719,7 @@ class TypeEngine(Visitable, Generic[_T]):
else:
return self.__class__
- @util.memoized_property
+ @util.ro_memoized_property
def _generic_type_affinity(
self,
) -> Type[TypeEngine[_T]]:
@@ -1694,7 +1694,7 @@ class TypeDecorator(SchemaEventTarget, ExternalType, TypeEngine[_T]):
tt.impl = tt.impl_instance = typedesc
return tt
- @util.non_memoized_property
+ @util.ro_non_memoized_property
def _type_affinity(self) -> Optional[Type[TypeEngine[Any]]]:
return self.impl_instance._type_affinity