diff options
author | mike bayer <mike_mp@zzzcomputing.com> | 2020-03-02 23:45:35 +0000 |
---|---|---|
committer | Gerrit Code Review <gerrit@bbpush.zzzcomputing.com> | 2020-03-02 23:45:35 +0000 |
commit | b5050beb73b2e50b122c36e7dcdc06abffd472f2 (patch) | |
tree | 6679019ff418d6c346d5bd4cdc4aab4a73d9303e | |
parent | 2d052d43518a0f4d9751db7e699cfebd3724c1e5 (diff) | |
parent | 57dc36a01b2b334a996f73f6a78b3bfbe4d9f2ec (diff) | |
download | sqlalchemy-b5050beb73b2e50b122c36e7dcdc06abffd472f2.tar.gz |
Merge "Ensure all nested exception throws have a cause"
50 files changed, 768 insertions, 435 deletions
diff --git a/doc/build/changelog/unreleased_13/4849.rst b/doc/build/changelog/unreleased_13/4849.rst new file mode 100644 index 000000000..5a649dc33 --- /dev/null +++ b/doc/build/changelog/unreleased_13/4849.rst @@ -0,0 +1,13 @@ +.. change:: + :tags: bug, general, py3k + :tickets: 4849 + + Applied an explicit "cause" to most if not all internally raised exceptions + that are raised from within an internal exception catch, to avoid + misleading stacktraces that suggest an error within the handling of an + exception. While it would be preferable to suppress the internally caught + exception in the way that the ``__suppress_context__`` attribute would, + there does not as yet seem to be a way to do this without suppressing an + enclosing user constructed context, so for now it exposes the internally + caught exception as the cause so that full information about the context + of the error is maintained. diff --git a/lib/sqlalchemy/cextension/resultproxy.c b/lib/sqlalchemy/cextension/resultproxy.c index 3c44010b8..f622e6a28 100644 --- a/lib/sqlalchemy/cextension/resultproxy.c +++ b/lib/sqlalchemy/cextension/resultproxy.c @@ -288,7 +288,7 @@ BaseRow_getitem_by_object(BaseRow *self, PyObject *key, int asmapping) if (record == NULL) { record = PyObject_CallMethod(self->parent, "_key_fallback", - "O", key); + "OO", key, Py_None); if (record == NULL) return NULL; key_fallback = 1; // boolean to indicate record is a new reference diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index e0bf16793..6ea8cbcb8 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -2968,7 +2968,7 @@ class MySQLDialect(default.DefaultDialect): ).execute(st) except exc.DBAPIError as e: if self._extract_error_code(e.orig) == 1146: - raise exc.NoSuchTableError(full_name) + util.raise_(exc.NoSuchTableError(full_name), replace_context=e) else: raise row = self._compat_first(rp, charset=charset) @@ -2992,11 +2992,16 @@ class MySQLDialect(default.DefaultDialect): except exc.DBAPIError as e: code = self._extract_error_code(e.orig) if code == 1146: - raise exc.NoSuchTableError(full_name) + util.raise_( + exc.NoSuchTableError(full_name), replace_context=e + ) elif code == 1356: - raise exc.UnreflectableTableError( - "Table or view named %s could not be " - "reflected: %s" % (full_name, e) + util.raise_( + exc.UnreflectableTableError( + "Table or view named %s could not be " + "reflected: %s" % (full_name, e) + ), + replace_context=e, ) else: raise diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index 0b6afc337..1b1c9b0ba 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -763,11 +763,14 @@ class PGDialect_psycopg2(PGDialect): def set_isolation_level(self, connection, level): try: level = self._isolation_lookup[level.replace("_", " ")] - except KeyError: - raise exc.ArgumentError( - "Invalid value '%s' for isolation_level. " - "Valid isolation levels for %s are %s" - % (level, self.name, ", ".join(self._isolation_lookup)) + except KeyError as err: + util.raise_( + exc.ArgumentError( + "Invalid value '%s' for isolation_level. " + "Valid isolation levels for %s are %s" + % (level, self.name, ", ".join(self._isolation_lookup)) + ), + replace_context=err, ) connection.set_isolation_level(level) diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index d04b543cd..b1a83bf92 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -997,9 +997,12 @@ class SQLiteCompiler(compiler.SQLCompiler): self.extract_map[extract.field], self.process(extract.expr, **kw), ) - except KeyError: - raise exc.CompileError( - "%s is not a valid extract argument." % extract.field + except KeyError as err: + util.raise_( + exc.CompileError( + "%s is not a valid extract argument." % extract.field + ), + replace_context=err, ) def limit_clause(self, select, **kw): @@ -1537,11 +1540,14 @@ class SQLiteDialect(default.DefaultDialect): def set_isolation_level(self, connection, level): try: isolation_level = self._isolation_lookup[level.replace("_", " ")] - except KeyError: - raise exc.ArgumentError( - "Invalid value '%s' for isolation_level. " - "Valid isolation levels for %s are %s" - % (level, self.name, ", ".join(self._isolation_lookup)) + except KeyError as err: + util.raise_( + exc.ArgumentError( + "Invalid value '%s' for isolation_level. " + "Valid isolation levels for %s are %s" + % (level, self.name, ", ".join(self._isolation_lookup)) + ), + replace_context=err, ) cursor = connection.cursor() cursor.execute("PRAGMA read_uncommitted = %d" % isolation_level) diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index ce6c2e9c6..449f386ce 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -996,8 +996,10 @@ class Connection(Connectable): return self._execute_text(object_, multiparams, params) try: meth = object_._execute_on_connection - except AttributeError: - raise exc.ObjectNotExecutableError(object_) + except AttributeError as err: + util.raise_( + exc.ObjectNotExecutableError(object_), replace_context=err + ) else: return meth(self, multiparams, params) @@ -1400,7 +1402,7 @@ class Connection(Connectable): invalidate_pool_on_disconnect = not is_exit_exception if self._reentrant_error: - util.raise_from_cause( + util.raise_( exc.DBAPIError.instance( statement, parameters, @@ -1412,7 +1414,8 @@ class Connection(Connectable): if context is not None else None, ), - exc_info, + with_traceback=exc_info[2], + from_=e, ) self._reentrant_error = True try: @@ -1502,11 +1505,13 @@ class Connection(Connectable): self._autorollback() if newraise: - util.raise_from_cause(newraise, exc_info) + util.raise_(newraise, with_traceback=exc_info[2], from_=e) elif should_wrap: - util.raise_from_cause(sqlalchemy_exception, exc_info) + util.raise_( + sqlalchemy_exception, with_traceback=exc_info[2], from_=e + ) else: - util.reraise(*exc_info) + util.raise_(exc_info[1], with_traceback=exc_info[2]) finally: del self._reentrant_error @@ -1573,11 +1578,13 @@ class Connection(Connectable): ) = ctx.is_disconnect if newraise: - util.raise_from_cause(newraise, exc_info) + util.raise_(newraise, with_traceback=exc_info[2], from_=e) elif should_wrap: - util.raise_from_cause(sqlalchemy_exception, exc_info) + util.raise_( + sqlalchemy_exception, with_traceback=exc_info[2], from_=e + ) else: - util.reraise(*exc_info) + util.raise_(exc_info[1], with_traceback=exc_info[2]) def _run_ddl_visitor(self, visitorcallable, element, **kwargs): """run a DDL visitor. @@ -2329,7 +2336,9 @@ class Engine(Connectable, log.Identified): e, dialect, self ) else: - util.reraise(*sys.exc_info()) + util.raise_( + sys.exc_info()[1], with_traceback=sys.exc_info()[2] + ) def raw_connection(self, _connection=None): """Return a "raw" DBAPI connection from the connection pool. diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index 1a63c307b..7db9eecae 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -53,11 +53,11 @@ class ResultMetaData(object): def _has_key(self, key): return key in self._keymap - def _key_fallback(self, key): + def _key_fallback(self, key, err): if isinstance(key, int): - raise IndexError(key) + util.raise_(IndexError(key), replace_context=err) else: - raise KeyError(key) + util.raise_(KeyError(key), replace_context=err) class SimpleResultMetaData(ResultMetaData): @@ -546,11 +546,14 @@ class CursorResultMetaData(ResultMetaData): ) in self._colnames_from_description(context, cursor_description): yield idx, colname, sqltypes.NULLTYPE, coltype, None, untranslated - def _key_fallback(self, key, raiseerr=True): + def _key_fallback(self, key, err, raiseerr=True): if raiseerr: - raise exc.NoSuchColumnError( - "Could not locate column in row for column '%s'" - % util.string_or_unprintable(key) + util.raise_( + exc.NoSuchColumnError( + "Could not locate column in row for column '%s'" + % util.string_or_unprintable(key) + ), + replace_context=err, ) else: return None @@ -570,8 +573,8 @@ class CursorResultMetaData(ResultMetaData): def _getter(self, key, raiseerr=True): try: rec = self._keymap[key] - except KeyError: - rec = self._key_fallback(key, raiseerr) + except KeyError as ke: + rec = self._key_fallback(key, ke, raiseerr) if rec is None: return None @@ -598,8 +601,8 @@ class CursorResultMetaData(ResultMetaData): for key in keys: try: rec = self._keymap[key] - except KeyError: - rec = self._key_fallback(key, raiseerr) + except KeyError as ke: + rec = self._key_fallback(key, ke, raiseerr) if rec is None: return None @@ -656,9 +659,9 @@ class LegacyCursorResultMetaData(CursorResultMetaData): ) return True else: - return self._key_fallback(key, False) is not None + return self._key_fallback(key, None, False) is not None - def _key_fallback(self, key, raiseerr=True): + def _key_fallback(self, key, err, raiseerr=True): map_ = self._keymap result = None @@ -714,9 +717,12 @@ class LegacyCursorResultMetaData(CursorResultMetaData): ) if result is None: if raiseerr: - raise exc.NoSuchColumnError( - "Could not locate column in row for column '%s'" - % util.string_or_unprintable(key) + util.raise_( + exc.NoSuchColumnError( + "Could not locate column in row for column '%s'" + % util.string_or_unprintable(key) + ), + replace_context=err, ) else: return None @@ -736,7 +742,7 @@ class LegacyCursorResultMetaData(CursorResultMetaData): if key in self._keymap: return True else: - return self._key_fallback(key, False) is not None + return self._key_fallback(key, None, False) is not None class CursorFetchStrategy(object): @@ -807,9 +813,12 @@ class NoCursorDQLFetchStrategy(CursorFetchStrategy): def fetchall(self): return self._non_result([]) - def _non_result(self, default): + def _non_result(self, default, err=None): if self.closed: - raise exc.ResourceClosedError("This result object is closed.") + util.raise_( + exc.ResourceClosedError("This result object is closed."), + replace_context=err, + ) else: return default @@ -843,10 +852,13 @@ class NoCursorDMLFetchStrategy(CursorFetchStrategy): def fetchall(self): return self._non_result([]) - def _non_result(self, default): - raise exc.ResourceClosedError( - "This result object does not return rows. " - "It has been closed automatically." + def _non_result(self, default, err=None): + util.raise_( + exc.ResourceClosedError( + "This result object does not return rows. " + "It has been closed automatically." + ), + replace_context=err, ) @@ -1123,24 +1135,24 @@ class BaseResult(object): def _getter(self, key, raiseerr=True): try: getter = self._metadata._getter - except AttributeError: - return self.cursor_strategy._non_result(None) + except AttributeError as err: + return self.cursor_strategy._non_result(None, err) else: return getter(key, raiseerr) def _tuple_getter(self, key, raiseerr=True): try: getter = self._metadata._tuple_getter - except AttributeError: - return self.cursor_strategy._non_result(None) + except AttributeError as err: + return self.cursor_strategy._non_result(None, err) else: return getter(key, raiseerr) def _has_key(self, key): try: has_key = self._metadata._has_key - except AttributeError: - return self.cursor_strategy._non_result(None) + except AttributeError as err: + return self.cursor_strategy._non_result(None, err) else: return has_key(key) diff --git a/lib/sqlalchemy/engine/row.py b/lib/sqlalchemy/engine/row.py index 55d8c2249..b58b350e2 100644 --- a/lib/sqlalchemy/engine/row.py +++ b/lib/sqlalchemy/engine/row.py @@ -84,8 +84,8 @@ except ImportError: def _subscript_impl(self, key, ismapping): try: rec = self._keymap[key] - except KeyError: - rec = self._parent._key_fallback(key) + except KeyError as ke: + rec = self._parent._key_fallback(key, ke) except TypeError: # the non-C version detects a slice using TypeError. # this is pretty inefficient for the slice use case @@ -119,7 +119,7 @@ except ImportError: try: return self._get_by_key_impl_mapping(name) except KeyError as e: - raise AttributeError(e.args[0]) + util.raise_(AttributeError(e.args[0]), replace_context=e) class Row(BaseRow, collections_abc.Sequence): diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index 41346fc4e..f00b642db 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -244,6 +244,10 @@ class AssociationProxy(interfaces.InspectionAttrInfo): try: inst = class_.__dict__[self.key + "_inst"] except KeyError: + inst = None + + # avoid exception context + if inst is None: owner = self._calc_owner(class_) if owner is not None: inst = AssociationProxyInstance.for_proxy(self, owner, obj) @@ -358,9 +362,12 @@ class AssociationProxyInstance(object): # this was never asserted before but this should be made clear. if not isinstance(prop, orm.RelationshipProperty): - raise NotImplementedError( - "association proxy to a non-relationship " - "intermediary is not supported" + util.raise_( + NotImplementedError( + "association proxy to a non-relationship " + "intermediary is not supported" + ), + replace_context=None, ) target_class = prop.mapper.class_ @@ -1323,10 +1330,13 @@ class _AssociationDict(_AssociationCollection): try: for k, v in seq_or_map: self[k] = v - except ValueError: - raise ValueError( - "dictionary update sequence " - "requires 2-element tuples" + except ValueError as err: + util.raise_( + ValueError( + "dictionary update sequence " + "requires 2-element tuples" + ), + replace_context=err, ) for key, value in kw: diff --git a/lib/sqlalchemy/ext/baked.py b/lib/sqlalchemy/ext/baked.py index cafe69093..cf67387e4 100644 --- a/lib/sqlalchemy/ext/baked.py +++ b/lib/sqlalchemy/ext/baked.py @@ -504,9 +504,12 @@ class Result(object): """ try: ret = self.one_or_none() - except orm_exc.MultipleResultsFound: - raise orm_exc.MultipleResultsFound( - "Multiple rows were found for one()" + except orm_exc.MultipleResultsFound as err: + util.raise_( + orm_exc.MultipleResultsFound( + "Multiple rows were found for one()" + ), + replace_context=err, ) else: if ret is None: diff --git a/lib/sqlalchemy/ext/compiler.py b/lib/sqlalchemy/ext/compiler.py index c27907cdc..b8b6f8dc0 100644 --- a/lib/sqlalchemy/ext/compiler.py +++ b/lib/sqlalchemy/ext/compiler.py @@ -398,6 +398,7 @@ Example usage:: """ from .. import exc +from .. import util from ..sql import sqltypes from ..sql import visitors @@ -422,10 +423,13 @@ def compiles(class_, *specs): def _wrap_existing_dispatch(element, compiler, **kw): try: return existing_dispatch(element, compiler, **kw) - except exc.UnsupportedCompilationError: - raise exc.CompileError( - "%s construct has no default " - "compilation handler." % type(element) + except exc.UnsupportedCompilationError as uce: + util.raise_( + exc.CompileError( + "%s construct has no default " + "compilation handler." % type(element) + ), + from_=uce, ) existing.specs["default"] = _wrap_existing_dispatch @@ -470,10 +474,13 @@ class _dispatcher(object): if not fn: try: fn = self.specs["default"] - except KeyError: - raise exc.CompileError( - "%s construct has no default " - "compilation handler." % type(element) + except KeyError as ke: + util.raise_( + exc.CompileError( + "%s construct has no default " + "compilation handler." % type(element) + ), + replace_context=ke, ) # if compilation includes add_to_result_map, collect add_to_result_map diff --git a/lib/sqlalchemy/ext/declarative/clsregistry.py b/lib/sqlalchemy/ext/declarative/clsregistry.py index 7ff30b807..93e643cf5 100644 --- a/lib/sqlalchemy/ext/declarative/clsregistry.py +++ b/lib/sqlalchemy/ext/declarative/clsregistry.py @@ -298,12 +298,15 @@ class _class_resolver(object): else: return x except NameError as n: - raise exc.InvalidRequestError( - "When initializing mapper %s, expression %r failed to " - "locate a name (%r). If this is a class name, consider " - "adding this relationship() to the %r class after " - "both dependent classes have been defined." - % (self.prop.parent, self.arg, n.args[0], self.cls) + util.raise_( + exc.InvalidRequestError( + "When initializing mapper %s, expression %r failed to " + "locate a name (%r). If this is a class name, consider " + "adding this relationship() to the %r class after " + "both dependent classes have been defined." + % (self.prop.parent, self.arg, n.args[0], self.cls) + ), + from_=n, ) diff --git a/lib/sqlalchemy/ext/indexable.py b/lib/sqlalchemy/ext/indexable.py index f2e0501bb..6eb7e1185 100644 --- a/lib/sqlalchemy/ext/indexable.py +++ b/lib/sqlalchemy/ext/indexable.py @@ -223,7 +223,8 @@ The above query will render:: """ # noqa from __future__ import absolute_import -from sqlalchemy import inspect +from .. import inspect +from .. import util from ..ext.hybrid import hybrid_property from ..orm.attributes import flag_modified @@ -301,9 +302,9 @@ class index_property(hybrid_property): # noqa self.datatype = dict self.onebased = onebased - def _fget_default(self): + def _fget_default(self, err=None): if self.default == self._NO_DEFAULT_ARGUMENT: - raise AttributeError(self.attr_name) + util.raise_(AttributeError(self.attr_name), replace_context=err) else: return self.default @@ -314,8 +315,8 @@ class index_property(hybrid_property): # noqa return self._fget_default() try: value = column_value[self.index] - except (KeyError, IndexError): - return self._fget_default() + except (KeyError, IndexError) as err: + return self._fget_default(err) else: return value @@ -337,8 +338,8 @@ class index_property(hybrid_property): # noqa raise AttributeError(self.attr_name) try: del column_value[self.index] - except KeyError: - raise AttributeError(self.attr_name) + except KeyError as err: + util.raise_(AttributeError(self.attr_name), replace_context=err) else: setattr(instance, attr_name, column_value) flag_modified(instance, attr_name) diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 66a18da99..a959b0a40 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -231,16 +231,19 @@ class QueryableAttribute( def __getattr__(self, key): try: return getattr(self.comparator, key) - except AttributeError: - raise AttributeError( - "Neither %r object nor %r object associated with %s " - "has an attribute %r" - % ( - type(self).__name__, - type(self.comparator).__name__, - self, - key, - ) + except AttributeError as err: + util.raise_( + AttributeError( + "Neither %r object nor %r object associated with %s " + "has an attribute %r" + % ( + type(self).__name__, + type(self.comparator).__name__, + self, + key, + ) + ), + replace_context=err, ) def __str__(self): @@ -373,31 +376,39 @@ def create_proxied_attribute(descriptor): comparator.""" try: return getattr(descriptor, attribute) - except AttributeError: + except AttributeError as err: if attribute == "comparator": - raise AttributeError("comparator") + util.raise_( + AttributeError("comparator"), replace_context=err + ) try: # comparator itself might be unreachable comparator = self.comparator - except AttributeError: - raise AttributeError( - "Neither %r object nor unconfigured comparator " - "object associated with %s has an attribute %r" - % (type(descriptor).__name__, self, attribute) + except AttributeError as err2: + util.raise_( + AttributeError( + "Neither %r object nor unconfigured comparator " + "object associated with %s has an attribute %r" + % (type(descriptor).__name__, self, attribute) + ), + replace_context=err2, ) else: try: return getattr(comparator, attribute) - except AttributeError: - raise AttributeError( - "Neither %r object nor %r object " - "associated with %s has an attribute %r" - % ( - type(descriptor).__name__, - type(comparator).__name__, - self, - attribute, - ) + except AttributeError as err3: + util.raise_( + AttributeError( + "Neither %r object nor %r object " + "associated with %s has an attribute %r" + % ( + type(descriptor).__name__, + type(comparator).__name__, + self, + attribute, + ) + ), + replace_context=err3, ) Proxy.__name__ = type(descriptor).__name__ + "Proxy" @@ -713,12 +724,15 @@ class AttributeImpl(object): elif value is ATTR_WAS_SET: try: return dict_[key] - except KeyError: + except KeyError as err: # TODO: no test coverage here. - raise KeyError( - "Deferred loader for attribute " - "%r failed to populate " - "correctly" % key + util.raise_( + KeyError( + "Deferred loader for attribute " + "%r failed to populate " + "correctly" % key + ), + replace_context=err, ) elif value is not ATTR_EMPTY: return self.set_committed_value(state, dict_, value) diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py index 571107a38..a31745aec 100644 --- a/lib/sqlalchemy/orm/base.py +++ b/lib/sqlalchemy/orm/base.py @@ -387,9 +387,12 @@ def _entity_descriptor(entity, key): try: return getattr(entity, key) - except AttributeError: - raise sa_exc.InvalidRequestError( - "Entity '%s' has no property '%s'" % (description, key) + except AttributeError as err: + util.raise_( + sa_exc.InvalidRequestError( + "Entity '%s' has no property '%s'" % (description, key) + ), + replace_context=err, ) diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index f75c7d3ba..57c192a5d 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -557,12 +557,15 @@ class StrategizedProperty(MapperProperty): try: return self._strategies[key] except KeyError: - cls = self._strategy_lookup(self, *key) - # this previously was setting self._strategies[cls], that's - # a bad idea; should use strategy key at all times because every - # strategy has multiple keys at this point - self._strategies[key] = strategy = cls(self, key) - return strategy + pass + + # run outside to prevent transfer of exception context + cls = self._strategy_lookup(self, *key) + # this previously was setting self._strategies[cls], that's + # a bad idea; should use strategy key at all times because every + # strategy has multiple keys at this point + self._strategies[key] = strategy = cls(self, key) + return strategy def setup(self, context, query_entity, path, adapter, **kwargs): loader = self._get_context_loader(context, path) diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index 193980e6c..d943ebb19 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -99,9 +99,9 @@ def instances(query, cursor, context): if not query._yield_per: break - except Exception as err: - cursor.close() - util.raise_from_cause(err) + except Exception: + with util.safe_reraise(): + cursor.close() @util.dependencies("sqlalchemy.orm.query") diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 0d87a9c40..91e3251e2 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1483,11 +1483,14 @@ class Mapper(sql_base.HasCacheKey, InspectionAttr): # it to mapped ColumnProperty try: self.polymorphic_on = self._props[self.polymorphic_on] - except KeyError: - raise sa_exc.ArgumentError( - "Can't determine polymorphic_on " - "value '%s' - no attribute is " - "mapped to this name." % self.polymorphic_on + except KeyError as err: + util.raise_( + sa_exc.ArgumentError( + "Can't determine polymorphic_on " + "value '%s' - no attribute is " + "mapped to this name." % self.polymorphic_on + ), + replace_context=err, ) if self.polymorphic_on in self._columntoproperty: @@ -1987,9 +1990,12 @@ class Mapper(sql_base.HasCacheKey, InspectionAttr): try: return self._props[key] - except KeyError: - raise sa_exc.InvalidRequestError( - "Mapper '%s' has no property '%s'" % (self, key) + except KeyError as err: + util.raise_( + sa_exc.InvalidRequestError( + "Mapper '%s' has no property '%s'" % (self, key) + ), + replace_context=err, ) def get_property_by_column(self, column): diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 3b274a389..46c84d4bd 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -1635,9 +1635,12 @@ def _sort_states(mapper, states): persistent, key=mapper._persistent_sortkey_fn ) except TypeError as err: - raise sa_exc.InvalidRequestError( - "Could not sort objects by primary key; primary key " - "values must be sortable in Python (was: %s)" % err + util.raise_( + sa_exc.InvalidRequestError( + "Could not sort objects by primary key; primary key " + "values must be sortable in Python (was: %s)" % err + ), + replace_context=err, ) return ( sorted(pending, key=operator.attrgetter("insert_order")) @@ -1681,10 +1684,13 @@ class BulkUD(object): def _factory(cls, lookup, synchronize_session, *arg): try: klass = lookup[synchronize_session] - except KeyError: - raise sa_exc.ArgumentError( - "Valid strategies for session synchronization " - "are %s" % (", ".join(sorted(repr(x) for x in lookup))) + except KeyError as err: + util.raise_( + sa_exc.ArgumentError( + "Valid strategies for session synchronization " + "are %s" % (", ".join(sorted(repr(x) for x in lookup))) + ), + replace_context=err, ) else: return klass(*arg) @@ -1768,10 +1774,13 @@ class BulkEvaluate(BulkUD): self._additional_evaluators(evaluator_compiler) except evaluator.UnevaluatableError as err: - raise sa_exc.InvalidRequestError( - 'Could not evaluate current criteria in Python: "%s". ' - "Specify 'fetch' or False for the " - "synchronize_session parameter." % err + util.raise_( + sa_exc.InvalidRequestError( + 'Could not evaluate current criteria in Python: "%s". ' + "Specify 'fetch' or False for the " + "synchronize_session parameter." % err + ), + from_=err, ) # TODO: detect when the where clause is a trivial primary key match diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index d237aa3bf..e29e6eeee 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1019,15 +1019,18 @@ class Query(Generative): for prop in mapper._identity_key_props ) - except KeyError: - raise sa_exc.InvalidRequestError( - "Incorrect names of values in identifier to formulate " - "primary key for query.get(); primary key attribute names" - " are %s" - % ",".join( - "'%s'" % prop.key - for prop in mapper._identity_key_props - ) + except KeyError as err: + util.raise_( + sa_exc.InvalidRequestError( + "Incorrect names of values in identifier to formulate " + "primary key for query.get(); primary key attribute " + "names are %s" + % ",".join( + "'%s'" % prop.key + for prop in mapper._identity_key_props + ) + ), + replace_context=err, ) if ( @@ -3292,9 +3295,12 @@ class Query(Generative): """ try: ret = self.one_or_none() - except orm_exc.MultipleResultsFound: - raise orm_exc.MultipleResultsFound( - "Multiple rows were found for one()" + except orm_exc.MultipleResultsFound as err: + util.raise_( + orm_exc.MultipleResultsFound( + "Multiple rows were found for one()" + ), + replace_context=err, ) else: if ret is None: diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index b82a3d271..2995baf5f 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -2484,50 +2484,64 @@ class JoinCondition(object): a_subset=self.parent_local_selectable, consider_as_foreign_keys=consider_as_foreign_keys, ) - except sa_exc.NoForeignKeysError: + except sa_exc.NoForeignKeysError as nfe: if self.secondary is not None: - raise sa_exc.NoForeignKeysError( - "Could not determine join " - "condition between parent/child tables on " - "relationship %s - there are no foreign keys " - "linking these tables via secondary table '%s'. " - "Ensure that referencing columns are associated " - "with a ForeignKey or ForeignKeyConstraint, or " - "specify 'primaryjoin' and 'secondaryjoin' " - "expressions." % (self.prop, self.secondary) + util.raise_( + sa_exc.NoForeignKeysError( + "Could not determine join " + "condition between parent/child tables on " + "relationship %s - there are no foreign keys " + "linking these tables via secondary table '%s'. " + "Ensure that referencing columns are associated " + "with a ForeignKey or ForeignKeyConstraint, or " + "specify 'primaryjoin' and 'secondaryjoin' " + "expressions." % (self.prop, self.secondary) + ), + from_=nfe, ) else: - raise sa_exc.NoForeignKeysError( - "Could not determine join " - "condition between parent/child tables on " - "relationship %s - there are no foreign keys " - "linking these tables. " - "Ensure that referencing columns are associated " - "with a ForeignKey or ForeignKeyConstraint, or " - "specify a 'primaryjoin' expression." % self.prop + util.raise_( + sa_exc.NoForeignKeysError( + "Could not determine join " + "condition between parent/child tables on " + "relationship %s - there are no foreign keys " + "linking these tables. " + "Ensure that referencing columns are associated " + "with a ForeignKey or ForeignKeyConstraint, or " + "specify a 'primaryjoin' expression." % self.prop + ), + from_=nfe, ) - except sa_exc.AmbiguousForeignKeysError: + except sa_exc.AmbiguousForeignKeysError as afe: if self.secondary is not None: - raise sa_exc.AmbiguousForeignKeysError( - "Could not determine join " - "condition between parent/child tables on " - "relationship %s - there are multiple foreign key " - "paths linking the tables via secondary table '%s'. " - "Specify the 'foreign_keys' " - "argument, providing a list of those columns which " - "should be counted as containing a foreign key " - "reference from the secondary table to each of the " - "parent and child tables." % (self.prop, self.secondary) + util.raise_( + sa_exc.AmbiguousForeignKeysError( + "Could not determine join " + "condition between parent/child tables on " + "relationship %s - there are multiple foreign key " + "paths linking the tables via secondary table '%s'. " + "Specify the 'foreign_keys' " + "argument, providing a list of those columns which " + "should be counted as containing a foreign key " + "reference from the secondary table to each of the " + "parent and child tables." + % (self.prop, self.secondary) + ), + from_=afe, ) else: - raise sa_exc.AmbiguousForeignKeysError( - "Could not determine join " - "condition between parent/child tables on " - "relationship %s - there are multiple foreign key " - "paths linking the tables. Specify the " - "'foreign_keys' argument, providing a list of those " - "columns which should be counted as containing a " - "foreign key reference to the parent table." % self.prop + util.raise_( + sa_exc.AmbiguousForeignKeysError( + "Could not determine join " + "condition between parent/child tables on " + "relationship %s - there are multiple foreign key " + "paths linking the tables. Specify the " + "'foreign_keys' argument, providing a list of those " + "columns which should be counted as containing a " + "foreign key reference to the parent table." + % self.prop + ), + from_=afe, ) @property diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 095033951..74e546483 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -575,7 +575,7 @@ class SessionTransaction(object): self._parent._rollback_exception = sys.exc_info()[1] if rollback_err: - util.reraise(*rollback_err) + util.raise_(rollback_err[1], with_traceback=rollback_err[2]) sess.dispatch.after_soft_rollback(sess, self) @@ -1362,10 +1362,13 @@ class Session(_SessionClassMethods): def _add_bind(self, key, bind): try: insp = inspect(key) - except sa_exc.NoInspectionAvailable: + except sa_exc.NoInspectionAvailable as err: if not isinstance(key, type): - raise sa_exc.ArgumentError( - "Not an acceptable bind target: %s" % key + util.raise_( + sa_exc.ArgumentError( + "Not an acceptable bind target: %s" % key + ), + replace_context=err, ) else: self.__binds[key] = bind @@ -1515,9 +1518,11 @@ class Session(_SessionClassMethods): if mapper is not None: try: mapper = inspect(mapper) - except sa_exc.NoInspectionAvailable: + except sa_exc.NoInspectionAvailable as err: if isinstance(mapper, type): - raise exc.UnmappedClassError(mapper) + util.raise_( + exc.UnmappedClassError(mapper), replace_context=err, + ) else: raise @@ -1656,7 +1661,7 @@ class Session(_SessionClassMethods): "consider using a session.no_autoflush block if this " "flush is occurring prematurely" ) - util.raise_from_cause(e) + util.raise_(e, with_traceback=sys.exc_info[2]) def refresh( self, @@ -1711,8 +1716,10 @@ class Session(_SessionClassMethods): """ try: state = attributes.instance_state(instance) - except exc.NO_STATE: - raise exc.UnmappedInstanceError(instance) + except exc.NO_STATE as err: + util.raise_( + exc.UnmappedInstanceError(instance), replace_context=err, + ) self._expire_state(state, attribute_names) @@ -1817,8 +1824,10 @@ class Session(_SessionClassMethods): """ try: state = attributes.instance_state(instance) - except exc.NO_STATE: - raise exc.UnmappedInstanceError(instance) + except exc.NO_STATE as err: + util.raise_( + exc.UnmappedInstanceError(instance), replace_context=err, + ) self._expire_state(state, attribute_names) def _expire_state(self, state, attribute_names): @@ -1872,8 +1881,10 @@ class Session(_SessionClassMethods): """ try: state = attributes.instance_state(instance) - except exc.NO_STATE: - raise exc.UnmappedInstanceError(instance) + except exc.NO_STATE as err: + util.raise_( + exc.UnmappedInstanceError(instance), replace_context=err, + ) if state.session_id is not self.hash_key: raise sa_exc.InvalidRequestError( "Instance %s is not present in this Session" % state_str(state) @@ -2024,8 +2035,10 @@ class Session(_SessionClassMethods): try: state = attributes.instance_state(instance) - except exc.NO_STATE: - raise exc.UnmappedInstanceError(instance) + except exc.NO_STATE as err: + util.raise_( + exc.UnmappedInstanceError(instance), replace_context=err, + ) self._save_or_update_state(state) @@ -2059,8 +2072,10 @@ class Session(_SessionClassMethods): try: state = attributes.instance_state(instance) - except exc.NO_STATE: - raise exc.UnmappedInstanceError(instance) + except exc.NO_STATE as err: + util.raise_( + exc.UnmappedInstanceError(instance), replace_context=err, + ) self._delete_impl(state, instance, head=True) @@ -2490,8 +2505,10 @@ class Session(_SessionClassMethods): """ try: state = attributes.instance_state(instance) - except exc.NO_STATE: - raise exc.UnmappedInstanceError(instance) + except exc.NO_STATE as err: + util.raise_( + exc.UnmappedInstanceError(instance), replace_context=err, + ) return self._contains_state(state) def __iter__(self): @@ -2586,8 +2603,11 @@ class Session(_SessionClassMethods): for o in objects: try: state = attributes.instance_state(o) - except exc.NO_STATE: - raise exc.UnmappedInstanceError(o) + + except exc.NO_STATE as err: + util.raise_( + exc.UnmappedInstanceError(o), replace_context=err, + ) objset.add(state) else: objset = None @@ -3450,8 +3470,10 @@ def object_session(instance): try: state = attributes.instance_state(instance) - except exc.NO_STATE: - raise exc.UnmappedInstanceError(instance) + except exc.NO_STATE as err: + util.raise_( + exc.UnmappedInstanceError(instance), replace_context=err, + ) else: return _state_session(state) diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py index 0c72f3b37..4f7d996d4 100644 --- a/lib/sqlalchemy/orm/strategy_options.py +++ b/lib/sqlalchemy/orm/strategy_options.py @@ -252,11 +252,14 @@ class Load(HasCacheKey, Generative, MapperOption): # use getattr on the class to work around # synonyms, hybrids, etc. attr = getattr(ent.class_, attr) - except AttributeError: + except AttributeError as err: if raiseerr: - raise sa_exc.ArgumentError( - 'Can\'t find property named "%s" on ' - "%s in this Query." % (attr, ent) + util.raise_( + sa_exc.ArgumentError( + 'Can\'t find property named "%s" on ' + "%s in this Query." % (attr, ent) + ), + replace_context=err, ) else: return None diff --git a/lib/sqlalchemy/orm/sync.py b/lib/sqlalchemy/orm/sync.py index 198e64f4f..ceaf54e5d 100644 --- a/lib/sqlalchemy/orm/sync.py +++ b/lib/sqlalchemy/orm/sync.py @@ -13,6 +13,7 @@ between instances based on join conditions. from . import attributes from . import exc from . import util as orm_util +from .. import util def populate( @@ -34,15 +35,15 @@ def populate( value = source.manager[prop.key].impl.get( source, source_dict, attributes.PASSIVE_OFF ) - except exc.UnmappedColumnError: - _raise_col_to_prop(False, source_mapper, l, dest_mapper, r) + except exc.UnmappedColumnError as err: + _raise_col_to_prop(False, source_mapper, l, dest_mapper, r, err) try: # inline of dest_mapper._set_state_attr_by_column prop = dest_mapper._columntoproperty[r] dest.manager[prop.key].impl.set(dest, dest_dict, value, None) - except exc.UnmappedColumnError: - _raise_col_to_prop(True, source_mapper, l, dest_mapper, r) + except exc.UnmappedColumnError as err: + _raise_col_to_prop(True, source_mapper, l, dest_mapper, r, err) # technically the "r.primary_key" check isn't # needed here, but we check for this condition to limit @@ -64,8 +65,8 @@ def bulk_populate_inherit_keys(source_dict, source_mapper, synchronize_pairs): try: prop = source_mapper._columntoproperty[l] value = source_dict[prop.key] - except exc.UnmappedColumnError: - _raise_col_to_prop(False, source_mapper, l, source_mapper, r) + except exc.UnmappedColumnError as err: + _raise_col_to_prop(False, source_mapper, l, source_mapper, r, err) try: prop = source_mapper._columntoproperty[r] @@ -88,8 +89,8 @@ def clear(dest, dest_mapper, synchronize_pairs): ) try: dest_mapper._set_state_attr_by_column(dest, dest.dict, r, None) - except exc.UnmappedColumnError: - _raise_col_to_prop(True, None, l, dest_mapper, r) + except exc.UnmappedColumnError as err: + _raise_col_to_prop(True, None, l, dest_mapper, r, err) def update(source, source_mapper, dest, old_prefix, synchronize_pairs): @@ -101,8 +102,8 @@ def update(source, source_mapper, dest, old_prefix, synchronize_pairs): value = source_mapper._get_state_attr_by_column( source, source.dict, l, passive=attributes.PASSIVE_OFF ) - except exc.UnmappedColumnError: - _raise_col_to_prop(False, source_mapper, l, None, r) + except exc.UnmappedColumnError as err: + _raise_col_to_prop(False, source_mapper, l, None, r, err) dest[r.key] = value dest[old_prefix + r.key] = oldvalue @@ -113,8 +114,8 @@ def populate_dict(source, source_mapper, dict_, synchronize_pairs): value = source_mapper._get_state_attr_by_column( source, source.dict, l, passive=attributes.PASSIVE_OFF ) - except exc.UnmappedColumnError: - _raise_col_to_prop(False, source_mapper, l, None, r) + except exc.UnmappedColumnError as err: + _raise_col_to_prop(False, source_mapper, l, None, r, err) dict_[r.key] = value @@ -127,8 +128,8 @@ def source_modified(uowcommit, source, source_mapper, synchronize_pairs): for l, r in synchronize_pairs: try: prop = source_mapper._columntoproperty[l] - except exc.UnmappedColumnError: - _raise_col_to_prop(False, source_mapper, l, None, r) + except exc.UnmappedColumnError as err: + _raise_col_to_prop(False, source_mapper, l, None, r, err) history = uowcommit.get_attribute_history( source, prop.key, attributes.PASSIVE_NO_INITIALIZE ) @@ -139,22 +140,28 @@ def source_modified(uowcommit, source, source_mapper, synchronize_pairs): def _raise_col_to_prop( - isdest, source_mapper, source_column, dest_mapper, dest_column + isdest, source_mapper, source_column, dest_mapper, dest_column, err ): if isdest: - raise exc.UnmappedColumnError( - "Can't execute sync rule for " - "destination column '%s'; mapper '%s' does not map " - "this column. Try using an explicit `foreign_keys` " - "collection which does not include this column (or use " - "a viewonly=True relation)." % (dest_column, dest_mapper) + util.raise_( + exc.UnmappedColumnError( + "Can't execute sync rule for " + "destination column '%s'; mapper '%s' does not map " + "this column. Try using an explicit `foreign_keys` " + "collection which does not include this column (or use " + "a viewonly=True relation)." % (dest_column, dest_mapper) + ), + replace_context=err, ) else: - raise exc.UnmappedColumnError( - "Can't execute sync rule for " - "source column '%s'; mapper '%s' does not map this " - "column. Try using an explicit `foreign_keys` " - "collection which does not include destination column " - "'%s' (or use a viewonly=True relation)." - % (source_column, source_mapper, dest_column) + util.raise_( + exc.UnmappedColumnError( + "Can't execute sync rule for " + "source column '%s'; mapper '%s' does not map this " + "column. Try using an explicit `foreign_keys` " + "collection which does not include destination column " + "'%s' (or use a viewonly=True relation)." + % (source_column, source_mapper, dest_column) + ), + replace_context=err, ) diff --git a/lib/sqlalchemy/pool/base.py b/lib/sqlalchemy/pool/base.py index b53f0d7dd..17d5ba15f 100644 --- a/lib/sqlalchemy/pool/base.py +++ b/lib/sqlalchemy/pool/base.py @@ -578,8 +578,8 @@ class _ConnectionRecord(object): self.connection = connection self.fresh = True except Exception as e: - pool.logger.debug("Error on connect(): %s", e) - raise + with util.safe_reraise(): + pool.logger.debug("Error on connect(): %s", e) else: if first_connect_check: pool.dispatch.first_connect.for_modify( diff --git a/lib/sqlalchemy/processors.py b/lib/sqlalchemy/processors.py index 67f1564ec..8618d5e2a 100644 --- a/lib/sqlalchemy/processors.py +++ b/lib/sqlalchemy/processors.py @@ -32,10 +32,13 @@ def str_to_datetime_processor_factory(regexp, type_): else: try: m = rmatch(value) - except TypeError: - raise ValueError( - "Couldn't parse %s string '%r' " - "- value is not a string." % (type_.__name__, value) + except TypeError as err: + util.raise_( + ValueError( + "Couldn't parse %s string '%r' " + "- value is not a string." % (type_.__name__, value) + ), + from_=err, ) if m is None: raise ValueError( diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index a7324c45f..2d336360f 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -128,8 +128,8 @@ class _DialectArgView(util.collections_abc.MutableMapping): def _key(self, key): try: dialect, value_key = key.split("_", 1) - except ValueError: - raise KeyError(key) + except ValueError as err: + util.raise_(KeyError(key), replace_context=err) else: return dialect, value_key @@ -138,17 +138,20 @@ class _DialectArgView(util.collections_abc.MutableMapping): try: opt = self.obj.dialect_options[dialect] - except exc.NoSuchModuleError: - raise KeyError(key) + except exc.NoSuchModuleError as err: + util.raise_(KeyError(key), replace_context=err) else: return opt[value_key] def __setitem__(self, key, value): try: dialect, value_key = self._key(key) - except KeyError: - raise exc.ArgumentError( - "Keys must be of the form <dialectname>_<argname>" + except KeyError as err: + util.raise_( + exc.ArgumentError( + "Keys must be of the form <dialectname>_<argname>" + ), + replace_context=err, ) else: self.obj.dialect_options[dialect][value_key] = value @@ -634,17 +637,17 @@ class ColumnCollection(object): def __getitem__(self, key): try: return self._index[key] - except KeyError: + except KeyError as err: if isinstance(key, util.int_types): - raise IndexError(key) + util.raise_(IndexError(key), replace_context=err) else: raise def __getattr__(self, key): try: return self._index[key] - except KeyError: - raise AttributeError(key) + except KeyError as err: + util.raise_(AttributeError(key), replace_context=err) def __contains__(self, key): if key not in self._index: diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index b3bf4e93b..fc841bb4b 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -133,7 +133,13 @@ class RoleImpl(object): self._raise_for_expected(element, argname, resolved) def _raise_for_expected( - self, element, argname=None, resolved=None, advice=None, code=None + self, + element, + argname=None, + resolved=None, + advice=None, + code=None, + err=None, ): if argname: msg = "%s expected for argument %r; got %r." % ( @@ -147,7 +153,7 @@ class RoleImpl(object): if advice: msg += " " + advice - raise exc.ArgumentError(msg, code=code) + util.raise_(exc.ArgumentError(msg, code=code), replace_context=err) class _Deannotate(object): @@ -201,16 +207,19 @@ class _ColumnCoercions(object): def _no_text_coercion( - element, argname=None, exc_cls=exc.ArgumentError, extra=None + element, argname=None, exc_cls=exc.ArgumentError, extra=None, err=None ): - raise exc_cls( - "%(extra)sTextual SQL expression %(expr)r %(argname)sshould be " - "explicitly declared as text(%(expr)r)" - % { - "expr": util.ellipses_string(element), - "argname": "for argument %s" % (argname,) if argname else "", - "extra": "%s " % extra if extra else "", - } + util.raise_( + exc_cls( + "%(extra)sTextual SQL expression %(expr)r %(argname)sshould be " + "explicitly declared as text(%(expr)r)" + % { + "expr": util.ellipses_string(element), + "argname": "for argument %s" % (argname,) if argname else "", + "extra": "%s " % extra if extra else "", + } + ), + replace_context=err, ) @@ -290,8 +299,8 @@ class ExpressionElementImpl( return elements.BindParameter( name, element, type_, unique=True ) - except exc.ArgumentError: - self._raise_for_expected(element) + except exc.ArgumentError as err: + self._raise_for_expected(element, err=err) class BinaryElementImpl( @@ -302,8 +311,8 @@ class BinaryElementImpl( ): try: return expr._bind_param(operator, element, type_=bindparam_type) - except exc.ArgumentError: - self._raise_for_expected(element) + except exc.ArgumentError as err: + self._raise_for_expected(element, err=err) def _post_coercion(self, resolved, expr, **kw): if ( diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 9c1f50ce1..d31cf67f8 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1074,7 +1074,7 @@ class SQLCompiler(Compiled): col = only_froms[element.element] else: col = with_cols[element.element] - except KeyError: + except KeyError as err: coercions._no_text_coercion( element.element, extra=( @@ -1082,6 +1082,7 @@ class SQLCompiler(Compiled): "GROUP BY / DISTINCT etc." ), exc_cls=exc.CompileError, + err=err, ) else: kwargs["render_label_as_label"] = col @@ -1671,8 +1672,11 @@ class SQLCompiler(Compiled): else: try: opstring = OPERATORS[operator_] - except KeyError: - raise exc.UnsupportedCompilationError(self, operator_) + except KeyError as err: + util.raise_( + exc.UnsupportedCompilationError(self, operator_), + replace_context=err, + ) else: return self._generate_generic_binary( binary, opstring, from_linter=from_linter, **kw @@ -3286,11 +3290,12 @@ class DDLCompiler(Compiled): if column.primary_key: first_pk = True except exc.CompileError as ce: - util.raise_from_cause( + util.raise_( exc.CompileError( util.u("(in table '%s', column '%s'): %s") % (table.description, column.name, ce.args[0]) - ) + ), + from_=ce, ) const = self.create_table_constraints( diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index 31bcc34a4..5a2095604 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -801,7 +801,7 @@ class SchemaDropper(DDLBase): ) collection = [(t, ()) for t in unsorted_tables] else: - util.raise_from_cause( + util.raise_( exc.CircularDependencyError( err2.args[0], err2.cycles, @@ -818,7 +818,8 @@ class SchemaDropper(DDLBase): sorted([t.fullname for t in err2.cycles]) ) ), - ) + ), + from_=err2, ) seq_coll = [ diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index df690c383..d0babb1be 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -747,10 +747,13 @@ class ColumnElement( def comparator(self): try: comparator_factory = self.type.comparator_factory - except AttributeError: - raise TypeError( - "Object %r associated with '.type' attribute " - "is not a TypeEngine class or object" % self.type + except AttributeError as err: + util.raise_( + TypeError( + "Object %r associated with '.type' attribute " + "is not a TypeEngine class or object" % self.type + ), + replace_context=err, ) else: return comparator_factory(self) @@ -758,10 +761,17 @@ class ColumnElement( def __getattr__(self, key): try: return getattr(self.comparator, key) - except AttributeError: - raise AttributeError( - "Neither %r object nor %r object has an attribute %r" - % (type(self).__name__, type(self.comparator).__name__, key) + except AttributeError as err: + util.raise_( + AttributeError( + "Neither %r object nor %r object has an attribute %r" + % ( + type(self).__name__, + type(self.comparator).__name__, + key, + ) + ), + replace_context=err, ) def operate(self, op, *other, **kwargs): @@ -1742,10 +1752,13 @@ class TextClause( # a unique/anonymous key in any case, so use the _orig_key # so that a text() construct can support unique parameters existing = new_params[bind._orig_key] - except KeyError: - raise exc.ArgumentError( - "This text() construct doesn't define a " - "bound parameter named %r" % bind._orig_key + except KeyError as err: + util.raise_( + exc.ArgumentError( + "This text() construct doesn't define a " + "bound parameter named %r" % bind._orig_key + ), + replace_context=err, ) else: new_params[existing._orig_key] = bind @@ -1753,10 +1766,13 @@ class TextClause( for key, value in names_to_values.items(): try: existing = new_params[key] - except KeyError: - raise exc.ArgumentError( - "This text() construct doesn't define a " - "bound parameter named %r" % key + except KeyError as err: + util.raise_( + exc.ArgumentError( + "This text() construct doesn't define a " + "bound parameter named %r" % key + ), + replace_context=err, ) else: new_params[key] = existing._with_value(value) @@ -3665,9 +3681,12 @@ class Over(ColumnElement): else: try: lower = int(range_[0]) - except ValueError: - raise exc.ArgumentError( - "Integer or None expected for range value" + except ValueError as err: + util.raise_( + exc.ArgumentError( + "Integer or None expected for range value" + ), + replace_context=err, ) else: if lower == 0: @@ -3678,9 +3697,12 @@ class Over(ColumnElement): else: try: upper = int(range_[1]) - except ValueError: - raise exc.ArgumentError( - "Integer or None expected for range value" + except ValueError as err: + util.raise_( + exc.ArgumentError( + "Integer or None expected for range value" + ), + replace_context=err, ) else: if upper == 0: diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index e6d3a6b05..5445a1bce 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -107,12 +107,13 @@ class SchemaItem(SchemaEventTarget, visitors.Visitable): if item is not None: try: spwd = item._set_parent_with_dispatch - except AttributeError: - util.raise_from_cause( + except AttributeError as err: + util.raise_( exc.ArgumentError( "'SchemaItem' object, such as a 'Column' or a " "'Constraint' expected, got %r" % item - ) + ), + replace_context=err, ) else: spwd(self) @@ -1569,15 +1570,16 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): _proxies=[self], *fk ) - except TypeError: - util.raise_from_cause( + except TypeError as err: + util.raise_( TypeError( "Could not create a copy of this %r object. " "Ensure the class includes a _constructor() " "attribute or method which accepts the " "standard Column constructor arguments, or " "references the Column class itself." % self.__class__ - ) + ), + from_=err, ) c.table = selectable @@ -3187,10 +3189,13 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): try: ColumnCollectionConstraint._set_parent(self, table) except KeyError as ke: - raise exc.ArgumentError( - "Can't create ForeignKeyConstraint " - "on table '%s': no column " - "named '%s' is present." % (table.description, ke.args[0]) + util.raise_( + exc.ArgumentError( + "Can't create ForeignKeyConstraint " + "on table '%s': no column " + "named '%s' is present." % (table.description, ke.args[0]) + ), + from_=ke, ) for col, fk in zip(self.columns, self.elements): diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index b8d88e160..b972c13be 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -2620,10 +2620,13 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase): return None try: value = clause._limit_offset_value - except AttributeError: - raise exc.CompileError( - "This SELECT structure does not use a simple " - "integer value for %s" % attrname + except AttributeError as err: + util.raise_( + exc.CompileError( + "This SELECT structure does not use a simple " + "integer value for %s" % attrname + ), + replace_context=err, ) else: return util.asint(value) @@ -3489,10 +3492,13 @@ class Select( try: cols_present = bool(columns) - except TypeError: - raise exc.ArgumentError( - "columns argument to select() must " - "be a Python list or other iterable" + except TypeError as err: + util.raise_( + exc.ArgumentError( + "columns argument to select() must " + "be a Python list or other iterable" + ), + from_=err, ) if cols_present: diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 22c80cc91..e4a029a3e 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -1462,7 +1462,7 @@ class Enum(Emulated, String, SchemaType): def _db_value_for_elem(self, elem): try: return self._valid_lookup[elem] - except KeyError: + except KeyError as err: # for unknown string values, we return as is. While we can # validate these if we wanted, that does not allow for lesser-used # end-user use cases, such as using a LIKE comparison with an enum, @@ -1476,8 +1476,11 @@ class Enum(Emulated, String, SchemaType): ): return elem else: - raise LookupError( - '"%s" is not among the defined enum values' % elem + util.raise_( + LookupError( + '"%s" is not among the defined enum values' % elem + ), + replace_context=err, ) class Comparator(String.Comparator): @@ -1496,9 +1499,12 @@ class Enum(Emulated, String, SchemaType): def _object_value_for_elem(self, elem): try: return self._object_lookup[elem] - except KeyError: - raise LookupError( - '"%s" is not among the defined enum values' % elem + except KeyError as err: + util.raise_( + LookupError( + '"%s" is not among the defined enum values' % elem + ), + replace_context=err, ) def __repr__(self): diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index c6c860844..739f96195 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -479,9 +479,12 @@ class TypeEngine(Traversible): try: return dialect._type_memos[self]["literal"] except KeyError: - d = self._dialect_info(dialect) - d["literal"] = lp = d["impl"].literal_processor(dialect) - return lp + pass + # avoid KeyError context coming into literal_processor() function + # raises + d = self._dialect_info(dialect) + d["literal"] = lp = d["impl"].literal_processor(dialect) + return lp def _cached_bind_processor(self, dialect): """Return a dialect-specific bind processor for this type.""" @@ -489,9 +492,12 @@ class TypeEngine(Traversible): try: return dialect._type_memos[self]["bind"] except KeyError: - d = self._dialect_info(dialect) - d["bind"] = bp = d["impl"].bind_processor(dialect) - return bp + pass + # avoid KeyError context coming into bind_processor() function + # raises + d = self._dialect_info(dialect) + d["bind"] = bp = d["impl"].bind_processor(dialect) + return bp def _cached_result_processor(self, dialect, coltype): """Return a dialect-specific result processor for this type.""" @@ -499,21 +505,27 @@ class TypeEngine(Traversible): try: return dialect._type_memos[self][coltype] except KeyError: - d = self._dialect_info(dialect) - # key assumption: DBAPI type codes are - # constants. Else this dictionary would - # grow unbounded. - d[coltype] = rp = d["impl"].result_processor(dialect, coltype) - return rp + pass + # avoid KeyError context coming into result_processor() function + # raises + d = self._dialect_info(dialect) + # key assumption: DBAPI type codes are + # constants. Else this dictionary would + # grow unbounded. + d[coltype] = rp = d["impl"].result_processor(dialect, coltype) + return rp def _cached_custom_processor(self, dialect, key, fn): try: return dialect._type_memos[self][key] except KeyError: - d = self._dialect_info(dialect) - impl = d["impl"] - d[key] = result = fn(impl) - return result + pass + # avoid KeyError context coming into fn() function + # raises + d = self._dialect_info(dialect) + impl = d["impl"] + d[key] = result = fn(impl) + return result def _dialect_info(self, dialect): """Return a dialect-specific registry which diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 77e6b53a8..fda48c657 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -62,9 +62,10 @@ def _generate_compiler_dispatch(cls): "def _compiler_dispatch(self, visitor, **kw):\n" " try:\n" " meth = visitor.visit_%(name)s\n" - " except AttributeError:\n" - " util.raise_from_cause(\n" - " exc.UnsupportedCompilationError(visitor, cls))\n" + " except AttributeError as err:\n" + " util.raise_(\n" + " exc.UnsupportedCompilationError(visitor, cls), \n" + " replace_context=err)\n" " else:\n" " return meth(self, **kw)\n" ) % {"name": visit_name} diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py index 582901579..79b7f9eb3 100644 --- a/lib/sqlalchemy/testing/__init__.py +++ b/lib/sqlalchemy/testing/__init__.py @@ -9,7 +9,9 @@ from . import config # noqa from . import mock # noqa from .assertions import assert_raises # noqa +from .assertions import assert_raises_context_ok # noqa from .assertions import assert_raises_message # noqa +from .assertions import assert_raises_message_context_ok # noqa from .assertions import assert_raises_return # noqa from .assertions import AssertsCompiledSQL # noqa from .assertions import AssertsExecutionResults # noqa diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index f5325b0cb..c97202516 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -9,6 +9,7 @@ from __future__ import absolute_import import contextlib import re +import sys import warnings from . import assertsql @@ -258,41 +259,80 @@ def eq_ignore_whitespace(a, b, msg=None): assert a == b, msg or "%r != %r" % (a, b) +def _assert_proper_exception_context(exception): + """assert that any exception we're catching does not have a __context__ + without a __cause__, and that __suppress_context__ is never set. + + Python 3 will report nested as exceptions as "during the handling of + error X, error Y occurred". That's not what we want to do. we want + these exceptions in a cause chain. + + """ + + if not util.py3k: + return + + if ( + exception.__context__ is not exception.__cause__ + and not exception.__suppress_context__ + ): + assert False, ( + "Exception %r was correctly raised but did not set a cause, " + "within context %r as its cause." + % (exception, exception.__context__) + ) + + def assert_raises(except_cls, callable_, *args, **kw): - try: - callable_(*args, **kw) - success = False - except except_cls: - success = True + _assert_raises(except_cls, callable_, args, kw, check_context=True) - # assert outside the block so it works for AssertionError too ! - assert success, "Callable did not raise an exception" + +def assert_raises_context_ok(except_cls, callable_, *args, **kw): + _assert_raises( + except_cls, callable_, args, kw, + ) def assert_raises_return(except_cls, callable_, *args, **kw): + return _assert_raises(except_cls, callable_, args, kw, check_context=True) + + +def assert_raises_message(except_cls, msg, callable_, *args, **kwargs): + _assert_raises( + except_cls, callable_, args, kwargs, msg=msg, check_context=True + ) + + +def assert_raises_message_context_ok( + except_cls, msg, callable_, *args, **kwargs +): + _assert_raises(except_cls, callable_, args, kwargs, msg=msg) + + +def _assert_raises( + except_cls, callable_, args, kwargs, msg=None, check_context=False +): ret_err = None + if check_context: + are_we_already_in_a_traceback = sys.exc_info()[0] try: - callable_(*args, **kw) + callable_(*args, **kwargs) success = False except except_cls as err: - success = True ret_err = err + success = True + if msg is not None: + assert re.search( + msg, util.text_type(err), re.UNICODE + ), "%r !~ %s" % (msg, err,) + if check_context and not are_we_already_in_a_traceback: + _assert_proper_exception_context(err) + print(util.text_type(err).encode("utf-8")) # assert outside the block so it works for AssertionError too ! assert success, "Callable did not raise an exception" - return ret_err - -def assert_raises_message(except_cls, msg, callable_, *args, **kwargs): - try: - callable_(*args, **kwargs) - assert False, "Callable did not raise an exception" - except except_cls as e: - assert re.search(msg, util.text_type(e), re.UNICODE), "%r !~ %s" % ( - msg, - e, - ) - print(util.text_type(e).encode("utf-8")) + return ret_err class AssertsCompiledSQL(object): diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py index 0c05bf9e9..1a23ebf41 100644 --- a/lib/sqlalchemy/testing/exclusions.py +++ b/lib/sqlalchemy/testing/exclusions.py @@ -9,6 +9,7 @@ import contextlib import operator import re +import sys from . import config from .. import util @@ -145,7 +146,7 @@ class compound(object): ) break else: - util.raise_from_cause(ex) + util.raise_(ex, with_traceback=sys.exc_info()[2]) def _expect_success(self, config, name="block"): if not self.fails: diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index b4610a1b0..819d18018 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -67,6 +67,7 @@ from .compat import py33 # noqa from .compat import py36 # noqa from .compat import py3k # noqa from .compat import quote_plus # noqa +from .compat import raise_ # noqa from .compat import raise_from_cause # noqa from .compat import reduce # noqa from .compat import reraise # noqa diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index 104e8e03d..31654b97c 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -147,13 +147,42 @@ if py3k: def cmp(a, b): return (a > b) - (a < b) - def reraise(tp, value, tb=None, cause=None): - if cause is not None: - assert cause is not value, "Same cause emitted" - value.__cause__ = cause - if value.__traceback__ is not tb: - raise value.with_traceback(tb) - raise value + def raise_( + exception, with_traceback=None, replace_context=None, from_=False + ): + r"""implement "raise" with cause support. + + :param exception: exception to raise + :param with_traceback: will call exception.with_traceback() + :param replace_context: an as-yet-unsupported feature. This is + an exception object which we are "replacing", e.g., it's our + "cause" but we don't want it printed. Basically just what + ``__suppress_context__`` does but we don't want to suppress + the enclosing context, if any. So for now we make it the + cause. + :param from\_: the cause. this actually sets the cause and doesn't + hope to hide it someday. + + """ + if with_traceback is not None: + exception = exception.with_traceback(with_traceback) + + if from_ is not False: + exception.__cause__ = from_ + elif replace_context is not None: + # no good solution here, we would like to have the exception + # have only the context of replace_context.__context__ so that the + # intermediary exception does not change, but we can't figure + # that out. + exception.__cause__ = replace_context + + try: + raise exception + finally: + # credit to + # https://cosmicpercolator.com/2016/01/13/exception-leaks-in-python-2-and-3/ + # as the __traceback__ object creates a cycle + del exception, replace_context, from_, with_traceback def u(s): return s @@ -257,13 +286,13 @@ else: else: return text - # not as nice as that of Py3K, but at least preserves - # the code line where the issue occurred exec( - "def reraise(tp, value, tb=None, cause=None):\n" - " if cause is not None:\n" - " assert cause is not value, 'Same cause emitted'\n" - " raise tp, value, tb\n" + "def raise_(exception, with_traceback=None, replace_context=None, " + "from_=False):\n" + " if with_traceback:\n" + " raise type(exception), exception, with_traceback\n" + " else:\n" + " raise exception\n" ) TYPE_CHECKING = False @@ -370,6 +399,8 @@ else: def raise_from_cause(exception, exc_info=None): + r"""legacy. use raise\_()""" + if exc_info is None: exc_info = sys.exc_info() exc_type, exc_value, exc_tb = exc_info @@ -377,6 +408,12 @@ def raise_from_cause(exception, exc_info=None): reraise(type(exception), exception, tb=exc_tb, cause=cause) +def reraise(tp, value, tb=None, cause=None): + r"""legacy. use raise\_()""" + + raise_(value, with_traceback=tb, from_=cause) + + def with_metaclass(meta, *bases): """Create a base class with a metaclass. diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 41a9698c7..09aa94bf2 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -65,7 +65,9 @@ class safe_reraise(object): exc_type, exc_value, exc_tb = self._exc_info self._exc_info = None # remove potential circular references if not self.warn_only: - compat.reraise(exc_type, exc_value, exc_tb) + compat.raise_( + exc_value, with_traceback=exc_tb, + ) else: if not compat.py3k and self._exc_info and self._exc_info[1]: # emulate Py3K's behavior of telling us when an exception @@ -76,7 +78,7 @@ class safe_reraise(object): "is:\n %s %s\n" % (self._exc_info[0], self._exc_info[1]) ) self._exc_info = None # remove potential circular references - compat.reraise(type_, value, traceback) + compat.raise_(value, with_traceback=traceback) def string_or_unprintable(element): diff --git a/test/aaa_profiling/test_memusage.py b/test/aaa_profiling/test_memusage.py index 55890cd06..8f84acde8 100644 --- a/test/aaa_profiling/test_memusage.py +++ b/test/aaa_profiling/test_memusage.py @@ -1124,6 +1124,20 @@ class CycleTest(_fixtures.FixtureTest): go() + def test_raise_from(self): + @assert_cycles() + def go(): + try: + try: + raise KeyError("foo") + except KeyError as ke: + + util.raise_(Exception("oops"), from_=ke) + except Exception as err: # noqa + pass + + go() + def test_query_alias(self): User, Address = self.classes("User", "Address") configure_mappers() diff --git a/test/aaa_profiling/test_resultset.py b/test/aaa_profiling/test_resultset.py index 87908f016..73a1a8b6f 100644 --- a/test/aaa_profiling/test_resultset.py +++ b/test/aaa_profiling/test_resultset.py @@ -111,7 +111,7 @@ class ResultSetTest(fixtures.TestBase, AssertsExecutionResults): "some other column", Integer ) - @profiling.function_call_count() + @profiling.function_call_count(variance=0.10) def go(): c1 in row diff --git a/test/base/test_utils.py b/test/base/test_utils.py index 48e464a01..183e157e5 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -3,7 +3,6 @@ import copy import datetime import inspect -import sys from sqlalchemy import exc from sqlalchemy import sql @@ -2899,20 +2898,29 @@ class ReraiseTest(fixtures.TestBase): except MyException as err: is_(err.__cause__, None) - def test_reraise_disallow_same_cause(self): + def test_raise_from_cause_legacy(self): class MyException(Exception): pass + class MyOtherException(Exception): + pass + + me = MyException("exc on") + def go(): try: - raise MyException("exc one") - except Exception as err: - type_, value, tb = sys.exc_info() - util.reraise(type_, err, tb, value) + raise me + except Exception: + util.raise_from_cause(MyOtherException("exc two")) - assert_raises_message(AssertionError, "Same cause emitted", go) + try: + go() + assert False + except MyOtherException as moe: + if testing.requires.python3.enabled: + is_(moe.__cause__, me) - def test_raise_from_cause(self): + def test_raise_from(self): class MyException(Exception): pass @@ -2924,8 +2932,8 @@ class ReraiseTest(fixtures.TestBase): def go(): try: raise me - except Exception: - util.raise_from_cause(MyOtherException("exc two")) + except Exception as err: + util.raise_(MyOtherException("exc two"), from_=err) try: go() diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index bad6c1603..0dd4f1301 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -709,12 +709,13 @@ class ExecuteTest(fixtures.TestBase): return super(MockCursor, self).execute(stmt, params, **kw) eng = engines.proxying_engine(cursor_cls=MockCursor) - assert_raises_message( - tsa.exc.SAWarning, - "Exception attempting to detect unicode returns", - eng.connect, - ) - assert eng.dialect.returns_unicode_strings in (True, False) + with testing.expect_warnings( + "Exception attempting to detect unicode returns" + ): + eng.connect() + + # because plain varchar passed, we don't know the correct answer + eq_(eng.dialect.returns_unicode_strings, "conditional") eng.dispose() def test_works_after_dispose(self): diff --git a/test/engine/test_pool.py b/test/engine/test_pool.py index cfe20f5ec..72e0fa186 100644 --- a/test/engine/test_pool.py +++ b/test/engine/test_pool.py @@ -10,6 +10,7 @@ from sqlalchemy import pool from sqlalchemy import select from sqlalchemy import testing from sqlalchemy.testing import assert_raises +from sqlalchemy.testing import assert_raises_context_ok from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures @@ -1256,7 +1257,7 @@ class QueuePoolTest(PoolTestBase): eq_(p.checkedout(), 0) eq_(p._overflow, 0) dbapi.shutdown(True) - assert_raises(Exception, p.connect) + assert_raises_context_ok(Exception, p.connect) eq_(p._overflow, 0) eq_(p.checkedout(), 0) # and not 1 diff --git a/test/engine/test_reconnect.py b/test/engine/test_reconnect.py index 205c1fb31..000be1a70 100644 --- a/test/engine/test_reconnect.py +++ b/test/engine/test_reconnect.py @@ -14,6 +14,7 @@ from sqlalchemy import util from sqlalchemy.engine import url from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message +from sqlalchemy.testing import assert_raises_message_context_ok from sqlalchemy.testing import engines from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_warnings @@ -255,7 +256,7 @@ class PrePingMockTest(fixtures.TestBase): self.dbapi.shutdown("execute", stop=True) - assert_raises_message( + assert_raises_message_context_ok( MockDisconnect, "database is stopped", pool.connect ) @@ -835,7 +836,7 @@ class CursorErrTest(fixtures.TestBase): def test_cursor_shutdown_in_initialize(self): db = self._fixture(True, True) - assert_raises_message( + assert_raises_message_context_ok( exc.SAWarning, "Exception attempting to detect", db.connect ) eq_( diff --git a/test/engine/test_reflection.py b/test/engine/test_reflection.py index 301614061..579f1aece 100644 --- a/test/engine/test_reflection.py +++ b/test/engine/test_reflection.py @@ -24,6 +24,7 @@ from sqlalchemy.testing import eq_regex from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import in_ +from sqlalchemy.testing import is_ from sqlalchemy.testing import is_false from sqlalchemy.testing import is_true from sqlalchemy.testing import mock @@ -596,13 +597,10 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): testing.db.dialect.ischema_names = {} try: m2 = MetaData(testing.db) - assert_raises(sa.exc.SAWarning, Table, "test", m2, autoload=True) - @testing.emits_warning("Did not recognize type") - def warns(): - m3 = MetaData(testing.db) - t3 = Table("test", m3, autoload=True) - assert t3.c.foo.type.__class__ == sa.types.NullType + with testing.expect_warnings("Did not recognize type"): + t3 = Table("test", m2, autoload_with=testing.db) + is_(t3.c.foo.type.__class__, sa.types.NullType) finally: testing.db.dialect.ischema_names = ischema_names diff --git a/test/sql/test_metadata.py b/test/sql/test_metadata.py index 3f4333750..8ef272a9e 100644 --- a/test/sql/test_metadata.py +++ b/test/sql/test_metadata.py @@ -4093,16 +4093,11 @@ class DialectKWArgTest(fixtures.TestBase): def test_unknown_dialect_warning(self): with self._fixture(): - assert_raises_message( - exc.SAWarning, + with testing.expect_warnings( "Can't validate argument 'unknown_y'; can't locate " "any SQLAlchemy dialect named 'unknown'", - Index, - "a", - "b", - "c", - unknown_y=True, - ) + ): + Index("a", "b", "c", unknown_y=True) def test_participating_bad_kw(self): with self._fixture(): |