summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorSebastian Berg <sebastianb@nvidia.com>2022-11-09 15:19:48 +0100
committerSebastian Berg <sebastianb@nvidia.com>2022-11-09 15:31:25 +0100
commiteae88ff5f9cc3e3caea5e849600e0f8db52a67db (patch)
tree036ffa2ca60b510bb4c60c2fc28fb553e4eed653 /numpy
parent1aa73ca93afd2cd2cac3b564051f1153a2c05ee8 (diff)
downloadnumpy-eae88ff5f9cc3e3caea5e849600e0f8db52a67db.tar.gz
BUG: Fix use and errorchecking of ObjectType use
This should be replaced really, it is pretty bad API use, and doesn't work well (up to being incorrect probably). But working on other things (trying to make promotion strict and thus saner), I realized that the use was always wrong: we cannot pass 0 since 0 means `bool`, what was always meant was passing no-type. So fixing this, and adding the error check everywhere. Checking for `PyErr_Occurred()` may have been necessary at some point, but is not anymore.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/_multiarray_tests.c.src13
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c40
-rw-r--r--numpy/core/tests/test_multiarray.py24
3 files changed, 62 insertions, 15 deletions
diff --git a/numpy/core/src/multiarray/_multiarray_tests.c.src b/numpy/core/src/multiarray/_multiarray_tests.c.src
index b22b2c14d..1fd28e721 100644
--- a/numpy/core/src/multiarray/_multiarray_tests.c.src
+++ b/numpy/core/src/multiarray/_multiarray_tests.c.src
@@ -177,8 +177,14 @@ test_neighborhood_iterator(PyObject* NPY_UNUSED(self), PyObject* args)
return NULL;
}
- typenum = PyArray_ObjectType(x, 0);
+ typenum = PyArray_ObjectType(x, NPY_NOTYPE);
+ if (typenum == NPY_NOTYPE) {
+ return NULL;
+ }
typenum = PyArray_ObjectType(fill, typenum);
+ if (typenum == NPY_NOTYPE) {
+ return NULL;
+ }
ax = (PyArrayObject*)PyArray_FromObject(x, typenum, 1, 10);
if (ax == NULL) {
@@ -343,7 +349,10 @@ test_neighborhood_iterator_oob(PyObject* NPY_UNUSED(self), PyObject* args)
return NULL;
}
- typenum = PyArray_ObjectType(x, 0);
+ typenum = PyArray_ObjectType(x, NPY_NOTYPE);
+ if (typenum == NPY_NOTYPE) {
+ return NULL;
+ }
ax = (PyArrayObject*)PyArray_FromObject(x, typenum, 1, 10);
if (ax == NULL) {
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index dda8831c5..b2925f758 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -899,11 +899,15 @@ PyArray_InnerProduct(PyObject *op1, PyObject *op2)
int i;
PyObject* ret = NULL;
- typenum = PyArray_ObjectType(op1, 0);
- if (typenum == NPY_NOTYPE && PyErr_Occurred()) {
+ typenum = PyArray_ObjectType(op1, NPY_NOTYPE);
+ if (typenum == NPY_NOTYPE) {
return NULL;
}
typenum = PyArray_ObjectType(op2, typenum);
+ if (typenum == NPY_NOTYPE) {
+ return NULL;
+ }
+
typec = PyArray_DescrFromType(typenum);
if (typec == NULL) {
if (!PyErr_Occurred()) {
@@ -991,11 +995,15 @@ PyArray_MatrixProduct2(PyObject *op1, PyObject *op2, PyArrayObject* out)
PyArray_Descr *typec = NULL;
NPY_BEGIN_THREADS_DEF;
- typenum = PyArray_ObjectType(op1, 0);
- if (typenum == NPY_NOTYPE && PyErr_Occurred()) {
+ typenum = PyArray_ObjectType(op1, NPY_NOTYPE);
+ if (typenum == NPY_NOTYPE) {
return NULL;
}
typenum = PyArray_ObjectType(op2, typenum);
+ if (typenum == NPY_NOTYPE) {
+ return NULL;
+ }
+
typec = PyArray_DescrFromType(typenum);
if (typec == NULL) {
if (!PyErr_Occurred()) {
@@ -1373,8 +1381,14 @@ PyArray_Correlate2(PyObject *op1, PyObject *op2, int mode)
int inverted;
int st;
- typenum = PyArray_ObjectType(op1, 0);
+ typenum = PyArray_ObjectType(op1, NPY_NOTYPE);
+ if (typenum == NPY_NOTYPE) {
+ return NULL;
+ }
typenum = PyArray_ObjectType(op2, typenum);
+ if (typenum == NPY_NOTYPE) {
+ return NULL;
+ }
typec = PyArray_DescrFromType(typenum);
Py_INCREF(typec);
@@ -1440,8 +1454,14 @@ PyArray_Correlate(PyObject *op1, PyObject *op2, int mode)
int unused;
PyArray_Descr *typec;
- typenum = PyArray_ObjectType(op1, 0);
+ typenum = PyArray_ObjectType(op1, NPY_NOTYPE);
+ if (typenum == NPY_NOTYPE) {
+ return NULL;
+ }
typenum = PyArray_ObjectType(op2, typenum);
+ if (typenum == NPY_NOTYPE) {
+ return NULL;
+ }
typec = PyArray_DescrFromType(typenum);
Py_INCREF(typec);
@@ -2541,8 +2561,14 @@ array_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args)
* Conjugating dot product using the BLAS for vectors.
* Flattens both op1 and op2 before dotting.
*/
- typenum = PyArray_ObjectType(op1, 0);
+ typenum = PyArray_ObjectType(op1, NPY_NOTYPE);
+ if (typenum == NPY_NOTYPE) {
+ return NULL;
+ }
typenum = PyArray_ObjectType(op2, typenum);
+ if (typenum == NPY_NOTYPE) {
+ return NULL;
+ }
type = PyArray_DescrFromType(typenum);
Py_INCREF(type);
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 027384fba..15619bcb3 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -1257,9 +1257,9 @@ class TestStructured:
# The main importance is that it does not return True:
with pytest.raises(TypeError):
x == y
-
+
def test_empty_structured_array_comparison(self):
- # Check that comparison works on empty arrays with nontrivially
+ # Check that comparison works on empty arrays with nontrivially
# shaped fields
a = np.zeros(0, [('a', '<f8', (1, 1))])
assert_equal(a, a)
@@ -2232,7 +2232,7 @@ class TestMethods:
assert_c(a.copy('C'))
assert_fortran(a.copy('F'))
assert_c(a.copy('A'))
-
+
@pytest.mark.parametrize("dtype", ['O', np.int32, 'i,O'])
def test__deepcopy__(self, dtype):
# Force the entry of NULLs into array
@@ -2441,7 +2441,7 @@ class TestMethods:
np.array([0, 1, np.nan]),
])
def test_searchsorted_floats(self, a):
- # test for floats arrays containing nans. Explicitly test
+ # test for floats arrays containing nans. Explicitly test
# half, single, and double precision floats to verify that
# the NaN-handling is correct.
msg = "Test real (%s) searchsorted with nans, side='l'" % a.dtype
@@ -2457,7 +2457,7 @@ class TestMethods:
assert_equal(y, 2)
def test_searchsorted_complex(self):
- # test for complex arrays containing nans.
+ # test for complex arrays containing nans.
# The search sorted routines use the compare functions for the
# array type, so this checks if that is consistent with the sort
# order.
@@ -2479,7 +2479,7 @@ class TestMethods:
a = np.array([0, 128], dtype='>i4')
b = a.searchsorted(np.array(128, dtype='>i4'))
assert_equal(b, 1, msg)
-
+
def test_searchsorted_n_elements(self):
# Check 0 elements
a = np.ones(0)
@@ -6731,6 +6731,18 @@ class TestDot:
res = np.dot(data, data)
assert res == 2**30+100
+ def test_dtype_discovery_fails(self):
+ # See gh-14247, error checking was missing for failed dtype discovery
+ class BadObject(object):
+ def __array__(self):
+ raise TypeError("just this tiny mint leaf")
+
+ with pytest.raises(TypeError):
+ np.dot(BadObject(), BadObject())
+
+ with pytest.raises(TypeError):
+ np.dot(3.0, BadObject())
+
class MatmulCommon:
"""Common tests for '@' operator and numpy.matmul.