diff options
Diffstat (limited to 'numpy/array_api')
| -rw-r--r-- | numpy/array_api/_array_object.py | 5 | ||||
| -rw-r--r-- | numpy/array_api/_data_type_functions.py | 3 | ||||
| -rw-r--r-- | numpy/array_api/_sorting_functions.py | 3 | ||||
| -rw-r--r-- | numpy/array_api/linalg.py | 10 |
4 files changed, 16 insertions, 5 deletions
diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index c86cfb3a6..6cf9ec6f3 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -177,6 +177,8 @@ class Array: integer that is too large to fit in a NumPy integer dtype, or TypeError when the scalar type is incompatible with the dtype of self. """ + # Note: Only Python scalar types that match the array dtype are + # allowed. if isinstance(scalar, bool): if self.dtype not in _boolean_dtypes: raise TypeError( @@ -195,6 +197,9 @@ class Array: else: raise TypeError("'scalar' must be a Python scalar") + # Note: scalars are unconditionally cast to the same dtype as the + # array. + # Note: the spec only specifies integer-dtype/int promotion # behavior for integers within the bounds of the integer dtype. # Outside of those bounds we use the default NumPy behavior (either diff --git a/numpy/array_api/_data_type_functions.py b/numpy/array_api/_data_type_functions.py index 1fb6062f6..7026bd489 100644 --- a/numpy/array_api/_data_type_functions.py +++ b/numpy/array_api/_data_type_functions.py @@ -56,7 +56,8 @@ def can_cast(from_: Union[Dtype, Array], to: Dtype, /) -> bool: raise TypeError(f"{from_=}, but should be an array_api array or dtype") if to not in _all_dtypes: raise TypeError(f"{to=}, but should be a dtype") - # Note: We avoid np.can_cast() as it has discrepancies with the array API. + # Note: We avoid np.can_cast() as it has discrepancies with the array API, + # since NumPy allows cross-kind casting (e.g., NumPy allows bool -> int8). # See https://github.com/numpy/numpy/issues/20870 try: # We promote `from_` and `to` together. We then check if the promoted diff --git a/numpy/array_api/_sorting_functions.py b/numpy/array_api/_sorting_functions.py index b2a11872f..afbb412f7 100644 --- a/numpy/array_api/_sorting_functions.py +++ b/numpy/array_api/_sorting_functions.py @@ -5,6 +5,7 @@ from ._array_object import Array import numpy as np +# Note: the descending keyword argument is new in this function def argsort( x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True ) -> Array: @@ -31,7 +32,7 @@ def argsort( res = max_i - res return Array._new(res) - +# Note: the descending keyword argument is new in this function def sort( x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True ) -> Array: diff --git a/numpy/array_api/linalg.py b/numpy/array_api/linalg.py index 8d7ba659e..f422e1c27 100644 --- a/numpy/array_api/linalg.py +++ b/numpy/array_api/linalg.py @@ -89,7 +89,6 @@ def diagonal(x: Array, /, *, offset: int = 0) -> Array: return Array._new(np.diagonal(x._array, offset=offset, axis1=-2, axis2=-1)) -# Note: the keyword argument name upper is different from np.linalg.eigh def eigh(x: Array, /) -> EighResult: """ Array API compatible wrapper for :py:func:`np.linalg.eigh <numpy.linalg.eigh>`. @@ -106,7 +105,6 @@ def eigh(x: Array, /) -> EighResult: return EighResult(*map(Array._new, np.linalg.eigh(x._array))) -# Note: the keyword argument name upper is different from np.linalg.eigvalsh def eigvalsh(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.eigvalsh <numpy.linalg.eigvalsh>`. @@ -346,6 +344,8 @@ def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult: # Note: svdvals is not in NumPy (but it is in SciPy). It is equivalent to # np.linalg.svd(compute_uv=False). def svdvals(x: Array, /) -> Union[Array, Tuple[Array, ...]]: + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in svdvals') return Array._new(np.linalg.svd(x._array, compute_uv=False)) # Note: tensordot is the numpy top-level namespace but not in np.linalg @@ -366,12 +366,16 @@ def trace(x: Array, /, *, offset: int = 0) -> Array: See its docstring for more information. """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in trace') # Note: trace always operates on the last two axes, whereas np.trace # operates on the first two axes by default return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1))) # Note: vecdot is not in NumPy def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in vecdot') return tensordot(x1, x2, axes=((axis,), (axis,))) @@ -380,7 +384,7 @@ def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: # The type for ord should be Optional[Union[int, float, Literal[np.inf, # -np.inf]]] but Literal does not support floating-point literals. -def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, int]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> Array: +def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.norm <numpy.linalg.norm>`. |
