diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2017-07-01 14:25:21 +0100 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2017-07-01 15:10:06 +0100 |
commit | bdbac02b0bddb265840cc00cc5dec0590c09b093 (patch) | |
tree | a33c87b8d62062a9e09e8029f06d960938790561 /numpy/lib/recfunctions.py | |
parent | cd761d81b571525ac6c2cca36da6bd270bb8357d (diff) | |
download | numpy-bdbac02b0bddb265840cc00cc5dec0590c09b093.tar.gz |
BUG: recfunctions.join_by fails when key is a subdtype
It seems that working with .descr is a generally terrible idea.
Instead we introduce `get_fieldspec`, which returns a list of 2-tuples,
encapsulating subdtypes.
This also means that np.core.test_rational.rational survives a roundtrip - its
.descr is 'V8', which ddoesn't survive
Diffstat (limited to 'numpy/lib/recfunctions.py')
-rw-r--r-- | numpy/lib/recfunctions.py | 58 |
1 files changed, 47 insertions, 11 deletions
diff --git a/numpy/lib/recfunctions.py b/numpy/lib/recfunctions.py index e42421786..a0a070547 100644 --- a/numpy/lib/recfunctions.py +++ b/numpy/lib/recfunctions.py @@ -70,6 +70,42 @@ def recursive_fill_fields(input, output): return output +def get_fieldspec(dtype): + """ + Produce a list of name/dtype pairs corresponding to the dtype fields + + Similar to dtype.descr, but the second item of each tuple is a dtype, not a + string. As a result, this handles subarray dtypes + + Can be passed to the dtype constructor to reconstruct the dtype, noting that + this (deliberately) discards field offsets. + + Examples + -------- + >>> dt = np.dtype([(('a', 'A'), int), ('b', float, 3)]) + >>> dt.descr + [(('a', 'A'), '<i4'), ('b', '<f8', (3,))] + >>> get_fieldspec(dt) + [(('a', 'A'), dtype('int32')), ('b', dtype(('<f8', (3,))))] + + """ + if dtype.names is None: + # .descr returns a nameless field, so we should too + return [('', dtype)] + else: + # extract the titles of the fields + name_titles = {} + for d in dtype.descr: + name_title = d[0] + if isinstance(name_title, tuple): + name = name_title[1] + else: + name = name_title + name_titles[name] = name_title + + return [(name_titles[name], dtype[name]) for name in dtype.names] + + def get_names(adtype): """ Returns the field names of the input datatype as a tuple. @@ -960,33 +996,33 @@ def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2', # # Build the new description of the output array ....... # Start with the key fields - ndtype = [list(_) for _ in r1k.dtype.descr] + ndtype = [list(f) for f in get_fieldspec(r1k.dtype)] # Add the other fields - ndtype.extend(list(_) for _ in r1.dtype.descr if _[0] not in key) + ndtype.extend(list(f) for f in get_fieldspec(r1.dtype) if f[0] not in key) - for desc in r2.dtype.descr: - desc = list(desc) + for field in get_fieldspec(r2.dtype): + field = list(field) # Have we seen the current name already ? - name = desc[0] + name = field[0] names = list(_[0] for _ in ndtype) try: nameidx = names.index(name) except ValueError: #... we haven't: just add the description to the current list - ndtype.append(desc) + ndtype.append(field) else: current = ndtype[nameidx] if name in key: # The current field is part of the key: take the largest dtype - current[-1] = max(desc[1], current[-1]) + current[1] = max(field[1], current[1]) else: # The current field is not part of the key: add the suffixes, # and place the new field adjacent to the old one current[0] += r1postfix - desc[0] += r2postfix - ndtype.insert(nameidx + 1, desc) - # Revert the elements to tuples - ndtype = [tuple(_) for _ in ndtype] + field[0] += r2postfix + ndtype.insert(nameidx + 1, field) + # Rebuild a dtype from the new fields + ndtype = np.dtype([tuple(_) for _ in ndtype]) # Find the largest nb of common fields : # r1cmn and r2cmn should be equal, but... cmn = max(r1cmn, r2cmn) |