diff options
Diffstat (limited to 'numpy')
| -rw-r--r-- | numpy/lib/shape_base.py | 16 | ||||
| -rw-r--r-- | numpy/lib/tests/test_shape_base.py | 24 |
2 files changed, 23 insertions, 17 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py index 581da0598..8371747c0 100644 --- a/numpy/lib/shape_base.py +++ b/numpy/lib/shape_base.py @@ -1159,9 +1159,11 @@ def kron(a, b): bs = (1,)*max(0, nda-ndb) + bs # Compute the product - a_arr = _nx.asarray(a).reshape(a.size, 1) - b_arr = _nx.asarray(b).reshape(1, b.size) - result = a_arr * b_arr + a_arr = a.reshape(a.size, 1) + b_arr = b.reshape(1, b.size) + is_any_mat = isinstance(a_arr, matrix) or isinstance(b_arr, matrix) + # In case of `mat`, convert result to `array` + result = _nx.multiply(a_arr, b_arr, subok=(not is_any_mat)) # Reshape back result = result.reshape(as_+bs) @@ -1169,13 +1171,7 @@ def kron(a, b): result = result.transpose(transposer) result = result.reshape(_nx.multiply(as_, bs)) - wrapper = get_array_prepare(a, b) - if wrapper is not None: - result = wrapper(result) - wrapper = get_array_wrap(a, b) - if wrapper is not None: - result = wrapper(result) - return result + return result if not is_any_mat else matrix(result, copy=False) def _tile_dispatcher(A, reps): diff --git a/numpy/lib/tests/test_shape_base.py b/numpy/lib/tests/test_shape_base.py index 564cdfeea..07960627b 100644 --- a/numpy/lib/tests/test_shape_base.py +++ b/numpy/lib/tests/test_shape_base.py @@ -646,21 +646,31 @@ class TestSqueeze: class TestKron: def test_return_type(self): class myarray(np.ndarray): - __array_priority__ = 0.0 + __array_priority__ = 1.0 a = np.ones([2, 2]) ma = myarray(a.shape, a.dtype, a.data) assert_equal(type(kron(a, a)), np.ndarray) assert_equal(type(kron(ma, ma)), myarray) - assert_equal(type(kron(a, ma)), np.ndarray) + assert_equal(type(kron(a, ma)), myarray) assert_equal(type(kron(ma, a)), myarray) - def test_kron_smoke(self): - a = np.ones([3, 3]) - b = np.ones([3, 3]) - k = np.ones([9, 9]) + @pytest.mark.parametrize( + "array_class", [np.asarray, np.mat] + ) + def test_kron_smoke(self, array_class): + a = array_class(np.ones([3, 3])) + b = array_class(np.ones([3, 3])) + k = array_class(np.ones([9, 9])) + + assert_array_equal(np.kron(a, b), k) + + def test_kron_ma(self): + x = np.ma.array([[1, 2], [3, 4]], mask=[[0, 1], [1, 0]]) + k = np.ma.array(np.diag([1, 4, 4, 16]), + mask=~np.array(np.identity(4), dtype=bool)) - assert np.array_equal(np.kron(a, b), k), "Smoke test for kron failed" + assert_array_equal(k, np.kron(x, x)) @pytest.mark.parametrize( "shape_a,shape_b", [ |
