diff options
Diffstat (limited to 'lib/sqlalchemy/databases/postgres.py')
| -rw-r--r-- | lib/sqlalchemy/databases/postgres.py | 138 | 
1 files changed, 61 insertions, 77 deletions
| diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index d83607793..2943d163e 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -4,33 +4,28 @@  # This module is part of SQLAlchemy and is released under  # the MIT License: http://www.opensource.org/licenses/mit-license.php -import datetime, sys, StringIO, string, types, re - -import sqlalchemy.util as util -import sqlalchemy.sql as sql -import sqlalchemy.engine as engine -import sqlalchemy.engine.default as default -import sqlalchemy.schema as schema -import sqlalchemy.ansisql as ansisql +import datetime, string, types, re, random + +from sqlalchemy import util, sql, schema, ansisql, exceptions +from sqlalchemy.engine import base, default  import sqlalchemy.types as sqltypes -import sqlalchemy.exceptions as exceptions  from sqlalchemy.databases import information_schema as ischema -import re  try:      import mx.DateTime.DateTime as mxDateTime  except:      mxDateTime = None -try: -    import psycopg2 as psycopg -    #import psycopg2.psycopg1 as psycopg -except: +def dbapi():      try: -        import psycopg -    except: -        psycopg = None - +        import psycopg2 as psycopg +    except ImportError, e: +        try: +            import psycopg +        except ImportError, e2: +            raise e +    return psycopg +      class PGInet(sqltypes.TypeEngine):      def get_col_spec(self):          return "INET" @@ -74,8 +69,8 @@ class PG1DateTime(sqltypes.DateTime):                  mx_datetime = mxDateTime(value.year, value.month, value.day,                                           value.hour, value.minute,                                           seconds) -                return psycopg.TimestampFromMx(mx_datetime) -            return psycopg.TimestampFromMx(value) +                return dialect.dbapi.TimestampFromMx(mx_datetime) +            return dialect.dbapi.TimestampFromMx(value)          else:              return None @@ -101,7 +96,7 @@ class PG1Date(sqltypes.Date):          # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime          # this one doesnt seem to work with the "emulation" mode          if value is not None: -            return psycopg.DateFromMx(value) +            return dialect.dbapi.DateFromMx(value)          else:              return None @@ -219,44 +214,49 @@ def descriptor():      ]}  class PGExecutionContext(default.DefaultExecutionContext): -    def post_exec(self, engine, proxy, compiled, parameters, **kwargs): -        if getattr(compiled, "isinsert", False) and self.last_inserted_ids is None: -            if not engine.dialect.use_oids: + +    def is_select(self): +        return re.match(r'SELECT', self.statement.lstrip(), re.I) and not re.search(r'FOR UPDATE\s*$', self.statement, re.I) +     +    def create_cursor(self): +        if self.dialect.server_side_cursors and self.is_select(): +            # use server-side cursors: +            # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html +            ident = "c" + hex(random.randint(0, 65535))[2:] +            return self.connection.connection.cursor(ident) +        else: +            return self.connection.connection.cursor() + +    def get_result_proxy(self): +        if self.dialect.server_side_cursors and self.is_select(): +            return base.BufferedRowResultProxy(self) +        else: +            return base.ResultProxy(self) +     +    def post_exec(self): +        if self.compiled.isinsert and self.last_inserted_ids is None: +            if not self.dialect.use_oids:                  pass                  # will raise invalid error when they go to get them              else: -                table = compiled.statement.table -                cursor = proxy() -                if cursor.lastrowid is not None and table is not None and len(table.primary_key): -                    s = sql.select(table.primary_key, table.oid_column == cursor.lastrowid) -                    c = s.compile(engine=engine) -                    cursor = proxy(str(c), c.get_params()) -                    row = cursor.fetchone() +                table = self.compiled.statement.table +                if self.cursor.lastrowid is not None and table is not None and len(table.primary_key): +                    s = sql.select(table.primary_key, table.oid_column == self.cursor.lastrowid) +                    row = self.connection.execute(s).fetchone()                  self._last_inserted_ids = [v for v in row] - +        super(PGExecutionContext, self).post_exec() +          class PGDialect(ansisql.ANSIDialect): -    def __init__(self, module=None, use_oids=False, use_information_schema=False, server_side_cursors=False, **params): +    def __init__(self, use_oids=False, use_information_schema=False, server_side_cursors=False, **kwargs): +        ansisql.ANSIDialect.__init__(self, default_paramstyle='pyformat', **kwargs)          self.use_oids = use_oids          self.server_side_cursors = server_side_cursors -        if module is None: -            #if psycopg is None: -            #    raise exceptions.ArgumentError("Couldnt locate psycopg1 or psycopg2: specify postgres module argument") -            self.module = psycopg +        if self.dbapi is None or not hasattr(self.dbapi, '__version__') or self.dbapi.__version__.startswith('2'): +            self.version = 2          else: -            self.module = module -        # figure psycopg version 1 or 2 -        try: -            if self.module.__version__.startswith('2'): -                self.version = 2 -            else: -                self.version = 1 -        except:              self.version = 1 -        ansisql.ANSIDialect.__init__(self, **params)          self.use_information_schema = use_information_schema -        # produce consistent paramstyle even if psycopg2 module not present -        if self.module is None: -            self.paramstyle = 'pyformat' +        self.paramstyle = 'pyformat'      def create_connect_args(self, url):          opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port']) @@ -268,16 +268,9 @@ class PGDialect(ansisql.ANSIDialect):          opts.update(url.query)          return ([], opts) -    def create_cursor(self, connection): -        if self.server_side_cursors: -            # use server-side cursors: -            # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html -            return connection.cursor('x') -        else: -            return connection.cursor() -    def create_execution_context(self): -        return PGExecutionContext(self) +    def create_execution_context(self, *args, **kwargs): +        return PGExecutionContext(self, *args, **kwargs)      def max_identifier_length(self):          return 68 @@ -292,13 +285,13 @@ class PGDialect(ansisql.ANSIDialect):          return PGCompiler(self, statement, bindparams, **kwargs)      def schemagenerator(self, *args, **kwargs): -        return PGSchemaGenerator(*args, **kwargs) +        return PGSchemaGenerator(self, *args, **kwargs)      def schemadropper(self, *args, **kwargs): -        return PGSchemaDropper(*args, **kwargs) +        return PGSchemaDropper(self, *args, **kwargs) -    def defaultrunner(self, engine, proxy): -        return PGDefaultRunner(engine, proxy) +    def defaultrunner(self, connection, **kwargs): +        return PGDefaultRunner(connection, **kwargs)      def preparer(self):          return PGIdentifierPreparer(self) @@ -326,7 +319,6 @@ class PGDialect(ansisql.ANSIDialect):          ``psycopg2`` is not nice enough to produce this correctly for          an executemany, so we do our own executemany here.          """ -          rowcount = 0          for param in parameters:              c.execute(statement, param) @@ -334,9 +326,6 @@ class PGDialect(ansisql.ANSIDialect):          if context is not None:              context._rowcount = rowcount -    def dbapi(self): -        return self.module -      def has_table(self, connection, table_name, schema=None):          # seems like case gets folded in pg_class...          if schema is None: @@ -542,7 +531,7 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator):              else:                  colspec += " SERIAL"          else: -            colspec += " " + column.type.engine_impl(self.engine).get_col_spec() +            colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()              default = self.get_column_default_string(column)              if default is not None:                  colspec += " DEFAULT " + default @@ -567,8 +556,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner):          if column.primary_key:              # passive defaults on primary keys have to be overridden              if isinstance(column.default, schema.PassiveDefault): -                c = self.proxy("select %s" % column.default.arg) -                return c.fetchone()[0] +                return self.connection.execute_text("select %s" % column.default.arg).scalar()              elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):                  sch = column.table.schema                  # TODO: this has to build into the Sequence object so we can get the quoting @@ -577,17 +565,13 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner):                      exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name)                  else:                      exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name) -                c = self.proxy(exc) -                return c.fetchone()[0] -            else: -                return ansisql.ANSIDefaultRunner.get_column_default(self, column) -        else: -            return ansisql.ANSIDefaultRunner.get_column_default(self, column) +                return self.connection.execute_text(exc).scalar() + +        return super(ansisql.ANSIDefaultRunner, self).get_column_default(column)      def visit_sequence(self, seq):          if not seq.optional: -            c = self.proxy("select nextval('%s')" % seq.name) #TODO: self.dialect.preparer.format_sequence(seq)) -            return c.fetchone()[0] +            return self.connection.execute("select nextval('%s')" % self.dialect.identifier_preparer.format_sequence(seq)).scalar()          else:              return None | 
