summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSebastian Berg <sebastian@sipsolutions.net>2020-11-19 12:53:42 -0600
committerSebastian Berg <sebastian@sipsolutions.net>2020-11-19 13:09:37 -0600
commit5f316a85dd205cb0a6134a5a1e6b5bc76943f17c (patch)
tree287bb2c8036eb103dd1c4e1b726729797aecc49d
parent44d66b8a3fbc88998b98d7bffc20bc9fa89251a5 (diff)
downloadnumpy-5f316a85dd205cb0a6134a5a1e6b5bc76943f17c.tar.gz
BUG: Fix pickling of scalars with NPY_LISTPICKLE
Structured void scalars (or potentially other dtypes using LISTPICKLE) were incorrectly pickled and unpickled previously. While the pickled value was valid and could have been used to reconstruct the data, the reconstruction itself would always fail. The approach here is to always pickle the full array and use the arrays pickle/unpickle machinery which handles structured void correctly. Future work could improve on this probably. In the (hopefully impossible) case that someone tries to unpickle a numpy object scalar, the code is now fixed to return the pickled value unmodified (but INCREF'd), however a deprecation-warning is given, since such an object should never be created.
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c27
-rw-r--r--numpy/core/src/multiarray/scalartypes.c.src9
-rw-r--r--numpy/core/tests/test_records.py10
3 files changed, 31 insertions, 15 deletions
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index 1aad70dc6..8dcd67c20 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -2003,20 +2003,33 @@ array_scalar(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *kwds)
int alloc = 0;
void *dptr;
PyObject *ret;
-
+ PyObject *base = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!|O:scalar", kwlist,
&PyArrayDescr_Type, &typecode, &obj)) {
return NULL;
}
if (PyDataType_FLAGCHK(typecode, NPY_LIST_PICKLE)) {
- if (!PySequence_Check(obj)) {
- PyErr_SetString(PyExc_TypeError,
- "found non-sequence while unpickling scalar with "
- "NPY_LIST_PICKLE set");
+ if (typecode->type_num == NPY_OBJECT) {
+ Py_INCREF(obj);
+ return obj;
+ }
+ /* We store the full array to unpack it here: */
+ if (!PyArray_CheckExact(obj)) {
+ /* We pickle structured voids as arrays currently */
+ PyErr_SetString(PyExc_RuntimeError,
+ "Unpickling NPY_LIST_PICKLE (structured void) scalar "
+ "requires an array. The pickle file may be corrupted?");
return NULL;
}
- dptr = &obj;
+ if (!PyArray_EquivTypes(PyArray_DESCR((PyArrayObject *)obj), typecode)) {
+ PyErr_SetString(PyExc_RuntimeError,
+ "Pickled array is not compatible with requested scalar "
+ "dtype. The pickle file may be corrupted?");
+ return NULL;
+ }
+ base = obj;
+ dptr = PyArray_BYTES((PyArrayObject *)obj);
}
else if (PyDataType_FLAGCHK(typecode, NPY_ITEM_IS_POINTER)) {
@@ -2066,7 +2079,7 @@ array_scalar(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *kwds)
dptr = PyBytes_AS_STRING(obj);
}
}
- ret = PyArray_Scalar(dptr, typecode, NULL);
+ ret = PyArray_Scalar(dptr, typecode, base);
/* free dptr which contains zeros */
if (alloc) {
diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src
index d2ae6ce31..df9350406 100644
--- a/numpy/core/src/multiarray/scalartypes.c.src
+++ b/numpy/core/src/multiarray/scalartypes.c.src
@@ -1741,13 +1741,8 @@ gentype_reduce(PyObject *self, PyObject *NPY_UNUSED(args))
if (arr == NULL) {
return NULL;
}
- /* arr.item() */
- PyObject *val = PyArray_GETITEM(arr, PyArray_DATA(arr));
- Py_DECREF(arr);
- if (val == NULL) {
- return NULL;
- }
- PyObject *tup = Py_BuildValue("NN", obj, val);
+ /* Use the whole array which handles sturctured void correctly */
+ PyObject *tup = Py_BuildValue("NN", obj, arr);
if (tup == NULL) {
return NULL;
}
diff --git a/numpy/core/tests/test_records.py b/numpy/core/tests/test_records.py
index f28ad5ac9..e30d8e65c 100644
--- a/numpy/core/tests/test_records.py
+++ b/numpy/core/tests/test_records.py
@@ -424,7 +424,15 @@ class TestRecord:
# make sure we did not pickle the address
assert not isinstance(obj, bytes)
- assert_raises(TypeError, ctor, dtype, 13)
+ assert_raises(RuntimeError, ctor, dtype, 13)
+
+ # Test roundtrip:
+ dump = pickle.dumps(a[0])
+ unpickled = pickle.loads(dump)
+ assert a[0] == unpickled
+
+ # Also check the similar (impossible) "object scalar" path:
+ assert ctor(np.dtype("O"), data) is data
def test_objview_record(self):
# https://github.com/numpy/numpy/issues/2599