diff options
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r-- | numpy/core/numeric.py | 47 |
1 files changed, 41 insertions, 6 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 0a211ff61..02b931f90 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -617,11 +617,14 @@ for key in _errdict.keys(): _errdict_rev[_errdict[key]] = key del key -def seterr(divide=None, over=None, under=None, invalid=None): +def seterr(all=None, divide=None, over=None, under=None, invalid=None): """Set how floating-point errors are handled. Valid values for each type of error are the strings "ignore", "warn", "raise", and "call". Returns the old settings. + If 'all' is specified, values that are not otherwise specified + will be set to 'all', otherwise they will retain their old + values. Note that operations on integer scalar types (such as int16) are handled like floating point, and are affected by these settings. @@ -630,19 +633,24 @@ def seterr(divide=None, over=None, under=None, invalid=None): >>> seterr(over='raise') {'over': 'ignore', 'divide': 'ignore', 'invalid': 'ignore', 'under': 'ignore'} + >>> seterr(all='warn', over='raise') + {'over': 'raise', 'divide': 'ignore', 'invalid': 'ignore', 'under': 'ignore'} >>> int16(32000) * int16(3) Traceback (most recent call last): File "<stdin>", line 1, in ? FloatingPointError: overflow encountered in short_scalars + >>> seterr(all='ignore') + {'over': 'ignore', 'divide': 'ignore', 'invalid': 'ignore', 'under': 'ignore'} + """ pyvals = umath.geterrobj() old = geterr() - if divide is None: divide = old['divide'] - if over is None: over = old['over'] - if under is None: under = old['under'] - if invalid is None: invalid = old['invalid'] + if divide is None: divide = all or old['divide'] + if over is None: over = all or old['over'] + if under is None: under = all or old['under'] + if invalid is None: invalid = all or old['invalid'] maskvalue = ((_errdict[divide] << SHIFT_DIVIDEBYZERO) + (_errdict[over] << SHIFT_OVERFLOW ) + @@ -653,6 +661,7 @@ def seterr(divide=None, over=None, under=None, invalid=None): umath.seterrobj(pyvals) return old + def geterr(): """Get the current way of handling floating-point errors. @@ -718,12 +727,38 @@ def geterrcall(): return umath.geterrobj()[2] class errstate(object): + """with errstate(**state): --> operations in following block use given state. + + # Set error handling to known state. + >>> _ = seterr(invalid='raise', divide='raise', over='raise', under='ignore') + + |>> a = -arange(3) + |>> with errstate(invalid='ignore'): + ... print sqrt(a) + [ 0. -1.#IND -1.#IND] + |>> print sqrt(a.astype(complex)) + [ 0. +0.00000000e+00j 0. +1.00000000e+00j 0. +1.41421356e+00j] + |>> print sqrt(a) + Traceback (most recent call last): + ... + FloatingPointError: invalid encountered in sqrt + |>> with errstate(divide='ignore'): + ... print a/0 + [0 0 0] + |>> print a/0 + Traceback (most recent call last): + ... + FloatingPointError: divide by zero encountered in divide + + """ + # Note that we don't want to run the above doctests because they will fail + # without a from __future__ import with_statement def __init__(self, **kwargs): self.kwargs = kwargs def __enter__(self): self.oldstate = seterr(**self.kwargs) def __exit__(self, *exc_info): - numpy.seterr(**self.oldstate) + seterr(**self.oldstate) def _setdef(): defval = [UFUNC_BUFSIZE_DEFAULT, ERR_DEFAULT, None] |