summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2008-01-02 01:29:38 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2008-01-02 01:29:38 +0000
commita8b62a02ddc21d622c08ab0b05923fbe71eda36d (patch)
tree019e7f88be4f0a7d30fc5dd2d870d07a09755ff4
parent02a4176a657d54027703de5bbb4d4041ef271fe4 (diff)
downloadsqlalchemy-a8b62a02ddc21d622c08ab0b05923fbe71eda36d.tar.gz
- further fix to new TypeDecorator, so that subclasses of TypeDecorators work properly
- _handle_dbapi_exception() usage changed so that unwrapped exceptions can be rethrown with the original stack trace
-rw-r--r--lib/sqlalchemy/engine/base.py24
-rw-r--r--lib/sqlalchemy/engine/default.py6
-rw-r--r--lib/sqlalchemy/sql/compiler.py4
-rw-r--r--lib/sqlalchemy/types.py4
-rw-r--r--test/sql/testtypes.py240
5 files changed, 152 insertions, 126 deletions
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index e3433a06b..f2a1cd286 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -732,7 +732,8 @@ class Connection(Connectable):
try:
self.engine.dialect.do_begin(self.connection)
except Exception, e:
- raise self._handle_dbapi_exception(e, None, None, None)
+ self._handle_dbapi_exception(e, None, None, None)
+ raise
def _rollback_impl(self):
if not self.closed and not self.invalidated and self.__connection.is_valid:
@@ -742,7 +743,8 @@ class Connection(Connectable):
self.engine.dialect.do_rollback(self.connection)
self.__transaction = None
except Exception, e:
- raise self._handle_dbapi_exception(e, None, None, None)
+ self._handle_dbapi_exception(e, None, None, None)
+ raise
else:
self.__transaction = None
@@ -753,7 +755,8 @@ class Connection(Connectable):
self.engine.dialect.do_commit(self.connection)
self.__transaction = None
except Exception, e:
- raise self._handle_dbapi_exception(e, None, None, None)
+ self._handle_dbapi_exception(e, None, None, None)
+ raise
def _savepoint_impl(self, name=None):
if name is None:
@@ -914,11 +917,11 @@ class Connection(Connectable):
def _handle_dbapi_exception(self, e, statement, parameters, cursor):
if getattr(self, '_reentrant_error', False):
- return exceptions.DBAPIError.instance(None, None, e)
+ raise exceptions.DBAPIError.instance(None, None, e)
self._reentrant_error = True
try:
if not isinstance(e, self.dialect.dbapi.Error):
- return e
+ return
is_disconnect = self.dialect.is_disconnect(e)
if is_disconnect:
self.invalidate(e)
@@ -929,7 +932,7 @@ class Connection(Connectable):
self._autorollback()
if self.__close_with_result:
self.close()
- return exceptions.DBAPIError.instance(statement, parameters, e, connection_invalidated=is_disconnect)
+ raise exceptions.DBAPIError.instance(statement, parameters, e, connection_invalidated=is_disconnect)
finally:
del self._reentrant_error
@@ -937,7 +940,8 @@ class Connection(Connectable):
try:
return self.engine.dialect.create_execution_context(connection=self, **kwargs)
except Exception, e:
- raise self._handle_dbapi_exception(e, kwargs.get('statement', None), kwargs.get('parameters', None), None)
+ self._handle_dbapi_exception(e, kwargs.get('statement', None), kwargs.get('parameters', None), None)
+ raise
def _cursor_execute(self, cursor, statement, parameters, context=None):
if self.engine._should_log_info:
@@ -946,7 +950,8 @@ class Connection(Connectable):
try:
self.dialect.do_execute(cursor, statement, parameters, context=context)
except Exception, e:
- raise self._handle_dbapi_exception(e, statement, parameters, cursor)
+ self._handle_dbapi_exception(e, statement, parameters, cursor)
+ raise
def _cursor_executemany(self, cursor, statement, parameters, context=None):
if self.engine._should_log_info:
@@ -955,7 +960,8 @@ class Connection(Connectable):
try:
self.dialect.do_executemany(cursor, statement, parameters, context=context)
except Exception, e:
- raise self._handle_dbapi_exception(e, statement, parameters, cursor)
+ self._handle_dbapi_exception(e, statement, parameters, cursor)
+ raise
# poor man's multimethod/generic function thingy
executors = {
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 16d55e5b8..e78eedd5c 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -337,7 +337,8 @@ class DefaultExecutionContext(base.ExecutionContext):
try:
self.cursor.setinputsizes(*inputsizes)
except Exception, e:
- raise self._connection._handle_dbapi_exception(e, None, None, None)
+ self._connection._handle_dbapi_exception(e, None, None, None)
+ raise
else:
inputsizes = {}
for key in self.compiled.bind_names.values():
@@ -348,7 +349,8 @@ class DefaultExecutionContext(base.ExecutionContext):
try:
self.cursor.setinputsizes(**inputsizes)
except Exception, e:
- raise self._connection._handle_dbapi_exception(e, None, None, None)
+ self._connection._handle_dbapi_exception(e, None, None, None)
+ raise
def __process_defaults(self):
"""generate default values for compiled insert/update statements,
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index cf6d14714..9e8b0f488 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -218,8 +218,8 @@ class DefaultCompiler(engine.Compiled):
return pd
else:
return dict([(self.bind_names[bindparam], bindparam.value) for bindparam in self.bind_names])
-
- params = property(lambda self:self.construct_params(), doc="""return a dictionary of bind parameter keys and values""")
+
+ params = property(construct_params)
def default_from(self):
"""Called when a SELECT statement has no froms, and no FROM clause is to be appended.
diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py
index 14262d6e0..5ab9ad450 100644
--- a/lib/sqlalchemy/types.py
+++ b/lib/sqlalchemy/types.py
@@ -239,7 +239,7 @@ class TypeDecorator(AbstractType):
raise NotImplementedError()
def bind_processor(self, dialect):
- if 'process_bind_param' in self.__class__.__dict__:
+ if self.__class__.process_bind_param.func_code is not TypeDecorator.process_bind_param.func_code:
impl_processor = self.impl.bind_processor(dialect)
if impl_processor:
def process(value):
@@ -253,7 +253,7 @@ class TypeDecorator(AbstractType):
return self.impl.bind_processor(dialect)
def result_processor(self, dialect):
- if 'process_result_value' in self.__class__.__dict__:
+ if self.__class__.process_result_value.func_code is not TypeDecorator.process_result_value.func_code:
impl_processor = self.impl.result_processor(dialect)
if impl_processor:
def process(value):
diff --git a/test/sql/testtypes.py b/test/sql/testtypes.py
index eeb4a373f..fc1da5578 100644
--- a/test/sql/testtypes.py
+++ b/test/sql/testtypes.py
@@ -9,112 +9,6 @@ from sqlalchemy.databases import mssql, oracle, mysql, postgres, firebird
from testlib import *
-class MyType(types.TypeEngine):
- def get_col_spec(self):
- return "VARCHAR(100)"
- def bind_processor(self, dialect):
- def process(value):
- return "BIND_IN"+ value
- return process
- def result_processor(self, dialect):
- def process(value):
- return value + "BIND_OUT"
- return process
- def adapt(self, typeobj):
- return typeobj()
-
-class MyDecoratedType(types.TypeDecorator):
- impl = String
- def bind_processor(self, dialect):
- impl_processor = super(MyDecoratedType, self).bind_processor(dialect) or (lambda value:value)
- def process(value):
- return "BIND_IN"+ impl_processor(value)
- return process
- def result_processor(self, dialect):
- impl_processor = super(MyDecoratedType, self).result_processor(dialect) or (lambda value:value)
- def process(value):
- return impl_processor(value) + "BIND_OUT"
- return process
- def copy(self):
- return MyDecoratedType()
-
-class MyNewUnicodeType(types.TypeDecorator):
- impl = Unicode
-
- def process_bind_param(self, value, dialect):
- return "BIND_IN" + value
-
- def process_result_value(self, value, dialect):
- return value + "BIND_OUT"
-
- def copy(self):
- return MyNewUnicodeType(self.impl.length)
-
-class MyNewIntType(types.TypeDecorator):
- impl = Integer
-
- def process_bind_param(self, value, dialect):
- return value * 10
-
- def process_result_value(self, value, dialect):
- return value * 10
-
- def copy(self):
- return MyNewIntType()
-
-class MyUnicodeType(types.TypeDecorator):
- impl = Unicode
-
- def bind_processor(self, dialect):
- impl_processor = super(MyUnicodeType, self).bind_processor(dialect) or (lambda value:value)
-
- def process(value):
- return "BIND_IN"+ impl_processor(value)
- return process
-
- def result_processor(self, dialect):
- impl_processor = super(MyUnicodeType, self).result_processor(dialect) or (lambda value:value)
- def process(value):
- return impl_processor(value) + "BIND_OUT"
- return process
-
- def copy(self):
- return MyUnicodeType(self.impl.length)
-
-class MyPickleType(types.TypeDecorator):
- impl = PickleType
-
- def process_bind_param(self, value, dialect):
- if value:
- value.stuff = 'this is modified stuff'
- return value
-
- def process_result_value(self, value, dialect):
- if value:
- value.stuff = 'this is the right stuff'
- return value
-
-class LegacyType(types.TypeEngine):
- def get_col_spec(self):
- return "VARCHAR(100)"
- def convert_bind_param(self, value, dialect):
- return "BIND_IN"+ value
- def convert_result_value(self, value, dialect):
- return value + "BIND_OUT"
- def adapt(self, typeobj):
- return typeobj()
-
-class LegacyUnicodeType(types.TypeDecorator):
- impl = Unicode
-
- def convert_bind_param(self, value, dialect):
- return "BIND_IN" + super(LegacyUnicodeType, self).convert_bind_param(value, dialect)
-
- def convert_result_value(self, value, dialect):
- return super(LegacyUnicodeType, self).convert_result_value(value, dialect) + "BIND_OUT"
-
- def copy(self):
- return LegacyUnicodeType(self.impl.length)
class AdaptTest(PersistTest):
def testadapt(self):
@@ -149,6 +43,11 @@ class AdaptTest(PersistTest):
def testoracletext(self):
dialect = oracle.OracleDialect()
+ class MyDecoratedType(types.TypeDecorator):
+ impl = String
+ def copy(self):
+ return MyDecoratedType()
+
col = Column('', MyDecoratedType)
dialect_type = col.type.dialect_impl(dialect)
assert isinstance(dialect_type.impl, oracle.OracleText), repr(dialect_type.impl)
@@ -215,25 +114,129 @@ class UserDefinedTest(PersistTest):
def testprocessing(self):
global users
- users.insert().execute(user_id = 2, goofy = 'jack', goofy2='jack', goofy3='jack', goofy4=u'jack', goofy5=u'jack', goofy6='jack', goofy7=u'jack', goofy8=12)
- users.insert().execute(user_id = 3, goofy = 'lala', goofy2='lala', goofy3='lala', goofy4=u'lala', goofy5=u'lala', goofy6='lala', goofy7=u'lala', goofy8=15)
- users.insert().execute(user_id = 4, goofy = 'fred', goofy2='fred', goofy3='fred', goofy4=u'fred', goofy5=u'fred', goofy6='fred', goofy7=u'fred', goofy8=9)
+ users.insert().execute(user_id = 2, goofy = 'jack', goofy2='jack', goofy3='jack', goofy4=u'jack', goofy5=u'jack', goofy6='jack', goofy7=u'jack', goofy8=12, goofy9=12)
+ users.insert().execute(user_id = 3, goofy = 'lala', goofy2='lala', goofy3='lala', goofy4=u'lala', goofy5=u'lala', goofy6='lala', goofy7=u'lala', goofy8=15, goofy9=15)
+ users.insert().execute(user_id = 4, goofy = 'fred', goofy2='fred', goofy3='fred', goofy4=u'fred', goofy5=u'fred', goofy6='fred', goofy7=u'fred', goofy8=9, goofy9=9)
l = users.select().execute().fetchall()
- for assertstr, assertint, row in zip(
+ for assertstr, assertint, assertint2, row in zip(
["BIND_INjackBIND_OUT", "BIND_INlalaBIND_OUT", "BIND_INfredBIND_OUT"],
[1200, 1500, 900],
+ [1800, 2250, 1350],
l
):
for col in row[1:8]:
self.assertEquals(col, assertstr)
self.assertEquals(row[8], assertint)
+ self.assertEquals(row[9], assertint2)
for col in (row[4], row[5], row[7]):
assert isinstance(col, unicode)
def setUpAll(self):
global users, metadata
+
+ class MyType(types.TypeEngine):
+ def get_col_spec(self):
+ return "VARCHAR(100)"
+ def bind_processor(self, dialect):
+ def process(value):
+ return "BIND_IN"+ value
+ return process
+ def result_processor(self, dialect):
+ def process(value):
+ return value + "BIND_OUT"
+ return process
+ def adapt(self, typeobj):
+ return typeobj()
+
+ class MyDecoratedType(types.TypeDecorator):
+ impl = String
+ def bind_processor(self, dialect):
+ impl_processor = super(MyDecoratedType, self).bind_processor(dialect) or (lambda value:value)
+ def process(value):
+ return "BIND_IN"+ impl_processor(value)
+ return process
+ def result_processor(self, dialect):
+ impl_processor = super(MyDecoratedType, self).result_processor(dialect) or (lambda value:value)
+ def process(value):
+ return impl_processor(value) + "BIND_OUT"
+ return process
+ def copy(self):
+ return MyDecoratedType()
+
+ class MyNewUnicodeType(types.TypeDecorator):
+ impl = Unicode
+
+ def process_bind_param(self, value, dialect):
+ return "BIND_IN" + value
+
+ def process_result_value(self, value, dialect):
+ return value + "BIND_OUT"
+
+ def copy(self):
+ return MyNewUnicodeType(self.impl.length)
+
+ class MyNewIntType(types.TypeDecorator):
+ impl = Integer
+
+ def process_bind_param(self, value, dialect):
+ return value * 10
+
+ def process_result_value(self, value, dialect):
+ return value * 10
+
+ def copy(self):
+ return MyNewIntType()
+
+ class MyNewIntSubClass(MyNewIntType):
+ def process_result_value(self, value, dialect):
+ return value * 15
+
+ def copy(self):
+ return MyNewIntSubClass()
+
+ class MyUnicodeType(types.TypeDecorator):
+ impl = Unicode
+
+ def bind_processor(self, dialect):
+ impl_processor = super(MyUnicodeType, self).bind_processor(dialect) or (lambda value:value)
+
+ def process(value):
+ return "BIND_IN"+ impl_processor(value)
+ return process
+
+ def result_processor(self, dialect):
+ impl_processor = super(MyUnicodeType, self).result_processor(dialect) or (lambda value:value)
+ def process(value):
+ return impl_processor(value) + "BIND_OUT"
+ return process
+
+ def copy(self):
+ return MyUnicodeType(self.impl.length)
+
+ class LegacyType(types.TypeEngine):
+ def get_col_spec(self):
+ return "VARCHAR(100)"
+ def convert_bind_param(self, value, dialect):
+ return "BIND_IN"+ value
+ def convert_result_value(self, value, dialect):
+ return value + "BIND_OUT"
+ def adapt(self, typeobj):
+ return typeobj()
+
+ class LegacyUnicodeType(types.TypeDecorator):
+ impl = Unicode
+
+ def convert_bind_param(self, value, dialect):
+ return "BIND_IN" + super(LegacyUnicodeType, self).convert_bind_param(value, dialect)
+
+ def convert_result_value(self, value, dialect):
+ return super(LegacyUnicodeType, self).convert_result_value(value, dialect) + "BIND_OUT"
+
+ def copy(self):
+ return LegacyUnicodeType(self.impl.length)
+
metadata = MetaData(testbase.db)
users = Table('type_users', metadata,
Column('user_id', Integer, primary_key = True),
@@ -251,6 +254,7 @@ class UserDefinedTest(PersistTest):
Column('goofy6', LegacyType, nullable = False),
Column('goofy7', MyNewUnicodeType, nullable = False),
Column('goofy8', MyNewIntType, nullable = False),
+ Column('goofy9', MyNewIntSubClass, nullable = False),
)
@@ -396,7 +400,21 @@ class UnicodeTest(AssertMixin):
class BinaryTest(AssertMixin):
def setUpAll(self):
- global binary_table
+ global binary_table, MyPickleType
+
+ class MyPickleType(types.TypeDecorator):
+ impl = PickleType
+
+ def process_bind_param(self, value, dialect):
+ if value:
+ value.stuff = 'this is modified stuff'
+ return value
+
+ def process_result_value(self, value, dialect):
+ if value:
+ value.stuff = 'this is the right stuff'
+ return value
+
binary_table = Table('binary_table', MetaData(testbase.db),
Column('primary_id', Integer, Sequence('binary_id_seq', optional=True), primary_key=True),
Column('data', Binary),