From db08a699489c9b0259579d7ff7fd6bf3496ca3a2 Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Thu, 14 Oct 2021 21:45:57 +0200 Subject: rearchitect reflection for batched performance Rearchitected the schema reflection API to allow some dialects to make use of high performing batch queries to reflect the schemas of many tables at once using much fewer queries. The new performance features are targeted first at the PostgreSQL and Oracle backends, and may be applied to any dialect that makes use of SELECT queries against system catalog tables to reflect tables (currently this omits the MySQL and SQLite dialects which instead make use of parsing the "CREATE TABLE" statement, however these dialects do not have a pre-existing performance issue with reflection. MS SQL Server is still a TODO). The new API is backwards compatible with the previous system, and should require no changes to third party dialects to retain compatibility; third party dialects can also opt into the new system by implementing batched queries for schema reflection. Along with this change is an updated reflection API that is fully :pep:`484` typed, features many new methods and some changes. Fixes: #4379 Change-Id: I897ec09843543aa7012bcdce758792ed3d415d08 --- lib/sqlalchemy/testing/assertions.py | 30 +- lib/sqlalchemy/testing/plugin/pytestplugin.py | 16 +- lib/sqlalchemy/testing/provision.py | 70 +- lib/sqlalchemy/testing/requirements.py | 46 + lib/sqlalchemy/testing/schema.py | 15 +- lib/sqlalchemy/testing/suite/test_reflection.py | 1544 +++++++++++++++++++---- lib/sqlalchemy/testing/suite/test_sequence.py | 33 +- lib/sqlalchemy/testing/util.py | 39 +- 8 files changed, 1492 insertions(+), 301 deletions(-) (limited to 'lib/sqlalchemy/testing') diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index 9888d7c18..937706363 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -644,13 +644,21 @@ class AssertsCompiledSQL: class ComparesTables: - def assert_tables_equal(self, table, reflected_table, strict_types=False): + def assert_tables_equal( + self, + table, + reflected_table, + strict_types=False, + strict_constraints=True, + ): assert len(table.c) == len(reflected_table.c) for c, reflected_c in zip(table.c, reflected_table.c): eq_(c.name, reflected_c.name) assert reflected_c is reflected_table.c[c.name] - eq_(c.primary_key, reflected_c.primary_key) - eq_(c.nullable, reflected_c.nullable) + + if strict_constraints: + eq_(c.primary_key, reflected_c.primary_key) + eq_(c.nullable, reflected_c.nullable) if strict_types: msg = "Type '%s' doesn't correspond to type '%s'" @@ -664,18 +672,20 @@ class ComparesTables: if isinstance(c.type, sqltypes.String): eq_(c.type.length, reflected_c.type.length) - eq_( - {f.column.name for f in c.foreign_keys}, - {f.column.name for f in reflected_c.foreign_keys}, - ) + if strict_constraints: + eq_( + {f.column.name for f in c.foreign_keys}, + {f.column.name for f in reflected_c.foreign_keys}, + ) if c.server_default: assert isinstance( reflected_c.server_default, schema.FetchedValue ) - assert len(table.primary_key) == len(reflected_table.primary_key) - for c in table.primary_key: - assert reflected_table.primary_key.columns[c.name] is not None + if strict_constraints: + assert len(table.primary_key) == len(reflected_table.primary_key) + for c in table.primary_key: + assert reflected_table.primary_key.columns[c.name] is not None def assert_types_base(self, c1, c2): assert c1.type._compare_type_affinity( diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index fa7d2ca19..cea07b305 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -741,13 +741,18 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions): fn._sa_parametrize.append((argnames, pytest_params)) return fn else: + _fn_argnames = inspect.getfullargspec(fn).args[1:] if argnames is None: - _argnames = inspect.getfullargspec(fn).args[1:] + _argnames = _fn_argnames else: _argnames = re.split(r", *", argnames) if has_exclusions: - _argnames += ["_exclusions"] + existing_exl = sum( + 1 for n in _fn_argnames if n.startswith("_exclusions") + ) + current_exclusion_name = f"_exclusions_{existing_exl}" + _argnames += [current_exclusion_name] @_pytest_fn_decorator def check_exclusions(fn, *args, **kw): @@ -755,13 +760,10 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions): if _exclusions: exlu = exclusions.compound().add(*_exclusions) fn = exlu(fn) - return fn(*args[0:-1], **kw) - - def process_metadata(spec): - spec.args.append("_exclusions") + return fn(*args[:-1], **kw) fn = check_exclusions( - fn, add_positional_parameters=("_exclusions",) + fn, add_positional_parameters=(current_exclusion_name,) ) return pytest.mark.parametrize(_argnames, pytest_params)(fn) diff --git a/lib/sqlalchemy/testing/provision.py b/lib/sqlalchemy/testing/provision.py index d38437732..498d92a77 100644 --- a/lib/sqlalchemy/testing/provision.py +++ b/lib/sqlalchemy/testing/provision.py @@ -230,7 +230,39 @@ def drop_all_schema_objects(cfg, eng): drop_all_schema_objects_pre_tables(cfg, eng) + drop_views(cfg, eng) + + if config.requirements.materialized_views.enabled: + drop_materialized_views(cfg, eng) + inspector = inspect(eng) + + consider_schemas = (None,) + if config.requirements.schemas.enabled_for_config(cfg): + consider_schemas += (cfg.test_schema, cfg.test_schema_2) + util.drop_all_tables(eng, inspector, consider_schemas=consider_schemas) + + drop_all_schema_objects_post_tables(cfg, eng) + + if config.requirements.sequences.enabled_for_config(cfg): + with eng.begin() as conn: + for seq in inspector.get_sequence_names(): + conn.execute(ddl.DropSequence(schema.Sequence(seq))) + if config.requirements.schemas.enabled_for_config(cfg): + for schema_name in [cfg.test_schema, cfg.test_schema_2]: + for seq in inspector.get_sequence_names( + schema=schema_name + ): + conn.execute( + ddl.DropSequence( + schema.Sequence(seq, schema=schema_name) + ) + ) + + +def drop_views(cfg, eng): + inspector = inspect(eng) + try: view_names = inspector.get_view_names() except NotImplementedError: @@ -244,7 +276,7 @@ def drop_all_schema_objects(cfg, eng): if config.requirements.schemas.enabled_for_config(cfg): try: - view_names = inspector.get_view_names(schema="test_schema") + view_names = inspector.get_view_names(schema=cfg.test_schema) except NotImplementedError: pass else: @@ -255,32 +287,30 @@ def drop_all_schema_objects(cfg, eng): schema.Table( vname, schema.MetaData(), - schema="test_schema", + schema=cfg.test_schema, ) ) ) - util.drop_all_tables(eng, inspector) - if config.requirements.schemas.enabled_for_config(cfg): - util.drop_all_tables(eng, inspector, schema=cfg.test_schema) - util.drop_all_tables(eng, inspector, schema=cfg.test_schema_2) - drop_all_schema_objects_post_tables(cfg, eng) +def drop_materialized_views(cfg, eng): + inspector = inspect(eng) - if config.requirements.sequences.enabled_for_config(cfg): + mview_names = inspector.get_materialized_view_names() + + with eng.begin() as conn: + for vname in mview_names: + conn.exec_driver_sql(f"DROP MATERIALIZED VIEW {vname}") + + if config.requirements.schemas.enabled_for_config(cfg): + mview_names = inspector.get_materialized_view_names( + schema=cfg.test_schema + ) with eng.begin() as conn: - for seq in inspector.get_sequence_names(): - conn.execute(ddl.DropSequence(schema.Sequence(seq))) - if config.requirements.schemas.enabled_for_config(cfg): - for schema_name in [cfg.test_schema, cfg.test_schema_2]: - for seq in inspector.get_sequence_names( - schema=schema_name - ): - conn.execute( - ddl.DropSequence( - schema.Sequence(seq, schema=schema_name) - ) - ) + for vname in mview_names: + conn.exec_driver_sql( + f"DROP MATERIALIZED VIEW {cfg.test_schema}.{vname}" + ) @register.init diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index 4f9c73cf6..038f6e9bd 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -64,6 +64,25 @@ class SuiteRequirements(Requirements): return exclusions.open() + @property + def foreign_keys_reflect_as_index(self): + """Target database creates an index that's reflected for + foreign keys.""" + + return exclusions.closed() + + @property + def unique_index_reflect_as_unique_constraints(self): + """Target database reflects unique indexes as unique constrains.""" + + return exclusions.closed() + + @property + def unique_constraints_reflect_as_index(self): + """Target database reflects unique constraints as indexes.""" + + return exclusions.closed() + @property def table_value_constructor(self): """Database / dialect supports a query like:: @@ -628,6 +647,12 @@ class SuiteRequirements(Requirements): def schema_reflection(self): return self.schemas + @property + def schema_create_delete(self): + """target database supports schema create and dropped with + 'CREATE SCHEMA' and 'DROP SCHEMA'""" + return exclusions.closed() + @property def primary_key_constraint_reflection(self): return exclusions.open() @@ -692,6 +717,12 @@ class SuiteRequirements(Requirements): """target database supports CREATE INDEX with per-column ASC/DESC.""" return exclusions.open() + @property + def reflect_indexes_with_ascdesc(self): + """target database supports reflecting INDEX with per-column + ASC/DESC.""" + return exclusions.open() + @property def indexes_with_expressions(self): """target database supports CREATE INDEX against SQL expressions.""" @@ -1567,3 +1598,18 @@ class SuiteRequirements(Requirements): def json_deserializer_binary(self): "indicates if the json_deserializer function is called with bytes" return exclusions.closed() + + @property + def reflect_table_options(self): + """Target database must support reflecting table_options.""" + return exclusions.closed() + + @property + def materialized_views(self): + """Target database must support MATERIALIZED VIEWs.""" + return exclusions.closed() + + @property + def materialized_views_reflect_pk(self): + """Target database reflect MATERIALIZED VIEWs pks.""" + return exclusions.closed() diff --git a/lib/sqlalchemy/testing/schema.py b/lib/sqlalchemy/testing/schema.py index e4a92a732..46cbf4759 100644 --- a/lib/sqlalchemy/testing/schema.py +++ b/lib/sqlalchemy/testing/schema.py @@ -23,7 +23,7 @@ __all__ = ["Table", "Column"] table_options = {} -def Table(*args, **kw): +def Table(*args, **kw) -> schema.Table: """A schema.Table wrapper/hook for dialect-specific tweaks.""" test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")} @@ -134,6 +134,19 @@ class eq_type_affinity: return self.target._type_affinity is not other._type_affinity +class eq_compile_type: + """similar to eq_type_affinity but uses compile""" + + def __init__(self, target): + self.target = target + + def __eq__(self, other): + return self.target == other.compile() + + def __ne__(self, other): + return self.target != other.compile() + + class eq_clause_element: """Helper to compare SQL structures based on compare()""" diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py index b09b96227..7b8e2aa8b 100644 --- a/lib/sqlalchemy/testing/suite/test_reflection.py +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -7,6 +7,8 @@ import sqlalchemy as sa from .. import config from .. import engines from .. import eq_ +from .. import expect_raises +from .. import expect_raises_message from .. import expect_warnings from .. import fixtures from .. import is_ @@ -24,12 +26,19 @@ from ... import MetaData from ... import String from ... import testing from ... import types as sql_types +from ...engine import Inspector +from ...engine import ObjectKind +from ...engine import ObjectScope +from ...exc import NoSuchTableError +from ...exc import UnreflectableTableError from ...schema import DDL from ...schema import Index from ...sql.elements import quoted_name from ...sql.schema import BLANK_SCHEMA +from ...testing import ComparesTables from ...testing import is_false from ...testing import is_true +from ...testing import mock metadata, users = None, None @@ -61,6 +70,19 @@ class HasTableTest(fixtures.TablesTest): is_false(config.db.dialect.has_table(conn, "test_table_s")) is_false(config.db.dialect.has_table(conn, "nonexistent_table")) + def test_has_table_cache(self, metadata): + insp = inspect(config.db) + is_true(insp.has_table("test_table")) + nt = Table("new_table", metadata, Column("col", Integer)) + is_false(insp.has_table("new_table")) + nt.create(config.db) + try: + is_false(insp.has_table("new_table")) + insp.clear_cache() + is_true(insp.has_table("new_table")) + finally: + nt.drop(config.db) + @testing.requires.schemas def test_has_table_schema(self): with config.db.begin() as conn: @@ -117,6 +139,7 @@ class HasIndexTest(fixtures.TablesTest): metadata, Column("id", Integer, primary_key=True), Column("data", String(50)), + Column("data2", String(50)), ) Index("my_idx", tt.c.data) @@ -130,40 +153,56 @@ class HasIndexTest(fixtures.TablesTest): ) Index("my_idx_s", tt.c.data) - def test_has_index(self): - with config.db.begin() as conn: - assert config.db.dialect.has_index(conn, "test_table", "my_idx") - assert not config.db.dialect.has_index( - conn, "test_table", "my_idx_s" - ) - assert not config.db.dialect.has_index( - conn, "nonexistent_table", "my_idx" - ) - assert not config.db.dialect.has_index( - conn, "test_table", "nonexistent_idx" - ) + kind = testing.combinations("dialect", "inspector", argnames="kind") + + def _has_index(self, kind, conn): + if kind == "dialect": + return lambda *a, **k: config.db.dialect.has_index(conn, *a, **k) + else: + return inspect(conn).has_index + + @kind + def test_has_index(self, kind, connection, metadata): + meth = self._has_index(kind, connection) + assert meth("test_table", "my_idx") + assert not meth("test_table", "my_idx_s") + assert not meth("nonexistent_table", "my_idx") + assert not meth("test_table", "nonexistent_idx") + + assert not meth("test_table", "my_idx_2") + assert not meth("test_table_2", "my_idx_3") + idx = Index("my_idx_2", self.tables.test_table.c.data2) + tbl = Table( + "test_table_2", + metadata, + Column("foo", Integer), + Index("my_idx_3", "foo"), + ) + idx.create(connection) + tbl.create(connection) + try: + if kind == "inspector": + assert not meth("test_table", "my_idx_2") + assert not meth("test_table_2", "my_idx_3") + meth.__self__.clear_cache() + assert meth("test_table", "my_idx_2") is True + assert meth("test_table_2", "my_idx_3") is True + finally: + tbl.drop(connection) + idx.drop(connection) @testing.requires.schemas - def test_has_index_schema(self): - with config.db.begin() as conn: - assert config.db.dialect.has_index( - conn, "test_table", "my_idx_s", schema=config.test_schema - ) - assert not config.db.dialect.has_index( - conn, "test_table", "my_idx", schema=config.test_schema - ) - assert not config.db.dialect.has_index( - conn, - "nonexistent_table", - "my_idx_s", - schema=config.test_schema, - ) - assert not config.db.dialect.has_index( - conn, - "test_table", - "nonexistent_idx_s", - schema=config.test_schema, - ) + @kind + def test_has_index_schema(self, kind, connection): + meth = self._has_index(kind, connection) + assert meth("test_table", "my_idx_s", schema=config.test_schema) + assert not meth("test_table", "my_idx", schema=config.test_schema) + assert not meth( + "nonexistent_table", "my_idx_s", schema=config.test_schema + ) + assert not meth( + "test_table", "nonexistent_idx_s", schema=config.test_schema + ) class QuotedNameArgumentTest(fixtures.TablesTest): @@ -264,7 +303,12 @@ class QuotedNameArgumentTest(fixtures.TablesTest): def test_get_table_options(self, name): insp = inspect(config.db) - insp.get_table_options(name) + if testing.requires.reflect_table_options.enabled: + res = insp.get_table_options(name) + is_true(isinstance(res, dict)) + else: + with expect_raises(NotImplementedError): + res = insp.get_table_options(name) @quote_fixtures @testing.requires.view_column_reflection @@ -311,7 +355,37 @@ class QuotedNameArgumentTest(fixtures.TablesTest): assert insp.get_check_constraints(name) -class ComponentReflectionTest(fixtures.TablesTest): +def _multi_combination(fn): + schema = testing.combinations( + None, + ( + lambda: config.test_schema, + testing.requires.schemas, + ), + argnames="schema", + ) + scope = testing.combinations( + ObjectScope.DEFAULT, + ObjectScope.TEMPORARY, + ObjectScope.ANY, + argnames="scope", + ) + kind = testing.combinations( + ObjectKind.TABLE, + ObjectKind.VIEW, + ObjectKind.MATERIALIZED_VIEW, + ObjectKind.ANY, + ObjectKind.ANY_VIEW, + ObjectKind.TABLE | ObjectKind.VIEW, + ObjectKind.TABLE | ObjectKind.MATERIALIZED_VIEW, + argnames="kind", + ) + filter_names = testing.combinations(True, False, argnames="use_filter") + + return schema(scope(kind(filter_names(fn)))) + + +class ComponentReflectionTest(ComparesTables, fixtures.TablesTest): run_inserts = run_deletes = None __backend__ = True @@ -354,6 +428,7 @@ class ComponentReflectionTest(fixtures.TablesTest): "%susers.user_id" % schema_prefix, name="user_id_fk" ), ), + sa.CheckConstraint("test2 > 0", name="test2_gt_zero"), schema=schema, test_needs_fk=True, ) @@ -364,6 +439,8 @@ class ComponentReflectionTest(fixtures.TablesTest): Column("user_id", sa.INT, primary_key=True), Column("test1", sa.CHAR(5), nullable=False), Column("test2", sa.Float(), nullable=False), + Column("parent_user_id", sa.Integer), + sa.CheckConstraint("test2 > 0", name="test2_gt_zero"), schema=schema, test_needs_fk=True, ) @@ -375,9 +452,19 @@ class ComponentReflectionTest(fixtures.TablesTest): Column( "address_id", sa.Integer, - sa.ForeignKey("%semail_addresses.address_id" % schema_prefix), + sa.ForeignKey( + "%semail_addresses.address_id" % schema_prefix, + name="email_add_id_fg", + ), + ), + Column("data", sa.String(30), unique=True), + sa.CheckConstraint( + "address_id > 0 AND address_id < 1000", + name="address_id_gt_zero", + ), + sa.UniqueConstraint( + "address_id", "dingaling_id", name="zz_dingalings_multiple" ), - Column("data", sa.String(30)), schema=schema, test_needs_fk=True, ) @@ -388,7 +475,7 @@ class ComponentReflectionTest(fixtures.TablesTest): Column( "remote_user_id", sa.Integer, sa.ForeignKey(users.c.user_id) ), - Column("email_address", sa.String(20)), + Column("email_address", sa.String(20), index=True), sa.PrimaryKeyConstraint("address_id", name="email_ad_pk"), schema=schema, test_needs_fk=True, @@ -406,6 +493,12 @@ class ComponentReflectionTest(fixtures.TablesTest): schema=schema, comment=r"""the test % ' " \ table comment""", ) + Table( + "no_constraints", + metadata, + Column("data", sa.String(20)), + schema=schema, + ) if testing.requires.cross_schema_fk_reflection.enabled: if schema is None: @@ -449,7 +542,10 @@ class ComponentReflectionTest(fixtures.TablesTest): ) if testing.requires.index_reflection.enabled: - cls.define_index(metadata, users) + Index("users_t_idx", users.c.test1, users.c.test2, unique=True) + Index( + "users_all_idx", users.c.user_id, users.c.test2, users.c.test1 + ) if not schema: # test_needs_fk is at the moment to force MySQL InnoDB @@ -468,7 +564,10 @@ class ComponentReflectionTest(fixtures.TablesTest): test_needs_fk=True, ) - if testing.requires.indexes_with_ascdesc.enabled: + if ( + testing.requires.indexes_with_ascdesc.enabled + and testing.requires.reflect_indexes_with_ascdesc.enabled + ): Index("noncol_idx_nopk", noncol_idx_test_nopk.c.q.desc()) Index("noncol_idx_pk", noncol_idx_test_pk.c.q.desc()) @@ -477,12 +576,16 @@ class ComponentReflectionTest(fixtures.TablesTest): if not schema and testing.requires.temp_table_reflection.enabled: cls.define_temp_tables(metadata) + @classmethod + def temp_table_name(cls): + return get_temp_table_name( + config, config.db, f"user_tmp_{config.ident}" + ) + @classmethod def define_temp_tables(cls, metadata): kw = temp_table_keyword_args(config, config.db) - table_name = get_temp_table_name( - config, config.db, "user_tmp_%s" % config.ident - ) + table_name = cls.temp_table_name() user_tmp = Table( table_name, metadata, @@ -495,7 +598,7 @@ class ComponentReflectionTest(fixtures.TablesTest): # unique constraints created against temp tables in different # databases. # https://www.arbinada.com/en/node/1645 - sa.UniqueConstraint("name", name="user_tmp_uq_%s" % config.ident), + sa.UniqueConstraint("name", name=f"user_tmp_uq_{config.ident}"), sa.Index("user_tmp_ix", "foo"), **kw, ) @@ -513,33 +616,636 @@ class ComponentReflectionTest(fixtures.TablesTest): ) event.listen(user_tmp, "before_drop", DDL("drop view user_tmp_v")) - @classmethod - def define_index(cls, metadata, users): - Index("users_t_idx", users.c.test1, users.c.test2) - Index("users_all_idx", users.c.user_id, users.c.test2, users.c.test1) - @classmethod def define_views(cls, metadata, schema): - for table_name in ("users", "email_addresses"): + if testing.requires.materialized_views.enabled: + materialized = {"dingalings"} + else: + materialized = set() + for table_name in ("users", "email_addresses", "dingalings"): fullname = table_name if schema: - fullname = "%s.%s" % (schema, table_name) + fullname = f"{schema}.{table_name}" view_name = fullname + "_v" - query = "CREATE VIEW %s AS SELECT * FROM %s" % ( - view_name, - fullname, + prefix = "MATERIALIZED " if table_name in materialized else "" + query = ( + f"CREATE {prefix}VIEW {view_name} AS SELECT * FROM {fullname}" ) event.listen(metadata, "after_create", DDL(query)) + if table_name in materialized: + index_name = "mat_index" + if schema and testing.against("oracle"): + index_name = f"{schema}.{index_name}" + idx = f"CREATE INDEX {index_name} ON {view_name}(data)" + event.listen(metadata, "after_create", DDL(idx)) event.listen( - metadata, "before_drop", DDL("DROP VIEW %s" % view_name) + metadata, "before_drop", DDL(f"DROP {prefix}VIEW {view_name}") + ) + + def _resolve_kind(self, kind, tables, views, materialized): + res = {} + if ObjectKind.TABLE in kind: + res.update(tables) + if ObjectKind.VIEW in kind: + res.update(views) + if ObjectKind.MATERIALIZED_VIEW in kind: + res.update(materialized) + return res + + def _resolve_views(self, views, materialized): + if not testing.requires.view_column_reflection.enabled: + materialized.clear() + views.clear() + elif not testing.requires.materialized_views.enabled: + views.update(materialized) + materialized.clear() + + def _resolve_names(self, schema, scope, filter_names, values): + scope_filter = lambda _: True # noqa: E731 + if scope is ObjectScope.DEFAULT: + scope_filter = lambda k: "tmp" not in k[1] # noqa: E731 + if scope is ObjectScope.TEMPORARY: + scope_filter = lambda k: "tmp" in k[1] # noqa: E731 + + removed = { + None: {"remote_table", "remote_table_2"}, + testing.config.test_schema: { + "local_table", + "noncol_idx_test_nopk", + "noncol_idx_test_pk", + "user_tmp_v", + self.temp_table_name(), + }, + } + if not testing.requires.cross_schema_fk_reflection.enabled: + removed[None].add("local_table") + removed[testing.config.test_schema].update( + ["remote_table", "remote_table_2"] + ) + if not testing.requires.index_reflection.enabled: + removed[None].update( + ["noncol_idx_test_nopk", "noncol_idx_test_pk"] ) + if ( + not testing.requires.temp_table_reflection.enabled + or not testing.requires.temp_table_names.enabled + ): + removed[None].update(["user_tmp_v", self.temp_table_name()]) + if not testing.requires.temporary_views.enabled: + removed[None].update(["user_tmp_v"]) + + res = { + k: v + for k, v in values.items() + if scope_filter(k) + and k[1] not in removed[schema] + and (not filter_names or k[1] in filter_names) + } + return res + + def exp_options( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + materialized = {(schema, "dingalings_v"): mock.ANY} + views = { + (schema, "email_addresses_v"): mock.ANY, + (schema, "users_v"): mock.ANY, + (schema, "user_tmp_v"): mock.ANY, + } + self._resolve_views(views, materialized) + tables = { + (schema, "users"): mock.ANY, + (schema, "dingalings"): mock.ANY, + (schema, "email_addresses"): mock.ANY, + (schema, "comment_test"): mock.ANY, + (schema, "no_constraints"): mock.ANY, + (schema, "local_table"): mock.ANY, + (schema, "remote_table"): mock.ANY, + (schema, "remote_table_2"): mock.ANY, + (schema, "noncol_idx_test_nopk"): mock.ANY, + (schema, "noncol_idx_test_pk"): mock.ANY, + (schema, self.temp_table_name()): mock.ANY, + } + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + def exp_comments( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + empty = {"text": None} + materialized = {(schema, "dingalings_v"): empty} + views = { + (schema, "email_addresses_v"): empty, + (schema, "users_v"): empty, + (schema, "user_tmp_v"): empty, + } + self._resolve_views(views, materialized) + tables = { + (schema, "users"): empty, + (schema, "dingalings"): empty, + (schema, "email_addresses"): empty, + (schema, "comment_test"): { + "text": r"""the test % ' " \ table comment""" + }, + (schema, "no_constraints"): empty, + (schema, "local_table"): empty, + (schema, "remote_table"): empty, + (schema, "remote_table_2"): empty, + (schema, "noncol_idx_test_nopk"): empty, + (schema, "noncol_idx_test_pk"): empty, + (schema, self.temp_table_name()): empty, + } + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + def exp_columns( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + def col( + name, auto=False, default=mock.ANY, comment=None, nullable=True + ): + res = { + "name": name, + "autoincrement": auto, + "type": mock.ANY, + "default": default, + "comment": comment, + "nullable": nullable, + } + if auto == "omit": + res.pop("autoincrement") + return res + + def pk(name, **kw): + kw = {"auto": True, "default": mock.ANY, "nullable": False, **kw} + return col(name, **kw) + + materialized = { + (schema, "dingalings_v"): [ + col("dingaling_id", auto="omit", nullable=mock.ANY), + col("address_id"), + col("data"), + ] + } + views = { + (schema, "email_addresses_v"): [ + col("address_id", auto="omit", nullable=mock.ANY), + col("remote_user_id"), + col("email_address"), + ], + (schema, "users_v"): [ + col("user_id", auto="omit", nullable=mock.ANY), + col("test1", nullable=mock.ANY), + col("test2", nullable=mock.ANY), + col("parent_user_id"), + ], + (schema, "user_tmp_v"): [ + col("id", auto="omit", nullable=mock.ANY), + col("name"), + col("foo"), + ], + } + self._resolve_views(views, materialized) + tables = { + (schema, "users"): [ + pk("user_id"), + col("test1", nullable=False), + col("test2", nullable=False), + col("parent_user_id"), + ], + (schema, "dingalings"): [ + pk("dingaling_id"), + col("address_id"), + col("data"), + ], + (schema, "email_addresses"): [ + pk("address_id"), + col("remote_user_id"), + col("email_address"), + ], + (schema, "comment_test"): [ + pk("id", comment="id comment"), + col("data", comment="data % comment"), + col( + "d2", + comment=r"""Comment types type speedily ' " \ '' Fun!""", + ), + ], + (schema, "no_constraints"): [col("data")], + (schema, "local_table"): [pk("id"), col("data"), col("remote_id")], + (schema, "remote_table"): [pk("id"), col("local_id"), col("data")], + (schema, "remote_table_2"): [pk("id"), col("data")], + (schema, "noncol_idx_test_nopk"): [col("q")], + (schema, "noncol_idx_test_pk"): [pk("id"), col("q")], + (schema, self.temp_table_name()): [ + pk("id"), + col("name"), + col("foo"), + ], + } + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + @property + def _required_column_keys(self): + return {"name", "type", "nullable", "default"} + + def exp_pks( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + def pk(*cols, name=mock.ANY): + return {"constrained_columns": list(cols), "name": name} + + empty = pk(name=None) + if testing.requires.materialized_views_reflect_pk.enabled: + materialized = {(schema, "dingalings_v"): pk("dingaling_id")} + else: + materialized = {(schema, "dingalings_v"): empty} + views = { + (schema, "email_addresses_v"): empty, + (schema, "users_v"): empty, + (schema, "user_tmp_v"): empty, + } + self._resolve_views(views, materialized) + tables = { + (schema, "users"): pk("user_id"), + (schema, "dingalings"): pk("dingaling_id"), + (schema, "email_addresses"): pk("address_id", name="email_ad_pk"), + (schema, "comment_test"): pk("id"), + (schema, "no_constraints"): empty, + (schema, "local_table"): pk("id"), + (schema, "remote_table"): pk("id"), + (schema, "remote_table_2"): pk("id"), + (schema, "noncol_idx_test_nopk"): empty, + (schema, "noncol_idx_test_pk"): pk("id"), + (schema, self.temp_table_name()): pk("id"), + } + if not testing.requires.reflects_pk_names.enabled: + for val in tables.values(): + if val["name"] is not None: + val["name"] = mock.ANY + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + @property + def _required_pk_keys(self): + return {"name", "constrained_columns"} + + def exp_fks( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + class tt: + def __eq__(self, other): + return ( + other is None + or config.db.dialect.default_schema_name == other + ) + + def fk(cols, ref_col, ref_table, ref_schema=schema, name=mock.ANY): + return { + "constrained_columns": cols, + "referred_columns": ref_col, + "name": name, + "options": mock.ANY, + "referred_schema": ref_schema + if ref_schema is not None + else tt(), + "referred_table": ref_table, + } + + materialized = {(schema, "dingalings_v"): []} + views = { + (schema, "email_addresses_v"): [], + (schema, "users_v"): [], + (schema, "user_tmp_v"): [], + } + self._resolve_views(views, materialized) + tables = { + (schema, "users"): [ + fk(["parent_user_id"], ["user_id"], "users", name="user_id_fk") + ], + (schema, "dingalings"): [ + fk( + ["address_id"], + ["address_id"], + "email_addresses", + name="email_add_id_fg", + ) + ], + (schema, "email_addresses"): [ + fk(["remote_user_id"], ["user_id"], "users") + ], + (schema, "comment_test"): [], + (schema, "no_constraints"): [], + (schema, "local_table"): [ + fk( + ["remote_id"], + ["id"], + "remote_table_2", + ref_schema=config.test_schema, + ) + ], + (schema, "remote_table"): [ + fk(["local_id"], ["id"], "local_table", ref_schema=None) + ], + (schema, "remote_table_2"): [], + (schema, "noncol_idx_test_nopk"): [], + (schema, "noncol_idx_test_pk"): [], + (schema, self.temp_table_name()): [], + } + if not testing.requires.self_referential_foreign_keys.enabled: + tables[(schema, "users")].clear() + if not testing.requires.named_constraints.enabled: + for vals in tables.values(): + for val in vals: + if val["name"] is not mock.ANY: + val["name"] = mock.ANY + + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + @property + def _required_fk_keys(self): + return { + "name", + "constrained_columns", + "referred_schema", + "referred_table", + "referred_columns", + } + + def exp_indexes( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + def idx( + *cols, + name, + unique=False, + column_sorting=None, + duplicates=False, + fk=False, + ): + fk_req = testing.requires.foreign_keys_reflect_as_index + dup_req = testing.requires.unique_constraints_reflect_as_index + if (fk and not fk_req.enabled) or ( + duplicates and not dup_req.enabled + ): + return () + res = { + "unique": unique, + "column_names": list(cols), + "name": name, + "dialect_options": mock.ANY, + "include_columns": [], + } + if column_sorting: + res["column_sorting"] = {"q": ("desc",)} + if duplicates: + res["duplicates_constraint"] = name + return [res] + + materialized = {(schema, "dingalings_v"): []} + views = { + (schema, "email_addresses_v"): [], + (schema, "users_v"): [], + (schema, "user_tmp_v"): [], + } + self._resolve_views(views, materialized) + if materialized: + materialized[(schema, "dingalings_v")].extend( + idx("data", name="mat_index") + ) + tables = { + (schema, "users"): [ + *idx("parent_user_id", name="user_id_fk", fk=True), + *idx("user_id", "test2", "test1", name="users_all_idx"), + *idx("test1", "test2", name="users_t_idx", unique=True), + ], + (schema, "dingalings"): [ + *idx("data", name=mock.ANY, unique=True, duplicates=True), + *idx( + "address_id", + "dingaling_id", + name="zz_dingalings_multiple", + unique=True, + duplicates=True, + ), + ], + (schema, "email_addresses"): [ + *idx("email_address", name=mock.ANY), + *idx("remote_user_id", name=mock.ANY, fk=True), + ], + (schema, "comment_test"): [], + (schema, "no_constraints"): [], + (schema, "local_table"): [ + *idx("remote_id", name=mock.ANY, fk=True) + ], + (schema, "remote_table"): [ + *idx("local_id", name=mock.ANY, fk=True) + ], + (schema, "remote_table_2"): [], + (schema, "noncol_idx_test_nopk"): [ + *idx( + "q", + name="noncol_idx_nopk", + column_sorting={"q": ("desc",)}, + ) + ], + (schema, "noncol_idx_test_pk"): [ + *idx( + "q", name="noncol_idx_pk", column_sorting={"q": ("desc",)} + ) + ], + (schema, self.temp_table_name()): [ + *idx("foo", name="user_tmp_ix"), + *idx( + "name", + name=f"user_tmp_uq_{config.ident}", + duplicates=True, + unique=True, + ), + ], + } + if ( + not testing.requires.indexes_with_ascdesc.enabled + or not testing.requires.reflect_indexes_with_ascdesc.enabled + ): + tables[(schema, "noncol_idx_test_nopk")].clear() + tables[(schema, "noncol_idx_test_pk")].clear() + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + @property + def _required_index_keys(self): + return {"name", "column_names", "unique"} + + def exp_ucs( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + all_=False, + ): + def uc(*cols, name, duplicates_index=None, is_index=False): + req = testing.requires.unique_index_reflect_as_unique_constraints + if is_index and not req.enabled: + return () + res = { + "column_names": list(cols), + "name": name, + } + if duplicates_index: + res["duplicates_index"] = duplicates_index + return [res] + + materialized = {(schema, "dingalings_v"): []} + views = { + (schema, "email_addresses_v"): [], + (schema, "users_v"): [], + (schema, "user_tmp_v"): [], + } + self._resolve_views(views, materialized) + tables = { + (schema, "users"): [ + *uc( + "test1", + "test2", + name="users_t_idx", + duplicates_index="users_t_idx", + is_index=True, + ) + ], + (schema, "dingalings"): [ + *uc("data", name=mock.ANY, duplicates_index=mock.ANY), + *uc( + "address_id", + "dingaling_id", + name="zz_dingalings_multiple", + duplicates_index="zz_dingalings_multiple", + ), + ], + (schema, "email_addresses"): [], + (schema, "comment_test"): [], + (schema, "no_constraints"): [], + (schema, "local_table"): [], + (schema, "remote_table"): [], + (schema, "remote_table_2"): [], + (schema, "noncol_idx_test_nopk"): [], + (schema, "noncol_idx_test_pk"): [], + (schema, self.temp_table_name()): [ + *uc("name", name=f"user_tmp_uq_{config.ident}") + ], + } + if all_: + return {**materialized, **views, **tables} + else: + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + @property + def _required_unique_cst_keys(self): + return {"name", "column_names"} + + def exp_ccs( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + class tt(str): + def __eq__(self, other): + res = ( + other.lower() + .replace("(", "") + .replace(")", "") + .replace("`", "") + ) + return self in res + + def cc(text, name): + return {"sqltext": tt(text), "name": name} + + # print({1: "test2 > (0)::double precision"} == {1: tt("test2 > 0")}) + # assert 0 + materialized = {(schema, "dingalings_v"): []} + views = { + (schema, "email_addresses_v"): [], + (schema, "users_v"): [], + (schema, "user_tmp_v"): [], + } + self._resolve_views(views, materialized) + tables = { + (schema, "users"): [cc("test2 > 0", "test2_gt_zero")], + (schema, "dingalings"): [ + cc( + "address_id > 0 and address_id < 1000", + name="address_id_gt_zero", + ), + ], + (schema, "email_addresses"): [], + (schema, "comment_test"): [], + (schema, "no_constraints"): [], + (schema, "local_table"): [], + (schema, "remote_table"): [], + (schema, "remote_table_2"): [], + (schema, "noncol_idx_test_nopk"): [], + (schema, "noncol_idx_test_pk"): [], + (schema, self.temp_table_name()): [], + } + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + @property + def _required_cc_keys(self): + return {"name", "sqltext"} @testing.requires.schema_reflection - def test_get_schema_names(self): - insp = inspect(self.bind) + def test_get_schema_names(self, connection): + insp = inspect(connection) - self.assert_(testing.config.test_schema in insp.get_schema_names()) + is_true(testing.config.test_schema in insp.get_schema_names()) + + @testing.requires.schema_reflection + def test_has_schema(self, connection): + insp = inspect(connection) + + is_true(insp.has_schema(testing.config.test_schema)) + is_false(insp.has_schema("sa_fake_schema_foo")) @testing.requires.schema_reflection def test_get_schema_names_w_translate_map(self, connection): @@ -553,7 +1259,37 @@ class ComponentReflectionTest(fixtures.TablesTest): ) insp = inspect(connection) - self.assert_(testing.config.test_schema in insp.get_schema_names()) + is_true(testing.config.test_schema in insp.get_schema_names()) + + @testing.requires.schema_reflection + def test_has_schema_w_translate_map(self, connection): + connection = connection.execution_options( + schema_translate_map={ + "foo": "bar", + BLANK_SCHEMA: testing.config.test_schema, + } + ) + insp = inspect(connection) + + is_true(insp.has_schema(testing.config.test_schema)) + is_false(insp.has_schema("sa_fake_schema_foo")) + + @testing.requires.schema_reflection + @testing.requires.schema_create_delete + def test_schema_cache(self, connection): + insp = inspect(connection) + + is_false("foo_bar" in insp.get_schema_names()) + is_false(insp.has_schema("foo_bar")) + connection.execute(DDL("CREATE SCHEMA foo_bar")) + try: + is_false("foo_bar" in insp.get_schema_names()) + is_false(insp.has_schema("foo_bar")) + insp.clear_cache() + is_true("foo_bar" in insp.get_schema_names()) + is_true(insp.has_schema("foo_bar")) + finally: + connection.execute(DDL("DROP SCHEMA foo_bar")) @testing.requires.schema_reflection def test_dialect_initialize(self): @@ -562,113 +1298,115 @@ class ComponentReflectionTest(fixtures.TablesTest): assert hasattr(engine.dialect, "default_schema_name") @testing.requires.schema_reflection - def test_get_default_schema_name(self): - insp = inspect(self.bind) - eq_(insp.default_schema_name, self.bind.dialect.default_schema_name) + def test_get_default_schema_name(self, connection): + insp = inspect(connection) + eq_(insp.default_schema_name, connection.dialect.default_schema_name) - @testing.requires.foreign_key_constraint_reflection @testing.combinations( - (None, True, False, False), - (None, True, False, True, testing.requires.schemas), - ("foreign_key", True, False, False), - (None, False, True, False), - (None, False, True, True, testing.requires.schemas), - (None, True, True, False), - (None, True, True, True, testing.requires.schemas), - argnames="order_by,include_plain,include_views,use_schema", + None, + ("foreign_key", testing.requires.foreign_key_constraint_reflection), + argnames="order_by", ) - def test_get_table_names( - self, connection, order_by, include_plain, include_views, use_schema - ): + @testing.combinations( + (True, testing.requires.schemas), False, argnames="use_schema" + ) + def test_get_table_names(self, connection, order_by, use_schema): if use_schema: schema = config.test_schema else: schema = None - _ignore_tables = [ + _ignore_tables = { "comment_test", "noncol_idx_test_pk", "noncol_idx_test_nopk", "local_table", "remote_table", "remote_table_2", - ] + "no_constraints", + } insp = inspect(connection) - if include_views: - table_names = insp.get_view_names(schema) - table_names.sort() - answer = ["email_addresses_v", "users_v"] - eq_(sorted(table_names), answer) + if order_by: + tables = [ + rec[0] + for rec in insp.get_sorted_table_and_fkc_names(schema) + if rec[0] + ] + else: + tables = insp.get_table_names(schema) + table_names = [t for t in tables if t not in _ignore_tables] - if include_plain: - if order_by: - tables = [ - rec[0] - for rec in insp.get_sorted_table_and_fkc_names(schema) - if rec[0] - ] - else: - tables = insp.get_table_names(schema) - table_names = [t for t in tables if t not in _ignore_tables] + if order_by == "foreign_key": + answer = ["users", "email_addresses", "dingalings"] + eq_(table_names, answer) + else: + answer = ["dingalings", "email_addresses", "users"] + eq_(sorted(table_names), answer) - if order_by == "foreign_key": - answer = ["users", "email_addresses", "dingalings"] - eq_(table_names, answer) - else: - answer = ["dingalings", "email_addresses", "users"] - eq_(sorted(table_names), answer) + @testing.combinations( + (True, testing.requires.schemas), False, argnames="use_schema" + ) + def test_get_view_names(self, connection, use_schema): + insp = inspect(connection) + if use_schema: + schema = config.test_schema + else: + schema = None + table_names = insp.get_view_names(schema) + if testing.requires.materialized_views.enabled: + eq_(sorted(table_names), ["email_addresses_v", "users_v"]) + eq_(insp.get_materialized_view_names(schema), ["dingalings_v"]) + else: + answer = ["dingalings_v", "email_addresses_v", "users_v"] + eq_(sorted(table_names), answer) @testing.requires.temp_table_names - def test_get_temp_table_names(self): - insp = inspect(self.bind) + def test_get_temp_table_names(self, connection): + insp = inspect(connection) temp_table_names = insp.get_temp_table_names() - eq_(sorted(temp_table_names), ["user_tmp_%s" % config.ident]) + eq_(sorted(temp_table_names), [f"user_tmp_{config.ident}"]) @testing.requires.view_reflection - @testing.requires.temp_table_names @testing.requires.temporary_views - def test_get_temp_view_names(self): - insp = inspect(self.bind) + def test_get_temp_view_names(self, connection): + insp = inspect(connection) temp_table_names = insp.get_temp_view_names() eq_(sorted(temp_table_names), ["user_tmp_v"]) @testing.requires.comment_reflection - def test_get_comments(self): - self._test_get_comments() + def test_get_comments(self, connection): + self._test_get_comments(connection) @testing.requires.comment_reflection @testing.requires.schemas - def test_get_comments_with_schema(self): - self._test_get_comments(testing.config.test_schema) - - def _test_get_comments(self, schema=None): - insp = inspect(self.bind) + def test_get_comments_with_schema(self, connection): + self._test_get_comments(connection, testing.config.test_schema) + def _test_get_comments(self, connection, schema=None): + insp = inspect(connection) + exp = self.exp_comments(schema=schema) eq_( insp.get_table_comment("comment_test", schema=schema), - {"text": r"""the test % ' " \ table comment"""}, + exp[(schema, "comment_test")], ) - eq_(insp.get_table_comment("users", schema=schema), {"text": None}) + eq_( + insp.get_table_comment("users", schema=schema), + exp[(schema, "users")], + ) eq_( - [ - {"name": rec["name"], "comment": rec["comment"]} - for rec in insp.get_columns("comment_test", schema=schema) - ], - [ - {"comment": "id comment", "name": "id"}, - {"comment": "data % comment", "name": "data"}, - { - "comment": ( - r"""Comment types type speedily ' " \ '' Fun!""" - ), - "name": "d2", - }, - ], + insp.get_table_comment("comment_test", schema=schema), + exp[(schema, "comment_test")], + ) + + no_cst = self.tables.no_constraints.name + eq_( + insp.get_table_comment(no_cst, schema=schema), + exp[(schema, no_cst)], ) @testing.combinations( @@ -691,7 +1429,7 @@ class ComponentReflectionTest(fixtures.TablesTest): users, addresses = (self.tables.users, self.tables.email_addresses) if use_views: - table_names = ["users_v", "email_addresses_v"] + table_names = ["users_v", "email_addresses_v", "dingalings_v"] else: table_names = ["users", "email_addresses"] @@ -699,7 +1437,7 @@ class ComponentReflectionTest(fixtures.TablesTest): for table_name, table in zip(table_names, (users, addresses)): schema_name = schema cols = insp.get_columns(table_name, schema=schema_name) - self.assert_(len(cols) > 0, len(cols)) + is_true(len(cols) > 0, len(cols)) # should be in order @@ -721,7 +1459,7 @@ class ComponentReflectionTest(fixtures.TablesTest): # assert that the desired type and return type share # a base within one of the generic types. - self.assert_( + is_true( len( set(ctype.__mro__) .intersection(ctype_def.__mro__) @@ -745,15 +1483,29 @@ class ComponentReflectionTest(fixtures.TablesTest): if not col.primary_key: assert cols[i]["default"] is None + # The case of a table with no column + # is tested below in TableNoColumnsTest + @testing.requires.temp_table_reflection - def test_get_temp_table_columns(self): - table_name = get_temp_table_name( - config, self.bind, "user_tmp_%s" % config.ident + def test_reflect_table_temp_table(self, connection): + + table_name = self.temp_table_name() + user_tmp = self.tables[table_name] + + reflected_user_tmp = Table( + table_name, MetaData(), autoload_with=connection ) + self.assert_tables_equal( + user_tmp, reflected_user_tmp, strict_constraints=False + ) + + @testing.requires.temp_table_reflection + def test_get_temp_table_columns(self, connection): + table_name = self.temp_table_name() user_tmp = self.tables[table_name] - insp = inspect(self.bind) + insp = inspect(connection) cols = insp.get_columns(table_name) - self.assert_(len(cols) > 0, len(cols)) + is_true(len(cols) > 0, len(cols)) for i, col in enumerate(user_tmp.columns): eq_(col.name, cols[i]["name"]) @@ -761,8 +1513,8 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.requires.temp_table_reflection @testing.requires.view_column_reflection @testing.requires.temporary_views - def test_get_temp_view_columns(self): - insp = inspect(self.bind) + def test_get_temp_view_columns(self, connection): + insp = inspect(connection) cols = insp.get_columns("user_tmp_v") eq_([col["name"] for col in cols], ["id", "name", "foo"]) @@ -778,18 +1530,27 @@ class ComponentReflectionTest(fixtures.TablesTest): users, addresses = self.tables.users, self.tables.email_addresses insp = inspect(connection) + exp = self.exp_pks(schema=schema) users_cons = insp.get_pk_constraint(users.name, schema=schema) - users_pkeys = users_cons["constrained_columns"] - eq_(users_pkeys, ["user_id"]) + self._check_list( + [users_cons], [exp[(schema, users.name)]], self._required_pk_keys + ) addr_cons = insp.get_pk_constraint(addresses.name, schema=schema) - addr_pkeys = addr_cons["constrained_columns"] - eq_(addr_pkeys, ["address_id"]) + exp_cols = exp[(schema, addresses.name)]["constrained_columns"] + eq_(addr_cons["constrained_columns"], exp_cols) with testing.requires.reflects_pk_names.fail_if(): eq_(addr_cons["name"], "email_ad_pk") + no_cst = self.tables.no_constraints.name + self._check_list( + [insp.get_pk_constraint(no_cst, schema=schema)], + [exp[(schema, no_cst)]], + self._required_pk_keys, + ) + @testing.combinations( (False,), (True, testing.requires.schemas), argnames="use_schema" ) @@ -815,31 +1576,33 @@ class ComponentReflectionTest(fixtures.TablesTest): eq_(fkey1["referred_schema"], expected_schema) eq_(fkey1["referred_table"], users.name) eq_(fkey1["referred_columns"], ["user_id"]) - if testing.requires.self_referential_foreign_keys.enabled: - eq_(fkey1["constrained_columns"], ["parent_user_id"]) + eq_(fkey1["constrained_columns"], ["parent_user_id"]) # addresses addr_fkeys = insp.get_foreign_keys(addresses.name, schema=schema) fkey1 = addr_fkeys[0] with testing.requires.implicitly_named_constraints.fail_if(): - self.assert_(fkey1["name"] is not None) + is_true(fkey1["name"] is not None) eq_(fkey1["referred_schema"], expected_schema) eq_(fkey1["referred_table"], users.name) eq_(fkey1["referred_columns"], ["user_id"]) eq_(fkey1["constrained_columns"], ["remote_user_id"]) + no_cst = self.tables.no_constraints.name + eq_(insp.get_foreign_keys(no_cst, schema=schema), []) + @testing.requires.cross_schema_fk_reflection @testing.requires.schemas - def test_get_inter_schema_foreign_keys(self): + def test_get_inter_schema_foreign_keys(self, connection): local_table, remote_table, remote_table_2 = self.tables( - "%s.local_table" % self.bind.dialect.default_schema_name, + "%s.local_table" % connection.dialect.default_schema_name, "%s.remote_table" % testing.config.test_schema, "%s.remote_table_2" % testing.config.test_schema, ) - insp = inspect(self.bind) + insp = inspect(connection) local_fkeys = insp.get_foreign_keys(local_table.name) eq_(len(local_fkeys), 1) @@ -857,25 +1620,21 @@ class ComponentReflectionTest(fixtures.TablesTest): fkey2 = remote_fkeys[0] - assert fkey2["referred_schema"] in ( - None, - self.bind.dialect.default_schema_name, + is_true( + fkey2["referred_schema"] + in ( + None, + connection.dialect.default_schema_name, + ) ) eq_(fkey2["referred_table"], local_table.name) eq_(fkey2["referred_columns"], ["id"]) eq_(fkey2["constrained_columns"], ["local_id"]) - def _assert_insp_indexes(self, indexes, expected_indexes): - index_names = [d["name"] for d in indexes] - for e_index in expected_indexes: - assert e_index["name"] in index_names - index = indexes[index_names.index(e_index["name"])] - for key in e_index: - eq_(e_index[key], index[key]) - @testing.combinations( (False,), (True, testing.requires.schemas), argnames="use_schema" ) + @testing.requires.index_reflection def test_get_indexes(self, connection, use_schema): if use_schema: @@ -885,21 +1644,19 @@ class ComponentReflectionTest(fixtures.TablesTest): # The database may decide to create indexes for foreign keys, etc. # so there may be more indexes than expected. - insp = inspect(self.bind) + insp = inspect(connection) indexes = insp.get_indexes("users", schema=schema) - expected_indexes = [ - { - "unique": False, - "column_names": ["test1", "test2"], - "name": "users_t_idx", - }, - { - "unique": False, - "column_names": ["user_id", "test2", "test1"], - "name": "users_all_idx", - }, - ] - self._assert_insp_indexes(indexes, expected_indexes) + exp = self.exp_indexes(schema=schema) + self._check_list( + indexes, exp[(schema, "users")], self._required_index_keys + ) + + no_cst = self.tables.no_constraints.name + self._check_list( + insp.get_indexes(no_cst, schema=schema), + exp[(schema, no_cst)], + self._required_index_keys, + ) @testing.combinations( ("noncol_idx_test_nopk", "noncol_idx_nopk"), @@ -908,15 +1665,15 @@ class ComponentReflectionTest(fixtures.TablesTest): ) @testing.requires.index_reflection @testing.requires.indexes_with_ascdesc + @testing.requires.reflect_indexes_with_ascdesc def test_get_noncol_index(self, connection, tname, ixname): insp = inspect(connection) indexes = insp.get_indexes(tname) - # reflecting an index that has "x DESC" in it as the column. # the DB may or may not give us "x", but make sure we get the index # back, it has a name, it's connected to the table. - expected_indexes = [{"unique": False, "name": ixname}] - self._assert_insp_indexes(indexes, expected_indexes) + expected_indexes = self.exp_indexes()[(None, tname)] + self._check_list(indexes, expected_indexes, self._required_index_keys) t = Table(tname, MetaData(), autoload_with=connection) eq_(len(t.indexes), 1) @@ -925,29 +1682,17 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.requires.temp_table_reflection @testing.requires.unique_constraint_reflection - def test_get_temp_table_unique_constraints(self): - insp = inspect(self.bind) - reflected = insp.get_unique_constraints("user_tmp_%s" % config.ident) - for refl in reflected: - # Different dialects handle duplicate index and constraints - # differently, so ignore this flag - refl.pop("duplicates_index", None) - eq_( - reflected, - [ - { - "column_names": ["name"], - "name": "user_tmp_uq_%s" % config.ident, - } - ], - ) + def test_get_temp_table_unique_constraints(self, connection): + insp = inspect(connection) + name = self.temp_table_name() + reflected = insp.get_unique_constraints(name) + exp = self.exp_ucs(all_=True)[(None, name)] + self._check_list(reflected, exp, self._required_index_keys) @testing.requires.temp_table_reflect_indexes - def test_get_temp_table_indexes(self): - insp = inspect(self.bind) - table_name = get_temp_table_name( - config, config.db, "user_tmp_%s" % config.ident - ) + def test_get_temp_table_indexes(self, connection): + insp = inspect(connection) + table_name = self.temp_table_name() indexes = insp.get_indexes(table_name) for ind in indexes: ind.pop("dialect_options", None) @@ -1005,9 +1750,9 @@ class ComponentReflectionTest(fixtures.TablesTest): ) table.create(connection) - inspector = inspect(connection) + insp = inspect(connection) reflected = sorted( - inspector.get_unique_constraints("testtbl", schema=schema), + insp.get_unique_constraints("testtbl", schema=schema), key=operator.itemgetter("name"), ) @@ -1047,6 +1792,9 @@ class ComponentReflectionTest(fixtures.TablesTest): eq_(names_that_duplicate_index, idx_names) eq_(uq_names, set()) + no_cst = self.tables.no_constraints.name + eq_(insp.get_unique_constraints(no_cst, schema=schema), []) + @testing.requires.view_reflection @testing.combinations( (False,), (True, testing.requires.schemas), argnames="use_schema" @@ -1056,32 +1804,21 @@ class ComponentReflectionTest(fixtures.TablesTest): schema = config.test_schema else: schema = None - view_name1 = "users_v" - view_name2 = "email_addresses_v" insp = inspect(connection) - v1 = insp.get_view_definition(view_name1, schema=schema) - self.assert_(v1) - v2 = insp.get_view_definition(view_name2, schema=schema) - self.assert_(v2) + for view in ["users_v", "email_addresses_v", "dingalings_v"]: + v = insp.get_view_definition(view, schema=schema) + is_true(bool(v)) - # why is this here if it's PG specific ? - @testing.combinations( - ("users", False), - ("users", True, testing.requires.schemas), - argnames="table_name,use_schema", - ) - @testing.only_on("postgresql", "PG specific feature") - def test_get_table_oid(self, connection, table_name, use_schema): - if use_schema: - schema = config.test_schema - else: - schema = None + @testing.requires.view_reflection + def test_get_view_definition_does_not_exist(self, connection): insp = inspect(connection) - oid = insp.get_table_oid(table_name, schema) - self.assert_(isinstance(oid, int)) + with expect_raises(NoSuchTableError): + insp.get_view_definition("view_does_not_exist") + with expect_raises(NoSuchTableError): + insp.get_view_definition("users") # a table @testing.requires.table_reflection - def test_autoincrement_col(self): + def test_autoincrement_col(self, connection): """test that 'autoincrement' is reflected according to sqla's policy. Don't mark this test as unsupported for any backend ! @@ -1094,7 +1831,7 @@ class ComponentReflectionTest(fixtures.TablesTest): """ - insp = inspect(self.bind) + insp = inspect(connection) for tname, cname in [ ("users", "user_id"), @@ -1105,6 +1842,330 @@ class ComponentReflectionTest(fixtures.TablesTest): id_ = {c["name"]: c for c in cols}[cname] assert id_.get("autoincrement", True) + @testing.combinations( + (True, testing.requires.schemas), (False,), argnames="use_schema" + ) + def test_get_table_options(self, use_schema): + insp = inspect(config.db) + schema = config.test_schema if use_schema else None + + if testing.requires.reflect_table_options.enabled: + res = insp.get_table_options("users", schema=schema) + is_true(isinstance(res, dict)) + # NOTE: can't really create a table with no option + res = insp.get_table_options("no_constraints", schema=schema) + is_true(isinstance(res, dict)) + else: + with expect_raises(NotImplementedError): + res = insp.get_table_options("users", schema=schema) + + @testing.combinations((True, testing.requires.schemas), False) + def test_multi_get_table_options(self, use_schema): + insp = inspect(config.db) + if testing.requires.reflect_table_options.enabled: + schema = config.test_schema if use_schema else None + res = insp.get_multi_table_options(schema=schema) + + exp = { + (schema, table): insp.get_table_options(table, schema=schema) + for table in insp.get_table_names(schema=schema) + } + eq_(res, exp) + else: + with expect_raises(NotImplementedError): + res = insp.get_multi_table_options() + + @testing.fixture + def get_multi_exp(self, connection): + def provide_fixture( + schema, scope, kind, use_filter, single_reflect_fn, exp_method + ): + insp = inspect(connection) + # call the reflection function at least once to avoid + # "Unexpected success" errors if the result is actually empty + # and NotImplementedError is not raised + single_reflect_fn(insp, "email_addresses") + kw = {"scope": scope, "kind": kind} + if schema: + schema = schema() + + filter_names = [] + + if ObjectKind.TABLE in kind: + filter_names.extend( + ["comment_test", "users", "does-not-exist"] + ) + if ObjectKind.VIEW in kind: + filter_names.extend(["email_addresses_v", "does-not-exist"]) + if ObjectKind.MATERIALIZED_VIEW in kind: + filter_names.extend(["dingalings_v", "does-not-exist"]) + + if schema: + kw["schema"] = schema + if use_filter: + kw["filter_names"] = filter_names + + exp = exp_method( + schema=schema, + scope=scope, + kind=kind, + filter_names=kw.get("filter_names"), + ) + kws = [kw] + if scope == ObjectScope.DEFAULT: + nkw = kw.copy() + nkw.pop("scope") + kws.append(nkw) + if kind == ObjectKind.TABLE: + nkw = kw.copy() + nkw.pop("kind") + kws.append(nkw) + + return inspect(connection), kws, exp + + return provide_fixture + + @testing.requires.reflect_table_options + @_multi_combination + def test_multi_get_table_options_tables( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_table_options, + self.exp_options, + ) + for kw in kws: + insp.clear_cache() + result = insp.get_multi_table_options(**kw) + eq_(result, exp) + + @testing.requires.comment_reflection + @_multi_combination + def test_get_multi_table_comment( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_table_comment, + self.exp_comments, + ) + for kw in kws: + insp.clear_cache() + eq_(insp.get_multi_table_comment(**kw), exp) + + def _check_list(self, result, exp, req_keys=None, msg=None): + if req_keys is None: + eq_(result, exp, msg) + else: + eq_(len(result), len(exp), msg) + for r, e in zip(result, exp): + for k in set(r) | set(e): + if k in req_keys or (k in r and k in e): + eq_(r[k], e[k], f"{msg} - {k} - {r}") + + def _check_table_dict(self, result, exp, req_keys=None, make_lists=False): + eq_(set(result.keys()), set(exp.keys())) + for k in result: + r, e = result[k], exp[k] + if make_lists: + r, e = [r], [e] + self._check_list(r, e, req_keys, k) + + @_multi_combination + def test_get_multi_columns( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_columns, + self.exp_columns, + ) + + for kw in kws: + insp.clear_cache() + result = insp.get_multi_columns(**kw) + self._check_table_dict(result, exp, self._required_column_keys) + + @testing.requires.primary_key_constraint_reflection + @_multi_combination + def test_get_multi_pk_constraint( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_pk_constraint, + self.exp_pks, + ) + for kw in kws: + insp.clear_cache() + result = insp.get_multi_pk_constraint(**kw) + self._check_table_dict( + result, exp, self._required_pk_keys, make_lists=True + ) + + def _adjust_sort(self, result, expected, key): + if not testing.requires.implicitly_named_constraints.enabled: + for obj in [result, expected]: + for val in obj.values(): + if len(val) > 1 and any( + v.get("name") in (None, mock.ANY) for v in val + ): + val.sort(key=key) + + @testing.requires.foreign_key_constraint_reflection + @_multi_combination + def test_get_multi_foreign_keys( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_foreign_keys, + self.exp_fks, + ) + for kw in kws: + insp.clear_cache() + result = insp.get_multi_foreign_keys(**kw) + self._adjust_sort( + result, exp, lambda d: tuple(d["constrained_columns"]) + ) + self._check_table_dict(result, exp, self._required_fk_keys) + + @testing.requires.index_reflection + @_multi_combination + def test_get_multi_indexes( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_indexes, + self.exp_indexes, + ) + for kw in kws: + insp.clear_cache() + result = insp.get_multi_indexes(**kw) + self._check_table_dict(result, exp, self._required_index_keys) + + @testing.requires.unique_constraint_reflection + @_multi_combination + def test_get_multi_unique_constraints( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_unique_constraints, + self.exp_ucs, + ) + for kw in kws: + insp.clear_cache() + result = insp.get_multi_unique_constraints(**kw) + self._adjust_sort(result, exp, lambda d: tuple(d["column_names"])) + self._check_table_dict(result, exp, self._required_unique_cst_keys) + + @testing.requires.check_constraint_reflection + @_multi_combination + def test_get_multi_check_constraints( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_check_constraints, + self.exp_ccs, + ) + for kw in kws: + insp.clear_cache() + result = insp.get_multi_check_constraints(**kw) + self._adjust_sort(result, exp, lambda d: tuple(d["sqltext"])) + self._check_table_dict(result, exp, self._required_cc_keys) + + @testing.combinations( + ("get_table_options", testing.requires.reflect_table_options), + "get_columns", + ( + "get_pk_constraint", + testing.requires.primary_key_constraint_reflection, + ), + ( + "get_foreign_keys", + testing.requires.foreign_key_constraint_reflection, + ), + ("get_indexes", testing.requires.index_reflection), + ( + "get_unique_constraints", + testing.requires.unique_constraint_reflection, + ), + ( + "get_check_constraints", + testing.requires.check_constraint_reflection, + ), + ("get_table_comment", testing.requires.comment_reflection), + argnames="method", + ) + def test_not_existing_table(self, method, connection): + insp = inspect(connection) + meth = getattr(insp, method) + with expect_raises(NoSuchTableError): + meth("table_does_not_exists") + + def test_unreflectable(self, connection): + mc = Inspector.get_multi_columns + + def patched(*a, **k): + ur = k.setdefault("unreflectable", {}) + ur[(None, "some_table")] = UnreflectableTableError("err") + return mc(*a, **k) + + with mock.patch.object(Inspector, "get_multi_columns", patched): + with expect_raises_message(UnreflectableTableError, "err"): + inspect(connection).reflect_table( + Table("some_table", MetaData()), None + ) + + @testing.combinations(True, False, argnames="use_schema") + @testing.combinations( + (True, testing.requires.views), False, argnames="views" + ) + def test_metadata(self, connection, use_schema, views): + m = MetaData() + schema = config.test_schema if use_schema else None + m.reflect(connection, schema=schema, views=views, resolve_fks=False) + + insp = inspect(connection) + tables = insp.get_table_names(schema) + if views: + tables += insp.get_view_names(schema) + try: + tables += insp.get_materialized_view_names(schema) + except NotImplementedError: + pass + if schema: + tables = [f"{schema}.{t}" for t in tables] + eq_(sorted(m.tables), sorted(tables)) + class TableNoColumnsTest(fixtures.TestBase): __requires__ = ("reflect_tables_no_columns",) @@ -1117,9 +2178,6 @@ class TableNoColumnsTest(fixtures.TestBase): @testing.fixture def view_no_columns(self, connection, metadata): - Table("empty", metadata) - metadata.create_all(connection) - Table("empty", metadata) event.listen( metadata, @@ -1134,31 +2192,32 @@ class TableNoColumnsTest(fixtures.TestBase): ) metadata.create_all(connection) - @testing.requires.reflect_tables_no_columns def test_reflect_table_no_columns(self, connection, table_no_columns): t2 = Table("empty", MetaData(), autoload_with=connection) eq_(list(t2.c), []) - @testing.requires.reflect_tables_no_columns def test_get_columns_table_no_columns(self, connection, table_no_columns): - eq_(inspect(connection).get_columns("empty"), []) + insp = inspect(connection) + eq_(insp.get_columns("empty"), []) + multi = insp.get_multi_columns() + eq_(multi, {(None, "empty"): []}) - @testing.requires.reflect_tables_no_columns def test_reflect_incl_table_no_columns(self, connection, table_no_columns): m = MetaData() m.reflect(connection) assert set(m.tables).intersection(["empty"]) @testing.requires.views - @testing.requires.reflect_tables_no_columns def test_reflect_view_no_columns(self, connection, view_no_columns): t2 = Table("empty_v", MetaData(), autoload_with=connection) eq_(list(t2.c), []) @testing.requires.views - @testing.requires.reflect_tables_no_columns def test_get_columns_view_no_columns(self, connection, view_no_columns): - eq_(inspect(connection).get_columns("empty_v"), []) + insp = inspect(connection) + eq_(insp.get_columns("empty_v"), []) + multi = insp.get_multi_columns(kind=ObjectKind.VIEW) + eq_(multi, {(None, "empty_v"): []}) class ComponentReflectionTestExtra(fixtures.TestBase): @@ -1185,12 +2244,18 @@ class ComponentReflectionTestExtra(fixtures.TestBase): ), schema=schema, ) + Table( + "no_constraints", + metadata, + Column("data", sa.String(20)), + schema=schema, + ) metadata.create_all(connection) - inspector = inspect(connection) + insp = inspect(connection) reflected = sorted( - inspector.get_check_constraints("sa_cc", schema=schema), + insp.get_check_constraints("sa_cc", schema=schema), key=operator.itemgetter("name"), ) @@ -1213,6 +2278,8 @@ class ComponentReflectionTestExtra(fixtures.TestBase): {"name": "cc1", "sqltext": "a > 1 and a < 5"}, ], ) + no_cst = "no_constraints" + eq_(insp.get_check_constraints(no_cst, schema=schema), []) @testing.requires.indexes_with_expressions def test_reflect_expression_based_indexes(self, metadata, connection): @@ -1642,7 +2709,8 @@ class IdentityReflectionTest(fixtures.TablesTest): if col["name"] == "normal": is_false("identity" in col) elif col["name"] == "id1": - is_true(col["autoincrement"] in (True, "auto")) + if "autoincrement" in col: + is_true(col["autoincrement"]) eq_(col["default"], None) is_true("identity" in col) self.check( @@ -1659,7 +2727,8 @@ class IdentityReflectionTest(fixtures.TablesTest): approx=True, ) elif col["name"] == "id2": - is_true(col["autoincrement"] in (True, "auto")) + if "autoincrement" in col: + is_true(col["autoincrement"]) eq_(col["default"], None) is_true("identity" in col) self.check( @@ -1685,7 +2754,8 @@ class IdentityReflectionTest(fixtures.TablesTest): if col["name"] == "normal": is_false("identity" in col) elif col["name"] == "id1": - is_true(col["autoincrement"] in (True, "auto")) + if "autoincrement" in col: + is_true(col["autoincrement"]) eq_(col["default"], None) is_true("identity" in col) self.check( @@ -1735,16 +2805,16 @@ class CompositeKeyReflectionTest(fixtures.TablesTest): ) @testing.requires.primary_key_constraint_reflection - def test_pk_column_order(self): + def test_pk_column_order(self, connection): # test for issue #5661 - insp = inspect(self.bind) + insp = inspect(connection) primary_key = insp.get_pk_constraint(self.tables.tb1.name) eq_(primary_key.get("constrained_columns"), ["name", "id", "attr"]) @testing.requires.foreign_key_constraint_reflection - def test_fk_column_order(self): + def test_fk_column_order(self, connection): # test for issue #5661 - insp = inspect(self.bind) + insp = inspect(connection) foreign_keys = insp.get_foreign_keys(self.tables.tb2.name) eq_(len(foreign_keys), 1) fkey1 = foreign_keys[0] diff --git a/lib/sqlalchemy/testing/suite/test_sequence.py b/lib/sqlalchemy/testing/suite/test_sequence.py index eae051992..e15fad642 100644 --- a/lib/sqlalchemy/testing/suite/test_sequence.py +++ b/lib/sqlalchemy/testing/suite/test_sequence.py @@ -194,16 +194,23 @@ class HasSequenceTest(fixtures.TablesTest): ) def test_has_sequence(self, connection): - eq_( - inspect(connection).has_sequence("user_id_seq"), - True, - ) + eq_(inspect(connection).has_sequence("user_id_seq"), True) + + def test_has_sequence_cache(self, connection, metadata): + insp = inspect(connection) + eq_(insp.has_sequence("user_id_seq"), True) + ss = Sequence("new_seq", metadata=metadata) + eq_(insp.has_sequence("new_seq"), False) + ss.create(connection) + try: + eq_(insp.has_sequence("new_seq"), False) + insp.clear_cache() + eq_(insp.has_sequence("new_seq"), True) + finally: + ss.drop(connection) def test_has_sequence_other_object(self, connection): - eq_( - inspect(connection).has_sequence("user_id_table"), - False, - ) + eq_(inspect(connection).has_sequence("user_id_table"), False) @testing.requires.schemas def test_has_sequence_schema(self, connection): @@ -215,10 +222,7 @@ class HasSequenceTest(fixtures.TablesTest): ) def test_has_sequence_neg(self, connection): - eq_( - inspect(connection).has_sequence("some_sequence"), - False, - ) + eq_(inspect(connection).has_sequence("some_sequence"), False) @testing.requires.schemas def test_has_sequence_schemas_neg(self, connection): @@ -240,10 +244,7 @@ class HasSequenceTest(fixtures.TablesTest): @testing.requires.schemas def test_has_sequence_remote_not_in_default(self, connection): - eq_( - inspect(connection).has_sequence("schema_seq"), - False, - ) + eq_(inspect(connection).has_sequence("schema_seq"), False) def test_get_sequence_names(self, connection): exp = {"other_seq", "user_id_seq"} diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py index 0070b4d67..6fd42af70 100644 --- a/lib/sqlalchemy/testing/util.py +++ b/lib/sqlalchemy/testing/util.py @@ -393,36 +393,55 @@ def drop_all_tables_from_metadata(metadata, engine_or_connection): go(engine_or_connection) -def drop_all_tables(engine, inspector, schema=None, include_names=None): +def drop_all_tables( + engine, + inspector, + schema=None, + consider_schemas=(None,), + include_names=None, +): if include_names is not None: include_names = set(include_names) + if schema is not None: + assert consider_schemas == ( + None, + ), "consider_schemas and schema are mutually exclusive" + consider_schemas = (schema,) + with engine.begin() as conn: - for tname, fkcs in reversed( - inspector.get_sorted_table_and_fkc_names(schema=schema) + for table_key, fkcs in reversed( + inspector.sort_tables_on_foreign_key_dependency( + consider_schemas=consider_schemas + ) ): - if tname: - if include_names is not None and tname not in include_names: + if table_key: + if ( + include_names is not None + and table_key[1] not in include_names + ): continue conn.execute( - DropTable(Table(tname, MetaData(), schema=schema)) + DropTable( + Table(table_key[1], MetaData(), schema=table_key[0]) + ) ) elif fkcs: if not engine.dialect.supports_alter: continue - for tname, fkc in fkcs: + for t_key, fkc in fkcs: if ( include_names is not None - and tname not in include_names + and t_key[1] not in include_names ): continue tb = Table( - tname, + t_key[1], MetaData(), Column("x", Integer), Column("y", Integer), - schema=schema, + schema=t_key[0], ) conn.execute( DropConstraint( -- cgit v1.2.1