summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/random/tests/test_random.py33
1 files changed, 16 insertions, 17 deletions
diff --git a/numpy/random/tests/test_random.py b/numpy/random/tests/test_random.py
index 773b63653..ecda5dab5 100644
--- a/numpy/random/tests/test_random.py
+++ b/numpy/random/tests/test_random.py
@@ -1712,23 +1712,22 @@ class TestSingleEltArrayInput:
out = func(self.argOne, argTwo[0])
assert_equal(out.shape, self.tgtShape)
-# TODO: Uncomment once randint can broadcast arguments
-# def test_randint(self):
-# itype = [bool, np.int8, np.uint8, np.int16, np.uint16,
-# np.int32, np.uint32, np.int64, np.uint64]
-# func = np.random.randint
-# high = np.array([1])
-# low = np.array([0])
-#
-# for dt in itype:
-# out = func(low, high, dtype=dt)
-# self.assert_equal(out.shape, self.tgtShape)
-#
-# out = func(low[0], high, dtype=dt)
-# self.assert_equal(out.shape, self.tgtShape)
-#
-# out = func(low, high[0], dtype=dt)
-# self.assert_equal(out.shape, self.tgtShape)
+ def test_randint(self):
+ itype = [bool, np.int8, np.uint8, np.int16, np.uint16,
+ np.int32, np.uint32, np.int64, np.uint64]
+ func = np.random.randint
+ high = np.array([1])
+ low = np.array([0])
+
+ for dt in itype:
+ out = func(low, high, dtype=dt)
+ assert_equal(out.shape, self.tgtShape)
+
+ out = func(low[0], high, dtype=dt)
+ assert_equal(out.shape, self.tgtShape)
+
+ out = func(low, high[0], dtype=dt)
+ assert_equal(out.shape, self.tgtShape)
def test_three_arg_funcs(self):
funcs = [np.random.noncentral_f, np.random.triangular,