diff options
Diffstat (limited to 'numpy/lib/tests')
-rw-r--r-- | numpy/lib/tests/test_shape_base.py | 24 |
1 files changed, 17 insertions, 7 deletions
diff --git a/numpy/lib/tests/test_shape_base.py b/numpy/lib/tests/test_shape_base.py index 564cdfeea..07960627b 100644 --- a/numpy/lib/tests/test_shape_base.py +++ b/numpy/lib/tests/test_shape_base.py @@ -646,21 +646,31 @@ class TestSqueeze: class TestKron: def test_return_type(self): class myarray(np.ndarray): - __array_priority__ = 0.0 + __array_priority__ = 1.0 a = np.ones([2, 2]) ma = myarray(a.shape, a.dtype, a.data) assert_equal(type(kron(a, a)), np.ndarray) assert_equal(type(kron(ma, ma)), myarray) - assert_equal(type(kron(a, ma)), np.ndarray) + assert_equal(type(kron(a, ma)), myarray) assert_equal(type(kron(ma, a)), myarray) - def test_kron_smoke(self): - a = np.ones([3, 3]) - b = np.ones([3, 3]) - k = np.ones([9, 9]) + @pytest.mark.parametrize( + "array_class", [np.asarray, np.mat] + ) + def test_kron_smoke(self, array_class): + a = array_class(np.ones([3, 3])) + b = array_class(np.ones([3, 3])) + k = array_class(np.ones([9, 9])) + + assert_array_equal(np.kron(a, b), k) + + def test_kron_ma(self): + x = np.ma.array([[1, 2], [3, 4]], mask=[[0, 1], [1, 0]]) + k = np.ma.array(np.diag([1, 4, 4, 16]), + mask=~np.array(np.identity(4), dtype=bool)) - assert np.array_equal(np.kron(a, b), k), "Smoke test for kron failed" + assert_array_equal(k, np.kron(x, x)) @pytest.mark.parametrize( "shape_a,shape_b", [ |