summaryrefslogtreecommitdiff
path: root/numpy/lib/tests/test_stride_tricks.py
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2015-02-26 21:45:17 -0500
committerCharles Harris <charlesr.harris@gmail.com>2015-02-26 21:45:17 -0500
commit06af9918f6bf03b8d818ec834f9fb48db57d1489 (patch)
tree9b8c9ded2271248f1ac139276675b3cc96cfae93 /numpy/lib/tests/test_stride_tricks.py
parent32e23a1d52a05d3a56f693010eaf8d96826db75f (diff)
parent05b5335ecf25e59477956b4f85b9a8edbdf71bcc (diff)
downloadnumpy-06af9918f6bf03b8d818ec834f9fb48db57d1489.tar.gz
Merge pull request #5371 from shoyer/broadcast_to
ENH: add np.broadcast_to and reimplement np.broadcast_arrays
Diffstat (limited to 'numpy/lib/tests/test_stride_tricks.py')
-rw-r--r--numpy/lib/tests/test_stride_tricks.py90
1 files changed, 88 insertions, 2 deletions
diff --git a/numpy/lib/tests/test_stride_tricks.py b/numpy/lib/tests/test_stride_tricks.py
index bc7e30ca4..0b73109bc 100644
--- a/numpy/lib/tests/test_stride_tricks.py
+++ b/numpy/lib/tests/test_stride_tricks.py
@@ -5,8 +5,9 @@ from numpy.testing import (
run_module_suite, assert_equal, assert_array_equal,
assert_raises, assert_
)
-from numpy.lib.stride_tricks import as_strided, broadcast_arrays
-
+from numpy.lib.stride_tricks import (
+ as_strided, broadcast_arrays, _broadcast_shape, broadcast_to
+)
def assert_shapes_correct(input_shapes, expected_shape):
# Broadcast a list of arrays with the given input shapes and check the
@@ -217,6 +218,62 @@ def test_same_as_ufunc():
assert_same_as_ufunc(input_shapes[0], input_shapes[1], False, True)
assert_same_as_ufunc(input_shapes[0], input_shapes[1], True, True)
+
+def test_broadcast_to_succeeds():
+ data = [
+ [np.array(0), (0,), np.array(0)],
+ [np.array(0), (1,), np.zeros(1)],
+ [np.array(0), (3,), np.zeros(3)],
+ [np.ones(1), (1,), np.ones(1)],
+ [np.ones(1), (2,), np.ones(2)],
+ [np.ones(1), (1, 2, 3), np.ones((1, 2, 3))],
+ [np.arange(3), (3,), np.arange(3)],
+ [np.arange(3), (1, 3), np.arange(3).reshape(1, -1)],
+ [np.arange(3), (2, 3), np.array([[0, 1, 2], [0, 1, 2]])],
+ # test if shape is not a tuple
+ [np.ones(0), 0, np.ones(0)],
+ [np.ones(1), 1, np.ones(1)],
+ [np.ones(1), 2, np.ones(2)],
+ # these cases with size 0 are strange, but they reproduce the behavior
+ # of broadcasting with ufuncs (see test_same_as_ufunc above)
+ [np.ones(1), (0,), np.ones(0)],
+ [np.ones((1, 2)), (0, 2), np.ones((0, 2))],
+ [np.ones((2, 1)), (2, 0), np.ones((2, 0))],
+ ]
+ for input_array, shape, expected in data:
+ actual = broadcast_to(input_array, shape)
+ assert_array_equal(expected, actual)
+
+
+def test_broadcast_to_raises():
+ data = [
+ [(0,), ()],
+ [(1,), ()],
+ [(3,), ()],
+ [(3,), (1,)],
+ [(3,), (2,)],
+ [(3,), (4,)],
+ [(1, 2), (2, 1)],
+ [(1, 1), (1,)],
+ [(1,), -1],
+ [(1,), (-1,)],
+ [(1, 2), (-1, 2)],
+ ]
+ for orig_shape, target_shape in data:
+ arr = np.zeros(orig_shape)
+ assert_raises(ValueError, lambda: broadcast_to(arr, target_shape))
+
+
+def test_broadcast_shape():
+ # broadcast_shape is already exercized indirectly by broadcast_arrays
+ assert_raises(ValueError, _broadcast_shape)
+ assert_equal(_broadcast_shape([1, 2]), (2,))
+ assert_equal(_broadcast_shape(np.ones((1, 1))), (1, 1))
+ assert_equal(_broadcast_shape(np.ones((1, 1)), np.ones((3, 4))), (3, 4))
+ assert_equal(_broadcast_shape(*([np.ones((1, 2))] * 32)), (1, 2))
+ assert_equal(_broadcast_shape(*([np.ones((1, 2))] * 100)), (1, 2))
+
+
def test_as_strided():
a = np.array([None])
a_view = as_strided(a)
@@ -277,6 +334,35 @@ def test_subclasses():
assert_(type(b_view) is np.ndarray)
assert_(a_view.shape == b_view.shape)
+ # and for broadcast_to
+ shape = (2, 4)
+ a_view = broadcast_to(a, shape)
+ assert_(type(a_view) is np.ndarray)
+ assert_(a_view.shape == shape)
+ a_view = broadcast_to(a, shape, subok=True)
+ assert_(type(a_view) is SimpleSubClass)
+ assert_(a_view.info == 'simple finalized')
+ assert_(a_view.shape == shape)
+
+
+def test_writeable():
+ # broadcast_to should return a readonly array
+ original = np.array([1, 2, 3])
+ result = broadcast_to(original, (2, 3))
+ assert_equal(result.flags.writeable, False)
+ assert_raises(ValueError, result.__setitem__, slice(None), 0)
+
+ # but the result of broadcast_arrays needs to be writeable (for now), to
+ # preserve backwards compatibility
+ for results in [broadcast_arrays(original),
+ broadcast_arrays(0, original)]:
+ for result in results:
+ assert_equal(result.flags.writeable, True)
+ # keep readonly input readonly
+ original.flags.writeable = False
+ _, result = broadcast_arrays(0, original)
+ assert_equal(result.flags.writeable, False)
+
if __name__ == "__main__":
run_module_suite()