summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c27
-rw-r--r--numpy/core/src/multiarray/scalartypes.c.src9
-rw-r--r--numpy/core/tests/test_records.py10
3 files changed, 31 insertions, 15 deletions
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index 1aad70dc6..8dcd67c20 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -2003,20 +2003,33 @@ array_scalar(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *kwds)
int alloc = 0;
void *dptr;
PyObject *ret;
-
+ PyObject *base = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!|O:scalar", kwlist,
&PyArrayDescr_Type, &typecode, &obj)) {
return NULL;
}
if (PyDataType_FLAGCHK(typecode, NPY_LIST_PICKLE)) {
- if (!PySequence_Check(obj)) {
- PyErr_SetString(PyExc_TypeError,
- "found non-sequence while unpickling scalar with "
- "NPY_LIST_PICKLE set");
+ if (typecode->type_num == NPY_OBJECT) {
+ Py_INCREF(obj);
+ return obj;
+ }
+ /* We store the full array to unpack it here: */
+ if (!PyArray_CheckExact(obj)) {
+ /* We pickle structured voids as arrays currently */
+ PyErr_SetString(PyExc_RuntimeError,
+ "Unpickling NPY_LIST_PICKLE (structured void) scalar "
+ "requires an array. The pickle file may be corrupted?");
return NULL;
}
- dptr = &obj;
+ if (!PyArray_EquivTypes(PyArray_DESCR((PyArrayObject *)obj), typecode)) {
+ PyErr_SetString(PyExc_RuntimeError,
+ "Pickled array is not compatible with requested scalar "
+ "dtype. The pickle file may be corrupted?");
+ return NULL;
+ }
+ base = obj;
+ dptr = PyArray_BYTES((PyArrayObject *)obj);
}
else if (PyDataType_FLAGCHK(typecode, NPY_ITEM_IS_POINTER)) {
@@ -2066,7 +2079,7 @@ array_scalar(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *kwds)
dptr = PyBytes_AS_STRING(obj);
}
}
- ret = PyArray_Scalar(dptr, typecode, NULL);
+ ret = PyArray_Scalar(dptr, typecode, base);
/* free dptr which contains zeros */
if (alloc) {
diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src
index d2ae6ce31..df9350406 100644
--- a/numpy/core/src/multiarray/scalartypes.c.src
+++ b/numpy/core/src/multiarray/scalartypes.c.src
@@ -1741,13 +1741,8 @@ gentype_reduce(PyObject *self, PyObject *NPY_UNUSED(args))
if (arr == NULL) {
return NULL;
}
- /* arr.item() */
- PyObject *val = PyArray_GETITEM(arr, PyArray_DATA(arr));
- Py_DECREF(arr);
- if (val == NULL) {
- return NULL;
- }
- PyObject *tup = Py_BuildValue("NN", obj, val);
+ /* Use the whole array which handles sturctured void correctly */
+ PyObject *tup = Py_BuildValue("NN", obj, arr);
if (tup == NULL) {
return NULL;
}
diff --git a/numpy/core/tests/test_records.py b/numpy/core/tests/test_records.py
index f28ad5ac9..e30d8e65c 100644
--- a/numpy/core/tests/test_records.py
+++ b/numpy/core/tests/test_records.py
@@ -424,7 +424,15 @@ class TestRecord:
# make sure we did not pickle the address
assert not isinstance(obj, bytes)
- assert_raises(TypeError, ctor, dtype, 13)
+ assert_raises(RuntimeError, ctor, dtype, 13)
+
+ # Test roundtrip:
+ dump = pickle.dumps(a[0])
+ unpickled = pickle.loads(dump)
+ assert a[0] == unpickled
+
+ # Also check the similar (impossible) "object scalar" path:
+ assert ctor(np.dtype("O"), data) is data
def test_objview_record(self):
# https://github.com/numpy/numpy/issues/2599