diff options
author | Bas van Beek <43369155+BvB93@users.noreply.github.com> | 2021-09-18 20:54:28 +0200 |
---|---|---|
committer | Bas van Beek <43369155+BvB93@users.noreply.github.com> | 2021-09-18 20:54:52 +0200 |
commit | 8c89fef9e677afd3ee7777f242b6a53d3b7dfef4 (patch) | |
tree | faacdca22d3ebfbe42c7c3df9e1d3d795c48534c | |
parent | aecdb9fe513bf10df704466cf138a280354e3166 (diff) | |
download | numpy-8c89fef9e677afd3ee7777f242b6a53d3b7dfef4.tar.gz |
ENH: Add special-casing for `complexfloating` so that it can take 2 parameters
-rw-r--r-- | numpy/core/src/multiarray/scalartypes.c.src | 14 | ||||
-rw-r--r-- | numpy/core/tests/test_scalar_methods.py | 8 |
2 files changed, 18 insertions, 4 deletions
diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src index cacf4485e..93cc9666e 100644 --- a/numpy/core/src/multiarray/scalartypes.c.src +++ b/numpy/core/src/multiarray/scalartypes.c.src @@ -1812,12 +1812,22 @@ numbertype_class_getitem_abc(PyObject *cls, PyObject *args) #ifdef Py_GENERICALIASOBJECT_H Py_ssize_t args_len; + int args_len_expected; + + /* complexfloating should take 2 parameters, all others take 1 */ + if (PyType_IsSubtype((PyTypeObject *)cls, + &PyComplexFloatingArrType_Type)) { + args_len_expected = 2; + } + else { + args_len_expected = 1; + } args_len = PyTuple_Check(args) ? PyTuple_Size(args) : 1; - if (args_len != 1) { + if (args_len != args_len_expected) { return PyErr_Format(PyExc_TypeError, "Too %s arguments for %s", - args_len > 1 ? "many" : "few", + args_len > args_len_expected ? "many" : "few", ((PyTypeObject *)cls)->tp_name); } generic_alias = Py_GenericAlias(cls, args); diff --git a/numpy/core/tests/test_scalar_methods.py b/numpy/core/tests/test_scalar_methods.py index 01fc66f84..6077c8f75 100644 --- a/numpy/core/tests/test_scalar_methods.py +++ b/numpy/core/tests/test_scalar_methods.py @@ -142,13 +142,17 @@ class TestClassGetItem: np.unsignedinteger, np.signedinteger, np.floating, - np.complexfloating, ]) def test_abc(self, cls: Type[np.number]) -> None: alias = cls[Any] assert isinstance(alias, types.GenericAlias) assert alias.__origin__ is cls + def test_abc_complexfloating(self) -> None: + alias = np.complexfloating[Any, Any] + assert isinstance(alias, types.GenericAlias) + assert alias.__origin__ is np.complexfloating + @pytest.mark.parametrize("cls", [np.generic, np.flexible, np.character]) def test_abc_non_numeric(self, cls: Type[np.generic]) -> None: with pytest.raises(TypeError): @@ -174,7 +178,7 @@ class TestClassGetItem: @pytest.mark.skipif(sys.version_info >= (3, 9), reason="Requires python 3.8") -@pytest.mark.parametrize("cls", [np.number, np.int64]) +@pytest.mark.parametrize("cls", [np.number, np.complexfloating, np.int64]) def test_class_getitem_38(cls: Type[np.number]) -> None: match = "Type subscription requires python >= 3.9" with pytest.raises(TypeError, match=match): |