diff options
author | Stephan Hoyer <shoyer@climate.com> | 2015-01-14 23:41:30 -0800 |
---|---|---|
committer | Stephan Hoyer <shoyer@climate.com> | 2015-02-26 18:15:44 -0800 |
commit | 05b5335ecf25e59477956b4f85b9a8edbdf71bcc (patch) | |
tree | 9f336ce41f3992cb955dbd5d0412015c9134c5c9 /numpy/lib/tests/test_stride_tricks.py | |
parent | 2e016ac65aceab4e08217794d6be7b365793976a (diff) | |
download | numpy-05b5335ecf25e59477956b4f85b9a8edbdf71bcc.tar.gz |
ENH: add broadcast_to function
Per the mailing list discussion [1], I have implemented a new function
`broadcast_to` that broadcasts an array to a given shape according to
numpy's broadcasting rules.
[1] http://mail.scipy.org/pipermail/numpy-discussion/2014-December/071796.html
Diffstat (limited to 'numpy/lib/tests/test_stride_tricks.py')
-rw-r--r-- | numpy/lib/tests/test_stride_tricks.py | 90 |
1 files changed, 88 insertions, 2 deletions
diff --git a/numpy/lib/tests/test_stride_tricks.py b/numpy/lib/tests/test_stride_tricks.py index bc7e30ca4..0b73109bc 100644 --- a/numpy/lib/tests/test_stride_tricks.py +++ b/numpy/lib/tests/test_stride_tricks.py @@ -5,8 +5,9 @@ from numpy.testing import ( run_module_suite, assert_equal, assert_array_equal, assert_raises, assert_ ) -from numpy.lib.stride_tricks import as_strided, broadcast_arrays - +from numpy.lib.stride_tricks import ( + as_strided, broadcast_arrays, _broadcast_shape, broadcast_to +) def assert_shapes_correct(input_shapes, expected_shape): # Broadcast a list of arrays with the given input shapes and check the @@ -217,6 +218,62 @@ def test_same_as_ufunc(): assert_same_as_ufunc(input_shapes[0], input_shapes[1], False, True) assert_same_as_ufunc(input_shapes[0], input_shapes[1], True, True) + +def test_broadcast_to_succeeds(): + data = [ + [np.array(0), (0,), np.array(0)], + [np.array(0), (1,), np.zeros(1)], + [np.array(0), (3,), np.zeros(3)], + [np.ones(1), (1,), np.ones(1)], + [np.ones(1), (2,), np.ones(2)], + [np.ones(1), (1, 2, 3), np.ones((1, 2, 3))], + [np.arange(3), (3,), np.arange(3)], + [np.arange(3), (1, 3), np.arange(3).reshape(1, -1)], + [np.arange(3), (2, 3), np.array([[0, 1, 2], [0, 1, 2]])], + # test if shape is not a tuple + [np.ones(0), 0, np.ones(0)], + [np.ones(1), 1, np.ones(1)], + [np.ones(1), 2, np.ones(2)], + # these cases with size 0 are strange, but they reproduce the behavior + # of broadcasting with ufuncs (see test_same_as_ufunc above) + [np.ones(1), (0,), np.ones(0)], + [np.ones((1, 2)), (0, 2), np.ones((0, 2))], + [np.ones((2, 1)), (2, 0), np.ones((2, 0))], + ] + for input_array, shape, expected in data: + actual = broadcast_to(input_array, shape) + assert_array_equal(expected, actual) + + +def test_broadcast_to_raises(): + data = [ + [(0,), ()], + [(1,), ()], + [(3,), ()], + [(3,), (1,)], + [(3,), (2,)], + [(3,), (4,)], + [(1, 2), (2, 1)], + [(1, 1), (1,)], + [(1,), -1], + [(1,), (-1,)], + [(1, 2), (-1, 2)], + ] + for orig_shape, target_shape in data: + arr = np.zeros(orig_shape) + assert_raises(ValueError, lambda: broadcast_to(arr, target_shape)) + + +def test_broadcast_shape(): + # broadcast_shape is already exercized indirectly by broadcast_arrays + assert_raises(ValueError, _broadcast_shape) + assert_equal(_broadcast_shape([1, 2]), (2,)) + assert_equal(_broadcast_shape(np.ones((1, 1))), (1, 1)) + assert_equal(_broadcast_shape(np.ones((1, 1)), np.ones((3, 4))), (3, 4)) + assert_equal(_broadcast_shape(*([np.ones((1, 2))] * 32)), (1, 2)) + assert_equal(_broadcast_shape(*([np.ones((1, 2))] * 100)), (1, 2)) + + def test_as_strided(): a = np.array([None]) a_view = as_strided(a) @@ -277,6 +334,35 @@ def test_subclasses(): assert_(type(b_view) is np.ndarray) assert_(a_view.shape == b_view.shape) + # and for broadcast_to + shape = (2, 4) + a_view = broadcast_to(a, shape) + assert_(type(a_view) is np.ndarray) + assert_(a_view.shape == shape) + a_view = broadcast_to(a, shape, subok=True) + assert_(type(a_view) is SimpleSubClass) + assert_(a_view.info == 'simple finalized') + assert_(a_view.shape == shape) + + +def test_writeable(): + # broadcast_to should return a readonly array + original = np.array([1, 2, 3]) + result = broadcast_to(original, (2, 3)) + assert_equal(result.flags.writeable, False) + assert_raises(ValueError, result.__setitem__, slice(None), 0) + + # but the result of broadcast_arrays needs to be writeable (for now), to + # preserve backwards compatibility + for results in [broadcast_arrays(original), + broadcast_arrays(0, original)]: + for result in results: + assert_equal(result.flags.writeable, True) + # keep readonly input readonly + original.flags.writeable = False + _, result = broadcast_arrays(0, original) + assert_equal(result.flags.writeable, False) + if __name__ == "__main__": run_module_suite() |