From 633b11ca28e0c740215c4797db27363cde5246ba Mon Sep 17 00:00:00 2001 From: Kevin Sheppard Date: Thu, 17 Aug 2017 13:02:35 +0100 Subject: BUG: Missing dirichlet input validation Dirichlet does not validate inputs and hangs when values are zero. Adds check that values are strictly positive as required by the distribution. closes #2089 --- numpy/random/mtrand/mtrand.pyx | 7 +++++++ numpy/random/tests/test_random.py | 5 +++++ 2 files changed, 12 insertions(+) (limited to 'numpy') diff --git a/numpy/random/mtrand/mtrand.pyx b/numpy/random/mtrand/mtrand.pyx index 68905f7ad..c4b4f7748 100644 --- a/numpy/random/mtrand/mtrand.pyx +++ b/numpy/random/mtrand/mtrand.pyx @@ -4666,6 +4666,11 @@ cdef class RandomState: samples : ndarray, The drawn samples, of shape (size, alpha.ndim). + Raises + ------- + ValueError + If any value in alpha is less than or equal to zero + Notes ----- .. math:: X \\approx \\prod_{i=1}^{k}{x^{\\alpha_i-1}_i} @@ -4731,6 +4736,8 @@ cdef class RandomState: k = len(alpha) alpha_arr = PyArray_ContiguousFromObject(alpha, NPY_DOUBLE, 1, 1) + if np.any(np.less_equal(alpha_arr, 0)): + raise ValueError('alpha <= 0') alpha_data = PyArray_DATA(alpha_arr) shape = _shape_from_size(size, k) diff --git a/numpy/random/tests/test_random.py b/numpy/random/tests/test_random.py index 3a1d8af51..39c45889f 100644 --- a/numpy/random/tests/test_random.py +++ b/numpy/random/tests/test_random.py @@ -525,6 +525,11 @@ class TestRandomDist(object): assert_raises(TypeError, np.random.dirichlet, p, float(1)) + def test_dirichlet_bad_alpha(self): + # gh-2089 + alpha = np.array([5.4e-01, -1.0e-16]) + assert_raises(ValueError, np.random.mtrand.dirichlet, alpha) + def test_exponential(self): np.random.seed(self.seed) actual = np.random.exponential(1.1234, size=(3, 2)) -- cgit v1.2.1