summaryrefslogtreecommitdiff
path: root/numpy/lib/recfunctions.py
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2017-07-01 14:25:21 +0100
committerEric Wieser <wieser.eric@gmail.com>2017-07-01 15:10:06 +0100
commitbdbac02b0bddb265840cc00cc5dec0590c09b093 (patch)
treea33c87b8d62062a9e09e8029f06d960938790561 /numpy/lib/recfunctions.py
parentcd761d81b571525ac6c2cca36da6bd270bb8357d (diff)
downloadnumpy-bdbac02b0bddb265840cc00cc5dec0590c09b093.tar.gz
BUG: recfunctions.join_by fails when key is a subdtype
It seems that working with .descr is a generally terrible idea. Instead we introduce `get_fieldspec`, which returns a list of 2-tuples, encapsulating subdtypes. This also means that np.core.test_rational.rational survives a roundtrip - its .descr is 'V8', which ddoesn't survive
Diffstat (limited to 'numpy/lib/recfunctions.py')
-rw-r--r--numpy/lib/recfunctions.py58
1 files changed, 47 insertions, 11 deletions
diff --git a/numpy/lib/recfunctions.py b/numpy/lib/recfunctions.py
index e42421786..a0a070547 100644
--- a/numpy/lib/recfunctions.py
+++ b/numpy/lib/recfunctions.py
@@ -70,6 +70,42 @@ def recursive_fill_fields(input, output):
return output
+def get_fieldspec(dtype):
+ """
+ Produce a list of name/dtype pairs corresponding to the dtype fields
+
+ Similar to dtype.descr, but the second item of each tuple is a dtype, not a
+ string. As a result, this handles subarray dtypes
+
+ Can be passed to the dtype constructor to reconstruct the dtype, noting that
+ this (deliberately) discards field offsets.
+
+ Examples
+ --------
+ >>> dt = np.dtype([(('a', 'A'), int), ('b', float, 3)])
+ >>> dt.descr
+ [(('a', 'A'), '<i4'), ('b', '<f8', (3,))]
+ >>> get_fieldspec(dt)
+ [(('a', 'A'), dtype('int32')), ('b', dtype(('<f8', (3,))))]
+
+ """
+ if dtype.names is None:
+ # .descr returns a nameless field, so we should too
+ return [('', dtype)]
+ else:
+ # extract the titles of the fields
+ name_titles = {}
+ for d in dtype.descr:
+ name_title = d[0]
+ if isinstance(name_title, tuple):
+ name = name_title[1]
+ else:
+ name = name_title
+ name_titles[name] = name_title
+
+ return [(name_titles[name], dtype[name]) for name in dtype.names]
+
+
def get_names(adtype):
"""
Returns the field names of the input datatype as a tuple.
@@ -960,33 +996,33 @@ def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2',
#
# Build the new description of the output array .......
# Start with the key fields
- ndtype = [list(_) for _ in r1k.dtype.descr]
+ ndtype = [list(f) for f in get_fieldspec(r1k.dtype)]
# Add the other fields
- ndtype.extend(list(_) for _ in r1.dtype.descr if _[0] not in key)
+ ndtype.extend(list(f) for f in get_fieldspec(r1.dtype) if f[0] not in key)
- for desc in r2.dtype.descr:
- desc = list(desc)
+ for field in get_fieldspec(r2.dtype):
+ field = list(field)
# Have we seen the current name already ?
- name = desc[0]
+ name = field[0]
names = list(_[0] for _ in ndtype)
try:
nameidx = names.index(name)
except ValueError:
#... we haven't: just add the description to the current list
- ndtype.append(desc)
+ ndtype.append(field)
else:
current = ndtype[nameidx]
if name in key:
# The current field is part of the key: take the largest dtype
- current[-1] = max(desc[1], current[-1])
+ current[1] = max(field[1], current[1])
else:
# The current field is not part of the key: add the suffixes,
# and place the new field adjacent to the old one
current[0] += r1postfix
- desc[0] += r2postfix
- ndtype.insert(nameidx + 1, desc)
- # Revert the elements to tuples
- ndtype = [tuple(_) for _ in ndtype]
+ field[0] += r2postfix
+ ndtype.insert(nameidx + 1, field)
+ # Rebuild a dtype from the new fields
+ ndtype = np.dtype([tuple(_) for _ in ndtype])
# Find the largest nb of common fields :
# r1cmn and r2cmn should be equal, but...
cmn = max(r1cmn, r2cmn)