summaryrefslogtreecommitdiff
path: root/numpy/random/tests/test_random.py
diff options
context:
space:
mode:
authorgfyoung <gfyoung@mit.edu>2016-01-18 20:29:34 +0000
committergfyoung <gfyoung@mit.edu>2016-01-24 03:20:20 +0000
commit61f872265b67b313058a07533eaed88f4170ff2c (patch)
tree176f5c2f8fd052e0dc5853427390feb43dbe91e4 /numpy/random/tests/test_random.py
parent9849922aa4ace91906878df51053a32e2719a722 (diff)
downloadnumpy-61f872265b67b313058a07533eaed88f4170ff2c.tar.gz
BUG: One element array inputs get one element arrays returned in np.random
Fixes bug in np.random methods that would return scalars when passed one-element array inputs. This is because one-element ndarrays can be cast to integers / floats, which is what functions like PyFloat_AsDouble do before converting to the intended data type. This commit changes the check used to determine whether the inputs are purely scalar by converting all inputs to arrays and checking if the resulting shape is an empty tuple (scalar) or not (array). Closes gh-4263.
Diffstat (limited to 'numpy/random/tests/test_random.py')
-rw-r--r--numpy/random/tests/test_random.py87
1 files changed, 86 insertions, 1 deletions
diff --git a/numpy/random/tests/test_random.py b/numpy/random/tests/test_random.py
index 7ec71e2e5..199509361 100644
--- a/numpy/random/tests/test_random.py
+++ b/numpy/random/tests/test_random.py
@@ -1345,7 +1345,6 @@ class TestBroadcast(TestCase):
assert_raises(ValueError, logseries, bad_p_one * 3)
assert_raises(ValueError, logseries, bad_p_two * 3)
-
class TestThread(TestCase):
# make sure each state produces the same sequence even in threads
def setUp(self):
@@ -1388,5 +1387,91 @@ class TestThread(TestCase):
out[...] = state.multinomial(10, [1/6.]*6, size=10000)
self.check_function(gen_random, sz=(10000, 6))
+# See Issue #4263
+class TestSingleEltArrayInput(TestCase):
+ def setUp(self):
+ self.argOne = np.array([2])
+ self.argTwo = np.array([3])
+ self.argThree = np.array([4])
+ self.tgtShape = (1,)
+
+ def test_one_arg_funcs(self):
+ funcs = (np.random.exponential, np.random.standard_gamma,
+ np.random.chisquare, np.random.standard_t,
+ np.random.pareto, np.random.weibull,
+ np.random.power, np.random.rayleigh,
+ np.random.poisson, np.random.zipf,
+ np.random.geometric, np.random.logseries)
+
+ probfuncs = (np.random.geometric, np.random.logseries)
+
+ for func in funcs:
+ if func in probfuncs: # p < 1.0
+ out = func(np.array([0.5]))
+
+ else:
+ out = func(self.argOne)
+
+ self.assertEqual(out.shape, self.tgtShape)
+
+ def test_two_arg_funcs(self):
+ funcs = (np.random.uniform, np.random.normal,
+ np.random.beta, np.random.gamma,
+ np.random.f, np.random.noncentral_chisquare,
+ np.random.vonmises, np.random.laplace,
+ np.random.gumbel, np.random.logistic,
+ np.random.lognormal, np.random.wald,
+ np.random.binomial, np.random.negative_binomial)
+
+ probfuncs = (np.random.binomial, np.random.negative_binomial)
+
+ for func in funcs:
+ if func in probfuncs: # p <= 1
+ argTwo = np.array([0.5])
+
+ else:
+ argTwo = self.argTwo
+
+ out = func(self.argOne, argTwo)
+ self.assertEqual(out.shape, self.tgtShape)
+
+ out = func(self.argOne[0], argTwo)
+ self.assertEqual(out.shape, self.tgtShape)
+
+ out = func(self.argOne, argTwo[0])
+ self.assertEqual(out.shape, self.tgtShape)
+
+# TODO: Uncomment once randint can broadcast arguments
+# def test_randint(self):
+# itype = [np.bool, np.int8, np.uint8, np.int16, np.uint16,
+# np.int32, np.uint32, np.int64, np.uint64]
+# func = np.random.randint
+# high = np.array([1])
+# low = np.array([0])
+#
+# for dt in itype:
+# out = func(low, high, dtype=dt)
+# self.assert_equal(out.shape, self.tgtShape)
+#
+# out = func(low[0], high, dtype=dt)
+# self.assert_equal(out.shape, self.tgtShape)
+#
+# out = func(low, high[0], dtype=dt)
+# self.assert_equal(out.shape, self.tgtShape)
+
+ def test_three_arg_funcs(self):
+ funcs = [np.random.noncentral_f, np.random.triangular,
+ np.random.hypergeometric]
+
+ for func in funcs:
+ out = func(self.argOne, self.argTwo, self.argThree)
+ self.assertEqual(out.shape, self.tgtShape)
+
+ out = func(self.argOne[0], self.argTwo, self.argThree)
+ self.assertEqual(out.shape, self.tgtShape)
+
+ out = func(self.argOne, self.argTwo[0], self.argThree)
+ self.assertEqual(out.shape, self.tgtShape)
+
if __name__ == "__main__":
run_module_suite()