diff options
-rw-r--r-- | numpy/core/src/multiarray/multiarraymodule.c | 27 | ||||
-rw-r--r-- | numpy/core/src/multiarray/scalartypes.c.src | 9 | ||||
-rw-r--r-- | numpy/core/tests/test_records.py | 10 |
3 files changed, 31 insertions, 15 deletions
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index 1aad70dc6..8dcd67c20 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -2003,20 +2003,33 @@ array_scalar(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *kwds) int alloc = 0; void *dptr; PyObject *ret; - + PyObject *base = NULL; if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!|O:scalar", kwlist, &PyArrayDescr_Type, &typecode, &obj)) { return NULL; } if (PyDataType_FLAGCHK(typecode, NPY_LIST_PICKLE)) { - if (!PySequence_Check(obj)) { - PyErr_SetString(PyExc_TypeError, - "found non-sequence while unpickling scalar with " - "NPY_LIST_PICKLE set"); + if (typecode->type_num == NPY_OBJECT) { + Py_INCREF(obj); + return obj; + } + /* We store the full array to unpack it here: */ + if (!PyArray_CheckExact(obj)) { + /* We pickle structured voids as arrays currently */ + PyErr_SetString(PyExc_RuntimeError, + "Unpickling NPY_LIST_PICKLE (structured void) scalar " + "requires an array. The pickle file may be corrupted?"); return NULL; } - dptr = &obj; + if (!PyArray_EquivTypes(PyArray_DESCR((PyArrayObject *)obj), typecode)) { + PyErr_SetString(PyExc_RuntimeError, + "Pickled array is not compatible with requested scalar " + "dtype. The pickle file may be corrupted?"); + return NULL; + } + base = obj; + dptr = PyArray_BYTES((PyArrayObject *)obj); } else if (PyDataType_FLAGCHK(typecode, NPY_ITEM_IS_POINTER)) { @@ -2066,7 +2079,7 @@ array_scalar(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *kwds) dptr = PyBytes_AS_STRING(obj); } } - ret = PyArray_Scalar(dptr, typecode, NULL); + ret = PyArray_Scalar(dptr, typecode, base); /* free dptr which contains zeros */ if (alloc) { diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src index d2ae6ce31..df9350406 100644 --- a/numpy/core/src/multiarray/scalartypes.c.src +++ b/numpy/core/src/multiarray/scalartypes.c.src @@ -1741,13 +1741,8 @@ gentype_reduce(PyObject *self, PyObject *NPY_UNUSED(args)) if (arr == NULL) { return NULL; } - /* arr.item() */ - PyObject *val = PyArray_GETITEM(arr, PyArray_DATA(arr)); - Py_DECREF(arr); - if (val == NULL) { - return NULL; - } - PyObject *tup = Py_BuildValue("NN", obj, val); + /* Use the whole array which handles sturctured void correctly */ + PyObject *tup = Py_BuildValue("NN", obj, arr); if (tup == NULL) { return NULL; } diff --git a/numpy/core/tests/test_records.py b/numpy/core/tests/test_records.py index f28ad5ac9..e30d8e65c 100644 --- a/numpy/core/tests/test_records.py +++ b/numpy/core/tests/test_records.py @@ -424,7 +424,15 @@ class TestRecord: # make sure we did not pickle the address assert not isinstance(obj, bytes) - assert_raises(TypeError, ctor, dtype, 13) + assert_raises(RuntimeError, ctor, dtype, 13) + + # Test roundtrip: + dump = pickle.dumps(a[0]) + unpickled = pickle.loads(dump) + assert a[0] == unpickled + + # Also check the similar (impossible) "object scalar" path: + assert ctor(np.dtype("O"), data) is data def test_objview_record(self): # https://github.com/numpy/numpy/issues/2599 |