diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2015-08-01 15:59:28 -0400 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2015-08-01 15:59:28 -0400 |
commit | c03abd7833b229039a7cd8268b3091deb457a528 (patch) | |
tree | ed2ab3e846070ba4d1ea00aba3131c1b987d48c4 | |
parent | 038a309d559c265fb04d03d1b3142ca735f42f2d (diff) | |
parent | b6d0263239926e8b14ebc26a0d7b9469fa7866d4 (diff) | |
download | numpy-c03abd7833b229039a7cd8268b3091deb457a528.tar.gz |
Merge pull request #6131 from argriffing/choice-precision
MAINT: adjust tolerance for validating the sum of probs in random.choice
-rw-r--r-- | numpy/random/mtrand/mtrand.pyx | 8 | ||||
-rw-r--r-- | numpy/random/tests/test_regression.py | 14 |
2 files changed, 20 insertions, 2 deletions
diff --git a/numpy/random/mtrand/mtrand.pyx b/numpy/random/mtrand/mtrand.pyx index 8345ac97e..9861204d9 100644 --- a/numpy/random/mtrand/mtrand.pyx +++ b/numpy/random/mtrand/mtrand.pyx @@ -1084,6 +1084,12 @@ cdef class RandomState: if p is not None: d = len(p) + + atol = np.sqrt(np.finfo(np.float64).eps) + if isinstance(p, np.ndarray): + if np.issubdtype(p.dtype, np.floating): + atol = max(atol, np.sqrt(np.finfo(p.dtype).eps)) + p = <ndarray>PyArray_ContiguousFromObject(p, NPY_DOUBLE, 1, 1) pix = <double*>PyArray_DATA(p) @@ -1093,7 +1099,7 @@ cdef class RandomState: raise ValueError("a and p must have same size") if np.logical_or.reduce(p < 0): raise ValueError("probabilities are not non-negative") - if abs(kahan_sum(pix, d) - 1.) > 1e-8: + if abs(kahan_sum(pix, d) - 1.) > atol: raise ValueError("probabilities do not sum to 1") shape = size diff --git a/numpy/random/tests/test_regression.py b/numpy/random/tests/test_regression.py index 52be0d4d5..528ef7839 100644 --- a/numpy/random/tests/test_regression.py +++ b/numpy/random/tests/test_regression.py @@ -2,7 +2,7 @@ from __future__ import division, absolute_import, print_function import sys from numpy.testing import (TestCase, run_module_suite, assert_, - assert_array_equal) + assert_array_equal, assert_raises) from numpy import random from numpy.compat import long import numpy as np @@ -100,6 +100,18 @@ class TestRegression(TestCase): x = np.random.beta(0.0001, 0.0001, size=100) assert_(not np.any(np.isnan(x)), 'Nans in np.random.beta') + def test_choice_sum_of_probs_tolerance(self): + # The sum of probs should be 1.0 with some tolerance. + # For low precision dtypes the tolerance was too tight. + # See numpy github issue 6123. + np.random.seed(1234) + a = [1, 2, 3] + counts = [4, 4, 2] + for dt in np.float16, np.float32, np.float64: + probs = np.array(counts, dtype=dt) / sum(counts) + c = np.random.choice(a, p=probs) + assert_(c in a) + assert_raises(ValueError, np.random.choice, a, p=probs*0.9) if __name__ == "__main__": run_module_suite() |