summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBas van Beek <b.f.van.beek@vu.nl>2021-07-14 18:37:18 +0200
committerBas van Beek <b.f.van.beek@vu.nl>2021-09-15 12:17:29 +0200
commit07124b53b92ec55ee7d41b5efef48b2f9b0c37ae (patch)
tree8b6c860deda79bf79d0ed6115cf8a220ebef7985
parentdc7dafe70a53d6c122091516f34058bd0a6d89e1 (diff)
downloadnumpy-07124b53b92ec55ee7d41b5efef48b2f9b0c37ae.tar.gz
ENH: Add `ndarray.__class_getitem__`
-rw-r--r--numpy/__init__.pyi7
-rw-r--r--numpy/core/_add_newdocs.py41
-rw-r--r--numpy/core/src/multiarray/methods.c7
-rw-r--r--numpy/core/tests/test_arraymethod.py15
4 files changed, 67 insertions, 3 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi
index b34cdecc7..f2c2329c6 100644
--- a/numpy/__init__.pyi
+++ b/numpy/__init__.pyi
@@ -9,6 +9,9 @@ from abc import abstractmethod
from types import TracebackType, MappingProxyType
from contextlib import ContextDecorator
+if sys.version_info >= (3, 9):
+ from types import GenericAlias
+
from numpy._pytesttester import PytestTester
from numpy.core.multiarray import flagsobj
from numpy.core._internal import _ctypes
@@ -1697,6 +1700,10 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
strides: _ShapeLike = ...,
order: _OrderKACF = ...,
) -> _ArraySelf: ...
+
+ if sys.version_info >= (3, 9):
+ def __class_getitem__(self, item: Any) -> GenericAlias: ...
+
@overload
def __array__(self, dtype: None = ..., /) -> ndarray[Any, _DType_co]: ...
@overload
diff --git a/numpy/core/_add_newdocs.py b/numpy/core/_add_newdocs.py
index 06f2a6376..2ac4d6cbc 100644
--- a/numpy/core/_add_newdocs.py
+++ b/numpy/core/_add_newdocs.py
@@ -9,6 +9,7 @@ NOTE: Many of the methods of ndarray have corresponding functions.
"""
+import sys
from numpy.core.function_base import add_newdoc
from numpy.core.overrides import array_function_like_doc
@@ -796,7 +797,7 @@ add_newdoc('numpy.core.multiarray', 'array',
object : array_like
An array, any object exposing the array interface, an object whose
__array__ method returns an array, or any (nested) sequence.
- If object is a scalar, a 0-dimensional array containing object is
+ If object is a scalar, a 0-dimensional array containing object is
returned.
dtype : data-type, optional
The desired data-type for the array. If not given, then the type will
@@ -2201,8 +2202,8 @@ add_newdoc('numpy.core.multiarray', 'ndarray',
empty : Create an array, but leave its allocated memory unchanged (i.e.,
it contains "garbage").
dtype : Create a data-type.
- numpy.typing.NDArray : A :term:`generic <generic type>` version
- of ndarray.
+ numpy.typing.NDArray : An ndarray alias :term:`generic <generic type>`
+ w.r.t. its `dtype.type <numpy.dtype.type>`.
Notes
-----
@@ -2798,6 +2799,40 @@ add_newdoc('numpy.core.multiarray', 'ndarray', ('__copy__',
"""))
+if sys.version_info > (3, 9):
+ add_newdoc('numpy.core.multiarray', 'ndarray', ('__class_getitem__',
+ """a.__class_getitem__(item, /)
+
+ Return a parametrized wrapper around the `~numpy.ndarray` type.
+
+ .. versionadded:: 1.22
+
+ Returns
+ -------
+ alias : types.GenericAlias
+ A parametrized `~numpy.ndarray` type.
+
+ Examples
+ --------
+ >>> from typing import Any
+ >>> import numpy as np
+
+ >>> np.ndarray[Any, np.dtype]
+ numpy.ndarray[typing.Any, numpy.dtype]
+
+ Note
+ ----
+ This method is only available for python 3.9 and later.
+
+ See Also
+ --------
+ :pep:`585` : Type hinting generics in standard collections.
+ numpy.typing.NDArray : An ndarray alias :term:`generic <generic type>`
+ w.r.t. its `dtype.type <numpy.dtype.type>`.
+
+ """))
+
+
add_newdoc('numpy.core.multiarray', 'ndarray', ('__deepcopy__',
"""a.__deepcopy__(memo, /) -> Deep copy of array.
diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c
index 2c10817fa..43167cbbf 100644
--- a/numpy/core/src/multiarray/methods.c
+++ b/numpy/core/src/multiarray/methods.c
@@ -2756,6 +2756,13 @@ NPY_NO_EXPORT PyMethodDef array_methods[] = {
(PyCFunction) array_format,
METH_VARARGS, NULL},
+ /* for typing; requires python >= 3.9 */
+ #ifdef Py_GENERICALIASOBJECT_H
+ {"__class_getitem__",
+ (PyCFunction)Py_GenericAlias,
+ METH_CLASS | METH_O, NULL},
+ #endif
+
/* Original and Extended methods added 2005 */
{"all",
(PyCFunction)array_all,
diff --git a/numpy/core/tests/test_arraymethod.py b/numpy/core/tests/test_arraymethod.py
index b1bc79b80..9bd4c54df 100644
--- a/numpy/core/tests/test_arraymethod.py
+++ b/numpy/core/tests/test_arraymethod.py
@@ -3,6 +3,10 @@ This file tests the generic aspects of ArrayMethod. At the time of writing
this is private API, but when added, public API may be added here.
"""
+import sys
+import types
+from typing import Any, Type
+
import pytest
import numpy as np
@@ -56,3 +60,14 @@ class TestSimpleStridedCall:
# This is private API, which may be modified freely
with pytest.raises(error):
self.method._simple_strided_call(*args)
+
+
+@pytest.mark.parametrize(
+ "cls", [np.ndarray, np.recarray, np.chararray, np.matrix, np.memmap]
+)
+@pytest.mark.skipif(sys.version_info < (3, 9), reason="Requires python 3.9")
+def test_class_getitem(cls: Type[np.ndarray]) -> None:
+ """Test `ndarray.__class_getitem__`."""
+ alias = cls[Any, Any]
+ assert isinstance(alias, types.GenericAlias)
+ assert alias.__origin__ is cls