diff options
author | mattip <matti.picus@gmail.com> | 2018-04-20 13:11:32 +0300 |
---|---|---|
committer | mattip <matti.picus@gmail.com> | 2018-04-20 13:11:32 +0300 |
commit | c26d273c1204d75fe5ab2ce9591e1b0b0b0880e1 (patch) | |
tree | d6fe1169553146618706c6fbce7c9eae193e384c /numpy | |
parent | 894dcab37ea2df285c6f48eb9b019a528b803cb5 (diff) | |
download | numpy-c26d273c1204d75fe5ab2ce9591e1b0b0b0880e1.tar.gz |
ENH: add nditer.close as per review
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/add_newdocs.py | 2 | ||||
-rw-r--r-- | numpy/core/src/multiarray/nditer_pywrap.c | 21 | ||||
-rw-r--r-- | numpy/core/tests/test_nditer.py | 36 |
3 files changed, 48 insertions, 11 deletions
diff --git a/numpy/add_newdocs.py b/numpy/add_newdocs.py index 193093109..a48b76a8d 100644 --- a/numpy/add_newdocs.py +++ b/numpy/add_newdocs.py @@ -553,7 +553,7 @@ add_newdoc('numpy.core', 'nditer', ('close', """ close() - Resolve all writeback semantics in operands. + Resolve all writeback semantics in writeable operands. """)) diff --git a/numpy/core/src/multiarray/nditer_pywrap.c b/numpy/core/src/multiarray/nditer_pywrap.c index 7e4e715f0..8efae59a6 100644 --- a/numpy/core/src/multiarray/nditer_pywrap.c +++ b/numpy/core/src/multiarray/nditer_pywrap.c @@ -2421,21 +2421,28 @@ npyiter_enter(NewNpyArrayIterObject *self) } static PyObject * -npyiter_exit(NewNpyArrayIterObject *self, PyObject *args) +npyiter_close(NewNpyArrayIterObject *self) { - int retval; + NpyIter *iter = self->iter; + int ret; if (self->iter == NULL) { Py_RETURN_NONE; } - self->managed = CONTEXT_EXITED; - /* even if called via exception handling, writeback any data */ - retval = NpyIter_Close(self->iter); - if (retval < 0) { + ret = NpyIter_Close(iter); + if (ret < 0) { return NULL; } Py_RETURN_NONE; } +static PyObject * +npyiter_exit(NewNpyArrayIterObject *self, PyObject *args) +{ + self->managed = CONTEXT_EXITED; + /* even if called via exception handling, writeback any data */ + return npyiter_close(self); +} + static PyMethodDef npyiter_methods[] = { {"reset", (PyCFunction)npyiter_reset, @@ -2465,6 +2472,8 @@ static PyMethodDef npyiter_methods[] = { METH_NOARGS, NULL}, {"__exit__", (PyCFunction)npyiter_exit, METH_VARARGS, NULL}, + {"close", (PyCFunction)npyiter_close, + METH_VARARGS, NULL}, {NULL, NULL, 0, NULL}, }; diff --git a/numpy/core/tests/test_nditer.py b/numpy/core/tests/test_nditer.py index 1973c73a2..bc9456536 100644 --- a/numpy/core/tests/test_nditer.py +++ b/numpy/core/tests/test_nditer.py @@ -2330,8 +2330,7 @@ class TestIterNested(object): assert_equal(vals, [[0, 1, 2], [3, 4, 5]]) vals = None - # writebackifcopy - # XXX ugly - is there a better way? np.nested_iter returns a tuple + # writebackifcopy - using conext manager a = arange(6, dtype='f4').reshape(2, 3) i, j = np.nested_iters(a, [[0], [1]], op_flags=['readwrite', 'updateifcopy'], @@ -2345,6 +2344,22 @@ class TestIterNested(object): assert_equal(a, [[0, 1, 2], [3, 4, 5]]) assert_equal(a, [[1, 2, 3], [4, 5, 6]]) + # writebackifcopy - using close() + a = arange(6, dtype='f4').reshape(2, 3) + i, j = np.nested_iters(a, [[0], [1]], + op_flags=['readwrite', 'updateifcopy'], + casting='same_kind', + op_dtypes='f8') + assert_equal(j[0].dtype, np.dtype('f8')) + for x in i: + for y in j: + y[...] += 1 + assert_equal(a, [[0, 1, 2], [3, 4, 5]]) + i.close() + j.close() + assert_equal(a, [[1, 2, 3], [4, 5, 6]]) + + def test_dtype_buffered(self): # Test nested iteration with buffering to change dtype @@ -2833,7 +2848,17 @@ def test_writebacks(): assert_raises(ValueError, enter) def test_close(): - def iter_add_py(x, y, out=None): + ''' using a context amanger and using nditer.close are equivalent + ''' + def add_close(x, y, out=None): + addop = np.add + it = np.nditer([x, y, out], [], + [['readonly'], ['readonly'], ['writeonly','allocate']]) + for (a, b, c) in it: + addop(a, b, out=c) + it.close() + return it.operands[2] + def add_context(x, y, out=None): addop = np.add it = np.nditer([x, y, out], [], [['readonly'], ['readonly'], ['writeonly','allocate']]) @@ -2841,7 +2866,10 @@ def test_close(): for (a, b, c) in it: addop(a, b, out=c) return it.operands[2] - z = iter_add_py(range(5), range(5)) + z = add_close(range(5), range(5)) + assert_equal(z, range(0, 10, 2)) + z = add_context(range(5), range(5)) + assert_equal(z, range(0, 10, 2)) def test_warn_noclose(): a = np.arange(6, dtype='f4') |