diff options
author | Sebastian Berg <sebastian@sipsolutions.net> | 2020-06-20 11:24:21 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-06-20 11:24:21 -0500 |
commit | 209499058705258d60194f7d06890a6825778d6a (patch) | |
tree | 32f0de546cdac52c3e93ff10e79b57b88aa31def /numpy | |
parent | 69555f52b2e2c5032cf70d40d9f9b8ec9a048500 (diff) | |
parent | f33a002366f3d02dd6662ad31595b8ae1877dfde (diff) | |
download | numpy-209499058705258d60194f7d06890a6825778d6a.tar.gz |
Merge pull request #15471 from eric-wieser/fix-scalar-array-type
BUG: Ensure PyArray_FromScalar always returns the requested dtype
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/multiarray/scalarapi.c | 50 | ||||
-rw-r--r-- | numpy/core/tests/test_scalar_ctors.py | 34 |
2 files changed, 59 insertions, 25 deletions
diff --git a/numpy/core/src/multiarray/scalarapi.c b/numpy/core/src/multiarray/scalarapi.c index f3c440dc6..154f4d637 100644 --- a/numpy/core/src/multiarray/scalarapi.c +++ b/numpy/core/src/multiarray/scalarapi.c @@ -286,14 +286,10 @@ PyArray_CastScalarDirect(PyObject *scalar, PyArray_Descr *indescr, NPY_NO_EXPORT PyObject * PyArray_FromScalar(PyObject *scalar, PyArray_Descr *outcode) { - PyArray_Descr *typecode; - PyArrayObject *r; - char *memptr; - PyObject *ret; - /* convert to 0-dim array of scalar typecode */ - typecode = PyArray_DescrFromScalar(scalar); + PyArray_Descr *typecode = PyArray_DescrFromScalar(scalar); if (typecode == NULL) { + Py_XDECREF(outcode); return NULL; } if ((typecode->type_num == NPY_VOID) && @@ -307,49 +303,53 @@ PyArray_FromScalar(PyObject *scalar, PyArray_Descr *outcode) NULL, (PyObject *)scalar); } - /* Need to INCREF typecode because PyArray_NewFromDescr steals a - * reference below and we still need to access typecode afterwards. */ - Py_INCREF(typecode); - r = (PyArrayObject *)PyArray_NewFromDescr(&PyArray_Type, + PyArrayObject *r = (PyArrayObject *)PyArray_NewFromDescr(&PyArray_Type, typecode, 0, NULL, NULL, NULL, 0, NULL); - if (r==NULL) { - Py_DECREF(typecode); Py_XDECREF(outcode); + if (r == NULL) { + Py_XDECREF(outcode); return NULL; } + /* the dtype used by the array may be different to the one requested */ + typecode = PyArray_DESCR(r); if (PyDataType_FLAGCHK(typecode, NPY_USE_SETITEM)) { if (typecode->f->setitem(scalar, PyArray_DATA(r), r) < 0) { - Py_DECREF(typecode); Py_XDECREF(outcode); Py_DECREF(r); + Py_DECREF(r); + Py_XDECREF(outcode); return NULL; } - goto finish; } + else { + char *memptr = scalar_value(scalar, typecode); - memptr = scalar_value(scalar, typecode); - - memcpy(PyArray_DATA(r), memptr, PyArray_ITEMSIZE(r)); - if (PyDataType_FLAGCHK(typecode, NPY_ITEM_HASOBJECT)) { - /* Need to INCREF just the PyObject portion */ - PyArray_Item_INCREF(memptr, typecode); + memcpy(PyArray_DATA(r), memptr, PyArray_ITEMSIZE(r)); + if (PyDataType_FLAGCHK(typecode, NPY_ITEM_HASOBJECT)) { + /* Need to INCREF just the PyObject portion */ + PyArray_Item_INCREF(memptr, typecode); + } } -finish: if (outcode == NULL) { - Py_DECREF(typecode); return (PyObject *)r; } if (PyArray_EquivTypes(outcode, typecode)) { if (!PyTypeNum_ISEXTENDED(typecode->type_num) || (outcode->elsize == typecode->elsize)) { - Py_DECREF(typecode); Py_DECREF(outcode); + /* + * Since the type is equivalent, and we haven't handed the array + * to anyone yet, let's fix the dtype to be what was requested, + * even if it is equivalent to what was passed in. + */ + Py_SETREF(((PyArrayObject_fields *)r)->descr, outcode); + return (PyObject *)r; } } /* cast if necessary to desired output typecode */ - ret = PyArray_CastToType((PyArrayObject *)r, outcode, 0); - Py_DECREF(typecode); Py_DECREF(r); + PyObject *ret = PyArray_CastToType(r, outcode, 0); + Py_DECREF(r); return ret; } diff --git a/numpy/core/tests/test_scalar_ctors.py b/numpy/core/tests/test_scalar_ctors.py index cd518274a..7e933537d 100644 --- a/numpy/core/tests/test_scalar_ctors.py +++ b/numpy/core/tests/test_scalar_ctors.py @@ -79,3 +79,37 @@ class TestFromInt: def test_uint64_from_negative(self): assert_equal(np.uint64(-2), np.uint64(18446744073709551614)) + + +int_types = [np.byte, np.short, np.intc, np.int_, np.longlong] +uint_types = [np.ubyte, np.ushort, np.uintc, np.uint, np.ulonglong] +float_types = [np.half, np.single, np.double, np.longdouble] +cfloat_types = [np.csingle, np.cdouble, np.clongdouble] + + +class TestArrayFromScalar: + """ gh-15467 """ + + def _do_test(self, t1, t2): + x = t1(2) + arr = np.array(x, dtype=t2) + # type should be preserved exactly + if t2 is None: + assert arr.dtype.type is t1 + else: + assert arr.dtype.type is t2 + + @pytest.mark.parametrize('t1', int_types + uint_types) + @pytest.mark.parametrize('t2', int_types + uint_types + [None]) + def test_integers(self, t1, t2): + return self._do_test(t1, t2) + + @pytest.mark.parametrize('t1', float_types) + @pytest.mark.parametrize('t2', float_types + [None]) + def test_reals(self, t1, t2): + return self._do_test(t1, t2) + + @pytest.mark.parametrize('t1', cfloat_types) + @pytest.mark.parametrize('t2', cfloat_types + [None]) + def test_complex(self, t1, t2): + return self._do_test(t1, t2) |