summaryrefslogtreecommitdiff
path: root/numpy/random/tests/test_random.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/random/tests/test_random.py')
-rw-r--r--numpy/random/tests/test_random.py87
1 files changed, 86 insertions, 1 deletions
diff --git a/numpy/random/tests/test_random.py b/numpy/random/tests/test_random.py
index 7ec71e2e5..199509361 100644
--- a/numpy/random/tests/test_random.py
+++ b/numpy/random/tests/test_random.py
@@ -1345,7 +1345,6 @@ class TestBroadcast(TestCase):
assert_raises(ValueError, logseries, bad_p_one * 3)
assert_raises(ValueError, logseries, bad_p_two * 3)
-
class TestThread(TestCase):
# make sure each state produces the same sequence even in threads
def setUp(self):
@@ -1388,5 +1387,91 @@ class TestThread(TestCase):
out[...] = state.multinomial(10, [1/6.]*6, size=10000)
self.check_function(gen_random, sz=(10000, 6))
+# See Issue #4263
+class TestSingleEltArrayInput(TestCase):
+ def setUp(self):
+ self.argOne = np.array([2])
+ self.argTwo = np.array([3])
+ self.argThree = np.array([4])
+ self.tgtShape = (1,)
+
+ def test_one_arg_funcs(self):
+ funcs = (np.random.exponential, np.random.standard_gamma,
+ np.random.chisquare, np.random.standard_t,
+ np.random.pareto, np.random.weibull,
+ np.random.power, np.random.rayleigh,
+ np.random.poisson, np.random.zipf,
+ np.random.geometric, np.random.logseries)
+
+ probfuncs = (np.random.geometric, np.random.logseries)
+
+ for func in funcs:
+ if func in probfuncs: # p < 1.0
+ out = func(np.array([0.5]))
+
+ else:
+ out = func(self.argOne)
+
+ self.assertEqual(out.shape, self.tgtShape)
+
+ def test_two_arg_funcs(self):
+ funcs = (np.random.uniform, np.random.normal,
+ np.random.beta, np.random.gamma,
+ np.random.f, np.random.noncentral_chisquare,
+ np.random.vonmises, np.random.laplace,
+ np.random.gumbel, np.random.logistic,
+ np.random.lognormal, np.random.wald,
+ np.random.binomial, np.random.negative_binomial)
+
+ probfuncs = (np.random.binomial, np.random.negative_binomial)
+
+ for func in funcs:
+ if func in probfuncs: # p <= 1
+ argTwo = np.array([0.5])
+
+ else:
+ argTwo = self.argTwo
+
+ out = func(self.argOne, argTwo)
+ self.assertEqual(out.shape, self.tgtShape)
+
+ out = func(self.argOne[0], argTwo)
+ self.assertEqual(out.shape, self.tgtShape)
+
+ out = func(self.argOne, argTwo[0])
+ self.assertEqual(out.shape, self.tgtShape)
+
+# TODO: Uncomment once randint can broadcast arguments
+# def test_randint(self):
+# itype = [np.bool, np.int8, np.uint8, np.int16, np.uint16,
+# np.int32, np.uint32, np.int64, np.uint64]
+# func = np.random.randint
+# high = np.array([1])
+# low = np.array([0])
+#
+# for dt in itype:
+# out = func(low, high, dtype=dt)
+# self.assert_equal(out.shape, self.tgtShape)
+#
+# out = func(low[0], high, dtype=dt)
+# self.assert_equal(out.shape, self.tgtShape)
+#
+# out = func(low, high[0], dtype=dt)
+# self.assert_equal(out.shape, self.tgtShape)
+
+ def test_three_arg_funcs(self):
+ funcs = [np.random.noncentral_f, np.random.triangular,
+ np.random.hypergeometric]
+
+ for func in funcs:
+ out = func(self.argOne, self.argTwo, self.argThree)
+ self.assertEqual(out.shape, self.tgtShape)
+
+ out = func(self.argOne[0], self.argTwo, self.argThree)
+ self.assertEqual(out.shape, self.tgtShape)
+
+ out = func(self.argOne, self.argTwo[0], self.argThree)
+ self.assertEqual(out.shape, self.tgtShape)
+
if __name__ == "__main__":
run_module_suite()