summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/tests/test_shape_base.py24
1 files changed, 17 insertions, 7 deletions
diff --git a/numpy/lib/tests/test_shape_base.py b/numpy/lib/tests/test_shape_base.py
index 564cdfeea..07960627b 100644
--- a/numpy/lib/tests/test_shape_base.py
+++ b/numpy/lib/tests/test_shape_base.py
@@ -646,21 +646,31 @@ class TestSqueeze:
class TestKron:
def test_return_type(self):
class myarray(np.ndarray):
- __array_priority__ = 0.0
+ __array_priority__ = 1.0
a = np.ones([2, 2])
ma = myarray(a.shape, a.dtype, a.data)
assert_equal(type(kron(a, a)), np.ndarray)
assert_equal(type(kron(ma, ma)), myarray)
- assert_equal(type(kron(a, ma)), np.ndarray)
+ assert_equal(type(kron(a, ma)), myarray)
assert_equal(type(kron(ma, a)), myarray)
- def test_kron_smoke(self):
- a = np.ones([3, 3])
- b = np.ones([3, 3])
- k = np.ones([9, 9])
+ @pytest.mark.parametrize(
+ "array_class", [np.asarray, np.mat]
+ )
+ def test_kron_smoke(self, array_class):
+ a = array_class(np.ones([3, 3]))
+ b = array_class(np.ones([3, 3]))
+ k = array_class(np.ones([9, 9]))
+
+ assert_array_equal(np.kron(a, b), k)
+
+ def test_kron_ma(self):
+ x = np.ma.array([[1, 2], [3, 4]], mask=[[0, 1], [1, 0]])
+ k = np.ma.array(np.diag([1, 4, 4, 16]),
+ mask=~np.array(np.identity(4), dtype=bool))
- assert np.array_equal(np.kron(a, b), k), "Smoke test for kron failed"
+ assert_array_equal(k, np.kron(x, x))
@pytest.mark.parametrize(
"shape_a,shape_b", [