diff options
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/src/umath/ufunc_object.c | 55 | ||||
-rw-r--r-- | numpy/core/tests/test_umath.py | 36 |
2 files changed, 68 insertions, 23 deletions
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c index b35f377d7..9ee44e918 100644 --- a/numpy/core/src/umath/ufunc_object.c +++ b/numpy/core/src/umath/ufunc_object.c @@ -5388,13 +5388,11 @@ ufunc_traverse(PyUFuncObject *self, visitproc visit, void *arg) static PyObject * ufunc_outer(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds) { - int i; int errval; PyObject *override = NULL; PyObject *ret; PyArrayObject *ap1 = NULL, *ap2 = NULL, *ap_new = NULL; PyObject *new_args, *tmp; - PyObject *shape1, *shape2, *newshape; static PyObject *_numpy_matrix; @@ -5460,34 +5458,45 @@ ufunc_outer(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds) Py_DECREF(ap1); return NULL; } - /* Construct new shape tuple */ - shape1 = PyTuple_New(PyArray_NDIM(ap1)); - if (shape1 == NULL) { - goto fail; - } - for (i = 0; i < PyArray_NDIM(ap1); i++) { - PyTuple_SET_ITEM(shape1, i, - PyLong_FromLongLong((npy_longlong)PyArray_DIMS(ap1)[i])); - } - shape2 = PyTuple_New(PyArray_NDIM(ap2)); - for (i = 0; i < PyArray_NDIM(ap2); i++) { - PyTuple_SET_ITEM(shape2, i, PyInt_FromLong((long) 1)); + /* Construct new shape from ap1 and ap2 and then reshape */ + PyArray_Dims newdims; + npy_intp newshape[NPY_MAXDIMS]; + newdims.len = PyArray_NDIM(ap1) + PyArray_NDIM(ap2); + newdims.ptr = newshape; + + if (newdims.len > NPY_MAXDIMS) { + PyErr_Format(PyExc_ValueError, + "maximum supported dimension for an ndarray is %d, but " + "`%s.outer()` result would have %d.", + NPY_MAXDIMS, ufunc->name, newdims.len); + return NPY_FAIL; } - if (shape2 == NULL) { - Py_DECREF(shape1); + if (newdims.ptr == NULL) { goto fail; } - newshape = PyNumber_Add(shape1, shape2); - Py_DECREF(shape1); - Py_DECREF(shape2); - if (newshape == NULL) { - goto fail; + memcpy(newshape, PyArray_DIMS(ap1), PyArray_NDIM(ap1) * sizeof(npy_intp)); + for (int i = PyArray_NDIM(ap1); i < newdims.len; i++) { + newshape[i] = 1; } - ap_new = (PyArrayObject *)PyArray_Reshape(ap1, newshape); - Py_DECREF(newshape); + + ap_new = (PyArrayObject *)PyArray_Newshape(ap1, &newdims, NPY_CORDER); if (ap_new == NULL) { goto fail; } + if (PyArray_NDIM(ap_new) != newdims.len || + !PyArray_CompareLists(PyArray_DIMS(ap_new), newshape, newdims.len)) { + PyErr_Format(PyExc_TypeError, + "%s.outer() called with ndarray-subclass of type '%s' " + "which modified its shape after a reshape. `outer()` relies " + "on reshaping the inputs and is for example not supported for " + "the 'np.matrix' class (the usage of matrix is generally " + "discouraged). " + "To work around this issue, please convert the inputs to " + "numpy arrays.", + ufunc->name, Py_TYPE(ap_new)->tp_name); + goto fail; + } + new_args = Py_BuildValue("(OO)", ap_new, ap2); Py_DECREF(ap1); Py_DECREF(ap2); diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py index 941d99521..facb00aa3 100644 --- a/numpy/core/tests/test_umath.py +++ b/numpy/core/tests/test_umath.py @@ -3270,3 +3270,39 @@ def test_outer_subclass_preserve(arr): class foo(np.ndarray): pass actual = np.multiply.outer(arr.view(foo), arr.view(foo)) assert actual.__class__.__name__ == 'foo' + +def test_outer_bad_subclass(): + class BadArr1(np.ndarray): + def __array_finalize__(self, obj): + # The outer call reshapes to 3 dims, try to do a bad reshape. + if self.ndim == 3: + self.shape = self.shape + (1,) + + def __array_prepare__(self, obj, context=None): + return obj + + class BadArr2(np.ndarray): + def __array_finalize__(self, obj): + if isinstance(obj, BadArr2): + # outer inserts 1-sized dims. In that case disturb them. + if self.shape[-1] == 1: + self.shape = self.shape[::-1] + + def __array_prepare__(self, obj, context=None): + return obj + + for cls in [BadArr1, BadArr2]: + arr = np.ones((2, 3)).view(cls) + with assert_raises(TypeError) as a: + # The first array gets reshaped (not the second one) + np.add.outer(arr, [1, 2]) + + # This actually works, since we only see the reshaping error: + arr = np.ones((2, 3)).view(cls) + assert type(np.add.outer([1, 2], arr)) is cls + +def test_outer_exceeds_maxdims(): + deep = np.ones((1,) * 17) + with assert_raises(ValueError): + np.add.outer(deep, deep) + |