summaryrefslogtreecommitdiff
path: root/numpy/random
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2018-06-13 01:45:21 -0700
committerEric Wieser <wieser.eric@gmail.com>2018-06-13 01:45:59 -0700
commite6763ed907de05640fe0583333dfb52b68f2d902 (patch)
tree446d39ce7a2ad625fdbd0673e87d184765693abe /numpy/random
parent3d9a082d146195944071a81ed0eae9d16976961a (diff)
downloadnumpy-e6763ed907de05640fe0583333dfb52b68f2d902.tar.gz
MAINT: Don't use dtype strings when the dtypes themselves can be used
Diffstat (limited to 'numpy/random')
-rw-r--r--numpy/random/mtrand/mtrand.pyx51
1 files changed, 25 insertions, 26 deletions
diff --git a/numpy/random/mtrand/mtrand.pyx b/numpy/random/mtrand/mtrand.pyx
index b45b3146f..43b1060a0 100644
--- a/numpy/random/mtrand/mtrand.pyx
+++ b/numpy/random/mtrand/mtrand.pyx
@@ -573,21 +573,21 @@ def _shape_from_size(size, d):
shape = tuple(size) + (d,)
return shape
-# Look up table for randint functions keyed by type name. The stored data
-# is a tuple (lbnd, ubnd, func), where lbnd is the smallest value for the
-# type, ubnd is one greater than the largest value, and func is the
+# Look up table for randint functions keyed by dtype.
+# The stored data is a tuple (lbnd, ubnd, func), where lbnd is the smallest
+# value for the type, ubnd is one greater than the largest value, and func is the
# function to call.
_randint_type = {
- 'bool': (0, 2, _rand_bool),
- 'int8': (-2**7, 2**7, _rand_int8),
- 'int16': (-2**15, 2**15, _rand_int16),
- 'int32': (-2**31, 2**31, _rand_int32),
- 'int64': (-2**63, 2**63, _rand_int64),
- 'uint8': (0, 2**8, _rand_uint8),
- 'uint16': (0, 2**16, _rand_uint16),
- 'uint32': (0, 2**32, _rand_uint32),
- 'uint64': (0, 2**64, _rand_uint64)
- }
+ np.dtype(np.bool_): (0, 2, _rand_bool),
+ np.dtype(np.int8): (-2**7, 2**7, _rand_int8),
+ np.dtype(np.int16): (-2**15, 2**15, _rand_int16),
+ np.dtype(np.int32): (-2**31, 2**31, _rand_int32),
+ np.dtype(np.int64): (-2**63, 2**63, _rand_int64),
+ np.dtype(np.uint8): (0, 2**8, _rand_uint8),
+ np.dtype(np.uint16): (0, 2**16, _rand_uint16),
+ np.dtype(np.uint32): (0, 2**32, _rand_uint32),
+ np.dtype(np.uint64): (0, 2**64, _rand_uint64)
+}
cdef class RandomState:
@@ -969,13 +969,12 @@ cdef class RandomState:
high = low
low = 0
- # '_randint_type' is defined in
- # 'generate_randint_helpers.py'
- key = np.dtype(dtype).name
- if key not in _randint_type:
- raise TypeError('Unsupported dtype "%s" for randint' % key)
-
- lowbnd, highbnd, randfunc = _randint_type[key]
+ raw_dtype = dtype
+ dtype = np.dtype(dtype)
+ try:
+ lowbnd, highbnd, randfunc = _randint_type[dtype]
+ except KeyError:
+ raise TypeError('Unsupported dtype "%s" for randint' % dtype)
# TODO: Do not cast these inputs to Python int
#
@@ -986,20 +985,20 @@ cdef class RandomState:
ihigh = int(high)
if ilow < lowbnd:
- raise ValueError("low is out of bounds for %s" % (key,))
+ raise ValueError("low is out of bounds for %s" % dtype)
if ihigh > highbnd:
- raise ValueError("high is out of bounds for %s" % (key,))
+ raise ValueError("high is out of bounds for %s" % dtype)
if ilow >= ihigh:
raise ValueError("low >= high")
with self.lock:
ret = randfunc(ilow, ihigh - 1, size, self.state_address)
- if size is None:
- if dtype in (np.bool, np.int, np.long):
- return dtype(ret)
+ # back-compat: keep python scalars when a python type is passed
+ if size is None and raw_dtype in (bool, int, np.long):
+ return raw_dtype(ret)
- return ret
+ return ret
def bytes(self, npy_intp length):
"""