summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/sqlalchemy/__init__.py2
-rw-r--r--lib/sqlalchemy/dialects/postgresql/ext.py25
-rw-r--r--lib/sqlalchemy/engine/base.py6
-rw-r--r--lib/sqlalchemy/ext/asyncio/engine.py6
-rw-r--r--lib/sqlalchemy/orm/clsregistry.py3
-rw-r--r--lib/sqlalchemy/orm/events.py22
-rw-r--r--lib/sqlalchemy/orm/loading.py6
-rw-r--r--lib/sqlalchemy/orm/session.py4
-rw-r--r--lib/sqlalchemy/sql/_typing.py1
-rw-r--r--lib/sqlalchemy/sql/elements.py24
-rw-r--r--lib/sqlalchemy/sql/expression.py4
-rw-r--r--lib/sqlalchemy/sql/functions.py14
-rw-r--r--lib/sqlalchemy/sql/schema.py5
-rw-r--r--lib/sqlalchemy/sql/selectable.py30
-rw-r--r--lib/sqlalchemy/sql/sqltypes.py20
-rw-r--r--lib/sqlalchemy/testing/fixtures.py8
16 files changed, 124 insertions, 56 deletions
diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py
index 6720437fa..70eace964 100644
--- a/lib/sqlalchemy/__init__.py
+++ b/lib/sqlalchemy/__init__.py
@@ -162,6 +162,8 @@ from .sql.expression import Null as Null
from .sql.expression import null as null
from .sql.expression import nulls_first as nulls_first
from .sql.expression import nulls_last as nulls_last
+from .sql.expression import nullsfirst as nullsfirst
+from .sql.expression import nullslast as nullslast
from .sql.expression import Operators as Operators
from .sql.expression import or_ as or_
from .sql.expression import outerjoin as outerjoin
diff --git a/lib/sqlalchemy/dialects/postgresql/ext.py b/lib/sqlalchemy/dialects/postgresql/ext.py
index 22604955d..8c09eddda 100644
--- a/lib/sqlalchemy/dialects/postgresql/ext.py
+++ b/lib/sqlalchemy/dialects/postgresql/ext.py
@@ -7,7 +7,6 @@
# mypy: ignore-errors
from __future__ import annotations
-from itertools import zip_longest
from typing import Any
from typing import TYPE_CHECKING
from typing import TypeVar
@@ -164,15 +163,19 @@ class ExcludeConstraint(ColumnCollectionConstraint):
:param \*elements:
A sequence of two tuples of the form ``(column, operator)`` where
- "column" is a SQL expression element or the name of a column as
- string, most typically a :class:`_schema.Column` object,
- and "operator" is a string containing the operator to use.
+ "column" is either a :class:`_schema.Column` object, or a SQL
+ expression element (e.g. ``func.int8range(table.from, table.to)``)
+ or the name of a column as string, and "operator" is a string
+ containing the operator to use (e.g. `"&&"` or `"="`).
+
In order to specify a column name when a :class:`_schema.Column`
object is not available, while ensuring
that any necessary quoting rules take effect, an ad-hoc
:class:`_schema.Column` or :func:`_expression.column`
- object should be used. ``column`` may also be a string SQL
- expression when passed as :func:`_expression.literal_column`
+ object should be used.
+ The ``column`` may also be a string SQL expression when
+ passed as :func:`_expression.literal_column` or
+ :func:`_expression.text`
:param name:
Optional, the in-database name of this constraint.
@@ -252,22 +255,20 @@ class ExcludeConstraint(ColumnCollectionConstraint):
self._render_exprs = [
(
- expr if isinstance(expr, elements.ClauseElement) else colexpr,
+ expr if not isinstance(expr, str) else table.c[expr],
name,
operator,
)
- for (expr, name, operator), colexpr in zip_longest(
- self._render_exprs, self.columns
- )
+ for expr, name, operator in (self._render_exprs)
]
def _copy(self, target_table=None, **kw):
elements = [
(
schema._copy_expression(expr, self.parent, target_table),
- self.operators[expr.name],
+ operator,
)
- for expr in self.columns
+ for expr, _, operator in self._render_exprs
]
c = self.__class__(
*elements,
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index f6c637aa8..926a08b76 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -1306,7 +1306,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
def scalars(
self,
statement: TypedReturnsRows[Tuple[_T]],
- parameters: Optional[_CoreSingleExecuteParams] = None,
+ parameters: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> ScalarResult[_T]:
@@ -1316,7 +1316,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
def scalars(
self,
statement: Executable,
- parameters: Optional[_CoreSingleExecuteParams] = None,
+ parameters: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> ScalarResult[Any]:
@@ -1325,7 +1325,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
def scalars(
self,
statement: Executable,
- parameters: Optional[_CoreSingleExecuteParams] = None,
+ parameters: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> ScalarResult[Any]:
diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py
index 86e257bdd..325c58bda 100644
--- a/lib/sqlalchemy/ext/asyncio/engine.py
+++ b/lib/sqlalchemy/ext/asyncio/engine.py
@@ -646,7 +646,7 @@ class AsyncConnection(
async def scalars(
self,
statement: TypedReturnsRows[Tuple[_T]],
- parameters: Optional[_CoreSingleExecuteParams] = None,
+ parameters: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> ScalarResult[_T]:
@@ -656,7 +656,7 @@ class AsyncConnection(
async def scalars(
self,
statement: Executable,
- parameters: Optional[_CoreSingleExecuteParams] = None,
+ parameters: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> ScalarResult[Any]:
@@ -665,7 +665,7 @@ class AsyncConnection(
async def scalars(
self,
statement: Executable,
- parameters: Optional[_CoreSingleExecuteParams] = None,
+ parameters: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> ScalarResult[Any]:
diff --git a/lib/sqlalchemy/orm/clsregistry.py b/lib/sqlalchemy/orm/clsregistry.py
index e5fff4a5e..c4d6c29eb 100644
--- a/lib/sqlalchemy/orm/clsregistry.py
+++ b/lib/sqlalchemy/orm/clsregistry.py
@@ -553,7 +553,8 @@ def _resolver(
if _fallback_dict is None:
import sqlalchemy
- from sqlalchemy.orm import foreign, remote
+ from . import foreign
+ from . import remote
_fallback_dict = util.immutabledict(sqlalchemy.__dict__).union(
{"foreign": foreign, "remote": remote}
diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py
index bf3da5015..413bfbfcb 100644
--- a/lib/sqlalchemy/orm/events.py
+++ b/lib/sqlalchemy/orm/events.py
@@ -1683,11 +1683,13 @@ class SessionEvents(event.Events[Session]):
This event is invoked for all top-level SQL statements invoked from the
:meth:`_orm.Session.execute` method, as well as related methods such as
:meth:`_orm.Session.scalars` and :meth:`_orm.Session.scalar`. As of
- SQLAlchemy 1.4, all ORM queries emitted on behalf of a
- :class:`_orm.Session` will flow through this method, so this event hook
- provides the single point at which ORM queries of all types may be
- intercepted before they are invoked, and additionally to replace their
- execution with a different process.
+ SQLAlchemy 1.4, all ORM queries that run through the
+ :meth:`_orm.Session.execute` method as well as related methods
+ :meth:`_orm.Session.scalars`, :meth:`_orm.Session.scalar` etc.
+ will participate in this event.
+ This event hook does **not** apply to the queries that are
+ emitted internally within the ORM flush process, i.e. the
+ process described at :ref:`session_flushing`.
.. note:: The :meth:`_orm.SessionEvents.do_orm_execute` event hook
is triggered **for ORM statement executions only**, meaning those
@@ -1698,11 +1700,17 @@ class SessionEvents(event.Events[Session]):
otherwise originating from an :class:`_engine.Engine` object without
any :class:`_orm.Session` involved. To intercept **all** SQL
executions regardless of whether the Core or ORM APIs are in use,
- see the event hooks at
- :class:`.ConnectionEvents`, such as
+ see the event hooks at :class:`.ConnectionEvents`, such as
:meth:`.ConnectionEvents.before_execute` and
:meth:`.ConnectionEvents.before_cursor_execute`.
+ Also, this event hook does **not** apply to queries that are
+ emitted internally within the ORM flush process,
+ i.e. the process described at :ref:`session_flushing`; to
+ intercept steps within the flush process, see the event
+ hooks described at :ref:`session_persistence_events` as
+ well as :ref:`session_persistence_mapper`.
+
This event is a ``do_`` event, meaning it has the capability to replace
the operation that the :meth:`_orm.Session.execute` method normally
performs. The intended use for this includes sharding and
diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py
index 7974d94c5..3d9ff7b0a 100644
--- a/lib/sqlalchemy/orm/loading.py
+++ b/lib/sqlalchemy/orm/loading.py
@@ -37,6 +37,8 @@ from .base import _RAISE_FOR_STATE
from .base import _SET_DEFERRED_EXPIRED
from .base import PassiveFlag
from .context import FromStatement
+from .context import ORMCompileState
+from .context import QueryContext
from .util import _none_set
from .util import state_str
from .. import exc as sa_exc
@@ -55,7 +57,6 @@ from ..util import EMPTY_DICT
if TYPE_CHECKING:
from ._typing import _IdentityKeyType
from .base import LoaderCallableStatus
- from .context import QueryContext
from .interfaces import ORMOption
from .mapper import Mapper
from .query import Query
@@ -519,9 +520,6 @@ def load_on_pk_identity(
assert not q._is_lambda_element
- # TODO: fix these imports ....
- from .context import QueryContext, ORMCompileState
-
if load_options is None:
load_options = QueryContext.default_load_options
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index a39dbc3ec..760c71afd 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -546,8 +546,8 @@ class ORMExecuteState(util.MemoizedSlots):
def is_orm_statement(self) -> bool:
"""return True if the operation is an ORM statement.
- This indicates that the select(), update(), or delete() being
- invoked contains ORM entities as subjects. For a statement
+ This indicates that the select(), insert(), update(), or delete()
+ being invoked contains ORM entities as subjects. For a statement
that does not have ORM entities and instead refers only to
:class:`.Table` metadata, it is invoked as a Core SQL statement
and no ORM-level automation takes place.
diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py
index a828d6a0f..14b1b9594 100644
--- a/lib/sqlalchemy/sql/_typing.py
+++ b/lib/sqlalchemy/sql/_typing.py
@@ -186,6 +186,7 @@ overall which brings in the TextClause object also.
"""
+
_ColumnExpressionOrLiteralArgument = Union[Any, _ColumnExpressionArgument[_T]]
_ColumnExpressionOrStrLabelArgument = Union[str, _ColumnExpressionArgument[_T]]
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index e51b755dd..a416b6ac0 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -99,6 +99,7 @@ if typing.TYPE_CHECKING:
from .selectable import _SelectIterable
from .selectable import FromClause
from .selectable import NamedFromClause
+ from .selectable import TextualSelect
from .sqltypes import TupleType
from .type_api import TypeEngine
from .visitors import _CloneCallableType
@@ -2385,7 +2386,9 @@ class TextClause(
return self
@util.preload_module("sqlalchemy.sql.selectable")
- def columns(self, *cols, **types):
+ def columns(
+ self, *cols: _ColumnExpressionArgument[Any], **types: TypeEngine[Any]
+ ) -> TextualSelect:
r"""Turn this :class:`_expression.TextClause` object into a
:class:`_expression.TextualSelect`
object that serves the same role as a SELECT
@@ -2503,29 +2506,38 @@ class TextClause(
"""
selectable = util.preloaded.sql_selectable
+
+ input_cols: List[NamedColumn[Any]] = [
+ coercions.expect(roles.LabeledColumnExprRole, col) for col in cols
+ ]
+
positional_input_cols = [
ColumnClause(col.key, types.pop(col.key))
if col.key in types
else col
- for col in cols
+ for col in input_cols
]
- keyed_input_cols: List[ColumnClause[Any]] = [
+ keyed_input_cols: List[NamedColumn[Any]] = [
ColumnClause(key, type_) for key, type_ in types.items()
]
- return selectable.TextualSelect(
+ elem = selectable.TextualSelect.__new__(selectable.TextualSelect)
+ elem._init(
self,
positional_input_cols + keyed_input_cols,
positional=bool(positional_input_cols) and not keyed_input_cols,
)
+ return elem
@property
- def type(self):
+ def type(self) -> TypeEngine[Any]:
return type_api.NULLTYPE
@property
def comparator(self):
- return self.type.comparator_factory(self)
+ # TODO: this seems wrong, it seems like we might not
+ # be using this method.
+ return self.type.comparator_factory(self) # type: ignore
def self_group(self, against=None):
if against is operators.in_op:
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index 7076cd10d..4fa9cda00 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -152,4 +152,8 @@ from .selectable import Values as Values
from .visitors import Visitable as Visitable
nullsfirst = nulls_first
+"""Synonym for the :func:`.nulls_first` function."""
+
+
nullslast = nulls_last
+"""Synonym for the :func:`.nulls_last` function."""
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
index 6054be98a..5f2e67288 100644
--- a/lib/sqlalchemy/sql/functions.py
+++ b/lib/sqlalchemy/sql/functions.py
@@ -13,6 +13,7 @@
from __future__ import annotations
import datetime
+import decimal
from typing import Any
from typing import cast
from typing import Dict
@@ -54,7 +55,6 @@ from .elements import WithinGroup
from .selectable import FromClause
from .selectable import Select
from .selectable import TableValuedAlias
-from .sqltypes import _N
from .sqltypes import TableValueType
from .type_api import TypeEngine
from .visitors import InternalTraversal
@@ -950,7 +950,7 @@ class _FunctionGenerator:
...
@property
- def cume_dist(self) -> Type[cume_dist[Any]]:
+ def cume_dist(self) -> Type[cume_dist]:
...
@property
@@ -1014,7 +1014,7 @@ class _FunctionGenerator:
...
@property
- def percent_rank(self) -> Type[percent_rank[Any]]:
+ def percent_rank(self) -> Type[percent_rank]:
...
@property
@@ -1703,7 +1703,7 @@ class dense_rank(GenericFunction[int]):
inherit_cache = True
-class percent_rank(GenericFunction[_N]):
+class percent_rank(GenericFunction[decimal.Decimal]):
"""Implement the ``percent_rank`` hypothetical-set aggregate function.
This function must be used with the :meth:`.FunctionElement.within_group`
@@ -1715,11 +1715,11 @@ class percent_rank(GenericFunction[_N]):
"""
- type: sqltypes.Numeric[_N] = sqltypes.Numeric()
+ type: sqltypes.Numeric[decimal.Decimal] = sqltypes.Numeric()
inherit_cache = True
-class cume_dist(GenericFunction[_N]):
+class cume_dist(GenericFunction[decimal.Decimal]):
"""Implement the ``cume_dist`` hypothetical-set aggregate function.
This function must be used with the :meth:`.FunctionElement.within_group`
@@ -1731,7 +1731,7 @@ class cume_dist(GenericFunction[_N]):
"""
- type: sqltypes.Numeric[_N] = sqltypes.Numeric()
+ type: sqltypes.Numeric[decimal.Decimal] = sqltypes.Numeric()
inherit_cache = True
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index 20c0341ad..b4263137b 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -5227,6 +5227,11 @@ class MetaData(HasSchemaAttr):
examples.
"""
+ if schema is not None and not isinstance(schema, str):
+ raise exc.ArgumentError(
+ "expected schema argument to be a string, "
+ f"got {type(schema)}."
+ )
self.tables = util.FacadeDict()
self.schema = quoted_name.construct(schema, quote_schema)
self.naming_convention = (
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index 75b5d09e3..39ef420dd 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -4551,7 +4551,7 @@ class SelectState(util.MemoizedSlots, CompileState):
@classmethod
def from_statement(
- cls, statement: Select[Any], from_statement: ExecutableReturnsRows
+ cls, statement: Select[Any], from_statement: roles.ReturnsRowsRole
) -> ExecutableReturnsRows:
cls._plugin_not_implemented()
@@ -5273,7 +5273,7 @@ class Select(
return meth(self)
def from_statement(
- self, statement: ExecutableReturnsRows
+ self, statement: roles.ReturnsRowsRole
) -> ExecutableReturnsRows:
"""Apply the columns which this :class:`.Select` would select
onto another statement.
@@ -6770,7 +6770,7 @@ class Exists(UnaryExpression[bool]):
return e
-class TextualSelect(SelectBase, Executable, Generative):
+class TextualSelect(SelectBase, ExecutableReturnsRows, Generative):
"""Wrap a :class:`_expression.TextClause` construct within a
:class:`_expression.SelectBase`
interface.
@@ -6815,14 +6815,28 @@ class TextualSelect(SelectBase, Executable, Generative):
def __init__(
self,
text: TextClause,
- columns: List[ColumnClause[Any]],
+ columns: List[_ColumnExpressionArgument[Any]],
+ positional: bool = False,
+ ) -> None:
+
+ self._init(
+ text,
+ # convert for ORM attributes->columns, etc
+ [
+ coercions.expect(roles.LabeledColumnExprRole, c)
+ for c in columns
+ ],
+ positional,
+ )
+
+ def _init(
+ self,
+ text: TextClause,
+ columns: List[NamedColumn[Any]],
positional: bool = False,
) -> None:
self.element = text
- # convert for ORM attributes->columns, etc
- self.column_args = [
- coercions.expect(roles.ColumnsClauseRole, c) for c in columns
- ]
+ self.column_args = columns
self.positional = positional
@HasMemoized_ro_memoized_attribute
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py
index 3c6cb0cb5..458394870 100644
--- a/lib/sqlalchemy/sql/sqltypes.py
+++ b/lib/sqlalchemy/sql/sqltypes.py
@@ -470,6 +470,26 @@ class Numeric(HasExpressionLookup, TypeEngine[_N]):
_default_decimal_return_scale = 10
+ @overload
+ def __init__(
+ self: Numeric[decimal.Decimal],
+ precision: Optional[int] = ...,
+ scale: Optional[int] = ...,
+ decimal_return_scale: Optional[int] = ...,
+ asdecimal: Literal[True] = ...,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self: Numeric[float],
+ precision: Optional[int] = ...,
+ scale: Optional[int] = ...,
+ decimal_return_scale: Optional[int] = ...,
+ asdecimal: Literal[False] = ...,
+ ):
+ ...
+
def __init__(
self,
precision: Optional[int] = None,
diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py
index 3b58052b8..a8bc6c50a 100644
--- a/lib/sqlalchemy/testing/fixtures.py
+++ b/lib/sqlalchemy/testing/fixtures.py
@@ -24,7 +24,12 @@ from .entities import ComparableEntity
from .entities import ComparableMixin # noqa
from .util import adict
from .util import drop_all_tables_from_metadata
+from .. import Column
from .. import event
+from .. import func
+from .. import Integer
+from .. import select
+from .. import Table
from .. import util
from ..orm import DeclarativeBase
from ..orm import events as orm_events
@@ -247,9 +252,6 @@ class TestBase:
def trans_ctx_manager_fixture(self, request, metadata):
rollback, second_operation, begin_nested = request.param
- from sqlalchemy import Table, Column, Integer, func, select
- from . import eq_
-
t = Table("test", metadata, Column("data", Integer))
eng = getattr(self, "bind", None) or config.db