diff options
Diffstat (limited to 'lib/sqlalchemy/orm/persistence.py')
| -rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 109 |
1 files changed, 60 insertions, 49 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 8e91dd6c7..5dc5a90b1 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -18,7 +18,7 @@ import operator from itertools import groupby, chain from .. import sql, util, exc as sa_exc from . import attributes, sync, exc as orm_exc, evaluator -from .base import state_str, _attr_as_key, _entity_descriptor +from .base import state_str, _entity_descriptor from ..sql import expression from ..sql.base import _from_objects from . import loading @@ -1180,6 +1180,12 @@ class BulkUD(object): self._do_post_synchronize() self._do_post() + def _execute_stmt(self, stmt): + self.result = self.query.session.execute( + stmt, params=self.query._params, + mapper=self.mapper) + self.rowcount = self.result.rowcount + @util.dependencies("sqlalchemy.orm.query") def _do_pre(self, querylib): query = self.query @@ -1287,41 +1293,49 @@ class BulkUpdate(BulkUD): False: BulkUpdate }, synchronize_session, query, values, update_kwargs) - def _resolve_string_to_expr(self, key): - if self.mapper and isinstance(key, util.string_types): - attr = _entity_descriptor(self.mapper, key) - return attr.__clause_element__() - else: - return key - - def _resolve_key_to_attrname(self, key): - if self.mapper and isinstance(key, util.string_types): - attr = _entity_descriptor(self.mapper, key) - return attr.property.key - elif isinstance(key, attributes.InstrumentedAttribute): - return key.key - elif hasattr(key, '__clause_element__'): - key = key.__clause_element__() - - if self.mapper and isinstance(key, expression.ColumnElement): - try: - attr = self.mapper._columntoproperty[key] - except orm_exc.UnmappedColumnError: - return None + @property + def _resolved_values(self): + values = [] + for k, v in ( + self.values.items() if hasattr(self.values, 'items') + else self.values): + if self.mapper: + if isinstance(k, util.string_types): + desc = _entity_descriptor(self.mapper, k) + values.extend(desc._bulk_update_tuples(v)) + elif isinstance(k, attributes.QueryableAttribute): + values.extend(k._bulk_update_tuples(v)) + else: + values.append((k, v)) else: - return attr.key - else: - raise sa_exc.InvalidRequestError( - "Invalid expression type: %r" % key) + values.append((k, v)) + return values + + @property + def _resolved_values_keys_as_propnames(self): + values = [] + for k, v in self._resolved_values: + if isinstance(k, attributes.QueryableAttribute): + values.append((k.key, v)) + continue + elif hasattr(k, '__clause_element__'): + k = k.__clause_element__() + + if self.mapper and isinstance(k, expression.ColumnElement): + try: + attr = self.mapper._columntoproperty[k] + except orm_exc.UnmappedColumnError: + pass + else: + values.append((attr.key, v)) + else: + raise sa_exc.InvalidRequestError( + "Invalid expression type: %r" % k) + return values def _do_exec(self): + values = self._resolved_values - values = [ - (self._resolve_string_to_expr(k), v) - for k, v in ( - self.values.items() if hasattr(self.values, 'items') - else self.values) - ] if not self.update_kwargs.get('preserve_parameter_order', False): values = dict(values) @@ -1329,10 +1343,7 @@ class BulkUpdate(BulkUD): self.context.whereclause, values, **self.update_kwargs) - self.result = self.query.session.execute( - update_stmt, params=self.query._params, - mapper=self.mapper) - self.rowcount = self.result.rowcount + self._execute_stmt(update_stmt) def _do_post(self): session = self.query.session @@ -1357,11 +1368,7 @@ class BulkDelete(BulkUD): delete_stmt = sql.delete(self.primary_table, self.context.whereclause) - self.result = self.query.session.execute( - delete_stmt, - params=self.query._params, - mapper=self.mapper) - self.rowcount = self.result.rowcount + self._execute_stmt(delete_stmt) def _do_post(self): session = self.query.session @@ -1374,13 +1381,10 @@ class BulkUpdateEvaluate(BulkEvaluate, BulkUpdate): def _additional_evaluators(self, evaluator_compiler): self.value_evaluators = {} - values = (self.values.items() if hasattr(self.values, 'items') - else self.values) + values = self._resolved_values_keys_as_propnames for key, value in values: - key = self._resolve_key_to_attrname(key) - if key is not None: - self.value_evaluators[key] = evaluator_compiler.process( - expression._literal_as_binds(value)) + self.value_evaluators[key] = evaluator_compiler.process( + expression._literal_as_binds(value)) def _do_post_synchronize(self): session = self.query.session @@ -1396,6 +1400,9 @@ class BulkUpdateEvaluate(BulkEvaluate, BulkUpdate): for key in to_evaluate: dict_[key] = self.value_evaluators[key](obj) + state.manager.dispatch.refresh( + state, None, to_evaluate) + state._commit(dict_, list(to_evaluate)) # expire attributes with pending changes @@ -1434,9 +1441,13 @@ class BulkUpdateFetch(BulkFetch, BulkUpdate): ] if identity_key in session.identity_map ]) - attrib = [_attr_as_key(k) for k in self.values] + + values = self._resolved_values_keys_as_propnames + attrib = set(k for k, v in values) for state in states: - session._expire_state(state, attrib) + to_expire = attrib.intersection(state.dict) + if to_expire: + session._expire_state(state, to_expire) session._register_altered(states) |
