summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/random/mtrand/mtrand.pyx7
-rw-r--r--numpy/random/tests/test_random.py5
2 files changed, 12 insertions, 0 deletions
diff --git a/numpy/random/mtrand/mtrand.pyx b/numpy/random/mtrand/mtrand.pyx
index 68905f7ad..c4b4f7748 100644
--- a/numpy/random/mtrand/mtrand.pyx
+++ b/numpy/random/mtrand/mtrand.pyx
@@ -4666,6 +4666,11 @@ cdef class RandomState:
samples : ndarray,
The drawn samples, of shape (size, alpha.ndim).
+ Raises
+ -------
+ ValueError
+ If any value in alpha is less than or equal to zero
+
Notes
-----
.. math:: X \\approx \\prod_{i=1}^{k}{x^{\\alpha_i-1}_i}
@@ -4731,6 +4736,8 @@ cdef class RandomState:
k = len(alpha)
alpha_arr = <ndarray>PyArray_ContiguousFromObject(alpha, NPY_DOUBLE, 1, 1)
+ if np.any(np.less_equal(alpha_arr, 0)):
+ raise ValueError('alpha <= 0')
alpha_data = <double*>PyArray_DATA(alpha_arr)
shape = _shape_from_size(size, k)
diff --git a/numpy/random/tests/test_random.py b/numpy/random/tests/test_random.py
index 3a1d8af51..39c45889f 100644
--- a/numpy/random/tests/test_random.py
+++ b/numpy/random/tests/test_random.py
@@ -525,6 +525,11 @@ class TestRandomDist(object):
assert_raises(TypeError, np.random.dirichlet, p, float(1))
+ def test_dirichlet_bad_alpha(self):
+ # gh-2089
+ alpha = np.array([5.4e-01, -1.0e-16])
+ assert_raises(ValueError, np.random.mtrand.dirichlet, alpha)
+
def test_exponential(self):
np.random.seed(self.seed)
actual = np.random.exponential(1.1234, size=(3, 2))