From de0521fc22e641be5e819a2fec785c6f89ebca8c Mon Sep 17 00:00:00 2001 From: Bas van Beek Date: Fri, 25 Feb 2022 18:05:39 +0100 Subject: MAINT: Let `ndarray.__imatmul__` handle inplace matrix multiplication in the array-api --- numpy/array_api/_array_object.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index c4746fad9..592ca09df 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -850,23 +850,13 @@ class Array: """ Performs the operation __imatmul__. """ - # Note: NumPy does not implement __imatmul__. - # matmul is not defined for scalars, but without this, we may get # the wrong error message from asarray. other = self._check_allowed_dtypes(other, "numeric", "__imatmul__") if other is NotImplemented: return other - - # __imatmul__ can only be allowed when it would not change the shape - # of self. - other_shape = other.shape - if self.shape == () or other_shape == (): - raise ValueError("@= requires at least one dimension") - if len(other_shape) == 1 or other_shape[-1] != other_shape[-2]: - raise ValueError("@= cannot change the shape of the input array") - self._array[:] = self._array.__matmul__(other._array) - return self + res = self._array.__imatmul__(other._array) + return self.__class__._new(res) def __rmatmul__(self: Array, other: Array, /) -> Array: """ -- cgit v1.2.1 From a6740500576475a43a7121087e9f5f96d48ef1d2 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 6 Dec 2022 17:48:32 -0700 Subject: DOC: Some updates to the array_api compat document (#22747) * Add reshape differences to the array API compat document * Add an item to the array API compat document about reverse broadcasting * Make some wording easier to read --- numpy/array_api/_manipulation_functions.py | 1 + 1 file changed, 1 insertion(+) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_manipulation_functions.py b/numpy/array_api/_manipulation_functions.py index 4f2114ff5..7991f46a2 100644 --- a/numpy/array_api/_manipulation_functions.py +++ b/numpy/array_api/_manipulation_functions.py @@ -52,6 +52,7 @@ def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array: return Array._new(np.transpose(x._array, axes)) +# Note: the optional argument is called 'shape', not 'newshape' def reshape(x: Array, /, shape: Tuple[int, ...]) -> Array: """ Array API compatible wrapper for :py:func:`np.reshape `. -- cgit v1.2.1 From 8816c76631af79c46a8cf33ac1a8a79b2717c9ac Mon Sep 17 00:00:00 2001 From: Francesc Elies Date: Tue, 21 Feb 2023 14:50:04 +0100 Subject: TYP,MAINT: Add a missing explicit Any parameter --- numpy/array_api/_array_object.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index c4746fad9..eee117be6 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -56,7 +56,7 @@ class Array: functions, such as asarray(). """ - _array: np.ndarray + _array: np.ndarray[Any, Any] # Use a custom constructor instead of __init__, as manually initializing # this class is not supported API. -- cgit v1.2.1 From f07d55b27671a4575e3b9b2fc7ca9ec897d4db9e Mon Sep 17 00:00:00 2001 From: Alex Rogozhnikov Date: Sun, 26 Feb 2023 07:56:47 +0000 Subject: add support for xp.take --- numpy/array_api/__init__.py | 4 ++++ numpy/array_api/_indexing_functions.py | 18 ++++++++++++++++++ numpy/array_api/tests/test_indexing_functions.py | 24 ++++++++++++++++++++++++ 3 files changed, 46 insertions(+) create mode 100644 numpy/array_api/_indexing_functions.py create mode 100644 numpy/array_api/tests/test_indexing_functions.py (limited to 'numpy/array_api') diff --git a/numpy/array_api/__init__.py b/numpy/array_api/__init__.py index 5e58ee0a8..e154b9952 100644 --- a/numpy/array_api/__init__.py +++ b/numpy/array_api/__init__.py @@ -333,6 +333,10 @@ __all__ += [ "trunc", ] +from ._indexing_functions import take + +__all__ += ["take"] + # linalg is an extension in the array API spec, which is a sub-namespace. Only # a subset of functions in it are imported into the top-level namespace. from . import linalg diff --git a/numpy/array_api/_indexing_functions.py b/numpy/array_api/_indexing_functions.py new file mode 100644 index 000000000..ba56bcd6f --- /dev/null +++ b/numpy/array_api/_indexing_functions.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from ._array_object import Array +from ._dtypes import _integer_dtypes + +import numpy as np + +def take(x: Array, indices: Array, /, *, axis: int) -> Array: + """ + Array API compatible wrapper for :py:func:`np.take `. + + See its docstring for more information. + """ + if indices.dtype not in _integer_dtypes: + raise TypeError("Only integer dtypes are allowed in indexing") + if indices.ndim != 1: + raise ValueError("Only 1-dim indices array is supported") + return Array._new(np.take(x._array, indices._array, axis=axis)) diff --git a/numpy/array_api/tests/test_indexing_functions.py b/numpy/array_api/tests/test_indexing_functions.py new file mode 100644 index 000000000..26667e32f --- /dev/null +++ b/numpy/array_api/tests/test_indexing_functions.py @@ -0,0 +1,24 @@ +import pytest + +from numpy import array_api as xp + + +@pytest.mark.parametrize( + "x, indices, axis, expected", + [ + ([2, 3], [1, 1, 0], 0, [3, 3, 2]), + ([2, 3], [1, 1, 0], -1, [3, 3, 2]), + ([[2, 3]], [1], -1, [[3]]), + ([[2, 3]], [0, 0], 0, [[2, 3], [2, 3]]), + ], +) +def test_stable_desc_argsort(x, indices, axis, expected): + """ + Indices respect relative order of a descending stable-sort + + See https://github.com/numpy/numpy/issues/20778 + """ + x = xp.asarray(x) + indices = xp.asarray(indices) + out = xp.take(x, indices, axis=axis) + assert xp.all(out == xp.asarray(expected)) -- cgit v1.2.1 From 786bd366b22675b1e4067653d4729031c258f35f Mon Sep 17 00:00:00 2001 From: Alex Rogozhnikov Date: Sun, 26 Feb 2023 08:19:45 +0000 Subject: rename test function --- numpy/array_api/tests/test_indexing_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/tests/test_indexing_functions.py b/numpy/array_api/tests/test_indexing_functions.py index 26667e32f..9e05c6386 100644 --- a/numpy/array_api/tests/test_indexing_functions.py +++ b/numpy/array_api/tests/test_indexing_functions.py @@ -12,7 +12,7 @@ from numpy import array_api as xp ([[2, 3]], [0, 0], 0, [[2, 3], [2, 3]]), ], ) -def test_stable_desc_argsort(x, indices, axis, expected): +def test_take_function(x, indices, axis, expected): """ Indices respect relative order of a descending stable-sort -- cgit v1.2.1 From e570c6f5ff6f5aa966193d51560d3cde30fc09bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Tue, 14 Mar 2023 11:24:19 +0100 Subject: MAINT: cleanup unused Python3.8-only code and references --- numpy/array_api/_typing.py | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_typing.py b/numpy/array_api/_typing.py index dfa87b358..3f9b7186a 100644 --- a/numpy/array_api/_typing.py +++ b/numpy/array_api/_typing.py @@ -17,14 +17,12 @@ __all__ = [ "PyCapsule", ] -import sys from typing import ( Any, Literal, Sequence, Type, Union, - TYPE_CHECKING, TypeVar, Protocol, ) @@ -51,21 +49,20 @@ class NestedSequence(Protocol[_T_co]): def __len__(self, /) -> int: ... Device = Literal["cpu"] -if TYPE_CHECKING or sys.version_info >= (3, 9): - Dtype = dtype[Union[ - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, - ]] -else: - Dtype = dtype + +Dtype = dtype[Union[ + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + float32, + float64, +]] + SupportsBufferProtocol = Any PyCapsule = Any -- cgit v1.2.1