diff options
| author | Charles Harris <charlesr.harris@gmail.com> | 2019-01-10 10:20:19 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2019-01-10 10:20:19 -0700 |
| commit | 23d53d90f7918f5c85c8e9423d06daf4cabfa437 (patch) | |
| tree | f61def84d9629fee980a77be2c58a2eb69f55f6c /numpy | |
| parent | 1b144ad9d45647960c009ff4fe4e1b18958ff997 (diff) | |
| parent | 73f1c08f60deb992a6a0912f6e02330bdf2081bd (diff) | |
| download | numpy-23d53d90f7918f5c85c8e9423d06daf4cabfa437.tar.gz | |
Merge pull request #12545 from bashtage/choice-nan-protect
BUG: Ensure probabilities are not NaN in choice
Diffstat (limited to 'numpy')
| -rw-r--r-- | numpy/random/mtrand/mtrand.pyx | 5 | ||||
| -rw-r--r-- | numpy/random/tests/test_random.py | 5 |
2 files changed, 9 insertions, 1 deletions
diff --git a/numpy/random/mtrand/mtrand.pyx b/numpy/random/mtrand/mtrand.pyx index 134bc2e09..f922c59a5 100644 --- a/numpy/random/mtrand/mtrand.pyx +++ b/numpy/random/mtrand/mtrand.pyx @@ -1141,9 +1141,12 @@ cdef class RandomState: raise ValueError("'p' must be 1-dimensional") if p.size != pop_size: raise ValueError("'a' and 'p' must have same size") + p_sum = kahan_sum(pix, d) + if np.isnan(p_sum): + raise ValueError("probabilities contain NaN") if np.logical_or.reduce(p < 0): raise ValueError("probabilities are not non-negative") - if abs(kahan_sum(pix, d) - 1.) > atol: + if abs(p_sum - 1.) > atol: raise ValueError("probabilities do not sum to 1") shape = size diff --git a/numpy/random/tests/test_random.py b/numpy/random/tests/test_random.py index d0bb92a73..d4721bc62 100644 --- a/numpy/random/tests/test_random.py +++ b/numpy/random/tests/test_random.py @@ -448,6 +448,11 @@ class TestRandomDist(object): assert_equal(np.random.choice(['a', 'b'], size=(3, 0, 4)).shape, (3, 0, 4)) assert_raises(ValueError, np.random.choice, [], 10) + def test_choice_nan_probabilities(self): + a = np.array([42, 1, 2]) + p = [None, None, None] + assert_raises(ValueError, np.random.choice, a, p=p) + def test_bytes(self): np.random.seed(self.seed) actual = np.random.bytes(10) |
