summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2015-08-01 15:59:28 -0400
committerCharles Harris <charlesr.harris@gmail.com>2015-08-01 15:59:28 -0400
commitc03abd7833b229039a7cd8268b3091deb457a528 (patch)
treeed2ab3e846070ba4d1ea00aba3131c1b987d48c4
parent038a309d559c265fb04d03d1b3142ca735f42f2d (diff)
parentb6d0263239926e8b14ebc26a0d7b9469fa7866d4 (diff)
downloadnumpy-c03abd7833b229039a7cd8268b3091deb457a528.tar.gz
Merge pull request #6131 from argriffing/choice-precision
MAINT: adjust tolerance for validating the sum of probs in random.choice
-rw-r--r--numpy/random/mtrand/mtrand.pyx8
-rw-r--r--numpy/random/tests/test_regression.py14
2 files changed, 20 insertions, 2 deletions
diff --git a/numpy/random/mtrand/mtrand.pyx b/numpy/random/mtrand/mtrand.pyx
index 8345ac97e..9861204d9 100644
--- a/numpy/random/mtrand/mtrand.pyx
+++ b/numpy/random/mtrand/mtrand.pyx
@@ -1084,6 +1084,12 @@ cdef class RandomState:
if p is not None:
d = len(p)
+
+ atol = np.sqrt(np.finfo(np.float64).eps)
+ if isinstance(p, np.ndarray):
+ if np.issubdtype(p.dtype, np.floating):
+ atol = max(atol, np.sqrt(np.finfo(p.dtype).eps))
+
p = <ndarray>PyArray_ContiguousFromObject(p, NPY_DOUBLE, 1, 1)
pix = <double*>PyArray_DATA(p)
@@ -1093,7 +1099,7 @@ cdef class RandomState:
raise ValueError("a and p must have same size")
if np.logical_or.reduce(p < 0):
raise ValueError("probabilities are not non-negative")
- if abs(kahan_sum(pix, d) - 1.) > 1e-8:
+ if abs(kahan_sum(pix, d) - 1.) > atol:
raise ValueError("probabilities do not sum to 1")
shape = size
diff --git a/numpy/random/tests/test_regression.py b/numpy/random/tests/test_regression.py
index 52be0d4d5..528ef7839 100644
--- a/numpy/random/tests/test_regression.py
+++ b/numpy/random/tests/test_regression.py
@@ -2,7 +2,7 @@ from __future__ import division, absolute_import, print_function
import sys
from numpy.testing import (TestCase, run_module_suite, assert_,
- assert_array_equal)
+ assert_array_equal, assert_raises)
from numpy import random
from numpy.compat import long
import numpy as np
@@ -100,6 +100,18 @@ class TestRegression(TestCase):
x = np.random.beta(0.0001, 0.0001, size=100)
assert_(not np.any(np.isnan(x)), 'Nans in np.random.beta')
+ def test_choice_sum_of_probs_tolerance(self):
+ # The sum of probs should be 1.0 with some tolerance.
+ # For low precision dtypes the tolerance was too tight.
+ # See numpy github issue 6123.
+ np.random.seed(1234)
+ a = [1, 2, 3]
+ counts = [4, 4, 2]
+ for dt in np.float16, np.float32, np.float64:
+ probs = np.array(counts, dtype=dt) / sum(counts)
+ c = np.random.choice(a, p=probs)
+ assert_(c in a)
+ assert_raises(ValueError, np.random.choice, a, p=probs*0.9)
if __name__ == "__main__":
run_module_suite()