summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2019-01-10 10:20:19 -0700
committerGitHub <noreply@github.com>2019-01-10 10:20:19 -0700
commit23d53d90f7918f5c85c8e9423d06daf4cabfa437 (patch)
treef61def84d9629fee980a77be2c58a2eb69f55f6c /numpy
parent1b144ad9d45647960c009ff4fe4e1b18958ff997 (diff)
parent73f1c08f60deb992a6a0912f6e02330bdf2081bd (diff)
downloadnumpy-23d53d90f7918f5c85c8e9423d06daf4cabfa437.tar.gz
Merge pull request #12545 from bashtage/choice-nan-protect
BUG: Ensure probabilities are not NaN in choice
Diffstat (limited to 'numpy')
-rw-r--r--numpy/random/mtrand/mtrand.pyx5
-rw-r--r--numpy/random/tests/test_random.py5
2 files changed, 9 insertions, 1 deletions
diff --git a/numpy/random/mtrand/mtrand.pyx b/numpy/random/mtrand/mtrand.pyx
index 134bc2e09..f922c59a5 100644
--- a/numpy/random/mtrand/mtrand.pyx
+++ b/numpy/random/mtrand/mtrand.pyx
@@ -1141,9 +1141,12 @@ cdef class RandomState:
raise ValueError("'p' must be 1-dimensional")
if p.size != pop_size:
raise ValueError("'a' and 'p' must have same size")
+ p_sum = kahan_sum(pix, d)
+ if np.isnan(p_sum):
+ raise ValueError("probabilities contain NaN")
if np.logical_or.reduce(p < 0):
raise ValueError("probabilities are not non-negative")
- if abs(kahan_sum(pix, d) - 1.) > atol:
+ if abs(p_sum - 1.) > atol:
raise ValueError("probabilities do not sum to 1")
shape = size
diff --git a/numpy/random/tests/test_random.py b/numpy/random/tests/test_random.py
index d0bb92a73..d4721bc62 100644
--- a/numpy/random/tests/test_random.py
+++ b/numpy/random/tests/test_random.py
@@ -448,6 +448,11 @@ class TestRandomDist(object):
assert_equal(np.random.choice(['a', 'b'], size=(3, 0, 4)).shape, (3, 0, 4))
assert_raises(ValueError, np.random.choice, [], 10)
+ def test_choice_nan_probabilities(self):
+ a = np.array([42, 1, 2])
+ p = [None, None, None]
+ assert_raises(ValueError, np.random.choice, a, p=p)
+
def test_bytes(self):
np.random.seed(self.seed)
actual = np.random.bytes(10)