summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2021-05-20 13:10:51 -0600
committerGitHub <noreply@github.com>2021-05-20 13:10:51 -0600
commitd707f39d0e748f4eebc0058103a3749111328bb6 (patch)
treecda37071889b0bccde4ea13fdc8ab880761223f7
parent5c19c2350a9d839db835540e52697ea79c52ce47 (diff)
parentb4d1d3525040f2db7ad6d0be4fde73d4dc9a8590 (diff)
downloadnumpy-d707f39d0e748f4eebc0058103a3749111328bb6.tar.gz
Merge pull request #19029 from BvB93/internal
ENH: Improve the annotations of `np.core._internal`
-rw-r--r--numpy/__init__.pyi2
-rw-r--r--numpy/core/_internal.pyi43
-rw-r--r--numpy/typing/tests/data/fail/ndarray_misc.py7
-rw-r--r--numpy/typing/tests/data/reveal/ndarray_misc.py12
4 files changed, 50 insertions, 14 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi
index 3415172ed..ac37eb8ad 100644
--- a/numpy/__init__.pyi
+++ b/numpy/__init__.pyi
@@ -1676,7 +1676,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
@overload
def __array__(self, __dtype: _DType) -> ndarray[Any, _DType]: ...
@property
- def ctypes(self) -> _ctypes: ...
+ def ctypes(self) -> _ctypes[int]: ...
@property
def shape(self) -> _Shape: ...
@shape.setter
diff --git a/numpy/core/_internal.pyi b/numpy/core/_internal.pyi
index 1b3889e51..1ef1c9fa1 100644
--- a/numpy/core/_internal.pyi
+++ b/numpy/core/_internal.pyi
@@ -1,18 +1,35 @@
-from typing import Any
+from typing import Any, TypeVar, Type, overload, Optional, Generic
+import ctypes as ct
-# TODO: add better annotations when ctypes is stubbed out
+from numpy import ndarray
-class _ctypes:
+_CastT = TypeVar("_CastT", bound=ct._CanCastTo) # Copied from `ctypes.cast`
+_CT = TypeVar("_CT", bound=ct._CData)
+_PT = TypeVar("_PT", bound=Optional[int])
+
+# TODO: Let the likes of `shape_as` and `strides_as` return `None`
+# for 0D arrays once we've got shape-support
+
+class _ctypes(Generic[_PT]):
+ @overload
+ def __new__(cls, array: ndarray[Any, Any], ptr: None = ...) -> _ctypes[None]: ...
+ @overload
+ def __new__(cls, array: ndarray[Any, Any], ptr: _PT) -> _ctypes[_PT]: ...
+
+ # NOTE: In practice `shape` and `strides` return one of the concrete
+ # platform dependant array-types (`c_int`, `c_long` or `c_longlong`)
+ # corresponding to C's `int_ptr_t`, as determined by `_getintp_ctype`
+ # TODO: Hook this in to the mypy plugin so that a more appropiate
+ # `ctypes._SimpleCData[int]` sub-type can be returned
@property
- def data(self) -> int: ...
+ def data(self) -> _PT: ...
@property
- def shape(self) -> Any: ...
+ def shape(self) -> ct.Array[ct.c_int64]: ...
@property
- def strides(self) -> Any: ...
- def data_as(self, obj: Any) -> Any: ...
- def shape_as(self, obj: Any) -> Any: ...
- def strides_as(self, obj: Any) -> Any: ...
- def get_data(self) -> int: ...
- def get_shape(self) -> Any: ...
- def get_strides(self) -> Any: ...
- def get_as_parameter(self) -> Any: ...
+ def strides(self) -> ct.Array[ct.c_int64]: ...
+ @property
+ def _as_parameter_(self) -> ct.c_void_p: ...
+
+ def data_as(self, obj: Type[_CastT]) -> _CastT: ...
+ def shape_as(self, obj: Type[_CT]) -> ct.Array[_CT]: ...
+ def strides_as(self, obj: Type[_CT]) -> ct.Array[_CT]: ...
diff --git a/numpy/typing/tests/data/fail/ndarray_misc.py b/numpy/typing/tests/data/fail/ndarray_misc.py
index 653b9267b..cf3fedc45 100644
--- a/numpy/typing/tests/data/fail/ndarray_misc.py
+++ b/numpy/typing/tests/data/fail/ndarray_misc.py
@@ -14,6 +14,13 @@ AR_f8: np.ndarray[Any, np.dtype[np.float64]]
AR_M: np.ndarray[Any, np.dtype[np.datetime64]]
AR_b: np.ndarray[Any, np.dtype[np.bool_]]
+ctypes_obj = AR_f8.ctypes
+
+reveal_type(ctypes_obj.get_data()) # E: has no attribute
+reveal_type(ctypes_obj.get_shape()) # E: has no attribute
+reveal_type(ctypes_obj.get_strides()) # E: has no attribute
+reveal_type(ctypes_obj.get_as_parameter()) # E: has no attribute
+
f8.argpartition(0) # E: has no attribute
f8.diagonal() # E: has no attribute
f8.dot(1) # E: has no attribute
diff --git a/numpy/typing/tests/data/reveal/ndarray_misc.py b/numpy/typing/tests/data/reveal/ndarray_misc.py
index ecc322251..ea01b7aa4 100644
--- a/numpy/typing/tests/data/reveal/ndarray_misc.py
+++ b/numpy/typing/tests/data/reveal/ndarray_misc.py
@@ -7,6 +7,7 @@ function-based counterpart in `../from_numeric.py`.
"""
import operator
+import ctypes as ct
from typing import Any
import numpy as np
@@ -19,6 +20,17 @@ AR_f8: np.ndarray[Any, np.dtype[np.float64]]
AR_i8: np.ndarray[Any, np.dtype[np.int64]]
AR_U: np.ndarray[Any, np.dtype[np.str_]]
+ctypes_obj = AR_f8.ctypes
+
+reveal_type(ctypes_obj.data) # E: int
+reveal_type(ctypes_obj.shape) # E: ctypes.Array[ctypes.c_int64]
+reveal_type(ctypes_obj.strides) # E: ctypes.Array[ctypes.c_int64]
+reveal_type(ctypes_obj._as_parameter_) # E: ctypes.c_void_p
+
+reveal_type(ctypes_obj.data_as(ct.c_void_p)) # E: ctypes.c_void_p
+reveal_type(ctypes_obj.shape_as(ct.c_longlong)) # E: ctypes.Array[ctypes.c_longlong]
+reveal_type(ctypes_obj.strides_as(ct.c_ubyte)) # E: ctypes.Array[ctypes.c_ubyte]
+
reveal_type(f8.all()) # E: numpy.bool_
reveal_type(AR_f8.all()) # E: numpy.bool_
reveal_type(AR_f8.all(axis=0)) # E: Any