summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorSebastian Berg <sebastian@sipsolutions.net>2013-02-21 23:48:29 +0100
committerSebastian Berg <sebastian@sipsolutions.net>2013-02-25 01:01:51 +0100
commit58548e66d5d3bda3e884ae0c0ab0805ab0160484 (patch)
treefd711597eab45dd6e24c333e4a7a651ac8e956d2 /numpy/core
parentb343f43eea856bd984a752f288bd148a42a789a1 (diff)
downloadnumpy-58548e66d5d3bda3e884ae0c0ab0805ab0160484.tar.gz
TST: Add test for np.take refcounting
Also make the testcase for take a class.
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/tests/test_item_selection.py94
1 files changed, 56 insertions, 38 deletions
diff --git a/numpy/core/tests/test_item_selection.py b/numpy/core/tests/test_item_selection.py
index b35f4db6f..47de43012 100644
--- a/numpy/core/tests/test_item_selection.py
+++ b/numpy/core/tests/test_item_selection.py
@@ -2,44 +2,62 @@ import numpy as np
from numpy.testing import *
import sys, warnings
-def test_take():
- a = [[1, 2], [3, 4]]
- a_str = [[b'1', b'2'],[b'3', b'4']]
- modes = ['raise', 'wrap', 'clip']
- indices = [-1, 4]
- index_arrays = [np.empty(0, dtype=np.intp),
- np.empty(tuple(), dtype=np.intp),
- np.empty((1,1), dtype=np.intp)]
- real_indices = {}
- real_indices['raise'] = {-1:1, 4:IndexError}
- real_indices['wrap'] = {-1:1, 4:0}
- real_indices['clip'] = {-1:0, 4:1}
- # Currently all types but object, use the same function generation.
- # So it should not be necessary to test all. However test also a non
- # refcounted struct on top of object.
- types = np.int, np.object, np.dtype([('', 'i', 2)])
- for t in types:
- # ta works, even if the array may be odd if buffer interface is used
- ta = np.array(a if np.issubdtype(t, np.number) else a_str, dtype=t)
- tresult = list(ta.T.copy())
- for index_array in index_arrays:
- if index_array.size != 0:
- tresult[0].shape = (2,) + index_array.shape
- tresult[1].shape = (2,) + index_array.shape
- for mode in modes:
- for index in indices:
- real_index = real_indices[mode][index]
- if real_index is IndexError and index_array.size != 0:
- index_array.put(0, index)
- assert_raises(IndexError, ta.take, index_array,
- mode=mode, axis=1)
- elif index_array.size != 0:
- index_array.put(0, index)
- res = ta.take(index_array, mode=mode, axis=1)
- assert_array_equal(res, tresult[real_index])
- else:
- res = ta.take(index_array, mode=mode, axis=1)
- assert_(res.shape == (2,) + index_array.shape)
+
+class TestTake(TestCase):
+ def test_simple(self):
+ a = [[1, 2], [3, 4]]
+ a_str = [[b'1', b'2'],[b'3', b'4']]
+ modes = ['raise', 'wrap', 'clip']
+ indices = [-1, 4]
+ index_arrays = [np.empty(0, dtype=np.intp),
+ np.empty(tuple(), dtype=np.intp),
+ np.empty((1,1), dtype=np.intp)]
+ real_indices = {}
+ real_indices['raise'] = {-1:1, 4:IndexError}
+ real_indices['wrap'] = {-1:1, 4:0}
+ real_indices['clip'] = {-1:0, 4:1}
+ # Currently all types but object, use the same function generation.
+ # So it should not be necessary to test all. However test also a non
+ # refcounted struct on top of object.
+ types = np.int, np.object, np.dtype([('', 'i', 2)])
+ for t in types:
+ # ta works, even if the array may be odd if buffer interface is used
+ ta = np.array(a if np.issubdtype(t, np.number) else a_str, dtype=t)
+ tresult = list(ta.T.copy())
+ for index_array in index_arrays:
+ if index_array.size != 0:
+ tresult[0].shape = (2,) + index_array.shape
+ tresult[1].shape = (2,) + index_array.shape
+ for mode in modes:
+ for index in indices:
+ real_index = real_indices[mode][index]
+ if real_index is IndexError and index_array.size != 0:
+ index_array.put(0, index)
+ assert_raises(IndexError, ta.take, index_array,
+ mode=mode, axis=1)
+ elif index_array.size != 0:
+ index_array.put(0, index)
+ res = ta.take(index_array, mode=mode, axis=1)
+ assert_array_equal(res, tresult[real_index])
+ else:
+ res = ta.take(index_array, mode=mode, axis=1)
+ assert_(res.shape == (2,) + index_array.shape)
+
+
+ def test_refcounting(self):
+ objects = [object() for i in xrange(10)]
+ for mode in ('raise', 'clip', 'wrap'):
+ a = np.array(objects)
+ b = np.array([2, 2, 4, 5, 3, 5])
+ a.take(b, out=a[:6])
+ del a
+ assert_(all(sys.getrefcount(o) == 3 for o in objects))
+ # not contiguous, example:
+ a = np.array(objects * 2)[::2]
+ a.take(b, out=a[:6])
+ del a
+ assert_(all(sys.getrefcount(o) == 3 for o in objects))
+
if __name__ == "__main__":
run_module_suite()