diff options
author | Matti Picus <matti.picus@gmail.com> | 2020-01-24 01:37:23 +1100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-01-24 01:37:23 +1100 |
commit | 04ac2a13b302a7af6fe2a5ca67e09a0e09a0f8e7 (patch) | |
tree | 41e24948f37f976d5c4c0e0553e749b2e214ae9e /numpy/core | |
parent | 1279616f48cffd55de0b1dd53f86a23e170701ed (diff) | |
parent | c0ca8143aa98a2623a28a748648da9c3667695c0 (diff) | |
download | numpy-04ac2a13b302a7af6fe2a5ca67e09a0e09a0f8e7.tar.gz |
Merge pull request #15393 from eric-wieser/remove-base-handling
MAINT/BUG: Fixups to scalar base classes
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/src/multiarray/scalartypes.c.src | 40 | ||||
-rw-r--r-- | numpy/core/tests/test_scalarinherit.py | 19 |
2 files changed, 35 insertions, 24 deletions
diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src index cd26d20fa..383cef5bd 100644 --- a/numpy/core/src/multiarray/scalartypes.c.src +++ b/numpy/core/src/multiarray/scalartypes.c.src @@ -2562,25 +2562,20 @@ object_arrtype_dealloc(PyObject *v) * memory from the sub-types memory allocator. */ -#define _WORK(num) \ - if (type->tp_bases && (PyTuple_GET_SIZE(type->tp_bases)==2)) { \ - PyTypeObject *sup; \ - /* We are inheriting from a Python type as well so \ - give it first dibs on conversion */ \ - sup = (PyTypeObject *)PyTuple_GET_ITEM(type->tp_bases, num); \ - /* Prevent recursion */ \ - if (thisfunc != sup->tp_new) { \ - robj = sup->tp_new(type, args, kwds); \ - if (robj != NULL) goto finish; \ - if (PyTuple_GET_SIZE(args)!=1) return NULL; \ - PyErr_Clear(); \ - } \ - /* now do default conversion */ \ - } - -#define _WORK1 _WORK(1) -#define _WORKz _WORK(0) -#define _WORK0 +#define _WORK(cls, num) \ + assert(cls.tp_bases && (PyTuple_GET_SIZE(cls.tp_bases) == 2)); \ + /* We are inheriting from a Python type as well so \ + give it first dibs on conversion */ \ + PyTypeObject *sup = (PyTypeObject *)PyTuple_GET_ITEM(cls.tp_bases, num); \ + robj = sup->tp_new(type, args, kwds); \ + if (robj != NULL) goto finish; \ + if (PyTuple_GET_SIZE(args) != 1) return NULL; \ + PyErr_Clear(); \ + /* now do default conversion */ + +#define _WORK1(cls) _WORK(cls, 1) +#define _WORKz(cls) _WORK(cls, 0) +#define _WORK0(cls) /**begin repeat * #name = byte, short, int, long, longlong, ubyte, ushort, uint, ulong, @@ -2592,7 +2587,7 @@ object_arrtype_dealloc(PyObject *v) * #TYPE = BYTE, SHORT, INT, LONG, LONGLONG, UBYTE, USHORT, UINT, ULONG, * ULONGLONG, HALF, FLOAT, DOUBLE, LONGDOUBLE, CFLOAT, CDOUBLE, * CLONGDOUBLE, STRING, UNICODE, OBJECT# - * #work = 0,0,1,1,1,0,0,0,0,0,0,0,1,0,0,0,0,z,z,0# + * #work = 0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,z,z,0# * #default = 0*17,1*2,2# */ @@ -2610,9 +2605,6 @@ static PyObject * PyObject *robj; PyArrayObject *arr; PyArray_Descr *typecode = NULL; -#if (@work@ != 0) || (@default@ == 1) - void *thisfunc = (void *)@name@_arrtype_new; -#endif #if !(@default@ == 2) int itemsize; void *dest, *src; @@ -2622,7 +2614,7 @@ static PyObject * * allow base-class (if any) to do conversion * If successful, this will jump to finish: */ - _WORK@work@ + _WORK@work@(Py@Name@ArrType_Type) /* TODO: include type name in error message, which is not @name@ */ if (!PyArg_ParseTuple(args, "|O", &obj)) { diff --git a/numpy/core/tests/test_scalarinherit.py b/numpy/core/tests/test_scalarinherit.py index af3669d73..74829986c 100644 --- a/numpy/core/tests/test_scalarinherit.py +++ b/numpy/core/tests/test_scalarinherit.py @@ -2,6 +2,8 @@ """ Test printing of scalar types. """ +import pytest + import numpy as np from numpy.testing import assert_ @@ -21,6 +23,14 @@ class B0(np.float64, A): class C0(B0): pass +class HasNew: + def __new__(cls, *args, **kwargs): + return cls, args, kwargs + +class B1(np.float64, HasNew): + pass + + class TestInherit: def test_init(self): x = B(1.0) @@ -36,6 +46,15 @@ class TestInherit: y = C0(2.0) assert_(str(y) == '2.0') + def test_gh_15395(self): + # HasNew is the second base, so `np.float64` should have priority + x = B1(1.0) + assert_(str(x) == '1.0') + + # previously caused RecursionError!? + with pytest.raises(TypeError): + B1(1.0, 2.0) + class TestCharacter: def test_char_radd(self): |