summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/lib/shape_base.py16
-rw-r--r--numpy/lib/tests/test_shape_base.py24
2 files changed, 23 insertions, 17 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py
index 581da0598..8371747c0 100644
--- a/numpy/lib/shape_base.py
+++ b/numpy/lib/shape_base.py
@@ -1159,9 +1159,11 @@ def kron(a, b):
bs = (1,)*max(0, nda-ndb) + bs
# Compute the product
- a_arr = _nx.asarray(a).reshape(a.size, 1)
- b_arr = _nx.asarray(b).reshape(1, b.size)
- result = a_arr * b_arr
+ a_arr = a.reshape(a.size, 1)
+ b_arr = b.reshape(1, b.size)
+ is_any_mat = isinstance(a_arr, matrix) or isinstance(b_arr, matrix)
+ # In case of `mat`, convert result to `array`
+ result = _nx.multiply(a_arr, b_arr, subok=(not is_any_mat))
# Reshape back
result = result.reshape(as_+bs)
@@ -1169,13 +1171,7 @@ def kron(a, b):
result = result.transpose(transposer)
result = result.reshape(_nx.multiply(as_, bs))
- wrapper = get_array_prepare(a, b)
- if wrapper is not None:
- result = wrapper(result)
- wrapper = get_array_wrap(a, b)
- if wrapper is not None:
- result = wrapper(result)
- return result
+ return result if not is_any_mat else matrix(result, copy=False)
def _tile_dispatcher(A, reps):
diff --git a/numpy/lib/tests/test_shape_base.py b/numpy/lib/tests/test_shape_base.py
index 564cdfeea..07960627b 100644
--- a/numpy/lib/tests/test_shape_base.py
+++ b/numpy/lib/tests/test_shape_base.py
@@ -646,21 +646,31 @@ class TestSqueeze:
class TestKron:
def test_return_type(self):
class myarray(np.ndarray):
- __array_priority__ = 0.0
+ __array_priority__ = 1.0
a = np.ones([2, 2])
ma = myarray(a.shape, a.dtype, a.data)
assert_equal(type(kron(a, a)), np.ndarray)
assert_equal(type(kron(ma, ma)), myarray)
- assert_equal(type(kron(a, ma)), np.ndarray)
+ assert_equal(type(kron(a, ma)), myarray)
assert_equal(type(kron(ma, a)), myarray)
- def test_kron_smoke(self):
- a = np.ones([3, 3])
- b = np.ones([3, 3])
- k = np.ones([9, 9])
+ @pytest.mark.parametrize(
+ "array_class", [np.asarray, np.mat]
+ )
+ def test_kron_smoke(self, array_class):
+ a = array_class(np.ones([3, 3]))
+ b = array_class(np.ones([3, 3]))
+ k = array_class(np.ones([9, 9]))
+
+ assert_array_equal(np.kron(a, b), k)
+
+ def test_kron_ma(self):
+ x = np.ma.array([[1, 2], [3, 4]], mask=[[0, 1], [1, 0]])
+ k = np.ma.array(np.diag([1, 4, 4, 16]),
+ mask=~np.array(np.identity(4), dtype=bool))
- assert np.array_equal(np.kron(a, b), k), "Smoke test for kron failed"
+ assert_array_equal(k, np.kron(x, x))
@pytest.mark.parametrize(
"shape_a,shape_b", [