summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorSebastian Berg <sebastian@sipsolutions.net>2020-06-20 11:24:21 -0500
committerGitHub <noreply@github.com>2020-06-20 11:24:21 -0500
commit209499058705258d60194f7d06890a6825778d6a (patch)
tree32f0de546cdac52c3e93ff10e79b57b88aa31def /numpy
parent69555f52b2e2c5032cf70d40d9f9b8ec9a048500 (diff)
parentf33a002366f3d02dd6662ad31595b8ae1877dfde (diff)
downloadnumpy-209499058705258d60194f7d06890a6825778d6a.tar.gz
Merge pull request #15471 from eric-wieser/fix-scalar-array-type
BUG: Ensure PyArray_FromScalar always returns the requested dtype
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/scalarapi.c50
-rw-r--r--numpy/core/tests/test_scalar_ctors.py34
2 files changed, 59 insertions, 25 deletions
diff --git a/numpy/core/src/multiarray/scalarapi.c b/numpy/core/src/multiarray/scalarapi.c
index f3c440dc6..154f4d637 100644
--- a/numpy/core/src/multiarray/scalarapi.c
+++ b/numpy/core/src/multiarray/scalarapi.c
@@ -286,14 +286,10 @@ PyArray_CastScalarDirect(PyObject *scalar, PyArray_Descr *indescr,
NPY_NO_EXPORT PyObject *
PyArray_FromScalar(PyObject *scalar, PyArray_Descr *outcode)
{
- PyArray_Descr *typecode;
- PyArrayObject *r;
- char *memptr;
- PyObject *ret;
-
/* convert to 0-dim array of scalar typecode */
- typecode = PyArray_DescrFromScalar(scalar);
+ PyArray_Descr *typecode = PyArray_DescrFromScalar(scalar);
if (typecode == NULL) {
+ Py_XDECREF(outcode);
return NULL;
}
if ((typecode->type_num == NPY_VOID) &&
@@ -307,49 +303,53 @@ PyArray_FromScalar(PyObject *scalar, PyArray_Descr *outcode)
NULL, (PyObject *)scalar);
}
- /* Need to INCREF typecode because PyArray_NewFromDescr steals a
- * reference below and we still need to access typecode afterwards. */
- Py_INCREF(typecode);
- r = (PyArrayObject *)PyArray_NewFromDescr(&PyArray_Type,
+ PyArrayObject *r = (PyArrayObject *)PyArray_NewFromDescr(&PyArray_Type,
typecode,
0, NULL,
NULL, NULL, 0, NULL);
- if (r==NULL) {
- Py_DECREF(typecode); Py_XDECREF(outcode);
+ if (r == NULL) {
+ Py_XDECREF(outcode);
return NULL;
}
+ /* the dtype used by the array may be different to the one requested */
+ typecode = PyArray_DESCR(r);
if (PyDataType_FLAGCHK(typecode, NPY_USE_SETITEM)) {
if (typecode->f->setitem(scalar, PyArray_DATA(r), r) < 0) {
- Py_DECREF(typecode); Py_XDECREF(outcode); Py_DECREF(r);
+ Py_DECREF(r);
+ Py_XDECREF(outcode);
return NULL;
}
- goto finish;
}
+ else {
+ char *memptr = scalar_value(scalar, typecode);
- memptr = scalar_value(scalar, typecode);
-
- memcpy(PyArray_DATA(r), memptr, PyArray_ITEMSIZE(r));
- if (PyDataType_FLAGCHK(typecode, NPY_ITEM_HASOBJECT)) {
- /* Need to INCREF just the PyObject portion */
- PyArray_Item_INCREF(memptr, typecode);
+ memcpy(PyArray_DATA(r), memptr, PyArray_ITEMSIZE(r));
+ if (PyDataType_FLAGCHK(typecode, NPY_ITEM_HASOBJECT)) {
+ /* Need to INCREF just the PyObject portion */
+ PyArray_Item_INCREF(memptr, typecode);
+ }
}
-finish:
if (outcode == NULL) {
- Py_DECREF(typecode);
return (PyObject *)r;
}
if (PyArray_EquivTypes(outcode, typecode)) {
if (!PyTypeNum_ISEXTENDED(typecode->type_num)
|| (outcode->elsize == typecode->elsize)) {
- Py_DECREF(typecode); Py_DECREF(outcode);
+ /*
+ * Since the type is equivalent, and we haven't handed the array
+ * to anyone yet, let's fix the dtype to be what was requested,
+ * even if it is equivalent to what was passed in.
+ */
+ Py_SETREF(((PyArrayObject_fields *)r)->descr, outcode);
+
return (PyObject *)r;
}
}
/* cast if necessary to desired output typecode */
- ret = PyArray_CastToType((PyArrayObject *)r, outcode, 0);
- Py_DECREF(typecode); Py_DECREF(r);
+ PyObject *ret = PyArray_CastToType(r, outcode, 0);
+ Py_DECREF(r);
return ret;
}
diff --git a/numpy/core/tests/test_scalar_ctors.py b/numpy/core/tests/test_scalar_ctors.py
index cd518274a..7e933537d 100644
--- a/numpy/core/tests/test_scalar_ctors.py
+++ b/numpy/core/tests/test_scalar_ctors.py
@@ -79,3 +79,37 @@ class TestFromInt:
def test_uint64_from_negative(self):
assert_equal(np.uint64(-2), np.uint64(18446744073709551614))
+
+
+int_types = [np.byte, np.short, np.intc, np.int_, np.longlong]
+uint_types = [np.ubyte, np.ushort, np.uintc, np.uint, np.ulonglong]
+float_types = [np.half, np.single, np.double, np.longdouble]
+cfloat_types = [np.csingle, np.cdouble, np.clongdouble]
+
+
+class TestArrayFromScalar:
+ """ gh-15467 """
+
+ def _do_test(self, t1, t2):
+ x = t1(2)
+ arr = np.array(x, dtype=t2)
+ # type should be preserved exactly
+ if t2 is None:
+ assert arr.dtype.type is t1
+ else:
+ assert arr.dtype.type is t2
+
+ @pytest.mark.parametrize('t1', int_types + uint_types)
+ @pytest.mark.parametrize('t2', int_types + uint_types + [None])
+ def test_integers(self, t1, t2):
+ return self._do_test(t1, t2)
+
+ @pytest.mark.parametrize('t1', float_types)
+ @pytest.mark.parametrize('t2', float_types + [None])
+ def test_reals(self, t1, t2):
+ return self._do_test(t1, t2)
+
+ @pytest.mark.parametrize('t1', cfloat_types)
+ @pytest.mark.parametrize('t2', cfloat_types + [None])
+ def test_complex(self, t1, t2):
+ return self._do_test(t1, t2)