diff options
Diffstat (limited to 'lib/sqlalchemy/orm/persistence.py')
-rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 87 |
1 files changed, 66 insertions, 21 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index fd484b52b..3d20cfdea 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -23,6 +23,7 @@ from . import evaluator from . import exc as orm_exc from . import loading from . import sync +from .base import NO_VALUE from .base import state_str from .. import exc as sa_exc from .. import future @@ -34,6 +35,7 @@ from ..sql import expression from ..sql import operators from ..sql import roles from ..sql import select +from ..sql import sqltypes from ..sql.base import _entity_namespace_key from ..sql.base import CompileState from ..sql.base import Options @@ -2002,31 +2004,12 @@ class BulkUDCompileState(CompileState): if statement._multi_values: return [] elif statement._ordered_values: - iterator = statement._ordered_values + return list(statement._ordered_values) elif statement._values: - iterator = statement._values.items() + return list(statement._values.items()) else: return [] - values = [] - if iterator: - for k, v in iterator: - if mapper: - if isinstance(k, util.string_types): - desc = _entity_namespace_key(mapper, k) - values.extend(desc._bulk_update_tuples(v)) - elif "entity_namespace" in k._annotations: - k_anno = k._annotations - attr = _entity_namespace_key( - k_anno["entity_namespace"], k_anno["proxy_key"] - ) - values.extend(attr._bulk_update_tuples(v)) - else: - values.append((k, v)) - else: - values.append((k, v)) - return values - @classmethod def _resolved_keys_as_propnames(cls, mapper, resolved_values): values = [] @@ -2191,6 +2174,68 @@ class BulkORMUpdate(UpdateDMLState, BulkUDCompileState): return self @classmethod + def _get_crud_kv_pairs(cls, statement, kv_iterator): + plugin_subject = statement._propagate_attrs["plugin_subject"] + + if plugin_subject: + mapper = plugin_subject.mapper + else: + mapper = None + + values = [] + core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs + + for k, v in kv_iterator: + if mapper: + k = coercions.expect(roles.DMLColumnRole, k) + + if isinstance(k, util.string_types): + desc = _entity_namespace_key(mapper, k, default=NO_VALUE) + if desc is NO_VALUE: + values.append( + ( + k, + coercions.expect( + roles.ExpressionElementRole, + v, + type_=sqltypes.NullType(), + is_crud=True, + ), + ) + ) + else: + values.extend( + core_get_crud_kv_pairs( + statement, desc._bulk_update_tuples(v) + ) + ) + elif "entity_namespace" in k._annotations: + k_anno = k._annotations + attr = _entity_namespace_key( + k_anno["entity_namespace"], k_anno["proxy_key"] + ) + values.extend( + core_get_crud_kv_pairs( + statement, attr._bulk_update_tuples(v) + ) + ) + else: + values.append( + ( + k, + coercions.expect( + roles.ExpressionElementRole, + v, + type_=sqltypes.NullType(), + is_crud=True, + ), + ) + ) + else: + values.extend(core_get_crud_kv_pairs(statement, [(k, v)])) + return values + + @classmethod def _do_post_synchronize_evaluate(cls, session, result, update_options): states = set() |