summaryrefslogtreecommitdiff
path: root/numpy/lib/tests/test_shape_base.py
diff options
context:
space:
mode:
authorGanesh Kathiresan <ganesh3597@gmail.com>2022-04-04 19:30:02 +0530
committerGanesh Kathiresan <ganesh3597@gmail.com>2022-04-05 20:49:35 +0530
commitc68a8b64e203db7c1310b575f7460da5f13c8dcd (patch)
treef2a0b9e37198e97b83c29db86090722599e4df48 /numpy/lib/tests/test_shape_base.py
parent730f3154f48e33f22b2ea8814eb10a45aa273e17 (diff)
downloadnumpy-c68a8b64e203db7c1310b575f7460da5f13c8dcd.tar.gz
TST: `np.kron` tests refinement
* Added `mat` cases to smoke tests * Changed type checks to handle new change which uses ufuncs order for result determination * Added cases for `ma` to check subclass info retention
Diffstat (limited to 'numpy/lib/tests/test_shape_base.py')
-rw-r--r--numpy/lib/tests/test_shape_base.py24
1 files changed, 17 insertions, 7 deletions
diff --git a/numpy/lib/tests/test_shape_base.py b/numpy/lib/tests/test_shape_base.py
index 564cdfeea..07960627b 100644
--- a/numpy/lib/tests/test_shape_base.py
+++ b/numpy/lib/tests/test_shape_base.py
@@ -646,21 +646,31 @@ class TestSqueeze:
class TestKron:
def test_return_type(self):
class myarray(np.ndarray):
- __array_priority__ = 0.0
+ __array_priority__ = 1.0
a = np.ones([2, 2])
ma = myarray(a.shape, a.dtype, a.data)
assert_equal(type(kron(a, a)), np.ndarray)
assert_equal(type(kron(ma, ma)), myarray)
- assert_equal(type(kron(a, ma)), np.ndarray)
+ assert_equal(type(kron(a, ma)), myarray)
assert_equal(type(kron(ma, a)), myarray)
- def test_kron_smoke(self):
- a = np.ones([3, 3])
- b = np.ones([3, 3])
- k = np.ones([9, 9])
+ @pytest.mark.parametrize(
+ "array_class", [np.asarray, np.mat]
+ )
+ def test_kron_smoke(self, array_class):
+ a = array_class(np.ones([3, 3]))
+ b = array_class(np.ones([3, 3]))
+ k = array_class(np.ones([9, 9]))
+
+ assert_array_equal(np.kron(a, b), k)
+
+ def test_kron_ma(self):
+ x = np.ma.array([[1, 2], [3, 4]], mask=[[0, 1], [1, 0]])
+ k = np.ma.array(np.diag([1, 4, 4, 16]),
+ mask=~np.array(np.identity(4), dtype=bool))
- assert np.array_equal(np.kron(a, b), k), "Smoke test for kron failed"
+ assert_array_equal(k, np.kron(x, x))
@pytest.mark.parametrize(
"shape_a,shape_b", [