From 8c89fef9e677afd3ee7777f242b6a53d3b7dfef4 Mon Sep 17 00:00:00 2001 From: Bas van Beek <43369155+BvB93@users.noreply.github.com> Date: Sat, 18 Sep 2021 20:54:28 +0200 Subject: ENH: Add special-casing for `complexfloating` so that it can take 2 parameters --- numpy/core/src/multiarray/scalartypes.c.src | 14 ++++++++++++-- numpy/core/tests/test_scalar_methods.py | 8 ++++++-- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src index cacf4485e..93cc9666e 100644 --- a/numpy/core/src/multiarray/scalartypes.c.src +++ b/numpy/core/src/multiarray/scalartypes.c.src @@ -1812,12 +1812,22 @@ numbertype_class_getitem_abc(PyObject *cls, PyObject *args) #ifdef Py_GENERICALIASOBJECT_H Py_ssize_t args_len; + int args_len_expected; + + /* complexfloating should take 2 parameters, all others take 1 */ + if (PyType_IsSubtype((PyTypeObject *)cls, + &PyComplexFloatingArrType_Type)) { + args_len_expected = 2; + } + else { + args_len_expected = 1; + } args_len = PyTuple_Check(args) ? PyTuple_Size(args) : 1; - if (args_len != 1) { + if (args_len != args_len_expected) { return PyErr_Format(PyExc_TypeError, "Too %s arguments for %s", - args_len > 1 ? "many" : "few", + args_len > args_len_expected ? "many" : "few", ((PyTypeObject *)cls)->tp_name); } generic_alias = Py_GenericAlias(cls, args); diff --git a/numpy/core/tests/test_scalar_methods.py b/numpy/core/tests/test_scalar_methods.py index 01fc66f84..6077c8f75 100644 --- a/numpy/core/tests/test_scalar_methods.py +++ b/numpy/core/tests/test_scalar_methods.py @@ -142,13 +142,17 @@ class TestClassGetItem: np.unsignedinteger, np.signedinteger, np.floating, - np.complexfloating, ]) def test_abc(self, cls: Type[np.number]) -> None: alias = cls[Any] assert isinstance(alias, types.GenericAlias) assert alias.__origin__ is cls + def test_abc_complexfloating(self) -> None: + alias = np.complexfloating[Any, Any] + assert isinstance(alias, types.GenericAlias) + assert alias.__origin__ is np.complexfloating + @pytest.mark.parametrize("cls", [np.generic, np.flexible, np.character]) def test_abc_non_numeric(self, cls: Type[np.generic]) -> None: with pytest.raises(TypeError): @@ -174,7 +178,7 @@ class TestClassGetItem: @pytest.mark.skipif(sys.version_info >= (3, 9), reason="Requires python 3.8") -@pytest.mark.parametrize("cls", [np.number, np.int64]) +@pytest.mark.parametrize("cls", [np.number, np.complexfloating, np.int64]) def test_class_getitem_38(cls: Type[np.number]) -> None: match = "Type subscription requires python >= 3.9" with pytest.raises(TypeError, match=match): -- cgit v1.2.1