diff options
| author | mike bayer <mike_mp@zzzcomputing.com> | 2023-03-10 21:05:35 +0000 |
|---|---|---|
| committer | Gerrit Code Review <gerrit@bbpush.zzzcomputing.com> | 2023-03-10 21:05:35 +0000 |
| commit | 7f4df1625ef06cbfda5ca2db34984fd576506fff (patch) | |
| tree | 87678b76f0a1ece5644599aa83851db4cdbe4965 | |
| parent | 3a7bd8405c3a5970ed295b4efdfa790c4a7d8875 (diff) | |
| parent | 2c9796b10c3e85450afdeedc4003607abda2f2db (diff) | |
| download | sqlalchemy-7f4df1625ef06cbfda5ca2db34984fd576506fff.tar.gz | |
Merge "repair broken lambda patch" into main
| -rw-r--r-- | doc/build/changelog/unreleased_20/9461.rst | 8 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/lambdas.py | 15 | ||||
| -rw-r--r-- | test/ext/test_automap.py | 2 | ||||
| -rw-r--r-- | test/requirements.py | 28 | ||||
| -rw-r--r-- | test/sql/test_lambdas.py | 102 |
5 files changed, 135 insertions, 20 deletions
diff --git a/doc/build/changelog/unreleased_20/9461.rst b/doc/build/changelog/unreleased_20/9461.rst new file mode 100644 index 000000000..3397cfe27 --- /dev/null +++ b/doc/build/changelog/unreleased_20/9461.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, sql, regression + :tickets: 9461 + + Fixed regression where the fix for :ticket:`8098`, which was released in + the 1.4 series and provided a layer of concurrency-safe checks for the + lambda SQL API, included additional fixes in the patch that failed to be + applied to the main branch. These additional fixes have been applied. diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py index 04bf86ee6..12175c75d 100644 --- a/lib/sqlalchemy/sql/lambdas.py +++ b/lib/sqlalchemy/sql/lambdas.py @@ -272,11 +272,16 @@ class LambdaElement(elements.ClauseElement): if rec is None: if cache_key is not _cache_key.NO_CACHE: - rec = AnalyzedFunction( - tracker, self, apply_propagate_attrs, fn - ) - rec.closure_bindparams = bindparams - lambda_cache[tracker_key + cache_key] = rec + with AnalyzedCode._generation_mutex: + key = tracker_key + cache_key + if key not in lambda_cache: + rec = AnalyzedFunction( + tracker, self, apply_propagate_attrs, fn + ) + rec.closure_bindparams = bindparams + lambda_cache[key] = rec + else: + rec = lambda_cache[key] else: rec = NonAnalyzedFunction(self._invoke_user_fn(fn)) diff --git a/test/ext/test_automap.py b/test/ext/test_automap.py index dca0bb063..c84bc1c78 100644 --- a/test/ext/test_automap.py +++ b/test/ext/test_automap.py @@ -657,7 +657,7 @@ class AutomapInhTest(fixtures.MappedTest): class ConcurrentAutomapTest(fixtures.TestBase): - __only_on__ = "sqlite" + __only_on__ = "sqlite+pysqlite" def _make_tables(self, e): m = MetaData() diff --git a/test/requirements.py b/test/requirements.py index 9d51ae477..67ecdc405 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -367,22 +367,22 @@ class DefaultRequirements(SuiteRequirements): Target must support simultaneous, independent database connections. """ - # This is also true of some configurations of UnixODBC and probably - # win32 ODBC as well. + # note: **do not** let any sqlite driver run "independent connection" + # tests. Use independent_readonly_connections for a concurrency + # related test that only uses reads to use sqlite + return skip_if(["sqlite"]) + + @property + def independent_readonly_connections(self): + """ + Target must support simultaneous, independent database connections + that will be used in a readonly fashion. + + """ return skip_if( [ - no_support( - "sqlite", - "independent connections disabled " - "when :memory: connections are used", - ), - exclude( - "mssql", - "<", - (9, 0, 0), - "SQL Server 2005+ is required for " - "independent connections", - ), + self._sqlite_memory_db, + "+aiosqlite", ] ) diff --git a/test/sql/test_lambdas.py b/test/sql/test_lambdas.py index c3e271706..002a13db9 100644 --- a/test/sql/test_lambdas.py +++ b/test/sql/test_lambdas.py @@ -1,3 +1,10 @@ +from __future__ import annotations + +import threading +import time +from typing import List +from typing import Optional + from sqlalchemy import exc from sqlalchemy import testing from sqlalchemy.future import select as future_select @@ -2083,3 +2090,98 @@ class DeferredLambdaElementTest( eq_(e12key[0], e1key[0]) eq_(e32key[0], e3key[0]) + + +class ConcurrencyTest(fixtures.TestBase): + """test for #8098 and #9461""" + + __requires__ = ("independent_readonly_connections",) + + __only_on__ = ("+psycopg2", "+mysqldb", "+pysqlite", "+pymysql") + + THREADS = 10 + + @testing.fixture + def mapping_fixture(self, decl_base): + class A(decl_base): + __tablename__ = "a" + id = Column(Integer, primary_key=True) + col1 = Column(String(100)) + col2 = Column(String(100)) + col3 = Column(String(100)) + col4 = Column(String(100)) + + decl_base.metadata.create_all(testing.db) + + from sqlalchemy.orm import Session + + with testing.db.connect() as conn: + with Session(conn) as session: + session.add_all( + [ + A(col1=str(i), col2=str(i), col3=str(i), col4=str(i)) + for i in range(self.THREADS + 1) + ] + ) + session.commit() + + return A + + @testing.requires.timing_intensive + def test_lambda_concurrency(self, testing_engine, mapping_fixture): + A = mapping_fixture + engine = testing_engine(options={"pool_size": self.THREADS + 5}) + NUM_OF_LAMBDAS = 150 + + code = """ +from sqlalchemy import lambda_stmt, select + + +def generate_lambda_stmt(wanted): + stmt = lambda_stmt(lambda: select(A.col1, A.col2, A.col3, A.col4)) +""" + + for _ in range(NUM_OF_LAMBDAS): + code += ( + " stmt += lambda s: s.where((A.col1 == wanted) & " + "(A.col2 == wanted) & (A.col3 == wanted) & " + "(A.col4 == wanted))\n" + ) + + code += """ + return stmt +""" + + d = {"A": A, "__name__": "lambda_fake"} + exec(code, d) + generate_lambda_stmt = d["generate_lambda_stmt"] + + runs: List[Optional[int]] = [None for _ in range(self.THREADS)] + conns = [engine.connect() for _ in range(self.THREADS)] + + def run(num): + wanted = str(num) + connection = conns[num] + time.sleep(0.1) + stmt = generate_lambda_stmt(wanted) + time.sleep(0.1) + row = connection.execute(stmt).first() + if not row: + runs[num] = False + else: + runs[num] = True + + threads = [ + threading.Thread(target=run, args=(num,)) + for num in range(self.THREADS) + ] + + for thread in threads: + thread.start() + for thread in threads: + thread.join(timeout=10) + for conn in conns: + conn.close() + + fails = len([r for r in runs if r is False]) + assert not fails, f"{fails} runs failed" |
