diff options
author | Pauli Virtanen <pav@iki.fi> | 2010-09-11 15:47:22 +0000 |
---|---|---|
committer | Pauli Virtanen <pav@iki.fi> | 2010-09-11 15:47:22 +0000 |
commit | b45222e6f8eea0eaaf7c773f1370f2ebc7765a2f (patch) | |
tree | 7a7edd2e29a282cf79ffe33e506c25a4d74304e9 /numpy/core | |
parent | d5382aa641e3a62cb7707af07724048422960ad0 (diff) | |
download | numpy-b45222e6f8eea0eaaf7c773f1370f2ebc7765a2f.tar.gz |
BUG: core: sync Python 3 file handle position in tofile/fromfile (fixes #1610)
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/include/numpy/npy_3kcompat.h | 62 | ||||
-rw-r--r-- | numpy/core/src/multiarray/methods.c | 12 | ||||
-rw-r--r-- | numpy/core/src/multiarray/multiarraymodule.c | 13 | ||||
-rw-r--r-- | numpy/core/tests/test_regression.py | 22 |
4 files changed, 90 insertions, 19 deletions
diff --git a/numpy/core/include/numpy/npy_3kcompat.h b/numpy/core/include/numpy/npy_3kcompat.h index 76f10ac9e..a0a1e0445 100644 --- a/numpy/core/include/numpy/npy_3kcompat.h +++ b/numpy/core/include/numpy/npy_3kcompat.h @@ -149,14 +149,19 @@ PyUnicode_Concat2(PyObject **left, PyObject *right) #endif /* - * PyFile_AsFile + * PyFile_* compatibility */ #if defined(NPY_PY3K) + +/* + * Get a FILE* handle to the file represented by the Python object + */ static NPY_INLINE FILE* npy_PyFile_Dup(PyObject *file, char *mode) { int fd, fd2; PyObject *ret, *os; + FILE *handle; /* Flush first to ensure things end up in the file in the correct order */ ret = PyObject_CallMethod(file, "flush", ""); if (ret == NULL) { @@ -179,11 +184,62 @@ npy_PyFile_Dup(PyObject *file, char *mode) fd2 = PyNumber_AsSsize_t(ret, NULL); Py_DECREF(ret); #ifdef _WIN32 - return _fdopen(fd2, mode); + handle = _fdopen(fd2, mode); #else - return fdopen(fd2, mode); + handle = fdopen(fd2, mode); #endif + if (handle == NULL) { + PyErr_SetString(PyExc_IOError, + "Getting a FILE* from a Python file object failed"); + } + return handle; } + +/* + * Close the dup-ed file handle, and seek the Python one to the current position + */ +static NPY_INLINE int +npy_PyFile_DupClose(PyObject *file, FILE* handle) +{ + PyObject *ret; + long position; + position = ftell(handle); + fclose(handle); + + ret = PyObject_CallMethod(file, "seek", "li", position, 0); + if (ret == NULL) { + return -1; + } + Py_DECREF(ret); + return 0; +} + +static int +npy_PyFile_Check(PyObject *file) +{ + static PyTypeObject *fileio = NULL; + + if (fileio == NULL) { + PyObject *mod; + mod = PyImport_ImportModule("io"); + if (mod == NULL) { + return 0; + } + fileio = (PyTypeObject*)PyObject_GetAttrString(mod, "FileIO"); + Py_DECREF(mod); + } + + if (fileio != NULL) { + return PyObject_TypeCheck(file, fileio); + } +} + +#else + +#define npy_PyFile_Dup(file, mode) PyFile_AsFile(file) +#define npy_PyFile_DupClose(file, handle) (0) +#define npy_PyFile_Check PyFile_Check + #endif static NPY_INLINE PyObject* diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c index 783554965..0290c9c39 100644 --- a/numpy/core/src/multiarray/methods.c +++ b/numpy/core/src/multiarray/methods.c @@ -496,7 +496,7 @@ array_tostring(PyArrayObject *self, PyObject *args, PyObject *kwds) static PyObject * array_tofile(PyArrayObject *self, PyObject *args, PyObject *kwds) { - int ret; + int ret, ret2; PyObject *file; FILE *fd; char *sep = ""; @@ -517,11 +517,7 @@ array_tofile(PyArrayObject *self, PyObject *args, PyObject *kwds) else { Py_INCREF(file); } -#if defined(NPY_PY3K) fd = npy_PyFile_Dup(file, "wb"); -#else - fd = PyFile_AsFile(file); -#endif if (fd == NULL) { PyErr_SetString(PyExc_IOError, "first argument must be a " \ "string or open file"); @@ -529,11 +525,9 @@ array_tofile(PyArrayObject *self, PyObject *args, PyObject *kwds) return NULL; } ret = PyArray_ToFile(self, fd, sep, format); -#if defined(NPY_PY3K) - fclose(fd); -#endif + ret2 = npy_PyFile_DupClose(file, fd); Py_DECREF(file); - if (ret < 0) { + if (ret < 0 || ret2 < 0) { return NULL; } Py_INCREF(Py_None); diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index de63f339c..fcd68e079 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -1706,6 +1706,7 @@ static PyObject * array_fromfile(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *keywds) { PyObject *file = NULL, *ret; + int ok; FILE *fp; char *sep = ""; Py_ssize_t nin = -1; @@ -1727,11 +1728,7 @@ array_fromfile(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *keywds) else { Py_INCREF(file); } -#if defined(NPY_PY3K) fp = npy_PyFile_Dup(file, "rb"); -#else - fp = PyFile_AsFile(file); -#endif if (fp == NULL) { PyErr_SetString(PyExc_IOError, "first argument must be an open file"); @@ -1742,10 +1739,12 @@ array_fromfile(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *keywds) type = PyArray_DescrFromType(PyArray_DEFAULT); } ret = PyArray_FromFile(fp, type, (intp) nin, sep); -#if defined(NPY_PY3K) - fclose(fp); -#endif + ok = npy_PyFile_DupClose(file, fp); Py_DECREF(file); + if (ok < 0) { + Py_DECREF(ret); + return NULL; + } return ret; } diff --git a/numpy/core/tests/test_regression.py b/numpy/core/tests/test_regression.py index 26dc769f9..d5ba4849b 100644 --- a/numpy/core/tests/test_regression.py +++ b/numpy/core/tests/test_regression.py @@ -7,6 +7,7 @@ from os import path from numpy.testing import * from numpy.testing.utils import _assert_valid_refcount from numpy.compat import asbytes, asunicode, asbytes_nested +import tempfile import numpy as np if sys.version_info[0] >= 3: @@ -1373,5 +1374,26 @@ class TestRegression(TestCase): c2 = sys.getrefcount(rgba) assert_equal(c1, c2) + def test_fromfile_tofile_seeks(self): + # On Python 3, tofile/fromfile used to get (#1610) the Python + # file handle out of sync + f = tempfile.TemporaryFile() + f.write(np.arange(255, dtype='u1').tostring()) + + f.seek(20) + ret = np.fromfile(f, count=4, dtype='u1') + assert_equal(ret, np.array([20, 21, 22, 23], dtype='u1')) + assert_equal(f.tell(), 24) + + f.seek(40) + np.array([1, 2, 3], dtype='u1').tofile(f) + assert_equal(f.tell(), 43) + + f.seek(40) + data = f.read(3) + assert_equal(data, asbytes("\x01\x02\x03")) + + f.close() + if __name__ == "__main__": run_module_suite() |