summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2020-10-16 09:43:13 -0600
committerGitHub <noreply@github.com>2020-10-16 09:43:13 -0600
commitb9dd2be0108cb312b4c34239a1dc8d24ef3a05a9 (patch)
treec2acac172a6b6ee298a4ae9165f111a070438480
parentb4718373f5412ea0d52ecbe2a3f9bbed824953a0 (diff)
parentb81ab444c0e56011e96c8895a19e18906ab4e731 (diff)
downloadnumpy-b9dd2be0108cb312b4c34239a1dc8d24ef3a05a9.tar.gz
Merge pull request #16759 from person142/dtype-generic
ENH: make dtype generic over scalar type
-rw-r--r--numpy/__init__.pyi318
-rw-r--r--numpy/typing/__init__.py2
-rw-r--r--numpy/typing/_dtype_like.py30
-rw-r--r--numpy/typing/tests/data/fail/dtype.py9
-rw-r--r--numpy/typing/tests/data/reveal/dtype.py33
5 files changed, 370 insertions, 22 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi
index e41c3cd78..64e4c75ce 100644
--- a/numpy/__init__.pyi
+++ b/numpy/__init__.pyi
@@ -15,6 +15,8 @@ from numpy.typing import (
_FloatLike,
_ComplexLike,
_NumberLike,
+ _SupportsDtype,
+ _VoidDtypeLike,
)
from numpy.typing._callable import (
_BoolOp,
@@ -508,16 +510,322 @@ where: Any
who: Any
_NdArraySubClass = TypeVar("_NdArraySubClass", bound=ndarray)
+_DTypeScalar = TypeVar("_DTypeScalar", bound=generic)
_ByteOrder = Literal["S", "<", ">", "=", "|", "L", "B", "N", "I"]
-class dtype:
+class dtype(Generic[_DTypeScalar]):
names: Optional[Tuple[str, ...]]
- def __init__(
- self,
- dtype: DtypeLike,
+ # Overload for subclass of generic
+ @overload
+ def __new__(
+ cls,
+ dtype: Type[_DTypeScalar],
align: bool = ...,
copy: bool = ...,
- ) -> None: ...
+ ) -> dtype[_DTypeScalar]: ...
+ # Overloads for string aliases, Python types, and some assorted
+ # other special cases. Order is sometimes important because of the
+ # subtype relationships
+ #
+ # bool < int < float < complex
+ #
+ # so we have to make sure the overloads for the narrowest type is
+ # first.
+ @overload
+ def __new__(
+ cls,
+ dtype: Union[
+ Type[bool],
+ Literal[
+ "?",
+ "=?",
+ "<?",
+ ">?",
+ "bool",
+ "bool_",
+ ],
+ ],
+ align: bool = ...,
+ copy: bool = ...,
+ ) -> dtype[bool_]: ...
+ @overload
+ def __new__(
+ cls,
+ dtype: Literal[
+ "uint8",
+ "u1",
+ "=u1",
+ "<u1",
+ ">u1",
+ ],
+ align: bool = ...,
+ copy: bool = ...,
+ ) -> dtype[uint8]: ...
+ @overload
+ def __new__(
+ cls,
+ dtype: Literal[
+ "uint16",
+ "u2",
+ "=u2",
+ "<u2",
+ ">u2",
+ ],
+ align: bool = ...,
+ copy: bool = ...,
+ ) -> dtype[uint16]: ...
+ @overload
+ def __new__(
+ cls,
+ dtype: Literal[
+ "uint32",
+ "u4",
+ "=u4",
+ "<u4",
+ ">u4",
+ ],
+ align: bool = ...,
+ copy: bool = ...,
+ ) -> dtype[uint32]: ...
+ @overload
+ def __new__(
+ cls,
+ dtype: Literal[
+ "uint64",
+ "u8",
+ "=u8",
+ "<u8",
+ ">u8",
+ ],
+ align: bool = ...,
+ copy: bool = ...,
+ ) -> dtype[uint64]: ...
+ @overload
+ def __new__(
+ cls,
+ dtype: Literal[
+ "int8",
+ "i1",
+ "=i1",
+ "<i1",
+ ">i1",
+ ],
+ align: bool = ...,
+ copy: bool = ...,
+ ) -> dtype[int8]: ...
+ @overload
+ def __new__(
+ cls,
+ dtype: Literal[
+ "int16",
+ "i2",
+ "=i2",
+ "<i2",
+ ">i2",
+ ],
+ align: bool = ...,
+ copy: bool = ...,
+ ) -> dtype[int16]: ...
+ @overload
+ def __new__(
+ cls,
+ dtype: Literal[
+ "int32",
+ "i4",
+ "=i4",
+ "<i4",
+ ">i4",
+ ],
+ align: bool = ...,
+ copy: bool = ...,
+ ) -> dtype[int32]: ...
+ @overload
+ def __new__(
+ cls,
+ dtype: Literal[
+ "int64",
+ "i8",
+ "=i8",
+ "<i8",
+ ">i8",
+ ],
+ align: bool = ...,
+ copy: bool = ...,
+ ) -> dtype[int64]: ...
+ # "int"/int resolve to int_, which is system dependent and as of
+ # now untyped. Long-term we'll do something fancier here.
+ @overload
+ def __new__(
+ cls,
+ dtype: Union[Type[int], Literal["int"]],
+ align: bool = ...,
+ copy: bool = ...,
+ ) -> dtype: ...
+ @overload
+ def __new__(
+ cls,
+ dtype: Literal[
+ "float16",
+ "f4",
+ "=f4",
+ "<f4",
+ ">f4",
+ "e",
+ "=e",
+ "<e",
+ ">e",
+ "half",
+ ],
+ align: bool = ...,
+ copy: bool = ...,
+ ) -> dtype[float16]: ...
+ @overload
+ def __new__(
+ cls,
+ dtype: Literal[
+ "float32",
+ "f4",
+ "=f4",
+ "<f4",
+ ">f4",
+ "f",
+ "=f",
+ "<f",
+ ">f",
+ "single",
+ ],
+ align: bool = ...,
+ copy: bool = ...,
+ ) -> dtype[float32]: ...
+ @overload
+ def __new__(
+ cls,
+ dtype: Union[
+ None,
+ Type[float],
+ Literal[
+ "float64",
+ "f8",
+ "=f8",
+ "<f8",
+ ">f8",
+ "d",
+ "<d",
+ ">d",
+ "float",
+ "double",
+ "float_",
+ ],
+ ],
+ align: bool = ...,
+ copy: bool = ...,
+ ) -> dtype[float64]: ...
+ @overload
+ def __new__(
+ cls,
+ dtype: Literal[
+ "complex64",
+ "c8",
+ "=c8",
+ "<c8",
+ ">c8",
+ "F",
+ "=F",
+ "<F",
+ ">F",
+ ],
+ align: bool = ...,
+ copy: bool = ...,
+ ) -> dtype[complex64]: ...
+ @overload
+ def __new__(
+ cls,
+ dtype: Union[
+ Type[complex],
+ Literal[
+ "complex128",
+ "c16",
+ "=c16",
+ "<c16",
+ ">c16",
+ "D",
+ "=D",
+ "<D",
+ ">D",
+ ],
+ ],
+ align: bool = ...,
+ copy: bool = ...,
+ ) -> dtype[complex128]: ...
+ @overload
+ def __new__(
+ cls,
+ dtype: Union[
+ Type[bytes],
+ Literal[
+ "S",
+ "=S",
+ "<S",
+ ">S",
+ "bytes",
+ "bytes_",
+ "bytes0",
+ ],
+ ],
+ align: bool = ...,
+ copy: bool = ...,
+ ) -> dtype[bytes_]: ...
+ @overload
+ def __new__(
+ cls,
+ dtype: Union[
+ Type[str],
+ Literal[
+ "U",
+ "=U",
+ # <U and >U intentionally not included; they are not
+ # the same dtype and which one dtype("U") translates
+ # to is platform-dependent.
+ "str",
+ "str_",
+ "str0",
+ ],
+ ],
+ align: bool = ...,
+ copy: bool = ...,
+ ) -> dtype[str_]: ...
+ # dtype of a dtype is the same dtype
+ @overload
+ def __new__(
+ cls,
+ dtype: dtype[_DTypeScalar],
+ align: bool = ...,
+ copy: bool = ...,
+ ) -> dtype[_DTypeScalar]: ...
+ # TODO: handle _SupportsDtype better
+ @overload
+ def __new__(
+ cls,
+ dtype: _SupportsDtype,
+ align: bool = ...,
+ copy: bool = ...,
+ ) -> dtype[Any]: ...
+ # Handle strings that can't be expressed as literals; i.e. s1, s2, ...
+ @overload
+ def __new__(
+ cls,
+ dtype: str,
+ align: bool = ...,
+ copy: bool = ...,
+ ) -> dtype[Any]: ...
+ # Catchall overload
+ @overload
+ def __new__(
+ cls,
+ dtype: _VoidDtypeLike,
+ align: bool = ...,
+ copy: bool = ...,
+ ) -> dtype[void]: ...
def __eq__(self, other: DtypeLike) -> bool: ...
def __ne__(self, other: DtypeLike) -> bool: ...
def __gt__(self, other: DtypeLike) -> bool: ...
diff --git a/numpy/typing/__init__.py b/numpy/typing/__init__.py
index 987aa39aa..dafabd95a 100644
--- a/numpy/typing/__init__.py
+++ b/numpy/typing/__init__.py
@@ -102,7 +102,7 @@ from ._scalars import (
)
from ._array_like import _SupportsArray, ArrayLike
from ._shape import _Shape, _ShapeLike
-from ._dtype_like import DtypeLike
+from ._dtype_like import _SupportsDtype, _VoidDtypeLike, DtypeLike
from numpy._pytesttester import PytestTester
test = PytestTester(__name__)
diff --git a/numpy/typing/_dtype_like.py b/numpy/typing/_dtype_like.py
index 7c1946a3e..5bfd8ffdc 100644
--- a/numpy/typing/_dtype_like.py
+++ b/numpy/typing/_dtype_like.py
@@ -38,18 +38,9 @@ else:
_DtypeDict = Any
_SupportsDtype = Any
-# Anything that can be coerced into numpy.dtype.
-# Reference: https://docs.scipy.org/doc/numpy/reference/arrays.dtypes.html
-DtypeLike = Union[
- dtype,
- # default data type (float64)
- None,
- # array-scalar types and generic types
- type, # TODO: enumerate these when we add type hints for numpy scalars
- # anything with a dtype attribute
- _SupportsDtype,
- # character codes, type strings or comma-separated fields, e.g., 'float64'
- str,
+
+# Would create a dtype[np.void]
+_VoidDtypeLike = Union[
# (flexible_dtype, itemsize)
Tuple[_DtypeLikeNested, int],
# (fixed_dtype, shape)
@@ -67,6 +58,21 @@ DtypeLike = Union[
Tuple[_DtypeLikeNested, _DtypeLikeNested],
]
+# Anything that can be coerced into numpy.dtype.
+# Reference: https://docs.scipy.org/doc/numpy/reference/arrays.dtypes.html
+DtypeLike = Union[
+ dtype,
+ # default data type (float64)
+ None,
+ # array-scalar types and generic types
+ type, # TODO: enumerate these when we add type hints for numpy scalars
+ # anything with a dtype attribute
+ _SupportsDtype,
+ # character codes, type strings or comma-separated fields, e.g., 'float64'
+ str,
+ _VoidDtypeLike,
+]
+
# NOTE: while it is possible to provide the dtype as a dict of
# dtype-like objects (e.g. `{'field1': ..., 'field2': ..., ...}`),
# this syntax is officially discourged and
diff --git a/numpy/typing/tests/data/fail/dtype.py b/numpy/typing/tests/data/fail/dtype.py
index 3dc027daf..7d4783d8f 100644
--- a/numpy/typing/tests/data/fail/dtype.py
+++ b/numpy/typing/tests/data/fail/dtype.py
@@ -1,15 +1,16 @@
import numpy as np
-
class Test:
not_dtype = float
-np.dtype(Test()) # E: Argument 1 to "dtype" has incompatible type
+np.dtype(Test()) # E: No overload variant of "dtype" matches
-np.dtype(
- { # E: Argument 1 to "dtype" has incompatible type
+np.dtype( # E: No overload variant of "dtype" matches
+ {
"field1": (float, 1),
"field2": (int, 3),
}
)
+
+np.dtype[np.float64](np.int64) # E: Argument 1 to "dtype" has incompatible type
diff --git a/numpy/typing/tests/data/reveal/dtype.py b/numpy/typing/tests/data/reveal/dtype.py
new file mode 100644
index 000000000..e0802299e
--- /dev/null
+++ b/numpy/typing/tests/data/reveal/dtype.py
@@ -0,0 +1,33 @@
+import numpy as np
+
+reveal_type(np.dtype(np.float64)) # E: numpy.dtype[numpy.float64*]
+reveal_type(np.dtype(np.int64)) # E: numpy.dtype[numpy.int64*]
+
+# String aliases
+reveal_type(np.dtype("float64")) # E: numpy.dtype[numpy.float64]
+reveal_type(np.dtype("float32")) # E: numpy.dtype[numpy.float32]
+reveal_type(np.dtype("int64")) # E: numpy.dtype[numpy.int64]
+reveal_type(np.dtype("int32")) # E: numpy.dtype[numpy.int32]
+reveal_type(np.dtype("bool")) # E: numpy.dtype[numpy.bool_]
+reveal_type(np.dtype("bytes")) # E: numpy.dtype[numpy.bytes_]
+reveal_type(np.dtype("str")) # E: numpy.dtype[numpy.str_]
+
+# Python types
+reveal_type(np.dtype(complex)) # E: numpy.dtype[numpy.complex128]
+reveal_type(np.dtype(float)) # E: numpy.dtype[numpy.float64]
+reveal_type(np.dtype(int)) # E: numpy.dtype
+reveal_type(np.dtype(bool)) # E: numpy.dtype[numpy.bool_]
+reveal_type(np.dtype(str)) # E: numpy.dtype[numpy.str_]
+reveal_type(np.dtype(bytes)) # E: numpy.dtype[numpy.bytes_]
+
+# Special case for None
+reveal_type(np.dtype(None)) # E: numpy.dtype[numpy.float64]
+
+# Dtypes of dtypes
+reveal_type(np.dtype(np.dtype(np.float64))) # E: numpy.dtype[numpy.float64*]
+
+# Parameterized dtypes
+reveal_type(np.dtype("S8")) # E: numpy.dtype
+
+# Void
+reveal_type(np.dtype(("U", 10))) # E: numpy.dtype[numpy.void]