diff options
author | Bas van Beek <b.f.van.beek@vu.nl> | 2021-07-14 18:37:18 +0200 |
---|---|---|
committer | Bas van Beek <b.f.van.beek@vu.nl> | 2021-09-15 12:17:29 +0200 |
commit | 07124b53b92ec55ee7d41b5efef48b2f9b0c37ae (patch) | |
tree | 8b6c860deda79bf79d0ed6115cf8a220ebef7985 | |
parent | dc7dafe70a53d6c122091516f34058bd0a6d89e1 (diff) | |
download | numpy-07124b53b92ec55ee7d41b5efef48b2f9b0c37ae.tar.gz |
ENH: Add `ndarray.__class_getitem__`
-rw-r--r-- | numpy/__init__.pyi | 7 | ||||
-rw-r--r-- | numpy/core/_add_newdocs.py | 41 | ||||
-rw-r--r-- | numpy/core/src/multiarray/methods.c | 7 | ||||
-rw-r--r-- | numpy/core/tests/test_arraymethod.py | 15 |
4 files changed, 67 insertions, 3 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi index b34cdecc7..f2c2329c6 100644 --- a/numpy/__init__.pyi +++ b/numpy/__init__.pyi @@ -9,6 +9,9 @@ from abc import abstractmethod from types import TracebackType, MappingProxyType from contextlib import ContextDecorator +if sys.version_info >= (3, 9): + from types import GenericAlias + from numpy._pytesttester import PytestTester from numpy.core.multiarray import flagsobj from numpy.core._internal import _ctypes @@ -1697,6 +1700,10 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]): strides: _ShapeLike = ..., order: _OrderKACF = ..., ) -> _ArraySelf: ... + + if sys.version_info >= (3, 9): + def __class_getitem__(self, item: Any) -> GenericAlias: ... + @overload def __array__(self, dtype: None = ..., /) -> ndarray[Any, _DType_co]: ... @overload diff --git a/numpy/core/_add_newdocs.py b/numpy/core/_add_newdocs.py index 06f2a6376..2ac4d6cbc 100644 --- a/numpy/core/_add_newdocs.py +++ b/numpy/core/_add_newdocs.py @@ -9,6 +9,7 @@ NOTE: Many of the methods of ndarray have corresponding functions. """ +import sys from numpy.core.function_base import add_newdoc from numpy.core.overrides import array_function_like_doc @@ -796,7 +797,7 @@ add_newdoc('numpy.core.multiarray', 'array', object : array_like An array, any object exposing the array interface, an object whose __array__ method returns an array, or any (nested) sequence. - If object is a scalar, a 0-dimensional array containing object is + If object is a scalar, a 0-dimensional array containing object is returned. dtype : data-type, optional The desired data-type for the array. If not given, then the type will @@ -2201,8 +2202,8 @@ add_newdoc('numpy.core.multiarray', 'ndarray', empty : Create an array, but leave its allocated memory unchanged (i.e., it contains "garbage"). dtype : Create a data-type. - numpy.typing.NDArray : A :term:`generic <generic type>` version - of ndarray. + numpy.typing.NDArray : An ndarray alias :term:`generic <generic type>` + w.r.t. its `dtype.type <numpy.dtype.type>`. Notes ----- @@ -2798,6 +2799,40 @@ add_newdoc('numpy.core.multiarray', 'ndarray', ('__copy__', """)) +if sys.version_info > (3, 9): + add_newdoc('numpy.core.multiarray', 'ndarray', ('__class_getitem__', + """a.__class_getitem__(item, /) + + Return a parametrized wrapper around the `~numpy.ndarray` type. + + .. versionadded:: 1.22 + + Returns + ------- + alias : types.GenericAlias + A parametrized `~numpy.ndarray` type. + + Examples + -------- + >>> from typing import Any + >>> import numpy as np + + >>> np.ndarray[Any, np.dtype] + numpy.ndarray[typing.Any, numpy.dtype] + + Note + ---- + This method is only available for python 3.9 and later. + + See Also + -------- + :pep:`585` : Type hinting generics in standard collections. + numpy.typing.NDArray : An ndarray alias :term:`generic <generic type>` + w.r.t. its `dtype.type <numpy.dtype.type>`. + + """)) + + add_newdoc('numpy.core.multiarray', 'ndarray', ('__deepcopy__', """a.__deepcopy__(memo, /) -> Deep copy of array. diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c index 2c10817fa..43167cbbf 100644 --- a/numpy/core/src/multiarray/methods.c +++ b/numpy/core/src/multiarray/methods.c @@ -2756,6 +2756,13 @@ NPY_NO_EXPORT PyMethodDef array_methods[] = { (PyCFunction) array_format, METH_VARARGS, NULL}, + /* for typing; requires python >= 3.9 */ + #ifdef Py_GENERICALIASOBJECT_H + {"__class_getitem__", + (PyCFunction)Py_GenericAlias, + METH_CLASS | METH_O, NULL}, + #endif + /* Original and Extended methods added 2005 */ {"all", (PyCFunction)array_all, diff --git a/numpy/core/tests/test_arraymethod.py b/numpy/core/tests/test_arraymethod.py index b1bc79b80..9bd4c54df 100644 --- a/numpy/core/tests/test_arraymethod.py +++ b/numpy/core/tests/test_arraymethod.py @@ -3,6 +3,10 @@ This file tests the generic aspects of ArrayMethod. At the time of writing this is private API, but when added, public API may be added here. """ +import sys +import types +from typing import Any, Type + import pytest import numpy as np @@ -56,3 +60,14 @@ class TestSimpleStridedCall: # This is private API, which may be modified freely with pytest.raises(error): self.method._simple_strided_call(*args) + + +@pytest.mark.parametrize( + "cls", [np.ndarray, np.recarray, np.chararray, np.matrix, np.memmap] +) +@pytest.mark.skipif(sys.version_info < (3, 9), reason="Requires python 3.9") +def test_class_getitem(cls: Type[np.ndarray]) -> None: + """Test `ndarray.__class_getitem__`.""" + alias = cls[Any, Any] + assert isinstance(alias, types.GenericAlias) + assert alias.__origin__ is cls |