summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBas van Beek <43369155+BvB93@users.noreply.github.com>2021-09-18 20:54:28 +0200
committerBas van Beek <43369155+BvB93@users.noreply.github.com>2021-09-18 20:54:52 +0200
commit8c89fef9e677afd3ee7777f242b6a53d3b7dfef4 (patch)
treefaacdca22d3ebfbe42c7c3df9e1d3d795c48534c
parentaecdb9fe513bf10df704466cf138a280354e3166 (diff)
downloadnumpy-8c89fef9e677afd3ee7777f242b6a53d3b7dfef4.tar.gz
ENH: Add special-casing for `complexfloating` so that it can take 2 parameters
-rw-r--r--numpy/core/src/multiarray/scalartypes.c.src14
-rw-r--r--numpy/core/tests/test_scalar_methods.py8
2 files changed, 18 insertions, 4 deletions
diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src
index cacf4485e..93cc9666e 100644
--- a/numpy/core/src/multiarray/scalartypes.c.src
+++ b/numpy/core/src/multiarray/scalartypes.c.src
@@ -1812,12 +1812,22 @@ numbertype_class_getitem_abc(PyObject *cls, PyObject *args)
#ifdef Py_GENERICALIASOBJECT_H
Py_ssize_t args_len;
+ int args_len_expected;
+
+ /* complexfloating should take 2 parameters, all others take 1 */
+ if (PyType_IsSubtype((PyTypeObject *)cls,
+ &PyComplexFloatingArrType_Type)) {
+ args_len_expected = 2;
+ }
+ else {
+ args_len_expected = 1;
+ }
args_len = PyTuple_Check(args) ? PyTuple_Size(args) : 1;
- if (args_len != 1) {
+ if (args_len != args_len_expected) {
return PyErr_Format(PyExc_TypeError,
"Too %s arguments for %s",
- args_len > 1 ? "many" : "few",
+ args_len > args_len_expected ? "many" : "few",
((PyTypeObject *)cls)->tp_name);
}
generic_alias = Py_GenericAlias(cls, args);
diff --git a/numpy/core/tests/test_scalar_methods.py b/numpy/core/tests/test_scalar_methods.py
index 01fc66f84..6077c8f75 100644
--- a/numpy/core/tests/test_scalar_methods.py
+++ b/numpy/core/tests/test_scalar_methods.py
@@ -142,13 +142,17 @@ class TestClassGetItem:
np.unsignedinteger,
np.signedinteger,
np.floating,
- np.complexfloating,
])
def test_abc(self, cls: Type[np.number]) -> None:
alias = cls[Any]
assert isinstance(alias, types.GenericAlias)
assert alias.__origin__ is cls
+ def test_abc_complexfloating(self) -> None:
+ alias = np.complexfloating[Any, Any]
+ assert isinstance(alias, types.GenericAlias)
+ assert alias.__origin__ is np.complexfloating
+
@pytest.mark.parametrize("cls", [np.generic, np.flexible, np.character])
def test_abc_non_numeric(self, cls: Type[np.generic]) -> None:
with pytest.raises(TypeError):
@@ -174,7 +178,7 @@ class TestClassGetItem:
@pytest.mark.skipif(sys.version_info >= (3, 9), reason="Requires python 3.8")
-@pytest.mark.parametrize("cls", [np.number, np.int64])
+@pytest.mark.parametrize("cls", [np.number, np.complexfloating, np.int64])
def test_class_getitem_38(cls: Type[np.number]) -> None:
match = "Type subscription requires python >= 3.9"
with pytest.raises(TypeError, match=match):