diff options
author | Sebastian Berg <sebastian@sipsolutions.net> | 2013-02-21 23:48:29 +0100 |
---|---|---|
committer | Sebastian Berg <sebastian@sipsolutions.net> | 2013-02-25 01:01:51 +0100 |
commit | 58548e66d5d3bda3e884ae0c0ab0805ab0160484 (patch) | |
tree | fd711597eab45dd6e24c333e4a7a651ac8e956d2 /numpy/core | |
parent | b343f43eea856bd984a752f288bd148a42a789a1 (diff) | |
download | numpy-58548e66d5d3bda3e884ae0c0ab0805ab0160484.tar.gz |
TST: Add test for np.take refcounting
Also make the testcase for take a class.
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/tests/test_item_selection.py | 94 |
1 files changed, 56 insertions, 38 deletions
diff --git a/numpy/core/tests/test_item_selection.py b/numpy/core/tests/test_item_selection.py index b35f4db6f..47de43012 100644 --- a/numpy/core/tests/test_item_selection.py +++ b/numpy/core/tests/test_item_selection.py @@ -2,44 +2,62 @@ import numpy as np from numpy.testing import * import sys, warnings -def test_take(): - a = [[1, 2], [3, 4]] - a_str = [[b'1', b'2'],[b'3', b'4']] - modes = ['raise', 'wrap', 'clip'] - indices = [-1, 4] - index_arrays = [np.empty(0, dtype=np.intp), - np.empty(tuple(), dtype=np.intp), - np.empty((1,1), dtype=np.intp)] - real_indices = {} - real_indices['raise'] = {-1:1, 4:IndexError} - real_indices['wrap'] = {-1:1, 4:0} - real_indices['clip'] = {-1:0, 4:1} - # Currently all types but object, use the same function generation. - # So it should not be necessary to test all. However test also a non - # refcounted struct on top of object. - types = np.int, np.object, np.dtype([('', 'i', 2)]) - for t in types: - # ta works, even if the array may be odd if buffer interface is used - ta = np.array(a if np.issubdtype(t, np.number) else a_str, dtype=t) - tresult = list(ta.T.copy()) - for index_array in index_arrays: - if index_array.size != 0: - tresult[0].shape = (2,) + index_array.shape - tresult[1].shape = (2,) + index_array.shape - for mode in modes: - for index in indices: - real_index = real_indices[mode][index] - if real_index is IndexError and index_array.size != 0: - index_array.put(0, index) - assert_raises(IndexError, ta.take, index_array, - mode=mode, axis=1) - elif index_array.size != 0: - index_array.put(0, index) - res = ta.take(index_array, mode=mode, axis=1) - assert_array_equal(res, tresult[real_index]) - else: - res = ta.take(index_array, mode=mode, axis=1) - assert_(res.shape == (2,) + index_array.shape) + +class TestTake(TestCase): + def test_simple(self): + a = [[1, 2], [3, 4]] + a_str = [[b'1', b'2'],[b'3', b'4']] + modes = ['raise', 'wrap', 'clip'] + indices = [-1, 4] + index_arrays = [np.empty(0, dtype=np.intp), + np.empty(tuple(), dtype=np.intp), + np.empty((1,1), dtype=np.intp)] + real_indices = {} + real_indices['raise'] = {-1:1, 4:IndexError} + real_indices['wrap'] = {-1:1, 4:0} + real_indices['clip'] = {-1:0, 4:1} + # Currently all types but object, use the same function generation. + # So it should not be necessary to test all. However test also a non + # refcounted struct on top of object. + types = np.int, np.object, np.dtype([('', 'i', 2)]) + for t in types: + # ta works, even if the array may be odd if buffer interface is used + ta = np.array(a if np.issubdtype(t, np.number) else a_str, dtype=t) + tresult = list(ta.T.copy()) + for index_array in index_arrays: + if index_array.size != 0: + tresult[0].shape = (2,) + index_array.shape + tresult[1].shape = (2,) + index_array.shape + for mode in modes: + for index in indices: + real_index = real_indices[mode][index] + if real_index is IndexError and index_array.size != 0: + index_array.put(0, index) + assert_raises(IndexError, ta.take, index_array, + mode=mode, axis=1) + elif index_array.size != 0: + index_array.put(0, index) + res = ta.take(index_array, mode=mode, axis=1) + assert_array_equal(res, tresult[real_index]) + else: + res = ta.take(index_array, mode=mode, axis=1) + assert_(res.shape == (2,) + index_array.shape) + + + def test_refcounting(self): + objects = [object() for i in xrange(10)] + for mode in ('raise', 'clip', 'wrap'): + a = np.array(objects) + b = np.array([2, 2, 4, 5, 3, 5]) + a.take(b, out=a[:6]) + del a + assert_(all(sys.getrefcount(o) == 3 for o in objects)) + # not contiguous, example: + a = np.array(objects * 2)[::2] + a.take(b, out=a[:6]) + del a + assert_(all(sys.getrefcount(o) == 3 for o in objects)) + if __name__ == "__main__": run_module_suite() |