diff options
author | Stefan van der Walt <stefan@sun.ac.za> | 2008-07-08 08:24:37 +0000 |
---|---|---|
committer | Stefan van der Walt <stefan@sun.ac.za> | 2008-07-08 08:24:37 +0000 |
commit | 048bc867ad2ba31dbd784eb6432a492a65686510 (patch) | |
tree | b16176d8225de0f4c1321d36ab519e167d5e090b /numpy/lib/function_base.py | |
parent | 757b1fbfd996c969eb4e76d6949a6ae242ddb3ae (diff) | |
download | numpy-048bc867ad2ba31dbd784eb6432a492a65686510.tar.gz |
Piecewise should not expose raw memory. Closes #798.
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r-- | numpy/lib/function_base.py | 19 |
1 files changed, 12 insertions, 7 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index e8df0b439..2204f1863 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -563,8 +563,11 @@ def piecewise(x, condlist, funclist, *args, **kw): """ x = asanyarray(x) n2 = len(funclist) - if not isinstance(condlist, type([])): + if isscalar(condlist) or \ + not (isinstance(condlist[0], list) or + isinstance(condlist[0], ndarray)): condlist = [condlist] + condlist = [asarray(c, dtype=bool) for c in condlist] n = len(condlist) if n == n2-1: # compute the "otherwise" condition. totlist = condlist[0] @@ -573,10 +576,11 @@ def piecewise(x, condlist, funclist, *args, **kw): condlist.append(~totlist) n += 1 if (n != n2): - raise ValueError, "function list and condition list must be the same" + raise ValueError, "function list and condition list " \ + "must be the same" zerod = False - # This is a hack to work around problems with NumPy's - # handling of 0-d arrays and boolean indexing with + # This is a hack to work around problems with NumPy's + # handling of 0-d arrays and boolean indexing with # numpy.bool_ scalars if x.ndim == 0: x = x[None] @@ -589,7 +593,8 @@ def piecewise(x, condlist, funclist, *args, **kw): condition = condlist[k] newcondlist.append(condition) condlist = newcondlist - y = empty(x.shape, x.dtype) + + y = zeros(x.shape, x.dtype) for k in range(n): item = funclist[k] if not callable(item): @@ -1090,7 +1095,7 @@ class vectorize(object): self.__doc__ = pyfunc.__doc__ else: self.__doc__ = doc - if isinstance(otypes, types.StringType): + if isinstance(otypes, str): self.otypes = otypes for char in self.otypes: if char not in typecodes['All']: @@ -1121,7 +1126,7 @@ class vectorize(object): for arg in args: newargs.append(asarray(arg).flat[0]) theout = self.thefunc(*newargs) - if isinstance(theout, types.TupleType): + if isinstance(theout, tuple): self.nout = len(theout) else: self.nout = 1 |