summaryrefslogtreecommitdiff
path: root/numpy/array_api/_manipulation_functions.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/array_api/_manipulation_functions.py')
-rw-r--r--numpy/array_api/_manipulation_functions.py18
1 files changed, 16 insertions, 2 deletions
diff --git a/numpy/array_api/_manipulation_functions.py b/numpy/array_api/_manipulation_functions.py
index 33f5d5a28..c11866261 100644
--- a/numpy/array_api/_manipulation_functions.py
+++ b/numpy/array_api/_manipulation_functions.py
@@ -8,7 +8,9 @@ from typing import List, Optional, Tuple, Union
import numpy as np
# Note: the function name is different here
-def concat(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0) -> Array:
+def concat(
+ arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0
+) -> Array:
"""
Array API compatible wrapper for :py:func:`np.concatenate <numpy.concatenate>`.
@@ -20,6 +22,7 @@ def concat(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[i
arrays = tuple(a._array for a in arrays)
return Array._new(np.concatenate(arrays, axis=axis, dtype=dtype))
+
def expand_dims(x: Array, /, *, axis: int) -> Array:
"""
Array API compatible wrapper for :py:func:`np.expand_dims <numpy.expand_dims>`.
@@ -28,6 +31,7 @@ def expand_dims(x: Array, /, *, axis: int) -> Array:
"""
return Array._new(np.expand_dims(x._array, axis))
+
def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array:
"""
Array API compatible wrapper for :py:func:`np.flip <numpy.flip>`.
@@ -36,6 +40,7 @@ def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) ->
"""
return Array._new(np.flip(x._array, axis=axis))
+
def reshape(x: Array, /, shape: Tuple[int, ...]) -> Array:
"""
Array API compatible wrapper for :py:func:`np.reshape <numpy.reshape>`.
@@ -44,7 +49,14 @@ def reshape(x: Array, /, shape: Tuple[int, ...]) -> Array:
"""
return Array._new(np.reshape(x._array, shape))
-def roll(x: Array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array:
+
+def roll(
+ x: Array,
+ /,
+ shift: Union[int, Tuple[int, ...]],
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+) -> Array:
"""
Array API compatible wrapper for :py:func:`np.roll <numpy.roll>`.
@@ -52,6 +64,7 @@ def roll(x: Array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Unio
"""
return Array._new(np.roll(x._array, shift, axis=axis))
+
def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array:
"""
Array API compatible wrapper for :py:func:`np.squeeze <numpy.squeeze>`.
@@ -60,6 +73,7 @@ def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array:
"""
return Array._new(np.squeeze(x._array, axis=axis))
+
def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> Array:
"""
Array API compatible wrapper for :py:func:`np.stack <numpy.stack>`.