summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2019-06-03 06:18:26 -0700
committerGitHub <noreply@github.com>2019-06-03 06:18:26 -0700
commit52d173d9087f6fb09994b433a127f58d19875c98 (patch)
treef8d8cac70bb2a5d3094db33e7d17a7b3aa52132d /numpy
parent495de50af010653be3f8a0b252c4ecd8228e2541 (diff)
parent8c275a1120363121457a164b05c5385be2d47718 (diff)
downloadnumpy-52d173d9087f6fb09994b433a127f58d19875c98.tar.gz
Merge pull request #13684 from eric-wieser/dump-to-python
BUG: Move ndarray.dump to python and make it close the file it opens
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/_methods.py12
-rw-r--r--numpy/core/src/multiarray/methods.c77
-rw-r--r--numpy/core/tests/test_multiarray.py4
3 files changed, 38 insertions, 55 deletions
diff --git a/numpy/core/_methods.py b/numpy/core/_methods.py
index ba6f7d111..269e509b8 100644
--- a/numpy/core/_methods.py
+++ b/numpy/core/_methods.py
@@ -13,6 +13,7 @@ from numpy.core._asarray import asanyarray
from numpy.core import numerictypes as nt
from numpy.core import _exceptions
from numpy._globals import _NoValue
+from numpy.compat import pickle, os_fspath, contextlib_nullcontext
# save those O(100) nanoseconds!
umr_maximum = um.maximum.reduce
@@ -230,3 +231,14 @@ def _ptp(a, axis=None, out=None, keepdims=False):
umr_minimum(a, axis, None, None, keepdims),
out
)
+
+def _dump(self, file, protocol=2):
+ if hasattr(file, 'write'):
+ ctx = contextlib_nullcontext(file)
+ else:
+ ctx = open(os_fspath(file), "wb")
+ with ctx as f:
+ pickle.dump(self, f, protocol=protocol)
+
+def _dumps(self, protocol=2):
+ return pickle.dumps(self, protocol=protocol)
diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c
index b843c7983..3d7e035ac 100644
--- a/numpy/core/src/multiarray/methods.c
+++ b/numpy/core/src/multiarray/methods.c
@@ -2098,37 +2098,22 @@ array_setstate(PyArrayObject *self, PyObject *args)
NPY_NO_EXPORT int
PyArray_Dump(PyObject *self, PyObject *file, int protocol)
{
- PyObject *cpick = NULL;
+ static PyObject *method = NULL;
PyObject *ret;
- if (protocol < 0) {
- protocol = 2;
- }
-
-#if defined(NPY_PY3K)
- cpick = PyImport_ImportModule("pickle");
-#else
- cpick = PyImport_ImportModule("cPickle");
-#endif
- if (cpick == NULL) {
+ npy_cache_import("numpy.core._methods", "_dump", &method);
+ if (method == NULL) {
return -1;
}
- if (PyBytes_Check(file) || PyUnicode_Check(file)) {
- file = npy_PyFile_OpenFile(file, "wb");
- if (file == NULL) {
- Py_DECREF(cpick);
- return -1;
- }
+ if (protocol < 0) {
+ ret = PyObject_CallFunction(method, "OO", self, file);
}
else {
- Py_INCREF(file);
+ ret = PyObject_CallFunction(method, "OOi", self, file, protocol);
}
- ret = PyObject_CallMethod(cpick, "dump", "OOi", self, file, protocol);
- Py_XDECREF(ret);
- Py_DECREF(file);
- Py_DECREF(cpick);
- if (PyErr_Occurred()) {
+ if (ret == NULL) {
return -1;
}
+ Py_DECREF(ret);
return 0;
}
@@ -2136,49 +2121,31 @@ PyArray_Dump(PyObject *self, PyObject *file, int protocol)
NPY_NO_EXPORT PyObject *
PyArray_Dumps(PyObject *self, int protocol)
{
- PyObject *cpick = NULL;
- PyObject *ret;
+ static PyObject *method = NULL;
+ npy_cache_import("numpy.core._methods", "_dumps", &method);
+ if (method == NULL) {
+ return NULL;
+ }
if (protocol < 0) {
- protocol = 2;
+ return PyObject_CallFunction(method, "O", self);
}
-#if defined(NPY_PY3K)
- cpick = PyImport_ImportModule("pickle");
-#else
- cpick = PyImport_ImportModule("cPickle");
-#endif
- if (cpick == NULL) {
- return NULL;
+ else {
+ return PyObject_CallFunction(method, "Oi", self, protocol);
}
- ret = PyObject_CallMethod(cpick, "dumps", "Oi", self, protocol);
- Py_DECREF(cpick);
- return ret;
}
static PyObject *
-array_dump(PyArrayObject *self, PyObject *args)
+array_dump(PyArrayObject *self, PyObject *args, PyObject *kwds)
{
- PyObject *file = NULL;
- int ret;
-
- if (!PyArg_ParseTuple(args, "O:dump", &file)) {
- return NULL;
- }
- ret = PyArray_Dump((PyObject *)self, file, 2);
- if (ret < 0) {
- return NULL;
- }
- Py_RETURN_NONE;
+ NPY_FORWARD_NDARRAY_METHOD("_dump");
}
static PyObject *
-array_dumps(PyArrayObject *self, PyObject *args)
+array_dumps(PyArrayObject *self, PyObject *args, PyObject *kwds)
{
- if (!PyArg_ParseTuple(args, "")) {
- return NULL;
- }
- return PyArray_Dumps((PyObject *)self, 2);
+ NPY_FORWARD_NDARRAY_METHOD("_dumps");
}
@@ -2753,10 +2720,10 @@ NPY_NO_EXPORT PyMethodDef array_methods[] = {
METH_VARARGS, NULL},
{"dumps",
(PyCFunction) array_dumps,
- METH_VARARGS, NULL},
+ METH_VARARGS | METH_KEYWORDS, NULL},
{"dump",
(PyCFunction) array_dump,
- METH_VARARGS, NULL},
+ METH_VARARGS | METH_KEYWORDS, NULL},
{"__complex__",
(PyCFunction) array_complex,
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 88acb76bd..b9eca45a3 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -1488,6 +1488,10 @@ class TestZeroSizeFlexible(object):
# viewing as any non-empty type gives an empty result
assert_equal(zs.view((dt, 1)).shape, (0,))
+ def test_dumps(self):
+ zs = self._zeros(10, int)
+ assert_equal(zs, pickle.loads(zs.dumps()))
+
def test_pickle(self):
for proto in range(2, pickle.HIGHEST_PROTOCOL + 1):
for dt in [bytes, np.void, unicode]: