summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authormattip <matti.picus@gmail.com>2018-04-20 13:11:32 +0300
committermattip <matti.picus@gmail.com>2018-04-20 13:11:32 +0300
commitc26d273c1204d75fe5ab2ce9591e1b0b0b0880e1 (patch)
treed6fe1169553146618706c6fbce7c9eae193e384c /numpy
parent894dcab37ea2df285c6f48eb9b019a528b803cb5 (diff)
downloadnumpy-c26d273c1204d75fe5ab2ce9591e1b0b0b0880e1.tar.gz
ENH: add nditer.close as per review
Diffstat (limited to 'numpy')
-rw-r--r--numpy/add_newdocs.py2
-rw-r--r--numpy/core/src/multiarray/nditer_pywrap.c21
-rw-r--r--numpy/core/tests/test_nditer.py36
3 files changed, 48 insertions, 11 deletions
diff --git a/numpy/add_newdocs.py b/numpy/add_newdocs.py
index 193093109..a48b76a8d 100644
--- a/numpy/add_newdocs.py
+++ b/numpy/add_newdocs.py
@@ -553,7 +553,7 @@ add_newdoc('numpy.core', 'nditer', ('close',
"""
close()
- Resolve all writeback semantics in operands.
+ Resolve all writeback semantics in writeable operands.
"""))
diff --git a/numpy/core/src/multiarray/nditer_pywrap.c b/numpy/core/src/multiarray/nditer_pywrap.c
index 7e4e715f0..8efae59a6 100644
--- a/numpy/core/src/multiarray/nditer_pywrap.c
+++ b/numpy/core/src/multiarray/nditer_pywrap.c
@@ -2421,21 +2421,28 @@ npyiter_enter(NewNpyArrayIterObject *self)
}
static PyObject *
-npyiter_exit(NewNpyArrayIterObject *self, PyObject *args)
+npyiter_close(NewNpyArrayIterObject *self)
{
- int retval;
+ NpyIter *iter = self->iter;
+ int ret;
if (self->iter == NULL) {
Py_RETURN_NONE;
}
- self->managed = CONTEXT_EXITED;
- /* even if called via exception handling, writeback any data */
- retval = NpyIter_Close(self->iter);
- if (retval < 0) {
+ ret = NpyIter_Close(iter);
+ if (ret < 0) {
return NULL;
}
Py_RETURN_NONE;
}
+static PyObject *
+npyiter_exit(NewNpyArrayIterObject *self, PyObject *args)
+{
+ self->managed = CONTEXT_EXITED;
+ /* even if called via exception handling, writeback any data */
+ return npyiter_close(self);
+}
+
static PyMethodDef npyiter_methods[] = {
{"reset",
(PyCFunction)npyiter_reset,
@@ -2465,6 +2472,8 @@ static PyMethodDef npyiter_methods[] = {
METH_NOARGS, NULL},
{"__exit__", (PyCFunction)npyiter_exit,
METH_VARARGS, NULL},
+ {"close", (PyCFunction)npyiter_close,
+ METH_VARARGS, NULL},
{NULL, NULL, 0, NULL},
};
diff --git a/numpy/core/tests/test_nditer.py b/numpy/core/tests/test_nditer.py
index 1973c73a2..bc9456536 100644
--- a/numpy/core/tests/test_nditer.py
+++ b/numpy/core/tests/test_nditer.py
@@ -2330,8 +2330,7 @@ class TestIterNested(object):
assert_equal(vals, [[0, 1, 2], [3, 4, 5]])
vals = None
- # writebackifcopy
- # XXX ugly - is there a better way? np.nested_iter returns a tuple
+ # writebackifcopy - using conext manager
a = arange(6, dtype='f4').reshape(2, 3)
i, j = np.nested_iters(a, [[0], [1]],
op_flags=['readwrite', 'updateifcopy'],
@@ -2345,6 +2344,22 @@ class TestIterNested(object):
assert_equal(a, [[0, 1, 2], [3, 4, 5]])
assert_equal(a, [[1, 2, 3], [4, 5, 6]])
+ # writebackifcopy - using close()
+ a = arange(6, dtype='f4').reshape(2, 3)
+ i, j = np.nested_iters(a, [[0], [1]],
+ op_flags=['readwrite', 'updateifcopy'],
+ casting='same_kind',
+ op_dtypes='f8')
+ assert_equal(j[0].dtype, np.dtype('f8'))
+ for x in i:
+ for y in j:
+ y[...] += 1
+ assert_equal(a, [[0, 1, 2], [3, 4, 5]])
+ i.close()
+ j.close()
+ assert_equal(a, [[1, 2, 3], [4, 5, 6]])
+
+
def test_dtype_buffered(self):
# Test nested iteration with buffering to change dtype
@@ -2833,7 +2848,17 @@ def test_writebacks():
assert_raises(ValueError, enter)
def test_close():
- def iter_add_py(x, y, out=None):
+ ''' using a context amanger and using nditer.close are equivalent
+ '''
+ def add_close(x, y, out=None):
+ addop = np.add
+ it = np.nditer([x, y, out], [],
+ [['readonly'], ['readonly'], ['writeonly','allocate']])
+ for (a, b, c) in it:
+ addop(a, b, out=c)
+ it.close()
+ return it.operands[2]
+ def add_context(x, y, out=None):
addop = np.add
it = np.nditer([x, y, out], [],
[['readonly'], ['readonly'], ['writeonly','allocate']])
@@ -2841,7 +2866,10 @@ def test_close():
for (a, b, c) in it:
addop(a, b, out=c)
return it.operands[2]
- z = iter_add_py(range(5), range(5))
+ z = add_close(range(5), range(5))
+ assert_equal(z, range(0, 10, 2))
+ z = add_context(range(5), range(5))
+ assert_equal(z, range(0, 10, 2))
def test_warn_noclose():
a = np.arange(6, dtype='f4')