summaryrefslogtreecommitdiff
path: root/numpy/_array_api
diff options
context:
space:
mode:
authorAaron Meurer <asmeurer@gmail.com>2021-03-02 16:08:41 -0700
committerAaron Meurer <asmeurer@gmail.com>2021-03-02 16:08:41 -0700
commitf1f9dca83213aa4b0e6495900482f1a180a014ae (patch)
tree34e80bc180b3116693eba9baf1dd0c6d7c22e292 /numpy/_array_api
parentd40985cbb50aeadf276a1a9332e455ddc096e113 (diff)
downloadnumpy-f1f9dca83213aa4b0e6495900482f1a180a014ae.tar.gz
Make the array API manipulation functions use the array API ndarray object
Diffstat (limited to 'numpy/_array_api')
-rw-r--r--numpy/_array_api/_manipulation_functions.py17
1 files changed, 10 insertions, 7 deletions
diff --git a/numpy/_array_api/_manipulation_functions.py b/numpy/_array_api/_manipulation_functions.py
index e312b18c5..413dbb1b1 100644
--- a/numpy/_array_api/_manipulation_functions.py
+++ b/numpy/_array_api/_manipulation_functions.py
@@ -1,6 +1,7 @@
from __future__ import annotations
from ._types import Optional, Tuple, Union, array
+from ._array_object import ndarray
import numpy as np
@@ -10,8 +11,9 @@ def concat(arrays: Tuple[array], /, *, axis: Optional[int] = 0) -> array:
See its docstring for more information.
"""
+ arrays = tuple(a._array for a in arrays)
# Note: the function name is different here
- return np.concatenate(arrays, axis=axis)
+ return ndarray._new(np.concatenate(arrays, axis=axis))
def expand_dims(x: array, axis: int, /) -> array:
"""
@@ -19,7 +21,7 @@ def expand_dims(x: array, axis: int, /) -> array:
See its docstring for more information.
"""
- return np.expand_dims._implementation(x, axis)
+ return ndarray._new(np.expand_dims._implementation(x._array, axis))
def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> array:
"""
@@ -27,7 +29,7 @@ def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) ->
See its docstring for more information.
"""
- return np.flip._implementation(x, axis=axis)
+ return ndarray._new(np.flip._implementation(x._array, axis=axis))
def reshape(x: array, shape: Tuple[int, ...], /) -> array:
"""
@@ -35,7 +37,7 @@ def reshape(x: array, shape: Tuple[int, ...], /) -> array:
See its docstring for more information.
"""
- return np.reshape._implementation(x, shape)
+ return ndarray._new(np.reshape._implementation(x._array, shape))
def roll(x: array, shift: Union[int, Tuple[int, ...]], /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> array:
"""
@@ -43,7 +45,7 @@ def roll(x: array, shift: Union[int, Tuple[int, ...]], /, *, axis: Optional[Unio
See its docstring for more information.
"""
- return np.roll._implementation(x, shift, axis=axis)
+ return ndarray._new(np.roll._implementation(x._array, shift, axis=axis))
def squeeze(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> array:
"""
@@ -51,7 +53,7 @@ def squeeze(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None)
See its docstring for more information.
"""
- return np.squeeze._implementation(x, axis=axis)
+ return ndarray._array(np.squeeze._implementation(x._array, axis=axis))
def stack(arrays: Tuple[array], /, *, axis: int = 0) -> array:
"""
@@ -59,4 +61,5 @@ def stack(arrays: Tuple[array], /, *, axis: int = 0) -> array:
See its docstring for more information.
"""
- return np.stack._implementation(arrays, axis=axis)
+ arrays = tuple(a._array for a in arrays)
+ return ndarray._array(np.stack._implementation(arrays, axis=axis))