diff options
| author | mike bayer <mike_mp@zzzcomputing.com> | 2021-03-18 19:47:02 +0000 |
|---|---|---|
| committer | Gerrit Code Review <gerrit@ci3.zzzcomputing.com> | 2021-03-18 19:47:02 +0000 |
| commit | cdfc7c276c661569befdb828a045a0888d485680 (patch) | |
| tree | f3fb425b8d16c925c04c5982742b800e61bd12ca /lib | |
| parent | b728ff1601336459840a7dcdab7697fa3535dbf5 (diff) | |
| parent | ce2f28c37e0a2f2aa3b4a404ee190cdc00b8b918 (diff) | |
| download | sqlalchemy-cdfc7c276c661569befdb828a045a0888d485680.tar.gz | |
Merge "Adjust dataclass rules to account for field w/ default"
Diffstat (limited to 'lib')
| -rw-r--r-- | lib/sqlalchemy/orm/decl_base.py | 24 |
1 files changed, 17 insertions, 7 deletions
diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index a21af192e..0a73288fd 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -15,6 +15,7 @@ from sqlalchemy.orm import instrumentation from . import clsregistry from . import exc as orm_exc from . import mapper as mapperlib +from .attributes import InstrumentedAttribute from .attributes import QueryableAttribute from .base import _is_mapped_class from .base import InspectionAttr @@ -366,18 +367,24 @@ class _ClassScanMapperConfig(_MapperConfig): elif ret is not absent: return True + all_field = all_datacls_fields.get(key, absent) + ret = getattr(cls, key, obj) if ret is obj: return False - elif ret is not absent: - return True - ret = all_datacls_fields.get(key, absent) + # for dataclasses, this could be the + # 'default' of the field. so filter more specifically + # for an already-mapped InstrumentedAttribute + if ret is not absent and isinstance( + ret, InstrumentedAttribute + ): + return True - if ret is obj: + if all_field is obj: return False - elif ret is not absent: + elif all_field is not absent: return True # can't find another attribute @@ -401,15 +408,18 @@ class _ClassScanMapperConfig(_MapperConfig): yield name, obj else: + field_names = set() def local_attributes_for_class(): - for name, obj in vars(cls).items(): - yield name, obj for field in util.local_dataclass_fields(cls): if sa_dataclass_metadata_key in field.metadata: + field_names.add(field.name) yield field.name, field.metadata[ sa_dataclass_metadata_key ] + for name, obj in vars(cls).items(): + if name not in field_names: + yield name, obj return local_attributes_for_class |
