summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/databases/postgres.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2007-07-27 04:08:53 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2007-07-27 04:08:53 +0000
commited4fc64bb0ac61c27bc4af32962fb129e74a36bf (patch)
treec1cf2fb7b1cafced82a8898e23d2a0bf5ced8526 /lib/sqlalchemy/databases/postgres.py
parent3a8e235af64e36b3b711df1f069d32359fe6c967 (diff)
downloadsqlalchemy-ed4fc64bb0ac61c27bc4af32962fb129e74a36bf.tar.gz
merging 0.4 branch to trunk. see CHANGES for details. 0.3 moves to maintenance branch in branches/rel_0_3.
Diffstat (limited to 'lib/sqlalchemy/databases/postgres.py')
-rw-r--r--lib/sqlalchemy/databases/postgres.py293
1 files changed, 137 insertions, 156 deletions
diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py
index d3726fc1f..b192c4778 100644
--- a/lib/sqlalchemy/databases/postgres.py
+++ b/lib/sqlalchemy/databases/postgres.py
@@ -4,12 +4,13 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import datetime, string, types, re, random, warnings
+import re, random, warnings, operator
-from sqlalchemy import util, sql, schema, ansisql, exceptions
+from sqlalchemy import sql, schema, ansisql, exceptions
from sqlalchemy.engine import base, default
import sqlalchemy.types as sqltypes
from sqlalchemy.databases import information_schema as ischema
+from decimal import Decimal
try:
import mx.DateTime.DateTime as mxDateTime
@@ -28,6 +29,15 @@ class PGNumeric(sqltypes.Numeric):
else:
return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}
+ def convert_bind_param(self, value, dialect):
+ return value
+
+ def convert_result_value(self, value, dialect):
+ if not self.asdecimal and isinstance(value, Decimal):
+ return float(value)
+ else:
+ return value
+
class PGFloat(sqltypes.Float):
def get_col_spec(self):
if not self.precision:
@@ -35,6 +45,7 @@ class PGFloat(sqltypes.Float):
else:
return "FLOAT(%(precision)s)" % {'precision': self.precision}
+
class PGInteger(sqltypes.Integer):
def get_col_spec(self):
return "INTEGER"
@@ -47,74 +58,15 @@ class PGBigInteger(PGInteger):
def get_col_spec(self):
return "BIGINT"
-class PG2DateTime(sqltypes.DateTime):
+class PGDateTime(sqltypes.DateTime):
def get_col_spec(self):
return "TIMESTAMP " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
-class PG1DateTime(sqltypes.DateTime):
- def convert_bind_param(self, value, dialect):
- if value is not None:
- if isinstance(value, datetime.datetime):
- seconds = float(str(value.second) + "."
- + str(value.microsecond))
- mx_datetime = mxDateTime(value.year, value.month, value.day,
- value.hour, value.minute,
- seconds)
- return dialect.dbapi.TimestampFromMx(mx_datetime)
- return dialect.dbapi.TimestampFromMx(value)
- else:
- return None
-
- def convert_result_value(self, value, dialect):
- if value is None:
- return None
- second_parts = str(value.second).split(".")
- seconds = int(second_parts[0])
- microseconds = int(float(second_parts[1]))
- return datetime.datetime(value.year, value.month, value.day,
- value.hour, value.minute, seconds,
- microseconds)
-
- def get_col_spec(self):
- return "TIMESTAMP " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
-
-class PG2Date(sqltypes.Date):
- def get_col_spec(self):
- return "DATE"
-
-class PG1Date(sqltypes.Date):
- def convert_bind_param(self, value, dialect):
- # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime
- # this one doesnt seem to work with the "emulation" mode
- if value is not None:
- return dialect.dbapi.DateFromMx(value)
- else:
- return None
-
- def convert_result_value(self, value, dialect):
- # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime
- return value
-
+class PGDate(sqltypes.Date):
def get_col_spec(self):
return "DATE"
-class PG2Time(sqltypes.Time):
- def get_col_spec(self):
- return "TIME " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
-
-class PG1Time(sqltypes.Time):
- def convert_bind_param(self, value, dialect):
- # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime
- # this one doesnt seem to work with the "emulation" mode
- if value is not None:
- return psycopg.TimeFromMx(value)
- else:
- return None
-
- def convert_result_value(self, value, dialect):
- # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime
- return value
-
+class PGTime(sqltypes.Time):
def get_col_spec(self):
return "TIME " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
@@ -142,28 +94,55 @@ class PGBoolean(sqltypes.Boolean):
def get_col_spec(self):
return "BOOLEAN"
-pg2_colspecs = {
+class PGArray(sqltypes.TypeEngine, sqltypes.Concatenable):
+ def __init__(self, item_type):
+ if isinstance(item_type, type):
+ item_type = item_type()
+ self.item_type = item_type
+
+ def dialect_impl(self, dialect):
+ impl = self.__class__.__new__(self.__class__)
+ impl.__dict__.update(self.__dict__)
+ impl.item_type = self.item_type.dialect_impl(dialect)
+ return impl
+ def convert_bind_param(self, value, dialect):
+ if value is None:
+ return value
+ def convert_item(item):
+ if isinstance(item, (list,tuple)):
+ return [convert_item(child) for child in item]
+ else:
+ return self.item_type.convert_bind_param(item, dialect)
+ return [convert_item(item) for item in value]
+ def convert_result_value(self, value, dialect):
+ if value is None:
+ return value
+ def convert_item(item):
+ if isinstance(item, list):
+ return [convert_item(child) for child in item]
+ else:
+ return self.item_type.convert_result_value(item, dialect)
+ # Could specialcase when item_type.convert_result_value is the default identity func
+ return [convert_item(item) for item in value]
+ def get_col_spec(self):
+ return self.item_type.get_col_spec() + '[]'
+
+colspecs = {
sqltypes.Integer : PGInteger,
sqltypes.Smallinteger : PGSmallInteger,
sqltypes.Numeric : PGNumeric,
sqltypes.Float : PGFloat,
- sqltypes.DateTime : PG2DateTime,
- sqltypes.Date : PG2Date,
- sqltypes.Time : PG2Time,
+ sqltypes.DateTime : PGDateTime,
+ sqltypes.Date : PGDate,
+ sqltypes.Time : PGTime,
sqltypes.String : PGString,
sqltypes.Binary : PGBinary,
sqltypes.Boolean : PGBoolean,
sqltypes.TEXT : PGText,
sqltypes.CHAR: PGChar,
}
-pg1_colspecs = pg2_colspecs.copy()
-pg1_colspecs.update({
- sqltypes.DateTime : PG1DateTime,
- sqltypes.Date : PG1Date,
- sqltypes.Time : PG1Time
- })
-
-pg2_ischema_names = {
+
+ischema_names = {
'integer' : PGInteger,
'bigint' : PGBigInteger,
'smallint' : PGSmallInteger,
@@ -175,24 +154,17 @@ pg2_ischema_names = {
'real' : PGFloat,
'inet': PGInet,
'double precision' : PGFloat,
- 'timestamp' : PG2DateTime,
- 'timestamp with time zone' : PG2DateTime,
- 'timestamp without time zone' : PG2DateTime,
- 'time with time zone' : PG2Time,
- 'time without time zone' : PG2Time,
- 'date' : PG2Date,
- 'time': PG2Time,
+ 'timestamp' : PGDateTime,
+ 'timestamp with time zone' : PGDateTime,
+ 'timestamp without time zone' : PGDateTime,
+ 'time with time zone' : PGTime,
+ 'time without time zone' : PGTime,
+ 'date' : PGDate,
+ 'time': PGTime,
'bytea' : PGBinary,
'boolean' : PGBoolean,
'interval':PGInterval,
}
-pg1_ischema_names = pg2_ischema_names.copy()
-pg1_ischema_names.update({
- 'timestamp with time zone' : PG1DateTime,
- 'timestamp without time zone' : PG1DateTime,
- 'date' : PG1Date,
- 'time' : PG1Time
- })
def descriptor():
return {'name':'postgres',
@@ -206,11 +178,11 @@ def descriptor():
class PGExecutionContext(default.DefaultExecutionContext):
- def is_select(self):
- return re.match(r'SELECT', self.statement.lstrip(), re.I) and not re.search(r'FOR UPDATE\s*$', self.statement, re.I)
-
+ def _is_server_side(self):
+ return self.dialect.server_side_cursors and self.is_select() and not re.search(r'FOR UPDATE(?: NOWAIT)?\s*$', self.statement, re.I)
+
def create_cursor(self):
- if self.dialect.server_side_cursors and self.is_select():
+ if self._is_server_side():
# use server-side cursors:
# http://lists.initd.org/pipermail/psycopg/2007-January/005251.html
ident = "c" + hex(random.randint(0, 65535))[2:]
@@ -219,7 +191,7 @@ class PGExecutionContext(default.DefaultExecutionContext):
return self.connection.connection.cursor()
def get_result_proxy(self):
- if self.dialect.server_side_cursors and self.is_select():
+ if self._is_server_side():
return base.BufferedRowResultProxy(self)
else:
return base.ResultProxy(self)
@@ -242,31 +214,18 @@ class PGDialect(ansisql.ANSIDialect):
ansisql.ANSIDialect.__init__(self, default_paramstyle='pyformat', **kwargs)
self.use_oids = use_oids
self.server_side_cursors = server_side_cursors
- if self.dbapi is None or not hasattr(self.dbapi, '__version__') or self.dbapi.__version__.startswith('2'):
- self.version = 2
- else:
- self.version = 1
self.use_information_schema = use_information_schema
self.paramstyle = 'pyformat'
def dbapi(cls):
- try:
- import psycopg2 as psycopg
- except ImportError, e:
- try:
- import psycopg
- except ImportError, e2:
- raise e
+ import psycopg2 as psycopg
return psycopg
dbapi = classmethod(dbapi)
def create_connect_args(self, url):
opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port'])
if opts.has_key('port'):
- if self.version == 2:
- opts['port'] = int(opts['port'])
- else:
- opts['port'] = str(opts['port'])
+ opts['port'] = int(opts['port'])
opts.update(url.query)
return ([], opts)
@@ -278,10 +237,7 @@ class PGDialect(ansisql.ANSIDialect):
return 63
def type_descriptor(self, typeobj):
- if self.version == 2:
- return sqltypes.adapt_type(typeobj, pg2_colspecs)
- else:
- return sqltypes.adapt_type(typeobj, pg1_colspecs)
+ return sqltypes.adapt_type(typeobj, colspecs)
def compiler(self, statement, bindparams, **kwargs):
return PGCompiler(self, statement, bindparams, **kwargs)
@@ -292,8 +248,36 @@ class PGDialect(ansisql.ANSIDialect):
def schemadropper(self, *args, **kwargs):
return PGSchemaDropper(self, *args, **kwargs)
- def defaultrunner(self, connection, **kwargs):
- return PGDefaultRunner(connection, **kwargs)
+ def do_begin_twophase(self, connection, xid):
+ self.do_begin(connection.connection)
+
+ def do_prepare_twophase(self, connection, xid):
+ connection.execute(sql.text("PREPARE TRANSACTION %(tid)s", bindparams=[sql.bindparam('tid', xid)]))
+
+ def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False):
+ if is_prepared:
+ if recover:
+ #FIXME: ugly hack to get out of transaction context when commiting recoverable transactions
+ # Must find out a way how to make the dbapi not open a transaction.
+ connection.execute(sql.text("ROLLBACK"))
+ connection.execute(sql.text("ROLLBACK PREPARED %(tid)s", bindparams=[sql.bindparam('tid', xid)]))
+ else:
+ self.do_rollback(connection.connection)
+
+ def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False):
+ if is_prepared:
+ if recover:
+ connection.execute(sql.text("ROLLBACK"))
+ connection.execute(sql.text("COMMIT PREPARED %(tid)s", bindparams=[sql.bindparam('tid', xid)]))
+ else:
+ self.do_commit(connection.connection)
+
+ def do_recover_twophase(self, connection):
+ resultset = connection.execute(sql.text("SELECT gid FROM pg_prepared_xacts"))
+ return [row[0] for row in resultset]
+
+ def defaultrunner(self, context, **kwargs):
+ return PGDefaultRunner(context, **kwargs)
def preparer(self):
return PGIdentifierPreparer(self)
@@ -351,14 +335,9 @@ class PGDialect(ansisql.ANSIDialect):
else:
return False
- def reflecttable(self, connection, table):
- if self.version == 2:
- ischema_names = pg2_ischema_names
- else:
- ischema_names = pg1_ischema_names
-
+ def reflecttable(self, connection, table, include_columns):
if self.use_information_schema:
- ischema.reflecttable(connection, table, ischema_names)
+ ischema.reflecttable(connection, table, include_columns, ischema_names)
else:
preparer = self.identifier_preparer
if table.schema is not None:
@@ -387,7 +366,7 @@ class PGDialect(ansisql.ANSIDialect):
ORDER BY a.attnum
""" % schema_where_clause
- s = sql.text(SQL_COLS, bindparams=[sql.bindparam('table_name', type=sqltypes.Unicode), sql.bindparam('schema', type=sqltypes.Unicode)], typemap={'attname':sqltypes.Unicode})
+ s = sql.text(SQL_COLS, bindparams=[sql.bindparam('table_name', type_=sqltypes.Unicode), sql.bindparam('schema', type_=sqltypes.Unicode)], typemap={'attname':sqltypes.Unicode})
c = connection.execute(s, table_name=table.name,
schema=table.schema)
rows = c.fetchall()
@@ -398,9 +377,13 @@ class PGDialect(ansisql.ANSIDialect):
domains = self._load_domains(connection)
for name, format_type, default, notnull, attnum, table_oid in rows:
+ if include_columns and name not in include_columns:
+ continue
+
## strip (30) from character varying(30)
- attype = re.search('([^\(]+)', format_type).group(1)
+ attype = re.search('([^\([]+)', format_type).group(1)
nullable = not notnull
+ is_array = format_type.endswith('[]')
try:
charlen = re.search('\(([\d,]+)\)', format_type).group(1)
@@ -453,6 +436,8 @@ class PGDialect(ansisql.ANSIDialect):
if coltype:
coltype = coltype(*args, **kwargs)
+ if is_array:
+ coltype = PGArray(coltype)
else:
warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (attype, name)))
coltype = sqltypes.NULLTYPE
@@ -517,7 +502,6 @@ class PGDialect(ansisql.ANSIDialect):
table.append_constraint(schema.ForeignKeyConstraint(constrained_columns, refspec, conname))
def _load_domains(self, connection):
-
## Load data types for domains:
SQL_DOMAINS = """
SELECT t.typname as "name",
@@ -554,49 +538,46 @@ class PGDialect(ansisql.ANSIDialect):
-
class PGCompiler(ansisql.ANSICompiler):
- def visit_insert_column(self, column, parameters):
- # all column primary key inserts must be explicitly present
- if column.primary_key:
- parameters[column.key] = None
+ operators = ansisql.ANSICompiler.operators.copy()
+ operators.update(
+ {
+ operator.mod : '%%'
+ }
+ )
- def visit_insert_sequence(self, column, sequence, parameters):
- """this is the 'sequence' equivalent to ANSICompiler's 'visit_insert_column_default' which ensures
- that the column is present in the generated column list"""
- parameters.setdefault(column.key, None)
+ def uses_sequences_for_inserts(self):
+ return True
def limit_clause(self, select):
text = ""
- if select.limit is not None:
- text += " \n LIMIT " + str(select.limit)
- if select.offset is not None:
- if select.limit is None:
+ if select._limit is not None:
+ text += " \n LIMIT " + str(select._limit)
+ if select._offset is not None:
+ if select._limit is None:
text += " \n LIMIT ALL"
- text += " OFFSET " + str(select.offset)
+ text += " OFFSET " + str(select._offset)
return text
- def visit_select_precolumns(self, select):
- if select.distinct:
- if type(select.distinct) == bool:
+ def get_select_precolumns(self, select):
+ if select._distinct:
+ if type(select._distinct) == bool:
return "DISTINCT "
- if type(select.distinct) == list:
+ if type(select._distinct) == list:
dist_set = "DISTINCT ON ("
- for col in select.distinct:
+ for col in select._distinct:
dist_set += self.strings[col] + ", "
dist_set = dist_set[:-2] + ") "
return dist_set
- return "DISTINCT ON (" + str(select.distinct) + ") "
+ return "DISTINCT ON (" + str(select._distinct) + ") "
else:
return ""
- def binary_operator_string(self, binary):
- if isinstance(binary.type, sqltypes.String) and binary.operator == '+':
- return '||'
- elif binary.operator == '%':
- return '%%'
+ def for_update_clause(self, select):
+ if select.for_update == 'nowait':
+ return " FOR UPDATE NOWAIT"
else:
- return ansisql.ANSICompiler.binary_operator_string(self, binary)
+ return super(PGCompiler, self).for_update_clause(select)
class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, **kwargs):
@@ -617,13 +598,13 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
return colspec
def visit_sequence(self, sequence):
- if not sequence.optional and (not self.dialect.has_sequence(self.connection, sequence.name)):
+ if not sequence.optional and (not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name)):
self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence))
self.execute()
class PGSchemaDropper(ansisql.ANSISchemaDropper):
def visit_sequence(self, sequence):
- if not sequence.optional and (self.dialect.has_sequence(self.connection, sequence.name)):
+ if not sequence.optional and (not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name)):
self.append("DROP SEQUENCE %s" % sequence.name)
self.execute()
@@ -632,7 +613,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner):
if column.primary_key:
# passive defaults on primary keys have to be overridden
if isinstance(column.default, schema.PassiveDefault):
- return self.connection.execute_text("select %s" % column.default.arg).scalar()
+ return self.connection.execute("select %s" % column.default.arg).scalar()
elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
sch = column.table.schema
# TODO: this has to build into the Sequence object so we can get the quoting
@@ -641,7 +622,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner):
exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name)
else:
exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name)
- return self.connection.execute_text(exc).scalar()
+ return self.connection.execute(exc).scalar()
return super(ansisql.ANSIDefaultRunner, self).get_column_default(column)