diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2018-07-31 00:41:28 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-07-31 00:41:28 -0700 |
commit | 7f4579279a6a6aa07df664b901afa36ab3fc5ce0 (patch) | |
tree | 3524c05c661f4948eabf066b46b5ad3aaf6ad617 /numpy/lib/recfunctions.py | |
parent | 24960daf3e326591047eb099af840da6e95d0910 (diff) | |
parent | 9bb569c4e0e1cf08128179d157bdab10c8706a97 (diff) | |
download | numpy-7f4579279a6a6aa07df664b901afa36ab3fc5ce0.tar.gz |
Merge branch 'master' into ix_-preserve-type
Diffstat (limited to 'numpy/lib/recfunctions.py')
-rw-r--r-- | numpy/lib/recfunctions.py | 318 |
1 files changed, 230 insertions, 88 deletions
diff --git a/numpy/lib/recfunctions.py b/numpy/lib/recfunctions.py index 4ae1079d2..b6453d5a2 100644 --- a/numpy/lib/recfunctions.py +++ b/numpy/lib/recfunctions.py @@ -70,6 +70,37 @@ def recursive_fill_fields(input, output): return output +def get_fieldspec(dtype): + """ + Produce a list of name/dtype pairs corresponding to the dtype fields + + Similar to dtype.descr, but the second item of each tuple is a dtype, not a + string. As a result, this handles subarray dtypes + + Can be passed to the dtype constructor to reconstruct the dtype, noting that + this (deliberately) discards field offsets. + + Examples + -------- + >>> dt = np.dtype([(('a', 'A'), int), ('b', float, 3)]) + >>> dt.descr + [(('a', 'A'), '<i4'), ('b', '<f8', (3,))] + >>> get_fieldspec(dt) + [(('a', 'A'), dtype('int32')), ('b', dtype(('<f8', (3,))))] + + """ + if dtype.names is None: + # .descr returns a nameless field, so we should too + return [('', dtype)] + else: + fields = ((name, dtype.fields[name]) for name in dtype.names) + # keep any titles, if present + return [ + (name if len(f) == 2 else (f[2], name), f[0]) + for name, f in fields + ] + + def get_names(adtype): """ Returns the field names of the input datatype as a tuple. @@ -146,7 +177,7 @@ def flatten_descr(ndtype): """ names = ndtype.names if names is None: - return ndtype.descr + return (('', ndtype),) else: descr = [] for field in names: @@ -158,6 +189,22 @@ def flatten_descr(ndtype): return tuple(descr) +def zip_dtype(seqarrays, flatten=False): + newdtype = [] + if flatten: + for a in seqarrays: + newdtype.extend(flatten_descr(a.dtype)) + else: + for a in seqarrays: + current = a.dtype + if current.names and len(current.names) <= 1: + # special case - dtypes of 0 or 1 field are flattened + newdtype.extend(get_fieldspec(current)) + else: + newdtype.append(('', current)) + return np.dtype(newdtype) + + def zip_descr(seqarrays, flatten=False): """ Combine the dtype description of a series of arrays. @@ -169,19 +216,7 @@ def zip_descr(seqarrays, flatten=False): flatten : {boolean}, optional Whether to collapse nested descriptions. """ - newdtype = [] - if flatten: - for a in seqarrays: - newdtype.extend(flatten_descr(a.dtype)) - else: - for a in seqarrays: - current = a.dtype - names = current.names or () - if len(names) > 1: - newdtype.append(('', current.descr)) - else: - newdtype.extend(current.descr) - return np.dtype(newdtype).descr + return zip_dtype(seqarrays, flatten=flatten).descr def get_fieldstructure(adtype, lastname=None, parents=None,): @@ -275,24 +310,20 @@ def izip_records(seqarrays, fill_value=None, flatten=True): flatten : {True, False}, Whether to """ - # OK, that's a complete ripoff from Python2.6 itertools.izip_longest - def sentinel(counter=([fill_value] * (len(seqarrays) - 1)).pop): - "Yields the fill_value or raises IndexError" - yield counter() - # - fillers = itertools.repeat(fill_value) - iters = [itertools.chain(it, sentinel(), fillers) for it in seqarrays] + # Should we flatten the items, or just use a nested approach if flatten: zipfunc = _izip_fields_flat else: zipfunc = _izip_fields - # - try: - for tup in zip(*iters): - yield tuple(zipfunc(tup)) - except IndexError: - pass + + if sys.version_info[0] >= 3: + zip_longest = itertools.zip_longest + else: + zip_longest = itertools.izip_longest + + for tup in zip_longest(*seqarrays, fillvalue=fill_value): + yield tuple(zipfunc(tup)) def _fix_output(output, usemask=True, asrecarray=False): @@ -366,12 +397,13 @@ def merge_arrays(seqarrays, fill_value=-1, flatten=False, Notes ----- * Without a mask, the missing value will be filled with something, - * depending on what its corresponding type: - -1 for integers - -1.0 for floating point numbers - '-' for characters - '-1' for strings - True for boolean values + depending on what its corresponding type: + + * ``-1`` for integers + * ``-1.0`` for floating point numbers + * ``'-'`` for characters + * ``'-1'`` for strings + * ``True`` for boolean values * XXX: I just obtained these values empirically """ # Only one item in the input sequence ? @@ -380,13 +412,12 @@ def merge_arrays(seqarrays, fill_value=-1, flatten=False, # Do we have a single ndarray as input ? if isinstance(seqarrays, (ndarray, np.void)): seqdtype = seqarrays.dtype - if (not flatten) or \ - (zip_descr((seqarrays,), flatten=True) == seqdtype.descr): + # Make sure we have named fields + if not seqdtype.names: + seqdtype = np.dtype([('', seqdtype)]) + if not flatten or zip_dtype((seqarrays,), flatten=True) == seqdtype: # Minimal processing needed: just make sure everythng's a-ok seqarrays = seqarrays.ravel() - # Make sure we have named fields - if not seqdtype.names: - seqdtype = [('', seqdtype)] # Find what type of array we must return if usemask: if asrecarray: @@ -407,7 +438,7 @@ def merge_arrays(seqarrays, fill_value=-1, flatten=False, sizes = tuple(a.size for a in seqarrays) maxlength = max(sizes) # Get the dtype of the output (flattening if needed) - newdtype = zip_descr(seqarrays, flatten=flatten) + newdtype = zip_dtype(seqarrays, flatten=flatten) # Initialize the sequences for data and mask seqdata = [] seqmask = [] @@ -499,7 +530,7 @@ def drop_fields(base, drop_names, usemask=True, asrecarray=False): dtype=[('a', '<i4')]) """ if _is_string_like(drop_names): - drop_names = [drop_names, ] + drop_names = [drop_names] else: drop_names = set(drop_names) @@ -527,6 +558,31 @@ def drop_fields(base, drop_names, usemask=True, asrecarray=False): return _fix_output(output, usemask=usemask, asrecarray=asrecarray) +def _keep_fields(base, keep_names, usemask=True, asrecarray=False): + """ + Return a new array keeping only the fields in `keep_names`, + and preserving the order of those fields. + + Parameters + ---------- + base : array + Input array + keep_names : string or sequence + String or sequence of strings corresponding to the names of the + fields to keep. Order of the names will be preserved. + usemask : {False, True}, optional + Whether to return a masked array or not. + asrecarray : string or sequence, optional + Whether to return a recarray or a mrecarray (`asrecarray=True`) or + a plain ndarray or masked array with flexible dtype. The default + is False. + """ + newdtype = [(n, base.dtype[n]) for n in keep_names] + output = np.empty(base.shape, dtype=newdtype) + output = recursive_fill_fields(base, output) + return _fix_output(output, usemask=usemask, asrecarray=asrecarray) + + def rec_drop_fields(base, drop_names): """ Returns a new numpy.recarray with fields in `drop_names` dropped. @@ -634,8 +690,9 @@ def append_fields(base, names, data, dtypes=None, else: data = data.pop() # - output = ma.masked_all(max(len(base), len(data)), - dtype=base.dtype.descr + data.dtype.descr) + output = ma.masked_all( + max(len(base), len(data)), + dtype=get_fieldspec(base.dtype) + get_fieldspec(data.dtype)) output = recursive_fill_fields(base, output) output = recursive_fill_fields(data, output) # @@ -675,6 +732,84 @@ def rec_append_fields(base, names, data, dtypes=None): return append_fields(base, names, data=data, dtypes=dtypes, asrecarray=True, usemask=False) +def repack_fields(a, align=False, recurse=False): + """ + Re-pack the fields of a structured array or dtype in memory. + + The memory layout of structured datatypes allows fields at arbitrary + byte offsets. This means the fields can be separated by padding bytes, + their offsets can be non-monotonically increasing, and they can overlap. + + This method removes any overlaps and reorders the fields in memory so they + have increasing byte offsets, and adds or removes padding bytes depending + on the `align` option, which behaves like the `align` option to `np.dtype`. + + If `align=False`, this method produces a "packed" memory layout in which + each field starts at the byte the previous field ended, and any padding + bytes are removed. + + If `align=True`, this methods produces an "aligned" memory layout in which + each field's offset is a multiple of its alignment, and the total itemsize + is a multiple of the largest alignment, by adding padding bytes as needed. + + Parameters + ---------- + a : ndarray or dtype + array or dtype for which to repack the fields. + align : boolean + If true, use an "aligned" memory layout, otherwise use a "packed" layout. + recurse : boolean + If True, also repack nested structures. + + Returns + ------- + repacked : ndarray or dtype + Copy of `a` with fields repacked, or `a` itself if no repacking was + needed. + + Examples + -------- + + >>> def print_offsets(d): + ... print("offsets:", [d.fields[name][1] for name in d.names]) + ... print("itemsize:", d.itemsize) + ... + >>> dt = np.dtype('u1,i4,f4', align=True) + >>> dt + dtype({'names':['f0','f1','f2'], 'formats':['u1','<i4','<f8'], 'offsets':[0,4,8], 'itemsize':16}, align=True) + >>> print_offsets(dt) + offsets: [0, 4, 8] + itemsize: 16 + >>> packed_dt = repack_fields(dt) + >>> packed_dt + dtype([('f0', 'u1'), ('f1', '<i4'), ('f2', '<f8')]) + >>> print_offsets(packed_dt) + offsets: [0, 1, 5] + itemsize: 13 + + """ + if not isinstance(a, np.dtype): + dt = repack_fields(a.dtype, align=align, recurse=recurse) + return a.astype(dt, copy=False) + + if a.names is None: + return a + + fieldinfo = [] + for name in a.names: + tup = a.fields[name] + if recurse: + fmt = repack_fields(tup[0], align=align, recurse=True) + else: + fmt = tup[0] + + if len(tup) == 3: + name = (tup[2], name) + + fieldinfo.append((name, fmt)) + + dt = np.dtype(fieldinfo, align=align) + return np.dtype((a.type, dt)) def stack_arrays(arrays, defaults=None, usemask=True, asrecarray=False, autoconvert=False): @@ -725,25 +860,21 @@ def stack_arrays(arrays, defaults=None, usemask=True, asrecarray=False, fldnames = [d.names for d in ndtype] # dtype_l = ndtype[0] - newdescr = dtype_l.descr - names = [_[0] for _ in newdescr] + newdescr = get_fieldspec(dtype_l) + names = [n for n, d in newdescr] for dtype_n in ndtype[1:]: - for descr in dtype_n.descr: - name = descr[0] or '' - if name not in names: - newdescr.append(descr) - names.append(name) + for fname, fdtype in get_fieldspec(dtype_n): + if fname not in names: + newdescr.append((fname, fdtype)) + names.append(fname) else: - nameidx = names.index(name) - current_descr = newdescr[nameidx] + nameidx = names.index(fname) + _, cdtype = newdescr[nameidx] if autoconvert: - if np.dtype(descr[1]) > np.dtype(current_descr[-1]): - current_descr = list(current_descr) - current_descr[-1] = descr[1] - newdescr[nameidx] = tuple(current_descr) - elif descr[1] != current_descr[-1]: + newdescr[nameidx] = (fname, max(fdtype, cdtype)) + elif fdtype != cdtype: raise TypeError("Incompatible type '%s' <> '%s'" % - (dict(newdescr)[name], descr[1])) + (cdtype, fdtype)) # Only one field: use concatenate if len(newdescr) == 1: output = ma.concatenate(seqarrays) @@ -881,11 +1012,14 @@ def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2', key = (key,) # Check the keys + if len(set(key)) != len(key): + dup = next(x for n,x in enumerate(key) if x in key[n+1:]) + raise ValueError("duplicate join key %r" % dup) for name in key: if name not in r1.dtype.names: - raise ValueError('r1 does not have key field %s' % name) + raise ValueError('r1 does not have key field %r' % name) if name not in r2.dtype.names: - raise ValueError('r2 does not have key field %s' % name) + raise ValueError('r2 does not have key field %r' % name) # Make sure we work with ravelled arrays r1 = r1.ravel() @@ -896,15 +1030,17 @@ def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2', (r1names, r2names) = (r1.dtype.names, r2.dtype.names) # Check the names for collision - if (set.intersection(set(r1names), set(r2names)).difference(key) and - not (r1postfix or r2postfix)): + collisions = (set(r1names) & set(r2names)) - set(key) + if collisions and not (r1postfix or r2postfix): msg = "r1 and r2 contain common names, r1postfix and r2postfix " - msg += "can't be empty" + msg += "can't both be empty" raise ValueError(msg) # Make temporary arrays of just the keys - r1k = drop_fields(r1, [n for n in r1names if n not in key]) - r2k = drop_fields(r2, [n for n in r2names if n not in key]) + # (use order of keys in `r1` for back-compatibility) + key1 = [ n for n in r1names if n in key ] + r1k = _keep_fields(r1, key1) + r2k = _keep_fields(r2, key1) # Concatenate the two arrays for comparison aux = ma.concatenate((r1k, r2k)) @@ -934,32 +1070,38 @@ def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2', # # Build the new description of the output array ....... # Start with the key fields - ndtype = [list(_) for _ in r1k.dtype.descr] - # Add the other fields - ndtype.extend(list(_) for _ in r1.dtype.descr if _[0] not in key) - # Find the new list of names (it may be different from r1names) - names = list(_[0] for _ in ndtype) - for desc in r2.dtype.descr: - desc = list(desc) - name = desc[0] + ndtype = get_fieldspec(r1k.dtype) + + # Add the fields from r1 + for fname, fdtype in get_fieldspec(r1.dtype): + if fname not in key: + ndtype.append((fname, fdtype)) + + # Add the fields from r2 + for fname, fdtype in get_fieldspec(r2.dtype): # Have we seen the current name already ? - if name in names: - nameidx = ndtype.index(desc) - current = ndtype[nameidx] - # The current field is part of the key: take the largest dtype - if name in key: - current[-1] = max(desc[1], current[-1]) - # The current field is not part of the key: add the suffixes - else: - current[0] += r1postfix - desc[0] += r2postfix - ndtype.insert(nameidx + 1, desc) - #... we haven't: just add the description to the current list + # we need to rebuild this list every time + names = list(name for name, dtype in ndtype) + try: + nameidx = names.index(fname) + except ValueError: + #... we haven't: just add the description to the current list + ndtype.append((fname, fdtype)) else: - names.extend(desc[0]) - ndtype.append(desc) - # Revert the elements to tuples - ndtype = [tuple(_) for _ in ndtype] + # collision + _, cdtype = ndtype[nameidx] + if fname in key: + # The current field is part of the key: take the largest dtype + ndtype[nameidx] = (fname, max(fdtype, cdtype)) + else: + # The current field is not part of the key: add the suffixes, + # and place the new field adjacent to the old one + ndtype[nameidx:nameidx + 1] = [ + (fname + r1postfix, cdtype), + (fname + r2postfix, fdtype) + ] + # Rebuild a dtype from the new fields + ndtype = np.dtype(ndtype) # Find the largest nb of common fields : # r1cmn and r2cmn should be equal, but... cmn = max(r1cmn, r2cmn) |