diff options
| author | Federico Caselli <cfederico87@gmail.com> | 2022-11-03 20:52:21 +0100 |
|---|---|---|
| committer | Federico Caselli <cfederico87@gmail.com> | 2022-11-16 23:03:04 +0100 |
| commit | 4eb4ceca36c7ce931ea65ac06d6ed08bf459fc66 (patch) | |
| tree | 4970cff3f78489a4a0066cd27fd4bae682402957 /lib/sqlalchemy/testing | |
| parent | 3fc6c40ea77c971d3067dab0fdf57a5b5313b69b (diff) | |
| download | sqlalchemy-4eb4ceca36c7ce931ea65ac06d6ed08bf459fc66.tar.gz | |
Try running pyupgrade on the code
command run is "pyupgrade --py37-plus --keep-runtime-typing --keep-percent-format <files...>"
pyupgrade will change assert_ to assertTrue. That was reverted since assertTrue does not
exists in sqlalchemy fixtures
Change-Id: Ie1ed2675c7b11d893d78e028aad0d1576baebb55
Diffstat (limited to 'lib/sqlalchemy/testing')
| -rw-r--r-- | lib/sqlalchemy/testing/assertsql.py | 12 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/engines.py | 8 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/exclusions.py | 6 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/fixtures.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/plugin/pytestplugin.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/profiling.py | 6 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/suite/test_dialect.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/suite/test_reflection.py | 22 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/suite/test_select.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/suite/test_types.py | 76 |
10 files changed, 64 insertions, 77 deletions
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index 2fda1e9cb..d183372c3 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -269,9 +269,9 @@ class DialectSQL(CompiledSQL): return received_stmt == stmt def _received_statement(self, execute_observed): - received_stmt, received_params = super( - DialectSQL, self - )._received_statement(execute_observed) + received_stmt, received_params = super()._received_statement( + execute_observed + ) # TODO: why do we need this part? for real_stmt in execute_observed.statements: @@ -392,15 +392,15 @@ class EachOf(AssertRule): if self.rules and not self.rules[0].is_consumed: self.rules[0].no_more_statements() elif self.rules: - super(EachOf, self).no_more_statements() + super().no_more_statements() class Conditional(EachOf): def __init__(self, condition, rules, else_rules): if condition: - super(Conditional, self).__init__(*rules) + super().__init__(*rules) else: - super(Conditional, self).__init__(*else_rules) + super().__init__(*else_rules) class Or(AllOf): diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py index c083f4e73..0a60a20d3 100644 --- a/lib/sqlalchemy/testing/engines.py +++ b/lib/sqlalchemy/testing/engines.py @@ -285,21 +285,21 @@ def reconnecting_engine(url=None, options=None): @typing.overload def testing_engine( - url: Optional["URL"] = None, + url: Optional[URL] = None, options: Optional[Dict[str, Any]] = None, asyncio: Literal[False] = False, transfer_staticpool: bool = False, -) -> "Engine": +) -> Engine: ... @typing.overload def testing_engine( - url: Optional["URL"] = None, + url: Optional[URL] = None, options: Optional[Dict[str, Any]] = None, asyncio: Literal[True] = True, transfer_staticpool: bool = False, -) -> "AsyncEngine": +) -> AsyncEngine: ... diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py index 25c6a0482..3cb060d01 100644 --- a/lib/sqlalchemy/testing/exclusions.py +++ b/lib/sqlalchemy/testing/exclusions.py @@ -129,10 +129,8 @@ class compound: for fail in self.fails: if fail(config): print( - ( - "%s failed as expected (%s): %s " - % (name, fail._as_string(config), ex) - ) + "%s failed as expected (%s): %s " + % (name, fail._as_string(config), ex) ) break else: diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index dcee3f18b..12b5acba4 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -814,7 +814,7 @@ class DeclarativeMappedTest(MappedTest): # sets up cls.Basic which is helpful for things like composite # classes - super(DeclarativeMappedTest, cls)._with_register_classes(fn) + super()._with_register_classes(fn) if cls._tables_metadata.tables and cls.run_create_tables: cls._tables_metadata.create_all(config.db) diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index 0a70f4008..d590ecbe4 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -49,7 +49,7 @@ def pytest_addoption(parser): required=False, help=None, # noqa ): - super(CallableAction, self).__init__( + super().__init__( option_strings=option_strings, dest=dest, nargs=0, @@ -210,7 +210,7 @@ def pytest_collection_modifyitems(session, config, items): and not item.getparent(pytest.Class).name.startswith("_") ] - test_classes = set(item.getparent(pytest.Class) for item in items) + test_classes = {item.getparent(pytest.Class) for item in items} def collect(element): for inst_or_fn in element.collect(): diff --git a/lib/sqlalchemy/testing/profiling.py b/lib/sqlalchemy/testing/profiling.py index 7672bcde5..dfc3f28f6 100644 --- a/lib/sqlalchemy/testing/profiling.py +++ b/lib/sqlalchemy/testing/profiling.py @@ -195,7 +195,7 @@ class ProfileStatsFile: def _read(self): try: profile_f = open(self.fname) - except IOError: + except OSError: return for lineno, line in enumerate(profile_f): line = line.strip() @@ -212,7 +212,7 @@ class ProfileStatsFile: profile_f.close() def _write(self): - print(("Writing profile file %s" % self.fname)) + print("Writing profile file %s" % self.fname) profile_f = open(self.fname, "w") profile_f.write(self._header()) for test_key in sorted(self.data): @@ -293,7 +293,7 @@ def count_functions(variance=0.05): else: line_no, expected_count = expected - print(("Pstats calls: %d Expected %s" % (callcount, expected_count))) + print("Pstats calls: %d Expected %s" % (callcount, expected_count)) stats.sort_stats(*re.split(r"[, ]", _profile_stats.sort)) stats.print_stats() if _profile_stats.dump: diff --git a/lib/sqlalchemy/testing/suite/test_dialect.py b/lib/sqlalchemy/testing/suite/test_dialect.py index 33e395c48..01cec1fb0 100644 --- a/lib/sqlalchemy/testing/suite/test_dialect.py +++ b/lib/sqlalchemy/testing/suite/test_dialect.py @@ -1,4 +1,3 @@ -#! coding: utf-8 # mypy: ignore-errors diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py index 68d1c13fa..bf745095d 100644 --- a/lib/sqlalchemy/testing/suite/test_reflection.py +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -1871,14 +1871,12 @@ class ComponentReflectionTest(ComparesTables, OneConnectionTablesTest): # "unique constraints" are actually unique indexes (with possible # exception of a unique that is a dupe of another one in the case # of Oracle). make sure # they aren't duplicated. - idx_names = set([idx.name for idx in reflected.indexes]) - uq_names = set( - [ - uq.name - for uq in reflected.constraints - if isinstance(uq, sa.UniqueConstraint) - ] - ).difference(["unique_c_a_b"]) + idx_names = {idx.name for idx in reflected.indexes} + uq_names = { + uq.name + for uq in reflected.constraints + if isinstance(uq, sa.UniqueConstraint) + }.difference(["unique_c_a_b"]) assert not idx_names.intersection(uq_names) if names_that_duplicate_index: @@ -2519,10 +2517,10 @@ class ComponentReflectionTestExtra(ComparesIndexes, fixtures.TestBase): ) t.create(connection) eq_( - dict( - (col["name"], col["nullable"]) + { + col["name"]: col["nullable"] for col in inspect(connection).get_columns("t") - ), + }, {"a": True, "b": False}, ) @@ -2613,7 +2611,7 @@ class ComponentReflectionTestExtra(ComparesIndexes, fixtures.TestBase): # that can reflect these, since alembic looks for this opts = insp.get_foreign_keys("table")[0]["options"] - eq_(dict((k, opts[k]) for k in opts if opts[k]), {}) + eq_({k: opts[k] for k in opts if opts[k]}, {}) opts = insp.get_foreign_keys("user")[0]["options"] eq_(opts, expected) diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py index 838b740fd..6394e4b9a 100644 --- a/lib/sqlalchemy/testing/suite/test_select.py +++ b/lib/sqlalchemy/testing/suite/test_select.py @@ -552,7 +552,7 @@ class FetchLimitOffsetTest(fixtures.TablesTest): .offset(2) ).fetchall() eq_(fa[0], (3, 3, 4)) - eq_(set(fa), set([(3, 3, 4), (4, 4, 5), (5, 4, 6)])) + eq_(set(fa), {(3, 3, 4), (4, 4, 5), (5, 4, 6)}) @testing.requires.fetch_ties @testing.requires.fetch_offset_with_options @@ -623,7 +623,7 @@ class FetchLimitOffsetTest(fixtures.TablesTest): .offset(2) ).fetchall() eq_(fa[0], (3, 3, 4)) - eq_(set(fa), set([(3, 3, 4), (4, 4, 5), (5, 4, 6)])) + eq_(set(fa), {(3, 3, 4), (4, 4, 5), (5, 4, 6)}) class SameNamedSchemaTableTest(fixtures.TablesTest): diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py index 25ed041c2..36fd7f247 100644 --- a/lib/sqlalchemy/testing/suite/test_types.py +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -832,8 +832,8 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): result = {row[0] for row in connection.execute(t.select())} output = set(output) if filter_: - result = set(filter_(x) for x in result) - output = set(filter_(x) for x in output) + result = {filter_(x) for x in result} + output = {filter_(x) for x in output} eq_(result, output) if check_scale: eq_([str(x) for x in result], [str(x) for x in output]) @@ -969,13 +969,11 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): @testing.requires.precision_numerics_general def test_precision_decimal(self, do_numeric_test): - numbers = set( - [ - decimal.Decimal("54.234246451650"), - decimal.Decimal("0.004354"), - decimal.Decimal("900.0"), - ] - ) + numbers = { + decimal.Decimal("54.234246451650"), + decimal.Decimal("0.004354"), + decimal.Decimal("900.0"), + } do_numeric_test(Numeric(precision=18, scale=12), numbers, numbers) @@ -988,52 +986,46 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): """ - numbers = set( - [ - decimal.Decimal("1E-2"), - decimal.Decimal("1E-3"), - decimal.Decimal("1E-4"), - decimal.Decimal("1E-5"), - decimal.Decimal("1E-6"), - decimal.Decimal("1E-7"), - decimal.Decimal("1E-8"), - decimal.Decimal("0.01000005940696"), - decimal.Decimal("0.00000005940696"), - decimal.Decimal("0.00000000000696"), - decimal.Decimal("0.70000000000696"), - decimal.Decimal("696E-12"), - ] - ) + numbers = { + decimal.Decimal("1E-2"), + decimal.Decimal("1E-3"), + decimal.Decimal("1E-4"), + decimal.Decimal("1E-5"), + decimal.Decimal("1E-6"), + decimal.Decimal("1E-7"), + decimal.Decimal("1E-8"), + decimal.Decimal("0.01000005940696"), + decimal.Decimal("0.00000005940696"), + decimal.Decimal("0.00000000000696"), + decimal.Decimal("0.70000000000696"), + decimal.Decimal("696E-12"), + } do_numeric_test(Numeric(precision=18, scale=14), numbers, numbers) @testing.requires.precision_numerics_enotation_large def test_enotation_decimal_large(self, do_numeric_test): """test exceedingly large decimals.""" - numbers = set( - [ - decimal.Decimal("4E+8"), - decimal.Decimal("5748E+15"), - decimal.Decimal("1.521E+15"), - decimal.Decimal("00000000000000.1E+12"), - ] - ) + numbers = { + decimal.Decimal("4E+8"), + decimal.Decimal("5748E+15"), + decimal.Decimal("1.521E+15"), + decimal.Decimal("00000000000000.1E+12"), + } do_numeric_test(Numeric(precision=25, scale=2), numbers, numbers) @testing.requires.precision_numerics_many_significant_digits def test_many_significant_digits(self, do_numeric_test): - numbers = set( - [ - decimal.Decimal("31943874831932418390.01"), - decimal.Decimal("319438950232418390.273596"), - decimal.Decimal("87673.594069654243"), - ] - ) + numbers = { + decimal.Decimal("31943874831932418390.01"), + decimal.Decimal("319438950232418390.273596"), + decimal.Decimal("87673.594069654243"), + } do_numeric_test(Numeric(precision=38, scale=12), numbers, numbers) @testing.requires.precision_numerics_retains_significant_digits def test_numeric_no_decimal(self, do_numeric_test): - numbers = set([decimal.Decimal("1.000")]) + numbers = {decimal.Decimal("1.000")} do_numeric_test( Numeric(precision=5, scale=3), numbers, numbers, check_scale=True ) @@ -1258,7 +1250,7 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): def default(self, o): if isinstance(o, decimal.Decimal): return str(o) - return super(DecimalEncoder, self).default(o) + return super().default(o) json_data = json.dumps(data_element, cls=DecimalEncoder) |
