summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/random/_generator.pyx85
-rw-r--r--numpy/random/tests/test_generator_mt19937.py25
2 files changed, 97 insertions, 13 deletions
diff --git a/numpy/random/_generator.pyx b/numpy/random/_generator.pyx
index b976d51c6..27cb2859e 100644
--- a/numpy/random/_generator.pyx
+++ b/numpy/random/_generator.pyx
@@ -4007,10 +4007,12 @@ cdef class Generator:
# return val
cdef np.npy_intp k, totsize, i, j
- cdef np.ndarray alpha_arr, val_arr
+ cdef np.ndarray alpha_arr, val_arr, alpha_csum_arr
+ cdef double csum
cdef double *alpha_data
+ cdef double *alpha_csum_data
cdef double *val_data
- cdef double acc, invacc
+ cdef double acc, invacc, v
k = len(alpha)
alpha_arr = <np.ndarray>np.PyArray_FROMANY(
@@ -4034,17 +4036,74 @@ cdef class Generator:
i = 0
totsize = np.PyArray_SIZE(val_arr)
- with self.lock, nogil:
- while i < totsize:
- acc = 0.0
- for j in range(k):
- val_data[i+j] = random_standard_gamma(&self._bitgen,
- alpha_data[j])
- acc = acc + val_data[i + j]
- invacc = 1/acc
- for j in range(k):
- val_data[i + j] = val_data[i + j] * invacc
- i = i + k
+
+ # Select one of the following two algorithms for the generation
+ # of Dirichlet random variates (RVs)
+ #
+ # A) Small alpha case: Use the stick-breaking approach with beta
+ # random variates (RVs).
+ # B) Standard case: Perform unit normalisation of a vector
+ # of gamma random variates
+ #
+ # A) prevents NaNs resulting from 0/0 that may occur in B)
+ # when all values in the vector ':math:\\alpha' are smaller
+ # than 1, then there is a nonzero probability that all
+ # generated gamma RVs will be 0. When that happens, the
+ # normalization process ends up computing 0/0, giving nan. A)
+ # does not use divisions, so that a situation in which 0/0 has
+ # to be computed cannot occur. A) is slower than B) as
+ # generation of beta RVs is slower than generation of gamma
+ # RVs. A) is selected whenever `alpha.max() < t`, where `t <
+ # 1` is a threshold that controls the probability of
+ # generating a NaN value when B) is used. For a given
+ # threshold `t` this probability can be bounded by
+ # `gammainc(t, d)` where `gammainc` is the regularized
+ # incomplete gamma function and `d` is the smallest positive
+ # floating point number that can be represented with a given
+ # precision. For the chosen threshold `t=0.1` this probability
+ # is smaller than `1.8e-31` for double precision floating
+ # point numbers.
+
+ if (k > 0) and (alpha_arr.max() < 0.1):
+ # Small alpha case: Use stick-breaking approach with beta
+ # random variates (RVs).
+ # alpha_csum_data will hold the cumulative sum, right to
+ # left, of alpha_arr.
+ # Use a numpy array for memory management only. We could just as
+ # well have malloc'd alpha_csum_data. alpha_arr is a C-contiguous
+ # double array, therefore so is alpha_csum_arr.
+ alpha_csum_arr = np.empty_like(alpha_arr)
+ alpha_csum_data = <double*>np.PyArray_DATA(alpha_csum_arr)
+ csum = 0.0
+ for j in range(k - 1, -1, -1):
+ csum += alpha_data[j]
+ alpha_csum_data[j] = csum
+
+ with self.lock, nogil:
+ while i < totsize:
+ acc = 1.
+ for j in range(k - 1):
+ v = random_beta(&self._bitgen, alpha_data[j],
+ alpha_csum_data[j + 1])
+ val_data[i + j] = acc * v
+ acc *= (1. - v)
+ val_data[i + k - 1] = acc
+ i = i + k
+
+ else:
+ # Standard case: Unit normalisation of a vector of gamma random
+ # variates
+ with self.lock, nogil:
+ while i < totsize:
+ acc = 0.
+ for j in range(k):
+ val_data[i + j] = random_standard_gamma(&self._bitgen,
+ alpha_data[j])
+ acc = acc + val_data[i + j]
+ invacc = 1. / acc
+ for j in range(k):
+ val_data[i + j] = val_data[i + j] * invacc
+ i = i + k
return diric
diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py
index 08b44e4db..a28b7ca11 100644
--- a/numpy/random/tests/test_generator_mt19937.py
+++ b/numpy/random/tests/test_generator_mt19937.py
@@ -1103,6 +1103,31 @@ class TestRandomDist:
size=(3, 2))
assert_array_almost_equal(non_contig, contig)
+ def test_dirichlet_small_alpha(self):
+ eps = 1.0e-9 # 1.0e-10 -> runtime x 10; 1e-11 -> runtime x 200, etc.
+ alpha = eps * np.array([1., 1.0e-3])
+ random = Generator(MT19937(self.seed))
+ actual = random.dirichlet(alpha, size=(3, 2))
+ expected = np.array([
+ [[1., 0.],
+ [1., 0.]],
+ [[1., 0.],
+ [1., 0.]],
+ [[1., 0.],
+ [1., 0.]]
+ ])
+ assert_array_almost_equal(actual, expected, decimal=15)
+
+ @pytest.mark.slow
+ def test_dirichlet_moderately_small_alpha(self):
+ # Use alpha.max() < 0.1 to trigger stick breaking code path
+ alpha = np.array([0.02, 0.04, 0.03])
+ exact_mean = alpha / alpha.sum()
+ random = Generator(MT19937(self.seed))
+ sample = random.dirichlet(alpha, size=20000000)
+ sample_mean = sample.mean(axis=0)
+ assert_allclose(sample_mean, exact_mean, rtol=1e-3)
+
def test_exponential(self):
random = Generator(MT19937(self.seed))
actual = random.exponential(1.1234, size=(3, 2))