summaryrefslogtreecommitdiff
path: root/numpy/lib/tests/test_recfunctions.py
diff options
context:
space:
mode:
authorSkipper Seabold <jsseabold@gmail.com>2011-06-29 22:19:52 -0400
committerCharles Harris <charlesr.harris@gmail.com>2011-07-08 20:39:29 -0600
commit978862660363f17962149ece1dfb67fa8051a8a1 (patch)
tree969a9ff0eafc3cf31ec3355811805b5486957f3c /numpy/lib/tests/test_recfunctions.py
parent834b5bf5219be6d874ff547775e728151f8d6cca (diff)
downloadnumpy-978862660363f17962149ece1dfb67fa8051a8a1.tar.gz
BUG: Fixed bugs in join_by and added tests
Diffstat (limited to 'numpy/lib/tests/test_recfunctions.py')
-rw-r--r--numpy/lib/tests/test_recfunctions.py67
1 files changed, 64 insertions, 3 deletions
diff --git a/numpy/lib/tests/test_recfunctions.py b/numpy/lib/tests/test_recfunctions.py
index 57d977814..c6befa5f6 100644
--- a/numpy/lib/tests/test_recfunctions.py
+++ b/numpy/lib/tests/test_recfunctions.py
@@ -36,7 +36,7 @@ class TestRecFunctions(TestCase):
test = zip_descr((x, x), flatten=False)
assert_equal(test,
np.dtype([('', int), ('', int)]))
- # Std & flexible-dtype
+ # Std & flexible-dtype
test = zip_descr((x, z), flatten=True)
assert_equal(test,
np.dtype([('', int), ('A', '|S3'), ('B', float)]))
@@ -44,7 +44,7 @@ class TestRecFunctions(TestCase):
assert_equal(test,
np.dtype([('', int),
('', [('A', '|S3'), ('B', float)])]))
- # Standard & nested dtype
+ # Standard & nested dtype
test = zip_descr((x, w), flatten=True)
assert_equal(test,
np.dtype([('', int),
@@ -259,7 +259,7 @@ class TestMergeArrays(TestCase):
control = np.array([(1, 10), (2, 20), (-1, 30)],
dtype=[('f0', int), ('f1', int)])
assert_equal(test, control)
- #
+ #
test = merge_arrays((x, y), usemask=True)
control = ma.array([(1, 10), (2, 20), (-1, 30)],
mask=[(0, 0), (0, 0), (1, 0)],
@@ -615,6 +615,67 @@ class TestJoinBy(TestCase):
dtype=[('a', int), ('b', int), ('c', int), ('d', int)])
+class TestJoinBy2(TestCase):
+ @classmethod
+ def setUp(cls):
+ cls.a = np.array(zip(np.arange(10), np.arange(50, 60),
+ np.arange(100, 110)),
+ dtype=[('a', int), ('b', int), ('c', int)])
+ cls.b = np.array(zip(np.arange(10), np.arange(65, 75),
+ np.arange(100, 110)),
+ dtype=[('a', int), ('b', int), ('d', int)])
+
+ def test_no_r1postfix(self):
+ "Basic test of join_by"
+ a, b = self.a, self.b
+
+ test = join_by('a', a, b, r1postfix='', r2postfix='2', jointype='inner')
+ control = np.array([(0, 50, 65, 100, 100), (1, 51, 66, 101, 101),
+ (2, 52, 67, 102, 102), (3, 53, 68, 103, 103),
+ (4, 54, 69, 104, 104), (5, 55, 70, 105, 105),
+ (6, 56, 71, 106, 106), (7, 57, 72, 107, 107),
+ (8, 58, 73, 108, 108), (9, 59, 74, 109, 109)],
+ dtype=[('a', int), ('b', int), ('b2', int),
+ ('c', int), ('d', int)])
+ assert_equal(test, control)
+
+
+ def test_no_postfix(self):
+ self.assertRaises(ValueError, join_by, 'a', self.a, self.b, r1postfix='', r2postfix='')
+
+ def test_no_r2postfix(self):
+ "Basic test of join_by"
+ a, b = self.a, self.b
+
+ test = join_by('a', a, b, r1postfix='1', r2postfix='', jointype='inner')
+ control = np.array([(0, 50, 65, 100, 100), (1, 51, 66, 101, 101),
+ (2, 52, 67, 102, 102), (3, 53, 68, 103, 103),
+ (4, 54, 69, 104, 104), (5, 55, 70, 105, 105),
+ (6, 56, 71, 106, 106), (7, 57, 72, 107, 107),
+ (8, 58, 73, 108, 108), (9, 59, 74, 109, 109)],
+ dtype=[('a', int), ('b1', int), ('b', int),
+ ('c', int), ('d', int)])
+ assert_equal(test, control)
+
+ def test_two_keys_two_vars(self):
+ a = np.array(zip(np.tile([10,11],5),np.repeat(np.arange(5),2),
+ np.arange(50, 60), np.arange(10,20)),
+ dtype=[('k', int), ('a', int), ('b', int),('c',int)])
+
+ b = np.array(zip(np.tile([10,11],5),np.repeat(np.arange(5),2),
+ np.arange(65, 75), np.arange(0,10)),
+ dtype=[('k', int), ('a', int), ('b', int), ('c',int)])
+
+ control = np.array([(10, 0, 50, 65, 10, 0), (11, 0, 51, 66, 11, 1),
+ (10, 1, 52, 67, 12, 2), (11, 1, 53, 68, 13, 3),
+ (10, 2, 54, 69, 14, 4), (11, 2, 55, 70, 15, 5),
+ (10, 3, 56, 71, 16, 6), (11, 3, 57, 72, 17, 7),
+ (10, 4, 58, 73, 18, 8), (11, 4, 59, 74, 19, 9)],
+ dtype=[('k', '<i8'), ('a', '<i8'), ('b1', '<i8'),
+ ('b2', '<i8'), ('c1', '<i8'), ('c2', '<i8')])
+ test = join_by(['a','k'], a, b, r1postfix='1', r2postfix='2', jointype='inner')
+ assert_equal(test, control)
+
if __name__ == '__main__':