diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2006-05-25 14:20:23 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2006-05-25 14:20:23 +0000 |
commit | bb79e2e871d0a4585164c1a6ed626d96d0231975 (patch) | |
tree | 6d457ba6c36c408b45db24ec3c29e147fe7504ff /lib/sqlalchemy/databases/postgres.py | |
parent | 4fc3a0648699c2b441251ba4e1d37a9107bd1986 (diff) | |
download | sqlalchemy-bb79e2e871d0a4585164c1a6ed626d96d0231975.tar.gz |
merged 0.2 branch into trunk; 0.1 now in sqlalchemy/branches/rel_0_1
Diffstat (limited to 'lib/sqlalchemy/databases/postgres.py')
-rw-r--r-- | lib/sqlalchemy/databases/postgres.py | 134 |
1 files changed, 66 insertions, 68 deletions
diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index a92cb340d..b6917c035 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -9,11 +9,11 @@ import datetime, sys, StringIO, string, types, re import sqlalchemy.util as util import sqlalchemy.sql as sql import sqlalchemy.engine as engine +import sqlalchemy.engine.default as default import sqlalchemy.schema as schema import sqlalchemy.ansisql as ansisql import sqlalchemy.types as sqltypes -from sqlalchemy.exceptions import * -from sqlalchemy import * +import sqlalchemy.exceptions as exceptions import information_schema as ischema try: @@ -47,7 +47,7 @@ class PG2DateTime(sqltypes.DateTime): def get_col_spec(self): return "TIMESTAMP" class PG1DateTime(sqltypes.DateTime): - def convert_bind_param(self, value, engine): + def convert_bind_param(self, value, dialect): if value is not None: if isinstance(value, datetime.datetime): seconds = float(str(value.second) + "." @@ -59,7 +59,7 @@ class PG1DateTime(sqltypes.DateTime): return psycopg.TimestampFromMx(value) else: return None - def convert_result_value(self, value, engine): + def convert_result_value(self, value, dialect): if value is None: return None second_parts = str(value.second).split(".") @@ -68,21 +68,20 @@ class PG1DateTime(sqltypes.DateTime): return datetime.datetime(value.year, value.month, value.day, value.hour, value.minute, seconds, microseconds) - def get_col_spec(self): return "TIMESTAMP" class PG2Date(sqltypes.Date): def get_col_spec(self): return "DATE" class PG1Date(sqltypes.Date): - def convert_bind_param(self, value, engine): + def convert_bind_param(self, value, dialect): # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime # this one doesnt seem to work with the "emulation" mode if value is not None: return psycopg.DateFromMx(value) else: return None - def convert_result_value(self, value, engine): + def convert_result_value(self, value, dialect): # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime return value def get_col_spec(self): @@ -91,14 +90,14 @@ class PG2Time(sqltypes.Time): def get_col_spec(self): return "TIME" class PG1Time(sqltypes.Time): - def convert_bind_param(self, value, engine): + def convert_bind_param(self, value, dialect): # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime # this one doesnt seem to work with the "emulation" mode if value is not None: return psycopg.TimeFromMx(value) else: return None - def convert_result_value(self, value, engine): + def convert_result_value(self, value, dialect): # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime return value def get_col_spec(self): @@ -175,18 +174,35 @@ def descriptor(): return {'name':'postgres', 'description':'PostGres', 'arguments':[ - ('user',"Database Username",None), + ('username',"Database Username",None), ('password',"Database Password",None), ('database',"Database Name",None), ('host',"Hostname", None), ]} -class PGSQLEngine(ansisql.ANSISQLEngine): - def __init__(self, opts, module=None, use_oids=False, **params): +class PGExecutionContext(default.DefaultExecutionContext): + + def post_exec(self, engine, proxy, compiled, parameters, **kwargs): + if getattr(compiled, "isinsert", False) and self.last_inserted_ids is None: + if not engine.dialect.use_oids: + pass + # will raise invalid error when they go to get them + else: + table = compiled.statement.table + cursor = proxy() + if cursor.lastrowid is not None and table is not None and len(table.primary_key): + s = sql.select(table.primary_key, table.oid_column == cursor.lastrowid) + c = s.compile(engine=engine) + cursor = proxy(str(c), c.get_params()) + row = cursor.fetchone() + self._last_inserted_ids = [v for v in row] + +class PGDialect(ansisql.ANSIDialect): + def __init__(self, module=None, use_oids=False, **params): self.use_oids = use_oids if module is None: - if psycopg is None: - raise ArgumentError("Couldnt locate psycopg1 or psycopg2: specify postgres module argument") + #if psycopg is None: + # raise exceptions.ArgumentError("Couldnt locate psycopg1 or psycopg2: specify postgres module argument") self.module = psycopg else: self.module = module @@ -198,17 +214,19 @@ class PGSQLEngine(ansisql.ANSISQLEngine): self.version = 1 except: self.version = 1 - self.opts = self._translate_connect_args(('host', 'database', 'user', 'password'), opts) - if self.opts.has_key('port'): + ansisql.ANSIDialect.__init__(self, **params) + + def create_connect_args(self, url): + opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port']) + if opts.has_key('port'): if self.version == 2: - self.opts['port'] = int(self.opts['port']) + opts['port'] = int(opts['port']) else: - self.opts['port'] = str(self.opts['port']) - - ansisql.ANSISQLEngine.__init__(self, **params) - - def connect_args(self): - return [[], self.opts] + opts['port'] = str(opts['port']) + return ([], opts) + + def create_execution_context(self): + return PGExecutionContext(self) def type_descriptor(self, typeobj): if self.version == 2: @@ -217,25 +235,22 @@ class PGSQLEngine(ansisql.ANSISQLEngine): return sqltypes.adapt_type(typeobj, pg1_colspecs) def compiler(self, statement, bindparams, **kwargs): - return PGCompiler(statement, bindparams, engine=self, **kwargs) - - def schemagenerator(self, **params): - return PGSchemaGenerator(self, **params) - - def schemadropper(self, **params): - return PGSchemaDropper(self, **params) - - def defaultrunner(self, proxy=None): - return PGDefaultRunner(self, proxy) + return PGCompiler(self, statement, bindparams, **kwargs) + def schemagenerator(self, *args, **kwargs): + return PGSchemaGenerator(*args, **kwargs) + def schemadropper(self, *args, **kwargs): + return PGSchemaDropper(*args, **kwargs) + def defaultrunner(self, engine, proxy): + return PGDefaultRunner(engine, proxy) - def get_default_schema_name(self): + def get_default_schema_name(self, connection): if not hasattr(self, '_default_schema_name'): - self._default_schema_name = text("select current_schema()", self).scalar() + self._default_schema_name = connection.scalar("select current_schema()", None) return self._default_schema_name def last_inserted_ids(self): if self.context.last_inserted_ids is None: - raise InvalidRequestError("no INSERT executed, or cant use cursor.lastrowid without Postgres OIDs enabled") + raise exceptions.InvalidRequestError("no INSERT executed, or cant use cursor.lastrowid without Postgres OIDs enabled") else: return self.context.last_inserted_ids @@ -245,51 +260,32 @@ class PGSQLEngine(ansisql.ANSISQLEngine): else: return None - def pre_exec(self, proxy, statement, parameters, **kwargs): - return - - def post_exec(self, proxy, compiled, parameters, **kwargs): - if getattr(compiled, "isinsert", False) and self.context.last_inserted_ids is None: - if not self.use_oids: - pass - # will raise invalid error when they go to get them - else: - table = compiled.statement.table - cursor = proxy() - if cursor.lastrowid is not None and table is not None and len(table.primary_key): - s = sql.select(table.primary_key, table.oid_column == cursor.lastrowid) - c = s.compile() - cursor = proxy(str(c), c.get_params()) - row = cursor.fetchone() - self.context.last_inserted_ids = [v for v in row] - - def _executemany(self, c, statement, parameters): + def do_executemany(self, c, statement, parameters, context=None): """we need accurate rowcounts for updates, inserts and deletes. psycopg2 is not nice enough to produce this correctly for an executemany, so we do our own executemany here.""" rowcount = 0 for param in parameters: - try: - c.execute(statement, param) - except Exception, e: - raise exceptions.SQLError(statement, param, e) + c.execute(statement, param) rowcount += c.rowcount - self.context.rowcount = rowcount + if context is not None: + context._rowcount = rowcount def dbapi(self): return self.module - def reflecttable(self, table): + def has_table(self, connection, table_name): + cursor = connection.execute("""select relname from pg_class where lower(relname) = %(name)s""", {'name':table_name.lower()}) + return bool( not not cursor.rowcount ) + + def reflecttable(self, connection, table): if self.version == 2: ischema_names = pg2_ischema_names else: ischema_names = pg1_ischema_names - # give ischema the given table's engine with which to look up - # other tables, not 'self', since it could be a ProxyEngine - ischema.reflecttable(table.engine, table, ischema_names) + ischema.reflecttable(connection, table, ischema_names) class PGCompiler(ansisql.ANSICompiler): - def visit_insert_column(self, column, parameters): # Postgres advises against OID usage and turns it off in 8.1, @@ -322,7 +318,7 @@ class PGCompiler(ansisql.ANSICompiler): return "DISTINCT ON (" + str(select.distinct) + ") " else: return "" - + def binary_operator_string(self, binary): if isinstance(binary.type, sqltypes.String) and binary.operator == '+': return '||' @@ -333,7 +329,7 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, override_pk=False, **kwargs): colspec = column.name - if column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): + if column.primary_key and isinstance(column.type, sqltypes.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): colspec += " SERIAL" else: colspec += " " + column.type.engine_impl(self.engine).get_col_spec() @@ -367,7 +363,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner): if isinstance(column.default, schema.PassiveDefault): c = self.proxy("select %s" % column.default.arg) return c.fetchone()[0] - elif isinstance(column.type, types.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): + elif isinstance(column.type, sqltypes.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): sch = column.table.schema if sch is not None: exc = "select nextval('%s.%s_%s_seq')" % (sch, column.table.name, column.name) @@ -386,3 +382,5 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner): return c.fetchone()[0] else: return None + +dialect = PGDialect |