diff options
| -rw-r--r-- | numpy/random/tests/test_random.py | 33 |
1 files changed, 16 insertions, 17 deletions
diff --git a/numpy/random/tests/test_random.py b/numpy/random/tests/test_random.py index 773b63653..ecda5dab5 100644 --- a/numpy/random/tests/test_random.py +++ b/numpy/random/tests/test_random.py @@ -1712,23 +1712,22 @@ class TestSingleEltArrayInput: out = func(self.argOne, argTwo[0]) assert_equal(out.shape, self.tgtShape) -# TODO: Uncomment once randint can broadcast arguments -# def test_randint(self): -# itype = [bool, np.int8, np.uint8, np.int16, np.uint16, -# np.int32, np.uint32, np.int64, np.uint64] -# func = np.random.randint -# high = np.array([1]) -# low = np.array([0]) -# -# for dt in itype: -# out = func(low, high, dtype=dt) -# self.assert_equal(out.shape, self.tgtShape) -# -# out = func(low[0], high, dtype=dt) -# self.assert_equal(out.shape, self.tgtShape) -# -# out = func(low, high[0], dtype=dt) -# self.assert_equal(out.shape, self.tgtShape) + def test_randint(self): + itype = [bool, np.int8, np.uint8, np.int16, np.uint16, + np.int32, np.uint32, np.int64, np.uint64] + func = np.random.randint + high = np.array([1]) + low = np.array([0]) + + for dt in itype: + out = func(low, high, dtype=dt) + assert_equal(out.shape, self.tgtShape) + + out = func(low[0], high, dtype=dt) + assert_equal(out.shape, self.tgtShape) + + out = func(low, high[0], dtype=dt) + assert_equal(out.shape, self.tgtShape) def test_three_arg_funcs(self): funcs = [np.random.noncentral_f, np.random.triangular, |
