diff options
Diffstat (limited to 'numpy/lib/stride_tricks.py')
-rw-r--r-- | numpy/lib/stride_tricks.py | 59 |
1 files changed, 57 insertions, 2 deletions
diff --git a/numpy/lib/stride_tricks.py b/numpy/lib/stride_tricks.py index 502235bdf..d8a8b325e 100644 --- a/numpy/lib/stride_tricks.py +++ b/numpy/lib/stride_tricks.py @@ -6,9 +6,9 @@ NumPy reference guide. """ import numpy as np -from numpy.core.overrides import array_function_dispatch +from numpy.core.overrides import array_function_dispatch, set_module -__all__ = ['broadcast_to', 'broadcast_arrays'] +__all__ = ['broadcast_to', 'broadcast_arrays', 'broadcast_shapes'] class DummyArray: @@ -165,6 +165,12 @@ def broadcast_to(array, shape, subok=False): If the array is not compatible with the new shape according to NumPy's broadcasting rules. + See Also + -------- + broadcast + broadcast_arrays + broadcast_shapes + Notes ----- .. versionadded:: 1.10.0 @@ -197,6 +203,49 @@ def _broadcast_shape(*args): return b.shape +@set_module('numpy') +def broadcast_shapes(*args): + """ + Broadcast the input shapes into a single shape. + + :ref:`Learn more about broadcasting here <basics.broadcasting>`. + + .. versionadded:: 1.20.0 + + Parameters + ---------- + `*args` : tuples of ints, or ints + The shapes to be broadcast against each other. + + Returns + ------- + tuple + Broadcasted shape. + + Raises + ------ + ValueError + If the shapes are not compatible and cannot be broadcast according + to NumPy's broadcasting rules. + + See Also + -------- + broadcast + broadcast_arrays + broadcast_to + + Examples + -------- + >>> np.broadcast_shapes((1, 2), (3, 1), (3, 2)) + (3, 2) + + >>> np.broadcast_shapes((6, 7), (5, 6, 1), (7,), (5, 1, 7)) + (5, 6, 7) + """ + arrays = [np.empty(x, dtype=[]) for x in args] + return _broadcast_shape(*arrays) + + def _broadcast_arrays_dispatcher(*args, subok=None): return args @@ -230,6 +279,12 @@ def broadcast_arrays(*args, subok=False): warning will be emitted. A future version will set the ``writable`` flag False so writing to it will raise an error. + See Also + -------- + broadcast + broadcast_to + broadcast_shapes + Examples -------- >>> x = np.array([[1,2,3]]) |