diff options
author | Jaime Fernandez <jaime.frio@gmail.com> | 2015-05-17 18:56:35 -0700 |
---|---|---|
committer | Jaime Fernandez <jaime.frio@gmail.com> | 2015-05-17 18:56:35 -0700 |
commit | 0c00f6910db141b6d514ded9a98857464a075838 (patch) | |
tree | fe039b7f5e0c227cf2e442f6f9fddf3c2a5dac7c | |
parent | c5cf20ab673c76f4b3ef7f85b4d1fd4149237772 (diff) | |
download | numpy-0c00f6910db141b6d514ded9a98857464a075838.tar.gz |
TST: np.broadcast should accept itself as input
-rw-r--r-- | numpy/core/tests/test_numeric.py | 28 |
1 files changed, 28 insertions, 0 deletions
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py index ee304a7af..7400366ac 100644 --- a/numpy/core/tests/test_numeric.py +++ b/numpy/core/tests/test_numeric.py @@ -2226,6 +2226,7 @@ class TestCross(TestCase): for axisc in range(-2, 2): assert_equal(np.cross(u, u, axisc=axisc).shape, (3, 4)) + def test_outer_out_param(): arr1 = np.ones((5,)) arr2 = np.ones((2,)) @@ -2236,6 +2237,7 @@ def test_outer_out_param(): assert_equal(res1, out1) assert_equal(np.outer(arr2, arr3, out2), out2) + class TestRequire(object): flag_names = ['C', 'C_CONTIGUOUS', 'CONTIGUOUS', 'F', 'F_CONTIGUOUS', 'FORTRAN', @@ -2310,5 +2312,31 @@ class TestRequire(object): yield self.set_and_check_flag, flag, None, a +class TestBroadcast(TestCase): + def test_broadcast_in_args(self): + # gh-5881 + arrs = [np.empty((6, 7)), np.empty((5, 6, 1)), np.empty((7,)), + np.empty((5, 1, 7))] + mits = [np.broadcast(*arrs), + np.broadcast(np.broadcast(*arrs[:2]), np.broadcast(*arrs[2:])), + np.broadcast(arrs[0], np.broadcast(*arrs[1:-1]), arrs[-1])] + for mit in mits: + assert_equal(mit.shape, (5, 6, 7)) + assert_equal(mit.nd, 3) + assert_equal(mit.numiter, 4) + for a, ia in zip(arrs, mit.iters): + assert_(a is ia.base) + + def test_number_of_arguments(self): + arr = np.empty((5,)) + for j in range(35): + arrs = [arr] * j + if j < 2 or j > 32: + assert_raises(ValueError, np.broadcast, *arrs) + else: + mit = np.broadcast(*arrs) + assert_equal(mit.numiter, j) + + if __name__ == "__main__": run_module_suite() |