summaryrefslogtreecommitdiff
path: root/numpy/array_api
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/array_api')
-rw-r--r--numpy/array_api/_creation_functions.py4
-rw-r--r--numpy/array_api/_data_type_functions.py4
-rw-r--r--numpy/array_api/_typing.py30
3 files changed, 27 insertions, 11 deletions
diff --git a/numpy/array_api/_creation_functions.py b/numpy/array_api/_creation_functions.py
index 9f8136267..e36807468 100644
--- a/numpy/array_api/_creation_functions.py
+++ b/numpy/array_api/_creation_functions.py
@@ -134,7 +134,7 @@ def eye(
n_cols: Optional[int] = None,
/,
*,
- k: Optional[int] = 0,
+ k: int = 0,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
) -> Array:
@@ -232,7 +232,7 @@ def linspace(
return Array._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint))
-def meshgrid(*arrays: Sequence[Array], indexing: str = "xy") -> List[Array, ...]:
+def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]:
"""
Array API compatible wrapper for :py:func:`np.meshgrid <numpy.meshgrid>`.
diff --git a/numpy/array_api/_data_type_functions.py b/numpy/array_api/_data_type_functions.py
index fd92aa250..7ccbe9469 100644
--- a/numpy/array_api/_data_type_functions.py
+++ b/numpy/array_api/_data_type_functions.py
@@ -13,7 +13,7 @@ if TYPE_CHECKING:
import numpy as np
-def broadcast_arrays(*arrays: Sequence[Array]) -> List[Array]:
+def broadcast_arrays(*arrays: Array) -> List[Array]:
"""
Array API compatible wrapper for :py:func:`np.broadcast_arrays <numpy.broadcast_arrays>`.
@@ -98,7 +98,7 @@ def iinfo(type: Union[Dtype, Array], /) -> iinfo_object:
return iinfo_object(ii.bits, ii.max, ii.min)
-def result_type(*arrays_and_dtypes: Sequence[Union[Array, Dtype]]) -> Dtype:
+def result_type(*arrays_and_dtypes: Union[Array, Dtype]) -> Dtype:
"""
Array API compatible wrapper for :py:func:`np.result_type <numpy.result_type>`.
diff --git a/numpy/array_api/_typing.py b/numpy/array_api/_typing.py
index 5f937a56c..519e8463c 100644
--- a/numpy/array_api/_typing.py
+++ b/numpy/array_api/_typing.py
@@ -15,10 +15,12 @@ __all__ = [
"PyCapsule",
]
-from typing import Any, Literal, Sequence, Type, Union
+import sys
+from typing import Any, Literal, Sequence, Type, Union, TYPE_CHECKING, TypeVar
-from . import (
- Array,
+from ._array_object import Array
+from numpy import (
+ dtype,
int8,
int16,
int32,
@@ -33,12 +35,26 @@ from . import (
# This should really be recursive, but that isn't supported yet. See the
# similar comment in numpy/typing/_array_like.py
-NestedSequence = Sequence[Sequence[Any]]
+_T = TypeVar("_T")
+NestedSequence = Sequence[Sequence[_T]]
Device = Literal["cpu"]
-Dtype = Type[
- Union[int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64]
-]
+if TYPE_CHECKING or sys.version_info >= (3, 9):
+ Dtype = dtype[Union[
+ int8,
+ int16,
+ int32,
+ int64,
+ uint8,
+ uint16,
+ uint32,
+ uint64,
+ float32,
+ float64,
+ ]]
+else:
+ Dtype = dtype
+
SupportsDLPack = Any
SupportsBufferProtocol = Any
PyCapsule = Any