diff options
author | Bas van Beek <43369155+BvB93@users.noreply.github.com> | 2021-07-28 14:16:14 +0200 |
---|---|---|
committer | Bas van Beek <b.f.van.beek@vu.nl> | 2021-09-15 13:33:32 +0200 |
commit | 9ca8076efde53701610352966a39a54602190bb9 (patch) | |
tree | 44f3f6eaeac4e72caffbff5ff39de57f7b9ca512 | |
parent | 07124b53b92ec55ee7d41b5efef48b2f9b0c37ae (diff) | |
download | numpy-9ca8076efde53701610352966a39a54602190bb9.tar.gz |
ENH: Add `number.__class_getitem__`
-rw-r--r-- | numpy/__init__.pyi | 2 | ||||
-rw-r--r-- | numpy/core/_add_newdocs.py | 31 | ||||
-rw-r--r-- | numpy/core/src/multiarray/scalartypes.c.src | 65 | ||||
-rw-r--r-- | numpy/core/tests/test_scalar_methods.py | 31 |
4 files changed, 129 insertions, 0 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi index f2c2329c6..58d5519de 100644 --- a/numpy/__init__.pyi +++ b/numpy/__init__.pyi @@ -3058,6 +3058,8 @@ class number(generic, Generic[_NBit1]): # type: ignore def real(self: _ArraySelf) -> _ArraySelf: ... @property def imag(self: _ArraySelf) -> _ArraySelf: ... + if sys.version_info >= (3, 9): + def __class_getitem__(self, item: Any) -> GenericAlias: ... def __int__(self) -> int: ... def __float__(self) -> float: ... def __complex__(self) -> complex: ... diff --git a/numpy/core/_add_newdocs.py b/numpy/core/_add_newdocs.py index 2ac4d6cbc..6fc3b48e5 100644 --- a/numpy/core/_add_newdocs.py +++ b/numpy/core/_add_newdocs.py @@ -6500,6 +6500,37 @@ add_newdoc('numpy.core.numerictypes', 'generic', add_newdoc('numpy.core.numerictypes', 'generic', refer_to_array_attribute('view')) +if sys.version_info >= (3, 9): + add_newdoc('numpy.core.numerictypes', 'number', ('__class_getitem__', + """ + __class_getitem__(item, /) + + Return a parametrized wrapper around the `~numpy.number` type. + + .. versionadded:: 1.22 + + Returns + ------- + alias : types.GenericAlias + A parametrized `~numpy.number` type. + + Examples + -------- + >>> from typing import Any + >>> import numpy as np + + >>> np.signedinteger[Any] + numpy.signedinteger[typing.Any] + + Note + ---- + This method is only available for python 3.9 and later. + + See Also + -------- + :pep:`585` : Type hinting generics in standard collections. + + """)) ############################################################################## # diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src index 4faa647ec..328581536 100644 --- a/numpy/core/src/multiarray/scalartypes.c.src +++ b/numpy/core/src/multiarray/scalartypes.c.src @@ -1806,6 +1806,21 @@ gentype_setflags(PyObject *NPY_UNUSED(self), PyObject *NPY_UNUSED(args), } /* + * Use for concrete np.number subclasses, making them act as if they + * were subtyped from e.g. np.signedinteger[object], thus lacking any + * free subscription parameters. Requires python >= 3.9. + */ +#ifdef Py_GENERICALIASOBJECT_H +static PyObject * +numbertype_class_getitem(PyObject *cls, PyObject *args) +{ + return PyErr_Format(PyExc_TypeError, + "There are no type variables left in %s", + ((PyTypeObject *)cls)->tp_name); +} +#endif + +/* * casting complex numbers (that don't inherit from Python complex) * to Python complex */ @@ -2188,6 +2203,16 @@ static PyGetSetDef inttype_getsets[] = { {NULL, NULL, NULL, NULL, NULL} }; +static PyMethodDef numbertype_methods[] = { + /* for typing; requires python >= 3.9 */ + #ifdef Py_GENERICALIASOBJECT_H + {"__class_getitem__", + (PyCFunction)Py_GenericAlias, + METH_CLASS | METH_O, NULL}, + #endif + {NULL, NULL, 0, NULL} /* sentinel */ +}; + /**begin repeat * #name = cfloat,clongdouble# */ @@ -2195,6 +2220,12 @@ static PyMethodDef @name@type_methods[] = { {"__complex__", (PyCFunction)@name@_complex, METH_VARARGS | METH_KEYWORDS, NULL}, + /* for typing; requires python >= 3.9 */ + #ifdef Py_GENERICALIASOBJECT_H + {"__class_getitem__", + (PyCFunction)numbertype_class_getitem, + METH_CLASS | METH_O, NULL}, + #endif {NULL, NULL, 0, NULL} }; /**end repeat**/ @@ -2232,6 +2263,27 @@ static PyMethodDef @name@type_methods[] = { {"is_integer", (PyCFunction)@name@_is_integer, METH_NOARGS, NULL}, + /* for typing; requires python >= 3.9 */ + #ifdef Py_GENERICALIASOBJECT_H + {"__class_getitem__", + (PyCFunction)numbertype_class_getitem, + METH_CLASS | METH_O, NULL}, + #endif + {NULL, NULL, 0, NULL} +}; +/**end repeat**/ + +/**begin repeat + * #name = byte, short, int, long, longlong, ubyte, ushort, + * uint, ulong, ulonglong, timedelta, cdouble# + */ +static PyMethodDef @name@type_methods[] = { + /* for typing; requires python >= 3.9 */ + #ifdef Py_GENERICALIASOBJECT_H + {"__class_getitem__", + (PyCFunction)numbertype_class_getitem, + METH_CLASS | METH_O, NULL}, + #endif {NULL, NULL, 0, NULL} }; /**end repeat**/ @@ -3951,6 +4003,8 @@ initialize_numeric_types(void) PyIntegerArrType_Type.tp_getset = inttype_getsets; + PyNumberArrType_Type.tp_methods = numbertype_methods; + /**begin repeat * #NAME= Number, Integer, SignedInteger, UnsignedInteger, Inexact, * Floating, ComplexFloating, Flexible, Character# @@ -4016,6 +4070,17 @@ initialize_numeric_types(void) /**end repeat**/ + /**begin repeat + * #name = byte, short, int, long, longlong, ubyte, ushort, + * uint, ulong, ulonglong, timedelta, cdouble# + * #Name = Byte, Short, Int, Long, LongLong, UByte, UShort, + * UInt, ULong, ULongLong, Timedelta, CDouble# + */ + + Py@Name@ArrType_Type.tp_methods = @name@type_methods; + + /**end repeat**/ + /* We won't be inheriting from Python Int type. */ PyIntArrType_Type.tp_hash = int_arrtype_hash; diff --git a/numpy/core/tests/test_scalar_methods.py b/numpy/core/tests/test_scalar_methods.py index 94b2dd3c9..0cdfe99b1 100644 --- a/numpy/core/tests/test_scalar_methods.py +++ b/numpy/core/tests/test_scalar_methods.py @@ -1,8 +1,11 @@ """ Test the scalar constructors, which also do type-coercion """ +import sys import fractions import platform +import types +from typing import Any, Type import pytest import numpy as np @@ -128,3 +131,31 @@ class TestIsInteger: if value == 0: continue assert not value.is_integer() + + +@pytest.mark.skipif(sys.version_info < (3, 9), reason="Requires python 3.9") +class TestClassGetItem: + @pytest.mark.parametrize("cls", [ + np.number, + np.integer, + np.inexact, + np.unsignedinteger, + np.signedinteger, + np.floating, + np.complexfloating, + ]) + def test_abc(self, cls: Type[np.number]) -> None: + alias = cls[Any] + assert isinstance(alias, types.GenericAlias) + assert alias.__origin__ is cls + + @pytest.mark.parametrize("cls", [np.generic, np.flexible, np.character]) + def test_abc_non_numeric(self, cls: Type[np.generic]) -> None: + with pytest.raises(TypeError): + cls[Any] + + @pytest.mark.parametrize("code", np.typecodes["All"]) + def test_concrete(self, code: str) -> None: + cls = np.dtype(code).type + with pytest.raises(TypeError): + cls[Any] |