diff options
author | Robert Kern <robert.kern@gmail.com> | 2008-07-03 06:42:28 +0000 |
---|---|---|
committer | Robert Kern <robert.kern@gmail.com> | 2008-07-03 06:42:28 +0000 |
commit | f912322ee454fee1d15da01c5fd951ab2fcb5f99 (patch) | |
tree | 9ef9a29f9573775cb3785a78b6838d9f4c3cbf29 /numpy/lib/tests/test_stride_tricks.py | |
parent | a74f0dfbcdfaf0ed5929fed7a27dc8738709828f (diff) | |
download | numpy-f912322ee454fee1d15da01c5fd951ab2fcb5f99.tar.gz |
ENH: Add broadcast_arrays() function to expose broadcasting to pure Python functions that cannot be made to be ufuncs.
Diffstat (limited to 'numpy/lib/tests/test_stride_tricks.py')
-rw-r--r-- | numpy/lib/tests/test_stride_tricks.py | 206 |
1 files changed, 206 insertions, 0 deletions
diff --git a/numpy/lib/tests/test_stride_tricks.py b/numpy/lib/tests/test_stride_tricks.py new file mode 100644 index 000000000..955a2cbc7 --- /dev/null +++ b/numpy/lib/tests/test_stride_tricks.py @@ -0,0 +1,206 @@ +from nose.tools import assert_raises +import numpy as np +from numpy.testing import assert_array_equal + +from numpy.lib.stride_tricks import broadcast_arrays + + +def assert_shapes_correct(input_shapes, expected_shape): + """ Broadcast a list of arrays with the given input shapes and check the + common output shape. + """ + inarrays = [np.zeros(s) for s in input_shapes] + outarrays = broadcast_arrays(*inarrays) + outshapes = [a.shape for a in outarrays] + expected = [expected_shape] * len(inarrays) + assert outshapes == expected + +def assert_incompatible_shapes_raise(input_shapes): + """ Broadcast a list of arrays with the given (incompatible) input shapes + and check that they raise a ValueError. + """ + inarrays = [np.zeros(s) for s in input_shapes] + assert_raises(ValueError, broadcast_arrays, *inarrays) + +def assert_same_as_ufunc(shape0, shape1, transposed=False, flipped=False): + """ Broadcast two shapes against each other and check that the data layout + is the same as if a ufunc did the broadcasting. + """ + x0 = np.zeros(shape0, dtype=int) + # Note that multiply.reduce's identity element is 1.0, so when shape1==(), + # this gives the desired n==1. + n = int(np.multiply.reduce(shape1)) + x1 = np.arange(n).reshape(shape1) + if transposed: + x0 = x0.T + x1 = x1.T + if flipped: + x0 = x0[::-1] + x1 = x1[::-1] + # Use the add ufunc to do the broadcasting. Since we're adding 0s to x1, the + # result should be exactly the same as the broadcasted view of x1. + y = x0 + x1 + b0, b1 = broadcast_arrays(x0, x1) + assert_array_equal(y, b1) + + +def test_same(): + x = np.arange(10) + y = np.arange(10) + bx, by = broadcast_arrays(x, y) + assert_array_equal(x, bx) + assert_array_equal(y, by) + +def test_one_off(): + x = np.array([[1,2,3]]) + y = np.array([[1],[2],[3]]) + bx, by = broadcast_arrays(x, y) + bx0 = np.array([[1,2,3],[1,2,3],[1,2,3]]) + by0 = bx0.T + assert_array_equal(bx0, bx) + assert_array_equal(by0, by) + +def test_same_input_shapes(): + """ Check that the final shape is just the input shape. + """ + data = [ + (), + (1,), + (3,), + (0,1), + (0,3), + (1,0), + (3,0), + (1,3), + (3,1), + (3,3), + ] + for shape in data: + input_shapes = [shape] + # Single input. + yield assert_shapes_correct, input_shapes, shape + # Double input. + input_shapes2 = [shape, shape] + yield assert_shapes_correct, input_shapes2, shape + # Triple input. + input_shapes3 = [shape, shape, shape] + yield assert_shapes_correct, input_shapes3, shape + +def test_two_compatible_by_ones_input_shapes(): + """ Check that two different input shapes (of the same length but some have + 1s) broadcast to the correct shape. + """ + data = [ + [[(1,), (3,)], (3,)], + [[(1,3), (3,3)], (3,3)], + [[(3,1), (3,3)], (3,3)], + [[(1,3), (3,1)], (3,3)], + [[(1,1), (3,3)], (3,3)], + [[(1,1), (1,3)], (1,3)], + [[(1,1), (3,1)], (3,1)], + [[(1,0), (0,0)], (0,0)], + [[(0,1), (0,0)], (0,0)], + [[(1,0), (0,1)], (0,0)], + [[(1,1), (0,0)], (0,0)], + [[(1,1), (1,0)], (1,0)], + [[(1,1), (0,1)], (0,1)], + ] + for input_shapes, expected_shape in data: + yield assert_shapes_correct, input_shapes, expected_shape + # Reverse the input shapes since broadcasting should be symmetric. + yield assert_shapes_correct, input_shapes[::-1], expected_shape + +def test_two_compatible_by_prepending_ones_input_shapes(): + """ Check that two different input shapes (of different lengths) broadcast + to the correct shape. + """ + data = [ + [[(), (3,)], (3,)], + [[(3,), (3,3)], (3,3)], + [[(3,), (3,1)], (3,3)], + [[(1,), (3,3)], (3,3)], + [[(), (3,3)], (3,3)], + [[(1,1), (3,)], (1,3)], + [[(1,), (3,1)], (3,1)], + [[(1,), (1,3)], (1,3)], + [[(), (1,3)], (1,3)], + [[(), (3,1)], (3,1)], + [[(), (0,)], (0,)], + [[(0,), (0,0)], (0,0)], + [[(0,), (0,1)], (0,0)], + [[(1,), (0,0)], (0,0)], + [[(), (0,0)], (0,0)], + [[(1,1), (0,)], (1,0)], + [[(1,), (0,1)], (0,1)], + [[(1,), (1,0)], (1,0)], + [[(), (1,0)], (1,0)], + [[(), (0,1)], (0,1)], + ] + for input_shapes, expected_shape in data: + yield assert_shapes_correct, input_shapes, expected_shape + # Reverse the input shapes since broadcasting should be symmetric. + yield assert_shapes_correct, input_shapes[::-1], expected_shape + +def test_incompatible_shapes_raise_valueerror(): + """ Check that a ValueError is raised for incompatible shapes. + """ + data = [ + [(3,), (4,)], + [(2,3), (2,)], + [(3,), (3,), (4,)], + [(1,3,4), (2,3,3)], + ] + for input_shapes in data: + yield assert_incompatible_shapes_raise, input_shapes + # Reverse the input shapes since broadcasting should be symmetric. + yield assert_incompatible_shapes_raise, input_shapes[::-1] + +def test_same_as_ufunc(): + """ Check that the data layout is the same as if a ufunc did the operation. + """ + data = [ + [[(1,), (3,)], (3,)], + [[(1,3), (3,3)], (3,3)], + [[(3,1), (3,3)], (3,3)], + [[(1,3), (3,1)], (3,3)], + [[(1,1), (3,3)], (3,3)], + [[(1,1), (1,3)], (1,3)], + [[(1,1), (3,1)], (3,1)], + [[(1,0), (0,0)], (0,0)], + [[(0,1), (0,0)], (0,0)], + [[(1,0), (0,1)], (0,0)], + [[(1,1), (0,0)], (0,0)], + [[(1,1), (1,0)], (1,0)], + [[(1,1), (0,1)], (0,1)], + [[(), (3,)], (3,)], + [[(3,), (3,3)], (3,3)], + [[(3,), (3,1)], (3,3)], + [[(1,), (3,3)], (3,3)], + [[(), (3,3)], (3,3)], + [[(1,1), (3,)], (1,3)], + [[(1,), (3,1)], (3,1)], + [[(1,), (1,3)], (1,3)], + [[(), (1,3)], (1,3)], + [[(), (3,1)], (3,1)], + [[(), (0,)], (0,)], + [[(0,), (0,0)], (0,0)], + [[(0,), (0,1)], (0,0)], + [[(1,), (0,0)], (0,0)], + [[(), (0,0)], (0,0)], + [[(1,1), (0,)], (1,0)], + [[(1,), (0,1)], (0,1)], + [[(1,), (1,0)], (1,0)], + [[(), (1,0)], (1,0)], + [[(), (0,1)], (0,1)], + ] + for input_shapes, expected_shape in data: + yield assert_same_as_ufunc, input_shapes[0], input_shapes[1] + # Reverse the input shapes since broadcasting should be symmetric. + yield assert_same_as_ufunc, input_shapes[1], input_shapes[0] + # Try them transposed, too. + yield assert_same_as_ufunc, input_shapes[0], input_shapes[1], True + # ... and flipped for non-rank-0 inputs in order to test negative + # strides. + if () not in input_shapes: + yield assert_same_as_ufunc, input_shapes[0], input_shapes[1], False, True + yield assert_same_as_ufunc, input_shapes[0], input_shapes[1], True, True |