summaryrefslogtreecommitdiff
path: root/scipy/base/records.py
diff options
context:
space:
mode:
authorTravis Oliphant <oliphant@enthought.com>2005-12-13 04:07:36 +0000
committerTravis Oliphant <oliphant@enthought.com>2005-12-13 04:07:36 +0000
commitaea56a2d4c50cbedd4591aa970c1ac2578666d5a (patch)
tree4e42ee69c941b2a5c18e47347e258f5a8d98ae16 /scipy/base/records.py
parentbb4910c3f4ae165e8cc11dcfdb51fb4dc1a93b99 (diff)
downloadnumpy-aea56a2d4c50cbedd4591aa970c1ac2578666d5a.tar.gz
Fixed pickling to support arbitrary dtypedescr arrays.
Diffstat (limited to 'scipy/base/records.py')
-rw-r--r--scipy/base/records.py121
1 files changed, 115 insertions, 6 deletions
diff --git a/scipy/base/records.py b/scipy/base/records.py
index 25503826b..62eb76777 100644
--- a/scipy/base/records.py
+++ b/scipy/base/records.py
@@ -13,6 +13,7 @@ import re
format_re = re.compile(r'(?P<repeat> *[(]?[ ,0-9]*[)]? *)(?P<dtype>[A-Za-z0-9.]*)')
numfmt = nt.typeDict
+_typestr = nt._typestr
def find_duplicate(list):
"""Find duplication in a list, return a list of duplicated elements"""
@@ -69,7 +70,7 @@ class format_parser:
_alignment = nt._alignment
_bytes = nt.nbytes
- _typestr = nt._typestr
+
if (type(formats) in [types.ListType, types.TupleType]):
_fmt = formats[:]
elif (type(formats) == types.StringType):
@@ -105,11 +106,14 @@ class format_parser:
# Flexible types need special treatment
_dtype = _dtype.strip()
- if _dtype[0] in ['V','S','U']:
+ if _dtype[0] in ['V','S','U','a']:
self._itemsizes[i] = int(_dtype[1:])
if _dtype[0] == 'U':
self._itemsizes[i] *= unisize
- _dtype = _dtype[0]
+ if _dtype[0] == 'a':
+ _dtype = 'S'
+ else:
+ _dtype = _dtype[0]
if _repeat == '':
_repeat = 1
@@ -214,7 +218,7 @@ class record(nt.void):
all.sort(lambda x,y: cmp(x[1],y[1]))
outlist = [self.getfield(item[0], item[1]) for item in all]
- return str(outlist)
+ return str(tuple(outlist))
def __getattribute__(self, attr):
if attr in ['setfield', 'getfield', 'fields']:
@@ -252,8 +256,10 @@ class recarray(sb.ndarray):
def __new__(subtype, shape, formats, names=None, titles=None,
buf=None, offset=0, strides=None, swap=0, aligned=0):
- if isinstance(formats,str):
- parsed = format_parser(formats, aligned, names, titles)
+ if isinstance(formats, sb.dtypedescr):
+ descr = formats
+ elif isinstance(formats,str):
+ parsed = format_parser(formats, names, titles, aligned)
descr = parsed._descr
else:
if aligned:
@@ -287,4 +293,107 @@ class recarray(sb.ndarray):
return sb.ndarray.__setattr__(self,attr,val)
return self.setfield(val,*res)
+
+
+def fromarrays(arrayList, formats=None, names=None, titles=None, shape=None,
+ swap=0, aligned=0):
+ """ create a record array from a (flat) list of arrays
+
+ >>> x1=array([1,2,3,4])
+ >>> x2=array(['a','dd','xyz','12'])
+ >>> x3=array([1.1,2,3,4])
+ >>> r=fromarrays([x1,x2,x3],names='a,b,c')
+ >>> print r[1]
+ (2, 'dd\x00', 2.0)
+ >>> x1[1]=34
+ >>> r.a
+ array([1, 2, 3, 4])
+ """
+
+ if shape is None or shape == 0:
+ shape = arrayList[0].shape
+
+ if formats is None:
+ # go through each object in the list to see if it is an ndarray
+ # and determine the formats.
+ formats = ''
+ for obj in arrayList:
+ if not isinstance(obj, sb.ndarray):
+ raise ValueError, "item in the array list must be an ndarray."
+ if obj.ndim == 1:
+ _repeat = ''
+ elif len(obj._shape) >= 2:
+ _repeat = `obj._shape[1:]`
+ formats += _repeat + _typestr[obj.dtype]
+ if issubclass(obj.dtype, nt.flexible):
+ formats += `obj.itemsize`
+ formats += ','
+ formats=formats[:-1]
+
+ for obj in arrayList:
+ if obj.shape != shape:
+ raise ValueError, "array has different shape"
+
+ parsed = format_parser(formats, names, titles, aligned)
+ _names = parsed._names
+ _array = recarray(shape, parsed._descr, swap=swap)
+ # populate the record array (makes a copy)
+ for i in range(len(arrayList)):
+ _array[_names[i]] = arrayList[i]
+
+ return _array
+
+def fromrecords(recList, formats=None, names=None, shape=0, swap=0, aligned=0):
+ """ create a Record Array from a list of records in text form
+
+ The data in the same field can be heterogeneous, they will be promoted
+ to the highest data type. This method is intended for creating
+ smaller record arrays. If used to create large array e.g.
+
+ r=fromrecords([[2,3.,'abc']]*100000)
+
+ it is slow.
+
+ >>> r=fromrecords([[456,'dbe',1.2],[2,'de',1.3]],names='col1,col2,col3')
+ >>> print r[0]
+ (456, 'dbe', 1.2)
+ >>> r.field('col1')
+ array([456, 2])
+ >>> r.field('col2')
+ array(['dbe', 'de'])
+ >>> import cPickle
+ >>> print cPickle.loads(cPickle.dumps(r))
+ recarray[
+ (456, 'dbe', 1.2),
+ (2, 'de', 1.3)
+ ]
+ """
+
+ if shape == 0:
+ _shape = len(recList)
+ else:
+ _shape = shape
+
+ _nfields = len(recList[0])
+ for _rec in recList:
+ if len(_rec) != _nfields:
+ raise ValueError, "inconsistent number of objects in each record"
+ arrlist = [0]*_nfields
+ for col in range(_nfields):
+ tmp = [0]*_shape
+ for row in range(_shape):
+ tmp[row] = recList[row][col]
+ try:
+ arrlist[col] = num.array(tmp)
+ except:
+ try:
+ arrlist[col] = chararray.array(tmp)
+ except:
+ raise ValueError, "inconsistent data at row %d,field %d" % (row, col)
+ _array = fromarrays(arrlist, formats=formats, shape=_shape, names=names,
+ byteorder=byteorder, aligned=aligned)
+ del arrlist
+ del tmp
+ return _array
+