summaryrefslogtreecommitdiff
path: root/numpy/array_api/_creation_functions.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/array_api/_creation_functions.py')
-rw-r--r--numpy/array_api/_creation_functions.py34
1 files changed, 31 insertions, 3 deletions
diff --git a/numpy/array_api/_creation_functions.py b/numpy/array_api/_creation_functions.py
index 2d6cf4414..d760bf2fc 100644
--- a/numpy/array_api/_creation_functions.py
+++ b/numpy/array_api/_creation_functions.py
@@ -22,7 +22,7 @@ def _check_valid_dtype(dtype):
# Note: Only spelling dtypes as the dtype objects is supported.
# We use this instead of "dtype in _all_dtypes" because the dtype objects
- # define equality with the sorts of things we want to disallw.
+ # define equality with the sorts of things we want to disallow.
for d in (None,) + _all_dtypes:
if dtype is d:
return
@@ -134,7 +134,7 @@ def eye(
n_cols: Optional[int] = None,
/,
*,
- k: Optional[int] = 0,
+ k: int = 0,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
) -> Array:
@@ -232,7 +232,7 @@ def linspace(
return Array._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint))
-def meshgrid(*arrays: Sequence[Array], indexing: str = "xy") -> List[Array, ...]:
+def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]:
"""
Array API compatible wrapper for :py:func:`np.meshgrid <numpy.meshgrid>`.
@@ -281,6 +281,34 @@ def ones_like(
return Array._new(np.ones_like(x._array, dtype=dtype))
+def tril(x: Array, /, *, k: int = 0) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.tril <numpy.tril>`.
+
+ See its docstring for more information.
+ """
+ from ._array_object import Array
+
+ if x.ndim < 2:
+ # Note: Unlike np.tril, x must be at least 2-D
+ raise ValueError("x must be at least 2-dimensional for tril")
+ return Array._new(np.tril(x._array, k=k))
+
+
+def triu(x: Array, /, *, k: int = 0) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.triu <numpy.triu>`.
+
+ See its docstring for more information.
+ """
+ from ._array_object import Array
+
+ if x.ndim < 2:
+ # Note: Unlike np.triu, x must be at least 2-D
+ raise ValueError("x must be at least 2-dimensional for triu")
+ return Array._new(np.triu(x._array, k=k))
+
+
def zeros(
shape: Union[int, Tuple[int, ...]],
*,