diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-04-02 21:36:11 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-04-02 21:36:11 +0000 |
commit | cdceb3c3714af707bfe3ede10af6536eaf529ca8 (patch) | |
tree | 2ccbfb60cd10d995c0309801b0adc4fc3a1f0a44 /lib/sqlalchemy/databases/postgres.py | |
parent | 8607de3159fd37923ae99118c499935c4a54d0e2 (diff) | |
download | sqlalchemy-cdceb3c3714af707bfe3ede10af6536eaf529ca8.tar.gz |
- merged the "execcontext" branch, refactors engine/dialect codepaths
- much more functionality moved into ExecutionContext, which impacted
the API used by dialects to some degree
- ResultProxy and subclasses now designed sanely
- merged patch for #522, Unicode subclasses String directly,
MSNVarchar implements for MS-SQL, removed MSUnicode.
- String moves its "VARCHAR"/"TEXT" switchy thing into
"get_search_list()" function, which VARCHAR and CHAR can override
to not return TEXT in any case (didnt do the latter yet)
- implements server side cursors for postgres, unit tests, #514
- includes overhaul of dbapi import strategy #480, all dbapi
importing happens in dialect method "dbapi()", is only called
inside of create_engine() for default and threadlocal strategies.
Dialect subclasses have a datamember "dbapi" referencing the loaded
module which may be None.
- added "mock" engine strategy, doesnt require DBAPI module and
gives you a "Connecition" which just sends all executes to a callable.
can be used to create string output of create_all()/drop_all().
Diffstat (limited to 'lib/sqlalchemy/databases/postgres.py')
-rw-r--r-- | lib/sqlalchemy/databases/postgres.py | 138 |
1 files changed, 61 insertions, 77 deletions
diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index d83607793..2943d163e 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -4,33 +4,28 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import datetime, sys, StringIO, string, types, re - -import sqlalchemy.util as util -import sqlalchemy.sql as sql -import sqlalchemy.engine as engine -import sqlalchemy.engine.default as default -import sqlalchemy.schema as schema -import sqlalchemy.ansisql as ansisql +import datetime, string, types, re, random + +from sqlalchemy import util, sql, schema, ansisql, exceptions +from sqlalchemy.engine import base, default import sqlalchemy.types as sqltypes -import sqlalchemy.exceptions as exceptions from sqlalchemy.databases import information_schema as ischema -import re try: import mx.DateTime.DateTime as mxDateTime except: mxDateTime = None -try: - import psycopg2 as psycopg - #import psycopg2.psycopg1 as psycopg -except: +def dbapi(): try: - import psycopg - except: - psycopg = None - + import psycopg2 as psycopg + except ImportError, e: + try: + import psycopg + except ImportError, e2: + raise e + return psycopg + class PGInet(sqltypes.TypeEngine): def get_col_spec(self): return "INET" @@ -74,8 +69,8 @@ class PG1DateTime(sqltypes.DateTime): mx_datetime = mxDateTime(value.year, value.month, value.day, value.hour, value.minute, seconds) - return psycopg.TimestampFromMx(mx_datetime) - return psycopg.TimestampFromMx(value) + return dialect.dbapi.TimestampFromMx(mx_datetime) + return dialect.dbapi.TimestampFromMx(value) else: return None @@ -101,7 +96,7 @@ class PG1Date(sqltypes.Date): # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime # this one doesnt seem to work with the "emulation" mode if value is not None: - return psycopg.DateFromMx(value) + return dialect.dbapi.DateFromMx(value) else: return None @@ -219,44 +214,49 @@ def descriptor(): ]} class PGExecutionContext(default.DefaultExecutionContext): - def post_exec(self, engine, proxy, compiled, parameters, **kwargs): - if getattr(compiled, "isinsert", False) and self.last_inserted_ids is None: - if not engine.dialect.use_oids: + + def is_select(self): + return re.match(r'SELECT', self.statement.lstrip(), re.I) and not re.search(r'FOR UPDATE\s*$', self.statement, re.I) + + def create_cursor(self): + if self.dialect.server_side_cursors and self.is_select(): + # use server-side cursors: + # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html + ident = "c" + hex(random.randint(0, 65535))[2:] + return self.connection.connection.cursor(ident) + else: + return self.connection.connection.cursor() + + def get_result_proxy(self): + if self.dialect.server_side_cursors and self.is_select(): + return base.BufferedRowResultProxy(self) + else: + return base.ResultProxy(self) + + def post_exec(self): + if self.compiled.isinsert and self.last_inserted_ids is None: + if not self.dialect.use_oids: pass # will raise invalid error when they go to get them else: - table = compiled.statement.table - cursor = proxy() - if cursor.lastrowid is not None and table is not None and len(table.primary_key): - s = sql.select(table.primary_key, table.oid_column == cursor.lastrowid) - c = s.compile(engine=engine) - cursor = proxy(str(c), c.get_params()) - row = cursor.fetchone() + table = self.compiled.statement.table + if self.cursor.lastrowid is not None and table is not None and len(table.primary_key): + s = sql.select(table.primary_key, table.oid_column == self.cursor.lastrowid) + row = self.connection.execute(s).fetchone() self._last_inserted_ids = [v for v in row] - + super(PGExecutionContext, self).post_exec() + class PGDialect(ansisql.ANSIDialect): - def __init__(self, module=None, use_oids=False, use_information_schema=False, server_side_cursors=False, **params): + def __init__(self, use_oids=False, use_information_schema=False, server_side_cursors=False, **kwargs): + ansisql.ANSIDialect.__init__(self, default_paramstyle='pyformat', **kwargs) self.use_oids = use_oids self.server_side_cursors = server_side_cursors - if module is None: - #if psycopg is None: - # raise exceptions.ArgumentError("Couldnt locate psycopg1 or psycopg2: specify postgres module argument") - self.module = psycopg + if self.dbapi is None or not hasattr(self.dbapi, '__version__') or self.dbapi.__version__.startswith('2'): + self.version = 2 else: - self.module = module - # figure psycopg version 1 or 2 - try: - if self.module.__version__.startswith('2'): - self.version = 2 - else: - self.version = 1 - except: self.version = 1 - ansisql.ANSIDialect.__init__(self, **params) self.use_information_schema = use_information_schema - # produce consistent paramstyle even if psycopg2 module not present - if self.module is None: - self.paramstyle = 'pyformat' + self.paramstyle = 'pyformat' def create_connect_args(self, url): opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port']) @@ -268,16 +268,9 @@ class PGDialect(ansisql.ANSIDialect): opts.update(url.query) return ([], opts) - def create_cursor(self, connection): - if self.server_side_cursors: - # use server-side cursors: - # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html - return connection.cursor('x') - else: - return connection.cursor() - def create_execution_context(self): - return PGExecutionContext(self) + def create_execution_context(self, *args, **kwargs): + return PGExecutionContext(self, *args, **kwargs) def max_identifier_length(self): return 68 @@ -292,13 +285,13 @@ class PGDialect(ansisql.ANSIDialect): return PGCompiler(self, statement, bindparams, **kwargs) def schemagenerator(self, *args, **kwargs): - return PGSchemaGenerator(*args, **kwargs) + return PGSchemaGenerator(self, *args, **kwargs) def schemadropper(self, *args, **kwargs): - return PGSchemaDropper(*args, **kwargs) + return PGSchemaDropper(self, *args, **kwargs) - def defaultrunner(self, engine, proxy): - return PGDefaultRunner(engine, proxy) + def defaultrunner(self, connection, **kwargs): + return PGDefaultRunner(connection, **kwargs) def preparer(self): return PGIdentifierPreparer(self) @@ -326,7 +319,6 @@ class PGDialect(ansisql.ANSIDialect): ``psycopg2`` is not nice enough to produce this correctly for an executemany, so we do our own executemany here. """ - rowcount = 0 for param in parameters: c.execute(statement, param) @@ -334,9 +326,6 @@ class PGDialect(ansisql.ANSIDialect): if context is not None: context._rowcount = rowcount - def dbapi(self): - return self.module - def has_table(self, connection, table_name, schema=None): # seems like case gets folded in pg_class... if schema is None: @@ -542,7 +531,7 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator): else: colspec += " SERIAL" else: - colspec += " " + column.type.engine_impl(self.engine).get_col_spec() + colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec() default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default @@ -567,8 +556,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner): if column.primary_key: # passive defaults on primary keys have to be overridden if isinstance(column.default, schema.PassiveDefault): - c = self.proxy("select %s" % column.default.arg) - return c.fetchone()[0] + return self.connection.execute_text("select %s" % column.default.arg).scalar() elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): sch = column.table.schema # TODO: this has to build into the Sequence object so we can get the quoting @@ -577,17 +565,13 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner): exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name) else: exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name) - c = self.proxy(exc) - return c.fetchone()[0] - else: - return ansisql.ANSIDefaultRunner.get_column_default(self, column) - else: - return ansisql.ANSIDefaultRunner.get_column_default(self, column) + return self.connection.execute_text(exc).scalar() + + return super(ansisql.ANSIDefaultRunner, self).get_column_default(column) def visit_sequence(self, seq): if not seq.optional: - c = self.proxy("select nextval('%s')" % seq.name) #TODO: self.dialect.preparer.format_sequence(seq)) - return c.fetchone()[0] + return self.connection.execute("select nextval('%s')" % self.dialect.identifier_preparer.format_sequence(seq)).scalar() else: return None |