diff options
author | Kevin Sheppard <bashtage@users.noreply.github.com> | 2021-12-02 14:26:11 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-12-02 08:26:11 -0600 |
commit | 0e10696f55576441fd820279d9ec10cd9f2a4c5d (patch) | |
tree | a11b2326c15508a00a9fdadf5106d23f1dc22ca2 | |
parent | 742f13f7ee6ff8ed56fc468c9ef57b3853141768 (diff) | |
download | numpy-0e10696f55576441fd820279d9ec10cd9f2a4c5d.tar.gz |
BUG: Protect divide by 0 in multinomial (#20490)
Guard against empty pvals
closes #20483
* MAINT: Correct exception type
Raise TypeError for enhanced backward compat when pvals
is scalar
-rw-r--r-- | numpy/random/mtrand.pyx | 17 | ||||
-rw-r--r-- | numpy/random/tests/test_randomstate_regression.py | 13 |
2 files changed, 23 insertions, 7 deletions
diff --git a/numpy/random/mtrand.pyx b/numpy/random/mtrand.pyx index 3e13503d0..ce09a041c 100644 --- a/numpy/random/mtrand.pyx +++ b/numpy/random/mtrand.pyx @@ -4234,18 +4234,21 @@ cdef class RandomState: ValueError: pvals < 0, pvals > 1 or pvals contains NaNs """ - cdef np.npy_intp d, i, sz, offset + cdef np.npy_intp d, i, sz, offset, niter cdef np.ndarray parr, mnarr cdef double *pix cdef long *mnix cdef long ni - d = len(pvals) parr = <np.ndarray>np.PyArray_FROMANY( - pvals, np.NPY_DOUBLE, 1, 1, np.NPY_ARRAY_ALIGNED | np.NPY_ARRAY_C_CONTIGUOUS) + pvals, np.NPY_DOUBLE, 0, 1, np.NPY_ARRAY_ALIGNED | np.NPY_ARRAY_C_CONTIGUOUS) + if np.PyArray_NDIM(parr) == 0: + raise TypeError("pvals must be a 1-d sequence") + d = np.PyArray_SIZE(parr) pix = <double*>np.PyArray_DATA(parr) check_array_constraint(parr, 'pvals', CONS_BOUNDED_0_1) - if kahan_sum(pix, d-1) > (1.0 + 1e-12): + # Only check if pvals is non-empty due no checks in kahan_sum + if d and kahan_sum(pix, d-1) > (1.0 + 1e-12): # When floating, but not float dtype, and close, improve the error # 1.0001 works for float16 and float32 if (isinstance(pvals, np.ndarray) @@ -4260,7 +4263,6 @@ cdef class RandomState: else: msg = "sum(pvals[:-1]) > 1.0" raise ValueError(msg) - if size is None: shape = (d,) else: @@ -4268,7 +4270,6 @@ cdef class RandomState: shape = (operator.index(size), d) except: shape = tuple(size) + (d,) - multin = np.zeros(shape, dtype=int) mnarr = <np.ndarray>multin mnix = <long*>np.PyArray_DATA(mnarr) @@ -4276,8 +4277,10 @@ cdef class RandomState: ni = n check_constraint(ni, 'n', CONS_NON_NEGATIVE) offset = 0 + # gh-20483: Avoids divide by 0 + niter = sz // d if d else 0 with self.lock, nogil: - for i in range(sz // d): + for i in range(niter): legacy_random_multinomial(&self._bitgen, ni, &mnix[offset], pix, d, &self._binomial) offset += d diff --git a/numpy/random/tests/test_randomstate_regression.py b/numpy/random/tests/test_randomstate_regression.py index 595fb5fd3..7ad19ab55 100644 --- a/numpy/random/tests/test_randomstate_regression.py +++ b/numpy/random/tests/test_randomstate_regression.py @@ -201,3 +201,16 @@ class TestRegression: [3, 4, 2, 3, 3, 1, 5, 3, 1, 3]]) assert_array_equal(random.binomial([[0], [10]], 0.25, size=(2, 10)), expected) + + +def test_multinomial_empty(): + # gh-20483 + # Ensure that empty p-vals are correctly handled + assert random.multinomial(10, []).shape == (0,) + assert random.multinomial(3, [], size=(7, 5, 3)).shape == (7, 5, 3, 0) + + +def test_multinomial_1d_pval(): + # gh-20483 + with pytest.raises(TypeError, match="pvals must be a 1-d"): + random.multinomial(10, 0.3) |