summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormadhulikajc <53166646+madhulikajc@users.noreply.github.com>2020-10-17 06:58:33 -0700
committerGitHub <noreply@github.com>2020-10-17 15:58:33 +0200
commit7b0a764fee6e1614f3249e9082d8c4acf1dc62d5 (patch)
tree40a1010555bb50448d16a63186235453a694efb4
parent2ce4ab4063457f558e5aab1051aedc68ea17eb25 (diff)
downloadnumpy-7b0a764fee6e1614f3249e9082d8c4acf1dc62d5.tar.gz
ENH: add function to get broadcast shape from a given set of shapes. (#17535)
* ENH: add function to get broadcast shape from a given set of shapes. Add new function numpy.broadcast_shape which takes tuples for the shapes to be broadcast against each other. Return the broadcasted shape as a tuple. See #17217 * Perform array allocations of size 0 for provided shape tuples Co-authored-by: Eric Wieser <wieser.eric@gmail.com> * Test for int as input shape Also update docstring to include both ints and tuples of ints as input * Remove unnecessary array_function_dispatch * Add missing set_module * Add release notes. Add versionadded to docstring. Also fix up docstring details. * follow convention for trailing comma Co-authored-by: Sebastian Berg <sebastian@sipsolutions.net> * Change name to broadcast_shapes. Also add test case, and type hint. * follow convention Co-authored-by: Eric Wieser <wieser.eric@gmail.com> * Update docstring Co-authored-by: Eric Wieser <wieser.eric@gmail.com> * Add reference to numpy docs on broadcasting to docstring Also move versionadded * Fix spelling Co-authored-by: Warren Weckesser <warren.weckesser@gmail.com> * Add broadcast_shapes to reference docs and add See Also sections Co-authored-by: Eric Wieser <wieser.eric@gmail.com> Co-authored-by: Sebastian Berg <sebastian@sipsolutions.net> Co-authored-by: Warren Weckesser <warren.weckesser@gmail.com>
-rw-r--r--doc/release/upcoming_changes/17535.new_function.rst15
-rw-r--r--doc/source/reference/routines.other.rst1
-rw-r--r--numpy/__init__.pyi2
-rw-r--r--numpy/core/_add_newdocs.py1
-rw-r--r--numpy/lib/stride_tricks.py59
-rw-r--r--numpy/lib/tests/test_stride_tricks.py65
6 files changed, 139 insertions, 4 deletions
diff --git a/doc/release/upcoming_changes/17535.new_function.rst b/doc/release/upcoming_changes/17535.new_function.rst
new file mode 100644
index 000000000..4c3c11de4
--- /dev/null
+++ b/doc/release/upcoming_changes/17535.new_function.rst
@@ -0,0 +1,15 @@
+`numpy.broadcast_shapes` is a new user-facing function
+------------------------------------------------------
+`broadcast_shapes` gets the resulting shape from
+broadcasting the given shape tuples against each other.
+
+.. code:: python
+
+ >>> np.broadcast_shapes((1, 2), (3, 1))
+ (3, 2)
+
+ >>> np.broadcast_shapes(2, (3, 1))
+ (3, 2)
+
+ >>> np.broadcast_shapes((6, 7), (5, 6, 1), (7,), (5, 1, 7))
+ (5, 6, 7)
diff --git a/doc/source/reference/routines.other.rst b/doc/source/reference/routines.other.rst
index def5b3e3c..aefd680bb 100644
--- a/doc/source/reference/routines.other.rst
+++ b/doc/source/reference/routines.other.rst
@@ -47,6 +47,7 @@ Utility
show_config
deprecate
deprecate_with_doc
+ broadcast_shapes
Matlab-like Functions
---------------------
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi
index 80fb213fd..2fff82d59 100644
--- a/numpy/__init__.pyi
+++ b/numpy/__init__.pyi
@@ -1926,6 +1926,8 @@ def empty(
like: ArrayLike = ...,
) -> ndarray: ...
+def broadcast_shapes(*args: _ShapeLike) -> _Shape: ...
+
#
# Constants
#
diff --git a/numpy/core/_add_newdocs.py b/numpy/core/_add_newdocs.py
index aa858761d..c9968f122 100644
--- a/numpy/core/_add_newdocs.py
+++ b/numpy/core/_add_newdocs.py
@@ -605,6 +605,7 @@ add_newdoc('numpy.core', 'broadcast',
--------
broadcast_arrays
broadcast_to
+ broadcast_shapes
Examples
--------
diff --git a/numpy/lib/stride_tricks.py b/numpy/lib/stride_tricks.py
index 502235bdf..d8a8b325e 100644
--- a/numpy/lib/stride_tricks.py
+++ b/numpy/lib/stride_tricks.py
@@ -6,9 +6,9 @@ NumPy reference guide.
"""
import numpy as np
-from numpy.core.overrides import array_function_dispatch
+from numpy.core.overrides import array_function_dispatch, set_module
-__all__ = ['broadcast_to', 'broadcast_arrays']
+__all__ = ['broadcast_to', 'broadcast_arrays', 'broadcast_shapes']
class DummyArray:
@@ -165,6 +165,12 @@ def broadcast_to(array, shape, subok=False):
If the array is not compatible with the new shape according to NumPy's
broadcasting rules.
+ See Also
+ --------
+ broadcast
+ broadcast_arrays
+ broadcast_shapes
+
Notes
-----
.. versionadded:: 1.10.0
@@ -197,6 +203,49 @@ def _broadcast_shape(*args):
return b.shape
+@set_module('numpy')
+def broadcast_shapes(*args):
+ """
+ Broadcast the input shapes into a single shape.
+
+ :ref:`Learn more about broadcasting here <basics.broadcasting>`.
+
+ .. versionadded:: 1.20.0
+
+ Parameters
+ ----------
+ `*args` : tuples of ints, or ints
+ The shapes to be broadcast against each other.
+
+ Returns
+ -------
+ tuple
+ Broadcasted shape.
+
+ Raises
+ ------
+ ValueError
+ If the shapes are not compatible and cannot be broadcast according
+ to NumPy's broadcasting rules.
+
+ See Also
+ --------
+ broadcast
+ broadcast_arrays
+ broadcast_to
+
+ Examples
+ --------
+ >>> np.broadcast_shapes((1, 2), (3, 1), (3, 2))
+ (3, 2)
+
+ >>> np.broadcast_shapes((6, 7), (5, 6, 1), (7,), (5, 1, 7))
+ (5, 6, 7)
+ """
+ arrays = [np.empty(x, dtype=[]) for x in args]
+ return _broadcast_shape(*arrays)
+
+
def _broadcast_arrays_dispatcher(*args, subok=None):
return args
@@ -230,6 +279,12 @@ def broadcast_arrays(*args, subok=False):
warning will be emitted. A future version will set the
``writable`` flag False so writing to it will raise an error.
+ See Also
+ --------
+ broadcast
+ broadcast_to
+ broadcast_shapes
+
Examples
--------
>>> x = np.array([[1,2,3]])
diff --git a/numpy/lib/tests/test_stride_tricks.py b/numpy/lib/tests/test_stride_tricks.py
index 9d95eb9d0..10d7a19ab 100644
--- a/numpy/lib/tests/test_stride_tricks.py
+++ b/numpy/lib/tests/test_stride_tricks.py
@@ -5,7 +5,8 @@ from numpy.testing import (
assert_raises_regex, assert_warns,
)
from numpy.lib.stride_tricks import (
- as_strided, broadcast_arrays, _broadcast_shape, broadcast_to
+ as_strided, broadcast_arrays, _broadcast_shape, broadcast_to,
+ broadcast_shapes,
)
def assert_shapes_correct(input_shapes, expected_shape):
@@ -274,7 +275,9 @@ def test_broadcast_to_raises():
def test_broadcast_shape():
- # broadcast_shape is already exercized indirectly by broadcast_arrays
+ # tests internal _broadcast_shape
+ # _broadcast_shape is already exercised indirectly by broadcast_arrays
+ # _broadcast_shape is also exercised by the public broadcast_shapes function
assert_equal(_broadcast_shape(), ())
assert_equal(_broadcast_shape([1, 2]), (2,))
assert_equal(_broadcast_shape(np.ones((1, 1))), (1, 1))
@@ -288,6 +291,64 @@ def test_broadcast_shape():
assert_raises(ValueError, lambda: _broadcast_shape(*bad_args))
+def test_broadcast_shapes_succeeds():
+ # tests public broadcast_shapes
+ data = [
+ [[], ()],
+ [[()], ()],
+ [[(7,)], (7,)],
+ [[(1, 2), (2,)], (1, 2)],
+ [[(1, 1)], (1, 1)],
+ [[(1, 1), (3, 4)], (3, 4)],
+ [[(6, 7), (5, 6, 1), (7,), (5, 1, 7)], (5, 6, 7)],
+ [[(5, 6, 1)], (5, 6, 1)],
+ [[(1, 3), (3, 1)], (3, 3)],
+ [[(1, 0), (0, 0)], (0, 0)],
+ [[(0, 1), (0, 0)], (0, 0)],
+ [[(1, 0), (0, 1)], (0, 0)],
+ [[(1, 1), (0, 0)], (0, 0)],
+ [[(1, 1), (1, 0)], (1, 0)],
+ [[(1, 1), (0, 1)], (0, 1)],
+ [[(), (0,)], (0,)],
+ [[(0,), (0, 0)], (0, 0)],
+ [[(0,), (0, 1)], (0, 0)],
+ [[(1,), (0, 0)], (0, 0)],
+ [[(), (0, 0)], (0, 0)],
+ [[(1, 1), (0,)], (1, 0)],
+ [[(1,), (0, 1)], (0, 1)],
+ [[(1,), (1, 0)], (1, 0)],
+ [[(), (1, 0)], (1, 0)],
+ [[(), (0, 1)], (0, 1)],
+ [[(1,), (3,)], (3,)],
+ [[2, (3, 2)], (3, 2)],
+ ]
+ for input_shapes, target_shape in data:
+ assert_equal(broadcast_shapes(*input_shapes), target_shape)
+
+ assert_equal(broadcast_shapes(*([(1, 2)] * 32)), (1, 2))
+ assert_equal(broadcast_shapes(*([(1, 2)] * 100)), (1, 2))
+
+ # regression tests for gh-5862
+ assert_equal(broadcast_shapes(*([(2,)] * 32)), (2,))
+
+
+def test_broadcast_shapes_raises():
+ # tests public broadcast_shapes
+ data = [
+ [(3,), (4,)],
+ [(2, 3), (2,)],
+ [(3,), (3,), (4,)],
+ [(1, 3, 4), (2, 3, 3)],
+ [(1, 2), (3,1), (3,2), (10, 5)],
+ [2, (2, 3)],
+ ]
+ for input_shapes in data:
+ assert_raises(ValueError, lambda: broadcast_shapes(*input_shapes))
+
+ bad_args = [(2,)] * 32 + [(3,)] * 32
+ assert_raises(ValueError, lambda: broadcast_shapes(*bad_args))
+
+
def test_as_strided():
a = np.array([None])
a_view = as_strided(a)