summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/src/umath/ufunc_object.c22
-rw-r--r--numpy/core/tests/test_umath.py11
2 files changed, 31 insertions, 2 deletions
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index dd5fb7e03..081c06813 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -5396,6 +5396,8 @@ ufunc_outer(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds)
PyArrayObject *ap1 = NULL, *ap2 = NULL, *ap_new = NULL;
PyObject *new_args, *tmp;
PyObject *shape1, *shape2, *newshape;
+ static PyObject *_numpy_matrix;
+
errval = PyUFunc_CheckOverride(ufunc, "outer", args, kwds, &override);
if (errval) {
@@ -5428,7 +5430,18 @@ ufunc_outer(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds)
if (tmp == NULL) {
return NULL;
}
- ap1 = (PyArrayObject *) PyArray_FromObject(tmp, NPY_NOTYPE, 0, 0);
+
+ npy_cache_import(
+ "numpy",
+ "matrix",
+ &_numpy_matrix);
+
+ if (PyObject_IsInstance(tmp, _numpy_matrix)) {
+ ap1 = (PyArrayObject *) PyArray_FromObject(tmp, NPY_NOTYPE, 0, 0);
+ }
+ else {
+ ap1 = (PyArrayObject *) PyArray_FROM_O(tmp);
+ }
Py_DECREF(tmp);
if (ap1 == NULL) {
return NULL;
@@ -5437,7 +5450,12 @@ ufunc_outer(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds)
if (tmp == NULL) {
return NULL;
}
- ap2 = (PyArrayObject *)PyArray_FromObject(tmp, NPY_NOTYPE, 0, 0);
+ if (PyObject_IsInstance(tmp, _numpy_matrix)) {
+ ap2 = (PyArrayObject *) PyArray_FromObject(tmp, NPY_NOTYPE, 0, 0);
+ }
+ else {
+ ap2 = (PyArrayObject *) PyArray_FROM_O(tmp);
+ }
Py_DECREF(tmp);
if (ap2 == NULL) {
Py_DECREF(ap1);
diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py
index db6b51922..0eedd1e97 100644
--- a/numpy/core/tests/test_umath.py
+++ b/numpy/core/tests/test_umath.py
@@ -2924,3 +2924,14 @@ def test_signaling_nan_exceptions():
with assert_no_warnings():
a = np.ndarray(shape=(), dtype='float32', buffer=b'\x00\xe0\xbf\xff')
np.isnan(a)
+
+@pytest.mark.parametrize("arr", [
+ np.arange(2),
+ np.matrix([0, 1]),
+ np.matrix([[0, 1], [2, 5]]),
+ ])
+def test_outer_subclass_preserve(arr):
+ # for gh-8661
+ class foo(np.ndarray): pass
+ actual = np.multiply.outer(arr.view(foo), arr.view(foo))
+ assert actual.__class__.__name__ == 'foo'