summaryrefslogtreecommitdiff
path: root/numpy/lib/recfunctions.py
diff options
context:
space:
mode:
authorpierregm <pierregm@localhost>2010-03-26 05:15:57 +0000
committerpierregm <pierregm@localhost>2010-03-26 05:15:57 +0000
commitb64659e01eeeaf1d4498f2140911499c05702c70 (patch)
treeaaea262a526f055a3407d709b8c92dc7b922f2b2 /numpy/lib/recfunctions.py
parent6cf6fd356c70bdf2d6d6d39f052cea9f662ecc07 (diff)
downloadnumpy-b64659e01eeeaf1d4498f2140911499c05702c70.tar.gz
* Fixed merge_arrays for arrays of size 1 (bug #1407)
* merge_arrays now accepts sequences of lists/tuples as inputs
Diffstat (limited to 'numpy/lib/recfunctions.py')
-rw-r--r--numpy/lib/recfunctions.py137
1 files changed, 84 insertions, 53 deletions
diff --git a/numpy/lib/recfunctions.py b/numpy/lib/recfunctions.py
index 65ae2bd41..3c0094aae 100644
--- a/numpy/lib/recfunctions.py
+++ b/numpy/lib/recfunctions.py
@@ -217,7 +217,7 @@ def get_fieldstructure(adtype, lastname=None, parents=None,):
current = adtype[name]
if current.names:
if lastname:
- parents[name] = [lastname,]
+ parents[name] = [lastname, ]
else:
parents[name] = []
parents.update(get_fieldstructure(current, name, parents))
@@ -227,7 +227,7 @@ def get_fieldstructure(adtype, lastname=None, parents=None,):
# if (lastparent[-1] != lastname):
lastparent.append(lastname)
elif lastname:
- lastparent = [lastname,]
+ lastparent = [lastname, ]
parents[name] = lastparent or []
return parents or None
@@ -274,7 +274,7 @@ def izip_records(seqarrays, fill_value=None, flatten=True):
Whether to
"""
# OK, that's a complete ripoff from Python2.6 itertools.izip_longest
- def sentinel(counter = ([fill_value]*(len(seqarrays)-1)).pop):
+ def sentinel(counter=([fill_value] * (len(seqarrays) - 1)).pop):
"Yields the fill_value or raises IndexError"
yield counter()
#
@@ -324,8 +324,9 @@ def _fix_defaults(output, defaults=None):
return output
+
def merge_arrays(seqarrays,
- fill_value=-1, flatten=False, usemask=True, asrecarray=False):
+ fill_value= -1, flatten=False, usemask=False, asrecarray=False):
"""
Merge arrays field by field.
@@ -372,62 +373,92 @@ def merge_arrays(seqarrays,
True for boolean values
* XXX: I just obtained these values empirically
"""
+ # Only one item in the input sequence ?
if (len(seqarrays) == 1):
- seqarrays = seqarrays[0]
- if isinstance(seqarrays, ndarray):
+ seqarrays = np.asanyarray(seqarrays[0])
+ # Do we have a single ndarary as input ?
+ if isinstance(seqarrays, (ndarray, np.void)):
seqdtype = seqarrays.dtype
if (not flatten) or \
(zip_descr((seqarrays,), flatten=True) == seqdtype.descr):
+ # Minimal processing needed: just make sure everythng's a-ok
seqarrays = seqarrays.ravel()
+ # Make sure we have named fields
if not seqdtype.names:
- seqarrays = seqarrays.view([('', seqdtype)])
+ seqdtype = [('', seqdtype)]
+ # Find what type of array we must return
if usemask:
if asrecarray:
- return seqarrays.view(MaskedRecords)
- return seqarrays.view(MaskedArray)
+ seqtype = MaskedRecords
+ else:
+ seqtype = MaskedArray
elif asrecarray:
- return seqarrays.view(recarray)
- return seqarrays
+ seqtype = recarray
+ else:
+ seqtype = ndarray
+ return seqarrays.view(dtype=seqdtype, type=seqtype)
else:
seqarrays = (seqarrays,)
- # Get the dtype
+ else:
+ # Make sure we have arrays in the input sequence
+ seqarrays = map(np.asanyarray, seqarrays)
+ # Find the sizes of the inputs and their maximum
+ sizes = tuple(a.size for a in seqarrays)
+ maxlength = max(sizes)
+ # Get the dtype of the output (flattening if needed)
newdtype = zip_descr(seqarrays, flatten=flatten)
- # Get the data and the fill_value from each array
- seqdata = [ma.getdata(a.ravel()) for a in seqarrays]
- seqmask = [ma.getmaskarray(a).ravel() for a in seqarrays]
- fill_value = [_check_fill_value(fill_value, a.dtype) for a in seqdata]
- # Make an iterator from each array, padding w/ fill_values
- maxlength = max(len(a) for a in seqarrays)
- for (i, (a, m, fval)) in enumerate(zip(seqdata, seqmask, fill_value)):
- # Flatten the fill_values if there's only one field
- if isinstance(fval, (ndarray, np.void)):
- fmsk = ma.ones((1,), m.dtype)[0]
- if len(fval.dtype) == 1:
- fval = fval.item()[0]
- fmsk = True
- else:
- # fval and fmsk should be np.void objects
- fval = np.array([fval,], dtype=a.dtype)[0]
-# fmsk = np.array([fmsk,], dtype=m.dtype)[0]
- else:
- fmsk = True
- nbmissing = (maxlength-len(a))
- seqdata[i] = itertools.chain(a, [fval]*nbmissing)
- seqmask[i] = itertools.chain(m, [fmsk]*nbmissing)
- #
- data = izip_records(seqdata, flatten=flatten)
- data = tuple(data)
+ # Initialize the sequences for data and mask
+ seqdata = []
+ seqmask = []
+ # If we expect some kind of MaskedArray, make a special loop.
if usemask:
- mask = izip_records(seqmask, fill_value=True, flatten=flatten)
- mask = tuple(mask)
- output = ma.array(np.fromiter(data, dtype=newdtype))
- output._mask[:] = list(mask)
+ for (a, n) in itertools.izip(seqarrays, sizes):
+ nbmissing = (maxlength - n)
+ # Get the data and mask
+ data = a.ravel().__array__()
+ mask = ma.getmaskarray(a).ravel()
+ # Get the filling value (if needed)
+ if nbmissing:
+ fval = _check_fill_value(fill_value, a.dtype)
+ if isinstance(fval, (ndarray, np.void)):
+ if len(fval.dtype) == 1:
+ fval = fval.item()[0]
+ fmsk = True
+ else:
+ fval = np.array(fval, dtype=a.dtype, ndmin=1)
+ fmsk = np.ones((1,), dtype=mask.dtype)
+ else:
+ fval = None
+ fmsk = True
+ # Store an iterator padding the input to the expected length
+ seqdata.append(itertools.chain(data, [fval] * nbmissing))
+ seqmask.append(itertools.chain(mask, [fmsk] * nbmissing))
+ # Create an iterator for the data
+ data = tuple(izip_records(seqdata, flatten=flatten))
+ output = ma.array(np.fromiter(data, dtype=newdtype, count=maxlength),
+ mask=list(izip_records(seqmask, flatten=flatten)))
if asrecarray:
output = output.view(MaskedRecords)
else:
- output = np.fromiter(data, dtype=newdtype)
+ # Same as before, without the mask we don't need...
+ for (a, n) in itertools.izip(seqarrays, sizes):
+ nbmissing = (maxlength - n)
+ data = a.ravel().__array__()
+ if nbmissing:
+ fval = _check_fill_value(fill_value, a.dtype)
+ if isinstance(fval, (ndarray, np.void)):
+ if len(fval.dtype) == 1:
+ fval = fval.item()[0]
+ else:
+ fval = np.array(fval, dtype=a.dtype, ndmin=1)
+ else:
+ fval = None
+ seqdata.append(itertools.chain(data, [fval] * nbmissing))
+ output = np.fromiter(tuple(izip_records(seqdata, flatten=flatten)),
+ dtype=newdtype, count=maxlength)
if asrecarray:
output = output.view(recarray)
+ # And we're done...
return output
@@ -467,7 +498,7 @@ def drop_fields(base, drop_names, usemask=True, asrecarray=False):
dtype=[('a', '<i4')])
"""
if _is_string_like(drop_names):
- drop_names = [drop_names,]
+ drop_names = [drop_names, ]
else:
drop_names = set(drop_names)
#
@@ -542,7 +573,7 @@ def rename_fields(base, namemapper):
def append_fields(base, names, data=None, dtypes=None,
- fill_value=-1, usemask=True, asrecarray=False):
+ fill_value= -1, usemask=True, asrecarray=False):
"""
Add new fields to an existing array.
@@ -577,14 +608,14 @@ def append_fields(base, names, data=None, dtypes=None,
err_msg = "The number of arrays does not match the number of names"
raise ValueError(err_msg)
elif isinstance(names, basestring):
- names = [names,]
- data = [data,]
+ names = [names, ]
+ data = [data, ]
#
if dtypes is None:
data = [np.array(a, copy=False, subok=True) for a in data]
data = [a.view([(name, a.dtype)]) for (name, a) in zip(names, data)]
elif not hasattr(dtypes, '__iter__'):
- dtypes = [dtypes,]
+ dtypes = [dtypes, ]
if len(data) != len(dtypes):
if len(dtypes) == 1:
dtypes = dtypes * len(data)
@@ -712,7 +743,7 @@ def stack_arrays(arrays, defaults=None, usemask=True, asrecarray=False,
current_descr[-1] = descr[1]
newdescr[nameidx] = tuple(current_descr)
elif descr[1] != current_descr[-1]:
- raise TypeError("Incompatible type '%s' <> '%s'" %\
+ raise TypeError("Incompatible type '%s' <> '%s'" % \
(dict(newdescr)[name], descr[1]))
# Only one field: use concatenate
if len(newdescr) == 1:
@@ -849,14 +880,14 @@ def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2',
"'outer' or 'leftouter' (got '%s' instead)" % jointype)
# If we have a single key, put it in a tuple
if isinstance(key, basestring):
- key = (key, )
+ key = (key,)
# Check the keys
for name in key:
if name not in r1.dtype.names:
- raise ValueError('r1 does not have key field %s'%name)
+ raise ValueError('r1 does not have key field %s' % name)
if name not in r2.dtype.names:
- raise ValueError('r2 does not have key field %s'%name)
+ raise ValueError('r2 does not have key field %s' % name)
# Make sure we work with ravelled arrays
r1 = r1.ravel()
@@ -915,7 +946,7 @@ def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2',
else:
current[0] += r1postfix
desc[0] += r2postfix
- ndtype.insert(nameidx+1, desc)
+ ndtype.insert(nameidx + 1, desc)
#... we haven't: just add the description to the current list
else:
names.extend(desc[0])
@@ -934,7 +965,7 @@ def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2',
current = output[f]
current[:r1cmn] = selected[:r1cmn]
if jointype in ('outer', 'leftouter'):
- current[cmn:cmn+r1spc] = selected[r1cmn:]
+ current[cmn:cmn + r1spc] = selected[r1cmn:]
for f in r2names:
selected = s2[f]
if f not in names: