diff options
author | madhulikajc <53166646+madhulikajc@users.noreply.github.com> | 2020-10-17 06:58:33 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-10-17 15:58:33 +0200 |
commit | 7b0a764fee6e1614f3249e9082d8c4acf1dc62d5 (patch) | |
tree | 40a1010555bb50448d16a63186235453a694efb4 | |
parent | 2ce4ab4063457f558e5aab1051aedc68ea17eb25 (diff) | |
download | numpy-7b0a764fee6e1614f3249e9082d8c4acf1dc62d5.tar.gz |
ENH: add function to get broadcast shape from a given set of shapes. (#17535)
* ENH: add function to get broadcast shape from a given set of shapes.
Add new function numpy.broadcast_shape which takes tuples
for the shapes to be broadcast against each other.
Return the broadcasted shape as a tuple.
See #17217
* Perform array allocations of size 0 for provided shape tuples
Co-authored-by: Eric Wieser <wieser.eric@gmail.com>
* Test for int as input shape
Also update docstring to include both ints and tuples of ints as input
* Remove unnecessary array_function_dispatch
* Add missing set_module
* Add release notes. Add versionadded to docstring.
Also fix up docstring details.
* follow convention for trailing comma
Co-authored-by: Sebastian Berg <sebastian@sipsolutions.net>
* Change name to broadcast_shapes. Also add test case, and type hint.
* follow convention
Co-authored-by: Eric Wieser <wieser.eric@gmail.com>
* Update docstring
Co-authored-by: Eric Wieser <wieser.eric@gmail.com>
* Add reference to numpy docs on broadcasting to docstring
Also move versionadded
* Fix spelling
Co-authored-by: Warren Weckesser <warren.weckesser@gmail.com>
* Add broadcast_shapes to reference docs and add See Also sections
Co-authored-by: Eric Wieser <wieser.eric@gmail.com>
Co-authored-by: Sebastian Berg <sebastian@sipsolutions.net>
Co-authored-by: Warren Weckesser <warren.weckesser@gmail.com>
-rw-r--r-- | doc/release/upcoming_changes/17535.new_function.rst | 15 | ||||
-rw-r--r-- | doc/source/reference/routines.other.rst | 1 | ||||
-rw-r--r-- | numpy/__init__.pyi | 2 | ||||
-rw-r--r-- | numpy/core/_add_newdocs.py | 1 | ||||
-rw-r--r-- | numpy/lib/stride_tricks.py | 59 | ||||
-rw-r--r-- | numpy/lib/tests/test_stride_tricks.py | 65 |
6 files changed, 139 insertions, 4 deletions
diff --git a/doc/release/upcoming_changes/17535.new_function.rst b/doc/release/upcoming_changes/17535.new_function.rst new file mode 100644 index 000000000..4c3c11de4 --- /dev/null +++ b/doc/release/upcoming_changes/17535.new_function.rst @@ -0,0 +1,15 @@ +`numpy.broadcast_shapes` is a new user-facing function +------------------------------------------------------ +`broadcast_shapes` gets the resulting shape from +broadcasting the given shape tuples against each other. + +.. code:: python + + >>> np.broadcast_shapes((1, 2), (3, 1)) + (3, 2) + + >>> np.broadcast_shapes(2, (3, 1)) + (3, 2) + + >>> np.broadcast_shapes((6, 7), (5, 6, 1), (7,), (5, 1, 7)) + (5, 6, 7) diff --git a/doc/source/reference/routines.other.rst b/doc/source/reference/routines.other.rst index def5b3e3c..aefd680bb 100644 --- a/doc/source/reference/routines.other.rst +++ b/doc/source/reference/routines.other.rst @@ -47,6 +47,7 @@ Utility show_config deprecate deprecate_with_doc + broadcast_shapes Matlab-like Functions --------------------- diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi index 80fb213fd..2fff82d59 100644 --- a/numpy/__init__.pyi +++ b/numpy/__init__.pyi @@ -1926,6 +1926,8 @@ def empty( like: ArrayLike = ..., ) -> ndarray: ... +def broadcast_shapes(*args: _ShapeLike) -> _Shape: ... + # # Constants # diff --git a/numpy/core/_add_newdocs.py b/numpy/core/_add_newdocs.py index aa858761d..c9968f122 100644 --- a/numpy/core/_add_newdocs.py +++ b/numpy/core/_add_newdocs.py @@ -605,6 +605,7 @@ add_newdoc('numpy.core', 'broadcast', -------- broadcast_arrays broadcast_to + broadcast_shapes Examples -------- diff --git a/numpy/lib/stride_tricks.py b/numpy/lib/stride_tricks.py index 502235bdf..d8a8b325e 100644 --- a/numpy/lib/stride_tricks.py +++ b/numpy/lib/stride_tricks.py @@ -6,9 +6,9 @@ NumPy reference guide. """ import numpy as np -from numpy.core.overrides import array_function_dispatch +from numpy.core.overrides import array_function_dispatch, set_module -__all__ = ['broadcast_to', 'broadcast_arrays'] +__all__ = ['broadcast_to', 'broadcast_arrays', 'broadcast_shapes'] class DummyArray: @@ -165,6 +165,12 @@ def broadcast_to(array, shape, subok=False): If the array is not compatible with the new shape according to NumPy's broadcasting rules. + See Also + -------- + broadcast + broadcast_arrays + broadcast_shapes + Notes ----- .. versionadded:: 1.10.0 @@ -197,6 +203,49 @@ def _broadcast_shape(*args): return b.shape +@set_module('numpy') +def broadcast_shapes(*args): + """ + Broadcast the input shapes into a single shape. + + :ref:`Learn more about broadcasting here <basics.broadcasting>`. + + .. versionadded:: 1.20.0 + + Parameters + ---------- + `*args` : tuples of ints, or ints + The shapes to be broadcast against each other. + + Returns + ------- + tuple + Broadcasted shape. + + Raises + ------ + ValueError + If the shapes are not compatible and cannot be broadcast according + to NumPy's broadcasting rules. + + See Also + -------- + broadcast + broadcast_arrays + broadcast_to + + Examples + -------- + >>> np.broadcast_shapes((1, 2), (3, 1), (3, 2)) + (3, 2) + + >>> np.broadcast_shapes((6, 7), (5, 6, 1), (7,), (5, 1, 7)) + (5, 6, 7) + """ + arrays = [np.empty(x, dtype=[]) for x in args] + return _broadcast_shape(*arrays) + + def _broadcast_arrays_dispatcher(*args, subok=None): return args @@ -230,6 +279,12 @@ def broadcast_arrays(*args, subok=False): warning will be emitted. A future version will set the ``writable`` flag False so writing to it will raise an error. + See Also + -------- + broadcast + broadcast_to + broadcast_shapes + Examples -------- >>> x = np.array([[1,2,3]]) diff --git a/numpy/lib/tests/test_stride_tricks.py b/numpy/lib/tests/test_stride_tricks.py index 9d95eb9d0..10d7a19ab 100644 --- a/numpy/lib/tests/test_stride_tricks.py +++ b/numpy/lib/tests/test_stride_tricks.py @@ -5,7 +5,8 @@ from numpy.testing import ( assert_raises_regex, assert_warns, ) from numpy.lib.stride_tricks import ( - as_strided, broadcast_arrays, _broadcast_shape, broadcast_to + as_strided, broadcast_arrays, _broadcast_shape, broadcast_to, + broadcast_shapes, ) def assert_shapes_correct(input_shapes, expected_shape): @@ -274,7 +275,9 @@ def test_broadcast_to_raises(): def test_broadcast_shape(): - # broadcast_shape is already exercized indirectly by broadcast_arrays + # tests internal _broadcast_shape + # _broadcast_shape is already exercised indirectly by broadcast_arrays + # _broadcast_shape is also exercised by the public broadcast_shapes function assert_equal(_broadcast_shape(), ()) assert_equal(_broadcast_shape([1, 2]), (2,)) assert_equal(_broadcast_shape(np.ones((1, 1))), (1, 1)) @@ -288,6 +291,64 @@ def test_broadcast_shape(): assert_raises(ValueError, lambda: _broadcast_shape(*bad_args)) +def test_broadcast_shapes_succeeds(): + # tests public broadcast_shapes + data = [ + [[], ()], + [[()], ()], + [[(7,)], (7,)], + [[(1, 2), (2,)], (1, 2)], + [[(1, 1)], (1, 1)], + [[(1, 1), (3, 4)], (3, 4)], + [[(6, 7), (5, 6, 1), (7,), (5, 1, 7)], (5, 6, 7)], + [[(5, 6, 1)], (5, 6, 1)], + [[(1, 3), (3, 1)], (3, 3)], + [[(1, 0), (0, 0)], (0, 0)], + [[(0, 1), (0, 0)], (0, 0)], + [[(1, 0), (0, 1)], (0, 0)], + [[(1, 1), (0, 0)], (0, 0)], + [[(1, 1), (1, 0)], (1, 0)], + [[(1, 1), (0, 1)], (0, 1)], + [[(), (0,)], (0,)], + [[(0,), (0, 0)], (0, 0)], + [[(0,), (0, 1)], (0, 0)], + [[(1,), (0, 0)], (0, 0)], + [[(), (0, 0)], (0, 0)], + [[(1, 1), (0,)], (1, 0)], + [[(1,), (0, 1)], (0, 1)], + [[(1,), (1, 0)], (1, 0)], + [[(), (1, 0)], (1, 0)], + [[(), (0, 1)], (0, 1)], + [[(1,), (3,)], (3,)], + [[2, (3, 2)], (3, 2)], + ] + for input_shapes, target_shape in data: + assert_equal(broadcast_shapes(*input_shapes), target_shape) + + assert_equal(broadcast_shapes(*([(1, 2)] * 32)), (1, 2)) + assert_equal(broadcast_shapes(*([(1, 2)] * 100)), (1, 2)) + + # regression tests for gh-5862 + assert_equal(broadcast_shapes(*([(2,)] * 32)), (2,)) + + +def test_broadcast_shapes_raises(): + # tests public broadcast_shapes + data = [ + [(3,), (4,)], + [(2, 3), (2,)], + [(3,), (3,), (4,)], + [(1, 3, 4), (2, 3, 3)], + [(1, 2), (3,1), (3,2), (10, 5)], + [2, (2, 3)], + ] + for input_shapes in data: + assert_raises(ValueError, lambda: broadcast_shapes(*input_shapes)) + + bad_args = [(2,)] * 32 + [(3,)] * 32 + assert_raises(ValueError, lambda: broadcast_shapes(*bad_args)) + + def test_as_strided(): a = np.array([None]) a_view = as_strided(a) |