summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/tests/test_multiarray.py32
1 files changed, 27 insertions, 5 deletions
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 775c06bac..a2979096b 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -7182,17 +7182,39 @@ class TestMatmulInplace:
else:
np.testing.assert_array_equal(a, ref)
- def test_large_matrix(self) -> None:
- a = np.arange(10**6).reshape(-1, 10).astype(np.float64)
+ SHAPES = {
+ "2d_large": ((10**5, 10), (10, 10)),
+ "3d_large": ((10**4, 10, 10), (1, 10, 10)),
+ "2d_broadcast": ((3, 3), (3, 1)),
+ "2d_broadcast_reverse": ((1, 3), (3, 3)),
+ "3d_broadcast1": ((3, 3, 3), (1, 3, 1)),
+ "3d_broadcast2": ((3, 3, 3), (1, 3, 3)),
+ "3d_broadcast3": ((3, 3, 3), (3, 3, 1)),
+ "3d_broadcast_reverse1": ((1, 3, 3), (3, 3, 3)),
+ "3d_broadcast_reverse2": ((3, 1, 3), (3, 3, 3)),
+ "3d_broadcast_reverse3": ((1, 1, 3), (3, 3, 3)),
+ }
+
+ @pytest.mark.parametrize("a_shape,b_shape", SHAPES.values(), ids=SHAPES)
+ def test_shapes(self, a_shape: tuple[int, ...], b_shape: tuple[int, ...]):
+ a_size = np.product(a_shape)
+ a = np.arange(a_size).reshape(a_shape).astype(np.float64)
a_id = id(a)
- b = np.arange(10**2).reshape(10, 10)
+
+ b_size = np.product(b_shape)
+ b = np.arange(b_size).reshape(b_shape)
ref = a @ b
- a @= b
+ if ref.shape != a_shape:
+ with pytest.raises(ValueError):
+ a @= b
+ return
+ else:
+ a @= b
assert id(a) == a_id
assert a.dtype.type == np.float64
- assert a.shape == (10**5, 10)
+ assert a.shape == a_shape
np.testing.assert_allclose(a, ref)