diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/array_api/_sorting_functions.py | 17 | ||||
-rw-r--r-- | numpy/array_api/tests/test_sorting_functions.py | 23 | ||||
-rw-r--r-- | numpy/core/_asarray.pyi | 8 | ||||
-rw-r--r-- | numpy/core/multiarray.pyi | 101 | ||||
-rw-r--r-- | numpy/core/numeric.pyi | 21 | ||||
-rw-r--r-- | numpy/f2py/symbolic.py | 2 | ||||
-rw-r--r-- | numpy/f2py/tests/test_symbolic.py | 2 | ||||
-rw-r--r-- | numpy/lib/npyio.pyi | 20 | ||||
-rw-r--r-- | numpy/lib/twodim_base.pyi | 13 | ||||
-rw-r--r-- | numpy/typing/__init__.py | 1 | ||||
-rw-r--r-- | numpy/typing/_array_like.py | 14 | ||||
-rw-r--r-- | numpy/typing/tests/data/fail/array_constructors.pyi | 2 | ||||
-rw-r--r-- | numpy/typing/tests/data/reveal/array_constructors.pyi | 1 |
13 files changed, 143 insertions, 82 deletions
diff --git a/numpy/array_api/_sorting_functions.py b/numpy/array_api/_sorting_functions.py index 9cd49786c..b2a11872f 100644 --- a/numpy/array_api/_sorting_functions.py +++ b/numpy/array_api/_sorting_functions.py @@ -15,9 +15,20 @@ def argsort( """ # Note: this keyword argument is different, and the default is different. kind = "stable" if stable else "quicksort" - res = np.argsort(x._array, axis=axis, kind=kind) - if descending: - res = np.flip(res, axis=axis) + if not descending: + res = np.argsort(x._array, axis=axis, kind=kind) + else: + # As NumPy has no native descending sort, we imitate it here. Note that + # simply flipping the results of np.argsort(x._array, ...) would not + # respect the relative order like it would in native descending sorts. + res = np.flip( + np.argsort(np.flip(x._array, axis=axis), axis=axis, kind=kind), + axis=axis, + ) + # Rely on flip()/argsort() to validate axis + normalised_axis = axis if axis >= 0 else x.ndim + axis + max_i = x.shape[normalised_axis] - 1 + res = max_i - res return Array._new(res) diff --git a/numpy/array_api/tests/test_sorting_functions.py b/numpy/array_api/tests/test_sorting_functions.py new file mode 100644 index 000000000..9848bbfeb --- /dev/null +++ b/numpy/array_api/tests/test_sorting_functions.py @@ -0,0 +1,23 @@ +import pytest + +from numpy import array_api as xp + + +@pytest.mark.parametrize( + "obj, axis, expected", + [ + ([0, 0], -1, [0, 1]), + ([0, 1, 0], -1, [1, 0, 2]), + ([[0, 1], [1, 1]], 0, [[1, 0], [0, 1]]), + ([[0, 1], [1, 1]], 1, [[1, 0], [0, 1]]), + ], +) +def test_stable_desc_argsort(obj, axis, expected): + """ + Indices respect relative order of a descending stable-sort + + See https://github.com/numpy/numpy/issues/20778 + """ + x = xp.asarray(obj) + out = xp.argsort(x, axis=axis, stable=True, descending=True) + assert xp.all(out == xp.asarray(expected)) diff --git a/numpy/core/_asarray.pyi b/numpy/core/_asarray.pyi index 0da2de912..51b794130 100644 --- a/numpy/core/_asarray.pyi +++ b/numpy/core/_asarray.pyi @@ -2,7 +2,7 @@ from collections.abc import Iterable from typing import TypeVar, Union, overload, Literal from numpy import ndarray -from numpy.typing import ArrayLike, DTypeLike +from numpy.typing import DTypeLike, _SupportsArrayFunc _ArrayType = TypeVar("_ArrayType", bound=ndarray) @@ -22,7 +22,7 @@ def require( dtype: None = ..., requirements: None | _Requirements | Iterable[_Requirements] = ..., *, - like: ArrayLike = ... + like: _SupportsArrayFunc = ... ) -> _ArrayType: ... @overload def require( @@ -30,7 +30,7 @@ def require( dtype: DTypeLike = ..., requirements: _E | Iterable[_RequirementsWithE] = ..., *, - like: ArrayLike = ... + like: _SupportsArrayFunc = ... ) -> ndarray: ... @overload def require( @@ -38,5 +38,5 @@ def require( dtype: DTypeLike = ..., requirements: None | _Requirements | Iterable[_Requirements] = ..., *, - like: ArrayLike = ... + like: _SupportsArrayFunc = ... ) -> ndarray: ... diff --git a/numpy/core/multiarray.pyi b/numpy/core/multiarray.pyi index 423aed85e..f2d3622d2 100644 --- a/numpy/core/multiarray.pyi +++ b/numpy/core/multiarray.pyi @@ -61,6 +61,7 @@ from numpy.typing import ( NDArray, ArrayLike, _SupportsArray, + _SupportsArrayFunc, _NestedSequence, _FiniteNestedSequence, _ArrayLikeBool_co, @@ -177,7 +178,7 @@ def array( order: _OrderKACF = ..., subok: L[True], ndmin: int = ..., - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> _ArrayType: ... @overload def array( @@ -188,7 +189,7 @@ def array( order: _OrderKACF = ..., subok: bool = ..., ndmin: int = ..., - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[_SCT]: ... @overload def array( @@ -199,7 +200,7 @@ def array( order: _OrderKACF = ..., subok: bool = ..., ndmin: int = ..., - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[Any]: ... @overload def array( @@ -210,7 +211,7 @@ def array( order: _OrderKACF = ..., subok: bool = ..., ndmin: int = ..., - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[_SCT]: ... @overload def array( @@ -221,7 +222,7 @@ def array( order: _OrderKACF = ..., subok: bool = ..., ndmin: int = ..., - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[Any]: ... @overload @@ -230,7 +231,7 @@ def zeros( dtype: None = ..., order: _OrderCF = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[float64]: ... @overload def zeros( @@ -238,7 +239,7 @@ def zeros( dtype: _DTypeLike[_SCT], order: _OrderCF = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[_SCT]: ... @overload def zeros( @@ -246,7 +247,7 @@ def zeros( dtype: DTypeLike, order: _OrderCF = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[Any]: ... @overload @@ -255,7 +256,7 @@ def empty( dtype: None = ..., order: _OrderCF = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[float64]: ... @overload def empty( @@ -263,7 +264,7 @@ def empty( dtype: _DTypeLike[_SCT], order: _OrderCF = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[_SCT]: ... @overload def empty( @@ -271,7 +272,7 @@ def empty( dtype: DTypeLike, order: _OrderCF = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[Any]: ... @overload @@ -468,7 +469,7 @@ def asarray( dtype: None = ..., order: _OrderKACF = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[_SCT]: ... @overload def asarray( @@ -476,7 +477,7 @@ def asarray( dtype: None = ..., order: _OrderKACF = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[Any]: ... @overload def asarray( @@ -484,7 +485,7 @@ def asarray( dtype: _DTypeLike[_SCT], order: _OrderKACF = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[_SCT]: ... @overload def asarray( @@ -492,7 +493,7 @@ def asarray( dtype: DTypeLike, order: _OrderKACF = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[Any]: ... @overload @@ -501,7 +502,7 @@ def asanyarray( dtype: None = ..., order: _OrderKACF = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> _ArrayType: ... @overload def asanyarray( @@ -509,7 +510,7 @@ def asanyarray( dtype: None = ..., order: _OrderKACF = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[_SCT]: ... @overload def asanyarray( @@ -517,7 +518,7 @@ def asanyarray( dtype: None = ..., order: _OrderKACF = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[Any]: ... @overload def asanyarray( @@ -525,7 +526,7 @@ def asanyarray( dtype: _DTypeLike[_SCT], order: _OrderKACF = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[_SCT]: ... @overload def asanyarray( @@ -533,7 +534,7 @@ def asanyarray( dtype: DTypeLike, order: _OrderKACF = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[Any]: ... @overload @@ -541,28 +542,28 @@ def ascontiguousarray( a: _ArrayLike[_SCT], dtype: None = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[_SCT]: ... @overload def ascontiguousarray( a: object, dtype: None = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[Any]: ... @overload def ascontiguousarray( a: Any, dtype: _DTypeLike[_SCT], *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[_SCT]: ... @overload def ascontiguousarray( a: Any, dtype: DTypeLike, *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[Any]: ... @overload @@ -570,28 +571,28 @@ def asfortranarray( a: _ArrayLike[_SCT], dtype: None = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[_SCT]: ... @overload def asfortranarray( a: object, dtype: None = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[Any]: ... @overload def asfortranarray( a: Any, dtype: _DTypeLike[_SCT], *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[_SCT]: ... @overload def asfortranarray( a: Any, dtype: DTypeLike, *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[Any]: ... # In practice `list[Any]` is list with an int, int and a valid @@ -609,7 +610,7 @@ def fromstring( count: SupportsIndex = ..., *, sep: str, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[float64]: ... @overload def fromstring( @@ -618,7 +619,7 @@ def fromstring( count: SupportsIndex = ..., *, sep: str, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[_SCT]: ... @overload def fromstring( @@ -627,7 +628,7 @@ def fromstring( count: SupportsIndex = ..., *, sep: str, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[Any]: ... def frompyfunc( @@ -646,7 +647,7 @@ def fromfile( sep: str = ..., offset: SupportsIndex = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[float64]: ... @overload def fromfile( @@ -656,7 +657,7 @@ def fromfile( sep: str = ..., offset: SupportsIndex = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[_SCT]: ... @overload def fromfile( @@ -666,7 +667,7 @@ def fromfile( sep: str = ..., offset: SupportsIndex = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[Any]: ... @overload @@ -675,7 +676,7 @@ def fromiter( dtype: _DTypeLike[_SCT], count: SupportsIndex = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[_SCT]: ... @overload def fromiter( @@ -683,7 +684,7 @@ def fromiter( dtype: DTypeLike, count: SupportsIndex = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[Any]: ... @overload @@ -693,7 +694,7 @@ def frombuffer( count: SupportsIndex = ..., offset: SupportsIndex = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[float64]: ... @overload def frombuffer( @@ -702,7 +703,7 @@ def frombuffer( count: SupportsIndex = ..., offset: SupportsIndex = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[_SCT]: ... @overload def frombuffer( @@ -711,7 +712,7 @@ def frombuffer( count: SupportsIndex = ..., offset: SupportsIndex = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[Any]: ... @overload @@ -719,7 +720,7 @@ def arange( # type: ignore[misc] stop: _IntLike_co, /, *, dtype: None = ..., - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[signedinteger[Any]]: ... @overload def arange( # type: ignore[misc] @@ -728,14 +729,14 @@ def arange( # type: ignore[misc] step: _IntLike_co = ..., dtype: None = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[signedinteger[Any]]: ... @overload def arange( # type: ignore[misc] stop: _FloatLike_co, /, *, dtype: None = ..., - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[floating[Any]]: ... @overload def arange( # type: ignore[misc] @@ -744,14 +745,14 @@ def arange( # type: ignore[misc] step: _FloatLike_co = ..., dtype: None = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[floating[Any]]: ... @overload def arange( stop: _TD64Like_co, /, *, dtype: None = ..., - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[timedelta64]: ... @overload def arange( @@ -760,7 +761,7 @@ def arange( step: _TD64Like_co = ..., dtype: None = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[timedelta64]: ... @overload def arange( # both start and stop must always be specified for datetime64 @@ -769,14 +770,14 @@ def arange( # both start and stop must always be specified for datetime64 step: datetime64 = ..., dtype: None = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[datetime64]: ... @overload def arange( stop: Any, /, *, dtype: _DTypeLike[_SCT], - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[_SCT]: ... @overload def arange( @@ -785,14 +786,14 @@ def arange( step: Any = ..., dtype: _DTypeLike[_SCT] = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[_SCT]: ... @overload def arange( stop: Any, /, *, dtype: DTypeLike, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[Any]: ... @overload def arange( @@ -801,7 +802,7 @@ def arange( step: Any = ..., dtype: DTypeLike = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[Any]: ... def datetime_data( diff --git a/numpy/core/numeric.pyi b/numpy/core/numeric.pyi index 8b92abab4..d5e28d24c 100644 --- a/numpy/core/numeric.pyi +++ b/numpy/core/numeric.pyi @@ -37,6 +37,7 @@ from numpy.typing import ( _SupportsDType, _FiniteNestedSequence, _SupportsArray, + _SupportsArrayFunc, _ScalarLike_co, _ArrayLikeBool_co, _ArrayLikeUInt_co, @@ -108,7 +109,7 @@ def ones( dtype: None = ..., order: _OrderCF = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[float64]: ... @overload def ones( @@ -116,7 +117,7 @@ def ones( dtype: _DTypeLike[_SCT], order: _OrderCF = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[_SCT]: ... @overload def ones( @@ -124,7 +125,7 @@ def ones( dtype: DTypeLike, order: _OrderCF = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[Any]: ... @overload @@ -175,7 +176,7 @@ def full( dtype: None = ..., order: _OrderCF = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[Any]: ... @overload def full( @@ -184,7 +185,7 @@ def full( dtype: _DTypeLike[_SCT], order: _OrderCF = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[_SCT]: ... @overload def full( @@ -193,7 +194,7 @@ def full( dtype: DTypeLike, order: _OrderCF = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[Any]: ... @overload @@ -563,7 +564,7 @@ def fromfunction( shape: Sequence[int], *, dtype: DTypeLike = ..., - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., **kwargs: Any, ) -> _T: ... @@ -584,21 +585,21 @@ def identity( n: int, dtype: None = ..., *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[float64]: ... @overload def identity( n: int, dtype: _DTypeLike[_SCT], *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[_SCT]: ... @overload def identity( n: int, dtype: DTypeLike, *, - like: ArrayLike = ..., + like: _SupportsArrayFunc = ..., ) -> NDArray[Any]: ... def allclose( diff --git a/numpy/f2py/symbolic.py b/numpy/f2py/symbolic.py index 1b7b35458..c2ab0f140 100644 --- a/numpy/f2py/symbolic.py +++ b/numpy/f2py/symbolic.py @@ -381,7 +381,7 @@ class Expr: language=language) for a in self.data] if language is Language.C: - r = f'({cond} ? {expr1} : {expr2})' + r = f'({cond}?{expr1}:{expr2})' elif language is Language.Python: r = f'({expr1} if {cond} else {expr2})' elif language is Language.Fortran: diff --git a/numpy/f2py/tests/test_symbolic.py b/numpy/f2py/tests/test_symbolic.py index e8dec72f0..845278311 100644 --- a/numpy/f2py/tests/test_symbolic.py +++ b/numpy/f2py/tests/test_symbolic.py @@ -201,7 +201,7 @@ class TestSymbolic(util.F2PyTest): assert (x + (x - y) / (x + y) + n).tostring(language=language) == "123 + x + (x - y) / (x + y)" - assert as_ternary(x, y, z).tostring(language=language) == "(x ? y : z)" + assert as_ternary(x, y, z).tostring(language=language) == "(x?y:z)" assert as_eq(x, y).tostring(language=language) == "x == y" assert as_ne(x, y).tostring(language=language) == "x != y" assert as_lt(x, y).tostring(language=language) == "x < y" diff --git a/numpy/lib/npyio.pyi b/numpy/lib/npyio.pyi index 60684c846..4f11fa807 100644 --- a/numpy/lib/npyio.pyi +++ b/numpy/lib/npyio.pyi @@ -27,7 +27,13 @@ from numpy import ( ) from numpy.ma.mrecords import MaskedRecords -from numpy.typing import ArrayLike, DTypeLike, NDArray, _SupportsDType +from numpy.typing import ( + ArrayLike, + DTypeLike, + NDArray, + _SupportsDType, + _SupportsArrayFunc, +) from numpy.core.multiarray import ( packbits as packbits, @@ -144,7 +150,7 @@ def loadtxt( encoding: None | str = ..., max_rows: None | int = ..., *, - like: None | ArrayLike = ... + like: None | _SupportsArrayFunc = ... ) -> NDArray[float64]: ... @overload def loadtxt( @@ -160,7 +166,7 @@ def loadtxt( encoding: None | str = ..., max_rows: None | int = ..., *, - like: None | ArrayLike = ... + like: None | _SupportsArrayFunc = ... ) -> NDArray[_SCT]: ... @overload def loadtxt( @@ -176,7 +182,7 @@ def loadtxt( encoding: None | str = ..., max_rows: None | int = ..., *, - like: None | ArrayLike = ... + like: None | _SupportsArrayFunc = ... ) -> NDArray[Any]: ... def savetxt( @@ -233,7 +239,7 @@ def genfromtxt( encoding: str = ..., *, ndmin: L[0, 1, 2] = ..., - like: None | ArrayLike = ..., + like: None | _SupportsArrayFunc = ..., ) -> NDArray[float64]: ... @overload def genfromtxt( @@ -262,7 +268,7 @@ def genfromtxt( encoding: str = ..., *, ndmin: L[0, 1, 2] = ..., - like: None | ArrayLike = ..., + like: None | _SupportsArrayFunc = ..., ) -> NDArray[_SCT]: ... @overload def genfromtxt( @@ -291,7 +297,7 @@ def genfromtxt( encoding: str = ..., *, ndmin: L[0, 1, 2] = ..., - like: None | ArrayLike = ..., + like: None | _SupportsArrayFunc = ..., ) -> NDArray[Any]: ... @overload diff --git a/numpy/lib/twodim_base.pyi b/numpy/lib/twodim_base.pyi index 76d7e5a9d..491862408 100644 --- a/numpy/lib/twodim_base.pyi +++ b/numpy/lib/twodim_base.pyi @@ -31,6 +31,7 @@ from numpy.typing import ( NDArray, _FiniteNestedSequence, _SupportsArray, + _SupportsArrayFunc, _ArrayLikeInt_co, _ArrayLikeFloat_co, _ArrayLikeComplex_co, @@ -73,7 +74,7 @@ def eye( dtype: None = ..., order: _OrderCF = ..., *, - like: None | ArrayLike = ..., + like: None | _SupportsArrayFunc = ..., ) -> NDArray[float64]: ... @overload def eye( @@ -83,7 +84,7 @@ def eye( dtype: _DTypeLike[_SCT] = ..., order: _OrderCF = ..., *, - like: None | ArrayLike = ..., + like: None | _SupportsArrayFunc = ..., ) -> NDArray[_SCT]: ... @overload def eye( @@ -93,7 +94,7 @@ def eye( dtype: DTypeLike = ..., order: _OrderCF = ..., *, - like: None | ArrayLike = ..., + like: None | _SupportsArrayFunc = ..., ) -> NDArray[Any]: ... @overload @@ -113,7 +114,7 @@ def tri( k: int = ..., dtype: None = ..., *, - like: None | ArrayLike = ... + like: None | _SupportsArrayFunc = ... ) -> NDArray[float64]: ... @overload def tri( @@ -122,7 +123,7 @@ def tri( k: int = ..., dtype: _DTypeLike[_SCT] = ..., *, - like: None | ArrayLike = ... + like: None | _SupportsArrayFunc = ... ) -> NDArray[_SCT]: ... @overload def tri( @@ -131,7 +132,7 @@ def tri( k: int = ..., dtype: DTypeLike = ..., *, - like: None | ArrayLike = ... + like: None | _SupportsArrayFunc = ... ) -> NDArray[Any]: ... @overload diff --git a/numpy/typing/__init__.py b/numpy/typing/__init__.py index d5cfbf5ac..72ac750ae 100644 --- a/numpy/typing/__init__.py +++ b/numpy/typing/__init__.py @@ -343,6 +343,7 @@ from ._array_like import ( _ArrayLike, _FiniteNestedSequence, _SupportsArray, + _SupportsArrayFunc, _ArrayLikeInt, _ArrayLikeBool_co, _ArrayLikeUInt_co, diff --git a/numpy/typing/_array_like.py b/numpy/typing/_array_like.py index 02e5ee573..bba545b7b 100644 --- a/numpy/typing/_array_like.py +++ b/numpy/typing/_array_like.py @@ -1,5 +1,8 @@ from __future__ import annotations +# NOTE: Import `Sequence` from `typing` as we it is needed for a type-alias, +# not an annotation +from collections.abc import Collection, Callable from typing import Any, Sequence, Protocol, Union, TypeVar from numpy import ( ndarray, @@ -34,6 +37,17 @@ class _SupportsArray(Protocol[_DType_co]): def __array__(self) -> ndarray[Any, _DType_co]: ... +class _SupportsArrayFunc(Protocol): + """A protocol class representing `~class.__array_function__`.""" + def __array_function__( + self, + func: Callable[..., Any], + types: Collection[type[Any]], + args: tuple[Any, ...], + kwargs: dict[str, Any], + ) -> object: ... + + # TODO: Wait until mypy supports recursive objects in combination with typevars _FiniteNestedSequence = Union[ _T, diff --git a/numpy/typing/tests/data/fail/array_constructors.pyi b/numpy/typing/tests/data/fail/array_constructors.pyi index 065b7d8a0..278894631 100644 --- a/numpy/typing/tests/data/fail/array_constructors.pyi +++ b/numpy/typing/tests/data/fail/array_constructors.pyi @@ -29,3 +29,5 @@ np.geomspace(None, 'bob') # E: No overload variant np.stack(generator) # E: No overload variant np.hstack({1, 2}) # E: No overload variant np.vstack(1) # E: No overload variant + +np.array([1], like=1) # E: No overload variant diff --git a/numpy/typing/tests/data/reveal/array_constructors.pyi b/numpy/typing/tests/data/reveal/array_constructors.pyi index c3b0c3457..448ed7e8b 100644 --- a/numpy/typing/tests/data/reveal/array_constructors.pyi +++ b/numpy/typing/tests/data/reveal/array_constructors.pyi @@ -28,6 +28,7 @@ reveal_type(np.array(B, subok=True)) # E: SubClass[{float64}] reveal_type(np.array([1, 1.0])) # E: ndarray[Any, dtype[Any]] reveal_type(np.array(A, dtype=np.int64)) # E: ndarray[Any, dtype[{int64}]] reveal_type(np.array(A, dtype='c16')) # E: ndarray[Any, dtype[Any]] +reveal_type(np.array(A, like=A)) # E: ndarray[Any, dtype[{float64}]] reveal_type(np.zeros([1, 5, 6])) # E: ndarray[Any, dtype[{float64}]] reveal_type(np.zeros([1, 5, 6], dtype=np.int64)) # E: ndarray[Any, dtype[{int64}]] |