summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/testing/suite
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/testing/suite')
-rw-r--r--lib/sqlalchemy/testing/suite/test_reflection.py2
-rw-r--r--lib/sqlalchemy/testing/suite/test_results.py32
-rw-r--r--lib/sqlalchemy/testing/suite/test_types.py74
3 files changed, 49 insertions, 59 deletions
diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py
index 6c3c1005a..de157d028 100644
--- a/lib/sqlalchemy/testing/suite/test_reflection.py
+++ b/lib/sqlalchemy/testing/suite/test_reflection.py
@@ -293,7 +293,7 @@ class ComponentReflectionTest(fixtures.TablesTest):
from sqlalchemy import pool
return engines.testing_engine(
- options=dict(poolclass=pool.StaticPool)
+ options=dict(poolclass=pool.StaticPool, scope="class"),
)
else:
return config.db
diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py
index e0fdbe47a..e8dd6cf2c 100644
--- a/lib/sqlalchemy/testing/suite/test_results.py
+++ b/lib/sqlalchemy/testing/suite/test_results.py
@@ -261,10 +261,6 @@ class ServerSideCursorsTest(
)
return self.engine
- def tearDown(self):
- engines.testing_reaper.close_all()
- self.engine.dispose()
-
@testing.combinations(
("global_string", True, "select 1", True),
("global_text", True, text("select 1"), True),
@@ -309,24 +305,22 @@ class ServerSideCursorsTest(
def test_conn_option(self):
engine = self._fixture(False)
- # should be enabled for this one
- result = (
- engine.connect()
- .execution_options(stream_results=True)
- .exec_driver_sql("select 1")
- )
- assert self._is_server_side(result.cursor)
+ with engine.connect() as conn:
+ # should be enabled for this one
+ result = conn.execution_options(
+ stream_results=True
+ ).exec_driver_sql("select 1")
+ assert self._is_server_side(result.cursor)
def test_stmt_enabled_conn_option_disabled(self):
engine = self._fixture(False)
s = select(1).execution_options(stream_results=True)
- # not this one
- result = (
- engine.connect().execution_options(stream_results=False).execute(s)
- )
- assert not self._is_server_side(result.cursor)
+ with engine.connect() as conn:
+ # not this one
+ result = conn.execution_options(stream_results=False).execute(s)
+ assert not self._is_server_side(result.cursor)
def test_aliases_and_ss(self):
engine = self._fixture(False)
@@ -344,8 +338,7 @@ class ServerSideCursorsTest(
assert not self._is_server_side(result.cursor)
result.close()
- @testing.provide_metadata
- def test_roundtrip_fetchall(self):
+ def test_roundtrip_fetchall(self, metadata):
md = self.metadata
engine = self._fixture(True)
@@ -385,8 +378,7 @@ class ServerSideCursorsTest(
0,
)
- @testing.provide_metadata
- def test_roundtrip_fetchmany(self):
+ def test_roundtrip_fetchmany(self, metadata):
md = self.metadata
engine = self._fixture(True)
diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py
index 3a5e02c32..ebcceaae7 100644
--- a/lib/sqlalchemy/testing/suite/test_types.py
+++ b/lib/sqlalchemy/testing/suite/test_types.py
@@ -511,24 +511,23 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
__backend__ = True
@testing.fixture
- def do_numeric_test(self, metadata):
+ def do_numeric_test(self, metadata, connection):
@testing.emits_warning(
r".*does \*not\* support Decimal objects natively"
)
def run(type_, input_, output, filter_=None, check_scale=False):
t = Table("t", metadata, Column("x", type_))
- t.create(testing.db)
- with config.db.begin() as conn:
- conn.execute(t.insert(), [{"x": x} for x in input_])
-
- result = {row[0] for row in conn.execute(t.select())}
- output = set(output)
- if filter_:
- result = set(filter_(x) for x in result)
- output = set(filter_(x) for x in output)
- eq_(result, output)
- if check_scale:
- eq_([str(x) for x in result], [str(x) for x in output])
+ t.create(connection)
+ connection.execute(t.insert(), [{"x": x} for x in input_])
+
+ result = {row[0] for row in connection.execute(t.select())}
+ output = set(output)
+ if filter_:
+ result = set(filter_(x) for x in result)
+ output = set(filter_(x) for x in output)
+ eq_(result, output)
+ if check_scale:
+ eq_([str(x) for x in result], [str(x) for x in output])
return run
@@ -1165,40 +1164,39 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
},
)
- def test_eval_none_flag_orm(self):
+ def test_eval_none_flag_orm(self, connection):
Base = declarative_base()
class Data(Base):
__table__ = self.tables.data_table
- s = Session(testing.db)
+ with Session(connection) as s:
+ d1 = Data(name="d1", data=None, nulldata=None)
+ s.add(d1)
+ s.commit()
- d1 = Data(name="d1", data=None, nulldata=None)
- s.add(d1)
- s.commit()
-
- s.bulk_insert_mappings(
- Data, [{"name": "d2", "data": None, "nulldata": None}]
- )
- eq_(
- s.query(
- cast(self.tables.data_table.c.data, String()),
- cast(self.tables.data_table.c.nulldata, String),
+ s.bulk_insert_mappings(
+ Data, [{"name": "d2", "data": None, "nulldata": None}]
)
- .filter(self.tables.data_table.c.name == "d1")
- .first(),
- ("null", None),
- )
- eq_(
- s.query(
- cast(self.tables.data_table.c.data, String()),
- cast(self.tables.data_table.c.nulldata, String),
+ eq_(
+ s.query(
+ cast(self.tables.data_table.c.data, String()),
+ cast(self.tables.data_table.c.nulldata, String),
+ )
+ .filter(self.tables.data_table.c.name == "d1")
+ .first(),
+ ("null", None),
+ )
+ eq_(
+ s.query(
+ cast(self.tables.data_table.c.data, String()),
+ cast(self.tables.data_table.c.nulldata, String),
+ )
+ .filter(self.tables.data_table.c.name == "d2")
+ .first(),
+ ("null", None),
)
- .filter(self.tables.data_table.c.name == "d2")
- .first(),
- ("null", None),
- )
class JSONLegacyStringCastIndexTest(