summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKevin Sheppard <bashtage@users.noreply.github.com>2021-12-02 14:26:11 +0000
committerGitHub <noreply@github.com>2021-12-02 08:26:11 -0600
commit0e10696f55576441fd820279d9ec10cd9f2a4c5d (patch)
treea11b2326c15508a00a9fdadf5106d23f1dc22ca2
parent742f13f7ee6ff8ed56fc468c9ef57b3853141768 (diff)
downloadnumpy-0e10696f55576441fd820279d9ec10cd9f2a4c5d.tar.gz
BUG: Protect divide by 0 in multinomial (#20490)
Guard against empty pvals closes #20483 * MAINT: Correct exception type Raise TypeError for enhanced backward compat when pvals is scalar
-rw-r--r--numpy/random/mtrand.pyx17
-rw-r--r--numpy/random/tests/test_randomstate_regression.py13
2 files changed, 23 insertions, 7 deletions
diff --git a/numpy/random/mtrand.pyx b/numpy/random/mtrand.pyx
index 3e13503d0..ce09a041c 100644
--- a/numpy/random/mtrand.pyx
+++ b/numpy/random/mtrand.pyx
@@ -4234,18 +4234,21 @@ cdef class RandomState:
ValueError: pvals < 0, pvals > 1 or pvals contains NaNs
"""
- cdef np.npy_intp d, i, sz, offset
+ cdef np.npy_intp d, i, sz, offset, niter
cdef np.ndarray parr, mnarr
cdef double *pix
cdef long *mnix
cdef long ni
- d = len(pvals)
parr = <np.ndarray>np.PyArray_FROMANY(
- pvals, np.NPY_DOUBLE, 1, 1, np.NPY_ARRAY_ALIGNED | np.NPY_ARRAY_C_CONTIGUOUS)
+ pvals, np.NPY_DOUBLE, 0, 1, np.NPY_ARRAY_ALIGNED | np.NPY_ARRAY_C_CONTIGUOUS)
+ if np.PyArray_NDIM(parr) == 0:
+ raise TypeError("pvals must be a 1-d sequence")
+ d = np.PyArray_SIZE(parr)
pix = <double*>np.PyArray_DATA(parr)
check_array_constraint(parr, 'pvals', CONS_BOUNDED_0_1)
- if kahan_sum(pix, d-1) > (1.0 + 1e-12):
+ # Only check if pvals is non-empty due no checks in kahan_sum
+ if d and kahan_sum(pix, d-1) > (1.0 + 1e-12):
# When floating, but not float dtype, and close, improve the error
# 1.0001 works for float16 and float32
if (isinstance(pvals, np.ndarray)
@@ -4260,7 +4263,6 @@ cdef class RandomState:
else:
msg = "sum(pvals[:-1]) > 1.0"
raise ValueError(msg)
-
if size is None:
shape = (d,)
else:
@@ -4268,7 +4270,6 @@ cdef class RandomState:
shape = (operator.index(size), d)
except:
shape = tuple(size) + (d,)
-
multin = np.zeros(shape, dtype=int)
mnarr = <np.ndarray>multin
mnix = <long*>np.PyArray_DATA(mnarr)
@@ -4276,8 +4277,10 @@ cdef class RandomState:
ni = n
check_constraint(ni, 'n', CONS_NON_NEGATIVE)
offset = 0
+ # gh-20483: Avoids divide by 0
+ niter = sz // d if d else 0
with self.lock, nogil:
- for i in range(sz // d):
+ for i in range(niter):
legacy_random_multinomial(&self._bitgen, ni, &mnix[offset], pix, d, &self._binomial)
offset += d
diff --git a/numpy/random/tests/test_randomstate_regression.py b/numpy/random/tests/test_randomstate_regression.py
index 595fb5fd3..7ad19ab55 100644
--- a/numpy/random/tests/test_randomstate_regression.py
+++ b/numpy/random/tests/test_randomstate_regression.py
@@ -201,3 +201,16 @@ class TestRegression:
[3, 4, 2, 3, 3, 1, 5, 3, 1, 3]])
assert_array_equal(random.binomial([[0], [10]], 0.25, size=(2, 10)),
expected)
+
+
+def test_multinomial_empty():
+ # gh-20483
+ # Ensure that empty p-vals are correctly handled
+ assert random.multinomial(10, []).shape == (0,)
+ assert random.multinomial(3, [], size=(7, 5, 3)).shape == (7, 5, 3, 0)
+
+
+def test_multinomial_1d_pval():
+ # gh-20483
+ with pytest.raises(TypeError, match="pvals must be a 1-d"):
+ random.multinomial(10, 0.3)