summaryrefslogtreecommitdiff
path: root/numpy/random/tests
diff options
context:
space:
mode:
authorKevin Sheppard <bashtage@users.noreply.github.com>2021-11-13 04:17:50 +0000
committerGitHub <noreply@github.com>2021-11-12 23:17:50 -0500
commitaa52dee8dcaba963faa0ec4e36a85d97bbc8e2de (patch)
treeb576c411c99f7ee2a93d0e86349a6a6a5ed18f36 /numpy/random/tests
parent491564d5fb5930beabbe351179c6e8ec9c628c23 (diff)
downloadnumpy-aa52dee8dcaba963faa0ec4e36a85d97bbc8e2de.tar.gz
ENH: random: Add broadcast support to Generator.multinomial (#16740)
xref github issue #15201
Diffstat (limited to 'numpy/random/tests')
-rw-r--r--numpy/random/tests/test_generator_mt19937.py64
1 files changed, 58 insertions, 6 deletions
diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py
index d057122f1..e5411b8ef 100644
--- a/numpy/random/tests/test_generator_mt19937.py
+++ b/numpy/random/tests/test_generator_mt19937.py
@@ -136,12 +136,6 @@ class TestMultinomial:
contig = random.multinomial(100, pvals=np.ascontiguousarray(pvals))
assert_array_equal(non_contig, contig)
- def test_multidimensional_pvals(self):
- assert_raises(ValueError, random.multinomial, 10, [[0, 1]])
- assert_raises(ValueError, random.multinomial, 10, [[0], [1]])
- assert_raises(ValueError, random.multinomial, 10, [[[0], [1]], [[1], [0]]])
- assert_raises(ValueError, random.multinomial, 10, np.array([[0, 1], [1, 0]]))
-
def test_multinomial_pvals_float32(self):
x = np.array([9.9e-01, 9.9e-01, 1.0e-09, 1.0e-09, 1.0e-09, 1.0e-09,
1.0e-09, 1.0e-09, 1.0e-09, 1.0e-09], dtype=np.float32)
@@ -2361,6 +2355,64 @@ class TestBroadcast:
[2, 3, 6, 4, 2, 3]], dtype=np.int64)
assert_array_equal(actual, desired)
+ random = Generator(MT19937(self.seed))
+ actual = random.multinomial([5, 20], [[1 / 6.] * 6] * 2)
+ desired = np.array([[0, 0, 2, 1, 2, 0],
+ [2, 3, 6, 4, 2, 3]], dtype=np.int64)
+ assert_array_equal(actual, desired)
+
+ random = Generator(MT19937(self.seed))
+ actual = random.multinomial([[5], [20]], [[1 / 6.] * 6] * 2)
+ desired = np.array([[[0, 0, 2, 1, 2, 0],
+ [0, 0, 2, 1, 1, 1]],
+ [[4, 2, 3, 3, 5, 3],
+ [7, 2, 2, 1, 4, 4]]], dtype=np.int64)
+ assert_array_equal(actual, desired)
+
+ @pytest.mark.parametrize("n", [10,
+ np.array([10, 10]),
+ np.array([[[10]], [[10]]])
+ ]
+ )
+ def test_multinomial_pval_broadcast(self, n):
+ random = Generator(MT19937(self.seed))
+ pvals = np.array([1 / 4] * 4)
+ actual = random.multinomial(n, pvals)
+ n_shape = tuple() if isinstance(n, int) else n.shape
+ expected_shape = n_shape + (4,)
+ assert actual.shape == expected_shape
+ pvals = np.vstack([pvals, pvals])
+ actual = random.multinomial(n, pvals)
+ expected_shape = np.broadcast_shapes(n_shape, pvals.shape[:-1]) + (4,)
+ assert actual.shape == expected_shape
+
+ pvals = np.vstack([[pvals], [pvals]])
+ actual = random.multinomial(n, pvals)
+ expected_shape = np.broadcast_shapes(n_shape, pvals.shape[:-1])
+ assert actual.shape == expected_shape + (4,)
+ actual = random.multinomial(n, pvals, size=(3, 2) + expected_shape)
+ assert actual.shape == (3, 2) + expected_shape + (4,)
+
+ with pytest.raises(ValueError):
+ # Ensure that size is not broadcast
+ actual = random.multinomial(n, pvals, size=(1,) * 6)
+
+ def test_invalid_pvals_broadcast(self):
+ random = Generator(MT19937(self.seed))
+ pvals = [[1 / 6] * 6, [1 / 4] * 6]
+ assert_raises(ValueError, random.multinomial, 1, pvals)
+ assert_raises(ValueError, random.multinomial, 6, 0.5)
+
+ def test_empty_outputs(self):
+ random = Generator(MT19937(self.seed))
+ actual = random.multinomial(np.empty((10, 0, 6), "i8"), [1 / 6] * 6)
+ assert actual.shape == (10, 0, 6, 6)
+ actual = random.multinomial(12, np.empty((10, 0, 10)))
+ assert actual.shape == (10, 0, 10)
+ actual = random.multinomial(np.empty((3, 0, 7), "i8"),
+ np.empty((3, 0, 7, 4)))
+ assert actual.shape == (3, 0, 7, 4)
+
class TestThread:
# make sure each state produces the same sequence even in threads