summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
authorStefan van der Walt <stefan@sun.ac.za>2008-07-08 08:24:37 +0000
committerStefan van der Walt <stefan@sun.ac.za>2008-07-08 08:24:37 +0000
commit048bc867ad2ba31dbd784eb6432a492a65686510 (patch)
treeb16176d8225de0f4c1321d36ab519e167d5e090b /numpy/lib
parent757b1fbfd996c969eb4e76d6949a6ae242ddb3ae (diff)
downloadnumpy-048bc867ad2ba31dbd784eb6432a492a65686510.tar.gz
Piecewise should not expose raw memory. Closes #798.
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/function_base.py19
-rw-r--r--numpy/lib/tests/test_function_base.py46
2 files changed, 54 insertions, 11 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index e8df0b439..2204f1863 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -563,8 +563,11 @@ def piecewise(x, condlist, funclist, *args, **kw):
"""
x = asanyarray(x)
n2 = len(funclist)
- if not isinstance(condlist, type([])):
+ if isscalar(condlist) or \
+ not (isinstance(condlist[0], list) or
+ isinstance(condlist[0], ndarray)):
condlist = [condlist]
+ condlist = [asarray(c, dtype=bool) for c in condlist]
n = len(condlist)
if n == n2-1: # compute the "otherwise" condition.
totlist = condlist[0]
@@ -573,10 +576,11 @@ def piecewise(x, condlist, funclist, *args, **kw):
condlist.append(~totlist)
n += 1
if (n != n2):
- raise ValueError, "function list and condition list must be the same"
+ raise ValueError, "function list and condition list " \
+ "must be the same"
zerod = False
- # This is a hack to work around problems with NumPy's
- # handling of 0-d arrays and boolean indexing with
+ # This is a hack to work around problems with NumPy's
+ # handling of 0-d arrays and boolean indexing with
# numpy.bool_ scalars
if x.ndim == 0:
x = x[None]
@@ -589,7 +593,8 @@ def piecewise(x, condlist, funclist, *args, **kw):
condition = condlist[k]
newcondlist.append(condition)
condlist = newcondlist
- y = empty(x.shape, x.dtype)
+
+ y = zeros(x.shape, x.dtype)
for k in range(n):
item = funclist[k]
if not callable(item):
@@ -1090,7 +1095,7 @@ class vectorize(object):
self.__doc__ = pyfunc.__doc__
else:
self.__doc__ = doc
- if isinstance(otypes, types.StringType):
+ if isinstance(otypes, str):
self.otypes = otypes
for char in self.otypes:
if char not in typecodes['All']:
@@ -1121,7 +1126,7 @@ class vectorize(object):
for arg in args:
newargs.append(asarray(arg).flat[0])
theout = self.thefunc(*newargs)
- if isinstance(theout, types.TupleType):
+ if isinstance(theout, tuple):
self.nout = len(theout)
else:
self.nout = 1
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index 0be1a22b7..f4d6c009d 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -615,16 +615,54 @@ class TestUnique(TestCase):
x = array([5+6j, 1+1j, 1+10j, 10, 5+6j])
assert(all(unique(x) == [1+1j, 1+10j, 5+6j, 10]))
-def compare_results(res,desired):
- for i in range(len(desired)):
- assert_array_equal(res[i],desired[i])
-class TestPiecewise(TestCase):
+class TestPiecewise(NumpyTestCase):
+ def check_simple(self):
+ # Condition is single bool list
+ x = piecewise([0, 0], [True, False], [1])
+ assert_array_equal(x, [1, 0])
+
+ # List of conditions: single bool list
+ x = piecewise([0, 0], [[True, False]], [1])
+ assert_array_equal(x, [1, 0])
+
+ # Conditions is single bool array
+ x = piecewise([0, 0], array([True, False]), [1])
+ assert_array_equal(x, [1, 0])
+
+ # Condition is single int array
+ x = piecewise([0, 0], array([1, 0]), [1])
+ assert_array_equal(x, [1, 0])
+
+ # List of conditions: int array
+ x = piecewise([0, 0], [array([1, 0])], [1])
+ assert_array_equal(x, [1, 0])
+
+
+ x = piecewise([0, 0], [[False, True]], [lambda x: -1])
+ assert_array_equal(x, [0, -1])
+
+ x = piecewise([1, 2], [[True, False], [False, True]], [3, 4])
+ assert_array_equal(x, [3, 4])
+
+ def check_default(self):
+ # No value specified for x[1], should be 0
+ x = piecewise([1, 2], [True, False], [2])
+ assert_array_equal(x, [2, 0])
+
+ # Should set x[1] to 3
+ x = piecewise([1, 2], [True, False], [2, 3])
+ assert_array_equal(x, [2, 3])
+
def test_0d(self):
x = array(3)
y = piecewise(x, x>3, [4, 0])
assert y.ndim == 0
assert y == 0
+def compare_results(res,desired):
+ for i in range(len(desired)):
+ assert_array_equal(res[i],desired[i])
+
if __name__ == "__main__":
run_module_suite()