diff options
author | Kevin Sheppard <kevin.k.sheppard@gmail.com> | 2020-07-02 18:41:11 +0100 |
---|---|---|
committer | Kevin Sheppard <kevin.sheppard@gmail.com> | 2021-02-26 22:41:03 +0000 |
commit | 9e068f4032dab58cf310cfb663fa4930371cecf3 (patch) | |
tree | d1c7c45578fb9a45b58081a9396f0fe86b22d0d7 /numpy | |
parent | 7d3b555ca383a20dcf4618ce3e3d0392ad556b04 (diff) | |
download | numpy-9e068f4032dab58cf310cfb663fa4930371cecf3.tar.gz |
ENH: Improve error message in multinomial
Improve error message when the sum of pvals is larger than 1
when the input data is an ndarray
closes #8317
xref #16732
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/random/_generator.pyx | 14 | ||||
-rw-r--r-- | numpy/random/tests/test_generator_mt19937.py | 8 |
2 files changed, 21 insertions, 1 deletions
diff --git a/numpy/random/_generator.pyx b/numpy/random/_generator.pyx index f25f16a8a..856b1cc21 100644 --- a/numpy/random/_generator.pyx +++ b/numpy/random/_generator.pyx @@ -3756,7 +3756,19 @@ cdef class Generator: pix = <double*>np.PyArray_DATA(parr) check_array_constraint(parr, 'pvals', CONS_BOUNDED_0_1) if kahan_sum(pix, d-1) > (1.0 + 1e-12): - raise ValueError("sum(pvals[:-1]) > 1.0") + msg = "sum(pvals[:-1]) > 1.0" + # When floating, but not float dtype, and close, improve the error + # 1.0001 works for float16 and float32 + if (isinstance(pvals, np.ndarray) and + pvals.dtype != float and + np.issubdtype(pvals.dtype, np.floating) and + pvals.sum() < 1.0001): + msg = ("sum(pvals[:-1].astype(np.float64)) > 1.0. pvals are " + "cast to 64-bit floating point values prior to " + "checking the constraint. Changes in precision when " + "casting may produce violations even if " + "pvals[:-1].sum() <= 1.") + raise ValueError(msg) if np.PyArray_NDIM(on) != 0: # vector check_array_constraint(on, 'n', CONS_NON_NEGATIVE) diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py index f6bd985b4..9de044774 100644 --- a/numpy/random/tests/test_generator_mt19937.py +++ b/numpy/random/tests/test_generator_mt19937.py @@ -142,6 +142,14 @@ class TestMultinomial: assert_raises(ValueError, random.multinomial, 10, [[[0], [1]], [[1], [0]]]) assert_raises(ValueError, random.multinomial, 10, np.array([[0, 1], [1, 0]])) + def test_multinomial_pvals_float32(self): + x = np.array([9.9e-01, 9.9e-01, 1.0e-09, 1.0e-09, 1.0e-09, 1.0e-09, + 1.0e-09, 1.0e-09, 1.0e-09, 1.0e-09], dtype=np.float32) + pvals = x / x.sum() + random = Generator(MT19937(1432985819)) + match = r"[\w\s]*pvals are cast to 64-bit floating" + with pytest.raises(ValueError, match=match): + random.multinomial(1, pvals) class TestMultivariateHypergeometric: |