summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-02-18 10:05:12 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2022-05-20 14:19:02 -0400
commita463b1109abb60fc85f8356f30c0351a4e2ed71e (patch)
treede8f96b7bce319fc0f19f56b302202ea3e4e91db /lib/sqlalchemy
parent9e7bed9df601ead02fd96bf2fc787b23b536d2d6 (diff)
downloadsqlalchemy-a463b1109abb60fc85f8356f30c0351a4e2ed71e.tar.gz
implement dataclass_transforms
Implement a new means of creating a mapped dataclass where instead of applying the `@dataclass` decorator distinctly, the declarative process itself can create the dataclass. MapperProperty and MappedColumn objects themselves take the place of the dataclasses.Field object when constructing the class. The overall approach is made possible at the typing level using pep-681 dataclass transforms [1]. This new approach should be able to completely supersede the previous "dataclasses" approach of embedding metadata into Field() objects, which remains a mutually exclusive declarative setup style (mixing them introduces new issues that are not worth solving). [1] https://peps.python.org/pep-0681/#transform-descriptor-types-example Fixes: #7642 Change-Id: I6ba88a87c5df38270317b4faf085904d91c8a63c
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/orm/__init__.py1
-rw-r--r--lib/sqlalchemy/orm/_orm_constructors.py167
-rw-r--r--lib/sqlalchemy/orm/decl_api.py142
-rw-r--r--lib/sqlalchemy/orm/decl_base.py315
-rw-r--r--lib/sqlalchemy/orm/descriptor_props.py48
-rw-r--r--lib/sqlalchemy/orm/instrumentation.py9
-rw-r--r--lib/sqlalchemy/orm/interfaces.py97
-rw-r--r--lib/sqlalchemy/orm/properties.py63
-rw-r--r--lib/sqlalchemy/orm/relationships.py26
-rw-r--r--lib/sqlalchemy/orm/util.py45
-rw-r--r--lib/sqlalchemy/testing/fixtures.py25
-rw-r--r--lib/sqlalchemy/util/compat.py13
-rw-r--r--lib/sqlalchemy/util/typing.py8
13 files changed, 799 insertions, 160 deletions
diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py
index b7d1df532..4f19ba946 100644
--- a/lib/sqlalchemy/orm/__init__.py
+++ b/lib/sqlalchemy/orm/__init__.py
@@ -60,6 +60,7 @@ from .decl_api import DeclarativeBaseNoMeta as DeclarativeBaseNoMeta
from .decl_api import DeclarativeMeta as DeclarativeMeta
from .decl_api import declared_attr as declared_attr
from .decl_api import has_inherited_table as has_inherited_table
+from .decl_api import MappedAsDataclass as MappedAsDataclass
from .decl_api import registry as registry
from .decl_api import synonym_for as synonym_for
from .descriptor_props import Composite as Composite
diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py
index 0692cac09..ece6a52be 100644
--- a/lib/sqlalchemy/orm/_orm_constructors.py
+++ b/lib/sqlalchemy/orm/_orm_constructors.py
@@ -21,9 +21,9 @@ from typing import Union
from . import mapperlib as mapperlib
from ._typing import _O
-from .base import Mapped
from .descriptor_props import Composite
from .descriptor_props import Synonym
+from .interfaces import _AttributeOptions
from .properties import ColumnProperty
from .properties import MappedColumn
from .query import AliasOption
@@ -37,6 +37,8 @@ from .util import LoaderCriteriaOption
from .. import sql
from .. import util
from ..exc import InvalidRequestError
+from ..sql._typing import _no_kw
+from ..sql.base import _NoArg
from ..sql.base import SchemaEventTarget
from ..sql.schema import SchemaConst
from ..sql.selectable import FromClause
@@ -105,6 +107,10 @@ def mapped_column(
Union[_TypeEngineArgument[Any], SchemaEventTarget]
] = None,
*args: SchemaEventTarget,
+ init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002
+ default: Optional[Any] = _NoArg.NO_ARG,
+ default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
nullable: Optional[
Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]]
] = SchemaConst.NULL_UNSPECIFIED,
@@ -113,7 +119,6 @@ def mapped_column(
name: Optional[str] = None,
type_: Optional[_TypeEngineArgument[Any]] = None,
autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto",
- default: Optional[Any] = None,
doc: Optional[str] = None,
key: Optional[str] = None,
index: Optional[bool] = None,
@@ -300,6 +305,12 @@ def mapped_column(
type_=type_,
autoincrement=autoincrement,
default=default,
+ attribute_options=_AttributeOptions(
+ init,
+ repr,
+ default,
+ default_factory,
+ ),
doc=doc,
key=key,
index=index,
@@ -325,6 +336,10 @@ def column_property(
deferred: bool = False,
raiseload: bool = False,
comparator_factory: Optional[Type[PropComparator[_T]]] = None,
+ init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002
+ default: Optional[Any] = _NoArg.NO_ARG,
+ default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
active_history: bool = False,
expire_on_flush: bool = True,
info: Optional[_InfoType] = None,
@@ -416,6 +431,12 @@ def column_property(
return ColumnProperty(
column,
*additional_columns,
+ attribute_options=_AttributeOptions(
+ init,
+ repr,
+ default,
+ default_factory,
+ ),
group=group,
deferred=deferred,
raiseload=raiseload,
@@ -429,25 +450,61 @@ def column_property(
@overload
def composite(
- class_: Type[_CC],
+ _class_or_attr: Type[_CC],
*attrs: _CompositeAttrType[Any],
- **kwargs: Any,
+ group: Optional[str] = None,
+ deferred: bool = False,
+ raiseload: bool = False,
+ comparator_factory: Optional[Type[Composite.Comparator[_T]]] = None,
+ active_history: bool = False,
+ init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002
+ default: Optional[Any] = _NoArg.NO_ARG,
+ default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
+ info: Optional[_InfoType] = None,
+ doc: Optional[str] = None,
+ **__kw: Any,
) -> Composite[_CC]:
...
@overload
def composite(
+ _class_or_attr: _CompositeAttrType[Any],
*attrs: _CompositeAttrType[Any],
- **kwargs: Any,
+ group: Optional[str] = None,
+ deferred: bool = False,
+ raiseload: bool = False,
+ comparator_factory: Optional[Type[Composite.Comparator[_T]]] = None,
+ active_history: bool = False,
+ init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002
+ default: Optional[Any] = _NoArg.NO_ARG,
+ default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
+ info: Optional[_InfoType] = None,
+ doc: Optional[str] = None,
+ **__kw: Any,
) -> Composite[Any]:
...
def composite(
- class_: Any = None,
+ _class_or_attr: Union[
+ None, Type[_CC], Callable[..., _CC], _CompositeAttrType[Any]
+ ] = None,
*attrs: _CompositeAttrType[Any],
- **kwargs: Any,
+ group: Optional[str] = None,
+ deferred: bool = False,
+ raiseload: bool = False,
+ comparator_factory: Optional[Type[Composite.Comparator[_T]]] = None,
+ active_history: bool = False,
+ init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002
+ default: Optional[Any] = _NoArg.NO_ARG,
+ default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
+ info: Optional[_InfoType] = None,
+ doc: Optional[str] = None,
+ **__kw: Any,
) -> Composite[Any]:
r"""Return a composite column-based property for use with a Mapper.
@@ -497,7 +554,26 @@ def composite(
:attr:`.MapperProperty.info` attribute of this object.
"""
- return Composite(class_, *attrs, **kwargs)
+ if __kw:
+ raise _no_kw()
+
+ return Composite(
+ _class_or_attr,
+ *attrs,
+ attribute_options=_AttributeOptions(
+ init,
+ repr,
+ default,
+ default_factory,
+ ),
+ group=group,
+ deferred=deferred,
+ raiseload=raiseload,
+ comparator_factory=comparator_factory,
+ active_history=active_history,
+ info=info,
+ doc=doc,
+ )
def with_loader_criteria(
@@ -700,6 +776,10 @@ def relationship(
post_update: bool = False,
cascade: str = "save-update, merge",
viewonly: bool = False,
+ init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002
+ default: Union[_NoArg, _T] = _NoArg.NO_ARG,
+ default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
lazy: _LazyLoadArgumentType = "select",
passive_deletes: Union[Literal["all"], bool] = False,
passive_updates: bool = True,
@@ -1532,6 +1612,12 @@ def relationship(
post_update=post_update,
cascade=cascade,
viewonly=viewonly,
+ attribute_options=_AttributeOptions(
+ init,
+ repr,
+ default,
+ default_factory,
+ ),
lazy=lazy,
passive_deletes=passive_deletes,
passive_updates=passive_updates,
@@ -1559,6 +1645,10 @@ def synonym(
map_column: Optional[bool] = None,
descriptor: Optional[Any] = None,
comparator_factory: Optional[Type[PropComparator[_T]]] = None,
+ init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002
+ default: Union[_NoArg, _T] = _NoArg.NO_ARG,
+ default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
info: Optional[_InfoType] = None,
doc: Optional[str] = None,
) -> Synonym[Any]:
@@ -1670,6 +1760,12 @@ def synonym(
map_column=map_column,
descriptor=descriptor,
comparator_factory=comparator_factory,
+ attribute_options=_AttributeOptions(
+ init,
+ repr,
+ default,
+ default_factory,
+ ),
doc=doc,
info=info,
)
@@ -1784,7 +1880,17 @@ def backref(name: str, **kwargs: Any) -> _ORMBackrefArgument:
def deferred(
column: _ORMColumnExprArgument[_T],
*additional_columns: _ORMColumnExprArgument[Any],
- **kw: Any,
+ group: Optional[str] = None,
+ raiseload: bool = False,
+ comparator_factory: Optional[Type[PropComparator[_T]]] = None,
+ init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002
+ default: Optional[Any] = _NoArg.NO_ARG,
+ default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
+ active_history: bool = False,
+ expire_on_flush: bool = True,
+ info: Optional[_InfoType] = None,
+ doc: Optional[str] = None,
) -> ColumnProperty[_T]:
r"""Indicate a column-based mapped attribute that by default will
not load unless accessed.
@@ -1803,21 +1909,41 @@ def deferred(
:ref:`deferred_raiseload`
- :param \**kw: additional keyword arguments passed to
- :class:`.ColumnProperty`.
+ Additional arguments are the same as that of :func:`_orm.column_property`.
.. seealso::
:ref:`deferred`
"""
- kw["deferred"] = True
- return ColumnProperty(column, *additional_columns, **kw)
+ return ColumnProperty(
+ column,
+ *additional_columns,
+ attribute_options=_AttributeOptions(
+ init,
+ repr,
+ default,
+ default_factory,
+ ),
+ group=group,
+ deferred=True,
+ raiseload=raiseload,
+ comparator_factory=comparator_factory,
+ active_history=active_history,
+ expire_on_flush=expire_on_flush,
+ info=info,
+ doc=doc,
+ )
def query_expression(
default_expr: _ORMColumnExprArgument[_T] = sql.null(),
-) -> Mapped[_T]:
+ *,
+ repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002
+ expire_on_flush: bool = True,
+ info: Optional[_InfoType] = None,
+ doc: Optional[str] = None,
+) -> ColumnProperty[_T]:
"""Indicate an attribute that populates from a query-time SQL expression.
:param default_expr: Optional SQL expression object that will be used in
@@ -1840,7 +1966,18 @@ def query_expression(
:ref:`mapper_querytime_expression`
"""
- prop = ColumnProperty(default_expr)
+ prop = ColumnProperty(
+ default_expr,
+ attribute_options=_AttributeOptions(
+ _NoArg.NO_ARG,
+ repr,
+ _NoArg.NO_ARG,
+ _NoArg.NO_ARG,
+ ),
+ expire_on_flush=expire_on_flush,
+ info=info,
+ doc=doc,
+ )
prop.strategy_key = (("query_expression", True),)
return prop
diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py
index 1c343b04c..553a50107 100644
--- a/lib/sqlalchemy/orm/decl_api.py
+++ b/lib/sqlalchemy/orm/decl_api.py
@@ -33,6 +33,13 @@ from . import clsregistry
from . import instrumentation
from . import interfaces
from . import mapperlib
+from ._orm_constructors import column_property
+from ._orm_constructors import composite
+from ._orm_constructors import deferred
+from ._orm_constructors import mapped_column
+from ._orm_constructors import query_expression
+from ._orm_constructors import relationship
+from ._orm_constructors import synonym
from .attributes import InstrumentedAttribute
from .base import _inspect_mapped_class
from .base import Mapped
@@ -42,8 +49,13 @@ from .decl_base import _declarative_constructor
from .decl_base import _DeferredMapperConfig
from .decl_base import _del_attribute
from .decl_base import _mapper
+from .descriptor_props import Composite
+from .descriptor_props import Synonym
from .descriptor_props import Synonym as _orm_synonym
from .mapper import Mapper
+from .properties import ColumnProperty
+from .properties import MappedColumn
+from .relationships import Relationship
from .state import InstanceState
from .. import exc
from .. import inspection
@@ -60,9 +72,9 @@ from ..util.typing import Literal
if TYPE_CHECKING:
from ._typing import _O
from ._typing import _RegistryType
- from .descriptor_props import Synonym
from .instrumentation import ClassManager
from .interfaces import MapperProperty
+ from .state import InstanceState # noqa
from ..sql._typing import _TypeEngineArgument
_T = TypeVar("_T", bound=Any)
@@ -120,6 +132,26 @@ class DeclarativeAttributeIntercept(
"""
+@compat_typing.dataclass_transform(
+ field_descriptors=(
+ MappedColumn[Any],
+ Relationship[Any],
+ Composite[Any],
+ ColumnProperty[Any],
+ Synonym[Any],
+ mapped_column,
+ relationship,
+ composite,
+ column_property,
+ synonym,
+ deferred,
+ query_expression,
+ ),
+)
+class DCTransformDeclarative(DeclarativeAttributeIntercept):
+ """metaclass that includes @dataclass_transforms"""
+
+
class DeclarativeMeta(
_DynamicAttributesType, inspection.Inspectable[Mapper[Any]]
):
@@ -543,12 +575,42 @@ class DeclarativeBaseNoMeta(inspection.Inspectable[Mapper[Any]]):
cls._sa_registry.map_declaratively(cls)
+class MappedAsDataclass(metaclass=DCTransformDeclarative):
+ """Mixin class to indicate when mapping this class, also convert it to be
+ a dataclass.
+
+ .. seealso::
+
+ :meth:`_orm.registry.mapped_as_dataclass`
+
+ .. versionadded:: 2.0
+ """
+
+ def __init_subclass__(
+ cls,
+ init: bool = True,
+ repr: bool = True, # noqa: A002
+ eq: bool = True,
+ order: bool = False,
+ unsafe_hash: bool = False,
+ ) -> None:
+ cls._sa_apply_dc_transforms = {
+ "init": init,
+ "repr": repr,
+ "eq": eq,
+ "order": order,
+ "unsafe_hash": unsafe_hash,
+ }
+ super().__init_subclass__()
+
+
class DeclarativeBase(
inspection.Inspectable[InstanceState[Any]],
metaclass=DeclarativeAttributeIntercept,
):
"""Base class used for declarative class definitions.
+
The :class:`_orm.DeclarativeBase` allows for the creation of new
declarative bases in such a way that is compatible with type checkers::
@@ -1121,7 +1183,7 @@ class registry:
bases = not isinstance(cls, tuple) and (cls,) or cls
- class_dict = dict(registry=self, metadata=metadata)
+ class_dict: Dict[str, Any] = dict(registry=self, metadata=metadata)
if isinstance(cls, type):
class_dict["__doc__"] = cls.__doc__
@@ -1142,6 +1204,78 @@ class registry:
return metaclass(name, bases, class_dict)
+ @compat_typing.dataclass_transform(
+ field_descriptors=(
+ MappedColumn[Any],
+ Relationship[Any],
+ Composite[Any],
+ ColumnProperty[Any],
+ Synonym[Any],
+ mapped_column,
+ relationship,
+ composite,
+ column_property,
+ synonym,
+ deferred,
+ query_expression,
+ ),
+ )
+ @overload
+ def mapped_as_dataclass(self, __cls: Type[_O]) -> Type[_O]:
+ ...
+
+ @overload
+ def mapped_as_dataclass(
+ self,
+ __cls: Literal[None] = ...,
+ *,
+ init: bool = True,
+ repr: bool = True, # noqa: A002
+ eq: bool = True,
+ order: bool = False,
+ unsafe_hash: bool = False,
+ ) -> Callable[[Type[_O]], Type[_O]]:
+ ...
+
+ def mapped_as_dataclass(
+ self,
+ __cls: Optional[Type[_O]] = None,
+ *,
+ init: bool = True,
+ repr: bool = True, # noqa: A002
+ eq: bool = True,
+ order: bool = False,
+ unsafe_hash: bool = False,
+ ) -> Union[Type[_O], Callable[[Type[_O]], Type[_O]]]:
+ """Class decorator that will apply the Declarative mapping process
+ to a given class, and additionally convert the class to be a
+ Python dataclass.
+
+ .. seealso::
+
+ :meth:`_orm.registry.mapped`
+
+ .. versionadded:: 2.0
+
+
+ """
+
+ def decorate(cls: Type[_O]) -> Type[_O]:
+ cls._sa_apply_dc_transforms = {
+ "init": init,
+ "repr": repr,
+ "eq": eq,
+ "order": order,
+ "unsafe_hash": unsafe_hash,
+ }
+ _as_declarative(self, cls, cls.__dict__)
+ return cls
+
+ if __cls:
+ return decorate(__cls)
+ else:
+ return decorate
+
def mapped(self, cls: Type[_O]) -> Type[_O]:
"""Class decorator that will apply the Declarative mapping process
to a given class.
@@ -1174,6 +1308,10 @@ class registry:
that will apply Declarative mapping to subclasses automatically
using a Python metaclass.
+ .. seealso::
+
+ :meth:`_orm.registry.mapped_as_dataclass`
+
"""
_as_declarative(self, cls, cls.__dict__)
return cls
diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py
index a66421e22..54a272f86 100644
--- a/lib/sqlalchemy/orm/decl_base.py
+++ b/lib/sqlalchemy/orm/decl_base.py
@@ -10,6 +10,8 @@
from __future__ import annotations
import collections
+import dataclasses
+import re
from typing import Any
from typing import Callable
from typing import cast
@@ -40,6 +42,7 @@ from .base import _is_mapped_class
from .base import InspectionAttr
from .descriptor_props import Composite
from .descriptor_props import Synonym
+from .interfaces import _AttributeOptions
from .interfaces import _IntrospectsAnnotations
from .interfaces import _MappedAttribute
from .interfaces import _MapsColumns
@@ -48,15 +51,18 @@ from .mapper import Mapper as mapper
from .mapper import Mapper
from .properties import ColumnProperty
from .properties import MappedColumn
+from .util import _extract_mapped_subtype
from .util import _is_mapped_annotation
from .util import class_mapper
from .. import event
from .. import exc
from .. import util
from ..sql import expression
+from ..sql.base import _NoArg
from ..sql.schema import Column
from ..sql.schema import Table
from ..util import topological
+from ..util.typing import _AnnotationScanType
from ..util.typing import Protocol
if TYPE_CHECKING:
@@ -392,11 +398,13 @@ class _ClassScanMapperConfig(_MapperConfig):
"mapper_args",
"mapper_args_fn",
"inherits",
+ "allow_dataclass_fields",
+ "dataclass_setup_arguments",
)
registry: _RegistryType
clsdict_view: _ClassDict
- collected_annotations: Dict[str, Tuple[Any, bool]]
+ collected_annotations: Dict[str, Tuple[Any, Any, bool]]
collected_attributes: Dict[str, Any]
local_table: Optional[FromClause]
persist_selectable: Optional[FromClause]
@@ -411,6 +419,17 @@ class _ClassScanMapperConfig(_MapperConfig):
mapper_args_fn: Optional[Callable[[], Dict[str, Any]]]
inherits: Optional[Type[Any]]
+ dataclass_setup_arguments: Optional[Dict[str, Any]]
+ """if the class has SQLAlchemy native dataclass parameters, where
+ we will create a SQLAlchemy dataclass (not a real dataclass).
+
+ """
+
+ allow_dataclass_fields: bool
+ """if true, look for dataclass-processed Field objects on the target
+ class as well as superclasses and extract ORM mapping directives from
+ the "metadata" attribute of each Field"""
+
def __init__(
self,
registry: _RegistryType,
@@ -434,10 +453,37 @@ class _ClassScanMapperConfig(_MapperConfig):
self.declared_columns = util.OrderedSet()
self.column_copies = {}
+ self.dataclass_setup_arguments = dca = getattr(
+ self.cls, "_sa_apply_dc_transforms", None
+ )
+
+ cld = dataclasses.is_dataclass(cls_)
+
+ sdk = _get_immediate_cls_attr(cls_, "__sa_dataclass_metadata_key__")
+
+ # we don't want to consume Field objects from a not-already-dataclass.
+ # the Field objects won't have their "name" or "type" populated,
+ # and while it seems like we could just set these on Field as we
+ # read them, Field is documented as "user read only" and we need to
+ # stay far away from any off-label use of dataclasses APIs.
+ if (not cld or dca) and sdk:
+ raise exc.InvalidRequestError(
+ "SQLAlchemy mapped dataclasses can't consume mapping "
+ "information from dataclass.Field() objects if the immediate "
+ "class is not already a dataclass."
+ )
+
+ # if already a dataclass, and __sa_dataclass_metadata_key__ present,
+ # then also look inside of dataclass.Field() objects yielded by
+ # dataclasses.get_fields(cls) when scanning for attributes
+ self.allow_dataclass_fields = bool(sdk and cld)
+
self._setup_declared_events()
self._scan_attributes()
+ self._setup_dataclasses_transforms()
+
with mapperlib._CONFIGURE_MUTEX:
clsregistry.add_class(
self.classname, self.cls, registry._class_registry
@@ -477,11 +523,15 @@ class _ClassScanMapperConfig(_MapperConfig):
attribute, taking SQLAlchemy-enabled dataclass fields into account.
"""
- sa_dataclass_metadata_key = _get_immediate_cls_attr(
- cls, "__sa_dataclass_metadata_key__"
- )
- if sa_dataclass_metadata_key is None:
+ if self.allow_dataclass_fields:
+ sa_dataclass_metadata_key = _get_immediate_cls_attr(
+ cls, "__sa_dataclass_metadata_key__"
+ )
+ else:
+ sa_dataclass_metadata_key = None
+
+ if not sa_dataclass_metadata_key:
def attribute_is_overridden(key: str, obj: Any) -> bool:
return getattr(cls, key) is not obj
@@ -551,6 +601,7 @@ class _ClassScanMapperConfig(_MapperConfig):
"__dict__",
"__weakref__",
"_sa_class_manager",
+ "_sa_apply_dc_transforms",
"__dict__",
"__weakref__",
]
@@ -563,10 +614,6 @@ class _ClassScanMapperConfig(_MapperConfig):
adjusting for SQLAlchemy fields embedded in dataclass fields.
"""
- sa_dataclass_metadata_key: Optional[str] = _get_immediate_cls_attr(
- cls, "__sa_dataclass_metadata_key__"
- )
-
cls_annotations = util.get_annotations(cls)
cls_vars = vars(cls)
@@ -576,7 +623,15 @@ class _ClassScanMapperConfig(_MapperConfig):
names = util.merge_lists_w_ordering(
[n for n in cls_vars if n not in skip], list(cls_annotations)
)
- if sa_dataclass_metadata_key is None:
+
+ if self.allow_dataclass_fields:
+ sa_dataclass_metadata_key: Optional[str] = _get_immediate_cls_attr(
+ cls, "__sa_dataclass_metadata_key__"
+ )
+ else:
+ sa_dataclass_metadata_key = None
+
+ if not sa_dataclass_metadata_key:
def local_attributes_for_class() -> Iterable[
Tuple[str, Any, Any, bool]
@@ -652,45 +707,51 @@ class _ClassScanMapperConfig(_MapperConfig):
name,
obj,
annotation,
- is_dataclass,
+ is_dataclass_field,
) in local_attributes_for_class():
- if name == "__mapper_args__":
- check_decl = _check_declared_props_nocascade(
- obj, name, cls
- )
- if not mapper_args_fn and (not class_mapped or check_decl):
- # don't even invoke __mapper_args__ until
- # after we've determined everything about the
- # mapped table.
- # make a copy of it so a class-level dictionary
- # is not overwritten when we update column-based
- # arguments.
- def _mapper_args_fn() -> Dict[str, Any]:
- return dict(cls_as_Decl.__mapper_args__)
-
- mapper_args_fn = _mapper_args_fn
-
- elif name == "__tablename__":
- check_decl = _check_declared_props_nocascade(
- obj, name, cls
- )
- if not tablename and (not class_mapped or check_decl):
- tablename = cls_as_Decl.__tablename__
- elif name == "__table_args__":
- check_decl = _check_declared_props_nocascade(
- obj, name, cls
- )
- if not table_args and (not class_mapped or check_decl):
- table_args = cls_as_Decl.__table_args__
- if not isinstance(
- table_args, (tuple, dict, type(None))
+ if re.match(r"^__.+__$", name):
+ if name == "__mapper_args__":
+ check_decl = _check_declared_props_nocascade(
+ obj, name, cls
+ )
+ if not mapper_args_fn and (
+ not class_mapped or check_decl
):
- raise exc.ArgumentError(
- "__table_args__ value must be a tuple, "
- "dict, or None"
- )
- if base is not cls:
- inherited_table_args = True
+ # don't even invoke __mapper_args__ until
+ # after we've determined everything about the
+ # mapped table.
+ # make a copy of it so a class-level dictionary
+ # is not overwritten when we update column-based
+ # arguments.
+ def _mapper_args_fn() -> Dict[str, Any]:
+ return dict(cls_as_Decl.__mapper_args__)
+
+ mapper_args_fn = _mapper_args_fn
+
+ elif name == "__tablename__":
+ check_decl = _check_declared_props_nocascade(
+ obj, name, cls
+ )
+ if not tablename and (not class_mapped or check_decl):
+ tablename = cls_as_Decl.__tablename__
+ elif name == "__table_args__":
+ check_decl = _check_declared_props_nocascade(
+ obj, name, cls
+ )
+ if not table_args and (not class_mapped or check_decl):
+ table_args = cls_as_Decl.__table_args__
+ if not isinstance(
+ table_args, (tuple, dict, type(None))
+ ):
+ raise exc.ArgumentError(
+ "__table_args__ value must be a tuple, "
+ "dict, or None"
+ )
+ if base is not cls:
+ inherited_table_args = True
+ else:
+ # skip all other dunder names
+ continue
elif class_mapped:
if _is_declarative_props(obj):
util.warn(
@@ -706,9 +767,8 @@ class _ClassScanMapperConfig(_MapperConfig):
# acting like that for now.
if isinstance(obj, (Column, MappedColumn)):
- self.collected_annotations[name] = (
- annotation,
- False,
+ self._collect_annotation(
+ name, annotation, is_dataclass_field, True, obj
)
# already copied columns to the mapped class.
continue
@@ -745,7 +805,7 @@ class _ClassScanMapperConfig(_MapperConfig):
] = ret = obj.__get__(obj, cls)
setattr(cls, name, ret)
else:
- if is_dataclass:
+ if is_dataclass_field:
# access attribute using normal class access
# first, to see if it's been mapped on a
# superclass. note if the dataclasses.field()
@@ -789,14 +849,16 @@ class _ClassScanMapperConfig(_MapperConfig):
):
ret.doc = obj.__doc__
- self.collected_annotations[name] = (
+ self._collect_annotation(
+ name,
obj._collect_return_annotation(),
False,
+ True,
+ obj,
)
elif _is_mapped_annotation(annotation, cls):
- self.collected_annotations[name] = (
- annotation,
- is_dataclass,
+ self._collect_annotation(
+ name, annotation, is_dataclass_field, True, obj
)
if obj is None:
if not fixed_table:
@@ -809,7 +871,7 @@ class _ClassScanMapperConfig(_MapperConfig):
# declarative mapping. however, check for some
# more common mistakes
self._warn_for_decl_attributes(base, name, obj)
- elif is_dataclass and (
+ elif is_dataclass_field and (
name not in clsdict_view or clsdict_view[name] is not obj
):
# here, we are definitely looking at the target class
@@ -826,14 +888,12 @@ class _ClassScanMapperConfig(_MapperConfig):
obj = obj.fget()
collected_attributes[name] = obj
- self.collected_annotations[name] = (
- annotation,
- True,
+ self._collect_annotation(
+ name, annotation, True, False, obj
)
else:
- self.collected_annotations[name] = (
- annotation,
- False,
+ self._collect_annotation(
+ name, annotation, False, None, obj
)
if (
obj is None
@@ -843,6 +903,10 @@ class _ClassScanMapperConfig(_MapperConfig):
collected_attributes[name] = MappedColumn()
elif name in clsdict_view:
collected_attributes[name] = obj
+ # else if the name is not in the cls.__dict__,
+ # don't collect it as an attribute.
+ # we will see the annotation only, which is meaningful
+ # both for mapping and dataclasses setup
if inherited_table_args and not tablename:
table_args = None
@@ -851,6 +915,77 @@ class _ClassScanMapperConfig(_MapperConfig):
self.tablename = tablename
self.mapper_args_fn = mapper_args_fn
+ def _setup_dataclasses_transforms(self) -> None:
+
+ dataclass_setup_arguments = self.dataclass_setup_arguments
+ if not dataclass_setup_arguments:
+ return
+
+ manager = instrumentation.manager_of_class(self.cls)
+ assert manager is not None
+
+ field_list = [
+ _AttributeOptions._get_arguments_for_make_dataclass(
+ key,
+ anno,
+ self.collected_attributes.get(key, _NoArg.NO_ARG),
+ )
+ for key, anno in (
+ (key, mapped_anno if mapped_anno else raw_anno)
+ for key, (
+ raw_anno,
+ mapped_anno,
+ is_dc,
+ ) in self.collected_annotations.items()
+ )
+ ]
+
+ annotations = {}
+ defaults = {}
+ for item in field_list:
+ if len(item) == 2:
+ name, tp = item # type: ignore
+ elif len(item) == 3:
+ name, tp, spec = item # type: ignore
+ defaults[name] = spec
+ else:
+ assert False
+ annotations[name] = tp
+
+ for k, v in defaults.items():
+ setattr(self.cls, k, v)
+ self.cls.__annotations__ = annotations
+
+ dataclasses.dataclass(self.cls, **dataclass_setup_arguments)
+
+ def _collect_annotation(
+ self,
+ name: str,
+ raw_annotation: _AnnotationScanType,
+ is_dataclass: bool,
+ expect_mapped: Optional[bool],
+ attr_value: Any,
+ ) -> None:
+
+ if expect_mapped is None:
+ expect_mapped = isinstance(attr_value, _MappedAttribute)
+
+ extracted_mapped_annotation = _extract_mapped_subtype(
+ raw_annotation,
+ self.cls,
+ name,
+ type(attr_value),
+ required=False,
+ is_dataclass_field=False,
+ expect_mapped=expect_mapped and not self.allow_dataclass_fields,
+ )
+
+ self.collected_annotations[name] = (
+ raw_annotation,
+ extracted_mapped_annotation,
+ is_dataclass,
+ )
+
def _warn_for_decl_attributes(
self, cls: Type[Any], key: str, c: Any
) -> None:
@@ -982,13 +1117,53 @@ class _ClassScanMapperConfig(_MapperConfig):
_undefer_column_name(
k, self.column_copies.get(value, value) # type: ignore
)
- elif isinstance(value, _IntrospectsAnnotations):
- annotation, is_dataclass = self.collected_annotations.get(
- k, (None, False)
- )
- value.declarative_scan(
- self.registry, cls, k, annotation, is_dataclass
- )
+ else:
+ if isinstance(value, _IntrospectsAnnotations):
+ (
+ annotation,
+ extracted_mapped_annotation,
+ is_dataclass,
+ ) = self.collected_annotations.get(k, (None, None, False))
+ value.declarative_scan(
+ self.registry,
+ cls,
+ k,
+ annotation,
+ extracted_mapped_annotation,
+ is_dataclass,
+ )
+
+ if (
+ isinstance(value, (MapperProperty, _MapsColumns))
+ and value._has_dataclass_arguments
+ and not self.dataclass_setup_arguments
+ ):
+ if isinstance(value, MapperProperty):
+ argnames = [
+ "init",
+ "default_factory",
+ "repr",
+ "default",
+ ]
+ else:
+ argnames = ["init", "default_factory", "repr"]
+
+ args = {
+ a
+ for a in argnames
+ if getattr(
+ value._attribute_options, f"dataclasses_{a}"
+ )
+ is not _NoArg.NO_ARG
+ }
+ raise exc.ArgumentError(
+ f"Attribute '{k}' on class {cls} includes dataclasses "
+ f"argument(s): "
+ f"{', '.join(sorted(repr(a) for a in args))} but "
+ f"class does not specify "
+ "SQLAlchemy native dataclass configuration."
+ )
+
our_stuff[k] = value
def _extract_declared_columns(self) -> None:
@@ -997,6 +1172,7 @@ class _ClassScanMapperConfig(_MapperConfig):
# extract columns from the class dict
declared_columns = self.declared_columns
name_to_prop_key = collections.defaultdict(set)
+
for key, c in list(our_stuff.items()):
if isinstance(c, _MapsColumns):
@@ -1019,7 +1195,6 @@ class _ClassScanMapperConfig(_MapperConfig):
# otherwise, Mapper will map it under the column key.
if mp_to_assign is None and key != col.key:
our_stuff[key] = col
-
elif isinstance(c, Column):
# undefer previously occurred here, and now occurs earlier.
# ensure every column we get here has been named
diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py
index 8c89f96aa..a366a9534 100644
--- a/lib/sqlalchemy/orm/descriptor_props.py
+++ b/lib/sqlalchemy/orm/descriptor_props.py
@@ -35,11 +35,11 @@ from .base import LoaderCallableStatus
from .base import Mapped
from .base import PassiveFlag
from .base import SQLORMOperations
+from .interfaces import _AttributeOptions
from .interfaces import _IntrospectsAnnotations
from .interfaces import _MapsColumns
from .interfaces import MapperProperty
from .interfaces import PropComparator
-from .util import _extract_mapped_subtype
from .util import _none_set
from .. import event
from .. import exc as sa_exc
@@ -200,24 +200,26 @@ class Composite(
def __init__(
self,
- class_: Union[
+ _class_or_attr: Union[
None, Type[_CC], Callable[..., _CC], _CompositeAttrType[Any]
] = None,
*attrs: _CompositeAttrType[Any],
+ attribute_options: Optional[_AttributeOptions] = None,
active_history: bool = False,
deferred: bool = False,
group: Optional[str] = None,
comparator_factory: Optional[Type[Comparator[_CC]]] = None,
info: Optional[_InfoType] = None,
+ **kwargs: Any,
):
- super().__init__()
+ super().__init__(attribute_options=attribute_options)
- if isinstance(class_, (Mapped, str, sql.ColumnElement)):
- self.attrs = (class_,) + attrs
+ if isinstance(_class_or_attr, (Mapped, str, sql.ColumnElement)):
+ self.attrs = (_class_or_attr,) + attrs
# will initialize within declarative_scan
self.composite_class = None # type: ignore
else:
- self.composite_class = class_ # type: ignore
+ self.composite_class = _class_or_attr # type: ignore
self.attrs = attrs
self.active_history = active_history
@@ -332,19 +334,15 @@ class Composite(
cls: Type[Any],
key: str,
annotation: Optional[_AnnotationScanType],
+ extracted_mapped_annotation: Optional[_AnnotationScanType],
is_dataclass_field: bool,
) -> None:
- MappedColumn = util.preloaded.orm_properties.MappedColumn
-
- argument = _extract_mapped_subtype(
- annotation,
- cls,
- key,
- MappedColumn,
- self.composite_class is None,
- is_dataclass_field,
- )
-
+ if (
+ self.composite_class is None
+ and extracted_mapped_annotation is None
+ ):
+ self._raise_for_required(key, cls)
+ argument = extracted_mapped_annotation
if argument and self.composite_class is None:
if isinstance(argument, str) or hasattr(
argument, "__forward_arg__"
@@ -371,11 +369,18 @@ class Composite(
for param, attr in itertools.zip_longest(
insp.parameters.values(), self.attrs
):
- if param is None or attr is None:
+ if param is None:
raise sa_exc.ArgumentError(
- f"number of arguments to {self.composite_class.__name__} "
- f"class and number of attributes don't match"
+ f"number of composite attributes "
+ f"{len(self.attrs)} exceeds "
+ f"that of the number of attributes in class "
+ f"{self.composite_class.__name__} {len(insp.parameters)}"
)
+ if attr is None:
+ # fill in missing attr spots with empty MappedColumn
+ attr = MappedColumn()
+ self.attrs += (attr,)
+
if isinstance(attr, MappedColumn):
attr.declarative_scan_for_composite(
registry, cls, key, param.name, param.annotation
@@ -800,10 +805,11 @@ class Synonym(DescriptorProperty[_T]):
map_column: Optional[bool] = None,
descriptor: Optional[Any] = None,
comparator_factory: Optional[Type[PropComparator[_T]]] = None,
+ attribute_options: Optional[_AttributeOptions] = None,
info: Optional[_InfoType] = None,
doc: Optional[str] = None,
):
- super().__init__()
+ super().__init__(attribute_options=attribute_options)
self.name = name
self.map_column = map_column
diff --git a/lib/sqlalchemy/orm/instrumentation.py b/lib/sqlalchemy/orm/instrumentation.py
index 4fa61b7ce..33de2aee9 100644
--- a/lib/sqlalchemy/orm/instrumentation.py
+++ b/lib/sqlalchemy/orm/instrumentation.py
@@ -113,6 +113,7 @@ class ClassManager(
"previously known as deferred_scalar_loader"
init_method: Optional[Callable[..., None]]
+ original_init: Optional[Callable[..., None]] = None
factory: Optional[_ManagerFactory]
@@ -229,7 +230,7 @@ class ClassManager(
if finalize and not self._finalized:
self._finalize()
- def _finalize(self):
+ def _finalize(self) -> None:
if self._finalized:
return
self._finalized = True
@@ -238,14 +239,14 @@ class ClassManager(
_instrumentation_factory.dispatch.class_instrument(self.class_)
- def __hash__(self):
+ def __hash__(self) -> int: # type: ignore[override]
return id(self)
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
return other is self
@property
- def is_mapped(self):
+ def is_mapped(self) -> bool:
return "mapper" in self.__dict__
@HasMemoized.memoized_attribute
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index b5569ce06..e0034061d 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -19,6 +19,7 @@ are exposed when inspecting mappings.
from __future__ import annotations
import collections
+import dataclasses
import typing
from typing import Any
from typing import Callable
@@ -27,6 +28,8 @@ from typing import ClassVar
from typing import Dict
from typing import Iterator
from typing import List
+from typing import NamedTuple
+from typing import NoReturn
from typing import Optional
from typing import Sequence
from typing import Set
@@ -51,11 +54,13 @@ from .base import ONETOMANY as ONETOMANY # noqa: F401
from .base import RelationshipDirection as RelationshipDirection # noqa: F401
from .base import SQLORMOperations
from .. import ColumnElement
+from .. import exc as sa_exc
from .. import inspection
from .. import util
from ..sql import operators
from ..sql import roles
from ..sql import visitors
+from ..sql.base import _NoArg
from ..sql.base import ExecutableOption
from ..sql.cache_key import HasCacheKey
from ..sql.schema import Column
@@ -141,6 +146,7 @@ class _IntrospectsAnnotations:
cls: Type[Any],
key: str,
annotation: Optional[_AnnotationScanType],
+ extracted_mapped_annotation: Optional[_AnnotationScanType],
is_dataclass_field: bool,
) -> None:
"""Perform class-specific initializaton at early declarative scanning
@@ -150,6 +156,70 @@ class _IntrospectsAnnotations:
"""
+ def _raise_for_required(self, key: str, cls: Type[Any]) -> NoReturn:
+ raise sa_exc.ArgumentError(
+ f"Python typing annotation is required for attribute "
+ f'"{cls.__name__}.{key}" when primary argument(s) for '
+ f'"{self.__class__.__name__}" construct are None or not present'
+ )
+
+
+class _AttributeOptions(NamedTuple):
+ """define Python-local attribute behavior options common to all
+ :class:`.MapperProperty` objects.
+
+ Currently this includes dataclass-generation arguments.
+
+ .. versionadded:: 2.0
+
+ """
+
+ dataclasses_init: Union[_NoArg, bool]
+ dataclasses_repr: Union[_NoArg, bool]
+ dataclasses_default: Union[_NoArg, Any]
+ dataclasses_default_factory: Union[_NoArg, Callable[[], Any]]
+
+ def _as_dataclass_field(self) -> Any:
+ """Return a ``dataclasses.Field`` object given these arguments."""
+
+ kw: Dict[str, Any] = {}
+ if self.dataclasses_default_factory is not _NoArg.NO_ARG:
+ kw["default_factory"] = self.dataclasses_default_factory
+ if self.dataclasses_default is not _NoArg.NO_ARG:
+ kw["default"] = self.dataclasses_default
+ if self.dataclasses_init is not _NoArg.NO_ARG:
+ kw["init"] = self.dataclasses_init
+ if self.dataclasses_repr is not _NoArg.NO_ARG:
+ kw["repr"] = self.dataclasses_repr
+
+ return dataclasses.field(**kw)
+
+ @classmethod
+ def _get_arguments_for_make_dataclass(
+ cls, key: str, annotation: Type[Any], elem: _T
+ ) -> Union[
+ Tuple[str, Type[Any]], Tuple[str, Type[Any], dataclasses.Field[Any]]
+ ]:
+ """given attribute key, annotation, and value from a class, return
+ the argument tuple we would pass to dataclasses.make_dataclass()
+ for this attribute.
+
+ """
+ if isinstance(elem, (MapperProperty, _MapsColumns)):
+ dc_field = elem._attribute_options._as_dataclass_field()
+
+ return (key, annotation, dc_field)
+ elif elem is not _NoArg.NO_ARG:
+ # why is typing not erroring on this?
+ return (key, annotation, elem)
+ else:
+ return (key, annotation)
+
+
+_DEFAULT_ATTRIBUTE_OPTIONS = _AttributeOptions(
+ _NoArg.NO_ARG, _NoArg.NO_ARG, _NoArg.NO_ARG, _NoArg.NO_ARG
+)
+
class _MapsColumns(_MappedAttribute[_T]):
"""interface for declarative-capable construct that delivers one or more
@@ -158,6 +228,9 @@ class _MapsColumns(_MappedAttribute[_T]):
__slots__ = ()
+ _attribute_options: _AttributeOptions
+ _has_dataclass_arguments: bool
+
@property
def mapper_property_to_assign(self) -> Optional[MapperProperty[_T]]:
"""return a MapperProperty to be assigned to the declarative mapping"""
@@ -199,6 +272,8 @@ class MapperProperty(
__slots__ = (
"_configure_started",
"_configure_finished",
+ "_attribute_options",
+ "_has_dataclass_arguments",
"parent",
"key",
"info",
@@ -241,6 +316,15 @@ class MapperProperty(
doc: Optional[str]
"""optional documentation string"""
+ _attribute_options: _AttributeOptions
+ """behavioral options for ORM-enabled Python attributes
+
+ .. versionadded:: 2.0
+
+ """
+
+ _has_dataclass_arguments: bool
+
def _memoized_attr_info(self) -> _InfoType:
"""Info dictionary associated with the object, allowing user-defined
data to be associated with this :class:`.InspectionAttr`.
@@ -349,9 +433,20 @@ class MapperProperty(
"""
- def __init__(self) -> None:
+ def __init__(
+ self, attribute_options: Optional[_AttributeOptions] = None
+ ) -> None:
self._configure_started = False
self._configure_finished = False
+ if (
+ attribute_options
+ and attribute_options != _DEFAULT_ATTRIBUTE_OPTIONS
+ ):
+ self._has_dataclass_arguments = True
+ self._attribute_options = attribute_options
+ else:
+ self._has_dataclass_arguments = False
+ self._attribute_options = _DEFAULT_ATTRIBUTE_OPTIONS
def init(self) -> None:
"""Called after all mappers are created to assemble
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index ad3e9f248..7655f3ae2 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -30,13 +30,14 @@ from . import strategy_options
from .descriptor_props import Composite
from .descriptor_props import ConcreteInheritedProperty
from .descriptor_props import Synonym
+from .interfaces import _AttributeOptions
+from .interfaces import _DEFAULT_ATTRIBUTE_OPTIONS
from .interfaces import _IntrospectsAnnotations
from .interfaces import _MapsColumns
from .interfaces import MapperProperty
from .interfaces import PropComparator
from .interfaces import StrategizedProperty
from .relationships import Relationship
-from .util import _extract_mapped_subtype
from .util import _orm_full_deannotate
from .. import exc as sa_exc
from .. import ForeignKey
@@ -45,6 +46,7 @@ from .. import util
from ..sql import coercions
from ..sql import roles
from ..sql import sqltypes
+from ..sql.base import _NoArg
from ..sql.elements import SQLCoreOperations
from ..sql.schema import Column
from ..sql.schema import SchemaConst
@@ -131,6 +133,7 @@ class ColumnProperty(
self,
column: _ORMColumnExprArgument[_T],
*additional_columns: _ORMColumnExprArgument[Any],
+ attribute_options: Optional[_AttributeOptions] = None,
group: Optional[str] = None,
deferred: bool = False,
raiseload: bool = False,
@@ -141,7 +144,9 @@ class ColumnProperty(
doc: Optional[str] = None,
_instrument: bool = True,
):
- super(ColumnProperty, self).__init__()
+ super(ColumnProperty, self).__init__(
+ attribute_options=attribute_options
+ )
columns = (column,) + additional_columns
self._orig_columns = [
coercions.expect(roles.LabeledColumnExprRole, c) for c in columns
@@ -193,6 +198,7 @@ class ColumnProperty(
cls: Type[Any],
key: str,
annotation: Optional[_AnnotationScanType],
+ extracted_mapped_annotation: Optional[_AnnotationScanType],
is_dataclass_field: bool,
) -> None:
column = self.columns[0]
@@ -487,13 +493,38 @@ class MappedColumn(
"foreign_keys",
"_has_nullable",
"deferred",
+ "_attribute_options",
+ "_has_dataclass_arguments",
)
deferred: bool
column: Column[_T]
foreign_keys: Optional[Set[ForeignKey]]
+ _attribute_options: _AttributeOptions
def __init__(self, *arg: Any, **kw: Any):
+ self._attribute_options = attr_opts = kw.pop(
+ "attribute_options", _DEFAULT_ATTRIBUTE_OPTIONS
+ )
+
+ self._has_dataclass_arguments = False
+
+ if attr_opts is not None and attr_opts != _DEFAULT_ATTRIBUTE_OPTIONS:
+ if attr_opts.dataclasses_default_factory is not _NoArg.NO_ARG:
+ self._has_dataclass_arguments = True
+ kw["default"] = attr_opts.dataclasses_default_factory
+ elif attr_opts.dataclasses_default is not _NoArg.NO_ARG:
+ kw["default"] = attr_opts.dataclasses_default
+
+ if (
+ attr_opts.dataclasses_init is not _NoArg.NO_ARG
+ or attr_opts.dataclasses_repr is not _NoArg.NO_ARG
+ ):
+ self._has_dataclass_arguments = True
+
+ if "default" in kw and kw["default"] is _NoArg.NO_ARG:
+ kw.pop("default")
+
self.deferred = kw.pop("deferred", False)
self.column = cast("Column[_T]", Column(*arg, **kw))
self.foreign_keys = self.column.foreign_keys
@@ -509,13 +540,19 @@ class MappedColumn(
new.deferred = self.deferred
new.foreign_keys = new.column.foreign_keys
new._has_nullable = self._has_nullable
+ new._attribute_options = self._attribute_options
+ new._has_dataclass_arguments = self._has_dataclass_arguments
util.set_creation_order(new)
return new
@property
def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]:
if self.deferred:
- return ColumnProperty(self.column, deferred=True)
+ return ColumnProperty(
+ self.column,
+ deferred=True,
+ attribute_options=self._attribute_options,
+ )
else:
return None
@@ -543,6 +580,7 @@ class MappedColumn(
cls: Type[Any],
key: str,
annotation: Optional[_AnnotationScanType],
+ extracted_mapped_annotation: Optional[_AnnotationScanType],
is_dataclass_field: bool,
) -> None:
column = self.column
@@ -553,18 +591,15 @@ class MappedColumn(
sqltype = column.type
- argument = _extract_mapped_subtype(
- annotation,
- cls,
- key,
- MappedColumn,
- sqltype._isnull and not self.column.foreign_keys,
- is_dataclass_field,
- )
- if argument is None:
- return
+ if extracted_mapped_annotation is None:
+ if sqltype._isnull and not self.column.foreign_keys:
+ self._raise_for_required(key, cls)
+ else:
+ return
- self._init_column_for_annotation(cls, registry, argument)
+ self._init_column_for_annotation(
+ cls, registry, extracted_mapped_annotation
+ )
@util.preload_module("sqlalchemy.orm.decl_base")
def declarative_scan_for_composite(
diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py
index 1186f0f54..deaf52147 100644
--- a/lib/sqlalchemy/orm/relationships.py
+++ b/lib/sqlalchemy/orm/relationships.py
@@ -49,6 +49,7 @@ from .base import class_mapper
from .base import LoaderCallableStatus
from .base import PassiveFlag
from .base import state_str
+from .interfaces import _AttributeOptions
from .interfaces import _IntrospectsAnnotations
from .interfaces import MANYTOMANY
from .interfaces import MANYTOONE
@@ -56,7 +57,6 @@ from .interfaces import ONETOMANY
from .interfaces import PropComparator
from .interfaces import RelationshipDirection
from .interfaces import StrategizedProperty
-from .util import _extract_mapped_subtype
from .util import _orm_annotate
from .util import _orm_deannotate
from .util import CascadeOptions
@@ -355,6 +355,7 @@ class Relationship(
post_update: bool = False,
cascade: str = "save-update, merge",
viewonly: bool = False,
+ attribute_options: Optional[_AttributeOptions] = None,
lazy: _LazyLoadArgumentType = "select",
passive_deletes: Union[Literal["all"], bool] = False,
passive_updates: bool = True,
@@ -380,7 +381,7 @@ class Relationship(
_local_remote_pairs: Optional[_ColumnPairs] = None,
_legacy_inactive_history_style: bool = False,
):
- super(Relationship, self).__init__()
+ super(Relationship, self).__init__(attribute_options=attribute_options)
self.uselist = uselist
self.argument = argument
@@ -1701,18 +1702,19 @@ class Relationship(
cls: Type[Any],
key: str,
annotation: Optional[_AnnotationScanType],
+ extracted_mapped_annotation: Optional[_AnnotationScanType],
is_dataclass_field: bool,
) -> None:
- argument = _extract_mapped_subtype(
- annotation,
- cls,
- key,
- Relationship,
- self.argument is None,
- is_dataclass_field,
- )
- if argument is None:
- return
+ argument = extracted_mapped_annotation
+
+ if extracted_mapped_annotation is None:
+
+ if self.argument is None:
+ self._raise_for_required(key, cls)
+ else:
+ return
+
+ argument = extracted_mapped_annotation
if hasattr(argument, "__origin__"):
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index c50cc5bac..520c95672 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -1927,7 +1927,7 @@ def _getitem(iterable_query: Query[Any], item: Any) -> Any:
def _is_mapped_annotation(
- raw_annotation: Union[type, str], cls: Type[Any]
+ raw_annotation: _AnnotationScanType, cls: Type[Any]
) -> bool:
annotated = de_stringify_annotation(cls, raw_annotation)
return is_origin_of(annotated, "Mapped", module="sqlalchemy.orm")
@@ -1969,9 +1969,14 @@ def _extract_mapped_subtype(
attr_cls: Type[Any],
required: bool,
is_dataclass_field: bool,
- superclasses: Optional[Tuple[Type[Any], ...]] = None,
+ expect_mapped: bool = True,
) -> Optional[Union[type, str]]:
+ """given an annotation, figure out if it's ``Mapped[something]`` and if
+ so, return the ``something`` part.
+ Includes error raise scenarios and other options.
+
+ """
if raw_annotation is None:
if required:
@@ -1989,25 +1994,29 @@ def _extract_mapped_subtype(
if is_dataclass_field:
return annotated
else:
- # TODO: there don't seem to be tests for the failure
- # conditions here
- if not hasattr(annotated, "__origin__") or (
- not issubclass(
- annotated.__origin__, # type: ignore
- superclasses if superclasses else attr_cls,
- )
- and not issubclass(attr_cls, annotated.__origin__) # type: ignore
+ if not hasattr(annotated, "__origin__") or not is_origin_of(
+ annotated, "Mapped", module="sqlalchemy.orm"
):
- our_annotated_str = (
- annotated.__name__
+ anno_name = (
+ getattr(annotated, "__name__", None)
if not isinstance(annotated, str)
- else repr(annotated)
- )
- raise sa_exc.ArgumentError(
- f'Type annotation for "{cls.__name__}.{key}" should use the '
- f'syntax "Mapped[{our_annotated_str}]" or '
- f'"{attr_cls.__name__}[{our_annotated_str}]".'
+ else None
)
+ if anno_name is None:
+ our_annotated_str = repr(annotated)
+ else:
+ our_annotated_str = anno_name
+
+ if expect_mapped:
+ raise sa_exc.ArgumentError(
+ f'Type annotation for "{cls.__name__}.{key}" '
+ "should use the "
+ f'syntax "Mapped[{our_annotated_str}]" or '
+ f'"{attr_cls.__name__}[{our_annotated_str}]".'
+ )
+
+ else:
+ return annotated
if len(annotated.__args__) != 1: # type: ignore
raise sa_exc.ArgumentError(
diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py
index 53f76f3ce..d4e4d2dca 100644
--- a/lib/sqlalchemy/testing/fixtures.py
+++ b/lib/sqlalchemy/testing/fixtures.py
@@ -25,6 +25,7 @@ from .. import event
from .. import util
from ..orm import declarative_base
from ..orm import DeclarativeBase
+from ..orm import MappedAsDataclass
from ..orm import registry
from ..schema import sort_tables_and_constraints
@@ -90,7 +91,14 @@ class TestBase:
@config.fixture()
def registry(self, metadata):
- reg = registry(metadata=metadata)
+ reg = registry(
+ metadata=metadata,
+ type_annotation_map={
+ str: sa.String().with_variant(
+ sa.String(50), "mysql", "mariadb"
+ )
+ },
+ )
yield reg
reg.dispose()
@@ -109,6 +117,21 @@ class TestBase:
yield Base
Base.registry.dispose()
+ @config.fixture
+ def dc_decl_base(self, metadata):
+ _md = metadata
+
+ class Base(MappedAsDataclass, DeclarativeBase):
+ metadata = _md
+ type_annotation_map = {
+ str: sa.String().with_variant(
+ sa.String(50), "mysql", "mariadb"
+ )
+ }
+
+ yield Base
+ Base.registry.dispose()
+
@config.fixture()
def future_connection(self, future_engine, connection):
# integrate the future_engine and connection fixtures so
diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py
index adbbf143f..4ce1e7ff3 100644
--- a/lib/sqlalchemy/util/compat.py
+++ b/lib/sqlalchemy/util/compat.py
@@ -230,7 +230,11 @@ def inspect_formatargspec(
def dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]:
"""Return a sequence of all dataclasses.Field objects associated
- with a class."""
+ with a class as an already processed dataclass.
+
+ The class must **already be a dataclass** for Field objects to be returned.
+
+ """
if dataclasses.is_dataclass(cls):
return dataclasses.fields(cls)
@@ -240,7 +244,12 @@ def dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]:
def local_dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]:
"""Return a sequence of all dataclasses.Field objects associated with
- a class, excluding those that originate from a superclass."""
+ an already processed dataclass, excluding those that originate from a
+ superclass.
+
+ The class must **already be a dataclass** for Field objects to be returned.
+
+ """
if dataclasses.is_dataclass(cls):
super_fields: Set[dataclasses.Field[Any]] = set()
diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py
index 44e26f609..454de100b 100644
--- a/lib/sqlalchemy/util/typing.py
+++ b/lib/sqlalchemy/util/typing.py
@@ -23,6 +23,14 @@ from typing_extensions import NotRequired as NotRequired # noqa: F401
from . import compat
+
+# more zimports issues
+if True:
+ from typing_extensions import ( # noqa: F401
+ dataclass_transform as dataclass_transform,
+ )
+
+
_T = TypeVar("_T", bound=Any)
_KT = TypeVar("_KT")
_KT_co = TypeVar("_KT_co", covariant=True)