summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorJulian Taylor <jtaylor.debian@googlemail.com>2013-09-29 13:35:39 +0200
committerJulian Taylor <jtaylor.debian@googlemail.com>2013-10-19 19:33:12 +0200
commit6d6dc6d56ed97421d1f106882854058e4fa4fdb3 (patch)
treec7e9199c00c29b5075e4f9e3c4914cf213aa9018 /numpy/core
parente79296bf9cd9872ab3a4c54c05cce8f5ac410bf7 (diff)
downloadnumpy-6d6dc6d56ed97421d1f106882854058e4fa4fdb3.tar.gz
MAINT: refactor ufunc error object handling
_get_global_ext_obj: retrieves global ufunc object _get_bufsize_errmask: get only bufsize and errormask from ufunc object _extract_pyvals: handle NULL extobj PyUFunc_GetPyValues implemented as _get_global_ext_obj +_extract_pyvals drop unused first_error variable. fix errobj memory leak in previous commit. add some test for the extobj and warning path, the warning tests are disabled like the raising path as they fail on a bunch of platforms.
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/src/umath/ufunc_object.c164
-rw-r--r--numpy/core/tests/test_numeric.py51
2 files changed, 143 insertions, 72 deletions
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index 46ce8e17b..e419a6611 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -73,11 +73,16 @@ static int
_does_loop_use_arrays(void *data);
static int
+_extract_pyvals(PyObject *ref, char *name, int *bufsize,
+ int *errmask, PyObject **errobj);
+
+static int
assign_reduce_identity_zero(PyArrayObject *result, void *data);
static int
assign_reduce_identity_one(PyArrayObject *result, void *data);
+
/*
* fpstatus is the ufunc_formatted hardware status
* errmask is the handling mask specified by the user.
@@ -225,6 +230,53 @@ PyUFunc_clearfperr()
PyUFunc_getfperr();
}
+
+#if USE_USE_DEFAULTS==1
+static int PyUFunc_NUM_NODEFAULTS = 0;
+#endif
+static PyObject *PyUFunc_PYVALS_NAME = NULL;
+
+static PyObject *
+_get_global_ext_obj(char * name)
+{
+ PyObject *thedict;
+ PyObject *ref = NULL;
+
+#if USE_USE_DEFAULTS==1
+ if (PyUFunc_NUM_NODEFAULTS != 0) {
+#endif
+ if (PyUFunc_PYVALS_NAME == NULL) {
+ PyUFunc_PYVALS_NAME = PyUString_InternFromString(UFUNC_PYVALS_NAME);
+ }
+ thedict = PyThreadState_GetDict();
+ if (thedict == NULL) {
+ thedict = PyEval_GetBuiltins();
+ }
+ ref = PyDict_GetItem(thedict, PyUFunc_PYVALS_NAME);
+#if USE_USE_DEFAULTS==1
+ }
+#endif
+
+ return ref;
+}
+
+
+static int
+_get_bufsize_errmask(PyObject * extobj, char * ufunc_name,
+ int *buffersize, int *errormask)
+{
+ /* Get the buffersize and errormask */
+ if (extobj == NULL) {
+ extobj = _get_global_ext_obj(ufunc_name);
+ }
+ if (_extract_pyvals(extobj, ufunc_name,
+ buffersize, errormask, NULL) < 0) {
+ return -1;
+ }
+
+ return 0;
+}
+
/*
* This function analyzes the input arguments
* and determines an appropriate __array_prepare__ function to call
@@ -363,16 +415,12 @@ _find_array_prepare(PyObject *args, PyObject *kwds,
return;
}
-#if USE_USE_DEFAULTS==1
-static int PyUFunc_NUM_NODEFAULTS = 0;
-#endif
-static PyObject *PyUFunc_PYVALS_NAME = NULL;
-
-
/*
* Extracts some values from the global pyvals tuple.
+ * all destinations may be NULL, in which case they are not retrieved
* ref - should hold the global tuple
* name - is the name of the ufunc (ufuncobj->name)
+ *
* bufsize - receives the buffer size to use
* errmask - receives the bitmask for error handling
* errobj - receives the python object to call with the error,
@@ -384,6 +432,13 @@ _extract_pyvals(PyObject *ref, char *name, int *bufsize,
{
PyObject *retval;
+ if (ref == NULL) {
+ *errmask = UFUNC_ERR_DEFAULT;
+ *errobj = Py_BuildValue("NO", PyBytes_FromString(name), Py_None);
+ *bufsize = NPY_BUFSIZE;
+ return 0;
+ }
+
if (!PyList_Check(ref) || (PyList_GET_SIZE(ref)!=3)) {
PyErr_Format(PyExc_TypeError,
"%s must be a length 3 list.", UFUNC_PYVALS_NAME);
@@ -445,7 +500,6 @@ _extract_pyvals(PyObject *ref, char *name, int *bufsize,
}
-
/*UFUNC_API
*
* On return, if errobj is populated with a non-NULL value, the caller
@@ -455,28 +509,8 @@ NPY_NO_EXPORT int
PyUFunc_GetPyValues(char *name, int *bufsize, int *errmask, PyObject **errobj)
{
PyObject *thedict;
- PyObject *ref = NULL;
+ PyObject *ref = _get_global_ext_obj(name);
-#if USE_USE_DEFAULTS==1
- if (PyUFunc_NUM_NODEFAULTS != 0) {
-#endif
- if (PyUFunc_PYVALS_NAME == NULL) {
- PyUFunc_PYVALS_NAME = PyUString_InternFromString(UFUNC_PYVALS_NAME);
- }
- thedict = PyThreadState_GetDict();
- if (thedict == NULL) {
- thedict = PyEval_GetBuiltins();
- }
- ref = PyDict_GetItem(thedict, PyUFunc_PYVALS_NAME);
-#if USE_USE_DEFAULTS==1
- }
-#endif
- if (ref == NULL) {
- *errmask = UFUNC_ERR_DEFAULT;
- *errobj = Py_BuildValue("NO", PyBytes_FromString(name), Py_None);
- *bufsize = NPY_BUFSIZE;
- return 0;
- }
return _extract_pyvals(ref, name, bufsize, errmask, errobj);
}
@@ -1721,10 +1755,20 @@ make_arr_prep_args(npy_intp nin, PyObject *args, PyObject *kwds)
}
}
+/*
+ * check the floating point status
+ * - errmask: mask of status to check
+ * - extobj: ufunc pyvals object
+ * may be null, in which case the thread global one is fetched
+ * - ufunc_name: name of ufunc
+ */
static int
-_check_ufunc_err(int errmask, PyObject *extobj, char* ufunc_name, int *first) {
+_check_ufunc_fperr(int errmask, PyObject *extobj, char* ufunc_name) {
int fperr;
PyObject *errobj = NULL;
+ int ret;
+ int first = 1;
+
if (!errmask) {
return 0;
}
@@ -1735,23 +1779,21 @@ _check_ufunc_err(int errmask, PyObject *extobj, char* ufunc_name, int *first) {
/* Get error object globals */
if (extobj == NULL) {
- if (PyUFunc_GetPyValues(ufunc_name,
- NULL, NULL, &errobj) < 0) {
- Py_XDECREF(errobj);
- return -1;
- }
+ extobj = _get_global_ext_obj(ufunc_name);
}
- else {
- if (_extract_pyvals(extobj, ufunc_name,
- NULL, NULL, &errobj) < 0) {
- Py_XDECREF(errobj);
- return -1;
- }
+ if (_extract_pyvals(extobj, ufunc_name,
+ NULL, NULL, &errobj) < 0) {
+ Py_XDECREF(errobj);
+ return -1;
}
- return PyUFunc_handlefperr(errmask, errobj, fperr, first);
+ ret = PyUFunc_handlefperr(errmask, errobj, fperr, &first);
+ Py_XDECREF(errobj);
+
+ return ret;
}
+
static int
PyUFunc_GeneralizedFunction(PyUFuncObject *ufunc,
PyObject *args, PyObject *kwds,
@@ -1777,7 +1819,6 @@ PyUFunc_GeneralizedFunction(PyUFuncObject *ufunc,
/* These parameters come from extobj= or from a TLS global */
int buffersize = 0, errormask = 0;
- int first_error = 1;
/* The selected inner loop */
PyUFuncGenericFunction innerloop = NULL;
@@ -1981,19 +2022,9 @@ PyUFunc_GeneralizedFunction(PyUFuncObject *ufunc,
}
/* Get the buffersize and errormask */
- if (extobj == NULL) {
- if (PyUFunc_GetPyValues(NULL,
- &buffersize, &errormask, NULL) < 0) {
- retval = -1;
- goto fail;
- }
- }
- else {
- if (_extract_pyvals(extobj, NULL,
- &buffersize, &errormask, NULL) < 0) {
- retval = -1;
- goto fail;
- }
+ if (_get_bufsize_errmask(extobj, ufunc_name, &buffersize, &errormask) < 0) {
+ retval = -1;
+ goto fail;
}
NPY_UF_DBG_PRINT("Finding inner loop\n");
@@ -2239,7 +2270,8 @@ PyUFunc_GeneralizedFunction(PyUFuncObject *ufunc,
}
/* Check whether any errors occurred during the loop */
- if (PyErr_Occurred() || _check_ufunc_err(errormask, extobj, ufunc_name, &first_error) < 0) {
+ if (PyErr_Occurred() ||
+ _check_ufunc_fperr(errormask, extobj, ufunc_name) < 0) {
retval = -1;
goto fail;
}
@@ -2296,7 +2328,6 @@ PyUFunc_GenericFunction(PyUFuncObject *ufunc,
/* These parameters come from extobj= or from a TLS global */
int buffersize = 0, errormask = 0;
- int first_error = 1;
/* The mask provided in the 'where=' parameter */
PyArrayObject *wheremask = NULL;
@@ -2359,19 +2390,9 @@ PyUFunc_GenericFunction(PyUFuncObject *ufunc,
}
/* Get the buffersize and errormask */
- if (extobj == NULL) {
- if (PyUFunc_GetPyValues(NULL,
- &buffersize, &errormask, NULL) < 0) {
- retval = -1;
- goto fail;
- }
- }
- else {
- if (_extract_pyvals(extobj, NULL,
- &buffersize, &errormask, NULL) < 0) {
- retval = -1;
- goto fail;
- }
+ if (_get_bufsize_errmask(extobj, ufunc_name, &buffersize, &errormask) < 0) {
+ retval = -1;
+ goto fail;
}
NPY_UF_DBG_PRINT("Finding inner loop\n");
@@ -2481,7 +2502,8 @@ PyUFunc_GenericFunction(PyUFuncObject *ufunc,
}
/* Check whether any errors occurred during the loop */
- if (PyErr_Occurred() || _check_ufunc_err(errormask, extobj, ufunc_name, &first_error) < 0) {
+ if (PyErr_Occurred() ||
+ _check_ufunc_fperr(errormask, extobj, ufunc_name) < 0) {
retval = -1;
goto fail;
}
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index 913599e09..5a3de8edd 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -409,6 +409,36 @@ class TestSeterr(TestCase):
seterr(divide='ignore')
array([1.]) / array([0.])
+ def test_errobj(self):
+ olderrobj = np.geterrobj()
+ self.called = 0
+ try:
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter("always")
+ with errstate(divide='warn'):
+ np.seterrobj([20000, 1, None])
+ array([1.]) / array([0.])
+ self.assertEqual(len(w), 1)
+
+ def log_err(*args):
+ self.called += 1
+ extobj_err = args
+ assert (len(extobj_err) == 2)
+ assert ("divide" in extobj_err[0])
+
+ with errstate(divide='ignore'):
+ np.seterrobj([20000, 3, log_err])
+ array([1.]) / array([0.])
+ self.assertEqual(self.called, 1)
+
+ np.seterrobj(olderrobj)
+ with errstate(divide='ignore'):
+ np.divide(1., 0., extobj=[20000, 3, log_err])
+ self.assertEqual(self.called, 2)
+ finally:
+ np.seterrobj(olderrobj)
+ del self.called
+
class TestFloatExceptions(TestCase):
def assert_raises_fpe(self, fpeerr, flop, x, y):
@@ -433,7 +463,7 @@ class TestFloatExceptions(TestCase):
self.assert_raises_fpe(fpeerr, flop, sc1, sc2[()]);
self.assert_raises_fpe(fpeerr, flop, sc1[()], sc2[()]);
- @dec.knownfailureif(True, "See ticket 1755")
+ @dec.knownfailureif(True, "See ticket #2350")
def test_floating_exceptions(self):
# Test basic arithmetic function errors
with np.errstate(all='raise'):
@@ -488,6 +518,25 @@ class TestFloatExceptions(TestCase):
self.assert_raises_fpe(invalid,
lambda a, b:a*b, ftype(0), ftype(np.inf))
+ def test_warnings(self):
+ # test warning code path
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter("always")
+ with np.errstate(all="warn"):
+ np.divide(1, 0.)
+ self.assertEqual(len(w), 1)
+ self.assertTrue("divide by zero" in str(w[0].message))
+ np.array(1e300) * np.array(1e300)
+ self.assertEqual(len(w), 2)
+ self.assertTrue("overflow" in str(w[-1].message))
+ np.array(np.inf) - np.array(np.inf)
+ self.assertEqual(len(w), 3)
+ self.assertTrue("invalid value" in str(w[-1].message))
+ np.array(1e-300) * np.array(1e-300)
+ self.assertEqual(len(w), 4)
+ self.assertTrue("underflow" in str(w[-1].message))
+
+
class TestTypes(TestCase):
def check_promotion_cases(self, promote_func):
#Tests that the scalars get coerced correctly.