summaryrefslogtreecommitdiff
path: root/numpy/array_api
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/array_api')
-rw-r--r--numpy/array_api/_array_object.py5
-rw-r--r--numpy/array_api/_data_type_functions.py3
-rw-r--r--numpy/array_api/_sorting_functions.py3
-rw-r--r--numpy/array_api/linalg.py10
4 files changed, 16 insertions, 5 deletions
diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py
index c86cfb3a6..6cf9ec6f3 100644
--- a/numpy/array_api/_array_object.py
+++ b/numpy/array_api/_array_object.py
@@ -177,6 +177,8 @@ class Array:
integer that is too large to fit in a NumPy integer dtype, or
TypeError when the scalar type is incompatible with the dtype of self.
"""
+ # Note: Only Python scalar types that match the array dtype are
+ # allowed.
if isinstance(scalar, bool):
if self.dtype not in _boolean_dtypes:
raise TypeError(
@@ -195,6 +197,9 @@ class Array:
else:
raise TypeError("'scalar' must be a Python scalar")
+ # Note: scalars are unconditionally cast to the same dtype as the
+ # array.
+
# Note: the spec only specifies integer-dtype/int promotion
# behavior for integers within the bounds of the integer dtype.
# Outside of those bounds we use the default NumPy behavior (either
diff --git a/numpy/array_api/_data_type_functions.py b/numpy/array_api/_data_type_functions.py
index 1fb6062f6..7026bd489 100644
--- a/numpy/array_api/_data_type_functions.py
+++ b/numpy/array_api/_data_type_functions.py
@@ -56,7 +56,8 @@ def can_cast(from_: Union[Dtype, Array], to: Dtype, /) -> bool:
raise TypeError(f"{from_=}, but should be an array_api array or dtype")
if to not in _all_dtypes:
raise TypeError(f"{to=}, but should be a dtype")
- # Note: We avoid np.can_cast() as it has discrepancies with the array API.
+ # Note: We avoid np.can_cast() as it has discrepancies with the array API,
+ # since NumPy allows cross-kind casting (e.g., NumPy allows bool -> int8).
# See https://github.com/numpy/numpy/issues/20870
try:
# We promote `from_` and `to` together. We then check if the promoted
diff --git a/numpy/array_api/_sorting_functions.py b/numpy/array_api/_sorting_functions.py
index b2a11872f..afbb412f7 100644
--- a/numpy/array_api/_sorting_functions.py
+++ b/numpy/array_api/_sorting_functions.py
@@ -5,6 +5,7 @@ from ._array_object import Array
import numpy as np
+# Note: the descending keyword argument is new in this function
def argsort(
x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True
) -> Array:
@@ -31,7 +32,7 @@ def argsort(
res = max_i - res
return Array._new(res)
-
+# Note: the descending keyword argument is new in this function
def sort(
x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True
) -> Array:
diff --git a/numpy/array_api/linalg.py b/numpy/array_api/linalg.py
index 8d7ba659e..f422e1c27 100644
--- a/numpy/array_api/linalg.py
+++ b/numpy/array_api/linalg.py
@@ -89,7 +89,6 @@ def diagonal(x: Array, /, *, offset: int = 0) -> Array:
return Array._new(np.diagonal(x._array, offset=offset, axis1=-2, axis2=-1))
-# Note: the keyword argument name upper is different from np.linalg.eigh
def eigh(x: Array, /) -> EighResult:
"""
Array API compatible wrapper for :py:func:`np.linalg.eigh <numpy.linalg.eigh>`.
@@ -106,7 +105,6 @@ def eigh(x: Array, /) -> EighResult:
return EighResult(*map(Array._new, np.linalg.eigh(x._array)))
-# Note: the keyword argument name upper is different from np.linalg.eigvalsh
def eigvalsh(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.linalg.eigvalsh <numpy.linalg.eigvalsh>`.
@@ -346,6 +344,8 @@ def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult:
# Note: svdvals is not in NumPy (but it is in SciPy). It is equivalent to
# np.linalg.svd(compute_uv=False).
def svdvals(x: Array, /) -> Union[Array, Tuple[Array, ...]]:
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in svdvals')
return Array._new(np.linalg.svd(x._array, compute_uv=False))
# Note: tensordot is the numpy top-level namespace but not in np.linalg
@@ -366,12 +366,16 @@ def trace(x: Array, /, *, offset: int = 0) -> Array:
See its docstring for more information.
"""
+ if x.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in trace')
# Note: trace always operates on the last two axes, whereas np.trace
# operates on the first two axes by default
return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1)))
# Note: vecdot is not in NumPy
def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in vecdot')
return tensordot(x1, x2, axes=((axis,), (axis,)))
@@ -380,7 +384,7 @@ def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
# The type for ord should be Optional[Union[int, float, Literal[np.inf,
# -np.inf]]] but Literal does not support floating-point literals.
-def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, int]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> Array:
+def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> Array:
"""
Array API compatible wrapper for :py:func:`np.linalg.norm <numpy.linalg.norm>`.