summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorKevin Sheppard <kevin.k.sheppard@gmail.com>2019-04-12 17:06:30 +0100
committermattip <matti.picus@gmail.com>2019-05-20 18:45:27 +0300
commitf11921d6f2578a1cc74dc15bd1458c62d180c01f (patch)
treea678b8de077082b5f141d0adf090ccba64595e3e /numpy
parentbb7abf2d5bff82a48dbd773c31f66164abdd849d (diff)
downloadnumpy-f11921d6f2578a1cc74dc15bd1458c62d180c01f.tar.gz
MAINT: Simplify return types
Standardize returns types for Windows and 32-bit platforms on int64 in choice and randint (default). Refactor tomaxint to call randint
Diffstat (limited to 'numpy')
-rw-r--r--numpy/random/generator.pyx27
-rw-r--r--numpy/random/tests/test_against_numpy.py1
-rw-r--r--numpy/random/tests/test_generator_mt19937.py12
3 files changed, 19 insertions, 21 deletions
diff --git a/numpy/random/generator.pyx b/numpy/random/generator.pyx
index 29a36787e..6e5f421af 100644
--- a/numpy/random/generator.pyx
+++ b/numpy/random/generator.pyx
@@ -368,26 +368,11 @@ cdef class RandomGenerator:
[ True, True]]])
"""
- cdef np.npy_intp n
- cdef np.ndarray randoms
- cdef int64_t *randoms_data
-
- if size is None:
- with self.lock:
- return random_positive_int(self._brng)
-
- randoms = <np.ndarray>np.empty(size, dtype=np.int64)
- randoms_data = <int64_t*>np.PyArray_DATA(randoms)
- n = np.PyArray_SIZE(randoms)
-
- for i in range(n):
- with self.lock, nogil:
- randoms_data[i] = random_positive_int(self._brng)
- return randoms
+ return self.randint(0, np.iinfo(np.int).max + 1, dtype=np.int, size=size)
- def randint(self, low, high=None, size=None, dtype=int, use_masked=True):
+ def randint(self, low, high=None, size=None, dtype=np.int64, use_masked=True):
"""
- randint(low, high=None, size=None, dtype='l', use_masked=True)
+ randint(low, high=None, size=None, dtype='int64', use_masked=True)
Return random integers from `low` (inclusive) to `high` (exclusive).
@@ -661,9 +646,9 @@ cdef class RandomGenerator:
cdf /= cdf[-1]
uniform_samples = self.random_sample(shape)
idx = cdf.searchsorted(uniform_samples, side='right')
- idx = np.array(idx, copy=False) # searchsorted returns a scalar
+ idx = np.array(idx, copy=False, dtype=np.int64) # searchsorted returns a scalar
else:
- idx = self.randint(0, pop_size, size=shape)
+ idx = self.randint(0, pop_size, size=shape, dtype=np.int64)
else:
if size > pop_size:
raise ValueError("Cannot take a larger sample than "
@@ -692,7 +677,7 @@ cdef class RandomGenerator:
n_uniq += new.size
idx = found
else:
- idx = self.permutation(pop_size)[:size]
+ idx = (self.permutation(pop_size)[:size]).astype(np.int64)
if shape is not None:
idx.shape = shape
diff --git a/numpy/random/tests/test_against_numpy.py b/numpy/random/tests/test_against_numpy.py
index b930fcbff..9846ba38f 100644
--- a/numpy/random/tests/test_against_numpy.py
+++ b/numpy/random/tests/test_against_numpy.py
@@ -183,6 +183,7 @@ class TestAgainstNumPy(object):
self.rs.standard_exponential)
self._is_state_common_legacy()
+ @pytest.mark.xfail(reason='Stream broken for simplicity')
def test_tomaxint(self):
self._set_common_state()
self._is_state_common()
diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py
index 895e7fc6c..d76291d1a 100644
--- a/numpy/random/tests/test_generator_mt19937.py
+++ b/numpy/random/tests/test_generator_mt19937.py
@@ -639,6 +639,18 @@ class TestRandomDist(object):
p = [None, None, None]
assert_raises(ValueError, random.choice, a, p=p)
+ def test_choice_return_type(self):
+ # gh 9867
+ p = np.ones(4) / 4.
+ actual = random.choice(4, 2)
+ assert actual.dtype == np.int64
+ actual = random.choice(4, 2, replace=False)
+ assert actual.dtype == np.int64
+ actual = random.choice(4, 2, p=p)
+ assert actual.dtype == np.int64
+ actual = random.choice(4, 2, p=p, replace=False)
+ assert actual.dtype == np.int64
+
def test_bytes(self):
random.brng.seed(self.seed)
actual = random.bytes(10)