summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorMatti Picus <matti.picus@gmail.com>2020-01-24 01:37:23 +1100
committerGitHub <noreply@github.com>2020-01-24 01:37:23 +1100
commit04ac2a13b302a7af6fe2a5ca67e09a0e09a0f8e7 (patch)
tree41e24948f37f976d5c4c0e0553e749b2e214ae9e /numpy/core
parent1279616f48cffd55de0b1dd53f86a23e170701ed (diff)
parentc0ca8143aa98a2623a28a748648da9c3667695c0 (diff)
downloadnumpy-04ac2a13b302a7af6fe2a5ca67e09a0e09a0f8e7.tar.gz
Merge pull request #15393 from eric-wieser/remove-base-handling
MAINT/BUG: Fixups to scalar base classes
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/src/multiarray/scalartypes.c.src40
-rw-r--r--numpy/core/tests/test_scalarinherit.py19
2 files changed, 35 insertions, 24 deletions
diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src
index cd26d20fa..383cef5bd 100644
--- a/numpy/core/src/multiarray/scalartypes.c.src
+++ b/numpy/core/src/multiarray/scalartypes.c.src
@@ -2562,25 +2562,20 @@ object_arrtype_dealloc(PyObject *v)
* memory from the sub-types memory allocator.
*/
-#define _WORK(num) \
- if (type->tp_bases && (PyTuple_GET_SIZE(type->tp_bases)==2)) { \
- PyTypeObject *sup; \
- /* We are inheriting from a Python type as well so \
- give it first dibs on conversion */ \
- sup = (PyTypeObject *)PyTuple_GET_ITEM(type->tp_bases, num); \
- /* Prevent recursion */ \
- if (thisfunc != sup->tp_new) { \
- robj = sup->tp_new(type, args, kwds); \
- if (robj != NULL) goto finish; \
- if (PyTuple_GET_SIZE(args)!=1) return NULL; \
- PyErr_Clear(); \
- } \
- /* now do default conversion */ \
- }
-
-#define _WORK1 _WORK(1)
-#define _WORKz _WORK(0)
-#define _WORK0
+#define _WORK(cls, num) \
+ assert(cls.tp_bases && (PyTuple_GET_SIZE(cls.tp_bases) == 2)); \
+ /* We are inheriting from a Python type as well so \
+ give it first dibs on conversion */ \
+ PyTypeObject *sup = (PyTypeObject *)PyTuple_GET_ITEM(cls.tp_bases, num); \
+ robj = sup->tp_new(type, args, kwds); \
+ if (robj != NULL) goto finish; \
+ if (PyTuple_GET_SIZE(args) != 1) return NULL; \
+ PyErr_Clear(); \
+ /* now do default conversion */
+
+#define _WORK1(cls) _WORK(cls, 1)
+#define _WORKz(cls) _WORK(cls, 0)
+#define _WORK0(cls)
/**begin repeat
* #name = byte, short, int, long, longlong, ubyte, ushort, uint, ulong,
@@ -2592,7 +2587,7 @@ object_arrtype_dealloc(PyObject *v)
* #TYPE = BYTE, SHORT, INT, LONG, LONGLONG, UBYTE, USHORT, UINT, ULONG,
* ULONGLONG, HALF, FLOAT, DOUBLE, LONGDOUBLE, CFLOAT, CDOUBLE,
* CLONGDOUBLE, STRING, UNICODE, OBJECT#
- * #work = 0,0,1,1,1,0,0,0,0,0,0,0,1,0,0,0,0,z,z,0#
+ * #work = 0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,z,z,0#
* #default = 0*17,1*2,2#
*/
@@ -2610,9 +2605,6 @@ static PyObject *
PyObject *robj;
PyArrayObject *arr;
PyArray_Descr *typecode = NULL;
-#if (@work@ != 0) || (@default@ == 1)
- void *thisfunc = (void *)@name@_arrtype_new;
-#endif
#if !(@default@ == 2)
int itemsize;
void *dest, *src;
@@ -2622,7 +2614,7 @@ static PyObject *
* allow base-class (if any) to do conversion
* If successful, this will jump to finish:
*/
- _WORK@work@
+ _WORK@work@(Py@Name@ArrType_Type)
/* TODO: include type name in error message, which is not @name@ */
if (!PyArg_ParseTuple(args, "|O", &obj)) {
diff --git a/numpy/core/tests/test_scalarinherit.py b/numpy/core/tests/test_scalarinherit.py
index af3669d73..74829986c 100644
--- a/numpy/core/tests/test_scalarinherit.py
+++ b/numpy/core/tests/test_scalarinherit.py
@@ -2,6 +2,8 @@
""" Test printing of scalar types.
"""
+import pytest
+
import numpy as np
from numpy.testing import assert_
@@ -21,6 +23,14 @@ class B0(np.float64, A):
class C0(B0):
pass
+class HasNew:
+ def __new__(cls, *args, **kwargs):
+ return cls, args, kwargs
+
+class B1(np.float64, HasNew):
+ pass
+
+
class TestInherit:
def test_init(self):
x = B(1.0)
@@ -36,6 +46,15 @@ class TestInherit:
y = C0(2.0)
assert_(str(y) == '2.0')
+ def test_gh_15395(self):
+ # HasNew is the second base, so `np.float64` should have priority
+ x = B1(1.0)
+ assert_(str(x) == '1.0')
+
+ # previously caused RecursionError!?
+ with pytest.raises(TypeError):
+ B1(1.0, 2.0)
+
class TestCharacter:
def test_char_radd(self):