summaryrefslogtreecommitdiff
path: root/numpy/lib/function_base.py
diff options
context:
space:
mode:
authorStefan van der Walt <stefan@sun.ac.za>2008-07-08 08:24:37 +0000
committerStefan van der Walt <stefan@sun.ac.za>2008-07-08 08:24:37 +0000
commit048bc867ad2ba31dbd784eb6432a492a65686510 (patch)
treeb16176d8225de0f4c1321d36ab519e167d5e090b /numpy/lib/function_base.py
parent757b1fbfd996c969eb4e76d6949a6ae242ddb3ae (diff)
downloadnumpy-048bc867ad2ba31dbd784eb6432a492a65686510.tar.gz
Piecewise should not expose raw memory. Closes #798.
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r--numpy/lib/function_base.py19
1 files changed, 12 insertions, 7 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index e8df0b439..2204f1863 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -563,8 +563,11 @@ def piecewise(x, condlist, funclist, *args, **kw):
"""
x = asanyarray(x)
n2 = len(funclist)
- if not isinstance(condlist, type([])):
+ if isscalar(condlist) or \
+ not (isinstance(condlist[0], list) or
+ isinstance(condlist[0], ndarray)):
condlist = [condlist]
+ condlist = [asarray(c, dtype=bool) for c in condlist]
n = len(condlist)
if n == n2-1: # compute the "otherwise" condition.
totlist = condlist[0]
@@ -573,10 +576,11 @@ def piecewise(x, condlist, funclist, *args, **kw):
condlist.append(~totlist)
n += 1
if (n != n2):
- raise ValueError, "function list and condition list must be the same"
+ raise ValueError, "function list and condition list " \
+ "must be the same"
zerod = False
- # This is a hack to work around problems with NumPy's
- # handling of 0-d arrays and boolean indexing with
+ # This is a hack to work around problems with NumPy's
+ # handling of 0-d arrays and boolean indexing with
# numpy.bool_ scalars
if x.ndim == 0:
x = x[None]
@@ -589,7 +593,8 @@ def piecewise(x, condlist, funclist, *args, **kw):
condition = condlist[k]
newcondlist.append(condition)
condlist = newcondlist
- y = empty(x.shape, x.dtype)
+
+ y = zeros(x.shape, x.dtype)
for k in range(n):
item = funclist[k]
if not callable(item):
@@ -1090,7 +1095,7 @@ class vectorize(object):
self.__doc__ = pyfunc.__doc__
else:
self.__doc__ = doc
- if isinstance(otypes, types.StringType):
+ if isinstance(otypes, str):
self.otypes = otypes
for char in self.otypes:
if char not in typecodes['All']:
@@ -1121,7 +1126,7 @@ class vectorize(object):
for arg in args:
newargs.append(asarray(arg).flat[0])
theout = self.thefunc(*newargs)
- if isinstance(theout, types.TupleType):
+ if isinstance(theout, tuple):
self.nout = len(theout)
else:
self.nout = 1