diff options
Diffstat (limited to 'lib')
| -rw-r--r-- | lib/sqlalchemy/__init__.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/ext.py | 25 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/base.py | 6 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/asyncio/engine.py | 6 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/clsregistry.py | 3 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/events.py | 22 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/loading.py | 6 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/session.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/_typing.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/elements.py | 24 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/expression.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/functions.py | 14 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/schema.py | 5 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 30 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/sqltypes.py | 20 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/fixtures.py | 8 |
16 files changed, 124 insertions, 56 deletions
diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py index 6720437fa..70eace964 100644 --- a/lib/sqlalchemy/__init__.py +++ b/lib/sqlalchemy/__init__.py @@ -162,6 +162,8 @@ from .sql.expression import Null as Null from .sql.expression import null as null from .sql.expression import nulls_first as nulls_first from .sql.expression import nulls_last as nulls_last +from .sql.expression import nullsfirst as nullsfirst +from .sql.expression import nullslast as nullslast from .sql.expression import Operators as Operators from .sql.expression import or_ as or_ from .sql.expression import outerjoin as outerjoin diff --git a/lib/sqlalchemy/dialects/postgresql/ext.py b/lib/sqlalchemy/dialects/postgresql/ext.py index 22604955d..8c09eddda 100644 --- a/lib/sqlalchemy/dialects/postgresql/ext.py +++ b/lib/sqlalchemy/dialects/postgresql/ext.py @@ -7,7 +7,6 @@ # mypy: ignore-errors from __future__ import annotations -from itertools import zip_longest from typing import Any from typing import TYPE_CHECKING from typing import TypeVar @@ -164,15 +163,19 @@ class ExcludeConstraint(ColumnCollectionConstraint): :param \*elements: A sequence of two tuples of the form ``(column, operator)`` where - "column" is a SQL expression element or the name of a column as - string, most typically a :class:`_schema.Column` object, - and "operator" is a string containing the operator to use. + "column" is either a :class:`_schema.Column` object, or a SQL + expression element (e.g. ``func.int8range(table.from, table.to)``) + or the name of a column as string, and "operator" is a string + containing the operator to use (e.g. `"&&"` or `"="`). + In order to specify a column name when a :class:`_schema.Column` object is not available, while ensuring that any necessary quoting rules take effect, an ad-hoc :class:`_schema.Column` or :func:`_expression.column` - object should be used. ``column`` may also be a string SQL - expression when passed as :func:`_expression.literal_column` + object should be used. + The ``column`` may also be a string SQL expression when + passed as :func:`_expression.literal_column` or + :func:`_expression.text` :param name: Optional, the in-database name of this constraint. @@ -252,22 +255,20 @@ class ExcludeConstraint(ColumnCollectionConstraint): self._render_exprs = [ ( - expr if isinstance(expr, elements.ClauseElement) else colexpr, + expr if not isinstance(expr, str) else table.c[expr], name, operator, ) - for (expr, name, operator), colexpr in zip_longest( - self._render_exprs, self.columns - ) + for expr, name, operator in (self._render_exprs) ] def _copy(self, target_table=None, **kw): elements = [ ( schema._copy_expression(expr, self.parent, target_table), - self.operators[expr.name], + operator, ) - for expr in self.columns + for expr, _, operator in self._render_exprs ] c = self.__class__( *elements, diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index f6c637aa8..926a08b76 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -1306,7 +1306,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): def scalars( self, statement: TypedReturnsRows[Tuple[_T]], - parameters: Optional[_CoreSingleExecuteParams] = None, + parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, ) -> ScalarResult[_T]: @@ -1316,7 +1316,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): def scalars( self, statement: Executable, - parameters: Optional[_CoreSingleExecuteParams] = None, + parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, ) -> ScalarResult[Any]: @@ -1325,7 +1325,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): def scalars( self, statement: Executable, - parameters: Optional[_CoreSingleExecuteParams] = None, + parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, ) -> ScalarResult[Any]: diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index 86e257bdd..325c58bda 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -646,7 +646,7 @@ class AsyncConnection( async def scalars( self, statement: TypedReturnsRows[Tuple[_T]], - parameters: Optional[_CoreSingleExecuteParams] = None, + parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, ) -> ScalarResult[_T]: @@ -656,7 +656,7 @@ class AsyncConnection( async def scalars( self, statement: Executable, - parameters: Optional[_CoreSingleExecuteParams] = None, + parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, ) -> ScalarResult[Any]: @@ -665,7 +665,7 @@ class AsyncConnection( async def scalars( self, statement: Executable, - parameters: Optional[_CoreSingleExecuteParams] = None, + parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, ) -> ScalarResult[Any]: diff --git a/lib/sqlalchemy/orm/clsregistry.py b/lib/sqlalchemy/orm/clsregistry.py index e5fff4a5e..c4d6c29eb 100644 --- a/lib/sqlalchemy/orm/clsregistry.py +++ b/lib/sqlalchemy/orm/clsregistry.py @@ -553,7 +553,8 @@ def _resolver( if _fallback_dict is None: import sqlalchemy - from sqlalchemy.orm import foreign, remote + from . import foreign + from . import remote _fallback_dict = util.immutabledict(sqlalchemy.__dict__).union( {"foreign": foreign, "remote": remote} diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index bf3da5015..413bfbfcb 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -1683,11 +1683,13 @@ class SessionEvents(event.Events[Session]): This event is invoked for all top-level SQL statements invoked from the :meth:`_orm.Session.execute` method, as well as related methods such as :meth:`_orm.Session.scalars` and :meth:`_orm.Session.scalar`. As of - SQLAlchemy 1.4, all ORM queries emitted on behalf of a - :class:`_orm.Session` will flow through this method, so this event hook - provides the single point at which ORM queries of all types may be - intercepted before they are invoked, and additionally to replace their - execution with a different process. + SQLAlchemy 1.4, all ORM queries that run through the + :meth:`_orm.Session.execute` method as well as related methods + :meth:`_orm.Session.scalars`, :meth:`_orm.Session.scalar` etc. + will participate in this event. + This event hook does **not** apply to the queries that are + emitted internally within the ORM flush process, i.e. the + process described at :ref:`session_flushing`. .. note:: The :meth:`_orm.SessionEvents.do_orm_execute` event hook is triggered **for ORM statement executions only**, meaning those @@ -1698,11 +1700,17 @@ class SessionEvents(event.Events[Session]): otherwise originating from an :class:`_engine.Engine` object without any :class:`_orm.Session` involved. To intercept **all** SQL executions regardless of whether the Core or ORM APIs are in use, - see the event hooks at - :class:`.ConnectionEvents`, such as + see the event hooks at :class:`.ConnectionEvents`, such as :meth:`.ConnectionEvents.before_execute` and :meth:`.ConnectionEvents.before_cursor_execute`. + Also, this event hook does **not** apply to queries that are + emitted internally within the ORM flush process, + i.e. the process described at :ref:`session_flushing`; to + intercept steps within the flush process, see the event + hooks described at :ref:`session_persistence_events` as + well as :ref:`session_persistence_mapper`. + This event is a ``do_`` event, meaning it has the capability to replace the operation that the :meth:`_orm.Session.execute` method normally performs. The intended use for this includes sharding and diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index 7974d94c5..3d9ff7b0a 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -37,6 +37,8 @@ from .base import _RAISE_FOR_STATE from .base import _SET_DEFERRED_EXPIRED from .base import PassiveFlag from .context import FromStatement +from .context import ORMCompileState +from .context import QueryContext from .util import _none_set from .util import state_str from .. import exc as sa_exc @@ -55,7 +57,6 @@ from ..util import EMPTY_DICT if TYPE_CHECKING: from ._typing import _IdentityKeyType from .base import LoaderCallableStatus - from .context import QueryContext from .interfaces import ORMOption from .mapper import Mapper from .query import Query @@ -519,9 +520,6 @@ def load_on_pk_identity( assert not q._is_lambda_element - # TODO: fix these imports .... - from .context import QueryContext, ORMCompileState - if load_options is None: load_options = QueryContext.default_load_options diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index a39dbc3ec..760c71afd 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -546,8 +546,8 @@ class ORMExecuteState(util.MemoizedSlots): def is_orm_statement(self) -> bool: """return True if the operation is an ORM statement. - This indicates that the select(), update(), or delete() being - invoked contains ORM entities as subjects. For a statement + This indicates that the select(), insert(), update(), or delete() + being invoked contains ORM entities as subjects. For a statement that does not have ORM entities and instead refers only to :class:`.Table` metadata, it is invoked as a Core SQL statement and no ORM-level automation takes place. diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index a828d6a0f..14b1b9594 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -186,6 +186,7 @@ overall which brings in the TextClause object also. """ + _ColumnExpressionOrLiteralArgument = Union[Any, _ColumnExpressionArgument[_T]] _ColumnExpressionOrStrLabelArgument = Union[str, _ColumnExpressionArgument[_T]] diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index e51b755dd..a416b6ac0 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -99,6 +99,7 @@ if typing.TYPE_CHECKING: from .selectable import _SelectIterable from .selectable import FromClause from .selectable import NamedFromClause + from .selectable import TextualSelect from .sqltypes import TupleType from .type_api import TypeEngine from .visitors import _CloneCallableType @@ -2385,7 +2386,9 @@ class TextClause( return self @util.preload_module("sqlalchemy.sql.selectable") - def columns(self, *cols, **types): + def columns( + self, *cols: _ColumnExpressionArgument[Any], **types: TypeEngine[Any] + ) -> TextualSelect: r"""Turn this :class:`_expression.TextClause` object into a :class:`_expression.TextualSelect` object that serves the same role as a SELECT @@ -2503,29 +2506,38 @@ class TextClause( """ selectable = util.preloaded.sql_selectable + + input_cols: List[NamedColumn[Any]] = [ + coercions.expect(roles.LabeledColumnExprRole, col) for col in cols + ] + positional_input_cols = [ ColumnClause(col.key, types.pop(col.key)) if col.key in types else col - for col in cols + for col in input_cols ] - keyed_input_cols: List[ColumnClause[Any]] = [ + keyed_input_cols: List[NamedColumn[Any]] = [ ColumnClause(key, type_) for key, type_ in types.items() ] - return selectable.TextualSelect( + elem = selectable.TextualSelect.__new__(selectable.TextualSelect) + elem._init( self, positional_input_cols + keyed_input_cols, positional=bool(positional_input_cols) and not keyed_input_cols, ) + return elem @property - def type(self): + def type(self) -> TypeEngine[Any]: return type_api.NULLTYPE @property def comparator(self): - return self.type.comparator_factory(self) + # TODO: this seems wrong, it seems like we might not + # be using this method. + return self.type.comparator_factory(self) # type: ignore def self_group(self, against=None): if against is operators.in_op: diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 7076cd10d..4fa9cda00 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -152,4 +152,8 @@ from .selectable import Values as Values from .visitors import Visitable as Visitable nullsfirst = nulls_first +"""Synonym for the :func:`.nulls_first` function.""" + + nullslast = nulls_last +"""Synonym for the :func:`.nulls_last` function.""" diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 6054be98a..5f2e67288 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -13,6 +13,7 @@ from __future__ import annotations import datetime +import decimal from typing import Any from typing import cast from typing import Dict @@ -54,7 +55,6 @@ from .elements import WithinGroup from .selectable import FromClause from .selectable import Select from .selectable import TableValuedAlias -from .sqltypes import _N from .sqltypes import TableValueType from .type_api import TypeEngine from .visitors import InternalTraversal @@ -950,7 +950,7 @@ class _FunctionGenerator: ... @property - def cume_dist(self) -> Type[cume_dist[Any]]: + def cume_dist(self) -> Type[cume_dist]: ... @property @@ -1014,7 +1014,7 @@ class _FunctionGenerator: ... @property - def percent_rank(self) -> Type[percent_rank[Any]]: + def percent_rank(self) -> Type[percent_rank]: ... @property @@ -1703,7 +1703,7 @@ class dense_rank(GenericFunction[int]): inherit_cache = True -class percent_rank(GenericFunction[_N]): +class percent_rank(GenericFunction[decimal.Decimal]): """Implement the ``percent_rank`` hypothetical-set aggregate function. This function must be used with the :meth:`.FunctionElement.within_group` @@ -1715,11 +1715,11 @@ class percent_rank(GenericFunction[_N]): """ - type: sqltypes.Numeric[_N] = sqltypes.Numeric() + type: sqltypes.Numeric[decimal.Decimal] = sqltypes.Numeric() inherit_cache = True -class cume_dist(GenericFunction[_N]): +class cume_dist(GenericFunction[decimal.Decimal]): """Implement the ``cume_dist`` hypothetical-set aggregate function. This function must be used with the :meth:`.FunctionElement.within_group` @@ -1731,7 +1731,7 @@ class cume_dist(GenericFunction[_N]): """ - type: sqltypes.Numeric[_N] = sqltypes.Numeric() + type: sqltypes.Numeric[decimal.Decimal] = sqltypes.Numeric() inherit_cache = True diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 20c0341ad..b4263137b 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -5227,6 +5227,11 @@ class MetaData(HasSchemaAttr): examples. """ + if schema is not None and not isinstance(schema, str): + raise exc.ArgumentError( + "expected schema argument to be a string, " + f"got {type(schema)}." + ) self.tables = util.FacadeDict() self.schema = quoted_name.construct(schema, quote_schema) self.naming_convention = ( diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 75b5d09e3..39ef420dd 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -4551,7 +4551,7 @@ class SelectState(util.MemoizedSlots, CompileState): @classmethod def from_statement( - cls, statement: Select[Any], from_statement: ExecutableReturnsRows + cls, statement: Select[Any], from_statement: roles.ReturnsRowsRole ) -> ExecutableReturnsRows: cls._plugin_not_implemented() @@ -5273,7 +5273,7 @@ class Select( return meth(self) def from_statement( - self, statement: ExecutableReturnsRows + self, statement: roles.ReturnsRowsRole ) -> ExecutableReturnsRows: """Apply the columns which this :class:`.Select` would select onto another statement. @@ -6770,7 +6770,7 @@ class Exists(UnaryExpression[bool]): return e -class TextualSelect(SelectBase, Executable, Generative): +class TextualSelect(SelectBase, ExecutableReturnsRows, Generative): """Wrap a :class:`_expression.TextClause` construct within a :class:`_expression.SelectBase` interface. @@ -6815,14 +6815,28 @@ class TextualSelect(SelectBase, Executable, Generative): def __init__( self, text: TextClause, - columns: List[ColumnClause[Any]], + columns: List[_ColumnExpressionArgument[Any]], + positional: bool = False, + ) -> None: + + self._init( + text, + # convert for ORM attributes->columns, etc + [ + coercions.expect(roles.LabeledColumnExprRole, c) + for c in columns + ], + positional, + ) + + def _init( + self, + text: TextClause, + columns: List[NamedColumn[Any]], positional: bool = False, ) -> None: self.element = text - # convert for ORM attributes->columns, etc - self.column_args = [ - coercions.expect(roles.ColumnsClauseRole, c) for c in columns - ] + self.column_args = columns self.positional = positional @HasMemoized_ro_memoized_attribute diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 3c6cb0cb5..458394870 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -470,6 +470,26 @@ class Numeric(HasExpressionLookup, TypeEngine[_N]): _default_decimal_return_scale = 10 + @overload + def __init__( + self: Numeric[decimal.Decimal], + precision: Optional[int] = ..., + scale: Optional[int] = ..., + decimal_return_scale: Optional[int] = ..., + asdecimal: Literal[True] = ..., + ): + ... + + @overload + def __init__( + self: Numeric[float], + precision: Optional[int] = ..., + scale: Optional[int] = ..., + decimal_return_scale: Optional[int] = ..., + asdecimal: Literal[False] = ..., + ): + ... + def __init__( self, precision: Optional[int] = None, diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index 3b58052b8..a8bc6c50a 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -24,7 +24,12 @@ from .entities import ComparableEntity from .entities import ComparableMixin # noqa from .util import adict from .util import drop_all_tables_from_metadata +from .. import Column from .. import event +from .. import func +from .. import Integer +from .. import select +from .. import Table from .. import util from ..orm import DeclarativeBase from ..orm import events as orm_events @@ -247,9 +252,6 @@ class TestBase: def trans_ctx_manager_fixture(self, request, metadata): rollback, second_operation, begin_nested = request.param - from sqlalchemy import Table, Column, Integer, func, select - from . import eq_ - t = Table("test", metadata, Column("data", Integer)) eng = getattr(self, "bind", None) or config.db |
