summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/_typing.py9
-rw-r--r--lib/sqlalchemy/sql/base.py41
-rw-r--r--lib/sqlalchemy/sql/cache_key.py13
-rw-r--r--lib/sqlalchemy/sql/compiler.py180
-rw-r--r--lib/sqlalchemy/sql/elements.py12
-rw-r--r--lib/sqlalchemy/sql/schema.py12
-rw-r--r--lib/sqlalchemy/sql/util.py38
-rw-r--r--lib/sqlalchemy/sql/visitors.py11
8 files changed, 233 insertions, 83 deletions
diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py
index 7d8b9ee5c..69e4645fa 100644
--- a/lib/sqlalchemy/sql/_typing.py
+++ b/lib/sqlalchemy/sql/_typing.py
@@ -1,20 +1,11 @@
from __future__ import annotations
-from typing import Any
-from typing import Mapping
-from typing import Sequence
from typing import Type
from typing import Union
from . import roles
from ..inspection import Inspectable
-from ..util import immutabledict
-_SingleExecuteParams = Mapping[str, Any]
-_MultiExecuteParams = Sequence[_SingleExecuteParams]
-_ExecuteParams = Union[_SingleExecuteParams, _MultiExecuteParams]
-_ExecuteOptions = Mapping[str, Any]
-_ImmutableExecuteOptions = immutabledict[str, Any]
_ColumnsClauseElement = Union[
roles.ColumnsClauseRole, Type, Inspectable[roles.HasClauseElement]
]
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index 3936ed9c6..a94590da1 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -19,11 +19,12 @@ from itertools import zip_longest
import operator
import re
import typing
+from typing import Optional
+from typing import Sequence
from typing import TypeVar
from . import roles
from . import visitors
-from ._typing import _ImmutableExecuteOptions
from .cache_key import HasCacheKey # noqa
from .cache_key import MemoizedHasCacheKey # noqa
from .traversals import HasCopyInternals # noqa
@@ -32,7 +33,7 @@ from .visitors import ExtendedInternalTraversal
from .visitors import InternalTraversal
from .. import exc
from .. import util
-from ..util import HasMemoized
+from ..util import HasMemoized as HasMemoized
from ..util import hybridmethod
from ..util import typing as compat_typing
from ..util._has_cy import HAS_CYEXTENSION
@@ -42,6 +43,16 @@ if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
else:
from sqlalchemy.cyextension.util import prefix_anon_map # noqa
+if typing.TYPE_CHECKING:
+ from ..engine import Connection
+ from ..engine import Result
+ from ..engine.interfaces import _CoreMultiExecuteParams
+ from ..engine.interfaces import _ExecuteOptions
+ from ..engine.interfaces import _ExecuteOptionsParameter
+ from ..engine.interfaces import _ImmutableExecuteOptions
+ from ..engine.interfaces import CacheStats
+
+
coercions = None
elements = None
type_api = None
@@ -856,6 +867,32 @@ class Executable(roles.StatementRole, Generative):
is_delete = False
is_dml = False
+ if typing.TYPE_CHECKING:
+
+ def _compile_w_cache(
+ self,
+ dialect: Dialect,
+ compiled_cache: Optional[_CompiledCacheType] = None,
+ column_keys: Optional[Sequence[str]] = None,
+ for_executemany: bool = False,
+ schema_translate_map: Optional[_SchemaTranslateMapType] = None,
+ **kw: Any,
+ ) -> Tuple[Compiled, _SingleExecuteParams, CacheStats]:
+ ...
+
+ def _execute_on_connection(
+ self,
+ connection: Connection,
+ distilled_params: _CoreMultiExecuteParams,
+ execution_options: _ExecuteOptionsParameter,
+ _force: bool = False,
+ ) -> Result:
+ ...
+
+ @property
+ def _all_selected_columns(self):
+ raise NotImplementedError()
+
@property
def _effective_plugin_target(self):
return self.__visit_name__
diff --git a/lib/sqlalchemy/sql/cache_key.py b/lib/sqlalchemy/sql/cache_key.py
index 49f1899d5..ff659b77d 100644
--- a/lib/sqlalchemy/sql/cache_key.py
+++ b/lib/sqlalchemy/sql/cache_key.py
@@ -7,10 +7,12 @@
from __future__ import annotations
-from collections import namedtuple
import enum
from itertools import zip_longest
+import typing
+from typing import Any
from typing import Callable
+from typing import NamedTuple
from typing import Union
from .visitors import anon_map
@@ -22,6 +24,10 @@ from ..util import HasMemoized
from ..util.typing import Literal
+if typing.TYPE_CHECKING:
+ from .elements import BindParameter
+
+
class CacheConst(enum.Enum):
NO_CACHE = 0
@@ -345,7 +351,7 @@ class MemoizedHasCacheKey(HasCacheKey, HasMemoized):
return HasCacheKey._generate_cache_key(self)
-class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])):
+class CacheKey(NamedTuple):
"""The key used to identify a SQL statement construct in the
SQL compilation cache.
@@ -355,6 +361,9 @@ class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])):
"""
+ key: Tuple[Any, ...]
+ bindparams: Sequence[BindParameter]
+
def __hash__(self):
"""CacheKey itself is not hashable - hash the .key portion"""
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index d0f114d6c..712d31462 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -27,6 +27,7 @@ from __future__ import annotations
import collections
import collections.abc as collections_abc
import contextlib
+from enum import IntEnum
import itertools
import operator
import re
@@ -35,9 +36,13 @@ import typing
from typing import Any
from typing import Dict
from typing import List
+from typing import Mapping
from typing import MutableMapping
+from typing import NamedTuple
from typing import Optional
+from typing import Sequence
from typing import Tuple
+from typing import Union
from . import base
from . import coercions
@@ -51,12 +56,17 @@ from . import sqltypes
from .base import NO_ARG
from .base import prefix_anon_map
from .elements import quoted_name
+from .schema import Column
+from .type_api import TypeEngine
from .. import exc
from .. import util
+from ..util.typing import Literal
if typing.TYPE_CHECKING:
from .selectable import CTE
from .selectable import FromClause
+ from ..engine.interfaces import _CoreSingleExecuteParams
+ from ..engine.result import _ProcessorType
_FromHintsType = Dict["FromClause", str]
@@ -271,42 +281,71 @@ COMPOUND_KEYWORDS = {
}
-RM_RENDERED_NAME = 0
-RM_NAME = 1
-RM_OBJECTS = 2
-RM_TYPE = 3
+class ResultColumnsEntry(NamedTuple):
+ """Tracks a column expression that is expected to be represented
+ in the result rows for this statement.
+ This normally refers to the columns clause of a SELECT statement
+ but may also refer to a RETURNING clause, as well as for dialect-specific
+ emulations.
-ExpandedState = collections.namedtuple(
- "ExpandedState",
- [
- "statement",
- "additional_parameters",
- "processors",
- "positiontup",
- "parameter_expansion",
- ],
-)
+ """
+ keyname: str
+ """string name that's expected in cursor.description"""
-NO_LINTING = util.symbol("NO_LINTING", "Disable all linting.", canonical=0)
+ name: str
+ """column name, may be labeled"""
-COLLECT_CARTESIAN_PRODUCTS = util.symbol(
- "COLLECT_CARTESIAN_PRODUCTS",
- "Collect data on FROMs and cartesian products and gather "
- "into 'self.from_linter'",
- canonical=1,
-)
+ objects: List[Any]
+ """list of objects that should be able to locate this column
+ in a RowMapping. This is typically string names and aliases
+ as well as Column objects.
-WARN_LINTING = util.symbol(
- "WARN_LINTING", "Emit warnings for linters that find problems", canonical=2
-)
+ """
+
+ type: TypeEngine[Any]
+ """Datatype to be associated with this column. This is where
+ the "result processing" logic directly links the compiled statement
+ to the rows that come back from the cursor.
+
+ """
+
+
+# integer indexes into ResultColumnsEntry used by cursor.py.
+# some profiling showed integer access faster than named tuple
+RM_RENDERED_NAME: Literal[0] = 0
+RM_NAME: Literal[1] = 1
+RM_OBJECTS: Literal[2] = 2
+RM_TYPE: Literal[3] = 3
+
+
+class ExpandedState(NamedTuple):
+ statement: str
+ additional_parameters: _CoreSingleExecuteParams
+ processors: Mapping[str, _ProcessorType]
+ positiontup: Optional[Sequence[str]]
+ parameter_expansion: Mapping[str, List[str]]
+
+
+class Linting(IntEnum):
+ NO_LINTING = 0
+ "Disable all linting."
+
+ COLLECT_CARTESIAN_PRODUCTS = 1
+ """Collect data on FROMs and cartesian products and gather into
+ 'self.from_linter'"""
+
+ WARN_LINTING = 2
+ "Emit warnings for linters that find problems"
-FROM_LINTING = util.symbol(
- "FROM_LINTING",
- "Warn for cartesian products; "
- "combines COLLECT_CARTESIAN_PRODUCTS and WARN_LINTING",
- canonical=COLLECT_CARTESIAN_PRODUCTS | WARN_LINTING,
+ FROM_LINTING = COLLECT_CARTESIAN_PRODUCTS | WARN_LINTING
+ """Warn for cartesian products; combines COLLECT_CARTESIAN_PRODUCTS
+ and WARN_LINTING"""
+
+
+NO_LINTING, COLLECT_CARTESIAN_PRODUCTS, WARN_LINTING, FROM_LINTING = tuple(
+ Linting
)
@@ -389,7 +428,7 @@ class Compiled:
_cached_metadata = None
- _result_columns = None
+ _result_columns: Optional[List[ResultColumnsEntry]] = None
schema_translate_map = None
@@ -418,7 +457,8 @@ class Compiled:
"""
cache_key = None
- _gen_time = None
+
+ _gen_time: float
def __init__(
self,
@@ -573,15 +613,43 @@ class SQLCompiler(Compiled):
extract_map = EXTRACT_MAP
+ _result_columns: List[ResultColumnsEntry]
+
compound_keywords = COMPOUND_KEYWORDS
- isdelete = isinsert = isupdate = False
+ isdelete: bool = False
+ isinsert: bool = False
+ isupdate: bool = False
"""class-level defaults which can be set at the instance
level to define if this Compiled instance represents
INSERT/UPDATE/DELETE
"""
- isplaintext = False
+ postfetch: Optional[List[Column[Any]]]
+ """list of columns that can be post-fetched after INSERT or UPDATE to
+ receive server-updated values"""
+
+ insert_prefetch: Optional[List[Column[Any]]]
+ """list of columns for which default values should be evaluated before
+ an INSERT takes place"""
+
+ update_prefetch: Optional[List[Column[Any]]]
+ """list of columns for which onupdate default values should be evaluated
+ before an UPDATE takes place"""
+
+ returning: Optional[List[Column[Any]]]
+ """list of columns that will be delivered to cursor.description or
+ dialect equivalent via the RETURNING clause on an INSERT, UPDATE, or DELETE
+
+ """
+
+ isplaintext: bool = False
+
+ result_columns: List[ResultColumnsEntry]
+ """relates label names in the final SQL to a tuple of local
+ column/label name, ColumnElement object (if any) and
+ TypeEngine. CursorResult uses this for type processing and
+ column targeting"""
returning = None
"""holds the "returning" collection of columns if
@@ -589,18 +657,18 @@ class SQLCompiler(Compiled):
either implicitly or explicitly
"""
- returning_precedes_values = False
+ returning_precedes_values: bool = False
"""set to True classwide to generate RETURNING
clauses before the VALUES or WHERE clause (i.e. MSSQL)
"""
- render_table_with_column_in_update_from = False
+ render_table_with_column_in_update_from: bool = False
"""set to True classwide to indicate the SET clause
in a multi-table UPDATE statement should qualify
columns with the table name (i.e. MySQL only)
"""
- ansi_bind_rules = False
+ ansi_bind_rules: bool = False
"""SQL 92 doesn't allow bind parameters to be used
in the columns clause of a SELECT, nor does it allow
ambiguous expressions like "? = ?". A compiler
@@ -608,33 +676,33 @@ class SQLCompiler(Compiled):
driver/DB enforces this
"""
- _textual_ordered_columns = False
+ _textual_ordered_columns: bool = False
"""tell the result object that the column names as rendered are important,
but they are also "ordered" vs. what is in the compiled object here.
"""
- _ordered_columns = True
+ _ordered_columns: bool = True
"""
if False, means we can't be sure the list of entries
in _result_columns is actually the rendered order. Usually
True unless using an unordered TextualSelect.
"""
- _loose_column_name_matching = False
+ _loose_column_name_matching: bool = False
"""tell the result object that the SQL statement is textual, wants to match
up to Column objects, and may be using the ._tq_label in the SELECT rather
than the base name.
"""
- _numeric_binds = False
+ _numeric_binds: bool = False
"""
True if paramstyle is "numeric". This paramstyle is trickier than
all the others.
"""
- _render_postcompile = False
+ _render_postcompile: bool = False
"""
whether to render out POSTCOMPILE params during the compile phase.
@@ -684,7 +752,7 @@ class SQLCompiler(Compiled):
"""
- positiontup = None
+ positiontup: Optional[Sequence[str]] = None
"""for a compiled construct that uses a positional paramstyle, will be
a sequence of strings, indicating the names of bound parameters in order.
@@ -699,7 +767,7 @@ class SQLCompiler(Compiled):
"""
- inline = False
+ inline: bool = False
def __init__(
self,
@@ -760,10 +828,6 @@ class SQLCompiler(Compiled):
# stack which keeps track of nested SELECT statements
self.stack = []
- # relates label names in the final SQL to a tuple of local
- # column/label name, ColumnElement object (if any) and
- # TypeEngine. CursorResult uses this for type processing and
- # column targeting
self._result_columns = []
# true if the paramstyle is positional
@@ -910,7 +974,9 @@ class SQLCompiler(Compiled):
)
@util.memoized_property
- def _bind_processors(self):
+ def _bind_processors(
+ self,
+ ) -> MutableMapping[str, Union[_ProcessorType, Sequence[_ProcessorType]]]:
return dict(
(key, value)
for key, value in (
@@ -1098,8 +1164,10 @@ class SQLCompiler(Compiled):
return self.construct_params(_check=False)
def _process_parameters_for_postcompile(
- self, parameters=None, _populate_self=False
- ):
+ self,
+ parameters: Optional[_CoreSingleExecuteParams] = None,
+ _populate_self: bool = False,
+ ) -> ExpandedState:
"""handle special post compile parameters.
These include:
@@ -3070,7 +3138,13 @@ class SQLCompiler(Compiled):
def get_render_as_alias_suffix(self, alias_name_text):
return " AS " + alias_name_text
- def _add_to_result_map(self, keyname, name, objects, type_):
+ def _add_to_result_map(
+ self,
+ keyname: str,
+ name: str,
+ objects: List[Any],
+ type_: TypeEngine[Any],
+ ) -> None:
if keyname is None or keyname == "*":
self._ordered_columns = False
self._textual_ordered_columns = True
@@ -3080,7 +3154,9 @@ class SQLCompiler(Compiled):
"from a tuple() object. If this is an ORM query, "
"consider using the Bundle object."
)
- self._result_columns.append((keyname, name, objects, type_))
+ self._result_columns.append(
+ ResultColumnsEntry(keyname, name, objects, type_)
+ )
def _label_returning_column(self, stmt, column, column_clause_args=None):
"""Render a column with necessary labels inside of a RETURNING clause.
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 0c532a135..ac5dc46db 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -61,6 +61,11 @@ if typing.TYPE_CHECKING:
from .selectable import Select
from .sqltypes import Boolean # noqa
from .type_api import TypeEngine
+ from ..engine import Compiled
+ from ..engine import Connection
+ from ..engine import Dialect
+ from ..engine import Engine
+
_NUMERIC = Union[complex, "Decimal"]
@@ -145,7 +150,12 @@ class CompilerElement(Visitable):
@util.preload_module("sqlalchemy.engine.default")
@util.preload_module("sqlalchemy.engine.url")
- def compile(self, bind=None, dialect=None, **kw):
+ def compile(
+ self,
+ bind: Optional[Union[Engine, Connection]] = None,
+ dialect: Optional[Dialect] = None,
+ **kw: Any,
+ ) -> Compiled:
"""Compile this SQL expression.
The return value is a :class:`~.Compiled` object.
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index 528691795..fdae4d7b0 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -174,7 +174,13 @@ class SchemaItem(SchemaEventTarget, visitors.Visitable):
_use_schema_map = True
-class Table(DialectKWArgs, SchemaItem, TableClause):
+class HasSchemaAttr(SchemaItem):
+ """schema item that includes a top-level schema name"""
+
+ schema: Optional[str]
+
+
+class Table(DialectKWArgs, HasSchemaAttr, TableClause):
r"""Represent a table in a database.
e.g.::
@@ -2850,7 +2856,7 @@ class IdentityOptions:
self.order = order
-class Sequence(IdentityOptions, DefaultGenerator):
+class Sequence(HasSchemaAttr, IdentityOptions, DefaultGenerator):
"""Represents a named database sequence.
The :class:`.Sequence` object represents the name and configurational
@@ -4330,7 +4336,7 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem):
DEFAULT_NAMING_CONVENTION = util.immutabledict({"ix": "ix_%(column_0_label)s"})
-class MetaData(SchemaItem):
+class MetaData(HasSchemaAttr):
"""A collection of :class:`_schema.Table`
objects and their associated schema
constructs.
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index e3e358cdb..e0248adf0 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -21,9 +21,6 @@ from . import coercions
from . import operators
from . import roles
from . import visitors
-from ._typing import _ExecuteParams
-from ._typing import _MultiExecuteParams
-from ._typing import _SingleExecuteParams
from .annotation import _deep_annotate # noqa
from .annotation import _deep_deannotate # noqa
from .annotation import _shallow_annotate # noqa
@@ -54,6 +51,10 @@ from .. import exc
from .. import util
if typing.TYPE_CHECKING:
+ from ..engine.interfaces import _AnyExecuteParams
+ from ..engine.interfaces import _AnyMultiExecuteParams
+ from ..engine.interfaces import _AnySingleExecuteParams
+ from ..engine.interfaces import _CoreSingleExecuteParams
from ..engine.row import Row
@@ -550,12 +551,12 @@ class _repr_params(_repr_base):
def __init__(
self,
- params: _ExecuteParams,
+ params: Optional[_AnyExecuteParams],
batches: int,
max_chars: int = 300,
ismulti: Optional[bool] = None,
):
- self.params: _ExecuteParams = params
+ self.params = params
self.ismulti = ismulti
self.batches = batches
self.max_chars = max_chars
@@ -575,7 +576,10 @@ class _repr_params(_repr_base):
return self.trunc(self.params)
if self.ismulti:
- multi_params = cast(_MultiExecuteParams, self.params)
+ multi_params = cast(
+ "_AnyMultiExecuteParams",
+ self.params,
+ )
if len(self.params) > self.batches:
msg = (
@@ -595,10 +599,18 @@ class _repr_params(_repr_base):
return self._repr_multi(multi_params, typ)
else:
return self._repr_params(
- cast(_SingleExecuteParams, self.params), typ
+ cast(
+ "_AnySingleExecuteParams",
+ self.params,
+ ),
+ typ,
)
- def _repr_multi(self, multi_params: _MultiExecuteParams, typ) -> str:
+ def _repr_multi(
+ self,
+ multi_params: _AnyMultiExecuteParams,
+ typ,
+ ) -> str:
if multi_params:
if isinstance(multi_params[0], list):
elem_type = self._LIST
@@ -622,13 +634,19 @@ class _repr_params(_repr_base):
else:
return "(%s)" % elements
- def _repr_params(self, params: _SingleExecuteParams, typ: int) -> str:
+ def _repr_params(
+ self,
+ params: Optional[_AnySingleExecuteParams],
+ typ: int,
+ ) -> str:
trunc = self.trunc
if typ is self._DICT:
return "{%s}" % (
", ".join(
"%r: %s" % (key, trunc(value))
- for key, value in params.items()
+ for key, value in cast(
+ "_CoreSingleExecuteParams", params
+ ).items()
)
)
elif typ is self._TUPLE:
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 523426d09..111ecd32e 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -28,6 +28,8 @@ from __future__ import annotations
from collections import deque
import itertools
import operator
+import typing
+from typing import Any
from typing import List
from typing import Tuple
@@ -35,12 +37,13 @@ from .. import exc
from .. import util
from ..util import langhelpers
from ..util import symbol
+from ..util._has_cy import HAS_CYEXTENSION
from ..util.langhelpers import _symbol
-try:
- from sqlalchemy.cyextension.util import cache_anon_map as anon_map # noqa
-except ImportError:
+if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
from ._py_util import cache_anon_map as anon_map # noqa
+else:
+ from sqlalchemy.cyextension.util import cache_anon_map as anon_map # noqa
__all__ = [
"iterate",
@@ -554,7 +557,7 @@ class ExternalTraversal:
__traverse_options__ = {}
- def traverse_single(self, obj, **kw):
+ def traverse_single(self, obj: Visitable, **kw: Any) -> Any:
for v in self.visitor_iterator:
meth = getattr(v, "visit_%s" % obj.__visit_name__, None)
if meth: