diff options
Diffstat (limited to 'scipy/base/records.py')
-rw-r--r-- | scipy/base/records.py | 121 |
1 files changed, 115 insertions, 6 deletions
diff --git a/scipy/base/records.py b/scipy/base/records.py index 25503826b..62eb76777 100644 --- a/scipy/base/records.py +++ b/scipy/base/records.py @@ -13,6 +13,7 @@ import re format_re = re.compile(r'(?P<repeat> *[(]?[ ,0-9]*[)]? *)(?P<dtype>[A-Za-z0-9.]*)') numfmt = nt.typeDict +_typestr = nt._typestr def find_duplicate(list): """Find duplication in a list, return a list of duplicated elements""" @@ -69,7 +70,7 @@ class format_parser: _alignment = nt._alignment _bytes = nt.nbytes - _typestr = nt._typestr + if (type(formats) in [types.ListType, types.TupleType]): _fmt = formats[:] elif (type(formats) == types.StringType): @@ -105,11 +106,14 @@ class format_parser: # Flexible types need special treatment _dtype = _dtype.strip() - if _dtype[0] in ['V','S','U']: + if _dtype[0] in ['V','S','U','a']: self._itemsizes[i] = int(_dtype[1:]) if _dtype[0] == 'U': self._itemsizes[i] *= unisize - _dtype = _dtype[0] + if _dtype[0] == 'a': + _dtype = 'S' + else: + _dtype = _dtype[0] if _repeat == '': _repeat = 1 @@ -214,7 +218,7 @@ class record(nt.void): all.sort(lambda x,y: cmp(x[1],y[1])) outlist = [self.getfield(item[0], item[1]) for item in all] - return str(outlist) + return str(tuple(outlist)) def __getattribute__(self, attr): if attr in ['setfield', 'getfield', 'fields']: @@ -252,8 +256,10 @@ class recarray(sb.ndarray): def __new__(subtype, shape, formats, names=None, titles=None, buf=None, offset=0, strides=None, swap=0, aligned=0): - if isinstance(formats,str): - parsed = format_parser(formats, aligned, names, titles) + if isinstance(formats, sb.dtypedescr): + descr = formats + elif isinstance(formats,str): + parsed = format_parser(formats, names, titles, aligned) descr = parsed._descr else: if aligned: @@ -287,4 +293,107 @@ class recarray(sb.ndarray): return sb.ndarray.__setattr__(self,attr,val) return self.setfield(val,*res) + + +def fromarrays(arrayList, formats=None, names=None, titles=None, shape=None, + swap=0, aligned=0): + """ create a record array from a (flat) list of arrays + + >>> x1=array([1,2,3,4]) + >>> x2=array(['a','dd','xyz','12']) + >>> x3=array([1.1,2,3,4]) + >>> r=fromarrays([x1,x2,x3],names='a,b,c') + >>> print r[1] + (2, 'dd\x00', 2.0) + >>> x1[1]=34 + >>> r.a + array([1, 2, 3, 4]) + """ + + if shape is None or shape == 0: + shape = arrayList[0].shape + + if formats is None: + # go through each object in the list to see if it is an ndarray + # and determine the formats. + formats = '' + for obj in arrayList: + if not isinstance(obj, sb.ndarray): + raise ValueError, "item in the array list must be an ndarray." + if obj.ndim == 1: + _repeat = '' + elif len(obj._shape) >= 2: + _repeat = `obj._shape[1:]` + formats += _repeat + _typestr[obj.dtype] + if issubclass(obj.dtype, nt.flexible): + formats += `obj.itemsize` + formats += ',' + formats=formats[:-1] + + for obj in arrayList: + if obj.shape != shape: + raise ValueError, "array has different shape" + + parsed = format_parser(formats, names, titles, aligned) + _names = parsed._names + _array = recarray(shape, parsed._descr, swap=swap) + # populate the record array (makes a copy) + for i in range(len(arrayList)): + _array[_names[i]] = arrayList[i] + + return _array + +def fromrecords(recList, formats=None, names=None, shape=0, swap=0, aligned=0): + """ create a Record Array from a list of records in text form + + The data in the same field can be heterogeneous, they will be promoted + to the highest data type. This method is intended for creating + smaller record arrays. If used to create large array e.g. + + r=fromrecords([[2,3.,'abc']]*100000) + + it is slow. + + >>> r=fromrecords([[456,'dbe',1.2],[2,'de',1.3]],names='col1,col2,col3') + >>> print r[0] + (456, 'dbe', 1.2) + >>> r.field('col1') + array([456, 2]) + >>> r.field('col2') + array(['dbe', 'de']) + >>> import cPickle + >>> print cPickle.loads(cPickle.dumps(r)) + recarray[ + (456, 'dbe', 1.2), + (2, 'de', 1.3) + ] + """ + + if shape == 0: + _shape = len(recList) + else: + _shape = shape + + _nfields = len(recList[0]) + for _rec in recList: + if len(_rec) != _nfields: + raise ValueError, "inconsistent number of objects in each record" + arrlist = [0]*_nfields + for col in range(_nfields): + tmp = [0]*_shape + for row in range(_shape): + tmp[row] = recList[row][col] + try: + arrlist[col] = num.array(tmp) + except: + try: + arrlist[col] = chararray.array(tmp) + except: + raise ValueError, "inconsistent data at row %d,field %d" % (row, col) + _array = fromarrays(arrlist, formats=formats, shape=_shape, names=names, + byteorder=byteorder, aligned=aligned) + del arrlist + del tmp + return _array + |