diff options
author | gfyoung <gfyoung@mit.edu> | 2016-01-18 20:29:34 +0000 |
---|---|---|
committer | gfyoung <gfyoung@mit.edu> | 2016-01-24 03:20:20 +0000 |
commit | 61f872265b67b313058a07533eaed88f4170ff2c (patch) | |
tree | 176f5c2f8fd052e0dc5853427390feb43dbe91e4 /numpy/random/tests/test_random.py | |
parent | 9849922aa4ace91906878df51053a32e2719a722 (diff) | |
download | numpy-61f872265b67b313058a07533eaed88f4170ff2c.tar.gz |
BUG: One element array inputs get one element arrays returned in np.random
Fixes bug in np.random methods that would return scalars
when passed one-element array inputs. This is because
one-element ndarrays can be cast to integers / floats, which
is what functions like PyFloat_AsDouble do before converting
to the intended data type.
This commit changes the check used to determine whether the
inputs are purely scalar by converting all inputs to arrays
and checking if the resulting shape is an empty tuple (scalar)
or not (array).
Closes gh-4263.
Diffstat (limited to 'numpy/random/tests/test_random.py')
-rw-r--r-- | numpy/random/tests/test_random.py | 87 |
1 files changed, 86 insertions, 1 deletions
diff --git a/numpy/random/tests/test_random.py b/numpy/random/tests/test_random.py index 7ec71e2e5..199509361 100644 --- a/numpy/random/tests/test_random.py +++ b/numpy/random/tests/test_random.py @@ -1345,7 +1345,6 @@ class TestBroadcast(TestCase): assert_raises(ValueError, logseries, bad_p_one * 3) assert_raises(ValueError, logseries, bad_p_two * 3) - class TestThread(TestCase): # make sure each state produces the same sequence even in threads def setUp(self): @@ -1388,5 +1387,91 @@ class TestThread(TestCase): out[...] = state.multinomial(10, [1/6.]*6, size=10000) self.check_function(gen_random, sz=(10000, 6)) +# See Issue #4263 +class TestSingleEltArrayInput(TestCase): + def setUp(self): + self.argOne = np.array([2]) + self.argTwo = np.array([3]) + self.argThree = np.array([4]) + self.tgtShape = (1,) + + def test_one_arg_funcs(self): + funcs = (np.random.exponential, np.random.standard_gamma, + np.random.chisquare, np.random.standard_t, + np.random.pareto, np.random.weibull, + np.random.power, np.random.rayleigh, + np.random.poisson, np.random.zipf, + np.random.geometric, np.random.logseries) + + probfuncs = (np.random.geometric, np.random.logseries) + + for func in funcs: + if func in probfuncs: # p < 1.0 + out = func(np.array([0.5])) + + else: + out = func(self.argOne) + + self.assertEqual(out.shape, self.tgtShape) + + def test_two_arg_funcs(self): + funcs = (np.random.uniform, np.random.normal, + np.random.beta, np.random.gamma, + np.random.f, np.random.noncentral_chisquare, + np.random.vonmises, np.random.laplace, + np.random.gumbel, np.random.logistic, + np.random.lognormal, np.random.wald, + np.random.binomial, np.random.negative_binomial) + + probfuncs = (np.random.binomial, np.random.negative_binomial) + + for func in funcs: + if func in probfuncs: # p <= 1 + argTwo = np.array([0.5]) + + else: + argTwo = self.argTwo + + out = func(self.argOne, argTwo) + self.assertEqual(out.shape, self.tgtShape) + + out = func(self.argOne[0], argTwo) + self.assertEqual(out.shape, self.tgtShape) + + out = func(self.argOne, argTwo[0]) + self.assertEqual(out.shape, self.tgtShape) + +# TODO: Uncomment once randint can broadcast arguments +# def test_randint(self): +# itype = [np.bool, np.int8, np.uint8, np.int16, np.uint16, +# np.int32, np.uint32, np.int64, np.uint64] +# func = np.random.randint +# high = np.array([1]) +# low = np.array([0]) +# +# for dt in itype: +# out = func(low, high, dtype=dt) +# self.assert_equal(out.shape, self.tgtShape) +# +# out = func(low[0], high, dtype=dt) +# self.assert_equal(out.shape, self.tgtShape) +# +# out = func(low, high[0], dtype=dt) +# self.assert_equal(out.shape, self.tgtShape) + + def test_three_arg_funcs(self): + funcs = [np.random.noncentral_f, np.random.triangular, + np.random.hypergeometric] + + for func in funcs: + out = func(self.argOne, self.argTwo, self.argThree) + self.assertEqual(out.shape, self.tgtShape) + + out = func(self.argOne[0], self.argTwo, self.argThree) + self.assertEqual(out.shape, self.tgtShape) + + out = func(self.argOne, self.argTwo[0], self.argThree) + self.assertEqual(out.shape, self.tgtShape) + if __name__ == "__main__": run_module_suite() |