diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2021-12-17 18:04:47 -0500 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-01-14 16:54:13 -0500 |
| commit | 06f83c26ea3636eaec0b85fc9d733ab4bfb827ec (patch) | |
| tree | 13d43b9007f956bf514d757ce6781a378125fc3e /lib | |
| parent | a869dc8fe3cd579ed9bab665d215a6c3e3d8a4ca (diff) | |
| download | sqlalchemy-06f83c26ea3636eaec0b85fc9d733ab4bfb827ec.tar.gz | |
track item schema names to identify name collisions w/ default schema
Added an additional lookup step to the compiler which will track all FROM
clauses which are tables, that may have the same name shared in multiple
schemas where one of the schemas is the implicit "default" schema; in this
case, the table name when referring to that name without a schema
qualification will be rendered with an anonymous alias name at the compiler
level in order to disambiguate the two (or more) names. The approach of
schema-qualifying the normally unqualified name with the server-detected
"default schema name" value was also considered, however this approach
doesn't apply to Oracle nor is it accepted by SQL Server, nor would it work
with multiple entries in the PostgreSQL search path. The name collision
issue resolved here has been identified as affecting at least Oracle,
PostgreSQL, SQL Server, MySQL and MariaDB.
Fixes: #7471
Change-Id: Id65e7ca8c43fe8d95777084e8d5ec140ebcd784d
Diffstat (limited to 'lib')
| -rw-r--r-- | lib/sqlalchemy/orm/context.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/base.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 28 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/elements.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 28 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/suite/test_select.py | 99 |
6 files changed, 156 insertions, 3 deletions
diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index cba7cf07d..4cff8defb 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -683,7 +683,6 @@ class ORMSelectCompileState(ORMCompileState, SelectState): self._setup_for_generate() SelectState.__init__(self, self.statement, compiler, **kw) - return self def _dump_option_struct(self): diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 6ab9a75f6..ae586c9f2 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -499,7 +499,7 @@ class CompileState: """ - __slots__ = ("statement",) + __slots__ = ("statement", "_ambiguous_table_name_map") plugins = {} diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index cb10811c6..af39f0672 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1466,6 +1466,7 @@ class SQLCompiler(Compiled): add_to_result_map=None, include_table=True, result_map_targets=(), + ambiguous_table_name_map=None, **kwargs, ): name = orig_name = column.name @@ -1502,6 +1503,14 @@ class SQLCompiler(Compiled): else: schema_prefix = "" tablename = table.name + + if ( + not effective_schema + and ambiguous_table_name_map + and tablename in ambiguous_table_name_map + ): + tablename = ambiguous_table_name_map[tablename] + if isinstance(tablename, elements._truncated_label): tablename = self._truncated_identifier("alias", tablename) @@ -3252,6 +3261,10 @@ class SQLCompiler(Compiled): compile_state = select_stmt._compile_state_factory( select_stmt, self, **kwargs ) + kwargs[ + "ambiguous_table_name_map" + ] = compile_state._ambiguous_table_name_map + select_stmt = compile_state.statement toplevel = not self.stack @@ -3732,6 +3745,7 @@ class SQLCompiler(Compiled): fromhints=None, use_schema=True, from_linter=None, + ambiguous_table_name_map=None, **kwargs, ): if from_linter: @@ -3748,6 +3762,20 @@ class SQLCompiler(Compiled): ) else: ret = self.preparer.quote(table.name) + + if ( + not effective_schema + and ambiguous_table_name_map + and table.name in ambiguous_table_name_map + ): + anon_name = self._truncated_identifier( + "alias", ambiguous_table_name_map[table.name] + ) + + ret = ret + self.get_render_as_alias_suffix( + self.preparer.format_alias(None, anon_name) + ) + if fromhints and table in fromhints: ret = self.format_from_hint_text( ret, table, fromhints[table], iscrud diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index a025cce35..1fa312b7e 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -286,6 +286,7 @@ class ClauseElement( is_clause_element = True is_selectable = False + _is_table = False _is_textual = False _is_from_clause = False _is_returns_rows = False diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index e674c4b74..6a7b83504 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -2484,6 +2484,8 @@ class TableClause(roles.DMLTableRole, Immutable, FromClause): named_with_column = True + _is_table = True + implicit_returning = False """:class:`_expression.TableClause` doesn't support having a primary key or column @@ -3980,6 +3982,8 @@ class SelectState(util.MemoizedSlots, CompileState): return go def _get_froms(self, statement): + self._ambiguous_table_name_map = ambiguous_table_name_map = {} + return self._normalize_froms( itertools.chain( itertools.chain.from_iterable( @@ -3997,10 +4001,16 @@ class SelectState(util.MemoizedSlots, CompileState): self.from_clauses, ), check_statement=statement, + ambiguous_table_name_map=ambiguous_table_name_map, ) @classmethod - def _normalize_froms(cls, iterable_of_froms, check_statement=None): + def _normalize_froms( + cls, + iterable_of_froms, + check_statement=None, + ambiguous_table_name_map=None, + ): """given an iterable of things to select FROM, reduce them to what would actually render in the FROM clause of a SELECT. @@ -4013,6 +4023,7 @@ class SelectState(util.MemoizedSlots, CompileState): froms = [] for item in iterable_of_froms: + if item._is_subquery and item.element is check_statement: raise exc.InvalidRequestError( "select() construct refers to itself as a FROM" @@ -4033,6 +4044,21 @@ class SelectState(util.MemoizedSlots, CompileState): # using a list to maintain ordering froms = [f for f in froms if f not in toremove] + if ambiguous_table_name_map is not None: + ambiguous_table_name_map.update( + ( + fr.name, + _anonymous_label.safe_construct( + hash(fr.name), fr.name + ), + ) + for item in froms + for fr in item._from_objects + if fr._is_table + and fr.schema + and fr.name not in ambiguous_table_name_map + ) + return froms def _get_display_froms( diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py index c1228f5df..92fd29503 100644 --- a/lib/sqlalchemy/testing/suite/test_select.py +++ b/lib/sqlalchemy/testing/suite/test_select.py @@ -624,6 +624,105 @@ class FetchLimitOffsetTest(fixtures.TablesTest): eq_(set(fa), set([(3, 3, 4), (4, 4, 5), (5, 4, 6)])) +class SameNamedSchemaTableTest(fixtures.TablesTest): + """tests for #7471""" + + __backend__ = True + + __requires__ = ("schemas",) + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + schema=config.test_schema, + ) + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column( + "some_table_id", + Integer, + # ForeignKey("%s.some_table.id" % config.test_schema), + nullable=False, + ), + ) + + @classmethod + def insert_data(cls, connection): + some_table, some_table_schema = cls.tables( + "some_table", "%s.some_table" % config.test_schema + ) + connection.execute(some_table_schema.insert(), {"id": 1}) + connection.execute(some_table.insert(), {"id": 1, "some_table_id": 1}) + + def test_simple_join_both_tables(self, connection): + some_table, some_table_schema = self.tables( + "some_table", "%s.some_table" % config.test_schema + ) + + eq_( + connection.execute( + select(some_table, some_table_schema).join_from( + some_table, + some_table_schema, + some_table.c.some_table_id == some_table_schema.c.id, + ) + ).first(), + (1, 1, 1), + ) + + def test_simple_join_whereclause_only(self, connection): + some_table, some_table_schema = self.tables( + "some_table", "%s.some_table" % config.test_schema + ) + + eq_( + connection.execute( + select(some_table) + .join_from( + some_table, + some_table_schema, + some_table.c.some_table_id == some_table_schema.c.id, + ) + .where(some_table.c.id == 1) + ).first(), + (1, 1), + ) + + def test_subquery(self, connection): + some_table, some_table_schema = self.tables( + "some_table", "%s.some_table" % config.test_schema + ) + + subq = ( + select(some_table) + .join_from( + some_table, + some_table_schema, + some_table.c.some_table_id == some_table_schema.c.id, + ) + .where(some_table.c.id == 1) + .subquery() + ) + + eq_( + connection.execute( + select(some_table, subq.c.id) + .join_from( + some_table, + subq, + some_table.c.some_table_id == subq.c.id, + ) + .where(some_table.c.id == 1) + ).first(), + (1, 1, 1), + ) + + class JoinTest(fixtures.TablesTest): __backend__ = True |
