diff options
author | pierregm <pierregm@localhost> | 2010-03-26 05:15:57 +0000 |
---|---|---|
committer | pierregm <pierregm@localhost> | 2010-03-26 05:15:57 +0000 |
commit | b64659e01eeeaf1d4498f2140911499c05702c70 (patch) | |
tree | aaea262a526f055a3407d709b8c92dc7b922f2b2 /numpy/lib/recfunctions.py | |
parent | 6cf6fd356c70bdf2d6d6d39f052cea9f662ecc07 (diff) | |
download | numpy-b64659e01eeeaf1d4498f2140911499c05702c70.tar.gz |
* Fixed merge_arrays for arrays of size 1 (bug #1407)
* merge_arrays now accepts sequences of lists/tuples as inputs
Diffstat (limited to 'numpy/lib/recfunctions.py')
-rw-r--r-- | numpy/lib/recfunctions.py | 137 |
1 files changed, 84 insertions, 53 deletions
diff --git a/numpy/lib/recfunctions.py b/numpy/lib/recfunctions.py index 65ae2bd41..3c0094aae 100644 --- a/numpy/lib/recfunctions.py +++ b/numpy/lib/recfunctions.py @@ -217,7 +217,7 @@ def get_fieldstructure(adtype, lastname=None, parents=None,): current = adtype[name] if current.names: if lastname: - parents[name] = [lastname,] + parents[name] = [lastname, ] else: parents[name] = [] parents.update(get_fieldstructure(current, name, parents)) @@ -227,7 +227,7 @@ def get_fieldstructure(adtype, lastname=None, parents=None,): # if (lastparent[-1] != lastname): lastparent.append(lastname) elif lastname: - lastparent = [lastname,] + lastparent = [lastname, ] parents[name] = lastparent or [] return parents or None @@ -274,7 +274,7 @@ def izip_records(seqarrays, fill_value=None, flatten=True): Whether to """ # OK, that's a complete ripoff from Python2.6 itertools.izip_longest - def sentinel(counter = ([fill_value]*(len(seqarrays)-1)).pop): + def sentinel(counter=([fill_value] * (len(seqarrays) - 1)).pop): "Yields the fill_value or raises IndexError" yield counter() # @@ -324,8 +324,9 @@ def _fix_defaults(output, defaults=None): return output + def merge_arrays(seqarrays, - fill_value=-1, flatten=False, usemask=True, asrecarray=False): + fill_value= -1, flatten=False, usemask=False, asrecarray=False): """ Merge arrays field by field. @@ -372,62 +373,92 @@ def merge_arrays(seqarrays, True for boolean values * XXX: I just obtained these values empirically """ + # Only one item in the input sequence ? if (len(seqarrays) == 1): - seqarrays = seqarrays[0] - if isinstance(seqarrays, ndarray): + seqarrays = np.asanyarray(seqarrays[0]) + # Do we have a single ndarary as input ? + if isinstance(seqarrays, (ndarray, np.void)): seqdtype = seqarrays.dtype if (not flatten) or \ (zip_descr((seqarrays,), flatten=True) == seqdtype.descr): + # Minimal processing needed: just make sure everythng's a-ok seqarrays = seqarrays.ravel() + # Make sure we have named fields if not seqdtype.names: - seqarrays = seqarrays.view([('', seqdtype)]) + seqdtype = [('', seqdtype)] + # Find what type of array we must return if usemask: if asrecarray: - return seqarrays.view(MaskedRecords) - return seqarrays.view(MaskedArray) + seqtype = MaskedRecords + else: + seqtype = MaskedArray elif asrecarray: - return seqarrays.view(recarray) - return seqarrays + seqtype = recarray + else: + seqtype = ndarray + return seqarrays.view(dtype=seqdtype, type=seqtype) else: seqarrays = (seqarrays,) - # Get the dtype + else: + # Make sure we have arrays in the input sequence + seqarrays = map(np.asanyarray, seqarrays) + # Find the sizes of the inputs and their maximum + sizes = tuple(a.size for a in seqarrays) + maxlength = max(sizes) + # Get the dtype of the output (flattening if needed) newdtype = zip_descr(seqarrays, flatten=flatten) - # Get the data and the fill_value from each array - seqdata = [ma.getdata(a.ravel()) for a in seqarrays] - seqmask = [ma.getmaskarray(a).ravel() for a in seqarrays] - fill_value = [_check_fill_value(fill_value, a.dtype) for a in seqdata] - # Make an iterator from each array, padding w/ fill_values - maxlength = max(len(a) for a in seqarrays) - for (i, (a, m, fval)) in enumerate(zip(seqdata, seqmask, fill_value)): - # Flatten the fill_values if there's only one field - if isinstance(fval, (ndarray, np.void)): - fmsk = ma.ones((1,), m.dtype)[0] - if len(fval.dtype) == 1: - fval = fval.item()[0] - fmsk = True - else: - # fval and fmsk should be np.void objects - fval = np.array([fval,], dtype=a.dtype)[0] -# fmsk = np.array([fmsk,], dtype=m.dtype)[0] - else: - fmsk = True - nbmissing = (maxlength-len(a)) - seqdata[i] = itertools.chain(a, [fval]*nbmissing) - seqmask[i] = itertools.chain(m, [fmsk]*nbmissing) - # - data = izip_records(seqdata, flatten=flatten) - data = tuple(data) + # Initialize the sequences for data and mask + seqdata = [] + seqmask = [] + # If we expect some kind of MaskedArray, make a special loop. if usemask: - mask = izip_records(seqmask, fill_value=True, flatten=flatten) - mask = tuple(mask) - output = ma.array(np.fromiter(data, dtype=newdtype)) - output._mask[:] = list(mask) + for (a, n) in itertools.izip(seqarrays, sizes): + nbmissing = (maxlength - n) + # Get the data and mask + data = a.ravel().__array__() + mask = ma.getmaskarray(a).ravel() + # Get the filling value (if needed) + if nbmissing: + fval = _check_fill_value(fill_value, a.dtype) + if isinstance(fval, (ndarray, np.void)): + if len(fval.dtype) == 1: + fval = fval.item()[0] + fmsk = True + else: + fval = np.array(fval, dtype=a.dtype, ndmin=1) + fmsk = np.ones((1,), dtype=mask.dtype) + else: + fval = None + fmsk = True + # Store an iterator padding the input to the expected length + seqdata.append(itertools.chain(data, [fval] * nbmissing)) + seqmask.append(itertools.chain(mask, [fmsk] * nbmissing)) + # Create an iterator for the data + data = tuple(izip_records(seqdata, flatten=flatten)) + output = ma.array(np.fromiter(data, dtype=newdtype, count=maxlength), + mask=list(izip_records(seqmask, flatten=flatten))) if asrecarray: output = output.view(MaskedRecords) else: - output = np.fromiter(data, dtype=newdtype) + # Same as before, without the mask we don't need... + for (a, n) in itertools.izip(seqarrays, sizes): + nbmissing = (maxlength - n) + data = a.ravel().__array__() + if nbmissing: + fval = _check_fill_value(fill_value, a.dtype) + if isinstance(fval, (ndarray, np.void)): + if len(fval.dtype) == 1: + fval = fval.item()[0] + else: + fval = np.array(fval, dtype=a.dtype, ndmin=1) + else: + fval = None + seqdata.append(itertools.chain(data, [fval] * nbmissing)) + output = np.fromiter(tuple(izip_records(seqdata, flatten=flatten)), + dtype=newdtype, count=maxlength) if asrecarray: output = output.view(recarray) + # And we're done... return output @@ -467,7 +498,7 @@ def drop_fields(base, drop_names, usemask=True, asrecarray=False): dtype=[('a', '<i4')]) """ if _is_string_like(drop_names): - drop_names = [drop_names,] + drop_names = [drop_names, ] else: drop_names = set(drop_names) # @@ -542,7 +573,7 @@ def rename_fields(base, namemapper): def append_fields(base, names, data=None, dtypes=None, - fill_value=-1, usemask=True, asrecarray=False): + fill_value= -1, usemask=True, asrecarray=False): """ Add new fields to an existing array. @@ -577,14 +608,14 @@ def append_fields(base, names, data=None, dtypes=None, err_msg = "The number of arrays does not match the number of names" raise ValueError(err_msg) elif isinstance(names, basestring): - names = [names,] - data = [data,] + names = [names, ] + data = [data, ] # if dtypes is None: data = [np.array(a, copy=False, subok=True) for a in data] data = [a.view([(name, a.dtype)]) for (name, a) in zip(names, data)] elif not hasattr(dtypes, '__iter__'): - dtypes = [dtypes,] + dtypes = [dtypes, ] if len(data) != len(dtypes): if len(dtypes) == 1: dtypes = dtypes * len(data) @@ -712,7 +743,7 @@ def stack_arrays(arrays, defaults=None, usemask=True, asrecarray=False, current_descr[-1] = descr[1] newdescr[nameidx] = tuple(current_descr) elif descr[1] != current_descr[-1]: - raise TypeError("Incompatible type '%s' <> '%s'" %\ + raise TypeError("Incompatible type '%s' <> '%s'" % \ (dict(newdescr)[name], descr[1])) # Only one field: use concatenate if len(newdescr) == 1: @@ -849,14 +880,14 @@ def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2', "'outer' or 'leftouter' (got '%s' instead)" % jointype) # If we have a single key, put it in a tuple if isinstance(key, basestring): - key = (key, ) + key = (key,) # Check the keys for name in key: if name not in r1.dtype.names: - raise ValueError('r1 does not have key field %s'%name) + raise ValueError('r1 does not have key field %s' % name) if name not in r2.dtype.names: - raise ValueError('r2 does not have key field %s'%name) + raise ValueError('r2 does not have key field %s' % name) # Make sure we work with ravelled arrays r1 = r1.ravel() @@ -915,7 +946,7 @@ def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2', else: current[0] += r1postfix desc[0] += r2postfix - ndtype.insert(nameidx+1, desc) + ndtype.insert(nameidx + 1, desc) #... we haven't: just add the description to the current list else: names.extend(desc[0]) @@ -934,7 +965,7 @@ def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2', current = output[f] current[:r1cmn] = selected[:r1cmn] if jointype in ('outer', 'leftouter'): - current[cmn:cmn+r1spc] = selected[r1cmn:] + current[cmn:cmn + r1spc] = selected[r1cmn:] for f in r2names: selected = s2[f] if f not in names: |