summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/__init__.pyi12
-rw-r--r--numpy/core/numerictypes.py17
-rw-r--r--numpy/core/numerictypes.pyi127
-rw-r--r--numpy/typing/tests/data/pass/numerictypes.py18
-rw-r--r--numpy/typing/tests/data/reveal/numerictypes.py18
5 files changed, 167 insertions, 25 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi
index 2fa5b3d41..08050a524 100644
--- a/numpy/__init__.pyi
+++ b/numpy/__init__.pyi
@@ -327,6 +327,10 @@ from numpy.core.numerictypes import (
issubdtype as issubdtype,
sctype2char as sctype2char,
find_common_type as find_common_type,
+ nbytes as nbytes,
+ cast as cast,
+ ScalarType as ScalarType,
+ typecodes as typecodes,
)
from numpy.core.shape_base import (
@@ -504,14 +508,6 @@ class vectorize:
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
def __getattr__(self, key: str) -> Any: ...
-# Placeholders for miscellaneous objects
-# NOTE: `cast` and `nbytes` are in fact instances of a `dict` subclass that
-# converts passed `DTypeLike` objects into the actual keys (`np.generic`)
-ScalarType: Tuple[Type[Any], ...]
-cast: Dict[DTypeLike, Callable[..., ndarray[Any, dtype[Any]]]]
-nbytes: Dict[DTypeLike, int]
-typecodes: Dict[str, str]
-
# Placeholders for Python-based functions
def angle(z, deg=...): ...
def append(arr, values, axis=...): ...
diff --git a/numpy/core/numerictypes.py b/numpy/core/numerictypes.py
index 93af5c95d..12f424fd4 100644
--- a/numpy/core/numerictypes.py
+++ b/numpy/core/numerictypes.py
@@ -79,7 +79,6 @@ Exported symbols include:
\\-> object_ (not used much) (kind=O)
"""
-import types as _types
import numbers
import warnings
@@ -512,15 +511,15 @@ cast = _typedict()
for key in _concrete_types:
cast[key] = lambda x, k=key: array(x, copy=False).astype(k)
-try:
- ScalarType = [_types.IntType, _types.FloatType, _types.ComplexType,
- _types.LongType, _types.BooleanType,
- _types.StringType, _types.UnicodeType, _types.BufferType]
-except AttributeError:
- # Py3K
- ScalarType = [int, float, complex, int, bool, bytes, str, memoryview]
-ScalarType.extend(_concrete_types)
+def _scalar_type_key(typ):
+ """A ``key`` function for `sorted`."""
+ dt = dtype(typ)
+ return (dt.kind.lower(), dt.itemsize)
+
+
+ScalarType = [int, float, complex, int, bool, bytes, str, memoryview]
+ScalarType += sorted(_concrete_types, key=_scalar_type_key)
ScalarType = tuple(ScalarType)
diff --git a/numpy/core/numerictypes.pyi b/numpy/core/numerictypes.pyi
index 238495fd3..fd4aa3fda 100644
--- a/numpy/core/numerictypes.pyi
+++ b/numpy/core/numerictypes.pyi
@@ -1,9 +1,85 @@
-from typing import TypeVar, Optional, Type, Union, Tuple, Sequence, overload, Any
+import sys
+from typing import (
+ TypeVar,
+ Optional,
+ Type,
+ Union,
+ Tuple,
+ Sequence,
+ overload,
+ Any,
+ TypeVar,
+ Dict,
+ List,
+)
-from numpy import generic, ndarray, dtype
-from numpy.typing import DTypeLike
+from numpy import (
+ ndarray,
+ dtype,
+ generic,
+ bool_,
+ ubyte,
+ ushort,
+ uintc,
+ uint,
+ ulonglong,
+ byte,
+ short,
+ intc,
+ int_,
+ longlong,
+ half,
+ single,
+ double,
+ longdouble,
+ csingle,
+ cdouble,
+ clongdouble,
+ datetime64,
+ timedelta64,
+ object_,
+ str_,
+ bytes_,
+ void,
+)
-_DefaultType = TypeVar("_DefaultType")
+from numpy.core._type_aliases import (
+ sctypeDict as sctypeDict,
+ sctypes as sctypes,
+)
+
+from numpy.typing import DTypeLike, ArrayLike
+
+if sys.version_info >= (3, 8):
+ from typing import Literal, Protocol, TypedDict
+else:
+ from typing_extensions import Literal, Protocol, TypedDict
+
+_T = TypeVar("_T")
+_ScalarType = TypeVar("_ScalarType", bound=generic)
+
+class _CastFunc(Protocol):
+ def __call__(
+ self, x: ArrayLike, k: DTypeLike = ...
+ ) -> ndarray[Any, dtype[Any]]: ...
+
+class _TypeCodes(TypedDict):
+ Character: Literal['c']
+ Integer: Literal['bhilqp']
+ UnsignedInteger: Literal['BHILQP']
+ Float: Literal['efdg']
+ Complex: Literal['FDG']
+ AllInteger: Literal['bBhHiIlLqQpP']
+ AllFloat: Literal['efdgFDG']
+ Datetime: Literal['Mm']
+ All: Literal['?bhilqpBHILQPefdgFDGSUVOMm']
+
+class _typedict(Dict[Type[generic], _T]):
+ def __getitem__(self, key: DTypeLike) -> _T: ...
+
+__all__: List[str]
+
+# TODO: Clean up the annotations for the 7 functions below
def maximum_sctype(t: DTypeLike) -> dtype: ...
def issctype(rep: object) -> bool: ...
@@ -13,8 +89,8 @@ def obj2sctype(rep: object) -> Optional[generic]: ...
def obj2sctype(rep: object, default: None) -> Optional[generic]: ...
@overload
def obj2sctype(
- rep: object, default: Type[_DefaultType]
-) -> Union[generic, Type[_DefaultType]]: ...
+ rep: object, default: Type[_T]
+) -> Union[generic, Type[_T]]: ...
def issubclass_(arg1: object, arg2: Union[object, Tuple[object, ...]]) -> bool: ...
def issubsctype(
arg1: Union[ndarray, DTypeLike], arg2: Union[ndarray, DTypeLike]
@@ -25,5 +101,40 @@ def find_common_type(
array_types: Sequence[DTypeLike], scalar_types: Sequence[DTypeLike]
) -> dtype: ...
-# TODO: Add annotations for the following objects:
-# nbytes, cast, ScalarType & typecodes
+cast: _typedict[_CastFunc]
+nbytes: _typedict[int]
+typecodes: _TypeCodes
+ScalarType: Tuple[
+ Type[int],
+ Type[float],
+ Type[complex],
+ Type[int],
+ Type[bool],
+ Type[bytes],
+ Type[str],
+ Type[memoryview],
+ Type[bool_],
+ Type[csingle],
+ Type[cdouble],
+ Type[clongdouble],
+ Type[half],
+ Type[single],
+ Type[double],
+ Type[longdouble],
+ Type[byte],
+ Type[short],
+ Type[intc],
+ Type[int_],
+ Type[longlong],
+ Type[timedelta64],
+ Type[datetime64],
+ Type[object_],
+ Type[bytes_],
+ Type[str_],
+ Type[ubyte],
+ Type[ushort],
+ Type[uintc],
+ Type[uint],
+ Type[ulonglong],
+ Type[void],
+]
diff --git a/numpy/typing/tests/data/pass/numerictypes.py b/numpy/typing/tests/data/pass/numerictypes.py
index 4f205cabc..5af0d171c 100644
--- a/numpy/typing/tests/data/pass/numerictypes.py
+++ b/numpy/typing/tests/data/pass/numerictypes.py
@@ -27,3 +27,21 @@ np.find_common_type([], [np.int64, np.float32, complex])
np.find_common_type((), (np.int64, np.float32, complex))
np.find_common_type([np.int64, np.float32], [])
np.find_common_type([np.float32], [np.int64, np.float64])
+
+np.cast[int]
+np.cast["i8"]
+np.cast[np.int64]
+
+np.nbytes[int]
+np.nbytes["i8"]
+np.nbytes[np.int64]
+
+np.ScalarType
+np.ScalarType[0]
+np.ScalarType[4]
+np.ScalarType[9]
+np.ScalarType[11]
+
+np.typecodes["Character"]
+np.typecodes["Complex"]
+np.typecodes["All"]
diff --git a/numpy/typing/tests/data/reveal/numerictypes.py b/numpy/typing/tests/data/reveal/numerictypes.py
index e026158cd..0f886b3fb 100644
--- a/numpy/typing/tests/data/reveal/numerictypes.py
+++ b/numpy/typing/tests/data/reveal/numerictypes.py
@@ -16,3 +16,21 @@ reveal_type(np.sctype2char("S8")) # E: str
reveal_type(np.sctype2char(list)) # E: str
reveal_type(np.find_common_type([np.int64], [np.int64])) # E: numpy.dtype
+
+reveal_type(np.cast[int]) # E: _CastFunc
+reveal_type(np.cast["i8"]) # E: _CastFunc
+reveal_type(np.cast[np.int64]) # E: _CastFunc
+
+reveal_type(np.nbytes[int]) # E: int
+reveal_type(np.nbytes["i8"]) # E: int
+reveal_type(np.nbytes[np.int64]) # E: int
+
+reveal_type(np.ScalarType) # E: Tuple
+reveal_type(np.ScalarType[0]) # E: Type[builtins.int]
+reveal_type(np.ScalarType[4]) # E: Type[builtins.bool]
+reveal_type(np.ScalarType[9]) # E: Type[{csingle}]
+reveal_type(np.ScalarType[11]) # E: Type[{clongdouble}]
+
+reveal_type(np.typecodes["Character"]) # E: Literal['c']
+reveal_type(np.typecodes["Complex"]) # E: Literal['FDG']
+reveal_type(np.typecodes["All"]) # E: Literal['?bhilqpBHILQPefdgFDGSUVOMm']