summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/dialects
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/dialects')
-rw-r--r--lib/sqlalchemy/dialects/mssql/base.py172
-rw-r--r--lib/sqlalchemy/dialects/mysql/base.py47
-rw-r--r--lib/sqlalchemy/dialects/oracle/base.py2240
-rw-r--r--lib/sqlalchemy/dialects/oracle/cx_oracle.py4
-rw-r--r--lib/sqlalchemy/dialects/oracle/dictionary.py495
-rw-r--r--lib/sqlalchemy/dialects/oracle/provision.py54
-rw-r--r--lib/sqlalchemy/dialects/oracle/types.py233
-rw-r--r--lib/sqlalchemy/dialects/postgresql/__init__.py31
-rw-r--r--lib/sqlalchemy/dialects/postgresql/_psycopg_common.py18
-rw-r--r--lib/sqlalchemy/dialects/postgresql/asyncpg.py5
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py2624
-rw-r--r--lib/sqlalchemy/dialects/postgresql/pg8000.py7
-rw-r--r--lib/sqlalchemy/dialects/postgresql/pg_catalog.py292
-rw-r--r--lib/sqlalchemy/dialects/postgresql/types.py485
-rw-r--r--lib/sqlalchemy/dialects/sqlite/base.py161
15 files changed, 4459 insertions, 2409 deletions
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py
index 12f495d6e..2a4362ccb 100644
--- a/lib/sqlalchemy/dialects/mssql/base.py
+++ b/lib/sqlalchemy/dialects/mssql/base.py
@@ -831,6 +831,7 @@ from ... import util
from ...engine import cursor as _cursor
from ...engine import default
from ...engine import reflection
+from ...engine.reflection import ReflectionDefaults
from ...sql import coercions
from ...sql import compiler
from ...sql import elements
@@ -3010,55 +3011,16 @@ class MSDialect(default.DefaultDialect):
return self.schema_name
@_db_plus_owner
- def has_table(self, connection, tablename, dbname, owner, schema):
+ def has_table(self, connection, tablename, dbname, owner, schema, **kw):
self._ensure_has_table_connection(connection)
- if tablename.startswith("#"): # temporary table
- # mssql does not support temporary views
- # SQL Error [4103] [S0001]: "#v": Temporary views are not allowed
- tables = ischema.mssql_temp_table_columns
- s = sql.select(tables.c.table_name).where(
- tables.c.table_name.like(
- self._temp_table_name_like_pattern(tablename)
- )
- )
-
- # #7168: fetch all (not just first match) in case some other #temp
- # table with the same name happens to appear first
- table_names = connection.execute(s).scalars().fetchall()
- # #6910: verify it's not a temp table from another session
- for table_name in table_names:
- if bool(
- connection.scalar(
- text("SELECT object_id(:table_name)"),
- {"table_name": "tempdb.dbo.[{}]".format(table_name)},
- )
- ):
- return True
- else:
- return False
- else:
- tables = ischema.tables
-
- s = sql.select(tables.c.table_name).where(
- sql.and_(
- sql.or_(
- tables.c.table_type == "BASE TABLE",
- tables.c.table_type == "VIEW",
- ),
- tables.c.table_name == tablename,
- )
- )
-
- if owner:
- s = s.where(tables.c.table_schema == owner)
-
- c = connection.execute(s)
-
- return c.first() is not None
+ return self._internal_has_table(connection, tablename, owner, **kw)
+ @reflection.cache
@_db_plus_owner
- def has_sequence(self, connection, sequencename, dbname, owner, schema):
+ def has_sequence(
+ self, connection, sequencename, dbname, owner, schema, **kw
+ ):
sequences = ischema.sequences
s = sql.select(sequences.c.sequence_name).where(
@@ -3128,6 +3090,60 @@ class MSDialect(default.DefaultDialect):
return view_names
@reflection.cache
+ def _internal_has_table(self, connection, tablename, owner, **kw):
+ if tablename.startswith("#"): # temporary table
+ # mssql does not support temporary views
+ # SQL Error [4103] [S0001]: "#v": Temporary views are not allowed
+ tables = ischema.mssql_temp_table_columns
+
+ s = sql.select(tables.c.table_name).where(
+ tables.c.table_name.like(
+ self._temp_table_name_like_pattern(tablename)
+ )
+ )
+
+ # #7168: fetch all (not just first match) in case some other #temp
+ # table with the same name happens to appear first
+ table_names = connection.scalars(s).all()
+ # #6910: verify it's not a temp table from another session
+ for table_name in table_names:
+ if bool(
+ connection.scalar(
+ text("SELECT object_id(:table_name)"),
+ {"table_name": "tempdb.dbo.[{}]".format(table_name)},
+ )
+ ):
+ return True
+ else:
+ return False
+ else:
+ tables = ischema.tables
+
+ s = sql.select(tables.c.table_name).where(
+ sql.and_(
+ sql.or_(
+ tables.c.table_type == "BASE TABLE",
+ tables.c.table_type == "VIEW",
+ ),
+ tables.c.table_name == tablename,
+ )
+ )
+
+ if owner:
+ s = s.where(tables.c.table_schema == owner)
+
+ c = connection.execute(s)
+
+ return c.first() is not None
+
+ def _default_or_error(self, connection, tablename, owner, method, **kw):
+ # TODO: try to avoid having to run a separate query here
+ if self._internal_has_table(connection, tablename, owner, **kw):
+ return method()
+ else:
+ raise exc.NoSuchTableError(f"{owner}.{tablename}")
+
+ @reflection.cache
@_db_plus_owner
def get_indexes(self, connection, tablename, dbname, owner, schema, **kw):
filter_definition = (
@@ -3138,14 +3154,14 @@ class MSDialect(default.DefaultDialect):
rp = connection.execution_options(future_result=True).execute(
sql.text(
"select ind.index_id, ind.is_unique, ind.name, "
- "%s "
+ f"{filter_definition} "
"from sys.indexes as ind join sys.tables as tab on "
"ind.object_id=tab.object_id "
"join sys.schemas as sch on sch.schema_id=tab.schema_id "
"where tab.name = :tabname "
"and sch.name=:schname "
- "and ind.is_primary_key=0 and ind.type != 0"
- % filter_definition
+ "and ind.is_primary_key=0 and ind.type != 0 "
+ "order by ind.name "
)
.bindparams(
sql.bindparam("tabname", tablename, ischema.CoerceUnicode()),
@@ -3203,31 +3219,34 @@ class MSDialect(default.DefaultDialect):
"mssql_include"
] = index_info["include_columns"]
- return list(indexes.values())
+ if indexes:
+ return list(indexes.values())
+ else:
+ return self._default_or_error(
+ connection, tablename, owner, ReflectionDefaults.indexes, **kw
+ )
@reflection.cache
@_db_plus_owner
def get_view_definition(
self, connection, viewname, dbname, owner, schema, **kw
):
- rp = connection.execute(
+ view_def = connection.execute(
sql.text(
- "select definition from sys.sql_modules as mod, "
- "sys.views as views, "
- "sys.schemas as sch"
- " where "
- "mod.object_id=views.object_id and "
- "views.schema_id=sch.schema_id and "
- "views.name=:viewname and sch.name=:schname"
+ "select mod.definition "
+ "from sys.sql_modules as mod "
+ "join sys.views as views on mod.object_id = views.object_id "
+ "join sys.schemas as sch on views.schema_id = sch.schema_id "
+ "where views.name=:viewname and sch.name=:schname"
).bindparams(
sql.bindparam("viewname", viewname, ischema.CoerceUnicode()),
sql.bindparam("schname", owner, ischema.CoerceUnicode()),
)
- )
-
- if rp:
- view_def = rp.scalar()
+ ).scalar()
+ if view_def:
return view_def
+ else:
+ raise exc.NoSuchTableError(f"{owner}.{viewname}")
def _temp_table_name_like_pattern(self, tablename):
# LIKE uses '%' to match zero or more characters and '_' to match any
@@ -3417,7 +3436,12 @@ class MSDialect(default.DefaultDialect):
cols.append(cdict)
- return cols
+ if cols:
+ return cols
+ else:
+ return self._default_or_error(
+ connection, tablename, owner, ReflectionDefaults.columns, **kw
+ )
@reflection.cache
@_db_plus_owner
@@ -3450,7 +3474,16 @@ class MSDialect(default.DefaultDialect):
pkeys.append(row["COLUMN_NAME"])
if constraint_name is None:
constraint_name = row[C.c.constraint_name.name]
- return {"constrained_columns": pkeys, "name": constraint_name}
+ if pkeys:
+ return {"constrained_columns": pkeys, "name": constraint_name}
+ else:
+ return self._default_or_error(
+ connection,
+ tablename,
+ owner,
+ ReflectionDefaults.pk_constraint,
+ **kw,
+ )
@reflection.cache
@_db_plus_owner
@@ -3591,7 +3624,7 @@ index_info AS (
fkeys = util.defaultdict(fkey_rec)
- for r in connection.execute(s).fetchall():
+ for r in connection.execute(s).all():
(
_, # constraint schema
rfknm,
@@ -3632,4 +3665,13 @@ index_info AS (
local_cols.append(scol)
remote_cols.append(rcol)
- return list(fkeys.values())
+ if fkeys:
+ return list(fkeys.values())
+ else:
+ return self._default_or_error(
+ connection,
+ tablename,
+ owner,
+ ReflectionDefaults.foreign_keys,
+ **kw,
+ )
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py
index 7c9a68236..502371be9 100644
--- a/lib/sqlalchemy/dialects/mysql/base.py
+++ b/lib/sqlalchemy/dialects/mysql/base.py
@@ -1056,6 +1056,7 @@ from ... import sql
from ... import util
from ...engine import default
from ...engine import reflection
+from ...engine.reflection import ReflectionDefaults
from ...sql import coercions
from ...sql import compiler
from ...sql import elements
@@ -2648,7 +2649,8 @@ class MySQLDialect(default.DefaultDialect):
def _get_default_schema_name(self, connection):
return connection.exec_driver_sql("SELECT DATABASE()").scalar()
- def has_table(self, connection, table_name, schema=None):
+ @reflection.cache
+ def has_table(self, connection, table_name, schema=None, **kw):
self._ensure_has_table_connection(connection)
if schema is None:
@@ -2670,7 +2672,8 @@ class MySQLDialect(default.DefaultDialect):
)
return bool(rs.scalar())
- def has_sequence(self, connection, sequence_name, schema=None):
+ @reflection.cache
+ def has_sequence(self, connection, sequence_name, schema=None, **kw):
if not self.supports_sequences:
self._sequences_not_supported()
if not schema:
@@ -2847,14 +2850,20 @@ class MySQLDialect(default.DefaultDialect):
parsed_state = self._parsed_state_or_create(
connection, table_name, schema, **kw
)
- return parsed_state.table_options
+ if parsed_state.table_options:
+ return parsed_state.table_options
+ else:
+ return ReflectionDefaults.table_options()
@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
parsed_state = self._parsed_state_or_create(
connection, table_name, schema, **kw
)
- return parsed_state.columns
+ if parsed_state.columns:
+ return parsed_state.columns
+ else:
+ return ReflectionDefaults.columns()
@reflection.cache
def get_pk_constraint(self, connection, table_name, schema=None, **kw):
@@ -2866,7 +2875,7 @@ class MySQLDialect(default.DefaultDialect):
# There can be only one.
cols = [s[0] for s in key["columns"]]
return {"constrained_columns": cols, "name": None}
- return {"constrained_columns": [], "name": None}
+ return ReflectionDefaults.pk_constraint()
@reflection.cache
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
@@ -2909,7 +2918,7 @@ class MySQLDialect(default.DefaultDialect):
if self._needs_correct_for_88718_96365:
self._correct_for_mysql_bugs_88718_96365(fkeys, connection)
- return fkeys
+ return fkeys if fkeys else ReflectionDefaults.foreign_keys()
def _correct_for_mysql_bugs_88718_96365(self, fkeys, connection):
# Foreign key is always in lower case (MySQL 8.0)
@@ -3000,21 +3009,22 @@ class MySQLDialect(default.DefaultDialect):
connection, table_name, schema, **kw
)
- return [
+ cks = [
{"name": spec["name"], "sqltext": spec["sqltext"]}
for spec in parsed_state.ck_constraints
]
+ return cks if cks else ReflectionDefaults.check_constraints()
@reflection.cache
def get_table_comment(self, connection, table_name, schema=None, **kw):
parsed_state = self._parsed_state_or_create(
connection, table_name, schema, **kw
)
- return {
- "text": parsed_state.table_options.get(
- "%s_comment" % self.name, None
- )
- }
+ comment = parsed_state.table_options.get(f"{self.name}_comment", None)
+ if comment is not None:
+ return {"text": comment}
+ else:
+ return ReflectionDefaults.table_comment()
@reflection.cache
def get_indexes(self, connection, table_name, schema=None, **kw):
@@ -3058,7 +3068,8 @@ class MySQLDialect(default.DefaultDialect):
if flavor:
index_d["type"] = flavor
indexes.append(index_d)
- return indexes
+ indexes.sort(key=lambda d: d["name"] or "~") # sort None as last
+ return indexes if indexes else ReflectionDefaults.indexes()
@reflection.cache
def get_unique_constraints(
@@ -3068,7 +3079,7 @@ class MySQLDialect(default.DefaultDialect):
connection, table_name, schema, **kw
)
- return [
+ ucs = [
{
"name": key["name"],
"column_names": [col[0] for col in key["columns"]],
@@ -3077,6 +3088,11 @@ class MySQLDialect(default.DefaultDialect):
for key in parsed_state.keys
if key["type"] == "UNIQUE"
]
+ ucs.sort(key=lambda d: d["name"] or "~") # sort None as last
+ if ucs:
+ return ucs
+ else:
+ return ReflectionDefaults.unique_constraints()
@reflection.cache
def get_view_definition(self, connection, view_name, schema=None, **kw):
@@ -3088,6 +3104,9 @@ class MySQLDialect(default.DefaultDialect):
sql = self._show_create_table(
connection, None, charset, full_name=full_name
)
+ if sql.upper().startswith("CREATE TABLE"):
+ # it's a table, not a view
+ raise exc.NoSuchTableError(full_name)
return sql
def _parsed_state_or_create(
diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py
index faac0deb7..fee098889 100644
--- a/lib/sqlalchemy/dialects/oracle/base.py
+++ b/lib/sqlalchemy/dialects/oracle/base.py
@@ -518,21 +518,52 @@ columns for non-unique indexes, all but the last column for unique indexes).
""" # noqa
-from itertools import groupby
+from __future__ import annotations
+
+from collections import defaultdict
+from functools import lru_cache
+from functools import wraps
import re
+from . import dictionary
+from .types import _OracleBoolean
+from .types import _OracleDate
+from .types import BFILE
+from .types import BINARY_DOUBLE
+from .types import BINARY_FLOAT
+from .types import DATE
+from .types import FLOAT
+from .types import INTERVAL
+from .types import LONG
+from .types import NCLOB
+from .types import NUMBER
+from .types import NVARCHAR2 # noqa
+from .types import OracleRaw # noqa
+from .types import RAW
+from .types import ROWID # noqa
+from .types import VARCHAR2 # noqa
from ... import Computed
from ... import exc
from ... import schema as sa_schema
from ... import sql
from ... import util
from ...engine import default
+from ...engine import ObjectKind
+from ...engine import ObjectScope
from ...engine import reflection
+from ...engine.reflection import ReflectionDefaults
+from ...sql import and_
+from ...sql import bindparam
from ...sql import compiler
from ...sql import expression
+from ...sql import func
+from ...sql import null
+from ...sql import or_
+from ...sql import select
from ...sql import sqltypes
from ...sql import util as sql_util
from ...sql import visitors
+from ...sql.visitors import InternalTraversal
from ...types import BLOB
from ...types import CHAR
from ...types import CLOB
@@ -561,229 +592,6 @@ NO_ARG_FNS = set(
)
-class RAW(sqltypes._Binary):
- __visit_name__ = "RAW"
-
-
-OracleRaw = RAW
-
-
-class NCLOB(sqltypes.Text):
- __visit_name__ = "NCLOB"
-
-
-class VARCHAR2(VARCHAR):
- __visit_name__ = "VARCHAR2"
-
-
-NVARCHAR2 = NVARCHAR
-
-
-class NUMBER(sqltypes.Numeric, sqltypes.Integer):
- __visit_name__ = "NUMBER"
-
- def __init__(self, precision=None, scale=None, asdecimal=None):
- if asdecimal is None:
- asdecimal = bool(scale and scale > 0)
-
- super(NUMBER, self).__init__(
- precision=precision, scale=scale, asdecimal=asdecimal
- )
-
- def adapt(self, impltype):
- ret = super(NUMBER, self).adapt(impltype)
- # leave a hint for the DBAPI handler
- ret._is_oracle_number = True
- return ret
-
- @property
- def _type_affinity(self):
- if bool(self.scale and self.scale > 0):
- return sqltypes.Numeric
- else:
- return sqltypes.Integer
-
-
-class FLOAT(sqltypes.FLOAT):
- """Oracle FLOAT.
-
- This is the same as :class:`_sqltypes.FLOAT` except that
- an Oracle-specific :paramref:`_oracle.FLOAT.binary_precision`
- parameter is accepted, and
- the :paramref:`_sqltypes.Float.precision` parameter is not accepted.
-
- Oracle FLOAT types indicate precision in terms of "binary precision", which
- defaults to 126. For a REAL type, the value is 63. This parameter does not
- cleanly map to a specific number of decimal places but is roughly
- equivalent to the desired number of decimal places divided by 0.3103.
-
- .. versionadded:: 2.0
-
- """
-
- __visit_name__ = "FLOAT"
-
- def __init__(
- self,
- binary_precision=None,
- asdecimal=False,
- decimal_return_scale=None,
- ):
- r"""
- Construct a FLOAT
-
- :param binary_precision: Oracle binary precision value to be rendered
- in DDL. This may be approximated to the number of decimal characters
- using the formula "decimal precision = 0.30103 * binary precision".
- The default value used by Oracle for FLOAT / DOUBLE PRECISION is 126.
-
- :param asdecimal: See :paramref:`_sqltypes.Float.asdecimal`
-
- :param decimal_return_scale: See
- :paramref:`_sqltypes.Float.decimal_return_scale`
-
- """
- super().__init__(
- asdecimal=asdecimal, decimal_return_scale=decimal_return_scale
- )
- self.binary_precision = binary_precision
-
-
-class BINARY_DOUBLE(sqltypes.Float):
- __visit_name__ = "BINARY_DOUBLE"
-
-
-class BINARY_FLOAT(sqltypes.Float):
- __visit_name__ = "BINARY_FLOAT"
-
-
-class BFILE(sqltypes.LargeBinary):
- __visit_name__ = "BFILE"
-
-
-class LONG(sqltypes.Text):
- __visit_name__ = "LONG"
-
-
-class _OracleDateLiteralRender:
- def _literal_processor_datetime(self, dialect):
- def process(value):
- if value is not None:
- if getattr(value, "microsecond", None):
- value = (
- f"""TO_TIMESTAMP"""
- f"""('{value.isoformat().replace("T", " ")}', """
- """'YYYY-MM-DD HH24:MI:SS.FF')"""
- )
- else:
- value = (
- f"""TO_DATE"""
- f"""('{value.isoformat().replace("T", " ")}', """
- """'YYYY-MM-DD HH24:MI:SS')"""
- )
- return value
-
- return process
-
- def _literal_processor_date(self, dialect):
- def process(value):
- if value is not None:
- if getattr(value, "microsecond", None):
- value = (
- f"""TO_TIMESTAMP"""
- f"""('{value.isoformat().split("T")[0]}', """
- """'YYYY-MM-DD')"""
- )
- else:
- value = (
- f"""TO_DATE"""
- f"""('{value.isoformat().split("T")[0]}', """
- """'YYYY-MM-DD')"""
- )
- return value
-
- return process
-
-
-class DATE(_OracleDateLiteralRender, sqltypes.DateTime):
- """Provide the oracle DATE type.
-
- This type has no special Python behavior, except that it subclasses
- :class:`_types.DateTime`; this is to suit the fact that the Oracle
- ``DATE`` type supports a time value.
-
- .. versionadded:: 0.9.4
-
- """
-
- __visit_name__ = "DATE"
-
- def literal_processor(self, dialect):
- return self._literal_processor_datetime(dialect)
-
- def _compare_type_affinity(self, other):
- return other._type_affinity in (sqltypes.DateTime, sqltypes.Date)
-
-
-class _OracleDate(_OracleDateLiteralRender, sqltypes.Date):
- def literal_processor(self, dialect):
- return self._literal_processor_date(dialect)
-
-
-class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
- __visit_name__ = "INTERVAL"
-
- def __init__(self, day_precision=None, second_precision=None):
- """Construct an INTERVAL.
-
- Note that only DAY TO SECOND intervals are currently supported.
- This is due to a lack of support for YEAR TO MONTH intervals
- within available DBAPIs.
-
- :param day_precision: the day precision value. this is the number of
- digits to store for the day field. Defaults to "2"
- :param second_precision: the second precision value. this is the
- number of digits to store for the fractional seconds field.
- Defaults to "6".
-
- """
- self.day_precision = day_precision
- self.second_precision = second_precision
-
- @classmethod
- def _adapt_from_generic_interval(cls, interval):
- return INTERVAL(
- day_precision=interval.day_precision,
- second_precision=interval.second_precision,
- )
-
- @property
- def _type_affinity(self):
- return sqltypes.Interval
-
- def as_generic(self, allow_nulltype=False):
- return sqltypes.Interval(
- native=True,
- second_precision=self.second_precision,
- day_precision=self.day_precision,
- )
-
-
-class ROWID(sqltypes.TypeEngine):
- """Oracle ROWID type.
-
- When used in a cast() or similar, generates ROWID.
-
- """
-
- __visit_name__ = "ROWID"
-
-
-class _OracleBoolean(sqltypes.Boolean):
- def get_dbapi_type(self, dbapi):
- return dbapi.NUMBER
-
-
colspecs = {
sqltypes.Boolean: _OracleBoolean,
sqltypes.Interval: INTERVAL,
@@ -1541,6 +1349,13 @@ class OracleExecutionContext(default.DefaultExecutionContext):
type_,
)
+ def pre_exec(self):
+ if self.statement and "_oracle_dblink" in self.execution_options:
+ self.statement = self.statement.replace(
+ dictionary.DB_LINK_PLACEHOLDER,
+ self.execution_options["_oracle_dblink"],
+ )
+
class OracleDialect(default.DefaultDialect):
name = "oracle"
@@ -1675,6 +1490,10 @@ class OracleDialect(default.DefaultDialect):
# it may work also on versions before the 18
return self.server_version_info and self.server_version_info >= (18,)
+ @property
+ def _supports_except_all(self):
+ return self.server_version_info and self.server_version_info >= (21,)
+
def do_release_savepoint(self, connection, name):
# Oracle does not support RELEASE SAVEPOINT
pass
@@ -1700,45 +1519,99 @@ class OracleDialect(default.DefaultDialect):
except:
return "READ COMMITTED"
- def has_table(self, connection, table_name, schema=None):
+ def _execute_reflection(
+ self, connection, query, dblink, returns_long, params=None
+ ):
+ if dblink and not dblink.startswith("@"):
+ dblink = f"@{dblink}"
+ execution_options = {
+ # handle db links
+ "_oracle_dblink": dblink or "",
+ # override any schema translate map
+ "schema_translate_map": None,
+ }
+
+ if dblink and returns_long:
+ # Oracle seems to error with
+ # "ORA-00997: illegal use of LONG datatype" when returning
+ # LONG columns via a dblink in a query with bind params
+ # This type seems to be very hard to cast into something else
+ # so it seems easier to just use bind param in this case
+ def visit_bindparam(bindparam):
+ bindparam.literal_execute = True
+
+ query = visitors.cloned_traverse(
+ query, {}, {"bindparam": visit_bindparam}
+ )
+ return connection.execute(
+ query, params, execution_options=execution_options
+ )
+
+ @util.memoized_property
+ def _has_table_query(self):
+ # materialized views are returned by all_tables
+ tables = (
+ select(
+ dictionary.all_tables.c.table_name,
+ dictionary.all_tables.c.owner,
+ )
+ .union_all(
+ select(
+ dictionary.all_views.c.view_name.label("table_name"),
+ dictionary.all_views.c.owner,
+ )
+ )
+ .subquery("tables_and_views")
+ )
+
+ query = select(tables.c.table_name).where(
+ tables.c.table_name == bindparam("table_name"),
+ tables.c.owner == bindparam("owner"),
+ )
+ return query
+
+ @reflection.cache
+ def has_table(
+ self, connection, table_name, schema=None, dblink=None, **kw
+ ):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link."""
self._ensure_has_table_connection(connection)
if not schema:
schema = self.default_schema_name
- cursor = connection.execute(
- sql.text(
- """SELECT table_name FROM all_tables
- WHERE table_name = CAST(:name AS VARCHAR2(128))
- AND owner = CAST(:schema_name AS VARCHAR2(128))
- UNION ALL
- SELECT view_name FROM all_views
- WHERE view_name = CAST(:name AS VARCHAR2(128))
- AND owner = CAST(:schema_name AS VARCHAR2(128))
- """
- ),
- dict(
- name=self.denormalize_name(table_name),
- schema_name=self.denormalize_name(schema),
- ),
+ params = {
+ "table_name": self.denormalize_name(table_name),
+ "owner": self.denormalize_name(schema),
+ }
+ cursor = self._execute_reflection(
+ connection,
+ self._has_table_query,
+ dblink,
+ returns_long=False,
+ params=params,
)
- return cursor.first() is not None
+ return bool(cursor.scalar())
- def has_sequence(self, connection, sequence_name, schema=None):
+ @reflection.cache
+ def has_sequence(
+ self, connection, sequence_name, schema=None, dblink=None, **kw
+ ):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link."""
if not schema:
schema = self.default_schema_name
- cursor = connection.execute(
- sql.text(
- "SELECT sequence_name FROM all_sequences "
- "WHERE sequence_name = :name AND "
- "sequence_owner = :schema_name"
- ),
- dict(
- name=self.denormalize_name(sequence_name),
- schema_name=self.denormalize_name(schema),
- ),
+
+ query = select(dictionary.all_sequences.c.sequence_name).where(
+ dictionary.all_sequences.c.sequence_name
+ == self.denormalize_name(sequence_name),
+ dictionary.all_sequences.c.sequence_owner
+ == self.denormalize_name(schema),
)
- return cursor.first() is not None
+
+ cursor = self._execute_reflection(
+ connection, query, dblink, returns_long=False
+ )
+ return bool(cursor.scalar())
def _get_default_schema_name(self, connection):
return self.normalize_name(
@@ -1747,329 +1620,633 @@ class OracleDialect(default.DefaultDialect):
).scalar()
)
- def _resolve_synonym(
- self,
- connection,
- desired_owner=None,
- desired_synonym=None,
- desired_table=None,
+ @reflection.flexi_cache(
+ ("schema", InternalTraversal.dp_string),
+ ("filter_names", InternalTraversal.dp_string_list),
+ ("dblink", InternalTraversal.dp_string),
+ )
+ def _get_synonyms(self, connection, schema, filter_names, dblink, **kw):
+ owner = self.denormalize_name(schema or self.default_schema_name)
+
+ has_filter_names, params = self._prepare_filter_names(filter_names)
+ query = select(
+ dictionary.all_synonyms.c.synonym_name,
+ dictionary.all_synonyms.c.table_name,
+ dictionary.all_synonyms.c.table_owner,
+ dictionary.all_synonyms.c.db_link,
+ ).where(dictionary.all_synonyms.c.owner == owner)
+ if has_filter_names:
+ query = query.where(
+ dictionary.all_synonyms.c.synonym_name.in_(
+ params["filter_names"]
+ )
+ )
+ result = self._execute_reflection(
+ connection, query, dblink, returns_long=False
+ ).mappings()
+ return result.all()
+
+ @lru_cache()
+ def _all_objects_query(
+ self, owner, scope, kind, has_filter_names, has_mat_views
):
- """search for a local synonym matching the given desired owner/name.
-
- if desired_owner is None, attempts to locate a distinct owner.
-
- returns the actual name, owner, dblink name, and synonym name if
- found.
- """
-
- q = (
- "SELECT owner, table_owner, table_name, db_link, "
- "synonym_name FROM all_synonyms WHERE "
+ query = (
+ select(dictionary.all_objects.c.object_name)
+ .select_from(dictionary.all_objects)
+ .where(dictionary.all_objects.c.owner == owner)
)
- clauses = []
- params = {}
- if desired_synonym:
- clauses.append(
- "synonym_name = CAST(:synonym_name AS VARCHAR2(128))"
+
+ # NOTE: materialized views are listed in all_objects twice;
+ # once as MATERIALIZE VIEW and once as TABLE
+ if kind is ObjectKind.ANY:
+ # materilaized view are listed also as tables so there is no
+ # need to add them to the in_.
+ query = query.where(
+ dictionary.all_objects.c.object_type.in_(("TABLE", "VIEW"))
)
- params["synonym_name"] = desired_synonym
- if desired_owner:
- clauses.append("owner = CAST(:desired_owner AS VARCHAR2(128))")
- params["desired_owner"] = desired_owner
- if desired_table:
- clauses.append("table_name = CAST(:tname AS VARCHAR2(128))")
- params["tname"] = desired_table
-
- q += " AND ".join(clauses)
-
- result = connection.execution_options(future_result=True).execute(
- sql.text(q), params
- )
- if desired_owner:
- row = result.mappings().first()
- if row:
- return (
- row["table_name"],
- row["table_owner"],
- row["db_link"],
- row["synonym_name"],
- )
- else:
- return None, None, None, None
else:
- rows = result.mappings().all()
- if len(rows) > 1:
- raise AssertionError(
- "There are multiple tables visible to the schema, you "
- "must specify owner"
- )
- elif len(rows) == 1:
- row = rows[0]
- return (
- row["table_name"],
- row["table_owner"],
- row["db_link"],
- row["synonym_name"],
- )
- else:
- return None, None, None, None
+ object_type = []
+ if ObjectKind.VIEW in kind:
+ object_type.append("VIEW")
+ if (
+ ObjectKind.MATERIALIZED_VIEW in kind
+ and ObjectKind.TABLE not in kind
+ ):
+ # materilaized view are listed also as tables so there is no
+ # need to add them to the in_ if also selecting tables.
+ object_type.append("MATERIALIZED VIEW")
+ if ObjectKind.TABLE in kind:
+ object_type.append("TABLE")
+ if has_mat_views and ObjectKind.MATERIALIZED_VIEW not in kind:
+ # materialized view are listed also as tables,
+ # so they need to be filtered out
+ # EXCEPT ALL / MINUS profiles as faster than using
+ # NOT EXISTS or NOT IN with a subquery, but it's in
+ # general faster to get the mat view names and exclude
+ # them only when needed
+ query = query.where(
+ dictionary.all_objects.c.object_name.not_in(
+ bindparam("mat_views")
+ )
+ )
+ query = query.where(
+ dictionary.all_objects.c.object_type.in_(object_type)
+ )
- @reflection.cache
- def _prepare_reflection_args(
- self,
- connection,
- table_name,
- schema=None,
- resolve_synonyms=False,
- dblink="",
- **kw,
- ):
+ # handles scope
+ if scope is ObjectScope.DEFAULT:
+ query = query.where(dictionary.all_objects.c.temporary == "N")
+ elif scope is ObjectScope.TEMPORARY:
+ query = query.where(dictionary.all_objects.c.temporary == "Y")
- if resolve_synonyms:
- actual_name, owner, dblink, synonym = self._resolve_synonym(
- connection,
- desired_owner=self.denormalize_name(schema),
- desired_synonym=self.denormalize_name(table_name),
+ if has_filter_names:
+ query = query.where(
+ dictionary.all_objects.c.object_name.in_(
+ bindparam("filter_names")
+ )
)
- else:
- actual_name, owner, dblink, synonym = None, None, None, None
- if not actual_name:
- actual_name = self.denormalize_name(table_name)
-
- if dblink:
- # using user_db_links here since all_db_links appears
- # to have more restricted permissions.
- # https://docs.oracle.com/cd/B28359_01/server.111/b28310/ds_admin005.htm
- # will need to hear from more users if we are doing
- # the right thing here. See [ticket:2619]
- owner = connection.scalar(
- sql.text(
- "SELECT username FROM user_db_links " "WHERE db_link=:link"
- ),
- dict(link=dblink),
+ return query
+
+ @reflection.flexi_cache(
+ ("schema", InternalTraversal.dp_string),
+ ("scope", InternalTraversal.dp_plain_obj),
+ ("kind", InternalTraversal.dp_plain_obj),
+ ("filter_names", InternalTraversal.dp_string_list),
+ ("dblink", InternalTraversal.dp_string),
+ )
+ def _get_all_objects(
+ self, connection, schema, scope, kind, filter_names, dblink, **kw
+ ):
+ owner = self.denormalize_name(schema or self.default_schema_name)
+
+ has_filter_names, params = self._prepare_filter_names(filter_names)
+ has_mat_views = False
+ if (
+ ObjectKind.TABLE in kind
+ and ObjectKind.MATERIALIZED_VIEW not in kind
+ ):
+ # see note in _all_objects_query
+ mat_views = self.get_materialized_view_names(
+ connection, schema, dblink, _normalize=False, **kw
)
- dblink = "@" + dblink
- elif not owner:
- owner = self.denormalize_name(schema or self.default_schema_name)
+ if mat_views:
+ params["mat_views"] = mat_views
+ has_mat_views = True
+
+ query = self._all_objects_query(
+ owner, scope, kind, has_filter_names, has_mat_views
+ )
- return (actual_name, owner, dblink or "", synonym)
+ result = self._execute_reflection(
+ connection, query, dblink, returns_long=False, params=params
+ ).scalars()
- @reflection.cache
- def get_schema_names(self, connection, **kw):
- s = "SELECT username FROM all_users ORDER BY username"
- cursor = connection.exec_driver_sql(s)
- return [self.normalize_name(row[0]) for row in cursor]
+ return result.all()
+
+ def _handle_synonyms_decorator(fn):
+ @wraps(fn)
+ def wrapper(self, *args, **kwargs):
+ return self._handle_synonyms(fn, *args, **kwargs)
+
+ return wrapper
+
+ def _handle_synonyms(self, fn, connection, *args, **kwargs):
+ if not kwargs.get("oracle_resolve_synonyms", False):
+ return fn(self, connection, *args, **kwargs)
+
+ original_kw = kwargs.copy()
+ schema = kwargs.pop("schema", None)
+ result = self._get_synonyms(
+ connection,
+ schema=schema,
+ filter_names=kwargs.pop("filter_names", None),
+ dblink=kwargs.pop("dblink", None),
+ info_cache=kwargs.get("info_cache", None),
+ )
+
+ dblinks_owners = defaultdict(dict)
+ for row in result:
+ key = row["db_link"], row["table_owner"]
+ tn = self.normalize_name(row["table_name"])
+ dblinks_owners[key][tn] = row["synonym_name"]
+
+ if not dblinks_owners:
+ # No synonym, do the plain thing
+ return fn(self, connection, *args, **original_kw)
+
+ data = {}
+ for (dblink, table_owner), mapping in dblinks_owners.items():
+ call_kw = {
+ **original_kw,
+ "schema": table_owner,
+ "dblink": self.normalize_name(dblink),
+ "filter_names": mapping.keys(),
+ }
+ call_result = fn(self, connection, *args, **call_kw)
+ for (_, tn), value in call_result:
+ synonym_name = self.normalize_name(mapping[tn])
+ data[(schema, synonym_name)] = value
+ return data.items()
@reflection.cache
- def get_table_names(self, connection, schema=None, **kw):
- schema = self.denormalize_name(schema or self.default_schema_name)
+ def get_schema_names(self, connection, dblink=None, **kw):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link."""
+ query = select(dictionary.all_users.c.username).order_by(
+ dictionary.all_users.c.username
+ )
+ result = self._execute_reflection(
+ connection, query, dblink, returns_long=False
+ ).scalars()
+ return [self.normalize_name(row) for row in result]
+ @reflection.cache
+ def get_table_names(self, connection, schema=None, dblink=None, **kw):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link."""
# note that table_names() isn't loading DBLINKed or synonym'ed tables
if schema is None:
schema = self.default_schema_name
- sql_str = "SELECT table_name FROM all_tables WHERE "
+ den_schema = self.denormalize_name(schema)
+ if kw.get("oracle_resolve_synonyms", False):
+ tables = (
+ select(
+ dictionary.all_tables.c.table_name,
+ dictionary.all_tables.c.owner,
+ dictionary.all_tables.c.iot_name,
+ dictionary.all_tables.c.duration,
+ dictionary.all_tables.c.tablespace_name,
+ )
+ .union_all(
+ select(
+ dictionary.all_synonyms.c.synonym_name.label(
+ "table_name"
+ ),
+ dictionary.all_synonyms.c.owner,
+ dictionary.all_tables.c.iot_name,
+ dictionary.all_tables.c.duration,
+ dictionary.all_tables.c.tablespace_name,
+ )
+ .select_from(dictionary.all_tables)
+ .join(
+ dictionary.all_synonyms,
+ and_(
+ dictionary.all_tables.c.table_name
+ == dictionary.all_synonyms.c.table_name,
+ dictionary.all_tables.c.owner
+ == func.coalesce(
+ dictionary.all_synonyms.c.table_owner,
+ dictionary.all_synonyms.c.owner,
+ ),
+ ),
+ )
+ )
+ .subquery("available_tables")
+ )
+ else:
+ tables = dictionary.all_tables
+
+ query = select(tables.c.table_name)
if self.exclude_tablespaces:
- sql_str += (
- "nvl(tablespace_name, 'no tablespace') "
- "NOT IN (%s) AND "
- % (", ".join(["'%s'" % ts for ts in self.exclude_tablespaces]))
+ query = query.where(
+ func.coalesce(
+ tables.c.tablespace_name, "no tablespace"
+ ).not_in(self.exclude_tablespaces)
)
- sql_str += (
- "OWNER = :owner " "AND IOT_NAME IS NULL " "AND DURATION IS NULL"
+ query = query.where(
+ tables.c.owner == den_schema,
+ tables.c.iot_name.is_(null()),
+ tables.c.duration.is_(null()),
)
- cursor = connection.execute(sql.text(sql_str), dict(owner=schema))
- return [self.normalize_name(row[0]) for row in cursor]
+ # remove materialized views
+ mat_query = select(
+ dictionary.all_mviews.c.mview_name.label("table_name")
+ ).where(dictionary.all_mviews.c.owner == den_schema)
+
+ query = (
+ query.except_all(mat_query)
+ if self._supports_except_all
+ else query.except_(mat_query)
+ )
+
+ result = self._execute_reflection(
+ connection, query, dblink, returns_long=False
+ ).scalars()
+ return [self.normalize_name(row) for row in result]
@reflection.cache
- def get_temp_table_names(self, connection, **kw):
+ def get_temp_table_names(self, connection, dblink=None, **kw):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link."""
schema = self.denormalize_name(self.default_schema_name)
- sql_str = "SELECT table_name FROM all_tables WHERE "
+ query = select(dictionary.all_tables.c.table_name)
if self.exclude_tablespaces:
- sql_str += (
- "nvl(tablespace_name, 'no tablespace') "
- "NOT IN (%s) AND "
- % (", ".join(["'%s'" % ts for ts in self.exclude_tablespaces]))
+ query = query.where(
+ func.coalesce(
+ dictionary.all_tables.c.tablespace_name, "no tablespace"
+ ).not_in(self.exclude_tablespaces)
)
- sql_str += (
- "OWNER = :owner "
- "AND IOT_NAME IS NULL "
- "AND DURATION IS NOT NULL"
+ query = query.where(
+ dictionary.all_tables.c.owner == schema,
+ dictionary.all_tables.c.iot_name.is_(null()),
+ dictionary.all_tables.c.duration.is_not(null()),
)
- cursor = connection.execute(sql.text(sql_str), dict(owner=schema))
- return [self.normalize_name(row[0]) for row in cursor]
+ result = self._execute_reflection(
+ connection, query, dblink, returns_long=False
+ ).scalars()
+ return [self.normalize_name(row) for row in result]
@reflection.cache
- def get_view_names(self, connection, schema=None, **kw):
- schema = self.denormalize_name(schema or self.default_schema_name)
- s = sql.text("SELECT view_name FROM all_views WHERE owner = :owner")
- cursor = connection.execute(
- s, dict(owner=self.denormalize_name(schema))
+ def get_materialized_view_names(
+ self, connection, schema=None, dblink=None, _normalize=True, **kw
+ ):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link."""
+ if not schema:
+ schema = self.default_schema_name
+
+ query = select(dictionary.all_mviews.c.mview_name).where(
+ dictionary.all_mviews.c.owner == self.denormalize_name(schema)
)
- return [self.normalize_name(row[0]) for row in cursor]
+ result = self._execute_reflection(
+ connection, query, dblink, returns_long=False
+ ).scalars()
+ if _normalize:
+ return [self.normalize_name(row) for row in result]
+ else:
+ return result.all()
@reflection.cache
- def get_sequence_names(self, connection, schema=None, **kw):
+ def get_view_names(self, connection, schema=None, dblink=None, **kw):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link."""
if not schema:
schema = self.default_schema_name
- cursor = connection.execute(
- sql.text(
- "SELECT sequence_name FROM all_sequences "
- "WHERE sequence_owner = :schema_name"
- ),
- dict(schema_name=self.denormalize_name(schema)),
+
+ query = select(dictionary.all_views.c.view_name).where(
+ dictionary.all_views.c.owner == self.denormalize_name(schema)
)
- return [self.normalize_name(row[0]) for row in cursor]
+ result = self._execute_reflection(
+ connection, query, dblink, returns_long=False
+ ).scalars()
+ return [self.normalize_name(row) for row in result]
@reflection.cache
- def get_table_options(self, connection, table_name, schema=None, **kw):
- options = {}
+ def get_sequence_names(self, connection, schema=None, dblink=None, **kw):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link."""
+ if not schema:
+ schema = self.default_schema_name
+ query = select(dictionary.all_sequences.c.sequence_name).where(
+ dictionary.all_sequences.c.sequence_owner
+ == self.denormalize_name(schema)
+ )
- resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
- dblink = kw.get("dblink", "")
- info_cache = kw.get("info_cache")
+ result = self._execute_reflection(
+ connection, query, dblink, returns_long=False
+ ).scalars()
+ return [self.normalize_name(row) for row in result]
- (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ def _value_or_raise(self, data, table, schema):
+ table = self.normalize_name(str(table))
+ try:
+ return dict(data)[(schema, table)]
+ except KeyError:
+ raise exc.NoSuchTableError(
+ f"{schema}.{table}" if schema else table
+ ) from None
+
+ def _prepare_filter_names(self, filter_names):
+ if filter_names:
+ fn = [self.denormalize_name(name) for name in filter_names]
+ return True, {"filter_names": fn}
+ else:
+ return False, {}
+
+ @reflection.cache
+ def get_table_options(self, connection, table_name, schema=None, **kw):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
+ """
+ data = self.get_multi_table_options(
connection,
- table_name,
- schema,
- resolve_synonyms,
- dblink,
- info_cache=info_cache,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
+ return self._value_or_raise(data, table_name, schema)
- params = {"table_name": table_name}
+ @lru_cache()
+ def _table_options_query(
+ self, owner, scope, kind, has_filter_names, has_mat_views
+ ):
+ query = select(
+ dictionary.all_tables.c.table_name,
+ dictionary.all_tables.c.compression,
+ dictionary.all_tables.c.compress_for,
+ ).where(dictionary.all_tables.c.owner == owner)
+ if has_filter_names:
+ query = query.where(
+ dictionary.all_tables.c.table_name.in_(
+ bindparam("filter_names")
+ )
+ )
+ if scope is ObjectScope.DEFAULT:
+ query = query.where(dictionary.all_tables.c.duration.is_(null()))
+ elif scope is ObjectScope.TEMPORARY:
+ query = query.where(
+ dictionary.all_tables.c.duration.is_not(null())
+ )
- columns = ["table_name"]
- if self._supports_table_compression:
- columns.append("compression")
- if self._supports_table_compress_for:
- columns.append("compress_for")
+ if (
+ has_mat_views
+ and ObjectKind.TABLE in kind
+ and ObjectKind.MATERIALIZED_VIEW not in kind
+ ):
+ # cant use EXCEPT ALL / MINUS here because we don't have an
+ # excludable row vs. the query above
+ # outerjoin + where null works better on oracle 21 but 11 does
+ # not like it at all. this is the next best thing
+
+ query = query.where(
+ dictionary.all_tables.c.table_name.not_in(
+ bindparam("mat_views")
+ )
+ )
+ elif (
+ ObjectKind.TABLE not in kind
+ and ObjectKind.MATERIALIZED_VIEW in kind
+ ):
+ query = query.where(
+ dictionary.all_tables.c.table_name.in_(bindparam("mat_views"))
+ )
+ return query
- text = (
- "SELECT %(columns)s "
- "FROM ALL_TABLES%(dblink)s "
- "WHERE table_name = CAST(:table_name AS VARCHAR(128))"
- )
+ @_handle_synonyms_decorator
+ def get_multi_table_options(
+ self,
+ connection,
+ *,
+ schema,
+ filter_names,
+ scope,
+ kind,
+ dblink=None,
+ **kw,
+ ):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
+ """
+ owner = self.denormalize_name(schema or self.default_schema_name)
- if schema is not None:
- params["owner"] = schema
- text += " AND owner = CAST(:owner AS VARCHAR(128)) "
- text = text % {"dblink": dblink, "columns": ", ".join(columns)}
+ has_filter_names, params = self._prepare_filter_names(filter_names)
+ has_mat_views = False
- result = connection.execute(sql.text(text), params)
+ if (
+ ObjectKind.TABLE in kind
+ and ObjectKind.MATERIALIZED_VIEW not in kind
+ ):
+ # see note in _table_options_query
+ mat_views = self.get_materialized_view_names(
+ connection, schema, dblink, _normalize=False, **kw
+ )
+ if mat_views:
+ params["mat_views"] = mat_views
+ has_mat_views = True
+ elif (
+ ObjectKind.TABLE not in kind
+ and ObjectKind.MATERIALIZED_VIEW in kind
+ ):
+ mat_views = self.get_materialized_view_names(
+ connection, schema, dblink, _normalize=False, **kw
+ )
+ params["mat_views"] = mat_views
- enabled = dict(DISABLED=False, ENABLED=True)
+ options = {}
+ default = ReflectionDefaults.table_options
- row = result.first()
- if row:
- if "compression" in row._fields and enabled.get(
- row.compression, False
- ):
- if "compress_for" in row._fields:
- options["oracle_compress"] = row.compress_for
+ if ObjectKind.TABLE in kind or ObjectKind.MATERIALIZED_VIEW in kind:
+ query = self._table_options_query(
+ owner, scope, kind, has_filter_names, has_mat_views
+ )
+ result = self._execute_reflection(
+ connection, query, dblink, returns_long=False, params=params
+ )
+
+ for table, compression, compress_for in result:
+ if compression == "ENABLED":
+ data = {"oracle_compress": compress_for}
else:
- options["oracle_compress"] = True
+ data = default()
+ options[(schema, self.normalize_name(table))] = data
+ if ObjectKind.VIEW in kind and ObjectScope.DEFAULT in scope:
+ # add the views (no temporary views)
+ for view in self.get_view_names(connection, schema, dblink, **kw):
+ if not filter_names or view in filter_names:
+ options[(schema, view)] = default()
- return options
+ return options.items()
@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
"""
- kw arguments can be:
+ data = self.get_multi_columns(
+ connection,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
+ )
+ return self._value_or_raise(data, table_name, schema)
+
+ def _run_batches(
+ self, connection, query, dblink, returns_long, mappings, all_objects
+ ):
+ each_batch = 500
+ batches = list(all_objects)
+ while batches:
+ batch = batches[0:each_batch]
+ batches[0:each_batch] = []
+
+ result = self._execute_reflection(
+ connection,
+ query,
+ dblink,
+ returns_long=returns_long,
+ params={"all_objects": batch},
+ )
+ if mappings:
+ yield from result.mappings()
+ else:
+ yield from result
+
+ @lru_cache()
+ def _column_query(self, owner):
+ all_cols = dictionary.all_tab_cols
+ all_comments = dictionary.all_col_comments
+ all_ids = dictionary.all_tab_identity_cols
- oracle_resolve_synonyms
+ if self.server_version_info >= (12,):
+ add_cols = (
+ all_cols.c.default_on_null,
+ sql.case(
+ (all_ids.c.table_name.is_(None), sql.null()),
+ else_=all_ids.c.generation_type
+ + ","
+ + all_ids.c.identity_options,
+ ).label("identity_options"),
+ )
+ join_identity_cols = True
+ else:
+ add_cols = (
+ sql.null().label("default_on_null"),
+ sql.null().label("identity_options"),
+ )
+ join_identity_cols = False
+
+ # NOTE: on oracle cannot create tables/views without columns and
+ # a table cannot have all column hidden:
+ # ORA-54039: table must have at least one column that is not invisible
+ # all_tab_cols returns data for tables/views/mat-views.
+ # all_tab_cols does not return recycled tables
+
+ query = (
+ select(
+ all_cols.c.table_name,
+ all_cols.c.column_name,
+ all_cols.c.data_type,
+ all_cols.c.char_length,
+ all_cols.c.data_precision,
+ all_cols.c.data_scale,
+ all_cols.c.nullable,
+ all_cols.c.data_default,
+ all_comments.c.comments,
+ all_cols.c.virtual_column,
+ *add_cols,
+ ).select_from(all_cols)
+ # NOTE: all_col_comments has a row for each column even if no
+ # comment is present, so a join could be performed, but there
+ # seems to be no difference compared to an outer join
+ .outerjoin(
+ all_comments,
+ and_(
+ all_cols.c.table_name == all_comments.c.table_name,
+ all_cols.c.column_name == all_comments.c.column_name,
+ all_cols.c.owner == all_comments.c.owner,
+ ),
+ )
+ )
+ if join_identity_cols:
+ query = query.outerjoin(
+ all_ids,
+ and_(
+ all_cols.c.table_name == all_ids.c.table_name,
+ all_cols.c.column_name == all_ids.c.column_name,
+ all_cols.c.owner == all_ids.c.owner,
+ ),
+ )
- dblink
+ query = query.where(
+ all_cols.c.table_name.in_(bindparam("all_objects")),
+ all_cols.c.hidden_column == "NO",
+ all_cols.c.owner == owner,
+ ).order_by(all_cols.c.table_name, all_cols.c.column_id)
+ return query
+ @_handle_synonyms_decorator
+ def get_multi_columns(
+ self,
+ connection,
+ *,
+ schema,
+ filter_names,
+ scope,
+ kind,
+ dblink=None,
+ **kw,
+ ):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
"""
+ owner = self.denormalize_name(schema or self.default_schema_name)
+ query = self._column_query(owner)
- resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
- dblink = kw.get("dblink", "")
- info_cache = kw.get("info_cache")
+ if (
+ filter_names
+ and kind is ObjectKind.ANY
+ and scope is ObjectScope.ANY
+ ):
+ all_objects = [self.denormalize_name(n) for n in filter_names]
+ else:
+ all_objects = self._get_all_objects(
+ connection, schema, scope, kind, filter_names, dblink, **kw
+ )
- (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ columns = defaultdict(list)
+
+ # all_tab_cols.data_default is LONG
+ result = self._run_batches(
connection,
- table_name,
- schema,
- resolve_synonyms,
+ query,
dblink,
- info_cache=info_cache,
+ returns_long=True,
+ mappings=True,
+ all_objects=all_objects,
)
- columns = []
- if self._supports_char_length:
- char_length_col = "char_length"
- else:
- char_length_col = "data_length"
- if self.server_version_info >= (12,):
- identity_cols = """\
- col.default_on_null,
- (
- SELECT id.generation_type || ',' || id.IDENTITY_OPTIONS
- FROM ALL_TAB_IDENTITY_COLS%(dblink)s id
- WHERE col.table_name = id.table_name
- AND col.column_name = id.column_name
- AND col.owner = id.owner
- ) AS identity_options""" % {
- "dblink": dblink
- }
- else:
- identity_cols = "NULL as default_on_null, NULL as identity_options"
-
- params = {"table_name": table_name}
-
- text = """
- SELECT
- col.column_name,
- col.data_type,
- col.%(char_length_col)s,
- col.data_precision,
- col.data_scale,
- col.nullable,
- col.data_default,
- com.comments,
- col.virtual_column,
- %(identity_cols)s
- FROM all_tab_cols%(dblink)s col
- LEFT JOIN all_col_comments%(dblink)s com
- ON col.table_name = com.table_name
- AND col.column_name = com.column_name
- AND col.owner = com.owner
- WHERE col.table_name = CAST(:table_name AS VARCHAR2(128))
- AND col.hidden_column = 'NO'
- """
- if schema is not None:
- params["owner"] = schema
- text += " AND col.owner = :owner "
- text += " ORDER BY col.column_id"
- text = text % {
- "dblink": dblink,
- "char_length_col": char_length_col,
- "identity_cols": identity_cols,
- }
-
- c = connection.execute(sql.text(text), params)
-
- for row in c:
- colname = self.normalize_name(row[0])
- orig_colname = row[0]
- coltype = row[1]
- length = row[2]
- precision = row[3]
- scale = row[4]
- nullable = row[5] == "Y"
- default = row[6]
- comment = row[7]
- generated = row[8]
- default_on_nul = row[9]
- identity_options = row[10]
+ for row_dict in result:
+ table_name = self.normalize_name(row_dict["table_name"])
+ orig_colname = row_dict["column_name"]
+ colname = self.normalize_name(orig_colname)
+ coltype = row_dict["data_type"]
+ precision = row_dict["data_precision"]
if coltype == "NUMBER":
+ scale = row_dict["data_scale"]
if precision is None and scale == 0:
coltype = INTEGER()
else:
@@ -2089,7 +2266,9 @@ class OracleDialect(default.DefaultDialect):
coltype = FLOAT(binary_precision=precision)
elif coltype in ("VARCHAR2", "NVARCHAR2", "CHAR", "NCHAR"):
- coltype = self.ischema_names.get(coltype)(length)
+ coltype = self.ischema_names.get(coltype)(
+ row_dict["char_length"]
+ )
elif "WITH TIME ZONE" in coltype:
coltype = TIMESTAMP(timezone=True)
else:
@@ -2103,15 +2282,17 @@ class OracleDialect(default.DefaultDialect):
)
coltype = sqltypes.NULLTYPE
- if generated == "YES":
+ default = row_dict["data_default"]
+ if row_dict["virtual_column"] == "YES":
computed = dict(sqltext=default)
default = None
else:
computed = None
+ identity_options = row_dict["identity_options"]
if identity_options is not None:
identity = self._parse_identity_options(
- identity_options, default_on_nul
+ identity_options, row_dict["default_on_null"]
)
default = None
else:
@@ -2120,10 +2301,9 @@ class OracleDialect(default.DefaultDialect):
cdict = {
"name": colname,
"type": coltype,
- "nullable": nullable,
+ "nullable": row_dict["nullable"] == "Y",
"default": default,
- "autoincrement": "auto",
- "comment": comment,
+ "comment": row_dict["comments"],
}
if orig_colname.lower() == orig_colname:
cdict["quote"] = True
@@ -2132,10 +2312,17 @@ class OracleDialect(default.DefaultDialect):
if identity is not None:
cdict["identity"] = identity
- columns.append(cdict)
- return columns
+ columns[(schema, table_name)].append(cdict)
- def _parse_identity_options(self, identity_options, default_on_nul):
+ # NOTE: default not needed since all tables have columns
+ # default = ReflectionDefaults.columns
+ # return (
+ # (key, value if value else default())
+ # for key, value in columns.items()
+ # )
+ return columns.items()
+
+ def _parse_identity_options(self, identity_options, default_on_null):
# identity_options is a string that starts with 'ALWAYS,' or
# 'BY DEFAULT,' and continues with
# START WITH: 1, INCREMENT BY: 1, MAX_VALUE: 123, MIN_VALUE: 1,
@@ -2144,7 +2331,7 @@ class OracleDialect(default.DefaultDialect):
parts = [p.strip() for p in identity_options.split(",")]
identity = {
"always": parts[0] == "ALWAYS",
- "on_null": default_on_nul == "YES",
+ "on_null": default_on_null == "YES",
}
for part in parts[1:]:
@@ -2168,384 +2355,641 @@ class OracleDialect(default.DefaultDialect):
return identity
@reflection.cache
- def get_table_comment(
+ def get_table_comment(self, connection, table_name, schema=None, **kw):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
+ """
+ data = self.get_multi_table_comment(
+ connection,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
+ )
+ return self._value_or_raise(data, table_name, schema)
+
+ @lru_cache()
+ def _comment_query(self, owner, scope, kind, has_filter_names):
+ # NOTE: all_tab_comments / all_mview_comments have a row for all
+ # object even if they don't have comments
+ queries = []
+ if ObjectKind.TABLE in kind or ObjectKind.VIEW in kind:
+ # all_tab_comments returns also plain views
+ tbl_view = select(
+ dictionary.all_tab_comments.c.table_name,
+ dictionary.all_tab_comments.c.comments,
+ ).where(
+ dictionary.all_tab_comments.c.owner == owner,
+ dictionary.all_tab_comments.c.table_name.not_like("BIN$%"),
+ )
+ if ObjectKind.VIEW not in kind:
+ tbl_view = tbl_view.where(
+ dictionary.all_tab_comments.c.table_type == "TABLE"
+ )
+ elif ObjectKind.TABLE not in kind:
+ tbl_view = tbl_view.where(
+ dictionary.all_tab_comments.c.table_type == "VIEW"
+ )
+ queries.append(tbl_view)
+ if ObjectKind.MATERIALIZED_VIEW in kind:
+ mat_view = select(
+ dictionary.all_mview_comments.c.mview_name.label("table_name"),
+ dictionary.all_mview_comments.c.comments,
+ ).where(
+ dictionary.all_mview_comments.c.owner == owner,
+ dictionary.all_mview_comments.c.mview_name.not_like("BIN$%"),
+ )
+ queries.append(mat_view)
+ if len(queries) == 1:
+ query = queries[0]
+ else:
+ union = sql.union_all(*queries).subquery("tables_and_views")
+ query = select(union.c.table_name, union.c.comments)
+
+ name_col = query.selected_columns.table_name
+
+ if scope in (ObjectScope.DEFAULT, ObjectScope.TEMPORARY):
+ temp = "Y" if scope is ObjectScope.TEMPORARY else "N"
+ # need distinct since materialized view are listed also
+ # as tables in all_objects
+ query = query.distinct().join(
+ dictionary.all_objects,
+ and_(
+ dictionary.all_objects.c.owner == owner,
+ dictionary.all_objects.c.object_name == name_col,
+ dictionary.all_objects.c.temporary == temp,
+ ),
+ )
+ if has_filter_names:
+ query = query.where(name_col.in_(bindparam("filter_names")))
+ return query
+
+ @_handle_synonyms_decorator
+ def get_multi_table_comment(
self,
connection,
- table_name,
- schema=None,
- resolve_synonyms=False,
- dblink="",
+ *,
+ schema,
+ filter_names,
+ scope,
+ kind,
+ dblink=None,
**kw,
):
-
- info_cache = kw.get("info_cache")
- (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
- connection,
- table_name,
- schema,
- resolve_synonyms,
- dblink,
- info_cache=info_cache,
- )
-
- if not schema:
- schema = self.default_schema_name
-
- COMMENT_SQL = """
- SELECT comments
- FROM all_tab_comments
- WHERE table_name = CAST(:table_name AS VARCHAR(128))
- AND owner = CAST(:schema_name AS VARCHAR(128))
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
"""
+ owner = self.denormalize_name(schema or self.default_schema_name)
+ has_filter_names, params = self._prepare_filter_names(filter_names)
+ query = self._comment_query(owner, scope, kind, has_filter_names)
- c = connection.execute(
- sql.text(COMMENT_SQL),
- dict(table_name=table_name, schema_name=schema),
+ result = self._execute_reflection(
+ connection, query, dblink, returns_long=False, params=params
+ )
+ default = ReflectionDefaults.table_comment
+ # materialized views by default seem to have a comment like
+ # "snapshot table for snapshot owner.mat_view_name"
+ ignore_mat_view = "snapshot table for snapshot "
+ return (
+ (
+ (schema, self.normalize_name(table)),
+ {"text": comment}
+ if comment is not None
+ and not comment.startswith(ignore_mat_view)
+ else default(),
+ )
+ for table, comment in result
)
- return {"text": c.scalar()}
@reflection.cache
- def get_indexes(
- self,
- connection,
- table_name,
- schema=None,
- resolve_synonyms=False,
- dblink="",
- **kw,
- ):
-
- info_cache = kw.get("info_cache")
- (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ def get_indexes(self, connection, table_name, schema=None, **kw):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
+ """
+ data = self.get_multi_indexes(
connection,
- table_name,
- schema,
- resolve_synonyms,
- dblink,
- info_cache=info_cache,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
- indexes = []
-
- params = {"table_name": table_name}
- text = (
- "SELECT a.index_name, a.column_name, "
- "\nb.index_type, b.uniqueness, b.compression, b.prefix_length "
- "\nFROM ALL_IND_COLUMNS%(dblink)s a, "
- "\nALL_INDEXES%(dblink)s b "
- "\nWHERE "
- "\na.index_name = b.index_name "
- "\nAND a.table_owner = b.table_owner "
- "\nAND a.table_name = b.table_name "
- "\nAND a.table_name = CAST(:table_name AS VARCHAR(128))"
+ return self._value_or_raise(data, table_name, schema)
+
+ @lru_cache()
+ def _index_query(self, owner):
+ return (
+ select(
+ dictionary.all_ind_columns.c.table_name,
+ dictionary.all_ind_columns.c.index_name,
+ dictionary.all_ind_columns.c.column_name,
+ dictionary.all_indexes.c.index_type,
+ dictionary.all_indexes.c.uniqueness,
+ dictionary.all_indexes.c.compression,
+ dictionary.all_indexes.c.prefix_length,
+ )
+ .select_from(dictionary.all_ind_columns)
+ .join(
+ dictionary.all_indexes,
+ sql.and_(
+ dictionary.all_ind_columns.c.index_name
+ == dictionary.all_indexes.c.index_name,
+ dictionary.all_ind_columns.c.table_owner
+ == dictionary.all_indexes.c.table_owner,
+ # NOTE: this condition on table_name is not required
+ # but it improves the query performance noticeably
+ dictionary.all_ind_columns.c.table_name
+ == dictionary.all_indexes.c.table_name,
+ ),
+ )
+ .where(
+ dictionary.all_ind_columns.c.table_owner == owner,
+ dictionary.all_ind_columns.c.table_name.in_(
+ bindparam("all_objects")
+ ),
+ )
+ .order_by(
+ dictionary.all_ind_columns.c.index_name,
+ dictionary.all_ind_columns.c.column_position,
+ )
)
- if schema is not None:
- params["schema"] = schema
- text += "AND a.table_owner = :schema "
+ @reflection.flexi_cache(
+ ("schema", InternalTraversal.dp_string),
+ ("dblink", InternalTraversal.dp_string),
+ ("all_objects", InternalTraversal.dp_string_list),
+ )
+ def _get_indexes_rows(self, connection, schema, dblink, all_objects, **kw):
+ owner = self.denormalize_name(schema or self.default_schema_name)
- text += "ORDER BY a.index_name, a.column_position"
+ query = self._index_query(owner)
- text = text % {"dblink": dblink}
+ pks = {
+ row_dict["constraint_name"]
+ for row_dict in self._get_all_constraint_rows(
+ connection, schema, dblink, all_objects, **kw
+ )
+ if row_dict["constraint_type"] == "P"
+ }
- q = sql.text(text)
- rp = connection.execute(q, params)
- indexes = []
- last_index_name = None
- pk_constraint = self.get_pk_constraint(
+ result = self._run_batches(
connection,
- table_name,
- schema,
- resolve_synonyms=resolve_synonyms,
- dblink=dblink,
- info_cache=kw.get("info_cache"),
+ query,
+ dblink,
+ returns_long=False,
+ mappings=True,
+ all_objects=all_objects,
)
- uniqueness = dict(NONUNIQUE=False, UNIQUE=True)
- enabled = dict(DISABLED=False, ENABLED=True)
+ return [
+ row_dict
+ for row_dict in result
+ if row_dict["index_name"] not in pks
+ ]
- oracle_sys_col = re.compile(r"SYS_NC\d+\$", re.IGNORECASE)
+ @_handle_synonyms_decorator
+ def get_multi_indexes(
+ self,
+ connection,
+ *,
+ schema,
+ filter_names,
+ scope,
+ kind,
+ dblink=None,
+ **kw,
+ ):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
+ """
+ all_objects = self._get_all_objects(
+ connection, schema, scope, kind, filter_names, dblink, **kw
+ )
- index = None
- for rset in rp:
- index_name_normalized = self.normalize_name(rset.index_name)
+ uniqueness = {"NONUNIQUE": False, "UNIQUE": True}
+ enabled = {"DISABLED": False, "ENABLED": True}
+ is_bitmap = {"BITMAP", "FUNCTION-BASED BITMAP"}
- # skip primary key index. This is refined as of
- # [ticket:5421]. Note that ALL_INDEXES.GENERATED will by "Y"
- # if the name of this index was generated by Oracle, however
- # if a named primary key constraint was created then this flag
- # is false.
- if (
- pk_constraint
- and index_name_normalized == pk_constraint["name"]
- ):
- continue
+ oracle_sys_col = re.compile(r"SYS_NC\d+\$", re.IGNORECASE)
- if rset.index_name != last_index_name:
- index = dict(
- name=index_name_normalized,
- column_names=[],
- dialect_options={},
- )
- indexes.append(index)
- index["unique"] = uniqueness.get(rset.uniqueness, False)
+ indexes = defaultdict(dict)
+
+ for row_dict in self._get_indexes_rows(
+ connection, schema, dblink, all_objects, **kw
+ ):
+ index_name = self.normalize_name(row_dict["index_name"])
+ table_name = self.normalize_name(row_dict["table_name"])
+ table_indexes = indexes[(schema, table_name)]
+
+ if index_name not in table_indexes:
+ table_indexes[index_name] = index_dict = {
+ "name": index_name,
+ "column_names": [],
+ "dialect_options": {},
+ "unique": uniqueness.get(row_dict["uniqueness"], False),
+ }
+ do = index_dict["dialect_options"]
+ if row_dict["index_type"] in is_bitmap:
+ do["oracle_bitmap"] = True
+ if enabled.get(row_dict["compression"], False):
+ do["oracle_compress"] = row_dict["prefix_length"]
- if rset.index_type in ("BITMAP", "FUNCTION-BASED BITMAP"):
- index["dialect_options"]["oracle_bitmap"] = True
- if enabled.get(rset.compression, False):
- index["dialect_options"][
- "oracle_compress"
- ] = rset.prefix_length
+ else:
+ index_dict = table_indexes[index_name]
# filter out Oracle SYS_NC names. could also do an outer join
- # to the all_tab_columns table and check for real col names there.
- if not oracle_sys_col.match(rset.column_name):
- index["column_names"].append(
- self.normalize_name(rset.column_name)
+ # to the all_tab_columns table and check for real col names
+ # there.
+ if not oracle_sys_col.match(row_dict["column_name"]):
+ index_dict["column_names"].append(
+ self.normalize_name(row_dict["column_name"])
)
- last_index_name = rset.index_name
- return indexes
+ default = ReflectionDefaults.indexes
- @reflection.cache
- def _get_constraint_data(
- self, connection, table_name, schema=None, dblink="", **kw
- ):
-
- params = {"table_name": table_name}
-
- text = (
- "SELECT"
- "\nac.constraint_name," # 0
- "\nac.constraint_type," # 1
- "\nloc.column_name AS local_column," # 2
- "\nrem.table_name AS remote_table," # 3
- "\nrem.column_name AS remote_column," # 4
- "\nrem.owner AS remote_owner," # 5
- "\nloc.position as loc_pos," # 6
- "\nrem.position as rem_pos," # 7
- "\nac.search_condition," # 8
- "\nac.delete_rule" # 9
- "\nFROM all_constraints%(dblink)s ac,"
- "\nall_cons_columns%(dblink)s loc,"
- "\nall_cons_columns%(dblink)s rem"
- "\nWHERE ac.table_name = CAST(:table_name AS VARCHAR2(128))"
- "\nAND ac.constraint_type IN ('R','P', 'U', 'C')"
- )
-
- if schema is not None:
- params["owner"] = schema
- text += "\nAND ac.owner = CAST(:owner AS VARCHAR2(128))"
-
- text += (
- "\nAND ac.owner = loc.owner"
- "\nAND ac.constraint_name = loc.constraint_name"
- "\nAND ac.r_owner = rem.owner(+)"
- "\nAND ac.r_constraint_name = rem.constraint_name(+)"
- "\nAND (rem.position IS NULL or loc.position=rem.position)"
- "\nORDER BY ac.constraint_name, loc.position"
+ return (
+ (key, list(indexes[key].values()) if key in indexes else default())
+ for key in (
+ (schema, self.normalize_name(obj_name))
+ for obj_name in all_objects
+ )
)
- text = text % {"dblink": dblink}
- rp = connection.execute(sql.text(text), params)
- constraint_data = rp.fetchall()
- return constraint_data
-
@reflection.cache
def get_pk_constraint(self, connection, table_name, schema=None, **kw):
- resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
- dblink = kw.get("dblink", "")
- info_cache = kw.get("info_cache")
-
- (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
+ """
+ data = self.get_multi_pk_constraint(
connection,
- table_name,
- schema,
- resolve_synonyms,
- dblink,
- info_cache=info_cache,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
- pkeys = []
- constraint_name = None
- constraint_data = self._get_constraint_data(
- connection,
- table_name,
- schema,
- dblink,
- info_cache=kw.get("info_cache"),
+ return self._value_or_raise(data, table_name, schema)
+
+ @lru_cache()
+ def _constraint_query(self, owner):
+ local = dictionary.all_cons_columns.alias("local")
+ remote = dictionary.all_cons_columns.alias("remote")
+ return (
+ select(
+ dictionary.all_constraints.c.table_name,
+ dictionary.all_constraints.c.constraint_type,
+ dictionary.all_constraints.c.constraint_name,
+ local.c.column_name.label("local_column"),
+ remote.c.table_name.label("remote_table"),
+ remote.c.column_name.label("remote_column"),
+ remote.c.owner.label("remote_owner"),
+ dictionary.all_constraints.c.search_condition,
+ dictionary.all_constraints.c.delete_rule,
+ )
+ .select_from(dictionary.all_constraints)
+ .join(
+ local,
+ and_(
+ local.c.owner == dictionary.all_constraints.c.owner,
+ dictionary.all_constraints.c.constraint_name
+ == local.c.constraint_name,
+ ),
+ )
+ .outerjoin(
+ remote,
+ and_(
+ dictionary.all_constraints.c.r_owner == remote.c.owner,
+ dictionary.all_constraints.c.r_constraint_name
+ == remote.c.constraint_name,
+ or_(
+ remote.c.position.is_(sql.null()),
+ local.c.position == remote.c.position,
+ ),
+ ),
+ )
+ .where(
+ dictionary.all_constraints.c.owner == owner,
+ dictionary.all_constraints.c.table_name.in_(
+ bindparam("all_objects")
+ ),
+ dictionary.all_constraints.c.constraint_type.in_(
+ ("R", "P", "U", "C")
+ ),
+ )
+ .order_by(
+ dictionary.all_constraints.c.constraint_name, local.c.position
+ )
)
- for row in constraint_data:
- (
- cons_name,
- cons_type,
- local_column,
- remote_table,
- remote_column,
- remote_owner,
- ) = row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]])
- if cons_type == "P":
- if constraint_name is None:
- constraint_name = self.normalize_name(cons_name)
- pkeys.append(local_column)
- return {"constrained_columns": pkeys, "name": constraint_name}
+ @reflection.flexi_cache(
+ ("schema", InternalTraversal.dp_string),
+ ("dblink", InternalTraversal.dp_string),
+ ("all_objects", InternalTraversal.dp_string_list),
+ )
+ def _get_all_constraint_rows(
+ self, connection, schema, dblink, all_objects, **kw
+ ):
+ owner = self.denormalize_name(schema or self.default_schema_name)
+ query = self._constraint_query(owner)
- @reflection.cache
- def get_foreign_keys(self, connection, table_name, schema=None, **kw):
+ # since the result is cached a list must be created
+ values = list(
+ self._run_batches(
+ connection,
+ query,
+ dblink,
+ returns_long=False,
+ mappings=True,
+ all_objects=all_objects,
+ )
+ )
+ return values
+
+ @_handle_synonyms_decorator
+ def get_multi_pk_constraint(
+ self,
+ connection,
+ *,
+ scope,
+ schema,
+ filter_names,
+ kind,
+ dblink=None,
+ **kw,
+ ):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
"""
+ all_objects = self._get_all_objects(
+ connection, schema, scope, kind, filter_names, dblink, **kw
+ )
- kw arguments can be:
+ primary_keys = defaultdict(dict)
+ default = ReflectionDefaults.pk_constraint
- oracle_resolve_synonyms
+ for row_dict in self._get_all_constraint_rows(
+ connection, schema, dblink, all_objects, **kw
+ ):
+ if row_dict["constraint_type"] != "P":
+ continue
+ table_name = self.normalize_name(row_dict["table_name"])
+ constraint_name = self.normalize_name(row_dict["constraint_name"])
+ column_name = self.normalize_name(row_dict["local_column"])
+
+ table_pk = primary_keys[(schema, table_name)]
+ if not table_pk:
+ table_pk["name"] = constraint_name
+ table_pk["constrained_columns"] = [column_name]
+ else:
+ table_pk["constrained_columns"].append(column_name)
- dblink
+ return (
+ (key, primary_keys[key] if key in primary_keys else default())
+ for key in (
+ (schema, self.normalize_name(obj_name))
+ for obj_name in all_objects
+ )
+ )
+ @reflection.cache
+ def get_foreign_keys(
+ self,
+ connection,
+ table_name,
+ schema=None,
+ **kw,
+ ):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
"""
- requested_schema = schema # to check later on
- resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
- dblink = kw.get("dblink", "")
- info_cache = kw.get("info_cache")
-
- (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ data = self.get_multi_foreign_keys(
connection,
- table_name,
- schema,
- resolve_synonyms,
- dblink,
- info_cache=info_cache,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
+ return self._value_or_raise(data, table_name, schema)
- constraint_data = self._get_constraint_data(
- connection,
- table_name,
- schema,
- dblink,
- info_cache=kw.get("info_cache"),
+ @_handle_synonyms_decorator
+ def get_multi_foreign_keys(
+ self,
+ connection,
+ *,
+ scope,
+ schema,
+ filter_names,
+ kind,
+ dblink=None,
+ **kw,
+ ):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
+ """
+ all_objects = self._get_all_objects(
+ connection, schema, scope, kind, filter_names, dblink, **kw
)
- def fkey_rec():
- return {
- "name": None,
- "constrained_columns": [],
- "referred_schema": None,
- "referred_table": None,
- "referred_columns": [],
- "options": {},
- }
+ resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
- fkeys = util.defaultdict(fkey_rec)
+ owner = self.denormalize_name(schema or self.default_schema_name)
- for row in constraint_data:
- (
- cons_name,
- cons_type,
- local_column,
- remote_table,
- remote_column,
- remote_owner,
- ) = row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]])
-
- cons_name = self.normalize_name(cons_name)
-
- if cons_type == "R":
- if remote_table is None:
- # ticket 363
- util.warn(
- (
- "Got 'None' querying 'table_name' from "
- "all_cons_columns%(dblink)s - does the user have "
- "proper rights to the table?"
- )
- % {"dblink": dblink}
- )
- continue
+ all_remote_owners = set()
+ fkeys = defaultdict(dict)
+
+ for row_dict in self._get_all_constraint_rows(
+ connection, schema, dblink, all_objects, **kw
+ ):
+ if row_dict["constraint_type"] != "R":
+ continue
+
+ table_name = self.normalize_name(row_dict["table_name"])
+ constraint_name = self.normalize_name(row_dict["constraint_name"])
+ table_fkey = fkeys[(schema, table_name)]
+
+ assert constraint_name is not None
- rec = fkeys[cons_name]
- rec["name"] = cons_name
- local_cols, remote_cols = (
- rec["constrained_columns"],
- rec["referred_columns"],
+ local_column = self.normalize_name(row_dict["local_column"])
+ remote_table = self.normalize_name(row_dict["remote_table"])
+ remote_column = self.normalize_name(row_dict["remote_column"])
+ remote_owner_orig = row_dict["remote_owner"]
+ remote_owner = self.normalize_name(remote_owner_orig)
+ if remote_owner_orig is not None:
+ all_remote_owners.add(remote_owner_orig)
+
+ if remote_table is None:
+ # ticket 363
+ if dblink and not dblink.startswith("@"):
+ dblink = f"@{dblink}"
+ util.warn(
+ "Got 'None' querying 'table_name' from "
+ f"all_cons_columns{dblink or ''} - does the user have "
+ "proper rights to the table?"
)
+ continue
- if not rec["referred_table"]:
- if resolve_synonyms:
- (
- ref_remote_name,
- ref_remote_owner,
- ref_dblink,
- ref_synonym,
- ) = self._resolve_synonym(
- connection,
- desired_owner=self.denormalize_name(remote_owner),
- desired_table=self.denormalize_name(remote_table),
- )
- if ref_synonym:
- remote_table = self.normalize_name(ref_synonym)
- remote_owner = self.normalize_name(
- ref_remote_owner
- )
+ if constraint_name not in table_fkey:
+ table_fkey[constraint_name] = fkey = {
+ "name": constraint_name,
+ "constrained_columns": [],
+ "referred_schema": None,
+ "referred_table": remote_table,
+ "referred_columns": [],
+ "options": {},
+ }
- rec["referred_table"] = remote_table
+ if resolve_synonyms:
+ # will be removed below
+ fkey["_ref_schema"] = remote_owner
- if (
- requested_schema is not None
- or self.denormalize_name(remote_owner) != schema
- ):
- rec["referred_schema"] = remote_owner
+ if schema is not None or remote_owner_orig != owner:
+ fkey["referred_schema"] = remote_owner
+
+ delete_rule = row_dict["delete_rule"]
+ if delete_rule != "NO ACTION":
+ fkey["options"]["ondelete"] = delete_rule
+
+ else:
+ fkey = table_fkey[constraint_name]
+
+ fkey["constrained_columns"].append(local_column)
+ fkey["referred_columns"].append(remote_column)
+
+ if resolve_synonyms and all_remote_owners:
+ query = select(
+ dictionary.all_synonyms.c.owner,
+ dictionary.all_synonyms.c.table_name,
+ dictionary.all_synonyms.c.table_owner,
+ dictionary.all_synonyms.c.synonym_name,
+ ).where(dictionary.all_synonyms.c.owner.in_(all_remote_owners))
+
+ result = self._execute_reflection(
+ connection, query, dblink, returns_long=False
+ ).mappings()
- if row[9] != "NO ACTION":
- rec["options"]["ondelete"] = row[9]
+ remote_owners_lut = {}
+ for row in result:
+ synonym_owner = self.normalize_name(row["owner"])
+ table_name = self.normalize_name(row["table_name"])
- local_cols.append(local_column)
- remote_cols.append(remote_column)
+ remote_owners_lut[(synonym_owner, table_name)] = (
+ row["table_owner"],
+ row["synonym_name"],
+ )
+
+ empty = (None, None)
+ for table_fkeys in fkeys.values():
+ for table_fkey in table_fkeys.values():
+ key = (
+ table_fkey.pop("_ref_schema"),
+ table_fkey["referred_table"],
+ )
+ remote_owner, syn_name = remote_owners_lut.get(key, empty)
+ if syn_name:
+ sn = self.normalize_name(syn_name)
+ table_fkey["referred_table"] = sn
+ if schema is not None or remote_owner != owner:
+ ro = self.normalize_name(remote_owner)
+ table_fkey["referred_schema"] = ro
+ else:
+ table_fkey["referred_schema"] = None
+ default = ReflectionDefaults.foreign_keys
- return list(fkeys.values())
+ return (
+ (key, list(fkeys[key].values()) if key in fkeys else default())
+ for key in (
+ (schema, self.normalize_name(obj_name))
+ for obj_name in all_objects
+ )
+ )
@reflection.cache
def get_unique_constraints(
self, connection, table_name, schema=None, **kw
):
- resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
- dblink = kw.get("dblink", "")
- info_cache = kw.get("info_cache")
-
- (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
+ """
+ data = self.get_multi_unique_constraints(
connection,
- table_name,
- schema,
- resolve_synonyms,
- dblink,
- info_cache=info_cache,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
+ return self._value_or_raise(data, table_name, schema)
- constraint_data = self._get_constraint_data(
- connection,
- table_name,
- schema,
- dblink,
- info_cache=kw.get("info_cache"),
+ @_handle_synonyms_decorator
+ def get_multi_unique_constraints(
+ self,
+ connection,
+ *,
+ scope,
+ schema,
+ filter_names,
+ kind,
+ dblink=None,
+ **kw,
+ ):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
+ """
+ all_objects = self._get_all_objects(
+ connection, schema, scope, kind, filter_names, dblink, **kw
)
- unique_keys = filter(lambda x: x[1] == "U", constraint_data)
- uniques_group = groupby(unique_keys, lambda x: x[0])
+ unique_cons = defaultdict(dict)
index_names = {
- ix["name"]
- for ix in self.get_indexes(connection, table_name, schema=schema)
+ row_dict["index_name"]
+ for row_dict in self._get_indexes_rows(
+ connection, schema, dblink, all_objects, **kw
+ )
}
- return [
- {
- "name": name,
- "column_names": cols,
- "duplicates_index": name if name in index_names else None,
- }
- for name, cols in [
- [
- self.normalize_name(i[0]),
- [self.normalize_name(x[2]) for x in i[1]],
- ]
- for i in uniques_group
- ]
- ]
+
+ for row_dict in self._get_all_constraint_rows(
+ connection, schema, dblink, all_objects, **kw
+ ):
+ if row_dict["constraint_type"] != "U":
+ continue
+ table_name = self.normalize_name(row_dict["table_name"])
+ constraint_name_orig = row_dict["constraint_name"]
+ constraint_name = self.normalize_name(constraint_name_orig)
+ column_name = self.normalize_name(row_dict["local_column"])
+ table_uc = unique_cons[(schema, table_name)]
+
+ assert constraint_name is not None
+
+ if constraint_name not in table_uc:
+ table_uc[constraint_name] = uc = {
+ "name": constraint_name,
+ "column_names": [],
+ "duplicates_index": constraint_name
+ if constraint_name_orig in index_names
+ else None,
+ }
+ else:
+ uc = table_uc[constraint_name]
+
+ uc["column_names"].append(column_name)
+
+ default = ReflectionDefaults.unique_constraints
+
+ return (
+ (
+ key,
+ list(unique_cons[key].values())
+ if key in unique_cons
+ else default(),
+ )
+ for key in (
+ (schema, self.normalize_name(obj_name))
+ for obj_name in all_objects
+ )
+ )
@reflection.cache
def get_view_definition(
@@ -2553,65 +2997,129 @@ class OracleDialect(default.DefaultDialect):
connection,
view_name,
schema=None,
- resolve_synonyms=False,
- dblink="",
+ dblink=None,
**kw,
):
- info_cache = kw.get("info_cache")
- (view_name, schema, dblink, synonym) = self._prepare_reflection_args(
- connection,
- view_name,
- schema,
- resolve_synonyms,
- dblink,
- info_cache=info_cache,
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
+ """
+ if kw.get("oracle_resolve_synonyms", False):
+ synonyms = self._get_synonyms(
+ connection, schema, filter_names=[view_name], dblink=dblink
+ )
+ if synonyms:
+ assert len(synonyms) == 1
+ row_dict = synonyms[0]
+ dblink = self.normalize_name(row_dict["db_link"])
+ schema = row_dict["table_owner"]
+ view_name = row_dict["table_name"]
+
+ name = self.denormalize_name(view_name)
+ owner = self.denormalize_name(schema or self.default_schema_name)
+ query = (
+ select(dictionary.all_views.c.text)
+ .where(
+ dictionary.all_views.c.view_name == name,
+ dictionary.all_views.c.owner == owner,
+ )
+ .union_all(
+ select(dictionary.all_mviews.c.query).where(
+ dictionary.all_mviews.c.mview_name == name,
+ dictionary.all_mviews.c.owner == owner,
+ )
+ )
)
- params = {"view_name": view_name}
- text = "SELECT text FROM all_views WHERE view_name=:view_name"
-
- if schema is not None:
- text += " AND owner = :schema"
- params["schema"] = schema
-
- rp = connection.execute(sql.text(text), params).scalar()
- if rp:
- return rp
+ rp = self._execute_reflection(
+ connection, query, dblink, returns_long=False
+ ).scalar()
+ if rp is None:
+ raise exc.NoSuchTableError(
+ f"{schema}.{view_name}" if schema else view_name
+ )
else:
- return None
+ return rp
@reflection.cache
def get_check_constraints(
self, connection, table_name, schema=None, include_all=False, **kw
):
- resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
- dblink = kw.get("dblink", "")
- info_cache = kw.get("info_cache")
-
- (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
+ """
+ data = self.get_multi_check_constraints(
connection,
- table_name,
- schema,
- resolve_synonyms,
- dblink,
- info_cache=info_cache,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ include_all=include_all,
+ kind=ObjectKind.ANY,
+ **kw,
)
+ return self._value_or_raise(data, table_name, schema)
- constraint_data = self._get_constraint_data(
- connection,
- table_name,
- schema,
- dblink,
- info_cache=kw.get("info_cache"),
+ @_handle_synonyms_decorator
+ def get_multi_check_constraints(
+ self,
+ connection,
+ *,
+ schema,
+ filter_names,
+ dblink=None,
+ scope,
+ kind,
+ include_all=False,
+ **kw,
+ ):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
+ """
+ all_objects = self._get_all_objects(
+ connection, schema, scope, kind, filter_names, dblink, **kw
)
- check_constraints = filter(lambda x: x[1] == "C", constraint_data)
+ not_null = re.compile(r"..+?. IS NOT NULL$")
- return [
- {"name": self.normalize_name(cons[0]), "sqltext": cons[8]}
- for cons in check_constraints
- if include_all or not re.match(r"..+?. IS NOT NULL$", cons[8])
- ]
+ check_constraints = defaultdict(list)
+
+ for row_dict in self._get_all_constraint_rows(
+ connection, schema, dblink, all_objects, **kw
+ ):
+ if row_dict["constraint_type"] != "C":
+ continue
+ table_name = self.normalize_name(row_dict["table_name"])
+ constraint_name = self.normalize_name(row_dict["constraint_name"])
+ search_condition = row_dict["search_condition"]
+
+ table_checks = check_constraints[(schema, table_name)]
+ if constraint_name is not None and (
+ include_all or not not_null.match(search_condition)
+ ):
+ table_checks.append(
+ {"name": constraint_name, "sqltext": search_condition}
+ )
+
+ default = ReflectionDefaults.check_constraints
+
+ return (
+ (
+ key,
+ check_constraints[key]
+ if key in check_constraints
+ else default(),
+ )
+ for key in (
+ (schema, self.normalize_name(obj_name))
+ for obj_name in all_objects
+ )
+ )
+
+ def _list_dblinks(self, connection, dblink=None):
+ query = select(dictionary.all_db_links.c.db_link)
+ links = self._execute_reflection(
+ connection, query, dblink, returns_long=False
+ ).scalars()
+ return [self.normalize_name(link) for link in links]
class _OuterJoinColumn(sql.ClauseElement):
diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py
index 25e93632c..d2ee0a96e 100644
--- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py
+++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py
@@ -431,6 +431,7 @@ from . import base as oracle
from .base import OracleCompiler
from .base import OracleDialect
from .base import OracleExecutionContext
+from .types import _OracleDateLiteralRender
from ... import exc
from ... import util
from ...engine import cursor as _cursor
@@ -573,7 +574,7 @@ class _CXOracleDate(oracle._OracleDate):
return process
-class _CXOracleTIMESTAMP(oracle._OracleDateLiteralRender, sqltypes.TIMESTAMP):
+class _CXOracleTIMESTAMP(_OracleDateLiteralRender, sqltypes.TIMESTAMP):
def literal_processor(self, dialect):
return self._literal_processor_datetime(dialect)
@@ -812,6 +813,7 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext):
return None
def pre_exec(self):
+ super().pre_exec()
if not getattr(self.compiled, "_oracle_cx_sql_compiler", False):
return
diff --git a/lib/sqlalchemy/dialects/oracle/dictionary.py b/lib/sqlalchemy/dialects/oracle/dictionary.py
new file mode 100644
index 000000000..ac7a350da
--- /dev/null
+++ b/lib/sqlalchemy/dialects/oracle/dictionary.py
@@ -0,0 +1,495 @@
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
+
+from .types import DATE
+from .types import LONG
+from .types import NUMBER
+from .types import RAW
+from .types import VARCHAR2
+from ... import Column
+from ... import MetaData
+from ... import Table
+from ... import table
+from ...sql.sqltypes import CHAR
+
+# constants
+DB_LINK_PLACEHOLDER = "__$sa_dblink$__"
+# tables
+dual = table("dual")
+dictionary_meta = MetaData()
+
+# NOTE: all the dictionary_meta are aliases because oracle does not like
+# using the full table@dblink for every column in query, and complains with
+# ORA-00960: ambiguous column naming in select list
+all_tables = Table(
+ "all_tables" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128), nullable=False),
+ Column("table_name", VARCHAR2(128), nullable=False),
+ Column("tablespace_name", VARCHAR2(30)),
+ Column("cluster_name", VARCHAR2(128)),
+ Column("iot_name", VARCHAR2(128)),
+ Column("status", VARCHAR2(8)),
+ Column("pct_free", NUMBER),
+ Column("pct_used", NUMBER),
+ Column("ini_trans", NUMBER),
+ Column("max_trans", NUMBER),
+ Column("initial_extent", NUMBER),
+ Column("next_extent", NUMBER),
+ Column("min_extents", NUMBER),
+ Column("max_extents", NUMBER),
+ Column("pct_increase", NUMBER),
+ Column("freelists", NUMBER),
+ Column("freelist_groups", NUMBER),
+ Column("logging", VARCHAR2(3)),
+ Column("backed_up", VARCHAR2(1)),
+ Column("num_rows", NUMBER),
+ Column("blocks", NUMBER),
+ Column("empty_blocks", NUMBER),
+ Column("avg_space", NUMBER),
+ Column("chain_cnt", NUMBER),
+ Column("avg_row_len", NUMBER),
+ Column("avg_space_freelist_blocks", NUMBER),
+ Column("num_freelist_blocks", NUMBER),
+ Column("degree", VARCHAR2(10)),
+ Column("instances", VARCHAR2(10)),
+ Column("cache", VARCHAR2(5)),
+ Column("table_lock", VARCHAR2(8)),
+ Column("sample_size", NUMBER),
+ Column("last_analyzed", DATE),
+ Column("partitioned", VARCHAR2(3)),
+ Column("iot_type", VARCHAR2(12)),
+ Column("temporary", VARCHAR2(1)),
+ Column("secondary", VARCHAR2(1)),
+ Column("nested", VARCHAR2(3)),
+ Column("buffer_pool", VARCHAR2(7)),
+ Column("flash_cache", VARCHAR2(7)),
+ Column("cell_flash_cache", VARCHAR2(7)),
+ Column("row_movement", VARCHAR2(8)),
+ Column("global_stats", VARCHAR2(3)),
+ Column("user_stats", VARCHAR2(3)),
+ Column("duration", VARCHAR2(15)),
+ Column("skip_corrupt", VARCHAR2(8)),
+ Column("monitoring", VARCHAR2(3)),
+ Column("cluster_owner", VARCHAR2(128)),
+ Column("dependencies", VARCHAR2(8)),
+ Column("compression", VARCHAR2(8)),
+ Column("compress_for", VARCHAR2(30)),
+ Column("dropped", VARCHAR2(3)),
+ Column("read_only", VARCHAR2(3)),
+ Column("segment_created", VARCHAR2(3)),
+ Column("result_cache", VARCHAR2(7)),
+ Column("clustering", VARCHAR2(3)),
+ Column("activity_tracking", VARCHAR2(23)),
+ Column("dml_timestamp", VARCHAR2(25)),
+ Column("has_identity", VARCHAR2(3)),
+ Column("container_data", VARCHAR2(3)),
+ Column("inmemory", VARCHAR2(8)),
+ Column("inmemory_priority", VARCHAR2(8)),
+ Column("inmemory_distribute", VARCHAR2(15)),
+ Column("inmemory_compression", VARCHAR2(17)),
+ Column("inmemory_duplicate", VARCHAR2(13)),
+ Column("default_collation", VARCHAR2(100)),
+ Column("duplicated", VARCHAR2(1)),
+ Column("sharded", VARCHAR2(1)),
+ Column("externally_sharded", VARCHAR2(1)),
+ Column("externally_duplicated", VARCHAR2(1)),
+ Column("external", VARCHAR2(3)),
+ Column("hybrid", VARCHAR2(3)),
+ Column("cellmemory", VARCHAR2(24)),
+ Column("containers_default", VARCHAR2(3)),
+ Column("container_map", VARCHAR2(3)),
+ Column("extended_data_link", VARCHAR2(3)),
+ Column("extended_data_link_map", VARCHAR2(3)),
+ Column("inmemory_service", VARCHAR2(12)),
+ Column("inmemory_service_name", VARCHAR2(1000)),
+ Column("container_map_object", VARCHAR2(3)),
+ Column("memoptimize_read", VARCHAR2(8)),
+ Column("memoptimize_write", VARCHAR2(8)),
+ Column("has_sensitive_column", VARCHAR2(3)),
+ Column("admit_null", VARCHAR2(3)),
+ Column("data_link_dml_enabled", VARCHAR2(3)),
+ Column("logical_replication", VARCHAR2(8)),
+).alias("a_tables")
+
+all_views = Table(
+ "all_views" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128), nullable=False),
+ Column("view_name", VARCHAR2(128), nullable=False),
+ Column("text_length", NUMBER),
+ Column("text", LONG),
+ Column("text_vc", VARCHAR2(4000)),
+ Column("type_text_length", NUMBER),
+ Column("type_text", VARCHAR2(4000)),
+ Column("oid_text_length", NUMBER),
+ Column("oid_text", VARCHAR2(4000)),
+ Column("view_type_owner", VARCHAR2(128)),
+ Column("view_type", VARCHAR2(128)),
+ Column("superview_name", VARCHAR2(128)),
+ Column("editioning_view", VARCHAR2(1)),
+ Column("read_only", VARCHAR2(1)),
+ Column("container_data", VARCHAR2(1)),
+ Column("bequeath", VARCHAR2(12)),
+ Column("origin_con_id", VARCHAR2(256)),
+ Column("default_collation", VARCHAR2(100)),
+ Column("containers_default", VARCHAR2(3)),
+ Column("container_map", VARCHAR2(3)),
+ Column("extended_data_link", VARCHAR2(3)),
+ Column("extended_data_link_map", VARCHAR2(3)),
+ Column("has_sensitive_column", VARCHAR2(3)),
+ Column("admit_null", VARCHAR2(3)),
+ Column("pdb_local_only", VARCHAR2(3)),
+).alias("a_views")
+
+all_sequences = Table(
+ "all_sequences" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("sequence_owner", VARCHAR2(128), nullable=False),
+ Column("sequence_name", VARCHAR2(128), nullable=False),
+ Column("min_value", NUMBER),
+ Column("max_value", NUMBER),
+ Column("increment_by", NUMBER, nullable=False),
+ Column("cycle_flag", VARCHAR2(1)),
+ Column("order_flag", VARCHAR2(1)),
+ Column("cache_size", NUMBER, nullable=False),
+ Column("last_number", NUMBER, nullable=False),
+ Column("scale_flag", VARCHAR2(1)),
+ Column("extend_flag", VARCHAR2(1)),
+ Column("sharded_flag", VARCHAR2(1)),
+ Column("session_flag", VARCHAR2(1)),
+ Column("keep_value", VARCHAR2(1)),
+).alias("a_sequences")
+
+all_users = Table(
+ "all_users" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("username", VARCHAR2(128), nullable=False),
+ Column("user_id", NUMBER, nullable=False),
+ Column("created", DATE, nullable=False),
+ Column("common", VARCHAR2(3)),
+ Column("oracle_maintained", VARCHAR2(1)),
+ Column("inherited", VARCHAR2(3)),
+ Column("default_collation", VARCHAR2(100)),
+ Column("implicit", VARCHAR2(3)),
+ Column("all_shard", VARCHAR2(3)),
+ Column("external_shard", VARCHAR2(3)),
+).alias("a_users")
+
+all_mviews = Table(
+ "all_mviews" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128), nullable=False),
+ Column("mview_name", VARCHAR2(128), nullable=False),
+ Column("container_name", VARCHAR2(128), nullable=False),
+ Column("query", LONG),
+ Column("query_len", NUMBER(38)),
+ Column("updatable", VARCHAR2(1)),
+ Column("update_log", VARCHAR2(128)),
+ Column("master_rollback_seg", VARCHAR2(128)),
+ Column("master_link", VARCHAR2(128)),
+ Column("rewrite_enabled", VARCHAR2(1)),
+ Column("rewrite_capability", VARCHAR2(9)),
+ Column("refresh_mode", VARCHAR2(6)),
+ Column("refresh_method", VARCHAR2(8)),
+ Column("build_mode", VARCHAR2(9)),
+ Column("fast_refreshable", VARCHAR2(18)),
+ Column("last_refresh_type", VARCHAR2(8)),
+ Column("last_refresh_date", DATE),
+ Column("last_refresh_end_time", DATE),
+ Column("staleness", VARCHAR2(19)),
+ Column("after_fast_refresh", VARCHAR2(19)),
+ Column("unknown_prebuilt", VARCHAR2(1)),
+ Column("unknown_plsql_func", VARCHAR2(1)),
+ Column("unknown_external_table", VARCHAR2(1)),
+ Column("unknown_consider_fresh", VARCHAR2(1)),
+ Column("unknown_import", VARCHAR2(1)),
+ Column("unknown_trusted_fd", VARCHAR2(1)),
+ Column("compile_state", VARCHAR2(19)),
+ Column("use_no_index", VARCHAR2(1)),
+ Column("stale_since", DATE),
+ Column("num_pct_tables", NUMBER),
+ Column("num_fresh_pct_regions", NUMBER),
+ Column("num_stale_pct_regions", NUMBER),
+ Column("segment_created", VARCHAR2(3)),
+ Column("evaluation_edition", VARCHAR2(128)),
+ Column("unusable_before", VARCHAR2(128)),
+ Column("unusable_beginning", VARCHAR2(128)),
+ Column("default_collation", VARCHAR2(100)),
+ Column("on_query_computation", VARCHAR2(1)),
+ Column("auto", VARCHAR2(3)),
+).alias("a_mviews")
+
+all_tab_identity_cols = Table(
+ "all_tab_identity_cols" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128), nullable=False),
+ Column("table_name", VARCHAR2(128), nullable=False),
+ Column("column_name", VARCHAR2(128), nullable=False),
+ Column("generation_type", VARCHAR2(10)),
+ Column("sequence_name", VARCHAR2(128), nullable=False),
+ Column("identity_options", VARCHAR2(298)),
+).alias("a_tab_identity_cols")
+
+all_tab_cols = Table(
+ "all_tab_cols" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128), nullable=False),
+ Column("table_name", VARCHAR2(128), nullable=False),
+ Column("column_name", VARCHAR2(128), nullable=False),
+ Column("data_type", VARCHAR2(128)),
+ Column("data_type_mod", VARCHAR2(3)),
+ Column("data_type_owner", VARCHAR2(128)),
+ Column("data_length", NUMBER, nullable=False),
+ Column("data_precision", NUMBER),
+ Column("data_scale", NUMBER),
+ Column("nullable", VARCHAR2(1)),
+ Column("column_id", NUMBER),
+ Column("default_length", NUMBER),
+ Column("data_default", LONG),
+ Column("num_distinct", NUMBER),
+ Column("low_value", RAW(1000)),
+ Column("high_value", RAW(1000)),
+ Column("density", NUMBER),
+ Column("num_nulls", NUMBER),
+ Column("num_buckets", NUMBER),
+ Column("last_analyzed", DATE),
+ Column("sample_size", NUMBER),
+ Column("character_set_name", VARCHAR2(44)),
+ Column("char_col_decl_length", NUMBER),
+ Column("global_stats", VARCHAR2(3)),
+ Column("user_stats", VARCHAR2(3)),
+ Column("avg_col_len", NUMBER),
+ Column("char_length", NUMBER),
+ Column("char_used", VARCHAR2(1)),
+ Column("v80_fmt_image", VARCHAR2(3)),
+ Column("data_upgraded", VARCHAR2(3)),
+ Column("hidden_column", VARCHAR2(3)),
+ Column("virtual_column", VARCHAR2(3)),
+ Column("segment_column_id", NUMBER),
+ Column("internal_column_id", NUMBER, nullable=False),
+ Column("histogram", VARCHAR2(15)),
+ Column("qualified_col_name", VARCHAR2(4000)),
+ Column("user_generated", VARCHAR2(3)),
+ Column("default_on_null", VARCHAR2(3)),
+ Column("identity_column", VARCHAR2(3)),
+ Column("evaluation_edition", VARCHAR2(128)),
+ Column("unusable_before", VARCHAR2(128)),
+ Column("unusable_beginning", VARCHAR2(128)),
+ Column("collation", VARCHAR2(100)),
+ Column("collated_column_id", NUMBER),
+).alias("a_tab_cols")
+
+all_tab_comments = Table(
+ "all_tab_comments" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128), nullable=False),
+ Column("table_name", VARCHAR2(128), nullable=False),
+ Column("table_type", VARCHAR2(11)),
+ Column("comments", VARCHAR2(4000)),
+ Column("origin_con_id", NUMBER),
+).alias("a_tab_comments")
+
+all_col_comments = Table(
+ "all_col_comments" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128), nullable=False),
+ Column("table_name", VARCHAR2(128), nullable=False),
+ Column("column_name", VARCHAR2(128), nullable=False),
+ Column("comments", VARCHAR2(4000)),
+ Column("origin_con_id", NUMBER),
+).alias("a_col_comments")
+
+all_mview_comments = Table(
+ "all_mview_comments" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128), nullable=False),
+ Column("mview_name", VARCHAR2(128), nullable=False),
+ Column("comments", VARCHAR2(4000)),
+).alias("a_mview_comments")
+
+all_ind_columns = Table(
+ "all_ind_columns" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("index_owner", VARCHAR2(128), nullable=False),
+ Column("index_name", VARCHAR2(128), nullable=False),
+ Column("table_owner", VARCHAR2(128), nullable=False),
+ Column("table_name", VARCHAR2(128), nullable=False),
+ Column("column_name", VARCHAR2(4000)),
+ Column("column_position", NUMBER, nullable=False),
+ Column("column_length", NUMBER, nullable=False),
+ Column("char_length", NUMBER),
+ Column("descend", VARCHAR2(4)),
+ Column("collated_column_id", NUMBER),
+).alias("a_ind_columns")
+
+all_indexes = Table(
+ "all_indexes" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128), nullable=False),
+ Column("index_name", VARCHAR2(128), nullable=False),
+ Column("index_type", VARCHAR2(27)),
+ Column("table_owner", VARCHAR2(128), nullable=False),
+ Column("table_name", VARCHAR2(128), nullable=False),
+ Column("table_type", CHAR(11)),
+ Column("uniqueness", VARCHAR2(9)),
+ Column("compression", VARCHAR2(13)),
+ Column("prefix_length", NUMBER),
+ Column("tablespace_name", VARCHAR2(30)),
+ Column("ini_trans", NUMBER),
+ Column("max_trans", NUMBER),
+ Column("initial_extent", NUMBER),
+ Column("next_extent", NUMBER),
+ Column("min_extents", NUMBER),
+ Column("max_extents", NUMBER),
+ Column("pct_increase", NUMBER),
+ Column("pct_threshold", NUMBER),
+ Column("include_column", NUMBER),
+ Column("freelists", NUMBER),
+ Column("freelist_groups", NUMBER),
+ Column("pct_free", NUMBER),
+ Column("logging", VARCHAR2(3)),
+ Column("blevel", NUMBER),
+ Column("leaf_blocks", NUMBER),
+ Column("distinct_keys", NUMBER),
+ Column("avg_leaf_blocks_per_key", NUMBER),
+ Column("avg_data_blocks_per_key", NUMBER),
+ Column("clustering_factor", NUMBER),
+ Column("status", VARCHAR2(8)),
+ Column("num_rows", NUMBER),
+ Column("sample_size", NUMBER),
+ Column("last_analyzed", DATE),
+ Column("degree", VARCHAR2(40)),
+ Column("instances", VARCHAR2(40)),
+ Column("partitioned", VARCHAR2(3)),
+ Column("temporary", VARCHAR2(1)),
+ Column("generated", VARCHAR2(1)),
+ Column("secondary", VARCHAR2(1)),
+ Column("buffer_pool", VARCHAR2(7)),
+ Column("flash_cache", VARCHAR2(7)),
+ Column("cell_flash_cache", VARCHAR2(7)),
+ Column("user_stats", VARCHAR2(3)),
+ Column("duration", VARCHAR2(15)),
+ Column("pct_direct_access", NUMBER),
+ Column("ityp_owner", VARCHAR2(128)),
+ Column("ityp_name", VARCHAR2(128)),
+ Column("parameters", VARCHAR2(1000)),
+ Column("global_stats", VARCHAR2(3)),
+ Column("domidx_status", VARCHAR2(12)),
+ Column("domidx_opstatus", VARCHAR2(6)),
+ Column("funcidx_status", VARCHAR2(8)),
+ Column("join_index", VARCHAR2(3)),
+ Column("iot_redundant_pkey_elim", VARCHAR2(3)),
+ Column("dropped", VARCHAR2(3)),
+ Column("visibility", VARCHAR2(9)),
+ Column("domidx_management", VARCHAR2(14)),
+ Column("segment_created", VARCHAR2(3)),
+ Column("orphaned_entries", VARCHAR2(3)),
+ Column("indexing", VARCHAR2(7)),
+ Column("auto", VARCHAR2(3)),
+).alias("a_indexes")
+
+all_constraints = Table(
+ "all_constraints" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128)),
+ Column("constraint_name", VARCHAR2(128)),
+ Column("constraint_type", VARCHAR2(1)),
+ Column("table_name", VARCHAR2(128)),
+ Column("search_condition", LONG),
+ Column("search_condition_vc", VARCHAR2(4000)),
+ Column("r_owner", VARCHAR2(128)),
+ Column("r_constraint_name", VARCHAR2(128)),
+ Column("delete_rule", VARCHAR2(9)),
+ Column("status", VARCHAR2(8)),
+ Column("deferrable", VARCHAR2(14)),
+ Column("deferred", VARCHAR2(9)),
+ Column("validated", VARCHAR2(13)),
+ Column("generated", VARCHAR2(14)),
+ Column("bad", VARCHAR2(3)),
+ Column("rely", VARCHAR2(4)),
+ Column("last_change", DATE),
+ Column("index_owner", VARCHAR2(128)),
+ Column("index_name", VARCHAR2(128)),
+ Column("invalid", VARCHAR2(7)),
+ Column("view_related", VARCHAR2(14)),
+ Column("origin_con_id", VARCHAR2(256)),
+).alias("a_constraints")
+
+all_cons_columns = Table(
+ "all_cons_columns" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128), nullable=False),
+ Column("constraint_name", VARCHAR2(128), nullable=False),
+ Column("table_name", VARCHAR2(128), nullable=False),
+ Column("column_name", VARCHAR2(4000)),
+ Column("position", NUMBER),
+).alias("a_cons_columns")
+
+# TODO figure out if it's still relevant, since there is no mention from here
+# https://docs.oracle.com/en/database/oracle/oracle-database/21/refrn/ALL_DB_LINKS.html
+# original note:
+# using user_db_links here since all_db_links appears
+# to have more restricted permissions.
+# https://docs.oracle.com/cd/B28359_01/server.111/b28310/ds_admin005.htm
+# will need to hear from more users if we are doing
+# the right thing here. See [ticket:2619]
+all_db_links = Table(
+ "all_db_links" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128), nullable=False),
+ Column("db_link", VARCHAR2(128), nullable=False),
+ Column("username", VARCHAR2(128)),
+ Column("host", VARCHAR2(2000)),
+ Column("created", DATE, nullable=False),
+ Column("hidden", VARCHAR2(3)),
+ Column("shard_internal", VARCHAR2(3)),
+ Column("valid", VARCHAR2(3)),
+ Column("intra_cdb", VARCHAR2(3)),
+).alias("a_db_links")
+
+all_synonyms = Table(
+ "all_synonyms" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128)),
+ Column("synonym_name", VARCHAR2(128)),
+ Column("table_owner", VARCHAR2(128)),
+ Column("table_name", VARCHAR2(128)),
+ Column("db_link", VARCHAR2(128)),
+ Column("origin_con_id", VARCHAR2(256)),
+).alias("a_synonyms")
+
+all_objects = Table(
+ "all_objects" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128), nullable=False),
+ Column("object_name", VARCHAR2(128), nullable=False),
+ Column("subobject_name", VARCHAR2(128)),
+ Column("object_id", NUMBER, nullable=False),
+ Column("data_object_id", NUMBER),
+ Column("object_type", VARCHAR2(23)),
+ Column("created", DATE, nullable=False),
+ Column("last_ddl_time", DATE, nullable=False),
+ Column("timestamp", VARCHAR2(19)),
+ Column("status", VARCHAR2(7)),
+ Column("temporary", VARCHAR2(1)),
+ Column("generated", VARCHAR2(1)),
+ Column("secondary", VARCHAR2(1)),
+ Column("namespace", NUMBER, nullable=False),
+ Column("edition_name", VARCHAR2(128)),
+ Column("sharing", VARCHAR2(13)),
+ Column("editionable", VARCHAR2(1)),
+ Column("oracle_maintained", VARCHAR2(1)),
+ Column("application", VARCHAR2(1)),
+ Column("default_collation", VARCHAR2(100)),
+ Column("duplicated", VARCHAR2(1)),
+ Column("sharded", VARCHAR2(1)),
+ Column("created_appid", NUMBER),
+ Column("created_vsnid", NUMBER),
+ Column("modified_appid", NUMBER),
+ Column("modified_vsnid", NUMBER),
+).alias("a_objects")
diff --git a/lib/sqlalchemy/dialects/oracle/provision.py b/lib/sqlalchemy/dialects/oracle/provision.py
index cba3b5be4..75b7a7aa9 100644
--- a/lib/sqlalchemy/dialects/oracle/provision.py
+++ b/lib/sqlalchemy/dialects/oracle/provision.py
@@ -2,9 +2,12 @@
from ... import create_engine
from ... import exc
+from ... import inspect
from ...engine import url as sa_url
from ...testing.provision import configure_follower
from ...testing.provision import create_db
+from ...testing.provision import drop_all_schema_objects_post_tables
+from ...testing.provision import drop_all_schema_objects_pre_tables
from ...testing.provision import drop_db
from ...testing.provision import follower_url_from_main
from ...testing.provision import log
@@ -28,6 +31,10 @@ def _oracle_create_db(cfg, eng, ident):
conn.exec_driver_sql("grant unlimited tablespace to %s" % ident)
conn.exec_driver_sql("grant unlimited tablespace to %s_ts1" % ident)
conn.exec_driver_sql("grant unlimited tablespace to %s_ts2" % ident)
+ # these are needed to create materialized views
+ conn.exec_driver_sql("grant create table to %s" % ident)
+ conn.exec_driver_sql("grant create table to %s_ts1" % ident)
+ conn.exec_driver_sql("grant create table to %s_ts2" % ident)
@configure_follower.for_db("oracle")
@@ -46,6 +53,30 @@ def _ora_drop_ignore(conn, dbname):
return False
+@drop_all_schema_objects_pre_tables.for_db("oracle")
+def _ora_drop_all_schema_objects_pre_tables(cfg, eng):
+ _purge_recyclebin(eng)
+ _purge_recyclebin(eng, cfg.test_schema)
+
+
+@drop_all_schema_objects_post_tables.for_db("oracle")
+def _ora_drop_all_schema_objects_post_tables(cfg, eng):
+
+ with eng.begin() as conn:
+ for syn in conn.dialect._get_synonyms(conn, None, None, None):
+ conn.exec_driver_sql(f"drop synonym {syn['synonym_name']}")
+
+ for syn in conn.dialect._get_synonyms(
+ conn, cfg.test_schema, None, None
+ ):
+ conn.exec_driver_sql(
+ f"drop synonym {cfg.test_schema}.{syn['synonym_name']}"
+ )
+
+ for tmp_table in inspect(conn).get_temp_table_names():
+ conn.exec_driver_sql(f"drop table {tmp_table}")
+
+
@drop_db.for_db("oracle")
def _oracle_drop_db(cfg, eng, ident):
with eng.begin() as conn:
@@ -60,13 +91,10 @@ def _oracle_drop_db(cfg, eng, ident):
@stop_test_class_outside_fixtures.for_db("oracle")
-def stop_test_class_outside_fixtures(config, db, cls):
+def _ora_stop_test_class_outside_fixtures(config, db, cls):
try:
- with db.begin() as conn:
- # run magic command to get rid of identity sequences
- # https://floo.bar/2019/11/29/drop-the-underlying-sequence-of-an-identity-column/ # noqa: E501
- conn.exec_driver_sql("purge recyclebin")
+ _purge_recyclebin(db)
except exc.DatabaseError as err:
log.warning("purge recyclebin command failed: %s", err)
@@ -85,6 +113,22 @@ def stop_test_class_outside_fixtures(config, db, cls):
_all_conns.clear()
+def _purge_recyclebin(eng, schema=None):
+ with eng.begin() as conn:
+ if schema is None:
+ # run magic command to get rid of identity sequences
+ # https://floo.bar/2019/11/29/drop-the-underlying-sequence-of-an-identity-column/ # noqa: E501
+ conn.exec_driver_sql("purge recyclebin")
+ else:
+ # per user: https://community.oracle.com/tech/developers/discussion/2255402/how-to-clear-dba-recyclebin-for-a-particular-user # noqa: E501
+ for owner, object_name, type_ in conn.exec_driver_sql(
+ "select owner, object_name,type from "
+ "dba_recyclebin where owner=:schema and type='TABLE'",
+ {"schema": conn.dialect.denormalize_name(schema)},
+ ).all():
+ conn.exec_driver_sql(f'purge {type_} {owner}."{object_name}"')
+
+
_all_conns = set()
diff --git a/lib/sqlalchemy/dialects/oracle/types.py b/lib/sqlalchemy/dialects/oracle/types.py
new file mode 100644
index 000000000..60a8ebcb5
--- /dev/null
+++ b/lib/sqlalchemy/dialects/oracle/types.py
@@ -0,0 +1,233 @@
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
+
+from ...sql import sqltypes
+from ...types import NVARCHAR
+from ...types import VARCHAR
+
+
+class RAW(sqltypes._Binary):
+ __visit_name__ = "RAW"
+
+
+OracleRaw = RAW
+
+
+class NCLOB(sqltypes.Text):
+ __visit_name__ = "NCLOB"
+
+
+class VARCHAR2(VARCHAR):
+ __visit_name__ = "VARCHAR2"
+
+
+NVARCHAR2 = NVARCHAR
+
+
+class NUMBER(sqltypes.Numeric, sqltypes.Integer):
+ __visit_name__ = "NUMBER"
+
+ def __init__(self, precision=None, scale=None, asdecimal=None):
+ if asdecimal is None:
+ asdecimal = bool(scale and scale > 0)
+
+ super(NUMBER, self).__init__(
+ precision=precision, scale=scale, asdecimal=asdecimal
+ )
+
+ def adapt(self, impltype):
+ ret = super(NUMBER, self).adapt(impltype)
+ # leave a hint for the DBAPI handler
+ ret._is_oracle_number = True
+ return ret
+
+ @property
+ def _type_affinity(self):
+ if bool(self.scale and self.scale > 0):
+ return sqltypes.Numeric
+ else:
+ return sqltypes.Integer
+
+
+class FLOAT(sqltypes.FLOAT):
+ """Oracle FLOAT.
+
+ This is the same as :class:`_sqltypes.FLOAT` except that
+ an Oracle-specific :paramref:`_oracle.FLOAT.binary_precision`
+ parameter is accepted, and
+ the :paramref:`_sqltypes.Float.precision` parameter is not accepted.
+
+ Oracle FLOAT types indicate precision in terms of "binary precision", which
+ defaults to 126. For a REAL type, the value is 63. This parameter does not
+ cleanly map to a specific number of decimal places but is roughly
+ equivalent to the desired number of decimal places divided by 0.3103.
+
+ .. versionadded:: 2.0
+
+ """
+
+ __visit_name__ = "FLOAT"
+
+ def __init__(
+ self,
+ binary_precision=None,
+ asdecimal=False,
+ decimal_return_scale=None,
+ ):
+ r"""
+ Construct a FLOAT
+
+ :param binary_precision: Oracle binary precision value to be rendered
+ in DDL. This may be approximated to the number of decimal characters
+ using the formula "decimal precision = 0.30103 * binary precision".
+ The default value used by Oracle for FLOAT / DOUBLE PRECISION is 126.
+
+ :param asdecimal: See :paramref:`_sqltypes.Float.asdecimal`
+
+ :param decimal_return_scale: See
+ :paramref:`_sqltypes.Float.decimal_return_scale`
+
+ """
+ super().__init__(
+ asdecimal=asdecimal, decimal_return_scale=decimal_return_scale
+ )
+ self.binary_precision = binary_precision
+
+
+class BINARY_DOUBLE(sqltypes.Float):
+ __visit_name__ = "BINARY_DOUBLE"
+
+
+class BINARY_FLOAT(sqltypes.Float):
+ __visit_name__ = "BINARY_FLOAT"
+
+
+class BFILE(sqltypes.LargeBinary):
+ __visit_name__ = "BFILE"
+
+
+class LONG(sqltypes.Text):
+ __visit_name__ = "LONG"
+
+
+class _OracleDateLiteralRender:
+ def _literal_processor_datetime(self, dialect):
+ def process(value):
+ if value is not None:
+ if getattr(value, "microsecond", None):
+ value = (
+ f"""TO_TIMESTAMP"""
+ f"""('{value.isoformat().replace("T", " ")}', """
+ """'YYYY-MM-DD HH24:MI:SS.FF')"""
+ )
+ else:
+ value = (
+ f"""TO_DATE"""
+ f"""('{value.isoformat().replace("T", " ")}', """
+ """'YYYY-MM-DD HH24:MI:SS')"""
+ )
+ return value
+
+ return process
+
+ def _literal_processor_date(self, dialect):
+ def process(value):
+ if value is not None:
+ if getattr(value, "microsecond", None):
+ value = (
+ f"""TO_TIMESTAMP"""
+ f"""('{value.isoformat().split("T")[0]}', """
+ """'YYYY-MM-DD')"""
+ )
+ else:
+ value = (
+ f"""TO_DATE"""
+ f"""('{value.isoformat().split("T")[0]}', """
+ """'YYYY-MM-DD')"""
+ )
+ return value
+
+ return process
+
+
+class DATE(_OracleDateLiteralRender, sqltypes.DateTime):
+ """Provide the oracle DATE type.
+
+ This type has no special Python behavior, except that it subclasses
+ :class:`_types.DateTime`; this is to suit the fact that the Oracle
+ ``DATE`` type supports a time value.
+
+ .. versionadded:: 0.9.4
+
+ """
+
+ __visit_name__ = "DATE"
+
+ def literal_processor(self, dialect):
+ return self._literal_processor_datetime(dialect)
+
+ def _compare_type_affinity(self, other):
+ return other._type_affinity in (sqltypes.DateTime, sqltypes.Date)
+
+
+class _OracleDate(_OracleDateLiteralRender, sqltypes.Date):
+ def literal_processor(self, dialect):
+ return self._literal_processor_date(dialect)
+
+
+class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
+ __visit_name__ = "INTERVAL"
+
+ def __init__(self, day_precision=None, second_precision=None):
+ """Construct an INTERVAL.
+
+ Note that only DAY TO SECOND intervals are currently supported.
+ This is due to a lack of support for YEAR TO MONTH intervals
+ within available DBAPIs.
+
+ :param day_precision: the day precision value. this is the number of
+ digits to store for the day field. Defaults to "2"
+ :param second_precision: the second precision value. this is the
+ number of digits to store for the fractional seconds field.
+ Defaults to "6".
+
+ """
+ self.day_precision = day_precision
+ self.second_precision = second_precision
+
+ @classmethod
+ def _adapt_from_generic_interval(cls, interval):
+ return INTERVAL(
+ day_precision=interval.day_precision,
+ second_precision=interval.second_precision,
+ )
+
+ @property
+ def _type_affinity(self):
+ return sqltypes.Interval
+
+ def as_generic(self, allow_nulltype=False):
+ return sqltypes.Interval(
+ native=True,
+ second_precision=self.second_precision,
+ day_precision=self.day_precision,
+ )
+
+
+class ROWID(sqltypes.TypeEngine):
+ """Oracle ROWID type.
+
+ When used in a cast() or similar, generates ROWID.
+
+ """
+
+ __visit_name__ = "ROWID"
+
+
+class _OracleBoolean(sqltypes.Boolean):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.NUMBER
diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py
index c2472fb55..85bbf8c5b 100644
--- a/lib/sqlalchemy/dialects/postgresql/__init__.py
+++ b/lib/sqlalchemy/dialects/postgresql/__init__.py
@@ -19,31 +19,16 @@ from .array import Any
from .array import ARRAY
from .array import array
from .base import BIGINT
-from .base import BIT
from .base import BOOLEAN
-from .base import BYTEA
from .base import CHAR
-from .base import CIDR
-from .base import CreateEnumType
from .base import DATE
from .base import DOUBLE_PRECISION
-from .base import DropEnumType
-from .base import ENUM
from .base import FLOAT
-from .base import INET
from .base import INTEGER
-from .base import INTERVAL
-from .base import MACADDR
-from .base import MONEY
from .base import NUMERIC
-from .base import OID
from .base import REAL
-from .base import REGCLASS
from .base import SMALLINT
from .base import TEXT
-from .base import TIME
-from .base import TIMESTAMP
-from .base import TSVECTOR
from .base import UUID
from .base import VARCHAR
from .dml import Insert
@@ -61,7 +46,21 @@ from .ranges import INT8RANGE
from .ranges import NUMRANGE
from .ranges import TSRANGE
from .ranges import TSTZRANGE
-from ...util import compat
+from .types import BIT
+from .types import BYTEA
+from .types import CIDR
+from .types import CreateEnumType
+from .types import DropEnumType
+from .types import ENUM
+from .types import INET
+from .types import INTERVAL
+from .types import MACADDR
+from .types import MONEY
+from .types import OID
+from .types import REGCLASS
+from .types import TIME
+from .types import TIMESTAMP
+from .types import TSVECTOR
# Alias psycopg also as psycopg_async
psycopg_async = type(
diff --git a/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py b/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py
index e831f2ed9..8dcd36c6d 100644
--- a/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py
+++ b/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py
@@ -1,3 +1,8 @@
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
import decimal
@@ -9,6 +14,9 @@ from .base import _INT_TYPES
from .base import PGDialect
from .base import PGExecutionContext
from .hstore import HSTORE
+from .pg_catalog import _SpaceVector
+from .pg_catalog import INT2VECTOR
+from .pg_catalog import OIDVECTOR
from ... import exc
from ... import types as sqltypes
from ... import util
@@ -66,6 +74,14 @@ class _PsycopgARRAY(PGARRAY):
render_bind_cast = True
+class _PsycopgINT2VECTOR(_SpaceVector, INT2VECTOR):
+ pass
+
+
+class _PsycopgOIDVECTOR(_SpaceVector, OIDVECTOR):
+ pass
+
+
class _PGExecutionContext_common_psycopg(PGExecutionContext):
def create_server_side_cursor(self):
# use server-side cursors:
@@ -91,6 +107,8 @@ class _PGDialect_common_psycopg(PGDialect):
sqltypes.Numeric: _PsycopgNumeric,
HSTORE: _PsycopgHStore,
sqltypes.ARRAY: _PsycopgARRAY,
+ INT2VECTOR: _PsycopgINT2VECTOR,
+ OIDVECTOR: _PsycopgOIDVECTOR,
},
)
diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py
index 1ec787e1f..d6385a5d6 100644
--- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py
+++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py
@@ -274,6 +274,10 @@ class AsyncpgOID(OID):
render_bind_cast = True
+class AsyncpgCHAR(sqltypes.CHAR):
+ render_bind_cast = True
+
+
class PGExecutionContext_asyncpg(PGExecutionContext):
def handle_dbapi_exception(self, e):
if isinstance(
@@ -823,6 +827,7 @@ class PGDialect_asyncpg(PGDialect):
sqltypes.Enum: AsyncPgEnum,
OID: AsyncpgOID,
REGCLASS: AsyncpgREGCLASS,
+ sqltypes.CHAR: AsyncpgCHAR,
},
)
is_async = True
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index 0de8a9c44..8402341f6 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -11,7 +11,7 @@ r"""
:name: PostgreSQL
:full_support: 9.6, 10, 11, 12, 13, 14
:normal_support: 9.6+
- :best_effort: 8+
+ :best_effort: 9+
.. _postgresql_sequences:
@@ -1448,23 +1448,52 @@ E.g.::
from __future__ import annotations
from collections import defaultdict
-import datetime as dt
+from functools import lru_cache
import re
-from typing import Any
from . import array as _array
from . import dml
from . import hstore as _hstore
from . import json as _json
+from . import pg_catalog
from . import ranges as _ranges
+from .types import _DECIMAL_TYPES # noqa
+from .types import _FLOAT_TYPES # noqa
+from .types import _INT_TYPES # noqa
+from .types import BIT
+from .types import BYTEA
+from .types import CIDR
+from .types import CreateEnumType # noqa
+from .types import DropEnumType # noqa
+from .types import ENUM
+from .types import INET
+from .types import INTERVAL
+from .types import MACADDR
+from .types import MONEY
+from .types import OID
+from .types import PGBit # noqa
+from .types import PGCidr # noqa
+from .types import PGInet # noqa
+from .types import PGInterval # noqa
+from .types import PGMacAddr # noqa
+from .types import PGUuid
+from .types import REGCLASS
+from .types import TIME
+from .types import TIMESTAMP
+from .types import TSVECTOR
from ... import exc
from ... import schema
+from ... import select
from ... import sql
from ... import util
from ...engine import characteristics
from ...engine import default
from ...engine import interfaces
+from ...engine import ObjectKind
+from ...engine import ObjectScope
from ...engine import reflection
+from ...engine.reflection import ReflectionDefaults
+from ...sql import bindparam
from ...sql import coercions
from ...sql import compiler
from ...sql import elements
@@ -1472,7 +1501,7 @@ from ...sql import expression
from ...sql import roles
from ...sql import sqltypes
from ...sql import util as sql_util
-from ...sql.ddl import InvokeDDLBase
+from ...sql.visitors import InternalTraversal
from ...types import BIGINT
from ...types import BOOLEAN
from ...types import CHAR
@@ -1596,469 +1625,6 @@ RESERVED_WORDS = set(
]
)
-_DECIMAL_TYPES = (1231, 1700)
-_FLOAT_TYPES = (700, 701, 1021, 1022)
-_INT_TYPES = (20, 21, 23, 26, 1005, 1007, 1016)
-
-
-class PGUuid(UUID):
- render_bind_cast = True
- render_literal_cast = True
-
-
-class BYTEA(sqltypes.LargeBinary[bytes]):
- __visit_name__ = "BYTEA"
-
-
-class INET(sqltypes.TypeEngine[str]):
- __visit_name__ = "INET"
-
-
-PGInet = INET
-
-
-class CIDR(sqltypes.TypeEngine[str]):
- __visit_name__ = "CIDR"
-
-
-PGCidr = CIDR
-
-
-class MACADDR(sqltypes.TypeEngine[str]):
- __visit_name__ = "MACADDR"
-
-
-PGMacAddr = MACADDR
-
-
-class MONEY(sqltypes.TypeEngine[str]):
-
- r"""Provide the PostgreSQL MONEY type.
-
- Depending on driver, result rows using this type may return a
- string value which includes currency symbols.
-
- For this reason, it may be preferable to provide conversion to a
- numerically-based currency datatype using :class:`_types.TypeDecorator`::
-
- import re
- import decimal
- from sqlalchemy import TypeDecorator
-
- class NumericMoney(TypeDecorator):
- impl = MONEY
-
- def process_result_value(self, value: Any, dialect: Any) -> None:
- if value is not None:
- # adjust this for the currency and numeric
- m = re.match(r"\$([\d.]+)", value)
- if m:
- value = decimal.Decimal(m.group(1))
- return value
-
- Alternatively, the conversion may be applied as a CAST using
- the :meth:`_types.TypeDecorator.column_expression` method as follows::
-
- import decimal
- from sqlalchemy import cast
- from sqlalchemy import TypeDecorator
-
- class NumericMoney(TypeDecorator):
- impl = MONEY
-
- def column_expression(self, column: Any):
- return cast(column, Numeric())
-
- .. versionadded:: 1.2
-
- """
-
- __visit_name__ = "MONEY"
-
-
-class OID(sqltypes.TypeEngine[int]):
-
- """Provide the PostgreSQL OID type.
-
- .. versionadded:: 0.9.5
-
- """
-
- __visit_name__ = "OID"
-
-
-class REGCLASS(sqltypes.TypeEngine[str]):
-
- """Provide the PostgreSQL REGCLASS type.
-
- .. versionadded:: 1.2.7
-
- """
-
- __visit_name__ = "REGCLASS"
-
-
-class TIMESTAMP(sqltypes.TIMESTAMP):
- def __init__(self, timezone=False, precision=None):
- super(TIMESTAMP, self).__init__(timezone=timezone)
- self.precision = precision
-
-
-class TIME(sqltypes.TIME):
- def __init__(self, timezone=False, precision=None):
- super(TIME, self).__init__(timezone=timezone)
- self.precision = precision
-
-
-class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
-
- """PostgreSQL INTERVAL type."""
-
- __visit_name__ = "INTERVAL"
- native = True
-
- def __init__(self, precision=None, fields=None):
- """Construct an INTERVAL.
-
- :param precision: optional integer precision value
- :param fields: string fields specifier. allows storage of fields
- to be limited, such as ``"YEAR"``, ``"MONTH"``, ``"DAY TO HOUR"``,
- etc.
-
- .. versionadded:: 1.2
-
- """
- self.precision = precision
- self.fields = fields
-
- @classmethod
- def adapt_emulated_to_native(cls, interval, **kw):
- return INTERVAL(precision=interval.second_precision)
-
- @property
- def _type_affinity(self):
- return sqltypes.Interval
-
- def as_generic(self, allow_nulltype=False):
- return sqltypes.Interval(native=True, second_precision=self.precision)
-
- @property
- def python_type(self):
- return dt.timedelta
-
-
-PGInterval = INTERVAL
-
-
-class BIT(sqltypes.TypeEngine[int]):
- __visit_name__ = "BIT"
-
- def __init__(self, length=None, varying=False):
- if not varying:
- # BIT without VARYING defaults to length 1
- self.length = length or 1
- else:
- # but BIT VARYING can be unlimited-length, so no default
- self.length = length
- self.varying = varying
-
-
-PGBit = BIT
-
-
-class TSVECTOR(sqltypes.TypeEngine[Any]):
-
- """The :class:`_postgresql.TSVECTOR` type implements the PostgreSQL
- text search type TSVECTOR.
-
- It can be used to do full text queries on natural language
- documents.
-
- .. versionadded:: 0.9.0
-
- .. seealso::
-
- :ref:`postgresql_match`
-
- """
-
- __visit_name__ = "TSVECTOR"
-
-
-class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
-
- """PostgreSQL ENUM type.
-
- This is a subclass of :class:`_types.Enum` which includes
- support for PG's ``CREATE TYPE`` and ``DROP TYPE``.
-
- When the builtin type :class:`_types.Enum` is used and the
- :paramref:`.Enum.native_enum` flag is left at its default of
- True, the PostgreSQL backend will use a :class:`_postgresql.ENUM`
- type as the implementation, so the special create/drop rules
- will be used.
-
- The create/drop behavior of ENUM is necessarily intricate, due to the
- awkward relationship the ENUM type has in relationship to the
- parent table, in that it may be "owned" by just a single table, or
- may be shared among many tables.
-
- When using :class:`_types.Enum` or :class:`_postgresql.ENUM`
- in an "inline" fashion, the ``CREATE TYPE`` and ``DROP TYPE`` is emitted
- corresponding to when the :meth:`_schema.Table.create` and
- :meth:`_schema.Table.drop`
- methods are called::
-
- table = Table('sometable', metadata,
- Column('some_enum', ENUM('a', 'b', 'c', name='myenum'))
- )
-
- table.create(engine) # will emit CREATE ENUM and CREATE TABLE
- table.drop(engine) # will emit DROP TABLE and DROP ENUM
-
- To use a common enumerated type between multiple tables, the best
- practice is to declare the :class:`_types.Enum` or
- :class:`_postgresql.ENUM` independently, and associate it with the
- :class:`_schema.MetaData` object itself::
-
- my_enum = ENUM('a', 'b', 'c', name='myenum', metadata=metadata)
-
- t1 = Table('sometable_one', metadata,
- Column('some_enum', myenum)
- )
-
- t2 = Table('sometable_two', metadata,
- Column('some_enum', myenum)
- )
-
- When this pattern is used, care must still be taken at the level
- of individual table creates. Emitting CREATE TABLE without also
- specifying ``checkfirst=True`` will still cause issues::
-
- t1.create(engine) # will fail: no such type 'myenum'
-
- If we specify ``checkfirst=True``, the individual table-level create
- operation will check for the ``ENUM`` and create if not exists::
-
- # will check if enum exists, and emit CREATE TYPE if not
- t1.create(engine, checkfirst=True)
-
- When using a metadata-level ENUM type, the type will always be created
- and dropped if either the metadata-wide create/drop is called::
-
- metadata.create_all(engine) # will emit CREATE TYPE
- metadata.drop_all(engine) # will emit DROP TYPE
-
- The type can also be created and dropped directly::
-
- my_enum.create(engine)
- my_enum.drop(engine)
-
- .. versionchanged:: 1.0.0 The PostgreSQL :class:`_postgresql.ENUM` type
- now behaves more strictly with regards to CREATE/DROP. A metadata-level
- ENUM type will only be created and dropped at the metadata level,
- not the table level, with the exception of
- ``table.create(checkfirst=True)``.
- The ``table.drop()`` call will now emit a DROP TYPE for a table-level
- enumerated type.
-
- """
-
- native_enum = True
-
- def __init__(self, *enums, **kw):
- """Construct an :class:`_postgresql.ENUM`.
-
- Arguments are the same as that of
- :class:`_types.Enum`, but also including
- the following parameters.
-
- :param create_type: Defaults to True.
- Indicates that ``CREATE TYPE`` should be
- emitted, after optionally checking for the
- presence of the type, when the parent
- table is being created; and additionally
- that ``DROP TYPE`` is called when the table
- is dropped. When ``False``, no check
- will be performed and no ``CREATE TYPE``
- or ``DROP TYPE`` is emitted, unless
- :meth:`~.postgresql.ENUM.create`
- or :meth:`~.postgresql.ENUM.drop`
- are called directly.
- Setting to ``False`` is helpful
- when invoking a creation scheme to a SQL file
- without access to the actual database -
- the :meth:`~.postgresql.ENUM.create` and
- :meth:`~.postgresql.ENUM.drop` methods can
- be used to emit SQL to a target bind.
-
- """
- native_enum = kw.pop("native_enum", None)
- if native_enum is False:
- util.warn(
- "the native_enum flag does not apply to the "
- "sqlalchemy.dialects.postgresql.ENUM datatype; this type "
- "always refers to ENUM. Use sqlalchemy.types.Enum for "
- "non-native enum."
- )
- self.create_type = kw.pop("create_type", True)
- super(ENUM, self).__init__(*enums, **kw)
-
- @classmethod
- def adapt_emulated_to_native(cls, impl, **kw):
- """Produce a PostgreSQL native :class:`_postgresql.ENUM` from plain
- :class:`.Enum`.
-
- """
- kw.setdefault("validate_strings", impl.validate_strings)
- kw.setdefault("name", impl.name)
- kw.setdefault("schema", impl.schema)
- kw.setdefault("inherit_schema", impl.inherit_schema)
- kw.setdefault("metadata", impl.metadata)
- kw.setdefault("_create_events", False)
- kw.setdefault("values_callable", impl.values_callable)
- kw.setdefault("omit_aliases", impl._omit_aliases)
- return cls(**kw)
-
- def create(self, bind=None, checkfirst=True):
- """Emit ``CREATE TYPE`` for this
- :class:`_postgresql.ENUM`.
-
- If the underlying dialect does not support
- PostgreSQL CREATE TYPE, no action is taken.
-
- :param bind: a connectable :class:`_engine.Engine`,
- :class:`_engine.Connection`, or similar object to emit
- SQL.
- :param checkfirst: if ``True``, a query against
- the PG catalog will be first performed to see
- if the type does not exist already before
- creating.
-
- """
- if not bind.dialect.supports_native_enum:
- return
-
- bind._run_ddl_visitor(self.EnumGenerator, self, checkfirst=checkfirst)
-
- def drop(self, bind=None, checkfirst=True):
- """Emit ``DROP TYPE`` for this
- :class:`_postgresql.ENUM`.
-
- If the underlying dialect does not support
- PostgreSQL DROP TYPE, no action is taken.
-
- :param bind: a connectable :class:`_engine.Engine`,
- :class:`_engine.Connection`, or similar object to emit
- SQL.
- :param checkfirst: if ``True``, a query against
- the PG catalog will be first performed to see
- if the type actually exists before dropping.
-
- """
- if not bind.dialect.supports_native_enum:
- return
-
- bind._run_ddl_visitor(self.EnumDropper, self, checkfirst=checkfirst)
-
- class EnumGenerator(InvokeDDLBase):
- def __init__(self, dialect, connection, checkfirst=False, **kwargs):
- super(ENUM.EnumGenerator, self).__init__(connection, **kwargs)
- self.checkfirst = checkfirst
-
- def _can_create_enum(self, enum):
- if not self.checkfirst:
- return True
-
- effective_schema = self.connection.schema_for_object(enum)
-
- return not self.connection.dialect.has_type(
- self.connection, enum.name, schema=effective_schema
- )
-
- def visit_enum(self, enum):
- if not self._can_create_enum(enum):
- return
-
- self.connection.execute(CreateEnumType(enum))
-
- class EnumDropper(InvokeDDLBase):
- def __init__(self, dialect, connection, checkfirst=False, **kwargs):
- super(ENUM.EnumDropper, self).__init__(connection, **kwargs)
- self.checkfirst = checkfirst
-
- def _can_drop_enum(self, enum):
- if not self.checkfirst:
- return True
-
- effective_schema = self.connection.schema_for_object(enum)
-
- return self.connection.dialect.has_type(
- self.connection, enum.name, schema=effective_schema
- )
-
- def visit_enum(self, enum):
- if not self._can_drop_enum(enum):
- return
-
- self.connection.execute(DropEnumType(enum))
-
- def get_dbapi_type(self, dbapi):
- """dont return dbapi.STRING for ENUM in PostgreSQL, since that's
- a different type"""
-
- return None
-
- def _check_for_name_in_memos(self, checkfirst, kw):
- """Look in the 'ddl runner' for 'memos', then
- note our name in that collection.
-
- This to ensure a particular named enum is operated
- upon only once within any kind of create/drop
- sequence without relying upon "checkfirst".
-
- """
- if not self.create_type:
- return True
- if "_ddl_runner" in kw:
- ddl_runner = kw["_ddl_runner"]
- if "_pg_enums" in ddl_runner.memo:
- pg_enums = ddl_runner.memo["_pg_enums"]
- else:
- pg_enums = ddl_runner.memo["_pg_enums"] = set()
- present = (self.schema, self.name) in pg_enums
- pg_enums.add((self.schema, self.name))
- return present
- else:
- return False
-
- def _on_table_create(self, target, bind, checkfirst=False, **kw):
- if (
- checkfirst
- or (
- not self.metadata
- and not kw.get("_is_metadata_operation", False)
- )
- ) and not self._check_for_name_in_memos(checkfirst, kw):
- self.create(bind=bind, checkfirst=checkfirst)
-
- def _on_table_drop(self, target, bind, checkfirst=False, **kw):
- if (
- not self.metadata
- and not kw.get("_is_metadata_operation", False)
- and not self._check_for_name_in_memos(checkfirst, kw)
- ):
- self.drop(bind=bind, checkfirst=checkfirst)
-
- def _on_metadata_create(self, target, bind, checkfirst=False, **kw):
- if not self._check_for_name_in_memos(checkfirst, kw):
- self.create(bind=bind, checkfirst=checkfirst)
-
- def _on_metadata_drop(self, target, bind, checkfirst=False, **kw):
- if not self._check_for_name_in_memos(checkfirst, kw):
- self.drop(bind=bind, checkfirst=checkfirst)
-
-
colspecs = {
sqltypes.ARRAY: _array.ARRAY,
sqltypes.Interval: INTERVAL,
@@ -2997,8 +2563,19 @@ class PGIdentifierPreparer(compiler.IdentifierPreparer):
class PGInspector(reflection.Inspector):
+ dialect: PGDialect
+
def get_table_oid(self, table_name, schema=None):
- """Return the OID for the given table name."""
+ """Return the OID for the given table name.
+
+ :param table_name: string name of the table. For special quoting,
+ use :class:`.quoted_name`.
+
+ :param schema: string schema name; if omitted, uses the default schema
+ of the database connection. For special quoting,
+ use :class:`.quoted_name`.
+
+ """
with self._operation_context() as conn:
return self.dialect.get_table_oid(
@@ -3023,9 +2600,10 @@ class PGInspector(reflection.Inspector):
.. versionadded:: 1.0.0
"""
- schema = schema or self.default_schema_name
with self._operation_context() as conn:
- return self.dialect._load_enums(conn, schema)
+ return self.dialect._load_enums(
+ conn, schema, info_cache=self.info_cache
+ )
def get_foreign_table_names(self, schema=None):
"""Return a list of FOREIGN TABLE names.
@@ -3038,38 +2616,29 @@ class PGInspector(reflection.Inspector):
.. versionadded:: 1.0.0
"""
- schema = schema or self.default_schema_name
with self._operation_context() as conn:
- return self.dialect._get_foreign_table_names(conn, schema)
-
- def get_view_names(self, schema=None, include=("plain", "materialized")):
- """Return all view names in `schema`.
+ return self.dialect._get_foreign_table_names(
+ conn, schema, info_cache=self.info_cache
+ )
- :param schema: Optional, retrieve names from a non-default schema.
- For special quoting, use :class:`.quoted_name`.
+ def has_type(self, type_name, schema=None, **kw):
+ """Return if the database has the specified type in the provided
+ schema.
- :param include: specify which types of views to return. Passed
- as a string value (for a single type) or a tuple (for any number
- of types). Defaults to ``('plain', 'materialized')``.
+ :param type_name: the type to check.
+ :param schema: schema name. If None, the default schema
+ (typically 'public') is used. May also be set to '*' to
+ check in all schemas.
- .. versionadded:: 1.1
+ .. versionadded:: 2.0
"""
-
with self._operation_context() as conn:
- return self.dialect.get_view_names(
- conn, schema, info_cache=self.info_cache, include=include
+ return self.dialect.has_type(
+ conn, type_name, schema, info_cache=self.info_cache
)
-class CreateEnumType(schema._CreateDropBase):
- __visit_name__ = "create_enum_type"
-
-
-class DropEnumType(schema._CreateDropBase):
- __visit_name__ = "drop_enum_type"
-
-
class PGExecutionContext(default.DefaultExecutionContext):
def fire_sequence(self, seq, type_):
return self._execute_scalar(
@@ -3262,35 +2831,14 @@ class PGDialect(default.DefaultDialect):
def initialize(self, connection):
super(PGDialect, self).initialize(connection)
- if self.server_version_info <= (8, 2):
- self.delete_returning = (
- self.update_returning
- ) = self.insert_returning = False
-
- self.supports_native_enum = self.server_version_info >= (8, 3)
- if not self.supports_native_enum:
- self.colspecs = self.colspecs.copy()
- # pop base Enum type
- self.colspecs.pop(sqltypes.Enum, None)
- # psycopg2, others may have placed ENUM here as well
- self.colspecs.pop(ENUM, None)
-
# https://www.postgresql.org/docs/9.3/static/release-9-2.html#AEN116689
self.supports_smallserial = self.server_version_info >= (9, 2)
- if self.server_version_info < (8, 2):
- self._backslash_escapes = False
- else:
- # ensure this query is not emitted on server version < 8.2
- # as it will fail
- std_string = connection.exec_driver_sql(
- "show standard_conforming_strings"
- ).scalar()
- self._backslash_escapes = std_string == "off"
-
- self._supports_create_index_concurrently = (
- self.server_version_info >= (8, 2)
- )
+ std_string = connection.exec_driver_sql(
+ "show standard_conforming_strings"
+ ).scalar()
+ self._backslash_escapes = std_string == "off"
+
self._supports_drop_index_concurrently = self.server_version_info >= (
9,
2,
@@ -3370,122 +2918,100 @@ class PGDialect(default.DefaultDialect):
self.do_commit(connection.connection)
def do_recover_twophase(self, connection):
- resultset = connection.execute(
+ return connection.scalars(
sql.text("SELECT gid FROM pg_prepared_xacts")
- )
- return [row[0] for row in resultset]
+ ).all()
def _get_default_schema_name(self, connection):
return connection.exec_driver_sql("select current_schema()").scalar()
- def has_schema(self, connection, schema):
- query = (
- "select nspname from pg_namespace " "where lower(nspname)=:schema"
- )
- cursor = connection.execute(
- sql.text(query).bindparams(
- sql.bindparam(
- "schema",
- str(schema.lower()),
- type_=sqltypes.Unicode,
- )
- )
+ @reflection.cache
+ def has_schema(self, connection, schema, **kw):
+ query = select(pg_catalog.pg_namespace.c.nspname).where(
+ pg_catalog.pg_namespace.c.nspname == schema
)
+ return bool(connection.scalar(query))
- return bool(cursor.first())
-
- def has_table(self, connection, table_name, schema=None):
- self._ensure_has_table_connection(connection)
- # seems like case gets folded in pg_class...
+ def _pg_class_filter_scope_schema(
+ self, query, schema, scope, pg_class_table=None
+ ):
+ if pg_class_table is None:
+ pg_class_table = pg_catalog.pg_class
+ query = query.join(
+ pg_catalog.pg_namespace,
+ pg_catalog.pg_namespace.c.oid == pg_class_table.c.relnamespace,
+ )
+ if scope is ObjectScope.DEFAULT:
+ query = query.where(pg_class_table.c.relpersistence != "t")
+ elif scope is ObjectScope.TEMPORARY:
+ query = query.where(pg_class_table.c.relpersistence == "t")
if schema is None:
- cursor = connection.execute(
- sql.text(
- "select relname from pg_class c join pg_namespace n on "
- "n.oid=c.relnamespace where "
- "pg_catalog.pg_table_is_visible(c.oid) "
- "and relname=:name"
- ).bindparams(
- sql.bindparam(
- "name",
- str(table_name),
- type_=sqltypes.Unicode,
- )
- )
+ query = query.where(
+ pg_catalog.pg_table_is_visible(pg_class_table.c.oid),
+ # ignore pg_catalog schema
+ pg_catalog.pg_namespace.c.nspname != "pg_catalog",
)
else:
- cursor = connection.execute(
- sql.text(
- "select relname from pg_class c join pg_namespace n on "
- "n.oid=c.relnamespace where n.nspname=:schema and "
- "relname=:name"
- ).bindparams(
- sql.bindparam(
- "name",
- str(table_name),
- type_=sqltypes.Unicode,
- ),
- sql.bindparam(
- "schema",
- str(schema),
- type_=sqltypes.Unicode,
- ),
- )
- )
- return bool(cursor.first())
-
- def has_sequence(self, connection, sequence_name, schema=None):
- if schema is None:
- schema = self.default_schema_name
- cursor = connection.execute(
- sql.text(
- "SELECT relname FROM pg_class c join pg_namespace n on "
- "n.oid=c.relnamespace where relkind='S' and "
- "n.nspname=:schema and relname=:name"
- ).bindparams(
- sql.bindparam(
- "name",
- str(sequence_name),
- type_=sqltypes.Unicode,
- ),
- sql.bindparam(
- "schema",
- str(schema),
- type_=sqltypes.Unicode,
- ),
- )
+ query = query.where(pg_catalog.pg_namespace.c.nspname == schema)
+ return query
+
+ def _pg_class_relkind_condition(self, relkinds, pg_class_table=None):
+ if pg_class_table is None:
+ pg_class_table = pg_catalog.pg_class
+ # uses the any form instead of in otherwise postgresql complaings
+ # that 'IN could not convert type character to "char"'
+ return pg_class_table.c.relkind == sql.any_(_array.array(relkinds))
+
+ @lru_cache()
+ def _has_table_query(self, schema):
+ query = select(pg_catalog.pg_class.c.relname).where(
+ pg_catalog.pg_class.c.relname == bindparam("table_name"),
+ self._pg_class_relkind_condition(
+ pg_catalog.RELKINDS_ALL_TABLE_LIKE
+ ),
+ )
+ return self._pg_class_filter_scope_schema(
+ query, schema, scope=ObjectScope.ANY
)
- return bool(cursor.first())
+ @reflection.cache
+ def has_table(self, connection, table_name, schema=None, **kw):
+ self._ensure_has_table_connection(connection)
+ query = self._has_table_query(schema)
+ return bool(connection.scalar(query, {"table_name": table_name}))
- def has_type(self, connection, type_name, schema=None):
- if schema is not None:
- query = """
- SELECT EXISTS (
- SELECT * FROM pg_catalog.pg_type t, pg_catalog.pg_namespace n
- WHERE t.typnamespace = n.oid
- AND t.typname = :typname
- AND n.nspname = :nspname
- )
- """
- query = sql.text(query)
- else:
- query = """
- SELECT EXISTS (
- SELECT * FROM pg_catalog.pg_type t
- WHERE t.typname = :typname
- AND pg_type_is_visible(t.oid)
- )
- """
- query = sql.text(query)
- query = query.bindparams(
- sql.bindparam("typname", str(type_name), type_=sqltypes.Unicode)
+ @reflection.cache
+ def has_sequence(self, connection, sequence_name, schema=None, **kw):
+ query = select(pg_catalog.pg_class.c.relname).where(
+ pg_catalog.pg_class.c.relkind == "S",
+ pg_catalog.pg_class.c.relname == sequence_name,
)
- if schema is not None:
- query = query.bindparams(
- sql.bindparam("nspname", str(schema), type_=sqltypes.Unicode)
+ query = self._pg_class_filter_scope_schema(
+ query, schema, scope=ObjectScope.ANY
+ )
+ return bool(connection.scalar(query))
+
+ @reflection.cache
+ def has_type(self, connection, type_name, schema=None, **kw):
+ query = (
+ select(pg_catalog.pg_type.c.typname)
+ .join(
+ pg_catalog.pg_namespace,
+ pg_catalog.pg_namespace.c.oid
+ == pg_catalog.pg_type.c.typnamespace,
)
- cursor = connection.execute(query)
- return bool(cursor.scalar())
+ .where(pg_catalog.pg_type.c.typname == type_name)
+ )
+ if schema is None:
+ query = query.where(
+ pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid),
+ # ignore pg_catalog schema
+ pg_catalog.pg_namespace.c.nspname != "pg_catalog",
+ )
+ elif schema != "*":
+ query = query.where(pg_catalog.pg_namespace.c.nspname == schema)
+
+ return bool(connection.scalar(query))
def _get_server_version_info(self, connection):
v = connection.exec_driver_sql("select pg_catalog.version()").scalar()
@@ -3502,229 +3028,300 @@ class PGDialect(default.DefaultDialect):
@reflection.cache
def get_table_oid(self, connection, table_name, schema=None, **kw):
- """Fetch the oid for schema.table_name.
-
- Several reflection methods require the table oid. The idea for using
- this method is that it can be fetched one time and cached for
- subsequent calls.
-
- """
- table_oid = None
- if schema is not None:
- schema_where_clause = "n.nspname = :schema"
- else:
- schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)"
- query = (
- """
- SELECT c.oid
- FROM pg_catalog.pg_class c
- LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
- WHERE (%s)
- AND c.relname = :table_name AND c.relkind in
- ('r', 'v', 'm', 'f', 'p')
- """
- % schema_where_clause
+ """Fetch the oid for schema.table_name."""
+ query = select(pg_catalog.pg_class.c.oid).where(
+ pg_catalog.pg_class.c.relname == table_name,
+ self._pg_class_relkind_condition(
+ pg_catalog.RELKINDS_ALL_TABLE_LIKE
+ ),
)
- # Since we're binding to unicode, table_name and schema_name must be
- # unicode.
- table_name = str(table_name)
- if schema is not None:
- schema = str(schema)
- s = sql.text(query).bindparams(table_name=sqltypes.Unicode)
- s = s.columns(oid=sqltypes.Integer)
- if schema:
- s = s.bindparams(sql.bindparam("schema", type_=sqltypes.Unicode))
- c = connection.execute(s, dict(table_name=table_name, schema=schema))
- table_oid = c.scalar()
+ query = self._pg_class_filter_scope_schema(
+ query, schema, scope=ObjectScope.ANY
+ )
+ table_oid = connection.scalar(query)
if table_oid is None:
- raise exc.NoSuchTableError(table_name)
+ raise exc.NoSuchTableError(
+ f"{schema}.{table_name}" if schema else table_name
+ )
return table_oid
@reflection.cache
def get_schema_names(self, connection, **kw):
- result = connection.execute(
- sql.text(
- "SELECT nspname FROM pg_namespace "
- "WHERE nspname NOT LIKE 'pg_%' "
- "ORDER BY nspname"
- ).columns(nspname=sqltypes.Unicode)
+ query = (
+ select(pg_catalog.pg_namespace.c.nspname)
+ .where(pg_catalog.pg_namespace.c.nspname.not_like("pg_%"))
+ .order_by(pg_catalog.pg_namespace.c.nspname)
+ )
+ return connection.scalars(query).all()
+
+ def _get_relnames_for_relkinds(self, connection, schema, relkinds, scope):
+ query = select(pg_catalog.pg_class.c.relname).where(
+ self._pg_class_relkind_condition(relkinds)
)
- return [name for name, in result]
+ query = self._pg_class_filter_scope_schema(query, schema, scope=scope)
+ return connection.scalars(query).all()
@reflection.cache
def get_table_names(self, connection, schema=None, **kw):
- result = connection.execute(
- sql.text(
- "SELECT c.relname FROM pg_class c "
- "JOIN pg_namespace n ON n.oid = c.relnamespace "
- "WHERE n.nspname = :schema AND c.relkind in ('r', 'p')"
- ).columns(relname=sqltypes.Unicode),
- dict(
- schema=schema
- if schema is not None
- else self.default_schema_name
- ),
+ return self._get_relnames_for_relkinds(
+ connection,
+ schema,
+ pg_catalog.RELKINDS_TABLE_NO_FOREIGN,
+ scope=ObjectScope.DEFAULT,
+ )
+
+ @reflection.cache
+ def get_temp_table_names(self, connection, **kw):
+ return self._get_relnames_for_relkinds(
+ connection,
+ schema=None,
+ relkinds=pg_catalog.RELKINDS_TABLE_NO_FOREIGN,
+ scope=ObjectScope.TEMPORARY,
)
- return [name for name, in result]
@reflection.cache
def _get_foreign_table_names(self, connection, schema=None, **kw):
- result = connection.execute(
- sql.text(
- "SELECT c.relname FROM pg_class c "
- "JOIN pg_namespace n ON n.oid = c.relnamespace "
- "WHERE n.nspname = :schema AND c.relkind = 'f'"
- ).columns(relname=sqltypes.Unicode),
- dict(
- schema=schema
- if schema is not None
- else self.default_schema_name
- ),
+ return self._get_relnames_for_relkinds(
+ connection, schema, relkinds=("f",), scope=ObjectScope.ANY
)
- return [name for name, in result]
@reflection.cache
- def get_view_names(
- self, connection, schema=None, include=("plain", "materialized"), **kw
- ):
+ def get_view_names(self, connection, schema=None, **kw):
+ return self._get_relnames_for_relkinds(
+ connection,
+ schema,
+ pg_catalog.RELKINDS_VIEW,
+ scope=ObjectScope.DEFAULT,
+ )
- include_kind = {"plain": "v", "materialized": "m"}
- try:
- kinds = [include_kind[i] for i in util.to_list(include)]
- except KeyError:
- raise ValueError(
- "include %r unknown, needs to be a sequence containing "
- "one or both of 'plain' and 'materialized'" % (include,)
- )
- if not kinds:
- raise ValueError(
- "empty include, needs to be a sequence containing "
- "one or both of 'plain' and 'materialized'"
- )
+ @reflection.cache
+ def get_materialized_view_names(self, connection, schema=None, **kw):
+ return self._get_relnames_for_relkinds(
+ connection,
+ schema,
+ pg_catalog.RELKINDS_MAT_VIEW,
+ scope=ObjectScope.DEFAULT,
+ )
- result = connection.execute(
- sql.text(
- "SELECT c.relname FROM pg_class c "
- "JOIN pg_namespace n ON n.oid = c.relnamespace "
- "WHERE n.nspname = :schema AND c.relkind IN (%s)"
- % (", ".join("'%s'" % elem for elem in kinds))
- ).columns(relname=sqltypes.Unicode),
- dict(
- schema=schema
- if schema is not None
- else self.default_schema_name
- ),
+ @reflection.cache
+ def get_temp_view_names(self, connection, schema=None, **kw):
+ return self._get_relnames_for_relkinds(
+ connection,
+ schema,
+ # NOTE: do not include temp materialzied views (that do not
+ # seem to be a thing at least up to version 14)
+ pg_catalog.RELKINDS_VIEW,
+ scope=ObjectScope.TEMPORARY,
)
- return [name for name, in result]
@reflection.cache
def get_sequence_names(self, connection, schema=None, **kw):
- if not schema:
- schema = self.default_schema_name
- cursor = connection.execute(
- sql.text(
- "SELECT relname FROM pg_class c join pg_namespace n on "
- "n.oid=c.relnamespace where relkind='S' and "
- "n.nspname=:schema"
- ).bindparams(
- sql.bindparam(
- "schema",
- str(schema),
- type_=sqltypes.Unicode,
- ),
- )
+ return self._get_relnames_for_relkinds(
+ connection, schema, relkinds=("S",), scope=ObjectScope.ANY
)
- return [row[0] for row in cursor]
@reflection.cache
def get_view_definition(self, connection, view_name, schema=None, **kw):
- view_def = connection.scalar(
- sql.text(
- "SELECT pg_get_viewdef(c.oid) view_def FROM pg_class c "
- "JOIN pg_namespace n ON n.oid = c.relnamespace "
- "WHERE n.nspname = :schema AND c.relname = :view_name "
- "AND c.relkind IN ('v', 'm')"
- ).columns(view_def=sqltypes.Unicode),
- dict(
- schema=schema
- if schema is not None
- else self.default_schema_name,
- view_name=view_name,
- ),
+ query = (
+ select(pg_catalog.pg_get_viewdef(pg_catalog.pg_class.c.oid))
+ .select_from(pg_catalog.pg_class)
+ .where(
+ pg_catalog.pg_class.c.relname == view_name,
+ self._pg_class_relkind_condition(
+ pg_catalog.RELKINDS_VIEW + pg_catalog.RELKINDS_MAT_VIEW
+ ),
+ )
)
- return view_def
+ query = self._pg_class_filter_scope_schema(
+ query, schema, scope=ObjectScope.ANY
+ )
+ res = connection.scalar(query)
+ if res is None:
+ raise exc.NoSuchTableError(
+ f"{schema}.{view_name}" if schema else view_name
+ )
+ else:
+ return res
+
+ def _value_or_raise(self, data, table, schema):
+ try:
+ return dict(data)[(schema, table)]
+ except KeyError:
+ raise exc.NoSuchTableError(
+ f"{schema}.{table}" if schema else table
+ ) from None
+
+ def _prepare_filter_names(self, filter_names):
+ if filter_names:
+ return True, {"filter_names": filter_names}
+ else:
+ return False, {}
+
+ def _kind_to_relkinds(self, kind: ObjectKind) -> tuple[str, ...]:
+ if kind is ObjectKind.ANY:
+ return pg_catalog.RELKINDS_ALL_TABLE_LIKE
+ relkinds = ()
+ if ObjectKind.TABLE in kind:
+ relkinds += pg_catalog.RELKINDS_TABLE
+ if ObjectKind.VIEW in kind:
+ relkinds += pg_catalog.RELKINDS_VIEW
+ if ObjectKind.MATERIALIZED_VIEW in kind:
+ relkinds += pg_catalog.RELKINDS_MAT_VIEW
+ return relkinds
@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
-
- table_oid = self.get_table_oid(
- connection, table_name, schema, info_cache=kw.get("info_cache")
+ data = self.get_multi_columns(
+ connection,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
+ return self._value_or_raise(data, table_name, schema)
+ @lru_cache()
+ def _columns_query(self, schema, has_filter_names, scope, kind):
+ # NOTE: the query with the default and identity options scalar
+ # subquery is faster than trying to use outer joins for them
generated = (
- "a.attgenerated as generated"
+ pg_catalog.pg_attribute.c.attgenerated.label("generated")
if self.server_version_info >= (12,)
- else "NULL as generated"
+ else sql.null().label("generated")
)
if self.server_version_info >= (10,):
- # a.attidentity != '' is required or it will reflect also
- # serial columns as identity.
- identity = """\
- (SELECT json_build_object(
- 'always', a.attidentity = 'a',
- 'start', s.seqstart,
- 'increment', s.seqincrement,
- 'minvalue', s.seqmin,
- 'maxvalue', s.seqmax,
- 'cache', s.seqcache,
- 'cycle', s.seqcycle)
- FROM pg_catalog.pg_sequence s
- JOIN pg_catalog.pg_class c on s.seqrelid = c."oid"
- WHERE c.relkind = 'S'
- AND a.attidentity != ''
- AND s.seqrelid = pg_catalog.pg_get_serial_sequence(
- a.attrelid::regclass::text, a.attname
- )::regclass::oid
- ) as identity_options\
- """
+ # join lateral performs worse (~2x slower) than a scalar_subquery
+ identity = (
+ select(
+ sql.func.json_build_object(
+ "always",
+ pg_catalog.pg_attribute.c.attidentity == "a",
+ "start",
+ pg_catalog.pg_sequence.c.seqstart,
+ "increment",
+ pg_catalog.pg_sequence.c.seqincrement,
+ "minvalue",
+ pg_catalog.pg_sequence.c.seqmin,
+ "maxvalue",
+ pg_catalog.pg_sequence.c.seqmax,
+ "cache",
+ pg_catalog.pg_sequence.c.seqcache,
+ "cycle",
+ pg_catalog.pg_sequence.c.seqcycle,
+ )
+ )
+ .select_from(pg_catalog.pg_sequence)
+ .where(
+ # attidentity != '' is required or it will reflect also
+ # serial columns as identity.
+ pg_catalog.pg_attribute.c.attidentity != "",
+ pg_catalog.pg_sequence.c.seqrelid
+ == sql.cast(
+ sql.cast(
+ pg_catalog.pg_get_serial_sequence(
+ sql.cast(
+ sql.cast(
+ pg_catalog.pg_attribute.c.attrelid,
+ REGCLASS,
+ ),
+ TEXT,
+ ),
+ pg_catalog.pg_attribute.c.attname,
+ ),
+ REGCLASS,
+ ),
+ OID,
+ ),
+ )
+ .correlate(pg_catalog.pg_attribute)
+ .scalar_subquery()
+ .label("identity_options")
+ )
else:
- identity = "NULL as identity_options"
-
- SQL_COLS = """
- SELECT a.attname,
- pg_catalog.format_type(a.atttypid, a.atttypmod),
- (
- SELECT pg_catalog.pg_get_expr(d.adbin, d.adrelid)
- FROM pg_catalog.pg_attrdef d
- WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum
- AND a.atthasdef
- ) AS DEFAULT,
- a.attnotnull,
- a.attrelid as table_oid,
- pgd.description as comment,
- %s,
- %s
- FROM pg_catalog.pg_attribute a
- LEFT JOIN pg_catalog.pg_description pgd ON (
- pgd.objoid = a.attrelid AND pgd.objsubid = a.attnum)
- WHERE a.attrelid = :table_oid
- AND a.attnum > 0 AND NOT a.attisdropped
- ORDER BY a.attnum
- """ % (
- generated,
- identity,
+ identity = sql.null().label("identity_options")
+
+ # join lateral performs the same as scalar_subquery here
+ default = (
+ select(
+ pg_catalog.pg_get_expr(
+ pg_catalog.pg_attrdef.c.adbin,
+ pg_catalog.pg_attrdef.c.adrelid,
+ )
+ )
+ .select_from(pg_catalog.pg_attrdef)
+ .where(
+ pg_catalog.pg_attrdef.c.adrelid
+ == pg_catalog.pg_attribute.c.attrelid,
+ pg_catalog.pg_attrdef.c.adnum
+ == pg_catalog.pg_attribute.c.attnum,
+ pg_catalog.pg_attribute.c.atthasdef,
+ )
+ .correlate(pg_catalog.pg_attribute)
+ .scalar_subquery()
+ .label("default")
)
- s = (
- sql.text(SQL_COLS)
- .bindparams(sql.bindparam("table_oid", type_=sqltypes.Integer))
- .columns(attname=sqltypes.Unicode, default=sqltypes.Unicode)
+ relkinds = self._kind_to_relkinds(kind)
+ query = (
+ select(
+ pg_catalog.pg_attribute.c.attname.label("name"),
+ pg_catalog.format_type(
+ pg_catalog.pg_attribute.c.atttypid,
+ pg_catalog.pg_attribute.c.atttypmod,
+ ).label("format_type"),
+ default,
+ pg_catalog.pg_attribute.c.attnotnull.label("not_null"),
+ pg_catalog.pg_class.c.relname.label("table_name"),
+ pg_catalog.pg_description.c.description.label("comment"),
+ generated,
+ identity,
+ )
+ .select_from(pg_catalog.pg_class)
+ # NOTE: postgresql support table with no user column, meaning
+ # there is no row with pg_attribute.attnum > 0. use a left outer
+ # join to avoid filtering these tables.
+ .outerjoin(
+ pg_catalog.pg_attribute,
+ sql.and_(
+ pg_catalog.pg_class.c.oid
+ == pg_catalog.pg_attribute.c.attrelid,
+ pg_catalog.pg_attribute.c.attnum > 0,
+ ~pg_catalog.pg_attribute.c.attisdropped,
+ ),
+ )
+ .outerjoin(
+ pg_catalog.pg_description,
+ sql.and_(
+ pg_catalog.pg_description.c.objoid
+ == pg_catalog.pg_attribute.c.attrelid,
+ pg_catalog.pg_description.c.objsubid
+ == pg_catalog.pg_attribute.c.attnum,
+ ),
+ )
+ .where(self._pg_class_relkind_condition(relkinds))
+ .order_by(
+ pg_catalog.pg_class.c.relname, pg_catalog.pg_attribute.c.attnum
+ )
)
- c = connection.execute(s, dict(table_oid=table_oid))
- rows = c.fetchall()
+ query = self._pg_class_filter_scope_schema(query, schema, scope=scope)
+ if has_filter_names:
+ query = query.where(
+ pg_catalog.pg_class.c.relname.in_(bindparam("filter_names"))
+ )
+ return query
+
+ def get_multi_columns(
+ self, connection, schema, filter_names, scope, kind, **kw
+ ):
+ has_filter_names, params = self._prepare_filter_names(filter_names)
+ query = self._columns_query(schema, has_filter_names, scope, kind)
+ rows = connection.execute(query, params).mappings()
# dictionary with (name, ) if default search path or (schema, name)
# as keys
- domains = self._load_domains(connection)
+ domains = self._load_domains(
+ connection, info_cache=kw.get("info_cache")
+ )
# dictionary with (name, ) if default search path or (schema, name)
# as keys
@@ -3732,257 +3329,340 @@ class PGDialect(default.DefaultDialect):
((rec["name"],), rec)
if rec["visible"]
else ((rec["schema"], rec["name"]), rec)
- for rec in self._load_enums(connection, schema="*")
+ for rec in self._load_enums(
+ connection, schema="*", info_cache=kw.get("info_cache")
+ )
)
- # format columns
- columns = []
-
- for (
- name,
- format_type,
- default_,
- notnull,
- table_oid,
- comment,
- generated,
- identity,
- ) in rows:
- column_info = self._get_column_info(
- name,
- format_type,
- default_,
- notnull,
- domains,
- enums,
- schema,
- comment,
- generated,
- identity,
- )
- columns.append(column_info)
- return columns
+ columns = self._get_columns_info(rows, domains, enums, schema)
+
+ return columns.items()
+
+ def _get_columns_info(self, rows, domains, enums, schema):
+ array_type_pattern = re.compile(r"\[\]$")
+ attype_pattern = re.compile(r"\(.*\)")
+ charlen_pattern = re.compile(r"\(([\d,]+)\)")
+ args_pattern = re.compile(r"\((.*)\)")
+ args_split_pattern = re.compile(r"\s*,\s*")
- def _get_column_info(
- self,
- name,
- format_type,
- default,
- notnull,
- domains,
- enums,
- schema,
- comment,
- generated,
- identity,
- ):
def _handle_array_type(attype):
return (
# strip '[]' from integer[], etc.
- re.sub(r"\[\]$", "", attype),
+ array_type_pattern.sub("", attype),
attype.endswith("[]"),
)
- # strip (*) from character varying(5), timestamp(5)
- # with time zone, geometry(POLYGON), etc.
- attype = re.sub(r"\(.*\)", "", format_type)
+ columns = defaultdict(list)
+ for row_dict in rows:
+ # ensure that each table has an entry, even if it has no columns
+ if row_dict["name"] is None:
+ columns[
+ (schema, row_dict["table_name"])
+ ] = ReflectionDefaults.columns()
+ continue
+ table_cols = columns[(schema, row_dict["table_name"])]
- # strip '[]' from integer[], etc. and check if an array
- attype, is_array = _handle_array_type(attype)
+ format_type = row_dict["format_type"]
+ default = row_dict["default"]
+ name = row_dict["name"]
+ generated = row_dict["generated"]
+ identity = row_dict["identity_options"]
- # strip quotes from case sensitive enum or domain names
- enum_or_domain_key = tuple(util.quoted_token_parser(attype))
+ # strip (*) from character varying(5), timestamp(5)
+ # with time zone, geometry(POLYGON), etc.
+ attype = attype_pattern.sub("", format_type)
- nullable = not notnull
+ # strip '[]' from integer[], etc. and check if an array
+ attype, is_array = _handle_array_type(attype)
- charlen = re.search(r"\(([\d,]+)\)", format_type)
- if charlen:
- charlen = charlen.group(1)
- args = re.search(r"\((.*)\)", format_type)
- if args and args.group(1):
- args = tuple(re.split(r"\s*,\s*", args.group(1)))
- else:
- args = ()
- kwargs = {}
+ # strip quotes from case sensitive enum or domain names
+ enum_or_domain_key = tuple(util.quoted_token_parser(attype))
+
+ nullable = not row_dict["not_null"]
- if attype == "numeric":
+ charlen = charlen_pattern.search(format_type)
if charlen:
- prec, scale = charlen.split(",")
- args = (int(prec), int(scale))
+ charlen = charlen.group(1)
+ args = args_pattern.search(format_type)
+ if args and args.group(1):
+ args = tuple(args_split_pattern.split(args.group(1)))
else:
args = ()
- elif attype == "double precision":
- args = (53,)
- elif attype == "integer":
- args = ()
- elif attype in ("timestamp with time zone", "time with time zone"):
- kwargs["timezone"] = True
- if charlen:
- kwargs["precision"] = int(charlen)
- args = ()
- elif attype in (
- "timestamp without time zone",
- "time without time zone",
- "time",
- ):
- kwargs["timezone"] = False
- if charlen:
- kwargs["precision"] = int(charlen)
- args = ()
- elif attype == "bit varying":
- kwargs["varying"] = True
- if charlen:
+ kwargs = {}
+
+ if attype == "numeric":
+ if charlen:
+ prec, scale = charlen.split(",")
+ args = (int(prec), int(scale))
+ else:
+ args = ()
+ elif attype == "double precision":
+ args = (53,)
+ elif attype == "integer":
+ args = ()
+ elif attype in ("timestamp with time zone", "time with time zone"):
+ kwargs["timezone"] = True
+ if charlen:
+ kwargs["precision"] = int(charlen)
+ args = ()
+ elif attype in (
+ "timestamp without time zone",
+ "time without time zone",
+ "time",
+ ):
+ kwargs["timezone"] = False
+ if charlen:
+ kwargs["precision"] = int(charlen)
+ args = ()
+ elif attype == "bit varying":
+ kwargs["varying"] = True
+ if charlen:
+ args = (int(charlen),)
+ else:
+ args = ()
+ elif attype.startswith("interval"):
+ field_match = re.match(r"interval (.+)", attype, re.I)
+ if charlen:
+ kwargs["precision"] = int(charlen)
+ if field_match:
+ kwargs["fields"] = field_match.group(1)
+ attype = "interval"
+ args = ()
+ elif charlen:
args = (int(charlen),)
+
+ while True:
+ # looping here to suit nested domains
+ if attype in self.ischema_names:
+ coltype = self.ischema_names[attype]
+ break
+ elif enum_or_domain_key in enums:
+ enum = enums[enum_or_domain_key]
+ coltype = ENUM
+ kwargs["name"] = enum["name"]
+ if not enum["visible"]:
+ kwargs["schema"] = enum["schema"]
+ args = tuple(enum["labels"])
+ break
+ elif enum_or_domain_key in domains:
+ domain = domains[enum_or_domain_key]
+ attype = domain["attype"]
+ attype, is_array = _handle_array_type(attype)
+ # strip quotes from case sensitive enum or domain names
+ enum_or_domain_key = tuple(
+ util.quoted_token_parser(attype)
+ )
+ # A table can't override a not null on the domain,
+ # but can override nullable
+ nullable = nullable and domain["nullable"]
+ if domain["default"] and not default:
+ # It can, however, override the default
+ # value, but can't set it to null.
+ default = domain["default"]
+ continue
+ else:
+ coltype = None
+ break
+
+ if coltype:
+ coltype = coltype(*args, **kwargs)
+ if is_array:
+ coltype = self.ischema_names["_array"](coltype)
else:
- args = ()
- elif attype.startswith("interval"):
- field_match = re.match(r"interval (.+)", attype, re.I)
- if charlen:
- kwargs["precision"] = int(charlen)
- if field_match:
- kwargs["fields"] = field_match.group(1)
- attype = "interval"
- args = ()
- elif charlen:
- args = (int(charlen),)
-
- while True:
- # looping here to suit nested domains
- if attype in self.ischema_names:
- coltype = self.ischema_names[attype]
- break
- elif enum_or_domain_key in enums:
- enum = enums[enum_or_domain_key]
- coltype = ENUM
- kwargs["name"] = enum["name"]
- if not enum["visible"]:
- kwargs["schema"] = enum["schema"]
- args = tuple(enum["labels"])
- break
- elif enum_or_domain_key in domains:
- domain = domains[enum_or_domain_key]
- attype = domain["attype"]
- attype, is_array = _handle_array_type(attype)
- # strip quotes from case sensitive enum or domain names
- enum_or_domain_key = tuple(util.quoted_token_parser(attype))
- # A table can't override a not null on the domain,
- # but can override nullable
- nullable = nullable and domain["nullable"]
- if domain["default"] and not default:
- # It can, however, override the default
- # value, but can't set it to null.
- default = domain["default"]
- continue
+ util.warn(
+ "Did not recognize type '%s' of column '%s'"
+ % (attype, name)
+ )
+ coltype = sqltypes.NULLTYPE
+
+ # If a zero byte or blank string depending on driver (is also
+ # absent for older PG versions), then not a generated column.
+ # Otherwise, s = stored. (Other values might be added in the
+ # future.)
+ if generated not in (None, "", b"\x00"):
+ computed = dict(
+ sqltext=default, persisted=generated in ("s", b"s")
+ )
+ default = None
else:
- coltype = None
- break
+ computed = None
- if coltype:
- coltype = coltype(*args, **kwargs)
- if is_array:
- coltype = self.ischema_names["_array"](coltype)
- else:
- util.warn(
- "Did not recognize type '%s' of column '%s'" % (attype, name)
+ # adjust the default value
+ autoincrement = False
+ if default is not None:
+ match = re.search(r"""(nextval\(')([^']+)('.*$)""", default)
+ if match is not None:
+ if issubclass(coltype._type_affinity, sqltypes.Integer):
+ autoincrement = True
+ # the default is related to a Sequence
+ if "." not in match.group(2) and schema is not None:
+ # unconditionally quote the schema name. this could
+ # later be enhanced to obey quoting rules /
+ # "quote schema"
+ default = (
+ match.group(1)
+ + ('"%s"' % schema)
+ + "."
+ + match.group(2)
+ + match.group(3)
+ )
+
+ column_info = {
+ "name": name,
+ "type": coltype,
+ "nullable": nullable,
+ "default": default,
+ "autoincrement": autoincrement or identity is not None,
+ "comment": row_dict["comment"],
+ }
+ if computed is not None:
+ column_info["computed"] = computed
+ if identity is not None:
+ column_info["identity"] = identity
+
+ table_cols.append(column_info)
+
+ return columns
+
+ @lru_cache()
+ def _table_oids_query(self, schema, has_filter_names, scope, kind):
+ relkinds = self._kind_to_relkinds(kind)
+ oid_q = select(
+ pg_catalog.pg_class.c.oid, pg_catalog.pg_class.c.relname
+ ).where(self._pg_class_relkind_condition(relkinds))
+ oid_q = self._pg_class_filter_scope_schema(oid_q, schema, scope=scope)
+
+ if has_filter_names:
+ oid_q = oid_q.where(
+ pg_catalog.pg_class.c.relname.in_(bindparam("filter_names"))
)
- coltype = sqltypes.NULLTYPE
-
- # If a zero byte or blank string depending on driver (is also absent
- # for older PG versions), then not a generated column. Otherwise, s =
- # stored. (Other values might be added in the future.)
- if generated not in (None, "", b"\x00"):
- computed = dict(
- sqltext=default, persisted=generated in ("s", b"s")
+ return oid_q
+
+ @reflection.flexi_cache(
+ ("schema", InternalTraversal.dp_string),
+ ("filter_names", InternalTraversal.dp_string_list),
+ ("kind", InternalTraversal.dp_plain_obj),
+ ("scope", InternalTraversal.dp_plain_obj),
+ )
+ def _get_table_oids(
+ self, connection, schema, filter_names, scope, kind, **kw
+ ):
+ has_filter_names, params = self._prepare_filter_names(filter_names)
+ oid_q = self._table_oids_query(schema, has_filter_names, scope, kind)
+ result = connection.execute(oid_q, params)
+ return result.all()
+
+ @util.memoized_property
+ def _constraint_query(self):
+ con_sq = (
+ select(
+ pg_catalog.pg_constraint.c.conrelid,
+ pg_catalog.pg_constraint.c.conname,
+ sql.func.unnest(pg_catalog.pg_constraint.c.conkey).label(
+ "attnum"
+ ),
+ sql.func.generate_subscripts(
+ pg_catalog.pg_constraint.c.conkey, 1
+ ).label("ord"),
)
- default = None
- else:
- computed = None
-
- # adjust the default value
- autoincrement = False
- if default is not None:
- match = re.search(r"""(nextval\(')([^']+)('.*$)""", default)
- if match is not None:
- if issubclass(coltype._type_affinity, sqltypes.Integer):
- autoincrement = True
- # the default is related to a Sequence
- sch = schema
- if "." not in match.group(2) and sch is not None:
- # unconditionally quote the schema name. this could
- # later be enhanced to obey quoting rules /
- # "quote schema"
- default = (
- match.group(1)
- + ('"%s"' % sch)
- + "."
- + match.group(2)
- + match.group(3)
- )
+ .where(
+ pg_catalog.pg_constraint.c.contype == bindparam("contype"),
+ pg_catalog.pg_constraint.c.conrelid.in_(bindparam("oids")),
+ )
+ .subquery("con")
+ )
- column_info = dict(
- name=name,
- type=coltype,
- nullable=nullable,
- default=default,
- autoincrement=autoincrement or identity is not None,
- comment=comment,
+ attr_sq = (
+ select(
+ con_sq.c.conrelid,
+ con_sq.c.conname,
+ pg_catalog.pg_attribute.c.attname,
+ )
+ .select_from(pg_catalog.pg_attribute)
+ .join(
+ con_sq,
+ sql.and_(
+ pg_catalog.pg_attribute.c.attnum == con_sq.c.attnum,
+ pg_catalog.pg_attribute.c.attrelid == con_sq.c.conrelid,
+ ),
+ )
+ .order_by(con_sq.c.conname, con_sq.c.ord)
+ .subquery("attr")
)
- if computed is not None:
- column_info["computed"] = computed
- if identity is not None:
- column_info["identity"] = identity
- return column_info
- @reflection.cache
- def get_pk_constraint(self, connection, table_name, schema=None, **kw):
- table_oid = self.get_table_oid(
- connection, table_name, schema, info_cache=kw.get("info_cache")
+ return (
+ select(
+ attr_sq.c.conrelid,
+ sql.func.array_agg(attr_sq.c.attname).label("cols"),
+ attr_sq.c.conname,
+ )
+ .group_by(attr_sq.c.conrelid, attr_sq.c.conname)
+ .order_by(attr_sq.c.conrelid, attr_sq.c.conname)
)
- if self.server_version_info < (8, 4):
- PK_SQL = """
- SELECT a.attname
- FROM
- pg_class t
- join pg_index ix on t.oid = ix.indrelid
- join pg_attribute a
- on t.oid=a.attrelid AND %s
- WHERE
- t.oid = :table_oid and ix.indisprimary = 't'
- ORDER BY a.attnum
- """ % self._pg_index_any(
- "a.attnum", "ix.indkey"
+ def _reflect_constraint(
+ self, connection, contype, schema, filter_names, scope, kind, **kw
+ ):
+ table_oids = self._get_table_oids(
+ connection, schema, filter_names, scope, kind, **kw
+ )
+ batches = list(table_oids)
+
+ while batches:
+ batch = batches[0:3000]
+ batches[0:3000] = []
+
+ result = connection.execute(
+ self._constraint_query,
+ {"oids": [r[0] for r in batch], "contype": contype},
)
- else:
- # unnest() and generate_subscripts() both introduced in
- # version 8.4
- PK_SQL = """
- SELECT a.attname
- FROM pg_attribute a JOIN (
- SELECT unnest(ix.indkey) attnum,
- generate_subscripts(ix.indkey, 1) ord
- FROM pg_index ix
- WHERE ix.indrelid = :table_oid AND ix.indisprimary
- ) k ON a.attnum=k.attnum
- WHERE a.attrelid = :table_oid
- ORDER BY k.ord
- """
- t = sql.text(PK_SQL).columns(attname=sqltypes.Unicode)
- c = connection.execute(t, dict(table_oid=table_oid))
- cols = [r[0] for r in c.fetchall()]
-
- PK_CONS_SQL = """
- SELECT conname
- FROM pg_catalog.pg_constraint r
- WHERE r.conrelid = :table_oid AND r.contype = 'p'
- ORDER BY 1
- """
- t = sql.text(PK_CONS_SQL).columns(conname=sqltypes.Unicode)
- c = connection.execute(t, dict(table_oid=table_oid))
- name = c.scalar()
+ result_by_oid = defaultdict(list)
+ for oid, cols, constraint_name in result:
+ result_by_oid[oid].append((cols, constraint_name))
+
+ for oid, tablename in batch:
+ for_oid = result_by_oid.get(oid, ())
+ if for_oid:
+ for cols, constraint in for_oid:
+ yield tablename, cols, constraint
+ else:
+ yield tablename, None, None
+
+ @reflection.cache
+ def get_pk_constraint(self, connection, table_name, schema=None, **kw):
+ data = self.get_multi_pk_constraint(
+ connection,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
+ )
+ return self._value_or_raise(data, table_name, schema)
+
+ def get_multi_pk_constraint(
+ self, connection, schema, filter_names, scope, kind, **kw
+ ):
+ result = self._reflect_constraint(
+ connection, "p", schema, filter_names, scope, kind, **kw
+ )
- return {"constrained_columns": cols, "name": name}
+ # only a single pk can be present for each table. Return an entry
+ # even if a table has no primary key
+ default = ReflectionDefaults.pk_constraint
+ return (
+ (
+ (schema, table_name),
+ {
+ "constrained_columns": [] if cols is None else cols,
+ "name": pk_name,
+ }
+ if pk_name is not None
+ else default(),
+ )
+ for (table_name, cols, pk_name) in result
+ )
@reflection.cache
def get_foreign_keys(
@@ -3993,27 +3673,71 @@ class PGDialect(default.DefaultDialect):
postgresql_ignore_search_path=False,
**kw,
):
- preparer = self.identifier_preparer
- table_oid = self.get_table_oid(
- connection, table_name, schema, info_cache=kw.get("info_cache")
+ data = self.get_multi_foreign_keys(
+ connection,
+ schema=schema,
+ filter_names=[table_name],
+ postgresql_ignore_search_path=postgresql_ignore_search_path,
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
+ return self._value_or_raise(data, table_name, schema)
- FK_SQL = """
- SELECT r.conname,
- pg_catalog.pg_get_constraintdef(r.oid, true) as condef,
- n.nspname as conschema
- FROM pg_catalog.pg_constraint r,
- pg_namespace n,
- pg_class c
-
- WHERE r.conrelid = :table AND
- r.contype = 'f' AND
- c.oid = confrelid AND
- n.oid = c.relnamespace
- ORDER BY 1
- """
- # https://www.postgresql.org/docs/9.0/static/sql-createtable.html
- FK_REGEX = re.compile(
+ @lru_cache()
+ def _foreing_key_query(self, schema, has_filter_names, scope, kind):
+ pg_class_ref = pg_catalog.pg_class.alias("cls_ref")
+ pg_namespace_ref = pg_catalog.pg_namespace.alias("nsp_ref")
+ relkinds = self._kind_to_relkinds(kind)
+ query = (
+ select(
+ pg_catalog.pg_class.c.relname,
+ pg_catalog.pg_constraint.c.conname,
+ sql.case(
+ (
+ pg_catalog.pg_constraint.c.oid.is_not(None),
+ pg_catalog.pg_get_constraintdef(
+ pg_catalog.pg_constraint.c.oid, True
+ ),
+ ),
+ else_=None,
+ ),
+ pg_namespace_ref.c.nspname,
+ )
+ .select_from(pg_catalog.pg_class)
+ .outerjoin(
+ pg_catalog.pg_constraint,
+ sql.and_(
+ pg_catalog.pg_class.c.oid
+ == pg_catalog.pg_constraint.c.conrelid,
+ pg_catalog.pg_constraint.c.contype == "f",
+ ),
+ )
+ .outerjoin(
+ pg_class_ref,
+ pg_class_ref.c.oid == pg_catalog.pg_constraint.c.confrelid,
+ )
+ .outerjoin(
+ pg_namespace_ref,
+ pg_class_ref.c.relnamespace == pg_namespace_ref.c.oid,
+ )
+ .order_by(
+ pg_catalog.pg_class.c.relname,
+ pg_catalog.pg_constraint.c.conname,
+ )
+ .where(self._pg_class_relkind_condition(relkinds))
+ )
+ query = self._pg_class_filter_scope_schema(query, schema, scope)
+ if has_filter_names:
+ query = query.where(
+ pg_catalog.pg_class.c.relname.in_(bindparam("filter_names"))
+ )
+ return query
+
+ @util.memoized_property
+ def _fk_regex_pattern(self):
+ # https://www.postgresql.org/docs/14.0/static/sql-createtable.html
+ return re.compile(
r"FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)"
r"[\s]?(MATCH (FULL|PARTIAL|SIMPLE)+)?"
r"[\s]?(ON UPDATE "
@@ -4024,12 +3748,33 @@ class PGDialect(default.DefaultDialect):
r"[\s]?(INITIALLY (DEFERRED|IMMEDIATE)+)?"
)
- t = sql.text(FK_SQL).columns(
- conname=sqltypes.Unicode, condef=sqltypes.Unicode
- )
- c = connection.execute(t, dict(table=table_oid))
- fkeys = []
- for conname, condef, conschema in c.fetchall():
+ def get_multi_foreign_keys(
+ self,
+ connection,
+ schema,
+ filter_names,
+ scope,
+ kind,
+ postgresql_ignore_search_path=False,
+ **kw,
+ ):
+ preparer = self.identifier_preparer
+
+ has_filter_names, params = self._prepare_filter_names(filter_names)
+ query = self._foreing_key_query(schema, has_filter_names, scope, kind)
+ result = connection.execute(query, params)
+
+ FK_REGEX = self._fk_regex_pattern
+
+ fkeys = defaultdict(list)
+ default = ReflectionDefaults.foreign_keys
+ for table_name, conname, condef, conschema in result:
+ # ensure that each table has an entry, even if it has
+ # no foreign keys
+ if conname is None:
+ fkeys[(schema, table_name)] = default()
+ continue
+ table_fks = fkeys[(schema, table_name)]
m = re.search(FK_REGEX, condef).groups()
(
@@ -4096,317 +3841,406 @@ class PGDialect(default.DefaultDialect):
"referred_columns": referred_columns,
"options": options,
}
- fkeys.append(fkey_d)
- return fkeys
-
- def _pg_index_any(self, col, compare_to):
- if self.server_version_info < (8, 1):
- # https://www.postgresql.org/message-id/10279.1124395722@sss.pgh.pa.us
- # "In CVS tip you could replace this with "attnum = ANY (indkey)".
- # Unfortunately, most array support doesn't work on int2vector in
- # pre-8.1 releases, so I think you're kinda stuck with the above
- # for now.
- # regards, tom lane"
- return "(%s)" % " OR ".join(
- "%s[%d] = %s" % (compare_to, ind, col) for ind in range(0, 10)
- )
- else:
- return "%s = ANY(%s)" % (col, compare_to)
+ table_fks.append(fkey_d)
+ return fkeys.items()
@reflection.cache
- def get_indexes(self, connection, table_name, schema, **kw):
- table_oid = self.get_table_oid(
- connection, table_name, schema, info_cache=kw.get("info_cache")
+ def get_indexes(self, connection, table_name, schema=None, **kw):
+ data = self.get_multi_indexes(
+ connection,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
+ return self._value_or_raise(data, table_name, schema)
- # cast indkey as varchar since it's an int2vector,
- # returned as a list by some drivers such as pypostgresql
-
- if self.server_version_info < (8, 5):
- IDX_SQL = """
- SELECT
- i.relname as relname,
- ix.indisunique, ix.indexprs, ix.indpred,
- a.attname, a.attnum, NULL, ix.indkey%s,
- %s, %s, am.amname,
- NULL as indnkeyatts
- FROM
- pg_class t
- join pg_index ix on t.oid = ix.indrelid
- join pg_class i on i.oid = ix.indexrelid
- left outer join
- pg_attribute a
- on t.oid = a.attrelid and %s
- left outer join
- pg_am am
- on i.relam = am.oid
- WHERE
- t.relkind IN ('r', 'v', 'f', 'm')
- and t.oid = :table_oid
- and ix.indisprimary = 'f'
- ORDER BY
- t.relname,
- i.relname
- """ % (
- # version 8.3 here was based on observing the
- # cast does not work in PG 8.2.4, does work in 8.3.0.
- # nothing in PG changelogs regarding this.
- "::varchar" if self.server_version_info >= (8, 3) else "",
- "ix.indoption::varchar"
- if self.server_version_info >= (8, 3)
- else "NULL",
- "i.reloptions"
- if self.server_version_info >= (8, 2)
- else "NULL",
- self._pg_index_any("a.attnum", "ix.indkey"),
+ @util.memoized_property
+ def _index_query(self):
+ pg_class_index = pg_catalog.pg_class.alias("cls_idx")
+ # NOTE: repeating oids clause improve query performance
+
+ # subquery to get the columns
+ idx_sq = (
+ select(
+ pg_catalog.pg_index.c.indexrelid,
+ pg_catalog.pg_index.c.indrelid,
+ sql.func.unnest(pg_catalog.pg_index.c.indkey).label("attnum"),
+ sql.func.generate_subscripts(
+ pg_catalog.pg_index.c.indkey, 1
+ ).label("ord"),
)
- else:
- IDX_SQL = """
- SELECT
- i.relname as relname,
- ix.indisunique, ix.indexprs,
- a.attname, a.attnum, c.conrelid, ix.indkey::varchar,
- ix.indoption::varchar, i.reloptions, am.amname,
- pg_get_expr(ix.indpred, ix.indrelid),
- %s as indnkeyatts
- FROM
- pg_class t
- join pg_index ix on t.oid = ix.indrelid
- join pg_class i on i.oid = ix.indexrelid
- left outer join
- pg_attribute a
- on t.oid = a.attrelid and a.attnum = ANY(ix.indkey)
- left outer join
- pg_constraint c
- on (ix.indrelid = c.conrelid and
- ix.indexrelid = c.conindid and
- c.contype in ('p', 'u', 'x'))
- left outer join
- pg_am am
- on i.relam = am.oid
- WHERE
- t.relkind IN ('r', 'v', 'f', 'm', 'p')
- and t.oid = :table_oid
- and ix.indisprimary = 'f'
- ORDER BY
- t.relname,
- i.relname
- """ % (
- "ix.indnkeyatts"
- if self.server_version_info >= (11, 0)
- else "NULL",
+ .where(
+ ~pg_catalog.pg_index.c.indisprimary,
+ pg_catalog.pg_index.c.indrelid.in_(bindparam("oids")),
)
+ .subquery("idx")
+ )
- t = sql.text(IDX_SQL).columns(
- relname=sqltypes.Unicode, attname=sqltypes.Unicode
+ attr_sq = (
+ select(
+ idx_sq.c.indexrelid,
+ idx_sq.c.indrelid,
+ pg_catalog.pg_attribute.c.attname,
+ )
+ .select_from(pg_catalog.pg_attribute)
+ .join(
+ idx_sq,
+ sql.and_(
+ pg_catalog.pg_attribute.c.attnum == idx_sq.c.attnum,
+ pg_catalog.pg_attribute.c.attrelid == idx_sq.c.indrelid,
+ ),
+ )
+ .where(idx_sq.c.indrelid.in_(bindparam("oids")))
+ .order_by(idx_sq.c.indexrelid, idx_sq.c.ord)
+ .subquery("idx_attr")
)
- c = connection.execute(t, dict(table_oid=table_oid))
- indexes = defaultdict(lambda: defaultdict(dict))
+ cols_sq = (
+ select(
+ attr_sq.c.indexrelid,
+ attr_sq.c.indrelid,
+ sql.func.array_agg(attr_sq.c.attname).label("cols"),
+ )
+ .group_by(attr_sq.c.indexrelid, attr_sq.c.indrelid)
+ .subquery("idx_cols")
+ )
- sv_idx_name = None
- for row in c.fetchall():
- (
- idx_name,
- unique,
- expr,
- col,
- col_num,
- conrelid,
- idx_key,
- idx_option,
- options,
- amname,
- filter_definition,
- indnkeyatts,
- ) = row
+ if self.server_version_info >= (11, 0):
+ indnkeyatts = pg_catalog.pg_index.c.indnkeyatts
+ else:
+ indnkeyatts = sql.null().label("indnkeyatts")
- if expr:
- if idx_name != sv_idx_name:
- util.warn(
- "Skipped unsupported reflection of "
- "expression-based index %s" % idx_name
- )
- sv_idx_name = idx_name
- continue
+ query = (
+ select(
+ pg_catalog.pg_index.c.indrelid,
+ pg_class_index.c.relname.label("relname_index"),
+ pg_catalog.pg_index.c.indisunique,
+ pg_catalog.pg_index.c.indexprs,
+ pg_catalog.pg_constraint.c.conrelid.is_not(None).label(
+ "has_constraint"
+ ),
+ pg_catalog.pg_index.c.indoption,
+ pg_class_index.c.reloptions,
+ pg_catalog.pg_am.c.amname,
+ pg_catalog.pg_get_expr(
+ pg_catalog.pg_index.c.indpred,
+ pg_catalog.pg_index.c.indrelid,
+ ).label("filter_definition"),
+ indnkeyatts,
+ cols_sq.c.cols.label("index_cols"),
+ )
+ .select_from(pg_catalog.pg_index)
+ .where(
+ pg_catalog.pg_index.c.indrelid.in_(bindparam("oids")),
+ ~pg_catalog.pg_index.c.indisprimary,
+ )
+ .join(
+ pg_class_index,
+ pg_catalog.pg_index.c.indexrelid == pg_class_index.c.oid,
+ )
+ .join(
+ pg_catalog.pg_am,
+ pg_class_index.c.relam == pg_catalog.pg_am.c.oid,
+ )
+ .outerjoin(
+ cols_sq,
+ pg_catalog.pg_index.c.indexrelid == cols_sq.c.indexrelid,
+ )
+ .outerjoin(
+ pg_catalog.pg_constraint,
+ sql.and_(
+ pg_catalog.pg_index.c.indrelid
+ == pg_catalog.pg_constraint.c.conrelid,
+ pg_catalog.pg_index.c.indexrelid
+ == pg_catalog.pg_constraint.c.conindid,
+ pg_catalog.pg_constraint.c.contype
+ == sql.any_(_array.array(("p", "u", "x"))),
+ ),
+ )
+ .order_by(pg_catalog.pg_index.c.indrelid, pg_class_index.c.relname)
+ )
+ return query
- has_idx = idx_name in indexes
- index = indexes[idx_name]
- if col is not None:
- index["cols"][col_num] = col
- if not has_idx:
- idx_keys = idx_key.split()
- # "The number of key columns in the index, not counting any
- # included columns, which are merely stored and do not
- # participate in the index semantics"
- if indnkeyatts and idx_keys[indnkeyatts:]:
- # this is a "covering index" which has INCLUDE columns
- # as well as regular index columns
- inc_keys = idx_keys[indnkeyatts:]
- idx_keys = idx_keys[:indnkeyatts]
- else:
- inc_keys = []
+ def get_multi_indexes(
+ self, connection, schema, filter_names, scope, kind, **kw
+ ):
- index["key"] = [int(k.strip()) for k in idx_keys]
- index["inc"] = [int(k.strip()) for k in inc_keys]
+ table_oids = self._get_table_oids(
+ connection, schema, filter_names, scope, kind, **kw
+ )
- # (new in pg 8.3)
- # "pg_index.indoption" is list of ints, one per column/expr.
- # int acts as bitmask: 0x01=DESC, 0x02=NULLSFIRST
- sorting = {}
- for col_idx, col_flags in enumerate(
- (idx_option or "").split()
- ):
- col_flags = int(col_flags.strip())
- col_sorting = ()
- # try to set flags only if they differ from PG defaults...
- if col_flags & 0x01:
- col_sorting += ("desc",)
- if not (col_flags & 0x02):
- col_sorting += ("nulls_last",)
+ indexes = defaultdict(list)
+ default = ReflectionDefaults.indexes
+
+ batches = list(table_oids)
+
+ while batches:
+ batch = batches[0:3000]
+ batches[0:3000] = []
+
+ result = connection.execute(
+ self._index_query, {"oids": [r[0] for r in batch]}
+ ).mappings()
+
+ result_by_oid = defaultdict(list)
+ for row_dict in result:
+ result_by_oid[row_dict["indrelid"]].append(row_dict)
+
+ for oid, table_name in batch:
+ if oid not in result_by_oid:
+ # ensure that each table has an entry, even if reflection
+ # is skipped because not supported
+ indexes[(schema, table_name)] = default()
+ continue
+
+ for row in result_by_oid[oid]:
+ index_name = row["relname_index"]
+
+ table_indexes = indexes[(schema, table_name)]
+
+ if row["indexprs"]:
+ tn = (
+ table_name
+ if schema is None
+ else f"{schema}.{table_name}"
+ )
+ util.warn(
+ "Skipped unsupported reflection of "
+ f"expression-based index {index_name} of "
+ f"table {tn}"
+ )
+ continue
+
+ all_cols = row["index_cols"]
+ indnkeyatts = row["indnkeyatts"]
+ # "The number of key columns in the index, not counting any
+ # included columns, which are merely stored and do not
+ # participate in the index semantics"
+ if indnkeyatts and all_cols[indnkeyatts:]:
+ # this is a "covering index" which has INCLUDE columns
+ # as well as regular index columns
+ inc_cols = all_cols[indnkeyatts:]
+ idx_cols = all_cols[:indnkeyatts]
else:
- if col_flags & 0x02:
- col_sorting += ("nulls_first",)
- if col_sorting:
- sorting[col_idx] = col_sorting
- if sorting:
- index["sorting"] = sorting
-
- index["unique"] = unique
- if conrelid is not None:
- index["duplicates_constraint"] = idx_name
- if options:
- index["options"] = dict(
- [option.split("=") for option in options]
- )
-
- # it *might* be nice to include that this is 'btree' in the
- # reflection info. But we don't want an Index object
- # to have a ``postgresql_using`` in it that is just the
- # default, so for the moment leaving this out.
- if amname and amname != "btree":
- index["amname"] = amname
-
- if filter_definition:
- index["postgresql_where"] = filter_definition
+ idx_cols = all_cols
+ inc_cols = []
+
+ index = {
+ "name": index_name,
+ "unique": row["indisunique"],
+ "column_names": idx_cols,
+ }
+
+ sorting = {}
+ for col_index, col_flags in enumerate(row["indoption"]):
+ col_sorting = ()
+ # try to set flags only if they differ from PG
+ # defaults...
+ if col_flags & 0x01:
+ col_sorting += ("desc",)
+ if not (col_flags & 0x02):
+ col_sorting += ("nulls_last",)
+ else:
+ if col_flags & 0x02:
+ col_sorting += ("nulls_first",)
+ if col_sorting:
+ sorting[idx_cols[col_index]] = col_sorting
+ if sorting:
+ index["column_sorting"] = sorting
+ if row["has_constraint"]:
+ index["duplicates_constraint"] = index_name
+
+ dialect_options = {}
+ if row["reloptions"]:
+ dialect_options["postgresql_with"] = dict(
+ [option.split("=") for option in row["reloptions"]]
+ )
+ # it *might* be nice to include that this is 'btree' in the
+ # reflection info. But we don't want an Index object
+ # to have a ``postgresql_using`` in it that is just the
+ # default, so for the moment leaving this out.
+ amname = row["amname"]
+ if amname != "btree":
+ dialect_options["postgresql_using"] = row["amname"]
+ if row["filter_definition"]:
+ dialect_options["postgresql_where"] = row[
+ "filter_definition"
+ ]
+ if self.server_version_info >= (11, 0):
+ # NOTE: this is legacy, this is part of
+ # dialect_options now as of #7382
+ index["include_columns"] = inc_cols
+ dialect_options["postgresql_include"] = inc_cols
+ if dialect_options:
+ index["dialect_options"] = dialect_options
- result = []
- for name, idx in indexes.items():
- entry = {
- "name": name,
- "unique": idx["unique"],
- "column_names": [idx["cols"][i] for i in idx["key"]],
- }
- if self.server_version_info >= (11, 0):
- # NOTE: this is legacy, this is part of dialect_options now
- # as of #7382
- entry["include_columns"] = [idx["cols"][i] for i in idx["inc"]]
- if "duplicates_constraint" in idx:
- entry["duplicates_constraint"] = idx["duplicates_constraint"]
- if "sorting" in idx:
- entry["column_sorting"] = dict(
- (idx["cols"][idx["key"][i]], value)
- for i, value in idx["sorting"].items()
- )
- if "include_columns" in entry:
- entry.setdefault("dialect_options", {})[
- "postgresql_include"
- ] = entry["include_columns"]
- if "options" in idx:
- entry.setdefault("dialect_options", {})[
- "postgresql_with"
- ] = idx["options"]
- if "amname" in idx:
- entry.setdefault("dialect_options", {})[
- "postgresql_using"
- ] = idx["amname"]
- if "postgresql_where" in idx:
- entry.setdefault("dialect_options", {})[
- "postgresql_where"
- ] = idx["postgresql_where"]
- result.append(entry)
- return result
+ table_indexes.append(index)
+ return indexes.items()
@reflection.cache
def get_unique_constraints(
self, connection, table_name, schema=None, **kw
):
- table_oid = self.get_table_oid(
- connection, table_name, schema, info_cache=kw.get("info_cache")
+ data = self.get_multi_unique_constraints(
+ connection,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
+ return self._value_or_raise(data, table_name, schema)
- UNIQUE_SQL = """
- SELECT
- cons.conname as name,
- cons.conkey as key,
- a.attnum as col_num,
- a.attname as col_name
- FROM
- pg_catalog.pg_constraint cons
- join pg_attribute a
- on cons.conrelid = a.attrelid AND
- a.attnum = ANY(cons.conkey)
- WHERE
- cons.conrelid = :table_oid AND
- cons.contype = 'u'
- """
-
- t = sql.text(UNIQUE_SQL).columns(col_name=sqltypes.Unicode)
- c = connection.execute(t, dict(table_oid=table_oid))
+ def get_multi_unique_constraints(
+ self,
+ connection,
+ schema,
+ filter_names,
+ scope,
+ kind,
+ **kw,
+ ):
+ result = self._reflect_constraint(
+ connection, "u", schema, filter_names, scope, kind, **kw
+ )
- uniques = defaultdict(lambda: defaultdict(dict))
- for row in c.fetchall():
- uc = uniques[row.name]
- uc["key"] = row.key
- uc["cols"][row.col_num] = row.col_name
+ # each table can have multiple unique constraints
+ uniques = defaultdict(list)
+ default = ReflectionDefaults.unique_constraints
+ for (table_name, cols, con_name) in result:
+ # ensure a list is created for each table. leave it empty if
+ # the table has no unique cosntraint
+ if con_name is None:
+ uniques[(schema, table_name)] = default()
+ continue
- return [
- {"name": name, "column_names": [uc["cols"][i] for i in uc["key"]]}
- for name, uc in uniques.items()
- ]
+ uniques[(schema, table_name)].append(
+ {
+ "column_names": cols,
+ "name": con_name,
+ }
+ )
+ return uniques.items()
@reflection.cache
def get_table_comment(self, connection, table_name, schema=None, **kw):
- table_oid = self.get_table_oid(
- connection, table_name, schema, info_cache=kw.get("info_cache")
+ data = self.get_multi_table_comment(
+ connection,
+ schema,
+ [table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
+ return self._value_or_raise(data, table_name, schema)
- COMMENT_SQL = """
- SELECT
- pgd.description as table_comment
- FROM
- pg_catalog.pg_description pgd
- WHERE
- pgd.objsubid = 0 AND
- pgd.objoid = :table_oid
- """
+ @lru_cache()
+ def _comment_query(self, schema, has_filter_names, scope, kind):
+ relkinds = self._kind_to_relkinds(kind)
+ query = (
+ select(
+ pg_catalog.pg_class.c.relname,
+ pg_catalog.pg_description.c.description,
+ )
+ .select_from(pg_catalog.pg_class)
+ .outerjoin(
+ pg_catalog.pg_description,
+ sql.and_(
+ pg_catalog.pg_class.c.oid
+ == pg_catalog.pg_description.c.objoid,
+ pg_catalog.pg_description.c.objsubid == 0,
+ ),
+ )
+ .where(self._pg_class_relkind_condition(relkinds))
+ )
+ query = self._pg_class_filter_scope_schema(query, schema, scope)
+ if has_filter_names:
+ query = query.where(
+ pg_catalog.pg_class.c.relname.in_(bindparam("filter_names"))
+ )
+ return query
- c = connection.execute(
- sql.text(COMMENT_SQL), dict(table_oid=table_oid)
+ def get_multi_table_comment(
+ self, connection, schema, filter_names, scope, kind, **kw
+ ):
+ has_filter_names, params = self._prepare_filter_names(filter_names)
+ query = self._comment_query(schema, has_filter_names, scope, kind)
+ result = connection.execute(query, params)
+
+ default = ReflectionDefaults.table_comment
+ return (
+ (
+ (schema, table),
+ {"text": comment} if comment is not None else default(),
+ )
+ for table, comment in result
)
- return {"text": c.scalar()}
@reflection.cache
def get_check_constraints(self, connection, table_name, schema=None, **kw):
- table_oid = self.get_table_oid(
- connection, table_name, schema, info_cache=kw.get("info_cache")
+ data = self.get_multi_check_constraints(
+ connection,
+ schema,
+ [table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
+ return self._value_or_raise(data, table_name, schema)
- CHECK_SQL = """
- SELECT
- cons.conname as name,
- pg_get_constraintdef(cons.oid) as src
- FROM
- pg_catalog.pg_constraint cons
- WHERE
- cons.conrelid = :table_oid AND
- cons.contype = 'c'
- """
-
- c = connection.execute(sql.text(CHECK_SQL), dict(table_oid=table_oid))
+ @lru_cache()
+ def _check_constraint_query(self, schema, has_filter_names, scope, kind):
+ relkinds = self._kind_to_relkinds(kind)
+ query = (
+ select(
+ pg_catalog.pg_class.c.relname,
+ pg_catalog.pg_constraint.c.conname,
+ sql.case(
+ (
+ pg_catalog.pg_constraint.c.oid.is_not(None),
+ pg_catalog.pg_get_constraintdef(
+ pg_catalog.pg_constraint.c.oid
+ ),
+ ),
+ else_=None,
+ ),
+ )
+ .select_from(pg_catalog.pg_class)
+ .outerjoin(
+ pg_catalog.pg_constraint,
+ sql.and_(
+ pg_catalog.pg_class.c.oid
+ == pg_catalog.pg_constraint.c.conrelid,
+ pg_catalog.pg_constraint.c.contype == "c",
+ ),
+ )
+ .where(self._pg_class_relkind_condition(relkinds))
+ )
+ query = self._pg_class_filter_scope_schema(query, schema, scope)
+ if has_filter_names:
+ query = query.where(
+ pg_catalog.pg_class.c.relname.in_(bindparam("filter_names"))
+ )
+ return query
- ret = []
- for name, src in c:
+ def get_multi_check_constraints(
+ self, connection, schema, filter_names, scope, kind, **kw
+ ):
+ has_filter_names, params = self._prepare_filter_names(filter_names)
+ query = self._check_constraint_query(
+ schema, has_filter_names, scope, kind
+ )
+ result = connection.execute(query, params)
+
+ check_constraints = defaultdict(list)
+ default = ReflectionDefaults.check_constraints
+ for table_name, check_name, src in result:
+ # only two cases for check_name and src: both null or both defined
+ if check_name is None and src is None:
+ check_constraints[(schema, table_name)] = default()
+ continue
# samples:
# "CHECK (((a > 1) AND (a < 5)))"
# "CHECK (((a = 1) OR ((a > 2) AND (a < 5))))"
@@ -4424,84 +4258,118 @@ class PGDialect(default.DefaultDialect):
sqltext = re.compile(
r"^[\s\n]*\((.+)\)[\s\n]*$", flags=re.DOTALL
).sub(r"\1", m.group(1))
- entry = {"name": name, "sqltext": sqltext}
+ entry = {"name": check_name, "sqltext": sqltext}
if m and m.group(2):
entry["dialect_options"] = {"not_valid": True}
- ret.append(entry)
- return ret
-
- def _load_enums(self, connection, schema=None):
- schema = schema or self.default_schema_name
- if not self.supports_native_enum:
- return {}
-
- # Load data types for enums:
- SQL_ENUMS = """
- SELECT t.typname as "name",
- -- no enum defaults in 8.4 at least
- -- t.typdefault as "default",
- pg_catalog.pg_type_is_visible(t.oid) as "visible",
- n.nspname as "schema",
- e.enumlabel as "label"
- FROM pg_catalog.pg_type t
- LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
- LEFT JOIN pg_catalog.pg_enum e ON t.oid = e.enumtypid
- WHERE t.typtype = 'e'
- """
+ check_constraints[(schema, table_name)].append(entry)
+ return check_constraints.items()
- if schema != "*":
- SQL_ENUMS += "AND n.nspname = :schema "
+ @lru_cache()
+ def _enum_query(self, schema):
+ lbl_sq = (
+ select(
+ pg_catalog.pg_enum.c.enumtypid, pg_catalog.pg_enum.c.enumlabel
+ )
+ .order_by(
+ pg_catalog.pg_enum.c.enumtypid,
+ pg_catalog.pg_enum.c.enumsortorder,
+ )
+ .subquery("lbl")
+ )
- # e.oid gives us label order within an enum
- SQL_ENUMS += 'ORDER BY "schema", "name", e.oid'
+ lbl_agg_sq = (
+ select(
+ lbl_sq.c.enumtypid,
+ sql.func.array_agg(lbl_sq.c.enumlabel).label("labels"),
+ )
+ .group_by(lbl_sq.c.enumtypid)
+ .subquery("lbl_agg")
+ )
- s = sql.text(SQL_ENUMS).columns(
- attname=sqltypes.Unicode, label=sqltypes.Unicode
+ query = (
+ select(
+ pg_catalog.pg_type.c.typname.label("name"),
+ pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid).label(
+ "visible"
+ ),
+ pg_catalog.pg_namespace.c.nspname.label("schema"),
+ lbl_agg_sq.c.labels.label("labels"),
+ )
+ .join(
+ pg_catalog.pg_namespace,
+ pg_catalog.pg_namespace.c.oid
+ == pg_catalog.pg_type.c.typnamespace,
+ )
+ .outerjoin(
+ lbl_agg_sq, pg_catalog.pg_type.c.oid == lbl_agg_sq.c.enumtypid
+ )
+ .where(pg_catalog.pg_type.c.typtype == "e")
+ .order_by(
+ pg_catalog.pg_namespace.c.nspname, pg_catalog.pg_type.c.typname
+ )
)
- if schema != "*":
- s = s.bindparams(schema=schema)
+ if schema is None:
+ query = query.where(
+ pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid),
+ # ignore pg_catalog schema
+ pg_catalog.pg_namespace.c.nspname != "pg_catalog",
+ )
+ elif schema != "*":
+ query = query.where(pg_catalog.pg_namespace.c.nspname == schema)
+ return query
+
+ @reflection.cache
+ def _load_enums(self, connection, schema=None, **kw):
+ if not self.supports_native_enum:
+ return []
- c = connection.execute(s)
+ result = connection.execute(self._enum_query(schema))
enums = []
- enum_by_name = {}
- for enum in c.fetchall():
- key = (enum.schema, enum.name)
- if key in enum_by_name:
- enum_by_name[key]["labels"].append(enum.label)
- else:
- enum_by_name[key] = enum_rec = {
- "name": enum.name,
- "schema": enum.schema,
- "visible": enum.visible,
- "labels": [],
+ for name, visible, schema, labels in result:
+ enums.append(
+ {
+ "name": name,
+ "schema": schema,
+ "visible": visible,
+ "labels": [] if labels is None else labels,
}
- if enum.label is not None:
- enum_rec["labels"].append(enum.label)
- enums.append(enum_rec)
+ )
return enums
- def _load_domains(self, connection):
- # Load data types for domains:
- SQL_DOMAINS = """
- SELECT t.typname as "name",
- pg_catalog.format_type(t.typbasetype, t.typtypmod) as "attype",
- not t.typnotnull as "nullable",
- t.typdefault as "default",
- pg_catalog.pg_type_is_visible(t.oid) as "visible",
- n.nspname as "schema"
- FROM pg_catalog.pg_type t
- LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
- WHERE t.typtype = 'd'
- """
+ @util.memoized_property
+ def _domain_query(self):
+ return (
+ select(
+ pg_catalog.pg_type.c.typname.label("name"),
+ pg_catalog.format_type(
+ pg_catalog.pg_type.c.typbasetype,
+ pg_catalog.pg_type.c.typtypmod,
+ ).label("attype"),
+ (~pg_catalog.pg_type.c.typnotnull).label("nullable"),
+ pg_catalog.pg_type.c.typdefault.label("default"),
+ pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid).label(
+ "visible"
+ ),
+ pg_catalog.pg_namespace.c.nspname.label("schema"),
+ )
+ .join(
+ pg_catalog.pg_namespace,
+ pg_catalog.pg_namespace.c.oid
+ == pg_catalog.pg_type.c.typnamespace,
+ )
+ .where(pg_catalog.pg_type.c.typtype == "d")
+ )
- s = sql.text(SQL_DOMAINS)
- c = connection.execution_options(future_result=True).execute(s)
+ @reflection.cache
+ def _load_domains(self, connection, **kw):
+ # Load data types for domains:
+ result = connection.execute(self._domain_query)
domains = {}
- for domain in c.mappings():
+ for domain in result.mappings():
domain = domain
# strip (30) from character varying(30)
attype = re.search(r"([^\(]+)", domain["attype"]).group(1)
diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py
index 6cb97ece4..ce9a3bb6c 100644
--- a/lib/sqlalchemy/dialects/postgresql/pg8000.py
+++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py
@@ -107,6 +107,8 @@ from .base import PGIdentifierPreparer
from .json import JSON
from .json import JSONB
from .json import JSONPathType
+from .pg_catalog import _SpaceVector
+from .pg_catalog import OIDVECTOR
from ... import exc
from ... import util
from ...engine import processors
@@ -245,6 +247,10 @@ class _PGARRAY(PGARRAY):
render_bind_cast = True
+class _PGOIDVECTOR(_SpaceVector, OIDVECTOR):
+ pass
+
+
_server_side_id = util.counter()
@@ -376,6 +382,7 @@ class PGDialect_pg8000(PGDialect):
sqltypes.BigInteger: _PGBigInteger,
sqltypes.Enum: _PGEnum,
sqltypes.ARRAY: _PGARRAY,
+ OIDVECTOR: _PGOIDVECTOR,
},
)
diff --git a/lib/sqlalchemy/dialects/postgresql/pg_catalog.py b/lib/sqlalchemy/dialects/postgresql/pg_catalog.py
new file mode 100644
index 000000000..a77e7ccf6
--- /dev/null
+++ b/lib/sqlalchemy/dialects/postgresql/pg_catalog.py
@@ -0,0 +1,292 @@
+# postgresql/pg_catalog.py
+# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
+
+from .array import ARRAY
+from .types import OID
+from .types import REGCLASS
+from ... import Column
+from ... import func
+from ... import MetaData
+from ... import Table
+from ...types import BigInteger
+from ...types import Boolean
+from ...types import CHAR
+from ...types import Float
+from ...types import Integer
+from ...types import SmallInteger
+from ...types import String
+from ...types import Text
+from ...types import TypeDecorator
+
+
+# types
+class NAME(TypeDecorator):
+ impl = String(64, collation="C")
+ cache_ok = True
+
+
+class PG_NODE_TREE(TypeDecorator):
+ impl = Text(collation="C")
+ cache_ok = True
+
+
+class INT2VECTOR(TypeDecorator):
+ impl = ARRAY(SmallInteger)
+ cache_ok = True
+
+
+class OIDVECTOR(TypeDecorator):
+ impl = ARRAY(OID)
+ cache_ok = True
+
+
+class _SpaceVector:
+ def result_processor(self, dialect, coltype):
+ def process(value):
+ if value is None:
+ return value
+ return [int(p) for p in value.split(" ")]
+
+ return process
+
+
+REGPROC = REGCLASS # seems an alias
+
+# functions
+_pg_cat = func.pg_catalog
+quote_ident = _pg_cat.quote_ident
+pg_table_is_visible = _pg_cat.pg_table_is_visible
+pg_type_is_visible = _pg_cat.pg_type_is_visible
+pg_get_viewdef = _pg_cat.pg_get_viewdef
+pg_get_serial_sequence = _pg_cat.pg_get_serial_sequence
+format_type = _pg_cat.format_type
+pg_get_expr = _pg_cat.pg_get_expr
+pg_get_constraintdef = _pg_cat.pg_get_constraintdef
+
+# constants
+RELKINDS_TABLE_NO_FOREIGN = ("r", "p")
+RELKINDS_TABLE = RELKINDS_TABLE_NO_FOREIGN + ("f",)
+RELKINDS_VIEW = ("v",)
+RELKINDS_MAT_VIEW = ("m",)
+RELKINDS_ALL_TABLE_LIKE = RELKINDS_TABLE + RELKINDS_VIEW + RELKINDS_MAT_VIEW
+
+# tables
+pg_catalog_meta = MetaData()
+
+pg_namespace = Table(
+ "pg_namespace",
+ pg_catalog_meta,
+ Column("oid", OID),
+ Column("nspname", NAME),
+ Column("nspowner", OID),
+ schema="pg_catalog",
+)
+
+pg_class = Table(
+ "pg_class",
+ pg_catalog_meta,
+ Column("oid", OID, info={"server_version": (9, 3)}),
+ Column("relname", NAME),
+ Column("relnamespace", OID),
+ Column("reltype", OID),
+ Column("reloftype", OID),
+ Column("relowner", OID),
+ Column("relam", OID),
+ Column("relfilenode", OID),
+ Column("reltablespace", OID),
+ Column("relpages", Integer),
+ Column("reltuples", Float),
+ Column("relallvisible", Integer, info={"server_version": (9, 2)}),
+ Column("reltoastrelid", OID),
+ Column("relhasindex", Boolean),
+ Column("relisshared", Boolean),
+ Column("relpersistence", CHAR, info={"server_version": (9, 1)}),
+ Column("relkind", CHAR),
+ Column("relnatts", SmallInteger),
+ Column("relchecks", SmallInteger),
+ Column("relhasrules", Boolean),
+ Column("relhastriggers", Boolean),
+ Column("relhassubclass", Boolean),
+ Column("relrowsecurity", Boolean),
+ Column("relforcerowsecurity", Boolean, info={"server_version": (9, 5)}),
+ Column("relispopulated", Boolean, info={"server_version": (9, 3)}),
+ Column("relreplident", CHAR, info={"server_version": (9, 4)}),
+ Column("relispartition", Boolean, info={"server_version": (10,)}),
+ Column("relrewrite", OID, info={"server_version": (11,)}),
+ Column("reloptions", ARRAY(Text)),
+ schema="pg_catalog",
+)
+
+pg_type = Table(
+ "pg_type",
+ pg_catalog_meta,
+ Column("oid", OID, info={"server_version": (9, 3)}),
+ Column("typname", NAME),
+ Column("typnamespace", OID),
+ Column("typowner", OID),
+ Column("typlen", SmallInteger),
+ Column("typbyval", Boolean),
+ Column("typtype", CHAR),
+ Column("typcategory", CHAR),
+ Column("typispreferred", Boolean),
+ Column("typisdefined", Boolean),
+ Column("typdelim", CHAR),
+ Column("typrelid", OID),
+ Column("typelem", OID),
+ Column("typarray", OID),
+ Column("typinput", REGPROC),
+ Column("typoutput", REGPROC),
+ Column("typreceive", REGPROC),
+ Column("typsend", REGPROC),
+ Column("typmodin", REGPROC),
+ Column("typmodout", REGPROC),
+ Column("typanalyze", REGPROC),
+ Column("typalign", CHAR),
+ Column("typstorage", CHAR),
+ Column("typnotnull", Boolean),
+ Column("typbasetype", OID),
+ Column("typtypmod", Integer),
+ Column("typndims", Integer),
+ Column("typcollation", OID, info={"server_version": (9, 1)}),
+ Column("typdefault", Text),
+ schema="pg_catalog",
+)
+
+pg_index = Table(
+ "pg_index",
+ pg_catalog_meta,
+ Column("indexrelid", OID),
+ Column("indrelid", OID),
+ Column("indnatts", SmallInteger),
+ Column("indnkeyatts", SmallInteger, info={"server_version": (11,)}),
+ Column("indisunique", Boolean),
+ Column("indisprimary", Boolean),
+ Column("indisexclusion", Boolean, info={"server_version": (9, 1)}),
+ Column("indimmediate", Boolean),
+ Column("indisclustered", Boolean),
+ Column("indisvalid", Boolean),
+ Column("indcheckxmin", Boolean),
+ Column("indisready", Boolean),
+ Column("indislive", Boolean, info={"server_version": (9, 3)}), # 9.3
+ Column("indisreplident", Boolean),
+ Column("indkey", INT2VECTOR),
+ Column("indcollation", OIDVECTOR, info={"server_version": (9, 1)}), # 9.1
+ Column("indclass", OIDVECTOR),
+ Column("indoption", INT2VECTOR),
+ Column("indexprs", PG_NODE_TREE),
+ Column("indpred", PG_NODE_TREE),
+ schema="pg_catalog",
+)
+
+pg_attribute = Table(
+ "pg_attribute",
+ pg_catalog_meta,
+ Column("attrelid", OID),
+ Column("attname", NAME),
+ Column("atttypid", OID),
+ Column("attstattarget", Integer),
+ Column("attlen", SmallInteger),
+ Column("attnum", SmallInteger),
+ Column("attndims", Integer),
+ Column("attcacheoff", Integer),
+ Column("atttypmod", Integer),
+ Column("attbyval", Boolean),
+ Column("attstorage", CHAR),
+ Column("attalign", CHAR),
+ Column("attnotnull", Boolean),
+ Column("atthasdef", Boolean),
+ Column("atthasmissing", Boolean, info={"server_version": (11,)}),
+ Column("attidentity", CHAR, info={"server_version": (10,)}),
+ Column("attgenerated", CHAR, info={"server_version": (12,)}),
+ Column("attisdropped", Boolean),
+ Column("attislocal", Boolean),
+ Column("attinhcount", Integer),
+ Column("attcollation", OID, info={"server_version": (9, 1)}),
+ schema="pg_catalog",
+)
+
+pg_constraint = Table(
+ "pg_constraint",
+ pg_catalog_meta,
+ Column("oid", OID), # 9.3
+ Column("conname", NAME),
+ Column("connamespace", OID),
+ Column("contype", CHAR),
+ Column("condeferrable", Boolean),
+ Column("condeferred", Boolean),
+ Column("convalidated", Boolean, info={"server_version": (9, 1)}),
+ Column("conrelid", OID),
+ Column("contypid", OID),
+ Column("conindid", OID),
+ Column("conparentid", OID, info={"server_version": (11,)}),
+ Column("confrelid", OID),
+ Column("confupdtype", CHAR),
+ Column("confdeltype", CHAR),
+ Column("confmatchtype", CHAR),
+ Column("conislocal", Boolean),
+ Column("coninhcount", Integer),
+ Column("connoinherit", Boolean, info={"server_version": (9, 2)}),
+ Column("conkey", ARRAY(SmallInteger)),
+ Column("confkey", ARRAY(SmallInteger)),
+ schema="pg_catalog",
+)
+
+pg_sequence = Table(
+ "pg_sequence",
+ pg_catalog_meta,
+ Column("seqrelid", OID),
+ Column("seqtypid", OID),
+ Column("seqstart", BigInteger),
+ Column("seqincrement", BigInteger),
+ Column("seqmax", BigInteger),
+ Column("seqmin", BigInteger),
+ Column("seqcache", BigInteger),
+ Column("seqcycle", Boolean),
+ schema="pg_catalog",
+ info={"server_version": (10,)},
+)
+
+pg_attrdef = Table(
+ "pg_attrdef",
+ pg_catalog_meta,
+ Column("oid", OID, info={"server_version": (9, 3)}),
+ Column("adrelid", OID),
+ Column("adnum", SmallInteger),
+ Column("adbin", PG_NODE_TREE),
+ schema="pg_catalog",
+)
+
+pg_description = Table(
+ "pg_description",
+ pg_catalog_meta,
+ Column("objoid", OID),
+ Column("classoid", OID),
+ Column("objsubid", Integer),
+ Column("description", Text(collation="C")),
+ schema="pg_catalog",
+)
+
+pg_enum = Table(
+ "pg_enum",
+ pg_catalog_meta,
+ Column("oid", OID, info={"server_version": (9, 3)}),
+ Column("enumtypid", OID),
+ Column("enumsortorder", Float(), info={"server_version": (9, 1)}),
+ Column("enumlabel", NAME),
+ schema="pg_catalog",
+)
+
+pg_am = Table(
+ "pg_am",
+ pg_catalog_meta,
+ Column("oid", OID, info={"server_version": (9, 3)}),
+ Column("amname", NAME),
+ Column("amhandler", REGPROC, info={"server_version": (9, 6)}),
+ Column("amtype", CHAR, info={"server_version": (9, 6)}),
+ schema="pg_catalog",
+)
diff --git a/lib/sqlalchemy/dialects/postgresql/types.py b/lib/sqlalchemy/dialects/postgresql/types.py
new file mode 100644
index 000000000..55735953b
--- /dev/null
+++ b/lib/sqlalchemy/dialects/postgresql/types.py
@@ -0,0 +1,485 @@
+# Copyright (C) 2013-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
+
+import datetime as dt
+from typing import Any
+
+from ... import schema
+from ... import util
+from ...sql import sqltypes
+from ...sql.ddl import InvokeDDLBase
+
+
+_DECIMAL_TYPES = (1231, 1700)
+_FLOAT_TYPES = (700, 701, 1021, 1022)
+_INT_TYPES = (20, 21, 23, 26, 1005, 1007, 1016)
+
+
+class PGUuid(sqltypes.UUID):
+ render_bind_cast = True
+ render_literal_cast = True
+
+
+class BYTEA(sqltypes.LargeBinary[bytes]):
+ __visit_name__ = "BYTEA"
+
+
+class INET(sqltypes.TypeEngine[str]):
+ __visit_name__ = "INET"
+
+
+PGInet = INET
+
+
+class CIDR(sqltypes.TypeEngine[str]):
+ __visit_name__ = "CIDR"
+
+
+PGCidr = CIDR
+
+
+class MACADDR(sqltypes.TypeEngine[str]):
+ __visit_name__ = "MACADDR"
+
+
+PGMacAddr = MACADDR
+
+
+class MONEY(sqltypes.TypeEngine[str]):
+
+ r"""Provide the PostgreSQL MONEY type.
+
+ Depending on driver, result rows using this type may return a
+ string value which includes currency symbols.
+
+ For this reason, it may be preferable to provide conversion to a
+ numerically-based currency datatype using :class:`_types.TypeDecorator`::
+
+ import re
+ import decimal
+ from sqlalchemy import TypeDecorator
+
+ class NumericMoney(TypeDecorator):
+ impl = MONEY
+
+ def process_result_value(self, value: Any, dialect: Any) -> None:
+ if value is not None:
+ # adjust this for the currency and numeric
+ m = re.match(r"\$([\d.]+)", value)
+ if m:
+ value = decimal.Decimal(m.group(1))
+ return value
+
+ Alternatively, the conversion may be applied as a CAST using
+ the :meth:`_types.TypeDecorator.column_expression` method as follows::
+
+ import decimal
+ from sqlalchemy import cast
+ from sqlalchemy import TypeDecorator
+
+ class NumericMoney(TypeDecorator):
+ impl = MONEY
+
+ def column_expression(self, column: Any):
+ return cast(column, Numeric())
+
+ .. versionadded:: 1.2
+
+ """
+
+ __visit_name__ = "MONEY"
+
+
+class OID(sqltypes.TypeEngine[int]):
+
+ """Provide the PostgreSQL OID type.
+
+ .. versionadded:: 0.9.5
+
+ """
+
+ __visit_name__ = "OID"
+
+
+class REGCLASS(sqltypes.TypeEngine[str]):
+
+ """Provide the PostgreSQL REGCLASS type.
+
+ .. versionadded:: 1.2.7
+
+ """
+
+ __visit_name__ = "REGCLASS"
+
+
+class TIMESTAMP(sqltypes.TIMESTAMP):
+ def __init__(self, timezone=False, precision=None):
+ super(TIMESTAMP, self).__init__(timezone=timezone)
+ self.precision = precision
+
+
+class TIME(sqltypes.TIME):
+ def __init__(self, timezone=False, precision=None):
+ super(TIME, self).__init__(timezone=timezone)
+ self.precision = precision
+
+
+class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
+
+ """PostgreSQL INTERVAL type."""
+
+ __visit_name__ = "INTERVAL"
+ native = True
+
+ def __init__(self, precision=None, fields=None):
+ """Construct an INTERVAL.
+
+ :param precision: optional integer precision value
+ :param fields: string fields specifier. allows storage of fields
+ to be limited, such as ``"YEAR"``, ``"MONTH"``, ``"DAY TO HOUR"``,
+ etc.
+
+ .. versionadded:: 1.2
+
+ """
+ self.precision = precision
+ self.fields = fields
+
+ @classmethod
+ def adapt_emulated_to_native(cls, interval, **kw):
+ return INTERVAL(precision=interval.second_precision)
+
+ @property
+ def _type_affinity(self):
+ return sqltypes.Interval
+
+ def as_generic(self, allow_nulltype=False):
+ return sqltypes.Interval(native=True, second_precision=self.precision)
+
+ @property
+ def python_type(self):
+ return dt.timedelta
+
+
+PGInterval = INTERVAL
+
+
+class BIT(sqltypes.TypeEngine[int]):
+ __visit_name__ = "BIT"
+
+ def __init__(self, length=None, varying=False):
+ if not varying:
+ # BIT without VARYING defaults to length 1
+ self.length = length or 1
+ else:
+ # but BIT VARYING can be unlimited-length, so no default
+ self.length = length
+ self.varying = varying
+
+
+PGBit = BIT
+
+
+class TSVECTOR(sqltypes.TypeEngine[Any]):
+
+ """The :class:`_postgresql.TSVECTOR` type implements the PostgreSQL
+ text search type TSVECTOR.
+
+ It can be used to do full text queries on natural language
+ documents.
+
+ .. versionadded:: 0.9.0
+
+ .. seealso::
+
+ :ref:`postgresql_match`
+
+ """
+
+ __visit_name__ = "TSVECTOR"
+
+
+class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
+
+ """PostgreSQL ENUM type.
+
+ This is a subclass of :class:`_types.Enum` which includes
+ support for PG's ``CREATE TYPE`` and ``DROP TYPE``.
+
+ When the builtin type :class:`_types.Enum` is used and the
+ :paramref:`.Enum.native_enum` flag is left at its default of
+ True, the PostgreSQL backend will use a :class:`_postgresql.ENUM`
+ type as the implementation, so the special create/drop rules
+ will be used.
+
+ The create/drop behavior of ENUM is necessarily intricate, due to the
+ awkward relationship the ENUM type has in relationship to the
+ parent table, in that it may be "owned" by just a single table, or
+ may be shared among many tables.
+
+ When using :class:`_types.Enum` or :class:`_postgresql.ENUM`
+ in an "inline" fashion, the ``CREATE TYPE`` and ``DROP TYPE`` is emitted
+ corresponding to when the :meth:`_schema.Table.create` and
+ :meth:`_schema.Table.drop`
+ methods are called::
+
+ table = Table('sometable', metadata,
+ Column('some_enum', ENUM('a', 'b', 'c', name='myenum'))
+ )
+
+ table.create(engine) # will emit CREATE ENUM and CREATE TABLE
+ table.drop(engine) # will emit DROP TABLE and DROP ENUM
+
+ To use a common enumerated type between multiple tables, the best
+ practice is to declare the :class:`_types.Enum` or
+ :class:`_postgresql.ENUM` independently, and associate it with the
+ :class:`_schema.MetaData` object itself::
+
+ my_enum = ENUM('a', 'b', 'c', name='myenum', metadata=metadata)
+
+ t1 = Table('sometable_one', metadata,
+ Column('some_enum', myenum)
+ )
+
+ t2 = Table('sometable_two', metadata,
+ Column('some_enum', myenum)
+ )
+
+ When this pattern is used, care must still be taken at the level
+ of individual table creates. Emitting CREATE TABLE without also
+ specifying ``checkfirst=True`` will still cause issues::
+
+ t1.create(engine) # will fail: no such type 'myenum'
+
+ If we specify ``checkfirst=True``, the individual table-level create
+ operation will check for the ``ENUM`` and create if not exists::
+
+ # will check if enum exists, and emit CREATE TYPE if not
+ t1.create(engine, checkfirst=True)
+
+ When using a metadata-level ENUM type, the type will always be created
+ and dropped if either the metadata-wide create/drop is called::
+
+ metadata.create_all(engine) # will emit CREATE TYPE
+ metadata.drop_all(engine) # will emit DROP TYPE
+
+ The type can also be created and dropped directly::
+
+ my_enum.create(engine)
+ my_enum.drop(engine)
+
+ .. versionchanged:: 1.0.0 The PostgreSQL :class:`_postgresql.ENUM` type
+ now behaves more strictly with regards to CREATE/DROP. A metadata-level
+ ENUM type will only be created and dropped at the metadata level,
+ not the table level, with the exception of
+ ``table.create(checkfirst=True)``.
+ The ``table.drop()`` call will now emit a DROP TYPE for a table-level
+ enumerated type.
+
+ """
+
+ native_enum = True
+
+ def __init__(self, *enums, **kw):
+ """Construct an :class:`_postgresql.ENUM`.
+
+ Arguments are the same as that of
+ :class:`_types.Enum`, but also including
+ the following parameters.
+
+ :param create_type: Defaults to True.
+ Indicates that ``CREATE TYPE`` should be
+ emitted, after optionally checking for the
+ presence of the type, when the parent
+ table is being created; and additionally
+ that ``DROP TYPE`` is called when the table
+ is dropped. When ``False``, no check
+ will be performed and no ``CREATE TYPE``
+ or ``DROP TYPE`` is emitted, unless
+ :meth:`~.postgresql.ENUM.create`
+ or :meth:`~.postgresql.ENUM.drop`
+ are called directly.
+ Setting to ``False`` is helpful
+ when invoking a creation scheme to a SQL file
+ without access to the actual database -
+ the :meth:`~.postgresql.ENUM.create` and
+ :meth:`~.postgresql.ENUM.drop` methods can
+ be used to emit SQL to a target bind.
+
+ """
+ native_enum = kw.pop("native_enum", None)
+ if native_enum is False:
+ util.warn(
+ "the native_enum flag does not apply to the "
+ "sqlalchemy.dialects.postgresql.ENUM datatype; this type "
+ "always refers to ENUM. Use sqlalchemy.types.Enum for "
+ "non-native enum."
+ )
+ self.create_type = kw.pop("create_type", True)
+ super(ENUM, self).__init__(*enums, **kw)
+
+ @classmethod
+ def adapt_emulated_to_native(cls, impl, **kw):
+ """Produce a PostgreSQL native :class:`_postgresql.ENUM` from plain
+ :class:`.Enum`.
+
+ """
+ kw.setdefault("validate_strings", impl.validate_strings)
+ kw.setdefault("name", impl.name)
+ kw.setdefault("schema", impl.schema)
+ kw.setdefault("inherit_schema", impl.inherit_schema)
+ kw.setdefault("metadata", impl.metadata)
+ kw.setdefault("_create_events", False)
+ kw.setdefault("values_callable", impl.values_callable)
+ kw.setdefault("omit_aliases", impl._omit_aliases)
+ return cls(**kw)
+
+ def create(self, bind=None, checkfirst=True):
+ """Emit ``CREATE TYPE`` for this
+ :class:`_postgresql.ENUM`.
+
+ If the underlying dialect does not support
+ PostgreSQL CREATE TYPE, no action is taken.
+
+ :param bind: a connectable :class:`_engine.Engine`,
+ :class:`_engine.Connection`, or similar object to emit
+ SQL.
+ :param checkfirst: if ``True``, a query against
+ the PG catalog will be first performed to see
+ if the type does not exist already before
+ creating.
+
+ """
+ if not bind.dialect.supports_native_enum:
+ return
+
+ bind._run_ddl_visitor(self.EnumGenerator, self, checkfirst=checkfirst)
+
+ def drop(self, bind=None, checkfirst=True):
+ """Emit ``DROP TYPE`` for this
+ :class:`_postgresql.ENUM`.
+
+ If the underlying dialect does not support
+ PostgreSQL DROP TYPE, no action is taken.
+
+ :param bind: a connectable :class:`_engine.Engine`,
+ :class:`_engine.Connection`, or similar object to emit
+ SQL.
+ :param checkfirst: if ``True``, a query against
+ the PG catalog will be first performed to see
+ if the type actually exists before dropping.
+
+ """
+ if not bind.dialect.supports_native_enum:
+ return
+
+ bind._run_ddl_visitor(self.EnumDropper, self, checkfirst=checkfirst)
+
+ class EnumGenerator(InvokeDDLBase):
+ def __init__(self, dialect, connection, checkfirst=False, **kwargs):
+ super(ENUM.EnumGenerator, self).__init__(connection, **kwargs)
+ self.checkfirst = checkfirst
+
+ def _can_create_enum(self, enum):
+ if not self.checkfirst:
+ return True
+
+ effective_schema = self.connection.schema_for_object(enum)
+
+ return not self.connection.dialect.has_type(
+ self.connection, enum.name, schema=effective_schema
+ )
+
+ def visit_enum(self, enum):
+ if not self._can_create_enum(enum):
+ return
+
+ self.connection.execute(CreateEnumType(enum))
+
+ class EnumDropper(InvokeDDLBase):
+ def __init__(self, dialect, connection, checkfirst=False, **kwargs):
+ super(ENUM.EnumDropper, self).__init__(connection, **kwargs)
+ self.checkfirst = checkfirst
+
+ def _can_drop_enum(self, enum):
+ if not self.checkfirst:
+ return True
+
+ effective_schema = self.connection.schema_for_object(enum)
+
+ return self.connection.dialect.has_type(
+ self.connection, enum.name, schema=effective_schema
+ )
+
+ def visit_enum(self, enum):
+ if not self._can_drop_enum(enum):
+ return
+
+ self.connection.execute(DropEnumType(enum))
+
+ def get_dbapi_type(self, dbapi):
+ """dont return dbapi.STRING for ENUM in PostgreSQL, since that's
+ a different type"""
+
+ return None
+
+ def _check_for_name_in_memos(self, checkfirst, kw):
+ """Look in the 'ddl runner' for 'memos', then
+ note our name in that collection.
+
+ This to ensure a particular named enum is operated
+ upon only once within any kind of create/drop
+ sequence without relying upon "checkfirst".
+
+ """
+ if not self.create_type:
+ return True
+ if "_ddl_runner" in kw:
+ ddl_runner = kw["_ddl_runner"]
+ if "_pg_enums" in ddl_runner.memo:
+ pg_enums = ddl_runner.memo["_pg_enums"]
+ else:
+ pg_enums = ddl_runner.memo["_pg_enums"] = set()
+ present = (self.schema, self.name) in pg_enums
+ pg_enums.add((self.schema, self.name))
+ return present
+ else:
+ return False
+
+ def _on_table_create(self, target, bind, checkfirst=False, **kw):
+ if (
+ checkfirst
+ or (
+ not self.metadata
+ and not kw.get("_is_metadata_operation", False)
+ )
+ ) and not self._check_for_name_in_memos(checkfirst, kw):
+ self.create(bind=bind, checkfirst=checkfirst)
+
+ def _on_table_drop(self, target, bind, checkfirst=False, **kw):
+ if (
+ not self.metadata
+ and not kw.get("_is_metadata_operation", False)
+ and not self._check_for_name_in_memos(checkfirst, kw)
+ ):
+ self.drop(bind=bind, checkfirst=checkfirst)
+
+ def _on_metadata_create(self, target, bind, checkfirst=False, **kw):
+ if not self._check_for_name_in_memos(checkfirst, kw):
+ self.create(bind=bind, checkfirst=checkfirst)
+
+ def _on_metadata_drop(self, target, bind, checkfirst=False, **kw):
+ if not self._check_for_name_in_memos(checkfirst, kw):
+ self.drop(bind=bind, checkfirst=checkfirst)
+
+
+class CreateEnumType(schema._CreateDropBase):
+ __visit_name__ = "create_enum_type"
+
+
+class DropEnumType(schema._CreateDropBase):
+ __visit_name__ = "drop_enum_type"
diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py
index fdcd1340b..22f003e38 100644
--- a/lib/sqlalchemy/dialects/sqlite/base.py
+++ b/lib/sqlalchemy/dialects/sqlite/base.py
@@ -867,6 +867,7 @@ from ... import util
from ...engine import default
from ...engine import processors
from ...engine import reflection
+from ...engine.reflection import ReflectionDefaults
from ...sql import coercions
from ...sql import ColumnElement
from ...sql import compiler
@@ -2053,28 +2054,27 @@ class SQLiteDialect(default.DefaultDialect):
return [db[1] for db in dl if db[1] != "temp"]
- @reflection.cache
- def get_table_names(self, connection, schema=None, **kw):
+ def _format_schema(self, schema, table_name):
if schema is not None:
qschema = self.identifier_preparer.quote_identifier(schema)
- master = "%s.sqlite_master" % qschema
+ name = f"{qschema}.{table_name}"
else:
- master = "sqlite_master"
- s = ("SELECT name FROM %s " "WHERE type='table' ORDER BY name") % (
- master,
- )
- rs = connection.exec_driver_sql(s)
- return [row[0] for row in rs]
+ name = table_name
+ return name
@reflection.cache
- def get_temp_table_names(self, connection, **kw):
- s = (
- "SELECT name FROM sqlite_temp_master "
- "WHERE type='table' ORDER BY name "
- )
- rs = connection.exec_driver_sql(s)
+ def get_table_names(self, connection, schema=None, **kw):
+ main = self._format_schema(schema, "sqlite_master")
+ s = f"SELECT name FROM {main} WHERE type='table' ORDER BY name"
+ names = connection.exec_driver_sql(s).scalars().all()
+ return names
- return [row[0] for row in rs]
+ @reflection.cache
+ def get_temp_table_names(self, connection, **kw):
+ main = "sqlite_temp_master"
+ s = f"SELECT name FROM {main} WHERE type='table' ORDER BY name"
+ names = connection.exec_driver_sql(s).scalars().all()
+ return names
@reflection.cache
def get_temp_view_names(self, connection, **kw):
@@ -2082,11 +2082,11 @@ class SQLiteDialect(default.DefaultDialect):
"SELECT name FROM sqlite_temp_master "
"WHERE type='view' ORDER BY name "
)
- rs = connection.exec_driver_sql(s)
-
- return [row[0] for row in rs]
+ names = connection.exec_driver_sql(s).scalars().all()
+ return names
- def has_table(self, connection, table_name, schema=None):
+ @reflection.cache
+ def has_table(self, connection, table_name, schema=None, **kw):
self._ensure_has_table_connection(connection)
info = self._get_table_pragma(
@@ -2099,23 +2099,16 @@ class SQLiteDialect(default.DefaultDialect):
@reflection.cache
def get_view_names(self, connection, schema=None, **kw):
- if schema is not None:
- qschema = self.identifier_preparer.quote_identifier(schema)
- master = "%s.sqlite_master" % qschema
- else:
- master = "sqlite_master"
- s = ("SELECT name FROM %s " "WHERE type='view' ORDER BY name") % (
- master,
- )
- rs = connection.exec_driver_sql(s)
-
- return [row[0] for row in rs]
+ main = self._format_schema(schema, "sqlite_master")
+ s = f"SELECT name FROM {main} WHERE type='view' ORDER BY name"
+ names = connection.exec_driver_sql(s).scalars().all()
+ return names
@reflection.cache
def get_view_definition(self, connection, view_name, schema=None, **kw):
if schema is not None:
qschema = self.identifier_preparer.quote_identifier(schema)
- master = "%s.sqlite_master" % qschema
+ master = f"{qschema}.sqlite_master"
s = ("SELECT sql FROM %s WHERE name = ? AND type='view'") % (
master,
)
@@ -2140,6 +2133,10 @@ class SQLiteDialect(default.DefaultDialect):
result = rs.fetchall()
if result:
return result[0].sql
+ else:
+ raise exc.NoSuchTableError(
+ f"{schema}.{view_name}" if schema else view_name
+ )
@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
@@ -2186,7 +2183,14 @@ class SQLiteDialect(default.DefaultDialect):
tablesql,
)
)
- return columns
+ if columns:
+ return columns
+ elif not self.has_table(connection, table_name, schema):
+ raise exc.NoSuchTableError(
+ f"{schema}.{table_name}" if schema else table_name
+ )
+ else:
+ return ReflectionDefaults.columns()
def _get_column_info(
self,
@@ -2216,7 +2220,6 @@ class SQLiteDialect(default.DefaultDialect):
"type": coltype,
"nullable": nullable,
"default": default,
- "autoincrement": "auto",
"primary_key": primary_key,
}
if generated:
@@ -2295,13 +2298,16 @@ class SQLiteDialect(default.DefaultDialect):
constraint_name = result.group(1) if result else None
cols = self.get_columns(connection, table_name, schema, **kw)
+ # consider only pk columns. This also avoids sorting the cached
+ # value returned by get_columns
+ cols = [col for col in cols if col.get("primary_key", 0) > 0]
cols.sort(key=lambda col: col.get("primary_key"))
- pkeys = []
- for col in cols:
- if col["primary_key"]:
- pkeys.append(col["name"])
+ pkeys = [col["name"] for col in cols]
- return {"constrained_columns": pkeys, "name": constraint_name}
+ if pkeys:
+ return {"constrained_columns": pkeys, "name": constraint_name}
+ else:
+ return ReflectionDefaults.pk_constraint()
@reflection.cache
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
@@ -2321,12 +2327,14 @@ class SQLiteDialect(default.DefaultDialect):
# original DDL. The referred columns of the foreign key
# constraint are therefore the primary key of the referred
# table.
- referred_pk = self.get_pk_constraint(
- connection, rtbl, schema=schema, **kw
- )
- # note that if table doesn't exist, we still get back a record,
- # just it has no columns in it
- referred_columns = referred_pk["constrained_columns"]
+ try:
+ referred_pk = self.get_pk_constraint(
+ connection, rtbl, schema=schema, **kw
+ )
+ referred_columns = referred_pk["constrained_columns"]
+ except exc.NoSuchTableError:
+ # ignore not existing parents
+ referred_columns = []
else:
# note we use this list only if this is the first column
# in the constraint. for subsequent columns we ignore the
@@ -2378,11 +2386,11 @@ class SQLiteDialect(default.DefaultDialect):
)
table_data = self._get_table_sql(connection, table_name, schema=schema)
- if table_data is None:
- # system tables, etc.
- return []
def parse_fks():
+ if table_data is None:
+ # system tables, etc.
+ return
FK_PATTERN = (
r"(?:CONSTRAINT (\w+) +)?"
r"FOREIGN KEY *\( *(.+?) *\) +"
@@ -2453,7 +2461,10 @@ class SQLiteDialect(default.DefaultDialect):
# use them as is as it's extremely difficult to parse inline
# constraints
fkeys.extend(keys_by_signature.values())
- return fkeys
+ if fkeys:
+ return fkeys
+ else:
+ return ReflectionDefaults.foreign_keys()
def _find_cols_in_sig(self, sig):
for match in re.finditer(r'(?:"(.+?)")|([a-z0-9_]+)', sig, re.I):
@@ -2480,12 +2491,11 @@ class SQLiteDialect(default.DefaultDialect):
table_data = self._get_table_sql(
connection, table_name, schema=schema, **kw
)
- if not table_data:
- return []
-
unique_constraints = []
def parse_uqs():
+ if table_data is None:
+ return
UNIQUE_PATTERN = r'(?:CONSTRAINT "?(.+?)"? +)?UNIQUE *\((.+?)\)'
INLINE_UNIQUE_PATTERN = (
r'(?:(".+?")|(?:[\[`])?([a-z0-9_]+)(?:[\]`])?) '
@@ -2513,15 +2523,16 @@ class SQLiteDialect(default.DefaultDialect):
unique_constraints.append(parsed_constraint)
# NOTE: auto_index_by_sig might not be empty here,
# the PRIMARY KEY may have an entry.
- return unique_constraints
+ if unique_constraints:
+ return unique_constraints
+ else:
+ return ReflectionDefaults.unique_constraints()
@reflection.cache
def get_check_constraints(self, connection, table_name, schema=None, **kw):
table_data = self._get_table_sql(
connection, table_name, schema=schema, **kw
)
- if not table_data:
- return []
CHECK_PATTERN = r"(?:CONSTRAINT (.+) +)?" r"CHECK *\( *(.+) *\),? *"
check_constraints = []
@@ -2531,7 +2542,7 @@ class SQLiteDialect(default.DefaultDialect):
# necessarily makes assumptions as to how the CREATE TABLE
# was emitted.
- for match in re.finditer(CHECK_PATTERN, table_data, re.I):
+ for match in re.finditer(CHECK_PATTERN, table_data or "", re.I):
name = match.group(1)
if name:
@@ -2539,7 +2550,10 @@ class SQLiteDialect(default.DefaultDialect):
check_constraints.append({"sqltext": match.group(2), "name": name})
- return check_constraints
+ if check_constraints:
+ return check_constraints
+ else:
+ return ReflectionDefaults.check_constraints()
@reflection.cache
def get_indexes(self, connection, table_name, schema=None, **kw):
@@ -2561,7 +2575,7 @@ class SQLiteDialect(default.DefaultDialect):
# loop thru unique indexes to get the column names.
for idx in list(indexes):
pragma_index = self._get_table_pragma(
- connection, "index_info", idx["name"]
+ connection, "index_info", idx["name"], schema=schema
)
for row in pragma_index:
@@ -2574,7 +2588,23 @@ class SQLiteDialect(default.DefaultDialect):
break
else:
idx["column_names"].append(row[2])
- return indexes
+ indexes.sort(key=lambda d: d["name"] or "~") # sort None as last
+ if indexes:
+ return indexes
+ elif not self.has_table(connection, table_name, schema):
+ raise exc.NoSuchTableError(
+ f"{schema}.{table_name}" if schema else table_name
+ )
+ else:
+ return ReflectionDefaults.indexes()
+
+ def _is_sys_table(self, table_name):
+ return table_name in {
+ "sqlite_schema",
+ "sqlite_master",
+ "sqlite_temp_schema",
+ "sqlite_temp_master",
+ }
@reflection.cache
def _get_table_sql(self, connection, table_name, schema=None, **kw):
@@ -2590,22 +2620,25 @@ class SQLiteDialect(default.DefaultDialect):
" (SELECT * FROM %(schema)ssqlite_master UNION ALL "
" SELECT * FROM %(schema)ssqlite_temp_master) "
"WHERE name = ? "
- "AND type = 'table'" % {"schema": schema_expr}
+ "AND type in ('table', 'view')" % {"schema": schema_expr}
)
rs = connection.exec_driver_sql(s, (table_name,))
except exc.DBAPIError:
s = (
"SELECT sql FROM %(schema)ssqlite_master "
"WHERE name = ? "
- "AND type = 'table'" % {"schema": schema_expr}
+ "AND type in ('table', 'view')" % {"schema": schema_expr}
)
rs = connection.exec_driver_sql(s, (table_name,))
- return rs.scalar()
+ value = rs.scalar()
+ if value is None and not self._is_sys_table(table_name):
+ raise exc.NoSuchTableError(f"{schema_expr}{table_name}")
+ return value
def _get_table_pragma(self, connection, pragma, table_name, schema=None):
quote = self.identifier_preparer.quote_identifier
if schema is not None:
- statements = ["PRAGMA %s." % quote(schema)]
+ statements = [f"PRAGMA {quote(schema)}."]
else:
# because PRAGMA looks in all attached databases if no schema
# given, need to specify "main" schema, however since we want
@@ -2615,7 +2648,7 @@ class SQLiteDialect(default.DefaultDialect):
qtable = quote(table_name)
for statement in statements:
- statement = "%s%s(%s)" % (statement, pragma, qtable)
+ statement = f"{statement}{pragma}({qtable})"
cursor = connection.exec_driver_sql(statement)
if not cursor._soft_closed:
# work around SQLite issue whereby cursor.description