summaryrefslogtreecommitdiff
path: root/numpy/lib/tests/test_recfunctions.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/tests/test_recfunctions.py')
-rw-r--r--numpy/lib/tests/test_recfunctions.py129
1 files changed, 108 insertions, 21 deletions
diff --git a/numpy/lib/tests/test_recfunctions.py b/numpy/lib/tests/test_recfunctions.py
index e9cfa4993..bc9f8d7b6 100644
--- a/numpy/lib/tests/test_recfunctions.py
+++ b/numpy/lib/tests/test_recfunctions.py
@@ -4,7 +4,9 @@ import numpy as np
import numpy.ma as ma
from numpy.ma.mrecords import MaskedRecords
from numpy.ma.testutils import assert_equal
-from numpy.testing import TestCase, run_module_suite, assert_, assert_raises
+from numpy.testing import (
+ run_module_suite, assert_, assert_raises, dec
+ )
from numpy.lib.recfunctions import (
drop_fields, rename_fields, get_fieldstructure, recursive_fill_fields,
find_duplicates, merge_arrays, append_fields, stack_arrays, join_by
@@ -14,10 +16,10 @@ get_names_flat = np.lib.recfunctions.get_names_flat
zip_descr = np.lib.recfunctions.zip_descr
-class TestRecFunctions(TestCase):
+class TestRecFunctions(object):
# Misc tests
- def setUp(self):
+ def setup(self):
x = np.array([1, 2, ])
y = np.array([10, 20, 30])
z = np.array([('A', 1.), ('B', 2.)],
@@ -191,7 +193,7 @@ class TestRecFunctions(TestCase):
assert_equal(test[0], a[test[-1]])
-class TestRecursiveFillFields(TestCase):
+class TestRecursiveFillFields(object):
# Test recursive_fill_fields.
def test_simple_flexible(self):
# Test recursive_fill_fields on flexible-array
@@ -214,10 +216,10 @@ class TestRecursiveFillFields(TestCase):
assert_equal(test, control)
-class TestMergeArrays(TestCase):
+class TestMergeArrays(object):
# Test merge_arrays
- def setUp(self):
+ def setup(self):
x = np.array([1, 2, ])
y = np.array([10, 20, 30])
z = np.array(
@@ -347,10 +349,10 @@ class TestMergeArrays(TestCase):
assert_equal(test, control)
-class TestAppendFields(TestCase):
+class TestAppendFields(object):
# Test append_fields
- def setUp(self):
+ def setup(self):
x = np.array([1, 2, ])
y = np.array([10, 20, 30])
z = np.array(
@@ -401,9 +403,9 @@ class TestAppendFields(TestCase):
assert_equal(test, control)
-class TestStackArrays(TestCase):
+class TestStackArrays(object):
# Test stack_arrays
- def setUp(self):
+ def setup(self):
x = np.array([1, 2, ])
y = np.array([10, 20, 30])
z = np.array(
@@ -417,11 +419,11 @@ class TestStackArrays(TestCase):
(_, x, _, _) = self.data
test = stack_arrays((x,))
assert_equal(test, x)
- self.assertTrue(test is x)
+ assert_(test is x)
test = stack_arrays(x)
assert_equal(test, x)
- self.assertTrue(test is x)
+ assert_(test is x)
def test_unnamed_fields(self):
# Tests combinations of arrays w/o named fields
@@ -546,9 +548,38 @@ class TestStackArrays(TestCase):
assert_equal(test, control)
assert_equal(test.mask, control.mask)
-
-class TestJoinBy(TestCase):
- def setUp(self):
+ def test_subdtype(self):
+ z = np.array([
+ ('A', 1), ('B', 2)
+ ], dtype=[('A', '|S3'), ('B', float, (1,))])
+ zz = np.array([
+ ('a', [10.], 100.), ('b', [20.], 200.), ('c', [30.], 300.)
+ ], dtype=[('A', '|S3'), ('B', float, (1,)), ('C', float)])
+
+ res = stack_arrays((z, zz))
+ expected = ma.array(
+ data=[
+ (b'A', [1.0], 0),
+ (b'B', [2.0], 0),
+ (b'a', [10.0], 100.0),
+ (b'b', [20.0], 200.0),
+ (b'c', [30.0], 300.0)],
+ mask=[
+ (False, [False], True),
+ (False, [False], True),
+ (False, [False], False),
+ (False, [False], False),
+ (False, [False], False)
+ ],
+ dtype=zz.dtype
+ )
+ assert_equal(res.dtype, expected.dtype)
+ assert_equal(res, expected)
+ assert_equal(res.mask, expected.mask)
+
+
+class TestJoinBy(object):
+ def setup(self):
self.a = np.array(list(zip(np.arange(10), np.arange(50, 60),
np.arange(100, 110))),
dtype=[('a', int), ('b', int), ('c', int)])
@@ -656,10 +687,66 @@ class TestJoinBy(TestCase):
b = np.ones(3, dtype=[('c', 'u1'), ('b', 'f4'), ('a', 'i4')])
assert_raises(ValueError, join_by, ['a', 'b', 'b'], a, b)
+ @dec.knownfailureif(True)
+ def test_same_name_different_dtypes_key(self):
+ a_dtype = np.dtype([('key', 'S5'), ('value', '<f4')])
+ b_dtype = np.dtype([('key', 'S10'), ('value', '<f4')])
+ expected_dtype = np.dtype([
+ ('key', 'S10'), ('value1', '<f4'), ('value2', '<f4')])
+
+ a = np.array([('Sarah', 8.0), ('John', 6.0)], dtype=a_dtype)
+ b = np.array([('Sarah', 10.0), ('John', 7.0)], dtype=b_dtype)
+ res = join_by('key', a, b)
+
+ assert_equal(res.dtype, expected_dtype)
+
+ def test_same_name_different_dtypes(self):
+ # gh-9338
+ a_dtype = np.dtype([('key', 'S10'), ('value', '<f4')])
+ b_dtype = np.dtype([('key', 'S10'), ('value', '<f8')])
+ expected_dtype = np.dtype([
+ ('key', '|S10'), ('value1', '<f4'), ('value2', '<f8')])
+
+ a = np.array([('Sarah', 8.0), ('John', 6.0)], dtype=a_dtype)
+ b = np.array([('Sarah', 10.0), ('John', 7.0)], dtype=b_dtype)
+ res = join_by('key', a, b)
+
+ assert_equal(res.dtype, expected_dtype)
+
+ def test_subarray_key(self):
+ a_dtype = np.dtype([('pos', int, 3), ('f', '<f4')])
+ a = np.array([([1, 1, 1], np.pi), ([1, 2, 3], 0.0)], dtype=a_dtype)
+
+ b_dtype = np.dtype([('pos', int, 3), ('g', '<f4')])
+ b = np.array([([1, 1, 1], 3), ([3, 2, 1], 0.0)], dtype=b_dtype)
+
+ expected_dtype = np.dtype([('pos', int, 3), ('f', '<f4'), ('g', '<f4')])
+ expected = np.array([([1, 1, 1], np.pi, 3)], dtype=expected_dtype)
+
+ res = join_by('pos', a, b)
+ assert_equal(res.dtype, expected_dtype)
+ assert_equal(res, expected)
+
+ def test_padded_dtype(self):
+ dt = np.dtype('i1,f4', align=True)
+ dt.names = ('k', 'v')
+ assert_(len(dt.descr), 3) # padding field is inserted
+
+ a = np.array([(1, 3), (3, 2)], dt)
+ b = np.array([(1, 1), (2, 2)], dt)
+ res = join_by('k', a, b)
+
+ # no padding fields remain
+ expected_dtype = np.dtype([
+ ('k', 'i1'), ('v1', 'f4'), ('v2', 'f4')
+ ])
+
+ assert_equal(res.dtype, expected_dtype)
+
-class TestJoinBy2(TestCase):
+class TestJoinBy2(object):
@classmethod
- def setUp(cls):
+ def setup(cls):
cls.a = np.array(list(zip(np.arange(10), np.arange(50, 60),
np.arange(100, 110))),
dtype=[('a', int), ('b', int), ('c', int)])
@@ -683,8 +770,8 @@ class TestJoinBy2(TestCase):
assert_equal(test, control)
def test_no_postfix(self):
- self.assertRaises(ValueError, join_by, 'a', self.a, self.b,
- r1postfix='', r2postfix='')
+ assert_raises(ValueError, join_by, 'a', self.a, self.b,
+ r1postfix='', r2postfix='')
def test_no_r2postfix(self):
# Basic test of join_by no_r2postfix
@@ -722,13 +809,13 @@ class TestJoinBy2(TestCase):
assert_equal(test.dtype, control.dtype)
assert_equal(test, control)
-class TestAppendFieldsObj(TestCase):
+class TestAppendFieldsObj(object):
"""
Test append_fields with arrays containing objects
"""
# https://github.com/numpy/numpy/issues/2346
- def setUp(self):
+ def setup(self):
from datetime import date
self.data = dict(obj=date(2000, 1, 1))