diff options
Diffstat (limited to 'lib/sqlalchemy/dialects')
| -rw-r--r-- | lib/sqlalchemy/dialects/mssql/base.py | 172 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/mysql/base.py | 47 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/oracle/base.py | 2240 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/oracle/cx_oracle.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/oracle/dictionary.py | 495 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/oracle/provision.py | 54 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/oracle/types.py | 233 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/__init__.py | 31 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/_psycopg_common.py | 18 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/asyncpg.py | 5 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/base.py | 2624 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/pg8000.py | 7 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/pg_catalog.py | 292 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/types.py | 485 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/sqlite/base.py | 161 |
15 files changed, 4459 insertions, 2409 deletions
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 12f495d6e..2a4362ccb 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -831,6 +831,7 @@ from ... import util from ...engine import cursor as _cursor from ...engine import default from ...engine import reflection +from ...engine.reflection import ReflectionDefaults from ...sql import coercions from ...sql import compiler from ...sql import elements @@ -3010,55 +3011,16 @@ class MSDialect(default.DefaultDialect): return self.schema_name @_db_plus_owner - def has_table(self, connection, tablename, dbname, owner, schema): + def has_table(self, connection, tablename, dbname, owner, schema, **kw): self._ensure_has_table_connection(connection) - if tablename.startswith("#"): # temporary table - # mssql does not support temporary views - # SQL Error [4103] [S0001]: "#v": Temporary views are not allowed - tables = ischema.mssql_temp_table_columns - s = sql.select(tables.c.table_name).where( - tables.c.table_name.like( - self._temp_table_name_like_pattern(tablename) - ) - ) - - # #7168: fetch all (not just first match) in case some other #temp - # table with the same name happens to appear first - table_names = connection.execute(s).scalars().fetchall() - # #6910: verify it's not a temp table from another session - for table_name in table_names: - if bool( - connection.scalar( - text("SELECT object_id(:table_name)"), - {"table_name": "tempdb.dbo.[{}]".format(table_name)}, - ) - ): - return True - else: - return False - else: - tables = ischema.tables - - s = sql.select(tables.c.table_name).where( - sql.and_( - sql.or_( - tables.c.table_type == "BASE TABLE", - tables.c.table_type == "VIEW", - ), - tables.c.table_name == tablename, - ) - ) - - if owner: - s = s.where(tables.c.table_schema == owner) - - c = connection.execute(s) - - return c.first() is not None + return self._internal_has_table(connection, tablename, owner, **kw) + @reflection.cache @_db_plus_owner - def has_sequence(self, connection, sequencename, dbname, owner, schema): + def has_sequence( + self, connection, sequencename, dbname, owner, schema, **kw + ): sequences = ischema.sequences s = sql.select(sequences.c.sequence_name).where( @@ -3128,6 +3090,60 @@ class MSDialect(default.DefaultDialect): return view_names @reflection.cache + def _internal_has_table(self, connection, tablename, owner, **kw): + if tablename.startswith("#"): # temporary table + # mssql does not support temporary views + # SQL Error [4103] [S0001]: "#v": Temporary views are not allowed + tables = ischema.mssql_temp_table_columns + + s = sql.select(tables.c.table_name).where( + tables.c.table_name.like( + self._temp_table_name_like_pattern(tablename) + ) + ) + + # #7168: fetch all (not just first match) in case some other #temp + # table with the same name happens to appear first + table_names = connection.scalars(s).all() + # #6910: verify it's not a temp table from another session + for table_name in table_names: + if bool( + connection.scalar( + text("SELECT object_id(:table_name)"), + {"table_name": "tempdb.dbo.[{}]".format(table_name)}, + ) + ): + return True + else: + return False + else: + tables = ischema.tables + + s = sql.select(tables.c.table_name).where( + sql.and_( + sql.or_( + tables.c.table_type == "BASE TABLE", + tables.c.table_type == "VIEW", + ), + tables.c.table_name == tablename, + ) + ) + + if owner: + s = s.where(tables.c.table_schema == owner) + + c = connection.execute(s) + + return c.first() is not None + + def _default_or_error(self, connection, tablename, owner, method, **kw): + # TODO: try to avoid having to run a separate query here + if self._internal_has_table(connection, tablename, owner, **kw): + return method() + else: + raise exc.NoSuchTableError(f"{owner}.{tablename}") + + @reflection.cache @_db_plus_owner def get_indexes(self, connection, tablename, dbname, owner, schema, **kw): filter_definition = ( @@ -3138,14 +3154,14 @@ class MSDialect(default.DefaultDialect): rp = connection.execution_options(future_result=True).execute( sql.text( "select ind.index_id, ind.is_unique, ind.name, " - "%s " + f"{filter_definition} " "from sys.indexes as ind join sys.tables as tab on " "ind.object_id=tab.object_id " "join sys.schemas as sch on sch.schema_id=tab.schema_id " "where tab.name = :tabname " "and sch.name=:schname " - "and ind.is_primary_key=0 and ind.type != 0" - % filter_definition + "and ind.is_primary_key=0 and ind.type != 0 " + "order by ind.name " ) .bindparams( sql.bindparam("tabname", tablename, ischema.CoerceUnicode()), @@ -3203,31 +3219,34 @@ class MSDialect(default.DefaultDialect): "mssql_include" ] = index_info["include_columns"] - return list(indexes.values()) + if indexes: + return list(indexes.values()) + else: + return self._default_or_error( + connection, tablename, owner, ReflectionDefaults.indexes, **kw + ) @reflection.cache @_db_plus_owner def get_view_definition( self, connection, viewname, dbname, owner, schema, **kw ): - rp = connection.execute( + view_def = connection.execute( sql.text( - "select definition from sys.sql_modules as mod, " - "sys.views as views, " - "sys.schemas as sch" - " where " - "mod.object_id=views.object_id and " - "views.schema_id=sch.schema_id and " - "views.name=:viewname and sch.name=:schname" + "select mod.definition " + "from sys.sql_modules as mod " + "join sys.views as views on mod.object_id = views.object_id " + "join sys.schemas as sch on views.schema_id = sch.schema_id " + "where views.name=:viewname and sch.name=:schname" ).bindparams( sql.bindparam("viewname", viewname, ischema.CoerceUnicode()), sql.bindparam("schname", owner, ischema.CoerceUnicode()), ) - ) - - if rp: - view_def = rp.scalar() + ).scalar() + if view_def: return view_def + else: + raise exc.NoSuchTableError(f"{owner}.{viewname}") def _temp_table_name_like_pattern(self, tablename): # LIKE uses '%' to match zero or more characters and '_' to match any @@ -3417,7 +3436,12 @@ class MSDialect(default.DefaultDialect): cols.append(cdict) - return cols + if cols: + return cols + else: + return self._default_or_error( + connection, tablename, owner, ReflectionDefaults.columns, **kw + ) @reflection.cache @_db_plus_owner @@ -3450,7 +3474,16 @@ class MSDialect(default.DefaultDialect): pkeys.append(row["COLUMN_NAME"]) if constraint_name is None: constraint_name = row[C.c.constraint_name.name] - return {"constrained_columns": pkeys, "name": constraint_name} + if pkeys: + return {"constrained_columns": pkeys, "name": constraint_name} + else: + return self._default_or_error( + connection, + tablename, + owner, + ReflectionDefaults.pk_constraint, + **kw, + ) @reflection.cache @_db_plus_owner @@ -3591,7 +3624,7 @@ index_info AS ( fkeys = util.defaultdict(fkey_rec) - for r in connection.execute(s).fetchall(): + for r in connection.execute(s).all(): ( _, # constraint schema rfknm, @@ -3632,4 +3665,13 @@ index_info AS ( local_cols.append(scol) remote_cols.append(rcol) - return list(fkeys.values()) + if fkeys: + return list(fkeys.values()) + else: + return self._default_or_error( + connection, + tablename, + owner, + ReflectionDefaults.foreign_keys, + **kw, + ) diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 7c9a68236..502371be9 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1056,6 +1056,7 @@ from ... import sql from ... import util from ...engine import default from ...engine import reflection +from ...engine.reflection import ReflectionDefaults from ...sql import coercions from ...sql import compiler from ...sql import elements @@ -2648,7 +2649,8 @@ class MySQLDialect(default.DefaultDialect): def _get_default_schema_name(self, connection): return connection.exec_driver_sql("SELECT DATABASE()").scalar() - def has_table(self, connection, table_name, schema=None): + @reflection.cache + def has_table(self, connection, table_name, schema=None, **kw): self._ensure_has_table_connection(connection) if schema is None: @@ -2670,7 +2672,8 @@ class MySQLDialect(default.DefaultDialect): ) return bool(rs.scalar()) - def has_sequence(self, connection, sequence_name, schema=None): + @reflection.cache + def has_sequence(self, connection, sequence_name, schema=None, **kw): if not self.supports_sequences: self._sequences_not_supported() if not schema: @@ -2847,14 +2850,20 @@ class MySQLDialect(default.DefaultDialect): parsed_state = self._parsed_state_or_create( connection, table_name, schema, **kw ) - return parsed_state.table_options + if parsed_state.table_options: + return parsed_state.table_options + else: + return ReflectionDefaults.table_options() @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( connection, table_name, schema, **kw ) - return parsed_state.columns + if parsed_state.columns: + return parsed_state.columns + else: + return ReflectionDefaults.columns() @reflection.cache def get_pk_constraint(self, connection, table_name, schema=None, **kw): @@ -2866,7 +2875,7 @@ class MySQLDialect(default.DefaultDialect): # There can be only one. cols = [s[0] for s in key["columns"]] return {"constrained_columns": cols, "name": None} - return {"constrained_columns": [], "name": None} + return ReflectionDefaults.pk_constraint() @reflection.cache def get_foreign_keys(self, connection, table_name, schema=None, **kw): @@ -2909,7 +2918,7 @@ class MySQLDialect(default.DefaultDialect): if self._needs_correct_for_88718_96365: self._correct_for_mysql_bugs_88718_96365(fkeys, connection) - return fkeys + return fkeys if fkeys else ReflectionDefaults.foreign_keys() def _correct_for_mysql_bugs_88718_96365(self, fkeys, connection): # Foreign key is always in lower case (MySQL 8.0) @@ -3000,21 +3009,22 @@ class MySQLDialect(default.DefaultDialect): connection, table_name, schema, **kw ) - return [ + cks = [ {"name": spec["name"], "sqltext": spec["sqltext"]} for spec in parsed_state.ck_constraints ] + return cks if cks else ReflectionDefaults.check_constraints() @reflection.cache def get_table_comment(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( connection, table_name, schema, **kw ) - return { - "text": parsed_state.table_options.get( - "%s_comment" % self.name, None - ) - } + comment = parsed_state.table_options.get(f"{self.name}_comment", None) + if comment is not None: + return {"text": comment} + else: + return ReflectionDefaults.table_comment() @reflection.cache def get_indexes(self, connection, table_name, schema=None, **kw): @@ -3058,7 +3068,8 @@ class MySQLDialect(default.DefaultDialect): if flavor: index_d["type"] = flavor indexes.append(index_d) - return indexes + indexes.sort(key=lambda d: d["name"] or "~") # sort None as last + return indexes if indexes else ReflectionDefaults.indexes() @reflection.cache def get_unique_constraints( @@ -3068,7 +3079,7 @@ class MySQLDialect(default.DefaultDialect): connection, table_name, schema, **kw ) - return [ + ucs = [ { "name": key["name"], "column_names": [col[0] for col in key["columns"]], @@ -3077,6 +3088,11 @@ class MySQLDialect(default.DefaultDialect): for key in parsed_state.keys if key["type"] == "UNIQUE" ] + ucs.sort(key=lambda d: d["name"] or "~") # sort None as last + if ucs: + return ucs + else: + return ReflectionDefaults.unique_constraints() @reflection.cache def get_view_definition(self, connection, view_name, schema=None, **kw): @@ -3088,6 +3104,9 @@ class MySQLDialect(default.DefaultDialect): sql = self._show_create_table( connection, None, charset, full_name=full_name ) + if sql.upper().startswith("CREATE TABLE"): + # it's a table, not a view + raise exc.NoSuchTableError(full_name) return sql def _parsed_state_or_create( diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index faac0deb7..fee098889 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -518,21 +518,52 @@ columns for non-unique indexes, all but the last column for unique indexes). """ # noqa -from itertools import groupby +from __future__ import annotations + +from collections import defaultdict +from functools import lru_cache +from functools import wraps import re +from . import dictionary +from .types import _OracleBoolean +from .types import _OracleDate +from .types import BFILE +from .types import BINARY_DOUBLE +from .types import BINARY_FLOAT +from .types import DATE +from .types import FLOAT +from .types import INTERVAL +from .types import LONG +from .types import NCLOB +from .types import NUMBER +from .types import NVARCHAR2 # noqa +from .types import OracleRaw # noqa +from .types import RAW +from .types import ROWID # noqa +from .types import VARCHAR2 # noqa from ... import Computed from ... import exc from ... import schema as sa_schema from ... import sql from ... import util from ...engine import default +from ...engine import ObjectKind +from ...engine import ObjectScope from ...engine import reflection +from ...engine.reflection import ReflectionDefaults +from ...sql import and_ +from ...sql import bindparam from ...sql import compiler from ...sql import expression +from ...sql import func +from ...sql import null +from ...sql import or_ +from ...sql import select from ...sql import sqltypes from ...sql import util as sql_util from ...sql import visitors +from ...sql.visitors import InternalTraversal from ...types import BLOB from ...types import CHAR from ...types import CLOB @@ -561,229 +592,6 @@ NO_ARG_FNS = set( ) -class RAW(sqltypes._Binary): - __visit_name__ = "RAW" - - -OracleRaw = RAW - - -class NCLOB(sqltypes.Text): - __visit_name__ = "NCLOB" - - -class VARCHAR2(VARCHAR): - __visit_name__ = "VARCHAR2" - - -NVARCHAR2 = NVARCHAR - - -class NUMBER(sqltypes.Numeric, sqltypes.Integer): - __visit_name__ = "NUMBER" - - def __init__(self, precision=None, scale=None, asdecimal=None): - if asdecimal is None: - asdecimal = bool(scale and scale > 0) - - super(NUMBER, self).__init__( - precision=precision, scale=scale, asdecimal=asdecimal - ) - - def adapt(self, impltype): - ret = super(NUMBER, self).adapt(impltype) - # leave a hint for the DBAPI handler - ret._is_oracle_number = True - return ret - - @property - def _type_affinity(self): - if bool(self.scale and self.scale > 0): - return sqltypes.Numeric - else: - return sqltypes.Integer - - -class FLOAT(sqltypes.FLOAT): - """Oracle FLOAT. - - This is the same as :class:`_sqltypes.FLOAT` except that - an Oracle-specific :paramref:`_oracle.FLOAT.binary_precision` - parameter is accepted, and - the :paramref:`_sqltypes.Float.precision` parameter is not accepted. - - Oracle FLOAT types indicate precision in terms of "binary precision", which - defaults to 126. For a REAL type, the value is 63. This parameter does not - cleanly map to a specific number of decimal places but is roughly - equivalent to the desired number of decimal places divided by 0.3103. - - .. versionadded:: 2.0 - - """ - - __visit_name__ = "FLOAT" - - def __init__( - self, - binary_precision=None, - asdecimal=False, - decimal_return_scale=None, - ): - r""" - Construct a FLOAT - - :param binary_precision: Oracle binary precision value to be rendered - in DDL. This may be approximated to the number of decimal characters - using the formula "decimal precision = 0.30103 * binary precision". - The default value used by Oracle for FLOAT / DOUBLE PRECISION is 126. - - :param asdecimal: See :paramref:`_sqltypes.Float.asdecimal` - - :param decimal_return_scale: See - :paramref:`_sqltypes.Float.decimal_return_scale` - - """ - super().__init__( - asdecimal=asdecimal, decimal_return_scale=decimal_return_scale - ) - self.binary_precision = binary_precision - - -class BINARY_DOUBLE(sqltypes.Float): - __visit_name__ = "BINARY_DOUBLE" - - -class BINARY_FLOAT(sqltypes.Float): - __visit_name__ = "BINARY_FLOAT" - - -class BFILE(sqltypes.LargeBinary): - __visit_name__ = "BFILE" - - -class LONG(sqltypes.Text): - __visit_name__ = "LONG" - - -class _OracleDateLiteralRender: - def _literal_processor_datetime(self, dialect): - def process(value): - if value is not None: - if getattr(value, "microsecond", None): - value = ( - f"""TO_TIMESTAMP""" - f"""('{value.isoformat().replace("T", " ")}', """ - """'YYYY-MM-DD HH24:MI:SS.FF')""" - ) - else: - value = ( - f"""TO_DATE""" - f"""('{value.isoformat().replace("T", " ")}', """ - """'YYYY-MM-DD HH24:MI:SS')""" - ) - return value - - return process - - def _literal_processor_date(self, dialect): - def process(value): - if value is not None: - if getattr(value, "microsecond", None): - value = ( - f"""TO_TIMESTAMP""" - f"""('{value.isoformat().split("T")[0]}', """ - """'YYYY-MM-DD')""" - ) - else: - value = ( - f"""TO_DATE""" - f"""('{value.isoformat().split("T")[0]}', """ - """'YYYY-MM-DD')""" - ) - return value - - return process - - -class DATE(_OracleDateLiteralRender, sqltypes.DateTime): - """Provide the oracle DATE type. - - This type has no special Python behavior, except that it subclasses - :class:`_types.DateTime`; this is to suit the fact that the Oracle - ``DATE`` type supports a time value. - - .. versionadded:: 0.9.4 - - """ - - __visit_name__ = "DATE" - - def literal_processor(self, dialect): - return self._literal_processor_datetime(dialect) - - def _compare_type_affinity(self, other): - return other._type_affinity in (sqltypes.DateTime, sqltypes.Date) - - -class _OracleDate(_OracleDateLiteralRender, sqltypes.Date): - def literal_processor(self, dialect): - return self._literal_processor_date(dialect) - - -class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval): - __visit_name__ = "INTERVAL" - - def __init__(self, day_precision=None, second_precision=None): - """Construct an INTERVAL. - - Note that only DAY TO SECOND intervals are currently supported. - This is due to a lack of support for YEAR TO MONTH intervals - within available DBAPIs. - - :param day_precision: the day precision value. this is the number of - digits to store for the day field. Defaults to "2" - :param second_precision: the second precision value. this is the - number of digits to store for the fractional seconds field. - Defaults to "6". - - """ - self.day_precision = day_precision - self.second_precision = second_precision - - @classmethod - def _adapt_from_generic_interval(cls, interval): - return INTERVAL( - day_precision=interval.day_precision, - second_precision=interval.second_precision, - ) - - @property - def _type_affinity(self): - return sqltypes.Interval - - def as_generic(self, allow_nulltype=False): - return sqltypes.Interval( - native=True, - second_precision=self.second_precision, - day_precision=self.day_precision, - ) - - -class ROWID(sqltypes.TypeEngine): - """Oracle ROWID type. - - When used in a cast() or similar, generates ROWID. - - """ - - __visit_name__ = "ROWID" - - -class _OracleBoolean(sqltypes.Boolean): - def get_dbapi_type(self, dbapi): - return dbapi.NUMBER - - colspecs = { sqltypes.Boolean: _OracleBoolean, sqltypes.Interval: INTERVAL, @@ -1541,6 +1349,13 @@ class OracleExecutionContext(default.DefaultExecutionContext): type_, ) + def pre_exec(self): + if self.statement and "_oracle_dblink" in self.execution_options: + self.statement = self.statement.replace( + dictionary.DB_LINK_PLACEHOLDER, + self.execution_options["_oracle_dblink"], + ) + class OracleDialect(default.DefaultDialect): name = "oracle" @@ -1675,6 +1490,10 @@ class OracleDialect(default.DefaultDialect): # it may work also on versions before the 18 return self.server_version_info and self.server_version_info >= (18,) + @property + def _supports_except_all(self): + return self.server_version_info and self.server_version_info >= (21,) + def do_release_savepoint(self, connection, name): # Oracle does not support RELEASE SAVEPOINT pass @@ -1700,45 +1519,99 @@ class OracleDialect(default.DefaultDialect): except: return "READ COMMITTED" - def has_table(self, connection, table_name, schema=None): + def _execute_reflection( + self, connection, query, dblink, returns_long, params=None + ): + if dblink and not dblink.startswith("@"): + dblink = f"@{dblink}" + execution_options = { + # handle db links + "_oracle_dblink": dblink or "", + # override any schema translate map + "schema_translate_map": None, + } + + if dblink and returns_long: + # Oracle seems to error with + # "ORA-00997: illegal use of LONG datatype" when returning + # LONG columns via a dblink in a query with bind params + # This type seems to be very hard to cast into something else + # so it seems easier to just use bind param in this case + def visit_bindparam(bindparam): + bindparam.literal_execute = True + + query = visitors.cloned_traverse( + query, {}, {"bindparam": visit_bindparam} + ) + return connection.execute( + query, params, execution_options=execution_options + ) + + @util.memoized_property + def _has_table_query(self): + # materialized views are returned by all_tables + tables = ( + select( + dictionary.all_tables.c.table_name, + dictionary.all_tables.c.owner, + ) + .union_all( + select( + dictionary.all_views.c.view_name.label("table_name"), + dictionary.all_views.c.owner, + ) + ) + .subquery("tables_and_views") + ) + + query = select(tables.c.table_name).where( + tables.c.table_name == bindparam("table_name"), + tables.c.owner == bindparam("owner"), + ) + return query + + @reflection.cache + def has_table( + self, connection, table_name, schema=None, dblink=None, **kw + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" self._ensure_has_table_connection(connection) if not schema: schema = self.default_schema_name - cursor = connection.execute( - sql.text( - """SELECT table_name FROM all_tables - WHERE table_name = CAST(:name AS VARCHAR2(128)) - AND owner = CAST(:schema_name AS VARCHAR2(128)) - UNION ALL - SELECT view_name FROM all_views - WHERE view_name = CAST(:name AS VARCHAR2(128)) - AND owner = CAST(:schema_name AS VARCHAR2(128)) - """ - ), - dict( - name=self.denormalize_name(table_name), - schema_name=self.denormalize_name(schema), - ), + params = { + "table_name": self.denormalize_name(table_name), + "owner": self.denormalize_name(schema), + } + cursor = self._execute_reflection( + connection, + self._has_table_query, + dblink, + returns_long=False, + params=params, ) - return cursor.first() is not None + return bool(cursor.scalar()) - def has_sequence(self, connection, sequence_name, schema=None): + @reflection.cache + def has_sequence( + self, connection, sequence_name, schema=None, dblink=None, **kw + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" if not schema: schema = self.default_schema_name - cursor = connection.execute( - sql.text( - "SELECT sequence_name FROM all_sequences " - "WHERE sequence_name = :name AND " - "sequence_owner = :schema_name" - ), - dict( - name=self.denormalize_name(sequence_name), - schema_name=self.denormalize_name(schema), - ), + + query = select(dictionary.all_sequences.c.sequence_name).where( + dictionary.all_sequences.c.sequence_name + == self.denormalize_name(sequence_name), + dictionary.all_sequences.c.sequence_owner + == self.denormalize_name(schema), ) - return cursor.first() is not None + + cursor = self._execute_reflection( + connection, query, dblink, returns_long=False + ) + return bool(cursor.scalar()) def _get_default_schema_name(self, connection): return self.normalize_name( @@ -1747,329 +1620,633 @@ class OracleDialect(default.DefaultDialect): ).scalar() ) - def _resolve_synonym( - self, - connection, - desired_owner=None, - desired_synonym=None, - desired_table=None, + @reflection.flexi_cache( + ("schema", InternalTraversal.dp_string), + ("filter_names", InternalTraversal.dp_string_list), + ("dblink", InternalTraversal.dp_string), + ) + def _get_synonyms(self, connection, schema, filter_names, dblink, **kw): + owner = self.denormalize_name(schema or self.default_schema_name) + + has_filter_names, params = self._prepare_filter_names(filter_names) + query = select( + dictionary.all_synonyms.c.synonym_name, + dictionary.all_synonyms.c.table_name, + dictionary.all_synonyms.c.table_owner, + dictionary.all_synonyms.c.db_link, + ).where(dictionary.all_synonyms.c.owner == owner) + if has_filter_names: + query = query.where( + dictionary.all_synonyms.c.synonym_name.in_( + params["filter_names"] + ) + ) + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).mappings() + return result.all() + + @lru_cache() + def _all_objects_query( + self, owner, scope, kind, has_filter_names, has_mat_views ): - """search for a local synonym matching the given desired owner/name. - - if desired_owner is None, attempts to locate a distinct owner. - - returns the actual name, owner, dblink name, and synonym name if - found. - """ - - q = ( - "SELECT owner, table_owner, table_name, db_link, " - "synonym_name FROM all_synonyms WHERE " + query = ( + select(dictionary.all_objects.c.object_name) + .select_from(dictionary.all_objects) + .where(dictionary.all_objects.c.owner == owner) ) - clauses = [] - params = {} - if desired_synonym: - clauses.append( - "synonym_name = CAST(:synonym_name AS VARCHAR2(128))" + + # NOTE: materialized views are listed in all_objects twice; + # once as MATERIALIZE VIEW and once as TABLE + if kind is ObjectKind.ANY: + # materilaized view are listed also as tables so there is no + # need to add them to the in_. + query = query.where( + dictionary.all_objects.c.object_type.in_(("TABLE", "VIEW")) ) - params["synonym_name"] = desired_synonym - if desired_owner: - clauses.append("owner = CAST(:desired_owner AS VARCHAR2(128))") - params["desired_owner"] = desired_owner - if desired_table: - clauses.append("table_name = CAST(:tname AS VARCHAR2(128))") - params["tname"] = desired_table - - q += " AND ".join(clauses) - - result = connection.execution_options(future_result=True).execute( - sql.text(q), params - ) - if desired_owner: - row = result.mappings().first() - if row: - return ( - row["table_name"], - row["table_owner"], - row["db_link"], - row["synonym_name"], - ) - else: - return None, None, None, None else: - rows = result.mappings().all() - if len(rows) > 1: - raise AssertionError( - "There are multiple tables visible to the schema, you " - "must specify owner" - ) - elif len(rows) == 1: - row = rows[0] - return ( - row["table_name"], - row["table_owner"], - row["db_link"], - row["synonym_name"], - ) - else: - return None, None, None, None + object_type = [] + if ObjectKind.VIEW in kind: + object_type.append("VIEW") + if ( + ObjectKind.MATERIALIZED_VIEW in kind + and ObjectKind.TABLE not in kind + ): + # materilaized view are listed also as tables so there is no + # need to add them to the in_ if also selecting tables. + object_type.append("MATERIALIZED VIEW") + if ObjectKind.TABLE in kind: + object_type.append("TABLE") + if has_mat_views and ObjectKind.MATERIALIZED_VIEW not in kind: + # materialized view are listed also as tables, + # so they need to be filtered out + # EXCEPT ALL / MINUS profiles as faster than using + # NOT EXISTS or NOT IN with a subquery, but it's in + # general faster to get the mat view names and exclude + # them only when needed + query = query.where( + dictionary.all_objects.c.object_name.not_in( + bindparam("mat_views") + ) + ) + query = query.where( + dictionary.all_objects.c.object_type.in_(object_type) + ) - @reflection.cache - def _prepare_reflection_args( - self, - connection, - table_name, - schema=None, - resolve_synonyms=False, - dblink="", - **kw, - ): + # handles scope + if scope is ObjectScope.DEFAULT: + query = query.where(dictionary.all_objects.c.temporary == "N") + elif scope is ObjectScope.TEMPORARY: + query = query.where(dictionary.all_objects.c.temporary == "Y") - if resolve_synonyms: - actual_name, owner, dblink, synonym = self._resolve_synonym( - connection, - desired_owner=self.denormalize_name(schema), - desired_synonym=self.denormalize_name(table_name), + if has_filter_names: + query = query.where( + dictionary.all_objects.c.object_name.in_( + bindparam("filter_names") + ) ) - else: - actual_name, owner, dblink, synonym = None, None, None, None - if not actual_name: - actual_name = self.denormalize_name(table_name) - - if dblink: - # using user_db_links here since all_db_links appears - # to have more restricted permissions. - # https://docs.oracle.com/cd/B28359_01/server.111/b28310/ds_admin005.htm - # will need to hear from more users if we are doing - # the right thing here. See [ticket:2619] - owner = connection.scalar( - sql.text( - "SELECT username FROM user_db_links " "WHERE db_link=:link" - ), - dict(link=dblink), + return query + + @reflection.flexi_cache( + ("schema", InternalTraversal.dp_string), + ("scope", InternalTraversal.dp_plain_obj), + ("kind", InternalTraversal.dp_plain_obj), + ("filter_names", InternalTraversal.dp_string_list), + ("dblink", InternalTraversal.dp_string), + ) + def _get_all_objects( + self, connection, schema, scope, kind, filter_names, dblink, **kw + ): + owner = self.denormalize_name(schema or self.default_schema_name) + + has_filter_names, params = self._prepare_filter_names(filter_names) + has_mat_views = False + if ( + ObjectKind.TABLE in kind + and ObjectKind.MATERIALIZED_VIEW not in kind + ): + # see note in _all_objects_query + mat_views = self.get_materialized_view_names( + connection, schema, dblink, _normalize=False, **kw ) - dblink = "@" + dblink - elif not owner: - owner = self.denormalize_name(schema or self.default_schema_name) + if mat_views: + params["mat_views"] = mat_views + has_mat_views = True + + query = self._all_objects_query( + owner, scope, kind, has_filter_names, has_mat_views + ) - return (actual_name, owner, dblink or "", synonym) + result = self._execute_reflection( + connection, query, dblink, returns_long=False, params=params + ).scalars() - @reflection.cache - def get_schema_names(self, connection, **kw): - s = "SELECT username FROM all_users ORDER BY username" - cursor = connection.exec_driver_sql(s) - return [self.normalize_name(row[0]) for row in cursor] + return result.all() + + def _handle_synonyms_decorator(fn): + @wraps(fn) + def wrapper(self, *args, **kwargs): + return self._handle_synonyms(fn, *args, **kwargs) + + return wrapper + + def _handle_synonyms(self, fn, connection, *args, **kwargs): + if not kwargs.get("oracle_resolve_synonyms", False): + return fn(self, connection, *args, **kwargs) + + original_kw = kwargs.copy() + schema = kwargs.pop("schema", None) + result = self._get_synonyms( + connection, + schema=schema, + filter_names=kwargs.pop("filter_names", None), + dblink=kwargs.pop("dblink", None), + info_cache=kwargs.get("info_cache", None), + ) + + dblinks_owners = defaultdict(dict) + for row in result: + key = row["db_link"], row["table_owner"] + tn = self.normalize_name(row["table_name"]) + dblinks_owners[key][tn] = row["synonym_name"] + + if not dblinks_owners: + # No synonym, do the plain thing + return fn(self, connection, *args, **original_kw) + + data = {} + for (dblink, table_owner), mapping in dblinks_owners.items(): + call_kw = { + **original_kw, + "schema": table_owner, + "dblink": self.normalize_name(dblink), + "filter_names": mapping.keys(), + } + call_result = fn(self, connection, *args, **call_kw) + for (_, tn), value in call_result: + synonym_name = self.normalize_name(mapping[tn]) + data[(schema, synonym_name)] = value + return data.items() @reflection.cache - def get_table_names(self, connection, schema=None, **kw): - schema = self.denormalize_name(schema or self.default_schema_name) + def get_schema_names(self, connection, dblink=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" + query = select(dictionary.all_users.c.username).order_by( + dictionary.all_users.c.username + ) + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalars() + return [self.normalize_name(row) for row in result] + @reflection.cache + def get_table_names(self, connection, schema=None, dblink=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" # note that table_names() isn't loading DBLINKed or synonym'ed tables if schema is None: schema = self.default_schema_name - sql_str = "SELECT table_name FROM all_tables WHERE " + den_schema = self.denormalize_name(schema) + if kw.get("oracle_resolve_synonyms", False): + tables = ( + select( + dictionary.all_tables.c.table_name, + dictionary.all_tables.c.owner, + dictionary.all_tables.c.iot_name, + dictionary.all_tables.c.duration, + dictionary.all_tables.c.tablespace_name, + ) + .union_all( + select( + dictionary.all_synonyms.c.synonym_name.label( + "table_name" + ), + dictionary.all_synonyms.c.owner, + dictionary.all_tables.c.iot_name, + dictionary.all_tables.c.duration, + dictionary.all_tables.c.tablespace_name, + ) + .select_from(dictionary.all_tables) + .join( + dictionary.all_synonyms, + and_( + dictionary.all_tables.c.table_name + == dictionary.all_synonyms.c.table_name, + dictionary.all_tables.c.owner + == func.coalesce( + dictionary.all_synonyms.c.table_owner, + dictionary.all_synonyms.c.owner, + ), + ), + ) + ) + .subquery("available_tables") + ) + else: + tables = dictionary.all_tables + + query = select(tables.c.table_name) if self.exclude_tablespaces: - sql_str += ( - "nvl(tablespace_name, 'no tablespace') " - "NOT IN (%s) AND " - % (", ".join(["'%s'" % ts for ts in self.exclude_tablespaces])) + query = query.where( + func.coalesce( + tables.c.tablespace_name, "no tablespace" + ).not_in(self.exclude_tablespaces) ) - sql_str += ( - "OWNER = :owner " "AND IOT_NAME IS NULL " "AND DURATION IS NULL" + query = query.where( + tables.c.owner == den_schema, + tables.c.iot_name.is_(null()), + tables.c.duration.is_(null()), ) - cursor = connection.execute(sql.text(sql_str), dict(owner=schema)) - return [self.normalize_name(row[0]) for row in cursor] + # remove materialized views + mat_query = select( + dictionary.all_mviews.c.mview_name.label("table_name") + ).where(dictionary.all_mviews.c.owner == den_schema) + + query = ( + query.except_all(mat_query) + if self._supports_except_all + else query.except_(mat_query) + ) + + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalars() + return [self.normalize_name(row) for row in result] @reflection.cache - def get_temp_table_names(self, connection, **kw): + def get_temp_table_names(self, connection, dblink=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" schema = self.denormalize_name(self.default_schema_name) - sql_str = "SELECT table_name FROM all_tables WHERE " + query = select(dictionary.all_tables.c.table_name) if self.exclude_tablespaces: - sql_str += ( - "nvl(tablespace_name, 'no tablespace') " - "NOT IN (%s) AND " - % (", ".join(["'%s'" % ts for ts in self.exclude_tablespaces])) + query = query.where( + func.coalesce( + dictionary.all_tables.c.tablespace_name, "no tablespace" + ).not_in(self.exclude_tablespaces) ) - sql_str += ( - "OWNER = :owner " - "AND IOT_NAME IS NULL " - "AND DURATION IS NOT NULL" + query = query.where( + dictionary.all_tables.c.owner == schema, + dictionary.all_tables.c.iot_name.is_(null()), + dictionary.all_tables.c.duration.is_not(null()), ) - cursor = connection.execute(sql.text(sql_str), dict(owner=schema)) - return [self.normalize_name(row[0]) for row in cursor] + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalars() + return [self.normalize_name(row) for row in result] @reflection.cache - def get_view_names(self, connection, schema=None, **kw): - schema = self.denormalize_name(schema or self.default_schema_name) - s = sql.text("SELECT view_name FROM all_views WHERE owner = :owner") - cursor = connection.execute( - s, dict(owner=self.denormalize_name(schema)) + def get_materialized_view_names( + self, connection, schema=None, dblink=None, _normalize=True, **kw + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" + if not schema: + schema = self.default_schema_name + + query = select(dictionary.all_mviews.c.mview_name).where( + dictionary.all_mviews.c.owner == self.denormalize_name(schema) ) - return [self.normalize_name(row[0]) for row in cursor] + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalars() + if _normalize: + return [self.normalize_name(row) for row in result] + else: + return result.all() @reflection.cache - def get_sequence_names(self, connection, schema=None, **kw): + def get_view_names(self, connection, schema=None, dblink=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" if not schema: schema = self.default_schema_name - cursor = connection.execute( - sql.text( - "SELECT sequence_name FROM all_sequences " - "WHERE sequence_owner = :schema_name" - ), - dict(schema_name=self.denormalize_name(schema)), + + query = select(dictionary.all_views.c.view_name).where( + dictionary.all_views.c.owner == self.denormalize_name(schema) ) - return [self.normalize_name(row[0]) for row in cursor] + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalars() + return [self.normalize_name(row) for row in result] @reflection.cache - def get_table_options(self, connection, table_name, schema=None, **kw): - options = {} + def get_sequence_names(self, connection, schema=None, dblink=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" + if not schema: + schema = self.default_schema_name + query = select(dictionary.all_sequences.c.sequence_name).where( + dictionary.all_sequences.c.sequence_owner + == self.denormalize_name(schema) + ) - resolve_synonyms = kw.get("oracle_resolve_synonyms", False) - dblink = kw.get("dblink", "") - info_cache = kw.get("info_cache") + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalars() + return [self.normalize_name(row) for row in result] - (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + def _value_or_raise(self, data, table, schema): + table = self.normalize_name(str(table)) + try: + return dict(data)[(schema, table)] + except KeyError: + raise exc.NoSuchTableError( + f"{schema}.{table}" if schema else table + ) from None + + def _prepare_filter_names(self, filter_names): + if filter_names: + fn = [self.denormalize_name(name) for name in filter_names] + return True, {"filter_names": fn} + else: + return False, {} + + @reflection.cache + def get_table_options(self, connection, table_name, schema=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + data = self.get_multi_table_options( connection, - table_name, - schema, - resolve_synonyms, - dblink, - info_cache=info_cache, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) + return self._value_or_raise(data, table_name, schema) - params = {"table_name": table_name} + @lru_cache() + def _table_options_query( + self, owner, scope, kind, has_filter_names, has_mat_views + ): + query = select( + dictionary.all_tables.c.table_name, + dictionary.all_tables.c.compression, + dictionary.all_tables.c.compress_for, + ).where(dictionary.all_tables.c.owner == owner) + if has_filter_names: + query = query.where( + dictionary.all_tables.c.table_name.in_( + bindparam("filter_names") + ) + ) + if scope is ObjectScope.DEFAULT: + query = query.where(dictionary.all_tables.c.duration.is_(null())) + elif scope is ObjectScope.TEMPORARY: + query = query.where( + dictionary.all_tables.c.duration.is_not(null()) + ) - columns = ["table_name"] - if self._supports_table_compression: - columns.append("compression") - if self._supports_table_compress_for: - columns.append("compress_for") + if ( + has_mat_views + and ObjectKind.TABLE in kind + and ObjectKind.MATERIALIZED_VIEW not in kind + ): + # cant use EXCEPT ALL / MINUS here because we don't have an + # excludable row vs. the query above + # outerjoin + where null works better on oracle 21 but 11 does + # not like it at all. this is the next best thing + + query = query.where( + dictionary.all_tables.c.table_name.not_in( + bindparam("mat_views") + ) + ) + elif ( + ObjectKind.TABLE not in kind + and ObjectKind.MATERIALIZED_VIEW in kind + ): + query = query.where( + dictionary.all_tables.c.table_name.in_(bindparam("mat_views")) + ) + return query - text = ( - "SELECT %(columns)s " - "FROM ALL_TABLES%(dblink)s " - "WHERE table_name = CAST(:table_name AS VARCHAR(128))" - ) + @_handle_synonyms_decorator + def get_multi_table_options( + self, + connection, + *, + schema, + filter_names, + scope, + kind, + dblink=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + owner = self.denormalize_name(schema or self.default_schema_name) - if schema is not None: - params["owner"] = schema - text += " AND owner = CAST(:owner AS VARCHAR(128)) " - text = text % {"dblink": dblink, "columns": ", ".join(columns)} + has_filter_names, params = self._prepare_filter_names(filter_names) + has_mat_views = False - result = connection.execute(sql.text(text), params) + if ( + ObjectKind.TABLE in kind + and ObjectKind.MATERIALIZED_VIEW not in kind + ): + # see note in _table_options_query + mat_views = self.get_materialized_view_names( + connection, schema, dblink, _normalize=False, **kw + ) + if mat_views: + params["mat_views"] = mat_views + has_mat_views = True + elif ( + ObjectKind.TABLE not in kind + and ObjectKind.MATERIALIZED_VIEW in kind + ): + mat_views = self.get_materialized_view_names( + connection, schema, dblink, _normalize=False, **kw + ) + params["mat_views"] = mat_views - enabled = dict(DISABLED=False, ENABLED=True) + options = {} + default = ReflectionDefaults.table_options - row = result.first() - if row: - if "compression" in row._fields and enabled.get( - row.compression, False - ): - if "compress_for" in row._fields: - options["oracle_compress"] = row.compress_for + if ObjectKind.TABLE in kind or ObjectKind.MATERIALIZED_VIEW in kind: + query = self._table_options_query( + owner, scope, kind, has_filter_names, has_mat_views + ) + result = self._execute_reflection( + connection, query, dblink, returns_long=False, params=params + ) + + for table, compression, compress_for in result: + if compression == "ENABLED": + data = {"oracle_compress": compress_for} else: - options["oracle_compress"] = True + data = default() + options[(schema, self.normalize_name(table))] = data + if ObjectKind.VIEW in kind and ObjectScope.DEFAULT in scope: + # add the views (no temporary views) + for view in self.get_view_names(connection, schema, dblink, **kw): + if not filter_names or view in filter_names: + options[(schema, view)] = default() - return options + return options.items() @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms """ - kw arguments can be: + data = self.get_multi_columns( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + def _run_batches( + self, connection, query, dblink, returns_long, mappings, all_objects + ): + each_batch = 500 + batches = list(all_objects) + while batches: + batch = batches[0:each_batch] + batches[0:each_batch] = [] + + result = self._execute_reflection( + connection, + query, + dblink, + returns_long=returns_long, + params={"all_objects": batch}, + ) + if mappings: + yield from result.mappings() + else: + yield from result + + @lru_cache() + def _column_query(self, owner): + all_cols = dictionary.all_tab_cols + all_comments = dictionary.all_col_comments + all_ids = dictionary.all_tab_identity_cols - oracle_resolve_synonyms + if self.server_version_info >= (12,): + add_cols = ( + all_cols.c.default_on_null, + sql.case( + (all_ids.c.table_name.is_(None), sql.null()), + else_=all_ids.c.generation_type + + "," + + all_ids.c.identity_options, + ).label("identity_options"), + ) + join_identity_cols = True + else: + add_cols = ( + sql.null().label("default_on_null"), + sql.null().label("identity_options"), + ) + join_identity_cols = False + + # NOTE: on oracle cannot create tables/views without columns and + # a table cannot have all column hidden: + # ORA-54039: table must have at least one column that is not invisible + # all_tab_cols returns data for tables/views/mat-views. + # all_tab_cols does not return recycled tables + + query = ( + select( + all_cols.c.table_name, + all_cols.c.column_name, + all_cols.c.data_type, + all_cols.c.char_length, + all_cols.c.data_precision, + all_cols.c.data_scale, + all_cols.c.nullable, + all_cols.c.data_default, + all_comments.c.comments, + all_cols.c.virtual_column, + *add_cols, + ).select_from(all_cols) + # NOTE: all_col_comments has a row for each column even if no + # comment is present, so a join could be performed, but there + # seems to be no difference compared to an outer join + .outerjoin( + all_comments, + and_( + all_cols.c.table_name == all_comments.c.table_name, + all_cols.c.column_name == all_comments.c.column_name, + all_cols.c.owner == all_comments.c.owner, + ), + ) + ) + if join_identity_cols: + query = query.outerjoin( + all_ids, + and_( + all_cols.c.table_name == all_ids.c.table_name, + all_cols.c.column_name == all_ids.c.column_name, + all_cols.c.owner == all_ids.c.owner, + ), + ) - dblink + query = query.where( + all_cols.c.table_name.in_(bindparam("all_objects")), + all_cols.c.hidden_column == "NO", + all_cols.c.owner == owner, + ).order_by(all_cols.c.table_name, all_cols.c.column_id) + return query + @_handle_synonyms_decorator + def get_multi_columns( + self, + connection, + *, + schema, + filter_names, + scope, + kind, + dblink=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms """ + owner = self.denormalize_name(schema or self.default_schema_name) + query = self._column_query(owner) - resolve_synonyms = kw.get("oracle_resolve_synonyms", False) - dblink = kw.get("dblink", "") - info_cache = kw.get("info_cache") + if ( + filter_names + and kind is ObjectKind.ANY + and scope is ObjectScope.ANY + ): + all_objects = [self.denormalize_name(n) for n in filter_names] + else: + all_objects = self._get_all_objects( + connection, schema, scope, kind, filter_names, dblink, **kw + ) - (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + columns = defaultdict(list) + + # all_tab_cols.data_default is LONG + result = self._run_batches( connection, - table_name, - schema, - resolve_synonyms, + query, dblink, - info_cache=info_cache, + returns_long=True, + mappings=True, + all_objects=all_objects, ) - columns = [] - if self._supports_char_length: - char_length_col = "char_length" - else: - char_length_col = "data_length" - if self.server_version_info >= (12,): - identity_cols = """\ - col.default_on_null, - ( - SELECT id.generation_type || ',' || id.IDENTITY_OPTIONS - FROM ALL_TAB_IDENTITY_COLS%(dblink)s id - WHERE col.table_name = id.table_name - AND col.column_name = id.column_name - AND col.owner = id.owner - ) AS identity_options""" % { - "dblink": dblink - } - else: - identity_cols = "NULL as default_on_null, NULL as identity_options" - - params = {"table_name": table_name} - - text = """ - SELECT - col.column_name, - col.data_type, - col.%(char_length_col)s, - col.data_precision, - col.data_scale, - col.nullable, - col.data_default, - com.comments, - col.virtual_column, - %(identity_cols)s - FROM all_tab_cols%(dblink)s col - LEFT JOIN all_col_comments%(dblink)s com - ON col.table_name = com.table_name - AND col.column_name = com.column_name - AND col.owner = com.owner - WHERE col.table_name = CAST(:table_name AS VARCHAR2(128)) - AND col.hidden_column = 'NO' - """ - if schema is not None: - params["owner"] = schema - text += " AND col.owner = :owner " - text += " ORDER BY col.column_id" - text = text % { - "dblink": dblink, - "char_length_col": char_length_col, - "identity_cols": identity_cols, - } - - c = connection.execute(sql.text(text), params) - - for row in c: - colname = self.normalize_name(row[0]) - orig_colname = row[0] - coltype = row[1] - length = row[2] - precision = row[3] - scale = row[4] - nullable = row[5] == "Y" - default = row[6] - comment = row[7] - generated = row[8] - default_on_nul = row[9] - identity_options = row[10] + for row_dict in result: + table_name = self.normalize_name(row_dict["table_name"]) + orig_colname = row_dict["column_name"] + colname = self.normalize_name(orig_colname) + coltype = row_dict["data_type"] + precision = row_dict["data_precision"] if coltype == "NUMBER": + scale = row_dict["data_scale"] if precision is None and scale == 0: coltype = INTEGER() else: @@ -2089,7 +2266,9 @@ class OracleDialect(default.DefaultDialect): coltype = FLOAT(binary_precision=precision) elif coltype in ("VARCHAR2", "NVARCHAR2", "CHAR", "NCHAR"): - coltype = self.ischema_names.get(coltype)(length) + coltype = self.ischema_names.get(coltype)( + row_dict["char_length"] + ) elif "WITH TIME ZONE" in coltype: coltype = TIMESTAMP(timezone=True) else: @@ -2103,15 +2282,17 @@ class OracleDialect(default.DefaultDialect): ) coltype = sqltypes.NULLTYPE - if generated == "YES": + default = row_dict["data_default"] + if row_dict["virtual_column"] == "YES": computed = dict(sqltext=default) default = None else: computed = None + identity_options = row_dict["identity_options"] if identity_options is not None: identity = self._parse_identity_options( - identity_options, default_on_nul + identity_options, row_dict["default_on_null"] ) default = None else: @@ -2120,10 +2301,9 @@ class OracleDialect(default.DefaultDialect): cdict = { "name": colname, "type": coltype, - "nullable": nullable, + "nullable": row_dict["nullable"] == "Y", "default": default, - "autoincrement": "auto", - "comment": comment, + "comment": row_dict["comments"], } if orig_colname.lower() == orig_colname: cdict["quote"] = True @@ -2132,10 +2312,17 @@ class OracleDialect(default.DefaultDialect): if identity is not None: cdict["identity"] = identity - columns.append(cdict) - return columns + columns[(schema, table_name)].append(cdict) - def _parse_identity_options(self, identity_options, default_on_nul): + # NOTE: default not needed since all tables have columns + # default = ReflectionDefaults.columns + # return ( + # (key, value if value else default()) + # for key, value in columns.items() + # ) + return columns.items() + + def _parse_identity_options(self, identity_options, default_on_null): # identity_options is a string that starts with 'ALWAYS,' or # 'BY DEFAULT,' and continues with # START WITH: 1, INCREMENT BY: 1, MAX_VALUE: 123, MIN_VALUE: 1, @@ -2144,7 +2331,7 @@ class OracleDialect(default.DefaultDialect): parts = [p.strip() for p in identity_options.split(",")] identity = { "always": parts[0] == "ALWAYS", - "on_null": default_on_nul == "YES", + "on_null": default_on_null == "YES", } for part in parts[1:]: @@ -2168,384 +2355,641 @@ class OracleDialect(default.DefaultDialect): return identity @reflection.cache - def get_table_comment( + def get_table_comment(self, connection, table_name, schema=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + data = self.get_multi_table_comment( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + @lru_cache() + def _comment_query(self, owner, scope, kind, has_filter_names): + # NOTE: all_tab_comments / all_mview_comments have a row for all + # object even if they don't have comments + queries = [] + if ObjectKind.TABLE in kind or ObjectKind.VIEW in kind: + # all_tab_comments returns also plain views + tbl_view = select( + dictionary.all_tab_comments.c.table_name, + dictionary.all_tab_comments.c.comments, + ).where( + dictionary.all_tab_comments.c.owner == owner, + dictionary.all_tab_comments.c.table_name.not_like("BIN$%"), + ) + if ObjectKind.VIEW not in kind: + tbl_view = tbl_view.where( + dictionary.all_tab_comments.c.table_type == "TABLE" + ) + elif ObjectKind.TABLE not in kind: + tbl_view = tbl_view.where( + dictionary.all_tab_comments.c.table_type == "VIEW" + ) + queries.append(tbl_view) + if ObjectKind.MATERIALIZED_VIEW in kind: + mat_view = select( + dictionary.all_mview_comments.c.mview_name.label("table_name"), + dictionary.all_mview_comments.c.comments, + ).where( + dictionary.all_mview_comments.c.owner == owner, + dictionary.all_mview_comments.c.mview_name.not_like("BIN$%"), + ) + queries.append(mat_view) + if len(queries) == 1: + query = queries[0] + else: + union = sql.union_all(*queries).subquery("tables_and_views") + query = select(union.c.table_name, union.c.comments) + + name_col = query.selected_columns.table_name + + if scope in (ObjectScope.DEFAULT, ObjectScope.TEMPORARY): + temp = "Y" if scope is ObjectScope.TEMPORARY else "N" + # need distinct since materialized view are listed also + # as tables in all_objects + query = query.distinct().join( + dictionary.all_objects, + and_( + dictionary.all_objects.c.owner == owner, + dictionary.all_objects.c.object_name == name_col, + dictionary.all_objects.c.temporary == temp, + ), + ) + if has_filter_names: + query = query.where(name_col.in_(bindparam("filter_names"))) + return query + + @_handle_synonyms_decorator + def get_multi_table_comment( self, connection, - table_name, - schema=None, - resolve_synonyms=False, - dblink="", + *, + schema, + filter_names, + scope, + kind, + dblink=None, **kw, ): - - info_cache = kw.get("info_cache") - (table_name, schema, dblink, synonym) = self._prepare_reflection_args( - connection, - table_name, - schema, - resolve_synonyms, - dblink, - info_cache=info_cache, - ) - - if not schema: - schema = self.default_schema_name - - COMMENT_SQL = """ - SELECT comments - FROM all_tab_comments - WHERE table_name = CAST(:table_name AS VARCHAR(128)) - AND owner = CAST(:schema_name AS VARCHAR(128)) + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms """ + owner = self.denormalize_name(schema or self.default_schema_name) + has_filter_names, params = self._prepare_filter_names(filter_names) + query = self._comment_query(owner, scope, kind, has_filter_names) - c = connection.execute( - sql.text(COMMENT_SQL), - dict(table_name=table_name, schema_name=schema), + result = self._execute_reflection( + connection, query, dblink, returns_long=False, params=params + ) + default = ReflectionDefaults.table_comment + # materialized views by default seem to have a comment like + # "snapshot table for snapshot owner.mat_view_name" + ignore_mat_view = "snapshot table for snapshot " + return ( + ( + (schema, self.normalize_name(table)), + {"text": comment} + if comment is not None + and not comment.startswith(ignore_mat_view) + else default(), + ) + for table, comment in result ) - return {"text": c.scalar()} @reflection.cache - def get_indexes( - self, - connection, - table_name, - schema=None, - resolve_synonyms=False, - dblink="", - **kw, - ): - - info_cache = kw.get("info_cache") - (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + def get_indexes(self, connection, table_name, schema=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + data = self.get_multi_indexes( connection, - table_name, - schema, - resolve_synonyms, - dblink, - info_cache=info_cache, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) - indexes = [] - - params = {"table_name": table_name} - text = ( - "SELECT a.index_name, a.column_name, " - "\nb.index_type, b.uniqueness, b.compression, b.prefix_length " - "\nFROM ALL_IND_COLUMNS%(dblink)s a, " - "\nALL_INDEXES%(dblink)s b " - "\nWHERE " - "\na.index_name = b.index_name " - "\nAND a.table_owner = b.table_owner " - "\nAND a.table_name = b.table_name " - "\nAND a.table_name = CAST(:table_name AS VARCHAR(128))" + return self._value_or_raise(data, table_name, schema) + + @lru_cache() + def _index_query(self, owner): + return ( + select( + dictionary.all_ind_columns.c.table_name, + dictionary.all_ind_columns.c.index_name, + dictionary.all_ind_columns.c.column_name, + dictionary.all_indexes.c.index_type, + dictionary.all_indexes.c.uniqueness, + dictionary.all_indexes.c.compression, + dictionary.all_indexes.c.prefix_length, + ) + .select_from(dictionary.all_ind_columns) + .join( + dictionary.all_indexes, + sql.and_( + dictionary.all_ind_columns.c.index_name + == dictionary.all_indexes.c.index_name, + dictionary.all_ind_columns.c.table_owner + == dictionary.all_indexes.c.table_owner, + # NOTE: this condition on table_name is not required + # but it improves the query performance noticeably + dictionary.all_ind_columns.c.table_name + == dictionary.all_indexes.c.table_name, + ), + ) + .where( + dictionary.all_ind_columns.c.table_owner == owner, + dictionary.all_ind_columns.c.table_name.in_( + bindparam("all_objects") + ), + ) + .order_by( + dictionary.all_ind_columns.c.index_name, + dictionary.all_ind_columns.c.column_position, + ) ) - if schema is not None: - params["schema"] = schema - text += "AND a.table_owner = :schema " + @reflection.flexi_cache( + ("schema", InternalTraversal.dp_string), + ("dblink", InternalTraversal.dp_string), + ("all_objects", InternalTraversal.dp_string_list), + ) + def _get_indexes_rows(self, connection, schema, dblink, all_objects, **kw): + owner = self.denormalize_name(schema or self.default_schema_name) - text += "ORDER BY a.index_name, a.column_position" + query = self._index_query(owner) - text = text % {"dblink": dblink} + pks = { + row_dict["constraint_name"] + for row_dict in self._get_all_constraint_rows( + connection, schema, dblink, all_objects, **kw + ) + if row_dict["constraint_type"] == "P" + } - q = sql.text(text) - rp = connection.execute(q, params) - indexes = [] - last_index_name = None - pk_constraint = self.get_pk_constraint( + result = self._run_batches( connection, - table_name, - schema, - resolve_synonyms=resolve_synonyms, - dblink=dblink, - info_cache=kw.get("info_cache"), + query, + dblink, + returns_long=False, + mappings=True, + all_objects=all_objects, ) - uniqueness = dict(NONUNIQUE=False, UNIQUE=True) - enabled = dict(DISABLED=False, ENABLED=True) + return [ + row_dict + for row_dict in result + if row_dict["index_name"] not in pks + ] - oracle_sys_col = re.compile(r"SYS_NC\d+\$", re.IGNORECASE) + @_handle_synonyms_decorator + def get_multi_indexes( + self, + connection, + *, + schema, + filter_names, + scope, + kind, + dblink=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + all_objects = self._get_all_objects( + connection, schema, scope, kind, filter_names, dblink, **kw + ) - index = None - for rset in rp: - index_name_normalized = self.normalize_name(rset.index_name) + uniqueness = {"NONUNIQUE": False, "UNIQUE": True} + enabled = {"DISABLED": False, "ENABLED": True} + is_bitmap = {"BITMAP", "FUNCTION-BASED BITMAP"} - # skip primary key index. This is refined as of - # [ticket:5421]. Note that ALL_INDEXES.GENERATED will by "Y" - # if the name of this index was generated by Oracle, however - # if a named primary key constraint was created then this flag - # is false. - if ( - pk_constraint - and index_name_normalized == pk_constraint["name"] - ): - continue + oracle_sys_col = re.compile(r"SYS_NC\d+\$", re.IGNORECASE) - if rset.index_name != last_index_name: - index = dict( - name=index_name_normalized, - column_names=[], - dialect_options={}, - ) - indexes.append(index) - index["unique"] = uniqueness.get(rset.uniqueness, False) + indexes = defaultdict(dict) + + for row_dict in self._get_indexes_rows( + connection, schema, dblink, all_objects, **kw + ): + index_name = self.normalize_name(row_dict["index_name"]) + table_name = self.normalize_name(row_dict["table_name"]) + table_indexes = indexes[(schema, table_name)] + + if index_name not in table_indexes: + table_indexes[index_name] = index_dict = { + "name": index_name, + "column_names": [], + "dialect_options": {}, + "unique": uniqueness.get(row_dict["uniqueness"], False), + } + do = index_dict["dialect_options"] + if row_dict["index_type"] in is_bitmap: + do["oracle_bitmap"] = True + if enabled.get(row_dict["compression"], False): + do["oracle_compress"] = row_dict["prefix_length"] - if rset.index_type in ("BITMAP", "FUNCTION-BASED BITMAP"): - index["dialect_options"]["oracle_bitmap"] = True - if enabled.get(rset.compression, False): - index["dialect_options"][ - "oracle_compress" - ] = rset.prefix_length + else: + index_dict = table_indexes[index_name] # filter out Oracle SYS_NC names. could also do an outer join - # to the all_tab_columns table and check for real col names there. - if not oracle_sys_col.match(rset.column_name): - index["column_names"].append( - self.normalize_name(rset.column_name) + # to the all_tab_columns table and check for real col names + # there. + if not oracle_sys_col.match(row_dict["column_name"]): + index_dict["column_names"].append( + self.normalize_name(row_dict["column_name"]) ) - last_index_name = rset.index_name - return indexes + default = ReflectionDefaults.indexes - @reflection.cache - def _get_constraint_data( - self, connection, table_name, schema=None, dblink="", **kw - ): - - params = {"table_name": table_name} - - text = ( - "SELECT" - "\nac.constraint_name," # 0 - "\nac.constraint_type," # 1 - "\nloc.column_name AS local_column," # 2 - "\nrem.table_name AS remote_table," # 3 - "\nrem.column_name AS remote_column," # 4 - "\nrem.owner AS remote_owner," # 5 - "\nloc.position as loc_pos," # 6 - "\nrem.position as rem_pos," # 7 - "\nac.search_condition," # 8 - "\nac.delete_rule" # 9 - "\nFROM all_constraints%(dblink)s ac," - "\nall_cons_columns%(dblink)s loc," - "\nall_cons_columns%(dblink)s rem" - "\nWHERE ac.table_name = CAST(:table_name AS VARCHAR2(128))" - "\nAND ac.constraint_type IN ('R','P', 'U', 'C')" - ) - - if schema is not None: - params["owner"] = schema - text += "\nAND ac.owner = CAST(:owner AS VARCHAR2(128))" - - text += ( - "\nAND ac.owner = loc.owner" - "\nAND ac.constraint_name = loc.constraint_name" - "\nAND ac.r_owner = rem.owner(+)" - "\nAND ac.r_constraint_name = rem.constraint_name(+)" - "\nAND (rem.position IS NULL or loc.position=rem.position)" - "\nORDER BY ac.constraint_name, loc.position" + return ( + (key, list(indexes[key].values()) if key in indexes else default()) + for key in ( + (schema, self.normalize_name(obj_name)) + for obj_name in all_objects + ) ) - text = text % {"dblink": dblink} - rp = connection.execute(sql.text(text), params) - constraint_data = rp.fetchall() - return constraint_data - @reflection.cache def get_pk_constraint(self, connection, table_name, schema=None, **kw): - resolve_synonyms = kw.get("oracle_resolve_synonyms", False) - dblink = kw.get("dblink", "") - info_cache = kw.get("info_cache") - - (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + data = self.get_multi_pk_constraint( connection, - table_name, - schema, - resolve_synonyms, - dblink, - info_cache=info_cache, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) - pkeys = [] - constraint_name = None - constraint_data = self._get_constraint_data( - connection, - table_name, - schema, - dblink, - info_cache=kw.get("info_cache"), + return self._value_or_raise(data, table_name, schema) + + @lru_cache() + def _constraint_query(self, owner): + local = dictionary.all_cons_columns.alias("local") + remote = dictionary.all_cons_columns.alias("remote") + return ( + select( + dictionary.all_constraints.c.table_name, + dictionary.all_constraints.c.constraint_type, + dictionary.all_constraints.c.constraint_name, + local.c.column_name.label("local_column"), + remote.c.table_name.label("remote_table"), + remote.c.column_name.label("remote_column"), + remote.c.owner.label("remote_owner"), + dictionary.all_constraints.c.search_condition, + dictionary.all_constraints.c.delete_rule, + ) + .select_from(dictionary.all_constraints) + .join( + local, + and_( + local.c.owner == dictionary.all_constraints.c.owner, + dictionary.all_constraints.c.constraint_name + == local.c.constraint_name, + ), + ) + .outerjoin( + remote, + and_( + dictionary.all_constraints.c.r_owner == remote.c.owner, + dictionary.all_constraints.c.r_constraint_name + == remote.c.constraint_name, + or_( + remote.c.position.is_(sql.null()), + local.c.position == remote.c.position, + ), + ), + ) + .where( + dictionary.all_constraints.c.owner == owner, + dictionary.all_constraints.c.table_name.in_( + bindparam("all_objects") + ), + dictionary.all_constraints.c.constraint_type.in_( + ("R", "P", "U", "C") + ), + ) + .order_by( + dictionary.all_constraints.c.constraint_name, local.c.position + ) ) - for row in constraint_data: - ( - cons_name, - cons_type, - local_column, - remote_table, - remote_column, - remote_owner, - ) = row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]]) - if cons_type == "P": - if constraint_name is None: - constraint_name = self.normalize_name(cons_name) - pkeys.append(local_column) - return {"constrained_columns": pkeys, "name": constraint_name} + @reflection.flexi_cache( + ("schema", InternalTraversal.dp_string), + ("dblink", InternalTraversal.dp_string), + ("all_objects", InternalTraversal.dp_string_list), + ) + def _get_all_constraint_rows( + self, connection, schema, dblink, all_objects, **kw + ): + owner = self.denormalize_name(schema or self.default_schema_name) + query = self._constraint_query(owner) - @reflection.cache - def get_foreign_keys(self, connection, table_name, schema=None, **kw): + # since the result is cached a list must be created + values = list( + self._run_batches( + connection, + query, + dblink, + returns_long=False, + mappings=True, + all_objects=all_objects, + ) + ) + return values + + @_handle_synonyms_decorator + def get_multi_pk_constraint( + self, + connection, + *, + scope, + schema, + filter_names, + kind, + dblink=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms """ + all_objects = self._get_all_objects( + connection, schema, scope, kind, filter_names, dblink, **kw + ) - kw arguments can be: + primary_keys = defaultdict(dict) + default = ReflectionDefaults.pk_constraint - oracle_resolve_synonyms + for row_dict in self._get_all_constraint_rows( + connection, schema, dblink, all_objects, **kw + ): + if row_dict["constraint_type"] != "P": + continue + table_name = self.normalize_name(row_dict["table_name"]) + constraint_name = self.normalize_name(row_dict["constraint_name"]) + column_name = self.normalize_name(row_dict["local_column"]) + + table_pk = primary_keys[(schema, table_name)] + if not table_pk: + table_pk["name"] = constraint_name + table_pk["constrained_columns"] = [column_name] + else: + table_pk["constrained_columns"].append(column_name) - dblink + return ( + (key, primary_keys[key] if key in primary_keys else default()) + for key in ( + (schema, self.normalize_name(obj_name)) + for obj_name in all_objects + ) + ) + @reflection.cache + def get_foreign_keys( + self, + connection, + table_name, + schema=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms """ - requested_schema = schema # to check later on - resolve_synonyms = kw.get("oracle_resolve_synonyms", False) - dblink = kw.get("dblink", "") - info_cache = kw.get("info_cache") - - (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + data = self.get_multi_foreign_keys( connection, - table_name, - schema, - resolve_synonyms, - dblink, - info_cache=info_cache, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) + return self._value_or_raise(data, table_name, schema) - constraint_data = self._get_constraint_data( - connection, - table_name, - schema, - dblink, - info_cache=kw.get("info_cache"), + @_handle_synonyms_decorator + def get_multi_foreign_keys( + self, + connection, + *, + scope, + schema, + filter_names, + kind, + dblink=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + all_objects = self._get_all_objects( + connection, schema, scope, kind, filter_names, dblink, **kw ) - def fkey_rec(): - return { - "name": None, - "constrained_columns": [], - "referred_schema": None, - "referred_table": None, - "referred_columns": [], - "options": {}, - } + resolve_synonyms = kw.get("oracle_resolve_synonyms", False) - fkeys = util.defaultdict(fkey_rec) + owner = self.denormalize_name(schema or self.default_schema_name) - for row in constraint_data: - ( - cons_name, - cons_type, - local_column, - remote_table, - remote_column, - remote_owner, - ) = row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]]) - - cons_name = self.normalize_name(cons_name) - - if cons_type == "R": - if remote_table is None: - # ticket 363 - util.warn( - ( - "Got 'None' querying 'table_name' from " - "all_cons_columns%(dblink)s - does the user have " - "proper rights to the table?" - ) - % {"dblink": dblink} - ) - continue + all_remote_owners = set() + fkeys = defaultdict(dict) + + for row_dict in self._get_all_constraint_rows( + connection, schema, dblink, all_objects, **kw + ): + if row_dict["constraint_type"] != "R": + continue + + table_name = self.normalize_name(row_dict["table_name"]) + constraint_name = self.normalize_name(row_dict["constraint_name"]) + table_fkey = fkeys[(schema, table_name)] + + assert constraint_name is not None - rec = fkeys[cons_name] - rec["name"] = cons_name - local_cols, remote_cols = ( - rec["constrained_columns"], - rec["referred_columns"], + local_column = self.normalize_name(row_dict["local_column"]) + remote_table = self.normalize_name(row_dict["remote_table"]) + remote_column = self.normalize_name(row_dict["remote_column"]) + remote_owner_orig = row_dict["remote_owner"] + remote_owner = self.normalize_name(remote_owner_orig) + if remote_owner_orig is not None: + all_remote_owners.add(remote_owner_orig) + + if remote_table is None: + # ticket 363 + if dblink and not dblink.startswith("@"): + dblink = f"@{dblink}" + util.warn( + "Got 'None' querying 'table_name' from " + f"all_cons_columns{dblink or ''} - does the user have " + "proper rights to the table?" ) + continue - if not rec["referred_table"]: - if resolve_synonyms: - ( - ref_remote_name, - ref_remote_owner, - ref_dblink, - ref_synonym, - ) = self._resolve_synonym( - connection, - desired_owner=self.denormalize_name(remote_owner), - desired_table=self.denormalize_name(remote_table), - ) - if ref_synonym: - remote_table = self.normalize_name(ref_synonym) - remote_owner = self.normalize_name( - ref_remote_owner - ) + if constraint_name not in table_fkey: + table_fkey[constraint_name] = fkey = { + "name": constraint_name, + "constrained_columns": [], + "referred_schema": None, + "referred_table": remote_table, + "referred_columns": [], + "options": {}, + } - rec["referred_table"] = remote_table + if resolve_synonyms: + # will be removed below + fkey["_ref_schema"] = remote_owner - if ( - requested_schema is not None - or self.denormalize_name(remote_owner) != schema - ): - rec["referred_schema"] = remote_owner + if schema is not None or remote_owner_orig != owner: + fkey["referred_schema"] = remote_owner + + delete_rule = row_dict["delete_rule"] + if delete_rule != "NO ACTION": + fkey["options"]["ondelete"] = delete_rule + + else: + fkey = table_fkey[constraint_name] + + fkey["constrained_columns"].append(local_column) + fkey["referred_columns"].append(remote_column) + + if resolve_synonyms and all_remote_owners: + query = select( + dictionary.all_synonyms.c.owner, + dictionary.all_synonyms.c.table_name, + dictionary.all_synonyms.c.table_owner, + dictionary.all_synonyms.c.synonym_name, + ).where(dictionary.all_synonyms.c.owner.in_(all_remote_owners)) + + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).mappings() - if row[9] != "NO ACTION": - rec["options"]["ondelete"] = row[9] + remote_owners_lut = {} + for row in result: + synonym_owner = self.normalize_name(row["owner"]) + table_name = self.normalize_name(row["table_name"]) - local_cols.append(local_column) - remote_cols.append(remote_column) + remote_owners_lut[(synonym_owner, table_name)] = ( + row["table_owner"], + row["synonym_name"], + ) + + empty = (None, None) + for table_fkeys in fkeys.values(): + for table_fkey in table_fkeys.values(): + key = ( + table_fkey.pop("_ref_schema"), + table_fkey["referred_table"], + ) + remote_owner, syn_name = remote_owners_lut.get(key, empty) + if syn_name: + sn = self.normalize_name(syn_name) + table_fkey["referred_table"] = sn + if schema is not None or remote_owner != owner: + ro = self.normalize_name(remote_owner) + table_fkey["referred_schema"] = ro + else: + table_fkey["referred_schema"] = None + default = ReflectionDefaults.foreign_keys - return list(fkeys.values()) + return ( + (key, list(fkeys[key].values()) if key in fkeys else default()) + for key in ( + (schema, self.normalize_name(obj_name)) + for obj_name in all_objects + ) + ) @reflection.cache def get_unique_constraints( self, connection, table_name, schema=None, **kw ): - resolve_synonyms = kw.get("oracle_resolve_synonyms", False) - dblink = kw.get("dblink", "") - info_cache = kw.get("info_cache") - - (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + data = self.get_multi_unique_constraints( connection, - table_name, - schema, - resolve_synonyms, - dblink, - info_cache=info_cache, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) + return self._value_or_raise(data, table_name, schema) - constraint_data = self._get_constraint_data( - connection, - table_name, - schema, - dblink, - info_cache=kw.get("info_cache"), + @_handle_synonyms_decorator + def get_multi_unique_constraints( + self, + connection, + *, + scope, + schema, + filter_names, + kind, + dblink=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + all_objects = self._get_all_objects( + connection, schema, scope, kind, filter_names, dblink, **kw ) - unique_keys = filter(lambda x: x[1] == "U", constraint_data) - uniques_group = groupby(unique_keys, lambda x: x[0]) + unique_cons = defaultdict(dict) index_names = { - ix["name"] - for ix in self.get_indexes(connection, table_name, schema=schema) + row_dict["index_name"] + for row_dict in self._get_indexes_rows( + connection, schema, dblink, all_objects, **kw + ) } - return [ - { - "name": name, - "column_names": cols, - "duplicates_index": name if name in index_names else None, - } - for name, cols in [ - [ - self.normalize_name(i[0]), - [self.normalize_name(x[2]) for x in i[1]], - ] - for i in uniques_group - ] - ] + + for row_dict in self._get_all_constraint_rows( + connection, schema, dblink, all_objects, **kw + ): + if row_dict["constraint_type"] != "U": + continue + table_name = self.normalize_name(row_dict["table_name"]) + constraint_name_orig = row_dict["constraint_name"] + constraint_name = self.normalize_name(constraint_name_orig) + column_name = self.normalize_name(row_dict["local_column"]) + table_uc = unique_cons[(schema, table_name)] + + assert constraint_name is not None + + if constraint_name not in table_uc: + table_uc[constraint_name] = uc = { + "name": constraint_name, + "column_names": [], + "duplicates_index": constraint_name + if constraint_name_orig in index_names + else None, + } + else: + uc = table_uc[constraint_name] + + uc["column_names"].append(column_name) + + default = ReflectionDefaults.unique_constraints + + return ( + ( + key, + list(unique_cons[key].values()) + if key in unique_cons + else default(), + ) + for key in ( + (schema, self.normalize_name(obj_name)) + for obj_name in all_objects + ) + ) @reflection.cache def get_view_definition( @@ -2553,65 +2997,129 @@ class OracleDialect(default.DefaultDialect): connection, view_name, schema=None, - resolve_synonyms=False, - dblink="", + dblink=None, **kw, ): - info_cache = kw.get("info_cache") - (view_name, schema, dblink, synonym) = self._prepare_reflection_args( - connection, - view_name, - schema, - resolve_synonyms, - dblink, - info_cache=info_cache, + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + if kw.get("oracle_resolve_synonyms", False): + synonyms = self._get_synonyms( + connection, schema, filter_names=[view_name], dblink=dblink + ) + if synonyms: + assert len(synonyms) == 1 + row_dict = synonyms[0] + dblink = self.normalize_name(row_dict["db_link"]) + schema = row_dict["table_owner"] + view_name = row_dict["table_name"] + + name = self.denormalize_name(view_name) + owner = self.denormalize_name(schema or self.default_schema_name) + query = ( + select(dictionary.all_views.c.text) + .where( + dictionary.all_views.c.view_name == name, + dictionary.all_views.c.owner == owner, + ) + .union_all( + select(dictionary.all_mviews.c.query).where( + dictionary.all_mviews.c.mview_name == name, + dictionary.all_mviews.c.owner == owner, + ) + ) ) - params = {"view_name": view_name} - text = "SELECT text FROM all_views WHERE view_name=:view_name" - - if schema is not None: - text += " AND owner = :schema" - params["schema"] = schema - - rp = connection.execute(sql.text(text), params).scalar() - if rp: - return rp + rp = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalar() + if rp is None: + raise exc.NoSuchTableError( + f"{schema}.{view_name}" if schema else view_name + ) else: - return None + return rp @reflection.cache def get_check_constraints( self, connection, table_name, schema=None, include_all=False, **kw ): - resolve_synonyms = kw.get("oracle_resolve_synonyms", False) - dblink = kw.get("dblink", "") - info_cache = kw.get("info_cache") - - (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + data = self.get_multi_check_constraints( connection, - table_name, - schema, - resolve_synonyms, - dblink, - info_cache=info_cache, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + include_all=include_all, + kind=ObjectKind.ANY, + **kw, ) + return self._value_or_raise(data, table_name, schema) - constraint_data = self._get_constraint_data( - connection, - table_name, - schema, - dblink, - info_cache=kw.get("info_cache"), + @_handle_synonyms_decorator + def get_multi_check_constraints( + self, + connection, + *, + schema, + filter_names, + dblink=None, + scope, + kind, + include_all=False, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + all_objects = self._get_all_objects( + connection, schema, scope, kind, filter_names, dblink, **kw ) - check_constraints = filter(lambda x: x[1] == "C", constraint_data) + not_null = re.compile(r"..+?. IS NOT NULL$") - return [ - {"name": self.normalize_name(cons[0]), "sqltext": cons[8]} - for cons in check_constraints - if include_all or not re.match(r"..+?. IS NOT NULL$", cons[8]) - ] + check_constraints = defaultdict(list) + + for row_dict in self._get_all_constraint_rows( + connection, schema, dblink, all_objects, **kw + ): + if row_dict["constraint_type"] != "C": + continue + table_name = self.normalize_name(row_dict["table_name"]) + constraint_name = self.normalize_name(row_dict["constraint_name"]) + search_condition = row_dict["search_condition"] + + table_checks = check_constraints[(schema, table_name)] + if constraint_name is not None and ( + include_all or not not_null.match(search_condition) + ): + table_checks.append( + {"name": constraint_name, "sqltext": search_condition} + ) + + default = ReflectionDefaults.check_constraints + + return ( + ( + key, + check_constraints[key] + if key in check_constraints + else default(), + ) + for key in ( + (schema, self.normalize_name(obj_name)) + for obj_name in all_objects + ) + ) + + def _list_dblinks(self, connection, dblink=None): + query = select(dictionary.all_db_links.c.db_link) + links = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalars() + return [self.normalize_name(link) for link in links] class _OuterJoinColumn(sql.ClauseElement): diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index 25e93632c..d2ee0a96e 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -431,6 +431,7 @@ from . import base as oracle from .base import OracleCompiler from .base import OracleDialect from .base import OracleExecutionContext +from .types import _OracleDateLiteralRender from ... import exc from ... import util from ...engine import cursor as _cursor @@ -573,7 +574,7 @@ class _CXOracleDate(oracle._OracleDate): return process -class _CXOracleTIMESTAMP(oracle._OracleDateLiteralRender, sqltypes.TIMESTAMP): +class _CXOracleTIMESTAMP(_OracleDateLiteralRender, sqltypes.TIMESTAMP): def literal_processor(self, dialect): return self._literal_processor_datetime(dialect) @@ -812,6 +813,7 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext): return None def pre_exec(self): + super().pre_exec() if not getattr(self.compiled, "_oracle_cx_sql_compiler", False): return diff --git a/lib/sqlalchemy/dialects/oracle/dictionary.py b/lib/sqlalchemy/dialects/oracle/dictionary.py new file mode 100644 index 000000000..ac7a350da --- /dev/null +++ b/lib/sqlalchemy/dialects/oracle/dictionary.py @@ -0,0 +1,495 @@ +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from .types import DATE +from .types import LONG +from .types import NUMBER +from .types import RAW +from .types import VARCHAR2 +from ... import Column +from ... import MetaData +from ... import Table +from ... import table +from ...sql.sqltypes import CHAR + +# constants +DB_LINK_PLACEHOLDER = "__$sa_dblink$__" +# tables +dual = table("dual") +dictionary_meta = MetaData() + +# NOTE: all the dictionary_meta are aliases because oracle does not like +# using the full table@dblink for every column in query, and complains with +# ORA-00960: ambiguous column naming in select list +all_tables = Table( + "all_tables" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("tablespace_name", VARCHAR2(30)), + Column("cluster_name", VARCHAR2(128)), + Column("iot_name", VARCHAR2(128)), + Column("status", VARCHAR2(8)), + Column("pct_free", NUMBER), + Column("pct_used", NUMBER), + Column("ini_trans", NUMBER), + Column("max_trans", NUMBER), + Column("initial_extent", NUMBER), + Column("next_extent", NUMBER), + Column("min_extents", NUMBER), + Column("max_extents", NUMBER), + Column("pct_increase", NUMBER), + Column("freelists", NUMBER), + Column("freelist_groups", NUMBER), + Column("logging", VARCHAR2(3)), + Column("backed_up", VARCHAR2(1)), + Column("num_rows", NUMBER), + Column("blocks", NUMBER), + Column("empty_blocks", NUMBER), + Column("avg_space", NUMBER), + Column("chain_cnt", NUMBER), + Column("avg_row_len", NUMBER), + Column("avg_space_freelist_blocks", NUMBER), + Column("num_freelist_blocks", NUMBER), + Column("degree", VARCHAR2(10)), + Column("instances", VARCHAR2(10)), + Column("cache", VARCHAR2(5)), + Column("table_lock", VARCHAR2(8)), + Column("sample_size", NUMBER), + Column("last_analyzed", DATE), + Column("partitioned", VARCHAR2(3)), + Column("iot_type", VARCHAR2(12)), + Column("temporary", VARCHAR2(1)), + Column("secondary", VARCHAR2(1)), + Column("nested", VARCHAR2(3)), + Column("buffer_pool", VARCHAR2(7)), + Column("flash_cache", VARCHAR2(7)), + Column("cell_flash_cache", VARCHAR2(7)), + Column("row_movement", VARCHAR2(8)), + Column("global_stats", VARCHAR2(3)), + Column("user_stats", VARCHAR2(3)), + Column("duration", VARCHAR2(15)), + Column("skip_corrupt", VARCHAR2(8)), + Column("monitoring", VARCHAR2(3)), + Column("cluster_owner", VARCHAR2(128)), + Column("dependencies", VARCHAR2(8)), + Column("compression", VARCHAR2(8)), + Column("compress_for", VARCHAR2(30)), + Column("dropped", VARCHAR2(3)), + Column("read_only", VARCHAR2(3)), + Column("segment_created", VARCHAR2(3)), + Column("result_cache", VARCHAR2(7)), + Column("clustering", VARCHAR2(3)), + Column("activity_tracking", VARCHAR2(23)), + Column("dml_timestamp", VARCHAR2(25)), + Column("has_identity", VARCHAR2(3)), + Column("container_data", VARCHAR2(3)), + Column("inmemory", VARCHAR2(8)), + Column("inmemory_priority", VARCHAR2(8)), + Column("inmemory_distribute", VARCHAR2(15)), + Column("inmemory_compression", VARCHAR2(17)), + Column("inmemory_duplicate", VARCHAR2(13)), + Column("default_collation", VARCHAR2(100)), + Column("duplicated", VARCHAR2(1)), + Column("sharded", VARCHAR2(1)), + Column("externally_sharded", VARCHAR2(1)), + Column("externally_duplicated", VARCHAR2(1)), + Column("external", VARCHAR2(3)), + Column("hybrid", VARCHAR2(3)), + Column("cellmemory", VARCHAR2(24)), + Column("containers_default", VARCHAR2(3)), + Column("container_map", VARCHAR2(3)), + Column("extended_data_link", VARCHAR2(3)), + Column("extended_data_link_map", VARCHAR2(3)), + Column("inmemory_service", VARCHAR2(12)), + Column("inmemory_service_name", VARCHAR2(1000)), + Column("container_map_object", VARCHAR2(3)), + Column("memoptimize_read", VARCHAR2(8)), + Column("memoptimize_write", VARCHAR2(8)), + Column("has_sensitive_column", VARCHAR2(3)), + Column("admit_null", VARCHAR2(3)), + Column("data_link_dml_enabled", VARCHAR2(3)), + Column("logical_replication", VARCHAR2(8)), +).alias("a_tables") + +all_views = Table( + "all_views" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("view_name", VARCHAR2(128), nullable=False), + Column("text_length", NUMBER), + Column("text", LONG), + Column("text_vc", VARCHAR2(4000)), + Column("type_text_length", NUMBER), + Column("type_text", VARCHAR2(4000)), + Column("oid_text_length", NUMBER), + Column("oid_text", VARCHAR2(4000)), + Column("view_type_owner", VARCHAR2(128)), + Column("view_type", VARCHAR2(128)), + Column("superview_name", VARCHAR2(128)), + Column("editioning_view", VARCHAR2(1)), + Column("read_only", VARCHAR2(1)), + Column("container_data", VARCHAR2(1)), + Column("bequeath", VARCHAR2(12)), + Column("origin_con_id", VARCHAR2(256)), + Column("default_collation", VARCHAR2(100)), + Column("containers_default", VARCHAR2(3)), + Column("container_map", VARCHAR2(3)), + Column("extended_data_link", VARCHAR2(3)), + Column("extended_data_link_map", VARCHAR2(3)), + Column("has_sensitive_column", VARCHAR2(3)), + Column("admit_null", VARCHAR2(3)), + Column("pdb_local_only", VARCHAR2(3)), +).alias("a_views") + +all_sequences = Table( + "all_sequences" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("sequence_owner", VARCHAR2(128), nullable=False), + Column("sequence_name", VARCHAR2(128), nullable=False), + Column("min_value", NUMBER), + Column("max_value", NUMBER), + Column("increment_by", NUMBER, nullable=False), + Column("cycle_flag", VARCHAR2(1)), + Column("order_flag", VARCHAR2(1)), + Column("cache_size", NUMBER, nullable=False), + Column("last_number", NUMBER, nullable=False), + Column("scale_flag", VARCHAR2(1)), + Column("extend_flag", VARCHAR2(1)), + Column("sharded_flag", VARCHAR2(1)), + Column("session_flag", VARCHAR2(1)), + Column("keep_value", VARCHAR2(1)), +).alias("a_sequences") + +all_users = Table( + "all_users" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("username", VARCHAR2(128), nullable=False), + Column("user_id", NUMBER, nullable=False), + Column("created", DATE, nullable=False), + Column("common", VARCHAR2(3)), + Column("oracle_maintained", VARCHAR2(1)), + Column("inherited", VARCHAR2(3)), + Column("default_collation", VARCHAR2(100)), + Column("implicit", VARCHAR2(3)), + Column("all_shard", VARCHAR2(3)), + Column("external_shard", VARCHAR2(3)), +).alias("a_users") + +all_mviews = Table( + "all_mviews" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("mview_name", VARCHAR2(128), nullable=False), + Column("container_name", VARCHAR2(128), nullable=False), + Column("query", LONG), + Column("query_len", NUMBER(38)), + Column("updatable", VARCHAR2(1)), + Column("update_log", VARCHAR2(128)), + Column("master_rollback_seg", VARCHAR2(128)), + Column("master_link", VARCHAR2(128)), + Column("rewrite_enabled", VARCHAR2(1)), + Column("rewrite_capability", VARCHAR2(9)), + Column("refresh_mode", VARCHAR2(6)), + Column("refresh_method", VARCHAR2(8)), + Column("build_mode", VARCHAR2(9)), + Column("fast_refreshable", VARCHAR2(18)), + Column("last_refresh_type", VARCHAR2(8)), + Column("last_refresh_date", DATE), + Column("last_refresh_end_time", DATE), + Column("staleness", VARCHAR2(19)), + Column("after_fast_refresh", VARCHAR2(19)), + Column("unknown_prebuilt", VARCHAR2(1)), + Column("unknown_plsql_func", VARCHAR2(1)), + Column("unknown_external_table", VARCHAR2(1)), + Column("unknown_consider_fresh", VARCHAR2(1)), + Column("unknown_import", VARCHAR2(1)), + Column("unknown_trusted_fd", VARCHAR2(1)), + Column("compile_state", VARCHAR2(19)), + Column("use_no_index", VARCHAR2(1)), + Column("stale_since", DATE), + Column("num_pct_tables", NUMBER), + Column("num_fresh_pct_regions", NUMBER), + Column("num_stale_pct_regions", NUMBER), + Column("segment_created", VARCHAR2(3)), + Column("evaluation_edition", VARCHAR2(128)), + Column("unusable_before", VARCHAR2(128)), + Column("unusable_beginning", VARCHAR2(128)), + Column("default_collation", VARCHAR2(100)), + Column("on_query_computation", VARCHAR2(1)), + Column("auto", VARCHAR2(3)), +).alias("a_mviews") + +all_tab_identity_cols = Table( + "all_tab_identity_cols" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("column_name", VARCHAR2(128), nullable=False), + Column("generation_type", VARCHAR2(10)), + Column("sequence_name", VARCHAR2(128), nullable=False), + Column("identity_options", VARCHAR2(298)), +).alias("a_tab_identity_cols") + +all_tab_cols = Table( + "all_tab_cols" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("column_name", VARCHAR2(128), nullable=False), + Column("data_type", VARCHAR2(128)), + Column("data_type_mod", VARCHAR2(3)), + Column("data_type_owner", VARCHAR2(128)), + Column("data_length", NUMBER, nullable=False), + Column("data_precision", NUMBER), + Column("data_scale", NUMBER), + Column("nullable", VARCHAR2(1)), + Column("column_id", NUMBER), + Column("default_length", NUMBER), + Column("data_default", LONG), + Column("num_distinct", NUMBER), + Column("low_value", RAW(1000)), + Column("high_value", RAW(1000)), + Column("density", NUMBER), + Column("num_nulls", NUMBER), + Column("num_buckets", NUMBER), + Column("last_analyzed", DATE), + Column("sample_size", NUMBER), + Column("character_set_name", VARCHAR2(44)), + Column("char_col_decl_length", NUMBER), + Column("global_stats", VARCHAR2(3)), + Column("user_stats", VARCHAR2(3)), + Column("avg_col_len", NUMBER), + Column("char_length", NUMBER), + Column("char_used", VARCHAR2(1)), + Column("v80_fmt_image", VARCHAR2(3)), + Column("data_upgraded", VARCHAR2(3)), + Column("hidden_column", VARCHAR2(3)), + Column("virtual_column", VARCHAR2(3)), + Column("segment_column_id", NUMBER), + Column("internal_column_id", NUMBER, nullable=False), + Column("histogram", VARCHAR2(15)), + Column("qualified_col_name", VARCHAR2(4000)), + Column("user_generated", VARCHAR2(3)), + Column("default_on_null", VARCHAR2(3)), + Column("identity_column", VARCHAR2(3)), + Column("evaluation_edition", VARCHAR2(128)), + Column("unusable_before", VARCHAR2(128)), + Column("unusable_beginning", VARCHAR2(128)), + Column("collation", VARCHAR2(100)), + Column("collated_column_id", NUMBER), +).alias("a_tab_cols") + +all_tab_comments = Table( + "all_tab_comments" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("table_type", VARCHAR2(11)), + Column("comments", VARCHAR2(4000)), + Column("origin_con_id", NUMBER), +).alias("a_tab_comments") + +all_col_comments = Table( + "all_col_comments" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("column_name", VARCHAR2(128), nullable=False), + Column("comments", VARCHAR2(4000)), + Column("origin_con_id", NUMBER), +).alias("a_col_comments") + +all_mview_comments = Table( + "all_mview_comments" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("mview_name", VARCHAR2(128), nullable=False), + Column("comments", VARCHAR2(4000)), +).alias("a_mview_comments") + +all_ind_columns = Table( + "all_ind_columns" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("index_owner", VARCHAR2(128), nullable=False), + Column("index_name", VARCHAR2(128), nullable=False), + Column("table_owner", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("column_name", VARCHAR2(4000)), + Column("column_position", NUMBER, nullable=False), + Column("column_length", NUMBER, nullable=False), + Column("char_length", NUMBER), + Column("descend", VARCHAR2(4)), + Column("collated_column_id", NUMBER), +).alias("a_ind_columns") + +all_indexes = Table( + "all_indexes" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("index_name", VARCHAR2(128), nullable=False), + Column("index_type", VARCHAR2(27)), + Column("table_owner", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("table_type", CHAR(11)), + Column("uniqueness", VARCHAR2(9)), + Column("compression", VARCHAR2(13)), + Column("prefix_length", NUMBER), + Column("tablespace_name", VARCHAR2(30)), + Column("ini_trans", NUMBER), + Column("max_trans", NUMBER), + Column("initial_extent", NUMBER), + Column("next_extent", NUMBER), + Column("min_extents", NUMBER), + Column("max_extents", NUMBER), + Column("pct_increase", NUMBER), + Column("pct_threshold", NUMBER), + Column("include_column", NUMBER), + Column("freelists", NUMBER), + Column("freelist_groups", NUMBER), + Column("pct_free", NUMBER), + Column("logging", VARCHAR2(3)), + Column("blevel", NUMBER), + Column("leaf_blocks", NUMBER), + Column("distinct_keys", NUMBER), + Column("avg_leaf_blocks_per_key", NUMBER), + Column("avg_data_blocks_per_key", NUMBER), + Column("clustering_factor", NUMBER), + Column("status", VARCHAR2(8)), + Column("num_rows", NUMBER), + Column("sample_size", NUMBER), + Column("last_analyzed", DATE), + Column("degree", VARCHAR2(40)), + Column("instances", VARCHAR2(40)), + Column("partitioned", VARCHAR2(3)), + Column("temporary", VARCHAR2(1)), + Column("generated", VARCHAR2(1)), + Column("secondary", VARCHAR2(1)), + Column("buffer_pool", VARCHAR2(7)), + Column("flash_cache", VARCHAR2(7)), + Column("cell_flash_cache", VARCHAR2(7)), + Column("user_stats", VARCHAR2(3)), + Column("duration", VARCHAR2(15)), + Column("pct_direct_access", NUMBER), + Column("ityp_owner", VARCHAR2(128)), + Column("ityp_name", VARCHAR2(128)), + Column("parameters", VARCHAR2(1000)), + Column("global_stats", VARCHAR2(3)), + Column("domidx_status", VARCHAR2(12)), + Column("domidx_opstatus", VARCHAR2(6)), + Column("funcidx_status", VARCHAR2(8)), + Column("join_index", VARCHAR2(3)), + Column("iot_redundant_pkey_elim", VARCHAR2(3)), + Column("dropped", VARCHAR2(3)), + Column("visibility", VARCHAR2(9)), + Column("domidx_management", VARCHAR2(14)), + Column("segment_created", VARCHAR2(3)), + Column("orphaned_entries", VARCHAR2(3)), + Column("indexing", VARCHAR2(7)), + Column("auto", VARCHAR2(3)), +).alias("a_indexes") + +all_constraints = Table( + "all_constraints" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128)), + Column("constraint_name", VARCHAR2(128)), + Column("constraint_type", VARCHAR2(1)), + Column("table_name", VARCHAR2(128)), + Column("search_condition", LONG), + Column("search_condition_vc", VARCHAR2(4000)), + Column("r_owner", VARCHAR2(128)), + Column("r_constraint_name", VARCHAR2(128)), + Column("delete_rule", VARCHAR2(9)), + Column("status", VARCHAR2(8)), + Column("deferrable", VARCHAR2(14)), + Column("deferred", VARCHAR2(9)), + Column("validated", VARCHAR2(13)), + Column("generated", VARCHAR2(14)), + Column("bad", VARCHAR2(3)), + Column("rely", VARCHAR2(4)), + Column("last_change", DATE), + Column("index_owner", VARCHAR2(128)), + Column("index_name", VARCHAR2(128)), + Column("invalid", VARCHAR2(7)), + Column("view_related", VARCHAR2(14)), + Column("origin_con_id", VARCHAR2(256)), +).alias("a_constraints") + +all_cons_columns = Table( + "all_cons_columns" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("constraint_name", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("column_name", VARCHAR2(4000)), + Column("position", NUMBER), +).alias("a_cons_columns") + +# TODO figure out if it's still relevant, since there is no mention from here +# https://docs.oracle.com/en/database/oracle/oracle-database/21/refrn/ALL_DB_LINKS.html +# original note: +# using user_db_links here since all_db_links appears +# to have more restricted permissions. +# https://docs.oracle.com/cd/B28359_01/server.111/b28310/ds_admin005.htm +# will need to hear from more users if we are doing +# the right thing here. See [ticket:2619] +all_db_links = Table( + "all_db_links" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("db_link", VARCHAR2(128), nullable=False), + Column("username", VARCHAR2(128)), + Column("host", VARCHAR2(2000)), + Column("created", DATE, nullable=False), + Column("hidden", VARCHAR2(3)), + Column("shard_internal", VARCHAR2(3)), + Column("valid", VARCHAR2(3)), + Column("intra_cdb", VARCHAR2(3)), +).alias("a_db_links") + +all_synonyms = Table( + "all_synonyms" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128)), + Column("synonym_name", VARCHAR2(128)), + Column("table_owner", VARCHAR2(128)), + Column("table_name", VARCHAR2(128)), + Column("db_link", VARCHAR2(128)), + Column("origin_con_id", VARCHAR2(256)), +).alias("a_synonyms") + +all_objects = Table( + "all_objects" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("object_name", VARCHAR2(128), nullable=False), + Column("subobject_name", VARCHAR2(128)), + Column("object_id", NUMBER, nullable=False), + Column("data_object_id", NUMBER), + Column("object_type", VARCHAR2(23)), + Column("created", DATE, nullable=False), + Column("last_ddl_time", DATE, nullable=False), + Column("timestamp", VARCHAR2(19)), + Column("status", VARCHAR2(7)), + Column("temporary", VARCHAR2(1)), + Column("generated", VARCHAR2(1)), + Column("secondary", VARCHAR2(1)), + Column("namespace", NUMBER, nullable=False), + Column("edition_name", VARCHAR2(128)), + Column("sharing", VARCHAR2(13)), + Column("editionable", VARCHAR2(1)), + Column("oracle_maintained", VARCHAR2(1)), + Column("application", VARCHAR2(1)), + Column("default_collation", VARCHAR2(100)), + Column("duplicated", VARCHAR2(1)), + Column("sharded", VARCHAR2(1)), + Column("created_appid", NUMBER), + Column("created_vsnid", NUMBER), + Column("modified_appid", NUMBER), + Column("modified_vsnid", NUMBER), +).alias("a_objects") diff --git a/lib/sqlalchemy/dialects/oracle/provision.py b/lib/sqlalchemy/dialects/oracle/provision.py index cba3b5be4..75b7a7aa9 100644 --- a/lib/sqlalchemy/dialects/oracle/provision.py +++ b/lib/sqlalchemy/dialects/oracle/provision.py @@ -2,9 +2,12 @@ from ... import create_engine from ... import exc +from ... import inspect from ...engine import url as sa_url from ...testing.provision import configure_follower from ...testing.provision import create_db +from ...testing.provision import drop_all_schema_objects_post_tables +from ...testing.provision import drop_all_schema_objects_pre_tables from ...testing.provision import drop_db from ...testing.provision import follower_url_from_main from ...testing.provision import log @@ -28,6 +31,10 @@ def _oracle_create_db(cfg, eng, ident): conn.exec_driver_sql("grant unlimited tablespace to %s" % ident) conn.exec_driver_sql("grant unlimited tablespace to %s_ts1" % ident) conn.exec_driver_sql("grant unlimited tablespace to %s_ts2" % ident) + # these are needed to create materialized views + conn.exec_driver_sql("grant create table to %s" % ident) + conn.exec_driver_sql("grant create table to %s_ts1" % ident) + conn.exec_driver_sql("grant create table to %s_ts2" % ident) @configure_follower.for_db("oracle") @@ -46,6 +53,30 @@ def _ora_drop_ignore(conn, dbname): return False +@drop_all_schema_objects_pre_tables.for_db("oracle") +def _ora_drop_all_schema_objects_pre_tables(cfg, eng): + _purge_recyclebin(eng) + _purge_recyclebin(eng, cfg.test_schema) + + +@drop_all_schema_objects_post_tables.for_db("oracle") +def _ora_drop_all_schema_objects_post_tables(cfg, eng): + + with eng.begin() as conn: + for syn in conn.dialect._get_synonyms(conn, None, None, None): + conn.exec_driver_sql(f"drop synonym {syn['synonym_name']}") + + for syn in conn.dialect._get_synonyms( + conn, cfg.test_schema, None, None + ): + conn.exec_driver_sql( + f"drop synonym {cfg.test_schema}.{syn['synonym_name']}" + ) + + for tmp_table in inspect(conn).get_temp_table_names(): + conn.exec_driver_sql(f"drop table {tmp_table}") + + @drop_db.for_db("oracle") def _oracle_drop_db(cfg, eng, ident): with eng.begin() as conn: @@ -60,13 +91,10 @@ def _oracle_drop_db(cfg, eng, ident): @stop_test_class_outside_fixtures.for_db("oracle") -def stop_test_class_outside_fixtures(config, db, cls): +def _ora_stop_test_class_outside_fixtures(config, db, cls): try: - with db.begin() as conn: - # run magic command to get rid of identity sequences - # https://floo.bar/2019/11/29/drop-the-underlying-sequence-of-an-identity-column/ # noqa: E501 - conn.exec_driver_sql("purge recyclebin") + _purge_recyclebin(db) except exc.DatabaseError as err: log.warning("purge recyclebin command failed: %s", err) @@ -85,6 +113,22 @@ def stop_test_class_outside_fixtures(config, db, cls): _all_conns.clear() +def _purge_recyclebin(eng, schema=None): + with eng.begin() as conn: + if schema is None: + # run magic command to get rid of identity sequences + # https://floo.bar/2019/11/29/drop-the-underlying-sequence-of-an-identity-column/ # noqa: E501 + conn.exec_driver_sql("purge recyclebin") + else: + # per user: https://community.oracle.com/tech/developers/discussion/2255402/how-to-clear-dba-recyclebin-for-a-particular-user # noqa: E501 + for owner, object_name, type_ in conn.exec_driver_sql( + "select owner, object_name,type from " + "dba_recyclebin where owner=:schema and type='TABLE'", + {"schema": conn.dialect.denormalize_name(schema)}, + ).all(): + conn.exec_driver_sql(f'purge {type_} {owner}."{object_name}"') + + _all_conns = set() diff --git a/lib/sqlalchemy/dialects/oracle/types.py b/lib/sqlalchemy/dialects/oracle/types.py new file mode 100644 index 000000000..60a8ebcb5 --- /dev/null +++ b/lib/sqlalchemy/dialects/oracle/types.py @@ -0,0 +1,233 @@ +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from ...sql import sqltypes +from ...types import NVARCHAR +from ...types import VARCHAR + + +class RAW(sqltypes._Binary): + __visit_name__ = "RAW" + + +OracleRaw = RAW + + +class NCLOB(sqltypes.Text): + __visit_name__ = "NCLOB" + + +class VARCHAR2(VARCHAR): + __visit_name__ = "VARCHAR2" + + +NVARCHAR2 = NVARCHAR + + +class NUMBER(sqltypes.Numeric, sqltypes.Integer): + __visit_name__ = "NUMBER" + + def __init__(self, precision=None, scale=None, asdecimal=None): + if asdecimal is None: + asdecimal = bool(scale and scale > 0) + + super(NUMBER, self).__init__( + precision=precision, scale=scale, asdecimal=asdecimal + ) + + def adapt(self, impltype): + ret = super(NUMBER, self).adapt(impltype) + # leave a hint for the DBAPI handler + ret._is_oracle_number = True + return ret + + @property + def _type_affinity(self): + if bool(self.scale and self.scale > 0): + return sqltypes.Numeric + else: + return sqltypes.Integer + + +class FLOAT(sqltypes.FLOAT): + """Oracle FLOAT. + + This is the same as :class:`_sqltypes.FLOAT` except that + an Oracle-specific :paramref:`_oracle.FLOAT.binary_precision` + parameter is accepted, and + the :paramref:`_sqltypes.Float.precision` parameter is not accepted. + + Oracle FLOAT types indicate precision in terms of "binary precision", which + defaults to 126. For a REAL type, the value is 63. This parameter does not + cleanly map to a specific number of decimal places but is roughly + equivalent to the desired number of decimal places divided by 0.3103. + + .. versionadded:: 2.0 + + """ + + __visit_name__ = "FLOAT" + + def __init__( + self, + binary_precision=None, + asdecimal=False, + decimal_return_scale=None, + ): + r""" + Construct a FLOAT + + :param binary_precision: Oracle binary precision value to be rendered + in DDL. This may be approximated to the number of decimal characters + using the formula "decimal precision = 0.30103 * binary precision". + The default value used by Oracle for FLOAT / DOUBLE PRECISION is 126. + + :param asdecimal: See :paramref:`_sqltypes.Float.asdecimal` + + :param decimal_return_scale: See + :paramref:`_sqltypes.Float.decimal_return_scale` + + """ + super().__init__( + asdecimal=asdecimal, decimal_return_scale=decimal_return_scale + ) + self.binary_precision = binary_precision + + +class BINARY_DOUBLE(sqltypes.Float): + __visit_name__ = "BINARY_DOUBLE" + + +class BINARY_FLOAT(sqltypes.Float): + __visit_name__ = "BINARY_FLOAT" + + +class BFILE(sqltypes.LargeBinary): + __visit_name__ = "BFILE" + + +class LONG(sqltypes.Text): + __visit_name__ = "LONG" + + +class _OracleDateLiteralRender: + def _literal_processor_datetime(self, dialect): + def process(value): + if value is not None: + if getattr(value, "microsecond", None): + value = ( + f"""TO_TIMESTAMP""" + f"""('{value.isoformat().replace("T", " ")}', """ + """'YYYY-MM-DD HH24:MI:SS.FF')""" + ) + else: + value = ( + f"""TO_DATE""" + f"""('{value.isoformat().replace("T", " ")}', """ + """'YYYY-MM-DD HH24:MI:SS')""" + ) + return value + + return process + + def _literal_processor_date(self, dialect): + def process(value): + if value is not None: + if getattr(value, "microsecond", None): + value = ( + f"""TO_TIMESTAMP""" + f"""('{value.isoformat().split("T")[0]}', """ + """'YYYY-MM-DD')""" + ) + else: + value = ( + f"""TO_DATE""" + f"""('{value.isoformat().split("T")[0]}', """ + """'YYYY-MM-DD')""" + ) + return value + + return process + + +class DATE(_OracleDateLiteralRender, sqltypes.DateTime): + """Provide the oracle DATE type. + + This type has no special Python behavior, except that it subclasses + :class:`_types.DateTime`; this is to suit the fact that the Oracle + ``DATE`` type supports a time value. + + .. versionadded:: 0.9.4 + + """ + + __visit_name__ = "DATE" + + def literal_processor(self, dialect): + return self._literal_processor_datetime(dialect) + + def _compare_type_affinity(self, other): + return other._type_affinity in (sqltypes.DateTime, sqltypes.Date) + + +class _OracleDate(_OracleDateLiteralRender, sqltypes.Date): + def literal_processor(self, dialect): + return self._literal_processor_date(dialect) + + +class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval): + __visit_name__ = "INTERVAL" + + def __init__(self, day_precision=None, second_precision=None): + """Construct an INTERVAL. + + Note that only DAY TO SECOND intervals are currently supported. + This is due to a lack of support for YEAR TO MONTH intervals + within available DBAPIs. + + :param day_precision: the day precision value. this is the number of + digits to store for the day field. Defaults to "2" + :param second_precision: the second precision value. this is the + number of digits to store for the fractional seconds field. + Defaults to "6". + + """ + self.day_precision = day_precision + self.second_precision = second_precision + + @classmethod + def _adapt_from_generic_interval(cls, interval): + return INTERVAL( + day_precision=interval.day_precision, + second_precision=interval.second_precision, + ) + + @property + def _type_affinity(self): + return sqltypes.Interval + + def as_generic(self, allow_nulltype=False): + return sqltypes.Interval( + native=True, + second_precision=self.second_precision, + day_precision=self.day_precision, + ) + + +class ROWID(sqltypes.TypeEngine): + """Oracle ROWID type. + + When used in a cast() or similar, generates ROWID. + + """ + + __visit_name__ = "ROWID" + + +class _OracleBoolean(sqltypes.Boolean): + def get_dbapi_type(self, dbapi): + return dbapi.NUMBER diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py index c2472fb55..85bbf8c5b 100644 --- a/lib/sqlalchemy/dialects/postgresql/__init__.py +++ b/lib/sqlalchemy/dialects/postgresql/__init__.py @@ -19,31 +19,16 @@ from .array import Any from .array import ARRAY from .array import array from .base import BIGINT -from .base import BIT from .base import BOOLEAN -from .base import BYTEA from .base import CHAR -from .base import CIDR -from .base import CreateEnumType from .base import DATE from .base import DOUBLE_PRECISION -from .base import DropEnumType -from .base import ENUM from .base import FLOAT -from .base import INET from .base import INTEGER -from .base import INTERVAL -from .base import MACADDR -from .base import MONEY from .base import NUMERIC -from .base import OID from .base import REAL -from .base import REGCLASS from .base import SMALLINT from .base import TEXT -from .base import TIME -from .base import TIMESTAMP -from .base import TSVECTOR from .base import UUID from .base import VARCHAR from .dml import Insert @@ -61,7 +46,21 @@ from .ranges import INT8RANGE from .ranges import NUMRANGE from .ranges import TSRANGE from .ranges import TSTZRANGE -from ...util import compat +from .types import BIT +from .types import BYTEA +from .types import CIDR +from .types import CreateEnumType +from .types import DropEnumType +from .types import ENUM +from .types import INET +from .types import INTERVAL +from .types import MACADDR +from .types import MONEY +from .types import OID +from .types import REGCLASS +from .types import TIME +from .types import TIMESTAMP +from .types import TSVECTOR # Alias psycopg also as psycopg_async psycopg_async = type( diff --git a/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py b/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py index e831f2ed9..8dcd36c6d 100644 --- a/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py +++ b/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py @@ -1,3 +1,8 @@ +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors import decimal @@ -9,6 +14,9 @@ from .base import _INT_TYPES from .base import PGDialect from .base import PGExecutionContext from .hstore import HSTORE +from .pg_catalog import _SpaceVector +from .pg_catalog import INT2VECTOR +from .pg_catalog import OIDVECTOR from ... import exc from ... import types as sqltypes from ... import util @@ -66,6 +74,14 @@ class _PsycopgARRAY(PGARRAY): render_bind_cast = True +class _PsycopgINT2VECTOR(_SpaceVector, INT2VECTOR): + pass + + +class _PsycopgOIDVECTOR(_SpaceVector, OIDVECTOR): + pass + + class _PGExecutionContext_common_psycopg(PGExecutionContext): def create_server_side_cursor(self): # use server-side cursors: @@ -91,6 +107,8 @@ class _PGDialect_common_psycopg(PGDialect): sqltypes.Numeric: _PsycopgNumeric, HSTORE: _PsycopgHStore, sqltypes.ARRAY: _PsycopgARRAY, + INT2VECTOR: _PsycopgINT2VECTOR, + OIDVECTOR: _PsycopgOIDVECTOR, }, ) diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index 1ec787e1f..d6385a5d6 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -274,6 +274,10 @@ class AsyncpgOID(OID): render_bind_cast = True +class AsyncpgCHAR(sqltypes.CHAR): + render_bind_cast = True + + class PGExecutionContext_asyncpg(PGExecutionContext): def handle_dbapi_exception(self, e): if isinstance( @@ -823,6 +827,7 @@ class PGDialect_asyncpg(PGDialect): sqltypes.Enum: AsyncPgEnum, OID: AsyncpgOID, REGCLASS: AsyncpgREGCLASS, + sqltypes.CHAR: AsyncpgCHAR, }, ) is_async = True diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 0de8a9c44..8402341f6 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -11,7 +11,7 @@ r""" :name: PostgreSQL :full_support: 9.6, 10, 11, 12, 13, 14 :normal_support: 9.6+ - :best_effort: 8+ + :best_effort: 9+ .. _postgresql_sequences: @@ -1448,23 +1448,52 @@ E.g.:: from __future__ import annotations from collections import defaultdict -import datetime as dt +from functools import lru_cache import re -from typing import Any from . import array as _array from . import dml from . import hstore as _hstore from . import json as _json +from . import pg_catalog from . import ranges as _ranges +from .types import _DECIMAL_TYPES # noqa +from .types import _FLOAT_TYPES # noqa +from .types import _INT_TYPES # noqa +from .types import BIT +from .types import BYTEA +from .types import CIDR +from .types import CreateEnumType # noqa +from .types import DropEnumType # noqa +from .types import ENUM +from .types import INET +from .types import INTERVAL +from .types import MACADDR +from .types import MONEY +from .types import OID +from .types import PGBit # noqa +from .types import PGCidr # noqa +from .types import PGInet # noqa +from .types import PGInterval # noqa +from .types import PGMacAddr # noqa +from .types import PGUuid +from .types import REGCLASS +from .types import TIME +from .types import TIMESTAMP +from .types import TSVECTOR from ... import exc from ... import schema +from ... import select from ... import sql from ... import util from ...engine import characteristics from ...engine import default from ...engine import interfaces +from ...engine import ObjectKind +from ...engine import ObjectScope from ...engine import reflection +from ...engine.reflection import ReflectionDefaults +from ...sql import bindparam from ...sql import coercions from ...sql import compiler from ...sql import elements @@ -1472,7 +1501,7 @@ from ...sql import expression from ...sql import roles from ...sql import sqltypes from ...sql import util as sql_util -from ...sql.ddl import InvokeDDLBase +from ...sql.visitors import InternalTraversal from ...types import BIGINT from ...types import BOOLEAN from ...types import CHAR @@ -1596,469 +1625,6 @@ RESERVED_WORDS = set( ] ) -_DECIMAL_TYPES = (1231, 1700) -_FLOAT_TYPES = (700, 701, 1021, 1022) -_INT_TYPES = (20, 21, 23, 26, 1005, 1007, 1016) - - -class PGUuid(UUID): - render_bind_cast = True - render_literal_cast = True - - -class BYTEA(sqltypes.LargeBinary[bytes]): - __visit_name__ = "BYTEA" - - -class INET(sqltypes.TypeEngine[str]): - __visit_name__ = "INET" - - -PGInet = INET - - -class CIDR(sqltypes.TypeEngine[str]): - __visit_name__ = "CIDR" - - -PGCidr = CIDR - - -class MACADDR(sqltypes.TypeEngine[str]): - __visit_name__ = "MACADDR" - - -PGMacAddr = MACADDR - - -class MONEY(sqltypes.TypeEngine[str]): - - r"""Provide the PostgreSQL MONEY type. - - Depending on driver, result rows using this type may return a - string value which includes currency symbols. - - For this reason, it may be preferable to provide conversion to a - numerically-based currency datatype using :class:`_types.TypeDecorator`:: - - import re - import decimal - from sqlalchemy import TypeDecorator - - class NumericMoney(TypeDecorator): - impl = MONEY - - def process_result_value(self, value: Any, dialect: Any) -> None: - if value is not None: - # adjust this for the currency and numeric - m = re.match(r"\$([\d.]+)", value) - if m: - value = decimal.Decimal(m.group(1)) - return value - - Alternatively, the conversion may be applied as a CAST using - the :meth:`_types.TypeDecorator.column_expression` method as follows:: - - import decimal - from sqlalchemy import cast - from sqlalchemy import TypeDecorator - - class NumericMoney(TypeDecorator): - impl = MONEY - - def column_expression(self, column: Any): - return cast(column, Numeric()) - - .. versionadded:: 1.2 - - """ - - __visit_name__ = "MONEY" - - -class OID(sqltypes.TypeEngine[int]): - - """Provide the PostgreSQL OID type. - - .. versionadded:: 0.9.5 - - """ - - __visit_name__ = "OID" - - -class REGCLASS(sqltypes.TypeEngine[str]): - - """Provide the PostgreSQL REGCLASS type. - - .. versionadded:: 1.2.7 - - """ - - __visit_name__ = "REGCLASS" - - -class TIMESTAMP(sqltypes.TIMESTAMP): - def __init__(self, timezone=False, precision=None): - super(TIMESTAMP, self).__init__(timezone=timezone) - self.precision = precision - - -class TIME(sqltypes.TIME): - def __init__(self, timezone=False, precision=None): - super(TIME, self).__init__(timezone=timezone) - self.precision = precision - - -class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval): - - """PostgreSQL INTERVAL type.""" - - __visit_name__ = "INTERVAL" - native = True - - def __init__(self, precision=None, fields=None): - """Construct an INTERVAL. - - :param precision: optional integer precision value - :param fields: string fields specifier. allows storage of fields - to be limited, such as ``"YEAR"``, ``"MONTH"``, ``"DAY TO HOUR"``, - etc. - - .. versionadded:: 1.2 - - """ - self.precision = precision - self.fields = fields - - @classmethod - def adapt_emulated_to_native(cls, interval, **kw): - return INTERVAL(precision=interval.second_precision) - - @property - def _type_affinity(self): - return sqltypes.Interval - - def as_generic(self, allow_nulltype=False): - return sqltypes.Interval(native=True, second_precision=self.precision) - - @property - def python_type(self): - return dt.timedelta - - -PGInterval = INTERVAL - - -class BIT(sqltypes.TypeEngine[int]): - __visit_name__ = "BIT" - - def __init__(self, length=None, varying=False): - if not varying: - # BIT without VARYING defaults to length 1 - self.length = length or 1 - else: - # but BIT VARYING can be unlimited-length, so no default - self.length = length - self.varying = varying - - -PGBit = BIT - - -class TSVECTOR(sqltypes.TypeEngine[Any]): - - """The :class:`_postgresql.TSVECTOR` type implements the PostgreSQL - text search type TSVECTOR. - - It can be used to do full text queries on natural language - documents. - - .. versionadded:: 0.9.0 - - .. seealso:: - - :ref:`postgresql_match` - - """ - - __visit_name__ = "TSVECTOR" - - -class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): - - """PostgreSQL ENUM type. - - This is a subclass of :class:`_types.Enum` which includes - support for PG's ``CREATE TYPE`` and ``DROP TYPE``. - - When the builtin type :class:`_types.Enum` is used and the - :paramref:`.Enum.native_enum` flag is left at its default of - True, the PostgreSQL backend will use a :class:`_postgresql.ENUM` - type as the implementation, so the special create/drop rules - will be used. - - The create/drop behavior of ENUM is necessarily intricate, due to the - awkward relationship the ENUM type has in relationship to the - parent table, in that it may be "owned" by just a single table, or - may be shared among many tables. - - When using :class:`_types.Enum` or :class:`_postgresql.ENUM` - in an "inline" fashion, the ``CREATE TYPE`` and ``DROP TYPE`` is emitted - corresponding to when the :meth:`_schema.Table.create` and - :meth:`_schema.Table.drop` - methods are called:: - - table = Table('sometable', metadata, - Column('some_enum', ENUM('a', 'b', 'c', name='myenum')) - ) - - table.create(engine) # will emit CREATE ENUM and CREATE TABLE - table.drop(engine) # will emit DROP TABLE and DROP ENUM - - To use a common enumerated type between multiple tables, the best - practice is to declare the :class:`_types.Enum` or - :class:`_postgresql.ENUM` independently, and associate it with the - :class:`_schema.MetaData` object itself:: - - my_enum = ENUM('a', 'b', 'c', name='myenum', metadata=metadata) - - t1 = Table('sometable_one', metadata, - Column('some_enum', myenum) - ) - - t2 = Table('sometable_two', metadata, - Column('some_enum', myenum) - ) - - When this pattern is used, care must still be taken at the level - of individual table creates. Emitting CREATE TABLE without also - specifying ``checkfirst=True`` will still cause issues:: - - t1.create(engine) # will fail: no such type 'myenum' - - If we specify ``checkfirst=True``, the individual table-level create - operation will check for the ``ENUM`` and create if not exists:: - - # will check if enum exists, and emit CREATE TYPE if not - t1.create(engine, checkfirst=True) - - When using a metadata-level ENUM type, the type will always be created - and dropped if either the metadata-wide create/drop is called:: - - metadata.create_all(engine) # will emit CREATE TYPE - metadata.drop_all(engine) # will emit DROP TYPE - - The type can also be created and dropped directly:: - - my_enum.create(engine) - my_enum.drop(engine) - - .. versionchanged:: 1.0.0 The PostgreSQL :class:`_postgresql.ENUM` type - now behaves more strictly with regards to CREATE/DROP. A metadata-level - ENUM type will only be created and dropped at the metadata level, - not the table level, with the exception of - ``table.create(checkfirst=True)``. - The ``table.drop()`` call will now emit a DROP TYPE for a table-level - enumerated type. - - """ - - native_enum = True - - def __init__(self, *enums, **kw): - """Construct an :class:`_postgresql.ENUM`. - - Arguments are the same as that of - :class:`_types.Enum`, but also including - the following parameters. - - :param create_type: Defaults to True. - Indicates that ``CREATE TYPE`` should be - emitted, after optionally checking for the - presence of the type, when the parent - table is being created; and additionally - that ``DROP TYPE`` is called when the table - is dropped. When ``False``, no check - will be performed and no ``CREATE TYPE`` - or ``DROP TYPE`` is emitted, unless - :meth:`~.postgresql.ENUM.create` - or :meth:`~.postgresql.ENUM.drop` - are called directly. - Setting to ``False`` is helpful - when invoking a creation scheme to a SQL file - without access to the actual database - - the :meth:`~.postgresql.ENUM.create` and - :meth:`~.postgresql.ENUM.drop` methods can - be used to emit SQL to a target bind. - - """ - native_enum = kw.pop("native_enum", None) - if native_enum is False: - util.warn( - "the native_enum flag does not apply to the " - "sqlalchemy.dialects.postgresql.ENUM datatype; this type " - "always refers to ENUM. Use sqlalchemy.types.Enum for " - "non-native enum." - ) - self.create_type = kw.pop("create_type", True) - super(ENUM, self).__init__(*enums, **kw) - - @classmethod - def adapt_emulated_to_native(cls, impl, **kw): - """Produce a PostgreSQL native :class:`_postgresql.ENUM` from plain - :class:`.Enum`. - - """ - kw.setdefault("validate_strings", impl.validate_strings) - kw.setdefault("name", impl.name) - kw.setdefault("schema", impl.schema) - kw.setdefault("inherit_schema", impl.inherit_schema) - kw.setdefault("metadata", impl.metadata) - kw.setdefault("_create_events", False) - kw.setdefault("values_callable", impl.values_callable) - kw.setdefault("omit_aliases", impl._omit_aliases) - return cls(**kw) - - def create(self, bind=None, checkfirst=True): - """Emit ``CREATE TYPE`` for this - :class:`_postgresql.ENUM`. - - If the underlying dialect does not support - PostgreSQL CREATE TYPE, no action is taken. - - :param bind: a connectable :class:`_engine.Engine`, - :class:`_engine.Connection`, or similar object to emit - SQL. - :param checkfirst: if ``True``, a query against - the PG catalog will be first performed to see - if the type does not exist already before - creating. - - """ - if not bind.dialect.supports_native_enum: - return - - bind._run_ddl_visitor(self.EnumGenerator, self, checkfirst=checkfirst) - - def drop(self, bind=None, checkfirst=True): - """Emit ``DROP TYPE`` for this - :class:`_postgresql.ENUM`. - - If the underlying dialect does not support - PostgreSQL DROP TYPE, no action is taken. - - :param bind: a connectable :class:`_engine.Engine`, - :class:`_engine.Connection`, or similar object to emit - SQL. - :param checkfirst: if ``True``, a query against - the PG catalog will be first performed to see - if the type actually exists before dropping. - - """ - if not bind.dialect.supports_native_enum: - return - - bind._run_ddl_visitor(self.EnumDropper, self, checkfirst=checkfirst) - - class EnumGenerator(InvokeDDLBase): - def __init__(self, dialect, connection, checkfirst=False, **kwargs): - super(ENUM.EnumGenerator, self).__init__(connection, **kwargs) - self.checkfirst = checkfirst - - def _can_create_enum(self, enum): - if not self.checkfirst: - return True - - effective_schema = self.connection.schema_for_object(enum) - - return not self.connection.dialect.has_type( - self.connection, enum.name, schema=effective_schema - ) - - def visit_enum(self, enum): - if not self._can_create_enum(enum): - return - - self.connection.execute(CreateEnumType(enum)) - - class EnumDropper(InvokeDDLBase): - def __init__(self, dialect, connection, checkfirst=False, **kwargs): - super(ENUM.EnumDropper, self).__init__(connection, **kwargs) - self.checkfirst = checkfirst - - def _can_drop_enum(self, enum): - if not self.checkfirst: - return True - - effective_schema = self.connection.schema_for_object(enum) - - return self.connection.dialect.has_type( - self.connection, enum.name, schema=effective_schema - ) - - def visit_enum(self, enum): - if not self._can_drop_enum(enum): - return - - self.connection.execute(DropEnumType(enum)) - - def get_dbapi_type(self, dbapi): - """dont return dbapi.STRING for ENUM in PostgreSQL, since that's - a different type""" - - return None - - def _check_for_name_in_memos(self, checkfirst, kw): - """Look in the 'ddl runner' for 'memos', then - note our name in that collection. - - This to ensure a particular named enum is operated - upon only once within any kind of create/drop - sequence without relying upon "checkfirst". - - """ - if not self.create_type: - return True - if "_ddl_runner" in kw: - ddl_runner = kw["_ddl_runner"] - if "_pg_enums" in ddl_runner.memo: - pg_enums = ddl_runner.memo["_pg_enums"] - else: - pg_enums = ddl_runner.memo["_pg_enums"] = set() - present = (self.schema, self.name) in pg_enums - pg_enums.add((self.schema, self.name)) - return present - else: - return False - - def _on_table_create(self, target, bind, checkfirst=False, **kw): - if ( - checkfirst - or ( - not self.metadata - and not kw.get("_is_metadata_operation", False) - ) - ) and not self._check_for_name_in_memos(checkfirst, kw): - self.create(bind=bind, checkfirst=checkfirst) - - def _on_table_drop(self, target, bind, checkfirst=False, **kw): - if ( - not self.metadata - and not kw.get("_is_metadata_operation", False) - and not self._check_for_name_in_memos(checkfirst, kw) - ): - self.drop(bind=bind, checkfirst=checkfirst) - - def _on_metadata_create(self, target, bind, checkfirst=False, **kw): - if not self._check_for_name_in_memos(checkfirst, kw): - self.create(bind=bind, checkfirst=checkfirst) - - def _on_metadata_drop(self, target, bind, checkfirst=False, **kw): - if not self._check_for_name_in_memos(checkfirst, kw): - self.drop(bind=bind, checkfirst=checkfirst) - - colspecs = { sqltypes.ARRAY: _array.ARRAY, sqltypes.Interval: INTERVAL, @@ -2997,8 +2563,19 @@ class PGIdentifierPreparer(compiler.IdentifierPreparer): class PGInspector(reflection.Inspector): + dialect: PGDialect + def get_table_oid(self, table_name, schema=None): - """Return the OID for the given table name.""" + """Return the OID for the given table name. + + :param table_name: string name of the table. For special quoting, + use :class:`.quoted_name`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + """ with self._operation_context() as conn: return self.dialect.get_table_oid( @@ -3023,9 +2600,10 @@ class PGInspector(reflection.Inspector): .. versionadded:: 1.0.0 """ - schema = schema or self.default_schema_name with self._operation_context() as conn: - return self.dialect._load_enums(conn, schema) + return self.dialect._load_enums( + conn, schema, info_cache=self.info_cache + ) def get_foreign_table_names(self, schema=None): """Return a list of FOREIGN TABLE names. @@ -3038,38 +2616,29 @@ class PGInspector(reflection.Inspector): .. versionadded:: 1.0.0 """ - schema = schema or self.default_schema_name with self._operation_context() as conn: - return self.dialect._get_foreign_table_names(conn, schema) - - def get_view_names(self, schema=None, include=("plain", "materialized")): - """Return all view names in `schema`. + return self.dialect._get_foreign_table_names( + conn, schema, info_cache=self.info_cache + ) - :param schema: Optional, retrieve names from a non-default schema. - For special quoting, use :class:`.quoted_name`. + def has_type(self, type_name, schema=None, **kw): + """Return if the database has the specified type in the provided + schema. - :param include: specify which types of views to return. Passed - as a string value (for a single type) or a tuple (for any number - of types). Defaults to ``('plain', 'materialized')``. + :param type_name: the type to check. + :param schema: schema name. If None, the default schema + (typically 'public') is used. May also be set to '*' to + check in all schemas. - .. versionadded:: 1.1 + .. versionadded:: 2.0 """ - with self._operation_context() as conn: - return self.dialect.get_view_names( - conn, schema, info_cache=self.info_cache, include=include + return self.dialect.has_type( + conn, type_name, schema, info_cache=self.info_cache ) -class CreateEnumType(schema._CreateDropBase): - __visit_name__ = "create_enum_type" - - -class DropEnumType(schema._CreateDropBase): - __visit_name__ = "drop_enum_type" - - class PGExecutionContext(default.DefaultExecutionContext): def fire_sequence(self, seq, type_): return self._execute_scalar( @@ -3262,35 +2831,14 @@ class PGDialect(default.DefaultDialect): def initialize(self, connection): super(PGDialect, self).initialize(connection) - if self.server_version_info <= (8, 2): - self.delete_returning = ( - self.update_returning - ) = self.insert_returning = False - - self.supports_native_enum = self.server_version_info >= (8, 3) - if not self.supports_native_enum: - self.colspecs = self.colspecs.copy() - # pop base Enum type - self.colspecs.pop(sqltypes.Enum, None) - # psycopg2, others may have placed ENUM here as well - self.colspecs.pop(ENUM, None) - # https://www.postgresql.org/docs/9.3/static/release-9-2.html#AEN116689 self.supports_smallserial = self.server_version_info >= (9, 2) - if self.server_version_info < (8, 2): - self._backslash_escapes = False - else: - # ensure this query is not emitted on server version < 8.2 - # as it will fail - std_string = connection.exec_driver_sql( - "show standard_conforming_strings" - ).scalar() - self._backslash_escapes = std_string == "off" - - self._supports_create_index_concurrently = ( - self.server_version_info >= (8, 2) - ) + std_string = connection.exec_driver_sql( + "show standard_conforming_strings" + ).scalar() + self._backslash_escapes = std_string == "off" + self._supports_drop_index_concurrently = self.server_version_info >= ( 9, 2, @@ -3370,122 +2918,100 @@ class PGDialect(default.DefaultDialect): self.do_commit(connection.connection) def do_recover_twophase(self, connection): - resultset = connection.execute( + return connection.scalars( sql.text("SELECT gid FROM pg_prepared_xacts") - ) - return [row[0] for row in resultset] + ).all() def _get_default_schema_name(self, connection): return connection.exec_driver_sql("select current_schema()").scalar() - def has_schema(self, connection, schema): - query = ( - "select nspname from pg_namespace " "where lower(nspname)=:schema" - ) - cursor = connection.execute( - sql.text(query).bindparams( - sql.bindparam( - "schema", - str(schema.lower()), - type_=sqltypes.Unicode, - ) - ) + @reflection.cache + def has_schema(self, connection, schema, **kw): + query = select(pg_catalog.pg_namespace.c.nspname).where( + pg_catalog.pg_namespace.c.nspname == schema ) + return bool(connection.scalar(query)) - return bool(cursor.first()) - - def has_table(self, connection, table_name, schema=None): - self._ensure_has_table_connection(connection) - # seems like case gets folded in pg_class... + def _pg_class_filter_scope_schema( + self, query, schema, scope, pg_class_table=None + ): + if pg_class_table is None: + pg_class_table = pg_catalog.pg_class + query = query.join( + pg_catalog.pg_namespace, + pg_catalog.pg_namespace.c.oid == pg_class_table.c.relnamespace, + ) + if scope is ObjectScope.DEFAULT: + query = query.where(pg_class_table.c.relpersistence != "t") + elif scope is ObjectScope.TEMPORARY: + query = query.where(pg_class_table.c.relpersistence == "t") if schema is None: - cursor = connection.execute( - sql.text( - "select relname from pg_class c join pg_namespace n on " - "n.oid=c.relnamespace where " - "pg_catalog.pg_table_is_visible(c.oid) " - "and relname=:name" - ).bindparams( - sql.bindparam( - "name", - str(table_name), - type_=sqltypes.Unicode, - ) - ) + query = query.where( + pg_catalog.pg_table_is_visible(pg_class_table.c.oid), + # ignore pg_catalog schema + pg_catalog.pg_namespace.c.nspname != "pg_catalog", ) else: - cursor = connection.execute( - sql.text( - "select relname from pg_class c join pg_namespace n on " - "n.oid=c.relnamespace where n.nspname=:schema and " - "relname=:name" - ).bindparams( - sql.bindparam( - "name", - str(table_name), - type_=sqltypes.Unicode, - ), - sql.bindparam( - "schema", - str(schema), - type_=sqltypes.Unicode, - ), - ) - ) - return bool(cursor.first()) - - def has_sequence(self, connection, sequence_name, schema=None): - if schema is None: - schema = self.default_schema_name - cursor = connection.execute( - sql.text( - "SELECT relname FROM pg_class c join pg_namespace n on " - "n.oid=c.relnamespace where relkind='S' and " - "n.nspname=:schema and relname=:name" - ).bindparams( - sql.bindparam( - "name", - str(sequence_name), - type_=sqltypes.Unicode, - ), - sql.bindparam( - "schema", - str(schema), - type_=sqltypes.Unicode, - ), - ) + query = query.where(pg_catalog.pg_namespace.c.nspname == schema) + return query + + def _pg_class_relkind_condition(self, relkinds, pg_class_table=None): + if pg_class_table is None: + pg_class_table = pg_catalog.pg_class + # uses the any form instead of in otherwise postgresql complaings + # that 'IN could not convert type character to "char"' + return pg_class_table.c.relkind == sql.any_(_array.array(relkinds)) + + @lru_cache() + def _has_table_query(self, schema): + query = select(pg_catalog.pg_class.c.relname).where( + pg_catalog.pg_class.c.relname == bindparam("table_name"), + self._pg_class_relkind_condition( + pg_catalog.RELKINDS_ALL_TABLE_LIKE + ), + ) + return self._pg_class_filter_scope_schema( + query, schema, scope=ObjectScope.ANY ) - return bool(cursor.first()) + @reflection.cache + def has_table(self, connection, table_name, schema=None, **kw): + self._ensure_has_table_connection(connection) + query = self._has_table_query(schema) + return bool(connection.scalar(query, {"table_name": table_name})) - def has_type(self, connection, type_name, schema=None): - if schema is not None: - query = """ - SELECT EXISTS ( - SELECT * FROM pg_catalog.pg_type t, pg_catalog.pg_namespace n - WHERE t.typnamespace = n.oid - AND t.typname = :typname - AND n.nspname = :nspname - ) - """ - query = sql.text(query) - else: - query = """ - SELECT EXISTS ( - SELECT * FROM pg_catalog.pg_type t - WHERE t.typname = :typname - AND pg_type_is_visible(t.oid) - ) - """ - query = sql.text(query) - query = query.bindparams( - sql.bindparam("typname", str(type_name), type_=sqltypes.Unicode) + @reflection.cache + def has_sequence(self, connection, sequence_name, schema=None, **kw): + query = select(pg_catalog.pg_class.c.relname).where( + pg_catalog.pg_class.c.relkind == "S", + pg_catalog.pg_class.c.relname == sequence_name, ) - if schema is not None: - query = query.bindparams( - sql.bindparam("nspname", str(schema), type_=sqltypes.Unicode) + query = self._pg_class_filter_scope_schema( + query, schema, scope=ObjectScope.ANY + ) + return bool(connection.scalar(query)) + + @reflection.cache + def has_type(self, connection, type_name, schema=None, **kw): + query = ( + select(pg_catalog.pg_type.c.typname) + .join( + pg_catalog.pg_namespace, + pg_catalog.pg_namespace.c.oid + == pg_catalog.pg_type.c.typnamespace, ) - cursor = connection.execute(query) - return bool(cursor.scalar()) + .where(pg_catalog.pg_type.c.typname == type_name) + ) + if schema is None: + query = query.where( + pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid), + # ignore pg_catalog schema + pg_catalog.pg_namespace.c.nspname != "pg_catalog", + ) + elif schema != "*": + query = query.where(pg_catalog.pg_namespace.c.nspname == schema) + + return bool(connection.scalar(query)) def _get_server_version_info(self, connection): v = connection.exec_driver_sql("select pg_catalog.version()").scalar() @@ -3502,229 +3028,300 @@ class PGDialect(default.DefaultDialect): @reflection.cache def get_table_oid(self, connection, table_name, schema=None, **kw): - """Fetch the oid for schema.table_name. - - Several reflection methods require the table oid. The idea for using - this method is that it can be fetched one time and cached for - subsequent calls. - - """ - table_oid = None - if schema is not None: - schema_where_clause = "n.nspname = :schema" - else: - schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)" - query = ( - """ - SELECT c.oid - FROM pg_catalog.pg_class c - LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace - WHERE (%s) - AND c.relname = :table_name AND c.relkind in - ('r', 'v', 'm', 'f', 'p') - """ - % schema_where_clause + """Fetch the oid for schema.table_name.""" + query = select(pg_catalog.pg_class.c.oid).where( + pg_catalog.pg_class.c.relname == table_name, + self._pg_class_relkind_condition( + pg_catalog.RELKINDS_ALL_TABLE_LIKE + ), ) - # Since we're binding to unicode, table_name and schema_name must be - # unicode. - table_name = str(table_name) - if schema is not None: - schema = str(schema) - s = sql.text(query).bindparams(table_name=sqltypes.Unicode) - s = s.columns(oid=sqltypes.Integer) - if schema: - s = s.bindparams(sql.bindparam("schema", type_=sqltypes.Unicode)) - c = connection.execute(s, dict(table_name=table_name, schema=schema)) - table_oid = c.scalar() + query = self._pg_class_filter_scope_schema( + query, schema, scope=ObjectScope.ANY + ) + table_oid = connection.scalar(query) if table_oid is None: - raise exc.NoSuchTableError(table_name) + raise exc.NoSuchTableError( + f"{schema}.{table_name}" if schema else table_name + ) return table_oid @reflection.cache def get_schema_names(self, connection, **kw): - result = connection.execute( - sql.text( - "SELECT nspname FROM pg_namespace " - "WHERE nspname NOT LIKE 'pg_%' " - "ORDER BY nspname" - ).columns(nspname=sqltypes.Unicode) + query = ( + select(pg_catalog.pg_namespace.c.nspname) + .where(pg_catalog.pg_namespace.c.nspname.not_like("pg_%")) + .order_by(pg_catalog.pg_namespace.c.nspname) + ) + return connection.scalars(query).all() + + def _get_relnames_for_relkinds(self, connection, schema, relkinds, scope): + query = select(pg_catalog.pg_class.c.relname).where( + self._pg_class_relkind_condition(relkinds) ) - return [name for name, in result] + query = self._pg_class_filter_scope_schema(query, schema, scope=scope) + return connection.scalars(query).all() @reflection.cache def get_table_names(self, connection, schema=None, **kw): - result = connection.execute( - sql.text( - "SELECT c.relname FROM pg_class c " - "JOIN pg_namespace n ON n.oid = c.relnamespace " - "WHERE n.nspname = :schema AND c.relkind in ('r', 'p')" - ).columns(relname=sqltypes.Unicode), - dict( - schema=schema - if schema is not None - else self.default_schema_name - ), + return self._get_relnames_for_relkinds( + connection, + schema, + pg_catalog.RELKINDS_TABLE_NO_FOREIGN, + scope=ObjectScope.DEFAULT, + ) + + @reflection.cache + def get_temp_table_names(self, connection, **kw): + return self._get_relnames_for_relkinds( + connection, + schema=None, + relkinds=pg_catalog.RELKINDS_TABLE_NO_FOREIGN, + scope=ObjectScope.TEMPORARY, ) - return [name for name, in result] @reflection.cache def _get_foreign_table_names(self, connection, schema=None, **kw): - result = connection.execute( - sql.text( - "SELECT c.relname FROM pg_class c " - "JOIN pg_namespace n ON n.oid = c.relnamespace " - "WHERE n.nspname = :schema AND c.relkind = 'f'" - ).columns(relname=sqltypes.Unicode), - dict( - schema=schema - if schema is not None - else self.default_schema_name - ), + return self._get_relnames_for_relkinds( + connection, schema, relkinds=("f",), scope=ObjectScope.ANY ) - return [name for name, in result] @reflection.cache - def get_view_names( - self, connection, schema=None, include=("plain", "materialized"), **kw - ): + def get_view_names(self, connection, schema=None, **kw): + return self._get_relnames_for_relkinds( + connection, + schema, + pg_catalog.RELKINDS_VIEW, + scope=ObjectScope.DEFAULT, + ) - include_kind = {"plain": "v", "materialized": "m"} - try: - kinds = [include_kind[i] for i in util.to_list(include)] - except KeyError: - raise ValueError( - "include %r unknown, needs to be a sequence containing " - "one or both of 'plain' and 'materialized'" % (include,) - ) - if not kinds: - raise ValueError( - "empty include, needs to be a sequence containing " - "one or both of 'plain' and 'materialized'" - ) + @reflection.cache + def get_materialized_view_names(self, connection, schema=None, **kw): + return self._get_relnames_for_relkinds( + connection, + schema, + pg_catalog.RELKINDS_MAT_VIEW, + scope=ObjectScope.DEFAULT, + ) - result = connection.execute( - sql.text( - "SELECT c.relname FROM pg_class c " - "JOIN pg_namespace n ON n.oid = c.relnamespace " - "WHERE n.nspname = :schema AND c.relkind IN (%s)" - % (", ".join("'%s'" % elem for elem in kinds)) - ).columns(relname=sqltypes.Unicode), - dict( - schema=schema - if schema is not None - else self.default_schema_name - ), + @reflection.cache + def get_temp_view_names(self, connection, schema=None, **kw): + return self._get_relnames_for_relkinds( + connection, + schema, + # NOTE: do not include temp materialzied views (that do not + # seem to be a thing at least up to version 14) + pg_catalog.RELKINDS_VIEW, + scope=ObjectScope.TEMPORARY, ) - return [name for name, in result] @reflection.cache def get_sequence_names(self, connection, schema=None, **kw): - if not schema: - schema = self.default_schema_name - cursor = connection.execute( - sql.text( - "SELECT relname FROM pg_class c join pg_namespace n on " - "n.oid=c.relnamespace where relkind='S' and " - "n.nspname=:schema" - ).bindparams( - sql.bindparam( - "schema", - str(schema), - type_=sqltypes.Unicode, - ), - ) + return self._get_relnames_for_relkinds( + connection, schema, relkinds=("S",), scope=ObjectScope.ANY ) - return [row[0] for row in cursor] @reflection.cache def get_view_definition(self, connection, view_name, schema=None, **kw): - view_def = connection.scalar( - sql.text( - "SELECT pg_get_viewdef(c.oid) view_def FROM pg_class c " - "JOIN pg_namespace n ON n.oid = c.relnamespace " - "WHERE n.nspname = :schema AND c.relname = :view_name " - "AND c.relkind IN ('v', 'm')" - ).columns(view_def=sqltypes.Unicode), - dict( - schema=schema - if schema is not None - else self.default_schema_name, - view_name=view_name, - ), + query = ( + select(pg_catalog.pg_get_viewdef(pg_catalog.pg_class.c.oid)) + .select_from(pg_catalog.pg_class) + .where( + pg_catalog.pg_class.c.relname == view_name, + self._pg_class_relkind_condition( + pg_catalog.RELKINDS_VIEW + pg_catalog.RELKINDS_MAT_VIEW + ), + ) ) - return view_def + query = self._pg_class_filter_scope_schema( + query, schema, scope=ObjectScope.ANY + ) + res = connection.scalar(query) + if res is None: + raise exc.NoSuchTableError( + f"{schema}.{view_name}" if schema else view_name + ) + else: + return res + + def _value_or_raise(self, data, table, schema): + try: + return dict(data)[(schema, table)] + except KeyError: + raise exc.NoSuchTableError( + f"{schema}.{table}" if schema else table + ) from None + + def _prepare_filter_names(self, filter_names): + if filter_names: + return True, {"filter_names": filter_names} + else: + return False, {} + + def _kind_to_relkinds(self, kind: ObjectKind) -> tuple[str, ...]: + if kind is ObjectKind.ANY: + return pg_catalog.RELKINDS_ALL_TABLE_LIKE + relkinds = () + if ObjectKind.TABLE in kind: + relkinds += pg_catalog.RELKINDS_TABLE + if ObjectKind.VIEW in kind: + relkinds += pg_catalog.RELKINDS_VIEW + if ObjectKind.MATERIALIZED_VIEW in kind: + relkinds += pg_catalog.RELKINDS_MAT_VIEW + return relkinds @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): - - table_oid = self.get_table_oid( - connection, table_name, schema, info_cache=kw.get("info_cache") + data = self.get_multi_columns( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) + return self._value_or_raise(data, table_name, schema) + @lru_cache() + def _columns_query(self, schema, has_filter_names, scope, kind): + # NOTE: the query with the default and identity options scalar + # subquery is faster than trying to use outer joins for them generated = ( - "a.attgenerated as generated" + pg_catalog.pg_attribute.c.attgenerated.label("generated") if self.server_version_info >= (12,) - else "NULL as generated" + else sql.null().label("generated") ) if self.server_version_info >= (10,): - # a.attidentity != '' is required or it will reflect also - # serial columns as identity. - identity = """\ - (SELECT json_build_object( - 'always', a.attidentity = 'a', - 'start', s.seqstart, - 'increment', s.seqincrement, - 'minvalue', s.seqmin, - 'maxvalue', s.seqmax, - 'cache', s.seqcache, - 'cycle', s.seqcycle) - FROM pg_catalog.pg_sequence s - JOIN pg_catalog.pg_class c on s.seqrelid = c."oid" - WHERE c.relkind = 'S' - AND a.attidentity != '' - AND s.seqrelid = pg_catalog.pg_get_serial_sequence( - a.attrelid::regclass::text, a.attname - )::regclass::oid - ) as identity_options\ - """ + # join lateral performs worse (~2x slower) than a scalar_subquery + identity = ( + select( + sql.func.json_build_object( + "always", + pg_catalog.pg_attribute.c.attidentity == "a", + "start", + pg_catalog.pg_sequence.c.seqstart, + "increment", + pg_catalog.pg_sequence.c.seqincrement, + "minvalue", + pg_catalog.pg_sequence.c.seqmin, + "maxvalue", + pg_catalog.pg_sequence.c.seqmax, + "cache", + pg_catalog.pg_sequence.c.seqcache, + "cycle", + pg_catalog.pg_sequence.c.seqcycle, + ) + ) + .select_from(pg_catalog.pg_sequence) + .where( + # attidentity != '' is required or it will reflect also + # serial columns as identity. + pg_catalog.pg_attribute.c.attidentity != "", + pg_catalog.pg_sequence.c.seqrelid + == sql.cast( + sql.cast( + pg_catalog.pg_get_serial_sequence( + sql.cast( + sql.cast( + pg_catalog.pg_attribute.c.attrelid, + REGCLASS, + ), + TEXT, + ), + pg_catalog.pg_attribute.c.attname, + ), + REGCLASS, + ), + OID, + ), + ) + .correlate(pg_catalog.pg_attribute) + .scalar_subquery() + .label("identity_options") + ) else: - identity = "NULL as identity_options" - - SQL_COLS = """ - SELECT a.attname, - pg_catalog.format_type(a.atttypid, a.atttypmod), - ( - SELECT pg_catalog.pg_get_expr(d.adbin, d.adrelid) - FROM pg_catalog.pg_attrdef d - WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum - AND a.atthasdef - ) AS DEFAULT, - a.attnotnull, - a.attrelid as table_oid, - pgd.description as comment, - %s, - %s - FROM pg_catalog.pg_attribute a - LEFT JOIN pg_catalog.pg_description pgd ON ( - pgd.objoid = a.attrelid AND pgd.objsubid = a.attnum) - WHERE a.attrelid = :table_oid - AND a.attnum > 0 AND NOT a.attisdropped - ORDER BY a.attnum - """ % ( - generated, - identity, + identity = sql.null().label("identity_options") + + # join lateral performs the same as scalar_subquery here + default = ( + select( + pg_catalog.pg_get_expr( + pg_catalog.pg_attrdef.c.adbin, + pg_catalog.pg_attrdef.c.adrelid, + ) + ) + .select_from(pg_catalog.pg_attrdef) + .where( + pg_catalog.pg_attrdef.c.adrelid + == pg_catalog.pg_attribute.c.attrelid, + pg_catalog.pg_attrdef.c.adnum + == pg_catalog.pg_attribute.c.attnum, + pg_catalog.pg_attribute.c.atthasdef, + ) + .correlate(pg_catalog.pg_attribute) + .scalar_subquery() + .label("default") ) - s = ( - sql.text(SQL_COLS) - .bindparams(sql.bindparam("table_oid", type_=sqltypes.Integer)) - .columns(attname=sqltypes.Unicode, default=sqltypes.Unicode) + relkinds = self._kind_to_relkinds(kind) + query = ( + select( + pg_catalog.pg_attribute.c.attname.label("name"), + pg_catalog.format_type( + pg_catalog.pg_attribute.c.atttypid, + pg_catalog.pg_attribute.c.atttypmod, + ).label("format_type"), + default, + pg_catalog.pg_attribute.c.attnotnull.label("not_null"), + pg_catalog.pg_class.c.relname.label("table_name"), + pg_catalog.pg_description.c.description.label("comment"), + generated, + identity, + ) + .select_from(pg_catalog.pg_class) + # NOTE: postgresql support table with no user column, meaning + # there is no row with pg_attribute.attnum > 0. use a left outer + # join to avoid filtering these tables. + .outerjoin( + pg_catalog.pg_attribute, + sql.and_( + pg_catalog.pg_class.c.oid + == pg_catalog.pg_attribute.c.attrelid, + pg_catalog.pg_attribute.c.attnum > 0, + ~pg_catalog.pg_attribute.c.attisdropped, + ), + ) + .outerjoin( + pg_catalog.pg_description, + sql.and_( + pg_catalog.pg_description.c.objoid + == pg_catalog.pg_attribute.c.attrelid, + pg_catalog.pg_description.c.objsubid + == pg_catalog.pg_attribute.c.attnum, + ), + ) + .where(self._pg_class_relkind_condition(relkinds)) + .order_by( + pg_catalog.pg_class.c.relname, pg_catalog.pg_attribute.c.attnum + ) ) - c = connection.execute(s, dict(table_oid=table_oid)) - rows = c.fetchall() + query = self._pg_class_filter_scope_schema(query, schema, scope=scope) + if has_filter_names: + query = query.where( + pg_catalog.pg_class.c.relname.in_(bindparam("filter_names")) + ) + return query + + def get_multi_columns( + self, connection, schema, filter_names, scope, kind, **kw + ): + has_filter_names, params = self._prepare_filter_names(filter_names) + query = self._columns_query(schema, has_filter_names, scope, kind) + rows = connection.execute(query, params).mappings() # dictionary with (name, ) if default search path or (schema, name) # as keys - domains = self._load_domains(connection) + domains = self._load_domains( + connection, info_cache=kw.get("info_cache") + ) # dictionary with (name, ) if default search path or (schema, name) # as keys @@ -3732,257 +3329,340 @@ class PGDialect(default.DefaultDialect): ((rec["name"],), rec) if rec["visible"] else ((rec["schema"], rec["name"]), rec) - for rec in self._load_enums(connection, schema="*") + for rec in self._load_enums( + connection, schema="*", info_cache=kw.get("info_cache") + ) ) - # format columns - columns = [] - - for ( - name, - format_type, - default_, - notnull, - table_oid, - comment, - generated, - identity, - ) in rows: - column_info = self._get_column_info( - name, - format_type, - default_, - notnull, - domains, - enums, - schema, - comment, - generated, - identity, - ) - columns.append(column_info) - return columns + columns = self._get_columns_info(rows, domains, enums, schema) + + return columns.items() + + def _get_columns_info(self, rows, domains, enums, schema): + array_type_pattern = re.compile(r"\[\]$") + attype_pattern = re.compile(r"\(.*\)") + charlen_pattern = re.compile(r"\(([\d,]+)\)") + args_pattern = re.compile(r"\((.*)\)") + args_split_pattern = re.compile(r"\s*,\s*") - def _get_column_info( - self, - name, - format_type, - default, - notnull, - domains, - enums, - schema, - comment, - generated, - identity, - ): def _handle_array_type(attype): return ( # strip '[]' from integer[], etc. - re.sub(r"\[\]$", "", attype), + array_type_pattern.sub("", attype), attype.endswith("[]"), ) - # strip (*) from character varying(5), timestamp(5) - # with time zone, geometry(POLYGON), etc. - attype = re.sub(r"\(.*\)", "", format_type) + columns = defaultdict(list) + for row_dict in rows: + # ensure that each table has an entry, even if it has no columns + if row_dict["name"] is None: + columns[ + (schema, row_dict["table_name"]) + ] = ReflectionDefaults.columns() + continue + table_cols = columns[(schema, row_dict["table_name"])] - # strip '[]' from integer[], etc. and check if an array - attype, is_array = _handle_array_type(attype) + format_type = row_dict["format_type"] + default = row_dict["default"] + name = row_dict["name"] + generated = row_dict["generated"] + identity = row_dict["identity_options"] - # strip quotes from case sensitive enum or domain names - enum_or_domain_key = tuple(util.quoted_token_parser(attype)) + # strip (*) from character varying(5), timestamp(5) + # with time zone, geometry(POLYGON), etc. + attype = attype_pattern.sub("", format_type) - nullable = not notnull + # strip '[]' from integer[], etc. and check if an array + attype, is_array = _handle_array_type(attype) - charlen = re.search(r"\(([\d,]+)\)", format_type) - if charlen: - charlen = charlen.group(1) - args = re.search(r"\((.*)\)", format_type) - if args and args.group(1): - args = tuple(re.split(r"\s*,\s*", args.group(1))) - else: - args = () - kwargs = {} + # strip quotes from case sensitive enum or domain names + enum_or_domain_key = tuple(util.quoted_token_parser(attype)) + + nullable = not row_dict["not_null"] - if attype == "numeric": + charlen = charlen_pattern.search(format_type) if charlen: - prec, scale = charlen.split(",") - args = (int(prec), int(scale)) + charlen = charlen.group(1) + args = args_pattern.search(format_type) + if args and args.group(1): + args = tuple(args_split_pattern.split(args.group(1))) else: args = () - elif attype == "double precision": - args = (53,) - elif attype == "integer": - args = () - elif attype in ("timestamp with time zone", "time with time zone"): - kwargs["timezone"] = True - if charlen: - kwargs["precision"] = int(charlen) - args = () - elif attype in ( - "timestamp without time zone", - "time without time zone", - "time", - ): - kwargs["timezone"] = False - if charlen: - kwargs["precision"] = int(charlen) - args = () - elif attype == "bit varying": - kwargs["varying"] = True - if charlen: + kwargs = {} + + if attype == "numeric": + if charlen: + prec, scale = charlen.split(",") + args = (int(prec), int(scale)) + else: + args = () + elif attype == "double precision": + args = (53,) + elif attype == "integer": + args = () + elif attype in ("timestamp with time zone", "time with time zone"): + kwargs["timezone"] = True + if charlen: + kwargs["precision"] = int(charlen) + args = () + elif attype in ( + "timestamp without time zone", + "time without time zone", + "time", + ): + kwargs["timezone"] = False + if charlen: + kwargs["precision"] = int(charlen) + args = () + elif attype == "bit varying": + kwargs["varying"] = True + if charlen: + args = (int(charlen),) + else: + args = () + elif attype.startswith("interval"): + field_match = re.match(r"interval (.+)", attype, re.I) + if charlen: + kwargs["precision"] = int(charlen) + if field_match: + kwargs["fields"] = field_match.group(1) + attype = "interval" + args = () + elif charlen: args = (int(charlen),) + + while True: + # looping here to suit nested domains + if attype in self.ischema_names: + coltype = self.ischema_names[attype] + break + elif enum_or_domain_key in enums: + enum = enums[enum_or_domain_key] + coltype = ENUM + kwargs["name"] = enum["name"] + if not enum["visible"]: + kwargs["schema"] = enum["schema"] + args = tuple(enum["labels"]) + break + elif enum_or_domain_key in domains: + domain = domains[enum_or_domain_key] + attype = domain["attype"] + attype, is_array = _handle_array_type(attype) + # strip quotes from case sensitive enum or domain names + enum_or_domain_key = tuple( + util.quoted_token_parser(attype) + ) + # A table can't override a not null on the domain, + # but can override nullable + nullable = nullable and domain["nullable"] + if domain["default"] and not default: + # It can, however, override the default + # value, but can't set it to null. + default = domain["default"] + continue + else: + coltype = None + break + + if coltype: + coltype = coltype(*args, **kwargs) + if is_array: + coltype = self.ischema_names["_array"](coltype) else: - args = () - elif attype.startswith("interval"): - field_match = re.match(r"interval (.+)", attype, re.I) - if charlen: - kwargs["precision"] = int(charlen) - if field_match: - kwargs["fields"] = field_match.group(1) - attype = "interval" - args = () - elif charlen: - args = (int(charlen),) - - while True: - # looping here to suit nested domains - if attype in self.ischema_names: - coltype = self.ischema_names[attype] - break - elif enum_or_domain_key in enums: - enum = enums[enum_or_domain_key] - coltype = ENUM - kwargs["name"] = enum["name"] - if not enum["visible"]: - kwargs["schema"] = enum["schema"] - args = tuple(enum["labels"]) - break - elif enum_or_domain_key in domains: - domain = domains[enum_or_domain_key] - attype = domain["attype"] - attype, is_array = _handle_array_type(attype) - # strip quotes from case sensitive enum or domain names - enum_or_domain_key = tuple(util.quoted_token_parser(attype)) - # A table can't override a not null on the domain, - # but can override nullable - nullable = nullable and domain["nullable"] - if domain["default"] and not default: - # It can, however, override the default - # value, but can't set it to null. - default = domain["default"] - continue + util.warn( + "Did not recognize type '%s' of column '%s'" + % (attype, name) + ) + coltype = sqltypes.NULLTYPE + + # If a zero byte or blank string depending on driver (is also + # absent for older PG versions), then not a generated column. + # Otherwise, s = stored. (Other values might be added in the + # future.) + if generated not in (None, "", b"\x00"): + computed = dict( + sqltext=default, persisted=generated in ("s", b"s") + ) + default = None else: - coltype = None - break + computed = None - if coltype: - coltype = coltype(*args, **kwargs) - if is_array: - coltype = self.ischema_names["_array"](coltype) - else: - util.warn( - "Did not recognize type '%s' of column '%s'" % (attype, name) + # adjust the default value + autoincrement = False + if default is not None: + match = re.search(r"""(nextval\(')([^']+)('.*$)""", default) + if match is not None: + if issubclass(coltype._type_affinity, sqltypes.Integer): + autoincrement = True + # the default is related to a Sequence + if "." not in match.group(2) and schema is not None: + # unconditionally quote the schema name. this could + # later be enhanced to obey quoting rules / + # "quote schema" + default = ( + match.group(1) + + ('"%s"' % schema) + + "." + + match.group(2) + + match.group(3) + ) + + column_info = { + "name": name, + "type": coltype, + "nullable": nullable, + "default": default, + "autoincrement": autoincrement or identity is not None, + "comment": row_dict["comment"], + } + if computed is not None: + column_info["computed"] = computed + if identity is not None: + column_info["identity"] = identity + + table_cols.append(column_info) + + return columns + + @lru_cache() + def _table_oids_query(self, schema, has_filter_names, scope, kind): + relkinds = self._kind_to_relkinds(kind) + oid_q = select( + pg_catalog.pg_class.c.oid, pg_catalog.pg_class.c.relname + ).where(self._pg_class_relkind_condition(relkinds)) + oid_q = self._pg_class_filter_scope_schema(oid_q, schema, scope=scope) + + if has_filter_names: + oid_q = oid_q.where( + pg_catalog.pg_class.c.relname.in_(bindparam("filter_names")) ) - coltype = sqltypes.NULLTYPE - - # If a zero byte or blank string depending on driver (is also absent - # for older PG versions), then not a generated column. Otherwise, s = - # stored. (Other values might be added in the future.) - if generated not in (None, "", b"\x00"): - computed = dict( - sqltext=default, persisted=generated in ("s", b"s") + return oid_q + + @reflection.flexi_cache( + ("schema", InternalTraversal.dp_string), + ("filter_names", InternalTraversal.dp_string_list), + ("kind", InternalTraversal.dp_plain_obj), + ("scope", InternalTraversal.dp_plain_obj), + ) + def _get_table_oids( + self, connection, schema, filter_names, scope, kind, **kw + ): + has_filter_names, params = self._prepare_filter_names(filter_names) + oid_q = self._table_oids_query(schema, has_filter_names, scope, kind) + result = connection.execute(oid_q, params) + return result.all() + + @util.memoized_property + def _constraint_query(self): + con_sq = ( + select( + pg_catalog.pg_constraint.c.conrelid, + pg_catalog.pg_constraint.c.conname, + sql.func.unnest(pg_catalog.pg_constraint.c.conkey).label( + "attnum" + ), + sql.func.generate_subscripts( + pg_catalog.pg_constraint.c.conkey, 1 + ).label("ord"), ) - default = None - else: - computed = None - - # adjust the default value - autoincrement = False - if default is not None: - match = re.search(r"""(nextval\(')([^']+)('.*$)""", default) - if match is not None: - if issubclass(coltype._type_affinity, sqltypes.Integer): - autoincrement = True - # the default is related to a Sequence - sch = schema - if "." not in match.group(2) and sch is not None: - # unconditionally quote the schema name. this could - # later be enhanced to obey quoting rules / - # "quote schema" - default = ( - match.group(1) - + ('"%s"' % sch) - + "." - + match.group(2) - + match.group(3) - ) + .where( + pg_catalog.pg_constraint.c.contype == bindparam("contype"), + pg_catalog.pg_constraint.c.conrelid.in_(bindparam("oids")), + ) + .subquery("con") + ) - column_info = dict( - name=name, - type=coltype, - nullable=nullable, - default=default, - autoincrement=autoincrement or identity is not None, - comment=comment, + attr_sq = ( + select( + con_sq.c.conrelid, + con_sq.c.conname, + pg_catalog.pg_attribute.c.attname, + ) + .select_from(pg_catalog.pg_attribute) + .join( + con_sq, + sql.and_( + pg_catalog.pg_attribute.c.attnum == con_sq.c.attnum, + pg_catalog.pg_attribute.c.attrelid == con_sq.c.conrelid, + ), + ) + .order_by(con_sq.c.conname, con_sq.c.ord) + .subquery("attr") ) - if computed is not None: - column_info["computed"] = computed - if identity is not None: - column_info["identity"] = identity - return column_info - @reflection.cache - def get_pk_constraint(self, connection, table_name, schema=None, **kw): - table_oid = self.get_table_oid( - connection, table_name, schema, info_cache=kw.get("info_cache") + return ( + select( + attr_sq.c.conrelid, + sql.func.array_agg(attr_sq.c.attname).label("cols"), + attr_sq.c.conname, + ) + .group_by(attr_sq.c.conrelid, attr_sq.c.conname) + .order_by(attr_sq.c.conrelid, attr_sq.c.conname) ) - if self.server_version_info < (8, 4): - PK_SQL = """ - SELECT a.attname - FROM - pg_class t - join pg_index ix on t.oid = ix.indrelid - join pg_attribute a - on t.oid=a.attrelid AND %s - WHERE - t.oid = :table_oid and ix.indisprimary = 't' - ORDER BY a.attnum - """ % self._pg_index_any( - "a.attnum", "ix.indkey" + def _reflect_constraint( + self, connection, contype, schema, filter_names, scope, kind, **kw + ): + table_oids = self._get_table_oids( + connection, schema, filter_names, scope, kind, **kw + ) + batches = list(table_oids) + + while batches: + batch = batches[0:3000] + batches[0:3000] = [] + + result = connection.execute( + self._constraint_query, + {"oids": [r[0] for r in batch], "contype": contype}, ) - else: - # unnest() and generate_subscripts() both introduced in - # version 8.4 - PK_SQL = """ - SELECT a.attname - FROM pg_attribute a JOIN ( - SELECT unnest(ix.indkey) attnum, - generate_subscripts(ix.indkey, 1) ord - FROM pg_index ix - WHERE ix.indrelid = :table_oid AND ix.indisprimary - ) k ON a.attnum=k.attnum - WHERE a.attrelid = :table_oid - ORDER BY k.ord - """ - t = sql.text(PK_SQL).columns(attname=sqltypes.Unicode) - c = connection.execute(t, dict(table_oid=table_oid)) - cols = [r[0] for r in c.fetchall()] - - PK_CONS_SQL = """ - SELECT conname - FROM pg_catalog.pg_constraint r - WHERE r.conrelid = :table_oid AND r.contype = 'p' - ORDER BY 1 - """ - t = sql.text(PK_CONS_SQL).columns(conname=sqltypes.Unicode) - c = connection.execute(t, dict(table_oid=table_oid)) - name = c.scalar() + result_by_oid = defaultdict(list) + for oid, cols, constraint_name in result: + result_by_oid[oid].append((cols, constraint_name)) + + for oid, tablename in batch: + for_oid = result_by_oid.get(oid, ()) + if for_oid: + for cols, constraint in for_oid: + yield tablename, cols, constraint + else: + yield tablename, None, None + + @reflection.cache + def get_pk_constraint(self, connection, table_name, schema=None, **kw): + data = self.get_multi_pk_constraint( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + def get_multi_pk_constraint( + self, connection, schema, filter_names, scope, kind, **kw + ): + result = self._reflect_constraint( + connection, "p", schema, filter_names, scope, kind, **kw + ) - return {"constrained_columns": cols, "name": name} + # only a single pk can be present for each table. Return an entry + # even if a table has no primary key + default = ReflectionDefaults.pk_constraint + return ( + ( + (schema, table_name), + { + "constrained_columns": [] if cols is None else cols, + "name": pk_name, + } + if pk_name is not None + else default(), + ) + for (table_name, cols, pk_name) in result + ) @reflection.cache def get_foreign_keys( @@ -3993,27 +3673,71 @@ class PGDialect(default.DefaultDialect): postgresql_ignore_search_path=False, **kw, ): - preparer = self.identifier_preparer - table_oid = self.get_table_oid( - connection, table_name, schema, info_cache=kw.get("info_cache") + data = self.get_multi_foreign_keys( + connection, + schema=schema, + filter_names=[table_name], + postgresql_ignore_search_path=postgresql_ignore_search_path, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) + return self._value_or_raise(data, table_name, schema) - FK_SQL = """ - SELECT r.conname, - pg_catalog.pg_get_constraintdef(r.oid, true) as condef, - n.nspname as conschema - FROM pg_catalog.pg_constraint r, - pg_namespace n, - pg_class c - - WHERE r.conrelid = :table AND - r.contype = 'f' AND - c.oid = confrelid AND - n.oid = c.relnamespace - ORDER BY 1 - """ - # https://www.postgresql.org/docs/9.0/static/sql-createtable.html - FK_REGEX = re.compile( + @lru_cache() + def _foreing_key_query(self, schema, has_filter_names, scope, kind): + pg_class_ref = pg_catalog.pg_class.alias("cls_ref") + pg_namespace_ref = pg_catalog.pg_namespace.alias("nsp_ref") + relkinds = self._kind_to_relkinds(kind) + query = ( + select( + pg_catalog.pg_class.c.relname, + pg_catalog.pg_constraint.c.conname, + sql.case( + ( + pg_catalog.pg_constraint.c.oid.is_not(None), + pg_catalog.pg_get_constraintdef( + pg_catalog.pg_constraint.c.oid, True + ), + ), + else_=None, + ), + pg_namespace_ref.c.nspname, + ) + .select_from(pg_catalog.pg_class) + .outerjoin( + pg_catalog.pg_constraint, + sql.and_( + pg_catalog.pg_class.c.oid + == pg_catalog.pg_constraint.c.conrelid, + pg_catalog.pg_constraint.c.contype == "f", + ), + ) + .outerjoin( + pg_class_ref, + pg_class_ref.c.oid == pg_catalog.pg_constraint.c.confrelid, + ) + .outerjoin( + pg_namespace_ref, + pg_class_ref.c.relnamespace == pg_namespace_ref.c.oid, + ) + .order_by( + pg_catalog.pg_class.c.relname, + pg_catalog.pg_constraint.c.conname, + ) + .where(self._pg_class_relkind_condition(relkinds)) + ) + query = self._pg_class_filter_scope_schema(query, schema, scope) + if has_filter_names: + query = query.where( + pg_catalog.pg_class.c.relname.in_(bindparam("filter_names")) + ) + return query + + @util.memoized_property + def _fk_regex_pattern(self): + # https://www.postgresql.org/docs/14.0/static/sql-createtable.html + return re.compile( r"FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)" r"[\s]?(MATCH (FULL|PARTIAL|SIMPLE)+)?" r"[\s]?(ON UPDATE " @@ -4024,12 +3748,33 @@ class PGDialect(default.DefaultDialect): r"[\s]?(INITIALLY (DEFERRED|IMMEDIATE)+)?" ) - t = sql.text(FK_SQL).columns( - conname=sqltypes.Unicode, condef=sqltypes.Unicode - ) - c = connection.execute(t, dict(table=table_oid)) - fkeys = [] - for conname, condef, conschema in c.fetchall(): + def get_multi_foreign_keys( + self, + connection, + schema, + filter_names, + scope, + kind, + postgresql_ignore_search_path=False, + **kw, + ): + preparer = self.identifier_preparer + + has_filter_names, params = self._prepare_filter_names(filter_names) + query = self._foreing_key_query(schema, has_filter_names, scope, kind) + result = connection.execute(query, params) + + FK_REGEX = self._fk_regex_pattern + + fkeys = defaultdict(list) + default = ReflectionDefaults.foreign_keys + for table_name, conname, condef, conschema in result: + # ensure that each table has an entry, even if it has + # no foreign keys + if conname is None: + fkeys[(schema, table_name)] = default() + continue + table_fks = fkeys[(schema, table_name)] m = re.search(FK_REGEX, condef).groups() ( @@ -4096,317 +3841,406 @@ class PGDialect(default.DefaultDialect): "referred_columns": referred_columns, "options": options, } - fkeys.append(fkey_d) - return fkeys - - def _pg_index_any(self, col, compare_to): - if self.server_version_info < (8, 1): - # https://www.postgresql.org/message-id/10279.1124395722@sss.pgh.pa.us - # "In CVS tip you could replace this with "attnum = ANY (indkey)". - # Unfortunately, most array support doesn't work on int2vector in - # pre-8.1 releases, so I think you're kinda stuck with the above - # for now. - # regards, tom lane" - return "(%s)" % " OR ".join( - "%s[%d] = %s" % (compare_to, ind, col) for ind in range(0, 10) - ) - else: - return "%s = ANY(%s)" % (col, compare_to) + table_fks.append(fkey_d) + return fkeys.items() @reflection.cache - def get_indexes(self, connection, table_name, schema, **kw): - table_oid = self.get_table_oid( - connection, table_name, schema, info_cache=kw.get("info_cache") + def get_indexes(self, connection, table_name, schema=None, **kw): + data = self.get_multi_indexes( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) + return self._value_or_raise(data, table_name, schema) - # cast indkey as varchar since it's an int2vector, - # returned as a list by some drivers such as pypostgresql - - if self.server_version_info < (8, 5): - IDX_SQL = """ - SELECT - i.relname as relname, - ix.indisunique, ix.indexprs, ix.indpred, - a.attname, a.attnum, NULL, ix.indkey%s, - %s, %s, am.amname, - NULL as indnkeyatts - FROM - pg_class t - join pg_index ix on t.oid = ix.indrelid - join pg_class i on i.oid = ix.indexrelid - left outer join - pg_attribute a - on t.oid = a.attrelid and %s - left outer join - pg_am am - on i.relam = am.oid - WHERE - t.relkind IN ('r', 'v', 'f', 'm') - and t.oid = :table_oid - and ix.indisprimary = 'f' - ORDER BY - t.relname, - i.relname - """ % ( - # version 8.3 here was based on observing the - # cast does not work in PG 8.2.4, does work in 8.3.0. - # nothing in PG changelogs regarding this. - "::varchar" if self.server_version_info >= (8, 3) else "", - "ix.indoption::varchar" - if self.server_version_info >= (8, 3) - else "NULL", - "i.reloptions" - if self.server_version_info >= (8, 2) - else "NULL", - self._pg_index_any("a.attnum", "ix.indkey"), + @util.memoized_property + def _index_query(self): + pg_class_index = pg_catalog.pg_class.alias("cls_idx") + # NOTE: repeating oids clause improve query performance + + # subquery to get the columns + idx_sq = ( + select( + pg_catalog.pg_index.c.indexrelid, + pg_catalog.pg_index.c.indrelid, + sql.func.unnest(pg_catalog.pg_index.c.indkey).label("attnum"), + sql.func.generate_subscripts( + pg_catalog.pg_index.c.indkey, 1 + ).label("ord"), ) - else: - IDX_SQL = """ - SELECT - i.relname as relname, - ix.indisunique, ix.indexprs, - a.attname, a.attnum, c.conrelid, ix.indkey::varchar, - ix.indoption::varchar, i.reloptions, am.amname, - pg_get_expr(ix.indpred, ix.indrelid), - %s as indnkeyatts - FROM - pg_class t - join pg_index ix on t.oid = ix.indrelid - join pg_class i on i.oid = ix.indexrelid - left outer join - pg_attribute a - on t.oid = a.attrelid and a.attnum = ANY(ix.indkey) - left outer join - pg_constraint c - on (ix.indrelid = c.conrelid and - ix.indexrelid = c.conindid and - c.contype in ('p', 'u', 'x')) - left outer join - pg_am am - on i.relam = am.oid - WHERE - t.relkind IN ('r', 'v', 'f', 'm', 'p') - and t.oid = :table_oid - and ix.indisprimary = 'f' - ORDER BY - t.relname, - i.relname - """ % ( - "ix.indnkeyatts" - if self.server_version_info >= (11, 0) - else "NULL", + .where( + ~pg_catalog.pg_index.c.indisprimary, + pg_catalog.pg_index.c.indrelid.in_(bindparam("oids")), ) + .subquery("idx") + ) - t = sql.text(IDX_SQL).columns( - relname=sqltypes.Unicode, attname=sqltypes.Unicode + attr_sq = ( + select( + idx_sq.c.indexrelid, + idx_sq.c.indrelid, + pg_catalog.pg_attribute.c.attname, + ) + .select_from(pg_catalog.pg_attribute) + .join( + idx_sq, + sql.and_( + pg_catalog.pg_attribute.c.attnum == idx_sq.c.attnum, + pg_catalog.pg_attribute.c.attrelid == idx_sq.c.indrelid, + ), + ) + .where(idx_sq.c.indrelid.in_(bindparam("oids"))) + .order_by(idx_sq.c.indexrelid, idx_sq.c.ord) + .subquery("idx_attr") ) - c = connection.execute(t, dict(table_oid=table_oid)) - indexes = defaultdict(lambda: defaultdict(dict)) + cols_sq = ( + select( + attr_sq.c.indexrelid, + attr_sq.c.indrelid, + sql.func.array_agg(attr_sq.c.attname).label("cols"), + ) + .group_by(attr_sq.c.indexrelid, attr_sq.c.indrelid) + .subquery("idx_cols") + ) - sv_idx_name = None - for row in c.fetchall(): - ( - idx_name, - unique, - expr, - col, - col_num, - conrelid, - idx_key, - idx_option, - options, - amname, - filter_definition, - indnkeyatts, - ) = row + if self.server_version_info >= (11, 0): + indnkeyatts = pg_catalog.pg_index.c.indnkeyatts + else: + indnkeyatts = sql.null().label("indnkeyatts") - if expr: - if idx_name != sv_idx_name: - util.warn( - "Skipped unsupported reflection of " - "expression-based index %s" % idx_name - ) - sv_idx_name = idx_name - continue + query = ( + select( + pg_catalog.pg_index.c.indrelid, + pg_class_index.c.relname.label("relname_index"), + pg_catalog.pg_index.c.indisunique, + pg_catalog.pg_index.c.indexprs, + pg_catalog.pg_constraint.c.conrelid.is_not(None).label( + "has_constraint" + ), + pg_catalog.pg_index.c.indoption, + pg_class_index.c.reloptions, + pg_catalog.pg_am.c.amname, + pg_catalog.pg_get_expr( + pg_catalog.pg_index.c.indpred, + pg_catalog.pg_index.c.indrelid, + ).label("filter_definition"), + indnkeyatts, + cols_sq.c.cols.label("index_cols"), + ) + .select_from(pg_catalog.pg_index) + .where( + pg_catalog.pg_index.c.indrelid.in_(bindparam("oids")), + ~pg_catalog.pg_index.c.indisprimary, + ) + .join( + pg_class_index, + pg_catalog.pg_index.c.indexrelid == pg_class_index.c.oid, + ) + .join( + pg_catalog.pg_am, + pg_class_index.c.relam == pg_catalog.pg_am.c.oid, + ) + .outerjoin( + cols_sq, + pg_catalog.pg_index.c.indexrelid == cols_sq.c.indexrelid, + ) + .outerjoin( + pg_catalog.pg_constraint, + sql.and_( + pg_catalog.pg_index.c.indrelid + == pg_catalog.pg_constraint.c.conrelid, + pg_catalog.pg_index.c.indexrelid + == pg_catalog.pg_constraint.c.conindid, + pg_catalog.pg_constraint.c.contype + == sql.any_(_array.array(("p", "u", "x"))), + ), + ) + .order_by(pg_catalog.pg_index.c.indrelid, pg_class_index.c.relname) + ) + return query - has_idx = idx_name in indexes - index = indexes[idx_name] - if col is not None: - index["cols"][col_num] = col - if not has_idx: - idx_keys = idx_key.split() - # "The number of key columns in the index, not counting any - # included columns, which are merely stored and do not - # participate in the index semantics" - if indnkeyatts and idx_keys[indnkeyatts:]: - # this is a "covering index" which has INCLUDE columns - # as well as regular index columns - inc_keys = idx_keys[indnkeyatts:] - idx_keys = idx_keys[:indnkeyatts] - else: - inc_keys = [] + def get_multi_indexes( + self, connection, schema, filter_names, scope, kind, **kw + ): - index["key"] = [int(k.strip()) for k in idx_keys] - index["inc"] = [int(k.strip()) for k in inc_keys] + table_oids = self._get_table_oids( + connection, schema, filter_names, scope, kind, **kw + ) - # (new in pg 8.3) - # "pg_index.indoption" is list of ints, one per column/expr. - # int acts as bitmask: 0x01=DESC, 0x02=NULLSFIRST - sorting = {} - for col_idx, col_flags in enumerate( - (idx_option or "").split() - ): - col_flags = int(col_flags.strip()) - col_sorting = () - # try to set flags only if they differ from PG defaults... - if col_flags & 0x01: - col_sorting += ("desc",) - if not (col_flags & 0x02): - col_sorting += ("nulls_last",) + indexes = defaultdict(list) + default = ReflectionDefaults.indexes + + batches = list(table_oids) + + while batches: + batch = batches[0:3000] + batches[0:3000] = [] + + result = connection.execute( + self._index_query, {"oids": [r[0] for r in batch]} + ).mappings() + + result_by_oid = defaultdict(list) + for row_dict in result: + result_by_oid[row_dict["indrelid"]].append(row_dict) + + for oid, table_name in batch: + if oid not in result_by_oid: + # ensure that each table has an entry, even if reflection + # is skipped because not supported + indexes[(schema, table_name)] = default() + continue + + for row in result_by_oid[oid]: + index_name = row["relname_index"] + + table_indexes = indexes[(schema, table_name)] + + if row["indexprs"]: + tn = ( + table_name + if schema is None + else f"{schema}.{table_name}" + ) + util.warn( + "Skipped unsupported reflection of " + f"expression-based index {index_name} of " + f"table {tn}" + ) + continue + + all_cols = row["index_cols"] + indnkeyatts = row["indnkeyatts"] + # "The number of key columns in the index, not counting any + # included columns, which are merely stored and do not + # participate in the index semantics" + if indnkeyatts and all_cols[indnkeyatts:]: + # this is a "covering index" which has INCLUDE columns + # as well as regular index columns + inc_cols = all_cols[indnkeyatts:] + idx_cols = all_cols[:indnkeyatts] else: - if col_flags & 0x02: - col_sorting += ("nulls_first",) - if col_sorting: - sorting[col_idx] = col_sorting - if sorting: - index["sorting"] = sorting - - index["unique"] = unique - if conrelid is not None: - index["duplicates_constraint"] = idx_name - if options: - index["options"] = dict( - [option.split("=") for option in options] - ) - - # it *might* be nice to include that this is 'btree' in the - # reflection info. But we don't want an Index object - # to have a ``postgresql_using`` in it that is just the - # default, so for the moment leaving this out. - if amname and amname != "btree": - index["amname"] = amname - - if filter_definition: - index["postgresql_where"] = filter_definition + idx_cols = all_cols + inc_cols = [] + + index = { + "name": index_name, + "unique": row["indisunique"], + "column_names": idx_cols, + } + + sorting = {} + for col_index, col_flags in enumerate(row["indoption"]): + col_sorting = () + # try to set flags only if they differ from PG + # defaults... + if col_flags & 0x01: + col_sorting += ("desc",) + if not (col_flags & 0x02): + col_sorting += ("nulls_last",) + else: + if col_flags & 0x02: + col_sorting += ("nulls_first",) + if col_sorting: + sorting[idx_cols[col_index]] = col_sorting + if sorting: + index["column_sorting"] = sorting + if row["has_constraint"]: + index["duplicates_constraint"] = index_name + + dialect_options = {} + if row["reloptions"]: + dialect_options["postgresql_with"] = dict( + [option.split("=") for option in row["reloptions"]] + ) + # it *might* be nice to include that this is 'btree' in the + # reflection info. But we don't want an Index object + # to have a ``postgresql_using`` in it that is just the + # default, so for the moment leaving this out. + amname = row["amname"] + if amname != "btree": + dialect_options["postgresql_using"] = row["amname"] + if row["filter_definition"]: + dialect_options["postgresql_where"] = row[ + "filter_definition" + ] + if self.server_version_info >= (11, 0): + # NOTE: this is legacy, this is part of + # dialect_options now as of #7382 + index["include_columns"] = inc_cols + dialect_options["postgresql_include"] = inc_cols + if dialect_options: + index["dialect_options"] = dialect_options - result = [] - for name, idx in indexes.items(): - entry = { - "name": name, - "unique": idx["unique"], - "column_names": [idx["cols"][i] for i in idx["key"]], - } - if self.server_version_info >= (11, 0): - # NOTE: this is legacy, this is part of dialect_options now - # as of #7382 - entry["include_columns"] = [idx["cols"][i] for i in idx["inc"]] - if "duplicates_constraint" in idx: - entry["duplicates_constraint"] = idx["duplicates_constraint"] - if "sorting" in idx: - entry["column_sorting"] = dict( - (idx["cols"][idx["key"][i]], value) - for i, value in idx["sorting"].items() - ) - if "include_columns" in entry: - entry.setdefault("dialect_options", {})[ - "postgresql_include" - ] = entry["include_columns"] - if "options" in idx: - entry.setdefault("dialect_options", {})[ - "postgresql_with" - ] = idx["options"] - if "amname" in idx: - entry.setdefault("dialect_options", {})[ - "postgresql_using" - ] = idx["amname"] - if "postgresql_where" in idx: - entry.setdefault("dialect_options", {})[ - "postgresql_where" - ] = idx["postgresql_where"] - result.append(entry) - return result + table_indexes.append(index) + return indexes.items() @reflection.cache def get_unique_constraints( self, connection, table_name, schema=None, **kw ): - table_oid = self.get_table_oid( - connection, table_name, schema, info_cache=kw.get("info_cache") + data = self.get_multi_unique_constraints( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) + return self._value_or_raise(data, table_name, schema) - UNIQUE_SQL = """ - SELECT - cons.conname as name, - cons.conkey as key, - a.attnum as col_num, - a.attname as col_name - FROM - pg_catalog.pg_constraint cons - join pg_attribute a - on cons.conrelid = a.attrelid AND - a.attnum = ANY(cons.conkey) - WHERE - cons.conrelid = :table_oid AND - cons.contype = 'u' - """ - - t = sql.text(UNIQUE_SQL).columns(col_name=sqltypes.Unicode) - c = connection.execute(t, dict(table_oid=table_oid)) + def get_multi_unique_constraints( + self, + connection, + schema, + filter_names, + scope, + kind, + **kw, + ): + result = self._reflect_constraint( + connection, "u", schema, filter_names, scope, kind, **kw + ) - uniques = defaultdict(lambda: defaultdict(dict)) - for row in c.fetchall(): - uc = uniques[row.name] - uc["key"] = row.key - uc["cols"][row.col_num] = row.col_name + # each table can have multiple unique constraints + uniques = defaultdict(list) + default = ReflectionDefaults.unique_constraints + for (table_name, cols, con_name) in result: + # ensure a list is created for each table. leave it empty if + # the table has no unique cosntraint + if con_name is None: + uniques[(schema, table_name)] = default() + continue - return [ - {"name": name, "column_names": [uc["cols"][i] for i in uc["key"]]} - for name, uc in uniques.items() - ] + uniques[(schema, table_name)].append( + { + "column_names": cols, + "name": con_name, + } + ) + return uniques.items() @reflection.cache def get_table_comment(self, connection, table_name, schema=None, **kw): - table_oid = self.get_table_oid( - connection, table_name, schema, info_cache=kw.get("info_cache") + data = self.get_multi_table_comment( + connection, + schema, + [table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) + return self._value_or_raise(data, table_name, schema) - COMMENT_SQL = """ - SELECT - pgd.description as table_comment - FROM - pg_catalog.pg_description pgd - WHERE - pgd.objsubid = 0 AND - pgd.objoid = :table_oid - """ + @lru_cache() + def _comment_query(self, schema, has_filter_names, scope, kind): + relkinds = self._kind_to_relkinds(kind) + query = ( + select( + pg_catalog.pg_class.c.relname, + pg_catalog.pg_description.c.description, + ) + .select_from(pg_catalog.pg_class) + .outerjoin( + pg_catalog.pg_description, + sql.and_( + pg_catalog.pg_class.c.oid + == pg_catalog.pg_description.c.objoid, + pg_catalog.pg_description.c.objsubid == 0, + ), + ) + .where(self._pg_class_relkind_condition(relkinds)) + ) + query = self._pg_class_filter_scope_schema(query, schema, scope) + if has_filter_names: + query = query.where( + pg_catalog.pg_class.c.relname.in_(bindparam("filter_names")) + ) + return query - c = connection.execute( - sql.text(COMMENT_SQL), dict(table_oid=table_oid) + def get_multi_table_comment( + self, connection, schema, filter_names, scope, kind, **kw + ): + has_filter_names, params = self._prepare_filter_names(filter_names) + query = self._comment_query(schema, has_filter_names, scope, kind) + result = connection.execute(query, params) + + default = ReflectionDefaults.table_comment + return ( + ( + (schema, table), + {"text": comment} if comment is not None else default(), + ) + for table, comment in result ) - return {"text": c.scalar()} @reflection.cache def get_check_constraints(self, connection, table_name, schema=None, **kw): - table_oid = self.get_table_oid( - connection, table_name, schema, info_cache=kw.get("info_cache") + data = self.get_multi_check_constraints( + connection, + schema, + [table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) + return self._value_or_raise(data, table_name, schema) - CHECK_SQL = """ - SELECT - cons.conname as name, - pg_get_constraintdef(cons.oid) as src - FROM - pg_catalog.pg_constraint cons - WHERE - cons.conrelid = :table_oid AND - cons.contype = 'c' - """ - - c = connection.execute(sql.text(CHECK_SQL), dict(table_oid=table_oid)) + @lru_cache() + def _check_constraint_query(self, schema, has_filter_names, scope, kind): + relkinds = self._kind_to_relkinds(kind) + query = ( + select( + pg_catalog.pg_class.c.relname, + pg_catalog.pg_constraint.c.conname, + sql.case( + ( + pg_catalog.pg_constraint.c.oid.is_not(None), + pg_catalog.pg_get_constraintdef( + pg_catalog.pg_constraint.c.oid + ), + ), + else_=None, + ), + ) + .select_from(pg_catalog.pg_class) + .outerjoin( + pg_catalog.pg_constraint, + sql.and_( + pg_catalog.pg_class.c.oid + == pg_catalog.pg_constraint.c.conrelid, + pg_catalog.pg_constraint.c.contype == "c", + ), + ) + .where(self._pg_class_relkind_condition(relkinds)) + ) + query = self._pg_class_filter_scope_schema(query, schema, scope) + if has_filter_names: + query = query.where( + pg_catalog.pg_class.c.relname.in_(bindparam("filter_names")) + ) + return query - ret = [] - for name, src in c: + def get_multi_check_constraints( + self, connection, schema, filter_names, scope, kind, **kw + ): + has_filter_names, params = self._prepare_filter_names(filter_names) + query = self._check_constraint_query( + schema, has_filter_names, scope, kind + ) + result = connection.execute(query, params) + + check_constraints = defaultdict(list) + default = ReflectionDefaults.check_constraints + for table_name, check_name, src in result: + # only two cases for check_name and src: both null or both defined + if check_name is None and src is None: + check_constraints[(schema, table_name)] = default() + continue # samples: # "CHECK (((a > 1) AND (a < 5)))" # "CHECK (((a = 1) OR ((a > 2) AND (a < 5))))" @@ -4424,84 +4258,118 @@ class PGDialect(default.DefaultDialect): sqltext = re.compile( r"^[\s\n]*\((.+)\)[\s\n]*$", flags=re.DOTALL ).sub(r"\1", m.group(1)) - entry = {"name": name, "sqltext": sqltext} + entry = {"name": check_name, "sqltext": sqltext} if m and m.group(2): entry["dialect_options"] = {"not_valid": True} - ret.append(entry) - return ret - - def _load_enums(self, connection, schema=None): - schema = schema or self.default_schema_name - if not self.supports_native_enum: - return {} - - # Load data types for enums: - SQL_ENUMS = """ - SELECT t.typname as "name", - -- no enum defaults in 8.4 at least - -- t.typdefault as "default", - pg_catalog.pg_type_is_visible(t.oid) as "visible", - n.nspname as "schema", - e.enumlabel as "label" - FROM pg_catalog.pg_type t - LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace - LEFT JOIN pg_catalog.pg_enum e ON t.oid = e.enumtypid - WHERE t.typtype = 'e' - """ + check_constraints[(schema, table_name)].append(entry) + return check_constraints.items() - if schema != "*": - SQL_ENUMS += "AND n.nspname = :schema " + @lru_cache() + def _enum_query(self, schema): + lbl_sq = ( + select( + pg_catalog.pg_enum.c.enumtypid, pg_catalog.pg_enum.c.enumlabel + ) + .order_by( + pg_catalog.pg_enum.c.enumtypid, + pg_catalog.pg_enum.c.enumsortorder, + ) + .subquery("lbl") + ) - # e.oid gives us label order within an enum - SQL_ENUMS += 'ORDER BY "schema", "name", e.oid' + lbl_agg_sq = ( + select( + lbl_sq.c.enumtypid, + sql.func.array_agg(lbl_sq.c.enumlabel).label("labels"), + ) + .group_by(lbl_sq.c.enumtypid) + .subquery("lbl_agg") + ) - s = sql.text(SQL_ENUMS).columns( - attname=sqltypes.Unicode, label=sqltypes.Unicode + query = ( + select( + pg_catalog.pg_type.c.typname.label("name"), + pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid).label( + "visible" + ), + pg_catalog.pg_namespace.c.nspname.label("schema"), + lbl_agg_sq.c.labels.label("labels"), + ) + .join( + pg_catalog.pg_namespace, + pg_catalog.pg_namespace.c.oid + == pg_catalog.pg_type.c.typnamespace, + ) + .outerjoin( + lbl_agg_sq, pg_catalog.pg_type.c.oid == lbl_agg_sq.c.enumtypid + ) + .where(pg_catalog.pg_type.c.typtype == "e") + .order_by( + pg_catalog.pg_namespace.c.nspname, pg_catalog.pg_type.c.typname + ) ) - if schema != "*": - s = s.bindparams(schema=schema) + if schema is None: + query = query.where( + pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid), + # ignore pg_catalog schema + pg_catalog.pg_namespace.c.nspname != "pg_catalog", + ) + elif schema != "*": + query = query.where(pg_catalog.pg_namespace.c.nspname == schema) + return query + + @reflection.cache + def _load_enums(self, connection, schema=None, **kw): + if not self.supports_native_enum: + return [] - c = connection.execute(s) + result = connection.execute(self._enum_query(schema)) enums = [] - enum_by_name = {} - for enum in c.fetchall(): - key = (enum.schema, enum.name) - if key in enum_by_name: - enum_by_name[key]["labels"].append(enum.label) - else: - enum_by_name[key] = enum_rec = { - "name": enum.name, - "schema": enum.schema, - "visible": enum.visible, - "labels": [], + for name, visible, schema, labels in result: + enums.append( + { + "name": name, + "schema": schema, + "visible": visible, + "labels": [] if labels is None else labels, } - if enum.label is not None: - enum_rec["labels"].append(enum.label) - enums.append(enum_rec) + ) return enums - def _load_domains(self, connection): - # Load data types for domains: - SQL_DOMAINS = """ - SELECT t.typname as "name", - pg_catalog.format_type(t.typbasetype, t.typtypmod) as "attype", - not t.typnotnull as "nullable", - t.typdefault as "default", - pg_catalog.pg_type_is_visible(t.oid) as "visible", - n.nspname as "schema" - FROM pg_catalog.pg_type t - LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace - WHERE t.typtype = 'd' - """ + @util.memoized_property + def _domain_query(self): + return ( + select( + pg_catalog.pg_type.c.typname.label("name"), + pg_catalog.format_type( + pg_catalog.pg_type.c.typbasetype, + pg_catalog.pg_type.c.typtypmod, + ).label("attype"), + (~pg_catalog.pg_type.c.typnotnull).label("nullable"), + pg_catalog.pg_type.c.typdefault.label("default"), + pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid).label( + "visible" + ), + pg_catalog.pg_namespace.c.nspname.label("schema"), + ) + .join( + pg_catalog.pg_namespace, + pg_catalog.pg_namespace.c.oid + == pg_catalog.pg_type.c.typnamespace, + ) + .where(pg_catalog.pg_type.c.typtype == "d") + ) - s = sql.text(SQL_DOMAINS) - c = connection.execution_options(future_result=True).execute(s) + @reflection.cache + def _load_domains(self, connection, **kw): + # Load data types for domains: + result = connection.execute(self._domain_query) domains = {} - for domain in c.mappings(): + for domain in result.mappings(): domain = domain # strip (30) from character varying(30) attype = re.search(r"([^\(]+)", domain["attype"]).group(1) diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py index 6cb97ece4..ce9a3bb6c 100644 --- a/lib/sqlalchemy/dialects/postgresql/pg8000.py +++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py @@ -107,6 +107,8 @@ from .base import PGIdentifierPreparer from .json import JSON from .json import JSONB from .json import JSONPathType +from .pg_catalog import _SpaceVector +from .pg_catalog import OIDVECTOR from ... import exc from ... import util from ...engine import processors @@ -245,6 +247,10 @@ class _PGARRAY(PGARRAY): render_bind_cast = True +class _PGOIDVECTOR(_SpaceVector, OIDVECTOR): + pass + + _server_side_id = util.counter() @@ -376,6 +382,7 @@ class PGDialect_pg8000(PGDialect): sqltypes.BigInteger: _PGBigInteger, sqltypes.Enum: _PGEnum, sqltypes.ARRAY: _PGARRAY, + OIDVECTOR: _PGOIDVECTOR, }, ) diff --git a/lib/sqlalchemy/dialects/postgresql/pg_catalog.py b/lib/sqlalchemy/dialects/postgresql/pg_catalog.py new file mode 100644 index 000000000..a77e7ccf6 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/pg_catalog.py @@ -0,0 +1,292 @@ +# postgresql/pg_catalog.py +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from .array import ARRAY +from .types import OID +from .types import REGCLASS +from ... import Column +from ... import func +from ... import MetaData +from ... import Table +from ...types import BigInteger +from ...types import Boolean +from ...types import CHAR +from ...types import Float +from ...types import Integer +from ...types import SmallInteger +from ...types import String +from ...types import Text +from ...types import TypeDecorator + + +# types +class NAME(TypeDecorator): + impl = String(64, collation="C") + cache_ok = True + + +class PG_NODE_TREE(TypeDecorator): + impl = Text(collation="C") + cache_ok = True + + +class INT2VECTOR(TypeDecorator): + impl = ARRAY(SmallInteger) + cache_ok = True + + +class OIDVECTOR(TypeDecorator): + impl = ARRAY(OID) + cache_ok = True + + +class _SpaceVector: + def result_processor(self, dialect, coltype): + def process(value): + if value is None: + return value + return [int(p) for p in value.split(" ")] + + return process + + +REGPROC = REGCLASS # seems an alias + +# functions +_pg_cat = func.pg_catalog +quote_ident = _pg_cat.quote_ident +pg_table_is_visible = _pg_cat.pg_table_is_visible +pg_type_is_visible = _pg_cat.pg_type_is_visible +pg_get_viewdef = _pg_cat.pg_get_viewdef +pg_get_serial_sequence = _pg_cat.pg_get_serial_sequence +format_type = _pg_cat.format_type +pg_get_expr = _pg_cat.pg_get_expr +pg_get_constraintdef = _pg_cat.pg_get_constraintdef + +# constants +RELKINDS_TABLE_NO_FOREIGN = ("r", "p") +RELKINDS_TABLE = RELKINDS_TABLE_NO_FOREIGN + ("f",) +RELKINDS_VIEW = ("v",) +RELKINDS_MAT_VIEW = ("m",) +RELKINDS_ALL_TABLE_LIKE = RELKINDS_TABLE + RELKINDS_VIEW + RELKINDS_MAT_VIEW + +# tables +pg_catalog_meta = MetaData() + +pg_namespace = Table( + "pg_namespace", + pg_catalog_meta, + Column("oid", OID), + Column("nspname", NAME), + Column("nspowner", OID), + schema="pg_catalog", +) + +pg_class = Table( + "pg_class", + pg_catalog_meta, + Column("oid", OID, info={"server_version": (9, 3)}), + Column("relname", NAME), + Column("relnamespace", OID), + Column("reltype", OID), + Column("reloftype", OID), + Column("relowner", OID), + Column("relam", OID), + Column("relfilenode", OID), + Column("reltablespace", OID), + Column("relpages", Integer), + Column("reltuples", Float), + Column("relallvisible", Integer, info={"server_version": (9, 2)}), + Column("reltoastrelid", OID), + Column("relhasindex", Boolean), + Column("relisshared", Boolean), + Column("relpersistence", CHAR, info={"server_version": (9, 1)}), + Column("relkind", CHAR), + Column("relnatts", SmallInteger), + Column("relchecks", SmallInteger), + Column("relhasrules", Boolean), + Column("relhastriggers", Boolean), + Column("relhassubclass", Boolean), + Column("relrowsecurity", Boolean), + Column("relforcerowsecurity", Boolean, info={"server_version": (9, 5)}), + Column("relispopulated", Boolean, info={"server_version": (9, 3)}), + Column("relreplident", CHAR, info={"server_version": (9, 4)}), + Column("relispartition", Boolean, info={"server_version": (10,)}), + Column("relrewrite", OID, info={"server_version": (11,)}), + Column("reloptions", ARRAY(Text)), + schema="pg_catalog", +) + +pg_type = Table( + "pg_type", + pg_catalog_meta, + Column("oid", OID, info={"server_version": (9, 3)}), + Column("typname", NAME), + Column("typnamespace", OID), + Column("typowner", OID), + Column("typlen", SmallInteger), + Column("typbyval", Boolean), + Column("typtype", CHAR), + Column("typcategory", CHAR), + Column("typispreferred", Boolean), + Column("typisdefined", Boolean), + Column("typdelim", CHAR), + Column("typrelid", OID), + Column("typelem", OID), + Column("typarray", OID), + Column("typinput", REGPROC), + Column("typoutput", REGPROC), + Column("typreceive", REGPROC), + Column("typsend", REGPROC), + Column("typmodin", REGPROC), + Column("typmodout", REGPROC), + Column("typanalyze", REGPROC), + Column("typalign", CHAR), + Column("typstorage", CHAR), + Column("typnotnull", Boolean), + Column("typbasetype", OID), + Column("typtypmod", Integer), + Column("typndims", Integer), + Column("typcollation", OID, info={"server_version": (9, 1)}), + Column("typdefault", Text), + schema="pg_catalog", +) + +pg_index = Table( + "pg_index", + pg_catalog_meta, + Column("indexrelid", OID), + Column("indrelid", OID), + Column("indnatts", SmallInteger), + Column("indnkeyatts", SmallInteger, info={"server_version": (11,)}), + Column("indisunique", Boolean), + Column("indisprimary", Boolean), + Column("indisexclusion", Boolean, info={"server_version": (9, 1)}), + Column("indimmediate", Boolean), + Column("indisclustered", Boolean), + Column("indisvalid", Boolean), + Column("indcheckxmin", Boolean), + Column("indisready", Boolean), + Column("indislive", Boolean, info={"server_version": (9, 3)}), # 9.3 + Column("indisreplident", Boolean), + Column("indkey", INT2VECTOR), + Column("indcollation", OIDVECTOR, info={"server_version": (9, 1)}), # 9.1 + Column("indclass", OIDVECTOR), + Column("indoption", INT2VECTOR), + Column("indexprs", PG_NODE_TREE), + Column("indpred", PG_NODE_TREE), + schema="pg_catalog", +) + +pg_attribute = Table( + "pg_attribute", + pg_catalog_meta, + Column("attrelid", OID), + Column("attname", NAME), + Column("atttypid", OID), + Column("attstattarget", Integer), + Column("attlen", SmallInteger), + Column("attnum", SmallInteger), + Column("attndims", Integer), + Column("attcacheoff", Integer), + Column("atttypmod", Integer), + Column("attbyval", Boolean), + Column("attstorage", CHAR), + Column("attalign", CHAR), + Column("attnotnull", Boolean), + Column("atthasdef", Boolean), + Column("atthasmissing", Boolean, info={"server_version": (11,)}), + Column("attidentity", CHAR, info={"server_version": (10,)}), + Column("attgenerated", CHAR, info={"server_version": (12,)}), + Column("attisdropped", Boolean), + Column("attislocal", Boolean), + Column("attinhcount", Integer), + Column("attcollation", OID, info={"server_version": (9, 1)}), + schema="pg_catalog", +) + +pg_constraint = Table( + "pg_constraint", + pg_catalog_meta, + Column("oid", OID), # 9.3 + Column("conname", NAME), + Column("connamespace", OID), + Column("contype", CHAR), + Column("condeferrable", Boolean), + Column("condeferred", Boolean), + Column("convalidated", Boolean, info={"server_version": (9, 1)}), + Column("conrelid", OID), + Column("contypid", OID), + Column("conindid", OID), + Column("conparentid", OID, info={"server_version": (11,)}), + Column("confrelid", OID), + Column("confupdtype", CHAR), + Column("confdeltype", CHAR), + Column("confmatchtype", CHAR), + Column("conislocal", Boolean), + Column("coninhcount", Integer), + Column("connoinherit", Boolean, info={"server_version": (9, 2)}), + Column("conkey", ARRAY(SmallInteger)), + Column("confkey", ARRAY(SmallInteger)), + schema="pg_catalog", +) + +pg_sequence = Table( + "pg_sequence", + pg_catalog_meta, + Column("seqrelid", OID), + Column("seqtypid", OID), + Column("seqstart", BigInteger), + Column("seqincrement", BigInteger), + Column("seqmax", BigInteger), + Column("seqmin", BigInteger), + Column("seqcache", BigInteger), + Column("seqcycle", Boolean), + schema="pg_catalog", + info={"server_version": (10,)}, +) + +pg_attrdef = Table( + "pg_attrdef", + pg_catalog_meta, + Column("oid", OID, info={"server_version": (9, 3)}), + Column("adrelid", OID), + Column("adnum", SmallInteger), + Column("adbin", PG_NODE_TREE), + schema="pg_catalog", +) + +pg_description = Table( + "pg_description", + pg_catalog_meta, + Column("objoid", OID), + Column("classoid", OID), + Column("objsubid", Integer), + Column("description", Text(collation="C")), + schema="pg_catalog", +) + +pg_enum = Table( + "pg_enum", + pg_catalog_meta, + Column("oid", OID, info={"server_version": (9, 3)}), + Column("enumtypid", OID), + Column("enumsortorder", Float(), info={"server_version": (9, 1)}), + Column("enumlabel", NAME), + schema="pg_catalog", +) + +pg_am = Table( + "pg_am", + pg_catalog_meta, + Column("oid", OID, info={"server_version": (9, 3)}), + Column("amname", NAME), + Column("amhandler", REGPROC, info={"server_version": (9, 6)}), + Column("amtype", CHAR, info={"server_version": (9, 6)}), + schema="pg_catalog", +) diff --git a/lib/sqlalchemy/dialects/postgresql/types.py b/lib/sqlalchemy/dialects/postgresql/types.py new file mode 100644 index 000000000..55735953b --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/types.py @@ -0,0 +1,485 @@ +# Copyright (C) 2013-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +import datetime as dt +from typing import Any + +from ... import schema +from ... import util +from ...sql import sqltypes +from ...sql.ddl import InvokeDDLBase + + +_DECIMAL_TYPES = (1231, 1700) +_FLOAT_TYPES = (700, 701, 1021, 1022) +_INT_TYPES = (20, 21, 23, 26, 1005, 1007, 1016) + + +class PGUuid(sqltypes.UUID): + render_bind_cast = True + render_literal_cast = True + + +class BYTEA(sqltypes.LargeBinary[bytes]): + __visit_name__ = "BYTEA" + + +class INET(sqltypes.TypeEngine[str]): + __visit_name__ = "INET" + + +PGInet = INET + + +class CIDR(sqltypes.TypeEngine[str]): + __visit_name__ = "CIDR" + + +PGCidr = CIDR + + +class MACADDR(sqltypes.TypeEngine[str]): + __visit_name__ = "MACADDR" + + +PGMacAddr = MACADDR + + +class MONEY(sqltypes.TypeEngine[str]): + + r"""Provide the PostgreSQL MONEY type. + + Depending on driver, result rows using this type may return a + string value which includes currency symbols. + + For this reason, it may be preferable to provide conversion to a + numerically-based currency datatype using :class:`_types.TypeDecorator`:: + + import re + import decimal + from sqlalchemy import TypeDecorator + + class NumericMoney(TypeDecorator): + impl = MONEY + + def process_result_value(self, value: Any, dialect: Any) -> None: + if value is not None: + # adjust this for the currency and numeric + m = re.match(r"\$([\d.]+)", value) + if m: + value = decimal.Decimal(m.group(1)) + return value + + Alternatively, the conversion may be applied as a CAST using + the :meth:`_types.TypeDecorator.column_expression` method as follows:: + + import decimal + from sqlalchemy import cast + from sqlalchemy import TypeDecorator + + class NumericMoney(TypeDecorator): + impl = MONEY + + def column_expression(self, column: Any): + return cast(column, Numeric()) + + .. versionadded:: 1.2 + + """ + + __visit_name__ = "MONEY" + + +class OID(sqltypes.TypeEngine[int]): + + """Provide the PostgreSQL OID type. + + .. versionadded:: 0.9.5 + + """ + + __visit_name__ = "OID" + + +class REGCLASS(sqltypes.TypeEngine[str]): + + """Provide the PostgreSQL REGCLASS type. + + .. versionadded:: 1.2.7 + + """ + + __visit_name__ = "REGCLASS" + + +class TIMESTAMP(sqltypes.TIMESTAMP): + def __init__(self, timezone=False, precision=None): + super(TIMESTAMP, self).__init__(timezone=timezone) + self.precision = precision + + +class TIME(sqltypes.TIME): + def __init__(self, timezone=False, precision=None): + super(TIME, self).__init__(timezone=timezone) + self.precision = precision + + +class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval): + + """PostgreSQL INTERVAL type.""" + + __visit_name__ = "INTERVAL" + native = True + + def __init__(self, precision=None, fields=None): + """Construct an INTERVAL. + + :param precision: optional integer precision value + :param fields: string fields specifier. allows storage of fields + to be limited, such as ``"YEAR"``, ``"MONTH"``, ``"DAY TO HOUR"``, + etc. + + .. versionadded:: 1.2 + + """ + self.precision = precision + self.fields = fields + + @classmethod + def adapt_emulated_to_native(cls, interval, **kw): + return INTERVAL(precision=interval.second_precision) + + @property + def _type_affinity(self): + return sqltypes.Interval + + def as_generic(self, allow_nulltype=False): + return sqltypes.Interval(native=True, second_precision=self.precision) + + @property + def python_type(self): + return dt.timedelta + + +PGInterval = INTERVAL + + +class BIT(sqltypes.TypeEngine[int]): + __visit_name__ = "BIT" + + def __init__(self, length=None, varying=False): + if not varying: + # BIT without VARYING defaults to length 1 + self.length = length or 1 + else: + # but BIT VARYING can be unlimited-length, so no default + self.length = length + self.varying = varying + + +PGBit = BIT + + +class TSVECTOR(sqltypes.TypeEngine[Any]): + + """The :class:`_postgresql.TSVECTOR` type implements the PostgreSQL + text search type TSVECTOR. + + It can be used to do full text queries on natural language + documents. + + .. versionadded:: 0.9.0 + + .. seealso:: + + :ref:`postgresql_match` + + """ + + __visit_name__ = "TSVECTOR" + + +class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): + + """PostgreSQL ENUM type. + + This is a subclass of :class:`_types.Enum` which includes + support for PG's ``CREATE TYPE`` and ``DROP TYPE``. + + When the builtin type :class:`_types.Enum` is used and the + :paramref:`.Enum.native_enum` flag is left at its default of + True, the PostgreSQL backend will use a :class:`_postgresql.ENUM` + type as the implementation, so the special create/drop rules + will be used. + + The create/drop behavior of ENUM is necessarily intricate, due to the + awkward relationship the ENUM type has in relationship to the + parent table, in that it may be "owned" by just a single table, or + may be shared among many tables. + + When using :class:`_types.Enum` or :class:`_postgresql.ENUM` + in an "inline" fashion, the ``CREATE TYPE`` and ``DROP TYPE`` is emitted + corresponding to when the :meth:`_schema.Table.create` and + :meth:`_schema.Table.drop` + methods are called:: + + table = Table('sometable', metadata, + Column('some_enum', ENUM('a', 'b', 'c', name='myenum')) + ) + + table.create(engine) # will emit CREATE ENUM and CREATE TABLE + table.drop(engine) # will emit DROP TABLE and DROP ENUM + + To use a common enumerated type between multiple tables, the best + practice is to declare the :class:`_types.Enum` or + :class:`_postgresql.ENUM` independently, and associate it with the + :class:`_schema.MetaData` object itself:: + + my_enum = ENUM('a', 'b', 'c', name='myenum', metadata=metadata) + + t1 = Table('sometable_one', metadata, + Column('some_enum', myenum) + ) + + t2 = Table('sometable_two', metadata, + Column('some_enum', myenum) + ) + + When this pattern is used, care must still be taken at the level + of individual table creates. Emitting CREATE TABLE without also + specifying ``checkfirst=True`` will still cause issues:: + + t1.create(engine) # will fail: no such type 'myenum' + + If we specify ``checkfirst=True``, the individual table-level create + operation will check for the ``ENUM`` and create if not exists:: + + # will check if enum exists, and emit CREATE TYPE if not + t1.create(engine, checkfirst=True) + + When using a metadata-level ENUM type, the type will always be created + and dropped if either the metadata-wide create/drop is called:: + + metadata.create_all(engine) # will emit CREATE TYPE + metadata.drop_all(engine) # will emit DROP TYPE + + The type can also be created and dropped directly:: + + my_enum.create(engine) + my_enum.drop(engine) + + .. versionchanged:: 1.0.0 The PostgreSQL :class:`_postgresql.ENUM` type + now behaves more strictly with regards to CREATE/DROP. A metadata-level + ENUM type will only be created and dropped at the metadata level, + not the table level, with the exception of + ``table.create(checkfirst=True)``. + The ``table.drop()`` call will now emit a DROP TYPE for a table-level + enumerated type. + + """ + + native_enum = True + + def __init__(self, *enums, **kw): + """Construct an :class:`_postgresql.ENUM`. + + Arguments are the same as that of + :class:`_types.Enum`, but also including + the following parameters. + + :param create_type: Defaults to True. + Indicates that ``CREATE TYPE`` should be + emitted, after optionally checking for the + presence of the type, when the parent + table is being created; and additionally + that ``DROP TYPE`` is called when the table + is dropped. When ``False``, no check + will be performed and no ``CREATE TYPE`` + or ``DROP TYPE`` is emitted, unless + :meth:`~.postgresql.ENUM.create` + or :meth:`~.postgresql.ENUM.drop` + are called directly. + Setting to ``False`` is helpful + when invoking a creation scheme to a SQL file + without access to the actual database - + the :meth:`~.postgresql.ENUM.create` and + :meth:`~.postgresql.ENUM.drop` methods can + be used to emit SQL to a target bind. + + """ + native_enum = kw.pop("native_enum", None) + if native_enum is False: + util.warn( + "the native_enum flag does not apply to the " + "sqlalchemy.dialects.postgresql.ENUM datatype; this type " + "always refers to ENUM. Use sqlalchemy.types.Enum for " + "non-native enum." + ) + self.create_type = kw.pop("create_type", True) + super(ENUM, self).__init__(*enums, **kw) + + @classmethod + def adapt_emulated_to_native(cls, impl, **kw): + """Produce a PostgreSQL native :class:`_postgresql.ENUM` from plain + :class:`.Enum`. + + """ + kw.setdefault("validate_strings", impl.validate_strings) + kw.setdefault("name", impl.name) + kw.setdefault("schema", impl.schema) + kw.setdefault("inherit_schema", impl.inherit_schema) + kw.setdefault("metadata", impl.metadata) + kw.setdefault("_create_events", False) + kw.setdefault("values_callable", impl.values_callable) + kw.setdefault("omit_aliases", impl._omit_aliases) + return cls(**kw) + + def create(self, bind=None, checkfirst=True): + """Emit ``CREATE TYPE`` for this + :class:`_postgresql.ENUM`. + + If the underlying dialect does not support + PostgreSQL CREATE TYPE, no action is taken. + + :param bind: a connectable :class:`_engine.Engine`, + :class:`_engine.Connection`, or similar object to emit + SQL. + :param checkfirst: if ``True``, a query against + the PG catalog will be first performed to see + if the type does not exist already before + creating. + + """ + if not bind.dialect.supports_native_enum: + return + + bind._run_ddl_visitor(self.EnumGenerator, self, checkfirst=checkfirst) + + def drop(self, bind=None, checkfirst=True): + """Emit ``DROP TYPE`` for this + :class:`_postgresql.ENUM`. + + If the underlying dialect does not support + PostgreSQL DROP TYPE, no action is taken. + + :param bind: a connectable :class:`_engine.Engine`, + :class:`_engine.Connection`, or similar object to emit + SQL. + :param checkfirst: if ``True``, a query against + the PG catalog will be first performed to see + if the type actually exists before dropping. + + """ + if not bind.dialect.supports_native_enum: + return + + bind._run_ddl_visitor(self.EnumDropper, self, checkfirst=checkfirst) + + class EnumGenerator(InvokeDDLBase): + def __init__(self, dialect, connection, checkfirst=False, **kwargs): + super(ENUM.EnumGenerator, self).__init__(connection, **kwargs) + self.checkfirst = checkfirst + + def _can_create_enum(self, enum): + if not self.checkfirst: + return True + + effective_schema = self.connection.schema_for_object(enum) + + return not self.connection.dialect.has_type( + self.connection, enum.name, schema=effective_schema + ) + + def visit_enum(self, enum): + if not self._can_create_enum(enum): + return + + self.connection.execute(CreateEnumType(enum)) + + class EnumDropper(InvokeDDLBase): + def __init__(self, dialect, connection, checkfirst=False, **kwargs): + super(ENUM.EnumDropper, self).__init__(connection, **kwargs) + self.checkfirst = checkfirst + + def _can_drop_enum(self, enum): + if not self.checkfirst: + return True + + effective_schema = self.connection.schema_for_object(enum) + + return self.connection.dialect.has_type( + self.connection, enum.name, schema=effective_schema + ) + + def visit_enum(self, enum): + if not self._can_drop_enum(enum): + return + + self.connection.execute(DropEnumType(enum)) + + def get_dbapi_type(self, dbapi): + """dont return dbapi.STRING for ENUM in PostgreSQL, since that's + a different type""" + + return None + + def _check_for_name_in_memos(self, checkfirst, kw): + """Look in the 'ddl runner' for 'memos', then + note our name in that collection. + + This to ensure a particular named enum is operated + upon only once within any kind of create/drop + sequence without relying upon "checkfirst". + + """ + if not self.create_type: + return True + if "_ddl_runner" in kw: + ddl_runner = kw["_ddl_runner"] + if "_pg_enums" in ddl_runner.memo: + pg_enums = ddl_runner.memo["_pg_enums"] + else: + pg_enums = ddl_runner.memo["_pg_enums"] = set() + present = (self.schema, self.name) in pg_enums + pg_enums.add((self.schema, self.name)) + return present + else: + return False + + def _on_table_create(self, target, bind, checkfirst=False, **kw): + if ( + checkfirst + or ( + not self.metadata + and not kw.get("_is_metadata_operation", False) + ) + ) and not self._check_for_name_in_memos(checkfirst, kw): + self.create(bind=bind, checkfirst=checkfirst) + + def _on_table_drop(self, target, bind, checkfirst=False, **kw): + if ( + not self.metadata + and not kw.get("_is_metadata_operation", False) + and not self._check_for_name_in_memos(checkfirst, kw) + ): + self.drop(bind=bind, checkfirst=checkfirst) + + def _on_metadata_create(self, target, bind, checkfirst=False, **kw): + if not self._check_for_name_in_memos(checkfirst, kw): + self.create(bind=bind, checkfirst=checkfirst) + + def _on_metadata_drop(self, target, bind, checkfirst=False, **kw): + if not self._check_for_name_in_memos(checkfirst, kw): + self.drop(bind=bind, checkfirst=checkfirst) + + +class CreateEnumType(schema._CreateDropBase): + __visit_name__ = "create_enum_type" + + +class DropEnumType(schema._CreateDropBase): + __visit_name__ = "drop_enum_type" diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index fdcd1340b..22f003e38 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -867,6 +867,7 @@ from ... import util from ...engine import default from ...engine import processors from ...engine import reflection +from ...engine.reflection import ReflectionDefaults from ...sql import coercions from ...sql import ColumnElement from ...sql import compiler @@ -2053,28 +2054,27 @@ class SQLiteDialect(default.DefaultDialect): return [db[1] for db in dl if db[1] != "temp"] - @reflection.cache - def get_table_names(self, connection, schema=None, **kw): + def _format_schema(self, schema, table_name): if schema is not None: qschema = self.identifier_preparer.quote_identifier(schema) - master = "%s.sqlite_master" % qschema + name = f"{qschema}.{table_name}" else: - master = "sqlite_master" - s = ("SELECT name FROM %s " "WHERE type='table' ORDER BY name") % ( - master, - ) - rs = connection.exec_driver_sql(s) - return [row[0] for row in rs] + name = table_name + return name @reflection.cache - def get_temp_table_names(self, connection, **kw): - s = ( - "SELECT name FROM sqlite_temp_master " - "WHERE type='table' ORDER BY name " - ) - rs = connection.exec_driver_sql(s) + def get_table_names(self, connection, schema=None, **kw): + main = self._format_schema(schema, "sqlite_master") + s = f"SELECT name FROM {main} WHERE type='table' ORDER BY name" + names = connection.exec_driver_sql(s).scalars().all() + return names - return [row[0] for row in rs] + @reflection.cache + def get_temp_table_names(self, connection, **kw): + main = "sqlite_temp_master" + s = f"SELECT name FROM {main} WHERE type='table' ORDER BY name" + names = connection.exec_driver_sql(s).scalars().all() + return names @reflection.cache def get_temp_view_names(self, connection, **kw): @@ -2082,11 +2082,11 @@ class SQLiteDialect(default.DefaultDialect): "SELECT name FROM sqlite_temp_master " "WHERE type='view' ORDER BY name " ) - rs = connection.exec_driver_sql(s) - - return [row[0] for row in rs] + names = connection.exec_driver_sql(s).scalars().all() + return names - def has_table(self, connection, table_name, schema=None): + @reflection.cache + def has_table(self, connection, table_name, schema=None, **kw): self._ensure_has_table_connection(connection) info = self._get_table_pragma( @@ -2099,23 +2099,16 @@ class SQLiteDialect(default.DefaultDialect): @reflection.cache def get_view_names(self, connection, schema=None, **kw): - if schema is not None: - qschema = self.identifier_preparer.quote_identifier(schema) - master = "%s.sqlite_master" % qschema - else: - master = "sqlite_master" - s = ("SELECT name FROM %s " "WHERE type='view' ORDER BY name") % ( - master, - ) - rs = connection.exec_driver_sql(s) - - return [row[0] for row in rs] + main = self._format_schema(schema, "sqlite_master") + s = f"SELECT name FROM {main} WHERE type='view' ORDER BY name" + names = connection.exec_driver_sql(s).scalars().all() + return names @reflection.cache def get_view_definition(self, connection, view_name, schema=None, **kw): if schema is not None: qschema = self.identifier_preparer.quote_identifier(schema) - master = "%s.sqlite_master" % qschema + master = f"{qschema}.sqlite_master" s = ("SELECT sql FROM %s WHERE name = ? AND type='view'") % ( master, ) @@ -2140,6 +2133,10 @@ class SQLiteDialect(default.DefaultDialect): result = rs.fetchall() if result: return result[0].sql + else: + raise exc.NoSuchTableError( + f"{schema}.{view_name}" if schema else view_name + ) @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): @@ -2186,7 +2183,14 @@ class SQLiteDialect(default.DefaultDialect): tablesql, ) ) - return columns + if columns: + return columns + elif not self.has_table(connection, table_name, schema): + raise exc.NoSuchTableError( + f"{schema}.{table_name}" if schema else table_name + ) + else: + return ReflectionDefaults.columns() def _get_column_info( self, @@ -2216,7 +2220,6 @@ class SQLiteDialect(default.DefaultDialect): "type": coltype, "nullable": nullable, "default": default, - "autoincrement": "auto", "primary_key": primary_key, } if generated: @@ -2295,13 +2298,16 @@ class SQLiteDialect(default.DefaultDialect): constraint_name = result.group(1) if result else None cols = self.get_columns(connection, table_name, schema, **kw) + # consider only pk columns. This also avoids sorting the cached + # value returned by get_columns + cols = [col for col in cols if col.get("primary_key", 0) > 0] cols.sort(key=lambda col: col.get("primary_key")) - pkeys = [] - for col in cols: - if col["primary_key"]: - pkeys.append(col["name"]) + pkeys = [col["name"] for col in cols] - return {"constrained_columns": pkeys, "name": constraint_name} + if pkeys: + return {"constrained_columns": pkeys, "name": constraint_name} + else: + return ReflectionDefaults.pk_constraint() @reflection.cache def get_foreign_keys(self, connection, table_name, schema=None, **kw): @@ -2321,12 +2327,14 @@ class SQLiteDialect(default.DefaultDialect): # original DDL. The referred columns of the foreign key # constraint are therefore the primary key of the referred # table. - referred_pk = self.get_pk_constraint( - connection, rtbl, schema=schema, **kw - ) - # note that if table doesn't exist, we still get back a record, - # just it has no columns in it - referred_columns = referred_pk["constrained_columns"] + try: + referred_pk = self.get_pk_constraint( + connection, rtbl, schema=schema, **kw + ) + referred_columns = referred_pk["constrained_columns"] + except exc.NoSuchTableError: + # ignore not existing parents + referred_columns = [] else: # note we use this list only if this is the first column # in the constraint. for subsequent columns we ignore the @@ -2378,11 +2386,11 @@ class SQLiteDialect(default.DefaultDialect): ) table_data = self._get_table_sql(connection, table_name, schema=schema) - if table_data is None: - # system tables, etc. - return [] def parse_fks(): + if table_data is None: + # system tables, etc. + return FK_PATTERN = ( r"(?:CONSTRAINT (\w+) +)?" r"FOREIGN KEY *\( *(.+?) *\) +" @@ -2453,7 +2461,10 @@ class SQLiteDialect(default.DefaultDialect): # use them as is as it's extremely difficult to parse inline # constraints fkeys.extend(keys_by_signature.values()) - return fkeys + if fkeys: + return fkeys + else: + return ReflectionDefaults.foreign_keys() def _find_cols_in_sig(self, sig): for match in re.finditer(r'(?:"(.+?)")|([a-z0-9_]+)', sig, re.I): @@ -2480,12 +2491,11 @@ class SQLiteDialect(default.DefaultDialect): table_data = self._get_table_sql( connection, table_name, schema=schema, **kw ) - if not table_data: - return [] - unique_constraints = [] def parse_uqs(): + if table_data is None: + return UNIQUE_PATTERN = r'(?:CONSTRAINT "?(.+?)"? +)?UNIQUE *\((.+?)\)' INLINE_UNIQUE_PATTERN = ( r'(?:(".+?")|(?:[\[`])?([a-z0-9_]+)(?:[\]`])?) ' @@ -2513,15 +2523,16 @@ class SQLiteDialect(default.DefaultDialect): unique_constraints.append(parsed_constraint) # NOTE: auto_index_by_sig might not be empty here, # the PRIMARY KEY may have an entry. - return unique_constraints + if unique_constraints: + return unique_constraints + else: + return ReflectionDefaults.unique_constraints() @reflection.cache def get_check_constraints(self, connection, table_name, schema=None, **kw): table_data = self._get_table_sql( connection, table_name, schema=schema, **kw ) - if not table_data: - return [] CHECK_PATTERN = r"(?:CONSTRAINT (.+) +)?" r"CHECK *\( *(.+) *\),? *" check_constraints = [] @@ -2531,7 +2542,7 @@ class SQLiteDialect(default.DefaultDialect): # necessarily makes assumptions as to how the CREATE TABLE # was emitted. - for match in re.finditer(CHECK_PATTERN, table_data, re.I): + for match in re.finditer(CHECK_PATTERN, table_data or "", re.I): name = match.group(1) if name: @@ -2539,7 +2550,10 @@ class SQLiteDialect(default.DefaultDialect): check_constraints.append({"sqltext": match.group(2), "name": name}) - return check_constraints + if check_constraints: + return check_constraints + else: + return ReflectionDefaults.check_constraints() @reflection.cache def get_indexes(self, connection, table_name, schema=None, **kw): @@ -2561,7 +2575,7 @@ class SQLiteDialect(default.DefaultDialect): # loop thru unique indexes to get the column names. for idx in list(indexes): pragma_index = self._get_table_pragma( - connection, "index_info", idx["name"] + connection, "index_info", idx["name"], schema=schema ) for row in pragma_index: @@ -2574,7 +2588,23 @@ class SQLiteDialect(default.DefaultDialect): break else: idx["column_names"].append(row[2]) - return indexes + indexes.sort(key=lambda d: d["name"] or "~") # sort None as last + if indexes: + return indexes + elif not self.has_table(connection, table_name, schema): + raise exc.NoSuchTableError( + f"{schema}.{table_name}" if schema else table_name + ) + else: + return ReflectionDefaults.indexes() + + def _is_sys_table(self, table_name): + return table_name in { + "sqlite_schema", + "sqlite_master", + "sqlite_temp_schema", + "sqlite_temp_master", + } @reflection.cache def _get_table_sql(self, connection, table_name, schema=None, **kw): @@ -2590,22 +2620,25 @@ class SQLiteDialect(default.DefaultDialect): " (SELECT * FROM %(schema)ssqlite_master UNION ALL " " SELECT * FROM %(schema)ssqlite_temp_master) " "WHERE name = ? " - "AND type = 'table'" % {"schema": schema_expr} + "AND type in ('table', 'view')" % {"schema": schema_expr} ) rs = connection.exec_driver_sql(s, (table_name,)) except exc.DBAPIError: s = ( "SELECT sql FROM %(schema)ssqlite_master " "WHERE name = ? " - "AND type = 'table'" % {"schema": schema_expr} + "AND type in ('table', 'view')" % {"schema": schema_expr} ) rs = connection.exec_driver_sql(s, (table_name,)) - return rs.scalar() + value = rs.scalar() + if value is None and not self._is_sys_table(table_name): + raise exc.NoSuchTableError(f"{schema_expr}{table_name}") + return value def _get_table_pragma(self, connection, pragma, table_name, schema=None): quote = self.identifier_preparer.quote_identifier if schema is not None: - statements = ["PRAGMA %s." % quote(schema)] + statements = [f"PRAGMA {quote(schema)}."] else: # because PRAGMA looks in all attached databases if no schema # given, need to specify "main" schema, however since we want @@ -2615,7 +2648,7 @@ class SQLiteDialect(default.DefaultDialect): qtable = quote(table_name) for statement in statements: - statement = "%s%s(%s)" % (statement, pragma, qtable) + statement = f"{statement}{pragma}({qtable})" cursor = connection.exec_driver_sql(statement) if not cursor._soft_closed: # work around SQLite issue whereby cursor.description |
