summaryrefslogtreecommitdiff
path: root/scipy/base/chararray.py
diff options
context:
space:
mode:
Diffstat (limited to 'scipy/base/chararray.py')
-rw-r--r--scipy/base/chararray.py289
1 files changed, 163 insertions, 126 deletions
diff --git a/scipy/base/chararray.py b/scipy/base/chararray.py
index d81c2fffd..510d4bf8c 100644
--- a/scipy/base/chararray.py
+++ b/scipy/base/chararray.py
@@ -1,240 +1,277 @@
from numerictypes import character, string, unicode_, obj2dtype, integer
-from numeric import ndarray, multiter, empty
+from numeric import ndarray, broadcast, empty
+import sys
# special sub-class for character arrays (string and unicode_)
# This adds equality testing and methods of str and unicode types
# which operate on an element-by-element basis
+
class ndchararray(ndarray):
def __new__(subtype, shape, itemlen=1, unicode=False, buffer=None,
offset=0, strides=None, swap=0, fortran=0):
if unicode:
- dtype = 'U%d' % itemlen
+ dtype = unicode_
else:
- dtype = 'U%d' % itemlen
+ dtype = string
swap = 0
-
if buffer is None:
- self = ndarray.__new__(subtype, shape, dtype, fortran=fortran)
+ self = ndarray.__new__(subtype, shape, dtype, itemlen, fortran=fortran)
else:
- self = ndarray.__new__(subtype, shape, dtype, buffer=buffer,
+ self = ndarray.__new__(subtype, shape, dtype, itemlen, buffer=buffer,
offset=offset, strides=strides,
swap=swap, fortran=fortran)
return self
-
def __reduce__(self):
pass
- # these should be moved to C
- def __eq__(self, other):
- b = multiter(self, other)
+ def _richcmpfunc(self, other, op):
+ b = broadcast(self, other)
result = empty(b.shape, dtype=bool)
res = result.flat
for k, val in enumerate(b):
- res[k] = (val[0] == val[1])
+ r1 = val[0].strip('\x00')
+ r2 = val[1]
+ res[k] = eval("r1 %s r2" % op, {'r1':r1,'r2':r2})
return result
+
+ # these should probably be moved to C
+ def __eq__(self, other):
+ return self._richcmpfunc(other, '==')
def __ne__(self, other):
- b = multiter(self, other)
- result = empty(b.shape, dtype=bool)
- res = result.flat
- for k, val in enumerate(b):
- res[k] = (val[0] != val[1])
- return result
+ return self._richcmpfunc(other, '!=')
def __ge__(self, other):
- b = multiter(self, other)
- result = empty(b.shape, dtype=bool)
- res = result.flat
- for k, val in enumerate(b):
- res[k] = (val[0] >= val[1])
- return result
-
+ return self._richcmpfunc(other, '>=')
+
def __le__(self, other):
- b = multiter(self, other)
- result = empty(b.shape, dtype=bool)
- res = result.flat
- for k, val in enumerate(b):
- res[k] = (val[0] <= val[1])
- return result
+ return self._richcmpfunc(other, '<=')
def __gt__(self, other):
- b = multiter(self, other)
- result = empty(b.shape, dtype=bool)
- res = result.flat
- for k, val in enumerate(b):
- res[k] = (val[0] > val[1])
- return result
+ return self._richcmpfunc(other, '>')
def __lt__(self, other):
- b = multiter(self, other)
- result = empty(b.shape, dtype=bool)
- res = result.flat
- for k, val in enumerate(b):
- res[k] = (val[0] < val[1])
- return result
+ return self._richcmpfunc(other, '<')
def __add__(self, other):
- b = multiter(self, other)
+ b = broadcast(self, other)
arr = b.iters[1].base
outitem = self.itemsize + arr.itemsize
- dtype = self.dtypestr[1:2] + str(outitem)
- result = empty(b.shape, dtype=dtype)
+ result = ndchararray(b.shape, outitem, self.dtype is unicode_)
res = result.flat
for k, val in enumerate(b):
res[k] = (val[0] + val[1])
return result
def __radd__(self, other):
- b = multiter(other, self)
+ b = broadcast(other, self)
outitem = b.iters[0].base.itemsize + \
b.iters[1].base.itemsize
- dtype = self.dtypestr[1:2] + str(outitem)
- result = empty(b.shape, dtype=dtype)
+ result = ndchararray(b.shape, outitem, self.dtype is unicode_)
res = result.flat
for k, val in enumerate(b):
res[k] = (val[0] + val[1])
return result
def __mul__(self, other):
- b = multiter(self, other)
+ b = broadcast(self, other)
arr = b.iters[1].base
if not issubclass(arr.dtype, integer):
raise ValueError, "Can only multiply by integers"
outitem = b.iters[0].base.itemsize * arr.max()
- dtype = self.dtypestr[1:2] + str(outitem)
- result = empty(b.shape, dtype=dtype)
+ result = ndchararray(b.shape, outitem, self.dtype is unicode_)
res = result.flat
for k, val in enumerate(b):
res[k] = val[0]*val[1]
return result
def __rmul__(self, other):
- b = multiter(self, other)
+ b = broadcast(self, other)
arr = b.iters[1].base
if not issubclass(arr.dtype, integer):
raise ValueError, "Can only multiply by integers"
outitem = b.iters[0].base.itemsize * arr.max()
- dtype = self.dtypestr[1:2] + str(outitem)
- result = empty(b.shape, dtype=dtype)
+ result = ndchararray(b.shape, outitem, self.dtype is unicode_)
res = result.flat
for k, val in enumerate(b):
res[k] = val[0]*val[1]
return result
def __mod__(self, other):
- return NotImplemented
-
+ b = broadcast(self, other)
+ res = [None]*b.size
+ maxsize = -1
+ for k,val in enumerate(b):
+ newval = val[0] % val[1]
+ maxsize = max(len(newval), maxsize)
+ res[k] = newval
+ newarr = ndchararray(b.shape, maxsize, self.dtype is unicode_)
+ nearr[:] = res
+ return newarr
+
def __rmod__(self, other):
return NotImplemented
- def capitalize(self):
- pass
-
- def center(self):
- pass
-
- def count(self):
- pass
-
- def decode(self):
- pass
-
- def encode(self):
- pass
+ def _generalmethod(self, name, myiter):
+ res = [None]*myiter.size
+ maxsize = -1
+ for k, val in enumerate(myiter):
+ newval = []
+ for chk in val[1:]:
+ if chk is None:
+ break
+ newval.append(chk)
+ newitem = getattr(val[0],name)(*newval)
+ maxsize = max(len(newitem), maxsize)
+ res[k] = newval
+ newarr = ndchararray(myiter.shape, maxsize, self.dtype is unicode_)
+ newarr[:] = res
+ return newarr
+
+ def _typedmethod(self, name, myiter, dtype):
+ result = empty(myiter.shape, dtype=dtype)
+ res = result.flat
+ for k, val in enumerate(myiter):
+ newval = []
+ for chk in val[1:]:
+ if chk is None:
+ break
+ newval.append(chk)
+ newitem = getattr(val[0],name)(*newval)
+ res[k] = newval
+ return res
+
+ def _samemethod(self, name):
+ result = self.copy()
+ res = result.flat
+ for k, val in enumerate(self.flat):
+ res[k] = getattr(val, name)()
+ return result
- def endswith(self):
- pass
+ def capitalize(self):
+ return self._samemethod('capitalize')
+
+ if sys.version[:3] >= '2.4':
+ def center(self, width, fillchar=' '):
+ return self._generalmethod('center', broadcast(self, width, fillchar))
+ def ljust(self, width, fillchar=' '):
+ return self._generalmethod('ljust', broadcast(self, width, fillchar))
+ def rjust(self, width, fillchar=' '):
+ return self._generalmethod('rjust', broadcast(self, width, fillchar))
+ def rsplit(self, sep=None, maxsplit=None):
+ return self._generalmethod2('rsplit', broadcast(self, sep, maxsplit))
+ else:
+ def ljust(self, width):
+ return self._generalmethod('ljust', broadcast(self, width))
+ def rjust(self, width):
+ return self._generalmethod('rjust', broadcast(self, width))
+ def center(self, width):
+ return self._generalmethod('center', broadcast(self, width))
+
+ def count(self, sub, start=None, end=None):
+ return self._typedmethod('count', broadcast(self, sub, start, end), int)
+
+ def decode(self,encoding=None,errors=None):
+ return self._generalmethod('decode', broadcast(self, encoding, errors))
+
+ def encode(self,encoding=None,errors=None):
+ return self._generalmethod('encode', broadcast(self, encoding, errors))
+
+ def endswith(self, suffix, start=None, end=None):
+ return self._typedmethod('endswith', broadcast(self, suffix, start, end), bool)
+
+ def expandtabs(self, tabsize=None):
+ return self._generalmethod('endswith', broadcast(self, tabsize))
- def expandtabs(self):
- pass
+ def find(self, sub, start=None, end=None):
+ return self._typedmethod('find', broadcast(self, sub, start, end), int)
- def find(self):
- pass
+ def index(self, sub, start=None, end=None):
+ return self._typedmethod('index', broadcast(self, sub, start, end), int)
- def index(self):
- pass
+ def _ismethod(self, name):
+ result = empty(self.shape, dtype=bool)
+ res = result.flat
+ for k, val in enumerate(self.flat):
+ res[k] = getattr(val, name)()
+ return result
def isalnum(self):
- pass
+ return self._ismethod('isalnum')
def isalpha(self):
- pass
+ return self._ismethod('isalpha')
def isdigit(self):
- pass
+ return self._ismethod('isdigit')
def islower(self):
- pass
-
+ return self._ismethod('islower')
+
def isspace(self):
- pass
+ return self._ismethod('isspace')
def istitle(self):
- pass
+ return self._ismethod('istitle')
def isupper(self):
- pass
-
- def join(self):
- pass
-
- def ljust(self):
- pass
+ return self._ismethod('isupper')
+ def join(self, seq):
+ return self._generalmethod('join', broadcast(self, seq))
+
def lower(self):
- pass
+ return self._samemethod('lower')
- def lstrip(self):
- pass
+ def lstrip(self, chars):
+ return self._generalmethod('lstrip', broadcast(self, chars))
- def replace(self):
- pass
-
- def rfind(self):
- pass
+ def replace(self, old, new, count=None):
+ return self._generalmethod('replace', broadcast(self, old, new, count))
- def rindex(self):
- pass
+ def rfind(self, sub, start=None, end=None):
+ return self._typedmethod('rfind', broadcast(self, sub, start, end), int)
- def rjust(self):
- pass
+ def rindex(self, sub, start=None, end=None):
+ return self._typedmethod('rindex', broadcast(self, sub, start, end), int)
- def rsplit(self):
- pass
+ def rstrip(self, chars=None):
+ return self._generalmethod('rstrip', broadcast(self, chars))
- def rstrip(self):
- pass
-
- def split(self):
- pass
-
- def splitlines(self):
- pass
+ def split(self, sep=None, maxsplit=None):
+ return self._typedmethod('split', broadcast(self, sep, maxsplit), object)
- def startswith(self):
- pass
+ def splitlines(self, keepends=None):
+ return self._typedmethod('splitlines', broadcast(self, keepends), object)
+
+ def startswith(self, prefix, start=None, end=None):
+ return self._typedmethod('startswith', broadcast(self, prefix, start, end), bool)
- def strip(self):
- pass
+ def strip(self, chars=None):
+ return self._generalmethod('strip', broadcast(self, chars))
def swapcase(self):
- pass
+ return self._samemethod('swapcase')
def title(self):
- pass
-
- def translate(self):
- pass
+ return self._samemethod('title')
+ #deletechars not accepted for unicode objects
+ def translate(self, table, deletechars=None):
+ if self.dtype is unicode_:
+ return self._generalmethod('translate', broadcast(self, table))
+ else:
+ return self._generalmethod('translate', broadcast(self, table, deletechars))
+
def upper(self):
- pass
+ return self._samemethod('upper')
- def zfill(self):
- pass
+ def zfill(self, width):
+ return self._generalmethod('zfill', broadcast(self, width))
def chararray(obj, itemlen=7, copy=True, unicode=False, fortran=False):