summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorTravis Oliphant <oliphant@enthought.com>2006-06-26 21:14:55 +0000
committerTravis Oliphant <oliphant@enthought.com>2006-06-26 21:14:55 +0000
commitbc7a2ecaade93a802543c7494427b7b8b73820ae (patch)
tree9b316c58835d6004b4b5722b704c6dec87bd0701 /numpy
parentce9de277644a87d960c0813687f6e4d65827ed65 (diff)
downloadnumpy-bc7a2ecaade93a802543c7494427b7b8b73820ae.tar.gz
Fix missing error checks.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarraymodule.c23
-rw-r--r--numpy/core/tests/test_numeric.py8
2 files changed, 19 insertions, 12 deletions
diff --git a/numpy/core/src/multiarraymodule.c b/numpy/core/src/multiarraymodule.c
index d64c33e52..cf9ea694f 100644
--- a/numpy/core/src/multiarraymodule.c
+++ b/numpy/core/src/multiarraymodule.c
@@ -216,6 +216,7 @@ power_of_ten(int n)
static PyObject *
PyArray_Round(PyArrayObject *a, int decimals)
{
+ PyObject *f, *ret=NULL, *tmp;
if (PyArray_ISCOMPLEX(a)) {
PyObject *part;
PyObject *round_part;
@@ -258,7 +259,6 @@ PyArray_Round(PyArrayObject *a, int decimals)
return PyArray_GenericUnaryFunction((PyAO *)a, n_ops.rint);
}
if (decimals > 0) {
- PyObject *f, *ret, *tmp;
if (PyArray_ISINTEGER(a)) {
Py_INCREF(a);
return (PyObject *)a;
@@ -266,48 +266,53 @@ PyArray_Round(PyArrayObject *a, int decimals)
f = PyFloat_FromDouble(power_of_ten(decimals));
if (f==NULL) return NULL;
ret = PyNumber_Multiply((PyObject *)a, f);
- if (ret==NULL) {Py_DECREF(f); return NULL;}
+ if (ret==NULL) goto finish;
if (PyArray_IsScalar(ret, Generic)) {
/* array scalars cannot be modified inplace */
tmp = PyObject_CallFunction(n_ops.rint, "O", ret);
Py_DECREF(ret);
+ if (tmp == NULL) {ret=NULL; goto finish;}
ret = PyObject_CallFunction(n_ops.divide, "OO",
tmp, f);
Py_DECREF(tmp);
} else {
tmp = PyObject_CallFunction(n_ops.rint, "OO", ret, ret);
- Py_DECREF(tmp);
+ if (tmp == NULL) {Py_DECREF(ret); ret=NULL; goto finish;}
tmp = PyObject_CallFunction(n_ops.divide, "OOO", ret,
f, ret);
+ if (tmp == NULL) {Py_DECREF(ret); ret=NULL; goto finish;}
Py_DECREF(tmp);
}
- Py_DECREF(f);
- return ret;
}
else {
/* remaining case: decimals < 0 */
- PyObject *f, *ret, *tmp;
f = PyFloat_FromDouble(power_of_ten(-decimals));
if (f==NULL) return NULL;
ret = PyNumber_Divide((PyObject *)a, f);
- if (ret==NULL) {Py_DECREF(f); return NULL;}
+ if (ret==NULL) goto finish;
if (PyArray_IsScalar(ret, Generic)) {
/* array scalars cannot be modified inplace */
tmp = PyObject_CallFunction(n_ops.rint, "O", ret);
Py_DECREF(ret);
+ if (tmp == NULL) {ret=NULL; goto finish;}
ret = PyObject_CallFunction(n_ops.multiply, "OO",
tmp, f);
Py_DECREF(tmp);
} else {
tmp = PyObject_CallFunction(n_ops.rint, "OO", ret, ret);
+ if (tmp == NULL) {Py_DECREF(ret); ret=NULL; goto finish;}
Py_DECREF(tmp);
tmp = PyObject_CallFunction(n_ops.multiply, "OOO", ret,
f, ret);
+ if (tmp==NULL) {Py_DECREF(ret); ret=NULL; goto finish;}
Py_DECREF(tmp);
}
- Py_DECREF(f);
- return ret;
}
+
+ finish:
+ Py_DECREF(f);
+ return ret;
+
}
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index be5a3ab4d..f6354b70b 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -130,12 +130,14 @@ class test_dot(NumpyTestCase):
assert (c1.shape == c2.shape)
assert_almost_equal(c1, c2, decimal=self.N)
- def check_vecobject(self,level=2):
+ def check_vecobject(self):
U_non_cont = transpose([[1.,1.],[1.,2.]])
U_cont = ascontiguousarray(U_non_cont)
x = array([Vec([1.,0.]),Vec([0.,1.])])
- assert_almost_equal(dot(U_cont,x),
- dot(U_non_cont,x))
+ zeros = array([Vec([0.,0.]),Vec([0.,0.])])
+ zeros_test = dot(U_cont,x) - dot(U_non_cont,x)
+ assert_equal(zeros[0].array, zeros_test[0].array)
+ assert_equal(zeros[1].array, zeros_test[1].array)
class test_bool_scalar(NumpyTestCase):