summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBas van Beek <43369155+BvB93@users.noreply.github.com>2021-07-28 14:16:14 +0200
committerBas van Beek <b.f.van.beek@vu.nl>2021-09-15 13:33:32 +0200
commit9ca8076efde53701610352966a39a54602190bb9 (patch)
tree44f3f6eaeac4e72caffbff5ff39de57f7b9ca512
parent07124b53b92ec55ee7d41b5efef48b2f9b0c37ae (diff)
downloadnumpy-9ca8076efde53701610352966a39a54602190bb9.tar.gz
ENH: Add `number.__class_getitem__`
-rw-r--r--numpy/__init__.pyi2
-rw-r--r--numpy/core/_add_newdocs.py31
-rw-r--r--numpy/core/src/multiarray/scalartypes.c.src65
-rw-r--r--numpy/core/tests/test_scalar_methods.py31
4 files changed, 129 insertions, 0 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi
index f2c2329c6..58d5519de 100644
--- a/numpy/__init__.pyi
+++ b/numpy/__init__.pyi
@@ -3058,6 +3058,8 @@ class number(generic, Generic[_NBit1]): # type: ignore
def real(self: _ArraySelf) -> _ArraySelf: ...
@property
def imag(self: _ArraySelf) -> _ArraySelf: ...
+ if sys.version_info >= (3, 9):
+ def __class_getitem__(self, item: Any) -> GenericAlias: ...
def __int__(self) -> int: ...
def __float__(self) -> float: ...
def __complex__(self) -> complex: ...
diff --git a/numpy/core/_add_newdocs.py b/numpy/core/_add_newdocs.py
index 2ac4d6cbc..6fc3b48e5 100644
--- a/numpy/core/_add_newdocs.py
+++ b/numpy/core/_add_newdocs.py
@@ -6500,6 +6500,37 @@ add_newdoc('numpy.core.numerictypes', 'generic',
add_newdoc('numpy.core.numerictypes', 'generic',
refer_to_array_attribute('view'))
+if sys.version_info >= (3, 9):
+ add_newdoc('numpy.core.numerictypes', 'number', ('__class_getitem__',
+ """
+ __class_getitem__(item, /)
+
+ Return a parametrized wrapper around the `~numpy.number` type.
+
+ .. versionadded:: 1.22
+
+ Returns
+ -------
+ alias : types.GenericAlias
+ A parametrized `~numpy.number` type.
+
+ Examples
+ --------
+ >>> from typing import Any
+ >>> import numpy as np
+
+ >>> np.signedinteger[Any]
+ numpy.signedinteger[typing.Any]
+
+ Note
+ ----
+ This method is only available for python 3.9 and later.
+
+ See Also
+ --------
+ :pep:`585` : Type hinting generics in standard collections.
+
+ """))
##############################################################################
#
diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src
index 4faa647ec..328581536 100644
--- a/numpy/core/src/multiarray/scalartypes.c.src
+++ b/numpy/core/src/multiarray/scalartypes.c.src
@@ -1806,6 +1806,21 @@ gentype_setflags(PyObject *NPY_UNUSED(self), PyObject *NPY_UNUSED(args),
}
/*
+ * Use for concrete np.number subclasses, making them act as if they
+ * were subtyped from e.g. np.signedinteger[object], thus lacking any
+ * free subscription parameters. Requires python >= 3.9.
+ */
+#ifdef Py_GENERICALIASOBJECT_H
+static PyObject *
+numbertype_class_getitem(PyObject *cls, PyObject *args)
+{
+ return PyErr_Format(PyExc_TypeError,
+ "There are no type variables left in %s",
+ ((PyTypeObject *)cls)->tp_name);
+}
+#endif
+
+/*
* casting complex numbers (that don't inherit from Python complex)
* to Python complex
*/
@@ -2188,6 +2203,16 @@ static PyGetSetDef inttype_getsets[] = {
{NULL, NULL, NULL, NULL, NULL}
};
+static PyMethodDef numbertype_methods[] = {
+ /* for typing; requires python >= 3.9 */
+ #ifdef Py_GENERICALIASOBJECT_H
+ {"__class_getitem__",
+ (PyCFunction)Py_GenericAlias,
+ METH_CLASS | METH_O, NULL},
+ #endif
+ {NULL, NULL, 0, NULL} /* sentinel */
+};
+
/**begin repeat
* #name = cfloat,clongdouble#
*/
@@ -2195,6 +2220,12 @@ static PyMethodDef @name@type_methods[] = {
{"__complex__",
(PyCFunction)@name@_complex,
METH_VARARGS | METH_KEYWORDS, NULL},
+ /* for typing; requires python >= 3.9 */
+ #ifdef Py_GENERICALIASOBJECT_H
+ {"__class_getitem__",
+ (PyCFunction)numbertype_class_getitem,
+ METH_CLASS | METH_O, NULL},
+ #endif
{NULL, NULL, 0, NULL}
};
/**end repeat**/
@@ -2232,6 +2263,27 @@ static PyMethodDef @name@type_methods[] = {
{"is_integer",
(PyCFunction)@name@_is_integer,
METH_NOARGS, NULL},
+ /* for typing; requires python >= 3.9 */
+ #ifdef Py_GENERICALIASOBJECT_H
+ {"__class_getitem__",
+ (PyCFunction)numbertype_class_getitem,
+ METH_CLASS | METH_O, NULL},
+ #endif
+ {NULL, NULL, 0, NULL}
+};
+/**end repeat**/
+
+/**begin repeat
+ * #name = byte, short, int, long, longlong, ubyte, ushort,
+ * uint, ulong, ulonglong, timedelta, cdouble#
+ */
+static PyMethodDef @name@type_methods[] = {
+ /* for typing; requires python >= 3.9 */
+ #ifdef Py_GENERICALIASOBJECT_H
+ {"__class_getitem__",
+ (PyCFunction)numbertype_class_getitem,
+ METH_CLASS | METH_O, NULL},
+ #endif
{NULL, NULL, 0, NULL}
};
/**end repeat**/
@@ -3951,6 +4003,8 @@ initialize_numeric_types(void)
PyIntegerArrType_Type.tp_getset = inttype_getsets;
+ PyNumberArrType_Type.tp_methods = numbertype_methods;
+
/**begin repeat
* #NAME= Number, Integer, SignedInteger, UnsignedInteger, Inexact,
* Floating, ComplexFloating, Flexible, Character#
@@ -4016,6 +4070,17 @@ initialize_numeric_types(void)
/**end repeat**/
+ /**begin repeat
+ * #name = byte, short, int, long, longlong, ubyte, ushort,
+ * uint, ulong, ulonglong, timedelta, cdouble#
+ * #Name = Byte, Short, Int, Long, LongLong, UByte, UShort,
+ * UInt, ULong, ULongLong, Timedelta, CDouble#
+ */
+
+ Py@Name@ArrType_Type.tp_methods = @name@type_methods;
+
+ /**end repeat**/
+
/* We won't be inheriting from Python Int type. */
PyIntArrType_Type.tp_hash = int_arrtype_hash;
diff --git a/numpy/core/tests/test_scalar_methods.py b/numpy/core/tests/test_scalar_methods.py
index 94b2dd3c9..0cdfe99b1 100644
--- a/numpy/core/tests/test_scalar_methods.py
+++ b/numpy/core/tests/test_scalar_methods.py
@@ -1,8 +1,11 @@
"""
Test the scalar constructors, which also do type-coercion
"""
+import sys
import fractions
import platform
+import types
+from typing import Any, Type
import pytest
import numpy as np
@@ -128,3 +131,31 @@ class TestIsInteger:
if value == 0:
continue
assert not value.is_integer()
+
+
+@pytest.mark.skipif(sys.version_info < (3, 9), reason="Requires python 3.9")
+class TestClassGetItem:
+ @pytest.mark.parametrize("cls", [
+ np.number,
+ np.integer,
+ np.inexact,
+ np.unsignedinteger,
+ np.signedinteger,
+ np.floating,
+ np.complexfloating,
+ ])
+ def test_abc(self, cls: Type[np.number]) -> None:
+ alias = cls[Any]
+ assert isinstance(alias, types.GenericAlias)
+ assert alias.__origin__ is cls
+
+ @pytest.mark.parametrize("cls", [np.generic, np.flexible, np.character])
+ def test_abc_non_numeric(self, cls: Type[np.generic]) -> None:
+ with pytest.raises(TypeError):
+ cls[Any]
+
+ @pytest.mark.parametrize("code", np.typecodes["All"])
+ def test_concrete(self, code: str) -> None:
+ cls = np.dtype(code).type
+ with pytest.raises(TypeError):
+ cls[Any]