diff options
author | Sebastian Berg <sebastian@sipsolutions.net> | 2020-11-19 12:53:42 -0600 |
---|---|---|
committer | Sebastian Berg <sebastian@sipsolutions.net> | 2020-11-19 13:09:37 -0600 |
commit | 5f316a85dd205cb0a6134a5a1e6b5bc76943f17c (patch) | |
tree | 287bb2c8036eb103dd1c4e1b726729797aecc49d | |
parent | 44d66b8a3fbc88998b98d7bffc20bc9fa89251a5 (diff) | |
download | numpy-5f316a85dd205cb0a6134a5a1e6b5bc76943f17c.tar.gz |
BUG: Fix pickling of scalars with NPY_LISTPICKLE
Structured void scalars (or potentially other dtypes using LISTPICKLE)
were incorrectly pickled and unpickled previously.
While the pickled value was valid and could have been used to
reconstruct the data, the reconstruction itself would always fail.
The approach here is to always pickle the full array and use the
arrays pickle/unpickle machinery which handles structured void
correctly. Future work could improve on this probably.
In the (hopefully impossible) case that someone tries to unpickle
a numpy object scalar, the code is now fixed to return the pickled
value unmodified (but INCREF'd), however a deprecation-warning
is given, since such an object should never be created.
-rw-r--r-- | numpy/core/src/multiarray/multiarraymodule.c | 27 | ||||
-rw-r--r-- | numpy/core/src/multiarray/scalartypes.c.src | 9 | ||||
-rw-r--r-- | numpy/core/tests/test_records.py | 10 |
3 files changed, 31 insertions, 15 deletions
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index 1aad70dc6..8dcd67c20 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -2003,20 +2003,33 @@ array_scalar(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *kwds) int alloc = 0; void *dptr; PyObject *ret; - + PyObject *base = NULL; if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!|O:scalar", kwlist, &PyArrayDescr_Type, &typecode, &obj)) { return NULL; } if (PyDataType_FLAGCHK(typecode, NPY_LIST_PICKLE)) { - if (!PySequence_Check(obj)) { - PyErr_SetString(PyExc_TypeError, - "found non-sequence while unpickling scalar with " - "NPY_LIST_PICKLE set"); + if (typecode->type_num == NPY_OBJECT) { + Py_INCREF(obj); + return obj; + } + /* We store the full array to unpack it here: */ + if (!PyArray_CheckExact(obj)) { + /* We pickle structured voids as arrays currently */ + PyErr_SetString(PyExc_RuntimeError, + "Unpickling NPY_LIST_PICKLE (structured void) scalar " + "requires an array. The pickle file may be corrupted?"); return NULL; } - dptr = &obj; + if (!PyArray_EquivTypes(PyArray_DESCR((PyArrayObject *)obj), typecode)) { + PyErr_SetString(PyExc_RuntimeError, + "Pickled array is not compatible with requested scalar " + "dtype. The pickle file may be corrupted?"); + return NULL; + } + base = obj; + dptr = PyArray_BYTES((PyArrayObject *)obj); } else if (PyDataType_FLAGCHK(typecode, NPY_ITEM_IS_POINTER)) { @@ -2066,7 +2079,7 @@ array_scalar(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *kwds) dptr = PyBytes_AS_STRING(obj); } } - ret = PyArray_Scalar(dptr, typecode, NULL); + ret = PyArray_Scalar(dptr, typecode, base); /* free dptr which contains zeros */ if (alloc) { diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src index d2ae6ce31..df9350406 100644 --- a/numpy/core/src/multiarray/scalartypes.c.src +++ b/numpy/core/src/multiarray/scalartypes.c.src @@ -1741,13 +1741,8 @@ gentype_reduce(PyObject *self, PyObject *NPY_UNUSED(args)) if (arr == NULL) { return NULL; } - /* arr.item() */ - PyObject *val = PyArray_GETITEM(arr, PyArray_DATA(arr)); - Py_DECREF(arr); - if (val == NULL) { - return NULL; - } - PyObject *tup = Py_BuildValue("NN", obj, val); + /* Use the whole array which handles sturctured void correctly */ + PyObject *tup = Py_BuildValue("NN", obj, arr); if (tup == NULL) { return NULL; } diff --git a/numpy/core/tests/test_records.py b/numpy/core/tests/test_records.py index f28ad5ac9..e30d8e65c 100644 --- a/numpy/core/tests/test_records.py +++ b/numpy/core/tests/test_records.py @@ -424,7 +424,15 @@ class TestRecord: # make sure we did not pickle the address assert not isinstance(obj, bytes) - assert_raises(TypeError, ctor, dtype, 13) + assert_raises(RuntimeError, ctor, dtype, 13) + + # Test roundtrip: + dump = pickle.dumps(a[0]) + unpickled = pickle.loads(dump) + assert a[0] == unpickled + + # Also check the similar (impossible) "object scalar" path: + assert ctor(np.dtype("O"), data) is data def test_objview_record(self): # https://github.com/numpy/numpy/issues/2599 |