diff options
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/src/umath/ufunc_object.c | 22 | ||||
-rw-r--r-- | numpy/core/tests/test_umath.py | 11 |
2 files changed, 31 insertions, 2 deletions
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c index dd5fb7e03..081c06813 100644 --- a/numpy/core/src/umath/ufunc_object.c +++ b/numpy/core/src/umath/ufunc_object.c @@ -5396,6 +5396,8 @@ ufunc_outer(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds) PyArrayObject *ap1 = NULL, *ap2 = NULL, *ap_new = NULL; PyObject *new_args, *tmp; PyObject *shape1, *shape2, *newshape; + static PyObject *_numpy_matrix; + errval = PyUFunc_CheckOverride(ufunc, "outer", args, kwds, &override); if (errval) { @@ -5428,7 +5430,18 @@ ufunc_outer(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds) if (tmp == NULL) { return NULL; } - ap1 = (PyArrayObject *) PyArray_FromObject(tmp, NPY_NOTYPE, 0, 0); + + npy_cache_import( + "numpy", + "matrix", + &_numpy_matrix); + + if (PyObject_IsInstance(tmp, _numpy_matrix)) { + ap1 = (PyArrayObject *) PyArray_FromObject(tmp, NPY_NOTYPE, 0, 0); + } + else { + ap1 = (PyArrayObject *) PyArray_FROM_O(tmp); + } Py_DECREF(tmp); if (ap1 == NULL) { return NULL; @@ -5437,7 +5450,12 @@ ufunc_outer(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds) if (tmp == NULL) { return NULL; } - ap2 = (PyArrayObject *)PyArray_FromObject(tmp, NPY_NOTYPE, 0, 0); + if (PyObject_IsInstance(tmp, _numpy_matrix)) { + ap2 = (PyArrayObject *) PyArray_FromObject(tmp, NPY_NOTYPE, 0, 0); + } + else { + ap2 = (PyArrayObject *) PyArray_FROM_O(tmp); + } Py_DECREF(tmp); if (ap2 == NULL) { Py_DECREF(ap1); diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py index db6b51922..0eedd1e97 100644 --- a/numpy/core/tests/test_umath.py +++ b/numpy/core/tests/test_umath.py @@ -2924,3 +2924,14 @@ def test_signaling_nan_exceptions(): with assert_no_warnings(): a = np.ndarray(shape=(), dtype='float32', buffer=b'\x00\xe0\xbf\xff') np.isnan(a) + +@pytest.mark.parametrize("arr", [ + np.arange(2), + np.matrix([0, 1]), + np.matrix([[0, 1], [2, 5]]), + ]) +def test_outer_subclass_preserve(arr): + # for gh-8661 + class foo(np.ndarray): pass + actual = np.multiply.outer(arr.view(foo), arr.view(foo)) + assert actual.__class__.__name__ == 'foo' |