diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-07-27 04:08:53 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-07-27 04:08:53 +0000 |
commit | ed4fc64bb0ac61c27bc4af32962fb129e74a36bf (patch) | |
tree | c1cf2fb7b1cafced82a8898e23d2a0bf5ced8526 /lib/sqlalchemy/databases/postgres.py | |
parent | 3a8e235af64e36b3b711df1f069d32359fe6c967 (diff) | |
download | sqlalchemy-ed4fc64bb0ac61c27bc4af32962fb129e74a36bf.tar.gz |
merging 0.4 branch to trunk. see CHANGES for details. 0.3 moves to maintenance branch in branches/rel_0_3.
Diffstat (limited to 'lib/sqlalchemy/databases/postgres.py')
-rw-r--r-- | lib/sqlalchemy/databases/postgres.py | 293 |
1 files changed, 137 insertions, 156 deletions
diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index d3726fc1f..b192c4778 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -4,12 +4,13 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import datetime, string, types, re, random, warnings +import re, random, warnings, operator -from sqlalchemy import util, sql, schema, ansisql, exceptions +from sqlalchemy import sql, schema, ansisql, exceptions from sqlalchemy.engine import base, default import sqlalchemy.types as sqltypes from sqlalchemy.databases import information_schema as ischema +from decimal import Decimal try: import mx.DateTime.DateTime as mxDateTime @@ -28,6 +29,15 @@ class PGNumeric(sqltypes.Numeric): else: return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length} + def convert_bind_param(self, value, dialect): + return value + + def convert_result_value(self, value, dialect): + if not self.asdecimal and isinstance(value, Decimal): + return float(value) + else: + return value + class PGFloat(sqltypes.Float): def get_col_spec(self): if not self.precision: @@ -35,6 +45,7 @@ class PGFloat(sqltypes.Float): else: return "FLOAT(%(precision)s)" % {'precision': self.precision} + class PGInteger(sqltypes.Integer): def get_col_spec(self): return "INTEGER" @@ -47,74 +58,15 @@ class PGBigInteger(PGInteger): def get_col_spec(self): return "BIGINT" -class PG2DateTime(sqltypes.DateTime): +class PGDateTime(sqltypes.DateTime): def get_col_spec(self): return "TIMESTAMP " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE" -class PG1DateTime(sqltypes.DateTime): - def convert_bind_param(self, value, dialect): - if value is not None: - if isinstance(value, datetime.datetime): - seconds = float(str(value.second) + "." - + str(value.microsecond)) - mx_datetime = mxDateTime(value.year, value.month, value.day, - value.hour, value.minute, - seconds) - return dialect.dbapi.TimestampFromMx(mx_datetime) - return dialect.dbapi.TimestampFromMx(value) - else: - return None - - def convert_result_value(self, value, dialect): - if value is None: - return None - second_parts = str(value.second).split(".") - seconds = int(second_parts[0]) - microseconds = int(float(second_parts[1])) - return datetime.datetime(value.year, value.month, value.day, - value.hour, value.minute, seconds, - microseconds) - - def get_col_spec(self): - return "TIMESTAMP " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE" - -class PG2Date(sqltypes.Date): - def get_col_spec(self): - return "DATE" - -class PG1Date(sqltypes.Date): - def convert_bind_param(self, value, dialect): - # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime - # this one doesnt seem to work with the "emulation" mode - if value is not None: - return dialect.dbapi.DateFromMx(value) - else: - return None - - def convert_result_value(self, value, dialect): - # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime - return value - +class PGDate(sqltypes.Date): def get_col_spec(self): return "DATE" -class PG2Time(sqltypes.Time): - def get_col_spec(self): - return "TIME " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE" - -class PG1Time(sqltypes.Time): - def convert_bind_param(self, value, dialect): - # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime - # this one doesnt seem to work with the "emulation" mode - if value is not None: - return psycopg.TimeFromMx(value) - else: - return None - - def convert_result_value(self, value, dialect): - # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime - return value - +class PGTime(sqltypes.Time): def get_col_spec(self): return "TIME " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE" @@ -142,28 +94,55 @@ class PGBoolean(sqltypes.Boolean): def get_col_spec(self): return "BOOLEAN" -pg2_colspecs = { +class PGArray(sqltypes.TypeEngine, sqltypes.Concatenable): + def __init__(self, item_type): + if isinstance(item_type, type): + item_type = item_type() + self.item_type = item_type + + def dialect_impl(self, dialect): + impl = self.__class__.__new__(self.__class__) + impl.__dict__.update(self.__dict__) + impl.item_type = self.item_type.dialect_impl(dialect) + return impl + def convert_bind_param(self, value, dialect): + if value is None: + return value + def convert_item(item): + if isinstance(item, (list,tuple)): + return [convert_item(child) for child in item] + else: + return self.item_type.convert_bind_param(item, dialect) + return [convert_item(item) for item in value] + def convert_result_value(self, value, dialect): + if value is None: + return value + def convert_item(item): + if isinstance(item, list): + return [convert_item(child) for child in item] + else: + return self.item_type.convert_result_value(item, dialect) + # Could specialcase when item_type.convert_result_value is the default identity func + return [convert_item(item) for item in value] + def get_col_spec(self): + return self.item_type.get_col_spec() + '[]' + +colspecs = { sqltypes.Integer : PGInteger, sqltypes.Smallinteger : PGSmallInteger, sqltypes.Numeric : PGNumeric, sqltypes.Float : PGFloat, - sqltypes.DateTime : PG2DateTime, - sqltypes.Date : PG2Date, - sqltypes.Time : PG2Time, + sqltypes.DateTime : PGDateTime, + sqltypes.Date : PGDate, + sqltypes.Time : PGTime, sqltypes.String : PGString, sqltypes.Binary : PGBinary, sqltypes.Boolean : PGBoolean, sqltypes.TEXT : PGText, sqltypes.CHAR: PGChar, } -pg1_colspecs = pg2_colspecs.copy() -pg1_colspecs.update({ - sqltypes.DateTime : PG1DateTime, - sqltypes.Date : PG1Date, - sqltypes.Time : PG1Time - }) - -pg2_ischema_names = { + +ischema_names = { 'integer' : PGInteger, 'bigint' : PGBigInteger, 'smallint' : PGSmallInteger, @@ -175,24 +154,17 @@ pg2_ischema_names = { 'real' : PGFloat, 'inet': PGInet, 'double precision' : PGFloat, - 'timestamp' : PG2DateTime, - 'timestamp with time zone' : PG2DateTime, - 'timestamp without time zone' : PG2DateTime, - 'time with time zone' : PG2Time, - 'time without time zone' : PG2Time, - 'date' : PG2Date, - 'time': PG2Time, + 'timestamp' : PGDateTime, + 'timestamp with time zone' : PGDateTime, + 'timestamp without time zone' : PGDateTime, + 'time with time zone' : PGTime, + 'time without time zone' : PGTime, + 'date' : PGDate, + 'time': PGTime, 'bytea' : PGBinary, 'boolean' : PGBoolean, 'interval':PGInterval, } -pg1_ischema_names = pg2_ischema_names.copy() -pg1_ischema_names.update({ - 'timestamp with time zone' : PG1DateTime, - 'timestamp without time zone' : PG1DateTime, - 'date' : PG1Date, - 'time' : PG1Time - }) def descriptor(): return {'name':'postgres', @@ -206,11 +178,11 @@ def descriptor(): class PGExecutionContext(default.DefaultExecutionContext): - def is_select(self): - return re.match(r'SELECT', self.statement.lstrip(), re.I) and not re.search(r'FOR UPDATE\s*$', self.statement, re.I) - + def _is_server_side(self): + return self.dialect.server_side_cursors and self.is_select() and not re.search(r'FOR UPDATE(?: NOWAIT)?\s*$', self.statement, re.I) + def create_cursor(self): - if self.dialect.server_side_cursors and self.is_select(): + if self._is_server_side(): # use server-side cursors: # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html ident = "c" + hex(random.randint(0, 65535))[2:] @@ -219,7 +191,7 @@ class PGExecutionContext(default.DefaultExecutionContext): return self.connection.connection.cursor() def get_result_proxy(self): - if self.dialect.server_side_cursors and self.is_select(): + if self._is_server_side(): return base.BufferedRowResultProxy(self) else: return base.ResultProxy(self) @@ -242,31 +214,18 @@ class PGDialect(ansisql.ANSIDialect): ansisql.ANSIDialect.__init__(self, default_paramstyle='pyformat', **kwargs) self.use_oids = use_oids self.server_side_cursors = server_side_cursors - if self.dbapi is None or not hasattr(self.dbapi, '__version__') or self.dbapi.__version__.startswith('2'): - self.version = 2 - else: - self.version = 1 self.use_information_schema = use_information_schema self.paramstyle = 'pyformat' def dbapi(cls): - try: - import psycopg2 as psycopg - except ImportError, e: - try: - import psycopg - except ImportError, e2: - raise e + import psycopg2 as psycopg return psycopg dbapi = classmethod(dbapi) def create_connect_args(self, url): opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port']) if opts.has_key('port'): - if self.version == 2: - opts['port'] = int(opts['port']) - else: - opts['port'] = str(opts['port']) + opts['port'] = int(opts['port']) opts.update(url.query) return ([], opts) @@ -278,10 +237,7 @@ class PGDialect(ansisql.ANSIDialect): return 63 def type_descriptor(self, typeobj): - if self.version == 2: - return sqltypes.adapt_type(typeobj, pg2_colspecs) - else: - return sqltypes.adapt_type(typeobj, pg1_colspecs) + return sqltypes.adapt_type(typeobj, colspecs) def compiler(self, statement, bindparams, **kwargs): return PGCompiler(self, statement, bindparams, **kwargs) @@ -292,8 +248,36 @@ class PGDialect(ansisql.ANSIDialect): def schemadropper(self, *args, **kwargs): return PGSchemaDropper(self, *args, **kwargs) - def defaultrunner(self, connection, **kwargs): - return PGDefaultRunner(connection, **kwargs) + def do_begin_twophase(self, connection, xid): + self.do_begin(connection.connection) + + def do_prepare_twophase(self, connection, xid): + connection.execute(sql.text("PREPARE TRANSACTION %(tid)s", bindparams=[sql.bindparam('tid', xid)])) + + def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False): + if is_prepared: + if recover: + #FIXME: ugly hack to get out of transaction context when commiting recoverable transactions + # Must find out a way how to make the dbapi not open a transaction. + connection.execute(sql.text("ROLLBACK")) + connection.execute(sql.text("ROLLBACK PREPARED %(tid)s", bindparams=[sql.bindparam('tid', xid)])) + else: + self.do_rollback(connection.connection) + + def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False): + if is_prepared: + if recover: + connection.execute(sql.text("ROLLBACK")) + connection.execute(sql.text("COMMIT PREPARED %(tid)s", bindparams=[sql.bindparam('tid', xid)])) + else: + self.do_commit(connection.connection) + + def do_recover_twophase(self, connection): + resultset = connection.execute(sql.text("SELECT gid FROM pg_prepared_xacts")) + return [row[0] for row in resultset] + + def defaultrunner(self, context, **kwargs): + return PGDefaultRunner(context, **kwargs) def preparer(self): return PGIdentifierPreparer(self) @@ -351,14 +335,9 @@ class PGDialect(ansisql.ANSIDialect): else: return False - def reflecttable(self, connection, table): - if self.version == 2: - ischema_names = pg2_ischema_names - else: - ischema_names = pg1_ischema_names - + def reflecttable(self, connection, table, include_columns): if self.use_information_schema: - ischema.reflecttable(connection, table, ischema_names) + ischema.reflecttable(connection, table, include_columns, ischema_names) else: preparer = self.identifier_preparer if table.schema is not None: @@ -387,7 +366,7 @@ class PGDialect(ansisql.ANSIDialect): ORDER BY a.attnum """ % schema_where_clause - s = sql.text(SQL_COLS, bindparams=[sql.bindparam('table_name', type=sqltypes.Unicode), sql.bindparam('schema', type=sqltypes.Unicode)], typemap={'attname':sqltypes.Unicode}) + s = sql.text(SQL_COLS, bindparams=[sql.bindparam('table_name', type_=sqltypes.Unicode), sql.bindparam('schema', type_=sqltypes.Unicode)], typemap={'attname':sqltypes.Unicode}) c = connection.execute(s, table_name=table.name, schema=table.schema) rows = c.fetchall() @@ -398,9 +377,13 @@ class PGDialect(ansisql.ANSIDialect): domains = self._load_domains(connection) for name, format_type, default, notnull, attnum, table_oid in rows: + if include_columns and name not in include_columns: + continue + ## strip (30) from character varying(30) - attype = re.search('([^\(]+)', format_type).group(1) + attype = re.search('([^\([]+)', format_type).group(1) nullable = not notnull + is_array = format_type.endswith('[]') try: charlen = re.search('\(([\d,]+)\)', format_type).group(1) @@ -453,6 +436,8 @@ class PGDialect(ansisql.ANSIDialect): if coltype: coltype = coltype(*args, **kwargs) + if is_array: + coltype = PGArray(coltype) else: warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (attype, name))) coltype = sqltypes.NULLTYPE @@ -517,7 +502,6 @@ class PGDialect(ansisql.ANSIDialect): table.append_constraint(schema.ForeignKeyConstraint(constrained_columns, refspec, conname)) def _load_domains(self, connection): - ## Load data types for domains: SQL_DOMAINS = """ SELECT t.typname as "name", @@ -554,49 +538,46 @@ class PGDialect(ansisql.ANSIDialect): - class PGCompiler(ansisql.ANSICompiler): - def visit_insert_column(self, column, parameters): - # all column primary key inserts must be explicitly present - if column.primary_key: - parameters[column.key] = None + operators = ansisql.ANSICompiler.operators.copy() + operators.update( + { + operator.mod : '%%' + } + ) - def visit_insert_sequence(self, column, sequence, parameters): - """this is the 'sequence' equivalent to ANSICompiler's 'visit_insert_column_default' which ensures - that the column is present in the generated column list""" - parameters.setdefault(column.key, None) + def uses_sequences_for_inserts(self): + return True def limit_clause(self, select): text = "" - if select.limit is not None: - text += " \n LIMIT " + str(select.limit) - if select.offset is not None: - if select.limit is None: + if select._limit is not None: + text += " \n LIMIT " + str(select._limit) + if select._offset is not None: + if select._limit is None: text += " \n LIMIT ALL" - text += " OFFSET " + str(select.offset) + text += " OFFSET " + str(select._offset) return text - def visit_select_precolumns(self, select): - if select.distinct: - if type(select.distinct) == bool: + def get_select_precolumns(self, select): + if select._distinct: + if type(select._distinct) == bool: return "DISTINCT " - if type(select.distinct) == list: + if type(select._distinct) == list: dist_set = "DISTINCT ON (" - for col in select.distinct: + for col in select._distinct: dist_set += self.strings[col] + ", " dist_set = dist_set[:-2] + ") " return dist_set - return "DISTINCT ON (" + str(select.distinct) + ") " + return "DISTINCT ON (" + str(select._distinct) + ") " else: return "" - def binary_operator_string(self, binary): - if isinstance(binary.type, sqltypes.String) and binary.operator == '+': - return '||' - elif binary.operator == '%': - return '%%' + def for_update_clause(self, select): + if select.for_update == 'nowait': + return " FOR UPDATE NOWAIT" else: - return ansisql.ANSICompiler.binary_operator_string(self, binary) + return super(PGCompiler, self).for_update_clause(select) class PGSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, **kwargs): @@ -617,13 +598,13 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator): return colspec def visit_sequence(self, sequence): - if not sequence.optional and (not self.dialect.has_sequence(self.connection, sequence.name)): + if not sequence.optional and (not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name)): self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence)) self.execute() class PGSchemaDropper(ansisql.ANSISchemaDropper): def visit_sequence(self, sequence): - if not sequence.optional and (self.dialect.has_sequence(self.connection, sequence.name)): + if not sequence.optional and (not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name)): self.append("DROP SEQUENCE %s" % sequence.name) self.execute() @@ -632,7 +613,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner): if column.primary_key: # passive defaults on primary keys have to be overridden if isinstance(column.default, schema.PassiveDefault): - return self.connection.execute_text("select %s" % column.default.arg).scalar() + return self.connection.execute("select %s" % column.default.arg).scalar() elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): sch = column.table.schema # TODO: this has to build into the Sequence object so we can get the quoting @@ -641,7 +622,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner): exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name) else: exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name) - return self.connection.execute_text(exc).scalar() + return self.connection.execute(exc).scalar() return super(ansisql.ANSIDefaultRunner, self).get_column_default(column) |