summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2020-10-02 16:07:30 +0000
committerGerrit Code Review <gerrit@bbpush.zzzcomputing.com>2020-10-02 16:07:30 +0000
commit7bb9ea911cb2e573696a91392a6a08161950ac9f (patch)
tree2d6022ba9f36633211167bfc64fb1203af9a824b
parent1a9618afdd6413fd04ee44b797a15735eaa1a230 (diff)
parent1430a30dd2033550699cf5e074cb81058687eb13 (diff)
downloadsqlalchemy-7bb9ea911cb2e573696a91392a6a08161950ac9f.tar.gz
Merge "use execute_20 to preserve compiled cache"
-rw-r--r--lib/sqlalchemy/orm/persistence.py97
-rw-r--r--test/orm/test_unitofworkv2.py35
2 files changed, 86 insertions, 46 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index fa126a279..022f6611f 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -52,8 +52,6 @@ def _bulk_insert(
):
base_mapper = mapper.base_mapper
- cached_connections = _cached_connection_dict(base_mapper)
-
if session_transaction.session.connection_callable:
raise NotImplementedError(
"connection_callable / per-instance sharding "
@@ -105,7 +103,6 @@ def _bulk_insert(
_emit_insert_statements(
base_mapper,
None,
- cached_connections,
super_mapper,
table,
records,
@@ -127,8 +124,6 @@ def _bulk_update(
):
base_mapper = mapper.base_mapper
- cached_connections = _cached_connection_dict(base_mapper)
-
search_keys = mapper._primary_key_propkeys
if mapper._version_id_prop:
search_keys = {mapper._version_id_prop.key}.union(search_keys)
@@ -183,7 +178,6 @@ def _bulk_update(
_emit_update_statements(
base_mapper,
None,
- cached_connections,
super_mapper,
table,
records,
@@ -210,7 +204,6 @@ def save_obj(base_mapper, states, uowtransaction, single=False):
states_to_update = []
states_to_insert = []
- cached_connections = _cached_connection_dict(base_mapper)
for (
state,
@@ -240,7 +233,6 @@ def save_obj(base_mapper, states, uowtransaction, single=False):
_emit_update_statements(
base_mapper,
uowtransaction,
- cached_connections,
mapper,
table,
update,
@@ -249,7 +241,6 @@ def save_obj(base_mapper, states, uowtransaction, single=False):
_emit_insert_statements(
base_mapper,
uowtransaction,
- cached_connections,
mapper,
table,
insert,
@@ -282,7 +273,6 @@ def post_update(base_mapper, states, uowtransaction, post_update_cols):
specifies post_update.
"""
- cached_connections = _cached_connection_dict(base_mapper)
states_to_update = list(
_organize_states_for_post_update(base_mapper, states, uowtransaction)
@@ -315,7 +305,6 @@ def post_update(base_mapper, states, uowtransaction, post_update_cols):
_emit_post_update_statements(
base_mapper,
uowtransaction,
- cached_connections,
mapper,
table,
update,
@@ -330,8 +319,6 @@ def delete_obj(base_mapper, states, uowtransaction):
"""
- cached_connections = _cached_connection_dict(base_mapper)
-
states_to_delete = list(
_organize_states_for_delete(base_mapper, states, uowtransaction)
)
@@ -352,7 +339,6 @@ def delete_obj(base_mapper, states, uowtransaction):
_emit_delete_statements(
base_mapper,
uowtransaction,
- cached_connections,
mapper,
table,
delete,
@@ -856,7 +842,6 @@ def _collect_delete_commands(
def _emit_update_statements(
base_mapper,
uowtransaction,
- cached_connections,
mapper,
table,
update,
@@ -870,6 +855,8 @@ def _emit_update_statements(
and mapper.version_id_col in mapper._cols_by_table[table]
)
+ execution_options = {"compiled_cache": base_mapper._compiled_cache}
+
def update_stmt():
clauses = BooleanClauseList._construct_raw(operators.and_)
@@ -948,7 +935,11 @@ def _emit_update_statements(
has_all_defaults,
has_all_pks,
) in records:
- c = connection.execute(statement.values(value_params), params)
+ c = connection._execute_20(
+ statement.values(value_params),
+ params,
+ execution_options=execution_options,
+ )
if bookkeeping:
_postfetch(
mapper,
@@ -977,8 +968,8 @@ def _emit_update_statements(
has_all_defaults,
has_all_pks,
) in records:
- c = cached_connections[connection].execute(
- statement, params
+ c = connection._execute_20(
+ statement, params, execution_options=execution_options
)
# TODO: why with bookkeeping=False?
@@ -1003,8 +994,8 @@ def _emit_update_statements(
assert_singlerow and len(multiparams) == 1
)
- c = cached_connections[connection].execute(
- statement, multiparams
+ c = connection._execute_20(
+ statement, multiparams, execution_options=execution_options
)
rows += c.rowcount
@@ -1054,7 +1045,6 @@ def _emit_update_statements(
def _emit_insert_statements(
base_mapper,
uowtransaction,
- cached_connections,
mapper,
table,
insert,
@@ -1065,6 +1055,8 @@ def _emit_insert_statements(
cached_stmt = base_mapper._memo(("insert", table), table.insert)
+ execution_options = {"compiled_cache": base_mapper._compiled_cache}
+
for (
(connection, pkeys, hasvalue, has_all_pks, has_all_defaults),
records,
@@ -1098,7 +1090,10 @@ def _emit_insert_statements(
records = list(records)
multiparams = [rec[2] for rec in records]
- c = cached_connections[connection].execute(statement, multiparams)
+ c = connection._execute_20(
+ statement, multiparams, execution_options=execution_options
+ )
+
if bookkeeping:
for (
(
@@ -1154,9 +1149,10 @@ def _emit_insert_statements(
if do_executemany:
multiparams = [rec[2] for rec in records]
- c = cached_connections[connection].execute(
- statement, multiparams
+ c = connection._execute_20(
+ statement, multiparams, execution_options=execution_options
)
+
if bookkeeping:
for (
(
@@ -1213,12 +1209,16 @@ def _emit_insert_statements(
has_all_defaults,
) in records:
if value_params:
- result = connection.execute(
- statement.values(value_params), params
+ result = connection._execute_20(
+ statement.values(value_params),
+ params,
+ execution_options=execution_options,
)
else:
- result = cached_connections[connection].execute(
- statement, params
+ result = connection._execute_20(
+ statement,
+ params,
+ execution_options=execution_options,
)
primary_key = result.inserted_primary_key
@@ -1253,11 +1253,13 @@ def _emit_insert_statements(
def _emit_post_update_statements(
- base_mapper, uowtransaction, cached_connections, mapper, table, update
+ base_mapper, uowtransaction, mapper, table, update
):
"""Emit UPDATE statements corresponding to value lists collected
by _collect_post_update_commands()."""
+ execution_options = {"compiled_cache": base_mapper._compiled_cache}
+
needs_version_id = (
mapper.version_id_col is not None
and mapper.version_id_col in mapper._cols_by_table[table]
@@ -1316,7 +1318,11 @@ def _emit_post_update_statements(
if not allow_multirow:
check_rowcount = assert_singlerow
for state, state_dict, mapper_rec, connection, params in records:
- c = cached_connections[connection].execute(statement, params)
+
+ c = connection._execute_20(
+ statement, params, execution_options=execution_options
+ )
+
_postfetch_post_update(
mapper_rec,
uowtransaction,
@@ -1337,7 +1343,9 @@ def _emit_post_update_statements(
assert_singlerow and len(multiparams) == 1
)
- c = cached_connections[connection].execute(statement, multiparams)
+ c = connection._execute_20(
+ statement, multiparams, execution_options=execution_options
+ )
rows += c.rowcount
for state, state_dict, mapper_rec, connection, params in records:
@@ -1368,7 +1376,7 @@ def _emit_post_update_statements(
def _emit_delete_statements(
- base_mapper, uowtransaction, cached_connections, mapper, table, delete
+ base_mapper, uowtransaction, mapper, table, delete
):
"""Emit DELETE statements corresponding to value lists collected
by _collect_delete_commands()."""
@@ -1400,8 +1408,7 @@ def _emit_delete_statements(
for connection, recs in groupby(delete, lambda rec: rec[1]): # connection
del_objects = [params for params, connection in recs]
- connection = cached_connections[connection]
-
+ execution_options = {"compiled_cache": base_mapper._compiled_cache}
expected = len(del_objects)
rows_matched = -1
only_warn = False
@@ -1415,7 +1422,10 @@ def _emit_delete_statements(
# execute deletes individually so that versioned
# rows can be verified
for params in del_objects:
- c = connection.execute(statement, params)
+
+ c = connection._execute_20(
+ statement, params, execution_options=execution_options
+ )
rows_matched += c.rowcount
else:
util.warn(
@@ -1423,9 +1433,13 @@ def _emit_delete_statements(
"- versioning cannot be verified."
% connection.dialect.dialect_description
)
- connection.execute(statement, del_objects)
+ connection._execute_20(
+ statement, del_objects, execution_options=execution_options
+ )
else:
- c = connection.execute(statement, del_objects)
+ c = connection._execute_20(
+ statement, del_objects, execution_options=execution_options
+ )
if not need_version_id:
only_warn = True
@@ -1702,15 +1716,6 @@ def _connections_for_states(base_mapper, uowtransaction, states):
yield state, state.dict, mapper, connection
-def _cached_connection_dict(base_mapper):
- # dictionary of connection->connection_with_cache_options.
- return util.PopulateDict(
- lambda conn: conn.execution_options(
- compiled_cache=base_mapper._compiled_cache
- )
- )
-
-
def _sort_states(mapper, states):
pending = set(states)
persistent = set(s for s in pending if s.key is not None)
diff --git a/test/orm/test_unitofworkv2.py b/test/orm/test_unitofworkv2.py
index e5d9a2f7a..ed320db10 100644
--- a/test/orm/test_unitofworkv2.py
+++ b/test/orm/test_unitofworkv2.py
@@ -4,6 +4,7 @@ from sqlalchemy import exc
from sqlalchemy import FetchedValue
from sqlalchemy import ForeignKey
from sqlalchemy import func
+from sqlalchemy import inspect
from sqlalchemy import Integer
from sqlalchemy import JSON
from sqlalchemy import literal
@@ -25,6 +26,7 @@ from sqlalchemy.testing import config
from sqlalchemy.testing import engines
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import is_
from sqlalchemy.testing.assertsql import AllOf
from sqlalchemy.testing.assertsql import CompiledSQL
from sqlalchemy.testing.assertsql import Conditional
@@ -3066,3 +3068,36 @@ class NullEvaluatingTest(fixtures.MappedTest, testing.AssertsExecutionResults):
s.commit()
eq_(s.query(cast(JSONThing.data, String)).scalar(), "null")
eq_(s.query(cast(JSONThing.data_null, String)).scalar(), None)
+
+
+class EnsureCacheTest(fixtures.FutureEngineMixin, UOWTest):
+ def test_ensure_cache(self):
+ users, User = self.tables.users, self.classes.User
+
+ mapper(User, users)
+
+ cache = {}
+ eq_(len(inspect(User)._compiled_cache), 0)
+
+ with testing.db.connect().execution_options(
+ compiled_cache=cache
+ ) as conn:
+ s = Session(conn)
+ u1 = User(name="adf")
+ s.add(u1)
+ s.flush()
+
+ is_(conn._execution_options["compiled_cache"], cache)
+ eq_(len(inspect(User)._compiled_cache), 1)
+
+ u1.name = "newname"
+ s.flush()
+
+ is_(conn._execution_options["compiled_cache"], cache)
+ eq_(len(inspect(User)._compiled_cache), 2)
+
+ s.delete(u1)
+ s.flush()
+
+ is_(conn._execution_options["compiled_cache"], cache)
+ eq_(len(inspect(User)._compiled_cache), 3)