diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2008-01-02 01:29:38 +0000 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2008-01-02 01:29:38 +0000 |
| commit | a8b62a02ddc21d622c08ab0b05923fbe71eda36d (patch) | |
| tree | 019e7f88be4f0a7d30fc5dd2d870d07a09755ff4 | |
| parent | 02a4176a657d54027703de5bbb4d4041ef271fe4 (diff) | |
| download | sqlalchemy-a8b62a02ddc21d622c08ab0b05923fbe71eda36d.tar.gz | |
- further fix to new TypeDecorator, so that subclasses of TypeDecorators work properly
- _handle_dbapi_exception() usage changed so that unwrapped exceptions can be rethrown with the original stack trace
| -rw-r--r-- | lib/sqlalchemy/engine/base.py | 24 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/default.py | 6 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/types.py | 4 | ||||
| -rw-r--r-- | test/sql/testtypes.py | 240 |
5 files changed, 152 insertions, 126 deletions
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index e3433a06b..f2a1cd286 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -732,7 +732,8 @@ class Connection(Connectable): try: self.engine.dialect.do_begin(self.connection) except Exception, e: - raise self._handle_dbapi_exception(e, None, None, None) + self._handle_dbapi_exception(e, None, None, None) + raise def _rollback_impl(self): if not self.closed and not self.invalidated and self.__connection.is_valid: @@ -742,7 +743,8 @@ class Connection(Connectable): self.engine.dialect.do_rollback(self.connection) self.__transaction = None except Exception, e: - raise self._handle_dbapi_exception(e, None, None, None) + self._handle_dbapi_exception(e, None, None, None) + raise else: self.__transaction = None @@ -753,7 +755,8 @@ class Connection(Connectable): self.engine.dialect.do_commit(self.connection) self.__transaction = None except Exception, e: - raise self._handle_dbapi_exception(e, None, None, None) + self._handle_dbapi_exception(e, None, None, None) + raise def _savepoint_impl(self, name=None): if name is None: @@ -914,11 +917,11 @@ class Connection(Connectable): def _handle_dbapi_exception(self, e, statement, parameters, cursor): if getattr(self, '_reentrant_error', False): - return exceptions.DBAPIError.instance(None, None, e) + raise exceptions.DBAPIError.instance(None, None, e) self._reentrant_error = True try: if not isinstance(e, self.dialect.dbapi.Error): - return e + return is_disconnect = self.dialect.is_disconnect(e) if is_disconnect: self.invalidate(e) @@ -929,7 +932,7 @@ class Connection(Connectable): self._autorollback() if self.__close_with_result: self.close() - return exceptions.DBAPIError.instance(statement, parameters, e, connection_invalidated=is_disconnect) + raise exceptions.DBAPIError.instance(statement, parameters, e, connection_invalidated=is_disconnect) finally: del self._reentrant_error @@ -937,7 +940,8 @@ class Connection(Connectable): try: return self.engine.dialect.create_execution_context(connection=self, **kwargs) except Exception, e: - raise self._handle_dbapi_exception(e, kwargs.get('statement', None), kwargs.get('parameters', None), None) + self._handle_dbapi_exception(e, kwargs.get('statement', None), kwargs.get('parameters', None), None) + raise def _cursor_execute(self, cursor, statement, parameters, context=None): if self.engine._should_log_info: @@ -946,7 +950,8 @@ class Connection(Connectable): try: self.dialect.do_execute(cursor, statement, parameters, context=context) except Exception, e: - raise self._handle_dbapi_exception(e, statement, parameters, cursor) + self._handle_dbapi_exception(e, statement, parameters, cursor) + raise def _cursor_executemany(self, cursor, statement, parameters, context=None): if self.engine._should_log_info: @@ -955,7 +960,8 @@ class Connection(Connectable): try: self.dialect.do_executemany(cursor, statement, parameters, context=context) except Exception, e: - raise self._handle_dbapi_exception(e, statement, parameters, cursor) + self._handle_dbapi_exception(e, statement, parameters, cursor) + raise # poor man's multimethod/generic function thingy executors = { diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 16d55e5b8..e78eedd5c 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -337,7 +337,8 @@ class DefaultExecutionContext(base.ExecutionContext): try: self.cursor.setinputsizes(*inputsizes) except Exception, e: - raise self._connection._handle_dbapi_exception(e, None, None, None) + self._connection._handle_dbapi_exception(e, None, None, None) + raise else: inputsizes = {} for key in self.compiled.bind_names.values(): @@ -348,7 +349,8 @@ class DefaultExecutionContext(base.ExecutionContext): try: self.cursor.setinputsizes(**inputsizes) except Exception, e: - raise self._connection._handle_dbapi_exception(e, None, None, None) + self._connection._handle_dbapi_exception(e, None, None, None) + raise def __process_defaults(self): """generate default values for compiled insert/update statements, diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index cf6d14714..9e8b0f488 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -218,8 +218,8 @@ class DefaultCompiler(engine.Compiled): return pd else: return dict([(self.bind_names[bindparam], bindparam.value) for bindparam in self.bind_names]) - - params = property(lambda self:self.construct_params(), doc="""return a dictionary of bind parameter keys and values""") + + params = property(construct_params) def default_from(self): """Called when a SELECT statement has no froms, and no FROM clause is to be appended. diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index 14262d6e0..5ab9ad450 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -239,7 +239,7 @@ class TypeDecorator(AbstractType): raise NotImplementedError() def bind_processor(self, dialect): - if 'process_bind_param' in self.__class__.__dict__: + if self.__class__.process_bind_param.func_code is not TypeDecorator.process_bind_param.func_code: impl_processor = self.impl.bind_processor(dialect) if impl_processor: def process(value): @@ -253,7 +253,7 @@ class TypeDecorator(AbstractType): return self.impl.bind_processor(dialect) def result_processor(self, dialect): - if 'process_result_value' in self.__class__.__dict__: + if self.__class__.process_result_value.func_code is not TypeDecorator.process_result_value.func_code: impl_processor = self.impl.result_processor(dialect) if impl_processor: def process(value): diff --git a/test/sql/testtypes.py b/test/sql/testtypes.py index eeb4a373f..fc1da5578 100644 --- a/test/sql/testtypes.py +++ b/test/sql/testtypes.py @@ -9,112 +9,6 @@ from sqlalchemy.databases import mssql, oracle, mysql, postgres, firebird from testlib import * -class MyType(types.TypeEngine): - def get_col_spec(self): - return "VARCHAR(100)" - def bind_processor(self, dialect): - def process(value): - return "BIND_IN"+ value - return process - def result_processor(self, dialect): - def process(value): - return value + "BIND_OUT" - return process - def adapt(self, typeobj): - return typeobj() - -class MyDecoratedType(types.TypeDecorator): - impl = String - def bind_processor(self, dialect): - impl_processor = super(MyDecoratedType, self).bind_processor(dialect) or (lambda value:value) - def process(value): - return "BIND_IN"+ impl_processor(value) - return process - def result_processor(self, dialect): - impl_processor = super(MyDecoratedType, self).result_processor(dialect) or (lambda value:value) - def process(value): - return impl_processor(value) + "BIND_OUT" - return process - def copy(self): - return MyDecoratedType() - -class MyNewUnicodeType(types.TypeDecorator): - impl = Unicode - - def process_bind_param(self, value, dialect): - return "BIND_IN" + value - - def process_result_value(self, value, dialect): - return value + "BIND_OUT" - - def copy(self): - return MyNewUnicodeType(self.impl.length) - -class MyNewIntType(types.TypeDecorator): - impl = Integer - - def process_bind_param(self, value, dialect): - return value * 10 - - def process_result_value(self, value, dialect): - return value * 10 - - def copy(self): - return MyNewIntType() - -class MyUnicodeType(types.TypeDecorator): - impl = Unicode - - def bind_processor(self, dialect): - impl_processor = super(MyUnicodeType, self).bind_processor(dialect) or (lambda value:value) - - def process(value): - return "BIND_IN"+ impl_processor(value) - return process - - def result_processor(self, dialect): - impl_processor = super(MyUnicodeType, self).result_processor(dialect) or (lambda value:value) - def process(value): - return impl_processor(value) + "BIND_OUT" - return process - - def copy(self): - return MyUnicodeType(self.impl.length) - -class MyPickleType(types.TypeDecorator): - impl = PickleType - - def process_bind_param(self, value, dialect): - if value: - value.stuff = 'this is modified stuff' - return value - - def process_result_value(self, value, dialect): - if value: - value.stuff = 'this is the right stuff' - return value - -class LegacyType(types.TypeEngine): - def get_col_spec(self): - return "VARCHAR(100)" - def convert_bind_param(self, value, dialect): - return "BIND_IN"+ value - def convert_result_value(self, value, dialect): - return value + "BIND_OUT" - def adapt(self, typeobj): - return typeobj() - -class LegacyUnicodeType(types.TypeDecorator): - impl = Unicode - - def convert_bind_param(self, value, dialect): - return "BIND_IN" + super(LegacyUnicodeType, self).convert_bind_param(value, dialect) - - def convert_result_value(self, value, dialect): - return super(LegacyUnicodeType, self).convert_result_value(value, dialect) + "BIND_OUT" - - def copy(self): - return LegacyUnicodeType(self.impl.length) class AdaptTest(PersistTest): def testadapt(self): @@ -149,6 +43,11 @@ class AdaptTest(PersistTest): def testoracletext(self): dialect = oracle.OracleDialect() + class MyDecoratedType(types.TypeDecorator): + impl = String + def copy(self): + return MyDecoratedType() + col = Column('', MyDecoratedType) dialect_type = col.type.dialect_impl(dialect) assert isinstance(dialect_type.impl, oracle.OracleText), repr(dialect_type.impl) @@ -215,25 +114,129 @@ class UserDefinedTest(PersistTest): def testprocessing(self): global users - users.insert().execute(user_id = 2, goofy = 'jack', goofy2='jack', goofy3='jack', goofy4=u'jack', goofy5=u'jack', goofy6='jack', goofy7=u'jack', goofy8=12) - users.insert().execute(user_id = 3, goofy = 'lala', goofy2='lala', goofy3='lala', goofy4=u'lala', goofy5=u'lala', goofy6='lala', goofy7=u'lala', goofy8=15) - users.insert().execute(user_id = 4, goofy = 'fred', goofy2='fred', goofy3='fred', goofy4=u'fred', goofy5=u'fred', goofy6='fred', goofy7=u'fred', goofy8=9) + users.insert().execute(user_id = 2, goofy = 'jack', goofy2='jack', goofy3='jack', goofy4=u'jack', goofy5=u'jack', goofy6='jack', goofy7=u'jack', goofy8=12, goofy9=12) + users.insert().execute(user_id = 3, goofy = 'lala', goofy2='lala', goofy3='lala', goofy4=u'lala', goofy5=u'lala', goofy6='lala', goofy7=u'lala', goofy8=15, goofy9=15) + users.insert().execute(user_id = 4, goofy = 'fred', goofy2='fred', goofy3='fred', goofy4=u'fred', goofy5=u'fred', goofy6='fred', goofy7=u'fred', goofy8=9, goofy9=9) l = users.select().execute().fetchall() - for assertstr, assertint, row in zip( + for assertstr, assertint, assertint2, row in zip( ["BIND_INjackBIND_OUT", "BIND_INlalaBIND_OUT", "BIND_INfredBIND_OUT"], [1200, 1500, 900], + [1800, 2250, 1350], l ): for col in row[1:8]: self.assertEquals(col, assertstr) self.assertEquals(row[8], assertint) + self.assertEquals(row[9], assertint2) for col in (row[4], row[5], row[7]): assert isinstance(col, unicode) def setUpAll(self): global users, metadata + + class MyType(types.TypeEngine): + def get_col_spec(self): + return "VARCHAR(100)" + def bind_processor(self, dialect): + def process(value): + return "BIND_IN"+ value + return process + def result_processor(self, dialect): + def process(value): + return value + "BIND_OUT" + return process + def adapt(self, typeobj): + return typeobj() + + class MyDecoratedType(types.TypeDecorator): + impl = String + def bind_processor(self, dialect): + impl_processor = super(MyDecoratedType, self).bind_processor(dialect) or (lambda value:value) + def process(value): + return "BIND_IN"+ impl_processor(value) + return process + def result_processor(self, dialect): + impl_processor = super(MyDecoratedType, self).result_processor(dialect) or (lambda value:value) + def process(value): + return impl_processor(value) + "BIND_OUT" + return process + def copy(self): + return MyDecoratedType() + + class MyNewUnicodeType(types.TypeDecorator): + impl = Unicode + + def process_bind_param(self, value, dialect): + return "BIND_IN" + value + + def process_result_value(self, value, dialect): + return value + "BIND_OUT" + + def copy(self): + return MyNewUnicodeType(self.impl.length) + + class MyNewIntType(types.TypeDecorator): + impl = Integer + + def process_bind_param(self, value, dialect): + return value * 10 + + def process_result_value(self, value, dialect): + return value * 10 + + def copy(self): + return MyNewIntType() + + class MyNewIntSubClass(MyNewIntType): + def process_result_value(self, value, dialect): + return value * 15 + + def copy(self): + return MyNewIntSubClass() + + class MyUnicodeType(types.TypeDecorator): + impl = Unicode + + def bind_processor(self, dialect): + impl_processor = super(MyUnicodeType, self).bind_processor(dialect) or (lambda value:value) + + def process(value): + return "BIND_IN"+ impl_processor(value) + return process + + def result_processor(self, dialect): + impl_processor = super(MyUnicodeType, self).result_processor(dialect) or (lambda value:value) + def process(value): + return impl_processor(value) + "BIND_OUT" + return process + + def copy(self): + return MyUnicodeType(self.impl.length) + + class LegacyType(types.TypeEngine): + def get_col_spec(self): + return "VARCHAR(100)" + def convert_bind_param(self, value, dialect): + return "BIND_IN"+ value + def convert_result_value(self, value, dialect): + return value + "BIND_OUT" + def adapt(self, typeobj): + return typeobj() + + class LegacyUnicodeType(types.TypeDecorator): + impl = Unicode + + def convert_bind_param(self, value, dialect): + return "BIND_IN" + super(LegacyUnicodeType, self).convert_bind_param(value, dialect) + + def convert_result_value(self, value, dialect): + return super(LegacyUnicodeType, self).convert_result_value(value, dialect) + "BIND_OUT" + + def copy(self): + return LegacyUnicodeType(self.impl.length) + metadata = MetaData(testbase.db) users = Table('type_users', metadata, Column('user_id', Integer, primary_key = True), @@ -251,6 +254,7 @@ class UserDefinedTest(PersistTest): Column('goofy6', LegacyType, nullable = False), Column('goofy7', MyNewUnicodeType, nullable = False), Column('goofy8', MyNewIntType, nullable = False), + Column('goofy9', MyNewIntSubClass, nullable = False), ) @@ -396,7 +400,21 @@ class UnicodeTest(AssertMixin): class BinaryTest(AssertMixin): def setUpAll(self): - global binary_table + global binary_table, MyPickleType + + class MyPickleType(types.TypeDecorator): + impl = PickleType + + def process_bind_param(self, value, dialect): + if value: + value.stuff = 'this is modified stuff' + return value + + def process_result_value(self, value, dialect): + if value: + value.stuff = 'this is the right stuff' + return value + binary_table = Table('binary_table', MetaData(testbase.db), Column('primary_id', Integer, Sequence('binary_id_seq', optional=True), primary_key=True), Column('data', Binary), |
