summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/recfunctions.py13
-rw-r--r--numpy/lib/tests/test_recfunctions.py67
2 files changed, 74 insertions, 6 deletions
diff --git a/numpy/lib/recfunctions.py b/numpy/lib/recfunctions.py
index b3c210fff..0127df9f9 100644
--- a/numpy/lib/recfunctions.py
+++ b/numpy/lib/recfunctions.py
@@ -895,6 +895,13 @@ def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2',
(nb1, nb2) = (len(r1), len(r2))
(r1names, r2names) = (r1.dtype.names, r2.dtype.names)
+ # Check the names for collision
+ if (set.intersection(set(r1names),set(r2names)).difference(key) and
+ not (r1postfix or r2postfix)):
+ msg = "r1 and r2 contain common names, r1postfix and r2postfix "
+ msg += "can't be empty"
+ raise ValueError(msg)
+
# Make temporary arrays of just the keys
r1k = drop_fields(r1, [n for n in r1names if n not in key])
r2k = drop_fields(r2, [n for n in r2names if n not in key])
@@ -937,7 +944,7 @@ def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2',
name = desc[0]
# Have we seen the current name already ?
if name in names:
- nameidx = names.index(name)
+ nameidx = ndtype.index(desc)
current = ndtype[nameidx]
# The current field is part of the key: take the largest dtype
if name in key:
@@ -960,7 +967,7 @@ def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2',
names = output.dtype.names
for f in r1names:
selected = s1[f]
- if f not in names:
+ if f not in names or (f in r2names and not r2postfix and not f in key):
f += r1postfix
current = output[f]
current[:r1cmn] = selected[:r1cmn]
@@ -968,7 +975,7 @@ def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2',
current[cmn:cmn + r1spc] = selected[r1cmn:]
for f in r2names:
selected = s2[f]
- if f not in names:
+ if f not in names or (f in r1names and not r1postfix and f not in key):
f += r2postfix
current = output[f]
current[:r2cmn] = selected[:r2cmn]
diff --git a/numpy/lib/tests/test_recfunctions.py b/numpy/lib/tests/test_recfunctions.py
index 57d977814..c6befa5f6 100644
--- a/numpy/lib/tests/test_recfunctions.py
+++ b/numpy/lib/tests/test_recfunctions.py
@@ -36,7 +36,7 @@ class TestRecFunctions(TestCase):
test = zip_descr((x, x), flatten=False)
assert_equal(test,
np.dtype([('', int), ('', int)]))
- # Std & flexible-dtype
+ # Std & flexible-dtype
test = zip_descr((x, z), flatten=True)
assert_equal(test,
np.dtype([('', int), ('A', '|S3'), ('B', float)]))
@@ -44,7 +44,7 @@ class TestRecFunctions(TestCase):
assert_equal(test,
np.dtype([('', int),
('', [('A', '|S3'), ('B', float)])]))
- # Standard & nested dtype
+ # Standard & nested dtype
test = zip_descr((x, w), flatten=True)
assert_equal(test,
np.dtype([('', int),
@@ -259,7 +259,7 @@ class TestMergeArrays(TestCase):
control = np.array([(1, 10), (2, 20), (-1, 30)],
dtype=[('f0', int), ('f1', int)])
assert_equal(test, control)
- #
+ #
test = merge_arrays((x, y), usemask=True)
control = ma.array([(1, 10), (2, 20), (-1, 30)],
mask=[(0, 0), (0, 0), (1, 0)],
@@ -615,6 +615,67 @@ class TestJoinBy(TestCase):
dtype=[('a', int), ('b', int), ('c', int), ('d', int)])
+class TestJoinBy2(TestCase):
+ @classmethod
+ def setUp(cls):
+ cls.a = np.array(zip(np.arange(10), np.arange(50, 60),
+ np.arange(100, 110)),
+ dtype=[('a', int), ('b', int), ('c', int)])
+ cls.b = np.array(zip(np.arange(10), np.arange(65, 75),
+ np.arange(100, 110)),
+ dtype=[('a', int), ('b', int), ('d', int)])
+
+ def test_no_r1postfix(self):
+ "Basic test of join_by"
+ a, b = self.a, self.b
+
+ test = join_by('a', a, b, r1postfix='', r2postfix='2', jointype='inner')
+ control = np.array([(0, 50, 65, 100, 100), (1, 51, 66, 101, 101),
+ (2, 52, 67, 102, 102), (3, 53, 68, 103, 103),
+ (4, 54, 69, 104, 104), (5, 55, 70, 105, 105),
+ (6, 56, 71, 106, 106), (7, 57, 72, 107, 107),
+ (8, 58, 73, 108, 108), (9, 59, 74, 109, 109)],
+ dtype=[('a', int), ('b', int), ('b2', int),
+ ('c', int), ('d', int)])
+ assert_equal(test, control)
+
+
+ def test_no_postfix(self):
+ self.assertRaises(ValueError, join_by, 'a', self.a, self.b, r1postfix='', r2postfix='')
+
+ def test_no_r2postfix(self):
+ "Basic test of join_by"
+ a, b = self.a, self.b
+
+ test = join_by('a', a, b, r1postfix='1', r2postfix='', jointype='inner')
+ control = np.array([(0, 50, 65, 100, 100), (1, 51, 66, 101, 101),
+ (2, 52, 67, 102, 102), (3, 53, 68, 103, 103),
+ (4, 54, 69, 104, 104), (5, 55, 70, 105, 105),
+ (6, 56, 71, 106, 106), (7, 57, 72, 107, 107),
+ (8, 58, 73, 108, 108), (9, 59, 74, 109, 109)],
+ dtype=[('a', int), ('b1', int), ('b', int),
+ ('c', int), ('d', int)])
+ assert_equal(test, control)
+
+ def test_two_keys_two_vars(self):
+ a = np.array(zip(np.tile([10,11],5),np.repeat(np.arange(5),2),
+ np.arange(50, 60), np.arange(10,20)),
+ dtype=[('k', int), ('a', int), ('b', int),('c',int)])
+
+ b = np.array(zip(np.tile([10,11],5),np.repeat(np.arange(5),2),
+ np.arange(65, 75), np.arange(0,10)),
+ dtype=[('k', int), ('a', int), ('b', int), ('c',int)])
+
+ control = np.array([(10, 0, 50, 65, 10, 0), (11, 0, 51, 66, 11, 1),
+ (10, 1, 52, 67, 12, 2), (11, 1, 53, 68, 13, 3),
+ (10, 2, 54, 69, 14, 4), (11, 2, 55, 70, 15, 5),
+ (10, 3, 56, 71, 16, 6), (11, 3, 57, 72, 17, 7),
+ (10, 4, 58, 73, 18, 8), (11, 4, 59, 74, 19, 9)],
+ dtype=[('k', '<i8'), ('a', '<i8'), ('b1', '<i8'),
+ ('b2', '<i8'), ('c1', '<i8'), ('c2', '<i8')])
+ test = join_by(['a','k'], a, b, r1postfix='1', r2postfix='2', jointype='inner')
+ assert_equal(test, control)
+
if __name__ == '__main__':