diff options
author | Skipper Seabold <jsseabold@gmail.com> | 2011-06-29 22:19:52 -0400 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2011-07-08 20:39:29 -0600 |
commit | 978862660363f17962149ece1dfb67fa8051a8a1 (patch) | |
tree | 969a9ff0eafc3cf31ec3355811805b5486957f3c /numpy/lib/recfunctions.py | |
parent | 834b5bf5219be6d874ff547775e728151f8d6cca (diff) | |
download | numpy-978862660363f17962149ece1dfb67fa8051a8a1.tar.gz |
BUG: Fixed bugs in join_by and added tests
Diffstat (limited to 'numpy/lib/recfunctions.py')
-rw-r--r-- | numpy/lib/recfunctions.py | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/numpy/lib/recfunctions.py b/numpy/lib/recfunctions.py index b3c210fff..0127df9f9 100644 --- a/numpy/lib/recfunctions.py +++ b/numpy/lib/recfunctions.py @@ -895,6 +895,13 @@ def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2', (nb1, nb2) = (len(r1), len(r2)) (r1names, r2names) = (r1.dtype.names, r2.dtype.names) + # Check the names for collision + if (set.intersection(set(r1names),set(r2names)).difference(key) and + not (r1postfix or r2postfix)): + msg = "r1 and r2 contain common names, r1postfix and r2postfix " + msg += "can't be empty" + raise ValueError(msg) + # Make temporary arrays of just the keys r1k = drop_fields(r1, [n for n in r1names if n not in key]) r2k = drop_fields(r2, [n for n in r2names if n not in key]) @@ -937,7 +944,7 @@ def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2', name = desc[0] # Have we seen the current name already ? if name in names: - nameidx = names.index(name) + nameidx = ndtype.index(desc) current = ndtype[nameidx] # The current field is part of the key: take the largest dtype if name in key: @@ -960,7 +967,7 @@ def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2', names = output.dtype.names for f in r1names: selected = s1[f] - if f not in names: + if f not in names or (f in r2names and not r2postfix and not f in key): f += r1postfix current = output[f] current[:r1cmn] = selected[:r1cmn] @@ -968,7 +975,7 @@ def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2', current[cmn:cmn + r1spc] = selected[r1cmn:] for f in r2names: selected = s2[f] - if f not in names: + if f not in names or (f in r1names and not r1postfix and f not in key): f += r2postfix current = output[f] current[:r2cmn] = selected[:r2cmn] |