summaryrefslogtreecommitdiff
path: root/numpy/array_api/__init__.py
diff options
context:
space:
mode:
authorAaron Meurer <asmeurer@gmail.com>2021-09-25 17:34:22 -0500
committerGitHub <noreply@github.com>2021-09-25 16:34:22 -0600
commit2d112a98ed7597c4120b31908384ae09b0304659 (patch)
treea03fcf59a0ca9cfff10ca2b346bd4c9d37268451 /numpy/array_api/__init__.py
parentac78192390943d90ebae2f4e209e194914d0bc97 (diff)
downloadnumpy-2d112a98ed7597c4120b31908384ae09b0304659.tar.gz
ENH: Updates to numpy.array_api (#19937)
* Add __index__ to array_api and update __int__, __bool__, and __float__ The spec specifies that they should only work on arrays with corresponding dtypes. __index__ is new in the spec since the initial PR, and works identically to np.array.__index__. * Add the to_device method to the array_api This method is new since #18585. It does nothing in NumPy since NumPy does not support non-CPU devices. * Update transpose methods in the array_api transpose() was renamed to matrix_transpose() and now operates on stacks of matrices. A function to permute dimensions will be added once it is finalized in the spec. The attribute mT was added and the T attribute was updated to only operate on 2-dimensional arrays as per the spec. * Restrict input dtypes in the array API statistical functions * Add the dtype parameter to the array API sum() and prod() * Add the function permute_dims() to the array_api namespace permute_dims() is the replacement for transpose(), which was split into permute_dims() and matrix_transpose(). * Add tril and triu to the array API namespace * Fix the array_api Array.__repr__ to indent the array properly * Make the Device type in the array_api just accept the string "cpu"
Diffstat (limited to 'numpy/array_api/__init__.py')
-rw-r--r--numpy/array_api/__init__.py11
1 files changed, 8 insertions, 3 deletions
diff --git a/numpy/array_api/__init__.py b/numpy/array_api/__init__.py
index 790157504..d8b29057e 100644
--- a/numpy/array_api/__init__.py
+++ b/numpy/array_api/__init__.py
@@ -143,6 +143,8 @@ from ._creation_functions import (
meshgrid,
ones,
ones_like,
+ tril,
+ triu,
zeros,
zeros_like,
)
@@ -160,6 +162,8 @@ __all__ += [
"meshgrid",
"ones",
"ones_like",
+ "tril",
+ "triu",
"zeros",
"zeros_like",
]
@@ -333,21 +337,22 @@ __all__ += [
# from ._linear_algebra_functions import einsum
# __all__ += ['einsum']
-from ._linear_algebra_functions import matmul, tensordot, transpose, vecdot
+from ._linear_algebra_functions import matmul, tensordot, matrix_transpose, vecdot
-__all__ += ["matmul", "tensordot", "transpose", "vecdot"]
+__all__ += ["matmul", "tensordot", "matrix_transpose", "vecdot"]
from ._manipulation_functions import (
concat,
expand_dims,
flip,
+ permute_dims,
reshape,
roll,
squeeze,
stack,
)
-__all__ += ["concat", "expand_dims", "flip", "reshape", "roll", "squeeze", "stack"]
+__all__ += ["concat", "expand_dims", "flip", "permute_dims", "reshape", "roll", "squeeze", "stack"]
from ._searching_functions import argmax, argmin, nonzero, where