diff options
Diffstat (limited to 'numpy/lib/tests/test_recfunctions.py')
-rw-r--r-- | numpy/lib/tests/test_recfunctions.py | 129 |
1 files changed, 108 insertions, 21 deletions
diff --git a/numpy/lib/tests/test_recfunctions.py b/numpy/lib/tests/test_recfunctions.py index e9cfa4993..bc9f8d7b6 100644 --- a/numpy/lib/tests/test_recfunctions.py +++ b/numpy/lib/tests/test_recfunctions.py @@ -4,7 +4,9 @@ import numpy as np import numpy.ma as ma from numpy.ma.mrecords import MaskedRecords from numpy.ma.testutils import assert_equal -from numpy.testing import TestCase, run_module_suite, assert_, assert_raises +from numpy.testing import ( + run_module_suite, assert_, assert_raises, dec + ) from numpy.lib.recfunctions import ( drop_fields, rename_fields, get_fieldstructure, recursive_fill_fields, find_duplicates, merge_arrays, append_fields, stack_arrays, join_by @@ -14,10 +16,10 @@ get_names_flat = np.lib.recfunctions.get_names_flat zip_descr = np.lib.recfunctions.zip_descr -class TestRecFunctions(TestCase): +class TestRecFunctions(object): # Misc tests - def setUp(self): + def setup(self): x = np.array([1, 2, ]) y = np.array([10, 20, 30]) z = np.array([('A', 1.), ('B', 2.)], @@ -191,7 +193,7 @@ class TestRecFunctions(TestCase): assert_equal(test[0], a[test[-1]]) -class TestRecursiveFillFields(TestCase): +class TestRecursiveFillFields(object): # Test recursive_fill_fields. def test_simple_flexible(self): # Test recursive_fill_fields on flexible-array @@ -214,10 +216,10 @@ class TestRecursiveFillFields(TestCase): assert_equal(test, control) -class TestMergeArrays(TestCase): +class TestMergeArrays(object): # Test merge_arrays - def setUp(self): + def setup(self): x = np.array([1, 2, ]) y = np.array([10, 20, 30]) z = np.array( @@ -347,10 +349,10 @@ class TestMergeArrays(TestCase): assert_equal(test, control) -class TestAppendFields(TestCase): +class TestAppendFields(object): # Test append_fields - def setUp(self): + def setup(self): x = np.array([1, 2, ]) y = np.array([10, 20, 30]) z = np.array( @@ -401,9 +403,9 @@ class TestAppendFields(TestCase): assert_equal(test, control) -class TestStackArrays(TestCase): +class TestStackArrays(object): # Test stack_arrays - def setUp(self): + def setup(self): x = np.array([1, 2, ]) y = np.array([10, 20, 30]) z = np.array( @@ -417,11 +419,11 @@ class TestStackArrays(TestCase): (_, x, _, _) = self.data test = stack_arrays((x,)) assert_equal(test, x) - self.assertTrue(test is x) + assert_(test is x) test = stack_arrays(x) assert_equal(test, x) - self.assertTrue(test is x) + assert_(test is x) def test_unnamed_fields(self): # Tests combinations of arrays w/o named fields @@ -546,9 +548,38 @@ class TestStackArrays(TestCase): assert_equal(test, control) assert_equal(test.mask, control.mask) - -class TestJoinBy(TestCase): - def setUp(self): + def test_subdtype(self): + z = np.array([ + ('A', 1), ('B', 2) + ], dtype=[('A', '|S3'), ('B', float, (1,))]) + zz = np.array([ + ('a', [10.], 100.), ('b', [20.], 200.), ('c', [30.], 300.) + ], dtype=[('A', '|S3'), ('B', float, (1,)), ('C', float)]) + + res = stack_arrays((z, zz)) + expected = ma.array( + data=[ + (b'A', [1.0], 0), + (b'B', [2.0], 0), + (b'a', [10.0], 100.0), + (b'b', [20.0], 200.0), + (b'c', [30.0], 300.0)], + mask=[ + (False, [False], True), + (False, [False], True), + (False, [False], False), + (False, [False], False), + (False, [False], False) + ], + dtype=zz.dtype + ) + assert_equal(res.dtype, expected.dtype) + assert_equal(res, expected) + assert_equal(res.mask, expected.mask) + + +class TestJoinBy(object): + def setup(self): self.a = np.array(list(zip(np.arange(10), np.arange(50, 60), np.arange(100, 110))), dtype=[('a', int), ('b', int), ('c', int)]) @@ -656,10 +687,66 @@ class TestJoinBy(TestCase): b = np.ones(3, dtype=[('c', 'u1'), ('b', 'f4'), ('a', 'i4')]) assert_raises(ValueError, join_by, ['a', 'b', 'b'], a, b) + @dec.knownfailureif(True) + def test_same_name_different_dtypes_key(self): + a_dtype = np.dtype([('key', 'S5'), ('value', '<f4')]) + b_dtype = np.dtype([('key', 'S10'), ('value', '<f4')]) + expected_dtype = np.dtype([ + ('key', 'S10'), ('value1', '<f4'), ('value2', '<f4')]) + + a = np.array([('Sarah', 8.0), ('John', 6.0)], dtype=a_dtype) + b = np.array([('Sarah', 10.0), ('John', 7.0)], dtype=b_dtype) + res = join_by('key', a, b) + + assert_equal(res.dtype, expected_dtype) + + def test_same_name_different_dtypes(self): + # gh-9338 + a_dtype = np.dtype([('key', 'S10'), ('value', '<f4')]) + b_dtype = np.dtype([('key', 'S10'), ('value', '<f8')]) + expected_dtype = np.dtype([ + ('key', '|S10'), ('value1', '<f4'), ('value2', '<f8')]) + + a = np.array([('Sarah', 8.0), ('John', 6.0)], dtype=a_dtype) + b = np.array([('Sarah', 10.0), ('John', 7.0)], dtype=b_dtype) + res = join_by('key', a, b) + + assert_equal(res.dtype, expected_dtype) + + def test_subarray_key(self): + a_dtype = np.dtype([('pos', int, 3), ('f', '<f4')]) + a = np.array([([1, 1, 1], np.pi), ([1, 2, 3], 0.0)], dtype=a_dtype) + + b_dtype = np.dtype([('pos', int, 3), ('g', '<f4')]) + b = np.array([([1, 1, 1], 3), ([3, 2, 1], 0.0)], dtype=b_dtype) + + expected_dtype = np.dtype([('pos', int, 3), ('f', '<f4'), ('g', '<f4')]) + expected = np.array([([1, 1, 1], np.pi, 3)], dtype=expected_dtype) + + res = join_by('pos', a, b) + assert_equal(res.dtype, expected_dtype) + assert_equal(res, expected) + + def test_padded_dtype(self): + dt = np.dtype('i1,f4', align=True) + dt.names = ('k', 'v') + assert_(len(dt.descr), 3) # padding field is inserted + + a = np.array([(1, 3), (3, 2)], dt) + b = np.array([(1, 1), (2, 2)], dt) + res = join_by('k', a, b) + + # no padding fields remain + expected_dtype = np.dtype([ + ('k', 'i1'), ('v1', 'f4'), ('v2', 'f4') + ]) + + assert_equal(res.dtype, expected_dtype) + -class TestJoinBy2(TestCase): +class TestJoinBy2(object): @classmethod - def setUp(cls): + def setup(cls): cls.a = np.array(list(zip(np.arange(10), np.arange(50, 60), np.arange(100, 110))), dtype=[('a', int), ('b', int), ('c', int)]) @@ -683,8 +770,8 @@ class TestJoinBy2(TestCase): assert_equal(test, control) def test_no_postfix(self): - self.assertRaises(ValueError, join_by, 'a', self.a, self.b, - r1postfix='', r2postfix='') + assert_raises(ValueError, join_by, 'a', self.a, self.b, + r1postfix='', r2postfix='') def test_no_r2postfix(self): # Basic test of join_by no_r2postfix @@ -722,13 +809,13 @@ class TestJoinBy2(TestCase): assert_equal(test.dtype, control.dtype) assert_equal(test, control) -class TestAppendFieldsObj(TestCase): +class TestAppendFieldsObj(object): """ Test append_fields with arrays containing objects """ # https://github.com/numpy/numpy/issues/2346 - def setUp(self): + def setup(self): from datetime import date self.data = dict(obj=date(2000, 1, 1)) |