summaryrefslogtreecommitdiff
path: root/test/sql
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2023-03-10 21:05:35 +0000
committerGerrit Code Review <gerrit@bbpush.zzzcomputing.com>2023-03-10 21:05:35 +0000
commit7f4df1625ef06cbfda5ca2db34984fd576506fff (patch)
tree87678b76f0a1ece5644599aa83851db4cdbe4965 /test/sql
parent3a7bd8405c3a5970ed295b4efdfa790c4a7d8875 (diff)
parent2c9796b10c3e85450afdeedc4003607abda2f2db (diff)
downloadsqlalchemy-7f4df1625ef06cbfda5ca2db34984fd576506fff.tar.gz
Merge "repair broken lambda patch" into main
Diffstat (limited to 'test/sql')
-rw-r--r--test/sql/test_lambdas.py102
1 files changed, 102 insertions, 0 deletions
diff --git a/test/sql/test_lambdas.py b/test/sql/test_lambdas.py
index c3e271706..002a13db9 100644
--- a/test/sql/test_lambdas.py
+++ b/test/sql/test_lambdas.py
@@ -1,3 +1,10 @@
+from __future__ import annotations
+
+import threading
+import time
+from typing import List
+from typing import Optional
+
from sqlalchemy import exc
from sqlalchemy import testing
from sqlalchemy.future import select as future_select
@@ -2083,3 +2090,98 @@ class DeferredLambdaElementTest(
eq_(e12key[0], e1key[0])
eq_(e32key[0], e3key[0])
+
+
+class ConcurrencyTest(fixtures.TestBase):
+ """test for #8098 and #9461"""
+
+ __requires__ = ("independent_readonly_connections",)
+
+ __only_on__ = ("+psycopg2", "+mysqldb", "+pysqlite", "+pymysql")
+
+ THREADS = 10
+
+ @testing.fixture
+ def mapping_fixture(self, decl_base):
+ class A(decl_base):
+ __tablename__ = "a"
+ id = Column(Integer, primary_key=True)
+ col1 = Column(String(100))
+ col2 = Column(String(100))
+ col3 = Column(String(100))
+ col4 = Column(String(100))
+
+ decl_base.metadata.create_all(testing.db)
+
+ from sqlalchemy.orm import Session
+
+ with testing.db.connect() as conn:
+ with Session(conn) as session:
+ session.add_all(
+ [
+ A(col1=str(i), col2=str(i), col3=str(i), col4=str(i))
+ for i in range(self.THREADS + 1)
+ ]
+ )
+ session.commit()
+
+ return A
+
+ @testing.requires.timing_intensive
+ def test_lambda_concurrency(self, testing_engine, mapping_fixture):
+ A = mapping_fixture
+ engine = testing_engine(options={"pool_size": self.THREADS + 5})
+ NUM_OF_LAMBDAS = 150
+
+ code = """
+from sqlalchemy import lambda_stmt, select
+
+
+def generate_lambda_stmt(wanted):
+ stmt = lambda_stmt(lambda: select(A.col1, A.col2, A.col3, A.col4))
+"""
+
+ for _ in range(NUM_OF_LAMBDAS):
+ code += (
+ " stmt += lambda s: s.where((A.col1 == wanted) & "
+ "(A.col2 == wanted) & (A.col3 == wanted) & "
+ "(A.col4 == wanted))\n"
+ )
+
+ code += """
+ return stmt
+"""
+
+ d = {"A": A, "__name__": "lambda_fake"}
+ exec(code, d)
+ generate_lambda_stmt = d["generate_lambda_stmt"]
+
+ runs: List[Optional[int]] = [None for _ in range(self.THREADS)]
+ conns = [engine.connect() for _ in range(self.THREADS)]
+
+ def run(num):
+ wanted = str(num)
+ connection = conns[num]
+ time.sleep(0.1)
+ stmt = generate_lambda_stmt(wanted)
+ time.sleep(0.1)
+ row = connection.execute(stmt).first()
+ if not row:
+ runs[num] = False
+ else:
+ runs[num] = True
+
+ threads = [
+ threading.Thread(target=run, args=(num,))
+ for num in range(self.THREADS)
+ ]
+
+ for thread in threads:
+ thread.start()
+ for thread in threads:
+ thread.join(timeout=10)
+ for conn in conns:
+ conn.close()
+
+ fails = len([r for r in runs if r is False])
+ assert not fails, f"{fails} runs failed"