summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/databases/postgres.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2007-04-02 21:36:11 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2007-04-02 21:36:11 +0000
commitcdceb3c3714af707bfe3ede10af6536eaf529ca8 (patch)
tree2ccbfb60cd10d995c0309801b0adc4fc3a1f0a44 /lib/sqlalchemy/databases/postgres.py
parent8607de3159fd37923ae99118c499935c4a54d0e2 (diff)
downloadsqlalchemy-cdceb3c3714af707bfe3ede10af6536eaf529ca8.tar.gz
- merged the "execcontext" branch, refactors engine/dialect codepaths
- much more functionality moved into ExecutionContext, which impacted the API used by dialects to some degree - ResultProxy and subclasses now designed sanely - merged patch for #522, Unicode subclasses String directly, MSNVarchar implements for MS-SQL, removed MSUnicode. - String moves its "VARCHAR"/"TEXT" switchy thing into "get_search_list()" function, which VARCHAR and CHAR can override to not return TEXT in any case (didnt do the latter yet) - implements server side cursors for postgres, unit tests, #514 - includes overhaul of dbapi import strategy #480, all dbapi importing happens in dialect method "dbapi()", is only called inside of create_engine() for default and threadlocal strategies. Dialect subclasses have a datamember "dbapi" referencing the loaded module which may be None. - added "mock" engine strategy, doesnt require DBAPI module and gives you a "Connecition" which just sends all executes to a callable. can be used to create string output of create_all()/drop_all().
Diffstat (limited to 'lib/sqlalchemy/databases/postgres.py')
-rw-r--r--lib/sqlalchemy/databases/postgres.py138
1 files changed, 61 insertions, 77 deletions
diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py
index d83607793..2943d163e 100644
--- a/lib/sqlalchemy/databases/postgres.py
+++ b/lib/sqlalchemy/databases/postgres.py
@@ -4,33 +4,28 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import datetime, sys, StringIO, string, types, re
-
-import sqlalchemy.util as util
-import sqlalchemy.sql as sql
-import sqlalchemy.engine as engine
-import sqlalchemy.engine.default as default
-import sqlalchemy.schema as schema
-import sqlalchemy.ansisql as ansisql
+import datetime, string, types, re, random
+
+from sqlalchemy import util, sql, schema, ansisql, exceptions
+from sqlalchemy.engine import base, default
import sqlalchemy.types as sqltypes
-import sqlalchemy.exceptions as exceptions
from sqlalchemy.databases import information_schema as ischema
-import re
try:
import mx.DateTime.DateTime as mxDateTime
except:
mxDateTime = None
-try:
- import psycopg2 as psycopg
- #import psycopg2.psycopg1 as psycopg
-except:
+def dbapi():
try:
- import psycopg
- except:
- psycopg = None
-
+ import psycopg2 as psycopg
+ except ImportError, e:
+ try:
+ import psycopg
+ except ImportError, e2:
+ raise e
+ return psycopg
+
class PGInet(sqltypes.TypeEngine):
def get_col_spec(self):
return "INET"
@@ -74,8 +69,8 @@ class PG1DateTime(sqltypes.DateTime):
mx_datetime = mxDateTime(value.year, value.month, value.day,
value.hour, value.minute,
seconds)
- return psycopg.TimestampFromMx(mx_datetime)
- return psycopg.TimestampFromMx(value)
+ return dialect.dbapi.TimestampFromMx(mx_datetime)
+ return dialect.dbapi.TimestampFromMx(value)
else:
return None
@@ -101,7 +96,7 @@ class PG1Date(sqltypes.Date):
# TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime
# this one doesnt seem to work with the "emulation" mode
if value is not None:
- return psycopg.DateFromMx(value)
+ return dialect.dbapi.DateFromMx(value)
else:
return None
@@ -219,44 +214,49 @@ def descriptor():
]}
class PGExecutionContext(default.DefaultExecutionContext):
- def post_exec(self, engine, proxy, compiled, parameters, **kwargs):
- if getattr(compiled, "isinsert", False) and self.last_inserted_ids is None:
- if not engine.dialect.use_oids:
+
+ def is_select(self):
+ return re.match(r'SELECT', self.statement.lstrip(), re.I) and not re.search(r'FOR UPDATE\s*$', self.statement, re.I)
+
+ def create_cursor(self):
+ if self.dialect.server_side_cursors and self.is_select():
+ # use server-side cursors:
+ # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html
+ ident = "c" + hex(random.randint(0, 65535))[2:]
+ return self.connection.connection.cursor(ident)
+ else:
+ return self.connection.connection.cursor()
+
+ def get_result_proxy(self):
+ if self.dialect.server_side_cursors and self.is_select():
+ return base.BufferedRowResultProxy(self)
+ else:
+ return base.ResultProxy(self)
+
+ def post_exec(self):
+ if self.compiled.isinsert and self.last_inserted_ids is None:
+ if not self.dialect.use_oids:
pass
# will raise invalid error when they go to get them
else:
- table = compiled.statement.table
- cursor = proxy()
- if cursor.lastrowid is not None and table is not None and len(table.primary_key):
- s = sql.select(table.primary_key, table.oid_column == cursor.lastrowid)
- c = s.compile(engine=engine)
- cursor = proxy(str(c), c.get_params())
- row = cursor.fetchone()
+ table = self.compiled.statement.table
+ if self.cursor.lastrowid is not None and table is not None and len(table.primary_key):
+ s = sql.select(table.primary_key, table.oid_column == self.cursor.lastrowid)
+ row = self.connection.execute(s).fetchone()
self._last_inserted_ids = [v for v in row]
-
+ super(PGExecutionContext, self).post_exec()
+
class PGDialect(ansisql.ANSIDialect):
- def __init__(self, module=None, use_oids=False, use_information_schema=False, server_side_cursors=False, **params):
+ def __init__(self, use_oids=False, use_information_schema=False, server_side_cursors=False, **kwargs):
+ ansisql.ANSIDialect.__init__(self, default_paramstyle='pyformat', **kwargs)
self.use_oids = use_oids
self.server_side_cursors = server_side_cursors
- if module is None:
- #if psycopg is None:
- # raise exceptions.ArgumentError("Couldnt locate psycopg1 or psycopg2: specify postgres module argument")
- self.module = psycopg
+ if self.dbapi is None or not hasattr(self.dbapi, '__version__') or self.dbapi.__version__.startswith('2'):
+ self.version = 2
else:
- self.module = module
- # figure psycopg version 1 or 2
- try:
- if self.module.__version__.startswith('2'):
- self.version = 2
- else:
- self.version = 1
- except:
self.version = 1
- ansisql.ANSIDialect.__init__(self, **params)
self.use_information_schema = use_information_schema
- # produce consistent paramstyle even if psycopg2 module not present
- if self.module is None:
- self.paramstyle = 'pyformat'
+ self.paramstyle = 'pyformat'
def create_connect_args(self, url):
opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port'])
@@ -268,16 +268,9 @@ class PGDialect(ansisql.ANSIDialect):
opts.update(url.query)
return ([], opts)
- def create_cursor(self, connection):
- if self.server_side_cursors:
- # use server-side cursors:
- # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html
- return connection.cursor('x')
- else:
- return connection.cursor()
- def create_execution_context(self):
- return PGExecutionContext(self)
+ def create_execution_context(self, *args, **kwargs):
+ return PGExecutionContext(self, *args, **kwargs)
def max_identifier_length(self):
return 68
@@ -292,13 +285,13 @@ class PGDialect(ansisql.ANSIDialect):
return PGCompiler(self, statement, bindparams, **kwargs)
def schemagenerator(self, *args, **kwargs):
- return PGSchemaGenerator(*args, **kwargs)
+ return PGSchemaGenerator(self, *args, **kwargs)
def schemadropper(self, *args, **kwargs):
- return PGSchemaDropper(*args, **kwargs)
+ return PGSchemaDropper(self, *args, **kwargs)
- def defaultrunner(self, engine, proxy):
- return PGDefaultRunner(engine, proxy)
+ def defaultrunner(self, connection, **kwargs):
+ return PGDefaultRunner(connection, **kwargs)
def preparer(self):
return PGIdentifierPreparer(self)
@@ -326,7 +319,6 @@ class PGDialect(ansisql.ANSIDialect):
``psycopg2`` is not nice enough to produce this correctly for
an executemany, so we do our own executemany here.
"""
-
rowcount = 0
for param in parameters:
c.execute(statement, param)
@@ -334,9 +326,6 @@ class PGDialect(ansisql.ANSIDialect):
if context is not None:
context._rowcount = rowcount
- def dbapi(self):
- return self.module
-
def has_table(self, connection, table_name, schema=None):
# seems like case gets folded in pg_class...
if schema is None:
@@ -542,7 +531,7 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
else:
colspec += " SERIAL"
else:
- colspec += " " + column.type.engine_impl(self.engine).get_col_spec()
+ colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
@@ -567,8 +556,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner):
if column.primary_key:
# passive defaults on primary keys have to be overridden
if isinstance(column.default, schema.PassiveDefault):
- c = self.proxy("select %s" % column.default.arg)
- return c.fetchone()[0]
+ return self.connection.execute_text("select %s" % column.default.arg).scalar()
elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
sch = column.table.schema
# TODO: this has to build into the Sequence object so we can get the quoting
@@ -577,17 +565,13 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner):
exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name)
else:
exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name)
- c = self.proxy(exc)
- return c.fetchone()[0]
- else:
- return ansisql.ANSIDefaultRunner.get_column_default(self, column)
- else:
- return ansisql.ANSIDefaultRunner.get_column_default(self, column)
+ return self.connection.execute_text(exc).scalar()
+
+ return super(ansisql.ANSIDefaultRunner, self).get_column_default(column)
def visit_sequence(self, seq):
if not seq.optional:
- c = self.proxy("select nextval('%s')" % seq.name) #TODO: self.dialect.preparer.format_sequence(seq))
- return c.fetchone()[0]
+ return self.connection.execute("select nextval('%s')" % self.dialect.identifier_preparer.format_sequence(seq)).scalar()
else:
return None