summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2017-10-09 21:27:32 -0700
committerGitHub <noreply@github.com>2017-10-09 21:27:32 -0700
commitbb5d666e84e2eb294543a67c6143d7e9124d1c73 (patch)
treec4dc58b728c4c9f171a96f13a6209a71bdf22e5d
parent6af2c42992f872233489ad00c5ac06ce5a430022 (diff)
parent45e00937723064f14752fc88c18e7146e62789a0 (diff)
downloadnumpy-bb5d666e84e2eb294543a67c6143d7e9124d1c73.tar.gz
Merge pull request #9824 from charris/gh-9680
BUG: Fixes for np.random.zipf
-rw-r--r--numpy/random/mtrand/distributions.c30
-rw-r--r--numpy/random/mtrand/mtrand.pyx10
-rw-r--r--numpy/random/tests/test_random.py8
3 files changed, 30 insertions, 18 deletions
diff --git a/numpy/random/mtrand/distributions.c b/numpy/random/mtrand/distributions.c
index 7673f92b4..b7e157915 100644
--- a/numpy/random/mtrand/distributions.c
+++ b/numpy/random/mtrand/distributions.c
@@ -45,6 +45,7 @@
#include <stdio.h>
#include <math.h>
#include <stdlib.h>
+#include <limits.h>
#ifndef min
#define min(x,y) ((x<y)?x:y)
@@ -719,26 +720,31 @@ double rk_wald(rk_state *state, double mean, double scale)
long rk_zipf(rk_state *state, double a)
{
- double T, U, V;
- long X;
double am1, b;
am1 = a - 1.0;
b = pow(2.0, am1);
- do
- {
- U = 1.0-rk_double(state);
+ while (1) {
+ double T, U, V, X;
+
+ U = 1.0 - rk_double(state);
V = rk_double(state);
- X = (long)floor(pow(U, -1.0/am1));
- /* The real result may be above what can be represented in a signed
- * long. It will get casted to -sys.maxint-1. Since this is
- * a straightforward rejection algorithm, we can just reject this value
- * in the rejection condition below. This function then models a Zipf
+ X = floor(pow(U, -1.0/am1));
+ /*
+ * The real result may be above what can be represented in a signed
+ * long. Since this is a straightforward rejection algorithm, we can
+ * just reject this value. This function then models a Zipf
* distribution truncated to sys.maxint.
*/
+ if (X > LONG_MAX || X < 1.0) {
+ continue;
+ }
+
T = pow(1.0 + 1.0/X, am1);
- } while (((V*X*(T-1.0)/(b-1.0)) > (T/b)) || X < 1);
- return X;
+ if (V*X*(T - 1.0)/(b - 1.0) <= T/b) {
+ return (long)X;
+ }
+ }
}
long rk_geometric_search(rk_state *state, double p)
diff --git a/numpy/random/mtrand/mtrand.pyx b/numpy/random/mtrand/mtrand.pyx
index 9e8a79804..adf820f0d 100644
--- a/numpy/random/mtrand/mtrand.pyx
+++ b/numpy/random/mtrand/mtrand.pyx
@@ -4076,13 +4076,15 @@ cdef class RandomState:
if oa.shape == ():
fa = PyFloat_AsDouble(a)
- if fa <= 1.0:
- raise ValueError("a <= 1.0")
+ # use logic that ensures NaN is rejected.
+ if not fa > 1.0:
+ raise ValueError("'a' must be a valid float > 1.0")
return discd_array_sc(self.internal_state, rk_zipf, size, fa,
self.lock)
- if np.any(np.less_equal(oa, 1.0)):
- raise ValueError("a <= 1.0")
+ # use logic that ensures NaN is rejected.
+ if not np.all(np.greater(oa, 1.0)):
+ raise ValueError("'a' must contain valid floats > 1.0")
return discd_array(self.internal_state, rk_zipf, size, oa, self.lock)
def geometric(self, p, size=None):
diff --git a/numpy/random/tests/test_random.py b/numpy/random/tests/test_random.py
index a530b9e13..e9c9bc492 100644
--- a/numpy/random/tests/test_random.py
+++ b/numpy/random/tests/test_random.py
@@ -1106,13 +1106,13 @@ class TestBroadcast(object):
assert_raises(ValueError, nonc_f, bad_dfnum, dfden, nonc * 3)
assert_raises(ValueError, nonc_f, dfnum, bad_dfden, nonc * 3)
assert_raises(ValueError, nonc_f, dfnum, dfden, bad_nonc * 3)
-
+
def test_noncentral_f_small_df(self):
self.setSeed()
desired = np.array([6.869638627492048, 0.785880199263955])
actual = np.random.noncentral_f(0.9, 0.9, 2, size=2)
assert_array_almost_equal(actual, desired, decimal=14)
-
+
def test_chisquare(self):
df = [1]
bad_df = [-1]
@@ -1434,6 +1434,10 @@ class TestBroadcast(object):
actual = zipf(a * 3)
assert_array_equal(actual, desired)
assert_raises(ValueError, zipf, bad_a * 3)
+ with np.errstate(invalid='ignore'):
+ assert_raises(ValueError, zipf, np.nan)
+ assert_raises(ValueError, zipf, [0, 0, np.nan])
+
def test_geometric(self):
p = [0.5]