diff options
Diffstat (limited to 'numpy/array_api')
-rw-r--r-- | numpy/array_api/_creation_functions.py | 4 | ||||
-rw-r--r-- | numpy/array_api/_data_type_functions.py | 4 | ||||
-rw-r--r-- | numpy/array_api/_typing.py | 30 |
3 files changed, 27 insertions, 11 deletions
diff --git a/numpy/array_api/_creation_functions.py b/numpy/array_api/_creation_functions.py index 9f8136267..e36807468 100644 --- a/numpy/array_api/_creation_functions.py +++ b/numpy/array_api/_creation_functions.py @@ -134,7 +134,7 @@ def eye( n_cols: Optional[int] = None, /, *, - k: Optional[int] = 0, + k: int = 0, dtype: Optional[Dtype] = None, device: Optional[Device] = None, ) -> Array: @@ -232,7 +232,7 @@ def linspace( return Array._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint)) -def meshgrid(*arrays: Sequence[Array], indexing: str = "xy") -> List[Array, ...]: +def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]: """ Array API compatible wrapper for :py:func:`np.meshgrid <numpy.meshgrid>`. diff --git a/numpy/array_api/_data_type_functions.py b/numpy/array_api/_data_type_functions.py index fd92aa250..7ccbe9469 100644 --- a/numpy/array_api/_data_type_functions.py +++ b/numpy/array_api/_data_type_functions.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: import numpy as np -def broadcast_arrays(*arrays: Sequence[Array]) -> List[Array]: +def broadcast_arrays(*arrays: Array) -> List[Array]: """ Array API compatible wrapper for :py:func:`np.broadcast_arrays <numpy.broadcast_arrays>`. @@ -98,7 +98,7 @@ def iinfo(type: Union[Dtype, Array], /) -> iinfo_object: return iinfo_object(ii.bits, ii.max, ii.min) -def result_type(*arrays_and_dtypes: Sequence[Union[Array, Dtype]]) -> Dtype: +def result_type(*arrays_and_dtypes: Union[Array, Dtype]) -> Dtype: """ Array API compatible wrapper for :py:func:`np.result_type <numpy.result_type>`. diff --git a/numpy/array_api/_typing.py b/numpy/array_api/_typing.py index 5f937a56c..519e8463c 100644 --- a/numpy/array_api/_typing.py +++ b/numpy/array_api/_typing.py @@ -15,10 +15,12 @@ __all__ = [ "PyCapsule", ] -from typing import Any, Literal, Sequence, Type, Union +import sys +from typing import Any, Literal, Sequence, Type, Union, TYPE_CHECKING, TypeVar -from . import ( - Array, +from ._array_object import Array +from numpy import ( + dtype, int8, int16, int32, @@ -33,12 +35,26 @@ from . import ( # This should really be recursive, but that isn't supported yet. See the # similar comment in numpy/typing/_array_like.py -NestedSequence = Sequence[Sequence[Any]] +_T = TypeVar("_T") +NestedSequence = Sequence[Sequence[_T]] Device = Literal["cpu"] -Dtype = Type[ - Union[int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64] -] +if TYPE_CHECKING or sys.version_info >= (3, 9): + Dtype = dtype[Union[ + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + float32, + float64, + ]] +else: + Dtype = dtype + SupportsDLPack = Any SupportsBufferProtocol = Any PyCapsule = Any |