diff options
Diffstat (limited to 'lib/sqlalchemy/sql')
| -rw-r--r-- | lib/sqlalchemy/sql/_typing.py | 9 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/base.py | 41 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/cache_key.py | 13 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 180 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/elements.py | 12 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/schema.py | 12 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/util.py | 38 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/visitors.py | 11 |
8 files changed, 233 insertions, 83 deletions
diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index 7d8b9ee5c..69e4645fa 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -1,20 +1,11 @@ from __future__ import annotations -from typing import Any -from typing import Mapping -from typing import Sequence from typing import Type from typing import Union from . import roles from ..inspection import Inspectable -from ..util import immutabledict -_SingleExecuteParams = Mapping[str, Any] -_MultiExecuteParams = Sequence[_SingleExecuteParams] -_ExecuteParams = Union[_SingleExecuteParams, _MultiExecuteParams] -_ExecuteOptions = Mapping[str, Any] -_ImmutableExecuteOptions = immutabledict[str, Any] _ColumnsClauseElement = Union[ roles.ColumnsClauseRole, Type, Inspectable[roles.HasClauseElement] ] diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 3936ed9c6..a94590da1 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -19,11 +19,12 @@ from itertools import zip_longest import operator import re import typing +from typing import Optional +from typing import Sequence from typing import TypeVar from . import roles from . import visitors -from ._typing import _ImmutableExecuteOptions from .cache_key import HasCacheKey # noqa from .cache_key import MemoizedHasCacheKey # noqa from .traversals import HasCopyInternals # noqa @@ -32,7 +33,7 @@ from .visitors import ExtendedInternalTraversal from .visitors import InternalTraversal from .. import exc from .. import util -from ..util import HasMemoized +from ..util import HasMemoized as HasMemoized from ..util import hybridmethod from ..util import typing as compat_typing from ..util._has_cy import HAS_CYEXTENSION @@ -42,6 +43,16 @@ if typing.TYPE_CHECKING or not HAS_CYEXTENSION: else: from sqlalchemy.cyextension.util import prefix_anon_map # noqa +if typing.TYPE_CHECKING: + from ..engine import Connection + from ..engine import Result + from ..engine.interfaces import _CoreMultiExecuteParams + from ..engine.interfaces import _ExecuteOptions + from ..engine.interfaces import _ExecuteOptionsParameter + from ..engine.interfaces import _ImmutableExecuteOptions + from ..engine.interfaces import CacheStats + + coercions = None elements = None type_api = None @@ -856,6 +867,32 @@ class Executable(roles.StatementRole, Generative): is_delete = False is_dml = False + if typing.TYPE_CHECKING: + + def _compile_w_cache( + self, + dialect: Dialect, + compiled_cache: Optional[_CompiledCacheType] = None, + column_keys: Optional[Sequence[str]] = None, + for_executemany: bool = False, + schema_translate_map: Optional[_SchemaTranslateMapType] = None, + **kw: Any, + ) -> Tuple[Compiled, _SingleExecuteParams, CacheStats]: + ... + + def _execute_on_connection( + self, + connection: Connection, + distilled_params: _CoreMultiExecuteParams, + execution_options: _ExecuteOptionsParameter, + _force: bool = False, + ) -> Result: + ... + + @property + def _all_selected_columns(self): + raise NotImplementedError() + @property def _effective_plugin_target(self): return self.__visit_name__ diff --git a/lib/sqlalchemy/sql/cache_key.py b/lib/sqlalchemy/sql/cache_key.py index 49f1899d5..ff659b77d 100644 --- a/lib/sqlalchemy/sql/cache_key.py +++ b/lib/sqlalchemy/sql/cache_key.py @@ -7,10 +7,12 @@ from __future__ import annotations -from collections import namedtuple import enum from itertools import zip_longest +import typing +from typing import Any from typing import Callable +from typing import NamedTuple from typing import Union from .visitors import anon_map @@ -22,6 +24,10 @@ from ..util import HasMemoized from ..util.typing import Literal +if typing.TYPE_CHECKING: + from .elements import BindParameter + + class CacheConst(enum.Enum): NO_CACHE = 0 @@ -345,7 +351,7 @@ class MemoizedHasCacheKey(HasCacheKey, HasMemoized): return HasCacheKey._generate_cache_key(self) -class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])): +class CacheKey(NamedTuple): """The key used to identify a SQL statement construct in the SQL compilation cache. @@ -355,6 +361,9 @@ class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])): """ + key: Tuple[Any, ...] + bindparams: Sequence[BindParameter] + def __hash__(self): """CacheKey itself is not hashable - hash the .key portion""" diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index d0f114d6c..712d31462 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -27,6 +27,7 @@ from __future__ import annotations import collections import collections.abc as collections_abc import contextlib +from enum import IntEnum import itertools import operator import re @@ -35,9 +36,13 @@ import typing from typing import Any from typing import Dict from typing import List +from typing import Mapping from typing import MutableMapping +from typing import NamedTuple from typing import Optional +from typing import Sequence from typing import Tuple +from typing import Union from . import base from . import coercions @@ -51,12 +56,17 @@ from . import sqltypes from .base import NO_ARG from .base import prefix_anon_map from .elements import quoted_name +from .schema import Column +from .type_api import TypeEngine from .. import exc from .. import util +from ..util.typing import Literal if typing.TYPE_CHECKING: from .selectable import CTE from .selectable import FromClause + from ..engine.interfaces import _CoreSingleExecuteParams + from ..engine.result import _ProcessorType _FromHintsType = Dict["FromClause", str] @@ -271,42 +281,71 @@ COMPOUND_KEYWORDS = { } -RM_RENDERED_NAME = 0 -RM_NAME = 1 -RM_OBJECTS = 2 -RM_TYPE = 3 +class ResultColumnsEntry(NamedTuple): + """Tracks a column expression that is expected to be represented + in the result rows for this statement. + This normally refers to the columns clause of a SELECT statement + but may also refer to a RETURNING clause, as well as for dialect-specific + emulations. -ExpandedState = collections.namedtuple( - "ExpandedState", - [ - "statement", - "additional_parameters", - "processors", - "positiontup", - "parameter_expansion", - ], -) + """ + keyname: str + """string name that's expected in cursor.description""" -NO_LINTING = util.symbol("NO_LINTING", "Disable all linting.", canonical=0) + name: str + """column name, may be labeled""" -COLLECT_CARTESIAN_PRODUCTS = util.symbol( - "COLLECT_CARTESIAN_PRODUCTS", - "Collect data on FROMs and cartesian products and gather " - "into 'self.from_linter'", - canonical=1, -) + objects: List[Any] + """list of objects that should be able to locate this column + in a RowMapping. This is typically string names and aliases + as well as Column objects. -WARN_LINTING = util.symbol( - "WARN_LINTING", "Emit warnings for linters that find problems", canonical=2 -) + """ + + type: TypeEngine[Any] + """Datatype to be associated with this column. This is where + the "result processing" logic directly links the compiled statement + to the rows that come back from the cursor. + + """ + + +# integer indexes into ResultColumnsEntry used by cursor.py. +# some profiling showed integer access faster than named tuple +RM_RENDERED_NAME: Literal[0] = 0 +RM_NAME: Literal[1] = 1 +RM_OBJECTS: Literal[2] = 2 +RM_TYPE: Literal[3] = 3 + + +class ExpandedState(NamedTuple): + statement: str + additional_parameters: _CoreSingleExecuteParams + processors: Mapping[str, _ProcessorType] + positiontup: Optional[Sequence[str]] + parameter_expansion: Mapping[str, List[str]] + + +class Linting(IntEnum): + NO_LINTING = 0 + "Disable all linting." + + COLLECT_CARTESIAN_PRODUCTS = 1 + """Collect data on FROMs and cartesian products and gather into + 'self.from_linter'""" + + WARN_LINTING = 2 + "Emit warnings for linters that find problems" -FROM_LINTING = util.symbol( - "FROM_LINTING", - "Warn for cartesian products; " - "combines COLLECT_CARTESIAN_PRODUCTS and WARN_LINTING", - canonical=COLLECT_CARTESIAN_PRODUCTS | WARN_LINTING, + FROM_LINTING = COLLECT_CARTESIAN_PRODUCTS | WARN_LINTING + """Warn for cartesian products; combines COLLECT_CARTESIAN_PRODUCTS + and WARN_LINTING""" + + +NO_LINTING, COLLECT_CARTESIAN_PRODUCTS, WARN_LINTING, FROM_LINTING = tuple( + Linting ) @@ -389,7 +428,7 @@ class Compiled: _cached_metadata = None - _result_columns = None + _result_columns: Optional[List[ResultColumnsEntry]] = None schema_translate_map = None @@ -418,7 +457,8 @@ class Compiled: """ cache_key = None - _gen_time = None + + _gen_time: float def __init__( self, @@ -573,15 +613,43 @@ class SQLCompiler(Compiled): extract_map = EXTRACT_MAP + _result_columns: List[ResultColumnsEntry] + compound_keywords = COMPOUND_KEYWORDS - isdelete = isinsert = isupdate = False + isdelete: bool = False + isinsert: bool = False + isupdate: bool = False """class-level defaults which can be set at the instance level to define if this Compiled instance represents INSERT/UPDATE/DELETE """ - isplaintext = False + postfetch: Optional[List[Column[Any]]] + """list of columns that can be post-fetched after INSERT or UPDATE to + receive server-updated values""" + + insert_prefetch: Optional[List[Column[Any]]] + """list of columns for which default values should be evaluated before + an INSERT takes place""" + + update_prefetch: Optional[List[Column[Any]]] + """list of columns for which onupdate default values should be evaluated + before an UPDATE takes place""" + + returning: Optional[List[Column[Any]]] + """list of columns that will be delivered to cursor.description or + dialect equivalent via the RETURNING clause on an INSERT, UPDATE, or DELETE + + """ + + isplaintext: bool = False + + result_columns: List[ResultColumnsEntry] + """relates label names in the final SQL to a tuple of local + column/label name, ColumnElement object (if any) and + TypeEngine. CursorResult uses this for type processing and + column targeting""" returning = None """holds the "returning" collection of columns if @@ -589,18 +657,18 @@ class SQLCompiler(Compiled): either implicitly or explicitly """ - returning_precedes_values = False + returning_precedes_values: bool = False """set to True classwide to generate RETURNING clauses before the VALUES or WHERE clause (i.e. MSSQL) """ - render_table_with_column_in_update_from = False + render_table_with_column_in_update_from: bool = False """set to True classwide to indicate the SET clause in a multi-table UPDATE statement should qualify columns with the table name (i.e. MySQL only) """ - ansi_bind_rules = False + ansi_bind_rules: bool = False """SQL 92 doesn't allow bind parameters to be used in the columns clause of a SELECT, nor does it allow ambiguous expressions like "? = ?". A compiler @@ -608,33 +676,33 @@ class SQLCompiler(Compiled): driver/DB enforces this """ - _textual_ordered_columns = False + _textual_ordered_columns: bool = False """tell the result object that the column names as rendered are important, but they are also "ordered" vs. what is in the compiled object here. """ - _ordered_columns = True + _ordered_columns: bool = True """ if False, means we can't be sure the list of entries in _result_columns is actually the rendered order. Usually True unless using an unordered TextualSelect. """ - _loose_column_name_matching = False + _loose_column_name_matching: bool = False """tell the result object that the SQL statement is textual, wants to match up to Column objects, and may be using the ._tq_label in the SELECT rather than the base name. """ - _numeric_binds = False + _numeric_binds: bool = False """ True if paramstyle is "numeric". This paramstyle is trickier than all the others. """ - _render_postcompile = False + _render_postcompile: bool = False """ whether to render out POSTCOMPILE params during the compile phase. @@ -684,7 +752,7 @@ class SQLCompiler(Compiled): """ - positiontup = None + positiontup: Optional[Sequence[str]] = None """for a compiled construct that uses a positional paramstyle, will be a sequence of strings, indicating the names of bound parameters in order. @@ -699,7 +767,7 @@ class SQLCompiler(Compiled): """ - inline = False + inline: bool = False def __init__( self, @@ -760,10 +828,6 @@ class SQLCompiler(Compiled): # stack which keeps track of nested SELECT statements self.stack = [] - # relates label names in the final SQL to a tuple of local - # column/label name, ColumnElement object (if any) and - # TypeEngine. CursorResult uses this for type processing and - # column targeting self._result_columns = [] # true if the paramstyle is positional @@ -910,7 +974,9 @@ class SQLCompiler(Compiled): ) @util.memoized_property - def _bind_processors(self): + def _bind_processors( + self, + ) -> MutableMapping[str, Union[_ProcessorType, Sequence[_ProcessorType]]]: return dict( (key, value) for key, value in ( @@ -1098,8 +1164,10 @@ class SQLCompiler(Compiled): return self.construct_params(_check=False) def _process_parameters_for_postcompile( - self, parameters=None, _populate_self=False - ): + self, + parameters: Optional[_CoreSingleExecuteParams] = None, + _populate_self: bool = False, + ) -> ExpandedState: """handle special post compile parameters. These include: @@ -3070,7 +3138,13 @@ class SQLCompiler(Compiled): def get_render_as_alias_suffix(self, alias_name_text): return " AS " + alias_name_text - def _add_to_result_map(self, keyname, name, objects, type_): + def _add_to_result_map( + self, + keyname: str, + name: str, + objects: List[Any], + type_: TypeEngine[Any], + ) -> None: if keyname is None or keyname == "*": self._ordered_columns = False self._textual_ordered_columns = True @@ -3080,7 +3154,9 @@ class SQLCompiler(Compiled): "from a tuple() object. If this is an ORM query, " "consider using the Bundle object." ) - self._result_columns.append((keyname, name, objects, type_)) + self._result_columns.append( + ResultColumnsEntry(keyname, name, objects, type_) + ) def _label_returning_column(self, stmt, column, column_clause_args=None): """Render a column with necessary labels inside of a RETURNING clause. diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 0c532a135..ac5dc46db 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -61,6 +61,11 @@ if typing.TYPE_CHECKING: from .selectable import Select from .sqltypes import Boolean # noqa from .type_api import TypeEngine + from ..engine import Compiled + from ..engine import Connection + from ..engine import Dialect + from ..engine import Engine + _NUMERIC = Union[complex, "Decimal"] @@ -145,7 +150,12 @@ class CompilerElement(Visitable): @util.preload_module("sqlalchemy.engine.default") @util.preload_module("sqlalchemy.engine.url") - def compile(self, bind=None, dialect=None, **kw): + def compile( + self, + bind: Optional[Union[Engine, Connection]] = None, + dialect: Optional[Dialect] = None, + **kw: Any, + ) -> Compiled: """Compile this SQL expression. The return value is a :class:`~.Compiled` object. diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 528691795..fdae4d7b0 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -174,7 +174,13 @@ class SchemaItem(SchemaEventTarget, visitors.Visitable): _use_schema_map = True -class Table(DialectKWArgs, SchemaItem, TableClause): +class HasSchemaAttr(SchemaItem): + """schema item that includes a top-level schema name""" + + schema: Optional[str] + + +class Table(DialectKWArgs, HasSchemaAttr, TableClause): r"""Represent a table in a database. e.g.:: @@ -2850,7 +2856,7 @@ class IdentityOptions: self.order = order -class Sequence(IdentityOptions, DefaultGenerator): +class Sequence(HasSchemaAttr, IdentityOptions, DefaultGenerator): """Represents a named database sequence. The :class:`.Sequence` object represents the name and configurational @@ -4330,7 +4336,7 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem): DEFAULT_NAMING_CONVENTION = util.immutabledict({"ix": "ix_%(column_0_label)s"}) -class MetaData(SchemaItem): +class MetaData(HasSchemaAttr): """A collection of :class:`_schema.Table` objects and their associated schema constructs. diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index e3e358cdb..e0248adf0 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -21,9 +21,6 @@ from . import coercions from . import operators from . import roles from . import visitors -from ._typing import _ExecuteParams -from ._typing import _MultiExecuteParams -from ._typing import _SingleExecuteParams from .annotation import _deep_annotate # noqa from .annotation import _deep_deannotate # noqa from .annotation import _shallow_annotate # noqa @@ -54,6 +51,10 @@ from .. import exc from .. import util if typing.TYPE_CHECKING: + from ..engine.interfaces import _AnyExecuteParams + from ..engine.interfaces import _AnyMultiExecuteParams + from ..engine.interfaces import _AnySingleExecuteParams + from ..engine.interfaces import _CoreSingleExecuteParams from ..engine.row import Row @@ -550,12 +551,12 @@ class _repr_params(_repr_base): def __init__( self, - params: _ExecuteParams, + params: Optional[_AnyExecuteParams], batches: int, max_chars: int = 300, ismulti: Optional[bool] = None, ): - self.params: _ExecuteParams = params + self.params = params self.ismulti = ismulti self.batches = batches self.max_chars = max_chars @@ -575,7 +576,10 @@ class _repr_params(_repr_base): return self.trunc(self.params) if self.ismulti: - multi_params = cast(_MultiExecuteParams, self.params) + multi_params = cast( + "_AnyMultiExecuteParams", + self.params, + ) if len(self.params) > self.batches: msg = ( @@ -595,10 +599,18 @@ class _repr_params(_repr_base): return self._repr_multi(multi_params, typ) else: return self._repr_params( - cast(_SingleExecuteParams, self.params), typ + cast( + "_AnySingleExecuteParams", + self.params, + ), + typ, ) - def _repr_multi(self, multi_params: _MultiExecuteParams, typ) -> str: + def _repr_multi( + self, + multi_params: _AnyMultiExecuteParams, + typ, + ) -> str: if multi_params: if isinstance(multi_params[0], list): elem_type = self._LIST @@ -622,13 +634,19 @@ class _repr_params(_repr_base): else: return "(%s)" % elements - def _repr_params(self, params: _SingleExecuteParams, typ: int) -> str: + def _repr_params( + self, + params: Optional[_AnySingleExecuteParams], + typ: int, + ) -> str: trunc = self.trunc if typ is self._DICT: return "{%s}" % ( ", ".join( "%r: %s" % (key, trunc(value)) - for key, value in params.items() + for key, value in cast( + "_CoreSingleExecuteParams", params + ).items() ) ) elif typ is self._TUPLE: diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 523426d09..111ecd32e 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -28,6 +28,8 @@ from __future__ import annotations from collections import deque import itertools import operator +import typing +from typing import Any from typing import List from typing import Tuple @@ -35,12 +37,13 @@ from .. import exc from .. import util from ..util import langhelpers from ..util import symbol +from ..util._has_cy import HAS_CYEXTENSION from ..util.langhelpers import _symbol -try: - from sqlalchemy.cyextension.util import cache_anon_map as anon_map # noqa -except ImportError: +if typing.TYPE_CHECKING or not HAS_CYEXTENSION: from ._py_util import cache_anon_map as anon_map # noqa +else: + from sqlalchemy.cyextension.util import cache_anon_map as anon_map # noqa __all__ = [ "iterate", @@ -554,7 +557,7 @@ class ExternalTraversal: __traverse_options__ = {} - def traverse_single(self, obj, **kw): + def traverse_single(self, obj: Visitable, **kw: Any) -> Any: for v in self.visitor_iterator: meth = getattr(v, "visit_%s" % obj.__visit_name__, None) if meth: |
