summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/dialects/postgresql/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/dialects/postgresql/base.py')
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py898
1 files changed, 898 insertions, 0 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
new file mode 100644
index 000000000..874907abc
--- /dev/null
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -0,0 +1,898 @@
+# postgresql.py
+# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+"""Support for the PostgreSQL database.
+
+For information on connecting using specific drivers, see the documentation section
+regarding that driver.
+
+Sequences/SERIAL
+----------------
+
+PostgreSQL supports sequences, and SQLAlchemy uses these as the default means of creating
+new primary key values for integer-based primary key columns. When creating tables,
+SQLAlchemy will issue the ``SERIAL`` datatype for integer-based primary key columns,
+which generates a sequence corresponding to the column and associated with it based on
+a naming convention.
+
+To specify a specific named sequence to be used for primary key generation, use the
+:func:`~sqlalchemy.schema.Sequence` construct::
+
+ Table('sometable', metadata,
+ Column('id', Integer, Sequence('some_id_seq'), primary_key=True)
+ )
+
+Currently, when SQLAlchemy issues a single insert statement, to fulfill the contract of
+having the "last insert identifier" available, the sequence is executed independently
+beforehand and the new value is retrieved, to be used in the subsequent insert. Note
+that when an :func:`~sqlalchemy.sql.expression.insert()` construct is executed using
+"executemany" semantics, the sequence is not pre-executed and normal PG SERIAL behavior
+is used.
+
+PostgreSQL 8.3 supports an ``INSERT...RETURNING`` syntax which SQLAlchemy supports
+as well. A future release of SQLA will use this feature by default in lieu of
+sequence pre-execution in order to retrieve new primary key values, when available.
+
+INSERT/UPDATE...RETURNING
+-------------------------
+
+The dialect supports PG 8.3's ``INSERT..RETURNING`` and ``UPDATE..RETURNING`` syntaxes,
+but must be explicitly enabled on a per-statement basis::
+
+ # INSERT..RETURNING
+ result = table.insert(postgresql_returning=[table.c.col1, table.c.col2]).\\
+ values(name='foo')
+ print result.fetchall()
+
+ # UPDATE..RETURNING
+ result = table.update(postgresql_returning=[table.c.col1, table.c.col2]).\\
+ where(table.c.name=='foo').values(name='bar')
+ print result.fetchall()
+
+Indexes
+-------
+
+PostgreSQL supports partial indexes. To create them pass a postgresql_where
+option to the Index constructor::
+
+ Index('my_index', my_table.c.id, postgresql_where=tbl.c.value > 10)
+
+
+
+"""
+
+import re
+
+from sqlalchemy import schema as sa_schema
+from sqlalchemy import sql, schema, exc, util
+from sqlalchemy.engine import base, default, reflection
+from sqlalchemy.sql import compiler, expression
+from sqlalchemy.sql import operators as sql_operators
+from sqlalchemy import types as sqltypes
+
+from sqlalchemy.types import INTEGER, BIGINT, SMALLINT, VARCHAR, \
+ CHAR, TEXT, FLOAT, NUMERIC, \
+ TIMESTAMP, TIME, DATE, BOOLEAN
+
+class REAL(sqltypes.Float):
+ __visit_name__ = "REAL"
+
+class BYTEA(sqltypes.Binary):
+ __visit_name__ = 'BYTEA'
+
+class DOUBLE_PRECISION(sqltypes.Float):
+ __visit_name__ = 'DOUBLE_PRECISION'
+
+class INET(sqltypes.TypeEngine):
+ __visit_name__ = "INET"
+PGInet = INET
+
+class CIDR(sqltypes.TypeEngine):
+ __visit_name__ = "CIDR"
+PGCidr = CIDR
+
+class MACADDR(sqltypes.TypeEngine):
+ __visit_name__ = "MACADDR"
+PGMacAddr = MACADDR
+
+class INTERVAL(sqltypes.TypeEngine):
+ __visit_name__ = 'INTERVAL'
+PGInterval = INTERVAL
+
+class BIT(sqltypes.TypeEngine):
+ __visit_name__ = 'BIT'
+PGBit = BIT
+
+class UUID(sqltypes.TypeEngine):
+ __visit_name__ = 'UUID'
+PGUuid = UUID
+
+class ARRAY(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine):
+ __visit_name__ = 'ARRAY'
+
+ def __init__(self, item_type, mutable=True):
+ if isinstance(item_type, type):
+ item_type = item_type()
+ self.item_type = item_type
+ self.mutable = mutable
+
+ def copy_value(self, value):
+ if value is None:
+ return None
+ elif self.mutable:
+ return list(value)
+ else:
+ return value
+
+ def compare_values(self, x, y):
+ return x == y
+
+ def is_mutable(self):
+ return self.mutable
+
+ def dialect_impl(self, dialect, **kwargs):
+ impl = self.__class__.__new__(self.__class__)
+ impl.__dict__.update(self.__dict__)
+ impl.item_type = self.item_type.dialect_impl(dialect)
+ return impl
+
+ def bind_processor(self, dialect):
+ item_proc = self.item_type.bind_processor(dialect)
+ def process(value):
+ if value is None:
+ return value
+ def convert_item(item):
+ if isinstance(item, (list, tuple)):
+ return [convert_item(child) for child in item]
+ else:
+ if item_proc:
+ return item_proc(item)
+ else:
+ return item
+ return [convert_item(item) for item in value]
+ return process
+
+ def result_processor(self, dialect):
+ item_proc = self.item_type.result_processor(dialect)
+ def process(value):
+ if value is None:
+ return value
+ def convert_item(item):
+ if isinstance(item, list):
+ return [convert_item(child) for child in item]
+ else:
+ if item_proc:
+ return item_proc(item)
+ else:
+ return item
+ return [convert_item(item) for item in value]
+ return process
+PGArray = ARRAY
+
+colspecs = {
+ sqltypes.Interval:INTERVAL
+}
+
+ischema_names = {
+ 'integer' : INTEGER,
+ 'bigint' : BIGINT,
+ 'smallint' : SMALLINT,
+ 'character varying' : VARCHAR,
+ 'character' : CHAR,
+ '"char"' : sqltypes.String,
+ 'name' : sqltypes.String,
+ 'text' : TEXT,
+ 'numeric' : NUMERIC,
+ 'float' : FLOAT,
+ 'real' : REAL,
+ 'inet': INET,
+ 'cidr': CIDR,
+ 'uuid': UUID,
+ 'bit':BIT,
+ 'macaddr': MACADDR,
+ 'double precision' : DOUBLE_PRECISION,
+ 'timestamp' : TIMESTAMP,
+ 'timestamp with time zone' : TIMESTAMP,
+ 'timestamp without time zone' : TIMESTAMP,
+ 'time with time zone' : TIME,
+ 'time without time zone' : TIME,
+ 'date' : DATE,
+ 'time': TIME,
+ 'bytea' : BYTEA,
+ 'boolean' : BOOLEAN,
+ 'interval':INTERVAL,
+}
+
+
+
+class PGCompiler(compiler.SQLCompiler):
+
+ def visit_match_op(self, binary, **kw):
+ return "%s @@ to_tsquery(%s)" % (self.process(binary.left), self.process(binary.right))
+
+ def visit_ilike_op(self, binary, **kw):
+ escape = binary.modifiers.get("escape", None)
+ return '%s ILIKE %s' % (self.process(binary.left), self.process(binary.right)) \
+ + (escape and ' ESCAPE \'%s\'' % escape or '')
+
+ def visit_notilike_op(self, binary, **kw):
+ escape = binary.modifiers.get("escape", None)
+ return '%s NOT ILIKE %s' % (self.process(binary.left), self.process(binary.right)) \
+ + (escape and ' ESCAPE \'%s\'' % escape or '')
+
+ def post_process_text(self, text):
+ if '%%' in text:
+ util.warn("The SQLAlchemy postgresql dialect now automatically escapes '%' in text() expressions to '%%'.")
+ return text.replace('%', '%%')
+
+ def visit_sequence(self, seq):
+ if seq.optional:
+ return None
+ else:
+ return "nextval('%s')" % self.preparer.format_sequence(seq)
+
+ def limit_clause(self, select):
+ text = ""
+ if select._limit is not None:
+ text += " \n LIMIT " + str(select._limit)
+ if select._offset is not None:
+ if select._limit is None:
+ text += " \n LIMIT ALL"
+ text += " OFFSET " + str(select._offset)
+ return text
+
+ def get_select_precolumns(self, select):
+ if select._distinct:
+ if isinstance(select._distinct, bool):
+ return "DISTINCT "
+ elif isinstance(select._distinct, (list, tuple)):
+ return "DISTINCT ON (" + ', '.join(
+ [(isinstance(col, basestring) and col or self.process(col)) for col in select._distinct]
+ )+ ") "
+ else:
+ return "DISTINCT ON (" + unicode(select._distinct) + ") "
+ else:
+ return ""
+
+ def for_update_clause(self, select):
+ if select.for_update == 'nowait':
+ return " FOR UPDATE NOWAIT"
+ else:
+ return super(PGCompiler, self).for_update_clause(select)
+
+ def returning_clause(self, stmt, returning_cols):
+
+ columns = [
+ self.process(
+ self.label_select_column(None, c, asfrom=False),
+ within_columns_clause=True,
+ result_map=self.result_map)
+ for c in expression._select_iterables(returning_cols)
+ ]
+
+ return 'RETURNING ' + ', '.join(columns)
+
+ def visit_extract(self, extract, **kwargs):
+ field = self.extract_map.get(extract.field, extract.field)
+ return "EXTRACT(%s FROM %s::timestamp)" % (
+ field, self.process(extract.expr))
+
+class PGDDLCompiler(compiler.DDLCompiler):
+ def get_column_specification(self, column, **kwargs):
+ colspec = self.preparer.format_column(column)
+ if column.primary_key and \
+ len(column.foreign_keys)==0 and \
+ column.autoincrement and \
+ isinstance(column.type, sqltypes.Integer) and \
+ not isinstance(column.type, sqltypes.SmallInteger) and \
+ (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
+ if isinstance(column.type, sqltypes.BigInteger):
+ colspec += " BIGSERIAL"
+ else:
+ colspec += " SERIAL"
+ else:
+ colspec += " " + self.dialect.type_compiler.process(column.type)
+ default = self.get_column_default_string(column)
+ if default is not None:
+ colspec += " DEFAULT " + default
+
+ if not column.nullable:
+ colspec += " NOT NULL"
+ return colspec
+
+ def visit_create_sequence(self, create):
+ return "CREATE SEQUENCE %s" % self.preparer.format_sequence(create.element)
+
+ def visit_drop_sequence(self, drop):
+ return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element)
+
+ def visit_create_index(self, create):
+ preparer = self.preparer
+ index = create.element
+ text = "CREATE "
+ if index.unique:
+ text += "UNIQUE "
+ text += "INDEX %s ON %s (%s)" \
+ % (preparer.quote(self._validate_identifier(index.name, True), index.quote),
+ preparer.format_table(index.table),
+ ', '.join([preparer.format_column(c) for c in index.columns]))
+
+ if "postgres_where" in index.kwargs:
+ whereclause = index.kwargs['postgres_where']
+ util.warn_deprecated("The 'postgres_where' argument has been renamed to 'postgresql_where'.")
+ elif 'postgresql_where' in index.kwargs:
+ whereclause = index.kwargs['postgresql_where']
+ else:
+ whereclause = None
+
+ if whereclause is not None:
+ compiler = self._compile(whereclause, None)
+ # this might belong to the compiler class
+ inlined_clause = str(compiler) % dict(
+ [(key,bind.value) for key,bind in compiler.binds.iteritems()])
+ text += " WHERE " + inlined_clause
+ return text
+
+
+class PGDefaultRunner(base.DefaultRunner):
+ def __init__(self, context):
+ base.DefaultRunner.__init__(self, context)
+ # craete cursor which won't conflict with a server-side cursor
+ self.cursor = context._connection.connection.cursor()
+
+ def get_column_default(self, column, isinsert=True):
+ if column.primary_key:
+ # pre-execute passive defaults on primary keys
+ if (isinstance(column.server_default, schema.DefaultClause) and
+ column.server_default.arg is not None):
+ return self.execute_string("select %s" % column.server_default.arg)
+ elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) \
+ and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
+ sch = column.table.schema
+ # TODO: this has to build into the Sequence object so we can get the quoting
+ # logic from it
+ if sch is not None:
+ exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name)
+ else:
+ exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name)
+
+ if self.dialect.supports_unicode_statements:
+ return self.execute_string(exc)
+ else:
+ return self.execute_string(exc.encode(self.dialect.encoding))
+
+ return super(PGDefaultRunner, self).get_column_default(column)
+
+ def visit_sequence(self, seq):
+ if not seq.optional:
+ return self.execute_string(("select nextval('%s')" % \
+ self.dialect.identifier_preparer.format_sequence(seq)))
+ else:
+ return None
+
+class PGTypeCompiler(compiler.GenericTypeCompiler):
+ def visit_INET(self, type_):
+ return "INET"
+
+ def visit_CIDR(self, type_):
+ return "CIDR"
+
+ def visit_MACADDR(self, type_):
+ return "MACADDR"
+
+ def visit_FLOAT(self, type_):
+ if not type_.precision:
+ return "FLOAT"
+ else:
+ return "FLOAT(%(precision)s)" % {'precision': type_.precision}
+
+ def visit_DOUBLE_PRECISION(self, type_):
+ return "DOUBLE PRECISION"
+
+ def visit_BIGINT(self, type_):
+ return "BIGINT"
+
+ def visit_datetime(self, type_):
+ return self.visit_TIMESTAMP(type_)
+
+ def visit_TIMESTAMP(self, type_):
+ return "TIMESTAMP " + (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
+
+ def visit_TIME(self, type_):
+ return "TIME " + (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
+
+ def visit_INTERVAL(self, type_):
+ return "INTERVAL"
+
+ def visit_BIT(self, type_):
+ return "BIT"
+
+ def visit_UUID(self, type_):
+ return "UUID"
+
+ def visit_binary(self, type_):
+ return self.visit_BYTEA(type_)
+
+ def visit_BYTEA(self, type_):
+ return "BYTEA"
+
+ def visit_REAL(self, type_):
+ return "REAL"
+
+ def visit_ARRAY(self, type_):
+ return self.process(type_.item_type) + '[]'
+
+class PGIdentifierPreparer(compiler.IdentifierPreparer):
+ def _unquote_identifier(self, value):
+ if value[0] == self.initial_quote:
+ value = value[1:-1].replace('""','"')
+ return value
+
+class PGInspector(reflection.Inspector):
+
+ def __init__(self, conn):
+ reflection.Inspector.__init__(self, conn)
+
+ def get_table_oid(self, table_name, schema=None):
+ """Return the oid from `table_name` and `schema`."""
+
+ return self.dialect.get_table_oid(self.conn, table_name, schema,
+ info_cache=self.info_cache)
+
+
+class PGDialect(default.DefaultDialect):
+ name = 'postgresql'
+ supports_alter = True
+ max_identifier_length = 63
+ supports_sane_rowcount = True
+
+ supports_sequences = True
+ sequences_optional = True
+ preexecute_autoincrement_sequences = True
+ postfetch_lastrowid = False
+
+ supports_default_values = True
+ supports_empty_insert = False
+ default_paramstyle = 'pyformat'
+ ischema_names = ischema_names
+ colspecs = colspecs
+
+ statement_compiler = PGCompiler
+ ddl_compiler = PGDDLCompiler
+ type_compiler = PGTypeCompiler
+ preparer = PGIdentifierPreparer
+ defaultrunner = PGDefaultRunner
+ inspector = PGInspector
+ isolation_level = None
+
+ def __init__(self, isolation_level=None, **kwargs):
+ default.DefaultDialect.__init__(self, **kwargs)
+ self.isolation_level = isolation_level
+
+ def initialize(self, connection):
+ super(PGDialect, self).initialize(connection)
+ self.implicit_returning = self.server_version_info > (8, 3) and \
+ self.__dict__.get('implicit_returning', True)
+
+ def visit_pool(self, pool):
+ if self.isolation_level is not None:
+ class SetIsolationLevel(object):
+ def __init__(self, isolation_level):
+ self.isolation_level = isolation_level
+
+ def connect(self, conn, rec):
+ cursor = conn.cursor()
+ cursor.execute("SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL %s"
+ % self.isolation_level)
+ cursor.execute("COMMIT")
+ cursor.close()
+ pool.add_listener(SetIsolationLevel(self.isolation_level))
+
+ def do_begin_twophase(self, connection, xid):
+ self.do_begin(connection.connection)
+
+ def do_prepare_twophase(self, connection, xid):
+ connection.execute("PREPARE TRANSACTION '%s'" % xid)
+
+ def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False):
+ if is_prepared:
+ if recover:
+ #FIXME: ugly hack to get out of transaction context when commiting recoverable transactions
+ # Must find out a way how to make the dbapi not open a transaction.
+ connection.execute("ROLLBACK")
+ connection.execute("ROLLBACK PREPARED '%s'" % xid)
+ connection.execute("BEGIN")
+ self.do_rollback(connection.connection)
+ else:
+ self.do_rollback(connection.connection)
+
+ def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False):
+ if is_prepared:
+ if recover:
+ connection.execute("ROLLBACK")
+ connection.execute("COMMIT PREPARED '%s'" % xid)
+ connection.execute("BEGIN")
+ self.do_rollback(connection.connection)
+ else:
+ self.do_commit(connection.connection)
+
+ def do_recover_twophase(self, connection):
+ resultset = connection.execute(sql.text("SELECT gid FROM pg_prepared_xacts"))
+ return [row[0] for row in resultset]
+
+ def get_default_schema_name(self, connection):
+ return connection.scalar("select current_schema()")
+
+ def has_table(self, connection, table_name, schema=None):
+ # seems like case gets folded in pg_class...
+ if schema is None:
+ cursor = connection.execute(
+ sql.text("select relname from pg_class c join pg_namespace n on "
+ "n.oid=c.relnamespace where n.nspname=current_schema() and lower(relname)=:name",
+ bindparams=[sql.bindparam('name', unicode(table_name.lower()), type_=sqltypes.Unicode)]
+ )
+ )
+ else:
+ cursor = connection.execute(
+ sql.text("select relname from pg_class c join pg_namespace n on "
+ "n.oid=c.relnamespace where n.nspname=:schema and lower(relname)=:name",
+ bindparams=[sql.bindparam('name', unicode(table_name.lower()), type_=sqltypes.Unicode),
+ sql.bindparam('schema', unicode(schema), type_=sqltypes.Unicode)]
+ )
+ )
+ return bool(cursor.fetchone())
+
+ def has_sequence(self, connection, sequence_name):
+ cursor = connection.execute(
+ sql.text("SELECT relname FROM pg_class WHERE relkind = 'S' AND "
+ "relnamespace IN ( SELECT oid FROM pg_namespace WHERE nspname NOT LIKE 'pg_%' "
+ "AND nspname != 'information_schema' AND relname = :seqname)",
+ bindparams=[sql.bindparam('seqname', unicode(sequence_name), type_=sqltypes.Unicode)]
+ ))
+ return bool(cursor.fetchone())
+
+ def table_names(self, connection, schema):
+ result = connection.execute(
+ sql.text(u"""SELECT relname
+ FROM pg_class c
+ WHERE relkind = 'r'
+ AND '%s' = (select nspname from pg_namespace n where n.oid = c.relnamespace)""" % schema,
+ typemap = {'relname':sqltypes.Unicode}
+ )
+ )
+ return [row[0] for row in result]
+
+ def _get_server_version_info(self, connection):
+ v = connection.execute("select version()").scalar()
+ m = re.match('PostgreSQL (\d+)\.(\d+)\.(\d+)', v)
+ if not m:
+ raise AssertionError("Could not determine version from string '%s'" % v)
+ return tuple([int(x) for x in m.group(1, 2, 3)])
+
+ @reflection.cache
+ def get_table_oid(self, connection, table_name, schema=None, **kw):
+ """Fetch the oid for schema.table_name.
+
+ Several reflection methods require the table oid. The idea for using
+ this method is that it can be fetched one time and cached for
+ subsequent calls.
+
+ """
+ table_oid = None
+ if schema is not None:
+ schema_where_clause = "n.nspname = :schema"
+ else:
+ schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)"
+ query = """
+ SELECT c.oid
+ FROM pg_catalog.pg_class c
+ LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
+ WHERE (%s)
+ AND c.relname = :table_name AND c.relkind in ('r','v')
+ """ % schema_where_clause
+ # Since we're binding to unicode, table_name and schema_name must be
+ # unicode.
+ table_name = unicode(table_name)
+ if schema is not None:
+ schema = unicode(schema)
+ s = sql.text(query, bindparams=[
+ sql.bindparam('table_name', type_=sqltypes.Unicode),
+ sql.bindparam('schema', type_=sqltypes.Unicode)
+ ],
+ typemap={'oid':sqltypes.Integer}
+ )
+ c = connection.execute(s, table_name=table_name, schema=schema)
+ table_oid = c.scalar()
+ if table_oid is None:
+ raise exc.NoSuchTableError(table_name)
+ return table_oid
+
+ @reflection.cache
+ def get_schema_names(self, connection, **kw):
+ s = """
+ SELECT nspname
+ FROM pg_namespace
+ ORDER BY nspname
+ """
+ rp = connection.execute(s)
+ # what about system tables?
+ schema_names = [row[0].decode(self.encoding) for row in rp \
+ if not row[0].startswith('pg_')]
+ return schema_names
+
+ @reflection.cache
+ def get_table_names(self, connection, schema=None, **kw):
+ if schema is not None:
+ current_schema = schema
+ else:
+ current_schema = self.get_default_schema_name(connection)
+ table_names = self.table_names(connection, current_schema)
+ return table_names
+
+ @reflection.cache
+ def get_view_names(self, connection, schema=None, **kw):
+ if schema is not None:
+ current_schema = schema
+ else:
+ current_schema = self.get_default_schema_name(connection)
+ s = """
+ SELECT relname
+ FROM pg_class c
+ WHERE relkind = 'v'
+ AND '%(schema)s' = (select nspname from pg_namespace n where n.oid = c.relnamespace)
+ """ % dict(schema=current_schema)
+ view_names = [row[0].decode(self.encoding) for row in connection.execute(s)]
+ return view_names
+
+ @reflection.cache
+ def get_view_definition(self, connection, view_name, schema=None, **kw):
+ if schema is not None:
+ current_schema = schema
+ else:
+ current_schema = self.get_default_schema_name(connection)
+ s = """
+ SELECT definition FROM pg_views
+ WHERE schemaname = :schema
+ AND viewname = :view_name
+ """
+ rp = connection.execute(sql.text(s),
+ view_name=view_name, schema=current_schema)
+ if rp:
+ view_def = rp.scalar().decode(self.encoding)
+ return view_def
+
+ @reflection.cache
+ def get_columns(self, connection, table_name, schema=None, **kw):
+
+ table_oid = self.get_table_oid(connection, table_name, schema,
+ info_cache=kw.get('info_cache'))
+ SQL_COLS = """
+ SELECT a.attname,
+ pg_catalog.format_type(a.atttypid, a.atttypmod),
+ (SELECT substring(d.adsrc for 128) FROM pg_catalog.pg_attrdef d
+ WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum AND a.atthasdef)
+ AS DEFAULT,
+ a.attnotnull, a.attnum, a.attrelid as table_oid
+ FROM pg_catalog.pg_attribute a
+ WHERE a.attrelid = :table_oid
+ AND a.attnum > 0 AND NOT a.attisdropped
+ ORDER BY a.attnum
+ """
+ s = sql.text(SQL_COLS,
+ bindparams=[sql.bindparam('table_oid', type_=sqltypes.Integer)],
+ typemap={'attname':sqltypes.Unicode, 'default':sqltypes.Unicode}
+ )
+ c = connection.execute(s, table_oid=table_oid)
+ rows = c.fetchall()
+ domains = self._load_domains(connection)
+ # format columns
+ columns = []
+ for name, format_type, default, notnull, attnum, table_oid in rows:
+ ## strip (30) from character varying(30)
+ attype = re.search('([^\([]+)', format_type).group(1)
+ nullable = not notnull
+ is_array = format_type.endswith('[]')
+ try:
+ charlen = re.search('\(([\d,]+)\)', format_type).group(1)
+ except:
+ charlen = False
+ numericprec = False
+ numericscale = False
+ if attype == 'numeric':
+ if charlen is False:
+ numericprec, numericscale = (None, None)
+ else:
+ numericprec, numericscale = charlen.split(',')
+ charlen = False
+ if attype == 'double precision':
+ numericprec, numericscale = (53, False)
+ charlen = False
+ if attype == 'integer':
+ numericprec, numericscale = (32, 0)
+ charlen = False
+ args = []
+ for a in (charlen, numericprec, numericscale):
+ if a is None:
+ args.append(None)
+ elif a is not False:
+ args.append(int(a))
+ kwargs = {}
+ if attype == 'timestamp with time zone':
+ kwargs['timezone'] = True
+ elif attype == 'timestamp without time zone':
+ kwargs['timezone'] = False
+ if attype in self.ischema_names:
+ coltype = self.ischema_names[attype]
+ else:
+ if attype in domains:
+ domain = domains[attype]
+ if domain['attype'] in self.ischema_names:
+ # A table can't override whether the domain is nullable.
+ nullable = domain['nullable']
+ if domain['default'] and not default:
+ # It can, however, override the default value, but can't set it to null.
+ default = domain['default']
+ coltype = self.ischema_names[domain['attype']]
+ else:
+ coltype = None
+ if coltype:
+ coltype = coltype(*args, **kwargs)
+ if is_array:
+ coltype = PGArray(coltype)
+ else:
+ util.warn("Did not recognize type '%s' of column '%s'" %
+ (attype, name))
+ coltype = sqltypes.NULLTYPE
+ # adjust the default value
+ autoincrement = False
+ if default is not None:
+ match = re.search(r"""(nextval\(')([^']+)('.*$)""", default)
+ if match is not None:
+ autoincrement = True
+ # the default is related to a Sequence
+ sch = schema
+ if '.' not in match.group(2) and sch is not None:
+ # unconditionally quote the schema name. this could
+ # later be enhanced to obey quoting rules / "quote schema"
+ default = match.group(1) + ('"%s"' % sch) + '.' + match.group(2) + match.group(3)
+
+ column_info = dict(name=name, type=coltype, nullable=nullable,
+ default=default, autoincrement=autoincrement)
+ columns.append(column_info)
+ return columns
+
+ @reflection.cache
+ def get_primary_keys(self, connection, table_name, schema=None, **kw):
+ table_oid = self.get_table_oid(connection, table_name, schema,
+ info_cache=kw.get('info_cache'))
+ PK_SQL = """
+ SELECT attname FROM pg_attribute
+ WHERE attrelid = (
+ SELECT indexrelid FROM pg_index i
+ WHERE i.indrelid = :table_oid
+ AND i.indisprimary = 't')
+ ORDER BY attnum
+ """
+ t = sql.text(PK_SQL, typemap={'attname':sqltypes.Unicode})
+ c = connection.execute(t, table_oid=table_oid)
+ primary_keys = [r[0] for r in c.fetchall()]
+ return primary_keys
+
+ @reflection.cache
+ def get_foreign_keys(self, connection, table_name, schema=None, **kw):
+ preparer = self.identifier_preparer
+ table_oid = self.get_table_oid(connection, table_name, schema,
+ info_cache=kw.get('info_cache'))
+ FK_SQL = """
+ SELECT conname, pg_catalog.pg_get_constraintdef(oid, true) as condef
+ FROM pg_catalog.pg_constraint r
+ WHERE r.conrelid = :table AND r.contype = 'f'
+ ORDER BY 1
+ """
+
+ t = sql.text(FK_SQL, typemap={'conname':sqltypes.Unicode, 'condef':sqltypes.Unicode})
+ c = connection.execute(t, table=table_oid)
+ fkeys = []
+ for conname, condef in c.fetchall():
+ m = re.search('FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)', condef).groups()
+ (constrained_columns, referred_schema, referred_table, referred_columns) = m
+ constrained_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s*', constrained_columns)]
+ if referred_schema:
+ referred_schema = preparer._unquote_identifier(referred_schema)
+ elif schema is not None and schema == self.get_default_schema_name(connection):
+ # no schema (i.e. its the default schema), and the table we're
+ # reflecting has the default schema explicit, then use that.
+ # i.e. try to use the user's conventions
+ referred_schema = schema
+ referred_table = preparer._unquote_identifier(referred_table)
+ referred_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s', referred_columns)]
+ fkey_d = {
+ 'name' : conname,
+ 'constrained_columns' : constrained_columns,
+ 'referred_schema' : referred_schema,
+ 'referred_table' : referred_table,
+ 'referred_columns' : referred_columns
+ }
+ fkeys.append(fkey_d)
+ return fkeys
+
+ @reflection.cache
+ def get_indexes(self, connection, table_name, schema, **kw):
+ table_oid = self.get_table_oid(connection, table_name, schema,
+ info_cache=kw.get('info_cache'))
+ IDX_SQL = """
+ SELECT c.relname, i.indisunique, i.indexprs, i.indpred,
+ a.attname
+ FROM pg_index i, pg_class c, pg_attribute a
+ WHERE i.indrelid = :table_oid AND i.indexrelid = c.oid
+ AND a.attrelid = i.indexrelid AND i.indisprimary = 'f'
+ ORDER BY c.relname, a.attnum
+ """
+ t = sql.text(IDX_SQL, typemap={'attname':sqltypes.Unicode})
+ c = connection.execute(t, table_oid=table_oid)
+ index_names = {}
+ indexes = []
+ sv_idx_name = None
+ for row in c.fetchall():
+ idx_name, unique, expr, prd, col = row
+ if expr:
+ if idx_name != sv_idx_name:
+ util.warn(
+ "Skipped unsupported reflection of expression-based index %s"
+ % idx_name)
+ sv_idx_name = idx_name
+ continue
+ if prd and not idx_name == sv_idx_name:
+ util.warn(
+ "Predicate of partial index %s ignored during reflection"
+ % idx_name)
+ sv_idx_name = idx_name
+ if idx_name in index_names:
+ index_d = index_names[idx_name]
+ else:
+ index_d = {'column_names':[]}
+ indexes.append(index_d)
+ index_names[idx_name] = index_d
+ index_d['name'] = idx_name
+ index_d['column_names'].append(col)
+ index_d['unique'] = unique
+ return indexes
+
+ def _load_domains(self, connection):
+ ## Load data types for domains:
+ SQL_DOMAINS = """
+ SELECT t.typname as "name",
+ pg_catalog.format_type(t.typbasetype, t.typtypmod) as "attype",
+ not t.typnotnull as "nullable",
+ t.typdefault as "default",
+ pg_catalog.pg_type_is_visible(t.oid) as "visible",
+ n.nspname as "schema"
+ FROM pg_catalog.pg_type t
+ LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
+ LEFT JOIN pg_catalog.pg_constraint r ON t.oid = r.contypid
+ WHERE t.typtype = 'd'
+ """
+
+ s = sql.text(SQL_DOMAINS, typemap={'attname':sqltypes.Unicode})
+ c = connection.execute(s)
+
+ domains = {}
+ for domain in c.fetchall():
+ ## strip (30) from character varying(30)
+ attype = re.search('([^\(]+)', domain['attype']).group(1)
+ if domain['visible']:
+ # 'visible' just means whether or not the domain is in a
+ # schema that's on the search path -- or not overriden by
+ # a schema with higher presedence. If it's not visible,
+ # it will be prefixed with the schema-name when it's used.
+ name = domain['name']
+ else:
+ name = "%s.%s" % (domain['schema'], domain['name'])
+
+ domains[name] = {'attype':attype, 'nullable': domain['nullable'], 'default': domain['default']}
+
+ return domains
+