summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/databases/postgres.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2009-08-06 21:11:27 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2009-08-06 21:11:27 +0000
commit8fc5005dfe3eb66a46470ad8a8c7b95fc4d6bdca (patch)
treeae9e27d12c9fbf8297bb90469509e1cb6a206242 /lib/sqlalchemy/databases/postgres.py
parent7638aa7f242c6ea3d743aa9100e32be2052546a6 (diff)
downloadsqlalchemy-8fc5005dfe3eb66a46470ad8a8c7b95fc4d6bdca.tar.gz
merge 0.6 series to trunk.
Diffstat (limited to 'lib/sqlalchemy/databases/postgres.py')
-rw-r--r--lib/sqlalchemy/databases/postgres.py889
1 files changed, 0 insertions, 889 deletions
diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py
deleted file mode 100644
index 154d971e3..000000000
--- a/lib/sqlalchemy/databases/postgres.py
+++ /dev/null
@@ -1,889 +0,0 @@
-# postgres.py
-# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com
-#
-# This module is part of SQLAlchemy and is released under
-# the MIT License: http://www.opensource.org/licenses/mit-license.php
-
-"""Support for the PostgreSQL database.
-
-Driver
-------
-
-The psycopg2 driver is supported, available at http://pypi.python.org/pypi/psycopg2/ .
-The dialect has several behaviors which are specifically tailored towards compatibility
-with this module.
-
-Note that psycopg1 is **not** supported.
-
-Connecting
-----------
-
-URLs are of the form `postgres://user:password@host:port/dbname[?key=value&key=value...]`.
-
-PostgreSQL-specific keyword arguments which are accepted by :func:`~sqlalchemy.create_engine()` are:
-
-* *server_side_cursors* - Enable the usage of "server side cursors" for SQL statements which support
- this feature. What this essentially means from a psycopg2 point of view is that the cursor is
- created using a name, e.g. `connection.cursor('some name')`, which has the effect that result rows
- are not immediately pre-fetched and buffered after statement execution, but are instead left
- on the server and only retrieved as needed. SQLAlchemy's :class:`~sqlalchemy.engine.base.ResultProxy`
- uses special row-buffering behavior when this feature is enabled, such that groups of 100 rows
- at a time are fetched over the wire to reduce conversational overhead.
-
-Sequences/SERIAL
-----------------
-
-PostgreSQL supports sequences, and SQLAlchemy uses these as the default means of creating
-new primary key values for integer-based primary key columns. When creating tables,
-SQLAlchemy will issue the ``SERIAL`` datatype for integer-based primary key columns,
-which generates a sequence corresponding to the column and associated with it based on
-a naming convention.
-
-To specify a specific named sequence to be used for primary key generation, use the
-:func:`~sqlalchemy.schema.Sequence` construct::
-
- Table('sometable', metadata,
- Column('id', Integer, Sequence('some_id_seq'), primary_key=True)
- )
-
-Currently, when SQLAlchemy issues a single insert statement, to fulfill the contract of
-having the "last insert identifier" available, the sequence is executed independently
-beforehand and the new value is retrieved, to be used in the subsequent insert. Note
-that when an :func:`~sqlalchemy.sql.expression.insert()` construct is executed using
-"executemany" semantics, the sequence is not pre-executed and normal PG SERIAL behavior
-is used.
-
-PostgreSQL 8.3 supports an ``INSERT...RETURNING`` syntax which SQLAlchemy supports
-as well. A future release of SQLA will use this feature by default in lieu of
-sequence pre-execution in order to retrieve new primary key values, when available.
-
-INSERT/UPDATE...RETURNING
--------------------------
-
-The dialect supports PG 8.3's ``INSERT..RETURNING`` and ``UPDATE..RETURNING`` syntaxes,
-but must be explicitly enabled on a per-statement basis::
-
- # INSERT..RETURNING
- result = table.insert(postgres_returning=[table.c.col1, table.c.col2]).\\
- values(name='foo')
- print result.fetchall()
-
- # UPDATE..RETURNING
- result = table.update(postgres_returning=[table.c.col1, table.c.col2]).\\
- where(table.c.name=='foo').values(name='bar')
- print result.fetchall()
-
-Indexes
--------
-
-PostgreSQL supports partial indexes. To create them pass a postgres_where
-option to the Index constructor::
-
- Index('my_index', my_table.c.id, postgres_where=tbl.c.value > 10)
-
-Transactions
-------------
-
-The PostgreSQL dialect fully supports SAVEPOINT and two-phase commit operations.
-
-
-"""
-
-import decimal, random, re, string
-
-from sqlalchemy import sql, schema, exc, util
-from sqlalchemy.engine import base, default
-from sqlalchemy.sql import compiler, expression
-from sqlalchemy.sql import operators as sql_operators
-from sqlalchemy import types as sqltypes
-
-
-class PGInet(sqltypes.TypeEngine):
- def get_col_spec(self):
- return "INET"
-
-class PGCidr(sqltypes.TypeEngine):
- def get_col_spec(self):
- return "CIDR"
-
-class PGMacAddr(sqltypes.TypeEngine):
- def get_col_spec(self):
- return "MACADDR"
-
-class PGNumeric(sqltypes.Numeric):
- def get_col_spec(self):
- if not self.precision:
- return "NUMERIC"
- else:
- return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale}
-
- def bind_processor(self, dialect):
- return None
-
- def result_processor(self, dialect):
- if self.asdecimal:
- return None
- else:
- def process(value):
- if isinstance(value, decimal.Decimal):
- return float(value)
- else:
- return value
- return process
-
-class PGFloat(sqltypes.Float):
- def get_col_spec(self):
- if not self.precision:
- return "FLOAT"
- else:
- return "FLOAT(%(precision)s)" % {'precision': self.precision}
-
-
-class PGInteger(sqltypes.Integer):
- def get_col_spec(self):
- return "INTEGER"
-
-class PGSmallInteger(sqltypes.Smallinteger):
- def get_col_spec(self):
- return "SMALLINT"
-
-class PGBigInteger(PGInteger):
- def get_col_spec(self):
- return "BIGINT"
-
-class PGDateTime(sqltypes.DateTime):
- def get_col_spec(self):
- return "TIMESTAMP " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
-
-class PGDate(sqltypes.Date):
- def get_col_spec(self):
- return "DATE"
-
-class PGTime(sqltypes.Time):
- def get_col_spec(self):
- return "TIME " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
-
-class PGInterval(sqltypes.TypeEngine):
- def get_col_spec(self):
- return "INTERVAL"
-
-class PGText(sqltypes.Text):
- def get_col_spec(self):
- return "TEXT"
-
-class PGString(sqltypes.String):
- def get_col_spec(self):
- if self.length:
- return "VARCHAR(%(length)d)" % {'length' : self.length}
- else:
- return "VARCHAR"
-
-class PGChar(sqltypes.CHAR):
- def get_col_spec(self):
- if self.length:
- return "CHAR(%(length)d)" % {'length' : self.length}
- else:
- return "CHAR"
-
-class PGBinary(sqltypes.Binary):
- def get_col_spec(self):
- return "BYTEA"
-
-class PGBoolean(sqltypes.Boolean):
- def get_col_spec(self):
- return "BOOLEAN"
-
-class PGBit(sqltypes.TypeEngine):
- def get_col_spec(self):
- return "BIT"
-
-class PGUuid(sqltypes.TypeEngine):
- def get_col_spec(self):
- return "UUID"
-
-class PGArray(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine):
- def __init__(self, item_type, mutable=True):
- if isinstance(item_type, type):
- item_type = item_type()
- self.item_type = item_type
- self.mutable = mutable
-
- def copy_value(self, value):
- if value is None:
- return None
- elif self.mutable:
- return list(value)
- else:
- return value
-
- def compare_values(self, x, y):
- return x == y
-
- def is_mutable(self):
- return self.mutable
-
- def dialect_impl(self, dialect, **kwargs):
- impl = self.__class__.__new__(self.__class__)
- impl.__dict__.update(self.__dict__)
- impl.item_type = self.item_type.dialect_impl(dialect)
- return impl
-
- def bind_processor(self, dialect):
- item_proc = self.item_type.bind_processor(dialect)
- def process(value):
- if value is None:
- return value
- def convert_item(item):
- if isinstance(item, (list, tuple)):
- return [convert_item(child) for child in item]
- else:
- if item_proc:
- return item_proc(item)
- else:
- return item
- return [convert_item(item) for item in value]
- return process
-
- def result_processor(self, dialect):
- item_proc = self.item_type.result_processor(dialect)
- def process(value):
- if value is None:
- return value
- def convert_item(item):
- if isinstance(item, list):
- return [convert_item(child) for child in item]
- else:
- if item_proc:
- return item_proc(item)
- else:
- return item
- return [convert_item(item) for item in value]
- return process
- def get_col_spec(self):
- return self.item_type.get_col_spec() + '[]'
-
-colspecs = {
- sqltypes.Integer : PGInteger,
- sqltypes.Smallinteger : PGSmallInteger,
- sqltypes.Numeric : PGNumeric,
- sqltypes.Float : PGFloat,
- sqltypes.DateTime : PGDateTime,
- sqltypes.Date : PGDate,
- sqltypes.Time : PGTime,
- sqltypes.String : PGString,
- sqltypes.Binary : PGBinary,
- sqltypes.Boolean : PGBoolean,
- sqltypes.Text : PGText,
- sqltypes.CHAR: PGChar,
-}
-
-ischema_names = {
- 'integer' : PGInteger,
- 'bigint' : PGBigInteger,
- 'smallint' : PGSmallInteger,
- 'character varying' : PGString,
- 'character' : PGChar,
- '"char"' : PGChar,
- 'name': PGChar,
- 'text' : PGText,
- 'numeric' : PGNumeric,
- 'float' : PGFloat,
- 'real' : PGFloat,
- 'inet': PGInet,
- 'cidr': PGCidr,
- 'uuid':PGUuid,
- 'bit':PGBit,
- 'macaddr': PGMacAddr,
- 'double precision' : PGFloat,
- 'timestamp' : PGDateTime,
- 'timestamp with time zone' : PGDateTime,
- 'timestamp without time zone' : PGDateTime,
- 'time with time zone' : PGTime,
- 'time without time zone' : PGTime,
- 'date' : PGDate,
- 'time': PGTime,
- 'bytea' : PGBinary,
- 'boolean' : PGBoolean,
- 'interval':PGInterval,
-}
-
-# TODO: filter out 'FOR UPDATE' statements
-SERVER_SIDE_CURSOR_RE = re.compile(
- r'\s*SELECT',
- re.I | re.UNICODE)
-
-class PGExecutionContext(default.DefaultExecutionContext):
- def create_cursor(self):
- # TODO: coverage for server side cursors + select.for_update()
- is_server_side = \
- self.dialect.server_side_cursors and \
- ((self.compiled and isinstance(self.compiled.statement, expression.Selectable)
- and not getattr(self.compiled.statement, 'for_update', False)) \
- or \
- (
- (not self.compiled or isinstance(self.compiled.statement, expression._TextClause))
- and self.statement and SERVER_SIDE_CURSOR_RE.match(self.statement))
- )
-
- self.__is_server_side = is_server_side
- if is_server_side:
- # use server-side cursors:
- # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html
- ident = "c_%s_%s" % (hex(id(self))[2:], hex(random.randint(0, 65535))[2:])
- return self._connection.connection.cursor(ident)
- else:
- return self._connection.connection.cursor()
-
- def get_result_proxy(self):
- if self.__is_server_side:
- return base.BufferedRowResultProxy(self)
- else:
- return base.ResultProxy(self)
-
-class PGDialect(default.DefaultDialect):
- name = 'postgres'
- supports_alter = True
- supports_unicode_statements = False
- max_identifier_length = 63
- supports_sane_rowcount = True
- supports_sane_multi_rowcount = False
- preexecute_pk_sequences = True
- supports_pk_autoincrement = False
- default_paramstyle = 'pyformat'
- supports_default_values = True
- supports_empty_insert = False
-
- def __init__(self, server_side_cursors=False, **kwargs):
- default.DefaultDialect.__init__(self, **kwargs)
- self.server_side_cursors = server_side_cursors
-
- def dbapi(cls):
- import psycopg2 as psycopg
- return psycopg
- dbapi = classmethod(dbapi)
-
- def create_connect_args(self, url):
- opts = url.translate_connect_args(username='user')
- if 'port' in opts:
- opts['port'] = int(opts['port'])
- opts.update(url.query)
- return ([], opts)
-
- def type_descriptor(self, typeobj):
- return sqltypes.adapt_type(typeobj, colspecs)
-
- def do_begin_twophase(self, connection, xid):
- self.do_begin(connection.connection)
-
- def do_prepare_twophase(self, connection, xid):
- connection.execute(sql.text("PREPARE TRANSACTION :tid", bindparams=[sql.bindparam('tid', xid)]))
-
- def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False):
- if is_prepared:
- if recover:
- #FIXME: ugly hack to get out of transaction context when commiting recoverable transactions
- # Must find out a way how to make the dbapi not open a transaction.
- connection.execute(sql.text("ROLLBACK"))
- connection.execute(sql.text("ROLLBACK PREPARED :tid", bindparams=[sql.bindparam('tid', xid)]))
- connection.execute(sql.text("BEGIN"))
- self.do_rollback(connection.connection)
- else:
- self.do_rollback(connection.connection)
-
- def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False):
- if is_prepared:
- if recover:
- connection.execute(sql.text("ROLLBACK"))
- connection.execute(sql.text("COMMIT PREPARED :tid", bindparams=[sql.bindparam('tid', xid)]))
- connection.execute(sql.text("BEGIN"))
- self.do_rollback(connection.connection)
- else:
- self.do_commit(connection.connection)
-
- def do_recover_twophase(self, connection):
- resultset = connection.execute(sql.text("SELECT gid FROM pg_prepared_xacts"))
- return [row[0] for row in resultset]
-
- def get_default_schema_name(self, connection):
- return connection.scalar("select current_schema()", None)
- get_default_schema_name = base.connection_memoize(
- ('dialect', 'default_schema_name'))(get_default_schema_name)
-
- def last_inserted_ids(self):
- if self.context.last_inserted_ids is None:
- raise exc.InvalidRequestError("no INSERT executed, or can't use cursor.lastrowid without PostgreSQL OIDs enabled")
- else:
- return self.context.last_inserted_ids
-
- def has_table(self, connection, table_name, schema=None):
- # seems like case gets folded in pg_class...
- if schema is None:
- cursor = connection.execute("""select relname from pg_class c join pg_namespace n on n.oid=c.relnamespace where n.nspname=current_schema() and lower(relname)=%(name)s""", {'name':table_name.lower().encode(self.encoding)});
- else:
- cursor = connection.execute("""select relname from pg_class c join pg_namespace n on n.oid=c.relnamespace where n.nspname=%(schema)s and lower(relname)=%(name)s""", {'name':table_name.lower().encode(self.encoding), 'schema':schema});
- return bool( not not cursor.rowcount )
-
- def has_sequence(self, connection, sequence_name):
- cursor = connection.execute('''SELECT relname FROM pg_class WHERE relkind = 'S' AND relnamespace IN ( SELECT oid FROM pg_namespace WHERE nspname NOT LIKE 'pg_%%' AND nspname != 'information_schema' AND relname = %(seqname)s);''', {'seqname': sequence_name.encode(self.encoding)})
- return bool(not not cursor.rowcount)
-
- def is_disconnect(self, e):
- if isinstance(e, self.dbapi.OperationalError):
- return 'closed the connection' in str(e) or 'connection not open' in str(e)
- elif isinstance(e, self.dbapi.InterfaceError):
- return 'connection already closed' in str(e) or 'cursor already closed' in str(e)
- elif isinstance(e, self.dbapi.ProgrammingError):
- # yes, it really says "losed", not "closed"
- return "losed the connection unexpectedly" in str(e)
- else:
- return False
-
- def table_names(self, connection, schema):
- s = """
- SELECT relname
- FROM pg_class c
- WHERE relkind = 'r'
- AND '%(schema)s' = (select nspname from pg_namespace n where n.oid = c.relnamespace)
- """ % locals()
- return [row[0].decode(self.encoding) for row in connection.execute(s)]
-
- def server_version_info(self, connection):
- v = connection.execute("select version()").scalar()
- m = re.match('PostgreSQL (\d+)\.(\d+)\.(\d+)', v)
- if not m:
- raise AssertionError("Could not determine version from string '%s'" % v)
- return tuple([int(x) for x in m.group(1, 2, 3)])
-
- def reflecttable(self, connection, table, include_columns):
- preparer = self.identifier_preparer
- if table.schema is not None:
- schema_where_clause = "n.nspname = :schema"
- schemaname = table.schema
- if isinstance(schemaname, str):
- schemaname = schemaname.decode(self.encoding)
- else:
- schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)"
- schemaname = None
-
- SQL_COLS = """
- SELECT a.attname,
- pg_catalog.format_type(a.atttypid, a.atttypmod),
- (SELECT substring(d.adsrc for 128) FROM pg_catalog.pg_attrdef d
- WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum AND a.atthasdef)
- AS DEFAULT,
- a.attnotnull, a.attnum, a.attrelid as table_oid
- FROM pg_catalog.pg_attribute a
- WHERE a.attrelid = (
- SELECT c.oid
- FROM pg_catalog.pg_class c
- LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
- WHERE (%s)
- AND c.relname = :table_name AND c.relkind in ('r','v')
- ) AND a.attnum > 0 AND NOT a.attisdropped
- ORDER BY a.attnum
- """ % schema_where_clause
-
- s = sql.text(SQL_COLS, bindparams=[sql.bindparam('table_name', type_=sqltypes.Unicode), sql.bindparam('schema', type_=sqltypes.Unicode)], typemap={'attname':sqltypes.Unicode, 'default':sqltypes.Unicode})
- tablename = table.name
- if isinstance(tablename, str):
- tablename = tablename.decode(self.encoding)
- c = connection.execute(s, table_name=tablename, schema=schemaname)
- rows = c.fetchall()
-
- if not rows:
- raise exc.NoSuchTableError(table.name)
-
- domains = self._load_domains(connection)
-
- for name, format_type, default, notnull, attnum, table_oid in rows:
- if include_columns and name not in include_columns:
- continue
-
- ## strip (30) from character varying(30)
- attype = re.search('([^\([]+)', format_type).group(1)
- nullable = not notnull
- is_array = format_type.endswith('[]')
-
- try:
- charlen = re.search('\(([\d,]+)\)', format_type).group(1)
- except:
- charlen = False
-
- numericprec = False
- numericscale = False
- if attype == 'numeric':
- if charlen is False:
- numericprec, numericscale = (None, None)
- else:
- numericprec, numericscale = charlen.split(',')
- charlen = False
- if attype == 'double precision':
- numericprec, numericscale = (53, False)
- charlen = False
- if attype == 'integer':
- numericprec, numericscale = (32, 0)
- charlen = False
-
- args = []
- for a in (charlen, numericprec, numericscale):
- if a is None:
- args.append(None)
- elif a is not False:
- args.append(int(a))
-
- kwargs = {}
- if attype == 'timestamp with time zone':
- kwargs['timezone'] = True
- elif attype == 'timestamp without time zone':
- kwargs['timezone'] = False
-
- coltype = None
- if attype in ischema_names:
- coltype = ischema_names[attype]
- else:
- if attype in domains:
- domain = domains[attype]
- if domain['attype'] in ischema_names:
- # A table can't override whether the domain is nullable.
- nullable = domain['nullable']
-
- if domain['default'] and not default:
- # It can, however, override the default value, but can't set it to null.
- default = domain['default']
- coltype = ischema_names[domain['attype']]
-
- if coltype:
- coltype = coltype(*args, **kwargs)
- if is_array:
- coltype = PGArray(coltype)
- else:
- util.warn("Did not recognize type '%s' of column '%s'" %
- (attype, name))
- coltype = sqltypes.NULLTYPE
-
- colargs = []
- if default is not None:
- match = re.search(r"""(nextval\(')([^']+)('.*$)""", default)
- if match is not None:
- # the default is related to a Sequence
- sch = table.schema
- if '.' not in match.group(2) and sch is not None:
- # unconditionally quote the schema name. this could
- # later be enhanced to obey quoting rules / "quote schema"
- default = match.group(1) + ('"%s"' % sch) + '.' + match.group(2) + match.group(3)
- colargs.append(schema.DefaultClause(sql.text(default)))
- table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs))
-
-
- # Primary keys
- PK_SQL = """
- SELECT attname FROM pg_attribute
- WHERE attrelid = (
- SELECT indexrelid FROM pg_index i
- WHERE i.indrelid = :table
- AND i.indisprimary = 't')
- ORDER BY attnum
- """
- t = sql.text(PK_SQL, typemap={'attname':sqltypes.Unicode})
- c = connection.execute(t, table=table_oid)
- for row in c.fetchall():
- pk = row[0]
- if pk in table.c:
- col = table.c[pk]
- table.primary_key.add(col)
- if col.default is None:
- col.autoincrement = False
-
- # Foreign keys
- FK_SQL = """
- SELECT conname, pg_catalog.pg_get_constraintdef(oid, true) as condef
- FROM pg_catalog.pg_constraint r
- WHERE r.conrelid = :table AND r.contype = 'f'
- ORDER BY 1
- """
-
- t = sql.text(FK_SQL, typemap={'conname':sqltypes.Unicode, 'condef':sqltypes.Unicode})
- c = connection.execute(t, table=table_oid)
- for conname, condef in c.fetchall():
- m = re.search('FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)', condef).groups()
- (constrained_columns, referred_schema, referred_table, referred_columns) = m
- constrained_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s*', constrained_columns)]
- if referred_schema:
- referred_schema = preparer._unquote_identifier(referred_schema)
- elif table.schema is not None and table.schema == self.get_default_schema_name(connection):
- # no schema (i.e. its the default schema), and the table we're
- # reflecting has the default schema explicit, then use that.
- # i.e. try to use the user's conventions
- referred_schema = table.schema
- referred_table = preparer._unquote_identifier(referred_table)
- referred_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s', referred_columns)]
-
- refspec = []
- if referred_schema is not None:
- schema.Table(referred_table, table.metadata, autoload=True, schema=referred_schema,
- autoload_with=connection)
- for column in referred_columns:
- refspec.append(".".join([referred_schema, referred_table, column]))
- else:
- schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection)
- for column in referred_columns:
- refspec.append(".".join([referred_table, column]))
-
- table.append_constraint(schema.ForeignKeyConstraint(constrained_columns, refspec, conname, link_to_name=True))
-
- # Indexes
- IDX_SQL = """
- SELECT c.relname, i.indisunique, i.indexprs, i.indpred,
- a.attname
- FROM pg_index i, pg_class c, pg_attribute a
- WHERE i.indrelid = :table AND i.indexrelid = c.oid
- AND a.attrelid = i.indexrelid AND i.indisprimary = 'f'
- ORDER BY c.relname, a.attnum
- """
- t = sql.text(IDX_SQL, typemap={'attname':sqltypes.Unicode})
- c = connection.execute(t, table=table_oid)
- indexes = {}
- sv_idx_name = None
- for row in c.fetchall():
- idx_name, unique, expr, prd, col = row
-
- if expr:
- if not idx_name == sv_idx_name:
- util.warn(
- "Skipped unsupported reflection of expression-based index %s"
- % idx_name)
- sv_idx_name = idx_name
- continue
- if prd and not idx_name == sv_idx_name:
- util.warn(
- "Predicate of partial index %s ignored during reflection"
- % idx_name)
- sv_idx_name = idx_name
-
- if not indexes.has_key(idx_name):
- indexes[idx_name] = [unique, []]
- indexes[idx_name][1].append(col)
-
- for name, (unique, columns) in indexes.items():
- schema.Index(name, *[table.columns[c] for c in columns],
- **dict(unique=unique))
-
-
-
- def _load_domains(self, connection):
- ## Load data types for domains:
- SQL_DOMAINS = """
- SELECT t.typname as "name",
- pg_catalog.format_type(t.typbasetype, t.typtypmod) as "attype",
- not t.typnotnull as "nullable",
- t.typdefault as "default",
- pg_catalog.pg_type_is_visible(t.oid) as "visible",
- n.nspname as "schema"
- FROM pg_catalog.pg_type t
- LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
- LEFT JOIN pg_catalog.pg_constraint r ON t.oid = r.contypid
- WHERE t.typtype = 'd'
- """
-
- s = sql.text(SQL_DOMAINS, typemap={'attname':sqltypes.Unicode})
- c = connection.execute(s)
-
- domains = {}
- for domain in c.fetchall():
- ## strip (30) from character varying(30)
- attype = re.search('([^\(]+)', domain['attype']).group(1)
- if domain['visible']:
- # 'visible' just means whether or not the domain is in a
- # schema that's on the search path -- or not overriden by
- # a schema with higher presedence. If it's not visible,
- # it will be prefixed with the schema-name when it's used.
- name = domain['name']
- else:
- name = "%s.%s" % (domain['schema'], domain['name'])
-
- domains[name] = {'attype':attype, 'nullable': domain['nullable'], 'default': domain['default']}
-
- return domains
-
-
-class PGCompiler(compiler.DefaultCompiler):
- operators = compiler.DefaultCompiler.operators.copy()
- operators.update(
- {
- sql_operators.mod : '%%',
- sql_operators.ilike_op: lambda x, y, escape=None: '%s ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
- sql_operators.notilike_op: lambda x, y, escape=None: '%s NOT ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
- sql_operators.match_op: lambda x, y: '%s @@ to_tsquery(%s)' % (x, y),
- }
- )
-
- functions = compiler.DefaultCompiler.functions.copy()
- functions.update (
- {
- 'TIMESTAMP':util.deprecated(message="Use a literal string 'timestamp <value>' instead")(lambda x:'TIMESTAMP %s' % x),
- }
- )
-
- def visit_sequence(self, seq):
- if seq.optional:
- return None
- else:
- return "nextval('%s')" % self.preparer.format_sequence(seq)
-
- def post_process_text(self, text):
- if '%%' in text:
- util.warn("The SQLAlchemy psycopg2 dialect now automatically escapes '%' in text() expressions to '%%'.")
- return text.replace('%', '%%')
-
- def limit_clause(self, select):
- text = ""
- if select._limit is not None:
- text += " \n LIMIT " + str(select._limit)
- if select._offset is not None:
- if select._limit is None:
- text += " \n LIMIT ALL"
- text += " OFFSET " + str(select._offset)
- return text
-
- def get_select_precolumns(self, select):
- if select._distinct:
- if isinstance(select._distinct, bool):
- return "DISTINCT "
- elif isinstance(select._distinct, (list, tuple)):
- return "DISTINCT ON (" + ', '.join(
- [(isinstance(col, basestring) and col or self.process(col)) for col in select._distinct]
- )+ ") "
- else:
- return "DISTINCT ON (" + unicode(select._distinct) + ") "
- else:
- return ""
-
- def for_update_clause(self, select):
- if select.for_update == 'nowait':
- return " FOR UPDATE NOWAIT"
- else:
- return super(PGCompiler, self).for_update_clause(select)
-
- def _append_returning(self, text, stmt):
- returning_cols = stmt.kwargs['postgres_returning']
- def flatten_columnlist(collist):
- for c in collist:
- if isinstance(c, expression.Selectable):
- for co in c.columns:
- yield co
- else:
- yield c
- columns = [self.process(c, within_columns_clause=True) for c in flatten_columnlist(returning_cols)]
- text += ' RETURNING ' + string.join(columns, ', ')
- return text
-
- def visit_update(self, update_stmt):
- text = super(PGCompiler, self).visit_update(update_stmt)
- if 'postgres_returning' in update_stmt.kwargs:
- return self._append_returning(text, update_stmt)
- else:
- return text
-
- def visit_insert(self, insert_stmt):
- text = super(PGCompiler, self).visit_insert(insert_stmt)
- if 'postgres_returning' in insert_stmt.kwargs:
- return self._append_returning(text, insert_stmt)
- else:
- return text
-
- def visit_extract(self, extract, **kwargs):
- field = self.extract_map.get(extract.field, extract.field)
- return "EXTRACT(%s FROM %s::timestamp)" % (
- field, self.process(extract.expr))
-
-
-class PGSchemaGenerator(compiler.SchemaGenerator):
- def get_column_specification(self, column, **kwargs):
- colspec = self.preparer.format_column(column)
- if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and isinstance(column.type, sqltypes.Integer) and not isinstance(column.type, sqltypes.SmallInteger) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
- if isinstance(column.type, PGBigInteger):
- colspec += " BIGSERIAL"
- else:
- colspec += " SERIAL"
- else:
- colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
- default = self.get_column_default_string(column)
- if default is not None:
- colspec += " DEFAULT " + default
-
- if not column.nullable:
- colspec += " NOT NULL"
- return colspec
-
- def visit_sequence(self, sequence):
- if not sequence.optional and (not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name)):
- self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence))
- self.execute()
-
- def visit_index(self, index):
- preparer = self.preparer
- self.append("CREATE ")
- if index.unique:
- self.append("UNIQUE ")
- self.append("INDEX %s ON %s (%s)" \
- % (preparer.quote(self._validate_identifier(index.name, True), index.quote),
- preparer.format_table(index.table),
- string.join([preparer.format_column(c) for c in index.columns], ', ')))
- whereclause = index.kwargs.get('postgres_where', None)
- if whereclause is not None:
- compiler = self._compile(whereclause, None)
- # this might belong to the compiler class
- inlined_clause = str(compiler) % dict(
- [(key,bind.value) for key,bind in compiler.binds.iteritems()])
- self.append(" WHERE " + inlined_clause)
- self.execute()
-
-class PGSchemaDropper(compiler.SchemaDropper):
- def visit_sequence(self, sequence):
- if not sequence.optional and (not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name)):
- self.append("DROP SEQUENCE %s" % self.preparer.format_sequence(sequence))
- self.execute()
-
-class PGDefaultRunner(base.DefaultRunner):
- def __init__(self, context):
- base.DefaultRunner.__init__(self, context)
- # craete cursor which won't conflict with a server-side cursor
- self.cursor = context._connection.connection.cursor()
-
- def get_column_default(self, column, isinsert=True):
- if column.primary_key:
- # pre-execute passive defaults on primary keys
- if (isinstance(column.server_default, schema.DefaultClause) and
- column.server_default.arg is not None):
- return self.execute_string("select %s" % column.server_default.arg)
- elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
- sch = column.table.schema
- # TODO: this has to build into the Sequence object so we can get the quoting
- # logic from it
- if sch is not None:
- exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name)
- else:
- exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name)
- return self.execute_string(exc.encode(self.dialect.encoding))
-
- return super(PGDefaultRunner, self).get_column_default(column)
-
- def visit_sequence(self, seq):
- if not seq.optional:
- return self.execute_string(("select nextval('%s')" % self.dialect.identifier_preparer.format_sequence(seq)))
- else:
- return None
-
-class PGIdentifierPreparer(compiler.IdentifierPreparer):
- def _unquote_identifier(self, value):
- if value[0] == self.initial_quote:
- value = value[1:-1].replace('""','"')
- return value
-
-dialect = PGDialect
-dialect.statement_compiler = PGCompiler
-dialect.schemagenerator = PGSchemaGenerator
-dialect.schemadropper = PGSchemaDropper
-dialect.preparer = PGIdentifierPreparer
-dialect.defaultrunner = PGDefaultRunner
-dialect.execution_ctx_cls = PGExecutionContext