diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2018-06-13 01:45:21 -0700 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2018-06-13 01:45:59 -0700 |
commit | e6763ed907de05640fe0583333dfb52b68f2d902 (patch) | |
tree | 446d39ce7a2ad625fdbd0673e87d184765693abe /numpy/random | |
parent | 3d9a082d146195944071a81ed0eae9d16976961a (diff) | |
download | numpy-e6763ed907de05640fe0583333dfb52b68f2d902.tar.gz |
MAINT: Don't use dtype strings when the dtypes themselves can be used
Diffstat (limited to 'numpy/random')
-rw-r--r-- | numpy/random/mtrand/mtrand.pyx | 51 |
1 files changed, 25 insertions, 26 deletions
diff --git a/numpy/random/mtrand/mtrand.pyx b/numpy/random/mtrand/mtrand.pyx index b45b3146f..43b1060a0 100644 --- a/numpy/random/mtrand/mtrand.pyx +++ b/numpy/random/mtrand/mtrand.pyx @@ -573,21 +573,21 @@ def _shape_from_size(size, d): shape = tuple(size) + (d,) return shape -# Look up table for randint functions keyed by type name. The stored data -# is a tuple (lbnd, ubnd, func), where lbnd is the smallest value for the -# type, ubnd is one greater than the largest value, and func is the +# Look up table for randint functions keyed by dtype. +# The stored data is a tuple (lbnd, ubnd, func), where lbnd is the smallest +# value for the type, ubnd is one greater than the largest value, and func is the # function to call. _randint_type = { - 'bool': (0, 2, _rand_bool), - 'int8': (-2**7, 2**7, _rand_int8), - 'int16': (-2**15, 2**15, _rand_int16), - 'int32': (-2**31, 2**31, _rand_int32), - 'int64': (-2**63, 2**63, _rand_int64), - 'uint8': (0, 2**8, _rand_uint8), - 'uint16': (0, 2**16, _rand_uint16), - 'uint32': (0, 2**32, _rand_uint32), - 'uint64': (0, 2**64, _rand_uint64) - } + np.dtype(np.bool_): (0, 2, _rand_bool), + np.dtype(np.int8): (-2**7, 2**7, _rand_int8), + np.dtype(np.int16): (-2**15, 2**15, _rand_int16), + np.dtype(np.int32): (-2**31, 2**31, _rand_int32), + np.dtype(np.int64): (-2**63, 2**63, _rand_int64), + np.dtype(np.uint8): (0, 2**8, _rand_uint8), + np.dtype(np.uint16): (0, 2**16, _rand_uint16), + np.dtype(np.uint32): (0, 2**32, _rand_uint32), + np.dtype(np.uint64): (0, 2**64, _rand_uint64) +} cdef class RandomState: @@ -969,13 +969,12 @@ cdef class RandomState: high = low low = 0 - # '_randint_type' is defined in - # 'generate_randint_helpers.py' - key = np.dtype(dtype).name - if key not in _randint_type: - raise TypeError('Unsupported dtype "%s" for randint' % key) - - lowbnd, highbnd, randfunc = _randint_type[key] + raw_dtype = dtype + dtype = np.dtype(dtype) + try: + lowbnd, highbnd, randfunc = _randint_type[dtype] + except KeyError: + raise TypeError('Unsupported dtype "%s" for randint' % dtype) # TODO: Do not cast these inputs to Python int # @@ -986,20 +985,20 @@ cdef class RandomState: ihigh = int(high) if ilow < lowbnd: - raise ValueError("low is out of bounds for %s" % (key,)) + raise ValueError("low is out of bounds for %s" % dtype) if ihigh > highbnd: - raise ValueError("high is out of bounds for %s" % (key,)) + raise ValueError("high is out of bounds for %s" % dtype) if ilow >= ihigh: raise ValueError("low >= high") with self.lock: ret = randfunc(ilow, ihigh - 1, size, self.state_address) - if size is None: - if dtype in (np.bool, np.int, np.long): - return dtype(ret) + # back-compat: keep python scalars when a python type is passed + if size is None and raw_dtype in (bool, int, np.long): + return raw_dtype(ret) - return ret + return ret def bytes(self, npy_intp length): """ |