diff options
Diffstat (limited to 'lib/sqlalchemy/dialects/postgresql/base.py')
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/base.py | 252 |
1 files changed, 189 insertions, 63 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 6ccf7190e..11bd3830d 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1,5 +1,5 @@ # postgresql/base.py -# Copyright (C) 2005-2013 the SQLAlchemy authors and contributors <see AUTHORS file> +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file> # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php @@ -131,6 +131,44 @@ use the :meth:`._UpdateBase.returning` method on a per-statement basis:: where(table.c.name=='foo') print result.fetchall() +.. _postgresql_match: + +Full Text Search +---------------- + +SQLAlchemy makes available the Postgresql ``@@`` operator via the +:meth:`.ColumnElement.match` method on any textual column expression. +On a Postgresql dialect, an expression like the following:: + + select([sometable.c.text.match("search string")]) + +will emit to the database:: + + SELECT text @@ to_tsquery('search string') FROM table + +The Postgresql text search functions such as ``to_tsquery()`` +and ``to_tsvector()`` are available +explicitly using the standard :attr:`.func` construct. For example:: + + select([ + func.to_tsvector('fat cats ate rats').match('cat & rat') + ]) + +Emits the equivalent of:: + + SELECT to_tsvector('fat cats ate rats') @@ to_tsquery('cat & rat') + +The :class:`.postgresql.TSVECTOR` type can provide for explicit CAST:: + + from sqlalchemy.dialects.postgresql import TSVECTOR + from sqlalchemy import select, cast + select([cast("some text", TSVECTOR)]) + +produces a statement equivalent to:: + + SELECT CAST('some text' AS TSVECTOR) AS anon_1 + + FROM ONLY ... ------------------------ @@ -210,7 +248,7 @@ import re from ... import sql, schema, exc, util from ...engine import default, reflection -from ...sql import compiler, expression, util as sql_util, operators +from ...sql import compiler, expression, operators from ... import types as sqltypes try: @@ -230,7 +268,7 @@ RESERVED_WORDS = set( "default", "deferrable", "desc", "distinct", "do", "else", "end", "except", "false", "fetch", "for", "foreign", "from", "grant", "group", "having", "in", "initially", "intersect", "into", "leading", "limit", - "localtime", "localtimestamp", "new", "not", "null", "off", "offset", + "localtime", "localtimestamp", "new", "not", "null", "of", "off", "offset", "old", "on", "only", "or", "order", "placing", "primary", "references", "returning", "select", "session_user", "some", "symmetric", "table", "then", "to", "trailing", "true", "union", "unique", "user", "using", @@ -368,6 +406,23 @@ class UUID(sqltypes.TypeEngine): PGUuid = UUID +class TSVECTOR(sqltypes.TypeEngine): + """The :class:`.postgresql.TSVECTOR` type implements the Postgresql + text search type TSVECTOR. + + It can be used to do full text queries on natural language + documents. + + .. versionadded:: 0.9.0 + + .. seealso:: + + :ref:`postgresql_match` + + """ + __visit_name__ = 'TSVECTOR' + + class _Slice(expression.ColumnElement): __visit_name__ = 'slice' @@ -913,6 +968,7 @@ ischema_names = { 'interval': INTERVAL, 'interval year to month': INTERVAL, 'interval day to second': INTERVAL, + 'tsvector' : TSVECTOR } @@ -954,25 +1010,30 @@ class PGCompiler(compiler.SQLCompiler): def visit_ilike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) + return '%s ILIKE %s' % \ (self.process(binary.left, **kw), self.process(binary.right, **kw)) \ - + (escape and - (' ESCAPE ' + self.render_literal_value(escape, None)) - or '') + + ( + ' ESCAPE ' + + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape else '' + ) def visit_notilike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) return '%s NOT ILIKE %s' % \ (self.process(binary.left, **kw), self.process(binary.right, **kw)) \ - + (escape and - (' ESCAPE ' + self.render_literal_value(escape, None)) - or '') + + ( + ' ESCAPE ' + + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape else '' + ) def render_literal_value(self, value, type_): value = super(PGCompiler, self).render_literal_value(value, type_) - # TODO: need to inspect "standard_conforming_strings" + if self.dialect._backslash_escapes: value = value.replace('\\', '\\\\') return value @@ -1009,14 +1070,25 @@ class PGCompiler(compiler.SQLCompiler): return "" def for_update_clause(self, select): - if select.for_update == 'nowait': - return " FOR UPDATE NOWAIT" - elif select.for_update == 'read': - return " FOR SHARE" - elif select.for_update == 'read_nowait': - return " FOR SHARE NOWAIT" + + if select._for_update_arg.read: + tmp = " FOR SHARE" else: - return super(PGCompiler, self).for_update_clause(select) + tmp = " FOR UPDATE" + + if select._for_update_arg.of: + tables = util.OrderedSet( + c.table if isinstance(c, expression.ColumnClause) + else c for c in select._for_update_arg.of) + tmp += " OF " + ", ".join( + self.process(table, ashint=True) + for table in tables + ) + + if select._for_update_arg.nowait: + tmp += " NOWAIT" + + return tmp def returning_clause(self, stmt, returning_cols): @@ -1039,12 +1111,15 @@ class PGCompiler(compiler.SQLCompiler): class PGDDLCompiler(compiler.DDLCompiler): def get_column_specification(self, column, **kwargs): + colspec = self.preparer.format_column(column) impl_type = column.type.dialect_impl(self.dialect) if column.primary_key and \ column is column.table._autoincrement_column and \ - not isinstance(impl_type, sqltypes.SmallInteger) and \ ( + self.dialect.supports_smallserial or + not isinstance(impl_type, sqltypes.SmallInteger) + ) and ( column.default is None or ( isinstance(column.default, schema.Sequence) and @@ -1052,6 +1127,8 @@ class PGDDLCompiler(compiler.DDLCompiler): )): if isinstance(impl_type, sqltypes.BigInteger): colspec += " BIGSERIAL" + elif isinstance(impl_type, sqltypes.SmallInteger): + colspec += " SMALLSERIAL" else: colspec += " SERIAL" else: @@ -1069,7 +1146,9 @@ class PGDDLCompiler(compiler.DDLCompiler): return "CREATE TYPE %s AS ENUM (%s)" % ( self.preparer.format_type(type_), - ",".join("'%s'" % e for e in type_.enums) + ", ".join( + self.sql_compiler.process(sql.literal(e), literal_binds=True) + for e in type_.enums) ) def visit_drop_enum_type(self, drop): @@ -1092,31 +1171,29 @@ class PGDDLCompiler(compiler.DDLCompiler): preparer.format_table(index.table) ) - if 'postgresql_using' in index.kwargs: - using = index.kwargs['postgresql_using'] - text += "USING %s " % preparer.quote(using, index.quote) + using = index.dialect_options['postgresql']['using'] + if using: + text += "USING %s " % preparer.quote(using) - ops = index.kwargs.get('postgresql_ops', {}) + ops = index.dialect_options["postgresql"]["ops"] text += "(%s)" \ % ( ', '.join([ - self.sql_compiler.process(expr, include_table=False) + - - + self.sql_compiler.process( + expr.self_group() + if not isinstance(expr, expression.ColumnClause) + else expr, + include_table=False, literal_binds=True) + (c.key in ops and (' ' + ops[c.key]) or '') - - for expr, c in zip(index.expressions, index.columns)]) ) - if 'postgresql_where' in index.kwargs: - whereclause = index.kwargs['postgresql_where'] - else: - whereclause = None + whereclause = index.dialect_options["postgresql"]["where"] if whereclause is not None: - whereclause = sql_util.expression_as_ddl(whereclause) - where_compiled = self.sql_compiler.process(whereclause) + where_compiled = self.sql_compiler.process( + whereclause, include_table=False, + literal_binds=True) text += " WHERE " + where_compiled return text @@ -1128,16 +1205,20 @@ class PGDDLCompiler(compiler.DDLCompiler): elements = [] for c in constraint.columns: op = constraint.operators[c.name] - elements.append(self.preparer.quote(c.name, c.quote)+' WITH '+op) + elements.append(self.preparer.quote(c.name) + ' WITH '+op) text += "EXCLUDE USING %s (%s)" % (constraint.using, ', '.join(elements)) if constraint.where is not None: - sqltext = sql_util.expression_as_ddl(constraint.where) - text += ' WHERE (%s)' % self.sql_compiler.process(sqltext) + text += ' WHERE (%s)' % self.sql_compiler.process( + constraint.where, + literal_binds=True) text += self.define_constraint_deferrability(constraint) return text class PGTypeCompiler(compiler.GenericTypeCompiler): + def visit_TSVECTOR(self, type): + return "TSVECTOR" + def visit_INET(self, type_): return "INET" @@ -1162,6 +1243,9 @@ class PGTypeCompiler(compiler.GenericTypeCompiler): def visit_HSTORE(self, type_): return "HSTORE" + def visit_JSON(self, type_): + return "JSON" + def visit_INT4RANGE(self, type_): return "INT4RANGE" @@ -1250,9 +1334,9 @@ class PGIdentifierPreparer(compiler.IdentifierPreparer): if not type_.name: raise exc.CompileError("Postgresql ENUM type requires a name.") - name = self.quote(type_.name, type_.quote) + name = self.quote(type_.name) if not self.omit_schema and use_schema and type_.schema is not None: - name = self.quote_schema(type_.schema, type_.quote) + "." + name + name = self.quote_schema(type_.schema) + "." + name return name @@ -1328,6 +1412,7 @@ class PGDialect(default.DefaultDialect): supports_native_enum = True supports_native_boolean = True + supports_smallserial = True supports_sequences = True sequences_optional = True @@ -1349,12 +1434,22 @@ class PGDialect(default.DefaultDialect): inspector = PGInspector isolation_level = None - # TODO: need to inspect "standard_conforming_strings" + construct_arguments = [ + (schema.Index, { + "using": False, + "where": None, + "ops": {} + }) + ] + _backslash_escapes = True - def __init__(self, isolation_level=None, **kwargs): + def __init__(self, isolation_level=None, json_serializer=None, + json_deserializer=None, **kwargs): default.DefaultDialect.__init__(self, **kwargs) self.isolation_level = isolation_level + self._json_deserializer = json_deserializer + self._json_serializer = json_serializer def initialize(self, connection): super(PGDialect, self).initialize(connection) @@ -1368,6 +1463,13 @@ class PGDialect(default.DefaultDialect): # psycopg2, others may have placed ENUM here as well self.colspecs.pop(ENUM, None) + # http://www.postgresql.org/docs/9.3/static/release-9-2.html#AEN116689 + self.supports_smallserial = self.server_version_info >= (9, 2) + + self._backslash_escapes = connection.scalar( + "show standard_conforming_strings" + ) == 'off' + def on_connect(self): if self.isolation_level is not None: def connect(conn): @@ -1515,12 +1617,6 @@ class PGDialect(default.DefaultDialect): return bool(cursor.first()) def has_type(self, connection, type_name, schema=None): - bindparams = [ - sql.bindparam('typname', - util.text_type(type_name), type_=sqltypes.Unicode), - sql.bindparam('nspname', - util.text_type(schema), type_=sqltypes.Unicode), - ] if schema is not None: query = """ SELECT EXISTS ( @@ -1530,6 +1626,7 @@ class PGDialect(default.DefaultDialect): AND n.nspname = :nspname ) """ + query = sql.text(query) else: query = """ SELECT EXISTS ( @@ -1538,13 +1635,23 @@ class PGDialect(default.DefaultDialect): AND pg_type_is_visible(t.oid) ) """ - cursor = connection.execute(sql.text(query, bindparams=bindparams)) + query = sql.text(query) + query = query.bindparams( + sql.bindparam('typname', + util.text_type(type_name), type_=sqltypes.Unicode), + ) + if schema is not None: + query = query.bindparams( + sql.bindparam('nspname', + util.text_type(schema), type_=sqltypes.Unicode), + ) + cursor = connection.execute(query) return bool(cursor.scalar()) def _get_server_version_info(self, connection): v = connection.execute("select version()").scalar() m = re.match( - '(?:PostgreSQL|EnterpriseDB) ' + '.*(?:PostgreSQL|EnterpriseDB) ' '(\d+)\.(\d+)(?:\.(\d+))?(?:\.\d+)?(?:devel)?', v) if not m: @@ -1578,12 +1685,10 @@ class PGDialect(default.DefaultDialect): table_name = util.text_type(table_name) if schema is not None: schema = util.text_type(schema) - s = sql.text(query, bindparams=[ - sql.bindparam('table_name', type_=sqltypes.Unicode), - sql.bindparam('schema', type_=sqltypes.Unicode) - ], - typemap={'oid': sqltypes.Integer} - ) + s = sql.text(query).bindparams(table_name=sqltypes.Unicode) + s = s.columns(oid=sqltypes.Integer) + if schema: + s = s.bindparams(sql.bindparam('schema', type_=sqltypes.Unicode)) c = connection.execute(s, table_name=table_name, schema=schema) table_oid = c.scalar() if table_oid is None: @@ -1675,8 +1780,7 @@ class PGDialect(default.DefaultDialect): SQL_COLS = """ SELECT a.attname, pg_catalog.format_type(a.atttypid, a.atttypmod), - (SELECT substring(pg_catalog.pg_get_expr(d.adbin, d.adrelid) - for 128) + (SELECT pg_catalog.pg_get_expr(d.adbin, d.adrelid) FROM pg_catalog.pg_attrdef d WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum AND a.atthasdef) @@ -1883,6 +1987,15 @@ class PGDialect(default.DefaultDialect): n.oid = c.relnamespace ORDER BY 1 """ + # http://www.postgresql.org/docs/9.0/static/sql-createtable.html + FK_REGEX = re.compile( + r'FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)' + r'[\s]?(MATCH (FULL|PARTIAL|SIMPLE)+)?' + r'[\s]?(ON UPDATE (CASCADE|RESTRICT|NO ACTION|SET NULL|SET DEFAULT)+)?' + r'[\s]?(ON DELETE (CASCADE|RESTRICT|NO ACTION|SET NULL|SET DEFAULT)+)?' + r'[\s]?(DEFERRABLE|NOT DEFERRABLE)?' + r'[\s]?(INITIALLY (DEFERRED|IMMEDIATE)+)?' + ) t = sql.text(FK_SQL, typemap={ 'conname': sqltypes.Unicode, @@ -1890,15 +2003,18 @@ class PGDialect(default.DefaultDialect): c = connection.execute(t, table=table_oid) fkeys = [] for conname, condef, conschema in c.fetchall(): - m = re.search('FOREIGN KEY \((.*?)\) REFERENCES ' - '(?:(.*?)\.)?(.*?)\((.*?)\)', condef).groups() + m = re.search(FK_REGEX, condef).groups() constrained_columns, referred_schema, \ - referred_table, referred_columns = m + referred_table, referred_columns, \ + _, match, _, onupdate, _, ondelete, \ + deferrable, _, initially = m + if deferrable is not None: + deferrable = True if deferrable == 'DEFERRABLE' else False constrained_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s*', constrained_columns)] if referred_schema: - referred_schema =\ + referred_schema = \ preparer._unquote_identifier(referred_schema) elif schema is not None and schema == conschema: # no schema was returned by pg_get_constraintdef(). This @@ -1916,7 +2032,14 @@ class PGDialect(default.DefaultDialect): 'constrained_columns': constrained_columns, 'referred_schema': referred_schema, 'referred_table': referred_table, - 'referred_columns': referred_columns + 'referred_columns': referred_columns, + 'options': { + 'onupdate': onupdate, + 'ondelete': ondelete, + 'deferrable': deferrable, + 'initially': initially, + 'match': match + } } fkeys.append(fkey_d) return fkeys @@ -1926,11 +2049,14 @@ class PGDialect(default.DefaultDialect): table_oid = self.get_table_oid(connection, table_name, schema, info_cache=kw.get('info_cache')) + # cast indkey as varchar since it's an int2vector, + # returned as a list by some drivers such as pypostgresql + IDX_SQL = """ SELECT i.relname as relname, ix.indisunique, ix.indexprs, ix.indpred, - a.attname, a.attnum, ix.indkey + a.attname, a.attnum, ix.indkey::varchar FROM pg_class t join pg_index ix on t.oid = ix.indrelid |