diff options
Diffstat (limited to 'numpy')
| -rw-r--r-- | numpy/core/tests/test_multiarray.py | 32 |
1 files changed, 27 insertions, 5 deletions
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 775c06bac..a2979096b 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -7182,17 +7182,39 @@ class TestMatmulInplace: else: np.testing.assert_array_equal(a, ref) - def test_large_matrix(self) -> None: - a = np.arange(10**6).reshape(-1, 10).astype(np.float64) + SHAPES = { + "2d_large": ((10**5, 10), (10, 10)), + "3d_large": ((10**4, 10, 10), (1, 10, 10)), + "2d_broadcast": ((3, 3), (3, 1)), + "2d_broadcast_reverse": ((1, 3), (3, 3)), + "3d_broadcast1": ((3, 3, 3), (1, 3, 1)), + "3d_broadcast2": ((3, 3, 3), (1, 3, 3)), + "3d_broadcast3": ((3, 3, 3), (3, 3, 1)), + "3d_broadcast_reverse1": ((1, 3, 3), (3, 3, 3)), + "3d_broadcast_reverse2": ((3, 1, 3), (3, 3, 3)), + "3d_broadcast_reverse3": ((1, 1, 3), (3, 3, 3)), + } + + @pytest.mark.parametrize("a_shape,b_shape", SHAPES.values(), ids=SHAPES) + def test_shapes(self, a_shape: tuple[int, ...], b_shape: tuple[int, ...]): + a_size = np.product(a_shape) + a = np.arange(a_size).reshape(a_shape).astype(np.float64) a_id = id(a) - b = np.arange(10**2).reshape(10, 10) + + b_size = np.product(b_shape) + b = np.arange(b_size).reshape(b_shape) ref = a @ b - a @= b + if ref.shape != a_shape: + with pytest.raises(ValueError): + a @= b + return + else: + a @= b assert id(a) == a_id assert a.dtype.type == np.float64 - assert a.shape == (10**5, 10) + assert a.shape == a_shape np.testing.assert_allclose(a, ref) |
