summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/array_api/_sorting_functions.py17
-rw-r--r--numpy/array_api/tests/test_sorting_functions.py23
-rw-r--r--numpy/core/_asarray.pyi8
-rw-r--r--numpy/core/multiarray.pyi101
-rw-r--r--numpy/core/numeric.pyi21
-rw-r--r--numpy/f2py/symbolic.py2
-rw-r--r--numpy/f2py/tests/test_symbolic.py2
-rw-r--r--numpy/lib/npyio.pyi20
-rw-r--r--numpy/lib/twodim_base.pyi13
-rw-r--r--numpy/typing/__init__.py1
-rw-r--r--numpy/typing/_array_like.py14
-rw-r--r--numpy/typing/tests/data/fail/array_constructors.pyi2
-rw-r--r--numpy/typing/tests/data/reveal/array_constructors.pyi1
13 files changed, 143 insertions, 82 deletions
diff --git a/numpy/array_api/_sorting_functions.py b/numpy/array_api/_sorting_functions.py
index 9cd49786c..b2a11872f 100644
--- a/numpy/array_api/_sorting_functions.py
+++ b/numpy/array_api/_sorting_functions.py
@@ -15,9 +15,20 @@ def argsort(
"""
# Note: this keyword argument is different, and the default is different.
kind = "stable" if stable else "quicksort"
- res = np.argsort(x._array, axis=axis, kind=kind)
- if descending:
- res = np.flip(res, axis=axis)
+ if not descending:
+ res = np.argsort(x._array, axis=axis, kind=kind)
+ else:
+ # As NumPy has no native descending sort, we imitate it here. Note that
+ # simply flipping the results of np.argsort(x._array, ...) would not
+ # respect the relative order like it would in native descending sorts.
+ res = np.flip(
+ np.argsort(np.flip(x._array, axis=axis), axis=axis, kind=kind),
+ axis=axis,
+ )
+ # Rely on flip()/argsort() to validate axis
+ normalised_axis = axis if axis >= 0 else x.ndim + axis
+ max_i = x.shape[normalised_axis] - 1
+ res = max_i - res
return Array._new(res)
diff --git a/numpy/array_api/tests/test_sorting_functions.py b/numpy/array_api/tests/test_sorting_functions.py
new file mode 100644
index 000000000..9848bbfeb
--- /dev/null
+++ b/numpy/array_api/tests/test_sorting_functions.py
@@ -0,0 +1,23 @@
+import pytest
+
+from numpy import array_api as xp
+
+
+@pytest.mark.parametrize(
+ "obj, axis, expected",
+ [
+ ([0, 0], -1, [0, 1]),
+ ([0, 1, 0], -1, [1, 0, 2]),
+ ([[0, 1], [1, 1]], 0, [[1, 0], [0, 1]]),
+ ([[0, 1], [1, 1]], 1, [[1, 0], [0, 1]]),
+ ],
+)
+def test_stable_desc_argsort(obj, axis, expected):
+ """
+ Indices respect relative order of a descending stable-sort
+
+ See https://github.com/numpy/numpy/issues/20778
+ """
+ x = xp.asarray(obj)
+ out = xp.argsort(x, axis=axis, stable=True, descending=True)
+ assert xp.all(out == xp.asarray(expected))
diff --git a/numpy/core/_asarray.pyi b/numpy/core/_asarray.pyi
index 0da2de912..51b794130 100644
--- a/numpy/core/_asarray.pyi
+++ b/numpy/core/_asarray.pyi
@@ -2,7 +2,7 @@ from collections.abc import Iterable
from typing import TypeVar, Union, overload, Literal
from numpy import ndarray
-from numpy.typing import ArrayLike, DTypeLike
+from numpy.typing import DTypeLike, _SupportsArrayFunc
_ArrayType = TypeVar("_ArrayType", bound=ndarray)
@@ -22,7 +22,7 @@ def require(
dtype: None = ...,
requirements: None | _Requirements | Iterable[_Requirements] = ...,
*,
- like: ArrayLike = ...
+ like: _SupportsArrayFunc = ...
) -> _ArrayType: ...
@overload
def require(
@@ -30,7 +30,7 @@ def require(
dtype: DTypeLike = ...,
requirements: _E | Iterable[_RequirementsWithE] = ...,
*,
- like: ArrayLike = ...
+ like: _SupportsArrayFunc = ...
) -> ndarray: ...
@overload
def require(
@@ -38,5 +38,5 @@ def require(
dtype: DTypeLike = ...,
requirements: None | _Requirements | Iterable[_Requirements] = ...,
*,
- like: ArrayLike = ...
+ like: _SupportsArrayFunc = ...
) -> ndarray: ...
diff --git a/numpy/core/multiarray.pyi b/numpy/core/multiarray.pyi
index 423aed85e..f2d3622d2 100644
--- a/numpy/core/multiarray.pyi
+++ b/numpy/core/multiarray.pyi
@@ -61,6 +61,7 @@ from numpy.typing import (
NDArray,
ArrayLike,
_SupportsArray,
+ _SupportsArrayFunc,
_NestedSequence,
_FiniteNestedSequence,
_ArrayLikeBool_co,
@@ -177,7 +178,7 @@ def array(
order: _OrderKACF = ...,
subok: L[True],
ndmin: int = ...,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> _ArrayType: ...
@overload
def array(
@@ -188,7 +189,7 @@ def array(
order: _OrderKACF = ...,
subok: bool = ...,
ndmin: int = ...,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[_SCT]: ...
@overload
def array(
@@ -199,7 +200,7 @@ def array(
order: _OrderKACF = ...,
subok: bool = ...,
ndmin: int = ...,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[Any]: ...
@overload
def array(
@@ -210,7 +211,7 @@ def array(
order: _OrderKACF = ...,
subok: bool = ...,
ndmin: int = ...,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[_SCT]: ...
@overload
def array(
@@ -221,7 +222,7 @@ def array(
order: _OrderKACF = ...,
subok: bool = ...,
ndmin: int = ...,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[Any]: ...
@overload
@@ -230,7 +231,7 @@ def zeros(
dtype: None = ...,
order: _OrderCF = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[float64]: ...
@overload
def zeros(
@@ -238,7 +239,7 @@ def zeros(
dtype: _DTypeLike[_SCT],
order: _OrderCF = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[_SCT]: ...
@overload
def zeros(
@@ -246,7 +247,7 @@ def zeros(
dtype: DTypeLike,
order: _OrderCF = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[Any]: ...
@overload
@@ -255,7 +256,7 @@ def empty(
dtype: None = ...,
order: _OrderCF = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[float64]: ...
@overload
def empty(
@@ -263,7 +264,7 @@ def empty(
dtype: _DTypeLike[_SCT],
order: _OrderCF = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[_SCT]: ...
@overload
def empty(
@@ -271,7 +272,7 @@ def empty(
dtype: DTypeLike,
order: _OrderCF = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[Any]: ...
@overload
@@ -468,7 +469,7 @@ def asarray(
dtype: None = ...,
order: _OrderKACF = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[_SCT]: ...
@overload
def asarray(
@@ -476,7 +477,7 @@ def asarray(
dtype: None = ...,
order: _OrderKACF = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[Any]: ...
@overload
def asarray(
@@ -484,7 +485,7 @@ def asarray(
dtype: _DTypeLike[_SCT],
order: _OrderKACF = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[_SCT]: ...
@overload
def asarray(
@@ -492,7 +493,7 @@ def asarray(
dtype: DTypeLike,
order: _OrderKACF = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[Any]: ...
@overload
@@ -501,7 +502,7 @@ def asanyarray(
dtype: None = ...,
order: _OrderKACF = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> _ArrayType: ...
@overload
def asanyarray(
@@ -509,7 +510,7 @@ def asanyarray(
dtype: None = ...,
order: _OrderKACF = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[_SCT]: ...
@overload
def asanyarray(
@@ -517,7 +518,7 @@ def asanyarray(
dtype: None = ...,
order: _OrderKACF = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[Any]: ...
@overload
def asanyarray(
@@ -525,7 +526,7 @@ def asanyarray(
dtype: _DTypeLike[_SCT],
order: _OrderKACF = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[_SCT]: ...
@overload
def asanyarray(
@@ -533,7 +534,7 @@ def asanyarray(
dtype: DTypeLike,
order: _OrderKACF = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[Any]: ...
@overload
@@ -541,28 +542,28 @@ def ascontiguousarray(
a: _ArrayLike[_SCT],
dtype: None = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[_SCT]: ...
@overload
def ascontiguousarray(
a: object,
dtype: None = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[Any]: ...
@overload
def ascontiguousarray(
a: Any,
dtype: _DTypeLike[_SCT],
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[_SCT]: ...
@overload
def ascontiguousarray(
a: Any,
dtype: DTypeLike,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[Any]: ...
@overload
@@ -570,28 +571,28 @@ def asfortranarray(
a: _ArrayLike[_SCT],
dtype: None = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[_SCT]: ...
@overload
def asfortranarray(
a: object,
dtype: None = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[Any]: ...
@overload
def asfortranarray(
a: Any,
dtype: _DTypeLike[_SCT],
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[_SCT]: ...
@overload
def asfortranarray(
a: Any,
dtype: DTypeLike,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[Any]: ...
# In practice `list[Any]` is list with an int, int and a valid
@@ -609,7 +610,7 @@ def fromstring(
count: SupportsIndex = ...,
*,
sep: str,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[float64]: ...
@overload
def fromstring(
@@ -618,7 +619,7 @@ def fromstring(
count: SupportsIndex = ...,
*,
sep: str,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[_SCT]: ...
@overload
def fromstring(
@@ -627,7 +628,7 @@ def fromstring(
count: SupportsIndex = ...,
*,
sep: str,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[Any]: ...
def frompyfunc(
@@ -646,7 +647,7 @@ def fromfile(
sep: str = ...,
offset: SupportsIndex = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[float64]: ...
@overload
def fromfile(
@@ -656,7 +657,7 @@ def fromfile(
sep: str = ...,
offset: SupportsIndex = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[_SCT]: ...
@overload
def fromfile(
@@ -666,7 +667,7 @@ def fromfile(
sep: str = ...,
offset: SupportsIndex = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[Any]: ...
@overload
@@ -675,7 +676,7 @@ def fromiter(
dtype: _DTypeLike[_SCT],
count: SupportsIndex = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[_SCT]: ...
@overload
def fromiter(
@@ -683,7 +684,7 @@ def fromiter(
dtype: DTypeLike,
count: SupportsIndex = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[Any]: ...
@overload
@@ -693,7 +694,7 @@ def frombuffer(
count: SupportsIndex = ...,
offset: SupportsIndex = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[float64]: ...
@overload
def frombuffer(
@@ -702,7 +703,7 @@ def frombuffer(
count: SupportsIndex = ...,
offset: SupportsIndex = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[_SCT]: ...
@overload
def frombuffer(
@@ -711,7 +712,7 @@ def frombuffer(
count: SupportsIndex = ...,
offset: SupportsIndex = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[Any]: ...
@overload
@@ -719,7 +720,7 @@ def arange( # type: ignore[misc]
stop: _IntLike_co,
/, *,
dtype: None = ...,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[signedinteger[Any]]: ...
@overload
def arange( # type: ignore[misc]
@@ -728,14 +729,14 @@ def arange( # type: ignore[misc]
step: _IntLike_co = ...,
dtype: None = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[signedinteger[Any]]: ...
@overload
def arange( # type: ignore[misc]
stop: _FloatLike_co,
/, *,
dtype: None = ...,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[floating[Any]]: ...
@overload
def arange( # type: ignore[misc]
@@ -744,14 +745,14 @@ def arange( # type: ignore[misc]
step: _FloatLike_co = ...,
dtype: None = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[floating[Any]]: ...
@overload
def arange(
stop: _TD64Like_co,
/, *,
dtype: None = ...,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[timedelta64]: ...
@overload
def arange(
@@ -760,7 +761,7 @@ def arange(
step: _TD64Like_co = ...,
dtype: None = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[timedelta64]: ...
@overload
def arange( # both start and stop must always be specified for datetime64
@@ -769,14 +770,14 @@ def arange( # both start and stop must always be specified for datetime64
step: datetime64 = ...,
dtype: None = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[datetime64]: ...
@overload
def arange(
stop: Any,
/, *,
dtype: _DTypeLike[_SCT],
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[_SCT]: ...
@overload
def arange(
@@ -785,14 +786,14 @@ def arange(
step: Any = ...,
dtype: _DTypeLike[_SCT] = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[_SCT]: ...
@overload
def arange(
stop: Any, /,
*,
dtype: DTypeLike,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[Any]: ...
@overload
def arange(
@@ -801,7 +802,7 @@ def arange(
step: Any = ...,
dtype: DTypeLike = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[Any]: ...
def datetime_data(
diff --git a/numpy/core/numeric.pyi b/numpy/core/numeric.pyi
index 8b92abab4..d5e28d24c 100644
--- a/numpy/core/numeric.pyi
+++ b/numpy/core/numeric.pyi
@@ -37,6 +37,7 @@ from numpy.typing import (
_SupportsDType,
_FiniteNestedSequence,
_SupportsArray,
+ _SupportsArrayFunc,
_ScalarLike_co,
_ArrayLikeBool_co,
_ArrayLikeUInt_co,
@@ -108,7 +109,7 @@ def ones(
dtype: None = ...,
order: _OrderCF = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[float64]: ...
@overload
def ones(
@@ -116,7 +117,7 @@ def ones(
dtype: _DTypeLike[_SCT],
order: _OrderCF = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[_SCT]: ...
@overload
def ones(
@@ -124,7 +125,7 @@ def ones(
dtype: DTypeLike,
order: _OrderCF = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[Any]: ...
@overload
@@ -175,7 +176,7 @@ def full(
dtype: None = ...,
order: _OrderCF = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[Any]: ...
@overload
def full(
@@ -184,7 +185,7 @@ def full(
dtype: _DTypeLike[_SCT],
order: _OrderCF = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[_SCT]: ...
@overload
def full(
@@ -193,7 +194,7 @@ def full(
dtype: DTypeLike,
order: _OrderCF = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[Any]: ...
@overload
@@ -563,7 +564,7 @@ def fromfunction(
shape: Sequence[int],
*,
dtype: DTypeLike = ...,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
**kwargs: Any,
) -> _T: ...
@@ -584,21 +585,21 @@ def identity(
n: int,
dtype: None = ...,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[float64]: ...
@overload
def identity(
n: int,
dtype: _DTypeLike[_SCT],
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[_SCT]: ...
@overload
def identity(
n: int,
dtype: DTypeLike,
*,
- like: ArrayLike = ...,
+ like: _SupportsArrayFunc = ...,
) -> NDArray[Any]: ...
def allclose(
diff --git a/numpy/f2py/symbolic.py b/numpy/f2py/symbolic.py
index 1b7b35458..c2ab0f140 100644
--- a/numpy/f2py/symbolic.py
+++ b/numpy/f2py/symbolic.py
@@ -381,7 +381,7 @@ class Expr:
language=language)
for a in self.data]
if language is Language.C:
- r = f'({cond} ? {expr1} : {expr2})'
+ r = f'({cond}?{expr1}:{expr2})'
elif language is Language.Python:
r = f'({expr1} if {cond} else {expr2})'
elif language is Language.Fortran:
diff --git a/numpy/f2py/tests/test_symbolic.py b/numpy/f2py/tests/test_symbolic.py
index e8dec72f0..845278311 100644
--- a/numpy/f2py/tests/test_symbolic.py
+++ b/numpy/f2py/tests/test_symbolic.py
@@ -201,7 +201,7 @@ class TestSymbolic(util.F2PyTest):
assert (x + (x - y) / (x + y) +
n).tostring(language=language) == "123 + x + (x - y) / (x + y)"
- assert as_ternary(x, y, z).tostring(language=language) == "(x ? y : z)"
+ assert as_ternary(x, y, z).tostring(language=language) == "(x?y:z)"
assert as_eq(x, y).tostring(language=language) == "x == y"
assert as_ne(x, y).tostring(language=language) == "x != y"
assert as_lt(x, y).tostring(language=language) == "x < y"
diff --git a/numpy/lib/npyio.pyi b/numpy/lib/npyio.pyi
index 60684c846..4f11fa807 100644
--- a/numpy/lib/npyio.pyi
+++ b/numpy/lib/npyio.pyi
@@ -27,7 +27,13 @@ from numpy import (
)
from numpy.ma.mrecords import MaskedRecords
-from numpy.typing import ArrayLike, DTypeLike, NDArray, _SupportsDType
+from numpy.typing import (
+ ArrayLike,
+ DTypeLike,
+ NDArray,
+ _SupportsDType,
+ _SupportsArrayFunc,
+)
from numpy.core.multiarray import (
packbits as packbits,
@@ -144,7 +150,7 @@ def loadtxt(
encoding: None | str = ...,
max_rows: None | int = ...,
*,
- like: None | ArrayLike = ...
+ like: None | _SupportsArrayFunc = ...
) -> NDArray[float64]: ...
@overload
def loadtxt(
@@ -160,7 +166,7 @@ def loadtxt(
encoding: None | str = ...,
max_rows: None | int = ...,
*,
- like: None | ArrayLike = ...
+ like: None | _SupportsArrayFunc = ...
) -> NDArray[_SCT]: ...
@overload
def loadtxt(
@@ -176,7 +182,7 @@ def loadtxt(
encoding: None | str = ...,
max_rows: None | int = ...,
*,
- like: None | ArrayLike = ...
+ like: None | _SupportsArrayFunc = ...
) -> NDArray[Any]: ...
def savetxt(
@@ -233,7 +239,7 @@ def genfromtxt(
encoding: str = ...,
*,
ndmin: L[0, 1, 2] = ...,
- like: None | ArrayLike = ...,
+ like: None | _SupportsArrayFunc = ...,
) -> NDArray[float64]: ...
@overload
def genfromtxt(
@@ -262,7 +268,7 @@ def genfromtxt(
encoding: str = ...,
*,
ndmin: L[0, 1, 2] = ...,
- like: None | ArrayLike = ...,
+ like: None | _SupportsArrayFunc = ...,
) -> NDArray[_SCT]: ...
@overload
def genfromtxt(
@@ -291,7 +297,7 @@ def genfromtxt(
encoding: str = ...,
*,
ndmin: L[0, 1, 2] = ...,
- like: None | ArrayLike = ...,
+ like: None | _SupportsArrayFunc = ...,
) -> NDArray[Any]: ...
@overload
diff --git a/numpy/lib/twodim_base.pyi b/numpy/lib/twodim_base.pyi
index 76d7e5a9d..491862408 100644
--- a/numpy/lib/twodim_base.pyi
+++ b/numpy/lib/twodim_base.pyi
@@ -31,6 +31,7 @@ from numpy.typing import (
NDArray,
_FiniteNestedSequence,
_SupportsArray,
+ _SupportsArrayFunc,
_ArrayLikeInt_co,
_ArrayLikeFloat_co,
_ArrayLikeComplex_co,
@@ -73,7 +74,7 @@ def eye(
dtype: None = ...,
order: _OrderCF = ...,
*,
- like: None | ArrayLike = ...,
+ like: None | _SupportsArrayFunc = ...,
) -> NDArray[float64]: ...
@overload
def eye(
@@ -83,7 +84,7 @@ def eye(
dtype: _DTypeLike[_SCT] = ...,
order: _OrderCF = ...,
*,
- like: None | ArrayLike = ...,
+ like: None | _SupportsArrayFunc = ...,
) -> NDArray[_SCT]: ...
@overload
def eye(
@@ -93,7 +94,7 @@ def eye(
dtype: DTypeLike = ...,
order: _OrderCF = ...,
*,
- like: None | ArrayLike = ...,
+ like: None | _SupportsArrayFunc = ...,
) -> NDArray[Any]: ...
@overload
@@ -113,7 +114,7 @@ def tri(
k: int = ...,
dtype: None = ...,
*,
- like: None | ArrayLike = ...
+ like: None | _SupportsArrayFunc = ...
) -> NDArray[float64]: ...
@overload
def tri(
@@ -122,7 +123,7 @@ def tri(
k: int = ...,
dtype: _DTypeLike[_SCT] = ...,
*,
- like: None | ArrayLike = ...
+ like: None | _SupportsArrayFunc = ...
) -> NDArray[_SCT]: ...
@overload
def tri(
@@ -131,7 +132,7 @@ def tri(
k: int = ...,
dtype: DTypeLike = ...,
*,
- like: None | ArrayLike = ...
+ like: None | _SupportsArrayFunc = ...
) -> NDArray[Any]: ...
@overload
diff --git a/numpy/typing/__init__.py b/numpy/typing/__init__.py
index d5cfbf5ac..72ac750ae 100644
--- a/numpy/typing/__init__.py
+++ b/numpy/typing/__init__.py
@@ -343,6 +343,7 @@ from ._array_like import (
_ArrayLike,
_FiniteNestedSequence,
_SupportsArray,
+ _SupportsArrayFunc,
_ArrayLikeInt,
_ArrayLikeBool_co,
_ArrayLikeUInt_co,
diff --git a/numpy/typing/_array_like.py b/numpy/typing/_array_like.py
index 02e5ee573..bba545b7b 100644
--- a/numpy/typing/_array_like.py
+++ b/numpy/typing/_array_like.py
@@ -1,5 +1,8 @@
from __future__ import annotations
+# NOTE: Import `Sequence` from `typing` as we it is needed for a type-alias,
+# not an annotation
+from collections.abc import Collection, Callable
from typing import Any, Sequence, Protocol, Union, TypeVar
from numpy import (
ndarray,
@@ -34,6 +37,17 @@ class _SupportsArray(Protocol[_DType_co]):
def __array__(self) -> ndarray[Any, _DType_co]: ...
+class _SupportsArrayFunc(Protocol):
+ """A protocol class representing `~class.__array_function__`."""
+ def __array_function__(
+ self,
+ func: Callable[..., Any],
+ types: Collection[type[Any]],
+ args: tuple[Any, ...],
+ kwargs: dict[str, Any],
+ ) -> object: ...
+
+
# TODO: Wait until mypy supports recursive objects in combination with typevars
_FiniteNestedSequence = Union[
_T,
diff --git a/numpy/typing/tests/data/fail/array_constructors.pyi b/numpy/typing/tests/data/fail/array_constructors.pyi
index 065b7d8a0..278894631 100644
--- a/numpy/typing/tests/data/fail/array_constructors.pyi
+++ b/numpy/typing/tests/data/fail/array_constructors.pyi
@@ -29,3 +29,5 @@ np.geomspace(None, 'bob') # E: No overload variant
np.stack(generator) # E: No overload variant
np.hstack({1, 2}) # E: No overload variant
np.vstack(1) # E: No overload variant
+
+np.array([1], like=1) # E: No overload variant
diff --git a/numpy/typing/tests/data/reveal/array_constructors.pyi b/numpy/typing/tests/data/reveal/array_constructors.pyi
index c3b0c3457..448ed7e8b 100644
--- a/numpy/typing/tests/data/reveal/array_constructors.pyi
+++ b/numpy/typing/tests/data/reveal/array_constructors.pyi
@@ -28,6 +28,7 @@ reveal_type(np.array(B, subok=True)) # E: SubClass[{float64}]
reveal_type(np.array([1, 1.0])) # E: ndarray[Any, dtype[Any]]
reveal_type(np.array(A, dtype=np.int64)) # E: ndarray[Any, dtype[{int64}]]
reveal_type(np.array(A, dtype='c16')) # E: ndarray[Any, dtype[Any]]
+reveal_type(np.array(A, like=A)) # E: ndarray[Any, dtype[{float64}]]
reveal_type(np.zeros([1, 5, 6])) # E: ndarray[Any, dtype[{float64}]]
reveal_type(np.zeros([1, 5, 6], dtype=np.int64)) # E: ndarray[Any, dtype[{int64}]]