summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2020-11-15 16:58:50 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2020-12-11 13:26:05 -0500
commitba5cbf9366e9b2c5ed8e27e91815d7a2c3b63e41 (patch)
tree038f2263d581d5e49d74731af68febc4bf64eb19 /test
parent87d58b6d8188ccff808b3207d5f9398bb9adf9b9 (diff)
downloadsqlalchemy-ba5cbf9366e9b2c5ed8e27e91815d7a2c3b63e41.tar.gz
correct for "autocommit" deprecation warning
Ensure no autocommit warnings occur internally or within tests. Also includes fixes for SQL Server full text tests which apparently have not been working at all for a long time, as it used long removed APIs. CI has not had fulltext running for some years and is now installed. Change-Id: Id806e1856c9da9f0a9eac88cebc7a94ecc95eb96
Diffstat (limited to 'test')
-rw-r--r--test/aaa_profiling/test_resultset.py41
-rwxr-xr-xtest/conftest.py2
-rw-r--r--test/dialect/mssql/test_engine.py2
-rw-r--r--test/dialect/mssql/test_query.py336
-rw-r--r--test/dialect/mssql/test_reflection.py17
-rw-r--r--test/dialect/mssql/test_types.py182
-rw-r--r--test/dialect/mysql/test_dialect.py94
-rw-r--r--test/dialect/mysql/test_on_duplicate.py263
-rw-r--r--test/dialect/mysql/test_query.py108
-rw-r--r--test/dialect/mysql/test_reflection.py96
-rw-r--r--test/dialect/oracle/test_dialect.py93
-rw-r--r--test/dialect/oracle/test_reflection.py221
-rw-r--r--test/dialect/oracle/test_types.py79
-rw-r--r--test/dialect/postgresql/test_dialect.py291
-rw-r--r--test/dialect/postgresql/test_on_conflict.py894
-rw-r--r--test/dialect/postgresql/test_query.py220
-rw-r--r--test/dialect/postgresql/test_reflection.py180
-rw-r--r--test/dialect/postgresql/test_types.py2
-rw-r--r--test/dialect/test_mxodbc.py63
-rw-r--r--test/dialect/test_sqlite.py136
-rw-r--r--test/engine/test_ddlevents.py1
-rw-r--r--test/engine/test_deprecations.py379
-rw-r--r--test/engine/test_execute.py161
-rw-r--r--test/engine/test_logging.py56
-rw-r--r--test/engine/test_reconnect.py16
-rw-r--r--test/engine/test_reflection.py8
-rw-r--r--test/engine/test_transaction.py281
-rw-r--r--test/ext/test_associationproxy.py56
-rw-r--r--test/ext/test_horizontal_shard.py15
-rw-r--r--test/orm/inheritance/test_selects.py15
-rw-r--r--test/orm/test_bind.py6
-rw-r--r--test/orm/test_compile.py56
-rw-r--r--test/orm/test_eager_relations.py119
-rw-r--r--test/orm/test_expire.py47
-rw-r--r--test/orm/test_lazy_relations.py24
-rw-r--r--test/orm/test_mapper.py22
-rw-r--r--test/orm/test_naturalpks.py16
-rw-r--r--test/orm/test_query.py15
-rw-r--r--test/orm/test_session.py73
-rw-r--r--test/orm/test_transaction.py4
-rw-r--r--test/orm/test_unitofworkv2.py3
-rw-r--r--test/sql/test_defaults.py14
-rw-r--r--test/sql/test_delete.py62
-rw-r--r--test/sql/test_deprecations.py252
-rw-r--r--test/sql/test_query.py39
-rw-r--r--test/sql/test_quote.py123
-rw-r--r--test/sql/test_resultset.py81
-rw-r--r--test/sql/test_returning.py74
-rw-r--r--test/sql/test_sequences.py64
-rw-r--r--test/sql/test_type_expressions.py28
-rw-r--r--test/sql/test_types.py385
-rw-r--r--test/sql/test_update.py86
52 files changed, 3075 insertions, 2826 deletions
diff --git a/test/aaa_profiling/test_resultset.py b/test/aaa_profiling/test_resultset.py
index 7188c4125..d36a0c9e1 100644
--- a/test/aaa_profiling/test_resultset.py
+++ b/test/aaa_profiling/test_resultset.py
@@ -48,25 +48,28 @@ class ResultSetTest(fixtures.TestBase, AssertsExecutionResults):
)
def setup(self):
- metadata.create_all()
- t.insert().execute(
- [
- dict(
- ("field%d" % fnum, u("value%d" % fnum))
- for fnum in range(NUM_FIELDS)
- )
- for r_num in range(NUM_RECORDS)
- ]
- )
- t2.insert().execute(
- [
- dict(
- ("field%d" % fnum, u("value%d" % fnum))
- for fnum in range(NUM_FIELDS)
- )
- for r_num in range(NUM_RECORDS)
- ]
- )
+ with testing.db.begin() as conn:
+ metadata.create_all(conn)
+ conn.execute(
+ t.insert(),
+ [
+ dict(
+ ("field%d" % fnum, u("value%d" % fnum))
+ for fnum in range(NUM_FIELDS)
+ )
+ for r_num in range(NUM_RECORDS)
+ ],
+ )
+ conn.execute(
+ t2.insert(),
+ [
+ dict(
+ ("field%d" % fnum, u("value%d" % fnum))
+ for fnum in range(NUM_FIELDS)
+ )
+ for r_num in range(NUM_RECORDS)
+ ],
+ )
# warm up type caches
with testing.db.connect() as conn:
diff --git a/test/conftest.py b/test/conftest.py
index 63f3989eb..0db4486a9 100755
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -12,6 +12,8 @@ import sys
import pytest
+os.environ["SQLALCHEMY_WARN_20"] = "true"
+
collect_ignore_glob = []
# minimum version for a py3k only test is at
diff --git a/test/dialect/mssql/test_engine.py b/test/dialect/mssql/test_engine.py
index 444455958..668df6ecb 100644
--- a/test/dialect/mssql/test_engine.py
+++ b/test/dialect/mssql/test_engine.py
@@ -382,7 +382,7 @@ class FastExecutemanyTest(fixtures.TestBase):
if executemany:
assert cursor.fast_executemany
- with eng.connect() as conn:
+ with eng.begin() as conn:
conn.execute(
t.insert(),
[{"id": i, "data": "data_%d" % i} for i in range(100)],
diff --git a/test/dialect/mssql/test_query.py b/test/dialect/mssql/test_query.py
index d9dc033e1..ea0bfa4d2 100644
--- a/test/dialect/mssql/test_query.py
+++ b/test/dialect/mssql/test_query.py
@@ -9,7 +9,6 @@ from sqlalchemy import func
from sqlalchemy import Identity
from sqlalchemy import Integer
from sqlalchemy import literal
-from sqlalchemy import MetaData
from sqlalchemy import or_
from sqlalchemy import PrimaryKeyConstraint
from sqlalchemy import select
@@ -26,22 +25,15 @@ from sqlalchemy.testing.assertsql import CursorSQL
from sqlalchemy.testing.assertsql import DialectSQL
from sqlalchemy.util import ue
-metadata = None
-cattable = None
-matchtable = None
-
-class IdentityInsertTest(fixtures.TestBase, AssertsCompiledSQL):
+class IdentityInsertTest(fixtures.TablesTest, AssertsCompiledSQL):
__only_on__ = "mssql"
__dialect__ = mssql.MSDialect()
__backend__ = True
@classmethod
- def setup_class(cls):
- global metadata, cattable
- metadata = MetaData(testing.db)
-
- cattable = Table(
+ def define_tables(cls, metadata):
+ Table(
"cattable",
metadata,
Column("id", Integer),
@@ -49,82 +41,82 @@ class IdentityInsertTest(fixtures.TestBase, AssertsCompiledSQL):
PrimaryKeyConstraint("id", name="PK_cattable"),
)
- def setup(self):
- metadata.create_all()
-
- def teardown(self):
- metadata.drop_all()
-
def test_compiled(self):
+ cattable = self.tables.cattable
self.assert_compile(
cattable.insert().values(id=9, description="Python"),
"INSERT INTO cattable (id, description) "
"VALUES (:id, :description)",
)
- def test_execute(self):
- with testing.db.connect() as conn:
- conn.execute(cattable.insert().values(id=9, description="Python"))
-
- cats = conn.execute(cattable.select().order_by(cattable.c.id))
- eq_([(9, "Python")], list(cats))
+ def test_execute(self, connection):
+ conn = connection
+ cattable = self.tables.cattable
+ conn.execute(cattable.insert().values(id=9, description="Python"))
- result = conn.execute(cattable.insert().values(description="PHP"))
- eq_(result.inserted_primary_key, (10,))
- lastcat = conn.execute(
- cattable.select().order_by(desc(cattable.c.id))
- )
- eq_((10, "PHP"), lastcat.first())
-
- def test_executemany(self):
- with testing.db.connect() as conn:
- conn.execute(
- cattable.insert(),
- [
- {"id": 89, "description": "Python"},
- {"id": 8, "description": "Ruby"},
- {"id": 3, "description": "Perl"},
- {"id": 1, "description": "Java"},
- ],
- )
- cats = conn.execute(cattable.select().order_by(cattable.c.id))
- eq_(
- [(1, "Java"), (3, "Perl"), (8, "Ruby"), (89, "Python")],
- list(cats),
- )
- conn.execute(
- cattable.insert(),
- [{"description": "PHP"}, {"description": "Smalltalk"}],
- )
- lastcats = conn.execute(
- cattable.select().order_by(desc(cattable.c.id)).limit(2)
- )
- eq_([(91, "Smalltalk"), (90, "PHP")], list(lastcats))
+ cats = conn.execute(cattable.select().order_by(cattable.c.id))
+ eq_([(9, "Python")], list(cats))
- def test_insert_plain_param(self):
- with testing.db.connect() as conn:
- conn.execute(cattable.insert(), id=5)
- eq_(conn.scalar(select(cattable.c.id)), 5)
+ result = conn.execute(cattable.insert().values(description="PHP"))
+ eq_(result.inserted_primary_key, (10,))
+ lastcat = conn.execute(cattable.select().order_by(desc(cattable.c.id)))
+ eq_((10, "PHP"), lastcat.first())
- def test_insert_values_key_plain(self):
- with testing.db.connect() as conn:
- conn.execute(cattable.insert().values(id=5))
- eq_(conn.scalar(select(cattable.c.id)), 5)
-
- def test_insert_values_key_expression(self):
- with testing.db.connect() as conn:
- conn.execute(cattable.insert().values(id=literal(5)))
- eq_(conn.scalar(select(cattable.c.id)), 5)
-
- def test_insert_values_col_plain(self):
- with testing.db.connect() as conn:
- conn.execute(cattable.insert().values({cattable.c.id: 5}))
- eq_(conn.scalar(select(cattable.c.id)), 5)
-
- def test_insert_values_col_expression(self):
- with testing.db.connect() as conn:
- conn.execute(cattable.insert().values({cattable.c.id: literal(5)}))
- eq_(conn.scalar(select(cattable.c.id)), 5)
+ def test_executemany(self, connection):
+ conn = connection
+ cattable = self.tables.cattable
+ conn.execute(
+ cattable.insert(),
+ [
+ {"id": 89, "description": "Python"},
+ {"id": 8, "description": "Ruby"},
+ {"id": 3, "description": "Perl"},
+ {"id": 1, "description": "Java"},
+ ],
+ )
+ cats = conn.execute(cattable.select().order_by(cattable.c.id))
+ eq_(
+ [(1, "Java"), (3, "Perl"), (8, "Ruby"), (89, "Python")],
+ list(cats),
+ )
+ conn.execute(
+ cattable.insert(),
+ [{"description": "PHP"}, {"description": "Smalltalk"}],
+ )
+ lastcats = conn.execute(
+ cattable.select().order_by(desc(cattable.c.id)).limit(2)
+ )
+ eq_([(91, "Smalltalk"), (90, "PHP")], list(lastcats))
+
+ def test_insert_plain_param(self, connection):
+ conn = connection
+ cattable = self.tables.cattable
+ conn.execute(cattable.insert(), id=5)
+ eq_(conn.scalar(select(cattable.c.id)), 5)
+
+ def test_insert_values_key_plain(self, connection):
+ conn = connection
+ cattable = self.tables.cattable
+ conn.execute(cattable.insert().values(id=5))
+ eq_(conn.scalar(select(cattable.c.id)), 5)
+
+ def test_insert_values_key_expression(self, connection):
+ conn = connection
+ cattable = self.tables.cattable
+ conn.execute(cattable.insert().values(id=literal(5)))
+ eq_(conn.scalar(select(cattable.c.id)), 5)
+
+ def test_insert_values_col_plain(self, connection):
+ conn = connection
+ cattable = self.tables.cattable
+ conn.execute(cattable.insert().values({cattable.c.id: 5}))
+ eq_(conn.scalar(select(cattable.c.id)), 5)
+
+ def test_insert_values_col_expression(self, connection):
+ conn = connection
+ cattable = self.tables.cattable
+ conn.execute(cattable.insert().values({cattable.c.id: literal(5)}))
+ eq_(conn.scalar(select(cattable.c.id)), 5)
class QueryUnicodeTest(fixtures.TestBase):
@@ -391,37 +383,35 @@ def full_text_search_missing():
"""Test if full text search is not implemented and return False if
it is and True otherwise."""
- try:
- connection = testing.db.connect()
- try:
- connection.exec_driver_sql(
- "CREATE FULLTEXT CATALOG Catalog AS " "DEFAULT"
- )
- return False
- except Exception:
- return True
- finally:
- connection.close()
+ if not testing.against("mssql"):
+ return True
+
+ with testing.db.connect() as conn:
+ result = conn.exec_driver_sql(
+ "SELECT cast(SERVERPROPERTY('IsFullTextInstalled') as integer)"
+ )
+ return result.scalar() == 0
-class MatchTest(fixtures.TestBase, AssertsCompiledSQL):
+class MatchTest(fixtures.TablesTest, AssertsCompiledSQL):
__only_on__ = "mssql"
__skip_if__ = (full_text_search_missing,)
__backend__ = True
+ run_setup_tables = "once"
+ run_inserts = run_deletes = "once"
+
@classmethod
- def setup_class(cls):
- global metadata, cattable, matchtable
- metadata = MetaData(testing.db)
- cattable = Table(
+ def define_tables(cls, metadata):
+ Table(
"cattable",
metadata,
Column("id", Integer),
Column("description", String(50)),
PrimaryKeyConstraint("id", name="PK_cattable"),
)
- matchtable = Table(
+ Table(
"matchtable",
metadata,
Column("id", Integer),
@@ -429,24 +419,65 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL):
Column("category_id", Integer, ForeignKey("cattable.id")),
PrimaryKeyConstraint("id", name="PK_matchtable"),
)
- DDL(
- """CREATE FULLTEXT INDEX
+
+ event.listen(
+ metadata,
+ "before_create",
+ DDL("CREATE FULLTEXT CATALOG Catalog AS DEFAULT"),
+ )
+ event.listen(
+ metadata,
+ "after_create",
+ DDL(
+ """CREATE FULLTEXT INDEX
ON cattable (description)
KEY INDEX PK_cattable"""
- ).execute_at("after-create", matchtable)
- DDL(
- """CREATE FULLTEXT INDEX
+ ),
+ )
+ event.listen(
+ metadata,
+ "after_create",
+ DDL(
+ """CREATE FULLTEXT INDEX
ON matchtable (title)
KEY INDEX PK_matchtable"""
- ).execute_at("after-create", matchtable)
- metadata.create_all()
- cattable.insert().execute(
+ ),
+ )
+
+ event.listen(
+ metadata,
+ "after_drop",
+ DDL("DROP FULLTEXT CATALOG Catalog"),
+ )
+
+ @classmethod
+ def setup_bind(cls):
+ return testing.db.execution_options(isolation_level="AUTOCOMMIT")
+
+ @classmethod
+ def setup_class(cls):
+ with testing.db.connect().execution_options(
+ isolation_level="AUTOCOMMIT"
+ ) as conn:
+ try:
+ conn.exec_driver_sql("DROP FULLTEXT CATALOG Catalog")
+ except:
+ pass
+ super(MatchTest, cls).setup_class()
+
+ @classmethod
+ def insert_data(cls, connection):
+ cattable, matchtable = cls.tables("cattable", "matchtable")
+
+ connection.execute(
+ cattable.insert(),
[
{"id": 1, "description": "Python"},
{"id": 2, "description": "Ruby"},
- ]
+ ],
)
- matchtable.insert().execute(
+ connection.execute(
+ matchtable.insert(),
[
{
"id": 1,
@@ -461,62 +492,53 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL):
},
{"id": 4, "title": "Guide to Django", "category_id": 1},
{"id": 5, "title": "Python in a Nutshell", "category_id": 1},
- ]
+ ],
)
- DDL("WAITFOR DELAY '00:00:05'").execute(bind=engines.testing_engine())
-
- @classmethod
- def teardown_class(cls):
- metadata.drop_all()
- connection = testing.db.connect()
- connection.exec_driver_sql("DROP FULLTEXT CATALOG Catalog")
- connection.close()
+ # apparently this is needed! index must run asynchronously
+ connection.execute(DDL("WAITFOR DELAY '00:00:05'"))
def test_expression(self):
+ matchtable = self.tables.matchtable
self.assert_compile(
matchtable.c.title.match("somstr"),
"CONTAINS (matchtable.title, ?)",
)
- def test_simple_match(self):
- results = (
+ def test_simple_match(self, connection):
+ matchtable = self.tables.matchtable
+ results = connection.execute(
matchtable.select()
.where(matchtable.c.title.match("python"))
.order_by(matchtable.c.id)
- .execute()
- .fetchall()
- )
+ ).fetchall()
eq_([2, 5], [r.id for r in results])
- def test_simple_match_with_apostrophe(self):
- results = (
- matchtable.select()
- .where(matchtable.c.title.match("Matz's"))
- .execute()
- .fetchall()
- )
+ def test_simple_match_with_apostrophe(self, connection):
+ matchtable = self.tables.matchtable
+ results = connection.execute(
+ matchtable.select().where(matchtable.c.title.match("Matz's"))
+ ).fetchall()
eq_([3], [r.id for r in results])
- def test_simple_prefix_match(self):
- results = (
- matchtable.select()
- .where(matchtable.c.title.match('"nut*"'))
- .execute()
- .fetchall()
- )
+ def test_simple_prefix_match(self, connection):
+ matchtable = self.tables.matchtable
+ results = connection.execute(
+ matchtable.select().where(matchtable.c.title.match('"nut*"'))
+ ).fetchall()
eq_([5], [r.id for r in results])
- def test_simple_inflectional_match(self):
- results = (
- matchtable.select()
- .where(matchtable.c.title.match('FORMSOF(INFLECTIONAL, "dives")'))
- .execute()
- .fetchall()
- )
+ def test_simple_inflectional_match(self, connection):
+ matchtable = self.tables.matchtable
+ results = connection.execute(
+ matchtable.select().where(
+ matchtable.c.title.match('FORMSOF(INFLECTIONAL, "dives")')
+ )
+ ).fetchall()
eq_([2], [r.id for r in results])
- def test_or_match(self):
- results1 = (
+ def test_or_match(self, connection):
+ matchtable = self.tables.matchtable
+ results1 = connection.execute(
matchtable.select()
.where(
or_(
@@ -525,31 +547,25 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL):
)
)
.order_by(matchtable.c.id)
- .execute()
- .fetchall()
- )
+ ).fetchall()
eq_([3, 5], [r.id for r in results1])
- results2 = (
+ results2 = connection.execute(
matchtable.select()
.where(matchtable.c.title.match("nutshell OR ruby"))
.order_by(matchtable.c.id)
- .execute()
- .fetchall()
- )
+ ).fetchall()
eq_([3, 5], [r.id for r in results2])
- def test_and_match(self):
- results1 = (
- matchtable.select()
- .where(
+ def test_and_match(self, connection):
+ matchtable = self.tables.matchtable
+ results1 = connection.execute(
+ matchtable.select().where(
and_(
matchtable.c.title.match("python"),
matchtable.c.title.match("nutshell"),
)
)
- .execute()
- .fetchall()
- )
+ ).fetchall()
eq_([5], [r.id for r in results1])
results2 = (
matchtable.select()
@@ -559,8 +575,10 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL):
)
eq_([5], [r.id for r in results2])
- def test_match_across_joins(self):
- results = (
+ def test_match_across_joins(self, connection):
+ matchtable = self.tables.matchtable
+ cattable = self.tables.cattable
+ results = connection.execute(
matchtable.select()
.where(
and_(
@@ -572,7 +590,5 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL):
)
)
.order_by(matchtable.c.id)
- .execute()
- .fetchall()
- )
+ ).fetchall()
eq_([1, 3, 5], [r.id for r in results])
diff --git a/test/dialect/mssql/test_reflection.py b/test/dialect/mssql/test_reflection.py
index 6009bfb6c..86c97316a 100644
--- a/test/dialect/mssql/test_reflection.py
+++ b/test/dialect/mssql/test_reflection.py
@@ -741,14 +741,9 @@ class IdentityReflectionTest(fixtures.TablesTest):
@testing.requires.views
def test_reflect_views(self, connection):
- try:
- with testing.db.connect() as conn:
- conn.exec_driver_sql("CREATE VIEW view1 AS SELECT * FROM t1")
- insp = inspect(testing.db)
- for col in insp.get_columns("view1"):
- is_true("dialect_options" not in col)
- is_true("identity" in col)
- eq_(col["identity"], {})
- finally:
- with testing.db.connect() as conn:
- conn.exec_driver_sql("DROP VIEW view1")
+ connection.exec_driver_sql("CREATE VIEW view1 AS SELECT * FROM t1")
+ insp = inspect(connection)
+ for col in insp.get_columns("view1"):
+ is_true("dialect_options" not in col)
+ is_true("identity" in col)
+ eq_(col["identity"], {})
diff --git a/test/dialect/mssql/test_types.py b/test/dialect/mssql/test_types.py
index 11a2a25b3..a4a3bedda 100644
--- a/test/dialect/mssql/test_types.py
+++ b/test/dialect/mssql/test_types.py
@@ -221,7 +221,7 @@ class RowVersionTest(fixtures.TablesTest):
Column("rv", cls(convert_int=convert_int)),
)
- with testing.db.connect() as conn:
+ with testing.db.begin() as conn:
conn.execute(t.insert().values(data="foo"))
last_ts_1 = conn.exec_driver_sql("SELECT @@DBTS").scalar()
@@ -545,7 +545,7 @@ class TypeRoundTripTest(
__backend__ = True
@testing.provide_metadata
- def test_decimal_notation(self):
+ def test_decimal_notation(self, connection):
metadata = self.metadata
numeric_table = Table(
"numeric_table",
@@ -560,7 +560,7 @@ class TypeRoundTripTest(
"numericcol", Numeric(precision=38, scale=20, asdecimal=True)
),
)
- metadata.create_all()
+ metadata.create_all(connection)
test_items = [
decimal.Decimal(d)
for d in (
@@ -623,21 +623,20 @@ class TypeRoundTripTest(
)
]
- with testing.db.connect() as conn:
- for value in test_items:
- result = conn.execute(
- numeric_table.insert(), dict(numericcol=value)
- )
- primary_key = result.inserted_primary_key
- returned = conn.scalar(
- select(numeric_table.c.numericcol).where(
- numeric_table.c.id == primary_key[0]
- )
+ for value in test_items:
+ result = connection.execute(
+ numeric_table.insert(), dict(numericcol=value)
+ )
+ primary_key = result.inserted_primary_key
+ returned = connection.scalar(
+ select(numeric_table.c.numericcol).where(
+ numeric_table.c.id == primary_key[0]
)
- eq_(value, returned)
+ )
+ eq_(value, returned)
@testing.provide_metadata
- def test_float(self):
+ def test_float(self, connection):
metadata = self.metadata
float_table = Table(
@@ -652,41 +651,47 @@ class TypeRoundTripTest(
Column("floatcol", Float()),
)
- metadata.create_all()
- try:
- test_items = [
- float(d)
- for d in (
- "1500000.00000000000000000000",
- "-1500000.00000000000000000000",
- "1500000",
- "0.0000000000000000002",
- "0.2",
- "-0.0000000000000000002",
- "156666.458923543",
- "-156666.458923543",
- "1",
- "-1",
- "1234",
- "2E-12",
- "4E8",
- "3E-6",
- "3E-7",
- "4.1",
- "1E-1",
- "1E-2",
- "1E-3",
- "1E-4",
- "1E-5",
- "1E-6",
- "1E-7",
- "1E-8",
+ metadata.create_all(connection)
+ test_items = [
+ float(d)
+ for d in (
+ "1500000.00000000000000000000",
+ "-1500000.00000000000000000000",
+ "1500000",
+ "0.0000000000000000002",
+ "0.2",
+ "-0.0000000000000000002",
+ "156666.458923543",
+ "-156666.458923543",
+ "1",
+ "-1",
+ "1234",
+ "2E-12",
+ "4E8",
+ "3E-6",
+ "3E-7",
+ "4.1",
+ "1E-1",
+ "1E-2",
+ "1E-3",
+ "1E-4",
+ "1E-5",
+ "1E-6",
+ "1E-7",
+ "1E-8",
+ )
+ ]
+ for value in test_items:
+ result = connection.execute(
+ float_table.insert(), dict(floatcol=value)
+ )
+ primary_key = result.inserted_primary_key
+ returned = connection.scalar(
+ select(float_table.c.floatcol).where(
+ float_table.c.id == primary_key[0]
)
- ]
- for value in test_items:
- float_table.insert().execute(floatcol=value)
- except Exception as e:
- raise e
+ )
+ eq_(value, returned)
# todo this should suppress warnings, but it does not
@emits_warning_on("mssql+mxodbc", r".*does not have any indexes.*")
@@ -770,18 +775,17 @@ class TypeRoundTripTest(
d2 = datetime.datetime(2007, 10, 30, 11, 2, 32)
return t, (d1, t1, d2)
- def test_date_roundtrips(self, date_fixture):
+ def test_date_roundtrips(self, date_fixture, connection):
t, (d1, t1, d2) = date_fixture
- with testing.db.begin() as conn:
- conn.execute(
- t.insert(), adate=d1, adatetime=d2, atime1=t1, atime2=d2
- )
+ connection.execute(
+ t.insert(), adate=d1, adatetime=d2, atime1=t1, atime2=d2
+ )
- row = conn.execute(t.select()).first()
- eq_(
- (row.adate, row.adatetime, row.atime1, row.atime2),
- (d1, d2, t1, d2.time()),
- )
+ row = connection.execute(t.select()).first()
+ eq_(
+ (row.adate, row.adatetime, row.atime1, row.atime2),
+ (d1, d2, t1, d2.time()),
+ )
@testing.metadata_fixture()
def datetimeoffset_fixture(self, metadata):
@@ -870,45 +874,45 @@ class TypeRoundTripTest(
dto_param_value,
expected_offset_hours,
should_fail,
+ connection,
):
t = datetimeoffset_fixture
dto_param_value = dto_param_value()
- with testing.db.begin() as conn:
- if should_fail:
- assert_raises(
- sa.exc.DBAPIError,
- conn.execute,
- t.insert(),
- adatetimeoffset=dto_param_value,
- )
- return
-
- conn.execute(
+ if should_fail:
+ assert_raises(
+ sa.exc.DBAPIError,
+ connection.execute,
t.insert(),
adatetimeoffset=dto_param_value,
)
+ return
- row = conn.execute(t.select()).first()
+ connection.execute(
+ t.insert(),
+ adatetimeoffset=dto_param_value,
+ )
- if dto_param_value is None:
- is_(row.adatetimeoffset, None)
- else:
- eq_(
- row.adatetimeoffset,
- datetime.datetime(
- 2007,
- 10,
- 30,
- 11,
- 2,
- 32,
- 123456,
- util.timezone(
- datetime.timedelta(hours=expected_offset_hours)
- ),
+ row = connection.execute(t.select()).first()
+
+ if dto_param_value is None:
+ is_(row.adatetimeoffset, None)
+ else:
+ eq_(
+ row.adatetimeoffset,
+ datetime.datetime(
+ 2007,
+ 10,
+ 30,
+ 11,
+ 2,
+ 32,
+ 123456,
+ util.timezone(
+ datetime.timedelta(hours=expected_offset_hours)
),
- )
+ ),
+ )
@emits_warning_on("mssql+mxodbc", r".*does not have any indexes.*")
@testing.provide_metadata
@@ -1173,7 +1177,7 @@ class BinaryTest(fixtures.TestBase):
if expected is None:
expected = data
- with engine.connect() as conn:
+ with engine.begin() as conn:
conn.execute(binary_table.insert(), data=data)
eq_(conn.scalar(select(binary_table.c.data)), expected)
diff --git a/test/dialect/mysql/test_dialect.py b/test/dialect/mysql/test_dialect.py
index abd3a491f..3c569bf05 100644
--- a/test/dialect/mysql/test_dialect.py
+++ b/test/dialect/mysql/test_dialect.py
@@ -20,7 +20,7 @@ from sqlalchemy.testing import expect_warnings
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
from sqlalchemy.testing import mock
-from ...engine import test_execute
+from ...engine import test_deprecations
class BackendDialectTest(fixtures.TestBase):
@@ -382,56 +382,56 @@ class RemoveUTCTimestampTest(fixtures.TablesTest):
Column("udata", DateTime, onupdate=func.utc_timestamp()),
)
- def test_insert_executemany(self):
- with testing.db.connect() as conn:
- conn.execute(
- self.tables.t.insert().values(data=func.utc_timestamp()),
- [{"x": 5}, {"x": 6}, {"x": 7}],
- )
+ def test_insert_executemany(self, connection):
+ conn = connection
+ conn.execute(
+ self.tables.t.insert().values(data=func.utc_timestamp()),
+ [{"x": 5}, {"x": 6}, {"x": 7}],
+ )
- def test_update_executemany(self):
- with testing.db.connect() as conn:
- timestamp = datetime.datetime(2015, 4, 17, 18, 5, 2)
- conn.execute(
- self.tables.t.insert(),
- [
- {"x": 5, "data": timestamp},
- {"x": 6, "data": timestamp},
- {"x": 7, "data": timestamp},
- ],
- )
+ def test_update_executemany(self, connection):
+ conn = connection
+ timestamp = datetime.datetime(2015, 4, 17, 18, 5, 2)
+ conn.execute(
+ self.tables.t.insert(),
+ [
+ {"x": 5, "data": timestamp},
+ {"x": 6, "data": timestamp},
+ {"x": 7, "data": timestamp},
+ ],
+ )
- conn.execute(
- self.tables.t.update()
- .values(data=func.utc_timestamp())
- .where(self.tables.t.c.x == bindparam("xval")),
- [{"xval": 5}, {"xval": 6}, {"xval": 7}],
- )
+ conn.execute(
+ self.tables.t.update()
+ .values(data=func.utc_timestamp())
+ .where(self.tables.t.c.x == bindparam("xval")),
+ [{"xval": 5}, {"xval": 6}, {"xval": 7}],
+ )
- def test_insert_executemany_w_default(self):
- with testing.db.connect() as conn:
- conn.execute(
- self.tables.t_default.insert(), [{"x": 5}, {"x": 6}, {"x": 7}]
- )
+ def test_insert_executemany_w_default(self, connection):
+ conn = connection
+ conn.execute(
+ self.tables.t_default.insert(), [{"x": 5}, {"x": 6}, {"x": 7}]
+ )
- def test_update_executemany_w_default(self):
- with testing.db.connect() as conn:
- timestamp = datetime.datetime(2015, 4, 17, 18, 5, 2)
- conn.execute(
- self.tables.t_default.insert(),
- [
- {"x": 5, "idata": timestamp},
- {"x": 6, "idata": timestamp},
- {"x": 7, "idata": timestamp},
- ],
- )
+ def test_update_executemany_w_default(self, connection):
+ conn = connection
+ timestamp = datetime.datetime(2015, 4, 17, 18, 5, 2)
+ conn.execute(
+ self.tables.t_default.insert(),
+ [
+ {"x": 5, "idata": timestamp},
+ {"x": 6, "idata": timestamp},
+ {"x": 7, "idata": timestamp},
+ ],
+ )
- conn.execute(
- self.tables.t_default.update()
- .values(idata=func.utc_timestamp())
- .where(self.tables.t_default.c.x == bindparam("xval")),
- [{"xval": 5}, {"xval": 6}, {"xval": 7}],
- )
+ conn.execute(
+ self.tables.t_default.update()
+ .values(idata=func.utc_timestamp())
+ .where(self.tables.t_default.c.x == bindparam("xval")),
+ [{"xval": 5}, {"xval": 6}, {"xval": 7}],
+ )
class SQLModeDetectionTest(fixtures.TestBase):
@@ -505,7 +505,7 @@ class ExecutionTest(fixtures.TestBase):
class AutocommitTextTest(
- test_execute.AutocommitKeywordFixture, fixtures.TestBase
+ test_deprecations.AutocommitKeywordFixture, fixtures.TestBase
):
__only_on__ = "mysql", "mariadb"
diff --git a/test/dialect/mysql/test_on_duplicate.py b/test/dialect/mysql/test_on_duplicate.py
index ed88121a5..dc86aaeb0 100644
--- a/test/dialect/mysql/test_on_duplicate.py
+++ b/test/dialect/mysql/test_on_duplicate.py
@@ -5,7 +5,6 @@ from sqlalchemy import func
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy import Table
-from sqlalchemy import testing
from sqlalchemy.dialects.mysql import insert
from sqlalchemy.testing import fixtures
from sqlalchemy.testing.assertions import assert_raises
@@ -47,155 +46,145 @@ class OnDuplicateTest(fixtures.TablesTest):
{"id": 2, "bar": "baz"},
)
- def test_on_duplicate_key_update_multirow(self):
+ def test_on_duplicate_key_update_multirow(self, connection):
foos = self.tables.foos
- with testing.db.connect() as conn:
- conn.execute(insert(foos, dict(id=1, bar="b", baz="bz")))
- stmt = insert(foos).values(
- [dict(id=1, bar="ab"), dict(id=2, bar="b")]
- )
- stmt = stmt.on_duplicate_key_update(bar=stmt.inserted.bar)
-
- result = conn.execute(stmt)
-
- # multirow, so its ambiguous. this is a behavioral change
- # in 1.4
- eq_(result.inserted_primary_key, (None,))
- eq_(
- conn.execute(foos.select().where(foos.c.id == 1)).fetchall(),
- [(1, "ab", "bz", False)],
- )
+ conn = connection
+ conn.execute(insert(foos, dict(id=1, bar="b", baz="bz")))
+ stmt = insert(foos).values([dict(id=1, bar="ab"), dict(id=2, bar="b")])
+ stmt = stmt.on_duplicate_key_update(bar=stmt.inserted.bar)
+
+ result = conn.execute(stmt)
+
+ # multirow, so its ambiguous. this is a behavioral change
+ # in 1.4
+ eq_(result.inserted_primary_key, (None,))
+ eq_(
+ conn.execute(foos.select().where(foos.c.id == 1)).fetchall(),
+ [(1, "ab", "bz", False)],
+ )
- def test_on_duplicate_key_update_singlerow(self):
+ def test_on_duplicate_key_update_singlerow(self, connection):
foos = self.tables.foos
- with testing.db.connect() as conn:
- conn.execute(insert(foos, dict(id=1, bar="b", baz="bz")))
- stmt = insert(foos).values(dict(id=2, bar="b"))
- stmt = stmt.on_duplicate_key_update(bar=stmt.inserted.bar)
-
- result = conn.execute(stmt)
-
- # only one row in the INSERT so we do inserted_primary_key
- eq_(result.inserted_primary_key, (2,))
- eq_(
- conn.execute(foos.select().where(foos.c.id == 1)).fetchall(),
- [(1, "b", "bz", False)],
- )
+ conn = connection
+ conn.execute(insert(foos, dict(id=1, bar="b", baz="bz")))
+ stmt = insert(foos).values(dict(id=2, bar="b"))
+ stmt = stmt.on_duplicate_key_update(bar=stmt.inserted.bar)
+
+ result = conn.execute(stmt)
+
+ # only one row in the INSERT so we do inserted_primary_key
+ eq_(result.inserted_primary_key, (2,))
+ eq_(
+ conn.execute(foos.select().where(foos.c.id == 1)).fetchall(),
+ [(1, "b", "bz", False)],
+ )
- def test_on_duplicate_key_update_null_multirow(self):
+ def test_on_duplicate_key_update_null_multirow(self, connection):
foos = self.tables.foos
- with testing.db.connect() as conn:
- conn.execute(insert(foos, dict(id=1, bar="b", baz="bz")))
- stmt = insert(foos).values(
- [dict(id=1, bar="ab"), dict(id=2, bar="b")]
- )
- stmt = stmt.on_duplicate_key_update(updated_once=None)
- result = conn.execute(stmt)
-
- # ambiguous
- eq_(result.inserted_primary_key, (None,))
- eq_(
- conn.execute(foos.select().where(foos.c.id == 1)).fetchall(),
- [(1, "b", "bz", None)],
- )
+ conn = connection
+ conn.execute(insert(foos, dict(id=1, bar="b", baz="bz")))
+ stmt = insert(foos).values([dict(id=1, bar="ab"), dict(id=2, bar="b")])
+ stmt = stmt.on_duplicate_key_update(updated_once=None)
+ result = conn.execute(stmt)
+
+ # ambiguous
+ eq_(result.inserted_primary_key, (None,))
+ eq_(
+ conn.execute(foos.select().where(foos.c.id == 1)).fetchall(),
+ [(1, "b", "bz", None)],
+ )
- def test_on_duplicate_key_update_expression_multirow(self):
+ def test_on_duplicate_key_update_expression_multirow(self, connection):
foos = self.tables.foos
- with testing.db.connect() as conn:
- conn.execute(insert(foos, dict(id=1, bar="b", baz="bz")))
- stmt = insert(foos).values(
- [dict(id=1, bar="ab"), dict(id=2, bar="b")]
- )
- stmt = stmt.on_duplicate_key_update(
- bar=func.concat(stmt.inserted.bar, "_foo")
- )
- result = conn.execute(stmt)
- eq_(result.inserted_primary_key, (None,))
- eq_(
- conn.execute(foos.select().where(foos.c.id == 1)).fetchall(),
- [(1, "ab_foo", "bz", False)],
- )
+ conn = connection
+ conn.execute(insert(foos, dict(id=1, bar="b", baz="bz")))
+ stmt = insert(foos).values([dict(id=1, bar="ab"), dict(id=2, bar="b")])
+ stmt = stmt.on_duplicate_key_update(
+ bar=func.concat(stmt.inserted.bar, "_foo")
+ )
+ result = conn.execute(stmt)
+ eq_(result.inserted_primary_key, (None,))
+ eq_(
+ conn.execute(foos.select().where(foos.c.id == 1)).fetchall(),
+ [(1, "ab_foo", "bz", False)],
+ )
- def test_on_duplicate_key_update_preserve_order(self):
+ def test_on_duplicate_key_update_preserve_order(self, connection):
foos = self.tables.foos
- with testing.db.connect() as conn:
- conn.execute(
- insert(
- foos,
- [
- dict(id=1, bar="b", baz="bz"),
- dict(id=2, bar="b", baz="bz2"),
- ],
- )
- )
-
- stmt = insert(foos)
- update_condition = foos.c.updated_once == False
-
- # The following statements show importance of the columns update
- # ordering as old values being referenced in UPDATE clause are
- # getting replaced one by one from left to right with their new
- # values.
- stmt1 = stmt.on_duplicate_key_update(
+ conn = connection
+ conn.execute(
+ insert(
+ foos,
[
- (
- "bar",
- func.if_(
- update_condition,
- func.values(foos.c.bar),
- foos.c.bar,
- ),
- ),
- (
- "updated_once",
- func.if_(update_condition, True, foos.c.updated_once),
- ),
- ]
+ dict(id=1, bar="b", baz="bz"),
+ dict(id=2, bar="b", baz="bz2"),
+ ],
)
- stmt2 = stmt.on_duplicate_key_update(
- [
- (
- "updated_once",
- func.if_(update_condition, True, foos.c.updated_once),
+ )
+
+ stmt = insert(foos)
+ update_condition = foos.c.updated_once == False
+
+ # The following statements show importance of the columns update
+ # ordering as old values being referenced in UPDATE clause are
+ # getting replaced one by one from left to right with their new
+ # values.
+ stmt1 = stmt.on_duplicate_key_update(
+ [
+ (
+ "bar",
+ func.if_(
+ update_condition,
+ func.values(foos.c.bar),
+ foos.c.bar,
),
- (
- "bar",
- func.if_(
- update_condition,
- func.values(foos.c.bar),
- foos.c.bar,
- ),
+ ),
+ (
+ "updated_once",
+ func.if_(update_condition, True, foos.c.updated_once),
+ ),
+ ]
+ )
+ stmt2 = stmt.on_duplicate_key_update(
+ [
+ (
+ "updated_once",
+ func.if_(update_condition, True, foos.c.updated_once),
+ ),
+ (
+ "bar",
+ func.if_(
+ update_condition,
+ func.values(foos.c.bar),
+ foos.c.bar,
),
- ]
- )
- # First statement should succeed updating column bar
- conn.execute(stmt1, dict(id=1, bar="ab"))
- eq_(
- conn.execute(foos.select().where(foos.c.id == 1)).fetchall(),
- [(1, "ab", "bz", True)],
- )
- # Second statement will do noop update of column bar
- conn.execute(stmt2, dict(id=2, bar="ab"))
- eq_(
- conn.execute(foos.select().where(foos.c.id == 2)).fetchall(),
- [(2, "b", "bz2", True)],
- )
+ ),
+ ]
+ )
+ # First statement should succeed updating column bar
+ conn.execute(stmt1, dict(id=1, bar="ab"))
+ eq_(
+ conn.execute(foos.select().where(foos.c.id == 1)).fetchall(),
+ [(1, "ab", "bz", True)],
+ )
+ # Second statement will do noop update of column bar
+ conn.execute(stmt2, dict(id=2, bar="ab"))
+ eq_(
+ conn.execute(foos.select().where(foos.c.id == 2)).fetchall(),
+ [(2, "b", "bz2", True)],
+ )
- def test_last_inserted_id(self):
+ def test_last_inserted_id(self, connection):
foos = self.tables.foos
- with testing.db.connect() as conn:
- stmt = insert(foos).values({"bar": "b", "baz": "bz"})
- result = conn.execute(
- stmt.on_duplicate_key_update(
- bar=stmt.inserted.bar, baz="newbz"
- )
- )
- eq_(result.inserted_primary_key, (1,))
+ conn = connection
+ stmt = insert(foos).values({"bar": "b", "baz": "bz"})
+ result = conn.execute(
+ stmt.on_duplicate_key_update(bar=stmt.inserted.bar, baz="newbz")
+ )
+ eq_(result.inserted_primary_key, (1,))
- stmt = insert(foos).values({"id": 1, "bar": "b", "baz": "bz"})
- result = conn.execute(
- stmt.on_duplicate_key_update(
- bar=stmt.inserted.bar, baz="newbz"
- )
- )
- eq_(result.inserted_primary_key, (1,))
+ stmt = insert(foos).values({"id": 1, "bar": "b", "baz": "bz"})
+ result = conn.execute(
+ stmt.on_duplicate_key_update(bar=stmt.inserted.bar, baz="newbz")
+ )
+ eq_(result.inserted_primary_key, (1,))
diff --git a/test/dialect/mysql/test_query.py b/test/dialect/mysql/test_query.py
index f9d9caf16..f56cd98aa 100644
--- a/test/dialect/mysql/test_query.py
+++ b/test/dialect/mysql/test_query.py
@@ -9,7 +9,6 @@ from sqlalchemy import Column
from sqlalchemy import false
from sqlalchemy import ForeignKey
from sqlalchemy import Integer
-from sqlalchemy import MetaData
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy import String
@@ -44,16 +43,13 @@ class IdiosyncrasyTest(fixtures.TestBase):
)
-class MatchTest(fixtures.TestBase):
+class MatchTest(fixtures.TablesTest):
__only_on__ = "mysql", "mariadb"
__backend__ = True
@classmethod
- def setup_class(cls):
- global metadata, cattable, matchtable
- metadata = MetaData(testing.db)
-
- cattable = Table(
+ def define_tables(cls, metadata):
+ Table(
"cattable",
metadata,
Column("id", Integer, primary_key=True),
@@ -61,7 +57,7 @@ class MatchTest(fixtures.TestBase):
mysql_engine="MyISAM",
mariadb_engine="MyISAM",
)
- matchtable = Table(
+ Table(
"matchtable",
metadata,
Column("id", Integer, primary_key=True),
@@ -70,15 +66,20 @@ class MatchTest(fixtures.TestBase):
mysql_engine="MyISAM",
mariadb_engine="MyISAM",
)
- metadata.create_all()
- cattable.insert().execute(
+ @classmethod
+ def insert_data(cls, connection):
+ cattable, matchtable = cls.tables("cattable", "matchtable")
+
+ connection.execute(
+ cattable.insert(),
[
{"id": 1, "description": "Python"},
{"id": 2, "description": "Ruby"},
- ]
+ ],
)
- matchtable.insert().execute(
+ connection.execute(
+ matchtable.insert(),
[
{
"id": 1,
@@ -97,43 +98,36 @@ class MatchTest(fixtures.TestBase):
"category_id": 1,
},
{"id": 5, "title": "Python in a Nutshell", "category_id": 1},
- ]
+ ],
)
- @classmethod
- def teardown_class(cls):
- metadata.drop_all()
-
- def test_simple_match(self):
- results = (
+ def test_simple_match(self, connection):
+ matchtable = self.tables.matchtable
+ results = connection.execute(
matchtable.select()
.where(matchtable.c.title.match("python"))
.order_by(matchtable.c.id)
- .execute()
- .fetchall()
- )
+ ).fetchall()
eq_([2, 5], [r.id for r in results])
- def test_not_match(self):
- results = (
+ def test_not_match(self, connection):
+ matchtable = self.tables.matchtable
+ results = connection.execute(
matchtable.select()
.where(~matchtable.c.title.match("python"))
.order_by(matchtable.c.id)
- .execute()
- .fetchall()
)
eq_([1, 3, 4], [r.id for r in results])
- def test_simple_match_with_apostrophe(self):
- results = (
- matchtable.select()
- .where(matchtable.c.title.match("Matz's"))
- .execute()
- .fetchall()
- )
+ def test_simple_match_with_apostrophe(self, connection):
+ matchtable = self.tables.matchtable
+ results = connection.execute(
+ matchtable.select().where(matchtable.c.title.match("Matz's"))
+ ).fetchall()
eq_([3], [r.id for r in results])
def test_return_value(self, connection):
+ matchtable = self.tables.matchtable
# test [ticket:3263]
result = connection.execute(
select(
@@ -155,8 +149,9 @@ class MatchTest(fixtures.TestBase):
],
)
- def test_or_match(self):
- results1 = (
+ def test_or_match(self, connection):
+ matchtable = self.tables.matchtable
+ results1 = connection.execute(
matchtable.select()
.where(
or_(
@@ -165,42 +160,37 @@ class MatchTest(fixtures.TestBase):
)
)
.order_by(matchtable.c.id)
- .execute()
- .fetchall()
- )
+ ).fetchall()
eq_([1, 3, 5], [r.id for r in results1])
- results2 = (
+ results2 = connection.execute(
matchtable.select()
.where(matchtable.c.title.match("nutshell ruby"))
.order_by(matchtable.c.id)
- .execute()
- .fetchall()
- )
+ ).fetchall()
eq_([1, 3, 5], [r.id for r in results2])
- def test_and_match(self):
- results1 = (
- matchtable.select()
- .where(
+ def test_and_match(self, connection):
+ matchtable = self.tables.matchtable
+ results1 = connection.execute(
+ matchtable.select().where(
and_(
matchtable.c.title.match("python"),
matchtable.c.title.match("nutshell"),
)
)
- .execute()
- .fetchall()
- )
+ ).fetchall()
eq_([5], [r.id for r in results1])
- results2 = (
- matchtable.select()
- .where(matchtable.c.title.match("+python +nutshell"))
- .execute()
- .fetchall()
- )
+ results2 = connection.execute(
+ matchtable.select().where(
+ matchtable.c.title.match("+python +nutshell")
+ )
+ ).fetchall()
eq_([5], [r.id for r in results2])
- def test_match_across_joins(self):
- results = (
+ def test_match_across_joins(self, connection):
+ matchtable = self.tables.matchtable
+ cattable = self.tables.cattable
+ results = connection.execute(
matchtable.select()
.where(
and_(
@@ -212,9 +202,7 @@ class MatchTest(fixtures.TestBase):
)
)
.order_by(matchtable.c.id)
- .execute()
- .fetchall()
- )
+ ).fetchall()
eq_([1, 3, 5], [r.id for r in results])
diff --git a/test/dialect/mysql/test_reflection.py b/test/dialect/mysql/test_reflection.py
index 3871dbecc..55d88957a 100644
--- a/test/dialect/mysql/test_reflection.py
+++ b/test/dialect/mysql/test_reflection.py
@@ -324,7 +324,8 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL):
str(reflected.c.c6.server_default.arg).upper(),
)
- def test_reflection_with_table_options(self):
+ @testing.provide_metadata
+ def test_reflection_with_table_options(self, connection):
comment = r"""Comment types type speedily ' " \ '' Fun!"""
if testing.against("mariadb"):
kwargs = dict(
@@ -347,18 +348,15 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL):
def_table = Table(
"mysql_def",
- MetaData(),
+ self.metadata,
Column("c1", Integer()),
comment=comment,
**kwargs
)
- with testing.db.connect() as conn:
- def_table.create(conn)
- try:
- reflected = Table("mysql_def", MetaData(), autoload_with=conn)
- finally:
- def_table.drop(conn)
+ conn = connection
+ def_table.create(conn)
+ reflected = Table("mysql_def", MetaData(), autoload_with=conn)
if testing.against("mariadb"):
assert def_table.kwargs["mariadb_engine"] == "MEMORY"
@@ -554,31 +552,31 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL):
assert 1 not in list(conn.execute(tbl.select()).first())
@testing.provide_metadata
- def test_view_reflection(self):
+ def test_view_reflection(self, connection):
Table(
"x", self.metadata, Column("a", Integer), Column("b", String(50))
)
- self.metadata.create_all()
+ self.metadata.create_all(connection)
- with testing.db.connect() as conn:
- conn.exec_driver_sql("CREATE VIEW v1 AS SELECT * FROM x")
- conn.exec_driver_sql(
- "CREATE ALGORITHM=MERGE VIEW v2 AS SELECT * FROM x"
- )
- conn.exec_driver_sql(
- "CREATE ALGORITHM=UNDEFINED VIEW v3 AS SELECT * FROM x"
- )
- conn.exec_driver_sql(
- "CREATE DEFINER=CURRENT_USER VIEW v4 AS SELECT * FROM x"
- )
+ conn = connection
+ conn.exec_driver_sql("CREATE VIEW v1 AS SELECT * FROM x")
+ conn.exec_driver_sql(
+ "CREATE ALGORITHM=MERGE VIEW v2 AS SELECT * FROM x"
+ )
+ conn.exec_driver_sql(
+ "CREATE ALGORITHM=UNDEFINED VIEW v3 AS SELECT * FROM x"
+ )
+ conn.exec_driver_sql(
+ "CREATE DEFINER=CURRENT_USER VIEW v4 AS SELECT * FROM x"
+ )
@event.listens_for(self.metadata, "before_drop")
def cleanup(*arg, **kw):
- with testing.db.connect() as conn:
+ with testing.db.begin() as conn:
for v in ["v1", "v2", "v3", "v4"]:
conn.exec_driver_sql("DROP VIEW %s" % v)
- insp = inspect(testing.db)
+ insp = inspect(connection)
for v in ["v1", "v2", "v3", "v4"]:
eq_(
[
@@ -589,38 +587,36 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL):
)
@testing.provide_metadata
- def test_skip_not_describable(self):
+ def test_skip_not_describable(self, connection):
@event.listens_for(self.metadata, "before_drop")
def cleanup(*arg, **kw):
- with testing.db.connect() as conn:
+ with testing.db.begin() as conn:
conn.exec_driver_sql("DROP TABLE IF EXISTS test_t1")
conn.exec_driver_sql("DROP TABLE IF EXISTS test_t2")
conn.exec_driver_sql("DROP VIEW IF EXISTS test_v")
- with testing.db.connect() as conn:
- conn.exec_driver_sql("CREATE TABLE test_t1 (id INTEGER)")
- conn.exec_driver_sql("CREATE TABLE test_t2 (id INTEGER)")
- conn.exec_driver_sql(
- "CREATE VIEW test_v AS SELECT id FROM test_t1"
- )
- conn.exec_driver_sql("DROP TABLE test_t1")
-
- m = MetaData()
- with expect_warnings(
- "Skipping .* Table or view named .?test_v.? could not be "
- "reflected: .* references invalid table"
- ):
- m.reflect(views=True, bind=conn)
- eq_(m.tables["test_t2"].name, "test_t2")
-
- assert_raises_message(
- exc.UnreflectableTableError,
- "references invalid table",
- Table,
- "test_v",
- MetaData(),
- autoload_with=conn,
- )
+ conn = connection
+ conn.exec_driver_sql("CREATE TABLE test_t1 (id INTEGER)")
+ conn.exec_driver_sql("CREATE TABLE test_t2 (id INTEGER)")
+ conn.exec_driver_sql("CREATE VIEW test_v AS SELECT id FROM test_t1")
+ conn.exec_driver_sql("DROP TABLE test_t1")
+
+ m = MetaData()
+ with expect_warnings(
+ "Skipping .* Table or view named .?test_v.? could not be "
+ "reflected: .* references invalid table"
+ ):
+ m.reflect(views=True, bind=conn)
+ eq_(m.tables["test_t2"].name, "test_t2")
+
+ assert_raises_message(
+ exc.UnreflectableTableError,
+ "references invalid table",
+ Table,
+ "test_v",
+ MetaData(),
+ autoload_with=conn,
+ )
@testing.exclude("mysql", "<", (5, 0, 0), "no information_schema support")
def test_system_views(self):
@@ -663,7 +659,7 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL):
):
Table("nn_t%d" % idx, meta) # to allow DROP
- with testing.db.connect() as c:
+ with testing.db.begin() as c:
c.exec_driver_sql(
"""
CREATE TABLE nn_t%d (
diff --git a/test/dialect/oracle/test_dialect.py b/test/dialect/oracle/test_dialect.py
index aafad8dc1..9a2174a24 100644
--- a/test/dialect/oracle/test_dialect.py
+++ b/test/dialect/oracle/test_dialect.py
@@ -89,6 +89,8 @@ class DefaultSchemaNameTest(fixtures.TestBase):
eng = engines.testing_engine()
with eng.connect() as conn:
+
+ trans = conn.begin()
eq_(
testing.db.dialect._get_default_schema_name(conn),
default_schema_name,
@@ -104,6 +106,7 @@ class DefaultSchemaNameTest(fixtures.TestBase):
)
conn.invalidate()
+ trans.rollback()
eq_(
testing.db.dialect._get_default_schema_name(conn),
@@ -317,53 +320,51 @@ class ComputedReturningTest(fixtures.TablesTest):
implicit_returning=False,
)
- def test_computed_insert(self):
+ def test_computed_insert(self, connection):
test = self.tables.test
- with testing.db.connect() as conn:
- result = conn.execute(
- test.insert().return_defaults(), {"id": 1, "foo": 5}
- )
+ conn = connection
+ result = conn.execute(
+ test.insert().return_defaults(), {"id": 1, "foo": 5}
+ )
- eq_(result.returned_defaults, (47,))
+ eq_(result.returned_defaults, (47,))
- eq_(conn.scalar(select(test.c.bar)), 47)
+ eq_(conn.scalar(select(test.c.bar)), 47)
- def test_computed_update_warning(self):
+ def test_computed_update_warning(self, connection):
test = self.tables.test
- with testing.db.connect() as conn:
- conn.execute(test.insert(), {"id": 1, "foo": 5})
+ conn = connection
+ conn.execute(test.insert(), {"id": 1, "foo": 5})
- if testing.db.dialect._supports_update_returning_computed_cols:
+ if testing.db.dialect._supports_update_returning_computed_cols:
+ result = conn.execute(
+ test.update().values(foo=10).return_defaults()
+ )
+ eq_(result.returned_defaults, (52,))
+ else:
+ with testing.expect_warnings(
+ "Computed columns don't work with Oracle UPDATE"
+ ):
result = conn.execute(
test.update().values(foo=10).return_defaults()
)
- eq_(result.returned_defaults, (52,))
- else:
- with testing.expect_warnings(
- "Computed columns don't work with Oracle UPDATE"
- ):
- result = conn.execute(
- test.update().values(foo=10).return_defaults()
- )
- # returns the *old* value
- eq_(result.returned_defaults, (47,))
+ # returns the *old* value
+ eq_(result.returned_defaults, (47,))
- eq_(conn.scalar(select(test.c.bar)), 52)
+ eq_(conn.scalar(select(test.c.bar)), 52)
- def test_computed_update_no_warning(self):
+ def test_computed_update_no_warning(self, connection):
test = self.tables.test_no_returning
- with testing.db.connect() as conn:
- conn.execute(test.insert(), {"id": 1, "foo": 5})
+ conn = connection
+ conn.execute(test.insert(), {"id": 1, "foo": 5})
- result = conn.execute(
- test.update().values(foo=10).return_defaults()
- )
+ result = conn.execute(test.update().values(foo=10).return_defaults())
- # no returning
- eq_(result.returned_defaults, None)
+ # no returning
+ eq_(result.returned_defaults, None)
- eq_(conn.scalar(select(test.c.bar)), 52)
+ eq_(conn.scalar(select(test.c.bar)), 52)
class OutParamTest(fixtures.TestBase, AssertsExecutionResults):
@@ -372,7 +373,7 @@ class OutParamTest(fixtures.TestBase, AssertsExecutionResults):
@classmethod
def setup_class(cls):
- with testing.db.connect() as c:
+ with testing.db.begin() as c:
c.exec_driver_sql(
"""
create or replace procedure foo(x_in IN number, x_out OUT number,
@@ -404,7 +405,7 @@ end;
@classmethod
def teardown_class(cls):
- with testing.db.connect() as conn:
+ with testing.db.begin() as conn:
conn.execute(text("DROP PROCEDURE foo"))
@@ -674,7 +675,7 @@ class ExecuteTest(fixtures.TestBase):
seq.drop(connection)
@testing.provide_metadata
- def test_limit_offset_for_update(self):
+ def test_limit_offset_for_update(self, connection):
metadata = self.metadata
# oracle can't actually do the ROWNUM thing with FOR UPDATE
# very well.
@@ -685,19 +686,24 @@ class ExecuteTest(fixtures.TestBase):
Column("id", Integer, primary_key=True),
Column("data", Integer),
)
- metadata.create_all()
+ metadata.create_all(connection)
- t.insert().execute(
- {"id": 1, "data": 1},
- {"id": 2, "data": 7},
- {"id": 3, "data": 12},
- {"id": 4, "data": 15},
- {"id": 5, "data": 32},
+ connection.execute(
+ t.insert(),
+ [
+ {"id": 1, "data": 1},
+ {"id": 2, "data": 7},
+ {"id": 3, "data": 12},
+ {"id": 4, "data": 15},
+ {"id": 5, "data": 32},
+ ],
)
# here, we can't use ORDER BY.
eq_(
- t.select().with_for_update().limit(2).execute().fetchall(),
+ connection.execute(
+ t.select().with_for_update().limit(2)
+ ).fetchall(),
[(1, 1), (2, 7)],
)
@@ -706,7 +712,8 @@ class ExecuteTest(fixtures.TestBase):
assert_raises_message(
exc.DatabaseError,
"ORA-02014",
- t.select().with_for_update().limit(2).offset(3).execute,
+ connection.execute,
+ t.select().with_for_update().limit(2).offset(3),
)
diff --git a/test/dialect/oracle/test_reflection.py b/test/dialect/oracle/test_reflection.py
index efa21fc1a..2e515556f 100644
--- a/test/dialect/oracle/test_reflection.py
+++ b/test/dialect/oracle/test_reflection.py
@@ -34,11 +34,6 @@ from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
-def exec_sql(engine, sql, *args, **kwargs):
- with engine.connect() as conn:
- return conn.exec_driver_sql(sql, *args, **kwargs)
-
-
class MultiSchemaTest(fixtures.TestBase, AssertsCompiledSQL):
__only_on__ = "oracle"
__backend__ = True
@@ -49,62 +44,64 @@ class MultiSchemaTest(fixtures.TestBase, AssertsCompiledSQL):
# don't really know how else to go here unless
# we connect as the other user.
- for stmt in (
- """
-create table %(test_schema)s.parent(
- id integer primary key,
- data varchar2(50)
-);
-
-COMMENT ON TABLE %(test_schema)s.parent IS 'my table comment';
-
-create table %(test_schema)s.child(
- id integer primary key,
- data varchar2(50),
- parent_id integer references %(test_schema)s.parent(id)
-);
-
-create table local_table(
- id integer primary key,
- data varchar2(50)
-);
-
-create synonym %(test_schema)s.ptable for %(test_schema)s.parent;
-create synonym %(test_schema)s.ctable for %(test_schema)s.child;
-
-create synonym %(test_schema)s_pt for %(test_schema)s.parent;
-
-create synonym %(test_schema)s.local_table for local_table;
-
--- can't make a ref from local schema to the
--- remote schema's table without this,
--- *and* cant give yourself a grant !
--- so we give it to public. ideas welcome.
-grant references on %(test_schema)s.parent to public;
-grant references on %(test_schema)s.child to public;
-"""
- % {"test_schema": testing.config.test_schema}
- ).split(";"):
- if stmt.strip():
- exec_sql(testing.db, stmt)
+ with testing.db.begin() as conn:
+ for stmt in (
+ """
+ create table %(test_schema)s.parent(
+ id integer primary key,
+ data varchar2(50)
+ );
+
+ COMMENT ON TABLE %(test_schema)s.parent IS 'my table comment';
+
+ create table %(test_schema)s.child(
+ id integer primary key,
+ data varchar2(50),
+ parent_id integer references %(test_schema)s.parent(id)
+ );
+
+ create table local_table(
+ id integer primary key,
+ data varchar2(50)
+ );
+
+ create synonym %(test_schema)s.ptable for %(test_schema)s.parent;
+ create synonym %(test_schema)s.ctable for %(test_schema)s.child;
+
+ create synonym %(test_schema)s_pt for %(test_schema)s.parent;
+
+ create synonym %(test_schema)s.local_table for local_table;
+
+ -- can't make a ref from local schema to the
+ -- remote schema's table without this,
+ -- *and* cant give yourself a grant !
+ -- so we give it to public. ideas welcome.
+ grant references on %(test_schema)s.parent to public;
+ grant references on %(test_schema)s.child to public;
+ """
+ % {"test_schema": testing.config.test_schema}
+ ).split(";"):
+ if stmt.strip():
+ conn.exec_driver_sql(stmt)
@classmethod
def teardown_class(cls):
- for stmt in (
- """
-drop table %(test_schema)s.child;
-drop table %(test_schema)s.parent;
-drop table local_table;
-drop synonym %(test_schema)s.ctable;
-drop synonym %(test_schema)s.ptable;
-drop synonym %(test_schema)s_pt;
-drop synonym %(test_schema)s.local_table;
-
-"""
- % {"test_schema": testing.config.test_schema}
- ).split(";"):
- if stmt.strip():
- exec_sql(testing.db, stmt)
+ with testing.db.begin() as conn:
+ for stmt in (
+ """
+ drop table %(test_schema)s.child;
+ drop table %(test_schema)s.parent;
+ drop table local_table;
+ drop synonym %(test_schema)s.ctable;
+ drop synonym %(test_schema)s.ptable;
+ drop synonym %(test_schema)s_pt;
+ drop synonym %(test_schema)s.local_table;
+
+ """
+ % {"test_schema": testing.config.test_schema}
+ ).split(";"):
+ if stmt.strip():
+ conn.exec_driver_sql(stmt)
@testing.provide_metadata
def test_create_same_names_explicit_schema(self):
@@ -162,7 +159,7 @@ drop synonym %(test_schema)s.local_table;
)
@testing.provide_metadata
- def test_create_same_names_implicit_schema(self):
+ def test_create_same_names_implicit_schema(self, connection):
meta = self.metadata
parent = Table(
"parent", meta, Column("pid", Integer, primary_key=True)
@@ -173,10 +170,11 @@ drop synonym %(test_schema)s.local_table;
Column("cid", Integer, primary_key=True),
Column("pid", Integer, ForeignKey("parent.pid")),
)
- meta.create_all()
- parent.insert().execute({"pid": 1})
- child.insert().execute({"cid": 1, "pid": 1})
- eq_(child.select().execute().fetchall(), [(1, 1)])
+ meta.create_all(connection)
+
+ connection.execute(parent.insert(), {"pid": 1})
+ connection.execute(child.insert(), {"cid": 1, "pid": 1})
+ eq_(connection.execute(child.select()).fetchall(), [(1, 1)])
def test_reflect_alt_owner_explicit(self):
meta = MetaData()
@@ -238,9 +236,8 @@ drop synonym %(test_schema)s.local_table;
{"text": "my local comment"},
)
- def test_reflect_local_to_remote(self):
- exec_sql(
- testing.db,
+ def test_reflect_local_to_remote(self, connection):
+ connection.exec_driver_sql(
"CREATE TABLE localtable (id INTEGER "
"PRIMARY KEY, parent_id INTEGER REFERENCES "
"%(test_schema)s.parent(id))"
@@ -258,7 +255,7 @@ drop synonym %(test_schema)s.local_table;
% {"test_schema": testing.config.test_schema},
)
finally:
- exec_sql(testing.db, "DROP TABLE localtable")
+ connection.exec_driver_sql("DROP TABLE localtable")
def test_reflect_alt_owner_implicit(self):
meta = MetaData()
@@ -286,9 +283,8 @@ drop synonym %(test_schema)s.local_table;
select(parent, child).select_from(parent.join(child))
).fetchall()
- def test_reflect_alt_owner_synonyms(self):
- exec_sql(
- testing.db,
+ def test_reflect_alt_owner_synonyms(self, connection):
+ connection.exec_driver_sql(
"CREATE TABLE localtable (id INTEGER "
"PRIMARY KEY, parent_id INTEGER REFERENCES "
"%s.ptable(id))" % testing.config.test_schema,
@@ -298,7 +294,7 @@ drop synonym %(test_schema)s.local_table;
lcl = Table(
"localtable",
meta,
- autoload_with=testing.db,
+ autoload_with=connection,
oracle_resolve_synonyms=True,
)
parent = meta.tables["%s.ptable" % testing.config.test_schema]
@@ -309,12 +305,11 @@ drop synonym %(test_schema)s.local_table;
"localtable.parent_id"
% {"test_schema": testing.config.test_schema},
)
- with testing.db.connect() as conn:
- conn.execute(
- select(parent, lcl).select_from(parent.join(lcl))
- ).fetchall()
+ connection.execute(
+ select(parent, lcl).select_from(parent.join(lcl))
+ ).fetchall()
finally:
- exec_sql(testing.db, "DROP TABLE localtable")
+ connection.exec_driver_sql("DROP TABLE localtable")
def test_reflect_remote_synonyms(self):
meta = MetaData()
@@ -389,19 +384,20 @@ class SystemTableTablenamesTest(fixtures.TestBase):
__backend__ = True
def setup(self):
- exec_sql(testing.db, "create table my_table (id integer)")
- exec_sql(
- testing.db,
- "create global temporary table my_temp_table (id integer)",
- )
- exec_sql(
- testing.db, "create table foo_table (id integer) tablespace SYSTEM"
- )
+ with testing.db.begin() as conn:
+ conn.exec_driver_sql("create table my_table (id integer)")
+ conn.exec_driver_sql(
+ "create global temporary table my_temp_table (id integer)",
+ )
+ conn.exec_driver_sql(
+ "create table foo_table (id integer) tablespace SYSTEM"
+ )
def teardown(self):
- exec_sql(testing.db, "drop table my_temp_table")
- exec_sql(testing.db, "drop table my_table")
- exec_sql(testing.db, "drop table foo_table")
+ with testing.db.begin() as conn:
+ conn.exec_driver_sql("drop table my_temp_table")
+ conn.exec_driver_sql("drop table my_table")
+ conn.exec_driver_sql("drop table foo_table")
def test_table_names_no_system(self):
insp = inspect(testing.db)
@@ -430,24 +426,25 @@ class DontReflectIOTTest(fixtures.TestBase):
__backend__ = True
def setup(self):
- exec_sql(
- testing.db,
- """
- CREATE TABLE admin_docindex(
- token char(20),
- doc_id NUMBER,
- token_frequency NUMBER,
- token_offsets VARCHAR2(2000),
- CONSTRAINT pk_admin_docindex PRIMARY KEY (token, doc_id))
- ORGANIZATION INDEX
- TABLESPACE users
- PCTTHRESHOLD 20
- OVERFLOW TABLESPACE users
- """,
- )
+ with testing.db.begin() as conn:
+ conn.exec_driver_sql(
+ """
+ CREATE TABLE admin_docindex(
+ token char(20),
+ doc_id NUMBER,
+ token_frequency NUMBER,
+ token_offsets VARCHAR2(2000),
+ CONSTRAINT pk_admin_docindex PRIMARY KEY (token, doc_id))
+ ORGANIZATION INDEX
+ TABLESPACE users
+ PCTTHRESHOLD 20
+ OVERFLOW TABLESPACE users
+ """,
+ )
def teardown(self):
- exec_sql(testing.db, "drop table admin_docindex")
+ with testing.db.begin() as conn:
+ conn.exec_driver_sql("drop table admin_docindex")
def test_reflect_all(self):
m = MetaData(testing.db)
@@ -456,30 +453,24 @@ class DontReflectIOTTest(fixtures.TestBase):
def all_tables_compression_missing():
- try:
- exec_sql(testing.db, "SELECT compression FROM all_tables")
+ with testing.db.connect() as conn:
if (
"Enterprise Edition"
- not in exec_sql(testing.db, "select * from v$version").scalar()
+ not in conn.exec_driver_sql("select * from v$version").scalar()
# this works in Oracle Database 18c Express Edition Release
) and testing.db.dialect.server_version_info < (18,):
return True
return False
- except Exception:
- return True
def all_tables_compress_for_missing():
- try:
- exec_sql(testing.db, "SELECT compress_for FROM all_tables")
+ with testing.db.connect() as conn:
if (
"Enterprise Edition"
- not in exec_sql(testing.db, "select * from v$version").scalar()
+ not in conn.exec_driver_sql("select * from v$version").scalar()
):
return True
return False
- except Exception:
- return True
class TableReflectionTest(fixtures.TestBase):
@@ -748,7 +739,7 @@ class DBLinkReflectionTest(fixtures.TestBase):
# note that the synonym here is still not totally functional
# when accessing via a different username as we do with the
# multiprocess test suite, so testing here is minimal
- with testing.db.connect() as conn:
+ with testing.db.begin() as conn:
conn.exec_driver_sql(
"create table test_table "
"(id integer primary key, data varchar2(50))"
@@ -760,7 +751,7 @@ class DBLinkReflectionTest(fixtures.TestBase):
@classmethod
def teardown_class(cls):
- with testing.db.connect() as conn:
+ with testing.db.begin() as conn:
conn.exec_driver_sql("drop synonym test_table_syn")
conn.exec_driver_sql("drop table test_table")
diff --git a/test/dialect/oracle/test_types.py b/test/dialect/oracle/test_types.py
index 8fbf374ee..db3825d13 100644
--- a/test/dialect/oracle/test_types.py
+++ b/test/dialect/oracle/test_types.py
@@ -228,16 +228,16 @@ class TypesTest(fixtures.TestBase):
@testing.requires.returning
@testing.provide_metadata
- def test_int_not_float(self):
+ def test_int_not_float(self, connection):
m = self.metadata
t1 = Table("t1", m, Column("foo", Integer))
- t1.create()
- r = t1.insert().values(foo=5).returning(t1.c.foo).execute()
+ t1.create(connection)
+ r = connection.execute(t1.insert().values(foo=5).returning(t1.c.foo))
x = r.scalar()
assert x == 5
assert isinstance(x, int)
- x = t1.select().scalar()
+ x = connection.scalar(t1.select())
assert x == 5
assert isinstance(x, int)
@@ -281,7 +281,7 @@ class TypesTest(fixtures.TestBase):
eq_(conn.execute(s3).fetchall(), [(5, rowid)])
@testing.provide_metadata
- def test_interval(self):
+ def test_interval(self, connection):
metadata = self.metadata
interval_table = Table(
"intervaltable",
@@ -291,11 +291,12 @@ class TypesTest(fixtures.TestBase):
),
Column("day_interval", oracle.INTERVAL(day_precision=3)),
)
- metadata.create_all()
- interval_table.insert().execute(
- day_interval=datetime.timedelta(days=35, seconds=5743)
+ metadata.create_all(connection)
+ connection.execute(
+ interval_table.insert(),
+ dict(day_interval=datetime.timedelta(days=35, seconds=5743)),
)
- row = interval_table.select().execute().first()
+ row = connection.execute(interval_table.select()).first()
eq_(row["day_interval"], datetime.timedelta(days=35, seconds=5743))
@testing.provide_metadata
@@ -364,16 +365,19 @@ class TypesTest(fixtures.TestBase):
Column("intcol", Integer),
Column("numericcol", oracle.BINARY_DOUBLE(asdecimal=False)),
)
- t1.create()
- t1.insert().execute(
+ t1.create(connection)
+ connection.execute(
+ t1.insert(),
[
dict(intcol=1, numericcol=float("inf")),
dict(intcol=2, numericcol=float("-inf")),
- ]
+ ],
)
eq_(
- select(t1.c.numericcol).order_by(t1.c.intcol).execute().fetchall(),
+ connection.execute(
+ select(t1.c.numericcol).order_by(t1.c.intcol)
+ ).fetchall(),
[(float("inf"),), (float("-inf"),)],
)
@@ -393,16 +397,19 @@ class TypesTest(fixtures.TestBase):
Column("intcol", Integer),
Column("numericcol", oracle.BINARY_DOUBLE(asdecimal=True)),
)
- t1.create()
- t1.insert().execute(
+ t1.create(connection)
+ connection.execute(
+ t1.insert(),
[
dict(intcol=1, numericcol=decimal.Decimal("Infinity")),
dict(intcol=2, numericcol=decimal.Decimal("-Infinity")),
- ]
+ ],
)
eq_(
- select(t1.c.numericcol).order_by(t1.c.intcol).execute().fetchall(),
+ connection.execute(
+ select(t1.c.numericcol).order_by(t1.c.intcol)
+ ).fetchall(),
[(decimal.Decimal("Infinity"),), (decimal.Decimal("-Infinity"),)],
)
@@ -422,20 +429,21 @@ class TypesTest(fixtures.TestBase):
Column("intcol", Integer),
Column("numericcol", oracle.BINARY_DOUBLE(asdecimal=False)),
)
- t1.create()
- t1.insert().execute(
+ t1.create(connection)
+ connection.execute(
+ t1.insert(),
[
dict(intcol=1, numericcol=float("nan")),
dict(intcol=2, numericcol=float("-nan")),
- ]
+ ],
)
eq_(
[
tuple(str(col) for col in row)
- for row in select(t1.c.numericcol)
- .order_by(t1.c.intcol)
- .execute()
+ for row in connection.execute(
+ select(t1.c.numericcol).order_by(t1.c.intcol)
+ )
],
[("nan",), ("nan",)],
)
@@ -786,7 +794,7 @@ class TypesTest(fixtures.TestBase):
eq_(connection.execute(raw_table.select()).first(), (1, b("ABCDEF")))
@testing.provide_metadata
- def test_reflect_nvarchar(self):
+ def test_reflect_nvarchar(self, connection):
metadata = self.metadata
Table(
"tnv",
@@ -794,31 +802,30 @@ class TypesTest(fixtures.TestBase):
Column("nv_data", sqltypes.NVARCHAR(255)),
Column("c_data", sqltypes.NCHAR(20)),
)
- metadata.create_all()
+ metadata.create_all(connection)
m2 = MetaData()
- t2 = Table("tnv", m2, autoload_with=testing.db)
+ t2 = Table("tnv", m2, autoload_with=connection)
assert isinstance(t2.c.nv_data.type, sqltypes.NVARCHAR)
assert isinstance(t2.c.c_data.type, sqltypes.NCHAR)
if testing.against("oracle+cx_oracle"):
assert isinstance(
- t2.c.nv_data.type.dialect_impl(testing.db.dialect),
+ t2.c.nv_data.type.dialect_impl(connection.dialect),
cx_oracle._OracleUnicodeStringNCHAR,
)
assert isinstance(
- t2.c.c_data.type.dialect_impl(testing.db.dialect),
+ t2.c.c_data.type.dialect_impl(connection.dialect),
cx_oracle._OracleNChar,
)
data = u("m’a réveillé.")
- with testing.db.connect() as conn:
- conn.execute(t2.insert(), dict(nv_data=data, c_data=data))
- nv_data, c_data = conn.execute(t2.select()).first()
- eq_(nv_data, data)
- eq_(c_data, data + (" " * 7)) # char is space padded
- assert isinstance(nv_data, util.text_type)
- assert isinstance(c_data, util.text_type)
+ connection.execute(t2.insert(), dict(nv_data=data, c_data=data))
+ nv_data, c_data = connection.execute(t2.select()).first()
+ eq_(nv_data, data)
+ eq_(c_data, data + (" " * 7)) # char is space padded
+ assert isinstance(nv_data, util.text_type)
+ assert isinstance(c_data, util.text_type)
@testing.provide_metadata
def test_reflect_unicode_no_nvarchar(self):
@@ -1183,7 +1190,7 @@ class SetInputSizesTest(fixtures.TestBase):
else:
engine = testing.db
- with engine.connect() as conn:
+ with engine.begin() as conn:
connection_fairy = conn.connection
for tab in [t1, t2, t3]:
with mock.patch.object(
diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py
index 5cea604d6..3bd8e9da0 100644
--- a/test/dialect/postgresql/test_dialect.py
+++ b/test/dialect/postgresql/test_dialect.py
@@ -36,6 +36,7 @@ from sqlalchemy.dialects.postgresql.psycopg2 import EXECUTEMANY_VALUES
from sqlalchemy.engine import cursor as _cursor
from sqlalchemy.engine import engine_from_config
from sqlalchemy.engine import url
+from sqlalchemy.testing import config
from sqlalchemy.testing import engines
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
@@ -51,7 +52,7 @@ from sqlalchemy.testing.assertions import eq_regex
from sqlalchemy.testing.assertions import ne_
from sqlalchemy.util import u
from sqlalchemy.util import ue
-from ...engine import test_execute
+from ...engine import test_deprecations
if True:
from sqlalchemy.dialects.postgresql.psycopg2 import (
@@ -195,6 +196,20 @@ class ExecuteManyMode(object):
options = None
+ @config.fixture()
+ def connection(self):
+ eng = engines.testing_engine(options=self.options)
+
+ conn = eng.connect()
+ trans = conn.begin()
+ try:
+ yield conn
+ finally:
+ if trans.is_active:
+ trans.rollback()
+ conn.close()
+ eng.dispose()
+
@classmethod
def define_tables(cls, metadata):
Table(
@@ -213,20 +228,12 @@ class ExecuteManyMode(object):
Column(ue("\u6e2c\u8a66"), Integer),
)
- def setup(self):
- super(ExecuteManyMode, self).setup()
- self.engine = engines.testing_engine(options=self.options)
-
- def teardown(self):
- self.engine.dispose()
- super(ExecuteManyMode, self).teardown()
-
- def test_insert(self):
+ def test_insert(self, connection):
from psycopg2 import extras
- values_page_size = self.engine.dialect.executemany_values_page_size
- batch_page_size = self.engine.dialect.executemany_batch_page_size
- if self.engine.dialect.executemany_mode & EXECUTEMANY_VALUES:
+ values_page_size = connection.dialect.executemany_values_page_size
+ batch_page_size = connection.dialect.executemany_batch_page_size
+ if connection.dialect.executemany_mode & EXECUTEMANY_VALUES:
meth = extras.execute_values
stmt = "INSERT INTO data (x, y) VALUES %s"
expected_kwargs = {
@@ -234,7 +241,7 @@ class ExecuteManyMode(object):
"page_size": values_page_size,
"fetch": False,
}
- elif self.engine.dialect.executemany_mode & EXECUTEMANY_BATCH:
+ elif connection.dialect.executemany_mode & EXECUTEMANY_BATCH:
meth = extras.execute_batch
stmt = "INSERT INTO data (x, y) VALUES (%(x)s, %(y)s)"
expected_kwargs = {"page_size": batch_page_size}
@@ -244,24 +251,23 @@ class ExecuteManyMode(object):
with mock.patch.object(
extras, meth.__name__, side_effect=meth
) as mock_exec:
- with self.engine.connect() as conn:
- conn.execute(
- self.tables.data.insert(),
- [
- {"x": "x1", "y": "y1"},
- {"x": "x2", "y": "y2"},
- {"x": "x3", "y": "y3"},
- ],
- )
+ connection.execute(
+ self.tables.data.insert(),
+ [
+ {"x": "x1", "y": "y1"},
+ {"x": "x2", "y": "y2"},
+ {"x": "x3", "y": "y3"},
+ ],
+ )
- eq_(
- conn.execute(select(self.tables.data)).fetchall(),
- [
- (1, "x1", "y1", 5),
- (2, "x2", "y2", 5),
- (3, "x3", "y3", 5),
- ],
- )
+ eq_(
+ connection.execute(select(self.tables.data)).fetchall(),
+ [
+ (1, "x1", "y1", 5),
+ (2, "x2", "y2", 5),
+ (3, "x3", "y3", 5),
+ ],
+ )
eq_(
mock_exec.mock_calls,
[
@@ -278,14 +284,13 @@ class ExecuteManyMode(object):
],
)
- def test_insert_no_page_size(self):
+ def test_insert_no_page_size(self, connection):
from psycopg2 import extras
- values_page_size = self.engine.dialect.executemany_values_page_size
- batch_page_size = self.engine.dialect.executemany_batch_page_size
+ values_page_size = connection.dialect.executemany_values_page_size
+ batch_page_size = connection.dialect.executemany_batch_page_size
- eng = self.engine
- if eng.dialect.executemany_mode & EXECUTEMANY_VALUES:
+ if connection.dialect.executemany_mode & EXECUTEMANY_VALUES:
meth = extras.execute_values
stmt = "INSERT INTO data (x, y) VALUES %s"
expected_kwargs = {
@@ -293,7 +298,7 @@ class ExecuteManyMode(object):
"page_size": values_page_size,
"fetch": False,
}
- elif eng.dialect.executemany_mode & EXECUTEMANY_BATCH:
+ elif connection.dialect.executemany_mode & EXECUTEMANY_BATCH:
meth = extras.execute_batch
stmt = "INSERT INTO data (x, y) VALUES (%(x)s, %(y)s)"
expected_kwargs = {"page_size": batch_page_size}
@@ -303,15 +308,14 @@ class ExecuteManyMode(object):
with mock.patch.object(
extras, meth.__name__, side_effect=meth
) as mock_exec:
- with eng.connect() as conn:
- conn.execute(
- self.tables.data.insert(),
- [
- {"x": "x1", "y": "y1"},
- {"x": "x2", "y": "y2"},
- {"x": "x3", "y": "y3"},
- ],
- )
+ connection.execute(
+ self.tables.data.insert(),
+ [
+ {"x": "x1", "y": "y1"},
+ {"x": "x2", "y": "y2"},
+ {"x": "x3", "y": "y3"},
+ ],
+ )
eq_(
mock_exec.mock_calls,
@@ -356,7 +360,7 @@ class ExecuteManyMode(object):
with mock.patch.object(
extras, meth.__name__, side_effect=meth
) as mock_exec:
- with eng.connect() as conn:
+ with eng.begin() as conn:
conn.execute(
self.tables.data.insert(),
[
@@ -398,11 +402,10 @@ class ExecuteManyMode(object):
eq_(connection.execute(table.select()).all(), [(1, 1), (2, 2), (3, 3)])
- def test_update_fallback(self):
+ def test_update_fallback(self, connection):
from psycopg2 import extras
- batch_page_size = self.engine.dialect.executemany_batch_page_size
- eng = self.engine
+ batch_page_size = connection.dialect.executemany_batch_page_size
meth = extras.execute_batch
stmt = "UPDATE data SET y=%(yval)s WHERE data.x = %(xval)s"
expected_kwargs = {"page_size": batch_page_size}
@@ -410,18 +413,17 @@ class ExecuteManyMode(object):
with mock.patch.object(
extras, meth.__name__, side_effect=meth
) as mock_exec:
- with eng.connect() as conn:
- conn.execute(
- self.tables.data.update()
- .where(self.tables.data.c.x == bindparam("xval"))
- .values(y=bindparam("yval")),
- [
- {"xval": "x1", "yval": "y5"},
- {"xval": "x3", "yval": "y6"},
- ],
- )
+ connection.execute(
+ self.tables.data.update()
+ .where(self.tables.data.c.x == bindparam("xval"))
+ .values(y=bindparam("yval")),
+ [
+ {"xval": "x1", "yval": "y5"},
+ {"xval": "x3", "yval": "y6"},
+ ],
+ )
- if eng.dialect.executemany_mode & EXECUTEMANY_BATCH:
+ if connection.dialect.executemany_mode & EXECUTEMANY_BATCH:
eq_(
mock_exec.mock_calls,
[
@@ -439,36 +441,34 @@ class ExecuteManyMode(object):
else:
eq_(mock_exec.mock_calls, [])
- def test_not_sane_rowcount(self):
- self.engine.connect().close()
- if self.engine.dialect.executemany_mode & EXECUTEMANY_BATCH:
- assert not self.engine.dialect.supports_sane_multi_rowcount
+ def test_not_sane_rowcount(self, connection):
+ if connection.dialect.executemany_mode & EXECUTEMANY_BATCH:
+ assert not connection.dialect.supports_sane_multi_rowcount
else:
- assert self.engine.dialect.supports_sane_multi_rowcount
+ assert connection.dialect.supports_sane_multi_rowcount
- def test_update(self):
- with self.engine.connect() as conn:
- conn.execute(
- self.tables.data.insert(),
- [
- {"x": "x1", "y": "y1"},
- {"x": "x2", "y": "y2"},
- {"x": "x3", "y": "y3"},
- ],
- )
+ def test_update(self, connection):
+ connection.execute(
+ self.tables.data.insert(),
+ [
+ {"x": "x1", "y": "y1"},
+ {"x": "x2", "y": "y2"},
+ {"x": "x3", "y": "y3"},
+ ],
+ )
- conn.execute(
- self.tables.data.update()
- .where(self.tables.data.c.x == bindparam("xval"))
- .values(y=bindparam("yval")),
- [{"xval": "x1", "yval": "y5"}, {"xval": "x3", "yval": "y6"}],
- )
- eq_(
- conn.execute(
- select(self.tables.data).order_by(self.tables.data.c.id)
- ).fetchall(),
- [(1, "x1", "y5", 5), (2, "x2", "y2", 5), (3, "x3", "y6", 5)],
- )
+ connection.execute(
+ self.tables.data.update()
+ .where(self.tables.data.c.x == bindparam("xval"))
+ .values(y=bindparam("yval")),
+ [{"xval": "x1", "yval": "y5"}, {"xval": "x3", "yval": "y6"}],
+ )
+ eq_(
+ connection.execute(
+ select(self.tables.data).order_by(self.tables.data.c.id)
+ ).fetchall(),
+ [(1, "x1", "y5", 5), (2, "x2", "y2", 5), (3, "x3", "y6", 5)],
+ )
class ExecutemanyBatchModeTest(ExecuteManyMode, fixtures.TablesTest):
@@ -578,7 +578,7 @@ class ExecutemanyValuesInsertsTest(ExecuteManyMode, fixtures.TablesTest):
[(pk,) for pk in range(1 + first_pk, total_rows + first_pk)],
)
- def test_insert_w_newlines(self):
+ def test_insert_w_newlines(self, connection):
from psycopg2 import extras
t = self.tables.data
@@ -606,15 +606,14 @@ class ExecutemanyValuesInsertsTest(ExecuteManyMode, fixtures.TablesTest):
extras, "execute_values", side_effect=meth
) as mock_exec:
- with self.engine.connect() as conn:
- conn.execute(
- ins,
- [
- {"id": 1, "y": "y1", "z": 1},
- {"id": 2, "y": "y2", "z": 2},
- {"id": 3, "y": "y3", "z": 3},
- ],
- )
+ connection.execute(
+ ins,
+ [
+ {"id": 1, "y": "y1", "z": 1},
+ {"id": 2, "y": "y2", "z": 2},
+ {"id": 3, "y": "y3", "z": 3},
+ ],
+ )
eq_(
mock_exec.mock_calls,
@@ -629,12 +628,12 @@ class ExecutemanyValuesInsertsTest(ExecuteManyMode, fixtures.TablesTest):
),
template="(%(id)s, (SELECT 5 \nFROM data), %(y)s, %(z)s)",
fetch=False,
- page_size=conn.dialect.executemany_values_page_size,
+ page_size=connection.dialect.executemany_values_page_size,
)
],
)
- def test_insert_modified_by_event(self):
+ def test_insert_modified_by_event(self, connection):
from psycopg2 import extras
t = self.tables.data
@@ -664,33 +663,33 @@ class ExecutemanyValuesInsertsTest(ExecuteManyMode, fixtures.TablesTest):
extras, "execute_batch", side_effect=meth
) as mock_batch:
- with self.engine.connect() as conn:
-
- # create an event hook that will change the statement to
- # something else, meaning the dialect has to detect that
- # insert_single_values_expr is no longer useful
- @event.listens_for(conn, "before_cursor_execute", retval=True)
- def before_cursor_execute(
- conn, cursor, statement, parameters, context, executemany
- ):
- statement = (
- "INSERT INTO data (id, y, z) VALUES "
- "(%(id)s, %(y)s, %(z)s)"
- )
- return statement, parameters
-
- conn.execute(
- ins,
- [
- {"id": 1, "y": "y1", "z": 1},
- {"id": 2, "y": "y2", "z": 2},
- {"id": 3, "y": "y3", "z": 3},
- ],
+ # create an event hook that will change the statement to
+ # something else, meaning the dialect has to detect that
+ # insert_single_values_expr is no longer useful
+ @event.listens_for(
+ connection, "before_cursor_execute", retval=True
+ )
+ def before_cursor_execute(
+ conn, cursor, statement, parameters, context, executemany
+ ):
+ statement = (
+ "INSERT INTO data (id, y, z) VALUES "
+ "(%(id)s, %(y)s, %(z)s)"
)
+ return statement, parameters
+
+ connection.execute(
+ ins,
+ [
+ {"id": 1, "y": "y1", "z": 1},
+ {"id": 2, "y": "y2", "z": 2},
+ {"id": 3, "y": "y3", "z": 3},
+ ],
+ )
eq_(mock_values.mock_calls, [])
- if self.engine.dialect.executemany_mode & EXECUTEMANY_BATCH:
+ if connection.dialect.executemany_mode & EXECUTEMANY_BATCH:
eq_(
mock_batch.mock_calls,
[
@@ -727,10 +726,10 @@ class ExecutemanyFlagOptionsTest(fixtures.TablesTest):
("values_only", EXECUTEMANY_VALUES),
("values_plus_batch", EXECUTEMANY_VALUES_PLUS_BATCH),
]:
- self.engine = engines.testing_engine(
+ connection = engines.testing_engine(
options={"executemany_mode": opt}
)
- is_(self.engine.dialect.executemany_mode, expected)
+ is_(connection.dialect.executemany_mode, expected)
def test_executemany_wrong_flag_options(self):
for opt in [1, True, "batch_insert"]:
@@ -1082,7 +1081,7 @@ $$ LANGUAGE plpgsql;
t.create(connection, checkfirst=True)
@testing.provide_metadata
- def test_schema_roundtrips(self):
+ def test_schema_roundtrips(self, connection):
meta = self.metadata
users = Table(
"users",
@@ -1091,33 +1090,37 @@ $$ LANGUAGE plpgsql;
Column("name", String(50)),
schema="test_schema",
)
- users.create()
- users.insert().execute(id=1, name="name1")
- users.insert().execute(id=2, name="name2")
- users.insert().execute(id=3, name="name3")
- users.insert().execute(id=4, name="name4")
+ users.create(connection)
+ connection.execute(users.insert(), dict(id=1, name="name1"))
+ connection.execute(users.insert(), dict(id=2, name="name2"))
+ connection.execute(users.insert(), dict(id=3, name="name3"))
+ connection.execute(users.insert(), dict(id=4, name="name4"))
eq_(
- users.select().where(users.c.name == "name2").execute().fetchall(),
+ connection.execute(
+ users.select().where(users.c.name == "name2")
+ ).fetchall(),
[(2, "name2")],
)
eq_(
- users.select(use_labels=True)
- .where(users.c.name == "name2")
- .execute()
- .fetchall(),
+ connection.execute(
+ users.select().apply_labels().where(users.c.name == "name2")
+ ).fetchall(),
[(2, "name2")],
)
- users.delete().where(users.c.id == 3).execute()
+ connection.execute(users.delete().where(users.c.id == 3))
eq_(
- users.select().where(users.c.name == "name3").execute().fetchall(),
+ connection.execute(
+ users.select().where(users.c.name == "name3")
+ ).fetchall(),
[],
)
- users.update().where(users.c.name == "name4").execute(name="newname")
+ connection.execute(
+ users.update().where(users.c.name == "name4"), dict(name="newname")
+ )
eq_(
- users.select(use_labels=True)
- .where(users.c.id == 4)
- .execute()
- .fetchall(),
+ connection.execute(
+ users.select().apply_labels().where(users.c.id == 4)
+ ).fetchall(),
[(4, "newname")],
)
@@ -1233,7 +1236,7 @@ $$ LANGUAGE plpgsql;
ne_(conn.connection.status, STATUS_IN_TRANSACTION)
-class AutocommitTextTest(test_execute.AutocommitTextTest):
+class AutocommitTextTest(test_deprecations.AutocommitTextTest):
__only_on__ = "postgresql"
def test_grant(self):
diff --git a/test/dialect/postgresql/test_on_conflict.py b/test/dialect/postgresql/test_on_conflict.py
index 760487842..4e96cc6a2 100644
--- a/test/dialect/postgresql/test_on_conflict.py
+++ b/test/dialect/postgresql/test_on_conflict.py
@@ -99,28 +99,29 @@ class OnConflictTest(fixtures.TablesTest):
ValueError, insert(self.tables.users).on_conflict_do_update
)
- def test_on_conflict_do_nothing(self):
+ def test_on_conflict_do_nothing(self, connection):
users = self.tables.users
- with testing.db.connect() as conn:
- result = conn.execute(
- insert(users).on_conflict_do_nothing(),
- dict(id=1, name="name1"),
- )
- eq_(result.inserted_primary_key, (1,))
- eq_(result.returned_defaults, None)
-
- result = conn.execute(
- insert(users).on_conflict_do_nothing(),
- dict(id=1, name="name2"),
- )
- eq_(result.inserted_primary_key, (1,))
- eq_(result.returned_defaults, None)
-
- eq_(
- conn.execute(users.select().where(users.c.id == 1)).fetchall(),
- [(1, "name1")],
- )
+ result = connection.execute(
+ insert(users).on_conflict_do_nothing(),
+ dict(id=1, name="name1"),
+ )
+ eq_(result.inserted_primary_key, (1,))
+ eq_(result.returned_defaults, None)
+
+ result = connection.execute(
+ insert(users).on_conflict_do_nothing(),
+ dict(id=1, name="name2"),
+ )
+ eq_(result.inserted_primary_key, (1,))
+ eq_(result.returned_defaults, None)
+
+ eq_(
+ connection.execute(
+ users.select().where(users.c.id == 1)
+ ).fetchall(),
+ [(1, "name1")],
+ )
def test_on_conflict_do_nothing_connectionless(self, connection):
users = self.tables.users_xtra
@@ -147,95 +148,99 @@ class OnConflictTest(fixtures.TablesTest):
)
@testing.provide_metadata
- def test_on_conflict_do_nothing_target(self):
+ def test_on_conflict_do_nothing_target(self, connection):
users = self.tables.users
- with testing.db.connect() as conn:
- result = conn.execute(
- insert(users).on_conflict_do_nothing(
- index_elements=users.primary_key.columns
- ),
- dict(id=1, name="name1"),
- )
- eq_(result.inserted_primary_key, (1,))
- eq_(result.returned_defaults, None)
-
- result = conn.execute(
- insert(users).on_conflict_do_nothing(
- index_elements=users.primary_key.columns
- ),
- dict(id=1, name="name2"),
- )
- eq_(result.inserted_primary_key, (1,))
- eq_(result.returned_defaults, None)
-
- eq_(
- conn.execute(users.select().where(users.c.id == 1)).fetchall(),
- [(1, "name1")],
- )
-
- def test_on_conflict_do_update_one(self):
+ result = connection.execute(
+ insert(users).on_conflict_do_nothing(
+ index_elements=users.primary_key.columns
+ ),
+ dict(id=1, name="name1"),
+ )
+ eq_(result.inserted_primary_key, (1,))
+ eq_(result.returned_defaults, None)
+
+ result = connection.execute(
+ insert(users).on_conflict_do_nothing(
+ index_elements=users.primary_key.columns
+ ),
+ dict(id=1, name="name2"),
+ )
+ eq_(result.inserted_primary_key, (1,))
+ eq_(result.returned_defaults, None)
+
+ eq_(
+ connection.execute(
+ users.select().where(users.c.id == 1)
+ ).fetchall(),
+ [(1, "name1")],
+ )
+
+ def test_on_conflict_do_update_one(self, connection):
users = self.tables.users
- with testing.db.connect() as conn:
- conn.execute(users.insert(), dict(id=1, name="name1"))
+ connection.execute(users.insert(), dict(id=1, name="name1"))
- i = insert(users)
- i = i.on_conflict_do_update(
- index_elements=[users.c.id], set_=dict(name=i.excluded.name)
- )
- result = conn.execute(i, dict(id=1, name="name1"))
+ i = insert(users)
+ i = i.on_conflict_do_update(
+ index_elements=[users.c.id], set_=dict(name=i.excluded.name)
+ )
+ result = connection.execute(i, dict(id=1, name="name1"))
- eq_(result.inserted_primary_key, (1,))
- eq_(result.returned_defaults, None)
+ eq_(result.inserted_primary_key, (1,))
+ eq_(result.returned_defaults, None)
- eq_(
- conn.execute(users.select().where(users.c.id == 1)).fetchall(),
- [(1, "name1")],
- )
+ eq_(
+ connection.execute(
+ users.select().where(users.c.id == 1)
+ ).fetchall(),
+ [(1, "name1")],
+ )
- def test_on_conflict_do_update_schema(self):
+ def test_on_conflict_do_update_schema(self, connection):
users = self.tables.get("%s.users_schema" % config.test_schema)
- with testing.db.connect() as conn:
- conn.execute(users.insert(), dict(id=1, name="name1"))
+ connection.execute(users.insert(), dict(id=1, name="name1"))
- i = insert(users)
- i = i.on_conflict_do_update(
- index_elements=[users.c.id], set_=dict(name=i.excluded.name)
- )
- result = conn.execute(i, dict(id=1, name="name1"))
+ i = insert(users)
+ i = i.on_conflict_do_update(
+ index_elements=[users.c.id], set_=dict(name=i.excluded.name)
+ )
+ result = connection.execute(i, dict(id=1, name="name1"))
- eq_(result.inserted_primary_key, (1,))
- eq_(result.returned_defaults, None)
+ eq_(result.inserted_primary_key, (1,))
+ eq_(result.returned_defaults, None)
- eq_(
- conn.execute(users.select().where(users.c.id == 1)).fetchall(),
- [(1, "name1")],
- )
+ eq_(
+ connection.execute(
+ users.select().where(users.c.id == 1)
+ ).fetchall(),
+ [(1, "name1")],
+ )
- def test_on_conflict_do_update_column_as_key_set(self):
+ def test_on_conflict_do_update_column_as_key_set(self, connection):
users = self.tables.users
- with testing.db.connect() as conn:
- conn.execute(users.insert(), dict(id=1, name="name1"))
+ connection.execute(users.insert(), dict(id=1, name="name1"))
- i = insert(users)
- i = i.on_conflict_do_update(
- index_elements=[users.c.id],
- set_={users.c.name: i.excluded.name},
- )
- result = conn.execute(i, dict(id=1, name="name1"))
+ i = insert(users)
+ i = i.on_conflict_do_update(
+ index_elements=[users.c.id],
+ set_={users.c.name: i.excluded.name},
+ )
+ result = connection.execute(i, dict(id=1, name="name1"))
- eq_(result.inserted_primary_key, (1,))
- eq_(result.returned_defaults, None)
+ eq_(result.inserted_primary_key, (1,))
+ eq_(result.returned_defaults, None)
- eq_(
- conn.execute(users.select().where(users.c.id == 1)).fetchall(),
- [(1, "name1")],
- )
+ eq_(
+ connection.execute(
+ users.select().where(users.c.id == 1)
+ ).fetchall(),
+ [(1, "name1")],
+ )
- def test_on_conflict_do_update_clauseelem_as_key_set(self):
+ def test_on_conflict_do_update_clauseelem_as_key_set(self, connection):
users = self.tables.users
class MyElem(object):
@@ -245,162 +250,165 @@ class OnConflictTest(fixtures.TablesTest):
def __clause_element__(self):
return self.expr
- with testing.db.connect() as conn:
- conn.execute(
- users.insert(),
- {"id": 1, "name": "name1"},
- )
+ connection.execute(
+ users.insert(),
+ {"id": 1, "name": "name1"},
+ )
- i = insert(users)
- i = i.on_conflict_do_update(
- index_elements=[users.c.id],
- set_={MyElem(users.c.name): i.excluded.name},
- ).values({MyElem(users.c.id): 1, MyElem(users.c.name): "name1"})
- result = conn.execute(i)
+ i = insert(users)
+ i = i.on_conflict_do_update(
+ index_elements=[users.c.id],
+ set_={MyElem(users.c.name): i.excluded.name},
+ ).values({MyElem(users.c.id): 1, MyElem(users.c.name): "name1"})
+ result = connection.execute(i)
- eq_(result.inserted_primary_key, (1,))
- eq_(result.returned_defaults, None)
+ eq_(result.inserted_primary_key, (1,))
+ eq_(result.returned_defaults, None)
- eq_(
- conn.execute(users.select().where(users.c.id == 1)).fetchall(),
- [(1, "name1")],
- )
+ eq_(
+ connection.execute(
+ users.select().where(users.c.id == 1)
+ ).fetchall(),
+ [(1, "name1")],
+ )
- def test_on_conflict_do_update_column_as_key_set_schema(self):
+ def test_on_conflict_do_update_column_as_key_set_schema(self, connection):
users = self.tables.get("%s.users_schema" % config.test_schema)
- with testing.db.connect() as conn:
- conn.execute(users.insert(), dict(id=1, name="name1"))
+ connection.execute(users.insert(), dict(id=1, name="name1"))
- i = insert(users)
- i = i.on_conflict_do_update(
- index_elements=[users.c.id],
- set_={users.c.name: i.excluded.name},
- )
- result = conn.execute(i, dict(id=1, name="name1"))
+ i = insert(users)
+ i = i.on_conflict_do_update(
+ index_elements=[users.c.id],
+ set_={users.c.name: i.excluded.name},
+ )
+ result = connection.execute(i, dict(id=1, name="name1"))
- eq_(result.inserted_primary_key, (1,))
- eq_(result.returned_defaults, None)
+ eq_(result.inserted_primary_key, (1,))
+ eq_(result.returned_defaults, None)
- eq_(
- conn.execute(users.select().where(users.c.id == 1)).fetchall(),
- [(1, "name1")],
- )
+ eq_(
+ connection.execute(
+ users.select().where(users.c.id == 1)
+ ).fetchall(),
+ [(1, "name1")],
+ )
- def test_on_conflict_do_update_two(self):
+ def test_on_conflict_do_update_two(self, connection):
users = self.tables.users
- with testing.db.connect() as conn:
- conn.execute(users.insert(), dict(id=1, name="name1"))
+ connection.execute(users.insert(), dict(id=1, name="name1"))
- i = insert(users)
- i = i.on_conflict_do_update(
- index_elements=[users.c.id],
- set_=dict(id=i.excluded.id, name=i.excluded.name),
- )
+ i = insert(users)
+ i = i.on_conflict_do_update(
+ index_elements=[users.c.id],
+ set_=dict(id=i.excluded.id, name=i.excluded.name),
+ )
- result = conn.execute(i, dict(id=1, name="name2"))
- eq_(result.inserted_primary_key, (1,))
- eq_(result.returned_defaults, None)
+ result = connection.execute(i, dict(id=1, name="name2"))
+ eq_(result.inserted_primary_key, (1,))
+ eq_(result.returned_defaults, None)
- eq_(
- conn.execute(users.select().where(users.c.id == 1)).fetchall(),
- [(1, "name2")],
- )
+ eq_(
+ connection.execute(
+ users.select().where(users.c.id == 1)
+ ).fetchall(),
+ [(1, "name2")],
+ )
- def test_on_conflict_do_update_three(self):
+ def test_on_conflict_do_update_three(self, connection):
users = self.tables.users
- with testing.db.connect() as conn:
- conn.execute(users.insert(), dict(id=1, name="name1"))
+ connection.execute(users.insert(), dict(id=1, name="name1"))
- i = insert(users)
- i = i.on_conflict_do_update(
- index_elements=users.primary_key.columns,
- set_=dict(name=i.excluded.name),
- )
- result = conn.execute(i, dict(id=1, name="name3"))
- eq_(result.inserted_primary_key, (1,))
- eq_(result.returned_defaults, None)
+ i = insert(users)
+ i = i.on_conflict_do_update(
+ index_elements=users.primary_key.columns,
+ set_=dict(name=i.excluded.name),
+ )
+ result = connection.execute(i, dict(id=1, name="name3"))
+ eq_(result.inserted_primary_key, (1,))
+ eq_(result.returned_defaults, None)
- eq_(
- conn.execute(users.select().where(users.c.id == 1)).fetchall(),
- [(1, "name3")],
- )
+ eq_(
+ connection.execute(
+ users.select().where(users.c.id == 1)
+ ).fetchall(),
+ [(1, "name3")],
+ )
- def test_on_conflict_do_update_four(self):
+ def test_on_conflict_do_update_four(self, connection):
users = self.tables.users
- with testing.db.connect() as conn:
- conn.execute(users.insert(), dict(id=1, name="name1"))
+ connection.execute(users.insert(), dict(id=1, name="name1"))
- i = insert(users)
- i = i.on_conflict_do_update(
- index_elements=users.primary_key.columns,
- set_=dict(id=i.excluded.id, name=i.excluded.name),
- ).values(id=1, name="name4")
+ i = insert(users)
+ i = i.on_conflict_do_update(
+ index_elements=users.primary_key.columns,
+ set_=dict(id=i.excluded.id, name=i.excluded.name),
+ ).values(id=1, name="name4")
- result = conn.execute(i)
- eq_(result.inserted_primary_key, (1,))
- eq_(result.returned_defaults, None)
+ result = connection.execute(i)
+ eq_(result.inserted_primary_key, (1,))
+ eq_(result.returned_defaults, None)
- eq_(
- conn.execute(users.select().where(users.c.id == 1)).fetchall(),
- [(1, "name4")],
- )
+ eq_(
+ connection.execute(
+ users.select().where(users.c.id == 1)
+ ).fetchall(),
+ [(1, "name4")],
+ )
- def test_on_conflict_do_update_five(self):
+ def test_on_conflict_do_update_five(self, connection):
users = self.tables.users
- with testing.db.connect() as conn:
- conn.execute(users.insert(), dict(id=1, name="name1"))
+ connection.execute(users.insert(), dict(id=1, name="name1"))
- i = insert(users)
- i = i.on_conflict_do_update(
- index_elements=users.primary_key.columns,
- set_=dict(id=10, name="I'm a name"),
- ).values(id=1, name="name4")
+ i = insert(users)
+ i = i.on_conflict_do_update(
+ index_elements=users.primary_key.columns,
+ set_=dict(id=10, name="I'm a name"),
+ ).values(id=1, name="name4")
- result = conn.execute(i)
- eq_(result.inserted_primary_key, (1,))
- eq_(result.returned_defaults, None)
+ result = connection.execute(i)
+ eq_(result.inserted_primary_key, (1,))
+ eq_(result.returned_defaults, None)
- eq_(
- conn.execute(
- users.select().where(users.c.id == 10)
- ).fetchall(),
- [(10, "I'm a name")],
- )
+ eq_(
+ connection.execute(
+ users.select().where(users.c.id == 10)
+ ).fetchall(),
+ [(10, "I'm a name")],
+ )
- def test_on_conflict_do_update_multivalues(self):
+ def test_on_conflict_do_update_multivalues(self, connection):
users = self.tables.users
- with testing.db.connect() as conn:
- conn.execute(users.insert(), dict(id=1, name="name1"))
- conn.execute(users.insert(), dict(id=2, name="name2"))
-
- i = insert(users)
- i = i.on_conflict_do_update(
- index_elements=users.primary_key.columns,
- set_=dict(name="updated"),
- where=(i.excluded.name != "name12"),
- ).values(
- [
- dict(id=1, name="name11"),
- dict(id=2, name="name12"),
- dict(id=3, name="name13"),
- dict(id=4, name="name14"),
- ]
- )
-
- result = conn.execute(i)
- eq_(result.inserted_primary_key, (None,))
- eq_(result.returned_defaults, None)
-
- eq_(
- conn.execute(users.select().order_by(users.c.id)).fetchall(),
- [(1, "updated"), (2, "name2"), (3, "name13"), (4, "name14")],
- )
+ connection.execute(users.insert(), dict(id=1, name="name1"))
+ connection.execute(users.insert(), dict(id=2, name="name2"))
+
+ i = insert(users)
+ i = i.on_conflict_do_update(
+ index_elements=users.primary_key.columns,
+ set_=dict(name="updated"),
+ where=(i.excluded.name != "name12"),
+ ).values(
+ [
+ dict(id=1, name="name11"),
+ dict(id=2, name="name12"),
+ dict(id=3, name="name13"),
+ dict(id=4, name="name14"),
+ ]
+ )
+
+ result = connection.execute(i)
+ eq_(result.inserted_primary_key, (None,))
+ eq_(result.returned_defaults, None)
+
+ eq_(
+ connection.execute(users.select().order_by(users.c.id)).fetchall(),
+ [(1, "updated"), (2, "name2"), (3, "name13"), (4, "name14")],
+ )
def _exotic_targets_fixture(self, conn):
users = self.tables.users_xtra
@@ -429,260 +437,250 @@ class OnConflictTest(fixtures.TablesTest):
[(1, "name1", "name1@gmail.com", "not")],
)
- def test_on_conflict_do_update_exotic_targets_two(self):
+ def test_on_conflict_do_update_exotic_targets_two(self, connection):
users = self.tables.users_xtra
- with testing.db.connect() as conn:
- self._exotic_targets_fixture(conn)
- # try primary key constraint: cause an upsert on unique id column
- i = insert(users)
- i = i.on_conflict_do_update(
- index_elements=users.primary_key.columns,
- set_=dict(
- name=i.excluded.name, login_email=i.excluded.login_email
- ),
- )
- result = conn.execute(
- i,
- dict(
- id=1,
- name="name2",
- login_email="name1@gmail.com",
- lets_index_this="not",
- ),
- )
- eq_(result.inserted_primary_key, (1,))
- eq_(result.returned_defaults, None)
-
- eq_(
- conn.execute(users.select().where(users.c.id == 1)).fetchall(),
- [(1, "name2", "name1@gmail.com", "not")],
- )
-
- def test_on_conflict_do_update_exotic_targets_three(self):
+ self._exotic_targets_fixture(connection)
+ # try primary key constraint: cause an upsert on unique id column
+ i = insert(users)
+ i = i.on_conflict_do_update(
+ index_elements=users.primary_key.columns,
+ set_=dict(
+ name=i.excluded.name, login_email=i.excluded.login_email
+ ),
+ )
+ result = connection.execute(
+ i,
+ dict(
+ id=1,
+ name="name2",
+ login_email="name1@gmail.com",
+ lets_index_this="not",
+ ),
+ )
+ eq_(result.inserted_primary_key, (1,))
+ eq_(result.returned_defaults, None)
+
+ eq_(
+ connection.execute(
+ users.select().where(users.c.id == 1)
+ ).fetchall(),
+ [(1, "name2", "name1@gmail.com", "not")],
+ )
+
+ def test_on_conflict_do_update_exotic_targets_three(self, connection):
users = self.tables.users_xtra
- with testing.db.connect() as conn:
- self._exotic_targets_fixture(conn)
- # try unique constraint: cause an upsert on target
- # login_email, not id
- i = insert(users)
- i = i.on_conflict_do_update(
- constraint=self.unique_constraint,
- set_=dict(
- id=i.excluded.id,
- name=i.excluded.name,
- login_email=i.excluded.login_email,
- ),
- )
- # note: lets_index_this value totally ignored in SET clause.
- result = conn.execute(
- i,
- dict(
- id=42,
- name="nameunique",
- login_email="name2@gmail.com",
- lets_index_this="unique",
- ),
- )
- eq_(result.inserted_primary_key, (42,))
- eq_(result.returned_defaults, None)
-
- eq_(
- conn.execute(
- users.select().where(
- users.c.login_email == "name2@gmail.com"
- )
- ).fetchall(),
- [(42, "nameunique", "name2@gmail.com", "not")],
- )
-
- def test_on_conflict_do_update_exotic_targets_four(self):
+ self._exotic_targets_fixture(connection)
+ # try unique constraint: cause an upsert on target
+ # login_email, not id
+ i = insert(users)
+ i = i.on_conflict_do_update(
+ constraint=self.unique_constraint,
+ set_=dict(
+ id=i.excluded.id,
+ name=i.excluded.name,
+ login_email=i.excluded.login_email,
+ ),
+ )
+ # note: lets_index_this value totally ignored in SET clause.
+ result = connection.execute(
+ i,
+ dict(
+ id=42,
+ name="nameunique",
+ login_email="name2@gmail.com",
+ lets_index_this="unique",
+ ),
+ )
+ eq_(result.inserted_primary_key, (42,))
+ eq_(result.returned_defaults, None)
+
+ eq_(
+ connection.execute(
+ users.select().where(users.c.login_email == "name2@gmail.com")
+ ).fetchall(),
+ [(42, "nameunique", "name2@gmail.com", "not")],
+ )
+
+ def test_on_conflict_do_update_exotic_targets_four(self, connection):
users = self.tables.users_xtra
- with testing.db.connect() as conn:
- self._exotic_targets_fixture(conn)
- # try unique constraint by name: cause an
- # upsert on target login_email, not id
- i = insert(users)
- i = i.on_conflict_do_update(
- constraint=self.unique_constraint.name,
- set_=dict(
- id=i.excluded.id,
- name=i.excluded.name,
- login_email=i.excluded.login_email,
- ),
- )
- # note: lets_index_this value totally ignored in SET clause.
-
- result = conn.execute(
- i,
- dict(
- id=43,
- name="nameunique2",
- login_email="name2@gmail.com",
- lets_index_this="unique",
- ),
- )
- eq_(result.inserted_primary_key, (43,))
- eq_(result.returned_defaults, None)
-
- eq_(
- conn.execute(
- users.select().where(
- users.c.login_email == "name2@gmail.com"
- )
- ).fetchall(),
- [(43, "nameunique2", "name2@gmail.com", "not")],
- )
-
- def test_on_conflict_do_update_exotic_targets_four_no_pk(self):
+ self._exotic_targets_fixture(connection)
+ # try unique constraint by name: cause an
+ # upsert on target login_email, not id
+ i = insert(users)
+ i = i.on_conflict_do_update(
+ constraint=self.unique_constraint.name,
+ set_=dict(
+ id=i.excluded.id,
+ name=i.excluded.name,
+ login_email=i.excluded.login_email,
+ ),
+ )
+ # note: lets_index_this value totally ignored in SET clause.
+
+ result = connection.execute(
+ i,
+ dict(
+ id=43,
+ name="nameunique2",
+ login_email="name2@gmail.com",
+ lets_index_this="unique",
+ ),
+ )
+ eq_(result.inserted_primary_key, (43,))
+ eq_(result.returned_defaults, None)
+
+ eq_(
+ connection.execute(
+ users.select().where(users.c.login_email == "name2@gmail.com")
+ ).fetchall(),
+ [(43, "nameunique2", "name2@gmail.com", "not")],
+ )
+
+ def test_on_conflict_do_update_exotic_targets_four_no_pk(self, connection):
users = self.tables.users_xtra
- with testing.db.connect() as conn:
- self._exotic_targets_fixture(conn)
- # try unique constraint by name: cause an
- # upsert on target login_email, not id
- i = insert(users)
- i = i.on_conflict_do_update(
- index_elements=[users.c.login_email],
- set_=dict(
- id=i.excluded.id,
- name=i.excluded.name,
- login_email=i.excluded.login_email,
- ),
- )
-
- result = conn.execute(
- i, dict(name="name3", login_email="name1@gmail.com")
- )
- eq_(result.inserted_primary_key, (1,))
- eq_(result.returned_defaults, (1,))
-
- eq_(
- conn.execute(users.select().order_by(users.c.id)).fetchall(),
- [
- (1, "name3", "name1@gmail.com", "not"),
- (2, "name2", "name2@gmail.com", "not"),
- ],
- )
-
- def test_on_conflict_do_update_exotic_targets_five(self):
+ self._exotic_targets_fixture(connection)
+ # try unique constraint by name: cause an
+ # upsert on target login_email, not id
+ i = insert(users)
+ i = i.on_conflict_do_update(
+ index_elements=[users.c.login_email],
+ set_=dict(
+ id=i.excluded.id,
+ name=i.excluded.name,
+ login_email=i.excluded.login_email,
+ ),
+ )
+
+ result = connection.execute(
+ i, dict(name="name3", login_email="name1@gmail.com")
+ )
+ eq_(result.inserted_primary_key, (1,))
+ eq_(result.returned_defaults, (1,))
+
+ eq_(
+ connection.execute(users.select().order_by(users.c.id)).fetchall(),
+ [
+ (1, "name3", "name1@gmail.com", "not"),
+ (2, "name2", "name2@gmail.com", "not"),
+ ],
+ )
+
+ def test_on_conflict_do_update_exotic_targets_five(self, connection):
users = self.tables.users_xtra
- with testing.db.connect() as conn:
- self._exotic_targets_fixture(conn)
- # try bogus index
- i = insert(users)
- i = i.on_conflict_do_update(
- index_elements=self.bogus_index.columns,
- index_where=self.bogus_index.dialect_options["postgresql"][
- "where"
- ],
- set_=dict(
- name=i.excluded.name, login_email=i.excluded.login_email
- ),
- )
-
- assert_raises(
- exc.ProgrammingError,
- conn.execute,
- i,
- dict(
- id=1,
- name="namebogus",
- login_email="bogus@gmail.com",
- lets_index_this="bogus",
- ),
- )
-
- def test_on_conflict_do_update_exotic_targets_six(self):
+ self._exotic_targets_fixture(connection)
+ # try bogus index
+ i = insert(users)
+ i = i.on_conflict_do_update(
+ index_elements=self.bogus_index.columns,
+ index_where=self.bogus_index.dialect_options["postgresql"][
+ "where"
+ ],
+ set_=dict(
+ name=i.excluded.name, login_email=i.excluded.login_email
+ ),
+ )
+
+ assert_raises(
+ exc.ProgrammingError,
+ connection.execute,
+ i,
+ dict(
+ id=1,
+ name="namebogus",
+ login_email="bogus@gmail.com",
+ lets_index_this="bogus",
+ ),
+ )
+
+ def test_on_conflict_do_update_exotic_targets_six(self, connection):
users = self.tables.users_xtra
- with testing.db.connect() as conn:
- conn.execute(
- insert(users),
+ connection.execute(
+ insert(users),
+ dict(
+ id=1,
+ name="name1",
+ login_email="mail1@gmail.com",
+ lets_index_this="unique_name",
+ ),
+ )
+
+ i = insert(users)
+ i = i.on_conflict_do_update(
+ index_elements=self.unique_partial_index.columns,
+ index_where=self.unique_partial_index.dialect_options[
+ "postgresql"
+ ]["where"],
+ set_=dict(
+ name=i.excluded.name, login_email=i.excluded.login_email
+ ),
+ )
+
+ connection.execute(
+ i,
+ [
dict(
- id=1,
name="name1",
- login_email="mail1@gmail.com",
+ login_email="mail2@gmail.com",
lets_index_this="unique_name",
- ),
- )
-
- i = insert(users)
- i = i.on_conflict_do_update(
- index_elements=self.unique_partial_index.columns,
- index_where=self.unique_partial_index.dialect_options[
- "postgresql"
- ]["where"],
- set_=dict(
- name=i.excluded.name, login_email=i.excluded.login_email
- ),
- )
-
- conn.execute(
- i,
- [
- dict(
- name="name1",
- login_email="mail2@gmail.com",
- lets_index_this="unique_name",
- )
- ],
- )
-
- eq_(
- conn.execute(users.select()).fetchall(),
- [(1, "name1", "mail2@gmail.com", "unique_name")],
- )
-
- def test_on_conflict_do_update_no_row_actually_affected(self):
+ )
+ ],
+ )
+
+ eq_(
+ connection.execute(users.select()).fetchall(),
+ [(1, "name1", "mail2@gmail.com", "unique_name")],
+ )
+
+ def test_on_conflict_do_update_no_row_actually_affected(self, connection):
users = self.tables.users_xtra
- with testing.db.connect() as conn:
- self._exotic_targets_fixture(conn)
- i = insert(users)
- i = i.on_conflict_do_update(
- index_elements=[users.c.login_email],
- set_=dict(name="new_name"),
- where=(i.excluded.name == "other_name"),
- )
- result = conn.execute(
- i, dict(name="name2", login_email="name1@gmail.com")
- )
-
- eq_(result.returned_defaults, None)
- eq_(result.inserted_primary_key, None)
-
- eq_(
- conn.execute(users.select()).fetchall(),
- [
- (1, "name1", "name1@gmail.com", "not"),
- (2, "name2", "name2@gmail.com", "not"),
- ],
- )
-
- def test_on_conflict_do_update_special_types_in_set(self):
+ self._exotic_targets_fixture(connection)
+ i = insert(users)
+ i = i.on_conflict_do_update(
+ index_elements=[users.c.login_email],
+ set_=dict(name="new_name"),
+ where=(i.excluded.name == "other_name"),
+ )
+ result = connection.execute(
+ i, dict(name="name2", login_email="name1@gmail.com")
+ )
+
+ eq_(result.returned_defaults, None)
+ eq_(result.inserted_primary_key, None)
+
+ eq_(
+ connection.execute(users.select()).fetchall(),
+ [
+ (1, "name1", "name1@gmail.com", "not"),
+ (2, "name2", "name2@gmail.com", "not"),
+ ],
+ )
+
+ def test_on_conflict_do_update_special_types_in_set(self, connection):
bind_targets = self.tables.bind_targets
- with testing.db.connect() as conn:
- i = insert(bind_targets)
- conn.execute(i, {"id": 1, "data": "initial data"})
-
- eq_(
- conn.scalar(sql.select(bind_targets.c.data)),
- "initial data processed",
- )
-
- i = insert(bind_targets)
- i = i.on_conflict_do_update(
- index_elements=[bind_targets.c.id],
- set_=dict(data="new updated data"),
- )
- conn.execute(i, {"id": 1, "data": "new inserted data"})
-
- eq_(
- conn.scalar(sql.select(bind_targets.c.data)),
- "new updated data processed",
- )
+ i = insert(bind_targets)
+ connection.execute(i, {"id": 1, "data": "initial data"})
+
+ eq_(
+ connection.scalar(sql.select(bind_targets.c.data)),
+ "initial data processed",
+ )
+
+ i = insert(bind_targets)
+ i = i.on_conflict_do_update(
+ index_elements=[bind_targets.c.id],
+ set_=dict(data="new updated data"),
+ )
+ connection.execute(i, {"id": 1, "data": "new inserted data"})
+
+ eq_(
+ connection.scalar(sql.select(bind_targets.c.data)),
+ "new updated data processed",
+ )
diff --git a/test/dialect/postgresql/test_query.py b/test/dialect/postgresql/test_query.py
index c959acf35..94af168ee 100644
--- a/test/dialect/postgresql/test_query.py
+++ b/test/dialect/postgresql/test_query.py
@@ -35,30 +35,32 @@ from sqlalchemy.testing.assertsql import CursorSQL
from sqlalchemy.testing.assertsql import DialectSQL
-matchtable = cattable = None
-
-
class InsertTest(fixtures.TestBase, AssertsExecutionResults):
__only_on__ = "postgresql"
__backend__ = True
- @classmethod
- def setup_class(cls):
- cls.metadata = MetaData(testing.db)
+ def setup(self):
+ self.metadata = MetaData()
def teardown(self):
- self.metadata.drop_all()
- self.metadata.clear()
+ with testing.db.begin() as conn:
+ self.metadata.drop_all(conn)
+
+ @testing.combinations((False,), (True,))
+ def test_foreignkey_missing_insert(self, implicit_returning):
+ engine = engines.testing_engine(
+ options={"implicit_returning": implicit_returning}
+ )
- def test_foreignkey_missing_insert(self):
Table("t1", self.metadata, Column("id", Integer, primary_key=True))
t2 = Table(
"t2",
self.metadata,
Column("id", Integer, ForeignKey("t1.id"), primary_key=True),
)
- self.metadata.create_all()
+
+ self.metadata.create_all(engine)
# want to ensure that "null value in column "id" violates not-
# null constraint" is raised (IntegrityError on psycoopg2, but
@@ -67,19 +69,13 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
# the latter corresponds to autoincrement behavior, which is not
# the case here due to the foreign key.
- for eng in [
- engines.testing_engine(options={"implicit_returning": False}),
- engines.testing_engine(options={"implicit_returning": True}),
- ]:
- with expect_warnings(
- ".*has no Python-side or server-side default.*"
- ):
- with eng.connect() as conn:
- assert_raises(
- (exc.IntegrityError, exc.ProgrammingError),
- conn.execute,
- t2.insert(),
- )
+ with expect_warnings(".*has no Python-side or server-side default.*"):
+ with engine.begin() as conn:
+ assert_raises(
+ (exc.IntegrityError, exc.ProgrammingError),
+ conn.execute,
+ t2.insert(),
+ )
def test_sequence_insert(self):
table = Table(
@@ -88,7 +84,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
Column("id", Integer, Sequence("my_seq"), primary_key=True),
Column("data", String(30)),
)
- self.metadata.create_all()
+ self.metadata.create_all(testing.db)
self._assert_data_with_sequence(table, "my_seq")
@testing.requires.returning
@@ -99,7 +95,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
Column("id", Integer, Sequence("my_seq"), primary_key=True),
Column("data", String(30)),
)
- self.metadata.create_all()
+ self.metadata.create_all(testing.db)
self._assert_data_with_sequence_returning(table, "my_seq")
def test_opt_sequence_insert(self):
@@ -114,7 +110,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
),
Column("data", String(30)),
)
- self.metadata.create_all()
+ self.metadata.create_all(testing.db)
self._assert_data_autoincrement(table)
@testing.requires.returning
@@ -130,7 +126,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
),
Column("data", String(30)),
)
- self.metadata.create_all()
+ self.metadata.create_all(testing.db)
self._assert_data_autoincrement_returning(table)
def test_autoincrement_insert(self):
@@ -140,7 +136,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
Column("id", Integer, primary_key=True),
Column("data", String(30)),
)
- self.metadata.create_all()
+ self.metadata.create_all(testing.db)
self._assert_data_autoincrement(table)
@testing.requires.returning
@@ -151,7 +147,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
Column("id", Integer, primary_key=True),
Column("data", String(30)),
)
- self.metadata.create_all()
+ self.metadata.create_all(testing.db)
self._assert_data_autoincrement_returning(table)
def test_noautoincrement_insert(self):
@@ -161,7 +157,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
Column("id", Integer, primary_key=True, autoincrement=False),
Column("data", String(30)),
)
- self.metadata.create_all()
+ self.metadata.create_all(testing.db)
self._assert_data_noautoincrement(table)
def _assert_data_autoincrement(self, table):
@@ -169,7 +165,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
with self.sql_execution_asserter(engine) as asserter:
- with engine.connect() as conn:
+ with engine.begin() as conn:
# execute with explicit id
r = conn.execute(table.insert(), {"id": 30, "data": "d1"})
@@ -226,7 +222,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
),
)
- with engine.connect() as conn:
+ with engine.begin() as conn:
eq_(
conn.execute(table.select()).fetchall(),
[
@@ -250,7 +246,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
table = Table(table.name, m2, autoload_with=engine)
with self.sql_execution_asserter(engine) as asserter:
- with engine.connect() as conn:
+ with engine.begin() as conn:
conn.execute(table.insert(), {"id": 30, "data": "d1"})
r = conn.execute(table.insert(), {"data": "d2"})
eq_(r.inserted_primary_key, (5,))
@@ -288,7 +284,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
"INSERT INTO testtable (data) VALUES (:data)", [{"data": "d8"}]
),
)
- with engine.connect() as conn:
+ with engine.begin() as conn:
eq_(
conn.execute(table.select()).fetchall(),
[
@@ -308,7 +304,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
engine = engines.testing_engine(options={"implicit_returning": True})
with self.sql_execution_asserter(engine) as asserter:
- with engine.connect() as conn:
+ with engine.begin() as conn:
# execute with explicit id
@@ -367,7 +363,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
),
)
- with engine.connect() as conn:
+ with engine.begin() as conn:
eq_(
conn.execute(table.select()).fetchall(),
[
@@ -390,7 +386,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
table = Table(table.name, m2, autoload_with=engine)
with self.sql_execution_asserter(engine) as asserter:
- with engine.connect() as conn:
+ with engine.begin() as conn:
conn.execute(table.insert(), {"id": 30, "data": "d1"})
r = conn.execute(table.insert(), {"data": "d2"})
eq_(r.inserted_primary_key, (5,))
@@ -430,7 +426,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
),
)
- with engine.connect() as conn:
+ with engine.begin() as conn:
eq_(
conn.execute(table.select()).fetchall(),
[
@@ -450,7 +446,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
engine = engines.testing_engine(options={"implicit_returning": False})
with self.sql_execution_asserter(engine) as asserter:
- with engine.connect() as conn:
+ with engine.begin() as conn:
conn.execute(table.insert(), {"id": 30, "data": "d1"})
conn.execute(table.insert(), {"data": "d2"})
conn.execute(
@@ -491,7 +487,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
[{"data": "d8"}],
),
)
- with engine.connect() as conn:
+ with engine.begin() as conn:
eq_(
conn.execute(table.select()).fetchall(),
[
@@ -513,7 +509,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
engine = engines.testing_engine(options={"implicit_returning": True})
with self.sql_execution_asserter(engine) as asserter:
- with engine.connect() as conn:
+ with engine.begin() as conn:
conn.execute(table.insert(), {"id": 30, "data": "d1"})
conn.execute(table.insert(), {"data": "d2"})
conn.execute(
@@ -555,7 +551,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
),
)
- with engine.connect() as conn:
+ with engine.begin() as conn:
eq_(
conn.execute(table.select()).fetchall(),
[
@@ -578,9 +574,12 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
# turning off the cache because we are checking for compile-time
# warnings
- with engine.connect().execution_options(compiled_cache=None) as conn:
+ engine = engine.execution_options(compiled_cache=None)
+
+ with engine.begin() as conn:
conn.execute(table.insert(), {"id": 30, "data": "d1"})
+ with engine.begin() as conn:
with expect_warnings(
".*has no Python-side or server-side default.*"
):
@@ -590,6 +589,8 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
table.insert(),
{"data": "d2"},
)
+
+ with engine.begin() as conn:
with expect_warnings(
".*has no Python-side or server-side default.*"
):
@@ -599,6 +600,8 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
table.insert(),
[{"data": "d2"}, {"data": "d3"}],
)
+
+ with engine.begin() as conn:
with expect_warnings(
".*has no Python-side or server-side default.*"
):
@@ -608,6 +611,8 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
table.insert(),
{"data": "d2"},
)
+
+ with engine.begin() as conn:
with expect_warnings(
".*has no Python-side or server-side default.*"
):
@@ -618,6 +623,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
[{"data": "d2"}, {"data": "d3"}],
)
+ with engine.begin() as conn:
conn.execute(
table.insert(),
[{"id": 31, "data": "d2"}, {"id": 32, "data": "d3"}],
@@ -634,9 +640,10 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
m2 = MetaData()
table = Table(table.name, m2, autoload_with=engine)
- with engine.connect() as conn:
+ with engine.begin() as conn:
conn.execute(table.insert(), {"id": 30, "data": "d1"})
+ with engine.begin() as conn:
with expect_warnings(
".*has no Python-side or server-side default.*"
):
@@ -646,6 +653,8 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
table.insert(),
{"data": "d2"},
)
+
+ with engine.begin() as conn:
with expect_warnings(
".*has no Python-side or server-side default.*"
):
@@ -655,6 +664,8 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
table.insert(),
[{"data": "d2"}, {"data": "d3"}],
)
+
+ with engine.begin() as conn:
conn.execute(
table.insert(),
[{"id": 31, "data": "d2"}, {"id": 32, "data": "d3"}],
@@ -666,36 +677,40 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
)
-class MatchTest(fixtures.TestBase, AssertsCompiledSQL):
+class MatchTest(fixtures.TablesTest, AssertsCompiledSQL):
__only_on__ = "postgresql >= 8.3"
__backend__ = True
@classmethod
- def setup_class(cls):
- global metadata, cattable, matchtable
- metadata = MetaData(testing.db)
- cattable = Table(
+ def define_tables(cls, metadata):
+ Table(
"cattable",
metadata,
Column("id", Integer, primary_key=True),
Column("description", String(50)),
)
- matchtable = Table(
+ Table(
"matchtable",
metadata,
Column("id", Integer, primary_key=True),
Column("title", String(200)),
Column("category_id", Integer, ForeignKey("cattable.id")),
)
- metadata.create_all()
- cattable.insert().execute(
+
+ @classmethod
+ def insert_data(cls, connection):
+ cattable, matchtable = cls.tables("cattable", "matchtable")
+
+ connection.execute(
+ cattable.insert(),
[
{"id": 1, "description": "Python"},
{"id": 2, "description": "Ruby"},
- ]
+ ],
)
- matchtable.insert().execute(
+ connection.execute(
+ matchtable.insert(),
[
{
"id": 1,
@@ -714,15 +729,12 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL):
"category_id": 1,
},
{"id": 5, "title": "Python in a Nutshell", "category_id": 1},
- ]
+ ],
)
- @classmethod
- def teardown_class(cls):
- metadata.drop_all()
-
@testing.requires.pyformat_paramstyle
def test_expression_pyformat(self):
+ matchtable = self.tables.matchtable
self.assert_compile(
matchtable.c.title.match("somstr"),
"matchtable.title @@ to_tsquery(%(title_1)s" ")",
@@ -730,51 +742,47 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL):
@testing.requires.format_paramstyle
def test_expression_positional(self):
+ matchtable = self.tables.matchtable
self.assert_compile(
matchtable.c.title.match("somstr"),
"matchtable.title @@ to_tsquery(%s)",
)
- def test_simple_match(self):
- results = (
+ def test_simple_match(self, connection):
+ matchtable = self.tables.matchtable
+ results = connection.execute(
matchtable.select()
.where(matchtable.c.title.match("python"))
.order_by(matchtable.c.id)
- .execute()
- .fetchall()
- )
+ ).fetchall()
eq_([2, 5], [r.id for r in results])
- def test_not_match(self):
- results = (
+ def test_not_match(self, connection):
+ matchtable = self.tables.matchtable
+ results = connection.execute(
matchtable.select()
.where(~matchtable.c.title.match("python"))
.order_by(matchtable.c.id)
- .execute()
- .fetchall()
- )
+ ).fetchall()
eq_([1, 3, 4], [r.id for r in results])
- def test_simple_match_with_apostrophe(self):
- results = (
- matchtable.select()
- .where(matchtable.c.title.match("Matz's"))
- .execute()
- .fetchall()
- )
+ def test_simple_match_with_apostrophe(self, connection):
+ matchtable = self.tables.matchtable
+ results = connection.execute(
+ matchtable.select().where(matchtable.c.title.match("Matz's"))
+ ).fetchall()
eq_([3], [r.id for r in results])
- def test_simple_derivative_match(self):
- results = (
- matchtable.select()
- .where(matchtable.c.title.match("nutshells"))
- .execute()
- .fetchall()
- )
+ def test_simple_derivative_match(self, connection):
+ matchtable = self.tables.matchtable
+ results = connection.execute(
+ matchtable.select().where(matchtable.c.title.match("nutshells"))
+ ).fetchall()
eq_([5], [r.id for r in results])
- def test_or_match(self):
- results1 = (
+ def test_or_match(self, connection):
+ matchtable = self.tables.matchtable
+ results1 = connection.execute(
matchtable.select()
.where(
or_(
@@ -783,42 +791,36 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL):
)
)
.order_by(matchtable.c.id)
- .execute()
- .fetchall()
- )
+ ).fetchall()
eq_([3, 5], [r.id for r in results1])
- results2 = (
+ results2 = connection.execute(
matchtable.select()
.where(matchtable.c.title.match("nutshells | rubies"))
.order_by(matchtable.c.id)
- .execute()
- .fetchall()
- )
+ ).fetchall()
eq_([3, 5], [r.id for r in results2])
- def test_and_match(self):
- results1 = (
- matchtable.select()
- .where(
+ def test_and_match(self, connection):
+ matchtable = self.tables.matchtable
+ results1 = connection.execute(
+ matchtable.select().where(
and_(
matchtable.c.title.match("python"),
matchtable.c.title.match("nutshells"),
)
)
- .execute()
- .fetchall()
- )
+ ).fetchall()
eq_([5], [r.id for r in results1])
- results2 = (
- matchtable.select()
- .where(matchtable.c.title.match("python & nutshells"))
- .execute()
- .fetchall()
- )
+ results2 = connection.execute(
+ matchtable.select().where(
+ matchtable.c.title.match("python & nutshells")
+ )
+ ).fetchall()
eq_([5], [r.id for r in results2])
- def test_match_across_joins(self):
- results = (
+ def test_match_across_joins(self, connection):
+ cattable, matchtable = self.tables("cattable", "matchtable")
+ results = connection.execute(
matchtable.select()
.where(
and_(
@@ -830,9 +832,7 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL):
)
)
.order_by(matchtable.c.id)
- .execute()
- .fetchall()
- )
+ ).fetchall()
eq_([1, 3, 5], [r.id for r in results])
diff --git a/test/dialect/postgresql/test_reflection.py b/test/dialect/postgresql/test_reflection.py
index 4de4d88e3..824f6cd36 100644
--- a/test/dialect/postgresql/test_reflection.py
+++ b/test/dialect/postgresql/test_reflection.py
@@ -291,63 +291,64 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
@classmethod
def setup_class(cls):
- con = testing.db.connect()
- for ddl in [
- 'CREATE SCHEMA "SomeSchema"',
- "CREATE DOMAIN testdomain INTEGER NOT NULL DEFAULT 42",
- "CREATE DOMAIN test_schema.testdomain INTEGER DEFAULT 0",
- "CREATE TYPE testtype AS ENUM ('test')",
- "CREATE DOMAIN enumdomain AS testtype",
- "CREATE DOMAIN arraydomain AS INTEGER[]",
- 'CREATE DOMAIN "SomeSchema"."Quoted.Domain" INTEGER DEFAULT 0',
- ]:
- try:
- con.exec_driver_sql(ddl)
- except exc.DBAPIError as e:
- if "already exists" not in str(e):
- raise e
- con.exec_driver_sql(
- "CREATE TABLE testtable (question integer, answer " "testdomain)"
- )
- con.exec_driver_sql(
- "CREATE TABLE test_schema.testtable(question "
- "integer, answer test_schema.testdomain, anything "
- "integer)"
- )
- con.exec_driver_sql(
- "CREATE TABLE crosschema (question integer, answer "
- "test_schema.testdomain)"
- )
+ with testing.db.begin() as con:
+ for ddl in [
+ 'CREATE SCHEMA "SomeSchema"',
+ "CREATE DOMAIN testdomain INTEGER NOT NULL DEFAULT 42",
+ "CREATE DOMAIN test_schema.testdomain INTEGER DEFAULT 0",
+ "CREATE TYPE testtype AS ENUM ('test')",
+ "CREATE DOMAIN enumdomain AS testtype",
+ "CREATE DOMAIN arraydomain AS INTEGER[]",
+ 'CREATE DOMAIN "SomeSchema"."Quoted.Domain" INTEGER DEFAULT 0',
+ ]:
+ try:
+ con.exec_driver_sql(ddl)
+ except exc.DBAPIError as e:
+ if "already exists" not in str(e):
+ raise e
+ con.exec_driver_sql(
+ "CREATE TABLE testtable (question integer, answer "
+ "testdomain)"
+ )
+ con.exec_driver_sql(
+ "CREATE TABLE test_schema.testtable(question "
+ "integer, answer test_schema.testdomain, anything "
+ "integer)"
+ )
+ con.exec_driver_sql(
+ "CREATE TABLE crosschema (question integer, answer "
+ "test_schema.testdomain)"
+ )
- con.exec_driver_sql(
- "CREATE TABLE enum_test (id integer, data enumdomain)"
- )
+ con.exec_driver_sql(
+ "CREATE TABLE enum_test (id integer, data enumdomain)"
+ )
- con.exec_driver_sql(
- "CREATE TABLE array_test (id integer, data arraydomain)"
- )
+ con.exec_driver_sql(
+ "CREATE TABLE array_test (id integer, data arraydomain)"
+ )
- con.exec_driver_sql(
- "CREATE TABLE quote_test "
- '(id integer, data "SomeSchema"."Quoted.Domain")'
- )
+ con.exec_driver_sql(
+ "CREATE TABLE quote_test "
+ '(id integer, data "SomeSchema"."Quoted.Domain")'
+ )
@classmethod
def teardown_class(cls):
- con = testing.db.connect()
- con.exec_driver_sql("DROP TABLE testtable")
- con.exec_driver_sql("DROP TABLE test_schema.testtable")
- con.exec_driver_sql("DROP TABLE crosschema")
- con.exec_driver_sql("DROP TABLE quote_test")
- con.exec_driver_sql("DROP DOMAIN testdomain")
- con.exec_driver_sql("DROP DOMAIN test_schema.testdomain")
- con.exec_driver_sql("DROP TABLE enum_test")
- con.exec_driver_sql("DROP DOMAIN enumdomain")
- con.exec_driver_sql("DROP TYPE testtype")
- con.exec_driver_sql("DROP TABLE array_test")
- con.exec_driver_sql("DROP DOMAIN arraydomain")
- con.exec_driver_sql('DROP DOMAIN "SomeSchema"."Quoted.Domain"')
- con.exec_driver_sql('DROP SCHEMA "SomeSchema"')
+ with testing.db.begin() as con:
+ con.exec_driver_sql("DROP TABLE testtable")
+ con.exec_driver_sql("DROP TABLE test_schema.testtable")
+ con.exec_driver_sql("DROP TABLE crosschema")
+ con.exec_driver_sql("DROP TABLE quote_test")
+ con.exec_driver_sql("DROP DOMAIN testdomain")
+ con.exec_driver_sql("DROP DOMAIN test_schema.testdomain")
+ con.exec_driver_sql("DROP TABLE enum_test")
+ con.exec_driver_sql("DROP DOMAIN enumdomain")
+ con.exec_driver_sql("DROP TYPE testtype")
+ con.exec_driver_sql("DROP TABLE array_test")
+ con.exec_driver_sql("DROP DOMAIN arraydomain")
+ con.exec_driver_sql('DROP DOMAIN "SomeSchema"."Quoted.Domain"')
+ con.exec_driver_sql('DROP SCHEMA "SomeSchema"')
def test_table_is_reflected(self):
metadata = MetaData()
@@ -486,7 +487,7 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
Column("id", Integer, primary_key=True),
Column("ref", Integer, ForeignKey("subject.id$")),
)
- meta1.create_all()
+ meta1.create_all(testing.db)
meta2 = MetaData()
subject = Table("subject", meta2, autoload_with=testing.db)
referer = Table("referer", meta2, autoload_with=testing.db)
@@ -523,9 +524,11 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
with testing.db.begin() as conn:
r = conn.execute(t2.insert())
eq_(r.inserted_primary_key, (1,))
- testing.db.connect().execution_options(
- autocommit=True
- ).exec_driver_sql("alter table t_id_seq rename to foobar_id_seq")
+
+ with testing.db.begin() as conn:
+ conn.exec_driver_sql(
+ "alter table t_id_seq rename to foobar_id_seq"
+ )
m3 = MetaData()
t3 = Table("t", m3, autoload_with=testing.db, implicit_returning=False)
eq_(
@@ -545,10 +548,12 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
Column("id", Integer, primary_key=True),
Column("x", Integer),
)
- metadata.create_all()
- testing.db.connect().execution_options(
- autocommit=True
- ).exec_driver_sql("alter table t alter column id type varchar(50)")
+ metadata.create_all(testing.db)
+
+ with testing.db.begin() as conn:
+ conn.exec_driver_sql(
+ "alter table t alter column id type varchar(50)"
+ )
m2 = MetaData()
t2 = Table("t", m2, autoload_with=testing.db)
eq_(t2.c.id.autoincrement, False)
@@ -558,10 +563,9 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
def test_renamed_pk_reflection(self):
metadata = self.metadata
Table("t", metadata, Column("id", Integer, primary_key=True))
- metadata.create_all()
- testing.db.connect().execution_options(
- autocommit=True
- ).exec_driver_sql("alter table t rename id to t_id")
+ metadata.create_all(testing.db)
+ with testing.db.begin() as conn:
+ conn.exec_driver_sql("alter table t rename id to t_id")
m2 = MetaData()
t2 = Table("t", m2, autoload_with=testing.db)
eq_([c.name for c in t2.primary_key], ["t_id"])
@@ -936,13 +940,13 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
Column("name", String(20), index=True),
Column("aname", String(20)),
)
- metadata.create_all()
- with testing.db.connect() as c:
- c.exec_driver_sql("create index idx1 on party ((id || name))")
- c.exec_driver_sql(
+ metadata.create_all(testing.db)
+ with testing.db.begin() as conn:
+ conn.exec_driver_sql("create index idx1 on party ((id || name))")
+ conn.exec_driver_sql(
"create unique index idx2 on party (id) where name = 'test'"
)
- c.exec_driver_sql(
+ conn.exec_driver_sql(
"""
create index idx3 on party using btree
(lower(name::text), lower(aname::text))
@@ -1029,7 +1033,7 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
Column("aname", String(20)),
)
- with testing.db.connect() as conn:
+ with testing.db.begin() as conn:
t1.create(conn)
@@ -1109,18 +1113,19 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
Column("id", Integer, primary_key=True),
Column("x", Integer),
)
- metadata.create_all()
- conn = testing.db.connect().execution_options(autocommit=True)
- conn.exec_driver_sql("CREATE INDEX idx1 ON t (x)")
- conn.exec_driver_sql("ALTER TABLE t RENAME COLUMN x to y")
+ metadata.create_all(testing.db)
+ with testing.db.begin() as conn:
+ conn.exec_driver_sql("CREATE INDEX idx1 ON t (x)")
+ conn.exec_driver_sql("ALTER TABLE t RENAME COLUMN x to y")
- ind = testing.db.dialect.get_indexes(conn, "t", None)
- expected = [{"name": "idx1", "unique": False, "column_names": ["y"]}]
- if testing.requires.index_reflects_included_columns.enabled:
- expected[0]["include_columns"] = []
+ ind = testing.db.dialect.get_indexes(conn, "t", None)
+ expected = [
+ {"name": "idx1", "unique": False, "column_names": ["y"]}
+ ]
+ if testing.requires.index_reflects_included_columns.enabled:
+ expected[0]["include_columns"] = []
- eq_(ind, expected)
- conn.close()
+ eq_(ind, expected)
@testing.fails_if("postgresql < 8.2", "reloptions not supported")
@testing.provide_metadata
@@ -1135,9 +1140,9 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
Column("id", Integer, primary_key=True),
Column("x", Integer),
)
- metadata.create_all()
+ metadata.create_all(testing.db)
- with testing.db.connect().execution_options(autocommit=True) as conn:
+ with testing.db.begin() as conn:
conn.exec_driver_sql(
"CREATE INDEX idx1 ON t (x) WITH (fillfactor = 50)"
)
@@ -1177,8 +1182,8 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
Column("id", Integer, primary_key=True),
Column("x", ARRAY(Integer)),
)
- metadata.create_all()
- with testing.db.connect().execution_options(autocommit=True) as conn:
+ metadata.create_all(testing.db)
+ with testing.db.begin() as conn:
conn.exec_driver_sql("CREATE INDEX idx1 ON t USING gin (x)")
ind = testing.db.dialect.get_indexes(conn, "t", None)
@@ -1215,7 +1220,7 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
Column("name", String(20)),
)
metadata.create_all()
- with testing.db.connect() as conn:
+ with testing.db.begin() as conn:
conn.exec_driver_sql("CREATE INDEX idx1 ON t (x) INCLUDE (name)")
# prior to #5205, this would return:
@@ -1312,8 +1317,7 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
eq_(fk, fk_ref[fk["name"]])
@testing.provide_metadata
- def test_inspect_enums_schema(self):
- conn = testing.db.connect()
+ def test_inspect_enums_schema(self, connection):
enum_type = postgresql.ENUM(
"sad",
"ok",
@@ -1322,8 +1326,8 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
schema="test_schema",
metadata=self.metadata,
)
- enum_type.create(conn)
- inspector = inspect(conn)
+ enum_type.create(connection)
+ inspector = inspect(connection)
eq_(
inspector.get_enums("test_schema"),
[
diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py
index e7174f234..ae7a65a3a 100644
--- a/test/dialect/postgresql/test_types.py
+++ b/test/dialect/postgresql/test_types.py
@@ -206,7 +206,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
),
schema=symbol_name,
)
- with testing.db.connect() as conn:
+ with testing.db.begin() as conn:
conn = conn.execution_options(
schema_translate_map={symbol_name: testing.config.test_schema}
)
diff --git a/test/dialect/test_mxodbc.py b/test/dialect/test_mxodbc.py
index de8b22b67..cd8768d73 100644
--- a/test/dialect/test_mxodbc.py
+++ b/test/dialect/test_mxodbc.py
@@ -30,34 +30,37 @@ class MxODBCTest(fixtures.TestBase):
)
conn = engine.connect()
- # crud: uses execute
- conn.execute(t1.insert().values(c1="foo"))
- conn.execute(t1.delete().where(t1.c.c1 == "foo"))
- conn.execute(t1.update().where(t1.c.c1 == "foo").values(c1="bar"))
-
- # select: uses executedirect
- conn.execute(t1.select())
-
- # manual flagging
- conn.execution_options(native_odbc_execute=True).execute(t1.select())
- conn.execution_options(native_odbc_execute=False).execute(
- t1.insert().values(c1="foo")
- )
+ with conn.begin():
+ # crud: uses execute
+ conn.execute(t1.insert().values(c1="foo"))
+ conn.execute(t1.delete().where(t1.c.c1 == "foo"))
+ conn.execute(t1.update().where(t1.c.c1 == "foo").values(c1="bar"))
- eq_(
- # fmt: off
- [
- c[2]
- for c in dbapi.connect.return_value.cursor.
- return_value.execute.mock_calls
- ],
- # fmt: on
- [
- {"direct": True},
- {"direct": True},
- {"direct": True},
- {"direct": True},
- {"direct": False},
- {"direct": True},
- ]
- )
+ # select: uses executedirect
+ conn.execute(t1.select())
+
+ # manual flagging
+ conn.execution_options(native_odbc_execute=True).execute(
+ t1.select()
+ )
+ conn.execution_options(native_odbc_execute=False).execute(
+ t1.insert().values(c1="foo")
+ )
+
+ eq_(
+ # fmt: off
+ [
+ c[2]
+ for c in dbapi.connect.return_value.cursor.
+ return_value.execute.mock_calls
+ ],
+ # fmt: on
+ [
+ {"direct": True},
+ {"direct": True},
+ {"direct": True},
+ {"direct": True},
+ {"direct": False},
+ {"direct": True},
+ ]
+ )
diff --git a/test/dialect/test_sqlite.py b/test/dialect/test_sqlite.py
index f8b50f888..12200f832 100644
--- a/test/dialect/test_sqlite.py
+++ b/test/dialect/test_sqlite.py
@@ -63,8 +63,9 @@ from sqlalchemy.util import ue
def exec_sql(engine, sql, *args, **kwargs):
- conn = engine.connect(close_with_result=True)
- return conn.exec_driver_sql(sql, *args, **kwargs)
+ # TODO: convert all tests to not use this
+ with engine.begin() as conn:
+ conn.exec_driver_sql(sql, *args, **kwargs)
class TestTypes(fixtures.TestBase, AssertsExecutionResults):
@@ -189,11 +190,13 @@ class TestTypes(fixtures.TestBase, AssertsExecutionResults):
connection.execute(
t.insert().values(d=datetime.datetime(2010, 10, 15, 12, 37, 0))
)
- exec_sql(
- testing.db, "insert into t (d) values ('2004-05-21T00:00:00')"
+ connection.exec_driver_sql(
+ "insert into t (d) values ('2004-05-21T00:00:00')"
)
eq_(
- exec_sql(testing.db, "select * from t order by d").fetchall(),
+ connection.exec_driver_sql(
+ "select * from t order by d"
+ ).fetchall(),
[("2004-05-21T00:00:00",), ("2010-10-15T12:37:00",)],
)
eq_(
@@ -216,9 +219,13 @@ class TestTypes(fixtures.TestBase, AssertsExecutionResults):
connection.execute(
t.insert().values(d=datetime.datetime(2010, 10, 15, 12, 37, 0))
)
- exec_sql(testing.db, "insert into t (d) values ('20040521000000')")
+ connection.exec_driver_sql(
+ "insert into t (d) values ('20040521000000')"
+ )
eq_(
- exec_sql(testing.db, "select * from t order by d").fetchall(),
+ connection.exec_driver_sql(
+ "select * from t order by d"
+ ).fetchall(),
[("20040521000000",), ("20101015123700",)],
)
eq_(
@@ -238,9 +245,11 @@ class TestTypes(fixtures.TestBase, AssertsExecutionResults):
t = Table("t", self.metadata, Column("d", sqlite_date))
self.metadata.create_all(connection)
connection.execute(t.insert().values(d=datetime.date(2010, 10, 15)))
- exec_sql(testing.db, "insert into t (d) values ('20040521')")
+ connection.exec_driver_sql("insert into t (d) values ('20040521')")
eq_(
- exec_sql(testing.db, "select * from t order by d").fetchall(),
+ connection.exec_driver_sql(
+ "select * from t order by d"
+ ).fetchall(),
[("20040521",), ("20101015",)],
)
eq_(
@@ -256,11 +265,15 @@ class TestTypes(fixtures.TestBase, AssertsExecutionResults):
regexp=r"(\d+)\|(\d+)\|(\d+)",
)
t = Table("t", self.metadata, Column("d", sqlite_date))
- self.metadata.create_all(testing.db)
+ self.metadata.create_all(connection)
connection.execute(t.insert().values(d=datetime.date(2010, 10, 15)))
- exec_sql(testing.db, "insert into t (d) values ('2004|05|21')")
+
+ connection.exec_driver_sql("insert into t (d) values ('2004|05|21')")
+
eq_(
- exec_sql(testing.db, "select * from t order by d").fetchall(),
+ connection.exec_driver_sql(
+ "select * from t order by d"
+ ).fetchall(),
[("2004|05|21",), ("2010|10|15",)],
)
eq_(
@@ -313,7 +326,7 @@ class JSONTest(fixtures.TestBase):
value = {"json": {"foo": "bar"}, "recs": ["one", "two"]}
- with testing.db.connect() as conn:
+ with testing.db.begin() as conn:
conn.execute(sqlite_json.insert(), foo=value)
eq_(conn.scalar(select(sqlite_json.c.foo)), value)
@@ -328,7 +341,7 @@ class JSONTest(fixtures.TestBase):
value = {"json": {"foo": "bar"}}
- with testing.db.connect() as conn:
+ with testing.db.begin() as conn:
conn.execute(sqlite_json.insert(), foo=value)
eq_(conn.scalar(select(sqlite_json.c.foo["json"])), value["json"])
@@ -551,7 +564,7 @@ class DefaultsTest(fixtures.TestBase, AssertsCompiledSQL):
Column("x", Boolean, server_default=sql.false()),
)
t.create(testing.db)
- with testing.db.connect() as conn:
+ with testing.db.begin() as conn:
conn.execute(t.insert())
conn.execute(t.insert().values(x=True))
eq_(
@@ -568,7 +581,7 @@ class DefaultsTest(fixtures.TestBase, AssertsCompiledSQL):
Column("x", DateTime(), server_default=func.now()),
)
t.create(testing.db)
- with testing.db.connect() as conn:
+ with testing.db.begin() as conn:
now = conn.scalar(func.now())
today = datetime.datetime.today()
conn.execute(t.insert())
@@ -587,7 +600,7 @@ class DefaultsTest(fixtures.TestBase, AssertsCompiledSQL):
Column("x", Integer(), server_default=func.abs(-5) + 17),
)
t.create(testing.db)
- with testing.db.connect() as conn:
+ with testing.db.begin() as conn:
conn.execute(t.insert())
conn.execute(t.insert().values(x=35))
eq_(
@@ -622,7 +635,8 @@ class DialectTest(
)
)
- def test_extra_reserved_words(self):
+ @testing.provide_metadata
+ def test_extra_reserved_words(self, connection):
"""Tests reserved words in identifiers.
'true', 'false', and 'column' are undocumented reserved words
@@ -630,22 +644,19 @@ class DialectTest(
here to ensure they remain in place if the dialect's
reserved_words set is updated in the future."""
- meta = MetaData(testing.db)
t = Table(
"reserved",
- meta,
+ self.metadata,
Column("safe", Integer),
Column("true", Integer),
Column("false", Integer),
Column("column", Integer),
Column("exists", Integer),
)
- try:
- meta.create_all()
- t.insert().execute(safe=1)
- list(t.select().execute())
- finally:
- meta.drop_all()
+ self.metadata.create_all(connection)
+ connection.execute(t.insert(), dict(safe=1))
+ result = connection.execute(t.select())
+ eq_(list(result), [(1, None, None, None, None)])
@testing.provide_metadata
def test_quoted_identifiers_functional_one(self):
@@ -827,7 +838,8 @@ class AttachedDBTest(fixtures.TestBase):
schema="test_schema",
)
- meta.create_all(self.conn)
+ with self.conn.begin():
+ meta.create_all(self.conn)
return ct
def setup(self):
@@ -835,7 +847,8 @@ class AttachedDBTest(fixtures.TestBase):
self.metadata = MetaData()
def teardown(self):
- self.metadata.drop_all(self.conn)
+ with self.conn.begin():
+ self.metadata.drop_all(self.conn)
self.conn.close()
def test_no_tables(self):
@@ -928,18 +941,20 @@ class AttachedDBTest(fixtures.TestBase):
def test_crud(self):
ct = self._fixture()
- self.conn.execute(ct.insert(), {"id": 1, "name": "foo"})
- eq_(self.conn.execute(ct.select()).fetchall(), [(1, "foo")])
+ with self.conn.begin():
+ self.conn.execute(ct.insert(), {"id": 1, "name": "foo"})
+ eq_(self.conn.execute(ct.select()).fetchall(), [(1, "foo")])
- self.conn.execute(ct.update(), {"id": 2, "name": "bar"})
- eq_(self.conn.execute(ct.select()).fetchall(), [(2, "bar")])
- self.conn.execute(ct.delete())
- eq_(self.conn.execute(ct.select()).fetchall(), [])
+ self.conn.execute(ct.update(), {"id": 2, "name": "bar"})
+ eq_(self.conn.execute(ct.select()).fetchall(), [(2, "bar")])
+ self.conn.execute(ct.delete())
+ eq_(self.conn.execute(ct.select()).fetchall(), [])
def test_col_targeting(self):
ct = self._fixture()
- self.conn.execute(ct.insert(), {"id": 1, "name": "foo"})
+ with self.conn.begin():
+ self.conn.execute(ct.insert(), {"id": 1, "name": "foo"})
row = self.conn.execute(ct.select()).first()
eq_(row._mapping["id"], 1)
eq_(row._mapping["name"], "foo")
@@ -947,7 +962,8 @@ class AttachedDBTest(fixtures.TestBase):
def test_col_targeting_union(self):
ct = self._fixture()
- self.conn.execute(ct.insert(), {"id": 1, "name": "foo"})
+ with self.conn.begin():
+ self.conn.execute(ct.insert(), {"id": 1, "name": "foo"})
row = self.conn.execute(ct.select().union(ct.select())).first()
eq_(row._mapping["id"], 1)
eq_(row._mapping["name"], "foo")
@@ -2236,7 +2252,7 @@ class ConstraintReflectionTest(fixtures.TestBase):
)
def test_foreign_key_options_unnamed_inline(self):
- with testing.db.connect() as conn:
+ with testing.db.begin() as conn:
conn.exec_driver_sql(
"create table foo (id integer, "
"foreign key (id) references bar (id) on update cascade)"
@@ -2571,33 +2587,33 @@ class TypeReflectionTest(fixtures.TestBase):
def _test_round_trip(self, fixture, warnings=False):
from sqlalchemy import inspect
- conn = testing.db.connect()
for from_, to_ in self._fixture_as_string(fixture):
- inspector = inspect(conn)
- conn.exec_driver_sql("CREATE TABLE foo (data %s)" % from_)
- try:
- if warnings:
+ with testing.db.begin() as conn:
+ inspector = inspect(conn)
+ conn.exec_driver_sql("CREATE TABLE foo (data %s)" % from_)
+ try:
+ if warnings:
- def go():
- return inspector.get_columns("foo")[0]
+ def go():
+ return inspector.get_columns("foo")[0]
- col_info = testing.assert_warnings(
- go, ["Could not instantiate"], regex=True
- )
- else:
- col_info = inspector.get_columns("foo")[0]
- expected_type = type(to_)
- is_(type(col_info["type"]), expected_type)
-
- # test args
- for attr in ("scale", "precision", "length"):
- if getattr(to_, attr, None) is not None:
- eq_(
- getattr(col_info["type"], attr),
- getattr(to_, attr, None),
+ col_info = testing.assert_warnings(
+ go, ["Could not instantiate"], regex=True
)
- finally:
- conn.exec_driver_sql("DROP TABLE foo")
+ else:
+ col_info = inspector.get_columns("foo")[0]
+ expected_type = type(to_)
+ is_(type(col_info["type"]), expected_type)
+
+ # test args
+ for attr in ("scale", "precision", "length"):
+ if getattr(to_, attr, None) is not None:
+ eq_(
+ getattr(col_info["type"], attr),
+ getattr(to_, attr, None),
+ )
+ finally:
+ conn.exec_driver_sql("DROP TABLE foo")
def test_lookup_direct_lookup(self):
self._test_lookup_direct(self._fixed_lookup_fixture())
diff --git a/test/engine/test_ddlevents.py b/test/engine/test_ddlevents.py
index f2429175f..5cbb47854 100644
--- a/test/engine/test_ddlevents.py
+++ b/test/engine/test_ddlevents.py
@@ -489,6 +489,7 @@ class DDLExecutionTest(fixtures.TestBase):
def test_ddl_execute(self):
engine = create_engine("sqlite:///")
cx = engine.connect()
+ cx.begin()
table = self.users
ddl = DDL("SELECT 1")
diff --git a/test/engine/test_deprecations.py b/test/engine/test_deprecations.py
index 5e32cc3e9..47e59b55d 100644
--- a/test/engine/test_deprecations.py
+++ b/test/engine/test_deprecations.py
@@ -93,6 +93,9 @@ class ConnectionlessDeprecationTest(fixtures.TestBase):
for meta in (MetaData, ThreadLocalMetaData):
for bind in (testing.db, testing.db.connect()):
+ if isinstance(bind, engine.Connection):
+ bind.begin()
+
if meta is ThreadLocalMetaData:
with testing.expect_deprecated(
"ThreadLocalMetaData is deprecated"
@@ -151,6 +154,8 @@ class ConnectionlessDeprecationTest(fixtures.TestBase):
def test_bind_create_drop_constructor_bound(self):
for bind in (testing.db, testing.db.connect()):
+ if isinstance(bind, engine.Connection):
+ bind.begin()
try:
for args in (([bind], {}), ([], {"bind": bind})):
metadata = MetaData(*args[0], **args[1])
@@ -177,15 +182,25 @@ class ConnectionlessDeprecationTest(fixtures.TestBase):
test_needs_acid=True,
)
conn = testing.db.connect()
- metadata.create_all(bind=conn)
+ with conn.begin():
+ metadata.create_all(bind=conn)
try:
trans = conn.begin()
metadata.bind = conn
t = table.insert()
assert t.bind is conn
- table.insert().execute(foo=5)
- table.insert().execute(foo=6)
- table.insert().execute(foo=7)
+ with testing.expect_deprecated_20(
+ r"The Executable.execute\(\) method is considered legacy"
+ ):
+ table.insert().execute(foo=5)
+ with testing.expect_deprecated_20(
+ r"The Executable.execute\(\) method is considered legacy"
+ ):
+ table.insert().execute(foo=6)
+ with testing.expect_deprecated_20(
+ r"The Executable.execute\(\) method is considered legacy"
+ ):
+ table.insert().execute(foo=7)
trans.rollback()
metadata.bind = None
assert (
@@ -195,7 +210,8 @@ class ConnectionlessDeprecationTest(fixtures.TestBase):
== 0
)
finally:
- metadata.drop_all(bind=conn)
+ with conn.begin():
+ metadata.drop_all(bind=conn)
def test_bind_clauseelement(self):
metadata = MetaData()
@@ -215,14 +231,21 @@ class ConnectionlessDeprecationTest(fixtures.TestBase):
):
e = elem(bind=bind)
assert e.bind is bind
- e.execute().close()
+ with testing.expect_deprecated_20(
+ r"The Executable.execute\(\) method is "
+ "considered legacy"
+ ):
+ e.execute().close()
finally:
if isinstance(bind, engine.Connection):
bind.close()
e = elem()
assert e.bind is None
- assert_raises(exc.UnboundExecutionError, e.execute)
+ with testing.expect_deprecated_20(
+ r"The Executable.execute\(\) method is considered legacy"
+ ):
+ assert_raises(exc.UnboundExecutionError, e.execute)
finally:
if isinstance(bind, engine.Connection):
bind.close()
@@ -365,6 +388,11 @@ class TransactionTest(fixtures.TablesTest):
)
Table("inserttable", metadata, Column("data", String(20)))
+ @testing.fixture
+ def local_connection(self):
+ with testing.db.connect() as conn:
+ yield conn
+
def test_transaction_container(self):
users = self.tables.users
@@ -429,6 +457,110 @@ class TransactionTest(fixtures.TablesTest):
"insert into inserttable (data) values ('thedata')"
)
+ def test_branch_autorollback(self, local_connection):
+ connection = local_connection
+ users = self.tables.users
+ branched = connection.connect()
+ with testing.expect_deprecated_20(
+ "The current statement is being autocommitted using "
+ "implicit autocommit"
+ ):
+ branched.execute(
+ users.insert(), dict(user_id=1, user_name="user1")
+ )
+ assert_raises(
+ exc.DBAPIError,
+ branched.execute,
+ users.insert(),
+ dict(user_id=1, user_name="user1"),
+ )
+ # can continue w/o issue
+ with testing.expect_deprecated_20(
+ "The current statement is being autocommitted using "
+ "implicit autocommit"
+ ):
+ branched.execute(
+ users.insert(), dict(user_id=2, user_name="user2")
+ )
+
+ def test_branch_orig_rollback(self, local_connection):
+ connection = local_connection
+ users = self.tables.users
+ branched = connection.connect()
+ with testing.expect_deprecated_20(
+ "The current statement is being autocommitted using "
+ "implicit autocommit"
+ ):
+ branched.execute(
+ users.insert(), dict(user_id=1, user_name="user1")
+ )
+ nested = branched.begin()
+ assert branched.in_transaction()
+ branched.execute(users.insert(), dict(user_id=2, user_name="user2"))
+ nested.rollback()
+ eq_(
+ connection.exec_driver_sql("select count(*) from users").scalar(),
+ 1,
+ )
+
+ @testing.requires.independent_connections
+ def test_branch_autocommit(self, local_connection):
+ users = self.tables.users
+ with testing.db.connect() as connection:
+ branched = connection.connect()
+ with testing.expect_deprecated_20(
+ "The current statement is being autocommitted using "
+ "implicit autocommit"
+ ):
+ branched.execute(
+ users.insert(), dict(user_id=1, user_name="user1")
+ )
+
+ eq_(
+ local_connection.execute(
+ text("select count(*) from users")
+ ).scalar(),
+ 1,
+ )
+
+ @testing.requires.savepoints
+ def test_branch_savepoint_rollback(self, local_connection):
+ connection = local_connection
+ users = self.tables.users
+ trans = connection.begin()
+ branched = connection.connect()
+ assert branched.in_transaction()
+ branched.execute(users.insert(), user_id=1, user_name="user1")
+ nested = branched.begin_nested()
+ branched.execute(users.insert(), user_id=2, user_name="user2")
+ nested.rollback()
+ assert connection.in_transaction()
+ trans.commit()
+ eq_(
+ connection.exec_driver_sql("select count(*) from users").scalar(),
+ 1,
+ )
+
+ @testing.requires.two_phase_transactions
+ def test_branch_twophase_rollback(self, local_connection):
+ connection = local_connection
+ users = self.tables.users
+ branched = connection.connect()
+ assert not branched.in_transaction()
+ with testing.expect_deprecated_20(
+ r"The current statement is being autocommitted using "
+ "implicit autocommit"
+ ):
+ branched.execute(users.insert(), user_id=1, user_name="user1")
+ nested = branched.begin_twophase()
+ branched.execute(users.insert(), user_id=2, user_name="user2")
+ nested.rollback()
+ assert not connection.in_transaction()
+ eq_(
+ connection.exec_driver_sql("select count(*) from users").scalar(),
+ 1,
+ )
+
class HandleInvalidatedOnConnectTest(fixtures.TestBase):
__requires__ = ("sqlite",)
@@ -699,20 +831,20 @@ class DeprecatedReflectionTest(fixtures.TablesTest):
def test_create_drop_explicit(self):
metadata = MetaData()
table = Table("test_table", metadata, Column("foo", Integer))
- for bind in (testing.db, testing.db.connect()):
- for args in [([], {"bind": bind}), ([bind], {})]:
- metadata.create_all(*args[0], **args[1])
- with testing.expect_deprecated(
- r"The Table.exists\(\) method is deprecated"
- ):
- assert table.exists(*args[0], **args[1])
- metadata.drop_all(*args[0], **args[1])
- table.create(*args[0], **args[1])
- table.drop(*args[0], **args[1])
- with testing.expect_deprecated(
- r"The Table.exists\(\) method is deprecated"
- ):
- assert not table.exists(*args[0], **args[1])
+ bind = testing.db
+ for args in [([], {"bind": bind}), ([bind], {})]:
+ metadata.create_all(*args[0], **args[1])
+ with testing.expect_deprecated(
+ r"The Table.exists\(\) method is deprecated"
+ ):
+ assert table.exists(*args[0], **args[1])
+ metadata.drop_all(*args[0], **args[1])
+ table.create(*args[0], **args[1])
+ table.drop(*args[0], **args[1])
+ with testing.expect_deprecated(
+ r"The Table.exists\(\) method is deprecated"
+ ):
+ assert not table.exists(*args[0], **args[1])
def test_create_drop_err_table(self):
metadata = MetaData()
@@ -1195,3 +1327,208 @@ class DDLExecutionTest(fixtures.TestBase):
with testing.expect_deprecated_20(ddl_msg):
r = fn(**kw)
eq_(list(r), [(1,)])
+
+
+class AutocommitKeywordFixture(object):
+ def _test_keyword(self, keyword, expected=True):
+ dbapi = Mock(
+ connect=Mock(
+ return_value=Mock(
+ cursor=Mock(return_value=Mock(description=()))
+ )
+ )
+ )
+ engine = engines.testing_engine(
+ options={"_initialize": False, "pool_reset_on_return": None}
+ )
+ engine.dialect.dbapi = dbapi
+
+ with engine.connect() as conn:
+ if expected:
+ with testing.expect_deprecated_20(
+ "The current statement is being autocommitted "
+ "using implicit autocommit"
+ ):
+ conn.exec_driver_sql(
+ "%s something table something" % keyword
+ )
+ else:
+ conn.exec_driver_sql("%s something table something" % keyword)
+
+ if expected:
+ eq_(
+ [n for (n, k, s) in dbapi.connect().mock_calls],
+ ["cursor", "commit"],
+ )
+ else:
+ eq_(
+ [n for (n, k, s) in dbapi.connect().mock_calls], ["cursor"]
+ )
+
+
+class AutocommitTextTest(AutocommitKeywordFixture, fixtures.TestBase):
+ __backend__ = True
+
+ def test_update(self):
+ self._test_keyword("UPDATE")
+
+ def test_insert(self):
+ self._test_keyword("INSERT")
+
+ def test_delete(self):
+ self._test_keyword("DELETE")
+
+ def test_alter(self):
+ self._test_keyword("ALTER TABLE")
+
+ def test_create(self):
+ self._test_keyword("CREATE TABLE foobar")
+
+ def test_drop(self):
+ self._test_keyword("DROP TABLE foobar")
+
+ def test_select(self):
+ self._test_keyword("SELECT foo FROM table", False)
+
+
+class ExplicitAutoCommitTest(fixtures.TestBase):
+
+ """test the 'autocommit' flag on select() and text() objects.
+
+ Requires PostgreSQL so that we may define a custom function which
+ modifies the database."""
+
+ __only_on__ = "postgresql"
+
+ @classmethod
+ def setup_class(cls):
+ global metadata, foo
+ metadata = MetaData(testing.db)
+ foo = Table(
+ "foo",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(100)),
+ )
+ with testing.db.begin() as conn:
+ metadata.create_all(conn)
+ conn.exec_driver_sql(
+ "create function insert_foo(varchar) "
+ "returns integer as 'insert into foo(data) "
+ "values ($1);select 1;' language sql"
+ )
+
+ def teardown(self):
+ with testing.db.begin() as conn:
+ conn.execute(foo.delete())
+
+ @classmethod
+ def teardown_class(cls):
+ with testing.db.begin() as conn:
+ conn.exec_driver_sql("drop function insert_foo(varchar)")
+ metadata.drop_all(conn)
+
+ def test_control(self):
+
+ # test that not using autocommit does not commit
+
+ conn1 = testing.db.connect()
+ conn2 = testing.db.connect()
+ conn1.execute(select(func.insert_foo("data1")))
+ assert conn2.execute(select(foo.c.data)).fetchall() == []
+ conn1.execute(text("select insert_foo('moredata')"))
+ assert conn2.execute(select(foo.c.data)).fetchall() == []
+ trans = conn1.begin()
+ trans.commit()
+ assert conn2.execute(select(foo.c.data)).fetchall() == [
+ ("data1",),
+ ("moredata",),
+ ]
+ conn1.close()
+ conn2.close()
+
+ def test_explicit_compiled(self):
+ conn1 = testing.db.connect()
+ conn2 = testing.db.connect()
+
+ with testing.expect_deprecated_20(
+ "The current statement is being autocommitted using "
+ "implicit autocommit"
+ ):
+ conn1.execute(
+ select(func.insert_foo("data1")).execution_options(
+ autocommit=True
+ )
+ )
+ assert conn2.execute(select(foo.c.data)).fetchall() == [("data1",)]
+ conn1.close()
+ conn2.close()
+
+ def test_explicit_connection(self):
+ conn1 = testing.db.connect()
+ conn2 = testing.db.connect()
+ with testing.expect_deprecated_20(
+ "The current statement is being autocommitted using "
+ "implicit autocommit"
+ ):
+ conn1.execution_options(autocommit=True).execute(
+ select(func.insert_foo("data1"))
+ )
+ eq_(conn2.execute(select(foo.c.data)).fetchall(), [("data1",)])
+
+ # connection supersedes statement
+
+ conn1.execution_options(autocommit=False).execute(
+ select(func.insert_foo("data2")).execution_options(autocommit=True)
+ )
+ eq_(conn2.execute(select(foo.c.data)).fetchall(), [("data1",)])
+
+ # ditto
+
+ with testing.expect_deprecated_20(
+ "The current statement is being autocommitted using "
+ "implicit autocommit"
+ ):
+ conn1.execution_options(autocommit=True).execute(
+ select(func.insert_foo("data3")).execution_options(
+ autocommit=False
+ )
+ )
+ eq_(
+ conn2.execute(select(foo.c.data)).fetchall(),
+ [("data1",), ("data2",), ("data3",)],
+ )
+ conn1.close()
+ conn2.close()
+
+ def test_explicit_text(self):
+ conn1 = testing.db.connect()
+ conn2 = testing.db.connect()
+ with testing.expect_deprecated_20(
+ "The current statement is being autocommitted using "
+ "implicit autocommit"
+ ):
+ conn1.execute(
+ text("select insert_foo('moredata')").execution_options(
+ autocommit=True
+ )
+ )
+ assert conn2.execute(select(foo.c.data)).fetchall() == [("moredata",)]
+ conn1.close()
+ conn2.close()
+
+ def test_implicit_text(self):
+ conn1 = testing.db.connect()
+ conn2 = testing.db.connect()
+ with testing.expect_deprecated_20(
+ "The current statement is being autocommitted using "
+ "implicit autocommit"
+ ):
+ conn1.execute(
+ text("insert into foo (data) values ('implicitdata')")
+ )
+ assert conn2.execute(select(foo.c.data)).fetchall() == [
+ ("implicitdata",)
+ ]
+ conn1.close()
+ conn2.close()
diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py
index efec9376c..55a114409 100644
--- a/test/engine/test_execute.py
+++ b/test/engine/test_execute.py
@@ -543,13 +543,15 @@ class ExecuteTest(fixtures.TablesTest):
@testing.only_on("sqlite")
def test_execute_compiled_favors_compiled_paramstyle(self):
+ users = self.tables.users
+
with patch.object(testing.db.dialect, "do_execute") as do_exec:
stmt = users.update().values(user_id=1, user_name="foo")
d1 = default.DefaultDialect(paramstyle="format")
d2 = default.DefaultDialect(paramstyle="pyformat")
- with testing.db.connect() as conn:
+ with testing.db.begin() as conn:
conn.execute(stmt.compile(dialect=d1))
conn.execute(stmt.compile(dialect=d2))
@@ -805,9 +807,8 @@ class ConvenienceExecuteTest(fixtures.TablesTest):
def test_connection_as_ctx(self):
fn = self._trans_fn()
- ctx = testing.db.connect()
- testing.run_as_contextmanager(ctx, fn, 5, value=8)
- # autocommit is on
+ with testing.db.begin() as conn:
+ fn(conn, 5, value=8)
self._assert_fn(5, value=8)
@testing.fails_on("mysql+oursql", "oursql bug ? getting wrong rowcount")
@@ -822,14 +823,12 @@ class ConvenienceExecuteTest(fixtures.TablesTest):
self._assert_no_data()
-class CompiledCacheTest(fixtures.TestBase):
+class CompiledCacheTest(fixtures.TablesTest):
__backend__ = True
@classmethod
- def setup_class(cls):
- global users, metadata
- metadata = MetaData(testing.db)
- users = Table(
+ def define_tables(cls, metadata):
+ Table(
"users",
metadata,
Column(
@@ -838,19 +837,11 @@ class CompiledCacheTest(fixtures.TestBase):
Column("user_name", VARCHAR(20)),
Column("extra_data", VARCHAR(20)),
)
- metadata.create_all()
- @engines.close_first
- def teardown(self):
- with testing.db.connect() as conn:
- conn.execute(users.delete())
-
- @classmethod
- def teardown_class(cls):
- metadata.drop_all()
+ def test_cache(self, connection):
+ users = self.tables.users
- def test_cache(self):
- conn = testing.db.connect()
+ conn = connection
cache = {}
cached_conn = conn.execution_options(compiled_cache=cache)
@@ -870,7 +861,7 @@ class CompiledCacheTest(fixtures.TestBase):
"uses blob value that is problematic for some DBAPIs",
)
@testing.provide_metadata
- def test_cache_noleak_on_statement_values(self):
+ def test_cache_noleak_on_statement_values(self, connection):
# This is a non regression test for an object reference leak caused
# by the compiled_cache.
@@ -883,11 +874,10 @@ class CompiledCacheTest(fixtures.TestBase):
),
Column("photo_blob", LargeBinary()),
)
- metadata.create_all()
+ metadata.create_all(connection)
- conn = testing.db.connect()
cache = {}
- cached_conn = conn.execution_options(compiled_cache=cache)
+ cached_conn = connection.execution_options(compiled_cache=cache)
class PhotoBlob(bytearray):
pass
@@ -902,7 +892,10 @@ class CompiledCacheTest(fixtures.TestBase):
cached_conn.execute(ins, {"photo_blob": blob})
eq_(compile_mock.call_count, 1)
eq_(len(cache), 1)
- eq_(conn.exec_driver_sql("select count(*) from photo").scalar(), 1)
+ eq_(
+ connection.exec_driver_sql("select count(*) from photo").scalar(),
+ 1,
+ )
del blob
@@ -912,14 +905,15 @@ class CompiledCacheTest(fixtures.TestBase):
# the statement values (only the keys).
eq_(ref_blob(), None)
- def test_keys_independent_of_ordering(self):
- conn = testing.db.connect()
- conn.execute(
+ def test_keys_independent_of_ordering(self, connection):
+ users = self.tables.users
+
+ connection.execute(
users.insert(),
{"user_id": 1, "user_name": "u1", "extra_data": "e1"},
)
cache = {}
- cached_conn = conn.execution_options(compiled_cache=cache)
+ cached_conn = connection.execution_options(compiled_cache=cache)
upd = users.update().where(users.c.user_id == bindparam("b_user_id"))
@@ -974,30 +968,32 @@ class CompiledCacheTest(fixtures.TestBase):
stmt = select(t1.c.q)
cache = {}
- with config.db.connect().execution_options(
- compiled_cache=cache
- ) as conn:
+ with config.db.begin() as conn:
+ conn = conn.execution_options(compiled_cache=cache)
conn.execute(ins, {"q": 1})
eq_(conn.scalar(stmt), 1)
- with config.db.connect().execution_options(
- compiled_cache=cache,
- schema_translate_map={None: config.test_schema},
- ) as conn:
+ with config.db.begin() as conn:
+ conn = conn.execution_options(
+ compiled_cache=cache,
+ schema_translate_map={None: config.test_schema},
+ )
conn.execute(ins, {"q": 2})
eq_(conn.scalar(stmt), 2)
- with config.db.connect().execution_options(
- compiled_cache=cache,
- schema_translate_map={None: None},
- ) as conn:
+ with config.db.begin() as conn:
+ conn = conn.execution_options(
+ compiled_cache=cache,
+ schema_translate_map={None: None},
+ )
# should use default schema again even though statement
# was compiled with test_schema in the map
eq_(conn.scalar(stmt), 1)
- with config.db.connect().execution_options(
- compiled_cache=cache
- ) as conn:
+ with config.db.begin() as conn:
+ conn = conn.execution_options(
+ compiled_cache=cache,
+ )
eq_(conn.scalar(stmt), 1)
@@ -1050,7 +1046,7 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
t3 = Table("t3", metadata, Column("x", Integer), schema="bar")
with self.sql_execution_asserter(config.db) as asserter:
- with config.db.connect().execution_options(
+ with config.db.begin() as conn, conn.execution_options(
schema_translate_map=map_
) as conn:
@@ -1091,9 +1087,8 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
Table("t2", metadata, Column("x", Integer), schema="foo")
Table("t3", metadata, Column("x", Integer), schema="bar")
- with config.db.connect().execution_options(
- schema_translate_map=map_
- ) as conn:
+ with config.db.begin() as conn:
+ conn = conn.execution_options(schema_translate_map=map_)
metadata.create_all(conn)
insp = inspect(config.db)
@@ -1101,9 +1096,8 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
is_true(insp.has_table("t2", schema=config.test_schema))
is_true(insp.has_table("t3", schema=None))
- with config.db.connect().execution_options(
- schema_translate_map=map_
- ) as conn:
+ with config.db.begin() as conn:
+ conn = conn.execution_options(schema_translate_map=map_)
metadata.drop_all(conn)
insp = inspect(config.db)
@@ -1127,7 +1121,7 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
t3 = Table("t3", metadata, Column("x", Integer), schema="bar")
with self.sql_execution_asserter(config.db) as asserter:
- with config.db.connect() as conn:
+ with config.db.begin() as conn:
execution_options = {"schema_translate_map": map_}
conn._execute_20(
@@ -1222,7 +1216,7 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
t3 = Table("t3", metadata, Column("x", Integer), schema="bar")
with self.sql_execution_asserter(config.db) as asserter:
- with config.db.connect().execution_options(
+ with config.db.begin() as conn, conn.execution_options(
schema_translate_map=map_
) as conn:
@@ -1790,6 +1784,7 @@ class EngineEventsTest(fixtures.TestBase):
else:
ctx = conn = engine.connect()
+ trans = conn.begin()
try:
m.create_all(conn, checkfirst=False)
try:
@@ -1801,8 +1796,7 @@ class EngineEventsTest(fixtures.TestBase):
)
finally:
m.drop_all(conn)
- if engine._is_future:
- conn.commit()
+ trans.commit()
finally:
if ctx:
ctx.close()
@@ -3046,7 +3040,7 @@ class DialectEventTest(fixtures.TestBase):
m1.do_execute_no_params.side_effect
) = mock_the_cursor
- with e.connect() as conn:
+ with e.begin() as conn:
yield conn, m1
def _assert(self, retval, m1, m2, mock_calls):
@@ -3244,59 +3238,6 @@ class DialectEventTest(fixtures.TestBase):
eq_(conn.info["boom"], "one")
-class AutocommitKeywordFixture(object):
- def _test_keyword(self, keyword, expected=True):
- dbapi = Mock(
- connect=Mock(
- return_value=Mock(
- cursor=Mock(return_value=Mock(description=()))
- )
- )
- )
- engine = engines.testing_engine(
- options={"_initialize": False, "pool_reset_on_return": None}
- )
- engine.dialect.dbapi = dbapi
-
- with engine.connect() as conn:
- conn.exec_driver_sql("%s something table something" % keyword)
-
- if expected:
- eq_(
- [n for (n, k, s) in dbapi.connect().mock_calls],
- ["cursor", "commit"],
- )
- else:
- eq_(
- [n for (n, k, s) in dbapi.connect().mock_calls], ["cursor"]
- )
-
-
-class AutocommitTextTest(AutocommitKeywordFixture, fixtures.TestBase):
- __backend__ = True
-
- def test_update(self):
- self._test_keyword("UPDATE")
-
- def test_insert(self):
- self._test_keyword("INSERT")
-
- def test_delete(self):
- self._test_keyword("DELETE")
-
- def test_alter(self):
- self._test_keyword("ALTER TABLE")
-
- def test_create(self):
- self._test_keyword("CREATE TABLE foobar")
-
- def test_drop(self):
- self._test_keyword("DROP TABLE foobar")
-
- def test_select(self):
- self._test_keyword("SELECT foo FROM table", False)
-
-
class FutureExecuteTest(fixtures.FutureEngineMixin, fixtures.TablesTest):
__backend__ = True
@@ -3463,7 +3404,7 @@ class SetInputSizesTest(fixtures.TablesTest):
def test_set_input_sizes_no_event(self, input_sizes_fixture):
engine, canary = input_sizes_fixture
- with engine.connect() as conn:
+ with engine.begin() as conn:
conn.execute(
self.tables.users.insert(),
[
@@ -3596,7 +3537,7 @@ class SetInputSizesTest(fixtures.TablesTest):
0,
)
- with engine.connect() as conn:
+ with engine.begin() as conn:
conn.execute(
self.tables.users.insert(),
[
diff --git a/test/engine/test_logging.py b/test/engine/test_logging.py
index aa272c0cf..29b8132aa 100644
--- a/test/engine/test_logging.py
+++ b/test/engine/test_logging.py
@@ -22,7 +22,7 @@ from sqlalchemy.testing.util import lazy_gc
def exec_sql(engine, sql, *args, **kwargs):
- with engine.connect() as conn:
+ with engine.begin() as conn:
return conn.exec_driver_sql(sql, *args, **kwargs)
@@ -56,7 +56,7 @@ class LogParamsTest(fixtures.TestBase):
[{"data": str(i)} for i in range(100)],
)
eq_(
- self.buf.buffer[1].message,
+ self.buf.buffer[2].message,
"[raw sql] [{'data': '0'}, {'data': '1'}, {'data': '2'}, "
"{'data': '3'}, "
"{'data': '4'}, {'data': '5'}, {'data': '6'}, {'data': '7'}"
@@ -86,7 +86,7 @@ class LogParamsTest(fixtures.TestBase):
[{"data": str(i)} for i in range(100)],
)
eq_(
- self.buf.buffer[1].message,
+ self.buf.buffer[2].message,
"[raw sql] [SQL parameters hidden due to hide_parameters=True]",
)
@@ -97,7 +97,7 @@ class LogParamsTest(fixtures.TestBase):
[(str(i),) for i in range(100)],
)
eq_(
- self.buf.buffer[1].message,
+ self.buf.buffer[2].message,
"[raw sql] [('0',), ('1',), ('2',), ('3',), ('4',), ('5',), "
"('6',), ('7',) ... displaying 10 of 100 total "
"bound parameter sets ... ('98',), ('99',)]",
@@ -227,7 +227,7 @@ class LogParamsTest(fixtures.TestBase):
exec_sql(self.eng, "INSERT INTO foo (data) values (?)", (largeparam,))
eq_(
- self.buf.buffer[1].message,
+ self.buf.buffer[2].message,
"[raw sql] ('%s ... (4702 characters truncated) ... %s',)"
% (largeparam[0:149], largeparam[-149:]),
)
@@ -242,7 +242,7 @@ class LogParamsTest(fixtures.TestBase):
exec_sql(self.eng, "SELECT ?, ?, ?", (lp1, lp2, lp3))
eq_(
- self.buf.buffer[1].message,
+ self.buf.buffer[2].message,
"[raw sql] ('%s', '%s', '%s ... (372 characters truncated) "
"... %s')" % (lp1, lp2, lp3[0:149], lp3[-149:]),
)
@@ -261,7 +261,7 @@ class LogParamsTest(fixtures.TestBase):
)
eq_(
- self.buf.buffer[1].message,
+ self.buf.buffer[2].message,
"[raw sql] [('%s ... (4702 characters truncated) ... %s',), "
"('%s',), "
"('%s ... (372 characters truncated) ... %s',)]"
@@ -347,20 +347,20 @@ class LogParamsTest(fixtures.TestBase):
row = result.first()
eq_(
- self.buf.buffer[1].message,
+ self.buf.buffer[2].message,
"[raw sql] ('%s ... (4702 characters truncated) ... %s',)"
% (largeparam[0:149], largeparam[-149:]),
)
if util.py3k:
eq_(
- self.buf.buffer[3].message,
+ self.buf.buffer[5].message,
"Row ('%s ... (4702 characters truncated) ... %s',)"
% (largeparam[0:149], largeparam[-149:]),
)
else:
eq_(
- self.buf.buffer[3].message,
+ self.buf.buffer[5].message,
"Row (u'%s ... (4703 characters truncated) ... %s',)"
% (largeparam[0:148], largeparam[-149:]),
)
@@ -495,7 +495,8 @@ class LoggingNameTest(fixtures.TestBase):
__requires__ = ("ad_hoc_engines",)
def _assert_names_in_execute(self, eng, eng_name, pool_name):
- eng.execute(select(1))
+ with eng.connect() as conn:
+ conn.execute(select(1))
assert self.buf.buffer
for name in [b.name for b in self.buf.buffer]:
assert name in (
@@ -505,7 +506,8 @@ class LoggingNameTest(fixtures.TestBase):
)
def _assert_no_name_in_execute(self, eng):
- eng.execute(select(1))
+ with eng.connect() as conn:
+ conn.execute(select(1))
assert self.buf.buffer
for name in [b.name for b in self.buf.buffer]:
assert name in (
@@ -548,7 +550,8 @@ class LoggingNameTest(fixtures.TestBase):
def test_named_logger_names_after_dispose(self):
eng = self._named_engine()
- eng.execute(select(1))
+ with eng.connect() as conn:
+ conn.execute(select(1))
eng.dispose()
eq_(eng.logging_name, "myenginename")
eq_(eng.pool.logging_name, "mypoolname")
@@ -568,7 +571,8 @@ class LoggingNameTest(fixtures.TestBase):
def test_named_logger_execute_after_dispose(self):
eng = self._named_engine()
- eng.execute(select(1))
+ with eng.connect() as conn:
+ conn.execute(select(1))
eng.dispose()
self._assert_names_in_execute(eng, "myenginename", "mypoolname")
@@ -599,7 +603,8 @@ class EchoTest(fixtures.TestBase):
# do an initial execute to clear out 'first connect'
# messages
- e.execute(select(10)).close()
+ with e.connect() as conn:
+ conn.execute(select(10)).close()
self.buf.flush()
return e
@@ -637,16 +642,25 @@ class EchoTest(fixtures.TestBase):
e2 = self._testing_engine()
e1.echo = True
- e1.execute(select(1)).close()
- e2.execute(select(2)).close()
+
+ with e1.connect() as conn:
+ conn.execute(select(1)).close()
+
+ with e2.connect() as conn:
+ conn.execute(select(2)).close()
e1.echo = False
- e1.execute(select(3)).close()
- e2.execute(select(4)).close()
+
+ with e1.connect() as conn:
+ conn.execute(select(3)).close()
+ with e2.connect() as conn:
+ conn.execute(select(4)).close()
e2.echo = True
- e1.execute(select(5)).close()
- e2.execute(select(6)).close()
+ with e1.connect() as conn:
+ conn.execute(select(5)).close()
+ with e2.connect() as conn:
+ conn.execute(select(6)).close()
assert self.buf.buffer[0].getMessage().startswith("SELECT 1")
assert self.buf.buffer[2].getMessage().startswith("SELECT 6")
diff --git a/test/engine/test_reconnect.py b/test/engine/test_reconnect.py
index 0dc35f99e..ebdaa79a0 100644
--- a/test/engine/test_reconnect.py
+++ b/test/engine/test_reconnect.py
@@ -1340,20 +1340,24 @@ class InvalidateDuringResultTest(fixtures.TestBase):
def setup(self):
self.engine = engines.reconnecting_engine()
- self.meta = MetaData(self.engine)
+ self.meta = MetaData()
table = Table(
"sometable",
self.meta,
Column("id", Integer, primary_key=True),
Column("name", String(50)),
)
- self.meta.create_all()
- table.insert().execute(
- [{"id": i, "name": "row %d" % i} for i in range(1, 100)]
- )
+
+ with self.engine.begin() as conn:
+ self.meta.create_all(conn)
+ conn.execute(
+ table.insert(),
+ [{"id": i, "name": "row %d" % i} for i in range(1, 100)],
+ )
def teardown(self):
- self.meta.drop_all()
+ with self.engine.begin() as conn:
+ self.meta.drop_all(conn)
self.engine.dispose()
@testing.crashes(
diff --git a/test/engine/test_reflection.py b/test/engine/test_reflection.py
index b19836c84..48b6c40d7 100644
--- a/test/engine/test_reflection.py
+++ b/test/engine/test_reflection.py
@@ -2016,7 +2016,7 @@ def createIndexes(con, schema=None):
@testing.requires.views
def _create_views(con, schema=None):
- with testing.db.connect() as conn:
+ with testing.db.begin() as conn:
for table_name in ("users", "email_addresses"):
fullname = table_name
if schema:
@@ -2031,7 +2031,7 @@ def _create_views(con, schema=None):
@testing.requires.views
def _drop_views(con, schema=None):
- with testing.db.connect() as conn:
+ with testing.db.begin() as conn:
for table_name in ("email_addresses", "users"):
fullname = table_name
if schema:
@@ -2047,7 +2047,7 @@ class ReverseCasingReflectTest(fixtures.TestBase, AssertsCompiledSQL):
@testing.requires.denormalized_names
def setup(self):
- with testing.db.connect() as conn:
+ with testing.db.begin() as conn:
conn.exec_driver_sql(
"""
CREATE TABLE weird_casing(
@@ -2060,7 +2060,7 @@ class ReverseCasingReflectTest(fixtures.TestBase, AssertsCompiledSQL):
@testing.requires.denormalized_names
def teardown(self):
- with testing.db.connect() as conn:
+ with testing.db.begin() as conn:
conn.exec_driver_sql("drop table weird_casing")
@testing.requires.denormalized_names
diff --git a/test/engine/test_transaction.py b/test/engine/test_transaction.py
index d0774e846..4db5a745a 100644
--- a/test/engine/test_transaction.py
+++ b/test/engine/test_transaction.py
@@ -5,20 +5,16 @@ from sqlalchemy import event
from sqlalchemy import exc
from sqlalchemy import func
from sqlalchemy import INT
-from sqlalchemy import Integer
from sqlalchemy import MetaData
from sqlalchemy import pool as _pool
from sqlalchemy import select
-from sqlalchemy import String
from sqlalchemy import testing
-from sqlalchemy import text
from sqlalchemy import util
from sqlalchemy import VARCHAR
from sqlalchemy.engine import base
from sqlalchemy.engine import characteristics
from sqlalchemy.engine import default
from sqlalchemy.engine import url
-from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import eq_
from sqlalchemy.testing import expect_warnings
@@ -29,31 +25,19 @@ from sqlalchemy.testing.engines import testing_engine
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
-users, metadata = None, None
-
-class TransactionTest(fixtures.TestBase):
+class TransactionTest(fixtures.TablesTest):
__backend__ = True
@classmethod
- def setup_class(cls):
- global users, metadata
- metadata = MetaData()
- users = Table(
- "query_users",
+ def define_tables(cls, metadata):
+ Table(
+ "users",
metadata,
Column("user_id", INT, primary_key=True),
Column("user_name", VARCHAR(20)),
test_needs_acid=True,
)
- users.create(testing.db)
-
- def teardown(self):
- testing.db.execute(users.delete()).close()
-
- @classmethod
- def teardown_class(cls):
- users.drop(testing.db)
@testing.fixture
def local_connection(self):
@@ -61,6 +45,7 @@ class TransactionTest(fixtures.TestBase):
yield conn
def test_commits(self, local_connection):
+ users = self.tables.users
connection = local_connection
transaction = connection.begin()
connection.execute(users.insert(), user_id=1, user_name="user1")
@@ -72,7 +57,7 @@ class TransactionTest(fixtures.TestBase):
transaction.commit()
transaction = connection.begin()
- result = connection.exec_driver_sql("select * from query_users")
+ result = connection.exec_driver_sql("select * from users")
assert len(result.fetchall()) == 3
transaction.commit()
connection.close()
@@ -80,17 +65,19 @@ class TransactionTest(fixtures.TestBase):
def test_rollback(self, local_connection):
"""test a basic rollback"""
+ users = self.tables.users
connection = local_connection
transaction = connection.begin()
connection.execute(users.insert(), user_id=1, user_name="user1")
connection.execute(users.insert(), user_id=2, user_name="user2")
connection.execute(users.insert(), user_id=3, user_name="user3")
transaction.rollback()
- result = connection.exec_driver_sql("select * from query_users")
+ result = connection.exec_driver_sql("select * from users")
assert len(result.fetchall()) == 0
def test_raise(self, local_connection):
connection = local_connection
+ users = self.tables.users
transaction = connection.begin()
try:
@@ -103,11 +90,12 @@ class TransactionTest(fixtures.TestBase):
print("Exception: ", e)
transaction.rollback()
- result = connection.exec_driver_sql("select * from query_users")
+ result = connection.exec_driver_sql("select * from users")
assert len(result.fetchall()) == 0
def test_nested_rollback(self, local_connection):
connection = local_connection
+ users = self.tables.users
try:
transaction = connection.begin()
try:
@@ -146,6 +134,7 @@ class TransactionTest(fixtures.TestBase):
def test_branch_nested_rollback(self, local_connection):
connection = local_connection
+ users = self.tables.users
connection.begin()
branched = connection.connect()
assert branched.in_transaction()
@@ -179,6 +168,7 @@ class TransactionTest(fixtures.TestBase):
@testing.requires.savepoints
def test_savepoint_cancelled_by_toplevel_marker(self, local_connection):
conn = local_connection
+ users = self.tables.users
trans = conn.begin()
conn.execute(users.insert(), {"user_id": 1, "user_name": "name"})
@@ -245,85 +235,6 @@ class TransactionTest(fixtures.TestBase):
nested.commit,
)
- def test_branch_autorollback(self, local_connection):
- connection = local_connection
- branched = connection.connect()
- branched.execute(users.insert(), dict(user_id=1, user_name="user1"))
- assert_raises(
- exc.DBAPIError,
- branched.execute,
- users.insert(),
- dict(user_id=1, user_name="user1"),
- )
- # can continue w/o issue
- branched.execute(users.insert(), dict(user_id=2, user_name="user2"))
-
- def test_branch_orig_rollback(self, local_connection):
- connection = local_connection
- branched = connection.connect()
- branched.execute(users.insert(), dict(user_id=1, user_name="user1"))
- nested = branched.begin()
- assert branched.in_transaction()
- branched.execute(users.insert(), dict(user_id=2, user_name="user2"))
- nested.rollback()
- eq_(
- connection.exec_driver_sql(
- "select count(*) from query_users"
- ).scalar(),
- 1,
- )
-
- @testing.requires.independent_connections
- def test_branch_autocommit(self, local_connection):
- with testing.db.connect() as connection:
- branched = connection.connect()
- branched.execute(
- users.insert(), dict(user_id=1, user_name="user1")
- )
-
- eq_(
- local_connection.execute(
- text("select count(*) from query_users")
- ).scalar(),
- 1,
- )
-
- @testing.requires.savepoints
- def test_branch_savepoint_rollback(self, local_connection):
- connection = local_connection
- trans = connection.begin()
- branched = connection.connect()
- assert branched.in_transaction()
- branched.execute(users.insert(), user_id=1, user_name="user1")
- nested = branched.begin_nested()
- branched.execute(users.insert(), user_id=2, user_name="user2")
- nested.rollback()
- assert connection.in_transaction()
- trans.commit()
- eq_(
- connection.exec_driver_sql(
- "select count(*) from query_users"
- ).scalar(),
- 1,
- )
-
- @testing.requires.two_phase_transactions
- def test_branch_twophase_rollback(self, local_connection):
- connection = local_connection
- branched = connection.connect()
- assert not branched.in_transaction()
- branched.execute(users.insert(), user_id=1, user_name="user1")
- nested = branched.begin_twophase()
- branched.execute(users.insert(), user_id=2, user_name="user2")
- nested.rollback()
- assert not connection.in_transaction()
- eq_(
- connection.exec_driver_sql(
- "select count(*) from query_users"
- ).scalar(),
- 1,
- )
-
def test_deactivated_warning_ctxmanager(self, local_connection):
with expect_warnings(
"transaction already deassociated from connection"
@@ -472,20 +383,20 @@ class TransactionTest(fixtures.TestBase):
def test_retains_through_options(self, local_connection):
connection = local_connection
+ users = self.tables.users
transaction = connection.begin()
connection.execute(users.insert(), user_id=1, user_name="user1")
conn2 = connection.execution_options(dummy=True)
conn2.execute(users.insert(), user_id=2, user_name="user2")
transaction.rollback()
eq_(
- connection.exec_driver_sql(
- "select count(*) from query_users"
- ).scalar(),
+ connection.exec_driver_sql("select count(*) from users").scalar(),
0,
)
def test_nesting(self, local_connection):
connection = local_connection
+ users = self.tables.users
transaction = connection.begin()
connection.execute(users.insert(), user_id=1, user_name="user1")
connection.execute(users.insert(), user_id=2, user_name="user2")
@@ -497,15 +408,16 @@ class TransactionTest(fixtures.TestBase):
transaction.rollback()
self.assert_(
connection.exec_driver_sql(
- "select count(*) from " "query_users"
+ "select count(*) from " "users"
).scalar()
== 0
)
- result = connection.exec_driver_sql("select * from query_users")
+ result = connection.exec_driver_sql("select * from users")
assert len(result.fetchall()) == 0
def test_with_interface(self, local_connection):
connection = local_connection
+ users = self.tables.users
trans = connection.begin()
connection.execute(users.insert(), user_id=1, user_name="user1")
connection.execute(users.insert(), user_id=2, user_name="user2")
@@ -517,7 +429,7 @@ class TransactionTest(fixtures.TestBase):
assert not trans.is_active
self.assert_(
connection.exec_driver_sql(
- "select count(*) from " "query_users"
+ "select count(*) from " "users"
).scalar()
== 0
)
@@ -528,13 +440,14 @@ class TransactionTest(fixtures.TestBase):
assert not trans.is_active
self.assert_(
connection.exec_driver_sql(
- "select count(*) from " "query_users"
+ "select count(*) from " "users"
).scalar()
== 1
)
def test_close(self, local_connection):
connection = local_connection
+ users = self.tables.users
transaction = connection.begin()
connection.execute(users.insert(), user_id=1, user_name="user1")
connection.execute(users.insert(), user_id=2, user_name="user2")
@@ -549,15 +462,16 @@ class TransactionTest(fixtures.TestBase):
assert not connection.in_transaction()
self.assert_(
connection.exec_driver_sql(
- "select count(*) from " "query_users"
+ "select count(*) from " "users"
).scalar()
== 5
)
- result = connection.exec_driver_sql("select * from query_users")
+ result = connection.exec_driver_sql("select * from users")
assert len(result.fetchall()) == 5
def test_close2(self, local_connection):
connection = local_connection
+ users = self.tables.users
transaction = connection.begin()
connection.execute(users.insert(), user_id=1, user_name="user1")
connection.execute(users.insert(), user_id=2, user_name="user2")
@@ -572,16 +486,17 @@ class TransactionTest(fixtures.TestBase):
assert not connection.in_transaction()
self.assert_(
connection.exec_driver_sql(
- "select count(*) from " "query_users"
+ "select count(*) from " "users"
).scalar()
== 0
)
- result = connection.exec_driver_sql("select * from query_users")
+ result = connection.exec_driver_sql("select * from users")
assert len(result.fetchall()) == 0
@testing.requires.savepoints
def test_nested_subtransaction_rollback(self, local_connection):
connection = local_connection
+ users = self.tables.users
transaction = connection.begin()
connection.execute(users.insert(), user_id=1, user_name="user1")
trans2 = connection.begin_nested()
@@ -599,6 +514,7 @@ class TransactionTest(fixtures.TestBase):
@testing.requires.savepoints
def test_nested_subtransaction_commit(self, local_connection):
connection = local_connection
+ users = self.tables.users
transaction = connection.begin()
connection.execute(users.insert(), user_id=1, user_name="user1")
trans2 = connection.begin_nested()
@@ -616,6 +532,7 @@ class TransactionTest(fixtures.TestBase):
@testing.requires.savepoints
def test_rollback_to_subtransaction(self, local_connection):
connection = local_connection
+ users = self.tables.users
transaction = connection.begin()
connection.execute(users.insert(), user_id=1, user_name="user1")
trans2 = connection.begin_nested()
@@ -646,6 +563,7 @@ class TransactionTest(fixtures.TestBase):
@testing.requires.two_phase_transactions
def test_two_phase_transaction(self, local_connection):
connection = local_connection
+ users = self.tables.users
transaction = connection.begin_twophase()
connection.execute(users.insert(), user_id=1, user_name="user1")
transaction.prepare()
@@ -680,6 +598,7 @@ class TransactionTest(fixtures.TestBase):
@testing.requires.savepoints
def test_mixed_two_phase_transaction(self, local_connection):
connection = local_connection
+ users = self.tables.users
transaction = connection.begin_twophase()
connection.execute(users.insert(), user_id=1, user_name="user1")
transaction2 = connection.begin()
@@ -704,6 +623,7 @@ class TransactionTest(fixtures.TestBase):
@testing.requires.two_phase_transactions
@testing.requires.two_phase_recovery
def test_two_phase_recover(self):
+ users = self.tables.users
# 2020, still can't get this to work w/ modern MySQL or MariaDB.
# the XA RECOVER comes back as bytes, OK, convert to string,
@@ -722,11 +642,14 @@ class TransactionTest(fixtures.TestBase):
with testing.db.connect() as connection2:
eq_(
- connection2.execution_options(autocommit=True)
- .execute(select(users.c.user_id).order_by(users.c.user_id))
- .fetchall(),
+ connection2.execute(
+ select(users.c.user_id).order_by(users.c.user_id)
+ ).fetchall(),
[],
)
+
+ # recover_twophase needs to be run in a new transaction
+ with testing.db.connect() as connection2:
recoverables = connection2.recover_twophase()
assert transaction.xid in recoverables
connection2.commit_prepared(transaction.xid, recover=True)
@@ -740,6 +663,7 @@ class TransactionTest(fixtures.TestBase):
@testing.requires.two_phase_transactions
def test_multiple_two_phase(self, local_connection):
conn = local_connection
+ users = self.tables.users
xa = conn.begin_twophase()
conn.execute(users.insert(), user_id=1, user_name="user1")
xa.prepare()
@@ -767,6 +691,7 @@ class TransactionTest(fixtures.TestBase):
# so that picky backends like MySQL correctly clear out
# their state when a connection is closed without handling
# the transaction explicitly.
+ users = self.tables.users
eng = testing_engine()
@@ -1005,7 +930,8 @@ class AutoRollbackTest(fixtures.TestBase):
Column("user_name", VARCHAR(20)),
test_needs_acid=True,
)
- users.create(conn1)
+ with conn1.begin():
+ users.create(conn1)
conn1.exec_driver_sql("select * from deadlock_users")
conn1.close()
@@ -1014,125 +940,8 @@ class AutoRollbackTest(fixtures.TestBase):
# pool but still has a lock on "deadlock_users". comment out the
# rollback in pool/ConnectionFairy._close() to see !
- users.drop(conn2)
- conn2.close()
-
-
-class ExplicitAutoCommitTest(fixtures.TestBase):
-
- """test the 'autocommit' flag on select() and text() objects.
-
- Requires PostgreSQL so that we may define a custom function which
- modifies the database."""
-
- __only_on__ = "postgresql"
-
- @classmethod
- def setup_class(cls):
- global metadata, foo
- metadata = MetaData(testing.db)
- foo = Table(
- "foo",
- metadata,
- Column("id", Integer, primary_key=True),
- Column("data", String(100)),
- )
- with testing.db.connect() as conn:
- metadata.create_all(conn)
- conn.exec_driver_sql(
- "create function insert_foo(varchar) "
- "returns integer as 'insert into foo(data) "
- "values ($1);select 1;' language sql"
- )
-
- def teardown(self):
- with testing.db.connect() as conn:
- conn.execute(foo.delete())
-
- @classmethod
- def teardown_class(cls):
- with testing.db.connect() as conn:
- conn.exec_driver_sql("drop function insert_foo(varchar)")
- metadata.drop_all(conn)
-
- def test_control(self):
-
- # test that not using autocommit does not commit
-
- conn1 = testing.db.connect()
- conn2 = testing.db.connect()
- conn1.execute(select(func.insert_foo("data1")))
- assert conn2.execute(select(foo.c.data)).fetchall() == []
- conn1.execute(text("select insert_foo('moredata')"))
- assert conn2.execute(select(foo.c.data)).fetchall() == []
- trans = conn1.begin()
- trans.commit()
- assert conn2.execute(select(foo.c.data)).fetchall() == [
- ("data1",),
- ("moredata",),
- ]
- conn1.close()
- conn2.close()
-
- def test_explicit_compiled(self):
- conn1 = testing.db.connect()
- conn2 = testing.db.connect()
- conn1.execute(
- select(func.insert_foo("data1")).execution_options(autocommit=True)
- )
- assert conn2.execute(select(foo.c.data)).fetchall() == [("data1",)]
- conn1.close()
- conn2.close()
-
- def test_explicit_connection(self):
- conn1 = testing.db.connect()
- conn2 = testing.db.connect()
- conn1.execution_options(autocommit=True).execute(
- select(func.insert_foo("data1"))
- )
- eq_(conn2.execute(select(foo.c.data)).fetchall(), [("data1",)])
-
- # connection supersedes statement
-
- conn1.execution_options(autocommit=False).execute(
- select(func.insert_foo("data2")).execution_options(autocommit=True)
- )
- eq_(conn2.execute(select(foo.c.data)).fetchall(), [("data1",)])
-
- # ditto
-
- conn1.execution_options(autocommit=True).execute(
- select(func.insert_foo("data3")).execution_options(
- autocommit=False
- )
- )
- eq_(
- conn2.execute(select(foo.c.data)).fetchall(),
- [("data1",), ("data2",), ("data3",)],
- )
- conn1.close()
- conn2.close()
-
- def test_explicit_text(self):
- conn1 = testing.db.connect()
- conn2 = testing.db.connect()
- conn1.execute(
- text("select insert_foo('moredata')").execution_options(
- autocommit=True
- )
- )
- assert conn2.execute(select(foo.c.data)).fetchall() == [("moredata",)]
- conn1.close()
- conn2.close()
-
- def test_implicit_text(self):
- conn1 = testing.db.connect()
- conn2 = testing.db.connect()
- conn1.execute(text("insert into foo (data) values ('implicitdata')"))
- assert conn2.execute(select(foo.c.data)).fetchall() == [
- ("implicitdata",)
- ]
- conn1.close()
+ with conn2.begin():
+ users.drop(conn2)
conn2.close()
diff --git a/test/ext/test_associationproxy.py b/test/ext/test_associationproxy.py
index 3cb29c67d..df27c8d27 100644
--- a/test/ext/test_associationproxy.py
+++ b/test/ext/test_associationproxy.py
@@ -1329,10 +1329,13 @@ class KVChild(object):
self.value = value
-class ReconstitutionTest(fixtures.TestBase):
- def setup(self):
- metadata = MetaData(testing.db)
- parents = Table(
+class ReconstitutionTest(fixtures.MappedTest):
+ run_setup_mappers = "each"
+ run_setup_classes = "each"
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
"parents",
metadata,
Column(
@@ -1340,7 +1343,7 @@ class ReconstitutionTest(fixtures.TestBase):
),
Column("name", String(30)),
)
- children = Table(
+ Table(
"children",
metadata,
Column(
@@ -1349,22 +1352,23 @@ class ReconstitutionTest(fixtures.TestBase):
Column("parent_id", Integer, ForeignKey("parents.id")),
Column("name", String(30)),
)
- metadata.create_all()
- parents.insert().execute(name="p1")
- self.metadata = metadata
- self.parents = parents
- self.children = children
- Parent.kids = association_proxy("children", "name")
- def teardown(self):
- self.metadata.drop_all()
- clear_mappers()
+ @classmethod
+ def insert_data(cls, connection):
+ parents = cls.tables.parents
+ connection.execute(parents.insert(), dict(name="p1"))
+
+ @classmethod
+ def setup_classes(cls):
+ Parent.kids = association_proxy("children", "name")
def test_weak_identity_map(self):
mapper(
- Parent, self.parents, properties=dict(children=relationship(Child))
+ Parent,
+ self.tables.parents,
+ properties=dict(children=relationship(Child)),
)
- mapper(Child, self.children)
+ mapper(Child, self.tables.children)
session = create_session()
def add_child(parent_name, child_name):
@@ -1380,9 +1384,11 @@ class ReconstitutionTest(fixtures.TestBase):
def test_copy(self):
mapper(
- Parent, self.parents, properties=dict(children=relationship(Child))
+ Parent,
+ self.tables.parents,
+ properties=dict(children=relationship(Child)),
)
- mapper(Child, self.children)
+ mapper(Child, self.tables.children)
p = Parent("p1")
p.kids.extend(["c1", "c2"])
p_copy = copy.copy(p)
@@ -1392,9 +1398,11 @@ class ReconstitutionTest(fixtures.TestBase):
def test_pickle_list(self):
mapper(
- Parent, self.parents, properties=dict(children=relationship(Child))
+ Parent,
+ self.tables.parents,
+ properties=dict(children=relationship(Child)),
)
- mapper(Child, self.children)
+ mapper(Child, self.tables.children)
p = Parent("p1")
p.kids.extend(["c1", "c2"])
r1 = pickle.loads(pickle.dumps(p))
@@ -1407,12 +1415,12 @@ class ReconstitutionTest(fixtures.TestBase):
def test_pickle_set(self):
mapper(
Parent,
- self.parents,
+ self.tables.parents,
properties=dict(
children=relationship(Child, collection_class=set)
),
)
- mapper(Child, self.children)
+ mapper(Child, self.tables.children)
p = Parent("p1")
p.kids.update(["c1", "c2"])
r1 = pickle.loads(pickle.dumps(p))
@@ -1425,7 +1433,7 @@ class ReconstitutionTest(fixtures.TestBase):
def test_pickle_dict(self):
mapper(
Parent,
- self.parents,
+ self.tables.parents,
properties=dict(
children=relationship(
KVChild,
@@ -1435,7 +1443,7 @@ class ReconstitutionTest(fixtures.TestBase):
)
),
)
- mapper(KVChild, self.children)
+ mapper(KVChild, self.tables.children)
p = Parent("p1")
p.kids.update({"c1": "v1", "c2": "v2"})
assert p.kids == {"c1": "c1", "c2": "c2"}
diff --git a/test/ext/test_horizontal_shard.py b/test/ext/test_horizontal_shard.py
index a8c17d7ac..e46c65ff0 100644
--- a/test/ext/test_horizontal_shard.py
+++ b/test/ext/test_horizontal_shard.py
@@ -53,10 +53,10 @@ class ShardTest(object):
def id_generator(ctx):
# in reality, might want to use a separate transaction for this.
- c = db1.connect()
- nextid = c.execute(ids.select().with_for_update()).scalar()
- c.execute(ids.update(values={ids.c.nextid: ids.c.nextid + 1}))
- return nextid
+ with db1.begin() as c:
+ nextid = c.execute(ids.select().with_for_update()).scalar()
+ c.execute(ids.update(values={ids.c.nextid: ids.c.nextid + 1}))
+ return nextid
weather_locations = Table(
"weather_locations",
@@ -80,7 +80,8 @@ class ShardTest(object):
for db in (db1, db2, db3, db4):
meta.create_all(db)
- db1.execute(ids.insert(), nextid=1)
+ with db1.begin() as conn:
+ conn.execute(ids.insert(), dict(nextid=1))
self.setup_session()
self.setup_mappers()
@@ -762,7 +763,7 @@ class MultipleDialectShardTest(ShardTest, fixtures.TestBase):
)
e2 = testing_engine()
- with e2.connect() as conn:
+ with e2.begin() as conn:
for i in [2, 4]:
conn.exec_driver_sql(
"CREATE SCHEMA IF NOT EXISTS shard%s" % (i,)
@@ -784,7 +785,7 @@ class MultipleDialectShardTest(ShardTest, fixtures.TestBase):
for i in [1, 3]:
os.remove("shard%d_%s.db" % (i, provision.FOLLOWER_IDENT))
- with self.postgresql_engine.connect() as conn:
+ with self.postgresql_engine.begin() as conn:
self.metadata.drop_all(conn)
for i in [2, 4]:
conn.exec_driver_sql("DROP SCHEMA shard%s CASCADE" % (i,))
diff --git a/test/orm/inheritance/test_selects.py b/test/orm/inheritance/test_selects.py
index c9a78db08..dab184194 100644
--- a/test/orm/inheritance/test_selects.py
+++ b/test/orm/inheritance/test_selects.py
@@ -2,7 +2,6 @@ from sqlalchemy import ForeignKey
from sqlalchemy import Integer
from sqlalchemy import select
from sqlalchemy import String
-from sqlalchemy import testing
from sqlalchemy.orm import mapper
from sqlalchemy.orm import Session
from sqlalchemy.testing import eq_
@@ -24,13 +23,13 @@ class InheritingSelectablesTest(fixtures.MappedTest):
cls.tables.bar = foo.select(foo.c.b == "bar").alias("bar")
cls.tables.baz = foo.select(foo.c.b == "baz").alias("baz")
- def test_load(self):
+ def test_load(self, connection):
foo, bar, baz = self.tables.foo, self.tables.bar, self.tables.baz
# TODO: add persistence test also
- testing.db.execute(foo.insert(), a="not bar", b="baz")
- testing.db.execute(foo.insert(), a="also not bar", b="baz")
- testing.db.execute(foo.insert(), a="i am bar", b="bar")
- testing.db.execute(foo.insert(), a="also bar", b="bar")
+ connection.execute(foo.insert(), dict(a="not bar", b="baz"))
+ connection.execute(foo.insert(), dict(a="also not bar", b="baz"))
+ connection.execute(foo.insert(), dict(a="i am bar", b="bar"))
+ connection.execute(foo.insert(), dict(a="also bar", b="bar"))
class Foo(fixtures.ComparableEntity):
pass
@@ -69,8 +68,8 @@ class InheritingSelectablesTest(fixtures.MappedTest):
polymorphic_identity="bar",
)
- s = Session()
- assert [Bar(), Bar()] == s.query(Bar).all()
+ s = Session(connection)
+ eq_(s.query(Bar).all(), [Bar(), Bar()])
class JoinFromSelectPersistenceTest(fixtures.MappedTest):
diff --git a/test/orm/test_bind.py b/test/orm/test_bind.py
index 3a9959857..64f85b335 100644
--- a/test/orm/test_bind.py
+++ b/test/orm/test_bind.py
@@ -151,7 +151,7 @@ class BindIntegrationTest(_fixtures.FixtureTest):
mapper(User, users)
- session = create_session()
+ session = Session()
session.execute(users.insert(), dict(name="Johnny"))
@@ -447,7 +447,9 @@ class BindIntegrationTest(_fixtures.FixtureTest):
sess.commit()
assert not c.in_transaction()
assert c.exec_driver_sql("select count(1) from users").scalar() == 1
- c.exec_driver_sql("delete from users")
+
+ with c.begin():
+ c.exec_driver_sql("delete from users")
assert c.exec_driver_sql("select count(1) from users").scalar() == 0
c = testing.db.connect()
diff --git a/test/orm/test_compile.py b/test/orm/test_compile.py
index dcf07eec8..c6a1226d4 100644
--- a/test/orm/test_compile.py
+++ b/test/orm/test_compile.py
@@ -190,8 +190,9 @@ class CompileTest(fixtures.ORMTest):
sa_exc.ArgumentError, "Error creating backref", configure_mappers
)
- def test_misc_one(self):
- metadata = MetaData(testing.db)
+ @testing.provide_metadata
+ def test_misc_one(self, connection):
+ metadata = self.metadata
node_table = Table(
"node",
metadata,
@@ -212,33 +213,30 @@ class CompileTest(fixtures.ORMTest):
Column("host_id", Integer, primary_key=True),
Column("hostname", String(64), nullable=False, unique=True),
)
- metadata.create_all()
- try:
- node_table.insert().execute(node_id=1, node_index=5)
-
- class Node(object):
- pass
-
- class NodeName(object):
- pass
-
- class Host(object):
- pass
-
- mapper(Node, node_table)
- mapper(Host, host_table)
- mapper(
- NodeName,
- node_name_table,
- properties={
- "node": relationship(Node, backref=backref("names")),
- "host": relationship(Host),
- },
- )
- sess = create_session()
- assert sess.query(Node).get(1).names == []
- finally:
- metadata.drop_all()
+ metadata.create_all(connection)
+ connection.execute(node_table.insert(), dict(node_id=1, node_index=5))
+
+ class Node(object):
+ pass
+
+ class NodeName(object):
+ pass
+
+ class Host(object):
+ pass
+
+ mapper(Node, node_table)
+ mapper(Host, host_table)
+ mapper(
+ NodeName,
+ node_name_table,
+ properties={
+ "node": relationship(Node, backref=backref("names")),
+ "host": relationship(Host),
+ },
+ )
+ sess = create_session(connection)
+ assert sess.query(Node).get(1).names == []
def test_conflicting_backref_two(self):
meta = MetaData()
diff --git a/test/orm/test_eager_relations.py b/test/orm/test_eager_relations.py
index 57225d640..7bc82b2a3 100644
--- a/test/orm/test_eager_relations.py
+++ b/test/orm/test_eager_relations.py
@@ -4808,6 +4808,8 @@ class MixedEntitiesTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
class SubqueryTest(fixtures.MappedTest):
+ run_deletes = "each"
+
@classmethod
def define_tables(cls, metadata):
Table(
@@ -4830,7 +4832,12 @@ class SubqueryTest(fixtures.MappedTest):
Column("score2", sa.Float),
)
- def test_label_anonymizing(self):
+ @testing.combinations(
+ (True, "score"),
+ (True, None),
+ (False, None),
+ )
+ def test_label_anonymizing(self, labeled, labelname):
"""Eager loading works with subqueries with labels,
Even if an explicit labelname which conflicts with a label on the
@@ -4859,75 +4866,65 @@ class SubqueryTest(fixtures.MappedTest):
def prop_score(self):
return self.score1 * self.score2
- for labeled, labelname in [
- (True, "score"),
- (True, None),
- (False, None),
- ]:
- sa.orm.clear_mappers()
-
- tag_score = tags_table.c.score1 * tags_table.c.score2
- user_score = sa.select(
- sa.func.sum(tags_table.c.score1 * tags_table.c.score2)
- ).where(
- tags_table.c.user_id == users_table.c.id,
- )
+ tag_score = tags_table.c.score1 * tags_table.c.score2
+ user_score = sa.select(
+ sa.func.sum(tags_table.c.score1 * tags_table.c.score2)
+ ).where(
+ tags_table.c.user_id == users_table.c.id,
+ )
- if labeled:
- tag_score = tag_score.label(labelname)
- user_score = user_score.label(labelname)
- else:
- user_score = user_score.scalar_subquery()
+ if labeled:
+ tag_score = tag_score.label(labelname)
+ user_score = user_score.label(labelname)
+ else:
+ user_score = user_score.scalar_subquery()
- mapper(
- Tag,
- tags_table,
- properties={"query_score": sa.orm.column_property(tag_score)},
- )
+ mapper(
+ Tag,
+ tags_table,
+ properties={"query_score": sa.orm.column_property(tag_score)},
+ )
- mapper(
- User,
- users_table,
- properties={
- "tags": relationship(Tag, backref="user", lazy="joined"),
- "query_score": sa.orm.column_property(user_score),
- },
- )
+ mapper(
+ User,
+ users_table,
+ properties={
+ "tags": relationship(Tag, backref="user", lazy="joined"),
+ "query_score": sa.orm.column_property(user_score),
+ },
+ )
- session = create_session()
- session.add(
- User(
- name="joe",
- tags=[
- Tag(score1=5.0, score2=3.0),
- Tag(score1=55.0, score2=1.0),
- ],
- )
+ session = create_session()
+ session.add(
+ User(
+ name="joe",
+ tags=[
+ Tag(score1=5.0, score2=3.0),
+ Tag(score1=55.0, score2=1.0),
+ ],
)
- session.add(
- User(
- name="bar",
- tags=[
- Tag(score1=5.0, score2=4.0),
- Tag(score1=50.0, score2=1.0),
- Tag(score1=15.0, score2=2.0),
- ],
- )
+ )
+ session.add(
+ User(
+ name="bar",
+ tags=[
+ Tag(score1=5.0, score2=4.0),
+ Tag(score1=50.0, score2=1.0),
+ Tag(score1=15.0, score2=2.0),
+ ],
)
- session.flush()
- session.expunge_all()
-
- for user in session.query(User).all():
- eq_(user.query_score, user.prop_score)
+ )
+ session.flush()
+ session.expunge_all()
- def go():
- u = session.query(User).filter_by(name="joe").one()
- eq_(u.query_score, u.prop_score)
+ for user in session.query(User).all():
+ eq_(user.query_score, user.prop_score)
- self.assert_sql_count(testing.db, go, 1)
+ def go():
+ u = session.query(User).filter_by(name="joe").one()
+ eq_(u.query_score, u.prop_score)
- for t in (tags_table, users_table):
- t.delete().execute()
+ self.assert_sql_count(testing.db, go, 1)
class CorrelatedSubqueryTest(fixtures.MappedTest):
diff --git a/test/orm/test_expire.py b/test/orm/test_expire.py
index 7ccf2c1ae..5abaa03db 100644
--- a/test/orm/test_expire.py
+++ b/test/orm/test_expire.py
@@ -9,7 +9,6 @@ from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy import testing
from sqlalchemy.orm import attributes
-from sqlalchemy.orm import create_session
from sqlalchemy.orm import defer
from sqlalchemy.orm import deferred
from sqlalchemy.orm import exc as orm_exc
@@ -26,6 +25,7 @@ from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.fixtures import create_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
from sqlalchemy.testing.util import gc_collect
@@ -66,7 +66,7 @@ class ExpireTest(_fixtures.FixtureTest):
u.name = "foo"
sess.flush()
# change the value in the DB
- users.update(users.c.id == 7, values=dict(name="jack")).execute()
+ sess.execute(users.update(users.c.id == 7, values=dict(name="jack")))
sess.expire(u)
# object isn't refreshed yet, using dict to bypass trigger
assert u.__dict__.get("name") != "jack"
@@ -471,7 +471,7 @@ class ExpireTest(_fixtures.FixtureTest):
o = sess.query(Order).get(3)
sess.expire(o)
- orders.update().execute(description="order 3 modified")
+ sess.execute(orders.update(), dict(description="order 3 modified"))
assert o.isopen == 1
assert (
attributes.instance_state(o).dict["description"]
@@ -788,7 +788,7 @@ class ExpireTest(_fixtures.FixtureTest):
sess.expire(u)
assert "name" not in u.__dict__
- users.update(users.c.id == 7).execute(name="jack2")
+ sess.execute(users.update(users.c.id == 7), dict(name="jack2"))
assert u.name == "jack2"
assert u.uname == "jack2"
assert "name" in u.__dict__
@@ -812,7 +812,10 @@ class ExpireTest(_fixtures.FixtureTest):
assert "description" not in o.__dict__
assert attributes.instance_state(o).dict["isopen"] == 1
- orders.update(orders.c.id == 3).execute(description="order 3 modified")
+ sess.execute(
+ orders.update(orders.c.id == 3),
+ dict(description="order 3 modified"),
+ )
def go():
assert o.description == "order 3 modified"
@@ -1660,12 +1663,9 @@ class LifecycleTest(fixtures.MappedTest):
def test_cols_missing_in_load(self):
Data = self.classes.Data
- sess = create_session()
-
- d1 = Data(data="d1")
- sess.add(d1)
- sess.flush()
- sess.close()
+ with Session(testing.db) as sess, sess.begin():
+ d1 = Data(data="d1")
+ sess.add(d1)
sess = create_session()
d1 = sess.query(Data).from_statement(select(Data.id)).first()
@@ -1679,21 +1679,18 @@ class LifecycleTest(fixtures.MappedTest):
def test_deferred_cols_missing_in_load_state_reset(self):
Data = self.classes.DataDefer
- sess = create_session()
+ with Session(testing.db) as sess, sess.begin():
+ d1 = Data(data="d1")
+ sess.add(d1)
- d1 = Data(data="d1")
- sess.add(d1)
- sess.flush()
- sess.close()
-
- sess = create_session()
- d1 = (
- sess.query(Data)
- .from_statement(select(Data.id))
- .options(undefer(Data.data))
- .first()
- )
- d1.data = "d2"
+ with Session(testing.db) as sess:
+ d1 = (
+ sess.query(Data)
+ .from_statement(select(Data.id))
+ .options(undefer(Data.data))
+ .first()
+ )
+ d1.data = "d2"
# the deferred loader has to clear out any state
# on the col, including that 'd2' here
diff --git a/test/orm/test_lazy_relations.py b/test/orm/test_lazy_relations.py
index c1cc85261..e1c0ec77b 100644
--- a/test/orm/test_lazy_relations.py
+++ b/test/orm/test_lazy_relations.py
@@ -1302,18 +1302,22 @@ class O2MWOSideFixedTest(fixtures.MappedTest):
def _fixture(self, include_other):
city, person = self.tables.city, self.tables.person
- if include_other:
- city.insert().execute({"id": 1, "deleted": False})
-
- person.insert().execute(
- {"id": 1, "city_id": 1}, {"id": 2, "city_id": 1}
- )
+ with testing.db.begin() as conn:
+ if include_other:
+ conn.execute(city.insert(), {"id": 1, "deleted": False})
+
+ conn.execute(
+ person.insert(),
+ {"id": 1, "city_id": 1},
+ {"id": 2, "city_id": 1},
+ )
- city.insert().execute({"id": 2, "deleted": True})
+ conn.execute(city.insert(), {"id": 2, "deleted": True})
- person.insert().execute(
- {"id": 3, "city_id": 2}, {"id": 4, "city_id": 2}
- )
+ conn.execute(
+ person.insert(),
+ [{"id": 3, "city_id": 2}, {"id": 4, "city_id": 2}],
+ )
def test_lazyload_assert_expected_sql(self):
self._fixture(True)
diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py
index fc6caa75d..edbb4b0cd 100644
--- a/test/orm/test_mapper.py
+++ b/test/orm/test_mapper.py
@@ -129,7 +129,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL):
)
assert_raises(sa.exc.ArgumentError, sa.orm.configure_mappers)
- def test_update_attr_keys(self):
+ def test_update_attr_keys(self, connection):
"""test that update()/insert() use the correct key when given
InstrumentedAttributes."""
@@ -137,21 +137,21 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL):
self.mapper(User, users, properties={"foobar": users.c.name})
- users.insert().values({User.foobar: "name1"}).execute()
+ connection.execute(users.insert().values({User.foobar: "name1"}))
eq_(
- sa.select(User.foobar)
- .where(User.foobar == "name1")
- .execute()
- .fetchall(),
+ connection.execute(
+ sa.select(User.foobar).where(User.foobar == "name1")
+ ).fetchall(),
[("name1",)],
)
- users.update().values({User.foobar: User.foobar + "foo"}).execute()
+ connection.execute(
+ users.update().values({User.foobar: User.foobar + "foo"})
+ )
eq_(
- sa.select(User.foobar)
- .where(User.foobar == "name1foo")
- .execute()
- .fetchall(),
+ connection.execute(
+ sa.select(User.foobar).where(User.foobar == "name1foo")
+ ).fetchall(),
[("name1foo",)],
)
diff --git a/test/orm/test_naturalpks.py b/test/orm/test_naturalpks.py
index 87ec0d79d..d814b0cab 100644
--- a/test/orm/test_naturalpks.py
+++ b/test/orm/test_naturalpks.py
@@ -12,7 +12,6 @@ from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy import testing
from sqlalchemy import TypeDecorator
-from sqlalchemy.orm import create_session
from sqlalchemy.orm import mapper
from sqlalchemy.orm import relationship
from sqlalchemy.orm import Session
@@ -23,6 +22,7 @@ from sqlalchemy.testing import eq_
from sqlalchemy.testing import expect_warnings
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import ne_
+from sqlalchemy.testing.fixtures import create_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
from test.orm import _fixtures
@@ -141,7 +141,9 @@ class NaturalPKTest(fixtures.MappedTest):
sess.flush()
assert sess.query(User).get("jack") is u1
- users.update(values={User.username: "jack"}).execute(username="ed")
+ sess.execute(
+ users.update(values={User.username: "jack"}), dict(username="ed")
+ )
# expire/refresh works off of primary key. the PK is gone
# in this case so there's no way to look it up. criterion-
@@ -1089,7 +1091,7 @@ class NonPKCascadeTest(fixtures.MappedTest):
a1 = u1.addresses[0]
eq_(
- sa.select(addresses.c.username).execute().fetchall(),
+ sess.execute(sa.select(addresses.c.username)).fetchall(),
[("jack",), ("jack",)],
)
@@ -1099,7 +1101,7 @@ class NonPKCascadeTest(fixtures.MappedTest):
sess.flush()
assert u1.addresses[0].username == "ed"
eq_(
- sa.select(addresses.c.username).execute().fetchall(),
+ sess.execute(sa.select(addresses.c.username)).fetchall(),
[("ed",), ("ed",)],
)
@@ -1141,7 +1143,7 @@ class NonPKCascadeTest(fixtures.MappedTest):
eq_(a1.username, None)
eq_(
- sa.select(addresses.c.username).execute().fetchall(),
+ sess.execute(sa.select(addresses.c.username)).fetchall(),
[(None,), (None,)],
)
@@ -1454,7 +1456,7 @@ class CascadeToFKPKTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
eq_(a1.username, "ed")
eq_(a2.username, "ed")
eq_(
- sa.select(addresses.c.username).execute().fetchall(),
+ sess.execute(sa.select(addresses.c.username)).fetchall(),
[("ed",), ("ed",)],
)
@@ -1465,7 +1467,7 @@ class CascadeToFKPKTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
eq_(a1.username, "jack")
eq_(a2.username, "jack")
eq_(
- sa.select(addresses.c.username).execute().fetchall(),
+ sess.execute(sa.select(addresses.c.username)).fetchall(),
[("jack",), ("jack",)],
)
diff --git a/test/orm/test_query.py b/test/orm/test_query.py
index 8cca45b27..9e528dc0d 100644
--- a/test/orm/test_query.py
+++ b/test/orm/test_query.py
@@ -806,7 +806,7 @@ class GetTest(QueryTest):
@testing.provide_metadata
@testing.requires.unicode_connections
- def test_unicode(self):
+ def test_unicode(self, connection):
"""test that Query.get properly sets up the type for the bind
parameter. using unicode would normally fail on postgresql, mysql and
oracle unless it is converted to an encoded string"""
@@ -818,19 +818,20 @@ class GetTest(QueryTest):
Column("id", Unicode(40), primary_key=True),
Column("data", Unicode(40)),
)
- metadata.create_all()
+ metadata.create_all(connection)
ustring = util.b("petit voix m\xe2\x80\x99a").decode("utf-8")
- table.insert().execute(id=ustring, data=ustring)
+ connection.execute(table.insert(), dict(id=ustring, data=ustring))
class LocalFoo(self.classes.Base):
pass
mapper(LocalFoo, table)
- eq_(
- create_session().query(LocalFoo).get(ustring),
- LocalFoo(id=ustring, data=ustring),
- )
+ with Session(connection) as sess:
+ eq_(
+ sess.get(LocalFoo, ustring),
+ LocalFoo(id=ustring, data=ustring),
+ )
def test_populate_existing(self):
User, Address = self.classes.User, self.classes.Address
diff --git a/test/orm/test_session.py b/test/orm/test_session.py
index 165008234..d2838e5bf 100644
--- a/test/orm/test_session.py
+++ b/test/orm/test_session.py
@@ -12,7 +12,6 @@ from sqlalchemy import testing
from sqlalchemy.orm import attributes
from sqlalchemy.orm import backref
from sqlalchemy.orm import close_all_sessions
-from sqlalchemy.orm import create_session
from sqlalchemy.orm import exc as orm_exc
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import make_transient
@@ -35,6 +34,7 @@ from sqlalchemy.testing import is_not
from sqlalchemy.testing import is_true
from sqlalchemy.testing import mock
from sqlalchemy.testing import pickleable
+from sqlalchemy.testing.fixtures import create_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
from sqlalchemy.testing.util import gc_collect
@@ -48,33 +48,33 @@ class ExecutionTest(_fixtures.FixtureTest):
__backend__ = True
@testing.requires.sequences
- def test_sequence_execute(self):
+ def test_sequence_execute(self, connection):
seq = Sequence("some_sequence")
- seq.create(testing.db)
+ seq.create(connection)
try:
- sess = create_session(bind=testing.db)
- eq_(sess.execute(seq), testing.db.dialect.default_sequence_base)
+ sess = Session(connection)
+ eq_(sess.execute(seq), connection.dialect.default_sequence_base)
finally:
- seq.drop(testing.db)
+ seq.drop(connection)
- def test_textual_execute(self):
+ def test_textual_execute(self, connection):
"""test that Session.execute() converts to text()"""
users = self.tables.users
- sess = create_session(bind=self.metadata.bind)
- users.insert().execute(id=7, name="jack")
+ with Session(bind=connection) as sess:
+ sess.execute(users.insert(), dict(id=7, name="jack"))
- # use :bindparam style
- eq_(
- sess.execute(
- "select * from users where id=:id", {"id": 7}
- ).fetchall(),
- [(7, "jack")],
- )
+ # use :bindparam style
+ eq_(
+ sess.execute(
+ "select * from users where id=:id", {"id": 7}
+ ).fetchall(),
+ [(7, "jack")],
+ )
- # use :bindparam style
- eq_(sess.scalar("select id from users where id=:id", {"id": 7}), 7)
+ # use :bindparam style
+ eq_(sess.scalar("select id from users where id=:id", {"id": 7}), 7)
def test_parameter_execute(self):
users = self.tables.users
@@ -104,7 +104,7 @@ class TransScopingTest(_fixtures.FixtureTest):
c.exec_driver_sql("select * from users")
mapper(User, users)
- s = create_session(bind=c)
+ s = Session(bind=c)
s.add(User(name="first"))
s.flush()
c.exec_driver_sql("select * from users")
@@ -118,7 +118,7 @@ class TransScopingTest(_fixtures.FixtureTest):
c.exec_driver_sql("select * from users")
mapper(User, users)
- s = create_session(bind=c)
+ s = Session(bind=c)
s.add(User(name="first"))
s.flush()
c.exec_driver_sql("select * from users")
@@ -189,7 +189,7 @@ class TransScopingTest(_fixtures.FixtureTest):
conn1 = testing.db.connect()
conn2 = testing.db.connect()
- sess = create_session(autocommit=False, bind=conn1)
+ sess = Session(autocommit=False, bind=conn1)
u = User(name="x")
sess.add(u)
sess.flush()
@@ -415,7 +415,7 @@ class SessionStateTest(_fixtures.FixtureTest):
conn1 = bind.connect()
conn2 = bind.connect()
- sess = create_session(bind=conn1, autocommit=False, autoflush=True)
+ sess = Session(bind=conn1, autocommit=False, autoflush=True)
u = User()
u.name = "ed"
sess.add(u)
@@ -600,7 +600,7 @@ class SessionStateTest(_fixtures.FixtureTest):
mapper(User, users)
conn1 = testing.db.connect()
- sess = create_session(bind=conn1, autocommit=False, autoflush=True)
+ sess = Session(bind=conn1, autocommit=False, autoflush=True)
u = User()
u.name = "ed"
sess.add(u)
@@ -620,7 +620,7 @@ class SessionStateTest(_fixtures.FixtureTest):
User, users = self.classes.User, self.tables.users
mapper(User, users)
- session = create_session(autocommit=True)
+ session = Session(testing.db, autocommit=True)
session.add(User(name="ed"))
@@ -629,7 +629,7 @@ class SessionStateTest(_fixtures.FixtureTest):
session.commit()
def test_active_flag_autocommit(self):
- sess = create_session(bind=config.db, autocommit=True)
+ sess = Session(bind=config.db, autocommit=True)
assert not sess.is_active
sess.begin()
assert sess.is_active
@@ -637,7 +637,7 @@ class SessionStateTest(_fixtures.FixtureTest):
assert not sess.is_active
def test_active_flag_autobegin(self):
- sess = create_session(bind=config.db, autocommit=False)
+ sess = Session(bind=config.db, autocommit=False)
assert sess.is_active
assert not sess.in_transaction()
sess.begin()
@@ -646,7 +646,7 @@ class SessionStateTest(_fixtures.FixtureTest):
assert sess.is_active
def test_active_flag_autobegin_future(self):
- sess = create_session(bind=config.db, future=True)
+ sess = Session(bind=config.db, future=True)
assert sess.is_active
assert not sess.in_transaction()
sess.begin()
@@ -655,7 +655,7 @@ class SessionStateTest(_fixtures.FixtureTest):
assert sess.is_active
def test_active_flag_partial_rollback(self):
- sess = create_session(bind=config.db, autocommit=False)
+ sess = Session(bind=config.db, autocommit=False)
assert sess.is_active
assert not sess.in_transaction()
sess.begin()
@@ -693,7 +693,7 @@ class SessionStateTest(_fixtures.FixtureTest):
)
s.add(user)
- s.flush()
+ s.commit()
user = s.query(User).one()
s.expunge(user)
assert user not in s
@@ -703,8 +703,7 @@ class SessionStateTest(_fixtures.FixtureTest):
s.add(user)
assert user in s
assert user in s.dirty
- s.flush()
- s.expunge_all()
+ s.commit()
assert s.query(User).count() == 1
user = s.query(User).one()
assert user.name == "fred"
@@ -766,8 +765,9 @@ class SessionStateTest(_fixtures.FixtureTest):
users, User = self.tables.users, self.classes.User
mapper(User, users)
- for s in (create_session(), create_session()):
- users.delete().execute()
+
+ with create_session() as s:
+ s.execute(users.delete())
u1 = User(name="ed")
s.add(u1)
s.flush()
@@ -1774,7 +1774,8 @@ class DisposedStates(fixtures.MappedTest):
def _test_session(self, **kwargs):
T = self.classes.T
- sess = create_session(**kwargs)
+
+ sess = Session(config.db, **kwargs)
data = o1, o2, o3, o4, o5 = [
T("t1"),
@@ -1786,7 +1787,7 @@ class DisposedStates(fixtures.MappedTest):
sess.add_all(data)
- sess.flush()
+ sess.commit()
o1.data = "t1modified"
o5.data = "t5modified"
@@ -1925,7 +1926,7 @@ class SessionInterface(fixtures.TestBase):
def raises_(method, *args, **kw):
watchdog.add(method)
- callable_ = getattr(create_session(), method)
+ callable_ = getattr(Session(), method)
if is_class:
assert_raises(
sa.orm.exc.UnmappedClassError, callable_, *args, **kw
diff --git a/test/orm/test_transaction.py b/test/orm/test_transaction.py
index e8f6c5c40..248f334cf 100644
--- a/test/orm/test_transaction.py
+++ b/test/orm/test_transaction.py
@@ -1951,9 +1951,7 @@ class AccountingFlagsTest(_LocalFixture):
sess.add(u1)
sess.commit()
- testing.db.execute(
- users.update(users.c.name == "ed").values(name="edward")
- )
+ sess.execute(users.update(users.c.name == "ed").values(name="edward"))
assert u1.name == "ed"
sess.expire_all()
diff --git a/test/orm/test_unitofworkv2.py b/test/orm/test_unitofworkv2.py
index ed320db10..31386b07f 100644
--- a/test/orm/test_unitofworkv2.py
+++ b/test/orm/test_unitofworkv2.py
@@ -778,7 +778,8 @@ class SingleCycleTest(UOWTest):
# mysql can't handle delete from nodes
# since it doesn't deal with the FKs correctly,
# so wipe out the parent_id first
- testing.db.execute(self.tables.nodes.update().values(parent_id=None))
+ with testing.db.begin() as conn:
+ conn.execute(self.tables.nodes.update().values(parent_id=None))
super(SingleCycleTest, self).teardown()
def test_one_to_many_save(self):
diff --git a/test/sql/test_defaults.py b/test/sql/test_defaults.py
index 4a6ebd0c8..2a2e70bc3 100644
--- a/test/sql/test_defaults.py
+++ b/test/sql/test_defaults.py
@@ -1012,9 +1012,7 @@ class PKIncrementTest(fixtures.TablesTest):
Column("str1", String(20)),
)
- # TODO: add coverage for increment on a secondary column in a key
- @testing.fails_on("firebird", "Data type unknown")
- def _test_autoincrement(self, connection):
+ def test_autoincrement(self, connection):
aitable = self.tables.aitable
ids = set()
@@ -1064,14 +1062,6 @@ class PKIncrementTest(fixtures.TablesTest):
],
)
- def test_autoincrement_autocommit(self):
- with testing.db.connect() as conn:
- self._test_autoincrement(conn)
-
- def test_autoincrement_transaction(self):
- with testing.db.begin() as conn:
- self._test_autoincrement(conn)
-
class EmptyInsertTest(fixtures.TestBase):
__backend__ = True
@@ -1267,7 +1257,7 @@ class SpecialTypePKTest(fixtures.TestBase):
implicit_returning=implicit_returning,
)
- with testing.db.connect() as conn:
+ with testing.db.begin() as conn:
t.create(conn)
r = conn.execute(t.insert().values(data=5))
diff --git a/test/sql/test_delete.py b/test/sql/test_delete.py
index 934022560..6f7b3f8f5 100644
--- a/test/sql/test_delete.py
+++ b/test/sql/test_delete.py
@@ -308,32 +308,31 @@ class DeleteFromRoundTripTest(fixtures.TablesTest):
)
@testing.requires.delete_from
- def test_exec_two_table(self):
+ def test_exec_two_table(self, connection):
users, addresses = self.tables.users, self.tables.addresses
dingalings = self.tables.dingalings
- with testing.db.connect() as conn:
- conn.execute(dingalings.delete()) # fk violation otherwise
+ connection.execute(dingalings.delete()) # fk violation otherwise
- conn.execute(
- addresses.delete()
- .where(users.c.id == addresses.c.user_id)
- .where(users.c.name == "ed")
- )
+ connection.execute(
+ addresses.delete()
+ .where(users.c.id == addresses.c.user_id)
+ .where(users.c.name == "ed")
+ )
- expected = [
- (1, 7, "x", "jack@bean.com"),
- (5, 9, "x", "fred@fred.com"),
- ]
- self._assert_table(addresses, expected)
+ expected = [
+ (1, 7, "x", "jack@bean.com"),
+ (5, 9, "x", "fred@fred.com"),
+ ]
+ self._assert_table(connection, addresses, expected)
@testing.requires.delete_from
- def test_exec_three_table(self):
+ def test_exec_three_table(self, connection):
users = self.tables.users
addresses = self.tables.addresses
dingalings = self.tables.dingalings
- testing.db.execute(
+ connection.execute(
dingalings.delete()
.where(users.c.id == addresses.c.user_id)
.where(users.c.name == "ed")
@@ -341,34 +340,33 @@ class DeleteFromRoundTripTest(fixtures.TablesTest):
)
expected = [(2, 5, "ding 2/5")]
- self._assert_table(dingalings, expected)
+ self._assert_table(connection, dingalings, expected)
@testing.requires.delete_from
- def test_exec_two_table_plus_alias(self):
+ def test_exec_two_table_plus_alias(self, connection):
users, addresses = self.tables.users, self.tables.addresses
dingalings = self.tables.dingalings
- with testing.db.connect() as conn:
- conn.execute(dingalings.delete()) # fk violation otherwise
- a1 = addresses.alias()
- conn.execute(
- addresses.delete()
- .where(users.c.id == addresses.c.user_id)
- .where(users.c.name == "ed")
- .where(a1.c.id == addresses.c.id)
- )
+ connection.execute(dingalings.delete()) # fk violation otherwise
+ a1 = addresses.alias()
+ connection.execute(
+ addresses.delete()
+ .where(users.c.id == addresses.c.user_id)
+ .where(users.c.name == "ed")
+ .where(a1.c.id == addresses.c.id)
+ )
expected = [(1, 7, "x", "jack@bean.com"), (5, 9, "x", "fred@fred.com")]
- self._assert_table(addresses, expected)
+ self._assert_table(connection, addresses, expected)
@testing.requires.delete_from
- def test_exec_alias_plus_table(self):
+ def test_exec_alias_plus_table(self, connection):
users, addresses = self.tables.users, self.tables.addresses
dingalings = self.tables.dingalings
d1 = dingalings.alias()
- testing.db.execute(
+ connection.execute(
delete(d1)
.where(users.c.id == addresses.c.user_id)
.where(users.c.name == "ed")
@@ -376,8 +374,8 @@ class DeleteFromRoundTripTest(fixtures.TablesTest):
)
expected = [(2, 5, "ding 2/5")]
- self._assert_table(dingalings, expected)
+ self._assert_table(connection, dingalings, expected)
- def _assert_table(self, table, expected):
+ def _assert_table(self, connection, table, expected):
stmt = table.select().order_by(table.c.id)
- eq_(testing.db.execute(stmt).fetchall(), expected)
+ eq_(connection.execute(stmt).fetchall(), expected)
diff --git a/test/sql/test_deprecations.py b/test/sql/test_deprecations.py
index c0d2e87e8..e082cf55d 100644
--- a/test/sql/test_deprecations.py
+++ b/test/sql/test_deprecations.py
@@ -23,6 +23,7 @@ from sqlalchemy import MetaData
from sqlalchemy import null
from sqlalchemy import or_
from sqlalchemy import select
+from sqlalchemy import Sequence
from sqlalchemy import sql
from sqlalchemy import String
from sqlalchemy import table
@@ -1271,6 +1272,165 @@ class KeyTargetingTest(fixtures.TablesTest):
in_(stmt.c.keyed2_b, row)
+class PKIncrementTest(fixtures.TablesTest):
+ run_define_tables = "each"
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "aitable",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ Sequence("ai_id_seq", optional=True),
+ primary_key=True,
+ ),
+ Column("int1", Integer),
+ Column("str1", String(20)),
+ )
+
+ def _test_autoincrement(self, connection):
+ aitable = self.tables.aitable
+
+ ids = set()
+ rs = connection.execute(aitable.insert(), int1=1)
+ last = rs.inserted_primary_key[0]
+ self.assert_(last)
+ self.assert_(last not in ids)
+ ids.add(last)
+
+ rs = connection.execute(aitable.insert(), str1="row 2")
+ last = rs.inserted_primary_key[0]
+ self.assert_(last)
+ self.assert_(last not in ids)
+ ids.add(last)
+
+ rs = connection.execute(aitable.insert(), int1=3, str1="row 3")
+ last = rs.inserted_primary_key[0]
+ self.assert_(last)
+ self.assert_(last not in ids)
+ ids.add(last)
+
+ rs = connection.execute(
+ aitable.insert().values({"int1": func.length("four")})
+ )
+ last = rs.inserted_primary_key[0]
+ self.assert_(last)
+ self.assert_(last not in ids)
+ ids.add(last)
+
+ eq_(
+ ids,
+ set(
+ range(
+ testing.db.dialect.default_sequence_base,
+ testing.db.dialect.default_sequence_base + 4,
+ )
+ ),
+ )
+
+ eq_(
+ list(connection.execute(aitable.select().order_by(aitable.c.id))),
+ [
+ (testing.db.dialect.default_sequence_base, 1, None),
+ (testing.db.dialect.default_sequence_base + 1, None, "row 2"),
+ (testing.db.dialect.default_sequence_base + 2, 3, "row 3"),
+ (testing.db.dialect.default_sequence_base + 3, 4, None),
+ ],
+ )
+
+ def test_autoincrement_autocommit(self):
+ with testing.db.connect() as conn:
+ with testing.expect_deprecated_20(
+ "The current statement is being autocommitted using "
+ "implicit autocommit, "
+ ):
+ self._test_autoincrement(conn)
+
+
+class ConnectionlessCursorResultTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "users",
+ metadata,
+ Column(
+ "user_id", INT, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("user_name", VARCHAR(20)),
+ test_needs_acid=True,
+ )
+
+ def test_connectionless_autoclose_rows_exhausted(self):
+ users = self.tables.users
+ with testing.db.begin() as conn:
+ conn.execute(users.insert(), dict(user_id=1, user_name="john"))
+
+ with testing.expect_deprecated_20(
+ r"The (?:Executable|Engine)\.(?:execute|scalar)\(\) method"
+ ):
+ result = testing.db.execute(text("select * from users"))
+ connection = result.connection
+ assert not connection.closed
+ eq_(result.fetchone(), (1, "john"))
+ assert not connection.closed
+ eq_(result.fetchone(), None)
+ assert connection.closed
+
+ @testing.requires.returning
+ def test_connectionless_autoclose_crud_rows_exhausted(self):
+ users = self.tables.users
+ stmt = (
+ users.insert()
+ .values(user_id=1, user_name="john")
+ .returning(users.c.user_id)
+ )
+ with testing.expect_deprecated_20(
+ r"The (?:Executable|Engine)\.(?:execute|scalar)\(\) method"
+ ):
+ result = testing.db.execute(stmt)
+ connection = result.connection
+ assert not connection.closed
+ eq_(result.fetchone(), (1,))
+ assert not connection.closed
+ eq_(result.fetchone(), None)
+ assert connection.closed
+
+ def test_connectionless_autoclose_no_rows(self):
+ with testing.expect_deprecated_20(
+ r"The (?:Executable|Engine)\.(?:execute|scalar)\(\) method"
+ ):
+ result = testing.db.execute(text("select * from users"))
+ connection = result.connection
+ assert not connection.closed
+ eq_(result.fetchone(), None)
+ assert connection.closed
+
+ @testing.requires.updateable_autoincrement_pks
+ def test_connectionless_autoclose_no_metadata(self):
+ with testing.expect_deprecated_20(
+ r"The (?:Executable|Engine)\.(?:execute|scalar)\(\) method"
+ ):
+ result = testing.db.execute(text("update users set user_id=5"))
+ connection = result.connection
+ assert connection.closed
+
+ assert_raises_message(
+ exc.ResourceClosedError,
+ "This result object does not return rows.",
+ result.fetchone,
+ )
+ assert_raises_message(
+ exc.ResourceClosedError,
+ "This result object does not return rows.",
+ result.keys,
+ )
+
+
class CursorResultTest(fixtures.TablesTest):
__backend__ = True
@@ -1436,7 +1596,7 @@ class CursorResultTest(fixtures.TablesTest):
def test_pickled_rows(self):
users = self.tables.users
addresses = self.tables.addresses
- with testing.db.connect() as conn:
+ with testing.db.begin() as conn:
conn.execute(users.delete())
conn.execute(
users.insert(),
@@ -2319,3 +2479,93 @@ class LegacyOperatorTest(AssertsCompiledSQL, fixtures.TestBase):
_op_modern = getattr(operators.ColumnOperators, _modern)
_op_legacy = getattr(operators.ColumnOperators, _legacy)
assert _op_modern == _op_legacy
+
+
+class LegacySequenceExecTest(fixtures.TestBase):
+ __requires__ = ("sequences",)
+ __backend__ = True
+
+ @classmethod
+ def setup_class(cls):
+ cls.seq = Sequence("my_sequence")
+ cls.seq.create(testing.db)
+
+ @classmethod
+ def teardown_class(cls):
+ cls.seq.drop(testing.db)
+
+ def _assert_seq_result(self, ret):
+ """asserts return of next_value is an int"""
+
+ assert isinstance(ret, util.int_types)
+ assert ret >= testing.db.dialect.default_sequence_base
+
+ def test_implicit_connectionless(self):
+ with testing.expect_deprecated_20(
+ r"The MetaData.bind argument is deprecated"
+ ):
+ s = Sequence("my_sequence", metadata=MetaData(testing.db))
+
+ with testing.expect_deprecated_20(
+ r"The DefaultGenerator.execute\(\) method is considered legacy "
+ "as of the 1.x",
+ ):
+ self._assert_seq_result(s.execute())
+
+ def test_explicit(self, connection):
+ s = Sequence("my_sequence")
+ with testing.expect_deprecated_20(
+ r"The DefaultGenerator.execute\(\) method is considered legacy"
+ ):
+ self._assert_seq_result(s.execute(connection))
+
+ def test_explicit_optional(self):
+ """test dialect executes a Sequence, returns nextval, whether
+ or not "optional" is set"""
+
+ s = Sequence("my_sequence", optional=True)
+ with testing.expect_deprecated_20(
+ r"The DefaultGenerator.execute\(\) method is considered legacy"
+ ):
+ self._assert_seq_result(s.execute(testing.db))
+
+ def test_func_implicit_connectionless_execute(self):
+ """test func.next_value().execute()/.scalar() works
+ with connectionless execution."""
+
+ with testing.expect_deprecated_20(
+ r"The MetaData.bind argument is deprecated"
+ ):
+ s = Sequence("my_sequence", metadata=MetaData(testing.db))
+ with testing.expect_deprecated_20(
+ r"The Executable.execute\(\) method is considered legacy"
+ ):
+ self._assert_seq_result(s.next_value().execute().scalar())
+
+ def test_func_explicit(self):
+ s = Sequence("my_sequence")
+ with testing.expect_deprecated_20(
+ r"The Engine.scalar\(\) method is considered legacy"
+ ):
+ self._assert_seq_result(testing.db.scalar(s.next_value()))
+
+ def test_func_implicit_connectionless_scalar(self):
+ """test func.next_value().execute()/.scalar() works. """
+
+ with testing.expect_deprecated_20(
+ r"The MetaData.bind argument is deprecated"
+ ):
+ s = Sequence("my_sequence", metadata=MetaData(testing.db))
+ with testing.expect_deprecated_20(
+ r"The Executable.execute\(\) method is considered legacy"
+ ):
+ self._assert_seq_result(s.next_value().scalar())
+
+ def test_func_embedded_select(self):
+ """test can use next_value() in select column expr"""
+
+ s = Sequence("my_sequence")
+ with testing.expect_deprecated_20(
+ r"The Engine.scalar\(\) method is considered legacy"
+ ):
+ self._assert_seq_result(testing.db.scalar(select(s.next_value())))
diff --git a/test/sql/test_query.py b/test/sql/test_query.py
index 7d05462ab..6d26f7975 100644
--- a/test/sql/test_query.py
+++ b/test/sql/test_query.py
@@ -84,7 +84,7 @@ class QueryTest(fixtures.TestBase):
@engines.close_first
def teardown(self):
- with testing.db.connect() as conn:
+ with testing.db.begin() as conn:
conn.execute(addresses.delete())
conn.execute(users.delete())
conn.execute(users2.delete())
@@ -878,21 +878,22 @@ class RequiredBindTest(fixtures.TablesTest):
)
def _assert_raises(self, stmt, params):
- assert_raises_message(
- exc.StatementError,
- "A value is required for bind parameter 'x'",
- testing.db.execute,
- stmt,
- **params
- )
+ with testing.db.connect() as conn:
+ assert_raises_message(
+ exc.StatementError,
+ "A value is required for bind parameter 'x'",
+ conn.execute,
+ stmt,
+ **params
+ )
- assert_raises_message(
- exc.StatementError,
- "A value is required for bind parameter 'x'",
- testing.db.execute,
- stmt,
- params,
- )
+ assert_raises_message(
+ exc.StatementError,
+ "A value is required for bind parameter 'x'",
+ conn.execute,
+ stmt,
+ params,
+ )
def test_insert(self):
stmt = self.tables.foo.insert().values(
@@ -953,7 +954,7 @@ class LimitTest(fixtures.TestBase):
)
metadata.create_all()
- with testing.db.connect() as conn:
+ with testing.db.begin() as conn:
conn.execute(users.insert(), user_id=1, user_name="john")
conn.execute(
addresses.insert(), address_id=1, user_id=1, address="addr1"
@@ -1105,7 +1106,7 @@ class CompoundTest(fixtures.TestBase):
)
metadata.create_all()
- with testing.db.connect() as conn:
+ with testing.db.begin() as conn:
conn.execute(
t1.insert(),
[
@@ -1470,7 +1471,7 @@ class JoinTest(fixtures.TestBase):
metadata.drop_all()
metadata.create_all()
- with testing.db.connect() as conn:
+ with testing.db.begin() as conn:
# t1.10 -> t2.20 -> t3.30
# t1.11 -> t2.21
# t1.12
@@ -1823,7 +1824,7 @@ class OperatorTest(fixtures.TestBase):
)
metadata.create_all()
- with testing.db.connect() as conn:
+ with testing.db.begin() as conn:
conn.execute(
flds.insert(),
[dict(intcol=5, strcol="foo"), dict(intcol=13, strcol="bar")],
diff --git a/test/sql/test_quote.py b/test/sql/test_quote.py
index 1c023e7b1..a78d6c16b 100644
--- a/test/sql/test_quote.py
+++ b/test/sql/test_quote.py
@@ -25,19 +25,12 @@ from sqlalchemy.testing import is_
from sqlalchemy.testing.util import picklers
-class QuoteExecTest(fixtures.TestBase):
+class QuoteExecTest(fixtures.TablesTest):
__backend__ = True
@classmethod
- def setup_class(cls):
- # TODO: figure out which databases/which identifiers allow special
- # characters to be used, such as: spaces, quote characters,
- # punctuation characters, set up tests for those as well.
-
- global table1, table2
- metadata = MetaData(testing.db)
-
- table1 = Table(
+ def define_tables(cls, metadata):
+ Table(
"WorstCase1",
metadata,
Column("lowercase", Integer, primary_key=True),
@@ -45,7 +38,7 @@ class QuoteExecTest(fixtures.TestBase):
Column("MixedCase", Integer),
Column("ASC", Integer, key="a123"),
)
- table2 = Table(
+ Table(
"WorstCase2",
metadata,
Column("desc", Integer, primary_key=True, key="d123"),
@@ -53,18 +46,6 @@ class QuoteExecTest(fixtures.TestBase):
Column("MixedCase", Integer),
)
- table1.create()
- table2.create()
-
- def teardown(self):
- table1.delete().execute()
- table2.delete().execute()
-
- @classmethod
- def teardown_class(cls):
- table1.drop()
- table2.drop()
-
def test_reflect(self):
meta2 = MetaData()
t2 = Table("WorstCase1", meta2, autoload_with=testing.db, quote=True)
@@ -88,25 +69,22 @@ class QuoteExecTest(fixtures.TestBase):
assert "MixedCase" in t2.c
@testing.provide_metadata
- def test_has_table_case_sensitive(self):
+ def test_has_table_case_sensitive(self, connection):
preparer = testing.db.dialect.identifier_preparer
- with testing.db.connect() as conn:
- if conn.dialect.requires_name_normalize:
- conn.exec_driver_sql("CREATE TABLE TAB1 (id INTEGER)")
- else:
- conn.exec_driver_sql("CREATE TABLE tab1 (id INTEGER)")
- conn.exec_driver_sql(
- "CREATE TABLE %s (id INTEGER)"
- % preparer.quote_identifier("tab2")
- )
- conn.exec_driver_sql(
- "CREATE TABLE %s (id INTEGER)"
- % preparer.quote_identifier("TAB3")
- )
- conn.exec_driver_sql(
- "CREATE TABLE %s (id INTEGER)"
- % preparer.quote_identifier("TAB4")
- )
+ conn = connection
+ if conn.dialect.requires_name_normalize:
+ conn.exec_driver_sql("CREATE TABLE TAB1 (id INTEGER)")
+ else:
+ conn.exec_driver_sql("CREATE TABLE tab1 (id INTEGER)")
+ conn.exec_driver_sql(
+ "CREATE TABLE %s (id INTEGER)" % preparer.quote_identifier("tab2")
+ )
+ conn.exec_driver_sql(
+ "CREATE TABLE %s (id INTEGER)" % preparer.quote_identifier("TAB3")
+ )
+ conn.exec_driver_sql(
+ "CREATE TABLE %s (id INTEGER)" % preparer.quote_identifier("TAB4")
+ )
t1 = Table(
"tab1", self.metadata, Column("id", Integer, primary_key=True)
@@ -127,7 +105,7 @@ class QuoteExecTest(fixtures.TestBase):
quote=True,
)
- insp = inspect(testing.db)
+ insp = inspect(connection)
assert insp.has_table(t1.name)
eq_([c["name"] for c in insp.get_columns(t1.name)], ["id"])
@@ -140,16 +118,24 @@ class QuoteExecTest(fixtures.TestBase):
assert insp.has_table(t4.name)
eq_([c["name"] for c in insp.get_columns(t4.name)], ["id"])
- def test_basic(self):
- table1.insert().execute(
- {"lowercase": 1, "UPPERCASE": 2, "MixedCase": 3, "a123": 4},
- {"lowercase": 2, "UPPERCASE": 2, "MixedCase": 3, "a123": 4},
- {"lowercase": 4, "UPPERCASE": 3, "MixedCase": 2, "a123": 1},
+ def test_basic(self, connection):
+ table1, table2 = self.tables("WorstCase1", "WorstCase2")
+
+ connection.execute(
+ table1.insert(),
+ [
+ {"lowercase": 1, "UPPERCASE": 2, "MixedCase": 3, "a123": 4},
+ {"lowercase": 2, "UPPERCASE": 2, "MixedCase": 3, "a123": 4},
+ {"lowercase": 4, "UPPERCASE": 3, "MixedCase": 2, "a123": 1},
+ ],
)
- table2.insert().execute(
- {"d123": 1, "u123": 2, "MixedCase": 3},
- {"d123": 2, "u123": 2, "MixedCase": 3},
- {"d123": 4, "u123": 3, "MixedCase": 2},
+ connection.execute(
+ table2.insert(),
+ [
+ {"d123": 1, "u123": 2, "MixedCase": 3},
+ {"d123": 2, "u123": 2, "MixedCase": 3},
+ {"d123": 4, "u123": 3, "MixedCase": 2},
+ ],
)
columns = [
@@ -158,23 +144,30 @@ class QuoteExecTest(fixtures.TestBase):
table1.c.MixedCase,
table1.c.a123,
]
- result = select(columns).execute().fetchall()
+ result = connection.execute(select(columns)).all()
assert result == [(1, 2, 3, 4), (2, 2, 3, 4), (4, 3, 2, 1)]
columns = [table2.c.d123, table2.c.u123, table2.c.MixedCase]
- result = select(columns).execute().fetchall()
+ result = connection.execute(select(columns)).all()
assert result == [(1, 2, 3), (2, 2, 3), (4, 3, 2)]
- def test_use_labels(self):
- table1.insert().execute(
- {"lowercase": 1, "UPPERCASE": 2, "MixedCase": 3, "a123": 4},
- {"lowercase": 2, "UPPERCASE": 2, "MixedCase": 3, "a123": 4},
- {"lowercase": 4, "UPPERCASE": 3, "MixedCase": 2, "a123": 1},
- )
- table2.insert().execute(
- {"d123": 1, "u123": 2, "MixedCase": 3},
- {"d123": 2, "u123": 2, "MixedCase": 3},
- {"d123": 4, "u123": 3, "MixedCase": 2},
+ def test_use_labels(self, connection):
+ table1, table2 = self.tables("WorstCase1", "WorstCase2")
+ connection.execute(
+ table1.insert(),
+ [
+ {"lowercase": 1, "UPPERCASE": 2, "MixedCase": 3, "a123": 4},
+ {"lowercase": 2, "UPPERCASE": 2, "MixedCase": 3, "a123": 4},
+ {"lowercase": 4, "UPPERCASE": 3, "MixedCase": 2, "a123": 1},
+ ],
+ )
+ connection.execute(
+ table2.insert(),
+ [
+ {"d123": 1, "u123": 2, "MixedCase": 3},
+ {"d123": 2, "u123": 2, "MixedCase": 3},
+ {"d123": 4, "u123": 3, "MixedCase": 2},
+ ],
)
columns = [
@@ -183,11 +176,11 @@ class QuoteExecTest(fixtures.TestBase):
table1.c.MixedCase,
table1.c.a123,
]
- result = select(columns, use_labels=True).execute().fetchall()
+ result = connection.execute(select(columns).apply_labels()).fetchall()
assert result == [(1, 2, 3, 4), (2, 2, 3, 4), (4, 3, 2, 1)]
columns = [table2.c.d123, table2.c.u123, table2.c.MixedCase]
- result = select(columns, use_labels=True).execute().fetchall()
+ result = connection.execute(select(columns).apply_labels()).all()
assert result == [(1, 2, 3), (2, 2, 3), (4, 3, 2)]
diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py
index 9ef533be3..db0e0d4c8 100644
--- a/test/sql/test_resultset.py
+++ b/test/sql/test_resultset.py
@@ -615,63 +615,6 @@ class CursorResultTest(fixtures.TablesTest):
result.fetchone,
)
- def test_connectionless_autoclose_rows_exhausted(self):
- # TODO: deprecate for 2.0
- users = self.tables.users
- with testing.db.connect() as conn:
- conn.execute(users.insert(), dict(user_id=1, user_name="john"))
-
- result = testing.db.execute(text("select * from users"))
- connection = result.connection
- assert not connection.closed
- eq_(result.fetchone(), (1, "john"))
- assert not connection.closed
- eq_(result.fetchone(), None)
- assert connection.closed
-
- @testing.requires.returning
- def test_connectionless_autoclose_crud_rows_exhausted(self):
- # TODO: deprecate for 2.0
- users = self.tables.users
- stmt = (
- users.insert()
- .values(user_id=1, user_name="john")
- .returning(users.c.user_id)
- )
- result = testing.db.execute(stmt)
- connection = result.connection
- assert not connection.closed
- eq_(result.fetchone(), (1,))
- assert not connection.closed
- eq_(result.fetchone(), None)
- assert connection.closed
-
- def test_connectionless_autoclose_no_rows(self):
- # TODO: deprecate for 2.0
- result = testing.db.execute(text("select * from users"))
- connection = result.connection
- assert not connection.closed
- eq_(result.fetchone(), None)
- assert connection.closed
-
- @testing.requires.updateable_autoincrement_pks
- def test_connectionless_autoclose_no_metadata(self):
- # TODO: deprecate for 2.0
- result = testing.db.execute(text("update users set user_id=5"))
- connection = result.connection
- assert connection.closed
-
- assert_raises_message(
- exc.ResourceClosedError,
- "This result object does not return rows.",
- result.fetchone,
- )
- assert_raises_message(
- exc.ResourceClosedError,
- "This result object does not return rows.",
- result.keys,
- )
-
def test_row_case_sensitive(self, connection):
row = connection.execute(
select(
@@ -1285,7 +1228,7 @@ class CursorResultTest(fixtures.TablesTest):
with patch.object(
engine.dialect.execution_ctx_cls, "rowcount"
) as mock_rowcount:
- with engine.connect() as conn:
+ with engine.begin() as conn:
mock_rowcount.__get__ = Mock()
conn.execute(
t.insert(), {"data": "d1"}, {"data": "d2"}, {"data": "d3"}
@@ -1362,20 +1305,14 @@ class CursorResultTest(fixtures.TablesTest):
eq_(row[1:0:-1], ("Uno",))
@testing.requires.cextensions
- def test_row_c_sequence_check(self):
- # TODO: modernize for 2.0
- metadata = MetaData()
- metadata.bind = "sqlite://"
- users = Table(
- "users",
- metadata,
- Column("id", Integer, primary_key=True),
- Column("name", String(40)),
- )
- users.create()
+ @testing.provide_metadata
+ def test_row_c_sequence_check(self, connection):
+ users = self.tables.users2
- users.insert().execute(name="Test")
- row = users.select().execute().fetchone()
+ connection.execute(users.insert(), dict(user_id=1, user_name="Test"))
+ row = connection.execute(
+ users.select().where(users.c.user_id == 1)
+ ).fetchone()
s = util.StringIO()
writer = csv.writer(s)
@@ -2340,7 +2277,7 @@ class AlternateCursorResultTest(fixtures.TablesTest):
@testing.fixture
def row_growth_fixture(self):
with self._proxy_fixture(_cursor.BufferedRowCursorFetchStrategy):
- with self.engine.connect() as conn:
+ with self.engine.begin() as conn:
conn.execute(
self.table.insert(),
[{"x": i, "y": "t_%d" % i} for i in range(15, 3000)],
diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py
index 065205c45..9f2afd7b7 100644
--- a/test/sql/test_returning.py
+++ b/test/sql/test_returning.py
@@ -23,9 +23,6 @@ from sqlalchemy.testing.schema import Table
from sqlalchemy.types import TypeDecorator
-table = GoofyType = seq = None
-
-
class ReturnCombinationTests(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = "postgresql"
@@ -92,14 +89,14 @@ class ReturnCombinationTests(fixtures.TestBase, AssertsCompiledSQL):
)
-class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
+class ReturningTest(fixtures.TablesTest, AssertsExecutionResults):
__requires__ = ("returning",)
__backend__ = True
- def setup(self):
- meta = MetaData(testing.db)
- global table, GoofyType
+ run_create_tables = "each"
+ @classmethod
+ def define_tables(cls, metadata):
class GoofyType(TypeDecorator):
impl = String
@@ -113,9 +110,11 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
return None
return value + "BAR"
- table = Table(
+ cls.GoofyType = GoofyType
+
+ Table(
"tables",
- meta,
+ metadata,
Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
),
@@ -123,14 +122,9 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
Column("full", Boolean),
Column("goofy", GoofyType(50)),
)
- with testing.db.connect() as conn:
- table.create(conn, checkfirst=True)
-
- def teardown(self):
- with testing.db.connect() as conn:
- table.drop(conn)
def test_column_targeting(self, connection):
+ table = self.tables.tables
result = connection.execute(
table.insert().returning(table.c.id, table.c.full),
{"persons": 1, "full": False},
@@ -155,6 +149,7 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
@testing.fails_on("firebird", "fb can't handle returning x AS y")
def test_labeling(self, connection):
+ table = self.tables.tables
result = connection.execute(
table.insert()
.values(persons=6)
@@ -167,6 +162,8 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
"firebird", "fb/kintersbasdb can't handle the bind params"
)
def test_anon_expressions(self, connection):
+ table = self.tables.tables
+ GoofyType = self.GoofyType
result = connection.execute(
table.insert()
.values(goofy="someOTHERgoofy")
@@ -182,6 +179,7 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
eq_(row[0], 30)
def test_update_returning(self, connection):
+ table = self.tables.tables
connection.execute(
table.insert(),
[{"persons": 5, "full": False}, {"persons": 3, "full": False}],
@@ -201,6 +199,7 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
@testing.requires.full_returning
def test_update_full_returning(self, connection):
+ table = self.tables.tables
connection.execute(
table.insert(),
[{"persons": 5, "full": False}, {"persons": 3, "full": False}],
@@ -215,6 +214,7 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
@testing.requires.full_returning
def test_delete_full_returning(self, connection):
+ table = self.tables.tables
connection.execute(
table.insert(),
[{"persons": 5, "full": False}, {"persons": 3, "full": False}],
@@ -226,6 +226,7 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
eq_(result.fetchall(), [(1, False), (2, False)])
def test_insert_returning(self, connection):
+ table = self.tables.tables
result = connection.execute(
table.insert().returning(table.c.id), {"persons": 1, "full": False}
)
@@ -234,6 +235,7 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
@testing.requires.multivalues_inserts
def test_multirow_returning(self, connection):
+ table = self.tables.tables
ins = (
table.insert()
.returning(table.c.id, table.c.persons)
@@ -249,6 +251,7 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
eq_(result.fetchall(), [(1, 1), (2, 2), (3, 3)])
def test_no_ipk_on_returning(self, connection):
+ table = self.tables.tables
result = connection.execute(
table.insert().returning(table.c.id), {"persons": 1, "full": False}
)
@@ -274,6 +277,7 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
eq_([dict(row._mapping) for row in result4], [{"persons": 10}])
def test_delete_returning(self, connection):
+ table = self.tables.tables
connection.execute(
table.insert(),
[{"persons": 5, "full": False}, {"persons": 3, "full": False}],
@@ -319,17 +323,16 @@ class CompositeStatementTest(fixtures.TestBase):
eq_(result.scalar(), 5)
-class SequenceReturningTest(fixtures.TestBase):
+class SequenceReturningTest(fixtures.TablesTest):
__requires__ = "returning", "sequences"
__backend__ = True
- def setup(self):
- meta = MetaData(testing.db)
- global table, seq
+ @classmethod
+ def define_tables(cls, metadata):
seq = Sequence("tid_seq")
- table = Table(
+ Table(
"tables",
- meta,
+ metadata,
Column(
"id",
Integer,
@@ -338,38 +341,32 @@ class SequenceReturningTest(fixtures.TestBase):
),
Column("data", String(50)),
)
- with testing.db.connect() as conn:
- table.create(conn, checkfirst=True)
-
- def teardown(self):
- with testing.db.connect() as conn:
- table.drop(conn)
+ cls.sequences.tid_seq = seq
def test_insert(self, connection):
+ table = self.tables.tables
r = connection.execute(
table.insert().values(data="hi").returning(table.c.id)
)
eq_(r.first(), tuple([testing.db.dialect.default_sequence_base]))
eq_(
- connection.execute(seq),
+ connection.execute(self.sequences.tid_seq),
testing.db.dialect.default_sequence_base + 1,
)
-class KeyReturningTest(fixtures.TestBase, AssertsExecutionResults):
+class KeyReturningTest(fixtures.TablesTest, AssertsExecutionResults):
"""test returning() works with columns that define 'key'."""
__requires__ = ("returning",)
__backend__ = True
- def setup(self):
- meta = MetaData(testing.db)
- global table
-
- table = Table(
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
"tables",
- meta,
+ metadata,
Column(
"id",
Integer,
@@ -379,16 +376,11 @@ class KeyReturningTest(fixtures.TestBase, AssertsExecutionResults):
),
Column("data", String(20)),
)
- with testing.db.connect() as conn:
- table.create(conn, checkfirst=True)
-
- def teardown(self):
- with testing.db.connect() as conn:
- table.drop(conn)
@testing.exclude("firebird", "<", (2, 0), "2.0+ feature")
@testing.exclude("postgresql", "<", (8, 2), "8.2+ feature")
def test_insert(self, connection):
+ table = self.tables.tables
result = connection.execute(
table.insert().returning(table.c.foo_id), data="somedata"
)
diff --git a/test/sql/test_sequences.py b/test/sql/test_sequences.py
index e609a8a91..1809e0cca 100644
--- a/test/sql/test_sequences.py
+++ b/test/sql/test_sequences.py
@@ -95,64 +95,6 @@ class SequenceDDLTest(fixtures.TestBase, testing.AssertsCompiledSQL):
)
-class LegacySequenceExecTest(fixtures.TestBase):
- __requires__ = ("sequences",)
- __backend__ = True
-
- @classmethod
- def setup_class(cls):
- cls.seq = Sequence("my_sequence")
- cls.seq.create(testing.db)
-
- @classmethod
- def teardown_class(cls):
- cls.seq.drop(testing.db)
-
- def _assert_seq_result(self, ret):
- """asserts return of next_value is an int"""
-
- assert isinstance(ret, util.int_types)
- assert ret >= testing.db.dialect.default_sequence_base
-
- def test_implicit_connectionless(self):
- s = Sequence("my_sequence", metadata=MetaData(testing.db))
- self._assert_seq_result(s.execute())
-
- def test_explicit(self, connection):
- s = Sequence("my_sequence")
- self._assert_seq_result(s.execute(connection))
-
- def test_explicit_optional(self):
- """test dialect executes a Sequence, returns nextval, whether
- or not "optional" is set"""
-
- s = Sequence("my_sequence", optional=True)
- self._assert_seq_result(s.execute(testing.db))
-
- def test_func_implicit_connectionless_execute(self):
- """test func.next_value().execute()/.scalar() works
- with connectionless execution."""
-
- s = Sequence("my_sequence", metadata=MetaData(testing.db))
- self._assert_seq_result(s.next_value().execute().scalar())
-
- def test_func_explicit(self):
- s = Sequence("my_sequence")
- self._assert_seq_result(testing.db.scalar(s.next_value()))
-
- def test_func_implicit_connectionless_scalar(self):
- """test func.next_value().execute()/.scalar() works. """
-
- s = Sequence("my_sequence", metadata=MetaData(testing.db))
- self._assert_seq_result(s.next_value().scalar())
-
- def test_func_embedded_select(self):
- """test can use next_value() in select column expr"""
-
- s = Sequence("my_sequence")
- self._assert_seq_result(testing.db.scalar(select(s.next_value())))
-
-
class SequenceExecTest(fixtures.TestBase):
__requires__ = ("sequences",)
__backend__ = True
@@ -247,7 +189,7 @@ class SequenceExecTest(fixtures.TestBase):
s = Sequence("my_sequence_here", metadata=metadata)
e = engines.testing_engine(options={"implicit_returning": False})
- with e.connect() as conn:
+ with e.begin() as conn:
t1.create(conn)
s.create(conn)
@@ -279,7 +221,7 @@ class SequenceExecTest(fixtures.TestBase):
t1.create(testing.db)
e = engines.testing_engine(options={"implicit_returning": True})
- with e.connect() as conn:
+ with e.begin() as conn:
r = conn.execute(t1.insert().values(x=s.next_value()))
self._assert_seq_result(r.inserted_primary_key[0])
@@ -476,7 +418,7 @@ class TableBoundSequenceTest(fixtures.TablesTest):
engine = engines.testing_engine(options={"implicit_returning": False})
- with engine.connect() as conn:
+ with engine.begin() as conn:
result = conn.execute(sometable.insert(), dict(name="somename"))
eq_(result.postfetch_cols(), [sometable.c.obj_id])
diff --git a/test/sql/test_type_expressions.py b/test/sql/test_type_expressions.py
index 09ade319e..719f8e318 100644
--- a/test/sql/test_type_expressions.py
+++ b/test/sql/test_type_expressions.py
@@ -359,34 +359,34 @@ class RoundTripTestBase(object):
[("X1", "Y1"), ("X2", "Y2"), ("X3", "Y3")],
)
- def test_targeting_no_labels(self):
- testing.db.execute(
+ def test_targeting_no_labels(self, connection):
+ connection.execute(
self.tables.test_table.insert(), {"x": "X1", "y": "Y1"}
)
- row = testing.db.execute(select(self.tables.test_table)).first()
+ row = connection.execute(select(self.tables.test_table)).first()
eq_(row._mapping[self.tables.test_table.c.y], "Y1")
- def test_targeting_by_string(self):
- testing.db.execute(
+ def test_targeting_by_string(self, connection):
+ connection.execute(
self.tables.test_table.insert(), {"x": "X1", "y": "Y1"}
)
- row = testing.db.execute(select(self.tables.test_table)).first()
+ row = connection.execute(select(self.tables.test_table)).first()
eq_(row._mapping["y"], "Y1")
- def test_targeting_apply_labels(self):
- testing.db.execute(
+ def test_targeting_apply_labels(self, connection):
+ connection.execute(
self.tables.test_table.insert(), {"x": "X1", "y": "Y1"}
)
- row = testing.db.execute(
+ row = connection.execute(
select(self.tables.test_table).apply_labels()
).first()
eq_(row._mapping[self.tables.test_table.c.y], "Y1")
- def test_targeting_individual_labels(self):
- testing.db.execute(
+ def test_targeting_individual_labels(self, connection):
+ connection.execute(
self.tables.test_table.insert(), {"x": "X1", "y": "Y1"}
)
- row = testing.db.execute(
+ row = connection.execute(
select(
self.tables.test_table.c.x.label("xbar"),
self.tables.test_table.c.y.label("ybar"),
@@ -450,9 +450,9 @@ class ReturningTest(fixtures.TablesTest):
)
@testing.provide_metadata
- def test_insert_returning(self):
+ def test_insert_returning(self, connection):
table = self.tables.test_table
- result = testing.db.execute(
+ result = connection.execute(
table.insert().returning(table.c.y), {"x": "xvalue"}
)
eq_(result.first(), ("yvalue",))
diff --git a/test/sql/test_types.py b/test/sql/test_types.py
index fd1783e09..3f89d438a 100644
--- a/test/sql/test_types.py
+++ b/test/sql/test_types.py
@@ -535,49 +535,48 @@ class _UserDefinedTypeFixture(object):
class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest):
__backend__ = True
- def _data_fixture(self):
+ def _data_fixture(self, connection):
users = self.tables.users
- with testing.db.connect() as conn:
- conn.execute(
- users.insert(),
- dict(
- user_id=2,
- goofy="jack",
- goofy2="jack",
- goofy4=util.u("jack"),
- goofy7=util.u("jack"),
- goofy8=12,
- goofy9=12,
- ),
- )
- conn.execute(
- users.insert(),
- dict(
- user_id=3,
- goofy="lala",
- goofy2="lala",
- goofy4=util.u("lala"),
- goofy7=util.u("lala"),
- goofy8=15,
- goofy9=15,
- ),
- )
- conn.execute(
- users.insert(),
- dict(
- user_id=4,
- goofy="fred",
- goofy2="fred",
- goofy4=util.u("fred"),
- goofy7=util.u("fred"),
- goofy8=9,
- goofy9=9,
- ),
- )
+ connection.execute(
+ users.insert(),
+ dict(
+ user_id=2,
+ goofy="jack",
+ goofy2="jack",
+ goofy4=util.u("jack"),
+ goofy7=util.u("jack"),
+ goofy8=12,
+ goofy9=12,
+ ),
+ )
+ connection.execute(
+ users.insert(),
+ dict(
+ user_id=3,
+ goofy="lala",
+ goofy2="lala",
+ goofy4=util.u("lala"),
+ goofy7=util.u("lala"),
+ goofy8=15,
+ goofy9=15,
+ ),
+ )
+ connection.execute(
+ users.insert(),
+ dict(
+ user_id=4,
+ goofy="fred",
+ goofy2="fred",
+ goofy4=util.u("fred"),
+ goofy7=util.u("fred"),
+ goofy8=9,
+ goofy9=9,
+ ),
+ )
def test_processing(self, connection):
users = self.tables.users
- self._data_fixture()
+ self._data_fixture(connection)
result = connection.execute(
users.select().order_by(users.c.user_id)
@@ -601,7 +600,7 @@ class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest):
def test_plain_in(self, connection):
users = self.tables.users
- self._data_fixture()
+ self._data_fixture(connection)
stmt = (
select(users.c.user_id, users.c.goofy8)
@@ -613,7 +612,7 @@ class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest):
def test_expanding_in(self, connection):
users = self.tables.users
- self._data_fixture()
+ self._data_fixture(connection)
stmt = (
select(users.c.user_id, users.c.goofy8)
@@ -1225,41 +1224,38 @@ class VariantTest(fixtures.TestBase, AssertsCompiledSQL):
@testing.only_on("sqlite")
@testing.provide_metadata
- def test_round_trip(self):
+ def test_round_trip(self, connection):
variant = self.UTypeOne().with_variant(self.UTypeTwo(), "sqlite")
t = Table("t", self.metadata, Column("x", variant))
- with testing.db.connect() as conn:
- t.create(conn)
+ t.create(connection)
- conn.execute(t.insert(), x="foo")
+ connection.execute(t.insert(), x="foo")
- eq_(conn.scalar(select(t.c.x).where(t.c.x == "foo")), "fooUTWO")
+ eq_(connection.scalar(select(t.c.x).where(t.c.x == "foo")), "fooUTWO")
@testing.only_on("sqlite")
@testing.provide_metadata
- def test_round_trip_sqlite_datetime(self):
+ def test_round_trip_sqlite_datetime(self, connection):
variant = DateTime().with_variant(
dialects.sqlite.DATETIME(truncate_microseconds=True), "sqlite"
)
t = Table("t", self.metadata, Column("x", variant))
- with testing.db.connect() as conn:
- t.create(conn)
+ t.create(connection)
- conn.execute(
- t.insert(), x=datetime.datetime(2015, 4, 18, 10, 15, 17, 4839)
- )
+ connection.execute(
+ t.insert(), x=datetime.datetime(2015, 4, 18, 10, 15, 17, 4839)
+ )
- eq_(
- conn.scalar(
- select(t.c.x).where(
- t.c.x
- == datetime.datetime(2015, 4, 18, 10, 15, 17, 1059)
- )
- ),
- datetime.datetime(2015, 4, 18, 10, 15, 17),
- )
+ eq_(
+ connection.scalar(
+ select(t.c.x).where(
+ t.c.x == datetime.datetime(2015, 4, 18, 10, 15, 17, 1059)
+ )
+ ),
+ datetime.datetime(2015, 4, 18, 10, 15, 17),
+ )
class UnicodeTest(fixtures.TestBase):
@@ -1702,14 +1698,25 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
2,
)
- with testing.db.connect() as conn:
- self.metadata.create_all(conn)
+ self.metadata.create_all(testing.db)
+
+ # not using the connection fixture because we need to rollback and
+ # start again in the middle
+ with testing.db.connect() as connection:
+ # postgresql needs this in order to continue after the exception
+ trans = connection.begin()
assert_raises(
(exc.DBAPIError,),
- conn.exec_driver_sql,
+ connection.exec_driver_sql,
"insert into my_table " "(data) values('four')",
)
- conn.exec_driver_sql("insert into my_table (data) values ('two')")
+ trans.rollback()
+
+ with connection.begin():
+ connection.exec_driver_sql(
+ "insert into my_table (data) values ('two')"
+ )
+ eq_(connection.execute(select(t.c.data)).scalar(), "two")
@testing.requires.enforces_check_constraints
@testing.provide_metadata
@@ -1747,34 +1754,44 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
2,
)
- with testing.db.connect() as conn:
- self.metadata.create_all(conn)
+ self.metadata.create_all(testing.db)
+
+ # not using the connection fixture because we need to rollback and
+ # start again in the middle
+ with testing.db.connect() as connection:
+ # postgresql needs this in order to continue after the exception
+ trans = connection.begin()
assert_raises(
(exc.DBAPIError,),
- conn.exec_driver_sql,
+ connection.exec_driver_sql,
"insert into my_table " "(data) values('two')",
)
- conn.exec_driver_sql("insert into my_table (data) values ('four')")
+ trans.rollback()
- def test_skip_check_constraint(self):
- with testing.db.connect() as conn:
- conn.exec_driver_sql(
- "insert into non_native_enum_table "
- "(id, someotherenum) values(1, 'four')"
- )
- eq_(
- conn.exec_driver_sql(
- "select someotherenum from non_native_enum_table"
- ).scalar(),
- "four",
- )
- assert_raises_message(
- LookupError,
- "'four' is not among the defined enum values. "
- "Enum name: None. Possible values: one, two, three",
- conn.scalar,
- select(self.tables.non_native_enum_table.c.someotherenum),
- )
+ with connection.begin():
+ connection.exec_driver_sql(
+ "insert into my_table (data) values ('four')"
+ )
+ eq_(connection.execute(select(t.c.data)).scalar(), "four")
+
+ def test_skip_check_constraint(self, connection):
+ connection.exec_driver_sql(
+ "insert into non_native_enum_table "
+ "(id, someotherenum) values(1, 'four')"
+ )
+ eq_(
+ connection.exec_driver_sql(
+ "select someotherenum from non_native_enum_table"
+ ).scalar(),
+ "four",
+ )
+ assert_raises_message(
+ LookupError,
+ "'four' is not among the defined enum values. "
+ "Enum name: None. Possible values: one, two, three",
+ connection.scalar,
+ select(self.tables.non_native_enum_table.c.someotherenum),
+ )
def test_non_native_round_trip(self, connection):
non_native_enum_table = self.tables["non_native_enum_table"]
@@ -2086,15 +2103,15 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
eq_(e.length, 42)
-binary_table = MyPickleType = metadata = None
+MyPickleType = None
-class BinaryTest(fixtures.TestBase, AssertsExecutionResults):
+class BinaryTest(fixtures.TablesTest, AssertsExecutionResults):
__backend__ = True
@classmethod
- def setup_class(cls):
- global binary_table, MyPickleType, metadata
+ def define_tables(cls, metadata):
+ global MyPickleType
class MyPickleType(types.TypeDecorator):
impl = PickleType
@@ -2109,8 +2126,7 @@ class BinaryTest(fixtures.TestBase, AssertsExecutionResults):
value.stuff = "this is the right stuff"
return value
- metadata = MetaData(testing.db)
- binary_table = Table(
+ Table(
"binary_table",
metadata,
Column(
@@ -2125,19 +2141,11 @@ class BinaryTest(fixtures.TestBase, AssertsExecutionResults):
Column("pickled", PickleType),
Column("mypickle", MyPickleType),
)
- metadata.create_all()
-
- @engines.close_first
- def teardown(self):
- with testing.db.connect() as conn:
- conn.execute(binary_table.delete())
-
- @classmethod
- def teardown_class(cls):
- metadata.drop_all()
@testing.requires.non_broken_binary
def test_round_trip(self, connection):
+ binary_table = self.tables.binary_table
+
testobj1 = pickleable.Foo("im foo 1")
testobj2 = pickleable.Foo("im foo 2")
testobj3 = pickleable.Foo("im foo 3")
@@ -2197,6 +2205,7 @@ class BinaryTest(fixtures.TestBase, AssertsExecutionResults):
@testing.requires.binary_comparisons
def test_comparison(self, connection):
"""test that type coercion occurs on comparison for binary"""
+ binary_table = self.tables.binary_table
expr = binary_table.c.data == "foo"
assert isinstance(expr.right.type, LargeBinary)
@@ -2419,17 +2428,17 @@ class ArrayTest(fixtures.TestBase):
assert isinstance(arrtable.c.strarr[1:3].type, MyArray)
-test_table = meta = MyCustomType = MyTypeDec = None
+MyCustomType = MyTypeDec = None
class ExpressionTest(
- fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL
+ fixtures.TablesTest, AssertsExecutionResults, AssertsCompiledSQL
):
__dialect__ = "default"
@classmethod
- def setup_class(cls):
- global test_table, meta, MyCustomType, MyTypeDec
+ def define_tables(cls, metadata):
+ global MyCustomType, MyTypeDec
class MyCustomType(types.UserDefinedType):
def get_col_spec(self):
@@ -2463,10 +2472,9 @@ class ExpressionTest(
def process_result_value(self, value, dialect):
return value + "BIND_OUT"
- meta = MetaData(testing.db)
- test_table = Table(
+ Table(
"test",
- meta,
+ metadata,
Column("id", Integer, primary_key=True),
Column("data", String(30)),
Column("atimestamp", Date),
@@ -2474,25 +2482,22 @@ class ExpressionTest(
Column("bvalue", MyTypeDec(50)),
)
- meta.create_all()
-
- with testing.db.connect() as conn:
- conn.execute(
- test_table.insert(),
- {
- "id": 1,
- "data": "somedata",
- "atimestamp": datetime.date(2007, 10, 15),
- "avalue": 25,
- "bvalue": "foo",
- },
- )
-
@classmethod
- def teardown_class(cls):
- meta.drop_all()
+ def insert_data(cls, connection):
+ test_table = cls.tables.test
+ connection.execute(
+ test_table.insert(),
+ {
+ "id": 1,
+ "data": "somedata",
+ "atimestamp": datetime.date(2007, 10, 15),
+ "avalue": 25,
+ "bvalue": "foo",
+ },
+ )
def test_control(self, connection):
+ test_table = self.tables.test
assert (
connection.exec_driver_sql("select avalue from test").scalar()
== 250
@@ -2513,6 +2518,9 @@ class ExpressionTest(
def test_bind_adapt(self, connection):
# test an untyped bind gets the left side's type
+
+ test_table = self.tables.test
+
expr = test_table.c.atimestamp == bindparam("thedate")
eq_(expr.right.type._type_affinity, Date)
@@ -2565,6 +2573,8 @@ class ExpressionTest(
)
def test_grouped_bind_adapt(self):
+ test_table = self.tables.test
+
expr = test_table.c.atimestamp == elements.Grouping(
bindparam("thedate")
)
@@ -2579,6 +2589,8 @@ class ExpressionTest(
eq_(expr.right.element.element.type._type_affinity, Date)
def test_bind_adapt_update(self):
+ test_table = self.tables.test
+
bp = bindparam("somevalue")
stmt = test_table.update().values(avalue=bp)
compiled = stmt.compile()
@@ -2586,13 +2598,17 @@ class ExpressionTest(
eq_(compiled.binds["somevalue"].type._type_affinity, MyCustomType)
def test_bind_adapt_insert(self):
+ test_table = self.tables.test
bp = bindparam("somevalue")
+
stmt = test_table.insert().values(avalue=bp)
compiled = stmt.compile()
eq_(bp.type._type_affinity, types.NullType)
eq_(compiled.binds["somevalue"].type._type_affinity, MyCustomType)
def test_bind_adapt_expression(self):
+ test_table = self.tables.test
+
bp = bindparam("somevalue")
stmt = test_table.c.avalue == bp
eq_(bp.type._type_affinity, types.NullType)
@@ -2629,6 +2645,8 @@ class ExpressionTest(
is_(literal(data).type.__class__, expected)
def test_typedec_operator_adapt(self, connection):
+ test_table = self.tables.test
+
expr = test_table.c.bvalue + "hi"
assert expr.type.__class__ is MyTypeDec
@@ -2846,6 +2864,8 @@ class ExpressionTest(
eq_(expr.type, types.NULLTYPE)
def test_distinct(self, connection):
+ test_table = self.tables.test
+
s = select(distinct(test_table.c.avalue))
eq_(connection.execute(s).scalar(), 25)
@@ -3004,17 +3024,18 @@ class NumericRawSQLTest(fixtures.TestBase):
__backend__ = True
- def _fixture(self, metadata, type_, data):
+ def _fixture(self, connection, metadata, type_, data):
t = Table("t", metadata, Column("val", type_))
- metadata.create_all()
- with testing.db.connect() as conn:
- conn.execute(t.insert(), val=data)
+ metadata.create_all(connection)
+ connection.execute(t.insert(), val=data)
@testing.fails_on("sqlite", "Doesn't provide Decimal results natively")
@testing.provide_metadata
def test_decimal_fp(self, connection):
metadata = self.metadata
- self._fixture(metadata, Numeric(10, 5), decimal.Decimal("45.5"))
+ self._fixture(
+ connection, metadata, Numeric(10, 5), decimal.Decimal("45.5")
+ )
val = connection.exec_driver_sql("select val from t").scalar()
assert isinstance(val, decimal.Decimal)
eq_(val, decimal.Decimal("45.5"))
@@ -3023,7 +3044,9 @@ class NumericRawSQLTest(fixtures.TestBase):
@testing.provide_metadata
def test_decimal_int(self, connection):
metadata = self.metadata
- self._fixture(metadata, Numeric(10, 5), decimal.Decimal("45"))
+ self._fixture(
+ connection, metadata, Numeric(10, 5), decimal.Decimal("45")
+ )
val = connection.exec_driver_sql("select val from t").scalar()
assert isinstance(val, decimal.Decimal)
eq_(val, decimal.Decimal("45"))
@@ -3031,7 +3054,7 @@ class NumericRawSQLTest(fixtures.TestBase):
@testing.provide_metadata
def test_ints(self, connection):
metadata = self.metadata
- self._fixture(metadata, Integer, 45)
+ self._fixture(connection, metadata, Integer, 45)
val = connection.exec_driver_sql("select val from t").scalar()
assert isinstance(val, util.int_types)
eq_(val, 45)
@@ -3039,7 +3062,7 @@ class NumericRawSQLTest(fixtures.TestBase):
@testing.provide_metadata
def test_float(self, connection):
metadata = self.metadata
- self._fixture(metadata, Float, 46.583)
+ self._fixture(connection, metadata, Float, 46.583)
val = connection.exec_driver_sql("select val from t").scalar()
assert isinstance(val, float)
@@ -3050,19 +3073,14 @@ class NumericRawSQLTest(fixtures.TestBase):
eq_(val, 46.583)
-interval_table = metadata = None
-
-
-class IntervalTest(fixtures.TestBase, AssertsExecutionResults):
+class IntervalTest(fixtures.TablesTest, AssertsExecutionResults):
__backend__ = True
@classmethod
- def setup_class(cls):
- global interval_table, metadata
- metadata = MetaData(testing.db)
- interval_table = Table(
- "intervaltable",
+ def define_tables(cls, metadata):
+ Table(
+ "intervals",
metadata,
Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
@@ -3074,16 +3092,6 @@ class IntervalTest(fixtures.TestBase, AssertsExecutionResults):
),
Column("non_native_interval", Interval(native=False)),
)
- metadata.create_all()
-
- @engines.close_first
- def teardown(self):
- with testing.db.connect() as conn:
- conn.execute(interval_table.delete())
-
- @classmethod
- def teardown_class(cls):
- metadata.drop_all()
def test_non_native_adapt(self):
interval = Interval(native=False)
@@ -3092,30 +3100,32 @@ class IntervalTest(fixtures.TestBase, AssertsExecutionResults):
assert adapted.native is False
eq_(str(adapted), "DATETIME")
- def test_roundtrip(self):
+ def test_roundtrip(self, connection):
+ interval_table = self.tables.intervals
+
small_delta = datetime.timedelta(days=15, seconds=5874)
delta = datetime.timedelta(14)
- with testing.db.begin() as conn:
- conn.execute(
- interval_table.insert(),
- native_interval=small_delta,
- native_interval_args=delta,
- non_native_interval=delta,
- )
- row = conn.execute(interval_table.select()).first()
+ connection.execute(
+ interval_table.insert(),
+ native_interval=small_delta,
+ native_interval_args=delta,
+ non_native_interval=delta,
+ )
+ row = connection.execute(interval_table.select()).first()
eq_(row.native_interval, small_delta)
eq_(row.native_interval_args, delta)
eq_(row.non_native_interval, delta)
- def test_null(self):
- with testing.db.begin() as conn:
- conn.execute(
- interval_table.insert(),
- id=1,
- native_inverval=None,
- non_native_interval=None,
- )
- row = conn.execute(interval_table.select()).first()
+ def test_null(self, connection):
+ interval_table = self.tables.intervals
+
+ connection.execute(
+ interval_table.insert(),
+ id=1,
+ native_inverval=None,
+ non_native_interval=None,
+ )
+ row = connection.execute(interval_table.select()).first()
eq_(row.native_interval, None)
eq_(row.native_interval_args, None)
eq_(row.non_native_interval, None)
@@ -3215,25 +3225,24 @@ class BooleanTest(
)
@testing.requires.non_native_boolean_unconstrained
- def test_nonnative_processor_coerces_integer_to_boolean(self):
+ def test_nonnative_processor_coerces_integer_to_boolean(self, connection):
boolean_table = self.tables.boolean_table
- with testing.db.connect() as conn:
- conn.exec_driver_sql(
- "insert into boolean_table (id, unconstrained_value) "
- "values (1, 5)"
- )
+ connection.exec_driver_sql(
+ "insert into boolean_table (id, unconstrained_value) "
+ "values (1, 5)"
+ )
- eq_(
- conn.exec_driver_sql(
- "select unconstrained_value from boolean_table"
- ).scalar(),
- 5,
- )
+ eq_(
+ connection.exec_driver_sql(
+ "select unconstrained_value from boolean_table"
+ ).scalar(),
+ 5,
+ )
- eq_(
- conn.scalar(select(boolean_table.c.unconstrained_value)),
- True,
- )
+ eq_(
+ connection.scalar(select(boolean_table.c.unconstrained_value)),
+ True,
+ )
def test_bind_processor_coercion_native_true(self):
proc = Boolean().bind_processor(
diff --git a/test/sql/test_update.py b/test/sql/test_update.py
index ec96af207..946a01651 100644
--- a/test/sql/test_update.py
+++ b/test/sql/test_update.py
@@ -1263,10 +1263,10 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest):
__backend__ = True
@testing.requires.update_from
- def test_exec_two_table(self):
+ def test_exec_two_table(self, connection):
users, addresses = self.tables.users, self.tables.addresses
- testing.db.execute(
+ connection.execute(
addresses.update()
.values(email_address=users.c.name)
.where(users.c.id == addresses.c.user_id)
@@ -1280,14 +1280,14 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest):
(4, 8, "x", "ed"),
(5, 9, "x", "fred@fred.com"),
]
- self._assert_addresses(addresses, expected)
+ self._assert_addresses(connection, addresses, expected)
@testing.requires.update_from
- def test_exec_two_table_plus_alias(self):
+ def test_exec_two_table_plus_alias(self, connection):
users, addresses = self.tables.users, self.tables.addresses
a1 = addresses.alias()
- testing.db.execute(
+ connection.execute(
addresses.update()
.values(email_address=users.c.name)
.where(users.c.id == a1.c.user_id)
@@ -1302,15 +1302,15 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest):
(4, 8, "x", "ed"),
(5, 9, "x", "fred@fred.com"),
]
- self._assert_addresses(addresses, expected)
+ self._assert_addresses(connection, addresses, expected)
@testing.requires.update_from
- def test_exec_three_table(self):
+ def test_exec_three_table(self, connection):
users = self.tables.users
addresses = self.tables.addresses
dingalings = self.tables.dingalings
- testing.db.execute(
+ connection.execute(
addresses.update()
.values(email_address=users.c.name)
.where(users.c.id == addresses.c.user_id)
@@ -1326,15 +1326,15 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest):
(4, 8, "x", "ed@lala.com"),
(5, 9, "x", "fred@fred.com"),
]
- self._assert_addresses(addresses, expected)
+ self._assert_addresses(connection, addresses, expected)
@testing.only_on("mysql", "Multi table update")
- def test_exec_multitable(self):
+ def test_exec_multitable(self, connection):
users, addresses = self.tables.users, self.tables.addresses
values = {addresses.c.email_address: "updated", users.c.name: "ed2"}
- testing.db.execute(
+ connection.execute(
addresses.update()
.values(values)
.where(users.c.id == addresses.c.user_id)
@@ -1348,18 +1348,18 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest):
(4, 8, "x", "updated"),
(5, 9, "x", "fred@fred.com"),
]
- self._assert_addresses(addresses, expected)
+ self._assert_addresses(connection, addresses, expected)
expected = [(7, "jack"), (8, "ed2"), (9, "fred"), (10, "chuck")]
- self._assert_users(users, expected)
+ self._assert_users(connection, users, expected)
@testing.only_on("mysql", "Multi table update")
- def test_exec_join_multitable(self):
+ def test_exec_join_multitable(self, connection):
users, addresses = self.tables.users, self.tables.addresses
values = {addresses.c.email_address: "updated", users.c.name: "ed2"}
- testing.db.execute(
+ connection.execute(
update(users.join(addresses))
.values(values)
.where(users.c.name == "ed")
@@ -1372,18 +1372,18 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest):
(4, 8, "x", "updated"),
(5, 9, "x", "fred@fred.com"),
]
- self._assert_addresses(addresses, expected)
+ self._assert_addresses(connection, addresses, expected)
expected = [(7, "jack"), (8, "ed2"), (9, "fred"), (10, "chuck")]
- self._assert_users(users, expected)
+ self._assert_users(connection, users, expected)
@testing.only_on("mysql", "Multi table update")
- def test_exec_multitable_same_name(self):
+ def test_exec_multitable_same_name(self, connection):
users, addresses = self.tables.users, self.tables.addresses
values = {addresses.c.name: "ad_ed2", users.c.name: "ed2"}
- testing.db.execute(
+ connection.execute(
addresses.update()
.values(values)
.where(users.c.id == addresses.c.user_id)
@@ -1397,18 +1397,18 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest):
(4, 8, "ad_ed2", "ed@lala.com"),
(5, 9, "x", "fred@fred.com"),
]
- self._assert_addresses(addresses, expected)
+ self._assert_addresses(connection, addresses, expected)
expected = [(7, "jack"), (8, "ed2"), (9, "fred"), (10, "chuck")]
- self._assert_users(users, expected)
+ self._assert_users(connection, users, expected)
- def _assert_addresses(self, addresses, expected):
+ def _assert_addresses(self, connection, addresses, expected):
stmt = addresses.select().order_by(addresses.c.id)
- eq_(testing.db.execute(stmt).fetchall(), expected)
+ eq_(connection.execute(stmt).fetchall(), expected)
- def _assert_users(self, users, expected):
+ def _assert_users(self, connection, users, expected):
stmt = users.select().order_by(users.c.id)
- eq_(testing.db.execute(stmt).fetchall(), expected)
+ eq_(connection.execute(stmt).fetchall(), expected)
class UpdateFromMultiTableUpdateDefaultsTest(
@@ -1472,12 +1472,12 @@ class UpdateFromMultiTableUpdateDefaultsTest(
)
@testing.only_on("mysql", "Multi table update")
- def test_defaults_second_table(self):
+ def test_defaults_second_table(self, connection):
users, addresses = self.tables.users, self.tables.addresses
values = {addresses.c.email_address: "updated", users.c.name: "ed2"}
- ret = testing.db.execute(
+ ret = connection.execute(
addresses.update()
.values(values)
.where(users.c.id == addresses.c.user_id)
@@ -1491,18 +1491,18 @@ class UpdateFromMultiTableUpdateDefaultsTest(
(3, 8, "updated"),
(4, 9, "fred@fred.com"),
]
- self._assert_addresses(addresses, expected)
+ self._assert_addresses(connection, addresses, expected)
expected = [(8, "ed2", "im the update"), (9, "fred", "value")]
- self._assert_users(users, expected)
+ self._assert_users(connection, users, expected)
@testing.only_on("mysql", "Multi table update")
- def test_defaults_second_table_same_name(self):
+ def test_defaults_second_table_same_name(self, connection):
users, foobar = self.tables.users, self.tables.foobar
values = {foobar.c.data: foobar.c.data + "a", users.c.name: "ed2"}
- ret = testing.db.execute(
+ ret = connection.execute(
users.update()
.values(values)
.where(users.c.id == foobar.c.user_id)
@@ -1519,16 +1519,16 @@ class UpdateFromMultiTableUpdateDefaultsTest(
(3, 8, "d2a", "im the other update"),
(4, 9, "d3", None),
]
- self._assert_foobar(foobar, expected)
+ self._assert_foobar(connection, foobar, expected)
expected = [(8, "ed2", "im the update"), (9, "fred", "value")]
- self._assert_users(users, expected)
+ self._assert_users(connection, users, expected)
@testing.only_on("mysql", "Multi table update")
- def test_no_defaults_second_table(self):
+ def test_no_defaults_second_table(self, connection):
users, addresses = self.tables.users, self.tables.addresses
- ret = testing.db.execute(
+ ret = connection.execute(
addresses.update()
.values({"email_address": users.c.name})
.where(users.c.id == addresses.c.user_id)
@@ -1538,20 +1538,20 @@ class UpdateFromMultiTableUpdateDefaultsTest(
eq_(ret.prefetch_cols(), [])
expected = [(2, 8, "ed"), (3, 8, "ed"), (4, 9, "fred@fred.com")]
- self._assert_addresses(addresses, expected)
+ self._assert_addresses(connection, addresses, expected)
# users table not actually updated, so no onupdate
expected = [(8, "ed", "value"), (9, "fred", "value")]
- self._assert_users(users, expected)
+ self._assert_users(connection, users, expected)
- def _assert_foobar(self, foobar, expected):
+ def _assert_foobar(self, connection, foobar, expected):
stmt = foobar.select().order_by(foobar.c.id)
- eq_(testing.db.execute(stmt).fetchall(), expected)
+ eq_(connection.execute(stmt).fetchall(), expected)
- def _assert_addresses(self, addresses, expected):
+ def _assert_addresses(self, connection, addresses, expected):
stmt = addresses.select().order_by(addresses.c.id)
- eq_(testing.db.execute(stmt).fetchall(), expected)
+ eq_(connection.execute(stmt).fetchall(), expected)
- def _assert_users(self, users, expected):
+ def _assert_users(self, connection, users, expected):
stmt = users.select().order_by(users.c.id)
- eq_(testing.db.execute(stmt).fetchall(), expected)
+ eq_(connection.execute(stmt).fetchall(), expected)