diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2019-06-03 06:18:26 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-06-03 06:18:26 -0700 |
commit | 52d173d9087f6fb09994b433a127f58d19875c98 (patch) | |
tree | f8d8cac70bb2a5d3094db33e7d17a7b3aa52132d /numpy | |
parent | 495de50af010653be3f8a0b252c4ecd8228e2541 (diff) | |
parent | 8c275a1120363121457a164b05c5385be2d47718 (diff) | |
download | numpy-52d173d9087f6fb09994b433a127f58d19875c98.tar.gz |
Merge pull request #13684 from eric-wieser/dump-to-python
BUG: Move ndarray.dump to python and make it close the file it opens
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/_methods.py | 12 | ||||
-rw-r--r-- | numpy/core/src/multiarray/methods.c | 77 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 4 |
3 files changed, 38 insertions, 55 deletions
diff --git a/numpy/core/_methods.py b/numpy/core/_methods.py index ba6f7d111..269e509b8 100644 --- a/numpy/core/_methods.py +++ b/numpy/core/_methods.py @@ -13,6 +13,7 @@ from numpy.core._asarray import asanyarray from numpy.core import numerictypes as nt from numpy.core import _exceptions from numpy._globals import _NoValue +from numpy.compat import pickle, os_fspath, contextlib_nullcontext # save those O(100) nanoseconds! umr_maximum = um.maximum.reduce @@ -230,3 +231,14 @@ def _ptp(a, axis=None, out=None, keepdims=False): umr_minimum(a, axis, None, None, keepdims), out ) + +def _dump(self, file, protocol=2): + if hasattr(file, 'write'): + ctx = contextlib_nullcontext(file) + else: + ctx = open(os_fspath(file), "wb") + with ctx as f: + pickle.dump(self, f, protocol=protocol) + +def _dumps(self, protocol=2): + return pickle.dumps(self, protocol=protocol) diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c index b843c7983..3d7e035ac 100644 --- a/numpy/core/src/multiarray/methods.c +++ b/numpy/core/src/multiarray/methods.c @@ -2098,37 +2098,22 @@ array_setstate(PyArrayObject *self, PyObject *args) NPY_NO_EXPORT int PyArray_Dump(PyObject *self, PyObject *file, int protocol) { - PyObject *cpick = NULL; + static PyObject *method = NULL; PyObject *ret; - if (protocol < 0) { - protocol = 2; - } - -#if defined(NPY_PY3K) - cpick = PyImport_ImportModule("pickle"); -#else - cpick = PyImport_ImportModule("cPickle"); -#endif - if (cpick == NULL) { + npy_cache_import("numpy.core._methods", "_dump", &method); + if (method == NULL) { return -1; } - if (PyBytes_Check(file) || PyUnicode_Check(file)) { - file = npy_PyFile_OpenFile(file, "wb"); - if (file == NULL) { - Py_DECREF(cpick); - return -1; - } + if (protocol < 0) { + ret = PyObject_CallFunction(method, "OO", self, file); } else { - Py_INCREF(file); + ret = PyObject_CallFunction(method, "OOi", self, file, protocol); } - ret = PyObject_CallMethod(cpick, "dump", "OOi", self, file, protocol); - Py_XDECREF(ret); - Py_DECREF(file); - Py_DECREF(cpick); - if (PyErr_Occurred()) { + if (ret == NULL) { return -1; } + Py_DECREF(ret); return 0; } @@ -2136,49 +2121,31 @@ PyArray_Dump(PyObject *self, PyObject *file, int protocol) NPY_NO_EXPORT PyObject * PyArray_Dumps(PyObject *self, int protocol) { - PyObject *cpick = NULL; - PyObject *ret; + static PyObject *method = NULL; + npy_cache_import("numpy.core._methods", "_dumps", &method); + if (method == NULL) { + return NULL; + } if (protocol < 0) { - protocol = 2; + return PyObject_CallFunction(method, "O", self); } -#if defined(NPY_PY3K) - cpick = PyImport_ImportModule("pickle"); -#else - cpick = PyImport_ImportModule("cPickle"); -#endif - if (cpick == NULL) { - return NULL; + else { + return PyObject_CallFunction(method, "Oi", self, protocol); } - ret = PyObject_CallMethod(cpick, "dumps", "Oi", self, protocol); - Py_DECREF(cpick); - return ret; } static PyObject * -array_dump(PyArrayObject *self, PyObject *args) +array_dump(PyArrayObject *self, PyObject *args, PyObject *kwds) { - PyObject *file = NULL; - int ret; - - if (!PyArg_ParseTuple(args, "O:dump", &file)) { - return NULL; - } - ret = PyArray_Dump((PyObject *)self, file, 2); - if (ret < 0) { - return NULL; - } - Py_RETURN_NONE; + NPY_FORWARD_NDARRAY_METHOD("_dump"); } static PyObject * -array_dumps(PyArrayObject *self, PyObject *args) +array_dumps(PyArrayObject *self, PyObject *args, PyObject *kwds) { - if (!PyArg_ParseTuple(args, "")) { - return NULL; - } - return PyArray_Dumps((PyObject *)self, 2); + NPY_FORWARD_NDARRAY_METHOD("_dumps"); } @@ -2753,10 +2720,10 @@ NPY_NO_EXPORT PyMethodDef array_methods[] = { METH_VARARGS, NULL}, {"dumps", (PyCFunction) array_dumps, - METH_VARARGS, NULL}, + METH_VARARGS | METH_KEYWORDS, NULL}, {"dump", (PyCFunction) array_dump, - METH_VARARGS, NULL}, + METH_VARARGS | METH_KEYWORDS, NULL}, {"__complex__", (PyCFunction) array_complex, diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 88acb76bd..b9eca45a3 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -1488,6 +1488,10 @@ class TestZeroSizeFlexible(object): # viewing as any non-empty type gives an empty result assert_equal(zs.view((dt, 1)).shape, (0,)) + def test_dumps(self): + zs = self._zeros(10, int) + assert_equal(zs, pickle.loads(zs.dumps())) + def test_pickle(self): for proto in range(2, pickle.HIGHEST_PROTOCOL + 1): for dt in [bytes, np.void, unicode]: |