summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/random/_generator.pyx14
-rw-r--r--numpy/random/tests/test_generator_mt19937.py8
2 files changed, 21 insertions, 1 deletions
diff --git a/numpy/random/_generator.pyx b/numpy/random/_generator.pyx
index f25f16a8a..856b1cc21 100644
--- a/numpy/random/_generator.pyx
+++ b/numpy/random/_generator.pyx
@@ -3756,7 +3756,19 @@ cdef class Generator:
pix = <double*>np.PyArray_DATA(parr)
check_array_constraint(parr, 'pvals', CONS_BOUNDED_0_1)
if kahan_sum(pix, d-1) > (1.0 + 1e-12):
- raise ValueError("sum(pvals[:-1]) > 1.0")
+ msg = "sum(pvals[:-1]) > 1.0"
+ # When floating, but not float dtype, and close, improve the error
+ # 1.0001 works for float16 and float32
+ if (isinstance(pvals, np.ndarray) and
+ pvals.dtype != float and
+ np.issubdtype(pvals.dtype, np.floating) and
+ pvals.sum() < 1.0001):
+ msg = ("sum(pvals[:-1].astype(np.float64)) > 1.0. pvals are "
+ "cast to 64-bit floating point values prior to "
+ "checking the constraint. Changes in precision when "
+ "casting may produce violations even if "
+ "pvals[:-1].sum() <= 1.")
+ raise ValueError(msg)
if np.PyArray_NDIM(on) != 0: # vector
check_array_constraint(on, 'n', CONS_NON_NEGATIVE)
diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py
index f6bd985b4..9de044774 100644
--- a/numpy/random/tests/test_generator_mt19937.py
+++ b/numpy/random/tests/test_generator_mt19937.py
@@ -142,6 +142,14 @@ class TestMultinomial:
assert_raises(ValueError, random.multinomial, 10, [[[0], [1]], [[1], [0]]])
assert_raises(ValueError, random.multinomial, 10, np.array([[0, 1], [1, 0]]))
+ def test_multinomial_pvals_float32(self):
+ x = np.array([9.9e-01, 9.9e-01, 1.0e-09, 1.0e-09, 1.0e-09, 1.0e-09,
+ 1.0e-09, 1.0e-09, 1.0e-09, 1.0e-09], dtype=np.float32)
+ pvals = x / x.sum()
+ random = Generator(MT19937(1432985819))
+ match = r"[\w\s]*pvals are cast to 64-bit floating"
+ with pytest.raises(ValueError, match=match):
+ random.multinomial(1, pvals)
class TestMultivariateHypergeometric: