diff options
Diffstat (limited to 'numpy/random/tests/test_random.py')
-rw-r--r-- | numpy/random/tests/test_random.py | 87 |
1 files changed, 86 insertions, 1 deletions
diff --git a/numpy/random/tests/test_random.py b/numpy/random/tests/test_random.py index 7ec71e2e5..199509361 100644 --- a/numpy/random/tests/test_random.py +++ b/numpy/random/tests/test_random.py @@ -1345,7 +1345,6 @@ class TestBroadcast(TestCase): assert_raises(ValueError, logseries, bad_p_one * 3) assert_raises(ValueError, logseries, bad_p_two * 3) - class TestThread(TestCase): # make sure each state produces the same sequence even in threads def setUp(self): @@ -1388,5 +1387,91 @@ class TestThread(TestCase): out[...] = state.multinomial(10, [1/6.]*6, size=10000) self.check_function(gen_random, sz=(10000, 6)) +# See Issue #4263 +class TestSingleEltArrayInput(TestCase): + def setUp(self): + self.argOne = np.array([2]) + self.argTwo = np.array([3]) + self.argThree = np.array([4]) + self.tgtShape = (1,) + + def test_one_arg_funcs(self): + funcs = (np.random.exponential, np.random.standard_gamma, + np.random.chisquare, np.random.standard_t, + np.random.pareto, np.random.weibull, + np.random.power, np.random.rayleigh, + np.random.poisson, np.random.zipf, + np.random.geometric, np.random.logseries) + + probfuncs = (np.random.geometric, np.random.logseries) + + for func in funcs: + if func in probfuncs: # p < 1.0 + out = func(np.array([0.5])) + + else: + out = func(self.argOne) + + self.assertEqual(out.shape, self.tgtShape) + + def test_two_arg_funcs(self): + funcs = (np.random.uniform, np.random.normal, + np.random.beta, np.random.gamma, + np.random.f, np.random.noncentral_chisquare, + np.random.vonmises, np.random.laplace, + np.random.gumbel, np.random.logistic, + np.random.lognormal, np.random.wald, + np.random.binomial, np.random.negative_binomial) + + probfuncs = (np.random.binomial, np.random.negative_binomial) + + for func in funcs: + if func in probfuncs: # p <= 1 + argTwo = np.array([0.5]) + + else: + argTwo = self.argTwo + + out = func(self.argOne, argTwo) + self.assertEqual(out.shape, self.tgtShape) + + out = func(self.argOne[0], argTwo) + self.assertEqual(out.shape, self.tgtShape) + + out = func(self.argOne, argTwo[0]) + self.assertEqual(out.shape, self.tgtShape) + +# TODO: Uncomment once randint can broadcast arguments +# def test_randint(self): +# itype = [np.bool, np.int8, np.uint8, np.int16, np.uint16, +# np.int32, np.uint32, np.int64, np.uint64] +# func = np.random.randint +# high = np.array([1]) +# low = np.array([0]) +# +# for dt in itype: +# out = func(low, high, dtype=dt) +# self.assert_equal(out.shape, self.tgtShape) +# +# out = func(low[0], high, dtype=dt) +# self.assert_equal(out.shape, self.tgtShape) +# +# out = func(low, high[0], dtype=dt) +# self.assert_equal(out.shape, self.tgtShape) + + def test_three_arg_funcs(self): + funcs = [np.random.noncentral_f, np.random.triangular, + np.random.hypergeometric] + + for func in funcs: + out = func(self.argOne, self.argTwo, self.argThree) + self.assertEqual(out.shape, self.tgtShape) + + out = func(self.argOne[0], self.argTwo, self.argThree) + self.assertEqual(out.shape, self.tgtShape) + + out = func(self.argOne, self.argTwo[0], self.argThree) + self.assertEqual(out.shape, self.tgtShape) + if __name__ == "__main__": run_module_suite() |