summaryrefslogtreecommitdiff
path: root/numpy/lib/tests/test_stride_tricks.py
diff options
context:
space:
mode:
authorRobert Kern <robert.kern@gmail.com>2008-07-03 06:42:28 +0000
committerRobert Kern <robert.kern@gmail.com>2008-07-03 06:42:28 +0000
commitf912322ee454fee1d15da01c5fd951ab2fcb5f99 (patch)
tree9ef9a29f9573775cb3785a78b6838d9f4c3cbf29 /numpy/lib/tests/test_stride_tricks.py
parenta74f0dfbcdfaf0ed5929fed7a27dc8738709828f (diff)
downloadnumpy-f912322ee454fee1d15da01c5fd951ab2fcb5f99.tar.gz
ENH: Add broadcast_arrays() function to expose broadcasting to pure Python functions that cannot be made to be ufuncs.
Diffstat (limited to 'numpy/lib/tests/test_stride_tricks.py')
-rw-r--r--numpy/lib/tests/test_stride_tricks.py206
1 files changed, 206 insertions, 0 deletions
diff --git a/numpy/lib/tests/test_stride_tricks.py b/numpy/lib/tests/test_stride_tricks.py
new file mode 100644
index 000000000..955a2cbc7
--- /dev/null
+++ b/numpy/lib/tests/test_stride_tricks.py
@@ -0,0 +1,206 @@
+from nose.tools import assert_raises
+import numpy as np
+from numpy.testing import assert_array_equal
+
+from numpy.lib.stride_tricks import broadcast_arrays
+
+
+def assert_shapes_correct(input_shapes, expected_shape):
+ """ Broadcast a list of arrays with the given input shapes and check the
+ common output shape.
+ """
+ inarrays = [np.zeros(s) for s in input_shapes]
+ outarrays = broadcast_arrays(*inarrays)
+ outshapes = [a.shape for a in outarrays]
+ expected = [expected_shape] * len(inarrays)
+ assert outshapes == expected
+
+def assert_incompatible_shapes_raise(input_shapes):
+ """ Broadcast a list of arrays with the given (incompatible) input shapes
+ and check that they raise a ValueError.
+ """
+ inarrays = [np.zeros(s) for s in input_shapes]
+ assert_raises(ValueError, broadcast_arrays, *inarrays)
+
+def assert_same_as_ufunc(shape0, shape1, transposed=False, flipped=False):
+ """ Broadcast two shapes against each other and check that the data layout
+ is the same as if a ufunc did the broadcasting.
+ """
+ x0 = np.zeros(shape0, dtype=int)
+ # Note that multiply.reduce's identity element is 1.0, so when shape1==(),
+ # this gives the desired n==1.
+ n = int(np.multiply.reduce(shape1))
+ x1 = np.arange(n).reshape(shape1)
+ if transposed:
+ x0 = x0.T
+ x1 = x1.T
+ if flipped:
+ x0 = x0[::-1]
+ x1 = x1[::-1]
+ # Use the add ufunc to do the broadcasting. Since we're adding 0s to x1, the
+ # result should be exactly the same as the broadcasted view of x1.
+ y = x0 + x1
+ b0, b1 = broadcast_arrays(x0, x1)
+ assert_array_equal(y, b1)
+
+
+def test_same():
+ x = np.arange(10)
+ y = np.arange(10)
+ bx, by = broadcast_arrays(x, y)
+ assert_array_equal(x, bx)
+ assert_array_equal(y, by)
+
+def test_one_off():
+ x = np.array([[1,2,3]])
+ y = np.array([[1],[2],[3]])
+ bx, by = broadcast_arrays(x, y)
+ bx0 = np.array([[1,2,3],[1,2,3],[1,2,3]])
+ by0 = bx0.T
+ assert_array_equal(bx0, bx)
+ assert_array_equal(by0, by)
+
+def test_same_input_shapes():
+ """ Check that the final shape is just the input shape.
+ """
+ data = [
+ (),
+ (1,),
+ (3,),
+ (0,1),
+ (0,3),
+ (1,0),
+ (3,0),
+ (1,3),
+ (3,1),
+ (3,3),
+ ]
+ for shape in data:
+ input_shapes = [shape]
+ # Single input.
+ yield assert_shapes_correct, input_shapes, shape
+ # Double input.
+ input_shapes2 = [shape, shape]
+ yield assert_shapes_correct, input_shapes2, shape
+ # Triple input.
+ input_shapes3 = [shape, shape, shape]
+ yield assert_shapes_correct, input_shapes3, shape
+
+def test_two_compatible_by_ones_input_shapes():
+ """ Check that two different input shapes (of the same length but some have
+ 1s) broadcast to the correct shape.
+ """
+ data = [
+ [[(1,), (3,)], (3,)],
+ [[(1,3), (3,3)], (3,3)],
+ [[(3,1), (3,3)], (3,3)],
+ [[(1,3), (3,1)], (3,3)],
+ [[(1,1), (3,3)], (3,3)],
+ [[(1,1), (1,3)], (1,3)],
+ [[(1,1), (3,1)], (3,1)],
+ [[(1,0), (0,0)], (0,0)],
+ [[(0,1), (0,0)], (0,0)],
+ [[(1,0), (0,1)], (0,0)],
+ [[(1,1), (0,0)], (0,0)],
+ [[(1,1), (1,0)], (1,0)],
+ [[(1,1), (0,1)], (0,1)],
+ ]
+ for input_shapes, expected_shape in data:
+ yield assert_shapes_correct, input_shapes, expected_shape
+ # Reverse the input shapes since broadcasting should be symmetric.
+ yield assert_shapes_correct, input_shapes[::-1], expected_shape
+
+def test_two_compatible_by_prepending_ones_input_shapes():
+ """ Check that two different input shapes (of different lengths) broadcast
+ to the correct shape.
+ """
+ data = [
+ [[(), (3,)], (3,)],
+ [[(3,), (3,3)], (3,3)],
+ [[(3,), (3,1)], (3,3)],
+ [[(1,), (3,3)], (3,3)],
+ [[(), (3,3)], (3,3)],
+ [[(1,1), (3,)], (1,3)],
+ [[(1,), (3,1)], (3,1)],
+ [[(1,), (1,3)], (1,3)],
+ [[(), (1,3)], (1,3)],
+ [[(), (3,1)], (3,1)],
+ [[(), (0,)], (0,)],
+ [[(0,), (0,0)], (0,0)],
+ [[(0,), (0,1)], (0,0)],
+ [[(1,), (0,0)], (0,0)],
+ [[(), (0,0)], (0,0)],
+ [[(1,1), (0,)], (1,0)],
+ [[(1,), (0,1)], (0,1)],
+ [[(1,), (1,0)], (1,0)],
+ [[(), (1,0)], (1,0)],
+ [[(), (0,1)], (0,1)],
+ ]
+ for input_shapes, expected_shape in data:
+ yield assert_shapes_correct, input_shapes, expected_shape
+ # Reverse the input shapes since broadcasting should be symmetric.
+ yield assert_shapes_correct, input_shapes[::-1], expected_shape
+
+def test_incompatible_shapes_raise_valueerror():
+ """ Check that a ValueError is raised for incompatible shapes.
+ """
+ data = [
+ [(3,), (4,)],
+ [(2,3), (2,)],
+ [(3,), (3,), (4,)],
+ [(1,3,4), (2,3,3)],
+ ]
+ for input_shapes in data:
+ yield assert_incompatible_shapes_raise, input_shapes
+ # Reverse the input shapes since broadcasting should be symmetric.
+ yield assert_incompatible_shapes_raise, input_shapes[::-1]
+
+def test_same_as_ufunc():
+ """ Check that the data layout is the same as if a ufunc did the operation.
+ """
+ data = [
+ [[(1,), (3,)], (3,)],
+ [[(1,3), (3,3)], (3,3)],
+ [[(3,1), (3,3)], (3,3)],
+ [[(1,3), (3,1)], (3,3)],
+ [[(1,1), (3,3)], (3,3)],
+ [[(1,1), (1,3)], (1,3)],
+ [[(1,1), (3,1)], (3,1)],
+ [[(1,0), (0,0)], (0,0)],
+ [[(0,1), (0,0)], (0,0)],
+ [[(1,0), (0,1)], (0,0)],
+ [[(1,1), (0,0)], (0,0)],
+ [[(1,1), (1,0)], (1,0)],
+ [[(1,1), (0,1)], (0,1)],
+ [[(), (3,)], (3,)],
+ [[(3,), (3,3)], (3,3)],
+ [[(3,), (3,1)], (3,3)],
+ [[(1,), (3,3)], (3,3)],
+ [[(), (3,3)], (3,3)],
+ [[(1,1), (3,)], (1,3)],
+ [[(1,), (3,1)], (3,1)],
+ [[(1,), (1,3)], (1,3)],
+ [[(), (1,3)], (1,3)],
+ [[(), (3,1)], (3,1)],
+ [[(), (0,)], (0,)],
+ [[(0,), (0,0)], (0,0)],
+ [[(0,), (0,1)], (0,0)],
+ [[(1,), (0,0)], (0,0)],
+ [[(), (0,0)], (0,0)],
+ [[(1,1), (0,)], (1,0)],
+ [[(1,), (0,1)], (0,1)],
+ [[(1,), (1,0)], (1,0)],
+ [[(), (1,0)], (1,0)],
+ [[(), (0,1)], (0,1)],
+ ]
+ for input_shapes, expected_shape in data:
+ yield assert_same_as_ufunc, input_shapes[0], input_shapes[1]
+ # Reverse the input shapes since broadcasting should be symmetric.
+ yield assert_same_as_ufunc, input_shapes[1], input_shapes[0]
+ # Try them transposed, too.
+ yield assert_same_as_ufunc, input_shapes[0], input_shapes[1], True
+ # ... and flipped for non-rank-0 inputs in order to test negative
+ # strides.
+ if () not in input_shapes:
+ yield assert_same_as_ufunc, input_shapes[0], input_shapes[1], False, True
+ yield assert_same_as_ufunc, input_shapes[0], input_shapes[1], True, True